| """V4 attention modules: causal self-attention (GQA) and cross-attention to LASER2.""" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from arkadiko.embedding.rope import apply_rotary_emb |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| """Causal multi-head attention with GQA, RoPE, and QK-norm.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.n_head = config.n_head |
| self.n_kv_head = config.n_kv_head |
| self.head_dim = config.head_dim |
| self.n_embd = config.n_embd |
| assert config.n_head % config.n_kv_head == 0, "n_head must be divisible by n_kv_head" |
| assert self.n_head * self.head_dim == self.n_embd, \ |
| f"n_head ({self.n_head}) * head_dim ({self.head_dim}) must equal n_embd ({self.n_embd})" |
|
|
| self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) |
| self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) |
| self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) |
| self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
| def forward(self, x, cos, sin): |
| B, T, C = x.shape |
|
|
| q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
| v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
| |
| q = F.rms_norm(q, (q.size(-1),)) |
| k = F.rms_norm(k, (k.size(-1),)) |
|
|
| |
| cos_t = cos[:T].unsqueeze(0).unsqueeze(0) |
| sin_t = sin[:T].unsqueeze(0).unsqueeze(0) |
| q = apply_rotary_emb(q, cos_t, sin_t) |
| k = apply_rotary_emb(k, cos_t, sin_t) |
|
|
| |
| y = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=None, |
| is_causal=True, |
| enable_gqa=True, |
| ) |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| return self.c_proj(y) |
|
|
|
|
| class CrossAttention(nn.Module): |
| """Cross-attention: decoder Q attends to encoder K/V (from LASER2). |
| |
| No causality mask. No RoPE (encoder output is already positional). |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.n_head = config.n_head |
| self.head_dim = config.head_dim |
| self.n_embd = config.n_embd |
| self.laser_dim = config.laser_dim |
|
|
| self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) |
| self.c_k = nn.Linear(self.laser_dim, self.n_head * self.head_dim, bias=False) |
| self.c_v = nn.Linear(self.laser_dim, self.n_head * self.head_dim, bias=False) |
| self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
| def forward(self, x, encoder_hidden, encoder_pad_mask=None): |
| """ |
| Args: |
| x: [B, T_dec, C] decoder hidden states |
| encoder_hidden: [B, T_enc, D_laser] LASER2 per-token output |
| encoder_pad_mask: [B, T_enc] bool, True = pad (ignore) |
| """ |
| B, T_dec, C = x.shape |
| T_enc = encoder_hidden.shape[1] |
|
|
| q = self.c_q(x).view(B, T_dec, self.n_head, self.head_dim).transpose(1, 2) |
| k = self.c_k(encoder_hidden).view(B, T_enc, self.n_head, self.head_dim).transpose(1, 2) |
| v = self.c_v(encoder_hidden).view(B, T_enc, self.n_head, self.head_dim).transpose(1, 2) |
|
|
| q = F.rms_norm(q, (q.size(-1),)) |
| k = F.rms_norm(k, (k.size(-1),)) |
|
|
| |
| attn_mask = None |
| if encoder_pad_mask is not None: |
| |
| |
| mask = ~encoder_pad_mask |
| attn_mask = mask[:, None, None, :] |
|
|
| y = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=attn_mask, |
| is_causal=False, |
| ) |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T_dec, C) |
| return self.c_proj(y) |
|
|