LLM-1B-Lab / llm_lab /model /attention.py
Vjeong's picture
Replace F.scaled_dot_product_attention with explicit implementation
e072b51
"""Grouped Query Attention (GQA)."""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from llm_lab.config import ModelConfig
from .rope import RotaryPositionalEmbedding
class GroupedQueryAttention(nn.Module):
"""GQA: A memory-efficient variant of Multi-Head Attention.
MHA vs GQA vs MQA:
- MHA (Multi-Head Attention): Q, K, V all have num_heads β†’ high memory usage
- MQA (Multi-Query Attention): K, V share a single head β†’ risk of quality degradation
- GQA (Grouped Query Attention): K, V are grouped into num_kv_heads
β†’ a middle ground between MHA and MQA, good quality-efficiency balance
Example (num_heads=16, num_kv_heads=4):
Q heads: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
K/V groups: [ 0 , 1 , 2 , 3 ]
β†’ 4 Q heads share 1 K/V head
Attention formula:
Attention(Q, K, V) = softmax(Q·K^T / √d_k) · V
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.head_dim = config.head_dim
self.num_heads = config.num_heads
self.num_kv_heads = config.num_kv_heads
self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
# Q/K/V projections
# Q: hidden_dim β†’ num_heads Γ— head_dim
self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
# K, V: hidden_dim β†’ num_kv_heads Γ— head_dim (smaller than Q!)
self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
# Output projection: merge all head outputs back to hidden_dim
self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
# RoPE
self.rope = RotaryPositionalEmbedding(
dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
position_offset: int = 0,
) -> torch.Tensor:
"""
Args:
x: (batch_size, seq_len, hidden_dim)
mask: (seq_len, seq_len) causal mask
position_offset: position offset (used during inference)
Returns:
(batch_size, seq_len, hidden_dim)
"""
batch_size, seq_len, _ = x.shape
# ──────────────────────────────────────────────
# Step 1: Q, K, V projections
# ──────────────────────────────────────────────
q = self.q_proj(x) # (batch_size, seq_len, num_heads Γ— head_dim)
k = self.k_proj(x) # (batch_size, seq_len, num_kv_heads Γ— head_dim)
v = self.v_proj(x) # (batch_size, seq_len, num_kv_heads Γ— head_dim)
# Reshape into multi-head form
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# β†’ (batch_size, num_heads, seq_len, head_dim)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# β†’ (batch_size, num_kv_heads, seq_len, head_dim)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# ──────────────────────────────────────────────
# Step 2: Apply RoPE (to Q and K only! Not to V)
# ──────────────────────────────────────────────
# Positional information should only affect "where to attend" (QΒ·K),
# not "what to retrieve" (V).
q, k = self.rope(q, k, position_offset)
# ──────────────────────────────────────────────
# Step 3: GQA - expand KV heads (repeat)
# ──────────────────────────────────────────────
# num_kv_heads=4 β†’ num_heads=16: repeat each KV 4 times
if self.num_kv_groups > 1:
k = self._repeat_kv(k) # (batch_size, num_heads, seq_len, head_dim)
v = self._repeat_kv(v)
# ──────────────────────────────────────────────
# Step 4: Scaled Dot-Product Attention
# ──────────────────────────────────────────────
# Step 4-1: Compute scaled attention scores
# Q @ K^T β†’ (batch_size, num_heads, seq_len, seq_len)
# Dividing by √d_k prevents dot products from growing too large,
# which would push softmax into regions with vanishing gradients.
scale = math.sqrt(self.head_dim)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
# Step 4-2: Apply mask
# Causal mask fills future positions with -inf so they become 0 after softmax,
# ensuring the model can only attend to past and current tokens (autoregressive).
if mask is not None:
attn_scores = attn_scores + mask
else:
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=q.device, dtype=q.dtype),
diagonal=1,
)
attn_scores = attn_scores + causal_mask
# Step 4-3: Softmax β†’ attention weights (probability distribution over keys)
attn_weights = F.softmax(attn_scores, dim=-1)
# Step 4-4: Dropout (only during training)
# Randomly zeroing some attention weights acts as regularization,
# preventing the model from relying too heavily on specific token relationships.
if self.training and self.config.dropout > 0.0:
attn_weights = F.dropout(attn_weights, p=self.config.dropout)
# Step 4-5: Weighted sum of values
attn_out = torch.matmul(attn_weights, v)
# β†’ (batch_size, num_heads, seq_len, head_dim)
# ──────────────────────────────────────────────
# Step 5: Merge heads + output projection
# ──────────────────────────────────────────────
attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
# β†’ (batch_size, seq_len, num_heads Γ— head_dim)
return self.o_proj(attn_out) # β†’ (batch_size, seq_len, hidden_dim)
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""Repeat KV heads to match the number of Q heads.
(batch_size, num_kv_heads, seq_len, head_dim) β†’ (batch_size, num_heads, seq_len, head_dim)
Example: num_kv_heads=4, num_kv_groups=4
[kv0, kv1, kv2, kv3] β†’ [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
"""
batch_size, num_kv_heads, seq_len, head_dim = x.shape
x = x[:, :, None, :, :] # (batch_size, num_kv_heads, 1, seq_len, head_dim)
x = x.expand(batch_size, num_kv_heads, self.num_kv_groups, seq_len, head_dim)
return x.reshape(batch_size, self.num_heads, seq_len, head_dim)