Fix device and type in RotaryPosinionalEmbeddings
Browse files- encoder.py +2 -2
encoder.py
CHANGED
|
@@ -354,9 +354,9 @@ class RotaryPositionalEmbedding(PositionalEncoding):
|
|
| 354 |
return None
|
| 355 |
positions = torch.arange(0, length, dtype=torch.float32, device=device)
|
| 356 |
inv_freq = 1.0 / (
|
| 357 |
-
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
|
| 358 |
)
|
| 359 |
-
t = torch.arange(length, device=positions.device
|
| 360 |
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 361 |
emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
|
| 362 |
return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
|
|
|
|
| 354 |
return None
|
| 355 |
positions = torch.arange(0, length, dtype=torch.float32, device=device)
|
| 356 |
inv_freq = 1.0 / (
|
| 357 |
+
self.base ** (torch.arange(0, self.dim, 2, device=positions.device).float() / self.dim)
|
| 358 |
)
|
| 359 |
+
t = torch.arange(length, device=positions.device, dtype=inv_freq.dtype)
|
| 360 |
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 361 |
emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
|
| 362 |
return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
|