Spaces:
Sleeping
Sleeping
File size: 2,316 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | from dataclasses import dataclass
from typing import Literal
import torch
from einops import rearrange, repeat
from jaxtyping import Float
from torch import Tensor
from ...dataset import DatasetCfg
from ..types import Gaussians
from .decoder import DepthRenderingMode
from .diffgs_cuda_splatting import render_cuda
from .decoder import Decoder, DecoderOutput
@dataclass
class DiffgsDecoderSplattingCUDACfg:
name: Literal["diffgs"]
scale_invariant: bool
class DiffgsDecoderSplattingCUDA(Decoder[DiffgsDecoderSplattingCUDACfg]):
background_color: Float[Tensor, "3"]
def __init__(
self,
cfg: DiffgsDecoderSplattingCUDACfg,
dataset_cfg: DatasetCfg,
) -> None:
super().__init__(cfg, dataset_cfg)
self.register_buffer(
"background_color",
torch.tensor(dataset_cfg.background_color, dtype=torch.float32),
persistent=False,
)
def forward(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
) -> DecoderOutput:
b, v, _, _ = extrinsics.shape
out = render_cuda(
rearrange(extrinsics, "b v i j -> (b v) i j"),
rearrange(intrinsics, "b v i j -> (b v) i j"),
rearrange(near, "b v -> (b v)"),
rearrange(far, "b v -> (b v)"),
image_shape,
repeat(self.background_color, "c -> (b v) c", b=b, v=v),
repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v),
repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v),
repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v),
repeat(gaussians.opacities, "b g -> (b v) g", v=v),
scale_invariant=self.cfg.scale_invariant,
)
color = rearrange(out['image'], "(b v) c h w -> b v c h w", b=b, v=v)
# the output is inverse depth (c == 1)
depth = 1. / rearrange(out['depth'], "(b v) c h w -> b (v c) h w", b=b, v=v).clamp(min=1e-6)
return DecoderOutput(
color,
depth
)
|