Update modeling_auristream.py
Browse files- modeling_auristream.py +1 -1
modeling_auristream.py
CHANGED
|
@@ -72,7 +72,7 @@ class AuriStream(PreTrainedModel):
|
|
| 72 |
elif isinstance(module, nn.Embedding):
|
| 73 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 74 |
|
| 75 |
-
def forward(self, seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
|
| 76 |
"""
|
| 77 |
Input: coch: torch.Tensor of shape (b, t)
|
| 78 |
tgt_coch: torch.Tensor of shape (b, t) or None
|
|
|
|
| 72 |
elif isinstance(module, nn.Embedding):
|
| 73 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 74 |
|
| 75 |
+
def forward(self, input_ids=seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
|
| 76 |
"""
|
| 77 |
Input: coch: torch.Tensor of shape (b, t)
|
| 78 |
tgt_coch: torch.Tensor of shape (b, t) or None
|