arkadiko-v4-ablation / code /attention.py
Ahmed
Upload code/attention.py with huggingface_hub
bf3b929 verified
"""V4 attention modules: causal self-attention (GQA) and cross-attention to LASER2."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from arkadiko.embedding.rope import apply_rotary_emb
class CausalSelfAttention(nn.Module):
"""Causal multi-head attention with GQA, RoPE, and QK-norm."""
def __init__(self, config):
super().__init__()
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.head_dim = config.head_dim
self.n_embd = config.n_embd
assert config.n_head % config.n_kv_head == 0, "n_head must be divisible by n_kv_head"
assert self.n_head * self.head_dim == self.n_embd, \
f"n_head ({self.n_head}) * head_dim ({self.head_dim}) must equal n_embd ({self.n_embd})"
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos, sin):
B, T, C = x.shape
q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) # [B, H, T, D]
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) # [B, H_kv, T, D]
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
# QK-norm
q = F.rms_norm(q, (q.size(-1),))
k = F.rms_norm(k, (k.size(-1),))
# RoPE
cos_t = cos[:T].unsqueeze(0).unsqueeze(0) # [1, 1, T, D//2]
sin_t = sin[:T].unsqueeze(0).unsqueeze(0)
q = apply_rotary_emb(q, cos_t, sin_t)
k = apply_rotary_emb(k, cos_t, sin_t)
# SDPA with native GQA (repeats KV heads internally via stride tricks)
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
is_causal=True,
enable_gqa=True,
) # [B, H, T, D]
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.c_proj(y)
class CrossAttention(nn.Module):
"""Cross-attention: decoder Q attends to encoder K/V (from LASER2).
No causality mask. No RoPE (encoder output is already positional).
"""
def __init__(self, config):
super().__init__()
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_embd = config.n_embd
self.laser_dim = config.laser_dim
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.laser_dim, self.n_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.laser_dim, self.n_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, encoder_hidden, encoder_pad_mask=None):
"""
Args:
x: [B, T_dec, C] decoder hidden states
encoder_hidden: [B, T_enc, D_laser] LASER2 per-token output
encoder_pad_mask: [B, T_enc] bool, True = pad (ignore)
"""
B, T_dec, C = x.shape
T_enc = encoder_hidden.shape[1]
q = self.c_q(x).view(B, T_dec, self.n_head, self.head_dim).transpose(1, 2) # [B, H, T_dec, D]
k = self.c_k(encoder_hidden).view(B, T_enc, self.n_head, self.head_dim).transpose(1, 2)
v = self.c_v(encoder_hidden).view(B, T_enc, self.n_head, self.head_dim).transpose(1, 2)
q = F.rms_norm(q, (q.size(-1),))
k = F.rms_norm(k, (k.size(-1),))
# Encoder padding mask
attn_mask = None
if encoder_pad_mask is not None:
# SDPA wants True = attend, False = mask OR additive mask
# encoder_pad_mask: True where pad → we want to mask those out
mask = ~encoder_pad_mask # True = attend
attn_mask = mask[:, None, None, :] # [B, 1, 1, T_enc]
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
is_causal=False,
) # [B, H, T_dec, D]
y = y.transpose(1, 2).contiguous().view(B, T_dec, C)
return self.c_proj(y)