sllm / model /attention.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
model/attention.py
Causal Multi-Head Self-Attention with RoPE.
Architecture:
Input x (B, T, d_model)
-> Linear projections Q, K, V (no bias)
-> Reshape to (B, n_heads, T, head_dim)
-> Apply RoPE to Q and K
-> Scaled dot-product attention with causal mask
-> Reshape back to (B, T, d_model)
-> Output projection O (no bias)
Uses torch.nn.functional.scaled_dot_product_attention (Flash Attention
when available via PyTorch 2.0+) for memory-efficient attention.
The causal mask is handled by is_causal=True — no need to materialize
an explicit O(T^2) mask tensor.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.config import ModelConfig
from model.rope import RoPECache, apply_rope
class CausalSelfAttention(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.n_heads = config.n_heads
self.head_dim = config.head_dim
self.d_model = config.d_model
self.dropout = config.dropout
# Q, K, V projections fused into one matrix for efficiency
# Output: (B, T, 3 * d_model), then split
self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
# Output projection
self.o_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
# Attention dropout (applied inside sdpa)
self.attn_dropout = config.dropout
# RoPE cache — lives as a buffer (moves to GPU automatically)
self.rope = RoPECache(
head_dim = config.head_dim,
max_seq_len = config.context_length,
theta = config.rope_theta,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x : (B, T, d_model)
Returns:
out : (B, T, d_model)
"""
B, T, C = x.shape # C = d_model
# ---- QKV projection ---------------------------------------- #
qkv = self.qkv_proj(x) # (B, T, 3*C)
q, k, v = qkv.split(self.d_model, dim=-1) # each: (B, T, C)
# ---- Reshape to (B, n_heads, T, head_dim) ------------------ #
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
# ---- Apply RoPE to Q and K --------------------------------- #
cos, sin = self.rope.get(T) # (T, head_dim)
q, k = apply_rope(q, k, cos, sin)
# ---- Scaled dot-product attention (Flash Attention) -------- #
# is_causal=True handles the causal mask internally — no mask alloc.
# dropout_p only applies during training.
attn_out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = None,
dropout_p = self.attn_dropout if self.training else 0.0,
is_causal = True,
) # (B, n_heads, T, head_dim)
# ---- Merge heads ------------------------------------------- #
# contiguous() needed before view after transpose
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C)
# ---- Output projection ------------------------------------- #
return self.o_proj(attn_out) # (B, T, d_model)
# ------------------------------------------------------------------ #
# QUICK CHECK
# ------------------------------------------------------------------ #
if __name__ == "__main__":
from model.config import SLLM_100M
cfg = SLLM_100M
attn = CausalSelfAttention(cfg)
print(f"Attention params : {sum(p.numel() for p in attn.parameters())/1e6:.2f}M")
B, T = 2, 64
x = torch.randn(B, T, cfg.d_model)
out = attn(x)
print(f"Input shape : {x.shape}")
print(f"Output shape : {out.shape}")
assert out.shape == (B, T, cfg.d_model), "Shape mismatch!"
print("PASS")