File size: 4,223 Bytes
bf3b929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""V4 attention modules: causal self-attention (GQA) and cross-attention to LASER2."""

import torch
import torch.nn as nn
import torch.nn.functional as F

from arkadiko.embedding.rope import apply_rotary_emb


class CausalSelfAttention(nn.Module):
    """Causal multi-head attention with GQA, RoPE, and QK-norm."""

    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.head_dim = config.head_dim
        self.n_embd = config.n_embd
        assert config.n_head % config.n_kv_head == 0, "n_head must be divisible by n_kv_head"
        assert self.n_head * self.head_dim == self.n_embd, \
            f"n_head ({self.n_head}) * head_dim ({self.head_dim}) must equal n_embd ({self.n_embd})"

        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

    def forward(self, x, cos, sin):
        B, T, C = x.shape

        q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)      # [B, H, T, D]
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)   # [B, H_kv, T, D]
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)

        # QK-norm
        q = F.rms_norm(q, (q.size(-1),))
        k = F.rms_norm(k, (k.size(-1),))

        # RoPE
        cos_t = cos[:T].unsqueeze(0).unsqueeze(0)  # [1, 1, T, D//2]
        sin_t = sin[:T].unsqueeze(0).unsqueeze(0)
        q = apply_rotary_emb(q, cos_t, sin_t)
        k = apply_rotary_emb(k, cos_t, sin_t)

        # SDPA with native GQA (repeats KV heads internally via stride tricks)
        y = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            is_causal=True,
            enable_gqa=True,
        )  # [B, H, T, D]

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)


class CrossAttention(nn.Module):
    """Cross-attention: decoder Q attends to encoder K/V (from LASER2).

    No causality mask. No RoPE (encoder output is already positional).
    """

    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.head_dim = config.head_dim
        self.n_embd = config.n_embd
        self.laser_dim = config.laser_dim

        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.laser_dim, self.n_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.laser_dim, self.n_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

    def forward(self, x, encoder_hidden, encoder_pad_mask=None):
        """
        Args:
            x: [B, T_dec, C] decoder hidden states
            encoder_hidden: [B, T_enc, D_laser] LASER2 per-token output
            encoder_pad_mask: [B, T_enc] bool, True = pad (ignore)
        """
        B, T_dec, C = x.shape
        T_enc = encoder_hidden.shape[1]

        q = self.c_q(x).view(B, T_dec, self.n_head, self.head_dim).transpose(1, 2)        # [B, H, T_dec, D]
        k = self.c_k(encoder_hidden).view(B, T_enc, self.n_head, self.head_dim).transpose(1, 2)
        v = self.c_v(encoder_hidden).view(B, T_enc, self.n_head, self.head_dim).transpose(1, 2)

        q = F.rms_norm(q, (q.size(-1),))
        k = F.rms_norm(k, (k.size(-1),))

        # Encoder padding mask
        attn_mask = None
        if encoder_pad_mask is not None:
            # SDPA wants True = attend, False = mask OR additive mask
            # encoder_pad_mask: True where pad → we want to mask those out
            mask = ~encoder_pad_mask  # True = attend
            attn_mask = mask[:, None, None, :]  # [B, 1, 1, T_enc]

        y = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            is_causal=False,
        )  # [B, H, T_dec, D]

        y = y.transpose(1, 2).contiguous().view(B, T_dec, C)
        return self.c_proj(y)