pointing_models_from_scratch / decoder_language_model.py
mbiswas's picture
Upload 10 files
b781107 verified
from model_components import Block
from constants import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import tokenizer, vocab_size
class DecoderLanguageModel(nn.Module):
"""
Transformer Decoder Language Model with optional coordinate regression head.
Processes a combined sequence of embeddings.
Outputs logits for token prediction and optionally regressed coordinates (for MAX_POINTS).
"""
def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, num_heads=NUM_HEADS,
n_layer=NUM_LAYERS, max_context=CONTEXT_LENGTH, dropout=DROPOUT):
super().__init__()
# --- Input Embeddings ---
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(max_context, n_embd)
self.dropout = nn.Dropout(dropout)
# --- Transformer Blocks ---
self.blocks = nn.ModuleList([
Block(n_embd, num_heads, dropout, is_decoder=True)
for _ in range(n_layer)
])
# --- Final Layer Norm ---
self.ln_f = nn.LayerNorm(n_embd)
# --- Output Heads ---
# 1. Head for token classification
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
# 2. Head for direct coordinate regression (predicting MAX_POINTS * 2 values)
self.regression_head = nn.Sequential(
nn.Linear(n_embd, n_embd // 2),
nn.GELU(),
nn.Linear(n_embd // 2, MAX_POINTS * 2), # Output MAX_POINTS * (x, y)
nn.Sigmoid() # Output activation [0, 1]
)
# --- End Output Heads ---
self.n_embd = n_embd
self.max_context = max_context
self.token_embedding_table.weight = self.lm_head.weight
self.apply(self._init_weights)
print(f"DecoderLanguageModel initialized with {n_layer} layers.")
def _init_weights(self, module):
# ... (same as before) ...
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def forward(self, combined_embeds, attention_mask=None, targets=None):
"""
Forward pass for training or inference where loss is calculated.
Regression output is now handled *outside* this module by VLM.
"""
# --- Input Validation & Processing ---
if combined_embeds.ndim != 3:
raise ValueError(f"DecoderLM received non-3D combined_embeds! Shape: {combined_embeds.shape}")
B, T, C = combined_embeds.shape
if T > self.max_context:
# ... (context truncation logic - same as before) ...
print(f"WARNING (Decoder forward): Input sequence length {T} > max context {self.max_context}. Truncating.")
combined_embeds = combined_embeds[:, -self.max_context:, :]
if attention_mask is not None: attention_mask = attention_mask[:, -self.max_context:]
if targets is not None: targets = targets[:, -self.max_context:]
T = self.max_context
# --- Positional Encoding ---
pos = torch.arange(0, T, dtype=torch.long, device=combined_embeds.device)
pos = pos.clamp(max=self.position_embedding_table.num_embeddings - 1)
pos_emb = self.position_embedding_table(pos) # Shape: (T, C)
x = combined_embeds + pos_emb.unsqueeze(0)
x = self.dropout(x)
# --- Transformer Blocks ---
for block in self.blocks:
x = block(x, attention_mask=attention_mask)
# --- Final Layer Norm ---
x_norm = self.ln_f(x) # Shape: (B, T, C) - Pass this out for VLM regression head
# --- Classification Head Output ---
logits = self.lm_head(x_norm) # Shape: (B, T, VocabSize)
# --- Classification Loss Calculation ---
class_loss = None
if targets is not None:
# ... (cross_entropy calculation - same as before) ...
try:
class_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100
)
if torch.isnan(class_loss):
print("Warning: class_loss is NaN.")
class_loss = None
except Exception as e:
print(f"Error calculating cross_entropy: {e}")
print(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}")
class_loss = None
# Return logits, class_loss, and the final normalized hidden states
return logits, class_loss, x_norm
# --- Generation Method (Example - if needed internally, otherwise VLM handles it) ---
# If VLM needs this class to perform generation based on token IDs:
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Autoregressive generation based on starting token IDs.
NOTE: This version doesn't handle combined embeddings directly.
The VisionLanguageModel should ideally use a method like
generate_from_embeddings or implement the loop externally.
"""
self.eval()
for _ in range(max_new_tokens):
# --- Context Management ---
# Crop idx if longer than context length
idx_cond = idx if idx.size(1) <= self.max_context else idx[:, -self.max_context:]
# --- Forward Pass ---
# Get embeddings
tok_embeds = self.token_embedding_table(idx_cond) # (B, T, C)
# Get positional embeddings
pos = torch.arange(0, idx_cond.size(1), dtype=torch.long, device=idx.device)
pos = pos.clamp(max=self.max_context - 1)
pos_emb = self.position_embedding_table(pos).unsqueeze(0) # (1, T, C)
x = self.dropout(tok_embeds + pos_emb)
# Pass through blocks (no padding mask needed here as we handle single sequence)
for block in self.blocks:
x = block(x, attention_mask=None) # Causal mask is internal to block/head
# Final layer norm and head for the last token only
x = self.ln_f(x[:, -1:, :]) # (B, 1, C)
logits = self.lm_head(x) # (B, 1, V)
logits = logits.squeeze(1) # (B, V)
# --- Sampling ---
logits = logits / temperature
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# Append sampled token
idx = torch.cat((idx, idx_next), dim=1)
# Stop if EOS
if hasattr(tokenizer, 'eos_token_id') and (idx_next == tokenizer.eos_token_id).all():
break
self.train()
return idx