| from typing import Any, Dict, Optional |
| import PIL |
| import torch |
| import PIL |
| import torch |
| from typing import Dict |
| from io import BytesIO |
| from transformers import SiglipImageProcessor |
| from sentence_transformers.models import Transformer as BaseTransformer |
|
|
|
|
| class MultiModalTransformer(BaseTransformer): |
|
|
| def __init__( |
| self, |
| model_name_or_path: str, |
| cache_dir: Optional[str] = None, |
| tokenizer_args: Optional[Dict[str, Any]] = None, |
| **kwargs, |
| ): |
| super().__init__(model_name_or_path, **kwargs) |
| if tokenizer_args is None: |
| tokenizer_args = {} |
| self.processor = SiglipImageProcessor.from_pretrained( |
| model_name_or_path, cache_dir=cache_dir, **tokenizer_args |
| ) |
|
|
| def forward( |
| self, features: dict[str, torch.Tensor], **kwargs |
| ) -> dict[str, torch.Tensor]: |
| trans_features = { |
| "input_ids": features["input_ids"], |
| "attention_mask": features["attention_mask"], |
| } |
| if "pixel_values" in features: |
| trans_features["pixel_values"] = features["pixel_values"].to( |
| self.auto_model.dtype |
| ) |
|
|
| sentence_embedding = self.auto_model(**trans_features, **kwargs)[ |
| "sentence_embedding" |
| ] |
| features.update({"sentence_embedding": sentence_embedding}) |
| return features |
|
|
| def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]: |
| img_start_token = "<|jasper_img_start|>" |
| img_token = "<|jasper_img_token|>" |
| img_end_token = "<|jasper_img_end|>" |
| num_img_tokens = 300 |
|
|
| def process_text_item(item): |
| if isinstance(item, str): |
| return item, [] |
| text, images = "", [] |
| for sub_item in item: |
| if sub_item["type"] == "text": |
| text += sub_item["content"] |
| elif sub_item["type"] == "image_bytes": |
| text += img_start_token + img_token * num_img_tokens + img_end_token |
| images.append( |
| PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB") |
| ) |
| elif sub_item["type"] == "image_path": |
| text += img_start_token + img_token * num_img_tokens + img_end_token |
| images.append(PIL.Image.open(sub_item["content"]).convert("RGB")) |
| else: |
| raise ValueError(f"unknown data type {sub_item['type']}") |
| return text, images |
|
|
| all_texts, all_images = [], [] |
| for item in texts: |
| text, images = process_text_item(item) |
| all_texts.append(text) |
| all_images.extend(images) |
| ipt = self.tokenizer( |
| all_texts, |
| padding="longest", |
| truncation=True, |
| max_length=self.max_seq_length, |
| return_tensors="pt", |
| ) |
| if all_images: |
| ipt["pixel_values"] = self.processor( |
| images=all_images, return_tensors="pt" |
| )["pixel_values"] |
| return ipt |
|
|