""" model/attention.py Causal Multi-Head Self-Attention with RoPE. Architecture: Input x (B, T, d_model) -> Linear projections Q, K, V (no bias) -> Reshape to (B, n_heads, T, head_dim) -> Apply RoPE to Q and K -> Scaled dot-product attention with causal mask -> Reshape back to (B, T, d_model) -> Output projection O (no bias) Uses torch.nn.functional.scaled_dot_product_attention (Flash Attention when available via PyTorch 2.0+) for memory-efficient attention. The causal mask is handled by is_causal=True — no need to materialize an explicit O(T^2) mask tensor. """ import torch import torch.nn as nn import torch.nn.functional as F from model.config import ModelConfig from model.rope import RoPECache, apply_rope class CausalSelfAttention(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.n_heads = config.n_heads self.head_dim = config.head_dim self.d_model = config.d_model self.dropout = config.dropout # Q, K, V projections fused into one matrix for efficiency # Output: (B, T, 3 * d_model), then split self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias) # Output projection self.o_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) # Attention dropout (applied inside sdpa) self.attn_dropout = config.dropout # RoPE cache — lives as a buffer (moves to GPU automatically) self.rope = RoPECache( head_dim = config.head_dim, max_seq_len = config.context_length, theta = config.rope_theta, ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x : (B, T, d_model) Returns: out : (B, T, d_model) """ B, T, C = x.shape # C = d_model # ---- QKV projection ---------------------------------------- # qkv = self.qkv_proj(x) # (B, T, 3*C) q, k, v = qkv.split(self.d_model, dim=-1) # each: (B, T, C) # ---- Reshape to (B, n_heads, T, head_dim) ------------------ # q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # ---- Apply RoPE to Q and K --------------------------------- # cos, sin = self.rope.get(T) # (T, head_dim) q, k = apply_rope(q, k, cos, sin) # ---- Scaled dot-product attention (Flash Attention) -------- # # is_causal=True handles the causal mask internally — no mask alloc. # dropout_p only applies during training. attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask = None, dropout_p = self.attn_dropout if self.training else 0.0, is_causal = True, ) # (B, n_heads, T, head_dim) # ---- Merge heads ------------------------------------------- # # contiguous() needed before view after transpose attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C) # ---- Output projection ------------------------------------- # return self.o_proj(attn_out) # (B, T, d_model) # ------------------------------------------------------------------ # # QUICK CHECK # ------------------------------------------------------------------ # if __name__ == "__main__": from model.config import SLLM_100M cfg = SLLM_100M attn = CausalSelfAttention(cfg) print(f"Attention params : {sum(p.numel() for p in attn.parameters())/1e6:.2f}M") B, T = 2, 64 x = torch.randn(B, T, cfg.d_model) out = attn(x) print(f"Input shape : {x.shape}") print(f"Output shape : {out.shape}") assert out.shape == (B, T, cfg.d_model), "Shape mismatch!" print("PASS")