Commit
·
009ff4d
1
Parent(s):
800682b
fix
Browse files
attn.py
CHANGED
|
@@ -37,6 +37,7 @@ def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
|
|
| 37 |
slopes = slopes_power_of_two + extra_slopes_trunc
|
| 38 |
slopes = torch.tensor(slopes, device=self.device, dtype=torch.float32)
|
| 39 |
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
|
|
|
|
| 40 |
return slopes
|
| 41 |
|
| 42 |
|
|
|
|
| 37 |
slopes = slopes_power_of_two + extra_slopes_trunc
|
| 38 |
slopes = torch.tensor(slopes, device=self.device, dtype=torch.float32)
|
| 39 |
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
|
| 40 |
+
slopes = slopes.to(torch.float32)
|
| 41 |
return slopes
|
| 42 |
|
| 43 |
|