klemenk commited on
Commit
0651ca3
·
verified ·
1 Parent(s): a2afe6a

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +3 -3
modeling_auristream.py CHANGED
@@ -72,18 +72,18 @@ class AuriStream(PreTrainedModel):
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
- def forward(self, input_ids=seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
  """
77
  Input: coch: torch.Tensor of shape (b, t)
78
  tgt_coch: torch.Tensor of shape (b, t) or None
79
  """
80
 
81
  # forward the GPT model itself
82
- tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
83
 
84
  # if wpe exists in self.transformer apply leanred positional embedding
85
  if hasattr(self.transformer, 'wpe'):
86
- pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
87
  pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
88
  x = self.transformer.drop(tok_emb + pos_emb)
89
  else:
 
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
+ def forward(self, input_ids, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
76
  """
77
  Input: coch: torch.Tensor of shape (b, t)
78
  tgt_coch: torch.Tensor of shape (b, t) or None
79
  """
80
 
81
  # forward the GPT model itself
82
+ tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd)
83
 
84
  # if wpe exists in self.transformer apply leanred positional embedding
85
  if hasattr(self.transformer, 'wpe'):
86
+ pos = torch.arange(0, input_ids.size(1), dtype=torch.long, device=input_ids.device)
87
  pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
88
  x = self.transformer.drop(tok_emb + pos_emb)
89
  else: