Update rope.py
Browse files
rope.py
CHANGED
|
@@ -57,4 +57,4 @@ def apply_rotary_emb(x, freqs_cis, position_ids, num_heads, rot_dim=None, interl
|
|
| 57 |
yi = xr * sin + xi * cos
|
| 58 |
y = torch.stack((yr, yi), dim=-1).flatten(-2) # (B,H,T,rd)
|
| 59 |
|
| 60 |
-
return torch.cat([y.to(x.dtype), x_pass], dim=-1)
|
|
|
|
| 57 |
yi = xr * sin + xi * cos
|
| 58 |
y = torch.stack((yr, yi), dim=-1).flatten(-2) # (B,H,T,rd)
|
| 59 |
|
| 60 |
+
return torch.cat([y.to(x.dtype), x_pass], dim=-1)
|