| """ |
| Text Decoder Module for SEM V6 - Multi-Token Prediction. |
| |
| Converts latent representations to text using Llama-3 tokenizer. |
| Uses multi-token prediction (MTP) for faster training and better representations. |
| |
| Architecture: |
| latent (256) → N prediction heads → vocab_logits[N, vocab_size] → tokens → text |
| |
| Reference: |
| - Multi-Token Prediction (Meta 2024): https://arxiv.org/abs/2404.19737 |
| - Llama 3: https://llama.meta.com/ |
| """ |
|
|
| import os |
| import glob |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
| from transformers import AutoTokenizer |
|
|
|
|
| class MultiTokenDecoder(nn.Module): |
| """ |
| Multi-token prediction decoder for SEM V6 with TIED EMBEDDINGS. |
| |
| Predicts N future tokens simultaneously from a latent vector, |
| enabling faster training and richer representations. |
| |
| Architecture (per prediction head): |
| latent (latent_dim) |
| → shared_trunk (latent_dim → hidden_dim) |
| → head_i (hidden_dim → embed_dim) |
| → vocab_embedding.T (embed_dim → vocab_size) # TIED WEIGHTS |
| |
| The key innovation is TIED EMBEDDINGS: the vocabulary embedding matrix |
| is shared between input and output projections. This: |
| 1. Reduces parameters by 256x (from hidden→vocab to hidden→embed + embed→vocab shared) |
| 2. Provides semantic guidance - vocabulary learns meaningful relationships |
| 3. Matches standard LLM practice (GPT-2, Llama, etc.) |
| |
| Each head predicts token at position i (0 = next token, 1 = token after, etc.) |
| |
| Example: |
| >>> decoder = MultiTokenDecoder(latent_dim=256, num_predict=4) |
| >>> z = torch.randn(1, 256) |
| >>> logits = decoder(z) # (1, 4, vocab_size) |
| >>> text = decoder.generate(z, max_length=100) |
| """ |
|
|
| def __init__( |
| self, |
| latent_dim: int = 256, |
| hidden_dim: int = 512, |
| num_predict: int = 4, |
| tokenizer_name: str = "meta-llama/Meta-Llama-3-8B", |
| device: str = "cuda", |
| embed_dim: int = 512, |
| ) -> None: |
| """ |
| Initialize multi-token decoder with tied embeddings. |
| |
| Args: |
| latent_dim: Dimension of input latent vectors |
| hidden_dim: Hidden layer dimension |
| num_predict: Number of future tokens to predict (default 4) |
| tokenizer_name: HuggingFace tokenizer (default: Llama-3.2-1B) |
| device: Computation device |
| embed_dim: Vocabulary embedding dimension (default 512) |
| Tied between input/output for semantic learning |
| """ |
| super().__init__() |
| self.latent_dim = latent_dim |
| self.hidden_dim = hidden_dim |
| self.num_predict = num_predict |
| self.device = torch.device(device) |
| self.embed_dim = embed_dim |
|
|
| |
| tokenizer_path = self._resolve_tokenizer_path(tokenizer_name) |
|
|
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_path, |
| trust_remote_code=True, |
| local_files_only=True, |
| ) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| self.vocab_size = len(self.tokenizer) |
|
|
| |
| |
| |
| self.vocab_embedding = nn.Embedding(self.vocab_size, embed_dim) |
| |
| |
| embed_init_std = 3.0 / (embed_dim ** 0.5) |
| nn.init.normal_(self.vocab_embedding.weight, std=embed_init_std) |
|
|
| |
| self.trunk = nn.Sequential( |
| nn.LayerNorm(latent_dim), |
| nn.Linear(latent_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| ) |
|
|
| |
| |
| |
| self.heads = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear( |
| hidden_dim, embed_dim |
| ), |
| ) |
| for _ in range(num_predict) |
| ] |
| ) |
|
|
| |
| self.output_bias = nn.Parameter(torch.zeros(self.vocab_size)) |
|
|
| |
| |
| |
| self._embed_trunk = nn.Sequential( |
| nn.LayerNorm(embed_dim), |
| nn.Linear(embed_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| ) |
|
|
| |
| for head in self.heads: |
| nn.init.normal_(head[-1].weight, std=0.02) |
| nn.init.zeros_(head[-1].bias) |
|
|
| |
| for layer in self._embed_trunk: |
| if isinstance(layer, nn.Linear): |
| nn.init.normal_(layer.weight, std=0.02) |
| nn.init.zeros_(layer.bias) |
|
|
| @staticmethod |
| def _resolve_tokenizer_path(tokenizer_name: str) -> str: |
| """Prefer local tokenizer snapshot to avoid network calls.""" |
| |
| env_path = os.getenv("SEM_TOKENIZER_PATH") |
| if env_path and os.path.isdir(env_path): |
| return env_path |
|
|
| |
| if os.path.isdir(tokenizer_name): |
| return tokenizer_name |
|
|
| |
| cache_root = os.path.expanduser("~/.cache/huggingface/hub") |
| model_dir = f"models--{tokenizer_name.replace('/', '--')}" |
| snapshot_glob = os.path.join(cache_root, model_dir, "snapshots", "*") |
| snapshots = sorted(glob.glob(snapshot_glob), key=os.path.getmtime, reverse=True) |
| if snapshots: |
| return snapshots[0] |
|
|
| |
| return tokenizer_name |
|
|
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| """ |
| Predict multiple future tokens from latent using TIED EMBEDDINGS. |
| |
| Args: |
| z: Latent tensor (batch, latent_dim) or (latent_dim,) |
| |
| Returns: |
| logits: (batch, num_predict, vocab_size) or (num_predict, vocab_size) |
| """ |
| squeeze = z.dim() == 1 |
| if squeeze: |
| z = z.unsqueeze(0) |
|
|
| |
| h = self.trunk(z) |
|
|
| |
| all_logits = [] |
| for head in self.heads: |
| |
| embed = head(h) |
|
|
| |
| |
| logits = F.linear(embed, self.vocab_embedding.weight, self.output_bias) |
| all_logits.append(logits) |
|
|
| |
| output = torch.stack(all_logits, dim=1) |
|
|
| if squeeze: |
| output = output.squeeze(0) |
|
|
| return output |
|
|
| def embed_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Embed token IDs using the shared vocabulary embedding. |
| |
| This enables using the same embeddings for both input and output, |
| which is the key to tied embeddings learning meaningful semantics. |
| |
| Args: |
| token_ids: Token IDs (batch, seq_len) or (seq_len,) |
| |
| Returns: |
| embeddings: Token embeddings (batch, seq_len, embed_dim) or (seq_len, embed_dim) |
| """ |
| return self.vocab_embedding(token_ids) |
|
|
| def forward_from_embedding(self, embed: torch.Tensor) -> torch.Tensor: |
| """ |
| Predict tokens directly from embedding space (bypasses latent→hidden trunk). |
| |
| This provides a direct path for text generation when using learnable encoder |
| with tied embeddings. The embedding is projected to vocab using tied weights. |
| |
| Args: |
| embed: Embedding tensor (batch, embed_dim) or (embed_dim,) |
| This should be the pooled embedding from LearnableEncoder.encode_to_embedding() |
| |
| Returns: |
| logits: (batch, num_predict, vocab_size) or (num_predict, vocab_size) |
| """ |
| squeeze = embed.dim() == 1 |
| if squeeze: |
| embed = embed.unsqueeze(0) |
|
|
| |
| |
| h = self._embed_trunk(embed) |
|
|
| |
| all_logits = [] |
| for head in self.heads: |
| head_embed = head(h) |
| logits = F.linear(head_embed, self.vocab_embedding.weight, self.output_bias) |
| all_logits.append(logits) |
|
|
| output = torch.stack(all_logits, dim=1) |
|
|
| if squeeze: |
| output = output.squeeze(0) |
|
|
| return output |
|
|
| def decode_tokens(self, token_ids: torch.Tensor) -> str: |
| """Decode token IDs to text.""" |
| if token_ids.dim() > 1: |
| token_ids = token_ids.flatten() |
| return self.tokenizer.decode(token_ids.tolist(), skip_special_tokens=True) |
|
|
| def encode_text(self, text: str) -> torch.Tensor: |
| """Encode text to token IDs.""" |
| tokens = self.tokenizer.encode(text, return_tensors="pt") |
| return tokens.to(self.device) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| z: torch.Tensor, |
| max_length: int = 100, |
| temperature: float = 0.8, |
| top_k: int = 50, |
| top_p: float = 0.9, |
| ) -> str: |
| """ |
| Generate text using multi-token prediction. |
| |
| Generates num_predict tokens at a time for faster generation. |
| |
| Args: |
| z: Latent vector (latent_dim,) or (1, latent_dim) |
| max_length: Maximum tokens to generate |
| temperature: Sampling temperature |
| top_k: Top-k filtering |
| top_p: Nucleus sampling |
| |
| Returns: |
| Generated text string |
| """ |
| if z.dim() == 1: |
| z = z.unsqueeze(0) |
|
|
| generated_ids = [] |
| current_z = z |
| safe_temperature = max(temperature, 1e-4) |
|
|
| while len(generated_ids) < max_length: |
| |
| logits = self.forward(current_z) |
|
|
| |
| for i in range(self.num_predict): |
| if len(generated_ids) >= max_length: |
| break |
|
|
| pos_logits = logits[0, i] / safe_temperature |
|
|
| |
| if top_k > 0: |
| top_k_vals, _ = torch.topk( |
| pos_logits, min(top_k, pos_logits.size(-1)) |
| ) |
| pos_logits[pos_logits < top_k_vals[-1]] = float("-inf") |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(pos_logits, descending=True) |
| cumsum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| mask = cumsum > top_p |
| mask[1:] = mask[:-1].clone() |
| mask[0] = False |
| sorted_logits[mask] = float("-inf") |
| pos_logits = sorted_logits.scatter( |
| 0, sorted_idx.argsort(), sorted_logits |
| ) |
|
|
| if torch.isneginf(pos_logits).all(): |
| pos_logits = torch.zeros_like(pos_logits) |
|
|
| pos_logits = torch.clamp(pos_logits, -50.0, 50.0) |
| probs = F.softmax(pos_logits, dim=-1) |
| probs = torch.where( |
| torch.isfinite(probs), probs, torch.zeros_like(probs) |
| ) |
| probs = probs + 1e-8 |
| total = probs.sum() |
| if not torch.isfinite(total) or total <= 0: |
| probs = torch.ones_like(probs) / probs.numel() |
| else: |
| probs = probs / total |
| token_id = torch.multinomial(probs, 1).item() |
|
|
| |
| if token_id == self.tokenizer.eos_token_id: |
| return self.tokenizer.decode( |
| generated_ids, skip_special_tokens=True |
| ) |
|
|
| generated_ids.append(token_id) |
|
|
| |
| |
| |
|
|
| return self.tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
| def multi_token_prediction_loss( |
| z_pred: torch.Tensor, |
| target_text: list[str], |
| decoder: MultiTokenDecoder, |
| max_length: int = 128, |
| loss_weight: float = 1.0, |
| position_decay: float = 1.0, |
| ) -> tuple[torch.Tensor, dict]: |
| """ |
| Multi-token prediction loss following Meta's MTP approach. |
| |
| Trains decoder to predict N future tokens from latent representation. |
| Each head predicts a different future position. |
| |
| Reference: |
| - Meta (2024): Better & Faster LLMs via Multi-token Prediction |
| https://arxiv.org/abs/2404.19737 |
| - NVIDIA Megatron: Average across depths, apply scaling factor |
| https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/multi_token_prediction.html |
| |
| Args: |
| z_pred: Predicted latent (batch, latent_dim) |
| target_text: List of target text strings |
| decoder: MultiTokenDecoder module |
| max_length: Max tokenization length |
| loss_weight: Scaling factor for MTP loss (default 1.0, use 0.1 when |
| combining with main task loss per NVIDIA Megatron) |
| position_decay: Exponential decay factor for later positions (default 1.0 = no decay) |
| Use <1.0 (e.g., 0.9) to weight earlier positions more heavily |
| |
| Returns: |
| loss: Combined cross-entropy loss (averaged across positions, scaled by loss_weight) |
| metrics: Dict with per-position accuracy and loss breakdown |
| """ |
| batch_size = z_pred.size(0) |
| device = z_pred.device |
| num_predict = decoder.num_predict |
|
|
| |
| encoded = decoder.tokenizer( |
| target_text, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt", |
| ) |
| target_ids = encoded["input_ids"].to(device) |
| seq_len = target_ids.size(1) |
|
|
| |
| logits = decoder(z_pred) |
|
|
| |
| |
| |
| total_loss = 0.0 |
| total_weight = 0.0 |
| accuracies = [] |
| per_pos_losses = [] |
|
|
| num_positions = min(num_predict, seq_len) |
| for i in range(num_positions): |
| pos_logits = logits[:, i] |
| pos_targets = target_ids[:, i] |
|
|
| |
| pos_loss = F.cross_entropy(pos_logits, pos_targets) |
|
|
| |
| pos_weight = position_decay**i |
| total_loss = total_loss + pos_loss * pos_weight |
| total_weight += pos_weight |
|
|
| per_pos_losses.append(pos_loss.item()) |
|
|
| |
| with torch.no_grad(): |
| preds = pos_logits.argmax(dim=-1) |
| acc = (preds == pos_targets).float().mean().item() |
| accuracies.append(acc) |
|
|
| |
| avg_loss = total_loss / total_weight if total_weight > 0 else total_loss |
|
|
| |
| scaled_loss = avg_loss * loss_weight |
|
|
| metrics = { |
| "mtp_loss": scaled_loss.item(), |
| "mtp_loss_unscaled": avg_loss.item(), |
| "avg_accuracy": sum(accuracies) / len(accuracies) if accuracies else 0.0, |
| **{f"acc_pos_{i}": acc for i, acc in enumerate(accuracies)}, |
| **{f"loss_pos_{i}": loss for i, loss in enumerate(per_pos_losses)}, |
| } |
|
|
| return scaled_loss, metrics |
|
|
|
|
| |
| TextDecoder = MultiTokenDecoder |
| text_generation_loss = multi_token_prediction_loss |
|
|