| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange |
| from jaxtyping import Float |
| from lpips import LPIPS |
| from torch import Tensor |
|
|
| from src.dataset.types import BatchedExample |
| from src.misc.nn_module_tools import convert_to_buffer |
| from src.model.decoder.decoder import DecoderOutput |
| from src.model.types import Gaussians |
| from .loss import Loss |
|
|
|
|
| @dataclass |
| class LossLODCfg: |
| mse_weight: float |
| lpips_weight: float |
|
|
| @dataclass |
| class LossLODCfgWrapper: |
| lod: LossLODCfg |
|
|
| WEIGHT_LEVEL_MAPPING = {0: 0.1, 1: 0.1, 2: 0.2, 3: 0.6} |
|
|
| class LossLOD(Loss[LossLODCfg, LossLODCfgWrapper]): |
| lpips: LPIPS |
|
|
| def __init__(self, cfg: LossLODCfgWrapper) -> None: |
| super().__init__(cfg) |
|
|
| self.lpips = LPIPS(net="vgg") |
| convert_to_buffer(self.lpips, persistent=False) |
| |
| def forward( |
| self, |
| prediction: DecoderOutput, |
| batch: BatchedExample, |
| gaussians: Gaussians, |
| global_step: int, |
| ) -> Float[Tensor, ""]: |
| image = batch["target"]["image"] |
| |
| def mse_loss(x, y): |
| delta = x - y |
| return torch.nan_to_num((delta**2).mean().mean(), nan=0.0, posinf=0.0, neginf=0.0) |
| |
| lod_rendering = prediction.lod_rendering |
| loss = 0.0 |
| for level in lod_rendering.keys(): |
| |
| |
| |
| |
| rendered_imgs = lod_rendering[level]['rendered_imgs'].flatten(0, 1) |
| _h, _w = rendered_imgs.shape[2:] |
| resized_image = F.interpolate(image.clone().flatten(0, 1), size=(_h, _w), mode='bilinear', align_corners=False) |
| level_mse_loss = mse_loss(rendered_imgs, resized_image) |
| level_lpips_loss = self.lpips.forward(rendered_imgs, resized_image, normalize=True).mean() |
|
|
| loss += (torch.nan_to_num(level_mse_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.mse_weight + torch.nan_to_num(level_lpips_loss, nan=0.0, posinf=0.0, neginf=0.0) * self.cfg.lpips_weight) |
| return loss / len(lod_rendering.keys()) |
|
|