Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +26 -35
- .gitignore +40 -0
- 2403.20309v6.pdf +3 -0
- 2601.09499v1.pdf +3 -0
- README.md +113 -12
- app.py +751 -0
- gs/.gitattributes +2 -0
- gs/.gitignore +13 -0
- gs/backward.py +1084 -0
- gs/config.py +151 -0
- gs/create_training_video.py +73 -0
- gs/dataset_reader.py +1 -0
- gs/forward.py +804 -0
- gs/lib64 +1 -0
- gs/loss.py +303 -0
- gs/optimizer.py +399 -0
- gs/render.py +141 -0
- gs/scheduler.py +28 -0
- gs/train.py +1044 -0
- gs/train_colmap.py +1586 -0
- gs/train_vdpm.py +712 -0
- gs/training_progress.mp4 +3 -0
- gs/utils/analyze_scales.py +100 -0
- gs/utils/camera_utils.py +215 -0
- gs/utils/check_opacities.py +21 -0
- gs/utils/math_utils.py +111 -0
- gs/utils/plot_loss_log.py +35 -0
- gs/utils/point_cloud_utils.py +160 -0
- gs/utils/wp_utils.py +45 -0
- requirements.txt +26 -0
- vdpm/.gitignore +132 -0
- vdpm/.gitmodules +0 -0
- vdpm/.gradio/certificate.pem +31 -0
- vdpm/LICENSE +22 -0
- vdpm/LICENSE-VGGT +115 -0
- vdpm/README.md +44 -0
- vdpm/check_model_size.py +85 -0
- vdpm/configs/config.yaml +50 -0
- vdpm/configs/model/dpm.yaml +3 -0
- vdpm/configs/visualise.yaml +13 -0
- vdpm/dpm/aggregator.py +366 -0
- vdpm/dpm/decoder.py +416 -0
- vdpm/dpm/model.py +149 -0
- vdpm/examples/videos/camel.mp4 +3 -0
- vdpm/examples/videos/car.mp4 +3 -0
- vdpm/examples/videos/figure1.mp4 +3 -0
- vdpm/examples/videos/figure2.mp4 +3 -0
- vdpm/examples/videos/figure3.mp4 +3 -0
- vdpm/examples/videos/goldfish.mp4 +3 -0
- vdpm/examples/videos/horse.mp4 +3 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,26 @@
|
|
| 1 |
-
|
| 2 |
-
*
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
| 3 |
+
2403.20309v6.pdf filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
2601.09499v1.pdf filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
gs/training_progress.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
vdpm/examples/videos/camel.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
vdpm/examples/videos/car.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
vdpm/examples/videos/figure1.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
vdpm/examples/videos/figure2.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
vdpm/examples/videos/figure3.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
vdpm/examples/videos/goldfish.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
vdpm/examples/videos/horse.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
vdpm/examples/videos/paragliding.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
vdpm/examples/videos/pstudio.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
vdpm/examples/videos/stroller.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
vdpm/examples/videos/swing.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
vdpm/examples/videos/tennis.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
vdpm/examples/videos/tesla.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
vdpm/input_images_20260128_014417_015976/images/000000.png filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
vdpm/input_images_20260128_014417_015976/images/000001.png filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
vdpm/input_images_20260128_014417_015976/images/000002.png filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
vdpm/input_images_20260128_014417_015976/images/000003.png filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
vdpm/input_images_20260128_014417_015976/output_4d.npz filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
vdpm/input_images_20260128_014417_015976/poses.npz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
vdpm/input_images_20260128_014417_015976/reconstruction_data.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
vdpm/input_images_20260128_014417_015976/tracks.npz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# User requested ignores
|
| 2 |
+
output/
|
| 3 |
+
mv-video/
|
| 4 |
+
|
| 5 |
+
# Python
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
*.so
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# Virtual Environments
|
| 30 |
+
venv/
|
| 31 |
+
env/
|
| 32 |
+
ENV/
|
| 33 |
+
env.bak/
|
| 34 |
+
venv.bak/
|
| 35 |
+
|
| 36 |
+
# VS Code
|
| 37 |
+
.vscode/
|
| 38 |
+
|
| 39 |
+
# Gradio
|
| 40 |
+
.gradio/
|
2403.20309v6.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd8415f171a0353126dcb1029126f48b805c3c24f65706bcc930c55dfc5dcc2e
|
| 3 |
+
size 8417471
|
2601.09499v1.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09bad1eec73fad7ab1cc4d9c4da01305d3c4cebb3094c06924cd4df088065738
|
| 3 |
+
size 13097134
|
README.md
CHANGED
|
@@ -1,12 +1,113 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: 4dgs
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: 4dgs-dpm
|
| 3 |
+
app_file: app.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 5.17.1
|
| 6 |
+
---
|
| 7 |
+
# DPM-Splat: Video → 4D Gaussian Splats
|
| 8 |
+
|
| 9 |
+
End-to-end pipeline combining **V-DPM** (Video Dynamic Point Maps) with **3D Gaussian Splatting** for dynamic 4D scene reconstruction from multi-view video.
|
| 10 |
+
|
| 11 |
+

|
| 12 |
+

|
| 13 |
+
|
| 14 |
+
## Features
|
| 15 |
+
|
| 16 |
+
- **Feed-forward reconstruction**: No per-scene optimization needed for initial point cloud
|
| 17 |
+
- **Multi-view support**: 1-4 synchronized video inputs
|
| 18 |
+
- **Temporal consistency**: Dynamic point tracking across frames
|
| 19 |
+
- **Memory efficient**: BF16/FP16 quantization, flash attention support
|
| 20 |
+
- **Co-visibility filtering**: Reduces redundant points (InstantSplat-inspired)
|
| 21 |
+
- **Gradio demo**: Easy-to-use web interface
|
| 22 |
+
|
| 23 |
+
## Demo
|
| 24 |
+
|
| 25 |
+
Run the interactive demo:
|
| 26 |
+
```bash
|
| 27 |
+
python app.py
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Or try the hosted version on [Hugging Face Spaces](https://huggingface.co/spaces/YOUR_USERNAME/dpm-splat)
|
| 31 |
+
|
| 32 |
+
## Installation
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
# Create environment
|
| 36 |
+
conda create -n 4dgs-dpm python=3.10
|
| 37 |
+
conda activate 4dgs-dpm
|
| 38 |
+
|
| 39 |
+
# Install PyTorch with CUDA
|
| 40 |
+
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
|
| 41 |
+
|
| 42 |
+
# Install dependencies
|
| 43 |
+
pip install -r requirements.txt
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Usage
|
| 47 |
+
|
| 48 |
+
### Web Interface (Recommended)
|
| 49 |
+
```bash
|
| 50 |
+
python app.py
|
| 51 |
+
```
|
| 52 |
+
Upload videos, adjust settings, and download results as ZIP.
|
| 53 |
+
|
| 54 |
+
### Command Line
|
| 55 |
+
```bash
|
| 56 |
+
# Run VDPM inference
|
| 57 |
+
python vdpm/visualise.py --input mv-video/your-video --output output/vdpm
|
| 58 |
+
|
| 59 |
+
# Train 3DGS from VDPM output
|
| 60 |
+
python -m gs.train_vdpm --input output/vdpm --output output/splats --iterations 1000
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Pipeline
|
| 64 |
+
|
| 65 |
+
1. **Video Processing**: Extract and interleave frames from multi-view videos
|
| 66 |
+
2. **VDPM Inference**: Generate dynamic point maps and camera poses using VGGT backbone
|
| 67 |
+
3. **3DGS Training**: Train per-frame Gaussian splats initialized from point maps
|
| 68 |
+
4. **Animation Rendering**: Generate GIF from interpolated camera viewpoint
|
| 69 |
+
|
| 70 |
+
## Output
|
| 71 |
+
|
| 72 |
+
The pipeline generates:
|
| 73 |
+
- `splats/frame_XXXX.ply` - Gaussian splat for each timestep
|
| 74 |
+
- `renders/` - Training progress images
|
| 75 |
+
- `animation.gif` - Rendered animation from average camera
|
| 76 |
+
- `tracks.npz` - 3D point tracks
|
| 77 |
+
- `poses.npz` - Camera poses
|
| 78 |
+
|
| 79 |
+
## Requirements
|
| 80 |
+
|
| 81 |
+
- NVIDIA GPU with 8GB+ VRAM (tested on RTX 3070 Ti)
|
| 82 |
+
- CUDA 11.8+
|
| 83 |
+
- Python 3.10+
|
| 84 |
+
|
| 85 |
+
## TO-DO
|
| 86 |
+
|
| 87 |
+
- [x] VGGT Quantization (BF16/FP16)
|
| 88 |
+
- [x] Co-visibility check to reduce points
|
| 89 |
+
- [x] Dynamic point tracking
|
| 90 |
+
- [x] Per-frame 3DGS training
|
| 91 |
+
- [x] Gradio demo with GIF rendering
|
| 92 |
+
- [ ] Flash Attention for VGGT
|
| 93 |
+
- [ ] Dynamic/Static segmentation
|
| 94 |
+
- [ ] 3DGS with dynamic deformation field
|
| 95 |
+
- [ ] 4DGS primitive support
|
| 96 |
+
|
| 97 |
+
## Citation
|
| 98 |
+
|
| 99 |
+
```bibtex
|
| 100 |
+
@misc{dpmsplat2026,
|
| 101 |
+
title={DPM-Splat: Video to 4D Gaussian Splats via Dynamic Point Maps},
|
| 102 |
+
author={Your Name},
|
| 103 |
+
year={2026},
|
| 104 |
+
url={https://github.com/YOUR_USERNAME/4dgs-dpm}
|
| 105 |
+
}
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Acknowledgements
|
| 109 |
+
|
| 110 |
+
- [VGGT](https://github.com/facebookresearch/vggt) - Visual Geometry Grounded Transformer
|
| 111 |
+
- [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting)
|
| 112 |
+
- [NVIDIA Warp](https://github.com/NVIDIA/warp)
|
| 113 |
+
|
app.py
ADDED
|
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DPM-Splat: End-to-end pipeline for Video → 4D Gaussian Splats
|
| 3 |
+
Combines VDPM inference with 3DGS training in a single Gradio interface.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import shutil
|
| 9 |
+
import zipfile
|
| 10 |
+
import gc
|
| 11 |
+
import json
|
| 12 |
+
import glob
|
| 13 |
+
import time
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
|
| 17 |
+
import cv2
|
| 18 |
+
import numpy as np
|
| 19 |
+
import gradio as gr
|
| 20 |
+
import torch
|
| 21 |
+
import imageio
|
| 22 |
+
|
| 23 |
+
# Set memory optimization
|
| 24 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 25 |
+
|
| 26 |
+
# Add paths
|
| 27 |
+
sys.path.insert(0, str(Path(__file__).parent / "vdpm"))
|
| 28 |
+
sys.path.insert(0, str(Path(__file__).parent / "gs"))
|
| 29 |
+
|
| 30 |
+
# Check GPU availability
|
| 31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
|
| 33 |
+
if device == "cuda":
|
| 34 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 35 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 36 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 37 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 38 |
+
print(f"✓ GPU: {gpu_name} ({gpu_mem:.1f} GB)")
|
| 39 |
+
else:
|
| 40 |
+
print("⚠ No GPU detected - running on CPU (will be slow)")
|
| 41 |
+
|
| 42 |
+
# Configuration
|
| 43 |
+
VIDEO_SAMPLE_HZ = 1.0
|
| 44 |
+
MAX_FRAMES = 8 if device == "cuda" else 4
|
| 45 |
+
|
| 46 |
+
# Global model cache
|
| 47 |
+
_vdpm_model = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_vdpm_model():
|
| 51 |
+
"""Load and cache the VDPM model"""
|
| 52 |
+
global _vdpm_model
|
| 53 |
+
|
| 54 |
+
if _vdpm_model is not None:
|
| 55 |
+
print("✓ Using cached VDPM model")
|
| 56 |
+
return _vdpm_model
|
| 57 |
+
|
| 58 |
+
print("Loading VDPM model...")
|
| 59 |
+
sys.stdout.flush()
|
| 60 |
+
|
| 61 |
+
from hydra import compose, initialize
|
| 62 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 63 |
+
from dpm.model import VDPM
|
| 64 |
+
|
| 65 |
+
if GlobalHydra.instance().is_initialized():
|
| 66 |
+
GlobalHydra.instance().clear()
|
| 67 |
+
|
| 68 |
+
with initialize(config_path="vdpm/configs"):
|
| 69 |
+
cfg = compose(config_name="visualise")
|
| 70 |
+
|
| 71 |
+
model = VDPM(cfg).to(device)
|
| 72 |
+
|
| 73 |
+
# Load weights
|
| 74 |
+
cache_dir = os.path.expanduser("~/.cache/vdpm")
|
| 75 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 76 |
+
model_path = os.path.join(cache_dir, "vdpm_model.pt")
|
| 77 |
+
|
| 78 |
+
_URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
|
| 79 |
+
|
| 80 |
+
if not os.path.exists(model_path):
|
| 81 |
+
print(f"Downloading VDPM model...")
|
| 82 |
+
sd = torch.hub.load_state_dict_from_url(_URL, file_name="vdpm_model.pt", progress=True, map_location=device)
|
| 83 |
+
torch.save(sd, model_path)
|
| 84 |
+
else:
|
| 85 |
+
print(f"✓ Loading cached model from {model_path}")
|
| 86 |
+
sd = torch.load(model_path, map_location=device)
|
| 87 |
+
|
| 88 |
+
model.load_state_dict(sd, strict=True)
|
| 89 |
+
model.eval()
|
| 90 |
+
|
| 91 |
+
# Use half precision
|
| 92 |
+
if device == "cuda":
|
| 93 |
+
if torch.cuda.get_device_capability()[0] >= 8:
|
| 94 |
+
model = model.to(torch.bfloat16)
|
| 95 |
+
print("✓ Using BF16 precision")
|
| 96 |
+
else:
|
| 97 |
+
model = model.half()
|
| 98 |
+
print("✓ Using FP16 precision")
|
| 99 |
+
|
| 100 |
+
_vdpm_model = model
|
| 101 |
+
return model
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def process_videos(video_files, target_dir):
|
| 105 |
+
"""Extract and interleave frames from uploaded videos"""
|
| 106 |
+
images_dir = target_dir / "images"
|
| 107 |
+
images_dir.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
num_views = len(video_files)
|
| 110 |
+
captures = []
|
| 111 |
+
intervals = []
|
| 112 |
+
|
| 113 |
+
for vid_obj in video_files:
|
| 114 |
+
video_path = vid_obj.name if hasattr(vid_obj, 'name') else str(vid_obj)
|
| 115 |
+
vs = cv2.VideoCapture(video_path)
|
| 116 |
+
fps = float(vs.get(cv2.CAP_PROP_FPS) or 30.0)
|
| 117 |
+
interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1)
|
| 118 |
+
captures.append(vs)
|
| 119 |
+
intervals.append(interval)
|
| 120 |
+
|
| 121 |
+
# Interleave frames
|
| 122 |
+
frame_num = 0
|
| 123 |
+
step_count = 0
|
| 124 |
+
active = True
|
| 125 |
+
image_paths = []
|
| 126 |
+
|
| 127 |
+
while active:
|
| 128 |
+
active = False
|
| 129 |
+
for i, vs in enumerate(captures):
|
| 130 |
+
if not vs.isOpened():
|
| 131 |
+
continue
|
| 132 |
+
ret, frame = vs.read()
|
| 133 |
+
if ret:
|
| 134 |
+
active = True
|
| 135 |
+
if step_count % intervals[i] == 0:
|
| 136 |
+
out_path = images_dir / f"{frame_num:06d}.png"
|
| 137 |
+
cv2.imwrite(str(out_path), frame)
|
| 138 |
+
image_paths.append(str(out_path))
|
| 139 |
+
frame_num += 1
|
| 140 |
+
else:
|
| 141 |
+
vs.release()
|
| 142 |
+
step_count += 1
|
| 143 |
+
|
| 144 |
+
for vs in captures:
|
| 145 |
+
if vs.isOpened():
|
| 146 |
+
vs.release()
|
| 147 |
+
|
| 148 |
+
# Save metadata
|
| 149 |
+
meta = {"num_views": num_views}
|
| 150 |
+
with open(target_dir / "meta.json", "w") as f:
|
| 151 |
+
json.dump(meta, f)
|
| 152 |
+
|
| 153 |
+
return image_paths, num_views
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run_vdpm_inference(target_dir, progress):
|
| 157 |
+
"""Run VDPM inference"""
|
| 158 |
+
from vggt.utils.load_fn import load_and_preprocess_images
|
| 159 |
+
|
| 160 |
+
model = get_vdpm_model()
|
| 161 |
+
|
| 162 |
+
image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
|
| 163 |
+
if not image_names:
|
| 164 |
+
raise ValueError("No images found")
|
| 165 |
+
|
| 166 |
+
# Load metadata
|
| 167 |
+
meta_path = target_dir / "meta.json"
|
| 168 |
+
num_views = 1
|
| 169 |
+
if meta_path.exists():
|
| 170 |
+
with open(meta_path) as f:
|
| 171 |
+
num_views = json.load(f).get("num_views", 1)
|
| 172 |
+
|
| 173 |
+
# Limit frames
|
| 174 |
+
if len(image_names) > MAX_FRAMES:
|
| 175 |
+
limit = (MAX_FRAMES // num_views) * num_views
|
| 176 |
+
if limit == 0:
|
| 177 |
+
limit = num_views
|
| 178 |
+
print(f"⚠ Limiting to {limit} frames")
|
| 179 |
+
image_names = image_names[:limit]
|
| 180 |
+
|
| 181 |
+
progress(0.15, desc=f"Loading {len(image_names)} images...")
|
| 182 |
+
images = load_and_preprocess_images(image_names).to(device)
|
| 183 |
+
|
| 184 |
+
# Construct views
|
| 185 |
+
views = []
|
| 186 |
+
for i in range(len(image_names)):
|
| 187 |
+
t_idx = i // num_views
|
| 188 |
+
cam_idx = i % num_views
|
| 189 |
+
views.append({
|
| 190 |
+
"img": images[i].unsqueeze(0),
|
| 191 |
+
"view_idxs": torch.tensor([[cam_idx, t_idx]], device=device, dtype=torch.long)
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
progress(0.2, desc="Running VDPM forward pass...")
|
| 195 |
+
print(f"Running inference on {len(image_names)} images...")
|
| 196 |
+
sys.stdout.flush()
|
| 197 |
+
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
with torch.amp.autocast('cuda'):
|
| 200 |
+
predictions = model.inference(views=views)
|
| 201 |
+
|
| 202 |
+
# Extract results
|
| 203 |
+
pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
|
| 204 |
+
conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
|
| 205 |
+
|
| 206 |
+
pose_enc = None
|
| 207 |
+
if "pose_enc" in predictions:
|
| 208 |
+
pose_enc = predictions["pose_enc"].detach().cpu().numpy()
|
| 209 |
+
|
| 210 |
+
del predictions
|
| 211 |
+
torch.cuda.empty_cache()
|
| 212 |
+
|
| 213 |
+
world_points_raw = np.concatenate(pts_list, axis=0)
|
| 214 |
+
world_points_conf_raw = np.concatenate(conf_list, axis=0)
|
| 215 |
+
|
| 216 |
+
T = world_points_raw.shape[0]
|
| 217 |
+
S = world_points_raw.shape[1]
|
| 218 |
+
num_timesteps = T
|
| 219 |
+
|
| 220 |
+
# Process multi-view
|
| 221 |
+
if num_views > 1 and S == num_views * T:
|
| 222 |
+
world_points_list = []
|
| 223 |
+
world_points_conf_list = []
|
| 224 |
+
for t in range(T):
|
| 225 |
+
start_idx = t * num_views
|
| 226 |
+
end_idx = start_idx + num_views
|
| 227 |
+
world_points_list.append(world_points_raw[t, start_idx:end_idx])
|
| 228 |
+
world_points_conf_list.append(world_points_conf_raw[t, start_idx:end_idx])
|
| 229 |
+
world_points = np.stack(world_points_list, axis=0)
|
| 230 |
+
world_points_conf = np.stack(world_points_conf_list, axis=0)
|
| 231 |
+
else:
|
| 232 |
+
if world_points_raw.ndim == 5 and world_points_raw.shape[0] == 1:
|
| 233 |
+
world_points = world_points_raw[0]
|
| 234 |
+
world_points_conf = world_points_conf_raw[0]
|
| 235 |
+
else:
|
| 236 |
+
world_points = world_points_raw
|
| 237 |
+
world_points_conf = world_points_conf_raw
|
| 238 |
+
|
| 239 |
+
progress(0.35, desc="Saving VDPM outputs...")
|
| 240 |
+
|
| 241 |
+
# Save outputs
|
| 242 |
+
np.savez_compressed(
|
| 243 |
+
target_dir / "tracks.npz",
|
| 244 |
+
world_points=world_points,
|
| 245 |
+
world_points_conf=world_points_conf,
|
| 246 |
+
num_views=num_views,
|
| 247 |
+
num_timesteps=num_timesteps
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if pose_enc is not None:
|
| 251 |
+
np.savez_compressed(target_dir / "poses.npz", pose_enc=pose_enc)
|
| 252 |
+
|
| 253 |
+
print(f"✓ VDPM complete: {num_timesteps} timesteps, {num_views} views")
|
| 254 |
+
sys.stdout.flush()
|
| 255 |
+
|
| 256 |
+
return num_timesteps, num_views
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def run_3dgs_training(target_dir, output_dir, iterations, conf_threshold, progress):
|
| 260 |
+
"""Run 3DGS training"""
|
| 261 |
+
import warp as wp
|
| 262 |
+
from train_vdpm import load_vdpm_data, VDPM3DGSTrainer
|
| 263 |
+
|
| 264 |
+
wp.init()
|
| 265 |
+
|
| 266 |
+
data = load_vdpm_data(str(target_dir))
|
| 267 |
+
num_timesteps = data['T']
|
| 268 |
+
|
| 269 |
+
output_path = Path(output_dir)
|
| 270 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 271 |
+
|
| 272 |
+
all_ply_files = []
|
| 273 |
+
|
| 274 |
+
for frame_idx in range(num_timesteps):
|
| 275 |
+
frame_progress = 0.4 + (0.5 * frame_idx / num_timesteps)
|
| 276 |
+
progress(frame_progress, desc=f"Training frame {frame_idx + 1}/{num_timesteps}...")
|
| 277 |
+
|
| 278 |
+
print(f"\n{'='*50}")
|
| 279 |
+
print(f"[Frame {frame_idx + 1}/{num_timesteps}]")
|
| 280 |
+
print(f"{'='*50}")
|
| 281 |
+
sys.stdout.flush()
|
| 282 |
+
|
| 283 |
+
trainer = VDPM3DGSTrainer(
|
| 284 |
+
data=data,
|
| 285 |
+
frame_idx=frame_idx,
|
| 286 |
+
output_path=str(output_path),
|
| 287 |
+
conf_threshold=conf_threshold
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Training loop with progress
|
| 291 |
+
print(f"Training for {iterations} iterations...")
|
| 292 |
+
sys.stdout.flush()
|
| 293 |
+
|
| 294 |
+
trainer.save(0) # Initial state
|
| 295 |
+
|
| 296 |
+
for it in range(iterations):
|
| 297 |
+
trainer.zero_grad()
|
| 298 |
+
|
| 299 |
+
cam_idx = np.random.randint(len(trainer.cameras))
|
| 300 |
+
camera = trainer.cameras[cam_idx]
|
| 301 |
+
target = trainer.images[cam_idx]
|
| 302 |
+
|
| 303 |
+
from forward import render_gaussians
|
| 304 |
+
from loss import l1_loss, compute_image_gradients
|
| 305 |
+
from backward import backward
|
| 306 |
+
from optimizer import adam_update
|
| 307 |
+
from config import DEVICE
|
| 308 |
+
|
| 309 |
+
rendered, depth, trainer.intermediate_buffers = render_gaussians(
|
| 310 |
+
background=np.array(trainer.config['background_color'], dtype=np.float32),
|
| 311 |
+
means3D=trainer.params['positions'].numpy(),
|
| 312 |
+
colors=None,
|
| 313 |
+
opacity=trainer.params['opacities'].numpy(),
|
| 314 |
+
scales=trainer.params['scales'].numpy(),
|
| 315 |
+
rotations=trainer.params['rotations'].numpy(),
|
| 316 |
+
scale_modifier=1.0,
|
| 317 |
+
viewmatrix=camera['world_to_camera'],
|
| 318 |
+
projmatrix=camera['full_proj_matrix'],
|
| 319 |
+
tan_fovx=camera['tan_fovx'],
|
| 320 |
+
tan_fovy=camera['tan_fovy'],
|
| 321 |
+
image_height=camera['height'],
|
| 322 |
+
image_width=camera['width'],
|
| 323 |
+
sh=trainer.params['shs'].numpy(),
|
| 324 |
+
degree=3,
|
| 325 |
+
campos=camera['camera_center'],
|
| 326 |
+
prefiltered=False,
|
| 327 |
+
antialiasing=True,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
target_wp = wp.array(target.astype(np.float32), dtype=wp.vec3, device=DEVICE)
|
| 331 |
+
loss = l1_loss(rendered, target_wp)
|
| 332 |
+
trainer.losses.append(loss)
|
| 333 |
+
|
| 334 |
+
pixel_grad_buffer = compute_image_gradients(rendered, target_wp, lambda_dssim=0)
|
| 335 |
+
|
| 336 |
+
view_matrix = wp.mat44(camera['world_to_camera'].flatten())
|
| 337 |
+
proj_matrix = wp.mat44(camera['full_proj_matrix'].flatten())
|
| 338 |
+
campos = wp.vec3(camera['camera_center'][0], camera['camera_center'][1], camera['camera_center'][2])
|
| 339 |
+
|
| 340 |
+
geom_buffer = {
|
| 341 |
+
'radii': trainer.intermediate_buffers['radii'],
|
| 342 |
+
'means2D': trainer.intermediate_buffers['points_xy_image'],
|
| 343 |
+
'conic_opacity': trainer.intermediate_buffers['conic_opacity'],
|
| 344 |
+
'rgb': trainer.intermediate_buffers['colors'],
|
| 345 |
+
'clamped': trainer.intermediate_buffers['clamped_state']
|
| 346 |
+
}
|
| 347 |
+
binning_buffer = {'point_list': trainer.intermediate_buffers['point_list']}
|
| 348 |
+
img_buffer = {
|
| 349 |
+
'ranges': trainer.intermediate_buffers['ranges'],
|
| 350 |
+
'final_Ts': trainer.intermediate_buffers['final_Ts'],
|
| 351 |
+
'n_contrib': trainer.intermediate_buffers['n_contrib']
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
gradients = backward(
|
| 355 |
+
background=np.array(trainer.config['background_color'], dtype=np.float32),
|
| 356 |
+
means3D=trainer.params['positions'],
|
| 357 |
+
dL_dpixels=pixel_grad_buffer,
|
| 358 |
+
opacity=trainer.params['opacities'],
|
| 359 |
+
shs=trainer.params['shs'],
|
| 360 |
+
scales=trainer.params['scales'],
|
| 361 |
+
rotations=trainer.params['rotations'],
|
| 362 |
+
scale_modifier=trainer.config['scale_modifier'],
|
| 363 |
+
viewmatrix=view_matrix,
|
| 364 |
+
projmatrix=proj_matrix,
|
| 365 |
+
tan_fovx=camera['tan_fovx'],
|
| 366 |
+
tan_fovy=camera['tan_fovy'],
|
| 367 |
+
image_height=camera['height'],
|
| 368 |
+
image_width=camera['width'],
|
| 369 |
+
campos=campos,
|
| 370 |
+
radii=trainer.intermediate_buffers['radii'],
|
| 371 |
+
means2D=trainer.intermediate_buffers['points_xy_image'],
|
| 372 |
+
conic_opacity=trainer.intermediate_buffers['conic_opacity'],
|
| 373 |
+
rgb=trainer.intermediate_buffers['colors'],
|
| 374 |
+
cov3Ds=trainer.intermediate_buffers['cov3Ds'],
|
| 375 |
+
clamped=trainer.intermediate_buffers['clamped_state'],
|
| 376 |
+
geom_buffer=geom_buffer,
|
| 377 |
+
binning_buffer=binning_buffer,
|
| 378 |
+
img_buffer=img_buffer,
|
| 379 |
+
degree=trainer.config['sh_degree'],
|
| 380 |
+
debug=False
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
wp.copy(trainer.grads['positions'], gradients['dL_dmean3D'])
|
| 384 |
+
wp.copy(trainer.grads['scales'], gradients['dL_dscale'])
|
| 385 |
+
wp.copy(trainer.grads['rotations'], gradients['dL_drot'])
|
| 386 |
+
wp.copy(trainer.grads['opacities'], gradients['dL_dopacity'])
|
| 387 |
+
wp.copy(trainer.grads['shs'], gradients['dL_dshs'])
|
| 388 |
+
|
| 389 |
+
lr = 0.001 * (0.1 ** (it / iterations))
|
| 390 |
+
wp.launch(adam_update, dim=trainer.num_points, inputs=[
|
| 391 |
+
trainer.params['positions'], trainer.params['scales'],
|
| 392 |
+
trainer.params['rotations'], trainer.params['opacities'], trainer.params['shs'],
|
| 393 |
+
trainer.grads['positions'], trainer.grads['scales'],
|
| 394 |
+
trainer.grads['rotations'], trainer.grads['opacities'], trainer.grads['shs'],
|
| 395 |
+
trainer.adam_m['positions'], trainer.adam_m['scales'],
|
| 396 |
+
trainer.adam_m['rotations'], trainer.adam_m['opacities'], trainer.adam_m['shs'],
|
| 397 |
+
trainer.adam_v['positions'], trainer.adam_v['scales'],
|
| 398 |
+
trainer.adam_v['rotations'], trainer.adam_v['opacities'], trainer.adam_v['shs'],
|
| 399 |
+
trainer.num_points, lr, lr*5, lr*5, lr*2, lr*5,
|
| 400 |
+
0.9, 0.999, 1e-8, it
|
| 401 |
+
])
|
| 402 |
+
|
| 403 |
+
# Progress logging
|
| 404 |
+
if (it + 1) % 100 == 0:
|
| 405 |
+
print(f" Iter {it+1}/{iterations} | Loss: {loss:.4f}")
|
| 406 |
+
sys.stdout.flush()
|
| 407 |
+
|
| 408 |
+
# Checkpoints
|
| 409 |
+
if (it + 1) % 500 == 0 or it == iterations - 1:
|
| 410 |
+
trainer.save(it + 1)
|
| 411 |
+
|
| 412 |
+
ply_path = trainer.save_final()
|
| 413 |
+
all_ply_files.append(str(ply_path))
|
| 414 |
+
print(f"✓ Frame {frame_idx} complete: {ply_path}")
|
| 415 |
+
sys.stdout.flush()
|
| 416 |
+
|
| 417 |
+
return all_ply_files
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def render_animation_gif(ply_files, data, output_path, progress, fps=10):
|
| 421 |
+
"""
|
| 422 |
+
Render a GIF animation from an average camera position across all frames.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
ply_files: List of PLY file paths for each frame
|
| 426 |
+
data: VDPM data dict with camera info
|
| 427 |
+
output_path: Path to save the GIF
|
| 428 |
+
progress: Gradio progress callback
|
| 429 |
+
fps: Frames per second for GIF
|
| 430 |
+
"""
|
| 431 |
+
import warp as wp
|
| 432 |
+
from forward import render_gaussians
|
| 433 |
+
from utils.point_cloud_utils import load_ply
|
| 434 |
+
from utils.math_utils import projection_matrix
|
| 435 |
+
from train_vdpm import decode_poses
|
| 436 |
+
|
| 437 |
+
if not ply_files:
|
| 438 |
+
return None
|
| 439 |
+
|
| 440 |
+
print("Rendering animation GIF...")
|
| 441 |
+
sys.stdout.flush()
|
| 442 |
+
|
| 443 |
+
# Get image dimensions
|
| 444 |
+
images = data['images']
|
| 445 |
+
img_H, img_W = images.shape[1:3]
|
| 446 |
+
|
| 447 |
+
# Decode poses to get all cameras
|
| 448 |
+
pose_enc = data.get('pose_enc')
|
| 449 |
+
if pose_enc is not None:
|
| 450 |
+
extrinsics, intrinsics = decode_poses(pose_enc, (img_H, img_W))
|
| 451 |
+
else:
|
| 452 |
+
# Fallback
|
| 453 |
+
N = data['T'] * data['V']
|
| 454 |
+
extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
|
| 455 |
+
fx = fy = max(img_H, img_W)
|
| 456 |
+
K = np.array([[fx, 0, img_W/2], [0, fy, img_H/2], [0, 0, 1]], dtype=np.float32)
|
| 457 |
+
intrinsics = np.tile(K, (N, 1, 1))
|
| 458 |
+
|
| 459 |
+
# Compute average camera position
|
| 460 |
+
camera_centers = []
|
| 461 |
+
for i in range(len(extrinsics)):
|
| 462 |
+
R = extrinsics[i][:3, :3]
|
| 463 |
+
t = extrinsics[i][:3, 3]
|
| 464 |
+
center = -R.T @ t
|
| 465 |
+
camera_centers.append(center)
|
| 466 |
+
|
| 467 |
+
avg_center = np.mean(camera_centers, axis=0)
|
| 468 |
+
|
| 469 |
+
# Use first camera's orientation and intrinsics as base
|
| 470 |
+
R = extrinsics[0][:3, :3]
|
| 471 |
+
intrinsic = intrinsics[0]
|
| 472 |
+
fx, fy = intrinsic[0, 0], intrinsic[1, 1]
|
| 473 |
+
|
| 474 |
+
# Compute translation for average position
|
| 475 |
+
t = -R @ avg_center
|
| 476 |
+
|
| 477 |
+
# Build camera matrices (transposed for Warp/OpenGL)
|
| 478 |
+
world_to_camera = np.eye(4, dtype=np.float32)
|
| 479 |
+
world_to_camera[:3, :3] = R
|
| 480 |
+
world_to_camera[:3, 3] = t
|
| 481 |
+
world_to_camera = world_to_camera.T
|
| 482 |
+
|
| 483 |
+
fov_x = 2 * np.arctan(img_W / (2 * fx))
|
| 484 |
+
fov_y = 2 * np.arctan(img_H / (2 * fy))
|
| 485 |
+
|
| 486 |
+
proj_matrix = projection_matrix(fovx=fov_x, fovy=fov_y, znear=0.01, zfar=100.0).T
|
| 487 |
+
full_proj_matrix = world_to_camera @ proj_matrix
|
| 488 |
+
|
| 489 |
+
tan_fovx = np.tan(fov_x / 2)
|
| 490 |
+
tan_fovy = np.tan(fov_y / 2)
|
| 491 |
+
|
| 492 |
+
# Render each frame
|
| 493 |
+
rendered_frames = []
|
| 494 |
+
background = np.array([1.0, 1.0, 1.0], dtype=np.float32) # White background
|
| 495 |
+
|
| 496 |
+
for i, ply_path in enumerate(ply_files):
|
| 497 |
+
if not Path(ply_path).exists():
|
| 498 |
+
continue
|
| 499 |
+
|
| 500 |
+
progress(0.9 + 0.05 * (i / len(ply_files)), desc=f"Rendering GIF frame {i+1}/{len(ply_files)}...")
|
| 501 |
+
|
| 502 |
+
# Load PLY
|
| 503 |
+
ply_data = load_ply(ply_path)
|
| 504 |
+
|
| 505 |
+
positions = ply_data['positions']
|
| 506 |
+
scales = ply_data['scales']
|
| 507 |
+
rotations = ply_data['rotations']
|
| 508 |
+
opacities = ply_data['opacities']
|
| 509 |
+
shs = ply_data['shs']
|
| 510 |
+
|
| 511 |
+
# Render
|
| 512 |
+
rendered, _, _ = render_gaussians(
|
| 513 |
+
background=background,
|
| 514 |
+
means3D=positions,
|
| 515 |
+
colors=None,
|
| 516 |
+
opacity=opacities,
|
| 517 |
+
scales=scales,
|
| 518 |
+
rotations=rotations,
|
| 519 |
+
scale_modifier=1.0,
|
| 520 |
+
viewmatrix=world_to_camera,
|
| 521 |
+
projmatrix=full_proj_matrix,
|
| 522 |
+
tan_fovx=tan_fovx,
|
| 523 |
+
tan_fovy=tan_fovy,
|
| 524 |
+
image_height=img_H,
|
| 525 |
+
image_width=img_W,
|
| 526 |
+
sh=shs,
|
| 527 |
+
degree=3,
|
| 528 |
+
campos=avg_center,
|
| 529 |
+
prefiltered=False,
|
| 530 |
+
antialiasing=True,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# Convert to numpy
|
| 534 |
+
rendered_np = wp.to_torch(rendered).cpu().numpy()
|
| 535 |
+
rendered_np = np.clip(rendered_np * 255, 0, 255).astype(np.uint8)
|
| 536 |
+
rendered_frames.append(rendered_np)
|
| 537 |
+
|
| 538 |
+
if not rendered_frames:
|
| 539 |
+
return None
|
| 540 |
+
|
| 541 |
+
# Save GIF
|
| 542 |
+
gif_path = Path(output_path)
|
| 543 |
+
imageio.mimsave(str(gif_path), rendered_frames, fps=fps, loop=0)
|
| 544 |
+
print(f"✓ Animation GIF saved: {gif_path}")
|
| 545 |
+
sys.stdout.flush()
|
| 546 |
+
|
| 547 |
+
return str(gif_path)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def run_pipeline(video_files, iterations, conf_threshold, progress=gr.Progress()):
|
| 551 |
+
"""Run the full VDPM → 3DGS pipeline"""
|
| 552 |
+
|
| 553 |
+
if not video_files:
|
| 554 |
+
return None, None, None, "❌ Please upload video file(s)"
|
| 555 |
+
|
| 556 |
+
gc.collect()
|
| 557 |
+
if device == "cuda":
|
| 558 |
+
torch.cuda.empty_cache()
|
| 559 |
+
|
| 560 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 561 |
+
run_dir = Path(f"output/pipeline/run_{timestamp}")
|
| 562 |
+
run_dir.mkdir(parents=True, exist_ok=True)
|
| 563 |
+
|
| 564 |
+
try:
|
| 565 |
+
# Step 1: Process videos
|
| 566 |
+
progress(0.05, desc="Processing uploaded videos...")
|
| 567 |
+
print("=" * 50)
|
| 568 |
+
print("Processing Videos")
|
| 569 |
+
print("=" * 50)
|
| 570 |
+
sys.stdout.flush()
|
| 571 |
+
|
| 572 |
+
image_paths, num_views = process_videos(video_files, run_dir)
|
| 573 |
+
print(f"✓ Extracted {len(image_paths)} frames from {num_views} videos")
|
| 574 |
+
sys.stdout.flush()
|
| 575 |
+
|
| 576 |
+
# Step 2: VDPM inference
|
| 577 |
+
progress(0.1, desc="Running VDPM inference...")
|
| 578 |
+
print("=" * 50)
|
| 579 |
+
print("Running VDPM Inference")
|
| 580 |
+
print("=" * 50)
|
| 581 |
+
sys.stdout.flush()
|
| 582 |
+
|
| 583 |
+
num_timesteps, num_views = run_vdpm_inference(run_dir, progress)
|
| 584 |
+
|
| 585 |
+
# Clear VRAM before 3DGS training
|
| 586 |
+
global _vdpm_model
|
| 587 |
+
_vdpm_model = None
|
| 588 |
+
gc.collect()
|
| 589 |
+
if device == "cuda":
|
| 590 |
+
torch.cuda.empty_cache()
|
| 591 |
+
print(f"✓ Cleared VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB in use")
|
| 592 |
+
sys.stdout.flush()
|
| 593 |
+
|
| 594 |
+
# Step 3: 3DGS training
|
| 595 |
+
progress(0.4, desc="Training 3D Gaussian Splats...")
|
| 596 |
+
print("=" * 50)
|
| 597 |
+
print("Training 3D Gaussian Splats")
|
| 598 |
+
print("=" * 50)
|
| 599 |
+
sys.stdout.flush()
|
| 600 |
+
|
| 601 |
+
splat_dir = run_dir / "splats"
|
| 602 |
+
all_ply_files = run_3dgs_training(
|
| 603 |
+
run_dir, splat_dir, int(iterations), float(conf_threshold), progress
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# Step 4: Render animation GIF from average camera
|
| 607 |
+
progress(0.9, desc="Rendering animation GIF...")
|
| 608 |
+
print("=" * 50)
|
| 609 |
+
print("Rendering Animation GIF")
|
| 610 |
+
print("=" * 50)
|
| 611 |
+
sys.stdout.flush()
|
| 612 |
+
|
| 613 |
+
gif_path = None
|
| 614 |
+
if all_ply_files:
|
| 615 |
+
from train_vdpm import load_vdpm_data
|
| 616 |
+
data = load_vdpm_data(str(run_dir))
|
| 617 |
+
gif_path = render_animation_gif(
|
| 618 |
+
all_ply_files, data, run_dir / "animation.gif", progress
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
# Step 5: Package results
|
| 622 |
+
progress(0.95, desc="Packaging results...")
|
| 623 |
+
|
| 624 |
+
zip_path = run_dir / "results.zip"
|
| 625 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 626 |
+
# Add PLY files
|
| 627 |
+
for ply in all_ply_files:
|
| 628 |
+
if ply and Path(ply).exists():
|
| 629 |
+
zf.write(ply, f"splats/{Path(ply).name}")
|
| 630 |
+
|
| 631 |
+
# Add all checkpoint renders
|
| 632 |
+
for render_dir in splat_dir.glob("frame_*/iter_*"):
|
| 633 |
+
for img in render_dir.glob("*.png"):
|
| 634 |
+
rel_path = img.relative_to(splat_dir)
|
| 635 |
+
zf.write(img, f"renders/{rel_path}")
|
| 636 |
+
|
| 637 |
+
# Add VDPM data to root
|
| 638 |
+
for f in ["tracks.npz", "poses.npz", "meta.json"]:
|
| 639 |
+
fp = run_dir / f
|
| 640 |
+
if fp.exists():
|
| 641 |
+
zf.write(fp, f)
|
| 642 |
+
|
| 643 |
+
# Add input images
|
| 644 |
+
images_dir = run_dir / "images"
|
| 645 |
+
if images_dir.exists():
|
| 646 |
+
for img in images_dir.glob("*"):
|
| 647 |
+
zf.write(img, f"images/{img.name}")
|
| 648 |
+
|
| 649 |
+
# Add animation GIF
|
| 650 |
+
if gif_path and Path(gif_path).exists():
|
| 651 |
+
zf.write(gif_path, "animation.gif")
|
| 652 |
+
|
| 653 |
+
progress(1.0, desc="Complete!")
|
| 654 |
+
|
| 655 |
+
# Return first PLY for preview
|
| 656 |
+
preview_ply = all_ply_files[0] if all_ply_files else None
|
| 657 |
+
|
| 658 |
+
status = f"""✅ Pipeline Complete!
|
| 659 |
+
|
| 660 |
+
📊 Results:
|
| 661 |
+
• {len(all_ply_files)} PLY files generated
|
| 662 |
+
• {num_timesteps} timesteps × {num_views} views
|
| 663 |
+
• Animation GIF rendered
|
| 664 |
+
|
| 665 |
+
📁 Output: {run_dir}
|
| 666 |
+
📦 Download the ZIP for all files"""
|
| 667 |
+
|
| 668 |
+
return preview_ply, str(zip_path), gif_path, status
|
| 669 |
+
|
| 670 |
+
except Exception as e:
|
| 671 |
+
import traceback
|
| 672 |
+
traceback.print_exc()
|
| 673 |
+
return None, None, None, f"❌ Error: {str(e)}"
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
# ===== Gradio Interface =====
|
| 677 |
+
with gr.Blocks(title="DPM-Splat: 4D Gaussian Splatting", theme=gr.themes.Soft()) as app:
|
| 678 |
+
gr.Markdown("""
|
| 679 |
+
# 🎬 DPM-Splat: Video → 4D Gaussian Splats
|
| 680 |
+
|
| 681 |
+
End-to-end pipeline combining **V-DPM** (Video Dynamic Point Maps) with **3D Gaussian Splatting**.
|
| 682 |
+
Upload multi-view synchronized videos to generate temporally consistent 4D reconstructions.
|
| 683 |
+
""")
|
| 684 |
+
|
| 685 |
+
with gr.Row():
|
| 686 |
+
with gr.Column(scale=1):
|
| 687 |
+
video_input = gr.File(
|
| 688 |
+
label="📹 Upload Videos",
|
| 689 |
+
file_count="multiple",
|
| 690 |
+
file_types=[".mp4", ".mov", ".avi", ".webm"]
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
gr.Markdown("*Upload 1-4 synchronized video files for best results*")
|
| 694 |
+
|
| 695 |
+
with gr.Accordion("⚙️ Settings", open=True):
|
| 696 |
+
iterations = gr.Slider(
|
| 697 |
+
minimum=0, maximum=10000, value=1000, step=100,
|
| 698 |
+
label="Training Iterations",
|
| 699 |
+
info="0 = export raw point cloud only, more = better quality"
|
| 700 |
+
)
|
| 701 |
+
conf_threshold = gr.Slider(
|
| 702 |
+
minimum=0, maximum=100, value=0, step=5,
|
| 703 |
+
label="Confidence Threshold (%)",
|
| 704 |
+
info="0% keeps all points, higher = filter low confidence"
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
run_btn = gr.Button("🚀 Run Pipeline", variant="primary", size="lg")
|
| 708 |
+
|
| 709 |
+
status_text = gr.Textbox(
|
| 710 |
+
label="Status",
|
| 711 |
+
interactive=False,
|
| 712 |
+
lines=6,
|
| 713 |
+
value="Upload videos and click 'Run Pipeline' to begin."
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
with gr.Column(scale=2):
|
| 717 |
+
with gr.Row():
|
| 718 |
+
model_viewer = gr.Model3D(
|
| 719 |
+
label="3D Preview (First Frame)",
|
| 720 |
+
clear_color=[1.0, 1.0, 1.0, 1.0],
|
| 721 |
+
height=400
|
| 722 |
+
)
|
| 723 |
+
gif_viewer = gr.Image(
|
| 724 |
+
label="🎞️ Animation (Average Camera)",
|
| 725 |
+
height=400
|
| 726 |
+
)
|
| 727 |
+
download_btn = gr.File(label="📦 Download Results (ZIP)")
|
| 728 |
+
|
| 729 |
+
gr.Markdown("""
|
| 730 |
+
---
|
| 731 |
+
### 📋 Output Contents
|
| 732 |
+
|
| 733 |
+
The downloaded ZIP contains:
|
| 734 |
+
- `splats/frame_XXXX.ply` - Gaussian splat for each timestep
|
| 735 |
+
- `renders/` - Training progress images (target vs rendered)
|
| 736 |
+
- `animation.gif` - Rendered animation from average camera
|
| 737 |
+
- `tracks.npz` - 3D point tracks
|
| 738 |
+
- `poses.npz` - Camera poses
|
| 739 |
+
- `images/` - Input frames
|
| 740 |
+
|
| 741 |
+
**Local runs**: Results saved to `output/pipeline/run_TIMESTAMP/`
|
| 742 |
+
""")
|
| 743 |
+
|
| 744 |
+
run_btn.click(
|
| 745 |
+
fn=run_pipeline,
|
| 746 |
+
inputs=[video_input, iterations, conf_threshold],
|
| 747 |
+
outputs=[model_viewer, download_btn, gif_viewer, status_text]
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
if __name__ == "__main__":
|
| 751 |
+
app.queue().launch(share=True, show_error=True)
|
gs/.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
gs/.gitignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by venv; see https://docs.python.org/3/library/venv.html
|
| 2 |
+
|
| 3 |
+
data/*
|
| 4 |
+
output/*
|
| 5 |
+
lib/*
|
| 6 |
+
lib64/*
|
| 7 |
+
data_/*
|
| 8 |
+
colmap_0/*
|
| 9 |
+
bin/*
|
| 10 |
+
share/*
|
| 11 |
+
__pycache__/*
|
| 12 |
+
utils/__pycache__/*
|
| 13 |
+
.DS_Store
|
gs/backward.py
ADDED
|
@@ -0,0 +1,1084 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warp as wp
|
| 2 |
+
import math
|
| 3 |
+
from utils.wp_utils import to_warp_array, wp_vec3_mul_element, wp_vec3_add_element, wp_vec3_sqrt, wp_vec3_div_element, wp_vec3_clamp
|
| 4 |
+
from config import * # Assuming TILE_M, TILE_N, VEC6, DEVICE are defined here
|
| 5 |
+
|
| 6 |
+
# Initialize Warp if not already done elsewhere
|
| 7 |
+
# wp.init()
|
| 8 |
+
|
| 9 |
+
# --- Spherical Harmonics Constants ---
|
| 10 |
+
SH_C0 = 0.28209479177387814
|
| 11 |
+
SH_C1 = 0.4886025119029199
|
| 12 |
+
|
| 13 |
+
@wp.func
|
| 14 |
+
def dnormvdv(v: wp.vec3, dv: wp.vec3) -> wp.vec3:
|
| 15 |
+
"""
|
| 16 |
+
Computes the gradient of normalize(v) with respect to v, scaled by dv.
|
| 17 |
+
This is a direct port of the CUDA implementation.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
v: The input vector to be normalized
|
| 21 |
+
dv: The gradient vector to scale the result by
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
The gradient vector
|
| 25 |
+
"""
|
| 26 |
+
sum2 = v[0] * v[0] + v[1] * v[1] + v[2] * v[2]
|
| 27 |
+
|
| 28 |
+
# Avoid division by zero
|
| 29 |
+
if sum2 < 1e-10:
|
| 30 |
+
return wp.vec3(0.0, 0.0, 0.0)
|
| 31 |
+
|
| 32 |
+
invsum32 = 1.0 / wp.sqrt(sum2 * sum2 * sum2)
|
| 33 |
+
|
| 34 |
+
result = wp.vec3(
|
| 35 |
+
((sum2 - v[0] * v[0]) * dv[0] - v[1] * v[0] * dv[1] - v[2] * v[0] * dv[2]) * invsum32,
|
| 36 |
+
(-v[0] * v[1] * dv[0] + (sum2 - v[1] * v[1]) * dv[1] - v[2] * v[1] * dv[2]) * invsum32,
|
| 37 |
+
(-v[0] * v[2] * dv[0] - v[1] * v[2] * dv[1] + (sum2 - v[2] * v[2]) * dv[2]) * invsum32
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return result
|
| 41 |
+
|
| 42 |
+
# --- Backward Kernels ---
|
| 43 |
+
@wp.kernel
|
| 44 |
+
def sh_backward_kernel(
|
| 45 |
+
# --- Inputs ---
|
| 46 |
+
num_points: int, # Number of Gaussian points
|
| 47 |
+
degree: int, # SH degree used in forward
|
| 48 |
+
means: wp.array(dtype=wp.vec3), # 3D positions (N, 3)
|
| 49 |
+
shs: wp.array(dtype=wp.vec3), # Flattened SH coeffs (N * 16, 3)
|
| 50 |
+
radii: wp.array(dtype=int), # Radii computed in forward (N,) - used for skipping
|
| 51 |
+
campos: wp.vec3, # Camera position (3,)
|
| 52 |
+
clamped_state: wp.array(dtype=wp.vec3), # Clamping state {0,1} from forward pass (N, 3)
|
| 53 |
+
dL_dcolor: wp.array(dtype=wp.vec3), # Grad L w.r.t. *final* gaussian color (N, 3)
|
| 54 |
+
|
| 55 |
+
# --- Outputs (Accumulate) ---
|
| 56 |
+
dL_dmeans: wp.array(dtype=wp.vec3), # Accumulate mean grads here (N, 3)
|
| 57 |
+
dL_dshs: wp.array(dtype=wp.vec3) # Accumulate SH grads here (N * 16, 3)
|
| 58 |
+
):
|
| 59 |
+
idx = wp.tid()
|
| 60 |
+
|
| 61 |
+
if idx >= num_points or radii[idx] <= 0: # Skip if not rendered
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
mean = means[idx]
|
| 65 |
+
base_sh_idx = idx * 16
|
| 66 |
+
|
| 67 |
+
# --- Recompute view direction ---
|
| 68 |
+
dir_orig = mean - campos
|
| 69 |
+
dir_len = wp.length(dir_orig)
|
| 70 |
+
# Skip if direction length is too small (matches CUDA implementation)
|
| 71 |
+
if dir_len < 1e-8:
|
| 72 |
+
return
|
| 73 |
+
|
| 74 |
+
# Normalize direction
|
| 75 |
+
dir = dir_orig / dir_len
|
| 76 |
+
x = dir[0]; y = dir[1]; z = dir[2]
|
| 77 |
+
|
| 78 |
+
# --- Apply clamping mask to input gradient ---
|
| 79 |
+
dL_dRGB = dL_dcolor[idx]
|
| 80 |
+
dL_dRGB = wp_vec3_mul_element(dL_dRGB, wp_vec3_add_element(wp.vec3(1.0, 1.0, 1.0), -1.0 * clamped_state[idx]))
|
| 81 |
+
|
| 82 |
+
# Initialize gradients w.r.t. direction components (dRawColor/ddir)
|
| 83 |
+
dRGBdx = wp.vec3(0.0, 0.0, 0.0)
|
| 84 |
+
dRGBdy = wp.vec3(0.0, 0.0, 0.0)
|
| 85 |
+
dRGBdz = wp.vec3(0.0, 0.0, 0.0)
|
| 86 |
+
|
| 87 |
+
# --- Degree 0 ---
|
| 88 |
+
# Direct assignment for clarity (matching CUDA style)
|
| 89 |
+
dRGBdsh0 = SH_C0
|
| 90 |
+
dL_dshs[base_sh_idx] = dRGBdsh0 * dL_dRGB
|
| 91 |
+
|
| 92 |
+
# --- Degree 1 ---
|
| 93 |
+
if degree > 0:
|
| 94 |
+
sh1 = shs[base_sh_idx + 1]
|
| 95 |
+
sh2 = shs[base_sh_idx + 2]
|
| 96 |
+
sh3 = shs[base_sh_idx + 3]
|
| 97 |
+
|
| 98 |
+
# Exactly match CUDA computation order
|
| 99 |
+
dRGBdsh1 = -SH_C1 * y
|
| 100 |
+
dRGBdsh2 = SH_C1 * z
|
| 101 |
+
dRGBdsh3 = -SH_C1 * x
|
| 102 |
+
|
| 103 |
+
dL_dshs[base_sh_idx + 1] = dRGBdsh1 * dL_dRGB
|
| 104 |
+
dL_dshs[base_sh_idx + 2] = dRGBdsh2 * dL_dRGB
|
| 105 |
+
dL_dshs[base_sh_idx + 3] = dRGBdsh3 * dL_dRGB
|
| 106 |
+
|
| 107 |
+
# Gradient components w.r.t. direction
|
| 108 |
+
dRGBdx = -SH_C1 * sh3
|
| 109 |
+
dRGBdy = -SH_C1 * sh1
|
| 110 |
+
dRGBdz = SH_C1 * sh2
|
| 111 |
+
# --- Degree 2 ---
|
| 112 |
+
if degree > 1:
|
| 113 |
+
xx = x*x; yy = y*y; zz = z*z
|
| 114 |
+
xy = x*y; yz = y*z; xz = x*z
|
| 115 |
+
|
| 116 |
+
sh4 = shs[base_sh_idx + 4]; sh5 = shs[base_sh_idx + 5]
|
| 117 |
+
sh6 = shs[base_sh_idx + 6]; sh7 = shs[base_sh_idx + 7]
|
| 118 |
+
sh8 = shs[base_sh_idx + 8]
|
| 119 |
+
|
| 120 |
+
# Hardcoded C2 values (same as CUDA SH_C2)
|
| 121 |
+
C2_0 = 1.0925484305920792
|
| 122 |
+
C2_1 = -1.0925484305920792
|
| 123 |
+
C2_2 = 0.31539156525252005
|
| 124 |
+
C2_3 = -1.0925484305920792
|
| 125 |
+
C2_4 = 0.5462742152960396
|
| 126 |
+
|
| 127 |
+
# Compute gradients for degree 2 (matching CUDA)
|
| 128 |
+
dRGBdsh4 = C2_0 * xy
|
| 129 |
+
dRGBdsh5 = C2_1 * yz
|
| 130 |
+
dRGBdsh6 = C2_2 * (2.0 * zz - xx - yy)
|
| 131 |
+
dRGBdsh7 = C2_3 * xz
|
| 132 |
+
dRGBdsh8 = C2_4 * (xx - yy)
|
| 133 |
+
|
| 134 |
+
dL_dshs[base_sh_idx + 4] = dRGBdsh4 * dL_dRGB
|
| 135 |
+
dL_dshs[base_sh_idx + 5] = dRGBdsh5 * dL_dRGB
|
| 136 |
+
dL_dshs[base_sh_idx + 6] = dRGBdsh6 * dL_dRGB
|
| 137 |
+
dL_dshs[base_sh_idx + 7] = dRGBdsh7 * dL_dRGB
|
| 138 |
+
dL_dshs[base_sh_idx + 8] = dRGBdsh8 * dL_dRGB
|
| 139 |
+
|
| 140 |
+
# Accumulate gradients w.r.t. direction (exactly matching CUDA)
|
| 141 |
+
dRGBdx += C2_0 * y * sh4 + C2_2 * 2.0 * -x * sh6 + C2_3 * z * sh7 + C2_4 * 2.0 * x * sh8
|
| 142 |
+
dRGBdy += C2_0 * x * sh4 + C2_1 * z * sh5 + C2_2 * 2.0 * -y * sh6 + C2_4 * 2.0 * -y * sh8
|
| 143 |
+
dRGBdz += C2_1 * y * sh5 + C2_2 * 2.0 * 2.0 * z * sh6 + C2_3 * x * sh7
|
| 144 |
+
|
| 145 |
+
# --- Degree 3 ---
|
| 146 |
+
if degree > 2:
|
| 147 |
+
sh9 = shs[base_sh_idx + 9]; sh10 = shs[base_sh_idx + 10]
|
| 148 |
+
sh11 = shs[base_sh_idx + 11]; sh12 = shs[base_sh_idx + 12]
|
| 149 |
+
sh13 = shs[base_sh_idx + 13]; sh14 = shs[base_sh_idx + 14]
|
| 150 |
+
sh15 = shs[base_sh_idx + 15]
|
| 151 |
+
|
| 152 |
+
# Hardcoded C3 values (same as CUDA SH_C3)
|
| 153 |
+
C3_0 = -0.5900435899266435
|
| 154 |
+
C3_1 = 2.890611442640554
|
| 155 |
+
C3_2 = -0.4570457994644658
|
| 156 |
+
C3_3 = 0.3731763325901154
|
| 157 |
+
C3_4 = -0.4570457994644658
|
| 158 |
+
C3_5 = 1.445305721320277
|
| 159 |
+
C3_6 = -0.5900435899266435
|
| 160 |
+
|
| 161 |
+
# Direct computation of degree 3 gradients (matching CUDA)
|
| 162 |
+
dRGBdsh9 = C3_0 * y * (3.0 * xx - yy)
|
| 163 |
+
dRGBdsh10 = C3_1 * xy * z
|
| 164 |
+
dRGBdsh11 = C3_2 * y * (4.0 * zz - xx - yy)
|
| 165 |
+
dRGBdsh12 = C3_3 * z * (2.0 * zz - 3.0 * xx - 3.0 * yy)
|
| 166 |
+
dRGBdsh13 = C3_4 * x * (4.0 * zz - xx - yy)
|
| 167 |
+
dRGBdsh14 = C3_5 * z * (xx - yy)
|
| 168 |
+
dRGBdsh15 = C3_6 * x * (xx - 3.0 * yy)
|
| 169 |
+
|
| 170 |
+
dL_dshs[base_sh_idx + 9] = dRGBdsh9 * dL_dRGB
|
| 171 |
+
dL_dshs[base_sh_idx + 10] = dRGBdsh10 * dL_dRGB
|
| 172 |
+
dL_dshs[base_sh_idx + 11] = dRGBdsh11 * dL_dRGB
|
| 173 |
+
dL_dshs[base_sh_idx + 12] = dRGBdsh12 * dL_dRGB
|
| 174 |
+
dL_dshs[base_sh_idx + 13] = dRGBdsh13 * dL_dRGB
|
| 175 |
+
dL_dshs[base_sh_idx + 14] = dRGBdsh14 * dL_dRGB
|
| 176 |
+
dL_dshs[base_sh_idx + 15] = dRGBdsh15 * dL_dRGB
|
| 177 |
+
|
| 178 |
+
# Accumulate dRGBdx (matching CUDA's expression structure)
|
| 179 |
+
dRGBdx += (
|
| 180 |
+
C3_0 * sh9 * 3.0 * 2.0 * xy +
|
| 181 |
+
C3_1 * sh10 * yz +
|
| 182 |
+
C3_2 * sh11 * -2.0 * xy +
|
| 183 |
+
C3_3 * sh12 * -3.0 * 2.0 * xz +
|
| 184 |
+
C3_4 * sh13 * (-3.0 * xx + 4.0 * zz - yy) +
|
| 185 |
+
C3_5 * sh14 * 2.0 * xz +
|
| 186 |
+
C3_6 * sh15 * 3.0 * (xx - yy)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Accumulate dRGBdy (matching CUDA's expression structure)
|
| 190 |
+
dRGBdy += (
|
| 191 |
+
C3_0 * sh9 * 3.0 * (xx - yy) +
|
| 192 |
+
C3_1 * sh10 * xz +
|
| 193 |
+
C3_2 * sh11 * (-3.0 * yy + 4.0 * zz - xx) +
|
| 194 |
+
C3_3 * sh12 * -3.0 * 2.0 * yz +
|
| 195 |
+
C3_4 * sh13 * -2.0 * xy +
|
| 196 |
+
C3_5 * sh14 * -2.0 * yz +
|
| 197 |
+
C3_6 * sh15 * -3.0 * 2.0 * xy
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Accumulate dRGBdz (matching CUDA's expression structure)
|
| 201 |
+
dRGBdz += (
|
| 202 |
+
C3_1 * sh10 * xy +
|
| 203 |
+
C3_2 * sh11 * 4.0 * 2.0 * yz +
|
| 204 |
+
C3_3 * sh12 * 3.0 * (2.0 * zz - xx - yy) +
|
| 205 |
+
C3_4 * sh13 * 4.0 * 2.0 * xz +
|
| 206 |
+
C3_5 * sh14 * (xx - yy)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# --- Compute gradient w.r.t. view direction (dL/ddir) ---
|
| 210 |
+
dL_ddir = wp.vec3(wp.dot(dRGBdx, dL_dRGB),
|
| 211 |
+
wp.dot(dRGBdy, dL_dRGB),
|
| 212 |
+
wp.dot(dRGBdz, dL_dRGB))
|
| 213 |
+
|
| 214 |
+
# --- Propagate gradient from direction to mean position (dL/dmean) ---
|
| 215 |
+
dL_dmeans_local = dnormvdv(dir_orig, dL_ddir)
|
| 216 |
+
|
| 217 |
+
# --- Accumulate gradients to global arrays ---
|
| 218 |
+
dL_dmeans[idx] += dL_dmeans_local
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@wp.kernel
|
| 222 |
+
def compute_cov2d_backward_kernel(
|
| 223 |
+
# --- Inputs ---
|
| 224 |
+
num_points: int, # Number of Gaussian points
|
| 225 |
+
means: wp.array(dtype=wp.vec3), # 3D positions (N, 3)
|
| 226 |
+
cov3Ds: wp.array(dtype=VEC6), # Packed 3D cov (N, 6)
|
| 227 |
+
radii: wp.array(dtype=int), # Radii computed in forward (N,) - used for skipping
|
| 228 |
+
h_x: float, h_y: float, # Focal lengths
|
| 229 |
+
tan_fovx: float, tan_fovy: float, # Tangent of FOV
|
| 230 |
+
view_matrix: wp.mat44, # World->View matrix (4, 4)
|
| 231 |
+
dL_dconics: wp.array(dtype=wp.vec4), # Grad L w.r.t. conic (a, b, c) (N, 3)
|
| 232 |
+
|
| 233 |
+
# --- Outputs (Accumulate) ---
|
| 234 |
+
dL_dmeans: wp.array(dtype=wp.vec3), # Accumulate mean grads here (N, 3)
|
| 235 |
+
dL_dcov3Ds: wp.array(dtype=VEC6) # Accumulate 3D cov grads here (N, 6)
|
| 236 |
+
):
|
| 237 |
+
idx = wp.tid()
|
| 238 |
+
if idx >= num_points or radii[idx] <= 0: # Skip if not rendered
|
| 239 |
+
# Zero out dL_dcov3Ds to ensure we don't keep old values
|
| 240 |
+
dL_dcov3Ds[idx] = VEC6(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
|
| 241 |
+
return
|
| 242 |
+
|
| 243 |
+
mean = means[idx]
|
| 244 |
+
cov3D_packed = cov3Ds[idx] # VEC6
|
| 245 |
+
|
| 246 |
+
dL_dconic = wp.vec3(dL_dconics[idx][0], dL_dconics[idx][1], dL_dconics[idx][3])
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
t = wp.vec4(mean[0], mean[1], mean[2], 1.0) * view_matrix
|
| 250 |
+
|
| 251 |
+
limx = 1.3 * tan_fovx
|
| 252 |
+
limy = 1.3 * tan_fovy
|
| 253 |
+
tz = t[2]
|
| 254 |
+
inv_tz = 1.0 / tz
|
| 255 |
+
txtz = t[0] * inv_tz
|
| 256 |
+
tytz = t[1] * inv_tz
|
| 257 |
+
|
| 258 |
+
x_clamped_flag = (txtz < -limx) or (txtz > limx)
|
| 259 |
+
y_clamped_flag = (tytz < -limy) or (tytz > limy)
|
| 260 |
+
x_grad_mul = 1.0 - float(x_clamped_flag) # 1.0 if not clamped, 0.0 if clamped
|
| 261 |
+
y_grad_mul = 1.0 - float(y_clamped_flag)
|
| 262 |
+
|
| 263 |
+
tx = wp.min(limx, wp.max(-limx, txtz)) * tz
|
| 264 |
+
ty = wp.min(limy, wp.max(-limy, tytz)) * tz
|
| 265 |
+
inv_tz2 = inv_tz * inv_tz
|
| 266 |
+
inv_tz3 = inv_tz2 * inv_tz
|
| 267 |
+
|
| 268 |
+
J00 = h_x * inv_tz
|
| 269 |
+
J11 = h_y * inv_tz
|
| 270 |
+
J02 = -h_x * tx * inv_tz2
|
| 271 |
+
J12 = -h_y * ty * inv_tz2
|
| 272 |
+
|
| 273 |
+
J = wp.transpose(wp.mat33(
|
| 274 |
+
J00, 0.0, J02,
|
| 275 |
+
0.0, J11, J12,
|
| 276 |
+
0.0, 0.0, 0.0
|
| 277 |
+
))
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
W = wp.mat33(
|
| 282 |
+
view_matrix[0,0], view_matrix[0,1], view_matrix[0,2],
|
| 283 |
+
view_matrix[1,0], view_matrix[1,1], view_matrix[1,2],
|
| 284 |
+
view_matrix[2,0], view_matrix[2,1], view_matrix[2,2]
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
T = W * J
|
| 288 |
+
c0 = cov3D_packed[0]; c1 = cov3D_packed[1]; c2 = cov3D_packed[2]
|
| 289 |
+
c11 = cov3D_packed[3]; c12 = cov3D_packed[4]; c22 = cov3D_packed[5]
|
| 290 |
+
Vrk = wp.mat33(c0, c1, c2, c1, c11, c12, c2, c12, c22) # Assumes VEC6 stores upper triangle row-wise
|
| 291 |
+
|
| 292 |
+
cov2D_mat = wp.transpose(T) * wp.transpose(Vrk) * T
|
| 293 |
+
|
| 294 |
+
a_noblr = cov2D_mat[0,0]
|
| 295 |
+
b_noblr = cov2D_mat[0,1]
|
| 296 |
+
c_noblr = cov2D_mat[1,1]
|
| 297 |
+
a = a_noblr + 0.3
|
| 298 |
+
b = b_noblr
|
| 299 |
+
c = c_noblr + 0.3
|
| 300 |
+
|
| 301 |
+
denom = a * c - b * b
|
| 302 |
+
dL_da = 0.0; dL_db = 0.0; dL_dc = 0.0
|
| 303 |
+
|
| 304 |
+
# --- Calculate Gradients ---
|
| 305 |
+
if denom != 0.0:
|
| 306 |
+
# Use a small epsilon to prevent division by zero
|
| 307 |
+
denom2inv = 1.0 / (denom * denom + 1e-7)
|
| 308 |
+
dL_da = denom2inv * (-c * c * dL_dconic[0] + 2.0 * b * c * dL_dconic[1] + (denom - a * c) * dL_dconic[2])
|
| 309 |
+
dL_dc = denom2inv * (-a * a * dL_dconic[2] + 2.0 * a * b * dL_dconic[1] + (denom - a * c) * dL_dconic[0])
|
| 310 |
+
dL_db = denom2inv * 2.0 * (b * c * dL_dconic[0] - (denom + 2.0 * b * b) * dL_dconic[1] + a * b * dL_dconic[2])
|
| 311 |
+
|
| 312 |
+
dL_dcov3Ds[idx] = VEC6(
|
| 313 |
+
# Diagonal elements
|
| 314 |
+
T[0][0] * T[0][0] * dL_da + T[0][0] * T[0][1] * dL_db + T[0][1] * T[0][1] * dL_dc, # c00
|
| 315 |
+
2.0 * T[0][0] * T[1][0] * dL_da + (T[0][0] * T[1][1] + T[1][0] * T[0][1]) * dL_db + 2.0 * T[0][1] * T[1][1] * dL_dc, # c01
|
| 316 |
+
2.0 * T[0][0] * T[2][0] * dL_da + (T[0][0] * T[2][1] + T[2][0] * T[0][1]) * dL_db + 2.0 * T[0][1] * T[2][1] * dL_dc, # c02
|
| 317 |
+
T[1][0] * T[1][0] * dL_da + T[1][0] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc, # c11
|
| 318 |
+
2.0 * T[2][0] * T[1][0] * dL_da + (T[1][0] * T[2][1] + T[2][0] * T[1][1]) * dL_db + 2.0 * T[1][1] * T[2][1] * dL_dc, # c12
|
| 319 |
+
T[2][0] * T[2][0] * dL_da + T[2][0] * T[2][1] * dL_db + T[2][1] * T[2][1] * dL_dc # c22
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
dL_dT00 = 2.0 * (T[0][0] * Vrk[0][0] + T[1][0] * Vrk[1][0] + T[2][0] * Vrk[2][0]) * dL_da + \
|
| 323 |
+
(T[0][1] * Vrk[0][0] + T[1][1] * Vrk[1][0] + T[2][1] * Vrk[2][0]) * dL_db
|
| 324 |
+
dL_dT01 = 2.0 * (T[0][0] * Vrk[0][1] + T[1][0] * Vrk[1][1] + T[2][0] * Vrk[2][1]) * dL_da + \
|
| 325 |
+
(T[0][1] * Vrk[0][1] + T[1][1] * Vrk[1][1] + T[2][1] * Vrk[2][1]) * dL_db
|
| 326 |
+
dL_dT02 = 2.0 * (T[0][0] * Vrk[0][2] + T[1][0] * Vrk[1][2] + T[2][0] * Vrk[2][2]) * dL_da + \
|
| 327 |
+
(T[0][1] * Vrk[0][2] + T[1][1] * Vrk[1][2] + T[2][1] * Vrk[2][2]) * dL_db
|
| 328 |
+
dL_dT10 = 2.0 * (T[0][1] * Vrk[0][0] + T[1][1] * Vrk[1][0] + T[2][1] * Vrk[2][0]) * dL_dc + \
|
| 329 |
+
(T[0][0] * Vrk[0][0] + T[1][0] * Vrk[1][0] + T[2][0] * Vrk[2][0]) * dL_db
|
| 330 |
+
dL_dT11 = 2.0 * (T[0][1] * Vrk[0][1] + T[1][1] * Vrk[1][1] + T[2][1] * Vrk[2][1]) * dL_dc + \
|
| 331 |
+
(T[0][0] * Vrk[0][1] + T[1][0] * Vrk[1][1] + T[2][0] * Vrk[2][1]) * dL_db
|
| 332 |
+
dL_dT12 = 2.0 * (T[0][1] * Vrk[0][2] + T[1][1] * Vrk[1][2] + T[2][1] * Vrk[2][2]) * dL_dc + \
|
| 333 |
+
(T[0][0] * Vrk[0][2] + T[1][0] * Vrk[1][2] + T[2][0] * Vrk[2][2]) * dL_db
|
| 334 |
+
|
| 335 |
+
dL_dJ00 = W[0,0] * dL_dT00 + W[1,0] * dL_dT01 + W[2,0] * dL_dT02
|
| 336 |
+
dL_dJ02 = W[0,2] * dL_dT00 + W[1,2] * dL_dT01 + W[2,2] * dL_dT02
|
| 337 |
+
dL_dJ11 = W[0,1] * dL_dT10 + W[1,1] * dL_dT11 + W[2,1] * dL_dT12
|
| 338 |
+
dL_dJ12 = W[0,2] * dL_dT10 + W[1,2] * dL_dT11 + W[2,2] * dL_dT12
|
| 339 |
+
|
| 340 |
+
dL_dtx = -h_x * inv_tz2 * dL_dJ02
|
| 341 |
+
dL_dty = -h_y * inv_tz2 * dL_dJ12
|
| 342 |
+
dL_dtz = -h_x * inv_tz2 * dL_dJ00 - h_y * inv_tz2 * dL_dJ11 + \
|
| 343 |
+
2.0 * h_x * tx * inv_tz3 * dL_dJ02 + 2.0 * h_y * ty * inv_tz3 * dL_dJ12
|
| 344 |
+
|
| 345 |
+
dL_dt = wp.vec3(dL_dtx * x_grad_mul, dL_dty * y_grad_mul, dL_dtz)
|
| 346 |
+
|
| 347 |
+
dL_dmean_from_cov = wp.vec4(dL_dt[0], dL_dt[1], dL_dt[2], 1.0) * wp.transpose(view_matrix)
|
| 348 |
+
dL_dmeans[idx] += wp.vec3(dL_dmean_from_cov[0], dL_dmean_from_cov[1], dL_dmean_from_cov[2])
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@wp.kernel
|
| 352 |
+
def compute_cov3d_backward_kernel(
|
| 353 |
+
# --- Inputs ---
|
| 354 |
+
num_points: int, # Number of Gaussian points
|
| 355 |
+
scales: wp.array(dtype=wp.vec3), # Scale parameters (N, 3)
|
| 356 |
+
rotations: wp.array(dtype=wp.vec4), # Quaternions (x, y, z, w) (N, 4)
|
| 357 |
+
radii: wp.array(dtype=int), # Radii computed in forward (N,) - used for skipping
|
| 358 |
+
scale_modifier: float, # Global scale modifier
|
| 359 |
+
dL_dcov3Ds: wp.array(dtype=VEC6), # Grad L w.r.t packed 3D cov (N, 6)
|
| 360 |
+
|
| 361 |
+
# --- Outputs ---
|
| 362 |
+
dL_dscales: wp.array(dtype=wp.vec3), # Write scale grads here (N, 3)
|
| 363 |
+
dL_drots: wp.array(dtype=wp.vec4) # Write rot grads here (N, 4)
|
| 364 |
+
):
|
| 365 |
+
idx = wp.tid()
|
| 366 |
+
# Skip if not rendered OR if grad input is zero (e.g., from compute_cov2d_backward)
|
| 367 |
+
if idx >= num_points or radii[idx] <= 0:
|
| 368 |
+
dL_dscales[idx] = wp.vec3(0.0, 0.0, 0.0)
|
| 369 |
+
dL_drots[idx] = wp.vec4(0.0, 0.0, 0.0, 0.0)
|
| 370 |
+
return
|
| 371 |
+
|
| 372 |
+
# --- Recompute intermediates ---
|
| 373 |
+
scale_vec = scales[idx]
|
| 374 |
+
rot_quat = rotations[idx] # (x, y, z, w) in Warp
|
| 375 |
+
|
| 376 |
+
# Extract quaternion components to match CUDA convention (r, x, y, z)
|
| 377 |
+
r = rot_quat[3] # Real part is w in Warp
|
| 378 |
+
x = rot_quat[0]
|
| 379 |
+
y = rot_quat[1]
|
| 380 |
+
z = rot_quat[2]
|
| 381 |
+
|
| 382 |
+
# 1. Construct rotation matrix R manually as in CUDA
|
| 383 |
+
R = wp.mat33(
|
| 384 |
+
1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - r * z), 2.0 * (x * z + r * y),
|
| 385 |
+
2.0 * (x * y + r * z), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - r * x),
|
| 386 |
+
2.0 * (x * z - r * y), 2.0 * (y * z + r * x), 1.0 - 2.0 * (x * x + y * y)
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# 2. Create scaling matrix S
|
| 390 |
+
s_vec = scale_modifier * scale_vec
|
| 391 |
+
S = wp.mat33(
|
| 392 |
+
s_vec[0], 0.0, 0.0,
|
| 393 |
+
0.0, s_vec[1], 0.0,
|
| 394 |
+
0.0, 0.0, s_vec[2]
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# 3. M = S * R (match CUDA multiplication order)
|
| 398 |
+
M = S * R
|
| 399 |
+
|
| 400 |
+
# --- Extract gradient w.r.t. 3D covariance ---
|
| 401 |
+
dL_dcov3D_packed = dL_dcov3Ds[idx]
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# Convert per-element covariance loss gradients to matrix form
|
| 405 |
+
dL_dSigma = wp.mat33(
|
| 406 |
+
dL_dcov3D_packed[0], 0.5 * dL_dcov3D_packed[1], 0.5 * dL_dcov3D_packed[2],
|
| 407 |
+
0.5 * dL_dcov3D_packed[1], dL_dcov3D_packed[3], 0.5 * dL_dcov3D_packed[4],
|
| 408 |
+
0.5 * dL_dcov3D_packed[2], 0.5 * dL_dcov3D_packed[4], dL_dcov3D_packed[5]
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# --- Calculate Gradients ---
|
| 412 |
+
# 1. Gradient w.r.t. M: dL/dM = 2 * M * dL/dSigma
|
| 413 |
+
dL_dM = 2.0 * M * dL_dSigma
|
| 414 |
+
|
| 415 |
+
# 2. Transpose of matrices for gradient calculations
|
| 416 |
+
Rt = wp.transpose(R)
|
| 417 |
+
dL_dMt = wp.transpose(dL_dM)
|
| 418 |
+
|
| 419 |
+
# 3. Gradient w.r.t. scales - matching CUDA directly
|
| 420 |
+
dL_dscale = wp.vec3(
|
| 421 |
+
wp.dot(Rt[0], dL_dMt[0]),
|
| 422 |
+
wp.dot(Rt[1], dL_dMt[1]),
|
| 423 |
+
wp.dot(Rt[2], dL_dMt[2])
|
| 424 |
+
)
|
| 425 |
+
dL_dscales[idx] = dL_dscale * scale_modifier
|
| 426 |
+
|
| 427 |
+
# 4. Scale dL_dMt by scale factors for quaternion gradient calculation
|
| 428 |
+
dL_dMt_scaled = wp.mat33(
|
| 429 |
+
dL_dMt[0, 0] * s_vec[0], dL_dMt[0, 1] * s_vec[0], dL_dMt[0, 2] * s_vec[0],
|
| 430 |
+
dL_dMt[1, 0] * s_vec[1], dL_dMt[1, 1] * s_vec[1], dL_dMt[1, 2] * s_vec[1],
|
| 431 |
+
dL_dMt[2, 0] * s_vec[2], dL_dMt[2, 1] * s_vec[2], dL_dMt[2, 2] * s_vec[2]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# 5. Gradients of loss w.r.t. quaternion components
|
| 435 |
+
dL_dr = 2.0 * (z * (dL_dMt_scaled[0, 1] - dL_dMt_scaled[1, 0]) +
|
| 436 |
+
y * (dL_dMt_scaled[2, 0] - dL_dMt_scaled[0, 2]) +
|
| 437 |
+
x * (dL_dMt_scaled[1, 2] - dL_dMt_scaled[2, 1]))
|
| 438 |
+
|
| 439 |
+
dL_dx = 2.0 * (y * (dL_dMt_scaled[1, 0] + dL_dMt_scaled[0, 1]) +
|
| 440 |
+
z * (dL_dMt_scaled[2, 0] + dL_dMt_scaled[0, 2]) +
|
| 441 |
+
r * (dL_dMt_scaled[1, 2] - dL_dMt_scaled[2, 1])) - \
|
| 442 |
+
4.0 * x * (dL_dMt_scaled[2, 2] + dL_dMt_scaled[1, 1])
|
| 443 |
+
|
| 444 |
+
dL_dy = 2.0 * (x * (dL_dMt_scaled[1, 0] + dL_dMt_scaled[0, 1]) +
|
| 445 |
+
r * (dL_dMt_scaled[2, 0] - dL_dMt_scaled[0, 2]) +
|
| 446 |
+
z * (dL_dMt_scaled[1, 2] + dL_dMt_scaled[2, 1])) - \
|
| 447 |
+
4.0 * y * (dL_dMt_scaled[2, 2] + dL_dMt_scaled[0, 0])
|
| 448 |
+
|
| 449 |
+
dL_dz = 2.0 * (r * (dL_dMt_scaled[0, 1] - dL_dMt_scaled[1, 0]) +
|
| 450 |
+
x * (dL_dMt_scaled[2, 0] + dL_dMt_scaled[0, 2]) +
|
| 451 |
+
y * (dL_dMt_scaled[1, 2] + dL_dMt_scaled[2, 1])) - \
|
| 452 |
+
4.0 * z * (dL_dMt_scaled[1, 1] + dL_dMt_scaled[0, 0])
|
| 453 |
+
|
| 454 |
+
# 6. Convert back to Warp's quaternion ordering (x, y, z, r/w)
|
| 455 |
+
dL_drots[idx] = wp.vec4(dL_dx, dL_dy, dL_dz, dL_dr)
|
| 456 |
+
|
| 457 |
+
@wp.kernel
|
| 458 |
+
def wp_render_backward_kernel(
|
| 459 |
+
# --- Inputs ---
|
| 460 |
+
# Tile/Range data
|
| 461 |
+
ranges: wp.array(dtype=wp.vec2i), # Range of point indices for each tile (start, end)
|
| 462 |
+
point_list: wp.array(dtype=int), # Sorted point indices
|
| 463 |
+
|
| 464 |
+
# Image parameters
|
| 465 |
+
W: int, # Image width
|
| 466 |
+
H: int, # Image height
|
| 467 |
+
bg_color: wp.vec3, # Background color
|
| 468 |
+
tile_grid: wp.vec3, # Tile grid dimensions
|
| 469 |
+
|
| 470 |
+
# Gaussian parameters
|
| 471 |
+
points_xy_image: wp.array(dtype=wp.vec2), # 2D projected positions
|
| 472 |
+
conic_opacity: wp.array(dtype=wp.vec4), # Conic matrices and opacities (a, b, c, opacity)
|
| 473 |
+
colors: wp.array(dtype=wp.vec3), # RGB colors
|
| 474 |
+
|
| 475 |
+
# Forward pass results
|
| 476 |
+
final_Ts: wp.array2d(dtype=float), # Final transparency values
|
| 477 |
+
n_contrib: wp.array2d(dtype=int), # Number of Gaussians contributing to each pixel
|
| 478 |
+
dL_dpixels: wp.array2d(dtype=wp.vec3), # Gradient of loss w.r.t. output pixels
|
| 479 |
+
|
| 480 |
+
# --- Outputs ---
|
| 481 |
+
dL_dmean2D: wp.array(dtype=wp.vec3), # Gradient w.r.t. 2D mean positions
|
| 482 |
+
dL_dconic2D: wp.array(dtype=wp.vec4), # Gradient w.r.t. conic matrices
|
| 483 |
+
dL_dopacity: wp.array(dtype=float), # Gradient w.r.t. opacity
|
| 484 |
+
dL_dcolors: wp.array(dtype=wp.vec3), # Gradient w.r.t. colors
|
| 485 |
+
):
|
| 486 |
+
"""
|
| 487 |
+
Backward version of the rendering procedure, computing gradients of the loss with respect
|
| 488 |
+
to Gaussian parameters based on gradients of the loss with respect to output pixels.
|
| 489 |
+
|
| 490 |
+
This kernel is launched per pixel and processes Gaussians in back-to-front order,
|
| 491 |
+
similar to the forward rendering pass but accumulating gradients.
|
| 492 |
+
"""
|
| 493 |
+
# Get pixel coordinates
|
| 494 |
+
tile_x, tile_y, tid_x, tid_y = wp.tid()
|
| 495 |
+
|
| 496 |
+
# Calculate pixel position
|
| 497 |
+
pix_x = tile_x * TILE_M + tid_x
|
| 498 |
+
pix_y = tile_y * TILE_N + tid_y
|
| 499 |
+
|
| 500 |
+
# Skip if pixel is outside image bounds
|
| 501 |
+
inside = (pix_x < W) and (pix_y < H)
|
| 502 |
+
if not inside:
|
| 503 |
+
return
|
| 504 |
+
|
| 505 |
+
# Convert to float coordinates for calculations
|
| 506 |
+
pixf_x = float(pix_x)
|
| 507 |
+
pixf_y = float(pix_y)
|
| 508 |
+
|
| 509 |
+
# Get tile range (start/end indices in point_list)
|
| 510 |
+
tile_id = tile_y * int(tile_grid[0]) + tile_x
|
| 511 |
+
|
| 512 |
+
range_start = ranges[tile_id][0]
|
| 513 |
+
range_end = ranges[tile_id][1]
|
| 514 |
+
|
| 515 |
+
# Get final transparency value and number of contributors from forward pass
|
| 516 |
+
T_final = final_Ts[pix_y, pix_x]
|
| 517 |
+
last_contributor = n_contrib[pix_y, pix_x]
|
| 518 |
+
|
| 519 |
+
# first_kept = max(range_start, range_end - last_contributor) # = range_end-N
|
| 520 |
+
last_kept = min(range_end, range_start + last_contributor)
|
| 521 |
+
|
| 522 |
+
# Initialize working variables
|
| 523 |
+
T = T_final # Current accumulated transparency
|
| 524 |
+
accum_rec = wp.vec3(0.0, 0.0, 0.0) # Accumulated color
|
| 525 |
+
last_alpha = float(0.0) # Alpha from the last processed Gaussian
|
| 526 |
+
last_color = wp.vec3(0.0, 0.0, 0.0) # Color from the last processed Gaussian
|
| 527 |
+
|
| 528 |
+
# Get gradients
|
| 529 |
+
dL_dpixel = dL_dpixels[pix_y, pix_x]
|
| 530 |
+
|
| 531 |
+
# Gradient of pixel coordinate w.r.t. normalized screen-space coordinates
|
| 532 |
+
ddelx_dx = 0.5 * float(W)
|
| 533 |
+
ddely_dy = 0.5 * float(H)
|
| 534 |
+
for i in range(last_kept - 1, range_start - 1, -1):
|
| 535 |
+
gaussian_id = point_list[i]
|
| 536 |
+
xy = points_xy_image[gaussian_id]
|
| 537 |
+
con_o = conic_opacity[gaussian_id] # (a, b, c, opacity)
|
| 538 |
+
color = colors[gaussian_id]
|
| 539 |
+
|
| 540 |
+
# Compute distance to pixel center
|
| 541 |
+
d_x = xy[0] - pixf_x
|
| 542 |
+
d_y = xy[1] - pixf_y
|
| 543 |
+
|
| 544 |
+
# Compute Gaussian power
|
| 545 |
+
power = -0.5 * (con_o[0] * d_x * d_x + con_o[2] * d_y * d_y) - con_o[1] * d_x * d_y
|
| 546 |
+
|
| 547 |
+
# Skip if power is positive (too far away)
|
| 548 |
+
if power > 0.0:
|
| 549 |
+
continue
|
| 550 |
+
|
| 551 |
+
# Compute Gaussian value and alpha
|
| 552 |
+
G = wp.exp(power)
|
| 553 |
+
alpha = wp.min(0.99, con_o[3] * G)
|
| 554 |
+
|
| 555 |
+
# Skip if alpha is too small
|
| 556 |
+
if alpha < (1.0 / 255.0):
|
| 557 |
+
continue
|
| 558 |
+
|
| 559 |
+
T = T / (1.0 - alpha)
|
| 560 |
+
|
| 561 |
+
# Gradient factor for color contribution
|
| 562 |
+
dchannel_dcolor = alpha * T
|
| 563 |
+
|
| 564 |
+
# Compute gradient w.r.t. alpha
|
| 565 |
+
dL_dalpha = 0.0
|
| 566 |
+
|
| 567 |
+
# Update color accumulation and compute color gradients
|
| 568 |
+
accum_rec = last_alpha * last_color + (1.0 - last_alpha) * accum_rec
|
| 569 |
+
dL_dchannel = dL_dpixel
|
| 570 |
+
last_color = color
|
| 571 |
+
|
| 572 |
+
dL_dalpha = wp.dot(color - accum_rec, dL_dpixel)
|
| 573 |
+
wp.atomic_add(dL_dcolors, gaussian_id, dchannel_dcolor * dL_dchannel)
|
| 574 |
+
|
| 575 |
+
# Scale dL_dalpha by T
|
| 576 |
+
dL_dalpha *= T
|
| 577 |
+
last_alpha = alpha
|
| 578 |
+
|
| 579 |
+
# Account for background color contribution
|
| 580 |
+
bg_dot_dpixel = wp.dot(bg_color, dL_dpixel)
|
| 581 |
+
dL_dalpha += (-T_final / (1.0 - alpha)) * bg_dot_dpixel
|
| 582 |
+
|
| 583 |
+
# Helpful temporary variables
|
| 584 |
+
dL_dG = con_o[3] * dL_dalpha
|
| 585 |
+
gdx = G * d_x
|
| 586 |
+
gdy = G * d_y
|
| 587 |
+
dG_ddelx = -gdx * con_o[0] - gdy * con_o[1]
|
| 588 |
+
dG_ddely = -gdy * con_o[2] - gdx * con_o[1]
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
# Update gradients w.r.t. 2D mean position
|
| 592 |
+
wp.atomic_add(dL_dmean2D, gaussian_id, wp.vec3(
|
| 593 |
+
dL_dG * dG_ddelx * ddelx_dx,
|
| 594 |
+
dL_dG * dG_ddely * ddely_dy,
|
| 595 |
+
0.0
|
| 596 |
+
))
|
| 597 |
+
|
| 598 |
+
# Update gradients w.r.t. 2D conic matrix
|
| 599 |
+
wp.atomic_add(dL_dconic2D, gaussian_id, wp.vec4(
|
| 600 |
+
-0.5 * gdx * d_x * dL_dG,
|
| 601 |
+
-0.5 * gdx * d_y * dL_dG,
|
| 602 |
+
0.0,
|
| 603 |
+
-0.5 * gdy * d_y * dL_dG
|
| 604 |
+
))
|
| 605 |
+
|
| 606 |
+
# Update gradients w.r.t. opacity
|
| 607 |
+
wp.atomic_add(dL_dopacity, gaussian_id, G * dL_dalpha)
|
| 608 |
+
|
| 609 |
+
@wp.kernel
|
| 610 |
+
def compute_projection_backward_kernel(
|
| 611 |
+
# --- Inputs ---
|
| 612 |
+
num_points: int, # Number of Gaussian points
|
| 613 |
+
means: wp.array(dtype=wp.vec3), # 3D positions (N, 3)
|
| 614 |
+
radii: wp.array(dtype=int), # Radii computed in forward (N,) - used for skipping
|
| 615 |
+
proj_matrix: wp.mat44, # Projection matrix (4, 4)
|
| 616 |
+
dL_dmean2D: wp.array(dtype=wp.vec3), # Grad of loss w.r.t. 2D projected means (N, 2)
|
| 617 |
+
|
| 618 |
+
# --- Outputs (Accumulate) ---
|
| 619 |
+
dL_dmeans: wp.array(dtype=wp.vec3) # Accumulate mean grads here (N, 3)
|
| 620 |
+
):
|
| 621 |
+
"""Compute gradients of 3D means due to projection to 2D.
|
| 622 |
+
|
| 623 |
+
This kernel handles the gradient propagation from 2D projected positions
|
| 624 |
+
back to 3D positions, based on the projection matrix.
|
| 625 |
+
"""
|
| 626 |
+
idx = wp.tid()
|
| 627 |
+
if idx >= num_points or radii[idx] <= 0: # Skip if not rendered
|
| 628 |
+
return
|
| 629 |
+
|
| 630 |
+
# Get 3D mean and 2D mean gradient
|
| 631 |
+
mean3D = means[idx]
|
| 632 |
+
dL_dmean2D_val = dL_dmean2D[idx]
|
| 633 |
+
|
| 634 |
+
# Compute homogeneous coordinates
|
| 635 |
+
m_hom = wp.vec4(mean3D[0], mean3D[1], mean3D[2], 1.0)
|
| 636 |
+
m_hom = m_hom * proj_matrix
|
| 637 |
+
|
| 638 |
+
# Division by w (perspective division)
|
| 639 |
+
m_w = 1.0 / (m_hom[3] + 0.0000001)
|
| 640 |
+
|
| 641 |
+
# Compute gradient of loss w.r.t. 3D means due to 2D mean gradients
|
| 642 |
+
# Following the chain rule through the perspective projection
|
| 643 |
+
mul1 = (proj_matrix[0, 0] * mean3D[0] + proj_matrix[1, 0] * mean3D[1] +
|
| 644 |
+
proj_matrix[2, 0] * mean3D[2] + proj_matrix[3, 0]) * m_w * m_w
|
| 645 |
+
|
| 646 |
+
mul2 = (proj_matrix[0, 1] * mean3D[0] + proj_matrix[1, 1] * mean3D[1] +
|
| 647 |
+
proj_matrix[2, 1] * mean3D[2] + proj_matrix[3, 1]) * m_w * m_w
|
| 648 |
+
|
| 649 |
+
dL_dmean = wp.vec3(0.0, 0.0, 0.0)
|
| 650 |
+
|
| 651 |
+
# x component of gradient
|
| 652 |
+
dL_dmean[0] = (proj_matrix[0, 0] * m_w - proj_matrix[0, 3] * mul1) * dL_dmean2D_val[0] + \
|
| 653 |
+
(proj_matrix[0, 1] * m_w - proj_matrix[0, 3] * mul2) * dL_dmean2D_val[1]
|
| 654 |
+
|
| 655 |
+
# y component of gradient
|
| 656 |
+
dL_dmean[1] = (proj_matrix[1, 0] * m_w - proj_matrix[1, 3] * mul1) * dL_dmean2D_val[0] + \
|
| 657 |
+
(proj_matrix[1, 1] * m_w - proj_matrix[1, 3] * mul2) * dL_dmean2D_val[1]
|
| 658 |
+
|
| 659 |
+
# z component of gradient
|
| 660 |
+
dL_dmean[2] = (proj_matrix[2, 0] * m_w - proj_matrix[2, 3] * mul1) * dL_dmean2D_val[0] + \
|
| 661 |
+
(proj_matrix[2, 1] * m_w - proj_matrix[2, 3] * mul2) * dL_dmean2D_val[1]
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
dL_dmeans[idx] += dL_dmean
|
| 665 |
+
|
| 666 |
+
def backward_preprocess(
|
| 667 |
+
# Camera and model parameters
|
| 668 |
+
num_points: int,
|
| 669 |
+
means: wp.array(dtype=wp.vec3), # 3D means
|
| 670 |
+
means_2d: wp.array(dtype=wp.vec2), # 2D means
|
| 671 |
+
radii: wp.array(dtype=int), # Computed radii
|
| 672 |
+
sh_coeffs: wp.array(dtype=wp.vec3), # SH coefficients
|
| 673 |
+
scales: wp.array(dtype=wp.vec3), # Scale parameters
|
| 674 |
+
rotations: wp.array(dtype=wp.vec4), # Rotation quaternions
|
| 675 |
+
viewmatrix: wp.mat44, # Camera view matrix
|
| 676 |
+
projmatrix: wp.mat44, # Camera projection matrix
|
| 677 |
+
fov_x: float, # Camera horizontal FOV
|
| 678 |
+
fov_y: float, # Camera vertical FOV
|
| 679 |
+
focal_x: float,
|
| 680 |
+
focal_y: float,
|
| 681 |
+
|
| 682 |
+
# Intermediate data from forward
|
| 683 |
+
cov3Ds: wp.array(dtype=wp.mat33), # 3D covariance matrices (or VEC6 depending on packing)
|
| 684 |
+
conic_opacity: wp.array(dtype=wp.vec4), # 2D conics and opacity
|
| 685 |
+
campos: wp.array(dtype=wp.vec3), # View directions (should be campos)
|
| 686 |
+
clamped: wp.array(dtype=wp.uint32), # Clamping states
|
| 687 |
+
|
| 688 |
+
# Incoming gradients from render backward
|
| 689 |
+
dL_dmean2D: wp.array(dtype=wp.vec3), # Grad of loss w.r.t. 2D means
|
| 690 |
+
dL_dconic: wp.array(dtype=wp.vec4), # Grad of loss w.r.t. 2D conics
|
| 691 |
+
dL_dopacity: wp.array(dtype=float), # Grad of loss w.r.t. opacity
|
| 692 |
+
dL_dcolors: wp.array(dtype=wp.vec3), # Grad of loss w.r.t. colors
|
| 693 |
+
|
| 694 |
+
# Output gradient buffers
|
| 695 |
+
dL_dmeans: wp.array(dtype=wp.vec3), # Output grad for 3D means
|
| 696 |
+
dL_dsh: wp.array(dtype=wp.vec3), # Output grad for SH coeffs
|
| 697 |
+
dL_dscales: wp.array(dtype=wp.vec3), # Output grad for scales
|
| 698 |
+
dL_drots: wp.array(dtype=wp.vec4), # Output grad for rotations
|
| 699 |
+
|
| 700 |
+
# Optional parameters
|
| 701 |
+
scale_modifier: float = 1.0,
|
| 702 |
+
sh_degree: int = 3
|
| 703 |
+
):
|
| 704 |
+
"""
|
| 705 |
+
Orchestrates the backward pass for 3D Gaussian Splatting by coordinating several kernel calls.
|
| 706 |
+
"""
|
| 707 |
+
# Create buffer for 3D covariance gradients
|
| 708 |
+
dL_dcov3D = wp.zeros(num_points, dtype=VEC6, device=DEVICE)
|
| 709 |
+
# Step 1: Compute gradients for 2D covariance (conic matrix)
|
| 710 |
+
# This also computes gradients w.r.t. 3D means due to conic computation
|
| 711 |
+
wp.launch(
|
| 712 |
+
kernel=compute_cov2d_backward_kernel,
|
| 713 |
+
dim=num_points,
|
| 714 |
+
inputs=[
|
| 715 |
+
num_points, # P
|
| 716 |
+
means, # means3D
|
| 717 |
+
cov3Ds, # cov3Ds
|
| 718 |
+
radii, # radii
|
| 719 |
+
focal_x, # focal_x
|
| 720 |
+
focal_y, # focal_y
|
| 721 |
+
fov_x, # tan_fovx
|
| 722 |
+
fov_y, # tan_fovy
|
| 723 |
+
viewmatrix, # viewmatrix
|
| 724 |
+
dL_dconic, # dL_dconic
|
| 725 |
+
dL_dmeans, # dL_dmean3D (outputs)
|
| 726 |
+
dL_dcov3D # dL_dcov3D (outputs)
|
| 727 |
+
],
|
| 728 |
+
device=DEVICE
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
dL_dmeans_np = dL_dmeans.numpy()
|
| 732 |
+
# Step 2: Compute gradients for 3D means due to projection
|
| 733 |
+
wp.launch(
|
| 734 |
+
kernel=compute_projection_backward_kernel,
|
| 735 |
+
dim=num_points,
|
| 736 |
+
inputs=[
|
| 737 |
+
num_points,
|
| 738 |
+
means,
|
| 739 |
+
radii,
|
| 740 |
+
projmatrix,
|
| 741 |
+
dL_dmean2D,
|
| 742 |
+
dL_dmeans # Accumulate to final means gradients
|
| 743 |
+
],
|
| 744 |
+
device=DEVICE
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
# Step 3: Compute gradients for SH coefficients
|
| 748 |
+
wp.launch(
|
| 749 |
+
kernel=sh_backward_kernel,
|
| 750 |
+
dim=num_points,
|
| 751 |
+
inputs=[
|
| 752 |
+
num_points,
|
| 753 |
+
sh_degree,
|
| 754 |
+
means,
|
| 755 |
+
sh_coeffs,
|
| 756 |
+
radii,
|
| 757 |
+
campos,
|
| 758 |
+
clamped,
|
| 759 |
+
dL_dcolors,
|
| 760 |
+
dL_dmeans,
|
| 761 |
+
dL_dsh
|
| 762 |
+
],
|
| 763 |
+
|
| 764 |
+
device=DEVICE
|
| 765 |
+
)
|
| 766 |
+
dL_dmeans_np = dL_dmeans.numpy()
|
| 767 |
+
# Step 4: Compute gradients for scales and rotations
|
| 768 |
+
wp.launch(
|
| 769 |
+
kernel=compute_cov3d_backward_kernel,
|
| 770 |
+
dim=num_points,
|
| 771 |
+
inputs=[
|
| 772 |
+
num_points,
|
| 773 |
+
scales,
|
| 774 |
+
rotations,
|
| 775 |
+
radii,
|
| 776 |
+
scale_modifier,
|
| 777 |
+
dL_dcov3D,
|
| 778 |
+
dL_dscales, # Output scale gradients
|
| 779 |
+
dL_drots # Output rotation gradients
|
| 780 |
+
],
|
| 781 |
+
device=DEVICE
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
return dL_dmeans, dL_dsh, dL_dscales, dL_drots
|
| 785 |
+
|
| 786 |
+
def backward_render(
|
| 787 |
+
ranges,
|
| 788 |
+
point_list,
|
| 789 |
+
width,
|
| 790 |
+
height,
|
| 791 |
+
bg_color,
|
| 792 |
+
tile_grid,
|
| 793 |
+
points_xy_image,
|
| 794 |
+
conic_opacity,
|
| 795 |
+
colors,
|
| 796 |
+
final_Ts,
|
| 797 |
+
n_contrib,
|
| 798 |
+
dL_dpixels,
|
| 799 |
+
dL_dmean2D,
|
| 800 |
+
dL_dconic2D,
|
| 801 |
+
dL_dopacity,
|
| 802 |
+
dL_dcolors,
|
| 803 |
+
):
|
| 804 |
+
"""
|
| 805 |
+
Orchestrates the backward rendering process by launching the backward kernel.
|
| 806 |
+
|
| 807 |
+
Args:
|
| 808 |
+
ranges: Range of point indices for each tile
|
| 809 |
+
point_list: Sorted list of point indices
|
| 810 |
+
width, height: Image dimensions
|
| 811 |
+
bg_color: Background color
|
| 812 |
+
points_xy_image: 2D positions of Gaussians
|
| 813 |
+
conic_opacity: Conic matrices and opacities
|
| 814 |
+
colors: RGB colors
|
| 815 |
+
final_Ts: Final transparency values from forward pass
|
| 816 |
+
n_contrib: Number of contributors per pixel
|
| 817 |
+
dL_dpixels: Gradient of loss w.r.t. output pixels
|
| 818 |
+
dL_dmean2D: Output gradient w.r.t. 2D mean positions
|
| 819 |
+
dL_dconic2D: Output gradient w.r.t. conic matrices
|
| 820 |
+
dL_dopacity: Output gradient w.r.t. opacity
|
| 821 |
+
dL_dcolors: Output gradient w.r.t. colors
|
| 822 |
+
"""
|
| 823 |
+
# Calculate tile grid dimensions
|
| 824 |
+
tile_grid_x = (width + TILE_M - 1) // TILE_M
|
| 825 |
+
tile_grid_y = (height + TILE_N - 1) // TILE_N
|
| 826 |
+
ranges_np = ranges.numpy()
|
| 827 |
+
# Launch the backward rendering kernel
|
| 828 |
+
wp.launch(
|
| 829 |
+
kernel=wp_render_backward_kernel,
|
| 830 |
+
dim=(tile_grid_x, tile_grid_y, TILE_M, TILE_N),
|
| 831 |
+
inputs=[
|
| 832 |
+
ranges,
|
| 833 |
+
point_list,
|
| 834 |
+
width,
|
| 835 |
+
height,
|
| 836 |
+
bg_color,
|
| 837 |
+
tile_grid,
|
| 838 |
+
points_xy_image,
|
| 839 |
+
conic_opacity,
|
| 840 |
+
colors,
|
| 841 |
+
final_Ts,
|
| 842 |
+
n_contrib,
|
| 843 |
+
dL_dpixels,
|
| 844 |
+
dL_dmean2D,
|
| 845 |
+
dL_dconic2D,
|
| 846 |
+
dL_dopacity,
|
| 847 |
+
dL_dcolors,
|
| 848 |
+
],
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
def backward(
|
| 852 |
+
# --- Core parameters ---
|
| 853 |
+
background,
|
| 854 |
+
means3D,
|
| 855 |
+
dL_dpixels,
|
| 856 |
+
# --- Model parameters ---
|
| 857 |
+
opacity=None,
|
| 858 |
+
shs=None,
|
| 859 |
+
scales=None,
|
| 860 |
+
rotations=None,
|
| 861 |
+
scale_modifier=1.0,
|
| 862 |
+
# --- Camera parameters ---
|
| 863 |
+
viewmatrix=None,
|
| 864 |
+
projmatrix=None,
|
| 865 |
+
tan_fovx=0.5,
|
| 866 |
+
tan_fovy=0.5,
|
| 867 |
+
image_height=256,
|
| 868 |
+
image_width=256,
|
| 869 |
+
campos=None,
|
| 870 |
+
# --- Forward output buffers ---
|
| 871 |
+
radii=None,
|
| 872 |
+
means2D=None,
|
| 873 |
+
conic_opacity=None,
|
| 874 |
+
rgb=None,
|
| 875 |
+
clamped=None,
|
| 876 |
+
cov3Ds=None,
|
| 877 |
+
# --- Internal state buffers ---
|
| 878 |
+
geom_buffer=None,
|
| 879 |
+
binning_buffer=None,
|
| 880 |
+
img_buffer=None,
|
| 881 |
+
# --- Algorithm parameters ---
|
| 882 |
+
degree=3,
|
| 883 |
+
debug=False,
|
| 884 |
+
):
|
| 885 |
+
"""
|
| 886 |
+
Main backward function for 3D Gaussian Splatting.
|
| 887 |
+
|
| 888 |
+
This function orchestrates the entire backward pass by calling two main sub-functions:
|
| 889 |
+
1. backward_render: Computes gradients w.r.t. 2D parameters (mean2D, conic, opacity, color)
|
| 890 |
+
2. backward_preprocess: Computes gradients w.r.t. 3D parameters
|
| 891 |
+
(mean3D, cov3D, SH coefficients, scales, rotations)
|
| 892 |
+
|
| 893 |
+
Args:
|
| 894 |
+
background: Background color as numpy array, torch tensor, or wp.vec3 (3,)
|
| 895 |
+
means3D: 3D positions as numpy array, torch tensor, or wp.array (N, 3)
|
| 896 |
+
dL_dpixels: Gradient of loss w.r.t. output pixels (H, W, 3)
|
| 897 |
+
opacity: Opacity values (N, 1) or (N,)
|
| 898 |
+
shs: Spherical harmonics coefficients (N, D, 3) or flattened (N*D, 3)
|
| 899 |
+
scales: Scale parameters (N, 3)
|
| 900 |
+
rotations: Rotation quaternions (N, 4)
|
| 901 |
+
scale_modifier: Global scale modifier (float)
|
| 902 |
+
viewmatrix: View matrix (4, 4)
|
| 903 |
+
projmatrix: Projection matrix (4, 4)
|
| 904 |
+
tan_fovx: Tangent of x field of view
|
| 905 |
+
tan_fovy: Tangent of y field of view
|
| 906 |
+
image_height: Image height
|
| 907 |
+
image_width: Image width
|
| 908 |
+
campos: Camera position (3,)
|
| 909 |
+
radii: Computed radii from forward pass (N,)
|
| 910 |
+
means2D: 2D projected positions from forward pass (N, 2)
|
| 911 |
+
conic_opacity: Conic matrices + opacity from forward pass (N, 4)
|
| 912 |
+
rgb: RGB colors from forward pass (N, 3)
|
| 913 |
+
clamped: Clamping state from forward pass (N, 3)
|
| 914 |
+
cov3Ds: 3D covariance matrices from forward pass (N, 6)
|
| 915 |
+
geom_buffer: Dictionary holding geometric state
|
| 916 |
+
binning_buffer: Dictionary holding binning state
|
| 917 |
+
img_buffer: Dictionary holding image state
|
| 918 |
+
degree: SH degree (0-3)
|
| 919 |
+
debug: Enable debug output
|
| 920 |
+
|
| 921 |
+
Returns:
|
| 922 |
+
dict: Dictionary containing gradients for all model parameters:
|
| 923 |
+
- dL_dmean3D: Gradient w.r.t. 3D positions (N, 3)
|
| 924 |
+
- dL_dcolor: Gradient w.r.t. colors (N, 3)
|
| 925 |
+
- dL_dshs: Gradient w.r.t. SH coefficients (N*D, 3)
|
| 926 |
+
- dL_dopacity: Gradient w.r.t. opacity (N,)
|
| 927 |
+
- dL_dscale: Gradient w.r.t. scales (N, 3)
|
| 928 |
+
- dL_drot: Gradient w.r.t. rotations (N, 4)
|
| 929 |
+
"""
|
| 930 |
+
# Calculate focal lengths from FoV
|
| 931 |
+
focal_y = image_height / (2.0 * tan_fovy)
|
| 932 |
+
focal_x = image_width / (2.0 * tan_fovx)
|
| 933 |
+
|
| 934 |
+
# Convert inputs to warp arrays
|
| 935 |
+
background_warp = background if isinstance(background, wp.vec3) else wp.vec3(background[0], background[1], background[2])
|
| 936 |
+
means3D_warp = to_warp_array(means3D, wp.vec3)
|
| 937 |
+
dL_dpixels_warp = to_warp_array(dL_dpixels, wp.vec3) if not isinstance(dL_dpixels, wp.array) else dL_dpixels
|
| 938 |
+
|
| 939 |
+
# Get number of points
|
| 940 |
+
num_points = means3D_warp.shape[0]
|
| 941 |
+
|
| 942 |
+
# Convert optional parameters if provided
|
| 943 |
+
opacity_warp = to_warp_array(opacity, float, flatten=True) if opacity is not None else None
|
| 944 |
+
|
| 945 |
+
# SH coefficients need special handling for flattening
|
| 946 |
+
if shs is not None:
|
| 947 |
+
sh_data = shs.reshape(-1, 3) if hasattr(shs, 'reshape') and shs.ndim > 2 else shs
|
| 948 |
+
shs_warp = to_warp_array(sh_data, wp.vec3)
|
| 949 |
+
else:
|
| 950 |
+
shs_warp = None
|
| 951 |
+
|
| 952 |
+
# Handle other model parameters
|
| 953 |
+
scales_warp = to_warp_array(scales, wp.vec3) if scales is not None else None
|
| 954 |
+
|
| 955 |
+
# Handle rotations differently based on shape (matrices vs quaternions)
|
| 956 |
+
if rotations is not None:
|
| 957 |
+
rot_shape = rotations.shape[-1] if hasattr(rotations, 'shape') else rotations.size(-1)
|
| 958 |
+
if rot_shape == 4: # Quaternions
|
| 959 |
+
rotations_warp = to_warp_array(rotations, wp.vec4)
|
| 960 |
+
else: # 3x3 matrices
|
| 961 |
+
rotations_warp = to_warp_array(rotations, wp.mat33)
|
| 962 |
+
else:
|
| 963 |
+
rotations_warp = None
|
| 964 |
+
|
| 965 |
+
# Handle camera parameters
|
| 966 |
+
viewmatrix_warp = viewmatrix if isinstance(viewmatrix, wp.mat44) else wp.mat44(viewmatrix.flatten())
|
| 967 |
+
projmatrix_warp = projmatrix if isinstance(projmatrix, wp.mat44) else wp.mat44(projmatrix.flatten())
|
| 968 |
+
campos_warp = campos if isinstance(campos, wp.vec3) else wp.vec3(campos[0], campos[1], campos[2])
|
| 969 |
+
|
| 970 |
+
# --- Extract data from buffer dictionaries if provided ---
|
| 971 |
+
if img_buffer is not None:
|
| 972 |
+
ranges = img_buffer.get('ranges')
|
| 973 |
+
final_Ts = img_buffer.get('final_Ts')
|
| 974 |
+
n_contrib = img_buffer.get('n_contrib')
|
| 975 |
+
|
| 976 |
+
if binning_buffer is not None:
|
| 977 |
+
point_list = binning_buffer.get('point_list')
|
| 978 |
+
|
| 979 |
+
if geom_buffer is not None:
|
| 980 |
+
# Use internal data if not provided directly
|
| 981 |
+
if radii is None:
|
| 982 |
+
radii = geom_buffer.get('radii')
|
| 983 |
+
if means2D is None:
|
| 984 |
+
means2D = geom_buffer.get('means2D')
|
| 985 |
+
if conic_opacity is None:
|
| 986 |
+
conic_opacity = geom_buffer.get('conic_opacity')
|
| 987 |
+
if rgb is None:
|
| 988 |
+
rgb = geom_buffer.get('rgb')
|
| 989 |
+
if clamped is None:
|
| 990 |
+
clamped = geom_buffer.get('clamped_state')
|
| 991 |
+
|
| 992 |
+
# Convert forward pass outputs to warp arrays if they're not already
|
| 993 |
+
radii_warp = to_warp_array(radii, int) if radii is not None else None
|
| 994 |
+
means2D_warp = to_warp_array(means2D, wp.vec2) if means2D is not None else None
|
| 995 |
+
conic_opacity_warp = to_warp_array(conic_opacity, wp.vec4) if conic_opacity is not None else None
|
| 996 |
+
rgb_warp = to_warp_array(rgb, wp.vec3) if rgb is not None else None
|
| 997 |
+
clamped_warp = to_warp_array(clamped, wp.uint32) if clamped is not None else None
|
| 998 |
+
|
| 999 |
+
# --- Initialize output gradient arrays ---
|
| 1000 |
+
dL_dmean2D = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
|
| 1001 |
+
dL_dconic = wp.zeros(num_points, dtype=wp.vec4, device=DEVICE)
|
| 1002 |
+
dL_dopacity = wp.zeros(num_points, dtype=float, device=DEVICE)
|
| 1003 |
+
dL_dcolor = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
|
| 1004 |
+
|
| 1005 |
+
dL_dmean3D = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
|
| 1006 |
+
dL_dcov3D = wp.zeros(num_points, dtype=VEC6, device=DEVICE)
|
| 1007 |
+
|
| 1008 |
+
# SH gradients depend on degree
|
| 1009 |
+
max_sh_coeffs = 16 if degree >= 3 else (degree + 1) * (degree + 1)
|
| 1010 |
+
dL_dsh = wp.zeros(num_points * max_sh_coeffs, dtype=wp.vec3, device=DEVICE)
|
| 1011 |
+
|
| 1012 |
+
dL_dscale = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
|
| 1013 |
+
dL_drot = wp.zeros(num_points, dtype=wp.vec4, device=DEVICE)
|
| 1014 |
+
|
| 1015 |
+
# Use precomputed colors if provided, otherwise use colors from forward pass
|
| 1016 |
+
|
| 1017 |
+
tile_grid = wp.vec3((image_width + TILE_M - 1) // TILE_M,
|
| 1018 |
+
(image_height + TILE_N - 1) // TILE_N,
|
| 1019 |
+
1)
|
| 1020 |
+
|
| 1021 |
+
# --- Step 1: Compute loss gradients w.r.t. 2D parameters ---
|
| 1022 |
+
backward_render(
|
| 1023 |
+
ranges=ranges,
|
| 1024 |
+
point_list=point_list,
|
| 1025 |
+
width=image_width,
|
| 1026 |
+
height=image_height,
|
| 1027 |
+
bg_color=background_warp,
|
| 1028 |
+
tile_grid=tile_grid,
|
| 1029 |
+
points_xy_image=means2D_warp,
|
| 1030 |
+
conic_opacity=conic_opacity_warp,
|
| 1031 |
+
colors=rgb_warp,
|
| 1032 |
+
final_Ts=final_Ts,
|
| 1033 |
+
n_contrib=n_contrib,
|
| 1034 |
+
dL_dpixels=dL_dpixels_warp,
|
| 1035 |
+
dL_dmean2D=dL_dmean2D,
|
| 1036 |
+
dL_dconic2D=dL_dconic,
|
| 1037 |
+
dL_dopacity=dL_dopacity,
|
| 1038 |
+
dL_dcolors=dL_dcolor,
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
# --- Step 2: Compute gradients for 3D parameters ---
|
| 1042 |
+
backward_preprocess(
|
| 1043 |
+
num_points=num_points,
|
| 1044 |
+
means=means3D_warp,
|
| 1045 |
+
means_2d=means2D_warp,
|
| 1046 |
+
radii=radii_warp,
|
| 1047 |
+
sh_coeffs=shs_warp,
|
| 1048 |
+
scales=scales_warp,
|
| 1049 |
+
rotations=rotations_warp,
|
| 1050 |
+
viewmatrix=viewmatrix_warp,
|
| 1051 |
+
projmatrix=projmatrix_warp,
|
| 1052 |
+
fov_x=tan_fovx,
|
| 1053 |
+
fov_y=tan_fovy,
|
| 1054 |
+
focal_x=focal_x,
|
| 1055 |
+
focal_y=focal_y,
|
| 1056 |
+
cov3Ds=cov3Ds,
|
| 1057 |
+
conic_opacity=conic_opacity_warp,
|
| 1058 |
+
campos=campos_warp,
|
| 1059 |
+
clamped=clamped_warp,
|
| 1060 |
+
dL_dmean2D=dL_dmean2D,
|
| 1061 |
+
dL_dconic=dL_dconic,
|
| 1062 |
+
dL_dopacity=dL_dopacity,
|
| 1063 |
+
dL_dcolors=dL_dcolor,
|
| 1064 |
+
dL_dmeans=dL_dmean3D,
|
| 1065 |
+
dL_dsh=dL_dsh,
|
| 1066 |
+
dL_dscales=dL_dscale,
|
| 1067 |
+
dL_drots=dL_drot,
|
| 1068 |
+
sh_degree=degree
|
| 1069 |
+
)
|
| 1070 |
+
|
| 1071 |
+
# Return all gradients in a dictionary for easy access
|
| 1072 |
+
return {
|
| 1073 |
+
'dL_dmean3D': dL_dmean3D,
|
| 1074 |
+
'dL_dcolor': dL_dcolor,
|
| 1075 |
+
'dL_dshs': dL_dsh,
|
| 1076 |
+
'dL_dopacity': dL_dopacity,
|
| 1077 |
+
'dL_dscale': dL_dscale,
|
| 1078 |
+
'dL_drot': dL_drot,
|
| 1079 |
+
# Include 2D gradients for completeness
|
| 1080 |
+
'dL_dmean2D': dL_dmean2D,
|
| 1081 |
+
'dL_dconic': dL_dconic,
|
| 1082 |
+
'dL_dcov3D': dL_dcov3D
|
| 1083 |
+
}
|
| 1084 |
+
|
gs/config.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings and constants for 3D Gaussian Splatting with NeRF datasets.
|
| 3 |
+
"""
|
| 4 |
+
import warp as wp
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
SEED = 42
|
| 9 |
+
random.seed(SEED)
|
| 10 |
+
|
| 11 |
+
# Warp data types and constants (keep capitalized as they are types)
|
| 12 |
+
WP_FLOAT16 = wp.float16
|
| 13 |
+
WP_FLOAT32 = wp.float32
|
| 14 |
+
WP_INT = wp.int32
|
| 15 |
+
WP_VEC2 = wp.vec2
|
| 16 |
+
WP_VEC2H = wp.vec2h
|
| 17 |
+
VEC6 = wp.types.vector(length=6, dtype=WP_FLOAT32)
|
| 18 |
+
DEVICE = "cuda" #"cpu" # Use "cpu" or "cuda"
|
| 19 |
+
|
| 20 |
+
TILE_M = wp.constant(16)
|
| 21 |
+
TILE_N = wp.constant(16)
|
| 22 |
+
TILE_THREADS = wp.constant(256)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GaussianParams:
|
| 26 |
+
"""Parameters for 3D Gaussian Splatting."""
|
| 27 |
+
|
| 28 |
+
# Training parameters
|
| 29 |
+
num_iterations = 3*7000//1 # Default number of training iterations
|
| 30 |
+
num_points = 5000 # Initial number of Gaussian points
|
| 31 |
+
|
| 32 |
+
# Simple learning rate scheduler configuration
|
| 33 |
+
use_lr_scheduler = True
|
| 34 |
+
# Learning rate scheduler configuration
|
| 35 |
+
lr_scheduler_config = {
|
| 36 |
+
'lr_pos': 1e-2, # Initial learning rate for positions
|
| 37 |
+
'lr_scale': 5e-3, # Initial learning rate for scales
|
| 38 |
+
'lr_rot': 5e-3, # Initial learning rate for rotations
|
| 39 |
+
'lr_sh': 2e-3, # Initial learning rate for spherical harmonics
|
| 40 |
+
'lr_opac': 5e-3, # Initial learning rate for opacities
|
| 41 |
+
'final_lr_factor': 0.01 # Final LR will be 1% of initial LR
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
# Optimization parameters
|
| 45 |
+
densification_interval = 100 # Perform densification every N iterations
|
| 46 |
+
pruning_interval = 100 # Perform pruning every N iterations
|
| 47 |
+
opacity_reset_interval = 3000
|
| 48 |
+
save_interval = 300 # Save checkpoint every N iterations
|
| 49 |
+
adam_beta1 = 0.9 # Adam optimizer beta1 parameter
|
| 50 |
+
adam_beta2 = 0.999 # Adam optimizer beta2 parameter
|
| 51 |
+
adam_epsilon = 1e-8 # Adam optimizer epsilon parameter
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
densify_grad_threshold = 0.0002
|
| 55 |
+
cull_opacity_threshold = 0.005
|
| 56 |
+
start_prune_iter = 500
|
| 57 |
+
end_prune_iter = 15000
|
| 58 |
+
percent_dense = 0.01
|
| 59 |
+
max_allowed_prune_ratio = 1.0 # no limit on pruning ratio
|
| 60 |
+
|
| 61 |
+
# Gaussian parameters
|
| 62 |
+
initial_scale = 0.1 #0.1 # Initial scale for Gaussian points
|
| 63 |
+
scale_modifier = 1.0 # Scaling factor for Gaussian splats
|
| 64 |
+
sh_degree = 3 # Spherical harmonics degree
|
| 65 |
+
|
| 66 |
+
# Scene parameters
|
| 67 |
+
scene_scale = 1.0 # Scale factor for the scene
|
| 68 |
+
background_color = [1.0,1.0,1.0] #[0.0, 0.0, 0.0] # White background for NeRF synthetic
|
| 69 |
+
|
| 70 |
+
# Loss parameters
|
| 71 |
+
lambda_dssim = 0.0 # Weight for SSIM loss (1.0 means only SSIM, 0.0 means only L1)
|
| 72 |
+
|
| 73 |
+
# Depth loss parameters
|
| 74 |
+
depth_l1_weight_init = 0.0 # Initial weight for depth L1 loss
|
| 75 |
+
depth_l1_weight_final = 0.0 # Final weight for depth L1 loss
|
| 76 |
+
depth_l1_delay_steps = 0 # Number of steps to delay depth loss
|
| 77 |
+
depth_l1_delay_mult = 0.0 # Multiplier for delay rate
|
| 78 |
+
|
| 79 |
+
near = 0.01 # Default near clipping plane
|
| 80 |
+
far = 100.0 # Default far clipping plane
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def get_depth_l1_weight(cls, step):
|
| 85 |
+
"""Compute the depth L1 loss weight for the current step.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
step (int): Current training step
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
float: Weight for depth L1 loss
|
| 92 |
+
"""
|
| 93 |
+
if step < 0 or (cls.depth_l1_weight_init == 0.0 and cls.depth_l1_weight_final == 0.0):
|
| 94 |
+
# Disable depth loss
|
| 95 |
+
return 0.0
|
| 96 |
+
|
| 97 |
+
if cls.depth_l1_delay_steps > 0:
|
| 98 |
+
# A kind of reverse cosine decay
|
| 99 |
+
delay_rate = cls.depth_l1_delay_mult + (1 - cls.depth_l1_delay_mult) * np.sin(
|
| 100 |
+
0.5 * np.pi * np.clip(step / cls.depth_l1_delay_steps, 0, 1)
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
delay_rate = 1.0
|
| 104 |
+
|
| 105 |
+
# Logarithmic interpolation between initial and final weights
|
| 106 |
+
t = np.clip(step / cls.num_iterations, 0, 1)
|
| 107 |
+
log_lerp = np.exp(np.log(cls.depth_l1_weight_init) * (1 - t) + np.log(cls.depth_l1_weight_final) * t)
|
| 108 |
+
|
| 109 |
+
return delay_rate * log_lerp
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def update(cls, **kwargs):
|
| 113 |
+
"""Update parameters with new values."""
|
| 114 |
+
for key, value in kwargs.items():
|
| 115 |
+
if hasattr(cls, key):
|
| 116 |
+
setattr(cls, key, value)
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError(f"Unknown parameter: {key}")
|
| 119 |
+
|
| 120 |
+
@classmethod
|
| 121 |
+
def get_config_dict(cls):
|
| 122 |
+
"""Get parameters as a dictionary."""
|
| 123 |
+
return {
|
| 124 |
+
'num_iterations': cls.num_iterations,
|
| 125 |
+
'num_points': cls.num_points,
|
| 126 |
+
'densification_interval': cls.densification_interval,
|
| 127 |
+
'pruning_interval': cls.pruning_interval,
|
| 128 |
+
'scale_modifier': cls.scale_modifier,
|
| 129 |
+
'sh_degree': cls.sh_degree,
|
| 130 |
+
'background_color': cls.background_color,
|
| 131 |
+
'save_interval': cls.save_interval,
|
| 132 |
+
'adam_beta1': cls.adam_beta1,
|
| 133 |
+
'adam_beta2': cls.adam_beta2,
|
| 134 |
+
'adam_epsilon': cls.adam_epsilon,
|
| 135 |
+
'initial_scale': cls.initial_scale,
|
| 136 |
+
'scene_scale': cls.scene_scale,
|
| 137 |
+
'near': cls.near,
|
| 138 |
+
'far': cls.far,
|
| 139 |
+
'lambda_dssim': cls.lambda_dssim,
|
| 140 |
+
'depth_l1_weight_init': cls.depth_l1_weight_init,
|
| 141 |
+
'depth_l1_weight_final': cls.depth_l1_weight_final,
|
| 142 |
+
'depth_l1_delay_steps': cls.depth_l1_delay_steps,
|
| 143 |
+
'depth_l1_delay_mult': cls.depth_l1_delay_mult,
|
| 144 |
+
'densify_grad_threshold': cls.densify_grad_threshold,
|
| 145 |
+
'cull_opacity_threshold': cls.cull_opacity_threshold,
|
| 146 |
+
'start_prune_iter': cls.start_prune_iter,
|
| 147 |
+
'end_prune_iter': cls.end_prune_iter,
|
| 148 |
+
'use_lr_scheduler': cls.use_lr_scheduler,
|
| 149 |
+
'lr_scheduler_config': cls.lr_scheduler_config,
|
| 150 |
+
'max_allowed_prune_ratio': cls.max_allowed_prune_ratio,
|
| 151 |
+
}
|
gs/create_training_video.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import glob
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
def create_training_video(input_pattern, output_path, fps=10):
|
| 8 |
+
"""
|
| 9 |
+
Create a video from training iteration images.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
input_pattern: Pattern to match image files (e.g., 'output/steak_is/point_cloud/iteration_*/rendered_view.png')
|
| 13 |
+
output_path: Path to save the output video
|
| 14 |
+
fps: Frames per second for the output video
|
| 15 |
+
"""
|
| 16 |
+
# Find all matching image files and sort them by iteration number
|
| 17 |
+
image_files = sorted(glob.glob(input_pattern),
|
| 18 |
+
key=lambda x: int(x.split('iteration_')[1].split('/')[0]))
|
| 19 |
+
|
| 20 |
+
if not image_files:
|
| 21 |
+
print(f"No images found matching pattern: {input_pattern}")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
print(f"Found {len(image_files)} image files")
|
| 25 |
+
|
| 26 |
+
# Read first image to get dimensions
|
| 27 |
+
first_img = cv2.imread(image_files[0])
|
| 28 |
+
h, w, _ = first_img.shape
|
| 29 |
+
|
| 30 |
+
# Create VideoWriter object
|
| 31 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 32 |
+
video = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
| 33 |
+
|
| 34 |
+
# Add each image to the video
|
| 35 |
+
for img_path in tqdm(image_files, desc="Creating video"):
|
| 36 |
+
img = cv2.imread(img_path)
|
| 37 |
+
|
| 38 |
+
# Optionally add iteration number as text overlay
|
| 39 |
+
iteration = int(img_path.split('iteration_')[1].split('/')[0])
|
| 40 |
+
cv2.putText(img, f"Iteration {iteration}", (20, 40),
|
| 41 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
| 42 |
+
|
| 43 |
+
video.write(img)
|
| 44 |
+
|
| 45 |
+
# Release the video writer
|
| 46 |
+
video.release()
|
| 47 |
+
print(f"Video created successfully: {output_path}")
|
| 48 |
+
|
| 49 |
+
# Add a simple UI to select images and set options
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
import argparse
|
| 52 |
+
|
| 53 |
+
parser = argparse.ArgumentParser(description='Create a video from training iteration images')
|
| 54 |
+
parser.add_argument('--input', default='output/steak_is/point_cloud/iteration_*/rendered_view.png',
|
| 55 |
+
help='Pattern to match image files')
|
| 56 |
+
parser.add_argument('--output', default='training_progress.mp4',
|
| 57 |
+
help='Path to save the output video')
|
| 58 |
+
parser.add_argument('--fps', type=int, default=10,
|
| 59 |
+
help='Frames per second for the output video')
|
| 60 |
+
parser.add_argument('--reverse', action='store_true',
|
| 61 |
+
help='Reverse the order of images (show latest first)')
|
| 62 |
+
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
|
| 65 |
+
if args.reverse:
|
| 66 |
+
# Find all matching image files and sort them in reverse order
|
| 67 |
+
image_files = sorted(glob.glob(args.input),
|
| 68 |
+
key=lambda x: int(x.split('iteration_')[1].split('/')[0]),
|
| 69 |
+
reverse=True)
|
| 70 |
+
if image_files:
|
| 71 |
+
create_training_video(image_files, args.output, args.fps)
|
| 72 |
+
else:
|
| 73 |
+
create_training_video(args.input, args.output, args.fps)
|
gs/dataset_reader.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
gs/forward.py
ADDED
|
@@ -0,0 +1,804 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warp as wp
|
| 2 |
+
from utils.wp_utils import to_warp_array
|
| 3 |
+
from config import *
|
| 4 |
+
# Initialize Warp
|
| 5 |
+
wp.init()
|
| 6 |
+
print("Warp devices:", wp.get_devices())
|
| 7 |
+
# Define spherical harmonics constants
|
| 8 |
+
SH_C0 = 0.28209479177387814
|
| 9 |
+
SH_C1 = 0.4886025119029199
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import warp as wp
|
| 13 |
+
|
| 14 |
+
# Define the CUDA code snippets for bit reinterpretation
|
| 15 |
+
float_to_uint32_snippet = """
|
| 16 |
+
return reinterpret_cast<uint32_t&>(x);
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
@wp.func_native(float_to_uint32_snippet)
|
| 20 |
+
def float_bits_to_uint32(x: float) -> wp.uint32:
|
| 21 |
+
...
|
| 22 |
+
|
| 23 |
+
@wp.func
|
| 24 |
+
def ndc2pix(x: float, size: float) -> float:
|
| 25 |
+
return ((x + 1.0) * size - 1.0) * 0.5
|
| 26 |
+
|
| 27 |
+
@wp.func
|
| 28 |
+
def get_rect(p: wp.vec2, max_radius: float, tile_grid: wp.vec3):
|
| 29 |
+
# Extract grid dimensions
|
| 30 |
+
grid_size_x = tile_grid[0]
|
| 31 |
+
grid_size_y = tile_grid[1]
|
| 32 |
+
|
| 33 |
+
rect_min_x = wp.min(wp.int32(grid_size_x), wp.int32(wp.max(wp.int32(0), wp.int32((p[0] - max_radius) / float(TILE_M)))))
|
| 34 |
+
rect_min_y = wp.min(wp.int32(grid_size_y), wp.int32(wp.max(wp.int32(0), wp.int32((p[1] - max_radius) / float(TILE_N)))))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
rect_max_x = wp.min(wp.int32(grid_size_x), wp.int32(wp.max(wp.int32(0), wp.int32((p[0] + max_radius + float(TILE_M) - 1.0) / float(TILE_M)))))
|
| 38 |
+
rect_max_y = wp.min(wp.int32(grid_size_y), wp.int32(wp.max(wp.int32(0), wp.int32((p[1] + max_radius + float(TILE_N) - 1.0) / float(TILE_N)))))
|
| 39 |
+
|
| 40 |
+
return rect_min_x, rect_min_y, rect_max_x, rect_max_y
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@wp.func
|
| 44 |
+
def compute_cov2d(p_orig: wp.vec3, cov3d: VEC6, view_matrix: wp.mat44,
|
| 45 |
+
tan_fovx: float, tan_fovy: float, width: float, height: float) -> wp.vec3:
|
| 46 |
+
|
| 47 |
+
t = wp.vec4(p_orig[0], p_orig[1], p_orig[2], 1.0) * view_matrix
|
| 48 |
+
limx = 1.3 * tan_fovx
|
| 49 |
+
limy = 1.3 * tan_fovy
|
| 50 |
+
# Clamp X/Y to stay inside frustum
|
| 51 |
+
txtz = t[0] / t[2]
|
| 52 |
+
tytz = t[1] / t[2]
|
| 53 |
+
t[0] = min(limx, max(-limx, txtz)) * t[2]
|
| 54 |
+
t[1] = min(limy, max(-limy, tytz)) * t[2]
|
| 55 |
+
|
| 56 |
+
focal_x = width / (2.0 * tan_fovx)
|
| 57 |
+
focal_y = height / (2.0 * tan_fovy)
|
| 58 |
+
# compute Jacobian
|
| 59 |
+
J = wp.mat33(
|
| 60 |
+
focal_x / t[2], 0.0, -(focal_x * t[0]) / (t[2] * t[2]),
|
| 61 |
+
0.0, focal_y / t[2], -(focal_y * t[1]) / (t[2] * t[2]),
|
| 62 |
+
0.0, 0.0, 0.0
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
W = wp.mat33(
|
| 66 |
+
view_matrix[0, 0], view_matrix[0, 1], view_matrix[0, 2],
|
| 67 |
+
view_matrix[1, 0], view_matrix[1, 1], view_matrix[1, 2],
|
| 68 |
+
view_matrix[2, 0], view_matrix[2, 1], view_matrix[2, 2]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
T = J * W
|
| 72 |
+
|
| 73 |
+
Vrk = wp.mat33(
|
| 74 |
+
cov3d[0], cov3d[1], cov3d[2],
|
| 75 |
+
cov3d[1], cov3d[3], cov3d[4],
|
| 76 |
+
cov3d[2], cov3d[4], cov3d[5]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
cov = T * wp.transpose(Vrk) * wp.transpose(T)
|
| 80 |
+
|
| 81 |
+
return wp.vec3(cov[0, 0], cov[0, 1], cov[1, 1])
|
| 82 |
+
|
| 83 |
+
@wp.func
|
| 84 |
+
def compute_cov3d(scale: wp.vec3, scale_mod: float, rot: wp.vec4) -> VEC6:
|
| 85 |
+
# Create scaling matrix with modifier applied
|
| 86 |
+
S = wp.mat33(
|
| 87 |
+
scale_mod * scale[0], 0.0, 0.0,
|
| 88 |
+
0.0, scale_mod * scale[1], 0.0,
|
| 89 |
+
0.0, 0.0, scale_mod * scale[2]
|
| 90 |
+
)
|
| 91 |
+
R = wp.quat_to_matrix(wp.quaternion(rot[0], rot[1], rot[2], rot[3]))
|
| 92 |
+
M = R * S
|
| 93 |
+
|
| 94 |
+
# Compute 3D covariance matrix: Sigma = M * M^T
|
| 95 |
+
sigma = M * wp.transpose(M)
|
| 96 |
+
|
| 97 |
+
return VEC6(sigma[0, 0], sigma[0, 1], sigma[0, 2], sigma[1, 1], sigma[1, 2], sigma[2, 2])
|
| 98 |
+
|
| 99 |
+
@wp.kernel
|
| 100 |
+
def wp_preprocess(
|
| 101 |
+
orig_points: wp.array(dtype=wp.vec3),
|
| 102 |
+
scales: wp.array(dtype=wp.vec3),
|
| 103 |
+
scale_modifier: float,
|
| 104 |
+
rotations: wp.array(dtype=wp.vec4),
|
| 105 |
+
|
| 106 |
+
opacities: wp.array(dtype=float),
|
| 107 |
+
shs: wp.array(dtype=wp.vec3),
|
| 108 |
+
degree: int,
|
| 109 |
+
clamped: bool,
|
| 110 |
+
|
| 111 |
+
view_matrix: wp.mat44,
|
| 112 |
+
proj_matrix: wp.mat44,
|
| 113 |
+
cam_pos: wp.vec3,
|
| 114 |
+
|
| 115 |
+
W: int,
|
| 116 |
+
H: int,
|
| 117 |
+
|
| 118 |
+
tan_fovx: float,
|
| 119 |
+
tan_fovy: float,
|
| 120 |
+
|
| 121 |
+
focal_x: float,
|
| 122 |
+
focal_y: float,
|
| 123 |
+
|
| 124 |
+
radii: wp.array(dtype=int),
|
| 125 |
+
points_xy_image: wp.array(dtype=wp.vec2),
|
| 126 |
+
depths: wp.array(dtype=float),
|
| 127 |
+
cov3Ds: wp.array(dtype=VEC6),
|
| 128 |
+
rgb: wp.array(dtype=wp.vec3),
|
| 129 |
+
conic_opacity: wp.array(dtype=wp.vec4),
|
| 130 |
+
tile_grid: wp.vec3,
|
| 131 |
+
tiles_touched: wp.array(dtype=int),
|
| 132 |
+
clamped_state: wp.array(dtype=wp.vec3),
|
| 133 |
+
|
| 134 |
+
prefiltered: bool,
|
| 135 |
+
antialiasing: bool
|
| 136 |
+
):
|
| 137 |
+
# Get thread indices
|
| 138 |
+
i = wp.tid()
|
| 139 |
+
|
| 140 |
+
# For each Gaussian
|
| 141 |
+
p_orig = orig_points[i]
|
| 142 |
+
p_view = wp.vec4(p_orig[0], p_orig[1], p_orig[2], 1.0) * view_matrix
|
| 143 |
+
|
| 144 |
+
if p_view[2] < 0.2:
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
p_hom = wp.vec4(p_orig[0], p_orig[1], p_orig[2], 1.0) * proj_matrix
|
| 148 |
+
|
| 149 |
+
p_w = 1.0 / (p_hom[3] + 0.0000001)
|
| 150 |
+
p_proj = wp.vec3(p_hom[0] * p_w, p_hom[1] * p_w, p_hom[2] * p_w)
|
| 151 |
+
|
| 152 |
+
cov3d = compute_cov3d(scales[i], scale_modifier, rotations[i])
|
| 153 |
+
|
| 154 |
+
cov3Ds[i] = cov3d
|
| 155 |
+
# Compute 2D covariance matrix
|
| 156 |
+
cov2d = compute_cov2d(p_orig, cov3d, view_matrix, tan_fovx, tan_fovy, float(W), float(H))
|
| 157 |
+
|
| 158 |
+
# Constants
|
| 159 |
+
h_var = 0.3
|
| 160 |
+
W_float = float(W)
|
| 161 |
+
H_float = float(H)
|
| 162 |
+
C = 3 # RGB channels
|
| 163 |
+
|
| 164 |
+
# Add blur/antialiasing factor to covariance
|
| 165 |
+
det_cov = cov2d[0] * cov2d[2] - cov2d[1] * cov2d[1]
|
| 166 |
+
cov_with_blur = wp.vec3(cov2d[0] + h_var, cov2d[1], cov2d[2] + h_var)
|
| 167 |
+
det_cov_plus_h_cov = cov_with_blur[0] * cov_with_blur[2] - cov_with_blur[1] * cov_with_blur[1]
|
| 168 |
+
|
| 169 |
+
# Invert covariance (EWA algorithm)
|
| 170 |
+
det = det_cov_plus_h_cov
|
| 171 |
+
if det == 0.0:
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
det_inv = 1.0 / det
|
| 175 |
+
conic = wp.vec3(
|
| 176 |
+
cov_with_blur[2] * det_inv,
|
| 177 |
+
-cov_with_blur[1] * det_inv,
|
| 178 |
+
cov_with_blur[0] * det_inv
|
| 179 |
+
)
|
| 180 |
+
# Compute eigenvalues of covariance matrix to find screen-space extent
|
| 181 |
+
mid = 0.5 * (cov_with_blur[0] + cov_with_blur[2])
|
| 182 |
+
lambda1 = mid + wp.sqrt(wp.max(0.1, mid * mid - det))
|
| 183 |
+
lambda2 = mid - wp.sqrt(wp.max(0.1, mid * mid - det))
|
| 184 |
+
my_radius = wp.ceil(3.0 * wp.sqrt(wp.max(lambda1, lambda2)))
|
| 185 |
+
# Convert to pixel coordinates
|
| 186 |
+
point_image = wp.vec2(ndc2pix(p_proj[0], W_float), ndc2pix(p_proj[1], H_float))
|
| 187 |
+
|
| 188 |
+
# Get rectangle of affected tiles
|
| 189 |
+
rect_min_x, rect_min_y, rect_max_x, rect_max_y = get_rect(point_image, my_radius, tile_grid)
|
| 190 |
+
|
| 191 |
+
# Skip if rectangle has 0 area
|
| 192 |
+
if (rect_max_x - rect_min_x) * (rect_max_y - rect_min_y) == 0:
|
| 193 |
+
return
|
| 194 |
+
# Compute color from spherical harmonics
|
| 195 |
+
pos = p_orig
|
| 196 |
+
dir_orig = pos - cam_pos
|
| 197 |
+
dir = wp.normalize(dir_orig)
|
| 198 |
+
x, y, z = dir[0], dir[1], dir[2]
|
| 199 |
+
|
| 200 |
+
# Base offset for this Gaussian's SH coefficients
|
| 201 |
+
base_idx = i * 16 # assuming degree 3 (16 coefficients)
|
| 202 |
+
|
| 203 |
+
# Start with the DC component (degree 0)
|
| 204 |
+
result = SH_C0 * shs[base_idx]
|
| 205 |
+
|
| 206 |
+
# Add higher degree terms if requested
|
| 207 |
+
if degree > 0:
|
| 208 |
+
# Degree 1 terms
|
| 209 |
+
result = result - SH_C1 * y * shs[base_idx + 1] + SH_C1 * z * shs[base_idx + 2] - SH_C1 * x * shs[base_idx + 3]
|
| 210 |
+
|
| 211 |
+
if degree > 1:
|
| 212 |
+
# Degree 2 terms
|
| 213 |
+
xx = x*x
|
| 214 |
+
yy = y*y
|
| 215 |
+
zz = z*z
|
| 216 |
+
xy = x*y
|
| 217 |
+
yz = y*z
|
| 218 |
+
xz = x*z
|
| 219 |
+
|
| 220 |
+
# Degree 2 terms with hardcoded constants
|
| 221 |
+
result = result + 1.0925484305920792 * xy * shs[base_idx + 4]
|
| 222 |
+
result = result + (-1.0925484305920792) * yz * shs[base_idx + 5]
|
| 223 |
+
result = result + 0.31539156525252005 * (2.0 * zz - xx - yy) * shs[base_idx + 6]
|
| 224 |
+
result = result + (-1.0925484305920792) * xz * shs[base_idx + 7]
|
| 225 |
+
result = result + 0.5462742152960396 * (xx - yy) * shs[base_idx + 8]
|
| 226 |
+
|
| 227 |
+
if degree > 2:
|
| 228 |
+
# Degree 3 terms with hardcoded constants
|
| 229 |
+
result = result + (-0.5900435899266435) * y * (3.0 * xx - yy) * shs[base_idx + 9]
|
| 230 |
+
result = result + 2.890611442640554 * xy * z * shs[base_idx + 10]
|
| 231 |
+
result = result + (-0.4570457994644658) * y * (4.0 * zz - xx - yy) * shs[base_idx + 11]
|
| 232 |
+
result = result + 0.3731763325901154 * z * (2.0 * zz - 3.0 * xx - 3.0 * yy) * shs[base_idx + 12]
|
| 233 |
+
result = result + (-0.4570457994644658) * x * (4.0 * zz - xx - yy) * shs[base_idx + 13]
|
| 234 |
+
result = result + 1.445305721320277 * z * (xx - yy) * shs[base_idx + 14]
|
| 235 |
+
result = result + (-0.5900435899266435) * x * (xx - 3.0 * yy) * shs[base_idx + 15]
|
| 236 |
+
|
| 237 |
+
result = result + wp.vec3(0.5, 0.5, 0.5)
|
| 238 |
+
|
| 239 |
+
# Track which color channels are clamped (using wp.vec3 instead of separate uint32 values)
|
| 240 |
+
# Store 1.0 if clamped, 0.0 if not clamped
|
| 241 |
+
# Use separate assignments instead of conditional expressions
|
| 242 |
+
r_clamped = 0.0
|
| 243 |
+
g_clamped = 0.0
|
| 244 |
+
b_clamped = 0.0
|
| 245 |
+
|
| 246 |
+
if result[0] < 0.0:
|
| 247 |
+
r_clamped = 1.0
|
| 248 |
+
if result[1] < 0.0:
|
| 249 |
+
g_clamped = 1.0
|
| 250 |
+
if result[2] < 0.0:
|
| 251 |
+
b_clamped = 1.0
|
| 252 |
+
|
| 253 |
+
clamped_state[i] = wp.vec3(r_clamped, g_clamped, b_clamped)
|
| 254 |
+
|
| 255 |
+
if clamped:
|
| 256 |
+
# RGB colors are clamped to positive values
|
| 257 |
+
result = wp.vec3(
|
| 258 |
+
wp.max(result[0], 0.0),
|
| 259 |
+
wp.max(result[1], 0.0),
|
| 260 |
+
wp.max(result[2], 0.0)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
rgb[i] = result
|
| 264 |
+
|
| 265 |
+
# Store computed data
|
| 266 |
+
depths[i] = p_view[2]
|
| 267 |
+
radii[i] = int(my_radius)
|
| 268 |
+
points_xy_image[i] = point_image
|
| 269 |
+
|
| 270 |
+
# Pack conic and opacity into single vec4
|
| 271 |
+
conic_opacity[i] = wp.vec4(conic[0], conic[1], conic[2], opacities[i])
|
| 272 |
+
# Store tile information
|
| 273 |
+
tiles_touched[i] = (rect_max_y - rect_min_y) * (rect_max_x - rect_min_x)
|
| 274 |
+
|
| 275 |
+
@wp.kernel
|
| 276 |
+
def wp_render_gaussians(
|
| 277 |
+
# Output buffers
|
| 278 |
+
rendered_image: wp.array2d(dtype=wp.vec3),
|
| 279 |
+
depth_image: wp.array2d(dtype=float),
|
| 280 |
+
|
| 281 |
+
# Tile data
|
| 282 |
+
ranges: wp.array(dtype=wp.vec2i),
|
| 283 |
+
point_list: wp.array(dtype=int),
|
| 284 |
+
|
| 285 |
+
# Image parameters
|
| 286 |
+
W: int,
|
| 287 |
+
H: int,
|
| 288 |
+
|
| 289 |
+
# Gaussian data
|
| 290 |
+
points_xy_image: wp.array(dtype=wp.vec2),
|
| 291 |
+
colors: wp.array(dtype=wp.vec3),
|
| 292 |
+
conic_opacity: wp.array(dtype=wp.vec4),
|
| 293 |
+
depths: wp.array(dtype=float),
|
| 294 |
+
|
| 295 |
+
# Background color
|
| 296 |
+
background: wp.vec3,
|
| 297 |
+
|
| 298 |
+
# Tile grid info
|
| 299 |
+
tile_grid: wp.vec3,
|
| 300 |
+
|
| 301 |
+
# Track additional data
|
| 302 |
+
final_Ts: wp.array2d(dtype=float),
|
| 303 |
+
n_contrib: wp.array2d(dtype=int),
|
| 304 |
+
):
|
| 305 |
+
tile_x, tile_y, tid_x, tid_y = wp.tid()
|
| 306 |
+
|
| 307 |
+
# Calculate tile index
|
| 308 |
+
|
| 309 |
+
if tile_y >= (H + TILE_N - 1) // TILE_N:
|
| 310 |
+
return
|
| 311 |
+
|
| 312 |
+
# Calculate pixel boundaries for this tile
|
| 313 |
+
pix_min_x = tile_x * TILE_M
|
| 314 |
+
pix_min_y = tile_y * TILE_N
|
| 315 |
+
pix_max_x = wp.min(pix_min_x + TILE_M, W)
|
| 316 |
+
pix_max_y = wp.min(pix_min_y + TILE_N, H)
|
| 317 |
+
|
| 318 |
+
# Calculate pixel position for this thread
|
| 319 |
+
pix_x = pix_min_x + tid_x
|
| 320 |
+
pix_y = pix_min_y + tid_y
|
| 321 |
+
|
| 322 |
+
# Check if this thread processes a valid pixel
|
| 323 |
+
inside = (pix_x < W) and (pix_y < H)
|
| 324 |
+
if not inside:
|
| 325 |
+
return
|
| 326 |
+
|
| 327 |
+
pixf_x = float(pix_x)
|
| 328 |
+
pixf_y = float(pix_y)
|
| 329 |
+
|
| 330 |
+
# Get start/end range of IDs to process for this tile
|
| 331 |
+
tile_id = tile_y * int(tile_grid[0]) + tile_x
|
| 332 |
+
range_start = ranges[tile_id][0]
|
| 333 |
+
range_end = ranges[tile_id][1]
|
| 334 |
+
|
| 335 |
+
# Initialize blending variables
|
| 336 |
+
T = float(1.0) # Transmittance
|
| 337 |
+
r, g, b = float(0.0), float(0.0), float(0.0) # Accumulated color
|
| 338 |
+
expected_inv_depth = float(0.0) # For depth calculation
|
| 339 |
+
|
| 340 |
+
# Track the number of contributors to this pixel
|
| 341 |
+
contributor_count = int(0)
|
| 342 |
+
last_contributor = int(0)
|
| 343 |
+
|
| 344 |
+
# Iterate over all Gaussians influencing this tile
|
| 345 |
+
for i in range(range_start, range_end):
|
| 346 |
+
# Get Gaussian ID
|
| 347 |
+
gaussian_id = point_list[i]
|
| 348 |
+
|
| 349 |
+
# Get Gaussian data
|
| 350 |
+
xy = points_xy_image[gaussian_id]
|
| 351 |
+
con_o = conic_opacity[gaussian_id]
|
| 352 |
+
color = colors[gaussian_id]
|
| 353 |
+
|
| 354 |
+
# Compute distance to Gaussian center
|
| 355 |
+
d_x = xy[0] - pixf_x
|
| 356 |
+
d_y = xy[1] - pixf_y
|
| 357 |
+
|
| 358 |
+
# Increment contributor count for this pixel
|
| 359 |
+
contributor_count += 1
|
| 360 |
+
|
| 361 |
+
# Compute Gaussian power (exponent)
|
| 362 |
+
power = -0.5 * (con_o[0] * d_x * d_x + con_o[2] * d_y * d_y) - con_o[1] * d_x * d_y
|
| 363 |
+
|
| 364 |
+
# Skip if power is positive (too far away)
|
| 365 |
+
if power > 0.0:
|
| 366 |
+
continue
|
| 367 |
+
|
| 368 |
+
# Compute alpha from power and opacity
|
| 369 |
+
alpha = wp.min(0.99, con_o[3] * wp.exp(power))
|
| 370 |
+
|
| 371 |
+
# Skip if alpha is too small
|
| 372 |
+
if alpha < (1.0 / 255.0):
|
| 373 |
+
continue
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# Test if we're close to fully opaque
|
| 377 |
+
test_T = T * (1.0 - alpha)
|
| 378 |
+
if test_T < 0.0001:
|
| 379 |
+
break # Early termination if pixel is almost opaque
|
| 380 |
+
|
| 381 |
+
# Accumulate color contribution
|
| 382 |
+
r += color[0] * alpha * T
|
| 383 |
+
g += color[1] * alpha * T
|
| 384 |
+
b += color[2] * alpha * T
|
| 385 |
+
|
| 386 |
+
# Accumulate inverse depth
|
| 387 |
+
expected_inv_depth += (1.0 / depths[gaussian_id]) * alpha * T
|
| 388 |
+
|
| 389 |
+
# Update transmittance
|
| 390 |
+
T = test_T
|
| 391 |
+
|
| 392 |
+
last_contributor = contributor_count
|
| 393 |
+
|
| 394 |
+
# Store final transmittance (T) and contributor count
|
| 395 |
+
final_Ts[pix_y, pix_x] = T
|
| 396 |
+
n_contrib[pix_y, pix_x] = last_contributor
|
| 397 |
+
|
| 398 |
+
# Write final color to output buffer (color + background)
|
| 399 |
+
rendered_image[pix_y, pix_x] = wp.vec3(
|
| 400 |
+
r + T * background[0],
|
| 401 |
+
g + T * background[1],
|
| 402 |
+
b + T * background[2]
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Write depth to output buffer
|
| 406 |
+
depth_image[pix_y, pix_x] = expected_inv_depth
|
| 407 |
+
|
| 408 |
+
@wp.kernel
|
| 409 |
+
def wp_duplicate_with_keys(
|
| 410 |
+
points_xy_image: wp.array(dtype=wp.vec2),
|
| 411 |
+
depths: wp.array(dtype=float),
|
| 412 |
+
point_offsets: wp.array(dtype=int),
|
| 413 |
+
point_list_keys_unsorted: wp.array(dtype=wp.int64),
|
| 414 |
+
point_list_unsorted: wp.array(dtype=int),
|
| 415 |
+
radii: wp.array(dtype=int),
|
| 416 |
+
tile_grid: wp.vec3
|
| 417 |
+
):
|
| 418 |
+
tid = wp.tid()
|
| 419 |
+
|
| 420 |
+
if tid >= points_xy_image.shape[0]:
|
| 421 |
+
return
|
| 422 |
+
|
| 423 |
+
r = radii[tid]
|
| 424 |
+
if r <= 0:
|
| 425 |
+
return
|
| 426 |
+
|
| 427 |
+
# Find the global offset into key/value buffers
|
| 428 |
+
offset = 0
|
| 429 |
+
if tid > 0:
|
| 430 |
+
offset = point_offsets[tid - 1]
|
| 431 |
+
|
| 432 |
+
pos = points_xy_image[tid]
|
| 433 |
+
depth_val = depths[tid]
|
| 434 |
+
|
| 435 |
+
rect_min_x, rect_min_y, rect_max_x, rect_max_y = get_rect(pos, float(r), tile_grid)
|
| 436 |
+
|
| 437 |
+
for y in range(rect_min_y, rect_max_y):
|
| 438 |
+
for x in range(rect_min_x, rect_max_x):
|
| 439 |
+
tile_id = y * int(tile_grid[0]) + x
|
| 440 |
+
# Convert to int64 to avoid overflow during bit shift
|
| 441 |
+
tile_id_64 = wp.int64(tile_id)
|
| 442 |
+
shifted = tile_id_64 << wp.int64(32)
|
| 443 |
+
depth_bits = wp.int64(float_bits_to_uint32(depth_val))
|
| 444 |
+
# Combine tile ID and depth into single key
|
| 445 |
+
key = wp.int64(shifted) | depth_bits
|
| 446 |
+
|
| 447 |
+
point_list_keys_unsorted[offset] = key
|
| 448 |
+
point_list_unsorted[offset] = tid
|
| 449 |
+
offset += 1
|
| 450 |
+
|
| 451 |
+
@wp.kernel
|
| 452 |
+
def wp_identify_tile_ranges(
|
| 453 |
+
num_rendered: int,
|
| 454 |
+
point_list_keys: wp.array(dtype=wp.int64),
|
| 455 |
+
ranges: wp.array(dtype=wp.vec2i) # Each range is (start, end)
|
| 456 |
+
):
|
| 457 |
+
idx = wp.tid()
|
| 458 |
+
|
| 459 |
+
if idx >= num_rendered:
|
| 460 |
+
return
|
| 461 |
+
|
| 462 |
+
key = point_list_keys[idx]
|
| 463 |
+
curr_tile = int(key >> wp.int64(32))
|
| 464 |
+
|
| 465 |
+
# Set start of range if first element or tile changed
|
| 466 |
+
if idx == 0:
|
| 467 |
+
ranges[curr_tile][0] = 0
|
| 468 |
+
else:
|
| 469 |
+
prev_key = point_list_keys[idx - 1]
|
| 470 |
+
prev_tile = int(prev_key >> wp.int64(32))
|
| 471 |
+
if curr_tile != prev_tile:
|
| 472 |
+
ranges[prev_tile][1] = idx
|
| 473 |
+
ranges[curr_tile][0] = idx
|
| 474 |
+
|
| 475 |
+
# Set end of range if last element
|
| 476 |
+
if idx == num_rendered - 1:
|
| 477 |
+
ranges[curr_tile][1] = num_rendered
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
@wp.kernel
|
| 481 |
+
def wp_prefix_sum(input_array: wp.array(dtype=int),
|
| 482 |
+
output_array: wp.array(dtype=int)):
|
| 483 |
+
tid = wp.tid()
|
| 484 |
+
|
| 485 |
+
if tid == 0:
|
| 486 |
+
output_array[0] = input_array[0]
|
| 487 |
+
|
| 488 |
+
# Perform prefix sum
|
| 489 |
+
for i in range(1, input_array.shape[0]):
|
| 490 |
+
output_array[i] = output_array[i-1] + input_array[i]
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@wp.kernel
|
| 494 |
+
def wp_copy_int64(src: wp.array(dtype=wp.int64), dst: wp.array(dtype=wp.int64), count: int):
|
| 495 |
+
i = wp.tid()
|
| 496 |
+
if i < count:
|
| 497 |
+
dst[i] = src[i]
|
| 498 |
+
|
| 499 |
+
@wp.kernel
|
| 500 |
+
def wp_copy_int(src: wp.array(dtype=int), dst: wp.array(dtype=int), count: int):
|
| 501 |
+
i = wp.tid()
|
| 502 |
+
if i < count:
|
| 503 |
+
dst[i] = src[i]
|
| 504 |
+
|
| 505 |
+
@wp.kernel
|
| 506 |
+
def track_pixel_stats(
|
| 507 |
+
rendered_image: wp.array2d(dtype=wp.vec3),
|
| 508 |
+
depth_image: wp.array2d(dtype=float),
|
| 509 |
+
background: wp.vec3,
|
| 510 |
+
final_Ts: wp.array2d(dtype=float),
|
| 511 |
+
n_contrib: wp.array2d(dtype=int),
|
| 512 |
+
W: int,
|
| 513 |
+
H: int
|
| 514 |
+
):
|
| 515 |
+
"""Kernel to track final transparency values and contributor counts for each pixel."""
|
| 516 |
+
x, y = wp.tid()
|
| 517 |
+
|
| 518 |
+
if x >= W or y >= H:
|
| 519 |
+
return
|
| 520 |
+
|
| 521 |
+
# Get the rendered pixel
|
| 522 |
+
pixel = rendered_image[y, x]
|
| 523 |
+
|
| 524 |
+
# Calculate approximate alpha transparency by checking for background contribution
|
| 525 |
+
# If the pixel has no contribution from background, final_T should be close to 0
|
| 526 |
+
# If it's mostly background, final_T will be close to 1
|
| 527 |
+
diff_r = abs(pixel[0] - background[0])
|
| 528 |
+
diff_g = abs(pixel[1] - background[1])
|
| 529 |
+
diff_b = abs(pixel[2] - background[2])
|
| 530 |
+
has_content = (diff_r > 0.01) or (diff_g > 0.01) or (diff_b > 0.01)
|
| 531 |
+
|
| 532 |
+
if has_content:
|
| 533 |
+
# Approximate final_T - in a real scenario this should already be tracked during rendering
|
| 534 |
+
# We're just making sure it's populated for existing renderings
|
| 535 |
+
if final_Ts[y, x] == 0.0:
|
| 536 |
+
# If final_Ts hasn't been set during rendering, approximate it
|
| 537 |
+
# Higher difference from background means lower T
|
| 538 |
+
max_diff = max(diff_r, max(diff_g, diff_b))
|
| 539 |
+
final_Ts[y, x] = 1.0 - min(0.99, max_diff)
|
| 540 |
+
|
| 541 |
+
# Set n_contrib to 1 if we know the pixel has content but no contributor count
|
| 542 |
+
if n_contrib[y, x] == 0:
|
| 543 |
+
n_contrib[y, x] = 1
|
| 544 |
+
|
| 545 |
+
def render_gaussians(
|
| 546 |
+
background,
|
| 547 |
+
means3D,
|
| 548 |
+
colors=None,
|
| 549 |
+
opacity=None,
|
| 550 |
+
scales=None,
|
| 551 |
+
rotations=None,
|
| 552 |
+
scale_modifier=1.0,
|
| 553 |
+
viewmatrix=None,
|
| 554 |
+
projmatrix=None,
|
| 555 |
+
tan_fovx=0.5,
|
| 556 |
+
tan_fovy=0.5,
|
| 557 |
+
image_height=256,
|
| 558 |
+
image_width=256,
|
| 559 |
+
sh=None,
|
| 560 |
+
degree=3,
|
| 561 |
+
campos=None,
|
| 562 |
+
prefiltered=False,
|
| 563 |
+
antialiasing=False,
|
| 564 |
+
clamped=True,
|
| 565 |
+
debug=False,
|
| 566 |
+
):
|
| 567 |
+
"""Render 3D Gaussians using Warp.
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
background: Background color tensor of shape (3,)
|
| 571 |
+
means3D: 3D positions tensor of shape (N, 3)
|
| 572 |
+
colors: Optional RGB colors tensor of shape (N, 3)
|
| 573 |
+
opacity: Opacity values tensor of shape (N, 1) or (N,)
|
| 574 |
+
scales: Scales tensor of shape (N, 3)
|
| 575 |
+
rotations: Rotation quaternions of shape (N, 4)
|
| 576 |
+
scale_modifier: Global scale modifier (float)
|
| 577 |
+
viewmatrix: View matrix tensor of shape (4, 4)
|
| 578 |
+
projmatrix: Projection matrix tensor of shape (4, 4)
|
| 579 |
+
tan_fovx: Tangent of the horizontal field of view
|
| 580 |
+
tan_fovy: Tangent of the vertical field of view
|
| 581 |
+
image_height: Height of the output image
|
| 582 |
+
image_width: Width of the output image
|
| 583 |
+
sh: Spherical harmonics coefficients tensor of shape (N, D, 3)
|
| 584 |
+
degree: Degree of spherical harmonics
|
| 585 |
+
campos: Camera position tensor of shape (3,)
|
| 586 |
+
prefiltered: Whether input Gaussians are prefiltered
|
| 587 |
+
antialiasing: Whether to apply antialiasing
|
| 588 |
+
clamped: Whether to clamp the colors
|
| 589 |
+
debug: Whether to print debug information
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
Tuple of (rendered_image, depth_image, intermediate_buffers)
|
| 593 |
+
"""
|
| 594 |
+
rendered_image = wp.zeros((image_height, image_width), dtype=wp.vec3, device=DEVICE)
|
| 595 |
+
depth_image = wp.zeros((image_height, image_width), dtype=float, device=DEVICE)
|
| 596 |
+
|
| 597 |
+
# Create additional buffers for tracking transparency and contributors
|
| 598 |
+
final_Ts = wp.zeros((image_height, image_width), dtype=float, device=DEVICE)
|
| 599 |
+
n_contrib = wp.zeros((image_height, image_width), dtype=int, device=DEVICE)
|
| 600 |
+
|
| 601 |
+
background_warp = wp.vec3(background[0], background[1], background[2])
|
| 602 |
+
points_warp = to_warp_array(means3D, wp.vec3)#(device=DEVICE)
|
| 603 |
+
# SH coefficients should be shape (n, 16, 3)
|
| 604 |
+
# Convert to a flattened array but preserve the structure
|
| 605 |
+
sh_data = sh.reshape(-1, 3) if hasattr(sh, 'reshape') else sh
|
| 606 |
+
shs_warp = to_warp_array(sh_data, wp.vec3)#.to(device=DEVICE)
|
| 607 |
+
|
| 608 |
+
# Handle other parameters
|
| 609 |
+
opacities_warp = to_warp_array(opacity, float, flatten=True)#.to(device=DEVICE)
|
| 610 |
+
scales_warp = to_warp_array(scales, wp.vec3)#.to(device=DEVICE)
|
| 611 |
+
rotations_warp = to_warp_array(rotations, wp.vec4)#.to(device=DEVICE)
|
| 612 |
+
|
| 613 |
+
# Handle camera parameters
|
| 614 |
+
view_matrix_warp = wp.mat44(viewmatrix.flatten()) if not isinstance(viewmatrix, wp.mat44) else viewmatrix
|
| 615 |
+
proj_matrix_warp = wp.mat44(projmatrix.flatten()) if not isinstance(projmatrix, wp.mat44) else projmatrix
|
| 616 |
+
campos_warp = wp.vec3(campos[0], campos[1], campos[2]) if not isinstance(campos, wp.vec3) else campos
|
| 617 |
+
|
| 618 |
+
# Calculate tile grid for spatial optimization
|
| 619 |
+
tile_grid = wp.vec3((image_width + TILE_M - 1) // TILE_M,
|
| 620 |
+
(image_height + TILE_N - 1) // TILE_N,
|
| 621 |
+
1)
|
| 622 |
+
|
| 623 |
+
# Preallocate buffers for preprocessed data
|
| 624 |
+
num_points = points_warp.shape[0]
|
| 625 |
+
radii = wp.zeros(num_points, dtype=int, device=DEVICE)
|
| 626 |
+
points_xy_image = wp.zeros(num_points, dtype=wp.vec2, device=DEVICE)
|
| 627 |
+
depths = wp.zeros(num_points, dtype=float, device=DEVICE)
|
| 628 |
+
cov3Ds = wp.zeros(num_points, dtype=VEC6, device=DEVICE)
|
| 629 |
+
rgb = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
|
| 630 |
+
conic_opacity = wp.zeros(num_points, dtype=wp.vec4, device=DEVICE)
|
| 631 |
+
tiles_touched = wp.zeros(num_points, dtype=int, device=DEVICE)
|
| 632 |
+
|
| 633 |
+
# Add clamped_state buffer to track which color channels are clamped
|
| 634 |
+
clamped_state = wp.zeros(num_points, dtype=wp.vec3, device=DEVICE)
|
| 635 |
+
|
| 636 |
+
if debug:
|
| 637 |
+
print(f"\nWARP RENDERING: {image_width}x{image_height} image, {num_points} gaussians")
|
| 638 |
+
print(f"Colors: {'from SH' if colors is None else 'provided'}, SH degree: {degree}")
|
| 639 |
+
print(f"Antialiasing: {antialiasing}, Prefiltered: {prefiltered}")
|
| 640 |
+
|
| 641 |
+
# Launch preprocessing kernel
|
| 642 |
+
wp.launch(
|
| 643 |
+
kernel=wp_preprocess,
|
| 644 |
+
dim=(num_points,),
|
| 645 |
+
inputs=[
|
| 646 |
+
points_warp, # orig_points
|
| 647 |
+
scales_warp, # scales
|
| 648 |
+
scale_modifier, # scale_modifier
|
| 649 |
+
rotations_warp, # rotations_quat
|
| 650 |
+
opacities_warp, # opacities
|
| 651 |
+
shs_warp, # shs
|
| 652 |
+
degree,
|
| 653 |
+
clamped, # clamped
|
| 654 |
+
view_matrix_warp, # view_matrix
|
| 655 |
+
proj_matrix_warp, # proj_matrix
|
| 656 |
+
campos_warp, # cam_pos
|
| 657 |
+
image_width, # W
|
| 658 |
+
image_height, # H
|
| 659 |
+
tan_fovx, # tan_fovx
|
| 660 |
+
tan_fovy, # tan_fovy
|
| 661 |
+
image_width / (2.0 * tan_fovx), # focal_x
|
| 662 |
+
image_height / (2.0 * tan_fovy), # focal_y
|
| 663 |
+
radii, # radii
|
| 664 |
+
points_xy_image, # points_xy_image
|
| 665 |
+
depths, # depths
|
| 666 |
+
cov3Ds, # cov3Ds
|
| 667 |
+
rgb, # rgb
|
| 668 |
+
conic_opacity, # conic_opacity
|
| 669 |
+
tile_grid, # tile_grid
|
| 670 |
+
tiles_touched, # tiles_touched
|
| 671 |
+
clamped_state, # clamped_state - now using wp.vec3
|
| 672 |
+
prefiltered, # prefiltered
|
| 673 |
+
antialiasing # antialiasing
|
| 674 |
+
],
|
| 675 |
+
)
|
| 676 |
+
point_offsets = wp.zeros(num_points, dtype=int, device=DEVICE)
|
| 677 |
+
wp.launch(
|
| 678 |
+
kernel=wp_prefix_sum,
|
| 679 |
+
dim=1,
|
| 680 |
+
inputs=[
|
| 681 |
+
tiles_touched,
|
| 682 |
+
point_offsets
|
| 683 |
+
]
|
| 684 |
+
)
|
| 685 |
+
num_rendered = int(wp.to_torch(point_offsets)[-1].item()) # total number of duplicated entries
|
| 686 |
+
if num_rendered > (1 << 30):
|
| 687 |
+
# radix sort needs 2x memory
|
| 688 |
+
raise ValueError("Number of rendered points exceeds the maximum supported by Warp.")
|
| 689 |
+
|
| 690 |
+
point_list_keys_unsorted = wp.zeros(num_rendered, dtype=wp.int64, device=DEVICE)
|
| 691 |
+
point_list_unsorted = wp.zeros(num_rendered, dtype=int, device=DEVICE)
|
| 692 |
+
point_list_keys = wp.zeros(num_rendered, dtype=wp.int64, device=DEVICE)
|
| 693 |
+
point_list = wp.zeros(num_rendered, dtype=int, device=DEVICE)
|
| 694 |
+
wp.launch(
|
| 695 |
+
kernel=wp_duplicate_with_keys,
|
| 696 |
+
dim=num_points,
|
| 697 |
+
inputs=[
|
| 698 |
+
points_xy_image,
|
| 699 |
+
depths,
|
| 700 |
+
point_offsets,
|
| 701 |
+
point_list_keys_unsorted,
|
| 702 |
+
point_list_unsorted,
|
| 703 |
+
radii,
|
| 704 |
+
tile_grid
|
| 705 |
+
]
|
| 706 |
+
)#
|
| 707 |
+
point_list_keys_unsorted_padded = wp.zeros(num_rendered * 2, dtype=wp.int64, device=DEVICE)
|
| 708 |
+
point_list_unsorted_padded = wp.zeros(num_rendered * 2, dtype=int, device=DEVICE)
|
| 709 |
+
|
| 710 |
+
# Copy data to padded arrays
|
| 711 |
+
wp.copy(point_list_keys_unsorted_padded, point_list_keys_unsorted)
|
| 712 |
+
wp.copy(point_list_unsorted_padded, point_list_unsorted)
|
| 713 |
+
wp.utils.radix_sort_pairs(
|
| 714 |
+
point_list_keys_unsorted_padded, # keys to sort
|
| 715 |
+
point_list_unsorted_padded, # values to sort along with keys
|
| 716 |
+
num_rendered # number of elements to sort
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
wp.launch(
|
| 720 |
+
kernel=wp_copy_int64,
|
| 721 |
+
dim=num_rendered,
|
| 722 |
+
inputs=[
|
| 723 |
+
point_list_keys_unsorted_padded,
|
| 724 |
+
point_list_keys,
|
| 725 |
+
num_rendered
|
| 726 |
+
]
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
wp.launch(
|
| 730 |
+
kernel=wp_copy_int,
|
| 731 |
+
dim=num_rendered,
|
| 732 |
+
inputs=[
|
| 733 |
+
point_list_unsorted_padded,
|
| 734 |
+
point_list,
|
| 735 |
+
num_rendered
|
| 736 |
+
]
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
tile_count = int(tile_grid[0] * tile_grid[1])
|
| 740 |
+
ranges = wp.zeros(tile_count, dtype=wp.vec2i, device=DEVICE) # each is (start, end)
|
| 741 |
+
|
| 742 |
+
if num_rendered > 0:
|
| 743 |
+
wp.launch(
|
| 744 |
+
kernel=wp_identify_tile_ranges, # You also need this kernel
|
| 745 |
+
dim=num_rendered,
|
| 746 |
+
inputs=[
|
| 747 |
+
num_rendered,
|
| 748 |
+
point_list_keys,
|
| 749 |
+
ranges
|
| 750 |
+
]
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
wp.launch(
|
| 754 |
+
kernel=wp_render_gaussians,
|
| 755 |
+
dim=(int(tile_grid[0]), int(tile_grid[1]), TILE_M, TILE_N),
|
| 756 |
+
inputs=[
|
| 757 |
+
rendered_image, # Output color image
|
| 758 |
+
depth_image, # Output depth image
|
| 759 |
+
ranges, # Tile ranges
|
| 760 |
+
point_list, # Sorted point indices
|
| 761 |
+
image_width, # Image width
|
| 762 |
+
image_height, # Image height
|
| 763 |
+
points_xy_image, # 2D points
|
| 764 |
+
rgb, # Precomputed colors
|
| 765 |
+
conic_opacity, # Conic matrices and opacities
|
| 766 |
+
depths, # Depth values
|
| 767 |
+
background_warp, # Background color
|
| 768 |
+
tile_grid, # Tile grid configuration
|
| 769 |
+
final_Ts, # Final transparency values
|
| 770 |
+
n_contrib, # Number of contributors per pixel
|
| 771 |
+
]
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# Launch the pixel stats tracking kernel as a fallback
|
| 775 |
+
# to make sure final_Ts and n_contrib are populated
|
| 776 |
+
# This is especially important for existing rendered pixels
|
| 777 |
+
wp.launch(
|
| 778 |
+
kernel=track_pixel_stats,
|
| 779 |
+
dim=(image_width, image_height),
|
| 780 |
+
inputs=[
|
| 781 |
+
rendered_image,
|
| 782 |
+
depth_image,
|
| 783 |
+
background_warp,
|
| 784 |
+
final_Ts,
|
| 785 |
+
n_contrib,
|
| 786 |
+
image_width,
|
| 787 |
+
image_height
|
| 788 |
+
]
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
return rendered_image, depth_image, {
|
| 792 |
+
"radii": radii,
|
| 793 |
+
"point_offsets": point_offsets,
|
| 794 |
+
"points_xy_image": points_xy_image,
|
| 795 |
+
"depths": depths,
|
| 796 |
+
"colors": rgb,
|
| 797 |
+
"cov3Ds": cov3Ds,
|
| 798 |
+
"conic_opacity": conic_opacity,
|
| 799 |
+
"point_list": point_list,
|
| 800 |
+
"ranges": ranges,
|
| 801 |
+
"final_Ts": final_Ts, # Add final_Ts to intermediate buffers
|
| 802 |
+
"n_contrib": n_contrib, # Add contributor count to intermediate buffers
|
| 803 |
+
"clamped_state": clamped_state # Add clamped state to intermediate buffers
|
| 804 |
+
}
|
gs/lib64
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
lib
|
gs/loss.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warp as wp
|
| 2 |
+
import numpy as np
|
| 3 |
+
from config import DEVICE
|
| 4 |
+
from utils.wp_utils import wp_vec3_mul_element
|
| 5 |
+
|
| 6 |
+
# Constants for SSIM calculation
|
| 7 |
+
C1 = 0.01 ** 2
|
| 8 |
+
C2 = 0.03 ** 2
|
| 9 |
+
WINDOW_SIZE = 11
|
| 10 |
+
|
| 11 |
+
@wp.kernel
|
| 12 |
+
def l1_loss_kernel(
|
| 13 |
+
rendered: wp.array2d(dtype=wp.vec3),
|
| 14 |
+
target: wp.array2d(dtype=wp.vec3),
|
| 15 |
+
loss_buffer: wp.array(dtype=float),
|
| 16 |
+
width: int,
|
| 17 |
+
height: int
|
| 18 |
+
):
|
| 19 |
+
i, j = wp.tid()
|
| 20 |
+
if i >= width or j >= height:
|
| 21 |
+
return
|
| 22 |
+
|
| 23 |
+
# Compute L1 difference for each pixel component
|
| 24 |
+
rendered_pixel = rendered[j, i]
|
| 25 |
+
target_pixel = target[j, i]
|
| 26 |
+
diff = wp.abs(rendered_pixel - target_pixel)
|
| 27 |
+
l1_diff = diff[0] + diff[1] + diff[2]
|
| 28 |
+
|
| 29 |
+
# Atomic add to loss buffer
|
| 30 |
+
wp.atomic_add(loss_buffer, 0, l1_diff)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@wp.kernel
|
| 34 |
+
def gaussian_kernel(
|
| 35 |
+
kernel: wp.array(dtype=float),
|
| 36 |
+
sigma: float,
|
| 37 |
+
kernel_size: int
|
| 38 |
+
):
|
| 39 |
+
i = wp.tid()
|
| 40 |
+
if i >= kernel_size:
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
center = kernel_size // 2
|
| 44 |
+
x = i - center
|
| 45 |
+
kernel[i] = wp.exp(-1.0 * float(x * x) / (2.0 * sigma * sigma))
|
| 46 |
+
|
| 47 |
+
@wp.kernel
|
| 48 |
+
def ssim_kernel(
|
| 49 |
+
rendered: wp.array2d(dtype=wp.vec3),
|
| 50 |
+
target: wp.array2d(dtype=wp.vec3),
|
| 51 |
+
gaussian_weights: wp.array(dtype=float),
|
| 52 |
+
ssim_buffer: wp.array(dtype=float),
|
| 53 |
+
width: int,
|
| 54 |
+
height: int,
|
| 55 |
+
window_size: int
|
| 56 |
+
):
|
| 57 |
+
i, j = wp.tid()
|
| 58 |
+
if i >= width or j >= height:
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
# Constants for numerical stability
|
| 62 |
+
c1 = 0.01 * 0.01
|
| 63 |
+
c2 = 0.03 * 0.03
|
| 64 |
+
|
| 65 |
+
# We'll compute SSIM in a local window around each pixel
|
| 66 |
+
half_window = window_size // 2
|
| 67 |
+
|
| 68 |
+
# Initialize accumulators
|
| 69 |
+
mu1 = wp.vec3(0.0, 0.0, 0.0)
|
| 70 |
+
mu2 = wp.vec3(0.0, 0.0, 0.0)
|
| 71 |
+
sigma1 = wp.vec3(0.0, 0.0, 0.0)
|
| 72 |
+
sigma2 = wp.vec3(0.0, 0.0, 0.0)
|
| 73 |
+
sigma12 = wp.vec3(0.0, 0.0, 0.0)
|
| 74 |
+
weight_sum = float(0.0)
|
| 75 |
+
|
| 76 |
+
# Calculate weighted means and variances over the window
|
| 77 |
+
for y in range(max(0, j - half_window), min(height, j + half_window + 1)):
|
| 78 |
+
for x in range(max(0, i - half_window), min(width, i + half_window + 1)):
|
| 79 |
+
# Get Gaussian weight for this position
|
| 80 |
+
wy = abs(y - j)
|
| 81 |
+
wx = abs(x - i)
|
| 82 |
+
if wx <= half_window and wy <= half_window:
|
| 83 |
+
w = gaussian_weights[wx] * gaussian_weights[wy]
|
| 84 |
+
|
| 85 |
+
# Get pixels
|
| 86 |
+
p1 = rendered[y, x]
|
| 87 |
+
p2 = target[y, x]
|
| 88 |
+
|
| 89 |
+
# Accumulate weighted values
|
| 90 |
+
mu1 += p1 * w
|
| 91 |
+
mu2 += p2 * w
|
| 92 |
+
sigma1 += wp_vec3_mul_element(p1, p1) * w
|
| 93 |
+
sigma2 += wp_vec3_mul_element(p2, p2) * w
|
| 94 |
+
sigma12 += wp_vec3_mul_element(p1, p2) * w
|
| 95 |
+
weight_sum += w
|
| 96 |
+
|
| 97 |
+
# Normalize by weights
|
| 98 |
+
if weight_sum > 0.0:
|
| 99 |
+
mu1 /= weight_sum
|
| 100 |
+
mu2 /= weight_sum
|
| 101 |
+
sigma1 /= weight_sum
|
| 102 |
+
sigma2 /= weight_sum
|
| 103 |
+
sigma12 /= weight_sum
|
| 104 |
+
|
| 105 |
+
# Calculate variance and covariance
|
| 106 |
+
sigma1 = sigma1 - wp_vec3_mul_element(mu1, mu1)
|
| 107 |
+
sigma2 = sigma2 - wp_vec3_mul_element(mu2, mu2)
|
| 108 |
+
sigma12 = sigma12 - wp_vec3_mul_element(mu1, mu2)
|
| 109 |
+
|
| 110 |
+
# Calculate SSIM for each channel
|
| 111 |
+
ssim_r = ((2.0 * mu1[0] * mu2[0] + c1) * (2.0 * sigma12[0] + c2)) / ((mu1[0] * mu1[0] + mu2[0] * mu2[0] + c1) * (sigma1[0] + sigma2[0] + c2))
|
| 112 |
+
ssim_g = ((2.0 * mu1[1] * mu2[1] + c1) * (2.0 * sigma12[1] + c2)) / ((mu1[1] * mu1[1] + mu2[1] * mu2[1] + c1) * (sigma1[1] + sigma2[1] + c2))
|
| 113 |
+
ssim_b = ((2.0 * mu1[2] * mu2[2] + c1) * (2.0 * sigma12[2] + c2)) / ((mu1[2] * mu1[2] + mu2[2] * mu2[2] + c1) * (sigma1[2] + sigma2[2] + c2))
|
| 114 |
+
|
| 115 |
+
# Average SSIM across channels
|
| 116 |
+
ssim_val = (ssim_r + ssim_g + ssim_b) / 3.0
|
| 117 |
+
|
| 118 |
+
# Atomic add to SSIM buffer
|
| 119 |
+
wp.atomic_add(ssim_buffer, 0, ssim_val)
|
| 120 |
+
|
| 121 |
+
@wp.kernel
|
| 122 |
+
def backprop_l1_pixel_gradients(
|
| 123 |
+
rendered: wp.array2d(dtype=wp.vec3),
|
| 124 |
+
target: wp.array2d(dtype=wp.vec3),
|
| 125 |
+
pixel_grad: wp.array2d(dtype=wp.vec3),
|
| 126 |
+
width: int,
|
| 127 |
+
height: int,
|
| 128 |
+
l1_weight: float
|
| 129 |
+
):
|
| 130 |
+
i, j = wp.tid()
|
| 131 |
+
if i >= width or j >= height:
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
# Compute gradient (sign function for L1 loss)
|
| 135 |
+
rendered_pixel = rendered[j, i]
|
| 136 |
+
target_pixel = target[j, i]
|
| 137 |
+
|
| 138 |
+
# Sign function for L1 gradient
|
| 139 |
+
l1_grad = wp.vec3(
|
| 140 |
+
l1_weight * wp.sign(rendered_pixel[0] - target_pixel[0]),
|
| 141 |
+
l1_weight * wp.sign(rendered_pixel[1] - target_pixel[1]),
|
| 142 |
+
l1_weight * wp.sign(rendered_pixel[2] - target_pixel[2])
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Store L1 gradients
|
| 146 |
+
pixel_grad[j, i] = l1_grad
|
| 147 |
+
|
| 148 |
+
def l1_loss(rendered, target):
|
| 149 |
+
"""Compute L1 loss between rendered and target images"""
|
| 150 |
+
height, width = rendered.shape[0], rendered.shape[1]
|
| 151 |
+
|
| 152 |
+
# Create device arrays if not already
|
| 153 |
+
if not isinstance(rendered, wp.array):
|
| 154 |
+
d_rendered = wp.array(rendered, dtype=wp.vec3, device=DEVICE)
|
| 155 |
+
else:
|
| 156 |
+
d_rendered = rendered
|
| 157 |
+
|
| 158 |
+
if not isinstance(target, wp.array):
|
| 159 |
+
d_target = wp.array(target, dtype=wp.vec3, device=DEVICE)
|
| 160 |
+
else:
|
| 161 |
+
d_target = target
|
| 162 |
+
|
| 163 |
+
# Create loss buffer
|
| 164 |
+
loss_buffer = wp.zeros(1, dtype=float, device=DEVICE)
|
| 165 |
+
|
| 166 |
+
# Compute loss
|
| 167 |
+
wp.launch(
|
| 168 |
+
kernel=l1_loss_kernel,
|
| 169 |
+
dim=(width, height),
|
| 170 |
+
inputs=[d_rendered, d_target, loss_buffer, width, height]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Get loss value
|
| 174 |
+
loss = float(loss_buffer.numpy()[0]) / (width * height * 3) # Normalize by pixel count and channels
|
| 175 |
+
np_loss_buffer = loss_buffer.numpy()
|
| 176 |
+
return loss
|
| 177 |
+
|
| 178 |
+
def ssim(rendered, target):
|
| 179 |
+
"""Compute SSIM between rendered and target images"""
|
| 180 |
+
height, width = rendered.shape[0], rendered.shape[1]
|
| 181 |
+
|
| 182 |
+
# Create device arrays if not already
|
| 183 |
+
if not isinstance(rendered, wp.array):
|
| 184 |
+
d_rendered = wp.array(rendered, dtype=wp.vec3, device=DEVICE)
|
| 185 |
+
else:
|
| 186 |
+
d_rendered = rendered
|
| 187 |
+
|
| 188 |
+
if not isinstance(target, wp.array):
|
| 189 |
+
d_target = wp.array(target, dtype=wp.vec3, device=DEVICE)
|
| 190 |
+
else:
|
| 191 |
+
d_target = target
|
| 192 |
+
|
| 193 |
+
# Precompute Gaussian kernel
|
| 194 |
+
kernel_size = WINDOW_SIZE
|
| 195 |
+
gaussian_weights = wp.zeros(kernel_size, dtype=float, device=DEVICE)
|
| 196 |
+
wp.launch(
|
| 197 |
+
gaussian_kernel,
|
| 198 |
+
dim=kernel_size,
|
| 199 |
+
inputs=[gaussian_weights, 1.5, kernel_size]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Create SSIM buffer
|
| 203 |
+
ssim_buffer = wp.zeros(1, dtype=float, device=DEVICE)
|
| 204 |
+
pixel_count = wp.zeros(1, dtype=int, device=DEVICE)
|
| 205 |
+
|
| 206 |
+
# Compute SSIM
|
| 207 |
+
wp.launch(
|
| 208 |
+
ssim_kernel,
|
| 209 |
+
dim=(width, height),
|
| 210 |
+
inputs=[d_rendered, d_target, gaussian_weights, ssim_buffer, width, height, kernel_size]
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Get SSIM value (average over valid pixels)
|
| 214 |
+
ssim_val = float(ssim_buffer.numpy()[0]) / (width * height)
|
| 215 |
+
return ssim_val
|
| 216 |
+
|
| 217 |
+
def compute_image_gradients(rendered, target, lambda_dssim=0.2):
|
| 218 |
+
"""Compute gradients for combined L1 and SSIM loss"""
|
| 219 |
+
height, width = rendered.shape[0], rendered.shape[1]
|
| 220 |
+
|
| 221 |
+
# Create device arrays if not already
|
| 222 |
+
if not isinstance(rendered, wp.array):
|
| 223 |
+
d_rendered = wp.array(rendered, dtype=wp.vec3, device=DEVICE)
|
| 224 |
+
else:
|
| 225 |
+
d_rendered = rendered
|
| 226 |
+
|
| 227 |
+
if not isinstance(target, wp.array):
|
| 228 |
+
d_target = wp.array(target, dtype=wp.vec3, device=DEVICE)
|
| 229 |
+
else:
|
| 230 |
+
d_target = target
|
| 231 |
+
|
| 232 |
+
# Create gradient buffer
|
| 233 |
+
pixel_grad = wp.zeros((height, width), dtype=wp.vec3, device=DEVICE)
|
| 234 |
+
|
| 235 |
+
# Compute L1 loss gradient
|
| 236 |
+
l1_weight = (1.0 - lambda_dssim) / (height * width * 3.0)
|
| 237 |
+
wp.launch(
|
| 238 |
+
backprop_l1_pixel_gradients,
|
| 239 |
+
dim=(width, height),
|
| 240 |
+
inputs=[d_rendered, d_target, pixel_grad, width, height, l1_weight]
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# TODO: Add SSIM gradient
|
| 244 |
+
return pixel_grad
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@wp.kernel
|
| 248 |
+
def depth_loss_kernel(
|
| 249 |
+
rendered_depth: wp.array2d(dtype=float),
|
| 250 |
+
target_depth: wp.array2d(dtype=float),
|
| 251 |
+
depth_mask: wp.array2d(dtype=float),
|
| 252 |
+
loss_buffer: wp.array(dtype=float),
|
| 253 |
+
width: int,
|
| 254 |
+
height: int
|
| 255 |
+
):
|
| 256 |
+
i, j = wp.tid()
|
| 257 |
+
if i >= width or j >= height:
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
# Get depths and mask
|
| 261 |
+
rendered_inv_depth = rendered_depth[j, i]
|
| 262 |
+
target_inv_depth = target_depth[j, i]
|
| 263 |
+
mask = depth_mask[j, i]
|
| 264 |
+
|
| 265 |
+
# Compute L1 difference for inverse depths
|
| 266 |
+
diff = wp.abs(rendered_inv_depth - target_inv_depth) * mask
|
| 267 |
+
|
| 268 |
+
# Atomic add to loss buffer
|
| 269 |
+
wp.atomic_add(loss_buffer, 0, diff)
|
| 270 |
+
|
| 271 |
+
def depth_loss(rendered_depth, target_depth, depth_mask):
|
| 272 |
+
"""Compute L1 loss between rendered and target inverse depths"""
|
| 273 |
+
height, width = rendered_depth.shape[0], rendered_depth.shape[1]
|
| 274 |
+
|
| 275 |
+
# Create device arrays if not already
|
| 276 |
+
if not isinstance(rendered_depth, wp.array):
|
| 277 |
+
d_rendered_depth = wp.array(rendered_depth, dtype=float, device=DEVICE)
|
| 278 |
+
else:
|
| 279 |
+
d_rendered_depth = rendered_depth
|
| 280 |
+
|
| 281 |
+
if not isinstance(target_depth, wp.array):
|
| 282 |
+
d_target_depth = wp.array(target_depth, dtype=float, device=DEVICE)
|
| 283 |
+
else:
|
| 284 |
+
d_target_depth = target_depth
|
| 285 |
+
|
| 286 |
+
if not isinstance(depth_mask, wp.array):
|
| 287 |
+
d_depth_mask = wp.array(depth_mask, dtype=float, device=DEVICE)
|
| 288 |
+
else:
|
| 289 |
+
d_depth_mask = depth_mask
|
| 290 |
+
|
| 291 |
+
# Create loss buffer
|
| 292 |
+
loss_buffer = wp.zeros(1, dtype=float, device=DEVICE)
|
| 293 |
+
|
| 294 |
+
# Compute loss
|
| 295 |
+
wp.launch(
|
| 296 |
+
kernel=depth_loss_kernel,
|
| 297 |
+
dim=(width, height),
|
| 298 |
+
inputs=[d_rendered_depth, d_target_depth, d_depth_mask, loss_buffer, width, height]
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Get loss value
|
| 302 |
+
loss = float(loss_buffer.numpy()[0]) / (width * height) # Normalize by pixel count
|
| 303 |
+
return loss
|
gs/optimizer.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warp as wp
|
| 2 |
+
from utils.wp_utils import to_warp_array, wp_vec3_mul_element, wp_vec3_add_element, wp_vec3_sqrt, wp_vec3_div_element, wp_vec3_clamp
|
| 3 |
+
from config import *
|
| 4 |
+
|
| 5 |
+
@wp.kernel
|
| 6 |
+
def adam_update(
|
| 7 |
+
# Parameters
|
| 8 |
+
positions: wp.array(dtype=wp.vec3),
|
| 9 |
+
scales: wp.array(dtype=wp.vec3),
|
| 10 |
+
rotations: wp.array(dtype=wp.vec4),
|
| 11 |
+
opacities: wp.array(dtype=float),
|
| 12 |
+
shs: wp.array(dtype=wp.vec3),
|
| 13 |
+
|
| 14 |
+
# Gradients
|
| 15 |
+
pos_grads: wp.array(dtype=wp.vec3),
|
| 16 |
+
scale_grads: wp.array(dtype=wp.vec3),
|
| 17 |
+
rot_grads: wp.array(dtype=wp.vec4),
|
| 18 |
+
opacity_grads: wp.array(dtype=float),
|
| 19 |
+
sh_grads: wp.array(dtype=wp.vec3),
|
| 20 |
+
|
| 21 |
+
# First moments (m)
|
| 22 |
+
m_positions: wp.array(dtype=wp.vec3),
|
| 23 |
+
m_scales: wp.array(dtype=wp.vec3),
|
| 24 |
+
m_rotations: wp.array(dtype=wp.vec4),
|
| 25 |
+
m_opacities: wp.array(dtype=float),
|
| 26 |
+
m_shs: wp.array(dtype=wp.vec3),
|
| 27 |
+
|
| 28 |
+
# Second moments (v)
|
| 29 |
+
v_positions: wp.array(dtype=wp.vec3),
|
| 30 |
+
v_scales: wp.array(dtype=wp.vec3),
|
| 31 |
+
v_rotations: wp.array(dtype=wp.vec4),
|
| 32 |
+
v_opacities: wp.array(dtype=float),
|
| 33 |
+
v_shs: wp.array(dtype=wp.vec3),
|
| 34 |
+
|
| 35 |
+
num_points: int,
|
| 36 |
+
lr_pos: float,
|
| 37 |
+
lr_scale: float,
|
| 38 |
+
lr_rot: float,
|
| 39 |
+
lr_opac: float,
|
| 40 |
+
lr_sh: float,
|
| 41 |
+
beta1: float,
|
| 42 |
+
beta2: float,
|
| 43 |
+
epsilon: float,
|
| 44 |
+
iteration: int
|
| 45 |
+
):
|
| 46 |
+
i = wp.tid()
|
| 47 |
+
if i >= num_points:
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
# Bias correction terms
|
| 51 |
+
bias_correction1 = 1.0 - wp.pow(beta1, float(iteration + 1))
|
| 52 |
+
bias_correction2 = 1.0 - wp.pow(beta2, float(iteration + 1))
|
| 53 |
+
|
| 54 |
+
# Update positions
|
| 55 |
+
m_positions[i] = beta1 * m_positions[i] + (1.0 - beta1) * pos_grads[i]
|
| 56 |
+
# Use the helper function for element-wise multiplication
|
| 57 |
+
v_positions[i] = beta2 * v_positions[i] + (1.0 - beta2) * wp_vec3_mul_element(pos_grads[i], pos_grads[i])
|
| 58 |
+
# Use distinct names for corrected moments per parameter type
|
| 59 |
+
m_pos_corrected = m_positions[i] / bias_correction1
|
| 60 |
+
v_pos_corrected = v_positions[i] / bias_correction2
|
| 61 |
+
# Use the helper function for element-wise sqrt and division
|
| 62 |
+
denominator_pos = wp_vec3_sqrt(v_pos_corrected) + wp.vec3(epsilon, epsilon, epsilon)
|
| 63 |
+
positions[i] = positions[i] - lr_pos * wp_vec3_div_element(m_pos_corrected, denominator_pos)
|
| 64 |
+
|
| 65 |
+
# Update scales (with some constraints to keep them positive)
|
| 66 |
+
m_scales[i] = beta1 * m_scales[i] + (1.0 - beta1) * scale_grads[i]
|
| 67 |
+
# Use the helper function for element-wise multiplication
|
| 68 |
+
v_scales[i] = beta2 * v_scales[i] + (1.0 - beta2) * wp_vec3_mul_element(scale_grads[i], scale_grads[i])
|
| 69 |
+
# Use distinct names for corrected moments per parameter type
|
| 70 |
+
m_scale_corrected = m_scales[i] / bias_correction1
|
| 71 |
+
v_scale_corrected = v_scales[i] / bias_correction2
|
| 72 |
+
# Use the helper function for element-wise sqrt and division
|
| 73 |
+
denominator_scale = wp_vec3_sqrt(v_scale_corrected) + wp.vec3(epsilon, epsilon, epsilon)
|
| 74 |
+
scale_update = lr_scale * wp_vec3_div_element(m_scale_corrected, denominator_scale)
|
| 75 |
+
scales[i] = wp.vec3(
|
| 76 |
+
wp.max(scales[i][0] - scale_update[0], 0.001),
|
| 77 |
+
wp.max(scales[i][1] - scale_update[1], 0.001),
|
| 78 |
+
wp.max(scales[i][2] - scale_update[2], 0.001)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Update rotations
|
| 82 |
+
m_rotations[i] = beta1 * m_rotations[i] + (1.0 - beta1) * rot_grads[i]
|
| 83 |
+
# Element-wise multiplication for quaternions
|
| 84 |
+
v_rotations[i] = beta2 * v_rotations[i] + (1.0 - beta2) * wp.vec4(
|
| 85 |
+
rot_grads[i][0] * rot_grads[i][0],
|
| 86 |
+
rot_grads[i][1] * rot_grads[i][1],
|
| 87 |
+
rot_grads[i][2] * rot_grads[i][2],
|
| 88 |
+
rot_grads[i][3] * rot_grads[i][3]
|
| 89 |
+
)
|
| 90 |
+
m_rot_corrected = m_rotations[i] / bias_correction1
|
| 91 |
+
v_rot_corrected = v_rotations[i] / bias_correction2
|
| 92 |
+
# Element-wise sqrt and division for quaternions
|
| 93 |
+
denominator_rot = wp.vec4(
|
| 94 |
+
wp.sqrt(v_rot_corrected[0]) + epsilon,
|
| 95 |
+
wp.sqrt(v_rot_corrected[1]) + epsilon,
|
| 96 |
+
wp.sqrt(v_rot_corrected[2]) + epsilon,
|
| 97 |
+
wp.sqrt(v_rot_corrected[3]) + epsilon
|
| 98 |
+
)
|
| 99 |
+
rot_update = wp.vec4(
|
| 100 |
+
lr_rot * m_rot_corrected[0] / denominator_rot[0],
|
| 101 |
+
lr_rot * m_rot_corrected[1] / denominator_rot[1],
|
| 102 |
+
lr_rot * m_rot_corrected[2] / denominator_rot[2],
|
| 103 |
+
lr_rot * m_rot_corrected[3] / denominator_rot[3]
|
| 104 |
+
)
|
| 105 |
+
rotations[i] = rotations[i] - rot_update
|
| 106 |
+
|
| 107 |
+
# Normalize quaternion to ensure it's a valid rotation
|
| 108 |
+
quat_length = wp.sqrt(rotations[i][0]*rotations[i][0] +
|
| 109 |
+
rotations[i][1]*rotations[i][1] +
|
| 110 |
+
rotations[i][2]*rotations[i][2] +
|
| 111 |
+
rotations[i][3]*rotations[i][3])
|
| 112 |
+
|
| 113 |
+
if quat_length > 0.0:
|
| 114 |
+
rotations[i] = wp.vec4(
|
| 115 |
+
rotations[i][0] / quat_length,
|
| 116 |
+
rotations[i][1] / quat_length,
|
| 117 |
+
rotations[i][2] / quat_length,
|
| 118 |
+
rotations[i][3] / quat_length
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Update opacity (with clamping to [0,1])
|
| 122 |
+
m_opacities[i] = beta1 * m_opacities[i] + (1.0 - beta1) * opacity_grads[i]
|
| 123 |
+
# Opacity is scalar, direct multiplication is fine
|
| 124 |
+
v_opacities[i] = beta2 * v_opacities[i] + (1.0 - beta2) * (opacity_grads[i] * opacity_grads[i])
|
| 125 |
+
# Use distinct names for corrected moments per parameter type
|
| 126 |
+
m_opacity_corrected = m_opacities[i] / bias_correction1
|
| 127 |
+
v_opacity_corrected = v_opacities[i] / bias_correction2
|
| 128 |
+
# Opacity is scalar, direct wp.sqrt is fine here
|
| 129 |
+
opacity_update = lr_opac * m_opacity_corrected / (wp.sqrt(v_opacity_corrected) + epsilon)
|
| 130 |
+
opacities[i] = wp.max(wp.min(opacities[i] - opacity_update, 1.0), 0.0)
|
| 131 |
+
|
| 132 |
+
# Update SH coefficients
|
| 133 |
+
for j in range(16):
|
| 134 |
+
idx = i * 16 + j
|
| 135 |
+
m_shs[idx] = beta1 * m_shs[idx] + (1.0 - beta1) * sh_grads[idx]
|
| 136 |
+
# Use the helper function for element-wise multiplication
|
| 137 |
+
v_shs[idx] = beta2 * v_shs[idx] + (1.0 - beta2) * wp_vec3_mul_element(sh_grads[idx], sh_grads[idx])
|
| 138 |
+
# Use distinct names for corrected moments per parameter type
|
| 139 |
+
m_sh_corrected = m_shs[idx] / bias_correction1
|
| 140 |
+
v_sh_corrected = v_shs[idx] / bias_correction2
|
| 141 |
+
# Use the helper function for element-wise sqrt and division
|
| 142 |
+
denominator_sh = wp_vec3_sqrt(v_sh_corrected) + wp.vec3(epsilon, epsilon, epsilon)
|
| 143 |
+
shs[idx] = shs[idx] - lr_sh * wp_vec3_div_element(m_sh_corrected, denominator_sh)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@wp.kernel
|
| 147 |
+
def reset_opacities(
|
| 148 |
+
opacities: wp.array(dtype=float),
|
| 149 |
+
max_opacity: float,
|
| 150 |
+
num_points: int
|
| 151 |
+
):
|
| 152 |
+
"""Reset opacities to prevent oversaturation."""
|
| 153 |
+
i = wp.tid()
|
| 154 |
+
if i >= num_points:
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
# Reset opacity to a small value
|
| 158 |
+
opacities[i] = max_opacity
|
| 159 |
+
|
| 160 |
+
@wp.kernel
|
| 161 |
+
def reset_densification_stats(
|
| 162 |
+
xyz_gradient_accum: wp.array(dtype=float),
|
| 163 |
+
denom: wp.array(dtype=float),
|
| 164 |
+
max_radii2D: wp.array(dtype=float),
|
| 165 |
+
num_points: int
|
| 166 |
+
):
|
| 167 |
+
"""Reset densification statistics after parameter count changes."""
|
| 168 |
+
i = wp.tid()
|
| 169 |
+
if i >= num_points:
|
| 170 |
+
return
|
| 171 |
+
|
| 172 |
+
xyz_gradient_accum[i] = 0.0
|
| 173 |
+
denom[i] = 0.0
|
| 174 |
+
max_radii2D[i] = 0.0
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@wp.kernel
|
| 178 |
+
def mark_split_candidates(
|
| 179 |
+
grads: wp.array(dtype=float),
|
| 180 |
+
scales: wp.array(dtype=wp.vec3),
|
| 181 |
+
grad_threshold: float,
|
| 182 |
+
scene_extent: float,
|
| 183 |
+
percent_dense: float,
|
| 184 |
+
split_mask: wp.array(dtype=int),
|
| 185 |
+
num_points: int
|
| 186 |
+
):
|
| 187 |
+
"""Mark large Gaussians with high gradients for splitting."""
|
| 188 |
+
i = wp.tid()
|
| 189 |
+
if i >= num_points:
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
# Check if gradient exceeds threshold
|
| 193 |
+
high_grad = grads[i] >= grad_threshold
|
| 194 |
+
|
| 195 |
+
# Check if Gaussian is large (max scale > threshold)
|
| 196 |
+
max_scale = wp.max(wp.max(scales[i][0], scales[i][1]), scales[i][2])
|
| 197 |
+
scale_threshold = percent_dense * scene_extent
|
| 198 |
+
large_gaussian = max_scale > scale_threshold
|
| 199 |
+
|
| 200 |
+
# Mark for splitting if both conditions are met
|
| 201 |
+
if (high_grad and large_gaussian):
|
| 202 |
+
split_mask[i] = 1
|
| 203 |
+
else:
|
| 204 |
+
split_mask[i] = 0
|
| 205 |
+
|
| 206 |
+
@wp.kernel
|
| 207 |
+
def mark_clone_candidates(
|
| 208 |
+
grads: wp.array(dtype=float),
|
| 209 |
+
scales: wp.array(dtype=wp.vec3),
|
| 210 |
+
grad_threshold: float,
|
| 211 |
+
scene_extent: float,
|
| 212 |
+
percent_dense: float,
|
| 213 |
+
clone_mask: wp.array(dtype=int),
|
| 214 |
+
num_points: int
|
| 215 |
+
):
|
| 216 |
+
"""Mark small Gaussians with high gradients for cloning."""
|
| 217 |
+
i = wp.tid()
|
| 218 |
+
if i >= num_points:
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
# Check if gradient exceeds threshold
|
| 222 |
+
high_grad = grads[i] >= grad_threshold
|
| 223 |
+
|
| 224 |
+
# Check if Gaussian is small (max scale <= threshold)
|
| 225 |
+
max_scale = wp.max(wp.max(scales[i][0], scales[i][1]), scales[i][2])
|
| 226 |
+
scale_threshold = percent_dense * scene_extent
|
| 227 |
+
small_gaussian = max_scale <= scale_threshold
|
| 228 |
+
|
| 229 |
+
# Mark for cloning if both conditions are met
|
| 230 |
+
if (high_grad and small_gaussian):
|
| 231 |
+
clone_mask[i] = 1
|
| 232 |
+
else:
|
| 233 |
+
clone_mask[i] = 0
|
| 234 |
+
|
| 235 |
+
@wp.kernel
|
| 236 |
+
def split_gaussians(
|
| 237 |
+
split_mask: wp.array(dtype=int),
|
| 238 |
+
prefix_sum: wp.array(dtype=int),
|
| 239 |
+
positions: wp.array(dtype=wp.vec3),
|
| 240 |
+
scales: wp.array(dtype=wp.vec3),
|
| 241 |
+
rotations: wp.array(dtype=wp.vec4),
|
| 242 |
+
opacities: wp.array(dtype=float),
|
| 243 |
+
shs: wp.array(dtype=wp.vec3),
|
| 244 |
+
N_split: int,
|
| 245 |
+
scale_factor: float,
|
| 246 |
+
offset: int,
|
| 247 |
+
out_positions: wp.array(dtype=wp.vec3),
|
| 248 |
+
out_scales: wp.array(dtype=wp.vec3),
|
| 249 |
+
out_rotations: wp.array(dtype=wp.vec4),
|
| 250 |
+
out_opacities: wp.array(dtype=float),
|
| 251 |
+
out_shs: wp.array(dtype=wp.vec3)
|
| 252 |
+
):
|
| 253 |
+
"""Split large Gaussians into multiple smaller ones."""
|
| 254 |
+
i = wp.tid()
|
| 255 |
+
|
| 256 |
+
# Copy original Gaussians first
|
| 257 |
+
if i < len(positions):
|
| 258 |
+
out_positions[i] = positions[i]
|
| 259 |
+
out_scales[i] = scales[i]
|
| 260 |
+
out_rotations[i] = rotations[i]
|
| 261 |
+
out_opacities[i] = opacities[i]
|
| 262 |
+
|
| 263 |
+
# Copy SH coefficients
|
| 264 |
+
for j in range(16):
|
| 265 |
+
out_shs[i * 16 + j] = shs[i * 16 + j]
|
| 266 |
+
|
| 267 |
+
# Handle splits
|
| 268 |
+
if i >= len(positions):
|
| 269 |
+
return
|
| 270 |
+
|
| 271 |
+
if split_mask[i] == 1:
|
| 272 |
+
# Find where to write new Gaussians
|
| 273 |
+
split_idx = prefix_sum[i]
|
| 274 |
+
|
| 275 |
+
# Create N_split new Gaussians
|
| 276 |
+
for j in range(N_split):
|
| 277 |
+
new_idx = offset + split_idx * N_split + j
|
| 278 |
+
if new_idx < len(out_positions):
|
| 279 |
+
# Scale down the original Gaussian
|
| 280 |
+
scaled_scales = wp.vec3(
|
| 281 |
+
scales[i][0] * scale_factor,
|
| 282 |
+
scales[i][1] * scale_factor,
|
| 283 |
+
scales[i][2] * scale_factor
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Add small random offset for position
|
| 287 |
+
random_offset = wp.vec3(
|
| 288 |
+
((wp.randf(wp.uint32(new_idx * 3))) * 2.0 - 1.0) * 0.01,
|
| 289 |
+
((wp.randf(wp.uint32(new_idx * 3 + 1))) * 2.0 - 1.0) * 0.01,
|
| 290 |
+
((wp.randf(wp.uint32(new_idx * 3 + 2))) * 2.0 - 1.0) * 0.01
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
out_positions[new_idx] = positions[i] + random_offset
|
| 294 |
+
out_scales[new_idx] = scaled_scales
|
| 295 |
+
out_rotations[new_idx] = rotations[i]
|
| 296 |
+
out_opacities[new_idx] = opacities[i]
|
| 297 |
+
|
| 298 |
+
# Copy SH coefficients
|
| 299 |
+
for k in range(16):
|
| 300 |
+
out_shs[new_idx * 16 + k] = shs[i * 16 + k]
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@wp.kernel
|
| 304 |
+
def clone_gaussians(
|
| 305 |
+
clone_mask: wp.array(dtype=int),
|
| 306 |
+
prefix_sum: wp.array(dtype=int),
|
| 307 |
+
positions: wp.array(dtype=wp.vec3),
|
| 308 |
+
scales: wp.array(dtype=wp.vec3),
|
| 309 |
+
rotations: wp.array(dtype=wp.vec4),
|
| 310 |
+
opacities: wp.array(dtype=float),
|
| 311 |
+
shs: wp.array(dtype=wp.vec3), # shape: [N * 16]
|
| 312 |
+
|
| 313 |
+
noise_scale: float,
|
| 314 |
+
offset: int, # where to start writing new points
|
| 315 |
+
out_positions: wp.array(dtype=wp.vec3),
|
| 316 |
+
out_scales: wp.array(dtype=wp.vec3),
|
| 317 |
+
out_rotations: wp.array(dtype=wp.vec4),
|
| 318 |
+
out_opacities: wp.array(dtype=float),
|
| 319 |
+
out_shs: wp.array(dtype=wp.vec3),
|
| 320 |
+
):
|
| 321 |
+
i = wp.tid()
|
| 322 |
+
if i >= offset:
|
| 323 |
+
return
|
| 324 |
+
|
| 325 |
+
# Copy original to out[i]
|
| 326 |
+
out_positions[i] = positions[i]
|
| 327 |
+
out_scales[i] = scales[i]
|
| 328 |
+
out_rotations[i] = rotations[i]
|
| 329 |
+
out_opacities[i] = opacities[i]
|
| 330 |
+
for j in range(16):
|
| 331 |
+
out_shs[i * 16 + j] = shs[i * 16 + j]
|
| 332 |
+
|
| 333 |
+
if clone_mask[i] == 1:
|
| 334 |
+
base_idx = prefix_sum[i] + offset
|
| 335 |
+
pos = positions[i]
|
| 336 |
+
scale = scales[i]
|
| 337 |
+
rot = rotations[i]
|
| 338 |
+
opac = opacities[i]
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
noise = wp.vec3(
|
| 342 |
+
wp.randf(wp.uint32(i * 3)) * noise_scale,
|
| 343 |
+
wp.randf(wp.uint32(i * 3 + 1)) * noise_scale,
|
| 344 |
+
wp.randf(wp.uint32(i * 3 + 2)) * noise_scale
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
out_positions[base_idx] = pos + noise
|
| 348 |
+
out_scales[base_idx] = scale
|
| 349 |
+
out_rotations[base_idx] = rot
|
| 350 |
+
out_opacities[base_idx] = opac
|
| 351 |
+
|
| 352 |
+
for j in range(16):
|
| 353 |
+
out_shs[base_idx * 16 + j] = shs[i * 16 + j]
|
| 354 |
+
|
| 355 |
+
@wp.kernel
|
| 356 |
+
def prune_gaussians(
|
| 357 |
+
opacities: wp.array(dtype=float),
|
| 358 |
+
opacity_threshold: float,
|
| 359 |
+
valid_mask: wp.array(dtype=int),
|
| 360 |
+
num_points: int
|
| 361 |
+
):
|
| 362 |
+
i = wp.tid()
|
| 363 |
+
if i >= num_points:
|
| 364 |
+
return
|
| 365 |
+
# Mark Gaussians for keeping or removal
|
| 366 |
+
if opacities[i] > opacity_threshold:
|
| 367 |
+
valid_mask[i] = 1
|
| 368 |
+
else:
|
| 369 |
+
valid_mask[i] = 0
|
| 370 |
+
|
| 371 |
+
@wp.kernel
|
| 372 |
+
def compact_gaussians(
|
| 373 |
+
valid_mask: wp.array(dtype=int),
|
| 374 |
+
prefix_sum: wp.array(dtype=int),
|
| 375 |
+
positions: wp.array(dtype=wp.vec3),
|
| 376 |
+
scales: wp.array(dtype=wp.vec3),
|
| 377 |
+
rotations: wp.array(dtype=wp.vec4),
|
| 378 |
+
opacities: wp.array(dtype=float),
|
| 379 |
+
shs: wp.array(dtype=wp.vec3), # shape: [N * 16]
|
| 380 |
+
|
| 381 |
+
out_positions: wp.array(dtype=wp.vec3),
|
| 382 |
+
out_scales: wp.array(dtype=wp.vec3),
|
| 383 |
+
out_rotations: wp.array(dtype=wp.vec4),
|
| 384 |
+
out_opacities: wp.array(dtype=float),
|
| 385 |
+
out_shs: wp.array(dtype=wp.vec3)
|
| 386 |
+
):
|
| 387 |
+
i = wp.tid()
|
| 388 |
+
if valid_mask[i] == 0:
|
| 389 |
+
return
|
| 390 |
+
|
| 391 |
+
new_i = prefix_sum[i]
|
| 392 |
+
|
| 393 |
+
out_positions[new_i] = positions[i]
|
| 394 |
+
out_scales[new_i] = scales[i]
|
| 395 |
+
out_rotations[new_i] = rotations[i]
|
| 396 |
+
out_opacities[new_i] = opacities[i]
|
| 397 |
+
|
| 398 |
+
for j in range(16):
|
| 399 |
+
out_shs[new_i * 16 + j] = shs[i * 16 + j]
|
gs/render.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import warp as wp
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import math
|
| 5 |
+
from forward import render_gaussians
|
| 6 |
+
from utils.math_utils import world_to_view, projection_matrix
|
| 7 |
+
|
| 8 |
+
# Initialize Warp
|
| 9 |
+
wp.init()
|
| 10 |
+
|
| 11 |
+
def setup_example_scene(image_width=1800, image_height=1800, fovx=45.0, fovy=45.0, znear=0.01, zfar=100.0):
|
| 12 |
+
"""Setup example scene with camera and Gaussians for testing and debugging"""
|
| 13 |
+
# Camera setup
|
| 14 |
+
T = np.array([0, 0, 5], dtype=np.float32)
|
| 15 |
+
R = np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]], dtype=np.float32)
|
| 16 |
+
world_to_camera = np.eye(4, dtype=np.float32)
|
| 17 |
+
world_to_camera[:3, :3] = R
|
| 18 |
+
world_to_camera[:3, 3] = T
|
| 19 |
+
world_to_camera = world_to_camera.T
|
| 20 |
+
|
| 21 |
+
# Compute matrices
|
| 22 |
+
view_matrix = world_to_view(R=R, t=T)
|
| 23 |
+
proj_matrix = projection_matrix(fovx=fovx, fovy=fovy, znear=znear, zfar=zfar).T
|
| 24 |
+
full_proj_matrix = world_to_camera @ proj_matrix
|
| 25 |
+
|
| 26 |
+
camera_center = np.linalg.inv(world_to_camera)[3, :3]
|
| 27 |
+
|
| 28 |
+
# Compute FOV parameters
|
| 29 |
+
tan_fovx = math.tan(fovx * 0.5)
|
| 30 |
+
tan_fovy = math.tan(fovy * 0.5)
|
| 31 |
+
|
| 32 |
+
focal_x = image_width / (2 * tan_fovx)
|
| 33 |
+
focal_y = image_height / (2 * tan_fovy)
|
| 34 |
+
|
| 35 |
+
camera_params = {
|
| 36 |
+
'R': R,
|
| 37 |
+
'T': T,
|
| 38 |
+
'camera_center': camera_center,
|
| 39 |
+
'view_matrix': view_matrix,
|
| 40 |
+
'proj_matrix': proj_matrix,
|
| 41 |
+
'world_to_camera': world_to_camera,
|
| 42 |
+
'full_proj_matrix': full_proj_matrix,
|
| 43 |
+
'tan_fovx': tan_fovx,
|
| 44 |
+
'tan_fovy': tan_fovy,
|
| 45 |
+
'focal_x': focal_x,
|
| 46 |
+
'focal_y': focal_y,
|
| 47 |
+
'width': image_width,
|
| 48 |
+
'height': image_height
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# Gaussian setup - 3 points in a line
|
| 52 |
+
pts = np.array([[-5, 0, -10], [0, 2, -10], [5, 0, -10]], dtype=np.float32)
|
| 53 |
+
n = len(pts)
|
| 54 |
+
|
| 55 |
+
# Hard-coded SHs for debugging
|
| 56 |
+
shs = np.array([[0.71734341, 0.91905449, 0.49961076],
|
| 57 |
+
[0.08068483, 0.82132256, 0.01301602],
|
| 58 |
+
[0.8335743, 0.31798138, 0.19709007],
|
| 59 |
+
[0.82589597, 0.28206231, 0.790489 ],
|
| 60 |
+
[0.24008527, 0.21312673, 0.53132892],
|
| 61 |
+
[0.19493135, 0.37989934, 0.61886235],
|
| 62 |
+
[0.98106522, 0.28960672, 0.57313965],
|
| 63 |
+
[0.92623716, 0.46034381, 0.5485369 ],
|
| 64 |
+
[0.81660616, 0.7801104, 0.27813915],
|
| 65 |
+
[0.96114063, 0.69872817, 0.68313804],
|
| 66 |
+
[0.95464185, 0.21984855, 0.92912192],
|
| 67 |
+
[0.23503135, 0.29786121, 0.24999751],
|
| 68 |
+
[0.29844887, 0.6327788, 0.05423596],
|
| 69 |
+
[0.08934335, 0.11851827, 0.04186001],
|
| 70 |
+
[0.59331831, 0.919777, 0.71364335],
|
| 71 |
+
[0.83377388, 0.40242542, 0.8792624 ]]*n).reshape(n, 16, 3)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
opacities = np.ones((n, 1), dtype=np.float32)
|
| 76 |
+
|
| 77 |
+
# Random anisotropic scales (e.g., each axis between 0.5 and 2.0)
|
| 78 |
+
scales = (0.2 + 1.5 * np.random.rand(n, 3)).astype(np.float32)
|
| 79 |
+
|
| 80 |
+
# Random rotations as unit quaternions
|
| 81 |
+
q = np.random.randn(n, 4).astype(np.float32)
|
| 82 |
+
rotations = q / np.linalg.norm(q, axis=1, keepdims=True)
|
| 83 |
+
|
| 84 |
+
colors = np.ones((n, 3), dtype=np.float32)
|
| 85 |
+
|
| 86 |
+
return pts, shs, scales, colors, rotations, opacities, camera_params
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
# Setup rendering parameters
|
| 90 |
+
image_width = 1800
|
| 91 |
+
image_height = 1800
|
| 92 |
+
background = np.array([0.0, 0.0, 0.0], dtype=np.float32) # Black background
|
| 93 |
+
scale_modifier = 1.0
|
| 94 |
+
sh_degree = 3
|
| 95 |
+
prefiltered = False
|
| 96 |
+
antialiasing = False
|
| 97 |
+
clamped = True
|
| 98 |
+
|
| 99 |
+
# Create example scene
|
| 100 |
+
pts, shs, scales, colors, rotations, opacities, camera_params = setup_example_scene(
|
| 101 |
+
image_width=image_width,
|
| 102 |
+
image_height=image_height
|
| 103 |
+
)
|
| 104 |
+
n = len(pts)
|
| 105 |
+
print(f"Created example scene with {n} Gaussians")
|
| 106 |
+
|
| 107 |
+
# Call the Gaussian rasterizer
|
| 108 |
+
rendered_image, depth_image, _ = render_gaussians(
|
| 109 |
+
background=background,
|
| 110 |
+
means3D=pts,
|
| 111 |
+
colors=colors,
|
| 112 |
+
opacity=opacities,
|
| 113 |
+
scales=scales,
|
| 114 |
+
rotations=rotations,
|
| 115 |
+
scale_modifier=scale_modifier,
|
| 116 |
+
viewmatrix=camera_params['view_matrix'],
|
| 117 |
+
projmatrix=camera_params['full_proj_matrix'],
|
| 118 |
+
tan_fovx=camera_params['tan_fovx'],
|
| 119 |
+
tan_fovy=camera_params['tan_fovy'],
|
| 120 |
+
image_height=image_height,
|
| 121 |
+
image_width=image_width,
|
| 122 |
+
sh=shs,
|
| 123 |
+
degree=sh_degree,
|
| 124 |
+
campos=camera_params['camera_center'],
|
| 125 |
+
prefiltered=prefiltered,
|
| 126 |
+
antialiasing=antialiasing,
|
| 127 |
+
clamped=clamped,
|
| 128 |
+
debug=False
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
print("Rendering completed")
|
| 132 |
+
|
| 133 |
+
# Convert the rendered image from device to host
|
| 134 |
+
rendered_array = wp.to_torch(rendered_image).cpu().numpy()
|
| 135 |
+
|
| 136 |
+
# Display and save using matplotlib
|
| 137 |
+
plt.figure(figsize=(10, 10))
|
| 138 |
+
plt.imshow(rendered_array)
|
| 139 |
+
plt.axis('off')
|
| 140 |
+
plt.savefig("example_render.png", bbox_inches='tight', dpi=150)
|
| 141 |
+
print("Rendered image saved to example_render.png")
|
gs/scheduler.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
class LRScheduler:
|
| 4 |
+
"""Simple exponential decay learning rate scheduler."""
|
| 5 |
+
|
| 6 |
+
def __init__(self, initial_lr, final_lr_factor=0.01):
|
| 7 |
+
"""
|
| 8 |
+
Args:
|
| 9 |
+
initial_lr: Starting learning rate
|
| 10 |
+
final_lr_factor: Final LR as fraction of initial (e.g., 0.01 means final_lr = 0.01 * initial_lr)
|
| 11 |
+
"""
|
| 12 |
+
self.initial_lr = initial_lr
|
| 13 |
+
self.final_lr = initial_lr * final_lr_factor
|
| 14 |
+
|
| 15 |
+
def get_lr(self, iteration, total_iterations):
|
| 16 |
+
"""Get learning rate for given iteration using exponential decay."""
|
| 17 |
+
if total_iterations <= 1:
|
| 18 |
+
return self.initial_lr
|
| 19 |
+
|
| 20 |
+
# Exponential decay from initial_lr to final_lr
|
| 21 |
+
progress = iteration / (total_iterations - 1)
|
| 22 |
+
progress = min(progress, 1.0) # Clamp to [0, 1]
|
| 23 |
+
|
| 24 |
+
# Exponential interpolation: lr = initial * (final/initial)^progress
|
| 25 |
+
lr_ratio = self.final_lr / self.initial_lr
|
| 26 |
+
current_lr = self.initial_lr * (lr_ratio ** progress)
|
| 27 |
+
|
| 28 |
+
return current_lr
|
gs/train.py
ADDED
|
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import warp as wp
|
| 5 |
+
import imageio
|
| 6 |
+
import json
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
from forward import render_gaussians
|
| 12 |
+
from backward import backward
|
| 13 |
+
from optimizer import prune_gaussians, adam_update, clone_gaussians, compact_gaussians, mark_split_candidates, mark_clone_candidates, split_gaussians, reset_opacities, reset_densification_stats
|
| 14 |
+
from config import *
|
| 15 |
+
from utils.camera_utils import load_camera
|
| 16 |
+
from utils.point_cloud_utils import save_ply
|
| 17 |
+
from loss import l1_loss, compute_image_gradients
|
| 18 |
+
from scheduler import LRScheduler
|
| 19 |
+
|
| 20 |
+
# Initialize Warp
|
| 21 |
+
wp.init()
|
| 22 |
+
|
| 23 |
+
# Kernels for parameter updates
|
| 24 |
+
@wp.kernel
|
| 25 |
+
def init_gaussian_params(
|
| 26 |
+
positions: wp.array(dtype=wp.vec3),
|
| 27 |
+
scales: wp.array(dtype=wp.vec3),
|
| 28 |
+
rotations: wp.array(dtype=wp.vec4),
|
| 29 |
+
opacities: wp.array(dtype=float),
|
| 30 |
+
shs: wp.array(dtype=wp.vec3),
|
| 31 |
+
num_points: int,
|
| 32 |
+
init_scale: float
|
| 33 |
+
):
|
| 34 |
+
i = wp.tid()
|
| 35 |
+
if i >= num_points:
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
# Initialize positions with random values
|
| 39 |
+
# Generate random positions using warp random
|
| 40 |
+
offset = wp.vec3(
|
| 41 |
+
(wp.randf(wp.uint32(i * 3)) * 2.6 - 1.3),
|
| 42 |
+
(wp.randf(wp.uint32(i * 3 + 1)) * 2.6 - 1.3),
|
| 43 |
+
(wp.randf(wp.uint32(i * 3 + 2)) * 2.6 - 1.3)
|
| 44 |
+
)
|
| 45 |
+
# camera_center
|
| 46 |
+
positions[i] = offset
|
| 47 |
+
|
| 48 |
+
# Initialize scales
|
| 49 |
+
scales[i] = wp.vec3(init_scale, init_scale, init_scale)
|
| 50 |
+
|
| 51 |
+
# Initialize rotations to identity matrix
|
| 52 |
+
rotations[i] = wp.vec4(1.0, 0.0, 0.0, 0.0)
|
| 53 |
+
|
| 54 |
+
# Initialize opacities
|
| 55 |
+
opacities[i] = 0.1
|
| 56 |
+
|
| 57 |
+
# Initialize SH coefficients (just DC term for now)
|
| 58 |
+
for j in range(16): # degree=3, total 16 coefficients
|
| 59 |
+
idx = i * 16 + j
|
| 60 |
+
# Slight random initialization with positive bias
|
| 61 |
+
if j == 0:
|
| 62 |
+
shs[idx] = wp.vec3(-0.007, -0.007, -0.007)
|
| 63 |
+
else:
|
| 64 |
+
shs[idx] = wp.vec3(0.0, 0.0, 0.0)
|
| 65 |
+
|
| 66 |
+
@wp.kernel
|
| 67 |
+
def zero_gradients(
|
| 68 |
+
pos_grad: wp.array(dtype=wp.vec3),
|
| 69 |
+
scale_grad: wp.array(dtype=wp.vec3),
|
| 70 |
+
rot_grad: wp.array(dtype=wp.vec4),
|
| 71 |
+
opacity_grad: wp.array(dtype=float),
|
| 72 |
+
sh_grad: wp.array(dtype=wp.vec3),
|
| 73 |
+
num_points: int
|
| 74 |
+
):
|
| 75 |
+
i = wp.tid()
|
| 76 |
+
if i >= num_points:
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
pos_grad[i] = wp.vec3(0.0, 0.0, 0.0)
|
| 80 |
+
scale_grad[i] = wp.vec3(0.0, 0.0, 0.0)
|
| 81 |
+
rot_grad[i] = wp.vec4(0.0, 0.0, 0.0, 0.0)
|
| 82 |
+
opacity_grad[i] = 0.0
|
| 83 |
+
|
| 84 |
+
# Zero SH gradients
|
| 85 |
+
for j in range(16):
|
| 86 |
+
idx = i * 16 + j
|
| 87 |
+
sh_grad[idx] = wp.vec3(0.0, 0.0, 0.0)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class NeRFGaussianSplattingTrainer:
|
| 92 |
+
def __init__(self, dataset_path, output_path, config=None):
|
| 93 |
+
"""Initialize the 3D Gaussian Splatting trainer using pure Warp for NeRF dataset."""
|
| 94 |
+
self.dataset_path = Path(dataset_path)
|
| 95 |
+
self.output_path = Path(output_path)
|
| 96 |
+
self.output_path.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
|
| 98 |
+
# Initialize configuration from GaussianParams
|
| 99 |
+
self.config = GaussianParams.get_config_dict()
|
| 100 |
+
|
| 101 |
+
if config is not None:
|
| 102 |
+
self.config.update(config)
|
| 103 |
+
|
| 104 |
+
# Initialize learning rate scheduler
|
| 105 |
+
self.lr_scheduler = self.create_lr_scheduler()
|
| 106 |
+
print(f"Learning rate scheduler: {'Enabled' if self.lr_scheduler else 'Disabled'}")
|
| 107 |
+
|
| 108 |
+
# For tracking learning rates
|
| 109 |
+
self.learning_rate_history = {
|
| 110 |
+
'positions': [],
|
| 111 |
+
'scales': [],
|
| 112 |
+
'rotations': [],
|
| 113 |
+
'shs': [],
|
| 114 |
+
'opacities': []
|
| 115 |
+
}
|
| 116 |
+
# Load NeRF dataset
|
| 117 |
+
print(f"Loading NeRF dataset from {self.dataset_path}")
|
| 118 |
+
self.cameras, self.image_paths = self.load_nerf_data("train")
|
| 119 |
+
self.val_cameras, self.val_image_paths = self.load_nerf_data("val")
|
| 120 |
+
self.test_cameras, self.test_image_paths = self.load_nerf_data("test")
|
| 121 |
+
print(f"Loaded {len(self.cameras)} train cameras and {len(self.image_paths)} train images")
|
| 122 |
+
print(f"Loaded {len(self.val_cameras)} val cameras and {len(self.val_image_paths)} val images")
|
| 123 |
+
print(f"Loaded {len(self.test_cameras)} test cameras and {len(self.test_image_paths)} test images")
|
| 124 |
+
|
| 125 |
+
# Calculate scene extent for densification
|
| 126 |
+
self.scene_extent = self.calculate_scene_extent()
|
| 127 |
+
print(f"Calculated scene extent: {self.scene_extent}")
|
| 128 |
+
|
| 129 |
+
# Initialize parameters
|
| 130 |
+
self.num_points = self.config['num_points']
|
| 131 |
+
self.params = self.initialize_parameters()
|
| 132 |
+
|
| 133 |
+
# Create gradient arrays
|
| 134 |
+
self.grads = self.create_gradient_arrays()
|
| 135 |
+
|
| 136 |
+
# Create optimizer state
|
| 137 |
+
self.adam_m = self.create_gradient_arrays() # First moment
|
| 138 |
+
self.adam_v = self.create_gradient_arrays() # Second moment
|
| 139 |
+
|
| 140 |
+
# Initialize densification state tracking
|
| 141 |
+
self.init_densification_state()
|
| 142 |
+
|
| 143 |
+
# For tracking loss
|
| 144 |
+
self.losses = []
|
| 145 |
+
|
| 146 |
+
# Initialize intermediate buffers dictionary
|
| 147 |
+
self.intermediate_buffers = {}
|
| 148 |
+
|
| 149 |
+
# Track iteration for opacity reset
|
| 150 |
+
self.opacity_reset_at = -32768
|
| 151 |
+
|
| 152 |
+
def create_lr_scheduler(self):
|
| 153 |
+
"""Create simple learning rate schedulers for each parameter type."""
|
| 154 |
+
if not self.config['use_lr_scheduler']:
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
config = self.config['lr_scheduler_config']
|
| 158 |
+
final_factor = config['final_lr_factor']
|
| 159 |
+
|
| 160 |
+
schedulers = {
|
| 161 |
+
'positions': LRScheduler(config['lr_pos'], final_factor),
|
| 162 |
+
'scales': LRScheduler(config['lr_scale'], final_factor),
|
| 163 |
+
'rotations': LRScheduler(config['lr_rot'], final_factor),
|
| 164 |
+
'shs': LRScheduler(config['lr_sh'], final_factor),
|
| 165 |
+
'opacities': LRScheduler(config['lr_opac'], final_factor)
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
return schedulers
|
| 169 |
+
|
| 170 |
+
def initialize_parameters(self):
|
| 171 |
+
"""Initialize Gaussian parameters."""
|
| 172 |
+
positions = wp.zeros(self.num_points, dtype=wp.vec3)
|
| 173 |
+
scales = wp.zeros(self.num_points, dtype=wp.vec3)
|
| 174 |
+
rotations = wp.zeros(self.num_points, dtype=wp.vec4)
|
| 175 |
+
opacities = wp.zeros(self.num_points, dtype=float)
|
| 176 |
+
shs = wp.zeros(self.num_points * 16, dtype=wp.vec3) # 16 coeffs per point
|
| 177 |
+
# Launch kernel to initialize parameters
|
| 178 |
+
wp.launch(
|
| 179 |
+
init_gaussian_params,
|
| 180 |
+
dim=self.num_points,
|
| 181 |
+
inputs=[positions, scales, rotations, opacities, shs, self.num_points, self.config['initial_scale']]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Return parameters as dictionary
|
| 185 |
+
return {
|
| 186 |
+
'positions': positions,
|
| 187 |
+
'scales': scales,
|
| 188 |
+
'rotations': rotations,
|
| 189 |
+
'opacities': opacities,
|
| 190 |
+
'shs': shs
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
def create_gradient_arrays(self):
|
| 194 |
+
"""Create arrays for gradients or optimizer state."""
|
| 195 |
+
positions = wp.zeros(self.num_points, dtype=wp.vec3)
|
| 196 |
+
scales = wp.zeros(self.num_points, dtype=wp.vec3)
|
| 197 |
+
rotations = wp.zeros(self.num_points, dtype=wp.vec4)
|
| 198 |
+
opacities = wp.zeros(self.num_points, dtype=float)
|
| 199 |
+
shs = wp.zeros(self.num_points * 16, dtype=wp.vec3)
|
| 200 |
+
|
| 201 |
+
# Return a dictionary of arrays
|
| 202 |
+
return {
|
| 203 |
+
'positions': positions,
|
| 204 |
+
'scales': scales,
|
| 205 |
+
'rotations': rotations,
|
| 206 |
+
'opacities': opacities,
|
| 207 |
+
'shs': shs
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
def calculate_scene_extent(self):
|
| 211 |
+
"""Calculate the extent of the scene based on camera positions."""
|
| 212 |
+
if not self.cameras:
|
| 213 |
+
return 1.0 # Default fallback
|
| 214 |
+
|
| 215 |
+
# Extract camera positions
|
| 216 |
+
camera_positions = []
|
| 217 |
+
for camera in self.cameras:
|
| 218 |
+
camera_positions.append(camera['camera_center'])
|
| 219 |
+
|
| 220 |
+
camera_positions = np.array(camera_positions)
|
| 221 |
+
|
| 222 |
+
# Calculate the centroid of all camera positions
|
| 223 |
+
scene_center = np.mean(camera_positions, axis=0)
|
| 224 |
+
|
| 225 |
+
# Calculate the maximum distance from any camera to the scene center
|
| 226 |
+
max_distance_to_center = 0.0
|
| 227 |
+
for pos in camera_positions:
|
| 228 |
+
distance = np.linalg.norm(pos - scene_center)
|
| 229 |
+
max_distance_to_center = max(max_distance_to_center, distance)
|
| 230 |
+
|
| 231 |
+
# The scene extent is the radius of the bounding sphere
|
| 232 |
+
# Use default factor if extent is too small
|
| 233 |
+
extent = max_distance_to_center * self.config.get('camera_extent_factor', 1.0)
|
| 234 |
+
return max(extent, 1.0)
|
| 235 |
+
|
| 236 |
+
def init_densification_state(self):
|
| 237 |
+
"""Initialize state tracking for densification."""
|
| 238 |
+
self.xyz_gradient_accum = wp.zeros(self.num_points, dtype=float, device=DEVICE)
|
| 239 |
+
self.denom = wp.zeros(self.num_points, dtype=float, device=DEVICE)
|
| 240 |
+
self.max_radii2D = wp.zeros(self.num_points, dtype=float, device=DEVICE)
|
| 241 |
+
|
| 242 |
+
def load_nerf_data(self, datasplit):
|
| 243 |
+
"""Load camera parameters and images from a NeRF dataset."""
|
| 244 |
+
# Read transforms_train.json
|
| 245 |
+
transforms_path = self.dataset_path / f"transforms_{datasplit}.json"
|
| 246 |
+
if not transforms_path.exists():
|
| 247 |
+
raise FileNotFoundError(f"No transforms_train.json found in {self.dataset_path}")
|
| 248 |
+
|
| 249 |
+
with open(transforms_path, 'r') as f:
|
| 250 |
+
transforms = json.load(f)
|
| 251 |
+
|
| 252 |
+
# Get image dimensions from the first image if available
|
| 253 |
+
first_frame = transforms['frames'][0]
|
| 254 |
+
first_img_path = str(self.dataset_path / f"{first_frame['file_path']}.png")
|
| 255 |
+
if os.path.exists(first_img_path):
|
| 256 |
+
# Load first image to get dimensions
|
| 257 |
+
img = imageio.imread(first_img_path)
|
| 258 |
+
width = img.shape[1]
|
| 259 |
+
height = img.shape[0]
|
| 260 |
+
print(f"Using image dimensions from dataset: {width}x{height}")
|
| 261 |
+
else:
|
| 262 |
+
# Use default dimensions from config if image not found
|
| 263 |
+
width = self.config['width']
|
| 264 |
+
height = self.config['height']
|
| 265 |
+
print(f"Using default dimensions: {width}x{height}")
|
| 266 |
+
|
| 267 |
+
# Update config with actual dimensions
|
| 268 |
+
self.config['width'] = width
|
| 269 |
+
self.config['height'] = height
|
| 270 |
+
|
| 271 |
+
self.config['camera_angle_x'] = transforms['camera_angle_x']
|
| 272 |
+
|
| 273 |
+
# Calculate focal length
|
| 274 |
+
focal = 0.5 * width / np.tan(0.5 * self.config['camera_angle_x'])
|
| 275 |
+
|
| 276 |
+
cameras = []
|
| 277 |
+
image_paths = []
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# Process each frame
|
| 281 |
+
for i, frame in enumerate(transforms['frames']):
|
| 282 |
+
camera_info = {
|
| 283 |
+
"camera_id": i,
|
| 284 |
+
"camera_to_world": frame['transform_matrix'],
|
| 285 |
+
"width": width,
|
| 286 |
+
"height": height,
|
| 287 |
+
"focal": focal,
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
# Load camera parameters using existing function
|
| 291 |
+
camera_params = load_camera(camera_info)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
if camera_params is not None:
|
| 295 |
+
cameras.append(camera_params)
|
| 296 |
+
image_paths.append(str(self.dataset_path / f"{frame['file_path']}.png"))
|
| 297 |
+
|
| 298 |
+
return cameras, image_paths
|
| 299 |
+
|
| 300 |
+
def load_image(self, path):
|
| 301 |
+
"""Load an image as a numpy array."""
|
| 302 |
+
if os.path.exists(path):
|
| 303 |
+
img = imageio.imread(path)
|
| 304 |
+
# Convert to float and normalize to [0, 1]
|
| 305 |
+
img_np = img.astype(np.float32) / 255.0
|
| 306 |
+
# Ensure image is RGB (discard alpha channel if present)
|
| 307 |
+
if img_np.shape[2] == 4:
|
| 308 |
+
img_np = img_np[:, :, :3] # Keep only R, G, B channels
|
| 309 |
+
return img_np
|
| 310 |
+
else:
|
| 311 |
+
raise FileNotFoundError(f"Image not found: {path}")
|
| 312 |
+
|
| 313 |
+
def zero_grad(self):
|
| 314 |
+
"""Zero out all gradients."""
|
| 315 |
+
wp.launch(
|
| 316 |
+
zero_gradients,
|
| 317 |
+
dim=self.num_points,
|
| 318 |
+
inputs=[
|
| 319 |
+
self.grads['positions'],
|
| 320 |
+
self.grads['scales'],
|
| 321 |
+
self.grads['rotations'],
|
| 322 |
+
self.grads['opacities'],
|
| 323 |
+
self.grads['shs'],
|
| 324 |
+
self.num_points
|
| 325 |
+
]
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
def densification_and_pruning(self, iteration):
|
| 329 |
+
"""Perform sophisticated densification and pruning of Gaussians."""
|
| 330 |
+
|
| 331 |
+
# Check if we should do densification
|
| 332 |
+
densify_from_iter = self.config.get('densify_from_iter', 500)
|
| 333 |
+
densify_until_iter = self.config.get('densify_until_iter', 15000)
|
| 334 |
+
densification_interval = self.config.get('densification_interval', 100)
|
| 335 |
+
opacity_reset_interval = self.config.get('opacity_reset_interval', 3000)
|
| 336 |
+
|
| 337 |
+
# Skip densification if outside iteration range
|
| 338 |
+
if iteration > densify_from_iter and iteration < densify_until_iter and iteration % densification_interval == 0:
|
| 339 |
+
print(f"Iteration {iteration}: Performing sophisticated densification and pruning")
|
| 340 |
+
|
| 341 |
+
# For simplified implementation, use position gradients as proxy for viewspace gradients
|
| 342 |
+
pos_grads = self.grads['positions']
|
| 343 |
+
avg_grads = wp.zeros(self.num_points, dtype=float, device=DEVICE)
|
| 344 |
+
|
| 345 |
+
@wp.kernel
|
| 346 |
+
def compute_grad_norms(pos_grad: wp.array(dtype=wp.vec3),
|
| 347 |
+
grad_norms: wp.array(dtype=float),
|
| 348 |
+
num_points: int):
|
| 349 |
+
i = wp.tid()
|
| 350 |
+
if i >= num_points:
|
| 351 |
+
return
|
| 352 |
+
grad_norms[i] = wp.length(pos_grad[i])
|
| 353 |
+
|
| 354 |
+
wp.launch(compute_grad_norms, dim=self.num_points,
|
| 355 |
+
inputs=[pos_grads, avg_grads, self.num_points])
|
| 356 |
+
|
| 357 |
+
# Configuration
|
| 358 |
+
grad_threshold = self.config.get('densify_grad_threshold', 0.0002)
|
| 359 |
+
percent_dense = self.config.get('percent_dense', 0.01)
|
| 360 |
+
|
| 361 |
+
# --- Step 1: Clone small Gaussians with high gradients ---
|
| 362 |
+
clone_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
|
| 363 |
+
wp.launch(
|
| 364 |
+
mark_clone_candidates,
|
| 365 |
+
dim=self.num_points,
|
| 366 |
+
inputs=[
|
| 367 |
+
avg_grads,
|
| 368 |
+
self.params['scales'],
|
| 369 |
+
grad_threshold,
|
| 370 |
+
self.scene_extent,
|
| 371 |
+
percent_dense,
|
| 372 |
+
clone_mask,
|
| 373 |
+
self.num_points
|
| 374 |
+
]
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Perform cloning
|
| 378 |
+
clone_prefix_sum = wp.zeros_like(clone_mask)
|
| 379 |
+
wp.utils.array_scan(clone_mask, clone_prefix_sum, inclusive=False)
|
| 380 |
+
total_to_clone = int(clone_prefix_sum.numpy()[-1])
|
| 381 |
+
|
| 382 |
+
if total_to_clone > 0:
|
| 383 |
+
print(f"[Clone] Cloning {total_to_clone} small Gaussians")
|
| 384 |
+
N = self.num_points
|
| 385 |
+
new_N = N + total_to_clone
|
| 386 |
+
|
| 387 |
+
# Allocate output arrays
|
| 388 |
+
out_params = {
|
| 389 |
+
'positions': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
|
| 390 |
+
'scales': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
|
| 391 |
+
'rotations': wp.zeros(new_N, dtype=wp.vec4, device=DEVICE),
|
| 392 |
+
'opacities': wp.zeros(new_N, dtype=float, device=DEVICE),
|
| 393 |
+
'shs': wp.zeros(new_N * 16, dtype=wp.vec3, device=DEVICE)
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
# Clone Gaussians
|
| 397 |
+
wp.launch(
|
| 398 |
+
clone_gaussians,
|
| 399 |
+
dim=N,
|
| 400 |
+
inputs=[
|
| 401 |
+
clone_mask,
|
| 402 |
+
clone_prefix_sum,
|
| 403 |
+
self.params['positions'],
|
| 404 |
+
self.params['scales'],
|
| 405 |
+
self.params['rotations'],
|
| 406 |
+
self.params['opacities'],
|
| 407 |
+
self.params['shs'],
|
| 408 |
+
0.01, # noise_scale
|
| 409 |
+
N, # offset
|
| 410 |
+
out_params['positions'],
|
| 411 |
+
out_params['scales'],
|
| 412 |
+
out_params['rotations'],
|
| 413 |
+
out_params['opacities'],
|
| 414 |
+
out_params['shs']
|
| 415 |
+
]
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Update parameters and state
|
| 419 |
+
self.params = out_params
|
| 420 |
+
self.num_points = new_N
|
| 421 |
+
self.grads = self.create_gradient_arrays()
|
| 422 |
+
self.adam_m = self.create_gradient_arrays()
|
| 423 |
+
self.adam_v = self.create_gradient_arrays()
|
| 424 |
+
|
| 425 |
+
# --- Step 2: Split large Gaussians with high gradients ---
|
| 426 |
+
split_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
|
| 427 |
+
wp.launch(
|
| 428 |
+
mark_split_candidates,
|
| 429 |
+
dim=self.num_points,
|
| 430 |
+
inputs=[
|
| 431 |
+
avg_grads,
|
| 432 |
+
self.params['scales'],
|
| 433 |
+
grad_threshold,
|
| 434 |
+
self.scene_extent,
|
| 435 |
+
percent_dense,
|
| 436 |
+
split_mask,
|
| 437 |
+
self.num_points
|
| 438 |
+
]
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Perform splitting
|
| 442 |
+
split_prefix_sum = wp.zeros_like(split_mask)
|
| 443 |
+
wp.utils.array_scan(split_mask, split_prefix_sum, inclusive=False)
|
| 444 |
+
total_to_split = int(split_prefix_sum.numpy()[-1])
|
| 445 |
+
|
| 446 |
+
if total_to_split > 0:
|
| 447 |
+
print(f"[Split] Splitting {total_to_split} large Gaussians")
|
| 448 |
+
N = self.num_points
|
| 449 |
+
N_split = 2 # Split each Gaussian into 2
|
| 450 |
+
new_N = N + total_to_split * N_split
|
| 451 |
+
|
| 452 |
+
# Allocate output arrays
|
| 453 |
+
out_params = {
|
| 454 |
+
'positions': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
|
| 455 |
+
'scales': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
|
| 456 |
+
'rotations': wp.zeros(new_N, dtype=wp.vec4, device=DEVICE),
|
| 457 |
+
'opacities': wp.zeros(new_N, dtype=float, device=DEVICE),
|
| 458 |
+
'shs': wp.zeros(new_N * 16, dtype=wp.vec3, device=DEVICE)
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
# Split Gaussians
|
| 462 |
+
wp.launch(
|
| 463 |
+
split_gaussians,
|
| 464 |
+
dim=N,
|
| 465 |
+
inputs=[
|
| 466 |
+
split_mask,
|
| 467 |
+
split_prefix_sum,
|
| 468 |
+
self.params['positions'],
|
| 469 |
+
self.params['scales'],
|
| 470 |
+
self.params['rotations'],
|
| 471 |
+
self.params['opacities'],
|
| 472 |
+
self.params['shs'],
|
| 473 |
+
N_split, # Number of splits per Gaussian
|
| 474 |
+
0.8, # scale_factor
|
| 475 |
+
N, # offset
|
| 476 |
+
out_params['positions'],
|
| 477 |
+
out_params['scales'],
|
| 478 |
+
out_params['rotations'],
|
| 479 |
+
out_params['opacities'],
|
| 480 |
+
out_params['shs']
|
| 481 |
+
]
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# Update parameters and state
|
| 485 |
+
self.params = out_params
|
| 486 |
+
self.num_points = new_N
|
| 487 |
+
self.grads = self.create_gradient_arrays()
|
| 488 |
+
self.adam_m = self.create_gradient_arrays()
|
| 489 |
+
self.adam_v = self.create_gradient_arrays()
|
| 490 |
+
|
| 491 |
+
# Remove original split Gaussians
|
| 492 |
+
prune_filter = wp.zeros(self.num_points, dtype=int, device=DEVICE)
|
| 493 |
+
|
| 494 |
+
@wp.kernel
|
| 495 |
+
def mark_split_originals_for_removal(
|
| 496 |
+
split_mask: wp.array(dtype=int),
|
| 497 |
+
prune_filter: wp.array(dtype=int),
|
| 498 |
+
offset: int,
|
| 499 |
+
num_points: int
|
| 500 |
+
):
|
| 501 |
+
i = wp.tid()
|
| 502 |
+
if i >= num_points:
|
| 503 |
+
return
|
| 504 |
+
if i < offset and split_mask[i] == 1:
|
| 505 |
+
prune_filter[i] = 1 # Mark for removal
|
| 506 |
+
else:
|
| 507 |
+
prune_filter[i] = 0 # Keep
|
| 508 |
+
|
| 509 |
+
wp.launch(mark_split_originals_for_removal, dim=self.num_points,
|
| 510 |
+
inputs=[split_mask, prune_filter, N, self.num_points])
|
| 511 |
+
|
| 512 |
+
# Invert mask to get valid mask
|
| 513 |
+
valid_mask = wp.zeros_like(prune_filter)
|
| 514 |
+
|
| 515 |
+
@wp.kernel
|
| 516 |
+
def invert_mask(prune: wp.array(dtype=int), valid: wp.array(dtype=int), n: int):
|
| 517 |
+
i = wp.tid()
|
| 518 |
+
if i >= n:
|
| 519 |
+
return
|
| 520 |
+
valid[i] = 1 - prune[i]
|
| 521 |
+
|
| 522 |
+
wp.launch(invert_mask, dim=self.num_points,
|
| 523 |
+
inputs=[prune_filter, valid_mask, self.num_points])
|
| 524 |
+
|
| 525 |
+
# Count valid points and compact
|
| 526 |
+
prefix_sum = wp.zeros_like(valid_mask)
|
| 527 |
+
wp.utils.array_scan(valid_mask, prefix_sum, inclusive=False)
|
| 528 |
+
valid_count = int(prefix_sum.numpy()[-1])
|
| 529 |
+
|
| 530 |
+
if valid_count < self.num_points:
|
| 531 |
+
print(f"[Split] Removing {self.num_points - valid_count} original split Gaussians")
|
| 532 |
+
|
| 533 |
+
# Allocate compacted output
|
| 534 |
+
compact_params = {
|
| 535 |
+
'positions': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
|
| 536 |
+
'scales': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
|
| 537 |
+
'rotations': wp.zeros(valid_count, dtype=wp.vec4, device=DEVICE),
|
| 538 |
+
'opacities': wp.zeros(valid_count, dtype=float, device=DEVICE),
|
| 539 |
+
'shs': wp.zeros(valid_count * 16, dtype=wp.vec3, device=DEVICE)
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
wp.launch(
|
| 543 |
+
compact_gaussians,
|
| 544 |
+
dim=self.num_points,
|
| 545 |
+
inputs=[
|
| 546 |
+
valid_mask,
|
| 547 |
+
prefix_sum,
|
| 548 |
+
self.params['positions'],
|
| 549 |
+
self.params['scales'],
|
| 550 |
+
self.params['rotations'],
|
| 551 |
+
self.params['opacities'],
|
| 552 |
+
self.params['shs'],
|
| 553 |
+
compact_params['positions'],
|
| 554 |
+
compact_params['scales'],
|
| 555 |
+
compact_params['rotations'],
|
| 556 |
+
compact_params['opacities'],
|
| 557 |
+
compact_params['shs']
|
| 558 |
+
]
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# Update parameters and state
|
| 562 |
+
self.params = compact_params
|
| 563 |
+
self.num_points = valid_count
|
| 564 |
+
self.grads = self.create_gradient_arrays()
|
| 565 |
+
self.adam_m = self.create_gradient_arrays()
|
| 566 |
+
self.adam_v = self.create_gradient_arrays()
|
| 567 |
+
|
| 568 |
+
# --- Step 3: Enhanced Pruning ---
|
| 569 |
+
print(f"[Prune] Performing enhanced pruning")
|
| 570 |
+
|
| 571 |
+
valid_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
|
| 572 |
+
|
| 573 |
+
# Use opacity-based pruning for now
|
| 574 |
+
wp.launch(
|
| 575 |
+
prune_gaussians,
|
| 576 |
+
dim=self.num_points,
|
| 577 |
+
inputs=[
|
| 578 |
+
self.params['opacities'],
|
| 579 |
+
self.config.get('cull_opacity_threshold', 0.005),
|
| 580 |
+
valid_mask,
|
| 581 |
+
self.num_points
|
| 582 |
+
]
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# Count valid points
|
| 586 |
+
prefix_sum = wp.zeros_like(valid_mask)
|
| 587 |
+
wp.utils.array_scan(valid_mask, prefix_sum, inclusive=False)
|
| 588 |
+
valid_count = int(prefix_sum.numpy()[-1])
|
| 589 |
+
|
| 590 |
+
# Check pruning constraints
|
| 591 |
+
min_valid_points = self.config.get('min_valid_points', 1000)
|
| 592 |
+
max_valid_points = self.config.get('max_valid_points', 1000000)
|
| 593 |
+
max_prune_ratio = self.config.get('max_allowed_prune_ratio', 0.5)
|
| 594 |
+
|
| 595 |
+
prune_count = self.num_points - valid_count
|
| 596 |
+
prune_ratio = prune_count / self.num_points if self.num_points > 0 else 0
|
| 597 |
+
|
| 598 |
+
if (valid_count >= min_valid_points and
|
| 599 |
+
valid_count <= max_valid_points and
|
| 600 |
+
prune_ratio <= max_prune_ratio and
|
| 601 |
+
valid_count < self.num_points):
|
| 602 |
+
|
| 603 |
+
print(f"[Prune] Compacting from {self.num_points} → {valid_count} points")
|
| 604 |
+
|
| 605 |
+
# Allocate compacted output
|
| 606 |
+
out_params = {
|
| 607 |
+
'positions': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
|
| 608 |
+
'scales': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
|
| 609 |
+
'rotations': wp.zeros(valid_count, dtype=wp.vec4, device=DEVICE),
|
| 610 |
+
'opacities': wp.zeros(valid_count, dtype=float, device=DEVICE),
|
| 611 |
+
'shs': wp.zeros(valid_count * 16, dtype=wp.vec3, device=DEVICE)
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
wp.launch(
|
| 615 |
+
compact_gaussians,
|
| 616 |
+
dim=self.num_points,
|
| 617 |
+
inputs=[
|
| 618 |
+
valid_mask,
|
| 619 |
+
prefix_sum,
|
| 620 |
+
self.params['positions'],
|
| 621 |
+
self.params['scales'],
|
| 622 |
+
self.params['rotations'],
|
| 623 |
+
self.params['opacities'],
|
| 624 |
+
self.params['shs'],
|
| 625 |
+
out_params['positions'],
|
| 626 |
+
out_params['scales'],
|
| 627 |
+
out_params['rotations'],
|
| 628 |
+
out_params['opacities'],
|
| 629 |
+
out_params['shs']
|
| 630 |
+
]
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
# Update parameters and state
|
| 634 |
+
self.params = out_params
|
| 635 |
+
self.num_points = valid_count
|
| 636 |
+
self.grads = self.create_gradient_arrays()
|
| 637 |
+
self.adam_m = self.create_gradient_arrays()
|
| 638 |
+
self.adam_v = self.create_gradient_arrays()
|
| 639 |
+
else:
|
| 640 |
+
print(f"[Prune] Skipping pruning: valid={valid_count}, ratio={prune_ratio:.3f}")
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
# Opacity reset - updated logic to match reference implementation
|
| 644 |
+
background_is_white = all(c == 1.0 for c in self.config['background_color'])
|
| 645 |
+
should_reset_opacity = (
|
| 646 |
+
iteration % opacity_reset_interval == 0 or
|
| 647 |
+
(background_is_white and iteration == densify_from_iter)
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
if should_reset_opacity:
|
| 651 |
+
print(f"Iteration {iteration}: Resetting opacities")
|
| 652 |
+
wp.launch(
|
| 653 |
+
reset_opacities,
|
| 654 |
+
dim=self.num_points,
|
| 655 |
+
inputs=[
|
| 656 |
+
self.params['opacities'],
|
| 657 |
+
0.01, # max_opacity
|
| 658 |
+
self.num_points
|
| 659 |
+
]
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def optimizer_step(self, iteration):
|
| 664 |
+
"""Perform an Adam optimization step."""
|
| 665 |
+
|
| 666 |
+
# Get learning rates from scheduler or use config defaults
|
| 667 |
+
if self.lr_scheduler:
|
| 668 |
+
lr_pos = self.lr_scheduler['positions'].get_lr(iteration, self.config['num_iterations'])
|
| 669 |
+
lr_scale = self.lr_scheduler['scales'].get_lr(iteration, self.config['num_iterations'])
|
| 670 |
+
lr_rot = self.lr_scheduler['rotations'].get_lr(iteration, self.config['num_iterations'])
|
| 671 |
+
lr_sh = self.lr_scheduler['shs'].get_lr(iteration, self.config['num_iterations'])
|
| 672 |
+
lr_opac = self.lr_scheduler['opacities'].get_lr(iteration, self.config['num_iterations'])
|
| 673 |
+
|
| 674 |
+
# Track learning rate history
|
| 675 |
+
self.learning_rate_history['positions'].append(lr_pos)
|
| 676 |
+
self.learning_rate_history['scales'].append(lr_scale)
|
| 677 |
+
self.learning_rate_history['rotations'].append(lr_rot)
|
| 678 |
+
self.learning_rate_history['shs'].append(lr_sh)
|
| 679 |
+
self.learning_rate_history['opacities'].append(lr_opac)
|
| 680 |
+
|
| 681 |
+
# Log learning rates occasionally
|
| 682 |
+
if iteration % 1000 == 0:
|
| 683 |
+
print(f"Iteration {iteration} learning rates:")
|
| 684 |
+
print(f" positions: {lr_pos:.6f}")
|
| 685 |
+
print(f" scales: {lr_scale:.6f}")
|
| 686 |
+
print(f" rotations: {lr_rot:.6f}")
|
| 687 |
+
print(f" shs: {lr_sh:.6f}")
|
| 688 |
+
print(f" opacities: {lr_opac:.6f}")
|
| 689 |
+
else:
|
| 690 |
+
# Use static learning rates from config
|
| 691 |
+
lr_pos = self.config['lr_pos']
|
| 692 |
+
lr_scale = self.config['lr_scale']
|
| 693 |
+
lr_rot = self.config['lr_rot']
|
| 694 |
+
lr_sh = self.config['lr_sh']
|
| 695 |
+
lr_opac = self.config['lr_opac']
|
| 696 |
+
|
| 697 |
+
wp.launch(
|
| 698 |
+
adam_update,
|
| 699 |
+
dim=self.num_points,
|
| 700 |
+
inputs=[
|
| 701 |
+
# Parameters
|
| 702 |
+
self.params['positions'],
|
| 703 |
+
self.params['scales'],
|
| 704 |
+
self.params['rotations'],
|
| 705 |
+
self.params['opacities'],
|
| 706 |
+
self.params['shs'],
|
| 707 |
+
|
| 708 |
+
# Gradients
|
| 709 |
+
self.grads['positions'],
|
| 710 |
+
self.grads['scales'],
|
| 711 |
+
self.grads['rotations'],
|
| 712 |
+
self.grads['opacities'],
|
| 713 |
+
self.grads['shs'],
|
| 714 |
+
|
| 715 |
+
# First moments (m)
|
| 716 |
+
self.adam_m['positions'],
|
| 717 |
+
self.adam_m['scales'],
|
| 718 |
+
self.adam_m['rotations'],
|
| 719 |
+
self.adam_m['opacities'],
|
| 720 |
+
self.adam_m['shs'],
|
| 721 |
+
|
| 722 |
+
# Second moments (v)
|
| 723 |
+
self.adam_v['positions'],
|
| 724 |
+
self.adam_v['scales'],
|
| 725 |
+
self.adam_v['rotations'],
|
| 726 |
+
self.adam_v['opacities'],
|
| 727 |
+
self.adam_v['shs'],
|
| 728 |
+
|
| 729 |
+
# Optimizer parameters with dynamic learning rates
|
| 730 |
+
self.num_points,
|
| 731 |
+
lr_pos, # Dynamic learning rate for positions
|
| 732 |
+
lr_scale, # Dynamic learning rate for scales
|
| 733 |
+
lr_rot, # Dynamic learning rate for rotations
|
| 734 |
+
lr_sh, # Dynamic learning rate for SH coefficients
|
| 735 |
+
lr_opac, # Dynamic learning rate for opacities
|
| 736 |
+
self.config['adam_beta1'],
|
| 737 |
+
self.config['adam_beta2'],
|
| 738 |
+
self.config['adam_epsilon'],
|
| 739 |
+
iteration
|
| 740 |
+
]
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
def save_checkpoint(self, iteration):
|
| 744 |
+
"""Save the current point cloud and training state."""
|
| 745 |
+
checkpoint_dir = self.output_path / "point_cloud" / f"iteration_{iteration}"
|
| 746 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 747 |
+
|
| 748 |
+
# Save point cloud as PLY
|
| 749 |
+
ply_path = checkpoint_dir / "point_cloud.ply"
|
| 750 |
+
save_ply(self.params, ply_path, self.num_points)
|
| 751 |
+
|
| 752 |
+
# Save loss history
|
| 753 |
+
loss_path = self.output_path / "loss.txt"
|
| 754 |
+
with open(loss_path, 'w') as f:
|
| 755 |
+
for loss in self.losses:
|
| 756 |
+
f.write(f"{loss}\n")
|
| 757 |
+
|
| 758 |
+
# Save loss plot
|
| 759 |
+
plt.figure(figsize=(10, 5))
|
| 760 |
+
plt.plot(self.losses)
|
| 761 |
+
plt.title('Training Loss')
|
| 762 |
+
plt.xlabel('Iteration')
|
| 763 |
+
plt.ylabel('Loss')
|
| 764 |
+
plt.savefig(self.output_path / "loss_plot.png")
|
| 765 |
+
plt.close()
|
| 766 |
+
|
| 767 |
+
# Save a rendered view
|
| 768 |
+
camera_idx = 0 # Front view
|
| 769 |
+
rendered_image, _, _ = render_gaussians(
|
| 770 |
+
background=np.array(self.config['background_color'], dtype=np.float32),
|
| 771 |
+
means3D=self.params['positions'].numpy(),
|
| 772 |
+
colors=None, # Use SH coefficients instead
|
| 773 |
+
opacity=self.params['opacities'].numpy(),
|
| 774 |
+
scales=self.params['scales'].numpy(),
|
| 775 |
+
rotations=self.params['rotations'].numpy(),
|
| 776 |
+
scale_modifier=self.config['scale_modifier'],
|
| 777 |
+
viewmatrix=self.cameras[camera_idx]['world_to_camera'],
|
| 778 |
+
projmatrix=self.cameras[camera_idx]['full_proj_matrix'],
|
| 779 |
+
tan_fovx=self.cameras[camera_idx]['tan_fovx'],
|
| 780 |
+
tan_fovy=self.cameras[camera_idx]['tan_fovy'],
|
| 781 |
+
image_height=self.cameras[camera_idx]['height'],
|
| 782 |
+
image_width=self.cameras[camera_idx]['width'],
|
| 783 |
+
sh=self.params['shs'].numpy(), # Pass SH coefficients
|
| 784 |
+
degree=self.config['sh_degree'],
|
| 785 |
+
campos=self.cameras[camera_idx]['camera_center'],
|
| 786 |
+
prefiltered=False,
|
| 787 |
+
antialiasing=True,
|
| 788 |
+
clamped=True
|
| 789 |
+
)
|
| 790 |
+
# Save rendered view as image
|
| 791 |
+
rendered_array = wp.to_torch(rendered_image).cpu().numpy()
|
| 792 |
+
# Handle case where rendered_array has shape (3, H, W) - transpose to (H, W, 3)
|
| 793 |
+
if rendered_array.shape[0] == 3 and len(rendered_array.shape) == 3:
|
| 794 |
+
rendered_array = np.transpose(rendered_array, (1, 2, 0))
|
| 795 |
+
img8 = (np.clip(rendered_array, 0, 1) * 255).astype(np.uint8)
|
| 796 |
+
imageio.imwrite(checkpoint_dir / "rendered_view.png", img8)
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
def debug_log_and_save_images(
|
| 800 |
+
self,
|
| 801 |
+
rendered_image, # np.float32 H×W×3 (range 0-1)
|
| 802 |
+
target_image, # np.float32
|
| 803 |
+
depth_image, # wp.array2d(float) – optional but unused here
|
| 804 |
+
camera_idx: int,
|
| 805 |
+
it: int
|
| 806 |
+
):
|
| 807 |
+
|
| 808 |
+
# ------ quick numeric read-out -----------------------------------
|
| 809 |
+
radii = wp.to_torch(self.intermediate_buffers["radii"]).cpu().numpy()
|
| 810 |
+
alphas = wp.to_torch(self.intermediate_buffers["conic_opacity"]).cpu().numpy()[:, 3]
|
| 811 |
+
offs = wp.to_torch(self.intermediate_buffers["point_offsets"]).cpu().numpy()
|
| 812 |
+
num_dup = int(offs[-1]) if len(offs) else 0
|
| 813 |
+
r_med = np.median(radii[radii > 0]) if (radii > 0).any() else 0
|
| 814 |
+
|
| 815 |
+
# Count visible Gaussians
|
| 816 |
+
xy_image = wp.to_torch(self.intermediate_buffers["points_xy_image"]).cpu().numpy()
|
| 817 |
+
W = self.cameras[camera_idx]['width']
|
| 818 |
+
H = self.cameras[camera_idx]['height']
|
| 819 |
+
visible_gaussians = np.sum(
|
| 820 |
+
(xy_image[:, 0] >= 0) & (xy_image[:, 0] < W) &
|
| 821 |
+
(xy_image[:, 1] >= 0) & (xy_image[:, 1] < H) &
|
| 822 |
+
np.isfinite(xy_image).all(axis=1) &
|
| 823 |
+
(radii > 0) # Only count Gaussians with positive radius
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
print(
|
| 827 |
+
f"[it {it:05d}] dup={num_dup:<6} "
|
| 828 |
+
f"r_med={r_med:5.1f} α∈[{alphas.min():.3f},"
|
| 829 |
+
f"{np.median(alphas):.3f},{alphas.max():.3f}] "
|
| 830 |
+
f"visible={visible_gaussians}/{len(xy_image)}"
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
# ------ save render / target PNG ---------------------------------
|
| 834 |
+
def save_rgb(arr_f32, stem):
|
| 835 |
+
# Handle case where arr_f32 has shape (3, H, W) - transpose to (H, W, 3)
|
| 836 |
+
if arr_f32.shape[0] == 3 and len(arr_f32.shape) == 3:
|
| 837 |
+
arr_f32 = np.transpose(arr_f32, (1, 2, 0))
|
| 838 |
+
img8 = (np.clip(arr_f32, 0, 1) * 255).astype(np.uint8)
|
| 839 |
+
imageio.imwrite(self.output_path / f"{stem}_{it:06d}.png", img8)
|
| 840 |
+
|
| 841 |
+
save_rgb(rendered_image if isinstance(rendered_image, np.ndarray) else wp.to_torch(rendered_image).cpu().numpy(), "render")
|
| 842 |
+
save_rgb(target_image, "target")
|
| 843 |
+
|
| 844 |
+
# ------ make 2-D projection scatter ------------------------------
|
| 845 |
+
xy = wp.to_torch(self.intermediate_buffers["points_xy_image"]).cpu().numpy()
|
| 846 |
+
depth = wp.to_torch(self.intermediate_buffers["depths"]).cpu().numpy()
|
| 847 |
+
H, W = self.config["height"], self.config["width"]
|
| 848 |
+
|
| 849 |
+
mask = (
|
| 850 |
+
(xy[:, 0] >= 0) & (xy[:, 0] < W) &
|
| 851 |
+
(xy[:, 1] >= 0) & (xy[:, 1] < H) &
|
| 852 |
+
np.isfinite(xy).all(axis=1) &
|
| 853 |
+
(radii > 0) # Only include Gaussians with positive radius
|
| 854 |
+
)
|
| 855 |
+
if mask.any():
|
| 856 |
+
plt.figure(figsize=(6, 6))
|
| 857 |
+
plt.scatter(xy[mask, 0], xy[mask, 1],
|
| 858 |
+
s=4, c=depth[mask], cmap="turbo", alpha=.7)
|
| 859 |
+
plt.gca().invert_yaxis()
|
| 860 |
+
plt.xlim(0, W); plt.ylim(H, 0)
|
| 861 |
+
plt.title(f"Projected Gaussians (iter {it}): {np.sum(mask)}/{len(xy)} visible")
|
| 862 |
+
plt.colorbar(label="depth(z)")
|
| 863 |
+
plt.tight_layout()
|
| 864 |
+
plt.savefig(self.output_path / f"proj_{it:06d}.png", dpi=250)
|
| 865 |
+
plt.close()
|
| 866 |
+
|
| 867 |
+
# depth histogram
|
| 868 |
+
plt.figure(figsize=(5, 3))
|
| 869 |
+
plt.hist(depth[mask], bins=40, color="steelblue")
|
| 870 |
+
plt.xlabel("depth (camera-z)")
|
| 871 |
+
plt.ylabel("count")
|
| 872 |
+
plt.title(f"Depth hist – {mask.sum()} pts")
|
| 873 |
+
plt.tight_layout()
|
| 874 |
+
plt.savefig(self.output_path / f"depth_hist_{it:06d}.png", dpi=250)
|
| 875 |
+
plt.close()
|
| 876 |
+
|
| 877 |
+
def train(self):
|
| 878 |
+
"""Train the 3D Gaussian Splatting model."""
|
| 879 |
+
num_iterations = self.config['num_iterations']
|
| 880 |
+
|
| 881 |
+
# Main training loop
|
| 882 |
+
with tqdm(total=num_iterations) as pbar:
|
| 883 |
+
for iteration in range(num_iterations):
|
| 884 |
+
# Select a random camera and corresponding image
|
| 885 |
+
camera_idx = np.random.randint(0, len(self.cameras))
|
| 886 |
+
image_path = self.image_paths[camera_idx]
|
| 887 |
+
target_image = self.load_image(image_path)
|
| 888 |
+
|
| 889 |
+
# Zero gradients
|
| 890 |
+
self.zero_grad()
|
| 891 |
+
# Render the view
|
| 892 |
+
rendered_image, depth_image, self.intermediate_buffers = render_gaussians(
|
| 893 |
+
background=np.array(self.config['background_color'], dtype=np.float32),
|
| 894 |
+
means3D=self.params['positions'].numpy(),
|
| 895 |
+
colors=None, # Use SH coefficients instead
|
| 896 |
+
opacity=self.params['opacities'].numpy(),
|
| 897 |
+
scales=self.params['scales'].numpy(),
|
| 898 |
+
rotations=self.params['rotations'].numpy(),
|
| 899 |
+
scale_modifier=self.config['scale_modifier'],
|
| 900 |
+
viewmatrix=self.cameras[camera_idx]['world_to_camera'],
|
| 901 |
+
projmatrix=self.cameras[camera_idx]['full_proj_matrix'],
|
| 902 |
+
tan_fovx=self.cameras[camera_idx]['tan_fovx'],
|
| 903 |
+
tan_fovy=self.cameras[camera_idx]['tan_fovy'],
|
| 904 |
+
image_height=self.cameras[camera_idx]['height'],
|
| 905 |
+
image_width=self.cameras[camera_idx]['width'],
|
| 906 |
+
sh=self.params['shs'].numpy(), # Pass SH coefficients
|
| 907 |
+
degree=self.config['sh_degree'],
|
| 908 |
+
campos=self.cameras[camera_idx]['camera_center'],
|
| 909 |
+
prefiltered=False,
|
| 910 |
+
antialiasing=False,
|
| 911 |
+
clamped=True
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
radii = wp.to_torch(self.intermediate_buffers["radii"]).cpu().numpy()
|
| 915 |
+
np_rendered_image = wp.to_torch(rendered_image).cpu().numpy()
|
| 916 |
+
np_rendered_image = np_rendered_image.transpose(2, 0, 1)
|
| 917 |
+
|
| 918 |
+
if iteration % self.config['save_interval'] == 0:
|
| 919 |
+
self.debug_log_and_save_images(np_rendered_image, target_image, depth_image, camera_idx, iteration)
|
| 920 |
+
|
| 921 |
+
# Calculate L1 loss
|
| 922 |
+
l1_val = l1_loss(rendered_image, target_image)
|
| 923 |
+
|
| 924 |
+
# # Calculate SSIM, not used
|
| 925 |
+
# ssim_val = ssim(rendered_image, target_image)
|
| 926 |
+
# # Combined loss with weighted SSIM
|
| 927 |
+
# lambda_dssim = self.config['lambda_dssim']
|
| 928 |
+
# # loss = (1 - λ) * L1 + λ * (1 - SSIM)
|
| 929 |
+
# loss = (1.0 - lambda_dssim) * l1_val + lambda_dssim * (1.0 - ssim_val)
|
| 930 |
+
|
| 931 |
+
loss = l1_val
|
| 932 |
+
self.losses.append(loss)
|
| 933 |
+
# Compute pixel gradients for image loss (dL/dColor)
|
| 934 |
+
pixel_grad_buffer = compute_image_gradients(
|
| 935 |
+
rendered_image, target_image, lambda_dssim=0
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# Prepare camera parameters
|
| 939 |
+
camera = self.cameras[camera_idx]
|
| 940 |
+
view_matrix = wp.mat44(camera['world_to_camera'].flatten())
|
| 941 |
+
proj_matrix = wp.mat44(camera['full_proj_matrix'].flatten())
|
| 942 |
+
campos = wp.vec3(camera['camera_center'][0], camera['camera_center'][1], camera['camera_center'][2])
|
| 943 |
+
|
| 944 |
+
# Create appropriate buffer dictionaries for the backward pass
|
| 945 |
+
geom_buffer = {
|
| 946 |
+
'radii': self.intermediate_buffers['radii'],
|
| 947 |
+
'means2D': self.intermediate_buffers['points_xy_image'],
|
| 948 |
+
'conic_opacity': self.intermediate_buffers['conic_opacity'],
|
| 949 |
+
'rgb': self.intermediate_buffers['colors'],
|
| 950 |
+
'clamped': self.intermediate_buffers['clamped_state']
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
binning_buffer = {
|
| 954 |
+
'point_list': self.intermediate_buffers['point_list']
|
| 955 |
+
}
|
| 956 |
+
|
| 957 |
+
img_buffer = {
|
| 958 |
+
'ranges': self.intermediate_buffers['ranges'],
|
| 959 |
+
'final_Ts': self.intermediate_buffers['final_Ts'],
|
| 960 |
+
'n_contrib': self.intermediate_buffers['n_contrib']
|
| 961 |
+
}
|
| 962 |
+
|
| 963 |
+
gradients = backward(
|
| 964 |
+
# Core parameters
|
| 965 |
+
background=np.array(self.config['background_color'], dtype=np.float32),
|
| 966 |
+
means3D=self.params['positions'],
|
| 967 |
+
dL_dpixels=pixel_grad_buffer,
|
| 968 |
+
|
| 969 |
+
# Model parameters (pass directly from self.params)
|
| 970 |
+
opacity=self.params['opacities'],
|
| 971 |
+
shs=self.params['shs'],
|
| 972 |
+
scales=self.params['scales'],
|
| 973 |
+
rotations=self.params['rotations'],
|
| 974 |
+
scale_modifier=self.config['scale_modifier'],
|
| 975 |
+
|
| 976 |
+
# Camera parameters
|
| 977 |
+
viewmatrix=view_matrix,
|
| 978 |
+
projmatrix=proj_matrix,
|
| 979 |
+
tan_fovx=camera['tan_fovx'],
|
| 980 |
+
tan_fovy=camera['tan_fovy'],
|
| 981 |
+
image_height=camera['height'],
|
| 982 |
+
image_width=camera['width'],
|
| 983 |
+
campos=campos,
|
| 984 |
+
|
| 985 |
+
# Forward output buffers
|
| 986 |
+
radii=self.intermediate_buffers['radii'],
|
| 987 |
+
means2D=self.intermediate_buffers['points_xy_image'],
|
| 988 |
+
conic_opacity=self.intermediate_buffers['conic_opacity'],
|
| 989 |
+
rgb=self.intermediate_buffers['colors'],
|
| 990 |
+
cov3Ds=self.intermediate_buffers['cov3Ds'],
|
| 991 |
+
clamped=self.intermediate_buffers['clamped_state'],
|
| 992 |
+
|
| 993 |
+
# Internal state buffers
|
| 994 |
+
geom_buffer=geom_buffer,
|
| 995 |
+
binning_buffer=binning_buffer,
|
| 996 |
+
img_buffer=img_buffer,
|
| 997 |
+
|
| 998 |
+
# Algorithm parameters
|
| 999 |
+
degree=self.config['sh_degree'],
|
| 1000 |
+
debug=False
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
# 3. Copy gradients from backward result to the optimizer's gradient buffers
|
| 1004 |
+
wp.copy(self.grads['positions'], gradients['dL_dmean3D'])
|
| 1005 |
+
wp.copy(self.grads['scales'], gradients['dL_dscale'])
|
| 1006 |
+
wp.copy(self.grads['rotations'], gradients['dL_drot'])
|
| 1007 |
+
wp.copy(self.grads['opacities'], gradients['dL_dopacity'])
|
| 1008 |
+
wp.copy(self.grads['shs'], gradients['dL_dshs'])
|
| 1009 |
+
|
| 1010 |
+
# Update parameters
|
| 1011 |
+
self.optimizer_step(iteration)
|
| 1012 |
+
|
| 1013 |
+
# Update progress bar
|
| 1014 |
+
pbar.update(1)
|
| 1015 |
+
pbar.set_description(f"Loss: {loss:.6f}")
|
| 1016 |
+
|
| 1017 |
+
self.densification_and_pruning(iteration)
|
| 1018 |
+
|
| 1019 |
+
# Save checkpoint
|
| 1020 |
+
if iteration % self.config['save_interval'] == 0 or iteration == num_iterations - 1:
|
| 1021 |
+
self.save_checkpoint(iteration)
|
| 1022 |
+
|
| 1023 |
+
print("Training complete!")
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
def main():
|
| 1027 |
+
parser = argparse.ArgumentParser(description="Train 3D Gaussian Splatting model with NeRF dataset")
|
| 1028 |
+
parser.add_argument("--dataset", type=str, default="./data/nerf_synthetic/lego",
|
| 1029 |
+
help="Path to NeRF dataset directory (default: Lego dataset)")
|
| 1030 |
+
parser.add_argument("--output", type=str, default="./output", help="Output directory")
|
| 1031 |
+
|
| 1032 |
+
args = parser.parse_args()
|
| 1033 |
+
|
| 1034 |
+
# Create trainer and start training
|
| 1035 |
+
trainer = NeRFGaussianSplattingTrainer(
|
| 1036 |
+
dataset_path=args.dataset,
|
| 1037 |
+
output_path=args.output,
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
trainer.train()
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
if __name__ == "__main__":
|
| 1044 |
+
main()
|
gs/train_colmap.py
ADDED
|
@@ -0,0 +1,1586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import warp as wp
|
| 5 |
+
import imageio
|
| 6 |
+
import json
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
from forward import render_gaussians
|
| 12 |
+
from backward import backward
|
| 13 |
+
from optimizer import prune_gaussians, adam_update, clone_gaussians, compact_gaussians, mark_split_candidates, mark_clone_candidates, split_gaussians, reset_opacities, reset_densification_stats
|
| 14 |
+
from config import *
|
| 15 |
+
from utils.camera_utils import load_camera, load_camera_colmap
|
| 16 |
+
from utils.point_cloud_utils import save_ply
|
| 17 |
+
from loss import l1_loss, compute_image_gradients
|
| 18 |
+
from scheduler import LRScheduler
|
| 19 |
+
from utils.math_utils import quaternion_to_rotation_matrix
|
| 20 |
+
from plyfile import PlyData, PlyElement
|
| 21 |
+
from scipy.spatial import cKDTree # Add this import
|
| 22 |
+
# Initialize Warp
|
| 23 |
+
wp.init()
|
| 24 |
+
|
| 25 |
+
# Kernels for parameter updates
|
| 26 |
+
@wp.kernel
|
| 27 |
+
def init_gaussian_params(
|
| 28 |
+
#positions: wp.array(dtype=wp.vec3),
|
| 29 |
+
#scales: wp.array(dtype=wp.vec3), # Keep as input, but it will be pre-filled
|
| 30 |
+
rotations: wp.array(dtype=wp.vec4),
|
| 31 |
+
opacities: wp.array(dtype=float),
|
| 32 |
+
#shs: wp.array(dtype=wp.vec3),
|
| 33 |
+
num_points: int
|
| 34 |
+
# init_scale: float # Remove init_scale, it's no longer used here
|
| 35 |
+
):
|
| 36 |
+
i = wp.tid()
|
| 37 |
+
if i >= num_points:
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
# Initialize positions with random values (This is commented out in your version)
|
| 41 |
+
# Generate random positions using warp random
|
| 42 |
+
# offset = wp.vec3(
|
| 43 |
+
# (wp.randf(wp.uint32(i * 3)) * 2.6 - 1.3),
|
| 44 |
+
# (wp.randf(wp.uint32(i * 3 + 1)) * 2.6 - 1.3),
|
| 45 |
+
# (wp.randf(wp.uint32(i * 3 + 2)) * 2.6 - 1.3)
|
| 46 |
+
# )
|
| 47 |
+
# # camera_center
|
| 48 |
+
# positions[i] = offset
|
| 49 |
+
|
| 50 |
+
# Initialize scales (This line is removed, scales are pre-calculated)
|
| 51 |
+
# scales[i] = wp.vec3(init_scale, init_scale, init_scale)
|
| 52 |
+
|
| 53 |
+
# Initialize rotations to identity matrix
|
| 54 |
+
rotations[i] = wp.vec4(1.0, 0.0, 0.0, 0.0)
|
| 55 |
+
|
| 56 |
+
# Initialize opacities
|
| 57 |
+
opacities[i] = 0.1
|
| 58 |
+
|
| 59 |
+
# Initialize SH coefficients (This is commented out in your version)
|
| 60 |
+
# for j in range(16): # degree=3, total 16 coefficients
|
| 61 |
+
# idx = i * 16 + j
|
| 62 |
+
# # Slight random initialization with positive bias
|
| 63 |
+
# if j == 0:
|
| 64 |
+
# shs[idx] = wp.vec3(-0.007, -0.007, -0.007)
|
| 65 |
+
# else:
|
| 66 |
+
# shs[idx] = wp.vec3(0.0, 0.0, 0.0)
|
| 67 |
+
|
| 68 |
+
@wp.kernel
|
| 69 |
+
def zero_gradients(
|
| 70 |
+
pos_grad: wp.array(dtype=wp.vec3),
|
| 71 |
+
scale_grad: wp.array(dtype=wp.vec3),
|
| 72 |
+
rot_grad: wp.array(dtype=wp.vec4),
|
| 73 |
+
opacity_grad: wp.array(dtype=float),
|
| 74 |
+
sh_grad: wp.array(dtype=wp.vec3),
|
| 75 |
+
num_points: int
|
| 76 |
+
):
|
| 77 |
+
i = wp.tid()
|
| 78 |
+
if i >= num_points:
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
pos_grad[i] = wp.vec3(0.0, 0.0, 0.0)
|
| 82 |
+
scale_grad[i] = wp.vec3(0.0, 0.0, 0.0)
|
| 83 |
+
rot_grad[i] = wp.vec4(0.0, 0.0, 0.0, 0.0)
|
| 84 |
+
opacity_grad[i] = 0.0
|
| 85 |
+
|
| 86 |
+
# Zero SH gradients
|
| 87 |
+
for j in range(16):
|
| 88 |
+
idx = i * 16 + j
|
| 89 |
+
sh_grad[idx] = wp.vec3(0.0, 0.0, 0.0)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class NeRFGaussianSplattingTrainer:
|
| 94 |
+
def __init__(self, dataset_path, output_path, config=None):
|
| 95 |
+
"""Initialize the 3D Gaussian Splatting trainer using pure Warp for NeRF dataset."""
|
| 96 |
+
self.dataset_path = Path(dataset_path)
|
| 97 |
+
self.output_path = Path(output_path)
|
| 98 |
+
|
| 99 |
+
# Create output directories
|
| 100 |
+
self.output_path.mkdir(parents=True, exist_ok=True)
|
| 101 |
+
(self.output_path / "proj").mkdir(exist_ok=True)
|
| 102 |
+
(self.output_path / "render").mkdir(exist_ok=True)
|
| 103 |
+
(self.output_path / "target").mkdir(exist_ok=True)
|
| 104 |
+
(self.output_path / "depth_hist").mkdir(exist_ok=True)
|
| 105 |
+
(self.output_path / "point_cloud").mkdir(exist_ok=True)
|
| 106 |
+
|
| 107 |
+
# Initialize configuration from GaussianParams
|
| 108 |
+
self.config = GaussianParams.get_config_dict()
|
| 109 |
+
|
| 110 |
+
if config is not None:
|
| 111 |
+
self.config.update(config)
|
| 112 |
+
|
| 113 |
+
# Set default number of points (will be updated if points3D.ply is loaded)
|
| 114 |
+
self.num_points = self.config.get('num_points', 50000)
|
| 115 |
+
|
| 116 |
+
# Initialize learning rate scheduler
|
| 117 |
+
self.lr_scheduler = self.create_lr_scheduler()
|
| 118 |
+
print(f"Learning rate scheduler: {'Enabled' if self.lr_scheduler else 'Disabled'}")
|
| 119 |
+
|
| 120 |
+
# For tracking learning rates
|
| 121 |
+
self.learning_rate_history = {
|
| 122 |
+
'positions': [],
|
| 123 |
+
'scales': [],
|
| 124 |
+
'rotations': [],
|
| 125 |
+
'shs': [],
|
| 126 |
+
'opacities': []
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Load dataset
|
| 130 |
+
print(f"Loading COLMAP dataset from {self.dataset_path}")
|
| 131 |
+
self.cameras, self.image_paths = self.load_colmap("train")
|
| 132 |
+
self.test_cameras, self.test_image_paths = self.load_colmap("test")
|
| 133 |
+
|
| 134 |
+
print(f"Loaded {len(self.cameras)} train cameras and {len(self.image_paths)} train images")
|
| 135 |
+
print(f"Loaded {len(self.test_cameras)} test cameras and {len(self.test_image_paths)} test images")
|
| 136 |
+
|
| 137 |
+
# Calculate scene extent for densification
|
| 138 |
+
self.scene_extent = self.calculate_scene_extent()
|
| 139 |
+
print(f"Calculated scene extent: {self.scene_extent}")
|
| 140 |
+
|
| 141 |
+
# Initialize parameters (this may update self.num_points if points3D.ply is found)
|
| 142 |
+
self.params = self.initialize_parameters()
|
| 143 |
+
print(f"Initialized {self.num_points} Gaussians")
|
| 144 |
+
|
| 145 |
+
# Create gradient arrays
|
| 146 |
+
self.grads = self.create_gradient_arrays()
|
| 147 |
+
|
| 148 |
+
# Create optimizer state
|
| 149 |
+
self.adam_m = self.create_gradient_arrays()
|
| 150 |
+
self.adam_v = self.create_gradient_arrays()
|
| 151 |
+
|
| 152 |
+
# Initialize densification state tracking
|
| 153 |
+
self.init_densification_state()
|
| 154 |
+
|
| 155 |
+
# For tracking loss
|
| 156 |
+
self.losses = []
|
| 157 |
+
|
| 158 |
+
# Initialize intermediate buffers dictionary
|
| 159 |
+
self.intermediate_buffers = {}
|
| 160 |
+
|
| 161 |
+
# Track iteration for opacity reset
|
| 162 |
+
self.opacity_reset_at = -32768
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# Call after loading data
|
| 166 |
+
#self.visualize_camera_points_alignment()
|
| 167 |
+
|
| 168 |
+
def create_lr_scheduler(self):
|
| 169 |
+
"""Create simple learning rate schedulers for each parameter type."""
|
| 170 |
+
if not self.config['use_lr_scheduler']:
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
config = self.config['lr_scheduler_config']
|
| 174 |
+
final_factor = config['final_lr_factor']
|
| 175 |
+
|
| 176 |
+
schedulers = {
|
| 177 |
+
'positions': LRScheduler(config['lr_pos'], final_factor),
|
| 178 |
+
'scales': LRScheduler(config['lr_scale'], final_factor),
|
| 179 |
+
'rotations': LRScheduler(config['lr_rot'], final_factor),
|
| 180 |
+
'shs': LRScheduler(config['lr_sh'], final_factor),
|
| 181 |
+
'opacities': LRScheduler(config['lr_opac'], final_factor)
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
return schedulers
|
| 185 |
+
|
| 186 |
+
def initialize_parameters(self):
|
| 187 |
+
"""Initialize Gaussian parameters using points3D.ply if available."""
|
| 188 |
+
# Try to load points from points3D.ply
|
| 189 |
+
points3d_path = self.dataset_path / "sparse/0/points3D.ply"
|
| 190 |
+
initial_positions_np = None # Renamed to avoid confusion
|
| 191 |
+
initial_colors_np = None # Renamed
|
| 192 |
+
|
| 193 |
+
if points3d_path.exists():
|
| 194 |
+
try:
|
| 195 |
+
plydata = PlyData.read(str(points3d_path))
|
| 196 |
+
vertices = plydata['vertex']
|
| 197 |
+
if 'x' in vertices and 'y' in vertices and 'z' in vertices:
|
| 198 |
+
positions_data = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
|
| 199 |
+
initial_positions_np = positions_data.astype(np.float32)
|
| 200 |
+
|
| 201 |
+
if 'red' in vertices and 'green' in vertices and 'blue' in vertices:
|
| 202 |
+
colors_data = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T
|
| 203 |
+
initial_colors_np = (colors_data / 255.0).astype(np.float32)
|
| 204 |
+
else:
|
| 205 |
+
print("Warning: Color attributes (red, green, blue) not found in points3D.ply.")
|
| 206 |
+
|
| 207 |
+
# Update num_points based on loaded points
|
| 208 |
+
self.num_points = len(initial_positions_np)
|
| 209 |
+
print(f"Loaded {self.num_points} points from points3D.ply")
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Warning: Could not load points3D.ply: {e}")
|
| 212 |
+
initial_positions_np = None
|
| 213 |
+
initial_colors_np = None
|
| 214 |
+
|
| 215 |
+
if initial_positions_np is None:
|
| 216 |
+
# Fallback if points3D.ply is not loaded or doesn't have positions
|
| 217 |
+
print(f"Warning: Initial positions not loaded. Initializing {self.num_points} positions to zeros (or expect random init if uncommented in kernel).")
|
| 218 |
+
# self.num_points is already set from config or updated if PLY was partially read
|
| 219 |
+
initial_positions_np = np.zeros((self.num_points, 3), dtype=np.float32)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# Initialize scales_np
|
| 223 |
+
scales_np = np.zeros((self.num_points, 3), dtype=np.float32)
|
| 224 |
+
if initial_positions_np is not None and self.num_points > 3: # cKDTree needs k <= num_points
|
| 225 |
+
try:
|
| 226 |
+
print("Calculating initial scales using cKDTree...")
|
| 227 |
+
kdtree = cKDTree(initial_positions_np)
|
| 228 |
+
k = 2 # 1 self-point + 3 nearest neighbors
|
| 229 |
+
distances, _ = kdtree.query(initial_positions_np, k=k, workers=-1) # Use all available cores
|
| 230 |
+
|
| 231 |
+
# distances[:, 0] is the distance to self (0.0), so we use distances[:, 1:]
|
| 232 |
+
radius_np = np.mean(distances[:, 1:], axis=1)
|
| 233 |
+
scales_np = np.tile(radius_np[:, np.newaxis], (1, 3))
|
| 234 |
+
print(f"Initial scales calculated. Min radius: {radius_np.min()}, Max radius: {radius_np.max()}, Mean radius: {radius_np.mean()}")
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"Error during cKDTree scale initialization: {e}. Falling back to default scale.")
|
| 237 |
+
default_scale_val = self.config['initial_scale']
|
| 238 |
+
scales_np = np.full((self.num_points, 3), default_scale_val, dtype=np.float32)
|
| 239 |
+
else:
|
| 240 |
+
default_scale_val = self.config['initial_scale']
|
| 241 |
+
print(f"Not enough points for cKDTree or initial_positions_np is None. Using default scale: {default_scale_val}")
|
| 242 |
+
scales_np = np.full((self.num_points, 3), default_scale_val, dtype=np.float32)
|
| 243 |
+
|
| 244 |
+
# Initialize arrays with proper size
|
| 245 |
+
positions = wp.array(initial_positions_np, dtype=wp.vec3, device=DEVICE)
|
| 246 |
+
scales = wp.array(scales_np, dtype=wp.vec3, device=DEVICE) # Use the calculated or default scales_np
|
| 247 |
+
rotations = wp.zeros(self.num_points, dtype=wp.vec4, device=DEVICE)
|
| 248 |
+
opacities = wp.zeros(self.num_points, dtype=float, device=DEVICE)
|
| 249 |
+
|
| 250 |
+
C0 = 0.28209479177387814 # Constant for Y₀₀
|
| 251 |
+
shs_np_data = np.zeros((self.num_points * 16, 3), dtype=np.float32)
|
| 252 |
+
if initial_colors_np is not None and initial_colors_np.shape[0] == self.num_points:
|
| 253 |
+
shs_np_data[::16] = (initial_colors_np - 0.5) / C0
|
| 254 |
+
else:
|
| 255 |
+
# Default to gray if colors are not available or mismatch
|
| 256 |
+
gray_color_sh = (np.array([0.5, 0.5, 0.5]) - 0.5) / C0
|
| 257 |
+
shs_np_data[::16] = np.tile(gray_color_sh, (self.num_points, 1))
|
| 258 |
+
shs = wp.array(shs_np_data, dtype=wp.vec3, device=DEVICE)
|
| 259 |
+
|
| 260 |
+
# Launch kernel to initialize parameters (rotations and opacities)
|
| 261 |
+
# scales and shs are already initialized from Python side.
|
| 262 |
+
wp.launch(
|
| 263 |
+
init_gaussian_params,
|
| 264 |
+
dim=self.num_points,
|
| 265 |
+
inputs=[rotations, opacities, self.num_points] # Removed self.config['initial_scale']
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
return {
|
| 269 |
+
'positions': positions,
|
| 270 |
+
'scales': scales,
|
| 271 |
+
'rotations': rotations,
|
| 272 |
+
'opacities': opacities,
|
| 273 |
+
'shs': shs
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
def create_gradient_arrays(self):
|
| 277 |
+
"""Create arrays for gradients or optimizer state."""
|
| 278 |
+
positions = wp.zeros(self.num_points, dtype=wp.vec3)
|
| 279 |
+
scales = wp.zeros(self.num_points, dtype=wp.vec3)
|
| 280 |
+
rotations = wp.zeros(self.num_points, dtype=wp.vec4)
|
| 281 |
+
opacities = wp.zeros(self.num_points, dtype=float)
|
| 282 |
+
shs = wp.zeros(self.num_points * 16, dtype=wp.vec3)
|
| 283 |
+
|
| 284 |
+
# Return a dictionary of arrays
|
| 285 |
+
return {
|
| 286 |
+
'positions': positions,
|
| 287 |
+
'scales': scales,
|
| 288 |
+
'rotations': rotations,
|
| 289 |
+
'opacities': opacities,
|
| 290 |
+
'shs': shs
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
def calculate_scene_extent(self):
|
| 294 |
+
"""Calculate the extent of the scene based on camera positions."""
|
| 295 |
+
if not self.cameras:
|
| 296 |
+
return 1.0 # Default fallback
|
| 297 |
+
|
| 298 |
+
# Extract camera positions
|
| 299 |
+
camera_positions = []
|
| 300 |
+
for camera in self.cameras:
|
| 301 |
+
camera_positions.append(camera['camera_center'])
|
| 302 |
+
|
| 303 |
+
camera_positions = np.array(camera_positions)
|
| 304 |
+
|
| 305 |
+
# Calculate the centroid of all camera positions
|
| 306 |
+
scene_center = np.mean(camera_positions, axis=0)
|
| 307 |
+
|
| 308 |
+
# Calculate the maximum distance from any camera to the scene center
|
| 309 |
+
max_distance_to_center = 0.0
|
| 310 |
+
for pos in camera_positions:
|
| 311 |
+
distance = np.linalg.norm(pos - scene_center)
|
| 312 |
+
max_distance_to_center = max(max_distance_to_center, distance)
|
| 313 |
+
|
| 314 |
+
# The scene extent is the radius of the bounding sphere
|
| 315 |
+
# Use default factor if extent is too small
|
| 316 |
+
extent = max_distance_to_center * self.config.get('camera_extent_factor', 1.0)
|
| 317 |
+
return max(extent, 1.0)
|
| 318 |
+
|
| 319 |
+
def init_densification_state(self):
|
| 320 |
+
"""Initialize state tracking for densification."""
|
| 321 |
+
self.xyz_gradient_accum = wp.zeros(self.num_points, dtype=float, device=DEVICE)
|
| 322 |
+
self.denom = wp.zeros(self.num_points, dtype=float, device=DEVICE)
|
| 323 |
+
self.max_radii2D = wp.zeros(self.num_points, dtype=float, device=DEVICE)
|
| 324 |
+
|
| 325 |
+
def load_colmap(self, datasplit="train", llffhold=8):
|
| 326 |
+
colmap_dir = self.dataset_path / "sparse/0"
|
| 327 |
+
images_dir = self.dataset_path / "images"
|
| 328 |
+
intrinsics = {}
|
| 329 |
+
|
| 330 |
+
with open(colmap_dir / "cameras.txt") as f:
|
| 331 |
+
for line in f:
|
| 332 |
+
if line.startswith("#"): continue
|
| 333 |
+
vals = line.strip().split()
|
| 334 |
+
if len(vals) < 4: continue
|
| 335 |
+
|
| 336 |
+
cam_id, model, w, h = int(vals[0]), vals[1], int(vals[2]), int(vals[3])
|
| 337 |
+
|
| 338 |
+
if model == "PINHOLE":
|
| 339 |
+
# PINHOLE has 4 parameters: fx, fy, cx, cy
|
| 340 |
+
if len(vals) >= 8: # 4 basic + 4 params
|
| 341 |
+
fx, fy, cx, cy = float(vals[4]), float(vals[5]), float(vals[6]), float(vals[7])
|
| 342 |
+
else:
|
| 343 |
+
continue
|
| 344 |
+
elif model == "SIMPLE_PINHOLE":
|
| 345 |
+
# SIMPLE_PINHOLE has 3 parameters: f, cx, cy
|
| 346 |
+
if len(vals) >= 7: # 4 basic + 3 params
|
| 347 |
+
f, cx, cy = float(vals[4]), float(vals[5]), float(vals[6])
|
| 348 |
+
fx = fy = f # Same focal length for both axes
|
| 349 |
+
else:
|
| 350 |
+
continue
|
| 351 |
+
else:
|
| 352 |
+
print(f"Unsupported camera model: {model}")
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
intrinsics[cam_id] = (fx, fy, w, h, cx, cy)
|
| 356 |
+
|
| 357 |
+
extrinsics = []
|
| 358 |
+
with open(colmap_dir / "images.txt") as f:
|
| 359 |
+
for line in f:
|
| 360 |
+
if line.startswith("#"): continue
|
| 361 |
+
parts = line.strip().split()
|
| 362 |
+
if len(parts) < 10: continue
|
| 363 |
+
|
| 364 |
+
# COLMAP images.txt format: IMAGE_ID QW QX QY QZ TX TY TZ CAMERA_ID NAME
|
| 365 |
+
img_id, qw, qx, qy, qz, tx, ty, tz, cam_id, img_name = parts[:10]
|
| 366 |
+
cam_id = int(cam_id)
|
| 367 |
+
|
| 368 |
+
if cam_id not in intrinsics:
|
| 369 |
+
print(f"Warning: Camera ID {cam_id} not found in intrinsics")
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
+
fx, fy, w, h, cx, cy = intrinsics[cam_id]
|
| 373 |
+
|
| 374 |
+
# Fix quaternion order and normalize
|
| 375 |
+
q = np.array([float(qw), float(qx), float(qy), float(qz)])
|
| 376 |
+
q = q / np.linalg.norm(q) # Normalize if needed
|
| 377 |
+
|
| 378 |
+
t = np.array([float(tx), float(ty), float(tz)])
|
| 379 |
+
R = quaternion_to_rotation_matrix(q)
|
| 380 |
+
|
| 381 |
+
# # Convert from COLMAP's world-to-camera to camera-to-world
|
| 382 |
+
# c2w = np.eye(4, dtype=np.float32)
|
| 383 |
+
# c2w[:3, :3] = R.T
|
| 384 |
+
# c2w[:3, 3] = -R.T @ t
|
| 385 |
+
|
| 386 |
+
cam_info = {
|
| 387 |
+
"camera_id": int(img_id),
|
| 388 |
+
#"camera_to_world": c2w,
|
| 389 |
+
"width": w,
|
| 390 |
+
"height": h,
|
| 391 |
+
"fx": fx,
|
| 392 |
+
"fy": fy,
|
| 393 |
+
"cx": cx,
|
| 394 |
+
"cy": cy,
|
| 395 |
+
"R": R,
|
| 396 |
+
"T": t
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
camera = load_camera_colmap(cam_info)
|
| 400 |
+
if camera:
|
| 401 |
+
extrinsics.append((camera, str(images_dir / img_name)))
|
| 402 |
+
|
| 403 |
+
# Split data based on datasplit parameter
|
| 404 |
+
if datasplit == "train":
|
| 405 |
+
selected = [c for i, c in enumerate(extrinsics) if i % llffhold != 0]
|
| 406 |
+
elif datasplit == "test":
|
| 407 |
+
selected = [c for i, c in enumerate(extrinsics) if i % llffhold == 0]
|
| 408 |
+
else:
|
| 409 |
+
selected = extrinsics
|
| 410 |
+
|
| 411 |
+
if selected:
|
| 412 |
+
cameras, image_paths = zip(*selected)
|
| 413 |
+
width = cameras[0]['width']
|
| 414 |
+
height = cameras[0]['height']
|
| 415 |
+
fx = cameras[0]['fx']
|
| 416 |
+
fy = cameras[0]['fy']
|
| 417 |
+
|
| 418 |
+
# Calculate field of view
|
| 419 |
+
camera_angle_x = 2 * np.arctan(0.5 * width / fx)
|
| 420 |
+
camera_angle_y = 2 * np.arctan(0.5 * height / fy)
|
| 421 |
+
|
| 422 |
+
self.config['width'] = width
|
| 423 |
+
self.config['height'] = height
|
| 424 |
+
self.config['fx'] = fx
|
| 425 |
+
self.config['fy'] = fy
|
| 426 |
+
self.config['focal'] = fx # Use fx as primary focal length
|
| 427 |
+
|
| 428 |
+
return list(cameras), list(image_paths)
|
| 429 |
+
|
| 430 |
+
return [], []
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def load_image(self, path):
|
| 434 |
+
"""Load an image as a numpy array."""
|
| 435 |
+
if os.path.exists(path):
|
| 436 |
+
img = imageio.imread(path)
|
| 437 |
+
# Convert to float and normalize to [0, 1]
|
| 438 |
+
img_np = img.astype(np.float32) / 255.0
|
| 439 |
+
# Ensure image is RGB (discard alpha channel if present)
|
| 440 |
+
if img_np.shape[2] == 4:
|
| 441 |
+
img_np = img_np[:, :, :3] # Keep only R, G, B channels
|
| 442 |
+
return img_np
|
| 443 |
+
else:
|
| 444 |
+
raise FileNotFoundError(f"Image not found: {path}")
|
| 445 |
+
|
| 446 |
+
def zero_grad(self):
|
| 447 |
+
"""Zero out all gradients."""
|
| 448 |
+
wp.launch(
|
| 449 |
+
zero_gradients,
|
| 450 |
+
dim=self.num_points,
|
| 451 |
+
inputs=[
|
| 452 |
+
self.grads['positions'],
|
| 453 |
+
self.grads['scales'],
|
| 454 |
+
self.grads['rotations'],
|
| 455 |
+
self.grads['opacities'],
|
| 456 |
+
self.grads['shs'],
|
| 457 |
+
self.num_points
|
| 458 |
+
]
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
def densification_and_pruning(self, iteration):
|
| 462 |
+
"""Perform sophisticated densification and pruning of Gaussians."""
|
| 463 |
+
|
| 464 |
+
# Check if we should do densification
|
| 465 |
+
densify_from_iter = self.config.get('densify_from_iter', 500)
|
| 466 |
+
densify_until_iter = self.config.get('densify_until_iter', 15000)
|
| 467 |
+
densification_interval = self.config.get('densification_interval', 100)
|
| 468 |
+
opacity_reset_interval = self.config.get('opacity_reset_interval', 3000)
|
| 469 |
+
|
| 470 |
+
# Skip densification if outside iteration range
|
| 471 |
+
if iteration > densify_from_iter and iteration < densify_until_iter and iteration % densification_interval == 0:
|
| 472 |
+
print(f"Iteration {iteration}: Performing sophisticated densification and pruning")
|
| 473 |
+
|
| 474 |
+
# For simplified implementation, use position gradients as proxy for viewspace gradients
|
| 475 |
+
pos_grads = self.grads['positions']
|
| 476 |
+
avg_grads = wp.zeros(self.num_points, dtype=float, device=DEVICE)
|
| 477 |
+
|
| 478 |
+
@wp.kernel
|
| 479 |
+
def compute_grad_norms(pos_grad: wp.array(dtype=wp.vec3),
|
| 480 |
+
grad_norms: wp.array(dtype=float),
|
| 481 |
+
num_points: int):
|
| 482 |
+
i = wp.tid()
|
| 483 |
+
if i >= num_points:
|
| 484 |
+
return
|
| 485 |
+
grad_norms[i] = wp.length(pos_grad[i])
|
| 486 |
+
|
| 487 |
+
wp.launch(compute_grad_norms, dim=self.num_points,
|
| 488 |
+
inputs=[pos_grads, avg_grads, self.num_points])
|
| 489 |
+
|
| 490 |
+
# Configuration
|
| 491 |
+
grad_threshold = self.config.get('densify_grad_threshold', 0.0002)
|
| 492 |
+
percent_dense = self.config.get('percent_dense', 0.01)
|
| 493 |
+
|
| 494 |
+
# --- Step 1: Clone small Gaussians with high gradients ---
|
| 495 |
+
clone_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
|
| 496 |
+
wp.launch(
|
| 497 |
+
mark_clone_candidates,
|
| 498 |
+
dim=self.num_points,
|
| 499 |
+
inputs=[
|
| 500 |
+
avg_grads,
|
| 501 |
+
self.params['scales'],
|
| 502 |
+
grad_threshold,
|
| 503 |
+
self.scene_extent,
|
| 504 |
+
percent_dense,
|
| 505 |
+
clone_mask,
|
| 506 |
+
self.num_points
|
| 507 |
+
]
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Perform cloning
|
| 511 |
+
clone_prefix_sum = wp.zeros_like(clone_mask)
|
| 512 |
+
wp.utils.array_scan(clone_mask, clone_prefix_sum, inclusive=False)
|
| 513 |
+
total_to_clone = int(clone_prefix_sum.numpy()[-1])
|
| 514 |
+
|
| 515 |
+
if total_to_clone > 0:
|
| 516 |
+
print(f"[Clone] Cloning {total_to_clone} small Gaussians")
|
| 517 |
+
N = self.num_points
|
| 518 |
+
new_N = N + total_to_clone
|
| 519 |
+
|
| 520 |
+
# Allocate output arrays
|
| 521 |
+
out_params = {
|
| 522 |
+
'positions': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
|
| 523 |
+
'scales': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
|
| 524 |
+
'rotations': wp.zeros(new_N, dtype=wp.vec4, device=DEVICE),
|
| 525 |
+
'opacities': wp.zeros(new_N, dtype=float, device=DEVICE),
|
| 526 |
+
'shs': wp.zeros(new_N * 16, dtype=wp.vec3, device=DEVICE)
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
# Clone Gaussians
|
| 530 |
+
wp.launch(
|
| 531 |
+
clone_gaussians,
|
| 532 |
+
dim=N,
|
| 533 |
+
inputs=[
|
| 534 |
+
clone_mask,
|
| 535 |
+
clone_prefix_sum,
|
| 536 |
+
self.params['positions'],
|
| 537 |
+
self.params['scales'],
|
| 538 |
+
self.params['rotations'],
|
| 539 |
+
self.params['opacities'],
|
| 540 |
+
self.params['shs'],
|
| 541 |
+
0.01, # noise_scale
|
| 542 |
+
N, # offset
|
| 543 |
+
out_params['positions'],
|
| 544 |
+
out_params['scales'],
|
| 545 |
+
out_params['rotations'],
|
| 546 |
+
out_params['opacities'],
|
| 547 |
+
out_params['shs']
|
| 548 |
+
]
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# Update parameters and state
|
| 552 |
+
self.params = out_params
|
| 553 |
+
self.num_points = new_N
|
| 554 |
+
self.grads = self.create_gradient_arrays()
|
| 555 |
+
self.adam_m = self.create_gradient_arrays()
|
| 556 |
+
self.adam_v = self.create_gradient_arrays()
|
| 557 |
+
|
| 558 |
+
# --- Step 2: Split large Gaussians with high gradients ---
|
| 559 |
+
split_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
|
| 560 |
+
wp.launch(
|
| 561 |
+
mark_split_candidates,
|
| 562 |
+
dim=self.num_points,
|
| 563 |
+
inputs=[
|
| 564 |
+
avg_grads,
|
| 565 |
+
self.params['scales'],
|
| 566 |
+
grad_threshold,
|
| 567 |
+
self.scene_extent,
|
| 568 |
+
percent_dense,
|
| 569 |
+
split_mask,
|
| 570 |
+
self.num_points
|
| 571 |
+
]
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
# Perform splitting
|
| 575 |
+
split_prefix_sum = wp.zeros_like(split_mask)
|
| 576 |
+
wp.utils.array_scan(split_mask, split_prefix_sum, inclusive=False)
|
| 577 |
+
total_to_split = int(split_prefix_sum.numpy()[-1])
|
| 578 |
+
|
| 579 |
+
if total_to_split > 0:
|
| 580 |
+
print(f"[Split] Splitting {total_to_split} large Gaussians")
|
| 581 |
+
N = self.num_points
|
| 582 |
+
N_split = 2 # Split each Gaussian into 2
|
| 583 |
+
new_N = N + total_to_split * N_split
|
| 584 |
+
|
| 585 |
+
# Allocate output arrays
|
| 586 |
+
out_params = {
|
| 587 |
+
'positions': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
|
| 588 |
+
'scales': wp.zeros(new_N, dtype=wp.vec3, device=DEVICE),
|
| 589 |
+
'rotations': wp.zeros(new_N, dtype=wp.vec4, device=DEVICE),
|
| 590 |
+
'opacities': wp.zeros(new_N, dtype=float, device=DEVICE),
|
| 591 |
+
'shs': wp.zeros(new_N * 16, dtype=wp.vec3, device=DEVICE)
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
# Split Gaussians
|
| 595 |
+
wp.launch(
|
| 596 |
+
split_gaussians,
|
| 597 |
+
dim=N,
|
| 598 |
+
inputs=[
|
| 599 |
+
split_mask,
|
| 600 |
+
split_prefix_sum,
|
| 601 |
+
self.params['positions'],
|
| 602 |
+
self.params['scales'],
|
| 603 |
+
self.params['rotations'],
|
| 604 |
+
self.params['opacities'],
|
| 605 |
+
self.params['shs'],
|
| 606 |
+
N_split, # Number of splits per Gaussian
|
| 607 |
+
0.8, # scale_factor
|
| 608 |
+
N, # offset
|
| 609 |
+
out_params['positions'],
|
| 610 |
+
out_params['scales'],
|
| 611 |
+
out_params['rotations'],
|
| 612 |
+
out_params['opacities'],
|
| 613 |
+
out_params['shs']
|
| 614 |
+
]
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
# Update parameters and state
|
| 618 |
+
self.params = out_params
|
| 619 |
+
self.num_points = new_N
|
| 620 |
+
self.grads = self.create_gradient_arrays()
|
| 621 |
+
self.adam_m = self.create_gradient_arrays()
|
| 622 |
+
self.adam_v = self.create_gradient_arrays()
|
| 623 |
+
|
| 624 |
+
# Remove original split Gaussians
|
| 625 |
+
prune_filter = wp.zeros(self.num_points, dtype=int, device=DEVICE)
|
| 626 |
+
|
| 627 |
+
@wp.kernel
|
| 628 |
+
def mark_split_originals_for_removal(
|
| 629 |
+
split_mask: wp.array(dtype=int),
|
| 630 |
+
prune_filter: wp.array(dtype=int),
|
| 631 |
+
offset: int,
|
| 632 |
+
num_points: int
|
| 633 |
+
):
|
| 634 |
+
i = wp.tid()
|
| 635 |
+
if i >= num_points:
|
| 636 |
+
return
|
| 637 |
+
if i < offset and split_mask[i] == 1:
|
| 638 |
+
prune_filter[i] = 1 # Mark for removal
|
| 639 |
+
else:
|
| 640 |
+
prune_filter[i] = 0 # Keep
|
| 641 |
+
|
| 642 |
+
wp.launch(mark_split_originals_for_removal, dim=self.num_points,
|
| 643 |
+
inputs=[split_mask, prune_filter, N, self.num_points])
|
| 644 |
+
|
| 645 |
+
# Invert mask to get valid mask
|
| 646 |
+
valid_mask = wp.zeros_like(prune_filter)
|
| 647 |
+
|
| 648 |
+
@wp.kernel
|
| 649 |
+
def invert_mask(prune: wp.array(dtype=int), valid: wp.array(dtype=int), n: int):
|
| 650 |
+
i = wp.tid()
|
| 651 |
+
if i >= n:
|
| 652 |
+
return
|
| 653 |
+
valid[i] = 1 - prune[i]
|
| 654 |
+
|
| 655 |
+
wp.launch(invert_mask, dim=self.num_points,
|
| 656 |
+
inputs=[prune_filter, valid_mask, self.num_points])
|
| 657 |
+
|
| 658 |
+
# Count valid points and compact
|
| 659 |
+
prefix_sum = wp.zeros_like(valid_mask)
|
| 660 |
+
wp.utils.array_scan(valid_mask, prefix_sum, inclusive=False)
|
| 661 |
+
valid_count = int(prefix_sum.numpy()[-1])
|
| 662 |
+
|
| 663 |
+
if valid_count < self.num_points:
|
| 664 |
+
print(f"[Split] Removing {self.num_points - valid_count} original split Gaussians")
|
| 665 |
+
|
| 666 |
+
# Allocate compacted output
|
| 667 |
+
compact_params = {
|
| 668 |
+
'positions': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
|
| 669 |
+
'scales': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
|
| 670 |
+
'rotations': wp.zeros(valid_count, dtype=wp.vec4, device=DEVICE),
|
| 671 |
+
'opacities': wp.zeros(valid_count, dtype=float, device=DEVICE),
|
| 672 |
+
'shs': wp.zeros(valid_count * 16, dtype=wp.vec3, device=DEVICE)
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
wp.launch(
|
| 676 |
+
compact_gaussians,
|
| 677 |
+
dim=self.num_points,
|
| 678 |
+
inputs=[
|
| 679 |
+
valid_mask,
|
| 680 |
+
prefix_sum,
|
| 681 |
+
self.params['positions'],
|
| 682 |
+
self.params['scales'],
|
| 683 |
+
self.params['rotations'],
|
| 684 |
+
self.params['opacities'],
|
| 685 |
+
self.params['shs'],
|
| 686 |
+
compact_params['positions'],
|
| 687 |
+
compact_params['scales'],
|
| 688 |
+
compact_params['rotations'],
|
| 689 |
+
compact_params['opacities'],
|
| 690 |
+
compact_params['shs']
|
| 691 |
+
]
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
# Update parameters and state
|
| 695 |
+
self.params = compact_params
|
| 696 |
+
self.num_points = valid_count
|
| 697 |
+
self.grads = self.create_gradient_arrays()
|
| 698 |
+
self.adam_m = self.create_gradient_arrays()
|
| 699 |
+
self.adam_v = self.create_gradient_arrays()
|
| 700 |
+
|
| 701 |
+
# --- Step 3: Enhanced Pruning ---
|
| 702 |
+
print(f"[Prune] Performing enhanced pruning")
|
| 703 |
+
|
| 704 |
+
valid_mask = wp.zeros(self.num_points, dtype=int, device=DEVICE)
|
| 705 |
+
|
| 706 |
+
# Use opacity-based pruning for now
|
| 707 |
+
wp.launch(
|
| 708 |
+
prune_gaussians,
|
| 709 |
+
dim=self.num_points,
|
| 710 |
+
inputs=[
|
| 711 |
+
self.params['opacities'],
|
| 712 |
+
self.config.get('cull_opacity_threshold', 0.005),
|
| 713 |
+
valid_mask,
|
| 714 |
+
self.num_points
|
| 715 |
+
]
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# Count valid points
|
| 719 |
+
prefix_sum = wp.zeros_like(valid_mask)
|
| 720 |
+
wp.utils.array_scan(valid_mask, prefix_sum, inclusive=False)
|
| 721 |
+
valid_count = int(prefix_sum.numpy()[-1])
|
| 722 |
+
|
| 723 |
+
# Check pruning constraints
|
| 724 |
+
min_valid_points = self.config.get('min_valid_points', 1000)
|
| 725 |
+
max_valid_points = self.config.get('max_valid_points', 1000000)
|
| 726 |
+
max_prune_ratio = self.config.get('max_allowed_prune_ratio', 0.5)
|
| 727 |
+
|
| 728 |
+
prune_count = self.num_points - valid_count
|
| 729 |
+
prune_ratio = prune_count / self.num_points if self.num_points > 0 else 0
|
| 730 |
+
|
| 731 |
+
if (valid_count >= min_valid_points and
|
| 732 |
+
valid_count <= max_valid_points and
|
| 733 |
+
prune_ratio <= max_prune_ratio and
|
| 734 |
+
valid_count < self.num_points):
|
| 735 |
+
|
| 736 |
+
print(f"[Prune] Compacting from {self.num_points} → {valid_count} points")
|
| 737 |
+
|
| 738 |
+
# Allocate compacted output
|
| 739 |
+
out_params = {
|
| 740 |
+
'positions': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
|
| 741 |
+
'scales': wp.zeros(valid_count, dtype=wp.vec3, device=DEVICE),
|
| 742 |
+
'rotations': wp.zeros(valid_count, dtype=wp.vec4, device=DEVICE),
|
| 743 |
+
'opacities': wp.zeros(valid_count, dtype=float, device=DEVICE),
|
| 744 |
+
'shs': wp.zeros(valid_count * 16, dtype=wp.vec3, device=DEVICE)
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
wp.launch(
|
| 748 |
+
compact_gaussians,
|
| 749 |
+
dim=self.num_points,
|
| 750 |
+
inputs=[
|
| 751 |
+
valid_mask,
|
| 752 |
+
prefix_sum,
|
| 753 |
+
self.params['positions'],
|
| 754 |
+
self.params['scales'],
|
| 755 |
+
self.params['rotations'],
|
| 756 |
+
self.params['opacities'],
|
| 757 |
+
self.params['shs'],
|
| 758 |
+
out_params['positions'],
|
| 759 |
+
out_params['scales'],
|
| 760 |
+
out_params['rotations'],
|
| 761 |
+
out_params['opacities'],
|
| 762 |
+
out_params['shs']
|
| 763 |
+
]
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
# Update parameters and state
|
| 767 |
+
self.params = out_params
|
| 768 |
+
self.num_points = valid_count
|
| 769 |
+
self.grads = self.create_gradient_arrays()
|
| 770 |
+
self.adam_m = self.create_gradient_arrays()
|
| 771 |
+
self.adam_v = self.create_gradient_arrays()
|
| 772 |
+
else:
|
| 773 |
+
print(f"[Prune] Skipping pruning: valid={valid_count}, ratio={prune_ratio:.3f}")
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
# Opacity reset - updated logic to match reference implementation
|
| 777 |
+
background_is_white = all(c == 1.0 for c in self.config['background_color'])
|
| 778 |
+
should_reset_opacity = (
|
| 779 |
+
iteration % opacity_reset_interval == 0 or
|
| 780 |
+
(background_is_white and iteration == densify_from_iter)
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
if should_reset_opacity:
|
| 784 |
+
print(f"Iteration {iteration}: Resetting opacities")
|
| 785 |
+
wp.launch(
|
| 786 |
+
reset_opacities,
|
| 787 |
+
dim=self.num_points,
|
| 788 |
+
inputs=[
|
| 789 |
+
self.params['opacities'],
|
| 790 |
+
0.01, # max_opacity
|
| 791 |
+
self.num_points
|
| 792 |
+
]
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
def optimizer_step(self, iteration):
|
| 797 |
+
"""Perform an Adam optimization step."""
|
| 798 |
+
|
| 799 |
+
# Get learning rates from scheduler or use config defaults
|
| 800 |
+
if self.lr_scheduler:
|
| 801 |
+
lr_pos = self.lr_scheduler['positions'].get_lr(iteration, self.config['num_iterations'])
|
| 802 |
+
lr_scale = self.lr_scheduler['scales'].get_lr(iteration, self.config['num_iterations'])
|
| 803 |
+
lr_rot = self.lr_scheduler['rotations'].get_lr(iteration, self.config['num_iterations'])
|
| 804 |
+
lr_sh = self.lr_scheduler['shs'].get_lr(iteration, self.config['num_iterations'])
|
| 805 |
+
lr_opac = self.lr_scheduler['opacities'].get_lr(iteration, self.config['num_iterations'])
|
| 806 |
+
|
| 807 |
+
# Track learning rate history
|
| 808 |
+
self.learning_rate_history['positions'].append(lr_pos)
|
| 809 |
+
self.learning_rate_history['scales'].append(lr_scale)
|
| 810 |
+
self.learning_rate_history['rotations'].append(lr_rot)
|
| 811 |
+
self.learning_rate_history['shs'].append(lr_sh)
|
| 812 |
+
self.learning_rate_history['opacities'].append(lr_opac)
|
| 813 |
+
|
| 814 |
+
# Log learning rates occasionally
|
| 815 |
+
if iteration % 1000 == 0:
|
| 816 |
+
print(f"Iteration {iteration} learning rates:")
|
| 817 |
+
print(f" positions: {lr_pos:.6f}")
|
| 818 |
+
print(f" scales: {lr_scale:.6f}")
|
| 819 |
+
print(f" rotations: {lr_rot:.6f}")
|
| 820 |
+
print(f" shs: {lr_sh:.6f}")
|
| 821 |
+
print(f" opacities: {lr_opac:.6f}")
|
| 822 |
+
else:
|
| 823 |
+
# Use static learning rates from config
|
| 824 |
+
lr_pos = self.config['lr_pos']
|
| 825 |
+
lr_scale = self.config['lr_scale']
|
| 826 |
+
lr_rot = self.config['lr_rot']
|
| 827 |
+
lr_sh = self.config['lr_sh']
|
| 828 |
+
lr_opac = self.config['lr_opac']
|
| 829 |
+
|
| 830 |
+
wp.launch(
|
| 831 |
+
adam_update,
|
| 832 |
+
dim=self.num_points,
|
| 833 |
+
inputs=[
|
| 834 |
+
# Parameters
|
| 835 |
+
self.params['positions'],
|
| 836 |
+
self.params['scales'],
|
| 837 |
+
self.params['rotations'],
|
| 838 |
+
self.params['opacities'],
|
| 839 |
+
self.params['shs'],
|
| 840 |
+
|
| 841 |
+
# Gradients
|
| 842 |
+
self.grads['positions'],
|
| 843 |
+
self.grads['scales'],
|
| 844 |
+
self.grads['rotations'],
|
| 845 |
+
self.grads['opacities'],
|
| 846 |
+
self.grads['shs'],
|
| 847 |
+
|
| 848 |
+
# First moments (m)
|
| 849 |
+
self.adam_m['positions'],
|
| 850 |
+
self.adam_m['scales'],
|
| 851 |
+
self.adam_m['rotations'],
|
| 852 |
+
self.adam_m['opacities'],
|
| 853 |
+
self.adam_m['shs'],
|
| 854 |
+
|
| 855 |
+
# Second moments (v)
|
| 856 |
+
self.adam_v['positions'],
|
| 857 |
+
self.adam_v['scales'],
|
| 858 |
+
self.adam_v['rotations'],
|
| 859 |
+
self.adam_v['opacities'],
|
| 860 |
+
self.adam_v['shs'],
|
| 861 |
+
|
| 862 |
+
# Optimizer parameters with dynamic learning rates
|
| 863 |
+
self.num_points,
|
| 864 |
+
lr_pos, # Dynamic learning rate for positions
|
| 865 |
+
lr_scale, # Dynamic learning rate for scales
|
| 866 |
+
lr_rot, # Dynamic learning rate for rotations
|
| 867 |
+
lr_sh, # Dynamic learning rate for SH coefficients
|
| 868 |
+
lr_opac, # Dynamic learning rate for opacities
|
| 869 |
+
self.config['adam_beta1'],
|
| 870 |
+
self.config['adam_beta2'],
|
| 871 |
+
self.config['adam_epsilon'],
|
| 872 |
+
iteration
|
| 873 |
+
]
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
def save_checkpoint(self, iteration):
|
| 877 |
+
"""Save the current point cloud and training state."""
|
| 878 |
+
checkpoint_dir = self.output_path / "point_cloud" / f"iteration_{iteration}"
|
| 879 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 880 |
+
|
| 881 |
+
# Save point cloud as PLY
|
| 882 |
+
ply_path = checkpoint_dir / "point_cloud.ply"
|
| 883 |
+
save_ply(self.params, ply_path, self.num_points)
|
| 884 |
+
|
| 885 |
+
# Save loss history
|
| 886 |
+
loss_path = self.output_path / "loss.txt"
|
| 887 |
+
with open(loss_path, 'w') as f:
|
| 888 |
+
for loss in self.losses:
|
| 889 |
+
f.write(f"{loss}\n")
|
| 890 |
+
|
| 891 |
+
# Save loss plot
|
| 892 |
+
plt.figure(figsize=(10, 5))
|
| 893 |
+
plt.plot(self.losses)
|
| 894 |
+
plt.title('Training Loss')
|
| 895 |
+
plt.xlabel('Iteration')
|
| 896 |
+
plt.ylabel('Loss')
|
| 897 |
+
plt.savefig(self.output_path / "loss_plot.png")
|
| 898 |
+
plt.close()
|
| 899 |
+
|
| 900 |
+
# Save a rendered view
|
| 901 |
+
camera_idx = 0 # Front view
|
| 902 |
+
rendered_image, _, _ = render_gaussians(
|
| 903 |
+
background=np.array(self.config['background_color'], dtype=np.float32),
|
| 904 |
+
means3D=self.params['positions'].numpy(),
|
| 905 |
+
colors=None, # Use SH coefficients instead
|
| 906 |
+
opacity=self.params['opacities'].numpy(),
|
| 907 |
+
scales=self.params['scales'].numpy(),
|
| 908 |
+
rotations=self.params['rotations'].numpy(),
|
| 909 |
+
scale_modifier=self.config['scale_modifier'],
|
| 910 |
+
viewmatrix=self.cameras[camera_idx]['world_to_camera'],
|
| 911 |
+
projmatrix=self.cameras[camera_idx]['full_proj_matrix'],
|
| 912 |
+
tan_fovx=self.cameras[camera_idx]['tan_fovx'],
|
| 913 |
+
tan_fovy=self.cameras[camera_idx]['tan_fovy'],
|
| 914 |
+
image_height=self.cameras[camera_idx]['height'],
|
| 915 |
+
image_width=self.cameras[camera_idx]['width'],
|
| 916 |
+
sh=self.params['shs'].numpy(), # Pass SH coefficients
|
| 917 |
+
degree=self.config['sh_degree'],
|
| 918 |
+
campos=self.cameras[camera_idx]['camera_center'],
|
| 919 |
+
prefiltered=False,
|
| 920 |
+
antialiasing=True,
|
| 921 |
+
clamped=True
|
| 922 |
+
)
|
| 923 |
+
# Save rendered view as image
|
| 924 |
+
rendered_array = wp.to_torch(rendered_image).cpu().numpy()
|
| 925 |
+
# Handle case where rendered_array has shape (3, H, W) - transpose to (H, W, 3)
|
| 926 |
+
if rendered_array.shape[0] == 3 and len(rendered_array.shape) == 3:
|
| 927 |
+
rendered_array = np.transpose(rendered_array, (1, 2, 0))
|
| 928 |
+
img8 = (np.clip(rendered_array, 0, 1) * 255).astype(np.uint8)
|
| 929 |
+
imageio.imwrite(checkpoint_dir / "rendered_view.png", img8)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
def debug_log_and_save_images(
|
| 933 |
+
self,
|
| 934 |
+
rendered_image, # np.float32 H×W×3 (range 0-1)
|
| 935 |
+
target_image, # np.float32
|
| 936 |
+
depth_image, # wp.array2d(float) – optional but unused here
|
| 937 |
+
camera_idx: int,
|
| 938 |
+
it: int
|
| 939 |
+
):
|
| 940 |
+
# ------ quick numeric read-out -----------------------------------
|
| 941 |
+
radii = wp.to_torch(self.intermediate_buffers["radii"]).cpu().numpy()
|
| 942 |
+
alphas = wp.to_torch(self.intermediate_buffers["conic_opacity"]).cpu().numpy()[:, 3]
|
| 943 |
+
offs = wp.to_torch(self.intermediate_buffers["point_offsets"]).cpu().numpy()
|
| 944 |
+
num_dup = int(offs[-1]) if len(offs) else 0
|
| 945 |
+
r_med = np.median(radii[radii > 0]) if (radii > 0).any() else 0
|
| 946 |
+
|
| 947 |
+
# Count visible Gaussians
|
| 948 |
+
xy_image = wp.to_torch(self.intermediate_buffers["points_xy_image"]).cpu().numpy()
|
| 949 |
+
W = self.cameras[camera_idx]['width']
|
| 950 |
+
H = self.cameras[camera_idx]['height']
|
| 951 |
+
visible_gaussians = np.sum(
|
| 952 |
+
(xy_image[:, 0] >= 0) & (xy_image[:, 0] < W) &
|
| 953 |
+
(xy_image[:, 1] >= 0) & (xy_image[:, 1] < H) &
|
| 954 |
+
np.isfinite(xy_image).all(axis=1) &
|
| 955 |
+
(radii > 0) # Only count Gaussians with positive radius
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
print(
|
| 959 |
+
f"[it {it:05d}] cam={camera_idx:02d} dup={num_dup:<6} "
|
| 960 |
+
f"r_med={r_med:5.1f} α∈[{alphas.min():.3f},"
|
| 961 |
+
f"{np.median(alphas):.3f},{alphas.max():.3f}] "
|
| 962 |
+
f"visible={visible_gaussians}/{len(xy_image)}"
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# ------ save render / target PNG ---------------------------------
|
| 966 |
+
def save_rgb(arr_f32, stem):
|
| 967 |
+
# Handle case where arr_f32 has shape (3, H, W) - transpose to (H, W, 3)
|
| 968 |
+
if arr_f32.shape[0] == 3 and len(arr_f32.shape) == 3:
|
| 969 |
+
arr_f32 = np.transpose(arr_f32, (1, 2, 0))
|
| 970 |
+
img8 = (np.clip(arr_f32, 0, 1) * 255).astype(np.uint8)
|
| 971 |
+
# Include camera index in the filename
|
| 972 |
+
imageio.imwrite(self.output_path / f"{stem}" / f"{stem}_{it:06d}_cam{camera_idx:02d}.png", img8)
|
| 973 |
+
|
| 974 |
+
save_rgb(rendered_image if isinstance(rendered_image, np.ndarray) else wp.to_torch(rendered_image).cpu().numpy(), "render")
|
| 975 |
+
save_rgb(target_image, "target")
|
| 976 |
+
|
| 977 |
+
# ------ make 2-D projection scatter ------------------------------
|
| 978 |
+
xy = wp.to_torch(self.intermediate_buffers["points_xy_image"]).cpu().numpy()
|
| 979 |
+
depth = wp.to_torch(self.intermediate_buffers["depths"]).cpu().numpy()
|
| 980 |
+
H, W = self.config["height"], self.config["width"]
|
| 981 |
+
|
| 982 |
+
mask = (
|
| 983 |
+
(xy[:, 0] >= 0) & (xy[:, 0] < W) &
|
| 984 |
+
(xy[:, 1] >= 0) & (xy[:, 1] < H) &
|
| 985 |
+
np.isfinite(xy).all(axis=1) &
|
| 986 |
+
(radii > 0) # Only include Gaussians with positive radius
|
| 987 |
+
)
|
| 988 |
+
if mask.any():
|
| 989 |
+
plt.figure(figsize=(6, 6))
|
| 990 |
+
plt.scatter(xy[mask, 0], xy[mask, 1],
|
| 991 |
+
s=4, c=depth[mask], cmap="turbo", alpha=.7)
|
| 992 |
+
plt.gca().invert_yaxis()
|
| 993 |
+
plt.xlim(0, W); plt.ylim(H, 0)
|
| 994 |
+
plt.title(f"Projected Gaussians (cam {camera_idx}, iter {it}): {np.sum(mask)}/{len(xy)} visible")
|
| 995 |
+
plt.colorbar(label="depth(z)")
|
| 996 |
+
plt.tight_layout()
|
| 997 |
+
# Include camera index in the filename
|
| 998 |
+
plt.savefig(self.output_path / 'proj' / f"proj_{it:06d}_cam{camera_idx:02d}.png", dpi=250)
|
| 999 |
+
plt.close()
|
| 1000 |
+
|
| 1001 |
+
# depth histogram
|
| 1002 |
+
plt.figure(figsize=(5, 3))
|
| 1003 |
+
plt.hist(depth[mask], bins=40, color="steelblue")
|
| 1004 |
+
plt.xlabel("depth (camera-z)")
|
| 1005 |
+
plt.ylabel("count")
|
| 1006 |
+
plt.title(f"Depth hist – cam {camera_idx}, {mask.sum()} pts")
|
| 1007 |
+
plt.tight_layout()
|
| 1008 |
+
# Include camera index in the filename
|
| 1009 |
+
plt.savefig(self.output_path / 'depth_hist' / f"depth_hist_{it:06d}.png", dpi=250)
|
| 1010 |
+
plt.close()
|
| 1011 |
+
|
| 1012 |
+
def train(self):
|
| 1013 |
+
"""Train the 3D Gaussian Splatting model."""
|
| 1014 |
+
num_iterations = self.config['num_iterations']
|
| 1015 |
+
|
| 1016 |
+
# Main training loop
|
| 1017 |
+
with tqdm(total=num_iterations) as pbar:
|
| 1018 |
+
for iteration in range(num_iterations):
|
| 1019 |
+
# Select a random camera and corresponding image
|
| 1020 |
+
camera_idx = np.random.randint(0, len(self.cameras))
|
| 1021 |
+
image_path = self.image_paths[camera_idx]
|
| 1022 |
+
target_image = self.load_image(image_path)
|
| 1023 |
+
|
| 1024 |
+
# Zero gradients
|
| 1025 |
+
self.zero_grad()
|
| 1026 |
+
# Render the view
|
| 1027 |
+
rendered_image, depth_image, self.intermediate_buffers = render_gaussians(
|
| 1028 |
+
background=np.array(self.config['background_color'], dtype=np.float32),
|
| 1029 |
+
means3D=self.params['positions'].numpy(),
|
| 1030 |
+
colors=None, # Use SH coefficients instead
|
| 1031 |
+
opacity=self.params['opacities'].numpy(),
|
| 1032 |
+
scales=self.params['scales'].numpy(),
|
| 1033 |
+
rotations=self.params['rotations'].numpy(),
|
| 1034 |
+
scale_modifier=self.config['scale_modifier'],
|
| 1035 |
+
viewmatrix=self.cameras[camera_idx]['world_to_camera'],
|
| 1036 |
+
projmatrix=self.cameras[camera_idx]['full_proj_matrix'],
|
| 1037 |
+
tan_fovx=self.cameras[camera_idx]['tan_fovx'],
|
| 1038 |
+
tan_fovy=self.cameras[camera_idx]['tan_fovy'],
|
| 1039 |
+
image_height=self.cameras[camera_idx]['height'],
|
| 1040 |
+
image_width=self.cameras[camera_idx]['width'],
|
| 1041 |
+
sh=self.params['shs'].numpy(), # Pass SH coefficients
|
| 1042 |
+
degree=self.config['sh_degree'],
|
| 1043 |
+
campos=self.cameras[camera_idx]['camera_center'],
|
| 1044 |
+
prefiltered=False,
|
| 1045 |
+
antialiasing=False,
|
| 1046 |
+
clamped=True
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
radii = wp.to_torch(self.intermediate_buffers["radii"]).cpu().numpy()
|
| 1050 |
+
np_rendered_image = wp.to_torch(rendered_image).cpu().numpy()
|
| 1051 |
+
np_rendered_image = np_rendered_image.transpose(2, 0, 1)
|
| 1052 |
+
|
| 1053 |
+
#if iteration % self.config['save_interval'] == 0:
|
| 1054 |
+
if (
|
| 1055 |
+
iteration < 10 or
|
| 1056 |
+
#(iteration < 50 and iteration % 5 == 0) or
|
| 1057 |
+
#(iteration < 100 and iteration % 10 == 0) or
|
| 1058 |
+
#(iteration < 1000 and iteration % 100 == 0) or
|
| 1059 |
+
(iteration % 1000 == 0) or
|
| 1060 |
+
(iteration == num_iterations - 1)
|
| 1061 |
+
):
|
| 1062 |
+
self.debug_log_and_save_images(np_rendered_image, target_image, depth_image, camera_idx, iteration)
|
| 1063 |
+
|
| 1064 |
+
# Calculate L1 loss
|
| 1065 |
+
l1_val = l1_loss(rendered_image, target_image)
|
| 1066 |
+
|
| 1067 |
+
# # Calculate SSIM, not used
|
| 1068 |
+
# ssim_val = ssim(rendered_image, target_image)
|
| 1069 |
+
# # Combined loss with weighted SSIM
|
| 1070 |
+
# lambda_dssim = self.config['lambda_dssim']
|
| 1071 |
+
# # loss = (1 - λ) * L1 + λ * (1 - SSIM)
|
| 1072 |
+
# loss = (1.0 - lambda_dssim) * l1_val + lambda_dssim * (1.0 - ssim_val)
|
| 1073 |
+
|
| 1074 |
+
loss = l1_val
|
| 1075 |
+
self.losses.append(loss)
|
| 1076 |
+
# Compute pixel gradients for image loss (dL/dColor)
|
| 1077 |
+
pixel_grad_buffer = compute_image_gradients(
|
| 1078 |
+
rendered_image, target_image, lambda_dssim=0
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
# Prepare camera parameters
|
| 1082 |
+
camera = self.cameras[camera_idx]
|
| 1083 |
+
view_matrix = wp.mat44(camera['world_to_camera'].flatten())
|
| 1084 |
+
proj_matrix = wp.mat44(camera['full_proj_matrix'].flatten())
|
| 1085 |
+
campos = wp.vec3(camera['camera_center'][0], camera['camera_center'][1], camera['camera_center'][2])
|
| 1086 |
+
|
| 1087 |
+
# Create appropriate buffer dictionaries for the backward pass
|
| 1088 |
+
geom_buffer = {
|
| 1089 |
+
'radii': self.intermediate_buffers['radii'],
|
| 1090 |
+
'means2D': self.intermediate_buffers['points_xy_image'],
|
| 1091 |
+
'conic_opacity': self.intermediate_buffers['conic_opacity'],
|
| 1092 |
+
'rgb': self.intermediate_buffers['colors'],
|
| 1093 |
+
'clamped': self.intermediate_buffers['clamped_state']
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
binning_buffer = {
|
| 1097 |
+
'point_list': self.intermediate_buffers['point_list']
|
| 1098 |
+
}
|
| 1099 |
+
|
| 1100 |
+
img_buffer = {
|
| 1101 |
+
'ranges': self.intermediate_buffers['ranges'],
|
| 1102 |
+
'final_Ts': self.intermediate_buffers['final_Ts'],
|
| 1103 |
+
'n_contrib': self.intermediate_buffers['n_contrib']
|
| 1104 |
+
}
|
| 1105 |
+
|
| 1106 |
+
gradients = backward(
|
| 1107 |
+
# Core parameters
|
| 1108 |
+
background=np.array(self.config['background_color'], dtype=np.float32),
|
| 1109 |
+
means3D=self.params['positions'],
|
| 1110 |
+
dL_dpixels=pixel_grad_buffer,
|
| 1111 |
+
|
| 1112 |
+
# Model parameters (pass directly from self.params)
|
| 1113 |
+
opacity=self.params['opacities'],
|
| 1114 |
+
shs=self.params['shs'],
|
| 1115 |
+
scales=self.params['scales'],
|
| 1116 |
+
rotations=self.params['rotations'],
|
| 1117 |
+
scale_modifier=self.config['scale_modifier'],
|
| 1118 |
+
|
| 1119 |
+
# Camera parameters
|
| 1120 |
+
viewmatrix=view_matrix,
|
| 1121 |
+
projmatrix=proj_matrix,
|
| 1122 |
+
tan_fovx=camera['tan_fovx'],
|
| 1123 |
+
tan_fovy=camera['tan_fovy'],
|
| 1124 |
+
image_height=camera['height'],
|
| 1125 |
+
image_width=camera['width'],
|
| 1126 |
+
campos=campos,
|
| 1127 |
+
|
| 1128 |
+
# Forward output buffers
|
| 1129 |
+
radii=self.intermediate_buffers['radii'],
|
| 1130 |
+
means2D=self.intermediate_buffers['points_xy_image'],
|
| 1131 |
+
conic_opacity=self.intermediate_buffers['conic_opacity'],
|
| 1132 |
+
rgb=self.intermediate_buffers['colors'],
|
| 1133 |
+
cov3Ds=self.intermediate_buffers['cov3Ds'],
|
| 1134 |
+
clamped=self.intermediate_buffers['clamped_state'],
|
| 1135 |
+
|
| 1136 |
+
# Internal state buffers
|
| 1137 |
+
geom_buffer=geom_buffer,
|
| 1138 |
+
binning_buffer=binning_buffer,
|
| 1139 |
+
img_buffer=img_buffer,
|
| 1140 |
+
|
| 1141 |
+
# Algorithm parameters
|
| 1142 |
+
degree=self.config['sh_degree'],
|
| 1143 |
+
debug=False
|
| 1144 |
+
)
|
| 1145 |
+
|
| 1146 |
+
# 3. Copy gradients from backward result to the optimizer's gradient buffers
|
| 1147 |
+
wp.copy(self.grads['positions'], gradients['dL_dmean3D'])
|
| 1148 |
+
wp.copy(self.grads['scales'], gradients['dL_dscale'])
|
| 1149 |
+
wp.copy(self.grads['rotations'], gradients['dL_drot'])
|
| 1150 |
+
wp.copy(self.grads['opacities'], gradients['dL_dopacity'])
|
| 1151 |
+
wp.copy(self.grads['shs'], gradients['dL_dshs'])
|
| 1152 |
+
|
| 1153 |
+
# Update parameters
|
| 1154 |
+
self.optimizer_step(iteration)
|
| 1155 |
+
|
| 1156 |
+
# Update progress bar
|
| 1157 |
+
pbar.update(1)
|
| 1158 |
+
pbar.set_description(f"Loss: {loss:.6f}")
|
| 1159 |
+
|
| 1160 |
+
self.densification_and_pruning(iteration)
|
| 1161 |
+
|
| 1162 |
+
# Save checkpoint
|
| 1163 |
+
#if iteration % self.config['save_interval'] == 0 or iteration == num_iterations - 1:
|
| 1164 |
+
|
| 1165 |
+
if (
|
| 1166 |
+
iteration < 10 or
|
| 1167 |
+
# (iteration < 50 and iteration % 5 == 0) or
|
| 1168 |
+
# (iteration < 100 and iteration % 10 == 0) or
|
| 1169 |
+
# (iteration < 1000 and iteration % 100 == 0) or
|
| 1170 |
+
(iteration % 1000 == 0) or
|
| 1171 |
+
(iteration == num_iterations - 1)
|
| 1172 |
+
):
|
| 1173 |
+
self.save_checkpoint(iteration)
|
| 1174 |
+
print("Training complete!")
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
def visualize_camera_points_alignment_interactive(self):
|
| 1178 |
+
"""Create an interactive 3D visualization with camera frustums and colored points"""
|
| 1179 |
+
try:
|
| 1180 |
+
import plotly.graph_objects as go
|
| 1181 |
+
from plotly.subplots import make_subplots
|
| 1182 |
+
import plotly.express as px
|
| 1183 |
+
except ImportError:
|
| 1184 |
+
print("plotly not found. Install with: pip install plotly")
|
| 1185 |
+
return
|
| 1186 |
+
|
| 1187 |
+
# Get data
|
| 1188 |
+
camera_positions = np.array([cam['camera_center'] for cam in self.cameras])
|
| 1189 |
+
points_np = wp.to_torch(self.params['positions']).cpu().numpy()
|
| 1190 |
+
|
| 1191 |
+
# Get SH coefficients for colors
|
| 1192 |
+
shs_np = wp.to_torch(self.params['shs']).cpu().numpy()
|
| 1193 |
+
|
| 1194 |
+
# Extract base colors from SH coefficients
|
| 1195 |
+
C0 = 0.28209479177387814 # Normalization constant for Y_00
|
| 1196 |
+
point_colors = np.zeros((len(points_np), 3), dtype=np.float32)
|
| 1197 |
+
|
| 1198 |
+
# Get only the DC component (first SH coefficient) for each point
|
| 1199 |
+
for i in range(len(points_np)):
|
| 1200 |
+
sh_dc = shs_np[i * 16] # First SH coefficient for each point
|
| 1201 |
+
rgb = sh_dc * C0 + 0.5
|
| 1202 |
+
point_colors[i] = np.clip(rgb, 0, 1)
|
| 1203 |
+
|
| 1204 |
+
# Sample points for better performance
|
| 1205 |
+
max_points = 5000
|
| 1206 |
+
if len(points_np) > max_points:
|
| 1207 |
+
indices = np.random.choice(len(points_np), max_points, replace=False)
|
| 1208 |
+
points_sample = points_np[indices]
|
| 1209 |
+
colors_sample = point_colors[indices]
|
| 1210 |
+
else:
|
| 1211 |
+
points_sample = points_np
|
| 1212 |
+
colors_sample = point_colors
|
| 1213 |
+
|
| 1214 |
+
# Convert colors to hex format for plotly
|
| 1215 |
+
colors_hex = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})' for r, g, b in colors_sample]
|
| 1216 |
+
|
| 1217 |
+
# Calculate extents
|
| 1218 |
+
cam_extent = np.max(np.abs(camera_positions))
|
| 1219 |
+
points_extent = np.max(np.abs(points_sample))
|
| 1220 |
+
|
| 1221 |
+
print(f"Camera extent: {cam_extent:.3f}")
|
| 1222 |
+
print(f"Points extent: {points_extent:.3f}")
|
| 1223 |
+
print(f"Scale ratio: {cam_extent/points_extent:.3f}")
|
| 1224 |
+
|
| 1225 |
+
# Create figure
|
| 1226 |
+
fig = make_subplots(
|
| 1227 |
+
rows=1, cols=1,
|
| 1228 |
+
specs=[[{"type": "scene"}]],
|
| 1229 |
+
subplot_titles=['3D Scene with Camera Frustums']
|
| 1230 |
+
)
|
| 1231 |
+
|
| 1232 |
+
# Colors for cameras
|
| 1233 |
+
n_cameras = len(camera_positions)
|
| 1234 |
+
camera_colors = px.colors.qualitative.Bold[:min(n_cameras, 10)]
|
| 1235 |
+
if n_cameras > 10:
|
| 1236 |
+
camera_colors = camera_colors * (n_cameras // 10 + 1)
|
| 1237 |
+
|
| 1238 |
+
# Add point cloud with actual colors
|
| 1239 |
+
fig.add_trace(
|
| 1240 |
+
go.Scatter3d(
|
| 1241 |
+
x=points_sample[:, 0],
|
| 1242 |
+
y=points_sample[:, 1],
|
| 1243 |
+
z=points_sample[:, 2],
|
| 1244 |
+
mode='markers',
|
| 1245 |
+
marker=dict(
|
| 1246 |
+
size=1.5,
|
| 1247 |
+
color=colors_hex,
|
| 1248 |
+
opacity=0.7
|
| 1249 |
+
),
|
| 1250 |
+
name='Point Cloud',
|
| 1251 |
+
hovertemplate='Point<br>X: %{x:.3f}<br>Y: %{y:.3f}<br>Z: %{z:.3f}<extra></extra>'
|
| 1252 |
+
)
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
# Create camera frustums
|
| 1256 |
+
frustum_scale = cam_extent * 0.2 # Increased frustum size
|
| 1257 |
+
|
| 1258 |
+
for i, (cam, pos) in enumerate(zip(self.cameras, camera_positions)):
|
| 1259 |
+
color = camera_colors[i % len(camera_colors)]
|
| 1260 |
+
|
| 1261 |
+
# Extract camera parameters correctly
|
| 1262 |
+
c2w = cam['camera_to_world']
|
| 1263 |
+
|
| 1264 |
+
# Camera coordinate system in world space
|
| 1265 |
+
right = c2w[:3, 0] # x-axis
|
| 1266 |
+
up = c2w[:3, 1] # y-axis
|
| 1267 |
+
|
| 1268 |
+
# IMPORTANT FIX: FLIP THE DIRECTION for correct frustum orientation
|
| 1269 |
+
# The camera looks along the +Z axis in camera space, which is the 3rd column of c2w
|
| 1270 |
+
forward = c2w[:3, 2] # Use +Z for dust3r/COLMAP convention
|
| 1271 |
+
|
| 1272 |
+
# Debug first few cameras
|
| 1273 |
+
if i < 3:
|
| 1274 |
+
print(f"Camera {i} coordinate system:")
|
| 1275 |
+
print(f" Position: {pos}")
|
| 1276 |
+
print(f" Right: {right}")
|
| 1277 |
+
print(f" Up: {up}")
|
| 1278 |
+
print(f" Forward: {forward}")
|
| 1279 |
+
|
| 1280 |
+
# Calculate frustum parameters
|
| 1281 |
+
aspect_ratio = cam['width'] / cam['height']
|
| 1282 |
+
tan_fov_x = cam['tan_fovx']
|
| 1283 |
+
tan_fov_y = cam['tan_fovy']
|
| 1284 |
+
|
| 1285 |
+
# Frustum corners at near and far planes
|
| 1286 |
+
near_dist = frustum_scale * 0.1
|
| 1287 |
+
far_dist = frustum_scale
|
| 1288 |
+
|
| 1289 |
+
# Near plane corners
|
| 1290 |
+
tl_near = pos + forward * near_dist - right * near_dist * tan_fov_x + up * near_dist * tan_fov_y
|
| 1291 |
+
tr_near = pos + forward * near_dist + right * near_dist * tan_fov_x + up * near_dist * tan_fov_y
|
| 1292 |
+
bl_near = pos + forward * near_dist - right * near_dist * tan_fov_x - up * near_dist * tan_fov_y
|
| 1293 |
+
br_near = pos + forward * near_dist + right * near_dist * tan_fov_x - up * near_dist * tan_fov_y
|
| 1294 |
+
|
| 1295 |
+
# Far plane corners
|
| 1296 |
+
tl_far = pos + forward * far_dist - right * far_dist * tan_fov_x + up * far_dist * tan_fov_y
|
| 1297 |
+
tr_far = pos + forward * far_dist + right * far_dist * tan_fov_x + up * far_dist * tan_fov_y
|
| 1298 |
+
bl_far = pos + forward * far_dist - right * far_dist * tan_fov_x - up * far_dist * tan_fov_y
|
| 1299 |
+
br_far = pos + forward * far_dist + right * far_dist * tan_fov_x - up * far_dist * tan_fov_y
|
| 1300 |
+
|
| 1301 |
+
# Camera position marker
|
| 1302 |
+
fig.add_trace(
|
| 1303 |
+
go.Scatter3d(
|
| 1304 |
+
x=[pos[0]],
|
| 1305 |
+
y=[pos[1]],
|
| 1306 |
+
z=[pos[2]],
|
| 1307 |
+
mode='markers',
|
| 1308 |
+
marker=dict(
|
| 1309 |
+
size=8, # Larger marker
|
| 1310 |
+
color=color,
|
| 1311 |
+
symbol='diamond',
|
| 1312 |
+
),
|
| 1313 |
+
name=f'Camera {i}',
|
| 1314 |
+
hovertemplate=f'Camera {i}<br>X: %{{x:.3f}}<br>Y: %{{y:.3f}}<br>Z: %{{z:.3f}}<extra></extra>'
|
| 1315 |
+
)
|
| 1316 |
+
)
|
| 1317 |
+
|
| 1318 |
+
# Draw frustum edges
|
| 1319 |
+
lines_x = []
|
| 1320 |
+
lines_y = []
|
| 1321 |
+
lines_z = []
|
| 1322 |
+
|
| 1323 |
+
# Helper to add a line
|
| 1324 |
+
def add_line(p1, p2):
|
| 1325 |
+
lines_x.extend([p1[0], p2[0], None])
|
| 1326 |
+
lines_y.extend([p1[1], p2[1], None])
|
| 1327 |
+
lines_z.extend([p1[2], p2[2], None])
|
| 1328 |
+
|
| 1329 |
+
# Near plane
|
| 1330 |
+
add_line(tl_near, tr_near)
|
| 1331 |
+
add_line(tr_near, br_near)
|
| 1332 |
+
add_line(br_near, bl_near)
|
| 1333 |
+
add_line(bl_near, tl_near)
|
| 1334 |
+
|
| 1335 |
+
# Far plane
|
| 1336 |
+
add_line(tl_far, tr_far)
|
| 1337 |
+
add_line(tr_far, br_far)
|
| 1338 |
+
add_line(br_far, bl_far)
|
| 1339 |
+
add_line(bl_far, tl_far)
|
| 1340 |
+
|
| 1341 |
+
# Connecting edges
|
| 1342 |
+
add_line(tl_near, tl_far)
|
| 1343 |
+
add_line(tr_near, tr_far)
|
| 1344 |
+
add_line(bl_near, bl_far)
|
| 1345 |
+
add_line(br_near, br_far)
|
| 1346 |
+
|
| 1347 |
+
# Camera to near plane corners
|
| 1348 |
+
add_line(pos, tl_near)
|
| 1349 |
+
add_line(pos, tr_near)
|
| 1350 |
+
add_line(pos, bl_near)
|
| 1351 |
+
add_line(pos, br_near)
|
| 1352 |
+
|
| 1353 |
+
# Add frustum lines
|
| 1354 |
+
fig.add_trace(
|
| 1355 |
+
go.Scatter3d(
|
| 1356 |
+
x=lines_x,
|
| 1357 |
+
y=lines_y,
|
| 1358 |
+
z=lines_z,
|
| 1359 |
+
mode='lines',
|
| 1360 |
+
line=dict(
|
| 1361 |
+
color=color,
|
| 1362 |
+
width=2
|
| 1363 |
+
),
|
| 1364 |
+
name=f'Frustum {i}',
|
| 1365 |
+
showlegend=False,
|
| 1366 |
+
hoverinfo='none'
|
| 1367 |
+
)
|
| 1368 |
+
)
|
| 1369 |
+
|
| 1370 |
+
# Add coordinate system axes with LARGER SIZE
|
| 1371 |
+
axis_length = frustum_scale * 0.15 # Increased from 0.05 to 0.15
|
| 1372 |
+
|
| 1373 |
+
# Right direction (X axis) - Red
|
| 1374 |
+
fig.add_trace(
|
| 1375 |
+
go.Scatter3d(
|
| 1376 |
+
x=[pos[0], pos[0] + right[0] * axis_length],
|
| 1377 |
+
y=[pos[1], pos[1] + right[1] * axis_length],
|
| 1378 |
+
z=[pos[2], pos[2] + right[2] * axis_length],
|
| 1379 |
+
mode='lines',
|
| 1380 |
+
line=dict(color='red', width=4), # Thicker line
|
| 1381 |
+
name='X (Right)' if i == 0 else '',
|
| 1382 |
+
showlegend=i==0,
|
| 1383 |
+
hoverinfo='none'
|
| 1384 |
+
)
|
| 1385 |
+
)
|
| 1386 |
+
|
| 1387 |
+
# Up direction (Y axis) - Green
|
| 1388 |
+
fig.add_trace(
|
| 1389 |
+
go.Scatter3d(
|
| 1390 |
+
x=[pos[0], pos[0] + up[0] * axis_length],
|
| 1391 |
+
y=[pos[1], pos[1] + up[1] * axis_length],
|
| 1392 |
+
z=[pos[2], pos[2] + up[2] * axis_length],
|
| 1393 |
+
mode='lines',
|
| 1394 |
+
line=dict(color='green', width=4), # Thicker line
|
| 1395 |
+
name='Y (Up)' if i == 0 else '',
|
| 1396 |
+
showlegend=i==0,
|
| 1397 |
+
hoverinfo='none'
|
| 1398 |
+
)
|
| 1399 |
+
)
|
| 1400 |
+
|
| 1401 |
+
# Forward direction (Z axis) - Blue
|
| 1402 |
+
fig.add_trace(
|
| 1403 |
+
go.Scatter3d(
|
| 1404 |
+
x=[pos[0], pos[0] + forward[0] * axis_length],
|
| 1405 |
+
y=[pos[1], pos[1] + forward[1] * axis_length],
|
| 1406 |
+
z=[pos[2], pos[2] + forward[2] * axis_length],
|
| 1407 |
+
mode='lines',
|
| 1408 |
+
line=dict(color='blue', width=4), # Thicker line
|
| 1409 |
+
name='Z (Forward)' if i == 0 else '',
|
| 1410 |
+
showlegend=i==0,
|
| 1411 |
+
hoverinfo='none'
|
| 1412 |
+
)
|
| 1413 |
+
)
|
| 1414 |
+
|
| 1415 |
+
# Add cones at the end of each axis for better direction visibility
|
| 1416 |
+
cone_size = axis_length * 0.2 # Size of cone relative to axis length
|
| 1417 |
+
|
| 1418 |
+
for axis_dir, color, axis_name in [
|
| 1419 |
+
(right, 'red', 'X'),
|
| 1420 |
+
(up, 'green', 'Y'),
|
| 1421 |
+
(forward, 'blue', 'Z')
|
| 1422 |
+
]:
|
| 1423 |
+
# End point of axis
|
| 1424 |
+
end_point = pos + axis_dir * axis_length
|
| 1425 |
+
|
| 1426 |
+
# Add a sphere marker at axis end
|
| 1427 |
+
fig.add_trace(
|
| 1428 |
+
go.Scatter3d(
|
| 1429 |
+
x=[end_point[0]],
|
| 1430 |
+
y=[end_point[1]],
|
| 1431 |
+
z=[end_point[2]],
|
| 1432 |
+
mode='markers',
|
| 1433 |
+
marker=dict(
|
| 1434 |
+
size=6, # Size of endpoint marker
|
| 1435 |
+
color=color,
|
| 1436 |
+
symbol='circle'
|
| 1437 |
+
),
|
| 1438 |
+
name=f'{axis_name} Axis End' if i == 0 else '',
|
| 1439 |
+
showlegend=False,
|
| 1440 |
+
hoverinfo='none'
|
| 1441 |
+
)
|
| 1442 |
+
)
|
| 1443 |
+
|
| 1444 |
+
# Update layout for better visualization
|
| 1445 |
+
fig.update_layout(
|
| 1446 |
+
title=dict(
|
| 1447 |
+
text=f'Interactive Camera Frustums and Point Cloud<br>'
|
| 1448 |
+
f'<sub>Cameras: {n_cameras}, Points: {len(points_sample)}/{len(points_np)}, '
|
| 1449 |
+
f'Scale ratio: {cam_extent/points_extent:.2f}</sub>',
|
| 1450 |
+
x=0.5
|
| 1451 |
+
),
|
| 1452 |
+
scene=dict(
|
| 1453 |
+
xaxis_title='X',
|
| 1454 |
+
yaxis_title='Y',
|
| 1455 |
+
zaxis_title='Z',
|
| 1456 |
+
aspectmode='data', # 'cube' or 'data'
|
| 1457 |
+
camera=dict(
|
| 1458 |
+
eye=dict(x=1.8, y=1.8, z=1.8) # Adjusted default view
|
| 1459 |
+
),
|
| 1460 |
+
annotations=[
|
| 1461 |
+
dict(
|
| 1462 |
+
showarrow=False,
|
| 1463 |
+
x=0.05,
|
| 1464 |
+
y=0.05,
|
| 1465 |
+
z=0.05,
|
| 1466 |
+
text="Camera axes:<br>Red: X (right)<br>Green: Y (up)<br>Blue: Z (forward)",
|
| 1467 |
+
xanchor="left",
|
| 1468 |
+
xshift=10,
|
| 1469 |
+
opacity=0.8,
|
| 1470 |
+
font=dict(size=14)
|
| 1471 |
+
)
|
| 1472 |
+
]
|
| 1473 |
+
),
|
| 1474 |
+
height=900,
|
| 1475 |
+
width=1000,
|
| 1476 |
+
margin=dict(l=0, r=0, t=50, b=0)
|
| 1477 |
+
)
|
| 1478 |
+
|
| 1479 |
+
# Add axis legend
|
| 1480 |
+
for color, name in [('red', 'X (Right)'), ('green', 'Y (Up)'), ('blue', 'Z (Forward)')]:
|
| 1481 |
+
fig.add_trace(
|
| 1482 |
+
go.Scatter3d(
|
| 1483 |
+
x=[None], y=[None], z=[None],
|
| 1484 |
+
mode='lines',
|
| 1485 |
+
line=dict(color=color, width=6),
|
| 1486 |
+
name=name,
|
| 1487 |
+
showlegend=True
|
| 1488 |
+
)
|
| 1489 |
+
)
|
| 1490 |
+
|
| 1491 |
+
# Save interactive HTML
|
| 1492 |
+
html_path = self.output_path / 'camera_frustums_visualization.html'
|
| 1493 |
+
fig.write_html(str(html_path))
|
| 1494 |
+
print(f"Interactive visualization saved to: {html_path}")
|
| 1495 |
+
|
| 1496 |
+
# Show in browser if possible
|
| 1497 |
+
try:
|
| 1498 |
+
fig.show()
|
| 1499 |
+
except Exception as e:
|
| 1500 |
+
print(f"Could not display in browser: {e}")
|
| 1501 |
+
print(f"Open {html_path} in your browser to view the interactive plot")
|
| 1502 |
+
|
| 1503 |
+
return fig
|
| 1504 |
+
def debug_camera_and_points_alignment(self):
|
| 1505 |
+
"""Enhanced debug function with scale analysis"""
|
| 1506 |
+
# Get camera positions
|
| 1507 |
+
camera_positions = np.array([cam['camera_center'] for cam in self.cameras])
|
| 1508 |
+
|
| 1509 |
+
# Get point cloud positions
|
| 1510 |
+
points_np = wp.to_torch(self.params['positions']).cpu().numpy()
|
| 1511 |
+
|
| 1512 |
+
# Calculate detailed statistics
|
| 1513 |
+
cam_stats = {
|
| 1514 |
+
'min': np.min(camera_positions, axis=0),
|
| 1515 |
+
'max': np.max(camera_positions, axis=0),
|
| 1516 |
+
'mean': np.mean(camera_positions, axis=0),
|
| 1517 |
+
'extent': np.max(np.abs(camera_positions))
|
| 1518 |
+
}
|
| 1519 |
+
|
| 1520 |
+
points_stats = {
|
| 1521 |
+
'min': np.min(points_np, axis=0),
|
| 1522 |
+
'max': np.max(points_np, axis=0),
|
| 1523 |
+
'mean': np.mean(points_np, axis=0),
|
| 1524 |
+
'extent': np.max(np.abs(points_np))
|
| 1525 |
+
}
|
| 1526 |
+
|
| 1527 |
+
print("=== ALIGNMENT DEBUG ===")
|
| 1528 |
+
print(f"Cameras ({len(camera_positions)}):")
|
| 1529 |
+
print(f" Min: [{cam_stats['min'][0]:8.3f}, {cam_stats['min'][1]:8.3f}, {cam_stats['min'][2]:8.3f}]")
|
| 1530 |
+
print(f" Max: [{cam_stats['max'][0]:8.3f}, {cam_stats['max'][1]:8.3f}, {cam_stats['max'][2]:8.3f}]")
|
| 1531 |
+
print(f" Mean:[{cam_stats['mean'][0]:8.3f}, {cam_stats['mean'][1]:8.3f}, {cam_stats['mean'][2]:8.3f}]")
|
| 1532 |
+
print(f" Extent: {cam_stats['extent']:.3f}")
|
| 1533 |
+
|
| 1534 |
+
print(f"\nPoints ({len(points_np)}):")
|
| 1535 |
+
print(f" Min: [{points_stats['min'][0]:8.3f}, {points_stats['min'][1]:8.3f}, {points_stats['min'][2]:8.3f}]")
|
| 1536 |
+
print(f" Max: [{points_stats['max'][0]:8.3f}, {points_stats['max'][1]:8.3f}, {points_stats['max'][2]:8.3f}]")
|
| 1537 |
+
print(f" Mean:[{points_stats['mean'][0]:8.3f}, {points_stats['mean'][1]:8.3f}, {points_stats['mean'][2]:8.3f}]")
|
| 1538 |
+
print(f" Extent: {points_stats['extent']:.3f}")
|
| 1539 |
+
|
| 1540 |
+
# Scale analysis
|
| 1541 |
+
scale_ratio = cam_stats['extent'] / points_stats['extent'] if points_stats['extent'] > 0 else float('inf')
|
| 1542 |
+
print(f"\nScale Analysis:")
|
| 1543 |
+
print(f" Scale ratio (cam/points): {scale_ratio:.3f}")
|
| 1544 |
+
|
| 1545 |
+
if scale_ratio > 10:
|
| 1546 |
+
print(" ⚠️ WARNING: Cameras much larger than points - may need to scale points up")
|
| 1547 |
+
print(f" Suggested point scale factor: {scale_ratio/10:.3f}")
|
| 1548 |
+
elif scale_ratio < 0.1:
|
| 1549 |
+
print(" ⚠️ WARNING: Points much larger than cameras - may need to scale points down")
|
| 1550 |
+
print(f" Suggested point scale factor: {scale_ratio*10:.3f}")
|
| 1551 |
+
else:
|
| 1552 |
+
print(" ✅ Scale ratio looks reasonable")
|
| 1553 |
+
|
| 1554 |
+
# Distance analysis
|
| 1555 |
+
center_distance = np.linalg.norm(cam_stats['mean'] - points_stats['mean'])
|
| 1556 |
+
print(f"\nCenter separation: {center_distance:.3f}")
|
| 1557 |
+
if center_distance > max(cam_stats['extent'], points_stats['extent']):
|
| 1558 |
+
print(" ⚠️ WARNING: Camera and point centers are far apart - possible coordinate system issue")
|
| 1559 |
+
|
| 1560 |
+
return cam_stats, points_stats
|
| 1561 |
+
def main():
|
| 1562 |
+
parser = argparse.ArgumentParser(description="Train 3D Gaussian Splatting model with Colmap")
|
| 1563 |
+
parser.add_argument("--dataset", type=str, default="./data_/scenes/steak_is",
|
| 1564 |
+
help="Path to NeRF dataset directory (default: Lego dataset)")
|
| 1565 |
+
parser.add_argument("--output", type=str, default="./output/steak_is", help="Output directory")
|
| 1566 |
+
|
| 1567 |
+
args = parser.parse_args()
|
| 1568 |
+
|
| 1569 |
+
# Create trainer and start training
|
| 1570 |
+
trainer = NeRFGaussianSplattingTrainer(
|
| 1571 |
+
dataset_path=args.dataset,
|
| 1572 |
+
output_path=args.output,
|
| 1573 |
+
)
|
| 1574 |
+
|
| 1575 |
+
# Debug alignment
|
| 1576 |
+
trainer.debug_camera_and_points_alignment()
|
| 1577 |
+
|
| 1578 |
+
# Create interactive visualization
|
| 1579 |
+
trainer.visualize_camera_points_alignment_interactive()
|
| 1580 |
+
|
| 1581 |
+
# Start training
|
| 1582 |
+
trainer.train()
|
| 1583 |
+
|
| 1584 |
+
|
| 1585 |
+
if __name__ == "__main__":
|
| 1586 |
+
main()
|
gs/train_vdpm.py
ADDED
|
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train 3D Gaussian Splatting from VDPM Output
|
| 3 |
+
|
| 4 |
+
Loads VDPM reconstruction (tracks.npz, poses.npz, images/) and trains 3DGS.
|
| 5 |
+
Supports per-frame training or combined multi-timestep training.
|
| 6 |
+
|
| 7 |
+
Usage (from 4dgs-dpm root):
|
| 8 |
+
python -m gs.train_vdpm --input ./vdpm/input_images_XXXX --output ./output/vdpm_scene
|
| 9 |
+
python -m gs.train_vdpm --input ./vdpm/input_images_XXXX --output ./output --frame 0
|
| 10 |
+
|
| 11 |
+
Or directly:
|
| 12 |
+
cd gs
|
| 13 |
+
python train_vdpm.py --input ../vdpm/input_images_XXXX --output ./output
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import json
|
| 19 |
+
import argparse
|
| 20 |
+
import numpy as np
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from scipy.spatial import cKDTree
|
| 23 |
+
import imageio
|
| 24 |
+
import torch
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
|
| 27 |
+
# Ensure gs/ modules are importable when running from root
|
| 28 |
+
_gs_dir = Path(__file__).parent.resolve()
|
| 29 |
+
if str(_gs_dir) not in sys.path:
|
| 30 |
+
sys.path.insert(0, str(_gs_dir))
|
| 31 |
+
|
| 32 |
+
import warp as wp
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
|
| 35 |
+
from forward import render_gaussians
|
| 36 |
+
from backward import backward
|
| 37 |
+
from optimizer import adam_update, prune_gaussians
|
| 38 |
+
from config import GaussianParams, DEVICE
|
| 39 |
+
from utils.point_cloud_utils import save_ply
|
| 40 |
+
from utils.math_utils import world_to_view, projection_matrix
|
| 41 |
+
from loss import l1_loss, compute_image_gradients
|
| 42 |
+
|
| 43 |
+
wp.init()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def decode_poses(pose_enc: np.ndarray, image_hw: tuple) -> tuple:
|
| 47 |
+
"""
|
| 48 |
+
Decode VGGT pose encodings to extrinsic and intrinsic matrices.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
pose_enc: (1, N, 9) pose encoding from VDPM
|
| 52 |
+
image_hw: (H, W) image dimensions
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
extrinsics: (N, 4, 4) world-to-camera matrices
|
| 56 |
+
intrinsics: (N, 3, 3) camera intrinsic matrices
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 60 |
+
|
| 61 |
+
pose_enc_t = torch.from_numpy(pose_enc).float()
|
| 62 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc_t, image_hw)
|
| 63 |
+
|
| 64 |
+
# extrinsic is (1, N, 3, 4) camera-from-world
|
| 65 |
+
extrinsic = extrinsic[0].numpy() # (N, 3, 4)
|
| 66 |
+
intrinsic = intrinsic[0].numpy() # (N, 3, 3)
|
| 67 |
+
|
| 68 |
+
# Add homogeneous row to extrinsic
|
| 69 |
+
N = extrinsic.shape[0]
|
| 70 |
+
bottom = np.array([0, 0, 0, 1], dtype=np.float32).reshape(1, 1, 4)
|
| 71 |
+
bottom = np.tile(bottom, (N, 1, 1))
|
| 72 |
+
extrinsics_4x4 = np.concatenate([extrinsic, bottom], axis=1) # (N, 4, 4)
|
| 73 |
+
|
| 74 |
+
return extrinsics_4x4, intrinsic
|
| 75 |
+
|
| 76 |
+
except ImportError:
|
| 77 |
+
print("Warning: vggt not available. Using identity poses.")
|
| 78 |
+
N = pose_enc.shape[1]
|
| 79 |
+
extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
|
| 80 |
+
|
| 81 |
+
H, W = image_hw
|
| 82 |
+
fx = fy = max(H, W)
|
| 83 |
+
cx, cy = W / 2, H / 2
|
| 84 |
+
intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| 85 |
+
intrinsics = np.tile(intrinsic, (N, 1, 1))
|
| 86 |
+
|
| 87 |
+
return extrinsics, intrinsics
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def load_vdpm_data(input_path: str) -> dict:
|
| 91 |
+
"""Load all VDPM outputs from a directory."""
|
| 92 |
+
input_path = Path(input_path)
|
| 93 |
+
|
| 94 |
+
# Load tracks/points
|
| 95 |
+
tracks_path = input_path / "tracks.npz"
|
| 96 |
+
output_path = input_path / "output_4d.npz"
|
| 97 |
+
|
| 98 |
+
if tracks_path.exists():
|
| 99 |
+
data = np.load(tracks_path)
|
| 100 |
+
elif output_path.exists():
|
| 101 |
+
data = np.load(output_path)
|
| 102 |
+
else:
|
| 103 |
+
raise FileNotFoundError(f"No tracks.npz or output_4d.npz in {input_path}")
|
| 104 |
+
|
| 105 |
+
world_points = data['world_points']
|
| 106 |
+
world_points_conf = data['world_points_conf']
|
| 107 |
+
num_views = int(data.get('num_views', 1))
|
| 108 |
+
num_timesteps = int(data.get('num_timesteps', world_points.shape[0]))
|
| 109 |
+
|
| 110 |
+
# Handle multi-view format
|
| 111 |
+
if world_points.ndim == 5:
|
| 112 |
+
T, V, H, W, _ = world_points.shape
|
| 113 |
+
print(f"Multi-view: {T} timesteps × {V} views × {H}×{W}")
|
| 114 |
+
else:
|
| 115 |
+
T, H, W, _ = world_points.shape
|
| 116 |
+
V = 1
|
| 117 |
+
world_points = world_points[:, np.newaxis, :, :, :]
|
| 118 |
+
world_points_conf = world_points_conf[:, np.newaxis, :, :]
|
| 119 |
+
print(f"Single-view: {T} timesteps × {H}×{W}")
|
| 120 |
+
|
| 121 |
+
# Load poses
|
| 122 |
+
poses_path = input_path / "poses.npz"
|
| 123 |
+
pose_enc = None
|
| 124 |
+
if poses_path.exists():
|
| 125 |
+
pose_data = np.load(poses_path)
|
| 126 |
+
pose_enc = pose_data.get('pose_enc')
|
| 127 |
+
print(f"Loaded poses: {pose_enc.shape if pose_enc is not None else 'None'}")
|
| 128 |
+
|
| 129 |
+
# Load images
|
| 130 |
+
images_dir = input_path / "images"
|
| 131 |
+
image_paths = sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
|
| 132 |
+
images = []
|
| 133 |
+
for img_path in image_paths:
|
| 134 |
+
img = imageio.imread(img_path)
|
| 135 |
+
if img.ndim == 2:
|
| 136 |
+
img = np.stack([img, img, img], axis=-1)
|
| 137 |
+
elif img.shape[-1] == 4:
|
| 138 |
+
img = img[..., :3]
|
| 139 |
+
images.append(img.astype(np.float32) / 255.0)
|
| 140 |
+
images = np.stack(images, axis=0) # (N, H, W, 3)
|
| 141 |
+
print(f"Loaded {len(images)} images, shape: {images.shape}")
|
| 142 |
+
|
| 143 |
+
# Load metadata
|
| 144 |
+
meta_path = input_path / "meta.json"
|
| 145 |
+
meta = {}
|
| 146 |
+
if meta_path.exists():
|
| 147 |
+
with open(meta_path) as f:
|
| 148 |
+
meta = json.load(f)
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
'world_points': world_points, # (T, V, H, W, 3)
|
| 152 |
+
'world_points_conf': world_points_conf, # (T, V, H, W)
|
| 153 |
+
'pose_enc': pose_enc,
|
| 154 |
+
'images': images, # (N, H, W, 3)
|
| 155 |
+
'num_views': num_views,
|
| 156 |
+
'num_timesteps': num_timesteps,
|
| 157 |
+
'T': T, 'V': V, 'H': H, 'W': W,
|
| 158 |
+
'meta': meta,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def extract_frame_pointcloud(data: dict, frame_idx: int, conf_threshold: float = 50.0):
|
| 163 |
+
"""
|
| 164 |
+
Extract point cloud and colors for a specific frame.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
positions: (N, 3) XYZ
|
| 168 |
+
colors: (N, 3) RGB [0,1]
|
| 169 |
+
confidence: (N,) confidence scores
|
| 170 |
+
"""
|
| 171 |
+
T, V, H, W = data['T'], data['V'], data['H'], data['W']
|
| 172 |
+
|
| 173 |
+
# Get points for this frame (merge all views)
|
| 174 |
+
pts = data['world_points'][frame_idx] # (V, H, W, 3)
|
| 175 |
+
conf = data['world_points_conf'][frame_idx] # (V, H, W)
|
| 176 |
+
|
| 177 |
+
# Flatten
|
| 178 |
+
pts_flat = pts.reshape(-1, 3) # (V*H*W, 3)
|
| 179 |
+
conf_flat = conf.reshape(-1) # (V*H*W,)
|
| 180 |
+
|
| 181 |
+
# Get colors from images
|
| 182 |
+
# Images are interleaved: [cam0_t0, cam1_t0, cam0_t1, cam1_t1, ...]
|
| 183 |
+
start_idx = frame_idx * V
|
| 184 |
+
end_idx = start_idx + V
|
| 185 |
+
frame_images = data['images'][start_idx:end_idx] # (V, H_img, W_img, 3)
|
| 186 |
+
|
| 187 |
+
# Resize images to match point cloud if needed
|
| 188 |
+
img_H, img_W = frame_images.shape[1:3]
|
| 189 |
+
if img_H != H or img_W != W:
|
| 190 |
+
from scipy.ndimage import zoom
|
| 191 |
+
scale_h = H / img_H
|
| 192 |
+
scale_w = W / img_W
|
| 193 |
+
resized = []
|
| 194 |
+
for v in range(V):
|
| 195 |
+
img_v = zoom(frame_images[v], (scale_h, scale_w, 1), order=1)
|
| 196 |
+
resized.append(img_v)
|
| 197 |
+
frame_images = np.stack(resized, axis=0)
|
| 198 |
+
|
| 199 |
+
colors_flat = frame_images.reshape(-1, 3) # (V*H*W, 3)
|
| 200 |
+
|
| 201 |
+
# Filter by confidence
|
| 202 |
+
if conf_threshold > 0:
|
| 203 |
+
thresh = np.percentile(conf_flat, conf_threshold)
|
| 204 |
+
mask = (conf_flat >= thresh) & (conf_flat > 1e-5)
|
| 205 |
+
else:
|
| 206 |
+
mask = conf_flat > 1e-5
|
| 207 |
+
|
| 208 |
+
# Also filter NaN/Inf
|
| 209 |
+
valid_pts = np.isfinite(pts_flat).all(axis=1)
|
| 210 |
+
mask = mask & valid_pts
|
| 211 |
+
|
| 212 |
+
return pts_flat[mask], colors_flat[mask], conf_flat[mask]
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def build_cameras(data: dict, frame_idx: int) -> list:
|
| 216 |
+
"""
|
| 217 |
+
Build camera dictionaries for training from VDPM data.
|
| 218 |
+
|
| 219 |
+
Returns list of camera dicts compatible with 3DGS training.
|
| 220 |
+
"""
|
| 221 |
+
T, V, H, W = data['T'], data['V'], data['H'], data['W']
|
| 222 |
+
images = data['images']
|
| 223 |
+
img_H, img_W = images.shape[1:3]
|
| 224 |
+
|
| 225 |
+
# Decode poses
|
| 226 |
+
pose_enc = data.get('pose_enc')
|
| 227 |
+
if pose_enc is not None:
|
| 228 |
+
extrinsics, intrinsics = decode_poses(pose_enc, (img_H, img_W))
|
| 229 |
+
else:
|
| 230 |
+
# Fallback: identity poses with reasonable intrinsics
|
| 231 |
+
N = T * V
|
| 232 |
+
extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
|
| 233 |
+
fx = fy = max(img_H, img_W)
|
| 234 |
+
cx, cy = img_W / 2, img_H / 2
|
| 235 |
+
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| 236 |
+
intrinsics = np.tile(K, (N, 1, 1))
|
| 237 |
+
|
| 238 |
+
cameras = []
|
| 239 |
+
|
| 240 |
+
# Get camera indices for this frame
|
| 241 |
+
for v in range(V):
|
| 242 |
+
img_idx = frame_idx * V + v
|
| 243 |
+
|
| 244 |
+
if img_idx >= len(extrinsics):
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
extrinsic = extrinsics[img_idx] # (4, 4) camera-from-world
|
| 248 |
+
intrinsic = intrinsics[img_idx] # (3, 3)
|
| 249 |
+
|
| 250 |
+
# Extract components
|
| 251 |
+
R = extrinsic[:3, :3]
|
| 252 |
+
t = extrinsic[:3, 3]
|
| 253 |
+
|
| 254 |
+
fx, fy = intrinsic[0, 0], intrinsic[1, 1]
|
| 255 |
+
cx, cy = intrinsic[0, 2], intrinsic[1, 2]
|
| 256 |
+
|
| 257 |
+
# Camera center in world coords
|
| 258 |
+
camera_center = -R.T @ t
|
| 259 |
+
|
| 260 |
+
# FOV from intrinsics
|
| 261 |
+
fov_x = 2 * np.arctan(img_W / (2 * fx))
|
| 262 |
+
fov_y = 2 * np.arctan(img_H / (2 * fy))
|
| 263 |
+
|
| 264 |
+
# Build matrices exactly like render.py does for Warp/OpenGL compatibility
|
| 265 |
+
# Warp uses column-major (OpenGL convention), so matrices must be transposed
|
| 266 |
+
world_to_camera = np.eye(4, dtype=np.float32)
|
| 267 |
+
world_to_camera[:3, :3] = R
|
| 268 |
+
world_to_camera[:3, 3] = t
|
| 269 |
+
world_to_camera = world_to_camera.T # Transpose for Warp/OpenGL!
|
| 270 |
+
|
| 271 |
+
# Projection matrix (transposed for Warp/OpenGL)
|
| 272 |
+
near, far = 0.01, 100.0
|
| 273 |
+
proj_matrix = projection_matrix(fovx=fov_x, fovy=fov_y, znear=near, zfar=far).T
|
| 274 |
+
|
| 275 |
+
# Full projection = view @ proj
|
| 276 |
+
full_proj_matrix = world_to_camera @ proj_matrix
|
| 277 |
+
|
| 278 |
+
cameras.append({
|
| 279 |
+
'camera_id': img_idx,
|
| 280 |
+
'width': img_W,
|
| 281 |
+
'height': img_H,
|
| 282 |
+
'world_to_camera': world_to_camera, # Transposed for Warp/OpenGL
|
| 283 |
+
'camera_to_world': np.linalg.inv(world_to_camera),
|
| 284 |
+
'camera_center': camera_center,
|
| 285 |
+
'full_proj_matrix': full_proj_matrix,
|
| 286 |
+
'tan_fovx': np.tan(fov_x / 2),
|
| 287 |
+
'tan_fovy': np.tan(fov_y / 2),
|
| 288 |
+
'fx': fx, 'fy': fy,
|
| 289 |
+
'cx': cx, 'cy': cy,
|
| 290 |
+
})
|
| 291 |
+
|
| 292 |
+
return cameras
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@wp.kernel
|
| 296 |
+
def init_rotations_opacities(
|
| 297 |
+
rotations: wp.array(dtype=wp.vec4),
|
| 298 |
+
opacities: wp.array(dtype=float),
|
| 299 |
+
num_points: int
|
| 300 |
+
):
|
| 301 |
+
i = wp.tid()
|
| 302 |
+
if i >= num_points:
|
| 303 |
+
return
|
| 304 |
+
rotations[i] = wp.vec4(1.0, 0.0, 0.0, 0.0)
|
| 305 |
+
opacities[i] = 0.5
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@wp.kernel
|
| 309 |
+
def zero_gradients(
|
| 310 |
+
pos_grad: wp.array(dtype=wp.vec3),
|
| 311 |
+
scale_grad: wp.array(dtype=wp.vec3),
|
| 312 |
+
rot_grad: wp.array(dtype=wp.vec4),
|
| 313 |
+
opacity_grad: wp.array(dtype=float),
|
| 314 |
+
sh_grad: wp.array(dtype=wp.vec3),
|
| 315 |
+
num_points: int
|
| 316 |
+
):
|
| 317 |
+
i = wp.tid()
|
| 318 |
+
if i >= num_points:
|
| 319 |
+
return
|
| 320 |
+
|
| 321 |
+
pos_grad[i] = wp.vec3(0.0, 0.0, 0.0)
|
| 322 |
+
scale_grad[i] = wp.vec3(0.0, 0.0, 0.0)
|
| 323 |
+
rot_grad[i] = wp.vec4(0.0, 0.0, 0.0, 0.0)
|
| 324 |
+
opacity_grad[i] = 0.0
|
| 325 |
+
|
| 326 |
+
for j in range(16):
|
| 327 |
+
idx = i * 16 + j
|
| 328 |
+
sh_grad[idx] = wp.vec3(0.0, 0.0, 0.0)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class VDPM3DGSTrainer:
|
| 332 |
+
"""Train 3DGS from VDPM point cloud initialization."""
|
| 333 |
+
|
| 334 |
+
def __init__(self, data: dict, frame_idx: int, output_path: str, conf_threshold: float = 50.0):
|
| 335 |
+
self.output_path = Path(output_path)
|
| 336 |
+
self.output_path.mkdir(parents=True, exist_ok=True)
|
| 337 |
+
|
| 338 |
+
# Create output directories for renders
|
| 339 |
+
self.render_dir = self.output_path / f"frame_{frame_idx}" / "renders"
|
| 340 |
+
self.render_dir.mkdir(parents=True, exist_ok=True)
|
| 341 |
+
|
| 342 |
+
self.config = GaussianParams.get_config_dict()
|
| 343 |
+
self.frame_idx = frame_idx
|
| 344 |
+
|
| 345 |
+
# Extract point cloud for this frame
|
| 346 |
+
print(f"Extracting point cloud for frame {frame_idx}...")
|
| 347 |
+
positions, colors, confidence = extract_frame_pointcloud(data, frame_idx, conf_threshold)
|
| 348 |
+
self.num_points = len(positions)
|
| 349 |
+
print(f"Got {self.num_points} points")
|
| 350 |
+
|
| 351 |
+
# Build cameras
|
| 352 |
+
self.cameras = build_cameras(data, frame_idx)
|
| 353 |
+
print(f"Built {len(self.cameras)} cameras")
|
| 354 |
+
|
| 355 |
+
# Store images
|
| 356 |
+
V = data['V']
|
| 357 |
+
start_idx = frame_idx * V
|
| 358 |
+
end_idx = start_idx + V
|
| 359 |
+
self.images = data['images'][start_idx:end_idx]
|
| 360 |
+
|
| 361 |
+
# Initialize Gaussian parameters
|
| 362 |
+
self.params = self._init_params(positions, colors)
|
| 363 |
+
self.grads = self._create_grad_arrays()
|
| 364 |
+
self.adam_m = self._create_grad_arrays()
|
| 365 |
+
self.adam_v = self._create_grad_arrays()
|
| 366 |
+
|
| 367 |
+
self.losses = []
|
| 368 |
+
self.intermediate_buffers = {}
|
| 369 |
+
|
| 370 |
+
def _init_params(self, positions: np.ndarray, colors: np.ndarray):
|
| 371 |
+
"""Initialize Gaussian parameters from point cloud."""
|
| 372 |
+
N = self.num_points
|
| 373 |
+
|
| 374 |
+
# Positions
|
| 375 |
+
positions_wp = wp.array(positions.astype(np.float32), dtype=wp.vec3, device=DEVICE)
|
| 376 |
+
|
| 377 |
+
# Scales from KNN
|
| 378 |
+
if N > 3:
|
| 379 |
+
tree = cKDTree(positions)
|
| 380 |
+
dists, _ = tree.query(positions, k=4)
|
| 381 |
+
avg_dist = np.mean(dists[:, 1:], axis=1)
|
| 382 |
+
scales_np = np.clip(avg_dist, 0.001, 1.0)[:, np.newaxis] * np.ones((N, 3))
|
| 383 |
+
else:
|
| 384 |
+
scales_np = np.full((N, 3), 0.01, dtype=np.float32)
|
| 385 |
+
scales_wp = wp.array(scales_np.astype(np.float32), dtype=wp.vec3, device=DEVICE)
|
| 386 |
+
|
| 387 |
+
# Rotations and opacities
|
| 388 |
+
rotations_wp = wp.zeros(N, dtype=wp.vec4, device=DEVICE)
|
| 389 |
+
opacities_wp = wp.zeros(N, dtype=float, device=DEVICE)
|
| 390 |
+
wp.launch(init_rotations_opacities, dim=N, inputs=[rotations_wp, opacities_wp, N])
|
| 391 |
+
|
| 392 |
+
# SH coefficients from colors
|
| 393 |
+
C0 = 0.28209479177387814
|
| 394 |
+
shs_np = np.zeros((N * 16, 3), dtype=np.float32)
|
| 395 |
+
shs_np[::16] = (colors - 0.5) / C0
|
| 396 |
+
shs_wp = wp.array(shs_np, dtype=wp.vec3, device=DEVICE)
|
| 397 |
+
|
| 398 |
+
return {
|
| 399 |
+
'positions': positions_wp,
|
| 400 |
+
'scales': scales_wp,
|
| 401 |
+
'rotations': rotations_wp,
|
| 402 |
+
'opacities': opacities_wp,
|
| 403 |
+
'shs': shs_wp,
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
def _create_grad_arrays(self):
|
| 407 |
+
N = self.num_points
|
| 408 |
+
return {
|
| 409 |
+
'positions': wp.zeros(N, dtype=wp.vec3, device=DEVICE),
|
| 410 |
+
'scales': wp.zeros(N, dtype=wp.vec3, device=DEVICE),
|
| 411 |
+
'rotations': wp.zeros(N, dtype=wp.vec4, device=DEVICE),
|
| 412 |
+
'opacities': wp.zeros(N, dtype=float, device=DEVICE),
|
| 413 |
+
'shs': wp.zeros(N * 16, dtype=wp.vec3, device=DEVICE),
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
def zero_grad(self):
|
| 417 |
+
wp.launch(zero_gradients, dim=self.num_points, inputs=[
|
| 418 |
+
self.grads['positions'], self.grads['scales'],
|
| 419 |
+
self.grads['rotations'], self.grads['opacities'],
|
| 420 |
+
self.grads['shs'], self.num_points
|
| 421 |
+
])
|
| 422 |
+
|
| 423 |
+
def train(self, num_iterations: int = 3000):
|
| 424 |
+
"""Train the 3DGS model."""
|
| 425 |
+
print(f"Training for {num_iterations} iterations...")
|
| 426 |
+
|
| 427 |
+
# Save iteration 0 (initial state before training)
|
| 428 |
+
print("Saving initial state (iteration 0)...")
|
| 429 |
+
self.save(0)
|
| 430 |
+
|
| 431 |
+
with tqdm(total=num_iterations) as pbar:
|
| 432 |
+
for it in range(num_iterations):
|
| 433 |
+
self.zero_grad()
|
| 434 |
+
|
| 435 |
+
# Pick a random camera
|
| 436 |
+
cam_idx = np.random.randint(len(self.cameras))
|
| 437 |
+
camera = self.cameras[cam_idx]
|
| 438 |
+
target = self.images[cam_idx]
|
| 439 |
+
|
| 440 |
+
# Render
|
| 441 |
+
rendered, depth, self.intermediate_buffers = render_gaussians(
|
| 442 |
+
background=np.array(self.config['background_color'], dtype=np.float32),
|
| 443 |
+
means3D=self.params['positions'].numpy(),
|
| 444 |
+
colors=None,
|
| 445 |
+
opacity=self.params['opacities'].numpy(),
|
| 446 |
+
scales=self.params['scales'].numpy(),
|
| 447 |
+
rotations=self.params['rotations'].numpy(),
|
| 448 |
+
scale_modifier=1.0,
|
| 449 |
+
viewmatrix=camera['world_to_camera'],
|
| 450 |
+
projmatrix=camera['full_proj_matrix'],
|
| 451 |
+
tan_fovx=camera['tan_fovx'],
|
| 452 |
+
tan_fovy=camera['tan_fovy'],
|
| 453 |
+
image_height=camera['height'],
|
| 454 |
+
image_width=camera['width'],
|
| 455 |
+
sh=self.params['shs'].numpy(),
|
| 456 |
+
degree=3,
|
| 457 |
+
campos=camera['camera_center'],
|
| 458 |
+
prefiltered=False,
|
| 459 |
+
antialiasing=True,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Compute loss
|
| 463 |
+
rendered_np = wp.to_torch(rendered).cpu().numpy()
|
| 464 |
+
if rendered_np.shape[0] == 3:
|
| 465 |
+
rendered_np = np.transpose(rendered_np, (1, 2, 0))
|
| 466 |
+
|
| 467 |
+
target_wp = wp.array(target.astype(np.float32), dtype=wp.vec3, device=DEVICE)
|
| 468 |
+
loss = l1_loss(rendered, target_wp)
|
| 469 |
+
self.losses.append(loss)
|
| 470 |
+
|
| 471 |
+
# Compute pixel gradients for backward pass
|
| 472 |
+
pixel_grad_buffer = compute_image_gradients(
|
| 473 |
+
rendered, target_wp, lambda_dssim=0
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Prepare camera parameters
|
| 477 |
+
view_matrix = wp.mat44(camera['world_to_camera'].flatten())
|
| 478 |
+
proj_matrix = wp.mat44(camera['full_proj_matrix'].flatten())
|
| 479 |
+
campos = wp.vec3(camera['camera_center'][0], camera['camera_center'][1], camera['camera_center'][2])
|
| 480 |
+
|
| 481 |
+
# Prepare buffers for backward pass
|
| 482 |
+
geom_buffer = {
|
| 483 |
+
'radii': self.intermediate_buffers['radii'],
|
| 484 |
+
'means2D': self.intermediate_buffers['points_xy_image'],
|
| 485 |
+
'conic_opacity': self.intermediate_buffers['conic_opacity'],
|
| 486 |
+
'rgb': self.intermediate_buffers['colors'],
|
| 487 |
+
'clamped': self.intermediate_buffers['clamped_state']
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
binning_buffer = {
|
| 491 |
+
'point_list': self.intermediate_buffers['point_list']
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
img_buffer = {
|
| 495 |
+
'ranges': self.intermediate_buffers['ranges'],
|
| 496 |
+
'final_Ts': self.intermediate_buffers['final_Ts'],
|
| 497 |
+
'n_contrib': self.intermediate_buffers['n_contrib']
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
# Backward pass
|
| 501 |
+
gradients = backward(
|
| 502 |
+
# Core parameters
|
| 503 |
+
background=np.array(self.config['background_color'], dtype=np.float32),
|
| 504 |
+
means3D=self.params['positions'],
|
| 505 |
+
dL_dpixels=pixel_grad_buffer,
|
| 506 |
+
|
| 507 |
+
# Model parameters
|
| 508 |
+
opacity=self.params['opacities'],
|
| 509 |
+
shs=self.params['shs'],
|
| 510 |
+
scales=self.params['scales'],
|
| 511 |
+
rotations=self.params['rotations'],
|
| 512 |
+
scale_modifier=self.config['scale_modifier'],
|
| 513 |
+
|
| 514 |
+
# Camera parameters
|
| 515 |
+
viewmatrix=view_matrix,
|
| 516 |
+
projmatrix=proj_matrix,
|
| 517 |
+
tan_fovx=camera['tan_fovx'],
|
| 518 |
+
tan_fovy=camera['tan_fovy'],
|
| 519 |
+
image_height=camera['height'],
|
| 520 |
+
image_width=camera['width'],
|
| 521 |
+
campos=campos,
|
| 522 |
+
|
| 523 |
+
# Forward output buffers
|
| 524 |
+
radii=self.intermediate_buffers['radii'],
|
| 525 |
+
means2D=self.intermediate_buffers['points_xy_image'],
|
| 526 |
+
conic_opacity=self.intermediate_buffers['conic_opacity'],
|
| 527 |
+
rgb=self.intermediate_buffers['colors'],
|
| 528 |
+
cov3Ds=self.intermediate_buffers['cov3Ds'],
|
| 529 |
+
clamped=self.intermediate_buffers['clamped_state'],
|
| 530 |
+
|
| 531 |
+
# Internal state buffers
|
| 532 |
+
geom_buffer=geom_buffer,
|
| 533 |
+
binning_buffer=binning_buffer,
|
| 534 |
+
img_buffer=img_buffer,
|
| 535 |
+
|
| 536 |
+
# Algorithm parameters
|
| 537 |
+
degree=self.config['sh_degree'],
|
| 538 |
+
debug=False
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# Copy gradients to optimizer buffers
|
| 542 |
+
wp.copy(self.grads['positions'], gradients['dL_dmean3D'])
|
| 543 |
+
wp.copy(self.grads['scales'], gradients['dL_dscale'])
|
| 544 |
+
wp.copy(self.grads['rotations'], gradients['dL_drot'])
|
| 545 |
+
wp.copy(self.grads['opacities'], gradients['dL_dopacity'])
|
| 546 |
+
wp.copy(self.grads['shs'], gradients['dL_dshs'])
|
| 547 |
+
|
| 548 |
+
# Optimizer step
|
| 549 |
+
lr = 0.001 * (0.1 ** (it / num_iterations))
|
| 550 |
+
wp.launch(adam_update, dim=self.num_points, inputs=[
|
| 551 |
+
self.params['positions'], self.params['scales'],
|
| 552 |
+
self.params['rotations'], self.params['opacities'], self.params['shs'],
|
| 553 |
+
self.grads['positions'], self.grads['scales'],
|
| 554 |
+
self.grads['rotations'], self.grads['opacities'], self.grads['shs'],
|
| 555 |
+
self.adam_m['positions'], self.adam_m['scales'],
|
| 556 |
+
self.adam_m['rotations'], self.adam_m['opacities'], self.adam_m['shs'],
|
| 557 |
+
self.adam_v['positions'], self.adam_v['scales'],
|
| 558 |
+
self.adam_v['rotations'], self.adam_v['opacities'], self.adam_v['shs'],
|
| 559 |
+
self.num_points, lr, lr*5, lr*5, lr*2, lr*5,
|
| 560 |
+
0.9, 0.999, 1e-8, it
|
| 561 |
+
])
|
| 562 |
+
|
| 563 |
+
pbar.set_postfix(loss=f"{loss:.4f}")
|
| 564 |
+
pbar.update(1)
|
| 565 |
+
|
| 566 |
+
# Save checkpoint
|
| 567 |
+
if (it + 1) % 500 == 0 or it == num_iterations - 1:
|
| 568 |
+
self.save(it + 1)
|
| 569 |
+
|
| 570 |
+
print("Training complete!")
|
| 571 |
+
|
| 572 |
+
def save(self, iteration: int):
|
| 573 |
+
"""Save checkpoint with rendered images."""
|
| 574 |
+
ckpt_dir = self.output_path / f"frame_{self.frame_idx}" / f"iter_{iteration}"
|
| 575 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 576 |
+
|
| 577 |
+
# Save PLY
|
| 578 |
+
save_ply(self.params, ckpt_dir / "point_cloud.ply", self.num_points)
|
| 579 |
+
|
| 580 |
+
# Render and save images for all cameras
|
| 581 |
+
for cam_idx, camera in enumerate(self.cameras):
|
| 582 |
+
target = self.images[cam_idx]
|
| 583 |
+
|
| 584 |
+
rendered, depth, _ = render_gaussians(
|
| 585 |
+
background=np.array(self.config['background_color'], dtype=np.float32),
|
| 586 |
+
means3D=self.params['positions'].numpy(),
|
| 587 |
+
colors=None,
|
| 588 |
+
opacity=self.params['opacities'].numpy(),
|
| 589 |
+
scales=self.params['scales'].numpy(),
|
| 590 |
+
rotations=self.params['rotations'].numpy(),
|
| 591 |
+
scale_modifier=1.0,
|
| 592 |
+
viewmatrix=camera['world_to_camera'],
|
| 593 |
+
projmatrix=camera['full_proj_matrix'],
|
| 594 |
+
tan_fovx=camera['tan_fovx'],
|
| 595 |
+
tan_fovy=camera['tan_fovy'],
|
| 596 |
+
image_height=camera['height'],
|
| 597 |
+
image_width=camera['width'],
|
| 598 |
+
sh=self.params['shs'].numpy(),
|
| 599 |
+
degree=3,
|
| 600 |
+
campos=camera['camera_center'],
|
| 601 |
+
prefiltered=False,
|
| 602 |
+
antialiasing=True,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# Convert rendered to numpy
|
| 606 |
+
rendered_np = wp.to_torch(rendered).cpu().numpy()
|
| 607 |
+
if rendered_np.shape[0] == 3:
|
| 608 |
+
rendered_np = np.transpose(rendered_np, (1, 2, 0))
|
| 609 |
+
|
| 610 |
+
# Save rendered image
|
| 611 |
+
rendered_uint8 = (np.clip(rendered_np, 0, 1) * 255).astype(np.uint8)
|
| 612 |
+
imageio.imwrite(ckpt_dir / f"render_cam{cam_idx}.png", rendered_uint8)
|
| 613 |
+
|
| 614 |
+
# Save target image
|
| 615 |
+
target_uint8 = (np.clip(target, 0, 1) * 255).astype(np.uint8)
|
| 616 |
+
imageio.imwrite(ckpt_dir / f"target_cam{cam_idx}.png", target_uint8)
|
| 617 |
+
|
| 618 |
+
# Save side-by-side comparison
|
| 619 |
+
comparison = np.concatenate([target_uint8, rendered_uint8], axis=1)
|
| 620 |
+
imageio.imwrite(ckpt_dir / f"compare_cam{cam_idx}.png", comparison)
|
| 621 |
+
|
| 622 |
+
# Save loss plot
|
| 623 |
+
if len(self.losses) > 0:
|
| 624 |
+
plt.figure(figsize=(10, 5))
|
| 625 |
+
plt.plot(self.losses)
|
| 626 |
+
plt.title(f'Training Loss - Frame {self.frame_idx}')
|
| 627 |
+
plt.xlabel('Iteration')
|
| 628 |
+
plt.ylabel('L1 Loss')
|
| 629 |
+
plt.grid(True)
|
| 630 |
+
plt.savefig(ckpt_dir / "loss_plot.png")
|
| 631 |
+
plt.close()
|
| 632 |
+
|
| 633 |
+
print(f"Saved checkpoint to {ckpt_dir}")
|
| 634 |
+
|
| 635 |
+
def save_final(self):
|
| 636 |
+
"""Save final PLY to flat output structure for easy loading."""
|
| 637 |
+
final_path = self.output_path / f"frame_{self.frame_idx:04d}.ply"
|
| 638 |
+
save_ply(self.params, final_path, self.num_points)
|
| 639 |
+
return final_path
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def train_single_frame(data: dict, frame_idx: int, output_path: str,
|
| 643 |
+
conf_threshold: float, iterations: int) -> str:
|
| 644 |
+
"""Train a single frame and return the output PLY path."""
|
| 645 |
+
trainer = VDPM3DGSTrainer(
|
| 646 |
+
data=data,
|
| 647 |
+
frame_idx=frame_idx,
|
| 648 |
+
output_path=output_path,
|
| 649 |
+
conf_threshold=conf_threshold,
|
| 650 |
+
)
|
| 651 |
+
trainer.train(num_iterations=iterations)
|
| 652 |
+
return trainer.save_final()
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def main():
|
| 656 |
+
parser = argparse.ArgumentParser(description="Train 3DGS from VDPM output")
|
| 657 |
+
parser.add_argument("--input", "-i", required=True, help="Path to VDPM output directory")
|
| 658 |
+
parser.add_argument("--output", "-o", required=True, help="Output directory")
|
| 659 |
+
parser.add_argument("--frame", "-f", type=int, default=None,
|
| 660 |
+
help="Single frame index to train (default: train ALL frames)")
|
| 661 |
+
parser.add_argument("--conf", type=float, default=50.0, help="Confidence threshold percentile")
|
| 662 |
+
parser.add_argument("--iterations", "-n", type=int, default=3000, help="Training iterations per frame")
|
| 663 |
+
|
| 664 |
+
args = parser.parse_args()
|
| 665 |
+
|
| 666 |
+
# Load data
|
| 667 |
+
print(f"Loading VDPM data from {args.input}...")
|
| 668 |
+
data = load_vdpm_data(args.input)
|
| 669 |
+
|
| 670 |
+
num_timesteps = data['T']
|
| 671 |
+
print(f"Found {num_timesteps} timesteps in data")
|
| 672 |
+
|
| 673 |
+
output_path = Path(args.output)
|
| 674 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 675 |
+
|
| 676 |
+
if args.frame is not None:
|
| 677 |
+
# Train single frame
|
| 678 |
+
if args.frame >= num_timesteps:
|
| 679 |
+
raise ValueError(f"Frame {args.frame} out of range (0-{num_timesteps-1})")
|
| 680 |
+
|
| 681 |
+
print(f"\n{'='*60}")
|
| 682 |
+
print(f"Training frame {args.frame}/{num_timesteps-1}")
|
| 683 |
+
print(f"{'='*60}")
|
| 684 |
+
|
| 685 |
+
ply_path = train_single_frame(
|
| 686 |
+
data, args.frame, args.output, args.conf, args.iterations
|
| 687 |
+
)
|
| 688 |
+
print(f"\n✓ Saved: {ply_path}")
|
| 689 |
+
else:
|
| 690 |
+
# Train ALL frames
|
| 691 |
+
print(f"\n{'='*60}")
|
| 692 |
+
print(f"Training ALL {num_timesteps} frames")
|
| 693 |
+
print(f"Output: {output_path}/frame_XXXX.ply")
|
| 694 |
+
print(f"{'='*60}")
|
| 695 |
+
|
| 696 |
+
ply_paths = []
|
| 697 |
+
for frame_idx in range(num_timesteps):
|
| 698 |
+
print(f"\n[Frame {frame_idx+1}/{num_timesteps}]")
|
| 699 |
+
ply_path = train_single_frame(
|
| 700 |
+
data, frame_idx, args.output, args.conf, args.iterations
|
| 701 |
+
)
|
| 702 |
+
ply_paths.append(ply_path)
|
| 703 |
+
|
| 704 |
+
print(f"\n{'='*60}")
|
| 705 |
+
print(f"✓ Training complete! Generated {len(ply_paths)} PLY files:")
|
| 706 |
+
for p in ply_paths:
|
| 707 |
+
print(f" {p}")
|
| 708 |
+
print(f"{'='*60}")
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
if __name__ == "__main__":
|
| 712 |
+
main()
|
gs/training_progress.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:052cff51cf3d320a299faea9c795fad53349ed193b8bc222ee37860afd7def99
|
| 3 |
+
size 306145
|
gs/utils/analyze_scales.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
from plyfile import PlyData
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
def analyze_scales(input_ply, output_ply=None, threshold=None, show_plot=True):
|
| 9 |
+
"""
|
| 10 |
+
Analyze scales in a PLY file and optionally filter out large splats.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
input_ply (str): Path to input PLY file
|
| 14 |
+
output_ply (str, optional): Path to save filtered PLY file
|
| 15 |
+
threshold (float, optional): Maximum scale value to keep
|
| 16 |
+
show_plot (bool): Whether to display the histogram plot
|
| 17 |
+
"""
|
| 18 |
+
# Convert input path to absolute path if it's relative
|
| 19 |
+
repo_root = Path(__file__).parent.parent # Go up one level from utils
|
| 20 |
+
input_ply = Path(repo_root) / input_ply if not os.path.isabs(input_ply) else Path(input_ply)
|
| 21 |
+
|
| 22 |
+
if not input_ply.exists():
|
| 23 |
+
raise FileNotFoundError(f"Could not find PLY file: {input_ply}")
|
| 24 |
+
|
| 25 |
+
print(f"Reading PLY file: {input_ply}")
|
| 26 |
+
plydata = PlyData.read(str(input_ply))
|
| 27 |
+
vertex_data = plydata['vertex']
|
| 28 |
+
|
| 29 |
+
# Extract scale values - assuming log-space encoding in PLY
|
| 30 |
+
scales = np.vstack([
|
| 31 |
+
np.exp(vertex_data['scale_0']),
|
| 32 |
+
np.exp(vertex_data['scale_1']),
|
| 33 |
+
np.exp(vertex_data['scale_2'])
|
| 34 |
+
]).T
|
| 35 |
+
|
| 36 |
+
# Calculate statistics
|
| 37 |
+
max_scales = np.max(scales, axis=1)
|
| 38 |
+
mean_scale = np.mean(max_scales)
|
| 39 |
+
median_scale = np.median(max_scales)
|
| 40 |
+
|
| 41 |
+
print(f"Scale statistics:")
|
| 42 |
+
print(f"Mean scale: {mean_scale:.6f}")
|
| 43 |
+
print(f"Median scale: {median_scale:.6f}")
|
| 44 |
+
print(f"Min scale: {np.min(max_scales):.6f}")
|
| 45 |
+
print(f"Max scale: {np.max(max_scales):.6f}")
|
| 46 |
+
|
| 47 |
+
# Plot histogram
|
| 48 |
+
if show_plot:
|
| 49 |
+
plt.figure(figsize=(10, 6))
|
| 50 |
+
plt.hist(max_scales, bins=100, edgecolor='black')
|
| 51 |
+
plt.title('Histogram of Maximum Scales per Gaussian')
|
| 52 |
+
plt.xlabel('Scale')
|
| 53 |
+
plt.ylabel('Count')
|
| 54 |
+
if threshold is not None:
|
| 55 |
+
plt.axvline(x=threshold, color='r', linestyle='--',
|
| 56 |
+
label=f'Threshold ({threshold})')
|
| 57 |
+
plt.legend()
|
| 58 |
+
plt.savefig(Path(input_ply).parent / 'scale_histogram.png')
|
| 59 |
+
plt.show()
|
| 60 |
+
|
| 61 |
+
# Filter and save new PLY if threshold is provided
|
| 62 |
+
if threshold is not None and output_ply is not None:
|
| 63 |
+
# Create mask for Gaussians to keep
|
| 64 |
+
keep_mask = max_scales <= threshold
|
| 65 |
+
num_removed = np.sum(~keep_mask)
|
| 66 |
+
print(f"Removing {num_removed} Gaussians ({(num_removed/len(keep_mask))*100:.2f}%)")
|
| 67 |
+
|
| 68 |
+
# Create new vertex data with filtered Gaussians
|
| 69 |
+
new_vertex = []
|
| 70 |
+
for i, keep in enumerate(keep_mask):
|
| 71 |
+
if keep:
|
| 72 |
+
new_vertex.append(tuple(vertex_data[i]))
|
| 73 |
+
|
| 74 |
+
# Create new PLY file
|
| 75 |
+
new_vertex_array = np.array(
|
| 76 |
+
new_vertex,
|
| 77 |
+
dtype=vertex_data.dtype
|
| 78 |
+
)
|
| 79 |
+
new_vertex_element = PlyData.describe(new_vertex_array, 'vertex')
|
| 80 |
+
PlyData([new_vertex_element], text=True).write(output_ply)
|
| 81 |
+
print(f"Saved filtered PLY to: {output_ply}")
|
| 82 |
+
|
| 83 |
+
def main():
|
| 84 |
+
parser = argparse.ArgumentParser(description='Analyze and filter Gaussian scales in PLY file')
|
| 85 |
+
parser.add_argument('input_ply', help='Input PLY file path')
|
| 86 |
+
parser.add_argument('--output', '-o', help='Output PLY file path')
|
| 87 |
+
parser.add_argument('--threshold', '-t', type=float, help='Maximum scale threshold')
|
| 88 |
+
parser.add_argument('--no-plot', action='store_true', help='Disable histogram plot')
|
| 89 |
+
|
| 90 |
+
args = parser.parse_args()
|
| 91 |
+
|
| 92 |
+
analyze_scales(
|
| 93 |
+
args.input_ply,
|
| 94 |
+
args.output,
|
| 95 |
+
args.threshold,
|
| 96 |
+
show_plot=not args.no_plot
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
main()
|
gs/utils/camera_utils.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from utils.math_utils import world_to_view, projection_matrix
|
| 6 |
+
|
| 7 |
+
# Y down, Z forward
|
| 8 |
+
def load_camera(camera_info):
|
| 9 |
+
"""Load camera parameters from camera info dictionary"""
|
| 10 |
+
# Extract camera parameters
|
| 11 |
+
camera_id = camera_info["camera_id"]
|
| 12 |
+
camera_to_world = np.asarray(camera_info["camera_to_world"], dtype=np.float64)
|
| 13 |
+
|
| 14 |
+
# Change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
| 15 |
+
camera_to_world[:3, 1:3] *= -1
|
| 16 |
+
|
| 17 |
+
# Calculate world to camera transform
|
| 18 |
+
world_to_camera = np.linalg.inv(camera_to_world).astype(np.float32)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Extract rotation and translation
|
| 22 |
+
R = world_to_camera[:3, :3]
|
| 23 |
+
T = world_to_camera[:3, 3]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
world_to_camera[3, 3] = 1.
|
| 27 |
+
world_to_camera = world_to_camera.T
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
width = camera_info.get("width")
|
| 31 |
+
height = camera_info.get("height")
|
| 32 |
+
fx = camera_info.get("focal")
|
| 33 |
+
fy = camera_info.get("focal")
|
| 34 |
+
cx = width / 2
|
| 35 |
+
cy = height / 2
|
| 36 |
+
|
| 37 |
+
# Calculate field of view from focal length
|
| 38 |
+
fovx = 2 * np.arctan(width / (2 * fx))
|
| 39 |
+
fovy = 2 * np.arctan(height / (2 * fy))
|
| 40 |
+
|
| 41 |
+
# Create view matrix
|
| 42 |
+
view_matrix = world_to_view(R=R, t=T)
|
| 43 |
+
|
| 44 |
+
# Create projection matrix
|
| 45 |
+
znear = 0.01
|
| 46 |
+
zfar = 100.0
|
| 47 |
+
proj_matrix = projection_matrix(fovx=fovx, fovy=fovy, znear=znear, zfar=zfar).T
|
| 48 |
+
full_proj_matrix = world_to_camera @ proj_matrix
|
| 49 |
+
|
| 50 |
+
# Calculate other parameters
|
| 51 |
+
tan_fovx = np.tan(fovx * 0.5)
|
| 52 |
+
tan_fovy = np.tan(fovy * 0.5)
|
| 53 |
+
|
| 54 |
+
camera_center = np.linalg.inv(world_to_camera)[3, :3]
|
| 55 |
+
|
| 56 |
+
# Handle camera type and distortion
|
| 57 |
+
camera_model = camera_info.get("camera_model", "OPENCV")
|
| 58 |
+
if camera_model == "OPENCV" or camera_model is None:
|
| 59 |
+
camera_type = 0 # PERSPECTIVE
|
| 60 |
+
elif camera_model == "OPENCV_FISHEYE":
|
| 61 |
+
camera_type = 1 # FISHEYE
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Unsupported camera_model '{camera_model}'")
|
| 64 |
+
|
| 65 |
+
# Get distortion parameters
|
| 66 |
+
distortion_params = []
|
| 67 |
+
for param_name in ["k1", "k2", "p1", "p2", "k3", "k4"]:
|
| 68 |
+
distortion_params.append(camera_info.get(param_name, 0.0))
|
| 69 |
+
|
| 70 |
+
camera_params = {
|
| 71 |
+
'R': R,
|
| 72 |
+
'T': T,
|
| 73 |
+
'camera_center': camera_center,
|
| 74 |
+
'view_matrix': view_matrix,
|
| 75 |
+
'proj_matrix': proj_matrix,
|
| 76 |
+
'full_proj_matrix': full_proj_matrix,
|
| 77 |
+
'tan_fovx': tan_fovx,
|
| 78 |
+
'tan_fovy': tan_fovy,
|
| 79 |
+
'fx': fx,
|
| 80 |
+
'fy': fy,
|
| 81 |
+
'cx': cx,
|
| 82 |
+
'cy': cy,
|
| 83 |
+
'width': width,
|
| 84 |
+
'height': height,
|
| 85 |
+
'camera_to_world': camera_to_world,
|
| 86 |
+
'world_to_camera': world_to_camera,
|
| 87 |
+
'camera_type': camera_type,
|
| 88 |
+
'distortion_params': np.array(distortion_params, dtype=np.float32)
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return camera_params
|
| 92 |
+
|
| 93 |
+
def load_camera_from_json(input_path, camera_id=0):
|
| 94 |
+
"""Load camera parameters from camera.json file"""
|
| 95 |
+
camera_file = os.path.join(os.path.dirname(input_path), "cameras.json")
|
| 96 |
+
if not os.path.exists(camera_file):
|
| 97 |
+
print(f"Warning: No cameras.json found in {os.path.dirname(input_path)}, using default camera")
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
with open(camera_file, 'r') as f:
|
| 102 |
+
cameras = json.load(f)
|
| 103 |
+
|
| 104 |
+
# Find camera with specified ID, or use the first one
|
| 105 |
+
camera = next((cam for cam in cameras if cam["id"] == camera_id), cameras[0])
|
| 106 |
+
|
| 107 |
+
# Use load_camera to process the camera parameters
|
| 108 |
+
return load_camera(camera)
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"Error loading camera from cameras.json: {e}")
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
def load_camera_colmap(cam_info):
|
| 115 |
+
"""
|
| 116 |
+
Load camera from COLMAP format (dust3r output) with exact compatibility to original load_camera.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
cam_info: Dictionary containing:
|
| 120 |
+
- width, height: image dimensions
|
| 121 |
+
- fx, fy: focal lengths
|
| 122 |
+
- cx, cy: principal point
|
| 123 |
+
- camera_id: unique identifier
|
| 124 |
+
- R: rotation matrix (world-to-camera rotation)
|
| 125 |
+
- T: translation vector (world-to-camera translation)
|
| 126 |
+
- Optional: camera_model, distortion params
|
| 127 |
+
"""
|
| 128 |
+
# Extract camera parameters
|
| 129 |
+
camera_id = cam_info["camera_id"]
|
| 130 |
+
|
| 131 |
+
# Use provided R and T directly (COLMAP convention - world to camera)
|
| 132 |
+
R = cam_info['R']
|
| 133 |
+
T = cam_info['T'] # This is world-to-camera translation
|
| 134 |
+
|
| 135 |
+
# Build world-to-camera matrix
|
| 136 |
+
world_to_camera = np.eye(4, dtype=np.float64)
|
| 137 |
+
world_to_camera[:3, :3] = R
|
| 138 |
+
world_to_camera[:3, 3] = T
|
| 139 |
+
|
| 140 |
+
# Invert to get camera-to-world
|
| 141 |
+
camera_to_world = np.linalg.inv(world_to_camera).astype(np.float64)
|
| 142 |
+
|
| 143 |
+
# IMPORTANT FIX: Ensure Z direction is correctly oriented for COLMAP convention
|
| 144 |
+
# COLMAP uses +Z forward, so no need to flip Z axis
|
| 145 |
+
# If frustums are still backwards, uncomment this line:
|
| 146 |
+
# camera_to_world[:3, 2] *= -1 # Flip Z axis if needed
|
| 147 |
+
|
| 148 |
+
# Recalculate world_to_camera after any modifications
|
| 149 |
+
world_to_camera = np.linalg.inv(camera_to_world).astype(np.float32)
|
| 150 |
+
|
| 151 |
+
# Extract intrinsics
|
| 152 |
+
width = cam_info.get("width")
|
| 153 |
+
height = cam_info.get("height")
|
| 154 |
+
fx = cam_info.get("fx", cam_info.get("focal", width * 0.7))
|
| 155 |
+
fy = cam_info.get("fy", cam_info.get("focal", height * 0.7))
|
| 156 |
+
cx = cam_info.get("cx", width / 2)
|
| 157 |
+
cy = cam_info.get("cy", height / 2)
|
| 158 |
+
|
| 159 |
+
# Calculate field of view from focal length
|
| 160 |
+
fovx = 2 * np.arctan(width / (2 * fx))
|
| 161 |
+
fovy = 2 * np.arctan(height / (2 * fy))
|
| 162 |
+
|
| 163 |
+
# Create view matrix using the original R and T
|
| 164 |
+
view_matrix = world_to_view(R=R, t=T)
|
| 165 |
+
|
| 166 |
+
# Create projection matrix
|
| 167 |
+
znear = 0.01
|
| 168 |
+
zfar = 100.0
|
| 169 |
+
proj_matrix = projection_matrix(fovx=fovx, fovy=fovy, znear=znear, zfar=zfar).T
|
| 170 |
+
full_proj_matrix = world_to_camera @ proj_matrix
|
| 171 |
+
|
| 172 |
+
# Calculate other parameters
|
| 173 |
+
tan_fovx = np.tan(fovx * 0.5)
|
| 174 |
+
tan_fovy = np.tan(fovy * 0.5)
|
| 175 |
+
|
| 176 |
+
# IMPORTANT FIX: Correctly calculate camera center
|
| 177 |
+
camera_center = camera_to_world[:3, 3] # Extract translation from c2w matrix
|
| 178 |
+
|
| 179 |
+
# Handle camera type and distortion
|
| 180 |
+
camera_model = cam_info.get("camera_model", "OPENCV")
|
| 181 |
+
if camera_model == "OPENCV" or camera_model is None:
|
| 182 |
+
camera_type = 0 # PERSPECTIVE
|
| 183 |
+
elif camera_model == "OPENCV_FISHEYE":
|
| 184 |
+
camera_type = 1 # FISHEYE
|
| 185 |
+
else:
|
| 186 |
+
camera_type = 0 # Default to PERSPECTIVE
|
| 187 |
+
|
| 188 |
+
# Get distortion parameters
|
| 189 |
+
distortion_params = []
|
| 190 |
+
for param_name in ["k1", "k2", "p1", "p2", "k3", "k4"]:
|
| 191 |
+
distortion_params.append(cam_info.get(param_name, 0.0))
|
| 192 |
+
|
| 193 |
+
# Return camera parameters
|
| 194 |
+
camera_params = {
|
| 195 |
+
'R': R,
|
| 196 |
+
'T': T,
|
| 197 |
+
'camera_center': camera_center,
|
| 198 |
+
'view_matrix': view_matrix,
|
| 199 |
+
'proj_matrix': proj_matrix,
|
| 200 |
+
'full_proj_matrix': full_proj_matrix,
|
| 201 |
+
'tan_fovx': tan_fovx,
|
| 202 |
+
'tan_fovy': tan_fovy,
|
| 203 |
+
'fx': fx,
|
| 204 |
+
'fy': fy,
|
| 205 |
+
'cx': cx,
|
| 206 |
+
'cy': cy,
|
| 207 |
+
'width': width,
|
| 208 |
+
'height': height,
|
| 209 |
+
'camera_to_world': camera_to_world,
|
| 210 |
+
'world_to_camera': world_to_camera,
|
| 211 |
+
'camera_type': camera_type,
|
| 212 |
+
'distortion_params': np.array(distortion_params, dtype=np.float32)
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
return camera_params
|
gs/utils/check_opacities.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from plyfile import PlyData
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Path for output folder that's one level above utils
|
| 6 |
+
ply_path = os.path.join(os.path.dirname(os.path.dirname(__file__)),
|
| 7 |
+
'output', 'point_cloud', 'iteration_13999', 'point_cloud.ply')
|
| 8 |
+
# load the PLY
|
| 9 |
+
ply = PlyData.read(ply_path)
|
| 10 |
+
opacities = np.array(ply['vertex']['opacity'])
|
| 11 |
+
|
| 12 |
+
# compute statistics
|
| 13 |
+
min_o, max_o, mean_o = opacities.min(), opacities.max(), opacities.mean()
|
| 14 |
+
near_zero = np.sum(opacities < 1e-3)
|
| 15 |
+
near_one = np.sum(opacities > 0.999)
|
| 16 |
+
|
| 17 |
+
print(f'Loaded {len(opacities)} splats')
|
| 18 |
+
print(f'Opacity range: min={min_o:.6f}, max={max_o:.6f}, mean={mean_o:.6f}')
|
| 19 |
+
print(f'Count near-zero (<1e-3): {near_zero}')
|
| 20 |
+
print(f'Count near-one (>0.999): {near_one}')
|
| 21 |
+
print('Sample opacities:', opacities[:100])
|
gs/utils/math_utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
import warp as wp
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def world_to_view(R, t, translate=np.array([0.0, 0.0, 0.0]), scale=1.0):
|
| 8 |
+
Rt = np.zeros((4, 4))
|
| 9 |
+
Rt[:3, :3] = R.transpose()
|
| 10 |
+
Rt[:3, 3] = t
|
| 11 |
+
Rt[3, 3] = 1.0
|
| 12 |
+
|
| 13 |
+
C2W = np.linalg.inv(Rt)
|
| 14 |
+
cam_center = C2W[:3, 3]
|
| 15 |
+
cam_center = (cam_center + translate) * scale
|
| 16 |
+
C2W[:3, 3] = cam_center
|
| 17 |
+
Rt = np.linalg.inv(C2W)
|
| 18 |
+
return np.float32(Rt)
|
| 19 |
+
|
| 20 |
+
def projection_matrix(fovx, fovy, znear, zfar):
|
| 21 |
+
tanHalfFovY = math.tan((fovy / 2))
|
| 22 |
+
tanHalfFovX = math.tan((fovx / 2))
|
| 23 |
+
|
| 24 |
+
top = tanHalfFovY * znear
|
| 25 |
+
bottom = -top
|
| 26 |
+
right = tanHalfFovX * znear
|
| 27 |
+
left = -right
|
| 28 |
+
|
| 29 |
+
P = np.zeros((4, 4))
|
| 30 |
+
|
| 31 |
+
z_sign = 1.0
|
| 32 |
+
|
| 33 |
+
P[0, 0] = 2.0 * znear / (right - left)
|
| 34 |
+
P[1, 1] = 2.0 * znear / (top - bottom)
|
| 35 |
+
P[0, 2] = (right + left) / (right - left)
|
| 36 |
+
P[1, 2] = (top + bottom) / (top - bottom)
|
| 37 |
+
P[3, 2] = z_sign
|
| 38 |
+
P[2, 2] = z_sign * zfar / (zfar - znear)
|
| 39 |
+
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
| 40 |
+
return P
|
| 41 |
+
|
| 42 |
+
def matrix_to_quaternion(matrix):
|
| 43 |
+
"""
|
| 44 |
+
Convert a 3x3 rotation matrix to a quaternion in (x, y, z, w) format.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
matrix: 3x3 rotation matrix
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Quaternion as (x, y, z, w) in numpy array of shape (4,)
|
| 51 |
+
"""
|
| 52 |
+
# Ensure the input is a proper rotation matrix
|
| 53 |
+
# This is just a simple check that might be helpful during debug
|
| 54 |
+
if np.abs(np.linalg.det(matrix) - 1.0) > 1e-5:
|
| 55 |
+
print(f"Warning: Input matrix determinant is not 1: {np.linalg.det(matrix)}")
|
| 56 |
+
|
| 57 |
+
trace = np.trace(matrix)
|
| 58 |
+
if trace > 0:
|
| 59 |
+
S = 2.0 * np.sqrt(trace + 1.0)
|
| 60 |
+
w = 0.25 * S
|
| 61 |
+
x = (matrix[2, 1] - matrix[1, 2]) / S
|
| 62 |
+
y = (matrix[0, 2] - matrix[2, 0]) / S
|
| 63 |
+
z = (matrix[1, 0] - matrix[0, 1]) / S
|
| 64 |
+
elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]:
|
| 65 |
+
S = 2.0 * np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2])
|
| 66 |
+
w = (matrix[2, 1] - matrix[1, 2]) / S
|
| 67 |
+
x = 0.25 * S
|
| 68 |
+
y = (matrix[0, 1] + matrix[1, 0]) / S
|
| 69 |
+
z = (matrix[0, 2] + matrix[2, 0]) / S
|
| 70 |
+
elif matrix[1, 1] > matrix[2, 2]:
|
| 71 |
+
S = 2.0 * np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2])
|
| 72 |
+
w = (matrix[0, 2] - matrix[2, 0]) / S
|
| 73 |
+
x = (matrix[0, 1] + matrix[1, 0]) / S
|
| 74 |
+
y = 0.25 * S
|
| 75 |
+
z = (matrix[1, 2] + matrix[2, 1]) / S
|
| 76 |
+
else:
|
| 77 |
+
S = 2.0 * np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1])
|
| 78 |
+
w = (matrix[1, 0] - matrix[0, 1]) / S
|
| 79 |
+
x = (matrix[0, 2] + matrix[2, 0]) / S
|
| 80 |
+
y = (matrix[1, 2] + matrix[2, 1]) / S
|
| 81 |
+
z = 0.25 * S
|
| 82 |
+
|
| 83 |
+
# Return as (x, y, z, w) to match Warp's convention
|
| 84 |
+
return np.array([x, y, z, w], dtype=np.float32)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def quaternion_to_rotation_matrix(q):
|
| 88 |
+
w, x, y, z = q
|
| 89 |
+
return np.array([
|
| 90 |
+
[1 - 2*y**2 - 2*z**2, 2*x*y - 2*z*w, 2*x*z + 2*y*w],
|
| 91 |
+
[2*x*y + 2*z*w, 1 - 2*x**2 - 2*z**2, 2*y*z - 2*x*w],
|
| 92 |
+
[2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x**2 - 2*y**2]
|
| 93 |
+
], dtype=np.float32)
|
| 94 |
+
|
| 95 |
+
# def quaternion_to_rotation_matrix(q):
|
| 96 |
+
# """Convert quaternion to rotation matrix with swapped X and Z axes."""
|
| 97 |
+
# qw, qx, qy, qz = q
|
| 98 |
+
|
| 99 |
+
# # Original conversion
|
| 100 |
+
# R = np.array([
|
| 101 |
+
# [1 - 2*qy*qy - 2*qz*qz, 2*qx*qy - 2*qz*qw, 2*qx*qz + 2*qy*qw],
|
| 102 |
+
# [2*qx*qy + 2*qz*qw, 1 - 2*qx*qx - 2*qz*qz, 2*qy*qz - 2*qx*qw],
|
| 103 |
+
# [2*qx*qz - 2*qy*qw, 2*qy*qz + 2*qx*qw, 1 - 2*qx*qx - 2*qy*qy]
|
| 104 |
+
# ])
|
| 105 |
+
|
| 106 |
+
# # Swap X and Z axes (columns and rows)
|
| 107 |
+
# R_fixed = R.copy()
|
| 108 |
+
# R_fixed[:, [0, 2]] = R[:, [2, 0]] # Swap columns 0 and 2
|
| 109 |
+
# R_fixed[[0, 2], :] = R_fixed[[2, 0], :] # Swap rows 0 and 2
|
| 110 |
+
|
| 111 |
+
# return R_fixed
|
gs/utils/plot_loss_log.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
|
| 4 |
+
def plot_loss_log(loss_file="output/steak/loss.txt"):
|
| 5 |
+
"""Plot training loss values on a logarithmic scale from loss.txt"""
|
| 6 |
+
|
| 7 |
+
# Load loss values from txt file
|
| 8 |
+
with open(loss_file, 'r') as f:
|
| 9 |
+
losses = [float(line.strip()) for line in f if line.strip()]
|
| 10 |
+
|
| 11 |
+
# Create figure with log scale
|
| 12 |
+
plt.figure(figsize=(12, 6))
|
| 13 |
+
plt.semilogy(losses, label='Training Loss')
|
| 14 |
+
|
| 15 |
+
# Customize plot
|
| 16 |
+
plt.grid(True, which="both", ls="-", alpha=0.2)
|
| 17 |
+
plt.xlabel('Iteration')
|
| 18 |
+
plt.ylabel('Loss (log scale)')
|
| 19 |
+
plt.title('Training Loss over Time (Log Scale)')
|
| 20 |
+
plt.legend()
|
| 21 |
+
|
| 22 |
+
# Save plot
|
| 23 |
+
output_path = loss_file.replace('.txt', '_plot_log.png')
|
| 24 |
+
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 25 |
+
plt.close()
|
| 26 |
+
|
| 27 |
+
print(f"Saved loss plot to: {output_path}")
|
| 28 |
+
print(f"Loss statistics:")
|
| 29 |
+
print(f" Min: {min(losses):.6f}")
|
| 30 |
+
print(f" Max: {max(losses):.6f}")
|
| 31 |
+
print(f" Mean: {np.mean(losses):.6f}")
|
| 32 |
+
print(f" Final: {losses[-1]:.6f}")
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
plot_loss_log()
|
gs/utils/point_cloud_utils.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
from plyfile import PlyData, PlyElement
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import warp as wp
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_ply(filepath):
|
| 10 |
+
"""
|
| 11 |
+
Load a Gaussian splat PLY file.
|
| 12 |
+
|
| 13 |
+
Returns dict with: positions, scales, rotations, opacities, shs
|
| 14 |
+
"""
|
| 15 |
+
plydata = PlyData.read(filepath)
|
| 16 |
+
vertex = plydata['vertex']
|
| 17 |
+
|
| 18 |
+
num_points = len(vertex)
|
| 19 |
+
|
| 20 |
+
# Load positions
|
| 21 |
+
positions = np.stack([
|
| 22 |
+
vertex['x'], vertex['y'], vertex['z']
|
| 23 |
+
], axis=-1).astype(np.float32)
|
| 24 |
+
|
| 25 |
+
# Load scales (stored in log space)
|
| 26 |
+
scales = np.stack([
|
| 27 |
+
np.exp(vertex['scale_0']),
|
| 28 |
+
np.exp(vertex['scale_1']),
|
| 29 |
+
np.exp(vertex['scale_2'])
|
| 30 |
+
], axis=-1).astype(np.float32)
|
| 31 |
+
|
| 32 |
+
# Load opacities
|
| 33 |
+
opacities = vertex['opacity'].astype(np.float32).reshape(-1, 1)
|
| 34 |
+
|
| 35 |
+
# Load rotations (quaternion)
|
| 36 |
+
rotations = np.stack([
|
| 37 |
+
vertex['rot_0'], vertex['rot_1'], vertex['rot_2'], vertex['rot_3']
|
| 38 |
+
], axis=-1).astype(np.float32)
|
| 39 |
+
|
| 40 |
+
# Load SH coefficients
|
| 41 |
+
# DC term
|
| 42 |
+
sh_dc = np.stack([
|
| 43 |
+
vertex['f_dc_0'], vertex['f_dc_1'], vertex['f_dc_2']
|
| 44 |
+
], axis=-1).astype(np.float32)
|
| 45 |
+
|
| 46 |
+
# Rest of SH coefficients
|
| 47 |
+
sh_rest = []
|
| 48 |
+
for i in range(45):
|
| 49 |
+
sh_rest.append(vertex[f'f_rest_{i}'])
|
| 50 |
+
sh_rest = np.stack(sh_rest, axis=-1).astype(np.float32) # (N, 45)
|
| 51 |
+
sh_rest = sh_rest.reshape(num_points, 15, 3) # (N, 15, 3)
|
| 52 |
+
|
| 53 |
+
# Combine into (N*16, 3) format expected by renderer
|
| 54 |
+
shs = np.zeros((num_points * 16, 3), dtype=np.float32)
|
| 55 |
+
for i in range(num_points):
|
| 56 |
+
shs[i * 16] = sh_dc[i]
|
| 57 |
+
for j in range(15):
|
| 58 |
+
shs[i * 16 + j + 1] = sh_rest[i, j]
|
| 59 |
+
|
| 60 |
+
return {
|
| 61 |
+
'positions': positions,
|
| 62 |
+
'scales': scales,
|
| 63 |
+
'rotations': rotations,
|
| 64 |
+
'opacities': opacities,
|
| 65 |
+
'shs': shs,
|
| 66 |
+
'num_points': num_points
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Function to save point cloud to PLY file
|
| 71 |
+
def save_ply(params, filepath, num_points, colors=None):
|
| 72 |
+
# Get numpy arrays
|
| 73 |
+
positions = params['positions'].numpy()
|
| 74 |
+
scales = params['scales'].numpy()
|
| 75 |
+
rotations = params['rotations'].numpy()
|
| 76 |
+
opacities = params['opacities'].numpy()
|
| 77 |
+
shs = params['shs'].numpy()
|
| 78 |
+
|
| 79 |
+
# Handle colors - either provided or computed from SH coefficients
|
| 80 |
+
if colors is not None:
|
| 81 |
+
# Use provided colors
|
| 82 |
+
if hasattr(colors, 'numpy'):
|
| 83 |
+
colors_np = colors.numpy()
|
| 84 |
+
else:
|
| 85 |
+
colors_np = colors
|
| 86 |
+
else:
|
| 87 |
+
# Compute colors from SH coefficients (DC term only for simplicity)
|
| 88 |
+
# SH DC coefficients are stored in the first coefficient (index 0)
|
| 89 |
+
colors_np = np.zeros((num_points, 3), dtype=np.float32)
|
| 90 |
+
for i in range(num_points):
|
| 91 |
+
# Get DC term from SH coefficients
|
| 92 |
+
sh_dc = shs[i * 16] # First SH coefficient contains DC term
|
| 93 |
+
# Convert from SH to RGB (simplified - just use DC term)
|
| 94 |
+
colors_np[i] = np.clip(sh_dc + 0.5, 0.0, 1.0) # Add 0.5 offset and clamp
|
| 95 |
+
|
| 96 |
+
# Create vertex data
|
| 97 |
+
vertex_data = []
|
| 98 |
+
for i in range(num_points):
|
| 99 |
+
# Basic properties
|
| 100 |
+
vertex = (
|
| 101 |
+
positions[i][0], positions[i][1], positions[i][2],
|
| 102 |
+
np.log(scales[i][0]), np.log(scales[i][1]), np.log(scales[i][2]), # Log-space encoding
|
| 103 |
+
(opacities[i])
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Add rotation quaternion elements
|
| 107 |
+
quat = rotations[i]
|
| 108 |
+
rot_elements = (quat[0], quat[1], quat[2], quat[3]) # x, y, z, w
|
| 109 |
+
vertex += rot_elements
|
| 110 |
+
|
| 111 |
+
# Add RGB colors (convert to 0-255 range)
|
| 112 |
+
color_255 = (
|
| 113 |
+
int(np.clip(colors_np[i][0] * 255, 0, 255)),
|
| 114 |
+
int(np.clip(colors_np[i][1] * 255, 0, 255)),
|
| 115 |
+
int(np.clip(colors_np[i][2] * 255, 0, 255))
|
| 116 |
+
)
|
| 117 |
+
vertex += color_255
|
| 118 |
+
|
| 119 |
+
# Add SH coefficients
|
| 120 |
+
sh_dc = tuple(shs[i * 16][j] for j in range(3))
|
| 121 |
+
vertex += sh_dc
|
| 122 |
+
|
| 123 |
+
# Add remaining SH coefficients
|
| 124 |
+
sh_rest = []
|
| 125 |
+
for j in range(1, 16):
|
| 126 |
+
for c in range(3):
|
| 127 |
+
sh_rest.append(shs[i * 16 + j][c])
|
| 128 |
+
vertex += tuple(sh_rest)
|
| 129 |
+
|
| 130 |
+
vertex_data.append(vertex)
|
| 131 |
+
|
| 132 |
+
# Define the structure of the PLY file
|
| 133 |
+
vertex_type = [
|
| 134 |
+
('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
| 135 |
+
('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'),
|
| 136 |
+
('opacity', 'f4')
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
# Add rotation quaternion elements
|
| 140 |
+
vertex_type.extend([('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4')])
|
| 141 |
+
|
| 142 |
+
# Add RGB color fields
|
| 143 |
+
vertex_type.extend([('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
|
| 144 |
+
|
| 145 |
+
# Add SH coefficients
|
| 146 |
+
vertex_type.extend([('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4')])
|
| 147 |
+
|
| 148 |
+
# Add remaining SH coefficients
|
| 149 |
+
for i in range(45): # 15 coeffs * 3 channels
|
| 150 |
+
vertex_type.append((f'f_rest_{i}', 'f4'))
|
| 151 |
+
|
| 152 |
+
vertex_array = np.array(vertex_data, dtype=vertex_type)
|
| 153 |
+
el = PlyElement.describe(vertex_array, 'vertex')
|
| 154 |
+
|
| 155 |
+
# Create directory if it doesn't exist
|
| 156 |
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
| 157 |
+
|
| 158 |
+
# Save the PLY file
|
| 159 |
+
PlyData([el], text=False).write(filepath)
|
| 160 |
+
print(f"Point cloud saved to {filepath}")
|
gs/utils/wp_utils.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warp as wp
|
| 2 |
+
from config import DEVICE
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@wp.func
|
| 6 |
+
def wp_vec3_mul_element(a: wp.vec3, b: wp.vec3) -> wp.vec3:
|
| 7 |
+
return wp.vec3(a[0] * b[0], a[1] * b[1], a[2] * b[2])
|
| 8 |
+
|
| 9 |
+
# Reinstate the element-wise vector square root helper function
|
| 10 |
+
@wp.func
|
| 11 |
+
def wp_vec3_sqrt(a: wp.vec3) -> wp.vec3:
|
| 12 |
+
return wp.vec3(wp.sqrt(a[0]), wp.sqrt(a[1]), wp.sqrt(a[2]))
|
| 13 |
+
|
| 14 |
+
# Add element-wise vector division helper function
|
| 15 |
+
@wp.func
|
| 16 |
+
def wp_vec3_div_element(a: wp.vec3, b: wp.vec3) -> wp.vec3:
|
| 17 |
+
# Add small epsilon to denominator to prevent division by zero
|
| 18 |
+
# (although Adam's epsilon should mostly handle this)
|
| 19 |
+
safe_b = wp.vec3(b[0] + 1e-9, b[1] + 1e-9, b[2] + 1e-9)
|
| 20 |
+
return wp.vec3(a[0] / safe_b[0], a[1] / safe_b[1], a[2] / safe_b[2])
|
| 21 |
+
|
| 22 |
+
@wp.func
|
| 23 |
+
def wp_vec3_add_element(a: wp.vec3, b: wp.vec3) -> wp.vec3:
|
| 24 |
+
return wp.vec3(a[0] + b[0], a[1] + b[1], a[2] + b[2])
|
| 25 |
+
|
| 26 |
+
@wp.func
|
| 27 |
+
def wp_vec3_clamp(x: wp.vec3, min_val: float, max_val: float) -> wp.vec3:
|
| 28 |
+
return wp.vec3(
|
| 29 |
+
wp.clamp(x[0], min_val, max_val),
|
| 30 |
+
wp.clamp(x[1], min_val, max_val),
|
| 31 |
+
wp.clamp(x[2], min_val, max_val)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def to_warp_array(data, dtype, shape_check=None, flatten=False):
|
| 35 |
+
if isinstance(data, wp.array):
|
| 36 |
+
return data
|
| 37 |
+
if data is None:
|
| 38 |
+
return None
|
| 39 |
+
# Convert torch tensor to numpy if needed
|
| 40 |
+
if hasattr(data, 'cpu') and hasattr(data, 'numpy'):
|
| 41 |
+
data = data.cpu().numpy()
|
| 42 |
+
if flatten and data.ndim == 2 and data.shape[1] == 1:
|
| 43 |
+
data = data.flatten()
|
| 44 |
+
return wp.array(data, dtype=dtype, device=DEVICE)
|
| 45 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Install PyTorch with CUDA 11.8 separately using:
|
| 2 |
+
# pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
|
| 3 |
+
# torch==2.5.1
|
| 4 |
+
# torchvision==0.20.1
|
| 5 |
+
|
| 6 |
+
warp-lang==1.7.0
|
| 7 |
+
numpy==1.26.3
|
| 8 |
+
imageio==2.34.1
|
| 9 |
+
plyfile
|
| 10 |
+
roma
|
| 11 |
+
gradio==5.17.1
|
| 12 |
+
pydantic==2.10.6
|
| 13 |
+
matplotlib==3.9.2
|
| 14 |
+
tqdm==4.66.5
|
| 15 |
+
opencv-python
|
| 16 |
+
pypng
|
| 17 |
+
scipy
|
| 18 |
+
einops
|
| 19 |
+
trimesh
|
| 20 |
+
pyglet<2
|
| 21 |
+
viser
|
| 22 |
+
jaxtyping
|
| 23 |
+
hydra-submitit-launcher
|
| 24 |
+
scikit-learn
|
| 25 |
+
plotly
|
| 26 |
+
git+https://github.com/facebookresearch/vggt.git@44b3afb
|
vdpm/.gitignore
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/
|
| 2 |
+
checkpoints/
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
pip-wheel-metadata/
|
| 27 |
+
share/python-wheels/
|
| 28 |
+
*.egg-info/
|
| 29 |
+
.installed.cfg
|
| 30 |
+
*.egg
|
| 31 |
+
MANIFEST
|
| 32 |
+
|
| 33 |
+
# PyInstaller
|
| 34 |
+
# Usually these files are written by a python script from a template
|
| 35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
*.py,cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
.pytest_cache/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
target/
|
| 79 |
+
|
| 80 |
+
# Jupyter Notebook
|
| 81 |
+
.ipynb_checkpoints
|
| 82 |
+
|
| 83 |
+
# IPython
|
| 84 |
+
profile_default/
|
| 85 |
+
ipython_config.py
|
| 86 |
+
|
| 87 |
+
# pyenv
|
| 88 |
+
.python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 98 |
+
__pypackages__/
|
| 99 |
+
|
| 100 |
+
# Celery stuff
|
| 101 |
+
celerybeat-schedule
|
| 102 |
+
celerybeat.pid
|
| 103 |
+
|
| 104 |
+
# SageMath parsed files
|
| 105 |
+
*.sage.py
|
| 106 |
+
|
| 107 |
+
# Environments
|
| 108 |
+
.env
|
| 109 |
+
.venv
|
| 110 |
+
env/
|
| 111 |
+
venv/
|
| 112 |
+
ENV/
|
| 113 |
+
env.bak/
|
| 114 |
+
venv.bak/
|
| 115 |
+
|
| 116 |
+
# Spyder project settings
|
| 117 |
+
.spyderproject
|
| 118 |
+
.spyproject
|
| 119 |
+
|
| 120 |
+
# Rope project settings
|
| 121 |
+
.ropeproject
|
| 122 |
+
|
| 123 |
+
# mkdocs documentation
|
| 124 |
+
/site
|
| 125 |
+
|
| 126 |
+
# mypy
|
| 127 |
+
.mypy_cache/
|
| 128 |
+
.dmypy.json
|
| 129 |
+
dmypy.json
|
| 130 |
+
|
| 131 |
+
# Pyre type checker
|
| 132 |
+
.pyre/
|
vdpm/.gitmodules
ADDED
|
File without changes
|
vdpm/.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
vdpm/LICENSE
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Eldar Insafutdinov, Edgar Sucar
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
vdpm/LICENSE-VGGT
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
VGGT License
|
| 2 |
+
|
| 3 |
+
v1 Last Updated: July 29, 2025
|
| 4 |
+
|
| 5 |
+
“Acceptable Use Policy” means the Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
|
| 6 |
+
|
| 7 |
+
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
“Documentation” means the specifications, manuals and documentation accompanying
|
| 11 |
+
Research Materials distributed by Meta.
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 15 |
+
|
| 16 |
+
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 17 |
+
“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
| 18 |
+
|
| 19 |
+
By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
1. License Rights and Redistribution.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
|
| 26 |
+
|
| 27 |
+
b. Redistribution and Use.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
i. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
iii. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
|
| 37 |
+
2. User Support. Your use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 41 |
+
|
| 42 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 43 |
+
|
| 44 |
+
5. Intellectual Property.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
| 48 |
+
|
| 49 |
+
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
|
| 50 |
+
|
| 51 |
+
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
|
| 52 |
+
|
| 53 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
Acceptable Use Policy
|
| 60 |
+
|
| 61 |
+
Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
|
| 62 |
+
|
| 63 |
+
As part of this mission, Meta makes certain research materials available for use in accordance with this Agreement (including the Acceptable Use Policy). Meta is committed to promoting the safe and responsible use of such research materials.
|
| 64 |
+
|
| 65 |
+
Prohibited Uses
|
| 66 |
+
|
| 67 |
+
You agree you will not use, or allow others to use, Research Materials to:
|
| 68 |
+
|
| 69 |
+
Violate the law or others’ rights, including to:
|
| 70 |
+
Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
|
| 71 |
+
Violence or terrorism
|
| 72 |
+
Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
|
| 73 |
+
Human trafficking, exploitation, and sexual violence
|
| 74 |
+
The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
|
| 75 |
+
Sexual solicitation
|
| 76 |
+
Any other criminal activity
|
| 77 |
+
|
| 78 |
+
Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
|
| 79 |
+
|
| 80 |
+
Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
|
| 81 |
+
|
| 82 |
+
Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
|
| 83 |
+
|
| 84 |
+
Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
|
| 85 |
+
|
| 86 |
+
Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using Research Materials
|
| 87 |
+
|
| 88 |
+
Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
|
| 89 |
+
|
| 90 |
+
2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
|
| 91 |
+
|
| 92 |
+
Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
|
| 93 |
+
|
| 94 |
+
Guns and illegal weapons (including weapon development)
|
| 95 |
+
|
| 96 |
+
Illegal drugs and regulated/controlled substances
|
| 97 |
+
Operation of critical infrastructure, transportation technologies, or heavy machinery
|
| 98 |
+
|
| 99 |
+
Self-harm or harm to others, including suicide, cutting, and eating disorders
|
| 100 |
+
Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
|
| 101 |
+
|
| 102 |
+
3. Intentionally deceive or mislead others, including use of Research Materials related to the following:
|
| 103 |
+
|
| 104 |
+
Generating, promoting, or furthering fraud or the creation or promotion of disinformation
|
| 105 |
+
Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
|
| 106 |
+
|
| 107 |
+
Generating, promoting, or further distributing spam
|
| 108 |
+
|
| 109 |
+
Impersonating another individual without consent, authorization, or legal right
|
| 110 |
+
|
| 111 |
+
Representing that outputs of research materials or outputs from technology using Research Materials are human-generated
|
| 112 |
+
|
| 113 |
+
Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
|
| 114 |
+
|
| 115 |
+
4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
|
vdpm/README.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: vdpm
|
| 3 |
+
app_file: gradio_demo.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 5.17.1
|
| 6 |
+
---
|
| 7 |
+
<div align="center">
|
| 8 |
+
<h1>V-DPM: 4D Video Reconstruction with Dynamic Point Maps</h1>
|
| 9 |
+
|
| 10 |
+
<a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
|
| 11 |
+
<a href="https://huggingface.co/spaces/edgarsucar/vdpm"><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
|
| 12 |
+
|
| 13 |
+
**[Visual Geometry Group, University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
[Edgar Sucar](https://edgarsucar.github.io/)\*, [Eldar Insafutdinov](https://eldar.insafutdinov.com/)\*, [Zihang Lai](https://scholar.google.com/citations?user=31eXgMYAAAAJ), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/)
|
| 17 |
+
</div>
|
| 18 |
+
|
| 19 |
+
## Setup
|
| 20 |
+
|
| 21 |
+
First, clone the repository and setup a virtual environment with [uv](https://github.com/astral-sh/uv):
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
git clone git@github.com:eldar/vdpm.git
|
| 25 |
+
cd vdpm
|
| 26 |
+
uv venv --python 3.12
|
| 27 |
+
. .venv/bin/activate
|
| 28 |
+
|
| 29 |
+
# Install PyTorch with CUDA 11.8 first
|
| 30 |
+
uv pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
|
| 31 |
+
|
| 32 |
+
# Then install remaining dependencies
|
| 33 |
+
uv pip install -r requirements.txt
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Viser demo
|
| 37 |
+
```bash
|
| 38 |
+
python visualise.py ++vis.input_video=examples/videos/camel.mp4
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Gradio demo
|
| 42 |
+
```bash
|
| 43 |
+
python gradio_demo.py
|
| 44 |
+
```
|
vdpm/check_model_size.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# Add parent directory to path
|
| 6 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 7 |
+
|
| 8 |
+
def check_model_memory():
|
| 9 |
+
# Simple config object
|
| 10 |
+
class SimpleConfig:
|
| 11 |
+
class ModelConfig:
|
| 12 |
+
decoder_depth = 4
|
| 13 |
+
model = ModelConfig()
|
| 14 |
+
|
| 15 |
+
cfg = SimpleConfig()
|
| 16 |
+
|
| 17 |
+
# Import after path is set
|
| 18 |
+
from dpm.model import VDPM
|
| 19 |
+
|
| 20 |
+
# Create model on CPU first to count parameters
|
| 21 |
+
print("Creating model...")
|
| 22 |
+
model = VDPM(cfg)
|
| 23 |
+
|
| 24 |
+
# Count parameters
|
| 25 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 26 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 27 |
+
|
| 28 |
+
print(f"\n{'='*60}")
|
| 29 |
+
print(f"MODEL SIZE ANALYSIS FOR RTX 3070 Ti (8GB)")
|
| 30 |
+
print(f"{'='*60}")
|
| 31 |
+
print(f"Total parameters: {total_params:,}")
|
| 32 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 33 |
+
print(f"\nEstimated model weights memory:")
|
| 34 |
+
print(f" - FP32 (float32): {total_params * 4 / 1024**3:.2f} GB")
|
| 35 |
+
print(f" - FP16 (float16): {total_params * 2 / 1024**3:.2f} GB")
|
| 36 |
+
print(f" - BF16 (bfloat16): {total_params * 2 / 1024**3:.2f} GB")
|
| 37 |
+
print(f" - INT8 (quantized): {total_params * 1 / 1024**3:.2f} GB <-- RECOMMENDED for 8GB GPU")
|
| 38 |
+
|
| 39 |
+
# Estimate activation memory for typical input
|
| 40 |
+
batch_size = 1
|
| 41 |
+
num_frames = 5 # typical video length
|
| 42 |
+
img_size = 518
|
| 43 |
+
print(f"\nEstimated activation memory (batch={batch_size}, frames={num_frames}, img_size={img_size}):")
|
| 44 |
+
|
| 45 |
+
# Input images: [B, S, 3, H, W]
|
| 46 |
+
input_mem = batch_size * num_frames * 3 * img_size * img_size * 4 / 1024**3
|
| 47 |
+
print(f" - Input images (FP32): {input_mem:.2f} GB")
|
| 48 |
+
|
| 49 |
+
# Rough estimate for activations (can be 2-4x model size during forward pass)
|
| 50 |
+
activation_mem_estimate = total_params * 2 * 3 / 1024**3 # conservative estimate
|
| 51 |
+
print(f" - Activations (estimate): {activation_mem_estimate:.2f} GB")
|
| 52 |
+
|
| 53 |
+
# Calculate total for different precision modes
|
| 54 |
+
total_fp16 = (total_params * 2 / 1024**3) + input_mem + activation_mem_estimate
|
| 55 |
+
total_int8 = (total_params * 1 / 1024**3) + input_mem + (activation_mem_estimate * 0.6) # INT8 reduces activations too
|
| 56 |
+
|
| 57 |
+
print(f"\nTotal estimated GPU memory needed:")
|
| 58 |
+
print(f" - With FP16/BF16: {total_fp16:.2f} GB")
|
| 59 |
+
print(f" - With INT8 quantization: {total_int8:.2f} GB <-- FITS IN 8GB!")
|
| 60 |
+
print(f"Your RTX 3070 Ti has: 8 GB VRAM")
|
| 61 |
+
|
| 62 |
+
if total_int8 <= 8:
|
| 63 |
+
print(f"\n✓ With INT8 quantization, model will fit in GPU memory!")
|
| 64 |
+
print(f" Set USE_QUANTIZATION = True in gradio_demo.py")
|
| 65 |
+
elif total_fp16 > 8:
|
| 66 |
+
print(f"\n⚠️ WARNING: Even with INT8 ({total_int8:.2f} GB), memory is tight")
|
| 67 |
+
print(f" Recommendations:")
|
| 68 |
+
print(f" 1. Use INT8 quantization (USE_QUANTIZATION = True)")
|
| 69 |
+
print(f" 2. Reduce number of input frames to {num_frames} or fewer")
|
| 70 |
+
print(f" 3. Clear CUDA cache between batches")
|
| 71 |
+
else:
|
| 72 |
+
print(f"\n✓ Model should fit with FP16!")
|
| 73 |
+
|
| 74 |
+
print(f"{'='*60}\n")
|
| 75 |
+
|
| 76 |
+
# Check actual GPU memory if CUDA available
|
| 77 |
+
if torch.cuda.is_available():
|
| 78 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 79 |
+
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
| 80 |
+
print(f"Current GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
| 81 |
+
print(f"Current GPU memory cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
check_model_memory()
|
| 85 |
+
|
vdpm/configs/config.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- hydra: defaults
|
| 4 |
+
- model: dpm
|
| 5 |
+
|
| 6 |
+
config:
|
| 7 |
+
exp_name: "debug"
|
| 8 |
+
file: "config.yaml"
|
| 9 |
+
|
| 10 |
+
data_loader:
|
| 11 |
+
batch_size: 2
|
| 12 |
+
num_workers: 8
|
| 13 |
+
dynamic_batch: false
|
| 14 |
+
|
| 15 |
+
train:
|
| 16 |
+
logging: true
|
| 17 |
+
num_gpus: 4
|
| 18 |
+
amp: bfloat16
|
| 19 |
+
amp_dpt: false
|
| 20 |
+
dry_run: false
|
| 21 |
+
camera_loss_lambda: 5.0
|
| 22 |
+
|
| 23 |
+
optimiser:
|
| 24 |
+
lr: 0.00005 # absolute lr
|
| 25 |
+
blr: 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 26 |
+
start_epoch:
|
| 27 |
+
epochs: 70
|
| 28 |
+
accum_iter: 1
|
| 29 |
+
warmup_epochs: 3
|
| 30 |
+
min_lr: 1e-06
|
| 31 |
+
|
| 32 |
+
run:
|
| 33 |
+
resume: false
|
| 34 |
+
dirpath: null
|
| 35 |
+
debug: false
|
| 36 |
+
random_seed: 42
|
| 37 |
+
git_hash: null
|
| 38 |
+
log_frequency: 250
|
| 39 |
+
training_progress_bar: false
|
| 40 |
+
save_freq: 5
|
| 41 |
+
eval_freq: 1
|
| 42 |
+
keep_freq: 5
|
| 43 |
+
print_freq: 20
|
| 44 |
+
num_keep_ckpts: 5
|
| 45 |
+
# Old Dust3r params
|
| 46 |
+
world_size: -1
|
| 47 |
+
local_rank: -1
|
| 48 |
+
dist_url: "env://"
|
| 49 |
+
seed: 0
|
| 50 |
+
|
vdpm/configs/model/dpm.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: dpm-video
|
| 2 |
+
pretrained: /work/eldar/models/vggt/VGGT-1B.pt
|
| 3 |
+
decoder_depth: 4
|
vdpm/configs/visualise.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- model: dpm
|
| 4 |
+
|
| 5 |
+
hydra:
|
| 6 |
+
output_subdir: null # Disable saving of config files.
|
| 7 |
+
job:
|
| 8 |
+
chdir: False
|
| 9 |
+
|
| 10 |
+
vis:
|
| 11 |
+
port: 8080
|
| 12 |
+
input_video:
|
| 13 |
+
|
vdpm/dpm/aggregator.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE-VGGT file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.utils.checkpoint import checkpoint
|
| 12 |
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
| 13 |
+
|
| 14 |
+
from vggt.layers import PatchEmbed
|
| 15 |
+
from vggt.layers.block import Block
|
| 16 |
+
from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
| 17 |
+
from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
| 22 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Aggregator(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
The Aggregator applies alternating-attention over input frames,
|
| 28 |
+
as described in VGGT: Visual Geometry Grounded Transformer.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
img_size (int): Image size in pixels.
|
| 33 |
+
patch_size (int): Size of each patch for PatchEmbed.
|
| 34 |
+
embed_dim (int): Dimension of the token embeddings.
|
| 35 |
+
depth (int): Number of blocks.
|
| 36 |
+
num_heads (int): Number of attention heads.
|
| 37 |
+
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
| 38 |
+
num_register_tokens (int): Number of register tokens.
|
| 39 |
+
block_fn (nn.Module): The block type used for attention (Block by default).
|
| 40 |
+
qkv_bias (bool): Whether to include bias in QKV projections.
|
| 41 |
+
proj_bias (bool): Whether to include bias in the output projection.
|
| 42 |
+
ffn_bias (bool): Whether to include bias in MLP layers.
|
| 43 |
+
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
|
| 44 |
+
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
|
| 45 |
+
aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
|
| 46 |
+
qk_norm (bool): Whether to apply QK normalization.
|
| 47 |
+
rope_freq (int): Base frequency for rotary embedding. -1 to disable.
|
| 48 |
+
init_values (float): Init scale for layer scale.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
img_size=518,
|
| 54 |
+
patch_size=14,
|
| 55 |
+
embed_dim=1024,
|
| 56 |
+
depth=24,
|
| 57 |
+
num_heads=16,
|
| 58 |
+
mlp_ratio=4.0,
|
| 59 |
+
num_register_tokens=4,
|
| 60 |
+
block_fn=Block,
|
| 61 |
+
qkv_bias=True,
|
| 62 |
+
proj_bias=True,
|
| 63 |
+
ffn_bias=True,
|
| 64 |
+
patch_embed="dinov2_vitl14_reg",
|
| 65 |
+
aa_order=["frame", "global"],
|
| 66 |
+
aa_block_size=1,
|
| 67 |
+
qk_norm=True,
|
| 68 |
+
rope_freq=100,
|
| 69 |
+
init_values=0.01,
|
| 70 |
+
):
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
|
| 74 |
+
|
| 75 |
+
# Initialize rotary position embedding if frequency > 0
|
| 76 |
+
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
| 77 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
| 78 |
+
|
| 79 |
+
self.frame_blocks = nn.ModuleList(
|
| 80 |
+
[
|
| 81 |
+
block_fn(
|
| 82 |
+
dim=embed_dim,
|
| 83 |
+
num_heads=num_heads,
|
| 84 |
+
mlp_ratio=mlp_ratio,
|
| 85 |
+
qkv_bias=qkv_bias,
|
| 86 |
+
proj_bias=proj_bias,
|
| 87 |
+
ffn_bias=ffn_bias,
|
| 88 |
+
init_values=init_values,
|
| 89 |
+
qk_norm=qk_norm,
|
| 90 |
+
rope=self.rope,
|
| 91 |
+
)
|
| 92 |
+
for _ in range(depth)
|
| 93 |
+
]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.global_blocks = nn.ModuleList(
|
| 97 |
+
[
|
| 98 |
+
block_fn(
|
| 99 |
+
dim=embed_dim,
|
| 100 |
+
num_heads=num_heads,
|
| 101 |
+
mlp_ratio=mlp_ratio,
|
| 102 |
+
qkv_bias=qkv_bias,
|
| 103 |
+
proj_bias=proj_bias,
|
| 104 |
+
ffn_bias=ffn_bias,
|
| 105 |
+
init_values=init_values,
|
| 106 |
+
qk_norm=qk_norm,
|
| 107 |
+
rope=self.rope,
|
| 108 |
+
)
|
| 109 |
+
for _ in range(depth)
|
| 110 |
+
]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.depth = depth
|
| 114 |
+
self.aa_order = aa_order
|
| 115 |
+
self.patch_size = patch_size
|
| 116 |
+
self.aa_block_size = aa_block_size
|
| 117 |
+
|
| 118 |
+
# Validate that depth is divisible by aa_block_size
|
| 119 |
+
if self.depth % self.aa_block_size != 0:
|
| 120 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
| 121 |
+
|
| 122 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
| 123 |
+
|
| 124 |
+
# Note: We have two camera tokens, one for the first frame and one for the rest
|
| 125 |
+
# The same applies for register tokens
|
| 126 |
+
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
|
| 127 |
+
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
|
| 128 |
+
|
| 129 |
+
# The patch tokens start after the camera and register tokens
|
| 130 |
+
self.patch_start_idx = 1 + num_register_tokens
|
| 131 |
+
|
| 132 |
+
self.time_conditioning_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
| 133 |
+
self.patch_start_idx += 1
|
| 134 |
+
|
| 135 |
+
# Initialize parameters with small values
|
| 136 |
+
nn.init.normal_(self.camera_token, std=1e-6)
|
| 137 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
| 138 |
+
|
| 139 |
+
# Register normalization constants as buffers
|
| 140 |
+
for name, value in (
|
| 141 |
+
("_resnet_mean", _RESNET_MEAN),
|
| 142 |
+
("_resnet_std", _RESNET_STD),
|
| 143 |
+
):
|
| 144 |
+
self.register_buffer(
|
| 145 |
+
name,
|
| 146 |
+
torch.FloatTensor(value).view(1, 1, 3, 1, 1),
|
| 147 |
+
persistent=False,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.use_reentrant = False # hardcoded to False
|
| 151 |
+
|
| 152 |
+
def __build_patch_embed__(
|
| 153 |
+
self,
|
| 154 |
+
patch_embed,
|
| 155 |
+
img_size,
|
| 156 |
+
patch_size,
|
| 157 |
+
num_register_tokens,
|
| 158 |
+
interpolate_antialias=True,
|
| 159 |
+
interpolate_offset=0.0,
|
| 160 |
+
block_chunks=0,
|
| 161 |
+
init_values=1.0,
|
| 162 |
+
embed_dim=1024,
|
| 163 |
+
):
|
| 164 |
+
"""
|
| 165 |
+
Build the patch embed layer. If 'conv', we use a
|
| 166 |
+
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
if "conv" in patch_embed:
|
| 170 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
|
| 171 |
+
else:
|
| 172 |
+
vit_models = {
|
| 173 |
+
"dinov2_vitl14_reg": vit_large,
|
| 174 |
+
"dinov2_vitb14_reg": vit_base,
|
| 175 |
+
"dinov2_vits14_reg": vit_small,
|
| 176 |
+
"dinov2_vitg2_reg": vit_giant2,
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
self.patch_embed = vit_models[patch_embed](
|
| 180 |
+
img_size=img_size,
|
| 181 |
+
patch_size=patch_size,
|
| 182 |
+
num_register_tokens=num_register_tokens,
|
| 183 |
+
interpolate_antialias=interpolate_antialias,
|
| 184 |
+
interpolate_offset=interpolate_offset,
|
| 185 |
+
block_chunks=block_chunks,
|
| 186 |
+
init_values=init_values,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Disable gradient updates for mask token
|
| 190 |
+
if hasattr(self.patch_embed, "mask_token"):
|
| 191 |
+
self.patch_embed.mask_token.requires_grad_(False)
|
| 192 |
+
|
| 193 |
+
def forward(
|
| 194 |
+
self,
|
| 195 |
+
images: torch.Tensor,
|
| 196 |
+
) -> Tuple[List[torch.Tensor], int]:
|
| 197 |
+
"""
|
| 198 |
+
Args:
|
| 199 |
+
images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 200 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
(list[torch.Tensor], int):
|
| 204 |
+
The list of outputs from the attention blocks,
|
| 205 |
+
and the patch_start_idx indicating where patch tokens begin.
|
| 206 |
+
"""
|
| 207 |
+
B, S, C_in, H, W = images.shape
|
| 208 |
+
|
| 209 |
+
if C_in != 3:
|
| 210 |
+
raise ValueError(f"Expected 3 input channels, got {C_in}")
|
| 211 |
+
|
| 212 |
+
# Normalize images and reshape for patch embed
|
| 213 |
+
images = (images - self._resnet_mean) / self._resnet_std
|
| 214 |
+
|
| 215 |
+
# Reshape to [B*S, C, H, W] for patch embedding
|
| 216 |
+
images = images.view(B * S, C_in, H, W)
|
| 217 |
+
patch_tokens = self.patch_embed(images)
|
| 218 |
+
|
| 219 |
+
if isinstance(patch_tokens, dict):
|
| 220 |
+
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
| 221 |
+
|
| 222 |
+
_, P, C = patch_tokens.shape
|
| 223 |
+
|
| 224 |
+
# Expand camera and register tokens to match batch size and sequence length
|
| 225 |
+
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
|
| 226 |
+
register_token = slice_expand_and_flatten(self.register_token, B, S)
|
| 227 |
+
# do something similar for time_conditioning_token
|
| 228 |
+
time_conditioning_token = slice_expand_and_flatten_single(self.time_conditioning_token, B, S)
|
| 229 |
+
# Concatenate special tokens with patch tokens
|
| 230 |
+
tokens = torch.cat([camera_token, time_conditioning_token, register_token, patch_tokens], dim=1)
|
| 231 |
+
|
| 232 |
+
pos = None
|
| 233 |
+
if self.rope is not None:
|
| 234 |
+
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
|
| 235 |
+
|
| 236 |
+
if self.patch_start_idx > 0:
|
| 237 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
| 238 |
+
# so set pos to 0 for the special tokens
|
| 239 |
+
pos = pos + 1
|
| 240 |
+
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
|
| 241 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
| 242 |
+
|
| 243 |
+
# update P because we added special tokens
|
| 244 |
+
_, P, C = tokens.shape
|
| 245 |
+
|
| 246 |
+
frame_idx = 0
|
| 247 |
+
global_idx = 0
|
| 248 |
+
output_list = []
|
| 249 |
+
|
| 250 |
+
for _ in range(self.aa_block_num):
|
| 251 |
+
for attn_type in self.aa_order:
|
| 252 |
+
if attn_type == "frame":
|
| 253 |
+
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
| 254 |
+
tokens, B, S, P, C, frame_idx, pos=pos
|
| 255 |
+
)
|
| 256 |
+
elif attn_type == "global":
|
| 257 |
+
tokens, global_idx, global_intermediates = self._process_global_attention(
|
| 258 |
+
tokens, B, S, P, C, global_idx, pos=pos
|
| 259 |
+
)
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
| 262 |
+
|
| 263 |
+
for i in range(len(frame_intermediates)):
|
| 264 |
+
# concat frame and global intermediates, [B x S x P x 2C]
|
| 265 |
+
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
| 266 |
+
output_list.append(concat_inter)
|
| 267 |
+
|
| 268 |
+
del concat_inter
|
| 269 |
+
del frame_intermediates
|
| 270 |
+
del global_intermediates
|
| 271 |
+
return output_list, self.patch_start_idx
|
| 272 |
+
|
| 273 |
+
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
|
| 274 |
+
"""
|
| 275 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
| 276 |
+
"""
|
| 277 |
+
# If needed, reshape tokens or positions:
|
| 278 |
+
if tokens.shape != (B * S, P, C):
|
| 279 |
+
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
|
| 280 |
+
|
| 281 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
| 282 |
+
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
|
| 283 |
+
|
| 284 |
+
intermediates = []
|
| 285 |
+
|
| 286 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 287 |
+
for _ in range(self.aa_block_size):
|
| 288 |
+
if self.training:
|
| 289 |
+
tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
| 290 |
+
else:
|
| 291 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
|
| 292 |
+
frame_idx += 1
|
| 293 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 294 |
+
|
| 295 |
+
return tokens, frame_idx, intermediates
|
| 296 |
+
|
| 297 |
+
def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
|
| 298 |
+
"""
|
| 299 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
| 300 |
+
"""
|
| 301 |
+
if tokens.shape != (B, S * P, C):
|
| 302 |
+
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
|
| 303 |
+
|
| 304 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
| 305 |
+
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
|
| 306 |
+
|
| 307 |
+
intermediates = []
|
| 308 |
+
|
| 309 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 310 |
+
for _ in range(self.aa_block_size):
|
| 311 |
+
if self.training:
|
| 312 |
+
tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
| 313 |
+
else:
|
| 314 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos)
|
| 315 |
+
global_idx += 1
|
| 316 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 317 |
+
|
| 318 |
+
return tokens, global_idx, intermediates
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def slice_expand_and_flatten(token_tensor, B, S):
|
| 322 |
+
"""
|
| 323 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
|
| 324 |
+
1) Uses the first position (index=0) for the first frame only
|
| 325 |
+
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
|
| 326 |
+
3) Expands both to match batch size B
|
| 327 |
+
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
|
| 328 |
+
followed by (S-1) second-position tokens
|
| 329 |
+
5) Flattens to (B*S, X, C) for processing
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
| 336 |
+
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
|
| 337 |
+
# Slice out the "other" tokens => shape (1, S-1, ...)
|
| 338 |
+
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
|
| 339 |
+
# Concatenate => shape (B, S, ...)
|
| 340 |
+
combined = torch.cat([query, others], dim=1)
|
| 341 |
+
|
| 342 |
+
# Finally flatten => shape (B*S, ...)
|
| 343 |
+
combined = combined.view(B * S, *combined.shape[2:])
|
| 344 |
+
return combined
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def slice_expand_and_flatten_single(token_tensor, B, S):
|
| 348 |
+
"""
|
| 349 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
|
| 350 |
+
1) Uses the first position (index=0) for the first frame only
|
| 351 |
+
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
|
| 352 |
+
3) Expands both to match batch size B
|
| 353 |
+
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
|
| 354 |
+
followed by (S-1) second-position tokens
|
| 355 |
+
5) Flattens to (B*S, X, C) for processing
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
| 362 |
+
token = token_tensor.expand(B, S, *token_tensor.shape[2:])
|
| 363 |
+
|
| 364 |
+
# Finally flatten => shape (B*S, ...)
|
| 365 |
+
token = token.view(B * S, 1, *token.shape[2:])
|
| 366 |
+
return token
|
vdpm/dpm/decoder.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE-VGGT file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn, Tensor
|
| 10 |
+
from torch.utils.checkpoint import checkpoint
|
| 11 |
+
from typing import List, Callable
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
from einops import repeat
|
| 15 |
+
|
| 16 |
+
from vggt.layers.block import drop_add_residual_stochastic_depth
|
| 17 |
+
from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
| 18 |
+
|
| 19 |
+
from vggt.layers.attention import Attention
|
| 20 |
+
from vggt.layers.drop_path import DropPath
|
| 21 |
+
from vggt.layers.layer_scale import LayerScale
|
| 22 |
+
from vggt.layers.mlp import Mlp
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ModulationOut:
|
| 29 |
+
shift: Tensor
|
| 30 |
+
scale: Tensor
|
| 31 |
+
gate: Tensor
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Modulation(nn.Module):
|
| 35 |
+
def __init__(self, dim: int, double: bool):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.is_double = double
|
| 38 |
+
self.multiplier = 6 if double else 3
|
| 39 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 40 |
+
|
| 41 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 42 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
| 43 |
+
|
| 44 |
+
return (
|
| 45 |
+
ModulationOut(*out[:3]),
|
| 46 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ConditionalBlock(nn.Module):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
dim: int,
|
| 54 |
+
num_heads: int,
|
| 55 |
+
mlp_ratio: float = 4.0,
|
| 56 |
+
qkv_bias: bool = True,
|
| 57 |
+
proj_bias: bool = True,
|
| 58 |
+
ffn_bias: bool = True,
|
| 59 |
+
drop: float = 0.0,
|
| 60 |
+
attn_drop: float = 0.0,
|
| 61 |
+
init_values=None,
|
| 62 |
+
drop_path: float = 0.0,
|
| 63 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 64 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 65 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 66 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 67 |
+
qk_norm: bool = False,
|
| 68 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 69 |
+
rope=None,
|
| 70 |
+
) -> None:
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
self.norm1 = norm_layer(dim, elementwise_affine=False)
|
| 74 |
+
self.modulation = Modulation(dim, double=False)
|
| 75 |
+
|
| 76 |
+
self.attn = attn_class(
|
| 77 |
+
dim,
|
| 78 |
+
num_heads=num_heads,
|
| 79 |
+
qkv_bias=qkv_bias,
|
| 80 |
+
proj_bias=proj_bias,
|
| 81 |
+
attn_drop=attn_drop,
|
| 82 |
+
proj_drop=drop,
|
| 83 |
+
qk_norm=qk_norm,
|
| 84 |
+
fused_attn=fused_attn,
|
| 85 |
+
rope=rope,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 89 |
+
|
| 90 |
+
self.norm2 = norm_layer(dim)
|
| 91 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 92 |
+
self.mlp = ffn_layer(
|
| 93 |
+
in_features=dim,
|
| 94 |
+
hidden_features=mlp_hidden_dim,
|
| 95 |
+
act_layer=act_layer,
|
| 96 |
+
drop=drop,
|
| 97 |
+
bias=ffn_bias,
|
| 98 |
+
)
|
| 99 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 100 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 101 |
+
|
| 102 |
+
self.sample_drop_ratio = drop_path
|
| 103 |
+
|
| 104 |
+
def forward(self, x: Tensor, pos=None, cond=None, is_global=False) -> Tensor:
|
| 105 |
+
B, S = cond.shape[:2]
|
| 106 |
+
C = x.shape[-1]
|
| 107 |
+
if is_global:
|
| 108 |
+
P = x.shape[1] // S
|
| 109 |
+
cond = cond.view(B * S, C)
|
| 110 |
+
mod, _ = self.modulation(cond)
|
| 111 |
+
|
| 112 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 113 |
+
"""
|
| 114 |
+
conditional attention following DiT implementation from Flux
|
| 115 |
+
https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py#L194-L239
|
| 116 |
+
"""
|
| 117 |
+
def prepare_for_mod(y):
|
| 118 |
+
"""reshape to modulate the patch tokens with correct conditioning one"""
|
| 119 |
+
return y.view(B, S, P, C).view(B * S, P, C) if is_global else y
|
| 120 |
+
def restore_after_mod(y):
|
| 121 |
+
"""reshape back to global sequence"""
|
| 122 |
+
return y.view(B, S, P, C).view(B, S * P, C) if is_global else y
|
| 123 |
+
|
| 124 |
+
x = prepare_for_mod(x)
|
| 125 |
+
x = (1 + mod.scale) * self.norm1(x) + mod.shift
|
| 126 |
+
x = restore_after_mod(x)
|
| 127 |
+
|
| 128 |
+
x = self.attn(x, pos=pos)
|
| 129 |
+
|
| 130 |
+
x = prepare_for_mod(x)
|
| 131 |
+
x = mod.gate * x
|
| 132 |
+
x = restore_after_mod(x)
|
| 133 |
+
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 137 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 138 |
+
|
| 139 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 140 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 141 |
+
x = drop_add_residual_stochastic_depth(
|
| 142 |
+
x,
|
| 143 |
+
pos=pos,
|
| 144 |
+
residual_func=attn_residual_func,
|
| 145 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 146 |
+
)
|
| 147 |
+
x = drop_add_residual_stochastic_depth(
|
| 148 |
+
x,
|
| 149 |
+
residual_func=ffn_residual_func,
|
| 150 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 151 |
+
)
|
| 152 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 153 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 154 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 155 |
+
else:
|
| 156 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 157 |
+
x = x + ffn_residual_func(x)
|
| 158 |
+
return x
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Decoder(nn.Module):
|
| 162 |
+
"""Attention blocks after encoder per DPT input feature
|
| 163 |
+
to generate point maps at a given time.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
cfg,
|
| 169 |
+
dim_in: int,
|
| 170 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
| 171 |
+
patch_size=14,
|
| 172 |
+
embed_dim=1024,
|
| 173 |
+
depth=2,
|
| 174 |
+
num_heads=16,
|
| 175 |
+
mlp_ratio=4.0,
|
| 176 |
+
block_fn=ConditionalBlock,
|
| 177 |
+
qkv_bias=True,
|
| 178 |
+
proj_bias=True,
|
| 179 |
+
ffn_bias=True,
|
| 180 |
+
aa_order=["frame", "global"],
|
| 181 |
+
aa_block_size=1,
|
| 182 |
+
qk_norm=True,
|
| 183 |
+
rope_freq=100,
|
| 184 |
+
init_values=0.01,
|
| 185 |
+
):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.cfg = cfg
|
| 188 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 189 |
+
|
| 190 |
+
self.depth = depth
|
| 191 |
+
self.aa_order = aa_order
|
| 192 |
+
self.patch_size = patch_size
|
| 193 |
+
self.aa_block_size = aa_block_size
|
| 194 |
+
|
| 195 |
+
# Validate that depth is divisible by aa_block_size
|
| 196 |
+
if self.depth % self.aa_block_size != 0:
|
| 197 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
| 198 |
+
|
| 199 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
| 200 |
+
|
| 201 |
+
self.rope = (
|
| 202 |
+
RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
| 203 |
+
)
|
| 204 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
| 205 |
+
|
| 206 |
+
self.dim_in = dim_in
|
| 207 |
+
|
| 208 |
+
self.old_decoder = False
|
| 209 |
+
if self.old_decoder:
|
| 210 |
+
self.frame_blocks = nn.ModuleList(
|
| 211 |
+
[
|
| 212 |
+
block_fn(
|
| 213 |
+
dim=embed_dim*2,
|
| 214 |
+
num_heads=num_heads,
|
| 215 |
+
mlp_ratio=mlp_ratio,
|
| 216 |
+
qkv_bias=qkv_bias,
|
| 217 |
+
proj_bias=proj_bias,
|
| 218 |
+
ffn_bias=ffn_bias,
|
| 219 |
+
init_values=init_values,
|
| 220 |
+
qk_norm=qk_norm,
|
| 221 |
+
rope=self.rope,
|
| 222 |
+
)
|
| 223 |
+
for _ in range(depth)
|
| 224 |
+
]
|
| 225 |
+
)
|
| 226 |
+
self.global_blocks = nn.ModuleList(
|
| 227 |
+
[
|
| 228 |
+
block_fn(
|
| 229 |
+
dim=embed_dim*2,
|
| 230 |
+
num_heads=num_heads,
|
| 231 |
+
mlp_ratio=mlp_ratio,
|
| 232 |
+
qkv_bias=qkv_bias,
|
| 233 |
+
proj_bias=proj_bias,
|
| 234 |
+
ffn_bias=ffn_bias,
|
| 235 |
+
init_values=init_values,
|
| 236 |
+
qk_norm=qk_norm,
|
| 237 |
+
rope=self.rope,
|
| 238 |
+
)
|
| 239 |
+
for _ in range(depth)
|
| 240 |
+
]
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
depths = [depth]
|
| 244 |
+
self.frame_blocks = nn.ModuleList([
|
| 245 |
+
nn.ModuleList([
|
| 246 |
+
block_fn(
|
| 247 |
+
dim=embed_dim*2,
|
| 248 |
+
num_heads=num_heads,
|
| 249 |
+
mlp_ratio=mlp_ratio,
|
| 250 |
+
qkv_bias=qkv_bias,
|
| 251 |
+
proj_bias=proj_bias,
|
| 252 |
+
ffn_bias=ffn_bias,
|
| 253 |
+
init_values=init_values,
|
| 254 |
+
qk_norm=qk_norm,
|
| 255 |
+
rope=self.rope,
|
| 256 |
+
)
|
| 257 |
+
for _ in range(d)
|
| 258 |
+
])
|
| 259 |
+
for d in depths
|
| 260 |
+
])
|
| 261 |
+
|
| 262 |
+
self.global_blocks = nn.ModuleList([
|
| 263 |
+
nn.ModuleList([
|
| 264 |
+
block_fn(
|
| 265 |
+
dim=embed_dim*2,
|
| 266 |
+
num_heads=num_heads,
|
| 267 |
+
mlp_ratio=mlp_ratio,
|
| 268 |
+
qkv_bias=qkv_bias,
|
| 269 |
+
proj_bias=proj_bias,
|
| 270 |
+
ffn_bias=ffn_bias,
|
| 271 |
+
init_values=init_values,
|
| 272 |
+
qk_norm=qk_norm,
|
| 273 |
+
rope=self.rope,
|
| 274 |
+
)
|
| 275 |
+
for _ in range(d)
|
| 276 |
+
])
|
| 277 |
+
for d in depths
|
| 278 |
+
])
|
| 279 |
+
|
| 280 |
+
self.use_reentrant = False # hardcoded to False
|
| 281 |
+
|
| 282 |
+
def get_condition_tokens(
|
| 283 |
+
self,
|
| 284 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 285 |
+
cond_view_idxs: torch.Tensor
|
| 286 |
+
):
|
| 287 |
+
# Use tokens from the last block for conditioning
|
| 288 |
+
tokens_last = aggregated_tokens_list[-1] # [B S N_tok D]
|
| 289 |
+
# Extract the camera tokens
|
| 290 |
+
cond_token_idx = 1
|
| 291 |
+
camera_tokens = tokens_last[:, :, [cond_token_idx]] # [B S D]
|
| 292 |
+
|
| 293 |
+
cond_view_idxs = cond_view_idxs.to(camera_tokens.device)
|
| 294 |
+
cond_view_idxs = repeat(
|
| 295 |
+
cond_view_idxs,
|
| 296 |
+
"b s -> b s c d",
|
| 297 |
+
c=camera_tokens.shape[2],
|
| 298 |
+
d=camera_tokens.shape[3],
|
| 299 |
+
)
|
| 300 |
+
cond_tokens = torch.gather(camera_tokens, 1, cond_view_idxs)
|
| 301 |
+
|
| 302 |
+
return cond_tokens
|
| 303 |
+
|
| 304 |
+
def forward(
|
| 305 |
+
self,
|
| 306 |
+
images: torch.Tensor,
|
| 307 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 308 |
+
patch_start_idx: int,
|
| 309 |
+
cond_view_idxs: torch.Tensor,
|
| 310 |
+
):
|
| 311 |
+
B, S, _, H, W = images.shape
|
| 312 |
+
|
| 313 |
+
cond_tokens = self.get_condition_tokens(
|
| 314 |
+
aggregated_tokens_list, cond_view_idxs
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
input_tokens = []
|
| 318 |
+
for k, layer_idx in enumerate(self.intermediate_layer_idx):
|
| 319 |
+
layer_tokens = aggregated_tokens_list[layer_idx].clone()
|
| 320 |
+
input_tokens.append(layer_tokens)
|
| 321 |
+
|
| 322 |
+
_, _, P, C = input_tokens[0].shape
|
| 323 |
+
|
| 324 |
+
pos = None
|
| 325 |
+
if self.rope is not None:
|
| 326 |
+
pos = self.position_getter(
|
| 327 |
+
B * S, H // self.patch_size, W // self.patch_size, device=images.device
|
| 328 |
+
)
|
| 329 |
+
if patch_start_idx > 0:
|
| 330 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
| 331 |
+
# so set pos to 0 for the special tokens
|
| 332 |
+
pos = pos + 1
|
| 333 |
+
pos_special = torch.zeros(B * S, patch_start_idx, 2).to(images.device).to(pos.dtype)
|
| 334 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
| 335 |
+
|
| 336 |
+
frame_idx = 0
|
| 337 |
+
global_idx = 0
|
| 338 |
+
depth = len(self.frame_blocks[0])
|
| 339 |
+
N = len(input_tokens)
|
| 340 |
+
# stack all intermediate layer tokens along batch dimension
|
| 341 |
+
# they are all processed by the same decoder
|
| 342 |
+
s_tokens = torch.cat(input_tokens)
|
| 343 |
+
s_cond_tokens = torch.cat([cond_tokens] * N, dim=0)
|
| 344 |
+
s_pos = torch.cat([pos] * N, dim=0)
|
| 345 |
+
|
| 346 |
+
# perform time conditioned attention
|
| 347 |
+
for _ in range(depth):
|
| 348 |
+
for attn_type in self.aa_order:
|
| 349 |
+
token_idx = 0
|
| 350 |
+
|
| 351 |
+
if attn_type == "frame":
|
| 352 |
+
s_tokens, frame_idx, _ = self._process_frame_attention(
|
| 353 |
+
s_tokens, s_cond_tokens, B * N, S, P, C, frame_idx, pos=s_pos, token_idx=token_idx
|
| 354 |
+
)
|
| 355 |
+
elif attn_type == "global":
|
| 356 |
+
s_tokens, global_idx, _ = self._process_global_attention(
|
| 357 |
+
s_tokens, s_cond_tokens, B * N, S, P, C, global_idx, pos=s_pos, token_idx=token_idx
|
| 358 |
+
)
|
| 359 |
+
else:
|
| 360 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
| 361 |
+
processed = [t.view(B, S, P, C) for t in s_tokens.split(B, dim=0)]
|
| 362 |
+
|
| 363 |
+
return processed
|
| 364 |
+
|
| 365 |
+
def _process_frame_attention(self, tokens, cond_tokens, B, S, P, C, frame_idx, pos=None, token_idx=0):
|
| 366 |
+
"""
|
| 367 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
| 368 |
+
"""
|
| 369 |
+
# If needed, reshape tokens or positions:
|
| 370 |
+
if tokens.shape != (B * S, P, C):
|
| 371 |
+
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
|
| 372 |
+
|
| 373 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
| 374 |
+
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
|
| 375 |
+
|
| 376 |
+
intermediates = []
|
| 377 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 378 |
+
for _ in range(self.aa_block_size):
|
| 379 |
+
if self.training:
|
| 380 |
+
tokens = checkpoint(self.frame_blocks[token_idx][frame_idx], tokens, pos, cond_tokens, use_reentrant=self.use_reentrant)
|
| 381 |
+
else:
|
| 382 |
+
if self.old_decoder:
|
| 383 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos, cond=cond_tokens)
|
| 384 |
+
else:
|
| 385 |
+
tokens = self.frame_blocks[0][frame_idx](tokens, pos=pos, cond=cond_tokens)
|
| 386 |
+
|
| 387 |
+
frame_idx += 1
|
| 388 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 389 |
+
|
| 390 |
+
return tokens, frame_idx, intermediates
|
| 391 |
+
|
| 392 |
+
def _process_global_attention(self, tokens, cond_tokens, B, S, P, C, global_idx, pos=None, token_idx=0):
|
| 393 |
+
"""
|
| 394 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
| 395 |
+
"""
|
| 396 |
+
if tokens.shape != (B, S * P, C):
|
| 397 |
+
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
|
| 398 |
+
|
| 399 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
| 400 |
+
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
|
| 401 |
+
|
| 402 |
+
intermediates = []
|
| 403 |
+
|
| 404 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 405 |
+
for _ in range(self.aa_block_size):
|
| 406 |
+
if self.training:
|
| 407 |
+
tokens = checkpoint(self.global_blocks[token_idx][global_idx], tokens, pos, cond_tokens, True, use_reentrant=self.use_reentrant)
|
| 408 |
+
else:
|
| 409 |
+
if self.old_decoder:
|
| 410 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
|
| 411 |
+
else:
|
| 412 |
+
tokens = self.global_blocks[0][global_idx](tokens, pos=pos, cond=cond_tokens, is_global=True)
|
| 413 |
+
global_idx += 1
|
| 414 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 415 |
+
|
| 416 |
+
return tokens, global_idx, intermediates
|
vdpm/dpm/model.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from vggt.heads.camera_head import CameraHead
|
| 5 |
+
from vggt.heads.dpt_head import DPTHead
|
| 6 |
+
|
| 7 |
+
from .aggregator import Aggregator
|
| 8 |
+
from .decoder import Decoder
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def freeze_all_params(modules):
|
| 12 |
+
for module in modules:
|
| 13 |
+
try:
|
| 14 |
+
for n, param in module.named_parameters():
|
| 15 |
+
param.requires_grad = False
|
| 16 |
+
except AttributeError:
|
| 17 |
+
# module is directly a parameter
|
| 18 |
+
module.requires_grad = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class VDPM(nn.Module):
|
| 22 |
+
def __init__(self, cfg, img_size=518, patch_size=14, embed_dim=1024):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.cfg = cfg
|
| 25 |
+
|
| 26 |
+
self.aggregator = Aggregator(
|
| 27 |
+
img_size=img_size,
|
| 28 |
+
patch_size=patch_size,
|
| 29 |
+
embed_dim=embed_dim,
|
| 30 |
+
)
|
| 31 |
+
self.decoder = Decoder(
|
| 32 |
+
cfg,
|
| 33 |
+
dim_in=2*embed_dim,
|
| 34 |
+
embed_dim=embed_dim,
|
| 35 |
+
depth=cfg.model.decoder_depth
|
| 36 |
+
)
|
| 37 |
+
self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
|
| 38 |
+
|
| 39 |
+
self.camera_head = CameraHead(dim_in=2 * embed_dim)
|
| 40 |
+
self.set_freeze()
|
| 41 |
+
|
| 42 |
+
def set_freeze(self):
|
| 43 |
+
to_be_frozen = [self.aggregator.patch_embed]
|
| 44 |
+
freeze_all_params(to_be_frozen)
|
| 45 |
+
|
| 46 |
+
def forward(
|
| 47 |
+
self,
|
| 48 |
+
views, autocast_dpt=None
|
| 49 |
+
):
|
| 50 |
+
images = torch.stack([view["img"] for view in views], dim=1)
|
| 51 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
|
| 52 |
+
|
| 53 |
+
res_dynamic = dict()
|
| 54 |
+
|
| 55 |
+
if self.decoder is not None:
|
| 56 |
+
cond_view_idxs = torch.stack([view["view_idxs"][:, 1] for view in views], dim=1)
|
| 57 |
+
decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
|
| 58 |
+
|
| 59 |
+
if autocast_dpt is None:
|
| 60 |
+
autocast_dpt = torch.amp.autocast("cuda", enabled=False)
|
| 61 |
+
|
| 62 |
+
with autocast_dpt:
|
| 63 |
+
pts3d, pts3d_conf = self.point_head(
|
| 64 |
+
aggregated_tokens_list, images, patch_start_idx
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
padded_decoded_tokens = [None] * len(aggregated_tokens_list)
|
| 68 |
+
for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
|
| 69 |
+
padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
|
| 70 |
+
pts3d_dyn, pts3d_dyn_conf = self.point_head(
|
| 71 |
+
padded_decoded_tokens, images, patch_start_idx
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
res_dynamic |= {
|
| 75 |
+
"pts3d": pts3d_dyn,
|
| 76 |
+
"conf": pts3d_dyn_conf
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 80 |
+
res_dynamic |= {"pose_enc_list": pose_enc_list}
|
| 81 |
+
|
| 82 |
+
res_static = dict(
|
| 83 |
+
pts3d=pts3d,
|
| 84 |
+
conf=pts3d_conf
|
| 85 |
+
)
|
| 86 |
+
return res_static, res_dynamic
|
| 87 |
+
|
| 88 |
+
def inference(
|
| 89 |
+
self,
|
| 90 |
+
views,
|
| 91 |
+
images=None,
|
| 92 |
+
num_timesteps=None
|
| 93 |
+
):
|
| 94 |
+
autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
|
| 95 |
+
|
| 96 |
+
if images is None:
|
| 97 |
+
images = torch.stack([view["img"] for view in views], dim=1)
|
| 98 |
+
|
| 99 |
+
with autocast_amp:
|
| 100 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
|
| 101 |
+
S = images.shape[1]
|
| 102 |
+
|
| 103 |
+
# Determine number of timesteps to query
|
| 104 |
+
if num_timesteps is None:
|
| 105 |
+
# Default to S if not specified (legacy behavior)
|
| 106 |
+
# But if views has indices, try to infer max time
|
| 107 |
+
if views is not None and "view_idxs" in views[0]:
|
| 108 |
+
try:
|
| 109 |
+
all_idxs = torch.cat([v["view_idxs"][:, 1] for v in views])
|
| 110 |
+
num_timesteps = int(all_idxs.max().item()) + 1
|
| 111 |
+
except:
|
| 112 |
+
num_timesteps = S
|
| 113 |
+
else:
|
| 114 |
+
num_timesteps = S
|
| 115 |
+
|
| 116 |
+
predictions = dict()
|
| 117 |
+
pointmaps = []
|
| 118 |
+
ones = torch.ones(1, S, dtype=torch.int64)
|
| 119 |
+
for time_ in range(num_timesteps):
|
| 120 |
+
cond_view_idxs = ones * time_
|
| 121 |
+
|
| 122 |
+
with autocast_amp:
|
| 123 |
+
decoded_tokens = self.decoder(images, aggregated_tokens_list, patch_start_idx, cond_view_idxs)
|
| 124 |
+
padded_decoded_tokens = [None] * len(aggregated_tokens_list)
|
| 125 |
+
for idx, layer_idx in enumerate(self.point_head.intermediate_layer_idx):
|
| 126 |
+
padded_decoded_tokens[layer_idx] = decoded_tokens[idx]
|
| 127 |
+
|
| 128 |
+
# ... existing code ...
|
| 129 |
+
|
| 130 |
+
pts3d, pts3d_conf = self.point_head(
|
| 131 |
+
padded_decoded_tokens, images, patch_start_idx
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
pointmaps.append(dict(
|
| 135 |
+
pts3d=pts3d,
|
| 136 |
+
conf=pts3d_conf
|
| 137 |
+
))
|
| 138 |
+
|
| 139 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 140 |
+
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
| 141 |
+
predictions["pose_enc_list"] = pose_enc_list
|
| 142 |
+
predictions["pointmaps"] = pointmaps
|
| 143 |
+
return predictions
|
| 144 |
+
|
| 145 |
+
def load_state_dict(self, ckpt, is_VGGT_static=False, **kw):
|
| 146 |
+
# don't load these VGGT heads as not needed
|
| 147 |
+
exclude = ["depth_head", "track_head"]
|
| 148 |
+
ckpt = {k:v for k, v in ckpt.items() if k.split('.')[0] not in exclude}
|
| 149 |
+
return super().load_state_dict(ckpt, **kw)
|
vdpm/examples/videos/camel.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3db92c240efbd1b97a466565988a9a06687fd422086656dc0a29e12c5b99b9bb
|
| 3 |
+
size 1301172
|
vdpm/examples/videos/car.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd74efdb4d4d59fc17356fefa5dadd4c5b787641c98ce3172ecd8e5a180e76a6
|
| 3 |
+
size 1015132
|
vdpm/examples/videos/figure1.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae285726e5d247e904bb1ea7887ee96733c0beea913b421abba39150a3299cd5
|
| 3 |
+
size 465850
|
vdpm/examples/videos/figure2.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b2b030dd564cffbb9b2795e7fcdf97fa50e3a518df5b71dfb3dfb36f431dfa4
|
| 3 |
+
size 516209
|
vdpm/examples/videos/figure3.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a4144a53f14bd2dc671376d26ecbb42b06c9b8810e1700f21a16d3e11dfbf5c
|
| 3 |
+
size 559096
|
vdpm/examples/videos/goldfish.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28912e59d0d9e6b20d26973efee4806e89e115c7f1e63aec7206384ac3d0bf78
|
| 3 |
+
size 668862
|
vdpm/examples/videos/horse.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8227c7d901a936aeab6a2b41f104dd17e5544315d4cde7dac37f5787319947e7
|
| 3 |
+
size 1223145
|