|
|
import numpy as np |
|
|
import torch |
|
|
from base import BaseTrainer |
|
|
from torchvision.utils import make_grid |
|
|
from utils import MetricTracker |
|
|
|
|
|
|
|
|
class Trainer(BaseTrainer): |
|
|
""" |
|
|
Trainer class |
|
|
""" |
|
|
|
|
|
def __init__(self, model, criterion, metric_ftns, optimizer, config, device, |
|
|
data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): |
|
|
super().__init__(model, criterion, metric_ftns, optimizer, config) |
|
|
self.config = config |
|
|
self.device = device |
|
|
self.data_loader = data_loader |
|
|
self.len_epoch = len(self.data_loader) |
|
|
self.valid_data_loader = valid_data_loader |
|
|
self.do_validation = self.valid_data_loader is not None |
|
|
self.lr_scheduler = lr_scheduler |
|
|
self.log_step = int(np.sqrt(data_loader.batch_size)) |
|
|
self.adaptive_step = config['trainer']['adaptive_step'] |
|
|
|
|
|
self.train_metrics = MetricTracker( |
|
|
'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) |
|
|
self.valid_metrics = MetricTracker( |
|
|
'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) |
|
|
|
|
|
def _train_epoch(self, epoch): |
|
|
""" |
|
|
Training logic for an epoch |
|
|
|
|
|
:param epoch: Integer, current training epoch. |
|
|
:return: A log that contains average loss and metric in this epoch. |
|
|
""" |
|
|
if epoch > self.adaptive_step and epoch % self.adaptive_step == 1: |
|
|
dataset = self.data_loader.inference.dataset |
|
|
self.model.eval() |
|
|
with torch.no_grad(): |
|
|
for batch_idx, (data, target) in enumerate(self.data_loader.inference): |
|
|
data, target = data.to(self.device), target.to(self.device) |
|
|
output = self.model(data) |
|
|
|
|
|
batch_size = self.data_loader.inference.batch_size |
|
|
patch_idx = torch.arange( |
|
|
batch_size * batch_idx, batch_size * batch_idx + data.shape[0]) |
|
|
pred = torch.argmax(output, dim=1) |
|
|
dataset.patches.store_data(patch_idx, [pred.unsqueeze(1)]) |
|
|
|
|
|
preds = [dataset.patches.combine(idx, data_idx=0)[0].cpu() |
|
|
for idx in range(len(dataset.data))] |
|
|
|
|
|
self.data_loader.update_dataset(preds) |
|
|
self.len_epoch = len(self.data_loader) |
|
|
|
|
|
self.model.train() |
|
|
self.train_metrics.reset() |
|
|
for batch_idx, (data, target) in enumerate(self.data_loader): |
|
|
data, target = data.to(self.device), target.to(self.device) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
output = self.model(data) |
|
|
loss = self.criterion(output, target) |
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
self.train_metrics.update('loss', loss.item()) |
|
|
for met in self.metric_ftns: |
|
|
self.train_metrics.update(met.__name__, met(output, target)) |
|
|
|
|
|
if batch_idx % self.log_step == 0: |
|
|
self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( |
|
|
epoch, |
|
|
self._progress(batch_idx), |
|
|
loss.item())) |
|
|
self.writer.add_image('input', make_grid( |
|
|
data.cpu(), nrow=8, normalize=True)) |
|
|
|
|
|
if batch_idx == self.len_epoch: |
|
|
break |
|
|
|
|
|
self.writer.next() |
|
|
self.train_metrics.add_scalers() |
|
|
log = self.train_metrics.result() |
|
|
|
|
|
if self.do_validation: |
|
|
val_log = self._valid_epoch(epoch) |
|
|
log.update(**{'val_' + k: v for k, v in val_log.items()}) |
|
|
|
|
|
if self.lr_scheduler is not None: |
|
|
self.lr_scheduler.step() |
|
|
return log |
|
|
|
|
|
def _valid_epoch(self, epoch): |
|
|
""" |
|
|
Validate after training an epoch |
|
|
|
|
|
:param epoch: Integer, current training epoch. |
|
|
:return: A log that contains information about validation |
|
|
""" |
|
|
self.model.eval() |
|
|
self.valid_metrics.reset() |
|
|
with torch.no_grad(): |
|
|
for batch_idx, (data, target) in enumerate(self.valid_data_loader): |
|
|
data, target = data.to(self.device), target.to(self.device) |
|
|
|
|
|
output = self.model(data) |
|
|
loss = self.criterion(output, target) |
|
|
|
|
|
self.valid_metrics.update('loss', loss.item()) |
|
|
for met in self.metric_ftns: |
|
|
self.valid_metrics.update( |
|
|
met.__name__, met(output, target)) |
|
|
self.writer.add_image('input', make_grid( |
|
|
data.cpu(), nrow=8, normalize=True)) |
|
|
|
|
|
|
|
|
for name, p in self.model.named_parameters(): |
|
|
self.writer.add_histogram(name, p, bins='auto') |
|
|
|
|
|
self.writer.next('valid') |
|
|
self.valid_metrics.add_scalers() |
|
|
return self.valid_metrics.result() |
|
|
|
|
|
def _progress(self, batch_idx): |
|
|
base = '[{}/{} ({:.0f}%)]' |
|
|
if hasattr(self.data_loader, 'n_samples'): |
|
|
current = batch_idx * self.data_loader.batch_size |
|
|
total = self.data_loader.n_samples |
|
|
else: |
|
|
current = batch_idx |
|
|
total = self.len_epoch |
|
|
return base.format(current, total, 100.0 * current / total) |
|
|
|