Replace F.scaled_dot_product_attention with explicit implementation
Browse filesExpand the single library call into 5 visible steps (scale, mask,
softmax, dropout, value-multiply) so learners can inspect each stage
of Scaled Dot-Product Attention directly in the source code.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- llm_lab/model/attention.py +31 -7
llm_lab/model/attention.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""Grouped Query Attention (GQA)."""
|
| 2 |
|
|
|
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
import torch
|
|
@@ -100,13 +101,36 @@ class GroupedQueryAttention(nn.Module):
|
|
| 100 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 101 |
# Step 4: Scaled Dot-Product Attention
|
| 102 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
-
#
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
# β (batch_size, num_heads, seq_len, head_dim)
|
| 111 |
|
| 112 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 1 |
"""Grouped Query Attention (GQA)."""
|
| 2 |
|
| 3 |
+
import math
|
| 4 |
from typing import Optional
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 101 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 102 |
# Step 4: Scaled Dot-Product Attention
|
| 103 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
# Step 4-1: Compute scaled attention scores
|
| 105 |
+
# Q @ K^T β (batch_size, num_heads, seq_len, seq_len)
|
| 106 |
+
# Dividing by βd_k prevents dot products from growing too large,
|
| 107 |
+
# which would push softmax into regions with vanishing gradients.
|
| 108 |
+
scale = math.sqrt(self.head_dim)
|
| 109 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
|
| 110 |
+
|
| 111 |
+
# Step 4-2: Apply mask
|
| 112 |
+
# Causal mask fills future positions with -inf so they become 0 after softmax,
|
| 113 |
+
# ensuring the model can only attend to past and current tokens (autoregressive).
|
| 114 |
+
if mask is not None:
|
| 115 |
+
attn_scores = attn_scores + mask
|
| 116 |
+
else:
|
| 117 |
+
causal_mask = torch.triu(
|
| 118 |
+
torch.full((seq_len, seq_len), float("-inf"), device=q.device, dtype=q.dtype),
|
| 119 |
+
diagonal=1,
|
| 120 |
+
)
|
| 121 |
+
attn_scores = attn_scores + causal_mask
|
| 122 |
+
|
| 123 |
+
# Step 4-3: Softmax β attention weights (probability distribution over keys)
|
| 124 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 125 |
+
|
| 126 |
+
# Step 4-4: Dropout (only during training)
|
| 127 |
+
# Randomly zeroing some attention weights acts as regularization,
|
| 128 |
+
# preventing the model from relying too heavily on specific token relationships.
|
| 129 |
+
if self.training and self.config.dropout > 0.0:
|
| 130 |
+
attn_weights = F.dropout(attn_weights, p=self.config.dropout)
|
| 131 |
+
|
| 132 |
+
# Step 4-5: Weighted sum of values
|
| 133 |
+
attn_out = torch.matmul(attn_weights, v)
|
| 134 |
# β (batch_size, num_heads, seq_len, head_dim)
|
| 135 |
|
| 136 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|