klemenk commited on
Commit
a2afe6a
·
verified ·
1 Parent(s): f163227

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +1 -1
modeling_auristream.py CHANGED
@@ -72,7 +72,7 @@ class AuriStream(PreTrainedModel):
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
- def forward(self, seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
  """
77
  Input: coch: torch.Tensor of shape (b, t)
78
  tgt_coch: torch.Tensor of shape (b, t) or None
 
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
+ def forward(self, input_ids=seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
  """
77
  Input: coch: torch.Tensor of shape (b, t)
78
  tgt_coch: torch.Tensor of shape (b, t) or None