dxm21 commited on
Commit
8c48cce
·
verified ·
1 Parent(s): 0325522

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -1,35 +1,26 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ 2403.20309v6.pdf filter=lfs diff=lfs merge=lfs -text
4
+ 2601.09499v1.pdf filter=lfs diff=lfs merge=lfs -text
5
+ gs/training_progress.mp4 filter=lfs diff=lfs merge=lfs -text
6
+ vdpm/examples/videos/camel.mp4 filter=lfs diff=lfs merge=lfs -text
7
+ vdpm/examples/videos/car.mp4 filter=lfs diff=lfs merge=lfs -text
8
+ vdpm/examples/videos/figure1.mp4 filter=lfs diff=lfs merge=lfs -text
9
+ vdpm/examples/videos/figure2.mp4 filter=lfs diff=lfs merge=lfs -text
10
+ vdpm/examples/videos/figure3.mp4 filter=lfs diff=lfs merge=lfs -text
11
+ vdpm/examples/videos/goldfish.mp4 filter=lfs diff=lfs merge=lfs -text
12
+ vdpm/examples/videos/horse.mp4 filter=lfs diff=lfs merge=lfs -text
13
+ vdpm/examples/videos/paragliding.mp4 filter=lfs diff=lfs merge=lfs -text
14
+ vdpm/examples/videos/pstudio.mp4 filter=lfs diff=lfs merge=lfs -text
15
+ vdpm/examples/videos/stroller.mp4 filter=lfs diff=lfs merge=lfs -text
16
+ vdpm/examples/videos/swing.mp4 filter=lfs diff=lfs merge=lfs -text
17
+ vdpm/examples/videos/tennis.mp4 filter=lfs diff=lfs merge=lfs -text
18
+ vdpm/examples/videos/tesla.mp4 filter=lfs diff=lfs merge=lfs -text
19
+ vdpm/input_images_20260128_014417_015976/images/000000.png filter=lfs diff=lfs merge=lfs -text
20
+ vdpm/input_images_20260128_014417_015976/images/000001.png filter=lfs diff=lfs merge=lfs -text
21
+ vdpm/input_images_20260128_014417_015976/images/000002.png filter=lfs diff=lfs merge=lfs -text
22
+ vdpm/input_images_20260128_014417_015976/images/000003.png filter=lfs diff=lfs merge=lfs -text
23
+ vdpm/input_images_20260128_014417_015976/output_4d.npz filter=lfs diff=lfs merge=lfs -text
24
+ vdpm/input_images_20260128_014417_015976/poses.npz filter=lfs diff=lfs merge=lfs -text
25
+ vdpm/input_images_20260128_014417_015976/reconstruction_data.zip filter=lfs diff=lfs merge=lfs -text
26
+ vdpm/input_images_20260128_014417_015976/tracks.npz filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # User requested ignores
2
+ output/
3
+ mv-video/
4
+
5
+ # Python
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ *.so
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # Virtual Environments
30
+ venv/
31
+ env/
32
+ ENV/
33
+ env.bak/
34
+ venv.bak/
35
+
36
+ # VS Code
37
+ .vscode/
38
+
39
+ # Gradio
40
+ .gradio/
2403.20309v6.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd8415f171a0353126dcb1029126f48b805c3c24f65706bcc930c55dfc5dcc2e
3
+ size 8417471
2601.09499v1.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09bad1eec73fad7ab1cc4d9c4da01305d3c4cebb3094c06924cd4df088065738
3
+ size 13097134
README.md CHANGED
@@ -1,12 +1,113 @@
1
- ---
2
- title: 4dgs Dpm
3
- emoji: 🌍
4
- colorFrom: red
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 6.4.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: 4dgs-dpm
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.17.1
6
+ ---
7
+ # DPM-Splat: Video → 4D Gaussian Splats
8
+
9
+ End-to-end pipeline combining **V-DPM** (Video Dynamic Point Maps) with **3D Gaussian Splatting** for dynamic 4D scene reconstruction from multi-view video.
10
+
11
+ ![Pipeline](https://img.shields.io/badge/Pipeline-VDPM%20→%203DGS-blue)
12
+ ![License](https://img.shields.io/badge/License-MIT-green)
13
+
14
+ ## Features
15
+
16
+ - **Feed-forward reconstruction**: No per-scene optimization needed for initial point cloud
17
+ - **Multi-view support**: 1-4 synchronized video inputs
18
+ - **Temporal consistency**: Dynamic point tracking across frames
19
+ - **Memory efficient**: BF16/FP16 quantization, flash attention support
20
+ - **Co-visibility filtering**: Reduces redundant points (InstantSplat-inspired)
21
+ - **Gradio demo**: Easy-to-use web interface
22
+
23
+ ## Demo
24
+
25
+ Run the interactive demo:
26
+ ```bash
27
+ python app.py
28
+ ```
29
+
30
+ Or try the hosted version on [Hugging Face Spaces](https://huggingface.co/spaces/YOUR_USERNAME/dpm-splat)
31
+
32
+ ## Installation
33
+
34
+ ```bash
35
+ # Create environment
36
+ conda create -n 4dgs-dpm python=3.10
37
+ conda activate 4dgs-dpm
38
+
39
+ # Install PyTorch with CUDA
40
+ pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
41
+
42
+ # Install dependencies
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ ## Usage
47
+
48
+ ### Web Interface (Recommended)
49
+ ```bash
50
+ python app.py
51
+ ```
52
+ Upload videos, adjust settings, and download results as ZIP.
53
+
54
+ ### Command Line
55
+ ```bash
56
+ # Run VDPM inference
57
+ python vdpm/visualise.py --input mv-video/your-video --output output/vdpm
58
+
59
+ # Train 3DGS from VDPM output
60
+ python -m gs.train_vdpm --input output/vdpm --output output/splats --iterations 1000
61
+ ```
62
+
63
+ ## Pipeline
64
+
65
+ 1. **Video Processing**: Extract and interleave frames from multi-view videos
66
+ 2. **VDPM Inference**: Generate dynamic point maps and camera poses using VGGT backbone
67
+ 3. **3DGS Training**: Train per-frame Gaussian splats initialized from point maps
68
+ 4. **Animation Rendering**: Generate GIF from interpolated camera viewpoint
69
+
70
+ ## Output
71
+
72
+ The pipeline generates:
73
+ - `splats/frame_XXXX.ply` - Gaussian splat for each timestep
74
+ - `renders/` - Training progress images
75
+ - `animation.gif` - Rendered animation from average camera
76
+ - `tracks.npz` - 3D point tracks
77
+ - `poses.npz` - Camera poses
78
+
79
+ ## Requirements
80
+
81
+ - NVIDIA GPU with 8GB+ VRAM (tested on RTX 3070 Ti)
82
+ - CUDA 11.8+
83
+ - Python 3.10+
84
+
85
+ ## TO-DO
86
+
87
+ - [x] VGGT Quantization (BF16/FP16)
88
+ - [x] Co-visibility check to reduce points
89
+ - [x] Dynamic point tracking
90
+ - [x] Per-frame 3DGS training
91
+ - [x] Gradio demo with GIF rendering
92
+ - [ ] Flash Attention for VGGT
93
+ - [ ] Dynamic/Static segmentation
94
+ - [ ] 3DGS with dynamic deformation field
95
+ - [ ] 4DGS primitive support
96
+
97
+ ## Citation
98
+
99
+ ```bibtex
100
+ @misc{dpmsplat2026,
101
+ title={DPM-Splat: Video to 4D Gaussian Splats via Dynamic Point Maps},
102
+ author={Your Name},
103
+ year={2026},
104
+ url={https://github.com/YOUR_USERNAME/4dgs-dpm}
105
+ }
106
+ ```
107
+
108
+ ## Acknowledgements
109
+
110
+ - [VGGT](https://github.com/facebookresearch/vggt) - Visual Geometry Grounded Transformer
111
+ - [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting)
112
+ - [NVIDIA Warp](https://github.com/NVIDIA/warp)
113
+
app.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DPM-Splat: End-to-end pipeline for Video → 4D Gaussian Splats
3
+ Combines VDPM inference with 3DGS training in a single Gradio interface.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import shutil
9
+ import zipfile
10
+ import gc
11
+ import json
12
+ import glob
13
+ import time
14
+ from pathlib import Path
15
+ from datetime import datetime
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import gradio as gr
20
+ import torch
21
+ import imageio
22
+
23
+ # Set memory optimization
24
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
25
+
26
+ # Add paths
27
+ sys.path.insert(0, str(Path(__file__).parent / "vdpm"))
28
+ sys.path.insert(0, str(Path(__file__).parent / "gs"))
29
+
30
+ # Check GPU availability
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+ if device == "cuda":
34
+ torch.backends.cuda.matmul.allow_tf32 = True
35
+ torch.backends.cudnn.allow_tf32 = True
36
+ gpu_name = torch.cuda.get_device_name(0)
37
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
38
+ print(f"✓ GPU: {gpu_name} ({gpu_mem:.1f} GB)")
39
+ else:
40
+ print("⚠ No GPU detected - running on CPU (will be slow)")
41
+
42
+ # Configuration
43
+ VIDEO_SAMPLE_HZ = 1.0
44
+ MAX_FRAMES = 8 if device == "cuda" else 4
45
+
46
+ # Global model cache
47
+ _vdpm_model = None
48
+
49
+
50
+ def get_vdpm_model():
51
+ """Load and cache the VDPM model"""
52
+ global _vdpm_model
53
+
54
+ if _vdpm_model is not None:
55
+ print("✓ Using cached VDPM model")
56
+ return _vdpm_model
57
+
58
+ print("Loading VDPM model...")
59
+ sys.stdout.flush()
60
+
61
+ from hydra import compose, initialize
62
+ from hydra.core.global_hydra import GlobalHydra
63
+ from dpm.model import VDPM
64
+
65
+ if GlobalHydra.instance().is_initialized():
66
+ GlobalHydra.instance().clear()
67
+
68
+ with initialize(config_path="vdpm/configs"):
69
+ cfg = compose(config_name="visualise")
70
+
71
+ model = VDPM(cfg).to(device)
72
+
73
+ # Load weights
74
+ cache_dir = os.path.expanduser("~/.cache/vdpm")
75
+ os.makedirs(cache_dir, exist_ok=True)
76
+ model_path = os.path.join(cache_dir, "vdpm_model.pt")
77
+
78
+ _URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
79
+
80
+ if not os.path.exists(model_path):
81
+ print(f"Downloading VDPM model...")
82
+ sd = torch.hub.load_state_dict_from_url(_URL, file_name="vdpm_model.pt", progress=True, map_location=device)
83
+ torch.save(sd, model_path)
84
+ else:
85
+ print(f"✓ Loading cached model from {model_path}")
86
+ sd = torch.load(model_path, map_location=device)
87
+
88
+ model.load_state_dict(sd, strict=True)
89
+ model.eval()
90
+
91
+ # Use half precision
92
+ if device == "cuda":
93
+ if torch.cuda.get_device_capability()[0] >= 8:
94
+ model = model.to(torch.bfloat16)
95
+ print("✓ Using BF16 precision")
96
+ else:
97
+ model = model.half()
98
+ print("✓ Using FP16 precision")
99
+
100
+ _vdpm_model = model
101
+ return model
102
+
103
+
104
+ def process_videos(video_files, target_dir):
105
+ """Extract and interleave frames from uploaded videos"""
106
+ images_dir = target_dir / "images"
107
+ images_dir.mkdir(parents=True, exist_ok=True)
108
+
109
+ num_views = len(video_files)
110
+ captures = []
111
+ intervals = []
112
+
113
+ for vid_obj in video_files:
114
+ video_path = vid_obj.name if hasattr(vid_obj, 'name') else str(vid_obj)
115
+ vs = cv2.VideoCapture(video_path)
116
+ fps = float(vs.get(cv2.CAP_PROP_FPS) or 30.0)
117
+ interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1)
118
+ captures.append(vs)
119
+ intervals.append(interval)
120
+
121
+ # Interleave frames
122
+ frame_num = 0
123
+ step_count = 0
124
+ active = True
125
+ image_paths = []
126
+
127
+ while active:
128
+ active = False
129
+ for i, vs in enumerate(captures):
130
+ if not vs.isOpened():
131
+ continue
132
+ ret, frame = vs.read()
133
+ if ret:
134
+ active = True
135
+ if step_count % intervals[i] == 0:
136
+ out_path = images_dir / f"{frame_num:06d}.png"
137
+ cv2.imwrite(str(out_path), frame)
138
+ image_paths.append(str(out_path))
139
+ frame_num += 1
140
+ else:
141
+ vs.release()
142
+ step_count += 1
143
+
144
+ for vs in captures:
145
+ if vs.isOpened():
146
+ vs.release()
147
+
148
+ # Save metadata
149
+ meta = {"num_views": num_views}
150
+ with open(target_dir / "meta.json", "w") as f:
151
+ json.dump(meta, f)
152
+
153
+ return image_paths, num_views
154
+
155
+
156
+ def run_vdpm_inference(target_dir, progress):
157
+ """Run VDPM inference"""
158
+ from vggt.utils.load_fn import load_and_preprocess_images
159
+
160
+ model = get_vdpm_model()
161
+
162
+ image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
163
+ if not image_names:
164
+ raise ValueError("No images found")
165
+
166
+ # Load metadata
167
+ meta_path = target_dir / "meta.json"
168
+ num_views = 1
169
+ if meta_path.exists():
170
+ with open(meta_path) as f:
171
+ num_views = json.load(f).get("num_views", 1)
172
+
173
+ # Limit frames
174
+ if len(image_names) > MAX_FRAMES:
175
+ limit = (MAX_FRAMES // num_views) * num_views
176
+ if limit == 0:
177
+ limit = num_views
178
+ print(f"⚠ Limiting to {limit} frames")
179
+ image_names = image_names[:limit]
180
+
181
+ progress(0.15, desc=f"Loading {len(image_names)} images...")
182
+ images = load_and_preprocess_images(image_names).to(device)
183
+
184
+ # Construct views
185
+ views = []
186
+ for i in range(len(image_names)):
187
+ t_idx = i // num_views
188
+ cam_idx = i % num_views
189
+ views.append({
190
+ "img": images[i].unsqueeze(0),
191
+ "view_idxs": torch.tensor([[cam_idx, t_idx]], device=device, dtype=torch.long)
192
+ })
193
+
194
+ progress(0.2, desc="Running VDPM forward pass...")
195
+ print(f"Running inference on {len(image_names)} images...")
196
+ sys.stdout.flush()
197
+
198
+ with torch.no_grad():
199
+ with torch.amp.autocast('cuda'):
200
+ predictions = model.inference(views=views)
201
+
202
+ # Extract results
203
+ pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
204
+ conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
205
+
206
+ pose_enc = None
207
+ if "pose_enc" in predictions:
208
+ pose_enc = predictions["pose_enc"].detach().cpu().numpy()
209
+
210
+ del predictions
211
+ torch.cuda.empty_cache()
212
+
213
+ world_points_raw = np.concatenate(pts_list, axis=0)
214
+ world_points_conf_raw = np.concatenate(conf_list, axis=0)
215
+
216
+ T = world_points_raw.shape[0]
217
+ S = world_points_raw.shape[1]
218
+ num_timesteps = T
219
+
220
+ # Process multi-view
221
+ if num_views > 1 and S == num_views * T:
222
+ world_points_list = []
223
+ world_points_conf_list = []
224
+ for t in range(T):
225
+ start_idx = t * num_views
226
+ end_idx = start_idx + num_views
227
+ world_points_list.append(world_points_raw[t, start_idx:end_idx])
228
+ world_points_conf_list.append(world_points_conf_raw[t, start_idx:end_idx])
229
+ world_points = np.stack(world_points_list, axis=0)
230
+ world_points_conf = np.stack(world_points_conf_list, axis=0)
231
+ else:
232
+ if world_points_raw.ndim == 5 and world_points_raw.shape[0] == 1:
233
+ world_points = world_points_raw[0]
234
+ world_points_conf = world_points_conf_raw[0]
235
+ else:
236
+ world_points = world_points_raw
237
+ world_points_conf = world_points_conf_raw
238
+
239
+ progress(0.35, desc="Saving VDPM outputs...")
240
+
241
+ # Save outputs
242
+ np.savez_compressed(
243
+ target_dir / "tracks.npz",
244
+ world_points=world_points,
245
+ world_points_conf=world_points_conf,
246
+ num_views=num_views,
247
+ num_timesteps=num_timesteps
248
+ )
249
+
250
+ if pose_enc is not None:
251
+ np.savez_compressed(target_dir / "poses.npz", pose_enc=pose_enc)
252
+
253
+ print(f"✓ VDPM complete: {num_timesteps} timesteps, {num_views} views")
254
+ sys.stdout.flush()
255
+
256
+ return num_timesteps, num_views
257
+
258
+
259
+ def run_3dgs_training(target_dir, output_dir, iterations, conf_threshold, progress):
260
+ """Run 3DGS training"""
261
+ import warp as wp
262
+ from train_vdpm import load_vdpm_data, VDPM3DGSTrainer
263
+
264
+ wp.init()
265
+
266
+ data = load_vdpm_data(str(target_dir))
267
+ num_timesteps = data['T']
268
+
269
+ output_path = Path(output_dir)
270
+ output_path.mkdir(parents=True, exist_ok=True)
271
+
272
+ all_ply_files = []
273
+
274
+ for frame_idx in range(num_timesteps):
275
+ frame_progress = 0.4 + (0.5 * frame_idx / num_timesteps)
276
+ progress(frame_progress, desc=f"Training frame {frame_idx + 1}/{num_timesteps}...")
277
+
278
+ print(f"\n{'='*50}")
279
+ print(f"[Frame {frame_idx + 1}/{num_timesteps}]")
280
+ print(f"{'='*50}")
281
+ sys.stdout.flush()
282
+
283
+ trainer = VDPM3DGSTrainer(
284
+ data=data,
285
+ frame_idx=frame_idx,
286
+ output_path=str(output_path),
287
+ conf_threshold=conf_threshold
288
+ )
289
+
290
+ # Training loop with progress
291
+ print(f"Training for {iterations} iterations...")
292
+ sys.stdout.flush()
293
+
294
+ trainer.save(0) # Initial state
295
+
296
+ for it in range(iterations):
297
+ trainer.zero_grad()
298
+
299
+ cam_idx = np.random.randint(len(trainer.cameras))
300
+ camera = trainer.cameras[cam_idx]
301
+ target = trainer.images[cam_idx]
302
+
303
+ from forward import render_gaussians
304
+ from loss import l1_loss, compute_image_gradients
305
+ from backward import backward
306
+ from optimizer import adam_update
307
+ from config import DEVICE
308
+
309
+ rendered, depth, trainer.intermediate_buffers = render_gaussians(
310
+ background=np.array(trainer.config['background_color'], dtype=np.float32),
311
+ means3D=trainer.params['positions'].numpy(),
312
+ colors=None,
313
+ opacity=trainer.params['opacities'].numpy(),
314
+ scales=trainer.params['scales'].numpy(),
315
+ rotations=trainer.params['rotations'].numpy(),
316
+ scale_modifier=1.0,
317
+ viewmatrix=camera['world_to_camera'],
318
+ projmatrix=camera['full_proj_matrix'],
319
+ tan_fovx=camera['tan_fovx'],
320
+ tan_fovy=camera['tan_fovy'],
321
+ image_height=camera['height'],
322
+ image_width=camera['width'],
323
+ sh=trainer.params['shs'].numpy(),
324
+ degree=3,
325
+ campos=camera['camera_center'],
326
+ prefiltered=False,
327
+ antialiasing=True,
328
+ )
329
+
330
+ target_wp = wp.array(target.astype(np.float32), dtype=wp.vec3, device=DEVICE)
331
+ loss = l1_loss(rendered, target_wp)
332
+ trainer.losses.append(loss)
333
+
334
+ pixel_grad_buffer = compute_image_gradients(rendered, target_wp, lambda_dssim=0)
335
+
336
+ view_matrix = wp.mat44(camera['world_to_camera'].flatten())
337
+ proj_matrix = wp.mat44(camera['full_proj_matrix'].flatten())
338
+ campos = wp.vec3(camera['camera_center'][0], camera['camera_center'][1], camera['camera_center'][2])
339
+
340
+ geom_buffer = {
341
+ 'radii': trainer.intermediate_buffers['radii'],
342
+ 'means2D': trainer.intermediate_buffers['points_xy_image'],
343
+ 'conic_opacity': trainer.intermediate_buffers['conic_opacity'],
344
+ 'rgb': trainer.intermediate_buffers['colors'],
345
+ 'clamped': trainer.intermediate_buffers['clamped_state']
346
+ }
347
+ binning_buffer = {'point_list': trainer.intermediate_buffers['point_list']}
348
+ img_buffer = {
349
+ 'ranges': trainer.intermediate_buffers['ranges'],
350
+ 'final_Ts': trainer.intermediate_buffers['final_Ts'],
351
+ 'n_contrib': trainer.intermediate_buffers['n_contrib']
352
+ }
353
+
354
+ gradients = backward(
355
+ background=np.array(trainer.config['background_color'], dtype=np.float32),
356
+ means3D=trainer.params['positions'],
357
+ dL_dpixels=pixel_grad_buffer,
358
+ opacity=trainer.params['opacities'],
359
+ shs=trainer.params['shs'],
360
+ scales=trainer.params['scales'],
361
+ rotations=trainer.params['rotations'],
362
+ scale_modifier=trainer.config['scale_modifier'],
363
+ viewmatrix=view_matrix,
364
+ projmatrix=proj_matrix,
365
+ tan_fovx=camera['tan_fovx'],
366
+ tan_fovy=camera['tan_fovy'],
367
+ image_height=camera['height'],
368
+ image_width=camera['width'],
369
+ campos=campos,
370
+ radii=trainer.intermediate_buffers['radii'],
371
+ means2D=trainer.intermediate_buffers['points_xy_image'],
372
+ conic_opacity=trainer.intermediate_buffers['conic_opacity'],
373
+ rgb=trainer.intermediate_buffers['colors'],
374
+ cov3Ds=trainer.intermediate_buffers['cov3Ds'],
375
+ clamped=trainer.intermediate_buffers['clamped_state'],
376
+ geom_buffer=geom_buffer,
377
+ binning_buffer=binning_buffer,
378
+ img_buffer=img_buffer,
379
+ degree=trainer.config['sh_degree'],
380
+ debug=False
381
+ )
382
+
383
+ wp.copy(trainer.grads['positions'], gradients['dL_dmean3D'])
384
+ wp.copy(trainer.grads['scales'], gradients['dL_dscale'])
385
+ wp.copy(trainer.grads['rotations'], gradients['dL_drot'])
386
+ wp.copy(trainer.grads['opacities'], gradients['dL_dopacity'])
387
+ wp.copy(trainer.grads['shs'], gradients['dL_dshs'])
388
+
389
+ lr = 0.001 * (0.1 ** (it / iterations))
390
+ wp.launch(adam_update, dim=trainer.num_points, inputs=[
391
+ trainer.params['positions'], trainer.params['scales'],
392
+ trainer.params['rotations'], trainer.params['opacities'], trainer.params['shs'],
393
+ trainer.grads['positions'], trainer.grads['scales'],
394
+ trainer.grads['rotations'], trainer.grads['opacities'], trainer.grads['shs'],
395
+ trainer.adam_m['positions'], trainer.adam_m['scales'],
396
+ trainer.adam_m['rotations'], trainer.adam_m['opacities'], trainer.adam_m['shs'],
397
+ trainer.adam_v['positions'], trainer.adam_v['scales'],
398
+ trainer.adam_v['rotations'], trainer.adam_v['opacities'], trainer.adam_v['shs'],
399
+ trainer.num_points, lr, lr*5, lr*5, lr*2, lr*5,
400
+ 0.9, 0.999, 1e-8, it
401
+ ])
402
+
403
+ # Progress logging
404
+ if (it + 1) % 100 == 0:
405
+ print(f" Iter {it+1}/{iterations} | Loss: {loss:.4f}")
406
+ sys.stdout.flush()
407
+
408
+ # Checkpoints
409
+ if (it + 1) % 500 == 0 or it == iterations - 1:
410
+ trainer.save(it + 1)
411
+
412
+ ply_path = trainer.save_final()
413
+ all_ply_files.append(str(ply_path))
414
+ print(f"✓ Frame {frame_idx} complete: {ply_path}")
415
+ sys.stdout.flush()
416
+
417
+ return all_ply_files
418
+
419
+
420
+ def render_animation_gif(ply_files, data, output_path, progress, fps=10):
421
+ """
422
+ Render a GIF animation from an average camera position across all frames.
423
+
424
+ Args:
425
+ ply_files: List of PLY file paths for each frame
426
+ data: VDPM data dict with camera info
427
+ output_path: Path to save the GIF
428
+ progress: Gradio progress callback
429
+ fps: Frames per second for GIF
430
+ """
431
+ import warp as wp
432
+ from forward import render_gaussians
433
+ from utils.point_cloud_utils import load_ply
434
+ from utils.math_utils import projection_matrix
435
+ from train_vdpm import decode_poses
436
+
437
+ if not ply_files:
438
+ return None
439
+
440
+ print("Rendering animation GIF...")
441
+ sys.stdout.flush()
442
+
443
+ # Get image dimensions
444
+ images = data['images']
445
+ img_H, img_W = images.shape[1:3]
446
+
447
+ # Decode poses to get all cameras
448
+ pose_enc = data.get('pose_enc')
449
+ if pose_enc is not None:
450
+ extrinsics, intrinsics = decode_poses(pose_enc, (img_H, img_W))
451
+ else:
452
+ # Fallback
453
+ N = data['T'] * data['V']
454
+ extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
455
+ fx = fy = max(img_H, img_W)
456
+ K = np.array([[fx, 0, img_W/2], [0, fy, img_H/2], [0, 0, 1]], dtype=np.float32)
457
+ intrinsics = np.tile(K, (N, 1, 1))
458
+
459
+ # Compute average camera position
460
+ camera_centers = []
461
+ for i in range(len(extrinsics)):
462
+ R = extrinsics[i][:3, :3]
463
+ t = extrinsics[i][:3, 3]
464
+ center = -R.T @ t
465
+ camera_centers.append(center)
466
+
467
+ avg_center = np.mean(camera_centers, axis=0)
468
+
469
+ # Use first camera's orientation and intrinsics as base
470
+ R = extrinsics[0][:3, :3]
471
+ intrinsic = intrinsics[0]
472
+ fx, fy = intrinsic[0, 0], intrinsic[1, 1]
473
+
474
+ # Compute translation for average position
475
+ t = -R @ avg_center
476
+
477
+ # Build camera matrices (transposed for Warp/OpenGL)
478
+ world_to_camera = np.eye(4, dtype=np.float32)
479
+ world_to_camera[:3, :3] = R
480
+ world_to_camera[:3, 3] = t
481
+ world_to_camera = world_to_camera.T
482
+
483
+ fov_x = 2 * np.arctan(img_W / (2 * fx))
484
+ fov_y = 2 * np.arctan(img_H / (2 * fy))
485
+
486
+ proj_matrix = projection_matrix(fovx=fov_x, fovy=fov_y, znear=0.01, zfar=100.0).T
487
+ full_proj_matrix = world_to_camera @ proj_matrix
488
+
489
+ tan_fovx = np.tan(fov_x / 2)
490
+ tan_fovy = np.tan(fov_y / 2)
491
+
492
+ # Render each frame
493
+ rendered_frames = []
494
+ background = np.array([1.0, 1.0, 1.0], dtype=np.float32) # White background
495
+
496
+ for i, ply_path in enumerate(ply_files):
497
+ if not Path(ply_path).exists():
498
+ continue
499
+
500
+ progress(0.9 + 0.05 * (i / len(ply_files)), desc=f"Rendering GIF frame {i+1}/{len(ply_files)}...")
501
+
502
+ # Load PLY
503
+ ply_data = load_ply(ply_path)
504
+
505
+ positions = ply_data['positions']
506
+ scales = ply_data['scales']
507
+ rotations = ply_data['rotations']
508
+ opacities = ply_data['opacities']
509
+ shs = ply_data['shs']
510
+
511
+ # Render
512
+ rendered, _, _ = render_gaussians(
513
+ background=background,
514
+ means3D=positions,
515
+ colors=None,
516
+ opacity=opacities,
517
+ scales=scales,
518
+ rotations=rotations,
519
+ scale_modifier=1.0,
520
+ viewmatrix=world_to_camera,
521
+ projmatrix=full_proj_matrix,
522
+ tan_fovx=tan_fovx,
523
+ tan_fovy=tan_fovy,
524
+ image_height=img_H,
525
+ image_width=img_W,
526
+ sh=shs,
527
+ degree=3,
528
+ campos=avg_center,
529
+ prefiltered=False,
530
+ antialiasing=True,
531
+ )
532
+
533
+ # Convert to numpy
534
+ rendered_np = wp.to_torch(rendered).cpu().numpy()
535
+ rendered_np = np.clip(rendered_np * 255, 0, 255).astype(np.uint8)
536
+ rendered_frames.append(rendered_np)
537
+
538
+ if not rendered_frames:
539
+ return None
540
+
541
+ # Save GIF
542
+ gif_path = Path(output_path)
543
+ imageio.mimsave(str(gif_path), rendered_frames, fps=fps, loop=0)
544
+ print(f"✓ Animation GIF saved: {gif_path}")
545
+ sys.stdout.flush()
546
+
547
+ return str(gif_path)
548
+
549
+
550
+ def run_pipeline(video_files, iterations, conf_threshold, progress=gr.Progress()):
551
+ """Run the full VDPM → 3DGS pipeline"""
552
+
553
+ if not video_files:
554
+ return None, None, None, "❌ Please upload video file(s)"
555
+
556
+ gc.collect()
557
+ if device == "cuda":
558
+ torch.cuda.empty_cache()
559
+
560
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
561
+ run_dir = Path(f"output/pipeline/run_{timestamp}")
562
+ run_dir.mkdir(parents=True, exist_ok=True)
563
+
564
+ try:
565
+ # Step 1: Process videos
566
+ progress(0.05, desc="Processing uploaded videos...")
567
+ print("=" * 50)
568
+ print("Processing Videos")
569
+ print("=" * 50)
570
+ sys.stdout.flush()
571
+
572
+ image_paths, num_views = process_videos(video_files, run_dir)
573
+ print(f"✓ Extracted {len(image_paths)} frames from {num_views} videos")
574
+ sys.stdout.flush()
575
+
576
+ # Step 2: VDPM inference
577
+ progress(0.1, desc="Running VDPM inference...")
578
+ print("=" * 50)
579
+ print("Running VDPM Inference")
580
+ print("=" * 50)
581
+ sys.stdout.flush()
582
+
583
+ num_timesteps, num_views = run_vdpm_inference(run_dir, progress)
584
+
585
+ # Clear VRAM before 3DGS training
586
+ global _vdpm_model
587
+ _vdpm_model = None
588
+ gc.collect()
589
+ if device == "cuda":
590
+ torch.cuda.empty_cache()
591
+ print(f"✓ Cleared VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB in use")
592
+ sys.stdout.flush()
593
+
594
+ # Step 3: 3DGS training
595
+ progress(0.4, desc="Training 3D Gaussian Splats...")
596
+ print("=" * 50)
597
+ print("Training 3D Gaussian Splats")
598
+ print("=" * 50)
599
+ sys.stdout.flush()
600
+
601
+ splat_dir = run_dir / "splats"
602
+ all_ply_files = run_3dgs_training(
603
+ run_dir, splat_dir, int(iterations), float(conf_threshold), progress
604
+ )
605
+
606
+ # Step 4: Render animation GIF from average camera
607
+ progress(0.9, desc="Rendering animation GIF...")
608
+ print("=" * 50)
609
+ print("Rendering Animation GIF")
610
+ print("=" * 50)
611
+ sys.stdout.flush()
612
+
613
+ gif_path = None
614
+ if all_ply_files:
615
+ from train_vdpm import load_vdpm_data
616
+ data = load_vdpm_data(str(run_dir))
617
+ gif_path = render_animation_gif(
618
+ all_ply_files, data, run_dir / "animation.gif", progress
619
+ )
620
+
621
+ # Step 5: Package results
622
+ progress(0.95, desc="Packaging results...")
623
+
624
+ zip_path = run_dir / "results.zip"
625
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
626
+ # Add PLY files
627
+ for ply in all_ply_files:
628
+ if ply and Path(ply).exists():
629
+ zf.write(ply, f"splats/{Path(ply).name}")
630
+
631
+ # Add all checkpoint renders
632
+ for render_dir in splat_dir.glob("frame_*/iter_*"):
633
+ for img in render_dir.glob("*.png"):
634
+ rel_path = img.relative_to(splat_dir)
635
+ zf.write(img, f"renders/{rel_path}")
636
+
637
+ # Add VDPM data to root
638
+ for f in ["tracks.npz", "poses.npz", "meta.json"]:
639
+ fp = run_dir / f
640
+ if fp.exists():
641
+ zf.write(fp, f)
642
+
643
+ # Add input images
644
+ images_dir = run_dir / "images"
645
+ if images_dir.exists():
646
+ for img in images_dir.glob("*"):
647
+ zf.write(img, f"images/{img.name}")
648
+
649
+ # Add animation GIF
650
+ if gif_path and Path(gif_path).exists():
651
+ zf.write(gif_path, "animation.gif")
652
+
653
+ progress(1.0, desc="Complete!")
654
+
655
+ # Return first PLY for preview
656
+ preview_ply = all_ply_files[0] if all_ply_files else None
657
+
658
+ status = f"""✅ Pipeline Complete!
659
+
660
+ 📊 Results:
661
+ • {len(all_ply_files)} PLY files generated
662
+ • {num_timesteps} timesteps × {num_views} views
663
+ • Animation GIF rendered
664
+
665
+ 📁 Output: {run_dir}
666
+ 📦 Download the ZIP for all files"""
667
+
668
+ return preview_ply, str(zip_path), gif_path, status
669
+
670
+ except Exception as e:
671
+ import traceback
672
+ traceback.print_exc()
673
+ return None, None, None, f"❌ Error: {str(e)}"
674
+
675
+
676
+ # ===== Gradio Interface =====
677
+ with gr.Blocks(title="DPM-Splat: 4D Gaussian Splatting", theme=gr.themes.Soft()) as app:
678
+ gr.Markdown("""
679
+ # 🎬 DPM-Splat: Video → 4D Gaussian Splats
680
+
681
+ End-to-end pipeline combining **V-DPM** (Video Dynamic Point Maps) with **3D Gaussian Splatting**.
682
+ Upload multi-view synchronized videos to generate temporally consistent 4D reconstructions.
683
+ """)
684
+
685
+ with gr.Row():
686
+ with gr.Column(scale=1):
687
+ video_input = gr.File(
688
+ label="📹 Upload Videos",
689
+ file_count="multiple",
690
+ file_types=[".mp4", ".mov", ".avi", ".webm"]
691
+ )
692
+
693
+ gr.Markdown("*Upload 1-4 synchronized video files for best results*")
694
+
695
+ with gr.Accordion("⚙️ Settings", open=True):
696
+ iterations = gr.Slider(
697
+ minimum=0, maximum=10000, value=1000, step=100,
698
+ label="Training Iterations",
699
+ info="0 = export raw point cloud only, more = better quality"
700
+ )
701
+ conf_threshold = gr.Slider(
702
+ minimum=0, maximum=100, value=0, step=5,
703
+ label="Confidence Threshold (%)",
704
+ info="0% keeps all points, higher = filter low confidence"
705
+ )
706
+
707
+ run_btn = gr.Button("🚀 Run Pipeline", variant="primary", size="lg")
708
+
709
+ status_text = gr.Textbox(
710
+ label="Status",
711
+ interactive=False,
712
+ lines=6,
713
+ value="Upload videos and click 'Run Pipeline' to begin."
714
+ )
715
+
716
+ with gr.Column(scale=2):
717
+ with gr.Row():
718
+ model_viewer = gr.Model3D(
719
+ label="3D Preview (First Frame)",
720
+ clear_color=[1.0, 1.0, 1.0, 1.0],
721
+ height=400
722
+ )
723
+ gif_viewer = gr.Image(
724
+ label="🎞️ Animation (Average Camera)",
725
+ height=400
726
+ )
727
+ download_btn = gr.File(label="📦 Download Results (ZIP)")
728
+
729
+ gr.Markdown("""
730
+ ---
731
+ ### 📋 Output Contents
732
+
733
+ The downloaded ZIP contains:
734
+ - `splats/frame_XXXX.ply` - Gaussian splat for each timestep
735
+ - `renders/` - Training progress images (target vs rendered)
736
+ - `animation.gif` - Rendered animation from average camera
737
+ - `tracks.npz` - 3D point tracks
738
+ - `poses.npz` - Camera poses
739
+ - `images/` - Input frames
740
+
741
+ **Local runs**: Results saved to `output/pipeline/run_TIMESTAMP/`
742
+ """)
743
+
744
+ run_btn.click(
745
+ fn=run_pipeline,
746
+ inputs=[video_input, iterations, conf_threshold],
747
+ outputs=[model_viewer, download_btn, gif_viewer, status_text]
748
+ )
749
+
750
+ if __name__ == "__main__":
751
+ app.queue().launch(share=True, show_error=True)
gs/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
gs/.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by venv; see https://docs.python.org/3/library/venv.html
2
+
3
+ data/*
4
+ output/*
5
+ lib/*
6
+ lib64/*
7
+ data_/*
8
+ colmap_0/*
9
+ bin/*
10
+ share/*
11
+ __pycache__/*
12
+ utils/__pycache__/*
13
+ .DS_Store
gs/backward.py ADDED
@@ -0,0 +1,1084 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warp as wp
2
+ import math
3
+ from utils.wp_utils import to_warp_array, wp_vec3_mul_element, wp_vec3_add_element, wp_vec3_sqrt, wp_vec3_div_element, wp_vec3_clamp
4
+ from config import * # Assuming TILE_M, TILE_N, VEC6, DEVICE are defined here
5
+
6
+ # Initialize Warp if not already done elsewhere
7
+ # wp.init()
8
+
9
+ # --- Spherical Harmonics Constants ---
10
+ SH_C0 = 0.28209479177387814
11
+ SH_C1 = 0.4886025119029199
12
+
13
+ @wp.func
14
+ def dnormvdv(v: wp.vec3, dv: wp.vec3) -> wp.vec3:
15
+ """
16
+ Computes the gradient of normalize(v) with respect to v, scaled by dv.
17
+ This is a direct port of the CUDA implementation.
18
+
19
+ Args:
20
+ v: The input vector to be normalized
21
+ dv: The gradient vector to scale the result by
22
+
23
+ Returns:
24
+ The gradient vector
25
+ """
26
+ sum2 = v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
27
+
28
+ # Avoid division by zero
29
+ if sum2 < 1e-10:
30
+ return wp.vec3(0.0, 0.0, 0.0)
31
+
32
+ invsum32 = 1.0 / wp.sqrt(sum2 * sum2 * sum2)
33
+
34
+ result = wp.vec3(
35
+ ((sum2 - v[0] * v[0]) * dv[0] - v[1] * v[0] * dv[1] - v[2] * v[0] * dv[2]) * invsum32,
36
+ (-v[0] * v[1] * dv[0] + (sum2 - v[1] * v[1]) * dv[1] - v[2] * v[1] * dv[2]) * invsum32,
37
+ (-v[0] * v[2] * dv[0] - v[1] * v[2] * dv[1] + (sum2 - v[2] * v[2]) * dv[2]) * invsum32
38
+ )
39
+
40
+ return result
41
+
42
+ # --- Backward Kernels ---
43
+ @wp.kernel
44
+ def sh_backward_kernel(
45
+ # --- Inputs ---
46
+ num_points: int, # Number of Gaussian points
47
+ degree: int, # SH degree used in forward
48
+ means: wp.array(dtype=wp.vec3), # 3D positions (N, 3)
49
+ shs: wp.array(dtype=wp.vec3), # Flattened SH coeffs (N * 16, 3)
50
+ radii: wp.array(dtype=int), # Radii computed in forward (N,) - used for skipping
51
+ campos: wp.vec3, # Camera position (3,)
52
+ clamped_state: wp.array(dtype=wp.vec3), # Clamping state {0,1} from forward pass (N, 3)
53
+ dL_dcolor: wp.array(dtype=wp.vec3), # Grad L w.r.t. *final* gaussian color (N, 3)
54
+
55
+ # --- Outputs (Accumulate) ---
56
+ dL_dmeans: wp.array(dtype=wp.vec3), # Accumulate mean grads here (N, 3)
57
+ dL_dshs: wp.array(dtype=wp.vec3) # Accumulate SH grads here (N * 16, 3)
58
+ ):
59
+ idx = wp.tid()
60
+
61
+ if idx >= num_points or radii[idx] <= 0: # Skip if not rendered
62
+ return
63
+
64
+ mean = means[idx]
65
+ base_sh_idx = idx * 16
66
+
67
+ # --- Recompute view direction ---
68
+ dir_orig = mean - campos
69
+ dir_len = wp.length(dir_orig)
70
+ # Skip if direction length is too small (matches CUDA implementation)
71
+ if dir_len < 1e-8:
72
+ return
73
+
74
+ # Normalize direction
75
+ dir = dir_orig / dir_len
76
+ x = dir[0]; y = dir[1]; z = dir[2]
77
+
78
+ # --- Apply clamping mask to input gradient ---
79
+ dL_dRGB = dL_dcolor[idx]
80
+ dL_dRGB = wp_vec3_mul_element(dL_dRGB, wp_vec3_add_element(wp.vec3(1.0, 1.0, 1.0), -1.0 * clamped_state[idx]))
81
+
82
+ # Initialize gradients w.r.t. direction components (dRawColor/ddir)
83
+ dRGBdx = wp.vec3(0.0, 0.0, 0.0)
84
+ dRGBdy = wp.vec3(0.0, 0.0, 0.0)
85
+ dRGBdz = wp.vec3(0.0, 0.0, 0.0)
86
+
87
+ # --- Degree 0 ---
88
+ # Direct assignment for clarity (matching CUDA style)
89
+ dRGBdsh0 = SH_C0
90
+ dL_dshs[base_sh_idx] = dRGBdsh0 * dL_dRGB
91
+
92
+ # --- Degree 1 ---
93
+ if degree > 0:
94
+ sh1 = shs[base_sh_idx + 1]
95
+ sh2 = shs[base_sh_idx + 2]
96
+ sh3 = shs[base_sh_idx + 3]
97
+
98
+ # Exactly match CUDA computation order
99
+ dRGBdsh1 = -SH_C1 * y
100
+ dRGBdsh2 = SH_C1 * z
101
+ dRGBdsh3 = -SH_C1 * x
102
+
103
+ dL_dshs[base_sh_idx + 1] = dRGBdsh1 * dL_dRGB
104
+ dL_dshs[base_sh_idx + 2] = dRGBdsh2 * dL_dRGB
105
+ dL_dshs[base_sh_idx + 3] = dRGBdsh3 * dL_dRGB
106
+
107
+ # Gradient components w.r.t. direction
108
+ dRGBdx = -SH_C1 * sh3
109
+ dRGBdy = -SH_C1 * sh1
110
+ dRGBdz = SH_C1 * sh2
111
+ # --- Degree 2 ---
112
+ if degree > 1:
113
+ xx = x*x; yy = y*y; zz = z*z
114
+ xy = x*y; yz = y*z; xz = x*z
115
+
116
+ sh4 = shs[base_sh_idx + 4]; sh5 = shs[base_sh_idx + 5]
117
+ sh6 = shs[base_sh_idx + 6]; sh7 = shs[base_sh_idx + 7]
118
+ sh8 = shs[base_sh_idx + 8]
119
+
120
+ # Hardcoded C2 values (same as CUDA SH_C2)
121
+ C2_0 = 1.0925484305920792
122
+ C2_1 = -1.0925484305920792
123
+ C2_2 = 0.31539156525252005
124
+ C2_3 = -1.0925484305920792
125
+ C2_4 = 0.5462742152960396
126
+
127
+ # Compute gradients for degree 2 (matching CUDA)
128
+ dRGBdsh4 = C2_0 * xy
129
+ dRGBdsh5 = C2_1 * yz
130
+ dRGBdsh6 = C2_2 * (2.0 * zz - xx - yy)
131
+ dRGBdsh7 = C2_3 * xz
132
+ dRGBdsh8 = C2_4 * (xx - yy)
133
+
134
+ dL_dshs[base_sh_idx + 4] = dRGBdsh4 * dL_dRGB
135
+ dL_dshs[base_sh_idx + 5] = dRGBdsh5 * dL_dRGB
136
+ dL_dshs[base_sh_idx + 6] = dRGBdsh6 * dL_dRGB
137
+ dL_dshs[base_sh_idx + 7] = dRGBdsh7 * dL_dRGB
138
+ dL_dshs[base_sh_idx + 8] = dRGBdsh8 * dL_dRGB
139
+
140
+ # Accumulate gradients w.r.t. direction (exactly matching CUDA)
141
+ dRGBdx += C2_0 * y * sh4 + C2_2 * 2.0 * -x * sh6 + C2_3 * z * sh7 + C2_4 * 2.0 * x * sh8
142
+ dRGBdy += C2_0 * x * sh4 + C2_1 * z * sh5 + C2_2 * 2.0 * -y * sh6 + C2_4 * 2.0 * -y * sh8
143
+ dRGBdz += C2_1 * y * sh5 + C2_2 * 2.0 * 2.0 * z * sh6 + C2_3 * x * sh7
144
+
145
+ # --- Degree 3 ---
146
+ if degree > 2:
147
+ sh9 = shs[base_sh_idx + 9]; sh10 = shs[base_sh_idx + 10]
148
+ sh11 = shs[base_sh_idx + 11]; sh12 = shs[base_sh_idx + 12]
149
+ sh13 = shs[base_sh_idx + 13]; sh14 = shs[base_sh_idx + 14]
150
+ sh15 = shs[base_sh_idx + 15]
151
+
152
+ # Hardcoded C3 values (same as CUDA SH_C3)
153
+ C3_0 = -0.5900435899266435
154
+ C3_1 = 2.890611442640554
155
+ C3_2 = -0.4570457994644658
156
+ C3_3 = 0.3731763325901154
157
+ C3_4 = -0.4570457994644658
158
+ C3_5 = 1.445305721320277
159
+ C3_6 = -0.5900435899266435
160
+
161
+ # Direct computation of degree 3 gradients (matching CUDA)
162
+ dRGBdsh9 = C3_0 * y * (3.0 * xx - yy)
163
+ dRGBdsh10 = C3_1 * xy * z
164
+ dRGBdsh11 = C3_2 * y * (4.0 * zz - xx - yy)
165
+ dRGBdsh12 = C3_3 * z * (2.0 * zz - 3.0 * xx - 3.0 * yy)
166
+ dRGBdsh13 = C3_4 * x * (4.0 * zz - xx - yy)
167
+ dRGBdsh14 = C3_5 * z * (xx - yy)
168
+ dRGBdsh15 = C3_6 * x * (xx - 3.0 * yy)
169
+
170
+ dL_dshs[base_sh_idx + 9] = dRGBdsh9 * dL_dRGB
171
+ dL_dshs[base_sh_idx + 10] = dRGBdsh10 * dL_dRGB
172
+ dL_dshs[base_sh_idx + 11] = dRGBdsh11 * dL_dRGB
173
+ dL_dshs[base_sh_idx + 12] = dRGBdsh12 * dL_dRGB
174
+ dL_dshs[base_sh_idx + 13] = dRGBdsh13 * dL_dRGB
175
+ dL_dshs[base_sh_idx + 14] = dRGBdsh14 * dL_dRGB
176
+ dL_dshs[base_sh_idx + 15] = dRGBdsh15 * dL_dRGB
177
+
178
+ # Accumulate dRGBdx (matching CUDA's expression structure)
179
+ dRGBdx += (
180
+ C3_0 * sh9 * 3.0 * 2.0 * xy +
181
+ C3_1 * sh10 * yz +
182
+ C3_2 * sh11 * -2.0 * xy +
183
+ C3_3 * sh12 * -3.0 * 2.0 * xz +
184
+ C3_4 * sh13 * (-3.0 * xx + 4.0 * zz - yy) +
185
+ C3_5 * sh14 * 2.0 * xz +
186
+ C3_6 * sh15 * 3.0 * (xx - yy)
187
+ )
188
+
189
+ # Accumulate dRGBdy (matching CUDA's expression structure)
190
+ dRGBdy += (
191
+ C3_0 * sh9 * 3.0 * (xx - yy) +
192
+ C3_1 * sh10 * xz +
193
+ C3_2 * sh11 * (-3.0 * yy + 4.0 * zz - xx) +
194
+ C3_3 * sh12 * -3.0 * 2.0 * yz +
195
+ C3_4 * sh13 * -2.0 * xy +
196
+ C3_5 * sh14 * -2.0 * yz +
197
+ C3_6 * sh15 * -3.0 * 2.0 * xy
198
+ )
199
+
200
+ # Accumulate dRGBdz (matching CUDA's expression structure)
201
+ dRGBdz += (
202
+ C3_1 * sh10 * xy +
203
+ C3_2 * sh11 * 4.0 * 2.0 * yz +
204
+ C3_3 * sh12 * 3.0 * (2.0 * zz - xx - yy) +
205
+ C3_4 * sh13 * 4.0 * 2.0 * xz +
206
+ C3_5 * sh14 * (xx - yy)
207
+ )
208
+
209
+ # --- Compute gradient w.r.t. view direction (dL/ddir) ---
210
+ dL_ddir = wp.vec3(wp.dot(dRGBdx, dL_dRGB),
211
+ wp.dot(dRGBdy, dL_dRGB),
212
+ wp.dot(dRGBdz, dL_dRGB))
213
+
214
+ # --- Propagate gradient from direction to mean position (dL/dmean) ---
215
+ dL_dmeans_local = dnormvdv(dir_orig, dL_ddir)
216
+
217
+ # --- Accumulate gradients to global arrays ---
218
+ dL_dmeans[idx] += dL_dmeans_local
219
+
220
+
221
+ @wp.kernel
222
+ def compute_cov2d_backward_kernel(
223
+ # --- Inputs ---
224
+ num_points: int, # Number of Gaussian points
225
+ means: wp.array(dtype=wp.vec3), # 3D positions (N, 3)
226
+ cov3Ds: wp.array(dtype=VEC6), # Packed 3D cov (N, 6)
227
+ radii: wp.array(dtype=int), # Radii computed in forward (N,) - used for skipping
228
+ h_x: float, h_y: float, # Focal lengths
229
+ tan_fovx: float, tan_fovy: float, # Tangent of FOV
230
+ view_matrix: wp.mat44, # World->View matrix (4, 4)
231
+ dL_dconics: wp.array(dtype=wp.vec4), # Grad L w.r.t. conic (a, b, c) (N, 3)
232
+
233
+ # --- Outputs (Accumulate) ---
234
+ dL_dmeans: wp.array(dtype=wp.vec3), # Accumulate mean grads here (N, 3)
235
+ dL_dcov3Ds: wp.array(dtype=VEC6) # Accumulate 3D cov grads here (N, 6)
236
+ ):
237
+ idx = wp.tid()
238
+ if idx >= num_points or radii[idx] <= 0: # Skip if not rendered
239
+ # Zero out dL_dcov3Ds to ensure we don't keep old values
240
+ dL_dcov3Ds[idx] = VEC6(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
241
+ return
242
+
243
+ mean = means[idx]
244
+ cov3D_packed = cov3Ds[idx] # VEC6
245
+
246
+ dL_dconic = wp.vec3(dL_dconics[idx][0], dL_dconics[idx][1], dL_dconics[idx][3])
247
+
248
+
249
+ t = wp.vec4(mean[0], mean[1], mean[2], 1.0) * view_matrix
250
+
251
+ limx = 1.3 * tan_fovx
252
+ limy = 1.3 * tan_fovy
253
+ tz = t[2]
254
+ inv_tz = 1.0 / tz
255
+ txtz = t[0] * inv_tz
256
+ tytz = t[1] * inv_tz
257
+
258
+ x_clamped_flag = (txtz < -limx) or (txtz > limx)
259
+ y_clamped_flag = (tytz < -limy) or (tytz > limy)
260
+ x_grad_mul = 1.0 - float(x_clamped_flag) # 1.0 if not clamped, 0.0 if clamped
261
+ y_grad_mul = 1.0 - float(y_clamped_flag)
262
+
263
+ tx = wp.min(limx, wp.max(-limx, txtz)) * tz
264
+ ty = wp.min(limy, wp.max(-limy, tytz)) * tz
265
+ inv_tz2 = inv_tz * inv_tz
266
+ inv_tz3 = inv_tz2 * inv_tz
267
+
268
+ J00 = h_x * inv_tz
269
+ J11 = h_y * inv_tz
270
+ J02 = -h_x * tx * inv_tz2
271
+ J12 = -h_y * ty * inv_tz2
272
+
273
+ J = wp.transpose(wp.mat33(
274
+ J00, 0.0, J02,
275
+ 0.0, J11, J12,
276
+ 0.0, 0.0, 0.0
277
+ ))
278
+
279
+
280
+
281
+ W = wp.mat33(
282
+ view_matrix[0,0], view_matrix[0,1], view_matrix[0,2],
283
+ view_matrix[1,0], view_matrix[1,1], view_matrix[1,2],
284
+ view_matrix[2,0], view_matrix[2,1], view_matrix[2,2]
285
+ )
286
+
287
+ T = W * J
288
+ c0 = cov3D_packed[0]; c1 = cov3D_packed[1]; c2 = cov3D_packed[2]
289
+ c11 = cov3D_packed[3]; c12 = cov3D_packed[4]; c22 = cov3D_packed[5]
290
+ Vrk = wp.mat33(c0, c1, c2, c1, c11, c12, c2, c12, c22) # Assumes VEC6 stores upper triangle row-wise
291
+
292
+ cov2D_mat = wp.transpose(T) * wp.transpose(Vrk) * T
293
+
294
+ a_noblr = cov2D_mat[0,0]
295
+ b_noblr = cov2D_mat[0,1]
296
+ c_noblr = cov2D_mat[1,1]
297
+ a = a_noblr + 0.3
298
+ b = b_noblr
299
+ c = c_noblr + 0.3
300
+
301
+ denom = a * c - b * b
302
+ dL_da = 0.0; dL_db = 0.0; dL_dc = 0.0
303
+
304
+ # --- Calculate Gradients ---
305
+ if denom != 0.0:
306
+ # Use a small epsilon to prevent division by zero
307
+ denom2inv = 1.0 / (denom * denom + 1e-7)
308
+ dL_da = denom2inv * (-c * c * dL_dconic[0] + 2.0 * b * c * dL_dconic[1] + (denom - a * c) * dL_dconic[2])
309
+ dL_dc = denom2inv * (-a * a * dL_dconic[2] + 2.0 * a * b * dL_dconic[1] + (denom - a * c) * dL_dconic[0])
310
+ dL_db = denom2inv * 2.0 * (b * c * dL_dconic[0] - (denom + 2.0 * b * b) * dL_dconic[1] + a * b * dL_dconic[2])
311
+
312
+ dL_dcov3Ds[idx] = VEC6(
313
+ # Diagonal elements
314
+ T[0][0] * T[0][0] * dL_da + T[0][0] * T[0][1] * dL_db + T[0][1] * T[0][1] * dL_dc, # c00
315
+ 2.0 * T[0][0] * T[1][0] * dL_da + (T[0][0] * T[1][1] + T[1][0] * T[0][1]) * dL_db + 2.0 * T[0][1] * T[1][1] * dL_dc, # c01
316
+ 2.0 * T[0][0] * T[2][0] * dL_da + (T[0][0] * T[2][1] + T[2][0] * T[0][1]) * dL_db + 2.0 * T[0][1] * T[2][1] * dL_dc, # c02
317
+ T[1][0] * T[1][0] * dL_da + T[1][0] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc, # c11
318
+ 2.0 * T[2][0] * T[1][0] * dL_da + (T[1][0] * T[2][1] + T[2][0] * T[1][1]) * dL_db + 2.0 * T[1][1] * T[2][1] * dL_dc, # c12
319
+ T[2][0] * T[2][0] * dL_da + T[2][0] * T[2][1] * dL_db + T[2][1] * T[2][1] * dL_dc # c22
320
+ )
321
+
322
+ dL_dT00 = 2.0 * (T[0][0] * Vrk[0][0] + T[1][0] * Vrk[1][0] + T[2][0] * Vrk[2][0]) * dL_da + \
323
+ (T[0][1] * Vrk[0][0] + T[1][1] * Vrk[1][0] + T[2][1] * Vrk[2][0]) * dL_db
324
+ dL_dT01 = 2.0 * (T[0][0] * Vrk[0][1] + T[1][0] * Vrk[1][1] + T[2][0] * Vrk[2][1]) * dL_da + \
325
+ (T[0][1] * Vrk[0][1] + T[1][1] * Vrk[1][1] + T[2][1] * Vrk[2][1]) * dL_db
326
+ dL_dT02 = 2.0 * (T[0][0] * Vrk[0][2] + T[1][0] * Vrk[1][2] + T[2][0] * Vrk[2][2]) * dL_da + \
327
+ (T[0][1] * Vrk[0][2] + T[1][1] * Vrk[1][2] + T[2][1] * Vrk[2][2]) * dL_db
328
+ dL_dT10 = 2.0 * (T[0][1] * Vrk[0][0] + T[1][1] * Vrk[1][0] + T[2][1] * Vrk[2][0]) * dL_dc + \
329
+ (T[0][0] * Vrk[0][0] + T[1][0] * Vrk[1][0] + T[2][0] * Vrk[2][0]) * dL_db
330
+ dL_dT11 = 2.0 * (T[0][1] * Vrk[0][1] + T[1][1] * Vrk[1][1] + T[2][1] * Vrk[2][1]) * dL_dc + \
331
+ (T[0][0] * Vrk[0][1] + T[1][0] * Vrk[1][1] + T[2][0] * Vrk[2][1]) * dL_db
332
+ dL_dT12 = 2.0 * (T[0][1] * Vrk[0][2] + T[1][1] * Vrk[1][2] + T[2][1] * Vrk[2][2]) * dL_dc + \
333
+ (T[0][0] * Vrk[0][2] + T[1][0] * Vrk[1][2] + T[2][0] * Vrk[2][2]) * dL_db
334
+
335
+ dL_dJ00 = W[0,0] * dL_dT00 + W[1,0] * dL_dT01 + W[2,0] * dL_dT02
336
+ dL_dJ02 = W[0,2] * dL_dT00 + W[1,2] * dL_dT01 + W[2,2] * dL_dT02
337
+ dL_dJ11 = W[0,1] * dL_dT10 + W[1,1] * dL_dT11 + W[2,1] * dL_dT12
338
+ dL_dJ12 = W[0,2] * dL_dT10 + W[1,2] * dL_dT11 + W[2,2] * dL_dT12
339
+
340
+ dL_dtx = -h_x * inv_tz2 * dL_dJ02
341
+ dL_dty = -h_y * inv_tz2 * dL_dJ12
342
+ dL_dtz = -h_x * inv_tz2 * dL_dJ00 - h_y * inv_tz2 * dL_dJ11 + \
343
+ 2.0 * h_x * tx * inv_tz3 * dL_dJ02 + 2.0 * h_y * ty * inv_tz3 * dL_dJ12
344
+
345
+ dL_dt = wp.vec3(dL_dtx * x_grad_mul, dL_dty * y_grad_mul, dL_dtz)
346
+
347
+ dL_dmean_from_cov = wp.vec4(dL_dt[0], dL_dt[1], dL_dt[2], 1.0) * wp.transpose(view_matrix)
348
+ dL_dmeans[idx] += wp.vec3(dL_dmean_from_cov[0], dL_dmean_from_cov[1], dL_dmean_from_cov[2])
349
+
350
+
351
+ @wp.kernel
352
+ def compute_cov3d_backward_kernel(
353
+ # --- Inputs ---
354
+ num_points: int, # Number of Gaussian points
355
+ scales: wp.array(dtype=wp.vec3), # Scale parameters (N, 3)
356
+ rotations: wp.array(dtype=wp.vec4), # Quaternions (x, y, z, w) (N, 4)
357
+ radii: wp.array(dtype=int), # Radii computed in forward (N,) - used for skipping
358
+ scale_modifier: float, # Global scale modifier
359
+ dL_dcov3Ds: wp.array(dtype=VEC6), # Grad L w.r.t packed 3D cov (N, 6)
360
+
361
+ # --- Outputs ---
362
+ dL_dscales: wp.array(dtype=wp.vec3), # Write scale grads here (N, 3)
363
+ dL_drots: wp.array(dtype=wp.vec4) # Write rot grads here (N, 4)
364
+ ):
365
+ idx = wp.tid()
366
+ # Skip if not rendered OR if grad input is zero (e.g., from compute_cov2d_backward)
367
+ if idx >= num_points or radii[idx] <= 0:
368
+ dL_dscales[idx] = wp.vec3(0.0, 0.0, 0.0)
369
+ dL_drots[idx] = wp.vec4(0.0, 0.0, 0.0, 0.0)
370
+ return
371
+
372
+ # --- Recompute intermediates ---
373
+ scale_vec = scales[idx]
374
+ rot_quat = rotations[idx] # (x, y, z, w) in Warp
375
+
376
+ # Extract quaternion components to match CUDA convention (r, x, y, z)
377
+ r = rot_quat[3] # Real part is w in Warp
378
+ x = rot_quat[0]
379
+ y = rot_quat[1]
380
+ z = rot_quat[2]
381
+
382
+ # 1. Construct rotation matrix R manually as in CUDA
383
+ R = wp.mat33(
384
+ 1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - r * z), 2.0 * (x * z + r * y),
385
+ 2.0 * (x * y + r * z), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - r * x),
386
+ 2.0 * (x * z - r * y), 2.0 * (y * z + r * x), 1.0 - 2.0 * (x * x + y * y)
387
+ )
388
+
389
+ # 2. Create scaling matrix S
390
+ s_vec = scale_modifier * scale_vec
391
+ S = wp.mat33(
392
+ s_vec[0], 0.0, 0.0,
393
+ 0.0, s_vec[1], 0.0,
394
+ 0.0, 0.0, s_vec[2]
395
+ )
396
+
397
+ # 3. M = S * R (match CUDA multiplication order)
398
+ M = S * R
399
+
400
+ # --- Extract gradient w.r.t. 3D covariance ---
401
+ dL_dcov3D_packed = dL_dcov3Ds[idx]
402
+
403
+
404
+ # Convert per-element covariance loss gradients to matrix form
405
+ dL_dSigma = wp.mat33(
406
+ dL_dcov3D_packed[0], 0.5 * dL_dcov3D_packed[1], 0.5 * dL_dcov3D_packed[2],
407
+ 0.5 * dL_dcov3D_packed[1], dL_dcov3D_packed[3], 0.5 * dL_dcov3D_packed[4],
408
+ 0.5 * dL_dcov3D_packed[2], 0.5 * dL_dcov3D_packed[4], dL_dcov3D_packed[5]
409
+ )
410
+
411
+ # --- Calculate Gradients ---
412
+ # 1. Gradient w.r.t. M: dL/dM = 2 * M * dL/dSigma
413
+ dL_dM = 2.0 * M * dL_dSigma
414
+
415
+ # 2. Transpose of matrices for gradient calculations
416
+ Rt = wp.transpose(R)
417
+ dL_dMt = wp.transpose(dL_dM)
418
+
419
+ # 3. Gradient w.r.t. scales - matching CUDA directly
420
+ dL_dscale = wp.vec3(
421
+ wp.dot(Rt[0], dL_dMt[0]),
422
+ wp.dot(Rt[1], dL_dMt[1]),
423
+ wp.dot(Rt[2], dL_dMt[2])
424
+ )
425
+ dL_dscales[idx] = dL_dscale * scale_modifier
426
+
427
+ # 4. Scale dL_dMt by scale factors for quaternion gradient calculation
428
+ dL_dMt_scaled = wp.mat33(
429
+ dL_dMt[0, 0] * s_vec[0], dL_dMt[0, 1] * s_vec[0], dL_dMt[0, 2] * s_vec[0],
430
+ dL_dMt[1, 0] * s_vec[1], dL_dMt[1, 1] * s_vec[1], dL_dMt[1, 2] * s_vec[1],
431
+ dL_dMt[2, 0] * s_vec[2], dL_dMt[2, 1] * s_vec[2], dL_dMt[2, 2] * s_vec[2]
432
+ )
433
+
434
+ # 5. Gradients of loss w.r.t. quaternion components
435
+ dL_dr = 2.0 * (z * (dL_dMt_scaled[0, 1] - dL_dMt_scaled[1, 0]) +
436
+ y * (dL_dMt_scaled[2, 0] - dL_dMt_scaled[0, 2]) +
437
+ x * (dL_dMt_scaled[1, 2] - dL_dMt_scaled[2, 1]))
438
+
439
+ dL_dx = 2.0 * (y * (dL_dMt_scaled[1, 0] + dL_dMt_scaled[0, 1]) +
440
+ z * (dL_dMt_scaled[2, 0] + dL_dMt_scaled[0, 2]) +
441
+ r * (dL_dMt_scaled[1, 2] - dL_dMt_scaled[2, 1])) - \
442
+ 4.0 * x * (dL_dMt_scaled[2, 2] + dL_dMt_scaled[1, 1])
443
+
444
+ dL_dy = 2.0 * (x * (dL_dMt_scaled[1, 0] + dL_dMt_scaled[0, 1]) +
445
+ r * (dL_dMt_scaled[2, 0] - dL_dMt_scaled[0, 2]) +
446
+ z * (dL_dMt_scaled[1, 2] + dL_dMt_scaled[2, 1])) - \
447
+ 4.0 * y * (dL_dMt_scaled[2, 2] + dL_dMt_scaled[0, 0])
448
+
449
+ dL_dz = 2.0 * (r * (dL_dMt_scaled[0, 1] - dL_dMt_scaled[1, 0]) +
450
+ x * (dL_dMt_scaled[2, 0] + dL_dMt_scaled[0, 2]) +
451
+ y * (dL_dMt_scaled[1, 2] + dL_dMt_scaled[2, 1])) - \
452
+ 4.0 * z * (dL_dMt_scaled[1, 1] + dL_dMt_scaled[0, 0])
453
+
454
+ # 6. Convert back to Warp's quaternion ordering (x, y, z, r/w)
455
+ dL_drots[idx] = wp.vec4(dL_dx, dL_dy, dL_dz, dL_dr)
456
+
457
+ @wp.kernel
458
+ def wp_render_backward_kernel(
459
+ # --- Inputs ---
460
+ # Tile/Range data
461
+ ranges: wp.array(dtype=wp.vec2i), # Range of point indices for each tile (start, end)
462
+ point_list: wp.array(dtype=int), # Sorted point indices
463
+
464
+ # Image parameters
465
+ W: int, # Image width
466
+ H: int, # Image height
467
+ bg_color: wp.vec3, # Background color
468
+ tile_grid: wp.vec3, # Tile grid dimensions
469
+
470
+ # Gaussian parameters
471
+ points_xy_image: wp.array(dtype=wp.vec2), # 2D projected positions
472
+ conic_opacity: wp.array(dtype=wp.vec4), # Conic matrices and opacities (a, b, c, opacity)
473
+ colors: wp.array(dtype=wp.vec3), # RGB colors
474
+
475
+ # Forward pass results
476
+ final_Ts: wp.array2d(dtype=float), # Final transparency values
477
+ n_contrib: wp.array2d(dtype=int), # Number of Gaussians contributing to each pixel
478
+ dL_dpixels: wp.array2d(dtype=wp.vec3), # Gradient of loss w.r.t. output pixels
479
+
480
+ # --- Outputs ---
481
+ dL_dmean2D: wp.array(dtype=wp.vec3), # Gradient w.r.t. 2D mean positions
482
+ dL_dconic2D: wp.array(dtype=wp.vec4), # Gradient w.r.t. conic matrices
483
+ dL_dopacity: wp.array(dtype=float), # Gradient w.r.t. opacity
484
+ dL_dcolors: wp.array(dtype=wp.vec3), # Gradient w.r.t. colors
485
+ ):
486
+ """
487
+ Backward version of the rendering procedure, computing gradients of the loss with respect
488
+ to Gaussian parameters based on gradients of the loss with respect to output pixels.
489
+
490
+ This kernel is launched per pixel and processes Gaussians in back-to-front order,
491
+ similar to the forward rendering pass but accumulating gradients.
492
+ """
493
+ # Get pixel coordinates
494
+ tile_x, tile_y, tid_x, tid_y = wp.tid()
495
+
496
+ # Calculate pixel position
497
+ pix_x = tile_x * TILE_M + tid_x
498
+ pix_y = tile_y * TILE_N + tid_y
499
+
500
+ # Skip if pixel is outside image bounds
501
+ inside = (pix_x < W) and (pix_y < H)
502
+ if not inside:
503
+ return
504
+
505
+ # Convert to float coordinates for calculations
506
+ pixf_x = float(pix_x)
507
+ pixf_y = float(pix_y)
508
+
509
+ # Get tile range (start/end indices in point_list)
510
+ tile_id = tile_y * int(tile_grid[0]) + tile_x
511
+
512
+ range_start = ranges[tile_id][0]
513
+ range_end = ranges[tile_id][1]
514
+
515
+ # Get final transparency value and number of contributors from forward pass
516
+ T_final = final_Ts[pix_y, pix_x]
517
+ last_contributor = n_contrib[pix_y, pix_x]
518
+
519
+ # first_kept = max(range_start, range_end - last_contributor) # = range_end-N
520
+ last_kept = min(range_end, range_start + last_contributor)
521
+
522
+ # Initialize working variables
523
+ T = T_final # Current accumulated transparency
524
+ accum_rec = wp.vec3(0.0, 0.0, 0.0) # Accumulated color
525
+ last_alpha = float(0.0) # Alpha from the last processed Gaussian
526
+ last_color = wp.vec3(0.0, 0.0, 0.0) # Color from the last processed Gaussian
527
+
528
+ # Get gradients
529
+ dL_dpixel = dL_dpixels[pix_y, pix_x]
530
+
531
+ # Gradient of pixel coordinate w.r.t. normalized screen-space coordinates
532
+ ddelx_dx = 0.5 * float(W)
533
+ ddely_dy = 0.5 * float(H)
534
+ for i in range(last_kept - 1, range_start - 1, -1):
535
+ gaussian_id = point_list[i]
536
+ xy = points_xy_image[gaussian_id]
537
+ con_o = conic_opacity[gaussian_id] # (a, b, c, opacity)
538
+ color = colors[gaussian_id]
539
+
540
+ # Compute distance to pixel center
541
+ d_x = xy[0] - pixf_x
542
+ d_y = xy[1] - pixf_y
543
+
544
+ # Compute Gaussian power
545
+ power = -0.5 * (con_o[0] * d_x * d_x + con_o[2] * d_y * d_y) - con_o[1] * d_x * d_y
546
+
547
+ # Skip if power is positive (too far away)
548
+ if power > 0.0:
549
+ continue
550
+
551
+ # Compute Gaussian value and alpha
552
+ G = wp.exp(power)
553
+ alpha = wp.min(0.99, con_o[3] * G)
554
+
555
+ # Skip if alpha is too small
556
+ if alpha < (1.0 / 255.0):
557
+ continue
558
+
559
+ T = T / (1.0 - alpha)
560
+
561
+ # Gradient factor for color contribution
562
+ dchannel_dcolor = alpha * T
563
+
564
+ # Compute gradient w.r.t. alpha
565
+ dL_dalpha = 0.0
566
+
567
+ # Update color accumulation and compute color gradients
568
+ accum_rec = last_alpha * last_color + (1.0 - last_alpha) * accum_rec
569
+ dL_dchannel = dL_dpixel
570
+ last_color = color
571
+
572
+ dL_dalpha = wp.dot(color - accum_rec, dL_dpixel)
573
+ wp.atomic_add(dL_dcolors, gaussian_id, dchannel_dcolor * dL_dchannel)
574
+
575
+ # Scale dL_dalpha by T
576
+ dL_dalpha *= T
577
+ last_alpha = alpha
578
+
579
+ # Account for background color contribution
580
+ bg_dot_dpixel = wp.dot(bg_color, dL_dpixel)
581
+ dL_dalpha += (-T_final / (1.0 - alpha)) * bg_dot_dpixel
582
+
583
+ # Helpful temporary variables
584
+ dL_dG = con_o[3] * dL_dalpha
585
+ gdx = G * d_x
586
+ gdy = G * d_y
587
+ dG_ddelx = -gdx * con_o[0] - gdy * con_o[1]
588
+ dG_ddely = -gdy * con_o[2] - gdx * con_o[1]
589
+
590
+
591
+ # Update gradients w.r.t. 2D mean position
592
+ wp.atomic_add(dL_dmean2D, gaussian_id, wp.vec3(
593
+ dL_dG * dG_ddelx * ddelx_dx,
594
+ dL_dG * dG_ddely * ddely_dy,
595
+ 0.0
596
+ ))
597
+
598
+ # Update gradients w.r.t. 2D conic matrix
599
+ wp.atomic_add(dL_dconic2D, gaussian_id, wp.vec4(
600
+ -0.5 * gdx * d_x * dL_dG,
601
+ -0.5 * gdx * d_y * dL_dG,
602
+ 0.0,
603
+ -0.5 * gdy * d_y * dL_dG
604
+ ))
605
+
606
+ # Update gradients w.r.t. opacity
607
+ wp.atomic_add(dL_dopacity, gaussian_id, G * dL_dalpha)
608
+
609
+ @wp.kernel
610
+ def compute_projection_backward_kernel(
611
+ # --- Inputs ---
612
+ num_points: int, # Number of Gaussian points
613
+ means: wp.array(dtype=wp.vec3), # 3D positions (N, 3)
614
+ radii: wp.array(dtype=int), # Radii computed in forward (N,) - used for skipping
615
+ proj_matrix: wp.mat44, # Projection matrix (4, 4)
616
+ dL_dmean2D: wp.array(dtype=wp.vec3), # Grad of loss w.r.t. 2D projected means (N, 2)
617
+
618
+ # --- Outputs (Accumulate) ---
619
+ dL_dmeans: wp.array(dtype=wp.vec3) # Accumulate mean grads here (N, 3)
620
+ ):
621
+ """Compute gradients of 3D means due to projection to 2D.
622
+
623
+ This kernel handles the gradient propagation from 2D projected positions
624
+ back to 3D positions, based on the projection matrix.
625
+ """
626
+ idx = wp.tid()
627
+ if idx >= num_points or radii[idx] <= 0: # Skip if not rendered
628
+ return
629
+
630
+ # Get 3D mean and 2D mean gradient
631
+ mean3D = means[idx]
632
+ dL_dmean2D_val = dL_dmean2D[idx]
633
+
634
+ # Compute homogeneous coordinates
635
+ m_hom = wp.vec4(mean3D[0], mean3D[1], mean3D[2], 1.0)
636
+ m_hom = m_hom * proj_matrix
637
+
638
+ # Division by w (perspective division)
639
+ m_w = 1.0 / (m_hom[3] + 0.0000001)
640
+
641
+ # Compute gradient of loss w.r.t. 3D means due to 2D mean gradients
642
+ # Following the chain rule through the perspective projection
643
+ mul1 = (proj_matrix[0, 0] * mean3D[0] + proj_matrix[1, 0] * mean3D[1] +
644
+ proj_matrix[2, 0] * mean3D[2] + proj_matrix[3, 0]) * m_w * m_w
645
+
646
+ mul2 = (proj_matrix[0, 1] * mean3D[0] + proj_matrix[1, 1] * mean3D[1] +
647
+ proj_matrix[2, 1] * mean3D[2] + proj_matrix[3, 1]) * m_w * m_w
648
+
649
+ dL_dmean = wp.vec3(0.0, 0.0, 0.0)
650
+
651
+ # x component of gradient
652
+ dL_dmean[0] = (proj_matrix[0, 0] * m_w - proj_matrix[0, 3] * mul1) * dL_dmean2D_val[0] + \
653
+ (proj_matrix[0, 1] * m_w - proj_matrix[0, 3] * mul2) * dL_dmean2D_val[1]
654
+
655
+ # y component of gradient
656
+ dL_dmean[1] = (proj_matrix[1, 0] * m_w - proj_matrix[1, 3] * mul1) * dL_dmean2D_val[0] + \
657
+ (proj_matrix[1, 1] * m_w - proj_matrix[1, 3] * mul2) * dL_dmean2D_val[1]
658
+
659
+ # z component of gradient
660
+ dL_dmean[2] = (proj_matrix[2, 0] * m_w - proj_matrix[2, 3] * mul1) * dL_dmean2D_val[0] + \
661
+ (proj_matrix[2, 1] * m_w - proj_matrix[2, 3] * mul2) * dL_dmean2D_val[1]
662
+
663
+
664
+ dL_dmeans[idx] += dL_dmean
665
+
666
+ def backward_preprocess(
667
+ # Camera and model parameters
668
+ num_points: int,
669
+ means: wp.array(dtype=wp.vec3), # 3D means
670
+ means_2d: wp.array(dtype=wp.vec2), # 2D means
671
+ radii: wp.array(dtype=int), # Computed radii
672
+ sh_coeffs: wp.array(dtype=wp.vec3), # SH coefficients
673
+ scales: wp.array(dtype=wp.vec3), # Scale parameters
674
+ rotations: wp.array(dtype=wp.vec4), # Rotation quaternions
675
+ viewmatrix: wp.mat44, # Camera view matrix
676
+ projmatrix: wp.mat44, # Camera projection matrix
677
+ fov_x: float, # Camera horizontal FOV
678
+ fov_y: float, # Camera vertical FOV
679
+ focal_x: float,
680
+ focal_y: float,
681
+
682
+ # Intermediate data from forward
683
+ cov3Ds: wp.array(dtype=wp.mat33), # 3D covariance matrices (or VEC6 depending on packing)
684
+ conic_opacity: wp.array(dtype=wp.vec4), # 2D conics and opacity
685
+ campos: wp.array(dtype=wp.vec3), # View directions (should be campos)
686
+ clamped: wp.array(dtype=wp.uint32), # Clamping states
687
+
688
+ # Incoming gradients from render backward
689
+ dL_dmean2D: wp.array(dtype=wp.vec3), # Grad of loss w.r.t. 2D means
690
+ dL_dconic: wp.array(dtype=wp.vec4), # Grad of loss w.r.t. 2D conics
691
+ dL_dopacity: wp.array(dtype=float), # Grad of loss w.r.t. opacity
692
+ dL_dcolors: wp.array(dtype=wp.vec3), # Grad of loss w.r.t. colors
693
+
694
+ # Output gradient buffers
695
+ dL_dmeans: wp.array(dtype=wp.vec3), # Output grad for 3D means
696
+ dL_dsh: wp.array(dtype=wp.vec3), # Output grad for SH coeffs
697
+ dL_dscales: wp.array(dtype=wp.vec3), # Output grad for scales
698
+ dL_drots: wp.array(dtype=wp.vec4), # Output grad for rotations
699
+
700
+ # Optional parameters
701
+ scale_modifier: float = 1.0,
702
+ sh_degree: int = 3
703
+ ):
704
+ """
705
+ Orchestrates the backward pass for 3D Gaussian Splatting by coordinating several kernel calls.
706
+ """
707
+ # Create buffer for 3D covariance gradients
708
+ dL_dcov3D = wp.zeros(num_points, dtype=VEC6, device=DEVICE)
709
+ # Step 1: Compute gradients for 2D covariance (conic matrix)
710
+ # This also computes gradients w.r.t. 3D means due to conic computation
711
+ wp.launch(
712
+ kernel=compute_cov2d_backward_kernel,
713
+ dim=num_points,
714
+ inputs=[
715
+ num_points, # P
716
+ means, # means3D
717
+ cov3Ds, # cov3Ds
718
+ radii, # radii
719
+ focal_x, # focal_x
720
+ focal_y, # focal_y
721
+ fov_x, # tan_fovx
722
+ fov_y, # tan_fovy
723
+ viewmatrix, # viewmatrix
724
+ dL_dconic, # dL_dconic
725
+ dL_dmeans, # dL_dmean3D (outputs)
726
+ dL_dcov3D # dL_dcov3D (outputs)
727
+ ],
728
+ device=DEVICE
729
+ )
730
+
731
+ dL_dmeans_np = dL_dmeans.numpy()
732
+ # Step 2: Compute gradients for 3D means due to projection
733
+ wp.launch(
734
+ kernel=compute_projection_backward_kernel,
735
+ dim=num_points,
736
+ inputs=[
737
+ num_points,
738
+ means,
739
+ radii,
740
+ projmatrix,
741
+ dL_dmean2D,
742
+ dL_dmeans # Accumulate to final means gradients
743
+ ],
744
+ device=DEVICE
745
+ )
746
+
747
+ # Step 3: Compute gradients for SH coefficients
748
+ wp.launch(
749
+ kernel=sh_backward_kernel,
750
+ dim=num_points,
751
+ inputs=[
752
+ num_points,
753
+ sh_degree,
754
+ means,
755
+ sh_coeffs,
756
+ radii,
757
+ campos,
758
+ clamped,
759
+ dL_dcolors,
760
+ dL_dmeans,
761
+ dL_dsh
762
+ ],
763
+
764
+ device=DEVICE
765
+ )
766
+ dL_dmeans_np = dL_dmeans.numpy()
767
+ # Step 4: Compute gradients for scales and rotations
768
+ wp.launch(
769
+ kernel=compute_cov3d_backward_kernel,
770
+ dim=num_points,
771
+ inputs=[
772
+ num_points,
773
+ scales,
774
+ rotations,
775
+ radii,
776
+ scale_modifier,
777
+ dL_dcov3D,
778
+ dL_dscales, # Output scale gradients
779
+ dL_drots # Output rotation gradients
780
+ ],
781
+ device=DEVICE
782
+ )
783
+
784
+ return dL_dmeans, dL_dsh, dL_dscales, dL_drots
785
+
786
+ def backward_render(
787
+ ranges,
788
+ point_list,
789
+ width,
790
+ height,
791
+ bg_color,
792
+ tile_grid,
793
+ points_xy_image,
794
+ conic_opacity,
795
+ colors,
796
+ final_Ts,
797
+ n_contrib,
798
+ dL_dpixels,
799
+ dL_dmean2D,
800
+ dL_dconic2D,
801
+ dL_dopacity,
802
+ dL_dcolors,
803
+ ):
804
+ """
805
+ Orchestrates the backward rendering process by launching the backward kernel.
806
+
807
+ Args:
808
+ ranges: Range of point indices for each tile
809
+ point_list: Sorted list of point indices
810
+ width, height: Image dimensions
811
+ bg_color: Background color
812
+ points_xy_image: 2D positions of Gaussians
813
+ conic_opacity: Conic matrices and opacities
814
+ colors: RGB colors
815
+ final_Ts: Final transparency values from forward pass
816
+ n_contrib: Number of contributors per pixel
817
+ dL_dpixels: Gradient of loss w.r.t. output pixels
818
+ dL_dmean2D: Output gradient w.r.t. 2D mean positions
819
+ dL_dconic2D: Output gradient w.r.t. conic matrices
820
+ dL_dopacity: Output gradient w.r.t. opacity
821
+ dL_dcolors: Output gradient w.r.t. colors
822
+ """
823
+ # Calculate tile grid dimensions
824
+ tile_grid_x = (width + TILE_M - 1) // TILE_M
825
+ tile_grid_y = (height + TILE_N - 1) // TILE_N
826
+ ranges_np = ranges.numpy()
827
+ # Launch the backward rendering kernel
828
+ wp.launch(
829
+ kernel=wp_render_backward_kernel,
830
+ dim=(tile_grid_x, tile_grid_y, TILE_M, TILE_N),
831
+ inputs=[
832
+ ranges,
833
+ point_list,
834
+ width,
835
+ height,
836
+ bg_color,
837
+ tile_grid,
838
+ points_xy_image,
839
+ conic_opacity,
840
+ colors,
841
+ final_Ts,
842
+ n_contrib,
843
+ dL_dpixels,
844
+ dL_dmean2D,
845
+ dL_dconic2D,
846
+ dL_dopacity,
847
+ dL_dcolors,
848
+ ],
849
+ )
850
+
851
+ def backward(
852
+ # --- Core parameters ---
853
+ background,
854
+ means3D,
855
+ dL_dpixels,
856
+ # --- Model parameters ---
857
+ opacity=None,
858
+ shs=None,
859
+ scales=None,
860
+ rotations=None,
861
+ scale_modifier=1.0,
862
+ # --- Camera parameters ---
863
+ viewmatrix=None,
864
+ projmatrix=None,
865
+ tan_fovx=0.5,
866
+ tan_fovy=0.5,
867
+ image_height=256,
868
+ image_width=256,
869
+ campos=None,
870
+ # --- Forward output buffers ---
871
+ radii=None,
872
+ means2D=None,
873
+ conic_opacity=None,
874
+ rgb=None,
875
+ clamped=None,
876
+ cov3Ds=None,
877
+ # --- Internal state buffers ---
878
+ geom_buffer=None,
879
+ binning_buffer=None,
880
+ img_buffer=None,
881
+ # --- Algorithm parameters ---
882
+ degree=3,
883
+ debug=False,
884
+ ):
885
+ """
886
+ Main backward function for 3D Gaussian Splatting.
887
+
888
+ This function orchestrates the entire backward pass by calling two main sub-functions:
889
+ 1. backward_render: Computes gradients w.r.t. 2D parameters (mean2D, conic, opacity, color)
890
+ 2. backward_preprocess: Computes gradients w.r.t. 3D parameters
891
+ (mean3D, cov3D, SH coefficients, scales, rotations)
892
+
893
+ Args:
894
+ background: Background color as numpy array, torch tensor, or wp.vec3 (3,)
895
+ means3D: 3D positions as numpy array, torch tensor, or wp.array (N, 3)
896
+ dL_dpixels: Gradient of loss w.r.t. output pixels (H, W, 3)
897
+ opacity: Opacity values (N, 1) or (N,)
898
+ shs: Spherical harmonics coefficients (N, D, 3) or flattened (N*D, 3)
899
+ scales: Scale parameters (N, 3)
900
+ rotations: Rotation quaternions (N, 4)
901
+ scale_modifier: Global scale modifier (float)
902
+ viewmatrix: View matrix (4, 4)
903
+ projmatrix: Projection matrix (4, 4)
904
+ tan_fovx: Tangent of x field of view
905
+ tan_fovy: Tangent of y field of view
906
+ image_height: Image height
907
+ image_width: Image width
908
+ campos: Camera position (3,)
909
+ radii: Computed radii from forward pass (N,)
910
+ means2D: 2D projected positions from forward pass (N, 2)
911
+ conic_opacity: Conic matrices + opacity from forward pass (N, 4)
912
+ rgb: RGB colors from forward pass (N, 3)
913
+ clamped: Clamping state from forward pass (N, 3)
914
+ cov3Ds: 3D covariance matrices from forward pass (N, 6)
915
+ geom_buffer: Dictionary holding geometric state
916
+ binning_buffer: Dictionary holding binning state
917
+ img_buffer: Dictionary holding image state
918
+ degree: SH degree (0-3)
919
+ debug: Enable debug output
920
+
921
+ Returns:
922
+ dict: Dictionary containing gradients for all model parameters:
923
+ - dL_dmean3D: Gradient w.r.t. 3D positions (N, 3)
924
+ - dL_dcolor: Gradient w.r.t. colors (N, 3)
925
+ - dL_dshs: Gradient w.r.t. SH coefficients (N*D, 3)
926
+ - dL_dopacity: Gradient w.r.t. opacity (N,)
927
+ - dL_dscale: Gradient w.r.t. scales (N, 3)
928
+ - dL_drot: Gradient w.r.t. rotations (N, 4)
929
+ """
930
+ # Calculate focal lengths from FoV
931
+ focal_y = image_height / (2.0 * tan_fovy)
932
+ focal_x = image_width / (2.0 * tan_fovx)
933
+
934
+ # Convert inputs to warp arrays
935
+ background_warp = background if isinstance(background, wp.vec3) else wp.vec3(background[0], background[1], background[2])
936
+ means3D_warp = to_warp_array(means3D, wp.vec3)
937
+ dL_dpixels_warp = to_warp_array(dL_dpixels, wp.vec3) if not isinstance(dL_dpixels, wp.array) else dL_dpixels
938
+
939
+ # Get number of points
940
+ num_points = means3D_warp.shape[0]
941
+
942
+ # Convert optional parameters if provided
943
+ opacity_warp = to_warp_array(opacity, float, flatten=True) if opacity is not None else None
944
+
945
+ # SH coefficients need special handling for flattening
946
+ if shs is not None:
947
+ sh_data = shs.reshape(-1, 3) if hasattr(shs, 'reshape') and shs.ndim > 2 else shs
948
+ shs_warp = to_warp_array(sh_data, wp.vec3)
949
+ else:
950
+ shs_warp = None
951
+
952
+ # Handle other model parameters
953
+ scales_warp = to_warp_array(scales, wp.vec3) if scales is not None else None
954
+
955
+ # Handle rotations differently based on shape (matrices vs quaternions)
956
+ if rotations is not None:
957
+ rot_shape = rotations.shape[-1] if hasattr(rotations, 'shape') else rotations.size(-1)
958
+ if rot_shape == 4: # Quaternions
959
+ rotations_warp = to_warp_array(rotations, wp.vec4)
960
+ else: # 3x3 matrices
961
+ rotations_warp = to_warp_array(rotations, wp.mat33)
962
+ else:
963
+ rotations_warp = None
964
+
965
+ # Handle camera parameters
966
+ viewmatrix_warp = viewmatrix if isinstance(viewmatrix, wp.mat44) else wp.mat44(viewmatrix.flatten())
967
+ projmatrix_warp = projmatrix if isinstance(projmatrix, wp.mat44) else wp.mat44(projmatrix.flatten())
968
+ campos_warp = campos if isinstance(campos, wp.vec3) else wp.vec3(campos[0], campos[1], campos[2])
969
+
970
+ # --- Extract data from buffer dictionaries if provided ---
971
+ if img_buffer is not None:
972
+ ranges = img_buffer.get('ranges')
973
+ final_Ts = img_buffer.get('final_Ts')
974
+ n_contrib = img_buffer.get('n_contrib')
975
+
976
+ if binning_buffer is not None:
977
+ point_list = binning_buffer.get('point_list')
978
+
979
+ if geom_buffer is not None:
980
+ # Use internal data if not provided directly
981
+ if radii is None:
982
+ radii = geom_buffer.get('radii')
983
+ if means2D is None:
984
+ means2D = geom_buffer.get('means2D')
985
+ if conic_opacity is None:
986
+ conic_opacity = geom_buffer.get('conic_opacity')
987
+ if rgb is None:
988
+ rgb = geom_buffer.get('rgb')
989
+ if clamped is None:
990
+ clamped = geom_buffer.get('clamped_state')
991
+
992
+ # Convert forward pass outputs to warp arrays if they're not already
993
+ radii_warp = to_warp_array(radii, int) if radii is not None else None
994
+ means2D_warp = to_warp_array(means2D, wp.vec2) if means2D is not None else None
995
+ conic_opacity_warp = to_warp_array(conic_opacity, wp.vec4) if conic_opacity is not None else None
996
+ rgb_warp = to_warp_array(rgb, wp.vec3) if rgb is not None else None
997
+ clamped_warp = to_warp_array(clamped, wp.uint32) if clamped is not None else None
998
+
999
+ # --- Initialize output gradient arrays ---
1000
+ dL_dmean2D = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
1001
+ dL_dconic = wp.zeros(num_points, dtype=wp.vec4, device=DEVICE)
1002
+ dL_dopacity = wp.zeros(num_points, dtype=float, device=DEVICE)
1003
+ dL_dcolor = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
1004
+
1005
+ dL_dmean3D = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
1006
+ dL_dcov3D = wp.zeros(num_points, dtype=VEC6, device=DEVICE)
1007
+
1008
+ # SH gradients depend on degree
1009
+ max_sh_coeffs = 16 if degree >= 3 else (degree + 1) * (degree + 1)
1010
+ dL_dsh = wp.zeros(num_points * max_sh_coeffs, dtype=wp.vec3, device=DEVICE)
1011
+
1012
+ dL_dscale = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
1013
+ dL_drot = wp.zeros(num_points, dtype=wp.vec4, device=DEVICE)
1014
+
1015
+ # Use precomputed colors if provided, otherwise use colors from forward pass
1016
+
1017
+ tile_grid = wp.vec3((image_width + TILE_M - 1) // TILE_M,
1018
+ (image_height + TILE_N - 1) // TILE_N,
1019
+ 1)
1020
+
1021
+ # --- Step 1: Compute loss gradients w.r.t. 2D parameters ---
1022
+ backward_render(
1023
+ ranges=ranges,
1024
+ point_list=point_list,
1025
+ width=image_width,
1026
+ height=image_height,
1027
+ bg_color=background_warp,
1028
+ tile_grid=tile_grid,
1029
+ points_xy_image=means2D_warp,
1030
+ conic_opacity=conic_opacity_warp,
1031
+ colors=rgb_warp,
1032
+ final_Ts=final_Ts,
1033
+ n_contrib=n_contrib,
1034
+ dL_dpixels=dL_dpixels_warp,
1035
+ dL_dmean2D=dL_dmean2D,
1036
+ dL_dconic2D=dL_dconic,
1037
+ dL_dopacity=dL_dopacity,
1038
+ dL_dcolors=dL_dcolor,
1039
+ )
1040
+
1041
+ # --- Step 2: Compute gradients for 3D parameters ---
1042
+ backward_preprocess(
1043
+ num_points=num_points,
1044
+ means=means3D_warp,
1045
+ means_2d=means2D_warp,
1046
+ radii=radii_warp,
1047
+ sh_coeffs=shs_warp,
1048
+ scales=scales_warp,
1049
+ rotations=rotations_warp,
1050
+ viewmatrix=viewmatrix_warp,
1051
+ projmatrix=projmatrix_warp,
1052
+ fov_x=tan_fovx,
1053
+ fov_y=tan_fovy,
1054
+ focal_x=focal_x,
1055
+ focal_y=focal_y,
1056
+ cov3Ds=cov3Ds,
1057
+ conic_opacity=conic_opacity_warp,
1058
+ campos=campos_warp,
1059
+ clamped=clamped_warp,
1060
+ dL_dmean2D=dL_dmean2D,
1061
+ dL_dconic=dL_dconic,
1062
+ dL_dopacity=dL_dopacity,
1063
+ dL_dcolors=dL_dcolor,
1064
+ dL_dmeans=dL_dmean3D,
1065
+ dL_dsh=dL_dsh,
1066
+ dL_dscales=dL_dscale,
1067
+ dL_drots=dL_drot,
1068
+ sh_degree=degree
1069
+ )
1070
+
1071
+ # Return all gradients in a dictionary for easy access
1072
+ return {
1073
+ 'dL_dmean3D': dL_dmean3D,
1074
+ 'dL_dcolor': dL_dcolor,
1075
+ 'dL_dshs': dL_dsh,
1076
+ 'dL_dopacity': dL_dopacity,
1077
+ 'dL_dscale': dL_dscale,
1078
+ 'dL_drot': dL_drot,
1079
+ # Include 2D gradients for completeness
1080
+ 'dL_dmean2D': dL_dmean2D,
1081
+ 'dL_dconic': dL_dconic,
1082
+ 'dL_dcov3D': dL_dcov3D
1083
+ }
1084
+
gs/config.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings and constants for 3D Gaussian Splatting with NeRF datasets.
3
+ """
4
+ import warp as wp
5
+ import numpy as np
6
+ import random
7
+
8
+ SEED = 42
9
+ random.seed(SEED)
10
+
11
+ # Warp data types and constants (keep capitalized as they are types)
12
+ WP_FLOAT16 = wp.float16
13
+ WP_FLOAT32 = wp.float32
14
+ WP_INT = wp.int32
15
+ WP_VEC2 = wp.vec2
16
+ WP_VEC2H = wp.vec2h
17
+ VEC6 = wp.types.vector(length=6, dtype=WP_FLOAT32)
18
+ DEVICE = "cuda" #"cpu" # Use "cpu" or "cuda"
19
+
20
+ TILE_M = wp.constant(16)
21
+ TILE_N = wp.constant(16)
22
+ TILE_THREADS = wp.constant(256)
23
+
24
+
25
+ class GaussianParams:
26
+ """Parameters for 3D Gaussian Splatting."""
27
+
28
+ # Training parameters
29
+ num_iterations = 3*7000//1 # Default number of training iterations
30
+ num_points = 5000 # Initial number of Gaussian points
31
+
32
+ # Simple learning rate scheduler configuration
33
+ use_lr_scheduler = True
34
+ # Learning rate scheduler configuration
35
+ lr_scheduler_config = {
36
+ 'lr_pos': 1e-2, # Initial learning rate for positions
37
+ 'lr_scale': 5e-3, # Initial learning rate for scales
38
+ 'lr_rot': 5e-3, # Initial learning rate for rotations
39
+ 'lr_sh': 2e-3, # Initial learning rate for spherical harmonics
40
+ 'lr_opac': 5e-3, # Initial learning rate for opacities
41
+ 'final_lr_factor': 0.01 # Final LR will be 1% of initial LR
42
+ }
43
+
44
+ # Optimization parameters
45
+ densification_interval = 100 # Perform densification every N iterations
46
+ pruning_interval = 100 # Perform pruning every N iterations
47
+ opacity_reset_interval = 3000
48
+ save_interval = 300 # Save checkpoint every N iterations
49
+ adam_beta1 = 0.9 # Adam optimizer beta1 parameter
50
+ adam_beta2 = 0.999 # Adam optimizer beta2 parameter
51
+ adam_epsilon = 1e-8 # Adam optimizer epsilon parameter
52
+
53
+
54
+ densify_grad_threshold = 0.0002
55
+ cull_opacity_threshold = 0.005
56
+ start_prune_iter = 500
57
+ end_prune_iter = 15000
58
+ percent_dense = 0.01
59
+ max_allowed_prune_ratio = 1.0 # no limit on pruning ratio
60
+
61
+ # Gaussian parameters
62
+ initial_scale = 0.1 #0.1 # Initial scale for Gaussian points
63
+ scale_modifier = 1.0 # Scaling factor for Gaussian splats
64
+ sh_degree = 3 # Spherical harmonics degree
65
+
66
+ # Scene parameters
67
+ scene_scale = 1.0 # Scale factor for the scene
68
+ background_color = [1.0,1.0,1.0] #[0.0, 0.0, 0.0] # White background for NeRF synthetic
69
+
70
+ # Loss parameters
71
+ lambda_dssim = 0.0 # Weight for SSIM loss (1.0 means only SSIM, 0.0 means only L1)
72
+
73
+ # Depth loss parameters
74
+ depth_l1_weight_init = 0.0 # Initial weight for depth L1 loss
75
+ depth_l1_weight_final = 0.0 # Final weight for depth L1 loss
76
+ depth_l1_delay_steps = 0 # Number of steps to delay depth loss
77
+ depth_l1_delay_mult = 0.0 # Multiplier for delay rate
78
+
79
+ near = 0.01 # Default near clipping plane
80
+ far = 100.0 # Default far clipping plane
81
+
82
+
83
+ @classmethod
84
+ def get_depth_l1_weight(cls, step):
85
+ """Compute the depth L1 loss weight for the current step.
86
+
87
+ Args:
88
+ step (int): Current training step
89
+
90
+ Returns:
91
+ float: Weight for depth L1 loss
92
+ """
93
+ if step < 0 or (cls.depth_l1_weight_init == 0.0 and cls.depth_l1_weight_final == 0.0):
94
+ # Disable depth loss
95
+ return 0.0
96
+
97
+ if cls.depth_l1_delay_steps > 0:
98
+ # A kind of reverse cosine decay
99
+ delay_rate = cls.depth_l1_delay_mult + (1 - cls.depth_l1_delay_mult) * np.sin(
100
+ 0.5 * np.pi * np.clip(step / cls.depth_l1_delay_steps, 0, 1)
101
+ )
102
+ else:
103
+ delay_rate = 1.0
104
+
105
+ # Logarithmic interpolation between initial and final weights
106
+ t = np.clip(step / cls.num_iterations, 0, 1)
107
+ log_lerp = np.exp(np.log(cls.depth_l1_weight_init) * (1 - t) + np.log(cls.depth_l1_weight_final) * t)
108
+
109
+ return delay_rate * log_lerp
110
+
111
+ @classmethod
112
+ def update(cls, **kwargs):
113
+ """Update parameters with new values."""
114
+ for key, value in kwargs.items():
115
+ if hasattr(cls, key):
116
+ setattr(cls, key, value)
117
+ else:
118
+ raise ValueError(f"Unknown parameter: {key}")
119
+
120
+ @classmethod
121
+ def get_config_dict(cls):
122
+ """Get parameters as a dictionary."""
123
+ return {
124
+ 'num_iterations': cls.num_iterations,
125
+ 'num_points': cls.num_points,
126
+ 'densification_interval': cls.densification_interval,
127
+ 'pruning_interval': cls.pruning_interval,
128
+ 'scale_modifier': cls.scale_modifier,
129
+ 'sh_degree': cls.sh_degree,
130
+ 'background_color': cls.background_color,
131
+ 'save_interval': cls.save_interval,
132
+ 'adam_beta1': cls.adam_beta1,
133
+ 'adam_beta2': cls.adam_beta2,
134
+ 'adam_epsilon': cls.adam_epsilon,
135
+ 'initial_scale': cls.initial_scale,
136
+ 'scene_scale': cls.scene_scale,
137
+ 'near': cls.near,
138
+ 'far': cls.far,
139
+ 'lambda_dssim': cls.lambda_dssim,
140
+ 'depth_l1_weight_init': cls.depth_l1_weight_init,
141
+ 'depth_l1_weight_final': cls.depth_l1_weight_final,
142
+ 'depth_l1_delay_steps': cls.depth_l1_delay_steps,
143
+ 'depth_l1_delay_mult': cls.depth_l1_delay_mult,
144
+ 'densify_grad_threshold': cls.densify_grad_threshold,
145
+ 'cull_opacity_threshold': cls.cull_opacity_threshold,
146
+ 'start_prune_iter': cls.start_prune_iter,
147
+ 'end_prune_iter': cls.end_prune_iter,
148
+ 'use_lr_scheduler': cls.use_lr_scheduler,
149
+ 'lr_scheduler_config': cls.lr_scheduler_config,
150
+ 'max_allowed_prune_ratio': cls.max_allowed_prune_ratio,
151
+ }
gs/create_training_video.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import glob
5
+ from tqdm import tqdm
6
+
7
+ def create_training_video(input_pattern, output_path, fps=10):
8
+ """
9
+ Create a video from training iteration images.
10
+
11
+ Args:
12
+ input_pattern: Pattern to match image files (e.g., 'output/steak_is/point_cloud/iteration_*/rendered_view.png')
13
+ output_path: Path to save the output video
14
+ fps: Frames per second for the output video
15
+ """
16
+ # Find all matching image files and sort them by iteration number
17
+ image_files = sorted(glob.glob(input_pattern),
18
+ key=lambda x: int(x.split('iteration_')[1].split('/')[0]))
19
+
20
+ if not image_files:
21
+ print(f"No images found matching pattern: {input_pattern}")
22
+ return
23
+
24
+ print(f"Found {len(image_files)} image files")
25
+
26
+ # Read first image to get dimensions
27
+ first_img = cv2.imread(image_files[0])
28
+ h, w, _ = first_img.shape
29
+
30
+ # Create VideoWriter object
31
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
32
+ video = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
33
+
34
+ # Add each image to the video
35
+ for img_path in tqdm(image_files, desc="Creating video"):
36
+ img = cv2.imread(img_path)
37
+
38
+ # Optionally add iteration number as text overlay
39
+ iteration = int(img_path.split('iteration_')[1].split('/')[0])
40
+ cv2.putText(img, f"Iteration {iteration}", (20, 40),
41
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
42
+
43
+ video.write(img)
44
+
45
+ # Release the video writer
46
+ video.release()
47
+ print(f"Video created successfully: {output_path}")
48
+
49
+ # Add a simple UI to select images and set options
50
+ if __name__ == "__main__":
51
+ import argparse
52
+
53
+ parser = argparse.ArgumentParser(description='Create a video from training iteration images')
54
+ parser.add_argument('--input', default='output/steak_is/point_cloud/iteration_*/rendered_view.png',
55
+ help='Pattern to match image files')
56
+ parser.add_argument('--output', default='training_progress.mp4',
57
+ help='Path to save the output video')
58
+ parser.add_argument('--fps', type=int, default=10,
59
+ help='Frames per second for the output video')
60
+ parser.add_argument('--reverse', action='store_true',
61
+ help='Reverse the order of images (show latest first)')
62
+
63
+ args = parser.parse_args()
64
+
65
+ if args.reverse:
66
+ # Find all matching image files and sort them in reverse order
67
+ image_files = sorted(glob.glob(args.input),
68
+ key=lambda x: int(x.split('iteration_')[1].split('/')[0]),
69
+ reverse=True)
70
+ if image_files:
71
+ create_training_video(image_files, args.output, args.fps)
72
+ else:
73
+ create_training_video(args.input, args.output, args.fps)
gs/dataset_reader.py ADDED
@@ -0,0 +1 @@
 
 
1
+
gs/forward.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warp as wp
2
+ from utils.wp_utils import to_warp_array
3
+ from config import *
4
+ # Initialize Warp
5
+ wp.init()
6
+ print("Warp devices:", wp.get_devices())
7
+ # Define spherical harmonics constants
8
+ SH_C0 = 0.28209479177387814
9
+ SH_C1 = 0.4886025119029199
10
+
11
+
12
+ import warp as wp
13
+
14
+ # Define the CUDA code snippets for bit reinterpretation
15
+ float_to_uint32_snippet = """
16
+ return reinterpret_cast<uint32_t&>(x);
17
+ """
18
+
19
+ @wp.func_native(float_to_uint32_snippet)
20
+ def float_bits_to_uint32(x: float) -> wp.uint32:
21
+ ...
22
+
23
+ @wp.func
24
+ def ndc2pix(x: float, size: float) -> float:
25
+ return ((x + 1.0) * size - 1.0) * 0.5
26
+
27
+ @wp.func
28
+ def get_rect(p: wp.vec2, max_radius: float, tile_grid: wp.vec3):
29
+ # Extract grid dimensions
30
+ grid_size_x = tile_grid[0]
31
+ grid_size_y = tile_grid[1]
32
+
33
+ rect_min_x = wp.min(wp.int32(grid_size_x), wp.int32(wp.max(wp.int32(0), wp.int32((p[0] - max_radius) / float(TILE_M)))))
34
+ rect_min_y = wp.min(wp.int32(grid_size_y), wp.int32(wp.max(wp.int32(0), wp.int32((p[1] - max_radius) / float(TILE_N)))))
35
+
36
+
37
+ rect_max_x = wp.min(wp.int32(grid_size_x), wp.int32(wp.max(wp.int32(0), wp.int32((p[0] + max_radius + float(TILE_M) - 1.0) / float(TILE_M)))))
38
+ rect_max_y = wp.min(wp.int32(grid_size_y), wp.int32(wp.max(wp.int32(0), wp.int32((p[1] + max_radius + float(TILE_N) - 1.0) / float(TILE_N)))))
39
+
40
+ return rect_min_x, rect_min_y, rect_max_x, rect_max_y
41
+
42
+
43
+ @wp.func
44
+ def compute_cov2d(p_orig: wp.vec3, cov3d: VEC6, view_matrix: wp.mat44,
45
+ tan_fovx: float, tan_fovy: float, width: float, height: float) -> wp.vec3:
46
+
47
+ t = wp.vec4(p_orig[0], p_orig[1], p_orig[2], 1.0) * view_matrix
48
+ limx = 1.3 * tan_fovx
49
+ limy = 1.3 * tan_fovy
50
+ # Clamp X/Y to stay inside frustum
51
+ txtz = t[0] / t[2]
52
+ tytz = t[1] / t[2]
53
+ t[0] = min(limx, max(-limx, txtz)) * t[2]
54
+ t[1] = min(limy, max(-limy, tytz)) * t[2]
55
+
56
+ focal_x = width / (2.0 * tan_fovx)
57
+ focal_y = height / (2.0 * tan_fovy)
58
+ # compute Jacobian
59
+ J = wp.mat33(
60
+ focal_x / t[2], 0.0, -(focal_x * t[0]) / (t[2] * t[2]),
61
+ 0.0, focal_y / t[2], -(focal_y * t[1]) / (t[2] * t[2]),
62
+ 0.0, 0.0, 0.0
63
+ )
64
+
65
+ W = wp.mat33(
66
+ view_matrix[0, 0], view_matrix[0, 1], view_matrix[0, 2],
67
+ view_matrix[1, 0], view_matrix[1, 1], view_matrix[1, 2],
68
+ view_matrix[2, 0], view_matrix[2, 1], view_matrix[2, 2]
69
+ )
70
+
71
+ T = J * W
72
+
73
+ Vrk = wp.mat33(
74
+ cov3d[0], cov3d[1], cov3d[2],
75
+ cov3d[1], cov3d[3], cov3d[4],
76
+ cov3d[2], cov3d[4], cov3d[5]
77
+ )
78
+
79
+ cov = T * wp.transpose(Vrk) * wp.transpose(T)
80
+
81
+ return wp.vec3(cov[0, 0], cov[0, 1], cov[1, 1])
82
+
83
+ @wp.func
84
+ def compute_cov3d(scale: wp.vec3, scale_mod: float, rot: wp.vec4) -> VEC6:
85
+ # Create scaling matrix with modifier applied
86
+ S = wp.mat33(
87
+ scale_mod * scale[0], 0.0, 0.0,
88
+ 0.0, scale_mod * scale[1], 0.0,
89
+ 0.0, 0.0, scale_mod * scale[2]
90
+ )
91
+ R = wp.quat_to_matrix(wp.quaternion(rot[0], rot[1], rot[2], rot[3]))
92
+ M = R * S
93
+
94
+ # Compute 3D covariance matrix: Sigma = M * M^T
95
+ sigma = M * wp.transpose(M)
96
+
97
+ return VEC6(sigma[0, 0], sigma[0, 1], sigma[0, 2], sigma[1, 1], sigma[1, 2], sigma[2, 2])
98
+
99
+ @wp.kernel
100
+ def wp_preprocess(
101
+ orig_points: wp.array(dtype=wp.vec3),
102
+ scales: wp.array(dtype=wp.vec3),
103
+ scale_modifier: float,
104
+ rotations: wp.array(dtype=wp.vec4),
105
+
106
+ opacities: wp.array(dtype=float),
107
+ shs: wp.array(dtype=wp.vec3),
108
+ degree: int,
109
+ clamped: bool,
110
+
111
+ view_matrix: wp.mat44,
112
+ proj_matrix: wp.mat44,
113
+ cam_pos: wp.vec3,
114
+
115
+ W: int,
116
+ H: int,
117
+
118
+ tan_fovx: float,
119
+ tan_fovy: float,
120
+
121
+ focal_x: float,
122
+ focal_y: float,
123
+
124
+ radii: wp.array(dtype=int),
125
+ points_xy_image: wp.array(dtype=wp.vec2),
126
+ depths: wp.array(dtype=float),
127
+ cov3Ds: wp.array(dtype=VEC6),
128
+ rgb: wp.array(dtype=wp.vec3),
129
+ conic_opacity: wp.array(dtype=wp.vec4),
130
+ tile_grid: wp.vec3,
131
+ tiles_touched: wp.array(dtype=int),
132
+ clamped_state: wp.array(dtype=wp.vec3),
133
+
134
+ prefiltered: bool,
135
+ antialiasing: bool
136
+ ):
137
+ # Get thread indices
138
+ i = wp.tid()
139
+
140
+ # For each Gaussian
141
+ p_orig = orig_points[i]
142
+ p_view = wp.vec4(p_orig[0], p_orig[1], p_orig[2], 1.0) * view_matrix
143
+
144
+ if p_view[2] < 0.2:
145
+ return
146
+
147
+ p_hom = wp.vec4(p_orig[0], p_orig[1], p_orig[2], 1.0) * proj_matrix
148
+
149
+ p_w = 1.0 / (p_hom[3] + 0.0000001)
150
+ p_proj = wp.vec3(p_hom[0] * p_w, p_hom[1] * p_w, p_hom[2] * p_w)
151
+
152
+ cov3d = compute_cov3d(scales[i], scale_modifier, rotations[i])
153
+
154
+ cov3Ds[i] = cov3d
155
+ # Compute 2D covariance matrix
156
+ cov2d = compute_cov2d(p_orig, cov3d, view_matrix, tan_fovx, tan_fovy, float(W), float(H))
157
+
158
+ # Constants
159
+ h_var = 0.3
160
+ W_float = float(W)
161
+ H_float = float(H)
162
+ C = 3 # RGB channels
163
+
164
+ # Add blur/antialiasing factor to covariance
165
+ det_cov = cov2d[0] * cov2d[2] - cov2d[1] * cov2d[1]
166
+ cov_with_blur = wp.vec3(cov2d[0] + h_var, cov2d[1], cov2d[2] + h_var)
167
+ det_cov_plus_h_cov = cov_with_blur[0] * cov_with_blur[2] - cov_with_blur[1] * cov_with_blur[1]
168
+
169
+ # Invert covariance (EWA algorithm)
170
+ det = det_cov_plus_h_cov
171
+ if det == 0.0:
172
+ return
173
+
174
+ det_inv = 1.0 / det
175
+ conic = wp.vec3(
176
+ cov_with_blur[2] * det_inv,
177
+ -cov_with_blur[1] * det_inv,
178
+ cov_with_blur[0] * det_inv
179
+ )
180
+ # Compute eigenvalues of covariance matrix to find screen-space extent
181
+ mid = 0.5 * (cov_with_blur[0] + cov_with_blur[2])
182
+ lambda1 = mid + wp.sqrt(wp.max(0.1, mid * mid - det))
183
+ lambda2 = mid - wp.sqrt(wp.max(0.1, mid * mid - det))
184
+ my_radius = wp.ceil(3.0 * wp.sqrt(wp.max(lambda1, lambda2)))
185
+ # Convert to pixel coordinates
186
+ point_image = wp.vec2(ndc2pix(p_proj[0], W_float), ndc2pix(p_proj[1], H_float))
187
+
188
+ # Get rectangle of affected tiles
189
+ rect_min_x, rect_min_y, rect_max_x, rect_max_y = get_rect(point_image, my_radius, tile_grid)
190
+
191
+ # Skip if rectangle has 0 area
192
+ if (rect_max_x - rect_min_x) * (rect_max_y - rect_min_y) == 0:
193
+ return
194
+ # Compute color from spherical harmonics
195
+ pos = p_orig
196
+ dir_orig = pos - cam_pos
197
+ dir = wp.normalize(dir_orig)
198
+ x, y, z = dir[0], dir[1], dir[2]
199
+
200
+ # Base offset for this Gaussian's SH coefficients
201
+ base_idx = i * 16 # assuming degree 3 (16 coefficients)
202
+
203
+ # Start with the DC component (degree 0)
204
+ result = SH_C0 * shs[base_idx]
205
+
206
+ # Add higher degree terms if requested
207
+ if degree > 0:
208
+ # Degree 1 terms
209
+ result = result - SH_C1 * y * shs[base_idx + 1] + SH_C1 * z * shs[base_idx + 2] - SH_C1 * x * shs[base_idx + 3]
210
+
211
+ if degree > 1:
212
+ # Degree 2 terms
213
+ xx = x*x
214
+ yy = y*y
215
+ zz = z*z
216
+ xy = x*y
217
+ yz = y*z
218
+ xz = x*z
219
+
220
+ # Degree 2 terms with hardcoded constants
221
+ result = result + 1.0925484305920792 * xy * shs[base_idx + 4]
222
+ result = result + (-1.0925484305920792) * yz * shs[base_idx + 5]
223
+ result = result + 0.31539156525252005 * (2.0 * zz - xx - yy) * shs[base_idx + 6]
224
+ result = result + (-1.0925484305920792) * xz * shs[base_idx + 7]
225
+ result = result + 0.5462742152960396 * (xx - yy) * shs[base_idx + 8]
226
+
227
+ if degree > 2:
228
+ # Degree 3 terms with hardcoded constants
229
+ result = result + (-0.5900435899266435) * y * (3.0 * xx - yy) * shs[base_idx + 9]
230
+ result = result + 2.890611442640554 * xy * z * shs[base_idx + 10]
231
+ result = result + (-0.4570457994644658) * y * (4.0 * zz - xx - yy) * shs[base_idx + 11]
232
+ result = result + 0.3731763325901154 * z * (2.0 * zz - 3.0 * xx - 3.0 * yy) * shs[base_idx + 12]
233
+ result = result + (-0.4570457994644658) * x * (4.0 * zz - xx - yy) * shs[base_idx + 13]
234
+ result = result + 1.445305721320277 * z * (xx - yy) * shs[base_idx + 14]
235
+ result = result + (-0.5900435899266435) * x * (xx - 3.0 * yy) * shs[base_idx + 15]
236
+
237
+ result = result + wp.vec3(0.5, 0.5, 0.5)
238
+
239
+ # Track which color channels are clamped (using wp.vec3 instead of separate uint32 values)
240
+ # Store 1.0 if clamped, 0.0 if not clamped
241
+ # Use separate assignments instead of conditional expressions
242
+ r_clamped = 0.0
243
+ g_clamped = 0.0
244
+ b_clamped = 0.0
245
+
246
+ if result[0] < 0.0:
247
+ r_clamped = 1.0
248
+ if result[1] < 0.0:
249
+ g_clamped = 1.0
250
+ if result[2] < 0.0:
251
+ b_clamped = 1.0
252
+
253
+ clamped_state[i] = wp.vec3(r_clamped, g_clamped, b_clamped)
254
+
255
+ if clamped:
256
+ # RGB colors are clamped to positive values
257
+ result = wp.vec3(
258
+ wp.max(result[0], 0.0),
259
+ wp.max(result[1], 0.0),
260
+ wp.max(result[2], 0.0)
261
+ )
262
+
263
+ rgb[i] = result
264
+
265
+ # Store computed data
266
+ depths[i] = p_view[2]
267
+ radii[i] = int(my_radius)
268
+ points_xy_image[i] = point_image
269
+
270
+ # Pack conic and opacity into single vec4
271
+ conic_opacity[i] = wp.vec4(conic[0], conic[1], conic[2], opacities[i])
272
+ # Store tile information
273
+ tiles_touched[i] = (rect_max_y - rect_min_y) * (rect_max_x - rect_min_x)
274
+
275
+ @wp.kernel
276
+ def wp_render_gaussians(
277
+ # Output buffers
278
+ rendered_image: wp.array2d(dtype=wp.vec3),
279
+ depth_image: wp.array2d(dtype=float),
280
+
281
+ # Tile data
282
+ ranges: wp.array(dtype=wp.vec2i),
283
+ point_list: wp.array(dtype=int),
284
+
285
+ # Image parameters
286
+ W: int,
287
+ H: int,
288
+
289
+ # Gaussian data
290
+ points_xy_image: wp.array(dtype=wp.vec2),
291
+ colors: wp.array(dtype=wp.vec3),
292
+ conic_opacity: wp.array(dtype=wp.vec4),
293
+ depths: wp.array(dtype=float),
294
+
295
+ # Background color
296
+ background: wp.vec3,
297
+
298
+ # Tile grid info
299
+ tile_grid: wp.vec3,
300
+
301
+ # Track additional data
302
+ final_Ts: wp.array2d(dtype=float),
303
+ n_contrib: wp.array2d(dtype=int),
304
+ ):
305
+ tile_x, tile_y, tid_x, tid_y = wp.tid()
306
+
307
+ # Calculate tile index
308
+
309
+ if tile_y >= (H + TILE_N - 1) // TILE_N:
310
+ return
311
+
312
+ # Calculate pixel boundaries for this tile
313
+ pix_min_x = tile_x * TILE_M
314
+ pix_min_y = tile_y * TILE_N
315
+ pix_max_x = wp.min(pix_min_x + TILE_M, W)
316
+ pix_max_y = wp.min(pix_min_y + TILE_N, H)
317
+
318
+ # Calculate pixel position for this thread
319
+ pix_x = pix_min_x + tid_x
320
+ pix_y = pix_min_y + tid_y
321
+
322
+ # Check if this thread processes a valid pixel
323
+ inside = (pix_x < W) and (pix_y < H)
324
+ if not inside:
325
+ return
326
+
327
+ pixf_x = float(pix_x)
328
+ pixf_y = float(pix_y)
329
+
330
+ # Get start/end range of IDs to process for this tile
331
+ tile_id = tile_y * int(tile_grid[0]) + tile_x
332
+ range_start = ranges[tile_id][0]
333
+ range_end = ranges[tile_id][1]
334
+
335
+ # Initialize blending variables
336
+ T = float(1.0) # Transmittance
337
+ r, g, b = float(0.0), float(0.0), float(0.0) # Accumulated color
338
+ expected_inv_depth = float(0.0) # For depth calculation
339
+
340
+ # Track the number of contributors to this pixel
341
+ contributor_count = int(0)
342
+ last_contributor = int(0)
343
+
344
+ # Iterate over all Gaussians influencing this tile
345
+ for i in range(range_start, range_end):
346
+ # Get Gaussian ID
347
+ gaussian_id = point_list[i]
348
+
349
+ # Get Gaussian data
350
+ xy = points_xy_image[gaussian_id]
351
+ con_o = conic_opacity[gaussian_id]
352
+ color = colors[gaussian_id]
353
+
354
+ # Compute distance to Gaussian center
355
+ d_x = xy[0] - pixf_x
356
+ d_y = xy[1] - pixf_y
357
+
358
+ # Increment contributor count for this pixel
359
+ contributor_count += 1
360
+
361
+ # Compute Gaussian power (exponent)
362
+ power = -0.5 * (con_o[0] * d_x * d_x + con_o[2] * d_y * d_y) - con_o[1] * d_x * d_y
363
+
364
+ # Skip if power is positive (too far away)
365
+ if power > 0.0:
366
+ continue
367
+
368
+ # Compute alpha from power and opacity
369
+ alpha = wp.min(0.99, con_o[3] * wp.exp(power))
370
+
371
+ # Skip if alpha is too small
372
+ if alpha < (1.0 / 255.0):
373
+ continue
374
+
375
+
376
+ # Test if we're close to fully opaque
377
+ test_T = T * (1.0 - alpha)
378
+ if test_T < 0.0001:
379
+ break # Early termination if pixel is almost opaque
380
+
381
+ # Accumulate color contribution
382
+ r += color[0] * alpha * T
383
+ g += color[1] * alpha * T
384
+ b += color[2] * alpha * T
385
+
386
+ # Accumulate inverse depth
387
+ expected_inv_depth += (1.0 / depths[gaussian_id]) * alpha * T
388
+
389
+ # Update transmittance
390
+ T = test_T
391
+
392
+ last_contributor = contributor_count
393
+
394
+ # Store final transmittance (T) and contributor count
395
+ final_Ts[pix_y, pix_x] = T
396
+ n_contrib[pix_y, pix_x] = last_contributor
397
+
398
+ # Write final color to output buffer (color + background)
399
+ rendered_image[pix_y, pix_x] = wp.vec3(
400
+ r + T * background[0],
401
+ g + T * background[1],
402
+ b + T * background[2]
403
+ )
404
+
405
+ # Write depth to output buffer
406
+ depth_image[pix_y, pix_x] = expected_inv_depth
407
+
408
+ @wp.kernel
409
+ def wp_duplicate_with_keys(
410
+ points_xy_image: wp.array(dtype=wp.vec2),
411
+ depths: wp.array(dtype=float),
412
+ point_offsets: wp.array(dtype=int),
413
+ point_list_keys_unsorted: wp.array(dtype=wp.int64),
414
+ point_list_unsorted: wp.array(dtype=int),
415
+ radii: wp.array(dtype=int),
416
+ tile_grid: wp.vec3
417
+ ):
418
+ tid = wp.tid()
419
+
420
+ if tid >= points_xy_image.shape[0]:
421
+ return
422
+
423
+ r = radii[tid]
424
+ if r <= 0:
425
+ return
426
+
427
+ # Find the global offset into key/value buffers
428
+ offset = 0
429
+ if tid > 0:
430
+ offset = point_offsets[tid - 1]
431
+
432
+ pos = points_xy_image[tid]
433
+ depth_val = depths[tid]
434
+
435
+ rect_min_x, rect_min_y, rect_max_x, rect_max_y = get_rect(pos, float(r), tile_grid)
436
+
437
+ for y in range(rect_min_y, rect_max_y):
438
+ for x in range(rect_min_x, rect_max_x):
439
+ tile_id = y * int(tile_grid[0]) + x
440
+ # Convert to int64 to avoid overflow during bit shift
441
+ tile_id_64 = wp.int64(tile_id)
442
+ shifted = tile_id_64 << wp.int64(32)
443
+ depth_bits = wp.int64(float_bits_to_uint32(depth_val))
444
+ # Combine tile ID and depth into single key
445
+ key = wp.int64(shifted) | depth_bits
446
+
447
+ point_list_keys_unsorted[offset] = key
448
+ point_list_unsorted[offset] = tid
449
+ offset += 1
450
+
451
+ @wp.kernel
452
+ def wp_identify_tile_ranges(
453
+ num_rendered: int,
454
+ point_list_keys: wp.array(dtype=wp.int64),
455
+ ranges: wp.array(dtype=wp.vec2i) # Each range is (start, end)
456
+ ):
457
+ idx = wp.tid()
458
+
459
+ if idx >= num_rendered:
460
+ return
461
+
462
+ key = point_list_keys[idx]
463
+ curr_tile = int(key >> wp.int64(32))
464
+
465
+ # Set start of range if first element or tile changed
466
+ if idx == 0:
467
+ ranges[curr_tile][0] = 0
468
+ else:
469
+ prev_key = point_list_keys[idx - 1]
470
+ prev_tile = int(prev_key >> wp.int64(32))
471
+ if curr_tile != prev_tile:
472
+ ranges[prev_tile][1] = idx
473
+ ranges[curr_tile][0] = idx
474
+
475
+ # Set end of range if last element
476
+ if idx == num_rendered - 1:
477
+ ranges[curr_tile][1] = num_rendered
478
+
479
+
480
+ @wp.kernel
481
+ def wp_prefix_sum(input_array: wp.array(dtype=int),
482
+ output_array: wp.array(dtype=int)):
483
+ tid = wp.tid()
484
+
485
+ if tid == 0:
486
+ output_array[0] = input_array[0]
487
+
488
+ # Perform prefix sum
489
+ for i in range(1, input_array.shape[0]):
490
+ output_array[i] = output_array[i-1] + input_array[i]
491
+
492
+
493
+ @wp.kernel
494
+ def wp_copy_int64(src: wp.array(dtype=wp.int64), dst: wp.array(dtype=wp.int64), count: int):
495
+ i = wp.tid()
496
+ if i < count:
497
+ dst[i] = src[i]
498
+
499
+ @wp.kernel
500
+ def wp_copy_int(src: wp.array(dtype=int), dst: wp.array(dtype=int), count: int):
501
+ i = wp.tid()
502
+ if i < count:
503
+ dst[i] = src[i]
504
+
505
+ @wp.kernel
506
+ def track_pixel_stats(
507
+ rendered_image: wp.array2d(dtype=wp.vec3),
508
+ depth_image: wp.array2d(dtype=float),
509
+ background: wp.vec3,
510
+ final_Ts: wp.array2d(dtype=float),
511
+ n_contrib: wp.array2d(dtype=int),
512
+ W: int,
513
+ H: int
514
+ ):
515
+ """Kernel to track final transparency values and contributor counts for each pixel."""
516
+ x, y = wp.tid()
517
+
518
+ if x >= W or y >= H:
519
+ return
520
+
521
+ # Get the rendered pixel
522
+ pixel = rendered_image[y, x]
523
+
524
+ # Calculate approximate alpha transparency by checking for background contribution
525
+ # If the pixel has no contribution from background, final_T should be close to 0
526
+ # If it's mostly background, final_T will be close to 1
527
+ diff_r = abs(pixel[0] - background[0])
528
+ diff_g = abs(pixel[1] - background[1])
529
+ diff_b = abs(pixel[2] - background[2])
530
+ has_content = (diff_r > 0.01) or (diff_g > 0.01) or (diff_b > 0.01)
531
+
532
+ if has_content:
533
+ # Approximate final_T - in a real scenario this should already be tracked during rendering
534
+ # We're just making sure it's populated for existing renderings
535
+ if final_Ts[y, x] == 0.0:
536
+ # If final_Ts hasn't been set during rendering, approximate it
537
+ # Higher difference from background means lower T
538
+ max_diff = max(diff_r, max(diff_g, diff_b))
539
+ final_Ts[y, x] = 1.0 - min(0.99, max_diff)
540
+
541
+ # Set n_contrib to 1 if we know the pixel has content but no contributor count
542
+ if n_contrib[y, x] == 0:
543
+ n_contrib[y, x] = 1
544
+
545
+ def render_gaussians(
546
+ background,
547
+ means3D,
548
+ colors=None,
549
+ opacity=None,
550
+ scales=None,
551
+ rotations=None,
552
+ scale_modifier=1.0,
553
+ viewmatrix=None,
554
+ projmatrix=None,
555
+ tan_fovx=0.5,
556
+ tan_fovy=0.5,
557
+ image_height=256,
558
+ image_width=256,
559
+ sh=None,
560
+ degree=3,
561
+ campos=None,
562
+ prefiltered=False,
563
+ antialiasing=False,
564
+ clamped=True,
565
+ debug=False,
566
+ ):
567
+ """Render 3D Gaussians using Warp.
568
+
569
+ Args:
570
+ background: Background color tensor of shape (3,)
571
+ means3D: 3D positions tensor of shape (N, 3)
572
+ colors: Optional RGB colors tensor of shape (N, 3)
573
+ opacity: Opacity values tensor of shape (N, 1) or (N,)
574
+ scales: Scales tensor of shape (N, 3)
575
+ rotations: Rotation quaternions of shape (N, 4)
576
+ scale_modifier: Global scale modifier (float)
577
+ viewmatrix: View matrix tensor of shape (4, 4)
578
+ projmatrix: Projection matrix tensor of shape (4, 4)
579
+ tan_fovx: Tangent of the horizontal field of view
580
+ tan_fovy: Tangent of the vertical field of view
581
+ image_height: Height of the output image
582
+ image_width: Width of the output image
583
+ sh: Spherical harmonics coefficients tensor of shape (N, D, 3)
584
+ degree: Degree of spherical harmonics
585
+ campos: Camera position tensor of shape (3,)
586
+ prefiltered: Whether input Gaussians are prefiltered
587
+ antialiasing: Whether to apply antialiasing
588
+ clamped: Whether to clamp the colors
589
+ debug: Whether to print debug information
590
+
591
+ Returns:
592
+ Tuple of (rendered_image, depth_image, intermediate_buffers)
593
+ """
594
+ rendered_image = wp.zeros((image_height, image_width), dtype=wp.vec3, device=DEVICE)
595
+ depth_image = wp.zeros((image_height, image_width), dtype=float, device=DEVICE)
596
+
597
+ # Create additional buffers for tracking transparency and contributors
598
+ final_Ts = wp.zeros((image_height, image_width), dtype=float, device=DEVICE)
599
+ n_contrib = wp.zeros((image_height, image_width), dtype=int, device=DEVICE)
600
+
601
+ background_warp = wp.vec3(background[0], background[1], background[2])
602
+ points_warp = to_warp_array(means3D, wp.vec3)#(device=DEVICE)
603
+ # SH coefficients should be shape (n, 16, 3)
604
+ # Convert to a flattened array but preserve the structure
605
+ sh_data = sh.reshape(-1, 3) if hasattr(sh, 'reshape') else sh
606
+ shs_warp = to_warp_array(sh_data, wp.vec3)#.to(device=DEVICE)
607
+
608
+ # Handle other parameters
609
+ opacities_warp = to_warp_array(opacity, float, flatten=True)#.to(device=DEVICE)
610
+ scales_warp = to_warp_array(scales, wp.vec3)#.to(device=DEVICE)
611
+ rotations_warp = to_warp_array(rotations, wp.vec4)#.to(device=DEVICE)
612
+
613
+ # Handle camera parameters
614
+ view_matrix_warp = wp.mat44(viewmatrix.flatten()) if not isinstance(viewmatrix, wp.mat44) else viewmatrix
615
+ proj_matrix_warp = wp.mat44(projmatrix.flatten()) if not isinstance(projmatrix, wp.mat44) else projmatrix
616
+ campos_warp = wp.vec3(campos[0], campos[1], campos[2]) if not isinstance(campos, wp.vec3) else campos
617
+
618
+ # Calculate tile grid for spatial optimization
619
+ tile_grid = wp.vec3((image_width + TILE_M - 1) // TILE_M,
620
+ (image_height + TILE_N - 1) // TILE_N,
621
+ 1)
622
+
623
+ # Preallocate buffers for preprocessed data
624
+ num_points = points_warp.shape[0]
625
+ radii = wp.zeros(num_points, dtype=int, device=DEVICE)
626
+ points_xy_image = wp.zeros(num_points, dtype=wp.vec2, device=DEVICE)
627
+ depths = wp.zeros(num_points, dtype=float, device=DEVICE)
628
+ cov3Ds = wp.zeros(num_points, dtype=VEC6, device=DEVICE)
629
+ rgb = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
630
+ conic_opacity = wp.zeros(num_points, dtype=wp.vec4, device=DEVICE)
631
+ tiles_touched = wp.zeros(num_points, dtype=int, device=DEVICE)
632
+
633
+ # Add clamped_state buffer to track which color channels are clamped
634
+ clamped_state = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
635
+
636
+ if debug:
637
+ print(f"\nWARP RENDERING: {image_width}x{image_height} image, {num_points} gaussians")
638
+ print(f"Colors: {'from SH' if colors is None else 'provided'}, SH degree: {degree}")
639
+ print(f"Antialiasing: {antialiasing}, Prefiltered: {prefiltered}")
640
+
641
+ # Launch preprocessing kernel
642
+ wp.launch(
643
+ kernel=wp_preprocess,
644
+ dim=(num_points,),
645
+ inputs=[
646
+ points_warp, # orig_points
647
+ scales_warp, # scales
648
+ scale_modifier, # scale_modifier
649
+ rotations_warp, # rotations_quat
650
+ opacities_warp, # opacities
651
+ shs_warp, # shs
652
+ degree,
653
+ clamped, # clamped
654
+ view_matrix_warp, # view_matrix
655
+ proj_matrix_warp, # proj_matrix
656
+ campos_warp, # cam_pos
657
+ image_width, # W
658
+ image_height, # H
659
+ tan_fovx, # tan_fovx
660
+ tan_fovy, # tan_fovy
661
+ image_width / (2.0 * tan_fovx), # focal_x
662
+ image_height / (2.0 * tan_fovy), # focal_y
663
+ radii, # radii
664
+ points_xy_image, # points_xy_image
665
+ depths, # depths
666
+ cov3Ds, # cov3Ds
667
+ rgb, # rgb
668
+ conic_opacity, # conic_opacity
669
+ tile_grid, # tile_grid
670
+ tiles_touched, # tiles_touched
671
+ clamped_state, # clamped_state - now using wp.vec3
672
+ prefiltered, # prefiltered
673
+ antialiasing # antialiasing
674
+ ],
675
+ )
676
+ point_offsets = wp.zeros(num_points, dtype=int, device=DEVICE)
677
+ wp.launch(
678
+ kernel=wp_prefix_sum,
679
+ dim=1,
680
+ inputs=[
681
+ tiles_touched,
682
+ point_offsets
683
+ ]
684
+ )
685
+ num_rendered = int(wp.to_torch(point_offsets)[-1].item()) # total number of duplicated entries
686
+ if num_rendered > (1 << 30):
687
+ # radix sort needs 2x memory
688
+ raise ValueError("Number of rendered points exceeds the maximum supported by Warp.")
689
+
690
+ point_list_keys_unsorted = wp.zeros(num_rendered, dtype=wp.int64, device=DEVICE)
691
+ point_list_unsorted = wp.zeros(num_rendered, dtype=int, device=DEVICE)
692
+ point_list_keys = wp.zeros(num_rendered, dtype=wp.int64, device=DEVICE)
693
+ point_list = wp.zeros(num_rendered, dtype=int, device=DEVICE)
694
+ wp.launch(
695
+ kernel=wp_duplicate_with_keys,
696
+ dim=num_points,
697
+ inputs=[
698
+ points_xy_image,
699
+ depths,
700
+ point_offsets,
701
+ point_list_keys_unsorted,
702
+ point_list_unsorted,
703
+ radii,
704
+ tile_grid
705
+ ]
706
+ )#
707
+ point_list_keys_unsorted_padded = wp.zeros(num_rendered * 2, dtype=wp.int64, device=DEVICE)
708
+ point_list_unsorted_padded = wp.zeros(num_rendered * 2, dtype=int, device=DEVICE)
709
+
710
+ # Copy data to padded arrays
711
+ wp.copy(point_list_keys_unsorted_padded, point_list_keys_unsorted)
712
+ wp.copy(point_list_unsorted_padded, point_list_unsorted)
713
+ wp.utils.radix_sort_pairs(
714
+ point_list_keys_unsorted_padded, # keys to sort
715
+ point_list_unsorted_padded, # values to sort along with keys
716
+ num_rendered # number of elements to sort
717
+ )
718
+
719
+ wp.launch(
720
+ kernel=wp_copy_int64,
721
+ dim=num_rendered,
722
+ inputs=[
723
+ point_list_keys_unsorted_padded,
724
+ point_list_keys,
725
+ num_rendered
726
+ ]
727
+ )
728
+
729
+ wp.launch(
730
+ kernel=wp_copy_int,
731
+ dim=num_rendered,
732
+ inputs=[
733
+ point_list_unsorted_padded,
734
+ point_list,
735
+ num_rendered
736
+ ]
737
+ )
738
+
739
+ tile_count = int(tile_grid[0] * tile_grid[1])
740
+ ranges = wp.zeros(tile_count, dtype=wp.vec2i, device=DEVICE) # each is (start, end)
741
+
742
+ if num_rendered > 0:
743
+ wp.launch(
744
+ kernel=wp_identify_tile_ranges, # You also need this kernel
745
+ dim=num_rendered,
746
+ inputs=[
747
+ num_rendered,
748
+ point_list_keys,
749
+ ranges
750
+ ]
751
+ )
752
+
753
+ wp.launch(
754
+ kernel=wp_render_gaussians,
755
+ dim=(int(tile_grid[0]), int(tile_grid[1]), TILE_M, TILE_N),
756
+ inputs=[
757
+ rendered_image, # Output color image
758
+ depth_image, # Output depth image
759
+ ranges, # Tile ranges
760
+ point_list, # Sorted point indices
761
+ image_width, # Image width
762
+ image_height, # Image height
763
+ points_xy_image, # 2D points
764
+ rgb, # Precomputed colors
765
+ conic_opacity, # Conic matrices and opacities
766
+ depths, # Depth values
767
+ background_warp, # Background color
768
+ tile_grid, # Tile grid configuration
769
+ final_Ts, # Final transparency values
770
+ n_contrib, # Number of contributors per pixel
771
+ ]
772
+ )
773
+
774
+ # Launch the pixel stats tracking kernel as a fallback
775
+ # to make sure final_Ts and n_contrib are populated
776
+ # This is especially important for existing rendered pixels
777
+ wp.launch(
778
+ kernel=track_pixel_stats,
779
+ dim=(image_width, image_height),
780
+ inputs=[
781
+ rendered_image,
782
+ depth_image,
783
+ background_warp,
784
+ final_Ts,
785
+ n_contrib,
786
+ image_width,
787
+ image_height
788
+ ]
789
+ )
790
+
791
+ return rendered_image, depth_image, {
792
+ "radii": radii,
793
+ "point_offsets": point_offsets,
794
+ "points_xy_image": points_xy_image,
795
+ "depths": depths,
796
+ "colors": rgb,
797
+ "cov3Ds": cov3Ds,
798
+ "conic_opacity": conic_opacity,
799
+ "point_list": point_list,
800
+ "ranges": ranges,
801
+ "final_Ts": final_Ts, # Add final_Ts to intermediate buffers
802
+ "n_contrib": n_contrib, # Add contributor count to intermediate buffers
803
+ "clamped_state": clamped_state # Add clamped state to intermediate buffers
804
+ }
gs/lib64 ADDED
@@ -0,0 +1 @@
 
 
1
+ lib
gs/loss.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warp as wp
2
+ import numpy as np
3
+ from config import DEVICE
4
+ from utils.wp_utils import wp_vec3_mul_element
5
+
6
+ # Constants for SSIM calculation
7
+ C1 = 0.01 ** 2
8
+ C2 = 0.03 ** 2
9
+ WINDOW_SIZE = 11
10
+
11
+ @wp.kernel
12
+ def l1_loss_kernel(
13
+ rendered: wp.array2d(dtype=wp.vec3),
14
+ target: wp.array2d(dtype=wp.vec3),
15
+ loss_buffer: wp.array(dtype=float),
16
+ width: int,
17
+ height: int
18
+ ):
19
+ i, j = wp.tid()
20
+ if i >= width or j >= height:
21
+ return
22
+
23
+ # Compute L1 difference for each pixel component
24
+ rendered_pixel = rendered[j, i]
25
+ target_pixel = target[j, i]
26
+ diff = wp.abs(rendered_pixel - target_pixel)
27
+ l1_diff = diff[0] + diff[1] + diff[2]
28
+
29
+ # Atomic add to loss buffer
30
+ wp.atomic_add(loss_buffer, 0, l1_diff)
31
+
32
+
33
+ @wp.kernel
34
+ def gaussian_kernel(
35
+ kernel: wp.array(dtype=float),
36
+ sigma: float,
37
+ kernel_size: int
38
+ ):
39
+ i = wp.tid()
40
+ if i >= kernel_size:
41
+ return
42
+
43
+ center = kernel_size // 2
44
+ x = i - center
45
+ kernel[i] = wp.exp(-1.0 * float(x * x) / (2.0 * sigma * sigma))
46
+
47
+ @wp.kernel
48
+ def ssim_kernel(
49
+ rendered: wp.array2d(dtype=wp.vec3),
50
+ target: wp.array2d(dtype=wp.vec3),
51
+ gaussian_weights: wp.array(dtype=float),
52
+ ssim_buffer: wp.array(dtype=float),
53
+ width: int,
54
+ height: int,
55
+ window_size: int
56
+ ):
57
+ i, j = wp.tid()
58
+ if i >= width or j >= height:
59
+ return
60
+
61
+ # Constants for numerical stability
62
+ c1 = 0.01 * 0.01
63
+ c2 = 0.03 * 0.03
64
+
65
+ # We'll compute SSIM in a local window around each pixel
66
+ half_window = window_size // 2
67
+
68
+ # Initialize accumulators
69
+ mu1 = wp.vec3(0.0, 0.0, 0.0)
70
+ mu2 = wp.vec3(0.0, 0.0, 0.0)
71
+ sigma1 = wp.vec3(0.0, 0.0, 0.0)
72
+ sigma2 = wp.vec3(0.0, 0.0, 0.0)
73
+ sigma12 = wp.vec3(0.0, 0.0, 0.0)
74
+ weight_sum = float(0.0)
75
+
76
+ # Calculate weighted means and variances over the window
77
+ for y in range(max(0, j - half_window), min(height, j + half_window + 1)):
78
+ for x in range(max(0, i - half_window), min(width, i + half_window + 1)):
79
+ # Get Gaussian weight for this position
80
+ wy = abs(y - j)
81
+ wx = abs(x - i)
82
+ if wx <= half_window and wy <= half_window:
83
+ w = gaussian_weights[wx] * gaussian_weights[wy]
84
+
85
+ # Get pixels
86
+ p1 = rendered[y, x]
87
+ p2 = target[y, x]
88
+
89
+ # Accumulate weighted values
90
+ mu1 += p1 * w
91
+ mu2 += p2 * w
92
+ sigma1 += wp_vec3_mul_element(p1, p1) * w
93
+ sigma2 += wp_vec3_mul_element(p2, p2) * w
94
+ sigma12 += wp_vec3_mul_element(p1, p2) * w
95
+ weight_sum += w
96
+
97
+ # Normalize by weights
98
+ if weight_sum > 0.0:
99
+ mu1 /= weight_sum
100
+ mu2 /= weight_sum
101
+ sigma1 /= weight_sum
102
+ sigma2 /= weight_sum
103
+ sigma12 /= weight_sum
104
+
105
+ # Calculate variance and covariance
106
+ sigma1 = sigma1 - wp_vec3_mul_element(mu1, mu1)
107
+ sigma2 = sigma2 - wp_vec3_mul_element(mu2, mu2)
108
+ sigma12 = sigma12 - wp_vec3_mul_element(mu1, mu2)
109
+
110
+ # Calculate SSIM for each channel
111
+ ssim_r = ((2.0 * mu1[0] * mu2[0] + c1) * (2.0 * sigma12[0] + c2)) / ((mu1[0] * mu1[0] + mu2[0] * mu2[0] + c1) * (sigma1[0] + sigma2[0] + c2))
112
+ ssim_g = ((2.0 * mu1[1] * mu2[1] + c1) * (2.0 * sigma12[1] + c2)) / ((mu1[1] * mu1[1] + mu2[1] * mu2[1] + c1) * (sigma1[1] + sigma2[1] + c2))
113
+ ssim_b = ((2.0 * mu1[2] * mu2[2] + c1) * (2.0 * sigma12[2] + c2)) / ((mu1[2] * mu1[2] + mu2[2] * mu2[2] + c1) * (sigma1[2] + sigma2[2] + c2))
114
+
115
+ # Average SSIM across channels
116
+ ssim_val = (ssim_r + ssim_g + ssim_b) / 3.0
117
+
118
+ # Atomic add to SSIM buffer
119
+ wp.atomic_add(ssim_buffer, 0, ssim_val)
120
+
121
+ @wp.kernel
122
+ def backprop_l1_pixel_gradients(
123
+ rendered: wp.array2d(dtype=wp.vec3),
124
+ target: wp.array2d(dtype=wp.vec3),
125
+ pixel_grad: wp.array2d(dtype=wp.vec3),
126
+ width: int,
127
+ height: int,
128
+ l1_weight: float
129
+ ):
130
+ i, j = wp.tid()
131
+ if i >= width or j >= height:
132
+ return
133
+
134
+ # Compute gradient (sign function for L1 loss)
135
+ rendered_pixel = rendered[j, i]
136
+ target_pixel = target[j, i]
137
+
138
+ # Sign function for L1 gradient
139
+ l1_grad = wp.vec3(
140
+ l1_weight * wp.sign(rendered_pixel[0] - target_pixel[0]),
141
+ l1_weight * wp.sign(rendered_pixel[1] - target_pixel[1]),
142
+ l1_weight * wp.sign(rendered_pixel[2] - target_pixel[2])
143
+ )
144
+
145
+ # Store L1 gradients
146
+ pixel_grad[j, i] = l1_grad
147
+
148
+ def l1_loss(rendered, target):
149
+ """Compute L1 loss between rendered and target images"""
150
+ height, width = rendered.shape[0], rendered.shape[1]
151
+
152
+ # Create device arrays if not already
153
+ if not isinstance(rendered, wp.array):
154
+ d_rendered = wp.array(rendered, dtype=wp.vec3, device=DEVICE)
155
+ else:
156
+ d_rendered = rendered
157
+
158
+ if not isinstance(target, wp.array):
159
+ d_target = wp.array(target, dtype=wp.vec3, device=DEVICE)
160
+ else:
161
+ d_target = target
162
+
163
+ # Create loss buffer
164
+ loss_buffer = wp.zeros(1, dtype=float, device=DEVICE)
165
+
166
+ # Compute loss
167
+ wp.launch(
168
+ kernel=l1_loss_kernel,
169
+ dim=(width, height),
170
+ inputs=[d_rendered, d_target, loss_buffer, width, height]
171
+ )
172
+
173
+ # Get loss value
174
+ loss = float(loss_buffer.numpy()[0]) / (width * height * 3) # Normalize by pixel count and channels
175
+ np_loss_buffer = loss_buffer.numpy()
176
+ return loss
177
+
178
+ def ssim(rendered, target):
179
+ """Compute SSIM between rendered and target images"""
180
+ height, width = rendered.shape[0], rendered.shape[1]
181
+
182
+ # Create device arrays if not already
183
+ if not isinstance(rendered, wp.array):
184
+ d_rendered = wp.array(rendered, dtype=wp.vec3, device=DEVICE)
185
+ else:
186
+ d_rendered = rendered
187
+
188
+ if not isinstance(target, wp.array):
189
+ d_target = wp.array(target, dtype=wp.vec3, device=DEVICE)
190
+ else:
191
+ d_target = target
192
+
193
+ # Precompute Gaussian kernel
194
+ kernel_size = WINDOW_SIZE
195
+ gaussian_weights = wp.zeros(kernel_size, dtype=float, device=DEVICE)
196
+ wp.launch(
197
+ gaussian_kernel,
198
+ dim=kernel_size,
199
+ inputs=[gaussian_weights, 1.5, kernel_size]
200
+ )
201
+
202
+ # Create SSIM buffer
203
+ ssim_buffer = wp.zeros(1, dtype=float, device=DEVICE)
204
+ pixel_count = wp.zeros(1, dtype=int, device=DEVICE)
205
+
206
+ # Compute SSIM
207
+ wp.launch(
208
+ ssim_kernel,
209
+ dim=(width, height),
210
+ inputs=[d_rendered, d_target, gaussian_weights, ssim_buffer, width, height, kernel_size]
211
+ )
212
+
213
+ # Get SSIM value (average over valid pixels)
214
+ ssim_val = float(ssim_buffer.numpy()[0]) / (width * height)
215
+ return ssim_val
216
+
217
+ def compute_image_gradients(rendered, target, lambda_dssim=0.2):
218
+ """Compute gradients for combined L1 and SSIM loss"""
219
+ height, width = rendered.shape[0], rendered.shape[1]
220
+
221
+ # Create device arrays if not already
222
+ if not isinstance(rendered, wp.array):
223
+ d_rendered = wp.array(rendered, dtype=wp.vec3, device=DEVICE)
224
+ else:
225
+ d_rendered = rendered
226
+
227
+ if not isinstance(target, wp.array):
228
+ d_target = wp.array(target, dtype=wp.vec3, device=DEVICE)
229
+ else:
230
+ d_target = target
231
+
232
+ # Create gradient buffer
233
+ pixel_grad = wp.zeros((height, width), dtype=wp.vec3, device=DEVICE)
234
+
235
+ # Compute L1 loss gradient
236
+ l1_weight = (1.0 - lambda_dssim) / (height * width * 3.0)
237
+ wp.launch(
238
+ backprop_l1_pixel_gradients,
239
+ dim=(width, height),
240
+ inputs=[d_rendered, d_target, pixel_grad, width, height, l1_weight]
241
+ )
242
+
243
+ # TODO: Add SSIM gradient
244
+ return pixel_grad
245
+
246
+
247
+ @wp.kernel
248
+ def depth_loss_kernel(
249
+ rendered_depth: wp.array2d(dtype=float),
250
+ target_depth: wp.array2d(dtype=float),
251
+ depth_mask: wp.array2d(dtype=float),
252
+ loss_buffer: wp.array(dtype=float),
253
+ width: int,
254
+ height: int
255
+ ):
256
+ i, j = wp.tid()
257
+ if i >= width or j >= height:
258
+ return
259
+
260
+ # Get depths and mask
261
+ rendered_inv_depth = rendered_depth[j, i]
262
+ target_inv_depth = target_depth[j, i]
263
+ mask = depth_mask[j, i]
264
+
265
+ # Compute L1 difference for inverse depths
266
+ diff = wp.abs(rendered_inv_depth - target_inv_depth) * mask
267
+
268
+ # Atomic add to loss buffer
269
+ wp.atomic_add(loss_buffer, 0, diff)
270
+
271
+ def depth_loss(rendered_depth, target_depth, depth_mask):
272
+ """Compute L1 loss between rendered and target inverse depths"""
273
+ height, width = rendered_depth.shape[0], rendered_depth.shape[1]
274
+
275
+ # Create device arrays if not already
276
+ if not isinstance(rendered_depth, wp.array):
277
+ d_rendered_depth = wp.array(rendered_depth, dtype=float, device=DEVICE)
278
+ else:
279
+ d_rendered_depth = rendered_depth
280
+
281
+ if not isinstance(target_depth, wp.array):
282
+ d_target_depth = wp.array(target_depth, dtype=float, device=DEVICE)
283
+ else:
284
+ d_target_depth = target_depth
285
+
286
+ if not isinstance(depth_mask, wp.array):
287
+ d_depth_mask = wp.array(depth_mask, dtype=float, device=DEVICE)
288
+ else:
289
+ d_depth_mask = depth_mask
290
+
291
+ # Create loss buffer
292
+ loss_buffer = wp.zeros(1, dtype=float, device=DEVICE)
293
+
294
+ # Compute loss
295
+ wp.launch(
296
+ kernel=depth_loss_kernel,
297
+ dim=(width, height),
298
+ inputs=[d_rendered_depth, d_target_depth, d_depth_mask, loss_buffer, width, height]
299
+ )
300
+
301
+ # Get loss value
302
+ loss = float(loss_buffer.numpy()[0]) / (width * height) # Normalize by pixel count
303
+ return loss
gs/optimizer.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warp as wp
2
+ from utils.wp_utils import to_warp_array, wp_vec3_mul_element, wp_vec3_add_element, wp_vec3_sqrt, wp_vec3_div_element, wp_vec3_clamp
3
+ from config import *
4
+
5
+ @wp.kernel
6
+ def adam_update(
7
+ # Parameters
8
+ positions: wp.array(dtype=wp.vec3),
9
+ scales: wp.array(dtype=wp.vec3),
10
+ rotations: wp.array(dtype=wp.vec4),
11
+ opacities: wp.array(dtype=float),
12
+ shs: wp.array(dtype=wp.vec3),
13
+
14
+ # Gradients
15
+ pos_grads: wp.array(dtype=wp.vec3),
16
+ scale_grads: wp.array(dtype=wp.vec3),
17
+ rot_grads: wp.array(dtype=wp.vec4),
18
+ opacity_grads: wp.array(dtype=float),
19
+ sh_grads: wp.array(dtype=wp.vec3),
20
+
21
+ # First moments (m)
22
+ m_positions: wp.array(dtype=wp.vec3),
23
+ m_scales: wp.array(dtype=wp.vec3),
24
+ m_rotations: wp.array(dtype=wp.vec4),
25
+ m_opacities: wp.array(dtype=float),
26
+ m_shs: wp.array(dtype=wp.vec3),
27
+
28
+ # Second moments (v)
29
+ v_positions: wp.array(dtype=wp.vec3),
30
+ v_scales: wp.array(dtype=wp.vec3),
31
+ v_rotations: wp.array(dtype=wp.vec4),
32
+ v_opacities: wp.array(dtype=float),
33
+ v_shs: wp.array(dtype=wp.vec3),
34
+
35
+ num_points: int,
36
+ lr_pos: float,
37
+ lr_scale: float,
38
+ lr_rot: float,
39
+ lr_opac: float,
40
+ lr_sh: float,
41
+ beta1: float,
42
+ beta2: float,
43
+ epsilon: float,
44
+ iteration: int
45
+ ):
46
+ i = wp.tid()
47
+ if i >= num_points:
48
+ return
49
+
50
+ # Bias correction terms
51
+ bias_correction1 = 1.0 - wp.pow(beta1, float(iteration + 1))
52
+ bias_correction2 = 1.0 - wp.pow(beta2, float(iteration + 1))
53
+
54
+ # Update positions
55
+ m_positions[i] = beta1 * m_positions[i] + (1.0 - beta1) * pos_grads[i]
56
+ # Use the helper function for element-wise multiplication
57
+ v_positions[i] = beta2 * v_positions[i] + (1.0 - beta2) * wp_vec3_mul_element(pos_grads[i], pos_grads[i])
58
+ # Use distinct names for corrected moments per parameter type
59
+ m_pos_corrected = m_positions[i] / bias_correction1
60
+ v_pos_corrected = v_positions[i] / bias_correction2
61
+ # Use the helper function for element-wise sqrt and division
62
+ denominator_pos = wp_vec3_sqrt(v_pos_corrected) + wp.vec3(epsilon, epsilon, epsilon)
63
+ positions[i] = positions[i] - lr_pos * wp_vec3_div_element(m_pos_corrected, denominator_pos)
64
+
65
+ # Update scales (with some constraints to keep them positive)
66
+ m_scales[i] = beta1 * m_scales[i] + (1.0 - beta1) * scale_grads[i]
67
+ # Use the helper function for element-wise multiplication
68
+ v_scales[i] = beta2 * v_scales[i] + (1.0 - beta2) * wp_vec3_mul_element(scale_grads[i], scale_grads[i])
69
+ # Use distinct names for corrected moments per parameter type
70
+ m_scale_corrected = m_scales[i] / bias_correction1
71
+ v_scale_corrected = v_scales[i] / bias_correction2
72
+ # Use the helper function for element-wise sqrt and division
73
+ denominator_scale = wp_vec3_sqrt(v_scale_corrected) + wp.vec3(epsilon, epsilon, epsilon)
74
+ scale_update = lr_scale * wp_vec3_div_element(m_scale_corrected, denominator_scale)
75
+ scales[i] = wp.vec3(
76
+ wp.max(scales[i][0] - scale_update[0], 0.001),
77
+ wp.max(scales[i][1] - scale_update[1], 0.001),
78
+ wp.max(scales[i][2] - scale_update[2], 0.001)
79
+ )
80
+
81
+ # Update rotations
82
+ m_rotations[i] = beta1 * m_rotations[i] + (1.0 - beta1) * rot_grads[i]
83
+ # Element-wise multiplication for quaternions
84
+ v_rotations[i] = beta2 * v_rotations[i] + (1.0 - beta2) * wp.vec4(
85
+ rot_grads[i][0] * rot_grads[i][0],
86
+ rot_grads[i][1] * rot_grads[i][1],
87
+ rot_grads[i][2] * rot_grads[i][2],
88
+ rot_grads[i][3] * rot_grads[i][3]
89
+ )
90
+ m_rot_corrected = m_rotations[i] / bias_correction1
91
+ v_rot_corrected = v_rotations[i] / bias_correction2
92
+ # Element-wise sqrt and division for quaternions
93
+ denominator_rot = wp.vec4(
94
+ wp.sqrt(v_rot_corrected[0]) + epsilon,
95
+ wp.sqrt(v_rot_corrected[1]) + epsilon,
96
+ wp.sqrt(v_rot_corrected[2]) + epsilon,
97
+ wp.sqrt(v_rot_corrected[3]) + epsilon
98
+ )
99
+ rot_update = wp.vec4(
100
+ lr_rot * m_rot_corrected[0] / denominator_rot[0],
101
+ lr_rot * m_rot_corrected[1] / denominator_rot[1],
102
+ lr_rot * m_rot_corrected[2] / denominator_rot[2],
103
+ lr_rot * m_rot_corrected[3] / denominator_rot[3]
104
+ )
105
+ rotations[i] = rotations[i] - rot_update
106
+
107
+ # Normalize quaternion to ensure it's a valid rotation
108
+ quat_length = wp.sqrt(rotations[i][0]*rotations[i][0] +
109
+ rotations[i][1]*rotations[i][1] +
110
+ rotations[i][2]*rotations[i][2] +
111
+ rotations[i][3]*rotations[i][3])
112
+
113
+ if quat_length > 0.0:
114
+ rotations[i] = wp.vec4(
115
+ rotations[i][0] / quat_length,
116
+ rotations[i][1] / quat_length,
117
+ rotations[i][2] / quat_length,
118
+ rotations[i][3] / quat_length
119
+ )
120
+
121
+ # Update opacity (with clamping to [0,1])
122
+ m_opacities[i] = beta1 * m_opacities[i] + (1.0 - beta1) * opacity_grads[i]
123
+ # Opacity is scalar, direct multiplication is fine
124
+ v_opacities[i] = beta2 * v_opacities[i] + (1.0 - beta2) * (opacity_grads[i] * opacity_grads[i])
125
+ # Use distinct names for corrected moments per parameter type
126
+ m_opacity_corrected = m_opacities[i] / bias_correction1
127
+ v_opacity_corrected = v_opacities[i] / bias_correction2
128
+ # Opacity is scalar, direct wp.sqrt is fine here
129
+ opacity_update = lr_opac * m_opacity_corrected / (wp.sqrt(v_opacity_corrected) + epsilon)
130
+ opacities[i] = wp.max(wp.min(opacities[i] - opacity_update, 1.0), 0.0)
131
+
132
+ # Update SH coefficients
133
+ for j in range(16):
134
+ idx = i * 16 + j
135
+ m_shs[idx] = beta1 * m_shs[idx] + (1.0 - beta1) * sh_grads[idx]
136
+ # Use the helper function for element-wise multiplication
137
+ v_shs[idx] = beta2 * v_shs[idx] + (1.0 - beta2) * wp_vec3_mul_element(sh_grads[idx], sh_grads[idx])
138
+ # Use distinct names for corrected moments per parameter type
139
+ m_sh_corrected = m_shs[idx] / bias_correction1
140
+ v_sh_corrected = v_shs[idx] / bias_correction2
141
+ # Use the helper function for element-wise sqrt and division
142
+ denominator_sh = wp_vec3_sqrt(v_sh_corrected) + wp.vec3(epsilon, epsilon, epsilon)
143
+ shs[idx] = shs[idx] - lr_sh * wp_vec3_div_element(m_sh_corrected, denominator_sh)
144
+
145
+
146
+ @wp.kernel
147
+ def reset_opacities(
148
+ opacities: wp.array(dtype=float),
149
+ max_opacity: float,
150
+ num_points: int
151
+ ):
152
+ """Reset opacities to prevent oversaturation."""
153
+ i = wp.tid()
154
+ if i >= num_points:
155
+ return
156
+
157
+ # Reset opacity to a small value
158
+ opacities[i] = max_opacity
159
+
160
+ @wp.kernel
161
+ def reset_densification_stats(
162
+ xyz_gradient_accum: wp.array(dtype=float),
163
+ denom: wp.array(dtype=float),
164
+ max_radii2D: wp.array(dtype=float),
165
+ num_points: int
166
+ ):
167
+ """Reset densification statistics after parameter count changes."""
168
+ i = wp.tid()
169
+ if i >= num_points:
170
+ return
171
+
172
+ xyz_gradient_accum[i] = 0.0
173
+ denom[i] = 0.0
174
+ max_radii2D[i] = 0.0
175
+
176
+
177
+ @wp.kernel
178
+ def mark_split_candidates(
179
+ grads: wp.array(dtype=float),
180
+ scales: wp.array(dtype=wp.vec3),
181
+ grad_threshold: float,
182
+ scene_extent: float,
183
+ percent_dense: float,
184
+ split_mask: wp.array(dtype=int),
185
+ num_points: int
186
+ ):
187
+ """Mark large Gaussians with high gradients for splitting."""
188
+ i = wp.tid()
189
+ if i >= num_points:
190
+ return
191
+
192
+ # Check if gradient exceeds threshold
193
+ high_grad = grads[i] >= grad_threshold
194
+
195
+ # Check if Gaussian is large (max scale > threshold)
196
+ max_scale = wp.max(wp.max(scales[i][0], scales[i][1]), scales[i][2])
197
+ scale_threshold = percent_dense * scene_extent
198
+ large_gaussian = max_scale > scale_threshold
199
+
200
+ # Mark for splitting if both conditions are met
201
+ if (high_grad and large_gaussian):
202
+ split_mask[i] = 1
203
+ else:
204
+ split_mask[i] = 0
205
+
206
+ @wp.kernel
207
+ def mark_clone_candidates(
208
+ grads: wp.array(dtype=float),
209
+ scales: wp.array(dtype=wp.vec3),
210
+ grad_threshold: float,
211
+ scene_extent: float,
212
+ percent_dense: float,
213
+ clone_mask: wp.array(dtype=int),
214
+ num_points: int
215
+ ):
216
+ """Mark small Gaussians with high gradients for cloning."""
217
+ i = wp.tid()
218
+ if i >= num_points:
219
+ return
220
+
221
+ # Check if gradient exceeds threshold
222
+ high_grad = grads[i] >= grad_threshold
223
+
224
+ # Check if Gaussian is small (max scale <= threshold)
225
+ max_scale = wp.max(wp.max(scales[i][0], scales[i][1]), scales[i][2])
226
+ scale_threshold = percent_dense * scene_extent
227
+ small_gaussian = max_scale <= scale_threshold
228
+
229
+ # Mark for cloning if both conditions are met
230
+ if (high_grad and small_gaussian):
231
+ clone_mask[i] = 1
232
+ else:
233
+ clone_mask[i] = 0
234
+
235
+ @wp.kernel
236
+ def split_gaussians(
237
+ split_mask: wp.array(dtype=int),
238
+ prefix_sum: wp.array(dtype=int),
239
+ positions: wp.array(dtype=wp.vec3),
240
+ scales: wp.array(dtype=wp.vec3),
241
+ rotations: wp.array(dtype=wp.vec4),
242
+ opacities: wp.array(dtype=float),
243
+ shs: wp.array(dtype=wp.vec3),
244
+ N_split: int,
245
+ scale_factor: float,
246
+ offset: int,
247
+ out_positions: wp.array(dtype=wp.vec3),
248
+ out_scales: wp.array(dtype=wp.vec3),
249
+ out_rotations: wp.array(dtype=wp.vec4),
250
+ out_opacities: wp.array(dtype=float),
251
+ out_shs: wp.array(dtype=wp.vec3)
252
+ ):
253
+ """Split large Gaussians into multiple smaller ones."""
254
+ i = wp.tid()
255
+
256
+ # Copy original Gaussians first
257
+ if i < len(positions):
258
+ out_positions[i] = positions[i]
259
+ out_scales[i] = scales[i]
260
+ out_rotations[i] = rotations[i]
261
+ out_opacities[i] = opacities[i]
262
+
263
+ # Copy SH coefficients
264
+ for j in range(16):
265
+ out_shs[i * 16 + j] = shs[i * 16 + j]
266
+
267
+ # Handle splits
268
+ if i >= len(positions):
269
+ return
270
+
271
+ if split_mask[i] == 1:
272
+ # Find where to write new Gaussians
273
+ split_idx = prefix_sum[i]
274
+
275
+ # Create N_split new Gaussians
276
+ for j in range(N_split):
277
+ new_idx = offset + split_idx * N_split + j
278
+ if new_idx < len(out_positions):
279
+ # Scale down the original Gaussian
280
+ scaled_scales = wp.vec3(
281
+ scales[i][0] * scale_factor,
282
+ scales[i][1] * scale_factor,
283
+ scales[i][2] * scale_factor
284
+ )
285
+
286
+ # Add small random offset for position
287
+ random_offset = wp.vec3(
288
+ ((wp.randf(wp.uint32(new_idx * 3))) * 2.0 - 1.0) * 0.01,
289
+ ((wp.randf(wp.uint32(new_idx * 3 + 1))) * 2.0 - 1.0) * 0.01,
290
+ ((wp.randf(wp.uint32(new_idx * 3 + 2))) * 2.0 - 1.0) * 0.01
291
+ )
292
+
293
+ out_positions[new_idx] = positions[i] + random_offset
294
+ out_scales[new_idx] = scaled_scales
295
+ out_rotations[new_idx] = rotations[i]
296
+ out_opacities[new_idx] = opacities[i]
297
+
298
+ # Copy SH coefficients
299
+ for k in range(16):
300
+ out_shs[new_idx * 16 + k] = shs[i * 16 + k]
301
+
302
+
303
+ @wp.kernel
304
+ def clone_gaussians(
305
+ clone_mask: wp.array(dtype=int),
306
+ prefix_sum: wp.array(dtype=int),
307
+ positions: wp.array(dtype=wp.vec3),
308
+ scales: wp.array(dtype=wp.vec3),
309
+ rotations: wp.array(dtype=wp.vec4),
310
+ opacities: wp.array(dtype=float),
311
+ shs: wp.array(dtype=wp.vec3), # shape: [N * 16]
312
+
313
+ noise_scale: float,
314
+ offset: int, # where to start writing new points
315
+ out_positions: wp.array(dtype=wp.vec3),
316
+ out_scales: wp.array(dtype=wp.vec3),
317
+ out_rotations: wp.array(dtype=wp.vec4),
318
+ out_opacities: wp.array(dtype=float),
319
+ out_shs: wp.array(dtype=wp.vec3),
320
+ ):
321
+ i = wp.tid()
322
+ if i >= offset:
323
+ return
324
+
325
+ # Copy original to out[i]
326
+ out_positions[i] = positions[i]
327
+ out_scales[i] = scales[i]
328
+ out_rotations[i] = rotations[i]
329
+ out_opacities[i] = opacities[i]
330
+ for j in range(16):
331
+ out_shs[i * 16 + j] = shs[i * 16 + j]
332
+
333
+ if clone_mask[i] == 1:
334
+ base_idx = prefix_sum[i] + offset
335
+ pos = positions[i]
336
+ scale = scales[i]
337
+ rot = rotations[i]
338
+ opac = opacities[i]
339
+
340
+
341
+ noise = wp.vec3(
342
+ wp.randf(wp.uint32(i * 3)) * noise_scale,
343
+ wp.randf(wp.uint32(i * 3 + 1)) * noise_scale,
344
+ wp.randf(wp.uint32(i * 3 + 2)) * noise_scale
345
+ )
346
+
347
+ out_positions[base_idx] = pos + noise
348
+ out_scales[base_idx] = scale
349
+ out_rotations[base_idx] = rot
350
+ out_opacities[base_idx] = opac
351
+
352
+ for j in range(16):
353
+ out_shs[base_idx * 16 + j] = shs[i * 16 + j]
354
+
355
+ @wp.kernel
356
+ def prune_gaussians(
357
+ opacities: wp.array(dtype=float),
358
+ opacity_threshold: float,
359
+ valid_mask: wp.array(dtype=int),
360
+ num_points: int
361
+ ):
362
+ i = wp.tid()
363
+ if i >= num_points:
364
+ return
365
+ # Mark Gaussians for keeping or removal
366
+ if opacities[i] > opacity_threshold:
367
+ valid_mask[i] = 1
368
+ else:
369
+ valid_mask[i] = 0
370
+
371
+ @wp.kernel
372
+ def compact_gaussians(
373
+ valid_mask: wp.array(dtype=int),
374
+ prefix_sum: wp.array(dtype=int),
375
+ positions: wp.array(dtype=wp.vec3),
376
+ scales: wp.array(dtype=wp.vec3),
377
+ rotations: wp.array(dtype=wp.vec4),
378
+ opacities: wp.array(dtype=float),
379
+ shs: wp.array(dtype=wp.vec3), # shape: [N * 16]
380
+
381
+ out_positions: wp.array(dtype=wp.vec3),
382
+ out_scales: wp.array(dtype=wp.vec3),
383
+ out_rotations: wp.array(dtype=wp.vec4),
384
+ out_opacities: wp.array(dtype=float),
385
+ out_shs: wp.array(dtype=wp.vec3)
386
+ ):
387
+ i = wp.tid()
388
+ if valid_mask[i] == 0:
389
+ return
390
+
391
+ new_i = prefix_sum[i]
392
+
393
+ out_positions[new_i] = positions[i]
394
+ out_scales[new_i] = scales[i]
395
+ out_rotations[new_i] = rotations[i]
396
+ out_opacities[new_i] = opacities[i]
397
+
398
+ for j in range(16):
399
+ out_shs[new_i * 16 + j] = shs[i * 16 + j]
gs/render.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import warp as wp
3
+ import matplotlib.pyplot as plt
4
+ import math
5
+ from forward import render_gaussians
6
+ from utils.math_utils import world_to_view, projection_matrix
7
+
8
+ # Initialize Warp
9
+ wp.init()
10
+
11
+ def setup_example_scene(image_width=1800, image_height=1800, fovx=45.0, fovy=45.0, znear=0.01, zfar=100.0):
12
+ """Setup example scene with camera and Gaussians for testing and debugging"""
13
+ # Camera setup
14
+ T = np.array([0, 0, 5], dtype=np.float32)
15
+ R = np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]], dtype=np.float32)
16
+ world_to_camera = np.eye(4, dtype=np.float32)
17
+ world_to_camera[:3, :3] = R
18
+ world_to_camera[:3, 3] = T
19
+ world_to_camera = world_to_camera.T
20
+
21
+ # Compute matrices
22
+ view_matrix = world_to_view(R=R, t=T)
23
+ proj_matrix = projection_matrix(fovx=fovx, fovy=fovy, znear=znear, zfar=zfar).T
24
+ full_proj_matrix = world_to_camera @ proj_matrix
25
+
26
+ camera_center = np.linalg.inv(world_to_camera)[3, :3]
27
+
28
+ # Compute FOV parameters
29
+ tan_fovx = math.tan(fovx * 0.5)
30
+ tan_fovy = math.tan(fovy * 0.5)
31
+
32
+ focal_x = image_width / (2 * tan_fovx)
33
+ focal_y = image_height / (2 * tan_fovy)
34
+
35
+ camera_params = {
36
+ 'R': R,
37
+ 'T': T,
38
+ 'camera_center': camera_center,
39
+ 'view_matrix': view_matrix,
40
+ 'proj_matrix': proj_matrix,
41
+ 'world_to_camera': world_to_camera,
42
+ 'full_proj_matrix': full_proj_matrix,
43
+ 'tan_fovx': tan_fovx,
44
+ 'tan_fovy': tan_fovy,
45
+ 'focal_x': focal_x,
46
+ 'focal_y': focal_y,
47
+ 'width': image_width,
48
+ 'height': image_height
49
+ }
50
+
51
+ # Gaussian setup - 3 points in a line
52
+ pts = np.array([[-5, 0, -10], [0, 2, -10], [5, 0, -10]], dtype=np.float32)
53
+ n = len(pts)
54
+
55
+ # Hard-coded SHs for debugging
56
+ shs = np.array([[0.71734341, 0.91905449, 0.49961076],
57
+ [0.08068483, 0.82132256, 0.01301602],
58
+ [0.8335743, 0.31798138, 0.19709007],
59
+ [0.82589597, 0.28206231, 0.790489 ],
60
+ [0.24008527, 0.21312673, 0.53132892],
61
+ [0.19493135, 0.37989934, 0.61886235],
62
+ [0.98106522, 0.28960672, 0.57313965],
63
+ [0.92623716, 0.46034381, 0.5485369 ],
64
+ [0.81660616, 0.7801104, 0.27813915],
65
+ [0.96114063, 0.69872817, 0.68313804],
66
+ [0.95464185, 0.21984855, 0.92912192],
67
+ [0.23503135, 0.29786121, 0.24999751],
68
+ [0.29844887, 0.6327788, 0.05423596],
69
+ [0.08934335, 0.11851827, 0.04186001],
70
+ [0.59331831, 0.919777, 0.71364335],
71
+ [0.83377388, 0.40242542, 0.8792624 ]]*n).reshape(n, 16, 3)
72
+
73
+
74
+
75
+ opacities = np.ones((n, 1), dtype=np.float32)
76
+
77
+ # Random anisotropic scales (e.g., each axis between 0.5 and 2.0)
78
+ scales = (0.2 + 1.5 * np.random.rand(n, 3)).astype(np.float32)
79
+
80
+ # Random rotations as unit quaternions
81
+ q = np.random.randn(n, 4).astype(np.float32)
82
+ rotations = q / np.linalg.norm(q, axis=1, keepdims=True)
83
+
84
+ colors = np.ones((n, 3), dtype=np.float32)
85
+
86
+ return pts, shs, scales, colors, rotations, opacities, camera_params
87
+
88
+ if __name__ == "__main__":
89
+ # Setup rendering parameters
90
+ image_width = 1800
91
+ image_height = 1800
92
+ background = np.array([0.0, 0.0, 0.0], dtype=np.float32) # Black background
93
+ scale_modifier = 1.0
94
+ sh_degree = 3
95
+ prefiltered = False
96
+ antialiasing = False
97
+ clamped = True
98
+
99
+ # Create example scene
100
+ pts, shs, scales, colors, rotations, opacities, camera_params = setup_example_scene(
101
+ image_width=image_width,
102
+ image_height=image_height
103
+ )
104
+ n = len(pts)
105
+ print(f"Created example scene with {n} Gaussians")
106
+
107
+ # Call the Gaussian rasterizer
108
+ rendered_image, depth_image, _ = render_gaussians(
109
+ background=background,
110
+ means3D=pts,
111
+ colors=colors,
112
+ opacity=opacities,
113
+ scales=scales,
114
+ rotations=rotations,
115
+ scale_modifier=scale_modifier,
116
+ viewmatrix=camera_params['view_matrix'],
117
+ projmatrix=camera_params['full_proj_matrix'],
118
+ tan_fovx=camera_params['tan_fovx'],
119
+ tan_fovy=camera_params['tan_fovy'],
120
+ image_height=image_height,
121
+ image_width=image_width,
122
+ sh=shs,
123
+ degree=sh_degree,
124
+ campos=camera_params['camera_center'],
125
+ prefiltered=prefiltered,
126
+ antialiasing=antialiasing,
127
+ clamped=clamped,
128
+ debug=False
129
+ )
130
+
131
+ print("Rendering completed")
132
+
133
+ # Convert the rendered image from device to host
134
+ rendered_array = wp.to_torch(rendered_image).cpu().numpy()
135
+
136
+ # Display and save using matplotlib
137
+ plt.figure(figsize=(10, 10))
138
+ plt.imshow(rendered_array)
139
+ plt.axis('off')
140
+ plt.savefig("example_render.png", bbox_inches='tight', dpi=150)
141
+ print("Rendered image saved to example_render.png")
gs/scheduler.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ class LRScheduler:
4
+ """Simple exponential decay learning rate scheduler."""
5
+
6
+ def __init__(self, initial_lr, final_lr_factor=0.01):
7
+ """
8
+ Args:
9
+ initial_lr: Starting learning rate
10
+ final_lr_factor: Final LR as fraction of initial (e.g., 0.01 means final_lr = 0.01 * initial_lr)
11
+ """
12
+ self.initial_lr = initial_lr
13
+ self.final_lr = initial_lr * final_lr_factor
14
+
15
+ def get_lr(self, iteration, total_iterations):
16
+ """Get learning rate for given iteration using exponential decay."""
17
+ if total_iterations <= 1:
18
+ return self.initial_lr
19
+
20
+ # Exponential decay from initial_lr to final_lr
21
+ progress = iteration / (total_iterations - 1)
22
+ progress = min(progress, 1.0) # Clamp to [0, 1]
23
+
24
+ # Exponential interpolation: lr = initial * (final/initial)^progress
25
+ lr_ratio = self.final_lr / self.initial_lr
26
+ current_lr = self.initial_lr * (lr_ratio ** progress)
27
+
28
+ return current_lr
gs/train.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import warp as wp
5
+ import imageio
6
+ import json
7
+ from tqdm import tqdm
8
+ from pathlib import Path
9
+ import argparse
10
+
11
+ from forward import render_gaussians
12
+ from backward import backward
13
+ from optimizer import prune_gaussians, adam_update, clone_gaussians, compact_gaussians, mark_split_candidates, mark_clone_candidates, split_gaussians, reset_opacities, reset_densification_stats
14
+ from config import *
15
+ from utils.camera_utils import load_camera
16
+ from utils.point_cloud_utils import save_ply
17
+ from loss import l1_loss, compute_image_gradients
18
+ from scheduler import LRScheduler
19
+
20
+ # Initialize Warp
21
+ wp.init()
22
+
23
+ # Kernels for parameter updates
24
+ @wp.kernel
25
+ def init_gaussian_params(
26
+ positions: wp.array(dtype=wp.vec3),
27
+ scales: wp.array(dtype=wp.vec3),
28
+ rotations: wp.array(dtype=wp.vec4),
29
+ opacities: wp.array(dtype=float),
30
+ shs: wp.array(dtype=wp.vec3),
31
+ num_points: int,
32
+ init_scale: float
33
+ ):
34
+ i = wp.tid()
35
+ if i >= num_points:
36
+ return
37
+
38
+ # Initialize positions with random values
39
+ # Generate random positions using warp random
40
+ offset = wp.vec3(
41
+ (wp.randf(wp.uint32(i * 3)) * 2.6 - 1.3),
42
+ (wp.randf(wp.uint32(i * 3 + 1)) * 2.6 - 1.3),
43
+ (wp.randf(wp.uint32(i * 3 + 2)) * 2.6 - 1.3)
44
+ )
45
+ # camera_center
46
+ positions[i] = offset
47
+
48
+ # Initialize scales
49
+ scales[i] = wp.vec3(init_scale, init_scale, init_scale)
50
+
51
+ # Initialize rotations to identity matrix
52
+ rotations[i] = wp.vec4(1.0, 0.0, 0.0, 0.0)
53
+
54
+ # Initialize opacities
55
+ opacities[i] = 0.1
56
+
57
+ # Initialize SH coefficients (just DC term for now)
58
+ for j in range(16): # degree=3, total 16 coefficients
59
+ idx = i * 16 + j
60
+ # Slight random initialization with positive bias
61
+ if j == 0:
62
+ shs[idx] = wp.vec3(-0.007, -0.007, -0.007)
63
+ else:
64
+ shs[idx] = wp.vec3(0.0, 0.0, 0.0)
65
+
66
+ @wp.kernel
67
+ def zero_gradients(
68
+ pos_grad: wp.array(dtype=wp.vec3),
69
+ scale_grad: wp.array(dtype=wp.vec3),
70
+ rot_grad: wp.array(dtype=wp.vec4),
71
+ opacity_grad: wp.array(dtype=float),
72
+ sh_grad: wp.array(dtype=wp.vec3),
73
+ num_points: int
74
+ ):
75
+ i = wp.tid()
76
+ if i >= num_points:
77
+ return
78
+
79
+ pos_grad[i] = wp.vec3(0.0, 0.0, 0.0)
80
+ scale_grad[i] = wp.vec3(0.0, 0.0, 0.0)
81
+ rot_grad[i] = wp.vec4(0.0, 0.0, 0.0, 0.0)
82
+ opacity_grad[i] = 0.0
83
+
84
+ # Zero SH gradients
85
+ for j in range(16):
86
+ idx = i * 16 + j
87
+ sh_grad[idx] = wp.vec3(0.0, 0.0, 0.0)
88
+
89
+
90
+
91
+ class NeRFGaussianSplattingTrainer:
92
+ def __init__(self, dataset_path, output_path, config=None):
93
+ """Initialize the 3D Gaussian Splatting trainer using pure Warp for NeRF dataset."""
94
+ self.dataset_path = Path(dataset_path)
95
+ self.output_path = Path(output_path)
96
+ self.output_path.mkdir(parents=True, exist_ok=True)
97
+
98
+ # Initialize configuration from GaussianParams
99
+ self.config = GaussianParams.get_config_dict()
100
+
101
+ if config is not None:
102
+ self.config.update(config)
103
+
104
+ # Initialize learning rate scheduler
105
+ self.lr_scheduler = self.create_lr_scheduler()
106
+ print(f"Learning rate scheduler: {'Enabled' if self.lr_scheduler else 'Disabled'}")
107
+
108
+ # For tracking learning rates
109
+ self.learning_rate_history = {
110
+ 'positions': [],
111
+ 'scales': [],
112
+ 'rotations': [],
113
+ 'shs': [],
114
+ 'opacities': []
115
+ }
116
+ # Load NeRF dataset
117
+ print(f"Loading NeRF dataset from {self.dataset_path}")
118
+ self.cameras, self.image_paths = self.load_nerf_data("train")
119
+ self.val_cameras, self.val_image_paths = self.load_nerf_data("val")
120
+ self.test_cameras, self.test_image_paths = self.load_nerf_data("test")
121
+ print(f"Loaded {len(self.cameras)} train cameras and {len(self.image_paths)} train images")
122
+ print(f"Loaded {len(self.val_cameras)} val cameras and {len(self.val_image_paths)} val images")
123
+ print(f"Loaded {len(self.test_cameras)} test cameras and {len(self.test_image_paths)} test images")
124
+
125
+ # Calculate scene extent for densification
126
+ self.scene_extent = self.calculate_scene_extent()
127
+ print(f"Calculated scene extent: {self.scene_extent}")
128
+
129
+ # Initialize parameters
130
+ self.num_points = self.config['num_points']
131
+ self.params = self.initialize_parameters()
132
+
133
+ # Create gradient arrays
134
+ self.grads = self.create_gradient_arrays()
135
+
136
+ # Create optimizer state
137
+ self.adam_m = self.create_gradient_arrays() # First moment
138
+ self.adam_v = self.create_gradient_arrays() # Second moment
139
+
140
+ # Initialize densification state tracking
141
+ self.init_densification_state()
142
+
143
+ # For tracking loss
144
+ self.losses = []
145
+
146
+ # Initialize intermediate buffers dictionary
147
+ self.intermediate_buffers = {}
148
+
149
+ # Track iteration for opacity reset
150
+ self.opacity_reset_at = -32768
151
+
152
+ def create_lr_scheduler(self):
153
+ """Create simple learning rate schedulers for each parameter type."""
154
+ if not self.config['use_lr_scheduler']:
155
+ return None
156
+
157
+ config = self.config['lr_scheduler_config']
158
+ final_factor = config['final_lr_factor']
159
+
160
+ schedulers = {
161
+ 'positions': LRScheduler(config['lr_pos'], final_factor),
162
+ 'scales': LRScheduler(config['lr_scale'], final_factor),
163
+ 'rotations': LRScheduler(config['lr_rot'], final_factor),
164
+ 'shs': LRScheduler(config['lr_sh'], final_factor),
165
+ 'opacities': LRScheduler(config['lr_opac'], final_factor)
166
+ }
167
+
168
+ return schedulers
169
+
170
+ def initialize_parameters(self):
171
+ """Initialize Gaussian parameters."""
172
+ positions = wp.zeros(self.num_points, dtype=wp.vec3)
173
+ scales = wp.zeros(self.num_points, dtype=wp.vec3)
174
+ rotations = wp.zeros(self.num_points, dtype=wp.vec4)
175
+ opacities = wp.zeros(self.num_points, dtype=float)
176
+ shs = wp.zeros(self.num_points * 16, dtype=wp.vec3) # 16 coeffs per point
177
+ # Launch kernel to initialize parameters
178
+ wp.launch(
179
+ init_gaussian_params,
180
+ dim=self.num_points,
181
+ inputs=[positions, scales, rotations, opacities, shs, self.num_points, self.config['initial_scale']]
182
+ )
183
+
184
+ # Return parameters as dictionary
185
+ return {
186
+ 'positions': positions,
187
+ 'scales': scales,
188
+ 'rotations': rotations,
189
+ 'opacities': opacities,
190
+ 'shs': shs
191
+ }
192
+
193
+ def create_gradient_arrays(self):
194
+ """Create arrays for gradients or optimizer state."""
195
+ positions = wp.zeros(self.num_points, dtype=wp.vec3)
196
+ scales = wp.zeros(self.num_points, dtype=wp.vec3)
197
+ rotations = wp.zeros(self.num_points, dtype=wp.vec4)
198
+ opacities = wp.zeros(self.num_points, dtype=float)
199
+ shs = wp.zeros(self.num_points * 16, dtype=wp.vec3)
200
+
201
+ # Return a dictionary of arrays
202
+ return {
203
+ 'positions': positions,
204
+ 'scales': scales,
205
+ 'rotations': rotations,
206
+ 'opacities': opacities,
207
+ 'shs': shs
208
+ }
209
+
210
+ def calculate_scene_extent(self):
211
+ """Calculate the extent of the scene based on camera positions."""
212
+ if not self.cameras:
213
+ return 1.0 # Default fallback
214
+
215
+ # Extract camera positions
216
+ camera_positions = []
217
+ for camera in self.cameras:
218
+ camera_positions.append(camera['camera_center'])
219
+
220
+ camera_positions = np.array(camera_positions)
221
+
222
+ # Calculate the centroid of all camera positions
223
+ scene_center = np.mean(camera_positions, axis=0)
224
+
225
+ # Calculate the maximum distance from any camera to the scene center
226
+ max_distance_to_center = 0.0
227
+ for pos in camera_positions:
228
+ distance = np.linalg.norm(pos - scene_center)
229
+ max_distance_to_center = max(max_distance_to_center, distance)
230
+
231
+ # The scene extent is the radius of the bounding sphere
232
+ # Use default factor if extent is too small
233
+ extent = max_distance_to_center * self.config.get('camera_extent_factor', 1.0)
234
+ return max(extent, 1.0)
235
+
236
+ def init_densification_state(self):
237
+ """Initialize state tracking for densification."""
238
+ self.xyz_gradient_accum = wp.zeros(self.num_points, dtype=float, device=DEVICE)
239
+ self.denom = wp.zeros(self.num_points, dtype=float, device=DEVICE)
240
+ self.max_radii2D = wp.zeros(self.num_points, dtype=float, device=DEVICE)
241
+
242
+ def load_nerf_data(self, datasplit):
243
+ """Load camera parameters and images from a NeRF dataset."""
244
+ # Read transforms_train.json
245
+ transforms_path = self.dataset_path / f"transforms_{datasplit}.json"
246
+ if not transforms_path.exists():
247
+ raise FileNotFoundError(f"No transforms_train.json found in {self.dataset_path}")
248
+
249
+ with open(transforms_path, 'r') as f:
250
+ transforms = json.load(f)
251
+
252
+ # Get image dimensions from the first image if available
253
+ first_frame = transforms['frames'][0]
254
+ first_img_path = str(self.dataset_path / f"{first_frame['file_path']}.png")
255
+ if os.path.exists(first_img_path):
256
+ # Load first image to get dimensions
257
+ img = imageio.imread(first_img_path)
258
+ width = img.shape[1]
259
+ height = img.shape[0]
260
+ print(f"Using image dimensions from dataset: {width}x{height}")
261
+ else:
262
+ # Use default dimensions from config if image not found
263
+ width = self.config['width']
264
+ height = self.config['height']
265
+ print(f"Using default dimensions: {width}x{height}")
266
+
267
+ # Update config with actual dimensions
268
+ self.config['width'] = width
269
+ self.config['height'] = height
270
+
271
+ self.config['camera_angle_x'] = transforms['camera_angle_x']
272
+
273
+ # Calculate focal length
274
+ focal = 0.5 * width / np.tan(0.5 * self.config['camera_angle_x'])
275
+
276
+ cameras = []
277
+ image_paths = []
278
+
279
+
280
+ # Process each frame
281
+ for i, frame in enumerate(transforms['frames']):
282
+ camera_info = {
283
+ "camera_id": i,
284
+ "camera_to_world": frame['transform_matrix'],
285
+ "width": width,
286
+ "height": height,
287
+ "focal": focal,
288
+ }
289
+
290
+ # Load camera parameters using existing function
291
+ camera_params = load_camera(camera_info)
292
+
293
+
294
+ if camera_params is not None:
295
+ cameras.append(camera_params)
296
+ image_paths.append(str(self.dataset_path / f"{frame['file_path']}.png"))
297
+
298
+ return cameras, image_paths
299
+
300
+ def load_image(self, path):
301
+ """Load an image as a numpy array."""
302
+ if os.path.exists(path):
303
+ img = imageio.imread(path)
304
+ # Convert to float and normalize to [0, 1]
305
+ img_np = img.astype(np.float32) / 255.0
306
+ # Ensure image is RGB (discard alpha channel if present)
307
+ if img_np.shape[2] == 4:
308
+ img_np = img_np[:, :, :3] # Keep only R, G, B channels
309
+ return img_np
310
+ else:
311
+ raise FileNotFoundError(f"Image not found: {path}")
312
+
313
+ def zero_grad(self):
314
+ """Zero out all gradients."""
315
+ wp.launch(
316
+ zero_gradients,
317
+ dim=self.num_points,
318
+ inputs=[
319
+ self.grads['positions'],
320
+ self.grads['scales'],
321
+ self.grads['rotations'],
322
+ self.grads['opacities'],
323
+ self.grads['shs'],
324
+ self.num_points
325
+ ]
326
+ )
327
+
328
+ def densification_and_pruning(self, iteration):
329
+ """Perform sophisticated densification and pruning of Gaussians."""
330
+
331
+ # Check if we should do densification
332
+ densify_from_iter = self.config.get('densify_from_iter', 500)
333
+ densify_until_iter = self.config.get('densify_until_iter', 15000)
334
+ densification_interval = self.config.get('densification_interval', 100)
335
+ opacity_reset_interval = self.config.get('opacity_reset_interval', 3000)
336
+
337
+ # Skip densification if outside iteration range
338
+ if iteration > densify_from_iter and iteration < densify_until_iter and iteration % densification_interval == 0:
339
+ print(f"Iteration {iteration}: Performing sophisticated densification and pruning")
340
+
341
+ # For simplified implementation, use position gradients as proxy for viewspace gradients
342
+ pos_grads = self.grads['positions']
343
+ avg_grads = wp.zeros(self.num_points, dtype=float, device=DEVICE)
344
+
345
+ @wp.kernel
346
+ def compute_grad_norms(pos_grad: wp.array(dtype=wp.vec3),
347
+ grad_norms: wp.array(dtype=float),
348
+ num_points: int):
349
+ i = wp.tid()
350
+ if i >= num_points:
351
+ return
352
+ grad_norms[i] = wp.length(pos_grad[i])
353
+
354
+ wp.launch(compute_grad_norms, dim=self.num_points,
355
+ inputs=[pos_grads, avg_grads, self.num_points])
356
+
357
+ # Configuration
358
+ grad_threshold = self.config.get('densify_grad_threshold', 0.0002)
359
+ percent_dense = self.config.get('percent_dense', 0.01)
360
+
361
+ # --- Step 1: Clone small Gaussians with high gradients ---
362
+ clone_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
363
+ wp.launch(
364
+ mark_clone_candidates,
365
+ dim=self.num_points,
366
+ inputs=[
367
+ avg_grads,
368
+ self.params['scales'],
369
+ grad_threshold,
370
+ self.scene_extent,
371
+ percent_dense,
372
+ clone_mask,
373
+ self.num_points
374
+ ]
375
+ )
376
+
377
+ # Perform cloning
378
+ clone_prefix_sum = wp.zeros_like(clone_mask)
379
+ wp.utils.array_scan(clone_mask, clone_prefix_sum, inclusive=False)
380
+ total_to_clone = int(clone_prefix_sum.numpy()[-1])
381
+
382
+ if total_to_clone > 0:
383
+ print(f"[Clone] Cloning {total_to_clone} small Gaussians")
384
+ N = self.num_points
385
+ new_N = N + total_to_clone
386
+
387
+ # Allocate output arrays
388
+ out_params = {
389
+ 'positions': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
390
+ 'scales': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
391
+ 'rotations': wp.zeros(new_N, dtype=wp.vec4, device=DEVICE),
392
+ 'opacities': wp.zeros(new_N, dtype=float, device=DEVICE),
393
+ 'shs': wp.zeros(new_N * 16, dtype=wp.vec3, device=DEVICE)
394
+ }
395
+
396
+ # Clone Gaussians
397
+ wp.launch(
398
+ clone_gaussians,
399
+ dim=N,
400
+ inputs=[
401
+ clone_mask,
402
+ clone_prefix_sum,
403
+ self.params['positions'],
404
+ self.params['scales'],
405
+ self.params['rotations'],
406
+ self.params['opacities'],
407
+ self.params['shs'],
408
+ 0.01, # noise_scale
409
+ N, # offset
410
+ out_params['positions'],
411
+ out_params['scales'],
412
+ out_params['rotations'],
413
+ out_params['opacities'],
414
+ out_params['shs']
415
+ ]
416
+ )
417
+
418
+ # Update parameters and state
419
+ self.params = out_params
420
+ self.num_points = new_N
421
+ self.grads = self.create_gradient_arrays()
422
+ self.adam_m = self.create_gradient_arrays()
423
+ self.adam_v = self.create_gradient_arrays()
424
+
425
+ # --- Step 2: Split large Gaussians with high gradients ---
426
+ split_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
427
+ wp.launch(
428
+ mark_split_candidates,
429
+ dim=self.num_points,
430
+ inputs=[
431
+ avg_grads,
432
+ self.params['scales'],
433
+ grad_threshold,
434
+ self.scene_extent,
435
+ percent_dense,
436
+ split_mask,
437
+ self.num_points
438
+ ]
439
+ )
440
+
441
+ # Perform splitting
442
+ split_prefix_sum = wp.zeros_like(split_mask)
443
+ wp.utils.array_scan(split_mask, split_prefix_sum, inclusive=False)
444
+ total_to_split = int(split_prefix_sum.numpy()[-1])
445
+
446
+ if total_to_split > 0:
447
+ print(f"[Split] Splitting {total_to_split} large Gaussians")
448
+ N = self.num_points
449
+ N_split = 2 # Split each Gaussian into 2
450
+ new_N = N + total_to_split * N_split
451
+
452
+ # Allocate output arrays
453
+ out_params = {
454
+ 'positions': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
455
+ 'scales': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
456
+ 'rotations': wp.zeros(new_N, dtype=wp.vec4, device=DEVICE),
457
+ 'opacities': wp.zeros(new_N, dtype=float, device=DEVICE),
458
+ 'shs': wp.zeros(new_N * 16, dtype=wp.vec3, device=DEVICE)
459
+ }
460
+
461
+ # Split Gaussians
462
+ wp.launch(
463
+ split_gaussians,
464
+ dim=N,
465
+ inputs=[
466
+ split_mask,
467
+ split_prefix_sum,
468
+ self.params['positions'],
469
+ self.params['scales'],
470
+ self.params['rotations'],
471
+ self.params['opacities'],
472
+ self.params['shs'],
473
+ N_split, # Number of splits per Gaussian
474
+ 0.8, # scale_factor
475
+ N, # offset
476
+ out_params['positions'],
477
+ out_params['scales'],
478
+ out_params['rotations'],
479
+ out_params['opacities'],
480
+ out_params['shs']
481
+ ]
482
+ )
483
+
484
+ # Update parameters and state
485
+ self.params = out_params
486
+ self.num_points = new_N
487
+ self.grads = self.create_gradient_arrays()
488
+ self.adam_m = self.create_gradient_arrays()
489
+ self.adam_v = self.create_gradient_arrays()
490
+
491
+ # Remove original split Gaussians
492
+ prune_filter = wp.zeros(self.num_points, dtype=int, device=DEVICE)
493
+
494
+ @wp.kernel
495
+ def mark_split_originals_for_removal(
496
+ split_mask: wp.array(dtype=int),
497
+ prune_filter: wp.array(dtype=int),
498
+ offset: int,
499
+ num_points: int
500
+ ):
501
+ i = wp.tid()
502
+ if i >= num_points:
503
+ return
504
+ if i < offset and split_mask[i] == 1:
505
+ prune_filter[i] = 1 # Mark for removal
506
+ else:
507
+ prune_filter[i] = 0 # Keep
508
+
509
+ wp.launch(mark_split_originals_for_removal, dim=self.num_points,
510
+ inputs=[split_mask, prune_filter, N, self.num_points])
511
+
512
+ # Invert mask to get valid mask
513
+ valid_mask = wp.zeros_like(prune_filter)
514
+
515
+ @wp.kernel
516
+ def invert_mask(prune: wp.array(dtype=int), valid: wp.array(dtype=int), n: int):
517
+ i = wp.tid()
518
+ if i >= n:
519
+ return
520
+ valid[i] = 1 - prune[i]
521
+
522
+ wp.launch(invert_mask, dim=self.num_points,
523
+ inputs=[prune_filter, valid_mask, self.num_points])
524
+
525
+ # Count valid points and compact
526
+ prefix_sum = wp.zeros_like(valid_mask)
527
+ wp.utils.array_scan(valid_mask, prefix_sum, inclusive=False)
528
+ valid_count = int(prefix_sum.numpy()[-1])
529
+
530
+ if valid_count < self.num_points:
531
+ print(f"[Split] Removing {self.num_points - valid_count} original split Gaussians")
532
+
533
+ # Allocate compacted output
534
+ compact_params = {
535
+ 'positions': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
536
+ 'scales': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
537
+ 'rotations': wp.zeros(valid_count, dtype=wp.vec4, device=DEVICE),
538
+ 'opacities': wp.zeros(valid_count, dtype=float, device=DEVICE),
539
+ 'shs': wp.zeros(valid_count * 16, dtype=wp.vec3, device=DEVICE)
540
+ }
541
+
542
+ wp.launch(
543
+ compact_gaussians,
544
+ dim=self.num_points,
545
+ inputs=[
546
+ valid_mask,
547
+ prefix_sum,
548
+ self.params['positions'],
549
+ self.params['scales'],
550
+ self.params['rotations'],
551
+ self.params['opacities'],
552
+ self.params['shs'],
553
+ compact_params['positions'],
554
+ compact_params['scales'],
555
+ compact_params['rotations'],
556
+ compact_params['opacities'],
557
+ compact_params['shs']
558
+ ]
559
+ )
560
+
561
+ # Update parameters and state
562
+ self.params = compact_params
563
+ self.num_points = valid_count
564
+ self.grads = self.create_gradient_arrays()
565
+ self.adam_m = self.create_gradient_arrays()
566
+ self.adam_v = self.create_gradient_arrays()
567
+
568
+ # --- Step 3: Enhanced Pruning ---
569
+ print(f"[Prune] Performing enhanced pruning")
570
+
571
+ valid_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
572
+
573
+ # Use opacity-based pruning for now
574
+ wp.launch(
575
+ prune_gaussians,
576
+ dim=self.num_points,
577
+ inputs=[
578
+ self.params['opacities'],
579
+ self.config.get('cull_opacity_threshold', 0.005),
580
+ valid_mask,
581
+ self.num_points
582
+ ]
583
+ )
584
+
585
+ # Count valid points
586
+ prefix_sum = wp.zeros_like(valid_mask)
587
+ wp.utils.array_scan(valid_mask, prefix_sum, inclusive=False)
588
+ valid_count = int(prefix_sum.numpy()[-1])
589
+
590
+ # Check pruning constraints
591
+ min_valid_points = self.config.get('min_valid_points', 1000)
592
+ max_valid_points = self.config.get('max_valid_points', 1000000)
593
+ max_prune_ratio = self.config.get('max_allowed_prune_ratio', 0.5)
594
+
595
+ prune_count = self.num_points - valid_count
596
+ prune_ratio = prune_count / self.num_points if self.num_points > 0 else 0
597
+
598
+ if (valid_count >= min_valid_points and
599
+ valid_count <= max_valid_points and
600
+ prune_ratio <= max_prune_ratio and
601
+ valid_count < self.num_points):
602
+
603
+ print(f"[Prune] Compacting from {self.num_points} → {valid_count} points")
604
+
605
+ # Allocate compacted output
606
+ out_params = {
607
+ 'positions': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
608
+ 'scales': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
609
+ 'rotations': wp.zeros(valid_count, dtype=wp.vec4, device=DEVICE),
610
+ 'opacities': wp.zeros(valid_count, dtype=float, device=DEVICE),
611
+ 'shs': wp.zeros(valid_count * 16, dtype=wp.vec3, device=DEVICE)
612
+ }
613
+
614
+ wp.launch(
615
+ compact_gaussians,
616
+ dim=self.num_points,
617
+ inputs=[
618
+ valid_mask,
619
+ prefix_sum,
620
+ self.params['positions'],
621
+ self.params['scales'],
622
+ self.params['rotations'],
623
+ self.params['opacities'],
624
+ self.params['shs'],
625
+ out_params['positions'],
626
+ out_params['scales'],
627
+ out_params['rotations'],
628
+ out_params['opacities'],
629
+ out_params['shs']
630
+ ]
631
+ )
632
+
633
+ # Update parameters and state
634
+ self.params = out_params
635
+ self.num_points = valid_count
636
+ self.grads = self.create_gradient_arrays()
637
+ self.adam_m = self.create_gradient_arrays()
638
+ self.adam_v = self.create_gradient_arrays()
639
+ else:
640
+ print(f"[Prune] Skipping pruning: valid={valid_count}, ratio={prune_ratio:.3f}")
641
+
642
+
643
+ # Opacity reset - updated logic to match reference implementation
644
+ background_is_white = all(c == 1.0 for c in self.config['background_color'])
645
+ should_reset_opacity = (
646
+ iteration % opacity_reset_interval == 0 or
647
+ (background_is_white and iteration == densify_from_iter)
648
+ )
649
+
650
+ if should_reset_opacity:
651
+ print(f"Iteration {iteration}: Resetting opacities")
652
+ wp.launch(
653
+ reset_opacities,
654
+ dim=self.num_points,
655
+ inputs=[
656
+ self.params['opacities'],
657
+ 0.01, # max_opacity
658
+ self.num_points
659
+ ]
660
+ )
661
+
662
+
663
+ def optimizer_step(self, iteration):
664
+ """Perform an Adam optimization step."""
665
+
666
+ # Get learning rates from scheduler or use config defaults
667
+ if self.lr_scheduler:
668
+ lr_pos = self.lr_scheduler['positions'].get_lr(iteration, self.config['num_iterations'])
669
+ lr_scale = self.lr_scheduler['scales'].get_lr(iteration, self.config['num_iterations'])
670
+ lr_rot = self.lr_scheduler['rotations'].get_lr(iteration, self.config['num_iterations'])
671
+ lr_sh = self.lr_scheduler['shs'].get_lr(iteration, self.config['num_iterations'])
672
+ lr_opac = self.lr_scheduler['opacities'].get_lr(iteration, self.config['num_iterations'])
673
+
674
+ # Track learning rate history
675
+ self.learning_rate_history['positions'].append(lr_pos)
676
+ self.learning_rate_history['scales'].append(lr_scale)
677
+ self.learning_rate_history['rotations'].append(lr_rot)
678
+ self.learning_rate_history['shs'].append(lr_sh)
679
+ self.learning_rate_history['opacities'].append(lr_opac)
680
+
681
+ # Log learning rates occasionally
682
+ if iteration % 1000 == 0:
683
+ print(f"Iteration {iteration} learning rates:")
684
+ print(f" positions: {lr_pos:.6f}")
685
+ print(f" scales: {lr_scale:.6f}")
686
+ print(f" rotations: {lr_rot:.6f}")
687
+ print(f" shs: {lr_sh:.6f}")
688
+ print(f" opacities: {lr_opac:.6f}")
689
+ else:
690
+ # Use static learning rates from config
691
+ lr_pos = self.config['lr_pos']
692
+ lr_scale = self.config['lr_scale']
693
+ lr_rot = self.config['lr_rot']
694
+ lr_sh = self.config['lr_sh']
695
+ lr_opac = self.config['lr_opac']
696
+
697
+ wp.launch(
698
+ adam_update,
699
+ dim=self.num_points,
700
+ inputs=[
701
+ # Parameters
702
+ self.params['positions'],
703
+ self.params['scales'],
704
+ self.params['rotations'],
705
+ self.params['opacities'],
706
+ self.params['shs'],
707
+
708
+ # Gradients
709
+ self.grads['positions'],
710
+ self.grads['scales'],
711
+ self.grads['rotations'],
712
+ self.grads['opacities'],
713
+ self.grads['shs'],
714
+
715
+ # First moments (m)
716
+ self.adam_m['positions'],
717
+ self.adam_m['scales'],
718
+ self.adam_m['rotations'],
719
+ self.adam_m['opacities'],
720
+ self.adam_m['shs'],
721
+
722
+ # Second moments (v)
723
+ self.adam_v['positions'],
724
+ self.adam_v['scales'],
725
+ self.adam_v['rotations'],
726
+ self.adam_v['opacities'],
727
+ self.adam_v['shs'],
728
+
729
+ # Optimizer parameters with dynamic learning rates
730
+ self.num_points,
731
+ lr_pos, # Dynamic learning rate for positions
732
+ lr_scale, # Dynamic learning rate for scales
733
+ lr_rot, # Dynamic learning rate for rotations
734
+ lr_sh, # Dynamic learning rate for SH coefficients
735
+ lr_opac, # Dynamic learning rate for opacities
736
+ self.config['adam_beta1'],
737
+ self.config['adam_beta2'],
738
+ self.config['adam_epsilon'],
739
+ iteration
740
+ ]
741
+ )
742
+
743
+ def save_checkpoint(self, iteration):
744
+ """Save the current point cloud and training state."""
745
+ checkpoint_dir = self.output_path / "point_cloud" / f"iteration_{iteration}"
746
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
747
+
748
+ # Save point cloud as PLY
749
+ ply_path = checkpoint_dir / "point_cloud.ply"
750
+ save_ply(self.params, ply_path, self.num_points)
751
+
752
+ # Save loss history
753
+ loss_path = self.output_path / "loss.txt"
754
+ with open(loss_path, 'w') as f:
755
+ for loss in self.losses:
756
+ f.write(f"{loss}\n")
757
+
758
+ # Save loss plot
759
+ plt.figure(figsize=(10, 5))
760
+ plt.plot(self.losses)
761
+ plt.title('Training Loss')
762
+ plt.xlabel('Iteration')
763
+ plt.ylabel('Loss')
764
+ plt.savefig(self.output_path / "loss_plot.png")
765
+ plt.close()
766
+
767
+ # Save a rendered view
768
+ camera_idx = 0 # Front view
769
+ rendered_image, _, _ = render_gaussians(
770
+ background=np.array(self.config['background_color'], dtype=np.float32),
771
+ means3D=self.params['positions'].numpy(),
772
+ colors=None, # Use SH coefficients instead
773
+ opacity=self.params['opacities'].numpy(),
774
+ scales=self.params['scales'].numpy(),
775
+ rotations=self.params['rotations'].numpy(),
776
+ scale_modifier=self.config['scale_modifier'],
777
+ viewmatrix=self.cameras[camera_idx]['world_to_camera'],
778
+ projmatrix=self.cameras[camera_idx]['full_proj_matrix'],
779
+ tan_fovx=self.cameras[camera_idx]['tan_fovx'],
780
+ tan_fovy=self.cameras[camera_idx]['tan_fovy'],
781
+ image_height=self.cameras[camera_idx]['height'],
782
+ image_width=self.cameras[camera_idx]['width'],
783
+ sh=self.params['shs'].numpy(), # Pass SH coefficients
784
+ degree=self.config['sh_degree'],
785
+ campos=self.cameras[camera_idx]['camera_center'],
786
+ prefiltered=False,
787
+ antialiasing=True,
788
+ clamped=True
789
+ )
790
+ # Save rendered view as image
791
+ rendered_array = wp.to_torch(rendered_image).cpu().numpy()
792
+ # Handle case where rendered_array has shape (3, H, W) - transpose to (H, W, 3)
793
+ if rendered_array.shape[0] == 3 and len(rendered_array.shape) == 3:
794
+ rendered_array = np.transpose(rendered_array, (1, 2, 0))
795
+ img8 = (np.clip(rendered_array, 0, 1) * 255).astype(np.uint8)
796
+ imageio.imwrite(checkpoint_dir / "rendered_view.png", img8)
797
+
798
+
799
+ def debug_log_and_save_images(
800
+ self,
801
+ rendered_image, # np.float32 H×W×3 (range 0-1)
802
+ target_image, # np.float32
803
+ depth_image, # wp.array2d(float) – optional but unused here
804
+ camera_idx: int,
805
+ it: int
806
+ ):
807
+
808
+ # ------ quick numeric read-out -----------------------------------
809
+ radii = wp.to_torch(self.intermediate_buffers["radii"]).cpu().numpy()
810
+ alphas = wp.to_torch(self.intermediate_buffers["conic_opacity"]).cpu().numpy()[:, 3]
811
+ offs = wp.to_torch(self.intermediate_buffers["point_offsets"]).cpu().numpy()
812
+ num_dup = int(offs[-1]) if len(offs) else 0
813
+ r_med = np.median(radii[radii > 0]) if (radii > 0).any() else 0
814
+
815
+ # Count visible Gaussians
816
+ xy_image = wp.to_torch(self.intermediate_buffers["points_xy_image"]).cpu().numpy()
817
+ W = self.cameras[camera_idx]['width']
818
+ H = self.cameras[camera_idx]['height']
819
+ visible_gaussians = np.sum(
820
+ (xy_image[:, 0] >= 0) & (xy_image[:, 0] < W) &
821
+ (xy_image[:, 1] >= 0) & (xy_image[:, 1] < H) &
822
+ np.isfinite(xy_image).all(axis=1) &
823
+ (radii > 0) # Only count Gaussians with positive radius
824
+ )
825
+
826
+ print(
827
+ f"[it {it:05d}] dup={num_dup:<6} "
828
+ f"r_med={r_med:5.1f} α∈[{alphas.min():.3f},"
829
+ f"{np.median(alphas):.3f},{alphas.max():.3f}] "
830
+ f"visible={visible_gaussians}/{len(xy_image)}"
831
+ )
832
+
833
+ # ------ save render / target PNG ---------------------------------
834
+ def save_rgb(arr_f32, stem):
835
+ # Handle case where arr_f32 has shape (3, H, W) - transpose to (H, W, 3)
836
+ if arr_f32.shape[0] == 3 and len(arr_f32.shape) == 3:
837
+ arr_f32 = np.transpose(arr_f32, (1, 2, 0))
838
+ img8 = (np.clip(arr_f32, 0, 1) * 255).astype(np.uint8)
839
+ imageio.imwrite(self.output_path / f"{stem}_{it:06d}.png", img8)
840
+
841
+ save_rgb(rendered_image if isinstance(rendered_image, np.ndarray) else wp.to_torch(rendered_image).cpu().numpy(), "render")
842
+ save_rgb(target_image, "target")
843
+
844
+ # ------ make 2-D projection scatter ------------------------------
845
+ xy = wp.to_torch(self.intermediate_buffers["points_xy_image"]).cpu().numpy()
846
+ depth = wp.to_torch(self.intermediate_buffers["depths"]).cpu().numpy()
847
+ H, W = self.config["height"], self.config["width"]
848
+
849
+ mask = (
850
+ (xy[:, 0] >= 0) & (xy[:, 0] < W) &
851
+ (xy[:, 1] >= 0) & (xy[:, 1] < H) &
852
+ np.isfinite(xy).all(axis=1) &
853
+ (radii > 0) # Only include Gaussians with positive radius
854
+ )
855
+ if mask.any():
856
+ plt.figure(figsize=(6, 6))
857
+ plt.scatter(xy[mask, 0], xy[mask, 1],
858
+ s=4, c=depth[mask], cmap="turbo", alpha=.7)
859
+ plt.gca().invert_yaxis()
860
+ plt.xlim(0, W); plt.ylim(H, 0)
861
+ plt.title(f"Projected Gaussians (iter {it}): {np.sum(mask)}/{len(xy)} visible")
862
+ plt.colorbar(label="depth(z)")
863
+ plt.tight_layout()
864
+ plt.savefig(self.output_path / f"proj_{it:06d}.png", dpi=250)
865
+ plt.close()
866
+
867
+ # depth histogram
868
+ plt.figure(figsize=(5, 3))
869
+ plt.hist(depth[mask], bins=40, color="steelblue")
870
+ plt.xlabel("depth (camera-z)")
871
+ plt.ylabel("count")
872
+ plt.title(f"Depth hist – {mask.sum()} pts")
873
+ plt.tight_layout()
874
+ plt.savefig(self.output_path / f"depth_hist_{it:06d}.png", dpi=250)
875
+ plt.close()
876
+
877
+ def train(self):
878
+ """Train the 3D Gaussian Splatting model."""
879
+ num_iterations = self.config['num_iterations']
880
+
881
+ # Main training loop
882
+ with tqdm(total=num_iterations) as pbar:
883
+ for iteration in range(num_iterations):
884
+ # Select a random camera and corresponding image
885
+ camera_idx = np.random.randint(0, len(self.cameras))
886
+ image_path = self.image_paths[camera_idx]
887
+ target_image = self.load_image(image_path)
888
+
889
+ # Zero gradients
890
+ self.zero_grad()
891
+ # Render the view
892
+ rendered_image, depth_image, self.intermediate_buffers = render_gaussians(
893
+ background=np.array(self.config['background_color'], dtype=np.float32),
894
+ means3D=self.params['positions'].numpy(),
895
+ colors=None, # Use SH coefficients instead
896
+ opacity=self.params['opacities'].numpy(),
897
+ scales=self.params['scales'].numpy(),
898
+ rotations=self.params['rotations'].numpy(),
899
+ scale_modifier=self.config['scale_modifier'],
900
+ viewmatrix=self.cameras[camera_idx]['world_to_camera'],
901
+ projmatrix=self.cameras[camera_idx]['full_proj_matrix'],
902
+ tan_fovx=self.cameras[camera_idx]['tan_fovx'],
903
+ tan_fovy=self.cameras[camera_idx]['tan_fovy'],
904
+ image_height=self.cameras[camera_idx]['height'],
905
+ image_width=self.cameras[camera_idx]['width'],
906
+ sh=self.params['shs'].numpy(), # Pass SH coefficients
907
+ degree=self.config['sh_degree'],
908
+ campos=self.cameras[camera_idx]['camera_center'],
909
+ prefiltered=False,
910
+ antialiasing=False,
911
+ clamped=True
912
+ )
913
+
914
+ radii = wp.to_torch(self.intermediate_buffers["radii"]).cpu().numpy()
915
+ np_rendered_image = wp.to_torch(rendered_image).cpu().numpy()
916
+ np_rendered_image = np_rendered_image.transpose(2, 0, 1)
917
+
918
+ if iteration % self.config['save_interval'] == 0:
919
+ self.debug_log_and_save_images(np_rendered_image, target_image, depth_image, camera_idx, iteration)
920
+
921
+ # Calculate L1 loss
922
+ l1_val = l1_loss(rendered_image, target_image)
923
+
924
+ # # Calculate SSIM, not used
925
+ # ssim_val = ssim(rendered_image, target_image)
926
+ # # Combined loss with weighted SSIM
927
+ # lambda_dssim = self.config['lambda_dssim']
928
+ # # loss = (1 - λ) * L1 + λ * (1 - SSIM)
929
+ # loss = (1.0 - lambda_dssim) * l1_val + lambda_dssim * (1.0 - ssim_val)
930
+
931
+ loss = l1_val
932
+ self.losses.append(loss)
933
+ # Compute pixel gradients for image loss (dL/dColor)
934
+ pixel_grad_buffer = compute_image_gradients(
935
+ rendered_image, target_image, lambda_dssim=0
936
+ )
937
+
938
+ # Prepare camera parameters
939
+ camera = self.cameras[camera_idx]
940
+ view_matrix = wp.mat44(camera['world_to_camera'].flatten())
941
+ proj_matrix = wp.mat44(camera['full_proj_matrix'].flatten())
942
+ campos = wp.vec3(camera['camera_center'][0], camera['camera_center'][1], camera['camera_center'][2])
943
+
944
+ # Create appropriate buffer dictionaries for the backward pass
945
+ geom_buffer = {
946
+ 'radii': self.intermediate_buffers['radii'],
947
+ 'means2D': self.intermediate_buffers['points_xy_image'],
948
+ 'conic_opacity': self.intermediate_buffers['conic_opacity'],
949
+ 'rgb': self.intermediate_buffers['colors'],
950
+ 'clamped': self.intermediate_buffers['clamped_state']
951
+ }
952
+
953
+ binning_buffer = {
954
+ 'point_list': self.intermediate_buffers['point_list']
955
+ }
956
+
957
+ img_buffer = {
958
+ 'ranges': self.intermediate_buffers['ranges'],
959
+ 'final_Ts': self.intermediate_buffers['final_Ts'],
960
+ 'n_contrib': self.intermediate_buffers['n_contrib']
961
+ }
962
+
963
+ gradients = backward(
964
+ # Core parameters
965
+ background=np.array(self.config['background_color'], dtype=np.float32),
966
+ means3D=self.params['positions'],
967
+ dL_dpixels=pixel_grad_buffer,
968
+
969
+ # Model parameters (pass directly from self.params)
970
+ opacity=self.params['opacities'],
971
+ shs=self.params['shs'],
972
+ scales=self.params['scales'],
973
+ rotations=self.params['rotations'],
974
+ scale_modifier=self.config['scale_modifier'],
975
+
976
+ # Camera parameters
977
+ viewmatrix=view_matrix,
978
+ projmatrix=proj_matrix,
979
+ tan_fovx=camera['tan_fovx'],
980
+ tan_fovy=camera['tan_fovy'],
981
+ image_height=camera['height'],
982
+ image_width=camera['width'],
983
+ campos=campos,
984
+
985
+ # Forward output buffers
986
+ radii=self.intermediate_buffers['radii'],
987
+ means2D=self.intermediate_buffers['points_xy_image'],
988
+ conic_opacity=self.intermediate_buffers['conic_opacity'],
989
+ rgb=self.intermediate_buffers['colors'],
990
+ cov3Ds=self.intermediate_buffers['cov3Ds'],
991
+ clamped=self.intermediate_buffers['clamped_state'],
992
+
993
+ # Internal state buffers
994
+ geom_buffer=geom_buffer,
995
+ binning_buffer=binning_buffer,
996
+ img_buffer=img_buffer,
997
+
998
+ # Algorithm parameters
999
+ degree=self.config['sh_degree'],
1000
+ debug=False
1001
+ )
1002
+
1003
+ # 3. Copy gradients from backward result to the optimizer's gradient buffers
1004
+ wp.copy(self.grads['positions'], gradients['dL_dmean3D'])
1005
+ wp.copy(self.grads['scales'], gradients['dL_dscale'])
1006
+ wp.copy(self.grads['rotations'], gradients['dL_drot'])
1007
+ wp.copy(self.grads['opacities'], gradients['dL_dopacity'])
1008
+ wp.copy(self.grads['shs'], gradients['dL_dshs'])
1009
+
1010
+ # Update parameters
1011
+ self.optimizer_step(iteration)
1012
+
1013
+ # Update progress bar
1014
+ pbar.update(1)
1015
+ pbar.set_description(f"Loss: {loss:.6f}")
1016
+
1017
+ self.densification_and_pruning(iteration)
1018
+
1019
+ # Save checkpoint
1020
+ if iteration % self.config['save_interval'] == 0 or iteration == num_iterations - 1:
1021
+ self.save_checkpoint(iteration)
1022
+
1023
+ print("Training complete!")
1024
+
1025
+
1026
+ def main():
1027
+ parser = argparse.ArgumentParser(description="Train 3D Gaussian Splatting model with NeRF dataset")
1028
+ parser.add_argument("--dataset", type=str, default="./data/nerf_synthetic/lego",
1029
+ help="Path to NeRF dataset directory (default: Lego dataset)")
1030
+ parser.add_argument("--output", type=str, default="./output", help="Output directory")
1031
+
1032
+ args = parser.parse_args()
1033
+
1034
+ # Create trainer and start training
1035
+ trainer = NeRFGaussianSplattingTrainer(
1036
+ dataset_path=args.dataset,
1037
+ output_path=args.output,
1038
+ )
1039
+
1040
+ trainer.train()
1041
+
1042
+
1043
+ if __name__ == "__main__":
1044
+ main()
gs/train_colmap.py ADDED
@@ -0,0 +1,1586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import warp as wp
5
+ import imageio
6
+ import json
7
+ from tqdm import tqdm
8
+ from pathlib import Path
9
+ import argparse
10
+
11
+ from forward import render_gaussians
12
+ from backward import backward
13
+ from optimizer import prune_gaussians, adam_update, clone_gaussians, compact_gaussians, mark_split_candidates, mark_clone_candidates, split_gaussians, reset_opacities, reset_densification_stats
14
+ from config import *
15
+ from utils.camera_utils import load_camera, load_camera_colmap
16
+ from utils.point_cloud_utils import save_ply
17
+ from loss import l1_loss, compute_image_gradients
18
+ from scheduler import LRScheduler
19
+ from utils.math_utils import quaternion_to_rotation_matrix
20
+ from plyfile import PlyData, PlyElement
21
+ from scipy.spatial import cKDTree # Add this import
22
+ # Initialize Warp
23
+ wp.init()
24
+
25
+ # Kernels for parameter updates
26
+ @wp.kernel
27
+ def init_gaussian_params(
28
+ #positions: wp.array(dtype=wp.vec3),
29
+ #scales: wp.array(dtype=wp.vec3), # Keep as input, but it will be pre-filled
30
+ rotations: wp.array(dtype=wp.vec4),
31
+ opacities: wp.array(dtype=float),
32
+ #shs: wp.array(dtype=wp.vec3),
33
+ num_points: int
34
+ # init_scale: float # Remove init_scale, it's no longer used here
35
+ ):
36
+ i = wp.tid()
37
+ if i >= num_points:
38
+ return
39
+
40
+ # Initialize positions with random values (This is commented out in your version)
41
+ # Generate random positions using warp random
42
+ # offset = wp.vec3(
43
+ # (wp.randf(wp.uint32(i * 3)) * 2.6 - 1.3),
44
+ # (wp.randf(wp.uint32(i * 3 + 1)) * 2.6 - 1.3),
45
+ # (wp.randf(wp.uint32(i * 3 + 2)) * 2.6 - 1.3)
46
+ # )
47
+ # # camera_center
48
+ # positions[i] = offset
49
+
50
+ # Initialize scales (This line is removed, scales are pre-calculated)
51
+ # scales[i] = wp.vec3(init_scale, init_scale, init_scale)
52
+
53
+ # Initialize rotations to identity matrix
54
+ rotations[i] = wp.vec4(1.0, 0.0, 0.0, 0.0)
55
+
56
+ # Initialize opacities
57
+ opacities[i] = 0.1
58
+
59
+ # Initialize SH coefficients (This is commented out in your version)
60
+ # for j in range(16): # degree=3, total 16 coefficients
61
+ # idx = i * 16 + j
62
+ # # Slight random initialization with positive bias
63
+ # if j == 0:
64
+ # shs[idx] = wp.vec3(-0.007, -0.007, -0.007)
65
+ # else:
66
+ # shs[idx] = wp.vec3(0.0, 0.0, 0.0)
67
+
68
+ @wp.kernel
69
+ def zero_gradients(
70
+ pos_grad: wp.array(dtype=wp.vec3),
71
+ scale_grad: wp.array(dtype=wp.vec3),
72
+ rot_grad: wp.array(dtype=wp.vec4),
73
+ opacity_grad: wp.array(dtype=float),
74
+ sh_grad: wp.array(dtype=wp.vec3),
75
+ num_points: int
76
+ ):
77
+ i = wp.tid()
78
+ if i >= num_points:
79
+ return
80
+
81
+ pos_grad[i] = wp.vec3(0.0, 0.0, 0.0)
82
+ scale_grad[i] = wp.vec3(0.0, 0.0, 0.0)
83
+ rot_grad[i] = wp.vec4(0.0, 0.0, 0.0, 0.0)
84
+ opacity_grad[i] = 0.0
85
+
86
+ # Zero SH gradients
87
+ for j in range(16):
88
+ idx = i * 16 + j
89
+ sh_grad[idx] = wp.vec3(0.0, 0.0, 0.0)
90
+
91
+
92
+
93
+ class NeRFGaussianSplattingTrainer:
94
+ def __init__(self, dataset_path, output_path, config=None):
95
+ """Initialize the 3D Gaussian Splatting trainer using pure Warp for NeRF dataset."""
96
+ self.dataset_path = Path(dataset_path)
97
+ self.output_path = Path(output_path)
98
+
99
+ # Create output directories
100
+ self.output_path.mkdir(parents=True, exist_ok=True)
101
+ (self.output_path / "proj").mkdir(exist_ok=True)
102
+ (self.output_path / "render").mkdir(exist_ok=True)
103
+ (self.output_path / "target").mkdir(exist_ok=True)
104
+ (self.output_path / "depth_hist").mkdir(exist_ok=True)
105
+ (self.output_path / "point_cloud").mkdir(exist_ok=True)
106
+
107
+ # Initialize configuration from GaussianParams
108
+ self.config = GaussianParams.get_config_dict()
109
+
110
+ if config is not None:
111
+ self.config.update(config)
112
+
113
+ # Set default number of points (will be updated if points3D.ply is loaded)
114
+ self.num_points = self.config.get('num_points', 50000)
115
+
116
+ # Initialize learning rate scheduler
117
+ self.lr_scheduler = self.create_lr_scheduler()
118
+ print(f"Learning rate scheduler: {'Enabled' if self.lr_scheduler else 'Disabled'}")
119
+
120
+ # For tracking learning rates
121
+ self.learning_rate_history = {
122
+ 'positions': [],
123
+ 'scales': [],
124
+ 'rotations': [],
125
+ 'shs': [],
126
+ 'opacities': []
127
+ }
128
+
129
+ # Load dataset
130
+ print(f"Loading COLMAP dataset from {self.dataset_path}")
131
+ self.cameras, self.image_paths = self.load_colmap("train")
132
+ self.test_cameras, self.test_image_paths = self.load_colmap("test")
133
+
134
+ print(f"Loaded {len(self.cameras)} train cameras and {len(self.image_paths)} train images")
135
+ print(f"Loaded {len(self.test_cameras)} test cameras and {len(self.test_image_paths)} test images")
136
+
137
+ # Calculate scene extent for densification
138
+ self.scene_extent = self.calculate_scene_extent()
139
+ print(f"Calculated scene extent: {self.scene_extent}")
140
+
141
+ # Initialize parameters (this may update self.num_points if points3D.ply is found)
142
+ self.params = self.initialize_parameters()
143
+ print(f"Initialized {self.num_points} Gaussians")
144
+
145
+ # Create gradient arrays
146
+ self.grads = self.create_gradient_arrays()
147
+
148
+ # Create optimizer state
149
+ self.adam_m = self.create_gradient_arrays()
150
+ self.adam_v = self.create_gradient_arrays()
151
+
152
+ # Initialize densification state tracking
153
+ self.init_densification_state()
154
+
155
+ # For tracking loss
156
+ self.losses = []
157
+
158
+ # Initialize intermediate buffers dictionary
159
+ self.intermediate_buffers = {}
160
+
161
+ # Track iteration for opacity reset
162
+ self.opacity_reset_at = -32768
163
+
164
+
165
+ # Call after loading data
166
+ #self.visualize_camera_points_alignment()
167
+
168
+ def create_lr_scheduler(self):
169
+ """Create simple learning rate schedulers for each parameter type."""
170
+ if not self.config['use_lr_scheduler']:
171
+ return None
172
+
173
+ config = self.config['lr_scheduler_config']
174
+ final_factor = config['final_lr_factor']
175
+
176
+ schedulers = {
177
+ 'positions': LRScheduler(config['lr_pos'], final_factor),
178
+ 'scales': LRScheduler(config['lr_scale'], final_factor),
179
+ 'rotations': LRScheduler(config['lr_rot'], final_factor),
180
+ 'shs': LRScheduler(config['lr_sh'], final_factor),
181
+ 'opacities': LRScheduler(config['lr_opac'], final_factor)
182
+ }
183
+
184
+ return schedulers
185
+
186
+ def initialize_parameters(self):
187
+ """Initialize Gaussian parameters using points3D.ply if available."""
188
+ # Try to load points from points3D.ply
189
+ points3d_path = self.dataset_path / "sparse/0/points3D.ply"
190
+ initial_positions_np = None # Renamed to avoid confusion
191
+ initial_colors_np = None # Renamed
192
+
193
+ if points3d_path.exists():
194
+ try:
195
+ plydata = PlyData.read(str(points3d_path))
196
+ vertices = plydata['vertex']
197
+ if 'x' in vertices and 'y' in vertices and 'z' in vertices:
198
+ positions_data = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
199
+ initial_positions_np = positions_data.astype(np.float32)
200
+
201
+ if 'red' in vertices and 'green' in vertices and 'blue' in vertices:
202
+ colors_data = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T
203
+ initial_colors_np = (colors_data / 255.0).astype(np.float32)
204
+ else:
205
+ print("Warning: Color attributes (red, green, blue) not found in points3D.ply.")
206
+
207
+ # Update num_points based on loaded points
208
+ self.num_points = len(initial_positions_np)
209
+ print(f"Loaded {self.num_points} points from points3D.ply")
210
+ except Exception as e:
211
+ print(f"Warning: Could not load points3D.ply: {e}")
212
+ initial_positions_np = None
213
+ initial_colors_np = None
214
+
215
+ if initial_positions_np is None:
216
+ # Fallback if points3D.ply is not loaded or doesn't have positions
217
+ print(f"Warning: Initial positions not loaded. Initializing {self.num_points} positions to zeros (or expect random init if uncommented in kernel).")
218
+ # self.num_points is already set from config or updated if PLY was partially read
219
+ initial_positions_np = np.zeros((self.num_points, 3), dtype=np.float32)
220
+
221
+
222
+ # Initialize scales_np
223
+ scales_np = np.zeros((self.num_points, 3), dtype=np.float32)
224
+ if initial_positions_np is not None and self.num_points > 3: # cKDTree needs k <= num_points
225
+ try:
226
+ print("Calculating initial scales using cKDTree...")
227
+ kdtree = cKDTree(initial_positions_np)
228
+ k = 2 # 1 self-point + 3 nearest neighbors
229
+ distances, _ = kdtree.query(initial_positions_np, k=k, workers=-1) # Use all available cores
230
+
231
+ # distances[:, 0] is the distance to self (0.0), so we use distances[:, 1:]
232
+ radius_np = np.mean(distances[:, 1:], axis=1)
233
+ scales_np = np.tile(radius_np[:, np.newaxis], (1, 3))
234
+ print(f"Initial scales calculated. Min radius: {radius_np.min()}, Max radius: {radius_np.max()}, Mean radius: {radius_np.mean()}")
235
+ except Exception as e:
236
+ print(f"Error during cKDTree scale initialization: {e}. Falling back to default scale.")
237
+ default_scale_val = self.config['initial_scale']
238
+ scales_np = np.full((self.num_points, 3), default_scale_val, dtype=np.float32)
239
+ else:
240
+ default_scale_val = self.config['initial_scale']
241
+ print(f"Not enough points for cKDTree or initial_positions_np is None. Using default scale: {default_scale_val}")
242
+ scales_np = np.full((self.num_points, 3), default_scale_val, dtype=np.float32)
243
+
244
+ # Initialize arrays with proper size
245
+ positions = wp.array(initial_positions_np, dtype=wp.vec3, device=DEVICE)
246
+ scales = wp.array(scales_np, dtype=wp.vec3, device=DEVICE) # Use the calculated or default scales_np
247
+ rotations = wp.zeros(self.num_points, dtype=wp.vec4, device=DEVICE)
248
+ opacities = wp.zeros(self.num_points, dtype=float, device=DEVICE)
249
+
250
+ C0 = 0.28209479177387814 # Constant for Y₀₀
251
+ shs_np_data = np.zeros((self.num_points * 16, 3), dtype=np.float32)
252
+ if initial_colors_np is not None and initial_colors_np.shape[0] == self.num_points:
253
+ shs_np_data[::16] = (initial_colors_np - 0.5) / C0
254
+ else:
255
+ # Default to gray if colors are not available or mismatch
256
+ gray_color_sh = (np.array([0.5, 0.5, 0.5]) - 0.5) / C0
257
+ shs_np_data[::16] = np.tile(gray_color_sh, (self.num_points, 1))
258
+ shs = wp.array(shs_np_data, dtype=wp.vec3, device=DEVICE)
259
+
260
+ # Launch kernel to initialize parameters (rotations and opacities)
261
+ # scales and shs are already initialized from Python side.
262
+ wp.launch(
263
+ init_gaussian_params,
264
+ dim=self.num_points,
265
+ inputs=[rotations, opacities, self.num_points] # Removed self.config['initial_scale']
266
+ )
267
+
268
+ return {
269
+ 'positions': positions,
270
+ 'scales': scales,
271
+ 'rotations': rotations,
272
+ 'opacities': opacities,
273
+ 'shs': shs
274
+ }
275
+
276
+ def create_gradient_arrays(self):
277
+ """Create arrays for gradients or optimizer state."""
278
+ positions = wp.zeros(self.num_points, dtype=wp.vec3)
279
+ scales = wp.zeros(self.num_points, dtype=wp.vec3)
280
+ rotations = wp.zeros(self.num_points, dtype=wp.vec4)
281
+ opacities = wp.zeros(self.num_points, dtype=float)
282
+ shs = wp.zeros(self.num_points * 16, dtype=wp.vec3)
283
+
284
+ # Return a dictionary of arrays
285
+ return {
286
+ 'positions': positions,
287
+ 'scales': scales,
288
+ 'rotations': rotations,
289
+ 'opacities': opacities,
290
+ 'shs': shs
291
+ }
292
+
293
+ def calculate_scene_extent(self):
294
+ """Calculate the extent of the scene based on camera positions."""
295
+ if not self.cameras:
296
+ return 1.0 # Default fallback
297
+
298
+ # Extract camera positions
299
+ camera_positions = []
300
+ for camera in self.cameras:
301
+ camera_positions.append(camera['camera_center'])
302
+
303
+ camera_positions = np.array(camera_positions)
304
+
305
+ # Calculate the centroid of all camera positions
306
+ scene_center = np.mean(camera_positions, axis=0)
307
+
308
+ # Calculate the maximum distance from any camera to the scene center
309
+ max_distance_to_center = 0.0
310
+ for pos in camera_positions:
311
+ distance = np.linalg.norm(pos - scene_center)
312
+ max_distance_to_center = max(max_distance_to_center, distance)
313
+
314
+ # The scene extent is the radius of the bounding sphere
315
+ # Use default factor if extent is too small
316
+ extent = max_distance_to_center * self.config.get('camera_extent_factor', 1.0)
317
+ return max(extent, 1.0)
318
+
319
+ def init_densification_state(self):
320
+ """Initialize state tracking for densification."""
321
+ self.xyz_gradient_accum = wp.zeros(self.num_points, dtype=float, device=DEVICE)
322
+ self.denom = wp.zeros(self.num_points, dtype=float, device=DEVICE)
323
+ self.max_radii2D = wp.zeros(self.num_points, dtype=float, device=DEVICE)
324
+
325
+ def load_colmap(self, datasplit="train", llffhold=8):
326
+ colmap_dir = self.dataset_path / "sparse/0"
327
+ images_dir = self.dataset_path / "images"
328
+ intrinsics = {}
329
+
330
+ with open(colmap_dir / "cameras.txt") as f:
331
+ for line in f:
332
+ if line.startswith("#"): continue
333
+ vals = line.strip().split()
334
+ if len(vals) < 4: continue
335
+
336
+ cam_id, model, w, h = int(vals[0]), vals[1], int(vals[2]), int(vals[3])
337
+
338
+ if model == "PINHOLE":
339
+ # PINHOLE has 4 parameters: fx, fy, cx, cy
340
+ if len(vals) >= 8: # 4 basic + 4 params
341
+ fx, fy, cx, cy = float(vals[4]), float(vals[5]), float(vals[6]), float(vals[7])
342
+ else:
343
+ continue
344
+ elif model == "SIMPLE_PINHOLE":
345
+ # SIMPLE_PINHOLE has 3 parameters: f, cx, cy
346
+ if len(vals) >= 7: # 4 basic + 3 params
347
+ f, cx, cy = float(vals[4]), float(vals[5]), float(vals[6])
348
+ fx = fy = f # Same focal length for both axes
349
+ else:
350
+ continue
351
+ else:
352
+ print(f"Unsupported camera model: {model}")
353
+ continue
354
+
355
+ intrinsics[cam_id] = (fx, fy, w, h, cx, cy)
356
+
357
+ extrinsics = []
358
+ with open(colmap_dir / "images.txt") as f:
359
+ for line in f:
360
+ if line.startswith("#"): continue
361
+ parts = line.strip().split()
362
+ if len(parts) < 10: continue
363
+
364
+ # COLMAP images.txt format: IMAGE_ID QW QX QY QZ TX TY TZ CAMERA_ID NAME
365
+ img_id, qw, qx, qy, qz, tx, ty, tz, cam_id, img_name = parts[:10]
366
+ cam_id = int(cam_id)
367
+
368
+ if cam_id not in intrinsics:
369
+ print(f"Warning: Camera ID {cam_id} not found in intrinsics")
370
+ continue
371
+
372
+ fx, fy, w, h, cx, cy = intrinsics[cam_id]
373
+
374
+ # Fix quaternion order and normalize
375
+ q = np.array([float(qw), float(qx), float(qy), float(qz)])
376
+ q = q / np.linalg.norm(q) # Normalize if needed
377
+
378
+ t = np.array([float(tx), float(ty), float(tz)])
379
+ R = quaternion_to_rotation_matrix(q)
380
+
381
+ # # Convert from COLMAP's world-to-camera to camera-to-world
382
+ # c2w = np.eye(4, dtype=np.float32)
383
+ # c2w[:3, :3] = R.T
384
+ # c2w[:3, 3] = -R.T @ t
385
+
386
+ cam_info = {
387
+ "camera_id": int(img_id),
388
+ #"camera_to_world": c2w,
389
+ "width": w,
390
+ "height": h,
391
+ "fx": fx,
392
+ "fy": fy,
393
+ "cx": cx,
394
+ "cy": cy,
395
+ "R": R,
396
+ "T": t
397
+ }
398
+
399
+ camera = load_camera_colmap(cam_info)
400
+ if camera:
401
+ extrinsics.append((camera, str(images_dir / img_name)))
402
+
403
+ # Split data based on datasplit parameter
404
+ if datasplit == "train":
405
+ selected = [c for i, c in enumerate(extrinsics) if i % llffhold != 0]
406
+ elif datasplit == "test":
407
+ selected = [c for i, c in enumerate(extrinsics) if i % llffhold == 0]
408
+ else:
409
+ selected = extrinsics
410
+
411
+ if selected:
412
+ cameras, image_paths = zip(*selected)
413
+ width = cameras[0]['width']
414
+ height = cameras[0]['height']
415
+ fx = cameras[0]['fx']
416
+ fy = cameras[0]['fy']
417
+
418
+ # Calculate field of view
419
+ camera_angle_x = 2 * np.arctan(0.5 * width / fx)
420
+ camera_angle_y = 2 * np.arctan(0.5 * height / fy)
421
+
422
+ self.config['width'] = width
423
+ self.config['height'] = height
424
+ self.config['fx'] = fx
425
+ self.config['fy'] = fy
426
+ self.config['focal'] = fx # Use fx as primary focal length
427
+
428
+ return list(cameras), list(image_paths)
429
+
430
+ return [], []
431
+
432
+
433
+ def load_image(self, path):
434
+ """Load an image as a numpy array."""
435
+ if os.path.exists(path):
436
+ img = imageio.imread(path)
437
+ # Convert to float and normalize to [0, 1]
438
+ img_np = img.astype(np.float32) / 255.0
439
+ # Ensure image is RGB (discard alpha channel if present)
440
+ if img_np.shape[2] == 4:
441
+ img_np = img_np[:, :, :3] # Keep only R, G, B channels
442
+ return img_np
443
+ else:
444
+ raise FileNotFoundError(f"Image not found: {path}")
445
+
446
+ def zero_grad(self):
447
+ """Zero out all gradients."""
448
+ wp.launch(
449
+ zero_gradients,
450
+ dim=self.num_points,
451
+ inputs=[
452
+ self.grads['positions'],
453
+ self.grads['scales'],
454
+ self.grads['rotations'],
455
+ self.grads['opacities'],
456
+ self.grads['shs'],
457
+ self.num_points
458
+ ]
459
+ )
460
+
461
+ def densification_and_pruning(self, iteration):
462
+ """Perform sophisticated densification and pruning of Gaussians."""
463
+
464
+ # Check if we should do densification
465
+ densify_from_iter = self.config.get('densify_from_iter', 500)
466
+ densify_until_iter = self.config.get('densify_until_iter', 15000)
467
+ densification_interval = self.config.get('densification_interval', 100)
468
+ opacity_reset_interval = self.config.get('opacity_reset_interval', 3000)
469
+
470
+ # Skip densification if outside iteration range
471
+ if iteration > densify_from_iter and iteration < densify_until_iter and iteration % densification_interval == 0:
472
+ print(f"Iteration {iteration}: Performing sophisticated densification and pruning")
473
+
474
+ # For simplified implementation, use position gradients as proxy for viewspace gradients
475
+ pos_grads = self.grads['positions']
476
+ avg_grads = wp.zeros(self.num_points, dtype=float, device=DEVICE)
477
+
478
+ @wp.kernel
479
+ def compute_grad_norms(pos_grad: wp.array(dtype=wp.vec3),
480
+ grad_norms: wp.array(dtype=float),
481
+ num_points: int):
482
+ i = wp.tid()
483
+ if i >= num_points:
484
+ return
485
+ grad_norms[i] = wp.length(pos_grad[i])
486
+
487
+ wp.launch(compute_grad_norms, dim=self.num_points,
488
+ inputs=[pos_grads, avg_grads, self.num_points])
489
+
490
+ # Configuration
491
+ grad_threshold = self.config.get('densify_grad_threshold', 0.0002)
492
+ percent_dense = self.config.get('percent_dense', 0.01)
493
+
494
+ # --- Step 1: Clone small Gaussians with high gradients ---
495
+ clone_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
496
+ wp.launch(
497
+ mark_clone_candidates,
498
+ dim=self.num_points,
499
+ inputs=[
500
+ avg_grads,
501
+ self.params['scales'],
502
+ grad_threshold,
503
+ self.scene_extent,
504
+ percent_dense,
505
+ clone_mask,
506
+ self.num_points
507
+ ]
508
+ )
509
+
510
+ # Perform cloning
511
+ clone_prefix_sum = wp.zeros_like(clone_mask)
512
+ wp.utils.array_scan(clone_mask, clone_prefix_sum, inclusive=False)
513
+ total_to_clone = int(clone_prefix_sum.numpy()[-1])
514
+
515
+ if total_to_clone > 0:
516
+ print(f"[Clone] Cloning {total_to_clone} small Gaussians")
517
+ N = self.num_points
518
+ new_N = N + total_to_clone
519
+
520
+ # Allocate output arrays
521
+ out_params = {
522
+ 'positions': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
523
+ 'scales': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
524
+ 'rotations': wp.zeros(new_N, dtype=wp.vec4, device=DEVICE),
525
+ 'opacities': wp.zeros(new_N, dtype=float, device=DEVICE),
526
+ 'shs': wp.zeros(new_N * 16, dtype=wp.vec3, device=DEVICE)
527
+ }
528
+
529
+ # Clone Gaussians
530
+ wp.launch(
531
+ clone_gaussians,
532
+ dim=N,
533
+ inputs=[
534
+ clone_mask,
535
+ clone_prefix_sum,
536
+ self.params['positions'],
537
+ self.params['scales'],
538
+ self.params['rotations'],
539
+ self.params['opacities'],
540
+ self.params['shs'],
541
+ 0.01, # noise_scale
542
+ N, # offset
543
+ out_params['positions'],
544
+ out_params['scales'],
545
+ out_params['rotations'],
546
+ out_params['opacities'],
547
+ out_params['shs']
548
+ ]
549
+ )
550
+
551
+ # Update parameters and state
552
+ self.params = out_params
553
+ self.num_points = new_N
554
+ self.grads = self.create_gradient_arrays()
555
+ self.adam_m = self.create_gradient_arrays()
556
+ self.adam_v = self.create_gradient_arrays()
557
+
558
+ # --- Step 2: Split large Gaussians with high gradients ---
559
+ split_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
560
+ wp.launch(
561
+ mark_split_candidates,
562
+ dim=self.num_points,
563
+ inputs=[
564
+ avg_grads,
565
+ self.params['scales'],
566
+ grad_threshold,
567
+ self.scene_extent,
568
+ percent_dense,
569
+ split_mask,
570
+ self.num_points
571
+ ]
572
+ )
573
+
574
+ # Perform splitting
575
+ split_prefix_sum = wp.zeros_like(split_mask)
576
+ wp.utils.array_scan(split_mask, split_prefix_sum, inclusive=False)
577
+ total_to_split = int(split_prefix_sum.numpy()[-1])
578
+
579
+ if total_to_split > 0:
580
+ print(f"[Split] Splitting {total_to_split} large Gaussians")
581
+ N = self.num_points
582
+ N_split = 2 # Split each Gaussian into 2
583
+ new_N = N + total_to_split * N_split
584
+
585
+ # Allocate output arrays
586
+ out_params = {
587
+ 'positions': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
588
+ 'scales': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
589
+ 'rotations': wp.zeros(new_N, dtype=wp.vec4, device=DEVICE),
590
+ 'opacities': wp.zeros(new_N, dtype=float, device=DEVICE),
591
+ 'shs': wp.zeros(new_N * 16, dtype=wp.vec3, device=DEVICE)
592
+ }
593
+
594
+ # Split Gaussians
595
+ wp.launch(
596
+ split_gaussians,
597
+ dim=N,
598
+ inputs=[
599
+ split_mask,
600
+ split_prefix_sum,
601
+ self.params['positions'],
602
+ self.params['scales'],
603
+ self.params['rotations'],
604
+ self.params['opacities'],
605
+ self.params['shs'],
606
+ N_split, # Number of splits per Gaussian
607
+ 0.8, # scale_factor
608
+ N, # offset
609
+ out_params['positions'],
610
+ out_params['scales'],
611
+ out_params['rotations'],
612
+ out_params['opacities'],
613
+ out_params['shs']
614
+ ]
615
+ )
616
+
617
+ # Update parameters and state
618
+ self.params = out_params
619
+ self.num_points = new_N
620
+ self.grads = self.create_gradient_arrays()
621
+ self.adam_m = self.create_gradient_arrays()
622
+ self.adam_v = self.create_gradient_arrays()
623
+
624
+ # Remove original split Gaussians
625
+ prune_filter = wp.zeros(self.num_points, dtype=int, device=DEVICE)
626
+
627
+ @wp.kernel
628
+ def mark_split_originals_for_removal(
629
+ split_mask: wp.array(dtype=int),
630
+ prune_filter: wp.array(dtype=int),
631
+ offset: int,
632
+ num_points: int
633
+ ):
634
+ i = wp.tid()
635
+ if i >= num_points:
636
+ return
637
+ if i < offset and split_mask[i] == 1:
638
+ prune_filter[i] = 1 # Mark for removal
639
+ else:
640
+ prune_filter[i] = 0 # Keep
641
+
642
+ wp.launch(mark_split_originals_for_removal, dim=self.num_points,
643
+ inputs=[split_mask, prune_filter, N, self.num_points])
644
+
645
+ # Invert mask to get valid mask
646
+ valid_mask = wp.zeros_like(prune_filter)
647
+
648
+ @wp.kernel
649
+ def invert_mask(prune: wp.array(dtype=int), valid: wp.array(dtype=int), n: int):
650
+ i = wp.tid()
651
+ if i >= n:
652
+ return
653
+ valid[i] = 1 - prune[i]
654
+
655
+ wp.launch(invert_mask, dim=self.num_points,
656
+ inputs=[prune_filter, valid_mask, self.num_points])
657
+
658
+ # Count valid points and compact
659
+ prefix_sum = wp.zeros_like(valid_mask)
660
+ wp.utils.array_scan(valid_mask, prefix_sum, inclusive=False)
661
+ valid_count = int(prefix_sum.numpy()[-1])
662
+
663
+ if valid_count < self.num_points:
664
+ print(f"[Split] Removing {self.num_points - valid_count} original split Gaussians")
665
+
666
+ # Allocate compacted output
667
+ compact_params = {
668
+ 'positions': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
669
+ 'scales': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
670
+ 'rotations': wp.zeros(valid_count, dtype=wp.vec4, device=DEVICE),
671
+ 'opacities': wp.zeros(valid_count, dtype=float, device=DEVICE),
672
+ 'shs': wp.zeros(valid_count * 16, dtype=wp.vec3, device=DEVICE)
673
+ }
674
+
675
+ wp.launch(
676
+ compact_gaussians,
677
+ dim=self.num_points,
678
+ inputs=[
679
+ valid_mask,
680
+ prefix_sum,
681
+ self.params['positions'],
682
+ self.params['scales'],
683
+ self.params['rotations'],
684
+ self.params['opacities'],
685
+ self.params['shs'],
686
+ compact_params['positions'],
687
+ compact_params['scales'],
688
+ compact_params['rotations'],
689
+ compact_params['opacities'],
690
+ compact_params['shs']
691
+ ]
692
+ )
693
+
694
+ # Update parameters and state
695
+ self.params = compact_params
696
+ self.num_points = valid_count
697
+ self.grads = self.create_gradient_arrays()
698
+ self.adam_m = self.create_gradient_arrays()
699
+ self.adam_v = self.create_gradient_arrays()
700
+
701
+ # --- Step 3: Enhanced Pruning ---
702
+ print(f"[Prune] Performing enhanced pruning")
703
+
704
+ valid_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
705
+
706
+ # Use opacity-based pruning for now
707
+ wp.launch(
708
+ prune_gaussians,
709
+ dim=self.num_points,
710
+ inputs=[
711
+ self.params['opacities'],
712
+ self.config.get('cull_opacity_threshold', 0.005),
713
+ valid_mask,
714
+ self.num_points
715
+ ]
716
+ )
717
+
718
+ # Count valid points
719
+ prefix_sum = wp.zeros_like(valid_mask)
720
+ wp.utils.array_scan(valid_mask, prefix_sum, inclusive=False)
721
+ valid_count = int(prefix_sum.numpy()[-1])
722
+
723
+ # Check pruning constraints
724
+ min_valid_points = self.config.get('min_valid_points', 1000)
725
+ max_valid_points = self.config.get('max_valid_points', 1000000)
726
+ max_prune_ratio = self.config.get('max_allowed_prune_ratio', 0.5)
727
+
728
+ prune_count = self.num_points - valid_count
729
+ prune_ratio = prune_count / self.num_points if self.num_points > 0 else 0
730
+
731
+ if (valid_count >= min_valid_points and
732
+ valid_count <= max_valid_points and
733
+ prune_ratio <= max_prune_ratio and
734
+ valid_count < self.num_points):
735
+
736
+ print(f"[Prune] Compacting from {self.num_points} → {valid_count} points")
737
+
738
+ # Allocate compacted output
739
+ out_params = {
740
+ 'positions': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
741
+ 'scales': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
742
+ 'rotations': wp.zeros(valid_count, dtype=wp.vec4, device=DEVICE),
743
+ 'opacities': wp.zeros(valid_count, dtype=float, device=DEVICE),
744
+ 'shs': wp.zeros(valid_count * 16, dtype=wp.vec3, device=DEVICE)
745
+ }
746
+
747
+ wp.launch(
748
+ compact_gaussians,
749
+ dim=self.num_points,
750
+ inputs=[
751
+ valid_mask,
752
+ prefix_sum,
753
+ self.params['positions'],
754
+ self.params['scales'],
755
+ self.params['rotations'],
756
+ self.params['opacities'],
757
+ self.params['shs'],
758
+ out_params['positions'],
759
+ out_params['scales'],
760
+ out_params['rotations'],
761
+ out_params['opacities'],
762
+ out_params['shs']
763
+ ]
764
+ )
765
+
766
+ # Update parameters and state
767
+ self.params = out_params
768
+ self.num_points = valid_count
769
+ self.grads = self.create_gradient_arrays()
770
+ self.adam_m = self.create_gradient_arrays()
771
+ self.adam_v = self.create_gradient_arrays()
772
+ else:
773
+ print(f"[Prune] Skipping pruning: valid={valid_count}, ratio={prune_ratio:.3f}")
774
+
775
+
776
+ # Opacity reset - updated logic to match reference implementation
777
+ background_is_white = all(c == 1.0 for c in self.config['background_color'])
778
+ should_reset_opacity = (
779
+ iteration % opacity_reset_interval == 0 or
780
+ (background_is_white and iteration == densify_from_iter)
781
+ )
782
+
783
+ if should_reset_opacity:
784
+ print(f"Iteration {iteration}: Resetting opacities")
785
+ wp.launch(
786
+ reset_opacities,
787
+ dim=self.num_points,
788
+ inputs=[
789
+ self.params['opacities'],
790
+ 0.01, # max_opacity
791
+ self.num_points
792
+ ]
793
+ )
794
+
795
+
796
+ def optimizer_step(self, iteration):
797
+ """Perform an Adam optimization step."""
798
+
799
+ # Get learning rates from scheduler or use config defaults
800
+ if self.lr_scheduler:
801
+ lr_pos = self.lr_scheduler['positions'].get_lr(iteration, self.config['num_iterations'])
802
+ lr_scale = self.lr_scheduler['scales'].get_lr(iteration, self.config['num_iterations'])
803
+ lr_rot = self.lr_scheduler['rotations'].get_lr(iteration, self.config['num_iterations'])
804
+ lr_sh = self.lr_scheduler['shs'].get_lr(iteration, self.config['num_iterations'])
805
+ lr_opac = self.lr_scheduler['opacities'].get_lr(iteration, self.config['num_iterations'])
806
+
807
+ # Track learning rate history
808
+ self.learning_rate_history['positions'].append(lr_pos)
809
+ self.learning_rate_history['scales'].append(lr_scale)
810
+ self.learning_rate_history['rotations'].append(lr_rot)
811
+ self.learning_rate_history['shs'].append(lr_sh)
812
+ self.learning_rate_history['opacities'].append(lr_opac)
813
+
814
+ # Log learning rates occasionally
815
+ if iteration % 1000 == 0:
816
+ print(f"Iteration {iteration} learning rates:")
817
+ print(f" positions: {lr_pos:.6f}")
818
+ print(f" scales: {lr_scale:.6f}")
819
+ print(f" rotations: {lr_rot:.6f}")
820
+ print(f" shs: {lr_sh:.6f}")
821
+ print(f" opacities: {lr_opac:.6f}")
822
+ else:
823
+ # Use static learning rates from config
824
+ lr_pos = self.config['lr_pos']
825
+ lr_scale = self.config['lr_scale']
826
+ lr_rot = self.config['lr_rot']
827
+ lr_sh = self.config['lr_sh']
828
+ lr_opac = self.config['lr_opac']
829
+
830
+ wp.launch(
831
+ adam_update,
832
+ dim=self.num_points,
833
+ inputs=[
834
+ # Parameters
835
+ self.params['positions'],
836
+ self.params['scales'],
837
+ self.params['rotations'],
838
+ self.params['opacities'],
839
+ self.params['shs'],
840
+
841
+ # Gradients
842
+ self.grads['positions'],
843
+ self.grads['scales'],
844
+ self.grads['rotations'],
845
+ self.grads['opacities'],
846
+ self.grads['shs'],
847
+
848
+ # First moments (m)
849
+ self.adam_m['positions'],
850
+ self.adam_m['scales'],
851
+ self.adam_m['rotations'],
852
+ self.adam_m['opacities'],
853
+ self.adam_m['shs'],
854
+
855
+ # Second moments (v)
856
+ self.adam_v['positions'],
857
+ self.adam_v['scales'],
858
+ self.adam_v['rotations'],
859
+ self.adam_v['opacities'],
860
+ self.adam_v['shs'],
861
+
862
+ # Optimizer parameters with dynamic learning rates
863
+ self.num_points,
864
+ lr_pos, # Dynamic learning rate for positions
865
+ lr_scale, # Dynamic learning rate for scales
866
+ lr_rot, # Dynamic learning rate for rotations
867
+ lr_sh, # Dynamic learning rate for SH coefficients
868
+ lr_opac, # Dynamic learning rate for opacities
869
+ self.config['adam_beta1'],
870
+ self.config['adam_beta2'],
871
+ self.config['adam_epsilon'],
872
+ iteration
873
+ ]
874
+ )
875
+
876
+ def save_checkpoint(self, iteration):
877
+ """Save the current point cloud and training state."""
878
+ checkpoint_dir = self.output_path / "point_cloud" / f"iteration_{iteration}"
879
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
880
+
881
+ # Save point cloud as PLY
882
+ ply_path = checkpoint_dir / "point_cloud.ply"
883
+ save_ply(self.params, ply_path, self.num_points)
884
+
885
+ # Save loss history
886
+ loss_path = self.output_path / "loss.txt"
887
+ with open(loss_path, 'w') as f:
888
+ for loss in self.losses:
889
+ f.write(f"{loss}\n")
890
+
891
+ # Save loss plot
892
+ plt.figure(figsize=(10, 5))
893
+ plt.plot(self.losses)
894
+ plt.title('Training Loss')
895
+ plt.xlabel('Iteration')
896
+ plt.ylabel('Loss')
897
+ plt.savefig(self.output_path / "loss_plot.png")
898
+ plt.close()
899
+
900
+ # Save a rendered view
901
+ camera_idx = 0 # Front view
902
+ rendered_image, _, _ = render_gaussians(
903
+ background=np.array(self.config['background_color'], dtype=np.float32),
904
+ means3D=self.params['positions'].numpy(),
905
+ colors=None, # Use SH coefficients instead
906
+ opacity=self.params['opacities'].numpy(),
907
+ scales=self.params['scales'].numpy(),
908
+ rotations=self.params['rotations'].numpy(),
909
+ scale_modifier=self.config['scale_modifier'],
910
+ viewmatrix=self.cameras[camera_idx]['world_to_camera'],
911
+ projmatrix=self.cameras[camera_idx]['full_proj_matrix'],
912
+ tan_fovx=self.cameras[camera_idx]['tan_fovx'],
913
+ tan_fovy=self.cameras[camera_idx]['tan_fovy'],
914
+ image_height=self.cameras[camera_idx]['height'],
915
+ image_width=self.cameras[camera_idx]['width'],
916
+ sh=self.params['shs'].numpy(), # Pass SH coefficients
917
+ degree=self.config['sh_degree'],
918
+ campos=self.cameras[camera_idx]['camera_center'],
919
+ prefiltered=False,
920
+ antialiasing=True,
921
+ clamped=True
922
+ )
923
+ # Save rendered view as image
924
+ rendered_array = wp.to_torch(rendered_image).cpu().numpy()
925
+ # Handle case where rendered_array has shape (3, H, W) - transpose to (H, W, 3)
926
+ if rendered_array.shape[0] == 3 and len(rendered_array.shape) == 3:
927
+ rendered_array = np.transpose(rendered_array, (1, 2, 0))
928
+ img8 = (np.clip(rendered_array, 0, 1) * 255).astype(np.uint8)
929
+ imageio.imwrite(checkpoint_dir / "rendered_view.png", img8)
930
+
931
+
932
+ def debug_log_and_save_images(
933
+ self,
934
+ rendered_image, # np.float32 H×W×3 (range 0-1)
935
+ target_image, # np.float32
936
+ depth_image, # wp.array2d(float) – optional but unused here
937
+ camera_idx: int,
938
+ it: int
939
+ ):
940
+ # ------ quick numeric read-out -----------------------------------
941
+ radii = wp.to_torch(self.intermediate_buffers["radii"]).cpu().numpy()
942
+ alphas = wp.to_torch(self.intermediate_buffers["conic_opacity"]).cpu().numpy()[:, 3]
943
+ offs = wp.to_torch(self.intermediate_buffers["point_offsets"]).cpu().numpy()
944
+ num_dup = int(offs[-1]) if len(offs) else 0
945
+ r_med = np.median(radii[radii > 0]) if (radii > 0).any() else 0
946
+
947
+ # Count visible Gaussians
948
+ xy_image = wp.to_torch(self.intermediate_buffers["points_xy_image"]).cpu().numpy()
949
+ W = self.cameras[camera_idx]['width']
950
+ H = self.cameras[camera_idx]['height']
951
+ visible_gaussians = np.sum(
952
+ (xy_image[:, 0] >= 0) & (xy_image[:, 0] < W) &
953
+ (xy_image[:, 1] >= 0) & (xy_image[:, 1] < H) &
954
+ np.isfinite(xy_image).all(axis=1) &
955
+ (radii > 0) # Only count Gaussians with positive radius
956
+ )
957
+
958
+ print(
959
+ f"[it {it:05d}] cam={camera_idx:02d} dup={num_dup:<6} "
960
+ f"r_med={r_med:5.1f} α∈[{alphas.min():.3f},"
961
+ f"{np.median(alphas):.3f},{alphas.max():.3f}] "
962
+ f"visible={visible_gaussians}/{len(xy_image)}"
963
+ )
964
+
965
+ # ------ save render / target PNG ---------------------------------
966
+ def save_rgb(arr_f32, stem):
967
+ # Handle case where arr_f32 has shape (3, H, W) - transpose to (H, W, 3)
968
+ if arr_f32.shape[0] == 3 and len(arr_f32.shape) == 3:
969
+ arr_f32 = np.transpose(arr_f32, (1, 2, 0))
970
+ img8 = (np.clip(arr_f32, 0, 1) * 255).astype(np.uint8)
971
+ # Include camera index in the filename
972
+ imageio.imwrite(self.output_path / f"{stem}" / f"{stem}_{it:06d}_cam{camera_idx:02d}.png", img8)
973
+
974
+ save_rgb(rendered_image if isinstance(rendered_image, np.ndarray) else wp.to_torch(rendered_image).cpu().numpy(), "render")
975
+ save_rgb(target_image, "target")
976
+
977
+ # ------ make 2-D projection scatter ------------------------------
978
+ xy = wp.to_torch(self.intermediate_buffers["points_xy_image"]).cpu().numpy()
979
+ depth = wp.to_torch(self.intermediate_buffers["depths"]).cpu().numpy()
980
+ H, W = self.config["height"], self.config["width"]
981
+
982
+ mask = (
983
+ (xy[:, 0] >= 0) & (xy[:, 0] < W) &
984
+ (xy[:, 1] >= 0) & (xy[:, 1] < H) &
985
+ np.isfinite(xy).all(axis=1) &
986
+ (radii > 0) # Only include Gaussians with positive radius
987
+ )
988
+ if mask.any():
989
+ plt.figure(figsize=(6, 6))
990
+ plt.scatter(xy[mask, 0], xy[mask, 1],
991
+ s=4, c=depth[mask], cmap="turbo", alpha=.7)
992
+ plt.gca().invert_yaxis()
993
+ plt.xlim(0, W); plt.ylim(H, 0)
994
+ plt.title(f"Projected Gaussians (cam {camera_idx}, iter {it}): {np.sum(mask)}/{len(xy)} visible")
995
+ plt.colorbar(label="depth(z)")
996
+ plt.tight_layout()
997
+ # Include camera index in the filename
998
+ plt.savefig(self.output_path / 'proj' / f"proj_{it:06d}_cam{camera_idx:02d}.png", dpi=250)
999
+ plt.close()
1000
+
1001
+ # depth histogram
1002
+ plt.figure(figsize=(5, 3))
1003
+ plt.hist(depth[mask], bins=40, color="steelblue")
1004
+ plt.xlabel("depth (camera-z)")
1005
+ plt.ylabel("count")
1006
+ plt.title(f"Depth hist – cam {camera_idx}, {mask.sum()} pts")
1007
+ plt.tight_layout()
1008
+ # Include camera index in the filename
1009
+ plt.savefig(self.output_path / 'depth_hist' / f"depth_hist_{it:06d}.png", dpi=250)
1010
+ plt.close()
1011
+
1012
+ def train(self):
1013
+ """Train the 3D Gaussian Splatting model."""
1014
+ num_iterations = self.config['num_iterations']
1015
+
1016
+ # Main training loop
1017
+ with tqdm(total=num_iterations) as pbar:
1018
+ for iteration in range(num_iterations):
1019
+ # Select a random camera and corresponding image
1020
+ camera_idx = np.random.randint(0, len(self.cameras))
1021
+ image_path = self.image_paths[camera_idx]
1022
+ target_image = self.load_image(image_path)
1023
+
1024
+ # Zero gradients
1025
+ self.zero_grad()
1026
+ # Render the view
1027
+ rendered_image, depth_image, self.intermediate_buffers = render_gaussians(
1028
+ background=np.array(self.config['background_color'], dtype=np.float32),
1029
+ means3D=self.params['positions'].numpy(),
1030
+ colors=None, # Use SH coefficients instead
1031
+ opacity=self.params['opacities'].numpy(),
1032
+ scales=self.params['scales'].numpy(),
1033
+ rotations=self.params['rotations'].numpy(),
1034
+ scale_modifier=self.config['scale_modifier'],
1035
+ viewmatrix=self.cameras[camera_idx]['world_to_camera'],
1036
+ projmatrix=self.cameras[camera_idx]['full_proj_matrix'],
1037
+ tan_fovx=self.cameras[camera_idx]['tan_fovx'],
1038
+ tan_fovy=self.cameras[camera_idx]['tan_fovy'],
1039
+ image_height=self.cameras[camera_idx]['height'],
1040
+ image_width=self.cameras[camera_idx]['width'],
1041
+ sh=self.params['shs'].numpy(), # Pass SH coefficients
1042
+ degree=self.config['sh_degree'],
1043
+ campos=self.cameras[camera_idx]['camera_center'],
1044
+ prefiltered=False,
1045
+ antialiasing=False,
1046
+ clamped=True
1047
+ )
1048
+
1049
+ radii = wp.to_torch(self.intermediate_buffers["radii"]).cpu().numpy()
1050
+ np_rendered_image = wp.to_torch(rendered_image).cpu().numpy()
1051
+ np_rendered_image = np_rendered_image.transpose(2, 0, 1)
1052
+
1053
+ #if iteration % self.config['save_interval'] == 0:
1054
+ if (
1055
+ iteration < 10 or
1056
+ #(iteration < 50 and iteration % 5 == 0) or
1057
+ #(iteration < 100 and iteration % 10 == 0) or
1058
+ #(iteration < 1000 and iteration % 100 == 0) or
1059
+ (iteration % 1000 == 0) or
1060
+ (iteration == num_iterations - 1)
1061
+ ):
1062
+ self.debug_log_and_save_images(np_rendered_image, target_image, depth_image, camera_idx, iteration)
1063
+
1064
+ # Calculate L1 loss
1065
+ l1_val = l1_loss(rendered_image, target_image)
1066
+
1067
+ # # Calculate SSIM, not used
1068
+ # ssim_val = ssim(rendered_image, target_image)
1069
+ # # Combined loss with weighted SSIM
1070
+ # lambda_dssim = self.config['lambda_dssim']
1071
+ # # loss = (1 - λ) * L1 + λ * (1 - SSIM)
1072
+ # loss = (1.0 - lambda_dssim) * l1_val + lambda_dssim * (1.0 - ssim_val)
1073
+
1074
+ loss = l1_val
1075
+ self.losses.append(loss)
1076
+ # Compute pixel gradients for image loss (dL/dColor)
1077
+ pixel_grad_buffer = compute_image_gradients(
1078
+ rendered_image, target_image, lambda_dssim=0
1079
+ )
1080
+
1081
+ # Prepare camera parameters
1082
+ camera = self.cameras[camera_idx]
1083
+ view_matrix = wp.mat44(camera['world_to_camera'].flatten())
1084
+ proj_matrix = wp.mat44(camera['full_proj_matrix'].flatten())
1085
+ campos = wp.vec3(camera['camera_center'][0], camera['camera_center'][1], camera['camera_center'][2])
1086
+
1087
+ # Create appropriate buffer dictionaries for the backward pass
1088
+ geom_buffer = {
1089
+ 'radii': self.intermediate_buffers['radii'],
1090
+ 'means2D': self.intermediate_buffers['points_xy_image'],
1091
+ 'conic_opacity': self.intermediate_buffers['conic_opacity'],
1092
+ 'rgb': self.intermediate_buffers['colors'],
1093
+ 'clamped': self.intermediate_buffers['clamped_state']
1094
+ }
1095
+
1096
+ binning_buffer = {
1097
+ 'point_list': self.intermediate_buffers['point_list']
1098
+ }
1099
+
1100
+ img_buffer = {
1101
+ 'ranges': self.intermediate_buffers['ranges'],
1102
+ 'final_Ts': self.intermediate_buffers['final_Ts'],
1103
+ 'n_contrib': self.intermediate_buffers['n_contrib']
1104
+ }
1105
+
1106
+ gradients = backward(
1107
+ # Core parameters
1108
+ background=np.array(self.config['background_color'], dtype=np.float32),
1109
+ means3D=self.params['positions'],
1110
+ dL_dpixels=pixel_grad_buffer,
1111
+
1112
+ # Model parameters (pass directly from self.params)
1113
+ opacity=self.params['opacities'],
1114
+ shs=self.params['shs'],
1115
+ scales=self.params['scales'],
1116
+ rotations=self.params['rotations'],
1117
+ scale_modifier=self.config['scale_modifier'],
1118
+
1119
+ # Camera parameters
1120
+ viewmatrix=view_matrix,
1121
+ projmatrix=proj_matrix,
1122
+ tan_fovx=camera['tan_fovx'],
1123
+ tan_fovy=camera['tan_fovy'],
1124
+ image_height=camera['height'],
1125
+ image_width=camera['width'],
1126
+ campos=campos,
1127
+
1128
+ # Forward output buffers
1129
+ radii=self.intermediate_buffers['radii'],
1130
+ means2D=self.intermediate_buffers['points_xy_image'],
1131
+ conic_opacity=self.intermediate_buffers['conic_opacity'],
1132
+ rgb=self.intermediate_buffers['colors'],
1133
+ cov3Ds=self.intermediate_buffers['cov3Ds'],
1134
+ clamped=self.intermediate_buffers['clamped_state'],
1135
+
1136
+ # Internal state buffers
1137
+ geom_buffer=geom_buffer,
1138
+ binning_buffer=binning_buffer,
1139
+ img_buffer=img_buffer,
1140
+
1141
+ # Algorithm parameters
1142
+ degree=self.config['sh_degree'],
1143
+ debug=False
1144
+ )
1145
+
1146
+ # 3. Copy gradients from backward result to the optimizer's gradient buffers
1147
+ wp.copy(self.grads['positions'], gradients['dL_dmean3D'])
1148
+ wp.copy(self.grads['scales'], gradients['dL_dscale'])
1149
+ wp.copy(self.grads['rotations'], gradients['dL_drot'])
1150
+ wp.copy(self.grads['opacities'], gradients['dL_dopacity'])
1151
+ wp.copy(self.grads['shs'], gradients['dL_dshs'])
1152
+
1153
+ # Update parameters
1154
+ self.optimizer_step(iteration)
1155
+
1156
+ # Update progress bar
1157
+ pbar.update(1)
1158
+ pbar.set_description(f"Loss: {loss:.6f}")
1159
+
1160
+ self.densification_and_pruning(iteration)
1161
+
1162
+ # Save checkpoint
1163
+ #if iteration % self.config['save_interval'] == 0 or iteration == num_iterations - 1:
1164
+
1165
+ if (
1166
+ iteration < 10 or
1167
+ # (iteration < 50 and iteration % 5 == 0) or
1168
+ # (iteration < 100 and iteration % 10 == 0) or
1169
+ # (iteration < 1000 and iteration % 100 == 0) or
1170
+ (iteration % 1000 == 0) or
1171
+ (iteration == num_iterations - 1)
1172
+ ):
1173
+ self.save_checkpoint(iteration)
1174
+ print("Training complete!")
1175
+
1176
+
1177
+ def visualize_camera_points_alignment_interactive(self):
1178
+ """Create an interactive 3D visualization with camera frustums and colored points"""
1179
+ try:
1180
+ import plotly.graph_objects as go
1181
+ from plotly.subplots import make_subplots
1182
+ import plotly.express as px
1183
+ except ImportError:
1184
+ print("plotly not found. Install with: pip install plotly")
1185
+ return
1186
+
1187
+ # Get data
1188
+ camera_positions = np.array([cam['camera_center'] for cam in self.cameras])
1189
+ points_np = wp.to_torch(self.params['positions']).cpu().numpy()
1190
+
1191
+ # Get SH coefficients for colors
1192
+ shs_np = wp.to_torch(self.params['shs']).cpu().numpy()
1193
+
1194
+ # Extract base colors from SH coefficients
1195
+ C0 = 0.28209479177387814 # Normalization constant for Y_00
1196
+ point_colors = np.zeros((len(points_np), 3), dtype=np.float32)
1197
+
1198
+ # Get only the DC component (first SH coefficient) for each point
1199
+ for i in range(len(points_np)):
1200
+ sh_dc = shs_np[i * 16] # First SH coefficient for each point
1201
+ rgb = sh_dc * C0 + 0.5
1202
+ point_colors[i] = np.clip(rgb, 0, 1)
1203
+
1204
+ # Sample points for better performance
1205
+ max_points = 5000
1206
+ if len(points_np) > max_points:
1207
+ indices = np.random.choice(len(points_np), max_points, replace=False)
1208
+ points_sample = points_np[indices]
1209
+ colors_sample = point_colors[indices]
1210
+ else:
1211
+ points_sample = points_np
1212
+ colors_sample = point_colors
1213
+
1214
+ # Convert colors to hex format for plotly
1215
+ colors_hex = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})' for r, g, b in colors_sample]
1216
+
1217
+ # Calculate extents
1218
+ cam_extent = np.max(np.abs(camera_positions))
1219
+ points_extent = np.max(np.abs(points_sample))
1220
+
1221
+ print(f"Camera extent: {cam_extent:.3f}")
1222
+ print(f"Points extent: {points_extent:.3f}")
1223
+ print(f"Scale ratio: {cam_extent/points_extent:.3f}")
1224
+
1225
+ # Create figure
1226
+ fig = make_subplots(
1227
+ rows=1, cols=1,
1228
+ specs=[[{"type": "scene"}]],
1229
+ subplot_titles=['3D Scene with Camera Frustums']
1230
+ )
1231
+
1232
+ # Colors for cameras
1233
+ n_cameras = len(camera_positions)
1234
+ camera_colors = px.colors.qualitative.Bold[:min(n_cameras, 10)]
1235
+ if n_cameras > 10:
1236
+ camera_colors = camera_colors * (n_cameras // 10 + 1)
1237
+
1238
+ # Add point cloud with actual colors
1239
+ fig.add_trace(
1240
+ go.Scatter3d(
1241
+ x=points_sample[:, 0],
1242
+ y=points_sample[:, 1],
1243
+ z=points_sample[:, 2],
1244
+ mode='markers',
1245
+ marker=dict(
1246
+ size=1.5,
1247
+ color=colors_hex,
1248
+ opacity=0.7
1249
+ ),
1250
+ name='Point Cloud',
1251
+ hovertemplate='Point<br>X: %{x:.3f}<br>Y: %{y:.3f}<br>Z: %{z:.3f}<extra></extra>'
1252
+ )
1253
+ )
1254
+
1255
+ # Create camera frustums
1256
+ frustum_scale = cam_extent * 0.2 # Increased frustum size
1257
+
1258
+ for i, (cam, pos) in enumerate(zip(self.cameras, camera_positions)):
1259
+ color = camera_colors[i % len(camera_colors)]
1260
+
1261
+ # Extract camera parameters correctly
1262
+ c2w = cam['camera_to_world']
1263
+
1264
+ # Camera coordinate system in world space
1265
+ right = c2w[:3, 0] # x-axis
1266
+ up = c2w[:3, 1] # y-axis
1267
+
1268
+ # IMPORTANT FIX: FLIP THE DIRECTION for correct frustum orientation
1269
+ # The camera looks along the +Z axis in camera space, which is the 3rd column of c2w
1270
+ forward = c2w[:3, 2] # Use +Z for dust3r/COLMAP convention
1271
+
1272
+ # Debug first few cameras
1273
+ if i < 3:
1274
+ print(f"Camera {i} coordinate system:")
1275
+ print(f" Position: {pos}")
1276
+ print(f" Right: {right}")
1277
+ print(f" Up: {up}")
1278
+ print(f" Forward: {forward}")
1279
+
1280
+ # Calculate frustum parameters
1281
+ aspect_ratio = cam['width'] / cam['height']
1282
+ tan_fov_x = cam['tan_fovx']
1283
+ tan_fov_y = cam['tan_fovy']
1284
+
1285
+ # Frustum corners at near and far planes
1286
+ near_dist = frustum_scale * 0.1
1287
+ far_dist = frustum_scale
1288
+
1289
+ # Near plane corners
1290
+ tl_near = pos + forward * near_dist - right * near_dist * tan_fov_x + up * near_dist * tan_fov_y
1291
+ tr_near = pos + forward * near_dist + right * near_dist * tan_fov_x + up * near_dist * tan_fov_y
1292
+ bl_near = pos + forward * near_dist - right * near_dist * tan_fov_x - up * near_dist * tan_fov_y
1293
+ br_near = pos + forward * near_dist + right * near_dist * tan_fov_x - up * near_dist * tan_fov_y
1294
+
1295
+ # Far plane corners
1296
+ tl_far = pos + forward * far_dist - right * far_dist * tan_fov_x + up * far_dist * tan_fov_y
1297
+ tr_far = pos + forward * far_dist + right * far_dist * tan_fov_x + up * far_dist * tan_fov_y
1298
+ bl_far = pos + forward * far_dist - right * far_dist * tan_fov_x - up * far_dist * tan_fov_y
1299
+ br_far = pos + forward * far_dist + right * far_dist * tan_fov_x - up * far_dist * tan_fov_y
1300
+
1301
+ # Camera position marker
1302
+ fig.add_trace(
1303
+ go.Scatter3d(
1304
+ x=[pos[0]],
1305
+ y=[pos[1]],
1306
+ z=[pos[2]],
1307
+ mode='markers',
1308
+ marker=dict(
1309
+ size=8, # Larger marker
1310
+ color=color,
1311
+ symbol='diamond',
1312
+ ),
1313
+ name=f'Camera {i}',
1314
+ hovertemplate=f'Camera {i}<br>X: %{{x:.3f}}<br>Y: %{{y:.3f}}<br>Z: %{{z:.3f}}<extra></extra>'
1315
+ )
1316
+ )
1317
+
1318
+ # Draw frustum edges
1319
+ lines_x = []
1320
+ lines_y = []
1321
+ lines_z = []
1322
+
1323
+ # Helper to add a line
1324
+ def add_line(p1, p2):
1325
+ lines_x.extend([p1[0], p2[0], None])
1326
+ lines_y.extend([p1[1], p2[1], None])
1327
+ lines_z.extend([p1[2], p2[2], None])
1328
+
1329
+ # Near plane
1330
+ add_line(tl_near, tr_near)
1331
+ add_line(tr_near, br_near)
1332
+ add_line(br_near, bl_near)
1333
+ add_line(bl_near, tl_near)
1334
+
1335
+ # Far plane
1336
+ add_line(tl_far, tr_far)
1337
+ add_line(tr_far, br_far)
1338
+ add_line(br_far, bl_far)
1339
+ add_line(bl_far, tl_far)
1340
+
1341
+ # Connecting edges
1342
+ add_line(tl_near, tl_far)
1343
+ add_line(tr_near, tr_far)
1344
+ add_line(bl_near, bl_far)
1345
+ add_line(br_near, br_far)
1346
+
1347
+ # Camera to near plane corners
1348
+ add_line(pos, tl_near)
1349
+ add_line(pos, tr_near)
1350
+ add_line(pos, bl_near)
1351
+ add_line(pos, br_near)
1352
+
1353
+ # Add frustum lines
1354
+ fig.add_trace(
1355
+ go.Scatter3d(
1356
+ x=lines_x,
1357
+ y=lines_y,
1358
+ z=lines_z,
1359
+ mode='lines',
1360
+ line=dict(
1361
+ color=color,
1362
+ width=2
1363
+ ),
1364
+ name=f'Frustum {i}',
1365
+ showlegend=False,
1366
+ hoverinfo='none'
1367
+ )
1368
+ )
1369
+
1370
+ # Add coordinate system axes with LARGER SIZE
1371
+ axis_length = frustum_scale * 0.15 # Increased from 0.05 to 0.15
1372
+
1373
+ # Right direction (X axis) - Red
1374
+ fig.add_trace(
1375
+ go.Scatter3d(
1376
+ x=[pos[0], pos[0] + right[0] * axis_length],
1377
+ y=[pos[1], pos[1] + right[1] * axis_length],
1378
+ z=[pos[2], pos[2] + right[2] * axis_length],
1379
+ mode='lines',
1380
+ line=dict(color='red', width=4), # Thicker line
1381
+ name='X (Right)' if i == 0 else '',
1382
+ showlegend=i==0,
1383
+ hoverinfo='none'
1384
+ )
1385
+ )
1386
+
1387
+ # Up direction (Y axis) - Green
1388
+ fig.add_trace(
1389
+ go.Scatter3d(
1390
+ x=[pos[0], pos[0] + up[0] * axis_length],
1391
+ y=[pos[1], pos[1] + up[1] * axis_length],
1392
+ z=[pos[2], pos[2] + up[2] * axis_length],
1393
+ mode='lines',
1394
+ line=dict(color='green', width=4), # Thicker line
1395
+ name='Y (Up)' if i == 0 else '',
1396
+ showlegend=i==0,
1397
+ hoverinfo='none'
1398
+ )
1399
+ )
1400
+
1401
+ # Forward direction (Z axis) - Blue
1402
+ fig.add_trace(
1403
+ go.Scatter3d(
1404
+ x=[pos[0], pos[0] + forward[0] * axis_length],
1405
+ y=[pos[1], pos[1] + forward[1] * axis_length],
1406
+ z=[pos[2], pos[2] + forward[2] * axis_length],
1407
+ mode='lines',
1408
+ line=dict(color='blue', width=4), # Thicker line
1409
+ name='Z (Forward)' if i == 0 else '',
1410
+ showlegend=i==0,
1411
+ hoverinfo='none'
1412
+ )
1413
+ )
1414
+
1415
+ # Add cones at the end of each axis for better direction visibility
1416
+ cone_size = axis_length * 0.2 # Size of cone relative to axis length
1417
+
1418
+ for axis_dir, color, axis_name in [
1419
+ (right, 'red', 'X'),
1420
+ (up, 'green', 'Y'),
1421
+ (forward, 'blue', 'Z')
1422
+ ]:
1423
+ # End point of axis
1424
+ end_point = pos + axis_dir * axis_length
1425
+
1426
+ # Add a sphere marker at axis end
1427
+ fig.add_trace(
1428
+ go.Scatter3d(
1429
+ x=[end_point[0]],
1430
+ y=[end_point[1]],
1431
+ z=[end_point[2]],
1432
+ mode='markers',
1433
+ marker=dict(
1434
+ size=6, # Size of endpoint marker
1435
+ color=color,
1436
+ symbol='circle'
1437
+ ),
1438
+ name=f'{axis_name} Axis End' if i == 0 else '',
1439
+ showlegend=False,
1440
+ hoverinfo='none'
1441
+ )
1442
+ )
1443
+
1444
+ # Update layout for better visualization
1445
+ fig.update_layout(
1446
+ title=dict(
1447
+ text=f'Interactive Camera Frustums and Point Cloud<br>'
1448
+ f'<sub>Cameras: {n_cameras}, Points: {len(points_sample)}/{len(points_np)}, '
1449
+ f'Scale ratio: {cam_extent/points_extent:.2f}</sub>',
1450
+ x=0.5
1451
+ ),
1452
+ scene=dict(
1453
+ xaxis_title='X',
1454
+ yaxis_title='Y',
1455
+ zaxis_title='Z',
1456
+ aspectmode='data', # 'cube' or 'data'
1457
+ camera=dict(
1458
+ eye=dict(x=1.8, y=1.8, z=1.8) # Adjusted default view
1459
+ ),
1460
+ annotations=[
1461
+ dict(
1462
+ showarrow=False,
1463
+ x=0.05,
1464
+ y=0.05,
1465
+ z=0.05,
1466
+ text="Camera axes:<br>Red: X (right)<br>Green: Y (up)<br>Blue: Z (forward)",
1467
+ xanchor="left",
1468
+ xshift=10,
1469
+ opacity=0.8,
1470
+ font=dict(size=14)
1471
+ )
1472
+ ]
1473
+ ),
1474
+ height=900,
1475
+ width=1000,
1476
+ margin=dict(l=0, r=0, t=50, b=0)
1477
+ )
1478
+
1479
+ # Add axis legend
1480
+ for color, name in [('red', 'X (Right)'), ('green', 'Y (Up)'), ('blue', 'Z (Forward)')]:
1481
+ fig.add_trace(
1482
+ go.Scatter3d(
1483
+ x=[None], y=[None], z=[None],
1484
+ mode='lines',
1485
+ line=dict(color=color, width=6),
1486
+ name=name,
1487
+ showlegend=True
1488
+ )
1489
+ )
1490
+
1491
+ # Save interactive HTML
1492
+ html_path = self.output_path / 'camera_frustums_visualization.html'
1493
+ fig.write_html(str(html_path))
1494
+ print(f"Interactive visualization saved to: {html_path}")
1495
+
1496
+ # Show in browser if possible
1497
+ try:
1498
+ fig.show()
1499
+ except Exception as e:
1500
+ print(f"Could not display in browser: {e}")
1501
+ print(f"Open {html_path} in your browser to view the interactive plot")
1502
+
1503
+ return fig
1504
+ def debug_camera_and_points_alignment(self):
1505
+ """Enhanced debug function with scale analysis"""
1506
+ # Get camera positions
1507
+ camera_positions = np.array([cam['camera_center'] for cam in self.cameras])
1508
+
1509
+ # Get point cloud positions
1510
+ points_np = wp.to_torch(self.params['positions']).cpu().numpy()
1511
+
1512
+ # Calculate detailed statistics
1513
+ cam_stats = {
1514
+ 'min': np.min(camera_positions, axis=0),
1515
+ 'max': np.max(camera_positions, axis=0),
1516
+ 'mean': np.mean(camera_positions, axis=0),
1517
+ 'extent': np.max(np.abs(camera_positions))
1518
+ }
1519
+
1520
+ points_stats = {
1521
+ 'min': np.min(points_np, axis=0),
1522
+ 'max': np.max(points_np, axis=0),
1523
+ 'mean': np.mean(points_np, axis=0),
1524
+ 'extent': np.max(np.abs(points_np))
1525
+ }
1526
+
1527
+ print("=== ALIGNMENT DEBUG ===")
1528
+ print(f"Cameras ({len(camera_positions)}):")
1529
+ print(f" Min: [{cam_stats['min'][0]:8.3f}, {cam_stats['min'][1]:8.3f}, {cam_stats['min'][2]:8.3f}]")
1530
+ print(f" Max: [{cam_stats['max'][0]:8.3f}, {cam_stats['max'][1]:8.3f}, {cam_stats['max'][2]:8.3f}]")
1531
+ print(f" Mean:[{cam_stats['mean'][0]:8.3f}, {cam_stats['mean'][1]:8.3f}, {cam_stats['mean'][2]:8.3f}]")
1532
+ print(f" Extent: {cam_stats['extent']:.3f}")
1533
+
1534
+ print(f"\nPoints ({len(points_np)}):")
1535
+ print(f" Min: [{points_stats['min'][0]:8.3f}, {points_stats['min'][1]:8.3f}, {points_stats['min'][2]:8.3f}]")
1536
+ print(f" Max: [{points_stats['max'][0]:8.3f}, {points_stats['max'][1]:8.3f}, {points_stats['max'][2]:8.3f}]")
1537
+ print(f" Mean:[{points_stats['mean'][0]:8.3f}, {points_stats['mean'][1]:8.3f}, {points_stats['mean'][2]:8.3f}]")
1538
+ print(f" Extent: {points_stats['extent']:.3f}")
1539
+
1540
+ # Scale analysis
1541
+ scale_ratio = cam_stats['extent'] / points_stats['extent'] if points_stats['extent'] > 0 else float('inf')
1542
+ print(f"\nScale Analysis:")
1543
+ print(f" Scale ratio (cam/points): {scale_ratio:.3f}")
1544
+
1545
+ if scale_ratio > 10:
1546
+ print(" ⚠️ WARNING: Cameras much larger than points - may need to scale points up")
1547
+ print(f" Suggested point scale factor: {scale_ratio/10:.3f}")
1548
+ elif scale_ratio < 0.1:
1549
+ print(" ⚠️ WARNING: Points much larger than cameras - may need to scale points down")
1550
+ print(f" Suggested point scale factor: {scale_ratio*10:.3f}")
1551
+ else:
1552
+ print(" ✅ Scale ratio looks reasonable")
1553
+
1554
+ # Distance analysis
1555
+ center_distance = np.linalg.norm(cam_stats['mean'] - points_stats['mean'])
1556
+ print(f"\nCenter separation: {center_distance:.3f}")
1557
+ if center_distance > max(cam_stats['extent'], points_stats['extent']):
1558
+ print(" ⚠️ WARNING: Camera and point centers are far apart - possible coordinate system issue")
1559
+
1560
+ return cam_stats, points_stats
1561
+ def main():
1562
+ parser = argparse.ArgumentParser(description="Train 3D Gaussian Splatting model with Colmap")
1563
+ parser.add_argument("--dataset", type=str, default="./data_/scenes/steak_is",
1564
+ help="Path to NeRF dataset directory (default: Lego dataset)")
1565
+ parser.add_argument("--output", type=str, default="./output/steak_is", help="Output directory")
1566
+
1567
+ args = parser.parse_args()
1568
+
1569
+ # Create trainer and start training
1570
+ trainer = NeRFGaussianSplattingTrainer(
1571
+ dataset_path=args.dataset,
1572
+ output_path=args.output,
1573
+ )
1574
+
1575
+ # Debug alignment
1576
+ trainer.debug_camera_and_points_alignment()
1577
+
1578
+ # Create interactive visualization
1579
+ trainer.visualize_camera_points_alignment_interactive()
1580
+
1581
+ # Start training
1582
+ trainer.train()
1583
+
1584
+
1585
+ if __name__ == "__main__":
1586
+ main()
gs/train_vdpm.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train 3D Gaussian Splatting from VDPM Output
3
+
4
+ Loads VDPM reconstruction (tracks.npz, poses.npz, images/) and trains 3DGS.
5
+ Supports per-frame training or combined multi-timestep training.
6
+
7
+ Usage (from 4dgs-dpm root):
8
+ python -m gs.train_vdpm --input ./vdpm/input_images_XXXX --output ./output/vdpm_scene
9
+ python -m gs.train_vdpm --input ./vdpm/input_images_XXXX --output ./output --frame 0
10
+
11
+ Or directly:
12
+ cd gs
13
+ python train_vdpm.py --input ../vdpm/input_images_XXXX --output ./output
14
+ """
15
+
16
+ import os
17
+ import sys
18
+ import json
19
+ import argparse
20
+ import numpy as np
21
+ from pathlib import Path
22
+ from scipy.spatial import cKDTree
23
+ import imageio
24
+ import torch
25
+ import matplotlib.pyplot as plt
26
+
27
+ # Ensure gs/ modules are importable when running from root
28
+ _gs_dir = Path(__file__).parent.resolve()
29
+ if str(_gs_dir) not in sys.path:
30
+ sys.path.insert(0, str(_gs_dir))
31
+
32
+ import warp as wp
33
+ from tqdm import tqdm
34
+
35
+ from forward import render_gaussians
36
+ from backward import backward
37
+ from optimizer import adam_update, prune_gaussians
38
+ from config import GaussianParams, DEVICE
39
+ from utils.point_cloud_utils import save_ply
40
+ from utils.math_utils import world_to_view, projection_matrix
41
+ from loss import l1_loss, compute_image_gradients
42
+
43
+ wp.init()
44
+
45
+
46
+ def decode_poses(pose_enc: np.ndarray, image_hw: tuple) -> tuple:
47
+ """
48
+ Decode VGGT pose encodings to extrinsic and intrinsic matrices.
49
+
50
+ Args:
51
+ pose_enc: (1, N, 9) pose encoding from VDPM
52
+ image_hw: (H, W) image dimensions
53
+
54
+ Returns:
55
+ extrinsics: (N, 4, 4) world-to-camera matrices
56
+ intrinsics: (N, 3, 3) camera intrinsic matrices
57
+ """
58
+ try:
59
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
60
+
61
+ pose_enc_t = torch.from_numpy(pose_enc).float()
62
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc_t, image_hw)
63
+
64
+ # extrinsic is (1, N, 3, 4) camera-from-world
65
+ extrinsic = extrinsic[0].numpy() # (N, 3, 4)
66
+ intrinsic = intrinsic[0].numpy() # (N, 3, 3)
67
+
68
+ # Add homogeneous row to extrinsic
69
+ N = extrinsic.shape[0]
70
+ bottom = np.array([0, 0, 0, 1], dtype=np.float32).reshape(1, 1, 4)
71
+ bottom = np.tile(bottom, (N, 1, 1))
72
+ extrinsics_4x4 = np.concatenate([extrinsic, bottom], axis=1) # (N, 4, 4)
73
+
74
+ return extrinsics_4x4, intrinsic
75
+
76
+ except ImportError:
77
+ print("Warning: vggt not available. Using identity poses.")
78
+ N = pose_enc.shape[1]
79
+ extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
80
+
81
+ H, W = image_hw
82
+ fx = fy = max(H, W)
83
+ cx, cy = W / 2, H / 2
84
+ intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
85
+ intrinsics = np.tile(intrinsic, (N, 1, 1))
86
+
87
+ return extrinsics, intrinsics
88
+
89
+
90
+ def load_vdpm_data(input_path: str) -> dict:
91
+ """Load all VDPM outputs from a directory."""
92
+ input_path = Path(input_path)
93
+
94
+ # Load tracks/points
95
+ tracks_path = input_path / "tracks.npz"
96
+ output_path = input_path / "output_4d.npz"
97
+
98
+ if tracks_path.exists():
99
+ data = np.load(tracks_path)
100
+ elif output_path.exists():
101
+ data = np.load(output_path)
102
+ else:
103
+ raise FileNotFoundError(f"No tracks.npz or output_4d.npz in {input_path}")
104
+
105
+ world_points = data['world_points']
106
+ world_points_conf = data['world_points_conf']
107
+ num_views = int(data.get('num_views', 1))
108
+ num_timesteps = int(data.get('num_timesteps', world_points.shape[0]))
109
+
110
+ # Handle multi-view format
111
+ if world_points.ndim == 5:
112
+ T, V, H, W, _ = world_points.shape
113
+ print(f"Multi-view: {T} timesteps × {V} views × {H}×{W}")
114
+ else:
115
+ T, H, W, _ = world_points.shape
116
+ V = 1
117
+ world_points = world_points[:, np.newaxis, :, :, :]
118
+ world_points_conf = world_points_conf[:, np.newaxis, :, :]
119
+ print(f"Single-view: {T} timesteps × {H}×{W}")
120
+
121
+ # Load poses
122
+ poses_path = input_path / "poses.npz"
123
+ pose_enc = None
124
+ if poses_path.exists():
125
+ pose_data = np.load(poses_path)
126
+ pose_enc = pose_data.get('pose_enc')
127
+ print(f"Loaded poses: {pose_enc.shape if pose_enc is not None else 'None'}")
128
+
129
+ # Load images
130
+ images_dir = input_path / "images"
131
+ image_paths = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
132
+ images = []
133
+ for img_path in image_paths:
134
+ img = imageio.imread(img_path)
135
+ if img.ndim == 2:
136
+ img = np.stack([img, img, img], axis=-1)
137
+ elif img.shape[-1] == 4:
138
+ img = img[..., :3]
139
+ images.append(img.astype(np.float32) / 255.0)
140
+ images = np.stack(images, axis=0) # (N, H, W, 3)
141
+ print(f"Loaded {len(images)} images, shape: {images.shape}")
142
+
143
+ # Load metadata
144
+ meta_path = input_path / "meta.json"
145
+ meta = {}
146
+ if meta_path.exists():
147
+ with open(meta_path) as f:
148
+ meta = json.load(f)
149
+
150
+ return {
151
+ 'world_points': world_points, # (T, V, H, W, 3)
152
+ 'world_points_conf': world_points_conf, # (T, V, H, W)
153
+ 'pose_enc': pose_enc,
154
+ 'images': images, # (N, H, W, 3)
155
+ 'num_views': num_views,
156
+ 'num_timesteps': num_timesteps,
157
+ 'T': T, 'V': V, 'H': H, 'W': W,
158
+ 'meta': meta,
159
+ }
160
+
161
+
162
+ def extract_frame_pointcloud(data: dict, frame_idx: int, conf_threshold: float = 50.0):
163
+ """
164
+ Extract point cloud and colors for a specific frame.
165
+
166
+ Returns:
167
+ positions: (N, 3) XYZ
168
+ colors: (N, 3) RGB [0,1]
169
+ confidence: (N,) confidence scores
170
+ """
171
+ T, V, H, W = data['T'], data['V'], data['H'], data['W']
172
+
173
+ # Get points for this frame (merge all views)
174
+ pts = data['world_points'][frame_idx] # (V, H, W, 3)
175
+ conf = data['world_points_conf'][frame_idx] # (V, H, W)
176
+
177
+ # Flatten
178
+ pts_flat = pts.reshape(-1, 3) # (V*H*W, 3)
179
+ conf_flat = conf.reshape(-1) # (V*H*W,)
180
+
181
+ # Get colors from images
182
+ # Images are interleaved: [cam0_t0, cam1_t0, cam0_t1, cam1_t1, ...]
183
+ start_idx = frame_idx * V
184
+ end_idx = start_idx + V
185
+ frame_images = data['images'][start_idx:end_idx] # (V, H_img, W_img, 3)
186
+
187
+ # Resize images to match point cloud if needed
188
+ img_H, img_W = frame_images.shape[1:3]
189
+ if img_H != H or img_W != W:
190
+ from scipy.ndimage import zoom
191
+ scale_h = H / img_H
192
+ scale_w = W / img_W
193
+ resized = []
194
+ for v in range(V):
195
+ img_v = zoom(frame_images[v], (scale_h, scale_w, 1), order=1)
196
+ resized.append(img_v)
197
+ frame_images = np.stack(resized, axis=0)
198
+
199
+ colors_flat = frame_images.reshape(-1, 3) # (V*H*W, 3)
200
+
201
+ # Filter by confidence
202
+ if conf_threshold > 0:
203
+ thresh = np.percentile(conf_flat, conf_threshold)
204
+ mask = (conf_flat >= thresh) & (conf_flat > 1e-5)
205
+ else:
206
+ mask = conf_flat > 1e-5
207
+
208
+ # Also filter NaN/Inf
209
+ valid_pts = np.isfinite(pts_flat).all(axis=1)
210
+ mask = mask & valid_pts
211
+
212
+ return pts_flat[mask], colors_flat[mask], conf_flat[mask]
213
+
214
+
215
+ def build_cameras(data: dict, frame_idx: int) -> list:
216
+ """
217
+ Build camera dictionaries for training from VDPM data.
218
+
219
+ Returns list of camera dicts compatible with 3DGS training.
220
+ """
221
+ T, V, H, W = data['T'], data['V'], data['H'], data['W']
222
+ images = data['images']
223
+ img_H, img_W = images.shape[1:3]
224
+
225
+ # Decode poses
226
+ pose_enc = data.get('pose_enc')
227
+ if pose_enc is not None:
228
+ extrinsics, intrinsics = decode_poses(pose_enc, (img_H, img_W))
229
+ else:
230
+ # Fallback: identity poses with reasonable intrinsics
231
+ N = T * V
232
+ extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
233
+ fx = fy = max(img_H, img_W)
234
+ cx, cy = img_W / 2, img_H / 2
235
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
236
+ intrinsics = np.tile(K, (N, 1, 1))
237
+
238
+ cameras = []
239
+
240
+ # Get camera indices for this frame
241
+ for v in range(V):
242
+ img_idx = frame_idx * V + v
243
+
244
+ if img_idx >= len(extrinsics):
245
+ continue
246
+
247
+ extrinsic = extrinsics[img_idx] # (4, 4) camera-from-world
248
+ intrinsic = intrinsics[img_idx] # (3, 3)
249
+
250
+ # Extract components
251
+ R = extrinsic[:3, :3]
252
+ t = extrinsic[:3, 3]
253
+
254
+ fx, fy = intrinsic[0, 0], intrinsic[1, 1]
255
+ cx, cy = intrinsic[0, 2], intrinsic[1, 2]
256
+
257
+ # Camera center in world coords
258
+ camera_center = -R.T @ t
259
+
260
+ # FOV from intrinsics
261
+ fov_x = 2 * np.arctan(img_W / (2 * fx))
262
+ fov_y = 2 * np.arctan(img_H / (2 * fy))
263
+
264
+ # Build matrices exactly like render.py does for Warp/OpenGL compatibility
265
+ # Warp uses column-major (OpenGL convention), so matrices must be transposed
266
+ world_to_camera = np.eye(4, dtype=np.float32)
267
+ world_to_camera[:3, :3] = R
268
+ world_to_camera[:3, 3] = t
269
+ world_to_camera = world_to_camera.T # Transpose for Warp/OpenGL!
270
+
271
+ # Projection matrix (transposed for Warp/OpenGL)
272
+ near, far = 0.01, 100.0
273
+ proj_matrix = projection_matrix(fovx=fov_x, fovy=fov_y, znear=near, zfar=far).T
274
+
275
+ # Full projection = view @ proj
276
+ full_proj_matrix = world_to_camera @ proj_matrix
277
+
278
+ cameras.append({
279
+ 'camera_id': img_idx,
280
+ 'width': img_W,
281
+ 'height': img_H,
282
+ 'world_to_camera': world_to_camera, # Transposed for Warp/OpenGL
283
+ 'camera_to_world': np.linalg.inv(world_to_camera),
284
+ 'camera_center': camera_center,
285
+ 'full_proj_matrix': full_proj_matrix,
286
+ 'tan_fovx': np.tan(fov_x / 2),
287
+ 'tan_fovy': np.tan(fov_y / 2),
288
+ 'fx': fx, 'fy': fy,
289
+ 'cx': cx, 'cy': cy,
290
+ })
291
+
292
+ return cameras
293
+
294
+
295
+ @wp.kernel
296
+ def init_rotations_opacities(
297
+ rotations: wp.array(dtype=wp.vec4),
298
+ opacities: wp.array(dtype=float),
299
+ num_points: int
300
+ ):
301
+ i = wp.tid()
302
+ if i >= num_points:
303
+ return
304
+ rotations[i] = wp.vec4(1.0, 0.0, 0.0, 0.0)
305
+ opacities[i] = 0.5
306
+
307
+
308
+ @wp.kernel
309
+ def zero_gradients(
310
+ pos_grad: wp.array(dtype=wp.vec3),
311
+ scale_grad: wp.array(dtype=wp.vec3),
312
+ rot_grad: wp.array(dtype=wp.vec4),
313
+ opacity_grad: wp.array(dtype=float),
314
+ sh_grad: wp.array(dtype=wp.vec3),
315
+ num_points: int
316
+ ):
317
+ i = wp.tid()
318
+ if i >= num_points:
319
+ return
320
+
321
+ pos_grad[i] = wp.vec3(0.0, 0.0, 0.0)
322
+ scale_grad[i] = wp.vec3(0.0, 0.0, 0.0)
323
+ rot_grad[i] = wp.vec4(0.0, 0.0, 0.0, 0.0)
324
+ opacity_grad[i] = 0.0
325
+
326
+ for j in range(16):
327
+ idx = i * 16 + j
328
+ sh_grad[idx] = wp.vec3(0.0, 0.0, 0.0)
329
+
330
+
331
+ class VDPM3DGSTrainer:
332
+ """Train 3DGS from VDPM point cloud initialization."""
333
+
334
+ def __init__(self, data: dict, frame_idx: int, output_path: str, conf_threshold: float = 50.0):
335
+ self.output_path = Path(output_path)
336
+ self.output_path.mkdir(parents=True, exist_ok=True)
337
+
338
+ # Create output directories for renders
339
+ self.render_dir = self.output_path / f"frame_{frame_idx}" / "renders"
340
+ self.render_dir.mkdir(parents=True, exist_ok=True)
341
+
342
+ self.config = GaussianParams.get_config_dict()
343
+ self.frame_idx = frame_idx
344
+
345
+ # Extract point cloud for this frame
346
+ print(f"Extracting point cloud for frame {frame_idx}...")
347
+ positions, colors, confidence = extract_frame_pointcloud(data, frame_idx, conf_threshold)
348
+ self.num_points = len(positions)
349
+ print(f"Got {self.num_points} points")
350
+
351
+ # Build cameras
352
+ self.cameras = build_cameras(data, frame_idx)
353
+ print(f"Built {len(self.cameras)} cameras")
354
+
355
+ # Store images
356
+ V = data['V']
357
+ start_idx = frame_idx * V
358
+ end_idx = start_idx + V
359
+ self.images = data['images'][start_idx:end_idx]
360
+
361
+ # Initialize Gaussian parameters
362
+ self.params = self._init_params(positions, colors)
363
+ self.grads = self._create_grad_arrays()
364
+ self.adam_m = self._create_grad_arrays()
365
+ self.adam_v = self._create_grad_arrays()
366
+
367
+ self.losses = []
368
+ self.intermediate_buffers = {}
369
+
370
+ def _init_params(self, positions: np.ndarray, colors: np.ndarray):
371
+ """Initialize Gaussian parameters from point cloud."""
372
+ N = self.num_points
373
+
374
+ # Positions
375
+ positions_wp = wp.array(positions.astype(np.float32), dtype=wp.vec3, device=DEVICE)
376
+
377
+ # Scales from KNN
378
+ if N > 3:
379
+ tree = cKDTree(positions)
380
+ dists, _ = tree.query(positions, k=4)
381
+ avg_dist = np.mean(dists[:, 1:], axis=1)
382
+ scales_np = np.clip(avg_dist, 0.001, 1.0)[:, np.newaxis] * np.ones((N, 3))
383
+ else:
384
+ scales_np = np.full((N, 3), 0.01, dtype=np.float32)
385
+ scales_wp = wp.array(scales_np.astype(np.float32), dtype=wp.vec3, device=DEVICE)
386
+
387
+ # Rotations and opacities
388
+ rotations_wp = wp.zeros(N, dtype=wp.vec4, device=DEVICE)
389
+ opacities_wp = wp.zeros(N, dtype=float, device=DEVICE)
390
+ wp.launch(init_rotations_opacities, dim=N, inputs=[rotations_wp, opacities_wp, N])
391
+
392
+ # SH coefficients from colors
393
+ C0 = 0.28209479177387814
394
+ shs_np = np.zeros((N * 16, 3), dtype=np.float32)
395
+ shs_np[::16] = (colors - 0.5) / C0
396
+ shs_wp = wp.array(shs_np, dtype=wp.vec3, device=DEVICE)
397
+
398
+ return {
399
+ 'positions': positions_wp,
400
+ 'scales': scales_wp,
401
+ 'rotations': rotations_wp,
402
+ 'opacities': opacities_wp,
403
+ 'shs': shs_wp,
404
+ }
405
+
406
+ def _create_grad_arrays(self):
407
+ N = self.num_points
408
+ return {
409
+ 'positions': wp.zeros(N, dtype=wp.vec3, device=DEVICE),
410
+ 'scales': wp.zeros(N, dtype=wp.vec3, device=DEVICE),
411
+ 'rotations': wp.zeros(N, dtype=wp.vec4, device=DEVICE),
412
+ 'opacities': wp.zeros(N, dtype=float, device=DEVICE),
413
+ 'shs': wp.zeros(N * 16, dtype=wp.vec3, device=DEVICE),
414
+ }
415
+
416
+ def zero_grad(self):
417
+ wp.launch(zero_gradients, dim=self.num_points, inputs=[
418
+ self.grads['positions'], self.grads['scales'],
419
+ self.grads['rotations'], self.grads['opacities'],
420
+ self.grads['shs'], self.num_points
421
+ ])
422
+
423
+ def train(self, num_iterations: int = 3000):
424
+ """Train the 3DGS model."""
425
+ print(f"Training for {num_iterations} iterations...")
426
+
427
+ # Save iteration 0 (initial state before training)
428
+ print("Saving initial state (iteration 0)...")
429
+ self.save(0)
430
+
431
+ with tqdm(total=num_iterations) as pbar:
432
+ for it in range(num_iterations):
433
+ self.zero_grad()
434
+
435
+ # Pick a random camera
436
+ cam_idx = np.random.randint(len(self.cameras))
437
+ camera = self.cameras[cam_idx]
438
+ target = self.images[cam_idx]
439
+
440
+ # Render
441
+ rendered, depth, self.intermediate_buffers = render_gaussians(
442
+ background=np.array(self.config['background_color'], dtype=np.float32),
443
+ means3D=self.params['positions'].numpy(),
444
+ colors=None,
445
+ opacity=self.params['opacities'].numpy(),
446
+ scales=self.params['scales'].numpy(),
447
+ rotations=self.params['rotations'].numpy(),
448
+ scale_modifier=1.0,
449
+ viewmatrix=camera['world_to_camera'],
450
+ projmatrix=camera['full_proj_matrix'],
451
+ tan_fovx=camera['tan_fovx'],
452
+ tan_fovy=camera['tan_fovy'],
453
+ image_height=camera['height'],
454
+ image_width=camera['width'],
455
+ sh=self.params['shs'].numpy(),
456
+ degree=3,
457
+ campos=camera['camera_center'],
458
+ prefiltered=False,
459
+ antialiasing=True,
460
+ )
461
+
462
+ # Compute loss
463
+ rendered_np = wp.to_torch(rendered).cpu().numpy()
464
+ if rendered_np.shape[0] == 3:
465
+ rendered_np = np.transpose(rendered_np, (1, 2, 0))
466
+
467
+ target_wp = wp.array(target.astype(np.float32), dtype=wp.vec3, device=DEVICE)
468
+ loss = l1_loss(rendered, target_wp)
469
+ self.losses.append(loss)
470
+
471
+ # Compute pixel gradients for backward pass
472
+ pixel_grad_buffer = compute_image_gradients(
473
+ rendered, target_wp, lambda_dssim=0
474
+ )
475
+
476
+ # Prepare camera parameters
477
+ view_matrix = wp.mat44(camera['world_to_camera'].flatten())
478
+ proj_matrix = wp.mat44(camera['full_proj_matrix'].flatten())
479
+ campos = wp.vec3(camera['camera_center'][0], camera['camera_center'][1], camera['camera_center'][2])
480
+
481
+ # Prepare buffers for backward pass
482
+ geom_buffer = {
483
+ 'radii': self.intermediate_buffers['radii'],
484
+ 'means2D': self.intermediate_buffers['points_xy_image'],
485
+ 'conic_opacity': self.intermediate_buffers['conic_opacity'],
486
+ 'rgb': self.intermediate_buffers['colors'],
487
+ 'clamped': self.intermediate_buffers['clamped_state']
488
+ }
489
+
490
+ binning_buffer = {
491
+ 'point_list': self.intermediate_buffers['point_list']
492
+ }
493
+
494
+ img_buffer = {
495
+ 'ranges': self.intermediate_buffers['ranges'],
496
+ 'final_Ts': self.intermediate_buffers['final_Ts'],
497
+ 'n_contrib': self.intermediate_buffers['n_contrib']
498
+ }
499
+
500
+ # Backward pass
501
+ gradients = backward(
502
+ # Core parameters
503
+ background=np.array(self.config['background_color'], dtype=np.float32),
504
+ means3D=self.params['positions'],
505
+ dL_dpixels=pixel_grad_buffer,
506
+
507
+ # Model parameters
508
+ opacity=self.params['opacities'],
509
+ shs=self.params['shs'],
510
+ scales=self.params['scales'],
511
+ rotations=self.params['rotations'],
512
+ scale_modifier=self.config['scale_modifier'],
513
+
514
+ # Camera parameters
515
+ viewmatrix=view_matrix,
516
+ projmatrix=proj_matrix,
517
+ tan_fovx=camera['tan_fovx'],
518
+ tan_fovy=camera['tan_fovy'],
519
+ image_height=camera['height'],
520
+ image_width=camera['width'],
521
+ campos=campos,
522
+
523
+ # Forward output buffers
524
+ radii=self.intermediate_buffers['radii'],
525
+ means2D=self.intermediate_buffers['points_xy_image'],
526
+ conic_opacity=self.intermediate_buffers['conic_opacity'],
527
+ rgb=self.intermediate_buffers['colors'],
528
+ cov3Ds=self.intermediate_buffers['cov3Ds'],
529
+ clamped=self.intermediate_buffers['clamped_state'],
530
+
531
+ # Internal state buffers
532
+ geom_buffer=geom_buffer,
533
+ binning_buffer=binning_buffer,
534
+ img_buffer=img_buffer,
535
+
536
+ # Algorithm parameters
537
+ degree=self.config['sh_degree'],
538
+ debug=False
539
+ )
540
+
541
+ # Copy gradients to optimizer buffers
542
+ wp.copy(self.grads['positions'], gradients['dL_dmean3D'])
543
+ wp.copy(self.grads['scales'], gradients['dL_dscale'])
544
+ wp.copy(self.grads['rotations'], gradients['dL_drot'])
545
+ wp.copy(self.grads['opacities'], gradients['dL_dopacity'])
546
+ wp.copy(self.grads['shs'], gradients['dL_dshs'])
547
+
548
+ # Optimizer step
549
+ lr = 0.001 * (0.1 ** (it / num_iterations))
550
+ wp.launch(adam_update, dim=self.num_points, inputs=[
551
+ self.params['positions'], self.params['scales'],
552
+ self.params['rotations'], self.params['opacities'], self.params['shs'],
553
+ self.grads['positions'], self.grads['scales'],
554
+ self.grads['rotations'], self.grads['opacities'], self.grads['shs'],
555
+ self.adam_m['positions'], self.adam_m['scales'],
556
+ self.adam_m['rotations'], self.adam_m['opacities'], self.adam_m['shs'],
557
+ self.adam_v['positions'], self.adam_v['scales'],
558
+ self.adam_v['rotations'], self.adam_v['opacities'], self.adam_v['shs'],
559
+ self.num_points, lr, lr*5, lr*5, lr*2, lr*5,
560
+ 0.9, 0.999, 1e-8, it
561
+ ])
562
+
563
+ pbar.set_postfix(loss=f"{loss:.4f}")
564
+ pbar.update(1)
565
+
566
+ # Save checkpoint
567
+ if (it + 1) % 500 == 0 or it == num_iterations - 1:
568
+ self.save(it + 1)
569
+
570
+ print("Training complete!")
571
+
572
+ def save(self, iteration: int):
573
+ """Save checkpoint with rendered images."""
574
+ ckpt_dir = self.output_path / f"frame_{self.frame_idx}" / f"iter_{iteration}"
575
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
576
+
577
+ # Save PLY
578
+ save_ply(self.params, ckpt_dir / "point_cloud.ply", self.num_points)
579
+
580
+ # Render and save images for all cameras
581
+ for cam_idx, camera in enumerate(self.cameras):
582
+ target = self.images[cam_idx]
583
+
584
+ rendered, depth, _ = render_gaussians(
585
+ background=np.array(self.config['background_color'], dtype=np.float32),
586
+ means3D=self.params['positions'].numpy(),
587
+ colors=None,
588
+ opacity=self.params['opacities'].numpy(),
589
+ scales=self.params['scales'].numpy(),
590
+ rotations=self.params['rotations'].numpy(),
591
+ scale_modifier=1.0,
592
+ viewmatrix=camera['world_to_camera'],
593
+ projmatrix=camera['full_proj_matrix'],
594
+ tan_fovx=camera['tan_fovx'],
595
+ tan_fovy=camera['tan_fovy'],
596
+ image_height=camera['height'],
597
+ image_width=camera['width'],
598
+ sh=self.params['shs'].numpy(),
599
+ degree=3,
600
+ campos=camera['camera_center'],
601
+ prefiltered=False,
602
+ antialiasing=True,
603
+ )
604
+
605
+ # Convert rendered to numpy
606
+ rendered_np = wp.to_torch(rendered).cpu().numpy()
607
+ if rendered_np.shape[0] == 3:
608
+ rendered_np = np.transpose(rendered_np, (1, 2, 0))
609
+
610
+ # Save rendered image
611
+ rendered_uint8 = (np.clip(rendered_np, 0, 1) * 255).astype(np.uint8)
612
+ imageio.imwrite(ckpt_dir / f"render_cam{cam_idx}.png", rendered_uint8)
613
+
614
+ # Save target image
615
+ target_uint8 = (np.clip(target, 0, 1) * 255).astype(np.uint8)
616
+ imageio.imwrite(ckpt_dir / f"target_cam{cam_idx}.png", target_uint8)
617
+
618
+ # Save side-by-side comparison
619
+ comparison = np.concatenate([target_uint8, rendered_uint8], axis=1)
620
+ imageio.imwrite(ckpt_dir / f"compare_cam{cam_idx}.png", comparison)
621
+
622
+ # Save loss plot
623
+ if len(self.losses) > 0:
624
+ plt.figure(figsize=(10, 5))
625
+ plt.plot(self.losses)
626
+ plt.title(f'Training Loss - Frame {self.frame_idx}')
627
+ plt.xlabel('Iteration')
628
+ plt.ylabel('L1 Loss')
629
+ plt.grid(True)
630
+ plt.savefig(ckpt_dir / "loss_plot.png")
631
+ plt.close()
632
+
633
+ print(f"Saved checkpoint to {ckpt_dir}")
634
+
635
+ def save_final(self):
636
+ """Save final PLY to flat output structure for easy loading."""
637
+ final_path = self.output_path / f"frame_{self.frame_idx:04d}.ply"
638
+ save_ply(self.params, final_path, self.num_points)
639
+ return final_path
640
+
641
+
642
+ def train_single_frame(data: dict, frame_idx: int, output_path: str,
643
+ conf_threshold: float, iterations: int) -> str:
644
+ """Train a single frame and return the output PLY path."""
645
+ trainer = VDPM3DGSTrainer(
646
+ data=data,
647
+ frame_idx=frame_idx,
648
+ output_path=output_path,
649
+ conf_threshold=conf_threshold,
650
+ )
651
+ trainer.train(num_iterations=iterations)
652
+ return trainer.save_final()
653
+
654
+
655
+ def main():
656
+ parser = argparse.ArgumentParser(description="Train 3DGS from VDPM output")
657
+ parser.add_argument("--input", "-i", required=True, help="Path to VDPM output directory")
658
+ parser.add_argument("--output", "-o", required=True, help="Output directory")
659
+ parser.add_argument("--frame", "-f", type=int, default=None,
660
+ help="Single frame index to train (default: train ALL frames)")
661
+ parser.add_argument("--conf", type=float, default=50.0, help="Confidence threshold percentile")
662
+ parser.add_argument("--iterations", "-n", type=int, default=3000, help="Training iterations per frame")
663
+
664
+ args = parser.parse_args()
665
+
666
+ # Load data
667
+ print(f"Loading VDPM data from {args.input}...")
668
+ data = load_vdpm_data(args.input)
669
+
670
+ num_timesteps = data['T']
671
+ print(f"Found {num_timesteps} timesteps in data")
672
+
673
+ output_path = Path(args.output)
674
+ output_path.mkdir(parents=True, exist_ok=True)
675
+
676
+ if args.frame is not None:
677
+ # Train single frame
678
+ if args.frame >= num_timesteps:
679
+ raise ValueError(f"Frame {args.frame} out of range (0-{num_timesteps-1})")
680
+
681
+ print(f"\n{'='*60}")
682
+ print(f"Training frame {args.frame}/{num_timesteps-1}")
683
+ print(f"{'='*60}")
684
+
685
+ ply_path = train_single_frame(
686
+ data, args.frame, args.output, args.conf, args.iterations
687
+ )
688
+ print(f"\n✓ Saved: {ply_path}")
689
+ else:
690
+ # Train ALL frames
691
+ print(f"\n{'='*60}")
692
+ print(f"Training ALL {num_timesteps} frames")
693
+ print(f"Output: {output_path}/frame_XXXX.ply")
694
+ print(f"{'='*60}")
695
+
696
+ ply_paths = []
697
+ for frame_idx in range(num_timesteps):
698
+ print(f"\n[Frame {frame_idx+1}/{num_timesteps}]")
699
+ ply_path = train_single_frame(
700
+ data, frame_idx, args.output, args.conf, args.iterations
701
+ )
702
+ ply_paths.append(ply_path)
703
+
704
+ print(f"\n{'='*60}")
705
+ print(f"✓ Training complete! Generated {len(ply_paths)} PLY files:")
706
+ for p in ply_paths:
707
+ print(f" {p}")
708
+ print(f"{'='*60}")
709
+
710
+
711
+ if __name__ == "__main__":
712
+ main()
gs/training_progress.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:052cff51cf3d320a299faea9c795fad53349ed193b8bc222ee37860afd7def99
3
+ size 306145
gs/utils/analyze_scales.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from plyfile import PlyData
4
+ import argparse
5
+ from pathlib import Path
6
+ import os
7
+
8
+ def analyze_scales(input_ply, output_ply=None, threshold=None, show_plot=True):
9
+ """
10
+ Analyze scales in a PLY file and optionally filter out large splats.
11
+
12
+ Args:
13
+ input_ply (str): Path to input PLY file
14
+ output_ply (str, optional): Path to save filtered PLY file
15
+ threshold (float, optional): Maximum scale value to keep
16
+ show_plot (bool): Whether to display the histogram plot
17
+ """
18
+ # Convert input path to absolute path if it's relative
19
+ repo_root = Path(__file__).parent.parent # Go up one level from utils
20
+ input_ply = Path(repo_root) / input_ply if not os.path.isabs(input_ply) else Path(input_ply)
21
+
22
+ if not input_ply.exists():
23
+ raise FileNotFoundError(f"Could not find PLY file: {input_ply}")
24
+
25
+ print(f"Reading PLY file: {input_ply}")
26
+ plydata = PlyData.read(str(input_ply))
27
+ vertex_data = plydata['vertex']
28
+
29
+ # Extract scale values - assuming log-space encoding in PLY
30
+ scales = np.vstack([
31
+ np.exp(vertex_data['scale_0']),
32
+ np.exp(vertex_data['scale_1']),
33
+ np.exp(vertex_data['scale_2'])
34
+ ]).T
35
+
36
+ # Calculate statistics
37
+ max_scales = np.max(scales, axis=1)
38
+ mean_scale = np.mean(max_scales)
39
+ median_scale = np.median(max_scales)
40
+
41
+ print(f"Scale statistics:")
42
+ print(f"Mean scale: {mean_scale:.6f}")
43
+ print(f"Median scale: {median_scale:.6f}")
44
+ print(f"Min scale: {np.min(max_scales):.6f}")
45
+ print(f"Max scale: {np.max(max_scales):.6f}")
46
+
47
+ # Plot histogram
48
+ if show_plot:
49
+ plt.figure(figsize=(10, 6))
50
+ plt.hist(max_scales, bins=100, edgecolor='black')
51
+ plt.title('Histogram of Maximum Scales per Gaussian')
52
+ plt.xlabel('Scale')
53
+ plt.ylabel('Count')
54
+ if threshold is not None:
55
+ plt.axvline(x=threshold, color='r', linestyle='--',
56
+ label=f'Threshold ({threshold})')
57
+ plt.legend()
58
+ plt.savefig(Path(input_ply).parent / 'scale_histogram.png')
59
+ plt.show()
60
+
61
+ # Filter and save new PLY if threshold is provided
62
+ if threshold is not None and output_ply is not None:
63
+ # Create mask for Gaussians to keep
64
+ keep_mask = max_scales <= threshold
65
+ num_removed = np.sum(~keep_mask)
66
+ print(f"Removing {num_removed} Gaussians ({(num_removed/len(keep_mask))*100:.2f}%)")
67
+
68
+ # Create new vertex data with filtered Gaussians
69
+ new_vertex = []
70
+ for i, keep in enumerate(keep_mask):
71
+ if keep:
72
+ new_vertex.append(tuple(vertex_data[i]))
73
+
74
+ # Create new PLY file
75
+ new_vertex_array = np.array(
76
+ new_vertex,
77
+ dtype=vertex_data.dtype
78
+ )
79
+ new_vertex_element = PlyData.describe(new_vertex_array, 'vertex')
80
+ PlyData([new_vertex_element], text=True).write(output_ply)
81
+ print(f"Saved filtered PLY to: {output_ply}")
82
+
83
+ def main():
84
+ parser = argparse.ArgumentParser(description='Analyze and filter Gaussian scales in PLY file')
85
+ parser.add_argument('input_ply', help='Input PLY file path')
86
+ parser.add_argument('--output', '-o', help='Output PLY file path')
87
+ parser.add_argument('--threshold', '-t', type=float, help='Maximum scale threshold')
88
+ parser.add_argument('--no-plot', action='store_true', help='Disable histogram plot')
89
+
90
+ args = parser.parse_args()
91
+
92
+ analyze_scales(
93
+ args.input_ply,
94
+ args.output,
95
+ args.threshold,
96
+ show_plot=not args.no_plot
97
+ )
98
+
99
+ if __name__ == "__main__":
100
+ main()
gs/utils/camera_utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+
5
+ from utils.math_utils import world_to_view, projection_matrix
6
+
7
+ # Y down, Z forward
8
+ def load_camera(camera_info):
9
+ """Load camera parameters from camera info dictionary"""
10
+ # Extract camera parameters
11
+ camera_id = camera_info["camera_id"]
12
+ camera_to_world = np.asarray(camera_info["camera_to_world"], dtype=np.float64)
13
+
14
+ # Change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
15
+ camera_to_world[:3, 1:3] *= -1
16
+
17
+ # Calculate world to camera transform
18
+ world_to_camera = np.linalg.inv(camera_to_world).astype(np.float32)
19
+
20
+
21
+ # Extract rotation and translation
22
+ R = world_to_camera[:3, :3]
23
+ T = world_to_camera[:3, 3]
24
+
25
+
26
+ world_to_camera[3, 3] = 1.
27
+ world_to_camera = world_to_camera.T
28
+
29
+
30
+ width = camera_info.get("width")
31
+ height = camera_info.get("height")
32
+ fx = camera_info.get("focal")
33
+ fy = camera_info.get("focal")
34
+ cx = width / 2
35
+ cy = height / 2
36
+
37
+ # Calculate field of view from focal length
38
+ fovx = 2 * np.arctan(width / (2 * fx))
39
+ fovy = 2 * np.arctan(height / (2 * fy))
40
+
41
+ # Create view matrix
42
+ view_matrix = world_to_view(R=R, t=T)
43
+
44
+ # Create projection matrix
45
+ znear = 0.01
46
+ zfar = 100.0
47
+ proj_matrix = projection_matrix(fovx=fovx, fovy=fovy, znear=znear, zfar=zfar).T
48
+ full_proj_matrix = world_to_camera @ proj_matrix
49
+
50
+ # Calculate other parameters
51
+ tan_fovx = np.tan(fovx * 0.5)
52
+ tan_fovy = np.tan(fovy * 0.5)
53
+
54
+ camera_center = np.linalg.inv(world_to_camera)[3, :3]
55
+
56
+ # Handle camera type and distortion
57
+ camera_model = camera_info.get("camera_model", "OPENCV")
58
+ if camera_model == "OPENCV" or camera_model is None:
59
+ camera_type = 0 # PERSPECTIVE
60
+ elif camera_model == "OPENCV_FISHEYE":
61
+ camera_type = 1 # FISHEYE
62
+ else:
63
+ raise ValueError(f"Unsupported camera_model '{camera_model}'")
64
+
65
+ # Get distortion parameters
66
+ distortion_params = []
67
+ for param_name in ["k1", "k2", "p1", "p2", "k3", "k4"]:
68
+ distortion_params.append(camera_info.get(param_name, 0.0))
69
+
70
+ camera_params = {
71
+ 'R': R,
72
+ 'T': T,
73
+ 'camera_center': camera_center,
74
+ 'view_matrix': view_matrix,
75
+ 'proj_matrix': proj_matrix,
76
+ 'full_proj_matrix': full_proj_matrix,
77
+ 'tan_fovx': tan_fovx,
78
+ 'tan_fovy': tan_fovy,
79
+ 'fx': fx,
80
+ 'fy': fy,
81
+ 'cx': cx,
82
+ 'cy': cy,
83
+ 'width': width,
84
+ 'height': height,
85
+ 'camera_to_world': camera_to_world,
86
+ 'world_to_camera': world_to_camera,
87
+ 'camera_type': camera_type,
88
+ 'distortion_params': np.array(distortion_params, dtype=np.float32)
89
+ }
90
+
91
+ return camera_params
92
+
93
+ def load_camera_from_json(input_path, camera_id=0):
94
+ """Load camera parameters from camera.json file"""
95
+ camera_file = os.path.join(os.path.dirname(input_path), "cameras.json")
96
+ if not os.path.exists(camera_file):
97
+ print(f"Warning: No cameras.json found in {os.path.dirname(input_path)}, using default camera")
98
+ return None
99
+
100
+ try:
101
+ with open(camera_file, 'r') as f:
102
+ cameras = json.load(f)
103
+
104
+ # Find camera with specified ID, or use the first one
105
+ camera = next((cam for cam in cameras if cam["id"] == camera_id), cameras[0])
106
+
107
+ # Use load_camera to process the camera parameters
108
+ return load_camera(camera)
109
+
110
+ except Exception as e:
111
+ print(f"Error loading camera from cameras.json: {e}")
112
+ return None
113
+
114
+ def load_camera_colmap(cam_info):
115
+ """
116
+ Load camera from COLMAP format (dust3r output) with exact compatibility to original load_camera.
117
+
118
+ Args:
119
+ cam_info: Dictionary containing:
120
+ - width, height: image dimensions
121
+ - fx, fy: focal lengths
122
+ - cx, cy: principal point
123
+ - camera_id: unique identifier
124
+ - R: rotation matrix (world-to-camera rotation)
125
+ - T: translation vector (world-to-camera translation)
126
+ - Optional: camera_model, distortion params
127
+ """
128
+ # Extract camera parameters
129
+ camera_id = cam_info["camera_id"]
130
+
131
+ # Use provided R and T directly (COLMAP convention - world to camera)
132
+ R = cam_info['R']
133
+ T = cam_info['T'] # This is world-to-camera translation
134
+
135
+ # Build world-to-camera matrix
136
+ world_to_camera = np.eye(4, dtype=np.float64)
137
+ world_to_camera[:3, :3] = R
138
+ world_to_camera[:3, 3] = T
139
+
140
+ # Invert to get camera-to-world
141
+ camera_to_world = np.linalg.inv(world_to_camera).astype(np.float64)
142
+
143
+ # IMPORTANT FIX: Ensure Z direction is correctly oriented for COLMAP convention
144
+ # COLMAP uses +Z forward, so no need to flip Z axis
145
+ # If frustums are still backwards, uncomment this line:
146
+ # camera_to_world[:3, 2] *= -1 # Flip Z axis if needed
147
+
148
+ # Recalculate world_to_camera after any modifications
149
+ world_to_camera = np.linalg.inv(camera_to_world).astype(np.float32)
150
+
151
+ # Extract intrinsics
152
+ width = cam_info.get("width")
153
+ height = cam_info.get("height")
154
+ fx = cam_info.get("fx", cam_info.get("focal", width * 0.7))
155
+ fy = cam_info.get("fy", cam_info.get("focal", height * 0.7))
156
+ cx = cam_info.get("cx", width / 2)
157
+ cy = cam_info.get("cy", height / 2)
158
+
159
+ # Calculate field of view from focal length
160
+ fovx = 2 * np.arctan(width / (2 * fx))
161
+ fovy = 2 * np.arctan(height / (2 * fy))
162
+
163
+ # Create view matrix using the original R and T
164
+ view_matrix = world_to_view(R=R, t=T)
165
+
166
+ # Create projection matrix
167
+ znear = 0.01
168
+ zfar = 100.0
169
+ proj_matrix = projection_matrix(fovx=fovx, fovy=fovy, znear=znear, zfar=zfar).T
170
+ full_proj_matrix = world_to_camera @ proj_matrix
171
+
172
+ # Calculate other parameters
173
+ tan_fovx = np.tan(fovx * 0.5)
174
+ tan_fovy = np.tan(fovy * 0.5)
175
+
176
+ # IMPORTANT FIX: Correctly calculate camera center
177
+ camera_center = camera_to_world[:3, 3] # Extract translation from c2w matrix
178
+
179
+ # Handle camera type and distortion
180
+ camera_model = cam_info.get("camera_model", "OPENCV")
181
+ if camera_model == "OPENCV" or camera_model is None:
182
+ camera_type = 0 # PERSPECTIVE
183
+ elif camera_model == "OPENCV_FISHEYE":
184
+ camera_type = 1 # FISHEYE
185
+ else:
186
+ camera_type = 0 # Default to PERSPECTIVE
187
+
188
+ # Get distortion parameters
189
+ distortion_params = []
190
+ for param_name in ["k1", "k2", "p1", "p2", "k3", "k4"]:
191
+ distortion_params.append(cam_info.get(param_name, 0.0))
192
+
193
+ # Return camera parameters
194
+ camera_params = {
195
+ 'R': R,
196
+ 'T': T,
197
+ 'camera_center': camera_center,
198
+ 'view_matrix': view_matrix,
199
+ 'proj_matrix': proj_matrix,
200
+ 'full_proj_matrix': full_proj_matrix,
201
+ 'tan_fovx': tan_fovx,
202
+ 'tan_fovy': tan_fovy,
203
+ 'fx': fx,
204
+ 'fy': fy,
205
+ 'cx': cx,
206
+ 'cy': cy,
207
+ 'width': width,
208
+ 'height': height,
209
+ 'camera_to_world': camera_to_world,
210
+ 'world_to_camera': world_to_camera,
211
+ 'camera_type': camera_type,
212
+ 'distortion_params': np.array(distortion_params, dtype=np.float32)
213
+ }
214
+
215
+ return camera_params
gs/utils/check_opacities.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from plyfile import PlyData
3
+ import os
4
+
5
+ # Path for output folder that's one level above utils
6
+ ply_path = os.path.join(os.path.dirname(os.path.dirname(__file__)),
7
+ 'output', 'point_cloud', 'iteration_13999', 'point_cloud.ply')
8
+ # load the PLY
9
+ ply = PlyData.read(ply_path)
10
+ opacities = np.array(ply['vertex']['opacity'])
11
+
12
+ # compute statistics
13
+ min_o, max_o, mean_o = opacities.min(), opacities.max(), opacities.mean()
14
+ near_zero = np.sum(opacities < 1e-3)
15
+ near_one = np.sum(opacities > 0.999)
16
+
17
+ print(f'Loaded {len(opacities)} splats')
18
+ print(f'Opacity range: min={min_o:.6f}, max={max_o:.6f}, mean={mean_o:.6f}')
19
+ print(f'Count near-zero (<1e-3): {near_zero}')
20
+ print(f'Count near-one (>0.999): {near_one}')
21
+ print('Sample opacities:', opacities[:100])
gs/utils/math_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import math
4
+ import warp as wp
5
+
6
+
7
+ def world_to_view(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0):
8
+ Rt = np.zeros((4, 4))
9
+ Rt[:3, :3] = R.transpose()
10
+ Rt[:3, 3] = t
11
+ Rt[3, 3] = 1.0
12
+
13
+ C2W = np.linalg.inv(Rt)
14
+ cam_center = C2W[:3, 3]
15
+ cam_center = (cam_center + translate) * scale
16
+ C2W[:3, 3] = cam_center
17
+ Rt = np.linalg.inv(C2W)
18
+ return np.float32(Rt)
19
+
20
+ def projection_matrix(fovx, fovy, znear, zfar):
21
+ tanHalfFovY = math.tan((fovy / 2))
22
+ tanHalfFovX = math.tan((fovx / 2))
23
+
24
+ top = tanHalfFovY * znear
25
+ bottom = -top
26
+ right = tanHalfFovX * znear
27
+ left = -right
28
+
29
+ P = np.zeros((4, 4))
30
+
31
+ z_sign = 1.0
32
+
33
+ P[0, 0] = 2.0 * znear / (right - left)
34
+ P[1, 1] = 2.0 * znear / (top - bottom)
35
+ P[0, 2] = (right + left) / (right - left)
36
+ P[1, 2] = (top + bottom) / (top - bottom)
37
+ P[3, 2] = z_sign
38
+ P[2, 2] = z_sign * zfar / (zfar - znear)
39
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
40
+ return P
41
+
42
+ def matrix_to_quaternion(matrix):
43
+ """
44
+ Convert a 3x3 rotation matrix to a quaternion in (x, y, z, w) format.
45
+
46
+ Args:
47
+ matrix: 3x3 rotation matrix
48
+
49
+ Returns:
50
+ Quaternion as (x, y, z, w) in numpy array of shape (4,)
51
+ """
52
+ # Ensure the input is a proper rotation matrix
53
+ # This is just a simple check that might be helpful during debug
54
+ if np.abs(np.linalg.det(matrix) - 1.0) > 1e-5:
55
+ print(f"Warning: Input matrix determinant is not 1: {np.linalg.det(matrix)}")
56
+
57
+ trace = np.trace(matrix)
58
+ if trace > 0:
59
+ S = 2.0 * np.sqrt(trace + 1.0)
60
+ w = 0.25 * S
61
+ x = (matrix[2, 1] - matrix[1, 2]) / S
62
+ y = (matrix[0, 2] - matrix[2, 0]) / S
63
+ z = (matrix[1, 0] - matrix[0, 1]) / S
64
+ elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]:
65
+ S = 2.0 * np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2])
66
+ w = (matrix[2, 1] - matrix[1, 2]) / S
67
+ x = 0.25 * S
68
+ y = (matrix[0, 1] + matrix[1, 0]) / S
69
+ z = (matrix[0, 2] + matrix[2, 0]) / S
70
+ elif matrix[1, 1] > matrix[2, 2]:
71
+ S = 2.0 * np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2])
72
+ w = (matrix[0, 2] - matrix[2, 0]) / S
73
+ x = (matrix[0, 1] + matrix[1, 0]) / S
74
+ y = 0.25 * S
75
+ z = (matrix[1, 2] + matrix[2, 1]) / S
76
+ else:
77
+ S = 2.0 * np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1])
78
+ w = (matrix[1, 0] - matrix[0, 1]) / S
79
+ x = (matrix[0, 2] + matrix[2, 0]) / S
80
+ y = (matrix[1, 2] + matrix[2, 1]) / S
81
+ z = 0.25 * S
82
+
83
+ # Return as (x, y, z, w) to match Warp's convention
84
+ return np.array([x, y, z, w], dtype=np.float32)
85
+
86
+
87
+ def quaternion_to_rotation_matrix(q):
88
+ w, x, y, z = q
89
+ return np.array([
90
+ [1 - 2*y**2 - 2*z**2, 2*x*y - 2*z*w, 2*x*z + 2*y*w],
91
+ [2*x*y + 2*z*w, 1 - 2*x**2 - 2*z**2, 2*y*z - 2*x*w],
92
+ [2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x**2 - 2*y**2]
93
+ ], dtype=np.float32)
94
+
95
+ # def quaternion_to_rotation_matrix(q):
96
+ # """Convert quaternion to rotation matrix with swapped X and Z axes."""
97
+ # qw, qx, qy, qz = q
98
+
99
+ # # Original conversion
100
+ # R = np.array([
101
+ # [1 - 2*qy*qy - 2*qz*qz, 2*qx*qy - 2*qz*qw, 2*qx*qz + 2*qy*qw],
102
+ # [2*qx*qy + 2*qz*qw, 1 - 2*qx*qx - 2*qz*qz, 2*qy*qz - 2*qx*qw],
103
+ # [2*qx*qz - 2*qy*qw, 2*qy*qz + 2*qx*qw, 1 - 2*qx*qx - 2*qy*qy]
104
+ # ])
105
+
106
+ # # Swap X and Z axes (columns and rows)
107
+ # R_fixed = R.copy()
108
+ # R_fixed[:, [0, 2]] = R[:, [2, 0]] # Swap columns 0 and 2
109
+ # R_fixed[[0, 2], :] = R_fixed[[2, 0], :] # Swap rows 0 and 2
110
+
111
+ # return R_fixed
gs/utils/plot_loss_log.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+
4
+ def plot_loss_log(loss_file="output/steak/loss.txt"):
5
+ """Plot training loss values on a logarithmic scale from loss.txt"""
6
+
7
+ # Load loss values from txt file
8
+ with open(loss_file, 'r') as f:
9
+ losses = [float(line.strip()) for line in f if line.strip()]
10
+
11
+ # Create figure with log scale
12
+ plt.figure(figsize=(12, 6))
13
+ plt.semilogy(losses, label='Training Loss')
14
+
15
+ # Customize plot
16
+ plt.grid(True, which="both", ls="-", alpha=0.2)
17
+ plt.xlabel('Iteration')
18
+ plt.ylabel('Loss (log scale)')
19
+ plt.title('Training Loss over Time (Log Scale)')
20
+ plt.legend()
21
+
22
+ # Save plot
23
+ output_path = loss_file.replace('.txt', '_plot_log.png')
24
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
25
+ plt.close()
26
+
27
+ print(f"Saved loss plot to: {output_path}")
28
+ print(f"Loss statistics:")
29
+ print(f" Min: {min(losses):.6f}")
30
+ print(f" Max: {max(losses):.6f}")
31
+ print(f" Mean: {np.mean(losses):.6f}")
32
+ print(f" Final: {losses[-1]:.6f}")
33
+
34
+ if __name__ == "__main__":
35
+ plot_loss_log()
gs/utils/point_cloud_utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from plyfile import PlyData, PlyElement
4
+ import math
5
+ import os
6
+ import warp as wp
7
+
8
+
9
+ def load_ply(filepath):
10
+ """
11
+ Load a Gaussian splat PLY file.
12
+
13
+ Returns dict with: positions, scales, rotations, opacities, shs
14
+ """
15
+ plydata = PlyData.read(filepath)
16
+ vertex = plydata['vertex']
17
+
18
+ num_points = len(vertex)
19
+
20
+ # Load positions
21
+ positions = np.stack([
22
+ vertex['x'], vertex['y'], vertex['z']
23
+ ], axis=-1).astype(np.float32)
24
+
25
+ # Load scales (stored in log space)
26
+ scales = np.stack([
27
+ np.exp(vertex['scale_0']),
28
+ np.exp(vertex['scale_1']),
29
+ np.exp(vertex['scale_2'])
30
+ ], axis=-1).astype(np.float32)
31
+
32
+ # Load opacities
33
+ opacities = vertex['opacity'].astype(np.float32).reshape(-1, 1)
34
+
35
+ # Load rotations (quaternion)
36
+ rotations = np.stack([
37
+ vertex['rot_0'], vertex['rot_1'], vertex['rot_2'], vertex['rot_3']
38
+ ], axis=-1).astype(np.float32)
39
+
40
+ # Load SH coefficients
41
+ # DC term
42
+ sh_dc = np.stack([
43
+ vertex['f_dc_0'], vertex['f_dc_1'], vertex['f_dc_2']
44
+ ], axis=-1).astype(np.float32)
45
+
46
+ # Rest of SH coefficients
47
+ sh_rest = []
48
+ for i in range(45):
49
+ sh_rest.append(vertex[f'f_rest_{i}'])
50
+ sh_rest = np.stack(sh_rest, axis=-1).astype(np.float32) # (N, 45)
51
+ sh_rest = sh_rest.reshape(num_points, 15, 3) # (N, 15, 3)
52
+
53
+ # Combine into (N*16, 3) format expected by renderer
54
+ shs = np.zeros((num_points * 16, 3), dtype=np.float32)
55
+ for i in range(num_points):
56
+ shs[i * 16] = sh_dc[i]
57
+ for j in range(15):
58
+ shs[i * 16 + j + 1] = sh_rest[i, j]
59
+
60
+ return {
61
+ 'positions': positions,
62
+ 'scales': scales,
63
+ 'rotations': rotations,
64
+ 'opacities': opacities,
65
+ 'shs': shs,
66
+ 'num_points': num_points
67
+ }
68
+
69
+
70
+ # Function to save point cloud to PLY file
71
+ def save_ply(params, filepath, num_points, colors=None):
72
+ # Get numpy arrays
73
+ positions = params['positions'].numpy()
74
+ scales = params['scales'].numpy()
75
+ rotations = params['rotations'].numpy()
76
+ opacities = params['opacities'].numpy()
77
+ shs = params['shs'].numpy()
78
+
79
+ # Handle colors - either provided or computed from SH coefficients
80
+ if colors is not None:
81
+ # Use provided colors
82
+ if hasattr(colors, 'numpy'):
83
+ colors_np = colors.numpy()
84
+ else:
85
+ colors_np = colors
86
+ else:
87
+ # Compute colors from SH coefficients (DC term only for simplicity)
88
+ # SH DC coefficients are stored in the first coefficient (index 0)
89
+ colors_np = np.zeros((num_points, 3), dtype=np.float32)
90
+ for i in range(num_points):
91
+ # Get DC term from SH coefficients
92
+ sh_dc = shs[i * 16] # First SH coefficient contains DC term
93
+ # Convert from SH to RGB (simplified - just use DC term)
94
+ colors_np[i] = np.clip(sh_dc + 0.5, 0.0, 1.0) # Add 0.5 offset and clamp
95
+
96
+ # Create vertex data
97
+ vertex_data = []
98
+ for i in range(num_points):
99
+ # Basic properties
100
+ vertex = (
101
+ positions[i][0], positions[i][1], positions[i][2],
102
+ np.log(scales[i][0]), np.log(scales[i][1]), np.log(scales[i][2]), # Log-space encoding
103
+ (opacities[i])
104
+ )
105
+
106
+ # Add rotation quaternion elements
107
+ quat = rotations[i]
108
+ rot_elements = (quat[0], quat[1], quat[2], quat[3]) # x, y, z, w
109
+ vertex += rot_elements
110
+
111
+ # Add RGB colors (convert to 0-255 range)
112
+ color_255 = (
113
+ int(np.clip(colors_np[i][0] * 255, 0, 255)),
114
+ int(np.clip(colors_np[i][1] * 255, 0, 255)),
115
+ int(np.clip(colors_np[i][2] * 255, 0, 255))
116
+ )
117
+ vertex += color_255
118
+
119
+ # Add SH coefficients
120
+ sh_dc = tuple(shs[i * 16][j] for j in range(3))
121
+ vertex += sh_dc
122
+
123
+ # Add remaining SH coefficients
124
+ sh_rest = []
125
+ for j in range(1, 16):
126
+ for c in range(3):
127
+ sh_rest.append(shs[i * 16 + j][c])
128
+ vertex += tuple(sh_rest)
129
+
130
+ vertex_data.append(vertex)
131
+
132
+ # Define the structure of the PLY file
133
+ vertex_type = [
134
+ ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
135
+ ('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'),
136
+ ('opacity', 'f4')
137
+ ]
138
+
139
+ # Add rotation quaternion elements
140
+ vertex_type.extend([('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4')])
141
+
142
+ # Add RGB color fields
143
+ vertex_type.extend([('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
144
+
145
+ # Add SH coefficients
146
+ vertex_type.extend([('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4')])
147
+
148
+ # Add remaining SH coefficients
149
+ for i in range(45): # 15 coeffs * 3 channels
150
+ vertex_type.append((f'f_rest_{i}', 'f4'))
151
+
152
+ vertex_array = np.array(vertex_data, dtype=vertex_type)
153
+ el = PlyElement.describe(vertex_array, 'vertex')
154
+
155
+ # Create directory if it doesn't exist
156
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
157
+
158
+ # Save the PLY file
159
+ PlyData([el], text=False).write(filepath)
160
+ print(f"Point cloud saved to {filepath}")
gs/utils/wp_utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warp as wp
2
+ from config import DEVICE
3
+
4
+
5
+ @wp.func
6
+ def wp_vec3_mul_element(a: wp.vec3, b: wp.vec3) -> wp.vec3:
7
+ return wp.vec3(a[0] * b[0], a[1] * b[1], a[2] * b[2])
8
+
9
+ # Reinstate the element-wise vector square root helper function
10
+ @wp.func
11
+ def wp_vec3_sqrt(a: wp.vec3) -> wp.vec3:
12
+ return wp.vec3(wp.sqrt(a[0]), wp.sqrt(a[1]), wp.sqrt(a[2]))
13
+
14
+ # Add element-wise vector division helper function
15
+ @wp.func
16
+ def wp_vec3_div_element(a: wp.vec3, b: wp.vec3) -> wp.vec3:
17
+ # Add small epsilon to denominator to prevent division by zero
18
+ # (although Adam's epsilon should mostly handle this)
19
+ safe_b = wp.vec3(b[0] + 1e-9, b[1] + 1e-9, b[2] + 1e-9)
20
+ return wp.vec3(a[0] / safe_b[0], a[1] / safe_b[1], a[2] / safe_b[2])
21
+
22
+ @wp.func
23
+ def wp_vec3_add_element(a: wp.vec3, b: wp.vec3) -> wp.vec3:
24
+ return wp.vec3(a[0] + b[0], a[1] + b[1], a[2] + b[2])
25
+
26
+ @wp.func
27
+ def wp_vec3_clamp(x: wp.vec3, min_val: float, max_val: float) -> wp.vec3:
28
+ return wp.vec3(
29
+ wp.clamp(x[0], min_val, max_val),
30
+ wp.clamp(x[1], min_val, max_val),
31
+ wp.clamp(x[2], min_val, max_val)
32
+ )
33
+
34
+ def to_warp_array(data, dtype, shape_check=None, flatten=False):
35
+ if isinstance(data, wp.array):
36
+ return data
37
+ if data is None:
38
+ return None
39
+ # Convert torch tensor to numpy if needed
40
+ if hasattr(data, 'cpu') and hasattr(data, 'numpy'):
41
+ data = data.cpu().numpy()
42
+ if flatten and data.ndim == 2 and data.shape[1] == 1:
43
+ data = data.flatten()
44
+ return wp.array(data, dtype=dtype, device=DEVICE)
45
+
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install PyTorch with CUDA 11.8 separately using:
2
+ # pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
3
+ # torch==2.5.1
4
+ # torchvision==0.20.1
5
+
6
+ warp-lang==1.7.0
7
+ numpy==1.26.3
8
+ imageio==2.34.1
9
+ plyfile
10
+ roma
11
+ gradio==5.17.1
12
+ pydantic==2.10.6
13
+ matplotlib==3.9.2
14
+ tqdm==4.66.5
15
+ opencv-python
16
+ pypng
17
+ scipy
18
+ einops
19
+ trimesh
20
+ pyglet<2
21
+ viser
22
+ jaxtyping
23
+ hydra-submitit-launcher
24
+ scikit-learn
25
+ plotly
26
+ git+https://github.com/facebookresearch/vggt.git@44b3afb
vdpm/.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ checkpoints/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ pip-wheel-metadata/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
vdpm/.gitmodules ADDED
File without changes
vdpm/.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
vdpm/LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Eldar Insafutdinov, Edgar Sucar
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
vdpm/LICENSE-VGGT ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VGGT License
2
+
3
+ v1 Last Updated: July 29, 2025
4
+
5
+ “Acceptable Use Policy” means the Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
6
+
7
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
8
+
9
+
10
+ “Documentation” means the specifications, manuals and documentation accompanying
11
+ Research Materials distributed by Meta.
12
+
13
+
14
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
15
+
16
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
17
+ “Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
18
+
19
+ By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
20
+
21
+
22
+ 1. License Rights and Redistribution.
23
+
24
+
25
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
26
+
27
+ b. Redistribution and Use.
28
+
29
+
30
+ i. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
31
+
32
+
33
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
34
+
35
+
36
+ iii. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
37
+ 2. User Support. Your use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
38
+
39
+
40
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
41
+
42
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
43
+
44
+ 5. Intellectual Property.
45
+
46
+
47
+ a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
48
+
49
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
50
+
51
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
52
+
53
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
54
+
55
+
56
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
57
+
58
+
59
+ Acceptable Use Policy
60
+
61
+ Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
62
+
63
+ As part of this mission, Meta makes certain research materials available for use in accordance with this Agreement (including the Acceptable Use Policy). Meta is committed to promoting the safe and responsible use of such research materials.
64
+
65
+ Prohibited Uses
66
+
67
+ You agree you will not use, or allow others to use, Research Materials to:
68
+
69
+ Violate the law or others’ rights, including to:
70
+ Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
71
+ Violence or terrorism
72
+ Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
73
+ Human trafficking, exploitation, and sexual violence
74
+ The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
75
+ Sexual solicitation
76
+ Any other criminal activity
77
+
78
+ Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
79
+
80
+ Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
81
+
82
+ Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
83
+
84
+ Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
85
+
86
+ Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using Research Materials
87
+
88
+ Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
89
+
90
+ 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
91
+
92
+ Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
93
+
94
+ Guns and illegal weapons (including weapon development)
95
+
96
+ Illegal drugs and regulated/controlled substances
97
+ Operation of critical infrastructure, transportation technologies, or heavy machinery
98
+
99
+ Self-harm or harm to others, including suicide, cutting, and eating disorders
100
+ Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
101
+
102
+ 3. Intentionally deceive or mislead others, including use of Research Materials related to the following:
103
+
104
+ Generating, promoting, or furthering fraud or the creation or promotion of disinformation
105
+ Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
106
+
107
+ Generating, promoting, or further distributing spam
108
+
109
+ Impersonating another individual without consent, authorization, or legal right
110
+
111
+ Representing that outputs of research materials or outputs from technology using Research Materials are human-generated
112
+
113
+ Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
114
+
115
+ 4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
vdpm/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: vdpm
3
+ app_file: gradio_demo.py
4
+ sdk: gradio
5
+ sdk_version: 5.17.1
6
+ ---
7
+ <div align="center">
8
+ <h1>V-DPM: 4D Video Reconstruction with Dynamic Point Maps</h1>
9
+
10
+ <a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
11
+ <a href="https://huggingface.co/spaces/edgarsucar/vdpm"><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
12
+
13
+ **[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**
14
+
15
+
16
+ [Edgar Sucar](https://edgarsucar.github.io/)\*, [Eldar Insafutdinov](https://eldar.insafutdinov.com/)\*, [Zihang Lai](https://scholar.google.com/citations?user=31eXgMYAAAAJ), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/)
17
+ </div>
18
+
19
+ ## Setup
20
+
21
+ First, clone the repository and setup a virtual environment with [uv](https://github.com/astral-sh/uv):
22
+
23
+ ```bash
24
+ git clone git@github.com:eldar/vdpm.git
25
+ cd vdpm
26
+ uv venv --python 3.12
27
+ . .venv/bin/activate
28
+
29
+ # Install PyTorch with CUDA 11.8 first
30
+ uv pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
31
+
32
+ # Then install remaining dependencies
33
+ uv pip install -r requirements.txt
34
+ ```
35
+
36
+ ## Viser demo
37
+ ```bash
38
+ python visualise.py ++vis.input_video=examples/videos/camel.mp4
39
+ ```
40
+
41
+ ## Gradio demo
42
+ ```bash
43
+ python gradio_demo.py
44
+ ```
vdpm/check_model_size.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # Add parent directory to path
6
+ sys.path.insert(0, str(Path(__file__).parent))
7
+
8
+ def check_model_memory():
9
+ # Simple config object
10
+ class SimpleConfig:
11
+ class ModelConfig:
12
+ decoder_depth = 4
13
+ model = ModelConfig()
14
+
15
+ cfg = SimpleConfig()
16
+
17
+ # Import after path is set
18
+ from dpm.model import VDPM
19
+
20
+ # Create model on CPU first to count parameters
21
+ print("Creating model...")
22
+ model = VDPM(cfg)
23
+
24
+ # Count parameters
25
+ total_params = sum(p.numel() for p in model.parameters())
26
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
27
+
28
+ print(f"\n{'='*60}")
29
+ print(f"MODEL SIZE ANALYSIS FOR RTX 3070 Ti (8GB)")
30
+ print(f"{'='*60}")
31
+ print(f"Total parameters: {total_params:,}")
32
+ print(f"Trainable parameters: {trainable_params:,}")
33
+ print(f"\nEstimated model weights memory:")
34
+ print(f" - FP32 (float32): {total_params * 4 / 1024**3:.2f} GB")
35
+ print(f" - FP16 (float16): {total_params * 2 / 1024**3:.2f} GB")
36
+ print(f" - BF16 (bfloat16): {total_params * 2 / 1024**3:.2f} GB")
37
+ print(f" - INT8 (quantized): {total_params * 1 / 1024**3:.2f} GB <-- RECOMMENDED for 8GB GPU")
38
+
39
+ # Estimate activation memory for typical input
40
+ batch_size = 1
41
+ num_frames = 5 # typical video length
42
+ img_size = 518
43
+ print(f"\nEstimated activation memory (batch={batch_size}, frames={num_frames}, img_size={img_size}):")
44
+
45
+ # Input images: [B, S, 3, H, W]
46
+ input_mem = batch_size * num_frames * 3 * img_size * img_size * 4 / 1024**3
47
+ print(f" - Input images (FP32): {input_mem:.2f} GB")
48
+
49
+ # Rough estimate for activations (can be 2-4x model size during forward pass)
50
+ activation_mem_estimate = total_params * 2 * 3 / 1024**3 # conservative estimate
51
+ print(f" - Activations (estimate): {activation_mem_estimate:.2f} GB")
52
+
53
+ # Calculate total for different precision modes
54
+ total_fp16 = (total_params * 2 / 1024**3) + input_mem + activation_mem_estimate
55
+ total_int8 = (total_params * 1 / 1024**3) + input_mem + (activation_mem_estimate * 0.6) # INT8 reduces activations too
56
+
57
+ print(f"\nTotal estimated GPU memory needed:")
58
+ print(f" - With FP16/BF16: {total_fp16:.2f} GB")
59
+ print(f" - With INT8 quantization: {total_int8:.2f} GB <-- FITS IN 8GB!")
60
+ print(f"Your RTX 3070 Ti has: 8 GB VRAM")
61
+
62
+ if total_int8 <= 8:
63
+ print(f"\n✓ With INT8 quantization, model will fit in GPU memory!")
64
+ print(f" Set USE_QUANTIZATION = True in gradio_demo.py")
65
+ elif total_fp16 > 8:
66
+ print(f"\n⚠️ WARNING: Even with INT8 ({total_int8:.2f} GB), memory is tight")
67
+ print(f" Recommendations:")
68
+ print(f" 1. Use INT8 quantization (USE_QUANTIZATION = True)")
69
+ print(f" 2. Reduce number of input frames to {num_frames} or fewer")
70
+ print(f" 3. Clear CUDA cache between batches")
71
+ else:
72
+ print(f"\n✓ Model should fit with FP16!")
73
+
74
+ print(f"{'='*60}\n")
75
+
76
+ # Check actual GPU memory if CUDA available
77
+ if torch.cuda.is_available():
78
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
79
+ print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
80
+ print(f"Current GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
81
+ print(f"Current GPU memory cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
82
+
83
+ if __name__ == "__main__":
84
+ check_model_memory()
85
+
vdpm/configs/config.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - hydra: defaults
4
+ - model: dpm
5
+
6
+ config:
7
+ exp_name: "debug"
8
+ file: "config.yaml"
9
+
10
+ data_loader:
11
+ batch_size: 2
12
+ num_workers: 8
13
+ dynamic_batch: false
14
+
15
+ train:
16
+ logging: true
17
+ num_gpus: 4
18
+ amp: bfloat16
19
+ amp_dpt: false
20
+ dry_run: false
21
+ camera_loss_lambda: 5.0
22
+
23
+ optimiser:
24
+ lr: 0.00005 # absolute lr
25
+ blr: 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
26
+ start_epoch:
27
+ epochs: 70
28
+ accum_iter: 1
29
+ warmup_epochs: 3
30
+ min_lr: 1e-06
31
+
32
+ run:
33
+ resume: false
34
+ dirpath: null
35
+ debug: false
36
+ random_seed: 42
37
+ git_hash: null
38
+ log_frequency: 250
39
+ training_progress_bar: false
40
+ save_freq: 5
41
+ eval_freq: 1
42
+ keep_freq: 5
43
+ print_freq: 20
44
+ num_keep_ckpts: 5
45
+ # Old Dust3r params
46
+ world_size: -1
47
+ local_rank: -1
48
+ dist_url: "env://"
49
+ seed: 0
50
+
vdpm/configs/model/dpm.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: dpm-video
2
+ pretrained: /work/eldar/models/vggt/VGGT-1B.pt
3
+ decoder_depth: 4
vdpm/configs/visualise.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - model: dpm
4
+
5
+ hydra:
6
+ output_subdir: null # Disable saving of config files.
7
+ job:
8
+ chdir: False
9
+
10
+ vis:
11
+ port: 8080
12
+ input_video:
13
+
vdpm/dpm/aggregator.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE-VGGT file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.checkpoint import checkpoint
12
+ from typing import Optional, Tuple, Union, List, Dict, Any
13
+
14
+ from vggt.layers import PatchEmbed
15
+ from vggt.layers.block import Block
16
+ from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
17
+ from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
22
+ _RESNET_STD = [0.229, 0.224, 0.225]
23
+
24
+
25
+ class Aggregator(nn.Module):
26
+ """
27
+ The Aggregator applies alternating-attention over input frames,
28
+ as described in VGGT: Visual Geometry Grounded Transformer.
29
+
30
+
31
+ Args:
32
+ img_size (int): Image size in pixels.
33
+ patch_size (int): Size of each patch for PatchEmbed.
34
+ embed_dim (int): Dimension of the token embeddings.
35
+ depth (int): Number of blocks.
36
+ num_heads (int): Number of attention heads.
37
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
38
+ num_register_tokens (int): Number of register tokens.
39
+ block_fn (nn.Module): The block type used for attention (Block by default).
40
+ qkv_bias (bool): Whether to include bias in QKV projections.
41
+ proj_bias (bool): Whether to include bias in the output projection.
42
+ ffn_bias (bool): Whether to include bias in MLP layers.
43
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
44
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
45
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
46
+ qk_norm (bool): Whether to apply QK normalization.
47
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
48
+ init_values (float): Init scale for layer scale.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ img_size=518,
54
+ patch_size=14,
55
+ embed_dim=1024,
56
+ depth=24,
57
+ num_heads=16,
58
+ mlp_ratio=4.0,
59
+ num_register_tokens=4,
60
+ block_fn=Block,
61
+ qkv_bias=True,
62
+ proj_bias=True,
63
+ ffn_bias=True,
64
+ patch_embed="dinov2_vitl14_reg",
65
+ aa_order=["frame", "global"],
66
+ aa_block_size=1,
67
+ qk_norm=True,
68
+ rope_freq=100,
69
+ init_values=0.01,
70
+ ):
71
+ super().__init__()
72
+
73
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
74
+
75
+ # Initialize rotary position embedding if frequency > 0
76
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
77
+ self.position_getter = PositionGetter() if self.rope is not None else None
78
+
79
+ self.frame_blocks = nn.ModuleList(
80
+ [
81
+ block_fn(
82
+ dim=embed_dim,
83
+ num_heads=num_heads,
84
+ mlp_ratio=mlp_ratio,
85
+ qkv_bias=qkv_bias,
86
+ proj_bias=proj_bias,
87
+ ffn_bias=ffn_bias,
88
+ init_values=init_values,
89
+ qk_norm=qk_norm,
90
+ rope=self.rope,
91
+ )
92
+ for _ in range(depth)
93
+ ]
94
+ )
95
+
96
+ self.global_blocks = nn.ModuleList(
97
+ [
98
+ block_fn(
99
+ dim=embed_dim,
100
+ num_heads=num_heads,
101
+ mlp_ratio=mlp_ratio,
102
+ qkv_bias=qkv_bias,
103
+ proj_bias=proj_bias,
104
+ ffn_bias=ffn_bias,
105
+ init_values=init_values,
106
+ qk_norm=qk_norm,
107
+ rope=self.rope,
108
+ )
109
+ for _ in range(depth)
110
+ ]
111
+ )
112
+
113
+ self.depth = depth
114
+ self.aa_order = aa_order
115
+ self.patch_size = patch_size
116
+ self.aa_block_size = aa_block_size
117
+
118
+ # Validate that depth is divisible by aa_block_size
119
+ if self.depth % self.aa_block_size != 0:
120
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
121
+
122
+ self.aa_block_num = self.depth // self.aa_block_size
123
+
124
+ # Note: We have two camera tokens, one for the first frame and one for the rest
125
+ # The same applies for register tokens
126
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
127
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
128
+
129
+ # The patch tokens start after the camera and register tokens
130
+ self.patch_start_idx = 1 + num_register_tokens
131
+
132
+ self.time_conditioning_token = nn.Parameter(torch.randn(1, 1, embed_dim))
133
+ self.patch_start_idx += 1
134
+
135
+ # Initialize parameters with small values
136
+ nn.init.normal_(self.camera_token, std=1e-6)
137
+ nn.init.normal_(self.register_token, std=1e-6)
138
+
139
+ # Register normalization constants as buffers
140
+ for name, value in (
141
+ ("_resnet_mean", _RESNET_MEAN),
142
+ ("_resnet_std", _RESNET_STD),
143
+ ):
144
+ self.register_buffer(
145
+ name,
146
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
147
+ persistent=False,
148
+ )
149
+
150
+ self.use_reentrant = False # hardcoded to False
151
+
152
+ def __build_patch_embed__(
153
+ self,
154
+ patch_embed,
155
+ img_size,
156
+ patch_size,
157
+ num_register_tokens,
158
+ interpolate_antialias=True,
159
+ interpolate_offset=0.0,
160
+ block_chunks=0,
161
+ init_values=1.0,
162
+ embed_dim=1024,
163
+ ):
164
+ """
165
+ Build the patch embed layer. If 'conv', we use a
166
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
167
+ """
168
+
169
+ if "conv" in patch_embed:
170
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
171
+ else:
172
+ vit_models = {
173
+ "dinov2_vitl14_reg": vit_large,
174
+ "dinov2_vitb14_reg": vit_base,
175
+ "dinov2_vits14_reg": vit_small,
176
+ "dinov2_vitg2_reg": vit_giant2,
177
+ }
178
+
179
+ self.patch_embed = vit_models[patch_embed](
180
+ img_size=img_size,
181
+ patch_size=patch_size,
182
+ num_register_tokens=num_register_tokens,
183
+ interpolate_antialias=interpolate_antialias,
184
+ interpolate_offset=interpolate_offset,
185
+ block_chunks=block_chunks,
186
+ init_values=init_values,
187
+ )
188
+
189
+ # Disable gradient updates for mask token
190
+ if hasattr(self.patch_embed, "mask_token"):
191
+ self.patch_embed.mask_token.requires_grad_(False)
192
+
193
+ def forward(
194
+ self,
195
+ images: torch.Tensor,
196
+ ) -> Tuple[List[torch.Tensor], int]:
197
+ """
198
+ Args:
199
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
200
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
201
+
202
+ Returns:
203
+ (list[torch.Tensor], int):
204
+ The list of outputs from the attention blocks,
205
+ and the patch_start_idx indicating where patch tokens begin.
206
+ """
207
+ B, S, C_in, H, W = images.shape
208
+
209
+ if C_in != 3:
210
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
211
+
212
+ # Normalize images and reshape for patch embed
213
+ images = (images - self._resnet_mean) / self._resnet_std
214
+
215
+ # Reshape to [B*S, C, H, W] for patch embedding
216
+ images = images.view(B * S, C_in, H, W)
217
+ patch_tokens = self.patch_embed(images)
218
+
219
+ if isinstance(patch_tokens, dict):
220
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
221
+
222
+ _, P, C = patch_tokens.shape
223
+
224
+ # Expand camera and register tokens to match batch size and sequence length
225
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
226
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
227
+ # do something similar for time_conditioning_token
228
+ time_conditioning_token = slice_expand_and_flatten_single(self.time_conditioning_token, B, S)
229
+ # Concatenate special tokens with patch tokens
230
+ tokens = torch.cat([camera_token, time_conditioning_token, register_token, patch_tokens], dim=1)
231
+
232
+ pos = None
233
+ if self.rope is not None:
234
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
235
+
236
+ if self.patch_start_idx > 0:
237
+ # do not use position embedding for special tokens (camera and register tokens)
238
+ # so set pos to 0 for the special tokens
239
+ pos = pos + 1
240
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
241
+ pos = torch.cat([pos_special, pos], dim=1)
242
+
243
+ # update P because we added special tokens
244
+ _, P, C = tokens.shape
245
+
246
+ frame_idx = 0
247
+ global_idx = 0
248
+ output_list = []
249
+
250
+ for _ in range(self.aa_block_num):
251
+ for attn_type in self.aa_order:
252
+ if attn_type == "frame":
253
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
254
+ tokens, B, S, P, C, frame_idx, pos=pos
255
+ )
256
+ elif attn_type == "global":
257
+ tokens, global_idx, global_intermediates = self._process_global_attention(
258
+ tokens, B, S, P, C, global_idx, pos=pos
259
+ )
260
+ else:
261
+ raise ValueError(f"Unknown attention type: {attn_type}")
262
+
263
+ for i in range(len(frame_intermediates)):
264
+ # concat frame and global intermediates, [B x S x P x 2C]
265
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
266
+ output_list.append(concat_inter)
267
+
268
+ del concat_inter
269
+ del frame_intermediates
270
+ del global_intermediates
271
+ return output_list, self.patch_start_idx
272
+
273
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
274
+ """
275
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
276
+ """
277
+ # If needed, reshape tokens or positions:
278
+ if tokens.shape != (B * S, P, C):
279
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
280
+
281
+ if pos is not None and pos.shape != (B * S, P, 2):
282
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
283
+
284
+ intermediates = []
285
+
286
+ # by default, self.aa_block_size=1, which processes one block at a time
287
+ for _ in range(self.aa_block_size):
288
+ if self.training:
289
+ tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
290
+ else:
291
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
292
+ frame_idx += 1
293
+ intermediates.append(tokens.view(B, S, P, C))
294
+
295
+ return tokens, frame_idx, intermediates
296
+
297
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
298
+ """
299
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
300
+ """
301
+ if tokens.shape != (B, S * P, C):
302
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
303
+
304
+ if pos is not None and pos.shape != (B, S * P, 2):
305
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
306
+
307
+ intermediates = []
308
+
309
+ # by default, self.aa_block_size=1, which processes one block at a time
310
+ for _ in range(self.aa_block_size):
311
+ if self.training:
312
+ tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
313
+ else:
314
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
315
+ global_idx += 1
316
+ intermediates.append(tokens.view(B, S, P, C))
317
+
318
+ return tokens, global_idx, intermediates
319
+
320
+
321
+ def slice_expand_and_flatten(token_tensor, B, S):
322
+ """
323
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
324
+ 1) Uses the first position (index=0) for the first frame only
325
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
326
+ 3) Expands both to match batch size B
327
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
328
+ followed by (S-1) second-position tokens
329
+ 5) Flattens to (B*S, X, C) for processing
330
+
331
+ Returns:
332
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
333
+ """
334
+
335
+ # Slice out the "query" tokens => shape (1, 1, ...)
336
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
337
+ # Slice out the "other" tokens => shape (1, S-1, ...)
338
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
339
+ # Concatenate => shape (B, S, ...)
340
+ combined = torch.cat([query, others], dim=1)
341
+
342
+ # Finally flatten => shape (B*S, ...)
343
+ combined = combined.view(B * S, *combined.shape[2:])
344
+ return combined
345
+
346
+
347
+ def slice_expand_and_flatten_single(token_tensor, B, S):
348
+ """
349
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
350
+ 1) Uses the first position (index=0) for the first frame only
351
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
352
+ 3) Expands both to match batch size B
353
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
354
+ followed by (S-1) second-position tokens
355
+ 5) Flattens to (B*S, X, C) for processing
356
+
357
+ Returns:
358
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
359
+ """
360
+
361
+ # Slice out the "query" tokens => shape (1, 1, ...)
362
+ token = token_tensor.expand(B, S, *token_tensor.shape[2:])
363
+
364
+ # Finally flatten => shape (B*S, ...)
365
+ token = token.view(B * S, 1, *token.shape[2:])
366
+ return token
vdpm/dpm/decoder.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE-VGGT file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ from torch import nn, Tensor
10
+ from torch.utils.checkpoint import checkpoint
11
+ from typing import List, Callable
12
+ from dataclasses import dataclass
13
+
14
+ from einops import repeat
15
+
16
+ from vggt.layers.block import drop_add_residual_stochastic_depth
17
+ from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
18
+
19
+ from vggt.layers.attention import Attention
20
+ from vggt.layers.drop_path import DropPath
21
+ from vggt.layers.layer_scale import LayerScale
22
+ from vggt.layers.mlp import Mlp
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class ModulationOut:
29
+ shift: Tensor
30
+ scale: Tensor
31
+ gate: Tensor
32
+
33
+
34
+ class Modulation(nn.Module):
35
+ def __init__(self, dim: int, double: bool):
36
+ super().__init__()
37
+ self.is_double = double
38
+ self.multiplier = 6 if double else 3
39
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
40
+
41
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
42
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
43
+
44
+ return (
45
+ ModulationOut(*out[:3]),
46
+ ModulationOut(*out[3:]) if self.is_double else None,
47
+ )
48
+
49
+
50
+ class ConditionalBlock(nn.Module):
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ num_heads: int,
55
+ mlp_ratio: float = 4.0,
56
+ qkv_bias: bool = True,
57
+ proj_bias: bool = True,
58
+ ffn_bias: bool = True,
59
+ drop: float = 0.0,
60
+ attn_drop: float = 0.0,
61
+ init_values=None,
62
+ drop_path: float = 0.0,
63
+ act_layer: Callable[..., nn.Module] = nn.GELU,
64
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
65
+ attn_class: Callable[..., nn.Module] = Attention,
66
+ ffn_layer: Callable[..., nn.Module] = Mlp,
67
+ qk_norm: bool = False,
68
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
69
+ rope=None,
70
+ ) -> None:
71
+ super().__init__()
72
+
73
+ self.norm1 = norm_layer(dim, elementwise_affine=False)
74
+ self.modulation = Modulation(dim, double=False)
75
+
76
+ self.attn = attn_class(
77
+ dim,
78
+ num_heads=num_heads,
79
+ qkv_bias=qkv_bias,
80
+ proj_bias=proj_bias,
81
+ attn_drop=attn_drop,
82
+ proj_drop=drop,
83
+ qk_norm=qk_norm,
84
+ fused_attn=fused_attn,
85
+ rope=rope,
86
+ )
87
+
88
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
89
+
90
+ self.norm2 = norm_layer(dim)
91
+ mlp_hidden_dim = int(dim * mlp_ratio)
92
+ self.mlp = ffn_layer(
93
+ in_features=dim,
94
+ hidden_features=mlp_hidden_dim,
95
+ act_layer=act_layer,
96
+ drop=drop,
97
+ bias=ffn_bias,
98
+ )
99
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
100
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
101
+
102
+ self.sample_drop_ratio = drop_path
103
+
104
+ def forward(self, x: Tensor, pos=None, cond=None, is_global=False) -> Tensor:
105
+ B, S = cond.shape[:2]
106
+ C = x.shape[-1]
107
+ if is_global:
108
+ P = x.shape[1] // S
109
+ cond = cond.view(B * S, C)
110
+ mod, _ = self.modulation(cond)
111
+
112
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
113
+ """
114
+ conditional attention following DiT implementation from Flux
115
+ https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py#L194-L239
116
+ """
117
+ def prepare_for_mod(y):
118
+ """reshape to modulate the patch tokens with correct conditioning one"""
119
+ return y.view(B, S, P, C).view(B * S, P, C) if is_global else y
120
+ def restore_after_mod(y):
121
+ """reshape back to global sequence"""
122
+ return y.view(B, S, P, C).view(B, S * P, C) if is_global else y
123
+
124
+ x = prepare_for_mod(x)
125
+ x = (1 + mod.scale) * self.norm1(x) + mod.shift
126
+ x = restore_after_mod(x)
127
+
128
+ x = self.attn(x, pos=pos)
129
+
130
+ x = prepare_for_mod(x)
131
+ x = mod.gate * x
132
+ x = restore_after_mod(x)
133
+
134
+ return x
135
+
136
+ def ffn_residual_func(x: Tensor) -> Tensor:
137
+ return self.ls2(self.mlp(self.norm2(x)))
138
+
139
+ if self.training and self.sample_drop_ratio > 0.1:
140
+ # the overhead is compensated only for a drop path rate larger than 0.1
141
+ x = drop_add_residual_stochastic_depth(
142
+ x,
143
+ pos=pos,
144
+ residual_func=attn_residual_func,
145
+ sample_drop_ratio=self.sample_drop_ratio,
146
+ )
147
+ x = drop_add_residual_stochastic_depth(
148
+ x,
149
+ residual_func=ffn_residual_func,
150
+ sample_drop_ratio=self.sample_drop_ratio,
151
+ )
152
+ elif self.training and self.sample_drop_ratio > 0.0:
153
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
154
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
155
+ else:
156
+ x = x + attn_residual_func(x, pos=pos)
157
+ x = x + ffn_residual_func(x)
158
+ return x
159
+
160
+
161
+ class Decoder(nn.Module):
162
+ """Attention blocks after encoder per DPT input feature
163
+ to generate point maps at a given time.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ cfg,
169
+ dim_in: int,
170
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
171
+ patch_size=14,
172
+ embed_dim=1024,
173
+ depth=2,
174
+ num_heads=16,
175
+ mlp_ratio=4.0,
176
+ block_fn=ConditionalBlock,
177
+ qkv_bias=True,
178
+ proj_bias=True,
179
+ ffn_bias=True,
180
+ aa_order=["frame", "global"],
181
+ aa_block_size=1,
182
+ qk_norm=True,
183
+ rope_freq=100,
184
+ init_values=0.01,
185
+ ):
186
+ super().__init__()
187
+ self.cfg = cfg
188
+ self.intermediate_layer_idx = intermediate_layer_idx
189
+
190
+ self.depth = depth
191
+ self.aa_order = aa_order
192
+ self.patch_size = patch_size
193
+ self.aa_block_size = aa_block_size
194
+
195
+ # Validate that depth is divisible by aa_block_size
196
+ if self.depth % self.aa_block_size != 0:
197
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
198
+
199
+ self.aa_block_num = self.depth // self.aa_block_size
200
+
201
+ self.rope = (
202
+ RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
203
+ )
204
+ self.position_getter = PositionGetter() if self.rope is not None else None
205
+
206
+ self.dim_in = dim_in
207
+
208
+ self.old_decoder = False
209
+ if self.old_decoder:
210
+ self.frame_blocks = nn.ModuleList(
211
+ [
212
+ block_fn(
213
+ dim=embed_dim*2,
214
+ num_heads=num_heads,
215
+ mlp_ratio=mlp_ratio,
216
+ qkv_bias=qkv_bias,
217
+ proj_bias=proj_bias,
218
+ ffn_bias=ffn_bias,
219
+ init_values=init_values,
220
+ qk_norm=qk_norm,
221
+ rope=self.rope,
222
+ )
223
+ for _ in range(depth)
224
+ ]
225
+ )
226
+ self.global_blocks = nn.ModuleList(
227
+ [
228
+ block_fn(
229
+ dim=embed_dim*2,
230
+ num_heads=num_heads,
231
+ mlp_ratio=mlp_ratio,
232
+ qkv_bias=qkv_bias,
233
+ proj_bias=proj_bias,
234
+ ffn_bias=ffn_bias,
235
+ init_values=init_values,
236
+ qk_norm=qk_norm,
237
+ rope=self.rope,
238
+ )
239
+ for _ in range(depth)
240
+ ]
241
+ )
242
+ else:
243
+ depths = [depth]
244
+ self.frame_blocks = nn.ModuleList([
245
+ nn.ModuleList([
246
+ block_fn(
247
+ dim=embed_dim*2,
248
+ num_heads=num_heads,
249
+ mlp_ratio=mlp_ratio,
250
+ qkv_bias=qkv_bias,
251
+ proj_bias=proj_bias,
252
+ ffn_bias=ffn_bias,
253
+ init_values=init_values,
254
+ qk_norm=qk_norm,
255
+ rope=self.rope,
256
+ )
257
+ for _ in range(d)
258
+ ])
259
+ for d in depths
260
+ ])
261
+
262
+ self.global_blocks = nn.ModuleList([
263
+ nn.ModuleList([
264
+ block_fn(
265
+ dim=embed_dim*2,
266
+ num_heads=num_heads,
267
+ mlp_ratio=mlp_ratio,
268
+ qkv_bias=qkv_bias,
269
+ proj_bias=proj_bias,
270
+ ffn_bias=ffn_bias,
271
+ init_values=init_values,
272
+ qk_norm=qk_norm,
273
+ rope=self.rope,
274
+ )
275
+ for _ in range(d)
276
+ ])
277
+ for d in depths
278
+ ])
279
+
280
+ self.use_reentrant = False # hardcoded to False
281
+
282
+ def get_condition_tokens(
283
+ self,
284
+ aggregated_tokens_list: List[torch.Tensor],
285
+ cond_view_idxs: torch.Tensor
286
+ ):
287
+ # Use tokens from the last block for conditioning
288
+ tokens_last = aggregated_tokens_list[-1] # [B S N_tok D]
289
+ # Extract the camera tokens
290
+ cond_token_idx = 1
291
+ camera_tokens = tokens_last[:, :, [cond_token_idx]] # [B S D]
292
+
293
+ cond_view_idxs = cond_view_idxs.to(camera_tokens.device)
294
+ cond_view_idxs = repeat(
295
+ cond_view_idxs,
296
+ "b s -> b s c d",
297
+ c=camera_tokens.shape[2],
298
+ d=camera_tokens.shape[3],
299
+ )
300
+ cond_tokens = torch.gather(camera_tokens, 1, cond_view_idxs)
301
+
302
+ return cond_tokens
303
+
304
+ def forward(
305
+ self,
306
+ images: torch.Tensor,
307
+ aggregated_tokens_list: List[torch.Tensor],
308
+ patch_start_idx: int,
309
+ cond_view_idxs: torch.Tensor,
310
+ ):
311
+ B, S, _, H, W = images.shape
312
+
313
+ cond_tokens = self.get_condition_tokens(
314
+ aggregated_tokens_list, cond_view_idxs
315
+ )
316
+
317
+ input_tokens = []
318
+ for k, layer_idx in enumerate(self.intermediate_layer_idx):
319
+ layer_tokens = aggregated_tokens_list[layer_idx].clone()
320
+ input_tokens.append(layer_tokens)
321
+
322
+ _, _, P, C = input_tokens[0].shape
323
+
324
+ pos = None
325
+ if self.rope is not None:
326
+ pos = self.position_getter(
327
+ B * S, H // self.patch_size, W // self.patch_size, device=images.device
328
+ )
329
+ if patch_start_idx > 0:
330
+ # do not use position embedding for special tokens (camera and register tokens)
331
+ # so set pos to 0 for the special tokens
332
+ pos = pos + 1
333
+ pos_special = torch.zeros(B * S, patch_start_idx, 2).to(images.device).to(pos.dtype)
334
+ pos = torch.cat([pos_special, pos], dim=1)
335
+
336
+ frame_idx = 0
337
+ global_idx = 0
338
+ depth = len(self.frame_blocks[0])
339
+ N = len(input_tokens)
340
+ # stack all intermediate layer tokens along batch dimension
341
+ # they are all processed by the same decoder
342
+ s_tokens = torch.cat(input_tokens)
343
+ s_cond_tokens = torch.cat([cond_tokens] * N, dim=0)
344
+ s_pos = torch.cat([pos] * N, dim=0)
345
+
346
+ # perform time conditioned attention
347
+ for _ in range(depth):
348
+ for attn_type in self.aa_order:
349
+ token_idx = 0
350
+
351
+ if attn_type == "frame":
352
+ s_tokens, frame_idx, _ = self._process_frame_attention(
353
+ s_tokens, s_cond_tokens, B * N, S, P, C, frame_idx, pos=s_pos, token_idx=token_idx
354
+ )
355
+ elif attn_type == "global":
356
+ s_tokens, global_idx, _ = self._process_global_attention(
357
+ s_tokens, s_cond_tokens, B * N, S, P, C, global_idx, pos=s_pos, token_idx=token_idx
358
+ )
359
+ else:
360
+ raise ValueError(f"Unknown attention type: {attn_type}")
361
+ processed = [t.view(B, S, P, C) for t in s_tokens.split(B, dim=0)]
362
+
363
+ return processed
364
+
365
+ def _process_frame_attention(self, tokens, cond_tokens, B, S, P, C, frame_idx, pos=None, token_idx=0):
366
+ """
367
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
368
+ """
369
+ # If needed, reshape tokens or positions:
370
+ if tokens.shape != (B * S, P, C):
371
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
372
+
373
+ if pos is not None and pos.shape != (B * S, P, 2):
374
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
375
+
376
+ intermediates = []
377
+ # by default, self.aa_block_size=1, which processes one block at a time
378
+ for _ in range(self.aa_block_size):
379
+ if self.training:
380
+ tokens = checkpoint(self.frame_blocks[token_idx][frame_idx], tokens, pos, cond_tokens, use_reentrant=self.use_reentrant)
381
+ else:
382
+ if self.old_decoder:
383
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos, cond=cond_tokens)
384
+ else:
385
+ tokens = self.frame_blocks[0][frame_idx](tokens, pos=pos, cond=cond_tokens)
386
+
387
+ frame_idx += 1
388
+ intermediates.append(tokens.view(B, S, P, C))
389
+
390
+ return tokens, frame_idx, intermediates
391
+
392
+ def _process_global_attention(self, tokens, cond_tokens, B, S, P, C, global_idx, pos=None, token_idx=0):
393
+ """
394
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
395
+ """
396
+ if tokens.shape != (B, S * P, C):
397
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
398
+
399
+ if pos is not None and pos.shape != (B, S * P, 2):
400
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
401
+
402
+ intermediates = []
403
+
404
+ # by default, self.aa_block_size=1, which processes one block at a time
405
+ for _ in range(self.aa_block_size):
406
+ if self.training:
407
+ tokens = checkpoint(self.global_blocks[token_idx][global_idx], tokens, pos, cond_tokens, True, use_reentrant=self.use_reentrant)
408
+ else:
409
+ if self.old_decoder:
410
+ tokens = self.global_blocks[global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
411
+ else:
412
+ tokens = self.global_blocks[0][global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
413
+ global_idx += 1
414
+ intermediates.append(tokens.view(B, S, P, C))
415
+
416
+ return tokens, global_idx, intermediates
vdpm/dpm/model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from vggt.heads.camera_head import CameraHead
5
+ from vggt.heads.dpt_head import DPTHead
6
+
7
+ from .aggregator import Aggregator
8
+ from .decoder import Decoder
9
+
10
+
11
+ def freeze_all_params(modules):
12
+ for module in modules:
13
+ try:
14
+ for n, param in module.named_parameters():
15
+ param.requires_grad = False
16
+ except AttributeError:
17
+ # module is directly a parameter
18
+ module.requires_grad = False
19
+
20
+
21
+ class VDPM(nn.Module):
22
+ def __init__(self, cfg, img_size=518, patch_size=14, embed_dim=1024):
23
+ super().__init__()
24
+ self.cfg = cfg
25
+
26
+ self.aggregator = Aggregator(
27
+ img_size=img_size,
28
+ patch_size=patch_size,
29
+ embed_dim=embed_dim,
30
+ )
31
+ self.decoder = Decoder(
32
+ cfg,
33
+ dim_in=2*embed_dim,
34
+ embed_dim=embed_dim,
35
+ depth=cfg.model.decoder_depth
36
+ )
37
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
38
+
39
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
40
+ self.set_freeze()
41
+
42
+ def set_freeze(self):
43
+ to_be_frozen = [self.aggregator.patch_embed]
44
+ freeze_all_params(to_be_frozen)
45
+
46
+ def forward(
47
+ self,
48
+ views, autocast_dpt=None
49
+ ):
50
+ images = torch.stack([view["img"] for view in views], dim=1)
51
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
52
+
53
+ res_dynamic = dict()
54
+
55
+ if self.decoder is not None:
56
+ cond_view_idxs = torch.stack([view["view_idxs"][:, 1] for view in views], dim=1)
57
+ decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
58
+
59
+ if autocast_dpt is None:
60
+ autocast_dpt = torch.amp.autocast("cuda", enabled=False)
61
+
62
+ with autocast_dpt:
63
+ pts3d, pts3d_conf = self.point_head(
64
+ aggregated_tokens_list, images, patch_start_idx
65
+ )
66
+
67
+ padded_decoded_tokens = [None] * len(aggregated_tokens_list)
68
+ for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
69
+ padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
70
+ pts3d_dyn, pts3d_dyn_conf = self.point_head(
71
+ padded_decoded_tokens, images, patch_start_idx
72
+ )
73
+
74
+ res_dynamic |= {
75
+ "pts3d": pts3d_dyn,
76
+ "conf": pts3d_dyn_conf
77
+ }
78
+
79
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
80
+ res_dynamic |= {"pose_enc_list": pose_enc_list}
81
+
82
+ res_static = dict(
83
+ pts3d=pts3d,
84
+ conf=pts3d_conf
85
+ )
86
+ return res_static, res_dynamic
87
+
88
+ def inference(
89
+ self,
90
+ views,
91
+ images=None,
92
+ num_timesteps=None
93
+ ):
94
+ autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
95
+
96
+ if images is None:
97
+ images = torch.stack([view["img"] for view in views], dim=1)
98
+
99
+ with autocast_amp:
100
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
101
+ S = images.shape[1]
102
+
103
+ # Determine number of timesteps to query
104
+ if num_timesteps is None:
105
+ # Default to S if not specified (legacy behavior)
106
+ # But if views has indices, try to infer max time
107
+ if views is not None and "view_idxs" in views[0]:
108
+ try:
109
+ all_idxs = torch.cat([v["view_idxs"][:, 1] for v in views])
110
+ num_timesteps = int(all_idxs.max().item()) + 1
111
+ except:
112
+ num_timesteps = S
113
+ else:
114
+ num_timesteps = S
115
+
116
+ predictions = dict()
117
+ pointmaps = []
118
+ ones = torch.ones(1, S, dtype=torch.int64)
119
+ for time_ in range(num_timesteps):
120
+ cond_view_idxs = ones * time_
121
+
122
+ with autocast_amp:
123
+ decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
124
+ padded_decoded_tokens = [None] * len(aggregated_tokens_list)
125
+ for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
126
+ padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
127
+
128
+ # ... existing code ...
129
+
130
+ pts3d, pts3d_conf = self.point_head(
131
+ padded_decoded_tokens, images, patch_start_idx
132
+ )
133
+
134
+ pointmaps.append(dict(
135
+ pts3d=pts3d,
136
+ conf=pts3d_conf
137
+ ))
138
+
139
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
140
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
141
+ predictions["pose_enc_list"] = pose_enc_list
142
+ predictions["pointmaps"] = pointmaps
143
+ return predictions
144
+
145
+ def load_state_dict(self, ckpt, is_VGGT_static=False, **kw):
146
+ # don't load these VGGT heads as not needed
147
+ exclude = ["depth_head", "track_head"]
148
+ ckpt = {k:v for k, v in ckpt.items() if k.split('.')[0] not in exclude}
149
+ return super().load_state_dict(ckpt, **kw)
vdpm/examples/videos/camel.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3db92c240efbd1b97a466565988a9a06687fd422086656dc0a29e12c5b99b9bb
3
+ size 1301172
vdpm/examples/videos/car.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd74efdb4d4d59fc17356fefa5dadd4c5b787641c98ce3172ecd8e5a180e76a6
3
+ size 1015132
vdpm/examples/videos/figure1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae285726e5d247e904bb1ea7887ee96733c0beea913b421abba39150a3299cd5
3
+ size 465850
vdpm/examples/videos/figure2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b2b030dd564cffbb9b2795e7fcdf97fa50e3a518df5b71dfb3dfb36f431dfa4
3
+ size 516209
vdpm/examples/videos/figure3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a4144a53f14bd2dc671376d26ecbb42b06c9b8810e1700f21a16d3e11dfbf5c
3
+ size 559096
vdpm/examples/videos/goldfish.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28912e59d0d9e6b20d26973efee4806e89e115c7f1e63aec7206384ac3d0bf78
3
+ size 668862
vdpm/examples/videos/horse.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8227c7d901a936aeab6a2b41f104dd17e5544315d4cde7dac37f5787319947e7
3
+ size 1223145