import torch import torch.nn as nn import torchtask def add_parser_arguments(parser): torchtask.criterion_template.add_parser_arguments(parser) def harmonizer_loss(): return HarmonizerLoss class AbsoluteLoss(nn.Module): def __init__(self, epsilon=1e-6): super(AbsoluteLoss, self).__init__() self.epsilon = epsilon def forward(self, pred, gt): loss = torch.sqrt((pred - gt) ** 2 + self.epsilon) return loss class HarmonizerLoss(torchtask.criterion_template.TaskCriterion): def __init__(self, args): super(HarmonizerLoss, self).__init__(args) self.l1 = AbsoluteLoss() self.l2 = nn.MSELoss(reduction='none') def forward(self, pred, gt, inp): pred_outputs, = pred x, mask = inp assert len(pred_outputs) == len(gt) image_losses = [] for pred_, gt_ in zip(pred_outputs, gt): l1_loss = torch.sum(self.l1(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6) l2_loss = torch.sum(self.l2(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6) * 10 loss = (l1_loss + l2_loss) image_losses.append(loss) return image_losses