| from dataclasses import dataclass |
| from typing import Literal |
|
|
| from jaxtyping import Float |
| from torch import Tensor |
| import torch |
| import torch.nn.functional as F |
| from src.dataset.types import BatchedExample |
| from src.model.decoder.decoder import DecoderOutput |
| from src.model.types import Gaussians |
| from .loss import Loss |
|
|
|
|
| @dataclass |
| class LossOpacityCfg: |
| weight: float |
| type: Literal["exp", "mean", "exp+mean"] = "exp+mean" |
|
|
|
|
| @dataclass |
| class LossOpacityCfgWrapper: |
| opacity: LossOpacityCfg |
|
|
|
|
| class LossOpacity(Loss[LossOpacityCfg, LossOpacityCfgWrapper]): |
| def forward( |
| self, |
| prediction: DecoderOutput, |
| batch: BatchedExample, |
| gaussians: Gaussians, |
| depth_dict: dict | None, |
| global_step: int, |
| ) -> Float[Tensor, ""]: |
| alpha = prediction.alpha |
| valid_mask = batch['context']['valid_mask'].float() |
| opacity_loss = F.mse_loss(alpha, valid_mask, reduction='none').mean() |
| |
| |
| |
| |
| |
| |
| return self.cfg.weight * torch.nan_to_num(opacity_loss, nan=0.0, posinf=0.0, neginf=0.0) |
|
|