Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| import torch | |
| from src.dataset.types import BatchedExample | |
| from src.model.decoder.decoder import DecoderOutput | |
| from src.model.types import Gaussians | |
| from .loss import Loss | |
| class LossMseCfg: | |
| weight: float | |
| conf: bool = False | |
| mask: bool = False | |
| alpha: bool = False | |
| class LossMseCfgWrapper: | |
| mse: LossMseCfg | |
| class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]): | |
| def forward( | |
| self, | |
| prediction: DecoderOutput, | |
| batch: BatchedExample, | |
| gaussians: Gaussians, | |
| depth_dict: dict | None, | |
| global_step: int, | |
| ) -> Float[Tensor, ""]: | |
| # Get alpha and valid mask from inputs | |
| alpha = prediction.alpha | |
| # valid_mask = torch.ones_like(alpha, device=alpha.device).bool() | |
| valid_mask = batch['context']['valid_mask'] | |
| # # only for objaverse | |
| # if batch['context']['valid_mask'].sum() > 0: | |
| # valid_mask = batch['context']['valid_mask'] | |
| # Determine which mask to use based on config | |
| if self.cfg.mask: | |
| mask = valid_mask | |
| elif self.cfg.alpha: | |
| mask = alpha | |
| elif self.cfg.conf: | |
| mask = depth_dict['conf_valid_mask'] | |
| else: | |
| mask = torch.ones_like(alpha, device=alpha.device).bool() | |
| # Rearrange and mask predicted and ground truth images | |
| pred_img = prediction.color.permute(0, 1, 3, 4, 2)[mask] | |
| gt_img = ((batch["context"]["image"][:, batch["using_index"]] + 1) / 2).permute(0, 1, 3, 4, 2)[mask] | |
| delta = pred_img - gt_img | |
| return self.cfg.weight * torch.nan_to_num((delta**2).mean(), nan=0.0, posinf=0.0, neginf=0.0) | |