mystorm commited on
Commit
8df2bdc
·
verified ·
1 Parent(s): 05b4b81

Upload 99 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. FastVGGT/.gitignore +160 -0
  3. FastVGGT/.vscode/launch.json +85 -0
  4. FastVGGT/README.md +163 -0
  5. FastVGGT/assets/attn_map.png +3 -0
  6. FastVGGT/assets/autolab_logo.png +3 -0
  7. FastVGGT/assets/maclab_logo.png +0 -0
  8. FastVGGT/assets/main.png +3 -0
  9. FastVGGT/assets/vs.png +3 -0
  10. FastVGGT/eval/__pycache__/base.cpython-310.pyc +0 -0
  11. FastVGGT/eval/__pycache__/criterion.cpython-310.pyc +0 -0
  12. FastVGGT/eval/__pycache__/data.cpython-310.pyc +0 -0
  13. FastVGGT/eval/__pycache__/data.cpython-37.pyc +0 -0
  14. FastVGGT/eval/__pycache__/utils.cpython-310.pyc +0 -0
  15. FastVGGT/eval/__pycache__/utils.cpython-37.pyc +0 -0
  16. FastVGGT/eval/base.py +273 -0
  17. FastVGGT/eval/criterion.py +534 -0
  18. FastVGGT/eval/data.py +338 -0
  19. FastVGGT/eval/dataset_utils/__init__.py +1 -0
  20. FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc +0 -0
  21. FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc +0 -0
  22. FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc +0 -0
  23. FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc +0 -0
  24. FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc +0 -0
  25. FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc +0 -0
  26. FastVGGT/eval/dataset_utils/corr.py +234 -0
  27. FastVGGT/eval/dataset_utils/cropping.py +140 -0
  28. FastVGGT/eval/dataset_utils/transforms.py +78 -0
  29. FastVGGT/eval/eval_7andN.py +497 -0
  30. FastVGGT/eval/eval_custom.py +467 -0
  31. FastVGGT/eval/eval_scannet.py +208 -0
  32. FastVGGT/eval/utils.py +142 -0
  33. FastVGGT/merging/__init__.py +3 -0
  34. FastVGGT/merging/__pycache__/__init__.cpython-310.pyc +0 -0
  35. FastVGGT/merging/__pycache__/merge.cpython-310.pyc +0 -0
  36. FastVGGT/merging/merge.py +370 -0
  37. FastVGGT/requirements.txt +15 -0
  38. FastVGGT/vggt/__init__.py +5 -0
  39. FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc +0 -0
  40. FastVGGT/vggt/dependency/__init__.py +5 -0
  41. FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc +0 -0
  42. FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc +0 -0
  43. FastVGGT/vggt/dependency/distortion.py +54 -0
  44. FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc +0 -0
  45. FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
  46. FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc +0 -0
  47. FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc +0 -0
  48. FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc +0 -0
  49. FastVGGT/vggt/heads/camera_head.py +149 -0
  50. FastVGGT/vggt/heads/dpt_head.py +598 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ FastVGGT/assets/attn_map.png filter=lfs diff=lfs merge=lfs -text
37
+ FastVGGT/assets/autolab_logo.png filter=lfs diff=lfs merge=lfs -text
38
+ FastVGGT/assets/main.png filter=lfs diff=lfs merge=lfs -text
39
+ FastVGGT/assets/vs.png filter=lfs diff=lfs merge=lfs -text
FastVGGT/.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .hydra/
2
+ output/
3
+ ckpt/
4
+ .vscode/
5
+ dependency/
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ **/__pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+ test_logs/
12
+ quick_start_logs/
13
+ logs/
14
+ *.pth
15
+ /data/
16
+ *.png
17
+ eval_results/
18
+ .vscode/
19
+ .curosr/
20
+
21
+ # C extensions
22
+ *.so
23
+ LightGlue/
24
+ # Distribution / packaging
25
+ .Python
26
+ build/
27
+ develop-eggs/
28
+ dist/
29
+ downloads/
30
+ eggs/
31
+ .eggs/
32
+ lib/
33
+ lib64/
34
+ parts/
35
+ sdist/
36
+ var/
37
+ wheels/
38
+ pip-wheel-metadata/
39
+ share/python-wheels/
40
+ *.egg-info/
41
+ .installed.cfg
42
+ *.egg
43
+ MANIFEST
44
+
45
+ # PyInstaller
46
+ # Usually these files are written by a python script from a template
47
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
48
+ *.manifest
49
+ *.spec
50
+
51
+ # Installer logs
52
+ pip-log.txt
53
+ pip-delete-this-directory.txt
54
+
55
+ # Unit test / coverage reports
56
+ htmlcov/
57
+ .tox/
58
+ .nox/
59
+ .coverage
60
+ .coverage.*
61
+ .cache
62
+ nosetests.xml
63
+ coverage.xml
64
+ *.cover
65
+ *.py,cover
66
+ .hypothesis/
67
+ .pytest_cache/
68
+ cover/
69
+
70
+ # Translations
71
+ *.mo
72
+ *.pot
73
+
74
+ # Django stuff:
75
+ *.log
76
+ local_settings.py
77
+ db.sqlite3
78
+ db.sqlite3-journal
79
+
80
+ # Flask stuff:
81
+ instance/
82
+ .webassets-cache
83
+
84
+ # Scrapy stuff:
85
+ .scrapy
86
+
87
+ # Sphinx documentation
88
+ docs/_build/
89
+
90
+ # PyBuilder
91
+ target/
92
+
93
+ # Jupyter Notebook
94
+ .ipynb_checkpoints
95
+
96
+ # IPython
97
+ profile_default/
98
+ ipython_config.py
99
+
100
+ # pyenv
101
+ .python-version
102
+
103
+ # pipenv
104
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
106
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
107
+ # install all needed dependencies.
108
+ #Pipfile.lock
109
+
110
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
111
+ __pypackages__/
112
+
113
+ # Celery stuff
114
+ celerybeat-schedule
115
+ celerybeat.pid
116
+
117
+ # SageMath parsed files
118
+ *.sage.py
119
+
120
+ # Environments
121
+ .env
122
+ .venv
123
+ env/
124
+ venv/
125
+ ENV/
126
+ env.bak/
127
+ venv.bak/
128
+
129
+ # Spyder project settings
130
+ .spyderproject
131
+ .spyproject
132
+
133
+ # Rope project settings
134
+ .ropeproject
135
+
136
+ # mkdocs documentation
137
+ /site
138
+
139
+ # mypy
140
+ .mypy_cache/
141
+ .dmypy.json
142
+ dmypy.json
143
+
144
+ # Pyre type checker
145
+ .pyre/
146
+
147
+ # pytype static type analyzer
148
+ .pytype/
149
+
150
+ # Profiling data
151
+ .prof
152
+
153
+ # Folder specific to your needs
154
+ **/tmp/
155
+ **/outputs/skyseg.onnx
156
+ skyseg.onnx
157
+
158
+ # pixi environments
159
+ .pixi
160
+ *.egg-info
FastVGGT/.vscode/launch.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+
8
+ {
9
+ "name": "launch",
10
+ "type": "debugpy",
11
+ "request": "launch",
12
+ "program": "/home/sy/code/vggt_0625/training/launch.py",
13
+ "console": "integratedTerminal",
14
+ "args": "${command:pickArgs}",
15
+ "env": {
16
+ "CUDA_VISIBLE_DEVICES": "3",
17
+ },
18
+ "cwd": "/home/sy/code/vggt_0625/training",
19
+ "justMyCode": true,
20
+ "python": "/home/sy/anaconda3/envs/vggt/bin/python"
21
+ }
22
+ ,{
23
+ "name": "train_scannet",
24
+ "type": "debugpy",
25
+ "request": "launch",
26
+ "program": "/home/sy/code/vggt_0625/training/launch_scannet.py",
27
+ "console": "integratedTerminal",
28
+ "args": [
29
+ // "--config_name", "scannet",
30
+ // "--exp_name", "scannet_exp001",
31
+ // "--resume_checkpoint_path", "/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt"
32
+ ],
33
+ "env": {
34
+ "CUDA_VISIBLE_DEVICES": "7",
35
+ "WORLD_SIZE": "1",
36
+ "RANK": "0",
37
+ "MASTER_ADDR": "localhost",
38
+ "MASTER_PORT": "12345"
39
+ },
40
+ "cwd": "/home/sy/code/vggt_0625/training",
41
+ "justMyCode": true,
42
+ "python": "/home/sy/anaconda3/envs/vggt/bin/python"
43
+ }
44
+ ,{
45
+ "name": "eval_scannet",
46
+ "type": "debugpy",
47
+ "request": "launch",
48
+ "program": "/home/sy/code/FastVGGT/eval/eval_scannet.py",
49
+ "console": "integratedTerminal",
50
+ "args": [
51
+ "--data_dir","/data/sy/scannetv2/process_scannet",
52
+ "--gt_ply_dir","/data/sy/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
53
+ "--output_path", "/home/sy/code/FastVGGT/eval_results",
54
+ "--merging", "0",
55
+ "--ckpt_path","/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt",
56
+ "--vis_attn_map"
57
+ ],
58
+ "env": {
59
+ "CUDA_VISIBLE_DEVICES": "2"
60
+ },
61
+ "justMyCode": true,
62
+ "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
63
+ },
64
+ {
65
+ "name": "eval_cd",
66
+ "type": "debugpy",
67
+ "request": "launch",
68
+ "program": "/home/sy/code/FastVGGT/eval/eval_custom.py",
69
+ "console": "integratedTerminal",
70
+ "args": [
71
+ "--merging", "0",
72
+ // "--kf","10",
73
+ // "--output_dir","/home/sy/code/vggt_0625/eval_results_cd",
74
+ "--data_path","/data/sy/segment-102751/",
75
+ "--vis_attn_map"
76
+ ],
77
+ "env": {
78
+ "CUDA_VISIBLE_DEVICES": "3"
79
+ },
80
+ "justMyCode": true,
81
+ // "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
82
+ }
83
+
84
+ ]
85
+ }
FastVGGT/README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h2>⚡️ FastVGGT: Training-Free Acceleration of Visual Geometry Transformer</h2>
3
+
4
+ <p align="center">
5
+ <a href="https://arxiv.org/abs/2509.02560"><img src="https://img.shields.io/badge/arXiv-FastVGGT-red?logo=arxiv" alt="Paper PDF"></a>
6
+ <a href="https://mystorm16.github.io/fastvggt/"><img src="https://img.shields.io/badge/Project_Page-FastVGGT-yellow" alt="Project Page"></a>
7
+ </p>
8
+
9
+ <img src="assets/maclab_logo.png" alt="Maclab Logo" width="110" style="margin-right: 40px;">
10
+ <img src="assets/autolab_logo.png" alt="Autolab Logo" width="110">
11
+
12
+
13
+ **[Media Analytics & Computing Laboratory](https://mac.xmu.edu.cn/)**; **[AUTOLAB](https://zhipengzhang.cn/)**
14
+
15
+
16
+ [You Shen](https://mystorm16.github.io/), [Zhipeng Zhang](https://zhipengzhang.cn/), [Yansong Qu](https://quyans.github.io/), [Liujuan Cao](https://mac.xmu.edu.cn/ljcao/)
17
+ </div>
18
+
19
+
20
+ ## 📰 News
21
+ - [Sep 8, 2025] Added custom dataset evaluation.
22
+ - [Sep 3, 2025] Paper release.
23
+ - [Sep 2, 2025] Code release.
24
+
25
+ ## 🔭 Overview
26
+
27
+ FastVGGT observes **strong similarity** in attention maps and leverages it to design a training-free acceleration method for long-sequence 3D reconstruction, **achieving up to 4× faster inference without sacrificing accuracy.**
28
+
29
+ <img src="assets/main.png" alt="Autolab Logo" width="">
30
+
31
+
32
+ ## ⚙️ Environment Setup
33
+ First, create a virtual environment using Conda, clone this repository to your local machine, and install the required dependencies.
34
+
35
+
36
+ ```bash
37
+ conda create -n fastvggt python=3.10
38
+ conda activate fastvggt
39
+ git clone git@github.com:mystorm16/FastVGGT.git
40
+ cd FastVGGT
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ Next, prepare the ScanNet dataset: http://www.scan-net.org/ScanNet/
45
+
46
+ Then, download the VGGT checkpoint (we use the checkpoint link provided in https://github.com/facebookresearch/vggt/tree/evaluation/evaluation):
47
+ ```bash
48
+ wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt
49
+ ```
50
+
51
+ Finally, configure the dataset path and VGGT checkpoint path. For example:
52
+ ```bash
53
+ parser.add_argument(
54
+ "--data_dir", type=Path, default="/data/scannetv2/process_scannet"
55
+ )
56
+ parser.add_argument(
57
+ "--gt_ply_dir",
58
+ type=Path,
59
+ default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
60
+ )
61
+ parser.add_argument(
62
+ "--ckpt_path",
63
+ type=str,
64
+ default="./ckpt/model_tracker_fixed_e20.pt",
65
+ )
66
+ ```
67
+
68
+
69
+ ## 💎 Observation
70
+
71
+ Note: A large number of input_frames may significantly slow down saving the visualization results. Please try using a smaller number first.
72
+ ```bash
73
+ python eval/eval_scannet.py --input_frame 30 --vis_attn_map --merging 0
74
+ ```
75
+
76
+ We observe that many token-level attention maps are highly similar in each block, motivating our optimization of the Global Attention module.
77
+
78
+ <img src="assets/attn_map.png" alt="Autolab Logo" width="">
79
+
80
+
81
+
82
+ ## 🏀 Evaluation
83
+ ### Custom Dataset
84
+ Please organize the data according to the following directory:
85
+ ```
86
+ <data_path>/
87
+ ├── images/
88
+ │ ├── 000000.jpg
89
+ │ ├── 000001.jpg
90
+ │ └── ...
91
+ ├── pose/ # Optional: Camera poses
92
+ │ ├── 000000.txt
93
+ │ ├── 000001.txt
94
+ │ └── ...
95
+ └── gt_ply/ # Optional: GT point cloud
96
+ └── scene_xxx.ply
97
+ ```
98
+ - Required: `images/`
99
+ - Additionally required when `--enable_evaluation` is enabled: `pose/` and `gt_ply/`
100
+
101
+ Inference only:
102
+
103
+ ```bash
104
+ python eval/eval_custom.py \
105
+ --data_path /path/to/your_dataset \
106
+ --output_path ./eval_results_custom \
107
+ --plot
108
+ ```
109
+
110
+ Inference + Evaluation (requires `pose/` and `gt_ply/`):
111
+
112
+ ```bash
113
+ python eval/eval_custom.py \
114
+ --data_path /path/to/your_dataset \
115
+ --enable_evaluation \
116
+ --output_path ./eval_results_custom \
117
+ --plot
118
+ ```
119
+
120
+ ### ScanNet
121
+ Evaluate FastVGGT on the ScanNet dataset with 1,000 input images. The **--merging** parameter specifies the block index at which the merging strategy is applied:
122
+
123
+ ```bash
124
+ python eval/eval_scannet.py --input_frame 1000 --merging 0
125
+ ```
126
+
127
+ Evaluate Baseline VGGT on the ScanNet dataset with 1,000 input images:
128
+ ```bash
129
+ python eval/eval_scannet.py --input_frame 1000
130
+ ```
131
+ <img src="assets/vs.png" alt="Autolab Logo" width="">
132
+
133
+ ### 7 Scenes & NRGBD
134
+ Evaluate across two datasets, sampling keyframes every 10 frames:
135
+ ```bash
136
+ python eval/eval_7andN.py --kf 10
137
+ ```
138
+
139
+ ## 🍺 Acknowledgements
140
+
141
+ - Thanks to these great repositories: [VGGT](https://github.com/facebookresearch/vggt), [Dust3r](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [CUT3R](https://github.com/CUT3R/CUT3R), [MV-DUSt3R+](https://github.com/facebookresearch/mvdust3r), [StreamVGGT](https://github.com/wzzheng/StreamVGGT), [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long), [ToMeSD](https://github.com/dbolya/tomesd) and many other inspiring works in the community.
142
+
143
+ - Special thanks to [Jianyuan Wang](https://jytime.github.io/) for his valuable discussions and suggestions on this work.
144
+
145
+ <!-- ## ✍️ Checklist
146
+
147
+ - [ ] Release the evaluation code on 7 Scenes / NRGBD -->
148
+
149
+
150
+ ## ⚖️ License
151
+ See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
152
+
153
+ ## Citation
154
+
155
+ If you find this project helpful, please consider citing the following paper:
156
+ ```
157
+ @article{shen2025fastvggt,
158
+ title={FastVGGT: Training-Free Acceleration of Visual Geometry Transformer},
159
+ author={Shen, You and Zhang, Zhipeng and Qu, Yansong and Cao, Liujuan},
160
+ journal={arXiv preprint arXiv:2509.02560},
161
+ year={2025}
162
+ }
163
+ ```
FastVGGT/assets/attn_map.png ADDED

Git LFS Details

  • SHA256: 8477957f593c203bcf41df91ac3ed0d22329e22250fdab9f8f8674340964242c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
FastVGGT/assets/autolab_logo.png ADDED

Git LFS Details

  • SHA256: 4fcead3160cbf561c4385cc8b938a17a94652e3d849da6497f053d32d1245596
  • Pointer size: 132 Bytes
  • Size of remote file: 5.13 MB
FastVGGT/assets/maclab_logo.png ADDED
FastVGGT/assets/main.png ADDED

Git LFS Details

  • SHA256: eecacb414647f01dc8a52b4aba5ff2556733f46d1b9129613e3f59aceff69685
  • Pointer size: 131 Bytes
  • Size of remote file: 884 kB
FastVGGT/assets/vs.png ADDED

Git LFS Details

  • SHA256: dac1ea5397b9f985c6fb86fc5f321313cc5a372d0917b7c1c1c7dd2cea6bec5f
  • Pointer size: 131 Bytes
  • Size of remote file: 230 kB
FastVGGT/eval/__pycache__/base.cpython-310.pyc ADDED
Binary file (6.92 kB). View file
 
FastVGGT/eval/__pycache__/criterion.cpython-310.pyc ADDED
Binary file (13.6 kB). View file
 
FastVGGT/eval/__pycache__/data.cpython-310.pyc ADDED
Binary file (7.78 kB). View file
 
FastVGGT/eval/__pycache__/data.cpython-37.pyc ADDED
Binary file (8.03 kB). View file
 
FastVGGT/eval/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.99 kB). View file
 
FastVGGT/eval/__pycache__/utils.cpython-37.pyc ADDED
Binary file (4.32 kB). View file
 
FastVGGT/eval/base.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # base class for implementing datasets
6
+ # --------------------------------------------------------
7
+ import PIL
8
+ import numpy as np
9
+ import torch
10
+
11
+ from dataset_utils.transforms import ImgNorm
12
+ import dataset_utils.cropping as cropping
13
+ from utils import depthmap_to_absolute_camera_coordinates
14
+
15
+
16
+ class BaseStereoViewDataset:
17
+ """Define all basic options.
18
+
19
+ Usage:
20
+ class MyDataset (BaseStereoViewDataset):
21
+ def _get_views(self, idx, rng):
22
+ # overload here
23
+ views = []
24
+ views.append(dict(img=, ...))
25
+ return views
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ *, # only keyword arguments
31
+ split=None,
32
+ resolution=None, # square_size or (width, height) or list of [(width,height), ...]
33
+ transform=ImgNorm,
34
+ aug_crop=False,
35
+ seed=None,
36
+ ):
37
+ self.num_views = 2
38
+ self.split = split
39
+ self._set_resolutions(resolution)
40
+
41
+ self.transform = transform
42
+ if isinstance(transform, str):
43
+ transform = eval(transform)
44
+
45
+ self.aug_crop = aug_crop
46
+ self.seed = seed
47
+
48
+ def __len__(self):
49
+ return len(self.scenes)
50
+
51
+ def get_stats(self):
52
+ return f"{len(self)} pairs"
53
+
54
+ def __repr__(self):
55
+ resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
56
+ return (
57
+ f"""{type(self).__name__}({self.get_stats()},
58
+ {self.split=},
59
+ {self.seed=},
60
+ resolutions={resolutions_str},
61
+ {self.transform=})""".replace(
62
+ "self.", ""
63
+ )
64
+ .replace("\n", "")
65
+ .replace(" ", "")
66
+ )
67
+
68
+ def _get_views(self, idx, resolution, rng):
69
+ raise NotImplementedError()
70
+
71
+ def __getitem__(self, idx):
72
+ if isinstance(idx, tuple):
73
+ # the idx is specifying the aspect-ratio
74
+ idx, ar_idx = idx
75
+ else:
76
+ assert len(self._resolutions) == 1
77
+ ar_idx = 0
78
+
79
+ # set-up the rng
80
+ if self.seed: # reseed for each __getitem__
81
+ self._rng = np.random.default_rng(seed=self.seed + idx)
82
+ elif not hasattr(self, "_rng"):
83
+ seed = torch.initial_seed() # this is different for each dataloader process
84
+ self._rng = np.random.default_rng(seed=seed)
85
+
86
+ # over-loaded code
87
+ resolution = self._resolutions[
88
+ ar_idx
89
+ ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
90
+ views = self._get_views(idx, resolution, self._rng)
91
+
92
+ # check data-types
93
+ for v, view in enumerate(views):
94
+ assert (
95
+ "pts3d" not in view
96
+ ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
97
+ view["idx"] = v
98
+
99
+ # encode the image
100
+ width, height = view["img"].size
101
+ view["true_shape"] = np.int32((height, width))
102
+ view["img"] = self.transform(view["img"])
103
+
104
+ assert "camera_intrinsics" in view
105
+ if "camera_pose" not in view:
106
+ view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
107
+ else:
108
+ assert np.isfinite(
109
+ view["camera_pose"]
110
+ ).all(), f"NaN in camera pose for view {view_name(view)}"
111
+ assert "pts3d" not in view
112
+ assert "valid_mask" not in view
113
+ assert np.isfinite(
114
+ view["depthmap"]
115
+ ).all(), f"NaN in depthmap for view {view_name(view)}"
116
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
117
+
118
+ view["pts3d"] = pts3d
119
+ view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
120
+
121
+ # check all datatypes
122
+ for key, val in view.items():
123
+ res, err_msg = is_good_type(key, val)
124
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
125
+ K = view["camera_intrinsics"]
126
+ view["img_mask"] = True
127
+ view["ray_mask"] = False
128
+ view["ray_map"] = torch.full(
129
+ (6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
130
+ )
131
+ view["update"] = True
132
+ view["reset"] = False
133
+
134
+ # last thing done!
135
+ for view in views:
136
+ # transpose to make sure all views are the same size
137
+ transpose_to_landscape(view)
138
+ # this allows to check whether the RNG is is the same state each time
139
+ view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
140
+ return views
141
+
142
+ def _set_resolutions(self, resolutions):
143
+ """Set the resolution(s) of the dataset.
144
+ Params:
145
+ - resolutions: int or tuple or list of tuples
146
+ """
147
+ assert resolutions is not None, "undefined resolution"
148
+
149
+ if not isinstance(resolutions, list):
150
+ resolutions = [resolutions]
151
+
152
+ self._resolutions = []
153
+ for resolution in resolutions:
154
+ if isinstance(resolution, int):
155
+ width = height = resolution
156
+ else:
157
+ width, height = resolution
158
+ assert isinstance(
159
+ width, int
160
+ ), f"Bad type for {width=} {type(width)=}, should be int"
161
+ assert isinstance(
162
+ height, int
163
+ ), f"Bad type for {height=} {type(height)=}, should be int"
164
+ assert width >= height
165
+ self._resolutions.append((width, height))
166
+
167
+ def _crop_resize_if_necessary(
168
+ self, image, depthmap, intrinsics, resolution, rng=None, info=None
169
+ ):
170
+ """This function:
171
+ - first downsizes the image with LANCZOS inteprolation,
172
+ which is better than bilinear interpolation in
173
+ """
174
+ if not isinstance(image, PIL.Image.Image):
175
+ image = PIL.Image.fromarray(image)
176
+
177
+ # downscale with lanczos interpolation so that image.size == resolution
178
+ # cropping centered on the principal point
179
+ W, H = image.size
180
+ cx, cy = intrinsics[:2, 2].round().astype(int)
181
+
182
+ # calculate min distance to margin
183
+ min_margin_x = min(cx, W - cx)
184
+ min_margin_y = min(cy, H - cy)
185
+ assert min_margin_x > W / 5, f"Bad principal point in view={info}"
186
+ assert min_margin_y > H / 5, f"Bad principal point in view={info}"
187
+
188
+ ## Center crop
189
+ # Crop on the principal point, make it always centered
190
+ # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
191
+ l, t = cx - min_margin_x, cy - min_margin_y
192
+ r, b = cx + min_margin_x, cy + min_margin_y
193
+ crop_bbox = (l, t, r, b)
194
+
195
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(
196
+ image, depthmap, intrinsics, crop_bbox
197
+ )
198
+
199
+ # # transpose the resolution if necessary
200
+ W, H = image.size # new size
201
+ assert resolution[0] >= resolution[1]
202
+ if H > 1.1 * W:
203
+ # image is portrait mode
204
+ resolution = resolution[::-1]
205
+ elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
206
+ # image is square, so we chose (portrait, landscape) randomly
207
+ if rng.integers(2):
208
+ resolution = resolution[::-1]
209
+
210
+ # high-quality Lanczos down-scaling
211
+ target_resolution = np.array(resolution)
212
+ # # if self.aug_crop > 1:
213
+ # # target_resolution += rng.integers(0, self.aug_crop)
214
+ # if resolution != (224, 224):
215
+ # halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
216
+ # ## Recale with max factor, so one of width or height might be larger than target_resolution
217
+ # image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
218
+ # else:
219
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(
220
+ image, depthmap, intrinsics, target_resolution
221
+ )
222
+ # actual cropping (if necessary) with bilinear interpolation
223
+ # if resolution == (224, 224):
224
+ intrinsics2 = cropping.camera_matrix_of_crop(
225
+ intrinsics, image.size, resolution, offset_factor=0.5
226
+ )
227
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(
228
+ intrinsics, intrinsics2, resolution
229
+ )
230
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(
231
+ image, depthmap, intrinsics, crop_bbox
232
+ )
233
+ return image, depthmap, intrinsics
234
+
235
+
236
+ def is_good_type(key, v):
237
+ """returns (is_good, err_msg)"""
238
+ if isinstance(v, (str, int, tuple)):
239
+ return True, None
240
+ if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
241
+ return False, f"bad {v.dtype=}"
242
+ return True, None
243
+
244
+
245
+ def view_name(view, batch_index=None):
246
+ def sel(x):
247
+ return x[batch_index] if batch_index not in (None, slice(None)) else x
248
+
249
+ db = sel(view["dataset"])
250
+ label = sel(view["label"])
251
+ instance = sel(view["instance"])
252
+ return f"{db}/{label}/{instance}"
253
+
254
+
255
+ def transpose_to_landscape(view):
256
+ height, width = view["true_shape"]
257
+
258
+ if width < height:
259
+ # rectify portrait to landscape
260
+ assert view["img"].shape == (3, height, width)
261
+ view["img"] = view["img"].swapaxes(1, 2)
262
+
263
+ assert view["valid_mask"].shape == (height, width)
264
+ view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
265
+
266
+ assert view["depthmap"].shape == (height, width)
267
+ view["depthmap"] = view["depthmap"].swapaxes(0, 1)
268
+
269
+ assert view["pts3d"].shape == (height, width, 3)
270
+ view["pts3d"] = view["pts3d"].swapaxes(0, 1)
271
+
272
+ # transpose x and y pixels
273
+ view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
FastVGGT/eval/criterion.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from copy import copy, deepcopy
4
+
5
+ from eval.dataset_utils.corr import geotrf, inv
6
+
7
+
8
+ def invalid_to_nans(arr, valid_mask, ndim=999):
9
+ if valid_mask is not None:
10
+ arr = arr.clone()
11
+ arr[~valid_mask] = float("nan")
12
+ if arr.ndim > ndim:
13
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
14
+ return arr
15
+
16
+
17
+ def invalid_to_zeros(arr, valid_mask, ndim=999):
18
+ if valid_mask is not None:
19
+ arr = arr.clone()
20
+ arr[~valid_mask] = 0
21
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
22
+ else:
23
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
24
+ if arr.ndim > ndim:
25
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
26
+ return arr, nnz
27
+
28
+
29
+ class BaseCriterion(nn.Module):
30
+ def __init__(self, reduction="mean"):
31
+ super().__init__()
32
+ self.reduction = reduction
33
+
34
+
35
+ class Criterion(nn.Module):
36
+ def __init__(self, criterion=None):
37
+ super().__init__()
38
+ assert isinstance(
39
+ criterion, BaseCriterion
40
+ ), f"{criterion} is not a proper criterion!"
41
+ self.criterion = copy(criterion)
42
+
43
+ def get_name(self):
44
+ return f"{type(self).__name__}({self.criterion})"
45
+
46
+ def with_reduction(self, mode="none"):
47
+ res = loss = deepcopy(self)
48
+ while loss is not None:
49
+ assert isinstance(loss, Criterion)
50
+ loss.criterion.reduction = mode # make it return the loss for each sample
51
+ loss = loss._loss2 # we assume loss is a Multiloss
52
+ return res
53
+
54
+
55
+ class MultiLoss(nn.Module):
56
+ """Easily combinable losses (also keep track of individual loss values):
57
+ loss = MyLoss1() + 0.1*MyLoss2()
58
+ Usage:
59
+ Inherit from this class and override get_name() and compute_loss()
60
+ """
61
+
62
+ def __init__(self):
63
+ super().__init__()
64
+ self._alpha = 1
65
+ self._loss2 = None
66
+
67
+ def compute_loss(self, *args, **kwargs):
68
+ raise NotImplementedError()
69
+
70
+ def get_name(self):
71
+ raise NotImplementedError()
72
+
73
+ def __mul__(self, alpha):
74
+ assert isinstance(alpha, (int, float))
75
+ res = copy(self)
76
+ res._alpha = alpha
77
+ return res
78
+
79
+ __rmul__ = __mul__ # same
80
+
81
+ def __add__(self, loss2):
82
+ assert isinstance(loss2, MultiLoss)
83
+ res = cur = copy(self)
84
+
85
+ while cur._loss2 is not None:
86
+ cur = cur._loss2
87
+ cur._loss2 = loss2
88
+ return res
89
+
90
+ def __repr__(self):
91
+ name = self.get_name()
92
+ if self._alpha != 1:
93
+ name = f"{self._alpha:g}*{name}"
94
+ if self._loss2:
95
+ name = f"{name} + {self._loss2}"
96
+ return name
97
+
98
+ def forward(self, *args, **kwargs):
99
+ loss = self.compute_loss(*args, **kwargs)
100
+ if isinstance(loss, tuple):
101
+ loss, details = loss
102
+ elif loss.ndim == 0:
103
+ details = {self.get_name(): float(loss)}
104
+ else:
105
+ details = {}
106
+ loss = loss * self._alpha
107
+
108
+ if self._loss2:
109
+ loss2, details2 = self._loss2(*args, **kwargs)
110
+ loss = loss + loss2
111
+ details |= details2
112
+
113
+ return loss, details
114
+
115
+
116
+ class LLoss(BaseCriterion):
117
+ """L-norm loss"""
118
+
119
+ def forward(self, a, b):
120
+ assert (
121
+ a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
122
+ ), f"Bad shape = {a.shape}"
123
+ dist = self.distance(a, b)
124
+
125
+ if self.reduction == "none":
126
+ return dist
127
+ if self.reduction == "sum":
128
+ return dist.sum()
129
+ if self.reduction == "mean":
130
+ return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
131
+ raise ValueError(f"bad {self.reduction=} mode")
132
+
133
+ def distance(self, a, b):
134
+ raise NotImplementedError()
135
+
136
+
137
+ class L21Loss(LLoss):
138
+ """Euclidean distance between 3d points"""
139
+
140
+ def distance(self, a, b):
141
+ return torch.norm(a - b, dim=-1) # normalized L2 distance
142
+
143
+
144
+ L21 = L21Loss()
145
+
146
+
147
+ def get_pred_pts3d(gt, pred, use_pose=False):
148
+ assert use_pose is True
149
+ return pred["pts3d_in_other_view"] # return!
150
+
151
+
152
+ def Sum(losses, masks, conf=None):
153
+ loss, mask = losses[0], masks[0]
154
+ if loss.ndim > 0:
155
+ # we are actually returning the loss for every pixels
156
+ if conf is not None:
157
+ return losses, masks, conf
158
+ return losses, masks
159
+ else:
160
+ # we are returning the global loss
161
+ for loss2 in losses[1:]:
162
+ loss = loss + loss2
163
+ return loss
164
+
165
+
166
+ def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
167
+ assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
168
+ assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
169
+ norm_mode, dis_mode = norm_mode.split("_")
170
+
171
+ nan_pts = []
172
+ nnzs = []
173
+
174
+ if norm_mode == "avg":
175
+ # gather all points together (joint normalization)
176
+
177
+ for i, pt in enumerate(pts):
178
+ nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
179
+ nan_pts.append(nan_pt)
180
+ nnzs.append(nnz)
181
+
182
+ if fix_first:
183
+ break
184
+ all_pts = torch.cat(nan_pts, dim=1)
185
+
186
+ # compute distance to origin
187
+ all_dis = all_pts.norm(dim=-1)
188
+ if dis_mode == "dis":
189
+ pass # do nothing
190
+ elif dis_mode == "log1p":
191
+ all_dis = torch.log1p(all_dis)
192
+ else:
193
+ raise ValueError(f"bad {dis_mode=}")
194
+
195
+ norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
196
+ else:
197
+ raise ValueError(f"Not implemented {norm_mode=}")
198
+
199
+ norm_factor = norm_factor.clip(min=1e-8)
200
+ while norm_factor.ndim < pts[0].ndim:
201
+ norm_factor.unsqueeze_(-1)
202
+
203
+ return norm_factor
204
+
205
+
206
+ def normalize_pointcloud_t(
207
+ pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
208
+ ):
209
+ if gt:
210
+ norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
211
+ res = []
212
+
213
+ for i, pt in enumerate(pts):
214
+ res.append(pt / norm_factor)
215
+
216
+ else:
217
+ # pts_l, pts_r = pts
218
+ # use pts_l and pts_r[-1] as pts to normalize
219
+ norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
220
+
221
+ res = []
222
+
223
+ for i in range(len(pts)):
224
+ res.append(pts[i] / norm_factor)
225
+ # res_r.append(pts_r[i] / norm_factor)
226
+
227
+ # res = [res_l, res_r]
228
+
229
+ return res, norm_factor
230
+
231
+
232
+ @torch.no_grad()
233
+ def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
234
+ # set invalid points to NaN
235
+ _zs = []
236
+ for i in range(len(zs)):
237
+ valid_mask = valid_masks[i] if valid_masks is not None else None
238
+ _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
239
+ _zs.append(_z)
240
+
241
+ _zs = torch.cat(_zs, dim=-1)
242
+
243
+ # compute median depth overall (ignoring nans)
244
+ if quantile == 0.5:
245
+ shift_z = torch.nanmedian(_zs, dim=-1).values
246
+ else:
247
+ shift_z = torch.nanquantile(_zs, quantile, dim=-1)
248
+ return shift_z # (B,)
249
+
250
+
251
+ @torch.no_grad()
252
+ def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
253
+ # set invalid points to NaN
254
+
255
+ _pts = []
256
+ for i in range(len(pts)):
257
+ valid_mask = valid_masks[i] if valid_masks is not None else None
258
+ _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
259
+ _pts.append(_pt)
260
+
261
+ _pts = torch.cat(_pts, dim=1)
262
+
263
+ # compute median center
264
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
265
+ if z_only:
266
+ _center[..., :2] = 0 # do not center X and Y
267
+
268
+ # compute median norm
269
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
270
+ scale = torch.nanmedian(_norm, dim=1).values
271
+ return _center[:, None, :, :], scale[:, None, None, None]
272
+
273
+
274
+ class Regr3D_t(Criterion, MultiLoss):
275
+ def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
276
+ super().__init__(criterion)
277
+ self.norm_mode = norm_mode
278
+ self.gt_scale = gt_scale
279
+ self.fix_first = fix_first
280
+
281
+ def get_all_pts3d_t(self, gts, preds, dist_clip=None):
282
+ # everything is normalized w.r.t. camera of view1
283
+ in_camera1 = inv(gts[0]["camera_pose"])
284
+
285
+ gt_pts = []
286
+ valids = []
287
+ pr_pts = []
288
+
289
+ for i, gt in enumerate(gts):
290
+ # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
291
+ gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
292
+ valid = gt["valid_mask"].clone()
293
+
294
+ if dist_clip is not None:
295
+ # points that are too far-away == invalid
296
+ dis = gt["pts3d"].norm(dim=-1)
297
+ valid = valid & (dis <= dist_clip)
298
+
299
+ valids.append(valid)
300
+ pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
301
+ # if i != len(gts)-1:
302
+ # pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))
303
+
304
+ # if i != 0:
305
+ # pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))
306
+
307
+ # pr_pts = (pr_pts_l, pr_pts_r)
308
+
309
+ if self.norm_mode:
310
+ pr_pts, pr_factor = normalize_pointcloud_t(
311
+ pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
312
+ )
313
+ else:
314
+ pr_factor = None
315
+
316
+ if self.norm_mode and not self.gt_scale:
317
+ gt_pts, gt_factor = normalize_pointcloud_t(
318
+ gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
319
+ )
320
+ else:
321
+ gt_factor = None
322
+
323
+ return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}
324
+
325
+ def compute_frame_loss(self, gts, preds, **kw):
326
+ gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
327
+ self.get_all_pts3d_t(gts, preds, **kw)
328
+ )
329
+
330
+ pred_pts_l, pred_pts_r = pred_pts
331
+
332
+ loss_all = []
333
+ mask_all = []
334
+ conf_all = []
335
+
336
+ loss_left = 0
337
+ loss_right = 0
338
+ pred_conf_l = 0
339
+ pred_conf_r = 0
340
+
341
+ for i in range(len(gt_pts)):
342
+
343
+ # Left (Reference)
344
+ if i != len(gt_pts) - 1:
345
+ frame_loss = self.criterion(
346
+ pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
347
+ )
348
+
349
+ loss_all.append(frame_loss)
350
+ mask_all.append(masks[i])
351
+ conf_all.append(preds[i][0]["conf"])
352
+
353
+ # To compare target/reference loss
354
+ if i != 0:
355
+ loss_left += frame_loss.cpu().detach().numpy().mean()
356
+ pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()
357
+
358
+ # Right (Target)
359
+ if i != 0:
360
+ frame_loss = self.criterion(
361
+ pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
362
+ )
363
+
364
+ loss_all.append(frame_loss)
365
+ mask_all.append(masks[i])
366
+ conf_all.append(preds[i - 1][1]["conf"])
367
+
368
+ # To compare target/reference loss
369
+ if i != len(gt_pts) - 1:
370
+ loss_right += frame_loss.cpu().detach().numpy().mean()
371
+ pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()
372
+
373
+ if pr_factor is not None and gt_factor is not None:
374
+ filter_factor = pr_factor[pr_factor > gt_factor]
375
+ else:
376
+ filter_factor = []
377
+
378
+ if len(filter_factor) > 0:
379
+ factor_loss = (filter_factor - gt_factor).abs().mean()
380
+ else:
381
+ factor_loss = 0.0
382
+
383
+ self_name = type(self).__name__
384
+ details = {
385
+ self_name + "_pts3d_1": float(loss_all[0].mean()),
386
+ self_name + "_pts3d_2": float(loss_all[1].mean()),
387
+ self_name + "loss_left": float(loss_left),
388
+ self_name + "loss_right": float(loss_right),
389
+ self_name + "conf_left": float(pred_conf_l),
390
+ self_name + "conf_right": float(pred_conf_r),
391
+ }
392
+
393
+ return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss
394
+
395
+
396
+ class ConfLoss_t(MultiLoss):
397
+ """Weighted regression by learned confidence.
398
+ Assuming the input pixel_loss is a pixel-level regression loss.
399
+
400
+ Principle:
401
+ high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
402
+ low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
403
+
404
+ alpha: hyperparameter
405
+ """
406
+
407
+ def __init__(self, pixel_loss, alpha=1):
408
+ super().__init__()
409
+ assert alpha > 0
410
+ self.alpha = alpha
411
+ self.pixel_loss = pixel_loss.with_reduction("none")
412
+
413
+ def get_name(self):
414
+ return f"ConfLoss({self.pixel_loss})"
415
+
416
+ def get_conf_log(self, x):
417
+ return x, torch.log(x)
418
+
419
+ def compute_frame_loss(self, gts, preds, **kw):
420
+ # compute per-pixel loss
421
+ (losses, masks, confs), details, loss_factor = (
422
+ self.pixel_loss.compute_frame_loss(gts, preds, **kw)
423
+ )
424
+
425
+ # weight by confidence
426
+ conf_losses = []
427
+ conf_sum = 0
428
+ for i in range(len(losses)):
429
+ conf, log_conf = self.get_conf_log(confs[i][masks[i]])
430
+ conf_sum += conf.mean()
431
+ conf_loss = losses[i] * conf - self.alpha * log_conf
432
+ conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
433
+ conf_losses.append(conf_loss)
434
+
435
+ conf_losses = torch.stack(conf_losses) * 2.0
436
+ conf_loss_mean = conf_losses.mean()
437
+
438
+ return (
439
+ conf_loss_mean,
440
+ dict(
441
+ conf_loss_1=float(conf_losses[0]),
442
+ conf_loss2=float(conf_losses[1]),
443
+ conf_mean=conf_sum / len(losses),
444
+ **details,
445
+ ),
446
+ loss_factor,
447
+ )
448
+
449
+
450
+ class Regr3D_t_ShiftInv(Regr3D_t):
451
+ """Same than Regr3D but invariant to depth shift."""
452
+
453
+ def get_all_pts3d_t(self, gts, preds):
454
+ # compute unnormalized points
455
+ gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
456
+ super().get_all_pts3d_t(gts, preds)
457
+ )
458
+
459
+ # pred_pts_l, pred_pts_r = pred_pts
460
+ gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]
461
+
462
+ pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
463
+ # pred_zs.append(pred_pts_r[-1][..., 2])
464
+
465
+ # compute median depth
466
+ gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
467
+ pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]
468
+
469
+ # subtract the median depth
470
+ for i in range(len(gt_pts)):
471
+ gt_pts[i][..., 2] -= gt_shift_z
472
+
473
+ for i in range(len(pred_pts)):
474
+ # for j in range(len(pred_pts[i])):
475
+ pred_pts[i][..., 2] -= pred_shift_z
476
+
477
+ monitoring = dict(
478
+ monitoring,
479
+ gt_shift_z=gt_shift_z.mean().detach(),
480
+ pred_shift_z=pred_shift_z.mean().detach(),
481
+ )
482
+ return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
483
+
484
+
485
+ class Regr3D_t_ScaleInv(Regr3D_t):
486
+ """Same than Regr3D but invariant to depth shift.
487
+ if gt_scale == True: enforce the prediction to take the same scale than GT
488
+ """
489
+
490
+ def get_all_pts3d_t(self, gts, preds):
491
+ # compute depth-normalized points
492
+ gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
493
+ super().get_all_pts3d_t(gts, preds)
494
+ )
495
+
496
+ # measure scene scale
497
+
498
+ # pred_pts_l, pred_pts_r = pred_pts
499
+
500
+ pred_pts_all = [
501
+ x.clone() for x in pred_pts
502
+ ] # [pred_pt for pred_pt in pred_pts_l]
503
+ # pred_pts_all.append(pred_pts_r[-1])
504
+
505
+ _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
506
+ _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)
507
+
508
+ # prevent predictions to be in a ridiculous range
509
+ pred_scale = pred_scale.clip(min=1e-3, max=1e3)
510
+
511
+ # subtract the median depth
512
+ if self.gt_scale:
513
+ for i in range(len(pred_pts)):
514
+ # for j in range(len(pred_pts[i])):
515
+ pred_pts[i] *= gt_scale / pred_scale
516
+
517
+ else:
518
+ for i in range(len(pred_pts)):
519
+ # for j in range(len(pred_pts[i])):
520
+ pred_pts[i] *= pred_scale / gt_scale
521
+
522
+ for i in range(len(gt_pts)):
523
+ gt_pts[i] *= gt_scale / pred_scale
524
+
525
+ monitoring = dict(
526
+ monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
527
+ )
528
+
529
+ return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
530
+
531
+
532
+ class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
533
+ # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
534
+ pass
FastVGGT/eval/data.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import os.path as osp
5
+ from collections import deque
6
+ from base import BaseStereoViewDataset
7
+ import dataset_utils.cropping as cropping
8
+ from vggt.utils.eval_utils import imread_cv2, shuffle_deque
9
+
10
+
11
+ class SevenScenes(BaseStereoViewDataset):
12
+ def __init__(
13
+ self,
14
+ num_seq=1,
15
+ num_frames=5,
16
+ min_thresh=10,
17
+ max_thresh=100,
18
+ test_id=None,
19
+ full_video=False,
20
+ tuple_list=None,
21
+ seq_id=None,
22
+ rebuttal=False,
23
+ shuffle_seed=-1,
24
+ kf_every=1,
25
+ *args,
26
+ ROOT,
27
+ **kwargs,
28
+ ):
29
+ self.ROOT = ROOT
30
+ super().__init__(*args, **kwargs)
31
+ self.num_seq = num_seq
32
+ self.num_frames = num_frames
33
+ self.max_thresh = max_thresh
34
+ self.min_thresh = min_thresh
35
+ self.test_id = test_id
36
+ self.full_video = full_video
37
+ self.kf_every = kf_every
38
+ self.seq_id = seq_id
39
+ self.rebuttal = rebuttal
40
+ self.shuffle_seed = shuffle_seed
41
+
42
+ # load all scenes
43
+ self.load_all_tuples(tuple_list)
44
+ self.load_all_scenes(ROOT)
45
+
46
+ def __len__(self):
47
+ if self.tuple_list is not None:
48
+ return len(self.tuple_list)
49
+ return len(self.scene_list) * self.num_seq
50
+
51
+ def load_all_tuples(self, tuple_list):
52
+ if tuple_list is not None:
53
+ self.tuple_list = tuple_list
54
+ # with open(tuple_path) as f:
55
+ # self.tuple_list = f.read().splitlines()
56
+
57
+ else:
58
+ self.tuple_list = None
59
+
60
+ def load_all_scenes(self, base_dir):
61
+
62
+ if self.tuple_list is not None:
63
+ # Use pre-defined simplerecon scene_ids
64
+ self.scene_list = [
65
+ "stairs/seq-06",
66
+ "stairs/seq-02",
67
+ "pumpkin/seq-06",
68
+ "chess/seq-01",
69
+ "heads/seq-02",
70
+ "fire/seq-02",
71
+ "office/seq-03",
72
+ "pumpkin/seq-03",
73
+ "redkitchen/seq-07",
74
+ "chess/seq-02",
75
+ "office/seq-01",
76
+ "redkitchen/seq-01",
77
+ "fire/seq-01",
78
+ ]
79
+ print(f"Found {len(self.scene_list)} sequences in split {self.split}")
80
+ return
81
+
82
+ scenes = os.listdir(base_dir)
83
+
84
+ file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split]
85
+
86
+ self.scene_list = []
87
+ for scene in scenes:
88
+ if self.test_id is not None and scene != self.test_id:
89
+ continue
90
+ # read file split
91
+ with open(osp.join(base_dir, scene, file_split)) as f:
92
+ seq_ids = f.read().splitlines()
93
+
94
+ for seq_id in seq_ids:
95
+ # seq is string, take the int part and make it 01, 02, 03
96
+ # seq_id = 'seq-{:2d}'.format(int(seq_id))
97
+ num_part = "".join(filter(str.isdigit, seq_id))
98
+ seq_id = f"seq-{num_part.zfill(2)}"
99
+ if self.seq_id is not None and seq_id != self.seq_id:
100
+ continue
101
+ self.scene_list.append(f"{scene}/{seq_id}")
102
+
103
+ print(f"Found {len(self.scene_list)} sequences in split {self.split}")
104
+
105
+ def _get_views(self, idx, resolution, rng):
106
+
107
+ if self.tuple_list is not None:
108
+ line = self.tuple_list[idx].split(" ")
109
+ scene_id = line[0]
110
+ img_idxs = line[1:]
111
+
112
+ else:
113
+ scene_id = self.scene_list[idx // self.num_seq]
114
+ seq_id = idx % self.num_seq
115
+
116
+ data_path = osp.join(self.ROOT, scene_id)
117
+ num_files = len([name for name in os.listdir(data_path) if "color" in name])
118
+ img_idxs = [f"{i:06d}" for i in range(num_files)]
119
+ img_idxs = img_idxs[:: self.kf_every]
120
+
121
+ # Intrinsics used in SimpleRecon
122
+ fx, fy, cx, cy = 525, 525, 320, 240
123
+ intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
124
+
125
+ views = []
126
+ imgs_idxs = deque(img_idxs)
127
+ if self.shuffle_seed >= 0:
128
+ imgs_idxs = shuffle_deque(imgs_idxs)
129
+
130
+ while len(imgs_idxs) > 0:
131
+ im_idx = imgs_idxs.popleft()
132
+ impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png")
133
+ depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png")
134
+ posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt")
135
+
136
+ rgb_image = imread_cv2(impath)
137
+
138
+ depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
139
+ rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
140
+
141
+ depthmap[depthmap == 65535] = 0
142
+ depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
143
+
144
+ depthmap[depthmap > 10] = 0
145
+ depthmap[depthmap < 1e-3] = 0
146
+
147
+ camera_pose = np.loadtxt(posepath).astype(np.float32)
148
+
149
+ if resolution != (224, 224) or self.rebuttal:
150
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
151
+ rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
152
+ )
153
+ else:
154
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
155
+ rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
156
+ )
157
+ W, H = rgb_image.size
158
+ cx = W // 2
159
+ cy = H // 2
160
+ l, t = cx - 112, cy - 112
161
+ r, b = cx + 112, cy + 112
162
+ crop_bbox = (l, t, r, b)
163
+ rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
164
+ rgb_image, depthmap, intrinsics, crop_bbox
165
+ )
166
+
167
+ views.append(
168
+ dict(
169
+ img=rgb_image,
170
+ depthmap=depthmap,
171
+ camera_pose=camera_pose,
172
+ camera_intrinsics=intrinsics,
173
+ dataset="7scenes",
174
+ label=osp.join(scene_id, im_idx),
175
+ instance=impath,
176
+ )
177
+ )
178
+ return views
179
+
180
+
181
+ class NRGBD(BaseStereoViewDataset):
182
+ def __init__(
183
+ self,
184
+ num_seq=1,
185
+ num_frames=5,
186
+ min_thresh=10,
187
+ max_thresh=100,
188
+ test_id=None,
189
+ full_video=False,
190
+ tuple_list=None,
191
+ seq_id=None,
192
+ rebuttal=False,
193
+ shuffle_seed=-1,
194
+ kf_every=1,
195
+ *args,
196
+ ROOT,
197
+ **kwargs,
198
+ ):
199
+
200
+ self.ROOT = ROOT
201
+ super().__init__(*args, **kwargs)
202
+ self.num_seq = num_seq
203
+ self.num_frames = num_frames
204
+ self.max_thresh = max_thresh
205
+ self.min_thresh = min_thresh
206
+ self.test_id = test_id
207
+ self.full_video = full_video
208
+ self.kf_every = kf_every
209
+ self.seq_id = seq_id
210
+ self.rebuttal = rebuttal
211
+ self.shuffle_seed = shuffle_seed
212
+
213
+ # load all scenes
214
+ self.load_all_tuples(tuple_list)
215
+ self.load_all_scenes(ROOT)
216
+
217
+ def __len__(self):
218
+ if self.tuple_list is not None:
219
+ return len(self.tuple_list)
220
+ return len(self.scene_list) * self.num_seq
221
+
222
+ def load_all_tuples(self, tuple_list):
223
+ if tuple_list is not None:
224
+ self.tuple_list = tuple_list
225
+ # with open(tuple_path) as f:
226
+ # self.tuple_list = f.read().splitlines()
227
+
228
+ else:
229
+ self.tuple_list = None
230
+
231
+ def load_all_scenes(self, base_dir):
232
+
233
+ scenes = [
234
+ d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
235
+ ]
236
+
237
+ if self.test_id is not None:
238
+ self.scene_list = [self.test_id]
239
+
240
+ else:
241
+ self.scene_list = scenes
242
+
243
+ print(f"Found {len(self.scene_list)} sequences in split {self.split}")
244
+
245
+ def load_poses(self, path):
246
+ file = open(path, "r")
247
+ lines = file.readlines()
248
+ file.close()
249
+ poses = []
250
+ valid = []
251
+ lines_per_matrix = 4
252
+ for i in range(0, len(lines), lines_per_matrix):
253
+ if "nan" in lines[i]:
254
+ valid.append(False)
255
+ poses.append(np.eye(4, 4, dtype=np.float32).tolist())
256
+ else:
257
+ valid.append(True)
258
+ pose_floats = [
259
+ [float(x) for x in line.split()]
260
+ for line in lines[i : i + lines_per_matrix]
261
+ ]
262
+ poses.append(pose_floats)
263
+
264
+ return np.array(poses, dtype=np.float32), valid
265
+
266
+ def _get_views(self, idx, resolution, rng):
267
+
268
+ if self.tuple_list is not None:
269
+ line = self.tuple_list[idx].split(" ")
270
+ scene_id = line[0]
271
+ img_idxs = line[1:]
272
+
273
+ else:
274
+ scene_id = self.scene_list[idx // self.num_seq]
275
+
276
+ num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images")))
277
+ img_idxs = [f"{i}" for i in range(num_files)]
278
+ img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)]
279
+
280
+ fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240
281
+ intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
282
+
283
+ posepath = osp.join(self.ROOT, scene_id, f"poses.txt")
284
+ camera_poses, valids = self.load_poses(posepath)
285
+
286
+ imgs_idxs = deque(img_idxs)
287
+ if self.shuffle_seed >= 0:
288
+ imgs_idxs = shuffle_deque(imgs_idxs)
289
+ views = []
290
+
291
+ while len(imgs_idxs) > 0:
292
+ im_idx = imgs_idxs.popleft()
293
+
294
+ impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png")
295
+ depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png")
296
+
297
+ rgb_image = imread_cv2(impath)
298
+ depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
299
+ depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
300
+ depthmap[depthmap > 10] = 0
301
+ depthmap[depthmap < 1e-3] = 0
302
+
303
+ rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
304
+
305
+ camera_pose = camera_poses[int(im_idx)]
306
+ # gl to cv
307
+ camera_pose[:, 1:3] *= -1.0
308
+ if resolution != (224, 224) or self.rebuttal:
309
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
310
+ rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
311
+ )
312
+ else:
313
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
314
+ rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
315
+ )
316
+ W, H = rgb_image.size
317
+ cx = W // 2
318
+ cy = H // 2
319
+ l, t = cx - 112, cy - 112
320
+ r, b = cx + 112, cy + 112
321
+ crop_bbox = (l, t, r, b)
322
+ rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
323
+ rgb_image, depthmap, intrinsics, crop_bbox
324
+ )
325
+
326
+ views.append(
327
+ dict(
328
+ img=rgb_image,
329
+ depthmap=depthmap,
330
+ camera_pose=camera_pose,
331
+ camera_intrinsics=intrinsics,
332
+ dataset="nrgbd",
333
+ label=osp.join(scene_id, im_idx),
334
+ instance=impath,
335
+ )
336
+ )
337
+
338
+ return views
FastVGGT/eval/dataset_utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (140 Bytes). View file
 
FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc ADDED
Binary file (5.85 kB). View file
 
FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc ADDED
Binary file (4.29 kB). View file
 
FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc ADDED
Binary file (4.29 kB). View file
 
FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (2.18 kB). View file
 
FastVGGT/eval/dataset_utils/corr.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ def todevice(batch, device, callback=None, non_blocking=False):
11
+ """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
12
+
13
+ batch: list, tuple, dict of tensors or other things
14
+ device: pytorch device or 'numpy'
15
+ callback: function that would be called on every sub-elements.
16
+ """
17
+ if callback:
18
+ batch = callback(batch)
19
+
20
+ if isinstance(batch, dict):
21
+ return {k: todevice(v, device) for k, v in batch.items()}
22
+
23
+ if isinstance(batch, (tuple, list)):
24
+ return type(batch)(todevice(x, device) for x in batch)
25
+
26
+ x = batch
27
+ if device == "numpy":
28
+ if isinstance(x, torch.Tensor):
29
+ x = x.detach().cpu().numpy()
30
+ elif x is not None:
31
+ if isinstance(x, np.ndarray):
32
+ x = torch.from_numpy(x)
33
+ if torch.is_tensor(x):
34
+ x = x.to(device, non_blocking=non_blocking)
35
+ return x
36
+
37
+
38
+ to_device = todevice # alias
39
+
40
+
41
+ def to_numpy(x):
42
+ return todevice(x, "numpy")
43
+
44
+
45
+ def geotrf(Trf, pts, ncol=None, norm=False):
46
+ """Apply a geometric transformation to a list of 3-D points.
47
+
48
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
49
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
50
+
51
+ ncol: int. number of columns of the result (2 or 3)
52
+ norm: float. if != 0, the resut is projected on the z=norm plane.
53
+
54
+ Returns an array of projected 2d points.
55
+ """
56
+ assert Trf.ndim >= 2
57
+ if isinstance(Trf, np.ndarray):
58
+ pts = np.asarray(pts)
59
+ elif isinstance(Trf, torch.Tensor):
60
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
61
+
62
+ output_reshape = pts.shape[:-1]
63
+ ncol = ncol or pts.shape[-1]
64
+
65
+ if (
66
+ isinstance(Trf, torch.Tensor)
67
+ and isinstance(pts, torch.Tensor)
68
+ and Trf.ndim == 3
69
+ and pts.ndim == 4
70
+ ):
71
+ d = pts.shape[3]
72
+ if Trf.shape[-1] == d:
73
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
74
+ elif Trf.shape[-1] == d + 1:
75
+ pts = (
76
+ torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
77
+ + Trf[:, None, None, :d, d]
78
+ )
79
+ else:
80
+ raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
81
+ else:
82
+ if Trf.ndim >= 3:
83
+ n = Trf.ndim - 2
84
+ assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
85
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
86
+
87
+ if pts.ndim > Trf.ndim:
88
+
89
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
90
+ elif pts.ndim == 2:
91
+
92
+ pts = pts[:, None, :]
93
+
94
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
95
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
96
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
97
+ elif pts.shape[-1] == Trf.shape[-1]:
98
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
99
+ pts = pts @ Trf
100
+ else:
101
+ pts = Trf @ pts.T
102
+ if pts.ndim >= 2:
103
+ pts = pts.swapaxes(-1, -2)
104
+
105
+ if norm:
106
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
107
+ if norm != 1:
108
+ pts *= norm
109
+
110
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
111
+ return res
112
+
113
+
114
+ def inv(mat):
115
+ """Invert a torch or numpy matrix"""
116
+ if isinstance(mat, torch.Tensor):
117
+ return torch.linalg.inv(mat)
118
+ if isinstance(mat, np.ndarray):
119
+ return np.linalg.inv(mat)
120
+ raise ValueError(f"bad matrix type = {type(mat)}")
121
+
122
+
123
+ def reproject_view(pts3d, view2):
124
+ shape = view2["pts3d"].shape[:2]
125
+ return reproject(
126
+ pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
127
+ )
128
+
129
+
130
+ def reproject(pts3d, K, world2cam, shape):
131
+ H, W, THREE = pts3d.shape
132
+ assert THREE == 3
133
+
134
+ with np.errstate(divide="ignore", invalid="ignore"):
135
+ pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
136
+
137
+ return (H, W), ravel_xy(pos, shape)
138
+
139
+
140
+ def ravel_xy(pos, shape):
141
+ H, W = shape
142
+ with np.errstate(invalid="ignore"):
143
+ qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
144
+ quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
145
+ min=0, max=H - 1, out=qy
146
+ )
147
+ return quantized_pos
148
+
149
+
150
+ def unravel_xy(pos, shape):
151
+
152
+ return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
153
+
154
+
155
+ def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
156
+ is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
157
+ pos1 = is_reciprocal1.nonzero()[0]
158
+ pos2 = corres_1_to_2[pos1]
159
+ if ret_recip:
160
+ return is_reciprocal1, pos1, pos2
161
+ return pos1, pos2
162
+
163
+
164
+ def extract_correspondences_from_pts3d(
165
+ view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
166
+ ):
167
+ view1, view2 = to_numpy((view1, view2))
168
+
169
+ shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
170
+ shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
171
+
172
+ is_reciprocal1, pos1, pos2 = reciprocal_1d(
173
+ corres1_to_2, corres2_to_1, ret_recip=True
174
+ )
175
+ is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
176
+
177
+ if target_n_corres is None:
178
+ if ret_xy:
179
+ pos1 = unravel_xy(pos1, shape1)
180
+ pos2 = unravel_xy(pos2, shape2)
181
+ return pos1, pos2
182
+
183
+ available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
184
+ target_n_positives = int(target_n_corres * (1 - nneg))
185
+ n_positives = min(len(pos1), target_n_positives)
186
+ n_negatives = min(target_n_corres - n_positives, available_negatives)
187
+
188
+ if n_negatives + n_positives != target_n_corres:
189
+
190
+ n_positives = target_n_corres - n_negatives
191
+ assert n_positives <= len(pos1)
192
+
193
+ assert n_positives <= len(pos1)
194
+ assert n_positives <= len(pos2)
195
+ assert n_negatives <= (~is_reciprocal1).sum()
196
+ assert n_negatives <= (~is_reciprocal2).sum()
197
+ assert n_positives + n_negatives == target_n_corres
198
+
199
+ valid = np.ones(n_positives, dtype=bool)
200
+ if n_positives < len(pos1):
201
+
202
+ perm = rng.permutation(len(pos1))[:n_positives]
203
+ pos1 = pos1[perm]
204
+ pos2 = pos2[perm]
205
+
206
+ if n_negatives > 0:
207
+
208
+ def norm(p):
209
+ return p / p.sum()
210
+
211
+ pos1 = np.r_[
212
+ pos1,
213
+ rng.choice(
214
+ shape1[0] * shape1[1],
215
+ size=n_negatives,
216
+ replace=False,
217
+ p=norm(~is_reciprocal1),
218
+ ),
219
+ ]
220
+ pos2 = np.r_[
221
+ pos2,
222
+ rng.choice(
223
+ shape2[0] * shape2[1],
224
+ size=n_negatives,
225
+ replace=False,
226
+ p=norm(~is_reciprocal2),
227
+ ),
228
+ ]
229
+ valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
230
+
231
+ if ret_xy:
232
+ pos1 = unravel_xy(pos1, shape1)
233
+ pos2 = unravel_xy(pos2, shape2)
234
+ return pos1, pos2, valid
FastVGGT/eval/dataset_utils/cropping.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+
6
+ import PIL.Image
7
+ import os
8
+
9
+ from utils import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
10
+
11
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
12
+ import cv2 # noqa
13
+ import numpy as np # noqa
14
+
15
+ try:
16
+ lanczos = PIL.Image.Resampling.LANCZOS
17
+ bicubic = PIL.Image.Resampling.BICUBIC
18
+ except AttributeError:
19
+ lanczos = PIL.Image.LANCZOS
20
+ bicubic = PIL.Image.BICUBIC
21
+
22
+
23
+ class ImageList:
24
+ """Convenience class to aply the same operation to a whole set of images."""
25
+
26
+ def __init__(self, images):
27
+ if not isinstance(images, (tuple, list, set)):
28
+ images = [images]
29
+ self.images = []
30
+ for image in images:
31
+ if not isinstance(image, PIL.Image.Image):
32
+ image = PIL.Image.fromarray(image)
33
+ self.images.append(image)
34
+
35
+ def __len__(self):
36
+ return len(self.images)
37
+
38
+ def to_pil(self):
39
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
40
+
41
+ @property
42
+ def size(self):
43
+ sizes = [im.size for im in self.images]
44
+ assert all(sizes[0] == s for s in sizes)
45
+ return sizes[0]
46
+
47
+ def resize(self, *args, **kwargs):
48
+ return ImageList(self._dispatch("resize", *args, **kwargs))
49
+
50
+ def crop(self, *args, **kwargs):
51
+ return ImageList(self._dispatch("crop", *args, **kwargs))
52
+
53
+ def _dispatch(self, func, *args, **kwargs):
54
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
55
+
56
+
57
+ def rescale_image_depthmap(
58
+ image, depthmap, camera_intrinsics, output_resolution, force=True
59
+ ):
60
+ """Jointly rescale a (image, depthmap)
61
+ so that (out_width, out_height) >= output_res
62
+ """
63
+ image = ImageList(image)
64
+ input_resolution = np.array(image.size) # (W,H)
65
+ output_resolution = np.array(output_resolution)
66
+ if depthmap is not None:
67
+
68
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
69
+
70
+ assert output_resolution.shape == (2,)
71
+ scale_final = max(output_resolution / image.size) + 1e-8
72
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
73
+ return (image.to_pil(), depthmap, camera_intrinsics)
74
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
75
+
76
+ image = image.resize(
77
+ output_resolution, resample=lanczos if scale_final < 1 else bicubic
78
+ )
79
+ if depthmap is not None:
80
+ depthmap = cv2.resize(
81
+ depthmap,
82
+ output_resolution,
83
+ fx=scale_final,
84
+ fy=scale_final,
85
+ interpolation=cv2.INTER_NEAREST,
86
+ )
87
+
88
+ camera_intrinsics = camera_matrix_of_crop(
89
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
90
+ )
91
+
92
+ return image.to_pil(), depthmap, camera_intrinsics
93
+
94
+
95
+ def camera_matrix_of_crop(
96
+ input_camera_matrix,
97
+ input_resolution,
98
+ output_resolution,
99
+ scaling=1,
100
+ offset_factor=0.5,
101
+ offset=None,
102
+ ):
103
+
104
+ margins = np.asarray(input_resolution) * scaling - output_resolution
105
+ assert np.all(margins >= 0.0)
106
+ if offset is None:
107
+ offset = offset_factor * margins
108
+
109
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
110
+ output_camera_matrix_colmap[:2, :] *= scaling
111
+ output_camera_matrix_colmap[:2, 2] -= offset
112
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
113
+
114
+ return output_camera_matrix
115
+
116
+
117
+ def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
118
+ """
119
+ Return a crop of the input view.
120
+ """
121
+ image = ImageList(image)
122
+ l, t, r, b = crop_bbox
123
+
124
+ image = image.crop((l, t, r, b))
125
+ depthmap = depthmap[t:b, l:r]
126
+
127
+ camera_intrinsics = camera_intrinsics.copy()
128
+ camera_intrinsics[0, 2] -= l
129
+ camera_intrinsics[1, 2] -= t
130
+
131
+ return image.to_pil(), depthmap, camera_intrinsics
132
+
133
+
134
+ def bbox_from_intrinsics_in_out(
135
+ input_camera_matrix, output_camera_matrix, output_resolution
136
+ ):
137
+ out_width, out_height = output_resolution
138
+ l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
139
+ crop_bbox = (l, t, l + out_width, t + out_height)
140
+ return crop_bbox
FastVGGT/eval/dataset_utils/transforms.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+
6
+ import torchvision.transforms as tvf
7
+
8
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
9
+
10
+
11
+ ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
12
+
13
+
14
+ def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
15
+ if isinstance(value, (int, float)):
16
+ if value < 0:
17
+ raise ValueError(f"If is a single number, it must be non negative.")
18
+ value = [center - float(value), center + float(value)]
19
+ if clip_first_on_zero:
20
+ value[0] = max(value[0], 0.0)
21
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
22
+ value = [float(value[0]), float(value[1])]
23
+ else:
24
+ raise TypeError(f"should be a single number or a list/tuple with length 2.")
25
+
26
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
27
+ raise ValueError(f"values should be between {bound}, but got {value}.")
28
+
29
+ if value[0] == value[1] == center:
30
+ return None
31
+ else:
32
+ return tuple(value)
33
+
34
+
35
+ import torch
36
+ import torchvision.transforms.functional as F
37
+
38
+
39
+ def SeqColorJitter():
40
+ """
41
+ Return a color jitter transform with same random parameters
42
+ """
43
+ brightness = _check_input(0.5)
44
+ contrast = _check_input(0.5)
45
+ saturation = _check_input(0.5)
46
+ hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
47
+
48
+ fn_idx = torch.randperm(4)
49
+ brightness_factor = (
50
+ None
51
+ if brightness is None
52
+ else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
53
+ )
54
+ contrast_factor = (
55
+ None
56
+ if contrast is None
57
+ else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
58
+ )
59
+ saturation_factor = (
60
+ None
61
+ if saturation is None
62
+ else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
63
+ )
64
+ hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
65
+
66
+ def _color_jitter(img):
67
+ for fn_id in fn_idx:
68
+ if fn_id == 0 and brightness_factor is not None:
69
+ img = F.adjust_brightness(img, brightness_factor)
70
+ elif fn_id == 1 and contrast_factor is not None:
71
+ img = F.adjust_contrast(img, contrast_factor)
72
+ elif fn_id == 2 and saturation_factor is not None:
73
+ img = F.adjust_saturation(img, saturation_factor)
74
+ elif fn_id == 3 and hue_factor is not None:
75
+ img = F.adjust_hue(img, hue_factor)
76
+ return ImgNorm(img)
77
+
78
+ return _color_jitter
FastVGGT/eval/eval_7andN.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Ensure project root is on sys.path for absolute imports like `vggt.*`
5
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
6
+ if ROOT_DIR not in sys.path:
7
+ sys.path.insert(0, ROOT_DIR)
8
+
9
+ import time
10
+ import torch
11
+ import argparse
12
+ import numpy as np
13
+ import open3d as o3d
14
+ import os.path as osp
15
+ from torch.utils.data import DataLoader
16
+ from torch.utils.data._utils.collate import default_collate
17
+ from tqdm import tqdm
18
+ from collections import defaultdict
19
+ import torchvision.transforms as transforms
20
+
21
+
22
+ def get_args_parser():
23
+ parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False)
24
+ parser.add_argument(
25
+ "--ckpt_path",
26
+ type=str,
27
+ default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
28
+ help="ckpt name",
29
+ )
30
+ parser.add_argument("--device", type=str, default="cuda:0", help="device")
31
+ parser.add_argument("--model_name", type=str, default="VGGT")
32
+ parser.add_argument(
33
+ "--conf_thresh", type=float, default=0.0, help="confidence threshold"
34
+ )
35
+ parser.add_argument(
36
+ "--output_dir",
37
+ type=str,
38
+ default="/home/sy/code/FastVGGT/eval_results",
39
+ help="value for outdir",
40
+ )
41
+ parser.add_argument("--size", type=int, default=518)
42
+ parser.add_argument("--revisit", type=int, default=1, help="revisit times")
43
+ parser.add_argument("--freeze", action="store_true")
44
+ parser.add_argument("--use_proj", action="store_true")
45
+ parser.add_argument(
46
+ "--merging", type=int, default=0, help="VGGT aggregator merging steps"
47
+ )
48
+ parser.add_argument("--kf", type=int, default=2, help="key frame")
49
+ return parser
50
+
51
+
52
+ def main(args):
53
+ from data import SevenScenes, NRGBD
54
+ from utils import accuracy, completion
55
+
56
+ if args.size == 512:
57
+ resolution = (512, 384)
58
+ elif args.size == 224:
59
+ resolution = 224
60
+ elif args.size == 518:
61
+ resolution = (518, 392)
62
+ else:
63
+ raise NotImplementedError
64
+ datasets_all = {
65
+ "7scenes": SevenScenes(
66
+ split="test",
67
+ ROOT="/data/sy/7scenes",
68
+ resolution=resolution,
69
+ num_seq=1,
70
+ full_video=True,
71
+ kf_every=args.kf,
72
+ ), # 20),
73
+ "NRGBD": NRGBD(
74
+ split="test",
75
+ ROOT="/data/sy/neural_rgbd_data",
76
+ resolution=resolution,
77
+ num_seq=1,
78
+ full_video=True,
79
+ kf_every=args.kf,
80
+ ),
81
+ }
82
+
83
+ device = args.device
84
+ model_name = args.model_name
85
+
86
+ from vggt.models.vggt import VGGT
87
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
88
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
89
+ from criterion import Regr3D_t_ScaleShiftInv, L21
90
+
91
+ # Force use of bf16 data type
92
+ dtype = torch.bfloat16
93
+ # Load VGGT model
94
+ model = VGGT(merging=args.merging, enable_point=True)
95
+ ckpt = torch.load(args.ckpt_path, map_location="cpu")
96
+
97
+ # ✅ Fix: load pre-trained weights
98
+ model.load_state_dict(
99
+ ckpt, strict=False
100
+ ) # Use strict=False due to enable_point=True difference
101
+
102
+ model = model.cuda().eval()
103
+ model = model.to(torch.bfloat16)
104
+
105
+ del ckpt
106
+ os.makedirs(osp.join(args.output_dir, f"{args.kf}"), exist_ok=True)
107
+
108
+ criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True)
109
+
110
+ with torch.no_grad():
111
+ for name_data, dataset in datasets_all.items():
112
+ save_path = osp.join(osp.join(args.output_dir, f"{args.kf}"), name_data)
113
+ os.makedirs(save_path, exist_ok=True)
114
+ log_file = osp.join(save_path, "logs.txt")
115
+
116
+ acc_all = 0
117
+ acc_all_med = 0
118
+ comp_all = 0
119
+ comp_all_med = 0
120
+ nc1_all = 0
121
+ nc1_all_med = 0
122
+ nc2_all = 0
123
+ nc2_all_med = 0
124
+ scene_infer_times = defaultdict(list)
125
+
126
+ for data_idx in tqdm(range(len(dataset))):
127
+ batch = default_collate([dataset[data_idx]])
128
+ ignore_keys = set(
129
+ [
130
+ "depthmap",
131
+ "dataset",
132
+ "label",
133
+ "instance",
134
+ "idx",
135
+ "true_shape",
136
+ "rng",
137
+ ]
138
+ )
139
+ for view in batch:
140
+ for name in view.keys(): # pseudo_focal
141
+ if name in ignore_keys:
142
+ continue
143
+ if isinstance(view[name], tuple) or isinstance(
144
+ view[name], list
145
+ ):
146
+ view[name] = [
147
+ x.to(device, non_blocking=True) for x in view[name]
148
+ ]
149
+ else:
150
+ view[name] = view[name].to(device, non_blocking=True)
151
+
152
+ pts_all = []
153
+ pts_gt_all = []
154
+ images_all = []
155
+ masks_all = []
156
+ conf_all = []
157
+ in_camera1 = None
158
+
159
+ dtype = (
160
+ torch.bfloat16
161
+ if torch.cuda.get_device_capability()[0] >= 8
162
+ else torch.float16
163
+ )
164
+ with torch.cuda.amp.autocast(dtype=dtype):
165
+ if isinstance(batch, dict) and "img" in batch:
166
+ batch["img"] = (batch["img"] + 1.0) / 2.0
167
+ elif isinstance(batch, list) and all(
168
+ isinstance(v, dict) and "img" in v for v in batch
169
+ ):
170
+ for view in batch:
171
+ view["img"] = (view["img"] + 1.0) / 2.0
172
+ # Gather all `img` tensors into a single tensor of shape [N, C, H, W]
173
+ imgs_tensor = torch.cat([v["img"] for v in batch], dim=0)
174
+
175
+ with torch.cuda.amp.autocast(dtype=dtype):
176
+ with torch.no_grad():
177
+ torch.cuda.synchronize()
178
+ start = time.time()
179
+ preds = model(imgs_tensor)
180
+ torch.cuda.synchronize()
181
+ end = time.time()
182
+ inference_time_ms = (end - start) * 1000
183
+ print(f"Inference time: {inference_time_ms:.2f}ms")
184
+
185
+ # Wrap model outputs per-view to align with batch later
186
+ predictions = preds
187
+ views = batch # list[dict]
188
+ if "pose_enc" in predictions:
189
+ B, S = predictions["pose_enc"].shape[:2]
190
+ elif "world_points" in predictions:
191
+ B, S = predictions["world_points"].shape[:2]
192
+ else:
193
+ raise KeyError(
194
+ "predictions is missing a key to infer sequence length"
195
+ )
196
+
197
+ ress = []
198
+ for s in range(S):
199
+ res = {
200
+ "pts3d_in_other_view": predictions["world_points"][:, s],
201
+ "conf": predictions["world_points_conf"][:, s],
202
+ "depth": predictions["depth"][:, s],
203
+ "depth_conf": predictions["depth_conf"][:, s],
204
+ "camera_pose": predictions["pose_enc"][:, s, :],
205
+ }
206
+ if (
207
+ isinstance(views, list)
208
+ and s < len(views)
209
+ and "valid_mask" in views[s]
210
+ ):
211
+ res["valid_mask"] = views[s]["valid_mask"]
212
+ if "track" in predictions:
213
+ res.update(
214
+ {
215
+ "track": predictions["track"][:, s],
216
+ "vis": (
217
+ predictions.get("vis", None)[:, s]
218
+ if "vis" in predictions
219
+ else None
220
+ ),
221
+ "track_conf": (
222
+ predictions.get("conf", None)[:, s]
223
+ if "conf" in predictions
224
+ else None
225
+ ),
226
+ }
227
+ )
228
+ ress.append(res)
229
+
230
+ preds = ress
231
+
232
+ valid_length = len(preds) // args.revisit
233
+ if args.revisit > 1:
234
+ preds = preds[-valid_length:]
235
+ batch = batch[-valid_length:]
236
+
237
+ # Evaluation
238
+ print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}")
239
+ gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
240
+ criterion.get_all_pts3d_t(batch, preds)
241
+ )
242
+
243
+ in_camera1 = None
244
+ pts_all = []
245
+ pts_gt_all = []
246
+ images_all = []
247
+ masks_all = []
248
+ conf_all = []
249
+
250
+ for j, view in enumerate(batch):
251
+ if in_camera1 is None:
252
+ in_camera1 = view["camera_pose"][0].cpu()
253
+
254
+ image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0]
255
+ mask = view["valid_mask"].cpu().numpy()[0]
256
+
257
+ pts = pred_pts[j].cpu().numpy()[0]
258
+ conf = preds[j]["conf"].cpu().data.numpy()[0]
259
+
260
+ # mask = mask & (conf > 1.8)
261
+
262
+ pts_gt = gt_pts[j].detach().cpu().numpy()[0]
263
+
264
+ H, W = image.shape[:2]
265
+ cx = W // 2
266
+ cy = H // 2
267
+ l, t = cx - 112, cy - 112
268
+ r, b = cx + 112, cy + 112
269
+ image = image[t:b, l:r]
270
+ mask = mask[t:b, l:r]
271
+ pts = pts[t:b, l:r]
272
+ pts_gt = pts_gt[t:b, l:r]
273
+
274
+ images_all.append(image[None, ...])
275
+ pts_all.append(pts[None, ...])
276
+ pts_gt_all.append(pts_gt[None, ...])
277
+ masks_all.append(mask[None, ...])
278
+ conf_all.append(conf[None, ...])
279
+
280
+ images_all = np.concatenate(images_all, axis=0)
281
+ pts_all = np.concatenate(pts_all, axis=0)
282
+ pts_gt_all = np.concatenate(pts_gt_all, axis=0)
283
+ masks_all = np.concatenate(masks_all, axis=0)
284
+
285
+ scene_id = view["label"][0].rsplit("/", 1)[0]
286
+ # Record average inference time per scene
287
+ try:
288
+ scene_infer_times[scene_id].append(float(inference_time_ms))
289
+ except Exception:
290
+ pass
291
+
292
+ save_params = {}
293
+
294
+ save_params["images_all"] = images_all
295
+ save_params["pts_all"] = pts_all
296
+ save_params["pts_gt_all"] = pts_gt_all
297
+ save_params["masks_all"] = masks_all
298
+
299
+ pts_all_masked = pts_all[masks_all > 0]
300
+ pts_gt_all_masked = pts_gt_all[masks_all > 0]
301
+ images_all_masked = images_all[masks_all > 0]
302
+
303
+ mask = np.isfinite(pts_all_masked)
304
+ pts_all_masked = pts_all_masked[mask]
305
+
306
+ mask_gt = np.isfinite(pts_gt_all_masked)
307
+ pts_gt_all_masked = pts_gt_all_masked[mask_gt]
308
+ images_all_masked = images_all_masked[mask]
309
+
310
+ # Reshape to point cloud (N, 3) before sampling
311
+ pts_all_masked = pts_all_masked.reshape(-1, 3)
312
+ pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
313
+ images_all_masked = images_all_masked.reshape(-1, 3)
314
+
315
+ # If number of points exceeds threshold, sample by points
316
+ if pts_all_masked.shape[0] > 999999:
317
+ sample_indices = np.random.choice(
318
+ pts_all_masked.shape[0], 999999, replace=False
319
+ )
320
+ pts_all_masked = pts_all_masked[sample_indices]
321
+ images_all_masked = images_all_masked[sample_indices]
322
+
323
+ # Apply the same sampling to GT point cloud
324
+ if pts_gt_all_masked.shape[0] > 999999:
325
+ sample_indices_gt = np.random.choice(
326
+ pts_gt_all_masked.shape[0], 999999, replace=False
327
+ )
328
+ pts_gt_all_masked = pts_gt_all_masked[sample_indices_gt]
329
+
330
+ if args.use_proj:
331
+
332
+ def umeyama_alignment(
333
+ src: np.ndarray, dst: np.ndarray, with_scale: bool = True
334
+ ):
335
+ assert src.shape == dst.shape
336
+ N, dim = src.shape
337
+
338
+ mu_src = src.mean(axis=0)
339
+ mu_dst = dst.mean(axis=0)
340
+ src_c = src - mu_src
341
+ dst_c = dst - mu_dst
342
+
343
+ Sigma = dst_c.T @ src_c / N # (3,3)
344
+
345
+ U, D, Vt = np.linalg.svd(Sigma)
346
+
347
+ S = np.eye(dim)
348
+ if np.linalg.det(U) * np.linalg.det(Vt) < 0:
349
+ S[-1, -1] = -1
350
+
351
+ R = U @ S @ Vt
352
+
353
+ if with_scale:
354
+ var_src = (src_c**2).sum() / N
355
+ s = (D * S.diagonal()).sum() / var_src
356
+ else:
357
+ s = 1.0
358
+
359
+ t = mu_dst - s * R @ mu_src
360
+
361
+ return s, R, t
362
+
363
+ pts_all_masked = pts_all_masked.reshape(-1, 3)
364
+ pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
365
+ s, R, t = umeyama_alignment(
366
+ pts_all_masked, pts_gt_all_masked, with_scale=True
367
+ )
368
+ pts_all_aligned = (s * (R @ pts_all_masked.T)).T + t # (N,3)
369
+ pts_all_masked = pts_all_aligned
370
+
371
+ pcd = o3d.geometry.PointCloud()
372
+ pcd.points = o3d.utility.Vector3dVector(pts_all_masked)
373
+ pcd.colors = o3d.utility.Vector3dVector(images_all_masked)
374
+
375
+ pcd_gt = o3d.geometry.PointCloud()
376
+ pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked)
377
+ pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked)
378
+
379
+ trans_init = np.eye(4)
380
+
381
+ threshold = 0.1
382
+ reg_p2p = o3d.pipelines.registration.registration_icp(
383
+ pcd,
384
+ pcd_gt,
385
+ threshold,
386
+ trans_init,
387
+ o3d.pipelines.registration.TransformationEstimationPointToPoint(),
388
+ )
389
+
390
+ transformation = reg_p2p.transformation
391
+
392
+ pcd = pcd.transform(transformation)
393
+ pcd.estimate_normals()
394
+ pcd_gt.estimate_normals()
395
+
396
+ gt_normal = np.asarray(pcd_gt.normals)
397
+ pred_normal = np.asarray(pcd.normals)
398
+
399
+ acc, acc_med, nc1, nc1_med = accuracy(
400
+ pcd_gt.points, pcd.points, gt_normal, pred_normal
401
+ )
402
+ comp, comp_med, nc2, nc2_med = completion(
403
+ pcd_gt.points, pcd.points, gt_normal, pred_normal
404
+ )
405
+ print(
406
+ f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}"
407
+ )
408
+ print(
409
+ f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}",
410
+ file=open(log_file, "a"),
411
+ )
412
+
413
+ acc_all += acc
414
+ comp_all += comp
415
+ nc1_all += nc1
416
+ nc2_all += nc2
417
+
418
+ acc_all_med += acc_med
419
+ comp_all_med += comp_med
420
+ nc1_all_med += nc1_med
421
+ nc2_all_med += nc2_med
422
+
423
+ # release cuda memory
424
+ torch.cuda.empty_cache()
425
+
426
+ # Get depth from pcd and run TSDFusion
427
+ to_write = ""
428
+ # Read the log file
429
+ if os.path.exists(osp.join(save_path, "logs.txt")):
430
+ with open(osp.join(save_path, "logs.txt"), "r") as f_sub:
431
+ to_write += f_sub.read()
432
+
433
+ with open(osp.join(save_path, f"logs_all.txt"), "w") as f:
434
+ log_data = to_write
435
+ metrics = defaultdict(list)
436
+ for line in log_data.strip().split("\n"):
437
+ match = regex.match(line)
438
+ if match:
439
+ data = match.groupdict()
440
+ # Exclude 'scene_id' from metrics as it's an identifier
441
+ for key, value in data.items():
442
+ if key != "scene_id":
443
+ metrics[key].append(float(value))
444
+ metrics["nc"].append(
445
+ (float(data["nc1"]) + float(data["nc2"])) / 2
446
+ )
447
+ metrics["nc_med"].append(
448
+ (float(data["nc1_med"]) + float(data["nc2_med"])) / 2
449
+ )
450
+ mean_metrics = {
451
+ metric: sum(values) / len(values)
452
+ for metric, values in metrics.items()
453
+ }
454
+
455
+ c_name = "mean"
456
+ print_str = f"{c_name.ljust(20)}: "
457
+ for m_name in mean_metrics:
458
+ print_num = np.mean(mean_metrics[m_name])
459
+ print_str = print_str + f"{m_name}: {print_num:.3f} | "
460
+ print_str = print_str + "\n"
461
+ # Summarize per-scene average inference time
462
+ time_lines = []
463
+ for sid, times in scene_infer_times.items():
464
+ if len(times) > 0:
465
+ time_lines.append(
466
+ f"Idx: {sid}, Time_avg_ms: {np.mean(times):.2f}"
467
+ )
468
+ time_block = "\n".join(time_lines) + (
469
+ "\n" if len(time_lines) > 0 else ""
470
+ )
471
+
472
+ f.write(to_write + time_block + print_str)
473
+
474
+
475
+ from collections import defaultdict
476
+ import re
477
+
478
+ pattern = r"""
479
+ Idx:\s*(?P<scene_id>[^,]+),\s*
480
+ Acc:\s*(?P<acc>[^,]+),\s*
481
+ Comp:\s*(?P<comp>[^,]+),\s*
482
+ NC1:\s*(?P<nc1>[^,]+),\s*
483
+ NC2:\s*(?P<nc2>[^,]+)\s*-\s*
484
+ Acc_med:\s*(?P<acc_med>[^,]+),\s*
485
+ Compc_med:\s*(?P<comp_med>[^,]+),\s*
486
+ NC1c_med:\s*(?P<nc1_med>[^,]+),\s*
487
+ NC2c_med:\s*(?P<nc2_med>[^,]+)
488
+ """
489
+
490
+ regex = re.compile(pattern, re.VERBOSE)
491
+
492
+
493
+ if __name__ == "__main__":
494
+ parser = get_args_parser()
495
+ args = parser.parse_args()
496
+
497
+ main(args)
FastVGGT/eval/eval_custom.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+ import sys
7
+ import matplotlib.pyplot as plt
8
+ from scipy.spatial.transform import Rotation
9
+
10
+ # Ensure project root is in sys.path for absolute imports like `vggt.*`
11
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
12
+ if ROOT_DIR not in sys.path:
13
+ sys.path.insert(0, ROOT_DIR)
14
+
15
+ from vggt.models.vggt import VGGT
16
+ from vggt.utils.eval_utils import (
17
+ load_poses,
18
+ get_vgg_input_imgs,
19
+ get_sorted_image_paths,
20
+ build_frame_selection,
21
+ load_images_rgb,
22
+ infer_vggt_and_reconstruct,
23
+ evaluate_scene_and_save,
24
+ )
25
+
26
+ # Import pose visualization libraries (optional EVO support)
27
+ try:
28
+ from evo.core.trajectory import PoseTrajectory3D
29
+ import evo.tools.plot as plot
30
+
31
+ EVO_AVAILABLE = True
32
+ except ImportError:
33
+ # EVO is optional; we have a matplotlib-based fallback
34
+ EVO_AVAILABLE = False
35
+
36
+
37
+ def visualize_predicted_poses(
38
+ all_cam_to_world_mat, frame_ids, output_scene_dir, scene_name="custom_dataset"
39
+ ):
40
+ """
41
+ Visualize the predicted camera pose trajectory (no GT comparison required).
42
+
43
+ Args:
44
+ all_cam_to_world_mat: List of camera-to-world transform matrices
45
+ frame_ids: List of frame IDs
46
+ output_scene_dir: Output directory
47
+ scene_name: Scene name
48
+ """
49
+ # Provide basic pose visualization even without EVO
50
+ if not EVO_AVAILABLE:
51
+ print("⚠️ EVO not installed; using basic matplotlib visualization")
52
+
53
+ try:
54
+ # Convert to numpy array
55
+ poses_est = np.array(all_cam_to_world_mat)
56
+
57
+ if len(poses_est) < 2:
58
+ print("⚠️ Not enough poses to generate trajectory plot")
59
+ return
60
+
61
+ print(f"🎨 Generating pose trajectory visualization...")
62
+
63
+ # Extract translation part
64
+ positions = poses_est[:, :3, 3] # shape: (N, 3)
65
+
66
+ # Create figure - show XZ-plane projection only
67
+ fig, ax = plt.subplots(1, 1, figsize=(10, 8))
68
+
69
+ # XZ-plane projection
70
+ ax.plot(
71
+ positions[:, 0],
72
+ positions[:, 2],
73
+ "b-",
74
+ linewidth=2,
75
+ label="Predicted Trajectory",
76
+ )
77
+ ax.scatter(
78
+ positions[0, 0], positions[0, 2], color="green", s=100, label="Start"
79
+ )
80
+ ax.scatter(positions[-1, 0], positions[-1, 2], color="red", s=100, label="End")
81
+ ax.set_xlabel("X (m)")
82
+ ax.set_ylabel("Z (m)")
83
+ ax.set_title(f"{scene_name} - XZ-plane projection")
84
+ ax.legend()
85
+ ax.grid(True, alpha=0.3)
86
+
87
+ # Save image
88
+ pose_plot_path = output_scene_dir / "predicted_trajectory.png"
89
+ plt.savefig(pose_plot_path, dpi=300, bbox_inches="tight")
90
+ plt.close()
91
+
92
+ print(f"📊 Trajectory visualization saved: {pose_plot_path}")
93
+
94
+ except Exception as e:
95
+ print(f"⚠️ Failed to generate pose visualization: {e}")
96
+ import traceback
97
+
98
+ traceback.print_exc()
99
+
100
+
101
+ def main():
102
+ """
103
+ Evaluation script for a Custom Dataset.
104
+ Supports optional evaluation and custom dataset structure.
105
+ """
106
+ parser = argparse.ArgumentParser(
107
+ description="Run FastVGGT evaluation on a Custom Dataset"
108
+ )
109
+
110
+ # Required: dataset path
111
+ parser.add_argument(
112
+ "--data_path",
113
+ type=Path,
114
+ required=True,
115
+ help="Dataset path containing subfolders: color, depth, gt_ply, pose",
116
+ )
117
+
118
+ # Optional: enable evaluation
119
+ parser.add_argument(
120
+ "--enable_evaluation",
121
+ action="store_true",
122
+ help="Enable evaluation (requires pose and ply data)",
123
+ )
124
+
125
+ # Output path
126
+ parser.add_argument(
127
+ "--output_path",
128
+ type=Path,
129
+ default="./eval_results_custom",
130
+ help="Output path for evaluation results",
131
+ )
132
+
133
+ # Model parameters
134
+ parser.add_argument(
135
+ "--ckpt_path",
136
+ type=str,
137
+ default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
138
+ help="Model checkpoint file path",
139
+ )
140
+
141
+ parser.add_argument("--merging", type=int, default=0, help="Merging parameter")
142
+
143
+ # Processing parameters
144
+ parser.add_argument(
145
+ "--input_frame",
146
+ type=int,
147
+ default=200,
148
+ help="Maximum number of frames to process per scene",
149
+ )
150
+
151
+ parser.add_argument(
152
+ "--depth_conf_thresh",
153
+ type=float,
154
+ default=3.0,
155
+ help="Depth confidence threshold to filter low-confidence depth values",
156
+ )
157
+
158
+ # Evaluation parameters (only used when evaluation is enabled)
159
+ parser.add_argument(
160
+ "--chamfer_max_dist",
161
+ type=float,
162
+ default=0.5,
163
+ help="Maximum distance threshold used in Chamfer Distance computation",
164
+ )
165
+
166
+ parser.add_argument("--plot", action="store_true", help="Whether to generate plots")
167
+
168
+ parser.add_argument(
169
+ "--vis_attn_map",
170
+ action="store_true",
171
+ help="Visualize attention maps during inference",
172
+ )
173
+
174
+ args = parser.parse_args()
175
+ torch.manual_seed(33)
176
+
177
+ # Check data path exists
178
+ if not args.data_path.exists():
179
+ print(f"❌ Error: Data path does not exist: {args.data_path}")
180
+ return
181
+
182
+ # Check required subdirectories
183
+ color_dir = args.data_path / "images"
184
+ pose_dir = args.data_path / "pose"
185
+
186
+ if not color_dir.exists():
187
+ print(f"❌ Error: color directory does not exist: {color_dir}")
188
+ return
189
+
190
+ print(f"📁 Dataset path: {args.data_path}")
191
+ # print(f"🔧 Enable evaluation: {'Yes' if args.enable_evaluation else 'No'}")
192
+
193
+ # If evaluation is enabled, check pose and gt_ply directories
194
+ if args.enable_evaluation:
195
+ if not pose_dir.exists():
196
+ print(f"❌ Error: Evaluation requires pose directory: {pose_dir}")
197
+ return
198
+
199
+ gt_ply_dir = args.data_path / "gt_ply"
200
+ if not gt_ply_dir.exists():
201
+ print(f"❌ Error: Evaluation requires gt_ply directory: {gt_ply_dir}")
202
+ return
203
+ print(f"📊 Evaluation will use Ground Truth")
204
+ else:
205
+ print(f"🏃 Inference only, no evaluation")
206
+
207
+ # Create output directory
208
+ args.output_path.mkdir(parents=True, exist_ok=True)
209
+ output_scene_dir = args.output_path / "custom_dataset"
210
+
211
+ # Check if already processed
212
+ if (output_scene_dir / "metrics.json").exists() and args.enable_evaluation:
213
+ print(
214
+ f"⚠️ Results already exist, skipping: {output_scene_dir / 'metrics.json'}"
215
+ )
216
+ return
217
+
218
+ # Force use of bf16 dtype
219
+ dtype = torch.bfloat16
220
+
221
+ # Load VGGT model
222
+ print(f"🔄 Loading model: {args.ckpt_path}")
223
+ model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
224
+ ckpt = torch.load(args.ckpt_path, map_location="cpu")
225
+ incompat = model.load_state_dict(ckpt, strict=False)
226
+ # if incompat.missing_keys or incompat.unexpected_keys:
227
+ # print(f"⚠️ Partially incompatible keys when loading model: {incompat}")
228
+ model = model.cuda().eval()
229
+ model = model.to(torch.bfloat16)
230
+ print(f"✅ Model loaded")
231
+
232
+ # Load scene data
233
+ image_paths = get_sorted_image_paths(color_dir)
234
+ if len(image_paths) == 0:
235
+ print(f"❌ Error: No images found in {color_dir}")
236
+ return
237
+
238
+ print(f"🖼️ Found {len(image_paths)} images")
239
+
240
+ # Process pose data (if evaluation is enabled)
241
+ poses_gt = None
242
+ first_gt_pose = None
243
+ available_pose_frame_ids = None
244
+ c2ws = None
245
+
246
+ if args.enable_evaluation:
247
+ poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_dir)
248
+ if (
249
+ poses_gt is None
250
+ or first_gt_pose is None
251
+ or available_pose_frame_ids is None
252
+ ):
253
+ print(f"❌ Error: Failed to load pose data")
254
+ return
255
+ print(f"📐 Loaded {len(poses_gt)} poses")
256
+
257
+ # Frame selection
258
+ if args.enable_evaluation and available_pose_frame_ids is not None:
259
+ # Use pose data for frame selection
260
+ selected_frame_ids, selected_image_paths, selected_pose_indices = (
261
+ build_frame_selection(
262
+ image_paths, available_pose_frame_ids, args.input_frame
263
+ )
264
+ )
265
+ c2ws = poses_gt[selected_pose_indices]
266
+ image_paths = selected_image_paths
267
+ else:
268
+ # Simply take the first N frames
269
+ num_frames = min(len(image_paths), args.input_frame)
270
+ selected_frame_ids = list(range(num_frames))
271
+ image_paths = image_paths[:num_frames]
272
+
273
+ print(f"📋 Selected {len(image_paths)} frames for processing")
274
+
275
+ try:
276
+ # Load images
277
+ print(f"🔄 Loading images...")
278
+ images = load_images_rgb(image_paths)
279
+
280
+ if not images or len(images) < 3:
281
+ print(f"❌ Error: Not enough valid images (need at least 3)")
282
+ return
283
+
284
+ frame_ids = selected_frame_ids
285
+ images_array = np.stack(images)
286
+ vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
287
+ print(f"📐 Image patch dimensions: {patch_width}x{patch_height}")
288
+
289
+ # Update attention layer patch dimensions in the model
290
+ model.update_patch_dimensions(patch_width, patch_height)
291
+
292
+ # Inference + Reconstruction
293
+ print(f"🚀 Start inference and reconstruction...")
294
+ (
295
+ extrinsic_np,
296
+ intrinsic_np,
297
+ all_world_points,
298
+ all_point_colors,
299
+ all_cam_to_world_mat,
300
+ inference_time_ms,
301
+ ) = infer_vggt_and_reconstruct(
302
+ model, vgg_input, dtype, args.depth_conf_thresh, image_paths
303
+ )
304
+ print(f"⏱️ Inference time: {inference_time_ms:.2f}ms")
305
+
306
+ # Check results
307
+ if not all_cam_to_world_mat or not all_world_points:
308
+ print(f"❌ Error: Failed to obtain valid camera poses or point clouds")
309
+ return
310
+
311
+ # print(f"✅ Inference done, obtained {len(all_world_points)} point sets")
312
+
313
+ # Evaluation and saving
314
+ if args.enable_evaluation:
315
+ print(f"📊 Start evaluation...")
316
+ gt_ply_dir = args.data_path / "gt_ply"
317
+ metrics = evaluate_scene_and_save(
318
+ "custom_dataset",
319
+ c2ws,
320
+ first_gt_pose,
321
+ frame_ids,
322
+ all_cam_to_world_mat,
323
+ all_world_points,
324
+ output_scene_dir,
325
+ gt_ply_dir,
326
+ args.chamfer_max_dist,
327
+ inference_time_ms,
328
+ args.plot,
329
+ )
330
+ if metrics is not None:
331
+ print("📈 Evaluation results:")
332
+ for key, value in metrics.items():
333
+ if key in [
334
+ "chamfer_distance",
335
+ "ate",
336
+ "are",
337
+ "rpe_rot",
338
+ "rpe_trans",
339
+ "inference_time_ms",
340
+ ]:
341
+ print(f" {key}: {float(value):.4f}")
342
+
343
+ # Also visualize predicted poses in evaluation branch
344
+ if args.plot:
345
+ visualize_predicted_poses(
346
+ all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
347
+ )
348
+ else:
349
+ # Save reconstruction only, no evaluation
350
+ print(f"💾 Saving reconstruction...")
351
+ output_scene_dir.mkdir(parents=True, exist_ok=True)
352
+
353
+ # Save camera poses
354
+ poses_output_path = output_scene_dir / "estimated_poses.txt"
355
+ with open(poses_output_path, "w") as f:
356
+ for i, pose in enumerate(all_cam_to_world_mat):
357
+ f.write(f"# Frame {frame_ids[i]}\n")
358
+ for row in pose:
359
+ f.write(" ".join(map(str, row)) + "\n")
360
+ f.write("\n")
361
+
362
+ # Save point cloud
363
+ if all_world_points:
364
+ points_output_path = output_scene_dir / "reconstructed_points.ply"
365
+
366
+ # Merge all frames' point clouds and colors
367
+ try:
368
+ merged_point_cloud = np.vstack(all_world_points)
369
+ merged_colors = (
370
+ np.vstack(all_point_colors).astype(np.uint8)
371
+ if all_point_colors is not None and len(all_point_colors) > 0
372
+ else None
373
+ )
374
+ print(
375
+ f"📊 Merged point clouds: {len(all_world_points)} frames, total {len(merged_point_cloud)} points"
376
+ )
377
+
378
+ # If too many points, randomly sample 100000 points
379
+ max_points = 100000
380
+ if len(merged_point_cloud) > max_points:
381
+ print(
382
+ f"🔽 Too many points, randomly sampling {max_points} points..."
383
+ )
384
+ # Randomly choose indices
385
+ indices = np.random.choice(
386
+ len(merged_point_cloud), size=max_points, replace=False
387
+ )
388
+ merged_point_cloud = merged_point_cloud[indices]
389
+ if merged_colors is not None:
390
+ merged_colors = merged_colors[indices]
391
+ print(
392
+ f"✅ Sampling done, kept {len(merged_point_cloud)} points"
393
+ )
394
+
395
+ # Save as PLY (with color)
396
+ with open(points_output_path, "w") as f:
397
+ f.write("ply\n")
398
+ f.write("format ascii 1.0\n")
399
+ f.write(f"element vertex {len(merged_point_cloud)}\n")
400
+ f.write("property float x\n")
401
+ f.write("property float y\n")
402
+ f.write("property float z\n")
403
+ if merged_colors is not None:
404
+ f.write("property uchar red\n")
405
+ f.write("property uchar green\n")
406
+ f.write("property uchar blue\n")
407
+ f.write("end_header\n")
408
+ if merged_colors is None:
409
+ for point in merged_point_cloud:
410
+ if not (np.isnan(point).any() or np.isinf(point).any()):
411
+ f.write(
412
+ f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n"
413
+ )
414
+ else:
415
+ for point, color in zip(merged_point_cloud, merged_colors):
416
+ # Check point validity
417
+ if not (np.isnan(point).any() or np.isinf(point).any()):
418
+ r = int(np.clip(color[0], 0, 255))
419
+ g = int(np.clip(color[1], 0, 255))
420
+ b = int(np.clip(color[2], 0, 255))
421
+ f.write(
422
+ f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r} {g} {b}\n"
423
+ )
424
+
425
+ print(f"💾 Point cloud saved to: {points_output_path}")
426
+
427
+ except Exception as e:
428
+ print(f"⚠️ Error saving point cloud: {e}")
429
+ # If merge fails, try to log per-frame info
430
+ print(f"🔍 Point cloud debug info:")
431
+ for i, frame_points in enumerate(all_world_points):
432
+ print(
433
+ f" Frame {i}: {frame_points.shape if hasattr(frame_points, 'shape') else type(frame_points)}"
434
+ )
435
+ if (
436
+ hasattr(frame_points, "shape")
437
+ and len(frame_points.shape) >= 2
438
+ ):
439
+ print(
440
+ f" Shape: {frame_points.shape}, Dtype: {frame_points.dtype}"
441
+ )
442
+ if frame_points.shape[0] > 0:
443
+ print(
444
+ f" Range: x[{np.min(frame_points[:, 0]):.3f}, {np.max(frame_points[:, 0]):.3f}] "
445
+ f"y[{np.min(frame_points[:, 1]):.3f}, {np.max(frame_points[:, 1]):.3f}] "
446
+ f"z[{np.min(frame_points[:, 2]):.3f}, {np.max(frame_points[:, 2]):.3f}]"
447
+ )
448
+
449
+ print(f"📁 Results saved to: {output_scene_dir}")
450
+
451
+ # Visualize predicted pose trajectory
452
+ if args.plot:
453
+ visualize_predicted_poses(
454
+ all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
455
+ )
456
+
457
+ print(f"🎉 Done!")
458
+
459
+ except Exception as e:
460
+ print(f"❌ Error occurred during processing: {e}")
461
+ import traceback
462
+
463
+ traceback.print_exc()
464
+
465
+
466
+ if __name__ == "__main__":
467
+ main()
FastVGGT/eval/eval_scannet.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+ import sys
7
+
8
+ # Ensure project root is in sys.path for absolute imports like `vggt.*`
9
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
10
+ if ROOT_DIR not in sys.path:
11
+ sys.path.insert(0, ROOT_DIR)
12
+
13
+ from vggt.models.vggt import VGGT
14
+ from vggt.utils.eval_utils import (
15
+ load_poses,
16
+ get_vgg_input_imgs,
17
+ get_sorted_image_paths,
18
+ get_all_scenes,
19
+ build_frame_selection,
20
+ load_images_rgb,
21
+ infer_vggt_and_reconstruct,
22
+ evaluate_scene_and_save,
23
+ compute_average_metrics_and_save,
24
+ )
25
+
26
+
27
+ if __name__ == "__main__":
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument(
30
+ "--data_dir", type=Path, default="/data/scannetv2/process_scannet"
31
+ )
32
+ parser.add_argument(
33
+ "--gt_ply_dir",
34
+ type=Path,
35
+ default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
36
+ )
37
+ parser.add_argument("--output_path", type=Path, default="./eval_results")
38
+ parser.add_argument("--merging", type=int, default=None)
39
+ parser.add_argument("--plot", type=bool, default=True)
40
+ parser.add_argument(
41
+ "--depth_conf_thresh",
42
+ type=float,
43
+ default=3.0,
44
+ help="Depth confidence threshold for filtering low confidence depth values",
45
+ )
46
+ parser.add_argument(
47
+ "--chamfer_max_dist",
48
+ type=float,
49
+ default=0.5,
50
+ help="Maximum distance threshold in Chamfer Distance computation, distances exceeding this value will be clipped",
51
+ )
52
+ parser.add_argument(
53
+ "--input_frame",
54
+ type=int,
55
+ default=200,
56
+ help="Maximum number of frames selected for processing per scene",
57
+ )
58
+ parser.add_argument(
59
+ "--num_scenes",
60
+ type=int,
61
+ default=50,
62
+ help="Maximum number of scenes to evaluate",
63
+ )
64
+ parser.add_argument(
65
+ "--ckpt_path",
66
+ type=str,
67
+ default="./ckpt/model_tracker_fixed_e20.pt",
68
+ help="Path to the model checkpoint file",
69
+ )
70
+ parser.add_argument(
71
+ "--vis_attn_map",
72
+ action="store_true",
73
+ help="Whether to visualize attention maps during inference",
74
+ )
75
+ args = parser.parse_args()
76
+ torch.manual_seed(33)
77
+
78
+ # Scene sampling
79
+ scannet_scenes = get_all_scenes(args.data_dir, args.num_scenes)
80
+ print(f"Evaluate {len(scannet_scenes)} scenes")
81
+
82
+ all_scenes_metrics = {"scenes": {}, "average": {}}
83
+ # Force use of bf16 data type
84
+ dtype = torch.bfloat16
85
+ # Load VGGT model
86
+ model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
87
+ ckpt = torch.load(args.ckpt_path, map_location="cpu")
88
+ incompat = model.load_state_dict(ckpt, strict=False)
89
+ model = model.cuda().eval()
90
+ model = model.to(torch.bfloat16)
91
+
92
+ # Process each scene
93
+ for scene in scannet_scenes:
94
+ scene_dir = args.data_dir / f"{scene}"
95
+ output_scene_dir = args.output_path / f"input_frame_{args.input_frame}" / scene
96
+ if (output_scene_dir / "metrics.json").exists():
97
+ continue
98
+
99
+ # Load scene data
100
+ images_dir = scene_dir / "color"
101
+ pose_path = scene_dir / "pose"
102
+ image_paths = get_sorted_image_paths(images_dir)
103
+ poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_path)
104
+ if (
105
+ poses_gt is None
106
+ or first_gt_pose is None
107
+ or available_pose_frame_ids is None
108
+ ):
109
+ print(f"Skipping scene {scene}: no pose data")
110
+ continue
111
+
112
+ # Frame filtering
113
+ selected_frame_ids, selected_image_paths, selected_pose_indices = (
114
+ build_frame_selection(
115
+ image_paths, available_pose_frame_ids, args.input_frame
116
+ )
117
+ )
118
+
119
+ # Get corresponding poses
120
+ c2ws = poses_gt[selected_pose_indices]
121
+ image_paths = selected_image_paths
122
+
123
+ if len(image_paths) == 0:
124
+ print(f"No images found in {images_dir}")
125
+ continue
126
+
127
+ print("🚩Processing", scene, f"Found {len(image_paths)} images")
128
+ all_cam_to_world_mat = []
129
+ all_world_points = []
130
+
131
+ try:
132
+ # Load images
133
+ images = load_images_rgb(image_paths)
134
+
135
+ if not images or len(images) < 3:
136
+ print(f"Skipping {scene}: insufficient valid images")
137
+ continue
138
+
139
+ frame_ids = selected_frame_ids
140
+ images_array = np.stack(images)
141
+ vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
142
+ print(f"Patch dimensions: {patch_width}x{patch_height}")
143
+
144
+ # Update model attention layers with dynamic patch dimensions
145
+ model.update_patch_dimensions(patch_width, patch_height)
146
+
147
+ # Inference + Reconstruction
148
+ (
149
+ extrinsic_np,
150
+ intrinsic_np,
151
+ all_world_points,
152
+ all_point_colors,
153
+ all_cam_to_world_mat,
154
+ inference_time_ms,
155
+ ) = infer_vggt_and_reconstruct(
156
+ model, vgg_input, dtype, args.depth_conf_thresh, image_paths
157
+ )
158
+ print(f"Inference time: {inference_time_ms:.2f}ms")
159
+
160
+ # Process results
161
+ if not all_cam_to_world_mat or not all_world_points:
162
+ print(
163
+ f"Skipping {scene}: failed to obtain valid camera poses or point clouds"
164
+ )
165
+ continue
166
+
167
+ # Evaluate and save
168
+ metrics = evaluate_scene_and_save(
169
+ scene,
170
+ c2ws,
171
+ first_gt_pose,
172
+ frame_ids,
173
+ all_cam_to_world_mat,
174
+ all_world_points,
175
+ output_scene_dir,
176
+ args.gt_ply_dir,
177
+ args.chamfer_max_dist,
178
+ inference_time_ms,
179
+ args.plot,
180
+ )
181
+ if metrics is not None:
182
+ all_scenes_metrics["scenes"][scene] = {
183
+ key: float(value)
184
+ for key, value in metrics.items()
185
+ if key
186
+ in [
187
+ "chamfer_distance",
188
+ "ate",
189
+ "are",
190
+ "rpe_rot",
191
+ "rpe_trans",
192
+ "inference_time_ms",
193
+ ]
194
+ }
195
+ print("Complete metrics", all_scenes_metrics["scenes"][scene])
196
+
197
+ except Exception as e:
198
+ print(f"Error processing scene {scene}: {e}")
199
+ import traceback
200
+
201
+ traceback.print_exc()
202
+
203
+ # Summarize average metrics and save
204
+ compute_average_metrics_and_save(
205
+ all_scenes_metrics,
206
+ args.output_path,
207
+ args.input_frame,
208
+ )
FastVGGT/eval/utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial import cKDTree as KDTree
3
+
4
+
5
+ def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
6
+ """
7
+ Args:
8
+ - depthmap (HxW array):
9
+ - camera_intrinsics: a 3x3 matrix
10
+ Returns:
11
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
12
+ """
13
+ camera_intrinsics = np.float32(camera_intrinsics)
14
+ H, W = depthmap.shape
15
+
16
+ assert camera_intrinsics[0, 1] == 0.0
17
+ assert camera_intrinsics[1, 0] == 0.0
18
+ if pseudo_focal is None:
19
+ fu = camera_intrinsics[0, 0]
20
+ fv = camera_intrinsics[1, 1]
21
+ else:
22
+ assert pseudo_focal.shape == (H, W)
23
+ fu = fv = pseudo_focal
24
+ cu = camera_intrinsics[0, 2]
25
+ cv = camera_intrinsics[1, 2]
26
+
27
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
28
+ z_cam = depthmap
29
+ x_cam = (u - cu) * z_cam / fu
30
+ y_cam = (v - cv) * z_cam / fv
31
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
32
+
33
+ valid_mask = depthmap > 0.0
34
+ return X_cam, valid_mask
35
+
36
+
37
+ def depthmap_to_absolute_camera_coordinates(
38
+ depthmap, camera_intrinsics, camera_pose, **kw
39
+ ):
40
+ """
41
+ Args:
42
+ - depthmap (HxW array):
43
+ - camera_intrinsics: a 3x3 matrix
44
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
45
+ Returns:
46
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
47
+ """
48
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
49
+
50
+ X_world = X_cam # default
51
+ if camera_pose is not None:
52
+
53
+ R_cam2world = camera_pose[:3, :3]
54
+ t_cam2world = camera_pose[:3, 3]
55
+
56
+ X_world = (
57
+ np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
58
+ )
59
+
60
+ return X_world, valid_mask
61
+
62
+
63
+ def completion_ratio(gt_points, rec_points, dist_th=0.05):
64
+ gen_points_kd_tree = KDTree(rec_points)
65
+ distances, _ = gen_points_kd_tree.query(gt_points)
66
+ comp_ratio = np.mean((distances < dist_th).astype(np.float32))
67
+ return comp_ratio
68
+
69
+
70
+ def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None):
71
+ gt_points_kd_tree = KDTree(gt_points)
72
+ distances, idx = gt_points_kd_tree.query(rec_points, workers=-1)
73
+ acc = np.mean(distances)
74
+
75
+ acc_median = np.median(distances)
76
+
77
+ if gt_normals is not None and rec_normals is not None:
78
+ normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1)
79
+ normal_dot = np.abs(normal_dot)
80
+
81
+ return acc, acc_median, np.mean(normal_dot), np.median(normal_dot)
82
+
83
+ return acc, acc_median
84
+
85
+
86
+ def completion(gt_points, rec_points, gt_normals=None, rec_normals=None):
87
+ gt_points_kd_tree = KDTree(rec_points)
88
+ distances, idx = gt_points_kd_tree.query(gt_points, workers=-1)
89
+ comp = np.mean(distances)
90
+ comp_median = np.median(distances)
91
+
92
+ if gt_normals is not None and rec_normals is not None:
93
+ normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1)
94
+ normal_dot = np.abs(normal_dot)
95
+
96
+ return comp, comp_median, np.mean(normal_dot), np.median(normal_dot)
97
+
98
+ return comp, comp_median
99
+
100
+
101
+ def compute_iou(pred_vox, target_vox):
102
+ # Get voxel indices
103
+ v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()]
104
+ v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()]
105
+
106
+ # Convert to sets for set operations
107
+ v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices)
108
+ v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices)
109
+
110
+ # Compute intersection and union
111
+ intersection = v_pred_filled & v_target_filled
112
+ union = v_pred_filled | v_target_filled
113
+
114
+ # Compute IoU
115
+ iou = len(intersection) / len(union)
116
+ return iou
117
+
118
+
119
+ def colmap_to_opencv_intrinsics(K):
120
+ """
121
+ Modify camera intrinsics to follow a different convention.
122
+ Coordinates of the center of the top-left pixels are by default:
123
+ - (0.5, 0.5) in Colmap
124
+ - (0,0) in OpenCV
125
+ """
126
+ K = K.copy()
127
+ K[0, 2] -= 0.5
128
+ K[1, 2] -= 0.5
129
+ return K
130
+
131
+
132
+ def opencv_to_colmap_intrinsics(K):
133
+ """
134
+ Modify camera intrinsics to follow a different convention.
135
+ Coordinates of the center of the top-left pixels are by default:
136
+ - (0.5, 0.5) in Colmap
137
+ - (0,0) in OpenCV
138
+ """
139
+ K = K.copy()
140
+ K[0, 2] += 0.5
141
+ K[1, 2] += 0.5
142
+ return K
FastVGGT/merging/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import merge
2
+
3
+ __all__ = ["merge"]
FastVGGT/merging/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file
 
FastVGGT/merging/__pycache__/merge.cpython-310.pyc ADDED
Binary file (7.54 kB). View file
 
FastVGGT/merging/merge.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Callable, Optional, Union
3
+
4
+
5
+ @torch.jit.script
6
+ def fast_similarity_chunks(
7
+ a: torch.Tensor, b_transposed: torch.Tensor, chunk_size: int
8
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
9
+
10
+ B, num_src, C = a.shape
11
+ original_dtype = a.dtype
12
+
13
+ # Convert to bf16 for computation to improve performance and reduce memory usage
14
+ a_bf16 = a.to(torch.bfloat16)
15
+ b_transposed_bf16 = b_transposed.to(torch.bfloat16)
16
+ node_max = torch.empty(B, num_src, device=a.device, dtype=original_dtype)
17
+ node_idx = torch.empty(B, num_src, device=a.device, dtype=torch.long)
18
+
19
+ # Process in chunks
20
+ for i in range(0, num_src, chunk_size):
21
+ end_i = min(i + chunk_size, num_src)
22
+ a_chunk = a_bf16[:, i:end_i, :] # [B, chunk_size, C]
23
+ scores_chunk = torch.bmm(a_chunk, b_transposed_bf16)
24
+ chunk_max_bf16, chunk_idx = torch.max(scores_chunk, dim=2)
25
+ chunk_max = chunk_max_bf16.to(original_dtype)
26
+ node_max[:, i:end_i] = chunk_max
27
+ node_idx[:, i:end_i] = chunk_idx
28
+ return node_max, node_idx
29
+
30
+
31
+ def do_nothing(
32
+ x: torch.Tensor,
33
+ extra_tensors=None,
34
+ extra_tensors_2=None,
35
+ ) -> Union[
36
+ torch.Tensor,
37
+ Tuple[torch.Tensor, torch.Tensor],
38
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
39
+ ]:
40
+ if extra_tensors is not None and extra_tensors_2 is not None:
41
+ return x, extra_tensors, extra_tensors_2
42
+ elif extra_tensors is not None:
43
+ return x, extra_tensors
44
+ else:
45
+ return x
46
+
47
+
48
+ def token_merge_bipartite2d(
49
+ metric: torch.Tensor,
50
+ w: int,
51
+ h: int,
52
+ sx: int,
53
+ sy: int,
54
+ r: int,
55
+ no_rand: bool = False,
56
+ generator: Optional[torch.Generator] = None,
57
+ enable_protection: bool = False,
58
+ ) -> Tuple[Callable, Callable]:
59
+ """
60
+ Divide tokens into source (src) and destination (dst) groups, and merge r tokens from src to dst.
61
+ dst tokens are selected by randomly choosing one token from each (sx, sy) region.
62
+ Optionally protect the top 10% of tokens from merging based on importance scores.
63
+
64
+ Args:
65
+ - metric [B, N, C]: Tensor for similarity computation, B=batch size, N=token count, C=feature dimension
66
+ - w: Image width in tokens
67
+ - h: Image height in tokens
68
+ - sx: dst stride in x dimension, must divide w evenly
69
+ - sy: dst stride in y dimension, must divide h evenly
70
+ - r: Number of tokens to remove through merging
71
+ - no_rand: If True, disable randomness (use only top-left token)
72
+ - generator: Random number generator if no_rand is False and not None
73
+ - enable_protection: If True, enable importance protection feature
74
+
75
+ Returns:
76
+ - (merge, unmerge): Two functions for merging tokens and restoring pre-merge state
77
+ """
78
+ B, N, _ = metric.shape # Batch size B, total tokens N
79
+ if r <= 0:
80
+ return do_nothing, do_nothing
81
+
82
+ gather = torch.gather
83
+
84
+ tokens_per_img = w * h + 5
85
+ num_imgs = N // tokens_per_img
86
+ assert tokens_per_img * num_imgs == N, "Token count doesn't match (w*h+5)*num_imgs"
87
+
88
+ with torch.no_grad():
89
+ # Determine whether to compute importance scores based on enable_protection
90
+ if enable_protection:
91
+ num_protected = int(N * 0.1)
92
+ step = max(1, N // num_protected)
93
+ protected_indices = torch.arange(0, N, step, device=metric.device)[
94
+ :num_protected
95
+ ]
96
+ else:
97
+ protected_indices = None
98
+ num_protected = 0
99
+
100
+ # Global idx_buffer_seq of length N; -1 indicates dst, 0 indicates src (maintain original logic)
101
+ idx_buffer_seq = torch.zeros(N, device=metric.device, dtype=torch.int64)
102
+ hsy, wsx = h // sy, w // sx # Number of blocks within each image
103
+
104
+ # Mark first image entirely as dst
105
+ if num_imgs > 0:
106
+ idx_buffer_seq[:tokens_per_img] = -1
107
+
108
+ # Process other images - fully vectorized batch operations
109
+ if num_imgs > 1:
110
+ cls_indices = (
111
+ torch.arange(1, num_imgs, device=metric.device) * tokens_per_img
112
+ )
113
+ cls_indices = cls_indices[:, None] + torch.arange(5, device=metric.device)
114
+ idx_buffer_seq[cls_indices.flatten()] = -1
115
+ effective_h = min(hsy * sy, h)
116
+ effective_w = min(wsx * sx, w)
117
+ effective_grid_size = effective_h * effective_w
118
+
119
+ if no_rand:
120
+ base_pattern = torch.zeros(
121
+ effective_grid_size, device=metric.device, dtype=torch.int64
122
+ )
123
+ grid_starts = (
124
+ torch.arange(1, num_imgs, device=metric.device) * tokens_per_img + 5
125
+ )
126
+ grid_indices = grid_starts[:, None] + torch.arange(
127
+ effective_grid_size, device=metric.device
128
+ )
129
+ idx_buffer_seq[grid_indices.flatten()] = base_pattern.repeat(
130
+ num_imgs - 1
131
+ )
132
+ else:
133
+ total_other_imgs = num_imgs - 1
134
+ all_rand_idx = torch.randint(
135
+ sy * sx,
136
+ size=(total_other_imgs, hsy, wsx),
137
+ device=metric.device,
138
+ generator=generator,
139
+ )
140
+
141
+ scatter_src = -torch.ones(
142
+ total_other_imgs, hsy, wsx, device=metric.device, dtype=torch.int64
143
+ )
144
+
145
+ idx_buffer_batch = torch.zeros(
146
+ total_other_imgs,
147
+ hsy,
148
+ wsx,
149
+ sy * sx,
150
+ device=metric.device,
151
+ dtype=torch.int64,
152
+ )
153
+ idx_buffer_batch.scatter_(
154
+ dim=3,
155
+ index=all_rand_idx.unsqueeze(-1),
156
+ src=scatter_src.unsqueeze(-1),
157
+ )
158
+
159
+ idx_buffer_batch = (
160
+ idx_buffer_batch.view(total_other_imgs, hsy, wsx, sy, sx)
161
+ .transpose(2, 3)
162
+ .reshape(total_other_imgs, hsy * sy, wsx * sx)
163
+ )
164
+
165
+ # Batch fill to target positions - still needs a small loop here, but operations are greatly reduced
166
+ for i in range(total_other_imgs):
167
+ img_idx = i + 1
168
+ grid_start = img_idx * tokens_per_img + 5
169
+ flat_view = idx_buffer_batch[
170
+ i, :effective_h, :effective_w
171
+ ].flatten()
172
+ idx_buffer_seq[grid_start : grid_start + effective_grid_size] = (
173
+ flat_view
174
+ )
175
+
176
+ rand_idx = idx_buffer_seq.reshape(1, -1, 1).argsort(dim=1)
177
+ num_dst_orig = int((idx_buffer_seq == -1).sum())
178
+
179
+ # Original src and dst indices
180
+ a_idx_orig = rand_idx[:, num_dst_orig:, :]
181
+ b_idx_orig = rand_idx[:, :num_dst_orig, :]
182
+ a_idx = a_idx_orig
183
+ b_idx = b_idx_orig
184
+
185
+ if enable_protection:
186
+ protected_idx = protected_indices.unsqueeze(0).unsqueeze(-1)
187
+ num_protected_actual = protected_idx.shape[1]
188
+ else:
189
+ protected_idx = None
190
+ num_protected_actual = 0
191
+
192
+ num_src = a_idx.shape[1]
193
+ num_dst = b_idx.shape[1]
194
+
195
+ # Define an internal function to separate src, dst, and protected tokens
196
+ def split(x):
197
+ C = x.shape[-1]
198
+
199
+ if enable_protection:
200
+ src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
201
+ dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
202
+ protected = gather(
203
+ x, dim=1, index=protected_idx.expand(B, num_protected_actual, C)
204
+ )
205
+ return src, dst, protected
206
+ else:
207
+ src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
208
+ dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
209
+ return src, dst
210
+
211
+ # Compute cosine similarity (normalize first then dot product)
212
+ metric = metric / metric.norm(dim=-1, keepdim=True)
213
+ if enable_protection:
214
+ a, b, protected = split(metric)
215
+ else:
216
+ a, b = split(metric)
217
+
218
+ r = min(a.shape[1], r)
219
+ num_src_actual = a.shape[1]
220
+ chunk_size = min(5000, num_src_actual)
221
+
222
+ node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype)
223
+ node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long)
224
+
225
+ b_transposed = b.transpose(-1, -2)
226
+ node_max, node_idx = fast_similarity_chunks(a, b_transposed, chunk_size)
227
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
228
+
229
+ # If protection is enabled, filter out protected tokens to ensure they are not merged
230
+ if enable_protection:
231
+ src_indices = a_idx[0, :, 0]
232
+ protected_mask_src = torch.isin(src_indices, protected_indices)
233
+ edge_flat = edge_idx[0, :, 0]
234
+ valid_mask = ~protected_mask_src[edge_flat]
235
+ valid_edges = edge_flat[valid_mask]
236
+
237
+ valid_count = valid_edges.shape[0]
238
+ r_actual = min(r, valid_count)
239
+
240
+ unm_idx = valid_edges[r_actual:].unsqueeze(0).unsqueeze(-1)
241
+ src_idx = valid_edges[:r_actual].unsqueeze(0).unsqueeze(-1)
242
+ else:
243
+ unm_idx = edge_idx[..., r:, :]
244
+ src_idx = edge_idx[..., :r, :]
245
+ r_actual = r
246
+
247
+ # Get dst token indices corresponding to each src token to be merged
248
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
249
+ r = r_actual
250
+
251
+ # Define merge function to merge selected src tokens to corresponding dst tokens
252
+ def merge(
253
+ x: torch.Tensor,
254
+ mode: str = "mean",
255
+ extra_tensors=None,
256
+ extra_tensors_2=None,
257
+ ) -> Union[
258
+ torch.Tensor,
259
+ Tuple[torch.Tensor, torch.Tensor],
260
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
261
+ ]:
262
+ if enable_protection:
263
+ src, dst, protected = split(x)
264
+ else:
265
+ src, dst = split(x)
266
+
267
+ n, t1, c = src.shape
268
+
269
+ # Extract unmerged src tokens - using actual unm_idx size
270
+ unm_len = unm_idx.shape[1]
271
+ unm = gather(src, dim=-2, index=unm_idx.expand(n, unm_len, c))
272
+ src_len = src_idx.shape[1]
273
+ src = gather(src, dim=-2, index=src_idx.expand(n, src_len, c))
274
+ dst = dst.scatter_reduce(-2, dst_idx.expand(n, src_len, c), src, reduce=mode)
275
+
276
+ # ---------------- Extra tensor processing ----------------
277
+ merged_extra_1 = None
278
+ merged_extra_2 = None
279
+ if extra_tensors is not None:
280
+ E_dim = extra_tensors.shape[-1]
281
+ if enable_protection:
282
+ src_e, dst_e, protected_e = split(extra_tensors)
283
+ else:
284
+ src_e, dst_e = split(extra_tensors)
285
+
286
+ # Consistent with main tensor, only select r src tokens to be merged
287
+ src_e_r = gather(src_e, dim=-2, index=src_idx.expand(n, src_len, E_dim))
288
+ unm_e = gather(src_e, dim=-2, index=unm_idx.expand(n, unm_len, E_dim))
289
+
290
+ dst_e = dst_e.scatter_reduce(
291
+ -2, dst_idx.expand(n, src_len, E_dim), src_e_r, reduce=mode
292
+ )
293
+ if enable_protection:
294
+ merged_extra_1 = torch.cat([unm_e, dst_e, protected_e], dim=1)
295
+ else:
296
+ merged_extra_1 = torch.cat([unm_e, dst_e], dim=1)
297
+
298
+ if extra_tensors_2 is not None:
299
+ E_dim_2 = extra_tensors_2.shape[-1]
300
+ if enable_protection:
301
+ src_e2, dst_e2, protected_e2 = split(extra_tensors_2)
302
+ else:
303
+ src_e2, dst_e2 = split(extra_tensors_2)
304
+
305
+ src_e2_r = gather(src_e2, dim=-2, index=src_idx.expand(n, src_len, E_dim_2))
306
+ unm_e2 = gather(src_e2, dim=-2, index=unm_idx.expand(n, unm_len, E_dim_2))
307
+
308
+ dst_e2 = dst_e2.scatter_reduce(
309
+ -2, dst_idx.expand(n, src_len, E_dim_2), src_e2_r, reduce=mode
310
+ )
311
+ if enable_protection:
312
+ merged_extra_2 = torch.cat([unm_e2, dst_e2, protected_e2], dim=1)
313
+ else:
314
+ merged_extra_2 = torch.cat([unm_e2, dst_e2], dim=1)
315
+
316
+ if enable_protection:
317
+ main_result = torch.cat([unm, dst, protected], dim=1)
318
+ else:
319
+ main_result = torch.cat([unm, dst], dim=1)
320
+
321
+ if merged_extra_1 is not None and merged_extra_2 is not None:
322
+ return main_result, merged_extra_1, merged_extra_2
323
+ elif merged_extra_1 is not None:
324
+ return main_result, merged_extra_1
325
+ else:
326
+ return main_result
327
+
328
+ # Define unmerge function to restore pre-merge state (for decoder)
329
+ def unmerge(x: torch.Tensor) -> torch.Tensor:
330
+ unm_len = unm_idx.shape[1]
331
+ dst_len = num_dst
332
+ src_len = src_idx.shape[1]
333
+ unm = x[..., :unm_len, :]
334
+ dst = x[..., unm_len : unm_len + dst_len, :]
335
+
336
+ if enable_protection:
337
+ protected = x[
338
+ ..., unm_len + dst_len : unm_len + dst_len + num_protected_actual, :
339
+ ]
340
+
341
+ _, _, c = unm.shape
342
+ src = gather(dst, dim=-2, index=dst_idx.expand(B, src_len, c))
343
+ out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
344
+ out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
345
+ out.scatter_(
346
+ dim=-2,
347
+ index=gather(
348
+ a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx
349
+ ).expand(B, unm_len, c),
350
+ src=unm,
351
+ )
352
+
353
+ out.scatter_(
354
+ dim=-2,
355
+ index=gather(
356
+ a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx
357
+ ).expand(B, src_len, c),
358
+ src=src,
359
+ )
360
+
361
+ if enable_protection:
362
+ out.scatter_(
363
+ dim=-2,
364
+ index=protected_idx.expand(B, num_protected_actual, c),
365
+ src=protected,
366
+ )
367
+
368
+ return out
369
+
370
+ return merge, unmerge
FastVGGT/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ numpy==1.26.1
4
+ Pillow
5
+ huggingface_hub
6
+ einops
7
+ safetensors
8
+ evo
9
+ open3d
10
+ matplotlib
11
+ scipy
12
+ opencv-python
13
+ scikit-image
14
+ tqdm
15
+
FastVGGT/vggt/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (132 Bytes). View file
 
FastVGGT/vggt/dependency/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc ADDED
Binary file (1.39 kB). View file
 
FastVGGT/vggt/dependency/distortion.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def apply_distortion(points, distortion_params):
12
+ """
13
+ Apply distortion to normalized camera coordinates.
14
+
15
+ Args:
16
+ points: Array of normalized camera coordinates
17
+ distortion_params: Distortion parameters
18
+
19
+ Returns:
20
+ Distorted coordinates
21
+ """
22
+ # Simple passthrough for now - implement actual distortion if needed
23
+ return points
24
+
25
+
26
+ def iterative_undistortion(points, distortion_params, max_iter=10):
27
+ """
28
+ Remove distortion from normalized camera coordinates using iterative method.
29
+
30
+ Args:
31
+ points: Array of distorted normalized camera coordinates
32
+ distortion_params: Distortion parameters
33
+ max_iter: Maximum number of iterations
34
+
35
+ Returns:
36
+ Undistorted coordinates
37
+ """
38
+ # Simple passthrough for now - implement actual undistortion if needed
39
+ return points
40
+
41
+
42
+ def single_undistortion(points, distortion_params):
43
+ """
44
+ Remove distortion from normalized camera coordinates using single step.
45
+
46
+ Args:
47
+ points: Array of distorted normalized camera coordinates
48
+ distortion_params: Distortion parameters
49
+
50
+ Returns:
51
+ Undistorted coordinates
52
+ """
53
+ # Simple passthrough for now - implement actual undistortion if needed
54
+ return points
FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc ADDED
Binary file (4.24 kB). View file
 
FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc ADDED
Binary file (3.1 kB). View file
 
FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc ADDED
Binary file (3.41 kB). View file
 
FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
FastVGGT/vggt/heads/camera_head.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from vggt.layers import Mlp
15
+ from vggt.layers.block import Block
16
+ from vggt.heads.head_act import activate_pose
17
+
18
+
19
+ class CameraHead(nn.Module):
20
+ """
21
+ CameraHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ if pose_encoding_type == "absT_quaR_FoV":
41
+ self.target_dim = 9
42
+ else:
43
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
+
45
+ self.trans_act = trans_act
46
+ self.quat_act = quat_act
47
+ self.fl_act = fl_act
48
+ self.trunk_depth = trunk_depth
49
+
50
+ # Build the trunk using a sequence of transformer blocks.
51
+ self.trunk = nn.Sequential(
52
+ *[
53
+ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
54
+ for _ in range(trunk_depth)
55
+ ]
56
+ )
57
+
58
+ # Normalizations for camera token and trunk output.
59
+ self.token_norm = nn.LayerNorm(dim_in)
60
+ self.trunk_norm = nn.LayerNorm(dim_in)
61
+
62
+ # Learnable empty camera pose token.
63
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
64
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
65
+
66
+ # Module for producing modulation parameters: shift, scale, and a gate.
67
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
68
+
69
+ # Adaptive layer normalization without affine parameters.
70
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
71
+ self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
72
+
73
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
74
+ """
75
+ Forward pass to predict camera parameters.
76
+
77
+ Args:
78
+ aggregated_tokens_list (list): List of token tensors from the network;
79
+ the last tensor is used for prediction.
80
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
81
+
82
+ Returns:
83
+ list: A list of predicted camera encodings (post-activation) from each iteration.
84
+ """
85
+ # Use tokens from the last block for camera prediction.
86
+ tokens = aggregated_tokens_list[-1]
87
+
88
+ # Extract the camera tokens
89
+ pose_tokens = tokens[:, :, 0]
90
+ pose_tokens = self.token_norm(pose_tokens)
91
+
92
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
93
+ return pred_pose_enc_list
94
+
95
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
96
+ """
97
+ Iteratively refine camera pose predictions.
98
+
99
+ Args:
100
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
101
+ num_iterations (int): Number of refinement iterations.
102
+
103
+ Returns:
104
+ list: List of activated camera encodings from each iteration.
105
+ """
106
+ B, S, C = pose_tokens.shape # S is expected to be 1.
107
+ pred_pose_enc = None
108
+ pred_pose_enc_list = []
109
+
110
+ for _ in range(num_iterations):
111
+ # Use a learned empty pose for the first iteration.
112
+ if pred_pose_enc is None:
113
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
114
+ else:
115
+ # Detach the previous prediction to avoid backprop through time.
116
+ pred_pose_enc = pred_pose_enc.detach()
117
+ module_input = self.embed_pose(pred_pose_enc)
118
+
119
+ # Generate modulation parameters and split them into shift, scale, and gate components.
120
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
121
+
122
+ # Adaptive layer normalization and modulation.
123
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
124
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
125
+
126
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
127
+ # Compute the delta update for the pose encoding.
128
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
129
+
130
+ if pred_pose_enc is None:
131
+ pred_pose_enc = pred_pose_enc_delta
132
+ else:
133
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
134
+
135
+ # Apply final activation functions for translation, quaternion, and field-of-view.
136
+ activated_pose = activate_pose(
137
+ pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
138
+ )
139
+ pred_pose_enc_list.append(activated_pose)
140
+
141
+ return pred_pose_enc_list
142
+
143
+
144
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
145
+ """
146
+ Modulate the input tensor using scaling and shifting parameters.
147
+ """
148
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
149
+ return x * (1 + scale) + shift
FastVGGT/vggt/heads/dpt_head.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union, Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [0, 1, 2, 3],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [
71
+ nn.Conv2d(
72
+ in_channels=dim_in,
73
+ out_channels=oc,
74
+ kernel_size=1,
75
+ stride=1,
76
+ padding=0,
77
+ )
78
+ for oc in out_channels
79
+ ]
80
+ )
81
+
82
+ # Resize layers for upsampling feature maps.
83
+ self.resize_layers = nn.ModuleList(
84
+ [
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[0],
87
+ out_channels=out_channels[0],
88
+ kernel_size=4,
89
+ stride=4,
90
+ padding=0,
91
+ ),
92
+ nn.ConvTranspose2d(
93
+ in_channels=out_channels[1],
94
+ out_channels=out_channels[1],
95
+ kernel_size=2,
96
+ stride=2,
97
+ padding=0,
98
+ ),
99
+ nn.Identity(),
100
+ nn.Conv2d(
101
+ in_channels=out_channels[3],
102
+ out_channels=out_channels[3],
103
+ kernel_size=3,
104
+ stride=2,
105
+ padding=1,
106
+ ),
107
+ ]
108
+ )
109
+
110
+ self.scratch = _make_scratch(out_channels, features, expand=False)
111
+
112
+ # Attach additional modules to scratch.
113
+ self.scratch.stem_transpose = nn.Identity() # Use Identity instead of None
114
+ self.scratch.refinenet1 = _make_fusion_block(features)
115
+ self.scratch.refinenet2 = _make_fusion_block(features)
116
+ self.scratch.refinenet3 = _make_fusion_block(features)
117
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
118
+
119
+ head_features_1 = features
120
+ head_features_2 = 32
121
+
122
+ if feature_only:
123
+ self.scratch.output_conv1 = nn.Conv2d(
124
+ head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
125
+ )
126
+ else:
127
+ self.scratch.output_conv1 = nn.Conv2d(
128
+ head_features_1,
129
+ head_features_1 // 2,
130
+ kernel_size=3,
131
+ stride=1,
132
+ padding=1,
133
+ )
134
+ conv2_in_channels = head_features_1 // 2
135
+
136
+ self.scratch.output_conv2 = nn.Sequential(
137
+ nn.Conv2d(
138
+ conv2_in_channels,
139
+ head_features_2,
140
+ kernel_size=3,
141
+ stride=1,
142
+ padding=1,
143
+ ),
144
+ nn.ReLU(inplace=True),
145
+ nn.Conv2d(
146
+ head_features_2, output_dim, kernel_size=1, stride=1, padding=0
147
+ ),
148
+ )
149
+
150
+ def forward(
151
+ self,
152
+ aggregated_tokens_list: List[torch.Tensor],
153
+ images: torch.Tensor,
154
+ patch_start_idx: int,
155
+ frames_chunk_size: int = 8,
156
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
157
+ """
158
+ Forward pass through the DPT head, supports processing by chunking frames.
159
+ Args:
160
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
161
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
162
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
163
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
164
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
165
+ If None or larger than S, all frames are processed at once. Default: 8.
166
+
167
+ Returns:
168
+ Tensor or Tuple[Tensor, Tensor]:
169
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
170
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
171
+ """
172
+ B, S, _, H, W = images.shape
173
+
174
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
175
+ if frames_chunk_size is None or frames_chunk_size >= S:
176
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
177
+
178
+ # Otherwise, process frames in chunks to manage memory usage
179
+ assert frames_chunk_size > 0
180
+
181
+ # Process frames in batches
182
+ all_preds = []
183
+ all_conf = []
184
+
185
+ for frames_start_idx in range(0, S, frames_chunk_size):
186
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
187
+
188
+ # Process batch of frames
189
+ if self.feature_only:
190
+ chunk_output = self._forward_impl(
191
+ aggregated_tokens_list,
192
+ images,
193
+ patch_start_idx,
194
+ frames_start_idx,
195
+ frames_end_idx,
196
+ )
197
+ all_preds.append(chunk_output)
198
+ else:
199
+ chunk_preds, chunk_conf = self._forward_impl(
200
+ aggregated_tokens_list,
201
+ images,
202
+ patch_start_idx,
203
+ frames_start_idx,
204
+ frames_end_idx,
205
+ )
206
+ all_preds.append(chunk_preds)
207
+ all_conf.append(chunk_conf)
208
+
209
+ # Concatenate results along the sequence dimension
210
+ if self.feature_only:
211
+ return torch.cat(all_preds, dim=1)
212
+ else:
213
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
214
+
215
+ def _forward_impl(
216
+ self,
217
+ aggregated_tokens_list: List[torch.Tensor],
218
+ images: torch.Tensor,
219
+ patch_start_idx: int,
220
+ frames_start_idx: Optional[int] = None,
221
+ frames_end_idx: Optional[int] = None,
222
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
223
+ """
224
+ Implementation of the forward pass through the DPT head.
225
+
226
+ This method processes a specific chunk of frames from the sequence.
227
+
228
+ Args:
229
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
230
+ images (Tensor): Input images with shape [B, S, 3, H, W].
231
+ patch_start_idx (int): Starting index for patch tokens.
232
+ frames_start_idx (int, optional): Starting index for frames to process.
233
+ frames_end_idx (int, optional): Ending index for frames to process.
234
+
235
+ Returns:
236
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
237
+ """
238
+ if frames_start_idx is not None and frames_end_idx is not None:
239
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
240
+
241
+ B, S, _, H, W = images.shape
242
+
243
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
244
+
245
+ out = []
246
+ dpt_idx = 0
247
+
248
+ for layer_idx in self.intermediate_layer_idx:
249
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
250
+
251
+ # Select frames if processing a chunk
252
+ if frames_start_idx is not None and frames_end_idx is not None:
253
+ x = x[:, frames_start_idx:frames_end_idx]
254
+
255
+ x = x.reshape(B * S, -1, x.shape[-1])
256
+
257
+ x = self.norm(x)
258
+
259
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
260
+
261
+ x = self.projects[dpt_idx](x)
262
+ if self.pos_embed:
263
+ x = self._apply_pos_embed(x, W, H)
264
+ x = self.resize_layers[dpt_idx](x)
265
+
266
+ out.append(x)
267
+ dpt_idx += 1
268
+
269
+ # Fuse features from multiple layers.
270
+ out = self.scratch_forward(out)
271
+ # Interpolate fused output to match target image resolution.
272
+ out = custom_interpolate(
273
+ out,
274
+ (
275
+ int(patch_h * self.patch_size / self.down_ratio),
276
+ int(patch_w * self.patch_size / self.down_ratio),
277
+ ),
278
+ mode="bilinear",
279
+ align_corners=True,
280
+ )
281
+
282
+ if self.pos_embed:
283
+ out = self._apply_pos_embed(out, W, H)
284
+
285
+ if self.feature_only:
286
+ return out.view(B, S, *out.shape[1:])
287
+
288
+ out = self.scratch.output_conv2(out)
289
+ preds, conf = activate_head(
290
+ out, activation=self.activation, conf_activation=self.conf_activation
291
+ )
292
+
293
+ preds = preds.view(B, S, *preds.shape[1:])
294
+ conf = conf.view(B, S, *conf.shape[1:])
295
+ return preds, conf
296
+
297
+ def _apply_pos_embed(
298
+ self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
299
+ ) -> torch.Tensor:
300
+ """
301
+ Apply positional embedding to tensor x.
302
+ """
303
+ patch_w = x.shape[-1]
304
+ patch_h = x.shape[-2]
305
+ pos_embed = create_uv_grid(
306
+ patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
307
+ )
308
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
309
+ pos_embed = pos_embed * ratio
310
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
311
+ return x + pos_embed
312
+
313
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
314
+ """
315
+ Forward pass through the fusion blocks.
316
+
317
+ Args:
318
+ features (List[Tensor]): List of feature maps from different layers.
319
+
320
+ Returns:
321
+ Tensor: Fused feature map.
322
+ """
323
+ layer_1, layer_2, layer_3, layer_4 = features
324
+
325
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
326
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
327
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
328
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
329
+
330
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
331
+ del layer_4_rn, layer_4
332
+
333
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
334
+ del layer_3_rn, layer_3
335
+
336
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
337
+ del layer_2_rn, layer_2
338
+
339
+ out = self.scratch.refinenet1(out, layer_1_rn)
340
+ del layer_1_rn, layer_1
341
+
342
+ out = self.scratch.output_conv1(out)
343
+ return out
344
+
345
+
346
+ ################################################################################
347
+ # Modules
348
+ ################################################################################
349
+
350
+
351
+ def _make_fusion_block(
352
+ features: int,
353
+ size: Optional[int] = None,
354
+ has_residual: bool = True,
355
+ groups: int = 1,
356
+ ) -> nn.Module:
357
+ return FeatureFusionBlock(
358
+ features,
359
+ nn.ReLU(inplace=True),
360
+ deconv=False,
361
+ bn=False,
362
+ expand=False,
363
+ align_corners=True,
364
+ size=size,
365
+ has_residual=has_residual,
366
+ groups=groups,
367
+ )
368
+
369
+
370
+ def _make_scratch(
371
+ in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
372
+ ) -> nn.Module:
373
+ scratch = nn.Module()
374
+ out_shape1 = out_shape
375
+ out_shape2 = out_shape
376
+ out_shape3 = out_shape
377
+ if len(in_shape) >= 4:
378
+ out_shape4 = out_shape
379
+
380
+ if expand:
381
+ out_shape1 = out_shape
382
+ out_shape2 = out_shape * 2
383
+ out_shape3 = out_shape * 4
384
+ if len(in_shape) >= 4:
385
+ out_shape4 = out_shape * 8
386
+
387
+ scratch.layer1_rn = nn.Conv2d(
388
+ in_shape[0],
389
+ out_shape1,
390
+ kernel_size=3,
391
+ stride=1,
392
+ padding=1,
393
+ bias=False,
394
+ groups=groups,
395
+ )
396
+ scratch.layer2_rn = nn.Conv2d(
397
+ in_shape[1],
398
+ out_shape2,
399
+ kernel_size=3,
400
+ stride=1,
401
+ padding=1,
402
+ bias=False,
403
+ groups=groups,
404
+ )
405
+ scratch.layer3_rn = nn.Conv2d(
406
+ in_shape[2],
407
+ out_shape3,
408
+ kernel_size=3,
409
+ stride=1,
410
+ padding=1,
411
+ bias=False,
412
+ groups=groups,
413
+ )
414
+ if len(in_shape) >= 4:
415
+ scratch.layer4_rn = nn.Conv2d(
416
+ in_shape[3],
417
+ out_shape4,
418
+ kernel_size=3,
419
+ stride=1,
420
+ padding=1,
421
+ bias=False,
422
+ groups=groups,
423
+ )
424
+ return scratch
425
+
426
+
427
+ class ResidualConvUnit(nn.Module):
428
+ """Residual convolution module."""
429
+
430
+ def __init__(self, features, activation, bn, groups=1):
431
+ """Init.
432
+
433
+ Args:
434
+ features (int): number of features
435
+ """
436
+ super().__init__()
437
+
438
+ self.bn = bn
439
+ self.groups = groups
440
+ self.conv1 = nn.Conv2d(
441
+ features,
442
+ features,
443
+ kernel_size=3,
444
+ stride=1,
445
+ padding=1,
446
+ bias=True,
447
+ groups=self.groups,
448
+ )
449
+ self.conv2 = nn.Conv2d(
450
+ features,
451
+ features,
452
+ kernel_size=3,
453
+ stride=1,
454
+ padding=1,
455
+ bias=True,
456
+ groups=self.groups,
457
+ )
458
+
459
+ self.norm1 = None
460
+ self.norm2 = None
461
+
462
+ self.activation = activation
463
+
464
+ def forward(self, x):
465
+ """Forward pass.
466
+
467
+ Args:
468
+ x (tensor): input
469
+
470
+ Returns:
471
+ tensor: output
472
+ """
473
+
474
+ out = self.activation(x)
475
+ out = self.conv1(out)
476
+ if self.norm1 is not None:
477
+ out = self.norm1(out)
478
+
479
+ out = self.activation(out)
480
+ out = self.conv2(out)
481
+ if self.norm2 is not None:
482
+ out = self.norm2(out)
483
+
484
+ return out + x
485
+
486
+
487
+ class FeatureFusionBlock(nn.Module):
488
+ """Feature fusion block."""
489
+
490
+ def __init__(
491
+ self,
492
+ features,
493
+ activation,
494
+ deconv=False,
495
+ bn=False,
496
+ expand=False,
497
+ align_corners=True,
498
+ size=None,
499
+ has_residual=True,
500
+ groups=1,
501
+ ):
502
+ """Init.
503
+
504
+ Args:
505
+ features (int): number of features
506
+ """
507
+ super(FeatureFusionBlock, self).__init__()
508
+
509
+ self.deconv = deconv
510
+ self.align_corners = align_corners
511
+ self.groups = groups
512
+ self.expand = expand
513
+ out_features = features
514
+ if self.expand == True:
515
+ out_features = features // 2
516
+
517
+ self.out_conv = nn.Conv2d(
518
+ features,
519
+ out_features,
520
+ kernel_size=1,
521
+ stride=1,
522
+ padding=0,
523
+ bias=True,
524
+ groups=self.groups,
525
+ )
526
+
527
+ if has_residual:
528
+ self.resConfUnit1 = ResidualConvUnit(
529
+ features, activation, bn, groups=self.groups
530
+ )
531
+
532
+ self.has_residual = has_residual
533
+ self.resConfUnit2 = ResidualConvUnit(
534
+ features, activation, bn, groups=self.groups
535
+ )
536
+
537
+ self.size = size
538
+
539
+ def forward(self, *xs, size=None):
540
+ """Forward pass.
541
+
542
+ Returns:
543
+ tensor: output
544
+ """
545
+ output = xs[0]
546
+
547
+ if self.has_residual:
548
+ res = self.resConfUnit1(xs[1])
549
+ output = output + res
550
+
551
+ output = self.resConfUnit2(output)
552
+
553
+ if (size is None) and (self.size is None):
554
+ modifier = {"scale_factor": 2}
555
+ elif size is None:
556
+ modifier = {"size": self.size}
557
+ else:
558
+ modifier = {"size": size}
559
+
560
+ output = custom_interpolate(
561
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
562
+ )
563
+ output = self.out_conv(output)
564
+
565
+ return output
566
+
567
+
568
+ def custom_interpolate(
569
+ x: torch.Tensor,
570
+ size: Optional[Tuple[int, int]] = None,
571
+ scale_factor: Optional[float] = None,
572
+ mode: str = "bilinear",
573
+ align_corners: bool = True,
574
+ ) -> torch.Tensor:
575
+ """
576
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
577
+ """
578
+ if size is None:
579
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
580
+
581
+ INT_MAX = 1610612736
582
+
583
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
584
+
585
+ if input_elements > INT_MAX:
586
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
587
+ interpolated_chunks = [
588
+ nn.functional.interpolate(
589
+ chunk, size=size, mode=mode, align_corners=align_corners
590
+ )
591
+ for chunk in chunks
592
+ ]
593
+ x = torch.cat(interpolated_chunks, dim=0)
594
+ return x.contiguous()
595
+ else:
596
+ return nn.functional.interpolate(
597
+ x, size=size, mode=mode, align_corners=align_corners
598
+ )