Update modeling_auristream.py
Browse files- modeling_auristream.py +3 -3
modeling_auristream.py
CHANGED
|
@@ -72,18 +72,18 @@ class AuriStream(PreTrainedModel):
|
|
| 72 |
elif isinstance(module, nn.Embedding):
|
| 73 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 74 |
|
| 75 |
-
def forward(self, input_ids
|
| 76 |
"""
|
| 77 |
Input: coch: torch.Tensor of shape (b, t)
|
| 78 |
tgt_coch: torch.Tensor of shape (b, t) or None
|
| 79 |
"""
|
| 80 |
|
| 81 |
# forward the GPT model itself
|
| 82 |
-
tok_emb = self.transformer.wte(
|
| 83 |
|
| 84 |
# if wpe exists in self.transformer apply leanred positional embedding
|
| 85 |
if hasattr(self.transformer, 'wpe'):
|
| 86 |
-
pos = torch.arange(0,
|
| 87 |
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
| 88 |
x = self.transformer.drop(tok_emb + pos_emb)
|
| 89 |
else:
|
|
|
|
| 72 |
elif isinstance(module, nn.Embedding):
|
| 73 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 74 |
|
| 75 |
+
def forward(self, input_ids, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
|
| 76 |
"""
|
| 77 |
Input: coch: torch.Tensor of shape (b, t)
|
| 78 |
tgt_coch: torch.Tensor of shape (b, t) or None
|
| 79 |
"""
|
| 80 |
|
| 81 |
# forward the GPT model itself
|
| 82 |
+
tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd)
|
| 83 |
|
| 84 |
# if wpe exists in self.transformer apply leanred positional embedding
|
| 85 |
if hasattr(self.transformer, 'wpe'):
|
| 86 |
+
pos = torch.arange(0, input_ids.size(1), dtype=torch.long, device=input_ids.device)
|
| 87 |
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
| 88 |
x = self.transformer.drop(tok_emb + pos_emb)
|
| 89 |
else:
|