mystorm commited on
Commit
3bead73
·
verified ·
1 Parent(s): 8df2bdc

Delete FastVGGT

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. FastVGGT/.gitignore +0 -160
  2. FastVGGT/.vscode/launch.json +0 -85
  3. FastVGGT/README.md +0 -163
  4. FastVGGT/assets/attn_map.png +0 -3
  5. FastVGGT/assets/autolab_logo.png +0 -3
  6. FastVGGT/assets/maclab_logo.png +0 -0
  7. FastVGGT/assets/main.png +0 -3
  8. FastVGGT/assets/vs.png +0 -3
  9. FastVGGT/eval/__pycache__/base.cpython-310.pyc +0 -0
  10. FastVGGT/eval/__pycache__/criterion.cpython-310.pyc +0 -0
  11. FastVGGT/eval/__pycache__/data.cpython-310.pyc +0 -0
  12. FastVGGT/eval/__pycache__/data.cpython-37.pyc +0 -0
  13. FastVGGT/eval/__pycache__/utils.cpython-310.pyc +0 -0
  14. FastVGGT/eval/__pycache__/utils.cpython-37.pyc +0 -0
  15. FastVGGT/eval/base.py +0 -273
  16. FastVGGT/eval/criterion.py +0 -534
  17. FastVGGT/eval/data.py +0 -338
  18. FastVGGT/eval/dataset_utils/__init__.py +0 -1
  19. FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc +0 -0
  20. FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc +0 -0
  21. FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc +0 -0
  22. FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc +0 -0
  23. FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc +0 -0
  24. FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc +0 -0
  25. FastVGGT/eval/dataset_utils/corr.py +0 -234
  26. FastVGGT/eval/dataset_utils/cropping.py +0 -140
  27. FastVGGT/eval/dataset_utils/transforms.py +0 -78
  28. FastVGGT/eval/eval_7andN.py +0 -497
  29. FastVGGT/eval/eval_custom.py +0 -467
  30. FastVGGT/eval/eval_scannet.py +0 -208
  31. FastVGGT/eval/utils.py +0 -142
  32. FastVGGT/merging/__init__.py +0 -3
  33. FastVGGT/merging/__pycache__/__init__.cpython-310.pyc +0 -0
  34. FastVGGT/merging/__pycache__/merge.cpython-310.pyc +0 -0
  35. FastVGGT/merging/merge.py +0 -370
  36. FastVGGT/requirements.txt +0 -15
  37. FastVGGT/vggt/__init__.py +0 -5
  38. FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc +0 -0
  39. FastVGGT/vggt/dependency/__init__.py +0 -5
  40. FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc +0 -0
  41. FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc +0 -0
  42. FastVGGT/vggt/dependency/distortion.py +0 -54
  43. FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc +0 -0
  44. FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
  45. FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc +0 -0
  46. FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc +0 -0
  47. FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc +0 -0
  48. FastVGGT/vggt/heads/camera_head.py +0 -149
  49. FastVGGT/vggt/heads/dpt_head.py +0 -598
  50. FastVGGT/vggt/heads/head_act.py +0 -125
FastVGGT/.gitignore DELETED
@@ -1,160 +0,0 @@
1
- .hydra/
2
- output/
3
- ckpt/
4
- .vscode/
5
- dependency/
6
- # Byte-compiled / optimized / DLL files
7
- __pycache__/
8
- **/__pycache__/
9
- *.py[cod]
10
- *$py.class
11
- test_logs/
12
- quick_start_logs/
13
- logs/
14
- *.pth
15
- /data/
16
- *.png
17
- eval_results/
18
- .vscode/
19
- .curosr/
20
-
21
- # C extensions
22
- *.so
23
- LightGlue/
24
- # Distribution / packaging
25
- .Python
26
- build/
27
- develop-eggs/
28
- dist/
29
- downloads/
30
- eggs/
31
- .eggs/
32
- lib/
33
- lib64/
34
- parts/
35
- sdist/
36
- var/
37
- wheels/
38
- pip-wheel-metadata/
39
- share/python-wheels/
40
- *.egg-info/
41
- .installed.cfg
42
- *.egg
43
- MANIFEST
44
-
45
- # PyInstaller
46
- # Usually these files are written by a python script from a template
47
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
48
- *.manifest
49
- *.spec
50
-
51
- # Installer logs
52
- pip-log.txt
53
- pip-delete-this-directory.txt
54
-
55
- # Unit test / coverage reports
56
- htmlcov/
57
- .tox/
58
- .nox/
59
- .coverage
60
- .coverage.*
61
- .cache
62
- nosetests.xml
63
- coverage.xml
64
- *.cover
65
- *.py,cover
66
- .hypothesis/
67
- .pytest_cache/
68
- cover/
69
-
70
- # Translations
71
- *.mo
72
- *.pot
73
-
74
- # Django stuff:
75
- *.log
76
- local_settings.py
77
- db.sqlite3
78
- db.sqlite3-journal
79
-
80
- # Flask stuff:
81
- instance/
82
- .webassets-cache
83
-
84
- # Scrapy stuff:
85
- .scrapy
86
-
87
- # Sphinx documentation
88
- docs/_build/
89
-
90
- # PyBuilder
91
- target/
92
-
93
- # Jupyter Notebook
94
- .ipynb_checkpoints
95
-
96
- # IPython
97
- profile_default/
98
- ipython_config.py
99
-
100
- # pyenv
101
- .python-version
102
-
103
- # pipenv
104
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
106
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
107
- # install all needed dependencies.
108
- #Pipfile.lock
109
-
110
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow
111
- __pypackages__/
112
-
113
- # Celery stuff
114
- celerybeat-schedule
115
- celerybeat.pid
116
-
117
- # SageMath parsed files
118
- *.sage.py
119
-
120
- # Environments
121
- .env
122
- .venv
123
- env/
124
- venv/
125
- ENV/
126
- env.bak/
127
- venv.bak/
128
-
129
- # Spyder project settings
130
- .spyderproject
131
- .spyproject
132
-
133
- # Rope project settings
134
- .ropeproject
135
-
136
- # mkdocs documentation
137
- /site
138
-
139
- # mypy
140
- .mypy_cache/
141
- .dmypy.json
142
- dmypy.json
143
-
144
- # Pyre type checker
145
- .pyre/
146
-
147
- # pytype static type analyzer
148
- .pytype/
149
-
150
- # Profiling data
151
- .prof
152
-
153
- # Folder specific to your needs
154
- **/tmp/
155
- **/outputs/skyseg.onnx
156
- skyseg.onnx
157
-
158
- # pixi environments
159
- .pixi
160
- *.egg-info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/.vscode/launch.json DELETED
@@ -1,85 +0,0 @@
1
- {
2
- // Use IntelliSense to learn about possible attributes.
3
- // Hover to view descriptions of existing attributes.
4
- // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
- "version": "0.2.0",
6
- "configurations": [
7
-
8
- {
9
- "name": "launch",
10
- "type": "debugpy",
11
- "request": "launch",
12
- "program": "/home/sy/code/vggt_0625/training/launch.py",
13
- "console": "integratedTerminal",
14
- "args": "${command:pickArgs}",
15
- "env": {
16
- "CUDA_VISIBLE_DEVICES": "3",
17
- },
18
- "cwd": "/home/sy/code/vggt_0625/training",
19
- "justMyCode": true,
20
- "python": "/home/sy/anaconda3/envs/vggt/bin/python"
21
- }
22
- ,{
23
- "name": "train_scannet",
24
- "type": "debugpy",
25
- "request": "launch",
26
- "program": "/home/sy/code/vggt_0625/training/launch_scannet.py",
27
- "console": "integratedTerminal",
28
- "args": [
29
- // "--config_name", "scannet",
30
- // "--exp_name", "scannet_exp001",
31
- // "--resume_checkpoint_path", "/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt"
32
- ],
33
- "env": {
34
- "CUDA_VISIBLE_DEVICES": "7",
35
- "WORLD_SIZE": "1",
36
- "RANK": "0",
37
- "MASTER_ADDR": "localhost",
38
- "MASTER_PORT": "12345"
39
- },
40
- "cwd": "/home/sy/code/vggt_0625/training",
41
- "justMyCode": true,
42
- "python": "/home/sy/anaconda3/envs/vggt/bin/python"
43
- }
44
- ,{
45
- "name": "eval_scannet",
46
- "type": "debugpy",
47
- "request": "launch",
48
- "program": "/home/sy/code/FastVGGT/eval/eval_scannet.py",
49
- "console": "integratedTerminal",
50
- "args": [
51
- "--data_dir","/data/sy/scannetv2/process_scannet",
52
- "--gt_ply_dir","/data/sy/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
53
- "--output_path", "/home/sy/code/FastVGGT/eval_results",
54
- "--merging", "0",
55
- "--ckpt_path","/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt",
56
- "--vis_attn_map"
57
- ],
58
- "env": {
59
- "CUDA_VISIBLE_DEVICES": "2"
60
- },
61
- "justMyCode": true,
62
- "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
63
- },
64
- {
65
- "name": "eval_cd",
66
- "type": "debugpy",
67
- "request": "launch",
68
- "program": "/home/sy/code/FastVGGT/eval/eval_custom.py",
69
- "console": "integratedTerminal",
70
- "args": [
71
- "--merging", "0",
72
- // "--kf","10",
73
- // "--output_dir","/home/sy/code/vggt_0625/eval_results_cd",
74
- "--data_path","/data/sy/segment-102751/",
75
- "--vis_attn_map"
76
- ],
77
- "env": {
78
- "CUDA_VISIBLE_DEVICES": "3"
79
- },
80
- "justMyCode": true,
81
- // "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
82
- }
83
-
84
- ]
85
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/README.md DELETED
@@ -1,163 +0,0 @@
1
- <div align="center">
2
- <h2>⚡️ FastVGGT: Training-Free Acceleration of Visual Geometry Transformer</h2>
3
-
4
- <p align="center">
5
- <a href="https://arxiv.org/abs/2509.02560"><img src="https://img.shields.io/badge/arXiv-FastVGGT-red?logo=arxiv" alt="Paper PDF"></a>
6
- <a href="https://mystorm16.github.io/fastvggt/"><img src="https://img.shields.io/badge/Project_Page-FastVGGT-yellow" alt="Project Page"></a>
7
- </p>
8
-
9
- <img src="assets/maclab_logo.png" alt="Maclab Logo" width="110" style="margin-right: 40px;">
10
- <img src="assets/autolab_logo.png" alt="Autolab Logo" width="110">
11
-
12
-
13
- **[Media Analytics & Computing Laboratory](https://mac.xmu.edu.cn/)**; **[AUTOLAB](https://zhipengzhang.cn/)**
14
-
15
-
16
- [You Shen](https://mystorm16.github.io/), [Zhipeng Zhang](https://zhipengzhang.cn/), [Yansong Qu](https://quyans.github.io/), [Liujuan Cao](https://mac.xmu.edu.cn/ljcao/)
17
- </div>
18
-
19
-
20
- ## 📰 News
21
- - [Sep 8, 2025] Added custom dataset evaluation.
22
- - [Sep 3, 2025] Paper release.
23
- - [Sep 2, 2025] Code release.
24
-
25
- ## 🔭 Overview
26
-
27
- FastVGGT observes **strong similarity** in attention maps and leverages it to design a training-free acceleration method for long-sequence 3D reconstruction, **achieving up to 4× faster inference without sacrificing accuracy.**
28
-
29
- <img src="assets/main.png" alt="Autolab Logo" width="">
30
-
31
-
32
- ## ⚙️ Environment Setup
33
- First, create a virtual environment using Conda, clone this repository to your local machine, and install the required dependencies.
34
-
35
-
36
- ```bash
37
- conda create -n fastvggt python=3.10
38
- conda activate fastvggt
39
- git clone git@github.com:mystorm16/FastVGGT.git
40
- cd FastVGGT
41
- pip install -r requirements.txt
42
- ```
43
-
44
- Next, prepare the ScanNet dataset: http://www.scan-net.org/ScanNet/
45
-
46
- Then, download the VGGT checkpoint (we use the checkpoint link provided in https://github.com/facebookresearch/vggt/tree/evaluation/evaluation):
47
- ```bash
48
- wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt
49
- ```
50
-
51
- Finally, configure the dataset path and VGGT checkpoint path. For example:
52
- ```bash
53
- parser.add_argument(
54
- "--data_dir", type=Path, default="/data/scannetv2/process_scannet"
55
- )
56
- parser.add_argument(
57
- "--gt_ply_dir",
58
- type=Path,
59
- default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
60
- )
61
- parser.add_argument(
62
- "--ckpt_path",
63
- type=str,
64
- default="./ckpt/model_tracker_fixed_e20.pt",
65
- )
66
- ```
67
-
68
-
69
- ## 💎 Observation
70
-
71
- Note: A large number of input_frames may significantly slow down saving the visualization results. Please try using a smaller number first.
72
- ```bash
73
- python eval/eval_scannet.py --input_frame 30 --vis_attn_map --merging 0
74
- ```
75
-
76
- We observe that many token-level attention maps are highly similar in each block, motivating our optimization of the Global Attention module.
77
-
78
- <img src="assets/attn_map.png" alt="Autolab Logo" width="">
79
-
80
-
81
-
82
- ## 🏀 Evaluation
83
- ### Custom Dataset
84
- Please organize the data according to the following directory:
85
- ```
86
- <data_path>/
87
- ├── images/
88
- │ ├── 000000.jpg
89
- │ ├── 000001.jpg
90
- │ └── ...
91
- ├── pose/ # Optional: Camera poses
92
- │ ├── 000000.txt
93
- │ ├── 000001.txt
94
- │ └── ...
95
- └── gt_ply/ # Optional: GT point cloud
96
- └── scene_xxx.ply
97
- ```
98
- - Required: `images/`
99
- - Additionally required when `--enable_evaluation` is enabled: `pose/` and `gt_ply/`
100
-
101
- Inference only:
102
-
103
- ```bash
104
- python eval/eval_custom.py \
105
- --data_path /path/to/your_dataset \
106
- --output_path ./eval_results_custom \
107
- --plot
108
- ```
109
-
110
- Inference + Evaluation (requires `pose/` and `gt_ply/`):
111
-
112
- ```bash
113
- python eval/eval_custom.py \
114
- --data_path /path/to/your_dataset \
115
- --enable_evaluation \
116
- --output_path ./eval_results_custom \
117
- --plot
118
- ```
119
-
120
- ### ScanNet
121
- Evaluate FastVGGT on the ScanNet dataset with 1,000 input images. The **--merging** parameter specifies the block index at which the merging strategy is applied:
122
-
123
- ```bash
124
- python eval/eval_scannet.py --input_frame 1000 --merging 0
125
- ```
126
-
127
- Evaluate Baseline VGGT on the ScanNet dataset with 1,000 input images:
128
- ```bash
129
- python eval/eval_scannet.py --input_frame 1000
130
- ```
131
- <img src="assets/vs.png" alt="Autolab Logo" width="">
132
-
133
- ### 7 Scenes & NRGBD
134
- Evaluate across two datasets, sampling keyframes every 10 frames:
135
- ```bash
136
- python eval/eval_7andN.py --kf 10
137
- ```
138
-
139
- ## 🍺 Acknowledgements
140
-
141
- - Thanks to these great repositories: [VGGT](https://github.com/facebookresearch/vggt), [Dust3r](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [CUT3R](https://github.com/CUT3R/CUT3R), [MV-DUSt3R+](https://github.com/facebookresearch/mvdust3r), [StreamVGGT](https://github.com/wzzheng/StreamVGGT), [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long), [ToMeSD](https://github.com/dbolya/tomesd) and many other inspiring works in the community.
142
-
143
- - Special thanks to [Jianyuan Wang](https://jytime.github.io/) for his valuable discussions and suggestions on this work.
144
-
145
- <!-- ## ✍️ Checklist
146
-
147
- - [ ] Release the evaluation code on 7 Scenes / NRGBD -->
148
-
149
-
150
- ## ⚖️ License
151
- See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
152
-
153
- ## Citation
154
-
155
- If you find this project helpful, please consider citing the following paper:
156
- ```
157
- @article{shen2025fastvggt,
158
- title={FastVGGT: Training-Free Acceleration of Visual Geometry Transformer},
159
- author={Shen, You and Zhang, Zhipeng and Qu, Yansong and Cao, Liujuan},
160
- journal={arXiv preprint arXiv:2509.02560},
161
- year={2025}
162
- }
163
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/assets/attn_map.png DELETED

Git LFS Details

  • SHA256: 8477957f593c203bcf41df91ac3ed0d22329e22250fdab9f8f8674340964242c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
FastVGGT/assets/autolab_logo.png DELETED

Git LFS Details

  • SHA256: 4fcead3160cbf561c4385cc8b938a17a94652e3d849da6497f053d32d1245596
  • Pointer size: 132 Bytes
  • Size of remote file: 5.13 MB
FastVGGT/assets/maclab_logo.png DELETED
Binary file (4.8 kB)
 
FastVGGT/assets/main.png DELETED

Git LFS Details

  • SHA256: eecacb414647f01dc8a52b4aba5ff2556733f46d1b9129613e3f59aceff69685
  • Pointer size: 131 Bytes
  • Size of remote file: 884 kB
FastVGGT/assets/vs.png DELETED

Git LFS Details

  • SHA256: dac1ea5397b9f985c6fb86fc5f321313cc5a372d0917b7c1c1c7dd2cea6bec5f
  • Pointer size: 131 Bytes
  • Size of remote file: 230 kB
FastVGGT/eval/__pycache__/base.cpython-310.pyc DELETED
Binary file (6.92 kB)
 
FastVGGT/eval/__pycache__/criterion.cpython-310.pyc DELETED
Binary file (13.6 kB)
 
FastVGGT/eval/__pycache__/data.cpython-310.pyc DELETED
Binary file (7.78 kB)
 
FastVGGT/eval/__pycache__/data.cpython-37.pyc DELETED
Binary file (8.03 kB)
 
FastVGGT/eval/__pycache__/utils.cpython-310.pyc DELETED
Binary file (3.99 kB)
 
FastVGGT/eval/__pycache__/utils.cpython-37.pyc DELETED
Binary file (4.32 kB)
 
FastVGGT/eval/base.py DELETED
@@ -1,273 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # base class for implementing datasets
6
- # --------------------------------------------------------
7
- import PIL
8
- import numpy as np
9
- import torch
10
-
11
- from dataset_utils.transforms import ImgNorm
12
- import dataset_utils.cropping as cropping
13
- from utils import depthmap_to_absolute_camera_coordinates
14
-
15
-
16
- class BaseStereoViewDataset:
17
- """Define all basic options.
18
-
19
- Usage:
20
- class MyDataset (BaseStereoViewDataset):
21
- def _get_views(self, idx, rng):
22
- # overload here
23
- views = []
24
- views.append(dict(img=, ...))
25
- return views
26
- """
27
-
28
- def __init__(
29
- self,
30
- *, # only keyword arguments
31
- split=None,
32
- resolution=None, # square_size or (width, height) or list of [(width,height), ...]
33
- transform=ImgNorm,
34
- aug_crop=False,
35
- seed=None,
36
- ):
37
- self.num_views = 2
38
- self.split = split
39
- self._set_resolutions(resolution)
40
-
41
- self.transform = transform
42
- if isinstance(transform, str):
43
- transform = eval(transform)
44
-
45
- self.aug_crop = aug_crop
46
- self.seed = seed
47
-
48
- def __len__(self):
49
- return len(self.scenes)
50
-
51
- def get_stats(self):
52
- return f"{len(self)} pairs"
53
-
54
- def __repr__(self):
55
- resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
56
- return (
57
- f"""{type(self).__name__}({self.get_stats()},
58
- {self.split=},
59
- {self.seed=},
60
- resolutions={resolutions_str},
61
- {self.transform=})""".replace(
62
- "self.", ""
63
- )
64
- .replace("\n", "")
65
- .replace(" ", "")
66
- )
67
-
68
- def _get_views(self, idx, resolution, rng):
69
- raise NotImplementedError()
70
-
71
- def __getitem__(self, idx):
72
- if isinstance(idx, tuple):
73
- # the idx is specifying the aspect-ratio
74
- idx, ar_idx = idx
75
- else:
76
- assert len(self._resolutions) == 1
77
- ar_idx = 0
78
-
79
- # set-up the rng
80
- if self.seed: # reseed for each __getitem__
81
- self._rng = np.random.default_rng(seed=self.seed + idx)
82
- elif not hasattr(self, "_rng"):
83
- seed = torch.initial_seed() # this is different for each dataloader process
84
- self._rng = np.random.default_rng(seed=seed)
85
-
86
- # over-loaded code
87
- resolution = self._resolutions[
88
- ar_idx
89
- ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
90
- views = self._get_views(idx, resolution, self._rng)
91
-
92
- # check data-types
93
- for v, view in enumerate(views):
94
- assert (
95
- "pts3d" not in view
96
- ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
97
- view["idx"] = v
98
-
99
- # encode the image
100
- width, height = view["img"].size
101
- view["true_shape"] = np.int32((height, width))
102
- view["img"] = self.transform(view["img"])
103
-
104
- assert "camera_intrinsics" in view
105
- if "camera_pose" not in view:
106
- view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
107
- else:
108
- assert np.isfinite(
109
- view["camera_pose"]
110
- ).all(), f"NaN in camera pose for view {view_name(view)}"
111
- assert "pts3d" not in view
112
- assert "valid_mask" not in view
113
- assert np.isfinite(
114
- view["depthmap"]
115
- ).all(), f"NaN in depthmap for view {view_name(view)}"
116
- pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
117
-
118
- view["pts3d"] = pts3d
119
- view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
120
-
121
- # check all datatypes
122
- for key, val in view.items():
123
- res, err_msg = is_good_type(key, val)
124
- assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
125
- K = view["camera_intrinsics"]
126
- view["img_mask"] = True
127
- view["ray_mask"] = False
128
- view["ray_map"] = torch.full(
129
- (6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
130
- )
131
- view["update"] = True
132
- view["reset"] = False
133
-
134
- # last thing done!
135
- for view in views:
136
- # transpose to make sure all views are the same size
137
- transpose_to_landscape(view)
138
- # this allows to check whether the RNG is is the same state each time
139
- view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
140
- return views
141
-
142
- def _set_resolutions(self, resolutions):
143
- """Set the resolution(s) of the dataset.
144
- Params:
145
- - resolutions: int or tuple or list of tuples
146
- """
147
- assert resolutions is not None, "undefined resolution"
148
-
149
- if not isinstance(resolutions, list):
150
- resolutions = [resolutions]
151
-
152
- self._resolutions = []
153
- for resolution in resolutions:
154
- if isinstance(resolution, int):
155
- width = height = resolution
156
- else:
157
- width, height = resolution
158
- assert isinstance(
159
- width, int
160
- ), f"Bad type for {width=} {type(width)=}, should be int"
161
- assert isinstance(
162
- height, int
163
- ), f"Bad type for {height=} {type(height)=}, should be int"
164
- assert width >= height
165
- self._resolutions.append((width, height))
166
-
167
- def _crop_resize_if_necessary(
168
- self, image, depthmap, intrinsics, resolution, rng=None, info=None
169
- ):
170
- """This function:
171
- - first downsizes the image with LANCZOS inteprolation,
172
- which is better than bilinear interpolation in
173
- """
174
- if not isinstance(image, PIL.Image.Image):
175
- image = PIL.Image.fromarray(image)
176
-
177
- # downscale with lanczos interpolation so that image.size == resolution
178
- # cropping centered on the principal point
179
- W, H = image.size
180
- cx, cy = intrinsics[:2, 2].round().astype(int)
181
-
182
- # calculate min distance to margin
183
- min_margin_x = min(cx, W - cx)
184
- min_margin_y = min(cy, H - cy)
185
- assert min_margin_x > W / 5, f"Bad principal point in view={info}"
186
- assert min_margin_y > H / 5, f"Bad principal point in view={info}"
187
-
188
- ## Center crop
189
- # Crop on the principal point, make it always centered
190
- # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
191
- l, t = cx - min_margin_x, cy - min_margin_y
192
- r, b = cx + min_margin_x, cy + min_margin_y
193
- crop_bbox = (l, t, r, b)
194
-
195
- image, depthmap, intrinsics = cropping.crop_image_depthmap(
196
- image, depthmap, intrinsics, crop_bbox
197
- )
198
-
199
- # # transpose the resolution if necessary
200
- W, H = image.size # new size
201
- assert resolution[0] >= resolution[1]
202
- if H > 1.1 * W:
203
- # image is portrait mode
204
- resolution = resolution[::-1]
205
- elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
206
- # image is square, so we chose (portrait, landscape) randomly
207
- if rng.integers(2):
208
- resolution = resolution[::-1]
209
-
210
- # high-quality Lanczos down-scaling
211
- target_resolution = np.array(resolution)
212
- # # if self.aug_crop > 1:
213
- # # target_resolution += rng.integers(0, self.aug_crop)
214
- # if resolution != (224, 224):
215
- # halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
216
- # ## Recale with max factor, so one of width or height might be larger than target_resolution
217
- # image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
218
- # else:
219
- image, depthmap, intrinsics = cropping.rescale_image_depthmap(
220
- image, depthmap, intrinsics, target_resolution
221
- )
222
- # actual cropping (if necessary) with bilinear interpolation
223
- # if resolution == (224, 224):
224
- intrinsics2 = cropping.camera_matrix_of_crop(
225
- intrinsics, image.size, resolution, offset_factor=0.5
226
- )
227
- crop_bbox = cropping.bbox_from_intrinsics_in_out(
228
- intrinsics, intrinsics2, resolution
229
- )
230
- image, depthmap, intrinsics = cropping.crop_image_depthmap(
231
- image, depthmap, intrinsics, crop_bbox
232
- )
233
- return image, depthmap, intrinsics
234
-
235
-
236
- def is_good_type(key, v):
237
- """returns (is_good, err_msg)"""
238
- if isinstance(v, (str, int, tuple)):
239
- return True, None
240
- if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
241
- return False, f"bad {v.dtype=}"
242
- return True, None
243
-
244
-
245
- def view_name(view, batch_index=None):
246
- def sel(x):
247
- return x[batch_index] if batch_index not in (None, slice(None)) else x
248
-
249
- db = sel(view["dataset"])
250
- label = sel(view["label"])
251
- instance = sel(view["instance"])
252
- return f"{db}/{label}/{instance}"
253
-
254
-
255
- def transpose_to_landscape(view):
256
- height, width = view["true_shape"]
257
-
258
- if width < height:
259
- # rectify portrait to landscape
260
- assert view["img"].shape == (3, height, width)
261
- view["img"] = view["img"].swapaxes(1, 2)
262
-
263
- assert view["valid_mask"].shape == (height, width)
264
- view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
265
-
266
- assert view["depthmap"].shape == (height, width)
267
- view["depthmap"] = view["depthmap"].swapaxes(0, 1)
268
-
269
- assert view["pts3d"].shape == (height, width, 3)
270
- view["pts3d"] = view["pts3d"].swapaxes(0, 1)
271
-
272
- # transpose x and y pixels
273
- view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/eval/criterion.py DELETED
@@ -1,534 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from copy import copy, deepcopy
4
-
5
- from eval.dataset_utils.corr import geotrf, inv
6
-
7
-
8
- def invalid_to_nans(arr, valid_mask, ndim=999):
9
- if valid_mask is not None:
10
- arr = arr.clone()
11
- arr[~valid_mask] = float("nan")
12
- if arr.ndim > ndim:
13
- arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
14
- return arr
15
-
16
-
17
- def invalid_to_zeros(arr, valid_mask, ndim=999):
18
- if valid_mask is not None:
19
- arr = arr.clone()
20
- arr[~valid_mask] = 0
21
- nnz = valid_mask.view(len(valid_mask), -1).sum(1)
22
- else:
23
- nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
24
- if arr.ndim > ndim:
25
- arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
26
- return arr, nnz
27
-
28
-
29
- class BaseCriterion(nn.Module):
30
- def __init__(self, reduction="mean"):
31
- super().__init__()
32
- self.reduction = reduction
33
-
34
-
35
- class Criterion(nn.Module):
36
- def __init__(self, criterion=None):
37
- super().__init__()
38
- assert isinstance(
39
- criterion, BaseCriterion
40
- ), f"{criterion} is not a proper criterion!"
41
- self.criterion = copy(criterion)
42
-
43
- def get_name(self):
44
- return f"{type(self).__name__}({self.criterion})"
45
-
46
- def with_reduction(self, mode="none"):
47
- res = loss = deepcopy(self)
48
- while loss is not None:
49
- assert isinstance(loss, Criterion)
50
- loss.criterion.reduction = mode # make it return the loss for each sample
51
- loss = loss._loss2 # we assume loss is a Multiloss
52
- return res
53
-
54
-
55
- class MultiLoss(nn.Module):
56
- """Easily combinable losses (also keep track of individual loss values):
57
- loss = MyLoss1() + 0.1*MyLoss2()
58
- Usage:
59
- Inherit from this class and override get_name() and compute_loss()
60
- """
61
-
62
- def __init__(self):
63
- super().__init__()
64
- self._alpha = 1
65
- self._loss2 = None
66
-
67
- def compute_loss(self, *args, **kwargs):
68
- raise NotImplementedError()
69
-
70
- def get_name(self):
71
- raise NotImplementedError()
72
-
73
- def __mul__(self, alpha):
74
- assert isinstance(alpha, (int, float))
75
- res = copy(self)
76
- res._alpha = alpha
77
- return res
78
-
79
- __rmul__ = __mul__ # same
80
-
81
- def __add__(self, loss2):
82
- assert isinstance(loss2, MultiLoss)
83
- res = cur = copy(self)
84
-
85
- while cur._loss2 is not None:
86
- cur = cur._loss2
87
- cur._loss2 = loss2
88
- return res
89
-
90
- def __repr__(self):
91
- name = self.get_name()
92
- if self._alpha != 1:
93
- name = f"{self._alpha:g}*{name}"
94
- if self._loss2:
95
- name = f"{name} + {self._loss2}"
96
- return name
97
-
98
- def forward(self, *args, **kwargs):
99
- loss = self.compute_loss(*args, **kwargs)
100
- if isinstance(loss, tuple):
101
- loss, details = loss
102
- elif loss.ndim == 0:
103
- details = {self.get_name(): float(loss)}
104
- else:
105
- details = {}
106
- loss = loss * self._alpha
107
-
108
- if self._loss2:
109
- loss2, details2 = self._loss2(*args, **kwargs)
110
- loss = loss + loss2
111
- details |= details2
112
-
113
- return loss, details
114
-
115
-
116
- class LLoss(BaseCriterion):
117
- """L-norm loss"""
118
-
119
- def forward(self, a, b):
120
- assert (
121
- a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
122
- ), f"Bad shape = {a.shape}"
123
- dist = self.distance(a, b)
124
-
125
- if self.reduction == "none":
126
- return dist
127
- if self.reduction == "sum":
128
- return dist.sum()
129
- if self.reduction == "mean":
130
- return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
131
- raise ValueError(f"bad {self.reduction=} mode")
132
-
133
- def distance(self, a, b):
134
- raise NotImplementedError()
135
-
136
-
137
- class L21Loss(LLoss):
138
- """Euclidean distance between 3d points"""
139
-
140
- def distance(self, a, b):
141
- return torch.norm(a - b, dim=-1) # normalized L2 distance
142
-
143
-
144
- L21 = L21Loss()
145
-
146
-
147
- def get_pred_pts3d(gt, pred, use_pose=False):
148
- assert use_pose is True
149
- return pred["pts3d_in_other_view"] # return!
150
-
151
-
152
- def Sum(losses, masks, conf=None):
153
- loss, mask = losses[0], masks[0]
154
- if loss.ndim > 0:
155
- # we are actually returning the loss for every pixels
156
- if conf is not None:
157
- return losses, masks, conf
158
- return losses, masks
159
- else:
160
- # we are returning the global loss
161
- for loss2 in losses[1:]:
162
- loss = loss + loss2
163
- return loss
164
-
165
-
166
- def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
167
- assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
168
- assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
169
- norm_mode, dis_mode = norm_mode.split("_")
170
-
171
- nan_pts = []
172
- nnzs = []
173
-
174
- if norm_mode == "avg":
175
- # gather all points together (joint normalization)
176
-
177
- for i, pt in enumerate(pts):
178
- nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
179
- nan_pts.append(nan_pt)
180
- nnzs.append(nnz)
181
-
182
- if fix_first:
183
- break
184
- all_pts = torch.cat(nan_pts, dim=1)
185
-
186
- # compute distance to origin
187
- all_dis = all_pts.norm(dim=-1)
188
- if dis_mode == "dis":
189
- pass # do nothing
190
- elif dis_mode == "log1p":
191
- all_dis = torch.log1p(all_dis)
192
- else:
193
- raise ValueError(f"bad {dis_mode=}")
194
-
195
- norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
196
- else:
197
- raise ValueError(f"Not implemented {norm_mode=}")
198
-
199
- norm_factor = norm_factor.clip(min=1e-8)
200
- while norm_factor.ndim < pts[0].ndim:
201
- norm_factor.unsqueeze_(-1)
202
-
203
- return norm_factor
204
-
205
-
206
- def normalize_pointcloud_t(
207
- pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
208
- ):
209
- if gt:
210
- norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
211
- res = []
212
-
213
- for i, pt in enumerate(pts):
214
- res.append(pt / norm_factor)
215
-
216
- else:
217
- # pts_l, pts_r = pts
218
- # use pts_l and pts_r[-1] as pts to normalize
219
- norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
220
-
221
- res = []
222
-
223
- for i in range(len(pts)):
224
- res.append(pts[i] / norm_factor)
225
- # res_r.append(pts_r[i] / norm_factor)
226
-
227
- # res = [res_l, res_r]
228
-
229
- return res, norm_factor
230
-
231
-
232
- @torch.no_grad()
233
- def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
234
- # set invalid points to NaN
235
- _zs = []
236
- for i in range(len(zs)):
237
- valid_mask = valid_masks[i] if valid_masks is not None else None
238
- _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
239
- _zs.append(_z)
240
-
241
- _zs = torch.cat(_zs, dim=-1)
242
-
243
- # compute median depth overall (ignoring nans)
244
- if quantile == 0.5:
245
- shift_z = torch.nanmedian(_zs, dim=-1).values
246
- else:
247
- shift_z = torch.nanquantile(_zs, quantile, dim=-1)
248
- return shift_z # (B,)
249
-
250
-
251
- @torch.no_grad()
252
- def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
253
- # set invalid points to NaN
254
-
255
- _pts = []
256
- for i in range(len(pts)):
257
- valid_mask = valid_masks[i] if valid_masks is not None else None
258
- _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
259
- _pts.append(_pt)
260
-
261
- _pts = torch.cat(_pts, dim=1)
262
-
263
- # compute median center
264
- _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
265
- if z_only:
266
- _center[..., :2] = 0 # do not center X and Y
267
-
268
- # compute median norm
269
- _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
270
- scale = torch.nanmedian(_norm, dim=1).values
271
- return _center[:, None, :, :], scale[:, None, None, None]
272
-
273
-
274
- class Regr3D_t(Criterion, MultiLoss):
275
- def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
276
- super().__init__(criterion)
277
- self.norm_mode = norm_mode
278
- self.gt_scale = gt_scale
279
- self.fix_first = fix_first
280
-
281
- def get_all_pts3d_t(self, gts, preds, dist_clip=None):
282
- # everything is normalized w.r.t. camera of view1
283
- in_camera1 = inv(gts[0]["camera_pose"])
284
-
285
- gt_pts = []
286
- valids = []
287
- pr_pts = []
288
-
289
- for i, gt in enumerate(gts):
290
- # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
291
- gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
292
- valid = gt["valid_mask"].clone()
293
-
294
- if dist_clip is not None:
295
- # points that are too far-away == invalid
296
- dis = gt["pts3d"].norm(dim=-1)
297
- valid = valid & (dis <= dist_clip)
298
-
299
- valids.append(valid)
300
- pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
301
- # if i != len(gts)-1:
302
- # pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))
303
-
304
- # if i != 0:
305
- # pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))
306
-
307
- # pr_pts = (pr_pts_l, pr_pts_r)
308
-
309
- if self.norm_mode:
310
- pr_pts, pr_factor = normalize_pointcloud_t(
311
- pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
312
- )
313
- else:
314
- pr_factor = None
315
-
316
- if self.norm_mode and not self.gt_scale:
317
- gt_pts, gt_factor = normalize_pointcloud_t(
318
- gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
319
- )
320
- else:
321
- gt_factor = None
322
-
323
- return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}
324
-
325
- def compute_frame_loss(self, gts, preds, **kw):
326
- gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
327
- self.get_all_pts3d_t(gts, preds, **kw)
328
- )
329
-
330
- pred_pts_l, pred_pts_r = pred_pts
331
-
332
- loss_all = []
333
- mask_all = []
334
- conf_all = []
335
-
336
- loss_left = 0
337
- loss_right = 0
338
- pred_conf_l = 0
339
- pred_conf_r = 0
340
-
341
- for i in range(len(gt_pts)):
342
-
343
- # Left (Reference)
344
- if i != len(gt_pts) - 1:
345
- frame_loss = self.criterion(
346
- pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
347
- )
348
-
349
- loss_all.append(frame_loss)
350
- mask_all.append(masks[i])
351
- conf_all.append(preds[i][0]["conf"])
352
-
353
- # To compare target/reference loss
354
- if i != 0:
355
- loss_left += frame_loss.cpu().detach().numpy().mean()
356
- pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()
357
-
358
- # Right (Target)
359
- if i != 0:
360
- frame_loss = self.criterion(
361
- pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
362
- )
363
-
364
- loss_all.append(frame_loss)
365
- mask_all.append(masks[i])
366
- conf_all.append(preds[i - 1][1]["conf"])
367
-
368
- # To compare target/reference loss
369
- if i != len(gt_pts) - 1:
370
- loss_right += frame_loss.cpu().detach().numpy().mean()
371
- pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()
372
-
373
- if pr_factor is not None and gt_factor is not None:
374
- filter_factor = pr_factor[pr_factor > gt_factor]
375
- else:
376
- filter_factor = []
377
-
378
- if len(filter_factor) > 0:
379
- factor_loss = (filter_factor - gt_factor).abs().mean()
380
- else:
381
- factor_loss = 0.0
382
-
383
- self_name = type(self).__name__
384
- details = {
385
- self_name + "_pts3d_1": float(loss_all[0].mean()),
386
- self_name + "_pts3d_2": float(loss_all[1].mean()),
387
- self_name + "loss_left": float(loss_left),
388
- self_name + "loss_right": float(loss_right),
389
- self_name + "conf_left": float(pred_conf_l),
390
- self_name + "conf_right": float(pred_conf_r),
391
- }
392
-
393
- return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss
394
-
395
-
396
- class ConfLoss_t(MultiLoss):
397
- """Weighted regression by learned confidence.
398
- Assuming the input pixel_loss is a pixel-level regression loss.
399
-
400
- Principle:
401
- high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
402
- low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
403
-
404
- alpha: hyperparameter
405
- """
406
-
407
- def __init__(self, pixel_loss, alpha=1):
408
- super().__init__()
409
- assert alpha > 0
410
- self.alpha = alpha
411
- self.pixel_loss = pixel_loss.with_reduction("none")
412
-
413
- def get_name(self):
414
- return f"ConfLoss({self.pixel_loss})"
415
-
416
- def get_conf_log(self, x):
417
- return x, torch.log(x)
418
-
419
- def compute_frame_loss(self, gts, preds, **kw):
420
- # compute per-pixel loss
421
- (losses, masks, confs), details, loss_factor = (
422
- self.pixel_loss.compute_frame_loss(gts, preds, **kw)
423
- )
424
-
425
- # weight by confidence
426
- conf_losses = []
427
- conf_sum = 0
428
- for i in range(len(losses)):
429
- conf, log_conf = self.get_conf_log(confs[i][masks[i]])
430
- conf_sum += conf.mean()
431
- conf_loss = losses[i] * conf - self.alpha * log_conf
432
- conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
433
- conf_losses.append(conf_loss)
434
-
435
- conf_losses = torch.stack(conf_losses) * 2.0
436
- conf_loss_mean = conf_losses.mean()
437
-
438
- return (
439
- conf_loss_mean,
440
- dict(
441
- conf_loss_1=float(conf_losses[0]),
442
- conf_loss2=float(conf_losses[1]),
443
- conf_mean=conf_sum / len(losses),
444
- **details,
445
- ),
446
- loss_factor,
447
- )
448
-
449
-
450
- class Regr3D_t_ShiftInv(Regr3D_t):
451
- """Same than Regr3D but invariant to depth shift."""
452
-
453
- def get_all_pts3d_t(self, gts, preds):
454
- # compute unnormalized points
455
- gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
456
- super().get_all_pts3d_t(gts, preds)
457
- )
458
-
459
- # pred_pts_l, pred_pts_r = pred_pts
460
- gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]
461
-
462
- pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
463
- # pred_zs.append(pred_pts_r[-1][..., 2])
464
-
465
- # compute median depth
466
- gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
467
- pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]
468
-
469
- # subtract the median depth
470
- for i in range(len(gt_pts)):
471
- gt_pts[i][..., 2] -= gt_shift_z
472
-
473
- for i in range(len(pred_pts)):
474
- # for j in range(len(pred_pts[i])):
475
- pred_pts[i][..., 2] -= pred_shift_z
476
-
477
- monitoring = dict(
478
- monitoring,
479
- gt_shift_z=gt_shift_z.mean().detach(),
480
- pred_shift_z=pred_shift_z.mean().detach(),
481
- )
482
- return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
483
-
484
-
485
- class Regr3D_t_ScaleInv(Regr3D_t):
486
- """Same than Regr3D but invariant to depth shift.
487
- if gt_scale == True: enforce the prediction to take the same scale than GT
488
- """
489
-
490
- def get_all_pts3d_t(self, gts, preds):
491
- # compute depth-normalized points
492
- gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
493
- super().get_all_pts3d_t(gts, preds)
494
- )
495
-
496
- # measure scene scale
497
-
498
- # pred_pts_l, pred_pts_r = pred_pts
499
-
500
- pred_pts_all = [
501
- x.clone() for x in pred_pts
502
- ] # [pred_pt for pred_pt in pred_pts_l]
503
- # pred_pts_all.append(pred_pts_r[-1])
504
-
505
- _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
506
- _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)
507
-
508
- # prevent predictions to be in a ridiculous range
509
- pred_scale = pred_scale.clip(min=1e-3, max=1e3)
510
-
511
- # subtract the median depth
512
- if self.gt_scale:
513
- for i in range(len(pred_pts)):
514
- # for j in range(len(pred_pts[i])):
515
- pred_pts[i] *= gt_scale / pred_scale
516
-
517
- else:
518
- for i in range(len(pred_pts)):
519
- # for j in range(len(pred_pts[i])):
520
- pred_pts[i] *= pred_scale / gt_scale
521
-
522
- for i in range(len(gt_pts)):
523
- gt_pts[i] *= gt_scale / pred_scale
524
-
525
- monitoring = dict(
526
- monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
527
- )
528
-
529
- return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
530
-
531
-
532
- class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
533
- # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
534
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/eval/data.py DELETED
@@ -1,338 +0,0 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import os.path as osp
5
- from collections import deque
6
- from base import BaseStereoViewDataset
7
- import dataset_utils.cropping as cropping
8
- from vggt.utils.eval_utils import imread_cv2, shuffle_deque
9
-
10
-
11
- class SevenScenes(BaseStereoViewDataset):
12
- def __init__(
13
- self,
14
- num_seq=1,
15
- num_frames=5,
16
- min_thresh=10,
17
- max_thresh=100,
18
- test_id=None,
19
- full_video=False,
20
- tuple_list=None,
21
- seq_id=None,
22
- rebuttal=False,
23
- shuffle_seed=-1,
24
- kf_every=1,
25
- *args,
26
- ROOT,
27
- **kwargs,
28
- ):
29
- self.ROOT = ROOT
30
- super().__init__(*args, **kwargs)
31
- self.num_seq = num_seq
32
- self.num_frames = num_frames
33
- self.max_thresh = max_thresh
34
- self.min_thresh = min_thresh
35
- self.test_id = test_id
36
- self.full_video = full_video
37
- self.kf_every = kf_every
38
- self.seq_id = seq_id
39
- self.rebuttal = rebuttal
40
- self.shuffle_seed = shuffle_seed
41
-
42
- # load all scenes
43
- self.load_all_tuples(tuple_list)
44
- self.load_all_scenes(ROOT)
45
-
46
- def __len__(self):
47
- if self.tuple_list is not None:
48
- return len(self.tuple_list)
49
- return len(self.scene_list) * self.num_seq
50
-
51
- def load_all_tuples(self, tuple_list):
52
- if tuple_list is not None:
53
- self.tuple_list = tuple_list
54
- # with open(tuple_path) as f:
55
- # self.tuple_list = f.read().splitlines()
56
-
57
- else:
58
- self.tuple_list = None
59
-
60
- def load_all_scenes(self, base_dir):
61
-
62
- if self.tuple_list is not None:
63
- # Use pre-defined simplerecon scene_ids
64
- self.scene_list = [
65
- "stairs/seq-06",
66
- "stairs/seq-02",
67
- "pumpkin/seq-06",
68
- "chess/seq-01",
69
- "heads/seq-02",
70
- "fire/seq-02",
71
- "office/seq-03",
72
- "pumpkin/seq-03",
73
- "redkitchen/seq-07",
74
- "chess/seq-02",
75
- "office/seq-01",
76
- "redkitchen/seq-01",
77
- "fire/seq-01",
78
- ]
79
- print(f"Found {len(self.scene_list)} sequences in split {self.split}")
80
- return
81
-
82
- scenes = os.listdir(base_dir)
83
-
84
- file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split]
85
-
86
- self.scene_list = []
87
- for scene in scenes:
88
- if self.test_id is not None and scene != self.test_id:
89
- continue
90
- # read file split
91
- with open(osp.join(base_dir, scene, file_split)) as f:
92
- seq_ids = f.read().splitlines()
93
-
94
- for seq_id in seq_ids:
95
- # seq is string, take the int part and make it 01, 02, 03
96
- # seq_id = 'seq-{:2d}'.format(int(seq_id))
97
- num_part = "".join(filter(str.isdigit, seq_id))
98
- seq_id = f"seq-{num_part.zfill(2)}"
99
- if self.seq_id is not None and seq_id != self.seq_id:
100
- continue
101
- self.scene_list.append(f"{scene}/{seq_id}")
102
-
103
- print(f"Found {len(self.scene_list)} sequences in split {self.split}")
104
-
105
- def _get_views(self, idx, resolution, rng):
106
-
107
- if self.tuple_list is not None:
108
- line = self.tuple_list[idx].split(" ")
109
- scene_id = line[0]
110
- img_idxs = line[1:]
111
-
112
- else:
113
- scene_id = self.scene_list[idx // self.num_seq]
114
- seq_id = idx % self.num_seq
115
-
116
- data_path = osp.join(self.ROOT, scene_id)
117
- num_files = len([name for name in os.listdir(data_path) if "color" in name])
118
- img_idxs = [f"{i:06d}" for i in range(num_files)]
119
- img_idxs = img_idxs[:: self.kf_every]
120
-
121
- # Intrinsics used in SimpleRecon
122
- fx, fy, cx, cy = 525, 525, 320, 240
123
- intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
124
-
125
- views = []
126
- imgs_idxs = deque(img_idxs)
127
- if self.shuffle_seed >= 0:
128
- imgs_idxs = shuffle_deque(imgs_idxs)
129
-
130
- while len(imgs_idxs) > 0:
131
- im_idx = imgs_idxs.popleft()
132
- impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png")
133
- depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png")
134
- posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt")
135
-
136
- rgb_image = imread_cv2(impath)
137
-
138
- depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
139
- rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
140
-
141
- depthmap[depthmap == 65535] = 0
142
- depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
143
-
144
- depthmap[depthmap > 10] = 0
145
- depthmap[depthmap < 1e-3] = 0
146
-
147
- camera_pose = np.loadtxt(posepath).astype(np.float32)
148
-
149
- if resolution != (224, 224) or self.rebuttal:
150
- rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
151
- rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
152
- )
153
- else:
154
- rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
155
- rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
156
- )
157
- W, H = rgb_image.size
158
- cx = W // 2
159
- cy = H // 2
160
- l, t = cx - 112, cy - 112
161
- r, b = cx + 112, cy + 112
162
- crop_bbox = (l, t, r, b)
163
- rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
164
- rgb_image, depthmap, intrinsics, crop_bbox
165
- )
166
-
167
- views.append(
168
- dict(
169
- img=rgb_image,
170
- depthmap=depthmap,
171
- camera_pose=camera_pose,
172
- camera_intrinsics=intrinsics,
173
- dataset="7scenes",
174
- label=osp.join(scene_id, im_idx),
175
- instance=impath,
176
- )
177
- )
178
- return views
179
-
180
-
181
- class NRGBD(BaseStereoViewDataset):
182
- def __init__(
183
- self,
184
- num_seq=1,
185
- num_frames=5,
186
- min_thresh=10,
187
- max_thresh=100,
188
- test_id=None,
189
- full_video=False,
190
- tuple_list=None,
191
- seq_id=None,
192
- rebuttal=False,
193
- shuffle_seed=-1,
194
- kf_every=1,
195
- *args,
196
- ROOT,
197
- **kwargs,
198
- ):
199
-
200
- self.ROOT = ROOT
201
- super().__init__(*args, **kwargs)
202
- self.num_seq = num_seq
203
- self.num_frames = num_frames
204
- self.max_thresh = max_thresh
205
- self.min_thresh = min_thresh
206
- self.test_id = test_id
207
- self.full_video = full_video
208
- self.kf_every = kf_every
209
- self.seq_id = seq_id
210
- self.rebuttal = rebuttal
211
- self.shuffle_seed = shuffle_seed
212
-
213
- # load all scenes
214
- self.load_all_tuples(tuple_list)
215
- self.load_all_scenes(ROOT)
216
-
217
- def __len__(self):
218
- if self.tuple_list is not None:
219
- return len(self.tuple_list)
220
- return len(self.scene_list) * self.num_seq
221
-
222
- def load_all_tuples(self, tuple_list):
223
- if tuple_list is not None:
224
- self.tuple_list = tuple_list
225
- # with open(tuple_path) as f:
226
- # self.tuple_list = f.read().splitlines()
227
-
228
- else:
229
- self.tuple_list = None
230
-
231
- def load_all_scenes(self, base_dir):
232
-
233
- scenes = [
234
- d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
235
- ]
236
-
237
- if self.test_id is not None:
238
- self.scene_list = [self.test_id]
239
-
240
- else:
241
- self.scene_list = scenes
242
-
243
- print(f"Found {len(self.scene_list)} sequences in split {self.split}")
244
-
245
- def load_poses(self, path):
246
- file = open(path, "r")
247
- lines = file.readlines()
248
- file.close()
249
- poses = []
250
- valid = []
251
- lines_per_matrix = 4
252
- for i in range(0, len(lines), lines_per_matrix):
253
- if "nan" in lines[i]:
254
- valid.append(False)
255
- poses.append(np.eye(4, 4, dtype=np.float32).tolist())
256
- else:
257
- valid.append(True)
258
- pose_floats = [
259
- [float(x) for x in line.split()]
260
- for line in lines[i : i + lines_per_matrix]
261
- ]
262
- poses.append(pose_floats)
263
-
264
- return np.array(poses, dtype=np.float32), valid
265
-
266
- def _get_views(self, idx, resolution, rng):
267
-
268
- if self.tuple_list is not None:
269
- line = self.tuple_list[idx].split(" ")
270
- scene_id = line[0]
271
- img_idxs = line[1:]
272
-
273
- else:
274
- scene_id = self.scene_list[idx // self.num_seq]
275
-
276
- num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images")))
277
- img_idxs = [f"{i}" for i in range(num_files)]
278
- img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)]
279
-
280
- fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240
281
- intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
282
-
283
- posepath = osp.join(self.ROOT, scene_id, f"poses.txt")
284
- camera_poses, valids = self.load_poses(posepath)
285
-
286
- imgs_idxs = deque(img_idxs)
287
- if self.shuffle_seed >= 0:
288
- imgs_idxs = shuffle_deque(imgs_idxs)
289
- views = []
290
-
291
- while len(imgs_idxs) > 0:
292
- im_idx = imgs_idxs.popleft()
293
-
294
- impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png")
295
- depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png")
296
-
297
- rgb_image = imread_cv2(impath)
298
- depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
299
- depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
300
- depthmap[depthmap > 10] = 0
301
- depthmap[depthmap < 1e-3] = 0
302
-
303
- rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
304
-
305
- camera_pose = camera_poses[int(im_idx)]
306
- # gl to cv
307
- camera_pose[:, 1:3] *= -1.0
308
- if resolution != (224, 224) or self.rebuttal:
309
- rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
310
- rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
311
- )
312
- else:
313
- rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
314
- rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
315
- )
316
- W, H = rgb_image.size
317
- cx = W // 2
318
- cy = H // 2
319
- l, t = cx - 112, cy - 112
320
- r, b = cx + 112, cy + 112
321
- crop_bbox = (l, t, r, b)
322
- rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
323
- rgb_image, depthmap, intrinsics, crop_bbox
324
- )
325
-
326
- views.append(
327
- dict(
328
- img=rgb_image,
329
- depthmap=depthmap,
330
- camera_pose=camera_pose,
331
- camera_intrinsics=intrinsics,
332
- dataset="nrgbd",
333
- label=osp.join(scene_id, im_idx),
334
- instance=impath,
335
- )
336
- )
337
-
338
- return views
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/eval/dataset_utils/__init__.py DELETED
@@ -1 +0,0 @@
1
-
 
 
FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (146 Bytes)
 
FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc DELETED
Binary file (140 Bytes)
 
FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc DELETED
Binary file (5.85 kB)
 
FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc DELETED
Binary file (4.29 kB)
 
FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc DELETED
Binary file (4.29 kB)
 
FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc DELETED
Binary file (2.18 kB)
 
FastVGGT/eval/dataset_utils/corr.py DELETED
@@ -1,234 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
-
6
- import numpy as np
7
- import torch
8
-
9
-
10
- def todevice(batch, device, callback=None, non_blocking=False):
11
- """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
12
-
13
- batch: list, tuple, dict of tensors or other things
14
- device: pytorch device or 'numpy'
15
- callback: function that would be called on every sub-elements.
16
- """
17
- if callback:
18
- batch = callback(batch)
19
-
20
- if isinstance(batch, dict):
21
- return {k: todevice(v, device) for k, v in batch.items()}
22
-
23
- if isinstance(batch, (tuple, list)):
24
- return type(batch)(todevice(x, device) for x in batch)
25
-
26
- x = batch
27
- if device == "numpy":
28
- if isinstance(x, torch.Tensor):
29
- x = x.detach().cpu().numpy()
30
- elif x is not None:
31
- if isinstance(x, np.ndarray):
32
- x = torch.from_numpy(x)
33
- if torch.is_tensor(x):
34
- x = x.to(device, non_blocking=non_blocking)
35
- return x
36
-
37
-
38
- to_device = todevice # alias
39
-
40
-
41
- def to_numpy(x):
42
- return todevice(x, "numpy")
43
-
44
-
45
- def geotrf(Trf, pts, ncol=None, norm=False):
46
- """Apply a geometric transformation to a list of 3-D points.
47
-
48
- H: 3x3 or 4x4 projection matrix (typically a Homography)
49
- p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
50
-
51
- ncol: int. number of columns of the result (2 or 3)
52
- norm: float. if != 0, the resut is projected on the z=norm plane.
53
-
54
- Returns an array of projected 2d points.
55
- """
56
- assert Trf.ndim >= 2
57
- if isinstance(Trf, np.ndarray):
58
- pts = np.asarray(pts)
59
- elif isinstance(Trf, torch.Tensor):
60
- pts = torch.as_tensor(pts, dtype=Trf.dtype)
61
-
62
- output_reshape = pts.shape[:-1]
63
- ncol = ncol or pts.shape[-1]
64
-
65
- if (
66
- isinstance(Trf, torch.Tensor)
67
- and isinstance(pts, torch.Tensor)
68
- and Trf.ndim == 3
69
- and pts.ndim == 4
70
- ):
71
- d = pts.shape[3]
72
- if Trf.shape[-1] == d:
73
- pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
74
- elif Trf.shape[-1] == d + 1:
75
- pts = (
76
- torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
77
- + Trf[:, None, None, :d, d]
78
- )
79
- else:
80
- raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
81
- else:
82
- if Trf.ndim >= 3:
83
- n = Trf.ndim - 2
84
- assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
85
- Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
86
-
87
- if pts.ndim > Trf.ndim:
88
-
89
- pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
90
- elif pts.ndim == 2:
91
-
92
- pts = pts[:, None, :]
93
-
94
- if pts.shape[-1] + 1 == Trf.shape[-1]:
95
- Trf = Trf.swapaxes(-1, -2) # transpose Trf
96
- pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
97
- elif pts.shape[-1] == Trf.shape[-1]:
98
- Trf = Trf.swapaxes(-1, -2) # transpose Trf
99
- pts = pts @ Trf
100
- else:
101
- pts = Trf @ pts.T
102
- if pts.ndim >= 2:
103
- pts = pts.swapaxes(-1, -2)
104
-
105
- if norm:
106
- pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
107
- if norm != 1:
108
- pts *= norm
109
-
110
- res = pts[..., :ncol].reshape(*output_reshape, ncol)
111
- return res
112
-
113
-
114
- def inv(mat):
115
- """Invert a torch or numpy matrix"""
116
- if isinstance(mat, torch.Tensor):
117
- return torch.linalg.inv(mat)
118
- if isinstance(mat, np.ndarray):
119
- return np.linalg.inv(mat)
120
- raise ValueError(f"bad matrix type = {type(mat)}")
121
-
122
-
123
- def reproject_view(pts3d, view2):
124
- shape = view2["pts3d"].shape[:2]
125
- return reproject(
126
- pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
127
- )
128
-
129
-
130
- def reproject(pts3d, K, world2cam, shape):
131
- H, W, THREE = pts3d.shape
132
- assert THREE == 3
133
-
134
- with np.errstate(divide="ignore", invalid="ignore"):
135
- pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
136
-
137
- return (H, W), ravel_xy(pos, shape)
138
-
139
-
140
- def ravel_xy(pos, shape):
141
- H, W = shape
142
- with np.errstate(invalid="ignore"):
143
- qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
144
- quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
145
- min=0, max=H - 1, out=qy
146
- )
147
- return quantized_pos
148
-
149
-
150
- def unravel_xy(pos, shape):
151
-
152
- return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
153
-
154
-
155
- def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
156
- is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
157
- pos1 = is_reciprocal1.nonzero()[0]
158
- pos2 = corres_1_to_2[pos1]
159
- if ret_recip:
160
- return is_reciprocal1, pos1, pos2
161
- return pos1, pos2
162
-
163
-
164
- def extract_correspondences_from_pts3d(
165
- view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
166
- ):
167
- view1, view2 = to_numpy((view1, view2))
168
-
169
- shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
170
- shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
171
-
172
- is_reciprocal1, pos1, pos2 = reciprocal_1d(
173
- corres1_to_2, corres2_to_1, ret_recip=True
174
- )
175
- is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
176
-
177
- if target_n_corres is None:
178
- if ret_xy:
179
- pos1 = unravel_xy(pos1, shape1)
180
- pos2 = unravel_xy(pos2, shape2)
181
- return pos1, pos2
182
-
183
- available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
184
- target_n_positives = int(target_n_corres * (1 - nneg))
185
- n_positives = min(len(pos1), target_n_positives)
186
- n_negatives = min(target_n_corres - n_positives, available_negatives)
187
-
188
- if n_negatives + n_positives != target_n_corres:
189
-
190
- n_positives = target_n_corres - n_negatives
191
- assert n_positives <= len(pos1)
192
-
193
- assert n_positives <= len(pos1)
194
- assert n_positives <= len(pos2)
195
- assert n_negatives <= (~is_reciprocal1).sum()
196
- assert n_negatives <= (~is_reciprocal2).sum()
197
- assert n_positives + n_negatives == target_n_corres
198
-
199
- valid = np.ones(n_positives, dtype=bool)
200
- if n_positives < len(pos1):
201
-
202
- perm = rng.permutation(len(pos1))[:n_positives]
203
- pos1 = pos1[perm]
204
- pos2 = pos2[perm]
205
-
206
- if n_negatives > 0:
207
-
208
- def norm(p):
209
- return p / p.sum()
210
-
211
- pos1 = np.r_[
212
- pos1,
213
- rng.choice(
214
- shape1[0] * shape1[1],
215
- size=n_negatives,
216
- replace=False,
217
- p=norm(~is_reciprocal1),
218
- ),
219
- ]
220
- pos2 = np.r_[
221
- pos2,
222
- rng.choice(
223
- shape2[0] * shape2[1],
224
- size=n_negatives,
225
- replace=False,
226
- p=norm(~is_reciprocal2),
227
- ),
228
- ]
229
- valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
230
-
231
- if ret_xy:
232
- pos1 = unravel_xy(pos1, shape1)
233
- pos2 = unravel_xy(pos2, shape2)
234
- return pos1, pos2, valid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/eval/dataset_utils/cropping.py DELETED
@@ -1,140 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
-
6
- import PIL.Image
7
- import os
8
-
9
- from utils import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
10
-
11
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
12
- import cv2 # noqa
13
- import numpy as np # noqa
14
-
15
- try:
16
- lanczos = PIL.Image.Resampling.LANCZOS
17
- bicubic = PIL.Image.Resampling.BICUBIC
18
- except AttributeError:
19
- lanczos = PIL.Image.LANCZOS
20
- bicubic = PIL.Image.BICUBIC
21
-
22
-
23
- class ImageList:
24
- """Convenience class to aply the same operation to a whole set of images."""
25
-
26
- def __init__(self, images):
27
- if not isinstance(images, (tuple, list, set)):
28
- images = [images]
29
- self.images = []
30
- for image in images:
31
- if not isinstance(image, PIL.Image.Image):
32
- image = PIL.Image.fromarray(image)
33
- self.images.append(image)
34
-
35
- def __len__(self):
36
- return len(self.images)
37
-
38
- def to_pil(self):
39
- return tuple(self.images) if len(self.images) > 1 else self.images[0]
40
-
41
- @property
42
- def size(self):
43
- sizes = [im.size for im in self.images]
44
- assert all(sizes[0] == s for s in sizes)
45
- return sizes[0]
46
-
47
- def resize(self, *args, **kwargs):
48
- return ImageList(self._dispatch("resize", *args, **kwargs))
49
-
50
- def crop(self, *args, **kwargs):
51
- return ImageList(self._dispatch("crop", *args, **kwargs))
52
-
53
- def _dispatch(self, func, *args, **kwargs):
54
- return [getattr(im, func)(*args, **kwargs) for im in self.images]
55
-
56
-
57
- def rescale_image_depthmap(
58
- image, depthmap, camera_intrinsics, output_resolution, force=True
59
- ):
60
- """Jointly rescale a (image, depthmap)
61
- so that (out_width, out_height) >= output_res
62
- """
63
- image = ImageList(image)
64
- input_resolution = np.array(image.size) # (W,H)
65
- output_resolution = np.array(output_resolution)
66
- if depthmap is not None:
67
-
68
- assert tuple(depthmap.shape[:2]) == image.size[::-1]
69
-
70
- assert output_resolution.shape == (2,)
71
- scale_final = max(output_resolution / image.size) + 1e-8
72
- if scale_final >= 1 and not force: # image is already smaller than what is asked
73
- return (image.to_pil(), depthmap, camera_intrinsics)
74
- output_resolution = np.floor(input_resolution * scale_final).astype(int)
75
-
76
- image = image.resize(
77
- output_resolution, resample=lanczos if scale_final < 1 else bicubic
78
- )
79
- if depthmap is not None:
80
- depthmap = cv2.resize(
81
- depthmap,
82
- output_resolution,
83
- fx=scale_final,
84
- fy=scale_final,
85
- interpolation=cv2.INTER_NEAREST,
86
- )
87
-
88
- camera_intrinsics = camera_matrix_of_crop(
89
- camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
90
- )
91
-
92
- return image.to_pil(), depthmap, camera_intrinsics
93
-
94
-
95
- def camera_matrix_of_crop(
96
- input_camera_matrix,
97
- input_resolution,
98
- output_resolution,
99
- scaling=1,
100
- offset_factor=0.5,
101
- offset=None,
102
- ):
103
-
104
- margins = np.asarray(input_resolution) * scaling - output_resolution
105
- assert np.all(margins >= 0.0)
106
- if offset is None:
107
- offset = offset_factor * margins
108
-
109
- output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
110
- output_camera_matrix_colmap[:2, :] *= scaling
111
- output_camera_matrix_colmap[:2, 2] -= offset
112
- output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
113
-
114
- return output_camera_matrix
115
-
116
-
117
- def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
118
- """
119
- Return a crop of the input view.
120
- """
121
- image = ImageList(image)
122
- l, t, r, b = crop_bbox
123
-
124
- image = image.crop((l, t, r, b))
125
- depthmap = depthmap[t:b, l:r]
126
-
127
- camera_intrinsics = camera_intrinsics.copy()
128
- camera_intrinsics[0, 2] -= l
129
- camera_intrinsics[1, 2] -= t
130
-
131
- return image.to_pil(), depthmap, camera_intrinsics
132
-
133
-
134
- def bbox_from_intrinsics_in_out(
135
- input_camera_matrix, output_camera_matrix, output_resolution
136
- ):
137
- out_width, out_height = output_resolution
138
- l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
139
- crop_bbox = (l, t, l + out_width, t + out_height)
140
- return crop_bbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/eval/dataset_utils/transforms.py DELETED
@@ -1,78 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
-
6
- import torchvision.transforms as tvf
7
-
8
- ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
9
-
10
-
11
- ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
12
-
13
-
14
- def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
15
- if isinstance(value, (int, float)):
16
- if value < 0:
17
- raise ValueError(f"If is a single number, it must be non negative.")
18
- value = [center - float(value), center + float(value)]
19
- if clip_first_on_zero:
20
- value[0] = max(value[0], 0.0)
21
- elif isinstance(value, (tuple, list)) and len(value) == 2:
22
- value = [float(value[0]), float(value[1])]
23
- else:
24
- raise TypeError(f"should be a single number or a list/tuple with length 2.")
25
-
26
- if not bound[0] <= value[0] <= value[1] <= bound[1]:
27
- raise ValueError(f"values should be between {bound}, but got {value}.")
28
-
29
- if value[0] == value[1] == center:
30
- return None
31
- else:
32
- return tuple(value)
33
-
34
-
35
- import torch
36
- import torchvision.transforms.functional as F
37
-
38
-
39
- def SeqColorJitter():
40
- """
41
- Return a color jitter transform with same random parameters
42
- """
43
- brightness = _check_input(0.5)
44
- contrast = _check_input(0.5)
45
- saturation = _check_input(0.5)
46
- hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
47
-
48
- fn_idx = torch.randperm(4)
49
- brightness_factor = (
50
- None
51
- if brightness is None
52
- else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
53
- )
54
- contrast_factor = (
55
- None
56
- if contrast is None
57
- else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
58
- )
59
- saturation_factor = (
60
- None
61
- if saturation is None
62
- else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
63
- )
64
- hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
65
-
66
- def _color_jitter(img):
67
- for fn_id in fn_idx:
68
- if fn_id == 0 and brightness_factor is not None:
69
- img = F.adjust_brightness(img, brightness_factor)
70
- elif fn_id == 1 and contrast_factor is not None:
71
- img = F.adjust_contrast(img, contrast_factor)
72
- elif fn_id == 2 and saturation_factor is not None:
73
- img = F.adjust_saturation(img, saturation_factor)
74
- elif fn_id == 3 and hue_factor is not None:
75
- img = F.adjust_hue(img, hue_factor)
76
- return ImgNorm(img)
77
-
78
- return _color_jitter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/eval/eval_7andN.py DELETED
@@ -1,497 +0,0 @@
1
- import os
2
- import sys
3
-
4
- # Ensure project root is on sys.path for absolute imports like `vggt.*`
5
- ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
6
- if ROOT_DIR not in sys.path:
7
- sys.path.insert(0, ROOT_DIR)
8
-
9
- import time
10
- import torch
11
- import argparse
12
- import numpy as np
13
- import open3d as o3d
14
- import os.path as osp
15
- from torch.utils.data import DataLoader
16
- from torch.utils.data._utils.collate import default_collate
17
- from tqdm import tqdm
18
- from collections import defaultdict
19
- import torchvision.transforms as transforms
20
-
21
-
22
- def get_args_parser():
23
- parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False)
24
- parser.add_argument(
25
- "--ckpt_path",
26
- type=str,
27
- default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
28
- help="ckpt name",
29
- )
30
- parser.add_argument("--device", type=str, default="cuda:0", help="device")
31
- parser.add_argument("--model_name", type=str, default="VGGT")
32
- parser.add_argument(
33
- "--conf_thresh", type=float, default=0.0, help="confidence threshold"
34
- )
35
- parser.add_argument(
36
- "--output_dir",
37
- type=str,
38
- default="/home/sy/code/FastVGGT/eval_results",
39
- help="value for outdir",
40
- )
41
- parser.add_argument("--size", type=int, default=518)
42
- parser.add_argument("--revisit", type=int, default=1, help="revisit times")
43
- parser.add_argument("--freeze", action="store_true")
44
- parser.add_argument("--use_proj", action="store_true")
45
- parser.add_argument(
46
- "--merging", type=int, default=0, help="VGGT aggregator merging steps"
47
- )
48
- parser.add_argument("--kf", type=int, default=2, help="key frame")
49
- return parser
50
-
51
-
52
- def main(args):
53
- from data import SevenScenes, NRGBD
54
- from utils import accuracy, completion
55
-
56
- if args.size == 512:
57
- resolution = (512, 384)
58
- elif args.size == 224:
59
- resolution = 224
60
- elif args.size == 518:
61
- resolution = (518, 392)
62
- else:
63
- raise NotImplementedError
64
- datasets_all = {
65
- "7scenes": SevenScenes(
66
- split="test",
67
- ROOT="/data/sy/7scenes",
68
- resolution=resolution,
69
- num_seq=1,
70
- full_video=True,
71
- kf_every=args.kf,
72
- ), # 20),
73
- "NRGBD": NRGBD(
74
- split="test",
75
- ROOT="/data/sy/neural_rgbd_data",
76
- resolution=resolution,
77
- num_seq=1,
78
- full_video=True,
79
- kf_every=args.kf,
80
- ),
81
- }
82
-
83
- device = args.device
84
- model_name = args.model_name
85
-
86
- from vggt.models.vggt import VGGT
87
- from vggt.utils.pose_enc import pose_encoding_to_extri_intri
88
- from vggt.utils.geometry import unproject_depth_map_to_point_map
89
- from criterion import Regr3D_t_ScaleShiftInv, L21
90
-
91
- # Force use of bf16 data type
92
- dtype = torch.bfloat16
93
- # Load VGGT model
94
- model = VGGT(merging=args.merging, enable_point=True)
95
- ckpt = torch.load(args.ckpt_path, map_location="cpu")
96
-
97
- # ✅ Fix: load pre-trained weights
98
- model.load_state_dict(
99
- ckpt, strict=False
100
- ) # Use strict=False due to enable_point=True difference
101
-
102
- model = model.cuda().eval()
103
- model = model.to(torch.bfloat16)
104
-
105
- del ckpt
106
- os.makedirs(osp.join(args.output_dir, f"{args.kf}"), exist_ok=True)
107
-
108
- criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True)
109
-
110
- with torch.no_grad():
111
- for name_data, dataset in datasets_all.items():
112
- save_path = osp.join(osp.join(args.output_dir, f"{args.kf}"), name_data)
113
- os.makedirs(save_path, exist_ok=True)
114
- log_file = osp.join(save_path, "logs.txt")
115
-
116
- acc_all = 0
117
- acc_all_med = 0
118
- comp_all = 0
119
- comp_all_med = 0
120
- nc1_all = 0
121
- nc1_all_med = 0
122
- nc2_all = 0
123
- nc2_all_med = 0
124
- scene_infer_times = defaultdict(list)
125
-
126
- for data_idx in tqdm(range(len(dataset))):
127
- batch = default_collate([dataset[data_idx]])
128
- ignore_keys = set(
129
- [
130
- "depthmap",
131
- "dataset",
132
- "label",
133
- "instance",
134
- "idx",
135
- "true_shape",
136
- "rng",
137
- ]
138
- )
139
- for view in batch:
140
- for name in view.keys(): # pseudo_focal
141
- if name in ignore_keys:
142
- continue
143
- if isinstance(view[name], tuple) or isinstance(
144
- view[name], list
145
- ):
146
- view[name] = [
147
- x.to(device, non_blocking=True) for x in view[name]
148
- ]
149
- else:
150
- view[name] = view[name].to(device, non_blocking=True)
151
-
152
- pts_all = []
153
- pts_gt_all = []
154
- images_all = []
155
- masks_all = []
156
- conf_all = []
157
- in_camera1 = None
158
-
159
- dtype = (
160
- torch.bfloat16
161
- if torch.cuda.get_device_capability()[0] >= 8
162
- else torch.float16
163
- )
164
- with torch.cuda.amp.autocast(dtype=dtype):
165
- if isinstance(batch, dict) and "img" in batch:
166
- batch["img"] = (batch["img"] + 1.0) / 2.0
167
- elif isinstance(batch, list) and all(
168
- isinstance(v, dict) and "img" in v for v in batch
169
- ):
170
- for view in batch:
171
- view["img"] = (view["img"] + 1.0) / 2.0
172
- # Gather all `img` tensors into a single tensor of shape [N, C, H, W]
173
- imgs_tensor = torch.cat([v["img"] for v in batch], dim=0)
174
-
175
- with torch.cuda.amp.autocast(dtype=dtype):
176
- with torch.no_grad():
177
- torch.cuda.synchronize()
178
- start = time.time()
179
- preds = model(imgs_tensor)
180
- torch.cuda.synchronize()
181
- end = time.time()
182
- inference_time_ms = (end - start) * 1000
183
- print(f"Inference time: {inference_time_ms:.2f}ms")
184
-
185
- # Wrap model outputs per-view to align with batch later
186
- predictions = preds
187
- views = batch # list[dict]
188
- if "pose_enc" in predictions:
189
- B, S = predictions["pose_enc"].shape[:2]
190
- elif "world_points" in predictions:
191
- B, S = predictions["world_points"].shape[:2]
192
- else:
193
- raise KeyError(
194
- "predictions is missing a key to infer sequence length"
195
- )
196
-
197
- ress = []
198
- for s in range(S):
199
- res = {
200
- "pts3d_in_other_view": predictions["world_points"][:, s],
201
- "conf": predictions["world_points_conf"][:, s],
202
- "depth": predictions["depth"][:, s],
203
- "depth_conf": predictions["depth_conf"][:, s],
204
- "camera_pose": predictions["pose_enc"][:, s, :],
205
- }
206
- if (
207
- isinstance(views, list)
208
- and s < len(views)
209
- and "valid_mask" in views[s]
210
- ):
211
- res["valid_mask"] = views[s]["valid_mask"]
212
- if "track" in predictions:
213
- res.update(
214
- {
215
- "track": predictions["track"][:, s],
216
- "vis": (
217
- predictions.get("vis", None)[:, s]
218
- if "vis" in predictions
219
- else None
220
- ),
221
- "track_conf": (
222
- predictions.get("conf", None)[:, s]
223
- if "conf" in predictions
224
- else None
225
- ),
226
- }
227
- )
228
- ress.append(res)
229
-
230
- preds = ress
231
-
232
- valid_length = len(preds) // args.revisit
233
- if args.revisit > 1:
234
- preds = preds[-valid_length:]
235
- batch = batch[-valid_length:]
236
-
237
- # Evaluation
238
- print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}")
239
- gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
240
- criterion.get_all_pts3d_t(batch, preds)
241
- )
242
-
243
- in_camera1 = None
244
- pts_all = []
245
- pts_gt_all = []
246
- images_all = []
247
- masks_all = []
248
- conf_all = []
249
-
250
- for j, view in enumerate(batch):
251
- if in_camera1 is None:
252
- in_camera1 = view["camera_pose"][0].cpu()
253
-
254
- image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0]
255
- mask = view["valid_mask"].cpu().numpy()[0]
256
-
257
- pts = pred_pts[j].cpu().numpy()[0]
258
- conf = preds[j]["conf"].cpu().data.numpy()[0]
259
-
260
- # mask = mask & (conf > 1.8)
261
-
262
- pts_gt = gt_pts[j].detach().cpu().numpy()[0]
263
-
264
- H, W = image.shape[:2]
265
- cx = W // 2
266
- cy = H // 2
267
- l, t = cx - 112, cy - 112
268
- r, b = cx + 112, cy + 112
269
- image = image[t:b, l:r]
270
- mask = mask[t:b, l:r]
271
- pts = pts[t:b, l:r]
272
- pts_gt = pts_gt[t:b, l:r]
273
-
274
- images_all.append(image[None, ...])
275
- pts_all.append(pts[None, ...])
276
- pts_gt_all.append(pts_gt[None, ...])
277
- masks_all.append(mask[None, ...])
278
- conf_all.append(conf[None, ...])
279
-
280
- images_all = np.concatenate(images_all, axis=0)
281
- pts_all = np.concatenate(pts_all, axis=0)
282
- pts_gt_all = np.concatenate(pts_gt_all, axis=0)
283
- masks_all = np.concatenate(masks_all, axis=0)
284
-
285
- scene_id = view["label"][0].rsplit("/", 1)[0]
286
- # Record average inference time per scene
287
- try:
288
- scene_infer_times[scene_id].append(float(inference_time_ms))
289
- except Exception:
290
- pass
291
-
292
- save_params = {}
293
-
294
- save_params["images_all"] = images_all
295
- save_params["pts_all"] = pts_all
296
- save_params["pts_gt_all"] = pts_gt_all
297
- save_params["masks_all"] = masks_all
298
-
299
- pts_all_masked = pts_all[masks_all > 0]
300
- pts_gt_all_masked = pts_gt_all[masks_all > 0]
301
- images_all_masked = images_all[masks_all > 0]
302
-
303
- mask = np.isfinite(pts_all_masked)
304
- pts_all_masked = pts_all_masked[mask]
305
-
306
- mask_gt = np.isfinite(pts_gt_all_masked)
307
- pts_gt_all_masked = pts_gt_all_masked[mask_gt]
308
- images_all_masked = images_all_masked[mask]
309
-
310
- # Reshape to point cloud (N, 3) before sampling
311
- pts_all_masked = pts_all_masked.reshape(-1, 3)
312
- pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
313
- images_all_masked = images_all_masked.reshape(-1, 3)
314
-
315
- # If number of points exceeds threshold, sample by points
316
- if pts_all_masked.shape[0] > 999999:
317
- sample_indices = np.random.choice(
318
- pts_all_masked.shape[0], 999999, replace=False
319
- )
320
- pts_all_masked = pts_all_masked[sample_indices]
321
- images_all_masked = images_all_masked[sample_indices]
322
-
323
- # Apply the same sampling to GT point cloud
324
- if pts_gt_all_masked.shape[0] > 999999:
325
- sample_indices_gt = np.random.choice(
326
- pts_gt_all_masked.shape[0], 999999, replace=False
327
- )
328
- pts_gt_all_masked = pts_gt_all_masked[sample_indices_gt]
329
-
330
- if args.use_proj:
331
-
332
- def umeyama_alignment(
333
- src: np.ndarray, dst: np.ndarray, with_scale: bool = True
334
- ):
335
- assert src.shape == dst.shape
336
- N, dim = src.shape
337
-
338
- mu_src = src.mean(axis=0)
339
- mu_dst = dst.mean(axis=0)
340
- src_c = src - mu_src
341
- dst_c = dst - mu_dst
342
-
343
- Sigma = dst_c.T @ src_c / N # (3,3)
344
-
345
- U, D, Vt = np.linalg.svd(Sigma)
346
-
347
- S = np.eye(dim)
348
- if np.linalg.det(U) * np.linalg.det(Vt) < 0:
349
- S[-1, -1] = -1
350
-
351
- R = U @ S @ Vt
352
-
353
- if with_scale:
354
- var_src = (src_c**2).sum() / N
355
- s = (D * S.diagonal()).sum() / var_src
356
- else:
357
- s = 1.0
358
-
359
- t = mu_dst - s * R @ mu_src
360
-
361
- return s, R, t
362
-
363
- pts_all_masked = pts_all_masked.reshape(-1, 3)
364
- pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
365
- s, R, t = umeyama_alignment(
366
- pts_all_masked, pts_gt_all_masked, with_scale=True
367
- )
368
- pts_all_aligned = (s * (R @ pts_all_masked.T)).T + t # (N,3)
369
- pts_all_masked = pts_all_aligned
370
-
371
- pcd = o3d.geometry.PointCloud()
372
- pcd.points = o3d.utility.Vector3dVector(pts_all_masked)
373
- pcd.colors = o3d.utility.Vector3dVector(images_all_masked)
374
-
375
- pcd_gt = o3d.geometry.PointCloud()
376
- pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked)
377
- pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked)
378
-
379
- trans_init = np.eye(4)
380
-
381
- threshold = 0.1
382
- reg_p2p = o3d.pipelines.registration.registration_icp(
383
- pcd,
384
- pcd_gt,
385
- threshold,
386
- trans_init,
387
- o3d.pipelines.registration.TransformationEstimationPointToPoint(),
388
- )
389
-
390
- transformation = reg_p2p.transformation
391
-
392
- pcd = pcd.transform(transformation)
393
- pcd.estimate_normals()
394
- pcd_gt.estimate_normals()
395
-
396
- gt_normal = np.asarray(pcd_gt.normals)
397
- pred_normal = np.asarray(pcd.normals)
398
-
399
- acc, acc_med, nc1, nc1_med = accuracy(
400
- pcd_gt.points, pcd.points, gt_normal, pred_normal
401
- )
402
- comp, comp_med, nc2, nc2_med = completion(
403
- pcd_gt.points, pcd.points, gt_normal, pred_normal
404
- )
405
- print(
406
- f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}"
407
- )
408
- print(
409
- f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}",
410
- file=open(log_file, "a"),
411
- )
412
-
413
- acc_all += acc
414
- comp_all += comp
415
- nc1_all += nc1
416
- nc2_all += nc2
417
-
418
- acc_all_med += acc_med
419
- comp_all_med += comp_med
420
- nc1_all_med += nc1_med
421
- nc2_all_med += nc2_med
422
-
423
- # release cuda memory
424
- torch.cuda.empty_cache()
425
-
426
- # Get depth from pcd and run TSDFusion
427
- to_write = ""
428
- # Read the log file
429
- if os.path.exists(osp.join(save_path, "logs.txt")):
430
- with open(osp.join(save_path, "logs.txt"), "r") as f_sub:
431
- to_write += f_sub.read()
432
-
433
- with open(osp.join(save_path, f"logs_all.txt"), "w") as f:
434
- log_data = to_write
435
- metrics = defaultdict(list)
436
- for line in log_data.strip().split("\n"):
437
- match = regex.match(line)
438
- if match:
439
- data = match.groupdict()
440
- # Exclude 'scene_id' from metrics as it's an identifier
441
- for key, value in data.items():
442
- if key != "scene_id":
443
- metrics[key].append(float(value))
444
- metrics["nc"].append(
445
- (float(data["nc1"]) + float(data["nc2"])) / 2
446
- )
447
- metrics["nc_med"].append(
448
- (float(data["nc1_med"]) + float(data["nc2_med"])) / 2
449
- )
450
- mean_metrics = {
451
- metric: sum(values) / len(values)
452
- for metric, values in metrics.items()
453
- }
454
-
455
- c_name = "mean"
456
- print_str = f"{c_name.ljust(20)}: "
457
- for m_name in mean_metrics:
458
- print_num = np.mean(mean_metrics[m_name])
459
- print_str = print_str + f"{m_name}: {print_num:.3f} | "
460
- print_str = print_str + "\n"
461
- # Summarize per-scene average inference time
462
- time_lines = []
463
- for sid, times in scene_infer_times.items():
464
- if len(times) > 0:
465
- time_lines.append(
466
- f"Idx: {sid}, Time_avg_ms: {np.mean(times):.2f}"
467
- )
468
- time_block = "\n".join(time_lines) + (
469
- "\n" if len(time_lines) > 0 else ""
470
- )
471
-
472
- f.write(to_write + time_block + print_str)
473
-
474
-
475
- from collections import defaultdict
476
- import re
477
-
478
- pattern = r"""
479
- Idx:\s*(?P<scene_id>[^,]+),\s*
480
- Acc:\s*(?P<acc>[^,]+),\s*
481
- Comp:\s*(?P<comp>[^,]+),\s*
482
- NC1:\s*(?P<nc1>[^,]+),\s*
483
- NC2:\s*(?P<nc2>[^,]+)\s*-\s*
484
- Acc_med:\s*(?P<acc_med>[^,]+),\s*
485
- Compc_med:\s*(?P<comp_med>[^,]+),\s*
486
- NC1c_med:\s*(?P<nc1_med>[^,]+),\s*
487
- NC2c_med:\s*(?P<nc2_med>[^,]+)
488
- """
489
-
490
- regex = re.compile(pattern, re.VERBOSE)
491
-
492
-
493
- if __name__ == "__main__":
494
- parser = get_args_parser()
495
- args = parser.parse_args()
496
-
497
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/eval/eval_custom.py DELETED
@@ -1,467 +0,0 @@
1
- import argparse
2
- from pathlib import Path
3
- import numpy as np
4
- import torch
5
- import os
6
- import sys
7
- import matplotlib.pyplot as plt
8
- from scipy.spatial.transform import Rotation
9
-
10
- # Ensure project root is in sys.path for absolute imports like `vggt.*`
11
- ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
12
- if ROOT_DIR not in sys.path:
13
- sys.path.insert(0, ROOT_DIR)
14
-
15
- from vggt.models.vggt import VGGT
16
- from vggt.utils.eval_utils import (
17
- load_poses,
18
- get_vgg_input_imgs,
19
- get_sorted_image_paths,
20
- build_frame_selection,
21
- load_images_rgb,
22
- infer_vggt_and_reconstruct,
23
- evaluate_scene_and_save,
24
- )
25
-
26
- # Import pose visualization libraries (optional EVO support)
27
- try:
28
- from evo.core.trajectory import PoseTrajectory3D
29
- import evo.tools.plot as plot
30
-
31
- EVO_AVAILABLE = True
32
- except ImportError:
33
- # EVO is optional; we have a matplotlib-based fallback
34
- EVO_AVAILABLE = False
35
-
36
-
37
- def visualize_predicted_poses(
38
- all_cam_to_world_mat, frame_ids, output_scene_dir, scene_name="custom_dataset"
39
- ):
40
- """
41
- Visualize the predicted camera pose trajectory (no GT comparison required).
42
-
43
- Args:
44
- all_cam_to_world_mat: List of camera-to-world transform matrices
45
- frame_ids: List of frame IDs
46
- output_scene_dir: Output directory
47
- scene_name: Scene name
48
- """
49
- # Provide basic pose visualization even without EVO
50
- if not EVO_AVAILABLE:
51
- print("⚠️ EVO not installed; using basic matplotlib visualization")
52
-
53
- try:
54
- # Convert to numpy array
55
- poses_est = np.array(all_cam_to_world_mat)
56
-
57
- if len(poses_est) < 2:
58
- print("⚠️ Not enough poses to generate trajectory plot")
59
- return
60
-
61
- print(f"🎨 Generating pose trajectory visualization...")
62
-
63
- # Extract translation part
64
- positions = poses_est[:, :3, 3] # shape: (N, 3)
65
-
66
- # Create figure - show XZ-plane projection only
67
- fig, ax = plt.subplots(1, 1, figsize=(10, 8))
68
-
69
- # XZ-plane projection
70
- ax.plot(
71
- positions[:, 0],
72
- positions[:, 2],
73
- "b-",
74
- linewidth=2,
75
- label="Predicted Trajectory",
76
- )
77
- ax.scatter(
78
- positions[0, 0], positions[0, 2], color="green", s=100, label="Start"
79
- )
80
- ax.scatter(positions[-1, 0], positions[-1, 2], color="red", s=100, label="End")
81
- ax.set_xlabel("X (m)")
82
- ax.set_ylabel("Z (m)")
83
- ax.set_title(f"{scene_name} - XZ-plane projection")
84
- ax.legend()
85
- ax.grid(True, alpha=0.3)
86
-
87
- # Save image
88
- pose_plot_path = output_scene_dir / "predicted_trajectory.png"
89
- plt.savefig(pose_plot_path, dpi=300, bbox_inches="tight")
90
- plt.close()
91
-
92
- print(f"📊 Trajectory visualization saved: {pose_plot_path}")
93
-
94
- except Exception as e:
95
- print(f"⚠️ Failed to generate pose visualization: {e}")
96
- import traceback
97
-
98
- traceback.print_exc()
99
-
100
-
101
- def main():
102
- """
103
- Evaluation script for a Custom Dataset.
104
- Supports optional evaluation and custom dataset structure.
105
- """
106
- parser = argparse.ArgumentParser(
107
- description="Run FastVGGT evaluation on a Custom Dataset"
108
- )
109
-
110
- # Required: dataset path
111
- parser.add_argument(
112
- "--data_path",
113
- type=Path,
114
- required=True,
115
- help="Dataset path containing subfolders: color, depth, gt_ply, pose",
116
- )
117
-
118
- # Optional: enable evaluation
119
- parser.add_argument(
120
- "--enable_evaluation",
121
- action="store_true",
122
- help="Enable evaluation (requires pose and ply data)",
123
- )
124
-
125
- # Output path
126
- parser.add_argument(
127
- "--output_path",
128
- type=Path,
129
- default="./eval_results_custom",
130
- help="Output path for evaluation results",
131
- )
132
-
133
- # Model parameters
134
- parser.add_argument(
135
- "--ckpt_path",
136
- type=str,
137
- default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
138
- help="Model checkpoint file path",
139
- )
140
-
141
- parser.add_argument("--merging", type=int, default=0, help="Merging parameter")
142
-
143
- # Processing parameters
144
- parser.add_argument(
145
- "--input_frame",
146
- type=int,
147
- default=200,
148
- help="Maximum number of frames to process per scene",
149
- )
150
-
151
- parser.add_argument(
152
- "--depth_conf_thresh",
153
- type=float,
154
- default=3.0,
155
- help="Depth confidence threshold to filter low-confidence depth values",
156
- )
157
-
158
- # Evaluation parameters (only used when evaluation is enabled)
159
- parser.add_argument(
160
- "--chamfer_max_dist",
161
- type=float,
162
- default=0.5,
163
- help="Maximum distance threshold used in Chamfer Distance computation",
164
- )
165
-
166
- parser.add_argument("--plot", action="store_true", help="Whether to generate plots")
167
-
168
- parser.add_argument(
169
- "--vis_attn_map",
170
- action="store_true",
171
- help="Visualize attention maps during inference",
172
- )
173
-
174
- args = parser.parse_args()
175
- torch.manual_seed(33)
176
-
177
- # Check data path exists
178
- if not args.data_path.exists():
179
- print(f"❌ Error: Data path does not exist: {args.data_path}")
180
- return
181
-
182
- # Check required subdirectories
183
- color_dir = args.data_path / "images"
184
- pose_dir = args.data_path / "pose"
185
-
186
- if not color_dir.exists():
187
- print(f"❌ Error: color directory does not exist: {color_dir}")
188
- return
189
-
190
- print(f"📁 Dataset path: {args.data_path}")
191
- # print(f"🔧 Enable evaluation: {'Yes' if args.enable_evaluation else 'No'}")
192
-
193
- # If evaluation is enabled, check pose and gt_ply directories
194
- if args.enable_evaluation:
195
- if not pose_dir.exists():
196
- print(f"❌ Error: Evaluation requires pose directory: {pose_dir}")
197
- return
198
-
199
- gt_ply_dir = args.data_path / "gt_ply"
200
- if not gt_ply_dir.exists():
201
- print(f"❌ Error: Evaluation requires gt_ply directory: {gt_ply_dir}")
202
- return
203
- print(f"📊 Evaluation will use Ground Truth")
204
- else:
205
- print(f"🏃 Inference only, no evaluation")
206
-
207
- # Create output directory
208
- args.output_path.mkdir(parents=True, exist_ok=True)
209
- output_scene_dir = args.output_path / "custom_dataset"
210
-
211
- # Check if already processed
212
- if (output_scene_dir / "metrics.json").exists() and args.enable_evaluation:
213
- print(
214
- f"⚠️ Results already exist, skipping: {output_scene_dir / 'metrics.json'}"
215
- )
216
- return
217
-
218
- # Force use of bf16 dtype
219
- dtype = torch.bfloat16
220
-
221
- # Load VGGT model
222
- print(f"🔄 Loading model: {args.ckpt_path}")
223
- model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
224
- ckpt = torch.load(args.ckpt_path, map_location="cpu")
225
- incompat = model.load_state_dict(ckpt, strict=False)
226
- # if incompat.missing_keys or incompat.unexpected_keys:
227
- # print(f"⚠️ Partially incompatible keys when loading model: {incompat}")
228
- model = model.cuda().eval()
229
- model = model.to(torch.bfloat16)
230
- print(f"✅ Model loaded")
231
-
232
- # Load scene data
233
- image_paths = get_sorted_image_paths(color_dir)
234
- if len(image_paths) == 0:
235
- print(f"❌ Error: No images found in {color_dir}")
236
- return
237
-
238
- print(f"🖼️ Found {len(image_paths)} images")
239
-
240
- # Process pose data (if evaluation is enabled)
241
- poses_gt = None
242
- first_gt_pose = None
243
- available_pose_frame_ids = None
244
- c2ws = None
245
-
246
- if args.enable_evaluation:
247
- poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_dir)
248
- if (
249
- poses_gt is None
250
- or first_gt_pose is None
251
- or available_pose_frame_ids is None
252
- ):
253
- print(f"❌ Error: Failed to load pose data")
254
- return
255
- print(f"📐 Loaded {len(poses_gt)} poses")
256
-
257
- # Frame selection
258
- if args.enable_evaluation and available_pose_frame_ids is not None:
259
- # Use pose data for frame selection
260
- selected_frame_ids, selected_image_paths, selected_pose_indices = (
261
- build_frame_selection(
262
- image_paths, available_pose_frame_ids, args.input_frame
263
- )
264
- )
265
- c2ws = poses_gt[selected_pose_indices]
266
- image_paths = selected_image_paths
267
- else:
268
- # Simply take the first N frames
269
- num_frames = min(len(image_paths), args.input_frame)
270
- selected_frame_ids = list(range(num_frames))
271
- image_paths = image_paths[:num_frames]
272
-
273
- print(f"📋 Selected {len(image_paths)} frames for processing")
274
-
275
- try:
276
- # Load images
277
- print(f"🔄 Loading images...")
278
- images = load_images_rgb(image_paths)
279
-
280
- if not images or len(images) < 3:
281
- print(f"❌ Error: Not enough valid images (need at least 3)")
282
- return
283
-
284
- frame_ids = selected_frame_ids
285
- images_array = np.stack(images)
286
- vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
287
- print(f"📐 Image patch dimensions: {patch_width}x{patch_height}")
288
-
289
- # Update attention layer patch dimensions in the model
290
- model.update_patch_dimensions(patch_width, patch_height)
291
-
292
- # Inference + Reconstruction
293
- print(f"🚀 Start inference and reconstruction...")
294
- (
295
- extrinsic_np,
296
- intrinsic_np,
297
- all_world_points,
298
- all_point_colors,
299
- all_cam_to_world_mat,
300
- inference_time_ms,
301
- ) = infer_vggt_and_reconstruct(
302
- model, vgg_input, dtype, args.depth_conf_thresh, image_paths
303
- )
304
- print(f"⏱️ Inference time: {inference_time_ms:.2f}ms")
305
-
306
- # Check results
307
- if not all_cam_to_world_mat or not all_world_points:
308
- print(f"❌ Error: Failed to obtain valid camera poses or point clouds")
309
- return
310
-
311
- # print(f"✅ Inference done, obtained {len(all_world_points)} point sets")
312
-
313
- # Evaluation and saving
314
- if args.enable_evaluation:
315
- print(f"📊 Start evaluation...")
316
- gt_ply_dir = args.data_path / "gt_ply"
317
- metrics = evaluate_scene_and_save(
318
- "custom_dataset",
319
- c2ws,
320
- first_gt_pose,
321
- frame_ids,
322
- all_cam_to_world_mat,
323
- all_world_points,
324
- output_scene_dir,
325
- gt_ply_dir,
326
- args.chamfer_max_dist,
327
- inference_time_ms,
328
- args.plot,
329
- )
330
- if metrics is not None:
331
- print("📈 Evaluation results:")
332
- for key, value in metrics.items():
333
- if key in [
334
- "chamfer_distance",
335
- "ate",
336
- "are",
337
- "rpe_rot",
338
- "rpe_trans",
339
- "inference_time_ms",
340
- ]:
341
- print(f" {key}: {float(value):.4f}")
342
-
343
- # Also visualize predicted poses in evaluation branch
344
- if args.plot:
345
- visualize_predicted_poses(
346
- all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
347
- )
348
- else:
349
- # Save reconstruction only, no evaluation
350
- print(f"💾 Saving reconstruction...")
351
- output_scene_dir.mkdir(parents=True, exist_ok=True)
352
-
353
- # Save camera poses
354
- poses_output_path = output_scene_dir / "estimated_poses.txt"
355
- with open(poses_output_path, "w") as f:
356
- for i, pose in enumerate(all_cam_to_world_mat):
357
- f.write(f"# Frame {frame_ids[i]}\n")
358
- for row in pose:
359
- f.write(" ".join(map(str, row)) + "\n")
360
- f.write("\n")
361
-
362
- # Save point cloud
363
- if all_world_points:
364
- points_output_path = output_scene_dir / "reconstructed_points.ply"
365
-
366
- # Merge all frames' point clouds and colors
367
- try:
368
- merged_point_cloud = np.vstack(all_world_points)
369
- merged_colors = (
370
- np.vstack(all_point_colors).astype(np.uint8)
371
- if all_point_colors is not None and len(all_point_colors) > 0
372
- else None
373
- )
374
- print(
375
- f"📊 Merged point clouds: {len(all_world_points)} frames, total {len(merged_point_cloud)} points"
376
- )
377
-
378
- # If too many points, randomly sample 100000 points
379
- max_points = 100000
380
- if len(merged_point_cloud) > max_points:
381
- print(
382
- f"🔽 Too many points, randomly sampling {max_points} points..."
383
- )
384
- # Randomly choose indices
385
- indices = np.random.choice(
386
- len(merged_point_cloud), size=max_points, replace=False
387
- )
388
- merged_point_cloud = merged_point_cloud[indices]
389
- if merged_colors is not None:
390
- merged_colors = merged_colors[indices]
391
- print(
392
- f"✅ Sampling done, kept {len(merged_point_cloud)} points"
393
- )
394
-
395
- # Save as PLY (with color)
396
- with open(points_output_path, "w") as f:
397
- f.write("ply\n")
398
- f.write("format ascii 1.0\n")
399
- f.write(f"element vertex {len(merged_point_cloud)}\n")
400
- f.write("property float x\n")
401
- f.write("property float y\n")
402
- f.write("property float z\n")
403
- if merged_colors is not None:
404
- f.write("property uchar red\n")
405
- f.write("property uchar green\n")
406
- f.write("property uchar blue\n")
407
- f.write("end_header\n")
408
- if merged_colors is None:
409
- for point in merged_point_cloud:
410
- if not (np.isnan(point).any() or np.isinf(point).any()):
411
- f.write(
412
- f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n"
413
- )
414
- else:
415
- for point, color in zip(merged_point_cloud, merged_colors):
416
- # Check point validity
417
- if not (np.isnan(point).any() or np.isinf(point).any()):
418
- r = int(np.clip(color[0], 0, 255))
419
- g = int(np.clip(color[1], 0, 255))
420
- b = int(np.clip(color[2], 0, 255))
421
- f.write(
422
- f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r} {g} {b}\n"
423
- )
424
-
425
- print(f"💾 Point cloud saved to: {points_output_path}")
426
-
427
- except Exception as e:
428
- print(f"⚠️ Error saving point cloud: {e}")
429
- # If merge fails, try to log per-frame info
430
- print(f"🔍 Point cloud debug info:")
431
- for i, frame_points in enumerate(all_world_points):
432
- print(
433
- f" Frame {i}: {frame_points.shape if hasattr(frame_points, 'shape') else type(frame_points)}"
434
- )
435
- if (
436
- hasattr(frame_points, "shape")
437
- and len(frame_points.shape) >= 2
438
- ):
439
- print(
440
- f" Shape: {frame_points.shape}, Dtype: {frame_points.dtype}"
441
- )
442
- if frame_points.shape[0] > 0:
443
- print(
444
- f" Range: x[{np.min(frame_points[:, 0]):.3f}, {np.max(frame_points[:, 0]):.3f}] "
445
- f"y[{np.min(frame_points[:, 1]):.3f}, {np.max(frame_points[:, 1]):.3f}] "
446
- f"z[{np.min(frame_points[:, 2]):.3f}, {np.max(frame_points[:, 2]):.3f}]"
447
- )
448
-
449
- print(f"📁 Results saved to: {output_scene_dir}")
450
-
451
- # Visualize predicted pose trajectory
452
- if args.plot:
453
- visualize_predicted_poses(
454
- all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
455
- )
456
-
457
- print(f"🎉 Done!")
458
-
459
- except Exception as e:
460
- print(f"❌ Error occurred during processing: {e}")
461
- import traceback
462
-
463
- traceback.print_exc()
464
-
465
-
466
- if __name__ == "__main__":
467
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/eval/eval_scannet.py DELETED
@@ -1,208 +0,0 @@
1
- import argparse
2
- from pathlib import Path
3
- import numpy as np
4
- import torch
5
- import os
6
- import sys
7
-
8
- # Ensure project root is in sys.path for absolute imports like `vggt.*`
9
- ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
10
- if ROOT_DIR not in sys.path:
11
- sys.path.insert(0, ROOT_DIR)
12
-
13
- from vggt.models.vggt import VGGT
14
- from vggt.utils.eval_utils import (
15
- load_poses,
16
- get_vgg_input_imgs,
17
- get_sorted_image_paths,
18
- get_all_scenes,
19
- build_frame_selection,
20
- load_images_rgb,
21
- infer_vggt_and_reconstruct,
22
- evaluate_scene_and_save,
23
- compute_average_metrics_and_save,
24
- )
25
-
26
-
27
- if __name__ == "__main__":
28
- parser = argparse.ArgumentParser()
29
- parser.add_argument(
30
- "--data_dir", type=Path, default="/data/scannetv2/process_scannet"
31
- )
32
- parser.add_argument(
33
- "--gt_ply_dir",
34
- type=Path,
35
- default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
36
- )
37
- parser.add_argument("--output_path", type=Path, default="./eval_results")
38
- parser.add_argument("--merging", type=int, default=None)
39
- parser.add_argument("--plot", type=bool, default=True)
40
- parser.add_argument(
41
- "--depth_conf_thresh",
42
- type=float,
43
- default=3.0,
44
- help="Depth confidence threshold for filtering low confidence depth values",
45
- )
46
- parser.add_argument(
47
- "--chamfer_max_dist",
48
- type=float,
49
- default=0.5,
50
- help="Maximum distance threshold in Chamfer Distance computation, distances exceeding this value will be clipped",
51
- )
52
- parser.add_argument(
53
- "--input_frame",
54
- type=int,
55
- default=200,
56
- help="Maximum number of frames selected for processing per scene",
57
- )
58
- parser.add_argument(
59
- "--num_scenes",
60
- type=int,
61
- default=50,
62
- help="Maximum number of scenes to evaluate",
63
- )
64
- parser.add_argument(
65
- "--ckpt_path",
66
- type=str,
67
- default="./ckpt/model_tracker_fixed_e20.pt",
68
- help="Path to the model checkpoint file",
69
- )
70
- parser.add_argument(
71
- "--vis_attn_map",
72
- action="store_true",
73
- help="Whether to visualize attention maps during inference",
74
- )
75
- args = parser.parse_args()
76
- torch.manual_seed(33)
77
-
78
- # Scene sampling
79
- scannet_scenes = get_all_scenes(args.data_dir, args.num_scenes)
80
- print(f"Evaluate {len(scannet_scenes)} scenes")
81
-
82
- all_scenes_metrics = {"scenes": {}, "average": {}}
83
- # Force use of bf16 data type
84
- dtype = torch.bfloat16
85
- # Load VGGT model
86
- model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
87
- ckpt = torch.load(args.ckpt_path, map_location="cpu")
88
- incompat = model.load_state_dict(ckpt, strict=False)
89
- model = model.cuda().eval()
90
- model = model.to(torch.bfloat16)
91
-
92
- # Process each scene
93
- for scene in scannet_scenes:
94
- scene_dir = args.data_dir / f"{scene}"
95
- output_scene_dir = args.output_path / f"input_frame_{args.input_frame}" / scene
96
- if (output_scene_dir / "metrics.json").exists():
97
- continue
98
-
99
- # Load scene data
100
- images_dir = scene_dir / "color"
101
- pose_path = scene_dir / "pose"
102
- image_paths = get_sorted_image_paths(images_dir)
103
- poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_path)
104
- if (
105
- poses_gt is None
106
- or first_gt_pose is None
107
- or available_pose_frame_ids is None
108
- ):
109
- print(f"Skipping scene {scene}: no pose data")
110
- continue
111
-
112
- # Frame filtering
113
- selected_frame_ids, selected_image_paths, selected_pose_indices = (
114
- build_frame_selection(
115
- image_paths, available_pose_frame_ids, args.input_frame
116
- )
117
- )
118
-
119
- # Get corresponding poses
120
- c2ws = poses_gt[selected_pose_indices]
121
- image_paths = selected_image_paths
122
-
123
- if len(image_paths) == 0:
124
- print(f"No images found in {images_dir}")
125
- continue
126
-
127
- print("🚩Processing", scene, f"Found {len(image_paths)} images")
128
- all_cam_to_world_mat = []
129
- all_world_points = []
130
-
131
- try:
132
- # Load images
133
- images = load_images_rgb(image_paths)
134
-
135
- if not images or len(images) < 3:
136
- print(f"Skipping {scene}: insufficient valid images")
137
- continue
138
-
139
- frame_ids = selected_frame_ids
140
- images_array = np.stack(images)
141
- vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
142
- print(f"Patch dimensions: {patch_width}x{patch_height}")
143
-
144
- # Update model attention layers with dynamic patch dimensions
145
- model.update_patch_dimensions(patch_width, patch_height)
146
-
147
- # Inference + Reconstruction
148
- (
149
- extrinsic_np,
150
- intrinsic_np,
151
- all_world_points,
152
- all_point_colors,
153
- all_cam_to_world_mat,
154
- inference_time_ms,
155
- ) = infer_vggt_and_reconstruct(
156
- model, vgg_input, dtype, args.depth_conf_thresh, image_paths
157
- )
158
- print(f"Inference time: {inference_time_ms:.2f}ms")
159
-
160
- # Process results
161
- if not all_cam_to_world_mat or not all_world_points:
162
- print(
163
- f"Skipping {scene}: failed to obtain valid camera poses or point clouds"
164
- )
165
- continue
166
-
167
- # Evaluate and save
168
- metrics = evaluate_scene_and_save(
169
- scene,
170
- c2ws,
171
- first_gt_pose,
172
- frame_ids,
173
- all_cam_to_world_mat,
174
- all_world_points,
175
- output_scene_dir,
176
- args.gt_ply_dir,
177
- args.chamfer_max_dist,
178
- inference_time_ms,
179
- args.plot,
180
- )
181
- if metrics is not None:
182
- all_scenes_metrics["scenes"][scene] = {
183
- key: float(value)
184
- for key, value in metrics.items()
185
- if key
186
- in [
187
- "chamfer_distance",
188
- "ate",
189
- "are",
190
- "rpe_rot",
191
- "rpe_trans",
192
- "inference_time_ms",
193
- ]
194
- }
195
- print("Complete metrics", all_scenes_metrics["scenes"][scene])
196
-
197
- except Exception as e:
198
- print(f"Error processing scene {scene}: {e}")
199
- import traceback
200
-
201
- traceback.print_exc()
202
-
203
- # Summarize average metrics and save
204
- compute_average_metrics_and_save(
205
- all_scenes_metrics,
206
- args.output_path,
207
- args.input_frame,
208
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/eval/utils.py DELETED
@@ -1,142 +0,0 @@
1
- import numpy as np
2
- from scipy.spatial import cKDTree as KDTree
3
-
4
-
5
- def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
6
- """
7
- Args:
8
- - depthmap (HxW array):
9
- - camera_intrinsics: a 3x3 matrix
10
- Returns:
11
- pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
12
- """
13
- camera_intrinsics = np.float32(camera_intrinsics)
14
- H, W = depthmap.shape
15
-
16
- assert camera_intrinsics[0, 1] == 0.0
17
- assert camera_intrinsics[1, 0] == 0.0
18
- if pseudo_focal is None:
19
- fu = camera_intrinsics[0, 0]
20
- fv = camera_intrinsics[1, 1]
21
- else:
22
- assert pseudo_focal.shape == (H, W)
23
- fu = fv = pseudo_focal
24
- cu = camera_intrinsics[0, 2]
25
- cv = camera_intrinsics[1, 2]
26
-
27
- u, v = np.meshgrid(np.arange(W), np.arange(H))
28
- z_cam = depthmap
29
- x_cam = (u - cu) * z_cam / fu
30
- y_cam = (v - cv) * z_cam / fv
31
- X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
32
-
33
- valid_mask = depthmap > 0.0
34
- return X_cam, valid_mask
35
-
36
-
37
- def depthmap_to_absolute_camera_coordinates(
38
- depthmap, camera_intrinsics, camera_pose, **kw
39
- ):
40
- """
41
- Args:
42
- - depthmap (HxW array):
43
- - camera_intrinsics: a 3x3 matrix
44
- - camera_pose: a 4x3 or 4x4 cam2world matrix
45
- Returns:
46
- pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
47
- """
48
- X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
49
-
50
- X_world = X_cam # default
51
- if camera_pose is not None:
52
-
53
- R_cam2world = camera_pose[:3, :3]
54
- t_cam2world = camera_pose[:3, 3]
55
-
56
- X_world = (
57
- np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
58
- )
59
-
60
- return X_world, valid_mask
61
-
62
-
63
- def completion_ratio(gt_points, rec_points, dist_th=0.05):
64
- gen_points_kd_tree = KDTree(rec_points)
65
- distances, _ = gen_points_kd_tree.query(gt_points)
66
- comp_ratio = np.mean((distances < dist_th).astype(np.float32))
67
- return comp_ratio
68
-
69
-
70
- def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None):
71
- gt_points_kd_tree = KDTree(gt_points)
72
- distances, idx = gt_points_kd_tree.query(rec_points, workers=-1)
73
- acc = np.mean(distances)
74
-
75
- acc_median = np.median(distances)
76
-
77
- if gt_normals is not None and rec_normals is not None:
78
- normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1)
79
- normal_dot = np.abs(normal_dot)
80
-
81
- return acc, acc_median, np.mean(normal_dot), np.median(normal_dot)
82
-
83
- return acc, acc_median
84
-
85
-
86
- def completion(gt_points, rec_points, gt_normals=None, rec_normals=None):
87
- gt_points_kd_tree = KDTree(rec_points)
88
- distances, idx = gt_points_kd_tree.query(gt_points, workers=-1)
89
- comp = np.mean(distances)
90
- comp_median = np.median(distances)
91
-
92
- if gt_normals is not None and rec_normals is not None:
93
- normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1)
94
- normal_dot = np.abs(normal_dot)
95
-
96
- return comp, comp_median, np.mean(normal_dot), np.median(normal_dot)
97
-
98
- return comp, comp_median
99
-
100
-
101
- def compute_iou(pred_vox, target_vox):
102
- # Get voxel indices
103
- v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()]
104
- v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()]
105
-
106
- # Convert to sets for set operations
107
- v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices)
108
- v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices)
109
-
110
- # Compute intersection and union
111
- intersection = v_pred_filled & v_target_filled
112
- union = v_pred_filled | v_target_filled
113
-
114
- # Compute IoU
115
- iou = len(intersection) / len(union)
116
- return iou
117
-
118
-
119
- def colmap_to_opencv_intrinsics(K):
120
- """
121
- Modify camera intrinsics to follow a different convention.
122
- Coordinates of the center of the top-left pixels are by default:
123
- - (0.5, 0.5) in Colmap
124
- - (0,0) in OpenCV
125
- """
126
- K = K.copy()
127
- K[0, 2] -= 0.5
128
- K[1, 2] -= 0.5
129
- return K
130
-
131
-
132
- def opencv_to_colmap_intrinsics(K):
133
- """
134
- Modify camera intrinsics to follow a different convention.
135
- Coordinates of the center of the top-left pixels are by default:
136
- - (0.5, 0.5) in Colmap
137
- - (0,0) in OpenCV
138
- """
139
- K = K.copy()
140
- K[0, 2] += 0.5
141
- K[1, 2] += 0.5
142
- return K
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/merging/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from . import merge
2
-
3
- __all__ = ["merge"]
 
 
 
 
FastVGGT/merging/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (187 Bytes)
 
FastVGGT/merging/__pycache__/merge.cpython-310.pyc DELETED
Binary file (7.54 kB)
 
FastVGGT/merging/merge.py DELETED
@@ -1,370 +0,0 @@
1
- import torch
2
- from typing import Tuple, Callable, Optional, Union
3
-
4
-
5
- @torch.jit.script
6
- def fast_similarity_chunks(
7
- a: torch.Tensor, b_transposed: torch.Tensor, chunk_size: int
8
- ) -> Tuple[torch.Tensor, torch.Tensor]:
9
-
10
- B, num_src, C = a.shape
11
- original_dtype = a.dtype
12
-
13
- # Convert to bf16 for computation to improve performance and reduce memory usage
14
- a_bf16 = a.to(torch.bfloat16)
15
- b_transposed_bf16 = b_transposed.to(torch.bfloat16)
16
- node_max = torch.empty(B, num_src, device=a.device, dtype=original_dtype)
17
- node_idx = torch.empty(B, num_src, device=a.device, dtype=torch.long)
18
-
19
- # Process in chunks
20
- for i in range(0, num_src, chunk_size):
21
- end_i = min(i + chunk_size, num_src)
22
- a_chunk = a_bf16[:, i:end_i, :] # [B, chunk_size, C]
23
- scores_chunk = torch.bmm(a_chunk, b_transposed_bf16)
24
- chunk_max_bf16, chunk_idx = torch.max(scores_chunk, dim=2)
25
- chunk_max = chunk_max_bf16.to(original_dtype)
26
- node_max[:, i:end_i] = chunk_max
27
- node_idx[:, i:end_i] = chunk_idx
28
- return node_max, node_idx
29
-
30
-
31
- def do_nothing(
32
- x: torch.Tensor,
33
- extra_tensors=None,
34
- extra_tensors_2=None,
35
- ) -> Union[
36
- torch.Tensor,
37
- Tuple[torch.Tensor, torch.Tensor],
38
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
39
- ]:
40
- if extra_tensors is not None and extra_tensors_2 is not None:
41
- return x, extra_tensors, extra_tensors_2
42
- elif extra_tensors is not None:
43
- return x, extra_tensors
44
- else:
45
- return x
46
-
47
-
48
- def token_merge_bipartite2d(
49
- metric: torch.Tensor,
50
- w: int,
51
- h: int,
52
- sx: int,
53
- sy: int,
54
- r: int,
55
- no_rand: bool = False,
56
- generator: Optional[torch.Generator] = None,
57
- enable_protection: bool = False,
58
- ) -> Tuple[Callable, Callable]:
59
- """
60
- Divide tokens into source (src) and destination (dst) groups, and merge r tokens from src to dst.
61
- dst tokens are selected by randomly choosing one token from each (sx, sy) region.
62
- Optionally protect the top 10% of tokens from merging based on importance scores.
63
-
64
- Args:
65
- - metric [B, N, C]: Tensor for similarity computation, B=batch size, N=token count, C=feature dimension
66
- - w: Image width in tokens
67
- - h: Image height in tokens
68
- - sx: dst stride in x dimension, must divide w evenly
69
- - sy: dst stride in y dimension, must divide h evenly
70
- - r: Number of tokens to remove through merging
71
- - no_rand: If True, disable randomness (use only top-left token)
72
- - generator: Random number generator if no_rand is False and not None
73
- - enable_protection: If True, enable importance protection feature
74
-
75
- Returns:
76
- - (merge, unmerge): Two functions for merging tokens and restoring pre-merge state
77
- """
78
- B, N, _ = metric.shape # Batch size B, total tokens N
79
- if r <= 0:
80
- return do_nothing, do_nothing
81
-
82
- gather = torch.gather
83
-
84
- tokens_per_img = w * h + 5
85
- num_imgs = N // tokens_per_img
86
- assert tokens_per_img * num_imgs == N, "Token count doesn't match (w*h+5)*num_imgs"
87
-
88
- with torch.no_grad():
89
- # Determine whether to compute importance scores based on enable_protection
90
- if enable_protection:
91
- num_protected = int(N * 0.1)
92
- step = max(1, N // num_protected)
93
- protected_indices = torch.arange(0, N, step, device=metric.device)[
94
- :num_protected
95
- ]
96
- else:
97
- protected_indices = None
98
- num_protected = 0
99
-
100
- # Global idx_buffer_seq of length N; -1 indicates dst, 0 indicates src (maintain original logic)
101
- idx_buffer_seq = torch.zeros(N, device=metric.device, dtype=torch.int64)
102
- hsy, wsx = h // sy, w // sx # Number of blocks within each image
103
-
104
- # Mark first image entirely as dst
105
- if num_imgs > 0:
106
- idx_buffer_seq[:tokens_per_img] = -1
107
-
108
- # Process other images - fully vectorized batch operations
109
- if num_imgs > 1:
110
- cls_indices = (
111
- torch.arange(1, num_imgs, device=metric.device) * tokens_per_img
112
- )
113
- cls_indices = cls_indices[:, None] + torch.arange(5, device=metric.device)
114
- idx_buffer_seq[cls_indices.flatten()] = -1
115
- effective_h = min(hsy * sy, h)
116
- effective_w = min(wsx * sx, w)
117
- effective_grid_size = effective_h * effective_w
118
-
119
- if no_rand:
120
- base_pattern = torch.zeros(
121
- effective_grid_size, device=metric.device, dtype=torch.int64
122
- )
123
- grid_starts = (
124
- torch.arange(1, num_imgs, device=metric.device) * tokens_per_img + 5
125
- )
126
- grid_indices = grid_starts[:, None] + torch.arange(
127
- effective_grid_size, device=metric.device
128
- )
129
- idx_buffer_seq[grid_indices.flatten()] = base_pattern.repeat(
130
- num_imgs - 1
131
- )
132
- else:
133
- total_other_imgs = num_imgs - 1
134
- all_rand_idx = torch.randint(
135
- sy * sx,
136
- size=(total_other_imgs, hsy, wsx),
137
- device=metric.device,
138
- generator=generator,
139
- )
140
-
141
- scatter_src = -torch.ones(
142
- total_other_imgs, hsy, wsx, device=metric.device, dtype=torch.int64
143
- )
144
-
145
- idx_buffer_batch = torch.zeros(
146
- total_other_imgs,
147
- hsy,
148
- wsx,
149
- sy * sx,
150
- device=metric.device,
151
- dtype=torch.int64,
152
- )
153
- idx_buffer_batch.scatter_(
154
- dim=3,
155
- index=all_rand_idx.unsqueeze(-1),
156
- src=scatter_src.unsqueeze(-1),
157
- )
158
-
159
- idx_buffer_batch = (
160
- idx_buffer_batch.view(total_other_imgs, hsy, wsx, sy, sx)
161
- .transpose(2, 3)
162
- .reshape(total_other_imgs, hsy * sy, wsx * sx)
163
- )
164
-
165
- # Batch fill to target positions - still needs a small loop here, but operations are greatly reduced
166
- for i in range(total_other_imgs):
167
- img_idx = i + 1
168
- grid_start = img_idx * tokens_per_img + 5
169
- flat_view = idx_buffer_batch[
170
- i, :effective_h, :effective_w
171
- ].flatten()
172
- idx_buffer_seq[grid_start : grid_start + effective_grid_size] = (
173
- flat_view
174
- )
175
-
176
- rand_idx = idx_buffer_seq.reshape(1, -1, 1).argsort(dim=1)
177
- num_dst_orig = int((idx_buffer_seq == -1).sum())
178
-
179
- # Original src and dst indices
180
- a_idx_orig = rand_idx[:, num_dst_orig:, :]
181
- b_idx_orig = rand_idx[:, :num_dst_orig, :]
182
- a_idx = a_idx_orig
183
- b_idx = b_idx_orig
184
-
185
- if enable_protection:
186
- protected_idx = protected_indices.unsqueeze(0).unsqueeze(-1)
187
- num_protected_actual = protected_idx.shape[1]
188
- else:
189
- protected_idx = None
190
- num_protected_actual = 0
191
-
192
- num_src = a_idx.shape[1]
193
- num_dst = b_idx.shape[1]
194
-
195
- # Define an internal function to separate src, dst, and protected tokens
196
- def split(x):
197
- C = x.shape[-1]
198
-
199
- if enable_protection:
200
- src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
201
- dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
202
- protected = gather(
203
- x, dim=1, index=protected_idx.expand(B, num_protected_actual, C)
204
- )
205
- return src, dst, protected
206
- else:
207
- src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
208
- dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
209
- return src, dst
210
-
211
- # Compute cosine similarity (normalize first then dot product)
212
- metric = metric / metric.norm(dim=-1, keepdim=True)
213
- if enable_protection:
214
- a, b, protected = split(metric)
215
- else:
216
- a, b = split(metric)
217
-
218
- r = min(a.shape[1], r)
219
- num_src_actual = a.shape[1]
220
- chunk_size = min(5000, num_src_actual)
221
-
222
- node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype)
223
- node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long)
224
-
225
- b_transposed = b.transpose(-1, -2)
226
- node_max, node_idx = fast_similarity_chunks(a, b_transposed, chunk_size)
227
- edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
228
-
229
- # If protection is enabled, filter out protected tokens to ensure they are not merged
230
- if enable_protection:
231
- src_indices = a_idx[0, :, 0]
232
- protected_mask_src = torch.isin(src_indices, protected_indices)
233
- edge_flat = edge_idx[0, :, 0]
234
- valid_mask = ~protected_mask_src[edge_flat]
235
- valid_edges = edge_flat[valid_mask]
236
-
237
- valid_count = valid_edges.shape[0]
238
- r_actual = min(r, valid_count)
239
-
240
- unm_idx = valid_edges[r_actual:].unsqueeze(0).unsqueeze(-1)
241
- src_idx = valid_edges[:r_actual].unsqueeze(0).unsqueeze(-1)
242
- else:
243
- unm_idx = edge_idx[..., r:, :]
244
- src_idx = edge_idx[..., :r, :]
245
- r_actual = r
246
-
247
- # Get dst token indices corresponding to each src token to be merged
248
- dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
249
- r = r_actual
250
-
251
- # Define merge function to merge selected src tokens to corresponding dst tokens
252
- def merge(
253
- x: torch.Tensor,
254
- mode: str = "mean",
255
- extra_tensors=None,
256
- extra_tensors_2=None,
257
- ) -> Union[
258
- torch.Tensor,
259
- Tuple[torch.Tensor, torch.Tensor],
260
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
261
- ]:
262
- if enable_protection:
263
- src, dst, protected = split(x)
264
- else:
265
- src, dst = split(x)
266
-
267
- n, t1, c = src.shape
268
-
269
- # Extract unmerged src tokens - using actual unm_idx size
270
- unm_len = unm_idx.shape[1]
271
- unm = gather(src, dim=-2, index=unm_idx.expand(n, unm_len, c))
272
- src_len = src_idx.shape[1]
273
- src = gather(src, dim=-2, index=src_idx.expand(n, src_len, c))
274
- dst = dst.scatter_reduce(-2, dst_idx.expand(n, src_len, c), src, reduce=mode)
275
-
276
- # ---------------- Extra tensor processing ----------------
277
- merged_extra_1 = None
278
- merged_extra_2 = None
279
- if extra_tensors is not None:
280
- E_dim = extra_tensors.shape[-1]
281
- if enable_protection:
282
- src_e, dst_e, protected_e = split(extra_tensors)
283
- else:
284
- src_e, dst_e = split(extra_tensors)
285
-
286
- # Consistent with main tensor, only select r src tokens to be merged
287
- src_e_r = gather(src_e, dim=-2, index=src_idx.expand(n, src_len, E_dim))
288
- unm_e = gather(src_e, dim=-2, index=unm_idx.expand(n, unm_len, E_dim))
289
-
290
- dst_e = dst_e.scatter_reduce(
291
- -2, dst_idx.expand(n, src_len, E_dim), src_e_r, reduce=mode
292
- )
293
- if enable_protection:
294
- merged_extra_1 = torch.cat([unm_e, dst_e, protected_e], dim=1)
295
- else:
296
- merged_extra_1 = torch.cat([unm_e, dst_e], dim=1)
297
-
298
- if extra_tensors_2 is not None:
299
- E_dim_2 = extra_tensors_2.shape[-1]
300
- if enable_protection:
301
- src_e2, dst_e2, protected_e2 = split(extra_tensors_2)
302
- else:
303
- src_e2, dst_e2 = split(extra_tensors_2)
304
-
305
- src_e2_r = gather(src_e2, dim=-2, index=src_idx.expand(n, src_len, E_dim_2))
306
- unm_e2 = gather(src_e2, dim=-2, index=unm_idx.expand(n, unm_len, E_dim_2))
307
-
308
- dst_e2 = dst_e2.scatter_reduce(
309
- -2, dst_idx.expand(n, src_len, E_dim_2), src_e2_r, reduce=mode
310
- )
311
- if enable_protection:
312
- merged_extra_2 = torch.cat([unm_e2, dst_e2, protected_e2], dim=1)
313
- else:
314
- merged_extra_2 = torch.cat([unm_e2, dst_e2], dim=1)
315
-
316
- if enable_protection:
317
- main_result = torch.cat([unm, dst, protected], dim=1)
318
- else:
319
- main_result = torch.cat([unm, dst], dim=1)
320
-
321
- if merged_extra_1 is not None and merged_extra_2 is not None:
322
- return main_result, merged_extra_1, merged_extra_2
323
- elif merged_extra_1 is not None:
324
- return main_result, merged_extra_1
325
- else:
326
- return main_result
327
-
328
- # Define unmerge function to restore pre-merge state (for decoder)
329
- def unmerge(x: torch.Tensor) -> torch.Tensor:
330
- unm_len = unm_idx.shape[1]
331
- dst_len = num_dst
332
- src_len = src_idx.shape[1]
333
- unm = x[..., :unm_len, :]
334
- dst = x[..., unm_len : unm_len + dst_len, :]
335
-
336
- if enable_protection:
337
- protected = x[
338
- ..., unm_len + dst_len : unm_len + dst_len + num_protected_actual, :
339
- ]
340
-
341
- _, _, c = unm.shape
342
- src = gather(dst, dim=-2, index=dst_idx.expand(B, src_len, c))
343
- out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
344
- out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
345
- out.scatter_(
346
- dim=-2,
347
- index=gather(
348
- a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx
349
- ).expand(B, unm_len, c),
350
- src=unm,
351
- )
352
-
353
- out.scatter_(
354
- dim=-2,
355
- index=gather(
356
- a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx
357
- ).expand(B, src_len, c),
358
- src=src,
359
- )
360
-
361
- if enable_protection:
362
- out.scatter_(
363
- dim=-2,
364
- index=protected_idx.expand(B, num_protected_actual, c),
365
- src=protected,
366
- )
367
-
368
- return out
369
-
370
- return merge, unmerge
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/requirements.txt DELETED
@@ -1,15 +0,0 @@
1
- torch==2.3.1
2
- torchvision==0.18.1
3
- numpy==1.26.1
4
- Pillow
5
- huggingface_hub
6
- einops
7
- safetensors
8
- evo
9
- open3d
10
- matplotlib
11
- scipy
12
- opencv-python
13
- scikit-image
14
- tqdm
15
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/vggt/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (132 Bytes)
 
FastVGGT/vggt/dependency/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (143 Bytes)
 
FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc DELETED
Binary file (1.39 kB)
 
FastVGGT/vggt/dependency/distortion.py DELETED
@@ -1,54 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import numpy as np
8
- import torch
9
-
10
-
11
- def apply_distortion(points, distortion_params):
12
- """
13
- Apply distortion to normalized camera coordinates.
14
-
15
- Args:
16
- points: Array of normalized camera coordinates
17
- distortion_params: Distortion parameters
18
-
19
- Returns:
20
- Distorted coordinates
21
- """
22
- # Simple passthrough for now - implement actual distortion if needed
23
- return points
24
-
25
-
26
- def iterative_undistortion(points, distortion_params, max_iter=10):
27
- """
28
- Remove distortion from normalized camera coordinates using iterative method.
29
-
30
- Args:
31
- points: Array of distorted normalized camera coordinates
32
- distortion_params: Distortion parameters
33
- max_iter: Maximum number of iterations
34
-
35
- Returns:
36
- Undistorted coordinates
37
- """
38
- # Simple passthrough for now - implement actual undistortion if needed
39
- return points
40
-
41
-
42
- def single_undistortion(points, distortion_params):
43
- """
44
- Remove distortion from normalized camera coordinates using single step.
45
-
46
- Args:
47
- points: Array of distorted normalized camera coordinates
48
- distortion_params: Distortion parameters
49
-
50
- Returns:
51
- Undistorted coordinates
52
- """
53
- # Simple passthrough for now - implement actual undistortion if needed
54
- return points
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc DELETED
Binary file (4.24 kB)
 
FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc DELETED
Binary file (12.8 kB)
 
FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc DELETED
Binary file (3.1 kB)
 
FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc DELETED
Binary file (3.41 kB)
 
FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc DELETED
Binary file (3.18 kB)
 
FastVGGT/vggt/heads/camera_head.py DELETED
@@ -1,149 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import math
8
- import numpy as np
9
-
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
-
14
- from vggt.layers import Mlp
15
- from vggt.layers.block import Block
16
- from vggt.heads.head_act import activate_pose
17
-
18
-
19
- class CameraHead(nn.Module):
20
- """
21
- CameraHead predicts camera parameters from token representations using iterative refinement.
22
-
23
- It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
- """
25
-
26
- def __init__(
27
- self,
28
- dim_in: int = 2048,
29
- trunk_depth: int = 4,
30
- pose_encoding_type: str = "absT_quaR_FoV",
31
- num_heads: int = 16,
32
- mlp_ratio: int = 4,
33
- init_values: float = 0.01,
34
- trans_act: str = "linear",
35
- quat_act: str = "linear",
36
- fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
- ):
38
- super().__init__()
39
-
40
- if pose_encoding_type == "absT_quaR_FoV":
41
- self.target_dim = 9
42
- else:
43
- raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
-
45
- self.trans_act = trans_act
46
- self.quat_act = quat_act
47
- self.fl_act = fl_act
48
- self.trunk_depth = trunk_depth
49
-
50
- # Build the trunk using a sequence of transformer blocks.
51
- self.trunk = nn.Sequential(
52
- *[
53
- Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
54
- for _ in range(trunk_depth)
55
- ]
56
- )
57
-
58
- # Normalizations for camera token and trunk output.
59
- self.token_norm = nn.LayerNorm(dim_in)
60
- self.trunk_norm = nn.LayerNorm(dim_in)
61
-
62
- # Learnable empty camera pose token.
63
- self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
64
- self.embed_pose = nn.Linear(self.target_dim, dim_in)
65
-
66
- # Module for producing modulation parameters: shift, scale, and a gate.
67
- self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
68
-
69
- # Adaptive layer normalization without affine parameters.
70
- self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
71
- self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
72
-
73
- def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
74
- """
75
- Forward pass to predict camera parameters.
76
-
77
- Args:
78
- aggregated_tokens_list (list): List of token tensors from the network;
79
- the last tensor is used for prediction.
80
- num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
81
-
82
- Returns:
83
- list: A list of predicted camera encodings (post-activation) from each iteration.
84
- """
85
- # Use tokens from the last block for camera prediction.
86
- tokens = aggregated_tokens_list[-1]
87
-
88
- # Extract the camera tokens
89
- pose_tokens = tokens[:, :, 0]
90
- pose_tokens = self.token_norm(pose_tokens)
91
-
92
- pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
93
- return pred_pose_enc_list
94
-
95
- def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
96
- """
97
- Iteratively refine camera pose predictions.
98
-
99
- Args:
100
- pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
101
- num_iterations (int): Number of refinement iterations.
102
-
103
- Returns:
104
- list: List of activated camera encodings from each iteration.
105
- """
106
- B, S, C = pose_tokens.shape # S is expected to be 1.
107
- pred_pose_enc = None
108
- pred_pose_enc_list = []
109
-
110
- for _ in range(num_iterations):
111
- # Use a learned empty pose for the first iteration.
112
- if pred_pose_enc is None:
113
- module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
114
- else:
115
- # Detach the previous prediction to avoid backprop through time.
116
- pred_pose_enc = pred_pose_enc.detach()
117
- module_input = self.embed_pose(pred_pose_enc)
118
-
119
- # Generate modulation parameters and split them into shift, scale, and gate components.
120
- shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
121
-
122
- # Adaptive layer normalization and modulation.
123
- pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
124
- pose_tokens_modulated = pose_tokens_modulated + pose_tokens
125
-
126
- pose_tokens_modulated = self.trunk(pose_tokens_modulated)
127
- # Compute the delta update for the pose encoding.
128
- pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
129
-
130
- if pred_pose_enc is None:
131
- pred_pose_enc = pred_pose_enc_delta
132
- else:
133
- pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
134
-
135
- # Apply final activation functions for translation, quaternion, and field-of-view.
136
- activated_pose = activate_pose(
137
- pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
138
- )
139
- pred_pose_enc_list.append(activated_pose)
140
-
141
- return pred_pose_enc_list
142
-
143
-
144
- def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
145
- """
146
- Modulate the input tensor using scaling and shifting parameters.
147
- """
148
- # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
149
- return x * (1 + scale) + shift
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/vggt/heads/dpt_head.py DELETED
@@ -1,598 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
-
10
-
11
- import os
12
- from typing import List, Dict, Tuple, Union, Optional
13
-
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- from .head_act import activate_head
18
- from .utils import create_uv_grid, position_grid_to_embed
19
-
20
-
21
- class DPTHead(nn.Module):
22
- """
23
- DPT Head for dense prediction tasks.
24
-
25
- This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
- (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
- backbone and produces dense predictions by fusing multi-scale features.
28
-
29
- Args:
30
- dim_in (int): Input dimension (channels).
31
- patch_size (int, optional): Patch size. Default is 14.
32
- output_dim (int, optional): Number of output channels. Default is 4.
33
- activation (str, optional): Activation type. Default is "inv_log".
34
- conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
- features (int, optional): Feature channels for intermediate representations. Default is 256.
36
- out_channels (List[int], optional): Output channels for each intermediate layer.
37
- intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
- pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
- feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
- down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
- """
42
-
43
- def __init__(
44
- self,
45
- dim_in: int,
46
- patch_size: int = 14,
47
- output_dim: int = 4,
48
- activation: str = "inv_log",
49
- conf_activation: str = "expp1",
50
- features: int = 256,
51
- out_channels: List[int] = [256, 512, 1024, 1024],
52
- intermediate_layer_idx: List[int] = [0, 1, 2, 3],
53
- pos_embed: bool = True,
54
- feature_only: bool = False,
55
- down_ratio: int = 1,
56
- ) -> None:
57
- super(DPTHead, self).__init__()
58
- self.patch_size = patch_size
59
- self.activation = activation
60
- self.conf_activation = conf_activation
61
- self.pos_embed = pos_embed
62
- self.feature_only = feature_only
63
- self.down_ratio = down_ratio
64
- self.intermediate_layer_idx = intermediate_layer_idx
65
-
66
- self.norm = nn.LayerNorm(dim_in)
67
-
68
- # Projection layers for each output channel from tokens.
69
- self.projects = nn.ModuleList(
70
- [
71
- nn.Conv2d(
72
- in_channels=dim_in,
73
- out_channels=oc,
74
- kernel_size=1,
75
- stride=1,
76
- padding=0,
77
- )
78
- for oc in out_channels
79
- ]
80
- )
81
-
82
- # Resize layers for upsampling feature maps.
83
- self.resize_layers = nn.ModuleList(
84
- [
85
- nn.ConvTranspose2d(
86
- in_channels=out_channels[0],
87
- out_channels=out_channels[0],
88
- kernel_size=4,
89
- stride=4,
90
- padding=0,
91
- ),
92
- nn.ConvTranspose2d(
93
- in_channels=out_channels[1],
94
- out_channels=out_channels[1],
95
- kernel_size=2,
96
- stride=2,
97
- padding=0,
98
- ),
99
- nn.Identity(),
100
- nn.Conv2d(
101
- in_channels=out_channels[3],
102
- out_channels=out_channels[3],
103
- kernel_size=3,
104
- stride=2,
105
- padding=1,
106
- ),
107
- ]
108
- )
109
-
110
- self.scratch = _make_scratch(out_channels, features, expand=False)
111
-
112
- # Attach additional modules to scratch.
113
- self.scratch.stem_transpose = nn.Identity() # Use Identity instead of None
114
- self.scratch.refinenet1 = _make_fusion_block(features)
115
- self.scratch.refinenet2 = _make_fusion_block(features)
116
- self.scratch.refinenet3 = _make_fusion_block(features)
117
- self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
118
-
119
- head_features_1 = features
120
- head_features_2 = 32
121
-
122
- if feature_only:
123
- self.scratch.output_conv1 = nn.Conv2d(
124
- head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
125
- )
126
- else:
127
- self.scratch.output_conv1 = nn.Conv2d(
128
- head_features_1,
129
- head_features_1 // 2,
130
- kernel_size=3,
131
- stride=1,
132
- padding=1,
133
- )
134
- conv2_in_channels = head_features_1 // 2
135
-
136
- self.scratch.output_conv2 = nn.Sequential(
137
- nn.Conv2d(
138
- conv2_in_channels,
139
- head_features_2,
140
- kernel_size=3,
141
- stride=1,
142
- padding=1,
143
- ),
144
- nn.ReLU(inplace=True),
145
- nn.Conv2d(
146
- head_features_2, output_dim, kernel_size=1, stride=1, padding=0
147
- ),
148
- )
149
-
150
- def forward(
151
- self,
152
- aggregated_tokens_list: List[torch.Tensor],
153
- images: torch.Tensor,
154
- patch_start_idx: int,
155
- frames_chunk_size: int = 8,
156
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
157
- """
158
- Forward pass through the DPT head, supports processing by chunking frames.
159
- Args:
160
- aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
161
- images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
162
- patch_start_idx (int): Starting index for patch tokens in the token sequence.
163
- Used to separate patch tokens from other tokens (e.g., camera or register tokens).
164
- frames_chunk_size (int, optional): Number of frames to process in each chunk.
165
- If None or larger than S, all frames are processed at once. Default: 8.
166
-
167
- Returns:
168
- Tensor or Tuple[Tensor, Tensor]:
169
- - If feature_only=True: Feature maps with shape [B, S, C, H, W]
170
- - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
171
- """
172
- B, S, _, H, W = images.shape
173
-
174
- # If frames_chunk_size is not specified or greater than S, process all frames at once
175
- if frames_chunk_size is None or frames_chunk_size >= S:
176
- return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
177
-
178
- # Otherwise, process frames in chunks to manage memory usage
179
- assert frames_chunk_size > 0
180
-
181
- # Process frames in batches
182
- all_preds = []
183
- all_conf = []
184
-
185
- for frames_start_idx in range(0, S, frames_chunk_size):
186
- frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
187
-
188
- # Process batch of frames
189
- if self.feature_only:
190
- chunk_output = self._forward_impl(
191
- aggregated_tokens_list,
192
- images,
193
- patch_start_idx,
194
- frames_start_idx,
195
- frames_end_idx,
196
- )
197
- all_preds.append(chunk_output)
198
- else:
199
- chunk_preds, chunk_conf = self._forward_impl(
200
- aggregated_tokens_list,
201
- images,
202
- patch_start_idx,
203
- frames_start_idx,
204
- frames_end_idx,
205
- )
206
- all_preds.append(chunk_preds)
207
- all_conf.append(chunk_conf)
208
-
209
- # Concatenate results along the sequence dimension
210
- if self.feature_only:
211
- return torch.cat(all_preds, dim=1)
212
- else:
213
- return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
214
-
215
- def _forward_impl(
216
- self,
217
- aggregated_tokens_list: List[torch.Tensor],
218
- images: torch.Tensor,
219
- patch_start_idx: int,
220
- frames_start_idx: Optional[int] = None,
221
- frames_end_idx: Optional[int] = None,
222
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
223
- """
224
- Implementation of the forward pass through the DPT head.
225
-
226
- This method processes a specific chunk of frames from the sequence.
227
-
228
- Args:
229
- aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
230
- images (Tensor): Input images with shape [B, S, 3, H, W].
231
- patch_start_idx (int): Starting index for patch tokens.
232
- frames_start_idx (int, optional): Starting index for frames to process.
233
- frames_end_idx (int, optional): Ending index for frames to process.
234
-
235
- Returns:
236
- Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
237
- """
238
- if frames_start_idx is not None and frames_end_idx is not None:
239
- images = images[:, frames_start_idx:frames_end_idx].contiguous()
240
-
241
- B, S, _, H, W = images.shape
242
-
243
- patch_h, patch_w = H // self.patch_size, W // self.patch_size
244
-
245
- out = []
246
- dpt_idx = 0
247
-
248
- for layer_idx in self.intermediate_layer_idx:
249
- x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
250
-
251
- # Select frames if processing a chunk
252
- if frames_start_idx is not None and frames_end_idx is not None:
253
- x = x[:, frames_start_idx:frames_end_idx]
254
-
255
- x = x.reshape(B * S, -1, x.shape[-1])
256
-
257
- x = self.norm(x)
258
-
259
- x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
260
-
261
- x = self.projects[dpt_idx](x)
262
- if self.pos_embed:
263
- x = self._apply_pos_embed(x, W, H)
264
- x = self.resize_layers[dpt_idx](x)
265
-
266
- out.append(x)
267
- dpt_idx += 1
268
-
269
- # Fuse features from multiple layers.
270
- out = self.scratch_forward(out)
271
- # Interpolate fused output to match target image resolution.
272
- out = custom_interpolate(
273
- out,
274
- (
275
- int(patch_h * self.patch_size / self.down_ratio),
276
- int(patch_w * self.patch_size / self.down_ratio),
277
- ),
278
- mode="bilinear",
279
- align_corners=True,
280
- )
281
-
282
- if self.pos_embed:
283
- out = self._apply_pos_embed(out, W, H)
284
-
285
- if self.feature_only:
286
- return out.view(B, S, *out.shape[1:])
287
-
288
- out = self.scratch.output_conv2(out)
289
- preds, conf = activate_head(
290
- out, activation=self.activation, conf_activation=self.conf_activation
291
- )
292
-
293
- preds = preds.view(B, S, *preds.shape[1:])
294
- conf = conf.view(B, S, *conf.shape[1:])
295
- return preds, conf
296
-
297
- def _apply_pos_embed(
298
- self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
299
- ) -> torch.Tensor:
300
- """
301
- Apply positional embedding to tensor x.
302
- """
303
- patch_w = x.shape[-1]
304
- patch_h = x.shape[-2]
305
- pos_embed = create_uv_grid(
306
- patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
307
- )
308
- pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
309
- pos_embed = pos_embed * ratio
310
- pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
311
- return x + pos_embed
312
-
313
- def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
314
- """
315
- Forward pass through the fusion blocks.
316
-
317
- Args:
318
- features (List[Tensor]): List of feature maps from different layers.
319
-
320
- Returns:
321
- Tensor: Fused feature map.
322
- """
323
- layer_1, layer_2, layer_3, layer_4 = features
324
-
325
- layer_1_rn = self.scratch.layer1_rn(layer_1)
326
- layer_2_rn = self.scratch.layer2_rn(layer_2)
327
- layer_3_rn = self.scratch.layer3_rn(layer_3)
328
- layer_4_rn = self.scratch.layer4_rn(layer_4)
329
-
330
- out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
331
- del layer_4_rn, layer_4
332
-
333
- out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
334
- del layer_3_rn, layer_3
335
-
336
- out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
337
- del layer_2_rn, layer_2
338
-
339
- out = self.scratch.refinenet1(out, layer_1_rn)
340
- del layer_1_rn, layer_1
341
-
342
- out = self.scratch.output_conv1(out)
343
- return out
344
-
345
-
346
- ################################################################################
347
- # Modules
348
- ################################################################################
349
-
350
-
351
- def _make_fusion_block(
352
- features: int,
353
- size: Optional[int] = None,
354
- has_residual: bool = True,
355
- groups: int = 1,
356
- ) -> nn.Module:
357
- return FeatureFusionBlock(
358
- features,
359
- nn.ReLU(inplace=True),
360
- deconv=False,
361
- bn=False,
362
- expand=False,
363
- align_corners=True,
364
- size=size,
365
- has_residual=has_residual,
366
- groups=groups,
367
- )
368
-
369
-
370
- def _make_scratch(
371
- in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
372
- ) -> nn.Module:
373
- scratch = nn.Module()
374
- out_shape1 = out_shape
375
- out_shape2 = out_shape
376
- out_shape3 = out_shape
377
- if len(in_shape) >= 4:
378
- out_shape4 = out_shape
379
-
380
- if expand:
381
- out_shape1 = out_shape
382
- out_shape2 = out_shape * 2
383
- out_shape3 = out_shape * 4
384
- if len(in_shape) >= 4:
385
- out_shape4 = out_shape * 8
386
-
387
- scratch.layer1_rn = nn.Conv2d(
388
- in_shape[0],
389
- out_shape1,
390
- kernel_size=3,
391
- stride=1,
392
- padding=1,
393
- bias=False,
394
- groups=groups,
395
- )
396
- scratch.layer2_rn = nn.Conv2d(
397
- in_shape[1],
398
- out_shape2,
399
- kernel_size=3,
400
- stride=1,
401
- padding=1,
402
- bias=False,
403
- groups=groups,
404
- )
405
- scratch.layer3_rn = nn.Conv2d(
406
- in_shape[2],
407
- out_shape3,
408
- kernel_size=3,
409
- stride=1,
410
- padding=1,
411
- bias=False,
412
- groups=groups,
413
- )
414
- if len(in_shape) >= 4:
415
- scratch.layer4_rn = nn.Conv2d(
416
- in_shape[3],
417
- out_shape4,
418
- kernel_size=3,
419
- stride=1,
420
- padding=1,
421
- bias=False,
422
- groups=groups,
423
- )
424
- return scratch
425
-
426
-
427
- class ResidualConvUnit(nn.Module):
428
- """Residual convolution module."""
429
-
430
- def __init__(self, features, activation, bn, groups=1):
431
- """Init.
432
-
433
- Args:
434
- features (int): number of features
435
- """
436
- super().__init__()
437
-
438
- self.bn = bn
439
- self.groups = groups
440
- self.conv1 = nn.Conv2d(
441
- features,
442
- features,
443
- kernel_size=3,
444
- stride=1,
445
- padding=1,
446
- bias=True,
447
- groups=self.groups,
448
- )
449
- self.conv2 = nn.Conv2d(
450
- features,
451
- features,
452
- kernel_size=3,
453
- stride=1,
454
- padding=1,
455
- bias=True,
456
- groups=self.groups,
457
- )
458
-
459
- self.norm1 = None
460
- self.norm2 = None
461
-
462
- self.activation = activation
463
-
464
- def forward(self, x):
465
- """Forward pass.
466
-
467
- Args:
468
- x (tensor): input
469
-
470
- Returns:
471
- tensor: output
472
- """
473
-
474
- out = self.activation(x)
475
- out = self.conv1(out)
476
- if self.norm1 is not None:
477
- out = self.norm1(out)
478
-
479
- out = self.activation(out)
480
- out = self.conv2(out)
481
- if self.norm2 is not None:
482
- out = self.norm2(out)
483
-
484
- return out + x
485
-
486
-
487
- class FeatureFusionBlock(nn.Module):
488
- """Feature fusion block."""
489
-
490
- def __init__(
491
- self,
492
- features,
493
- activation,
494
- deconv=False,
495
- bn=False,
496
- expand=False,
497
- align_corners=True,
498
- size=None,
499
- has_residual=True,
500
- groups=1,
501
- ):
502
- """Init.
503
-
504
- Args:
505
- features (int): number of features
506
- """
507
- super(FeatureFusionBlock, self).__init__()
508
-
509
- self.deconv = deconv
510
- self.align_corners = align_corners
511
- self.groups = groups
512
- self.expand = expand
513
- out_features = features
514
- if self.expand == True:
515
- out_features = features // 2
516
-
517
- self.out_conv = nn.Conv2d(
518
- features,
519
- out_features,
520
- kernel_size=1,
521
- stride=1,
522
- padding=0,
523
- bias=True,
524
- groups=self.groups,
525
- )
526
-
527
- if has_residual:
528
- self.resConfUnit1 = ResidualConvUnit(
529
- features, activation, bn, groups=self.groups
530
- )
531
-
532
- self.has_residual = has_residual
533
- self.resConfUnit2 = ResidualConvUnit(
534
- features, activation, bn, groups=self.groups
535
- )
536
-
537
- self.size = size
538
-
539
- def forward(self, *xs, size=None):
540
- """Forward pass.
541
-
542
- Returns:
543
- tensor: output
544
- """
545
- output = xs[0]
546
-
547
- if self.has_residual:
548
- res = self.resConfUnit1(xs[1])
549
- output = output + res
550
-
551
- output = self.resConfUnit2(output)
552
-
553
- if (size is None) and (self.size is None):
554
- modifier = {"scale_factor": 2}
555
- elif size is None:
556
- modifier = {"size": self.size}
557
- else:
558
- modifier = {"size": size}
559
-
560
- output = custom_interpolate(
561
- output, **modifier, mode="bilinear", align_corners=self.align_corners
562
- )
563
- output = self.out_conv(output)
564
-
565
- return output
566
-
567
-
568
- def custom_interpolate(
569
- x: torch.Tensor,
570
- size: Optional[Tuple[int, int]] = None,
571
- scale_factor: Optional[float] = None,
572
- mode: str = "bilinear",
573
- align_corners: bool = True,
574
- ) -> torch.Tensor:
575
- """
576
- Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
577
- """
578
- if size is None:
579
- size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
580
-
581
- INT_MAX = 1610612736
582
-
583
- input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
584
-
585
- if input_elements > INT_MAX:
586
- chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
587
- interpolated_chunks = [
588
- nn.functional.interpolate(
589
- chunk, size=size, mode=mode, align_corners=align_corners
590
- )
591
- for chunk in chunks
592
- ]
593
- x = torch.cat(interpolated_chunks, dim=0)
594
- return x.contiguous()
595
- else:
596
- return nn.functional.interpolate(
597
- x, size=size, mode=mode, align_corners=align_corners
598
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
FastVGGT/vggt/heads/head_act.py DELETED
@@ -1,125 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import torch
9
- import torch.nn.functional as F
10
-
11
-
12
- def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
- """
14
- Activate pose parameters with specified activation functions.
15
-
16
- Args:
17
- pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
- trans_act: Activation type for translation component
19
- quat_act: Activation type for quaternion component
20
- fl_act: Activation type for focal length component
21
-
22
- Returns:
23
- Activated pose parameters tensor
24
- """
25
- T = pred_pose_enc[..., :3]
26
- quat = pred_pose_enc[..., 3:7]
27
- fl = pred_pose_enc[..., 7:] # or fov
28
-
29
- T = base_pose_act(T, trans_act)
30
- quat = base_pose_act(quat, quat_act)
31
- fl = base_pose_act(fl, fl_act) # or fov
32
-
33
- pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
-
35
- return pred_pose_enc
36
-
37
-
38
- def base_pose_act(pose_enc, act_type="linear"):
39
- """
40
- Apply basic activation function to pose parameters.
41
-
42
- Args:
43
- pose_enc: Tensor containing encoded pose parameters
44
- act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
-
46
- Returns:
47
- Activated pose parameters
48
- """
49
- if act_type == "linear":
50
- return pose_enc
51
- elif act_type == "inv_log":
52
- return inverse_log_transform(pose_enc)
53
- elif act_type == "exp":
54
- return torch.exp(pose_enc)
55
- elif act_type == "relu":
56
- return F.relu(pose_enc)
57
- else:
58
- raise ValueError(f"Unknown act_type: {act_type}")
59
-
60
-
61
- def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
- """
63
- Process network output to extract 3D points and confidence values.
64
-
65
- Args:
66
- out: Network output tensor (B, C, H, W)
67
- activation: Activation type for 3D points
68
- conf_activation: Activation type for confidence values
69
-
70
- Returns:
71
- Tuple of (3D points tensor, confidence tensor)
72
- """
73
- # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
- fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
-
76
- # Split into xyz (first C-1 channels) and confidence (last channel)
77
- xyz = fmap[:, :, :, :-1]
78
- conf = fmap[:, :, :, -1]
79
-
80
- if activation == "norm_exp":
81
- d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
- xyz_normed = xyz / d
83
- pts3d = xyz_normed * torch.expm1(d)
84
- elif activation == "norm":
85
- pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
- elif activation == "exp":
87
- pts3d = torch.exp(xyz)
88
- elif activation == "relu":
89
- pts3d = F.relu(xyz)
90
- elif activation == "inv_log":
91
- pts3d = inverse_log_transform(xyz)
92
- elif activation == "xy_inv_log":
93
- xy, z = xyz.split([2, 1], dim=-1)
94
- z = inverse_log_transform(z)
95
- pts3d = torch.cat([xy * z, z], dim=-1)
96
- elif activation == "sigmoid":
97
- pts3d = torch.sigmoid(xyz)
98
- elif activation == "linear":
99
- pts3d = xyz
100
- else:
101
- raise ValueError(f"Unknown activation: {activation}")
102
-
103
- if conf_activation == "expp1":
104
- conf_out = 1 + conf.exp()
105
- elif conf_activation == "expp0":
106
- conf_out = conf.exp()
107
- elif conf_activation == "sigmoid":
108
- conf_out = torch.sigmoid(conf)
109
- else:
110
- raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
-
112
- return pts3d, conf_out
113
-
114
-
115
- def inverse_log_transform(y):
116
- """
117
- Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
-
119
- Args:
120
- y: Input tensor
121
-
122
- Returns:
123
- Transformed tensor
124
- """
125
- return torch.sign(y) * (torch.expm1(torch.abs(y)))