from model_components import ViT, MultiModalProjector from decoder_language_model import DecoderLanguageModel from constants import * import torch import torch.nn as nn import torch.nn.functional as F from utils import tokenizer, vocab_size class VisionLanguageModel(nn.Module): """ Vision Language Model integrating ViT, Projector, Contrastive Loss, Decoder (Class + Reg). Handles multiple points via padded regression targets and masked loss. """ def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_heads=NUM_HEADS, num_blks_vit=NUM_LAYERS, num_blks_dec=NUM_LAYERS, emb_dropout=DROPOUT, blk_dropout=DROPOUT, max_context=CONTEXT_LENGTH, shared_embed_dim=SHARED_EMBED_DIM, lambda_contrastive=LAMBDA_CONTRASTIVE, lambda_regression=LAMBDA_REGRESSION, # Use the updated constant max_points = MAX_POINTS # Store max points ): super().__init__() # --- Vision Backbone --- self.vision_encoder = ViT( img_size=img_size, patch_size=patch_size, num_hiddens=n_embd, # Assuming ViT output dim matches decoder embed dim num_heads=num_heads, num_blks=num_blks_vit, emb_dropout=emb_dropout, blk_dropout=blk_dropout ) # --- Multimodal Components --- self.multimodal_projector = MultiModalProjector( image_embed_dim=n_embd, # Input from ViT text_embed_dim=n_embd, # Output matches decoder dim dropout=emb_dropout ) self.image_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False) self.text_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False) self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07))) # --- Text Decoder --- # DecoderLanguageModel now has regression head outputting MAX_POINTS*2 self.decoder = DecoderLanguageModel( n_embd=n_embd, vocab_size=vocab_size, num_heads=num_heads, n_layer=num_blks_dec, max_context=max_context, dropout=blk_dropout # Use block dropout for decoder consistency ) # --- Store Configuration --- self.n_embd = n_embd self.vocab_size = vocab_size self.num_patches = (img_size // patch_size)**2 + 1 self.lambda_contrastive = lambda_contrastive self.lambda_regression = lambda_regression self.max_points = max_points # Store max points self._resize_embeddings_if_needed(self.vocab_size) print("VisionLanguageModel initialized.") def _resize_embeddings_if_needed(self, current_vocab_size): """ Resizes decoder token embeddings if vocab size changed after init. """ decoder_embedding_size = self.decoder.token_embedding_table.num_embeddings if decoder_embedding_size != current_vocab_size: print(f"Resizing VLM decoder token embeddings from {decoder_embedding_size} to {current_vocab_size}") # Freeze original weights before replacing layers self.decoder.token_embedding_table.weight.requires_grad = False self.decoder.lm_head.weight.requires_grad = False # Create new layers new_embedding = nn.Embedding(current_vocab_size, self.n_embd).to(DEVICE) new_lm_head = nn.Linear(self.n_embd, current_vocab_size, bias=False).to(DEVICE) # Assign new layers self.decoder.token_embedding_table = new_embedding self.decoder.lm_head = new_lm_head # Re-tie weights self.decoder.token_embedding_table.weight = self.decoder.lm_head.weight print("VLM decoder embeddings resized and weights retied.") def _calculate_contrastive_loss(self, image_features, text_features): """ Calculates the symmetric InfoNCE loss. """ # Assumes features are already projected to shared_embed_dim # image_features: (B, E) # text_features: (B, E) # Normalize features image_features = F.normalize(image_features, dim=-1) text_features = F.normalize(text_features, dim=-1) # Cosine similarity as logits (using learnable temperature) logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # Calculate symmetric cross-entropy loss labels = torch.arange(len(logits_per_image), device=logits_per_image.device) loss_i = F.cross_entropy(logits_per_image, labels) loss_t = F.cross_entropy(logits_per_text, labels) contrastive_loss = (loss_i + loss_t) / 2.0 # Handle potential NaNs if torch.isnan(contrastive_loss): print("Warning: Contrastive loss is NaN.") return None # Return None or zero tensor return contrastive_loss def forward(self, img_array, prompt_ids, prompt_attention_mask, target_ids, target_attention_mask, generative_targets=None, continuous_coords=None, # Now expects shape (B, MAX_POINTS, 2), padded coords_mask=None # Mask for valid points (B, MAX_POINTS) ): """ Main forward pass for training. Calculates combined loss with masked regression loss. """ # --- 1. Encode Image --- image_embeds_raw = self.vision_encoder(img_array) # (B, N_img, C) B, N_img, C_img = image_embeds_raw.shape img_cls_token = image_embeds_raw[:, 0] # --- 2. Contrastive Loss Path --- contrastive_loss = None # ... (contrastive loss calculation - same as before) ... image_features_contrast = self.image_contrastive_head(img_cls_token) with torch.no_grad(): # Keep no_grad here for efficiency if prompt embeddings aren't trained via contrastive prompt_text_embeds_contrast = self.decoder.token_embedding_table(prompt_ids) prompt_lengths = prompt_attention_mask.sum(dim=1) last_token_indices = (prompt_lengths - 1).clamp(min=0) gather_indices = last_token_indices.view(B, 1, 1).expand(-1, -1, C_img) prompt_last_token_embed = prompt_text_embeds_contrast.gather(1, gather_indices).squeeze(1) text_features_contrast = self.text_contrastive_head(prompt_last_token_embed) contrastive_loss = self._calculate_contrastive_loss(image_features_contrast, text_features_contrast) # --- 3. Generative / Regression Path --- image_embeds_decoder = self.multimodal_projector(image_embeds_raw) prompt_embeds_decoder = self.decoder.token_embedding_table(prompt_ids) target_embeds_decoder = self.decoder.token_embedding_table(target_ids) B, T_prompt, C = prompt_embeds_decoder.shape B, T_target, _ = target_embeds_decoder.shape # Prepare combined input sequence and attention mask for the decoder combined_embeds = torch.cat([ image_embeds_decoder, prompt_embeds_decoder, target_embeds_decoder ], dim=1) combined_attention_mask = torch.cat([ torch.ones(B, N_img, dtype=torch.long, device=DEVICE), prompt_attention_mask, target_attention_mask ], dim=1) T_combined = combined_embeds.shape[1] # Prepare combined targets for the classification loss combined_class_targets = None if generative_targets is not None: combined_class_targets = torch.cat([ torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE), generative_targets ], dim=1) # --- Pass through Decoder --- logits, class_loss, x_norm = self.decoder( combined_embeds, attention_mask=combined_attention_mask, targets=combined_class_targets ) # x_norm shape: (B, T_combined, C) # --- Calculate Regression Output & Loss (Modified for multiple points) --- regression_loss = None regression_output = None if continuous_coords is not None and coords_mask is not None and x_norm is not None: # Strategy: Use hidden state corresponding to token *before* (or ) # This single state predicts coordinates for *all* MAX_POINTS. target_lengths = target_attention_mask.sum(dim=1) # Length of actual target tokens (B,) # Index relative to start of *target sequence* is length - 2 (token before /) relative_target_idx = (target_lengths - 2).clamp(min=0) # Absolute index in the combined sequence's hidden states (x_norm) absolute_idx = N_img + T_prompt + relative_target_idx absolute_idx = absolute_idx.clamp(max=T_combined - 1) # Clamp index # Gather the hidden states at these specific indices gather_indices_reg = absolute_idx.view(B, 1, 1).expand(-1, -1, C) try: hidden_state_for_regression = x_norm.gather(1, gather_indices_reg).squeeze(1) # Shape: (B, C) # Pass through the regression head regression_output_flat = self.decoder.regression_head(hidden_state_for_regression) # Shape: (B, MAX_POINTS * 2) # Reshape to (B, MAX_POINTS, 2) regression_output = regression_output_flat.view(B, self.max_points, 2) # --- Calculate MASKED regression loss (L1 - Mean Absolute Error) --- loss_per_coord = F.l1_loss(regression_output, continuous_coords, reduction='none') # (B, MAX_POINTS, 2) # Apply mask (mask is (B, MAX_POINTS), need to broadcast to (B, MAX_POINTS, 2)) masked_loss = loss_per_coord * coords_mask.unsqueeze(-1) # Sum loss over valid points and coordinates, divide by number of valid coordinates num_valid_coords = coords_mask.sum() * 2 # Total number of valid x,y values in batch if num_valid_coords > 0: regression_loss = masked_loss.sum() / num_valid_coords else: regression_loss = torch.tensor(0.0, device=DEVICE) # No valid points in batch if torch.isnan(regression_loss): print("Warning: Regression loss is NaN.") regression_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) # Set to zero tensor if NaN except Exception as e: print(f"Error during regression calculation: {e}") print(f"x_norm shape: {x_norm.shape}, absolute_idx: {absolute_idx}") regression_loss = None regression_output = None # Ensure output is None if error occurs # --- 4. Combine All Losses --- total_loss = torch.tensor(0.0, device=DEVICE) # Ensure requires_grad=True # Add valid losses with their respective weights loss_log = {} if class_loss is not None and torch.isfinite(class_loss): total_loss += class_loss # Weight = 1.0 assumed loss_log["class_loss"] = class_loss.item() else: # If class_loss is None or NaN/Inf, don't add it, log NaN loss_log["class_loss"] = float('nan') print(f"Warning: Invalid class_loss ({class_loss})") if contrastive_loss is not None and torch.isfinite(contrastive_loss): total_loss += self.lambda_contrastive * contrastive_loss loss_log["contrastive_loss"] = contrastive_loss.item() else: loss_log["contrastive_loss"] = float('nan') print(f"Warning: Invalid contrastive_loss ({contrastive_loss})") if regression_loss is not None and torch.isfinite(regression_loss): total_loss += self.lambda_regression * regression_loss loss_log["regression_loss"] = regression_loss.item() else: loss_log["regression_loss"] = float('nan') # Don't print warning if it was intentionally set to 0 due to no valid points if regression_loss is not None and not (regression_loss == 0.0 and num_valid_coords == 0): print(f"Warning: Invalid regression_loss ({regression_loss})") # Handle case where total loss becomes NaN/Inf if not torch.isfinite(total_loss): print(f"Warning: Total loss became non-finite ({total_loss}). Setting to zero and clearing gradients.") total_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) # It might be safer to skip the optimizer step entirely here, handled in training loop # Use the loss_log dictionary for clearer logging later class_loss_val = loss_log["class_loss"] contrastive_loss_val = loss_log["contrastive_loss"] regression_loss_val = loss_log["regression_loss"] # Return all relevant outputs (use scalar values for loss logging) return logits, regression_output, total_loss, \ torch.tensor(class_loss_val), torch.tensor(contrastive_loss_val), torch.tensor(regression_loss_val) # --- Generation Method --- @torch.no_grad() # Ensure no gradients are computed during generation def generate(self, img_array, idx_prompt, max_new_tokens, temperature=1.0, top_k=None, # Default to greedy if temp=1, top_k=None force_result_start=True # Option to manually add ): """ Generates token sequences autoregressively based on image and prompt. Uses the classification head (lm_head). Args: img_array (torch.Tensor): Input image tensor (B, 3, H, W). B should be 1 for this impl. idx_prompt (torch.Tensor): Input prompt token IDs (B, T_prompt). max_new_tokens (int): Maximum number of new tokens to generate. temperature (float): Softmax temperature. 1.0 means no change. Lower values make it sharper. top_k (int | None): If set, restricts sampling to top K most likely tokens. force_result_start (bool): If True, manually appends embedding after the prompt before starting generation loop. Returns: torch.Tensor: Generated sequence IDs, including the prompt (B, T_prompt + T_generated). """ self.eval() # Ensure model is in eval mode B = img_array.shape[0] if B > 1: # This simplified generation loop assumes B=1 for clarity # Batch generation requires careful handling of EOS and padding within the loop print("Warning: Generation function currently assumes batch size B=1.") # Process only the first item for now img_array = img_array[:1] idx_prompt = idx_prompt[:1] B = 1 # --- 1. Prepare Initial Embeddings --- image_embeds_raw = self.vision_encoder(img_array) image_embeds_decoder = self.multimodal_projector(image_embeds_raw) prompt_embeds_decoder = self.decoder.token_embedding_table(idx_prompt) # Initial sequence for the decoder loop current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1) generated_ids_list = [] # Store newly generated IDs as a list # Manually add if forced if force_result_start: try: result_start_token_id = tokenizer.encode("", add_special_tokens=False)[0] result_start_embed = self.decoder.token_embedding_table( torch.tensor([[result_start_token_id]], device=DEVICE) ) current_embeds = torch.cat([current_embeds, result_start_embed], dim=1) # Also store this token ID if we added it generated_ids_list.append(torch.tensor([[result_start_token_id]], device=DEVICE)) except Exception as e: print(f"Warning: Could not encode or add : {e}") # --- 2. Autoregressive Loop --- for _ in range(max_new_tokens): T_current = current_embeds.shape[1] # Context truncation if T_current > self.decoder.max_context: current_embeds = current_embeds[:, -self.decoder.max_context:, :] T_current = self.decoder.max_context # Prepare inputs for decoder blocks pos = torch.arange(0, T_current, dtype=torch.long, device=DEVICE) pos = pos.clamp(max=self.decoder.max_context - 1) pos_emb = self.decoder.position_embedding_table(pos).unsqueeze(0) x = current_embeds + pos_emb attention_mask = torch.ones(B, T_current, device=DEVICE, dtype=torch.long) # No padding needed # Pass through decoder blocks for block in self.decoder.blocks: x = block(x, attention_mask=attention_mask) # Get logits for the last token x = self.decoder.ln_f(x[:, -1:, :]) # (B, 1, C) logits = self.decoder.lm_head(x) # (B, 1, V) logits = logits.squeeze(1) / temperature # Apply temperature (B, V) # --- Sampling / Decoding --- # Optional: Top-K filtering if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # Apply mask # Get probabilities probs = F.softmax(logits, dim=-1) # Sample next token ID # For deterministic output (greedy), use torch.argmax instead of multinomial if temperature == 0.0 or top_k == 1: # Greedy condition idx_next = torch.argmax(probs, dim=-1, keepdim=True) else: idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) # Append the generated token ID generated_ids_list.append(idx_next) # Stop if EOS is generated if hasattr(tokenizer, 'eos_token_id') and idx_next.item() == tokenizer.eos_token_id: break # Prepare for next iteration next_token_embed = self.decoder.token_embedding_table(idx_next) current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) # --- 3. Combine results --- if generated_ids_list: generated_ids_tensor = torch.cat(generated_ids_list, dim=1) # (B, T_generated) full_sequence_ids = torch.cat([idx_prompt, generated_ids_tensor], dim=1) else: full_sequence_ids = idx_prompt # Return only prompt if nothing generated self.train() # Set model back to training mode return full_sequence_ids