File size: 4,095 Bytes
7f974df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
model/attention.py

Causal Multi-Head Self-Attention with RoPE.

Architecture:
    Input x  (B, T, d_model)
      -> Linear projections Q, K, V  (no bias)
      -> Reshape to (B, n_heads, T, head_dim)
      -> Apply RoPE to Q and K
      -> Scaled dot-product attention with causal mask
      -> Reshape back to (B, T, d_model)
      -> Output projection O  (no bias)

Uses torch.nn.functional.scaled_dot_product_attention (Flash Attention
when available via PyTorch 2.0+) for memory-efficient attention.
The causal mask is handled by is_causal=True — no need to materialize
an explicit O(T^2) mask tensor.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from model.config import ModelConfig
from model.rope import RoPECache, apply_rope


class CausalSelfAttention(nn.Module):

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_heads   = config.n_heads
        self.head_dim  = config.head_dim
        self.d_model   = config.d_model
        self.dropout   = config.dropout

        # Q, K, V projections fused into one matrix for efficiency
        # Output: (B, T, 3 * d_model), then split
        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)

        # Output projection
        self.o_proj   = nn.Linear(config.d_model, config.d_model, bias=config.bias)

        # Attention dropout (applied inside sdpa)
        self.attn_dropout = config.dropout

        # RoPE cache — lives as a buffer (moves to GPU automatically)
        self.rope = RoPECache(
            head_dim    = config.head_dim,
            max_seq_len = config.context_length,
            theta       = config.rope_theta,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x : (B, T, d_model)

        Returns:
            out : (B, T, d_model)
        """
        B, T, C = x.shape                              # C = d_model

        # ---- QKV projection ---------------------------------------- #
        qkv = self.qkv_proj(x)                        # (B, T, 3*C)
        q, k, v = qkv.split(self.d_model, dim=-1)     # each: (B, T, C)

        # ---- Reshape to (B, n_heads, T, head_dim) ------------------ #
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        # ---- Apply RoPE to Q and K --------------------------------- #
        cos, sin = self.rope.get(T)                    # (T, head_dim)
        q, k     = apply_rope(q, k, cos, sin)

        # ---- Scaled dot-product attention (Flash Attention) -------- #
        # is_causal=True handles the causal mask internally — no mask alloc.
        # dropout_p only applies during training.
        attn_out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask   = None,
            dropout_p   = self.attn_dropout if self.training else 0.0,
            is_causal   = True,
        )                                              # (B, n_heads, T, head_dim)

        # ---- Merge heads ------------------------------------------- #
        # contiguous() needed before view after transpose
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C)

        # ---- Output projection ------------------------------------- #
        return self.o_proj(attn_out)                   # (B, T, d_model)


# ------------------------------------------------------------------ #
#  QUICK CHECK
# ------------------------------------------------------------------ #

if __name__ == "__main__":
    from model.config import SLLM_100M

    cfg  = SLLM_100M
    attn = CausalSelfAttention(cfg)
    print(f"Attention params : {sum(p.numel() for p in attn.parameters())/1e6:.2f}M")

    B, T = 2, 64
    x   = torch.randn(B, T, cfg.d_model)
    out = attn(x)

    print(f"Input  shape : {x.shape}")
    print(f"Output shape : {out.shape}")
    assert out.shape == (B, T, cfg.d_model), "Shape mismatch!"
    print("PASS")