Vjeong Claude Sonnet 4.6 commited on
Commit
e072b51
Β·
1 Parent(s): fac7da2

Replace F.scaled_dot_product_attention with explicit implementation

Browse files

Expand the single library call into 5 visible steps (scale, mask,
softmax, dropout, value-multiply) so learners can inspect each stage
of Scaled Dot-Product Attention directly in the source code.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. llm_lab/model/attention.py +31 -7
llm_lab/model/attention.py CHANGED
@@ -1,5 +1,6 @@
1
  """Grouped Query Attention (GQA)."""
2
 
 
3
  from typing import Optional
4
 
5
  import torch
@@ -100,13 +101,36 @@ class GroupedQueryAttention(nn.Module):
100
  # ──────────────────────────────────────────────
101
  # Step 4: Scaled Dot-Product Attention
102
  # ──────────────────────────────────────────────
103
- # Uses PyTorch >= 2.0's optimized implementation (Flash Attention applied automatically)
104
- attn_out = F.scaled_dot_product_attention(
105
- q, k, v,
106
- attn_mask=mask,
107
- dropout_p=self.config.dropout if self.training else 0.0,
108
- is_causal=(mask is None), # apply automatic causal masking when no mask is provided
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # β†’ (batch_size, num_heads, seq_len, head_dim)
111
 
112
  # ──────────────────────────────────────────────
 
1
  """Grouped Query Attention (GQA)."""
2
 
3
+ import math
4
  from typing import Optional
5
 
6
  import torch
 
101
  # ──────────────────────────────────────────────
102
  # Step 4: Scaled Dot-Product Attention
103
  # ──────────────────────────────────────────────
104
+ # Step 4-1: Compute scaled attention scores
105
+ # Q @ K^T β†’ (batch_size, num_heads, seq_len, seq_len)
106
+ # Dividing by √d_k prevents dot products from growing too large,
107
+ # which would push softmax into regions with vanishing gradients.
108
+ scale = math.sqrt(self.head_dim)
109
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
110
+
111
+ # Step 4-2: Apply mask
112
+ # Causal mask fills future positions with -inf so they become 0 after softmax,
113
+ # ensuring the model can only attend to past and current tokens (autoregressive).
114
+ if mask is not None:
115
+ attn_scores = attn_scores + mask
116
+ else:
117
+ causal_mask = torch.triu(
118
+ torch.full((seq_len, seq_len), float("-inf"), device=q.device, dtype=q.dtype),
119
+ diagonal=1,
120
+ )
121
+ attn_scores = attn_scores + causal_mask
122
+
123
+ # Step 4-3: Softmax β†’ attention weights (probability distribution over keys)
124
+ attn_weights = F.softmax(attn_scores, dim=-1)
125
+
126
+ # Step 4-4: Dropout (only during training)
127
+ # Randomly zeroing some attention weights acts as regularization,
128
+ # preventing the model from relying too heavily on specific token relationships.
129
+ if self.training and self.config.dropout > 0.0:
130
+ attn_weights = F.dropout(attn_weights, p=self.config.dropout)
131
+
132
+ # Step 4-5: Weighted sum of values
133
+ attn_out = torch.matmul(attn_weights, v)
134
  # β†’ (batch_size, num_heads, seq_len, head_dim)
135
 
136
  # ──────────────────────────────────────────────