arcisvlm / model /decoder.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
10.8 kB
"""
MoE Decoder — Mixture of Experts text decoder for VL-JEPA.
Takes a predicted embedding from the JEPA predictor and autoregressively generates
text output. Each transformer block's FFN is replaced with a MoE layer containing
task-specialized experts.
Only invoked when selective decoding detects a semantic shift — NOT on every frame.
v2 additions:
- apply_lora() / clear_lora(): Inject/remove per-camera LoRA adapters
generated by the HyperNetwork via HyperMother orchestrator.
- LoRA targets Q, V attention projections only (NOT MoE FFN gating).
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.transformer import TransformerBlock
from model.moe import MoELayer
class MoEDecoder(nn.Module):
"""
MoE Decoder — generates text from predicted embedding.
Architecture:
Predicted embedding → prepend as first token
→ Token embedding + positional embedding
→ N × TransformerBlock(causal, MoE FFN)
→ LayerNorm → Linear head → logits
The predicted embedding from JEPA predictor is used as the initial "thought"
token, and the decoder autoregressively generates the text output.
Args:
hidden_dim: Transformer dimension (768)
embed_dim: Input embedding dimension from predictor (1536)
vocab_size: BPE vocabulary size (8192)
num_heads: Number of attention heads (12)
num_blocks: Number of transformer blocks (6)
num_experts: Number of experts per MoE layer (5)
top_k: Active experts per token (2)
max_seq_len: Maximum output sequence length (512)
dropout: Dropout rate
"""
def __init__(
self,
hidden_dim: int = 768,
embed_dim: int = 1536,
vocab_size: int = 8192,
num_heads: int = 12,
num_blocks: int = 6,
num_experts: int = 5,
top_k: int = 2,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
# Project predictor embedding to decoder dimension
self.embed_proj = nn.Linear(embed_dim, hidden_dim)
# Token and position embeddings
self.token_embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)
self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len + 1, hidden_dim) * 0.02) # +1 for embedding token
self.embed_dropout = nn.Dropout(dropout)
# Transformer blocks with MoE FFN (causal attention for autoregressive generation)
self.blocks = nn.ModuleList()
self.moe_layers: list[MoELayer] = []
self._feature_gates_enabled = False
for _ in range(num_blocks):
moe = MoELayer(hidden_dim, num_experts, top_k, dropout=dropout)
self.moe_layers.append(moe)
block = TransformerBlock(hidden_dim, num_heads, dropout, mode="causal", ffn=moe)
self.blocks.append(block)
# Control signal for v2 feature gating (set externally by HyperMother)
self._control_signal: Optional[torch.Tensor] = None
self.norm = nn.LayerNorm(hidden_dim)
# Output head: hidden_dim → vocab_size
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
def forward(
self,
pred_embedding: torch.Tensor,
target_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Args:
pred_embedding: [batch, embed_dim] — predicted embedding from JEPA predictor
target_ids: [batch, seq_len] — target token IDs for training (teacher forcing)
Returns:
logits: [batch, seq_len+1, vocab_size] — token predictions
loss: scalar loss if target_ids provided, else None
"""
B = pred_embedding.shape[0]
# Project embedding to decoder dimension: [B, embed_dim] → [B, 1, hidden_dim]
embed_token = self.embed_proj(pred_embedding).unsqueeze(1)
if target_ids is not None:
# Training: teacher forcing
token_embeds = self.token_embed(target_ids) # [B, T, hidden_dim]
x = torch.cat([embed_token, token_embeds], dim=1) # [B, 1+T, hidden_dim]
else:
# Inference: start with just the embedding token
x = embed_token # [B, 1, hidden_dim]
T = x.shape[1]
x = x + self.pos_embed[:, :T, :]
x = self.embed_dropout(x)
# Pass through MoE transformer blocks
for block in self.blocks:
x = block(x, control_signal=self._control_signal)
x = self.norm(x)
logits = self.lm_head(x) # [B, T, vocab_size]
loss = None
if target_ids is not None:
# Shift logits and targets for next-token prediction
# Logits from positions [0, ..., T-1] predict tokens at [1, ..., T]
# Position 0 (embedding token) predicts first text token
shift_logits = logits[:, :-1, :].contiguous() # [B, T, vocab]
shift_targets = target_ids.contiguous() # [B, T]
loss = F.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_targets.view(-1),
ignore_index=0, # ignore padding
)
return logits, loss
@torch.no_grad()
def generate(
self,
pred_embedding: torch.Tensor,
max_new_tokens: int = 256,
temperature: float = 0.0,
eos_id: int = 2,
) -> torch.Tensor:
"""
Autoregressive text generation.
Args:
pred_embedding: [batch, embed_dim] — predicted embedding
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
eos_id: End of sequence token ID
Returns:
[batch, generated_len] — generated token IDs
"""
B = pred_embedding.shape[0]
device = pred_embedding.device
# Start with just the embedding token
embed_token = self.embed_proj(pred_embedding).unsqueeze(1) # [B, 1, hidden_dim]
generated = torch.zeros(B, 0, dtype=torch.long, device=device)
x = embed_token
for _ in range(max_new_tokens):
T = x.shape[1]
pos_x = x + self.pos_embed[:, :T, :]
h = pos_x
for block in self.blocks:
h = block(h, control_signal=self._control_signal)
h = self.norm(h)
# Get logits for last position only
next_logits = self.lm_head(h[:, -1, :]) # [B, vocab]
if temperature <= 0:
next_token = next_logits.argmax(dim=-1, keepdim=True)
else:
probs = F.softmax(next_logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # [B, 1]
generated = torch.cat([generated, next_token], dim=1)
# Check for EOS
if (next_token == eos_id).all():
break
# Append token embedding for next step
next_embed = self.token_embed(next_token) # [B, 1, hidden_dim]
x = torch.cat([x, next_embed], dim=1)
return generated
def get_all_load_balancing_data(self) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""Collect load balancing data from all MoE layers."""
data = []
for moe in self.moe_layers:
lb_data = moe.get_load_balancing_data()
if lb_data is not None:
data.append(lb_data)
return data
# -- Feature control (v2) -----------------------------------------------
def enable_feature_gates(self, control_dim: int = 256) -> None:
"""
Add FeatureControlGate to all transformer blocks.
Called once during model initialization when v2 features are enabled.
Safe to call multiple times — gates are only added if not already present.
"""
from model.feature_control import FeatureControlGate
for block in self.blocks:
if block.feature_gate is None:
block.feature_gate = FeatureControlGate(
dim=self.hidden_dim,
control_dim=control_dim,
)
self._feature_gates_enabled = True
def set_control_signal(self, control_signal: Optional[torch.Tensor]) -> None:
"""
Set the control signal for feature gating.
Called by HyperMother before running decoder forward/generate.
Args:
control_signal: [B, control_dim] from ConditionEncoder, or None to clear
"""
self._control_signal = control_signal
# -- LoRA injection (v2) -------------------------------------------------
def apply_lora(
self,
lora_layers: list[dict[str, "nn.Module"]],
) -> None:
"""
Inject LoRA layers into all transformer blocks' attention modules.
Args:
lora_layers: List of dicts (one per block), mapping target name
("q", "v") to LoRALayer instances. Length must match
self.blocks.
"""
assert len(lora_layers) == len(self.blocks), (
f"Expected {len(self.blocks)} LoRA layer dicts, got {len(lora_layers)}"
)
for block, block_loras in zip(self.blocks, lora_layers):
block.attn.set_lora(
lora_q=block_loras.get("q"),
lora_v=block_loras.get("v"),
)
def clear_lora(self) -> None:
"""Remove all LoRA layers from all transformer blocks."""
for block in self.blocks:
block.attn.clear_lora()
@property
def has_lora(self) -> bool:
"""Whether any block currently has LoRA active."""
return any(block.attn.has_lora for block in self.blocks)
def apply_lora_from_flat(
self,
flat_params: torch.Tensor,
config: "LoRAConfig",
) -> None:
"""
Convenience: create and inject LoRA layers from a flat parameter vector.
This is the primary interface used by HyperMother — the HyperNetwork
outputs a flat tensor, and this method handles reshaping and injection.
Args:
flat_params: Flat parameter tensor from HyperNetwork
config: LoRA configuration
"""
from model.lora import LoRAInjector
injector = LoRAInjector(config, len(self.blocks), self.hidden_dim)
lora_layers = injector.create_lora_layers(
flat_params, device=next(self.parameters()).device
)
self.apply_lora(lora_layers)