Spaces:
Sleeping
Sleeping
| import warnings | |
| from dataclasses import dataclass | |
| import torch.nn.functional as F | |
| import math | |
| from typing import List, Tuple, Iterator | |
| from typing import Literal, Generic, TypeVar | |
| from fused_ssim import allowed_padding, FusedSSIMMap | |
| import torch | |
| from gsplat import fully_fused_projection, isect_tiles, isect_offset_encode, rasterize_to_indices_in_range | |
| from nerfacc import accumulate_along_rays, render_weight_from_alpha | |
| from torch import Tensor | |
| from optgs.model.decoder.decoder import Decoder, DecoderOutput | |
| from optgs.model.types import Gaussians | |
| from einops import rearrange | |
| from tqdm import tqdm | |
| import gc | |
| import torch.autograd.profiler as profiler | |
| from optgs.misc.memory_profiler import profile_gpu_memory, report_gpu_tensors | |
| T = TypeVar("T") | |
| GPU_MEM_PROFILING = False # set to True to enable GPU memory profiling | |
| def split_grads(grads_tensor, cfg): | |
| assert isinstance(grads_tensor, Tensor), "grads_tensor is not a Tensor" | |
| # handle case where grads_tensor has batch dimension | |
| if grads_tensor.ndim == 3: | |
| assert grads_tensor.shape[0] == 1, "Batch size > 1 not supported for grads_tensor with ndim 3" | |
| grads_tensor = grads_tensor.squeeze(0) # [N, D] | |
| # Split the last dimension | |
| means, scales, rotations, opacities, shs = torch.split( | |
| grads_tensor, (3, 3, 4, 1, 3 * cfg.sh_d), dim=-1 | |
| ) | |
| shs = rearrange(shs, "n (c x) -> n c x", c=3, x=cfg.sh_d) # [N, 3, sh_d] | |
| sh0s = shs[..., 0:1] | |
| if cfg.sh_d > 1: | |
| shNs = shs[..., 1:] | |
| else: | |
| shNs = None | |
| grads: dict = { | |
| "means": means, | |
| "scales": scales, | |
| "rotations": rotations, | |
| "opacities": opacities, | |
| "sh0s": sh0s, | |
| "shNs": shNs, | |
| } | |
| return grads | |
| def inner_loss_for_input_gradients( | |
| gt_images, | |
| output_renderer: DecoderOutput, | |
| reduction: str = "mean", | |
| with_ssim: bool = True, | |
| ) -> Tensor: | |
| # compute scalar loss | |
| # assume batch size 1 | |
| assert gt_images.shape[0] == 1 | |
| assert gt_images.shape == output_renderer.color.shape | |
| l1_loss = (output_renderer.color - gt_images).abs() | |
| if reduction == "mean": | |
| l1_loss = l1_loss.mean() | |
| elif reduction == "sum": | |
| l1_loss = l1_loss.sum() | |
| elif reduction == "mean_pixels_sum_views": | |
| l1_loss = l1_loss.mean(dim=(-1, -2, -3)).sum(dim=-1).mean() | |
| else: | |
| raise ValueError(f"Unknown reduction: {reduction!r}") | |
| if not with_ssim: | |
| return l1_loss | |
| gt_images_for_ssim = gt_images.clone() if gt_images.is_inference() else gt_images | |
| ssim_loss = fused_ssim_with_reduction( | |
| rearrange(output_renderer.color, "b v c h w -> (b v) c h w"), | |
| rearrange(gt_images_for_ssim, "b v c h w -> (b v) c h w"), | |
| padding="valid", | |
| reduction=reduction, | |
| loss=True, # returns mean(1 - ssim), i.e. the SSIM loss | |
| ) | |
| return 0.8 * l1_loss + 0.2 * ssim_loss | |
| def squeeze_grad_dict(grad_dict): | |
| for k, v in grad_dict.items(): | |
| if v is not None: | |
| grad_dict[k] = v.squeeze(0) | |
| return grad_dict | |
| def smooth_grads(grads: dict, smoothers: dict) -> dict: | |
| smoothed_grads = {} | |
| for k, v in grads.items(): | |
| if k not in smoothers: | |
| continue | |
| else: | |
| if v is not None: | |
| smoothed_grads[k] = smoothers[k](v) | |
| else: | |
| smoothed_grads[k] = None | |
| return smoothed_grads | |
| def chunk_ranges(v: int, chunk_size: int) -> List[Tuple[int, int]]: | |
| """ | |
| Return a list of (start, stop) index ranges that partition [0, v). | |
| Last chunk may be smaller if v % chunk_size != 0. | |
| Example: chunk_ranges(10, 4) -> [(0,4),(4,8),(8,10)] | |
| """ | |
| if chunk_size <= 0: | |
| raise ValueError("chunk_size must be > 0") | |
| ranges = [] | |
| start = 0 | |
| while start < v: | |
| stop = min(start + chunk_size, v) | |
| ranges.append((start, stop)) | |
| start = stop | |
| return ranges | |
| def chunk_slices(v: int, chunk_size: int, dim: int = 1) -> List[slice]: | |
| """ | |
| Return a list of slice objects that slice along axis `dim`. | |
| Use like: tensor[(slice(None), slice_start_stop, ...)] — easier: use helper below. | |
| NOTE: slice objects don't encode the axis; they only give start/stop; see usage. | |
| """ | |
| return [slice(s, e) for s, e in chunk_ranges(v, chunk_size)] | |
| def chunk_index_iter(v: int, chunk_size: int) -> Iterator[Tuple[int,int,int]]: | |
| """ | |
| Iterate chunk info as (chunk_idx, start, stop) for convenience. | |
| """ | |
| for idx, (s, e) in enumerate(chunk_ranges(v, chunk_size)): | |
| yield idx, s, e | |
| def fused_ssim_with_reduction(img1, img2, padding="same", train=True, reduction="mean", loss=False): | |
| C1 = 0.01 ** 2 | |
| C2 = 0.03 ** 2 | |
| assert padding in allowed_padding | |
| img1 = img1.contiguous() | |
| ssim_map = FusedSSIMMap.apply(C1, C2, img1, img2, padding, train) # [v c h w] | |
| if loss: | |
| ssim_map = 1 - ssim_map | |
| if reduction == "mean": | |
| return ssim_map.mean() | |
| elif reduction == "sum": | |
| return ssim_map.sum() | |
| elif reduction == "mean_pixels_sum_views": | |
| # Mean over spatial (h, w) and channel (c) dims, then sum over views (v) | |
| return ssim_map.mean(dim=(-1, -2, -3)).sum(dim=-1) | |
| else: | |
| raise ValueError(f"Unsupported reduction: {reduction}") | |
| def calc_input_gradients( | |
| iter_context, | |
| prev_means, | |
| prev_scales_raw, | |
| prev_rotations_unnorm, | |
| prev_opacities_raw, # [B, N] — may be a non-leaf view of gaussians.opacities | |
| prev_shs, # [B, N, 3, sh_d] | |
| renderer: Decoder, | |
| need_2d_grads: bool, | |
| chunk_size: int | None, | |
| any_adc: bool = True, | |
| sh_degree: int | None = None, | |
| meta_bufs: dict | None = None, # mutable dict populated/reused across calls for radii & visibility | |
| loss_reduction: str = "mean", | |
| loss_with_ssim: bool = True, | |
| opacity_reg_lambda: float = 0.0, # L1 opacity regularization weight (3DGS-MCMC) | |
| ) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor | None] | None]: | |
| b, v, _, h, w = iter_context["image"].shape | |
| assert b == 1, "Batch size > 1 not supported for post-processing" | |
| if chunk_size == -1: | |
| chunk_size = v | |
| nr_chunks = math.ceil(v / chunk_size) | |
| N = prev_means.shape[1] | |
| device = prev_means.device | |
| # --- Grad setup --- | |
| # Gradients are obtained functionally via torch.autograd.grad below, so .grad buffers | |
| # are never read or written. Enable requires_grad on the leaf params as a fallback if | |
| # the caller did not already set it up. Order matters: it defines the autograd.grad | |
| # input order and therefore the order of the returned per-param gradients. | |
| _leaf_params = [prev_means, prev_scales_raw, prev_rotations_unnorm, prev_opacities_raw, prev_shs] | |
| for t in _leaf_params: | |
| if not t.requires_grad: | |
| t.requires_grad_(True) | |
| # --- Allocate or reuse radii / visibility buffers (only needed when any_adc) --- | |
| bufs_valid = ( | |
| any_adc | |
| and meta_bufs is not None | |
| and meta_bufs.get("N") == N | |
| and meta_bufs.get("v") == v | |
| ) | |
| if bufs_valid: | |
| radii_all = meta_bufs["radii"] | |
| visibility_all = meta_bufs["visibility"] | |
| means2d_grads_all = meta_bufs.get("means2d_grads") | |
| if need_2d_grads and means2d_grads_all is None: | |
| means2d_grads_all = torch.empty((b, v, N, 2), dtype=torch.float32, device=device) | |
| meta_bufs["means2d_grads"] = means2d_grads_all | |
| elif any_adc: | |
| radii_all = torch.empty((b, v, N, 2), dtype=torch.float32, device=device) | |
| visibility_all = torch.empty((b, v, N), dtype=torch.bool, device=device) | |
| means2d_grads_all = ( | |
| torch.empty((b, v, N, 2), dtype=torch.float32, device=device) | |
| if need_2d_grads else None | |
| ) | |
| if meta_bufs is not None: | |
| meta_bufs.update({"N": N, "v": v, "radii": radii_all, | |
| "visibility": visibility_all, "means2d_grads": means2d_grads_all}) | |
| else: | |
| radii_all = visibility_all = means2d_grads_all = None | |
| # --- Forward + autograd.grad loop --- | |
| # Per-chunk gradients for the leaf params are summed here, then averaged below. | |
| accumulated_grads: list[Tensor] | None = None | |
| with torch.enable_grad(): | |
| assert not torch.is_inference_mode_enabled() | |
| for chunk_idx, start, stop in tqdm(chunk_index_iter(v, chunk_size), disable=nr_chunks <= 1, | |
| desc="Computing input gradients in chunks"): | |
| image_chunk = iter_context["image"][:, start:stop] | |
| extrinsics_chunk = iter_context["extrinsics"][:, start:stop] | |
| intrinsics_chunk = iter_context["intrinsics"][:, start:stop] | |
| near_chunk = iter_context["near"][:, start:stop] | |
| far_chunk = iter_context["far"][:, start:stop] | |
| prev_opacities = torch.sigmoid(prev_opacities_raw) | |
| prev_scales = torch.exp(prev_scales_raw) | |
| prev_rotations = F.normalize(prev_rotations_unnorm, dim=-1) | |
| if sh_degree is not None: | |
| prev_shs_for_render = prev_shs[..., :(sh_degree + 1) ** 2] | |
| else: | |
| prev_shs_for_render = prev_shs | |
| tmp_gaussians = Gaussians( | |
| means=prev_means, | |
| covariances=None, | |
| harmonics=prev_shs_for_render, | |
| opacities=prev_opacities, | |
| scales=prev_scales, | |
| rotations=prev_rotations, | |
| rotations_unnorm=prev_rotations_unnorm, | |
| stores_activated=True, | |
| ) | |
| if GPU_MEM_PROFILING: | |
| output_renderer: DecoderOutput = profile_gpu_memory( | |
| fn=renderer.forward, gaussians=tmp_gaussians, | |
| extrinsics=extrinsics_chunk, intrinsics=intrinsics_chunk, | |
| near=near_chunk, far=far_chunk, image_shape=(h, w)) | |
| else: | |
| output_renderer: DecoderOutput = renderer.forward( | |
| gaussians=tmp_gaussians, | |
| extrinsics=extrinsics_chunk, intrinsics=intrinsics_chunk, | |
| near=near_chunk, far=far_chunk, image_shape=(h, w)) | |
| loss = inner_loss_for_input_gradients(image_chunk, output_renderer, | |
| reduction=loss_reduction, with_ssim=loss_with_ssim) | |
| # L1 opacity regularization (3DGS-MCMC) folded into the differentiated loss. | |
| grad_loss = loss | |
| if opacity_reg_lambda > 0.0: | |
| grad_loss = loss + opacity_reg_lambda * torch.sigmoid(prev_opacities_raw).mean() | |
| grad_inputs = list(_leaf_params) | |
| if need_2d_grads: | |
| assert output_renderer.means2d is not None | |
| grad_inputs.append(output_renderer.means2d) | |
| chunk_grads = torch.autograd.grad(grad_loss, grad_inputs, | |
| create_graph=False, retain_graph=False) | |
| param_grads = [g.detach() for g in chunk_grads[:5]] | |
| if accumulated_grads is None: | |
| accumulated_grads = param_grads | |
| else: | |
| accumulated_grads = [a + g for a, g in zip(accumulated_grads, param_grads)] | |
| # store per-chunk meta | |
| if any_adc: | |
| radii_all[:, start:stop] = output_renderer.radii | |
| visibility_all[:, start:stop] = output_renderer.visibility_filter | |
| if need_2d_grads: | |
| means2d_grads_all[:, start:stop] = chunk_grads[5].detach() | |
| # --- Average grads for multi-chunk --- | |
| if nr_chunks > 1: | |
| inv = 1.0 / nr_chunks | |
| accumulated_grads = [g * inv for g in accumulated_grads] | |
| means_grads, scales_raw_grads, rotations_unnorm_grads, opacities_raw_grads, harmonics_grads = accumulated_grads | |
| sh0s_grads = harmonics_grads[..., 0:1] | |
| shNs_grads = harmonics_grads[..., 1:] if harmonics_grads.shape[-1] > 1 else None | |
| grads = { | |
| "means": means_grads, | |
| "scales": scales_raw_grads, | |
| "rotations": rotations_unnorm_grads, | |
| "opacities": opacities_raw_grads, | |
| "sh0s": sh0s_grads, | |
| "shNs": shNs_grads, | |
| } | |
| meta_for_adc = { | |
| "visibility_filter": visibility_all, | |
| "radii": radii_all, | |
| "means_2d_grads": means2d_grads_all if need_2d_grads else None, | |
| } if any_adc else None | |
| return loss, grads, meta_for_adc | |
| def unpack_gaussians( | |
| gaussians: Gaussians, | |
| scales_log: bool, | |
| opacities_logit: bool, | |
| opacities_unsqueeze: bool, | |
| detach: bool = True, | |
| clone: bool = False, | |
| requires_grad: bool = False, | |
| scales_lims: tuple | None = None, # post activation (1e-6, 3) | |
| raw_opacities_lims: tuple | None = None, # pre activation (-7, 7) | |
| ): | |
| """ Unpack Gaussian parameters and invert opacities and scales. | |
| # TODO Naama: fix this | |
| Clamp values for scales are in post-activation space, i.e., after exponentiation. | |
| Clamp values for opacities are in pre-activation space, i.e., before sigmoid | |
| """ | |
| # Means | |
| means = gaussians.means # [B, N, 3] | |
| # Scales | |
| scales = gaussians.scales # [B, N, 3] | |
| if scales_lims is not None: | |
| scales = torch.clamp(scales, scales_lims[0], scales_lims[1]) | |
| # if self.cfg.opt_scales_before_act: | |
| if scales_log: | |
| # Invert also scales | |
| scales = torch.log(scales + 1e-8) | |
| # Quaternions | |
| # use unnormalized rotations since we are going to refine the unnormed rotations | |
| rotations_unnorm = gaussians.rotations_unnorm # [B, N, 4] | |
| # Opacities | |
| # before sigmoid, eps is necessary, otherwise might be nan | |
| if opacities_logit: | |
| opacities_raw = torch.logit(gaussians.opacities, eps=1e-7) # [B, N] | |
| if raw_opacities_lims is not None: | |
| opacities_raw = torch.clamp(opacities_raw, raw_opacities_lims[0], raw_opacities_lims[1]) | |
| else: | |
| opacities_raw = gaussians.opacities # [B, N] | |
| if opacities_unsqueeze: | |
| opacities_raw = opacities_raw.unsqueeze(-1) # [B, N, 1] | |
| # SHs - use flatten instead of rearrange for speed | |
| shs = gaussians.harmonics # [B, N, 3, 9] | |
| shs = shs.flatten(-2) # [B, N, C] - faster than rearrange | |
| if gaussians.sel is not None: | |
| # TODO Naama: move method to Gaussians class | |
| sel = gaussians.sel # [B, N] | |
| means = means[:, sel] | |
| opacities_raw = opacities_raw[:, sel] | |
| rotations_unnorm = rotations_unnorm[:, sel] | |
| scales = scales[:, sel] | |
| shs = shs[:, sel] | |
| if detach: | |
| means = means.detach() | |
| opacities_raw = opacities_raw.detach() | |
| rotations_unnorm = rotations_unnorm.detach() | |
| scales = scales.detach() | |
| shs = shs.detach() | |
| if clone: | |
| means = means.clone() | |
| opacities_raw = opacities_raw.clone() | |
| rotations_unnorm = rotations_unnorm.clone() | |
| scales = scales.clone() | |
| shs = shs.clone() | |
| if requires_grad: | |
| means.requires_grad_(True) | |
| opacities_raw.requires_grad_(True) | |
| rotations_unnorm.requires_grad_(True) | |
| scales.requires_grad_(True) | |
| shs.requires_grad_(True) | |
| # # predicting multiple gaussians per point, init new gaussians by copying with scaled opacities | |
| # if self.cfg.reinit_gaussian_when_refine_multiple and self.cfg.refine_gaussian_multiple > 1: | |
| # raise NotImplementedError | |
| # # This should only be called at the first iteration | |
| # # TODO Naama: might be bug if we use replay buffer | |
| # repeat = self.cfg.refine_gaussian_multiple | |
| # prev_means = prev_means.repeat(1, repeat, 1) | |
| # prev_scales = prev_scales.repeat(1, repeat, 1) | |
| # prev_rotations_unnorm = prev_rotations_unnorm.repeat(1, repeat, 1) | |
| # | |
| # # scale down opacities | |
| # prev_opacities_raw = prev_opacities_raw.repeat(1, repeat, 1) # smaller opacities, important | |
| # # Given y = sigmoid(x), to get new x' such that sigmoid(x') = y / K: | |
| # # The formula is: x' = x + log((1 - y) / (K - y)) | |
| # # This adjusts x so that the sigmoid output is scaled down by a factor of K | |
| # tmp_sigmoid = prev_opacities_raw.sigmoid() | |
| # # print(tmp_sigmoid.mean().item()) | |
| # prev_opacities_raw = prev_opacities_raw + torch.log((1 - tmp_sigmoid) / (repeat - tmp_sigmoid)) | |
| # | |
| # prev_shs = prev_shs.repeat(1, repeat, 1) | |
| # | |
| # # TODO: this part not ready | |
| return means, scales, rotations_unnorm, opacities_raw, shs | |
| def get_gaussian_param_slices(sh_d: int) -> dict: | |
| """Return index slices for each Gaussian parameter group in the packed vector. | |
| Layout (must match pack_gaussians): | |
| [means(3) | scales(3) | quats(4) | opacities(1) | shs(3*sh_d)] | |
| """ | |
| sh_end = 11 + 3 * sh_d | |
| return { | |
| "means": slice(0, 3), | |
| "scales": slice(3, 6), | |
| "quats": slice(6, 10), | |
| "opacities": slice(10, 11), | |
| "sh0": slice(11, sh_end, sh_d), | |
| "shN": [i for i in range(11, sh_end) if (i - 11) % sh_d != 0], | |
| } | |
| def get_gaussian_param_sizes(sh_d: int) -> dict: | |
| """Return the element count for each Gaussian parameter group. | |
| Layout matches pack_gaussians / get_gaussian_param_slices: | |
| [means(3) | scales(3) | quats(4) | opacities(1) | shs(3*sh_d)] | |
| """ | |
| return { | |
| "means": 3, | |
| "scales": 3, | |
| "quats": 4, | |
| "opacities": 1, | |
| "shs": 3 * sh_d, | |
| } | |
| def pack_gaussians( | |
| means: Tensor, | |
| scales: Tensor, | |
| rotations_unnorm: Tensor, | |
| opacities_raw: Tensor, | |
| shs: Tensor, | |
| ) -> Tensor: | |
| """Concatenate unpacked Gaussian parameters into a single [B, N, C] vector. | |
| Layout (must match get_gaussian_param_slices): | |
| [means(3) | scales(3) | quats(4) | opacities(1) | shs(3*sh_d)] | |
| """ | |
| return torch.cat((means, scales, rotations_unnorm, opacities_raw, shs), dim=-1) | |
| def get_visibility_contribution_from_gaussian_obj(views_info, gaussians, image_shape=None, render_image=False) -> tuple[Tensor, dict]: | |
| """ | |
| Args: | |
| views_info: dict containing: | |
| "extrinsics": Tensor of shape [B, V, 4, 4] | |
| "intrinsics": Tensor of shape [B, V, 3, 3] | |
| "image": Tensor of shape [B, V, C, H, W] (Optional, only for shape reference) | |
| "near": Tensor of shape [B, 1] | |
| "far": Tensor of shape [B, 1] | |
| gaussians: Gaussian object containing: | |
| .means: Tensor of shape [B, N, 3] | |
| .rotations_unnorm: Tensor of shape [B, N, 4] | |
| .scales: Tensor of shape [B, N, 3] | |
| .opacities: Tensor of shape [B, N] | |
| image_shape: Optional tuple (width, height). If None, use the shape from views_info["image"]. | |
| Returns a (N,) shaped tensor whose entry k is the visibility contribution of the k-th Gaussian. | |
| out[k] = sum_{c,i,j}^{C, H, W} w_{k,c,i,j} | |
| ️ | |
| """ | |
| # Context can be either context or target | |
| # TODO Naama: check visibility for both context and target views | |
| b = gaussians.means.shape[0] | |
| assert b == 1 | |
| # Data preparation | |
| means = gaussians.means[0] # [N, 3] | |
| # Not sure why, the rendering uses it and says the rastereization will normalize | |
| quats = gaussians.rotations_unnorm[0] | |
| quats = quats[:, [3, 0, 1, 2]] # [N, 4] # xyzw to wxyz | |
| scales = gaussians.scales[0] # [N, 3] | |
| opacities = gaussians.opacities[0] # [N] | |
| viewmats = views_info["extrinsics"][0] # [V, 4, 4] | |
| viewmats = viewmats.inverse() | |
| Ks = views_info["intrinsics"][0].clone() # [V, 3, 3] | |
| if image_shape is not None: | |
| width, height = image_shape | |
| else: | |
| width = views_info["image"].shape[-1] | |
| height = views_info["image"].shape[-2] | |
| Ks[:, 0] *= width | |
| Ks[:, 1] *= height | |
| near = views_info["near"][0, 0].item() | |
| far = views_info["far"][0, 0].item() | |
| with torch.no_grad(): | |
| weight_vis_contribution, info = get_gaussians_visibility_contribution( | |
| means=means, | |
| quats=quats, | |
| scales=scales, | |
| opacities=opacities, | |
| viewmats=viewmats, | |
| Ks=Ks, | |
| width=width, | |
| height=height, | |
| near_plane=near, | |
| far_plane=far, | |
| eps2d=0.1, | |
| rasterize_mode="antialiased", | |
| ) | |
| return weight_vis_contribution, info | |
| def get_gaussians_visibility_contribution( | |
| means: Tensor, # [N, 3] | |
| quats: Tensor, # [N, 4] | |
| scales: Tensor, # [N, 3] | |
| opacities: Tensor, # [N] | |
| viewmats: Tensor, # [V, 4, 4] | |
| Ks: Tensor, # [V, 3, 3] | |
| width: int, | |
| height: int, | |
| # set these as in your render function | |
| near_plane: float = 0.01, | |
| far_plane: float = 1e10, | |
| eps2d: float = 0.3, | |
| tile_size: int = 16, | |
| rasterize_mode: Literal["classic", "antialiased"] = "antialiased", | |
| batch_per_iter: int = 100, | |
| ) -> tuple[Tensor, dict]: | |
| """ | |
| Returns a (N,) shaped tensor whose entry k is the visibility contribution of the k-th Gaussian. | |
| out[k] = sum_{c,i,j}^{C, H, W} w_{k,c,i,j} | |
| """ | |
| N = means.shape[0] | |
| V = viewmats.shape[0] | |
| assert means.shape == (N, 3), means.shape | |
| assert quats.shape == (N, 4), quats.shape | |
| assert scales.shape == (N, 3), scales.shape | |
| assert opacities.shape == (N,), opacities.shape | |
| assert viewmats.shape == (V, 4, 4), viewmats.shape | |
| assert Ks.shape == (V, 3, 3), Ks.shape | |
| # Project Gaussians to 2D. | |
| # The results are with shape [V, N, ...]. Only the elements with radii > 0 are valid. | |
| radii, means2d, depths, conics, compensations = fully_fused_projection( | |
| means=means, | |
| covars=None, | |
| quats=quats, | |
| scales=scales, | |
| viewmats=viewmats, | |
| Ks=Ks, | |
| width=width, | |
| height=height, | |
| eps2d=eps2d, | |
| near_plane=near_plane, | |
| far_plane=far_plane, | |
| calc_compensations=(rasterize_mode == "antialiased"), | |
| ) | |
| # import matplotlib.pyplot as plt | |
| # view_id = 0 # choose a view to inspect | |
| # image = torch.ones((3, height, width)) # [3, H, W] | |
| # image = image.permute(1, 2, 0) | |
| # image = (image * 255).clamp(0, 255).byte().cpu().detach().numpy() | |
| # | |
| # # Get 2D projected points and depth | |
| # x = means2d[view_id, :, 0].cpu().detach().numpy() | |
| # y = means2d[view_id, :, 1].cpu().detach().numpy() | |
| # | |
| # # Optional: mask out invalid points (e.g., outside image or radius == 0) | |
| # H, W = image.shape[:2] | |
| # valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H) | |
| # | |
| # # Plot | |
| # plt.figure(figsize=(10, 10)) | |
| # plt.imshow(image) # Background image | |
| # plt.scatter(x[valid_mask], y[valid_mask], c=means[:, -1][valid_mask].cpu().detach().numpy(), cmap='viridis', s=2) | |
| # # plt.gca().invert_yaxis() # Optional: for image coordinate convention | |
| # plt.title("Overlay: Projected Gaussians (colored by depth)") | |
| # plt.colorbar(label="Depth") | |
| # plt.show() | |
| opacities = opacities.repeat(V, 1) # [V, N] | |
| if compensations is not None: | |
| opacities = opacities * compensations | |
| # Identify intersecting tiles | |
| tile_width = math.ceil(width / float(tile_size)) | |
| tile_height = math.ceil(height / float(tile_size)) | |
| tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( | |
| means2d, | |
| radii, | |
| depths, | |
| tile_size, | |
| tile_width, | |
| tile_height, | |
| packed=False, | |
| n_images=V, | |
| image_ids=None, | |
| gaussian_ids=None, | |
| ) | |
| isect_offsets = isect_offset_encode(isect_ids, V, tile_width, tile_height) | |
| vis_contributions_sum, render_alphas, gaussian_weights_per_view = _gaussians_vis_contribution( | |
| means2d, | |
| conics, | |
| opacities, | |
| width, | |
| height, | |
| tile_size, | |
| isect_offsets, | |
| flatten_ids, | |
| batch_per_iter=batch_per_iter, | |
| ) # (N,) | |
| return vis_contributions_sum, {"alphas": render_alphas, | |
| "radii": radii, | |
| "means2d": means2d, | |
| "conics": conics, | |
| "depths": depths, | |
| "weights_per_view": gaussian_weights_per_view} # (N,) | |
| def _gaussians_vis_contribution( | |
| means2d: Tensor, # [V, N, 2] | |
| conics: Tensor, # [V, N, 3] | |
| opacities: Tensor, # [V, N] | |
| image_width: int, | |
| image_height: int, | |
| tile_size: int, | |
| isect_offsets: Tensor, # [V, tile_height, tile_width] | |
| flatten_ids: Tensor, # [n_isects] | |
| batch_per_iter: int = 100, | |
| ): | |
| V, N = means2d.shape[:2] | |
| n_isects = len(flatten_ids) | |
| device = means2d.device | |
| render_alphas = torch.zeros((V, image_height, image_width, 1), device=device) | |
| gaussian_weights = torch.zeros(N, dtype=opacities.dtype, device=device) | |
| gaussian_weights_per_view = torch.zeros((V, N), dtype=opacities.dtype, device=device) | |
| # Split Gaussians into batches and iteratively accumulate the renderings | |
| block_size = tile_size * tile_size | |
| isect_offsets_fl = torch.cat( | |
| [isect_offsets.flatten(), torch.tensor([n_isects], device=device)] | |
| ) | |
| max_range = (isect_offsets_fl[1:] - isect_offsets_fl[:-1]).max().item() | |
| num_batches = (max_range + block_size - 1) // block_size | |
| total_pixels = V * image_height * image_width | |
| # Pre-allocate accumulator reused across loop iterations to avoid per-step allocation | |
| out = torch.zeros(N, dtype=opacities.dtype, device=device) | |
| # Loop over batches of Gaussians | |
| for step in range(0, num_batches, batch_per_iter): | |
| # Current transmittance | |
| transmittances = 1.0 - render_alphas[..., 0] | |
| gs_ids, image_ids, indices, pixel_ids, weights = get_m_intersection_weights(batch_per_iter, conics, flatten_ids, | |
| image_height, image_width, | |
| isect_offsets, means2d, opacities, | |
| step, tile_size, total_pixels, | |
| transmittances) | |
| # Sum weights over gaussian indices (reuse pre-allocated buffer) | |
| out.zero_() | |
| out.index_add_(0, gs_ids, weights) # (N,) | |
| gaussian_weights_per_view[image_ids, gs_ids] += weights | |
| # Add to the global sum | |
| gaussian_weights += out | |
| # Accumulate alpha along rays | |
| alphas = accumulate_along_rays( | |
| weights, None, ray_indices=indices, n_rays=total_pixels | |
| ) | |
| alphas = alphas.reshape(V, image_height, image_width, 1) | |
| render_alphas.add_(alphas * transmittances[..., None]) | |
| return gaussian_weights, render_alphas, gaussian_weights_per_view | |
| def get_m_intersection_weights(range_size, conics, flatten_ids, image_height, image_width, isect_offsets, means2d, | |
| opacities, step, tile_size, total_pixels, transmittances): | |
| # Find the M intersections between pixels and gaussians. | |
| # Each intersection corresponds to a tuple (gs_id, pixel_id, camera_id) | |
| gs_ids, pixel_ids, image_ids = rasterize_to_indices_in_range( | |
| step, | |
| step + range_size, | |
| transmittances, | |
| means2d, | |
| conics, | |
| opacities, | |
| image_width, | |
| image_height, | |
| tile_size, | |
| isect_offsets, | |
| flatten_ids, | |
| ) # [M], [M] | |
| # if len(gs_ids) == 0: | |
| # break | |
| # Compute gaussian-pixel alpha values (reduced opacity due to gaussian intensity in 2D) -> (M,) | |
| pixel_ids_x = pixel_ids % image_width | |
| pixel_ids_y = pixel_ids // image_width | |
| pixel_coords = torch.stack([pixel_ids_x, pixel_ids_y], dim=-1) + 0.5 # [M, 2] | |
| deltas = pixel_coords - means2d[image_ids, gs_ids] # [M, 2] | |
| c = conics[image_ids, gs_ids] # [M, 3] | |
| sigmas = ( | |
| 0.5 * (c[:, 0] * deltas[:, 0] ** 2 + c[:, 2] * deltas[:, 1] ** 2) | |
| + c[:, 1] * deltas[:, 0] * deltas[:, 1] | |
| ) # [M] | |
| alphas = opacities[image_ids, gs_ids] * torch.exp(-sigmas) | |
| # alphas = torch.clamp_max( | |
| # opacities[image_ids, gs_ids] * torch.exp(-sigmas), 0.999 | |
| # ) | |
| if (alphas > 1).any(): | |
| warnings.warn(f"Not all alphas <= 1, max alpha: {alphas.max().item()}") | |
| # indices of the samples with shape (all_samples,) | |
| indices = image_ids * image_height * image_width + pixel_ids # (M,) | |
| # `weights` is a flattened tensor with shape (all_samples,) | |
| weights, _ = render_weight_from_alpha( | |
| alphas, ray_indices=indices, n_rays=total_pixels | |
| ) # (M,) | |
| return gs_ids, image_ids, indices, pixel_ids, weights | |
| class Base3DGSAttributeCfg(Generic[T]): | |
| _base: T | |
| _means: T | |
| _scales: T | |
| _opacities: T | |
| _quats: T | |
| _sh0: T | |
| _shN: T | |
| def base(self) -> T: | |
| return self._base | |
| def means(self) -> T: | |
| return self.base * self._means | |
| def scales(self) -> T: | |
| return self.base * self._scales | |
| def opacities(self) -> T: | |
| return self.base * self._opacities | |
| def quats(self) -> T: | |
| return self.base * self._quats | |
| def rotations(self) -> T: | |
| return self.quats | |
| def sh0(self) -> T: | |
| return self.base * self._sh0 | |
| def shN(self) -> T: | |
| return self.base * self._shN | |
| def param_names(self) -> list[str]: | |
| return ['means', 'scales', 'quats', 'opacities', 'sh0', 'shN'] | |
| def dict(self): | |
| return {name: getattr(self, name) for name in self.param_names} | |
| class Bool3DGSCfg(Base3DGSAttributeCfg[bool]): | |
| # Config loading via dacite doesn't seem to support generic type, so need to write types explicitly | |
| _base: bool | |
| _means: bool | |
| _scales: bool | |
| _opacities: bool | |
| _quats: bool | |
| _sh0: bool | |
| _shN: bool | |
| def all_true(self): | |
| # return all attributes that are True | |
| return all([getattr(self, attr) for attr in self.param_names]) | |
| def __str__(self): | |
| if self.all_true: | |
| return "all" | |
| else: | |
| return "_".join([f"{attr}" for attr in self.param_names if getattr(self, attr)]) | |
| class Number3DGSCfg(Base3DGSAttributeCfg[float | int]): | |
| # Config loading via dacite doesn't seem to support generic type, so need to write types explicitly | |
| _base: float | int | |
| _means: float | int | |
| _scales: float | int | |
| _opacities: float | int | |
| _quats: float | int | |
| _sh0: float | int | |
| _shN: float | int | |