yagizdevre commited on
Commit
f223616
·
1 Parent(s): 111f34f
Files changed (1) hide show
  1. attn.py +1 -1
attn.py CHANGED
@@ -35,7 +35,7 @@ def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
35
  extra_slopes = self._generate_slopes(2 * n)
36
  extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
37
  slopes = slopes_power_of_two + extra_slopes_trunc
38
- slopes = torch.tensor(slopes, device=self.device)
39
  slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
40
  return slopes
41
 
 
35
  extra_slopes = self._generate_slopes(2 * n)
36
  extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
37
  slopes = slopes_power_of_two + extra_slopes_trunc
38
+ slopes = torch.tensor(slopes, device=self.device, dtype=torch.float32)
39
  slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
40
  return slopes
41