File size: 1,250 Bytes
4c62147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
import torch.nn as nn
import torchtask
def add_parser_arguments(parser):
torchtask.criterion_template.add_parser_arguments(parser)
def harmonizer_loss():
return HarmonizerLoss
class AbsoluteLoss(nn.Module):
def __init__(self, epsilon=1e-6):
super(AbsoluteLoss, self).__init__()
self.epsilon = epsilon
def forward(self, pred, gt):
loss = torch.sqrt((pred - gt) ** 2 + self.epsilon)
return loss
class HarmonizerLoss(torchtask.criterion_template.TaskCriterion):
def __init__(self, args):
super(HarmonizerLoss, self).__init__(args)
self.l1 = AbsoluteLoss()
self.l2 = nn.MSELoss(reduction='none')
def forward(self, pred, gt, inp):
pred_outputs, = pred
x, mask = inp
assert len(pred_outputs) == len(gt)
image_losses = []
for pred_, gt_ in zip(pred_outputs, gt):
l1_loss = torch.sum(self.l1(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6)
l2_loss = torch.sum(self.l2(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6) * 10
loss = (l1_loss + l2_loss)
image_losses.append(loss)
return image_losses
|