jaj
Browse files- modelLM.py +6 -1
modelLM.py
CHANGED
|
@@ -22,10 +22,15 @@ class OBILanguageModel(PreTrainedModel):
|
|
| 22 |
self.ln2 = nn.LayerNorm(config.hidden_size)
|
| 23 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) # Use length of SentencePiece vocab
|
| 24 |
|
| 25 |
-
def forward(self, idx, targets=None):
|
| 26 |
tok_emb = self.token_embedding_table(idx)
|
| 27 |
pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
|
| 28 |
x = tok_emb + pos_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
x = self.transformer(x, x)
|
| 30 |
x = self.ln1(x)
|
| 31 |
x = self.ln2(x)
|
|
|
|
| 22 |
self.ln2 = nn.LayerNorm(config.hidden_size)
|
| 23 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) # Use length of SentencePiece vocab
|
| 24 |
|
| 25 |
+
def forward(self, idx, attention_mask=None, targets=None):
|
| 26 |
tok_emb = self.token_embedding_table(idx)
|
| 27 |
pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu'))
|
| 28 |
x = tok_emb + pos_emb
|
| 29 |
+
|
| 30 |
+
# Assuming you need to add attention_mask here
|
| 31 |
+
if attention_mask is not None:
|
| 32 |
+
x *= attention_mask.unsqueeze(-1)
|
| 33 |
+
|
| 34 |
x = self.transformer(x, x)
|
| 35 |
x = self.ln1(x)
|
| 36 |
x = self.ln2(x)
|