| | import os |
| | from importlib import import_module |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.distributed as dist |
| |
|
| | import matplotlib.pyplot as plt |
| | plt.switch_backend('agg') |
| |
|
| | from .metric import PSNR, SSIM |
| |
|
| | from utils import interact |
| |
|
| | class Loss(torch.nn.modules.loss._Loss): |
| | def __init__(self, args, epoch=None, model=None, optimizer=None): |
| | """ |
| | input: |
| | args.loss use '+' to sum over different loss functions |
| | use '*' to specify the loss weight |
| | |
| | example: |
| | 1*MSE+0.5*VGG54 |
| | loss = sum of MSE and VGG54(weight=0.5) |
| | |
| | args.measure similar to args.loss, but without weight |
| | |
| | example: |
| | MSE+PSNR |
| | measure MSE and PSNR, independently |
| | """ |
| | super(Loss, self).__init__() |
| |
|
| | self.args = args |
| |
|
| | self.rgb_range = args.rgb_range |
| | self.device_type = args.device_type |
| | self.synchronized = False |
| |
|
| | self.epoch = args.start_epoch if epoch is None else epoch |
| | self.save_dir = args.save_dir |
| | self.save_name = os.path.join(self.save_dir, 'loss.pt') |
| |
|
| | |
| | self.validating = False |
| | self.testing = False |
| | self.mode = 'train' |
| | self.modes = ('train', 'val', 'test') |
| |
|
| | |
| | self.loss = nn.ModuleDict() |
| | self.loss_types = [] |
| | self.weight = {} |
| |
|
| | self.loss_stat = {mode:{} for mode in self.modes} |
| | |
| | |
| |
|
| | for weighted_loss in args.loss.split('+'): |
| | w, l = weighted_loss.split('*') |
| | l = l.upper() |
| | if l in ('ABS', 'L1'): |
| | loss_type = 'L1' |
| | func = nn.L1Loss() |
| | elif l in ('MSE', 'L2'): |
| | loss_type = 'L2' |
| | func = nn.MSELoss() |
| | elif l in ('ADV', 'GAN'): |
| | loss_type = 'ADV' |
| | m = import_module('loss.adversarial') |
| | func = getattr(m, 'Adversarial')(args, model, optimizer) |
| | else: |
| | loss_type = l |
| | m = import_module*'loss.{}'.format(l.lower()) |
| | func = getattr(m, l)(args) |
| |
|
| | self.loss_types += [loss_type] |
| | self.loss[loss_type] = func |
| | self.weight[loss_type] = float(w) |
| |
|
| | print('Loss function: {}'.format(args.loss)) |
| |
|
| | |
| | self.do_measure = args.metric.lower() != 'none' |
| |
|
| | self.metric = nn.ModuleDict() |
| | self.metric_types = [] |
| | self.metric_stat = {mode:{} for mode in self.modes} |
| | |
| |
|
| | if self.do_measure: |
| | for metric_type in args.metric.split(','): |
| | metric_type = metric_type.upper() |
| | if metric_type == 'PSNR': |
| | metric_func = PSNR() |
| | elif metric_type == 'SSIM': |
| | metric_func = SSIM(args.device_type) |
| | else: |
| | raise NotImplementedError |
| |
|
| | self.metric_types += [metric_type] |
| | self.metric[metric_type] = metric_func |
| |
|
| | print('Metrics: {}'.format(args.metric)) |
| |
|
| | if args.start_epoch != 1: |
| | self.load(args.start_epoch - 1) |
| |
|
| | for mode in self.modes: |
| | for loss_type in self.loss: |
| | if loss_type not in self.loss_stat[mode]: |
| | self.loss_stat[mode][loss_type] = {} |
| |
|
| | if 'Total' not in self.loss_stat[mode]: |
| | self.loss_stat[mode]['Total'] = {} |
| |
|
| | if self.do_measure: |
| | for metric_type in self.metric: |
| | if metric_type not in self.metric_stat[mode]: |
| | self.metric_stat[mode][metric_type] = {} |
| |
|
| | self.count = 0 |
| | self.count_m = 0 |
| |
|
| | self.to(args.device, dtype=args.dtype) |
| |
|
| | def train(self, mode=True): |
| | super(Loss, self).train(mode) |
| | if mode: |
| | self.validating = False |
| | self.testing = False |
| | self.mode = 'train' |
| | else: |
| | self.validating = False |
| | self.testing = True |
| | self.mode = 'test' |
| |
|
| | def validate(self): |
| | super(Loss, self).eval() |
| | |
| | self.validating = True |
| | self.testing = False |
| | self.mode = 'val' |
| |
|
| | def test(self): |
| | super(Loss, self).eval() |
| | |
| | self.validating = False |
| | self.testing = True |
| | self.mode = 'test' |
| |
|
| | def forward(self, input, target): |
| | self.synchronized = False |
| |
|
| | loss = 0 |
| |
|
| | def _ms_forward(input, target, func): |
| | if isinstance(input, (list, tuple)): |
| | _loss = [] |
| | for (input_i, target_i) in zip(input, target): |
| | _loss += [func(input_i, target_i)] |
| | return sum(_loss) |
| | elif isinstance(input, dict): |
| | _loss = [] |
| | for key in input: |
| | _loss += [func(input[key], target[key])] |
| | return sum(_loss) |
| | else: |
| | return func(input, target) |
| |
|
| | |
| | if self.count == 0: |
| | for loss_type in self.loss_types: |
| | self.loss_stat[self.mode][loss_type][self.epoch] = 0 |
| | self.loss_stat[self.mode]['Total'][self.epoch] = 0 |
| |
|
| | if isinstance(input, list): |
| | count = input[0].shape[0] |
| | else: |
| | count = input.shape[0] |
| |
|
| | isnan = False |
| | for loss_type in self.loss_types: |
| |
|
| | if loss_type == 'ADV': |
| | _loss = self.loss[loss_type](input[0], target[0], self.training) * self.weight[loss_type] |
| | else: |
| | _loss = _ms_forward(input, target, self.loss[loss_type]) * self.weight[loss_type] |
| |
|
| | if torch.isnan(_loss): |
| | isnan = True |
| | else: |
| | self.loss_stat[self.mode][loss_type][self.epoch] += _loss.item() * count |
| | self.loss_stat[self.mode]['Total'][self.epoch] += _loss.item() * count |
| |
|
| | loss += _loss |
| |
|
| | if not isnan: |
| | self.count += count |
| |
|
| | if not self.training and self.do_measure: |
| | self.measure(input, target) |
| |
|
| | return loss |
| |
|
| | def measure(self, input, target): |
| | if isinstance(input, (list, tuple)): |
| | self.measure(input[0], target[0]) |
| | return |
| | elif isinstance(input, dict): |
| | first_key = list(input.keys())[0] |
| | self.measure(input[first_key], target[first_key]) |
| | return |
| | else: |
| | pass |
| |
|
| | if self.count_m == 0: |
| | for metric_type in self.metric_stat[self.mode]: |
| | self.metric_stat[self.mode][metric_type][self.epoch] = 0 |
| |
|
| | if isinstance(input, list): |
| | count = input[0].shape[0] |
| | else: |
| | count = input.shape[0] |
| |
|
| | for metric_type in self.metric_stat[self.mode]: |
| |
|
| | input = input.clamp(0, self.rgb_range) |
| | if self.rgb_range == 255: |
| | input.round_() |
| |
|
| | _metric = self.metric[metric_type](input, target) |
| | self.metric_stat[self.mode][metric_type][self.epoch] += _metric.item() * count |
| |
|
| | self.count_m += count |
| |
|
| | return |
| |
|
| | def normalize(self): |
| | if self.args.distributed: |
| | dist.barrier() |
| | if not self.synchronized: |
| | self.all_reduce() |
| |
|
| | if self.count > 0: |
| | for loss_type in self.loss_stat[self.mode]: |
| | self.loss_stat[self.mode][loss_type][self.epoch] /= self.count |
| | self.count = 0 |
| |
|
| | if self.count_m > 0: |
| | for metric_type in self.metric_stat[self.mode]: |
| | self.metric_stat[self.mode][metric_type][self.epoch] /= self.count_m |
| | self.count_m = 0 |
| |
|
| | return |
| |
|
| | def all_reduce(self, epoch=None): |
| | |
| |
|
| | if epoch is None: |
| | epoch = self.epoch |
| |
|
| | def _reduce_value(value, ReduceOp=dist.ReduceOp.SUM): |
| | value_tensor = torch.Tensor([value]).to(self.args.device, self.args.dtype, non_blocking=True) |
| | dist.all_reduce(value_tensor, ReduceOp, async_op=False) |
| | value = value_tensor.item() |
| | del value_tensor |
| |
|
| | return value |
| |
|
| | dist.barrier() |
| | if self.count > 0: |
| | self.count = _reduce_value(self.count, dist.ReduceOp.SUM) |
| |
|
| | for loss_type in self.loss_stat[self.mode]: |
| | self.loss_stat[self.mode][loss_type][epoch] = _reduce_value( |
| | self.loss_stat[self.mode][loss_type][epoch], |
| | dist.ReduceOp.SUM |
| | ) |
| |
|
| | if self.count_m > 0: |
| | self.count_m = _reduce_value(self.count_m, dist.ReduceOp.SUM) |
| |
|
| | for metric_type in self.metric_stat[self.mode]: |
| | self.metric_stat[self.mode][metric_type][epoch] = _reduce_value( |
| | self.metric_stat[self.mode][metric_type][epoch], |
| | dist.ReduceOp.SUM |
| | ) |
| |
|
| | self.synchronized = True |
| |
|
| | return |
| |
|
| | def print_metrics(self): |
| |
|
| | print(self.get_metric_desc()) |
| | return |
| |
|
| | def get_last_loss(self): |
| | return self.loss_stat[self.mode]['Total'][self.epoch] |
| |
|
| | def get_loss_desc(self): |
| |
|
| | if self.mode == 'train': |
| | desc_prefix = 'Train' |
| | elif self.mode == 'val': |
| | desc_prefix = 'Validation' |
| | else: |
| | desc_prefix = 'Test' |
| |
|
| | loss = self.loss_stat[self.mode]['Total'][self.epoch] |
| | if self.count > 0: |
| | loss /= self.count |
| | desc = '{} Loss: {:.1f}'.format(desc_prefix, loss) |
| |
|
| | if self.mode in ('val', 'test'): |
| | metric_desc = self.get_metric_desc() |
| | desc = '{}{}'.format(desc, metric_desc) |
| |
|
| | return desc |
| |
|
| | def get_metric_desc(self): |
| | desc = '' |
| | for metric_type in self.metric_stat[self.mode]: |
| | measured = self.metric_stat[self.mode][metric_type][self.epoch] |
| | if self.count_m > 0: |
| | measured /= self.count_m |
| |
|
| | if metric_type == 'PSNR': |
| | desc += ' {}: {:2.2f}'.format(metric_type, measured) |
| | elif metric_type == 'SSIM': |
| | desc += ' {}: {:1.4f}'.format(metric_type, measured) |
| | else: |
| | desc += ' {}: {:2.4f}'.format(metric_type, measured) |
| |
|
| | return desc |
| |
|
| | def step(self, plot_name=None): |
| | self.normalize() |
| | self.plot(plot_name) |
| | if not self.training and self.do_measure: |
| | |
| | self.plot_metric() |
| | |
| |
|
| | return |
| |
|
| | def save(self): |
| |
|
| | state = { |
| | 'loss_stat': self.loss_stat, |
| | 'metric_stat': self.metric_stat, |
| | } |
| | torch.save(state, self.save_name) |
| |
|
| | return |
| |
|
| | def load(self, epoch=None): |
| |
|
| | print('Loading loss record from {}'.format(self.save_name)) |
| | if os.path.exists(self.save_name): |
| | state = torch.load(self.save_name, map_location=self.args.device) |
| |
|
| | self.loss_stat = state['loss_stat'] |
| | if 'metric_stat' in state: |
| | self.metric_stat = state['metric_stat'] |
| | else: |
| | pass |
| | else: |
| | print('no loss record found for {}!'.format(self.save_name)) |
| |
|
| | if epoch is not None: |
| | self.epoch = epoch |
| |
|
| | return |
| |
|
| | def plot(self, plot_name=None, metric=False): |
| |
|
| | self.plot_loss(plot_name) |
| |
|
| | if metric: |
| | self.plot_metric(plot_name) |
| | |
| | |
| |
|
| | return |
| |
|
| |
|
| | def plot_loss(self, plot_name=None): |
| | if plot_name is None: |
| | plot_name = os.path.join(self.save_dir, "{}_loss.pdf".format(self.mode)) |
| |
|
| | title = "{} loss".format(self.mode) |
| |
|
| | fig = plt.figure() |
| | plt.title(title) |
| | plt.xlabel('epochs') |
| | plt.ylabel('loss') |
| | plt.grid(True, linestyle=':') |
| |
|
| | for loss_type, loss_record in self.loss_stat[self.mode].items(): |
| | axis = sorted([epoch for epoch in loss_record.keys() if epoch <= self.epoch]) |
| | value = [self.loss_stat[self.mode][loss_type][epoch] for epoch in axis] |
| | label = loss_type |
| |
|
| | plt.plot(axis, value, label=label) |
| |
|
| | plt.xlim(0, self.epoch) |
| | plt.legend() |
| | plt.savefig(plot_name) |
| | plt.close(fig) |
| |
|
| | return |
| |
|
| | def plot_metric(self, plot_name=None): |
| | |
| | if plot_name is None: |
| | plot_name = os.path.join(self.save_dir, "{}_metric.pdf".format(self.mode)) |
| |
|
| | title = "{} metrics".format(self.mode) |
| |
|
| | fig, ax1 = plt.subplots() |
| | plt.title(title) |
| | plt.grid(True, linestyle=':') |
| | ax1.set_xlabel('epochs') |
| |
|
| | plots = None |
| | for metric_type, metric_record in self.metric_stat[self.mode].items(): |
| | axis = sorted([epoch for epoch in metric_record.keys() if epoch <= self.epoch]) |
| | value = [metric_record[epoch] for epoch in axis] |
| | label = metric_type |
| |
|
| | if metric_type == 'PSNR': |
| | ax = ax1 |
| | color='C0' |
| | elif metric_type == 'SSIM': |
| | ax2 = ax1.twinx() |
| | ax = ax2 |
| | color='C1' |
| |
|
| | ax.set_ylabel(metric_type) |
| | if plots is None: |
| | plots = ax.plot(axis, value, label=label, color=color) |
| | else: |
| | plots += ax.plot(axis, value, label=label, color=color) |
| |
|
| | labels = [plot.get_label() for plot in plots] |
| | plt.legend(plots, labels) |
| | plt.xlim(0, self.epoch) |
| | plt.savefig(plot_name) |
| | plt.close(fig) |
| |
|
| | return |
| |
|
| | def sort(self): |
| | |
| | for mode in self.modes: |
| | for loss_type, loss_epochs in self.loss_stat[mode].items(): |
| | self.loss_stat[mode][loss_type] = {epoch: loss_epochs[epoch] for epoch in sorted(loss_epochs)} |
| |
|
| | for metric_type, metric_epochs in self.metric_stat[mode].items(): |
| | self.metric_stat[mode][metric_type] = {epoch: metric_epochs[epoch] for epoch in sorted(metric_epochs)} |
| |
|
| | return self |
| |
|