File size: 19,272 Bytes
b781107 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 |
from model_components import ViT, MultiModalProjector
from decoder_language_model import DecoderLanguageModel
from constants import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import tokenizer, vocab_size
class VisionLanguageModel(nn.Module):
"""
Vision Language Model integrating ViT, Projector, Contrastive Loss, Decoder (Class + Reg).
Handles multiple points via padded regression targets and masked loss.
"""
def __init__(self,
n_embd=HIDDEN_DIM,
vocab_size=vocab_size,
img_size=IMAGE_SIZE,
patch_size=PATCH_SIZE,
num_heads=NUM_HEADS,
num_blks_vit=NUM_LAYERS,
num_blks_dec=NUM_LAYERS,
emb_dropout=DROPOUT,
blk_dropout=DROPOUT,
max_context=CONTEXT_LENGTH,
shared_embed_dim=SHARED_EMBED_DIM,
lambda_contrastive=LAMBDA_CONTRASTIVE,
lambda_regression=LAMBDA_REGRESSION, # Use the updated constant
max_points = MAX_POINTS # Store max points
):
super().__init__()
# --- Vision Backbone ---
self.vision_encoder = ViT(
img_size=img_size,
patch_size=patch_size,
num_hiddens=n_embd, # Assuming ViT output dim matches decoder embed dim
num_heads=num_heads,
num_blks=num_blks_vit,
emb_dropout=emb_dropout,
blk_dropout=blk_dropout
)
# --- Multimodal Components ---
self.multimodal_projector = MultiModalProjector(
image_embed_dim=n_embd, # Input from ViT
text_embed_dim=n_embd, # Output matches decoder dim
dropout=emb_dropout
)
self.image_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
self.text_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False)
self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07)))
# --- Text Decoder ---
# DecoderLanguageModel now has regression head outputting MAX_POINTS*2
self.decoder = DecoderLanguageModel(
n_embd=n_embd,
vocab_size=vocab_size,
num_heads=num_heads,
n_layer=num_blks_dec,
max_context=max_context,
dropout=blk_dropout # Use block dropout for decoder consistency
)
# --- Store Configuration ---
self.n_embd = n_embd
self.vocab_size = vocab_size
self.num_patches = (img_size // patch_size)**2 + 1
self.lambda_contrastive = lambda_contrastive
self.lambda_regression = lambda_regression
self.max_points = max_points # Store max points
self._resize_embeddings_if_needed(self.vocab_size)
print("VisionLanguageModel initialized.")
def _resize_embeddings_if_needed(self, current_vocab_size):
""" Resizes decoder token embeddings if vocab size changed after init. """
decoder_embedding_size = self.decoder.token_embedding_table.num_embeddings
if decoder_embedding_size != current_vocab_size:
print(f"Resizing VLM decoder token embeddings from {decoder_embedding_size} to {current_vocab_size}")
# Freeze original weights before replacing layers
self.decoder.token_embedding_table.weight.requires_grad = False
self.decoder.lm_head.weight.requires_grad = False
# Create new layers
new_embedding = nn.Embedding(current_vocab_size, self.n_embd).to(DEVICE)
new_lm_head = nn.Linear(self.n_embd, current_vocab_size, bias=False).to(DEVICE)
# Assign new layers
self.decoder.token_embedding_table = new_embedding
self.decoder.lm_head = new_lm_head
# Re-tie weights
self.decoder.token_embedding_table.weight = self.decoder.lm_head.weight
print("VLM decoder embeddings resized and weights retied.")
def _calculate_contrastive_loss(self, image_features, text_features):
""" Calculates the symmetric InfoNCE loss. """
# Assumes features are already projected to shared_embed_dim
# image_features: (B, E)
# text_features: (B, E)
# Normalize features
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# Cosine similarity as logits (using learnable temperature)
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# Calculate symmetric cross-entropy loss
labels = torch.arange(len(logits_per_image), device=logits_per_image.device)
loss_i = F.cross_entropy(logits_per_image, labels)
loss_t = F.cross_entropy(logits_per_text, labels)
contrastive_loss = (loss_i + loss_t) / 2.0
# Handle potential NaNs
if torch.isnan(contrastive_loss):
print("Warning: Contrastive loss is NaN.")
return None # Return None or zero tensor
return contrastive_loss
def forward(self,
img_array,
prompt_ids,
prompt_attention_mask,
target_ids,
target_attention_mask,
generative_targets=None,
continuous_coords=None, # Now expects shape (B, MAX_POINTS, 2), padded
coords_mask=None # Mask for valid points (B, MAX_POINTS)
):
"""
Main forward pass for training. Calculates combined loss with masked regression loss.
"""
# --- 1. Encode Image ---
image_embeds_raw = self.vision_encoder(img_array) # (B, N_img, C)
B, N_img, C_img = image_embeds_raw.shape
img_cls_token = image_embeds_raw[:, 0]
# --- 2. Contrastive Loss Path ---
contrastive_loss = None
# ... (contrastive loss calculation - same as before) ...
image_features_contrast = self.image_contrastive_head(img_cls_token)
with torch.no_grad(): # Keep no_grad here for efficiency if prompt embeddings aren't trained via contrastive
prompt_text_embeds_contrast = self.decoder.token_embedding_table(prompt_ids)
prompt_lengths = prompt_attention_mask.sum(dim=1)
last_token_indices = (prompt_lengths - 1).clamp(min=0)
gather_indices = last_token_indices.view(B, 1, 1).expand(-1, -1, C_img)
prompt_last_token_embed = prompt_text_embeds_contrast.gather(1, gather_indices).squeeze(1)
text_features_contrast = self.text_contrastive_head(prompt_last_token_embed)
contrastive_loss = self._calculate_contrastive_loss(image_features_contrast, text_features_contrast)
# --- 3. Generative / Regression Path ---
image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
prompt_embeds_decoder = self.decoder.token_embedding_table(prompt_ids)
target_embeds_decoder = self.decoder.token_embedding_table(target_ids)
B, T_prompt, C = prompt_embeds_decoder.shape
B, T_target, _ = target_embeds_decoder.shape
# Prepare combined input sequence and attention mask for the decoder
combined_embeds = torch.cat([
image_embeds_decoder, prompt_embeds_decoder, target_embeds_decoder
], dim=1)
combined_attention_mask = torch.cat([
torch.ones(B, N_img, dtype=torch.long, device=DEVICE),
prompt_attention_mask,
target_attention_mask
], dim=1)
T_combined = combined_embeds.shape[1]
# Prepare combined targets for the classification loss
combined_class_targets = None
if generative_targets is not None:
combined_class_targets = torch.cat([
torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE),
generative_targets
], dim=1)
# --- Pass through Decoder ---
logits, class_loss, x_norm = self.decoder(
combined_embeds,
attention_mask=combined_attention_mask,
targets=combined_class_targets
)
# x_norm shape: (B, T_combined, C)
# --- Calculate Regression Output & Loss (Modified for multiple points) ---
regression_loss = None
regression_output = None
if continuous_coords is not None and coords_mask is not None and x_norm is not None:
# Strategy: Use hidden state corresponding to token *before* <result_end> (or <eos>)
# This single state predicts coordinates for *all* MAX_POINTS.
target_lengths = target_attention_mask.sum(dim=1) # Length of actual target tokens (B,)
# Index relative to start of *target sequence* is length - 2 (token before <eos>/<result_end>)
relative_target_idx = (target_lengths - 2).clamp(min=0)
# Absolute index in the combined sequence's hidden states (x_norm)
absolute_idx = N_img + T_prompt + relative_target_idx
absolute_idx = absolute_idx.clamp(max=T_combined - 1) # Clamp index
# Gather the hidden states at these specific indices
gather_indices_reg = absolute_idx.view(B, 1, 1).expand(-1, -1, C)
try:
hidden_state_for_regression = x_norm.gather(1, gather_indices_reg).squeeze(1) # Shape: (B, C)
# Pass through the regression head
regression_output_flat = self.decoder.regression_head(hidden_state_for_regression) # Shape: (B, MAX_POINTS * 2)
# Reshape to (B, MAX_POINTS, 2)
regression_output = regression_output_flat.view(B, self.max_points, 2)
# --- Calculate MASKED regression loss (L1 - Mean Absolute Error) ---
loss_per_coord = F.l1_loss(regression_output, continuous_coords, reduction='none') # (B, MAX_POINTS, 2)
# Apply mask (mask is (B, MAX_POINTS), need to broadcast to (B, MAX_POINTS, 2))
masked_loss = loss_per_coord * coords_mask.unsqueeze(-1)
# Sum loss over valid points and coordinates, divide by number of valid coordinates
num_valid_coords = coords_mask.sum() * 2 # Total number of valid x,y values in batch
if num_valid_coords > 0:
regression_loss = masked_loss.sum() / num_valid_coords
else:
regression_loss = torch.tensor(0.0, device=DEVICE) # No valid points in batch
if torch.isnan(regression_loss):
print("Warning: Regression loss is NaN.")
regression_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) # Set to zero tensor if NaN
except Exception as e:
print(f"Error during regression calculation: {e}")
print(f"x_norm shape: {x_norm.shape}, absolute_idx: {absolute_idx}")
regression_loss = None
regression_output = None # Ensure output is None if error occurs
# --- 4. Combine All Losses ---
total_loss = torch.tensor(0.0, device=DEVICE) # Ensure requires_grad=True
# Add valid losses with their respective weights
loss_log = {}
if class_loss is not None and torch.isfinite(class_loss):
total_loss += class_loss # Weight = 1.0 assumed
loss_log["class_loss"] = class_loss.item()
else:
# If class_loss is None or NaN/Inf, don't add it, log NaN
loss_log["class_loss"] = float('nan')
print(f"Warning: Invalid class_loss ({class_loss})")
if contrastive_loss is not None and torch.isfinite(contrastive_loss):
total_loss += self.lambda_contrastive * contrastive_loss
loss_log["contrastive_loss"] = contrastive_loss.item()
else:
loss_log["contrastive_loss"] = float('nan')
print(f"Warning: Invalid contrastive_loss ({contrastive_loss})")
if regression_loss is not None and torch.isfinite(regression_loss):
total_loss += self.lambda_regression * regression_loss
loss_log["regression_loss"] = regression_loss.item()
else:
loss_log["regression_loss"] = float('nan')
# Don't print warning if it was intentionally set to 0 due to no valid points
if regression_loss is not None and not (regression_loss == 0.0 and num_valid_coords == 0):
print(f"Warning: Invalid regression_loss ({regression_loss})")
# Handle case where total loss becomes NaN/Inf
if not torch.isfinite(total_loss):
print(f"Warning: Total loss became non-finite ({total_loss}). Setting to zero and clearing gradients.")
total_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True)
# It might be safer to skip the optimizer step entirely here, handled in training loop
# Use the loss_log dictionary for clearer logging later
class_loss_val = loss_log["class_loss"]
contrastive_loss_val = loss_log["contrastive_loss"]
regression_loss_val = loss_log["regression_loss"]
# Return all relevant outputs (use scalar values for loss logging)
return logits, regression_output, total_loss, \
torch.tensor(class_loss_val), torch.tensor(contrastive_loss_val), torch.tensor(regression_loss_val)
# --- Generation Method ---
@torch.no_grad() # Ensure no gradients are computed during generation
def generate(self, img_array, idx_prompt, max_new_tokens,
temperature=1.0, top_k=None, # Default to greedy if temp=1, top_k=None
force_result_start=True # Option to manually add <result_start>
):
"""
Generates token sequences autoregressively based on image and prompt.
Uses the classification head (lm_head).
Args:
img_array (torch.Tensor): Input image tensor (B, 3, H, W). B should be 1 for this impl.
idx_prompt (torch.Tensor): Input prompt token IDs (B, T_prompt).
max_new_tokens (int): Maximum number of new tokens to generate.
temperature (float): Softmax temperature. 1.0 means no change. Lower values make it sharper.
top_k (int | None): If set, restricts sampling to top K most likely tokens.
force_result_start (bool): If True, manually appends <result_start> embedding
after the prompt before starting generation loop.
Returns:
torch.Tensor: Generated sequence IDs, including the prompt (B, T_prompt + T_generated).
"""
self.eval() # Ensure model is in eval mode
B = img_array.shape[0]
if B > 1:
# This simplified generation loop assumes B=1 for clarity
# Batch generation requires careful handling of EOS and padding within the loop
print("Warning: Generation function currently assumes batch size B=1.")
# Process only the first item for now
img_array = img_array[:1]
idx_prompt = idx_prompt[:1]
B = 1
# --- 1. Prepare Initial Embeddings ---
image_embeds_raw = self.vision_encoder(img_array)
image_embeds_decoder = self.multimodal_projector(image_embeds_raw)
prompt_embeds_decoder = self.decoder.token_embedding_table(idx_prompt)
# Initial sequence for the decoder loop
current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1)
generated_ids_list = [] # Store newly generated IDs as a list
# Manually add <result_start> if forced
if force_result_start:
try:
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0]
result_start_embed = self.decoder.token_embedding_table(
torch.tensor([[result_start_token_id]], device=DEVICE)
)
current_embeds = torch.cat([current_embeds, result_start_embed], dim=1)
# Also store this token ID if we added it
generated_ids_list.append(torch.tensor([[result_start_token_id]], device=DEVICE))
except Exception as e:
print(f"Warning: Could not encode or add <result_start>: {e}")
# --- 2. Autoregressive Loop ---
for _ in range(max_new_tokens):
T_current = current_embeds.shape[1]
# Context truncation
if T_current > self.decoder.max_context:
current_embeds = current_embeds[:, -self.decoder.max_context:, :]
T_current = self.decoder.max_context
# Prepare inputs for decoder blocks
pos = torch.arange(0, T_current, dtype=torch.long, device=DEVICE)
pos = pos.clamp(max=self.decoder.max_context - 1)
pos_emb = self.decoder.position_embedding_table(pos).unsqueeze(0)
x = current_embeds + pos_emb
attention_mask = torch.ones(B, T_current, device=DEVICE, dtype=torch.long) # No padding needed
# Pass through decoder blocks
for block in self.decoder.blocks:
x = block(x, attention_mask=attention_mask)
# Get logits for the last token
x = self.decoder.ln_f(x[:, -1:, :]) # (B, 1, C)
logits = self.decoder.lm_head(x) # (B, 1, V)
logits = logits.squeeze(1) / temperature # Apply temperature (B, V)
# --- Sampling / Decoding ---
# Optional: Top-K filtering
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf') # Apply mask
# Get probabilities
probs = F.softmax(logits, dim=-1)
# Sample next token ID
# For deterministic output (greedy), use torch.argmax instead of multinomial
if temperature == 0.0 or top_k == 1: # Greedy condition
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
else:
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# Append the generated token ID
generated_ids_list.append(idx_next)
# Stop if EOS is generated
if hasattr(tokenizer, 'eos_token_id') and idx_next.item() == tokenizer.eos_token_id:
break
# Prepare for next iteration
next_token_embed = self.decoder.token_embedding_table(idx_next)
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1)
# --- 3. Combine results ---
if generated_ids_list:
generated_ids_tensor = torch.cat(generated_ids_list, dim=1) # (B, T_generated)
full_sequence_ids = torch.cat([idx_prompt, generated_ids_tensor], dim=1)
else:
full_sequence_ids = idx_prompt # Return only prompt if nothing generated
self.train() # Set model back to training mode
return full_sequence_ids |