| """Arkadiko V4 decoder. |
| |
| Dense transformer decoder with optional LASER2 cross-attention. |
| Three variants (set via config.cross_attention_mode): |
| - "per_layer": cross-attention at every decoder layer (Variant A) |
| - "none": pure decoder, no LASER2 (Variant B) |
| - "input_only": LASER2 added to token embeddings at input only (Variant C) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from arkadiko.embedding.rope import precompute_rotary_embeddings |
| from arkadiko.embedding.mlp import SwiGLU |
| from arkadiko.llm.attention import CausalSelfAttention, CrossAttention |
| from arkadiko.llm.config import V4Config |
|
|
|
|
| def norm(x): |
| return F.rms_norm(x, (x.size(-1),)) |
|
|
|
|
| class V4Block(nn.Module): |
| """Decoder block: self-attn → (optional cross-attn) → FFN.""" |
|
|
| def __init__(self, config: V4Config): |
| super().__init__() |
| self.config = config |
| self.use_cross_attn = (config.cross_attention_mode == "per_layer") |
|
|
| self.self_attn = CausalSelfAttention(config) |
| if self.use_cross_attn: |
| self.cross_attn = CrossAttention(config) |
| self.mlp = SwiGLU(config.n_embd, config.ffn_mult, hidden=config.ffn_hidden) |
|
|
| def forward(self, x, cos, sin, encoder_hidden=None, encoder_pad_mask=None): |
| |
| x = x + self.self_attn(norm(x), cos, sin) |
|
|
| |
| if self.use_cross_attn and encoder_hidden is not None: |
| x = x + self.cross_attn(norm(x), encoder_hidden, encoder_pad_mask) |
|
|
| |
| x = x + self.mlp(norm(x)) |
| return x |
|
|
|
|
| class V4Decoder(nn.Module): |
| """Arkadiko V4 decoder.""" |
|
|
| def __init__(self, config: V4Config): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd, padding_idx=config.pad_token_id) |
|
|
| |
| if config.cross_attention_mode == "input_only": |
| self.laser_input_proj = nn.Linear(config.laser_dim, config.n_embd, bias=False) |
|
|
| |
| self.blocks = nn.ModuleList([V4Block(config) for _ in range(config.n_layer)]) |
| self.final_norm_gamma = nn.Parameter(torch.ones(config.n_embd)) |
|
|
| |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| |
| cos, sin = precompute_rotary_embeddings( |
| config.max_seq_len, config.head_dim, config.rope_theta |
| ) |
| self.register_buffer("cos", cos) |
| self.register_buffer("sin", sin) |
|
|
| self.init_weights() |
|
|
| if config.tied_embeddings: |
| self.lm_head.weight = self.wte.weight |
|
|
| def init_weights(self): |
| std = self.config.init_std |
| n_embd = self.config.n_embd |
| s = 3**0.5 * n_embd**-0.5 |
|
|
| nn.init.normal_(self.wte.weight, mean=0.0, std=std) |
|
|
| for block in self.blocks: |
| nn.init.uniform_(block.self_attn.c_q.weight, -s, s) |
| nn.init.uniform_(block.self_attn.c_k.weight, -s, s) |
| nn.init.uniform_(block.self_attn.c_v.weight, -s, s) |
| nn.init.zeros_(block.self_attn.c_proj.weight) |
|
|
| if block.use_cross_attn: |
| |
| nn.init.uniform_(block.cross_attn.c_q.weight, -s, s) |
| |
| s_laser = 3**0.5 * self.config.laser_dim**-0.5 |
| nn.init.uniform_(block.cross_attn.c_k.weight, -s_laser, s_laser) |
| nn.init.uniform_(block.cross_attn.c_v.weight, -s_laser, s_laser) |
| nn.init.zeros_(block.cross_attn.c_proj.weight) |
|
|
| nn.init.uniform_(block.mlp.c_gate.weight, -s * 0.5, s * 0.5) |
| nn.init.uniform_(block.mlp.c_up.weight, -s * 0.5, s * 0.5) |
| nn.init.zeros_(block.mlp.c_proj.weight) |
|
|
| if hasattr(self, "laser_input_proj"): |
| nn.init.zeros_(self.laser_input_proj.weight) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| encoder_hidden: torch.Tensor | None = None, |
| encoder_pad_mask: torch.Tensor | None = None, |
| labels: torch.Tensor | None = None, |
| ): |
| """ |
| Args: |
| input_ids: [B, T] decoder tokens (causal LM targets) |
| encoder_hidden: [B, T_enc, laser_dim] LASER2 output (per_layer or input_only) |
| encoder_pad_mask: [B, T_enc] bool, True = pad |
| labels: [B, T] targets for cross-entropy loss (shifted by caller) |
| |
| Returns: |
| dict with 'logits' [B, T, V] and optionally 'loss' |
| """ |
| B, T = input_ids.shape |
|
|
| |
| x = self.wte(input_ids) |
|
|
| |
| if self.config.cross_attention_mode == "input_only" and encoder_hidden is not None: |
| |
| if encoder_pad_mask is not None: |
| mask = (~encoder_pad_mask).to(encoder_hidden.dtype).unsqueeze(-1) |
| laser_pool = (encoder_hidden * mask).sum(1) / mask.sum(1).clamp(min=1) |
| else: |
| laser_pool = encoder_hidden.mean(1) |
| laser_proj = self.laser_input_proj(laser_pool.to(x.dtype)) |
| x = x + laser_proj.unsqueeze(1) |
|
|
| |
| for block in self.blocks: |
| x = block(x, self.cos, self.sin, |
| encoder_hidden=encoder_hidden, |
| encoder_pad_mask=encoder_pad_mask) |
|
|
| |
| x = norm(x) * self.final_norm_gamma |
| logits = self.lm_head(x) |
|
|
| out = {"logits": logits} |
|
|
| if labels is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, self.config.vocab_size), |
| labels.view(-1), |
| ignore_index=self.config.pad_token_id, |
| ) |
| out["loss"] = loss |
|
|
| return out |
|
|
| def num_parameters(self, exclude_embedding: bool = False) -> int: |
| n = sum(p.numel() for p in self.parameters()) |
| if exclude_embedding: |
| n -= self.wte.weight.numel() |
| return n |
|
|