File size: 5,267 Bytes
006869b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import numpy as np
import torch
from base import BaseTrainer
from torchvision.utils import make_grid
from utils import MetricTracker
class Trainer(BaseTrainer):
"""
Trainer class
"""
def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
super().__init__(model, criterion, metric_ftns, optimizer, config)
self.config = config
self.device = device
self.data_loader = data_loader
self.len_epoch = len(self.data_loader)
self.valid_data_loader = valid_data_loader
self.do_validation = self.valid_data_loader is not None
self.lr_scheduler = lr_scheduler
self.log_step = int(np.sqrt(data_loader.batch_size))
self.adaptive_step = config['trainer']['adaptive_step']
self.train_metrics = MetricTracker(
'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
self.valid_metrics = MetricTracker(
'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
def _train_epoch(self, epoch):
"""
Training logic for an epoch
:param epoch: Integer, current training epoch.
:return: A log that contains average loss and metric in this epoch.
"""
if epoch > self.adaptive_step and epoch % self.adaptive_step == 1:
dataset = self.data_loader.inference.dataset
self.model.eval()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(self.data_loader.inference):
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
batch_size = self.data_loader.inference.batch_size
patch_idx = torch.arange(
batch_size * batch_idx, batch_size * batch_idx + data.shape[0])
pred = torch.argmax(output, dim=1)
dataset.patches.store_data(patch_idx, [pred.unsqueeze(1)])
preds = [dataset.patches.combine(idx, data_idx=0)[0].cpu()
for idx in range(len(dataset.data))]
self.data_loader.update_dataset(preds)
self.len_epoch = len(self.data_loader)
self.model.train()
self.train_metrics.reset()
for batch_idx, (data, target) in enumerate(self.data_loader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
self.train_metrics.update('loss', loss.item())
for met in self.metric_ftns:
self.train_metrics.update(met.__name__, met(output, target))
if batch_idx % self.log_step == 0:
self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
epoch,
self._progress(batch_idx),
loss.item()))
self.writer.add_image('input', make_grid(
data.cpu(), nrow=8, normalize=True))
if batch_idx == self.len_epoch:
break
self.writer.next()
self.train_metrics.add_scalers()
log = self.train_metrics.result()
if self.do_validation:
val_log = self._valid_epoch(epoch)
log.update(**{'val_' + k: v for k, v in val_log.items()})
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return log
def _valid_epoch(self, epoch):
"""
Validate after training an epoch
:param epoch: Integer, current training epoch.
:return: A log that contains information about validation
"""
self.model.eval()
self.valid_metrics.reset()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(self.valid_data_loader):
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
loss = self.criterion(output, target)
self.valid_metrics.update('loss', loss.item())
for met in self.metric_ftns:
self.valid_metrics.update(
met.__name__, met(output, target))
self.writer.add_image('input', make_grid(
data.cpu(), nrow=8, normalize=True))
# add histogram of model parameters to the tensorboard
for name, p in self.model.named_parameters():
self.writer.add_histogram(name, p, bins='auto')
self.writer.next('valid')
self.valid_metrics.add_scalers()
return self.valid_metrics.result()
def _progress(self, batch_idx):
base = '[{}/{} ({:.0f}%)]'
if hasattr(self.data_loader, 'n_samples'):
current = batch_idx * self.data_loader.batch_size
total = self.data_loader.n_samples
else:
current = batch_idx
total = self.len_epoch
return base.format(current, total, 100.0 * current / total)
|