| import torch | |
| def multi_scale_grad_loss(prediction, target, mask): | |
| total = 0 | |
| for scale in range(4): | |
| step = pow(2, scale) | |
| total += grad_loss(prediction[:, ::step, ::step], | |
| target[:, ::step, ::step], | |
| mask[:, ::step, ::step]) | |
| return total | |
| def grad_loss(prediction, target, mask): | |
| M = torch.sum(mask, (1, 2)) | |
| diff = prediction - target | |
| diff = torch.mul(mask, diff) | |
| grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) | |
| mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) | |
| grad_x = torch.mul(mask_x, grad_x) | |
| grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) | |
| mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) | |
| grad_y = torch.mul(mask_y, grad_y) | |
| image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) | |
| return torch.sum(image_loss) / torch.sum(M) |