pointing_models_from_scratch / vision_language_model.py
mbiswas's picture
Upload 10 files
b781107 verified
from model_components import ViT, MultiModalProjector
from decoder_language_model import DecoderLanguageModel
from constants import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import tokenizer, vocab_size
class VisionLanguageModel(nn.Module):
"""
Vision Language Model integrating ViT, Projector, Contrastive Loss, Decoder (Class + Reg).
Handles multiple points via padded regression targets and masked loss.
"""
def __init__(self,
n_embd=HIDDEN_DIM,
vocab_size=vocab_size,
img_size=IMAGE_SIZE,
patch_size=PATCH_SIZE,
num_heads=NUM_HEADS,
num_blks_vit=NUM_LAYERS,
num_blks_dec=NUM_LAYERS,
emb_dropout=DROPOUT,
blk_dropout=DROPOUT,
max_context=CONTEXT_LENGTH,
shared_embed_dim=SHARED_EMBED_DIM,
lambda_contrastive=LAMBDA_CONTRASTIVE,
lambda_regression=LAMBDA_REGRESSION, # Use the updated constant
max_points = MAX_POINTS # Store max points
):
super().__init__()
# --- Vision Backbone ---
self.vision_encoder = ViT(
img_size=img_size,
patch_size=patch_size,
num_hiddens=n_embd, # Assuming ViT output dim matches decoder embed dim
num_heads=num_heads,
num_blks=num_blks_vit,
emb_dropout=emb_dropout,
blk_dropout=blk_dropout
)
# --- Multimodal Components ---
self.multimodal_projector = MultiModalProjector(
image_embed_dim=n_embd, # Input from ViT
text_embed_dim=n_embd, # Output matches decoder dim
dropout=emb_dropout
)
self.image_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
self.text_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))
# --- Text Decoder ---
# DecoderLanguageModel now has regression head outputting MAX_POINTS*2
self.decoder = DecoderLanguageModel(
n_embd=n_embd,
vocab_size=vocab_size,
num_heads=num_heads,
n_layer=num_blks_dec,
max_context=max_context,
dropout=blk_dropout # Use block dropout for decoder consistency
)
# --- Store Configuration ---
self.n_embd = n_embd
self.vocab_size = vocab_size
self.num_patches = (img_size // patch_size)**2 + 1
self.lambda_contrastive = lambda_contrastive
self.lambda_regression = lambda_regression
self.max_points = max_points # Store max points
self._resize_embeddings_if_needed(self.vocab_size)
print("VisionLanguageModel initialized.")
def _resize_embeddings_if_needed(self, current_vocab_size):
""" Resizes decoder token embeddings if vocab size changed after init. """
decoder_embedding_size = self.decoder.token_embedding_table.num_embeddings
if decoder_embedding_size != current_vocab_size:
print(f"Resizing VLM decoder token embeddings from {decoder_embedding_size} to {current_vocab_size}")
# Freeze original weights before replacing layers
self.decoder.token_embedding_table.weight.requires_grad = False
self.decoder.lm_head.weight.requires_grad = False
# Create new layers
new_embedding = nn.Embedding(current_vocab_size, self.n_embd).to(DEVICE)
new_lm_head = nn.Linear(self.n_embd, current_vocab_size, bias=False).to(DEVICE)
# Assign new layers
self.decoder.token_embedding_table = new_embedding
self.decoder.lm_head = new_lm_head
# Re-tie weights
self.decoder.token_embedding_table.weight = self.decoder.lm_head.weight
print("VLM decoder embeddings resized and weights retied.")
def _calculate_contrastive_loss(self, image_features, text_features):
""" Calculates the symmetric InfoNCE loss. """
# Assumes features are already projected to shared_embed_dim
# image_features: (B, E)
# text_features: (B, E)
# Normalize features
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# Cosine similarity as logits (using learnable temperature)
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# Calculate symmetric cross-entropy loss
labels = torch.arange(len(logits_per_image), device=logits_per_image.device)
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_text, labels)
contrastive_loss = (loss_i + loss_t) / 2.0
# Handle potential NaNs
if torch.isnan(contrastive_loss):
print("Warning: Contrastive loss is NaN.")
return None # Return None or zero tensor
return contrastive_loss
def forward(self,
img_array,
prompt_ids,
prompt_attention_mask,
target_ids,
target_attention_mask,
generative_targets=None,
continuous_coords=None, # Now expects shape (B, MAX_POINTS, 2), padded
coords_mask=None # Mask for valid points (B, MAX_POINTS)
):
"""
Main forward pass for training. Calculates combined loss with masked regression loss.
"""
# --- 1. Encode Image ---
image_embeds_raw = self.vision_encoder(img_array) # (B, N_img, C)
B, N_img, C_img = image_embeds_raw.shape
img_cls_token = image_embeds_raw[:, 0]
# --- 2. Contrastive Loss Path ---
contrastive_loss = None
# ... (contrastive loss calculation - same as before) ...
image_features_contrast = self.image_contrastive_head(img_cls_token)
with torch.no_grad(): # Keep no_grad here for efficiency if prompt embeddings aren't trained via contrastive
prompt_text_embeds_contrast = self.decoder.token_embedding_table(prompt_ids)
prompt_lengths = prompt_attention_mask.sum(dim=1)
last_token_indices = (prompt_lengths - 1).clamp(min=0)
gather_indices = last_token_indices.view(B, 1, 1).expand(-1, -1, C_img)
prompt_last_token_embed = prompt_text_embeds_contrast.gather(1, gather_indices).squeeze(1)
text_features_contrast = self.text_contrastive_head(prompt_last_token_embed)
contrastive_loss = self._calculate_contrastive_loss(image_features_contrast, text_features_contrast)
# --- 3. Generative / Regression Path ---
image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
prompt_embeds_decoder = self.decoder.token_embedding_table(prompt_ids)
target_embeds_decoder = self.decoder.token_embedding_table(target_ids)
B, T_prompt, C = prompt_embeds_decoder.shape
B, T_target, _ = target_embeds_decoder.shape
# Prepare combined input sequence and attention mask for the decoder
combined_embeds = torch.cat([
image_embeds_decoder, prompt_embeds_decoder, target_embeds_decoder
], dim=1)
combined_attention_mask = torch.cat([
torch.ones(B, N_img, dtype=torch.long, device=DEVICE),
prompt_attention_mask,
target_attention_mask
], dim=1)
T_combined = combined_embeds.shape[1]
# Prepare combined targets for the classification loss
combined_class_targets = None
if generative_targets is not None:
combined_class_targets = torch.cat([
torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE),
generative_targets
], dim=1)
# --- Pass through Decoder ---
logits, class_loss, x_norm = self.decoder(
combined_embeds,
attention_mask=combined_attention_mask,
targets=combined_class_targets
)
# x_norm shape: (B, T_combined, C)
# --- Calculate Regression Output & Loss (Modified for multiple points) ---
regression_loss = None
regression_output = None
if continuous_coords is not None and coords_mask is not None and x_norm is not None:
# Strategy: Use hidden state corresponding to token *before* <result_end> (or <eos>)
# This single state predicts coordinates for *all* MAX_POINTS.
target_lengths = target_attention_mask.sum(dim=1) # Length of actual target tokens (B,)
# Index relative to start of *target sequence* is length - 2 (token before <eos>/<result_end>)
relative_target_idx = (target_lengths - 2).clamp(min=0)
# Absolute index in the combined sequence's hidden states (x_norm)
absolute_idx = N_img + T_prompt + relative_target_idx
absolute_idx = absolute_idx.clamp(max=T_combined - 1) # Clamp index
# Gather the hidden states at these specific indices
gather_indices_reg = absolute_idx.view(B, 1, 1).expand(-1, -1, C)
try:
hidden_state_for_regression = x_norm.gather(1, gather_indices_reg).squeeze(1) # Shape: (B, C)
# Pass through the regression head
regression_output_flat = self.decoder.regression_head(hidden_state_for_regression) # Shape: (B, MAX_POINTS * 2)
# Reshape to (B, MAX_POINTS, 2)
regression_output = regression_output_flat.view(B, self.max_points, 2)
# --- Calculate MASKED regression loss (L1 - Mean Absolute Error) ---
loss_per_coord = F.l1_loss(regression_output, continuous_coords, reduction='none') # (B, MAX_POINTS, 2)
# Apply mask (mask is (B, MAX_POINTS), need to broadcast to (B, MAX_POINTS, 2))
masked_loss = loss_per_coord * coords_mask.unsqueeze(-1)
# Sum loss over valid points and coordinates, divide by number of valid coordinates
num_valid_coords = coords_mask.sum() * 2 # Total number of valid x,y values in batch
if num_valid_coords > 0:
regression_loss = masked_loss.sum() / num_valid_coords
else:
regression_loss = torch.tensor(0.0, device=DEVICE) # No valid points in batch
if torch.isnan(regression_loss):
print("Warning: Regression loss is NaN.")
regression_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) # Set to zero tensor if NaN
except Exception as e:
print(f"Error during regression calculation: {e}")
print(f"x_norm shape: {x_norm.shape}, absolute_idx: {absolute_idx}")
regression_loss = None
regression_output = None # Ensure output is None if error occurs
# --- 4. Combine All Losses ---
total_loss = torch.tensor(0.0, device=DEVICE) # Ensure requires_grad=True
# Add valid losses with their respective weights
loss_log = {}
if class_loss is not None and torch.isfinite(class_loss):
total_loss += class_loss # Weight = 1.0 assumed
loss_log["class_loss"] = class_loss.item()
else:
# If class_loss is None or NaN/Inf, don't add it, log NaN
loss_log["class_loss"] = float('nan')
print(f"Warning: Invalid class_loss ({class_loss})")
if contrastive_loss is not None and torch.isfinite(contrastive_loss):
total_loss += self.lambda_contrastive * contrastive_loss
loss_log["contrastive_loss"] = contrastive_loss.item()
else:
loss_log["contrastive_loss"] = float('nan')
print(f"Warning: Invalid contrastive_loss ({contrastive_loss})")
if regression_loss is not None and torch.isfinite(regression_loss):
total_loss += self.lambda_regression * regression_loss
loss_log["regression_loss"] = regression_loss.item()
else:
loss_log["regression_loss"] = float('nan')
# Don't print warning if it was intentionally set to 0 due to no valid points
if regression_loss is not None and not (regression_loss == 0.0 and num_valid_coords == 0):
print(f"Warning: Invalid regression_loss ({regression_loss})")
# Handle case where total loss becomes NaN/Inf
if not torch.isfinite(total_loss):
print(f"Warning: Total loss became non-finite ({total_loss}). Setting to zero and clearing gradients.")
total_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
# It might be safer to skip the optimizer step entirely here, handled in training loop
# Use the loss_log dictionary for clearer logging later
class_loss_val = loss_log["class_loss"]
contrastive_loss_val = loss_log["contrastive_loss"]
regression_loss_val = loss_log["regression_loss"]
# Return all relevant outputs (use scalar values for loss logging)
return logits, regression_output, total_loss, \
torch.tensor(class_loss_val), torch.tensor(contrastive_loss_val), torch.tensor(regression_loss_val)
# --- Generation Method ---
@torch.no_grad() # Ensure no gradients are computed during generation
def generate(self, img_array, idx_prompt, max_new_tokens,
temperature=1.0, top_k=None, # Default to greedy if temp=1, top_k=None
force_result_start=True # Option to manually add <result_start>
):
"""
Generates token sequences autoregressively based on image and prompt.
Uses the classification head (lm_head).
Args:
img_array (torch.Tensor): Input image tensor (B, 3, H, W). B should be 1 for this impl.
idx_prompt (torch.Tensor): Input prompt token IDs (B, T_prompt).
max_new_tokens (int): Maximum number of new tokens to generate.
temperature (float): Softmax temperature. 1.0 means no change. Lower values make it sharper.
top_k (int | None): If set, restricts sampling to top K most likely tokens.
force_result_start (bool): If True, manually appends <result_start> embedding
after the prompt before starting generation loop.
Returns:
torch.Tensor: Generated sequence IDs, including the prompt (B, T_prompt + T_generated).
"""
self.eval() # Ensure model is in eval mode
B = img_array.shape[0]
if B > 1:
# This simplified generation loop assumes B=1 for clarity
# Batch generation requires careful handling of EOS and padding within the loop
print("Warning: Generation function currently assumes batch size B=1.")
# Process only the first item for now
img_array = img_array[:1]
idx_prompt = idx_prompt[:1]
B = 1
# --- 1. Prepare Initial Embeddings ---
image_embeds_raw = self.vision_encoder(img_array)
image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
prompt_embeds_decoder = self.decoder.token_embedding_table(idx_prompt)
# Initial sequence for the decoder loop
current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1)
generated_ids_list = [] # Store newly generated IDs as a list
# Manually add <result_start> if forced
if force_result_start:
try:
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
result_start_embed = self.decoder.token_embedding_table(
torch.tensor([[result_start_token_id]], device=DEVICE)
)
current_embeds = torch.cat([current_embeds, result_start_embed], dim=1)
# Also store this token ID if we added it
generated_ids_list.append(torch.tensor([[result_start_token_id]], device=DEVICE))
except Exception as e:
print(f"Warning: Could not encode or add <result_start>: {e}")
# --- 2. Autoregressive Loop ---
for _ in range(max_new_tokens):
T_current = current_embeds.shape[1]
# Context truncation
if T_current > self.decoder.max_context:
current_embeds = current_embeds[:, -self.decoder.max_context:, :]
T_current = self.decoder.max_context
# Prepare inputs for decoder blocks
pos = torch.arange(0, T_current, dtype=torch.long, device=DEVICE)
pos = pos.clamp(max=self.decoder.max_context - 1)
pos_emb = self.decoder.position_embedding_table(pos).unsqueeze(0)
x = current_embeds + pos_emb
attention_mask = torch.ones(B, T_current, device=DEVICE, dtype=torch.long) # No padding needed
# Pass through decoder blocks
for block in self.decoder.blocks:
x = block(x, attention_mask=attention_mask)
# Get logits for the last token
x = self.decoder.ln_f(x[:, -1:, :]) # (B, 1, C)
logits = self.decoder.lm_head(x) # (B, 1, V)
logits = logits.squeeze(1) / temperature # Apply temperature (B, V)
# --- Sampling / Decoding ---
# Optional: Top-K filtering
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf') # Apply mask
# Get probabilities
probs = F.softmax(logits, dim=-1)
# Sample next token ID
# For deterministic output (greedy), use torch.argmax instead of multinomial
if temperature == 0.0 or top_k == 1: # Greedy condition
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
else:
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# Append the generated token ID
generated_ids_list.append(idx_next)
# Stop if EOS is generated
if hasattr(tokenizer, 'eos_token_id') and idx_next.item() == tokenizer.eos_token_id:
break
# Prepare for next iteration
next_token_embed = self.decoder.token_embedding_table(idx_next)
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1)
# --- 3. Combine results ---
if generated_ids_list:
generated_ids_tensor = torch.cat(generated_ids_list, dim=1) # (B, T_generated)
full_sequence_ids = torch.cat([idx_prompt, generated_ids_tensor], dim=1)
else:
full_sequence_ids = idx_prompt # Return only prompt if nothing generated
self.train() # Set model back to training mode
return full_sequence_ids