Commit ·
1177ff0
1
Parent(s): 1f5f496
fix
Browse files- __pycache__/layers.cpython-312.pyc +0 -0
- convolve.py +1 -3
- stu.py +1 -1
__pycache__/layers.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/layers.cpython-312.pyc and b/__pycache__/layers.cpython-312.pyc differ
|
|
|
convolve.py
CHANGED
|
@@ -41,8 +41,6 @@ def flash_convolve(
|
|
| 41 |
u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
|
| 42 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 43 |
dtype = u.dtype # Store the original dtype
|
| 44 |
-
u = u.to(torch.float32)
|
| 45 |
-
v = v.to(torch.float32)
|
| 46 |
|
| 47 |
bsz, seq_len, d_in = u.shape
|
| 48 |
_, K = v.shape
|
|
@@ -50,7 +48,7 @@ def flash_convolve(
|
|
| 50 |
padded_len = nearest_power_of_two(seq_len, round_up=True)
|
| 51 |
pad_len = padded_len - seq_len
|
| 52 |
|
| 53 |
-
sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=
|
| 54 |
sgn[:, :, 1::2] = -1
|
| 55 |
|
| 56 |
if use_approx:
|
|
|
|
| 41 |
u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
|
| 42 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 43 |
dtype = u.dtype # Store the original dtype
|
|
|
|
|
|
|
| 44 |
|
| 45 |
bsz, seq_len, d_in = u.shape
|
| 46 |
_, K = v.shape
|
|
|
|
| 48 |
padded_len = nearest_power_of_two(seq_len, round_up=True)
|
| 49 |
pad_len = padded_len - seq_len
|
| 50 |
|
| 51 |
+
sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=dtype)
|
| 52 |
sgn[:, :, 1::2] = -1
|
| 53 |
|
| 54 |
if use_approx:
|
stu.py
CHANGED
|
@@ -30,7 +30,7 @@ class STU(nn.Module):
|
|
| 30 |
self.use_hankel_L = config.use_hankel_L
|
| 31 |
self.use_approx = config.use_approx
|
| 32 |
self.flash_fft = (
|
| 33 |
-
FlashFFTConv(self.n, dtype=
|
| 34 |
if config.use_flash_fft and flash_fft_available
|
| 35 |
else None
|
| 36 |
)
|
|
|
|
| 30 |
self.use_hankel_L = config.use_hankel_L
|
| 31 |
self.use_approx = config.use_approx
|
| 32 |
self.flash_fft = (
|
| 33 |
+
FlashFFTConv(self.n, dtype=torch_dtype)
|
| 34 |
if config.use_flash_fft and flash_fft_available
|
| 35 |
else None
|
| 36 |
)
|