Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from jaxtyping import Float | |
| from lpips import LPIPS | |
| from torch import Tensor | |
| from src.dataset.types import BatchedExample | |
| from src.misc.nn_module_tools import convert_to_buffer | |
| from src.model.decoder.decoder import DecoderOutput | |
| from src.model.types import Gaussians | |
| from .loss import Loss | |
| class LossLODCfg: | |
| mse_weight: float | |
| lpips_weight: float | |
| class LossLODCfgWrapper: | |
| lod: LossLODCfg | |
| WEIGHT_LEVEL_MAPPING = {0: 0.1, 1: 0.1, 2: 0.2, 3: 0.6} | |
| class LossLOD(Loss[LossLODCfg, LossLODCfgWrapper]): | |
| lpips: LPIPS | |
| def __init__(self, cfg: LossLODCfgWrapper) -> None: | |
| super().__init__(cfg) | |
| self.lpips = LPIPS(net="vgg") | |
| convert_to_buffer(self.lpips, persistent=False) | |
| def forward( | |
| self, | |
| prediction: DecoderOutput, | |
| batch: BatchedExample, | |
| gaussians: Gaussians, | |
| global_step: int, | |
| ) -> Float[Tensor, ""]: | |
| image = batch["target"]["image"] | |
| # breakpoint() | |
| def mse_loss(x, y): | |
| delta = x - y | |
| return torch.nan_to_num((delta**2).mean().mean(), nan=0.0, posinf=0.0, neginf=0.0) | |
| # Before the specified step, don't apply the loss. | |
| lod_rendering = prediction.lod_rendering | |
| loss = 0.0 | |
| for level in lod_rendering.keys(): | |
| # level_weight | |
| # breakpoint() | |
| # if level != 3: | |
| # continue | |
| rendered_imgs = lod_rendering[level]['rendered_imgs'].flatten(0, 1) | |
| _h, _w = rendered_imgs.shape[2:] | |
| resized_image = F.interpolate(image.clone().flatten(0, 1), size=(_h, _w), mode='bilinear', align_corners=False) | |
| level_mse_loss = mse_loss(rendered_imgs, resized_image) | |
| level_lpips_loss = self.lpips.forward(rendered_imgs, resized_image, normalize=True).mean() | |
| loss += (torch.nan_to_num(level_mse_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.mse_weight + torch.nan_to_num(level_lpips_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.lpips_weight) | |
| return loss / len(lod_rendering.keys()) | |