yagizdevre commited on
Commit
009ff4d
·
1 Parent(s): 800682b
Files changed (1) hide show
  1. attn.py +1 -0
attn.py CHANGED
@@ -37,6 +37,7 @@ def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
37
  slopes = slopes_power_of_two + extra_slopes_trunc
38
  slopes = torch.tensor(slopes, device=self.device, dtype=torch.float32)
39
  slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
 
40
  return slopes
41
 
42
 
 
37
  slopes = slopes_power_of_two + extra_slopes_trunc
38
  slopes = torch.tensor(slopes, device=self.device, dtype=torch.float32)
39
  slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
40
+ slopes = slopes.to(torch.float32)
41
  return slopes
42
 
43