|
|
from model_components import ViT, MultiModalProjector |
|
|
from decoder_language_model import DecoderLanguageModel |
|
|
from constants import * |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from utils import tokenizer, vocab_size |
|
|
|
|
|
|
|
|
class VisionLanguageModel(nn.Module): |
|
|
""" |
|
|
Vision Language Model integrating ViT, Projector, Contrastive Loss, Decoder (Class + Reg). |
|
|
Handles multiple points via padded regression targets and masked loss. |
|
|
""" |
|
|
def __init__(self, |
|
|
n_embd=HIDDEN_DIM, |
|
|
vocab_size=vocab_size, |
|
|
img_size=IMAGE_SIZE, |
|
|
patch_size=PATCH_SIZE, |
|
|
num_heads=NUM_HEADS, |
|
|
num_blks_vit=NUM_LAYERS, |
|
|
num_blks_dec=NUM_LAYERS, |
|
|
emb_dropout=DROPOUT, |
|
|
blk_dropout=DROPOUT, |
|
|
max_context=CONTEXT_LENGTH, |
|
|
shared_embed_dim=SHARED_EMBED_DIM, |
|
|
lambda_contrastive=LAMBDA_CONTRASTIVE, |
|
|
lambda_regression=LAMBDA_REGRESSION, |
|
|
max_points = MAX_POINTS |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.vision_encoder = ViT( |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
num_hiddens=n_embd, |
|
|
num_heads=num_heads, |
|
|
num_blks=num_blks_vit, |
|
|
emb_dropout=emb_dropout, |
|
|
blk_dropout=blk_dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.multimodal_projector = MultiModalProjector( |
|
|
image_embed_dim=n_embd, |
|
|
text_embed_dim=n_embd, |
|
|
dropout=emb_dropout |
|
|
) |
|
|
self.image_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False) |
|
|
self.text_contrastive_head = nn.Linear(n_embd, shared_embed_dim, bias=False) |
|
|
self.logit_scale = nn.Parameter(torch.log(torch.tensor(1 / 0.07))) |
|
|
|
|
|
|
|
|
|
|
|
self.decoder = DecoderLanguageModel( |
|
|
n_embd=n_embd, |
|
|
vocab_size=vocab_size, |
|
|
num_heads=num_heads, |
|
|
n_layer=num_blks_dec, |
|
|
max_context=max_context, |
|
|
dropout=blk_dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.n_embd = n_embd |
|
|
self.vocab_size = vocab_size |
|
|
self.num_patches = (img_size // patch_size)**2 + 1 |
|
|
self.lambda_contrastive = lambda_contrastive |
|
|
self.lambda_regression = lambda_regression |
|
|
self.max_points = max_points |
|
|
|
|
|
self._resize_embeddings_if_needed(self.vocab_size) |
|
|
print("VisionLanguageModel initialized.") |
|
|
|
|
|
|
|
|
def _resize_embeddings_if_needed(self, current_vocab_size): |
|
|
""" Resizes decoder token embeddings if vocab size changed after init. """ |
|
|
decoder_embedding_size = self.decoder.token_embedding_table.num_embeddings |
|
|
if decoder_embedding_size != current_vocab_size: |
|
|
print(f"Resizing VLM decoder token embeddings from {decoder_embedding_size} to {current_vocab_size}") |
|
|
|
|
|
self.decoder.token_embedding_table.weight.requires_grad = False |
|
|
self.decoder.lm_head.weight.requires_grad = False |
|
|
|
|
|
new_embedding = nn.Embedding(current_vocab_size, self.n_embd).to(DEVICE) |
|
|
new_lm_head = nn.Linear(self.n_embd, current_vocab_size, bias=False).to(DEVICE) |
|
|
|
|
|
self.decoder.token_embedding_table = new_embedding |
|
|
self.decoder.lm_head = new_lm_head |
|
|
|
|
|
self.decoder.token_embedding_table.weight = self.decoder.lm_head.weight |
|
|
print("VLM decoder embeddings resized and weights retied.") |
|
|
|
|
|
|
|
|
def _calculate_contrastive_loss(self, image_features, text_features): |
|
|
""" Calculates the symmetric InfoNCE loss. """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_features = F.normalize(image_features, dim=-1) |
|
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
|
|
|
|
|
logit_scale = self.logit_scale.exp() |
|
|
logits_per_image = logit_scale * image_features @ text_features.t() |
|
|
logits_per_text = logits_per_image.t() |
|
|
|
|
|
|
|
|
labels = torch.arange(len(logits_per_image), device=logits_per_image.device) |
|
|
loss_i = F.cross_entropy(logits_per_image, labels) |
|
|
loss_t = F.cross_entropy(logits_per_text, labels) |
|
|
contrastive_loss = (loss_i + loss_t) / 2.0 |
|
|
|
|
|
|
|
|
if torch.isnan(contrastive_loss): |
|
|
print("Warning: Contrastive loss is NaN.") |
|
|
return None |
|
|
|
|
|
return contrastive_loss |
|
|
|
|
|
def forward(self, |
|
|
img_array, |
|
|
prompt_ids, |
|
|
prompt_attention_mask, |
|
|
target_ids, |
|
|
target_attention_mask, |
|
|
generative_targets=None, |
|
|
continuous_coords=None, |
|
|
coords_mask=None |
|
|
): |
|
|
""" |
|
|
Main forward pass for training. Calculates combined loss with masked regression loss. |
|
|
""" |
|
|
|
|
|
|
|
|
image_embeds_raw = self.vision_encoder(img_array) |
|
|
B, N_img, C_img = image_embeds_raw.shape |
|
|
img_cls_token = image_embeds_raw[:, 0] |
|
|
|
|
|
|
|
|
contrastive_loss = None |
|
|
|
|
|
image_features_contrast = self.image_contrastive_head(img_cls_token) |
|
|
with torch.no_grad(): |
|
|
prompt_text_embeds_contrast = self.decoder.token_embedding_table(prompt_ids) |
|
|
prompt_lengths = prompt_attention_mask.sum(dim=1) |
|
|
last_token_indices = (prompt_lengths - 1).clamp(min=0) |
|
|
gather_indices = last_token_indices.view(B, 1, 1).expand(-1, -1, C_img) |
|
|
prompt_last_token_embed = prompt_text_embeds_contrast.gather(1, gather_indices).squeeze(1) |
|
|
text_features_contrast = self.text_contrastive_head(prompt_last_token_embed) |
|
|
contrastive_loss = self._calculate_contrastive_loss(image_features_contrast, text_features_contrast) |
|
|
|
|
|
|
|
|
|
|
|
image_embeds_decoder = self.multimodal_projector(image_embeds_raw) |
|
|
prompt_embeds_decoder = self.decoder.token_embedding_table(prompt_ids) |
|
|
target_embeds_decoder = self.decoder.token_embedding_table(target_ids) |
|
|
B, T_prompt, C = prompt_embeds_decoder.shape |
|
|
B, T_target, _ = target_embeds_decoder.shape |
|
|
|
|
|
|
|
|
combined_embeds = torch.cat([ |
|
|
image_embeds_decoder, prompt_embeds_decoder, target_embeds_decoder |
|
|
], dim=1) |
|
|
combined_attention_mask = torch.cat([ |
|
|
torch.ones(B, N_img, dtype=torch.long, device=DEVICE), |
|
|
prompt_attention_mask, |
|
|
target_attention_mask |
|
|
], dim=1) |
|
|
T_combined = combined_embeds.shape[1] |
|
|
|
|
|
|
|
|
combined_class_targets = None |
|
|
if generative_targets is not None: |
|
|
combined_class_targets = torch.cat([ |
|
|
torch.full((B, N_img + T_prompt), -100, dtype=torch.long, device=DEVICE), |
|
|
generative_targets |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
logits, class_loss, x_norm = self.decoder( |
|
|
combined_embeds, |
|
|
attention_mask=combined_attention_mask, |
|
|
targets=combined_class_targets |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
regression_loss = None |
|
|
regression_output = None |
|
|
if continuous_coords is not None and coords_mask is not None and x_norm is not None: |
|
|
|
|
|
|
|
|
target_lengths = target_attention_mask.sum(dim=1) |
|
|
|
|
|
relative_target_idx = (target_lengths - 2).clamp(min=0) |
|
|
|
|
|
absolute_idx = N_img + T_prompt + relative_target_idx |
|
|
absolute_idx = absolute_idx.clamp(max=T_combined - 1) |
|
|
|
|
|
|
|
|
gather_indices_reg = absolute_idx.view(B, 1, 1).expand(-1, -1, C) |
|
|
try: |
|
|
hidden_state_for_regression = x_norm.gather(1, gather_indices_reg).squeeze(1) |
|
|
|
|
|
regression_output_flat = self.decoder.regression_head(hidden_state_for_regression) |
|
|
|
|
|
regression_output = regression_output_flat.view(B, self.max_points, 2) |
|
|
|
|
|
|
|
|
loss_per_coord = F.l1_loss(regression_output, continuous_coords, reduction='none') |
|
|
|
|
|
masked_loss = loss_per_coord * coords_mask.unsqueeze(-1) |
|
|
|
|
|
num_valid_coords = coords_mask.sum() * 2 |
|
|
if num_valid_coords > 0: |
|
|
regression_loss = masked_loss.sum() / num_valid_coords |
|
|
else: |
|
|
regression_loss = torch.tensor(0.0, device=DEVICE) |
|
|
|
|
|
if torch.isnan(regression_loss): |
|
|
print("Warning: Regression loss is NaN.") |
|
|
regression_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during regression calculation: {e}") |
|
|
print(f"x_norm shape: {x_norm.shape}, absolute_idx: {absolute_idx}") |
|
|
regression_loss = None |
|
|
regression_output = None |
|
|
|
|
|
|
|
|
|
|
|
total_loss = torch.tensor(0.0, device=DEVICE) |
|
|
|
|
|
loss_log = {} |
|
|
if class_loss is not None and torch.isfinite(class_loss): |
|
|
total_loss += class_loss |
|
|
loss_log["class_loss"] = class_loss.item() |
|
|
else: |
|
|
|
|
|
loss_log["class_loss"] = float('nan') |
|
|
print(f"Warning: Invalid class_loss ({class_loss})") |
|
|
|
|
|
|
|
|
if contrastive_loss is not None and torch.isfinite(contrastive_loss): |
|
|
total_loss += self.lambda_contrastive * contrastive_loss |
|
|
loss_log["contrastive_loss"] = contrastive_loss.item() |
|
|
else: |
|
|
loss_log["contrastive_loss"] = float('nan') |
|
|
print(f"Warning: Invalid contrastive_loss ({contrastive_loss})") |
|
|
|
|
|
|
|
|
if regression_loss is not None and torch.isfinite(regression_loss): |
|
|
total_loss += self.lambda_regression * regression_loss |
|
|
loss_log["regression_loss"] = regression_loss.item() |
|
|
else: |
|
|
loss_log["regression_loss"] = float('nan') |
|
|
|
|
|
if regression_loss is not None and not (regression_loss == 0.0 and num_valid_coords == 0): |
|
|
print(f"Warning: Invalid regression_loss ({regression_loss})") |
|
|
|
|
|
|
|
|
|
|
|
if not torch.isfinite(total_loss): |
|
|
print(f"Warning: Total loss became non-finite ({total_loss}). Setting to zero and clearing gradients.") |
|
|
total_loss = torch.tensor(0.0, device=DEVICE, requires_grad=True) |
|
|
|
|
|
|
|
|
|
|
|
class_loss_val = loss_log["class_loss"] |
|
|
contrastive_loss_val = loss_log["contrastive_loss"] |
|
|
regression_loss_val = loss_log["regression_loss"] |
|
|
|
|
|
|
|
|
return logits, regression_output, total_loss, \ |
|
|
torch.tensor(class_loss_val), torch.tensor(contrastive_loss_val), torch.tensor(regression_loss_val) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, img_array, idx_prompt, max_new_tokens, |
|
|
temperature=1.0, top_k=None, |
|
|
force_result_start=True |
|
|
): |
|
|
""" |
|
|
Generates token sequences autoregressively based on image and prompt. |
|
|
Uses the classification head (lm_head). |
|
|
|
|
|
Args: |
|
|
img_array (torch.Tensor): Input image tensor (B, 3, H, W). B should be 1 for this impl. |
|
|
idx_prompt (torch.Tensor): Input prompt token IDs (B, T_prompt). |
|
|
max_new_tokens (int): Maximum number of new tokens to generate. |
|
|
temperature (float): Softmax temperature. 1.0 means no change. Lower values make it sharper. |
|
|
top_k (int | None): If set, restricts sampling to top K most likely tokens. |
|
|
force_result_start (bool): If True, manually appends <result_start> embedding |
|
|
after the prompt before starting generation loop. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Generated sequence IDs, including the prompt (B, T_prompt + T_generated). |
|
|
""" |
|
|
self.eval() |
|
|
B = img_array.shape[0] |
|
|
if B > 1: |
|
|
|
|
|
|
|
|
print("Warning: Generation function currently assumes batch size B=1.") |
|
|
|
|
|
img_array = img_array[:1] |
|
|
idx_prompt = idx_prompt[:1] |
|
|
B = 1 |
|
|
|
|
|
|
|
|
image_embeds_raw = self.vision_encoder(img_array) |
|
|
image_embeds_decoder = self.multimodal_projector(image_embeds_raw) |
|
|
prompt_embeds_decoder = self.decoder.token_embedding_table(idx_prompt) |
|
|
|
|
|
|
|
|
current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1) |
|
|
generated_ids_list = [] |
|
|
|
|
|
|
|
|
if force_result_start: |
|
|
try: |
|
|
result_start_token_id = tokenizer.encode("<result_start>", add_special_tokens=False)[0] |
|
|
result_start_embed = self.decoder.token_embedding_table( |
|
|
torch.tensor([[result_start_token_id]], device=DEVICE) |
|
|
) |
|
|
current_embeds = torch.cat([current_embeds, result_start_embed], dim=1) |
|
|
|
|
|
generated_ids_list.append(torch.tensor([[result_start_token_id]], device=DEVICE)) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not encode or add <result_start>: {e}") |
|
|
|
|
|
|
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
T_current = current_embeds.shape[1] |
|
|
|
|
|
|
|
|
if T_current > self.decoder.max_context: |
|
|
current_embeds = current_embeds[:, -self.decoder.max_context:, :] |
|
|
T_current = self.decoder.max_context |
|
|
|
|
|
|
|
|
pos = torch.arange(0, T_current, dtype=torch.long, device=DEVICE) |
|
|
pos = pos.clamp(max=self.decoder.max_context - 1) |
|
|
pos_emb = self.decoder.position_embedding_table(pos).unsqueeze(0) |
|
|
x = current_embeds + pos_emb |
|
|
attention_mask = torch.ones(B, T_current, device=DEVICE, dtype=torch.long) |
|
|
|
|
|
|
|
|
for block in self.decoder.blocks: |
|
|
x = block(x, attention_mask=attention_mask) |
|
|
|
|
|
|
|
|
x = self.decoder.ln_f(x[:, -1:, :]) |
|
|
logits = self.decoder.lm_head(x) |
|
|
logits = logits.squeeze(1) / temperature |
|
|
|
|
|
|
|
|
|
|
|
if top_k is not None and top_k > 0: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
if temperature == 0.0 or top_k == 1: |
|
|
idx_next = torch.argmax(probs, dim=-1, keepdim=True) |
|
|
else: |
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
generated_ids_list.append(idx_next) |
|
|
|
|
|
|
|
|
if hasattr(tokenizer, 'eos_token_id') and idx_next.item() == tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
|
|
|
next_token_embed = self.decoder.token_embedding_table(idx_next) |
|
|
current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
if generated_ids_list: |
|
|
generated_ids_tensor = torch.cat(generated_ids_list, dim=1) |
|
|
full_sequence_ids = torch.cat([idx_prompt, generated_ids_tensor], dim=1) |
|
|
else: |
|
|
full_sequence_ids = idx_prompt |
|
|
|
|
|
self.train() |
|
|
return full_sequence_ids |