Upload 99 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- FastVGGT/.gitignore +160 -0
- FastVGGT/.vscode/launch.json +85 -0
- FastVGGT/README.md +163 -0
- FastVGGT/assets/attn_map.png +3 -0
- FastVGGT/assets/autolab_logo.png +3 -0
- FastVGGT/assets/maclab_logo.png +0 -0
- FastVGGT/assets/main.png +3 -0
- FastVGGT/assets/vs.png +3 -0
- FastVGGT/eval/__pycache__/base.cpython-310.pyc +0 -0
- FastVGGT/eval/__pycache__/criterion.cpython-310.pyc +0 -0
- FastVGGT/eval/__pycache__/data.cpython-310.pyc +0 -0
- FastVGGT/eval/__pycache__/data.cpython-37.pyc +0 -0
- FastVGGT/eval/__pycache__/utils.cpython-310.pyc +0 -0
- FastVGGT/eval/__pycache__/utils.cpython-37.pyc +0 -0
- FastVGGT/eval/base.py +273 -0
- FastVGGT/eval/criterion.py +534 -0
- FastVGGT/eval/data.py +338 -0
- FastVGGT/eval/dataset_utils/__init__.py +1 -0
- FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc +0 -0
- FastVGGT/eval/dataset_utils/corr.py +234 -0
- FastVGGT/eval/dataset_utils/cropping.py +140 -0
- FastVGGT/eval/dataset_utils/transforms.py +78 -0
- FastVGGT/eval/eval_7andN.py +497 -0
- FastVGGT/eval/eval_custom.py +467 -0
- FastVGGT/eval/eval_scannet.py +208 -0
- FastVGGT/eval/utils.py +142 -0
- FastVGGT/merging/__init__.py +3 -0
- FastVGGT/merging/__pycache__/__init__.cpython-310.pyc +0 -0
- FastVGGT/merging/__pycache__/merge.cpython-310.pyc +0 -0
- FastVGGT/merging/merge.py +370 -0
- FastVGGT/requirements.txt +15 -0
- FastVGGT/vggt/__init__.py +5 -0
- FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc +0 -0
- FastVGGT/vggt/dependency/__init__.py +5 -0
- FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc +0 -0
- FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc +0 -0
- FastVGGT/vggt/dependency/distortion.py +54 -0
- FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/camera_head.py +149 -0
- FastVGGT/vggt/heads/dpt_head.py +598 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
FastVGGT/assets/attn_map.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
FastVGGT/assets/autolab_logo.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
FastVGGT/assets/main.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
FastVGGT/assets/vs.png filter=lfs diff=lfs merge=lfs -text
|
FastVGGT/.gitignore
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.hydra/
|
| 2 |
+
output/
|
| 3 |
+
ckpt/
|
| 4 |
+
.vscode/
|
| 5 |
+
dependency/
|
| 6 |
+
# Byte-compiled / optimized / DLL files
|
| 7 |
+
__pycache__/
|
| 8 |
+
**/__pycache__/
|
| 9 |
+
*.py[cod]
|
| 10 |
+
*$py.class
|
| 11 |
+
test_logs/
|
| 12 |
+
quick_start_logs/
|
| 13 |
+
logs/
|
| 14 |
+
*.pth
|
| 15 |
+
/data/
|
| 16 |
+
*.png
|
| 17 |
+
eval_results/
|
| 18 |
+
.vscode/
|
| 19 |
+
.curosr/
|
| 20 |
+
|
| 21 |
+
# C extensions
|
| 22 |
+
*.so
|
| 23 |
+
LightGlue/
|
| 24 |
+
# Distribution / packaging
|
| 25 |
+
.Python
|
| 26 |
+
build/
|
| 27 |
+
develop-eggs/
|
| 28 |
+
dist/
|
| 29 |
+
downloads/
|
| 30 |
+
eggs/
|
| 31 |
+
.eggs/
|
| 32 |
+
lib/
|
| 33 |
+
lib64/
|
| 34 |
+
parts/
|
| 35 |
+
sdist/
|
| 36 |
+
var/
|
| 37 |
+
wheels/
|
| 38 |
+
pip-wheel-metadata/
|
| 39 |
+
share/python-wheels/
|
| 40 |
+
*.egg-info/
|
| 41 |
+
.installed.cfg
|
| 42 |
+
*.egg
|
| 43 |
+
MANIFEST
|
| 44 |
+
|
| 45 |
+
# PyInstaller
|
| 46 |
+
# Usually these files are written by a python script from a template
|
| 47 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 48 |
+
*.manifest
|
| 49 |
+
*.spec
|
| 50 |
+
|
| 51 |
+
# Installer logs
|
| 52 |
+
pip-log.txt
|
| 53 |
+
pip-delete-this-directory.txt
|
| 54 |
+
|
| 55 |
+
# Unit test / coverage reports
|
| 56 |
+
htmlcov/
|
| 57 |
+
.tox/
|
| 58 |
+
.nox/
|
| 59 |
+
.coverage
|
| 60 |
+
.coverage.*
|
| 61 |
+
.cache
|
| 62 |
+
nosetests.xml
|
| 63 |
+
coverage.xml
|
| 64 |
+
*.cover
|
| 65 |
+
*.py,cover
|
| 66 |
+
.hypothesis/
|
| 67 |
+
.pytest_cache/
|
| 68 |
+
cover/
|
| 69 |
+
|
| 70 |
+
# Translations
|
| 71 |
+
*.mo
|
| 72 |
+
*.pot
|
| 73 |
+
|
| 74 |
+
# Django stuff:
|
| 75 |
+
*.log
|
| 76 |
+
local_settings.py
|
| 77 |
+
db.sqlite3
|
| 78 |
+
db.sqlite3-journal
|
| 79 |
+
|
| 80 |
+
# Flask stuff:
|
| 81 |
+
instance/
|
| 82 |
+
.webassets-cache
|
| 83 |
+
|
| 84 |
+
# Scrapy stuff:
|
| 85 |
+
.scrapy
|
| 86 |
+
|
| 87 |
+
# Sphinx documentation
|
| 88 |
+
docs/_build/
|
| 89 |
+
|
| 90 |
+
# PyBuilder
|
| 91 |
+
target/
|
| 92 |
+
|
| 93 |
+
# Jupyter Notebook
|
| 94 |
+
.ipynb_checkpoints
|
| 95 |
+
|
| 96 |
+
# IPython
|
| 97 |
+
profile_default/
|
| 98 |
+
ipython_config.py
|
| 99 |
+
|
| 100 |
+
# pyenv
|
| 101 |
+
.python-version
|
| 102 |
+
|
| 103 |
+
# pipenv
|
| 104 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 105 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 106 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 107 |
+
# install all needed dependencies.
|
| 108 |
+
#Pipfile.lock
|
| 109 |
+
|
| 110 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 111 |
+
__pypackages__/
|
| 112 |
+
|
| 113 |
+
# Celery stuff
|
| 114 |
+
celerybeat-schedule
|
| 115 |
+
celerybeat.pid
|
| 116 |
+
|
| 117 |
+
# SageMath parsed files
|
| 118 |
+
*.sage.py
|
| 119 |
+
|
| 120 |
+
# Environments
|
| 121 |
+
.env
|
| 122 |
+
.venv
|
| 123 |
+
env/
|
| 124 |
+
venv/
|
| 125 |
+
ENV/
|
| 126 |
+
env.bak/
|
| 127 |
+
venv.bak/
|
| 128 |
+
|
| 129 |
+
# Spyder project settings
|
| 130 |
+
.spyderproject
|
| 131 |
+
.spyproject
|
| 132 |
+
|
| 133 |
+
# Rope project settings
|
| 134 |
+
.ropeproject
|
| 135 |
+
|
| 136 |
+
# mkdocs documentation
|
| 137 |
+
/site
|
| 138 |
+
|
| 139 |
+
# mypy
|
| 140 |
+
.mypy_cache/
|
| 141 |
+
.dmypy.json
|
| 142 |
+
dmypy.json
|
| 143 |
+
|
| 144 |
+
# Pyre type checker
|
| 145 |
+
.pyre/
|
| 146 |
+
|
| 147 |
+
# pytype static type analyzer
|
| 148 |
+
.pytype/
|
| 149 |
+
|
| 150 |
+
# Profiling data
|
| 151 |
+
.prof
|
| 152 |
+
|
| 153 |
+
# Folder specific to your needs
|
| 154 |
+
**/tmp/
|
| 155 |
+
**/outputs/skyseg.onnx
|
| 156 |
+
skyseg.onnx
|
| 157 |
+
|
| 158 |
+
# pixi environments
|
| 159 |
+
.pixi
|
| 160 |
+
*.egg-info
|
FastVGGT/.vscode/launch.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
// Use IntelliSense to learn about possible attributes.
|
| 3 |
+
// Hover to view descriptions of existing attributes.
|
| 4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
| 5 |
+
"version": "0.2.0",
|
| 6 |
+
"configurations": [
|
| 7 |
+
|
| 8 |
+
{
|
| 9 |
+
"name": "launch",
|
| 10 |
+
"type": "debugpy",
|
| 11 |
+
"request": "launch",
|
| 12 |
+
"program": "/home/sy/code/vggt_0625/training/launch.py",
|
| 13 |
+
"console": "integratedTerminal",
|
| 14 |
+
"args": "${command:pickArgs}",
|
| 15 |
+
"env": {
|
| 16 |
+
"CUDA_VISIBLE_DEVICES": "3",
|
| 17 |
+
},
|
| 18 |
+
"cwd": "/home/sy/code/vggt_0625/training",
|
| 19 |
+
"justMyCode": true,
|
| 20 |
+
"python": "/home/sy/anaconda3/envs/vggt/bin/python"
|
| 21 |
+
}
|
| 22 |
+
,{
|
| 23 |
+
"name": "train_scannet",
|
| 24 |
+
"type": "debugpy",
|
| 25 |
+
"request": "launch",
|
| 26 |
+
"program": "/home/sy/code/vggt_0625/training/launch_scannet.py",
|
| 27 |
+
"console": "integratedTerminal",
|
| 28 |
+
"args": [
|
| 29 |
+
// "--config_name", "scannet",
|
| 30 |
+
// "--exp_name", "scannet_exp001",
|
| 31 |
+
// "--resume_checkpoint_path", "/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt"
|
| 32 |
+
],
|
| 33 |
+
"env": {
|
| 34 |
+
"CUDA_VISIBLE_DEVICES": "7",
|
| 35 |
+
"WORLD_SIZE": "1",
|
| 36 |
+
"RANK": "0",
|
| 37 |
+
"MASTER_ADDR": "localhost",
|
| 38 |
+
"MASTER_PORT": "12345"
|
| 39 |
+
},
|
| 40 |
+
"cwd": "/home/sy/code/vggt_0625/training",
|
| 41 |
+
"justMyCode": true,
|
| 42 |
+
"python": "/home/sy/anaconda3/envs/vggt/bin/python"
|
| 43 |
+
}
|
| 44 |
+
,{
|
| 45 |
+
"name": "eval_scannet",
|
| 46 |
+
"type": "debugpy",
|
| 47 |
+
"request": "launch",
|
| 48 |
+
"program": "/home/sy/code/FastVGGT/eval/eval_scannet.py",
|
| 49 |
+
"console": "integratedTerminal",
|
| 50 |
+
"args": [
|
| 51 |
+
"--data_dir","/data/sy/scannetv2/process_scannet",
|
| 52 |
+
"--gt_ply_dir","/data/sy/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
|
| 53 |
+
"--output_path", "/home/sy/code/FastVGGT/eval_results",
|
| 54 |
+
"--merging", "0",
|
| 55 |
+
"--ckpt_path","/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt",
|
| 56 |
+
"--vis_attn_map"
|
| 57 |
+
],
|
| 58 |
+
"env": {
|
| 59 |
+
"CUDA_VISIBLE_DEVICES": "2"
|
| 60 |
+
},
|
| 61 |
+
"justMyCode": true,
|
| 62 |
+
"python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"name": "eval_cd",
|
| 66 |
+
"type": "debugpy",
|
| 67 |
+
"request": "launch",
|
| 68 |
+
"program": "/home/sy/code/FastVGGT/eval/eval_custom.py",
|
| 69 |
+
"console": "integratedTerminal",
|
| 70 |
+
"args": [
|
| 71 |
+
"--merging", "0",
|
| 72 |
+
// "--kf","10",
|
| 73 |
+
// "--output_dir","/home/sy/code/vggt_0625/eval_results_cd",
|
| 74 |
+
"--data_path","/data/sy/segment-102751/",
|
| 75 |
+
"--vis_attn_map"
|
| 76 |
+
],
|
| 77 |
+
"env": {
|
| 78 |
+
"CUDA_VISIBLE_DEVICES": "3"
|
| 79 |
+
},
|
| 80 |
+
"justMyCode": true,
|
| 81 |
+
// "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
]
|
| 85 |
+
}
|
FastVGGT/README.md
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h2>⚡️ FastVGGT: Training-Free Acceleration of Visual Geometry Transformer</h2>
|
| 3 |
+
|
| 4 |
+
<p align="center">
|
| 5 |
+
<a href="https://arxiv.org/abs/2509.02560"><img src="https://img.shields.io/badge/arXiv-FastVGGT-red?logo=arxiv" alt="Paper PDF"></a>
|
| 6 |
+
<a href="https://mystorm16.github.io/fastvggt/"><img src="https://img.shields.io/badge/Project_Page-FastVGGT-yellow" alt="Project Page"></a>
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
+
<img src="assets/maclab_logo.png" alt="Maclab Logo" width="110" style="margin-right: 40px;">
|
| 10 |
+
<img src="assets/autolab_logo.png" alt="Autolab Logo" width="110">
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
**[Media Analytics & Computing Laboratory](https://mac.xmu.edu.cn/)**; **[AUTOLAB](https://zhipengzhang.cn/)**
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
[You Shen](https://mystorm16.github.io/), [Zhipeng Zhang](https://zhipengzhang.cn/), [Yansong Qu](https://quyans.github.io/), [Liujuan Cao](https://mac.xmu.edu.cn/ljcao/)
|
| 17 |
+
</div>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## 📰 News
|
| 21 |
+
- [Sep 8, 2025] Added custom dataset evaluation.
|
| 22 |
+
- [Sep 3, 2025] Paper release.
|
| 23 |
+
- [Sep 2, 2025] Code release.
|
| 24 |
+
|
| 25 |
+
## 🔭 Overview
|
| 26 |
+
|
| 27 |
+
FastVGGT observes **strong similarity** in attention maps and leverages it to design a training-free acceleration method for long-sequence 3D reconstruction, **achieving up to 4× faster inference without sacrificing accuracy.**
|
| 28 |
+
|
| 29 |
+
<img src="assets/main.png" alt="Autolab Logo" width="">
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## ⚙️ Environment Setup
|
| 33 |
+
First, create a virtual environment using Conda, clone this repository to your local machine, and install the required dependencies.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
conda create -n fastvggt python=3.10
|
| 38 |
+
conda activate fastvggt
|
| 39 |
+
git clone git@github.com:mystorm16/FastVGGT.git
|
| 40 |
+
cd FastVGGT
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Next, prepare the ScanNet dataset: http://www.scan-net.org/ScanNet/
|
| 45 |
+
|
| 46 |
+
Then, download the VGGT checkpoint (we use the checkpoint link provided in https://github.com/facebookresearch/vggt/tree/evaluation/evaluation):
|
| 47 |
+
```bash
|
| 48 |
+
wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Finally, configure the dataset path and VGGT checkpoint path. For example:
|
| 52 |
+
```bash
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--data_dir", type=Path, default="/data/scannetv2/process_scannet"
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--gt_ply_dir",
|
| 58 |
+
type=Path,
|
| 59 |
+
default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--ckpt_path",
|
| 63 |
+
type=str,
|
| 64 |
+
default="./ckpt/model_tracker_fixed_e20.pt",
|
| 65 |
+
)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
## 💎 Observation
|
| 70 |
+
|
| 71 |
+
Note: A large number of input_frames may significantly slow down saving the visualization results. Please try using a smaller number first.
|
| 72 |
+
```bash
|
| 73 |
+
python eval/eval_scannet.py --input_frame 30 --vis_attn_map --merging 0
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
We observe that many token-level attention maps are highly similar in each block, motivating our optimization of the Global Attention module.
|
| 77 |
+
|
| 78 |
+
<img src="assets/attn_map.png" alt="Autolab Logo" width="">
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
## 🏀 Evaluation
|
| 83 |
+
### Custom Dataset
|
| 84 |
+
Please organize the data according to the following directory:
|
| 85 |
+
```
|
| 86 |
+
<data_path>/
|
| 87 |
+
├── images/
|
| 88 |
+
│ ├── 000000.jpg
|
| 89 |
+
│ ├── 000001.jpg
|
| 90 |
+
│ └── ...
|
| 91 |
+
├── pose/ # Optional: Camera poses
|
| 92 |
+
│ ├── 000000.txt
|
| 93 |
+
│ ├── 000001.txt
|
| 94 |
+
│ └── ...
|
| 95 |
+
└── gt_ply/ # Optional: GT point cloud
|
| 96 |
+
└── scene_xxx.ply
|
| 97 |
+
```
|
| 98 |
+
- Required: `images/`
|
| 99 |
+
- Additionally required when `--enable_evaluation` is enabled: `pose/` and `gt_ply/`
|
| 100 |
+
|
| 101 |
+
Inference only:
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
python eval/eval_custom.py \
|
| 105 |
+
--data_path /path/to/your_dataset \
|
| 106 |
+
--output_path ./eval_results_custom \
|
| 107 |
+
--plot
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
Inference + Evaluation (requires `pose/` and `gt_ply/`):
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
python eval/eval_custom.py \
|
| 114 |
+
--data_path /path/to/your_dataset \
|
| 115 |
+
--enable_evaluation \
|
| 116 |
+
--output_path ./eval_results_custom \
|
| 117 |
+
--plot
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### ScanNet
|
| 121 |
+
Evaluate FastVGGT on the ScanNet dataset with 1,000 input images. The **--merging** parameter specifies the block index at which the merging strategy is applied:
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
python eval/eval_scannet.py --input_frame 1000 --merging 0
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
Evaluate Baseline VGGT on the ScanNet dataset with 1,000 input images:
|
| 128 |
+
```bash
|
| 129 |
+
python eval/eval_scannet.py --input_frame 1000
|
| 130 |
+
```
|
| 131 |
+
<img src="assets/vs.png" alt="Autolab Logo" width="">
|
| 132 |
+
|
| 133 |
+
### 7 Scenes & NRGBD
|
| 134 |
+
Evaluate across two datasets, sampling keyframes every 10 frames:
|
| 135 |
+
```bash
|
| 136 |
+
python eval/eval_7andN.py --kf 10
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
## 🍺 Acknowledgements
|
| 140 |
+
|
| 141 |
+
- Thanks to these great repositories: [VGGT](https://github.com/facebookresearch/vggt), [Dust3r](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [CUT3R](https://github.com/CUT3R/CUT3R), [MV-DUSt3R+](https://github.com/facebookresearch/mvdust3r), [StreamVGGT](https://github.com/wzzheng/StreamVGGT), [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long), [ToMeSD](https://github.com/dbolya/tomesd) and many other inspiring works in the community.
|
| 142 |
+
|
| 143 |
+
- Special thanks to [Jianyuan Wang](https://jytime.github.io/) for his valuable discussions and suggestions on this work.
|
| 144 |
+
|
| 145 |
+
<!-- ## ✍️ Checklist
|
| 146 |
+
|
| 147 |
+
- [ ] Release the evaluation code on 7 Scenes / NRGBD -->
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
## ⚖️ License
|
| 151 |
+
See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
|
| 152 |
+
|
| 153 |
+
## Citation
|
| 154 |
+
|
| 155 |
+
If you find this project helpful, please consider citing the following paper:
|
| 156 |
+
```
|
| 157 |
+
@article{shen2025fastvggt,
|
| 158 |
+
title={FastVGGT: Training-Free Acceleration of Visual Geometry Transformer},
|
| 159 |
+
author={Shen, You and Zhang, Zhipeng and Qu, Yansong and Cao, Liujuan},
|
| 160 |
+
journal={arXiv preprint arXiv:2509.02560},
|
| 161 |
+
year={2025}
|
| 162 |
+
}
|
| 163 |
+
```
|
FastVGGT/assets/attn_map.png
ADDED
|
Git LFS Details
|
FastVGGT/assets/autolab_logo.png
ADDED
|
Git LFS Details
|
FastVGGT/assets/maclab_logo.png
ADDED
|
FastVGGT/assets/main.png
ADDED
|
Git LFS Details
|
FastVGGT/assets/vs.png
ADDED
|
Git LFS Details
|
FastVGGT/eval/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (6.92 kB). View file
|
|
|
FastVGGT/eval/__pycache__/criterion.cpython-310.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
FastVGGT/eval/__pycache__/data.cpython-310.pyc
ADDED
|
Binary file (7.78 kB). View file
|
|
|
FastVGGT/eval/__pycache__/data.cpython-37.pyc
ADDED
|
Binary file (8.03 kB). View file
|
|
|
FastVGGT/eval/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.99 kB). View file
|
|
|
FastVGGT/eval/__pycache__/utils.cpython-37.pyc
ADDED
|
Binary file (4.32 kB). View file
|
|
|
FastVGGT/eval/base.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# base class for implementing datasets
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import PIL
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from dataset_utils.transforms import ImgNorm
|
| 12 |
+
import dataset_utils.cropping as cropping
|
| 13 |
+
from utils import depthmap_to_absolute_camera_coordinates
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseStereoViewDataset:
|
| 17 |
+
"""Define all basic options.
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
class MyDataset (BaseStereoViewDataset):
|
| 21 |
+
def _get_views(self, idx, rng):
|
| 22 |
+
# overload here
|
| 23 |
+
views = []
|
| 24 |
+
views.append(dict(img=, ...))
|
| 25 |
+
return views
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
*, # only keyword arguments
|
| 31 |
+
split=None,
|
| 32 |
+
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
| 33 |
+
transform=ImgNorm,
|
| 34 |
+
aug_crop=False,
|
| 35 |
+
seed=None,
|
| 36 |
+
):
|
| 37 |
+
self.num_views = 2
|
| 38 |
+
self.split = split
|
| 39 |
+
self._set_resolutions(resolution)
|
| 40 |
+
|
| 41 |
+
self.transform = transform
|
| 42 |
+
if isinstance(transform, str):
|
| 43 |
+
transform = eval(transform)
|
| 44 |
+
|
| 45 |
+
self.aug_crop = aug_crop
|
| 46 |
+
self.seed = seed
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.scenes)
|
| 50 |
+
|
| 51 |
+
def get_stats(self):
|
| 52 |
+
return f"{len(self)} pairs"
|
| 53 |
+
|
| 54 |
+
def __repr__(self):
|
| 55 |
+
resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
|
| 56 |
+
return (
|
| 57 |
+
f"""{type(self).__name__}({self.get_stats()},
|
| 58 |
+
{self.split=},
|
| 59 |
+
{self.seed=},
|
| 60 |
+
resolutions={resolutions_str},
|
| 61 |
+
{self.transform=})""".replace(
|
| 62 |
+
"self.", ""
|
| 63 |
+
)
|
| 64 |
+
.replace("\n", "")
|
| 65 |
+
.replace(" ", "")
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def _get_views(self, idx, resolution, rng):
|
| 69 |
+
raise NotImplementedError()
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, idx):
|
| 72 |
+
if isinstance(idx, tuple):
|
| 73 |
+
# the idx is specifying the aspect-ratio
|
| 74 |
+
idx, ar_idx = idx
|
| 75 |
+
else:
|
| 76 |
+
assert len(self._resolutions) == 1
|
| 77 |
+
ar_idx = 0
|
| 78 |
+
|
| 79 |
+
# set-up the rng
|
| 80 |
+
if self.seed: # reseed for each __getitem__
|
| 81 |
+
self._rng = np.random.default_rng(seed=self.seed + idx)
|
| 82 |
+
elif not hasattr(self, "_rng"):
|
| 83 |
+
seed = torch.initial_seed() # this is different for each dataloader process
|
| 84 |
+
self._rng = np.random.default_rng(seed=seed)
|
| 85 |
+
|
| 86 |
+
# over-loaded code
|
| 87 |
+
resolution = self._resolutions[
|
| 88 |
+
ar_idx
|
| 89 |
+
] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
| 90 |
+
views = self._get_views(idx, resolution, self._rng)
|
| 91 |
+
|
| 92 |
+
# check data-types
|
| 93 |
+
for v, view in enumerate(views):
|
| 94 |
+
assert (
|
| 95 |
+
"pts3d" not in view
|
| 96 |
+
), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
| 97 |
+
view["idx"] = v
|
| 98 |
+
|
| 99 |
+
# encode the image
|
| 100 |
+
width, height = view["img"].size
|
| 101 |
+
view["true_shape"] = np.int32((height, width))
|
| 102 |
+
view["img"] = self.transform(view["img"])
|
| 103 |
+
|
| 104 |
+
assert "camera_intrinsics" in view
|
| 105 |
+
if "camera_pose" not in view:
|
| 106 |
+
view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
|
| 107 |
+
else:
|
| 108 |
+
assert np.isfinite(
|
| 109 |
+
view["camera_pose"]
|
| 110 |
+
).all(), f"NaN in camera pose for view {view_name(view)}"
|
| 111 |
+
assert "pts3d" not in view
|
| 112 |
+
assert "valid_mask" not in view
|
| 113 |
+
assert np.isfinite(
|
| 114 |
+
view["depthmap"]
|
| 115 |
+
).all(), f"NaN in depthmap for view {view_name(view)}"
|
| 116 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 117 |
+
|
| 118 |
+
view["pts3d"] = pts3d
|
| 119 |
+
view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
| 120 |
+
|
| 121 |
+
# check all datatypes
|
| 122 |
+
for key, val in view.items():
|
| 123 |
+
res, err_msg = is_good_type(key, val)
|
| 124 |
+
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
| 125 |
+
K = view["camera_intrinsics"]
|
| 126 |
+
view["img_mask"] = True
|
| 127 |
+
view["ray_mask"] = False
|
| 128 |
+
view["ray_map"] = torch.full(
|
| 129 |
+
(6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
|
| 130 |
+
)
|
| 131 |
+
view["update"] = True
|
| 132 |
+
view["reset"] = False
|
| 133 |
+
|
| 134 |
+
# last thing done!
|
| 135 |
+
for view in views:
|
| 136 |
+
# transpose to make sure all views are the same size
|
| 137 |
+
transpose_to_landscape(view)
|
| 138 |
+
# this allows to check whether the RNG is is the same state each time
|
| 139 |
+
view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
|
| 140 |
+
return views
|
| 141 |
+
|
| 142 |
+
def _set_resolutions(self, resolutions):
|
| 143 |
+
"""Set the resolution(s) of the dataset.
|
| 144 |
+
Params:
|
| 145 |
+
- resolutions: int or tuple or list of tuples
|
| 146 |
+
"""
|
| 147 |
+
assert resolutions is not None, "undefined resolution"
|
| 148 |
+
|
| 149 |
+
if not isinstance(resolutions, list):
|
| 150 |
+
resolutions = [resolutions]
|
| 151 |
+
|
| 152 |
+
self._resolutions = []
|
| 153 |
+
for resolution in resolutions:
|
| 154 |
+
if isinstance(resolution, int):
|
| 155 |
+
width = height = resolution
|
| 156 |
+
else:
|
| 157 |
+
width, height = resolution
|
| 158 |
+
assert isinstance(
|
| 159 |
+
width, int
|
| 160 |
+
), f"Bad type for {width=} {type(width)=}, should be int"
|
| 161 |
+
assert isinstance(
|
| 162 |
+
height, int
|
| 163 |
+
), f"Bad type for {height=} {type(height)=}, should be int"
|
| 164 |
+
assert width >= height
|
| 165 |
+
self._resolutions.append((width, height))
|
| 166 |
+
|
| 167 |
+
def _crop_resize_if_necessary(
|
| 168 |
+
self, image, depthmap, intrinsics, resolution, rng=None, info=None
|
| 169 |
+
):
|
| 170 |
+
"""This function:
|
| 171 |
+
- first downsizes the image with LANCZOS inteprolation,
|
| 172 |
+
which is better than bilinear interpolation in
|
| 173 |
+
"""
|
| 174 |
+
if not isinstance(image, PIL.Image.Image):
|
| 175 |
+
image = PIL.Image.fromarray(image)
|
| 176 |
+
|
| 177 |
+
# downscale with lanczos interpolation so that image.size == resolution
|
| 178 |
+
# cropping centered on the principal point
|
| 179 |
+
W, H = image.size
|
| 180 |
+
cx, cy = intrinsics[:2, 2].round().astype(int)
|
| 181 |
+
|
| 182 |
+
# calculate min distance to margin
|
| 183 |
+
min_margin_x = min(cx, W - cx)
|
| 184 |
+
min_margin_y = min(cy, H - cy)
|
| 185 |
+
assert min_margin_x > W / 5, f"Bad principal point in view={info}"
|
| 186 |
+
assert min_margin_y > H / 5, f"Bad principal point in view={info}"
|
| 187 |
+
|
| 188 |
+
## Center crop
|
| 189 |
+
# Crop on the principal point, make it always centered
|
| 190 |
+
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
|
| 191 |
+
l, t = cx - min_margin_x, cy - min_margin_y
|
| 192 |
+
r, b = cx + min_margin_x, cy + min_margin_y
|
| 193 |
+
crop_bbox = (l, t, r, b)
|
| 194 |
+
|
| 195 |
+
image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 196 |
+
image, depthmap, intrinsics, crop_bbox
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# # transpose the resolution if necessary
|
| 200 |
+
W, H = image.size # new size
|
| 201 |
+
assert resolution[0] >= resolution[1]
|
| 202 |
+
if H > 1.1 * W:
|
| 203 |
+
# image is portrait mode
|
| 204 |
+
resolution = resolution[::-1]
|
| 205 |
+
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
|
| 206 |
+
# image is square, so we chose (portrait, landscape) randomly
|
| 207 |
+
if rng.integers(2):
|
| 208 |
+
resolution = resolution[::-1]
|
| 209 |
+
|
| 210 |
+
# high-quality Lanczos down-scaling
|
| 211 |
+
target_resolution = np.array(resolution)
|
| 212 |
+
# # if self.aug_crop > 1:
|
| 213 |
+
# # target_resolution += rng.integers(0, self.aug_crop)
|
| 214 |
+
# if resolution != (224, 224):
|
| 215 |
+
# halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
|
| 216 |
+
# ## Recale with max factor, so one of width or height might be larger than target_resolution
|
| 217 |
+
# image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
|
| 218 |
+
# else:
|
| 219 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(
|
| 220 |
+
image, depthmap, intrinsics, target_resolution
|
| 221 |
+
)
|
| 222 |
+
# actual cropping (if necessary) with bilinear interpolation
|
| 223 |
+
# if resolution == (224, 224):
|
| 224 |
+
intrinsics2 = cropping.camera_matrix_of_crop(
|
| 225 |
+
intrinsics, image.size, resolution, offset_factor=0.5
|
| 226 |
+
)
|
| 227 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(
|
| 228 |
+
intrinsics, intrinsics2, resolution
|
| 229 |
+
)
|
| 230 |
+
image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 231 |
+
image, depthmap, intrinsics, crop_bbox
|
| 232 |
+
)
|
| 233 |
+
return image, depthmap, intrinsics
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def is_good_type(key, v):
|
| 237 |
+
"""returns (is_good, err_msg)"""
|
| 238 |
+
if isinstance(v, (str, int, tuple)):
|
| 239 |
+
return True, None
|
| 240 |
+
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
|
| 241 |
+
return False, f"bad {v.dtype=}"
|
| 242 |
+
return True, None
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def view_name(view, batch_index=None):
|
| 246 |
+
def sel(x):
|
| 247 |
+
return x[batch_index] if batch_index not in (None, slice(None)) else x
|
| 248 |
+
|
| 249 |
+
db = sel(view["dataset"])
|
| 250 |
+
label = sel(view["label"])
|
| 251 |
+
instance = sel(view["instance"])
|
| 252 |
+
return f"{db}/{label}/{instance}"
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def transpose_to_landscape(view):
|
| 256 |
+
height, width = view["true_shape"]
|
| 257 |
+
|
| 258 |
+
if width < height:
|
| 259 |
+
# rectify portrait to landscape
|
| 260 |
+
assert view["img"].shape == (3, height, width)
|
| 261 |
+
view["img"] = view["img"].swapaxes(1, 2)
|
| 262 |
+
|
| 263 |
+
assert view["valid_mask"].shape == (height, width)
|
| 264 |
+
view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
|
| 265 |
+
|
| 266 |
+
assert view["depthmap"].shape == (height, width)
|
| 267 |
+
view["depthmap"] = view["depthmap"].swapaxes(0, 1)
|
| 268 |
+
|
| 269 |
+
assert view["pts3d"].shape == (height, width, 3)
|
| 270 |
+
view["pts3d"] = view["pts3d"].swapaxes(0, 1)
|
| 271 |
+
|
| 272 |
+
# transpose x and y pixels
|
| 273 |
+
view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
|
FastVGGT/eval/criterion.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from copy import copy, deepcopy
|
| 4 |
+
|
| 5 |
+
from eval.dataset_utils.corr import geotrf, inv
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def invalid_to_nans(arr, valid_mask, ndim=999):
|
| 9 |
+
if valid_mask is not None:
|
| 10 |
+
arr = arr.clone()
|
| 11 |
+
arr[~valid_mask] = float("nan")
|
| 12 |
+
if arr.ndim > ndim:
|
| 13 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 14 |
+
return arr
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def invalid_to_zeros(arr, valid_mask, ndim=999):
|
| 18 |
+
if valid_mask is not None:
|
| 19 |
+
arr = arr.clone()
|
| 20 |
+
arr[~valid_mask] = 0
|
| 21 |
+
nnz = valid_mask.view(len(valid_mask), -1).sum(1)
|
| 22 |
+
else:
|
| 23 |
+
nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
|
| 24 |
+
if arr.ndim > ndim:
|
| 25 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 26 |
+
return arr, nnz
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BaseCriterion(nn.Module):
|
| 30 |
+
def __init__(self, reduction="mean"):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.reduction = reduction
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Criterion(nn.Module):
|
| 36 |
+
def __init__(self, criterion=None):
|
| 37 |
+
super().__init__()
|
| 38 |
+
assert isinstance(
|
| 39 |
+
criterion, BaseCriterion
|
| 40 |
+
), f"{criterion} is not a proper criterion!"
|
| 41 |
+
self.criterion = copy(criterion)
|
| 42 |
+
|
| 43 |
+
def get_name(self):
|
| 44 |
+
return f"{type(self).__name__}({self.criterion})"
|
| 45 |
+
|
| 46 |
+
def with_reduction(self, mode="none"):
|
| 47 |
+
res = loss = deepcopy(self)
|
| 48 |
+
while loss is not None:
|
| 49 |
+
assert isinstance(loss, Criterion)
|
| 50 |
+
loss.criterion.reduction = mode # make it return the loss for each sample
|
| 51 |
+
loss = loss._loss2 # we assume loss is a Multiloss
|
| 52 |
+
return res
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class MultiLoss(nn.Module):
|
| 56 |
+
"""Easily combinable losses (also keep track of individual loss values):
|
| 57 |
+
loss = MyLoss1() + 0.1*MyLoss2()
|
| 58 |
+
Usage:
|
| 59 |
+
Inherit from this class and override get_name() and compute_loss()
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self._alpha = 1
|
| 65 |
+
self._loss2 = None
|
| 66 |
+
|
| 67 |
+
def compute_loss(self, *args, **kwargs):
|
| 68 |
+
raise NotImplementedError()
|
| 69 |
+
|
| 70 |
+
def get_name(self):
|
| 71 |
+
raise NotImplementedError()
|
| 72 |
+
|
| 73 |
+
def __mul__(self, alpha):
|
| 74 |
+
assert isinstance(alpha, (int, float))
|
| 75 |
+
res = copy(self)
|
| 76 |
+
res._alpha = alpha
|
| 77 |
+
return res
|
| 78 |
+
|
| 79 |
+
__rmul__ = __mul__ # same
|
| 80 |
+
|
| 81 |
+
def __add__(self, loss2):
|
| 82 |
+
assert isinstance(loss2, MultiLoss)
|
| 83 |
+
res = cur = copy(self)
|
| 84 |
+
|
| 85 |
+
while cur._loss2 is not None:
|
| 86 |
+
cur = cur._loss2
|
| 87 |
+
cur._loss2 = loss2
|
| 88 |
+
return res
|
| 89 |
+
|
| 90 |
+
def __repr__(self):
|
| 91 |
+
name = self.get_name()
|
| 92 |
+
if self._alpha != 1:
|
| 93 |
+
name = f"{self._alpha:g}*{name}"
|
| 94 |
+
if self._loss2:
|
| 95 |
+
name = f"{name} + {self._loss2}"
|
| 96 |
+
return name
|
| 97 |
+
|
| 98 |
+
def forward(self, *args, **kwargs):
|
| 99 |
+
loss = self.compute_loss(*args, **kwargs)
|
| 100 |
+
if isinstance(loss, tuple):
|
| 101 |
+
loss, details = loss
|
| 102 |
+
elif loss.ndim == 0:
|
| 103 |
+
details = {self.get_name(): float(loss)}
|
| 104 |
+
else:
|
| 105 |
+
details = {}
|
| 106 |
+
loss = loss * self._alpha
|
| 107 |
+
|
| 108 |
+
if self._loss2:
|
| 109 |
+
loss2, details2 = self._loss2(*args, **kwargs)
|
| 110 |
+
loss = loss + loss2
|
| 111 |
+
details |= details2
|
| 112 |
+
|
| 113 |
+
return loss, details
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class LLoss(BaseCriterion):
|
| 117 |
+
"""L-norm loss"""
|
| 118 |
+
|
| 119 |
+
def forward(self, a, b):
|
| 120 |
+
assert (
|
| 121 |
+
a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
|
| 122 |
+
), f"Bad shape = {a.shape}"
|
| 123 |
+
dist = self.distance(a, b)
|
| 124 |
+
|
| 125 |
+
if self.reduction == "none":
|
| 126 |
+
return dist
|
| 127 |
+
if self.reduction == "sum":
|
| 128 |
+
return dist.sum()
|
| 129 |
+
if self.reduction == "mean":
|
| 130 |
+
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
|
| 131 |
+
raise ValueError(f"bad {self.reduction=} mode")
|
| 132 |
+
|
| 133 |
+
def distance(self, a, b):
|
| 134 |
+
raise NotImplementedError()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class L21Loss(LLoss):
|
| 138 |
+
"""Euclidean distance between 3d points"""
|
| 139 |
+
|
| 140 |
+
def distance(self, a, b):
|
| 141 |
+
return torch.norm(a - b, dim=-1) # normalized L2 distance
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
L21 = L21Loss()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_pred_pts3d(gt, pred, use_pose=False):
|
| 148 |
+
assert use_pose is True
|
| 149 |
+
return pred["pts3d_in_other_view"] # return!
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def Sum(losses, masks, conf=None):
|
| 153 |
+
loss, mask = losses[0], masks[0]
|
| 154 |
+
if loss.ndim > 0:
|
| 155 |
+
# we are actually returning the loss for every pixels
|
| 156 |
+
if conf is not None:
|
| 157 |
+
return losses, masks, conf
|
| 158 |
+
return losses, masks
|
| 159 |
+
else:
|
| 160 |
+
# we are returning the global loss
|
| 161 |
+
for loss2 in losses[1:]:
|
| 162 |
+
loss = loss + loss2
|
| 163 |
+
return loss
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
|
| 167 |
+
assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
|
| 168 |
+
assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
|
| 169 |
+
norm_mode, dis_mode = norm_mode.split("_")
|
| 170 |
+
|
| 171 |
+
nan_pts = []
|
| 172 |
+
nnzs = []
|
| 173 |
+
|
| 174 |
+
if norm_mode == "avg":
|
| 175 |
+
# gather all points together (joint normalization)
|
| 176 |
+
|
| 177 |
+
for i, pt in enumerate(pts):
|
| 178 |
+
nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
|
| 179 |
+
nan_pts.append(nan_pt)
|
| 180 |
+
nnzs.append(nnz)
|
| 181 |
+
|
| 182 |
+
if fix_first:
|
| 183 |
+
break
|
| 184 |
+
all_pts = torch.cat(nan_pts, dim=1)
|
| 185 |
+
|
| 186 |
+
# compute distance to origin
|
| 187 |
+
all_dis = all_pts.norm(dim=-1)
|
| 188 |
+
if dis_mode == "dis":
|
| 189 |
+
pass # do nothing
|
| 190 |
+
elif dis_mode == "log1p":
|
| 191 |
+
all_dis = torch.log1p(all_dis)
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(f"bad {dis_mode=}")
|
| 194 |
+
|
| 195 |
+
norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
|
| 196 |
+
else:
|
| 197 |
+
raise ValueError(f"Not implemented {norm_mode=}")
|
| 198 |
+
|
| 199 |
+
norm_factor = norm_factor.clip(min=1e-8)
|
| 200 |
+
while norm_factor.ndim < pts[0].ndim:
|
| 201 |
+
norm_factor.unsqueeze_(-1)
|
| 202 |
+
|
| 203 |
+
return norm_factor
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def normalize_pointcloud_t(
|
| 207 |
+
pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
|
| 208 |
+
):
|
| 209 |
+
if gt:
|
| 210 |
+
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
|
| 211 |
+
res = []
|
| 212 |
+
|
| 213 |
+
for i, pt in enumerate(pts):
|
| 214 |
+
res.append(pt / norm_factor)
|
| 215 |
+
|
| 216 |
+
else:
|
| 217 |
+
# pts_l, pts_r = pts
|
| 218 |
+
# use pts_l and pts_r[-1] as pts to normalize
|
| 219 |
+
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
|
| 220 |
+
|
| 221 |
+
res = []
|
| 222 |
+
|
| 223 |
+
for i in range(len(pts)):
|
| 224 |
+
res.append(pts[i] / norm_factor)
|
| 225 |
+
# res_r.append(pts_r[i] / norm_factor)
|
| 226 |
+
|
| 227 |
+
# res = [res_l, res_r]
|
| 228 |
+
|
| 229 |
+
return res, norm_factor
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@torch.no_grad()
|
| 233 |
+
def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
|
| 234 |
+
# set invalid points to NaN
|
| 235 |
+
_zs = []
|
| 236 |
+
for i in range(len(zs)):
|
| 237 |
+
valid_mask = valid_masks[i] if valid_masks is not None else None
|
| 238 |
+
_z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
|
| 239 |
+
_zs.append(_z)
|
| 240 |
+
|
| 241 |
+
_zs = torch.cat(_zs, dim=-1)
|
| 242 |
+
|
| 243 |
+
# compute median depth overall (ignoring nans)
|
| 244 |
+
if quantile == 0.5:
|
| 245 |
+
shift_z = torch.nanmedian(_zs, dim=-1).values
|
| 246 |
+
else:
|
| 247 |
+
shift_z = torch.nanquantile(_zs, quantile, dim=-1)
|
| 248 |
+
return shift_z # (B,)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@torch.no_grad()
|
| 252 |
+
def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
|
| 253 |
+
# set invalid points to NaN
|
| 254 |
+
|
| 255 |
+
_pts = []
|
| 256 |
+
for i in range(len(pts)):
|
| 257 |
+
valid_mask = valid_masks[i] if valid_masks is not None else None
|
| 258 |
+
_pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
|
| 259 |
+
_pts.append(_pt)
|
| 260 |
+
|
| 261 |
+
_pts = torch.cat(_pts, dim=1)
|
| 262 |
+
|
| 263 |
+
# compute median center
|
| 264 |
+
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
|
| 265 |
+
if z_only:
|
| 266 |
+
_center[..., :2] = 0 # do not center X and Y
|
| 267 |
+
|
| 268 |
+
# compute median norm
|
| 269 |
+
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
|
| 270 |
+
scale = torch.nanmedian(_norm, dim=1).values
|
| 271 |
+
return _center[:, None, :, :], scale[:, None, None, None]
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class Regr3D_t(Criterion, MultiLoss):
|
| 275 |
+
def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
|
| 276 |
+
super().__init__(criterion)
|
| 277 |
+
self.norm_mode = norm_mode
|
| 278 |
+
self.gt_scale = gt_scale
|
| 279 |
+
self.fix_first = fix_first
|
| 280 |
+
|
| 281 |
+
def get_all_pts3d_t(self, gts, preds, dist_clip=None):
|
| 282 |
+
# everything is normalized w.r.t. camera of view1
|
| 283 |
+
in_camera1 = inv(gts[0]["camera_pose"])
|
| 284 |
+
|
| 285 |
+
gt_pts = []
|
| 286 |
+
valids = []
|
| 287 |
+
pr_pts = []
|
| 288 |
+
|
| 289 |
+
for i, gt in enumerate(gts):
|
| 290 |
+
# in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
|
| 291 |
+
gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
|
| 292 |
+
valid = gt["valid_mask"].clone()
|
| 293 |
+
|
| 294 |
+
if dist_clip is not None:
|
| 295 |
+
# points that are too far-away == invalid
|
| 296 |
+
dis = gt["pts3d"].norm(dim=-1)
|
| 297 |
+
valid = valid & (dis <= dist_clip)
|
| 298 |
+
|
| 299 |
+
valids.append(valid)
|
| 300 |
+
pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
|
| 301 |
+
# if i != len(gts)-1:
|
| 302 |
+
# pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))
|
| 303 |
+
|
| 304 |
+
# if i != 0:
|
| 305 |
+
# pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))
|
| 306 |
+
|
| 307 |
+
# pr_pts = (pr_pts_l, pr_pts_r)
|
| 308 |
+
|
| 309 |
+
if self.norm_mode:
|
| 310 |
+
pr_pts, pr_factor = normalize_pointcloud_t(
|
| 311 |
+
pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
pr_factor = None
|
| 315 |
+
|
| 316 |
+
if self.norm_mode and not self.gt_scale:
|
| 317 |
+
gt_pts, gt_factor = normalize_pointcloud_t(
|
| 318 |
+
gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
gt_factor = None
|
| 322 |
+
|
| 323 |
+
return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}
|
| 324 |
+
|
| 325 |
+
def compute_frame_loss(self, gts, preds, **kw):
|
| 326 |
+
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
|
| 327 |
+
self.get_all_pts3d_t(gts, preds, **kw)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
pred_pts_l, pred_pts_r = pred_pts
|
| 331 |
+
|
| 332 |
+
loss_all = []
|
| 333 |
+
mask_all = []
|
| 334 |
+
conf_all = []
|
| 335 |
+
|
| 336 |
+
loss_left = 0
|
| 337 |
+
loss_right = 0
|
| 338 |
+
pred_conf_l = 0
|
| 339 |
+
pred_conf_r = 0
|
| 340 |
+
|
| 341 |
+
for i in range(len(gt_pts)):
|
| 342 |
+
|
| 343 |
+
# Left (Reference)
|
| 344 |
+
if i != len(gt_pts) - 1:
|
| 345 |
+
frame_loss = self.criterion(
|
| 346 |
+
pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
loss_all.append(frame_loss)
|
| 350 |
+
mask_all.append(masks[i])
|
| 351 |
+
conf_all.append(preds[i][0]["conf"])
|
| 352 |
+
|
| 353 |
+
# To compare target/reference loss
|
| 354 |
+
if i != 0:
|
| 355 |
+
loss_left += frame_loss.cpu().detach().numpy().mean()
|
| 356 |
+
pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()
|
| 357 |
+
|
| 358 |
+
# Right (Target)
|
| 359 |
+
if i != 0:
|
| 360 |
+
frame_loss = self.criterion(
|
| 361 |
+
pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
loss_all.append(frame_loss)
|
| 365 |
+
mask_all.append(masks[i])
|
| 366 |
+
conf_all.append(preds[i - 1][1]["conf"])
|
| 367 |
+
|
| 368 |
+
# To compare target/reference loss
|
| 369 |
+
if i != len(gt_pts) - 1:
|
| 370 |
+
loss_right += frame_loss.cpu().detach().numpy().mean()
|
| 371 |
+
pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()
|
| 372 |
+
|
| 373 |
+
if pr_factor is not None and gt_factor is not None:
|
| 374 |
+
filter_factor = pr_factor[pr_factor > gt_factor]
|
| 375 |
+
else:
|
| 376 |
+
filter_factor = []
|
| 377 |
+
|
| 378 |
+
if len(filter_factor) > 0:
|
| 379 |
+
factor_loss = (filter_factor - gt_factor).abs().mean()
|
| 380 |
+
else:
|
| 381 |
+
factor_loss = 0.0
|
| 382 |
+
|
| 383 |
+
self_name = type(self).__name__
|
| 384 |
+
details = {
|
| 385 |
+
self_name + "_pts3d_1": float(loss_all[0].mean()),
|
| 386 |
+
self_name + "_pts3d_2": float(loss_all[1].mean()),
|
| 387 |
+
self_name + "loss_left": float(loss_left),
|
| 388 |
+
self_name + "loss_right": float(loss_right),
|
| 389 |
+
self_name + "conf_left": float(pred_conf_l),
|
| 390 |
+
self_name + "conf_right": float(pred_conf_r),
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class ConfLoss_t(MultiLoss):
|
| 397 |
+
"""Weighted regression by learned confidence.
|
| 398 |
+
Assuming the input pixel_loss is a pixel-level regression loss.
|
| 399 |
+
|
| 400 |
+
Principle:
|
| 401 |
+
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
|
| 402 |
+
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
|
| 403 |
+
|
| 404 |
+
alpha: hyperparameter
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
def __init__(self, pixel_loss, alpha=1):
|
| 408 |
+
super().__init__()
|
| 409 |
+
assert alpha > 0
|
| 410 |
+
self.alpha = alpha
|
| 411 |
+
self.pixel_loss = pixel_loss.with_reduction("none")
|
| 412 |
+
|
| 413 |
+
def get_name(self):
|
| 414 |
+
return f"ConfLoss({self.pixel_loss})"
|
| 415 |
+
|
| 416 |
+
def get_conf_log(self, x):
|
| 417 |
+
return x, torch.log(x)
|
| 418 |
+
|
| 419 |
+
def compute_frame_loss(self, gts, preds, **kw):
|
| 420 |
+
# compute per-pixel loss
|
| 421 |
+
(losses, masks, confs), details, loss_factor = (
|
| 422 |
+
self.pixel_loss.compute_frame_loss(gts, preds, **kw)
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# weight by confidence
|
| 426 |
+
conf_losses = []
|
| 427 |
+
conf_sum = 0
|
| 428 |
+
for i in range(len(losses)):
|
| 429 |
+
conf, log_conf = self.get_conf_log(confs[i][masks[i]])
|
| 430 |
+
conf_sum += conf.mean()
|
| 431 |
+
conf_loss = losses[i] * conf - self.alpha * log_conf
|
| 432 |
+
conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
|
| 433 |
+
conf_losses.append(conf_loss)
|
| 434 |
+
|
| 435 |
+
conf_losses = torch.stack(conf_losses) * 2.0
|
| 436 |
+
conf_loss_mean = conf_losses.mean()
|
| 437 |
+
|
| 438 |
+
return (
|
| 439 |
+
conf_loss_mean,
|
| 440 |
+
dict(
|
| 441 |
+
conf_loss_1=float(conf_losses[0]),
|
| 442 |
+
conf_loss2=float(conf_losses[1]),
|
| 443 |
+
conf_mean=conf_sum / len(losses),
|
| 444 |
+
**details,
|
| 445 |
+
),
|
| 446 |
+
loss_factor,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class Regr3D_t_ShiftInv(Regr3D_t):
|
| 451 |
+
"""Same than Regr3D but invariant to depth shift."""
|
| 452 |
+
|
| 453 |
+
def get_all_pts3d_t(self, gts, preds):
|
| 454 |
+
# compute unnormalized points
|
| 455 |
+
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
|
| 456 |
+
super().get_all_pts3d_t(gts, preds)
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# pred_pts_l, pred_pts_r = pred_pts
|
| 460 |
+
gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]
|
| 461 |
+
|
| 462 |
+
pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
|
| 463 |
+
# pred_zs.append(pred_pts_r[-1][..., 2])
|
| 464 |
+
|
| 465 |
+
# compute median depth
|
| 466 |
+
gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
|
| 467 |
+
pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]
|
| 468 |
+
|
| 469 |
+
# subtract the median depth
|
| 470 |
+
for i in range(len(gt_pts)):
|
| 471 |
+
gt_pts[i][..., 2] -= gt_shift_z
|
| 472 |
+
|
| 473 |
+
for i in range(len(pred_pts)):
|
| 474 |
+
# for j in range(len(pred_pts[i])):
|
| 475 |
+
pred_pts[i][..., 2] -= pred_shift_z
|
| 476 |
+
|
| 477 |
+
monitoring = dict(
|
| 478 |
+
monitoring,
|
| 479 |
+
gt_shift_z=gt_shift_z.mean().detach(),
|
| 480 |
+
pred_shift_z=pred_shift_z.mean().detach(),
|
| 481 |
+
)
|
| 482 |
+
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class Regr3D_t_ScaleInv(Regr3D_t):
|
| 486 |
+
"""Same than Regr3D but invariant to depth shift.
|
| 487 |
+
if gt_scale == True: enforce the prediction to take the same scale than GT
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
def get_all_pts3d_t(self, gts, preds):
|
| 491 |
+
# compute depth-normalized points
|
| 492 |
+
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
|
| 493 |
+
super().get_all_pts3d_t(gts, preds)
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# measure scene scale
|
| 497 |
+
|
| 498 |
+
# pred_pts_l, pred_pts_r = pred_pts
|
| 499 |
+
|
| 500 |
+
pred_pts_all = [
|
| 501 |
+
x.clone() for x in pred_pts
|
| 502 |
+
] # [pred_pt for pred_pt in pred_pts_l]
|
| 503 |
+
# pred_pts_all.append(pred_pts_r[-1])
|
| 504 |
+
|
| 505 |
+
_, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
|
| 506 |
+
_, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)
|
| 507 |
+
|
| 508 |
+
# prevent predictions to be in a ridiculous range
|
| 509 |
+
pred_scale = pred_scale.clip(min=1e-3, max=1e3)
|
| 510 |
+
|
| 511 |
+
# subtract the median depth
|
| 512 |
+
if self.gt_scale:
|
| 513 |
+
for i in range(len(pred_pts)):
|
| 514 |
+
# for j in range(len(pred_pts[i])):
|
| 515 |
+
pred_pts[i] *= gt_scale / pred_scale
|
| 516 |
+
|
| 517 |
+
else:
|
| 518 |
+
for i in range(len(pred_pts)):
|
| 519 |
+
# for j in range(len(pred_pts[i])):
|
| 520 |
+
pred_pts[i] *= pred_scale / gt_scale
|
| 521 |
+
|
| 522 |
+
for i in range(len(gt_pts)):
|
| 523 |
+
gt_pts[i] *= gt_scale / pred_scale
|
| 524 |
+
|
| 525 |
+
monitoring = dict(
|
| 526 |
+
monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
|
| 533 |
+
# calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
|
| 534 |
+
pass
|
FastVGGT/eval/data.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os.path as osp
|
| 5 |
+
from collections import deque
|
| 6 |
+
from base import BaseStereoViewDataset
|
| 7 |
+
import dataset_utils.cropping as cropping
|
| 8 |
+
from vggt.utils.eval_utils import imread_cv2, shuffle_deque
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SevenScenes(BaseStereoViewDataset):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
num_seq=1,
|
| 15 |
+
num_frames=5,
|
| 16 |
+
min_thresh=10,
|
| 17 |
+
max_thresh=100,
|
| 18 |
+
test_id=None,
|
| 19 |
+
full_video=False,
|
| 20 |
+
tuple_list=None,
|
| 21 |
+
seq_id=None,
|
| 22 |
+
rebuttal=False,
|
| 23 |
+
shuffle_seed=-1,
|
| 24 |
+
kf_every=1,
|
| 25 |
+
*args,
|
| 26 |
+
ROOT,
|
| 27 |
+
**kwargs,
|
| 28 |
+
):
|
| 29 |
+
self.ROOT = ROOT
|
| 30 |
+
super().__init__(*args, **kwargs)
|
| 31 |
+
self.num_seq = num_seq
|
| 32 |
+
self.num_frames = num_frames
|
| 33 |
+
self.max_thresh = max_thresh
|
| 34 |
+
self.min_thresh = min_thresh
|
| 35 |
+
self.test_id = test_id
|
| 36 |
+
self.full_video = full_video
|
| 37 |
+
self.kf_every = kf_every
|
| 38 |
+
self.seq_id = seq_id
|
| 39 |
+
self.rebuttal = rebuttal
|
| 40 |
+
self.shuffle_seed = shuffle_seed
|
| 41 |
+
|
| 42 |
+
# load all scenes
|
| 43 |
+
self.load_all_tuples(tuple_list)
|
| 44 |
+
self.load_all_scenes(ROOT)
|
| 45 |
+
|
| 46 |
+
def __len__(self):
|
| 47 |
+
if self.tuple_list is not None:
|
| 48 |
+
return len(self.tuple_list)
|
| 49 |
+
return len(self.scene_list) * self.num_seq
|
| 50 |
+
|
| 51 |
+
def load_all_tuples(self, tuple_list):
|
| 52 |
+
if tuple_list is not None:
|
| 53 |
+
self.tuple_list = tuple_list
|
| 54 |
+
# with open(tuple_path) as f:
|
| 55 |
+
# self.tuple_list = f.read().splitlines()
|
| 56 |
+
|
| 57 |
+
else:
|
| 58 |
+
self.tuple_list = None
|
| 59 |
+
|
| 60 |
+
def load_all_scenes(self, base_dir):
|
| 61 |
+
|
| 62 |
+
if self.tuple_list is not None:
|
| 63 |
+
# Use pre-defined simplerecon scene_ids
|
| 64 |
+
self.scene_list = [
|
| 65 |
+
"stairs/seq-06",
|
| 66 |
+
"stairs/seq-02",
|
| 67 |
+
"pumpkin/seq-06",
|
| 68 |
+
"chess/seq-01",
|
| 69 |
+
"heads/seq-02",
|
| 70 |
+
"fire/seq-02",
|
| 71 |
+
"office/seq-03",
|
| 72 |
+
"pumpkin/seq-03",
|
| 73 |
+
"redkitchen/seq-07",
|
| 74 |
+
"chess/seq-02",
|
| 75 |
+
"office/seq-01",
|
| 76 |
+
"redkitchen/seq-01",
|
| 77 |
+
"fire/seq-01",
|
| 78 |
+
]
|
| 79 |
+
print(f"Found {len(self.scene_list)} sequences in split {self.split}")
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
scenes = os.listdir(base_dir)
|
| 83 |
+
|
| 84 |
+
file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split]
|
| 85 |
+
|
| 86 |
+
self.scene_list = []
|
| 87 |
+
for scene in scenes:
|
| 88 |
+
if self.test_id is not None and scene != self.test_id:
|
| 89 |
+
continue
|
| 90 |
+
# read file split
|
| 91 |
+
with open(osp.join(base_dir, scene, file_split)) as f:
|
| 92 |
+
seq_ids = f.read().splitlines()
|
| 93 |
+
|
| 94 |
+
for seq_id in seq_ids:
|
| 95 |
+
# seq is string, take the int part and make it 01, 02, 03
|
| 96 |
+
# seq_id = 'seq-{:2d}'.format(int(seq_id))
|
| 97 |
+
num_part = "".join(filter(str.isdigit, seq_id))
|
| 98 |
+
seq_id = f"seq-{num_part.zfill(2)}"
|
| 99 |
+
if self.seq_id is not None and seq_id != self.seq_id:
|
| 100 |
+
continue
|
| 101 |
+
self.scene_list.append(f"{scene}/{seq_id}")
|
| 102 |
+
|
| 103 |
+
print(f"Found {len(self.scene_list)} sequences in split {self.split}")
|
| 104 |
+
|
| 105 |
+
def _get_views(self, idx, resolution, rng):
|
| 106 |
+
|
| 107 |
+
if self.tuple_list is not None:
|
| 108 |
+
line = self.tuple_list[idx].split(" ")
|
| 109 |
+
scene_id = line[0]
|
| 110 |
+
img_idxs = line[1:]
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
scene_id = self.scene_list[idx // self.num_seq]
|
| 114 |
+
seq_id = idx % self.num_seq
|
| 115 |
+
|
| 116 |
+
data_path = osp.join(self.ROOT, scene_id)
|
| 117 |
+
num_files = len([name for name in os.listdir(data_path) if "color" in name])
|
| 118 |
+
img_idxs = [f"{i:06d}" for i in range(num_files)]
|
| 119 |
+
img_idxs = img_idxs[:: self.kf_every]
|
| 120 |
+
|
| 121 |
+
# Intrinsics used in SimpleRecon
|
| 122 |
+
fx, fy, cx, cy = 525, 525, 320, 240
|
| 123 |
+
intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| 124 |
+
|
| 125 |
+
views = []
|
| 126 |
+
imgs_idxs = deque(img_idxs)
|
| 127 |
+
if self.shuffle_seed >= 0:
|
| 128 |
+
imgs_idxs = shuffle_deque(imgs_idxs)
|
| 129 |
+
|
| 130 |
+
while len(imgs_idxs) > 0:
|
| 131 |
+
im_idx = imgs_idxs.popleft()
|
| 132 |
+
impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png")
|
| 133 |
+
depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png")
|
| 134 |
+
posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt")
|
| 135 |
+
|
| 136 |
+
rgb_image = imread_cv2(impath)
|
| 137 |
+
|
| 138 |
+
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
|
| 139 |
+
rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
|
| 140 |
+
|
| 141 |
+
depthmap[depthmap == 65535] = 0
|
| 142 |
+
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
|
| 143 |
+
|
| 144 |
+
depthmap[depthmap > 10] = 0
|
| 145 |
+
depthmap[depthmap < 1e-3] = 0
|
| 146 |
+
|
| 147 |
+
camera_pose = np.loadtxt(posepath).astype(np.float32)
|
| 148 |
+
|
| 149 |
+
if resolution != (224, 224) or self.rebuttal:
|
| 150 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 151 |
+
rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 155 |
+
rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
|
| 156 |
+
)
|
| 157 |
+
W, H = rgb_image.size
|
| 158 |
+
cx = W // 2
|
| 159 |
+
cy = H // 2
|
| 160 |
+
l, t = cx - 112, cy - 112
|
| 161 |
+
r, b = cx + 112, cy + 112
|
| 162 |
+
crop_bbox = (l, t, r, b)
|
| 163 |
+
rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 164 |
+
rgb_image, depthmap, intrinsics, crop_bbox
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
views.append(
|
| 168 |
+
dict(
|
| 169 |
+
img=rgb_image,
|
| 170 |
+
depthmap=depthmap,
|
| 171 |
+
camera_pose=camera_pose,
|
| 172 |
+
camera_intrinsics=intrinsics,
|
| 173 |
+
dataset="7scenes",
|
| 174 |
+
label=osp.join(scene_id, im_idx),
|
| 175 |
+
instance=impath,
|
| 176 |
+
)
|
| 177 |
+
)
|
| 178 |
+
return views
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class NRGBD(BaseStereoViewDataset):
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
num_seq=1,
|
| 185 |
+
num_frames=5,
|
| 186 |
+
min_thresh=10,
|
| 187 |
+
max_thresh=100,
|
| 188 |
+
test_id=None,
|
| 189 |
+
full_video=False,
|
| 190 |
+
tuple_list=None,
|
| 191 |
+
seq_id=None,
|
| 192 |
+
rebuttal=False,
|
| 193 |
+
shuffle_seed=-1,
|
| 194 |
+
kf_every=1,
|
| 195 |
+
*args,
|
| 196 |
+
ROOT,
|
| 197 |
+
**kwargs,
|
| 198 |
+
):
|
| 199 |
+
|
| 200 |
+
self.ROOT = ROOT
|
| 201 |
+
super().__init__(*args, **kwargs)
|
| 202 |
+
self.num_seq = num_seq
|
| 203 |
+
self.num_frames = num_frames
|
| 204 |
+
self.max_thresh = max_thresh
|
| 205 |
+
self.min_thresh = min_thresh
|
| 206 |
+
self.test_id = test_id
|
| 207 |
+
self.full_video = full_video
|
| 208 |
+
self.kf_every = kf_every
|
| 209 |
+
self.seq_id = seq_id
|
| 210 |
+
self.rebuttal = rebuttal
|
| 211 |
+
self.shuffle_seed = shuffle_seed
|
| 212 |
+
|
| 213 |
+
# load all scenes
|
| 214 |
+
self.load_all_tuples(tuple_list)
|
| 215 |
+
self.load_all_scenes(ROOT)
|
| 216 |
+
|
| 217 |
+
def __len__(self):
|
| 218 |
+
if self.tuple_list is not None:
|
| 219 |
+
return len(self.tuple_list)
|
| 220 |
+
return len(self.scene_list) * self.num_seq
|
| 221 |
+
|
| 222 |
+
def load_all_tuples(self, tuple_list):
|
| 223 |
+
if tuple_list is not None:
|
| 224 |
+
self.tuple_list = tuple_list
|
| 225 |
+
# with open(tuple_path) as f:
|
| 226 |
+
# self.tuple_list = f.read().splitlines()
|
| 227 |
+
|
| 228 |
+
else:
|
| 229 |
+
self.tuple_list = None
|
| 230 |
+
|
| 231 |
+
def load_all_scenes(self, base_dir):
|
| 232 |
+
|
| 233 |
+
scenes = [
|
| 234 |
+
d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
if self.test_id is not None:
|
| 238 |
+
self.scene_list = [self.test_id]
|
| 239 |
+
|
| 240 |
+
else:
|
| 241 |
+
self.scene_list = scenes
|
| 242 |
+
|
| 243 |
+
print(f"Found {len(self.scene_list)} sequences in split {self.split}")
|
| 244 |
+
|
| 245 |
+
def load_poses(self, path):
|
| 246 |
+
file = open(path, "r")
|
| 247 |
+
lines = file.readlines()
|
| 248 |
+
file.close()
|
| 249 |
+
poses = []
|
| 250 |
+
valid = []
|
| 251 |
+
lines_per_matrix = 4
|
| 252 |
+
for i in range(0, len(lines), lines_per_matrix):
|
| 253 |
+
if "nan" in lines[i]:
|
| 254 |
+
valid.append(False)
|
| 255 |
+
poses.append(np.eye(4, 4, dtype=np.float32).tolist())
|
| 256 |
+
else:
|
| 257 |
+
valid.append(True)
|
| 258 |
+
pose_floats = [
|
| 259 |
+
[float(x) for x in line.split()]
|
| 260 |
+
for line in lines[i : i + lines_per_matrix]
|
| 261 |
+
]
|
| 262 |
+
poses.append(pose_floats)
|
| 263 |
+
|
| 264 |
+
return np.array(poses, dtype=np.float32), valid
|
| 265 |
+
|
| 266 |
+
def _get_views(self, idx, resolution, rng):
|
| 267 |
+
|
| 268 |
+
if self.tuple_list is not None:
|
| 269 |
+
line = self.tuple_list[idx].split(" ")
|
| 270 |
+
scene_id = line[0]
|
| 271 |
+
img_idxs = line[1:]
|
| 272 |
+
|
| 273 |
+
else:
|
| 274 |
+
scene_id = self.scene_list[idx // self.num_seq]
|
| 275 |
+
|
| 276 |
+
num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images")))
|
| 277 |
+
img_idxs = [f"{i}" for i in range(num_files)]
|
| 278 |
+
img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)]
|
| 279 |
+
|
| 280 |
+
fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240
|
| 281 |
+
intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| 282 |
+
|
| 283 |
+
posepath = osp.join(self.ROOT, scene_id, f"poses.txt")
|
| 284 |
+
camera_poses, valids = self.load_poses(posepath)
|
| 285 |
+
|
| 286 |
+
imgs_idxs = deque(img_idxs)
|
| 287 |
+
if self.shuffle_seed >= 0:
|
| 288 |
+
imgs_idxs = shuffle_deque(imgs_idxs)
|
| 289 |
+
views = []
|
| 290 |
+
|
| 291 |
+
while len(imgs_idxs) > 0:
|
| 292 |
+
im_idx = imgs_idxs.popleft()
|
| 293 |
+
|
| 294 |
+
impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png")
|
| 295 |
+
depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png")
|
| 296 |
+
|
| 297 |
+
rgb_image = imread_cv2(impath)
|
| 298 |
+
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
|
| 299 |
+
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
|
| 300 |
+
depthmap[depthmap > 10] = 0
|
| 301 |
+
depthmap[depthmap < 1e-3] = 0
|
| 302 |
+
|
| 303 |
+
rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
|
| 304 |
+
|
| 305 |
+
camera_pose = camera_poses[int(im_idx)]
|
| 306 |
+
# gl to cv
|
| 307 |
+
camera_pose[:, 1:3] *= -1.0
|
| 308 |
+
if resolution != (224, 224) or self.rebuttal:
|
| 309 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 310 |
+
rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
|
| 311 |
+
)
|
| 312 |
+
else:
|
| 313 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 314 |
+
rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
|
| 315 |
+
)
|
| 316 |
+
W, H = rgb_image.size
|
| 317 |
+
cx = W // 2
|
| 318 |
+
cy = H // 2
|
| 319 |
+
l, t = cx - 112, cy - 112
|
| 320 |
+
r, b = cx + 112, cy + 112
|
| 321 |
+
crop_bbox = (l, t, r, b)
|
| 322 |
+
rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 323 |
+
rgb_image, depthmap, intrinsics, crop_bbox
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
views.append(
|
| 327 |
+
dict(
|
| 328 |
+
img=rgb_image,
|
| 329 |
+
depthmap=depthmap,
|
| 330 |
+
camera_pose=camera_pose,
|
| 331 |
+
camera_intrinsics=intrinsics,
|
| 332 |
+
dataset="nrgbd",
|
| 333 |
+
label=osp.join(scene_id, im_idx),
|
| 334 |
+
instance=impath,
|
| 335 |
+
)
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return views
|
FastVGGT/eval/dataset_utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (140 Bytes). View file
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc
ADDED
|
Binary file (5.85 kB). View file
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc
ADDED
|
Binary file (2.18 kB). View file
|
|
|
FastVGGT/eval/dataset_utils/corr.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def todevice(batch, device, callback=None, non_blocking=False):
|
| 11 |
+
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
|
| 12 |
+
|
| 13 |
+
batch: list, tuple, dict of tensors or other things
|
| 14 |
+
device: pytorch device or 'numpy'
|
| 15 |
+
callback: function that would be called on every sub-elements.
|
| 16 |
+
"""
|
| 17 |
+
if callback:
|
| 18 |
+
batch = callback(batch)
|
| 19 |
+
|
| 20 |
+
if isinstance(batch, dict):
|
| 21 |
+
return {k: todevice(v, device) for k, v in batch.items()}
|
| 22 |
+
|
| 23 |
+
if isinstance(batch, (tuple, list)):
|
| 24 |
+
return type(batch)(todevice(x, device) for x in batch)
|
| 25 |
+
|
| 26 |
+
x = batch
|
| 27 |
+
if device == "numpy":
|
| 28 |
+
if isinstance(x, torch.Tensor):
|
| 29 |
+
x = x.detach().cpu().numpy()
|
| 30 |
+
elif x is not None:
|
| 31 |
+
if isinstance(x, np.ndarray):
|
| 32 |
+
x = torch.from_numpy(x)
|
| 33 |
+
if torch.is_tensor(x):
|
| 34 |
+
x = x.to(device, non_blocking=non_blocking)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
to_device = todevice # alias
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def to_numpy(x):
|
| 42 |
+
return todevice(x, "numpy")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def geotrf(Trf, pts, ncol=None, norm=False):
|
| 46 |
+
"""Apply a geometric transformation to a list of 3-D points.
|
| 47 |
+
|
| 48 |
+
H: 3x3 or 4x4 projection matrix (typically a Homography)
|
| 49 |
+
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
| 50 |
+
|
| 51 |
+
ncol: int. number of columns of the result (2 or 3)
|
| 52 |
+
norm: float. if != 0, the resut is projected on the z=norm plane.
|
| 53 |
+
|
| 54 |
+
Returns an array of projected 2d points.
|
| 55 |
+
"""
|
| 56 |
+
assert Trf.ndim >= 2
|
| 57 |
+
if isinstance(Trf, np.ndarray):
|
| 58 |
+
pts = np.asarray(pts)
|
| 59 |
+
elif isinstance(Trf, torch.Tensor):
|
| 60 |
+
pts = torch.as_tensor(pts, dtype=Trf.dtype)
|
| 61 |
+
|
| 62 |
+
output_reshape = pts.shape[:-1]
|
| 63 |
+
ncol = ncol or pts.shape[-1]
|
| 64 |
+
|
| 65 |
+
if (
|
| 66 |
+
isinstance(Trf, torch.Tensor)
|
| 67 |
+
and isinstance(pts, torch.Tensor)
|
| 68 |
+
and Trf.ndim == 3
|
| 69 |
+
and pts.ndim == 4
|
| 70 |
+
):
|
| 71 |
+
d = pts.shape[3]
|
| 72 |
+
if Trf.shape[-1] == d:
|
| 73 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
|
| 74 |
+
elif Trf.shape[-1] == d + 1:
|
| 75 |
+
pts = (
|
| 76 |
+
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
|
| 77 |
+
+ Trf[:, None, None, :d, d]
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
|
| 81 |
+
else:
|
| 82 |
+
if Trf.ndim >= 3:
|
| 83 |
+
n = Trf.ndim - 2
|
| 84 |
+
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
|
| 85 |
+
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
| 86 |
+
|
| 87 |
+
if pts.ndim > Trf.ndim:
|
| 88 |
+
|
| 89 |
+
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
| 90 |
+
elif pts.ndim == 2:
|
| 91 |
+
|
| 92 |
+
pts = pts[:, None, :]
|
| 93 |
+
|
| 94 |
+
if pts.shape[-1] + 1 == Trf.shape[-1]:
|
| 95 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 96 |
+
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
| 97 |
+
elif pts.shape[-1] == Trf.shape[-1]:
|
| 98 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 99 |
+
pts = pts @ Trf
|
| 100 |
+
else:
|
| 101 |
+
pts = Trf @ pts.T
|
| 102 |
+
if pts.ndim >= 2:
|
| 103 |
+
pts = pts.swapaxes(-1, -2)
|
| 104 |
+
|
| 105 |
+
if norm:
|
| 106 |
+
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
| 107 |
+
if norm != 1:
|
| 108 |
+
pts *= norm
|
| 109 |
+
|
| 110 |
+
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
| 111 |
+
return res
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def inv(mat):
|
| 115 |
+
"""Invert a torch or numpy matrix"""
|
| 116 |
+
if isinstance(mat, torch.Tensor):
|
| 117 |
+
return torch.linalg.inv(mat)
|
| 118 |
+
if isinstance(mat, np.ndarray):
|
| 119 |
+
return np.linalg.inv(mat)
|
| 120 |
+
raise ValueError(f"bad matrix type = {type(mat)}")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def reproject_view(pts3d, view2):
|
| 124 |
+
shape = view2["pts3d"].shape[:2]
|
| 125 |
+
return reproject(
|
| 126 |
+
pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def reproject(pts3d, K, world2cam, shape):
|
| 131 |
+
H, W, THREE = pts3d.shape
|
| 132 |
+
assert THREE == 3
|
| 133 |
+
|
| 134 |
+
with np.errstate(divide="ignore", invalid="ignore"):
|
| 135 |
+
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
|
| 136 |
+
|
| 137 |
+
return (H, W), ravel_xy(pos, shape)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def ravel_xy(pos, shape):
|
| 141 |
+
H, W = shape
|
| 142 |
+
with np.errstate(invalid="ignore"):
|
| 143 |
+
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
|
| 144 |
+
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
|
| 145 |
+
min=0, max=H - 1, out=qy
|
| 146 |
+
)
|
| 147 |
+
return quantized_pos
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def unravel_xy(pos, shape):
|
| 151 |
+
|
| 152 |
+
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
|
| 156 |
+
is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
|
| 157 |
+
pos1 = is_reciprocal1.nonzero()[0]
|
| 158 |
+
pos2 = corres_1_to_2[pos1]
|
| 159 |
+
if ret_recip:
|
| 160 |
+
return is_reciprocal1, pos1, pos2
|
| 161 |
+
return pos1, pos2
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def extract_correspondences_from_pts3d(
|
| 165 |
+
view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
|
| 166 |
+
):
|
| 167 |
+
view1, view2 = to_numpy((view1, view2))
|
| 168 |
+
|
| 169 |
+
shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
|
| 170 |
+
shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
|
| 171 |
+
|
| 172 |
+
is_reciprocal1, pos1, pos2 = reciprocal_1d(
|
| 173 |
+
corres1_to_2, corres2_to_1, ret_recip=True
|
| 174 |
+
)
|
| 175 |
+
is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
|
| 176 |
+
|
| 177 |
+
if target_n_corres is None:
|
| 178 |
+
if ret_xy:
|
| 179 |
+
pos1 = unravel_xy(pos1, shape1)
|
| 180 |
+
pos2 = unravel_xy(pos2, shape2)
|
| 181 |
+
return pos1, pos2
|
| 182 |
+
|
| 183 |
+
available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
|
| 184 |
+
target_n_positives = int(target_n_corres * (1 - nneg))
|
| 185 |
+
n_positives = min(len(pos1), target_n_positives)
|
| 186 |
+
n_negatives = min(target_n_corres - n_positives, available_negatives)
|
| 187 |
+
|
| 188 |
+
if n_negatives + n_positives != target_n_corres:
|
| 189 |
+
|
| 190 |
+
n_positives = target_n_corres - n_negatives
|
| 191 |
+
assert n_positives <= len(pos1)
|
| 192 |
+
|
| 193 |
+
assert n_positives <= len(pos1)
|
| 194 |
+
assert n_positives <= len(pos2)
|
| 195 |
+
assert n_negatives <= (~is_reciprocal1).sum()
|
| 196 |
+
assert n_negatives <= (~is_reciprocal2).sum()
|
| 197 |
+
assert n_positives + n_negatives == target_n_corres
|
| 198 |
+
|
| 199 |
+
valid = np.ones(n_positives, dtype=bool)
|
| 200 |
+
if n_positives < len(pos1):
|
| 201 |
+
|
| 202 |
+
perm = rng.permutation(len(pos1))[:n_positives]
|
| 203 |
+
pos1 = pos1[perm]
|
| 204 |
+
pos2 = pos2[perm]
|
| 205 |
+
|
| 206 |
+
if n_negatives > 0:
|
| 207 |
+
|
| 208 |
+
def norm(p):
|
| 209 |
+
return p / p.sum()
|
| 210 |
+
|
| 211 |
+
pos1 = np.r_[
|
| 212 |
+
pos1,
|
| 213 |
+
rng.choice(
|
| 214 |
+
shape1[0] * shape1[1],
|
| 215 |
+
size=n_negatives,
|
| 216 |
+
replace=False,
|
| 217 |
+
p=norm(~is_reciprocal1),
|
| 218 |
+
),
|
| 219 |
+
]
|
| 220 |
+
pos2 = np.r_[
|
| 221 |
+
pos2,
|
| 222 |
+
rng.choice(
|
| 223 |
+
shape2[0] * shape2[1],
|
| 224 |
+
size=n_negatives,
|
| 225 |
+
replace=False,
|
| 226 |
+
p=norm(~is_reciprocal2),
|
| 227 |
+
),
|
| 228 |
+
]
|
| 229 |
+
valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
|
| 230 |
+
|
| 231 |
+
if ret_xy:
|
| 232 |
+
pos1 = unravel_xy(pos1, shape1)
|
| 233 |
+
pos2 = unravel_xy(pos2, shape2)
|
| 234 |
+
return pos1, pos2, valid
|
FastVGGT/eval/dataset_utils/cropping.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import PIL.Image
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from utils import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
|
| 10 |
+
|
| 11 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 12 |
+
import cv2 # noqa
|
| 13 |
+
import numpy as np # noqa
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
lanczos = PIL.Image.Resampling.LANCZOS
|
| 17 |
+
bicubic = PIL.Image.Resampling.BICUBIC
|
| 18 |
+
except AttributeError:
|
| 19 |
+
lanczos = PIL.Image.LANCZOS
|
| 20 |
+
bicubic = PIL.Image.BICUBIC
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ImageList:
|
| 24 |
+
"""Convenience class to aply the same operation to a whole set of images."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, images):
|
| 27 |
+
if not isinstance(images, (tuple, list, set)):
|
| 28 |
+
images = [images]
|
| 29 |
+
self.images = []
|
| 30 |
+
for image in images:
|
| 31 |
+
if not isinstance(image, PIL.Image.Image):
|
| 32 |
+
image = PIL.Image.fromarray(image)
|
| 33 |
+
self.images.append(image)
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return len(self.images)
|
| 37 |
+
|
| 38 |
+
def to_pil(self):
|
| 39 |
+
return tuple(self.images) if len(self.images) > 1 else self.images[0]
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def size(self):
|
| 43 |
+
sizes = [im.size for im in self.images]
|
| 44 |
+
assert all(sizes[0] == s for s in sizes)
|
| 45 |
+
return sizes[0]
|
| 46 |
+
|
| 47 |
+
def resize(self, *args, **kwargs):
|
| 48 |
+
return ImageList(self._dispatch("resize", *args, **kwargs))
|
| 49 |
+
|
| 50 |
+
def crop(self, *args, **kwargs):
|
| 51 |
+
return ImageList(self._dispatch("crop", *args, **kwargs))
|
| 52 |
+
|
| 53 |
+
def _dispatch(self, func, *args, **kwargs):
|
| 54 |
+
return [getattr(im, func)(*args, **kwargs) for im in self.images]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def rescale_image_depthmap(
|
| 58 |
+
image, depthmap, camera_intrinsics, output_resolution, force=True
|
| 59 |
+
):
|
| 60 |
+
"""Jointly rescale a (image, depthmap)
|
| 61 |
+
so that (out_width, out_height) >= output_res
|
| 62 |
+
"""
|
| 63 |
+
image = ImageList(image)
|
| 64 |
+
input_resolution = np.array(image.size) # (W,H)
|
| 65 |
+
output_resolution = np.array(output_resolution)
|
| 66 |
+
if depthmap is not None:
|
| 67 |
+
|
| 68 |
+
assert tuple(depthmap.shape[:2]) == image.size[::-1]
|
| 69 |
+
|
| 70 |
+
assert output_resolution.shape == (2,)
|
| 71 |
+
scale_final = max(output_resolution / image.size) + 1e-8
|
| 72 |
+
if scale_final >= 1 and not force: # image is already smaller than what is asked
|
| 73 |
+
return (image.to_pil(), depthmap, camera_intrinsics)
|
| 74 |
+
output_resolution = np.floor(input_resolution * scale_final).astype(int)
|
| 75 |
+
|
| 76 |
+
image = image.resize(
|
| 77 |
+
output_resolution, resample=lanczos if scale_final < 1 else bicubic
|
| 78 |
+
)
|
| 79 |
+
if depthmap is not None:
|
| 80 |
+
depthmap = cv2.resize(
|
| 81 |
+
depthmap,
|
| 82 |
+
output_resolution,
|
| 83 |
+
fx=scale_final,
|
| 84 |
+
fy=scale_final,
|
| 85 |
+
interpolation=cv2.INTER_NEAREST,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
camera_intrinsics = camera_matrix_of_crop(
|
| 89 |
+
camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def camera_matrix_of_crop(
|
| 96 |
+
input_camera_matrix,
|
| 97 |
+
input_resolution,
|
| 98 |
+
output_resolution,
|
| 99 |
+
scaling=1,
|
| 100 |
+
offset_factor=0.5,
|
| 101 |
+
offset=None,
|
| 102 |
+
):
|
| 103 |
+
|
| 104 |
+
margins = np.asarray(input_resolution) * scaling - output_resolution
|
| 105 |
+
assert np.all(margins >= 0.0)
|
| 106 |
+
if offset is None:
|
| 107 |
+
offset = offset_factor * margins
|
| 108 |
+
|
| 109 |
+
output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
|
| 110 |
+
output_camera_matrix_colmap[:2, :] *= scaling
|
| 111 |
+
output_camera_matrix_colmap[:2, 2] -= offset
|
| 112 |
+
output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
|
| 113 |
+
|
| 114 |
+
return output_camera_matrix
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
|
| 118 |
+
"""
|
| 119 |
+
Return a crop of the input view.
|
| 120 |
+
"""
|
| 121 |
+
image = ImageList(image)
|
| 122 |
+
l, t, r, b = crop_bbox
|
| 123 |
+
|
| 124 |
+
image = image.crop((l, t, r, b))
|
| 125 |
+
depthmap = depthmap[t:b, l:r]
|
| 126 |
+
|
| 127 |
+
camera_intrinsics = camera_intrinsics.copy()
|
| 128 |
+
camera_intrinsics[0, 2] -= l
|
| 129 |
+
camera_intrinsics[1, 2] -= t
|
| 130 |
+
|
| 131 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def bbox_from_intrinsics_in_out(
|
| 135 |
+
input_camera_matrix, output_camera_matrix, output_resolution
|
| 136 |
+
):
|
| 137 |
+
out_width, out_height = output_resolution
|
| 138 |
+
l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
|
| 139 |
+
crop_bbox = (l, t, l + out_width, t + out_height)
|
| 140 |
+
return crop_bbox
|
FastVGGT/eval/dataset_utils/transforms.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import torchvision.transforms as tvf
|
| 7 |
+
|
| 8 |
+
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
|
| 15 |
+
if isinstance(value, (int, float)):
|
| 16 |
+
if value < 0:
|
| 17 |
+
raise ValueError(f"If is a single number, it must be non negative.")
|
| 18 |
+
value = [center - float(value), center + float(value)]
|
| 19 |
+
if clip_first_on_zero:
|
| 20 |
+
value[0] = max(value[0], 0.0)
|
| 21 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
| 22 |
+
value = [float(value[0]), float(value[1])]
|
| 23 |
+
else:
|
| 24 |
+
raise TypeError(f"should be a single number or a list/tuple with length 2.")
|
| 25 |
+
|
| 26 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
| 27 |
+
raise ValueError(f"values should be between {bound}, but got {value}.")
|
| 28 |
+
|
| 29 |
+
if value[0] == value[1] == center:
|
| 30 |
+
return None
|
| 31 |
+
else:
|
| 32 |
+
return tuple(value)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torchvision.transforms.functional as F
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def SeqColorJitter():
|
| 40 |
+
"""
|
| 41 |
+
Return a color jitter transform with same random parameters
|
| 42 |
+
"""
|
| 43 |
+
brightness = _check_input(0.5)
|
| 44 |
+
contrast = _check_input(0.5)
|
| 45 |
+
saturation = _check_input(0.5)
|
| 46 |
+
hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
| 47 |
+
|
| 48 |
+
fn_idx = torch.randperm(4)
|
| 49 |
+
brightness_factor = (
|
| 50 |
+
None
|
| 51 |
+
if brightness is None
|
| 52 |
+
else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
|
| 53 |
+
)
|
| 54 |
+
contrast_factor = (
|
| 55 |
+
None
|
| 56 |
+
if contrast is None
|
| 57 |
+
else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
|
| 58 |
+
)
|
| 59 |
+
saturation_factor = (
|
| 60 |
+
None
|
| 61 |
+
if saturation is None
|
| 62 |
+
else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
|
| 63 |
+
)
|
| 64 |
+
hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
|
| 65 |
+
|
| 66 |
+
def _color_jitter(img):
|
| 67 |
+
for fn_id in fn_idx:
|
| 68 |
+
if fn_id == 0 and brightness_factor is not None:
|
| 69 |
+
img = F.adjust_brightness(img, brightness_factor)
|
| 70 |
+
elif fn_id == 1 and contrast_factor is not None:
|
| 71 |
+
img = F.adjust_contrast(img, contrast_factor)
|
| 72 |
+
elif fn_id == 2 and saturation_factor is not None:
|
| 73 |
+
img = F.adjust_saturation(img, saturation_factor)
|
| 74 |
+
elif fn_id == 3 and hue_factor is not None:
|
| 75 |
+
img = F.adjust_hue(img, hue_factor)
|
| 76 |
+
return ImgNorm(img)
|
| 77 |
+
|
| 78 |
+
return _color_jitter
|
FastVGGT/eval/eval_7andN.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# Ensure project root is on sys.path for absolute imports like `vggt.*`
|
| 5 |
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
| 6 |
+
if ROOT_DIR not in sys.path:
|
| 7 |
+
sys.path.insert(0, ROOT_DIR)
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
import torch
|
| 11 |
+
import argparse
|
| 12 |
+
import numpy as np
|
| 13 |
+
import open3d as o3d
|
| 14 |
+
import os.path as osp
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
from torch.utils.data._utils.collate import default_collate
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
import torchvision.transforms as transforms
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_args_parser():
|
| 23 |
+
parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--ckpt_path",
|
| 26 |
+
type=str,
|
| 27 |
+
default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
|
| 28 |
+
help="ckpt name",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument("--device", type=str, default="cuda:0", help="device")
|
| 31 |
+
parser.add_argument("--model_name", type=str, default="VGGT")
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--conf_thresh", type=float, default=0.0, help="confidence threshold"
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--output_dir",
|
| 37 |
+
type=str,
|
| 38 |
+
default="/home/sy/code/FastVGGT/eval_results",
|
| 39 |
+
help="value for outdir",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument("--size", type=int, default=518)
|
| 42 |
+
parser.add_argument("--revisit", type=int, default=1, help="revisit times")
|
| 43 |
+
parser.add_argument("--freeze", action="store_true")
|
| 44 |
+
parser.add_argument("--use_proj", action="store_true")
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--merging", type=int, default=0, help="VGGT aggregator merging steps"
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument("--kf", type=int, default=2, help="key frame")
|
| 49 |
+
return parser
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def main(args):
|
| 53 |
+
from data import SevenScenes, NRGBD
|
| 54 |
+
from utils import accuracy, completion
|
| 55 |
+
|
| 56 |
+
if args.size == 512:
|
| 57 |
+
resolution = (512, 384)
|
| 58 |
+
elif args.size == 224:
|
| 59 |
+
resolution = 224
|
| 60 |
+
elif args.size == 518:
|
| 61 |
+
resolution = (518, 392)
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
datasets_all = {
|
| 65 |
+
"7scenes": SevenScenes(
|
| 66 |
+
split="test",
|
| 67 |
+
ROOT="/data/sy/7scenes",
|
| 68 |
+
resolution=resolution,
|
| 69 |
+
num_seq=1,
|
| 70 |
+
full_video=True,
|
| 71 |
+
kf_every=args.kf,
|
| 72 |
+
), # 20),
|
| 73 |
+
"NRGBD": NRGBD(
|
| 74 |
+
split="test",
|
| 75 |
+
ROOT="/data/sy/neural_rgbd_data",
|
| 76 |
+
resolution=resolution,
|
| 77 |
+
num_seq=1,
|
| 78 |
+
full_video=True,
|
| 79 |
+
kf_every=args.kf,
|
| 80 |
+
),
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
device = args.device
|
| 84 |
+
model_name = args.model_name
|
| 85 |
+
|
| 86 |
+
from vggt.models.vggt import VGGT
|
| 87 |
+
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 88 |
+
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
| 89 |
+
from criterion import Regr3D_t_ScaleShiftInv, L21
|
| 90 |
+
|
| 91 |
+
# Force use of bf16 data type
|
| 92 |
+
dtype = torch.bfloat16
|
| 93 |
+
# Load VGGT model
|
| 94 |
+
model = VGGT(merging=args.merging, enable_point=True)
|
| 95 |
+
ckpt = torch.load(args.ckpt_path, map_location="cpu")
|
| 96 |
+
|
| 97 |
+
# ✅ Fix: load pre-trained weights
|
| 98 |
+
model.load_state_dict(
|
| 99 |
+
ckpt, strict=False
|
| 100 |
+
) # Use strict=False due to enable_point=True difference
|
| 101 |
+
|
| 102 |
+
model = model.cuda().eval()
|
| 103 |
+
model = model.to(torch.bfloat16)
|
| 104 |
+
|
| 105 |
+
del ckpt
|
| 106 |
+
os.makedirs(osp.join(args.output_dir, f"{args.kf}"), exist_ok=True)
|
| 107 |
+
|
| 108 |
+
criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True)
|
| 109 |
+
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
for name_data, dataset in datasets_all.items():
|
| 112 |
+
save_path = osp.join(osp.join(args.output_dir, f"{args.kf}"), name_data)
|
| 113 |
+
os.makedirs(save_path, exist_ok=True)
|
| 114 |
+
log_file = osp.join(save_path, "logs.txt")
|
| 115 |
+
|
| 116 |
+
acc_all = 0
|
| 117 |
+
acc_all_med = 0
|
| 118 |
+
comp_all = 0
|
| 119 |
+
comp_all_med = 0
|
| 120 |
+
nc1_all = 0
|
| 121 |
+
nc1_all_med = 0
|
| 122 |
+
nc2_all = 0
|
| 123 |
+
nc2_all_med = 0
|
| 124 |
+
scene_infer_times = defaultdict(list)
|
| 125 |
+
|
| 126 |
+
for data_idx in tqdm(range(len(dataset))):
|
| 127 |
+
batch = default_collate([dataset[data_idx]])
|
| 128 |
+
ignore_keys = set(
|
| 129 |
+
[
|
| 130 |
+
"depthmap",
|
| 131 |
+
"dataset",
|
| 132 |
+
"label",
|
| 133 |
+
"instance",
|
| 134 |
+
"idx",
|
| 135 |
+
"true_shape",
|
| 136 |
+
"rng",
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
+
for view in batch:
|
| 140 |
+
for name in view.keys(): # pseudo_focal
|
| 141 |
+
if name in ignore_keys:
|
| 142 |
+
continue
|
| 143 |
+
if isinstance(view[name], tuple) or isinstance(
|
| 144 |
+
view[name], list
|
| 145 |
+
):
|
| 146 |
+
view[name] = [
|
| 147 |
+
x.to(device, non_blocking=True) for x in view[name]
|
| 148 |
+
]
|
| 149 |
+
else:
|
| 150 |
+
view[name] = view[name].to(device, non_blocking=True)
|
| 151 |
+
|
| 152 |
+
pts_all = []
|
| 153 |
+
pts_gt_all = []
|
| 154 |
+
images_all = []
|
| 155 |
+
masks_all = []
|
| 156 |
+
conf_all = []
|
| 157 |
+
in_camera1 = None
|
| 158 |
+
|
| 159 |
+
dtype = (
|
| 160 |
+
torch.bfloat16
|
| 161 |
+
if torch.cuda.get_device_capability()[0] >= 8
|
| 162 |
+
else torch.float16
|
| 163 |
+
)
|
| 164 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
| 165 |
+
if isinstance(batch, dict) and "img" in batch:
|
| 166 |
+
batch["img"] = (batch["img"] + 1.0) / 2.0
|
| 167 |
+
elif isinstance(batch, list) and all(
|
| 168 |
+
isinstance(v, dict) and "img" in v for v in batch
|
| 169 |
+
):
|
| 170 |
+
for view in batch:
|
| 171 |
+
view["img"] = (view["img"] + 1.0) / 2.0
|
| 172 |
+
# Gather all `img` tensors into a single tensor of shape [N, C, H, W]
|
| 173 |
+
imgs_tensor = torch.cat([v["img"] for v in batch], dim=0)
|
| 174 |
+
|
| 175 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
torch.cuda.synchronize()
|
| 178 |
+
start = time.time()
|
| 179 |
+
preds = model(imgs_tensor)
|
| 180 |
+
torch.cuda.synchronize()
|
| 181 |
+
end = time.time()
|
| 182 |
+
inference_time_ms = (end - start) * 1000
|
| 183 |
+
print(f"Inference time: {inference_time_ms:.2f}ms")
|
| 184 |
+
|
| 185 |
+
# Wrap model outputs per-view to align with batch later
|
| 186 |
+
predictions = preds
|
| 187 |
+
views = batch # list[dict]
|
| 188 |
+
if "pose_enc" in predictions:
|
| 189 |
+
B, S = predictions["pose_enc"].shape[:2]
|
| 190 |
+
elif "world_points" in predictions:
|
| 191 |
+
B, S = predictions["world_points"].shape[:2]
|
| 192 |
+
else:
|
| 193 |
+
raise KeyError(
|
| 194 |
+
"predictions is missing a key to infer sequence length"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
ress = []
|
| 198 |
+
for s in range(S):
|
| 199 |
+
res = {
|
| 200 |
+
"pts3d_in_other_view": predictions["world_points"][:, s],
|
| 201 |
+
"conf": predictions["world_points_conf"][:, s],
|
| 202 |
+
"depth": predictions["depth"][:, s],
|
| 203 |
+
"depth_conf": predictions["depth_conf"][:, s],
|
| 204 |
+
"camera_pose": predictions["pose_enc"][:, s, :],
|
| 205 |
+
}
|
| 206 |
+
if (
|
| 207 |
+
isinstance(views, list)
|
| 208 |
+
and s < len(views)
|
| 209 |
+
and "valid_mask" in views[s]
|
| 210 |
+
):
|
| 211 |
+
res["valid_mask"] = views[s]["valid_mask"]
|
| 212 |
+
if "track" in predictions:
|
| 213 |
+
res.update(
|
| 214 |
+
{
|
| 215 |
+
"track": predictions["track"][:, s],
|
| 216 |
+
"vis": (
|
| 217 |
+
predictions.get("vis", None)[:, s]
|
| 218 |
+
if "vis" in predictions
|
| 219 |
+
else None
|
| 220 |
+
),
|
| 221 |
+
"track_conf": (
|
| 222 |
+
predictions.get("conf", None)[:, s]
|
| 223 |
+
if "conf" in predictions
|
| 224 |
+
else None
|
| 225 |
+
),
|
| 226 |
+
}
|
| 227 |
+
)
|
| 228 |
+
ress.append(res)
|
| 229 |
+
|
| 230 |
+
preds = ress
|
| 231 |
+
|
| 232 |
+
valid_length = len(preds) // args.revisit
|
| 233 |
+
if args.revisit > 1:
|
| 234 |
+
preds = preds[-valid_length:]
|
| 235 |
+
batch = batch[-valid_length:]
|
| 236 |
+
|
| 237 |
+
# Evaluation
|
| 238 |
+
print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}")
|
| 239 |
+
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
|
| 240 |
+
criterion.get_all_pts3d_t(batch, preds)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
in_camera1 = None
|
| 244 |
+
pts_all = []
|
| 245 |
+
pts_gt_all = []
|
| 246 |
+
images_all = []
|
| 247 |
+
masks_all = []
|
| 248 |
+
conf_all = []
|
| 249 |
+
|
| 250 |
+
for j, view in enumerate(batch):
|
| 251 |
+
if in_camera1 is None:
|
| 252 |
+
in_camera1 = view["camera_pose"][0].cpu()
|
| 253 |
+
|
| 254 |
+
image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0]
|
| 255 |
+
mask = view["valid_mask"].cpu().numpy()[0]
|
| 256 |
+
|
| 257 |
+
pts = pred_pts[j].cpu().numpy()[0]
|
| 258 |
+
conf = preds[j]["conf"].cpu().data.numpy()[0]
|
| 259 |
+
|
| 260 |
+
# mask = mask & (conf > 1.8)
|
| 261 |
+
|
| 262 |
+
pts_gt = gt_pts[j].detach().cpu().numpy()[0]
|
| 263 |
+
|
| 264 |
+
H, W = image.shape[:2]
|
| 265 |
+
cx = W // 2
|
| 266 |
+
cy = H // 2
|
| 267 |
+
l, t = cx - 112, cy - 112
|
| 268 |
+
r, b = cx + 112, cy + 112
|
| 269 |
+
image = image[t:b, l:r]
|
| 270 |
+
mask = mask[t:b, l:r]
|
| 271 |
+
pts = pts[t:b, l:r]
|
| 272 |
+
pts_gt = pts_gt[t:b, l:r]
|
| 273 |
+
|
| 274 |
+
images_all.append(image[None, ...])
|
| 275 |
+
pts_all.append(pts[None, ...])
|
| 276 |
+
pts_gt_all.append(pts_gt[None, ...])
|
| 277 |
+
masks_all.append(mask[None, ...])
|
| 278 |
+
conf_all.append(conf[None, ...])
|
| 279 |
+
|
| 280 |
+
images_all = np.concatenate(images_all, axis=0)
|
| 281 |
+
pts_all = np.concatenate(pts_all, axis=0)
|
| 282 |
+
pts_gt_all = np.concatenate(pts_gt_all, axis=0)
|
| 283 |
+
masks_all = np.concatenate(masks_all, axis=0)
|
| 284 |
+
|
| 285 |
+
scene_id = view["label"][0].rsplit("/", 1)[0]
|
| 286 |
+
# Record average inference time per scene
|
| 287 |
+
try:
|
| 288 |
+
scene_infer_times[scene_id].append(float(inference_time_ms))
|
| 289 |
+
except Exception:
|
| 290 |
+
pass
|
| 291 |
+
|
| 292 |
+
save_params = {}
|
| 293 |
+
|
| 294 |
+
save_params["images_all"] = images_all
|
| 295 |
+
save_params["pts_all"] = pts_all
|
| 296 |
+
save_params["pts_gt_all"] = pts_gt_all
|
| 297 |
+
save_params["masks_all"] = masks_all
|
| 298 |
+
|
| 299 |
+
pts_all_masked = pts_all[masks_all > 0]
|
| 300 |
+
pts_gt_all_masked = pts_gt_all[masks_all > 0]
|
| 301 |
+
images_all_masked = images_all[masks_all > 0]
|
| 302 |
+
|
| 303 |
+
mask = np.isfinite(pts_all_masked)
|
| 304 |
+
pts_all_masked = pts_all_masked[mask]
|
| 305 |
+
|
| 306 |
+
mask_gt = np.isfinite(pts_gt_all_masked)
|
| 307 |
+
pts_gt_all_masked = pts_gt_all_masked[mask_gt]
|
| 308 |
+
images_all_masked = images_all_masked[mask]
|
| 309 |
+
|
| 310 |
+
# Reshape to point cloud (N, 3) before sampling
|
| 311 |
+
pts_all_masked = pts_all_masked.reshape(-1, 3)
|
| 312 |
+
pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
|
| 313 |
+
images_all_masked = images_all_masked.reshape(-1, 3)
|
| 314 |
+
|
| 315 |
+
# If number of points exceeds threshold, sample by points
|
| 316 |
+
if pts_all_masked.shape[0] > 999999:
|
| 317 |
+
sample_indices = np.random.choice(
|
| 318 |
+
pts_all_masked.shape[0], 999999, replace=False
|
| 319 |
+
)
|
| 320 |
+
pts_all_masked = pts_all_masked[sample_indices]
|
| 321 |
+
images_all_masked = images_all_masked[sample_indices]
|
| 322 |
+
|
| 323 |
+
# Apply the same sampling to GT point cloud
|
| 324 |
+
if pts_gt_all_masked.shape[0] > 999999:
|
| 325 |
+
sample_indices_gt = np.random.choice(
|
| 326 |
+
pts_gt_all_masked.shape[0], 999999, replace=False
|
| 327 |
+
)
|
| 328 |
+
pts_gt_all_masked = pts_gt_all_masked[sample_indices_gt]
|
| 329 |
+
|
| 330 |
+
if args.use_proj:
|
| 331 |
+
|
| 332 |
+
def umeyama_alignment(
|
| 333 |
+
src: np.ndarray, dst: np.ndarray, with_scale: bool = True
|
| 334 |
+
):
|
| 335 |
+
assert src.shape == dst.shape
|
| 336 |
+
N, dim = src.shape
|
| 337 |
+
|
| 338 |
+
mu_src = src.mean(axis=0)
|
| 339 |
+
mu_dst = dst.mean(axis=0)
|
| 340 |
+
src_c = src - mu_src
|
| 341 |
+
dst_c = dst - mu_dst
|
| 342 |
+
|
| 343 |
+
Sigma = dst_c.T @ src_c / N # (3,3)
|
| 344 |
+
|
| 345 |
+
U, D, Vt = np.linalg.svd(Sigma)
|
| 346 |
+
|
| 347 |
+
S = np.eye(dim)
|
| 348 |
+
if np.linalg.det(U) * np.linalg.det(Vt) < 0:
|
| 349 |
+
S[-1, -1] = -1
|
| 350 |
+
|
| 351 |
+
R = U @ S @ Vt
|
| 352 |
+
|
| 353 |
+
if with_scale:
|
| 354 |
+
var_src = (src_c**2).sum() / N
|
| 355 |
+
s = (D * S.diagonal()).sum() / var_src
|
| 356 |
+
else:
|
| 357 |
+
s = 1.0
|
| 358 |
+
|
| 359 |
+
t = mu_dst - s * R @ mu_src
|
| 360 |
+
|
| 361 |
+
return s, R, t
|
| 362 |
+
|
| 363 |
+
pts_all_masked = pts_all_masked.reshape(-1, 3)
|
| 364 |
+
pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
|
| 365 |
+
s, R, t = umeyama_alignment(
|
| 366 |
+
pts_all_masked, pts_gt_all_masked, with_scale=True
|
| 367 |
+
)
|
| 368 |
+
pts_all_aligned = (s * (R @ pts_all_masked.T)).T + t # (N,3)
|
| 369 |
+
pts_all_masked = pts_all_aligned
|
| 370 |
+
|
| 371 |
+
pcd = o3d.geometry.PointCloud()
|
| 372 |
+
pcd.points = o3d.utility.Vector3dVector(pts_all_masked)
|
| 373 |
+
pcd.colors = o3d.utility.Vector3dVector(images_all_masked)
|
| 374 |
+
|
| 375 |
+
pcd_gt = o3d.geometry.PointCloud()
|
| 376 |
+
pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked)
|
| 377 |
+
pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked)
|
| 378 |
+
|
| 379 |
+
trans_init = np.eye(4)
|
| 380 |
+
|
| 381 |
+
threshold = 0.1
|
| 382 |
+
reg_p2p = o3d.pipelines.registration.registration_icp(
|
| 383 |
+
pcd,
|
| 384 |
+
pcd_gt,
|
| 385 |
+
threshold,
|
| 386 |
+
trans_init,
|
| 387 |
+
o3d.pipelines.registration.TransformationEstimationPointToPoint(),
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
transformation = reg_p2p.transformation
|
| 391 |
+
|
| 392 |
+
pcd = pcd.transform(transformation)
|
| 393 |
+
pcd.estimate_normals()
|
| 394 |
+
pcd_gt.estimate_normals()
|
| 395 |
+
|
| 396 |
+
gt_normal = np.asarray(pcd_gt.normals)
|
| 397 |
+
pred_normal = np.asarray(pcd.normals)
|
| 398 |
+
|
| 399 |
+
acc, acc_med, nc1, nc1_med = accuracy(
|
| 400 |
+
pcd_gt.points, pcd.points, gt_normal, pred_normal
|
| 401 |
+
)
|
| 402 |
+
comp, comp_med, nc2, nc2_med = completion(
|
| 403 |
+
pcd_gt.points, pcd.points, gt_normal, pred_normal
|
| 404 |
+
)
|
| 405 |
+
print(
|
| 406 |
+
f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}"
|
| 407 |
+
)
|
| 408 |
+
print(
|
| 409 |
+
f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}",
|
| 410 |
+
file=open(log_file, "a"),
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
acc_all += acc
|
| 414 |
+
comp_all += comp
|
| 415 |
+
nc1_all += nc1
|
| 416 |
+
nc2_all += nc2
|
| 417 |
+
|
| 418 |
+
acc_all_med += acc_med
|
| 419 |
+
comp_all_med += comp_med
|
| 420 |
+
nc1_all_med += nc1_med
|
| 421 |
+
nc2_all_med += nc2_med
|
| 422 |
+
|
| 423 |
+
# release cuda memory
|
| 424 |
+
torch.cuda.empty_cache()
|
| 425 |
+
|
| 426 |
+
# Get depth from pcd and run TSDFusion
|
| 427 |
+
to_write = ""
|
| 428 |
+
# Read the log file
|
| 429 |
+
if os.path.exists(osp.join(save_path, "logs.txt")):
|
| 430 |
+
with open(osp.join(save_path, "logs.txt"), "r") as f_sub:
|
| 431 |
+
to_write += f_sub.read()
|
| 432 |
+
|
| 433 |
+
with open(osp.join(save_path, f"logs_all.txt"), "w") as f:
|
| 434 |
+
log_data = to_write
|
| 435 |
+
metrics = defaultdict(list)
|
| 436 |
+
for line in log_data.strip().split("\n"):
|
| 437 |
+
match = regex.match(line)
|
| 438 |
+
if match:
|
| 439 |
+
data = match.groupdict()
|
| 440 |
+
# Exclude 'scene_id' from metrics as it's an identifier
|
| 441 |
+
for key, value in data.items():
|
| 442 |
+
if key != "scene_id":
|
| 443 |
+
metrics[key].append(float(value))
|
| 444 |
+
metrics["nc"].append(
|
| 445 |
+
(float(data["nc1"]) + float(data["nc2"])) / 2
|
| 446 |
+
)
|
| 447 |
+
metrics["nc_med"].append(
|
| 448 |
+
(float(data["nc1_med"]) + float(data["nc2_med"])) / 2
|
| 449 |
+
)
|
| 450 |
+
mean_metrics = {
|
| 451 |
+
metric: sum(values) / len(values)
|
| 452 |
+
for metric, values in metrics.items()
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
c_name = "mean"
|
| 456 |
+
print_str = f"{c_name.ljust(20)}: "
|
| 457 |
+
for m_name in mean_metrics:
|
| 458 |
+
print_num = np.mean(mean_metrics[m_name])
|
| 459 |
+
print_str = print_str + f"{m_name}: {print_num:.3f} | "
|
| 460 |
+
print_str = print_str + "\n"
|
| 461 |
+
# Summarize per-scene average inference time
|
| 462 |
+
time_lines = []
|
| 463 |
+
for sid, times in scene_infer_times.items():
|
| 464 |
+
if len(times) > 0:
|
| 465 |
+
time_lines.append(
|
| 466 |
+
f"Idx: {sid}, Time_avg_ms: {np.mean(times):.2f}"
|
| 467 |
+
)
|
| 468 |
+
time_block = "\n".join(time_lines) + (
|
| 469 |
+
"\n" if len(time_lines) > 0 else ""
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
f.write(to_write + time_block + print_str)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
from collections import defaultdict
|
| 476 |
+
import re
|
| 477 |
+
|
| 478 |
+
pattern = r"""
|
| 479 |
+
Idx:\s*(?P<scene_id>[^,]+),\s*
|
| 480 |
+
Acc:\s*(?P<acc>[^,]+),\s*
|
| 481 |
+
Comp:\s*(?P<comp>[^,]+),\s*
|
| 482 |
+
NC1:\s*(?P<nc1>[^,]+),\s*
|
| 483 |
+
NC2:\s*(?P<nc2>[^,]+)\s*-\s*
|
| 484 |
+
Acc_med:\s*(?P<acc_med>[^,]+),\s*
|
| 485 |
+
Compc_med:\s*(?P<comp_med>[^,]+),\s*
|
| 486 |
+
NC1c_med:\s*(?P<nc1_med>[^,]+),\s*
|
| 487 |
+
NC2c_med:\s*(?P<nc2_med>[^,]+)
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
regex = re.compile(pattern, re.VERBOSE)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
if __name__ == "__main__":
|
| 494 |
+
parser = get_args_parser()
|
| 495 |
+
args = parser.parse_args()
|
| 496 |
+
|
| 497 |
+
main(args)
|
FastVGGT/eval/eval_custom.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from scipy.spatial.transform import Rotation
|
| 9 |
+
|
| 10 |
+
# Ensure project root is in sys.path for absolute imports like `vggt.*`
|
| 11 |
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
| 12 |
+
if ROOT_DIR not in sys.path:
|
| 13 |
+
sys.path.insert(0, ROOT_DIR)
|
| 14 |
+
|
| 15 |
+
from vggt.models.vggt import VGGT
|
| 16 |
+
from vggt.utils.eval_utils import (
|
| 17 |
+
load_poses,
|
| 18 |
+
get_vgg_input_imgs,
|
| 19 |
+
get_sorted_image_paths,
|
| 20 |
+
build_frame_selection,
|
| 21 |
+
load_images_rgb,
|
| 22 |
+
infer_vggt_and_reconstruct,
|
| 23 |
+
evaluate_scene_and_save,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Import pose visualization libraries (optional EVO support)
|
| 27 |
+
try:
|
| 28 |
+
from evo.core.trajectory import PoseTrajectory3D
|
| 29 |
+
import evo.tools.plot as plot
|
| 30 |
+
|
| 31 |
+
EVO_AVAILABLE = True
|
| 32 |
+
except ImportError:
|
| 33 |
+
# EVO is optional; we have a matplotlib-based fallback
|
| 34 |
+
EVO_AVAILABLE = False
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def visualize_predicted_poses(
|
| 38 |
+
all_cam_to_world_mat, frame_ids, output_scene_dir, scene_name="custom_dataset"
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Visualize the predicted camera pose trajectory (no GT comparison required).
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
all_cam_to_world_mat: List of camera-to-world transform matrices
|
| 45 |
+
frame_ids: List of frame IDs
|
| 46 |
+
output_scene_dir: Output directory
|
| 47 |
+
scene_name: Scene name
|
| 48 |
+
"""
|
| 49 |
+
# Provide basic pose visualization even without EVO
|
| 50 |
+
if not EVO_AVAILABLE:
|
| 51 |
+
print("⚠️ EVO not installed; using basic matplotlib visualization")
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
# Convert to numpy array
|
| 55 |
+
poses_est = np.array(all_cam_to_world_mat)
|
| 56 |
+
|
| 57 |
+
if len(poses_est) < 2:
|
| 58 |
+
print("⚠️ Not enough poses to generate trajectory plot")
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
print(f"🎨 Generating pose trajectory visualization...")
|
| 62 |
+
|
| 63 |
+
# Extract translation part
|
| 64 |
+
positions = poses_est[:, :3, 3] # shape: (N, 3)
|
| 65 |
+
|
| 66 |
+
# Create figure - show XZ-plane projection only
|
| 67 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
|
| 68 |
+
|
| 69 |
+
# XZ-plane projection
|
| 70 |
+
ax.plot(
|
| 71 |
+
positions[:, 0],
|
| 72 |
+
positions[:, 2],
|
| 73 |
+
"b-",
|
| 74 |
+
linewidth=2,
|
| 75 |
+
label="Predicted Trajectory",
|
| 76 |
+
)
|
| 77 |
+
ax.scatter(
|
| 78 |
+
positions[0, 0], positions[0, 2], color="green", s=100, label="Start"
|
| 79 |
+
)
|
| 80 |
+
ax.scatter(positions[-1, 0], positions[-1, 2], color="red", s=100, label="End")
|
| 81 |
+
ax.set_xlabel("X (m)")
|
| 82 |
+
ax.set_ylabel("Z (m)")
|
| 83 |
+
ax.set_title(f"{scene_name} - XZ-plane projection")
|
| 84 |
+
ax.legend()
|
| 85 |
+
ax.grid(True, alpha=0.3)
|
| 86 |
+
|
| 87 |
+
# Save image
|
| 88 |
+
pose_plot_path = output_scene_dir / "predicted_trajectory.png"
|
| 89 |
+
plt.savefig(pose_plot_path, dpi=300, bbox_inches="tight")
|
| 90 |
+
plt.close()
|
| 91 |
+
|
| 92 |
+
print(f"📊 Trajectory visualization saved: {pose_plot_path}")
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"⚠️ Failed to generate pose visualization: {e}")
|
| 96 |
+
import traceback
|
| 97 |
+
|
| 98 |
+
traceback.print_exc()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def main():
|
| 102 |
+
"""
|
| 103 |
+
Evaluation script for a Custom Dataset.
|
| 104 |
+
Supports optional evaluation and custom dataset structure.
|
| 105 |
+
"""
|
| 106 |
+
parser = argparse.ArgumentParser(
|
| 107 |
+
description="Run FastVGGT evaluation on a Custom Dataset"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Required: dataset path
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--data_path",
|
| 113 |
+
type=Path,
|
| 114 |
+
required=True,
|
| 115 |
+
help="Dataset path containing subfolders: color, depth, gt_ply, pose",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Optional: enable evaluation
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--enable_evaluation",
|
| 121 |
+
action="store_true",
|
| 122 |
+
help="Enable evaluation (requires pose and ply data)",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Output path
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--output_path",
|
| 128 |
+
type=Path,
|
| 129 |
+
default="./eval_results_custom",
|
| 130 |
+
help="Output path for evaluation results",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Model parameters
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--ckpt_path",
|
| 136 |
+
type=str,
|
| 137 |
+
default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
|
| 138 |
+
help="Model checkpoint file path",
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
parser.add_argument("--merging", type=int, default=0, help="Merging parameter")
|
| 142 |
+
|
| 143 |
+
# Processing parameters
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--input_frame",
|
| 146 |
+
type=int,
|
| 147 |
+
default=200,
|
| 148 |
+
help="Maximum number of frames to process per scene",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--depth_conf_thresh",
|
| 153 |
+
type=float,
|
| 154 |
+
default=3.0,
|
| 155 |
+
help="Depth confidence threshold to filter low-confidence depth values",
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Evaluation parameters (only used when evaluation is enabled)
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"--chamfer_max_dist",
|
| 161 |
+
type=float,
|
| 162 |
+
default=0.5,
|
| 163 |
+
help="Maximum distance threshold used in Chamfer Distance computation",
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
parser.add_argument("--plot", action="store_true", help="Whether to generate plots")
|
| 167 |
+
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--vis_attn_map",
|
| 170 |
+
action="store_true",
|
| 171 |
+
help="Visualize attention maps during inference",
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
args = parser.parse_args()
|
| 175 |
+
torch.manual_seed(33)
|
| 176 |
+
|
| 177 |
+
# Check data path exists
|
| 178 |
+
if not args.data_path.exists():
|
| 179 |
+
print(f"❌ Error: Data path does not exist: {args.data_path}")
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
# Check required subdirectories
|
| 183 |
+
color_dir = args.data_path / "images"
|
| 184 |
+
pose_dir = args.data_path / "pose"
|
| 185 |
+
|
| 186 |
+
if not color_dir.exists():
|
| 187 |
+
print(f"❌ Error: color directory does not exist: {color_dir}")
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
print(f"📁 Dataset path: {args.data_path}")
|
| 191 |
+
# print(f"🔧 Enable evaluation: {'Yes' if args.enable_evaluation else 'No'}")
|
| 192 |
+
|
| 193 |
+
# If evaluation is enabled, check pose and gt_ply directories
|
| 194 |
+
if args.enable_evaluation:
|
| 195 |
+
if not pose_dir.exists():
|
| 196 |
+
print(f"❌ Error: Evaluation requires pose directory: {pose_dir}")
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
gt_ply_dir = args.data_path / "gt_ply"
|
| 200 |
+
if not gt_ply_dir.exists():
|
| 201 |
+
print(f"❌ Error: Evaluation requires gt_ply directory: {gt_ply_dir}")
|
| 202 |
+
return
|
| 203 |
+
print(f"📊 Evaluation will use Ground Truth")
|
| 204 |
+
else:
|
| 205 |
+
print(f"🏃 Inference only, no evaluation")
|
| 206 |
+
|
| 207 |
+
# Create output directory
|
| 208 |
+
args.output_path.mkdir(parents=True, exist_ok=True)
|
| 209 |
+
output_scene_dir = args.output_path / "custom_dataset"
|
| 210 |
+
|
| 211 |
+
# Check if already processed
|
| 212 |
+
if (output_scene_dir / "metrics.json").exists() and args.enable_evaluation:
|
| 213 |
+
print(
|
| 214 |
+
f"⚠️ Results already exist, skipping: {output_scene_dir / 'metrics.json'}"
|
| 215 |
+
)
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
# Force use of bf16 dtype
|
| 219 |
+
dtype = torch.bfloat16
|
| 220 |
+
|
| 221 |
+
# Load VGGT model
|
| 222 |
+
print(f"🔄 Loading model: {args.ckpt_path}")
|
| 223 |
+
model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
|
| 224 |
+
ckpt = torch.load(args.ckpt_path, map_location="cpu")
|
| 225 |
+
incompat = model.load_state_dict(ckpt, strict=False)
|
| 226 |
+
# if incompat.missing_keys or incompat.unexpected_keys:
|
| 227 |
+
# print(f"⚠️ Partially incompatible keys when loading model: {incompat}")
|
| 228 |
+
model = model.cuda().eval()
|
| 229 |
+
model = model.to(torch.bfloat16)
|
| 230 |
+
print(f"✅ Model loaded")
|
| 231 |
+
|
| 232 |
+
# Load scene data
|
| 233 |
+
image_paths = get_sorted_image_paths(color_dir)
|
| 234 |
+
if len(image_paths) == 0:
|
| 235 |
+
print(f"❌ Error: No images found in {color_dir}")
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
print(f"🖼️ Found {len(image_paths)} images")
|
| 239 |
+
|
| 240 |
+
# Process pose data (if evaluation is enabled)
|
| 241 |
+
poses_gt = None
|
| 242 |
+
first_gt_pose = None
|
| 243 |
+
available_pose_frame_ids = None
|
| 244 |
+
c2ws = None
|
| 245 |
+
|
| 246 |
+
if args.enable_evaluation:
|
| 247 |
+
poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_dir)
|
| 248 |
+
if (
|
| 249 |
+
poses_gt is None
|
| 250 |
+
or first_gt_pose is None
|
| 251 |
+
or available_pose_frame_ids is None
|
| 252 |
+
):
|
| 253 |
+
print(f"❌ Error: Failed to load pose data")
|
| 254 |
+
return
|
| 255 |
+
print(f"📐 Loaded {len(poses_gt)} poses")
|
| 256 |
+
|
| 257 |
+
# Frame selection
|
| 258 |
+
if args.enable_evaluation and available_pose_frame_ids is not None:
|
| 259 |
+
# Use pose data for frame selection
|
| 260 |
+
selected_frame_ids, selected_image_paths, selected_pose_indices = (
|
| 261 |
+
build_frame_selection(
|
| 262 |
+
image_paths, available_pose_frame_ids, args.input_frame
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
c2ws = poses_gt[selected_pose_indices]
|
| 266 |
+
image_paths = selected_image_paths
|
| 267 |
+
else:
|
| 268 |
+
# Simply take the first N frames
|
| 269 |
+
num_frames = min(len(image_paths), args.input_frame)
|
| 270 |
+
selected_frame_ids = list(range(num_frames))
|
| 271 |
+
image_paths = image_paths[:num_frames]
|
| 272 |
+
|
| 273 |
+
print(f"📋 Selected {len(image_paths)} frames for processing")
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
# Load images
|
| 277 |
+
print(f"🔄 Loading images...")
|
| 278 |
+
images = load_images_rgb(image_paths)
|
| 279 |
+
|
| 280 |
+
if not images or len(images) < 3:
|
| 281 |
+
print(f"❌ Error: Not enough valid images (need at least 3)")
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
frame_ids = selected_frame_ids
|
| 285 |
+
images_array = np.stack(images)
|
| 286 |
+
vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
|
| 287 |
+
print(f"📐 Image patch dimensions: {patch_width}x{patch_height}")
|
| 288 |
+
|
| 289 |
+
# Update attention layer patch dimensions in the model
|
| 290 |
+
model.update_patch_dimensions(patch_width, patch_height)
|
| 291 |
+
|
| 292 |
+
# Inference + Reconstruction
|
| 293 |
+
print(f"🚀 Start inference and reconstruction...")
|
| 294 |
+
(
|
| 295 |
+
extrinsic_np,
|
| 296 |
+
intrinsic_np,
|
| 297 |
+
all_world_points,
|
| 298 |
+
all_point_colors,
|
| 299 |
+
all_cam_to_world_mat,
|
| 300 |
+
inference_time_ms,
|
| 301 |
+
) = infer_vggt_and_reconstruct(
|
| 302 |
+
model, vgg_input, dtype, args.depth_conf_thresh, image_paths
|
| 303 |
+
)
|
| 304 |
+
print(f"⏱️ Inference time: {inference_time_ms:.2f}ms")
|
| 305 |
+
|
| 306 |
+
# Check results
|
| 307 |
+
if not all_cam_to_world_mat or not all_world_points:
|
| 308 |
+
print(f"❌ Error: Failed to obtain valid camera poses or point clouds")
|
| 309 |
+
return
|
| 310 |
+
|
| 311 |
+
# print(f"✅ Inference done, obtained {len(all_world_points)} point sets")
|
| 312 |
+
|
| 313 |
+
# Evaluation and saving
|
| 314 |
+
if args.enable_evaluation:
|
| 315 |
+
print(f"📊 Start evaluation...")
|
| 316 |
+
gt_ply_dir = args.data_path / "gt_ply"
|
| 317 |
+
metrics = evaluate_scene_and_save(
|
| 318 |
+
"custom_dataset",
|
| 319 |
+
c2ws,
|
| 320 |
+
first_gt_pose,
|
| 321 |
+
frame_ids,
|
| 322 |
+
all_cam_to_world_mat,
|
| 323 |
+
all_world_points,
|
| 324 |
+
output_scene_dir,
|
| 325 |
+
gt_ply_dir,
|
| 326 |
+
args.chamfer_max_dist,
|
| 327 |
+
inference_time_ms,
|
| 328 |
+
args.plot,
|
| 329 |
+
)
|
| 330 |
+
if metrics is not None:
|
| 331 |
+
print("📈 Evaluation results:")
|
| 332 |
+
for key, value in metrics.items():
|
| 333 |
+
if key in [
|
| 334 |
+
"chamfer_distance",
|
| 335 |
+
"ate",
|
| 336 |
+
"are",
|
| 337 |
+
"rpe_rot",
|
| 338 |
+
"rpe_trans",
|
| 339 |
+
"inference_time_ms",
|
| 340 |
+
]:
|
| 341 |
+
print(f" {key}: {float(value):.4f}")
|
| 342 |
+
|
| 343 |
+
# Also visualize predicted poses in evaluation branch
|
| 344 |
+
if args.plot:
|
| 345 |
+
visualize_predicted_poses(
|
| 346 |
+
all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
# Save reconstruction only, no evaluation
|
| 350 |
+
print(f"💾 Saving reconstruction...")
|
| 351 |
+
output_scene_dir.mkdir(parents=True, exist_ok=True)
|
| 352 |
+
|
| 353 |
+
# Save camera poses
|
| 354 |
+
poses_output_path = output_scene_dir / "estimated_poses.txt"
|
| 355 |
+
with open(poses_output_path, "w") as f:
|
| 356 |
+
for i, pose in enumerate(all_cam_to_world_mat):
|
| 357 |
+
f.write(f"# Frame {frame_ids[i]}\n")
|
| 358 |
+
for row in pose:
|
| 359 |
+
f.write(" ".join(map(str, row)) + "\n")
|
| 360 |
+
f.write("\n")
|
| 361 |
+
|
| 362 |
+
# Save point cloud
|
| 363 |
+
if all_world_points:
|
| 364 |
+
points_output_path = output_scene_dir / "reconstructed_points.ply"
|
| 365 |
+
|
| 366 |
+
# Merge all frames' point clouds and colors
|
| 367 |
+
try:
|
| 368 |
+
merged_point_cloud = np.vstack(all_world_points)
|
| 369 |
+
merged_colors = (
|
| 370 |
+
np.vstack(all_point_colors).astype(np.uint8)
|
| 371 |
+
if all_point_colors is not None and len(all_point_colors) > 0
|
| 372 |
+
else None
|
| 373 |
+
)
|
| 374 |
+
print(
|
| 375 |
+
f"📊 Merged point clouds: {len(all_world_points)} frames, total {len(merged_point_cloud)} points"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# If too many points, randomly sample 100000 points
|
| 379 |
+
max_points = 100000
|
| 380 |
+
if len(merged_point_cloud) > max_points:
|
| 381 |
+
print(
|
| 382 |
+
f"🔽 Too many points, randomly sampling {max_points} points..."
|
| 383 |
+
)
|
| 384 |
+
# Randomly choose indices
|
| 385 |
+
indices = np.random.choice(
|
| 386 |
+
len(merged_point_cloud), size=max_points, replace=False
|
| 387 |
+
)
|
| 388 |
+
merged_point_cloud = merged_point_cloud[indices]
|
| 389 |
+
if merged_colors is not None:
|
| 390 |
+
merged_colors = merged_colors[indices]
|
| 391 |
+
print(
|
| 392 |
+
f"✅ Sampling done, kept {len(merged_point_cloud)} points"
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Save as PLY (with color)
|
| 396 |
+
with open(points_output_path, "w") as f:
|
| 397 |
+
f.write("ply\n")
|
| 398 |
+
f.write("format ascii 1.0\n")
|
| 399 |
+
f.write(f"element vertex {len(merged_point_cloud)}\n")
|
| 400 |
+
f.write("property float x\n")
|
| 401 |
+
f.write("property float y\n")
|
| 402 |
+
f.write("property float z\n")
|
| 403 |
+
if merged_colors is not None:
|
| 404 |
+
f.write("property uchar red\n")
|
| 405 |
+
f.write("property uchar green\n")
|
| 406 |
+
f.write("property uchar blue\n")
|
| 407 |
+
f.write("end_header\n")
|
| 408 |
+
if merged_colors is None:
|
| 409 |
+
for point in merged_point_cloud:
|
| 410 |
+
if not (np.isnan(point).any() or np.isinf(point).any()):
|
| 411 |
+
f.write(
|
| 412 |
+
f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n"
|
| 413 |
+
)
|
| 414 |
+
else:
|
| 415 |
+
for point, color in zip(merged_point_cloud, merged_colors):
|
| 416 |
+
# Check point validity
|
| 417 |
+
if not (np.isnan(point).any() or np.isinf(point).any()):
|
| 418 |
+
r = int(np.clip(color[0], 0, 255))
|
| 419 |
+
g = int(np.clip(color[1], 0, 255))
|
| 420 |
+
b = int(np.clip(color[2], 0, 255))
|
| 421 |
+
f.write(
|
| 422 |
+
f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r} {g} {b}\n"
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
print(f"💾 Point cloud saved to: {points_output_path}")
|
| 426 |
+
|
| 427 |
+
except Exception as e:
|
| 428 |
+
print(f"⚠️ Error saving point cloud: {e}")
|
| 429 |
+
# If merge fails, try to log per-frame info
|
| 430 |
+
print(f"🔍 Point cloud debug info:")
|
| 431 |
+
for i, frame_points in enumerate(all_world_points):
|
| 432 |
+
print(
|
| 433 |
+
f" Frame {i}: {frame_points.shape if hasattr(frame_points, 'shape') else type(frame_points)}"
|
| 434 |
+
)
|
| 435 |
+
if (
|
| 436 |
+
hasattr(frame_points, "shape")
|
| 437 |
+
and len(frame_points.shape) >= 2
|
| 438 |
+
):
|
| 439 |
+
print(
|
| 440 |
+
f" Shape: {frame_points.shape}, Dtype: {frame_points.dtype}"
|
| 441 |
+
)
|
| 442 |
+
if frame_points.shape[0] > 0:
|
| 443 |
+
print(
|
| 444 |
+
f" Range: x[{np.min(frame_points[:, 0]):.3f}, {np.max(frame_points[:, 0]):.3f}] "
|
| 445 |
+
f"y[{np.min(frame_points[:, 1]):.3f}, {np.max(frame_points[:, 1]):.3f}] "
|
| 446 |
+
f"z[{np.min(frame_points[:, 2]):.3f}, {np.max(frame_points[:, 2]):.3f}]"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
print(f"📁 Results saved to: {output_scene_dir}")
|
| 450 |
+
|
| 451 |
+
# Visualize predicted pose trajectory
|
| 452 |
+
if args.plot:
|
| 453 |
+
visualize_predicted_poses(
|
| 454 |
+
all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
print(f"🎉 Done!")
|
| 458 |
+
|
| 459 |
+
except Exception as e:
|
| 460 |
+
print(f"❌ Error occurred during processing: {e}")
|
| 461 |
+
import traceback
|
| 462 |
+
|
| 463 |
+
traceback.print_exc()
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
if __name__ == "__main__":
|
| 467 |
+
main()
|
FastVGGT/eval/eval_scannet.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
# Ensure project root is in sys.path for absolute imports like `vggt.*`
|
| 9 |
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
| 10 |
+
if ROOT_DIR not in sys.path:
|
| 11 |
+
sys.path.insert(0, ROOT_DIR)
|
| 12 |
+
|
| 13 |
+
from vggt.models.vggt import VGGT
|
| 14 |
+
from vggt.utils.eval_utils import (
|
| 15 |
+
load_poses,
|
| 16 |
+
get_vgg_input_imgs,
|
| 17 |
+
get_sorted_image_paths,
|
| 18 |
+
get_all_scenes,
|
| 19 |
+
build_frame_selection,
|
| 20 |
+
load_images_rgb,
|
| 21 |
+
infer_vggt_and_reconstruct,
|
| 22 |
+
evaluate_scene_and_save,
|
| 23 |
+
compute_average_metrics_and_save,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--data_dir", type=Path, default="/data/scannetv2/process_scannet"
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--gt_ply_dir",
|
| 34 |
+
type=Path,
|
| 35 |
+
default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument("--output_path", type=Path, default="./eval_results")
|
| 38 |
+
parser.add_argument("--merging", type=int, default=None)
|
| 39 |
+
parser.add_argument("--plot", type=bool, default=True)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--depth_conf_thresh",
|
| 42 |
+
type=float,
|
| 43 |
+
default=3.0,
|
| 44 |
+
help="Depth confidence threshold for filtering low confidence depth values",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--chamfer_max_dist",
|
| 48 |
+
type=float,
|
| 49 |
+
default=0.5,
|
| 50 |
+
help="Maximum distance threshold in Chamfer Distance computation, distances exceeding this value will be clipped",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--input_frame",
|
| 54 |
+
type=int,
|
| 55 |
+
default=200,
|
| 56 |
+
help="Maximum number of frames selected for processing per scene",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--num_scenes",
|
| 60 |
+
type=int,
|
| 61 |
+
default=50,
|
| 62 |
+
help="Maximum number of scenes to evaluate",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--ckpt_path",
|
| 66 |
+
type=str,
|
| 67 |
+
default="./ckpt/model_tracker_fixed_e20.pt",
|
| 68 |
+
help="Path to the model checkpoint file",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--vis_attn_map",
|
| 72 |
+
action="store_true",
|
| 73 |
+
help="Whether to visualize attention maps during inference",
|
| 74 |
+
)
|
| 75 |
+
args = parser.parse_args()
|
| 76 |
+
torch.manual_seed(33)
|
| 77 |
+
|
| 78 |
+
# Scene sampling
|
| 79 |
+
scannet_scenes = get_all_scenes(args.data_dir, args.num_scenes)
|
| 80 |
+
print(f"Evaluate {len(scannet_scenes)} scenes")
|
| 81 |
+
|
| 82 |
+
all_scenes_metrics = {"scenes": {}, "average": {}}
|
| 83 |
+
# Force use of bf16 data type
|
| 84 |
+
dtype = torch.bfloat16
|
| 85 |
+
# Load VGGT model
|
| 86 |
+
model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
|
| 87 |
+
ckpt = torch.load(args.ckpt_path, map_location="cpu")
|
| 88 |
+
incompat = model.load_state_dict(ckpt, strict=False)
|
| 89 |
+
model = model.cuda().eval()
|
| 90 |
+
model = model.to(torch.bfloat16)
|
| 91 |
+
|
| 92 |
+
# Process each scene
|
| 93 |
+
for scene in scannet_scenes:
|
| 94 |
+
scene_dir = args.data_dir / f"{scene}"
|
| 95 |
+
output_scene_dir = args.output_path / f"input_frame_{args.input_frame}" / scene
|
| 96 |
+
if (output_scene_dir / "metrics.json").exists():
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
# Load scene data
|
| 100 |
+
images_dir = scene_dir / "color"
|
| 101 |
+
pose_path = scene_dir / "pose"
|
| 102 |
+
image_paths = get_sorted_image_paths(images_dir)
|
| 103 |
+
poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_path)
|
| 104 |
+
if (
|
| 105 |
+
poses_gt is None
|
| 106 |
+
or first_gt_pose is None
|
| 107 |
+
or available_pose_frame_ids is None
|
| 108 |
+
):
|
| 109 |
+
print(f"Skipping scene {scene}: no pose data")
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
# Frame filtering
|
| 113 |
+
selected_frame_ids, selected_image_paths, selected_pose_indices = (
|
| 114 |
+
build_frame_selection(
|
| 115 |
+
image_paths, available_pose_frame_ids, args.input_frame
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Get corresponding poses
|
| 120 |
+
c2ws = poses_gt[selected_pose_indices]
|
| 121 |
+
image_paths = selected_image_paths
|
| 122 |
+
|
| 123 |
+
if len(image_paths) == 0:
|
| 124 |
+
print(f"No images found in {images_dir}")
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
print("🚩Processing", scene, f"Found {len(image_paths)} images")
|
| 128 |
+
all_cam_to_world_mat = []
|
| 129 |
+
all_world_points = []
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
# Load images
|
| 133 |
+
images = load_images_rgb(image_paths)
|
| 134 |
+
|
| 135 |
+
if not images or len(images) < 3:
|
| 136 |
+
print(f"Skipping {scene}: insufficient valid images")
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
frame_ids = selected_frame_ids
|
| 140 |
+
images_array = np.stack(images)
|
| 141 |
+
vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
|
| 142 |
+
print(f"Patch dimensions: {patch_width}x{patch_height}")
|
| 143 |
+
|
| 144 |
+
# Update model attention layers with dynamic patch dimensions
|
| 145 |
+
model.update_patch_dimensions(patch_width, patch_height)
|
| 146 |
+
|
| 147 |
+
# Inference + Reconstruction
|
| 148 |
+
(
|
| 149 |
+
extrinsic_np,
|
| 150 |
+
intrinsic_np,
|
| 151 |
+
all_world_points,
|
| 152 |
+
all_point_colors,
|
| 153 |
+
all_cam_to_world_mat,
|
| 154 |
+
inference_time_ms,
|
| 155 |
+
) = infer_vggt_and_reconstruct(
|
| 156 |
+
model, vgg_input, dtype, args.depth_conf_thresh, image_paths
|
| 157 |
+
)
|
| 158 |
+
print(f"Inference time: {inference_time_ms:.2f}ms")
|
| 159 |
+
|
| 160 |
+
# Process results
|
| 161 |
+
if not all_cam_to_world_mat or not all_world_points:
|
| 162 |
+
print(
|
| 163 |
+
f"Skipping {scene}: failed to obtain valid camera poses or point clouds"
|
| 164 |
+
)
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
# Evaluate and save
|
| 168 |
+
metrics = evaluate_scene_and_save(
|
| 169 |
+
scene,
|
| 170 |
+
c2ws,
|
| 171 |
+
first_gt_pose,
|
| 172 |
+
frame_ids,
|
| 173 |
+
all_cam_to_world_mat,
|
| 174 |
+
all_world_points,
|
| 175 |
+
output_scene_dir,
|
| 176 |
+
args.gt_ply_dir,
|
| 177 |
+
args.chamfer_max_dist,
|
| 178 |
+
inference_time_ms,
|
| 179 |
+
args.plot,
|
| 180 |
+
)
|
| 181 |
+
if metrics is not None:
|
| 182 |
+
all_scenes_metrics["scenes"][scene] = {
|
| 183 |
+
key: float(value)
|
| 184 |
+
for key, value in metrics.items()
|
| 185 |
+
if key
|
| 186 |
+
in [
|
| 187 |
+
"chamfer_distance",
|
| 188 |
+
"ate",
|
| 189 |
+
"are",
|
| 190 |
+
"rpe_rot",
|
| 191 |
+
"rpe_trans",
|
| 192 |
+
"inference_time_ms",
|
| 193 |
+
]
|
| 194 |
+
}
|
| 195 |
+
print("Complete metrics", all_scenes_metrics["scenes"][scene])
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"Error processing scene {scene}: {e}")
|
| 199 |
+
import traceback
|
| 200 |
+
|
| 201 |
+
traceback.print_exc()
|
| 202 |
+
|
| 203 |
+
# Summarize average metrics and save
|
| 204 |
+
compute_average_metrics_and_save(
|
| 205 |
+
all_scenes_metrics,
|
| 206 |
+
args.output_path,
|
| 207 |
+
args.input_frame,
|
| 208 |
+
)
|
FastVGGT/eval/utils.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.spatial import cKDTree as KDTree
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
|
| 6 |
+
"""
|
| 7 |
+
Args:
|
| 8 |
+
- depthmap (HxW array):
|
| 9 |
+
- camera_intrinsics: a 3x3 matrix
|
| 10 |
+
Returns:
|
| 11 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 12 |
+
"""
|
| 13 |
+
camera_intrinsics = np.float32(camera_intrinsics)
|
| 14 |
+
H, W = depthmap.shape
|
| 15 |
+
|
| 16 |
+
assert camera_intrinsics[0, 1] == 0.0
|
| 17 |
+
assert camera_intrinsics[1, 0] == 0.0
|
| 18 |
+
if pseudo_focal is None:
|
| 19 |
+
fu = camera_intrinsics[0, 0]
|
| 20 |
+
fv = camera_intrinsics[1, 1]
|
| 21 |
+
else:
|
| 22 |
+
assert pseudo_focal.shape == (H, W)
|
| 23 |
+
fu = fv = pseudo_focal
|
| 24 |
+
cu = camera_intrinsics[0, 2]
|
| 25 |
+
cv = camera_intrinsics[1, 2]
|
| 26 |
+
|
| 27 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 28 |
+
z_cam = depthmap
|
| 29 |
+
x_cam = (u - cu) * z_cam / fu
|
| 30 |
+
y_cam = (v - cv) * z_cam / fv
|
| 31 |
+
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
| 32 |
+
|
| 33 |
+
valid_mask = depthmap > 0.0
|
| 34 |
+
return X_cam, valid_mask
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def depthmap_to_absolute_camera_coordinates(
|
| 38 |
+
depthmap, camera_intrinsics, camera_pose, **kw
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
- depthmap (HxW array):
|
| 43 |
+
- camera_intrinsics: a 3x3 matrix
|
| 44 |
+
- camera_pose: a 4x3 or 4x4 cam2world matrix
|
| 45 |
+
Returns:
|
| 46 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 47 |
+
"""
|
| 48 |
+
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
|
| 49 |
+
|
| 50 |
+
X_world = X_cam # default
|
| 51 |
+
if camera_pose is not None:
|
| 52 |
+
|
| 53 |
+
R_cam2world = camera_pose[:3, :3]
|
| 54 |
+
t_cam2world = camera_pose[:3, 3]
|
| 55 |
+
|
| 56 |
+
X_world = (
|
| 57 |
+
np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return X_world, valid_mask
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def completion_ratio(gt_points, rec_points, dist_th=0.05):
|
| 64 |
+
gen_points_kd_tree = KDTree(rec_points)
|
| 65 |
+
distances, _ = gen_points_kd_tree.query(gt_points)
|
| 66 |
+
comp_ratio = np.mean((distances < dist_th).astype(np.float32))
|
| 67 |
+
return comp_ratio
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None):
|
| 71 |
+
gt_points_kd_tree = KDTree(gt_points)
|
| 72 |
+
distances, idx = gt_points_kd_tree.query(rec_points, workers=-1)
|
| 73 |
+
acc = np.mean(distances)
|
| 74 |
+
|
| 75 |
+
acc_median = np.median(distances)
|
| 76 |
+
|
| 77 |
+
if gt_normals is not None and rec_normals is not None:
|
| 78 |
+
normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1)
|
| 79 |
+
normal_dot = np.abs(normal_dot)
|
| 80 |
+
|
| 81 |
+
return acc, acc_median, np.mean(normal_dot), np.median(normal_dot)
|
| 82 |
+
|
| 83 |
+
return acc, acc_median
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def completion(gt_points, rec_points, gt_normals=None, rec_normals=None):
|
| 87 |
+
gt_points_kd_tree = KDTree(rec_points)
|
| 88 |
+
distances, idx = gt_points_kd_tree.query(gt_points, workers=-1)
|
| 89 |
+
comp = np.mean(distances)
|
| 90 |
+
comp_median = np.median(distances)
|
| 91 |
+
|
| 92 |
+
if gt_normals is not None and rec_normals is not None:
|
| 93 |
+
normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1)
|
| 94 |
+
normal_dot = np.abs(normal_dot)
|
| 95 |
+
|
| 96 |
+
return comp, comp_median, np.mean(normal_dot), np.median(normal_dot)
|
| 97 |
+
|
| 98 |
+
return comp, comp_median
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def compute_iou(pred_vox, target_vox):
|
| 102 |
+
# Get voxel indices
|
| 103 |
+
v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()]
|
| 104 |
+
v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()]
|
| 105 |
+
|
| 106 |
+
# Convert to sets for set operations
|
| 107 |
+
v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices)
|
| 108 |
+
v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices)
|
| 109 |
+
|
| 110 |
+
# Compute intersection and union
|
| 111 |
+
intersection = v_pred_filled & v_target_filled
|
| 112 |
+
union = v_pred_filled | v_target_filled
|
| 113 |
+
|
| 114 |
+
# Compute IoU
|
| 115 |
+
iou = len(intersection) / len(union)
|
| 116 |
+
return iou
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def colmap_to_opencv_intrinsics(K):
|
| 120 |
+
"""
|
| 121 |
+
Modify camera intrinsics to follow a different convention.
|
| 122 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 123 |
+
- (0.5, 0.5) in Colmap
|
| 124 |
+
- (0,0) in OpenCV
|
| 125 |
+
"""
|
| 126 |
+
K = K.copy()
|
| 127 |
+
K[0, 2] -= 0.5
|
| 128 |
+
K[1, 2] -= 0.5
|
| 129 |
+
return K
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def opencv_to_colmap_intrinsics(K):
|
| 133 |
+
"""
|
| 134 |
+
Modify camera intrinsics to follow a different convention.
|
| 135 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 136 |
+
- (0.5, 0.5) in Colmap
|
| 137 |
+
- (0,0) in OpenCV
|
| 138 |
+
"""
|
| 139 |
+
K = K.copy()
|
| 140 |
+
K[0, 2] += 0.5
|
| 141 |
+
K[1, 2] += 0.5
|
| 142 |
+
return K
|
FastVGGT/merging/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import merge
|
| 2 |
+
|
| 3 |
+
__all__ = ["merge"]
|
FastVGGT/merging/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (187 Bytes). View file
|
|
|
FastVGGT/merging/__pycache__/merge.cpython-310.pyc
ADDED
|
Binary file (7.54 kB). View file
|
|
|
FastVGGT/merging/merge.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple, Callable, Optional, Union
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@torch.jit.script
|
| 6 |
+
def fast_similarity_chunks(
|
| 7 |
+
a: torch.Tensor, b_transposed: torch.Tensor, chunk_size: int
|
| 8 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 9 |
+
|
| 10 |
+
B, num_src, C = a.shape
|
| 11 |
+
original_dtype = a.dtype
|
| 12 |
+
|
| 13 |
+
# Convert to bf16 for computation to improve performance and reduce memory usage
|
| 14 |
+
a_bf16 = a.to(torch.bfloat16)
|
| 15 |
+
b_transposed_bf16 = b_transposed.to(torch.bfloat16)
|
| 16 |
+
node_max = torch.empty(B, num_src, device=a.device, dtype=original_dtype)
|
| 17 |
+
node_idx = torch.empty(B, num_src, device=a.device, dtype=torch.long)
|
| 18 |
+
|
| 19 |
+
# Process in chunks
|
| 20 |
+
for i in range(0, num_src, chunk_size):
|
| 21 |
+
end_i = min(i + chunk_size, num_src)
|
| 22 |
+
a_chunk = a_bf16[:, i:end_i, :] # [B, chunk_size, C]
|
| 23 |
+
scores_chunk = torch.bmm(a_chunk, b_transposed_bf16)
|
| 24 |
+
chunk_max_bf16, chunk_idx = torch.max(scores_chunk, dim=2)
|
| 25 |
+
chunk_max = chunk_max_bf16.to(original_dtype)
|
| 26 |
+
node_max[:, i:end_i] = chunk_max
|
| 27 |
+
node_idx[:, i:end_i] = chunk_idx
|
| 28 |
+
return node_max, node_idx
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def do_nothing(
|
| 32 |
+
x: torch.Tensor,
|
| 33 |
+
extra_tensors=None,
|
| 34 |
+
extra_tensors_2=None,
|
| 35 |
+
) -> Union[
|
| 36 |
+
torch.Tensor,
|
| 37 |
+
Tuple[torch.Tensor, torch.Tensor],
|
| 38 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
| 39 |
+
]:
|
| 40 |
+
if extra_tensors is not None and extra_tensors_2 is not None:
|
| 41 |
+
return x, extra_tensors, extra_tensors_2
|
| 42 |
+
elif extra_tensors is not None:
|
| 43 |
+
return x, extra_tensors
|
| 44 |
+
else:
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def token_merge_bipartite2d(
|
| 49 |
+
metric: torch.Tensor,
|
| 50 |
+
w: int,
|
| 51 |
+
h: int,
|
| 52 |
+
sx: int,
|
| 53 |
+
sy: int,
|
| 54 |
+
r: int,
|
| 55 |
+
no_rand: bool = False,
|
| 56 |
+
generator: Optional[torch.Generator] = None,
|
| 57 |
+
enable_protection: bool = False,
|
| 58 |
+
) -> Tuple[Callable, Callable]:
|
| 59 |
+
"""
|
| 60 |
+
Divide tokens into source (src) and destination (dst) groups, and merge r tokens from src to dst.
|
| 61 |
+
dst tokens are selected by randomly choosing one token from each (sx, sy) region.
|
| 62 |
+
Optionally protect the top 10% of tokens from merging based on importance scores.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
- metric [B, N, C]: Tensor for similarity computation, B=batch size, N=token count, C=feature dimension
|
| 66 |
+
- w: Image width in tokens
|
| 67 |
+
- h: Image height in tokens
|
| 68 |
+
- sx: dst stride in x dimension, must divide w evenly
|
| 69 |
+
- sy: dst stride in y dimension, must divide h evenly
|
| 70 |
+
- r: Number of tokens to remove through merging
|
| 71 |
+
- no_rand: If True, disable randomness (use only top-left token)
|
| 72 |
+
- generator: Random number generator if no_rand is False and not None
|
| 73 |
+
- enable_protection: If True, enable importance protection feature
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
- (merge, unmerge): Two functions for merging tokens and restoring pre-merge state
|
| 77 |
+
"""
|
| 78 |
+
B, N, _ = metric.shape # Batch size B, total tokens N
|
| 79 |
+
if r <= 0:
|
| 80 |
+
return do_nothing, do_nothing
|
| 81 |
+
|
| 82 |
+
gather = torch.gather
|
| 83 |
+
|
| 84 |
+
tokens_per_img = w * h + 5
|
| 85 |
+
num_imgs = N // tokens_per_img
|
| 86 |
+
assert tokens_per_img * num_imgs == N, "Token count doesn't match (w*h+5)*num_imgs"
|
| 87 |
+
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
# Determine whether to compute importance scores based on enable_protection
|
| 90 |
+
if enable_protection:
|
| 91 |
+
num_protected = int(N * 0.1)
|
| 92 |
+
step = max(1, N // num_protected)
|
| 93 |
+
protected_indices = torch.arange(0, N, step, device=metric.device)[
|
| 94 |
+
:num_protected
|
| 95 |
+
]
|
| 96 |
+
else:
|
| 97 |
+
protected_indices = None
|
| 98 |
+
num_protected = 0
|
| 99 |
+
|
| 100 |
+
# Global idx_buffer_seq of length N; -1 indicates dst, 0 indicates src (maintain original logic)
|
| 101 |
+
idx_buffer_seq = torch.zeros(N, device=metric.device, dtype=torch.int64)
|
| 102 |
+
hsy, wsx = h // sy, w // sx # Number of blocks within each image
|
| 103 |
+
|
| 104 |
+
# Mark first image entirely as dst
|
| 105 |
+
if num_imgs > 0:
|
| 106 |
+
idx_buffer_seq[:tokens_per_img] = -1
|
| 107 |
+
|
| 108 |
+
# Process other images - fully vectorized batch operations
|
| 109 |
+
if num_imgs > 1:
|
| 110 |
+
cls_indices = (
|
| 111 |
+
torch.arange(1, num_imgs, device=metric.device) * tokens_per_img
|
| 112 |
+
)
|
| 113 |
+
cls_indices = cls_indices[:, None] + torch.arange(5, device=metric.device)
|
| 114 |
+
idx_buffer_seq[cls_indices.flatten()] = -1
|
| 115 |
+
effective_h = min(hsy * sy, h)
|
| 116 |
+
effective_w = min(wsx * sx, w)
|
| 117 |
+
effective_grid_size = effective_h * effective_w
|
| 118 |
+
|
| 119 |
+
if no_rand:
|
| 120 |
+
base_pattern = torch.zeros(
|
| 121 |
+
effective_grid_size, device=metric.device, dtype=torch.int64
|
| 122 |
+
)
|
| 123 |
+
grid_starts = (
|
| 124 |
+
torch.arange(1, num_imgs, device=metric.device) * tokens_per_img + 5
|
| 125 |
+
)
|
| 126 |
+
grid_indices = grid_starts[:, None] + torch.arange(
|
| 127 |
+
effective_grid_size, device=metric.device
|
| 128 |
+
)
|
| 129 |
+
idx_buffer_seq[grid_indices.flatten()] = base_pattern.repeat(
|
| 130 |
+
num_imgs - 1
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
total_other_imgs = num_imgs - 1
|
| 134 |
+
all_rand_idx = torch.randint(
|
| 135 |
+
sy * sx,
|
| 136 |
+
size=(total_other_imgs, hsy, wsx),
|
| 137 |
+
device=metric.device,
|
| 138 |
+
generator=generator,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
scatter_src = -torch.ones(
|
| 142 |
+
total_other_imgs, hsy, wsx, device=metric.device, dtype=torch.int64
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
idx_buffer_batch = torch.zeros(
|
| 146 |
+
total_other_imgs,
|
| 147 |
+
hsy,
|
| 148 |
+
wsx,
|
| 149 |
+
sy * sx,
|
| 150 |
+
device=metric.device,
|
| 151 |
+
dtype=torch.int64,
|
| 152 |
+
)
|
| 153 |
+
idx_buffer_batch.scatter_(
|
| 154 |
+
dim=3,
|
| 155 |
+
index=all_rand_idx.unsqueeze(-1),
|
| 156 |
+
src=scatter_src.unsqueeze(-1),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
idx_buffer_batch = (
|
| 160 |
+
idx_buffer_batch.view(total_other_imgs, hsy, wsx, sy, sx)
|
| 161 |
+
.transpose(2, 3)
|
| 162 |
+
.reshape(total_other_imgs, hsy * sy, wsx * sx)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Batch fill to target positions - still needs a small loop here, but operations are greatly reduced
|
| 166 |
+
for i in range(total_other_imgs):
|
| 167 |
+
img_idx = i + 1
|
| 168 |
+
grid_start = img_idx * tokens_per_img + 5
|
| 169 |
+
flat_view = idx_buffer_batch[
|
| 170 |
+
i, :effective_h, :effective_w
|
| 171 |
+
].flatten()
|
| 172 |
+
idx_buffer_seq[grid_start : grid_start + effective_grid_size] = (
|
| 173 |
+
flat_view
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
rand_idx = idx_buffer_seq.reshape(1, -1, 1).argsort(dim=1)
|
| 177 |
+
num_dst_orig = int((idx_buffer_seq == -1).sum())
|
| 178 |
+
|
| 179 |
+
# Original src and dst indices
|
| 180 |
+
a_idx_orig = rand_idx[:, num_dst_orig:, :]
|
| 181 |
+
b_idx_orig = rand_idx[:, :num_dst_orig, :]
|
| 182 |
+
a_idx = a_idx_orig
|
| 183 |
+
b_idx = b_idx_orig
|
| 184 |
+
|
| 185 |
+
if enable_protection:
|
| 186 |
+
protected_idx = protected_indices.unsqueeze(0).unsqueeze(-1)
|
| 187 |
+
num_protected_actual = protected_idx.shape[1]
|
| 188 |
+
else:
|
| 189 |
+
protected_idx = None
|
| 190 |
+
num_protected_actual = 0
|
| 191 |
+
|
| 192 |
+
num_src = a_idx.shape[1]
|
| 193 |
+
num_dst = b_idx.shape[1]
|
| 194 |
+
|
| 195 |
+
# Define an internal function to separate src, dst, and protected tokens
|
| 196 |
+
def split(x):
|
| 197 |
+
C = x.shape[-1]
|
| 198 |
+
|
| 199 |
+
if enable_protection:
|
| 200 |
+
src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
|
| 201 |
+
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
| 202 |
+
protected = gather(
|
| 203 |
+
x, dim=1, index=protected_idx.expand(B, num_protected_actual, C)
|
| 204 |
+
)
|
| 205 |
+
return src, dst, protected
|
| 206 |
+
else:
|
| 207 |
+
src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
|
| 208 |
+
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
| 209 |
+
return src, dst
|
| 210 |
+
|
| 211 |
+
# Compute cosine similarity (normalize first then dot product)
|
| 212 |
+
metric = metric / metric.norm(dim=-1, keepdim=True)
|
| 213 |
+
if enable_protection:
|
| 214 |
+
a, b, protected = split(metric)
|
| 215 |
+
else:
|
| 216 |
+
a, b = split(metric)
|
| 217 |
+
|
| 218 |
+
r = min(a.shape[1], r)
|
| 219 |
+
num_src_actual = a.shape[1]
|
| 220 |
+
chunk_size = min(5000, num_src_actual)
|
| 221 |
+
|
| 222 |
+
node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype)
|
| 223 |
+
node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long)
|
| 224 |
+
|
| 225 |
+
b_transposed = b.transpose(-1, -2)
|
| 226 |
+
node_max, node_idx = fast_similarity_chunks(a, b_transposed, chunk_size)
|
| 227 |
+
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
| 228 |
+
|
| 229 |
+
# If protection is enabled, filter out protected tokens to ensure they are not merged
|
| 230 |
+
if enable_protection:
|
| 231 |
+
src_indices = a_idx[0, :, 0]
|
| 232 |
+
protected_mask_src = torch.isin(src_indices, protected_indices)
|
| 233 |
+
edge_flat = edge_idx[0, :, 0]
|
| 234 |
+
valid_mask = ~protected_mask_src[edge_flat]
|
| 235 |
+
valid_edges = edge_flat[valid_mask]
|
| 236 |
+
|
| 237 |
+
valid_count = valid_edges.shape[0]
|
| 238 |
+
r_actual = min(r, valid_count)
|
| 239 |
+
|
| 240 |
+
unm_idx = valid_edges[r_actual:].unsqueeze(0).unsqueeze(-1)
|
| 241 |
+
src_idx = valid_edges[:r_actual].unsqueeze(0).unsqueeze(-1)
|
| 242 |
+
else:
|
| 243 |
+
unm_idx = edge_idx[..., r:, :]
|
| 244 |
+
src_idx = edge_idx[..., :r, :]
|
| 245 |
+
r_actual = r
|
| 246 |
+
|
| 247 |
+
# Get dst token indices corresponding to each src token to be merged
|
| 248 |
+
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
| 249 |
+
r = r_actual
|
| 250 |
+
|
| 251 |
+
# Define merge function to merge selected src tokens to corresponding dst tokens
|
| 252 |
+
def merge(
|
| 253 |
+
x: torch.Tensor,
|
| 254 |
+
mode: str = "mean",
|
| 255 |
+
extra_tensors=None,
|
| 256 |
+
extra_tensors_2=None,
|
| 257 |
+
) -> Union[
|
| 258 |
+
torch.Tensor,
|
| 259 |
+
Tuple[torch.Tensor, torch.Tensor],
|
| 260 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
| 261 |
+
]:
|
| 262 |
+
if enable_protection:
|
| 263 |
+
src, dst, protected = split(x)
|
| 264 |
+
else:
|
| 265 |
+
src, dst = split(x)
|
| 266 |
+
|
| 267 |
+
n, t1, c = src.shape
|
| 268 |
+
|
| 269 |
+
# Extract unmerged src tokens - using actual unm_idx size
|
| 270 |
+
unm_len = unm_idx.shape[1]
|
| 271 |
+
unm = gather(src, dim=-2, index=unm_idx.expand(n, unm_len, c))
|
| 272 |
+
src_len = src_idx.shape[1]
|
| 273 |
+
src = gather(src, dim=-2, index=src_idx.expand(n, src_len, c))
|
| 274 |
+
dst = dst.scatter_reduce(-2, dst_idx.expand(n, src_len, c), src, reduce=mode)
|
| 275 |
+
|
| 276 |
+
# ---------------- Extra tensor processing ----------------
|
| 277 |
+
merged_extra_1 = None
|
| 278 |
+
merged_extra_2 = None
|
| 279 |
+
if extra_tensors is not None:
|
| 280 |
+
E_dim = extra_tensors.shape[-1]
|
| 281 |
+
if enable_protection:
|
| 282 |
+
src_e, dst_e, protected_e = split(extra_tensors)
|
| 283 |
+
else:
|
| 284 |
+
src_e, dst_e = split(extra_tensors)
|
| 285 |
+
|
| 286 |
+
# Consistent with main tensor, only select r src tokens to be merged
|
| 287 |
+
src_e_r = gather(src_e, dim=-2, index=src_idx.expand(n, src_len, E_dim))
|
| 288 |
+
unm_e = gather(src_e, dim=-2, index=unm_idx.expand(n, unm_len, E_dim))
|
| 289 |
+
|
| 290 |
+
dst_e = dst_e.scatter_reduce(
|
| 291 |
+
-2, dst_idx.expand(n, src_len, E_dim), src_e_r, reduce=mode
|
| 292 |
+
)
|
| 293 |
+
if enable_protection:
|
| 294 |
+
merged_extra_1 = torch.cat([unm_e, dst_e, protected_e], dim=1)
|
| 295 |
+
else:
|
| 296 |
+
merged_extra_1 = torch.cat([unm_e, dst_e], dim=1)
|
| 297 |
+
|
| 298 |
+
if extra_tensors_2 is not None:
|
| 299 |
+
E_dim_2 = extra_tensors_2.shape[-1]
|
| 300 |
+
if enable_protection:
|
| 301 |
+
src_e2, dst_e2, protected_e2 = split(extra_tensors_2)
|
| 302 |
+
else:
|
| 303 |
+
src_e2, dst_e2 = split(extra_tensors_2)
|
| 304 |
+
|
| 305 |
+
src_e2_r = gather(src_e2, dim=-2, index=src_idx.expand(n, src_len, E_dim_2))
|
| 306 |
+
unm_e2 = gather(src_e2, dim=-2, index=unm_idx.expand(n, unm_len, E_dim_2))
|
| 307 |
+
|
| 308 |
+
dst_e2 = dst_e2.scatter_reduce(
|
| 309 |
+
-2, dst_idx.expand(n, src_len, E_dim_2), src_e2_r, reduce=mode
|
| 310 |
+
)
|
| 311 |
+
if enable_protection:
|
| 312 |
+
merged_extra_2 = torch.cat([unm_e2, dst_e2, protected_e2], dim=1)
|
| 313 |
+
else:
|
| 314 |
+
merged_extra_2 = torch.cat([unm_e2, dst_e2], dim=1)
|
| 315 |
+
|
| 316 |
+
if enable_protection:
|
| 317 |
+
main_result = torch.cat([unm, dst, protected], dim=1)
|
| 318 |
+
else:
|
| 319 |
+
main_result = torch.cat([unm, dst], dim=1)
|
| 320 |
+
|
| 321 |
+
if merged_extra_1 is not None and merged_extra_2 is not None:
|
| 322 |
+
return main_result, merged_extra_1, merged_extra_2
|
| 323 |
+
elif merged_extra_1 is not None:
|
| 324 |
+
return main_result, merged_extra_1
|
| 325 |
+
else:
|
| 326 |
+
return main_result
|
| 327 |
+
|
| 328 |
+
# Define unmerge function to restore pre-merge state (for decoder)
|
| 329 |
+
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
| 330 |
+
unm_len = unm_idx.shape[1]
|
| 331 |
+
dst_len = num_dst
|
| 332 |
+
src_len = src_idx.shape[1]
|
| 333 |
+
unm = x[..., :unm_len, :]
|
| 334 |
+
dst = x[..., unm_len : unm_len + dst_len, :]
|
| 335 |
+
|
| 336 |
+
if enable_protection:
|
| 337 |
+
protected = x[
|
| 338 |
+
..., unm_len + dst_len : unm_len + dst_len + num_protected_actual, :
|
| 339 |
+
]
|
| 340 |
+
|
| 341 |
+
_, _, c = unm.shape
|
| 342 |
+
src = gather(dst, dim=-2, index=dst_idx.expand(B, src_len, c))
|
| 343 |
+
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
| 344 |
+
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
| 345 |
+
out.scatter_(
|
| 346 |
+
dim=-2,
|
| 347 |
+
index=gather(
|
| 348 |
+
a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx
|
| 349 |
+
).expand(B, unm_len, c),
|
| 350 |
+
src=unm,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
out.scatter_(
|
| 354 |
+
dim=-2,
|
| 355 |
+
index=gather(
|
| 356 |
+
a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx
|
| 357 |
+
).expand(B, src_len, c),
|
| 358 |
+
src=src,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if enable_protection:
|
| 362 |
+
out.scatter_(
|
| 363 |
+
dim=-2,
|
| 364 |
+
index=protected_idx.expand(B, num_protected_actual, c),
|
| 365 |
+
src=protected,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
return out
|
| 369 |
+
|
| 370 |
+
return merge, unmerge
|
FastVGGT/requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.3.1
|
| 2 |
+
torchvision==0.18.1
|
| 3 |
+
numpy==1.26.1
|
| 4 |
+
Pillow
|
| 5 |
+
huggingface_hub
|
| 6 |
+
einops
|
| 7 |
+
safetensors
|
| 8 |
+
evo
|
| 9 |
+
open3d
|
| 10 |
+
matplotlib
|
| 11 |
+
scipy
|
| 12 |
+
opencv-python
|
| 13 |
+
scikit-image
|
| 14 |
+
tqdm
|
| 15 |
+
|
FastVGGT/vggt/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (132 Bytes). View file
|
|
|
FastVGGT/vggt/dependency/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
FastVGGT/vggt/dependency/distortion.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def apply_distortion(points, distortion_params):
|
| 12 |
+
"""
|
| 13 |
+
Apply distortion to normalized camera coordinates.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
points: Array of normalized camera coordinates
|
| 17 |
+
distortion_params: Distortion parameters
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Distorted coordinates
|
| 21 |
+
"""
|
| 22 |
+
# Simple passthrough for now - implement actual distortion if needed
|
| 23 |
+
return points
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def iterative_undistortion(points, distortion_params, max_iter=10):
|
| 27 |
+
"""
|
| 28 |
+
Remove distortion from normalized camera coordinates using iterative method.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
points: Array of distorted normalized camera coordinates
|
| 32 |
+
distortion_params: Distortion parameters
|
| 33 |
+
max_iter: Maximum number of iterations
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Undistorted coordinates
|
| 37 |
+
"""
|
| 38 |
+
# Simple passthrough for now - implement actual undistortion if needed
|
| 39 |
+
return points
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def single_undistortion(points, distortion_params):
|
| 43 |
+
"""
|
| 44 |
+
Remove distortion from normalized camera coordinates using single step.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
points: Array of distorted normalized camera coordinates
|
| 48 |
+
distortion_params: Distortion parameters
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Undistorted coordinates
|
| 52 |
+
"""
|
| 53 |
+
# Simple passthrough for now - implement actual undistortion if needed
|
| 54 |
+
return points
|
FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc
ADDED
|
Binary file (4.24 kB). View file
|
|
|
FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc
ADDED
|
Binary file (3.1 kB). View file
|
|
|
FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc
ADDED
|
Binary file (3.41 kB). View file
|
|
|
FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.18 kB). View file
|
|
|
FastVGGT/vggt/heads/camera_head.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from vggt.layers import Mlp
|
| 15 |
+
from vggt.layers.block import Block
|
| 16 |
+
from vggt.heads.head_act import activate_pose
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CameraHead(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 22 |
+
|
| 23 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
dim_in: int = 2048,
|
| 29 |
+
trunk_depth: int = 4,
|
| 30 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
| 31 |
+
num_heads: int = 16,
|
| 32 |
+
mlp_ratio: int = 4,
|
| 33 |
+
init_values: float = 0.01,
|
| 34 |
+
trans_act: str = "linear",
|
| 35 |
+
quat_act: str = "linear",
|
| 36 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 41 |
+
self.target_dim = 9
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
| 44 |
+
|
| 45 |
+
self.trans_act = trans_act
|
| 46 |
+
self.quat_act = quat_act
|
| 47 |
+
self.fl_act = fl_act
|
| 48 |
+
self.trunk_depth = trunk_depth
|
| 49 |
+
|
| 50 |
+
# Build the trunk using a sequence of transformer blocks.
|
| 51 |
+
self.trunk = nn.Sequential(
|
| 52 |
+
*[
|
| 53 |
+
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
|
| 54 |
+
for _ in range(trunk_depth)
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Normalizations for camera token and trunk output.
|
| 59 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
| 60 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 61 |
+
|
| 62 |
+
# Learnable empty camera pose token.
|
| 63 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 64 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 65 |
+
|
| 66 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 67 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
| 68 |
+
|
| 69 |
+
# Adaptive layer normalization without affine parameters.
|
| 70 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 71 |
+
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
| 72 |
+
|
| 73 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
| 74 |
+
"""
|
| 75 |
+
Forward pass to predict camera parameters.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
| 79 |
+
the last tensor is used for prediction.
|
| 80 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 84 |
+
"""
|
| 85 |
+
# Use tokens from the last block for camera prediction.
|
| 86 |
+
tokens = aggregated_tokens_list[-1]
|
| 87 |
+
|
| 88 |
+
# Extract the camera tokens
|
| 89 |
+
pose_tokens = tokens[:, :, 0]
|
| 90 |
+
pose_tokens = self.token_norm(pose_tokens)
|
| 91 |
+
|
| 92 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
| 93 |
+
return pred_pose_enc_list
|
| 94 |
+
|
| 95 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
| 96 |
+
"""
|
| 97 |
+
Iteratively refine camera pose predictions.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
| 101 |
+
num_iterations (int): Number of refinement iterations.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
list: List of activated camera encodings from each iteration.
|
| 105 |
+
"""
|
| 106 |
+
B, S, C = pose_tokens.shape # S is expected to be 1.
|
| 107 |
+
pred_pose_enc = None
|
| 108 |
+
pred_pose_enc_list = []
|
| 109 |
+
|
| 110 |
+
for _ in range(num_iterations):
|
| 111 |
+
# Use a learned empty pose for the first iteration.
|
| 112 |
+
if pred_pose_enc is None:
|
| 113 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 114 |
+
else:
|
| 115 |
+
# Detach the previous prediction to avoid backprop through time.
|
| 116 |
+
pred_pose_enc = pred_pose_enc.detach()
|
| 117 |
+
module_input = self.embed_pose(pred_pose_enc)
|
| 118 |
+
|
| 119 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 120 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
| 121 |
+
|
| 122 |
+
# Adaptive layer normalization and modulation.
|
| 123 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
| 124 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 125 |
+
|
| 126 |
+
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
| 127 |
+
# Compute the delta update for the pose encoding.
|
| 128 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
| 129 |
+
|
| 130 |
+
if pred_pose_enc is None:
|
| 131 |
+
pred_pose_enc = pred_pose_enc_delta
|
| 132 |
+
else:
|
| 133 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 134 |
+
|
| 135 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
| 136 |
+
activated_pose = activate_pose(
|
| 137 |
+
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
| 138 |
+
)
|
| 139 |
+
pred_pose_enc_list.append(activated_pose)
|
| 140 |
+
|
| 141 |
+
return pred_pose_enc_list
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Modulate the input tensor using scaling and shifting parameters.
|
| 147 |
+
"""
|
| 148 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 149 |
+
return x * (1 + scale) + shift
|
FastVGGT/vggt/heads/dpt_head.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from typing import List, Dict, Tuple, Union, Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from .head_act import activate_head
|
| 18 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DPTHead(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
DPT Head for dense prediction tasks.
|
| 24 |
+
|
| 25 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 26 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 27 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dim_in (int): Input dimension (channels).
|
| 31 |
+
patch_size (int, optional): Patch size. Default is 14.
|
| 32 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
| 33 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
| 34 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 35 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 36 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 37 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 38 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 39 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 40 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim_in: int,
|
| 46 |
+
patch_size: int = 14,
|
| 47 |
+
output_dim: int = 4,
|
| 48 |
+
activation: str = "inv_log",
|
| 49 |
+
conf_activation: str = "expp1",
|
| 50 |
+
features: int = 256,
|
| 51 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 52 |
+
intermediate_layer_idx: List[int] = [0, 1, 2, 3],
|
| 53 |
+
pos_embed: bool = True,
|
| 54 |
+
feature_only: bool = False,
|
| 55 |
+
down_ratio: int = 1,
|
| 56 |
+
) -> None:
|
| 57 |
+
super(DPTHead, self).__init__()
|
| 58 |
+
self.patch_size = patch_size
|
| 59 |
+
self.activation = activation
|
| 60 |
+
self.conf_activation = conf_activation
|
| 61 |
+
self.pos_embed = pos_embed
|
| 62 |
+
self.feature_only = feature_only
|
| 63 |
+
self.down_ratio = down_ratio
|
| 64 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 65 |
+
|
| 66 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 67 |
+
|
| 68 |
+
# Projection layers for each output channel from tokens.
|
| 69 |
+
self.projects = nn.ModuleList(
|
| 70 |
+
[
|
| 71 |
+
nn.Conv2d(
|
| 72 |
+
in_channels=dim_in,
|
| 73 |
+
out_channels=oc,
|
| 74 |
+
kernel_size=1,
|
| 75 |
+
stride=1,
|
| 76 |
+
padding=0,
|
| 77 |
+
)
|
| 78 |
+
for oc in out_channels
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Resize layers for upsampling feature maps.
|
| 83 |
+
self.resize_layers = nn.ModuleList(
|
| 84 |
+
[
|
| 85 |
+
nn.ConvTranspose2d(
|
| 86 |
+
in_channels=out_channels[0],
|
| 87 |
+
out_channels=out_channels[0],
|
| 88 |
+
kernel_size=4,
|
| 89 |
+
stride=4,
|
| 90 |
+
padding=0,
|
| 91 |
+
),
|
| 92 |
+
nn.ConvTranspose2d(
|
| 93 |
+
in_channels=out_channels[1],
|
| 94 |
+
out_channels=out_channels[1],
|
| 95 |
+
kernel_size=2,
|
| 96 |
+
stride=2,
|
| 97 |
+
padding=0,
|
| 98 |
+
),
|
| 99 |
+
nn.Identity(),
|
| 100 |
+
nn.Conv2d(
|
| 101 |
+
in_channels=out_channels[3],
|
| 102 |
+
out_channels=out_channels[3],
|
| 103 |
+
kernel_size=3,
|
| 104 |
+
stride=2,
|
| 105 |
+
padding=1,
|
| 106 |
+
),
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.scratch = _make_scratch(out_channels, features, expand=False)
|
| 111 |
+
|
| 112 |
+
# Attach additional modules to scratch.
|
| 113 |
+
self.scratch.stem_transpose = nn.Identity() # Use Identity instead of None
|
| 114 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 115 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 116 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 117 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 118 |
+
|
| 119 |
+
head_features_1 = features
|
| 120 |
+
head_features_2 = 32
|
| 121 |
+
|
| 122 |
+
if feature_only:
|
| 123 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 124 |
+
head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 128 |
+
head_features_1,
|
| 129 |
+
head_features_1 // 2,
|
| 130 |
+
kernel_size=3,
|
| 131 |
+
stride=1,
|
| 132 |
+
padding=1,
|
| 133 |
+
)
|
| 134 |
+
conv2_in_channels = head_features_1 // 2
|
| 135 |
+
|
| 136 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 137 |
+
nn.Conv2d(
|
| 138 |
+
conv2_in_channels,
|
| 139 |
+
head_features_2,
|
| 140 |
+
kernel_size=3,
|
| 141 |
+
stride=1,
|
| 142 |
+
padding=1,
|
| 143 |
+
),
|
| 144 |
+
nn.ReLU(inplace=True),
|
| 145 |
+
nn.Conv2d(
|
| 146 |
+
head_features_2, output_dim, kernel_size=1, stride=1, padding=0
|
| 147 |
+
),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def forward(
|
| 151 |
+
self,
|
| 152 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 153 |
+
images: torch.Tensor,
|
| 154 |
+
patch_start_idx: int,
|
| 155 |
+
frames_chunk_size: int = 8,
|
| 156 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 157 |
+
"""
|
| 158 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
| 159 |
+
Args:
|
| 160 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 161 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 162 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 163 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 164 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 165 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Tensor or Tuple[Tensor, Tensor]:
|
| 169 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 170 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 171 |
+
"""
|
| 172 |
+
B, S, _, H, W = images.shape
|
| 173 |
+
|
| 174 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
| 175 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 176 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
| 177 |
+
|
| 178 |
+
# Otherwise, process frames in chunks to manage memory usage
|
| 179 |
+
assert frames_chunk_size > 0
|
| 180 |
+
|
| 181 |
+
# Process frames in batches
|
| 182 |
+
all_preds = []
|
| 183 |
+
all_conf = []
|
| 184 |
+
|
| 185 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
| 186 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
| 187 |
+
|
| 188 |
+
# Process batch of frames
|
| 189 |
+
if self.feature_only:
|
| 190 |
+
chunk_output = self._forward_impl(
|
| 191 |
+
aggregated_tokens_list,
|
| 192 |
+
images,
|
| 193 |
+
patch_start_idx,
|
| 194 |
+
frames_start_idx,
|
| 195 |
+
frames_end_idx,
|
| 196 |
+
)
|
| 197 |
+
all_preds.append(chunk_output)
|
| 198 |
+
else:
|
| 199 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
| 200 |
+
aggregated_tokens_list,
|
| 201 |
+
images,
|
| 202 |
+
patch_start_idx,
|
| 203 |
+
frames_start_idx,
|
| 204 |
+
frames_end_idx,
|
| 205 |
+
)
|
| 206 |
+
all_preds.append(chunk_preds)
|
| 207 |
+
all_conf.append(chunk_conf)
|
| 208 |
+
|
| 209 |
+
# Concatenate results along the sequence dimension
|
| 210 |
+
if self.feature_only:
|
| 211 |
+
return torch.cat(all_preds, dim=1)
|
| 212 |
+
else:
|
| 213 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
| 214 |
+
|
| 215 |
+
def _forward_impl(
|
| 216 |
+
self,
|
| 217 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 218 |
+
images: torch.Tensor,
|
| 219 |
+
patch_start_idx: int,
|
| 220 |
+
frames_start_idx: Optional[int] = None,
|
| 221 |
+
frames_end_idx: Optional[int] = None,
|
| 222 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 223 |
+
"""
|
| 224 |
+
Implementation of the forward pass through the DPT head.
|
| 225 |
+
|
| 226 |
+
This method processes a specific chunk of frames from the sequence.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 230 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
| 231 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 232 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
| 233 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
| 237 |
+
"""
|
| 238 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 239 |
+
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
| 240 |
+
|
| 241 |
+
B, S, _, H, W = images.shape
|
| 242 |
+
|
| 243 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 244 |
+
|
| 245 |
+
out = []
|
| 246 |
+
dpt_idx = 0
|
| 247 |
+
|
| 248 |
+
for layer_idx in self.intermediate_layer_idx:
|
| 249 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
| 250 |
+
|
| 251 |
+
# Select frames if processing a chunk
|
| 252 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 253 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
| 254 |
+
|
| 255 |
+
x = x.reshape(B * S, -1, x.shape[-1])
|
| 256 |
+
|
| 257 |
+
x = self.norm(x)
|
| 258 |
+
|
| 259 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 260 |
+
|
| 261 |
+
x = self.projects[dpt_idx](x)
|
| 262 |
+
if self.pos_embed:
|
| 263 |
+
x = self._apply_pos_embed(x, W, H)
|
| 264 |
+
x = self.resize_layers[dpt_idx](x)
|
| 265 |
+
|
| 266 |
+
out.append(x)
|
| 267 |
+
dpt_idx += 1
|
| 268 |
+
|
| 269 |
+
# Fuse features from multiple layers.
|
| 270 |
+
out = self.scratch_forward(out)
|
| 271 |
+
# Interpolate fused output to match target image resolution.
|
| 272 |
+
out = custom_interpolate(
|
| 273 |
+
out,
|
| 274 |
+
(
|
| 275 |
+
int(patch_h * self.patch_size / self.down_ratio),
|
| 276 |
+
int(patch_w * self.patch_size / self.down_ratio),
|
| 277 |
+
),
|
| 278 |
+
mode="bilinear",
|
| 279 |
+
align_corners=True,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if self.pos_embed:
|
| 283 |
+
out = self._apply_pos_embed(out, W, H)
|
| 284 |
+
|
| 285 |
+
if self.feature_only:
|
| 286 |
+
return out.view(B, S, *out.shape[1:])
|
| 287 |
+
|
| 288 |
+
out = self.scratch.output_conv2(out)
|
| 289 |
+
preds, conf = activate_head(
|
| 290 |
+
out, activation=self.activation, conf_activation=self.conf_activation
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
preds = preds.view(B, S, *preds.shape[1:])
|
| 294 |
+
conf = conf.view(B, S, *conf.shape[1:])
|
| 295 |
+
return preds, conf
|
| 296 |
+
|
| 297 |
+
def _apply_pos_embed(
|
| 298 |
+
self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
|
| 299 |
+
) -> torch.Tensor:
|
| 300 |
+
"""
|
| 301 |
+
Apply positional embedding to tensor x.
|
| 302 |
+
"""
|
| 303 |
+
patch_w = x.shape[-1]
|
| 304 |
+
patch_h = x.shape[-2]
|
| 305 |
+
pos_embed = create_uv_grid(
|
| 306 |
+
patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
|
| 307 |
+
)
|
| 308 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 309 |
+
pos_embed = pos_embed * ratio
|
| 310 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 311 |
+
return x + pos_embed
|
| 312 |
+
|
| 313 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 314 |
+
"""
|
| 315 |
+
Forward pass through the fusion blocks.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
features (List[Tensor]): List of feature maps from different layers.
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
Tensor: Fused feature map.
|
| 322 |
+
"""
|
| 323 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
| 324 |
+
|
| 325 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 326 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 327 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 328 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 329 |
+
|
| 330 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 331 |
+
del layer_4_rn, layer_4
|
| 332 |
+
|
| 333 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 334 |
+
del layer_3_rn, layer_3
|
| 335 |
+
|
| 336 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 337 |
+
del layer_2_rn, layer_2
|
| 338 |
+
|
| 339 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 340 |
+
del layer_1_rn, layer_1
|
| 341 |
+
|
| 342 |
+
out = self.scratch.output_conv1(out)
|
| 343 |
+
return out
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
################################################################################
|
| 347 |
+
# Modules
|
| 348 |
+
################################################################################
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def _make_fusion_block(
|
| 352 |
+
features: int,
|
| 353 |
+
size: Optional[int] = None,
|
| 354 |
+
has_residual: bool = True,
|
| 355 |
+
groups: int = 1,
|
| 356 |
+
) -> nn.Module:
|
| 357 |
+
return FeatureFusionBlock(
|
| 358 |
+
features,
|
| 359 |
+
nn.ReLU(inplace=True),
|
| 360 |
+
deconv=False,
|
| 361 |
+
bn=False,
|
| 362 |
+
expand=False,
|
| 363 |
+
align_corners=True,
|
| 364 |
+
size=size,
|
| 365 |
+
has_residual=has_residual,
|
| 366 |
+
groups=groups,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def _make_scratch(
|
| 371 |
+
in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
|
| 372 |
+
) -> nn.Module:
|
| 373 |
+
scratch = nn.Module()
|
| 374 |
+
out_shape1 = out_shape
|
| 375 |
+
out_shape2 = out_shape
|
| 376 |
+
out_shape3 = out_shape
|
| 377 |
+
if len(in_shape) >= 4:
|
| 378 |
+
out_shape4 = out_shape
|
| 379 |
+
|
| 380 |
+
if expand:
|
| 381 |
+
out_shape1 = out_shape
|
| 382 |
+
out_shape2 = out_shape * 2
|
| 383 |
+
out_shape3 = out_shape * 4
|
| 384 |
+
if len(in_shape) >= 4:
|
| 385 |
+
out_shape4 = out_shape * 8
|
| 386 |
+
|
| 387 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 388 |
+
in_shape[0],
|
| 389 |
+
out_shape1,
|
| 390 |
+
kernel_size=3,
|
| 391 |
+
stride=1,
|
| 392 |
+
padding=1,
|
| 393 |
+
bias=False,
|
| 394 |
+
groups=groups,
|
| 395 |
+
)
|
| 396 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 397 |
+
in_shape[1],
|
| 398 |
+
out_shape2,
|
| 399 |
+
kernel_size=3,
|
| 400 |
+
stride=1,
|
| 401 |
+
padding=1,
|
| 402 |
+
bias=False,
|
| 403 |
+
groups=groups,
|
| 404 |
+
)
|
| 405 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 406 |
+
in_shape[2],
|
| 407 |
+
out_shape3,
|
| 408 |
+
kernel_size=3,
|
| 409 |
+
stride=1,
|
| 410 |
+
padding=1,
|
| 411 |
+
bias=False,
|
| 412 |
+
groups=groups,
|
| 413 |
+
)
|
| 414 |
+
if len(in_shape) >= 4:
|
| 415 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 416 |
+
in_shape[3],
|
| 417 |
+
out_shape4,
|
| 418 |
+
kernel_size=3,
|
| 419 |
+
stride=1,
|
| 420 |
+
padding=1,
|
| 421 |
+
bias=False,
|
| 422 |
+
groups=groups,
|
| 423 |
+
)
|
| 424 |
+
return scratch
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class ResidualConvUnit(nn.Module):
|
| 428 |
+
"""Residual convolution module."""
|
| 429 |
+
|
| 430 |
+
def __init__(self, features, activation, bn, groups=1):
|
| 431 |
+
"""Init.
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
features (int): number of features
|
| 435 |
+
"""
|
| 436 |
+
super().__init__()
|
| 437 |
+
|
| 438 |
+
self.bn = bn
|
| 439 |
+
self.groups = groups
|
| 440 |
+
self.conv1 = nn.Conv2d(
|
| 441 |
+
features,
|
| 442 |
+
features,
|
| 443 |
+
kernel_size=3,
|
| 444 |
+
stride=1,
|
| 445 |
+
padding=1,
|
| 446 |
+
bias=True,
|
| 447 |
+
groups=self.groups,
|
| 448 |
+
)
|
| 449 |
+
self.conv2 = nn.Conv2d(
|
| 450 |
+
features,
|
| 451 |
+
features,
|
| 452 |
+
kernel_size=3,
|
| 453 |
+
stride=1,
|
| 454 |
+
padding=1,
|
| 455 |
+
bias=True,
|
| 456 |
+
groups=self.groups,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
self.norm1 = None
|
| 460 |
+
self.norm2 = None
|
| 461 |
+
|
| 462 |
+
self.activation = activation
|
| 463 |
+
|
| 464 |
+
def forward(self, x):
|
| 465 |
+
"""Forward pass.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
x (tensor): input
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
tensor: output
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
out = self.activation(x)
|
| 475 |
+
out = self.conv1(out)
|
| 476 |
+
if self.norm1 is not None:
|
| 477 |
+
out = self.norm1(out)
|
| 478 |
+
|
| 479 |
+
out = self.activation(out)
|
| 480 |
+
out = self.conv2(out)
|
| 481 |
+
if self.norm2 is not None:
|
| 482 |
+
out = self.norm2(out)
|
| 483 |
+
|
| 484 |
+
return out + x
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class FeatureFusionBlock(nn.Module):
|
| 488 |
+
"""Feature fusion block."""
|
| 489 |
+
|
| 490 |
+
def __init__(
|
| 491 |
+
self,
|
| 492 |
+
features,
|
| 493 |
+
activation,
|
| 494 |
+
deconv=False,
|
| 495 |
+
bn=False,
|
| 496 |
+
expand=False,
|
| 497 |
+
align_corners=True,
|
| 498 |
+
size=None,
|
| 499 |
+
has_residual=True,
|
| 500 |
+
groups=1,
|
| 501 |
+
):
|
| 502 |
+
"""Init.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
features (int): number of features
|
| 506 |
+
"""
|
| 507 |
+
super(FeatureFusionBlock, self).__init__()
|
| 508 |
+
|
| 509 |
+
self.deconv = deconv
|
| 510 |
+
self.align_corners = align_corners
|
| 511 |
+
self.groups = groups
|
| 512 |
+
self.expand = expand
|
| 513 |
+
out_features = features
|
| 514 |
+
if self.expand == True:
|
| 515 |
+
out_features = features // 2
|
| 516 |
+
|
| 517 |
+
self.out_conv = nn.Conv2d(
|
| 518 |
+
features,
|
| 519 |
+
out_features,
|
| 520 |
+
kernel_size=1,
|
| 521 |
+
stride=1,
|
| 522 |
+
padding=0,
|
| 523 |
+
bias=True,
|
| 524 |
+
groups=self.groups,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
if has_residual:
|
| 528 |
+
self.resConfUnit1 = ResidualConvUnit(
|
| 529 |
+
features, activation, bn, groups=self.groups
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
self.has_residual = has_residual
|
| 533 |
+
self.resConfUnit2 = ResidualConvUnit(
|
| 534 |
+
features, activation, bn, groups=self.groups
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
self.size = size
|
| 538 |
+
|
| 539 |
+
def forward(self, *xs, size=None):
|
| 540 |
+
"""Forward pass.
|
| 541 |
+
|
| 542 |
+
Returns:
|
| 543 |
+
tensor: output
|
| 544 |
+
"""
|
| 545 |
+
output = xs[0]
|
| 546 |
+
|
| 547 |
+
if self.has_residual:
|
| 548 |
+
res = self.resConfUnit1(xs[1])
|
| 549 |
+
output = output + res
|
| 550 |
+
|
| 551 |
+
output = self.resConfUnit2(output)
|
| 552 |
+
|
| 553 |
+
if (size is None) and (self.size is None):
|
| 554 |
+
modifier = {"scale_factor": 2}
|
| 555 |
+
elif size is None:
|
| 556 |
+
modifier = {"size": self.size}
|
| 557 |
+
else:
|
| 558 |
+
modifier = {"size": size}
|
| 559 |
+
|
| 560 |
+
output = custom_interpolate(
|
| 561 |
+
output, **modifier, mode="bilinear", align_corners=self.align_corners
|
| 562 |
+
)
|
| 563 |
+
output = self.out_conv(output)
|
| 564 |
+
|
| 565 |
+
return output
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def custom_interpolate(
|
| 569 |
+
x: torch.Tensor,
|
| 570 |
+
size: Optional[Tuple[int, int]] = None,
|
| 571 |
+
scale_factor: Optional[float] = None,
|
| 572 |
+
mode: str = "bilinear",
|
| 573 |
+
align_corners: bool = True,
|
| 574 |
+
) -> torch.Tensor:
|
| 575 |
+
"""
|
| 576 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 577 |
+
"""
|
| 578 |
+
if size is None:
|
| 579 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 580 |
+
|
| 581 |
+
INT_MAX = 1610612736
|
| 582 |
+
|
| 583 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 584 |
+
|
| 585 |
+
if input_elements > INT_MAX:
|
| 586 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 587 |
+
interpolated_chunks = [
|
| 588 |
+
nn.functional.interpolate(
|
| 589 |
+
chunk, size=size, mode=mode, align_corners=align_corners
|
| 590 |
+
)
|
| 591 |
+
for chunk in chunks
|
| 592 |
+
]
|
| 593 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 594 |
+
return x.contiguous()
|
| 595 |
+
else:
|
| 596 |
+
return nn.functional.interpolate(
|
| 597 |
+
x, size=size, mode=mode, align_corners=align_corners
|
| 598 |
+
)
|