""" MoE Decoder — Mixture of Experts text decoder for VL-JEPA. Takes a predicted embedding from the JEPA predictor and autoregressively generates text output. Each transformer block's FFN is replaced with a MoE layer containing task-specialized experts. Only invoked when selective decoding detects a semantic shift — NOT on every frame. v2 additions: - apply_lora() / clear_lora(): Inject/remove per-camera LoRA adapters generated by the HyperNetwork via HyperMother orchestrator. - LoRA targets Q, V attention projections only (NOT MoE FFN gating). """ from __future__ import annotations from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from model.transformer import TransformerBlock from model.moe import MoELayer class MoEDecoder(nn.Module): """ MoE Decoder — generates text from predicted embedding. Architecture: Predicted embedding → prepend as first token → Token embedding + positional embedding → N × TransformerBlock(causal, MoE FFN) → LayerNorm → Linear head → logits The predicted embedding from JEPA predictor is used as the initial "thought" token, and the decoder autoregressively generates the text output. Args: hidden_dim: Transformer dimension (768) embed_dim: Input embedding dimension from predictor (1536) vocab_size: BPE vocabulary size (8192) num_heads: Number of attention heads (12) num_blocks: Number of transformer blocks (6) num_experts: Number of experts per MoE layer (5) top_k: Active experts per token (2) max_seq_len: Maximum output sequence length (512) dropout: Dropout rate """ def __init__( self, hidden_dim: int = 768, embed_dim: int = 1536, vocab_size: int = 8192, num_heads: int = 12, num_blocks: int = 6, num_experts: int = 5, top_k: int = 2, max_seq_len: int = 512, dropout: float = 0.1, ): super().__init__() self.hidden_dim = hidden_dim self.vocab_size = vocab_size self.max_seq_len = max_seq_len # Project predictor embedding to decoder dimension self.embed_proj = nn.Linear(embed_dim, hidden_dim) # Token and position embeddings self.token_embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=0) self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len + 1, hidden_dim) * 0.02) # +1 for embedding token self.embed_dropout = nn.Dropout(dropout) # Transformer blocks with MoE FFN (causal attention for autoregressive generation) self.blocks = nn.ModuleList() self.moe_layers: list[MoELayer] = [] self._feature_gates_enabled = False for _ in range(num_blocks): moe = MoELayer(hidden_dim, num_experts, top_k, dropout=dropout) self.moe_layers.append(moe) block = TransformerBlock(hidden_dim, num_heads, dropout, mode="causal", ffn=moe) self.blocks.append(block) # Control signal for v2 feature gating (set externally by HyperMother) self._control_signal: Optional[torch.Tensor] = None self.norm = nn.LayerNorm(hidden_dim) # Output head: hidden_dim → vocab_size self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False) def forward( self, pred_embedding: torch.Tensor, target_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Args: pred_embedding: [batch, embed_dim] — predicted embedding from JEPA predictor target_ids: [batch, seq_len] — target token IDs for training (teacher forcing) Returns: logits: [batch, seq_len+1, vocab_size] — token predictions loss: scalar loss if target_ids provided, else None """ B = pred_embedding.shape[0] # Project embedding to decoder dimension: [B, embed_dim] → [B, 1, hidden_dim] embed_token = self.embed_proj(pred_embedding).unsqueeze(1) if target_ids is not None: # Training: teacher forcing token_embeds = self.token_embed(target_ids) # [B, T, hidden_dim] x = torch.cat([embed_token, token_embeds], dim=1) # [B, 1+T, hidden_dim] else: # Inference: start with just the embedding token x = embed_token # [B, 1, hidden_dim] T = x.shape[1] x = x + self.pos_embed[:, :T, :] x = self.embed_dropout(x) # Pass through MoE transformer blocks for block in self.blocks: x = block(x, control_signal=self._control_signal) x = self.norm(x) logits = self.lm_head(x) # [B, T, vocab_size] loss = None if target_ids is not None: # Shift logits and targets for next-token prediction # Logits from positions [0, ..., T-1] predict tokens at [1, ..., T] # Position 0 (embedding token) predicts first text token shift_logits = logits[:, :-1, :].contiguous() # [B, T, vocab] shift_targets = target_ids.contiguous() # [B, T] loss = F.cross_entropy( shift_logits.view(-1, self.vocab_size), shift_targets.view(-1), ignore_index=0, # ignore padding ) return logits, loss @torch.no_grad() def generate( self, pred_embedding: torch.Tensor, max_new_tokens: int = 256, temperature: float = 0.0, eos_id: int = 2, ) -> torch.Tensor: """ Autoregressive text generation. Args: pred_embedding: [batch, embed_dim] — predicted embedding max_new_tokens: Maximum tokens to generate temperature: Sampling temperature eos_id: End of sequence token ID Returns: [batch, generated_len] — generated token IDs """ B = pred_embedding.shape[0] device = pred_embedding.device # Start with just the embedding token embed_token = self.embed_proj(pred_embedding).unsqueeze(1) # [B, 1, hidden_dim] generated = torch.zeros(B, 0, dtype=torch.long, device=device) x = embed_token for _ in range(max_new_tokens): T = x.shape[1] pos_x = x + self.pos_embed[:, :T, :] h = pos_x for block in self.blocks: h = block(h, control_signal=self._control_signal) h = self.norm(h) # Get logits for last position only next_logits = self.lm_head(h[:, -1, :]) # [B, vocab] if temperature <= 0: next_token = next_logits.argmax(dim=-1, keepdim=True) else: probs = F.softmax(next_logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # [B, 1] generated = torch.cat([generated, next_token], dim=1) # Check for EOS if (next_token == eos_id).all(): break # Append token embedding for next step next_embed = self.token_embed(next_token) # [B, 1, hidden_dim] x = torch.cat([x, next_embed], dim=1) return generated def get_all_load_balancing_data(self) -> list[tuple[torch.Tensor, torch.Tensor]]: """Collect load balancing data from all MoE layers.""" data = [] for moe in self.moe_layers: lb_data = moe.get_load_balancing_data() if lb_data is not None: data.append(lb_data) return data # -- Feature control (v2) ----------------------------------------------- def enable_feature_gates(self, control_dim: int = 256) -> None: """ Add FeatureControlGate to all transformer blocks. Called once during model initialization when v2 features are enabled. Safe to call multiple times — gates are only added if not already present. """ from model.feature_control import FeatureControlGate for block in self.blocks: if block.feature_gate is None: block.feature_gate = FeatureControlGate( dim=self.hidden_dim, control_dim=control_dim, ) self._feature_gates_enabled = True def set_control_signal(self, control_signal: Optional[torch.Tensor]) -> None: """ Set the control signal for feature gating. Called by HyperMother before running decoder forward/generate. Args: control_signal: [B, control_dim] from ConditionEncoder, or None to clear """ self._control_signal = control_signal # -- LoRA injection (v2) ------------------------------------------------- def apply_lora( self, lora_layers: list[dict[str, "nn.Module"]], ) -> None: """ Inject LoRA layers into all transformer blocks' attention modules. Args: lora_layers: List of dicts (one per block), mapping target name ("q", "v") to LoRALayer instances. Length must match self.blocks. """ assert len(lora_layers) == len(self.blocks), ( f"Expected {len(self.blocks)} LoRA layer dicts, got {len(lora_layers)}" ) for block, block_loras in zip(self.blocks, lora_layers): block.attn.set_lora( lora_q=block_loras.get("q"), lora_v=block_loras.get("v"), ) def clear_lora(self) -> None: """Remove all LoRA layers from all transformer blocks.""" for block in self.blocks: block.attn.clear_lora() @property def has_lora(self) -> bool: """Whether any block currently has LoRA active.""" return any(block.attn.has_lora for block in self.blocks) def apply_lora_from_flat( self, flat_params: torch.Tensor, config: "LoRAConfig", ) -> None: """ Convenience: create and inject LoRA layers from a flat parameter vector. This is the primary interface used by HyperMother — the HyperNetwork outputs a flat tensor, and this method handles reshaping and injection. Args: flat_params: Flat parameter tensor from HyperNetwork config: LoRA configuration """ from model.lora import LoRAInjector injector = LoRAInjector(config, len(self.blocks), self.hidden_dim) lora_layers = injector.create_lora_layers( flat_params, device=next(self.parameters()).device ) self.apply_lora(lora_layers)