yagizdevre commited on
Commit
4f58476
·
1 Parent(s): d75d7ea

Patch Fix

Browse files
Files changed (1) hide show
  1. attn.py +3 -2
attn.py CHANGED
@@ -52,14 +52,15 @@ class Attention(nn.Module):
52
  # Get slopes for the nearest power of two
53
  n = nearest_power_of_two(n_heads, round_up=False)
54
  slopes_power_of_two = self._generate_slopes(n)
55
-
56
  # Generate extra slopes
57
  extra_slopes = self._generate_slopes(2 * n)
58
  extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
59
  slopes = slopes_power_of_two + extra_slopes_trunc
60
  slopes = torch.tensor(slopes, device=self.device)
61
  slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
62
- return slopes
 
63
 
64
  def forward(self, x):
65
  bsz, seq_len, d_in = x.size()
 
52
  # Get slopes for the nearest power of two
53
  n = nearest_power_of_two(n_heads, round_up=False)
54
  slopes_power_of_two = self._generate_slopes(n)
55
+
56
  # Generate extra slopes
57
  extra_slopes = self._generate_slopes(2 * n)
58
  extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
59
  slopes = slopes_power_of_two + extra_slopes_trunc
60
  slopes = torch.tensor(slopes, device=self.device)
61
  slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
62
+ return slopes.to(torch.float32) # Ensure slopes are in float32
63
+
64
 
65
  def forward(self, x):
66
  bsz, seq_len, d_in = x.size()