Spaces:
Running
on
Zero
Running
on
Zero
| from colpali_engine.models import ColPali | |
| from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor | |
| from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor | |
| from colpali_engine.utils.torch_utils import ListDataset, get_torch_device | |
| from torch.utils.data import DataLoader | |
| import torch | |
| from typing import List, cast | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import os | |
| import spaces | |
| model_name = "vidore/colpali-v1.2" | |
| device = get_torch_device("cuda") | |
| model = ColPali.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map=device, | |
| ).eval() | |
| processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) | |
| class ColpaliManager: | |
| def __init__(self, device = "cuda", model_name = "vidore/colpali-v1.2"): | |
| print(f"Initializing ColpaliManager with device {device} and model {model_name}") | |
| # self.device = get_torch_device(device) | |
| # self.model = ColPali.from_pretrained( | |
| # model_name, | |
| # torch_dtype=torch.bfloat16, | |
| # device_map=self.device, | |
| # ).eval() | |
| # self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) | |
| def get_images(self, paths: list[str]) -> List[Image.Image]: | |
| return [Image.open(path) for path in paths] | |
| def process_images(self, image_paths:list[str], batch_size=5): | |
| print(f"Processing {len(image_paths)} image_paths") | |
| images = self.get_images(image_paths) | |
| dataloader = DataLoader( | |
| dataset=ListDataset[str](images), | |
| batch_size=batch_size, | |
| shuffle=False, | |
| collate_fn=lambda x: processor.process_images(x), | |
| ) | |
| ds: List[torch.Tensor] = [] | |
| for batch_doc in tqdm(dataloader): | |
| with torch.no_grad(): | |
| batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} | |
| embeddings_doc = model(**batch_doc) | |
| ds.extend(list(torch.unbind(embeddings_doc.to(device)))) | |
| ds_np = [d.float().cpu().numpy() for d in ds] | |
| return ds_np | |
| def process_text(self, texts: list[str]): | |
| print(f"Processing {len(texts)} texts") | |
| dataloader = DataLoader( | |
| dataset=ListDataset[str](texts), | |
| batch_size=1, | |
| shuffle=False, | |
| collate_fn=lambda x: processor.process_queries(x), | |
| ) | |
| qs: List[torch.Tensor] = [] | |
| for batch_query in dataloader: | |
| with torch.no_grad(): | |
| batch_query = {k: v.to(model.device) for k, v in batch_query.items()} | |
| embeddings_query = model(**batch_query) | |
| qs.extend(list(torch.unbind(embeddings_query.to(device)))) | |
| qs_np = [q.float().cpu().numpy() for q in qs] | |
| return qs_np | |