Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch.nn as nn | |
| import torch as th | |
| import numpy as np | |
| import logging | |
| from .vgg import VGGLossMasked | |
| logger = logging.getLogger("dva.{__name__}") | |
| class DCTLoss(nn.Module): | |
| def __init__(self, weights): | |
| super().__init__() | |
| self.weights = weights | |
| def forward(self, inputs, preds, iteration=None): | |
| loss_dict = {"loss_total": 0.0} | |
| target = inputs['gt'] | |
| recon = preds['recon'] | |
| posterior = preds['posterior'] | |
| fft_gt = th.view_as_real(th.fft.fft(target.reshape(target.shape[0], -1))) | |
| fft_recon = th.view_as_real(th.fft.fft(recon.reshape(recon.shape[0], -1))) | |
| loss_recon_dct_l1 = th.mean(th.abs(fft_gt - fft_recon)) | |
| loss_recon_l1 = th.mean(th.abs(target - recon)) | |
| loss_kl = posterior.kl().mean() | |
| loss_dict.update(loss_recon_l1=loss_recon_l1, loss_recon_dct_l1=loss_recon_dct_l1, loss_kl=loss_kl) | |
| loss_total = self.weights.recon * loss_recon_dct_l1 + self.weights.kl * loss_kl | |
| loss_dict["loss_total"] = loss_total | |
| return loss_total, loss_dict | |
| class VAESepL2Loss(nn.Module): | |
| def __init__(self, weights): | |
| super().__init__() | |
| self.weights = weights | |
| def forward(self, inputs, preds, iteration=None): | |
| loss_dict = {"loss_total": 0.0} | |
| target = inputs['gt'] | |
| recon = preds['recon'] | |
| posterior = preds['posterior'] | |
| recon_diff = (target - recon) ** 2 | |
| loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...]) | |
| loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...]) | |
| loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...]) | |
| loss_kl = posterior.kl().mean() | |
| loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl) | |
| loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1 | |
| if "kl" in self.weights: | |
| loss_total += self.weights.kl * loss_kl | |
| loss_dict["loss_total"] = loss_total | |
| return loss_total, loss_dict | |
| class VAESepLoss(nn.Module): | |
| def __init__(self, weights): | |
| super().__init__() | |
| self.weights = weights | |
| def forward(self, inputs, preds, iteration=None): | |
| loss_dict = {"loss_total": 0.0} | |
| target = inputs['gt'] | |
| recon = preds['recon'] | |
| posterior = preds['posterior'] | |
| recon_diff = th.abs(target - recon) | |
| loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...]) | |
| loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...]) | |
| loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...]) | |
| loss_kl = posterior.kl().mean() | |
| loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl) | |
| loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1 | |
| if "kl" in self.weights: | |
| loss_total += self.weights.kl * loss_kl | |
| loss_dict["loss_total"] = loss_total | |
| return loss_total, loss_dict | |
| class VAELoss(nn.Module): | |
| def __init__(self, weights): | |
| super().__init__() | |
| self.weights = weights | |
| def forward(self, inputs, preds, iteration=None): | |
| loss_dict = {"loss_total": 0.0} | |
| target = inputs['gt'] | |
| recon = preds['recon'] | |
| posterior = preds['posterior'] | |
| loss_recon_l1 = th.mean(th.abs(target - recon)) | |
| loss_kl = posterior.kl().mean() | |
| loss_dict.update(loss_recon_l1=loss_recon_l1, loss_kl=loss_kl) | |
| loss_total = self.weights.recon * loss_recon_l1 + self.weights.kl * loss_kl | |
| loss_dict["loss_total"] = loss_total | |
| return loss_total, loss_dict | |
| class PrimSDFLoss(nn.Module): | |
| def __init__(self, weights, shape_opt_steps=2000, tex_opt_steps=6000): | |
| super().__init__() | |
| self.weights = weights | |
| self.shape_opt_steps = shape_opt_steps | |
| self.tex_opt_steps = tex_opt_steps | |
| def forward(self, inputs, preds, iteration=None): | |
| loss_dict = {"loss_total": 0.0} | |
| if iteration < self.shape_opt_steps: | |
| target_sdf = inputs['sdf'] | |
| sdf = preds['sdf'] | |
| loss_sdf_l1 = th.mean(th.abs(sdf - target_sdf)) | |
| loss_dict.update(loss_sdf_l1=loss_sdf_l1) | |
| loss_total = self.weights.sdf_l1 * loss_sdf_l1 | |
| prim_scale = preds["prim_scale"] | |
| # we use 1/scale instead of the original 100/scale as our scale is normalized to [-1, 1] cube | |
| if "vol_sum" in self.weights: | |
| loss_prim_vol_sum = th.mean(th.sum(th.prod(1 / prim_scale, dim=-1), dim=-1)) | |
| loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum) | |
| loss_total += self.weights.vol_sum * loss_prim_vol_sum | |
| if iteration >= self.shape_opt_steps and iteration < self.tex_opt_steps: | |
| target_tex = inputs['tex'] | |
| tex = preds['tex'] | |
| loss_tex_l1 = th.mean(th.abs(tex - target_tex)) | |
| loss_dict.update(loss_tex_l1=loss_tex_l1) | |
| loss_total = ( | |
| self.weights.rgb_l1 * loss_tex_l1 | |
| ) | |
| if "mat_l1" in self.weights: | |
| target_mat = inputs['mat'] | |
| mat = preds['mat'] | |
| loss_mat_l1 = th.mean(th.abs(mat - target_mat)) | |
| loss_dict.update(loss_mat_l1=loss_mat_l1) | |
| loss_total += self.weights.mat_l1 * loss_mat_l1 | |
| if "grad_l2" in self.weights: | |
| loss_grad_l2 = th.mean((preds["grad"] - inputs["grad"]) ** 2) | |
| loss_total += self.weights.grad_l2 * loss_grad_l2 | |
| loss_dict.update(loss_grad_l2=loss_grad_l2) | |
| loss_dict["loss_total"] = loss_total | |
| return loss_total, loss_dict | |
| class TotalMVPLoss(nn.Module): | |
| def __init__(self, weights, assets=None): | |
| super().__init__() | |
| self.weights = weights | |
| if "vgg" in self.weights: | |
| self.vgg_loss = VGGLossMasked() | |
| def forward(self, inputs, preds, iteration=None): | |
| loss_dict = {"loss_total": 0.0} | |
| B = inputs["image"].shape | |
| # rgb | |
| target_rgb = inputs["image"].permute(0, 2, 3, 1) | |
| # removing the mask | |
| target_rgb = target_rgb * inputs["image_mask"][:, 0, :, :, np.newaxis] | |
| rgb = preds["rgb"] | |
| loss_rgb_mse = th.mean(((rgb - target_rgb) / 16.0) ** 2.0) | |
| loss_dict.update(loss_rgb_mse=loss_rgb_mse) | |
| alpha = preds["alpha"] | |
| # mask loss | |
| target_mask = inputs["image_mask"][:, 0].to(th.float32) | |
| loss_mask_mae = th.mean((target_mask - alpha).abs()) | |
| loss_dict.update(loss_mask_mae=loss_mask_mae) | |
| B = alpha.shape[0] | |
| # beta prior on opacity | |
| loss_alpha_prior = th.mean( | |
| th.log(0.1 + alpha.reshape(B, -1)) | |
| + th.log(0.1 + 1.0 - alpha.reshape(B, -1)) | |
| - -2.20727 | |
| ) | |
| loss_dict.update(loss_alpha_prior=loss_alpha_prior) | |
| prim_scale = preds["prim_scale"] | |
| loss_prim_vol_sum = th.mean(th.sum(th.prod(100.0 / prim_scale, dim=-1), dim=-1)) | |
| loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum) | |
| loss_total = ( | |
| self.weights.rgb_mse * loss_rgb_mse | |
| + self.weights.mask_mae * loss_mask_mae | |
| + self.weights.alpha_prior * loss_alpha_prior | |
| + self.weights.prim_vol_sum * loss_prim_vol_sum | |
| ) | |
| if "embs_l2" in self.weights: | |
| loss_embs_l2 = th.sum(th.norm(preds["embs"], dim=1)) | |
| loss_total += self.weights.embs_l2 * loss_embs_l2 | |
| loss_dict.update(loss_embs_l2=loss_embs_l2) | |
| if "vgg" in self.weights: | |
| loss_vgg = self.vgg_loss( | |
| rgb.permute(0, 3, 1, 2), | |
| target_rgb.permute(0, 3, 1, 2), | |
| inputs["image_mask"], | |
| ) | |
| loss_total += self.weights.vgg * loss_vgg | |
| loss_dict.update(loss_vgg=loss_vgg) | |
| if "prim_scale_var" in self.weights: | |
| log_prim_scale = th.log(prim_scale) | |
| # NOTE: should we detach this? | |
| log_prim_scale_mean = th.mean(log_prim_scale, dim=1, keepdim=True) | |
| loss_prim_scale_var = th.mean((log_prim_scale - log_prim_scale_mean) ** 2.0) | |
| loss_total += self.weights.prim_scale_var * loss_prim_scale_var | |
| loss_dict.update(loss_prim_scale_var=loss_prim_scale_var) | |
| loss_dict["loss_total"] = loss_total | |
| return loss_total, loss_dict | |
| def process_losses(loss_dict, reduce=True, detach=True): | |
| """Preprocess the dict of losses outputs.""" | |
| result = { | |
| k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_") | |
| } | |
| if detach: | |
| result = {k: v.detach() for k, v in result.items()} | |
| if reduce: | |
| result = {k: float(v.mean().item()) for k, v in result.items()} | |
| return result | |