Update modeling_mambavision.py
#2
by
jeongminl
- opened
- modeling_mambavision.py +1 -1
modeling_mambavision.py
CHANGED
|
@@ -778,6 +778,6 @@ class MambaVisionModelForImageClassification(PreTrainedModel):
|
|
| 778 |
def forward(self, tensor, labels=None):
|
| 779 |
logits = self.model(tensor)
|
| 780 |
if labels is not None:
|
| 781 |
-
loss = torch.nn.cross_entropy(logits, labels)
|
| 782 |
return {"loss": loss, "logits": logits}
|
| 783 |
return {"logits": logits}
|
|
|
|
| 778 |
def forward(self, tensor, labels=None):
|
| 779 |
logits = self.model(tensor)
|
| 780 |
if labels is not None:
|
| 781 |
+
loss = torch.nn.functional.cross_entropy(logits, labels)
|
| 782 |
return {"loss": loss, "logits": logits}
|
| 783 |
return {"logits": logits}
|