Spaces:
Paused
Paused
| import torch | |
| from transformers import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoProcessor, | |
| CLIPTextModel, | |
| CLIPTokenizer, | |
| ) | |
| from .utils import freeze | |
| class ClipTextEmbedder: | |
| def __init__(self, conf, device): | |
| self.model = CLIPTextModel.from_pretrained(conf.checkpoint_path).to(device) | |
| self.model = freeze(self.model) | |
| self.tokenizer = CLIPTokenizer.from_pretrained(conf.checkpoint_path) | |
| self.max_length = conf.max_length | |
| def __call__(self, texts): | |
| inputs = self.tokenizer( | |
| texts, | |
| max_length=self.max_length, | |
| truncation=True, | |
| add_special_tokens=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ).to(self.model.device) | |
| with torch.no_grad(): | |
| pooled_embed = self.model(**inputs)["pooler_output"] | |
| return pooled_embed | |
| class Qwen2_5_VLTextEmbedder: | |
| PROMPT_TEMPLATE = { | |
| "template": { | |
| "video": ( | |
| "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", | |
| "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", | |
| "Describe the location of the video, main characters or objects and their action.", | |
| "Describe the dynamism of the video and presented actions.", | |
| "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", | |
| "Describe the visual effects, postprocessing and transitions if they are presented in the video.", | |
| "Pay attention to the order of key actions shown in the scene.<|im_end|>", | |
| "<|im_start|>user\n{}<|im_end|>", | |
| ), | |
| "image": ( | |
| "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>", | |
| "<|im_start|>user\n{}<|im_end|>", | |
| ), | |
| }, | |
| "crop_start": {"video": 129, "image": 41}, | |
| } | |
| def __init__(self, conf, device): | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| conf.checkpoint_path, | |
| dtype=torch.bfloat16, | |
| device_map=device, | |
| ) | |
| self.model = freeze(self.model) | |
| self.model = torch.compile(self.model, dynamic=True) | |
| self.processor = AutoProcessor.from_pretrained(conf.checkpoint_path, use_fast=True) | |
| self.max_length = conf.max_length | |
| def __call__(self, texts, type_of_content="video"): | |
| prompt_template = "\n".join(self.PROMPT_TEMPLATE["template"][type_of_content]) | |
| crop_start = self.PROMPT_TEMPLATE["crop_start"][type_of_content] | |
| full_texts = list(map(lambda x: prompt_template.format(x), texts)) | |
| max_length = self.max_length + crop_start | |
| inputs = self.processor( | |
| text=full_texts, | |
| images=None, | |
| videos=None, | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| padding=True, | |
| ).to(self.model.device) | |
| with torch.no_grad(): | |
| embeds = self.model( | |
| input_ids=inputs["input_ids"], | |
| return_dict=True, | |
| output_hidden_states=True, | |
| )["hidden_states"][-1][:, crop_start:] | |
| attention_mask = inputs["attention_mask"][:, crop_start:] | |
| embeds = embeds[attention_mask.bool()] | |
| cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) | |
| cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to( | |
| dtype=torch.int32 | |
| ) | |
| return embeds, cu_seqlens | |
| class Kandinsky5TextEmbedder: | |
| def __init__(self, conf, device="cpu"): | |
| self.embedder = Qwen2_5_VLTextEmbedder(conf.qwen, device) | |
| self.clip_embedder = ClipTextEmbedder(conf.clip, device) | |
| self.conf = conf | |
| def encode(self, texts, type_of_content="image"): | |
| text_embeds, cu_seqlens = self.embedder(texts, type_of_content=type_of_content) | |
| pooled_embed = self.clip_embedder(texts) | |
| return {"text_embeds": text_embeds, "pooled_embed": pooled_embed}, cu_seqlens | |
| def to(self, device): | |
| self.embedder.model = self.embedder.model.to(device) | |
| self.clip_embedder.model = self.clip_embedder.model.to(device) | |
| return self | |
| def get_text_embedder(conf, device="cpu"): | |
| return Kandinsky5TextEmbedder(conf, device) | |