Vjeong Claude Sonnet 4.6 commited on
Commit
331cfcd
·
1 Parent(s): baf4768

Fix dtype mismatch in RoPE cos/sin for mixed precision training

Browse files

Cast cos/sin buffers (float32) to input dtype in _apply_rotation
to prevent implicit upcasting when q/k are bf16/fp16.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. llm_lab/model/rope.py +3 -2
llm_lab/model/rope.py CHANGED
@@ -93,8 +93,9 @@ class RotaryPositionalEmbedding(nn.Module):
93
  x_odd = x[..., 1::2] # odd indices
94
 
95
  # Adjust dimensions for broadcasting: (seq_len, dim/2) → (1, 1, seq_len, dim/2)
96
- cos = cos.unsqueeze(0).unsqueeze(0)
97
- sin = sin.unsqueeze(0).unsqueeze(0)
 
98
 
99
  # Apply rotation
100
  rotated_even = x_even * cos - x_odd * sin
 
93
  x_odd = x[..., 1::2] # odd indices
94
 
95
  # Adjust dimensions for broadcasting: (seq_len, dim/2) → (1, 1, seq_len, dim/2)
96
+ # Cast to input dtype to support bf16/fp16 mixed precision training
97
+ cos = cos.unsqueeze(0).unsqueeze(0).to(x.dtype)
98
+ sin = sin.unsqueeze(0).unsqueeze(0).to(x.dtype)
99
 
100
  # Apply rotation
101
  rotated_even = x_even * cos - x_odd * sin