Ahmed
Upload code/model.py with huggingface_hub
f5b9ac5 verified
"""Arkadiko V4 decoder.
Dense transformer decoder with optional LASER2 cross-attention.
Three variants (set via config.cross_attention_mode):
- "per_layer": cross-attention at every decoder layer (Variant A)
- "none": pure decoder, no LASER2 (Variant B)
- "input_only": LASER2 added to token embeddings at input only (Variant C)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from arkadiko.embedding.rope import precompute_rotary_embeddings
from arkadiko.embedding.mlp import SwiGLU
from arkadiko.llm.attention import CausalSelfAttention, CrossAttention
from arkadiko.llm.config import V4Config
def norm(x):
return F.rms_norm(x, (x.size(-1),))
class V4Block(nn.Module):
"""Decoder block: self-attn → (optional cross-attn) → FFN."""
def __init__(self, config: V4Config):
super().__init__()
self.config = config
self.use_cross_attn = (config.cross_attention_mode == "per_layer")
self.self_attn = CausalSelfAttention(config)
if self.use_cross_attn:
self.cross_attn = CrossAttention(config)
self.mlp = SwiGLU(config.n_embd, config.ffn_mult, hidden=config.ffn_hidden)
def forward(self, x, cos, sin, encoder_hidden=None, encoder_pad_mask=None):
# Self-attention (pre-norm)
x = x + self.self_attn(norm(x), cos, sin)
# Cross-attention (if enabled and encoder output provided)
if self.use_cross_attn and encoder_hidden is not None:
x = x + self.cross_attn(norm(x), encoder_hidden, encoder_pad_mask)
# FFN
x = x + self.mlp(norm(x))
return x
class V4Decoder(nn.Module):
"""Arkadiko V4 decoder."""
def __init__(self, config: V4Config):
super().__init__()
self.config = config
# Token embedding
self.wte = nn.Embedding(config.vocab_size, config.n_embd, padding_idx=config.pad_token_id)
# Input projection for LASER2 (for input_only mode)
if config.cross_attention_mode == "input_only":
self.laser_input_proj = nn.Linear(config.laser_dim, config.n_embd, bias=False)
# Decoder blocks
self.blocks = nn.ModuleList([V4Block(config) for _ in range(config.n_layer)])
self.final_norm_gamma = nn.Parameter(torch.ones(config.n_embd))
# LM head (tied to wte if config says so)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# RoPE buffers
cos, sin = precompute_rotary_embeddings(
config.max_seq_len, config.head_dim, config.rope_theta
)
self.register_buffer("cos", cos)
self.register_buffer("sin", sin)
self.init_weights()
if config.tied_embeddings:
self.lm_head.weight = self.wte.weight
def init_weights(self):
std = self.config.init_std
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5
nn.init.normal_(self.wte.weight, mean=0.0, std=std)
for block in self.blocks:
nn.init.uniform_(block.self_attn.c_q.weight, -s, s)
nn.init.uniform_(block.self_attn.c_k.weight, -s, s)
nn.init.uniform_(block.self_attn.c_v.weight, -s, s)
nn.init.zeros_(block.self_attn.c_proj.weight)
if block.use_cross_attn:
# Cross-attention inputs scaled for decoder dim
nn.init.uniform_(block.cross_attn.c_q.weight, -s, s)
# K/V project from laser_dim (1024) to decoder dim
s_laser = 3**0.5 * self.config.laser_dim**-0.5
nn.init.uniform_(block.cross_attn.c_k.weight, -s_laser, s_laser)
nn.init.uniform_(block.cross_attn.c_v.weight, -s_laser, s_laser)
nn.init.zeros_(block.cross_attn.c_proj.weight) # start as no-op
nn.init.uniform_(block.mlp.c_gate.weight, -s * 0.5, s * 0.5)
nn.init.uniform_(block.mlp.c_up.weight, -s * 0.5, s * 0.5)
nn.init.zeros_(block.mlp.c_proj.weight)
if hasattr(self, "laser_input_proj"):
nn.init.zeros_(self.laser_input_proj.weight)
def forward(
self,
input_ids: torch.Tensor,
encoder_hidden: torch.Tensor | None = None,
encoder_pad_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
):
"""
Args:
input_ids: [B, T] decoder tokens (causal LM targets)
encoder_hidden: [B, T_enc, laser_dim] LASER2 output (per_layer or input_only)
encoder_pad_mask: [B, T_enc] bool, True = pad
labels: [B, T] targets for cross-entropy loss (shifted by caller)
Returns:
dict with 'logits' [B, T, V] and optionally 'loss'
"""
B, T = input_ids.shape
# Embeddings
x = self.wte(input_ids)
# Input-only LASER2 injection
if self.config.cross_attention_mode == "input_only" and encoder_hidden is not None:
# Mean-pool encoder output across time, broadcast to all positions
if encoder_pad_mask is not None:
mask = (~encoder_pad_mask).to(encoder_hidden.dtype).unsqueeze(-1)
laser_pool = (encoder_hidden * mask).sum(1) / mask.sum(1).clamp(min=1)
else:
laser_pool = encoder_hidden.mean(1)
laser_proj = self.laser_input_proj(laser_pool.to(x.dtype)) # [B, C]
x = x + laser_proj.unsqueeze(1) # broadcast to all decoder positions
# Decoder blocks
for block in self.blocks:
x = block(x, self.cos, self.sin,
encoder_hidden=encoder_hidden,
encoder_pad_mask=encoder_pad_mask)
# Final norm + LM head
x = norm(x) * self.final_norm_gamma
logits = self.lm_head(x)
out = {"logits": logits}
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
labels.view(-1),
ignore_index=self.config.pad_token_id,
)
out["loss"] = loss
return out
def num_parameters(self, exclude_embedding: bool = False) -> int:
n = sum(p.numel() for p in self.parameters())
if exclude_embedding:
n -= self.wte.weight.numel()
return n