File size: 5,302 Bytes
2df56dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from io import BytesIO
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from sentence_transformers.models import Transformer as BaseTransformer
from transformers import AutoModelForVision2Seq, AutoProcessor
class MultiModalTransformer(BaseTransformer):
def __init__(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
min_image_tokens: int = 256,
max_image_tokens: int = 1280,
max_length: int = 1800,
**kwargs,
):
super().__init__(model_name_or_path, **kwargs)
if tokenizer_args is None:
tokenizer_args = {}
tokenizer_args.pop("trust_remote_code", None)
# Initialize processor
min_pixels = min_image_tokens * 28 * 28
max_pixels = max_image_tokens * 28 * 28
self.processor = AutoProcessor.from_pretrained(
model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
)
self.processor.tokenizer.padding_side = 'right'
self.sep = ' '
self.max_length = max_length
self.normalize = True
def _load_model(
self,
model_name_or_path: str,
config,
cache_dir: str,
backend: str,
is_peft_model: bool,
**model_args,
) -> None:
model_args.pop("trust_remote_code", None)
self.auto_model = AutoModelForVision2Seq.from_pretrained(
model_name_or_path, torch_dtype=torch.float16, **model_args
)
def forward(
self, features: Dict[str, torch.Tensor], **kwargs
) -> Dict[str, torch.Tensor]:
if features.get("inputs_embeds", None) is None:
features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
if features.get("pixel_values", None) is not None:
features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
image_embeds = self.auto_model.visual(
features["pixel_values"], grid_thw=features["image_grid_thw"]
)
image_mask = features["input_ids"] == self.auto_model.config.image_token_id
features["inputs_embeds"][image_mask] = image_embeds
features.pop("pixel_values")
features.pop("image_grid_thw")
features.pop("input_ids")
outputs = self.auto_model.model(
**features,
return_dict=True,
output_hidden_states=True,
# **kwargs
)
pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
if left_padding:
embeddings = outputs.last_hidden_state
else:
sequence_lengths = pooling_mask.sum(dim=1) - 1
embeddings = outputs.last_hidden_state[torch.arange(
outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
), sequence_lengths]
features.update({"token_embeddings": embeddings})
return features
def tokenize(self, texts: List[List[Dict[str, Image.Image]]] | List[str]) -> Dict[str, torch.Tensor]:
split_token = "<|im_end|>\n"
def process_text_item(item):
if isinstance(item, str):
return item, None
text, img = "", None
if "image" in item:
text += "<|vision_start|><|image_pad|><|vision_end|>"
img = item["image"]
if isinstance(img, bytes):
img = Image.open(BytesIO(img)).convert("RGB")
elif isinstance(img, str):
img = Image.open(img).convert("RGB")
elif not isinstance(img, Image):
raise ValueError(f"Unknown image type {type(img)}")
if "text" in item:
text += item["text"].lstrip()
if split_token in text:
instruction, text = text.split(split_token, 1)
text = f'{instruction}{split_token}<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
else:
text = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
return text, img
all_texts, all_images = [], []
for item in texts:
text, images = process_text_item(item)
all_texts.append(text)
all_images.append(images)
if all_images != [None] * len(all_images):
inputs = self.processor(
text=all_texts,
images=all_images,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
else:
inputs = self.processor(
text=all_texts,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
return inputs
|