Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Literal | |
| import math | |
| import torch | |
| # import torch.nn.functional as F | |
| from einops import repeat | |
| from gsplat.rendering import rasterization | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from tqdm import tqdm | |
| from optgs.scene_trainer.gaussian_module import GaussiansModule | |
| from .decoder import Decoder, DecoderOutput | |
| from .decoder import DepthRenderingMode | |
| from ..types import Gaussians | |
| from ...dataset import DatasetCfg | |
| class GSplatDecoderSplattingCUDACfg: | |
| name: Literal["gsplat"] | |
| use_covariances: bool | |
| rasterize_mode: Literal["antialiased", "classic"] | |
| eps2d: float | |
| class GSplatDecoderSplattingCUDA(Decoder[GSplatDecoderSplattingCUDACfg]): | |
| background_color: Float[Tensor, "3"] | |
| def __init__( | |
| self, | |
| cfg: GSplatDecoderSplattingCUDACfg, | |
| dataset_cfg: DatasetCfg, | |
| ) -> None: | |
| super().__init__(cfg, dataset_cfg) | |
| self.register_buffer( | |
| "background_color", | |
| torch.tensor(dataset_cfg.background_color, dtype=torch.float32), | |
| persistent=False, | |
| ) | |
| def forward( | |
| self, | |
| gaussians: Gaussians | GaussiansModule, | |
| extrinsics: Float[Tensor, "batch view 4 4"], | |
| intrinsics: Float[Tensor, "batch view 3 3"], | |
| near: Float[Tensor, "batch view"], | |
| far: Float[Tensor, "batch view"], | |
| image_shape: tuple[int, int], | |
| depth_mode: DepthRenderingMode | None = None, # always render depth | |
| return_radii: bool = False, # always return radii | |
| iter_batch_size: int = -1, # -1 to render all views at once | |
| use_covariances: bool = False, # override cfg | |
| to_cpu: bool = False, # move outputs to cpu | |
| ) -> DecoderOutput: | |
| _use_covariances = self.cfg.use_covariances if use_covariances is None else use_covariances | |
| height, width = image_shape | |
| if isinstance(gaussians, GaussiansModule): | |
| # nb: no batch dimension | |
| means = gaussians.means | |
| quats = gaussians.rotations # [N, 4] in xyzw (scalar last) | |
| quats = quats[:, [3, 0, 1, 2]] # [N, 4] in wxyz (scalar first) | |
| quats = quats # [1, N, 4] | |
| scales = gaussians.scales # post-activation | |
| opacities = gaussians.opacities # post-activation | |
| colors = gaussians.harmonics.permute(0, 2, 1) # [1, N, d_sh, 3] | |
| if _use_covariances: | |
| covars = gaussians.covariances | |
| else: | |
| covars = None | |
| # add batch dimension | |
| means = means.unsqueeze(0) # [1, N, 3] | |
| quats = quats.unsqueeze(0) # [1, N, 4] | |
| scales = scales.unsqueeze(0) # [1, N, 3] | |
| opacities = opacities.unsqueeze(0) # [1, N, 1] | |
| colors = colors.unsqueeze(0) # [1, N, d_sh, 3] | |
| if covars is not None: | |
| covars = covars.unsqueeze(0) # [1, N, 3, 3] | |
| elif isinstance(gaussians, Gaussians): | |
| means = gaussians.means # [B, N, 3] | |
| quats = gaussians.rotations_unnorm # [B, N, 4] in wxyz (scalar first), rasterization normalizes internally | |
| quats = quats[:, :, [3, 0, 1, 2]] # [B, N, 4] in wxyz (scalar first) | |
| scales = gaussians.scales # [B, N, 3] | |
| opacities = gaussians.opacities # [B, G] | |
| colors = gaussians.harmonics.permute(0, 1, 3, 2) # [B, N, d_sh, 3] | |
| if _use_covariances: | |
| covars = gaussians.covariances # [B, N, 3, 3] | |
| if covars is None: | |
| raise ValueError("Covariances are set to be used, but gaussians.covariances is None.") | |
| else: | |
| covars = None | |
| if gaussians.stores_activated: | |
| # already activated | |
| pass | |
| else: | |
| # activate | |
| scales = torch.exp(scales) # [B, N, 3] | |
| opacities = torch.sigmoid(opacities) # [B, N] | |
| else: | |
| raise ValueError(f"Unknown type of gaussians: {type(gaussians)}") | |
| # prepare inputs for rasterization | |
| sh_degree = int(math.sqrt(colors.shape[-2])) - 1 # d_sh = (degree + 1) ** 2 | |
| viewmats = extrinsics.inverse() # [B, V, 4, 4] | |
| # scale intrinsics to image shape (avoid clone by creating scaled version directly) | |
| intrinsics_scaled = intrinsics * intrinsics.new_tensor([[[width], [height], [1]]]) # [B, V, 3, 3] | |
| def _render(viewmats, Ks): | |
| # rasterize | |
| render_colors, render_alphas, meta = rasterization( | |
| means=means, | |
| quats=quats, | |
| scales=scales, | |
| opacities=opacities, | |
| colors=colors, | |
| sh_degree=sh_degree, | |
| viewmats=viewmats, | |
| Ks=Ks, | |
| width=width, | |
| height=height, | |
| # near_plane=near[0, 0].item(), # use default | |
| # far_plane=far[0, 0].item(), # use default | |
| eps2d=self.cfg.eps2d, | |
| rasterize_mode=self.cfg.rasterize_mode, | |
| packed=False, | |
| # absgrad=False, # use default | |
| # sparse_grad=False, # use default | |
| render_mode="RGB+ED", | |
| # with_ut=False, # use default | |
| # with_eval3d=False, # use default | |
| # covars=covars, # use default | |
| ) | |
| # unpack outputs | |
| color = render_colors[..., :3].permute(0, 1, 4, 2, 3) # [B, V, 3, H, W] | |
| depth = render_colors[..., -1] # [B, V, H, W] | |
| means2d = meta["means2d"] # [B, V, N, 2] | |
| radii = meta["radii"] # [B, V, N, 2] | |
| visibility_filter = torch.all(radii > 0, dim=-1) # [B, V, N] | |
| return color, depth, render_alphas, means2d, visibility_filter, radii | |
| # split into chunks to save memory | |
| nr_views = extrinsics.shape[1] | |
| if iter_batch_size < 0: | |
| # render all views at once | |
| color, depth, render_alphas, means2d, visibility_filter, radii = _render(viewmats, intrinsics_scaled) | |
| if to_cpu: | |
| color = color.detach().cpu() | |
| depth = depth.detach().cpu() | |
| render_alphas = render_alphas.detach().cpu() | |
| means2d = means2d.detach().cpu() | |
| visibility_filter = visibility_filter.detach().cpu() | |
| radii = radii.detach().cpu() | |
| else: | |
| # split into chunks | |
| chunk_outputs = [] | |
| for i in tqdm(range(0, nr_views, iter_batch_size), desc="Rendering in batches"): | |
| if i + iter_batch_size > nr_views: | |
| bs = nr_views - i | |
| else: | |
| bs = iter_batch_size | |
| iter_viewmats = viewmats[:, i : i + bs] # [B, v, 4, 4] | |
| iter_intrinsics = intrinsics_scaled[:, i : i + bs] # [B, v, 3, 3] | |
| color, depth, render_alphas, means2d, visibility_filter, radii = _render(iter_viewmats, iter_intrinsics) | |
| if to_cpu: | |
| color = color.detach().cpu() | |
| depth = depth.detach().cpu() | |
| render_alphas = render_alphas.detach().cpu() | |
| means2d = means2d.detach().cpu() | |
| visibility_filter = visibility_filter.detach().cpu() | |
| radii = radii.detach().cpu() | |
| chunk_outputs.append((color, depth, render_alphas, means2d, visibility_filter, radii)) | |
| # concatenate all chunks | |
| color = torch.cat([o[0] for o in chunk_outputs], dim=1) # [B, V, 3, H, W] | |
| depth = torch.cat([o[1] for o in chunk_outputs], dim=1) # [B, V, H, W] | |
| render_alphas = torch.cat([o[2] for o in chunk_outputs], dim=1) # [B, V, H, W, 1] | |
| means2d = torch.cat([o[3] for o in chunk_outputs], dim=1) # [B, V, N, 2] | |
| visibility_filter = torch.cat([o[4] for o in chunk_outputs], dim=1) # [B, V, N] | |
| radii = torch.cat([o[5] for o in chunk_outputs], dim=1) # [B, V, N, 2] | |
| return DecoderOutput( | |
| color, | |
| depth=depth, | |
| accumulated_alpha=render_alphas.squeeze(-1), # [B, V, H, W] | |
| means2d=means2d, | |
| visibility_filter=visibility_filter, | |
| radii=radii, | |
| ) | |