csuhan's picture
Upload folder using huggingface_hub
b0c0df0 verified
import json
import os
from copy import deepcopy
from datetime import timedelta
from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import torch
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState
from accelerate.utils import gather_object
from PIL import Image
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor
from .BaseEmbedder import BaseEmbedder
class ClipBgeEmbedder(BaseEmbedder):
def __init__(
self,
name: str,
output_path: str,
mm_pretrained: str = "openai/clip-vit-large-patch14",
txt_pretrained: str = "BAAI/bge-m3",
device: str = "cuda",
device_map: str = "",
) -> None:
super().__init__(name, output_path)
self.model = CLIPModel.from_pretrained(mm_pretrained)
self.processor = CLIPProcessor.from_pretrained(mm_pretrained)
self.text_model = SentenceTransformer(txt_pretrained)
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
self.accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
if self.accelerator.num_processes > 1 and device_map == "":
self.device = torch.device(f"cuda:{self.accelerator.local_process_index}")
self.device_map = f"cuda:{self.accelerator.local_process_index}"
else:
self.device = torch.device(device)
self.device_map = device_map
self.model.to(self.device)
self.text_model.to(self.device)
def embed_task(self, task: str, ignored_ids: Union[set, List] = None):
DATASET_PATH, DATASET_NAME, split, requests, task_obj, self.docs = BaseEmbedder.init_task(task, ignored_ids)
self.accelerator.wait_for_everyone()
with self.accelerator.split_between_processes(requests, apply_padding=False) as requests_split:
results = {"outputs": []}
for req in tqdm(requests_split, disable=not self.accelerator.is_main_process):
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = req.args
visuals = [doc_to_visual(self.docs[doc_id])]
visuals = self.flatten(visuals)
text_embedding = self.text_model.encode([contexts])
text_embedding = torch.from_numpy(text_embedding).flatten()
if len(visuals) > 0:
img_inputs = self.processor(images=visuals, return_tensors="pt")
img_inputs = {k: v.to(self.device) for k, v in img_inputs.items()}
# For multiple images, we take the mean of it
image_embedding = self.model.get_image_features(**img_inputs).mean(dim=0).detach().cpu()
else:
image_embedding = torch.zeros(self.model.config.projection_dim)
embedding = torch.concat([image_embedding, text_embedding])
results["outputs"].append(embedding)
results = [results]
self.accelerator.wait_for_everyone()
results_gathered = gather_object(results)
if self.accelerator.is_main_process:
outputs = []
for r in results_gathered:
outputs += r["outputs"]
results_gathered = torch.stack(outputs)
np.save(open(os.path.join(self.output_path, f"{task}_embed.npy"), "wb"), results_gathered)
return results_gathered
if __name__ == "__main__":
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
text = ["a photo of a cat", "a photo of a dog"]
inputs = processor(text=text, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
outputs = model.get_text_features(**inputs)
print(outputs.mean(dim=0).shape)