|
|
from io import BytesIO |
|
|
from typing import Any, Dict, Optional, List |
|
|
import torch |
|
|
from PIL import Image |
|
|
from sentence_transformers.models import Transformer as BaseTransformer |
|
|
from transformers import AutoModelForVision2Seq, AutoProcessor |
|
|
|
|
|
|
|
|
class MultiModalTransformer(BaseTransformer): |
|
|
def __init__( |
|
|
self, |
|
|
model_name_or_path: str, |
|
|
cache_dir: Optional[str] = None, |
|
|
tokenizer_args: Optional[Dict[str, Any]] = None, |
|
|
min_image_tokens: int = 256, |
|
|
max_image_tokens: int = 1280, |
|
|
max_length: int = 1800, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(model_name_or_path, **kwargs) |
|
|
if tokenizer_args is None: |
|
|
tokenizer_args = {} |
|
|
tokenizer_args.pop("trust_remote_code", None) |
|
|
|
|
|
|
|
|
min_pixels = min_image_tokens * 28 * 28 |
|
|
max_pixels = max_image_tokens * 28 * 28 |
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs |
|
|
) |
|
|
self.processor.tokenizer.padding_side = 'right' |
|
|
self.sep = ' ' |
|
|
self.max_length = max_length |
|
|
self.normalize = True |
|
|
|
|
|
def _load_model( |
|
|
self, |
|
|
model_name_or_path: str, |
|
|
config, |
|
|
cache_dir: str, |
|
|
backend: str, |
|
|
is_peft_model: bool, |
|
|
**model_args, |
|
|
) -> None: |
|
|
model_args.pop("trust_remote_code", None) |
|
|
self.auto_model = AutoModelForVision2Seq.from_pretrained( |
|
|
model_name_or_path, torch_dtype=torch.float16, **model_args |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, features: Dict[str, torch.Tensor], **kwargs |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
if features.get("inputs_embeds", None) is None: |
|
|
features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"]) |
|
|
if features.get("pixel_values", None) is not None: |
|
|
features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype()) |
|
|
image_embeds = self.auto_model.visual( |
|
|
features["pixel_values"], grid_thw=features["image_grid_thw"] |
|
|
) |
|
|
image_mask = features["input_ids"] == self.auto_model.config.image_token_id |
|
|
features["inputs_embeds"][image_mask] = image_embeds |
|
|
features.pop("pixel_values") |
|
|
features.pop("image_grid_thw") |
|
|
features.pop("input_ids") |
|
|
outputs = self.auto_model.model( |
|
|
**features, |
|
|
return_dict=True, |
|
|
output_hidden_states=True, |
|
|
|
|
|
) |
|
|
pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"] |
|
|
left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) |
|
|
if left_padding: |
|
|
embeddings = outputs.last_hidden_state |
|
|
else: |
|
|
sequence_lengths = pooling_mask.sum(dim=1) - 1 |
|
|
embeddings = outputs.last_hidden_state[torch.arange( |
|
|
outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device |
|
|
), sequence_lengths] |
|
|
features.update({"token_embeddings": embeddings}) |
|
|
return features |
|
|
|
|
|
def tokenize(self, texts: List[List[Dict[str, Image.Image]]] | List[str]) -> Dict[str, torch.Tensor]: |
|
|
split_token = "<|im_end|>\n" |
|
|
def process_text_item(item): |
|
|
if isinstance(item, str): |
|
|
return item, None |
|
|
|
|
|
text, img = "", None |
|
|
if "image" in item: |
|
|
text += "<|vision_start|><|image_pad|><|vision_end|>" |
|
|
img = item["image"] |
|
|
if isinstance(img, bytes): |
|
|
img = Image.open(BytesIO(img)).convert("RGB") |
|
|
elif isinstance(img, str): |
|
|
img = Image.open(img).convert("RGB") |
|
|
elif not isinstance(img, Image): |
|
|
raise ValueError(f"Unknown image type {type(img)}") |
|
|
if "text" in item: |
|
|
text += item["text"].lstrip() |
|
|
if split_token in text: |
|
|
instruction, text = text.split(split_token, 1) |
|
|
text = f'{instruction}{split_token}<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>' |
|
|
else: |
|
|
text = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>" |
|
|
return text, img |
|
|
|
|
|
all_texts, all_images = [], [] |
|
|
for item in texts: |
|
|
text, images = process_text_item(item) |
|
|
all_texts.append(text) |
|
|
all_images.append(images) |
|
|
|
|
|
if all_images != [None] * len(all_images): |
|
|
inputs = self.processor( |
|
|
text=all_texts, |
|
|
images=all_images, |
|
|
padding="longest", |
|
|
truncation=True, |
|
|
max_length=self.max_seq_length, |
|
|
return_tensors="pt" |
|
|
) |
|
|
else: |
|
|
inputs = self.processor( |
|
|
text=all_texts, |
|
|
padding="longest", |
|
|
truncation=True, |
|
|
max_length=self.max_seq_length, |
|
|
return_tensors="pt" |
|
|
) |
|
|
return inputs |
|
|
|