Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Literal | |
| import torch | |
| from einops import rearrange, repeat | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from ...dataset import DatasetCfg | |
| from ..types import Gaussians | |
| from .decoder import DepthRenderingMode | |
| from .diffgs_cuda_splatting import render_cuda | |
| from .decoder import Decoder, DecoderOutput | |
| class DiffgsDecoderSplattingCUDACfg: | |
| name: Literal["diffgs"] | |
| scale_invariant: bool | |
| class DiffgsDecoderSplattingCUDA(Decoder[DiffgsDecoderSplattingCUDACfg]): | |
| background_color: Float[Tensor, "3"] | |
| def __init__( | |
| self, | |
| cfg: DiffgsDecoderSplattingCUDACfg, | |
| dataset_cfg: DatasetCfg, | |
| ) -> None: | |
| super().__init__(cfg, dataset_cfg) | |
| self.register_buffer( | |
| "background_color", | |
| torch.tensor(dataset_cfg.background_color, dtype=torch.float32), | |
| persistent=False, | |
| ) | |
| def forward( | |
| self, | |
| gaussians: Gaussians, | |
| extrinsics: Float[Tensor, "batch view 4 4"], | |
| intrinsics: Float[Tensor, "batch view 3 3"], | |
| near: Float[Tensor, "batch view"], | |
| far: Float[Tensor, "batch view"], | |
| image_shape: tuple[int, int], | |
| depth_mode: DepthRenderingMode | None = None, | |
| ) -> DecoderOutput: | |
| b, v, _, _ = extrinsics.shape | |
| out = render_cuda( | |
| rearrange(extrinsics, "b v i j -> (b v) i j"), | |
| rearrange(intrinsics, "b v i j -> (b v) i j"), | |
| rearrange(near, "b v -> (b v)"), | |
| rearrange(far, "b v -> (b v)"), | |
| image_shape, | |
| repeat(self.background_color, "c -> (b v) c", b=b, v=v), | |
| repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), | |
| repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), | |
| repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), | |
| repeat(gaussians.opacities, "b g -> (b v) g", v=v), | |
| scale_invariant=self.cfg.scale_invariant, | |
| ) | |
| color = rearrange(out['image'], "(b v) c h w -> b v c h w", b=b, v=v) | |
| # the output is inverse depth (c == 1) | |
| depth = 1. / rearrange(out['depth'], "(b v) c h w -> b (v c) h w", b=b, v=v).clamp(min=1e-6) | |
| return DecoderOutput( | |
| color, | |
| depth | |
| ) | |