| """ |
| model/attention.py |
| |
| Causal Multi-Head Self-Attention with RoPE. |
| |
| Architecture: |
| Input x (B, T, d_model) |
| -> Linear projections Q, K, V (no bias) |
| -> Reshape to (B, n_heads, T, head_dim) |
| -> Apply RoPE to Q and K |
| -> Scaled dot-product attention with causal mask |
| -> Reshape back to (B, T, d_model) |
| -> Output projection O (no bias) |
| |
| Uses torch.nn.functional.scaled_dot_product_attention (Flash Attention |
| when available via PyTorch 2.0+) for memory-efficient attention. |
| The causal mask is handled by is_causal=True — no need to materialize |
| an explicit O(T^2) mask tensor. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from model.config import ModelConfig |
| from model.rope import RoPECache, apply_rope |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.n_heads = config.n_heads |
| self.head_dim = config.head_dim |
| self.d_model = config.d_model |
| self.dropout = config.dropout |
|
|
| |
| |
| self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias) |
|
|
| |
| self.o_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) |
|
|
| |
| self.attn_dropout = config.dropout |
|
|
| |
| self.rope = RoPECache( |
| head_dim = config.head_dim, |
| max_seq_len = config.context_length, |
| theta = config.rope_theta, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x : (B, T, d_model) |
| |
| Returns: |
| out : (B, T, d_model) |
| """ |
| B, T, C = x.shape |
|
|
| |
| qkv = self.qkv_proj(x) |
| q, k, v = qkv.split(self.d_model, dim=-1) |
|
|
| |
| q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| cos, sin = self.rope.get(T) |
| q, k = apply_rope(q, k, cos, sin) |
|
|
| |
| |
| |
| attn_out = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask = None, |
| dropout_p = self.attn_dropout if self.training else 0.0, |
| is_causal = True, |
| ) |
|
|
| |
| |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
| |
| return self.o_proj(attn_out) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| from model.config import SLLM_100M |
|
|
| cfg = SLLM_100M |
| attn = CausalSelfAttention(cfg) |
| print(f"Attention params : {sum(p.numel() for p in attn.parameters())/1e6:.2f}M") |
|
|
| B, T = 2, 64 |
| x = torch.randn(B, T, cfg.d_model) |
| out = attn(x) |
|
|
| print(f"Input shape : {x.shape}") |
| print(f"Output shape : {out.shape}") |
| assert out.shape == (B, T, cfg.d_model), "Shape mismatch!" |
| print("PASS") |
|
|