Update modeling_auristream.py
Browse files- modeling_auristream.py +3 -1
modeling_auristream.py
CHANGED
|
@@ -72,7 +72,9 @@ class AuriStream(PreTrainedModel):
|
|
| 72 |
elif isinstance(module, nn.Embedding):
|
| 73 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 74 |
|
| 75 |
-
def forward(self, input_ids, tgt=None, output_logits=False,
|
|
|
|
|
|
|
| 76 |
"""
|
| 77 |
Input: coch: torch.Tensor of shape (b, t)
|
| 78 |
tgt_coch: torch.Tensor of shape (b, t) or None
|
|
|
|
| 72 |
elif isinstance(module, nn.Embedding):
|
| 73 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 74 |
|
| 75 |
+
def forward(self, input_ids, tgt=None, output_logits=False,
|
| 76 |
+
output_hidden_states=False, return_dict=False,
|
| 77 |
+
up_until_layer=None, **kwargs):
|
| 78 |
"""
|
| 79 |
Input: coch: torch.Tensor of shape (b, t)
|
| 80 |
tgt_coch: torch.Tensor of shape (b, t) or None
|