File size: 7,349 Bytes
b781107 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from model_components import Block
from constants import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import tokenizer, vocab_size
class DecoderLanguageModel(nn.Module):
"""
Transformer Decoder Language Model with optional coordinate regression head.
Processes a combined sequence of embeddings.
Outputs logits for token prediction and optionally regressed coordinates (for MAX_POINTS).
"""
def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, num_heads=NUM_HEADS,
n_layer=NUM_LAYERS, max_context=CONTEXT_LENGTH, dropout=DROPOUT):
super().__init__()
# --- Input Embeddings ---
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(max_context, n_embd)
self.dropout = nn.Dropout(dropout)
# --- Transformer Blocks ---
self.blocks = nn.ModuleList([
Block(n_embd, num_heads, dropout, is_decoder=True)
for _ in range(n_layer)
])
# --- Final Layer Norm ---
self.ln_f = nn.LayerNorm(n_embd)
# --- Output Heads ---
# 1. Head for token classification
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
# 2. Head for direct coordinate regression (predicting MAX_POINTS * 2 values)
self.regression_head = nn.Sequential(
nn.Linear(n_embd, n_embd // 2),
nn.GELU(),
nn.Linear(n_embd // 2, MAX_POINTS * 2), # Output MAX_POINTS * (x, y)
nn.Sigmoid() # Output activation [0, 1]
)
# --- End Output Heads ---
self.n_embd = n_embd
self.max_context = max_context
self.token_embedding_table.weight = self.lm_head.weight
self.apply(self._init_weights)
print(f"DecoderLanguageModel initialized with {n_layer} layers.")
def _init_weights(self, module):
# ... (same as before) ...
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def forward(self, combined_embeds, attention_mask=None, targets=None):
"""
Forward pass for training or inference where loss is calculated.
Regression output is now handled *outside* this module by VLM.
"""
# --- Input Validation & Processing ---
if combined_embeds.ndim != 3:
raise ValueError(f"DecoderLM received non-3D combined_embeds! Shape: {combined_embeds.shape}")
B, T, C = combined_embeds.shape
if T > self.max_context:
# ... (context truncation logic - same as before) ...
print(f"WARNING (Decoder forward): Input sequence length {T} > max context {self.max_context}. Truncating.")
combined_embeds = combined_embeds[:, -self.max_context:, :]
if attention_mask is not None: attention_mask = attention_mask[:, -self.max_context:]
if targets is not None: targets = targets[:, -self.max_context:]
T = self.max_context
# --- Positional Encoding ---
pos = torch.arange(0, T, dtype=torch.long, device=combined_embeds.device)
pos = pos.clamp(max=self.position_embedding_table.num_embeddings - 1)
pos_emb = self.position_embedding_table(pos) # Shape: (T, C)
x = combined_embeds + pos_emb.unsqueeze(0)
x = self.dropout(x)
# --- Transformer Blocks ---
for block in self.blocks:
x = block(x, attention_mask=attention_mask)
# --- Final Layer Norm ---
x_norm = self.ln_f(x) # Shape: (B, T, C) - Pass this out for VLM regression head
# --- Classification Head Output ---
logits = self.lm_head(x_norm) # Shape: (B, T, VocabSize)
# --- Classification Loss Calculation ---
class_loss = None
if targets is not None:
# ... (cross_entropy calculation - same as before) ...
try:
class_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100
)
if torch.isnan(class_loss):
print("Warning: class_loss is NaN.")
class_loss = None
except Exception as e:
print(f"Error calculating cross_entropy: {e}")
print(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}")
class_loss = None
# Return logits, class_loss, and the final normalized hidden states
return logits, class_loss, x_norm
# --- Generation Method (Example - if needed internally, otherwise VLM handles it) ---
# If VLM needs this class to perform generation based on token IDs:
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Autoregressive generation based on starting token IDs.
NOTE: This version doesn't handle combined embeddings directly.
The VisionLanguageModel should ideally use a method like
generate_from_embeddings or implement the loop externally.
"""
self.eval()
for _ in range(max_new_tokens):
# --- Context Management ---
# Crop idx if longer than context length
idx_cond = idx if idx.size(1) <= self.max_context else idx[:, -self.max_context:]
# --- Forward Pass ---
# Get embeddings
tok_embeds = self.token_embedding_table(idx_cond) # (B, T, C)
# Get positional embeddings
pos = torch.arange(0, idx_cond.size(1), dtype=torch.long, device=idx.device)
pos = pos.clamp(max=self.max_context - 1)
pos_emb = self.position_embedding_table(pos).unsqueeze(0) # (1, T, C)
x = self.dropout(tok_embeds + pos_emb)
# Pass through blocks (no padding mask needed here as we handle single sequence)
for block in self.blocks:
x = block(x, attention_mask=None) # Causal mask is internal to block/head
# Final layer norm and head for the last token only
x = self.ln_f(x[:, -1:, :]) # (B, 1, C)
logits = self.lm_head(x) # (B, 1, V)
logits = logits.squeeze(1) # (B, V)
# --- Sampling ---
logits = logits / temperature
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# Append sampled token
idx = torch.cat((idx, idx_next), dim=1)
# Stop if EOS
if hasattr(tokenizer, 'eos_token_id') and (idx_next == tokenizer.eos_token_id).all():
break
self.train()
return idx
|