Spaces:
Sleeping
Sleeping
| """Hydra-free checkpoint -> optimizer construction. | |
| Rebuilds the learned optimizer (architecture + weights) from a checkpoint | |
| *without* going through Hydra. Only ``_load_checkpoint_cfg`` + | |
| ``load_typed_config`` are used (both Hydra-free); the Hydra coupling lives in | |
| ``setup_cfg`` / ``merge_config_from_file`` / ``setup_output_dir`` which we never | |
| call. All heavy imports are deferred into the functions so ``import optgs`` | |
| stays cheap. | |
| """ | |
| from __future__ import annotations | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| from optgs.experimental.api.integration.scene_protocol import OptGSError | |
| def _load_ckpt_cfg_cached(cfg_path_str: str): | |
| """Load + migrate a checkpoint config once per path (read-only callers). | |
| ``build_optimizer_cfg`` / ``build_decoder`` / ``get_scene_trainer_scalar`` | |
| all need the same DictConfig; caching avoids re-parsing the file. | |
| """ | |
| from optgs.config import _load_checkpoint_cfg # Hydra-free | |
| return _load_checkpoint_cfg(Path(cfg_path_str)) | |
| def get_scene_trainer_scalar(cfg_path: Path, key: str, default): | |
| """Read ``scene_trainer.<key>`` from a checkpoint config (or ``default``). | |
| Used for scalars that live on the (Hydra-free unavailable) scene-trainer | |
| config rather than the optimizer cfg: ``num_update_steps``, | |
| ``iter_batch_size``, ``sh_degree_interval``. | |
| """ | |
| from omegaconf import OmegaConf | |
| cfg = _load_ckpt_cfg_cached(str(cfg_path)) | |
| return OmegaConf.select(cfg, f"scene_trainer.{key}", default=default) | |
| if TYPE_CHECKING: # pragma: no cover - typing only | |
| from torch import nn | |
| from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizerCfg | |
| def _optimizer_class_by_cfg_name(): | |
| """Map a checkpoint config's ``scene_optimizer.name`` -> optimizer class. | |
| The registry (``SCENE_OPTIMIZERS``) keys on registry names (e.g. | |
| ``"depthsplat"``), but a checkpoint config's ``scene_optimizer.name`` is the | |
| cfg literal (``"knn_based"`` / ``"l2s"`` / ``"resplat_v1"`` / | |
| ``"resplat_v2"``). Each class asserts ``cfg.name`` is its ``OPTIMIZER_NAME`` | |
| or one of its ``OPTIMIZER_NAME_ALIASES`` (e.g. legacy ``"clogs"`` for | |
| ``Learn2SplatOptimizer``), so we dispatch on both. | |
| """ | |
| from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizer | |
| from optgs.scene_trainer.optimizer.optimizer_learn2splat import ( | |
| Learn2SplatOptimizer, | |
| ) | |
| from optgs.scene_trainer.optimizer.optimizer_resplat import ( | |
| ResplatOptimizerV1, | |
| ResplatOptimizerV2, | |
| ) | |
| classes = ( | |
| KnnBasedOptimizer, | |
| Learn2SplatOptimizer, | |
| ResplatOptimizerV1, | |
| ResplatOptimizerV2, | |
| ) | |
| mapping = {} | |
| for cls in classes: | |
| for name in (cls.OPTIMIZER_NAME, *getattr(cls, "OPTIMIZER_NAME_ALIASES", ())): | |
| mapping[name] = cls | |
| return mapping | |
| def _initializer_cfg_class(name: str): | |
| """Map ``scene_initializer.name`` -> its concrete typed Cfg dataclass. | |
| ``InitializerCfg`` is a PEP-604 union; dacite needs a concrete dataclass | |
| as the top-level target (a union is only resolvable as a *field* type). | |
| Keyed to match both ``SCENE_INITIALIZERS`` and each Cfg's ``name`` | |
| Literal. | |
| """ | |
| from optgs.scene_trainer.initializer import ( | |
| InitializerColmapCfg, | |
| InitializerEdgsCfg, | |
| InitializerPlyCfg, | |
| InitializerPointcloudCfg, | |
| InitializerRandomCfg, | |
| ResplatInitializerCfg, | |
| ) | |
| return { | |
| "resplat_v1": ResplatInitializerCfg, | |
| "resplat_v2": ResplatInitializerCfg, | |
| "colmap": InitializerColmapCfg, | |
| "ply": InitializerPlyCfg, | |
| "edgs": InitializerEdgsCfg, | |
| "random": InitializerRandomCfg, | |
| "pointcloud": InitializerPointcloudCfg, | |
| }.get(name) | |
| def _compose_default_group(group: str, value: str): | |
| """Hydra-compose the bundled default for ``scene_trainer.<group>=<value>``. | |
| Released checkpoints predate fields later added to the typed configs | |
| (e.g. ``scene_optimizer.refiner.fallback_means_lr``). The training/eval | |
| pipeline reconciles this by merging the checkpoint config over the | |
| *current* default config (config.py:merge_config_from_file). We mirror | |
| that: compose the bundled default for the group (e.g. | |
| ``scene_optimizer=knn_based`` -> base -> refiner:none, or | |
| ``scene_initializer=colmap``) so missing fields can be backfilled with | |
| current defaults while checkpoint values win for shared keys. | |
| Scoped use of ``hydra.compose`` (no ``@hydra.main`` / ``HydraConfig.get``, | |
| no app context); lazily imported so ``import optgs`` stays light. Returns | |
| ``None`` if composition fails (caller falls back to a strict parse). | |
| """ | |
| try: | |
| import optgs | |
| from hydra import compose, initialize_config_dir | |
| from hydra.core.global_hydra import GlobalHydra | |
| from omegaconf import OmegaConf | |
| config_dir = str(Path(optgs.__file__).resolve().parent / "config") | |
| GlobalHydra.instance().clear() | |
| try: | |
| with initialize_config_dir(version_base=None, config_dir=config_dir): | |
| composed = compose( | |
| config_name="main", | |
| overrides=[f"scene_trainer/{group}={value}"], | |
| ) | |
| finally: | |
| GlobalHydra.instance().clear() | |
| return OmegaConf.select(composed, f"scene_trainer.{group}") | |
| except Exception as e: # noqa: BLE001 - best-effort backfill | |
| print( | |
| f"[optgs] warning: could not compose default scene_trainer.{group}" | |
| f"={value} for back-compat merge ({type(e).__name__}: {e}); " | |
| f"parsing checkpoint config as-is." | |
| ) | |
| return None | |
| def build_optimizer_cfg(cfg_path: Path) -> tuple["KnnBasedOptimizerCfg", int | None]: | |
| """Load a checkpoint's saved config and return its typed optimizer cfg. | |
| Returns ``(KnnBasedOptimizerCfg, num_update_steps)`` where | |
| ``num_update_steps`` (the per-scene optimization step count) is read from | |
| ``scene_trainer.num_update_steps`` if present (it is NOT part of the | |
| optimizer cfg), else ``None``. | |
| """ | |
| from omegaconf import OmegaConf | |
| from optgs.config import load_typed_config | |
| from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizerCfg | |
| cfg = _load_ckpt_cfg_cached(str(cfg_path)) # read_omega_cfg + migrate; NO Hydra | |
| so = OmegaConf.select(cfg, "scene_trainer.scene_optimizer") | |
| name = OmegaConf.select(cfg, "scene_trainer.scene_optimizer.name") | |
| if so is None or name in (None, "none"): | |
| raise OptGSError( | |
| f"checkpoint config at {cfg_path} has no learned scene_optimizer " | |
| f"(scene_trainer.scene_optimizer={name!r}). OptGS needs a learned " | |
| f"optimizer checkpoint (knn_based / clogs / resplat_v1 / resplat_v2)." | |
| ) | |
| # Backfill fields a released (older) checkpoint config lacks with the | |
| # current defaults, then let checkpoint values win for shared keys | |
| # (mirrors config.py:merge_config_from_file's OmegaConf.merge). | |
| default_so = _compose_default_group("scene_optimizer", "knn_based") | |
| if default_so is not None: | |
| OmegaConf.set_struct(default_so, False) | |
| merged_so = OmegaConf.merge(default_so, so) | |
| else: | |
| merged_so = so | |
| try: | |
| opt_cfg = load_typed_config(merged_so, KnnBasedOptimizerCfg) | |
| except Exception as e: # dacite/omegaconf errors -> actionable message | |
| raise OptGSError( | |
| f"failed to parse scene_optimizer from {cfg_path} into " | |
| f"KnnBasedOptimizerCfg ({type(e).__name__}: {e})." | |
| ) from e | |
| # Mirror SceneTrainerCfg (scene_trainer_cfg.py: scene_optimizer.update( | |
| # scene_initializer)): wire the checkpoint's initializer cfg into the | |
| # optimizer cfg so the runtime-only fields init_gaussian_param_num / | |
| # init_sh_d / sh_d — absent from every config file — are populated before | |
| # the optimizer nn.Module is built. | |
| si = OmegaConf.select(cfg, "scene_trainer.scene_initializer") | |
| si_name = OmegaConf.select(cfg, "scene_trainer.scene_initializer.name") | |
| if si is None or si_name in (None, "none"): | |
| raise OptGSError( | |
| f"checkpoint config at {cfg_path} has no scene_initializer " | |
| f"(name={si_name!r}); cannot derive init_gaussian_param_num " | |
| f"required to build the optimizer." | |
| ) | |
| init_cls = _initializer_cfg_class(str(si_name)) | |
| if init_cls is None: | |
| raise OptGSError( | |
| f"unsupported scene_initializer.name={si_name!r} in {cfg_path}; " | |
| f"cannot derive init_gaussian_param_num for the optimizer." | |
| ) | |
| default_si = _compose_default_group("scene_initializer", str(si_name)) | |
| if default_si is not None: | |
| OmegaConf.set_struct(default_si, False) | |
| merged_si = OmegaConf.merge(default_si, si) | |
| else: | |
| merged_si = si | |
| try: | |
| init_cfg = load_typed_config(merged_si, init_cls) | |
| opt_cfg.update(init_cfg) # sets init_gaussian_param_num/init_sh_d/sh_d | |
| except Exception as e: | |
| raise OptGSError( | |
| f"failed to wire scene_initializer ({si_name!r}) into the " | |
| f"optimizer cfg from {cfg_path} ({type(e).__name__}: {e})." | |
| ) from e | |
| num_update_steps = OmegaConf.select( | |
| cfg, "scene_trainer.num_update_steps", default=None | |
| ) | |
| return opt_cfg, num_update_steps | |
| def build_decoder( | |
| cfg_path: Path, dataset_cfg: object, decoder_overrides: dict | None = None | |
| ) -> "nn.Module": | |
| """Build the renderer the checkpoint was trained with. | |
| Uses ``scene_trainer.decoder`` from the checkpoint config (NOT a hardcoded | |
| backend): the learned optimizer's in-loop render gradients must match the | |
| backend it trained with, and only the registered/available backends are | |
| usable (e.g. ``gsplat`` — the optgs default; the ``inria`` backend needs | |
| ``diff_gaussian_rasterization``, which is optional). ``dataset_cfg`` only | |
| needs a ``background_color`` attribute. ``decoder_overrides`` (e.g. | |
| ``rasterize_mode`` / ``eps2d``) take precedence over the checkpoint config. | |
| """ | |
| from omegaconf import OmegaConf | |
| from optgs.config import load_typed_config | |
| from optgs.model.decoder import DecoderCfg, get_decoder | |
| cfg = _load_ckpt_cfg_cached(str(cfg_path)) | |
| node = OmegaConf.select(cfg, "scene_trainer.decoder") | |
| if node is None: | |
| raise OptGSError( | |
| f"checkpoint config at {cfg_path} has no scene_trainer.decoder; " | |
| f"cannot rebuild the renderer the optimizer trained with." | |
| ) | |
| # gsplat decoder rasterize_mode / eps2d, by precedence: | |
| # caller override > checkpoint config > gsplat rasterization() default | |
| # (so an older checkpoint that omits a field behaves as plain gsplat would). | |
| if OmegaConf.select(node, "name") == "gsplat": | |
| import inspect | |
| from gsplat.rendering import rasterization | |
| sig = inspect.signature(rasterization).parameters | |
| node = OmegaConf.merge( | |
| OmegaConf.create( | |
| {f: sig[f].default for f in ("rasterize_mode", "eps2d") if f in sig} | |
| ), | |
| node, | |
| OmegaConf.create(dict(decoder_overrides or {})), | |
| ) | |
| try: | |
| decoder_cfg = load_typed_config(node, DecoderCfg) | |
| except Exception as e: | |
| raise OptGSError( | |
| f"failed to parse scene_trainer.decoder from {cfg_path} " | |
| f"({type(e).__name__}: {e})." | |
| ) from e | |
| try: | |
| return get_decoder(decoder_cfg, dataset_cfg) | |
| except (KeyError, ImportError) as e: | |
| raise OptGSError( | |
| f"decoder backend {decoder_cfg.name!r} is not available in this " | |
| f"environment ({type(e).__name__}: {e}). Install its backend " | |
| f"(e.g. diff_gaussian_rasterization for 'inria') or use a " | |
| f"checkpoint trained with the 'gsplat' decoder." | |
| ) from e | |
| def build_optimizer(opt_cfg: "KnnBasedOptimizerCfg") -> "nn.Module": | |
| """Construct the concrete learned optimizer for ``opt_cfg`` (no weights).""" | |
| from optgs.misc.io import FrequencyScheduler | |
| mapping = _optimizer_class_by_cfg_name() | |
| cls = mapping.get(opt_cfg.name) | |
| if cls is None: | |
| raise OptGSError( | |
| f"unsupported scene_optimizer.name={opt_cfg.name!r}; OptGS supports " | |
| f"{sorted(mapping)}." | |
| ) | |
| optimizer = cls(opt_cfg) | |
| # The optimizer's save_every (info/context/target/debug artifact dumps) is | |
| # wired by SceneTrainer during training; the optimizer calls it | |
| # unconditionally, so the API inference path — which has nothing to dump — | |
| # installs a disabled scheduler instead of leaving it None. | |
| save_every = FrequencyScheduler(last_step=0) | |
| save_every.disable(True) | |
| optimizer.save_every = save_every | |
| return optimizer | |
| def build_adam_baseline(num_refine: int) -> "nn.Module": | |
| """Build the codebase's 3DGS Adam optimizer for a fair baseline comparison. | |
| Uses the bundled ``scene_optimizer=3dgs`` config — gsplat's example | |
| hyperparameters (LRs, betas). Densification is disabled so the baseline | |
| refines the same fixed Gaussian set as the learned optimizer (a | |
| like-for-like update-rule comparison), and the means-LR decay horizon is set | |
| to ``num_refine``. Returns a ready-to-run ``AdamOptimizer``. | |
| """ | |
| from omegaconf import OmegaConf | |
| from optgs.config import load_typed_config | |
| from optgs.misc.io import FrequencyScheduler | |
| from optgs.scene_trainer.optimizer.optimizer_adam import ( | |
| AdamOptimizer, | |
| AdamOptimizerCfg, | |
| ) | |
| composed = _compose_default_group("scene_optimizer", "3dgs") | |
| if composed is None: | |
| raise OptGSError( | |
| "could not Hydra-compose the bundled 'scene_optimizer=3dgs' config " | |
| "for the Adam baseline." | |
| ) | |
| OmegaConf.set_struct(composed, False) | |
| # gsplat decays the means LR over the full step budget. | |
| composed.means_lr_max_steps = int(num_refine) | |
| # Disable densification — the baseline refines the same fixed Gaussian set | |
| # as the learned optimizer (a like-for-like comparison of the update rule). | |
| for flag in ("do_densify", "do_prune", "do_opacity_reset"): | |
| if flag in composed.refiner: | |
| composed.refiner[flag] = False | |
| try: | |
| adam_cfg = load_typed_config(composed, AdamOptimizerCfg) | |
| except Exception as e: | |
| raise OptGSError( | |
| f"failed to parse the bundled '3dgs' config into AdamOptimizerCfg " | |
| f"({type(e).__name__}: {e})." | |
| ) from e | |
| optimizer = AdamOptimizer(adam_cfg) | |
| save_every = FrequencyScheduler(last_step=0) # nothing to dump (see build_optimizer) | |
| save_every.disable(True) | |
| optimizer.save_every = save_every | |
| # AdamOptimizer is a NonlearnedOptimizer — already pinned to eval mode. | |
| return optimizer | |
| # Module-attribute renames applied when the legacy Resplat encoder was split | |
| # into separate initializer/optimizer modules (transcribed from | |
| # optgs/main.py:load_optimizer). | |
| _ORIG_OPTIMIZER_ATTR_RENAMES = { | |
| "render_error_mv_attn": "update_error_attn", | |
| } | |
| def load_optimizer_state( | |
| optimizer: "nn.Module", | |
| ckpt_path: str, | |
| init_state_wo_features: bool, | |
| strict: bool, | |
| ) -> None: | |
| """Load optimizer weights from ``ckpt_path`` into ``optimizer``. | |
| Transcribes the prefix-stripping / legacy-rename / feature-drop logic from | |
| ``optgs/main.py:load_optimizer`` (we cannot call that function: it needs a | |
| full Hydra ``cfg`` and a ``scene_trainer``). | |
| """ | |
| import torch | |
| state = torch.load(ckpt_path, map_location="cpu") | |
| if isinstance(state, dict) and "state_dict" in state: | |
| state = state["state_dict"] | |
| # Strip the Lightning "scene_trainer." prefix if present. | |
| state = {k.replace("scene_trainer.", ""): v for k, v in state.items()} | |
| if any(k.startswith("optimizer.") for k in state): | |
| # Unified repo format: keys are optimizer.* | |
| osd = { | |
| k[len("optimizer."):]: v | |
| for k, v in state.items() | |
| if k.startswith("optimizer.") | |
| } | |
| else: | |
| # Legacy Resplat format: keys are encoder.* (before init/opt split). | |
| osd = { | |
| k[len("encoder."):]: v | |
| for k, v in state.items() | |
| if k.startswith("encoder.") | |
| } | |
| renamed = {} | |
| for k, v in osd.items(): | |
| for old, new in _ORIG_OPTIMIZER_ATTR_RENAMES.items(): | |
| if k == old or k.startswith(old + "."): | |
| k = new + k[len(old):] | |
| break | |
| renamed[k] = v | |
| osd = renamed | |
| if not osd: | |
| raise OptGSError( | |
| f"no optimizer weights found in {ckpt_path} (looked for " | |
| f"'optimizer.*' or legacy 'encoder.*' keys)." | |
| ) | |
| if init_state_wo_features: | |
| osd = {k: v for k, v in osd.items() if "update_proj" not in k} | |
| optimizer.load_state_dict(osd, strict=strict) | |