KrushiJethe commited on
Commit
7019440
·
1 Parent(s): 3511918

renamed targets to labels in bharatAI's forward method

Browse files
Files changed (1) hide show
  1. model.py +4 -4
model.py CHANGED
@@ -116,7 +116,7 @@ class BharatAI(PreTrainedModel):
116
  elif isinstance(module, nn.Embedding):
117
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
 
119
- def forward(self, input_ids, index, targets=None):
120
  B, T = index.shape
121
  x = input_ids
122
 
@@ -128,13 +128,13 @@ class BharatAI(PreTrainedModel):
128
  x = self.ln_f(x) # (B,T,C)
129
  logits = self.lm_head(x) # (B,T,vocab_size)
130
 
131
- if targets is None:
132
  loss = None
133
  else:
134
  B, T, C = logits.shape
135
  logits = logits.view(B*T, C)
136
- targets = targets.view(B*T)
137
- loss = F.cross_entropy(logits, targets)
138
 
139
  return logits, loss
140
 
 
116
  elif isinstance(module, nn.Embedding):
117
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
 
119
+ def forward(self, input_ids, index, labels=None): #, targets
120
  B, T = index.shape
121
  x = input_ids
122
 
 
128
  x = self.ln_f(x) # (B,T,C)
129
  logits = self.lm_head(x) # (B,T,vocab_size)
130
 
131
+ if labels is None:
132
  loss = None
133
  else:
134
  B, T, C = logits.shape
135
  logits = logits.view(B*T, C)
136
+ labels = labels.view(B*T)
137
+ loss = F.cross_entropy(logits, labels)
138
 
139
  return logits, loss
140