File size: 19,523 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
"""Public API: use optgs's learned optimizer in external 3DGS codebases.

Typical inria (graphdeco-inria/gaussian-splatting) integration β€” replace the
hand-written training loop with three lines::

    from optgs.experimental.api import OptGS

    gaussians = GaussianModel(sh_degree)        # set up as usual (SfM init)
    scene = Scene(dataset, gaussians)
    optgs = OptGS(checkpoint="hf://org/repo/model.ckpt", device="cuda")
    optgs.initialize(scene)                      # ingest scene + build optimizer
    optgs.optimize(scene)                        # learned optimization, written back in place
    scene.save(iteration)                        # proceed as normal

Full-replacement semantics: ``optimize`` overwrites ``scene.gaussians`` in
place and nulls the inria Adam optimizer + densification accumulators. If you
later want to resume inria Adam, call ``gaussians.training_setup(...)`` again.

For non-inria codebases use :meth:`OptGS.initialize_from_ply` /
:meth:`OptGS.initialize_from_tensors` + :meth:`OptGS.export_ply`.

External SfM scenes carry no optgs encoder features, so checkpoints trained
with ``init_state_wo_features=False`` are coerced at construction (with a
warning): the feature-conditioned ``update_proj`` weights are dropped and the
optimizer state is initialized standard-normal.
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Sequence

import torch

from optgs.experimental.api.integration.scene_protocol import OptGSError

if TYPE_CHECKING:  # pragma: no cover - typing only
    from optgs.model.types import Gaussians

__all__ = ["OptGS", "OptGSError"]


class OptGS:
    """Facade around the learned per-scene optimizer."""

    def __init__(
        self,
        checkpoint: str,
        *,
        device: str | torch.device = "cuda",
        num_refine: int | None = None,
        iter_batch_size: int | None = None,
        opt_batch_size: int | None = None,
        opt_batch_strategy: str | None = None,
        background_color: Sequence[float] | None = None,
        rasterize_mode: str | None = None,
        eps2d: float | None = None,
        strict_load: bool = True,
    ) -> None:
        if not checkpoint:
            raise OptGSError(
                "OptGS(checkpoint=...) is required (an 'hf://org/repo/file' "
                "reference or a local checkpoint path)."
            )
        self.device = torch.device(device)
        if self.device.type != "cuda":
            raise OptGSError(
                "OptGS requires a CUDA device (the learned optimizer uses "
                "CUDA/KNN kernels). Pass device='cuda'."
            )
        # float32 only β€” the learned optimizer's CUDA/KNN kernels and the
        # gsplat rasterizer require it (and the checkpoint trained with it).
        self.dtype = torch.float32
        self.iter_batch_size = iter_batch_size
        self.opt_batch_size = opt_batch_size
        self.opt_batch_strategy = opt_batch_strategy

        from optgs.config import _find_config_for_checkpoint
        from optgs.experimental.api.integration.config_bridge import (
            build_decoder,
            build_optimizer,
            build_optimizer_cfg,
            get_scene_trainer_scalar,
            load_optimizer_state,
        )
        from optgs.misc.hf_ckpt import hf_sibling_config, maybe_resolve_hf_ref

        local_ckpt = maybe_resolve_hf_ref(checkpoint)
        # For hf:// refs, hf_hub_download fetches only the ckpt; pull the
        # sibling config.yaml so the architecture can be rebuilt.
        cfg_path = hf_sibling_config(checkpoint) or _find_config_for_checkpoint(local_ckpt)
        if cfg_path is None:
            raise OptGSError(
                f"no config.yaml found next to checkpoint {local_ckpt!r} "
                f"(looked for <ckpt>/../../config.yaml and the wandb "
                f"latest-run fallback). OptGS needs the training config to "
                f"rebuild the optimizer architecture."
            )

        opt_cfg, num_update_steps = build_optimizer_cfg(cfg_path)

        if not getattr(opt_cfg, "init_state_wo_features", False):
            warnings.warn(
                "this checkpoint was trained WITH encoder features "
                "(scene_trainer.scene_optimizer.init_state_wo_features=False). "
                "External SfM/inria scenes carry no optgs encoder features; "
                "proceeding with init_state_wo_features=True β€” the "
                "feature-conditioned update_proj weights are dropped and the "
                "initial optimizer state is set to a standard-normal random "
                "vector (init_state_type='random', init_state_scale=1.0)."
            )
            opt_cfg.init_state_wo_features = True
            opt_cfg.init_state_type = "random"
            opt_cfg.init_state_scale = 1.0

        optimizer = build_optimizer(opt_cfg)  # asserts cfg.name; nn.Module
        load_optimizer_state(
            optimizer, local_ckpt, init_state_wo_features=True, strict=strict_load
        )
        self.optimizer = optimizer.to(device=self.device, dtype=self.dtype).eval()

        from types import SimpleNamespace

        bg = list(background_color) if background_color is not None else [0.0, 0.0, 0.0]
        # Build the renderer the checkpoint trained with (gsplat by default;
        # NOT a hardcoded backend β€” see build_decoder). rasterize_mode / eps2d,
        # when given, override the checkpoint's decoder config.
        decoder_overrides = {
            k: v
            for k, v in (("rasterize_mode", rasterize_mode), ("eps2d", eps2d))
            if v is not None
        }
        self.decoder = build_decoder(
            cfg_path, SimpleNamespace(background_color=bg), decoder_overrides
        ).to(self.device)

        resolved = num_refine if num_refine is not None else num_update_steps
        if resolved is None:
            raise OptGSError(
                "num_refine could not be determined: pass OptGS(num_refine=...) "
                "or use a checkpoint whose config has "
                "scene_trainer.num_update_steps."
            )
        self.num_refine = int(resolved)

        # Render-batching size: user override, else the checkpoint's
        # scene_trainer.iter_batch_size (-1 = render all views per step).
        if self.iter_batch_size is None:
            self.iter_batch_size = int(
                get_scene_trainer_scalar(cfg_path, "iter_batch_size", -1)
            )

        # Per-step view minibatch β€” opt_batch_size views are fed to the
        # optimizer each step (the checkpoint's scene_trainer.opt_batch_size /
        # opt_batch_strategy, i.e. the regime it was trained with). -1 = all.
        if self.opt_batch_size is None:
            self.opt_batch_size = int(
                get_scene_trainer_scalar(cfg_path, "opt_batch_size", -1)
            )
        if self.opt_batch_strategy is None:
            self.opt_batch_strategy = str(
                get_scene_trainer_scalar(cfg_path, "opt_batch_strategy", "random")
            )
        if self.opt_batch_strategy not in ("random", "sequential", "fps"):
            raise OptGSError(
                f"opt_batch_strategy={self.opt_batch_strategy!r} is not supported "
                f"by the API (supported: 'random', 'sequential', 'fps'). Pass "
                f"OptGS(opt_batch_strategy='random')."
            )

        self._opt_cfg = opt_cfg
        # SH degree the checkpoint's Gaussians use β€” derived from the optimizer
        # cfg's init_sh_d (= (sh_degree + 1) ** 2, set by opt_cfg.update from the
        # initializer cfg). API consumers build/render Gaussians with this; it is
        # dictated by the checkpoint, not a free choice.
        self.sh_degree = int(round(opt_cfg.init_sh_d ** 0.5)) - 1
        self._initialized = False
        self._scene_ref = None
        self._context = None
        self._init_output = None
        self._refined: "Gaussians | None" = None

    # ------------------------------------------------------------------
    # Ingest
    # ------------------------------------------------------------------

    def initialize(self, scene: object) -> "OptGS":
        """Ingest an already-initialized inria-style scene.

        This does NOT run optgs's learned Initializer β€” the scene already has
        Gaussians (e.g. from SfM / inria ``create_from_pcd``).
        """
        from optgs.experimental.api.integration.inria_bridge import (
            batched_views_from_cameras,
            optgs_gaussians_from_inria_model,
        )
        from optgs.experimental.api.integration.scene_protocol import (
            assert_scene_protocol,
        )
        from optgs.scene_trainer.initializer.initializer import InitializerOutput

        assert_scene_protocol(scene)
        g = optgs_gaussians_from_inria_model(
            scene.gaussians, device=self.device, dtype=self.dtype
        )
        self._init_output = InitializerOutput(gaussians=g, features=None, depths=None)
        self._context = batched_views_from_cameras(
            list(scene.getTrainCameras()),
            scene_scale=float(scene.cameras_extent),
            device=self.device,
            dtype=self.dtype,
        )
        self._scene_ref = scene
        self._initialized = True
        return self

    def initialize_from_ply(
        self,
        ply_path: str,
        cameras: Sequence[object],
        *,
        sh_degree: int,
        scene_scale: float,
    ) -> "OptGS":
        """Low-level ingest for non-inria codebases (no inria ``Scene``).

        ``cameras`` is a sequence of inria-``Camera``-like objects (``R``,
        ``T``, ``FoVx``, ``FoVy``, ``image_width``, ``image_height``,
        ``original_image``).
        """
        from optgs.experimental.api.integration.inria_bridge import (
            batched_views_from_cameras,
            optgs_gaussians_from_ply,
        )
        from optgs.scene_trainer.initializer.initializer import InitializerOutput

        g = optgs_gaussians_from_ply(
            ply_path, sh_degree=sh_degree, device=self.device, dtype=self.dtype
        )
        self._init_output = InitializerOutput(gaussians=g, features=None, depths=None)
        self._context = batched_views_from_cameras(
            list(cameras), scene_scale=scene_scale, device=self.device, dtype=self.dtype
        )
        self._scene_ref = None
        self._initialized = True
        return self

    def initialize_from_tensors(self, gaussians: object, batched_views: object) -> "OptGS":
        """Low-level ingest from optgs-native objects (power users).

        ``gaussians``: an optgs ``Gaussians`` (batch=1, post-activation).
        ``batched_views``: an optgs ``BatchedViews`` or a dict accepted by
        ``BatchedViews.from_dict``.
        """
        from optgs.dataset.data_types import BatchedViews
        from optgs.model.types import Gaussians
        from optgs.scene_trainer.initializer.initializer import InitializerOutput

        if not isinstance(gaussians, Gaussians):
            raise OptGSError(
                "initialize_from_tensors expects an optgs Gaussians instance "
                "(use initialize_from_ply for raw 3DGS PLY input)."
            )
        bv = (
            batched_views
            if isinstance(batched_views, BatchedViews)
            else BatchedViews.from_dict(batched_views)
        )
        self._init_output = InitializerOutput(
            gaussians=gaussians.to(device=self.device, dtype=self.dtype),
            features=None,
            depths=None,
        )
        self._context = bv
        self._scene_ref = None
        self._initialized = True
        return self

    # ------------------------------------------------------------------
    # Optimize
    # ------------------------------------------------------------------

    def _view_minibatch(self, views):
        """Sample the next per-step view minibatch from ``views``.

        Mirrors SceneTrainer's viewpoint-stack cycling: views are drawn
        ``opt_batch_size`` at a time and the stack is refilled once exhausted,
        so every view is seen before any repeats. ``random``/``sequential`` take
        the front of the (shuffled/ordered) stack; ``fps`` picks a
        farthest-point spread over the remaining views' camera positions.
        Returns ``views`` unchanged when ``opt_batch_size`` is <= 0 or already
        covers the whole scene.
        """
        v = views.image.shape[1]
        bs = self.opt_batch_size
        if bs <= 0 or bs >= v:
            return views

        views.reset_viewpoint_stack_if_needed(self.opt_batch_strategy, bs)
        stack = views.viewpoint_stack  # [B, V_stack]

        if self.opt_batch_strategy == "fps":
            from optgs.dataset.view_sampler.view_sampler_bounded_v2 import (
                farthest_point_sample,
            )

            b = stack.shape[0]
            arange = torch.arange(b, device=stack.device)[:, None]
            # FPS over the camera positions of the views still in the stack.
            positions = views.extrinsics[arange, stack][:, :, :3, 3]  # [B, V_stack, 3]
            local = farthest_point_sample(positions, bs, first_idx_strategy="random")
            idx = stack[arange, local]  # [B, bs]
            keep = ~(stack.unsqueeze(-1) == idx.unsqueeze(1)).any(-1)  # [B, V_stack]
            views.viewpoint_stack = stack[keep].view(b, -1)
        else:  # random / sequential β€” take the front of the stack
            idx = stack[:, :bs]
            views.viewpoint_stack = stack[:, bs:]
        return views.batchify_views(idx)

    @torch.no_grad()
    def optimize(self, scene: object | None = None, *, optimizer=None):
        """Run the learned optimization.

        inria path: refined Gaussians are written back into ``scene.gaussians``
        in place and ``scene.gaussians`` is returned. Low-level path: the
        refined optgs ``Gaussians`` is returned (use :meth:`export_ply` to
        persist).

        ``optimizer`` swaps in a different optgs ``Optimizer`` (e.g. an Adam
        baseline) β€” running the *same* per-scene pipeline (init, view minibatch,
        step budget, renderer) with another update rule, i.e. a fair
        comparison. Defaults to the checkpoint's learned optimizer.
        """
        if scene is not None and scene is not self._scene_ref:
            self.initialize(scene)
        if not self._initialized:
            raise OptGSError("call initialize(scene) before optimize().")

        opt = optimizer if optimizer is not None else self.optimizer

        from optgs.scene_trainer.optimizer.optimizer import (
            OptimizerInput,
            OptimizerOutput,
            OptimizerPreviousOutput,
        )

        inp = OptimizerInput(
            context=self._context,
            renderer=self.decoder,
            prev_output=self._init_output,
            num_refine=self.num_refine,
            iter_batch_size=self.iter_batch_size,
            target=self._context,
        )
        opt.validate_input(inp)
        opt.on_scene_start(inp)  # InitializerOutput -> OptimizerPreviousOutput (+ADC)
        if not isinstance(inp.prev_output, OptimizerPreviousOutput):
            raise OptGSError(
                "optimizer.on_scene_start did not produce an "
                f"OptimizerPreviousOutput (got {type(inp.prev_output)})."
            )

        out = OptimizerOutput.empty(t=0)
        out.T = self.num_refine
        steps = range(self.num_refine)
        try:
            from tqdm import tqdm

            steps = tqdm(steps, desc=f"optimize[{type(opt).__name__}]")
        except Exception:
            pass
        for step in steps:
            # Feed the optimizer a fresh view minibatch each step (the regime it
            # was trained with); full_context/full_target stay the whole scene.
            batch = self._view_minibatch(self._context)
            inp.context = batch
            inp.target = batch
            out = opt(
                step, inp, out, full_context=self._context, full_target=self._context
            )
            out.t = (out.t or 0) + 1

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        opt.on_scene_end()

        final = inp.prev_output.gaussians
        self._refined = final

        if self._scene_ref is not None:
            from optgs.experimental.api.integration.inria_bridge import (
                write_back_to_inria_model,
            )

            write_back_to_inria_model(self._scene_ref.gaussians, final)
            return self._scene_ref.gaussians
        return final

    def optimize_iter(self, *, optimizer=None):
        """Generator form of :meth:`optimize`: yields ``(step, gaussians)`` after
        each optimization step.

        Lets a caller drive the learned optimization one step at a time and
        render the Gaussians in between β€” used by ``demo.py``'s ``--with-gui``.
        ``on_scene_end()`` runs even if the caller closes the generator early
        (e.g. a GUI Reset), via the ``finally`` block.
        """
        if not self._initialized:
            raise OptGSError("call initialize(...) before optimize_iter().")

        opt = optimizer if optimizer is not None else self.optimizer

        from optgs.scene_trainer.optimizer.optimizer import (
            OptimizerInput,
            OptimizerOutput,
            OptimizerPreviousOutput,
        )

        with torch.no_grad():
            inp = OptimizerInput(
                context=self._context,
                renderer=self.decoder,
                prev_output=self._init_output,
                num_refine=self.num_refine,
                iter_batch_size=self.iter_batch_size,
                target=self._context,
            )
            opt.validate_input(inp)
            opt.on_scene_start(inp)  # InitializerOutput -> OptimizerPreviousOutput
            if not isinstance(inp.prev_output, OptimizerPreviousOutput):
                raise OptGSError(
                    "optimizer.on_scene_start did not produce an "
                    f"OptimizerPreviousOutput (got {type(inp.prev_output)})."
                )

            out = OptimizerOutput.empty(t=0)
            out.T = self.num_refine
            try:
                for step in range(self.num_refine):
                    # Fresh view minibatch each step (the regime the optimizer
                    # was trained with); full_context/target stay the whole scene.
                    batch = self._view_minibatch(self._context)
                    inp.context = batch
                    inp.target = batch
                    out = opt(
                        step, inp, out,
                        full_context=self._context, full_target=self._context,
                    )
                    out.t = (out.t or 0) + 1
                    yield step, inp.prev_output.gaussians
            finally:
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                opt.on_scene_end()
                self._refined = inp.prev_output.gaussians

    def export_ply(self, path: str) -> None:
        """Write the most recently refined Gaussians to a 3DGS PLY."""
        if self._refined is None:
            raise OptGSError("nothing to export β€” call optimize() first.")
        from pathlib import Path

        from optgs.model.ply_export import save_gaussian_ply

        save_gaussian_ply(self._refined, save_path=Path(path))