Commit ·
3511918
1
Parent(s): ba1271a
added input_ids to BharatAI's forward method
Browse files
model.py
CHANGED
|
@@ -116,9 +116,9 @@ class BharatAI(PreTrainedModel):
|
|
| 116 |
elif isinstance(module, nn.Embedding):
|
| 117 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 118 |
|
| 119 |
-
def forward(self, index, targets=None):
|
| 120 |
B, T = index.shape
|
| 121 |
-
|
| 122 |
|
| 123 |
# idx and targets are both (B,T) tensor of integers
|
| 124 |
tok_emb = self.token_embedding_table(index) # (B,T,C)
|
|
|
|
| 116 |
elif isinstance(module, nn.Embedding):
|
| 117 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 118 |
|
| 119 |
+
def forward(self, input_ids, index, targets=None):
|
| 120 |
B, T = index.shape
|
| 121 |
+
x = input_ids
|
| 122 |
|
| 123 |
# idx and targets are both (B,T) tensor of integers
|
| 124 |
tok_emb = self.token_embedding_table(index) # (B,T,C)
|