Update model.py
Browse files
model.py
CHANGED
|
@@ -350,6 +350,8 @@ class StripedHyena(nn.Module):
|
|
| 350 |
self.blocks = nn.ModuleList(
|
| 351 |
get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
|
| 352 |
)
|
|
|
|
|
|
|
| 353 |
|
| 354 |
def forward(self, x, inference_params_dict=None, padding_mask=None):
|
| 355 |
L = x.shape[1]
|
|
|
|
| 350 |
self.blocks = nn.ModuleList(
|
| 351 |
get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
|
| 352 |
)
|
| 353 |
+
self.gradient_checkpointing = False
|
| 354 |
+
self._gradient_checkpointing_func = None
|
| 355 |
|
| 356 |
def forward(self, x, inference_params_dict=None, padding_mask=None):
|
| 357 |
L = x.shape[1]
|