Hannes Kuchelmeister
commited on
Commit
·
90cf80d
1
Parent(s):
ee280e5
fix predictions to not use argmax
Browse files
models/src/models/focus_module.py
CHANGED
|
@@ -86,7 +86,7 @@ class FocusLitModule(LightningModule):
|
|
| 86 |
y = batch["focus_value"]
|
| 87 |
logits = self.forward(x)
|
| 88 |
loss = self.criterion(logits, y)
|
| 89 |
-
preds = torch.
|
| 90 |
return loss, preds, y
|
| 91 |
|
| 92 |
def training_step(self, batch: Any, batch_idx: int):
|
|
|
|
| 86 |
y = batch["focus_value"]
|
| 87 |
logits = self.forward(x)
|
| 88 |
loss = self.criterion(logits, y)
|
| 89 |
+
preds = torch.squeeze(logits)
|
| 90 |
return loss, preds, y
|
| 91 |
|
| 92 |
def training_step(self, batch: Any, batch_idx: int):
|