KrushiJethe commited on
Commit
3511918
·
1 Parent(s): ba1271a

added input_ids to BharatAI's forward method

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -116,9 +116,9 @@ class BharatAI(PreTrainedModel):
116
  elif isinstance(module, nn.Embedding):
117
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
 
119
- def forward(self, index, targets=None):
120
  B, T = index.shape
121
-
122
 
123
  # idx and targets are both (B,T) tensor of integers
124
  tok_emb = self.token_embedding_table(index) # (B,T,C)
 
116
  elif isinstance(module, nn.Embedding):
117
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
 
119
+ def forward(self, input_ids, index, targets=None):
120
  B, T = index.shape
121
+ x = input_ids
122
 
123
  # idx and targets are both (B,T) tensor of integers
124
  tok_emb = self.token_embedding_table(index) # (B,T,C)