Spaces:
Sleeping
Sleeping
File size: 8,548 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | from dataclasses import dataclass
from typing import Literal
import math
import torch
# import torch.nn.functional as F
from einops import repeat
from gsplat.rendering import rasterization
from jaxtyping import Float
from torch import Tensor
from tqdm import tqdm
from optgs.scene_trainer.gaussian_module import GaussiansModule
from .decoder import Decoder, DecoderOutput
from .decoder import DepthRenderingMode
from ..types import Gaussians
from ...dataset import DatasetCfg
@dataclass
class GSplatDecoderSplattingCUDACfg:
name: Literal["gsplat"]
use_covariances: bool
rasterize_mode: Literal["antialiased", "classic"]
eps2d: float
class GSplatDecoderSplattingCUDA(Decoder[GSplatDecoderSplattingCUDACfg]):
background_color: Float[Tensor, "3"]
def __init__(
self,
cfg: GSplatDecoderSplattingCUDACfg,
dataset_cfg: DatasetCfg,
) -> None:
super().__init__(cfg, dataset_cfg)
self.register_buffer(
"background_color",
torch.tensor(dataset_cfg.background_color, dtype=torch.float32),
persistent=False,
)
def forward(
self,
gaussians: Gaussians | GaussiansModule,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None, # always render depth
return_radii: bool = False, # always return radii
iter_batch_size: int = -1, # -1 to render all views at once
use_covariances: bool = False, # override cfg
to_cpu: bool = False, # move outputs to cpu
) -> DecoderOutput:
_use_covariances = self.cfg.use_covariances if use_covariances is None else use_covariances
height, width = image_shape
if isinstance(gaussians, GaussiansModule):
# nb: no batch dimension
means = gaussians.means
quats = gaussians.rotations # [N, 4] in xyzw (scalar last)
quats = quats[:, [3, 0, 1, 2]] # [N, 4] in wxyz (scalar first)
quats = quats # [1, N, 4]
scales = gaussians.scales # post-activation
opacities = gaussians.opacities # post-activation
colors = gaussians.harmonics.permute(0, 2, 1) # [1, N, d_sh, 3]
if _use_covariances:
covars = gaussians.covariances
else:
covars = None
# add batch dimension
means = means.unsqueeze(0) # [1, N, 3]
quats = quats.unsqueeze(0) # [1, N, 4]
scales = scales.unsqueeze(0) # [1, N, 3]
opacities = opacities.unsqueeze(0) # [1, N, 1]
colors = colors.unsqueeze(0) # [1, N, d_sh, 3]
if covars is not None:
covars = covars.unsqueeze(0) # [1, N, 3, 3]
elif isinstance(gaussians, Gaussians):
means = gaussians.means # [B, N, 3]
quats = gaussians.rotations_unnorm # [B, N, 4] in wxyz (scalar first), rasterization normalizes internally
quats = quats[:, :, [3, 0, 1, 2]] # [B, N, 4] in wxyz (scalar first)
scales = gaussians.scales # [B, N, 3]
opacities = gaussians.opacities # [B, G]
colors = gaussians.harmonics.permute(0, 1, 3, 2) # [B, N, d_sh, 3]
if _use_covariances:
covars = gaussians.covariances # [B, N, 3, 3]
if covars is None:
raise ValueError("Covariances are set to be used, but gaussians.covariances is None.")
else:
covars = None
if gaussians.stores_activated:
# already activated
pass
else:
# activate
scales = torch.exp(scales) # [B, N, 3]
opacities = torch.sigmoid(opacities) # [B, N]
else:
raise ValueError(f"Unknown type of gaussians: {type(gaussians)}")
# prepare inputs for rasterization
sh_degree = int(math.sqrt(colors.shape[-2])) - 1 # d_sh = (degree + 1) ** 2
viewmats = extrinsics.inverse() # [B, V, 4, 4]
# scale intrinsics to image shape (avoid clone by creating scaled version directly)
intrinsics_scaled = intrinsics * intrinsics.new_tensor([[[width], [height], [1]]]) # [B, V, 3, 3]
def _render(viewmats, Ks):
# rasterize
render_colors, render_alphas, meta = rasterization(
means=means,
quats=quats,
scales=scales,
opacities=opacities,
colors=colors,
sh_degree=sh_degree,
viewmats=viewmats,
Ks=Ks,
width=width,
height=height,
# near_plane=near[0, 0].item(), # use default
# far_plane=far[0, 0].item(), # use default
eps2d=self.cfg.eps2d,
rasterize_mode=self.cfg.rasterize_mode,
packed=False,
# absgrad=False, # use default
# sparse_grad=False, # use default
render_mode="RGB+ED",
# with_ut=False, # use default
# with_eval3d=False, # use default
# covars=covars, # use default
)
# unpack outputs
color = render_colors[..., :3].permute(0, 1, 4, 2, 3) # [B, V, 3, H, W]
depth = render_colors[..., -1] # [B, V, H, W]
means2d = meta["means2d"] # [B, V, N, 2]
radii = meta["radii"] # [B, V, N, 2]
visibility_filter = torch.all(radii > 0, dim=-1) # [B, V, N]
return color, depth, render_alphas, means2d, visibility_filter, radii
# split into chunks to save memory
nr_views = extrinsics.shape[1]
if iter_batch_size < 0:
# render all views at once
color, depth, render_alphas, means2d, visibility_filter, radii = _render(viewmats, intrinsics_scaled)
if to_cpu:
color = color.detach().cpu()
depth = depth.detach().cpu()
render_alphas = render_alphas.detach().cpu()
means2d = means2d.detach().cpu()
visibility_filter = visibility_filter.detach().cpu()
radii = radii.detach().cpu()
else:
# split into chunks
chunk_outputs = []
for i in tqdm(range(0, nr_views, iter_batch_size), desc="Rendering in batches"):
if i + iter_batch_size > nr_views:
bs = nr_views - i
else:
bs = iter_batch_size
iter_viewmats = viewmats[:, i : i + bs] # [B, v, 4, 4]
iter_intrinsics = intrinsics_scaled[:, i : i + bs] # [B, v, 3, 3]
color, depth, render_alphas, means2d, visibility_filter, radii = _render(iter_viewmats, iter_intrinsics)
if to_cpu:
color = color.detach().cpu()
depth = depth.detach().cpu()
render_alphas = render_alphas.detach().cpu()
means2d = means2d.detach().cpu()
visibility_filter = visibility_filter.detach().cpu()
radii = radii.detach().cpu()
chunk_outputs.append((color, depth, render_alphas, means2d, visibility_filter, radii))
# concatenate all chunks
color = torch.cat([o[0] for o in chunk_outputs], dim=1) # [B, V, 3, H, W]
depth = torch.cat([o[1] for o in chunk_outputs], dim=1) # [B, V, H, W]
render_alphas = torch.cat([o[2] for o in chunk_outputs], dim=1) # [B, V, H, W, 1]
means2d = torch.cat([o[3] for o in chunk_outputs], dim=1) # [B, V, N, 2]
visibility_filter = torch.cat([o[4] for o in chunk_outputs], dim=1) # [B, V, N]
radii = torch.cat([o[5] for o in chunk_outputs], dim=1) # [B, V, N, 2]
return DecoderOutput(
color,
depth=depth,
accumulated_alpha=render_alphas.squeeze(-1), # [B, V, H, W]
means2d=means2d,
visibility_filter=visibility_filter,
radii=radii,
)
|