File size: 19,272 Bytes
b781107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
from model_components import ViT, MultiModalProjector
from decoder_language_model import DecoderLanguageModel
from constants import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import tokenizer, vocab_size


class VisionLanguageModel(nn.Module):
    """
    Vision Language Model integrating ViT, Projector, Contrastive Loss, Decoder (Class + Reg).
    Handles multiple points via padded regression targets and masked loss.
    """
    def __init__(self,
                 n_embd=HIDDEN_DIM,
                 vocab_size=vocab_size,
                 img_size=IMAGE_SIZE,
                 patch_size=PATCH_SIZE,
                 num_heads=NUM_HEADS,
                 num_blks_vit=NUM_LAYERS,
                 num_blks_dec=NUM_LAYERS,
                 emb_dropout=DROPOUT,
                 blk_dropout=DROPOUT,
                 max_context=CONTEXT_LENGTH,
                 shared_embed_dim=SHARED_EMBED_DIM,
                 lambda_contrastive=LAMBDA_CONTRASTIVE,
                 lambda_regression=LAMBDA_REGRESSION, # Use the updated constant
                 max_points = MAX_POINTS # Store max points
                 ):
        super().__init__()

        # --- Vision Backbone ---
        self.vision_encoder = ViT(
            img_size=img_size,
            patch_size=patch_size,
            num_hiddens=n_embd, # Assuming ViT output dim matches decoder embed dim
            num_heads=num_heads,
            num_blks=num_blks_vit,
            emb_dropout=emb_dropout,
            blk_dropout=blk_dropout
        )

        # --- Multimodal Components ---
        self.multimodal_projector =  MultiModalProjector(
            image_embed_dim=n_embd, # Input from ViT
            text_embed_dim=n_embd,  # Output matches decoder dim
            dropout=emb_dropout
        )
        self.image_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
        self.text_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))

        # --- Text Decoder ---
        # DecoderLanguageModel now has regression head outputting MAX_POINTS*2
        self.decoder = DecoderLanguageModel(
            n_embd=n_embd,
            vocab_size=vocab_size,
            num_heads=num_heads,
            n_layer=num_blks_dec,
            max_context=max_context,
            dropout=blk_dropout # Use block dropout for decoder consistency
        )

        # --- Store Configuration ---
        self.n_embd = n_embd
        self.vocab_size = vocab_size
        self.num_patches = (img_size // patch_size)**2 + 1
        self.lambda_contrastive = lambda_contrastive
        self.lambda_regression = lambda_regression
        self.max_points = max_points # Store max points

        self._resize_embeddings_if_needed(self.vocab_size)
        print("VisionLanguageModel initialized.")


    def _resize_embeddings_if_needed(self, current_vocab_size):
        """ Resizes decoder token embeddings if vocab size changed after init. """
        decoder_embedding_size = self.decoder.token_embedding_table.num_embeddings
        if decoder_embedding_size != current_vocab_size:
            print(f"Resizing VLM decoder token embeddings from {decoder_embedding_size} to {current_vocab_size}")
            # Freeze original weights before replacing layers
            self.decoder.token_embedding_table.weight.requires_grad = False
            self.decoder.lm_head.weight.requires_grad = False
            # Create new layers
            new_embedding = nn.Embedding(current_vocab_size, self.n_embd).to(DEVICE)
            new_lm_head = nn.Linear(self.n_embd, current_vocab_size, bias=False).to(DEVICE)
            # Assign new layers
            self.decoder.token_embedding_table = new_embedding
            self.decoder.lm_head = new_lm_head
            # Re-tie weights
            self.decoder.token_embedding_table.weight = self.decoder.lm_head.weight
            print("VLM decoder embeddings resized and weights retied.")


    def _calculate_contrastive_loss(self, image_features, text_features):
        """ Calculates the symmetric InfoNCE loss. """
        # Assumes features are already projected to shared_embed_dim
        # image_features: (B, E)
        # text_features: (B, E)

        # Normalize features
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)

        # Cosine similarity as logits (using learnable temperature)
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        # Calculate symmetric cross-entropy loss
        labels = torch.arange(len(logits_per_image), device=logits_per_image.device)
        loss_i = F.cross_entropy(logits_per_image, labels)
        loss_t = F.cross_entropy(logits_per_text, labels)
        contrastive_loss = (loss_i + loss_t) / 2.0

        # Handle potential NaNs
        if torch.isnan(contrastive_loss):
             print("Warning: Contrastive loss is NaN.")
             return None # Return None or zero tensor

        return contrastive_loss

    def forward(self,
                img_array,
                prompt_ids,
                prompt_attention_mask,
                target_ids,
                target_attention_mask,
                generative_targets=None,
                continuous_coords=None, # Now expects shape (B, MAX_POINTS, 2), padded
                coords_mask=None        # Mask for valid points (B, MAX_POINTS)
                ):
        """
        Main forward pass for training. Calculates combined loss with masked regression loss.
        """

        # --- 1. Encode Image ---
        image_embeds_raw = self.vision_encoder(img_array) # (B, N_img, C)
        B, N_img, C_img = image_embeds_raw.shape
        img_cls_token = image_embeds_raw[:, 0]

        # --- 2. Contrastive Loss Path ---
        contrastive_loss = None
        # ... (contrastive loss calculation - same as before) ...
        image_features_contrast = self.image_contrastive_head(img_cls_token)
        with torch.no_grad(): # Keep no_grad here for efficiency if prompt embeddings aren't trained via contrastive
             prompt_text_embeds_contrast = self.decoder.token_embedding_table(prompt_ids)
        prompt_lengths = prompt_attention_mask.sum(dim=1)
        last_token_indices = (prompt_lengths - 1).clamp(min=0)
        gather_indices = last_token_indices.view(B, 1, 1).expand(-1, -1, C_img)
        prompt_last_token_embed = prompt_text_embeds_contrast.gather(1, gather_indices).squeeze(1)
        text_features_contrast = self.text_contrastive_head(prompt_last_token_embed)
        contrastive_loss = self._calculate_contrastive_loss(image_features_contrast, text_features_contrast)


        # --- 3. Generative / Regression Path ---
        image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
        prompt_embeds_decoder = self.decoder.token_embedding_table(prompt_ids)
        target_embeds_decoder = self.decoder.token_embedding_table(target_ids)
        B, T_prompt, C = prompt_embeds_decoder.shape
        B, T_target, _ = target_embeds_decoder.shape

        # Prepare combined input sequence and attention mask for the decoder
        combined_embeds = torch.cat([
            image_embeds_decoder, prompt_embeds_decoder, target_embeds_decoder
        ], dim=1)
        combined_attention_mask = torch.cat([
            torch.ones(B, N_img, dtype=torch.long, device=DEVICE),
            prompt_attention_mask,
            target_attention_mask
        ], dim=1)
        T_combined = combined_embeds.shape[1]

        # Prepare combined targets for the classification loss
        combined_class_targets = None
        if generative_targets is not None:
            combined_class_targets = torch.cat([
                torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE),
                generative_targets
            ], dim=1)

        # --- Pass through Decoder ---
        logits, class_loss, x_norm = self.decoder(
            combined_embeds,
            attention_mask=combined_attention_mask,
            targets=combined_class_targets
        )
        # x_norm shape: (B, T_combined, C)

        # --- Calculate Regression Output & Loss (Modified for multiple points) ---
        regression_loss = None
        regression_output = None
        if continuous_coords is not None and coords_mask is not None and x_norm is not None:
            # Strategy: Use hidden state corresponding to token *before* <result_end> (or <eos>)
            # This single state predicts coordinates for *all* MAX_POINTS.
            target_lengths = target_attention_mask.sum(dim=1) # Length of actual target tokens (B,)
            # Index relative to start of *target sequence* is length - 2 (token before <eos>/<result_end>)
            relative_target_idx = (target_lengths - 2).clamp(min=0)
            # Absolute index in the combined sequence's hidden states (x_norm)
            absolute_idx = N_img + T_prompt + relative_target_idx
            absolute_idx = absolute_idx.clamp(max=T_combined - 1) # Clamp index

            # Gather the hidden states at these specific indices
            gather_indices_reg = absolute_idx.view(B, 1, 1).expand(-1, -1, C)
            try:
                hidden_state_for_regression = x_norm.gather(1, gather_indices_reg).squeeze(1) # Shape: (B, C)
                # Pass through the regression head
                regression_output_flat = self.decoder.regression_head(hidden_state_for_regression) # Shape: (B, MAX_POINTS * 2)
                # Reshape to (B, MAX_POINTS, 2)
                regression_output = regression_output_flat.view(B, self.max_points, 2)

                # --- Calculate MASKED regression loss (L1 - Mean Absolute Error) ---
                loss_per_coord = F.l1_loss(regression_output, continuous_coords, reduction='none') # (B, MAX_POINTS, 2)
                # Apply mask (mask is (B, MAX_POINTS), need to broadcast to (B, MAX_POINTS, 2))
                masked_loss = loss_per_coord * coords_mask.unsqueeze(-1)
                # Sum loss over valid points and coordinates, divide by number of valid coordinates
                num_valid_coords = coords_mask.sum() * 2 # Total number of valid x,y values in batch
                if num_valid_coords > 0:
                    regression_loss = masked_loss.sum() / num_valid_coords
                else:
                    regression_loss = torch.tensor(0.0, device=DEVICE) # No valid points in batch

                if torch.isnan(regression_loss):
                    print("Warning: Regression loss is NaN.")
                    regression_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) # Set to zero tensor if NaN


            except Exception as e:
                 print(f"Error during regression calculation: {e}")
                 print(f"x_norm shape: {x_norm.shape}, absolute_idx: {absolute_idx}")
                 regression_loss = None
                 regression_output = None # Ensure output is None if error occurs


        # --- 4. Combine All Losses ---
        total_loss = torch.tensor(0.0, device=DEVICE) # Ensure requires_grad=True
        # Add valid losses with their respective weights
        loss_log = {}
        if class_loss is not None and torch.isfinite(class_loss):
            total_loss += class_loss # Weight = 1.0 assumed
            loss_log["class_loss"] = class_loss.item()
        else:
            # If class_loss is None or NaN/Inf, don't add it, log NaN
            loss_log["class_loss"] = float('nan')
            print(f"Warning: Invalid class_loss ({class_loss})")


        if contrastive_loss is not None and torch.isfinite(contrastive_loss):
            total_loss += self.lambda_contrastive * contrastive_loss
            loss_log["contrastive_loss"] = contrastive_loss.item()
        else:
            loss_log["contrastive_loss"] = float('nan')
            print(f"Warning: Invalid contrastive_loss ({contrastive_loss})")


        if regression_loss is not None and torch.isfinite(regression_loss):
            total_loss += self.lambda_regression * regression_loss
            loss_log["regression_loss"] = regression_loss.item()
        else:
            loss_log["regression_loss"] = float('nan')
             # Don't print warning if it was intentionally set to 0 due to no valid points
            if regression_loss is not None and not (regression_loss == 0.0 and num_valid_coords == 0):
                 print(f"Warning: Invalid regression_loss ({regression_loss})")


        # Handle case where total loss becomes NaN/Inf
        if not torch.isfinite(total_loss):
            print(f"Warning: Total loss became non-finite ({total_loss}). Setting to zero and clearing gradients.")
            total_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
            # It might be safer to skip the optimizer step entirely here, handled in training loop

        # Use the loss_log dictionary for clearer logging later
        class_loss_val = loss_log["class_loss"]
        contrastive_loss_val = loss_log["contrastive_loss"]
        regression_loss_val = loss_log["regression_loss"]

        # Return all relevant outputs (use scalar values for loss logging)
        return logits, regression_output, total_loss, \
               torch.tensor(class_loss_val), torch.tensor(contrastive_loss_val), torch.tensor(regression_loss_val)


    # --- Generation Method ---
    @torch.no_grad() # Ensure no gradients are computed during generation
    def generate(self, img_array, idx_prompt, max_new_tokens,
                 temperature=1.0, top_k=None, # Default to greedy if temp=1, top_k=None
                 force_result_start=True # Option to manually add <result_start>
                 ):
        """
        Generates token sequences autoregressively based on image and prompt.
        Uses the classification head (lm_head).

        Args:
            img_array (torch.Tensor): Input image tensor (B, 3, H, W). B should be 1 for this impl.
            idx_prompt (torch.Tensor): Input prompt token IDs (B, T_prompt).
            max_new_tokens (int): Maximum number of new tokens to generate.
            temperature (float): Softmax temperature. 1.0 means no change. Lower values make it sharper.
            top_k (int | None): If set, restricts sampling to top K most likely tokens.
            force_result_start (bool): If True, manually appends <result_start> embedding
                                       after the prompt before starting generation loop.

        Returns:
            torch.Tensor: Generated sequence IDs, including the prompt (B, T_prompt + T_generated).
        """
        self.eval() # Ensure model is in eval mode
        B = img_array.shape[0]
        if B > 1:
            # This simplified generation loop assumes B=1 for clarity
            # Batch generation requires careful handling of EOS and padding within the loop
            print("Warning: Generation function currently assumes batch size B=1.")
            # Process only the first item for now
            img_array = img_array[:1]
            idx_prompt = idx_prompt[:1]
            B = 1

        # --- 1. Prepare Initial Embeddings ---
        image_embeds_raw = self.vision_encoder(img_array)
        image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
        prompt_embeds_decoder = self.decoder.token_embedding_table(idx_prompt)

        # Initial sequence for the decoder loop
        current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1)
        generated_ids_list = [] # Store newly generated IDs as a list

        # Manually add <result_start> if forced
        if force_result_start:
            try:
                 result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
                 result_start_embed = self.decoder.token_embedding_table(
                     torch.tensor([[result_start_token_id]], device=DEVICE)
                 )
                 current_embeds = torch.cat([current_embeds, result_start_embed], dim=1)
                 # Also store this token ID if we added it
                 generated_ids_list.append(torch.tensor([[result_start_token_id]], device=DEVICE))
            except Exception as e:
                 print(f"Warning: Could not encode or add <result_start>: {e}")


        # --- 2. Autoregressive Loop ---
        for _ in range(max_new_tokens):
            T_current = current_embeds.shape[1]

            # Context truncation
            if T_current > self.decoder.max_context:
                current_embeds = current_embeds[:, -self.decoder.max_context:, :]
                T_current = self.decoder.max_context

            # Prepare inputs for decoder blocks
            pos = torch.arange(0, T_current, dtype=torch.long, device=DEVICE)
            pos = pos.clamp(max=self.decoder.max_context - 1)
            pos_emb = self.decoder.position_embedding_table(pos).unsqueeze(0)
            x = current_embeds + pos_emb
            attention_mask = torch.ones(B, T_current, device=DEVICE, dtype=torch.long) # No padding needed

            # Pass through decoder blocks
            for block in self.decoder.blocks:
                x = block(x, attention_mask=attention_mask)

            # Get logits for the last token
            x = self.decoder.ln_f(x[:, -1:, :]) # (B, 1, C)
            logits = self.decoder.lm_head(x)    # (B, 1, V)
            logits = logits.squeeze(1) / temperature # Apply temperature (B, V)

            # --- Sampling / Decoding ---
            # Optional: Top-K filtering
            if top_k is not None and top_k > 0:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf') # Apply mask

            # Get probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample next token ID
            # For deterministic output (greedy), use torch.argmax instead of multinomial
            if temperature == 0.0 or top_k == 1: # Greedy condition
                 idx_next = torch.argmax(probs, dim=-1, keepdim=True)
            else:
                 idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

            # Append the generated token ID
            generated_ids_list.append(idx_next)

            # Stop if EOS is generated
            if hasattr(tokenizer, 'eos_token_id') and idx_next.item() == tokenizer.eos_token_id:
                break

            # Prepare for next iteration
            next_token_embed = self.decoder.token_embedding_table(idx_next)
            current_embeds = torch.cat([current_embeds, next_token_embed], dim=1)


        # --- 3. Combine results ---
        if generated_ids_list:
            generated_ids_tensor = torch.cat(generated_ids_list, dim=1) # (B, T_generated)
            full_sequence_ids = torch.cat([idx_prompt, generated_ids_tensor], dim=1)
        else:
            full_sequence_ids = idx_prompt # Return only prompt if nothing generated

        self.train() # Set model back to training mode
        return full_sequence_ids