Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| import torch | |
| from einops import reduce | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from src.dataset.types import BatchedExample | |
| from src.model.decoder.decoder import DecoderOutput | |
| from src.model.types import Gaussians | |
| from .loss import Loss | |
| from typing import Generic, TypeVar | |
| from dataclasses import fields | |
| import torch.nn.functional as F | |
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # from src.loss.depth_anything.dpt import DepthAnything | |
| from src.misc.utils import vis_depth_map | |
| T_cfg = TypeVar("T_cfg") | |
| T_wrapper = TypeVar("T_wrapper") | |
| class LossDepthCfg: | |
| weight: float | |
| sigma_image: float | None | |
| use_second_derivative: bool | |
| class LossDepthCfgWrapper: | |
| depth: LossDepthCfg | |
| class LossDepth(Loss[LossDepthCfg, LossDepthCfgWrapper]): | |
| def __init__(self, cfg: T_wrapper) -> None: | |
| super().__init__(cfg) | |
| # Extract the configuration from the wrapper. | |
| (field,) = fields(type(cfg)) | |
| self.cfg = getattr(cfg, field.name) | |
| self.name = field.name | |
| model_configs = { | |
| 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
| 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
| 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]} | |
| } | |
| encoder = 'vits' # or 'vitb', 'vits' | |
| depth_anything = DepthAnything(model_configs[encoder]) | |
| depth_anything.load_state_dict(torch.load(f'src/loss/depth_anything/depth_anything_{encoder}14.pth')) | |
| self.depth_anything = depth_anything | |
| for param in self.depth_anything.parameters(): | |
| param.requires_grad = False | |
| def disp_rescale(self, disp: Float[Tensor, "B H W"]): | |
| disp = disp.flatten(1, 2) | |
| disp_median = torch.median(disp, dim=-1, keepdim=True)[0] # (B, V, 1) | |
| disp_var = (disp - disp_median).abs().mean(dim=-1, keepdim=True) # (B, V, 1) | |
| disp = (disp - disp_median) / (disp_var + 1e-6) | |
| return disp | |
| def smooth_l1_loss(self, pred, target, beta=1.0, reduction='none'): | |
| diff = pred - target | |
| abs_diff = torch.abs(diff) | |
| loss = torch.where(abs_diff < beta, 0.5 * diff ** 2 / beta, abs_diff - 0.5 * beta) | |
| if reduction == 'mean': | |
| return loss.mean() | |
| elif reduction == 'sum': | |
| return loss.sum() | |
| elif reduction == 'none': | |
| return loss | |
| else: | |
| raise ValueError("Invalid reduction type. Choose from 'mean', 'sum', or 'none'.") | |
| def ctx_depth_loss(self, | |
| depth_map: Float[Tensor, "B V H W C"], | |
| depth_conf: Float[Tensor, "B V H W"], | |
| batch: BatchedExample, | |
| cxt_depth_weight: float = 0.01, | |
| alpha: float = 0.2): | |
| B, V, _, H, W = batch["context"]["image"].shape | |
| ctx_imgs = batch["context"]["image"].view(B * V, 3, H, W).float() | |
| da_output = self.depth_anything(ctx_imgs) | |
| da_output = self.disp_rescale(da_output) | |
| disp_context = 1.0 / depth_map.flatten(0, 1).squeeze(-1).clamp(1e-3) # (B * V, H, W) | |
| context_output = self.disp_rescale(disp_context) | |
| depth_conf = depth_conf.flatten(0, 1).flatten(1, 2) # (B * V) | |
| return cxt_depth_weight * (self.smooth_l1_loss(context_output*depth_conf, da_output*depth_conf, reduction='none') - alpha * torch.log(depth_conf)).mean() | |
| def forward( | |
| self, | |
| prediction: DecoderOutput, | |
| batch: BatchedExample, | |
| gaussians: Gaussians, | |
| global_step: int, | |
| ) -> Float[Tensor, ""]: | |
| # Scale the depth between the near and far planes. | |
| target_imgs = batch["target"]["image"] | |
| B, V, _, H, W = target_imgs.shape | |
| target_imgs = target_imgs.view(B * V, 3, H, W) | |
| da_output = self.depth_anything(target_imgs.float()) | |
| da_output = self.disp_rescale(da_output) | |
| disp_gs = 1.0 / prediction.depth.flatten(0, 1).clamp(1e-3).float() | |
| gs_output = self.disp_rescale(disp_gs) | |
| return self.cfg.weight * torch.nan_to_num(F.smooth_l1_loss(da_output, gs_output), nan=0.0) |