saracandu commited on
Commit
107cfff
·
verified ·
1 Parent(s): b5b542f

Update modeling_stldec.py

Browse files
Files changed (1) hide show
  1. modeling_stldec.py +1 -1
modeling_stldec.py CHANGED
@@ -2139,7 +2139,7 @@ class STLForCausalLM(STLModel, GenerationMixin):
2139
  loss = None
2140
  if labels is not None:
2141
  labels = labels.to(logits.device)
2142
- loss_fct = CrossEntropyLoss()
2143
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
2144
 
2145
  if not return_dict:
 
2139
  loss = None
2140
  if labels is not None:
2141
  labels = labels.to(logits.device)
2142
+ loss_fct = nn.CrossEntropyLoss()
2143
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
2144
 
2145
  if not return_dict: