Commit ·
4f58476
1
Parent(s): d75d7ea
Patch Fix
Browse files
attn.py
CHANGED
|
@@ -52,14 +52,15 @@ class Attention(nn.Module):
|
|
| 52 |
# Get slopes for the nearest power of two
|
| 53 |
n = nearest_power_of_two(n_heads, round_up=False)
|
| 54 |
slopes_power_of_two = self._generate_slopes(n)
|
| 55 |
-
|
| 56 |
# Generate extra slopes
|
| 57 |
extra_slopes = self._generate_slopes(2 * n)
|
| 58 |
extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
|
| 59 |
slopes = slopes_power_of_two + extra_slopes_trunc
|
| 60 |
slopes = torch.tensor(slopes, device=self.device)
|
| 61 |
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
|
| 62 |
-
return slopes
|
|
|
|
| 63 |
|
| 64 |
def forward(self, x):
|
| 65 |
bsz, seq_len, d_in = x.size()
|
|
|
|
| 52 |
# Get slopes for the nearest power of two
|
| 53 |
n = nearest_power_of_two(n_heads, round_up=False)
|
| 54 |
slopes_power_of_two = self._generate_slopes(n)
|
| 55 |
+
|
| 56 |
# Generate extra slopes
|
| 57 |
extra_slopes = self._generate_slopes(2 * n)
|
| 58 |
extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
|
| 59 |
slopes = slopes_power_of_two + extra_slopes_trunc
|
| 60 |
slopes = torch.tensor(slopes, device=self.device)
|
| 61 |
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
|
| 62 |
+
return slopes.to(torch.float32) # Ensure slopes are in float32
|
| 63 |
+
|
| 64 |
|
| 65 |
def forward(self, x):
|
| 66 |
bsz, seq_len, d_in = x.size()
|