Update modeling_stldec.py
Browse files- modeling_stldec.py +1 -1
modeling_stldec.py
CHANGED
|
@@ -2139,7 +2139,7 @@ class STLForCausalLM(STLModel, GenerationMixin):
|
|
| 2139 |
loss = None
|
| 2140 |
if labels is not None:
|
| 2141 |
labels = labels.to(logits.device)
|
| 2142 |
-
loss_fct = CrossEntropyLoss()
|
| 2143 |
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 2144 |
|
| 2145 |
if not return_dict:
|
|
|
|
| 2139 |
loss = None
|
| 2140 |
if labels is not None:
|
| 2141 |
labels = labels.to(logits.device)
|
| 2142 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 2143 |
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 2144 |
|
| 2145 |
if not return_dict:
|