| from dataclasses import dataclass |
| from typing import Literal |
|
|
| import torch |
| from einops import rearrange, repeat |
| from jaxtyping import Float |
| from torch import Tensor |
| import torchvision |
|
|
| from ..types import Gaussians |
| |
| from .decoder import Decoder, DecoderOutput |
| from math import sqrt |
| from gsplat import rasterization |
|
|
| from ...misc.utils import vis_depth_map |
|
|
| DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"] |
|
|
| @dataclass |
| class DecoderSplattingCUDACfg: |
| name: Literal["splatting_cuda"] |
| background_color: list[float] |
| make_scale_invariant: bool |
|
|
|
|
| class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): |
| background_color: Float[Tensor, "3"] |
| |
| def __init__( |
| self, |
| cfg: DecoderSplattingCUDACfg, |
| ) -> None: |
| super().__init__(cfg) |
| self.make_scale_invariant = cfg.make_scale_invariant |
| self.register_buffer( |
| "background_color", |
| torch.tensor(cfg.background_color, dtype=torch.float32), |
| persistent=False, |
| ) |
|
|
| def rendering_fn( |
| self, |
| gaussians: Gaussians, |
| extrinsics: Float[Tensor, "batch view 4 4"], |
| intrinsics: Float[Tensor, "batch view 3 3"], |
| near: Float[Tensor, "batch view"], |
| far: Float[Tensor, "batch view"], |
| image_shape: tuple[int, int], |
| depth_mode: DepthRenderingMode | None = None, |
| cam_rot_delta: Float[Tensor, "batch view 3"] | None = None, |
| cam_trans_delta: Float[Tensor, "batch view 3"] | None = None, |
| ) -> DecoderOutput: |
| B, V, _, _ = intrinsics.shape |
| H, W = image_shape |
| rendered_imgs, rendered_depths, rendered_alphas = [], [], [] |
| xyzs, opacitys, rotations, scales, features = gaussians.means, gaussians.opacities, gaussians.rotations, gaussians.scales, gaussians.harmonics.permute(0, 1, 3, 2).contiguous() |
| covariances = gaussians.covariances |
| for i in range(B): |
| xyz_i = xyzs[i].float() |
| feature_i = features[i].float() |
| covar_i = covariances[i].float() |
| scale_i = scales[i].float() |
| rotation_i = rotations[i].float() |
| opacity_i = opacitys[i].squeeze().float() |
| test_w2c_i = extrinsics[i].float().inverse() |
| test_intr_i_normalized = intrinsics[i].float() |
| |
| test_intr_i = test_intr_i_normalized.clone() |
| test_intr_i[:, 0] = test_intr_i_normalized[:, 0] * W |
| test_intr_i[:, 1] = test_intr_i_normalized[:, 1] * H |
| sh_degree = (int(sqrt(feature_i.shape[-2])) - 1) |
|
|
| rendering_list = [] |
| rendering_depth_list = [] |
| rendering_alpha_list = [] |
| for j in range(V): |
| rendering, alpha, _ = rasterization(xyz_i, rotation_i, scale_i, opacity_i, feature_i, |
| test_w2c_i[j:j+1], test_intr_i[j:j+1], W, H, sh_degree=sh_degree, |
| |
| render_mode="RGB+D", packed=False, |
| near_plane=1e-10, |
| backgrounds=self.background_color.unsqueeze(0).repeat(1, 1), |
| radius_clip=0.1, |
| covars=covar_i, |
| rasterize_mode='classic') |
| rendering_img, rendering_depth = torch.split(rendering, [3, 1], dim=-1) |
| rendering_img = rendering_img.clamp(0.0, 1.0) |
| rendering_list.append(rendering_img.permute(0, 3, 1, 2)) |
| rendering_depth_list.append(rendering_depth) |
| rendering_alpha_list.append(alpha) |
| rendered_depths.append(torch.cat(rendering_depth_list, dim=0).squeeze()) |
| rendered_imgs.append(torch.cat(rendering_list, dim=0)) |
| rendered_alphas.append(torch.cat(rendering_alpha_list, dim=0).squeeze()) |
| return DecoderOutput(torch.stack(rendered_imgs), torch.stack(rendered_depths), torch.stack(rendered_alphas), lod_rendering=None) |
|
|
| def forward( |
| self, |
| gaussians: Gaussians, |
| extrinsics: Float[Tensor, "batch view 4 4"], |
| intrinsics: Float[Tensor, "batch view 3 3"], |
| near: Float[Tensor, "batch view"], |
| far: Float[Tensor, "batch view"], |
| image_shape: tuple[int, int], |
| depth_mode: DepthRenderingMode | None = None, |
| cam_rot_delta: Float[Tensor, "batch view 3"] | None = None, |
| cam_trans_delta: Float[Tensor, "batch view 3"] | None = None, |
| ) -> DecoderOutput: |
| |
| return self.rendering_fn(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode, cam_rot_delta, cam_trans_delta) |
|
|
|
|