Hannes Kuchelmeister commited on
Commit ·
754b856
1
Parent(s): 90cf80d
rename metric to mae
Browse files
models/src/models/focus_module.py
CHANGED
|
@@ -71,12 +71,12 @@ class FocusLitModule(LightningModule):
|
|
| 71 |
|
| 72 |
# use separate metric instance for train, val and test step
|
| 73 |
# to ensure a proper reduction over the epoch
|
| 74 |
-
self.
|
| 75 |
-
self.
|
| 76 |
-
self.
|
| 77 |
|
| 78 |
# for logging best so far validation accuracy
|
| 79 |
-
self.
|
| 80 |
|
| 81 |
def forward(self, x: torch.Tensor):
|
| 82 |
return self.model(x)
|
|
@@ -93,9 +93,9 @@ class FocusLitModule(LightningModule):
|
|
| 93 |
loss, preds, targets = self.step(batch)
|
| 94 |
|
| 95 |
# log train metrics
|
| 96 |
-
|
| 97 |
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
| 98 |
-
self.log("train/
|
| 99 |
|
| 100 |
# we can return here dict with any tensors
|
| 101 |
# and then read it in some callback or in `training_epoch_end()`` below
|
|
@@ -110,26 +110,26 @@ class FocusLitModule(LightningModule):
|
|
| 110 |
loss, preds, targets = self.step(batch)
|
| 111 |
|
| 112 |
# log val metrics
|
| 113 |
-
|
| 114 |
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
| 115 |
-
self.log("val/
|
| 116 |
|
| 117 |
return {"loss": loss, "preds": preds, "targets": targets}
|
| 118 |
|
| 119 |
def validation_epoch_end(self, outputs: List[Any]):
|
| 120 |
-
|
| 121 |
-
self.
|
| 122 |
self.log(
|
| 123 |
-
"val/
|
| 124 |
)
|
| 125 |
|
| 126 |
def test_step(self, batch: Any, batch_idx: int):
|
| 127 |
loss, preds, targets = self.step(batch)
|
| 128 |
|
| 129 |
# log test metrics
|
| 130 |
-
|
| 131 |
self.log("test/loss", loss, on_step=False, on_epoch=True)
|
| 132 |
-
self.log("test/
|
| 133 |
|
| 134 |
return {"loss": loss, "preds": preds, "targets": targets}
|
| 135 |
|
|
@@ -138,9 +138,9 @@ class FocusLitModule(LightningModule):
|
|
| 138 |
|
| 139 |
def on_epoch_end(self):
|
| 140 |
# reset metrics at the end of every epoch
|
| 141 |
-
self.
|
| 142 |
-
self.
|
| 143 |
-
self.
|
| 144 |
|
| 145 |
def configure_optimizers(self):
|
| 146 |
"""Choose what optimizers and learning-rate schedulers.
|
|
|
|
| 71 |
|
| 72 |
# use separate metric instance for train, val and test step
|
| 73 |
# to ensure a proper reduction over the epoch
|
| 74 |
+
self.train_mae = MeanAbsoluteError()
|
| 75 |
+
self.val_mae = MeanAbsoluteError()
|
| 76 |
+
self.test_mae = MeanAbsoluteError()
|
| 77 |
|
| 78 |
# for logging best so far validation accuracy
|
| 79 |
+
self.val_mae_best = MinMetric()
|
| 80 |
|
| 81 |
def forward(self, x: torch.Tensor):
|
| 82 |
return self.model(x)
|
|
|
|
| 93 |
loss, preds, targets = self.step(batch)
|
| 94 |
|
| 95 |
# log train metrics
|
| 96 |
+
mae = self.train_mae(preds, targets)
|
| 97 |
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
| 98 |
+
self.log("train/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
| 99 |
|
| 100 |
# we can return here dict with any tensors
|
| 101 |
# and then read it in some callback or in `training_epoch_end()`` below
|
|
|
|
| 110 |
loss, preds, targets = self.step(batch)
|
| 111 |
|
| 112 |
# log val metrics
|
| 113 |
+
mae = self.val_mae(preds, targets)
|
| 114 |
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
| 115 |
+
self.log("val/mae", mae, on_step=False, on_epoch=True, prog_bar=True)
|
| 116 |
|
| 117 |
return {"loss": loss, "preds": preds, "targets": targets}
|
| 118 |
|
| 119 |
def validation_epoch_end(self, outputs: List[Any]):
|
| 120 |
+
mae = self.val_mae.compute() # get val accuracy from current epoch
|
| 121 |
+
self.val_mae_best.update(mae)
|
| 122 |
self.log(
|
| 123 |
+
"val/mae_best", self.val_mae_best.compute(), on_epoch=True, prog_bar=True
|
| 124 |
)
|
| 125 |
|
| 126 |
def test_step(self, batch: Any, batch_idx: int):
|
| 127 |
loss, preds, targets = self.step(batch)
|
| 128 |
|
| 129 |
# log test metrics
|
| 130 |
+
mae = self.test_mae(preds, targets)
|
| 131 |
self.log("test/loss", loss, on_step=False, on_epoch=True)
|
| 132 |
+
self.log("test/mae", mae, on_step=False, on_epoch=True)
|
| 133 |
|
| 134 |
return {"loss": loss, "preds": preds, "targets": targets}
|
| 135 |
|
|
|
|
| 138 |
|
| 139 |
def on_epoch_end(self):
|
| 140 |
# reset metrics at the end of every epoch
|
| 141 |
+
self.train_mae.reset()
|
| 142 |
+
self.test_mae.reset()
|
| 143 |
+
self.val_mae.reset()
|
| 144 |
|
| 145 |
def configure_optimizers(self):
|
| 146 |
"""Choose what optimizers and learning-rate schedulers.
|