depthsplat / src /model /decoder /decoder_splatting_cuda.py
Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
from dataclasses import dataclass
from typing import Literal
import torch
from einops import rearrange, repeat
from jaxtyping import Float
from torch import Tensor
from ...dataset import DatasetCfg
from ..types import Gaussians
from .cuda_splatting import DepthRenderingMode, render_cuda, render_depth_cuda
from .decoder import Decoder, DecoderOutput
@dataclass
class DecoderSplattingCUDACfg:
name: Literal["splatting_cuda"]
class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]):
background_color: Float[Tensor, "3"]
def __init__(
self,
cfg: DecoderSplattingCUDACfg,
dataset_cfg: DatasetCfg,
) -> None:
super().__init__(cfg, dataset_cfg)
self.register_buffer(
"background_color",
torch.tensor(dataset_cfg.background_color, dtype=torch.float32),
persistent=False,
)
def forward(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
) -> DecoderOutput:
b, v, _, _ = extrinsics.shape #b = 2, v = 8
color = render_cuda(
rearrange(extrinsics, "b v i j -> (b v) i j"), #[2, 8, 4, 4] -> [16, 4, 4]
rearrange(intrinsics, "b v i j -> (b v) i j"), #[2, 8, 3, 3] -> [16, 3, 3]
rearrange(near, "b v -> (b v)"), #[2, 8] -> [16]
rearrange(far, "b v -> (b v)"), #[2, 8] -> [16]
image_shape,
repeat(self.background_color, "c -> (b v) c", b=b, v=v),
repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), #[2, 256000, 3] --> [16, 256000, 3]
repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v),
repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v),
repeat(gaussians.opacities, "b g -> (b v) g", v=v),
)
color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v)
return DecoderOutput(
color,
None
if depth_mode is None
else self.render_depth(
gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode
),
)
def render_depth(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
mode: DepthRenderingMode = "depth",
) -> Float[Tensor, "batch view height width"]:
b, v, _, _ = extrinsics.shape
result = render_depth_cuda(
rearrange(extrinsics, "b v i j -> (b v) i j"),
rearrange(intrinsics, "b v i j -> (b v) i j"),
rearrange(near, "b v -> (b v)"),
rearrange(far, "b v -> (b v)"),
image_shape,
repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v),
repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v),
repeat(gaussians.opacities, "b g -> (b v) g", v=v),
mode=mode,
)
return rearrange(result, "(b v) h w -> b v h w", b=b, v=v)