from dataclasses import dataclass from typing import Literal import torch from einops import rearrange, repeat from jaxtyping import Float from torch import Tensor from ...dataset import DatasetCfg from ..types import Gaussians from .decoder import DepthRenderingMode from .diffgs_cuda_splatting import render_cuda from .decoder import Decoder, DecoderOutput @dataclass class DiffgsDecoderSplattingCUDACfg: name: Literal["diffgs"] scale_invariant: bool class DiffgsDecoderSplattingCUDA(Decoder[DiffgsDecoderSplattingCUDACfg]): background_color: Float[Tensor, "3"] def __init__( self, cfg: DiffgsDecoderSplattingCUDACfg, dataset_cfg: DatasetCfg, ) -> None: super().__init__(cfg, dataset_cfg) self.register_buffer( "background_color", torch.tensor(dataset_cfg.background_color, dtype=torch.float32), persistent=False, ) def forward( self, gaussians: Gaussians, extrinsics: Float[Tensor, "batch view 4 4"], intrinsics: Float[Tensor, "batch view 3 3"], near: Float[Tensor, "batch view"], far: Float[Tensor, "batch view"], image_shape: tuple[int, int], depth_mode: DepthRenderingMode | None = None, ) -> DecoderOutput: b, v, _, _ = extrinsics.shape out = render_cuda( rearrange(extrinsics, "b v i j -> (b v) i j"), rearrange(intrinsics, "b v i j -> (b v) i j"), rearrange(near, "b v -> (b v)"), rearrange(far, "b v -> (b v)"), image_shape, repeat(self.background_color, "c -> (b v) c", b=b, v=v), repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), repeat(gaussians.opacities, "b g -> (b v) g", v=v), scale_invariant=self.cfg.scale_invariant, ) color = rearrange(out['image'], "(b v) c h w -> b v c h w", b=b, v=v) # the output is inverse depth (c == 1) depth = 1. / rearrange(out['depth'], "(b v) c h w -> b (v c) h w", b=b, v=v).clamp(min=1e-6) return DecoderOutput( color, depth )