| | import os
|
| | from tqdm import tqdm
|
| |
|
| | import torch
|
| |
|
| | import data.common
|
| | from utils import interact, MultiSaver
|
| | from torch.utils.tensorboard import SummaryWriter
|
| | import torchvision
|
| |
|
| | import torch.cuda.amp as amp
|
| |
|
| | class Trainer():
|
| |
|
| | def __init__(self, args, model, criterion, optimizer, loaders):
|
| | print('===> Initializing trainer')
|
| | self.args = args
|
| | self.mode = 'train'
|
| | self.epoch = args.start_epoch
|
| | self.save_dir = args.save_dir
|
| |
|
| | self.model = model
|
| | self.criterion = criterion
|
| | self.optimizer = optimizer
|
| | self.loaders = loaders
|
| |
|
| |
|
| | self.do_train = args.do_train
|
| | self.do_validate = args.do_validate
|
| | self.do_test = args.do_test
|
| |
|
| | self.device = args.device
|
| | self.dtype = args.dtype
|
| | self.dtype_eval = torch.float32 if args.precision == 'single' else torch.float16
|
| | self.recurrence=args.n_scales
|
| | if self.args.demo and self.args.demo_output_dir:
|
| | self.result_dir = self.args.demo_output_dir
|
| | else:
|
| | self.result_dir = os.path.join(self.save_dir, 'result')
|
| | os.makedirs(self.result_dir, exist_ok=True)
|
| | print('results are saved in {}'.format(self.result_dir))
|
| |
|
| | self.imsaver = MultiSaver(self.result_dir)
|
| |
|
| | self.is_slave = self.args.launched and self.args.rank != 0
|
| |
|
| | self.scaler = amp.GradScaler(
|
| | init_scale=self.args.init_scale,
|
| | enabled=self.args.amp
|
| | )
|
| | if not self.is_slave:
|
| | self.writter=SummaryWriter(f"runs/{args.save_dir}")
|
| |
|
| |
|
| | def save(self, epoch=None):
|
| | epoch = self.epoch if epoch is None else epoch
|
| | if epoch % self.args.save_every == 0:
|
| | if self.mode == 'train':
|
| | self.model.save(epoch)
|
| | self.optimizer.save(epoch)
|
| | self.criterion.save()
|
| |
|
| | return
|
| |
|
| | def load(self, epoch=None, pretrained=None):
|
| | if epoch is None:
|
| | epoch = self.args.load_epoch
|
| | self.epoch = epoch
|
| | self.model.load(epoch, pretrained)
|
| | self.optimizer.load(epoch)
|
| | self.criterion.load(epoch)
|
| |
|
| | return
|
| |
|
| | def train(self, epoch):
|
| | self.mode = 'train'
|
| | self.epoch = epoch
|
| |
|
| | self.model.train()
|
| | self.model.to(dtype=self.dtype)
|
| |
|
| | self.criterion.train()
|
| | self.criterion.epoch = epoch
|
| |
|
| | if not self.is_slave:
|
| | print('[Epoch {} / lr {:.2e}]'.format(
|
| | epoch, self.optimizer.get_lr()
|
| | ))
|
| | total=len(self.loaders[self.mode])
|
| | acc=0.0
|
| | if self.args.distributed:
|
| | self.loaders[self.mode].sampler.set_epoch(epoch)
|
| | if self.is_slave:
|
| | tq = self.loaders[self.mode]
|
| | else:
|
| | tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
|
| | buffer=[0.0]*self.recurrence
|
| | torch.set_grad_enabled(True)
|
| | for idx, batch in enumerate(tq):
|
| | self.optimizer.zero_grad()
|
| |
|
| | input, target = data.common.to(
|
| | batch[0], batch[1], device=self.device, dtype=self.dtype)
|
| |
|
| | with amp.autocast(self.args.amp):
|
| | output = self.model(input)
|
| | loss = self.criterion(output, target)
|
| |
|
| | for i in range(self.recurrence):
|
| | buffer[i]+=self.criterion.buffer[i]
|
| |
|
| | self.scaler.scale(loss).backward()
|
| | if self.args.clip>0:
|
| | self.scaler.unscale_(self.optimizer.G)
|
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
|
| | self.scaler.step(self.optimizer.G)
|
| | self.scaler.update()
|
| |
|
| | if isinstance(tq, tqdm):
|
| | tq.set_description(self.criterion.get_loss_desc())
|
| | if not self.is_slave:
|
| | rgb_range=self.args.rgb_range
|
| | for i in range(len(output)):
|
| | grid=torchvision.utils.make_grid(output[i])
|
| | self.writter.add_image(f"Output{i}",grid/rgb_range,epoch)
|
| | self.writter.add_scalar(f"Loss{i}",buffer[i],epoch)
|
| | self.writter.add_image("Input",torchvision.utils.make_grid(input[0])/rgb_range,epoch)
|
| | self.writter.add_image("Target",torchvision.utils.make_grid(target[0])/rgb_range,epoch)
|
| | self.criterion.normalize()
|
| | if isinstance(tq, tqdm):
|
| | tq.set_description(self.criterion.get_loss_desc())
|
| | tq.display(pos=-1)
|
| |
|
| | self.criterion.step()
|
| | self.optimizer.schedule(self.criterion.get_last_loss())
|
| |
|
| | if self.args.rank == 0:
|
| | self.save(epoch)
|
| |
|
| | return
|
| |
|
| | def evaluate(self, epoch, mode='val'):
|
| | self.mode = mode
|
| | self.epoch = epoch
|
| |
|
| | self.model.eval()
|
| | self.model.to(dtype=self.dtype_eval)
|
| |
|
| | if mode == 'val':
|
| | self.criterion.validate()
|
| | elif mode == 'test':
|
| | self.criterion.test()
|
| | self.criterion.epoch = epoch
|
| |
|
| | self.imsaver.join_background()
|
| |
|
| | if self.is_slave:
|
| | tq = self.loaders[self.mode]
|
| | else:
|
| | tq = tqdm(self.loaders[self.mode], ncols=80, smoothing=0, bar_format='{desc}|{bar}{r_bar}')
|
| |
|
| | compute_loss = True
|
| | torch.set_grad_enabled(False)
|
| | for idx, batch in enumerate(tq):
|
| | input, target = data.common.to(
|
| | batch[0], batch[1], device=self.device, dtype=self.dtype_eval)
|
| | with amp.autocast(self.args.amp):
|
| | output = self.model(input)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if mode == 'demo':
|
| | pad_width = batch[2]
|
| | output[0], _ = data.common.pad(output[0], pad_width=pad_width, negative=True)
|
| |
|
| | if isinstance(batch[1], torch.BoolTensor):
|
| | compute_loss = False
|
| |
|
| | if compute_loss:
|
| | self.criterion(output, target)
|
| | if isinstance(tq, tqdm):
|
| | tq.set_description(self.criterion.get_loss_desc())
|
| |
|
| | if self.args.save_results != 'none':
|
| | if isinstance(output, (list, tuple)):
|
| | result = output[-1]
|
| | elif isinstance(output, torch.Tensor):
|
| | result = output
|
| |
|
| | names = batch[-1]
|
| |
|
| | if self.args.save_results == 'part' and compute_loss:
|
| | indices = batch[-2]
|
| | save_ids = [save_id for save_id, idx in enumerate(indices) if idx % 10 == 0]
|
| |
|
| | result = result[save_ids]
|
| | names = [names[save_id] for save_id in save_ids]
|
| |
|
| | self.imsaver.save_image(result, names)
|
| |
|
| | if compute_loss:
|
| | self.criterion.normalize()
|
| | if isinstance(tq, tqdm):
|
| | tq.set_description(self.criterion.get_loss_desc())
|
| | tq.display(pos=-1)
|
| |
|
| | self.criterion.step()
|
| | if self.args.rank == 0:
|
| | self.save()
|
| |
|
| | self.imsaver.end_background()
|
| |
|
| | def validate(self, epoch):
|
| | self.evaluate(epoch, 'val')
|
| | return
|
| |
|
| | def test(self, epoch):
|
| | self.evaluate(epoch, 'test')
|
| | return
|
| |
|
| | def fill_evaluation(self, epoch, mode=None, force=False):
|
| | if epoch <= 0:
|
| | return
|
| |
|
| | if mode is not None:
|
| | self.mode = mode
|
| |
|
| | do_eval = force
|
| | if not force:
|
| | loss_missing = epoch not in self.criterion.loss_stat[self.mode]['Total']
|
| |
|
| | metric_missing = False
|
| | for metric_type in self.criterion.metric:
|
| | if epoch not in self.criterion.metric_stat[mode][metric_type]:
|
| | metric_missing = True
|
| |
|
| | do_eval = loss_missing or metric_missing
|
| |
|
| | if do_eval:
|
| | try:
|
| | self.load(epoch)
|
| | self.evaluate(epoch, self.mode)
|
| | except:
|
| |
|
| | pass
|
| |
|
| | return
|
| |
|