SteEsp commited on
Commit
78d2329
·
verified ·
1 Parent(s): 7e6090c

Add Docker-based Learn2Splat demo (viser GUI)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +32 -0
  2. .gitattributes +6 -0
  3. Dockerfile +76 -0
  4. LICENSE +21 -0
  5. README.md +19 -8
  6. demo.py +766 -0
  7. optgs/__init__.py +1 -0
  8. optgs/config.py +770 -0
  9. optgs/config/dataset/base.yaml +8 -0
  10. optgs/config/dataset/colmap.yaml +12 -0
  11. optgs/config/dataset/dl3dv.yaml +61 -0
  12. optgs/config/dataset/re10k.yaml +27 -0
  13. optgs/config/dataset/scannet.yaml +13 -0
  14. optgs/config/dataset/view_sampler/all.yaml +1 -0
  15. optgs/config/dataset/view_sampler/arbitrary.yaml +7 -0
  16. optgs/config/dataset/view_sampler/bounded.yaml +12 -0
  17. optgs/config/dataset/view_sampler/boundedv2.yaml +15 -0
  18. optgs/config/dataset/view_sampler/boundedv2_360.yaml +17 -0
  19. optgs/config/dataset/view_sampler/dense.yaml +6 -0
  20. optgs/config/dataset/view_sampler/evaluation.yaml +4 -0
  21. optgs/config/dataset/view_sampler/ids.yaml +4 -0
  22. optgs/config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml +11 -0
  23. optgs/config/dataset/view_sampler_dataset_specific_config/boundedv2_dl3dv.yaml +14 -0
  24. optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_dl3dv.yaml +5 -0
  25. optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml +5 -0
  26. optgs/config/experiment/re10k_unified.yaml +78 -0
  27. optgs/config/experiment/test_colmap.yaml +32 -0
  28. optgs/config/experiment/test_dl3dv.yaml +38 -0
  29. optgs/config/experiment/test_re10k.yaml +36 -0
  30. optgs/config/experiment/train_dl3dv.yaml +55 -0
  31. optgs/config/experiment/train_l2s_sparse_dl3dv.yaml +41 -0
  32. optgs/config/experiment/train_l2s_sparse_dl3dv_no_delta.yaml +35 -0
  33. optgs/config/experiment/train_l2s_sparse_dl3dv_no_loss.yaml +35 -0
  34. optgs/config/loss/deltas.yaml +6 -0
  35. optgs/config/loss/gaussians.yaml +6 -0
  36. optgs/config/loss/iso_scales.yaml +2 -0
  37. optgs/config/loss/lpips.yaml +4 -0
  38. optgs/config/loss/mse.yaml +2 -0
  39. optgs/config/loss/sgd.yaml +2 -0
  40. optgs/config/loss/sh0.yaml +2 -0
  41. optgs/config/loss/ssim.yaml +2 -0
  42. optgs/config/loss/stability.yaml +2 -0
  43. optgs/config/main.yaml +195 -0
  44. optgs/config/meta_trainer/test/postprocessing/adam.yaml +10 -0
  45. optgs/config/meta_trainer/test/postprocessing/base.yaml +24 -0
  46. optgs/config/meta_trainer/test/postprocessing/none.yaml +5 -0
  47. optgs/config/meta_trainer/test/postprocessing/sgd.yaml +7 -0
  48. optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs.yaml +12 -0
  49. optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs_sgd.yaml +12 -0
  50. optgs/config/meta_trainer/train/replay_buffer_cfg/default.yaml +12 -0
.dockerignore ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build-context / image trimming for the Learn2Splat demo Space.
2
+ # The Dockerfile needs: demo.py, optgs/, submodules/, requirements.txt,
3
+ # pyproject.toml, LICENSE — keep those; drop everything below.
4
+
5
+ # Secrets — never copy into the image.
6
+ .env
7
+ .env.*
8
+ /wandb/
9
+
10
+ # Git + Python build droppings.
11
+ .git/
12
+ .gitignore
13
+ **/__pycache__/
14
+ **/*.pyc
15
+ **/*.egg-info/
16
+ submodules/*/build/
17
+
18
+ # Large runtime artefacts — fetched into the container on first run.
19
+ /data/
20
+ /checkpoints/
21
+ /results/
22
+
23
+ # Repo material the demo doesn't use.
24
+ /assets/
25
+ /docs/
26
+ /figures/
27
+ /tests/
28
+ /scripts/
29
+ /mlcloud_scripts/
30
+ /visualization/
31
+ /todo/
32
+ huggingface_space/
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ submodules/fused-ssim/images/albert.jpg filter=lfs diff=lfs merge=lfs -text
37
+ submodules/fused-ssim/images/inference_time.png filter=lfs diff=lfs merge=lfs -text
38
+ submodules/fused-ssim/images/inference_time_4090.png filter=lfs diff=lfs merge=lfs -text
39
+ submodules/fused-ssim/images/predicted.jpg filter=lfs diff=lfs merge=lfs -text
40
+ submodules/fused-ssim/images/training_time.png filter=lfs diff=lfs merge=lfs -text
41
+ submodules/fused-ssim/images/training_time_4090.png filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Learn2Splat — interactive demo for a Hugging Face Space (Docker SDK, GPU).
2
+ #
3
+ # Builds the optgs package + its CUDA extensions and runs demo.py's viser GUI:
4
+ # SfM-initialize a COLMAP scene, then refine the Gaussians with the learned
5
+ # optimizer live in the browser. Mirrors setup.sh, minus conda — the CUDA
6
+ # toolkit ships in the base image.
7
+ #
8
+ # Build context = the optgs repo root (see huggingface_space/DEPLOY.md).
9
+ # Hardware: pick a GPU in the Space settings — A10G (24 GB) recommended; the
10
+ # GUI holds the dense and sparse checkpoints in VRAM at once.
11
+
12
+ # CUDA 12.8 devel (nvcc + headers); Ubuntu 22.04 — the OS setup.sh is tested on.
13
+ # A devel base is required: gsplat / nerfacc JIT-compile CUDA on first use, so
14
+ # nvcc must also be present at runtime.
15
+ FROM nvidia/cuda:12.8.0-devel-ubuntu22.04
16
+
17
+ ENV DEBIAN_FRONTEND=noninteractive \
18
+ PYTHONUNBUFFERED=1 \
19
+ PIP_NO_CACHE_DIR=1 \
20
+ # Compile the CUDA extensions for every GPU a Space may run on
21
+ # (T4 7.5 · A100 8.0 · A10G 8.6 · L4/L40S 8.9 · H100 9.0). Trim this to
22
+ # your chosen GPU to shorten the build.
23
+ TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6 8.9 9.0+PTX"
24
+
25
+ # Build tools + extension headers (libglm-dev) and the OpenCV runtime libs
26
+ # (libgl1, libglib2.0-0 — optgs's COLMAP loader imports cv2).
27
+ RUN apt-get update && apt-get install -y --no-install-recommends \
28
+ python3 python3-dev python3-venv \
29
+ git build-essential ninja-build libglm-dev \
30
+ libgl1 libglib2.0-0 ca-certificates \
31
+ && rm -rf /var/lib/apt/lists/*
32
+
33
+ # HF Spaces convention: run as a non-root user (UID 1000).
34
+ RUN useradd -m -u 1000 user
35
+ USER user
36
+ ENV HOME=/home/user \
37
+ HF_HOME=/home/user/.cache/huggingface \
38
+ TORCH_HOME=/home/user/.cache/torch
39
+ WORKDIR /home/user/app
40
+
41
+ # All Python work happens in a venv on PATH (no system-Python writes).
42
+ RUN python3 -m venv /home/user/venv
43
+ ENV PATH=/home/user/venv/bin:$PATH
44
+ RUN pip install --upgrade pip setuptools wheel
45
+
46
+ # PyTorch (CUDA 12.8) — pinned to setup.sh.
47
+ RUN pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 \
48
+ --index-url https://download.pytorch.org/whl/cu128
49
+
50
+ # Python requirements (copied first so this layer caches across code edits).
51
+ COPY --chown=user:user requirements.txt .
52
+ RUN pip install -r requirements.txt
53
+
54
+ # gsplat + nerfacc — built from git against the torch installed above.
55
+ RUN pip install --no-build-isolation \
56
+ git+https://github.com/nerfstudio-project/nerfacc \
57
+ git+https://github.com/nerfstudio-project/gsplat.git
58
+
59
+ # The optgs repo.
60
+ COPY --chown=user:user . .
61
+
62
+ # CUDA-extension submodules, then optgs itself. pycolmap is the pure-Python
63
+ # COLMAP reader (no C++ build); the other four compile CUDA kernels.
64
+ RUN pip install submodules/pycolmap \
65
+ && pip install --no-build-isolation submodules/fused-ssim \
66
+ && pip install --no-build-isolation submodules/simple-knn \
67
+ && pip install --no-build-isolation submodules/pointops \
68
+ && pip install --no-build-isolation submodules/fused_knn_attn \
69
+ && pip install --no-build-isolation --no-deps -e .
70
+
71
+ # viser serves the GUI here — must equal app_port in README.md.
72
+ EXPOSE 7860
73
+
74
+ # client mode: viser ships the splats to the browser's WebGL renderer, so the
75
+ # GPU is used only for optimization. viser binds 0.0.0.0 by default.
76
+ CMD ["python", "demo.py", "--with-gui", "client", "--gui-port", "7860"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Naama Pearl and Stefano Esposito
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,24 @@
1
  ---
2
  title: Learn2Splat
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
- app_file: app.py
10
  pinned: false
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Learn2Splat
3
+ emoji: 🪴
4
+ colorFrom: green
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
 
 
8
  pinned: false
9
+ short_description: Interactive demo of the Learn2Splat learned 3DGS optimizer
10
  ---
11
 
12
+ # Learn2Splat interactive demo
13
+
14
+ A learned optimizer for 3D Gaussian Splatting. This Space SfM-initializes a
15
+ COLMAP scene and refines the Gaussians live in your browser: pick the
16
+ Learn2Splat optimizer (dense or sparse checkpoint) or a 3DGS Adam baseline,
17
+ press **Start**, and watch the splats converge.
18
+
19
+ Runs `demo.py --with-gui client` from the
20
+ [Learn2Splat repository](https://github.com/autonomousvision/learn2splat);
21
+ the splats are drawn by viser's in-browser WebGL renderer.
22
+
23
+ > Requires GPU hardware. The demo holds two checkpoints in VRAM at once —
24
+ > an A10G (24 GB) is recommended.
demo.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end OptGS demo on a COLMAP scene.
2
+
3
+ Main-codebase port of ``baselines/gsplat/examples/simple_trainer_optgs.py``:
4
+ same flow — SfM-initialize Gaussians, refine them with the learned optimizer
5
+ via the ``OptGS`` API, evaluate on held-out views — but using only the
6
+ ``optgs`` package (no gsplat / gsplat-examples dependency):
7
+
8
+ from optgs.experimental.api import OptGS
9
+
10
+ optgs = OptGS(checkpoint="hf://org/repo/model.ckpt", device="cuda")
11
+ optgs.initialize_from_tensors(gaussians, batched_views)
12
+ refined = optgs.optimize() # learned optimization
13
+
14
+ COLMAP loading uses ``optgs.dataset.colmap``; the SfM init builds an optgs
15
+ ``Gaussians`` directly via ``points_to_gaussians``; evaluation renders with
16
+ the optimizer's own decoder.
17
+
18
+ The scene is refined three ways and compared on held-out views: the learned
19
+ optimizer (Learn2Splat) with the *dense* and the *sparse* checkpoint, and a
20
+ 3DGS Adam baseline (gsplat hyperparameters). All run through the same
21
+ ``optimize()`` path with identical SfM init, view minibatches and step budget.
22
+ Each uses its checkpoint's gsplat renderer; ``--rasterize-mode`` / ``--eps2d``
23
+ pin one renderer across all runs.
24
+
25
+ Usage (run from the repo root, with ``optgs`` importable):
26
+
27
+ python demo.py # headless: dense + sparse checkpoints + an Adam baseline
28
+ python demo.py --with-gui server # interactive viser GUI (frames rendered by the decoder)
29
+ python demo.py --with-gui client # interactive viser GUI (viser's WebGL splat renderer)
30
+
31
+ The demo scene and the checkpoints are fetched from the Hugging Face Hub on
32
+ first run (cached under ./data and ./checkpoints). A CUDA device is required.
33
+ """
34
+
35
+ import warnings
36
+
37
+ # Demo: silence third-party UserWarnings (xFormers/flash-attn not installed,
38
+ # Hydra's _self_ notice, pointops' deprecated tensor constructors) for clean output.
39
+ warnings.filterwarnings("ignore")
40
+
41
+ import json
42
+ import os
43
+ import time
44
+ from dataclasses import dataclass
45
+ from typing import Dict, List, Literal, Optional, Tuple
46
+
47
+ import imageio.v2 as imageio
48
+ import numpy as np
49
+ import torch
50
+ import torch.nn.functional as F
51
+ import tyro
52
+ from rich.console import Console
53
+ from rich.table import Table
54
+ from torch import Tensor
55
+
56
+ console = Console()
57
+
58
+ from optgs.dataset.colmap.utils import Dataset, Parser
59
+ from optgs.experimental.initializers_utils import knn, points_to_gaussians
60
+ from optgs.model.types import Gaussians
61
+ from optgs.scene_trainer.common.gaussian_adapter import build_covariance
62
+
63
+ # Camera near/far planes — inria's znear/zfar (also the optgs colmap-dataset
64
+ # constants). Fixed; not a user knob.
65
+ NEAR_PLANE = 0.01
66
+ FAR_PLANE = 100.0
67
+
68
+ # Spherical-harmonics DC -> RGB (3DGS convention: rgb = 0.5 + C0 * dc). Colours
69
+ # the splats for viser's client-side renderer.
70
+ SH_C0 = 0.28209479177387814
71
+
72
+ # The demo scene is fetched from this Hugging Face repo on first run. The repo
73
+ # mirrors the local layout, so e.g. ``data/mip360/garden`` in the repo lands at
74
+ # ``./data/mip360/garden``.
75
+ DEMO_DATA_REPO = "autonomousvision/learn2splat"
76
+
77
+ # Learned-optimizer checkpoints on the Hugging Face Hub. hf:// refs are fetched
78
+ # and cached under ./checkpoints on first use (see optgs.misc.hf_ckpt).
79
+ CHECKPOINTS = {
80
+ "dense": "hf://autonomousvision/learn2splat/dense/checkpoints/epoch_5-step_50000.ckpt",
81
+ "sparse": "hf://autonomousvision/learn2splat/sparse/checkpoints/epoch_9-step_90000.ckpt",
82
+ }
83
+
84
+
85
+ def ensure_data(data_dir: str) -> None:
86
+ """Download the demo scene from the Hugging Face Hub if it is not present."""
87
+ if os.path.isdir(data_dir) and os.listdir(data_dir):
88
+ return
89
+ from huggingface_hub import snapshot_download
90
+
91
+ console.print(
92
+ f"[yellow]{data_dir}[/] not found — downloading from "
93
+ f"[cyan]hf://{DEMO_DATA_REPO}[/] …"
94
+ )
95
+ snapshot_download(
96
+ repo_id=DEMO_DATA_REPO,
97
+ allow_patterns=[f"{data_dir.rstrip('/')}/**"],
98
+ local_dir=".",
99
+ )
100
+ console.print(f"[green]✓[/] scene ready at [yellow]{data_dir}[/]")
101
+
102
+
103
+ @dataclass
104
+ class Config:
105
+ # Path to the COLMAP dataset (expects images/ + sparse/0/).
106
+ data_dir: str = "data/mip360/garden"
107
+ # Downsample factor for the dataset.
108
+ data_factor: int = 4
109
+ # Global multiplier on scene-size-related parameters.
110
+ global_scale: float = 1.0
111
+ # Normalize the world space.
112
+ normalize_world_space: bool = True
113
+ # Every N images is a test image, held out for evaluation.
114
+ test_every: int = 8
115
+ # Directory to save renders / stats / the refined PLY.
116
+ result_dir: str = "results/demo"
117
+ # Random seed.
118
+ seed: int = 42
119
+
120
+ # --- Interactive GUI ---
121
+ # Launch a viser GUI instead of the headless comparison. "server" renders
122
+ # frames with the optgs decoder; "client" uses viser's built-in WebGL
123
+ # Gaussian-splat renderer. Unset = headless run.
124
+ with_gui: Optional[Literal["client", "server"]] = None
125
+ # Port for the viser GUI web server (--with-gui only).
126
+ gui_port: int = 8080
127
+
128
+ # --- OptGS learned optimizer ---
129
+ # Compute device (OptGS requires CUDA).
130
+ device: str = "cuda"
131
+ # Number of learned refinement steps.
132
+ max_steps: int = 100
133
+ # Views the optimizer sees per refinement step (the view minibatch).
134
+ opt_batch_size: int = 8
135
+ # View-minibatch sampling strategy: "random", "sequential", or "fps"
136
+ # (farthest-point sampling over camera positions).
137
+ opt_batch_strategy: Literal["random", "sequential", "fps"] = "fps"
138
+
139
+ # --- gsplat renderer ---
140
+ # rasterize_mode / eps2d: when set, applied to every run (dense, sparse,
141
+ # Adam), overriding each checkpoint's decoder config so the comparison uses
142
+ # one renderer. Left unset, each run uses its own checkpoint's value.
143
+ rasterize_mode: Optional[Literal["classic", "antialiased"]] = None
144
+ eps2d: Optional[float] = None
145
+
146
+ # --- Initialization ---
147
+ # Initialization strategy: "sfm" or "random".
148
+ init_type: str = "sfm"
149
+ # Initial number of GSs. Ignored when init_type="sfm".
150
+ init_num_pts: int = 100_000
151
+ # Initial extent of GSs as a multiple of the scene extent (random init).
152
+ init_extent: float = 3.0
153
+ # Initial opacity / scale of each GS.
154
+ init_opa: float = 0.1
155
+ init_scale: float = 1.0
156
+
157
+
158
+ def scene_extent(parser: Parser, global_scale: float) -> float:
159
+ """Scene-size scalar: parser extent x 1.1 x global_scale."""
160
+ return parser.scene_scale * 1.1 * global_scale
161
+
162
+
163
+ def sfm_initialization(
164
+ parser: Parser, cfg: Config, sh_degree: int, device: torch.device, dtype: torch.dtype
165
+ ) -> Gaussians:
166
+ """SfM (or random) Gaussian init -> an optgs ``Gaussians`` (batch=1).
167
+
168
+ Builds the parameter tensors with the same heuristics as 3DGS / the optgs
169
+ COLMAP initializer, then assembles them through ``points_to_gaussians``.
170
+ """
171
+ if cfg.init_type == "sfm":
172
+ points = torch.from_numpy(parser.points).float()
173
+ rgbs = torch.from_numpy(parser.points_rgb / 255.0).float()
174
+ elif cfg.init_type == "random":
175
+ extent = scene_extent(parser, cfg.global_scale)
176
+ points = cfg.init_extent * extent * (
177
+ torch.rand((cfg.init_num_pts, 3)) * 2 - 1
178
+ )
179
+ rgbs = torch.rand((cfg.init_num_pts, 3))
180
+ else:
181
+ raise ValueError(f"unknown init_type: {cfg.init_type!r} (sfm | random)")
182
+
183
+ # GS size = average distance to the 3 nearest neighbours ([:, 1:] drops self).
184
+ dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1)
185
+ scales = (torch.sqrt(dist2_avg) * cfg.init_scale).unsqueeze(-1).repeat(1, 3)
186
+ opacities = torch.full((points.shape[0],), cfg.init_opa)
187
+
188
+ # points_to_gaussians returns pre-activation params (log scales, logit
189
+ # opacity, sh0/shN, random quats).
190
+ g = points_to_gaussians(
191
+ {"xyz": points, "rgb": rgbs, "scales": scales, "opacities": opacities},
192
+ sh_degree=sh_degree,
193
+ device=device,
194
+ )
195
+ sh0, shN = g["sh0"], g["shN"]
196
+ harmonics = torch.cat([sh0, shN], dim=1) if shN is not None else sh0 # [N, K, 3]
197
+ harmonics = harmonics.permute(0, 2, 1) # -> [N, 3, K]
198
+
199
+ scales_act = torch.exp(g["scales_raw"])
200
+ opacities_act = torch.sigmoid(g["opacities_raw"])
201
+ rotations = F.normalize(g["rotations_unnorm"], dim=-1)
202
+ covariances = build_covariance(scale=scales_act, rotation_xyzw=rotations)
203
+
204
+ def _b(t: Tensor) -> Tensor: # add the batch dimension and cast
205
+ return t.unsqueeze(0).to(dtype)
206
+
207
+ return Gaussians(
208
+ means=_b(g["xyz"]),
209
+ covariances=_b(covariances),
210
+ harmonics=_b(harmonics),
211
+ opacities=_b(opacities_act),
212
+ scales=_b(scales_act),
213
+ rotations=_b(rotations),
214
+ rotations_unnorm=_b(g["rotations_unnorm"]),
215
+ )
216
+
217
+
218
+ def collect_cameras(
219
+ dataset: Dataset, indices: List[int]
220
+ ) -> Tuple[Tensor, Tensor, Tensor]:
221
+ """Stack the selected views into ``(camtoworlds, Ks, images)``.
222
+
223
+ ``images`` is returned in [0, 1]. All views must share one (H, W) — the
224
+ optgs renderer takes a single image shape.
225
+ """
226
+ c2ws, ks, imgs = [], [], []
227
+ hw = None
228
+ for i in indices:
229
+ data = dataset[i]
230
+ img = data["image"] / 255.0 # [H, W, 3], float
231
+ if hw is None:
232
+ hw = img.shape[:2]
233
+ elif img.shape[:2] != hw:
234
+ raise ValueError(
235
+ f"all views must share one (H, W); got {tuple(img.shape[:2])} "
236
+ f"vs {tuple(hw)}. Render the dataset at a single resolution."
237
+ )
238
+ c2ws.append(data["camtoworld"])
239
+ ks.append(data["K"])
240
+ imgs.append(img)
241
+ return torch.stack(c2ws), torch.stack(ks), torch.stack(imgs)
242
+
243
+
244
+ def build_batched_views(
245
+ camtoworlds: Tensor,
246
+ Ks: Tensor,
247
+ images: Tensor,
248
+ scene_scale: float,
249
+ device: torch.device,
250
+ dtype: torch.dtype,
251
+ ) -> dict:
252
+ """COLMAP cameras -> an optgs ``BatchedViews`` dict (batch=1).
253
+
254
+ COLMAP ``camtoworld`` is already optgs's extrinsics convention (OpenCV
255
+ camera->world). ``K`` is pixel-space; optgs wants it normalized by image
256
+ width/height.
257
+ """
258
+ v, h, w = images.shape[0], images.shape[1], images.shape[2]
259
+
260
+ Ks_norm = Ks.clone()
261
+ Ks_norm[:, 0, :] /= w # normalized focal / principal point
262
+ Ks_norm[:, 1, :] /= h
263
+
264
+ image = images.permute(0, 3, 1, 2) # [V, 3, H, W]
265
+
266
+ def _b(t: Tensor) -> Tensor: # add the batch dimension and move to device
267
+ return t.unsqueeze(0).to(device=device, dtype=dtype)
268
+
269
+ return {
270
+ "extrinsics": _b(camtoworlds),
271
+ "intrinsics": _b(Ks_norm),
272
+ "image": _b(image),
273
+ "near": torch.full((1, v), NEAR_PLANE, device=device, dtype=dtype),
274
+ "far": torch.full((1, v), FAR_PLANE, device=device, dtype=dtype),
275
+ "index": torch.arange(v, device=device).unsqueeze(0),
276
+ "scene_scale": torch.tensor([scene_scale], device=device, dtype=dtype),
277
+ }
278
+
279
+
280
+ @torch.no_grad()
281
+ def render_and_score(
282
+ optgs,
283
+ refined: Gaussians,
284
+ val_bv: dict,
285
+ val_images: Tensor,
286
+ out_dir: str,
287
+ device: torch.device,
288
+ ) -> dict:
289
+ """Render one optimizer's result on the held-out views; report mean PSNR.
290
+
291
+ Saves a ``gt | pred`` strip per view under ``out_dir/renders``.
292
+ """
293
+ render_dir = os.path.join(out_dir, "renders")
294
+ os.makedirs(render_dir, exist_ok=True)
295
+ h, w = val_images.shape[1], val_images.shape[2]
296
+
297
+ out = optgs.decoder.forward(
298
+ refined, val_bv["extrinsics"], val_bv["intrinsics"],
299
+ val_bv["near"], val_bv["far"], image_shape=(h, w),
300
+ )
301
+ colors = out.color[0].clamp(0.0, 1.0) # [V, 3, H, W]
302
+
303
+ psnrs = []
304
+ for i in range(colors.shape[0]):
305
+ gt = val_images[i].to(device) # [H, W, 3]
306
+ pred = colors[i].permute(1, 2, 0)
307
+ psnrs.append(-10.0 * torch.log10(torch.mean((pred - gt) ** 2)).item())
308
+
309
+ canvas = torch.cat([gt, pred], dim=1).cpu().numpy() # gt | pred
310
+ imageio.imwrite(
311
+ os.path.join(render_dir, f"val_{i:04d}.png"),
312
+ (canvas * 255).astype(np.uint8),
313
+ )
314
+
315
+ return {"psnr": float(np.mean(psnrs)), "num_views": int(colors.shape[0])}
316
+
317
+
318
+ @torch.no_grad()
319
+ def render_view(
320
+ optgs, gaussians: Gaussians, camera, height: int,
321
+ device: torch.device, dtype: torch.dtype,
322
+ ) -> np.ndarray:
323
+ """Render ``gaussians`` from a viser camera into an ``[H, W, 3]`` uint8 image.
324
+
325
+ viser cameras follow OpenCV conventions, so ``(wxyz, position)`` is directly
326
+ the camera-to-world transform the optgs decoder expects — no axis flip.
327
+ """
328
+ import viser.transforms as vtf
329
+
330
+ from optgs.misc.image_io import prep_image
331
+
332
+ h = int(height)
333
+ w = max(1, round(h * camera.aspect)) # camera.aspect = width / height
334
+
335
+ c2w = torch.eye(4, device=device, dtype=dtype)
336
+ c2w[:3, :3] = torch.tensor(
337
+ vtf.SO3(camera.wxyz).as_matrix(), device=device, dtype=dtype
338
+ )
339
+ c2w[:3, 3] = torch.tensor(camera.position, device=device, dtype=dtype)
340
+
341
+ # Normalized intrinsics from the vertical fov; the decoder un-normalizes by
342
+ # the image width/height.
343
+ fy = (h / 2.0) / float(np.tan(camera.fov / 2.0))
344
+ K = torch.eye(3, device=device, dtype=dtype)
345
+ K[0, 0] = fy / w
346
+ K[1, 1] = fy / h
347
+ K[0, 2] = 0.5
348
+ K[1, 2] = 0.5
349
+
350
+ near = torch.full((1, 1), NEAR_PLANE, device=device, dtype=dtype)
351
+ far = torch.full((1, 1), FAR_PLANE, device=device, dtype=dtype)
352
+ out = optgs.decoder.forward(
353
+ gaussians, c2w[None, None], K[None, None], near, far, image_shape=(h, w),
354
+ )
355
+ return prep_image(out.color[0, 0]) # [H, W, 3] uint8
356
+
357
+
358
+ def gaussians_to_splat_data(gaussians: Gaussians) -> dict:
359
+ """An optgs ``Gaussians`` (batch=1) -> numpy arrays for viser's splat viewer.
360
+
361
+ Covariances are recomputed from scale/rotation (the optimizer updates those
362
+ but may leave the optional ``Gaussians.covariances`` field stale); colours
363
+ come from the SH DC term (degree 0 — viser's renderer is not view-dependent).
364
+ """
365
+ scales = gaussians.scales[0]
366
+ opacities = gaussians.opacities[0]
367
+ if not gaussians.stores_activated:
368
+ scales = torch.exp(scales)
369
+ opacities = torch.sigmoid(opacities)
370
+ rotations = F.normalize(gaussians.rotations_unnorm[0], dim=-1)
371
+ covariances = build_covariance(scale=scales, rotation_xyzw=rotations)
372
+ rgbs = (0.5 + SH_C0 * gaussians.harmonics[0, :, :, 0]).clamp(0.0, 1.0)
373
+
374
+ def _np(t: Tensor) -> np.ndarray:
375
+ return t.detach().cpu().numpy().astype(np.float32)
376
+
377
+ return {
378
+ "centers": _np(gaussians.means[0]), # (N, 3)
379
+ "covariances": _np(covariances), # (N, 3, 3)
380
+ "rgbs": _np(rgbs), # (N, 3)
381
+ "opacities": _np(opacities.reshape(-1, 1)), # (N, 1)
382
+ }
383
+
384
+
385
+ def run_gui(
386
+ instances: dict,
387
+ gaussians: Gaussians,
388
+ train_bv: dict,
389
+ cfg: Config,
390
+ device: torch.device,
391
+ dtype: torch.dtype,
392
+ ) -> None:
393
+ """Interactive viser GUI: watch the optimization, pick an optimizer, reset.
394
+
395
+ The initialization is shown first; the user picks an optimizer — the
396
+ Learn2Splat learned optimizer (dense or sparse checkpoint) or a 3DGS Adam
397
+ baseline — and clicks Start; every optimizer step is rendered and displayed;
398
+ Reset restores the initialization. ``cfg.with_gui`` chooses the renderer —
399
+ "server" (optgs decoder, frames streamed as images) or "client" (viser's
400
+ WebGL splats).
401
+
402
+ ``instances`` maps "dense"/"sparse" to their initialized ``OptGS``.
403
+ """
404
+ import threading
405
+
406
+ import viser
407
+ import viser.transforms as vtf
408
+
409
+ from optgs.experimental.api.integration.config_bridge import build_adam_baseline
410
+
411
+ mode = cfg.with_gui # "server" | "client"
412
+ server = viser.ViserServer(port=cfg.gui_port)
413
+
414
+ # Optimizer dropdown label -> (instances key, whether to swap in Adam).
415
+ # "dense"/"sparse" run that checkpoint's own learned optimizer; "Adam" runs
416
+ # a 3DGS Adam baseline on the dense checkpoint's pipeline.
417
+ OPTIONS: Dict[str, Tuple[str, bool]] = {
418
+ "Learn2Splat (dense)": ("dense", False),
419
+ "Learn2Splat (sparse)": ("sparse", False),
420
+ "Adam (3DGS)": ("dense", True),
421
+ }
422
+
423
+ optimizer_dd = server.gui.add_dropdown("Optimizer", tuple(OPTIONS))
424
+
425
+ # Optimization controls — applied to the picked OptGS at Start; frozen
426
+ # while optimizing, unfrozen by Reset. opt_batch_size is capped at the
427
+ # number of training views (the per-step view minibatch can't exceed them).
428
+ n_train_views = int(train_bv["image"].shape[1])
429
+ max_steps_input = server.gui.add_number(
430
+ "Max steps", min=1, max=1000, step=1, initial_value=cfg.max_steps
431
+ )
432
+ batch_size_input = server.gui.add_number(
433
+ "Opt batch size", min=1, max=n_train_views, step=1,
434
+ initial_value=min(cfg.opt_batch_size, n_train_views),
435
+ )
436
+ strategy_dd = server.gui.add_dropdown(
437
+ "Opt batch strategy", ("random", "sequential", "fps"),
438
+ initial_value=cfg.opt_batch_strategy,
439
+ )
440
+ opt_controls = (max_steps_input, batch_size_input, strategy_dd)
441
+
442
+ start_btn = server.gui.add_button("Start optimization")
443
+ reset_btn = server.gui.add_button("Reset to initialization")
444
+ status = server.gui.add_markdown("**initialized** — pick an optimizer, then Start")
445
+ res_slider = (
446
+ server.gui.add_slider(
447
+ "Render height", min=240, max=1080, step=60, initial_value=540
448
+ )
449
+ if mode == "server"
450
+ else None
451
+ )
452
+
453
+ init_gaussians = gaussians.clone() # pristine copy, for Reset
454
+ current = init_gaussians # Gaussians currently displayed
455
+ active = instances["dense"] # OptGS used to render + to optimize next
456
+ gen = None # optimize_iter generator while running
457
+ last_cam_ts: dict = {} # client id -> last-rendered camera stamp
458
+ lock = threading.Lock()
459
+ state = {
460
+ "mode": "init", # "init" | "optimizing" | "done"
461
+ "step": 0,
462
+ "start": False,
463
+ "reset": False,
464
+ "rerender": False, # a GUI control changed -> re-render once
465
+ "selected": next(iter(OPTIONS)),
466
+ }
467
+
468
+ @start_btn.on_click
469
+ def _(_) -> None:
470
+ with lock:
471
+ if state["mode"] in ("init", "done"):
472
+ state["selected"] = optimizer_dd.value
473
+ state["start"] = True
474
+
475
+ @reset_btn.on_click
476
+ def _(_) -> None:
477
+ with lock:
478
+ state["reset"] = True
479
+
480
+ # The render-height slider only affects server-rendered frames; re-render
481
+ # on change so the new resolution takes effect without a camera move.
482
+ if res_slider is not None:
483
+
484
+ @res_slider.on_update
485
+ def _(_) -> None:
486
+ with lock:
487
+ state["rerender"] = True
488
+
489
+ # Frame newly-connected clients on the first training camera (viser and
490
+ # optgs share the OpenCV camera-to-world convention).
491
+ cam_extr = train_bv["extrinsics"][0, 0].detach().cpu().numpy()
492
+
493
+ @server.on_client_connect
494
+ def _(client) -> None:
495
+ try:
496
+ client.camera.position = cam_extr[:3, 3]
497
+ client.camera.wxyz = vtf.SO3.from_matrix(cam_extr[:3, :3]).wxyz
498
+ except Exception:
499
+ pass
500
+
501
+ if mode == "client": # show the initialization immediately
502
+ # Black backdrop for the WebGL splat renderer (viser's canvas is not
503
+ # black by default); on server.scene so late-joining clients get it.
504
+ server.scene.set_background_image(np.zeros((8, 8, 3), dtype=np.uint8))
505
+ server.scene.add_gaussian_splats(
506
+ "/optgs/splats", **gaussians_to_splat_data(current)
507
+ )
508
+
509
+ console.print(
510
+ f"[green]✓[/] viser GUI ([cyan]{mode}[/]) on port [cyan]{cfg.gui_port}[/]"
511
+ f" — forward the port over SSH and open the printed URL"
512
+ )
513
+
514
+ try:
515
+ while True:
516
+ changed = False
517
+
518
+ with lock:
519
+ do_reset, do_start = state["reset"], state["start"]
520
+ do_rerender = state["rerender"]
521
+ state["reset"] = state["start"] = state["rerender"] = False
522
+ selected = state["selected"]
523
+
524
+ if do_rerender:
525
+ changed = True # server mode re-renders every connected client
526
+
527
+ if do_reset:
528
+ if gen is not None:
529
+ gen.close() # runs optimize_iter's finally -> on_scene_end()
530
+ gen = None
531
+ current = init_gaussians
532
+ with lock:
533
+ state["mode"], state["step"] = "init", 0
534
+ optimizer_dd.disabled = start_btn.disabled = False
535
+ for c in opt_controls:
536
+ c.disabled = False
537
+ changed = True
538
+
539
+ if do_start and gen is None:
540
+ name, use_adam = OPTIONS[selected]
541
+ active = instances[name]
542
+ # Apply the GUI optimization controls before the run starts.
543
+ active.num_refine = int(max_steps_input.value)
544
+ active.opt_batch_size = int(batch_size_input.value)
545
+ active.opt_batch_strategy = strategy_dd.value
546
+ opt = (
547
+ build_adam_baseline(active.num_refine).to(device)
548
+ if use_adam
549
+ else None
550
+ )
551
+ gen = active.optimize_iter(optimizer=opt)
552
+ with lock:
553
+ state["mode"], state["step"] = "optimizing", 0
554
+ optimizer_dd.disabled = start_btn.disabled = True
555
+ for c in opt_controls:
556
+ c.disabled = True
557
+
558
+ if gen is not None:
559
+ try:
560
+ step, current = next(gen)
561
+ changed = True
562
+ with lock:
563
+ state["step"] = step + 1
564
+ except StopIteration:
565
+ gen = None
566
+ with lock:
567
+ state["mode"] = "done"
568
+ optimizer_dd.disabled = start_btn.disabled = False
569
+
570
+ if mode == "server":
571
+ for cid, client in server.get_clients().items():
572
+ try:
573
+ cam_ts = client.camera.update_timestamp
574
+ if last_cam_ts.get(cid) != cam_ts or changed:
575
+ last_cam_ts[cid] = cam_ts
576
+ image = render_view(
577
+ active, current, client.camera,
578
+ res_slider.value, device, dtype,
579
+ )
580
+ client.scene.set_background_image(image, format="jpeg")
581
+ except Exception:
582
+ continue # no camera message from this client yet
583
+ elif changed: # client mode — re-push splats when the Gaussians change
584
+ server.scene.add_gaussian_splats(
585
+ "/optgs/splats", **gaussians_to_splat_data(current)
586
+ )
587
+
588
+ with lock:
589
+ status.content = (
590
+ f"**{state['mode']}** — step "
591
+ f"{state['step']}/{active.num_refine} — "
592
+ f"{current.means.shape[1]} Gaussians"
593
+ )
594
+
595
+ if gen is None:
596
+ time.sleep(1 / 30) # idle: poll cameras at ~30 Hz
597
+ except KeyboardInterrupt:
598
+ if gen is not None:
599
+ gen.close()
600
+ console.print("\n[yellow]GUI stopped.[/]")
601
+
602
+
603
+ def main(cfg: Config) -> None:
604
+ # Fetch the demo scene on first run, before anything else touches it.
605
+ ensure_data(cfg.data_dir)
606
+
607
+ from optgs.experimental.api import OptGS, OptGSError
608
+ from optgs.experimental.api.integration.config_bridge import build_adam_baseline
609
+
610
+ os.makedirs(cfg.result_dir, exist_ok=True)
611
+ device = torch.device(cfg.device)
612
+ dtype = torch.float32
613
+
614
+ console.rule("[bold cyan]OptGS demo[/] · Learn2Splat vs Adam")
615
+
616
+ # --- COLMAP scene, train/val split ---
617
+ parser = Parser(
618
+ data_dir=cfg.data_dir,
619
+ factor=cfg.data_factor,
620
+ normalize=cfg.normalize_world_space,
621
+ verbose=False,
622
+ )
623
+ dataset = Dataset(parser)
624
+ val_idx = [i for i in range(len(dataset)) if i % cfg.test_every == 0]
625
+ train_idx = [i for i in range(len(dataset)) if i % cfg.test_every != 0]
626
+ scene_scale = scene_extent(parser, cfg.global_scale)
627
+ console.print(
628
+ f"scene scale [cyan]{scene_scale:.4f}[/] · "
629
+ f"train [cyan]{len(train_idx)}[/] · val [cyan]{len(val_idx)}[/]"
630
+ )
631
+ train_bv = build_batched_views(
632
+ *collect_cameras(dataset, train_idx), scene_scale, device, dtype
633
+ )
634
+
635
+ # --- Interactive GUI: build both learned-optimizer checkpoints (dense and
636
+ # sparse), initialize each, and hand off to the viser GUI instead of the
637
+ # headless comparison. The GUI's Optimizer dropdown picks between them. ---
638
+ if cfg.with_gui is not None:
639
+ instances = {}
640
+ for name in ("dense", "sparse"):
641
+ try:
642
+ instances[name] = OptGS(
643
+ checkpoint=CHECKPOINTS[name],
644
+ device=cfg.device,
645
+ num_refine=cfg.max_steps,
646
+ opt_batch_size=cfg.opt_batch_size,
647
+ opt_batch_strategy=cfg.opt_batch_strategy,
648
+ rasterize_mode=cfg.rasterize_mode,
649
+ eps2d=cfg.eps2d,
650
+ )
651
+ except OptGSError as e:
652
+ console.print(f"[bold red]OptGS error ({name}):[/] {e}")
653
+ raise SystemExit(1)
654
+
655
+ # One SfM init shared by both checkpoints: dense and sparse get an
656
+ # identical starting point, and the GUI shows a single initialization
657
+ # regardless of which optimizer is picked.
658
+ torch.manual_seed(cfg.seed)
659
+ np.random.seed(cfg.seed)
660
+ gaussians = sfm_initialization(
661
+ parser, cfg, instances["dense"].sh_degree, device, dtype
662
+ )
663
+ for inst in instances.values():
664
+ inst.initialize_from_tensors(gaussians, train_bv)
665
+
666
+ run_gui(instances, gaussians, train_bv, cfg, device, dtype)
667
+ return
668
+
669
+ val_c2w, val_Ks, val_images = collect_cameras(dataset, val_idx)
670
+ val_bv = build_batched_views(val_c2w, val_Ks, val_images, scene_scale, device, dtype)
671
+
672
+ results: dict = {}
673
+
674
+ def finish(optgs, refined, name: str, elapsed: float) -> None:
675
+ """Persist + evaluate one run's result under results/demo/<name>/."""
676
+ out_dir = os.path.join(cfg.result_dir, name)
677
+ os.makedirs(out_dir, exist_ok=True)
678
+ optgs.export_ply(os.path.join(out_dir, "point_cloud.ply"))
679
+ ev = render_and_score(optgs, refined, val_bv, val_images, out_dir, device)
680
+ results[name] = {
681
+ "psnr": ev["psnr"], "time": elapsed,
682
+ "num_views": ev["num_views"], "num_GS": int(refined.means.shape[1]),
683
+ }
684
+ console.print(
685
+ f"[green]✓[/] [bold]{name}[/] — PSNR [cyan]{ev['psnr']:.3f}[/] · "
686
+ f"[cyan]{elapsed:.1f}s[/] → [yellow]{out_dir}[/]"
687
+ )
688
+
689
+ # --- Learned optimizer (Learn2Splat): dense, then sparse ---
690
+ optgs = None
691
+ for name in ("dense", "sparse"):
692
+ optgs = None # free the previous instance before building the next
693
+ torch.cuda.empty_cache()
694
+ try:
695
+ optgs = OptGS(
696
+ checkpoint=CHECKPOINTS[name],
697
+ device=cfg.device,
698
+ num_refine=cfg.max_steps,
699
+ opt_batch_size=cfg.opt_batch_size,
700
+ opt_batch_strategy=cfg.opt_batch_strategy,
701
+ rasterize_mode=cfg.rasterize_mode,
702
+ eps2d=cfg.eps2d,
703
+ )
704
+ except OptGSError as e:
705
+ console.print(f"[bold red]OptGS error ({name}):[/] {e}")
706
+ raise SystemExit(1)
707
+ # Seed *after* construction so dense and sparse get an identical SfM init.
708
+ torch.manual_seed(cfg.seed)
709
+ np.random.seed(cfg.seed)
710
+ gaussians = sfm_initialization(parser, cfg, optgs.sh_degree, device, dtype)
711
+ optgs.initialize_from_tensors(gaussians, train_bv)
712
+
713
+ torch.cuda.synchronize() # drain setup GPU work so it isn't timed
714
+ tic = time.time()
715
+ refined = optgs.optimize()
716
+ torch.cuda.synchronize()
717
+ finish(optgs, refined, name, time.time() - tic)
718
+
719
+ # --- Fair Adam baseline: same SfM init / views / step budget / gsplat
720
+ # renderer, run through the same optimize() path on the last OptGS
721
+ # instance — only the update rule differs. ---
722
+ adam = build_adam_baseline(optgs.num_refine).to(device)
723
+ torch.cuda.synchronize() # drain setup GPU work so it isn't timed
724
+ tic = time.time()
725
+ refined_adam = optgs.optimize(optimizer=adam)
726
+ torch.cuda.synchronize()
727
+ finish(optgs, refined_adam, "adam", time.time() - tic)
728
+
729
+ # --- Comparison table ---
730
+ table = Table(
731
+ title=(
732
+ f"Novel-view PSNR · {results['dense']['num_views']} held-out "
733
+ f"views · {cfg.max_steps} steps · "
734
+ f"{results['dense']['num_GS']} Gaussians"
735
+ ),
736
+ title_style="bold",
737
+ caption=(
738
+ f"gsplat renderer · "
739
+ f"rasterize_mode={cfg.rasterize_mode or 'per-checkpoint'} · "
740
+ f"eps2d={cfg.eps2d if cfg.eps2d is not None else 'per-checkpoint'}"
741
+ ),
742
+ )
743
+ table.add_column("Optimizer")
744
+ table.add_column("PSNR (dB)", justify="right")
745
+ table.add_column("Time (s)", justify="right")
746
+ best = max(results, key=lambda k: results[k]["psnr"])
747
+ for key, label in (
748
+ ("dense", "Learn2Splat (dense)"),
749
+ ("sparse", "Learn2Splat (sparse)"),
750
+ ("adam", "Adam"),
751
+ ):
752
+ table.add_row(
753
+ label,
754
+ f"{results[key]['psnr']:.3f}",
755
+ f"{results[key]['time']:.1f}",
756
+ style="bold green" if key == best else None,
757
+ )
758
+ console.print(table)
759
+
760
+ with open(os.path.join(cfg.result_dir, "stats.json"), "w") as f:
761
+ json.dump(results, f, indent=2)
762
+ console.print(f"[green]✓[/] results written to [yellow]{cfg.result_dir}[/]")
763
+
764
+
765
+ if __name__ == "__main__":
766
+ main(tyro.cli(Config))
optgs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """optgs — learned optimization for 3D Gaussian Splatting."""
optgs/config.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Literal, Optional, Type, TypeVar, Any, Callable
6
+
7
+ import hydra
8
+ import torch
9
+ from dacite import Config, from_dict, UnionMatchError
10
+ from hydra.core.global_hydra import GlobalHydra
11
+ from hydra.core.hydra_config import HydraConfig
12
+ from hydra.types import RunMode
13
+ from omegaconf import DictConfig
14
+ from omegaconf import OmegaConf
15
+ from pytorch_lightning.strategies import DDPStrategy, FSDPStrategy
16
+
17
+ from .config_migrate import migrate, CURRENT_CFG_VERSION
18
+ from .dataset.data_module import DataLoaderCfg, DatasetCfg
19
+ from .global_cfg import set_cfg
20
+ from .loss import LossCfgWrapper
21
+ from .misc.io import CustomPath
22
+ from .misc.io import cyan, read_omega_cfg
23
+ from .misc.checkpointing import find_latest_ckpt
24
+ from .misc.hf_ckpt import maybe_resolve_hf_ref
25
+ from .paths import CKPT_DIR, RESULTS_DIR
26
+ from .scene_trainer.scene_trainer_cfg import SceneTrainerCfg, MetaOptimizerCfg, TestCfg, TrainCfg
27
+
28
+
29
+ # In order to extract filename or dirname from a path in the config
30
+ def checkpoint_rel_dir(path):
31
+ rel_dir = CustomPath(path) - CKPT_DIR # dir_path / checkpoints / epoch_x-step_xxxxx.ckpt
32
+ dir_path = rel_dir.parent.parent
33
+ return str(dir_path)
34
+
35
+
36
+ OmegaConf.register_new_resolver("checkpoint_rel_dir", checkpoint_rel_dir)
37
+ OmegaConf.register_new_resolver("parent_dir", lambda path: str(CustomPath(path).parent))
38
+
39
+
40
+ @dataclass
41
+ class CheckpointingCfg:
42
+ load: Optional[str] # Not a path, since it could be something like wandb://...
43
+ every_n_train_steps: int
44
+ save_top_k: int
45
+ pretrained_model: Optional[str]
46
+ pretrained_monodepth: Optional[str]
47
+ pretrained_mvdepth: Optional[str]
48
+ pretrained_depth: Optional[str]
49
+ pretrained_scale_predictor: Optional[str]
50
+ pretrained_depth_teacher: Optional[str]
51
+ no_strict_load: bool
52
+ resume: bool
53
+ no_resume_upsampler: bool
54
+ partial_load: bool
55
+ freeze_mono_vit: bool
56
+ pretrained_initializer: Optional[str]
57
+ pretrained_optimizer: Optional[str]
58
+ resume_update_module: str | None
59
+ load_existing_cfg: bool
60
+
61
+ def __post_init__(self):
62
+ # Resolve any Hugging Face Hub references (hf://org/repo/file[@rev]) to
63
+ # local cached paths so all downstream torch.load calls work unchanged.
64
+ for attr in ("pretrained_model", "pretrained_optimizer", "pretrained_initializer",
65
+ "pretrained_monodepth", "pretrained_mvdepth", "pretrained_depth",
66
+ "pretrained_scale_predictor", "pretrained_depth_teacher",
67
+ "resume_update_module"):
68
+ resolved = maybe_resolve_hf_ref(getattr(self, attr))
69
+ if resolved != getattr(self, attr):
70
+ setattr(self, attr, resolved)
71
+
72
+ for attr in ("pretrained_model", "pretrained_optimizer", "pretrained_initializer"):
73
+ path = getattr(self, attr)
74
+ if path is not None and Path(path).name == "last":
75
+ try:
76
+ resolved = find_latest_ckpt(Path(path).parent)
77
+ setattr(self, attr, resolved)
78
+ print(f"Replacing {attr} to last checkpoint: {resolved}")
79
+ except Exception as e:
80
+ print(cyan(f"Warning: {e}. Continuing with 'last' as {attr}."))
81
+
82
+
83
+ @dataclass
84
+ class MetaTrainerCfg:
85
+ max_steps: int
86
+ val_check_interval: int | float | None
87
+ gradient_clip_val: int | float | None
88
+ num_sanity_val_steps: int
89
+ num_nodes: int
90
+ eval_index: str | None
91
+ limit_test_batches: int | float
92
+ limit_train_batches: int | float
93
+ test: TestCfg
94
+ train: TrainCfg
95
+
96
+ def get_dist_strategy(self, scene_trainer_cfg: SceneTrainerCfg):
97
+ from .scene_trainer.initializer.initializer_resplat import ResplatInitializerCfg
98
+ dist_strategy = "auto"
99
+ if torch.cuda.device_count() > 1:
100
+ dist_strategy = 'ddp'
101
+ if isinstance(scene_trainer_cfg.scene_optimizer, ResplatInitializerCfg):
102
+ if scene_trainer_cfg.scene_initializer.use_gt_depth:
103
+ dist_strategy = 'ddp_find_unused_parameters_true'
104
+ if scene_trainer_cfg.scene_initializer.use_checkpointing or scene_trainer_cfg.scene_initializer.init_use_checkpointing:
105
+ dist_strategy = DDPStrategy(static_graph=True)
106
+ if scene_trainer_cfg.use_fsdp:
107
+ def only_wrap_trainable(module, recurse, nonwrapped_numel):
108
+ has_trainable = any(p.requires_grad for p in module.parameters())
109
+ return has_trainable
110
+
111
+ dist_strategy = FSDPStrategy(auto_wrap_policy=only_wrap_trainable)
112
+ if self.train.use_replay_buffer:
113
+ # When resampling from the replay buffer,
114
+ # we don't project the condition_features to state, so the update_proj is not used
115
+ dist_strategy = "ddp_find_unused_parameters_true"
116
+ return dist_strategy
117
+
118
+
119
+ @dataclass
120
+ class RootCfg:
121
+ wandb: dict
122
+ mode: Literal["train", "test"]
123
+ dataset: DatasetCfg
124
+ data_loader: DataLoaderCfg
125
+ scene_trainer: SceneTrainerCfg
126
+ meta_optimizer: MetaOptimizerCfg ## TODO Naama: should we move under meta trainer config?
127
+ checkpointing: CheckpointingCfg
128
+ meta_trainer: MetaTrainerCfg
129
+ loss: list[LossCfgWrapper]
130
+ seed: int
131
+ use_plugins: bool
132
+ output_dir: str
133
+ version: int | None
134
+ debug_cfg: bool
135
+
136
+ def __post_init__(self):
137
+ if self.mode == "test":
138
+ self._setup_test_output_dir()
139
+
140
+ def _setup_test_output_dir(self):
141
+ base_res_dir = RESULTS_DIR
142
+ if self.meta_trainer.limit_test_batches != 1.0:
143
+ base_res_dir = RESULTS_DIR + f"_{self.meta_trainer.limit_test_batches}_scenes"
144
+ if self.output_dir == "placeholder":
145
+ if self.meta_trainer.test.postprocessing is not None and self.meta_trainer.test.postprocessing.is_active:
146
+ self.output_dir = (base_res_dir /
147
+ "nonlearned" /
148
+ "vanilla_3dgs" /
149
+ self.meta_trainer.test.postprocessing.name /
150
+ self.meta_trainer.test.postprocessing.get_dir_name(with_name=False))
151
+ else:
152
+ ckpt_path = self.checkpointing.pretrained_model or self.checkpointing.pretrained_optimizer
153
+ pretrained_model_rel_dir = checkpoint_rel_dir(ckpt_path)
154
+ self.output_dir = (base_res_dir /
155
+ "optgs" /
156
+ pretrained_model_rel_dir)
157
+ elif 'experimental' in str(self.output_dir): # TODO (release): remove
158
+ self._setup_experimental_output_dir()
159
+
160
+ def _setup_experimental_output_dir(self):
161
+ resplat_str = []
162
+ grad_str = []
163
+ normgrad_str = []
164
+ assert self.scene_trainer.scene_optimizer.experimental_run
165
+ for p in self.scene_trainer.scene_optimizer.experimental_update.param_names:
166
+ update = getattr(self.scene_trainer.scene_optimizer.experimental_update, p)
167
+ use_norm_grad = getattr(self.scene_trainer.scene_optimizer.experimental_use_norm_grads, p)
168
+ use_grad = self.scene_trainer.scene_optimizer.experimental_use_grads and not use_norm_grad
169
+ use_resplat = update and not use_grad and not use_norm_grad
170
+ if update:
171
+ assert use_grad ^ use_norm_grad ^ use_resplat, f"Invalid combination for {p}: use_resplat={use_resplat}, use_grad={use_grad}, use_norm_grad={use_norm_grad}"
172
+ if use_resplat:
173
+ resplat_str.append(p)
174
+ if use_grad:
175
+ grad_str.append(p)
176
+ if use_norm_grad:
177
+ normgrad_str.append(p)
178
+
179
+ if len(resplat_str) == len(self.scene_trainer.scene_optimizer.experimental_update.param_names):
180
+ resplat_str = ["all"]
181
+ if len(grad_str) == len(self.scene_trainer.scene_optimizer.experimental_update.param_names):
182
+ grad_str = ["all"]
183
+ if len(normgrad_str) == len(self.scene_trainer.scene_optimizer.experimental_update.param_names):
184
+ normgrad_str = ["all"]
185
+
186
+ exp_name = "_".join([
187
+ ("resplat_" + "_".join(resplat_str) if len(resplat_str) > 0 else ""),
188
+ ("grad_" + "_".join(grad_str) if len(grad_str) > 0 else ""),
189
+ ("normgrad_" + "_".join(normgrad_str) if len(normgrad_str) > 0 else ""),
190
+ ])
191
+
192
+ output_dir_str = str(self.output_dir)
193
+ output_dir_str = output_dir_str.replace("experimental", f"experimental_{exp_name}")
194
+ self.output_dir = Path(output_dir_str)
195
+ print(cyan(f"Experimental run, setting output_dir to {CustomPath(self.output_dir)}"))
196
+
197
+
198
+ TYPE_HOOKS = {
199
+ Path: Path,
200
+ }
201
+
202
+ T = TypeVar("T")
203
+
204
+
205
+ def get_class_by_path(path: str):
206
+ module_path, class_name = path.rsplit('.', 1)
207
+ module = importlib.import_module(module_path)
208
+ return getattr(module, class_name)
209
+
210
+
211
+ def _diagnose_union_error(e: UnionMatchError, data: dict, dacite_config: Config) -> str:
212
+ """Try each union member individually and report per-member errors."""
213
+ import dataclasses
214
+ import typing
215
+ union_type = e.field_type
216
+ # Extract the member types from the union
217
+ args = typing.get_args(union_type)
218
+ if not args:
219
+ return str(e)
220
+ lines = [str(e), "", "Per-member diagnostics:"]
221
+ for member_type in args:
222
+ try:
223
+ from_dict(member_type, data, config=dacite_config)
224
+ lines.append(f" {member_type.__name__}: matched OK (unexpected)")
225
+ except Exception as member_err:
226
+ lines.append(f" {member_type.__name__}: {member_err}")
227
+ # For dataclasses, also check for extra/missing fields
228
+ if dataclasses.is_dataclass(member_type):
229
+ expected = {f.name for f in dataclasses.fields(member_type)}
230
+ provided = set(data.keys()) if isinstance(data, dict) else set()
231
+ missing = expected - provided
232
+ extra = provided - expected
233
+ if missing:
234
+ lines.append(f" missing fields: {missing}")
235
+ if extra:
236
+ lines.append(f" extra fields (ignored with strict=False): {extra}")
237
+ return "\n".join(lines)
238
+
239
+
240
+ def load_typed_config(
241
+ cfg: DictConfig,
242
+ data_class: Type[T],
243
+ extra_type_hooks: dict = {},
244
+ ) -> T:
245
+ dacite_config = Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks})
246
+ try:
247
+ return from_dict(
248
+ data_class,
249
+ OmegaConf.to_container(cfg),
250
+ config=dacite_config,
251
+ )
252
+ except UnionMatchError as e:
253
+ diagnostic = _diagnose_union_error(e, e.value, dacite_config)
254
+ print(f"\n{'='*60}\n"
255
+ f"Current config: {e.value}\n"
256
+ "\n"
257
+ "\n"
258
+ f"UnionMatchError diagnostic:\n{diagnostic}\n{'='*60}"
259
+ f"\n",
260
+ flush=True)
261
+ raise
262
+
263
+
264
+ def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]:
265
+ # The dummy allows the union to be converted.
266
+ @dataclass
267
+ class Dummy:
268
+ dummy: LossCfgWrapper
269
+
270
+ return [
271
+ load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy
272
+ for k, v in joined.items()
273
+ ]
274
+
275
+
276
+ def universal_target_hook(cfg: dict, _: Type) -> Any:
277
+ """Generic hook to construct config objects from `__target__`."""
278
+ if not isinstance(cfg, dict):
279
+ return None
280
+ if "__target__" not in cfg:
281
+ return None # Let decite handle it
282
+
283
+ cfg_copy = deepcopy(cfg) # avoid mutating original
284
+ target = cfg_copy.pop("__target__")
285
+
286
+ if isinstance(target, str):
287
+ target_type = get_class_by_path(target)
288
+ else:
289
+ target_type = target
290
+
291
+ # Use recursive loading with known additional hooks
292
+ return load_typed_config(
293
+ DictConfig(cfg_copy),
294
+ target_type,
295
+ )
296
+
297
+
298
+ def make_target_hook_for_type(t: Type) -> Callable:
299
+ return lambda cfg: universal_target_hook(cfg, t)
300
+
301
+
302
+ def load_typed_root_config(cfg: DictConfig) -> RootCfg:
303
+ # scene_trainer/scene_optimizer=none loads a full dict from none.yaml;
304
+ # dacite can't match that dict to the None arm of SceneOptimizerCfg | None.
305
+ # Convert it to Python None here so dacite matches correctly.
306
+ scene_opt = OmegaConf.select(cfg, "scene_trainer.scene_optimizer")
307
+ if isinstance(scene_opt, DictConfig) and OmegaConf.select(scene_opt, "name") == "none":
308
+ OmegaConf.set_struct(cfg, False)
309
+ OmegaConf.update(cfg, "scene_trainer.scene_optimizer", None, merge=False)
310
+ OmegaConf.set_struct(cfg, True)
311
+
312
+ return load_typed_config(
313
+ cfg,
314
+ RootCfg,
315
+ {list[LossCfgWrapper]: separate_loss_cfg_wrappers}
316
+ )
317
+
318
+
319
+ def should_run(cfg_dict):
320
+ if cfg_dict.mode == "test":
321
+ if cfg_dict.meta_trainer.test.skip_if_outputs_exist:
322
+ output_dir = cfg_dict.output_dir
323
+ if not output_dir.exists():
324
+ return True
325
+ metrics_path_pattern = output_dir / "metrics" / "target_*_psnr.json"
326
+ metric_paths = list(metrics_path_pattern.parent.glob(metrics_path_pattern.name))
327
+ if len(metric_paths) > 0:
328
+ print(cyan(f"Test metrics already exist at {metric_paths}."))
329
+ return False
330
+ return True
331
+
332
+
333
+ def setup_cfg(cfg_dict):
334
+ # Get the original config from the output directory, when testing or resuming.
335
+ cfg_dict = merge_config_from_file(cfg_dict)
336
+ eval_cfg = get_eval_cfg(cfg_dict)
337
+ cfg = load_typed_root_config(cfg_dict)
338
+ # Set global cfg object.
339
+ set_cfg(cfg_dict)
340
+ # Set up the output directory.
341
+ setup_output_dir(cfg, cfg_dict)
342
+ return cfg, cfg_dict, eval_cfg # TODO Naama: why do we need both cfg and cfg_dict?
343
+
344
+
345
+ def flatten_wandb(cfg):
346
+ """Recursively replace {'desc': ..., 'value': v} with v."""
347
+ if isinstance(cfg, dict):
348
+ if "value" in cfg and len(cfg) == 2 and "desc" in cfg:
349
+ return flatten_wandb(cfg["value"])
350
+ return {k: flatten_wandb(v) for k, v in cfg.items()}
351
+ elif isinstance(cfg, list):
352
+ return [flatten_wandb(v) for v in cfg]
353
+ else:
354
+ return cfg
355
+
356
+
357
+ def _apply_cli_overrides(merged_cfg: DictConfig, orig_cli_cfg: DictConfig, raw_overrides: list[str]) -> DictConfig:
358
+ """
359
+ Re-apply CLI overrides onto merged_cfg after the checkpoint config has been merged in.
360
+
361
+ Takes already-composed values from orig_cli_cfg rather than re-parsing the raw override
362
+ strings. This correctly handles:
363
+ - Group overrides (e.g. dataset/view_sampler=evaluation) → replace subtree from cli
364
+ - Complex values (e.g. loss=[mse,ssim]) → replace subtree from cli
365
+ - Interpolated values (e.g. output_dir=${...}) → take resolved value from cli
366
+ - Defaults-list overrides (+experiment=re10k) → skip (already baked into orig_cli_cfg)
367
+ """
368
+ if not raw_overrides:
369
+ return merged_cfg
370
+
371
+ from hydra.core.override_parser.overrides_parser import OverridesParser
372
+ parser = OverridesParser.create()
373
+ parsed = parser.parse_overrides(raw_overrides)
374
+
375
+ print(cyan(f"Re-applying {len(raw_overrides)} CLI overrides onto merged config."))
376
+ OmegaConf.set_struct(merged_cfg, False)
377
+
378
+ # Architecture subtrees: CLI group default fills in *new* fields only;
379
+ # checkpoint values win for fields that already exist.
380
+ ARCH_KEYS = {"scene_optimizer", "scene_initializer"}
381
+ # Sub-keys within ARCH_KEYS where CLI should always win over checkpoint values.
382
+ CLI_WINS_SUBKEYS = {"refiner"}
383
+
384
+ for override in parsed:
385
+ key = override.key_or_group
386
+ dotkey = key.replace("/", ".")
387
+
388
+ cli_val = OmegaConf.select(orig_cli_cfg, dotkey, default=None, throw_on_resolution_failure=False)
389
+
390
+ if cli_val is None:
391
+ # No direct config path — e.g. +experiment=re10k is a defaults-list override
392
+ # whose effect is already baked into orig_cli_cfg; nothing to apply.
393
+ print(cyan(f" Skipping '{key}' (no direct config path in cli)"))
394
+ continue
395
+
396
+ # For architecture group overrides: fill in missing fields from CLI defaults
397
+ # without overriding checkpoint values for fields that already exist.
398
+ is_group_override = "/" in key or isinstance(cli_val, (DictConfig, dict, list))
399
+ if is_group_override and any(arch_key in dotkey for arch_key in ARCH_KEYS):
400
+ # If the override targets a CLI-wins sub-key directly, CLI wins entirely.
401
+ dotkey_parts = set(dotkey.split("."))
402
+ if dotkey_parts & CLI_WINS_SUBKEYS:
403
+ OmegaConf.update(merged_cfg, dotkey, cli_val, merge=False)
404
+ print(cyan(f" '{dotkey}': replace from cli (CLI wins)"))
405
+ continue
406
+
407
+ existing_val = OmegaConf.select(merged_cfg, dotkey, default=None)
408
+ if existing_val is not None:
409
+ # cli_val provides new defaults; existing_val (checkpoint) wins for shared fields
410
+ new_val = OmegaConf.merge(cli_val, existing_val)
411
+ # Re-apply CLI-wins sub-keys so they override checkpoint values.
412
+ for subkey in CLI_WINS_SUBKEYS:
413
+ cli_subval = OmegaConf.select(cli_val, subkey, default=None)
414
+ if cli_subval is not None:
415
+ OmegaConf.set_struct(new_val, False)
416
+ OmegaConf.update(new_val, subkey, cli_subval, merge=False)
417
+ print(cyan(f" '{dotkey}.{subkey}': CLI override applied (CLI wins)"))
418
+ OmegaConf.update(merged_cfg, dotkey, new_val, merge=False)
419
+ print(cyan(f" '{dotkey}': fill-missing from cli (checkpoint values preserved)"))
420
+ continue
421
+
422
+ # Group overrides and complex values replace the whole subtree;
423
+ # scalars are merged so sibling keys are preserved.
424
+ replace = is_group_override
425
+ print(cyan(f" '{dotkey}': {'replace' if replace else 'update'} from cli"))
426
+ OmegaConf.update(merged_cfg, dotkey, cli_val, merge=not replace)
427
+
428
+ OmegaConf.set_struct(merged_cfg, True)
429
+ return merged_cfg
430
+
431
+
432
+ def _print_cfg_diff(before: dict, after: dict, prefix: str = "") -> None:
433
+ """Recursively print keys that differ between two plain-dict config snapshots."""
434
+ all_keys = set(before) | set(after)
435
+ diffs = []
436
+ for k in sorted(all_keys):
437
+ full_key = f"{prefix}.{k}" if prefix else k
438
+ b_val = before.get(k, "<missing>")
439
+ a_val = after.get(k, "<missing>")
440
+ if isinstance(b_val, dict) and isinstance(a_val, dict):
441
+ _print_cfg_diff(b_val, a_val, prefix=full_key)
442
+ elif b_val != a_val:
443
+ diffs.append((full_key, b_val, a_val))
444
+ for full_key, b_val, a_val in diffs:
445
+ print(cyan(f" [cfg diff] {full_key}: {b_val!r} → {a_val!r}"))
446
+
447
+
448
+ def _find_config_for_checkpoint(ckpt_path) -> Path | None:
449
+ """Return the config.yaml path for a given checkpoint, or None."""
450
+ p = Path(ckpt_path).parent.parent / "config.yaml"
451
+ if p.exists():
452
+ return p
453
+ # Fall back to wandb latest-run
454
+ p = Path(ckpt_path).parent.parent / "wandb" / "latest-run" / "files" / "config.yaml"
455
+ if p.exists():
456
+ return p
457
+ return None
458
+
459
+
460
+ def _load_checkpoint_cfg(config_path: Path) -> DictConfig:
461
+ """Load, migrate, and (if from wandb) flatten a checkpoint config file."""
462
+ cfg = read_omega_cfg(config_path)
463
+ cfg = migrate(cfg)
464
+ if "wandb" in str(config_path):
465
+ cfg = OmegaConf.create(flatten_wandb(OmegaConf.to_container(cfg, resolve=True)))
466
+ return cfg
467
+
468
+
469
+ def _patch_scene_initializer(target_cfg: DictConfig, init_config_path: Path, context: str) -> None:
470
+ """
471
+ Load scene_trainer.scene_initializer from init_config_path and patch it into target_cfg in-place.
472
+ target_cfg must not be struct-protected when this is called.
473
+ """
474
+ init_cfg = _load_checkpoint_cfg(init_config_path)
475
+ initializer_subcfg = OmegaConf.select(init_cfg, "scene_trainer.scene_initializer", default=None)
476
+ if initializer_subcfg is not None:
477
+ print(cyan(f"{context}: patching scene_trainer.scene_initializer from pretrained_initializer config."))
478
+ OmegaConf.update(target_cfg, "scene_trainer.scene_initializer", initializer_subcfg, merge=True)
479
+ else:
480
+ print(cyan("pretrained_initializer config has no scene_trainer.scene_initializer key; skipping patch."))
481
+
482
+
483
+ def _resolve_config_paths(cli_cfg) -> tuple[Path | None, Path | None]:
484
+ """
485
+ Determine which config files to load based on CLI checkpointing settings.
486
+
487
+ Returns:
488
+ config_path: main checkpoint config (optimizer + initializer architecture), or None
489
+ initializer_config_path: separate initializer checkpoint config (overrides main for initializer), or None
490
+
491
+ Priority for config_path:
492
+ resume > pretrained_model > pretrained_optimizer (> pretrained_initializer sets initializer_config_path only)
493
+ """
494
+ pretrained_model = cli_cfg.checkpointing.pretrained_model
495
+ pretrained_optimizer = cli_cfg.checkpointing.pretrained_optimizer
496
+ pretrained_initializer = cli_cfg.checkpointing.pretrained_initializer
497
+ should_load = cli_cfg.mode == "test" or cli_cfg.checkpointing.load_existing_cfg
498
+
499
+ config_path = None
500
+ initializer_config_path = None
501
+
502
+ if pretrained_model is not None:
503
+ if should_load:
504
+ config_path = _find_config_for_checkpoint(pretrained_model)
505
+ print(cyan(f"Loading config from pretrained_model checkpoint {config_path}"
506
+ if config_path else f"No config found for pretrained_model {pretrained_model}."))
507
+
508
+ elif pretrained_optimizer is not None:
509
+ if should_load:
510
+ config_path = _find_config_for_checkpoint(pretrained_optimizer)
511
+ print(cyan(f"Loading config from pretrained_optimizer checkpoint {config_path}"
512
+ if config_path else f"No config found for pretrained_optimizer {pretrained_optimizer}."))
513
+ if pretrained_initializer is not None:
514
+ initializer_config_path = _find_config_for_checkpoint(pretrained_initializer)
515
+ print(cyan(f"Loading initializer config from pretrained_initializer checkpoint {initializer_config_path}"
516
+ if initializer_config_path else f"No config found for pretrained_initializer {pretrained_initializer}."))
517
+
518
+ elif pretrained_initializer is not None:
519
+ if should_load:
520
+ initializer_config_path = _find_config_for_checkpoint(pretrained_initializer)
521
+ print(cyan(f"Loading initializer-only config from pretrained_initializer checkpoint {initializer_config_path}"
522
+ if initializer_config_path else f"No config found for pretrained_initializer {pretrained_initializer}."))
523
+
524
+ else:
525
+ print(cyan("No pretrained_model, pretrained_optimizer, or pretrained_initializer specified, using cli config only."))
526
+
527
+ # Resume overrides config_path to point at the output directory's saved config.
528
+ if cli_cfg.checkpointing.resume and cli_cfg.checkpointing.load_existing_cfg:
529
+ config_path = Path(cli_cfg.output_dir) / "config.yaml"
530
+ print(cyan(f"Resuming: loading config from cfg.output_dir {config_path}"))
531
+ else:
532
+ print(cyan("Not resuming.."))
533
+
534
+ if config_path is not None and not config_path.exists():
535
+ print(cyan(f"Config file {config_path} does not exist. Continuing with cli config only."))
536
+ config_path = None
537
+ elif config_path is not None:
538
+ print(cyan(f"Found config file {config_path}."))
539
+
540
+ return config_path, initializer_config_path
541
+
542
+
543
+ def _merge_test_mode(
544
+ cli_cfg: DictConfig,
545
+ loaded_cfg: DictConfig,
546
+ initializer_config_path: Path | None,
547
+ pretrained_initializer: str | None,
548
+ ) -> tuple[DictConfig, DictConfig]:
549
+ """
550
+ Test mode: CLI config is the base for all settings (dataset, test flags, etc.).
551
+ Only optimizer and initializer *architecture* are patched in from checkpoint configs.
552
+
553
+ Initializer source priority:
554
+ 1. separate initializer_config_path (pretrained_initializer ckpt with a config file)
555
+ 2. main loaded_cfg (optimizer checkpoint's bundled initializer)
556
+ 3. CLI config as-is (pretrained_initializer set but has no config file)
557
+
558
+ Returns (merged_cfg, orig_cli_cfg); orig_cli_cfg is the snapshot taken before any
559
+ checkpoint patches so that _apply_cli_overrides can restore explicit CLI values.
560
+ """
561
+ OmegaConf.set_struct(cli_cfg, False)
562
+ # Snapshot BEFORE patching: merged_cfg aliases cli_cfg, so patches below also mutate
563
+ # cli_cfg. _apply_cli_overrides must see the original CLI values, not the patched ones.
564
+ orig_cli_cfg = OmegaConf.create(
565
+ OmegaConf.to_container(cli_cfg, resolve=False, throw_on_missing=False)
566
+ )
567
+ merged_cfg = cli_cfg # patched in-place
568
+
569
+ # Patch optimizer architecture from checkpoint
570
+ optimizer_subcfg = OmegaConf.select(loaded_cfg, "scene_trainer.scene_optimizer", default=None)
571
+ if optimizer_subcfg is not None:
572
+ print(cyan("Test mode: patching scene_trainer.scene_optimizer from checkpoint config."))
573
+ OmegaConf.update(merged_cfg, "scene_trainer.scene_optimizer", optimizer_subcfg, merge=True)
574
+
575
+ # Patch initializer architecture (priority order above)
576
+ if initializer_config_path is not None and initializer_config_path.exists():
577
+ _patch_scene_initializer(merged_cfg, initializer_config_path, context="Test mode")
578
+ elif pretrained_initializer is None:
579
+ pass
580
+ # TODO Naama
581
+ # No explicit initializer checkpoint — fall back to the optimizer checkpoint's initializer
582
+ # initializer_subcfg = OmegaConf.select(loaded_cfg, "scene_trainer.scene_initializer", default=None)
583
+ # if initializer_subcfg is not None:
584
+ # print(cyan("Test mode: patching scene_trainer.scene_initializer from checkpoint config."))
585
+ # OmegaConf.update(merged_cfg, "scene_trainer.scene_initializer", initializer_subcfg, merge=True)
586
+ else:
587
+ print(cyan("pretrained_initializer set but has no config file; using CLI scene_initializer config."))
588
+
589
+ OmegaConf.set_struct(merged_cfg, True)
590
+ return merged_cfg, orig_cli_cfg
591
+
592
+
593
+ def _merge_train_mode(
594
+ cli_cfg: DictConfig,
595
+ loaded_cfg: DictConfig,
596
+ initializer_config_path: Path | None,
597
+ ) -> tuple[DictConfig, DictConfig]:
598
+ """
599
+ Train mode: checkpoint config takes priority over CLI for all existing fields
600
+ (preserves the trained architecture). CLI fills in any new fields added since training.
601
+
602
+ If a separate initializer checkpoint is given, its scene_initializer replaces the one
603
+ inside loaded_cfg before the full merge, so the right initializer architecture is used.
604
+
605
+ Returns (merged_cfg, orig_cli_cfg); orig_cli_cfg is the pre-merge snapshot used
606
+ by _apply_cli_overrides to restore explicit CLI values.
607
+ """
608
+ if initializer_config_path is not None and initializer_config_path.exists():
609
+ init_cfg = _load_checkpoint_cfg(initializer_config_path)
610
+ initializer_subcfg = OmegaConf.select(init_cfg, "scene_trainer.scene_initializer", default=None)
611
+ if initializer_subcfg is not None:
612
+ print(cyan("Replacing scene_trainer.scene_initializer in loaded config with initializer config."))
613
+ OmegaConf.update(loaded_cfg, "scene_trainer.scene_initializer", initializer_subcfg, merge=False)
614
+ else:
615
+ print(cyan("pretrained_initializer config has no scene_trainer.scene_initializer key; skipping patch."))
616
+
617
+ orig_cli_cfg = OmegaConf.create(
618
+ OmegaConf.to_container(cli_cfg, resolve=False, throw_on_missing=False)
619
+ )
620
+ OmegaConf.set_struct(cli_cfg, False)
621
+ merged_cfg = OmegaConf.merge(cli_cfg, loaded_cfg) # loaded_cfg wins for existing fields
622
+ OmegaConf.set_struct(merged_cfg, True)
623
+ return merged_cfg, orig_cli_cfg
624
+
625
+
626
+ def merge_config_from_file(cli_cfg):
627
+ # 1. Determine which config files to load.
628
+ config_path, initializer_config_path = _resolve_config_paths(cli_cfg)
629
+
630
+ # 2. No checkpoint config: use CLI as-is, optionally patching in initializer architecture.
631
+ if config_path is None:
632
+ print(cyan(f"No config file found, using cli config only. \n"
633
+ f"Setting config version to {CURRENT_CFG_VERSION}."))
634
+ cli_cfg["version"] = CURRENT_CFG_VERSION
635
+ if initializer_config_path is not None and initializer_config_path.exists():
636
+ OmegaConf.set_struct(cli_cfg, False)
637
+ _patch_scene_initializer(cli_cfg, initializer_config_path, context="No-checkpoint")
638
+ OmegaConf.set_struct(cli_cfg, True)
639
+ return cli_cfg
640
+
641
+ # 3. Load and migrate the checkpoint config.
642
+ print(cyan(f"Loading config from {config_path}."))
643
+ loaded_cfg = _load_checkpoint_cfg(config_path)
644
+
645
+ # 4. Merge checkpoint config with CLI config (strategy differs by mode).
646
+ # Test: CLI is the base; only optimizer/initializer architecture patched from checkpoint.
647
+ # Train: checkpoint takes priority; CLI fills in new fields added since training.
648
+ pretrained_initializer = cli_cfg.checkpointing.pretrained_initializer
649
+ if cli_cfg.mode == "test":
650
+ merged_cfg, orig_cli_cfg = _merge_test_mode(
651
+ cli_cfg, loaded_cfg, initializer_config_path, pretrained_initializer
652
+ )
653
+ else:
654
+ merged_cfg, orig_cli_cfg = _merge_train_mode(cli_cfg, loaded_cfg, initializer_config_path)
655
+
656
+ # 5. Re-apply CLI overrides so user-specified values win over loaded checkpoint config.
657
+ merged_cfg = _apply_cli_overrides(merged_cfg, orig_cli_cfg, list(HydraConfig.get().overrides.task))
658
+
659
+ return merged_cfg
660
+
661
+
662
+ class SkipRun(Exception):
663
+ pass
664
+
665
+
666
+ def setup_output_dir(cfg, cfg_dict):
667
+ if cfg.output_dir != cfg_dict.output_dir:
668
+ if "$" in str(cfg.output_dir):
669
+ # interpolated value, not sure how to make it work.
670
+ cfg.output_dir = CustomPath(cfg_dict.output_dir)
671
+ output_dir = cfg.output_dir
672
+ if output_dir is None:
673
+ output_dir = CustomPath(
674
+ HydraConfig.get()["runtime"]["output_dir"]
675
+ )
676
+ else: # for resuming
677
+ output_dir = CustomPath(output_dir)
678
+ output_dir.mkdir(exist_ok=True, parents=True)
679
+
680
+ if HydraConfig.get().mode == RunMode.MULTIRUN and output_dir == "placeholder":
681
+ # Hack to overcome multirun issues
682
+ # TODO Naama, need to move to post_init of cfg
683
+ output_dir = CustomPath(hydra.core.hydra_config.HydraConfig.get()["run"]["dir"])
684
+ print(cyan(f"Multirun detected, setting output_dir to {CustomPath(output_dir):link}"))
685
+ # save checkoint path to a file for debugging
686
+ ckpt_path = cfg.checkpointing.pretrained_model or cfg.checkpointing.pretrained_optimizer
687
+ (output_dir / "ckpt_dir.txt").write_text(str(ckpt_path))
688
+ cfg_dict.output_dir = output_dir
689
+ cfg.output_dir = output_dir
690
+ output_dir.mkdir(exist_ok=True, parents=True)
691
+
692
+ if cfg.mode == 'test':
693
+ if cfg.meta_trainer.test.output_path is None or str(cfg.meta_trainer.test.output_path) in ['placeholder', 'outputs/test']:
694
+ cfg.meta_trainer.test.output_path = output_dir
695
+ if cfg.meta_trainer.test.compute_scores:
696
+ (cfg.meta_trainer.test.output_path / "metrics").mkdir(exist_ok=True, parents=True)
697
+ print(cyan(f"Saving outputs to {CustomPath(output_dir):link}."))
698
+
699
+ # Save the config to the output directory.
700
+ cfg_dict_path = output_dir / "config.yaml"
701
+
702
+ with open(cfg_dict_path, "w") as f:
703
+ OmegaConf.save(cfg_dict, f)
704
+
705
+
706
+ def get_eval_cfg(cfg_dict):
707
+ if "meta_trainer" in cfg_dict:
708
+ meta_trainer_dict = cfg_dict["meta_trainer"]
709
+ else:
710
+ raise ValueError("No trainer or meta_trainer in cfg_dict")
711
+
712
+ if cfg_dict["mode"] == "train" and meta_trainer_dict["train"]["eval_model_every_n_val"] > 0:
713
+ eval_cfg_dict = deepcopy(cfg_dict)
714
+ dataset_dir = str(cfg_dict["dataset"]["roots"]).lower()
715
+ if "re10k" in dataset_dir:
716
+ if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 2:
717
+ eval_path = "assets/evaluation_index_re10k.json"
718
+ elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 4:
719
+ eval_path = "assets/re10k_start_0_distance_150_ctx_4v_tgt_6v.json"
720
+ elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 6:
721
+ eval_path = "assets/re10k_start_0_distance_200_ctx_6v_tgt_6v.json"
722
+ else:
723
+ if meta_trainer_dict["eval_index"] is not None:
724
+ eval_path = None # placeholder
725
+ else:
726
+ raise ValueError("unsupported number of views for re10k")
727
+ elif "dl3dv" in dataset_dir:
728
+ if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 6:
729
+ eval_path = "assets/dl3dv_start_0_distance_50_ctx_6v_tgt_8v.json"
730
+ elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 2:
731
+ eval_path = "assets/dl3dv_start_0_distance_20_ctx_2v_tgt_4v.json"
732
+ elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 8:
733
+ eval_path = "assets/dl3dv_evaluation/dl3dv_start_0_distance_40_ctx_8v_tgt_8v.json"
734
+ elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 16:
735
+ eval_path = "assets/dl3dv_evaluation/dl3dv_start_0_distance_80_ctx_16v_tgt_16v.json"
736
+ elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 32:
737
+ eval_path = "assets/dl3dv_evaluation/dl3dv_start_0_distance_160_ctx_32v_tgt_24v.json"
738
+ elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 64:
739
+ eval_path = "assets/dl3dv_benchmark/dl3dv_ctx_64v_tgt_every8th.json"
740
+ elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == -1:
741
+ print("Setting manually eval_path, num_context_views remains -1 for dl3dv eval")
742
+ eval_path = "assets/dl3dv_evaluation/dl3dv_start_0_distance_40_ctx_8v_tgt_8v.json"
743
+ else:
744
+ raise ValueError("unsupported number of views for dl3dv")
745
+ elif "scannet" in dataset_dir:
746
+ if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 2:
747
+ eval_path = "assets/evaluation_index_scannet_view2.json"
748
+ else:
749
+ raise ValueError("unsupported number of views for scannet")
750
+ elif "tartanair" in dataset_dir:
751
+ if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 2:
752
+ eval_path = 'assets/evaluation_index_tartanair_view2.json'
753
+ else:
754
+ raise ValueError("unsupported number of views for tartanair")
755
+ else:
756
+ raise Exception("Fail to load eval index path")
757
+ eval_cfg_dict["dataset"]["view_sampler"] = {
758
+ "name": "evaluation",
759
+ "index_path": eval_path,
760
+ "num_context_views": cfg_dict["dataset"]["view_sampler"]["num_context_views"],
761
+ }
762
+
763
+ # specify eval index
764
+ if meta_trainer_dict["eval_index"] is not None:
765
+ eval_cfg_dict["dataset"]["view_sampler"]["index_path"] = meta_trainer_dict["eval_index"]
766
+
767
+ eval_cfg = load_typed_root_config(eval_cfg_dict)
768
+ else:
769
+ eval_cfg = None
770
+ return eval_cfg
optgs/config/dataset/base.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ image_shape: [0, 0]
2
+ background_color: [0.0, 0.0, 0.0]
3
+ cameras_are_circular: false
4
+ overfit_to_scene: null
5
+ opencv_pose_format: false
6
+ pose_align_middle_view: false
7
+
8
+ test_start_idx: 0
optgs/config/dataset/colmap.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - view_sampler: dense
4
+
5
+ name: colmap
6
+ roots: null
7
+ scene_name: null
8
+ normalize_world_space: false
9
+ subsample_factor: 8
10
+ symmetric_principal_point: false
11
+
12
+ crop_size: null
optgs/config/dataset/dl3dv.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - view_sampler: boundedv2_360
4
+
5
+ name: dl3dv
6
+ roots: [datasets/dl3dv]
7
+ make_baseline_1: false
8
+ augment: true
9
+
10
+
11
+ image_shape: [270, 480]
12
+
13
+ baseline_epsilon: 1e-3
14
+ max_fov: 100.0
15
+
16
+ skip_bad_shape: true
17
+ near: -1.
18
+ far: -1.
19
+ baseline_scale_bounds: false
20
+ shuffle_val: true
21
+ test_len: -1
22
+ test_chunk_interval: 1
23
+ sort_target_index: true
24
+ sort_context_index: true
25
+
26
+ train_times_per_scene: 1
27
+ test_times_per_scene: 1
28
+ ori_image_shape: [270, 480]
29
+ overfit_max_views: 148
30
+ use_index_to_load_chunk: false
31
+
32
+ mix_tartanair: false
33
+ no_mix_test_set: true
34
+ load_depth: false
35
+ center_pose: false
36
+
37
+ pose_align_first_view: false
38
+
39
+ scale_extrinsics: 1.
40
+ metric_scale_align_dl3dv: false
41
+
42
+ # view filtering
43
+ min_views: 0
44
+ max_views: 0
45
+ highres: false
46
+
47
+ # mix re10k & dl3dv
48
+ mix_re10k: false
49
+ re10k_min_view_dist: 40
50
+ re10k_max_view_dist: 300
51
+
52
+ # load remaining context views
53
+ load_remain_context: false
54
+ num_remain_context: 8
55
+
56
+ # random crop in training
57
+ random_crop: false
58
+ min_size: null
59
+ max_size: null
60
+
61
+ index_name: index.json
optgs/config/dataset/re10k.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - view_sampler: bounded
4
+
5
+ name: re10k
6
+ roots: [datasets/re10k]
7
+ make_baseline_1: false
8
+ augment: true
9
+
10
+ image_shape: [180, 320]
11
+ highres: false
12
+
13
+ baseline_epsilon: 1e-3
14
+ max_fov: 100.0
15
+
16
+ skip_bad_shape: true
17
+ near: -1.
18
+ far: -1.
19
+ baseline_scale_bounds: true
20
+ shuffle_val: true
21
+ test_len: -1
22
+ test_chunk_interval: 1
23
+
24
+ use_index_to_load_chunk: false
25
+
26
+ average_pose: false
27
+ center_pose: false
optgs/config/dataset/scannet.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - view_sampler: ids
4
+
5
+ name: scannet
6
+ roots: datasets/quicksplat_spp_data_processed
7
+ scene_name: null
8
+ split: test
9
+ subsample_factor: 1
10
+ num_context_views: 100
11
+ filter_bad_frames: true
12
+
13
+ crop_size: null
optgs/config/dataset/view_sampler/all.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ name: all
optgs/config/dataset/view_sampler/arbitrary.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ name: arbitrary
2
+
3
+ num_target_views: 1
4
+ num_context_views: 2
5
+
6
+ # If you want to hard-code context views, do so here.
7
+ context_views: null
optgs/config/dataset/view_sampler/bounded.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: bounded
2
+
3
+ num_target_views: 1
4
+ num_context_views: 2
5
+
6
+ min_distance_between_context_views: 2
7
+ max_distance_between_context_views: 6
8
+ min_distance_to_context_views: 0
9
+
10
+ warm_up_steps: 0
11
+ initial_min_distance_between_context_views: 2
12
+ initial_max_distance_between_context_views: 6
optgs/config/dataset/view_sampler/boundedv2.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: boundedv2
2
+
3
+ num_target_views: 1
4
+ num_context_views: 2
5
+
6
+ min_distance_between_context_views: 2
7
+ max_distance_between_context_views: 6
8
+ max_distance_to_context_views: 0
9
+
10
+ context_gap_warm_up_steps: 0
11
+ target_gap_warm_up_steps: 0
12
+
13
+ initial_min_distance_between_context_views: 2
14
+ initial_max_distance_between_context_views: 6
15
+ initial_max_distance_to_context_views: 0
optgs/config/dataset/view_sampler/boundedv2_360.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: boundedv2
2
+
3
+ num_target_views: 4
4
+ num_context_views: 4
5
+
6
+ min_distance_between_context_views: 20
7
+ max_distance_between_context_views: 50
8
+ max_distance_to_context_views: 0
9
+
10
+ context_gap_warm_up_steps: 10000
11
+ target_gap_warm_up_steps: 0
12
+
13
+ initial_min_distance_between_context_views: 15
14
+ initial_max_distance_between_context_views: 30
15
+ initial_max_distance_to_context_views: 0
16
+ extra_views_sampling_strategy: farthest_point
17
+ target_views_replace_sample: false
optgs/config/dataset/view_sampler/dense.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: dense
2
+
3
+ target_every: 8
4
+ context_every: -1
5
+ num_target_views: -1
6
+ num_context_views: -1
optgs/config/dataset/view_sampler/evaluation.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: evaluation
2
+
3
+ index_path: assets/evaluation_index_re10k_video.json
4
+ num_context_views: 2
optgs/config/dataset/view_sampler/ids.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ name: ids
2
+
3
+ context_views_ids: []
4
+ target_views_ids: []
optgs/config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ dataset:
4
+ view_sampler:
5
+ min_distance_between_context_views: 45
6
+ max_distance_between_context_views: 135
7
+ min_distance_to_context_views: 0
8
+ warm_up_steps: 30000
9
+ initial_min_distance_between_context_views: 25
10
+ initial_max_distance_between_context_views: 45
11
+ num_target_views: 4
optgs/config/dataset/view_sampler_dataset_specific_config/boundedv2_dl3dv.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ dataset:
4
+ view_sampler:
5
+ min_distance_between_context_views: 20
6
+ max_distance_between_context_views: 50
7
+ max_distance_to_context_views: 0
8
+ context_gap_warm_up_steps: 10000
9
+ target_gap_warm_up_steps: 0
10
+ initial_min_distance_between_context_views: 15
11
+ initial_max_distance_between_context_views: 30
12
+ initial_max_distance_to_context_views: 0
13
+ extra_views_sampling_strategy: farthest_point
14
+ num_target_views: 4
optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_dl3dv.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ dataset:
4
+ view_sampler:
5
+ index_path: assets/dl3dv_360_v5.json
optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ dataset:
4
+ view_sampler:
5
+ index_path: assets/evaluation_index_re10k.json
optgs/config/experiment/re10k_unified.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: re10k
5
+ - override /scene_trainer/scene_initializer: resplat_v1
6
+ - override /scene_trainer/scene_optimizer: learn2splat
7
+ - override /loss: [mse, lpips]
8
+
9
+ wandb:
10
+ name: re10k
11
+ tags: [re10k, 256x256]
12
+
13
+ data_loader:
14
+ train:
15
+ batch_size: 14
16
+
17
+ meta_trainer:
18
+ max_steps: 300_001
19
+ num_nodes: 1
20
+ test:
21
+ eval_time_skip_steps: 5
22
+ compute_scores: true
23
+ compute_scores_metrics: [psnr,ssim,lpips]
24
+ metrics_batch_size: 32
25
+
26
+ scene_trainer:
27
+ initializer:
28
+ num_depth_candidates: 128
29
+ costvolume_unet_feat_dim: 128
30
+ costvolume_unet_channel_mult: [1,1,1]
31
+ costvolume_unet_attn_res: [4]
32
+ gaussians_per_pixel: 1
33
+ depth_unet_feat_dim: 32
34
+ depth_unet_attn_res: [16]
35
+ depth_unet_channel_mult: [1,1,1,1,1]
36
+ shim_patch_size: 16
37
+ use_fsdp: false
38
+ train_scene_init: false
39
+ train_scene_opt: false
40
+ num_update_steps: 0
41
+ iter_batch_size: -1
42
+ opt_batch_size: -1
43
+ train_min_refine: 0
44
+ train_max_refine: 0
45
+
46
+
47
+ # lpips loss
48
+ loss:
49
+ lpips:
50
+ apply_after_step: 0
51
+ weight: 0.5
52
+ perceptual_loss: true
53
+ deltas:
54
+ weight: 0.0
55
+ exclude_by_norm_grad: false
56
+ exclude_by_norm_grad_opposite: true
57
+ eps: 1e-8
58
+ apply_after_step: 10000000
59
+
60
+
61
+ dataset:
62
+ image_shape: [256, 256]
63
+ roots: [datasets/re10k]
64
+ near: 0.01
65
+ far: 100.
66
+ baseline_scale_bounds: false
67
+ make_baseline_1: false
68
+ train_times_per_scene: 1
69
+ highres: false
70
+ scannet: false
71
+ tartanair: false
72
+ load_depth: false
73
+ pose_align_first_view: false
74
+ scale_extrinsics: 1.
75
+ load_remain_context: false
76
+ pose_align_middle_view: false
77
+ overfit_to_scene: null
78
+ opencv_pose_format: false
optgs/config/experiment/test_colmap.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: colmap
5
+ - override /scene_trainer/scene_initializer: null # overridden by init_opts.sh
6
+ - override /scene_trainer/scene_optimizer: null # overridden by checkpoint (ours) or CLI (baselines)
7
+ - override /scene_trainer/decoder: gsplat
8
+ - override /loss: [mse]
9
+ - override /meta_trainer/test/postprocessing: none
10
+
11
+ mode: test
12
+
13
+ scene_trainer:
14
+ train_scene_init: false
15
+ train_scene_opt: false
16
+ opt_batch_strategy: fps
17
+
18
+ checkpointing:
19
+ pretrained_model: null
20
+ pretrained_depth: null
21
+
22
+ meta_trainer:
23
+ test:
24
+ compute_scores: true
25
+ skip_if_outputs_exist: true
26
+ save_cameras_json: false
27
+ save_render_image: false
28
+ save_gaussian: false
29
+ eval_initialization: false
30
+
31
+ output_dir: placeholder
32
+ log_slurm_id: true
optgs/config/experiment/test_dl3dv.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: dl3dv
5
+ - override /scene_trainer/scene_initializer: null # overridden by checkpoint (ours) or init_opts.sh
6
+ - override /scene_trainer/scene_optimizer: null # overridden by checkpoint (ours) or CLI (baselines)
7
+ - override /scene_trainer/decoder: gsplat
8
+ - override /meta_trainer/test/postprocessing: none
9
+
10
+ mode: test
11
+
12
+ dataset:
13
+ roots: [datasets/dl3dv-480p-chunks]
14
+ near: 0.01
15
+ far: 200.
16
+ opencv_pose_format: false
17
+ image_shape: [256, 448]
18
+
19
+ scene_trainer:
20
+ train_scene_init: false
21
+ train_scene_opt: false
22
+ opt_batch_strategy: fps
23
+
24
+ checkpointing:
25
+ pretrained_model: null
26
+ pretrained_depth: null
27
+
28
+ meta_trainer:
29
+ test:
30
+ compute_scores: true
31
+ skip_if_outputs_exist: false
32
+ save_cameras_json: false
33
+ save_render_image: false
34
+ save_gaussian: false
35
+ eval_initialization: false
36
+
37
+ output_dir: placeholder
38
+ log_slurm_id: true
optgs/config/experiment/test_re10k.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: re10k
5
+ - override /scene_trainer/scene_initializer: resplat_v1
6
+ - override /scene_trainer/scene_optimizer: knn_based
7
+ - override /scene_trainer/decoder: gsplat
8
+ - override /loss: [mse]
9
+ - override /meta_trainer/test/postprocessing: none
10
+
11
+ mode: test
12
+
13
+ dataset:
14
+ image_shape: [512, 960]
15
+ ori_image_shape: [512, 960]
16
+
17
+ scene_trainer:
18
+ train_scene_init: false
19
+ train_scene_opt: false
20
+ opt_batch_strategy: fps
21
+
22
+ checkpointing:
23
+ pretrained_model: null
24
+ pretrained_depth: null
25
+
26
+ meta_trainer:
27
+ test:
28
+ compute_scores: true
29
+ skip_if_outputs_exist: true
30
+ save_cameras_json: false
31
+ save_render_image: false
32
+ save_gaussian: false
33
+ eval_initialization: false
34
+
35
+ output_dir: placeholder
36
+ log_slurm_id: true
optgs/config/experiment/train_dl3dv.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # A shared config for training on dl3dv, used by both resplat initializer, resplat optimizer, and learn2splat optimizer.
4
+
5
+ defaults:
6
+ - override /dataset: dl3dv
7
+ - override /scene_trainer/scene_initializer: resplat_v1
8
+ - override /scene_trainer/scene_optimizer: learn2splat
9
+ - override /loss: [ mse, lpips ]
10
+ - override /dataset/view_sampler: boundedv2_360
11
+
12
+ wandb:
13
+ name: dl3dv
14
+ tags: [ dl3dv, 270x480 ]
15
+
16
+ data_loader:
17
+ train:
18
+ batch_size: 1
19
+
20
+ meta_trainer:
21
+ max_steps: 50_000
22
+ val_check_interval: 0.25
23
+ train:
24
+ l1_loss: true
25
+ depth_smooth_loss_weight: 0.0
26
+ test:
27
+ eval_time_skip_steps: 0
28
+ dec_chunk_size: 30
29
+ save_every_freq: [ 1, 10, 100, 500 ]
30
+ save_every_steps: [ 0, 10, 100, 1000 ]
31
+
32
+ # lpips loss
33
+ loss:
34
+ lpips:
35
+ apply_after_step: 0
36
+ weight: 0.5
37
+ perceptual_loss: true
38
+
39
+ dataset:
40
+ roots: [ datasets/dl3dv-480p-chunks ]
41
+ near: 0.01
42
+ far: 200.
43
+ min_size: [ 384,512 ]
44
+ max_size: [ 512,960 ]
45
+ image_shape: [ 256, 448 ]
46
+ view_sampler:
47
+ num_context_views: 8
48
+ num_target_views: 6
49
+ min_distance_between_context_views: 24
50
+ max_distance_between_context_views: 45
51
+ initial_min_distance_between_context_views: 20
52
+ initial_max_distance_between_context_views: 30
53
+
54
+ output_dir: placeholder
55
+ log_slurm_id: true
optgs/config/experiment/train_l2s_sparse_dl3dv.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - train_dl3dv
5
+ - override /meta_trainer/train/replay_buffer_cfg: default
6
+ - override /loss: [ mse, lpips, deltas ]
7
+
8
+ loss:
9
+ mse:
10
+ weight: 1.0
11
+ lpips:
12
+ apply_after_step: 0
13
+ weight: 0.5
14
+ perceptual_loss: true
15
+ deltas:
16
+ weight: 1
17
+ exclude_by_norm_grad: true
18
+ exclude_by_norm_grad_opposite: true
19
+ eps: 1e-8
20
+ apply_after_step: 100
21
+
22
+ meta_trainer:
23
+ train:
24
+ loss_on_input_views: true
25
+ loss_on_input_views_num: 4
26
+ use_replay_buffer: true
27
+
28
+ scene_trainer:
29
+ train_scene_opt: true
30
+ num_update_steps: 4
31
+ train_max_refine: 6
32
+ train_min_refine: 1
33
+
34
+ meta_optimizer:
35
+ lr: 1e-4
36
+ lr_monodepth: 0.0
37
+
38
+
39
+ checkpointing:
40
+ pretrained_initializer: checkpoints/optgs/unified-dl3dv-8views/init/checkpoints/epoch_20-step_100000.ckpt # resplat inititalizer
41
+ no_strict_load: false
optgs/config/experiment/train_l2s_sparse_dl3dv_no_delta.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - train_dl3dv
5
+ - override /meta_trainer/train/replay_buffer_cfg: default
6
+ - override /loss: [ mse, lpips ]
7
+
8
+ loss:
9
+ mse:
10
+ weight: 1.0
11
+ lpips:
12
+ apply_after_step: 0
13
+ weight: 0.5
14
+ perceptual_loss: true
15
+
16
+ meta_trainer:
17
+ train:
18
+ loss_on_input_views: true
19
+ loss_on_input_views_num: 4
20
+ use_replay_buffer: true
21
+
22
+ scene_trainer:
23
+ train_scene_opt: true
24
+ num_update_steps: 4
25
+ train_max_refine: 6
26
+ train_min_refine: 1
27
+
28
+ meta_optimizer:
29
+ lr: 1e-4
30
+ lr_monodepth: 0.0
31
+
32
+
33
+ checkpointing:
34
+ pretrained_initializer: checkpoints/optgs/unified-dl3dv-8views/init/checkpoints/epoch_20-step_100000.ckpt # resplat inititalizer
35
+ no_strict_load: false
optgs/config/experiment/train_l2s_sparse_dl3dv_no_loss.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - train_dl3dv
5
+ - override /meta_trainer/train/replay_buffer_cfg: default
6
+ - override /loss: [ mse, lpips ]
7
+
8
+ loss:
9
+ mse:
10
+ weight: 1.0
11
+ lpips:
12
+ apply_after_step: 0
13
+ weight: 0.5
14
+ perceptual_loss: true
15
+
16
+ meta_trainer:
17
+ train:
18
+ loss_on_input_views: true
19
+ loss_on_input_views_num: 4
20
+ use_replay_buffer: true
21
+
22
+ scene_trainer:
23
+ train_scene_opt: true
24
+ num_update_steps: 4
25
+ train_max_refine: 6
26
+ train_min_refine: 1
27
+
28
+ meta_optimizer:
29
+ lr: 1e-4
30
+ lr_monodepth: 0.0
31
+
32
+
33
+ checkpointing:
34
+ pretrained_initializer: checkpoints/optgs/unified-dl3dv-8views/init/checkpoints/epoch_20-step_100000.ckpt # resplat inititalizer
35
+ no_strict_load: false
optgs/config/loss/deltas.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ deltas:
2
+ weight: 1.0
3
+ exclude_by_norm_grad: false
4
+ exclude_by_norm_grad_opposite: true
5
+ eps: 0.1
6
+ apply_after_step: 100
optgs/config/loss/gaussians.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gaussians:
2
+ weight: 1.0
3
+ weight_scales: 0.01
4
+ weight_opacities: 0.0
5
+ weight_sh: 0.005
6
+ sh_alpha: 1.0 # 1.0 = uniform; >1.0 = penalize higher SH degrees more
optgs/config/loss/iso_scales.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ iso_scales:
2
+ weight: 1.0
optgs/config/loss/lpips.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ lpips:
2
+ weight: 0.05
3
+ apply_after_step: 150_000
4
+ perceptual_loss: false
optgs/config/loss/mse.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ mse:
2
+ weight: 1.0
optgs/config/loss/sgd.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sgd:
2
+ weight: 1.0
optgs/config/loss/sh0.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ mse:
2
+ weight: 1.0
optgs/config/loss/ssim.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ssim:
2
+ weight: 0.2 # default in 3dgs
optgs/config/loss/stability.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ stability:
2
+ weight: 1.0
optgs/config/main.yaml ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - loss: [ mse ]
3
+ - dataset: null
4
+ - scene_trainer/scene_initializer: null
5
+ - scene_trainer/scene_optimizer: null
6
+ - scene_trainer/decoder: gsplat
7
+ - meta_trainer/test/postprocessing: none
8
+ - meta_trainer/train/replay_buffer_cfg: none
9
+
10
+ wandb:
11
+ project: placeholder
12
+ entity: placeholder
13
+ name: placeholder
14
+ mode: online
15
+ id: null
16
+ notes: null
17
+
18
+ mode: train
19
+
20
+ data_loader:
21
+ train:
22
+ num_workers: 10
23
+ persistent_workers: true
24
+ batch_size: 4
25
+ seed: 1234
26
+ test:
27
+ num_workers: 4
28
+ persistent_workers: false
29
+ batch_size: 1
30
+ seed: 2345
31
+ val:
32
+ num_workers: 1
33
+ persistent_workers: true
34
+ batch_size: 1
35
+ seed: 3456
36
+
37
+ meta_optimizer:
38
+ lr: 2.e-4
39
+ lr_monodepth: 2.e-6
40
+ lr_depth: 0.
41
+ warm_up_steps: 2000
42
+ weight_decay: 0.01
43
+ warm_up_ratio: 0.01
44
+ adamw_8bit: false
45
+
46
+ checkpointing:
47
+ load: null
48
+ every_n_train_steps: 1000
49
+ save_top_k: 5
50
+ pretrained_model: null
51
+ pretrained_model_rel_dir: ${checkpoint_rel_dir:${checkpointing.pretrained_model}}
52
+ pretrained_monodepth: null
53
+ pretrained_mvdepth: null
54
+ pretrained_depth: null
55
+ pretrained_scale_predictor: null
56
+ pretrained_depth_teacher: null
57
+ no_strict_load: false
58
+ resume: false
59
+ no_resume_upsampler: false
60
+ partial_load: false
61
+ freeze_mono_vit: false
62
+ resume_update_module: null
63
+ pretrained_initializer: null
64
+ pretrained_optimizer: null
65
+ load_existing_cfg: false
66
+
67
+ seed: 111123
68
+
69
+ meta_trainer:
70
+ max_steps: -1
71
+ val_check_interval: 0.5
72
+ gradient_clip_val: 0.5
73
+ num_sanity_val_steps: 2
74
+ eval_index: null
75
+ limit_test_batches: 1.0
76
+ limit_train_batches: 1.0
77
+ num_nodes: 1
78
+ train:
79
+ depth_mode: null
80
+ extended_visualization: false
81
+ print_log_every_n_steps: 100
82
+ eval_model_every_n_val: 2 # quantitative evaluation every n val
83
+ eval_data_length: 999999
84
+ eval_deterministic: false
85
+ eval_time_skip_steps: 3
86
+ eval_save_model: true
87
+ l1_loss: false
88
+ intermediate_loss_weight: 0.9
89
+ no_viz_video: false
90
+ eval_depth: false
91
+ train_ignore_large_loss: 0.
92
+ no_log_projections: true
93
+ no_log_video: true
94
+ depth_loss_weight: 0.
95
+ log_depth_loss: true
96
+ depth_smooth_loss_weight: 0.01
97
+ depth_smooth_loss_nonorm: false
98
+ depth_smooth_loss_weight_nvs: 0. # for novel views
99
+ monodepth_loss_weight: 0. # for monocular depth loss
100
+ depth_teacher_loss_weight: 0.
101
+ viz_depth_teacher: false
102
+ eval_render_depth: false
103
+ render_depth_loss_weight: 0.
104
+ viz_render_depth: false
105
+ use_gt_depth_range: false
106
+ depth_range_from_disparity: false
107
+ max_disparity: 128.
108
+ min_disparity: 4.
109
+ loss_on_input_views: false
110
+ loss_on_target_views: true
111
+ loss_on_input_views_num: 1
112
+ loss_on_target_views_num: -1
113
+ train_window_size: null
114
+ half_res_lpips_loss: false
115
+ viz_depth_separate: false
116
+ # L2 weight decay on Gaussian properties (meta-loss)
117
+ scale_l2_loss_weight: 0.
118
+ sh_l2_loss_weight: 0.
119
+ opacity_l2_loss_weight: 0.
120
+ use_replay_buffer: false
121
+ test:
122
+ output_path: null
123
+ compute_scores: true
124
+ compute_scores_metrics: [psnr,ssim,lpips]
125
+ metrics_batch_size: 32
126
+ eval_time_skip_steps: 0
127
+ eval_initialization: true
128
+ save_render_image: false
129
+ save_render_image_last_only: false
130
+ save_gt_image: false
131
+ save_render_depth: false
132
+ save_gt_depth: false
133
+ save_error_image: false
134
+ save_video: false
135
+ save_video_fixed_view: false
136
+ save_video_fixed_view_index: 0
137
+ save_video_fixed_view_duplicate: 0
138
+ save_video_fixed_iteration: false
139
+ save_video_fixed_iteration_indices: null
140
+ save_video_fixed_iteration_render_fixed_view: false
141
+ save_video_combined: false
142
+ save_video_combined_iterations: null
143
+ save_video_combined_fixed_iteration_length: 50
144
+ save_gaussian: false
145
+ save_poses: false
146
+ save_cameras_json: true
147
+ save_cameras_npz: true
148
+ save_point_cloud: false
149
+ render_chunk_size: null
150
+ dec_chunk_size: null
151
+ stablize_camera: false
152
+ stab_camera_kernel: 50
153
+ eval_context_views: false
154
+ inference_window_size: null
155
+ profile_model: false
156
+ save_colmap_train_test_views: false
157
+ ori_colmap_data_path: null
158
+ adam_optimizer_step: 0
159
+ save_at_iters: null
160
+ save_every_freq: null
161
+ save_every_steps: null
162
+ skip_if_outputs_exist: false
163
+ scenes_filter: null
164
+
165
+ experimental_add_noise_to_images: false
166
+ experimental_add_noise_to_images_std: null
167
+
168
+ scene_trainer:
169
+ use_fsdp: false
170
+ train_scene_init: false
171
+ train_scene_opt: false
172
+ train_min_refine: 0
173
+ train_max_refine: 0
174
+ num_update_steps: 0
175
+ iter_batch_size: -1
176
+ opt_batch_size: -1
177
+ opt_batch_size_min: 0
178
+ opt_batch_size_max: 0
179
+ opt_batch_strategy: random
180
+ sh_degree_interval: 0
181
+
182
+ output_dir: null
183
+
184
+ use_plugins: false
185
+
186
+ log_slurm_id: false
187
+
188
+ version: null
189
+
190
+ profiling:
191
+ # one of: none, basic, advanced, pytorch
192
+ # advanced profiling requires pytorch-lightning-2.5.3 (default: 2.4.0)
193
+ mode: none
194
+
195
+ debug_cfg: false
optgs/config/meta_trainer/test/postprocessing/adam.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+
4
+ name: adam
5
+ lr_data:
6
+ _base: 0.001
7
+ betas: [0.9, 0.999]
8
+ weight_decay: 0.0
9
+ amsgrad: false
10
+ eps: 1e-08
optgs/config/meta_trainer/test/postprocessing/base.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ steps: 2000
2
+ compute_metrics_every: 100
3
+ lr_data:
4
+ _base: 1
5
+ _means: 1
6
+ _scales: 1
7
+ _opacities: 1
8
+ _quats: 1
9
+ _sh0: 1
10
+ _shN: 1
11
+ scheduler: null
12
+ scheduler_warm_up_ratio: 0.01
13
+ prior_steps: 0
14
+
15
+ # Means LR scheduling (defaults match vanilla optimizer)
16
+ means_lr_final_ratio: 0.0625 # ratio of final/initial means LR (vanilla: 1e-5 / 1.6e-4)
17
+ means_lr_delay_mult: 0.01 # ramp-up delay multiplier (vanilla default)
18
+ means_lr_scale_by_scene_extent: true
19
+
20
+ # View chunking for gradient accumulation
21
+ chunk_size: -1 # -1 = all views at once
22
+
23
+ # ADC (Adaptive Density Control) - null = disabled
24
+ adc: null
optgs/config/meta_trainer/test/postprocessing/none.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+
4
+ name: none
5
+ steps: 0
optgs/config/meta_trainer/test/postprocessing/sgd.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+
4
+ name: sgd
5
+ momentum: 0.0
6
+ weight_decay: 0.0
7
+ nesterov: false
optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - adam
4
+
5
+ lr_data:
6
+ _base: 1
7
+ _means: 1.6e-4
8
+ _scales: 5e-3
9
+ _opacities: 5e-2
10
+ _quats: 1e-3
11
+ _sh0: 2.5e-3
12
+ _shN: 1.25e-4 # 2.5e-3 / 20
optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs_sgd.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - sgd
4
+
5
+ lr_data:
6
+ _base: 1
7
+ _means: 1.6e-4
8
+ _scales: 5e-3
9
+ _opacities: 5e-2
10
+ _quats: 1e-3
11
+ _sh0: 2.5e-3
12
+ _shN: 1.25e-4 # 2.5e-3 / 20
optgs/config/meta_trainer/train/replay_buffer_cfg/default.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ capacity: 20
2
+ sample_batch_size: 1
3
+ sample_prob: 0.7
4
+ insert_prob: 0.7
5
+ return_prob: 0.99
6
+ simulate_ahead: true
7
+ simulate_ahead_min_steps: 1
8
+ simulate_ahead_max_steps: 50
9
+ simulate_ahead_grow: 10000
10
+ max_t: null
11
+ push_only_if_not_full: false
12
+ remove_strategy_when_full: oldest