File size: 6,331 Bytes
f5b9ac5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """Arkadiko V4 decoder.
Dense transformer decoder with optional LASER2 cross-attention.
Three variants (set via config.cross_attention_mode):
- "per_layer": cross-attention at every decoder layer (Variant A)
- "none": pure decoder, no LASER2 (Variant B)
- "input_only": LASER2 added to token embeddings at input only (Variant C)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from arkadiko.embedding.rope import precompute_rotary_embeddings
from arkadiko.embedding.mlp import SwiGLU
from arkadiko.llm.attention import CausalSelfAttention, CrossAttention
from arkadiko.llm.config import V4Config
def norm(x):
return F.rms_norm(x, (x.size(-1),))
class V4Block(nn.Module):
"""Decoder block: self-attn → (optional cross-attn) → FFN."""
def __init__(self, config: V4Config):
super().__init__()
self.config = config
self.use_cross_attn = (config.cross_attention_mode == "per_layer")
self.self_attn = CausalSelfAttention(config)
if self.use_cross_attn:
self.cross_attn = CrossAttention(config)
self.mlp = SwiGLU(config.n_embd, config.ffn_mult, hidden=config.ffn_hidden)
def forward(self, x, cos, sin, encoder_hidden=None, encoder_pad_mask=None):
# Self-attention (pre-norm)
x = x + self.self_attn(norm(x), cos, sin)
# Cross-attention (if enabled and encoder output provided)
if self.use_cross_attn and encoder_hidden is not None:
x = x + self.cross_attn(norm(x), encoder_hidden, encoder_pad_mask)
# FFN
x = x + self.mlp(norm(x))
return x
class V4Decoder(nn.Module):
"""Arkadiko V4 decoder."""
def __init__(self, config: V4Config):
super().__init__()
self.config = config
# Token embedding
self.wte = nn.Embedding(config.vocab_size, config.n_embd, padding_idx=config.pad_token_id)
# Input projection for LASER2 (for input_only mode)
if config.cross_attention_mode == "input_only":
self.laser_input_proj = nn.Linear(config.laser_dim, config.n_embd, bias=False)
# Decoder blocks
self.blocks = nn.ModuleList([V4Block(config) for _ in range(config.n_layer)])
self.final_norm_gamma = nn.Parameter(torch.ones(config.n_embd))
# LM head (tied to wte if config says so)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# RoPE buffers
cos, sin = precompute_rotary_embeddings(
config.max_seq_len, config.head_dim, config.rope_theta
)
self.register_buffer("cos", cos)
self.register_buffer("sin", sin)
self.init_weights()
if config.tied_embeddings:
self.lm_head.weight = self.wte.weight
def init_weights(self):
std = self.config.init_std
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5
nn.init.normal_(self.wte.weight, mean=0.0, std=std)
for block in self.blocks:
nn.init.uniform_(block.self_attn.c_q.weight, -s, s)
nn.init.uniform_(block.self_attn.c_k.weight, -s, s)
nn.init.uniform_(block.self_attn.c_v.weight, -s, s)
nn.init.zeros_(block.self_attn.c_proj.weight)
if block.use_cross_attn:
# Cross-attention inputs scaled for decoder dim
nn.init.uniform_(block.cross_attn.c_q.weight, -s, s)
# K/V project from laser_dim (1024) to decoder dim
s_laser = 3**0.5 * self.config.laser_dim**-0.5
nn.init.uniform_(block.cross_attn.c_k.weight, -s_laser, s_laser)
nn.init.uniform_(block.cross_attn.c_v.weight, -s_laser, s_laser)
nn.init.zeros_(block.cross_attn.c_proj.weight) # start as no-op
nn.init.uniform_(block.mlp.c_gate.weight, -s * 0.5, s * 0.5)
nn.init.uniform_(block.mlp.c_up.weight, -s * 0.5, s * 0.5)
nn.init.zeros_(block.mlp.c_proj.weight)
if hasattr(self, "laser_input_proj"):
nn.init.zeros_(self.laser_input_proj.weight)
def forward(
self,
input_ids: torch.Tensor,
encoder_hidden: torch.Tensor | None = None,
encoder_pad_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
):
"""
Args:
input_ids: [B, T] decoder tokens (causal LM targets)
encoder_hidden: [B, T_enc, laser_dim] LASER2 output (per_layer or input_only)
encoder_pad_mask: [B, T_enc] bool, True = pad
labels: [B, T] targets for cross-entropy loss (shifted by caller)
Returns:
dict with 'logits' [B, T, V] and optionally 'loss'
"""
B, T = input_ids.shape
# Embeddings
x = self.wte(input_ids)
# Input-only LASER2 injection
if self.config.cross_attention_mode == "input_only" and encoder_hidden is not None:
# Mean-pool encoder output across time, broadcast to all positions
if encoder_pad_mask is not None:
mask = (~encoder_pad_mask).to(encoder_hidden.dtype).unsqueeze(-1)
laser_pool = (encoder_hidden * mask).sum(1) / mask.sum(1).clamp(min=1)
else:
laser_pool = encoder_hidden.mean(1)
laser_proj = self.laser_input_proj(laser_pool.to(x.dtype)) # [B, C]
x = x + laser_proj.unsqueeze(1) # broadcast to all decoder positions
# Decoder blocks
for block in self.blocks:
x = block(x, self.cos, self.sin,
encoder_hidden=encoder_hidden,
encoder_pad_mask=encoder_pad_mask)
# Final norm + LM head
x = norm(x) * self.final_norm_gamma
logits = self.lm_head(x)
out = {"logits": logits}
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
labels.view(-1),
ignore_index=self.config.pad_token_id,
)
out["loss"] = loss
return out
def num_parameters(self, exclude_embedding: bool = False) -> int:
n = sum(p.numel() for p in self.parameters())
if exclude_embedding:
n -= self.wte.weight.numel()
return n
|