"""Hydra-free checkpoint -> optimizer construction. Rebuilds the learned optimizer (architecture + weights) from a checkpoint *without* going through Hydra. Only ``_load_checkpoint_cfg`` + ``load_typed_config`` are used (both Hydra-free); the Hydra coupling lives in ``setup_cfg`` / ``merge_config_from_file`` / ``setup_output_dir`` which we never call. All heavy imports are deferred into the functions so ``import optgs`` stays cheap. """ from __future__ import annotations from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING from optgs.experimental.api.integration.scene_protocol import OptGSError @lru_cache(maxsize=8) def _load_ckpt_cfg_cached(cfg_path_str: str): """Load + migrate a checkpoint config once per path (read-only callers). ``build_optimizer_cfg`` / ``build_decoder`` / ``get_scene_trainer_scalar`` all need the same DictConfig; caching avoids re-parsing the file. """ from optgs.config import _load_checkpoint_cfg # Hydra-free return _load_checkpoint_cfg(Path(cfg_path_str)) def get_scene_trainer_scalar(cfg_path: Path, key: str, default): """Read ``scene_trainer.`` from a checkpoint config (or ``default``). Used for scalars that live on the (Hydra-free unavailable) scene-trainer config rather than the optimizer cfg: ``num_update_steps``, ``iter_batch_size``, ``sh_degree_interval``. """ from omegaconf import OmegaConf cfg = _load_ckpt_cfg_cached(str(cfg_path)) return OmegaConf.select(cfg, f"scene_trainer.{key}", default=default) if TYPE_CHECKING: # pragma: no cover - typing only from torch import nn from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizerCfg def _optimizer_class_by_cfg_name(): """Map a checkpoint config's ``scene_optimizer.name`` -> optimizer class. The registry (``SCENE_OPTIMIZERS``) keys on registry names (e.g. ``"depthsplat"``), but a checkpoint config's ``scene_optimizer.name`` is the cfg literal (``"knn_based"`` / ``"l2s"`` / ``"resplat_v1"`` / ``"resplat_v2"``). Each class asserts ``cfg.name`` is its ``OPTIMIZER_NAME`` or one of its ``OPTIMIZER_NAME_ALIASES`` (e.g. legacy ``"clogs"`` for ``Learn2SplatOptimizer``), so we dispatch on both. """ from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizer from optgs.scene_trainer.optimizer.optimizer_learn2splat import ( Learn2SplatOptimizer, ) from optgs.scene_trainer.optimizer.optimizer_resplat import ( ResplatOptimizerV1, ResplatOptimizerV2, ) classes = ( KnnBasedOptimizer, Learn2SplatOptimizer, ResplatOptimizerV1, ResplatOptimizerV2, ) mapping = {} for cls in classes: for name in (cls.OPTIMIZER_NAME, *getattr(cls, "OPTIMIZER_NAME_ALIASES", ())): mapping[name] = cls return mapping def _initializer_cfg_class(name: str): """Map ``scene_initializer.name`` -> its concrete typed Cfg dataclass. ``InitializerCfg`` is a PEP-604 union; dacite needs a concrete dataclass as the top-level target (a union is only resolvable as a *field* type). Keyed to match both ``SCENE_INITIALIZERS`` and each Cfg's ``name`` Literal. """ from optgs.scene_trainer.initializer import ( InitializerColmapCfg, InitializerEdgsCfg, InitializerPlyCfg, InitializerPointcloudCfg, InitializerRandomCfg, ResplatInitializerCfg, ) return { "resplat_v1": ResplatInitializerCfg, "resplat_v2": ResplatInitializerCfg, "colmap": InitializerColmapCfg, "ply": InitializerPlyCfg, "edgs": InitializerEdgsCfg, "random": InitializerRandomCfg, "pointcloud": InitializerPointcloudCfg, }.get(name) def _compose_default_group(group: str, value: str): """Hydra-compose the bundled default for ``scene_trainer.=``. Released checkpoints predate fields later added to the typed configs (e.g. ``scene_optimizer.refiner.fallback_means_lr``). The training/eval pipeline reconciles this by merging the checkpoint config over the *current* default config (config.py:merge_config_from_file). We mirror that: compose the bundled default for the group (e.g. ``scene_optimizer=knn_based`` -> base -> refiner:none, or ``scene_initializer=colmap``) so missing fields can be backfilled with current defaults while checkpoint values win for shared keys. Scoped use of ``hydra.compose`` (no ``@hydra.main`` / ``HydraConfig.get``, no app context); lazily imported so ``import optgs`` stays light. Returns ``None`` if composition fails (caller falls back to a strict parse). """ try: import optgs from hydra import compose, initialize_config_dir from hydra.core.global_hydra import GlobalHydra from omegaconf import OmegaConf config_dir = str(Path(optgs.__file__).resolve().parent / "config") GlobalHydra.instance().clear() try: with initialize_config_dir(version_base=None, config_dir=config_dir): composed = compose( config_name="main", overrides=[f"scene_trainer/{group}={value}"], ) finally: GlobalHydra.instance().clear() return OmegaConf.select(composed, f"scene_trainer.{group}") except Exception as e: # noqa: BLE001 - best-effort backfill print( f"[optgs] warning: could not compose default scene_trainer.{group}" f"={value} for back-compat merge ({type(e).__name__}: {e}); " f"parsing checkpoint config as-is." ) return None def build_optimizer_cfg(cfg_path: Path) -> tuple["KnnBasedOptimizerCfg", int | None]: """Load a checkpoint's saved config and return its typed optimizer cfg. Returns ``(KnnBasedOptimizerCfg, num_update_steps)`` where ``num_update_steps`` (the per-scene optimization step count) is read from ``scene_trainer.num_update_steps`` if present (it is NOT part of the optimizer cfg), else ``None``. """ from omegaconf import OmegaConf from optgs.config import load_typed_config from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizerCfg cfg = _load_ckpt_cfg_cached(str(cfg_path)) # read_omega_cfg + migrate; NO Hydra so = OmegaConf.select(cfg, "scene_trainer.scene_optimizer") name = OmegaConf.select(cfg, "scene_trainer.scene_optimizer.name") if so is None or name in (None, "none"): raise OptGSError( f"checkpoint config at {cfg_path} has no learned scene_optimizer " f"(scene_trainer.scene_optimizer={name!r}). OptGS needs a learned " f"optimizer checkpoint (knn_based / clogs / resplat_v1 / resplat_v2)." ) # Backfill fields a released (older) checkpoint config lacks with the # current defaults, then let checkpoint values win for shared keys # (mirrors config.py:merge_config_from_file's OmegaConf.merge). default_so = _compose_default_group("scene_optimizer", "knn_based") if default_so is not None: OmegaConf.set_struct(default_so, False) merged_so = OmegaConf.merge(default_so, so) else: merged_so = so try: opt_cfg = load_typed_config(merged_so, KnnBasedOptimizerCfg) except Exception as e: # dacite/omegaconf errors -> actionable message raise OptGSError( f"failed to parse scene_optimizer from {cfg_path} into " f"KnnBasedOptimizerCfg ({type(e).__name__}: {e})." ) from e # Mirror SceneTrainerCfg (scene_trainer_cfg.py: scene_optimizer.update( # scene_initializer)): wire the checkpoint's initializer cfg into the # optimizer cfg so the runtime-only fields init_gaussian_param_num / # init_sh_d / sh_d — absent from every config file — are populated before # the optimizer nn.Module is built. si = OmegaConf.select(cfg, "scene_trainer.scene_initializer") si_name = OmegaConf.select(cfg, "scene_trainer.scene_initializer.name") if si is None or si_name in (None, "none"): raise OptGSError( f"checkpoint config at {cfg_path} has no scene_initializer " f"(name={si_name!r}); cannot derive init_gaussian_param_num " f"required to build the optimizer." ) init_cls = _initializer_cfg_class(str(si_name)) if init_cls is None: raise OptGSError( f"unsupported scene_initializer.name={si_name!r} in {cfg_path}; " f"cannot derive init_gaussian_param_num for the optimizer." ) default_si = _compose_default_group("scene_initializer", str(si_name)) if default_si is not None: OmegaConf.set_struct(default_si, False) merged_si = OmegaConf.merge(default_si, si) else: merged_si = si try: init_cfg = load_typed_config(merged_si, init_cls) opt_cfg.update(init_cfg) # sets init_gaussian_param_num/init_sh_d/sh_d except Exception as e: raise OptGSError( f"failed to wire scene_initializer ({si_name!r}) into the " f"optimizer cfg from {cfg_path} ({type(e).__name__}: {e})." ) from e num_update_steps = OmegaConf.select( cfg, "scene_trainer.num_update_steps", default=None ) return opt_cfg, num_update_steps def build_decoder( cfg_path: Path, dataset_cfg: object, decoder_overrides: dict | None = None ) -> "nn.Module": """Build the renderer the checkpoint was trained with. Uses ``scene_trainer.decoder`` from the checkpoint config (NOT a hardcoded backend): the learned optimizer's in-loop render gradients must match the backend it trained with, and only the registered/available backends are usable (e.g. ``gsplat`` — the optgs default; the ``inria`` backend needs ``diff_gaussian_rasterization``, which is optional). ``dataset_cfg`` only needs a ``background_color`` attribute. ``decoder_overrides`` (e.g. ``rasterize_mode`` / ``eps2d``) take precedence over the checkpoint config. """ from omegaconf import OmegaConf from optgs.config import load_typed_config from optgs.model.decoder import DecoderCfg, get_decoder cfg = _load_ckpt_cfg_cached(str(cfg_path)) node = OmegaConf.select(cfg, "scene_trainer.decoder") if node is None: raise OptGSError( f"checkpoint config at {cfg_path} has no scene_trainer.decoder; " f"cannot rebuild the renderer the optimizer trained with." ) # gsplat decoder rasterize_mode / eps2d, by precedence: # caller override > checkpoint config > gsplat rasterization() default # (so an older checkpoint that omits a field behaves as plain gsplat would). if OmegaConf.select(node, "name") == "gsplat": import inspect from gsplat.rendering import rasterization sig = inspect.signature(rasterization).parameters node = OmegaConf.merge( OmegaConf.create( {f: sig[f].default for f in ("rasterize_mode", "eps2d") if f in sig} ), node, OmegaConf.create(dict(decoder_overrides or {})), ) try: decoder_cfg = load_typed_config(node, DecoderCfg) except Exception as e: raise OptGSError( f"failed to parse scene_trainer.decoder from {cfg_path} " f"({type(e).__name__}: {e})." ) from e try: return get_decoder(decoder_cfg, dataset_cfg) except (KeyError, ImportError) as e: raise OptGSError( f"decoder backend {decoder_cfg.name!r} is not available in this " f"environment ({type(e).__name__}: {e}). Install its backend " f"(e.g. diff_gaussian_rasterization for 'inria') or use a " f"checkpoint trained with the 'gsplat' decoder." ) from e def build_optimizer(opt_cfg: "KnnBasedOptimizerCfg") -> "nn.Module": """Construct the concrete learned optimizer for ``opt_cfg`` (no weights).""" from optgs.misc.io import FrequencyScheduler mapping = _optimizer_class_by_cfg_name() cls = mapping.get(opt_cfg.name) if cls is None: raise OptGSError( f"unsupported scene_optimizer.name={opt_cfg.name!r}; OptGS supports " f"{sorted(mapping)}." ) optimizer = cls(opt_cfg) # The optimizer's save_every (info/context/target/debug artifact dumps) is # wired by SceneTrainer during training; the optimizer calls it # unconditionally, so the API inference path — which has nothing to dump — # installs a disabled scheduler instead of leaving it None. save_every = FrequencyScheduler(last_step=0) save_every.disable(True) optimizer.save_every = save_every return optimizer def build_adam_baseline(num_refine: int) -> "nn.Module": """Build the codebase's 3DGS Adam optimizer for a fair baseline comparison. Uses the bundled ``scene_optimizer=3dgs`` config — gsplat's example hyperparameters (LRs, betas). Densification is disabled so the baseline refines the same fixed Gaussian set as the learned optimizer (a like-for-like update-rule comparison), and the means-LR decay horizon is set to ``num_refine``. Returns a ready-to-run ``AdamOptimizer``. """ from omegaconf import OmegaConf from optgs.config import load_typed_config from optgs.misc.io import FrequencyScheduler from optgs.scene_trainer.optimizer.optimizer_adam import ( AdamOptimizer, AdamOptimizerCfg, ) composed = _compose_default_group("scene_optimizer", "3dgs") if composed is None: raise OptGSError( "could not Hydra-compose the bundled 'scene_optimizer=3dgs' config " "for the Adam baseline." ) OmegaConf.set_struct(composed, False) # gsplat decays the means LR over the full step budget. composed.means_lr_max_steps = int(num_refine) # Disable densification — the baseline refines the same fixed Gaussian set # as the learned optimizer (a like-for-like comparison of the update rule). for flag in ("do_densify", "do_prune", "do_opacity_reset"): if flag in composed.refiner: composed.refiner[flag] = False try: adam_cfg = load_typed_config(composed, AdamOptimizerCfg) except Exception as e: raise OptGSError( f"failed to parse the bundled '3dgs' config into AdamOptimizerCfg " f"({type(e).__name__}: {e})." ) from e optimizer = AdamOptimizer(adam_cfg) save_every = FrequencyScheduler(last_step=0) # nothing to dump (see build_optimizer) save_every.disable(True) optimizer.save_every = save_every # AdamOptimizer is a NonlearnedOptimizer — already pinned to eval mode. return optimizer # Module-attribute renames applied when the legacy Resplat encoder was split # into separate initializer/optimizer modules (transcribed from # optgs/main.py:load_optimizer). _ORIG_OPTIMIZER_ATTR_RENAMES = { "render_error_mv_attn": "update_error_attn", } def load_optimizer_state( optimizer: "nn.Module", ckpt_path: str, init_state_wo_features: bool, strict: bool, ) -> None: """Load optimizer weights from ``ckpt_path`` into ``optimizer``. Transcribes the prefix-stripping / legacy-rename / feature-drop logic from ``optgs/main.py:load_optimizer`` (we cannot call that function: it needs a full Hydra ``cfg`` and a ``scene_trainer``). """ import torch state = torch.load(ckpt_path, map_location="cpu") if isinstance(state, dict) and "state_dict" in state: state = state["state_dict"] # Strip the Lightning "scene_trainer." prefix if present. state = {k.replace("scene_trainer.", ""): v for k, v in state.items()} if any(k.startswith("optimizer.") for k in state): # Unified repo format: keys are optimizer.* osd = { k[len("optimizer."):]: v for k, v in state.items() if k.startswith("optimizer.") } else: # Legacy Resplat format: keys are encoder.* (before init/opt split). osd = { k[len("encoder."):]: v for k, v in state.items() if k.startswith("encoder.") } renamed = {} for k, v in osd.items(): for old, new in _ORIG_OPTIMIZER_ATTR_RENAMES.items(): if k == old or k.startswith(old + "."): k = new + k[len(old):] break renamed[k] = v osd = renamed if not osd: raise OptGSError( f"no optimizer weights found in {ckpt_path} (looked for " f"'optimizer.*' or legacy 'encoder.*' keys)." ) if init_state_wo_features: osd = {k: v for k, v in osd.items() if "update_proj" not in k} optimizer.load_state_dict(osd, strict=strict)