Learn2Splat / optgs /model /decoder /gsplat_decoder_splatting_cuda.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
from dataclasses import dataclass
from typing import Literal
import math
import torch
# import torch.nn.functional as F
from einops import repeat
from gsplat.rendering import rasterization
from jaxtyping import Float
from torch import Tensor
from tqdm import tqdm
from optgs.scene_trainer.gaussian_module import GaussiansModule
from .decoder import Decoder, DecoderOutput
from .decoder import DepthRenderingMode
from ..types import Gaussians
from ...dataset import DatasetCfg
@dataclass
class GSplatDecoderSplattingCUDACfg:
name: Literal["gsplat"]
use_covariances: bool
rasterize_mode: Literal["antialiased", "classic"]
eps2d: float
class GSplatDecoderSplattingCUDA(Decoder[GSplatDecoderSplattingCUDACfg]):
background_color: Float[Tensor, "3"]
def __init__(
self,
cfg: GSplatDecoderSplattingCUDACfg,
dataset_cfg: DatasetCfg,
) -> None:
super().__init__(cfg, dataset_cfg)
self.register_buffer(
"background_color",
torch.tensor(dataset_cfg.background_color, dtype=torch.float32),
persistent=False,
)
def forward(
self,
gaussians: Gaussians | GaussiansModule,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None, # always render depth
return_radii: bool = False, # always return radii
iter_batch_size: int = -1, # -1 to render all views at once
use_covariances: bool = False, # override cfg
to_cpu: bool = False, # move outputs to cpu
) -> DecoderOutput:
_use_covariances = self.cfg.use_covariances if use_covariances is None else use_covariances
height, width = image_shape
if isinstance(gaussians, GaussiansModule):
# nb: no batch dimension
means = gaussians.means
quats = gaussians.rotations # [N, 4] in xyzw (scalar last)
quats = quats[:, [3, 0, 1, 2]] # [N, 4] in wxyz (scalar first)
quats = quats # [1, N, 4]
scales = gaussians.scales # post-activation
opacities = gaussians.opacities # post-activation
colors = gaussians.harmonics.permute(0, 2, 1) # [1, N, d_sh, 3]
if _use_covariances:
covars = gaussians.covariances
else:
covars = None
# add batch dimension
means = means.unsqueeze(0) # [1, N, 3]
quats = quats.unsqueeze(0) # [1, N, 4]
scales = scales.unsqueeze(0) # [1, N, 3]
opacities = opacities.unsqueeze(0) # [1, N, 1]
colors = colors.unsqueeze(0) # [1, N, d_sh, 3]
if covars is not None:
covars = covars.unsqueeze(0) # [1, N, 3, 3]
elif isinstance(gaussians, Gaussians):
means = gaussians.means # [B, N, 3]
quats = gaussians.rotations_unnorm # [B, N, 4] in wxyz (scalar first), rasterization normalizes internally
quats = quats[:, :, [3, 0, 1, 2]] # [B, N, 4] in wxyz (scalar first)
scales = gaussians.scales # [B, N, 3]
opacities = gaussians.opacities # [B, G]
colors = gaussians.harmonics.permute(0, 1, 3, 2) # [B, N, d_sh, 3]
if _use_covariances:
covars = gaussians.covariances # [B, N, 3, 3]
if covars is None:
raise ValueError("Covariances are set to be used, but gaussians.covariances is None.")
else:
covars = None
if gaussians.stores_activated:
# already activated
pass
else:
# activate
scales = torch.exp(scales) # [B, N, 3]
opacities = torch.sigmoid(opacities) # [B, N]
else:
raise ValueError(f"Unknown type of gaussians: {type(gaussians)}")
# prepare inputs for rasterization
sh_degree = int(math.sqrt(colors.shape[-2])) - 1 # d_sh = (degree + 1) ** 2
viewmats = extrinsics.inverse() # [B, V, 4, 4]
# scale intrinsics to image shape (avoid clone by creating scaled version directly)
intrinsics_scaled = intrinsics * intrinsics.new_tensor([[[width], [height], [1]]]) # [B, V, 3, 3]
def _render(viewmats, Ks):
# rasterize
render_colors, render_alphas, meta = rasterization(
means=means,
quats=quats,
scales=scales,
opacities=opacities,
colors=colors,
sh_degree=sh_degree,
viewmats=viewmats,
Ks=Ks,
width=width,
height=height,
# near_plane=near[0, 0].item(), # use default
# far_plane=far[0, 0].item(), # use default
eps2d=self.cfg.eps2d,
rasterize_mode=self.cfg.rasterize_mode,
packed=False,
# absgrad=False, # use default
# sparse_grad=False, # use default
render_mode="RGB+ED",
# with_ut=False, # use default
# with_eval3d=False, # use default
# covars=covars, # use default
)
# unpack outputs
color = render_colors[..., :3].permute(0, 1, 4, 2, 3) # [B, V, 3, H, W]
depth = render_colors[..., -1] # [B, V, H, W]
means2d = meta["means2d"] # [B, V, N, 2]
radii = meta["radii"] # [B, V, N, 2]
visibility_filter = torch.all(radii > 0, dim=-1) # [B, V, N]
return color, depth, render_alphas, means2d, visibility_filter, radii
# split into chunks to save memory
nr_views = extrinsics.shape[1]
if iter_batch_size < 0:
# render all views at once
color, depth, render_alphas, means2d, visibility_filter, radii = _render(viewmats, intrinsics_scaled)
if to_cpu:
color = color.detach().cpu()
depth = depth.detach().cpu()
render_alphas = render_alphas.detach().cpu()
means2d = means2d.detach().cpu()
visibility_filter = visibility_filter.detach().cpu()
radii = radii.detach().cpu()
else:
# split into chunks
chunk_outputs = []
for i in tqdm(range(0, nr_views, iter_batch_size), desc="Rendering in batches"):
if i + iter_batch_size > nr_views:
bs = nr_views - i
else:
bs = iter_batch_size
iter_viewmats = viewmats[:, i : i + bs] # [B, v, 4, 4]
iter_intrinsics = intrinsics_scaled[:, i : i + bs] # [B, v, 3, 3]
color, depth, render_alphas, means2d, visibility_filter, radii = _render(iter_viewmats, iter_intrinsics)
if to_cpu:
color = color.detach().cpu()
depth = depth.detach().cpu()
render_alphas = render_alphas.detach().cpu()
means2d = means2d.detach().cpu()
visibility_filter = visibility_filter.detach().cpu()
radii = radii.detach().cpu()
chunk_outputs.append((color, depth, render_alphas, means2d, visibility_filter, radii))
# concatenate all chunks
color = torch.cat([o[0] for o in chunk_outputs], dim=1) # [B, V, 3, H, W]
depth = torch.cat([o[1] for o in chunk_outputs], dim=1) # [B, V, H, W]
render_alphas = torch.cat([o[2] for o in chunk_outputs], dim=1) # [B, V, H, W, 1]
means2d = torch.cat([o[3] for o in chunk_outputs], dim=1) # [B, V, N, 2]
visibility_filter = torch.cat([o[4] for o in chunk_outputs], dim=1) # [B, V, N]
radii = torch.cat([o[5] for o in chunk_outputs], dim=1) # [B, V, N, 2]
return DecoderOutput(
color,
depth=depth,
accumulated_alpha=render_alphas.squeeze(-1), # [B, V, H, W]
means2d=means2d,
visibility_filter=visibility_filter,
radii=radii,
)