import math import random from dataclasses import dataclass from typing import Literal, Optional, Any import torch import torch.nn.functional as F import torch.utils.checkpoint import torchvision.transforms as T from einops import rearrange from torch import nn, Tensor from optgs.dataset.data_types import BatchedExample, DataShim from optgs.dataset.data_types import BatchedViews from optgs.dataset.shims.patch_shim import apply_patch_shim from optgs.geometry.projection import project, sample_image_grid from optgs.misc.general_utils import SkipBatchException from optgs.misc.io import FrequencyScheduler from optgs.model.decoder.decoder import Decoder from optgs.model.encoder.layer import ResNetFeatureWarpper from optgs.model.types import Gaussians from optgs.scene_trainer.common.gaussian_adapter import build_covariance from optgs.scene_trainer.initializer import InitializerCfg, InitializerColmapCfg, InitializerEdgsCfg, \ InitializerRandomCfg, InitializerPointcloudCfg from optgs.scene_trainer.initializer import InitializerPlyCfg from optgs.scene_trainer.initializer.initializer_resplat import ResplatInitializerCfg from optgs.scene_trainer.optimizer.optimizer import OptimizerInput, LearnedOptimizer, OptimizerOutput, OptimizerState, \ OptimizerPreviousOutput, OptimizerCfg from optgs.scene_trainer.optimizer.optimizer_utils import Number3DGSCfg, Bool3DGSCfg from optgs.scene_trainer.optimizer.optimizer_utils import unpack_gaussians, \ get_visibility_contribution_from_gaussian_obj try: from optgs.model.encoder.point_transformer.layer import (PlainPointTransformer, SubsampleBlock, PointLinearWrapper, MultiScalePointTransformer, MultViewLowresAttn) except: pass try: from simple_knn._C import distCUDA2 except: pass from optgs.scene_trainer.optimizer.layer import CustomGroupNorm, AdamInputSmoothing, SlicedG3RNorm from optgs.scene_trainer.initializer.initializer import InitializerOutput from optgs.scene_trainer.optimizer.time_embed import get_embedder, TimeEncodingWrapper from optgs.loss.loss_depth_smooth import get_smooth_loss from optgs.scene_trainer.optimizer.optimizer_utils import ( inner_loss_for_input_gradients, chunk_index_iter, split_grads, get_gaussian_param_slices, get_gaussian_param_sizes, pack_gaussians, ) _IMAGENET_NORM = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) @dataclass class KnnBasedOptimizerCfg(OptimizerCfg): name: Literal["knn_based", "resplat_v1", "resplat_v2", "clogs", "l2s"] # TODO (release) remove clogs # iterative refine no_render_error: bool input_error_shallow_resnet_feature: bool input_error_resnet_feature_layers: int refine_sh_only: bool num_basic_refine_blocks: int num_refine_blocks: int concat_init_state: bool # always concat init state during updates replace_init_state: bool # always use the init state during updates state_channels: int refine_block_rmsnorm: bool refine_block_layernorm: bool pt_qk_norm: bool norm_pt_block: bool refine_gaussian_multiple: int # predict more gaussian residuals based on the previous gaussian center refine_residual_init_state: bool # add residual connection in the prediction head to the inital state clamp_refine_max_scale: float clamp_min_scale: float | int clamp_min_raw_scales: float | int clamp_max_raw_scales: float | int clamp_min_raw_opacities: float | int clamp_max_raw_opacities: float | int clamp_min_sh0: float | int clamp_max_sh0: float | int clamp_min_shs: float | int clamp_max_shs: float | int clamp_shs_soft: bool gaussian_head_multiple: int # use multiple non-weight sharing heads to predict multiple gaussians gradient_update_scale: float | int input_gradient_with_ssim_loss: bool update_attn_proj_channels: int | None update_no_knn_attn: bool update_no_tran_block_norm: bool update_tran_block_act: str | None multi_gaussian_scale_smaller: bool refine_condition_pt_feature: bool reinit_gaussian_when_refine_multiple: bool refine_same_num_points: bool # when init_gaussian_multiple > 1, refine directly works on it instead of subsampling points refine_knn_samples: int refine_multi_scale_pt: bool # KNN use_fused_attn: bool prune_invisible_gaussians: bool knn_idx_update_every: int # point transformer pt_heads: int # inputs input_alpha: bool input_depth: bool input_depth_smooth_error: bool # input error input_error: bool # render error as input to the refine head input_error_rgb_no_shuffle: bool # sample single pixel instead of pixel unshuffling input_error_add_rgb_feature: bool # resnet input_error_resnet_feature: bool input_error_cache_resnet_feature: bool input_error_no_freeze_resnet_feature: bool # number of views for render error input_error_num_views: int input_error_additional_cross_attn: bool input_error_num_intermediate_views: int # render error with remaining context views input_error_remain_context: bool input_error_merge_remain_context: bool input_error_warp_remain_context: bool input_error_random_num_remain_context: bool input_error_num_remain_context_test: int # render error mv attn input_error_mv_attn: bool input_error_mv_attn_blocks: int # refine global attention refine_with_mv_attn: bool refine_with_mv_attn_lowres: bool refine_no_mv_attn: bool # remove only the attn mv_attn_conv_with_norm: bool # unet-attn conv with norm refine_mv_shuffle_attn: bool # use pixel shuffle to save computation instead of unet refine_mv_attn_with_pos_enc: bool refine_shuffle_attn_no_norm: bool refine_mv_unimatch_attn: bool # input gradients input_gradient: bool input_gradient_log: bool input_gradient_log_clip_deltas: float | int input_gradient_scale: float | int input_gradient_same_loss: bool # use the same loss as the gaussian update input_gradient_loss_reduction: str scale_residual_grads: bool # sliding window window_local_refine: bool # refine each local window separately and then combine all windows window_global_refine: bool # refine all windows together window_local_global_refine: bool # first refine each window seprately, and then refine all windows together # sliding window update instead of update all gaussians together update_window_size: int local_gaussian_render: bool # time encoding use_time_encoding: bool time_encoding_max_steps: int train_global_update_only: bool # random size refine # update more for low resolution, less for high random_update_with_size: bool # amp use_amp: bool pt_head_amp: bool pt_update_amp: bool use_checkpointing: bool recurrent_use_checkpointing: bool # Debugging debug_refine_update_module: bool # Normalizing input input_gradient_normalize: bool input_gradient_normalize_type: str input_normalize_state: bool input_normalize_gaussians: bool # State scaling predict_state_scale: bool predict_state_scale_norm: bool # whether to normalize the state before scaling # Use optimizer without condition features init_state_wo_features: bool init_state_type: Literal["random", "constant"] init_state_scale: float | int opt_scales_before_act: bool # optimize scale before activation (raw -> exp -> scale -> log -> raw) # Preprocessing the init gaussians scale_initial_opacities: float | int # Experimental experimental_run: bool experimental_update: Bool3DGSCfg experimental_use_grads: bool experimental_use_norm_grads: Bool3DGSCfg experimental_lr: Number3DGSCfg # Deactivate gaussians local_prune_zero_radii: bool local_prune_low_weights: bool local_prune_low_weights_thresh: float | int update_only_nonzero_grad: bool # update learn residual state residual_state: bool # Update head update_head_layer_num: int update_head_concat_img: bool update_head_act: str | None # update_head activation to predict the deltas update_head_final_act: str | None # final activation in the update_head update_head_hidden_dim_matches: str # rebuttal or submission version update_head_scale_mag: bool # predict deltas as scale * 0.01 * jnp.exp(mag * 0.01) update_head_scalar_scale: bool # predict deltas as scalar * delta / norm(delta) update_head_scalar_scale_act: str # activation for the scalar scale output # Per-parameter-group update head (Feature A) update_head_per_param_heads: bool # separate heads per param group, each with own normalize+scale update_head_per_param_hidden_dim: int # hidden dim for per-param heads (SH head gets 2x) # Per-parameter scalar scales (Feature B) — requires update_head_scalar_scale=true update_head_per_param_scales: bool # per-group scalar scales instead of one global scalar # Config from initializer sh_d: int | None init_gaussian_param_num: int | None = None init_sh_d: int | None = None # Fow initialization from feed forward, gaussians are aligned with pixels. init_gaussian_multiple: int | None = None latent_downsample: int | None = None delta_adam_combine_step: int = 0 # combine deltas and adam updates def update(self, initializer_cfg: InitializerCfg): """ Update the optimizer config based on the initializer config""" # General settings self.init_gaussian_param_num = initializer_cfg.get_gaussian_param_num() self.init_sh_d = initializer_cfg.get_sh_d() if self.sh_d is None: # get sh_d from initializer if not set self.sh_d = initializer_cfg.get_sh_d() # Settings specific to DepthSplat initializer if isinstance(initializer_cfg, ResplatInitializerCfg): self.latent_downsample = initializer_cfg.latent_downsample self.init_gaussian_multiple = initializer_cfg.init_gaussian_multiple # update proj channels if self.refine_condition_pt_feature: self.condition_channels = initializer_cfg.gaussian_regressor_channels else: self.condition_channels = initializer_cfg.get_pt_in_channels() # Settings specific to Colmap initializer elif isinstance(initializer_cfg, (InitializerPlyCfg, InitializerColmapCfg, InitializerEdgsCfg, InitializerRandomCfg, InitializerPointcloudCfg)): # Since pixels and gaussians are not alligned, we can not use pixel attributes assert not self.input_error, "The error calculation assumes per pixel gaussians" assert not self.update_head_concat_img assert not self.input_alpha assert not self.local_gaussian_render, "The local rendering assumes per view gaussians" assert self.init_state_wo_features, "Colmap initializer does not have point features, init_state_wo_features must be set to True" self.init_gaussian_multiple = 1 self.latent_downsample = 1 else: raise ValueError(f"Unsupported initializer config type: {type(initializer_cfg)}") class KnnBasedOptimizerState: # TODO Naama: OptimizerState class already exists def __init__(self, state: torch.Tensor): self.state = state def clone(self, clone_mask: torch.Tensor, zero_t: bool) -> None: cloned_state = self.state[clone_mask] if zero_t: cloned_state = torch.zeros_like(cloned_state) self.state = torch.cat([self.state, cloned_state], dim=0) def split(self, split_mask, num_splits: int, zero_t: bool) -> None: states_to_split = self.state[split_mask] split_states = states_to_split.chunk(num_splits, dim=0) new_states = [] for i in range(num_splits): if zero_t: new_states.append(torch.zeros_like(split_states[i])) else: new_states.append(split_states[i]) self.state = torch.cat([self.state, *new_states], dim=0) def replace(self, from_indices: torch.Tensor, dest_indices: torch.Tensor, zero_t: bool) -> None: if zero_t: self.state[dest_indices] = 0.0 else: self.state[dest_indices] = self.state[from_indices] def prune(self, prune_mask: torch.Tensor) -> None: self.state = self.state[~prune_mask] def add(self, num_new: int) -> None: if num_new <= 0: return device = self.state.device dtype = self.state.dtype input_dim = self.state.shape[1:] self.state = torch.cat([self.state, torch.zeros((num_new, *input_dim), device=device, dtype=dtype)], dim=0) def extend(self, num_new): self.add(num_new) class Abs(nn.Module): def forward(self, x): return torch.abs(x) def get_activation_cls(activation: Optional[str] = None): if activation in ['none', None, 'identity']: return nn.Identity elif activation == 'tanh': return nn.Tanh elif activation == "gelu": return nn.GELU elif activation == 'sigmoid': return nn.Sigmoid elif activation == 'relu': return nn.ReLU elif activation == "softplus": return nn.Softplus elif activation == "abs": return Abs else: raise ValueError(f"Unsupported activation: {activation}") class KnnBasedOptimizer(LearnedOptimizer[KnnBasedOptimizerCfg]): OPTIMIZER_NAME = "knn_based" OPTIMIZER_NAME_ALIASES: tuple[str, ...] = () def __init__(self, cfg: KnnBasedOptimizerCfg, save_every: Optional[FrequencyScheduler] = None) -> None: valid = {self.OPTIMIZER_NAME, *self.OPTIMIZER_NAME_ALIASES} assert cfg.name in valid, f"Expected optimizer name {valid}, got {cfg.name}" super().__init__(cfg, save_every) if self.cfg.residual_state: assert not self.cfg.refine_residual_init_state # State channel self.state_channels = self.cfg.state_channels # time embedder if self.cfg.use_time_encoding: self.time_encoder_fn, self.time_embedding_dim = get_embedder(multires=6) else: self.time_encoder_fn = None self.time_embedding_dim = 0 # update_proj if not self.cfg.init_state_wo_features: self.update_proj = nn.Conv2d(self.cfg.condition_channels, self.state_channels, 1) channels, in_channels, update_gaussian_param_num, out_channels, error_features_channels = ( self.define_update_channels(self.cfg.init_gaussian_param_num)) self.error_features_channels = error_features_channels self.gaussian_param_num = out_channels if self.cfg.input_error: self.update_feature = self.get_input_error_feature_extractor() if self.cfg.input_error_add_rgb_feature: if self.cfg.init_gaussian_multiple == 4: # re10k self.update_rgb_error_proj = nn.Sequential( nn.Linear(3, error_features_channels), nn.LayerNorm(error_features_channels) ) else: self.update_rgb_error_proj = nn.Sequential( nn.Linear(3 * self.cfg.latent_downsample ** 2, error_features_channels), nn.LayerNorm(error_features_channels) ) self.update_input_norm = self.get_update_input_norm(in_channels) self.update_module = self.get_update_module(channels, in_channels) # predict multiple gaussians out_channels = out_channels * self.cfg.refine_gaussian_multiple if not self.cfg.refine_same_num_points: out_channels = out_channels * self.cfg.init_gaussian_multiple # make sure the input size of the gaussian head is updated accordingly if self.cfg.use_time_encoding: channels += self.time_embedding_dim # Compute per-param group dims (needed by per_param_heads and per_param_scales) if self.cfg.update_head_per_param_heads or self.cfg.update_head_per_param_scales: self._per_param_group_dims = self._compute_per_param_group_dims(out_channels) # Scaling state for update head if self.cfg.predict_state_scale: self.state_scale_head = self.get_state_scale_head(in_channels) self.update_head = self.get_update_head(in_channels, channels, out_channels) # multiple gaussian heads to predict multiple gaussians if self.cfg.gaussian_head_multiple > 1: self.update_head_list = self.get_update_head_list(channels, out_channels) # Define error calculation # add global attention to the render error if self.cfg.input_error and self.cfg.input_error_mv_attn: assert self.cfg.input_error_resnet_feature self.update_error_attn = nn.ModuleList([ MultViewLowresAttn(error_features_channels) for _ in range(self.cfg.input_error_mv_attn_blocks) ]) self.param_slices = get_gaussian_param_slices(self.cfg.sh_d) def _reset_knn_caches(self) -> None: """Invalidate cached KNN indices on all point-transformer sub-modules. Must be called whenever the number of Gaussians changes (e.g. after add_new) so the next forward recomputes KNN from scratch instead of using stale indices that index out-of-bounds into the grown point cloud. """ for module in self.modules(): if hasattr(module, "cache_knn_idx"): module.cache_knn_idx = None @property def adc_object_dict_to_adjust(self): if self.cfg.any_adc: object_dict: dict[str, Any] = {"depthsplat_state": None} # For ADC if self.cfg.input_gradient_normalize and self.cfg.input_gradient_normalize_type == "adam": object_dict.update(self.update_input_norm.subgroups_view(self.param_slices)) else: return None return object_dict def _compute_per_param_group_dims(self, out_channels): """Compute per-parameter-group output dimensions from total out_channels. Returns a dict {group_name: dim} in the same order as split_delta_gaussians. Accounts for no_refine_rotation, no_refine_mean, refine_sh_only, and multipliers. """ # TODO Naama: allow combination of no_refine_* p = get_gaussian_param_sizes(self.cfg.sh_d) all_params = [ ("means", "means"), ("scales", "scales"), ("rotations", "quats"), ("opacities", "opacities"), ("shs", "shs"), ] if self.cfg.refine_sh_only: excluded = {"means", "scales", "rotations", "opacities"} elif self.cfg.no_refine_rotation: excluded = {"rotations"} elif self.cfg.no_refine_mean: excluded = {"means"} else: excluded = set() multiplier = self.cfg.refine_gaussian_multiple if not self.cfg.refine_same_num_points: multiplier *= self.cfg.init_gaussian_multiple group_dims = {name: p[key] * multiplier for name, key in all_params if name not in excluded} assert sum(group_dims.values()) == out_channels, ( f"Per-param group dims {dict(group_dims)} sum={sum(group_dims.values())} != out_channels={out_channels}" ) return group_dims def _build_per_param_heads(self, channels, out_channels): """Build per-parameter-group heads (Feature A). Each head: Linear(channels, hidden) -> act -> Linear(hidden, dim+1) The +1 is a per-group scalar scale. Each head independently normalizes + scales. """ act_cls = get_activation_cls(self.cfg.update_head_act) hidden_dim = self.cfg.update_head_per_param_hidden_dim # Set up scale activation (shared across all per-param heads) scale_act_name = self.cfg.update_head_scalar_scale_act init_bias_map = {'softplus': -1, 'relu': 1e-8, 'abs': 1e-8} if scale_act_name not in init_bias_map: raise ValueError(f"Unsupported scalar_scale_act: {scale_act_name}") act_class = get_activation_cls(scale_act_name) self.scale_act = act_class(beta=1) if scale_act_name == 'softplus' else act_class() heads = nn.ModuleDict() for name, dim in self._per_param_group_dims.items(): # SH head gets 2x hidden dim (more outputs to predict) h = hidden_dim * 2 if name == "shs" else hidden_dim layers = [nn.Linear(channels, h), act_cls()] for _ in range(self.cfg.update_head_layer_num - 2): layers += [nn.Linear(h, h), act_cls()] layers.append(nn.Linear(h, dim + 1)) # +1 for scalar scale head = nn.Sequential(*layers) # Zero-init last layer (deltas start at 0) nn.init.zeros_(head[-1].weight) nn.init.zeros_(head[-1].bias) # Init scale bias nn.init.constant_(head[-1].bias[-1], init_bias_map[scale_act_name]) heads[name] = head return heads def get_update_head(self, in_channels, channels, out_channels): update_head_activation_cls = get_activation_cls(self.cfg.update_head_act) final_head_activation_cls = get_activation_cls(self.cfg.update_head_final_act) # skip connection to the image color if self.cfg.update_head_concat_img: channels += 3 * (self.cfg.latent_downsample ** 2) # Feature A: per-parameter-group heads (early return — builds ModuleDict instead of Sequential) if self.cfg.update_head_per_param_heads: assert not self.cfg.update_head_scale_mag, "update_head_scale_mag not supported with per_param_heads" assert not self.cfg.update_head_per_param_scales, "per_param_heads already includes per-group scales" return self._build_per_param_heads(channels, out_channels) # predict delta = scale * 0.01 * jnp.exp(mag * 0.01) if self.cfg.update_head_scale_mag: out_channels = out_channels * 2 if self.cfg.update_head_scalar_scale: if self.cfg.update_head_per_param_scales: # Feature B: one scalar scale per parameter group out_channels = out_channels + len(self._per_param_group_dims) else: out_channels = out_channels + 1 # Determine hidden layer size # TODO: update_head_hidden_dim_source should be "output" (out_channels). # Using "input" currently as default to reproduce rebuttal results. if self.cfg.update_head_hidden_dim_matches == "input": hidden_dim = channels # rebuttal version else: hidden_dim = out_channels # submitted version # Build update head layers_list = [ nn.Linear(channels, hidden_dim), update_head_activation_cls() ] for i in range(self.cfg.update_head_layer_num - 2): layers_list += [ nn.Linear(hidden_dim, hidden_dim), update_head_activation_cls(), ] layers_list += [ nn.Linear(hidden_dim, out_channels), final_head_activation_cls() ] update_head = nn.Sequential(*layers_list) # init the delta as 0 nn.init.zeros_(update_head[-2].weight) if final_head_activation_cls == torch.nn.Sigmoid: desired_init_delta = 0.005 bias = math.log(desired_init_delta / (1 - desired_init_delta)) # ~= -4.6 nn.init.constant_(update_head[-2].bias, bias) else: nn.init.zeros_(update_head[-2].bias) # Scalar scale output if self.cfg.update_head_scalar_scale: # Set the initial scale to very low number, to get the gradients flow init_bias_map = { 'softplus': -1, 'relu': 1e-8, 'abs': 1e-8, } act_name = self.cfg.update_head_scalar_scale_act if act_name not in init_bias_map: raise ValueError(f"Unsupported scalar_scale_out_act: {act_name}") # Initialize bias for scale output(s) if self.cfg.update_head_per_param_scales: num_groups = len(self._per_param_group_dims) for i in range(num_groups): nn.init.constant_(update_head[-2].bias[-(num_groups - i)], init_bias_map[act_name]) else: nn.init.constant_(update_head[-2].bias[-1], init_bias_map[act_name]) # Create activation act_class = get_activation_cls(act_name) self.scale_act = act_class(beta=1) if act_name == 'softplus' else act_class() return update_head def get_update_head_list(self, channels, out_channels): update_head_activation = get_activation_cls(self.cfg.update_head_act) final_head_activation = get_activation_cls(self.cfg.final_head_act) update_head_list = nn.ModuleList() for i in range(self.cfg.gaussian_head_multiple - 1): update_head_list.append( nn.Sequential( nn.Linear(channels, channels), update_head_activation(), nn.Linear(channels, out_channels), final_head_activation() ) ) # init the delta as 0 nn.init.zeros_(update_head_list[i][-2].weight) nn.init.zeros_(update_head_list[i][-2].bias) return update_head_list def get_update_input_norm(self, in_channels): if self.cfg.input_gradient_normalize: assert self.cfg.input_gradient, "for now we only normalize when using gradient as input" if self.cfg.input_gradient_normalize_type == 'layer': return nn.LayerNorm(in_channels) elif self.cfg.input_gradient_normalize_type == 'group': return CustomGroupNorm([self.gaussian_param_num, self.state_channels, self.gaussian_param_num]) elif self.cfg.input_gradient_normalize_type == 'batch': return nn.BatchNorm1d(in_channels, affine=False) elif self.cfg.input_gradient_normalize_type == 'g3r': return SlicedG3RNorm(in_channels, slice(-self.gaussian_param_num, None)) elif self.cfg.input_gradient_normalize_type == 'adam': assert not self.cfg.input_gradient_log and self.cfg.input_gradient_scale == 1 return AdamInputSmoothing(input_slice=slice(-self.gaussian_param_num, None)) else: raise ValueError(f"normalization type not supported {self.cfg.input_gradient_normalize_type}") else: return nn.Identity() def get_update_module(self, channels, in_channels): if not self.cfg.debug_refine_update_module: return None if self.cfg.refine_multi_scale_pt: update_module = nn.Sequential( PointLinearWrapper(in_channels, channels), MultiScalePointTransformer(channels, self.cfg.refine_knn_samples, subsample_method=self.cfg.subsample_method, attn_proj_channels=self.cfg.update_attn_proj_channels, ) ) else: update_module = nn.Sequential( PointLinearWrapper(in_channels, channels), PlainPointTransformer(channels, self.cfg.refine_knn_samples, num_blocks=self.cfg.num_basic_refine_blocks, qk_norm=self.cfg.pt_qk_norm, norm_pt_block=self.cfg.norm_pt_block, num_heads=self.cfg.pt_heads, no_rpe=True, no_attn=self.cfg.update_no_knn_attn, no_norm=self.cfg.update_no_tran_block_norm, act=self.cfg.update_tran_block_act, attn_proj_channels=self.cfg.update_attn_proj_channels, with_mv_attn=self.cfg.refine_with_mv_attn, with_mv_attn_lowres=self.cfg.refine_with_mv_attn_lowres, no_mv_attn=self.cfg.refine_no_mv_attn, conv_with_norm=self.cfg.mv_attn_conv_with_norm, mv_shuffle_attn=self.cfg.refine_mv_shuffle_attn, with_pos_enc=self.cfg.refine_mv_attn_with_pos_enc, shuffle_attn_no_norm=self.cfg.refine_shuffle_attn_no_norm, mv_unimatch_attn=self.cfg.refine_mv_unimatch_attn, use_checkpointing=self.cfg.use_checkpointing, use_fused_attn=self.cfg.use_fused_attn, knn_idx_update_every=self.cfg.knn_idx_update_every ) ) # Init normalization layers if self.cfg.input_normalize_state: for block in update_module[1].blocks: nn.init.zeros_(block.norm1.bias) nn.init.zeros_(block.norm2.bias) nn.init.ones_(block.norm1.weight) nn.init.ones_(block.norm2.weight) return update_module def get_state_scale_head(self, in_channels): state_scale_head = nn.Sequential( nn.Linear(in_channels, in_channels // 2), nn.ReLU(), nn.Linear(in_channels // 2, 1), nn.ReLU() ) # Init the scale to 1 # nn.init.zeros_(state_scale_head[-2].weight) nn.init.ones_(state_scale_head[-2].bias) return state_scale_head def define_update_channels(self, init_gaussian_param_num): if self.cfg.init_gaussian_multiple > 1: gaussian_param_num = init_gaussian_param_num // self.cfg.init_gaussian_multiple else: gaussian_param_num = init_gaussian_param_num # no pixel offset gaussian_param_num -= 2 # update position gaussian_param_num += 3 # SHs if self.cfg.sh_d != self.cfg.init_sh_d: gaussian_param_num += 3 * (self.cfg.sh_d - self.cfg.init_sh_d) # Get error channels if self.cfg.input_error: error_channels, error_feature_channels = self.define_error_channels() else: error_channels, error_feature_channels = 0, 0 # Get gradient channels if self.cfg.input_gradient: gradient_channels = gaussian_param_num * self.cfg.init_gaussian_multiple else: gradient_channels = 0 # final input channels input_signal_channels = gradient_channels + error_channels if self.cfg.refine_same_num_points: in_channels = (gaussian_param_num + self.state_channels + input_signal_channels) else: in_channels = (gaussian_param_num * self.cfg.init_gaussian_multiple + self.state_channels + input_signal_channels) if self.cfg.concat_init_state: in_channels += self.state_channels out_channels = gaussian_param_num if self.cfg.no_refine_mean: out_channels -= 3 channels = self.state_channels if self.cfg.input_alpha: # pixel shuffle the alpha channel to the latent resolution in_channels += self.cfg.latent_downsample ** 2 # alpha if self.cfg.input_depth or self.cfg.input_depth_smooth_error: # pixel shuffle the depth channel to the latent resolution in_channels += self.cfg.latent_downsample ** 2 # depth return channels, in_channels, gaussian_param_num, out_channels, error_feature_channels def define_error_channels(self): if self.cfg.no_render_error: error_channels = 0 else: if self.cfg.input_error_rgb_no_shuffle: error_channels = 3 else: error_channels = 3 * self.cfg.latent_downsample ** 2 if self.cfg.input_error_resnet_feature: # 3 scales: 1/2, 1/4, 1/8, channels: 64, 64, 128 if self.cfg.input_error_resnet_feature_layers in (18, 34): error_feature_channels = 64 + 64 if self.cfg.input_error_shallow_resnet_feature else 64 + 64 + 128 elif self.cfg.input_error_resnet_feature_layers == 50: error_feature_channels = 64 + 256 + 512 else: raise NotImplementedError error_channels = error_feature_channels else: error_feature_channels = 256 return error_channels, error_feature_channels def optimizer_preprocessing(self, optimizer_input: OptimizerInput, from_init: bool) -> None: if self.cfg.input_error_remain_context or self.cfg.input_error_merge_remain_context: assert self.cfg.input_error_cache_resnet_feature # Image dimensions context = optimizer_input.context b, v, _, h, w = context["image"].shape # Prepare Gaussians if from_init: # Scale initial opacities (in normal scale) # TODO Naama: add option to reset opacities and randomly reset/scale opacities of intermidiate updates opacities = optimizer_input.prev_output.gaussians.opacities # post activation, in [0, 1] scaled_opacities = opacities * self.cfg.scale_initial_opacities # default to 1.0 optimizer_input.prev_output.gaussians.opacities = scaled_opacities # Process shs shs = optimizer_input.prev_output.gaussians.harmonics # [B, N, 3, init_sh_d] init_sh_d = shs.shape[-1] if init_sh_d != self.cfg.sh_d: if init_sh_d > self.cfg.sh_d: shs = shs[:, :, :, :self.cfg.sh_d] # truncate [B, N, 3, sh_d] else: pad = self.cfg.sh_d - init_sh_d shs = F.pad(shs, (0, pad), "constant", 0) optimizer_input.prev_output.gaussians.harmonics = shs # Right now, this does not do anything, since we do not use windows local_window_update, test_window_size, window_end, window_start = self.get_window_size(v) optimizer_input.additional_info = local_window_update, test_window_size, window_end, window_start self.update_gaussians_for_window(v, h, w, optimizer_input) # Prepare state # Gaussians dimensions n = optimizer_input.prev_output.gaussians.means.shape[1] vector_state = self.get_vector_state(b, v, n, optimizer_input, from_init) if from_init: # Set everything so that the optimizer isn't aware whether it's a new scene # Convert InitializerOutput to OptimizerPreviousOutput optimizer_input.prev_output = OptimizerPreviousOutput(gaussians=optimizer_input.prev_output.gaussians, state=OptimizerState()) optimizer_input.prev_output.state.state = vector_state # init_state captures the scene-start state used by some experiments; # only set it on a fresh scene so replay-buffer resumes preserve the original value. if from_init: optimizer_input.prev_output.state.init_state = vector_state def update_gaussians_for_window(self, v, h, w, optimizer_input): # Get window parameters and set gaussians accordingly local_window_update, test_window_size, window_end, window_start = optimizer_input.additional_info if local_window_update and self.cfg.local_gaussian_render: init_gaussians = optimizer_input.prev_output.gaussians # select subset of gaussians init_gaussians_subset = select_gaussian_subset(init_gaussians, window_start, window_end, v=v, h=h // self.cfg.latent_downsample, w=w // self.cfg.latent_downsample, ) optimizer_input.prev_output.gaussians = init_gaussians_subset def _forward_impl( self, i: int, optimizer_input: OptimizerInput, optimizer_output: OptimizerOutput, full_context: BatchedViews, full_target: BatchedViews, **kwargs ) -> OptimizerOutput: # Timing self.iter_start.record() # Unpack iter_context: BatchedViews = optimizer_input.context target: BatchedViews = optimizer_input.target renderer: Decoder = optimizer_input.renderer b, v, _, h, w = iter_context["image"].shape assert b == 1, "Batch size > 1 not supported for post-processing" # Log number of gaussians self.nr_gaussians_log.append( optimizer_input.prev_output.gaussians.means.shape[1] ) # One optimization step res = self.apply_one_update_step( i, optimizer_input, optimizer_output ) updated_gaussians: Gaussians = res[0] state: Tensor = res[1] meta_for_adc: dict = res[2] updates: dict[str, Tensor] = res[3] grads_raw: Tensor | None = res[4] normalized_grads: Tensor | None = res[5] scaled_state: Tensor | None = res[6] gaussians_sel: Tensor | None = res[7] # Timing self._record_iter_timing() # Log stats if grads_raw is not None: grads = grads_raw # [B, G, D] nonzero_grads = (grads != 0).any(-1) # [B, G] # Filter out strictly zero gradients for logging grads = grads[nonzero_grads].unsqueeze(0) # [1, N_nonzero, D] assert nonzero_grads.shape[0] == 1 self.nr_nonzero_grad_log.append(nonzero_grads[0].sum().item()) # Local ADC # if optimizer_output.t == 500: # weight_vis_contribution, _ = get_visibility_contribution_from_gaussian_obj(iter_context, updated_gaussians) # [N] # prune_mask = weight_vis_contribution < 5 # print(f"Pruning {torch.sum(prune_mask)} gaussians out of {prune_mask.shape[0]} at iteration {i}") # updated_gaussians = updated_gaussians[:, ~prune_mask] # state = state[~prune_mask] # if self.cfg.normalize_update_input and self.cfg.normalize_update_input_type == "adam": # if not self.update_input_norm.is_reset(): # self.update_input_norm.prune(prune_mask) # Densification and Pruning if self.cfg.any_adc: n_before_adc = updated_gaussians.means.shape[1] # Prepare objects to adjust during ADC object_dict = self.adc_object_dict_to_adjust object_dict["depthsplat_state"] = KnnBasedOptimizerState(state) object_dict["depthsplat_init_state"] = KnnBasedOptimizerState(optimizer_input.prev_output.state.init_state) # Apply ADC self.apply_adc( i=i, v=v, h=h, w=w, adc_state=optimizer_input.prev_output.state.adc_state, gaussians=updated_gaussians, meta=meta_for_adc, object_dict_to_adjust=object_dict ) # Update state after ADC state = object_dict["depthsplat_state"].state optimizer_input.prev_output.state.init_state = object_dict["depthsplat_init_state"].state del object_dict["depthsplat_state"] if self.cfg.input_gradient_normalize and self.cfg.input_gradient_normalize_type == "adam": self.update_input_norm.aggregate_from_subgroups(object_dict, self.param_slices) # If N changed (add_new grew the population), stale KNN caches in the # point transformer modules would index out-of-bounds on the next forward # pass → CUDA illegal memory access. Reset them so they are recomputed. if updated_gaussians.means.shape[1] != n_before_adc: self._reset_knn_caches() # Save updated gaussians and state optimizer_input.prev_output.gaussians = updated_gaussians optimizer_input.prev_output.state.state = state if self.cfg.input_gradient_normalize_type == "adam": optimizer_input.prev_output.state.adam_state = self.update_input_norm.get_state() if self.training: optimizer_output.gaussian_list.append(updated_gaussians) # Info if not self.training and self.save_every(i + 1, tag="info"): # TODO Naama: review and refactor # save guassians optimizer_output.gaussian_list.append(updated_gaussians, detach_and_cpu=True, save_to_disk=False) # Save delta stats assert optimizer_output.info is not None # log updates # unpack shs shs = updates.pop("shs") # [1, N, 3*sh_d] assert shs.shape[0] == 1, "Batch size > 1 not supported" shs = shs.squeeze(0) # [N, 3*sh_d] shs = rearrange(shs, "n (c x) -> n c x", c=3, x=self.cfg.sh_d) # [N, 3, sh_d] updates["sh0s"] = shs[..., 0:1] if self.cfg.sh_d > 1: updates["shNs"] = shs[..., 1:] else: updates["shNs"] = None # log deltas if "deltas" not in optimizer_output.info: optimizer_output.info["deltas"] = [] optimizer_output.info["deltas"].append( {k: v.squeeze(0).cpu() if v is not None else None for k, v in updates.items()}) # Split each vector grad into gaussians components if grads_raw is not None: if gaussians_sel is not None: # Restore the zero gradients for tracking b, g_valid, d = grads_raw.shape g = state.shape[0] grads_raw_full = torch.zeros((b, g, d)) normalized_grads_full = torch.zeros((b, g, d)) grads_raw_full[:, gaussians_sel, :] = grads_raw.cpu() normalized_grads_full[:, gaussians_sel, :] = normalized_grads.cpu() grads_raw = grads_raw_full normalized_grads = normalized_grads_full grads_raw: dict[str, Tensor] = split_grads(grads_raw.cpu(), self.cfg) # Split each vector normalized_grads into gaussians components if normalized_grads is not None: normalized_grads: dict[str, Tensor] = split_grads(normalized_grads.cpu(), self.cfg) assert grads_raw["means"].shape == normalized_grads["means"].shape, \ f"Shape mismatch between grads and normalized_grads: {grads_raw['means'].shape} vs {normalized_grads['means'].shape}" # log states if scaled_state is not None: if "states_norms" not in optimizer_output.info: optimizer_output.info["states_norms"] = [] state_norm = torch.norm(scaled_state, dim=-1) # [B, N] optimizer_output.info["states_norms"].append(state_norm.cpu()) # log gradients if "grads" not in optimizer_output.info: optimizer_output.info["grads"] = [] optimizer_output.info["grads"].append(grads_raw) # log normalized gradients if "normalized_grads" not in optimizer_output.info: optimizer_output.info["normalized_grads"] = [] optimizer_output.info["normalized_grads"].append(normalized_grads) # Check if output_path in kwargs output_path = kwargs.get("output_path", None) scene_name = kwargs.get("scene_name", None) if self.cfg.any_adc: pass # Plot stats # self.plot_info(i, output_path=output_path, scene_name=scene_name) # Post-update context + target renders self._save_post_update_renders( i, optimizer_input, optimizer_output, updated_gaussians, full_context, full_target, ) # Optimizer output is being changed in place, but for clarity we return it return optimizer_output def apply_one_update_step( self, i, optimizer_input: OptimizerInput, optimizer_output: OptimizerOutput ) -> tuple[Gaussians, Tensor, dict, dict[str, Tensor], Tensor | None, Tensor | None, Tensor | None, Tensor | None]: # Unpacking context = optimizer_input.context target = optimizer_input.target renderer = optimizer_input.renderer debug_dict = optimizer_input.debug_dict num_refine = optimizer_input.num_refine gaussians = optimizer_input.prev_output.gaussians # Gaussian object of [B, N, C] state = optimizer_input.prev_output.state.state # [N, C] init_state = optimizer_input.prev_output.state.init_state # [N, C] local_window_update, test_window_size, window_end, window_start = optimizer_input.additional_info # Get input signal for the optimizer model (erros/gradients) self.decoder_event_start.record() input_signal, gaussian_grads_raw, gaussian_grads, grad_sign, context_render_output, means2d_grads = ( self.prepare_input_signal(context, i, gaussians, local_window_update, renderer, window_end, window_start, num_refine) ) self.decoder_event_end.record() # Preparing meta for ADC if means2d_grads is not None: means2d_grads = means2d_grads.detach() # [B, V, N, 2] meta_for_adc = { "visibility_filter": context_render_output.visibility_filter.detach(), # [B, V, N] "radii": context_render_output.radii.detach(), # [B, V, N, 1] "means_2d_grads": means2d_grads, # [B, V, N, 2] } # Handle zero gradient gaussians # We either prune them, or exclude them from the input/output update if self.cfg.update_only_nonzero_grad and gaussian_grads is not None: gaussian_grads, gaussian_grads_raw, gaussians, grad_sign, init_state, input_signal, state = ( self.handle_zero_grad_gaussians( context, context_render_output, gaussian_grads, gaussian_grads_raw, gaussians, grad_sign, init_state, input_signal, means2d_grads, meta_for_adc, optimizer_input, state) ) # For training, if the number of active gaussians is too high, skip this batch # TODO Naama: maybe sampling? active_gaussians_num = state.shape[0] if self.training: if active_gaussians_num > 100_000: print(f"Skipping batch at iteration {i} with {active_gaussians_num} active gaussians.") raise SkipBatchException() if active_gaussians_num < self.cfg.refine_knn_samples: print( f"Skipping batch at iteration {i} with only {active_gaussians_num} active gaussians (need >= {self.cfg.refine_knn_samples}).") raise SkipBatchException() # Training only: save the rendering of initialization for logging # Will not be used for loss calculation # TODO Naama: this cause to many confusion. Pull it out of this function if self.training and i == 0: # Append context images initialization assert context_render_output is not None optimizer_output.context_render_list.append(context_render_output, detach_and_cpu=False) # render target images initialization target_render_output = renderer.forward_batch_subset(gaussians, target) optimizer_output.target_render_list.append(target_render_output, detach_and_cpu=False) # Unpack Gaussians means, scales, rotations_unnorm, opacities_raw, shs = unpack_gaussians( gaussians, scales_log=self.cfg.opt_scales_before_act, opacities_logit=True, opacities_unsqueeze=True, detach=True, # stop gradient of last predictions scales_lims=(self.cfg.clamp_min_scale, self.cfg.clamp_refine_max_scale), raw_opacities_lims=(self.cfg.clamp_min_raw_opacities, self.cfg.clamp_max_raw_opacities) ) gaussians_concat = pack_gaussians(means, scales, rotations_unnorm, opacities_raw, shs) # [B, N, C] b, v, c, h, w = context["image"].shape latent_h = h // self.cfg.latent_downsample latent_w = w // self.cfg.latent_downsample # Debugging reprojection error if debug_dict is not None and (not self.training and self.save_every(i, tag="debug")): if "reprojection_error" in debug_dict: self.debug_reprojection_error(means, debug_dict, context, i, latent_h, latent_w) # prepare pt input point_cloud, tmp_batch_size = self.get_point_cloud(latent_h, latent_w, local_window_update, means, test_window_size, v) # Create offset directly on device to avoid CPU-GPU transfer offset = torch.arange(1, b + 1, device=state.device, dtype=torch.long) * tmp_batch_size # reshape tmp_gaussian = self.reshape_gaussians_to_nc(latent_h, latent_w, gaussians_concat, v) # [B, N, C] --> [BN, C] # add global attention to exchange info across views if self.cfg.input_error_mv_attn: input_signal = self.apply_global_attn(b, h, input_signal, latent_h, latent_w, local_window_update, test_window_size, v, w) tmp_input_signal = input_signal.reshape(-1, input_signal.shape[-1]) # [B, N, C] --> [BN, C] - faster than rearrange tmp_input_signal = self.append_to_input_signal(b, context, context_render_output, tmp_input_signal, v) # Normalize state before input it to the update module if self.cfg.input_normalize_state: state_norm = state.norm(dim=1, keepdim=True) / math.sqrt(state.shape[-1]) # [BG, 1] state = state / (state_norm + 1e-8) # [BG, C] normalized_input_signal = self.update_input_norm(tmp_input_signal) if self.cfg.input_normalize_gaussians: tmp_gaussian_mean = tmp_gaussian.mean() tmp_gaussian_std = tmp_gaussian.std() tmp_gaussian = (tmp_gaussian - tmp_gaussian_mean) / (tmp_gaussian_std + 1e-8) with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): point_cloud, tmp_gaussian, state, update_input = self.prepare_update_input(b, i, init_state, normalized_input_signal, latent_h, latent_w, local_window_update, point_cloud, tmp_gaussian, # gradients/errors + additional pixel related quantities state, v, window_end, window_start) # if self.cfg.refine_with_mv_attn: # state = concat # for i in range(len(self.update_module)): # print(i, len(self.update_module), self.update_module[i]) # state = self.update_module[i]([point_cloud, state, offset]) # [N, C] # else: updated_state = self.apply_update_module(b, latent_h, latent_w, offset, point_cloud, update_input, v, state, i) # Hard coded extract normalized gradients if self.cfg.input_gradient and self.cfg.input_gradient_normalize: normalized_grads = normalized_input_signal else: normalized_grads = None # Recover the state norm if self.cfg.input_normalize_state: # state = state * state_std + state_mean updated_state = updated_state * state_norm # Predict a scale for the updtaed scale for the MLP deltas prediction # The updated state for the next stage remains the same if self.cfg.predict_state_scale: state_scale = self.state_scale_head(update_input.detach()) if self.cfg.predict_state_scale_norm: # Normalize the state vector state_scale = state_scale / (state_scale.norm(p=2, dim=1, keepdim=True) + 1e-8) else: state_scale = torch.tensor([1], device=state.device, dtype=state.dtype) updated_state_scaled = state_scale * updated_state # optionally append time encodiing to normalize input with TimeEncodingWrapper(self.cfg.use_time_encoding, self.time_encoder_fn, optimizer_output.t, self.cfg.time_encoding_max_steps, updated_state_scaled) as embedded_state: if self.cfg.use_time_encoding: assert not self.cfg.concat_init_state assert not self.cfg.replace_init_state # delta gaussian head delta_gaussians = self.apply_delta_gaussian_head(b, context, init_state, embedded_state, v) visibility_scale = None # disable for now delta_means, delta_opacities, delta_rotations, delta_scales, delta_shs, init_repeat, delta_gaussians = ( self.postprocess_deltas(b, delta_gaussians, gaussian_grads, gaussians_concat, grad_sign, latent_h, latent_w, local_window_update, normalized_grads, state, test_window_size, v, window_end, window_start, optimizer_output.t, optimizer_output.T, visibility_scale) ) means, opacities_raw, rotations_unnorm, scales, shs = self.repeat_gaussians(means, opacities_raw, rotations_unnorm, scales, shs) covariances, means, scales, rotations, rotations_unnorm, opacities_raw, shs = self.update_gaussians_params( delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs, means, scales, rotations_unnorm, opacities_raw, shs, init_repeat) # Recover the state in non valid gaussians (and grad for logging) if gaussians.sel is not None: sel = gaussians.sel # [B, G] full_state = optimizer_input.prev_output.state.state # Convert full state to the dtype of state full_state = full_state.to(state.dtype) # Use non-in-place index_put to avoid in-place modification of tensors # in the autograd computation graph (fixes version mismatch errors with stability loss) updated_state = full_state.index_put((sel,), updated_state) else: sel = None # update gaussians (only where mask is True) # Use view instead of rearrange for speed shs_reshaped = shs.view(shs.shape[0], shs.shape[1], 3, -1) gaussians = gaussians.update_object_by_curr_mask( means=means, covariances=covariances, harmonics=shs_reshaped, opacities=opacities_raw.squeeze(-1).sigmoid(), scales=scales, rotations=rotations, rotations_unnorm=rotations_unnorm, sel=None, deltas=delta_gaussians if self.training else None, gradients=gaussian_grads_raw if self.training else None, norm_gradients=normalized_grads.unsqueeze(0) if normalized_grads is not None and self.training else None ) updates = { "means": delta_means.detach(), "scales": delta_scales.detach(), "rotations": delta_rotations.detach(), "opacities": delta_opacities.detach(), "shs": delta_shs.detach() } grads_raw = gaussian_grads.detach() if gaussian_grads is not None else None grads_adam = normalized_grads.detach() if normalized_grads is not None else None return gaussians, updated_state, meta_for_adc, updates, grads_raw, grads_adam, updated_state_scaled, sel def postprocess_deltas(self, b, delta_gaussians, gaussian_grads, gaussians_concat, grad_sign, latent_h, latent_w, local_window_update, normalized_grads, state, test_window_size, v, window_end, window_start, t, T, visibility_scale): # Updates for gradient input (scale, log scale, ) delta_gaussians_raw = delta_gaussians delta_gaussians = self.update_delta_for_gradients_input(delta_gaussians_raw, grad_sign, normalized_grads, visibility_scale) # Rearrange back to [B, N, C] delta_gaussians, delta_gaussians_raw = self.rearrange_delta_gaussians(b, delta_gaussians, delta_gaussians_raw, latent_h, latent_w, local_window_update, gaussians_concat, test_window_size, v, window_end, window_start) # TODO Naama: shouldn't it be before rearranging? # multiple gaussian heads to predict multiple gaussians with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): if self.cfg.gaussian_head_multiple > 1: num_additional_heads = self.cfg.gaussian_head_multiple - 1 delta_gaussian_list = [delta_gaussians] # list of [B, N, C] for i in range(num_additional_heads): curr_delta = self.update_head_list[i](state) curr_delta = rearrange(curr_delta, "(b n) c -> b n c", b=b) delta_gaussian_list.append(curr_delta) delta_gaussians = torch.cat(delta_gaussian_list, dim=1) # [B, K*N, C] # Experimental overide deltas if self.cfg.experimental_run: self.experimental_update_deltas(delta_gaussians, gaussian_grads, normalized_grads) # Split delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs, init_repeat = ( self.split_delta_gaussians(delta_gaussians) ) # Apply lr delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs = self.scale_deltas_with_lr( t, delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs ) # Linear combination with adam normalized gradients if self.cfg.delta_adam_combine_step > 0 and normalized_grads is not None: assert t <= T if t > self.cfg.delta_adam_combine_step: alpha = 0.0 beta = 1 - ((t - self.cfg.delta_adam_combine_step) / (T - self.cfg.delta_adam_combine_step)) ** alpha # Linear combination with adam normalized gradients # Use the inverse of the normalized gradients # TODO Naama: hard coded lr # means delta_means = beta * delta_means + (1 - beta) * -normalized_grads[ ..., self.param_slices["means"]] * 1.6e-4 # scales delta_scales = beta * delta_scales + (1 - beta) * -normalized_grads[ ..., self.param_slices["scales"]] * 5e-3 # rotations delta_rotations = beta * delta_rotations + (1 - beta) * -normalized_grads[ ..., self.param_slices["quats"]] * 1e-3 # opacities delta_opacities = beta * delta_opacities + (1 - beta) * -normalized_grads[ ..., self.param_slices["opacities"]] * 5e-2 # sh0 - use view instead of rearrange for speed delta_shs_bgdc = delta_shs.view(delta_shs.shape[0], delta_shs.shape[1], 3, -1) # [b, g, 3, c] delta_sh0 = delta_shs_bgdc[..., 0] # [b, g, 3] delta_shN = delta_shs_bgdc[..., 1:] # [b, g, 3, d-1] delta_shN = delta_shN.flatten(-2) # [b, g, 3*(d-1)] - faster than rearrange new_delta_sh0 = beta * delta_sh0 + (1 - beta) * -normalized_grads[ ..., self.param_slices["sh0"]] * 2.5e-3 new_delta_shN = beta * delta_shN + (1 - beta) * -normalized_grads[ ..., self.param_slices["shN"]] * 1.25e-4 new_delta_shN = new_delta_shN.view(new_delta_shN.shape[0], new_delta_shN.shape[1], 3, -1) # [b, g, 3, d-1] delta_shs[..., ::self.cfg.sh_d] = new_delta_sh0 # shN for i in range(1, self.cfg.sh_d): delta_shs[..., i::self.cfg.sh_d] = new_delta_shN[..., i - 1] return delta_means, delta_opacities, delta_rotations, delta_scales, delta_shs, init_repeat, delta_gaussians def handle_zero_grad_gaussians(self, context, context_render_output, gaussian_grads, gaussian_grads_raw, gaussians, grad_sign, init_state, input_signal, means2d_grads, meta_for_adc, optimizer_input, state): # Compute a mask for gaussian that did not contribute to any pixel of context views # Their gradients are strictly zero. # We don't want to prune them, as they might be relevant in other views (in dense views). if self.cfg.prune_invisible_gaussians: gaussian_grads, gaussians, grad_sign, input_signal, state = self.prune_invisible_gaussians(context, context_render_output, gaussian_grads, gaussian_grads_raw, gaussians, grad_sign, input_signal, means2d_grads, meta_for_adc, optimizer_input, state) else: assert not self.cfg.local_prune_zero_radii assert not self.cfg.local_prune_low_weights assert gaussian_grads.shape[0] == 1, "Batch size > 1 not supported with mask" # radii_mask = (context_render_output.radii != 0).all(1).all(-1) # [B, G] # valid_mask = valid_mask & radii_mask # only consider gaussians with non-zero radius as valid # radii = context_render_output.radii # [B, V, G, 2] # # # XOR on radii last dimension to find gaussians that have zero radius in only one dimension # assert ((radii[..., 0] == 0) ^ (radii[..., 1] == 0)).sum() == 0 # [B, V, G] # # # Check that all zero radius gaussians are in the zero gradient mask (but not necessarily the opposite) # zero_radius_mask = (radii == 0).any(1).any(-1) # [B, G] # zero_grad_mask = ~valid_mask # [B, G] # zero_radius_cnt = zero_radius_mask.sum() # zero_grad_of_zero_radii_cnt = zero_grad_mask[zero_radius_mask].sum() # assert zero_grad_of_zero_radii_cnt == zero_radius_cnt, (f"All zero radius gaussians should have zero " # f"gradients. Found {zero_radius_cnt} zero radius gaussians, but only {zero_grad_of_zero_radii_cnt} of " # f"them have zero gradients.") # print(f"Found {zero_grad_of_zero_radii_cnt} / {zero_radius_cnt} zero radius gaussians with zero gradients.") # Contribution of zero gradient gaussians # gaussian_grads_zero_radii = gaussian_grads[zero_radius_mask] # [G_zero_radius, D] # assert gaussian_grads_zero_radii.abs().sum() == 0, "Gaussians with zero radius should have zero gradients." # radii of zero gradient gaussians # radii_zero_grad = radii[:, :, zero_grad_mask[0]] # [G_zero_grad, V, 2] # zero_grad_radii_cont = radii_zero_grad.sum() # Compute [G] mask without materializing [B,G,D] bool # any() on floats treats nonzero as True valid_g = gaussian_grads[0].any(dim=-1) # [G] bool sel = None # if everything is valid, skip all slicing work if not valid_g.all(): sel = valid_g.nonzero(as_tuple=True)[0] # [G_valid] input_signal = input_signal[:, sel, :] # [B, G_valid, C] gaussian_grads = gaussian_grads[:, sel, :] # [B, G_valid, D] if gaussian_grads_raw is not None: gaussian_grads_raw = gaussian_grads_raw[:, sel, :] if grad_sign is not None: grad_sign = grad_sign[:, sel, :] state = state[sel, :] # [G_valid, C] init_state = init_state[sel, :] # [G_valid, C] valid_mask = valid_g.unsqueeze(0) # [1, G] gaussians.sel = sel if self.cfg.input_gradient_normalize_type == "adam": self.update_input_norm.sel = sel return gaussian_grads, gaussian_grads_raw, gaussians, grad_sign, init_state, input_signal, state def prune_invisible_gaussians(self, context, context_render_output, gaussian_grads, gaussian_grads_raw, gaussians, grad_sign, input_signal, means2d_grads, meta_for_adc, optimizer_input, state): # Get visible gaussians mask, based on the last rendering with torch.no_grad(): visible_mask = self.get_visible_gaussian_mask(gaussian_grads, gaussians, context_render_output.visibility_filter, context) # [B, N, 1] if visible_mask is None: return gaussian_grads, gaussians, grad_sign, input_signal, state assert visible_mask.shape[0] == 1 visible_mask = visible_mask[0, :, 0] # [N], squeeze batch and last dim # Apply mask gaussians = gaussians[:, visible_mask] state = state[visible_mask] input_signal = input_signal[:, visible_mask] # [B, N, C] if gaussian_grads is not None: gaussian_grads = gaussian_grads[:, visible_mask] # [B, N, C] if gaussian_grads_raw is not None: gaussian_grads_raw = gaussian_grads_raw[:, visible_mask] # [B, N, C] if grad_sign is not None: grad_sign = grad_sign[:, visible_mask] # [B, N, C] meta_for_adc["visibility_filter"] = context_render_output.visibility_filter[:, :, visible_mask] meta_for_adc["radii"] = context_render_output.radii[:, :, visible_mask] if means2d_grads is not None: meta_for_adc["means_2d_grads"] = means2d_grads[:, :, visible_mask] if self.cfg.input_gradient_normalize and self.cfg.input_gradient_normalize_type == "adam": if not self.update_input_norm.is_reset(): prune_mask = ~visible_mask self.update_input_norm.prune(prune_mask) # the prune fn will invert the mask again if self.cfg.any_adc: optimizer_input.prev_output.state.adc_state.external_pruning(visible_mask) return gaussian_grads, gaussians, grad_sign, input_signal, state def deactivate_updates(self, subset, gaussians, radii_vis_mask, deltas, gaussian_grads): """ Deactivate updates for gaussians that are not visible in any view """ visible_mask = self.get_visible_gaussian_mask(gaussian_grads, gaussians, radii_vis_mask, subset) deltas = deltas * visible_mask # [B, N, C] return deltas def get_visible_gaussian_mask(self, gaussian_grads, gaussians, radii_vis_mask, subset): """ Get mask for gaussians that are visible in at least one view. We calculate two criteria: 1. Whether the projected 2d radius is visible in at least one view. 2. Whether the gaussian has a non-zero weight contribution to the rendering. If neither pruning criterion is enabled, returns None. Args: gaussian_grads: [B, N, C] or None gaussians: Gaussians object radii_vis_mask: [B, V, N], bool subset: dict, context or target """ # If no pruning criteria are active, return None if not (self.cfg.local_prune_zero_radii or self.cfg.local_prune_low_weights): return None b, v, n = radii_vis_mask.shape # Criterion 1: Projected radius visibility if self.cfg.local_prune_zero_radii: radii_vis_mask = radii_vis_mask.any(dim=1).unsqueeze(-1) # [B, N, 1] else: radii_vis_mask = torch.ones((b, n, 1), dtype=torch.bool, device=radii_vis_mask.device) # Criterion 2: Weight contribution visibility if self.cfg.local_prune_low_weights: threshold = self.cfg.local_prune_low_weights_thresh weight_vis_contribution, _ = get_visibility_contribution_from_gaussian_obj(subset, gaussians) # [N] weight_cont_mask = (weight_vis_contribution > threshold).view(1, -1, 1) else: weight_cont_mask = torch.ones((b, n, 1), dtype=torch.bool, device=radii_vis_mask.device) visible_mask = radii_vis_mask & weight_cont_mask # [B, N, 1] return visible_mask def experimental_inplace_update_delta(self, deltas, grads, normalized_grads, cfg_attr): # Slicing of the gradients vector per parameter param_num = grads.shape[-1] assert param_num == 11 + self.cfg.sh_d * 3 param_slices = self.param_slices update = getattr(self.cfg.experimental_update, cfg_attr) if update: # Update this parameter use_norm_grad = getattr(self.cfg.experimental_use_norm_grads, cfg_attr) use_grad = self.cfg.experimental_use_grads and not use_norm_grad use_resplat = not use_grad and not use_norm_grad assert not (use_grad and use_norm_grad) if use_grad: # Use the inverse of the gradients # TODO Naama: hard coded learning rate for SGD deltas[..., param_slices[cfg_attr]] = -(grads[..., param_slices[cfg_attr]]).to(deltas.dtype) * 30 elif use_norm_grad: # Use the inverse of the normalized gradients updated_delta = -normalized_grads[..., param_slices[cfg_attr]] * getattr(self.cfg.experimental_lr, cfg_attr) deltas[..., param_slices[cfg_attr]] = updated_delta.to(deltas.dtype) else: # Use the network prediction (already negated before) pass else: # Do not update this parameter deltas[..., param_slices[cfg_attr]] = 0 def experimental_update_deltas(self, deltas, grads, normalized_grads): # Verify that at least one parameter is actually using norm_grads or grads override any_norm_grad = any( getattr(self.cfg.experimental_use_norm_grads, p) for p in self.cfg.experimental_update.param_names) any_grad = self.cfg.experimental_use_grads any_override = any_norm_grad or any_grad assert any_override, ( "experimental_run=true but no parameter has use_norm_grads or use_grads enabled. " "Check that experimental_use_norm_grads._base=true (it gates all other fields via property)." ) if any_norm_grad: assert normalized_grads is not None, ( "experimental_use_norm_grads is enabled but normalized_grads is None. " "Ensure input_gradient=true and input_gradient_normalize=true." ) for p in self.cfg.experimental_update.param_names: self.experimental_inplace_update_delta(deltas, grads, normalized_grads, p) def scale_deltas_with_lr(self, t, delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs): # Scale deltas with learning rates delta_means = delta_means * self.scheduler.get_lr(t, "means") delta_scales = delta_scales * self.scheduler.get_lr(t, "scales") if delta_rotations is not None: delta_rotations = delta_rotations * self.scheduler.get_lr(t, "rotations") delta_opacities = delta_opacities * self.scheduler.get_lr(t, "opacities") # Use view instead of rearrange for speed delta_shs = delta_shs.view(delta_shs.shape[0], delta_shs.shape[1], 3, -1) # [b, g, 3, c] delta_sh0 = delta_shs[..., 0] # [B, N, C] delta_shN = delta_shs[..., 1:] delta_sh0 = delta_sh0 * self.scheduler.get_lr(t, "sh0") delta_shN = delta_shN * self.scheduler.get_lr(t, "shN") delta_shs = torch.cat((delta_sh0.unsqueeze(-1), delta_shN), dim=-1) delta_shs = delta_shs.flatten(-2) # [b, g, d*c] - faster than rearrange return delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs def append_to_input_signal(self, b, context, context_render, tmp_input_signal, v): if self.cfg.input_alpha: render_alpha = rearrange(context_render.accumulated_alpha, "b v h w -> (b v) () h w") render_alpha = F.pixel_unshuffle(render_alpha, downscale_factor=self.cfg.latent_downsample) render_alpha = rearrange(render_alpha, "(b v) c h w -> (b v h w) c", b=b, v=v) tmp_input_signal = torch.cat((tmp_input_signal, render_alpha), dim=-1) if self.cfg.input_depth: render_depth = rearrange(context_render.depth, "b v h w -> (b v) () h w") render_depth = F.pixel_unshuffle(render_depth, downscale_factor=self.cfg.latent_downsample) render_depth = rearrange(render_depth, "(b v) c h w -> (b v h w) c", b=b, v=v) tmp_input_signal = torch.cat((tmp_input_signal, render_depth), dim=-1) if self.cfg.input_depth_smooth_error: disp = 1. / context_render.depth.clamp(min=1e-3, max=1000.) # [B, V, H, W] disp = rearrange(disp, "b v h w -> (b v) () h w") mean_disp = disp.mean(2, True).mean(3, True) norm_disp = disp / (mean_disp + 1e-7) tmp_imgs = rearrange(context["image"], "b v c h w -> (b v) c h w") depth_smooth_error = get_smooth_loss(norm_disp, tmp_imgs, no_mean=True) depth_smooth_error = F.pixel_unshuffle(depth_smooth_error, downscale_factor=self.cfg.latent_downsample) depth_smooth_error = rearrange(depth_smooth_error, "(b v) c h w -> (b v h w) c", b=b, v=v) tmp_input_signal = torch.cat((tmp_input_signal, depth_smooth_error), dim=-1) return tmp_input_signal def repeat_gaussians(self, prev_means, prev_opacities_raw, prev_rotations_unnorm, prev_scales, prev_shs): if self.cfg.gaussian_head_multiple > 1: # predict multiple gaussians for each point prev_means = prev_means.repeat(1, self.cfg.gaussian_head_multiple, 1) prev_scales = prev_scales.repeat(1, self.cfg.gaussian_head_multiple, 1) prev_rotations_unnorm = prev_rotations_unnorm.repeat(1, self.cfg.gaussian_head_multiple, 1) prev_opacities_raw = prev_opacities_raw.repeat(1, self.cfg.gaussian_head_multiple, 1) / self.cfg.gaussian_head_multiple # smaller opacities, important prev_shs = prev_shs.repeat(1, self.cfg.gaussian_head_multiple, 1) # NOTE: only repeat at the first iteration refine_repeat = self.cfg.refine_gaussian_multiple if refine_repeat > 1: # predict multiple gaussians for each point prev_means = prev_means.repeat(1, refine_repeat, 1) prev_scales = prev_scales.repeat(1, refine_repeat, 1) prev_rotations_unnorm = prev_rotations_unnorm.repeat(1, refine_repeat, 1) prev_opacities_raw = prev_opacities_raw.repeat(1, refine_repeat, 1) # smaller opacities, important prev_shs = prev_shs.repeat(1, refine_repeat, 1) return prev_means, prev_opacities_raw, prev_rotations_unnorm, prev_scales, prev_shs def split_delta_gaussians(self, delta_gaussians): delta_rotations = None if self.cfg.init_gaussian_multiple > 1 and not self.cfg.refine_same_num_points: init_repeat = self.cfg.init_gaussian_multiple else: init_repeat = 1 p = get_gaussian_param_sizes(self.cfg.sh_d) if self.cfg.refine_sh_only: delta_shs = delta_gaussians delta_means = delta_scales = delta_opacities = 0. elif self.cfg.no_refine_rotation: delta_means, delta_scales, delta_opacities, delta_shs = delta_gaussians.split( (p["means"] * init_repeat, p["scales"] * init_repeat, p["opacities"] * init_repeat, p["shs"] * init_repeat), dim=-1 ) elif self.cfg.no_refine_mean: delta_scales, delta_rotations, delta_opacities, delta_shs = delta_gaussians.split( (p["scales"] * init_repeat, p["quats"] * init_repeat, p["opacities"] * init_repeat, p["shs"] * init_repeat), dim=-1 ) delta_means = torch.zeros_like(delta_scales) else: delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs = delta_gaussians.split( (p["means"] * init_repeat, p["scales"] * init_repeat, p["quats"] * init_repeat, p["opacities"] * init_repeat, p["shs"] * init_repeat), dim=-1 ) if ( self.cfg.refine_gaussian_multiple > 1 or self.cfg.init_gaussian_multiple > 1) and not self.cfg.refine_same_num_points: delta_means = rearrange(delta_means, "b n (c k) -> b (n k) c", k=init_repeat) delta_scales = rearrange(delta_scales, "b n (c k) -> b (n k) c", k=init_repeat) delta_rotations = rearrange(delta_rotations, "b n (c k) -> b (n k) c", k=init_repeat) delta_opacities = rearrange(delta_opacities, "b n (c k) -> b (n k) c", k=init_repeat) delta_shs = rearrange(delta_shs, "b n (c k) -> b (n k) c", k=init_repeat) return delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs, init_repeat def rearrange_delta_gaussians(self, b, delta_gaussians, delta_gaussians_raw, latent_h, latent_w, local_window_update, prev_gaussians_concat, test_window_size, v, window_end, window_start): # [BV, C] # update gaussian parameters delta_gaussians = rearrange(delta_gaussians, "(b n) c -> b n c", b=b) delta_gaussians_raw = rearrange(delta_gaussians_raw, "(b n) c -> b n c", b=b) if local_window_update and not self.cfg.local_gaussian_render: # zero padding for non-updated gaussians # curr_v = self.cfg.update_window_size if self.training else test_window_size curr_v = test_window_size tmp_delta = rearrange(delta_gaussians, "b (v h w) c -> b v h w c", b=b, v=curr_v, h=latent_h, w=latent_w) all_delta = [] # padding if window_start > 0: tmp_size = rearrange(prev_gaussians_concat, "b (v h w) c -> b v h w c", b=b, v=v, h=latent_h, w=latent_w) pad_left = torch.zeros_like(tmp_size[:, :window_start, :, :, :], requires_grad=False) all_delta.append(pad_left) all_delta.append(tmp_delta) if window_end < v: tmp_size = rearrange(prev_gaussians_concat, "b (v h w) c -> b v h w c", b=b, v=v, h=latent_h, w=latent_w) pad_right = torch.zeros_like(tmp_size[:, window_end:, :, :, :], requires_grad=False) all_delta.append(pad_right) tmp_delta = torch.cat(all_delta, dim=1) # [B, V, H, W, C] delta_gaussians = rearrange(tmp_delta, "b v h w c -> b (v h w) c") # [B, N, C] return delta_gaussians, delta_gaussians_raw def update_gaussians_params(self, delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs, means, scales, rotations_unnorm, opacities_raw, shs, repeat): means = self.update_means(delta_means, means) # clamp the scale scales = self.update_scales(delta_scales, scales, repeat) if self.cfg.opt_scales_before_act: scales = scales.exp() if not self.cfg.no_refine_rotation: rotations, rotations_unnorm = self.update_rotations(delta_rotations, rotations_unnorm) else: rotations = F.normalize(rotations_unnorm, dim=-1) # compute covariance covariances = build_covariance(scales, rotations) # ([1, VHW, 3, 3]) opacities_raw = self.update_opacities(delta_opacities, opacities_raw, repeat) shs = self.update_shs(delta_shs, shs) return covariances, means, scales, rotations, rotations_unnorm, opacities_raw, shs def update_shs(self, delta_shs, prev_shs): shs = prev_shs + delta_shs # [B, N, 3*sh_d] if self.cfg.clamp_shs_soft: assert self.cfg.clamp_min_shs == -self.cfg.clamp_max_shs, "For soft clamp, min and max should be symmetric around 0" shs = torch.tanh(shs / self.cfg.clamp_max_shs) * self.cfg.clamp_max_shs else: shs = shs.clamp(min=self.cfg.clamp_min_shs, max=self.cfg.clamp_max_shs) return shs def update_opacities(self, delta_opacities, prev_opacities_raw, repeat): # update init opacities when predicting multiple gaussians if repeat > 1 and not self.cfg.multi_gaussian_scale_smaller and (self.cfg.init_gaussian_multiple == 1): # Given y = sigmoid(x), to get new x' such that sigmoid(x') = y / K: # The formula is: x' = x + log((1 - y) / (K - y)) # This adjusts x so that the sigmoid output is scaled down by a factor of K tmp_sigmoid = prev_opacities_raw.sigmoid() prev_opacities_raw = prev_opacities_raw + torch.log( (1 - tmp_sigmoid) / (repeat - tmp_sigmoid)) + delta_opacities else: prev_opacities_raw = prev_opacities_raw + delta_opacities # prev_opacities_raw = prev_opacities_raw.clamp(min=-5, max=5) return prev_opacities_raw @staticmethod def update_rotations(delta_rotations, prev_rotations_unnorm): assert delta_rotations is not None prev_rotations_unnorm = prev_rotations_unnorm + delta_rotations # normazlie prev_rotations = prev_rotations_unnorm / (prev_rotations_unnorm.norm(dim=-1, keepdim=True) + 1e-8) return prev_rotations, prev_rotations_unnorm def update_scales(self, delta_scales, prev_scales, repeat): if repeat > 1 and self.cfg.multi_gaussian_scale_smaller: # smaller initial scales new_scales = (prev_scales / repeat + delta_scales).clamp(min=self.cfg.gaussian_adapter.clamp_min_scale) else: new_scales = (prev_scales + delta_scales) if self.cfg.opt_scales_before_act: min_scale = self.cfg.clamp_min_raw_scales max_scale = self.cfg.clamp_max_raw_scales else: min_scale = self.cfg.clamp_min_scale max_scale = self.cfg.clamp_refine_max_scale new_scales = new_scales.clamp(min=min_scale) new_scales = new_scales.clamp(max=max_scale) return new_scales @staticmethod def update_means(delta_means, prev_means): prev_means = (prev_means + delta_means) return prev_means def _on_scene_start_impl(self, optimizer_input: OptimizerInput) -> None: # Reset the state if isinstance(optimizer_input.prev_output, InitializerOutput): # New scene from_init = True # Reset the optimizer state for a new scene # We cannot just use super().on_scene_start() because we need to process the InitializerOutput in case it # contain conditioning features self.reset_logs() if self.cfg.input_gradient_normalize_type == "adam": self.update_input_norm.reset() nr_gaussians = rearrange(optimizer_input.prev_output.gaussians.means, "b n c -> (b n) c").shape[0] param_num = self.gaussian_param_num self.update_input_norm.initialize(shape=(nr_gaussians, param_num), device=optimizer_input.prev_output.gaussians.means.device) # make sure xyz are contiguous optimizer_input.prev_output.gaussians.means = optimizer_input.prev_output.gaussians.means.contiguous() elif isinstance(optimizer_input.prev_output, OptimizerPreviousOutput): from_init = False if self.cfg.input_gradient_normalize_type == "adam": # Continuing previous optimization from replay buffer self.update_input_norm.update_state(optimizer_input.prev_output.state.adam_state) # TODO Naama: logs are not handled right now for continuing from replay buffer self.reset_logs() else: raise ValueError(f"Unknown prev_output type {type(optimizer_input.prev_output)}") # Preparing the input for a new scene (will handle both new scene and continuing from replay buffer) # Will convert init_output to prev_output internally self.optimizer_preprocessing(optimizer_input, from_init=from_init) # initialize adc state, after converting to prev_output if from_init and self.cfg.any_adc: self.initialize_adc_state(self.cfg, optimizer_input) def reshape_gaussians_to_nc(self, latent_h, latent_w, prev_gaussians_concat, v): if self.cfg.init_gaussian_multiple == 4 and not self.cfg.refine_same_num_points: # gaussians are with more points, reshape tmp_gaussian = rearrange(prev_gaussians_concat, "b (v h x w y) c -> (b v h w) (c x y)", v=v, h=latent_h, w=latent_w, x=2, y=2) elif self.cfg.init_gaussian_multiple == 16 and not self.cfg.refine_same_num_points: tmp_gaussian = rearrange(prev_gaussians_concat, "b (v h x w y) c -> (b v h w) (c x y)", v=v, h=latent_h, w=latent_w, x=4, y=4) else: tmp_gaussian = rearrange(prev_gaussians_concat, "b n c -> (b n) c") return tmp_gaussian def get_point_cloud(self, latent_h, latent_w, local_window_update, prev_means, test_window_size, v): # TODO: when the initial model predicts multiple gaussians, the number of points also increases if self.cfg.init_gaussian_multiple == 4 and not self.cfg.refine_same_num_points: point_cloud = rearrange(prev_means, "b (v h w) c -> b v h w c", v=v, h=latent_h * 2, w=latent_w * 2, ) tmp_batch_size = v * latent_h * latent_w # simply use uniform grid subsample of point cloud to reduce points point_cloud = point_cloud[:, :, ::2, ::2] point_cloud = rearrange(point_cloud, "b v h w c -> (b v h w) c") elif self.cfg.init_gaussian_multiple == 16 and not self.cfg.refine_same_num_points: point_cloud = rearrange(prev_means, "b (v h w) c -> b v h w c", v=v, h=latent_h * 4, w=latent_w * 4, ) tmp_batch_size = v * latent_h * latent_w # simply use uniform grid subsample of point cloud to reduce points point_cloud = point_cloud[:, :, ::4, ::4] point_cloud = rearrange(point_cloud, "b v h w c -> (b v h w) c") else: point_cloud = rearrange(prev_means, "b n c -> (b n) c") if local_window_update: tmp_batch_size = test_window_size * latent_h * latent_w else: tmp_batch_size = prev_means.shape[1] return point_cloud, tmp_batch_size def get_vector_state(self, b, v, n, optimizer_input, from_init): if from_init: # Starting a new scene directly from the initializer # State should not be provided # Create initial state # optimizer_input.prev_output is of type InitializerOutput if optimizer_input.prev_output.features is None or self.cfg.init_state_wo_features: # Creating state without initializer features assert self.cfg.init_state_wo_features with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): dtype = torch.get_autocast_dtype('cuda') if self.cfg.init_state_type == "constant": state = torch.ones((b, n, self.cfg.state_channels), device=self.device, dtype=dtype) elif self.cfg.init_state_type == "random": state = torch.randn((b, n, self.cfg.state_channels), device=self.device, dtype=dtype) else: raise ValueError(f"Unknown init_state_type {self.cfg.init_state_type}") state = state * self.cfg.init_state_scale else: # Calculating state from initializer features state = self.get_state_from_condition_features(b, optimizer_input.prev_output.features, v) # [B, N, C] else: # Restarting optimizing a scene from a replay buffer state = optimizer_input.prev_output.state.state # TODO Naama: need to understand why rearrange here, perhaps something with pruning state = rearrange(state, "(b n) c -> b n c", b=b) # combine gaussians of all scnes in the batch [B*N, C] state = rearrange(state, "b n c -> (b n) c") # [B*N, C] # Do something with window size _, _, _, h, w = optimizer_input.context["image"].shape # [B, V, C, H, W] local_window_update, test_window_size, window_end, window_start = optimizer_input.additional_info # select initial state if local_window_update and self.cfg.local_gaussian_render: state = rearrange(state, "(b v h w) c -> b v h w c", b=b, v=v, h=h // self.cfg.latent_downsample, w=w // self.cfg.latent_downsample) state = state[:, window_start:window_end, :, :, :] state = rearrange(state, "b v h w c -> (b v h w) c") return state @staticmethod def _align_features(features, latent_h: int, latent_w: int) -> list: """Resize each feature map to (latent_h, latent_w) if needed and return as a list.""" out = [] vals = features.values() if isinstance(features, dict) else features for feat in vals: if feat.shape[-2:] != (latent_h, latent_w): feat = F.interpolate(feat, size=(latent_h, latent_w), mode='bilinear', align_corners=True) out.append(feat) return out def _get_latent_size(self, h: int, w: int) -> tuple[int, int]: """Compute latent (H, W) from image (H, W), accounting for init_gaussian_multiple upsampling.""" latent_h = h // self.cfg.latent_downsample latent_w = w // self.cfg.latent_downsample if self.cfg.init_gaussian_multiple == 4 and self.cfg.refine_same_num_points: latent_h *= 2 latent_w *= 2 elif self.cfg.init_gaussian_multiple == 16 and self.cfg.refine_same_num_points: latent_h *= 4 latent_w *= 4 return latent_h, latent_w def render_input_views_for_error_calc(self, context, local_window_update, prev_gaussians, renderer, window_end, window_start, num_refine, i): _, _, _, h, w = context["image"].shape # [B, V, C, H, W] render_res = (h, w) # Default rendering parameters input_info = context start = None end = None cfg = self.cfg # Use only first N views if cfg.input_error_num_views > 0: end = cfg.input_error_num_views # Local window update logic elif local_window_update: if i >= num_refine - 1: return None # Skip rendering on the last iteration start = window_start end = window_end # Final unified rendering call return renderer.forward_batch_subset( prev_gaussians, input_info, render_res, start=start, end=end, return_radii=False ) def get_state_from_condition_features(self, b, condition_features, v): with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): if not self.cfg.pt_update_amp and condition_features.dtype == torch.bfloat16: condition_features = condition_features.float() state = self.update_proj(condition_features.detach()) # [B, C, H, W] if self.cfg.init_gaussian_multiple == 4 and self.cfg.refine_same_num_points: state = F.interpolate(state, scale_factor=2, mode='bilinear', align_corners=True) elif self.cfg.init_gaussian_multiple == 16 and self.cfg.refine_same_num_points: state = F.interpolate(state, scale_factor=4, mode='bilinear', align_corners=True) else: pass # Convert to vector of Gaussians per batch [B, N, C] state = rearrange(state, "(b v) c h w -> b (v h w) c", b=b, v=v) # N = v * h * w return state def get_window_size(self, v): test_window_size = None if self.cfg.update_window_size > 0: local_window_update = True # if self.training: # window_start = random.randint(0, v - self.cfg.update_window_size) # window_end = window_start + self.cfg.update_window_size # else: # fixed window at test time, uniformly move from left to right # TODO: loop closure, connect left and right if self.training: test_window_size = self.cfg.update_window_size window_start = random.randint(0, test_window_size) window_end = window_start + test_window_size if window_end == v: # restart window_start = random.randint(0, test_window_size) window_end = window_start + test_window_size else: # at least do a full pass of all input views # test_window_size = int(np.ceil(v / self.cfg.num_refine)) test_window_size = self.cfg.update_window_size window_start = 0 window_end = window_start + test_window_size else: local_window_update = False window_start = 0 window_end = v return local_window_update, test_window_size, window_end, window_start def prepare_update_input(self, b, i, init_state, input_signal, latent_h, latent_w, local_window_update, point_cloud, tmp_gaussian, state, v, window_end, window_start): if self.cfg.replace_init_state: state = init_state if self.cfg.no_render_error: update_input = torch.cat((tmp_gaussian, state), dim=-1) else: if local_window_update and not self.cfg.local_gaussian_render: # select local window tmp_gaussian = rearrange(tmp_gaussian, "(b v h w) c -> b v h w c", b=b, v=v, h=latent_h, w=latent_w) tmp_gaussian = tmp_gaussian[:, window_start:window_end, :, :, :] tmp_gaussian = rearrange(tmp_gaussian, "b v h w c -> (b v h w) c") if i == 0: state = rearrange(state, "(b v h w) c -> b v h w c", b=b, v=v, h=latent_h, w=latent_w) state = state[:, window_start:window_end, :, :, :] state = rearrange(state, "b v h w c -> (b v h w) c") # local point cloud point_cloud = rearrange(point_cloud, "(b v h w) c -> b v h w c", b=b, v=v, h=latent_h, w=latent_w) point_cloud = point_cloud[:, window_start:window_end, :, :, :] point_cloud = rearrange(point_cloud, "b v h w c -> (b v h w) c") update_input = torch.cat((tmp_gaussian, state, input_signal), dim=-1) if self.cfg.concat_init_state: update_input = torch.cat((update_input, init_state), dim=-1) return point_cloud, tmp_gaussian, state, update_input def apply_update_module(self, b, latent_h, latent_w, offset, point_cloud, update_input, v, state, iter): def recurrent_chunk(update_input, point_cloud, offset): pxo = self.update_module[0]([point_cloud, update_input, offset]) state = self.update_module[1](pxo, iter=iter, b=b, v=v, h=latent_h, w=latent_w) return state if self.cfg.use_checkpointing or self.cfg.recurrent_use_checkpointing: new_state = torch.utils.checkpoint.checkpoint( recurrent_chunk, update_input, point_cloud, offset, use_reentrant=False, ) else: new_state = recurrent_chunk(update_input, point_cloud, offset) if self.cfg.residual_state: new_state = new_state + state return new_state def apply_delta_gaussian_head(self, b, context, init_state, state, v): if self.cfg.update_head_concat_img: # pixel unshuffle image img_unshuffle = rearrange(context["image"], "b v c h w -> (b v) c h w") img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=self.cfg.latent_downsample) img_unshuffle = rearrange(img_unshuffle, "(b v) c h w -> (b v h w) c", b=b, v=v) head_input = torch.cat((state, img_unshuffle), dim=-1) else: if self.cfg.refine_residual_init_state: head_input = state + init_state else: head_input = state if self.cfg.update_head_per_param_heads: delta_gaussians = self._apply_per_param_heads(head_input) else: delta_gaussians = self.update_head(head_input) return delta_gaussians def _apply_per_param_heads(self, head_input): """Run per-parameter-group heads and concatenate results. Each head outputs [N, dim+1] where the last dim is the scalar scale. Per-group normalize + scale is applied independently. """ deltas = [] for name, dim in self._per_param_group_dims.items(): raw = self.update_head[name](head_input) # [N, dim+1] scale = self.scale_act(raw[:, -1:]) # [N, 1] delta = raw[:, :-1] # [N, dim] if dim > 1: delta = delta / (delta.norm(p=2, dim=-1, keepdim=True) + 1e-8) * scale else: # 1-d (e.g. opacities): no direction to normalize, just scale magnitude delta = delta * scale deltas.append(delta) return torch.cat(deltas, dim=-1) def apply_global_attn(self, b, h, input_signal, latent_h, latent_w, local_window_update, test_window_size, v, w): # TODO Naama: do we need local_window? assert self.cfg.input_error_resnet_feature assert self.cfg.input_error if self.cfg.input_gradient and self.cfg.input_error: input_render_error = input_signal[..., :self.error_features_channels] else: input_render_error = input_signal with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16): for blk in self.update_error_attn: if self.cfg.refine_same_num_points: # no downsample, for re10k 256 input_render_error = blk(input_render_error, v=v, h=h, w=w) else: curr_v = test_window_size if local_window_update else v input_render_error = blk(input_render_error, v=curr_v, h=latent_h, w=latent_w) if self.cfg.input_gradient and self.cfg.input_error: input_signal[..., :self.error_features_channels] = input_render_error else: input_signal = input_render_error return input_signal def prepare_input_signal(self, context, i, gaussians, local_window_update, renderer, window_end, window_start, num_refine): # TODO Naama: review # make sure at least one of the following is True assert self.cfg.input_gradient or self.cfg.input_error input_view_features = None input_signal = None input_render_error = None context_render_output = None gaussian_grads_raw = None gaussian_grads = None grad_sign = None means2d_grads = None # calculate input gradients if self.cfg.input_gradient: gaussian_grads_raw, gaussian_grads, grad_sign, context_render_output, means2d_grads = ( self._calc_input_gradients(context, gaussians, renderer) ) input_signal = gaussian_grads_raw # When using gradients, context_render_output cannot be used for the meta-training, # because there was already one backward pass. # So we render again if in training. if context_render_output is None or self.training: context_render_output = self.render_input_views_for_error_calc(context, local_window_update, gaussians, renderer, window_end, window_start, num_refine, i) # calculate input rendering errors if self.cfg.input_error: if means2d_grads is None and self.cfg.need_2d_grads: raise NotImplementedError("Calculating 2dgrad for ADC is not implemented for error input alone") input_render_error = self._calc_input_errors(context, i, context_render_output, input_view_features, local_window_update, gaussians.means.detach(), window_end, window_start) input_signal = input_render_error if self.cfg.input_gradient and self.cfg.input_error: # Concatenate both gradients and errors input_signal = torch.cat((input_render_error, gaussian_grads), dim=-1) return input_signal, gaussian_grads_raw, gaussian_grads, grad_sign, context_render_output, means2d_grads def get_data_shim(self) -> DataShim: def data_shim(batch: BatchedExample) -> BatchedExample: batch = apply_patch_shim( batch, patch_size=self.cfg.shim_patch_size * self.cfg.downscale_factor, ) return batch return data_shim @property def sampler(self): return None def debug_reprojection_error(self, means, debug_dict, context, i, latent_h, latent_w): # Prepare means (remove singleton dim) means = rearrange(means, "b (v h w) c -> b v (h w) c", h=latent_h, w=latent_w) # [B, V, H*W, 3] # Expand extrinsics/intrinsics for broadcast extrinsics = context["extrinsics"].unsqueeze(2) # [B, V, 1, 4, 4] intrinsics = context["intrinsics"].unsqueeze(2) # [B, V, 1, 3, 3] # Project xy_ray_reconstructed, in_front = project(means, extrinsics, intrinsics) # [B, V, H*W, 2], [B, V, H*W] xy_ray, _ = sample_image_grid((latent_h, latent_w), xy_ray_reconstructed.device) # [B, V, H*W, 1, 2] xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") xy_ray = xy_ray.squeeze(-2) # [B, V, H*W, 2] xy_ray_unnorm = xy_ray.clone() xy_ray_unnorm[..., 0] *= latent_w xy_ray_unnorm[..., 1] *= latent_h xy_ray_reconstructed_unnorm = xy_ray_reconstructed.clone() xy_ray_reconstructed_unnorm[..., 0] *= latent_w xy_ray_reconstructed_unnorm[..., 1] *= latent_h reprojection_error = (xy_ray_unnorm - xy_ray_reconstructed_unnorm).abs() if debug_dict["reprojection_error"] is None: # First iteration, first scene debug_dict["reprojection_error"] = [[]] elif i == 0: # New iteration, new scene debug_dict["reprojection_error"].append([]) debug_dict["reprojection_error"][-1].append(reprojection_error.detach().cpu()) # import matplotlib.pyplot as plt # plt.figure(figsize=(12, 6)) # plt.hist(reprojection_error.flatten().detach().cpu(), bins=100, range=(0, 10)) # plt.title(f"Reprojection Error - step {i}") # plt.xlabel("Error (pixels)") # plt.ylabel("Frequency") # plt.show() def _calc_input_errors(self, context, i, input_render, input_view_features, local_window_update, prev_means, window_end, window_start): b, v, _, h, w = context["image"].shape # Detach the last rendered object input_rgb = input_render.color.detach() # compute input view rendering error if self.cfg.input_error_resnet_feature: input0 = rearrange(input_rgb, "b v c h w -> (b v) c h w") if self.cfg.input_error_num_views > 0: gt_input = context["image"][:, :self.cfg.input_error_num_views, :, :, :] elif local_window_update: gt_input = context["image"][:, window_start:window_end, :, :, :] else: gt_input = context["image"] input1 = rearrange(gt_input, "b v c h w -> (b v) c h w") transform = _IMAGENET_NORM if input_view_features is None: assert i == 0 # first time: extract all features concat = torch.cat((input0, input1), dim=0) input_tensor = transform(concat) with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): # Extract features with torch.no_grad(): features = self.update_feature(input_tensor) # align to the latent resolution latent_h, latent_w = self._get_latent_size(h, w) all_features = torch.cat(self._align_features(features, latent_h, latent_w), dim=1) render_view_features = all_features[:input0.shape[0]] input_view_features = all_features[input0.shape[0]:] else: # only extract render view features with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): # Extract features with torch.no_grad(): features = self.update_feature(transform(input0)) # align to the latent resolution latent_h, latent_w = self._get_latent_size(h, w) render_view_features = torch.cat(self._align_features(features, latent_h, latent_w), dim=1) corr = render_view_features - input_view_features if self.cfg.input_error_num_views > 0: # pad to V views curr_v = self.cfg.input_error_num_views indices = torch.arange(v) * curr_v // v corr = rearrange(corr, "(b v) c h w -> b v c h w", b=b) corr = corr[torch.arange(b).unsqueeze(1), indices, :, :, :] input_render_error = rearrange(corr, "b v c h w -> b (v h w) c") else: input_render_error = rearrange(corr, "(b v) c h w -> b (v h w) c", b=b) else: input_render_error = (input_render.color - context["image"]).abs() # [B, V, 3, H, W] input_render_error = rearrange(input_render_error, "b v c h w -> (b v) c h w") if self.cfg.input_error_rgb_no_shuffle: # bilinear input_render_error = F.interpolate(input_render_error, scale_factor=1. / self.cfg.latent_downsample, mode='bilinear', align_corners=True) else: # pixel unshuffle # TODO: when fps is used, how to reshape the render error to make sure its somehow pixel aligned to the gaussians input_render_error = F.pixel_unshuffle(input_render_error, downscale_factor=self.cfg.latent_downsample) input_render_error = rearrange(input_render_error, "(b v) c h w -> b (v h w) c", b=b, v=v) # [B, N, C] # include both feature error and image error if self.cfg.input_error_add_rgb_feature: rgb_render_error = input_render.color - context["image"] rgb_render_error = rearrange(rgb_render_error, "b v c h w -> (b v) c h w") if self.cfg.input_error_rgb_no_shuffle: # bilinear rgb_render_error = F.interpolate(rgb_render_error, scale_factor=1. / self.cfg.latent_downsample, mode='bilinear', align_corners=True) else: # pixel unshuffle # TODO: when fps is used, how to reshape the render error to make sure its somehow pixel aligned to the gaussians rgb_render_error = F.pixel_unshuffle(rgb_render_error, downscale_factor=self.cfg.latent_downsample) rgb_render_error = rearrange(rgb_render_error, "(b v) c h w -> b (v h w) c", b=b, v=v) # [B, N, C] rgb_render_error = self.update_rgb_error_proj(rgb_render_error) input_render_error = input_render_error + rgb_render_error return input_render_error def get_input_error_feature_extractor(self): update_feature = None # resnet feature if self.cfg.input_error_resnet_feature: update_feature = ResNetFeatureWarpper( shallow_resnet_feature=self.cfg.input_error_shallow_resnet_feature) if self.cfg.input_error_no_freeze_resnet_feature: # remove unused layers # NOTE: layer 3 is also not used update_feature.layer3 = nn.Identity() update_feature.train() for params in update_feature.parameters(): params.requires_grad = True else: update_feature.eval() for params in update_feature.parameters(): params.requires_grad = False return update_feature def update_delta_for_gradients_input(self, delta_gaussians, grad_sign, normalized_grad, visibility_scale: Tensor | None = None): if self.cfg.input_gradient: delta_gaussians = delta_gaussians / self.cfg.input_gradient_scale if self.cfg.input_gradient_log: grad_sign = rearrange(grad_sign, "b n c -> (b n) c") # recover log scale for applying the deltas. # For loss calculation the delta should still be in log scale delta_gaussians = grad_sign * (delta_gaussians.exp() - 1e-8) if self.cfg.input_gradient_log_clip_deltas > 0: # clip the delta to avoid too large updates clip_value = self.cfg.input_gradient_log_clip_deltas clip_mask = delta_gaussians.abs() > clip_value delta_gaussians[clip_mask] = delta_gaussians[clip_mask].sign() * clip_value # TODO Naama: move these two, as they are not related to gradients if self.cfg.update_head_scale_mag: out_channels = delta_gaussians.shape[-1] param_num = out_channels / 2 assert param_num.is_integer() param_num = int(param_num) scale = delta_gaussians[:, :param_num] mag = delta_gaussians[:, param_num:] delta_gaussians = scale * 0.01 * torch.exp(mag * 0.01) if self.cfg.update_head_scalar_scale: if self.cfg.update_head_per_param_heads: # Already handled in _apply_per_param_heads — nothing to do here pass elif self.cfg.update_head_per_param_scales: # Feature B: per-group scalar scales num_groups = len(self._per_param_group_dims) scales = delta_gaussians[:, -num_groups:] # [G, num_groups] scales = self.scale_act(scales) deltas = delta_gaussians[:, :-num_groups] # [G, D] normalized_deltas = [] offset = 0 for i, (name, dim) in enumerate(self._per_param_group_dims.items()): group_delta = deltas[:, offset:offset + dim] # [G, dim] group_scale = scales[:, i:i + 1] # [G, 1] if dim > 1: group_delta = group_delta / (group_delta.norm(p=2, dim=-1, keepdim=True) + 1e-8) group_delta = group_delta * group_scale normalized_deltas.append(group_delta) offset += dim delta_gaussians = torch.cat(normalized_deltas, dim=-1) else: # Original global scalar scale scale = delta_gaussians[:, -1:] # [G, 1] scale = self.scale_act(scale) # make sure scale is positive deltas_unnorm = delta_gaussians[:, :-1] # [G, D] deltas_norm = deltas_unnorm / (deltas_unnorm.norm(p=2, dim=1, keepdim=True) + 1e-8) # [G, D] delta_gaussians = deltas_norm * scale if visibility_scale is not None: delta_gaussians = delta_gaussians * visibility_scale if self.cfg.scale_residual_grads: delta_gaussians = delta_gaussians * normalized_grad * self.cfg.gradient_update_scale # 1.0 # To match the default behavior of SGD, Adam, and other optimizers, deltas are negated. # SGD applies the gradients as `x = x - lr * grad`, while resaplt applies them as `x = x + lr * deltas`. delta_gaussians = -delta_gaussians return delta_gaussians def _calc_input_gradients(self, context, gaussians, renderer): assert not self.cfg.input_gradient_same_loss, "input_gradient_same_loss is not implemented" _, v, _, h, w = context["image"].shape with torch.enable_grad(): # Unpack gaussians means, scales, rotations_unnorm, opacities_raw, shs = unpack_gaussians( gaussians, scales_log=self.cfg.opt_scales_before_act, opacities_logit=True, opacities_unsqueeze=True, detach=True, clone=False, requires_grad=True, scales_lims=(self.cfg.clamp_min_scale, self.cfg.clamp_refine_max_scale), raw_opacities_lims=(self.cfg.clamp_min_raw_opacities, self.cfg.clamp_max_raw_opacities) ) # Create temporary Gaussians with same values but requires_grad=True grad_batch_size = self.cfg.input_gradients_chunk_size if grad_batch_size == -1: grad_batch_size = v gaussian_grads = 0 means2d_grads_chunks = [] nr_chunks = math.ceil(v / grad_batch_size) # Pre-compute shapes and config flags outside the loop shs_shape = (shs.shape[0], shs.shape[1], 3, -1) opt_scales_before_act = self.cfg.opt_scales_before_act # Pre-compute normalized rotations once (not in gradient inputs, so no grad needed) with torch.no_grad(): rotations = rotations_unnorm / (rotations_unnorm.norm(dim=-1, keepdim=True) + 1e-8) for chunk_idx, start, stop in chunk_index_iter(v, grad_batch_size): # zero grads means = means.detach().requires_grad_(True) scales = scales.detach().requires_grad_(True) rotations_unnorm = rotations_unnorm.detach().requires_grad_(True) opacities_raw = opacities_raw.detach().requires_grad_(True) shs = shs.detach().requires_grad_(True) # Apply activation to scales if needed (before calculating covariance) scales_act = scales.exp() if opt_scales_before_act else scales tmp_gaussians = Gaussians( means=means, covariances=None, harmonics=shs.view(shs_shape), opacities=torch.sigmoid(opacities_raw.squeeze(-1)), scales=scales_act, rotations=rotations, rotations_unnorm=rotations_unnorm, ) # render input views, calculate inner loss and backprop to get gradients context_render_output = renderer.forward_batch_subset( tmp_gaussians, context, start=start, end=stop, image_shape=(h, w), ) inputs = [means, scales, rotations_unnorm, opacities_raw, shs] if self.cfg.need_2d_grads: assert context_render_output.means2d is not None, "output_renderer.means2d is None" means2d = context_render_output.means2d # [B, V, N, 2] # means2d.retain_grad() # retain grad for means2d grads computation inputs.append(means2d) inner_loss = inner_loss_for_input_gradients( context["image"][:, start:stop], context_render_output, reduction=self.cfg.input_gradient_loss_reduction, with_ssim=self.cfg.input_gradient_with_ssim_loss, ) if self.cfg.opacity_reg_lambda > 0.0: inner_loss = inner_loss + self.cfg.opacity_reg_lambda * torch.sigmoid(opacities_raw).mean() grads = torch.autograd.grad(outputs=inner_loss, inputs=inputs, create_graph=False, retain_graph=False, ) gaussian_grads = gaussian_grads + torch.cat(grads[:5], dim=-1) # [B, G, D] assert not torch.isnan(gaussian_grads).any(), "NaN detected in gaussian_grads" if self.cfg.need_2d_grads: means2d_grads_chunks.append(grads[5]) # [B, V_chunk, N, 2] gaussian_grads = gaussian_grads / nr_chunks if self.cfg.need_2d_grads: means2d_grads = torch.cat(means2d_grads_chunks, dim=1) # [B, V, N, 2] if self.cfg.input_gradient_loss_reduction == "mean_pixels_sum_views": means2d_grads = means2d_grads / v else: means2d_grads = None gaussian_grads_raw = gaussian_grads * self.cfg.input_gradient_scale if self.cfg.input_gradient_log: # log gradients grads_sign = gaussian_grads.sign() gaussian_grads_raw = (gaussian_grads_raw.abs() + 1e-8).log() else: grads_sign = None # Detach gradients to avoid gradient flow through the input gaussian_grads = gaussian_grads.detach() gaussian_grads_raw = gaussian_grads_raw.detach() if grads_sign is not None: grads_sign = grads_sign.detach() # Returning also the render output, but it can only be used for visualization, # as we already backpropogate gradients through it return gaussian_grads_raw, gaussian_grads, grads_sign, context_render_output, means2d_grads def select_gaussian_subset(gaussians, window_start, window_end, v, h, w): """Select a subset of gaussians based on view window. Optimized to avoid rearrange overhead.""" b = gaussians.means.shape[0] hw = h * w window_v = window_end - window_start new_n = window_v * hw # Helper to slice view dimension efficiently using view+slice+reshape instead of rearrange def slice_tensor(t, extra_dims): # t shape: [b, v*h*w, *extra_dims] -> [b, window_v*h*w, *extra_dims] shape = (b, v, hw) + extra_dims new_shape = (b, new_n) + extra_dims return t.view(shape)[:, window_start:window_end, :].reshape(new_shape) means = slice_tensor(gaussians.means, (3,)) covariances = slice_tensor(gaussians.covariances, (3, 3)) if gaussians.covariances is not None else None shs = slice_tensor(gaussians.harmonics, gaussians.harmonics.shape[2:]) opacities = slice_tensor(gaussians.opacities.unsqueeze(-1), ()).squeeze(-1) scales = slice_tensor(gaussians.scales, (3,)) rotations = slice_tensor(gaussians.rotations, (4,)) if gaussians.rotations is not None else None rotations_unnorm = slice_tensor(gaussians.rotations_unnorm, (4,)) return Gaussians( means=means, covariances=covariances, harmonics=shs, opacities=opacities, scales=scales, rotations=rotations, rotations_unnorm=rotations_unnorm, ) def replace_window(original, window, window_start, window_end, dim=1): slices = [] if window_start > 0: # TODO: detach or not # slices.append(original[:, :window_start].detach()) slices.append(original[:, :window_start]) slices.append(window) if window_end < original.shape[dim]: # TODO: detach or not # slices.append(original[:, window_end:].detach()) slices.append(original[:, window_end:]) return torch.cat(slices, dim=dim) def freeze_batchnorm_layers(model): import torch.nn as nn for module in model.modules(): if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm3d): module.eval() # Set to evaluation mode for param in module.parameters(): param.requires_grad = False # Freeze parameters