import numpy as np import torch from base import BaseTrainer from torchvision.utils import make_grid from utils import MetricTracker class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, device, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.device = device self.data_loader = data_loader self.len_epoch = len(self.data_loader) self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.adaptive_step = config['trainer']['adaptive_step'] self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ if epoch > self.adaptive_step and epoch % self.adaptive_step == 1: dataset = self.data_loader.inference.dataset self.model.eval() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.data_loader.inference): data, target = data.to(self.device), target.to(self.device) output = self.model(data) batch_size = self.data_loader.inference.batch_size patch_idx = torch.arange( batch_size * batch_idx, batch_size * batch_idx + data.shape[0]) pred = torch.argmax(output, dim=1) dataset.patches.store_data(patch_idx, [pred.unsqueeze(1)]) preds = [dataset.patches.combine(idx, data_idx=0)[0].cpu() for idx in range(len(dataset.data))] self.data_loader.update_dataset(preds) self.len_epoch = len(self.data_loader) self.model.train() self.train_metrics.reset() for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, target)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) self.writer.add_image('input', make_grid( data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break self.writer.next() self.train_metrics.add_scalers() log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update( met.__name__, met(output, target)) self.writer.add_image('input', make_grid( data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') self.writer.next('valid') self.valid_metrics.add_scalers() return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)