File size: 850 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | import torch
def multi_scale_grad_loss(prediction, target, mask):
total = 0
for scale in range(4):
step = pow(2, scale)
total += grad_loss(prediction[:, ::step, ::step],
target[:, ::step, ::step],
mask[:, ::step, ::step])
return total
def grad_loss(prediction, target, mask):
M = torch.sum(mask, (1, 2))
diff = prediction - target
diff = torch.mul(mask, diff)
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
grad_x = torch.mul(mask_x, grad_x)
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
grad_y = torch.mul(mask_y, grad_y)
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
return torch.sum(image_loss) / torch.sum(M) |