Update modeling_mambavision.py

#2
by jeongminl - opened
Files changed (1) hide show
  1. modeling_mambavision.py +1 -1
modeling_mambavision.py CHANGED
@@ -778,6 +778,6 @@ class MambaVisionModelForImageClassification(PreTrainedModel):
778
  def forward(self, tensor, labels=None):
779
  logits = self.model(tensor)
780
  if labels is not None:
781
- loss = torch.nn.cross_entropy(logits, labels)
782
  return {"loss": loss, "logits": logits}
783
  return {"logits": logits}
 
778
  def forward(self, tensor, labels=None):
779
  logits = self.model(tensor)
780
  if labels is not None:
781
+ loss = torch.nn.functional.cross_entropy(logits, labels)
782
  return {"loss": loss, "logits": logits}
783
  return {"logits": logits}