sem-v6-training / src /sem_v6 /modules /text_decoder.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
"""
Text Decoder Module for SEM V6 - Multi-Token Prediction.
Converts latent representations to text using Llama-3 tokenizer.
Uses multi-token prediction (MTP) for faster training and better representations.
Architecture:
latent (256) → N prediction heads → vocab_logits[N, vocab_size] → tokens → text
Reference:
- Multi-Token Prediction (Meta 2024): https://arxiv.org/abs/2404.19737
- Llama 3: https://llama.meta.com/
"""
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from transformers import AutoTokenizer
class MultiTokenDecoder(nn.Module):
"""
Multi-token prediction decoder for SEM V6 with TIED EMBEDDINGS.
Predicts N future tokens simultaneously from a latent vector,
enabling faster training and richer representations.
Architecture (per prediction head):
latent (latent_dim)
→ shared_trunk (latent_dim → hidden_dim)
→ head_i (hidden_dim → embed_dim)
→ vocab_embedding.T (embed_dim → vocab_size) # TIED WEIGHTS
The key innovation is TIED EMBEDDINGS: the vocabulary embedding matrix
is shared between input and output projections. This:
1. Reduces parameters by 256x (from hidden→vocab to hidden→embed + embed→vocab shared)
2. Provides semantic guidance - vocabulary learns meaningful relationships
3. Matches standard LLM practice (GPT-2, Llama, etc.)
Each head predicts token at position i (0 = next token, 1 = token after, etc.)
Example:
>>> decoder = MultiTokenDecoder(latent_dim=256, num_predict=4)
>>> z = torch.randn(1, 256)
>>> logits = decoder(z) # (1, 4, vocab_size)
>>> text = decoder.generate(z, max_length=100)
"""
def __init__(
self,
latent_dim: int = 256,
hidden_dim: int = 512,
num_predict: int = 4, # Number of future tokens to predict
tokenizer_name: str = "meta-llama/Meta-Llama-3-8B",
device: str = "cuda",
embed_dim: int = 512, # Embedding dimension for tied weights
) -> None:
"""
Initialize multi-token decoder with tied embeddings.
Args:
latent_dim: Dimension of input latent vectors
hidden_dim: Hidden layer dimension
num_predict: Number of future tokens to predict (default 4)
tokenizer_name: HuggingFace tokenizer (default: Llama-3.2-1B)
device: Computation device
embed_dim: Vocabulary embedding dimension (default 512)
Tied between input/output for semantic learning
"""
super().__init__()
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.num_predict = num_predict
self.device = torch.device(device)
self.embed_dim = embed_dim
# Resolve tokenizer path for offline use (prefer local cache)
tokenizer_path = self._resolve_tokenizer_path(tokenizer_name)
# Load tokenizer
# Offline-safe: rely on local HF cache when network is unavailable.
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
trust_remote_code=True,
local_files_only=True,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Use len(tokenizer) not vocab_size - includes special tokens (128256 vs 128000)
self.vocab_size = len(self.tokenizer)
# TIED EMBEDDING: Shared vocabulary embedding matrix
# This is the key fix - vocabulary learns semantic relationships
# Output projection uses embedding.weight.T (tied weights)
self.vocab_embedding = nn.Embedding(self.vocab_size, embed_dim)
# Scale init for 512-dim bottleneck: std=0.02 gives σ_logit≈0.32 (near-uniform softmax)
# Target σ_logit≈3.0 for meaningful discrimination over 128k vocab
embed_init_std = 3.0 / (embed_dim ** 0.5) # ≈0.13 for embed_dim=512
nn.init.normal_(self.vocab_embedding.weight, std=embed_init_std)
# Shared trunk: latent → hidden representation
self.trunk = nn.Sequential(
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
)
# Separate prediction heads for each future position
# Head i predicts token at position i
# NOW: hidden_dim → embed_dim (NOT vocab_size directly)
self.heads = nn.ModuleList(
[
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(
hidden_dim, embed_dim
), # Output embed_dim, not vocab_size
)
for _ in range(num_predict)
]
)
# Output bias (optional, for vocab projection)
self.output_bias = nn.Parameter(torch.zeros(self.vocab_size))
# Embedding trunk: embed_dim → hidden for direct embedding path
# This is used by forward_from_embedding() for learnable encoder training
# Must be defined in __init__ (not lazily) so weights can be loaded from checkpoint
self._embed_trunk = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
)
# Initialize head output layers
for head in self.heads:
nn.init.normal_(head[-1].weight, std=0.02)
nn.init.zeros_(head[-1].bias)
# Initialize _embed_trunk layers
for layer in self._embed_trunk:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, std=0.02)
nn.init.zeros_(layer.bias)
@staticmethod
def _resolve_tokenizer_path(tokenizer_name: str) -> str:
"""Prefer local tokenizer snapshot to avoid network calls."""
# Explicit override via env
env_path = os.getenv("SEM_TOKENIZER_PATH")
if env_path and os.path.isdir(env_path):
return env_path
# If tokenizer_name is already a local path, use it
if os.path.isdir(tokenizer_name):
return tokenizer_name
# Try to resolve from HF cache
cache_root = os.path.expanduser("~/.cache/huggingface/hub")
model_dir = f"models--{tokenizer_name.replace('/', '--')}"
snapshot_glob = os.path.join(cache_root, model_dir, "snapshots", "*")
snapshots = sorted(glob.glob(snapshot_glob), key=os.path.getmtime, reverse=True)
if snapshots:
return snapshots[0]
# Fallback to original name (may fail offline)
return tokenizer_name
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Predict multiple future tokens from latent using TIED EMBEDDINGS.
Args:
z: Latent tensor (batch, latent_dim) or (latent_dim,)
Returns:
logits: (batch, num_predict, vocab_size) or (num_predict, vocab_size)
"""
squeeze = z.dim() == 1
if squeeze:
z = z.unsqueeze(0)
# Shared representation
h = self.trunk(z) # (batch, hidden_dim)
# Predict from each head using TIED EMBEDDINGS
all_logits = []
for head in self.heads:
# Head outputs embed_dim representation
embed = head(h) # (batch, embed_dim)
# Project to vocab using TIED embedding weights
# logits = embed @ vocab_embedding.weight.T + bias
logits = F.linear(embed, self.vocab_embedding.weight, self.output_bias)
all_logits.append(logits) # (batch, vocab_size)
# Stack: (batch, num_predict, vocab_size)
output = torch.stack(all_logits, dim=1)
if squeeze:
output = output.squeeze(0)
return output
def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
"""
Embed token IDs using the shared vocabulary embedding.
This enables using the same embeddings for both input and output,
which is the key to tied embeddings learning meaningful semantics.
Args:
token_ids: Token IDs (batch, seq_len) or (seq_len,)
Returns:
embeddings: Token embeddings (batch, seq_len, embed_dim) or (seq_len, embed_dim)
"""
return self.vocab_embedding(token_ids)
def forward_from_embedding(self, embed: torch.Tensor) -> torch.Tensor:
"""
Predict tokens directly from embedding space (bypasses latent→hidden trunk).
This provides a direct path for text generation when using learnable encoder
with tied embeddings. The embedding is projected to vocab using tied weights.
Args:
embed: Embedding tensor (batch, embed_dim) or (embed_dim,)
This should be the pooled embedding from LearnableEncoder.encode_to_embedding()
Returns:
logits: (batch, num_predict, vocab_size) or (num_predict, vocab_size)
"""
squeeze = embed.dim() == 1
if squeeze:
embed = embed.unsqueeze(0)
# Shared representation from embedding
# Note: _embed_trunk is now created in __init__ so weights load from checkpoint
h = self._embed_trunk(embed) # (batch, hidden_dim)
# Predict from each head using TIED EMBEDDINGS
all_logits = []
for head in self.heads:
head_embed = head(h) # (batch, embed_dim)
logits = F.linear(head_embed, self.vocab_embedding.weight, self.output_bias)
all_logits.append(logits)
output = torch.stack(all_logits, dim=1)
if squeeze:
output = output.squeeze(0)
return output
def decode_tokens(self, token_ids: torch.Tensor) -> str:
"""Decode token IDs to text."""
if token_ids.dim() > 1:
token_ids = token_ids.flatten()
return self.tokenizer.decode(token_ids.tolist(), skip_special_tokens=True)
def encode_text(self, text: str) -> torch.Tensor:
"""Encode text to token IDs."""
tokens = self.tokenizer.encode(text, return_tensors="pt")
return tokens.to(self.device)
@torch.no_grad()
def generate(
self,
z: torch.Tensor,
max_length: int = 100,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
) -> str:
"""
Generate text using multi-token prediction.
Generates num_predict tokens at a time for faster generation.
Args:
z: Latent vector (latent_dim,) or (1, latent_dim)
max_length: Maximum tokens to generate
temperature: Sampling temperature
top_k: Top-k filtering
top_p: Nucleus sampling
Returns:
Generated text string
"""
if z.dim() == 1:
z = z.unsqueeze(0)
generated_ids = []
current_z = z
safe_temperature = max(temperature, 1e-4)
while len(generated_ids) < max_length:
# Get multi-token prediction
logits = self.forward(current_z) # (1, num_predict, vocab_size)
# Sample from each position
for i in range(self.num_predict):
if len(generated_ids) >= max_length:
break
pos_logits = logits[0, i] / safe_temperature
# Top-k filtering
if top_k > 0:
top_k_vals, _ = torch.topk(
pos_logits, min(top_k, pos_logits.size(-1))
)
pos_logits[pos_logits < top_k_vals[-1]] = float("-inf")
# Top-p filtering
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(pos_logits, descending=True)
cumsum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
mask = cumsum > top_p
mask[1:] = mask[:-1].clone()
mask[0] = False
sorted_logits[mask] = float("-inf")
pos_logits = sorted_logits.scatter(
0, sorted_idx.argsort(), sorted_logits
)
if torch.isneginf(pos_logits).all():
pos_logits = torch.zeros_like(pos_logits)
pos_logits = torch.clamp(pos_logits, -50.0, 50.0)
probs = F.softmax(pos_logits, dim=-1)
probs = torch.where(
torch.isfinite(probs), probs, torch.zeros_like(probs)
)
probs = probs + 1e-8
total = probs.sum()
if not torch.isfinite(total) or total <= 0:
probs = torch.ones_like(probs) / probs.numel()
else:
probs = probs / total
token_id = torch.multinomial(probs, 1).item()
# Stop at EOS
if token_id == self.tokenizer.eos_token_id:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True
)
generated_ids.append(token_id)
# For continuous generation, we'd update z based on generated tokens
# This requires feeding tokens back through the encoder
# For now, we continue with same z (limitation - proper impl needs feedback)
return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
def multi_token_prediction_loss(
z_pred: torch.Tensor,
target_text: list[str],
decoder: MultiTokenDecoder,
max_length: int = 128,
loss_weight: float = 1.0,
position_decay: float = 1.0,
) -> tuple[torch.Tensor, dict]:
"""
Multi-token prediction loss following Meta's MTP approach.
Trains decoder to predict N future tokens from latent representation.
Each head predicts a different future position.
Reference:
- Meta (2024): Better & Faster LLMs via Multi-token Prediction
https://arxiv.org/abs/2404.19737
- NVIDIA Megatron: Average across depths, apply scaling factor
https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/multi_token_prediction.html
Args:
z_pred: Predicted latent (batch, latent_dim)
target_text: List of target text strings
decoder: MultiTokenDecoder module
max_length: Max tokenization length
loss_weight: Scaling factor for MTP loss (default 1.0, use 0.1 when
combining with main task loss per NVIDIA Megatron)
position_decay: Exponential decay factor for later positions (default 1.0 = no decay)
Use <1.0 (e.g., 0.9) to weight earlier positions more heavily
Returns:
loss: Combined cross-entropy loss (averaged across positions, scaled by loss_weight)
metrics: Dict with per-position accuracy and loss breakdown
"""
batch_size = z_pred.size(0)
device = z_pred.device
num_predict = decoder.num_predict
# Tokenize targets
encoded = decoder.tokenizer(
target_text,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
target_ids = encoded["input_ids"].to(device) # (batch, seq_len)
seq_len = target_ids.size(1)
# Get multi-token predictions
logits = decoder(z_pred) # (batch, num_predict, vocab_size)
# Compute loss for each prediction position
# Per Meta: average across all depths
# Per NVIDIA: compute average of MTP losses across all depths, multiply by scaling factor
total_loss = 0.0
total_weight = 0.0
accuracies = []
per_pos_losses = []
num_positions = min(num_predict, seq_len)
for i in range(num_positions):
pos_logits = logits[:, i] # (batch, vocab_size)
pos_targets = target_ids[:, i] # (batch,)
# Cross-entropy for this position
pos_loss = F.cross_entropy(pos_logits, pos_targets)
# Optional position weighting (earlier positions = higher weight)
pos_weight = position_decay**i
total_loss = total_loss + pos_loss * pos_weight
total_weight += pos_weight
per_pos_losses.append(pos_loss.item())
# Accuracy
with torch.no_grad():
preds = pos_logits.argmax(dim=-1)
acc = (preds == pos_targets).float().mean().item()
accuracies.append(acc)
# Average loss across positions (weighted if position_decay < 1.0)
avg_loss = total_loss / total_weight if total_weight > 0 else total_loss
# Apply loss weight (scaling factor per NVIDIA Megatron)
scaled_loss = avg_loss * loss_weight
metrics = {
"mtp_loss": scaled_loss.item(),
"mtp_loss_unscaled": avg_loss.item(),
"avg_accuracy": sum(accuracies) / len(accuracies) if accuracies else 0.0,
**{f"acc_pos_{i}": acc for i, acc in enumerate(accuracies)},
**{f"loss_pos_{i}": loss for i, loss in enumerate(per_pos_losses)},
}
return scaled_loss, metrics
# Alias for backward compatibility
TextDecoder = MultiTokenDecoder
text_generation_loss = multi_token_prediction_loss