| |
| |
| |
|
|
| import pdb |
| from tqdm import tqdm |
| from collections import defaultdict |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class Trainer(nn.Module): |
| """Helper class to train a deep network. |
| Overload this class `forward_backward` for your actual needs. |
| |
| Usage: |
| train = Trainer(net, loader, loss, optimizer) |
| for epoch in range(n_epochs): |
| train() |
| """ |
|
|
| def __init__(self, net, loader, loss, optimizer): |
| nn.Module.__init__(self) |
| self.net = net |
| self.loader = loader |
| self.loss_func = loss |
| self.optimizer = optimizer |
|
|
| def iscuda(self): |
| return next(self.net.parameters()).device != torch.device("cpu") |
|
|
| def todevice(self, x): |
| if isinstance(x, dict): |
| return {k: self.todevice(v) for k, v in x.items()} |
| if isinstance(x, (tuple, list)): |
| return [self.todevice(v) for v in x] |
|
|
| if self.iscuda(): |
| return x.contiguous().cuda(non_blocking=True) |
| else: |
| return x.cpu() |
|
|
| def __call__(self): |
| self.net.train() |
|
|
| stats = defaultdict(list) |
|
|
| for iter, inputs in enumerate(tqdm(self.loader)): |
| inputs = self.todevice(inputs) |
|
|
| |
| self.optimizer.zero_grad() |
|
|
| loss, details = self.forward_backward(inputs) |
| if torch.isnan(loss): |
| raise RuntimeError("Loss is NaN") |
|
|
| self.optimizer.step() |
|
|
| for key, val in details.items(): |
| stats[key].append(val) |
|
|
| print(" Summary of losses during this epoch:") |
| mean = lambda lis: sum(lis) / len(lis) |
| for loss_name, vals in stats.items(): |
| N = 1 + len(vals) // 10 |
| print(f" - {loss_name:20}:", end="") |
| print( |
| f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})" |
| ) |
| return mean(stats["loss"]) |
|
|
| def forward_backward(self, inputs): |
| raise NotImplementedError() |
|
|