| from dataclasses import dataclass |
|
|
| import torch |
| from einops import reduce |
| from jaxtyping import Float |
| from torch import Tensor |
|
|
| from src.dataset.types import BatchedExample |
| from src.model.decoder.decoder import DecoderOutput |
| from src.model.types import Gaussians |
| from .loss import Loss |
| from typing import Generic, TypeVar |
| from dataclasses import fields |
| import torch.nn.functional as F |
| import sys |
| from pytorch3d.loss import chamfer_distance |
| import os |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| |
| from src.misc.utils import vis_depth_map |
|
|
| T_cfg = TypeVar("T_cfg") |
| T_wrapper = TypeVar("T_wrapper") |
|
|
|
|
| @dataclass |
| class LossChamferDistanceCfg: |
| weight: float |
| down_sample_ratio: float |
| sigma_image: float | None |
|
|
|
|
| @dataclass |
| class LossChamferDistanceCfgWrapper: |
| chamfer_distance: LossChamferDistanceCfg |
|
|
| class LossChamferDistance(Loss[LossChamferDistanceCfg, LossChamferDistanceCfgWrapper]): |
| def __init__(self, cfg: T_wrapper) -> None: |
| super().__init__(cfg) |
| |
| |
| (field,) = fields(type(cfg)) |
| self.cfg = getattr(cfg, field.name) |
| self.name = field.name |
|
|
| def forward( |
| self, |
| prediction: DecoderOutput, |
| batch: BatchedExample, |
| gaussians: Gaussians, |
| depth_dict: dict, |
| global_step: int, |
| ) -> Float[Tensor, ""]: |
| |
| b, v, h, w, _ = depth_dict['distill_infos']['pts_all'].shape |
| pred_pts = depth_dict['distill_infos']['pts_all'].flatten(0, 1) |
|
|
| conf_mask = depth_dict['distill_infos']['conf_mask'] |
| gaussian_meas = gaussians.means |
|
|
| pred_pts = pred_pts.view(b, v, h, w, -1) |
| conf_mask = conf_mask.view(b, v, h, w) |
|
|
| pts_mask = torch.abs(gaussian_meas[..., -1]) < 1e2 |
| |
| |
| cd_losses = 0.0 |
| for b_idx in range(b): |
| batch_pts, batch_conf, batch_gaussian = pred_pts[b_idx], conf_mask[b_idx], gaussian_meas[b_idx][pts_mask[b_idx]] |
| batch_pts = batch_pts[batch_conf] |
| batch_pts = batch_pts[torch.randperm(batch_pts.shape[0])[:int(batch_pts.shape[0] * self.cfg.down_sample_ratio)]] |
| batch_gaussian = batch_gaussian[torch.randperm(batch_gaussian.shape[0])[:int(batch_gaussian.shape[0] * self.cfg.down_sample_ratio)]] |
| cd_loss = chamfer_distance(batch_pts.unsqueeze(0), batch_gaussian.unsqueeze(0))[0] |
| cd_losses = cd_losses + cd_loss |
| return self.cfg.weight * torch.nan_to_num(cd_losses / b, nan=0.0) |