|
|
import logging |
|
|
from typing import Callable, Dict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from accelerate import Accelerator |
|
|
except ImportError: |
|
|
Accelerator = None |
|
|
|
|
|
from kornia.metrics import AverageMeter |
|
|
|
|
|
from .utils import Configuration, TrainerState |
|
|
|
|
|
callbacks_whitelist = [ |
|
|
|
|
|
"preprocess", |
|
|
"augmentations", |
|
|
"evaluate", |
|
|
"fit", |
|
|
"fit_epoch", |
|
|
|
|
|
"on_epoch_start", |
|
|
"on_before_model", |
|
|
"on_after_model", |
|
|
"on_checkpoint", |
|
|
"on_epoch_end", |
|
|
] |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
"""Base class to train the different models in kornia. |
|
|
|
|
|
.. warning:: |
|
|
The API is experimental and subject to be modified based on the needs of kornia models. |
|
|
|
|
|
Args: |
|
|
model: the nn.Module to be optimized. |
|
|
train_dataloader: the data loader used in the training loop. |
|
|
valid_dataloader: the data loader used in the validation loop. |
|
|
criterion: the nn.Module with the function that computes the loss. |
|
|
optimizer: the torch optimizer object to be used during the optimization. |
|
|
scheduler: the torch scheduler object with defiing the scheduling strategy. |
|
|
accelerator: the Accelerator object to distribute the training. |
|
|
config: a TrainerConfiguration structure containing the experiment hyper parameters. |
|
|
callbacks: a dictionary containing the pointers to the functions to overrides. The |
|
|
main supported hooks are ``evaluate``, ``preprocess``, ``augmentations`` and ``fit``. |
|
|
|
|
|
.. important:: |
|
|
The API heavily relies on `accelerate <https://github.com/huggingface/accelerate/>`_. |
|
|
In order to use it, you must: ``pip install kornia[x]`` |
|
|
|
|
|
.. seealso:: |
|
|
Learn how to use the API in our documentation |
|
|
`here <https://kornia.readthedocs.io/en/latest/get-started/training.html>`_. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
train_dataloader: DataLoader, |
|
|
valid_dataloader: DataLoader, |
|
|
criterion: nn.Module, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler: torch.optim.lr_scheduler.CosineAnnealingLR, |
|
|
config: Configuration, |
|
|
callbacks: Dict[str, Callable] = {}, |
|
|
) -> None: |
|
|
|
|
|
if Accelerator is None: |
|
|
raise ModuleNotFoundError( |
|
|
"accelerate library is not installed: pip install kornia[x]") |
|
|
self.accelerator = Accelerator() |
|
|
|
|
|
|
|
|
self.model = self.accelerator.prepare(model) |
|
|
self.train_dataloader = self.accelerator.prepare(train_dataloader) |
|
|
self.valid_dataloader = self.accelerator.prepare(valid_dataloader) |
|
|
self.criterion = criterion.to(self.device) |
|
|
self.optimizer = self.accelerator.prepare(optimizer) |
|
|
self.scheduler = scheduler |
|
|
self.config = config |
|
|
|
|
|
|
|
|
for fn_name, fn in callbacks.items(): |
|
|
if fn_name not in callbacks_whitelist: |
|
|
raise ValueError(f"Not supported: {fn_name}.") |
|
|
setattr(Trainer, fn_name, fn) |
|
|
|
|
|
|
|
|
self.num_epochs = config.num_epochs |
|
|
|
|
|
self.state = TrainerState.STARTING |
|
|
|
|
|
self._logger = logging.getLogger('train') |
|
|
|
|
|
@property |
|
|
def device(self) -> torch.device: |
|
|
return self.accelerator.device |
|
|
|
|
|
def backward(self, loss: torch.Tensor) -> None: |
|
|
self.accelerator.backward(loss) |
|
|
|
|
|
def fit_epoch(self, epoch: int) -> None: |
|
|
|
|
|
self.model.train() |
|
|
losses = AverageMeter() |
|
|
for sample_id, sample in enumerate(self.train_dataloader): |
|
|
sample = {"input": sample[0], "target": sample[1]} |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
sample = self.preprocess(sample) |
|
|
sample = self.augmentations(sample) |
|
|
sample = self.on_before_model(sample) |
|
|
|
|
|
output = self.model(sample["input"]) |
|
|
self.on_after_model(output, sample) |
|
|
loss = self.criterion(output, sample["target"]) |
|
|
self.backward(loss) |
|
|
self.optimizer.step() |
|
|
|
|
|
losses.update(loss.item(), sample["target"].shape[0]) |
|
|
|
|
|
if sample_id % 50 == 0: |
|
|
self._logger.info( |
|
|
f"Train: {epoch + 1}/{self.num_epochs} " |
|
|
f"Sample: {sample_id + 1}/{len(self.train_dataloader)} " |
|
|
f"Loss: {losses.val:.3f} {losses.avg:.3f}" |
|
|
) |
|
|
|
|
|
def fit(self,) -> None: |
|
|
|
|
|
|
|
|
for epoch in range(self.num_epochs): |
|
|
|
|
|
|
|
|
self.state = TrainerState.TRAINING |
|
|
self.fit_epoch(epoch) |
|
|
|
|
|
|
|
|
|
|
|
self.state = TrainerState.VALIDATE |
|
|
valid_stats = self.evaluate() |
|
|
|
|
|
self.on_checkpoint(self.model, epoch, valid_stats) |
|
|
|
|
|
self.on_epoch_end() |
|
|
if self.state == TrainerState.TERMINATE: |
|
|
break |
|
|
|
|
|
|
|
|
self.scheduler.step() |
|
|
|
|
|
... |
|
|
|
|
|
|
|
|
|
|
|
def evaluate(self): |
|
|
... |
|
|
|
|
|
def on_epoch_start(self, *args, **kwargs): |
|
|
... |
|
|
|
|
|
def preprocess(self, x: dict) -> dict: |
|
|
return x |
|
|
|
|
|
def augmentations(self, x: dict) -> dict: |
|
|
return x |
|
|
|
|
|
def on_before_model(self, x: dict) -> dict: |
|
|
return x |
|
|
|
|
|
def on_after_model(self, output: torch.Tensor, sample: dict): |
|
|
... |
|
|
|
|
|
def on_checkpoint(self, *args, **kwargs): |
|
|
... |
|
|
|
|
|
def on_epoch_end(self, *args, **kwargs): |
|
|
... |
|
|
|