KitsuVp commited on
Commit
817d821
·
verified ·
1 Parent(s): f5ba41e

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +2 -1
modeling_neollm.py CHANGED
@@ -588,7 +588,8 @@ class NeoLLMRotaryEmbedding(nn.Module):
588
  cos = emb.cos() * self.attention_scaling
589
  sin = emb.sin() * self.attention_scaling
590
 
591
- return cos, sin
 
592
 
593
 
594
  def rotate_half(x):
 
588
  cos = emb.cos() * self.attention_scaling
589
  sin = emb.sin() * self.attention_scaling
590
 
591
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
592
+
593
 
594
 
595
  def rotate_half(x):