Fix dtype mismatch in RoPE cos/sin for mixed precision training
Browse filesCast cos/sin buffers (float32) to input dtype in _apply_rotation
to prevent implicit upcasting when q/k are bf16/fp16.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- llm_lab/model/rope.py +3 -2
llm_lab/model/rope.py
CHANGED
|
@@ -93,8 +93,9 @@ class RotaryPositionalEmbedding(nn.Module):
|
|
| 93 |
x_odd = x[..., 1::2] # odd indices
|
| 94 |
|
| 95 |
# Adjust dimensions for broadcasting: (seq_len, dim/2) → (1, 1, seq_len, dim/2)
|
| 96 |
-
|
| 97 |
-
|
|
|
|
| 98 |
|
| 99 |
# Apply rotation
|
| 100 |
rotated_even = x_even * cos - x_odd * sin
|
|
|
|
| 93 |
x_odd = x[..., 1::2] # odd indices
|
| 94 |
|
| 95 |
# Adjust dimensions for broadcasting: (seq_len, dim/2) → (1, 1, seq_len, dim/2)
|
| 96 |
+
# Cast to input dtype to support bf16/fp16 mixed precision training
|
| 97 |
+
cos = cos.unsqueeze(0).unsqueeze(0).to(x.dtype)
|
| 98 |
+
sin = sin.unsqueeze(0).unsqueeze(0).to(x.dtype)
|
| 99 |
|
| 100 |
# Apply rotation
|
| 101 |
rotated_even = x_even * cos - x_odd * sin
|