|
|
from model_components import Block |
|
|
from constants import * |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from utils import tokenizer, vocab_size |
|
|
|
|
|
class DecoderLanguageModel(nn.Module): |
|
|
""" |
|
|
Transformer Decoder Language Model with optional coordinate regression head. |
|
|
Processes a combined sequence of embeddings. |
|
|
Outputs logits for token prediction and optionally regressed coordinates (for MAX_POINTS). |
|
|
""" |
|
|
def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, num_heads=NUM_HEADS, |
|
|
n_layer=NUM_LAYERS, max_context=CONTEXT_LENGTH, dropout=DROPOUT): |
|
|
super().__init__() |
|
|
|
|
|
self.token_embedding_table = nn.Embedding(vocab_size, n_embd) |
|
|
self.position_embedding_table = nn.Embedding(max_context, n_embd) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
Block(n_embd, num_heads, dropout, is_decoder=True) |
|
|
for _ in range(n_layer) |
|
|
]) |
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(n_embd) |
|
|
|
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.regression_head = nn.Sequential( |
|
|
nn.Linear(n_embd, n_embd // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(n_embd // 2, MAX_POINTS * 2), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
self.n_embd = n_embd |
|
|
self.max_context = max_context |
|
|
self.token_embedding_table.weight = self.lm_head.weight |
|
|
self.apply(self._init_weights) |
|
|
print(f"DecoderLanguageModel initialized with {n_layer} layers.") |
|
|
|
|
|
def _init_weights(self, module): |
|
|
|
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
torch.nn.init.ones_(module.weight) |
|
|
|
|
|
|
|
|
def forward(self, combined_embeds, attention_mask=None, targets=None): |
|
|
""" |
|
|
Forward pass for training or inference where loss is calculated. |
|
|
Regression output is now handled *outside* this module by VLM. |
|
|
""" |
|
|
|
|
|
if combined_embeds.ndim != 3: |
|
|
raise ValueError(f"DecoderLM received non-3D combined_embeds! Shape: {combined_embeds.shape}") |
|
|
B, T, C = combined_embeds.shape |
|
|
if T > self.max_context: |
|
|
|
|
|
print(f"WARNING (Decoder forward): Input sequence length {T} > max context {self.max_context}. Truncating.") |
|
|
combined_embeds = combined_embeds[:, -self.max_context:, :] |
|
|
if attention_mask is not None: attention_mask = attention_mask[:, -self.max_context:] |
|
|
if targets is not None: targets = targets[:, -self.max_context:] |
|
|
T = self.max_context |
|
|
|
|
|
|
|
|
pos = torch.arange(0, T, dtype=torch.long, device=combined_embeds.device) |
|
|
pos = pos.clamp(max=self.position_embedding_table.num_embeddings - 1) |
|
|
pos_emb = self.position_embedding_table(pos) |
|
|
x = combined_embeds + pos_emb.unsqueeze(0) |
|
|
x = self.dropout(x) |
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x, attention_mask=attention_mask) |
|
|
|
|
|
|
|
|
x_norm = self.ln_f(x) |
|
|
|
|
|
|
|
|
logits = self.lm_head(x_norm) |
|
|
|
|
|
|
|
|
class_loss = None |
|
|
if targets is not None: |
|
|
|
|
|
try: |
|
|
class_loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
targets.view(-1), |
|
|
ignore_index=-100 |
|
|
) |
|
|
if torch.isnan(class_loss): |
|
|
print("Warning: class_loss is NaN.") |
|
|
class_loss = None |
|
|
except Exception as e: |
|
|
print(f"Error calculating cross_entropy: {e}") |
|
|
print(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}") |
|
|
class_loss = None |
|
|
|
|
|
|
|
|
return logits, class_loss, x_norm |
|
|
|
|
|
|
|
|
|
|
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
|
""" |
|
|
Autoregressive generation based on starting token IDs. |
|
|
NOTE: This version doesn't handle combined embeddings directly. |
|
|
The VisionLanguageModel should ideally use a method like |
|
|
generate_from_embeddings or implement the loop externally. |
|
|
""" |
|
|
self.eval() |
|
|
for _ in range(max_new_tokens): |
|
|
|
|
|
|
|
|
idx_cond = idx if idx.size(1) <= self.max_context else idx[:, -self.max_context:] |
|
|
|
|
|
|
|
|
|
|
|
tok_embeds = self.token_embedding_table(idx_cond) |
|
|
|
|
|
pos = torch.arange(0, idx_cond.size(1), dtype=torch.long, device=idx.device) |
|
|
pos = pos.clamp(max=self.max_context - 1) |
|
|
pos_emb = self.position_embedding_table(pos).unsqueeze(0) |
|
|
x = self.dropout(tok_embeds + pos_emb) |
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x, attention_mask=None) |
|
|
|
|
|
x = self.ln_f(x[:, -1:, :]) |
|
|
logits = self.lm_head(x) |
|
|
logits = logits.squeeze(1) |
|
|
|
|
|
|
|
|
logits = logits / temperature |
|
|
if top_k is not None and top_k > 0: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
|
|
|
|
|
if hasattr(tokenizer, 'eos_token_id') and (idx_next == tokenizer.eos_token_id).all(): |
|
|
break |
|
|
self.train() |
|
|
return idx |
|
|
|
|
|
|