File size: 1,250 Bytes
4c62147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

import torch
import torch.nn as nn

import torchtask


def add_parser_arguments(parser):
    torchtask.criterion_template.add_parser_arguments(parser)



def harmonizer_loss():
    return HarmonizerLoss


class AbsoluteLoss(nn.Module):
    def __init__(self, epsilon=1e-6):
        super(AbsoluteLoss, self).__init__()
        self.epsilon = epsilon

    def forward(self, pred, gt):
        loss = torch.sqrt((pred - gt) ** 2 + self.epsilon)
        return loss


class HarmonizerLoss(torchtask.criterion_template.TaskCriterion):
    def __init__(self, args):
        super(HarmonizerLoss, self).__init__(args)

        self.l1 = AbsoluteLoss()
        self.l2 = nn.MSELoss(reduction='none')

    def forward(self, pred, gt, inp):
        pred_outputs, = pred
        x, mask = inp

        assert len(pred_outputs) == len(gt)

        image_losses = []
        for pred_, gt_ in zip(pred_outputs, gt):
            l1_loss = torch.sum(self.l1(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6)
            l2_loss = torch.sum(self.l2(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6) * 10
            loss = (l1_loss + l2_loss)
            image_losses.append(loss)

        return image_losses