File size: 17,007 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
"""Hydra-free checkpoint -> optimizer construction.

Rebuilds the learned optimizer (architecture + weights) from a checkpoint
*without* going through Hydra. Only ``_load_checkpoint_cfg`` +
``load_typed_config`` are used (both Hydra-free); the Hydra coupling lives in
``setup_cfg`` / ``merge_config_from_file`` / ``setup_output_dir`` which we never
call. All heavy imports are deferred into the functions so ``import optgs``
stays cheap.
"""

from __future__ import annotations

from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING

from optgs.experimental.api.integration.scene_protocol import OptGSError


@lru_cache(maxsize=8)
def _load_ckpt_cfg_cached(cfg_path_str: str):
    """Load + migrate a checkpoint config once per path (read-only callers).

    ``build_optimizer_cfg`` / ``build_decoder`` / ``get_scene_trainer_scalar``
    all need the same DictConfig; caching avoids re-parsing the file.
    """
    from optgs.config import _load_checkpoint_cfg  # Hydra-free

    return _load_checkpoint_cfg(Path(cfg_path_str))


def get_scene_trainer_scalar(cfg_path: Path, key: str, default):
    """Read ``scene_trainer.<key>`` from a checkpoint config (or ``default``).

    Used for scalars that live on the (Hydra-free unavailable) scene-trainer
    config rather than the optimizer cfg: ``num_update_steps``,
    ``iter_batch_size``, ``sh_degree_interval``.
    """
    from omegaconf import OmegaConf

    cfg = _load_ckpt_cfg_cached(str(cfg_path))
    return OmegaConf.select(cfg, f"scene_trainer.{key}", default=default)

if TYPE_CHECKING:  # pragma: no cover - typing only
    from torch import nn

    from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizerCfg


def _optimizer_class_by_cfg_name():
    """Map a checkpoint config's ``scene_optimizer.name`` -> optimizer class.

    The registry (``SCENE_OPTIMIZERS``) keys on registry names (e.g.
    ``"depthsplat"``), but a checkpoint config's ``scene_optimizer.name`` is the
    cfg literal (``"knn_based"`` / ``"l2s"`` / ``"resplat_v1"`` /
    ``"resplat_v2"``). Each class asserts ``cfg.name`` is its ``OPTIMIZER_NAME``
    or one of its ``OPTIMIZER_NAME_ALIASES`` (e.g. legacy ``"clogs"`` for
    ``Learn2SplatOptimizer``), so we dispatch on both.
    """
    from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizer
    from optgs.scene_trainer.optimizer.optimizer_learn2splat import (
        Learn2SplatOptimizer,
    )
    from optgs.scene_trainer.optimizer.optimizer_resplat import (
        ResplatOptimizerV1,
        ResplatOptimizerV2,
    )

    classes = (
        KnnBasedOptimizer,
        Learn2SplatOptimizer,
        ResplatOptimizerV1,
        ResplatOptimizerV2,
    )
    mapping = {}
    for cls in classes:
        for name in (cls.OPTIMIZER_NAME, *getattr(cls, "OPTIMIZER_NAME_ALIASES", ())):
            mapping[name] = cls
    return mapping


def _initializer_cfg_class(name: str):
    """Map ``scene_initializer.name`` -> its concrete typed Cfg dataclass.

    ``InitializerCfg`` is a PEP-604 union; dacite needs a concrete dataclass
    as the top-level target (a union is only resolvable as a *field* type).
    Keyed to match both ``SCENE_INITIALIZERS`` and each Cfg's ``name``
    Literal.
    """
    from optgs.scene_trainer.initializer import (
        InitializerColmapCfg,
        InitializerEdgsCfg,
        InitializerPlyCfg,
        InitializerPointcloudCfg,
        InitializerRandomCfg,
        ResplatInitializerCfg,
    )

    return {
        "resplat_v1": ResplatInitializerCfg,
        "resplat_v2": ResplatInitializerCfg,
        "colmap": InitializerColmapCfg,
        "ply": InitializerPlyCfg,
        "edgs": InitializerEdgsCfg,
        "random": InitializerRandomCfg,
        "pointcloud": InitializerPointcloudCfg,
    }.get(name)


def _compose_default_group(group: str, value: str):
    """Hydra-compose the bundled default for ``scene_trainer.<group>=<value>``.

    Released checkpoints predate fields later added to the typed configs
    (e.g. ``scene_optimizer.refiner.fallback_means_lr``). The training/eval
    pipeline reconciles this by merging the checkpoint config over the
    *current* default config (config.py:merge_config_from_file). We mirror
    that: compose the bundled default for the group (e.g.
    ``scene_optimizer=knn_based`` -> base -> refiner:none, or
    ``scene_initializer=colmap``) so missing fields can be backfilled with
    current defaults while checkpoint values win for shared keys.

    Scoped use of ``hydra.compose`` (no ``@hydra.main`` / ``HydraConfig.get``,
    no app context); lazily imported so ``import optgs`` stays light. Returns
    ``None`` if composition fails (caller falls back to a strict parse).
    """
    try:
        import optgs
        from hydra import compose, initialize_config_dir
        from hydra.core.global_hydra import GlobalHydra
        from omegaconf import OmegaConf

        config_dir = str(Path(optgs.__file__).resolve().parent / "config")
        GlobalHydra.instance().clear()
        try:
            with initialize_config_dir(version_base=None, config_dir=config_dir):
                composed = compose(
                    config_name="main",
                    overrides=[f"scene_trainer/{group}={value}"],
                )
        finally:
            GlobalHydra.instance().clear()
        return OmegaConf.select(composed, f"scene_trainer.{group}")
    except Exception as e:  # noqa: BLE001 - best-effort backfill
        print(
            f"[optgs] warning: could not compose default scene_trainer.{group}"
            f"={value} for back-compat merge ({type(e).__name__}: {e}); "
            f"parsing checkpoint config as-is."
        )
        return None


def build_optimizer_cfg(cfg_path: Path) -> tuple["KnnBasedOptimizerCfg", int | None]:
    """Load a checkpoint's saved config and return its typed optimizer cfg.

    Returns ``(KnnBasedOptimizerCfg, num_update_steps)`` where
    ``num_update_steps`` (the per-scene optimization step count) is read from
    ``scene_trainer.num_update_steps`` if present (it is NOT part of the
    optimizer cfg), else ``None``.
    """
    from omegaconf import OmegaConf

    from optgs.config import load_typed_config
    from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizerCfg

    cfg = _load_ckpt_cfg_cached(str(cfg_path))  # read_omega_cfg + migrate; NO Hydra
    so = OmegaConf.select(cfg, "scene_trainer.scene_optimizer")
    name = OmegaConf.select(cfg, "scene_trainer.scene_optimizer.name")
    if so is None or name in (None, "none"):
        raise OptGSError(
            f"checkpoint config at {cfg_path} has no learned scene_optimizer "
            f"(scene_trainer.scene_optimizer={name!r}). OptGS needs a learned "
            f"optimizer checkpoint (knn_based / clogs / resplat_v1 / resplat_v2)."
        )
    # Backfill fields a released (older) checkpoint config lacks with the
    # current defaults, then let checkpoint values win for shared keys
    # (mirrors config.py:merge_config_from_file's OmegaConf.merge).
    default_so = _compose_default_group("scene_optimizer", "knn_based")
    if default_so is not None:
        OmegaConf.set_struct(default_so, False)
        merged_so = OmegaConf.merge(default_so, so)
    else:
        merged_so = so
    try:
        opt_cfg = load_typed_config(merged_so, KnnBasedOptimizerCfg)
    except Exception as e:  # dacite/omegaconf errors -> actionable message
        raise OptGSError(
            f"failed to parse scene_optimizer from {cfg_path} into "
            f"KnnBasedOptimizerCfg ({type(e).__name__}: {e})."
        ) from e

    # Mirror SceneTrainerCfg (scene_trainer_cfg.py: scene_optimizer.update(
    # scene_initializer)): wire the checkpoint's initializer cfg into the
    # optimizer cfg so the runtime-only fields init_gaussian_param_num /
    # init_sh_d / sh_d — absent from every config file — are populated before
    # the optimizer nn.Module is built.
    si = OmegaConf.select(cfg, "scene_trainer.scene_initializer")
    si_name = OmegaConf.select(cfg, "scene_trainer.scene_initializer.name")
    if si is None or si_name in (None, "none"):
        raise OptGSError(
            f"checkpoint config at {cfg_path} has no scene_initializer "
            f"(name={si_name!r}); cannot derive init_gaussian_param_num "
            f"required to build the optimizer."
        )
    init_cls = _initializer_cfg_class(str(si_name))
    if init_cls is None:
        raise OptGSError(
            f"unsupported scene_initializer.name={si_name!r} in {cfg_path}; "
            f"cannot derive init_gaussian_param_num for the optimizer."
        )
    default_si = _compose_default_group("scene_initializer", str(si_name))
    if default_si is not None:
        OmegaConf.set_struct(default_si, False)
        merged_si = OmegaConf.merge(default_si, si)
    else:
        merged_si = si
    try:
        init_cfg = load_typed_config(merged_si, init_cls)
        opt_cfg.update(init_cfg)  # sets init_gaussian_param_num/init_sh_d/sh_d
    except Exception as e:
        raise OptGSError(
            f"failed to wire scene_initializer ({si_name!r}) into the "
            f"optimizer cfg from {cfg_path} ({type(e).__name__}: {e})."
        ) from e

    num_update_steps = OmegaConf.select(
        cfg, "scene_trainer.num_update_steps", default=None
    )
    return opt_cfg, num_update_steps


def build_decoder(
    cfg_path: Path, dataset_cfg: object, decoder_overrides: dict | None = None
) -> "nn.Module":
    """Build the renderer the checkpoint was trained with.

    Uses ``scene_trainer.decoder`` from the checkpoint config (NOT a hardcoded
    backend): the learned optimizer's in-loop render gradients must match the
    backend it trained with, and only the registered/available backends are
    usable (e.g. ``gsplat`` — the optgs default; the ``inria`` backend needs
    ``diff_gaussian_rasterization``, which is optional). ``dataset_cfg`` only
    needs a ``background_color`` attribute. ``decoder_overrides`` (e.g.
    ``rasterize_mode`` / ``eps2d``) take precedence over the checkpoint config.
    """
    from omegaconf import OmegaConf

    from optgs.config import load_typed_config
    from optgs.model.decoder import DecoderCfg, get_decoder

    cfg = _load_ckpt_cfg_cached(str(cfg_path))
    node = OmegaConf.select(cfg, "scene_trainer.decoder")
    if node is None:
        raise OptGSError(
            f"checkpoint config at {cfg_path} has no scene_trainer.decoder; "
            f"cannot rebuild the renderer the optimizer trained with."
        )
    # gsplat decoder rasterize_mode / eps2d, by precedence:
    #   caller override  >  checkpoint config  >  gsplat rasterization() default
    # (so an older checkpoint that omits a field behaves as plain gsplat would).
    if OmegaConf.select(node, "name") == "gsplat":
        import inspect

        from gsplat.rendering import rasterization

        sig = inspect.signature(rasterization).parameters
        node = OmegaConf.merge(
            OmegaConf.create(
                {f: sig[f].default for f in ("rasterize_mode", "eps2d") if f in sig}
            ),
            node,
            OmegaConf.create(dict(decoder_overrides or {})),
        )
    try:
        decoder_cfg = load_typed_config(node, DecoderCfg)
    except Exception as e:
        raise OptGSError(
            f"failed to parse scene_trainer.decoder from {cfg_path} "
            f"({type(e).__name__}: {e})."
        ) from e
    try:
        return get_decoder(decoder_cfg, dataset_cfg)
    except (KeyError, ImportError) as e:
        raise OptGSError(
            f"decoder backend {decoder_cfg.name!r} is not available in this "
            f"environment ({type(e).__name__}: {e}). Install its backend "
            f"(e.g. diff_gaussian_rasterization for 'inria') or use a "
            f"checkpoint trained with the 'gsplat' decoder."
        ) from e


def build_optimizer(opt_cfg: "KnnBasedOptimizerCfg") -> "nn.Module":
    """Construct the concrete learned optimizer for ``opt_cfg`` (no weights)."""
    from optgs.misc.io import FrequencyScheduler

    mapping = _optimizer_class_by_cfg_name()
    cls = mapping.get(opt_cfg.name)
    if cls is None:
        raise OptGSError(
            f"unsupported scene_optimizer.name={opt_cfg.name!r}; OptGS supports "
            f"{sorted(mapping)}."
        )
    optimizer = cls(opt_cfg)
    # The optimizer's save_every (info/context/target/debug artifact dumps) is
    # wired by SceneTrainer during training; the optimizer calls it
    # unconditionally, so the API inference path — which has nothing to dump —
    # installs a disabled scheduler instead of leaving it None.
    save_every = FrequencyScheduler(last_step=0)
    save_every.disable(True)
    optimizer.save_every = save_every
    return optimizer


def build_adam_baseline(num_refine: int) -> "nn.Module":
    """Build the codebase's 3DGS Adam optimizer for a fair baseline comparison.

    Uses the bundled ``scene_optimizer=3dgs`` config — gsplat's example
    hyperparameters (LRs, betas). Densification is disabled so the baseline
    refines the same fixed Gaussian set as the learned optimizer (a
    like-for-like update-rule comparison), and the means-LR decay horizon is set
    to ``num_refine``. Returns a ready-to-run ``AdamOptimizer``.
    """
    from omegaconf import OmegaConf

    from optgs.config import load_typed_config
    from optgs.misc.io import FrequencyScheduler
    from optgs.scene_trainer.optimizer.optimizer_adam import (
        AdamOptimizer,
        AdamOptimizerCfg,
    )

    composed = _compose_default_group("scene_optimizer", "3dgs")
    if composed is None:
        raise OptGSError(
            "could not Hydra-compose the bundled 'scene_optimizer=3dgs' config "
            "for the Adam baseline."
        )
    OmegaConf.set_struct(composed, False)
    # gsplat decays the means LR over the full step budget.
    composed.means_lr_max_steps = int(num_refine)
    # Disable densification — the baseline refines the same fixed Gaussian set
    # as the learned optimizer (a like-for-like comparison of the update rule).
    for flag in ("do_densify", "do_prune", "do_opacity_reset"):
        if flag in composed.refiner:
            composed.refiner[flag] = False
    try:
        adam_cfg = load_typed_config(composed, AdamOptimizerCfg)
    except Exception as e:
        raise OptGSError(
            f"failed to parse the bundled '3dgs' config into AdamOptimizerCfg "
            f"({type(e).__name__}: {e})."
        ) from e

    optimizer = AdamOptimizer(adam_cfg)
    save_every = FrequencyScheduler(last_step=0)  # nothing to dump (see build_optimizer)
    save_every.disable(True)
    optimizer.save_every = save_every
    # AdamOptimizer is a NonlearnedOptimizer — already pinned to eval mode.
    return optimizer


# Module-attribute renames applied when the legacy Resplat encoder was split
# into separate initializer/optimizer modules (transcribed from
# optgs/main.py:load_optimizer).
_ORIG_OPTIMIZER_ATTR_RENAMES = {
    "render_error_mv_attn": "update_error_attn",
}


def load_optimizer_state(
    optimizer: "nn.Module",
    ckpt_path: str,
    init_state_wo_features: bool,
    strict: bool,
) -> None:
    """Load optimizer weights from ``ckpt_path`` into ``optimizer``.

    Transcribes the prefix-stripping / legacy-rename / feature-drop logic from
    ``optgs/main.py:load_optimizer`` (we cannot call that function: it needs a
    full Hydra ``cfg`` and a ``scene_trainer``).
    """
    import torch

    state = torch.load(ckpt_path, map_location="cpu")
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    # Strip the Lightning "scene_trainer." prefix if present.
    state = {k.replace("scene_trainer.", ""): v for k, v in state.items()}

    if any(k.startswith("optimizer.") for k in state):
        # Unified repo format: keys are optimizer.*
        osd = {
            k[len("optimizer."):]: v
            for k, v in state.items()
            if k.startswith("optimizer.")
        }
    else:
        # Legacy Resplat format: keys are encoder.* (before init/opt split).
        osd = {
            k[len("encoder."):]: v
            for k, v in state.items()
            if k.startswith("encoder.")
        }
        renamed = {}
        for k, v in osd.items():
            for old, new in _ORIG_OPTIMIZER_ATTR_RENAMES.items():
                if k == old or k.startswith(old + "."):
                    k = new + k[len(old):]
                    break
            renamed[k] = v
        osd = renamed

    if not osd:
        raise OptGSError(
            f"no optimizer weights found in {ckpt_path} (looked for "
            f"'optimizer.*' or legacy 'encoder.*' keys)."
        )

    if init_state_wo_features:
        osd = {k: v for k, v in osd.items() if "update_proj" not in k}

    optimizer.load_state_dict(osd, strict=strict)