Marmik commited on
Commit
ac87ecf
·
verified ·
1 Parent(s): 054320c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tiny_mixtral.py +2 -1
modeling_tiny_mixtral.py CHANGED
@@ -190,7 +190,8 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
190
  x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
191
  x_out=x_out.reshape(*x.shape)
192
 
193
- return x_out.type_as(x).to(device)
 
194
 
195
 
196
 
 
190
  x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
191
  x_out=x_out.reshape(*x.shape)
192
 
193
+ # Keep the output on the same device as the input, not the device parameter
194
+ return x_out.type_as(x)
195
 
196
 
197