Commit ·
f223616
1
Parent(s): 111f34f
fix
Browse files
attn.py
CHANGED
|
@@ -35,7 +35,7 @@ def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
|
|
| 35 |
extra_slopes = self._generate_slopes(2 * n)
|
| 36 |
extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
|
| 37 |
slopes = slopes_power_of_two + extra_slopes_trunc
|
| 38 |
-
slopes = torch.tensor(slopes, device=self.device)
|
| 39 |
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
|
| 40 |
return slopes
|
| 41 |
|
|
|
|
| 35 |
extra_slopes = self._generate_slopes(2 * n)
|
| 36 |
extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
|
| 37 |
slopes = slopes_power_of_two + extra_slopes_trunc
|
| 38 |
+
slopes = torch.tensor(slopes, device=self.device, dtype=torch.float32)
|
| 39 |
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
|
| 40 |
return slopes
|
| 41 |
|