""" Text Decoder Module for SEM V6 - Multi-Token Prediction. Converts latent representations to text using Llama-3 tokenizer. Uses multi-token prediction (MTP) for faster training and better representations. Architecture: latent (256) → N prediction heads → vocab_logits[N, vocab_size] → tokens → text Reference: - Multi-Token Prediction (Meta 2024): https://arxiv.org/abs/2404.19737 - Llama 3: https://llama.meta.com/ """ import os import glob import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from transformers import AutoTokenizer class MultiTokenDecoder(nn.Module): """ Multi-token prediction decoder for SEM V6 with TIED EMBEDDINGS. Predicts N future tokens simultaneously from a latent vector, enabling faster training and richer representations. Architecture (per prediction head): latent (latent_dim) → shared_trunk (latent_dim → hidden_dim) → head_i (hidden_dim → embed_dim) → vocab_embedding.T (embed_dim → vocab_size) # TIED WEIGHTS The key innovation is TIED EMBEDDINGS: the vocabulary embedding matrix is shared between input and output projections. This: 1. Reduces parameters by 256x (from hidden→vocab to hidden→embed + embed→vocab shared) 2. Provides semantic guidance - vocabulary learns meaningful relationships 3. Matches standard LLM practice (GPT-2, Llama, etc.) Each head predicts token at position i (0 = next token, 1 = token after, etc.) Example: >>> decoder = MultiTokenDecoder(latent_dim=256, num_predict=4) >>> z = torch.randn(1, 256) >>> logits = decoder(z) # (1, 4, vocab_size) >>> text = decoder.generate(z, max_length=100) """ def __init__( self, latent_dim: int = 256, hidden_dim: int = 512, num_predict: int = 4, # Number of future tokens to predict tokenizer_name: str = "meta-llama/Meta-Llama-3-8B", device: str = "cuda", embed_dim: int = 512, # Embedding dimension for tied weights ) -> None: """ Initialize multi-token decoder with tied embeddings. Args: latent_dim: Dimension of input latent vectors hidden_dim: Hidden layer dimension num_predict: Number of future tokens to predict (default 4) tokenizer_name: HuggingFace tokenizer (default: Llama-3.2-1B) device: Computation device embed_dim: Vocabulary embedding dimension (default 512) Tied between input/output for semantic learning """ super().__init__() self.latent_dim = latent_dim self.hidden_dim = hidden_dim self.num_predict = num_predict self.device = torch.device(device) self.embed_dim = embed_dim # Resolve tokenizer path for offline use (prefer local cache) tokenizer_path = self._resolve_tokenizer_path(tokenizer_name) # Load tokenizer # Offline-safe: rely on local HF cache when network is unavailable. self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, trust_remote_code=True, local_files_only=True, ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Use len(tokenizer) not vocab_size - includes special tokens (128256 vs 128000) self.vocab_size = len(self.tokenizer) # TIED EMBEDDING: Shared vocabulary embedding matrix # This is the key fix - vocabulary learns semantic relationships # Output projection uses embedding.weight.T (tied weights) self.vocab_embedding = nn.Embedding(self.vocab_size, embed_dim) # Scale init for 512-dim bottleneck: std=0.02 gives σ_logit≈0.32 (near-uniform softmax) # Target σ_logit≈3.0 for meaningful discrimination over 128k vocab embed_init_std = 3.0 / (embed_dim ** 0.5) # ≈0.13 for embed_dim=512 nn.init.normal_(self.vocab_embedding.weight, std=embed_init_std) # Shared trunk: latent → hidden representation self.trunk = nn.Sequential( nn.LayerNorm(latent_dim), nn.Linear(latent_dim, hidden_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), ) # Separate prediction heads for each future position # Head i predicts token at position i # NOW: hidden_dim → embed_dim (NOT vocab_size directly) self.heads = nn.ModuleList( [ nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear( hidden_dim, embed_dim ), # Output embed_dim, not vocab_size ) for _ in range(num_predict) ] ) # Output bias (optional, for vocab projection) self.output_bias = nn.Parameter(torch.zeros(self.vocab_size)) # Embedding trunk: embed_dim → hidden for direct embedding path # This is used by forward_from_embedding() for learnable encoder training # Must be defined in __init__ (not lazily) so weights can be loaded from checkpoint self._embed_trunk = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), ) # Initialize head output layers for head in self.heads: nn.init.normal_(head[-1].weight, std=0.02) nn.init.zeros_(head[-1].bias) # Initialize _embed_trunk layers for layer in self._embed_trunk: if isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, std=0.02) nn.init.zeros_(layer.bias) @staticmethod def _resolve_tokenizer_path(tokenizer_name: str) -> str: """Prefer local tokenizer snapshot to avoid network calls.""" # Explicit override via env env_path = os.getenv("SEM_TOKENIZER_PATH") if env_path and os.path.isdir(env_path): return env_path # If tokenizer_name is already a local path, use it if os.path.isdir(tokenizer_name): return tokenizer_name # Try to resolve from HF cache cache_root = os.path.expanduser("~/.cache/huggingface/hub") model_dir = f"models--{tokenizer_name.replace('/', '--')}" snapshot_glob = os.path.join(cache_root, model_dir, "snapshots", "*") snapshots = sorted(glob.glob(snapshot_glob), key=os.path.getmtime, reverse=True) if snapshots: return snapshots[0] # Fallback to original name (may fail offline) return tokenizer_name def forward(self, z: torch.Tensor) -> torch.Tensor: """ Predict multiple future tokens from latent using TIED EMBEDDINGS. Args: z: Latent tensor (batch, latent_dim) or (latent_dim,) Returns: logits: (batch, num_predict, vocab_size) or (num_predict, vocab_size) """ squeeze = z.dim() == 1 if squeeze: z = z.unsqueeze(0) # Shared representation h = self.trunk(z) # (batch, hidden_dim) # Predict from each head using TIED EMBEDDINGS all_logits = [] for head in self.heads: # Head outputs embed_dim representation embed = head(h) # (batch, embed_dim) # Project to vocab using TIED embedding weights # logits = embed @ vocab_embedding.weight.T + bias logits = F.linear(embed, self.vocab_embedding.weight, self.output_bias) all_logits.append(logits) # (batch, vocab_size) # Stack: (batch, num_predict, vocab_size) output = torch.stack(all_logits, dim=1) if squeeze: output = output.squeeze(0) return output def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: """ Embed token IDs using the shared vocabulary embedding. This enables using the same embeddings for both input and output, which is the key to tied embeddings learning meaningful semantics. Args: token_ids: Token IDs (batch, seq_len) or (seq_len,) Returns: embeddings: Token embeddings (batch, seq_len, embed_dim) or (seq_len, embed_dim) """ return self.vocab_embedding(token_ids) def forward_from_embedding(self, embed: torch.Tensor) -> torch.Tensor: """ Predict tokens directly from embedding space (bypasses latent→hidden trunk). This provides a direct path for text generation when using learnable encoder with tied embeddings. The embedding is projected to vocab using tied weights. Args: embed: Embedding tensor (batch, embed_dim) or (embed_dim,) This should be the pooled embedding from LearnableEncoder.encode_to_embedding() Returns: logits: (batch, num_predict, vocab_size) or (num_predict, vocab_size) """ squeeze = embed.dim() == 1 if squeeze: embed = embed.unsqueeze(0) # Shared representation from embedding # Note: _embed_trunk is now created in __init__ so weights load from checkpoint h = self._embed_trunk(embed) # (batch, hidden_dim) # Predict from each head using TIED EMBEDDINGS all_logits = [] for head in self.heads: head_embed = head(h) # (batch, embed_dim) logits = F.linear(head_embed, self.vocab_embedding.weight, self.output_bias) all_logits.append(logits) output = torch.stack(all_logits, dim=1) if squeeze: output = output.squeeze(0) return output def decode_tokens(self, token_ids: torch.Tensor) -> str: """Decode token IDs to text.""" if token_ids.dim() > 1: token_ids = token_ids.flatten() return self.tokenizer.decode(token_ids.tolist(), skip_special_tokens=True) def encode_text(self, text: str) -> torch.Tensor: """Encode text to token IDs.""" tokens = self.tokenizer.encode(text, return_tensors="pt") return tokens.to(self.device) @torch.no_grad() def generate( self, z: torch.Tensor, max_length: int = 100, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9, ) -> str: """ Generate text using multi-token prediction. Generates num_predict tokens at a time for faster generation. Args: z: Latent vector (latent_dim,) or (1, latent_dim) max_length: Maximum tokens to generate temperature: Sampling temperature top_k: Top-k filtering top_p: Nucleus sampling Returns: Generated text string """ if z.dim() == 1: z = z.unsqueeze(0) generated_ids = [] current_z = z safe_temperature = max(temperature, 1e-4) while len(generated_ids) < max_length: # Get multi-token prediction logits = self.forward(current_z) # (1, num_predict, vocab_size) # Sample from each position for i in range(self.num_predict): if len(generated_ids) >= max_length: break pos_logits = logits[0, i] / safe_temperature # Top-k filtering if top_k > 0: top_k_vals, _ = torch.topk( pos_logits, min(top_k, pos_logits.size(-1)) ) pos_logits[pos_logits < top_k_vals[-1]] = float("-inf") # Top-p filtering if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(pos_logits, descending=True) cumsum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) mask = cumsum > top_p mask[1:] = mask[:-1].clone() mask[0] = False sorted_logits[mask] = float("-inf") pos_logits = sorted_logits.scatter( 0, sorted_idx.argsort(), sorted_logits ) if torch.isneginf(pos_logits).all(): pos_logits = torch.zeros_like(pos_logits) pos_logits = torch.clamp(pos_logits, -50.0, 50.0) probs = F.softmax(pos_logits, dim=-1) probs = torch.where( torch.isfinite(probs), probs, torch.zeros_like(probs) ) probs = probs + 1e-8 total = probs.sum() if not torch.isfinite(total) or total <= 0: probs = torch.ones_like(probs) / probs.numel() else: probs = probs / total token_id = torch.multinomial(probs, 1).item() # Stop at EOS if token_id == self.tokenizer.eos_token_id: return self.tokenizer.decode( generated_ids, skip_special_tokens=True ) generated_ids.append(token_id) # For continuous generation, we'd update z based on generated tokens # This requires feeding tokens back through the encoder # For now, we continue with same z (limitation - proper impl needs feedback) return self.tokenizer.decode(generated_ids, skip_special_tokens=True) def multi_token_prediction_loss( z_pred: torch.Tensor, target_text: list[str], decoder: MultiTokenDecoder, max_length: int = 128, loss_weight: float = 1.0, position_decay: float = 1.0, ) -> tuple[torch.Tensor, dict]: """ Multi-token prediction loss following Meta's MTP approach. Trains decoder to predict N future tokens from latent representation. Each head predicts a different future position. Reference: - Meta (2024): Better & Faster LLMs via Multi-token Prediction https://arxiv.org/abs/2404.19737 - NVIDIA Megatron: Average across depths, apply scaling factor https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/multi_token_prediction.html Args: z_pred: Predicted latent (batch, latent_dim) target_text: List of target text strings decoder: MultiTokenDecoder module max_length: Max tokenization length loss_weight: Scaling factor for MTP loss (default 1.0, use 0.1 when combining with main task loss per NVIDIA Megatron) position_decay: Exponential decay factor for later positions (default 1.0 = no decay) Use <1.0 (e.g., 0.9) to weight earlier positions more heavily Returns: loss: Combined cross-entropy loss (averaged across positions, scaled by loss_weight) metrics: Dict with per-position accuracy and loss breakdown """ batch_size = z_pred.size(0) device = z_pred.device num_predict = decoder.num_predict # Tokenize targets encoded = decoder.tokenizer( target_text, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) target_ids = encoded["input_ids"].to(device) # (batch, seq_len) seq_len = target_ids.size(1) # Get multi-token predictions logits = decoder(z_pred) # (batch, num_predict, vocab_size) # Compute loss for each prediction position # Per Meta: average across all depths # Per NVIDIA: compute average of MTP losses across all depths, multiply by scaling factor total_loss = 0.0 total_weight = 0.0 accuracies = [] per_pos_losses = [] num_positions = min(num_predict, seq_len) for i in range(num_positions): pos_logits = logits[:, i] # (batch, vocab_size) pos_targets = target_ids[:, i] # (batch,) # Cross-entropy for this position pos_loss = F.cross_entropy(pos_logits, pos_targets) # Optional position weighting (earlier positions = higher weight) pos_weight = position_decay**i total_loss = total_loss + pos_loss * pos_weight total_weight += pos_weight per_pos_losses.append(pos_loss.item()) # Accuracy with torch.no_grad(): preds = pos_logits.argmax(dim=-1) acc = (preds == pos_targets).float().mean().item() accuracies.append(acc) # Average loss across positions (weighted if position_decay < 1.0) avg_loss = total_loss / total_weight if total_weight > 0 else total_loss # Apply loss weight (scaling factor per NVIDIA Megatron) scaled_loss = avg_loss * loss_weight metrics = { "mtp_loss": scaled_loss.item(), "mtp_loss_unscaled": avg_loss.item(), "avg_accuracy": sum(accuracies) / len(accuracies) if accuracies else 0.0, **{f"acc_pos_{i}": acc for i, acc in enumerate(accuracies)}, **{f"loss_pos_{i}": loss for i, loss in enumerate(per_pos_losses)}, } return scaled_loss, metrics # Alias for backward compatibility TextDecoder = MultiTokenDecoder text_generation_loss = multi_token_prediction_loss