File size: 5,302 Bytes
2df56dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from io import BytesIO
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from sentence_transformers.models import Transformer as BaseTransformer
from transformers import AutoModelForVision2Seq, AutoProcessor


class MultiModalTransformer(BaseTransformer):
    def __init__(
        self,
        model_name_or_path: str,
        cache_dir: Optional[str] = None,
        tokenizer_args: Optional[Dict[str, Any]] = None,
        min_image_tokens: int = 256,
        max_image_tokens: int = 1280,
        max_length: int = 1800,
        **kwargs,
    ):
        super().__init__(model_name_or_path, **kwargs)
        if tokenizer_args is None:
            tokenizer_args = {}
        tokenizer_args.pop("trust_remote_code", None)

        # Initialize processor
        min_pixels = min_image_tokens * 28 * 28
        max_pixels = max_image_tokens * 28 * 28
        self.processor = AutoProcessor.from_pretrained(
            model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
        )
        self.processor.tokenizer.padding_side = 'right'
        self.sep = ' '
        self.max_length = max_length
        self.normalize = True

    def _load_model(
            self,
            model_name_or_path: str,
            config,
            cache_dir: str,
            backend: str,
            is_peft_model: bool,
            **model_args,
    ) -> None:
        model_args.pop("trust_remote_code", None)
        self.auto_model = AutoModelForVision2Seq.from_pretrained(
            model_name_or_path, torch_dtype=torch.float16, **model_args
        )

    def forward(
        self, features: Dict[str, torch.Tensor], **kwargs
    ) -> Dict[str, torch.Tensor]:       
        if features.get("inputs_embeds", None) is None:
            features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
            if features.get("pixel_values", None) is not None:
                features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
                image_embeds = self.auto_model.visual(
                    features["pixel_values"], grid_thw=features["image_grid_thw"]
                )
                image_mask = features["input_ids"] == self.auto_model.config.image_token_id
                features["inputs_embeds"][image_mask] = image_embeds
                features.pop("pixel_values")
                features.pop("image_grid_thw")
        features.pop("input_ids")       
        outputs = self.auto_model.model(
            **features,
            return_dict=True,
            output_hidden_states=True,
            # **kwargs
        )
        pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
        left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0])  # TODO
        if left_padding:
            embeddings = outputs.last_hidden_state
        else:
            sequence_lengths = pooling_mask.sum(dim=1) - 1
            embeddings = outputs.last_hidden_state[torch.arange(
                outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
            ), sequence_lengths]
        features.update({"token_embeddings": embeddings})
        return features 

    def tokenize(self, texts: List[List[Dict[str, Image.Image]]] | List[str]) -> Dict[str, torch.Tensor]:
        split_token = "<|im_end|>\n"
        def process_text_item(item):
            if isinstance(item, str):
                return item, None

            text, img = "", None
            if "image" in item:
                text += "<|vision_start|><|image_pad|><|vision_end|>"
                img = item["image"]
                if isinstance(img, bytes):
                    img = Image.open(BytesIO(img)).convert("RGB")
                elif isinstance(img, str):
                    img = Image.open(img).convert("RGB")
                elif not isinstance(img, Image):
                    raise ValueError(f"Unknown image type {type(img)}")
            if "text" in item:
                text += item["text"].lstrip()
            if split_token in text:
                instruction, text = text.split(split_token, 1)
                text = f'{instruction}{split_token}<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
            else:
                text = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
            return text, img

        all_texts, all_images = [], []
        for item in texts:
            text, images = process_text_item(item)
            all_texts.append(text)
            all_images.append(images)
        
        if all_images != [None] * len(all_images):
            inputs = self.processor(
                text=all_texts,
                images=all_images,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
        else:
            inputs = self.processor(
                text=all_texts,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
        return inputs