|
|
import torch |
|
|
|
|
|
from kornia.metrics import accuracy, AverageMeter, mean_iou |
|
|
|
|
|
from .trainer import Trainer |
|
|
|
|
|
|
|
|
class ImageClassifierTrainer(Trainer): |
|
|
"""Module to be used for Image Classification purposes. |
|
|
|
|
|
The module subclasses :py:class:`~kornia.x.Trainer` and overrides the |
|
|
:py:func:`~kornia.x.Trainer.evaluate` function implementing a standard |
|
|
:py:func:`~kornia.metrics.accuracy` topk@[1, 5]. |
|
|
|
|
|
.. seealso:: |
|
|
Learn how to use this class in the following |
|
|
`example <https://github.com/kornia/kornia/blob/master/examples/train/image_classifier/>`__. |
|
|
""" |
|
|
@torch.no_grad() |
|
|
def evaluate(self) -> dict: |
|
|
self.model.eval() |
|
|
stats = {'losses': AverageMeter(), 'top1': AverageMeter(), 'top5': AverageMeter()} |
|
|
for sample_id, sample in enumerate(self.valid_dataloader): |
|
|
sample = {"input": sample[0], "target": sample[1]} |
|
|
|
|
|
sample = self.preprocess(sample) |
|
|
|
|
|
out = self.model(sample["input"]) |
|
|
|
|
|
val_loss = self.criterion(out, sample["target"]) |
|
|
|
|
|
|
|
|
acc1, acc5 = accuracy(out.detach(), sample["target"], topk=(1, 5)) |
|
|
batch_size: int = sample["input"].shape[0] |
|
|
stats['losses'].update(val_loss.item(), batch_size) |
|
|
stats['top1'].update(acc1.item(), batch_size) |
|
|
stats['top5'].update(acc5.item(), batch_size) |
|
|
|
|
|
if sample_id % 10 == 0: |
|
|
self._logger.info( |
|
|
f"Test: {sample_id}/{len(self.valid_dataloader)} " |
|
|
f"Loss: {stats['losses'].val:.2f} {stats['losses'].avg:.2f} " |
|
|
f"Acc@1: {stats['top1'].val:.2f} {stats['top1'].val:.2f} " |
|
|
f"Acc@5: {stats['top5'].val:.2f} {stats['top5'].val:.2f} " |
|
|
) |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
class SemanticSegmentationTrainer(Trainer): |
|
|
"""Module to be used for Semantic segmentation purposes. |
|
|
|
|
|
The module subclasses :py:class:`~kornia.x.Trainer` and overrides the |
|
|
:py:func:`~kornia.x.Trainer.evaluate` function implementing IoU :py:func:`~kornia.metrics.mean_iou`. |
|
|
|
|
|
.. seealso:: |
|
|
Learn how to use this class in the following |
|
|
`example <https://github.com/kornia/kornia/blob/master/examples/train/semantic_segmentation/>`__. |
|
|
""" |
|
|
@torch.no_grad() |
|
|
def evaluate(self) -> dict: |
|
|
self.model.eval() |
|
|
stats = {'losses': AverageMeter(), 'iou': AverageMeter()} |
|
|
for sample_id, sample in enumerate(self.valid_dataloader): |
|
|
sample = {"input": sample[0], "target": sample[1]} |
|
|
|
|
|
sample = self.preprocess(sample) |
|
|
sample = self.on_before_model(sample) |
|
|
|
|
|
out = self.model(sample["input"]) |
|
|
self.on_after_model(out, sample) |
|
|
|
|
|
val_loss = self.criterion(out, sample["target"]) |
|
|
|
|
|
|
|
|
iou = mean_iou(out.argmax(1), sample["target"], out.shape[1]).mean() |
|
|
batch_size: int = sample["input"].shape[0] |
|
|
stats['losses'].update(val_loss.item(), batch_size) |
|
|
stats['iou'].update(iou, batch_size) |
|
|
|
|
|
if sample_id % 10 == 0: |
|
|
self._logger.info( |
|
|
f"Test: {sample_id}/{len(self.valid_dataloader)} " |
|
|
f"Loss: {stats['losses'].val:.2f} {stats['losses'].avg:.2f} " |
|
|
f"IoU: {stats['iou'].val:.2f} {stats['iou'].val:.2f} " |
|
|
) |
|
|
|
|
|
return stats |
|
|
|