Commit
·
7019440
1
Parent(s):
3511918
renamed targets to labels in bharatAI's forward method
Browse files
model.py
CHANGED
|
@@ -116,7 +116,7 @@ class BharatAI(PreTrainedModel):
|
|
| 116 |
elif isinstance(module, nn.Embedding):
|
| 117 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 118 |
|
| 119 |
-
def forward(self, input_ids, index,
|
| 120 |
B, T = index.shape
|
| 121 |
x = input_ids
|
| 122 |
|
|
@@ -128,13 +128,13 @@ class BharatAI(PreTrainedModel):
|
|
| 128 |
x = self.ln_f(x) # (B,T,C)
|
| 129 |
logits = self.lm_head(x) # (B,T,vocab_size)
|
| 130 |
|
| 131 |
-
if
|
| 132 |
loss = None
|
| 133 |
else:
|
| 134 |
B, T, C = logits.shape
|
| 135 |
logits = logits.view(B*T, C)
|
| 136 |
-
|
| 137 |
-
loss = F.cross_entropy(logits,
|
| 138 |
|
| 139 |
return logits, loss
|
| 140 |
|
|
|
|
| 116 |
elif isinstance(module, nn.Embedding):
|
| 117 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 118 |
|
| 119 |
+
def forward(self, input_ids, index, labels=None): #, targets
|
| 120 |
B, T = index.shape
|
| 121 |
x = input_ids
|
| 122 |
|
|
|
|
| 128 |
x = self.ln_f(x) # (B,T,C)
|
| 129 |
logits = self.lm_head(x) # (B,T,vocab_size)
|
| 130 |
|
| 131 |
+
if labels is None:
|
| 132 |
loss = None
|
| 133 |
else:
|
| 134 |
B, T, C = logits.shape
|
| 135 |
logits = logits.view(B*T, C)
|
| 136 |
+
labels = labels.view(B*T)
|
| 137 |
+
loss = F.cross_entropy(logits, labels)
|
| 138 |
|
| 139 |
return logits, loss
|
| 140 |
|