File size: 7,519 Bytes
8249d39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9ab04a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from .modeling_llama import AdapterMLP, DEFAULT_SYSTEM_PROMPT, LlamaForCausalLM
from .configuration_llama import VLMConfig
from .configuration_clip import CLIPConfig
from .visual_modeling import CLIPModel
import torch
from torch import nn
from transformers import AutoProcessor

class AtriVLM(LlamaForCausalLM):
    def __init__(self, config: VLMConfig):
        super().__init__(config)
        if config.special_token_map:
            self.image_start_token_id = config.special_token_map['Image'][1]
            self.image_end_token_id = config.special_token_map['Image_End'][1]
            self.caption_token_id = config.special_token_map['Caption'][1]
            self.image_token_id = config.special_token_map['Image_Token'][1]
        else:
            raise ValueError("Special token map not found")
        self.image_adapter = AdapterMLP(config)
        self.num_patches = config.num_patches
        self.processor = AutoProcessor.from_pretrained(config.pretrained_vision_model).image_processor
        self.img_place_holder = "<IMGPLH>"
        self.img_start_token = "<IMAGE>"
        self.img_end_token = "<IMAGE_END>"
        self.image_token = "<Image_Token>"
        if config.load_vision_model:
            if isinstance(config.visual_config, dict):
                self.visual = CLIPModel(CLIPConfig(**config.visual_config))
            else:
                self.visual = CLIPModel(config.visual_config)
        else:
            self.visual = None
    
    def forward(self, input_ids=None, encoded_image=None, labels=None, past_key_values = None, attention_mask = None, inputs_embeds = None, **kwargs):
        """

        Forward pass for the VLM model that combines image and text embeddings.

        

        Args:

            input_ids (torch.LongTensor): Input token ids of shape (batch_size, seq_len)

            encoded_image (torch.FloatTensor): Encoded image features of shape (batch_size, num_patches, hidden_dim)

            labels (torch.LongTensor): Labels for computing the language modeling loss

        """
        if not past_key_values and (encoded_image is not None):
            encoded_image = encoded_image.to(self.get_input_embeddings().weight.dtype)
            # Process image features through the adapter
            processed_image = self.image_adapter(encoded_image)

            # Get embeddings for all input tokens
            token_embeddings = self.get_input_embeddings()(input_ids)
            
            # Find positions of image tokens and replace them with processed image embeddings
            image_token_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)
            token_embeddings = token_embeddings
            token_embeddings[image_token_positions] = processed_image.reshape(-1, processed_image.size(-1))
        else:
            token_embeddings = self.get_input_embeddings()(input_ids)
        # Call the native forward method with the modified embeddings
        outputs = self._native_forward(
            inputs_embeds=token_embeddings,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )
        
        return outputs
    
    
    def prepare_input_ids_for_generation(self, prompts, images, tokenizer, system_prompt=DEFAULT_SYSTEM_PROMPT):
        """

        Prepare input ids and images for generation.

        

        Args:

            prompts (List[str]): List of text prompts

            images (List[Image]): List of images corresponding to prompts

            tokenizer: Tokenizer instance

            system_prompt (str): System prompt to be prepended

            

        Returns:

            dict: Contains input_ids, attention_mask, and processed images

        """
        # Process the images first
        processed_images = []
        for image in images:
            # Process image through vision encoder
            pixel_values = self.processor(image, return_tensors="pt")["pixel_values"].to(self.visual.vision_model.embeddings.patch_embedding.weight.device)
            image_features = self.visual.encode_image(pixel_values)
            processed_images.append(image_features)
        
        # Stack all processed images
        if processed_images:
            processed_images = torch.cat(processed_images, dim=0)
        
        # Process each prompt
        formatted_prompts = []
        for prompt in prompts:
            # Replace image placeholder with tokens
            if self.img_place_holder in prompt:
                image_token_sequence = (
                    f"{self.img_start_token}" + 
                    f"{self.image_token}" * self.num_patches +
                    f"{self.img_end_token}"
                )
                formatted_prompt = prompt.replace(self.img_place_holder, image_token_sequence)
            else:
                formatted_prompt = prompt
                
            # Create conversation format
            conversation = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": formatted_prompt},
            ]
            
            # Apply chat template
            formatted_conversation = tokenizer.apply_chat_template(
                conversation, 
                tokenize=False,
                add_generation_prompt=True
            )
            formatted_prompts.append(formatted_conversation)
        
        # Tokenize all prompts together
        tokenized_output = tokenizer(
            formatted_prompts,
            padding=True,
            return_tensors="pt",
            padding_side="left"  # Use left padding since we're generating on the right
        )
        
        return {
            "input_ids": tokenized_output["input_ids"],
            "attention_mask": tokenized_output["attention_mask"],
            "encoded_image": processed_images if processed_images.size(0) > 0 else None
        }
    
    def prepare_for_generation(self, input_ids, encoded_image, **kwargs):
        """

        Prepare KV cache for generation by processing the image and initial tokens.

        

        Args:

            input_ids (torch.LongTensor): Input token ids of shape (batch_size, seq_len)

            encoded_image (torch.FloatTensor): Encoded image features of shape (batch_size, num_patches, hidden_dim)

            

        Returns:

            past_key_values: Tuple containing the key and value states to be used for subsequent generation

        """
        encoded_image = encoded_image.to(self.get_input_embeddings().weight.dtype)
        # Process image features through the adapter
        processed_image = self.image_adapter(encoded_image)
        
        # Get embeddings for all input tokens
        token_embeddings = self.get_input_embeddings()(input_ids)
        
        # Find positions of image tokens and replace them with processed image embeddings
        image_token_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)
        token_embeddings[image_token_positions] = processed_image.reshape(-1, processed_image.size(-1))
        
        # Forward pass with cache preparation
        outputs = self._native_forward(
            inputs_embeds=token_embeddings,
            use_cache=True,
            **kwargs
        )
        
        return outputs.past_key_values