Update model.py
Browse files
model.py
CHANGED
|
@@ -343,7 +343,7 @@ class StripedHyena(nn.Module):
|
|
| 343 |
from flashfftconv import FlashFFTConv
|
| 344 |
except:
|
| 345 |
raise ImportError
|
| 346 |
-
self.flash_fft = FlashFFTConv(2 * config.
|
| 347 |
else:
|
| 348 |
self.flash_fft = None
|
| 349 |
|
|
|
|
| 343 |
from flashfftconv import FlashFFTConv
|
| 344 |
except:
|
| 345 |
raise ImportError
|
| 346 |
+
self.flash_fft = FlashFFTConv(2 * config.max_seqlen, dtype=torch.bfloat16)
|
| 347 |
else:
|
| 348 |
self.flash_fft = None
|
| 349 |
|