File size: 850 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch

def multi_scale_grad_loss(prediction, target, mask):
    total = 0
    for scale in range(4):
        step = pow(2, scale)
        total += grad_loss(prediction[:, ::step, ::step], 
            target[:, ::step, ::step],
            mask[:, ::step, ::step])
    return total

def grad_loss(prediction, target, mask):
    M = torch.sum(mask, (1, 2))
    diff = prediction - target
    diff = torch.mul(mask, diff)
    grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
    mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
    grad_x = torch.mul(mask_x, grad_x)
    grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
    mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
    grad_y = torch.mul(mask_y, grad_y)
    image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))

    return torch.sum(image_loss) / torch.sum(M)