Update model.py
Browse files
model.py
CHANGED
|
@@ -22,7 +22,7 @@ try:
|
|
| 22 |
except ImportError:
|
| 23 |
"could not import swap_mha_rope from positional_embeddings.py"
|
| 24 |
|
| 25 |
-
from flashfftconv import
|
| 26 |
|
| 27 |
# dummy import to force huggingface to bundle the tokenizer
|
| 28 |
from .tokenizer import ByteTokenizer
|
|
@@ -122,7 +122,7 @@ class ParallelHyenaFilter(nn.Module):
|
|
| 122 |
self.data_dtype = None
|
| 123 |
|
| 124 |
if self.use_flash_depthwise:
|
| 125 |
-
self.fir_fn =
|
| 126 |
channels=3 * self.hidden_size,
|
| 127 |
kernel_size=self.short_filter_length,
|
| 128 |
padding=self.short_filter_length - 1,
|
|
|
|
| 22 |
except ImportError:
|
| 23 |
"could not import swap_mha_rope from positional_embeddings.py"
|
| 24 |
|
| 25 |
+
from flashfftconv import FlashDepthWiseConv1d
|
| 26 |
|
| 27 |
# dummy import to force huggingface to bundle the tokenizer
|
| 28 |
from .tokenizer import ByteTokenizer
|
|
|
|
| 122 |
self.data_dtype = None
|
| 123 |
|
| 124 |
if self.use_flash_depthwise:
|
| 125 |
+
self.fir_fn = FlashDepthWiseConv1d(
|
| 126 |
channels=3 * self.hidden_size,
|
| 127 |
kernel_size=self.short_filter_length,
|
| 128 |
padding=self.short_filter_length - 1,
|