Learn2Splat / optgs /model /decoder /diffgs_decoder_splatting_cuda.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
from dataclasses import dataclass
from typing import Literal
import torch
from einops import rearrange, repeat
from jaxtyping import Float
from torch import Tensor
from ...dataset import DatasetCfg
from ..types import Gaussians
from .decoder import DepthRenderingMode
from .diffgs_cuda_splatting import render_cuda
from .decoder import Decoder, DecoderOutput
@dataclass
class DiffgsDecoderSplattingCUDACfg:
name: Literal["diffgs"]
scale_invariant: bool
class DiffgsDecoderSplattingCUDA(Decoder[DiffgsDecoderSplattingCUDACfg]):
background_color: Float[Tensor, "3"]
def __init__(
self,
cfg: DiffgsDecoderSplattingCUDACfg,
dataset_cfg: DatasetCfg,
) -> None:
super().__init__(cfg, dataset_cfg)
self.register_buffer(
"background_color",
torch.tensor(dataset_cfg.background_color, dtype=torch.float32),
persistent=False,
)
def forward(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
) -> DecoderOutput:
b, v, _, _ = extrinsics.shape
out = render_cuda(
rearrange(extrinsics, "b v i j -> (b v) i j"),
rearrange(intrinsics, "b v i j -> (b v) i j"),
rearrange(near, "b v -> (b v)"),
rearrange(far, "b v -> (b v)"),
image_shape,
repeat(self.background_color, "c -> (b v) c", b=b, v=v),
repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v),
repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v),
repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v),
repeat(gaussians.opacities, "b g -> (b v) g", v=v),
scale_invariant=self.cfg.scale_invariant,
)
color = rearrange(out['image'], "(b v) c h w -> b v c h w", b=b, v=v)
# the output is inverse depth (c == 1)
depth = 1. / rearrange(out['depth'], "(b v) c h w -> b (v c) h w", b=b, v=v).clamp(min=1e-6)
return DecoderOutput(
color,
depth
)