Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from functools import partial | |
| from typing import Literal, List, Optional | |
| import torch | |
| from torch import Tensor | |
| from optgs.dataset.data_types import BatchedViews | |
| from optgs.misc.general_utils import get_expon_lr_func | |
| from optgs.misc.io import FrequencyScheduler | |
| from optgs.model.decoder.decoder import Decoder | |
| from optgs.model.types import Gaussians | |
| from optgs.scene_trainer.initializer import InitializerCfg | |
| from optgs.scene_trainer.optimizer.layer import AdamInputSmoothing | |
| from optgs.scene_trainer.optimizer.optimizer import ( | |
| OptimizerInput, | |
| OptimizerOutput, | |
| OptimizerCfg, NonlearnedOptimizer, | |
| ) | |
| from optgs.scene_trainer.optimizer.optimizer_utils import ( | |
| calc_input_gradients, | |
| squeeze_grad_dict, | |
| smooth_grads, | |
| ) | |
| class AdamOptimizerCfg(OptimizerCfg): | |
| name: Literal["adam"] | |
| # adam params | |
| betas: List[float | int] # Typically a list of two floats, e.g., [0.9, 0.999] | |
| eps: float | |
| weight_decay: float | |
| # learning rates | |
| base_lr: int | float | |
| means_lr_init: float | |
| means_lr_final: float | |
| means_lr_delay_mult: float | |
| means_lr_max_steps: int # should be equal to total optimization steps | |
| scales_lr: float | |
| rotations_lr: float | |
| opacities_lr: float | |
| sh0s_lr: float | |
| shNs_lr: float # 20 times less as sh0s_lr in original paper | |
| def update(self, initializer_cfg: InitializerCfg): | |
| pass | |
| class AdamOptimizer(NonlearnedOptimizer[AdamOptimizerCfg]): | |
| def __init__( | |
| self, cfg: AdamOptimizerCfg, save_every: Optional[FrequencyScheduler] = None | |
| ) -> None: | |
| super().__init__(cfg, save_every) | |
| self.smoothers = None | |
| self.means_lr_scheduler = None | |
| self._meta_bufs: dict = {} # reused across steps: radii, visibility buffers | |
| # NOTE: AdamOptimizer is evaluation-only (3DGS baseline); not used during meta-training. | |
| def _on_scene_start_impl(self, optimizer_input: OptimizerInput) -> None: | |
| super()._on_scene_start_impl(optimizer_input) | |
| # assert scene batch size 1 | |
| context = optimizer_input.context | |
| assert ( | |
| context["extrinsics"].shape[0] == context["intrinsics"].shape[0] == 1 | |
| ), "scene batch size > 1 not supported yet..." | |
| # instantiate Adam optimizers for each parameter type | |
| nr_gaussians = optimizer_input.prev_output.gaussians.means.shape[1] | |
| device = optimizer_input.prev_output.gaussians.means.device | |
| smoother_cls = partial(AdamInputSmoothing, beta1=self.cfg.betas[0], beta2=self.cfg.betas[1], eps=self.cfg.eps, | |
| device=device) | |
| means_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.means.shape[1:]) | |
| scales_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.scales.shape[1:]) | |
| rotations_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.rotations.shape[1:]) | |
| opacities_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.opacities.shape[1:]) | |
| sh0s_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.harmonics[..., :, :1].shape[1:]) | |
| init_gaussians = optimizer_input.prev_output.gaussians | |
| if init_gaussians.harmonics.shape[-1] > 1: | |
| shNs_smoother = smoother_cls(shape=(init_gaussians.harmonics[..., :, 1:]).shape[1:]) | |
| else: | |
| shNs_smoother = None | |
| self.smoothers = { | |
| "means": means_smoother, | |
| "scales": scales_smoother, | |
| "rotations": rotations_smoother, | |
| "opacities": opacities_smoother, | |
| "sh0s": sh0s_smoother, | |
| "shNs": shNs_smoother, | |
| } | |
| # get scene extent | |
| scene_scale = optimizer_input.context["scene_scale"] | |
| if scene_scale is None: | |
| scene_scale = torch.ones(1, 1, device=device) | |
| scene_scale = scene_scale[0].item() | |
| # initialize learning rate scheduler for means | |
| self.means_lr_scheduler = get_expon_lr_func( | |
| lr_init=self.cfg.means_lr_init * scene_scale, | |
| lr_final=self.cfg.means_lr_final * scene_scale, | |
| lr_delay_mult=self.cfg.means_lr_delay_mult, | |
| max_steps=self.cfg.means_lr_max_steps | |
| ) | |
| def on_scene_end(self) -> None: | |
| super().on_scene_end() | |
| self.smoothers = None | |
| self.means_lr_scheduler = None | |
| self._meta_bufs.clear() | |
| def _forward_impl( | |
| self, | |
| i: int, | |
| optimizer_input: OptimizerInput, | |
| optimizer_output: OptimizerOutput, | |
| full_context: BatchedViews, | |
| full_target: BatchedViews, | |
| **kwargs | |
| ) -> OptimizerOutput: | |
| # Timing | |
| self.iter_start.record() | |
| # Unpack | |
| iter_context: BatchedViews = optimizer_input.context | |
| target: BatchedViews = optimizer_input.target | |
| renderer: Decoder = optimizer_input.renderer | |
| b, v, _, h, w = iter_context["image"].shape | |
| assert b == 1, "Batch size > 1 not supported for post-processing" | |
| # Log number of gaussians | |
| self.nr_gaussians_log.append( | |
| optimizer_input.prev_output.gaussians.means.shape[1] | |
| ) | |
| # One optimization step | |
| res = self.apply_one_update_step(i, optimizer_input, optimizer_output, sh_degree=kwargs.get("sh_degree", None)) | |
| gaussians: Gaussians = res[0] | |
| meta_for_adc: dict = res[1] | |
| updates: dict[str, Tensor] = res[2] | |
| grads_raw: dict[str, Tensor] = res[3] | |
| normalized_grads: dict[str, Tensor] = res[4] | |
| learning_rates: dict[str, float] = res[5] | |
| # Densification and Pruning | |
| if self.cfg.any_adc: | |
| # Apply ADC | |
| self.apply_adc( | |
| i=i, v=v, h=h, w=w, | |
| adc_state=optimizer_input.prev_output.state.adc_state, | |
| gaussians=gaussians, | |
| meta=meta_for_adc, | |
| object_dict_to_adjust=self.smoothers | |
| ) | |
| # ADC changes N → cached buffers are invalid; re-make tensors as fresh leaves. | |
| # torch.cat (used by add_new/relocate) produces a non-leaf even with requires_grad=True, | |
| # so .grad is never populated by backward(). detach() cuts the grad_fn first. | |
| buf_nr_gaussians = self._meta_bufs['N'] | |
| actual_nr_gaussians = gaussians.means.shape[1] | |
| if buf_nr_gaussians != actual_nr_gaussians: | |
| self._meta_bufs.clear() | |
| # TODO Naama: need to think if the detach is necessary (was added during mcmc implementation) | |
| gaussians.means = gaussians.means.detach().requires_grad_(True) | |
| gaussians.scales = gaussians.scales.detach().requires_grad_(True) | |
| gaussians.rotations_unnorm = gaussians.rotations_unnorm.detach().requires_grad_(True) | |
| gaussians.opacities = gaussians.opacities.detach().requires_grad_(True) | |
| gaussians.harmonics = gaussians.harmonics.detach().requires_grad_(True) | |
| # Timing | |
| self._record_iter_timing() | |
| # TODO Naama: we can log stats with save_every, but need to change stuff later. | |
| # Log stats — guard with save_every | |
| if grads_raw is not None: # and self.save_every(i + 1, tag="info"): | |
| G = grads_raw["means"].shape[0] | |
| nonzero_grads = [(g.reshape(G, -1) != 0).any(dim=-1) for g in grads_raw.values() if g is not None] | |
| nonzero_grads = torch.stack(nonzero_grads) # [num_params, G] | |
| nonzero_grads = nonzero_grads.any(dim=0) # [G] | |
| self.nr_nonzero_grad_log.append(nonzero_grads.sum().item()) | |
| # Save updated gaussians (for next iteration) | |
| optimizer_input.prev_output.gaussians = gaussians | |
| # Info | |
| if self.save_every(i + 1, tag="info"): | |
| # save gaussians | |
| optimizer_output.gaussian_list.append(gaussians, detach_and_cpu=True, save_to_disk=False, no_cache=False) | |
| # Save delta stats | |
| assert optimizer_output.info is not None | |
| # log deltas | |
| if "deltas" not in optimizer_output.info: | |
| optimizer_output.info["deltas"] = [] | |
| optimizer_output.info["deltas"].append({k: v.cpu() for k, v in updates.items() if v is not None}) | |
| # log gradients | |
| if "grads" not in optimizer_output.info: | |
| optimizer_output.info["grads"] = [] | |
| optimizer_output.info["grads"].append({k: v.cpu() for k, v in grads_raw.items() if v is not None}) | |
| # log normalized gradients | |
| if "normalized_grads" not in optimizer_output.info: | |
| optimizer_output.info["normalized_grads"] = [] | |
| optimizer_output.info["normalized_grads"].append( | |
| {k: v.cpu() for k, v in normalized_grads.items() if v is not None}) | |
| # log learning rates | |
| if "learning_rates" not in optimizer_output.info: | |
| optimizer_output.info["learning_rates"] = [] | |
| optimizer_output.info["learning_rates"].append(learning_rates) | |
| # Check if output_path in kwargs | |
| output_path = kwargs.get("output_path", None) | |
| scene_name = kwargs.get("scene_name", None) | |
| # Plot stats | |
| # if self.cfg.any_adc: | |
| # self.plot_info(i, output_path=output_path, scene_name=scene_name) | |
| # Post-update context + target renders | |
| self._save_post_update_renders( | |
| i, optimizer_input, optimizer_output, gaussians, | |
| full_context, full_target, | |
| ) | |
| # Optimizer output is being changed in place, but for clarity we return it | |
| return optimizer_output | |
| def apply_one_update_step( | |
| self, i, optimizer_input: OptimizerInput, optimizer_output: OptimizerOutput, sh_degree: int | None = None | |
| ) -> tuple[Gaussians, dict | None, dict, dict[str, Tensor], dict[str, Tensor], dict[str, float]]: | |
| iter_context = optimizer_input.context | |
| b, v, _, h, w = iter_context["image"].shape | |
| renderer = optimizer_input.renderer | |
| gaussians = optimizer_input.prev_output.gaussians | |
| # if first iteration | |
| if i == 0: | |
| # assert gaussians stores activated values | |
| assert gaussians.stores_activated, "Gaussians must store activated values." | |
| # deactivate values in-place (avoids allocating new tensors) | |
| gaussians.scales.log_() # [B, N, 3] | |
| gaussians.opacities.logit_() | |
| gaussians.stores_activated = False | |
| # enable requires_grad once — .grad buffers persist across steps, | |
| # so backward() reuses them instead of allocating new tensors each call | |
| gaussians.means.requires_grad_(True) | |
| gaussians.scales.requires_grad_(True) | |
| gaussians.rotations_unnorm.requires_grad_(True) | |
| gaussians.opacities.requires_grad_(True) | |
| gaussians.harmonics.requires_grad_(True) | |
| else: | |
| # assert gaussians does not store activated values | |
| assert not gaussians.stores_activated, "Gaussians must not store activated values." | |
| # learning rates | |
| # TODO Naama: use current cfg field lr_scheduler, which also defines the lr per param | |
| assert self.means_lr_scheduler is not None, "means_lr_scheduler is not initialized" | |
| means_lr = self.means_lr_scheduler(i) * self.cfg.base_lr | |
| scales_lr = self.cfg.scales_lr * self.cfg.base_lr | |
| rotations_lr = self.cfg.rotations_lr * self.cfg.base_lr | |
| opacities_lr = self.cfg.opacities_lr * self.cfg.base_lr | |
| sh0s_lr = self.cfg.sh0s_lr * self.cfg.base_lr | |
| shNs_lr = self.cfg.shNs_lr * self.cfg.base_lr | |
| # scale learning rates by number of views in the batch | |
| # means_lr *= v | |
| # scales_lr *= v | |
| # rotations_lr *= v | |
| # opacities_lr *= v | |
| # sh0s_lr *= v | |
| # shNs_lr *= v | |
| assert ( | |
| iter_context["extrinsics"].shape[0] == iter_context["extrinsics"].shape[0] == 1 | |
| ), "scene batch size > 1 not supported for yet..." | |
| # unpack gaussians | |
| means = gaussians.means # [B, N, 3] | |
| rotations_unnorm = gaussians.rotations_unnorm # [B, N, 4] | |
| scales_raw = gaussians.scales # [B, N, 3] | |
| opacities_raw = gaussians.opacities # [B, N] | |
| shs = gaussians.harmonics # [B, N, 3, sh_d] | |
| self.decoder_event_start.record() | |
| loss, grads_raw, meta_for_adc = calc_input_gradients( | |
| iter_context, | |
| means, | |
| scales_raw, | |
| rotations_unnorm, | |
| opacities_raw, | |
| shs, | |
| renderer, | |
| need_2d_grads=self.cfg.need_2d_grads, | |
| chunk_size=self.cfg.input_gradients_chunk_size, | |
| any_adc=self.cfg.any_adc, | |
| sh_degree=sh_degree, | |
| meta_bufs=self._meta_bufs, | |
| opacity_reg_lambda=self.cfg.opacity_reg_lambda, | |
| ) | |
| self.decoder_event_end.record() | |
| # get updates from adam optimizer | |
| grads_raw = squeeze_grad_dict(grads_raw) | |
| assert self.smoothers is not None, "Smoothers not initialized" | |
| grads_adam = smooth_grads(grads_raw, self.smoothers) | |
| # update the gaussians parameters | |
| # Batch delta computation for contiguous params with _foreach_mul to reduce kernel launches. | |
| # no_refine flags are handled by excluding the param from the batch (delta stays None). | |
| _grad_lr_pairs = [ | |
| (grads_adam["means"], -means_lr, self.cfg.no_refine_mean), | |
| (grads_adam["scales"], -scales_lr, self.cfg.no_refine_scale), | |
| (grads_adam["rotations"], -rotations_lr, self.cfg.no_refine_rotation), | |
| (grads_adam["opacities"], -opacities_lr, self.cfg.no_refine_opacity), | |
| ] | |
| _active_grads = [g for g, lr, skip in _grad_lr_pairs if not skip] | |
| _active_lrs = [lr for g, lr, skip in _grad_lr_pairs if not skip] | |
| _active_deltas = torch._foreach_mul(_active_grads, _active_lrs) if _active_grads else [] | |
| _delta_iter = iter(_active_deltas) | |
| delta_means = next(_delta_iter) if not self.cfg.no_refine_mean else None | |
| delta_scales_raw = next(_delta_iter) if not self.cfg.no_refine_scale else None | |
| delta_rotations_unnorm = next(_delta_iter) if not self.cfg.no_refine_rotation else None | |
| delta_opacities_raw = next(_delta_iter) if not self.cfg.no_refine_opacity else None | |
| # SH deltas stay separate (non-contiguous slice views) | |
| delta_sh0s = None if self.cfg.no_refine_sh0 else -sh0s_lr * grads_adam["sh0s"] | |
| delta_shNs = None | |
| if grads_adam["shNs"] is not None and not self.cfg.no_refine_shN: | |
| delta_shNs = -shNs_lr * grads_adam["shNs"] | |
| # step — batch contiguous params with _foreach_add_ to reduce kernel launches; | |
| # SH slice views are non-contiguous so they stay separate | |
| _params = [means, scales_raw, rotations_unnorm, opacities_raw] | |
| _deltas = [delta_means, delta_scales_raw, delta_rotations_unnorm, delta_opacities_raw] | |
| _active = [(p, d) for p, d in zip(_params, _deltas) if d is not None] | |
| if _active: | |
| torch._foreach_add_([p for p, d in _active], [d for p, d in _active]) | |
| self.safe_inplace_update(delta_sh0s, shs[..., 0:1]) | |
| self.safe_inplace_update(delta_shNs, shs[..., 1:]) | |
| # assign (means/scales/rotations/harmonics are the same objects; in-place ops above | |
| # already updated their storage. opacities_raw is a view — do NOT reassign | |
| # gaussians.opacities here, as that would replace the persistent leaf with a non-leaf | |
| # view and break retain_grad() on subsequent steps.) | |
| gaussians.means = means | |
| gaussians.scales = scales_raw | |
| gaussians.rotations_unnorm = rotations_unnorm | |
| gaussians.harmonics = shs | |
| # group updates | |
| updates = { | |
| "means": delta_means, | |
| "scales": delta_scales_raw, | |
| "rotations": delta_rotations_unnorm, | |
| "opacities": delta_opacities_raw, | |
| "sh0s": delta_sh0s, | |
| "shNs": delta_shNs, | |
| } | |
| learning_rates = { | |
| "means": means_lr, | |
| "scales": scales_lr, | |
| "rotations": rotations_lr, | |
| "opacities": opacities_lr, | |
| "sh0s": sh0s_lr, | |
| "shNs": shNs_lr, | |
| } | |
| return gaussians, meta_for_adc, updates, grads_raw, grads_adam, learning_rates | |
| def safe_inplace_update(delta_means: Tensor | None, means: Tensor): | |
| if delta_means is not None: | |
| means += delta_means | |