klemenk commited on
Commit
cce6654
·
verified ·
1 Parent(s): 0651ca3

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +3 -1
modeling_auristream.py CHANGED
@@ -72,7 +72,9 @@ class AuriStream(PreTrainedModel):
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
- def forward(self, input_ids, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
 
 
76
  """
77
  Input: coch: torch.Tensor of shape (b, t)
78
  tgt_coch: torch.Tensor of shape (b, t) or None
 
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
+ def forward(self, input_ids, tgt=None, output_logits=False,
76
+ output_hidden_states=False, return_dict=False,
77
+ up_until_layer=None, **kwargs):
78
  """
79
  Input: coch: torch.Tensor of shape (b, t)
80
  tgt_coch: torch.Tensor of shape (b, t) or None