Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| import torch | |
| from einops import rearrange | |
| from jaxtyping import Float | |
| from lpips import LPIPS | |
| from torch import Tensor | |
| from src.dataset.types import BatchedExample | |
| from src.misc.nn_module_tools import convert_to_buffer | |
| from src.model.decoder.decoder import DecoderOutput | |
| from src.model.types import Gaussians | |
| from .loss import Loss | |
| class LossLpipsCfg: | |
| weight: float | |
| apply_after_step: int | |
| conf: bool = False | |
| alpha: bool = False | |
| mask: bool = False | |
| class LossLpipsCfgWrapper: | |
| lpips: LossLpipsCfg | |
| class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]): | |
| lpips: LPIPS | |
| def __init__(self, cfg: LossLpipsCfgWrapper) -> None: | |
| super().__init__(cfg) | |
| self.lpips = LPIPS(net="vgg") | |
| convert_to_buffer(self.lpips, persistent=False) | |
| def forward( | |
| self, | |
| prediction: DecoderOutput, | |
| batch: BatchedExample, | |
| gaussians: Gaussians, | |
| depth_dict: dict | None, | |
| global_step: int, | |
| ) -> Float[Tensor, ""]: | |
| image = (batch["context"]["image"] + 1) / 2 | |
| # Before the specified step, don't apply the loss. | |
| if global_step < self.cfg.apply_after_step: | |
| return torch.tensor(0, dtype=torch.float32, device=image.device) | |
| if self.cfg.mask or self.cfg.alpha or self.cfg.conf: | |
| if self.cfg.mask: | |
| mask = batch["context"]["valid_mask"] | |
| elif self.cfg.alpha: | |
| mask = prediction.alpha | |
| elif self.cfg.conf: | |
| mask = depth_dict['conf_valid_mask'] | |
| b, v, c, h, w = prediction.color.shape | |
| expanded_mask = mask.unsqueeze(2).expand(-1, -1, c, -1, -1) | |
| masked_pred = prediction.color * expanded_mask | |
| masked_img = image * expanded_mask | |
| loss = self.lpips.forward( | |
| rearrange(masked_pred, "b v c h w -> (b v) c h w"), | |
| rearrange(masked_img, "b v c h w -> (b v) c h w"), | |
| normalize=True, | |
| ) | |
| else: | |
| loss = self.lpips.forward( | |
| rearrange(prediction.color, "b v c h w -> (b v) c h w"), | |
| rearrange(image, "b v c h w -> (b v) c h w"), | |
| normalize=True, | |
| ) | |
| return self.cfg.weight * torch.nan_to_num(loss.mean(), nan=0.0, posinf=0.0, neginf=0.0) | |