Update model.py
Browse files
model.py
CHANGED
|
@@ -355,7 +355,7 @@ class StripedHyena(nn.Module):
|
|
| 355 |
self.gradient_checkpointing = False
|
| 356 |
self._gradient_checkpointing_func = None
|
| 357 |
|
| 358 |
-
def forward(self,
|
| 359 |
L = x.shape[1]
|
| 360 |
x = self.embedding_layer.embed(x)
|
| 361 |
if inference_params_dict is not None:
|
|
@@ -370,7 +370,7 @@ class StripedHyena(nn.Module):
|
|
| 370 |
x = self.unembed.unembed(x)
|
| 371 |
return x, inference_params_dict_out
|
| 372 |
|
| 373 |
-
def stateful_forward(self,
|
| 374 |
for block_idx, block in enumerate(self.blocks):
|
| 375 |
block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
|
| 376 |
inference_params = inference_params_dict[block_name]
|
|
@@ -378,7 +378,7 @@ class StripedHyena(nn.Module):
|
|
| 378 |
|
| 379 |
return x, inference_params_dict
|
| 380 |
|
| 381 |
-
def stateless_forward(self,
|
| 382 |
if type(padding_mask) == torch.Tensor:
|
| 383 |
x = x * padding_mask[..., None]
|
| 384 |
|
|
|
|
| 355 |
self.gradient_checkpointing = False
|
| 356 |
self._gradient_checkpointing_func = None
|
| 357 |
|
| 358 |
+
def forward(self, input_ids, inference_params_dict=None, padding_mask=None):
|
| 359 |
L = x.shape[1]
|
| 360 |
x = self.embedding_layer.embed(x)
|
| 361 |
if inference_params_dict is not None:
|
|
|
|
| 370 |
x = self.unembed.unembed(x)
|
| 371 |
return x, inference_params_dict_out
|
| 372 |
|
| 373 |
+
def stateful_forward(self, input_ids, inference_params_dict=None):
|
| 374 |
for block_idx, block in enumerate(self.blocks):
|
| 375 |
block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
|
| 376 |
inference_params = inference_params_dict[block_name]
|
|
|
|
| 378 |
|
| 379 |
return x, inference_params_dict
|
| 380 |
|
| 381 |
+
def stateless_forward(self, input_ids, padding_mask=None):
|
| 382 |
if type(padding_mask) == torch.Tensor:
|
| 383 |
x = x * padding_mask[..., None]
|
| 384 |
|