"""Arkadiko V4 decoder. Dense transformer decoder with optional LASER2 cross-attention. Three variants (set via config.cross_attention_mode): - "per_layer": cross-attention at every decoder layer (Variant A) - "none": pure decoder, no LASER2 (Variant B) - "input_only": LASER2 added to token embeddings at input only (Variant C) """ import torch import torch.nn as nn import torch.nn.functional as F from arkadiko.embedding.rope import precompute_rotary_embeddings from arkadiko.embedding.mlp import SwiGLU from arkadiko.llm.attention import CausalSelfAttention, CrossAttention from arkadiko.llm.config import V4Config def norm(x): return F.rms_norm(x, (x.size(-1),)) class V4Block(nn.Module): """Decoder block: self-attn → (optional cross-attn) → FFN.""" def __init__(self, config: V4Config): super().__init__() self.config = config self.use_cross_attn = (config.cross_attention_mode == "per_layer") self.self_attn = CausalSelfAttention(config) if self.use_cross_attn: self.cross_attn = CrossAttention(config) self.mlp = SwiGLU(config.n_embd, config.ffn_mult, hidden=config.ffn_hidden) def forward(self, x, cos, sin, encoder_hidden=None, encoder_pad_mask=None): # Self-attention (pre-norm) x = x + self.self_attn(norm(x), cos, sin) # Cross-attention (if enabled and encoder output provided) if self.use_cross_attn and encoder_hidden is not None: x = x + self.cross_attn(norm(x), encoder_hidden, encoder_pad_mask) # FFN x = x + self.mlp(norm(x)) return x class V4Decoder(nn.Module): """Arkadiko V4 decoder.""" def __init__(self, config: V4Config): super().__init__() self.config = config # Token embedding self.wte = nn.Embedding(config.vocab_size, config.n_embd, padding_idx=config.pad_token_id) # Input projection for LASER2 (for input_only mode) if config.cross_attention_mode == "input_only": self.laser_input_proj = nn.Linear(config.laser_dim, config.n_embd, bias=False) # Decoder blocks self.blocks = nn.ModuleList([V4Block(config) for _ in range(config.n_layer)]) self.final_norm_gamma = nn.Parameter(torch.ones(config.n_embd)) # LM head (tied to wte if config says so) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # RoPE buffers cos, sin = precompute_rotary_embeddings( config.max_seq_len, config.head_dim, config.rope_theta ) self.register_buffer("cos", cos) self.register_buffer("sin", sin) self.init_weights() if config.tied_embeddings: self.lm_head.weight = self.wte.weight def init_weights(self): std = self.config.init_std n_embd = self.config.n_embd s = 3**0.5 * n_embd**-0.5 nn.init.normal_(self.wte.weight, mean=0.0, std=std) for block in self.blocks: nn.init.uniform_(block.self_attn.c_q.weight, -s, s) nn.init.uniform_(block.self_attn.c_k.weight, -s, s) nn.init.uniform_(block.self_attn.c_v.weight, -s, s) nn.init.zeros_(block.self_attn.c_proj.weight) if block.use_cross_attn: # Cross-attention inputs scaled for decoder dim nn.init.uniform_(block.cross_attn.c_q.weight, -s, s) # K/V project from laser_dim (1024) to decoder dim s_laser = 3**0.5 * self.config.laser_dim**-0.5 nn.init.uniform_(block.cross_attn.c_k.weight, -s_laser, s_laser) nn.init.uniform_(block.cross_attn.c_v.weight, -s_laser, s_laser) nn.init.zeros_(block.cross_attn.c_proj.weight) # start as no-op nn.init.uniform_(block.mlp.c_gate.weight, -s * 0.5, s * 0.5) nn.init.uniform_(block.mlp.c_up.weight, -s * 0.5, s * 0.5) nn.init.zeros_(block.mlp.c_proj.weight) if hasattr(self, "laser_input_proj"): nn.init.zeros_(self.laser_input_proj.weight) def forward( self, input_ids: torch.Tensor, encoder_hidden: torch.Tensor | None = None, encoder_pad_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, ): """ Args: input_ids: [B, T] decoder tokens (causal LM targets) encoder_hidden: [B, T_enc, laser_dim] LASER2 output (per_layer or input_only) encoder_pad_mask: [B, T_enc] bool, True = pad labels: [B, T] targets for cross-entropy loss (shifted by caller) Returns: dict with 'logits' [B, T, V] and optionally 'loss' """ B, T = input_ids.shape # Embeddings x = self.wte(input_ids) # Input-only LASER2 injection if self.config.cross_attention_mode == "input_only" and encoder_hidden is not None: # Mean-pool encoder output across time, broadcast to all positions if encoder_pad_mask is not None: mask = (~encoder_pad_mask).to(encoder_hidden.dtype).unsqueeze(-1) laser_pool = (encoder_hidden * mask).sum(1) / mask.sum(1).clamp(min=1) else: laser_pool = encoder_hidden.mean(1) laser_proj = self.laser_input_proj(laser_pool.to(x.dtype)) # [B, C] x = x + laser_proj.unsqueeze(1) # broadcast to all decoder positions # Decoder blocks for block in self.blocks: x = block(x, self.cos, self.sin, encoder_hidden=encoder_hidden, encoder_pad_mask=encoder_pad_mask) # Final norm + LM head x = norm(x) * self.final_norm_gamma logits = self.lm_head(x) out = {"logits": logits} if labels is not None: loss = F.cross_entropy( logits.view(-1, self.config.vocab_size), labels.view(-1), ignore_index=self.config.pad_token_id, ) out["loss"] = loss return out def num_parameters(self, exclude_embedding: bool = False) -> int: n = sum(p.numel() for p in self.parameters()) if exclude_embedding: n -= self.wte.weight.numel() return n