"""V4 attention modules: causal self-attention (GQA) and cross-attention to LASER2.""" import torch import torch.nn as nn import torch.nn.functional as F from arkadiko.embedding.rope import apply_rotary_emb class CausalSelfAttention(nn.Module): """Causal multi-head attention with GQA, RoPE, and QK-norm.""" def __init__(self, config): super().__init__() self.n_head = config.n_head self.n_kv_head = config.n_kv_head self.head_dim = config.head_dim self.n_embd = config.n_embd assert config.n_head % config.n_kv_head == 0, "n_head must be divisible by n_kv_head" assert self.n_head * self.head_dim == self.n_embd, \ f"n_head ({self.n_head}) * head_dim ({self.head_dim}) must equal n_embd ({self.n_embd})" self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) def forward(self, x, cos, sin): B, T, C = x.shape q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) # [B, H, T, D] k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) # [B, H_kv, T, D] v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) # QK-norm q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) # RoPE cos_t = cos[:T].unsqueeze(0).unsqueeze(0) # [1, 1, T, D//2] sin_t = sin[:T].unsqueeze(0).unsqueeze(0) q = apply_rotary_emb(q, cos_t, sin_t) k = apply_rotary_emb(k, cos_t, sin_t) # SDPA with native GQA (repeats KV heads internally via stride tricks) y = F.scaled_dot_product_attention( q, k, v, attn_mask=None, is_causal=True, enable_gqa=True, ) # [B, H, T, D] y = y.transpose(1, 2).contiguous().view(B, T, C) return self.c_proj(y) class CrossAttention(nn.Module): """Cross-attention: decoder Q attends to encoder K/V (from LASER2). No causality mask. No RoPE (encoder output is already positional). """ def __init__(self, config): super().__init__() self.n_head = config.n_head self.head_dim = config.head_dim self.n_embd = config.n_embd self.laser_dim = config.laser_dim self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) self.c_k = nn.Linear(self.laser_dim, self.n_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.laser_dim, self.n_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) def forward(self, x, encoder_hidden, encoder_pad_mask=None): """ Args: x: [B, T_dec, C] decoder hidden states encoder_hidden: [B, T_enc, D_laser] LASER2 per-token output encoder_pad_mask: [B, T_enc] bool, True = pad (ignore) """ B, T_dec, C = x.shape T_enc = encoder_hidden.shape[1] q = self.c_q(x).view(B, T_dec, self.n_head, self.head_dim).transpose(1, 2) # [B, H, T_dec, D] k = self.c_k(encoder_hidden).view(B, T_enc, self.n_head, self.head_dim).transpose(1, 2) v = self.c_v(encoder_hidden).view(B, T_enc, self.n_head, self.head_dim).transpose(1, 2) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) # Encoder padding mask attn_mask = None if encoder_pad_mask is not None: # SDPA wants True = attend, False = mask OR additive mask # encoder_pad_mask: True where pad → we want to mask those out mask = ~encoder_pad_mask # True = attend attn_mask = mask[:, None, None, :] # [B, 1, 1, T_enc] y = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, is_causal=False, ) # [B, H, T_dec, D] y = y.transpose(1, 2).contiguous().view(B, T_dec, C) return self.c_proj(y)