Inmental's picture
Upload folder using huggingface_hub
4c62147 verified
import torch
import torch.nn as nn
import torchtask
def add_parser_arguments(parser):
torchtask.criterion_template.add_parser_arguments(parser)
def harmonizer_loss():
return HarmonizerLoss
class AbsoluteLoss(nn.Module):
def __init__(self, epsilon=1e-6):
super(AbsoluteLoss, self).__init__()
self.epsilon = epsilon
def forward(self, pred, gt):
loss = torch.sqrt((pred - gt) ** 2 + self.epsilon)
return loss
class HarmonizerLoss(torchtask.criterion_template.TaskCriterion):
def __init__(self, args):
super(HarmonizerLoss, self).__init__(args)
self.l1 = AbsoluteLoss()
self.l2 = nn.MSELoss(reduction='none')
def forward(self, pred, gt, inp):
pred_outputs, = pred
x, mask = inp
assert len(pred_outputs) == len(gt)
image_losses = []
for pred_, gt_ in zip(pred_outputs, gt):
l1_loss = torch.sum(self.l1(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6)
l2_loss = torch.sum(self.l2(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6) * 10
loss = (l1_loss + l2_loss)
image_losses.append(loss)
return image_losses