Spaces:
Paused
Paused
| # | |
| # Copyright (C) 2023, Inria | |
| # GRAPHDECO research group, https://team.inria.fr/graphdeco | |
| # All rights reserved. | |
| # | |
| # This software is free for non-commercial, research and evaluation use | |
| # under the terms of the LICENSE.md file. | |
| # | |
| # For inquiries contact george.drettakis@inria.fr | |
| # | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| from math import exp | |
| from torch import Tensor, nn | |
| from typing import Dict, Literal, Optional, Tuple, cast | |
| from jaxtyping import Bool, Float | |
| L1Loss = nn.L1Loss | |
| MSELoss = nn.MSELoss | |
| def l1_loss(network_output, gt, mask=None): | |
| l1 = torch.abs((network_output - gt)) | |
| if mask is not None: | |
| l1 = l1[:, mask] | |
| return l1.mean() | |
| def l2_loss(network_output, gt): | |
| return ((network_output - gt) ** 2).mean() | |
| def gaussian(window_size, sigma): | |
| gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) | |
| return gauss / gauss.sum() | |
| def create_window(window_size, channel): | |
| _1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
| _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
| window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
| return window | |
| def ssim(img1, img2, window_size=11, size_average=True): | |
| channel = img1.size(-3) | |
| window = create_window(window_size, channel) | |
| if img1.is_cuda: | |
| window = window.cuda(img1.get_device()) | |
| window = window.type_as(img1) | |
| return _ssim(img1, img2, window, window_size, channel, size_average) | |
| def _ssim(img1, img2, window, window_size, channel, size_average=True): | |
| mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) | |
| mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) | |
| mu1_sq = mu1.pow(2) | |
| mu2_sq = mu2.pow(2) | |
| mu1_mu2 = mu1 * mu2 | |
| sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq | |
| sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq | |
| sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 | |
| C1 = 0.01 ** 2 | |
| C2 = 0.03 ** 2 | |
| ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
| if size_average: | |
| return ssim_map.mean() | |
| else: | |
| return ssim_map.mean(1).mean(1).mean(1) | |
| def ssim_loss(img1, img2, window_size=11, size_average=True, mask=None): | |
| channel = img1.size(-3) | |
| window = create_window(window_size, channel) | |
| if img1.is_cuda: | |
| window = window.cuda(img1.get_device()) | |
| window = window.type_as(img1) | |
| return _ssim_loss(img1, img2, window, window_size, channel, size_average, mask) | |
| def _ssim_loss(img1, img2, window, window_size, channel, size_average=True, mask=None): | |
| mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) | |
| mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) | |
| mu1_sq = mu1.pow(2) | |
| mu2_sq = mu2.pow(2) | |
| mu1_mu2 = mu1 * mu2 | |
| sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq | |
| sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq | |
| sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 | |
| C1 = 0.01 ** 2 | |
| C2 = 0.03 ** 2 | |
| ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
| ssim_map = 1 - ssim_map | |
| if mask is not None: | |
| ssim_map = ssim_map[:, mask] | |
| if size_average: | |
| return ssim_map.mean() | |
| else: | |
| return ssim_map.mean(1).mean(1).mean(1) | |
| def masked_reduction( | |
| input_tensor: Float[Tensor, "1 32 mult"], | |
| mask: Bool[Tensor, "1 32 mult"], | |
| reduction_type: Literal["image", "batch"], | |
| ) -> Tensor: | |
| """ | |
| Whether to consolidate the input_tensor across the batch or across the image | |
| Args: | |
| input_tensor: input tensor | |
| mask: mask tensor | |
| reduction_type: either "batch" or "image" | |
| Returns: | |
| input_tensor: reduced input_tensor | |
| """ | |
| if reduction_type == "batch": | |
| # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) | |
| divisor = torch.sum(mask) | |
| if divisor == 0: | |
| return torch.tensor(0, device=input_tensor.device) | |
| input_tensor = torch.sum(input_tensor) / divisor | |
| elif reduction_type == "image": | |
| # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) | |
| valid = mask.nonzero() | |
| input_tensor[valid] = input_tensor[valid] / mask[valid] | |
| input_tensor = torch.mean(input_tensor) | |
| return input_tensor | |
| def normalized_depth_scale_and_shift( | |
| prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"] | |
| ): | |
| """ | |
| More info here: https://arxiv.org/pdf/2206.00665.pdf supplementary section A2 Depth Consistency Loss | |
| This function computes scale/shift required to normalizes predicted depth map, | |
| to allow for using normalized depth maps as input from monocular depth estimation networks. | |
| These networks are trained such that they predict normalized depth maps. | |
| Solves for scale/shift using a least squares approach with a closed form solution: | |
| Based on: | |
| https://github.com/autonomousvision/monosdf/blob/d9619e948bf3d85c6adec1a643f679e2e8e84d4b/code/model/loss.py#L7 | |
| Args: | |
| prediction: predicted depth map | |
| target: ground truth depth map | |
| mask: mask of valid pixels | |
| Returns: | |
| scale and shift for depth prediction | |
| """ | |
| # system matrix: A = [[a_00, a_01], [a_10, a_11]] | |
| a_00 = torch.sum(mask * prediction * prediction, (1, 2)) | |
| a_01 = torch.sum(mask * prediction, (1, 2)) | |
| a_11 = torch.sum(mask, (1, 2)) | |
| # right hand side: b = [b_0, b_1] | |
| b_0 = torch.sum(mask * prediction * target, (1, 2)) | |
| b_1 = torch.sum(mask * target, (1, 2)) | |
| # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b | |
| scale = torch.zeros_like(b_0) | |
| shift = torch.zeros_like(b_1) | |
| det = a_00 * a_11 - a_01 * a_01 | |
| valid = det.nonzero() | |
| scale[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] | |
| shift[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] | |
| return scale, shift | |
| class MiDaSMSELoss(nn.Module): | |
| """ | |
| data term from MiDaS paper | |
| """ | |
| def __init__(self, reduction_type: Literal["image", "batch"] = "batch"): | |
| super().__init__() | |
| self.reduction_type: Literal["image", "batch"] = reduction_type | |
| # reduction here is different from the image/batch-based reduction. This is either "mean" or "sum" | |
| self.mse_loss = MSELoss(reduction="none") | |
| def forward( | |
| self, | |
| prediction: Float[Tensor, "1 32 mult"], | |
| target: Float[Tensor, "1 32 mult"], | |
| mask: Bool[Tensor, "1 32 mult"], | |
| ) -> Float[Tensor, "0"]: | |
| """ | |
| Args: | |
| prediction: predicted depth map | |
| target: ground truth depth map | |
| mask: mask of valid pixels | |
| Returns: | |
| mse loss based on reduction function | |
| """ | |
| summed_mask = torch.sum(mask, (1, 2)) | |
| image_loss = torch.sum(self.mse_loss(prediction, target) * mask, (1, 2)) | |
| # multiply by 2 magic number? | |
| image_loss = masked_reduction(image_loss, 2 * summed_mask, self.reduction_type) | |
| return image_loss | |
| class GradientLoss(nn.Module): | |
| """ | |
| multiscale, scale-invariant gradient matching term to the disparity space. | |
| This term biases discontinuities to be sharp and to coincide with discontinuities in the ground truth | |
| More info here https://arxiv.org/pdf/1907.01341.pdf Equation 11 | |
| """ | |
| def __init__(self, scales: int = 4, reduction_type: Literal["image", "batch"] = "batch"): | |
| """ | |
| Args: | |
| scales: number of scales to use | |
| reduction_type: either "batch" or "image" | |
| """ | |
| super().__init__() | |
| self.reduction_type: Literal["image", "batch"] = reduction_type | |
| self.__scales = scales | |
| def forward( | |
| self, | |
| prediction: Float[Tensor, "1 32 mult"], | |
| target: Float[Tensor, "1 32 mult"], | |
| mask: Bool[Tensor, "1 32 mult"], | |
| ) -> Float[Tensor, "0"]: | |
| """ | |
| Args: | |
| prediction: predicted depth map | |
| target: ground truth depth map | |
| mask: mask of valid pixels | |
| Returns: | |
| gradient loss based on reduction function | |
| """ | |
| assert self.__scales >= 1 | |
| total = 0.0 | |
| for scale in range(self.__scales): | |
| step = pow(2, scale) | |
| grad_loss = self.gradient_loss( | |
| prediction[:, ::step, ::step], | |
| target[:, ::step, ::step], | |
| mask[:, ::step, ::step], | |
| ) | |
| total += grad_loss | |
| assert isinstance(total, Tensor) | |
| return total | |
| def gradient_loss( | |
| self, | |
| prediction: Float[Tensor, "1 32 mult"], | |
| target: Float[Tensor, "1 32 mult"], | |
| mask: Bool[Tensor, "1 32 mult"], | |
| ) -> Float[Tensor, "0"]: | |
| """ | |
| multiscale, scale-invariant gradient matching term to the disparity space. | |
| This term biases discontinuities to be sharp and to coincide with discontinuities in the ground truth | |
| More info here https://arxiv.org/pdf/1907.01341.pdf Equation 11 | |
| Args: | |
| prediction: predicted depth map | |
| target: ground truth depth map | |
| reduction: reduction function, either reduction_batch_based or reduction_image_based | |
| Returns: | |
| gradient loss based on reduction function | |
| """ | |
| summed_mask = torch.sum(mask, (1, 2)) | |
| diff = prediction - target | |
| diff = torch.mul(mask, diff) | |
| grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) | |
| mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) | |
| grad_x = torch.mul(mask_x, grad_x) | |
| grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) | |
| mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) | |
| grad_y = torch.mul(mask_y, grad_y) | |
| image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) | |
| image_loss = masked_reduction(image_loss, summed_mask, self.reduction_type) | |
| return image_loss | |
| class ScaleAndShiftInvariantLoss(nn.Module): | |
| """ | |
| Scale and shift invariant loss as described in | |
| "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer" | |
| https://arxiv.org/pdf/1907.01341.pdf | |
| """ | |
| def __init__(self, alpha: float = 0.5, scales: int = 4, reduction_type: Literal["image", "batch"] = "batch"): | |
| """ | |
| Args: | |
| alpha: weight of the regularization term | |
| scales: number of scales to use | |
| reduction_type: either "batch" or "image" | |
| """ | |
| super().__init__() | |
| self.__data_loss = MiDaSMSELoss(reduction_type=reduction_type) | |
| self.__regularization_loss = GradientLoss(scales=scales, reduction_type=reduction_type) | |
| self.__alpha = alpha | |
| self.__prediction_ssi = None | |
| def forward( | |
| self, | |
| prediction: Float[Tensor, "1 32 mult"], | |
| target: Float[Tensor, "1 32 mult"], | |
| mask: Bool[Tensor, "1 32 mult"], | |
| ) -> Float[Tensor, "0"]: | |
| """ | |
| Args: | |
| prediction: predicted depth map (unnormalized) | |
| target: ground truth depth map (normalized) | |
| mask: mask of valid pixels | |
| Returns: | |
| scale and shift invariant loss | |
| """ | |
| scale, shift = normalized_depth_scale_and_shift(prediction, target, mask) | |
| self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) | |
| total = self.__data_loss(self.__prediction_ssi, target, mask) | |
| # if self.__alpha > 0: | |
| # total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) | |
| return total | |
| def __get_prediction_ssi(self): | |
| """ | |
| scale and shift invariant prediction | |
| from https://arxiv.org/pdf/1907.01341.pdf equation 1 | |
| """ | |
| return self.__prediction_ssi | |
| prediction_ssi = property(__get_prediction_ssi) |