Delete FastVGGT
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- FastVGGT/.gitignore +0 -160
- FastVGGT/.vscode/launch.json +0 -85
- FastVGGT/README.md +0 -163
- FastVGGT/assets/attn_map.png +0 -3
- FastVGGT/assets/autolab_logo.png +0 -3
- FastVGGT/assets/maclab_logo.png +0 -0
- FastVGGT/assets/main.png +0 -3
- FastVGGT/assets/vs.png +0 -3
- FastVGGT/eval/__pycache__/base.cpython-310.pyc +0 -0
- FastVGGT/eval/__pycache__/criterion.cpython-310.pyc +0 -0
- FastVGGT/eval/__pycache__/data.cpython-310.pyc +0 -0
- FastVGGT/eval/__pycache__/data.cpython-37.pyc +0 -0
- FastVGGT/eval/__pycache__/utils.cpython-310.pyc +0 -0
- FastVGGT/eval/__pycache__/utils.cpython-37.pyc +0 -0
- FastVGGT/eval/base.py +0 -273
- FastVGGT/eval/criterion.py +0 -534
- FastVGGT/eval/data.py +0 -338
- FastVGGT/eval/dataset_utils/__init__.py +0 -1
- FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc +0 -0
- FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc +0 -0
- FastVGGT/eval/dataset_utils/corr.py +0 -234
- FastVGGT/eval/dataset_utils/cropping.py +0 -140
- FastVGGT/eval/dataset_utils/transforms.py +0 -78
- FastVGGT/eval/eval_7andN.py +0 -497
- FastVGGT/eval/eval_custom.py +0 -467
- FastVGGT/eval/eval_scannet.py +0 -208
- FastVGGT/eval/utils.py +0 -142
- FastVGGT/merging/__init__.py +0 -3
- FastVGGT/merging/__pycache__/__init__.cpython-310.pyc +0 -0
- FastVGGT/merging/__pycache__/merge.cpython-310.pyc +0 -0
- FastVGGT/merging/merge.py +0 -370
- FastVGGT/requirements.txt +0 -15
- FastVGGT/vggt/__init__.py +0 -5
- FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc +0 -0
- FastVGGT/vggt/dependency/__init__.py +0 -5
- FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc +0 -0
- FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc +0 -0
- FastVGGT/vggt/dependency/distortion.py +0 -54
- FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc +0 -0
- FastVGGT/vggt/heads/camera_head.py +0 -149
- FastVGGT/vggt/heads/dpt_head.py +0 -598
- FastVGGT/vggt/heads/head_act.py +0 -125
FastVGGT/.gitignore
DELETED
|
@@ -1,160 +0,0 @@
|
|
| 1 |
-
.hydra/
|
| 2 |
-
output/
|
| 3 |
-
ckpt/
|
| 4 |
-
.vscode/
|
| 5 |
-
dependency/
|
| 6 |
-
# Byte-compiled / optimized / DLL files
|
| 7 |
-
__pycache__/
|
| 8 |
-
**/__pycache__/
|
| 9 |
-
*.py[cod]
|
| 10 |
-
*$py.class
|
| 11 |
-
test_logs/
|
| 12 |
-
quick_start_logs/
|
| 13 |
-
logs/
|
| 14 |
-
*.pth
|
| 15 |
-
/data/
|
| 16 |
-
*.png
|
| 17 |
-
eval_results/
|
| 18 |
-
.vscode/
|
| 19 |
-
.curosr/
|
| 20 |
-
|
| 21 |
-
# C extensions
|
| 22 |
-
*.so
|
| 23 |
-
LightGlue/
|
| 24 |
-
# Distribution / packaging
|
| 25 |
-
.Python
|
| 26 |
-
build/
|
| 27 |
-
develop-eggs/
|
| 28 |
-
dist/
|
| 29 |
-
downloads/
|
| 30 |
-
eggs/
|
| 31 |
-
.eggs/
|
| 32 |
-
lib/
|
| 33 |
-
lib64/
|
| 34 |
-
parts/
|
| 35 |
-
sdist/
|
| 36 |
-
var/
|
| 37 |
-
wheels/
|
| 38 |
-
pip-wheel-metadata/
|
| 39 |
-
share/python-wheels/
|
| 40 |
-
*.egg-info/
|
| 41 |
-
.installed.cfg
|
| 42 |
-
*.egg
|
| 43 |
-
MANIFEST
|
| 44 |
-
|
| 45 |
-
# PyInstaller
|
| 46 |
-
# Usually these files are written by a python script from a template
|
| 47 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 48 |
-
*.manifest
|
| 49 |
-
*.spec
|
| 50 |
-
|
| 51 |
-
# Installer logs
|
| 52 |
-
pip-log.txt
|
| 53 |
-
pip-delete-this-directory.txt
|
| 54 |
-
|
| 55 |
-
# Unit test / coverage reports
|
| 56 |
-
htmlcov/
|
| 57 |
-
.tox/
|
| 58 |
-
.nox/
|
| 59 |
-
.coverage
|
| 60 |
-
.coverage.*
|
| 61 |
-
.cache
|
| 62 |
-
nosetests.xml
|
| 63 |
-
coverage.xml
|
| 64 |
-
*.cover
|
| 65 |
-
*.py,cover
|
| 66 |
-
.hypothesis/
|
| 67 |
-
.pytest_cache/
|
| 68 |
-
cover/
|
| 69 |
-
|
| 70 |
-
# Translations
|
| 71 |
-
*.mo
|
| 72 |
-
*.pot
|
| 73 |
-
|
| 74 |
-
# Django stuff:
|
| 75 |
-
*.log
|
| 76 |
-
local_settings.py
|
| 77 |
-
db.sqlite3
|
| 78 |
-
db.sqlite3-journal
|
| 79 |
-
|
| 80 |
-
# Flask stuff:
|
| 81 |
-
instance/
|
| 82 |
-
.webassets-cache
|
| 83 |
-
|
| 84 |
-
# Scrapy stuff:
|
| 85 |
-
.scrapy
|
| 86 |
-
|
| 87 |
-
# Sphinx documentation
|
| 88 |
-
docs/_build/
|
| 89 |
-
|
| 90 |
-
# PyBuilder
|
| 91 |
-
target/
|
| 92 |
-
|
| 93 |
-
# Jupyter Notebook
|
| 94 |
-
.ipynb_checkpoints
|
| 95 |
-
|
| 96 |
-
# IPython
|
| 97 |
-
profile_default/
|
| 98 |
-
ipython_config.py
|
| 99 |
-
|
| 100 |
-
# pyenv
|
| 101 |
-
.python-version
|
| 102 |
-
|
| 103 |
-
# pipenv
|
| 104 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 105 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 106 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 107 |
-
# install all needed dependencies.
|
| 108 |
-
#Pipfile.lock
|
| 109 |
-
|
| 110 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 111 |
-
__pypackages__/
|
| 112 |
-
|
| 113 |
-
# Celery stuff
|
| 114 |
-
celerybeat-schedule
|
| 115 |
-
celerybeat.pid
|
| 116 |
-
|
| 117 |
-
# SageMath parsed files
|
| 118 |
-
*.sage.py
|
| 119 |
-
|
| 120 |
-
# Environments
|
| 121 |
-
.env
|
| 122 |
-
.venv
|
| 123 |
-
env/
|
| 124 |
-
venv/
|
| 125 |
-
ENV/
|
| 126 |
-
env.bak/
|
| 127 |
-
venv.bak/
|
| 128 |
-
|
| 129 |
-
# Spyder project settings
|
| 130 |
-
.spyderproject
|
| 131 |
-
.spyproject
|
| 132 |
-
|
| 133 |
-
# Rope project settings
|
| 134 |
-
.ropeproject
|
| 135 |
-
|
| 136 |
-
# mkdocs documentation
|
| 137 |
-
/site
|
| 138 |
-
|
| 139 |
-
# mypy
|
| 140 |
-
.mypy_cache/
|
| 141 |
-
.dmypy.json
|
| 142 |
-
dmypy.json
|
| 143 |
-
|
| 144 |
-
# Pyre type checker
|
| 145 |
-
.pyre/
|
| 146 |
-
|
| 147 |
-
# pytype static type analyzer
|
| 148 |
-
.pytype/
|
| 149 |
-
|
| 150 |
-
# Profiling data
|
| 151 |
-
.prof
|
| 152 |
-
|
| 153 |
-
# Folder specific to your needs
|
| 154 |
-
**/tmp/
|
| 155 |
-
**/outputs/skyseg.onnx
|
| 156 |
-
skyseg.onnx
|
| 157 |
-
|
| 158 |
-
# pixi environments
|
| 159 |
-
.pixi
|
| 160 |
-
*.egg-info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/.vscode/launch.json
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
// Use IntelliSense to learn about possible attributes.
|
| 3 |
-
// Hover to view descriptions of existing attributes.
|
| 4 |
-
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
| 5 |
-
"version": "0.2.0",
|
| 6 |
-
"configurations": [
|
| 7 |
-
|
| 8 |
-
{
|
| 9 |
-
"name": "launch",
|
| 10 |
-
"type": "debugpy",
|
| 11 |
-
"request": "launch",
|
| 12 |
-
"program": "/home/sy/code/vggt_0625/training/launch.py",
|
| 13 |
-
"console": "integratedTerminal",
|
| 14 |
-
"args": "${command:pickArgs}",
|
| 15 |
-
"env": {
|
| 16 |
-
"CUDA_VISIBLE_DEVICES": "3",
|
| 17 |
-
},
|
| 18 |
-
"cwd": "/home/sy/code/vggt_0625/training",
|
| 19 |
-
"justMyCode": true,
|
| 20 |
-
"python": "/home/sy/anaconda3/envs/vggt/bin/python"
|
| 21 |
-
}
|
| 22 |
-
,{
|
| 23 |
-
"name": "train_scannet",
|
| 24 |
-
"type": "debugpy",
|
| 25 |
-
"request": "launch",
|
| 26 |
-
"program": "/home/sy/code/vggt_0625/training/launch_scannet.py",
|
| 27 |
-
"console": "integratedTerminal",
|
| 28 |
-
"args": [
|
| 29 |
-
// "--config_name", "scannet",
|
| 30 |
-
// "--exp_name", "scannet_exp001",
|
| 31 |
-
// "--resume_checkpoint_path", "/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt"
|
| 32 |
-
],
|
| 33 |
-
"env": {
|
| 34 |
-
"CUDA_VISIBLE_DEVICES": "7",
|
| 35 |
-
"WORLD_SIZE": "1",
|
| 36 |
-
"RANK": "0",
|
| 37 |
-
"MASTER_ADDR": "localhost",
|
| 38 |
-
"MASTER_PORT": "12345"
|
| 39 |
-
},
|
| 40 |
-
"cwd": "/home/sy/code/vggt_0625/training",
|
| 41 |
-
"justMyCode": true,
|
| 42 |
-
"python": "/home/sy/anaconda3/envs/vggt/bin/python"
|
| 43 |
-
}
|
| 44 |
-
,{
|
| 45 |
-
"name": "eval_scannet",
|
| 46 |
-
"type": "debugpy",
|
| 47 |
-
"request": "launch",
|
| 48 |
-
"program": "/home/sy/code/FastVGGT/eval/eval_scannet.py",
|
| 49 |
-
"console": "integratedTerminal",
|
| 50 |
-
"args": [
|
| 51 |
-
"--data_dir","/data/sy/scannetv2/process_scannet",
|
| 52 |
-
"--gt_ply_dir","/data/sy/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
|
| 53 |
-
"--output_path", "/home/sy/code/FastVGGT/eval_results",
|
| 54 |
-
"--merging", "0",
|
| 55 |
-
"--ckpt_path","/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt",
|
| 56 |
-
"--vis_attn_map"
|
| 57 |
-
],
|
| 58 |
-
"env": {
|
| 59 |
-
"CUDA_VISIBLE_DEVICES": "2"
|
| 60 |
-
},
|
| 61 |
-
"justMyCode": true,
|
| 62 |
-
"python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
|
| 63 |
-
},
|
| 64 |
-
{
|
| 65 |
-
"name": "eval_cd",
|
| 66 |
-
"type": "debugpy",
|
| 67 |
-
"request": "launch",
|
| 68 |
-
"program": "/home/sy/code/FastVGGT/eval/eval_custom.py",
|
| 69 |
-
"console": "integratedTerminal",
|
| 70 |
-
"args": [
|
| 71 |
-
"--merging", "0",
|
| 72 |
-
// "--kf","10",
|
| 73 |
-
// "--output_dir","/home/sy/code/vggt_0625/eval_results_cd",
|
| 74 |
-
"--data_path","/data/sy/segment-102751/",
|
| 75 |
-
"--vis_attn_map"
|
| 76 |
-
],
|
| 77 |
-
"env": {
|
| 78 |
-
"CUDA_VISIBLE_DEVICES": "3"
|
| 79 |
-
},
|
| 80 |
-
"justMyCode": true,
|
| 81 |
-
// "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
|
| 82 |
-
}
|
| 83 |
-
|
| 84 |
-
]
|
| 85 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/README.md
DELETED
|
@@ -1,163 +0,0 @@
|
|
| 1 |
-
<div align="center">
|
| 2 |
-
<h2>⚡️ FastVGGT: Training-Free Acceleration of Visual Geometry Transformer</h2>
|
| 3 |
-
|
| 4 |
-
<p align="center">
|
| 5 |
-
<a href="https://arxiv.org/abs/2509.02560"><img src="https://img.shields.io/badge/arXiv-FastVGGT-red?logo=arxiv" alt="Paper PDF"></a>
|
| 6 |
-
<a href="https://mystorm16.github.io/fastvggt/"><img src="https://img.shields.io/badge/Project_Page-FastVGGT-yellow" alt="Project Page"></a>
|
| 7 |
-
</p>
|
| 8 |
-
|
| 9 |
-
<img src="assets/maclab_logo.png" alt="Maclab Logo" width="110" style="margin-right: 40px;">
|
| 10 |
-
<img src="assets/autolab_logo.png" alt="Autolab Logo" width="110">
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
**[Media Analytics & Computing Laboratory](https://mac.xmu.edu.cn/)**; **[AUTOLAB](https://zhipengzhang.cn/)**
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
[You Shen](https://mystorm16.github.io/), [Zhipeng Zhang](https://zhipengzhang.cn/), [Yansong Qu](https://quyans.github.io/), [Liujuan Cao](https://mac.xmu.edu.cn/ljcao/)
|
| 17 |
-
</div>
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
## 📰 News
|
| 21 |
-
- [Sep 8, 2025] Added custom dataset evaluation.
|
| 22 |
-
- [Sep 3, 2025] Paper release.
|
| 23 |
-
- [Sep 2, 2025] Code release.
|
| 24 |
-
|
| 25 |
-
## 🔭 Overview
|
| 26 |
-
|
| 27 |
-
FastVGGT observes **strong similarity** in attention maps and leverages it to design a training-free acceleration method for long-sequence 3D reconstruction, **achieving up to 4× faster inference without sacrificing accuracy.**
|
| 28 |
-
|
| 29 |
-
<img src="assets/main.png" alt="Autolab Logo" width="">
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
## ⚙️ Environment Setup
|
| 33 |
-
First, create a virtual environment using Conda, clone this repository to your local machine, and install the required dependencies.
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
```bash
|
| 37 |
-
conda create -n fastvggt python=3.10
|
| 38 |
-
conda activate fastvggt
|
| 39 |
-
git clone git@github.com:mystorm16/FastVGGT.git
|
| 40 |
-
cd FastVGGT
|
| 41 |
-
pip install -r requirements.txt
|
| 42 |
-
```
|
| 43 |
-
|
| 44 |
-
Next, prepare the ScanNet dataset: http://www.scan-net.org/ScanNet/
|
| 45 |
-
|
| 46 |
-
Then, download the VGGT checkpoint (we use the checkpoint link provided in https://github.com/facebookresearch/vggt/tree/evaluation/evaluation):
|
| 47 |
-
```bash
|
| 48 |
-
wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt
|
| 49 |
-
```
|
| 50 |
-
|
| 51 |
-
Finally, configure the dataset path and VGGT checkpoint path. For example:
|
| 52 |
-
```bash
|
| 53 |
-
parser.add_argument(
|
| 54 |
-
"--data_dir", type=Path, default="/data/scannetv2/process_scannet"
|
| 55 |
-
)
|
| 56 |
-
parser.add_argument(
|
| 57 |
-
"--gt_ply_dir",
|
| 58 |
-
type=Path,
|
| 59 |
-
default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
|
| 60 |
-
)
|
| 61 |
-
parser.add_argument(
|
| 62 |
-
"--ckpt_path",
|
| 63 |
-
type=str,
|
| 64 |
-
default="./ckpt/model_tracker_fixed_e20.pt",
|
| 65 |
-
)
|
| 66 |
-
```
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
## 💎 Observation
|
| 70 |
-
|
| 71 |
-
Note: A large number of input_frames may significantly slow down saving the visualization results. Please try using a smaller number first.
|
| 72 |
-
```bash
|
| 73 |
-
python eval/eval_scannet.py --input_frame 30 --vis_attn_map --merging 0
|
| 74 |
-
```
|
| 75 |
-
|
| 76 |
-
We observe that many token-level attention maps are highly similar in each block, motivating our optimization of the Global Attention module.
|
| 77 |
-
|
| 78 |
-
<img src="assets/attn_map.png" alt="Autolab Logo" width="">
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
## 🏀 Evaluation
|
| 83 |
-
### Custom Dataset
|
| 84 |
-
Please organize the data according to the following directory:
|
| 85 |
-
```
|
| 86 |
-
<data_path>/
|
| 87 |
-
├── images/
|
| 88 |
-
│ ├── 000000.jpg
|
| 89 |
-
│ ├── 000001.jpg
|
| 90 |
-
│ └── ...
|
| 91 |
-
├── pose/ # Optional: Camera poses
|
| 92 |
-
│ ├── 000000.txt
|
| 93 |
-
│ ├── 000001.txt
|
| 94 |
-
│ └── ...
|
| 95 |
-
└── gt_ply/ # Optional: GT point cloud
|
| 96 |
-
└── scene_xxx.ply
|
| 97 |
-
```
|
| 98 |
-
- Required: `images/`
|
| 99 |
-
- Additionally required when `--enable_evaluation` is enabled: `pose/` and `gt_ply/`
|
| 100 |
-
|
| 101 |
-
Inference only:
|
| 102 |
-
|
| 103 |
-
```bash
|
| 104 |
-
python eval/eval_custom.py \
|
| 105 |
-
--data_path /path/to/your_dataset \
|
| 106 |
-
--output_path ./eval_results_custom \
|
| 107 |
-
--plot
|
| 108 |
-
```
|
| 109 |
-
|
| 110 |
-
Inference + Evaluation (requires `pose/` and `gt_ply/`):
|
| 111 |
-
|
| 112 |
-
```bash
|
| 113 |
-
python eval/eval_custom.py \
|
| 114 |
-
--data_path /path/to/your_dataset \
|
| 115 |
-
--enable_evaluation \
|
| 116 |
-
--output_path ./eval_results_custom \
|
| 117 |
-
--plot
|
| 118 |
-
```
|
| 119 |
-
|
| 120 |
-
### ScanNet
|
| 121 |
-
Evaluate FastVGGT on the ScanNet dataset with 1,000 input images. The **--merging** parameter specifies the block index at which the merging strategy is applied:
|
| 122 |
-
|
| 123 |
-
```bash
|
| 124 |
-
python eval/eval_scannet.py --input_frame 1000 --merging 0
|
| 125 |
-
```
|
| 126 |
-
|
| 127 |
-
Evaluate Baseline VGGT on the ScanNet dataset with 1,000 input images:
|
| 128 |
-
```bash
|
| 129 |
-
python eval/eval_scannet.py --input_frame 1000
|
| 130 |
-
```
|
| 131 |
-
<img src="assets/vs.png" alt="Autolab Logo" width="">
|
| 132 |
-
|
| 133 |
-
### 7 Scenes & NRGBD
|
| 134 |
-
Evaluate across two datasets, sampling keyframes every 10 frames:
|
| 135 |
-
```bash
|
| 136 |
-
python eval/eval_7andN.py --kf 10
|
| 137 |
-
```
|
| 138 |
-
|
| 139 |
-
## 🍺 Acknowledgements
|
| 140 |
-
|
| 141 |
-
- Thanks to these great repositories: [VGGT](https://github.com/facebookresearch/vggt), [Dust3r](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [CUT3R](https://github.com/CUT3R/CUT3R), [MV-DUSt3R+](https://github.com/facebookresearch/mvdust3r), [StreamVGGT](https://github.com/wzzheng/StreamVGGT), [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long), [ToMeSD](https://github.com/dbolya/tomesd) and many other inspiring works in the community.
|
| 142 |
-
|
| 143 |
-
- Special thanks to [Jianyuan Wang](https://jytime.github.io/) for his valuable discussions and suggestions on this work.
|
| 144 |
-
|
| 145 |
-
<!-- ## ✍️ Checklist
|
| 146 |
-
|
| 147 |
-
- [ ] Release the evaluation code on 7 Scenes / NRGBD -->
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
## ⚖️ License
|
| 151 |
-
See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
|
| 152 |
-
|
| 153 |
-
## Citation
|
| 154 |
-
|
| 155 |
-
If you find this project helpful, please consider citing the following paper:
|
| 156 |
-
```
|
| 157 |
-
@article{shen2025fastvggt,
|
| 158 |
-
title={FastVGGT: Training-Free Acceleration of Visual Geometry Transformer},
|
| 159 |
-
author={Shen, You and Zhang, Zhipeng and Qu, Yansong and Cao, Liujuan},
|
| 160 |
-
journal={arXiv preprint arXiv:2509.02560},
|
| 161 |
-
year={2025}
|
| 162 |
-
}
|
| 163 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/assets/attn_map.png
DELETED
Git LFS Details
|
FastVGGT/assets/autolab_logo.png
DELETED
Git LFS Details
|
FastVGGT/assets/maclab_logo.png
DELETED
|
Binary file (4.8 kB)
|
|
|
FastVGGT/assets/main.png
DELETED
Git LFS Details
|
FastVGGT/assets/vs.png
DELETED
Git LFS Details
|
FastVGGT/eval/__pycache__/base.cpython-310.pyc
DELETED
|
Binary file (6.92 kB)
|
|
|
FastVGGT/eval/__pycache__/criterion.cpython-310.pyc
DELETED
|
Binary file (13.6 kB)
|
|
|
FastVGGT/eval/__pycache__/data.cpython-310.pyc
DELETED
|
Binary file (7.78 kB)
|
|
|
FastVGGT/eval/__pycache__/data.cpython-37.pyc
DELETED
|
Binary file (8.03 kB)
|
|
|
FastVGGT/eval/__pycache__/utils.cpython-310.pyc
DELETED
|
Binary file (3.99 kB)
|
|
|
FastVGGT/eval/__pycache__/utils.cpython-37.pyc
DELETED
|
Binary file (4.32 kB)
|
|
|
FastVGGT/eval/base.py
DELETED
|
@@ -1,273 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# base class for implementing datasets
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import PIL
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
from dataset_utils.transforms import ImgNorm
|
| 12 |
-
import dataset_utils.cropping as cropping
|
| 13 |
-
from utils import depthmap_to_absolute_camera_coordinates
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class BaseStereoViewDataset:
|
| 17 |
-
"""Define all basic options.
|
| 18 |
-
|
| 19 |
-
Usage:
|
| 20 |
-
class MyDataset (BaseStereoViewDataset):
|
| 21 |
-
def _get_views(self, idx, rng):
|
| 22 |
-
# overload here
|
| 23 |
-
views = []
|
| 24 |
-
views.append(dict(img=, ...))
|
| 25 |
-
return views
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
def __init__(
|
| 29 |
-
self,
|
| 30 |
-
*, # only keyword arguments
|
| 31 |
-
split=None,
|
| 32 |
-
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
| 33 |
-
transform=ImgNorm,
|
| 34 |
-
aug_crop=False,
|
| 35 |
-
seed=None,
|
| 36 |
-
):
|
| 37 |
-
self.num_views = 2
|
| 38 |
-
self.split = split
|
| 39 |
-
self._set_resolutions(resolution)
|
| 40 |
-
|
| 41 |
-
self.transform = transform
|
| 42 |
-
if isinstance(transform, str):
|
| 43 |
-
transform = eval(transform)
|
| 44 |
-
|
| 45 |
-
self.aug_crop = aug_crop
|
| 46 |
-
self.seed = seed
|
| 47 |
-
|
| 48 |
-
def __len__(self):
|
| 49 |
-
return len(self.scenes)
|
| 50 |
-
|
| 51 |
-
def get_stats(self):
|
| 52 |
-
return f"{len(self)} pairs"
|
| 53 |
-
|
| 54 |
-
def __repr__(self):
|
| 55 |
-
resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
|
| 56 |
-
return (
|
| 57 |
-
f"""{type(self).__name__}({self.get_stats()},
|
| 58 |
-
{self.split=},
|
| 59 |
-
{self.seed=},
|
| 60 |
-
resolutions={resolutions_str},
|
| 61 |
-
{self.transform=})""".replace(
|
| 62 |
-
"self.", ""
|
| 63 |
-
)
|
| 64 |
-
.replace("\n", "")
|
| 65 |
-
.replace(" ", "")
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
def _get_views(self, idx, resolution, rng):
|
| 69 |
-
raise NotImplementedError()
|
| 70 |
-
|
| 71 |
-
def __getitem__(self, idx):
|
| 72 |
-
if isinstance(idx, tuple):
|
| 73 |
-
# the idx is specifying the aspect-ratio
|
| 74 |
-
idx, ar_idx = idx
|
| 75 |
-
else:
|
| 76 |
-
assert len(self._resolutions) == 1
|
| 77 |
-
ar_idx = 0
|
| 78 |
-
|
| 79 |
-
# set-up the rng
|
| 80 |
-
if self.seed: # reseed for each __getitem__
|
| 81 |
-
self._rng = np.random.default_rng(seed=self.seed + idx)
|
| 82 |
-
elif not hasattr(self, "_rng"):
|
| 83 |
-
seed = torch.initial_seed() # this is different for each dataloader process
|
| 84 |
-
self._rng = np.random.default_rng(seed=seed)
|
| 85 |
-
|
| 86 |
-
# over-loaded code
|
| 87 |
-
resolution = self._resolutions[
|
| 88 |
-
ar_idx
|
| 89 |
-
] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
| 90 |
-
views = self._get_views(idx, resolution, self._rng)
|
| 91 |
-
|
| 92 |
-
# check data-types
|
| 93 |
-
for v, view in enumerate(views):
|
| 94 |
-
assert (
|
| 95 |
-
"pts3d" not in view
|
| 96 |
-
), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
| 97 |
-
view["idx"] = v
|
| 98 |
-
|
| 99 |
-
# encode the image
|
| 100 |
-
width, height = view["img"].size
|
| 101 |
-
view["true_shape"] = np.int32((height, width))
|
| 102 |
-
view["img"] = self.transform(view["img"])
|
| 103 |
-
|
| 104 |
-
assert "camera_intrinsics" in view
|
| 105 |
-
if "camera_pose" not in view:
|
| 106 |
-
view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
|
| 107 |
-
else:
|
| 108 |
-
assert np.isfinite(
|
| 109 |
-
view["camera_pose"]
|
| 110 |
-
).all(), f"NaN in camera pose for view {view_name(view)}"
|
| 111 |
-
assert "pts3d" not in view
|
| 112 |
-
assert "valid_mask" not in view
|
| 113 |
-
assert np.isfinite(
|
| 114 |
-
view["depthmap"]
|
| 115 |
-
).all(), f"NaN in depthmap for view {view_name(view)}"
|
| 116 |
-
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 117 |
-
|
| 118 |
-
view["pts3d"] = pts3d
|
| 119 |
-
view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
| 120 |
-
|
| 121 |
-
# check all datatypes
|
| 122 |
-
for key, val in view.items():
|
| 123 |
-
res, err_msg = is_good_type(key, val)
|
| 124 |
-
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
| 125 |
-
K = view["camera_intrinsics"]
|
| 126 |
-
view["img_mask"] = True
|
| 127 |
-
view["ray_mask"] = False
|
| 128 |
-
view["ray_map"] = torch.full(
|
| 129 |
-
(6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
|
| 130 |
-
)
|
| 131 |
-
view["update"] = True
|
| 132 |
-
view["reset"] = False
|
| 133 |
-
|
| 134 |
-
# last thing done!
|
| 135 |
-
for view in views:
|
| 136 |
-
# transpose to make sure all views are the same size
|
| 137 |
-
transpose_to_landscape(view)
|
| 138 |
-
# this allows to check whether the RNG is is the same state each time
|
| 139 |
-
view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
|
| 140 |
-
return views
|
| 141 |
-
|
| 142 |
-
def _set_resolutions(self, resolutions):
|
| 143 |
-
"""Set the resolution(s) of the dataset.
|
| 144 |
-
Params:
|
| 145 |
-
- resolutions: int or tuple or list of tuples
|
| 146 |
-
"""
|
| 147 |
-
assert resolutions is not None, "undefined resolution"
|
| 148 |
-
|
| 149 |
-
if not isinstance(resolutions, list):
|
| 150 |
-
resolutions = [resolutions]
|
| 151 |
-
|
| 152 |
-
self._resolutions = []
|
| 153 |
-
for resolution in resolutions:
|
| 154 |
-
if isinstance(resolution, int):
|
| 155 |
-
width = height = resolution
|
| 156 |
-
else:
|
| 157 |
-
width, height = resolution
|
| 158 |
-
assert isinstance(
|
| 159 |
-
width, int
|
| 160 |
-
), f"Bad type for {width=} {type(width)=}, should be int"
|
| 161 |
-
assert isinstance(
|
| 162 |
-
height, int
|
| 163 |
-
), f"Bad type for {height=} {type(height)=}, should be int"
|
| 164 |
-
assert width >= height
|
| 165 |
-
self._resolutions.append((width, height))
|
| 166 |
-
|
| 167 |
-
def _crop_resize_if_necessary(
|
| 168 |
-
self, image, depthmap, intrinsics, resolution, rng=None, info=None
|
| 169 |
-
):
|
| 170 |
-
"""This function:
|
| 171 |
-
- first downsizes the image with LANCZOS inteprolation,
|
| 172 |
-
which is better than bilinear interpolation in
|
| 173 |
-
"""
|
| 174 |
-
if not isinstance(image, PIL.Image.Image):
|
| 175 |
-
image = PIL.Image.fromarray(image)
|
| 176 |
-
|
| 177 |
-
# downscale with lanczos interpolation so that image.size == resolution
|
| 178 |
-
# cropping centered on the principal point
|
| 179 |
-
W, H = image.size
|
| 180 |
-
cx, cy = intrinsics[:2, 2].round().astype(int)
|
| 181 |
-
|
| 182 |
-
# calculate min distance to margin
|
| 183 |
-
min_margin_x = min(cx, W - cx)
|
| 184 |
-
min_margin_y = min(cy, H - cy)
|
| 185 |
-
assert min_margin_x > W / 5, f"Bad principal point in view={info}"
|
| 186 |
-
assert min_margin_y > H / 5, f"Bad principal point in view={info}"
|
| 187 |
-
|
| 188 |
-
## Center crop
|
| 189 |
-
# Crop on the principal point, make it always centered
|
| 190 |
-
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
|
| 191 |
-
l, t = cx - min_margin_x, cy - min_margin_y
|
| 192 |
-
r, b = cx + min_margin_x, cy + min_margin_y
|
| 193 |
-
crop_bbox = (l, t, r, b)
|
| 194 |
-
|
| 195 |
-
image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 196 |
-
image, depthmap, intrinsics, crop_bbox
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
# # transpose the resolution if necessary
|
| 200 |
-
W, H = image.size # new size
|
| 201 |
-
assert resolution[0] >= resolution[1]
|
| 202 |
-
if H > 1.1 * W:
|
| 203 |
-
# image is portrait mode
|
| 204 |
-
resolution = resolution[::-1]
|
| 205 |
-
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
|
| 206 |
-
# image is square, so we chose (portrait, landscape) randomly
|
| 207 |
-
if rng.integers(2):
|
| 208 |
-
resolution = resolution[::-1]
|
| 209 |
-
|
| 210 |
-
# high-quality Lanczos down-scaling
|
| 211 |
-
target_resolution = np.array(resolution)
|
| 212 |
-
# # if self.aug_crop > 1:
|
| 213 |
-
# # target_resolution += rng.integers(0, self.aug_crop)
|
| 214 |
-
# if resolution != (224, 224):
|
| 215 |
-
# halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
|
| 216 |
-
# ## Recale with max factor, so one of width or height might be larger than target_resolution
|
| 217 |
-
# image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
|
| 218 |
-
# else:
|
| 219 |
-
image, depthmap, intrinsics = cropping.rescale_image_depthmap(
|
| 220 |
-
image, depthmap, intrinsics, target_resolution
|
| 221 |
-
)
|
| 222 |
-
# actual cropping (if necessary) with bilinear interpolation
|
| 223 |
-
# if resolution == (224, 224):
|
| 224 |
-
intrinsics2 = cropping.camera_matrix_of_crop(
|
| 225 |
-
intrinsics, image.size, resolution, offset_factor=0.5
|
| 226 |
-
)
|
| 227 |
-
crop_bbox = cropping.bbox_from_intrinsics_in_out(
|
| 228 |
-
intrinsics, intrinsics2, resolution
|
| 229 |
-
)
|
| 230 |
-
image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 231 |
-
image, depthmap, intrinsics, crop_bbox
|
| 232 |
-
)
|
| 233 |
-
return image, depthmap, intrinsics
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
def is_good_type(key, v):
|
| 237 |
-
"""returns (is_good, err_msg)"""
|
| 238 |
-
if isinstance(v, (str, int, tuple)):
|
| 239 |
-
return True, None
|
| 240 |
-
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
|
| 241 |
-
return False, f"bad {v.dtype=}"
|
| 242 |
-
return True, None
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
def view_name(view, batch_index=None):
|
| 246 |
-
def sel(x):
|
| 247 |
-
return x[batch_index] if batch_index not in (None, slice(None)) else x
|
| 248 |
-
|
| 249 |
-
db = sel(view["dataset"])
|
| 250 |
-
label = sel(view["label"])
|
| 251 |
-
instance = sel(view["instance"])
|
| 252 |
-
return f"{db}/{label}/{instance}"
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
def transpose_to_landscape(view):
|
| 256 |
-
height, width = view["true_shape"]
|
| 257 |
-
|
| 258 |
-
if width < height:
|
| 259 |
-
# rectify portrait to landscape
|
| 260 |
-
assert view["img"].shape == (3, height, width)
|
| 261 |
-
view["img"] = view["img"].swapaxes(1, 2)
|
| 262 |
-
|
| 263 |
-
assert view["valid_mask"].shape == (height, width)
|
| 264 |
-
view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
|
| 265 |
-
|
| 266 |
-
assert view["depthmap"].shape == (height, width)
|
| 267 |
-
view["depthmap"] = view["depthmap"].swapaxes(0, 1)
|
| 268 |
-
|
| 269 |
-
assert view["pts3d"].shape == (height, width, 3)
|
| 270 |
-
view["pts3d"] = view["pts3d"].swapaxes(0, 1)
|
| 271 |
-
|
| 272 |
-
# transpose x and y pixels
|
| 273 |
-
view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/eval/criterion.py
DELETED
|
@@ -1,534 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from copy import copy, deepcopy
|
| 4 |
-
|
| 5 |
-
from eval.dataset_utils.corr import geotrf, inv
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def invalid_to_nans(arr, valid_mask, ndim=999):
|
| 9 |
-
if valid_mask is not None:
|
| 10 |
-
arr = arr.clone()
|
| 11 |
-
arr[~valid_mask] = float("nan")
|
| 12 |
-
if arr.ndim > ndim:
|
| 13 |
-
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 14 |
-
return arr
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def invalid_to_zeros(arr, valid_mask, ndim=999):
|
| 18 |
-
if valid_mask is not None:
|
| 19 |
-
arr = arr.clone()
|
| 20 |
-
arr[~valid_mask] = 0
|
| 21 |
-
nnz = valid_mask.view(len(valid_mask), -1).sum(1)
|
| 22 |
-
else:
|
| 23 |
-
nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
|
| 24 |
-
if arr.ndim > ndim:
|
| 25 |
-
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 26 |
-
return arr, nnz
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class BaseCriterion(nn.Module):
|
| 30 |
-
def __init__(self, reduction="mean"):
|
| 31 |
-
super().__init__()
|
| 32 |
-
self.reduction = reduction
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class Criterion(nn.Module):
|
| 36 |
-
def __init__(self, criterion=None):
|
| 37 |
-
super().__init__()
|
| 38 |
-
assert isinstance(
|
| 39 |
-
criterion, BaseCriterion
|
| 40 |
-
), f"{criterion} is not a proper criterion!"
|
| 41 |
-
self.criterion = copy(criterion)
|
| 42 |
-
|
| 43 |
-
def get_name(self):
|
| 44 |
-
return f"{type(self).__name__}({self.criterion})"
|
| 45 |
-
|
| 46 |
-
def with_reduction(self, mode="none"):
|
| 47 |
-
res = loss = deepcopy(self)
|
| 48 |
-
while loss is not None:
|
| 49 |
-
assert isinstance(loss, Criterion)
|
| 50 |
-
loss.criterion.reduction = mode # make it return the loss for each sample
|
| 51 |
-
loss = loss._loss2 # we assume loss is a Multiloss
|
| 52 |
-
return res
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class MultiLoss(nn.Module):
|
| 56 |
-
"""Easily combinable losses (also keep track of individual loss values):
|
| 57 |
-
loss = MyLoss1() + 0.1*MyLoss2()
|
| 58 |
-
Usage:
|
| 59 |
-
Inherit from this class and override get_name() and compute_loss()
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
def __init__(self):
|
| 63 |
-
super().__init__()
|
| 64 |
-
self._alpha = 1
|
| 65 |
-
self._loss2 = None
|
| 66 |
-
|
| 67 |
-
def compute_loss(self, *args, **kwargs):
|
| 68 |
-
raise NotImplementedError()
|
| 69 |
-
|
| 70 |
-
def get_name(self):
|
| 71 |
-
raise NotImplementedError()
|
| 72 |
-
|
| 73 |
-
def __mul__(self, alpha):
|
| 74 |
-
assert isinstance(alpha, (int, float))
|
| 75 |
-
res = copy(self)
|
| 76 |
-
res._alpha = alpha
|
| 77 |
-
return res
|
| 78 |
-
|
| 79 |
-
__rmul__ = __mul__ # same
|
| 80 |
-
|
| 81 |
-
def __add__(self, loss2):
|
| 82 |
-
assert isinstance(loss2, MultiLoss)
|
| 83 |
-
res = cur = copy(self)
|
| 84 |
-
|
| 85 |
-
while cur._loss2 is not None:
|
| 86 |
-
cur = cur._loss2
|
| 87 |
-
cur._loss2 = loss2
|
| 88 |
-
return res
|
| 89 |
-
|
| 90 |
-
def __repr__(self):
|
| 91 |
-
name = self.get_name()
|
| 92 |
-
if self._alpha != 1:
|
| 93 |
-
name = f"{self._alpha:g}*{name}"
|
| 94 |
-
if self._loss2:
|
| 95 |
-
name = f"{name} + {self._loss2}"
|
| 96 |
-
return name
|
| 97 |
-
|
| 98 |
-
def forward(self, *args, **kwargs):
|
| 99 |
-
loss = self.compute_loss(*args, **kwargs)
|
| 100 |
-
if isinstance(loss, tuple):
|
| 101 |
-
loss, details = loss
|
| 102 |
-
elif loss.ndim == 0:
|
| 103 |
-
details = {self.get_name(): float(loss)}
|
| 104 |
-
else:
|
| 105 |
-
details = {}
|
| 106 |
-
loss = loss * self._alpha
|
| 107 |
-
|
| 108 |
-
if self._loss2:
|
| 109 |
-
loss2, details2 = self._loss2(*args, **kwargs)
|
| 110 |
-
loss = loss + loss2
|
| 111 |
-
details |= details2
|
| 112 |
-
|
| 113 |
-
return loss, details
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
class LLoss(BaseCriterion):
|
| 117 |
-
"""L-norm loss"""
|
| 118 |
-
|
| 119 |
-
def forward(self, a, b):
|
| 120 |
-
assert (
|
| 121 |
-
a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
|
| 122 |
-
), f"Bad shape = {a.shape}"
|
| 123 |
-
dist = self.distance(a, b)
|
| 124 |
-
|
| 125 |
-
if self.reduction == "none":
|
| 126 |
-
return dist
|
| 127 |
-
if self.reduction == "sum":
|
| 128 |
-
return dist.sum()
|
| 129 |
-
if self.reduction == "mean":
|
| 130 |
-
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
|
| 131 |
-
raise ValueError(f"bad {self.reduction=} mode")
|
| 132 |
-
|
| 133 |
-
def distance(self, a, b):
|
| 134 |
-
raise NotImplementedError()
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class L21Loss(LLoss):
|
| 138 |
-
"""Euclidean distance between 3d points"""
|
| 139 |
-
|
| 140 |
-
def distance(self, a, b):
|
| 141 |
-
return torch.norm(a - b, dim=-1) # normalized L2 distance
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
L21 = L21Loss()
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
def get_pred_pts3d(gt, pred, use_pose=False):
|
| 148 |
-
assert use_pose is True
|
| 149 |
-
return pred["pts3d_in_other_view"] # return!
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
def Sum(losses, masks, conf=None):
|
| 153 |
-
loss, mask = losses[0], masks[0]
|
| 154 |
-
if loss.ndim > 0:
|
| 155 |
-
# we are actually returning the loss for every pixels
|
| 156 |
-
if conf is not None:
|
| 157 |
-
return losses, masks, conf
|
| 158 |
-
return losses, masks
|
| 159 |
-
else:
|
| 160 |
-
# we are returning the global loss
|
| 161 |
-
for loss2 in losses[1:]:
|
| 162 |
-
loss = loss + loss2
|
| 163 |
-
return loss
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
|
| 167 |
-
assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
|
| 168 |
-
assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
|
| 169 |
-
norm_mode, dis_mode = norm_mode.split("_")
|
| 170 |
-
|
| 171 |
-
nan_pts = []
|
| 172 |
-
nnzs = []
|
| 173 |
-
|
| 174 |
-
if norm_mode == "avg":
|
| 175 |
-
# gather all points together (joint normalization)
|
| 176 |
-
|
| 177 |
-
for i, pt in enumerate(pts):
|
| 178 |
-
nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
|
| 179 |
-
nan_pts.append(nan_pt)
|
| 180 |
-
nnzs.append(nnz)
|
| 181 |
-
|
| 182 |
-
if fix_first:
|
| 183 |
-
break
|
| 184 |
-
all_pts = torch.cat(nan_pts, dim=1)
|
| 185 |
-
|
| 186 |
-
# compute distance to origin
|
| 187 |
-
all_dis = all_pts.norm(dim=-1)
|
| 188 |
-
if dis_mode == "dis":
|
| 189 |
-
pass # do nothing
|
| 190 |
-
elif dis_mode == "log1p":
|
| 191 |
-
all_dis = torch.log1p(all_dis)
|
| 192 |
-
else:
|
| 193 |
-
raise ValueError(f"bad {dis_mode=}")
|
| 194 |
-
|
| 195 |
-
norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
|
| 196 |
-
else:
|
| 197 |
-
raise ValueError(f"Not implemented {norm_mode=}")
|
| 198 |
-
|
| 199 |
-
norm_factor = norm_factor.clip(min=1e-8)
|
| 200 |
-
while norm_factor.ndim < pts[0].ndim:
|
| 201 |
-
norm_factor.unsqueeze_(-1)
|
| 202 |
-
|
| 203 |
-
return norm_factor
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
def normalize_pointcloud_t(
|
| 207 |
-
pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
|
| 208 |
-
):
|
| 209 |
-
if gt:
|
| 210 |
-
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
|
| 211 |
-
res = []
|
| 212 |
-
|
| 213 |
-
for i, pt in enumerate(pts):
|
| 214 |
-
res.append(pt / norm_factor)
|
| 215 |
-
|
| 216 |
-
else:
|
| 217 |
-
# pts_l, pts_r = pts
|
| 218 |
-
# use pts_l and pts_r[-1] as pts to normalize
|
| 219 |
-
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
|
| 220 |
-
|
| 221 |
-
res = []
|
| 222 |
-
|
| 223 |
-
for i in range(len(pts)):
|
| 224 |
-
res.append(pts[i] / norm_factor)
|
| 225 |
-
# res_r.append(pts_r[i] / norm_factor)
|
| 226 |
-
|
| 227 |
-
# res = [res_l, res_r]
|
| 228 |
-
|
| 229 |
-
return res, norm_factor
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
@torch.no_grad()
|
| 233 |
-
def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
|
| 234 |
-
# set invalid points to NaN
|
| 235 |
-
_zs = []
|
| 236 |
-
for i in range(len(zs)):
|
| 237 |
-
valid_mask = valid_masks[i] if valid_masks is not None else None
|
| 238 |
-
_z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
|
| 239 |
-
_zs.append(_z)
|
| 240 |
-
|
| 241 |
-
_zs = torch.cat(_zs, dim=-1)
|
| 242 |
-
|
| 243 |
-
# compute median depth overall (ignoring nans)
|
| 244 |
-
if quantile == 0.5:
|
| 245 |
-
shift_z = torch.nanmedian(_zs, dim=-1).values
|
| 246 |
-
else:
|
| 247 |
-
shift_z = torch.nanquantile(_zs, quantile, dim=-1)
|
| 248 |
-
return shift_z # (B,)
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
@torch.no_grad()
|
| 252 |
-
def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
|
| 253 |
-
# set invalid points to NaN
|
| 254 |
-
|
| 255 |
-
_pts = []
|
| 256 |
-
for i in range(len(pts)):
|
| 257 |
-
valid_mask = valid_masks[i] if valid_masks is not None else None
|
| 258 |
-
_pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
|
| 259 |
-
_pts.append(_pt)
|
| 260 |
-
|
| 261 |
-
_pts = torch.cat(_pts, dim=1)
|
| 262 |
-
|
| 263 |
-
# compute median center
|
| 264 |
-
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
|
| 265 |
-
if z_only:
|
| 266 |
-
_center[..., :2] = 0 # do not center X and Y
|
| 267 |
-
|
| 268 |
-
# compute median norm
|
| 269 |
-
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
|
| 270 |
-
scale = torch.nanmedian(_norm, dim=1).values
|
| 271 |
-
return _center[:, None, :, :], scale[:, None, None, None]
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
class Regr3D_t(Criterion, MultiLoss):
|
| 275 |
-
def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
|
| 276 |
-
super().__init__(criterion)
|
| 277 |
-
self.norm_mode = norm_mode
|
| 278 |
-
self.gt_scale = gt_scale
|
| 279 |
-
self.fix_first = fix_first
|
| 280 |
-
|
| 281 |
-
def get_all_pts3d_t(self, gts, preds, dist_clip=None):
|
| 282 |
-
# everything is normalized w.r.t. camera of view1
|
| 283 |
-
in_camera1 = inv(gts[0]["camera_pose"])
|
| 284 |
-
|
| 285 |
-
gt_pts = []
|
| 286 |
-
valids = []
|
| 287 |
-
pr_pts = []
|
| 288 |
-
|
| 289 |
-
for i, gt in enumerate(gts):
|
| 290 |
-
# in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
|
| 291 |
-
gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
|
| 292 |
-
valid = gt["valid_mask"].clone()
|
| 293 |
-
|
| 294 |
-
if dist_clip is not None:
|
| 295 |
-
# points that are too far-away == invalid
|
| 296 |
-
dis = gt["pts3d"].norm(dim=-1)
|
| 297 |
-
valid = valid & (dis <= dist_clip)
|
| 298 |
-
|
| 299 |
-
valids.append(valid)
|
| 300 |
-
pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
|
| 301 |
-
# if i != len(gts)-1:
|
| 302 |
-
# pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))
|
| 303 |
-
|
| 304 |
-
# if i != 0:
|
| 305 |
-
# pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))
|
| 306 |
-
|
| 307 |
-
# pr_pts = (pr_pts_l, pr_pts_r)
|
| 308 |
-
|
| 309 |
-
if self.norm_mode:
|
| 310 |
-
pr_pts, pr_factor = normalize_pointcloud_t(
|
| 311 |
-
pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
|
| 312 |
-
)
|
| 313 |
-
else:
|
| 314 |
-
pr_factor = None
|
| 315 |
-
|
| 316 |
-
if self.norm_mode and not self.gt_scale:
|
| 317 |
-
gt_pts, gt_factor = normalize_pointcloud_t(
|
| 318 |
-
gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
|
| 319 |
-
)
|
| 320 |
-
else:
|
| 321 |
-
gt_factor = None
|
| 322 |
-
|
| 323 |
-
return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}
|
| 324 |
-
|
| 325 |
-
def compute_frame_loss(self, gts, preds, **kw):
|
| 326 |
-
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
|
| 327 |
-
self.get_all_pts3d_t(gts, preds, **kw)
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
pred_pts_l, pred_pts_r = pred_pts
|
| 331 |
-
|
| 332 |
-
loss_all = []
|
| 333 |
-
mask_all = []
|
| 334 |
-
conf_all = []
|
| 335 |
-
|
| 336 |
-
loss_left = 0
|
| 337 |
-
loss_right = 0
|
| 338 |
-
pred_conf_l = 0
|
| 339 |
-
pred_conf_r = 0
|
| 340 |
-
|
| 341 |
-
for i in range(len(gt_pts)):
|
| 342 |
-
|
| 343 |
-
# Left (Reference)
|
| 344 |
-
if i != len(gt_pts) - 1:
|
| 345 |
-
frame_loss = self.criterion(
|
| 346 |
-
pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
|
| 347 |
-
)
|
| 348 |
-
|
| 349 |
-
loss_all.append(frame_loss)
|
| 350 |
-
mask_all.append(masks[i])
|
| 351 |
-
conf_all.append(preds[i][0]["conf"])
|
| 352 |
-
|
| 353 |
-
# To compare target/reference loss
|
| 354 |
-
if i != 0:
|
| 355 |
-
loss_left += frame_loss.cpu().detach().numpy().mean()
|
| 356 |
-
pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()
|
| 357 |
-
|
| 358 |
-
# Right (Target)
|
| 359 |
-
if i != 0:
|
| 360 |
-
frame_loss = self.criterion(
|
| 361 |
-
pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
|
| 362 |
-
)
|
| 363 |
-
|
| 364 |
-
loss_all.append(frame_loss)
|
| 365 |
-
mask_all.append(masks[i])
|
| 366 |
-
conf_all.append(preds[i - 1][1]["conf"])
|
| 367 |
-
|
| 368 |
-
# To compare target/reference loss
|
| 369 |
-
if i != len(gt_pts) - 1:
|
| 370 |
-
loss_right += frame_loss.cpu().detach().numpy().mean()
|
| 371 |
-
pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()
|
| 372 |
-
|
| 373 |
-
if pr_factor is not None and gt_factor is not None:
|
| 374 |
-
filter_factor = pr_factor[pr_factor > gt_factor]
|
| 375 |
-
else:
|
| 376 |
-
filter_factor = []
|
| 377 |
-
|
| 378 |
-
if len(filter_factor) > 0:
|
| 379 |
-
factor_loss = (filter_factor - gt_factor).abs().mean()
|
| 380 |
-
else:
|
| 381 |
-
factor_loss = 0.0
|
| 382 |
-
|
| 383 |
-
self_name = type(self).__name__
|
| 384 |
-
details = {
|
| 385 |
-
self_name + "_pts3d_1": float(loss_all[0].mean()),
|
| 386 |
-
self_name + "_pts3d_2": float(loss_all[1].mean()),
|
| 387 |
-
self_name + "loss_left": float(loss_left),
|
| 388 |
-
self_name + "loss_right": float(loss_right),
|
| 389 |
-
self_name + "conf_left": float(pred_conf_l),
|
| 390 |
-
self_name + "conf_right": float(pred_conf_r),
|
| 391 |
-
}
|
| 392 |
-
|
| 393 |
-
return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
class ConfLoss_t(MultiLoss):
|
| 397 |
-
"""Weighted regression by learned confidence.
|
| 398 |
-
Assuming the input pixel_loss is a pixel-level regression loss.
|
| 399 |
-
|
| 400 |
-
Principle:
|
| 401 |
-
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
|
| 402 |
-
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
|
| 403 |
-
|
| 404 |
-
alpha: hyperparameter
|
| 405 |
-
"""
|
| 406 |
-
|
| 407 |
-
def __init__(self, pixel_loss, alpha=1):
|
| 408 |
-
super().__init__()
|
| 409 |
-
assert alpha > 0
|
| 410 |
-
self.alpha = alpha
|
| 411 |
-
self.pixel_loss = pixel_loss.with_reduction("none")
|
| 412 |
-
|
| 413 |
-
def get_name(self):
|
| 414 |
-
return f"ConfLoss({self.pixel_loss})"
|
| 415 |
-
|
| 416 |
-
def get_conf_log(self, x):
|
| 417 |
-
return x, torch.log(x)
|
| 418 |
-
|
| 419 |
-
def compute_frame_loss(self, gts, preds, **kw):
|
| 420 |
-
# compute per-pixel loss
|
| 421 |
-
(losses, masks, confs), details, loss_factor = (
|
| 422 |
-
self.pixel_loss.compute_frame_loss(gts, preds, **kw)
|
| 423 |
-
)
|
| 424 |
-
|
| 425 |
-
# weight by confidence
|
| 426 |
-
conf_losses = []
|
| 427 |
-
conf_sum = 0
|
| 428 |
-
for i in range(len(losses)):
|
| 429 |
-
conf, log_conf = self.get_conf_log(confs[i][masks[i]])
|
| 430 |
-
conf_sum += conf.mean()
|
| 431 |
-
conf_loss = losses[i] * conf - self.alpha * log_conf
|
| 432 |
-
conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
|
| 433 |
-
conf_losses.append(conf_loss)
|
| 434 |
-
|
| 435 |
-
conf_losses = torch.stack(conf_losses) * 2.0
|
| 436 |
-
conf_loss_mean = conf_losses.mean()
|
| 437 |
-
|
| 438 |
-
return (
|
| 439 |
-
conf_loss_mean,
|
| 440 |
-
dict(
|
| 441 |
-
conf_loss_1=float(conf_losses[0]),
|
| 442 |
-
conf_loss2=float(conf_losses[1]),
|
| 443 |
-
conf_mean=conf_sum / len(losses),
|
| 444 |
-
**details,
|
| 445 |
-
),
|
| 446 |
-
loss_factor,
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
class Regr3D_t_ShiftInv(Regr3D_t):
|
| 451 |
-
"""Same than Regr3D but invariant to depth shift."""
|
| 452 |
-
|
| 453 |
-
def get_all_pts3d_t(self, gts, preds):
|
| 454 |
-
# compute unnormalized points
|
| 455 |
-
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
|
| 456 |
-
super().get_all_pts3d_t(gts, preds)
|
| 457 |
-
)
|
| 458 |
-
|
| 459 |
-
# pred_pts_l, pred_pts_r = pred_pts
|
| 460 |
-
gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]
|
| 461 |
-
|
| 462 |
-
pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
|
| 463 |
-
# pred_zs.append(pred_pts_r[-1][..., 2])
|
| 464 |
-
|
| 465 |
-
# compute median depth
|
| 466 |
-
gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
|
| 467 |
-
pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]
|
| 468 |
-
|
| 469 |
-
# subtract the median depth
|
| 470 |
-
for i in range(len(gt_pts)):
|
| 471 |
-
gt_pts[i][..., 2] -= gt_shift_z
|
| 472 |
-
|
| 473 |
-
for i in range(len(pred_pts)):
|
| 474 |
-
# for j in range(len(pred_pts[i])):
|
| 475 |
-
pred_pts[i][..., 2] -= pred_shift_z
|
| 476 |
-
|
| 477 |
-
monitoring = dict(
|
| 478 |
-
monitoring,
|
| 479 |
-
gt_shift_z=gt_shift_z.mean().detach(),
|
| 480 |
-
pred_shift_z=pred_shift_z.mean().detach(),
|
| 481 |
-
)
|
| 482 |
-
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
class Regr3D_t_ScaleInv(Regr3D_t):
|
| 486 |
-
"""Same than Regr3D but invariant to depth shift.
|
| 487 |
-
if gt_scale == True: enforce the prediction to take the same scale than GT
|
| 488 |
-
"""
|
| 489 |
-
|
| 490 |
-
def get_all_pts3d_t(self, gts, preds):
|
| 491 |
-
# compute depth-normalized points
|
| 492 |
-
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
|
| 493 |
-
super().get_all_pts3d_t(gts, preds)
|
| 494 |
-
)
|
| 495 |
-
|
| 496 |
-
# measure scene scale
|
| 497 |
-
|
| 498 |
-
# pred_pts_l, pred_pts_r = pred_pts
|
| 499 |
-
|
| 500 |
-
pred_pts_all = [
|
| 501 |
-
x.clone() for x in pred_pts
|
| 502 |
-
] # [pred_pt for pred_pt in pred_pts_l]
|
| 503 |
-
# pred_pts_all.append(pred_pts_r[-1])
|
| 504 |
-
|
| 505 |
-
_, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
|
| 506 |
-
_, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)
|
| 507 |
-
|
| 508 |
-
# prevent predictions to be in a ridiculous range
|
| 509 |
-
pred_scale = pred_scale.clip(min=1e-3, max=1e3)
|
| 510 |
-
|
| 511 |
-
# subtract the median depth
|
| 512 |
-
if self.gt_scale:
|
| 513 |
-
for i in range(len(pred_pts)):
|
| 514 |
-
# for j in range(len(pred_pts[i])):
|
| 515 |
-
pred_pts[i] *= gt_scale / pred_scale
|
| 516 |
-
|
| 517 |
-
else:
|
| 518 |
-
for i in range(len(pred_pts)):
|
| 519 |
-
# for j in range(len(pred_pts[i])):
|
| 520 |
-
pred_pts[i] *= pred_scale / gt_scale
|
| 521 |
-
|
| 522 |
-
for i in range(len(gt_pts)):
|
| 523 |
-
gt_pts[i] *= gt_scale / pred_scale
|
| 524 |
-
|
| 525 |
-
monitoring = dict(
|
| 526 |
-
monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
|
| 527 |
-
)
|
| 528 |
-
|
| 529 |
-
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
|
| 533 |
-
# calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
|
| 534 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/eval/data.py
DELETED
|
@@ -1,338 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import os.path as osp
|
| 5 |
-
from collections import deque
|
| 6 |
-
from base import BaseStereoViewDataset
|
| 7 |
-
import dataset_utils.cropping as cropping
|
| 8 |
-
from vggt.utils.eval_utils import imread_cv2, shuffle_deque
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class SevenScenes(BaseStereoViewDataset):
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
num_seq=1,
|
| 15 |
-
num_frames=5,
|
| 16 |
-
min_thresh=10,
|
| 17 |
-
max_thresh=100,
|
| 18 |
-
test_id=None,
|
| 19 |
-
full_video=False,
|
| 20 |
-
tuple_list=None,
|
| 21 |
-
seq_id=None,
|
| 22 |
-
rebuttal=False,
|
| 23 |
-
shuffle_seed=-1,
|
| 24 |
-
kf_every=1,
|
| 25 |
-
*args,
|
| 26 |
-
ROOT,
|
| 27 |
-
**kwargs,
|
| 28 |
-
):
|
| 29 |
-
self.ROOT = ROOT
|
| 30 |
-
super().__init__(*args, **kwargs)
|
| 31 |
-
self.num_seq = num_seq
|
| 32 |
-
self.num_frames = num_frames
|
| 33 |
-
self.max_thresh = max_thresh
|
| 34 |
-
self.min_thresh = min_thresh
|
| 35 |
-
self.test_id = test_id
|
| 36 |
-
self.full_video = full_video
|
| 37 |
-
self.kf_every = kf_every
|
| 38 |
-
self.seq_id = seq_id
|
| 39 |
-
self.rebuttal = rebuttal
|
| 40 |
-
self.shuffle_seed = shuffle_seed
|
| 41 |
-
|
| 42 |
-
# load all scenes
|
| 43 |
-
self.load_all_tuples(tuple_list)
|
| 44 |
-
self.load_all_scenes(ROOT)
|
| 45 |
-
|
| 46 |
-
def __len__(self):
|
| 47 |
-
if self.tuple_list is not None:
|
| 48 |
-
return len(self.tuple_list)
|
| 49 |
-
return len(self.scene_list) * self.num_seq
|
| 50 |
-
|
| 51 |
-
def load_all_tuples(self, tuple_list):
|
| 52 |
-
if tuple_list is not None:
|
| 53 |
-
self.tuple_list = tuple_list
|
| 54 |
-
# with open(tuple_path) as f:
|
| 55 |
-
# self.tuple_list = f.read().splitlines()
|
| 56 |
-
|
| 57 |
-
else:
|
| 58 |
-
self.tuple_list = None
|
| 59 |
-
|
| 60 |
-
def load_all_scenes(self, base_dir):
|
| 61 |
-
|
| 62 |
-
if self.tuple_list is not None:
|
| 63 |
-
# Use pre-defined simplerecon scene_ids
|
| 64 |
-
self.scene_list = [
|
| 65 |
-
"stairs/seq-06",
|
| 66 |
-
"stairs/seq-02",
|
| 67 |
-
"pumpkin/seq-06",
|
| 68 |
-
"chess/seq-01",
|
| 69 |
-
"heads/seq-02",
|
| 70 |
-
"fire/seq-02",
|
| 71 |
-
"office/seq-03",
|
| 72 |
-
"pumpkin/seq-03",
|
| 73 |
-
"redkitchen/seq-07",
|
| 74 |
-
"chess/seq-02",
|
| 75 |
-
"office/seq-01",
|
| 76 |
-
"redkitchen/seq-01",
|
| 77 |
-
"fire/seq-01",
|
| 78 |
-
]
|
| 79 |
-
print(f"Found {len(self.scene_list)} sequences in split {self.split}")
|
| 80 |
-
return
|
| 81 |
-
|
| 82 |
-
scenes = os.listdir(base_dir)
|
| 83 |
-
|
| 84 |
-
file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split]
|
| 85 |
-
|
| 86 |
-
self.scene_list = []
|
| 87 |
-
for scene in scenes:
|
| 88 |
-
if self.test_id is not None and scene != self.test_id:
|
| 89 |
-
continue
|
| 90 |
-
# read file split
|
| 91 |
-
with open(osp.join(base_dir, scene, file_split)) as f:
|
| 92 |
-
seq_ids = f.read().splitlines()
|
| 93 |
-
|
| 94 |
-
for seq_id in seq_ids:
|
| 95 |
-
# seq is string, take the int part and make it 01, 02, 03
|
| 96 |
-
# seq_id = 'seq-{:2d}'.format(int(seq_id))
|
| 97 |
-
num_part = "".join(filter(str.isdigit, seq_id))
|
| 98 |
-
seq_id = f"seq-{num_part.zfill(2)}"
|
| 99 |
-
if self.seq_id is not None and seq_id != self.seq_id:
|
| 100 |
-
continue
|
| 101 |
-
self.scene_list.append(f"{scene}/{seq_id}")
|
| 102 |
-
|
| 103 |
-
print(f"Found {len(self.scene_list)} sequences in split {self.split}")
|
| 104 |
-
|
| 105 |
-
def _get_views(self, idx, resolution, rng):
|
| 106 |
-
|
| 107 |
-
if self.tuple_list is not None:
|
| 108 |
-
line = self.tuple_list[idx].split(" ")
|
| 109 |
-
scene_id = line[0]
|
| 110 |
-
img_idxs = line[1:]
|
| 111 |
-
|
| 112 |
-
else:
|
| 113 |
-
scene_id = self.scene_list[idx // self.num_seq]
|
| 114 |
-
seq_id = idx % self.num_seq
|
| 115 |
-
|
| 116 |
-
data_path = osp.join(self.ROOT, scene_id)
|
| 117 |
-
num_files = len([name for name in os.listdir(data_path) if "color" in name])
|
| 118 |
-
img_idxs = [f"{i:06d}" for i in range(num_files)]
|
| 119 |
-
img_idxs = img_idxs[:: self.kf_every]
|
| 120 |
-
|
| 121 |
-
# Intrinsics used in SimpleRecon
|
| 122 |
-
fx, fy, cx, cy = 525, 525, 320, 240
|
| 123 |
-
intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| 124 |
-
|
| 125 |
-
views = []
|
| 126 |
-
imgs_idxs = deque(img_idxs)
|
| 127 |
-
if self.shuffle_seed >= 0:
|
| 128 |
-
imgs_idxs = shuffle_deque(imgs_idxs)
|
| 129 |
-
|
| 130 |
-
while len(imgs_idxs) > 0:
|
| 131 |
-
im_idx = imgs_idxs.popleft()
|
| 132 |
-
impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png")
|
| 133 |
-
depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png")
|
| 134 |
-
posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt")
|
| 135 |
-
|
| 136 |
-
rgb_image = imread_cv2(impath)
|
| 137 |
-
|
| 138 |
-
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
|
| 139 |
-
rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
|
| 140 |
-
|
| 141 |
-
depthmap[depthmap == 65535] = 0
|
| 142 |
-
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
|
| 143 |
-
|
| 144 |
-
depthmap[depthmap > 10] = 0
|
| 145 |
-
depthmap[depthmap < 1e-3] = 0
|
| 146 |
-
|
| 147 |
-
camera_pose = np.loadtxt(posepath).astype(np.float32)
|
| 148 |
-
|
| 149 |
-
if resolution != (224, 224) or self.rebuttal:
|
| 150 |
-
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 151 |
-
rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
|
| 152 |
-
)
|
| 153 |
-
else:
|
| 154 |
-
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 155 |
-
rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
|
| 156 |
-
)
|
| 157 |
-
W, H = rgb_image.size
|
| 158 |
-
cx = W // 2
|
| 159 |
-
cy = H // 2
|
| 160 |
-
l, t = cx - 112, cy - 112
|
| 161 |
-
r, b = cx + 112, cy + 112
|
| 162 |
-
crop_bbox = (l, t, r, b)
|
| 163 |
-
rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 164 |
-
rgb_image, depthmap, intrinsics, crop_bbox
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
views.append(
|
| 168 |
-
dict(
|
| 169 |
-
img=rgb_image,
|
| 170 |
-
depthmap=depthmap,
|
| 171 |
-
camera_pose=camera_pose,
|
| 172 |
-
camera_intrinsics=intrinsics,
|
| 173 |
-
dataset="7scenes",
|
| 174 |
-
label=osp.join(scene_id, im_idx),
|
| 175 |
-
instance=impath,
|
| 176 |
-
)
|
| 177 |
-
)
|
| 178 |
-
return views
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
class NRGBD(BaseStereoViewDataset):
|
| 182 |
-
def __init__(
|
| 183 |
-
self,
|
| 184 |
-
num_seq=1,
|
| 185 |
-
num_frames=5,
|
| 186 |
-
min_thresh=10,
|
| 187 |
-
max_thresh=100,
|
| 188 |
-
test_id=None,
|
| 189 |
-
full_video=False,
|
| 190 |
-
tuple_list=None,
|
| 191 |
-
seq_id=None,
|
| 192 |
-
rebuttal=False,
|
| 193 |
-
shuffle_seed=-1,
|
| 194 |
-
kf_every=1,
|
| 195 |
-
*args,
|
| 196 |
-
ROOT,
|
| 197 |
-
**kwargs,
|
| 198 |
-
):
|
| 199 |
-
|
| 200 |
-
self.ROOT = ROOT
|
| 201 |
-
super().__init__(*args, **kwargs)
|
| 202 |
-
self.num_seq = num_seq
|
| 203 |
-
self.num_frames = num_frames
|
| 204 |
-
self.max_thresh = max_thresh
|
| 205 |
-
self.min_thresh = min_thresh
|
| 206 |
-
self.test_id = test_id
|
| 207 |
-
self.full_video = full_video
|
| 208 |
-
self.kf_every = kf_every
|
| 209 |
-
self.seq_id = seq_id
|
| 210 |
-
self.rebuttal = rebuttal
|
| 211 |
-
self.shuffle_seed = shuffle_seed
|
| 212 |
-
|
| 213 |
-
# load all scenes
|
| 214 |
-
self.load_all_tuples(tuple_list)
|
| 215 |
-
self.load_all_scenes(ROOT)
|
| 216 |
-
|
| 217 |
-
def __len__(self):
|
| 218 |
-
if self.tuple_list is not None:
|
| 219 |
-
return len(self.tuple_list)
|
| 220 |
-
return len(self.scene_list) * self.num_seq
|
| 221 |
-
|
| 222 |
-
def load_all_tuples(self, tuple_list):
|
| 223 |
-
if tuple_list is not None:
|
| 224 |
-
self.tuple_list = tuple_list
|
| 225 |
-
# with open(tuple_path) as f:
|
| 226 |
-
# self.tuple_list = f.read().splitlines()
|
| 227 |
-
|
| 228 |
-
else:
|
| 229 |
-
self.tuple_list = None
|
| 230 |
-
|
| 231 |
-
def load_all_scenes(self, base_dir):
|
| 232 |
-
|
| 233 |
-
scenes = [
|
| 234 |
-
d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
|
| 235 |
-
]
|
| 236 |
-
|
| 237 |
-
if self.test_id is not None:
|
| 238 |
-
self.scene_list = [self.test_id]
|
| 239 |
-
|
| 240 |
-
else:
|
| 241 |
-
self.scene_list = scenes
|
| 242 |
-
|
| 243 |
-
print(f"Found {len(self.scene_list)} sequences in split {self.split}")
|
| 244 |
-
|
| 245 |
-
def load_poses(self, path):
|
| 246 |
-
file = open(path, "r")
|
| 247 |
-
lines = file.readlines()
|
| 248 |
-
file.close()
|
| 249 |
-
poses = []
|
| 250 |
-
valid = []
|
| 251 |
-
lines_per_matrix = 4
|
| 252 |
-
for i in range(0, len(lines), lines_per_matrix):
|
| 253 |
-
if "nan" in lines[i]:
|
| 254 |
-
valid.append(False)
|
| 255 |
-
poses.append(np.eye(4, 4, dtype=np.float32).tolist())
|
| 256 |
-
else:
|
| 257 |
-
valid.append(True)
|
| 258 |
-
pose_floats = [
|
| 259 |
-
[float(x) for x in line.split()]
|
| 260 |
-
for line in lines[i : i + lines_per_matrix]
|
| 261 |
-
]
|
| 262 |
-
poses.append(pose_floats)
|
| 263 |
-
|
| 264 |
-
return np.array(poses, dtype=np.float32), valid
|
| 265 |
-
|
| 266 |
-
def _get_views(self, idx, resolution, rng):
|
| 267 |
-
|
| 268 |
-
if self.tuple_list is not None:
|
| 269 |
-
line = self.tuple_list[idx].split(" ")
|
| 270 |
-
scene_id = line[0]
|
| 271 |
-
img_idxs = line[1:]
|
| 272 |
-
|
| 273 |
-
else:
|
| 274 |
-
scene_id = self.scene_list[idx // self.num_seq]
|
| 275 |
-
|
| 276 |
-
num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images")))
|
| 277 |
-
img_idxs = [f"{i}" for i in range(num_files)]
|
| 278 |
-
img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)]
|
| 279 |
-
|
| 280 |
-
fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240
|
| 281 |
-
intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| 282 |
-
|
| 283 |
-
posepath = osp.join(self.ROOT, scene_id, f"poses.txt")
|
| 284 |
-
camera_poses, valids = self.load_poses(posepath)
|
| 285 |
-
|
| 286 |
-
imgs_idxs = deque(img_idxs)
|
| 287 |
-
if self.shuffle_seed >= 0:
|
| 288 |
-
imgs_idxs = shuffle_deque(imgs_idxs)
|
| 289 |
-
views = []
|
| 290 |
-
|
| 291 |
-
while len(imgs_idxs) > 0:
|
| 292 |
-
im_idx = imgs_idxs.popleft()
|
| 293 |
-
|
| 294 |
-
impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png")
|
| 295 |
-
depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png")
|
| 296 |
-
|
| 297 |
-
rgb_image = imread_cv2(impath)
|
| 298 |
-
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
|
| 299 |
-
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
|
| 300 |
-
depthmap[depthmap > 10] = 0
|
| 301 |
-
depthmap[depthmap < 1e-3] = 0
|
| 302 |
-
|
| 303 |
-
rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
|
| 304 |
-
|
| 305 |
-
camera_pose = camera_poses[int(im_idx)]
|
| 306 |
-
# gl to cv
|
| 307 |
-
camera_pose[:, 1:3] *= -1.0
|
| 308 |
-
if resolution != (224, 224) or self.rebuttal:
|
| 309 |
-
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 310 |
-
rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
|
| 311 |
-
)
|
| 312 |
-
else:
|
| 313 |
-
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 314 |
-
rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
|
| 315 |
-
)
|
| 316 |
-
W, H = rgb_image.size
|
| 317 |
-
cx = W // 2
|
| 318 |
-
cy = H // 2
|
| 319 |
-
l, t = cx - 112, cy - 112
|
| 320 |
-
r, b = cx + 112, cy + 112
|
| 321 |
-
crop_bbox = (l, t, r, b)
|
| 322 |
-
rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 323 |
-
rgb_image, depthmap, intrinsics, crop_bbox
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
views.append(
|
| 327 |
-
dict(
|
| 328 |
-
img=rgb_image,
|
| 329 |
-
depthmap=depthmap,
|
| 330 |
-
camera_pose=camera_pose,
|
| 331 |
-
camera_intrinsics=intrinsics,
|
| 332 |
-
dataset="nrgbd",
|
| 333 |
-
label=osp.join(scene_id, im_idx),
|
| 334 |
-
instance=impath,
|
| 335 |
-
)
|
| 336 |
-
)
|
| 337 |
-
|
| 338 |
-
return views
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/eval/dataset_utils/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
|
|
|
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (146 Bytes)
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc
DELETED
|
Binary file (140 Bytes)
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc
DELETED
|
Binary file (5.85 kB)
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc
DELETED
|
Binary file (4.29 kB)
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc
DELETED
|
Binary file (4.29 kB)
|
|
|
FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc
DELETED
|
Binary file (2.18 kB)
|
|
|
FastVGGT/eval/dataset_utils/corr.py
DELETED
|
@@ -1,234 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def todevice(batch, device, callback=None, non_blocking=False):
|
| 11 |
-
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
|
| 12 |
-
|
| 13 |
-
batch: list, tuple, dict of tensors or other things
|
| 14 |
-
device: pytorch device or 'numpy'
|
| 15 |
-
callback: function that would be called on every sub-elements.
|
| 16 |
-
"""
|
| 17 |
-
if callback:
|
| 18 |
-
batch = callback(batch)
|
| 19 |
-
|
| 20 |
-
if isinstance(batch, dict):
|
| 21 |
-
return {k: todevice(v, device) for k, v in batch.items()}
|
| 22 |
-
|
| 23 |
-
if isinstance(batch, (tuple, list)):
|
| 24 |
-
return type(batch)(todevice(x, device) for x in batch)
|
| 25 |
-
|
| 26 |
-
x = batch
|
| 27 |
-
if device == "numpy":
|
| 28 |
-
if isinstance(x, torch.Tensor):
|
| 29 |
-
x = x.detach().cpu().numpy()
|
| 30 |
-
elif x is not None:
|
| 31 |
-
if isinstance(x, np.ndarray):
|
| 32 |
-
x = torch.from_numpy(x)
|
| 33 |
-
if torch.is_tensor(x):
|
| 34 |
-
x = x.to(device, non_blocking=non_blocking)
|
| 35 |
-
return x
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
to_device = todevice # alias
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def to_numpy(x):
|
| 42 |
-
return todevice(x, "numpy")
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def geotrf(Trf, pts, ncol=None, norm=False):
|
| 46 |
-
"""Apply a geometric transformation to a list of 3-D points.
|
| 47 |
-
|
| 48 |
-
H: 3x3 or 4x4 projection matrix (typically a Homography)
|
| 49 |
-
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
| 50 |
-
|
| 51 |
-
ncol: int. number of columns of the result (2 or 3)
|
| 52 |
-
norm: float. if != 0, the resut is projected on the z=norm plane.
|
| 53 |
-
|
| 54 |
-
Returns an array of projected 2d points.
|
| 55 |
-
"""
|
| 56 |
-
assert Trf.ndim >= 2
|
| 57 |
-
if isinstance(Trf, np.ndarray):
|
| 58 |
-
pts = np.asarray(pts)
|
| 59 |
-
elif isinstance(Trf, torch.Tensor):
|
| 60 |
-
pts = torch.as_tensor(pts, dtype=Trf.dtype)
|
| 61 |
-
|
| 62 |
-
output_reshape = pts.shape[:-1]
|
| 63 |
-
ncol = ncol or pts.shape[-1]
|
| 64 |
-
|
| 65 |
-
if (
|
| 66 |
-
isinstance(Trf, torch.Tensor)
|
| 67 |
-
and isinstance(pts, torch.Tensor)
|
| 68 |
-
and Trf.ndim == 3
|
| 69 |
-
and pts.ndim == 4
|
| 70 |
-
):
|
| 71 |
-
d = pts.shape[3]
|
| 72 |
-
if Trf.shape[-1] == d:
|
| 73 |
-
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
|
| 74 |
-
elif Trf.shape[-1] == d + 1:
|
| 75 |
-
pts = (
|
| 76 |
-
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
|
| 77 |
-
+ Trf[:, None, None, :d, d]
|
| 78 |
-
)
|
| 79 |
-
else:
|
| 80 |
-
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
|
| 81 |
-
else:
|
| 82 |
-
if Trf.ndim >= 3:
|
| 83 |
-
n = Trf.ndim - 2
|
| 84 |
-
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
|
| 85 |
-
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
| 86 |
-
|
| 87 |
-
if pts.ndim > Trf.ndim:
|
| 88 |
-
|
| 89 |
-
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
| 90 |
-
elif pts.ndim == 2:
|
| 91 |
-
|
| 92 |
-
pts = pts[:, None, :]
|
| 93 |
-
|
| 94 |
-
if pts.shape[-1] + 1 == Trf.shape[-1]:
|
| 95 |
-
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 96 |
-
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
| 97 |
-
elif pts.shape[-1] == Trf.shape[-1]:
|
| 98 |
-
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 99 |
-
pts = pts @ Trf
|
| 100 |
-
else:
|
| 101 |
-
pts = Trf @ pts.T
|
| 102 |
-
if pts.ndim >= 2:
|
| 103 |
-
pts = pts.swapaxes(-1, -2)
|
| 104 |
-
|
| 105 |
-
if norm:
|
| 106 |
-
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
| 107 |
-
if norm != 1:
|
| 108 |
-
pts *= norm
|
| 109 |
-
|
| 110 |
-
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
| 111 |
-
return res
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def inv(mat):
|
| 115 |
-
"""Invert a torch or numpy matrix"""
|
| 116 |
-
if isinstance(mat, torch.Tensor):
|
| 117 |
-
return torch.linalg.inv(mat)
|
| 118 |
-
if isinstance(mat, np.ndarray):
|
| 119 |
-
return np.linalg.inv(mat)
|
| 120 |
-
raise ValueError(f"bad matrix type = {type(mat)}")
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def reproject_view(pts3d, view2):
|
| 124 |
-
shape = view2["pts3d"].shape[:2]
|
| 125 |
-
return reproject(
|
| 126 |
-
pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def reproject(pts3d, K, world2cam, shape):
|
| 131 |
-
H, W, THREE = pts3d.shape
|
| 132 |
-
assert THREE == 3
|
| 133 |
-
|
| 134 |
-
with np.errstate(divide="ignore", invalid="ignore"):
|
| 135 |
-
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
|
| 136 |
-
|
| 137 |
-
return (H, W), ravel_xy(pos, shape)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def ravel_xy(pos, shape):
|
| 141 |
-
H, W = shape
|
| 142 |
-
with np.errstate(invalid="ignore"):
|
| 143 |
-
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
|
| 144 |
-
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
|
| 145 |
-
min=0, max=H - 1, out=qy
|
| 146 |
-
)
|
| 147 |
-
return quantized_pos
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def unravel_xy(pos, shape):
|
| 151 |
-
|
| 152 |
-
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
|
| 156 |
-
is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
|
| 157 |
-
pos1 = is_reciprocal1.nonzero()[0]
|
| 158 |
-
pos2 = corres_1_to_2[pos1]
|
| 159 |
-
if ret_recip:
|
| 160 |
-
return is_reciprocal1, pos1, pos2
|
| 161 |
-
return pos1, pos2
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def extract_correspondences_from_pts3d(
|
| 165 |
-
view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
|
| 166 |
-
):
|
| 167 |
-
view1, view2 = to_numpy((view1, view2))
|
| 168 |
-
|
| 169 |
-
shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
|
| 170 |
-
shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
|
| 171 |
-
|
| 172 |
-
is_reciprocal1, pos1, pos2 = reciprocal_1d(
|
| 173 |
-
corres1_to_2, corres2_to_1, ret_recip=True
|
| 174 |
-
)
|
| 175 |
-
is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
|
| 176 |
-
|
| 177 |
-
if target_n_corres is None:
|
| 178 |
-
if ret_xy:
|
| 179 |
-
pos1 = unravel_xy(pos1, shape1)
|
| 180 |
-
pos2 = unravel_xy(pos2, shape2)
|
| 181 |
-
return pos1, pos2
|
| 182 |
-
|
| 183 |
-
available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
|
| 184 |
-
target_n_positives = int(target_n_corres * (1 - nneg))
|
| 185 |
-
n_positives = min(len(pos1), target_n_positives)
|
| 186 |
-
n_negatives = min(target_n_corres - n_positives, available_negatives)
|
| 187 |
-
|
| 188 |
-
if n_negatives + n_positives != target_n_corres:
|
| 189 |
-
|
| 190 |
-
n_positives = target_n_corres - n_negatives
|
| 191 |
-
assert n_positives <= len(pos1)
|
| 192 |
-
|
| 193 |
-
assert n_positives <= len(pos1)
|
| 194 |
-
assert n_positives <= len(pos2)
|
| 195 |
-
assert n_negatives <= (~is_reciprocal1).sum()
|
| 196 |
-
assert n_negatives <= (~is_reciprocal2).sum()
|
| 197 |
-
assert n_positives + n_negatives == target_n_corres
|
| 198 |
-
|
| 199 |
-
valid = np.ones(n_positives, dtype=bool)
|
| 200 |
-
if n_positives < len(pos1):
|
| 201 |
-
|
| 202 |
-
perm = rng.permutation(len(pos1))[:n_positives]
|
| 203 |
-
pos1 = pos1[perm]
|
| 204 |
-
pos2 = pos2[perm]
|
| 205 |
-
|
| 206 |
-
if n_negatives > 0:
|
| 207 |
-
|
| 208 |
-
def norm(p):
|
| 209 |
-
return p / p.sum()
|
| 210 |
-
|
| 211 |
-
pos1 = np.r_[
|
| 212 |
-
pos1,
|
| 213 |
-
rng.choice(
|
| 214 |
-
shape1[0] * shape1[1],
|
| 215 |
-
size=n_negatives,
|
| 216 |
-
replace=False,
|
| 217 |
-
p=norm(~is_reciprocal1),
|
| 218 |
-
),
|
| 219 |
-
]
|
| 220 |
-
pos2 = np.r_[
|
| 221 |
-
pos2,
|
| 222 |
-
rng.choice(
|
| 223 |
-
shape2[0] * shape2[1],
|
| 224 |
-
size=n_negatives,
|
| 225 |
-
replace=False,
|
| 226 |
-
p=norm(~is_reciprocal2),
|
| 227 |
-
),
|
| 228 |
-
]
|
| 229 |
-
valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
|
| 230 |
-
|
| 231 |
-
if ret_xy:
|
| 232 |
-
pos1 = unravel_xy(pos1, shape1)
|
| 233 |
-
pos2 = unravel_xy(pos2, shape2)
|
| 234 |
-
return pos1, pos2, valid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/eval/dataset_utils/cropping.py
DELETED
|
@@ -1,140 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
|
| 6 |
-
import PIL.Image
|
| 7 |
-
import os
|
| 8 |
-
|
| 9 |
-
from utils import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
|
| 10 |
-
|
| 11 |
-
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 12 |
-
import cv2 # noqa
|
| 13 |
-
import numpy as np # noqa
|
| 14 |
-
|
| 15 |
-
try:
|
| 16 |
-
lanczos = PIL.Image.Resampling.LANCZOS
|
| 17 |
-
bicubic = PIL.Image.Resampling.BICUBIC
|
| 18 |
-
except AttributeError:
|
| 19 |
-
lanczos = PIL.Image.LANCZOS
|
| 20 |
-
bicubic = PIL.Image.BICUBIC
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ImageList:
|
| 24 |
-
"""Convenience class to aply the same operation to a whole set of images."""
|
| 25 |
-
|
| 26 |
-
def __init__(self, images):
|
| 27 |
-
if not isinstance(images, (tuple, list, set)):
|
| 28 |
-
images = [images]
|
| 29 |
-
self.images = []
|
| 30 |
-
for image in images:
|
| 31 |
-
if not isinstance(image, PIL.Image.Image):
|
| 32 |
-
image = PIL.Image.fromarray(image)
|
| 33 |
-
self.images.append(image)
|
| 34 |
-
|
| 35 |
-
def __len__(self):
|
| 36 |
-
return len(self.images)
|
| 37 |
-
|
| 38 |
-
def to_pil(self):
|
| 39 |
-
return tuple(self.images) if len(self.images) > 1 else self.images[0]
|
| 40 |
-
|
| 41 |
-
@property
|
| 42 |
-
def size(self):
|
| 43 |
-
sizes = [im.size for im in self.images]
|
| 44 |
-
assert all(sizes[0] == s for s in sizes)
|
| 45 |
-
return sizes[0]
|
| 46 |
-
|
| 47 |
-
def resize(self, *args, **kwargs):
|
| 48 |
-
return ImageList(self._dispatch("resize", *args, **kwargs))
|
| 49 |
-
|
| 50 |
-
def crop(self, *args, **kwargs):
|
| 51 |
-
return ImageList(self._dispatch("crop", *args, **kwargs))
|
| 52 |
-
|
| 53 |
-
def _dispatch(self, func, *args, **kwargs):
|
| 54 |
-
return [getattr(im, func)(*args, **kwargs) for im in self.images]
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def rescale_image_depthmap(
|
| 58 |
-
image, depthmap, camera_intrinsics, output_resolution, force=True
|
| 59 |
-
):
|
| 60 |
-
"""Jointly rescale a (image, depthmap)
|
| 61 |
-
so that (out_width, out_height) >= output_res
|
| 62 |
-
"""
|
| 63 |
-
image = ImageList(image)
|
| 64 |
-
input_resolution = np.array(image.size) # (W,H)
|
| 65 |
-
output_resolution = np.array(output_resolution)
|
| 66 |
-
if depthmap is not None:
|
| 67 |
-
|
| 68 |
-
assert tuple(depthmap.shape[:2]) == image.size[::-1]
|
| 69 |
-
|
| 70 |
-
assert output_resolution.shape == (2,)
|
| 71 |
-
scale_final = max(output_resolution / image.size) + 1e-8
|
| 72 |
-
if scale_final >= 1 and not force: # image is already smaller than what is asked
|
| 73 |
-
return (image.to_pil(), depthmap, camera_intrinsics)
|
| 74 |
-
output_resolution = np.floor(input_resolution * scale_final).astype(int)
|
| 75 |
-
|
| 76 |
-
image = image.resize(
|
| 77 |
-
output_resolution, resample=lanczos if scale_final < 1 else bicubic
|
| 78 |
-
)
|
| 79 |
-
if depthmap is not None:
|
| 80 |
-
depthmap = cv2.resize(
|
| 81 |
-
depthmap,
|
| 82 |
-
output_resolution,
|
| 83 |
-
fx=scale_final,
|
| 84 |
-
fy=scale_final,
|
| 85 |
-
interpolation=cv2.INTER_NEAREST,
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
camera_intrinsics = camera_matrix_of_crop(
|
| 89 |
-
camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
return image.to_pil(), depthmap, camera_intrinsics
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def camera_matrix_of_crop(
|
| 96 |
-
input_camera_matrix,
|
| 97 |
-
input_resolution,
|
| 98 |
-
output_resolution,
|
| 99 |
-
scaling=1,
|
| 100 |
-
offset_factor=0.5,
|
| 101 |
-
offset=None,
|
| 102 |
-
):
|
| 103 |
-
|
| 104 |
-
margins = np.asarray(input_resolution) * scaling - output_resolution
|
| 105 |
-
assert np.all(margins >= 0.0)
|
| 106 |
-
if offset is None:
|
| 107 |
-
offset = offset_factor * margins
|
| 108 |
-
|
| 109 |
-
output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
|
| 110 |
-
output_camera_matrix_colmap[:2, :] *= scaling
|
| 111 |
-
output_camera_matrix_colmap[:2, 2] -= offset
|
| 112 |
-
output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
|
| 113 |
-
|
| 114 |
-
return output_camera_matrix
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
|
| 118 |
-
"""
|
| 119 |
-
Return a crop of the input view.
|
| 120 |
-
"""
|
| 121 |
-
image = ImageList(image)
|
| 122 |
-
l, t, r, b = crop_bbox
|
| 123 |
-
|
| 124 |
-
image = image.crop((l, t, r, b))
|
| 125 |
-
depthmap = depthmap[t:b, l:r]
|
| 126 |
-
|
| 127 |
-
camera_intrinsics = camera_intrinsics.copy()
|
| 128 |
-
camera_intrinsics[0, 2] -= l
|
| 129 |
-
camera_intrinsics[1, 2] -= t
|
| 130 |
-
|
| 131 |
-
return image.to_pil(), depthmap, camera_intrinsics
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def bbox_from_intrinsics_in_out(
|
| 135 |
-
input_camera_matrix, output_camera_matrix, output_resolution
|
| 136 |
-
):
|
| 137 |
-
out_width, out_height = output_resolution
|
| 138 |
-
l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
|
| 139 |
-
crop_bbox = (l, t, l + out_width, t + out_height)
|
| 140 |
-
return crop_bbox
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/eval/dataset_utils/transforms.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
|
| 6 |
-
import torchvision.transforms as tvf
|
| 7 |
-
|
| 8 |
-
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
|
| 15 |
-
if isinstance(value, (int, float)):
|
| 16 |
-
if value < 0:
|
| 17 |
-
raise ValueError(f"If is a single number, it must be non negative.")
|
| 18 |
-
value = [center - float(value), center + float(value)]
|
| 19 |
-
if clip_first_on_zero:
|
| 20 |
-
value[0] = max(value[0], 0.0)
|
| 21 |
-
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
| 22 |
-
value = [float(value[0]), float(value[1])]
|
| 23 |
-
else:
|
| 24 |
-
raise TypeError(f"should be a single number or a list/tuple with length 2.")
|
| 25 |
-
|
| 26 |
-
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
| 27 |
-
raise ValueError(f"values should be between {bound}, but got {value}.")
|
| 28 |
-
|
| 29 |
-
if value[0] == value[1] == center:
|
| 30 |
-
return None
|
| 31 |
-
else:
|
| 32 |
-
return tuple(value)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
import torch
|
| 36 |
-
import torchvision.transforms.functional as F
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def SeqColorJitter():
|
| 40 |
-
"""
|
| 41 |
-
Return a color jitter transform with same random parameters
|
| 42 |
-
"""
|
| 43 |
-
brightness = _check_input(0.5)
|
| 44 |
-
contrast = _check_input(0.5)
|
| 45 |
-
saturation = _check_input(0.5)
|
| 46 |
-
hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
| 47 |
-
|
| 48 |
-
fn_idx = torch.randperm(4)
|
| 49 |
-
brightness_factor = (
|
| 50 |
-
None
|
| 51 |
-
if brightness is None
|
| 52 |
-
else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
|
| 53 |
-
)
|
| 54 |
-
contrast_factor = (
|
| 55 |
-
None
|
| 56 |
-
if contrast is None
|
| 57 |
-
else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
|
| 58 |
-
)
|
| 59 |
-
saturation_factor = (
|
| 60 |
-
None
|
| 61 |
-
if saturation is None
|
| 62 |
-
else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
|
| 63 |
-
)
|
| 64 |
-
hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
|
| 65 |
-
|
| 66 |
-
def _color_jitter(img):
|
| 67 |
-
for fn_id in fn_idx:
|
| 68 |
-
if fn_id == 0 and brightness_factor is not None:
|
| 69 |
-
img = F.adjust_brightness(img, brightness_factor)
|
| 70 |
-
elif fn_id == 1 and contrast_factor is not None:
|
| 71 |
-
img = F.adjust_contrast(img, contrast_factor)
|
| 72 |
-
elif fn_id == 2 and saturation_factor is not None:
|
| 73 |
-
img = F.adjust_saturation(img, saturation_factor)
|
| 74 |
-
elif fn_id == 3 and hue_factor is not None:
|
| 75 |
-
img = F.adjust_hue(img, hue_factor)
|
| 76 |
-
return ImgNorm(img)
|
| 77 |
-
|
| 78 |
-
return _color_jitter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/eval/eval_7andN.py
DELETED
|
@@ -1,497 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
|
| 4 |
-
# Ensure project root is on sys.path for absolute imports like `vggt.*`
|
| 5 |
-
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
| 6 |
-
if ROOT_DIR not in sys.path:
|
| 7 |
-
sys.path.insert(0, ROOT_DIR)
|
| 8 |
-
|
| 9 |
-
import time
|
| 10 |
-
import torch
|
| 11 |
-
import argparse
|
| 12 |
-
import numpy as np
|
| 13 |
-
import open3d as o3d
|
| 14 |
-
import os.path as osp
|
| 15 |
-
from torch.utils.data import DataLoader
|
| 16 |
-
from torch.utils.data._utils.collate import default_collate
|
| 17 |
-
from tqdm import tqdm
|
| 18 |
-
from collections import defaultdict
|
| 19 |
-
import torchvision.transforms as transforms
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def get_args_parser():
|
| 23 |
-
parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False)
|
| 24 |
-
parser.add_argument(
|
| 25 |
-
"--ckpt_path",
|
| 26 |
-
type=str,
|
| 27 |
-
default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
|
| 28 |
-
help="ckpt name",
|
| 29 |
-
)
|
| 30 |
-
parser.add_argument("--device", type=str, default="cuda:0", help="device")
|
| 31 |
-
parser.add_argument("--model_name", type=str, default="VGGT")
|
| 32 |
-
parser.add_argument(
|
| 33 |
-
"--conf_thresh", type=float, default=0.0, help="confidence threshold"
|
| 34 |
-
)
|
| 35 |
-
parser.add_argument(
|
| 36 |
-
"--output_dir",
|
| 37 |
-
type=str,
|
| 38 |
-
default="/home/sy/code/FastVGGT/eval_results",
|
| 39 |
-
help="value for outdir",
|
| 40 |
-
)
|
| 41 |
-
parser.add_argument("--size", type=int, default=518)
|
| 42 |
-
parser.add_argument("--revisit", type=int, default=1, help="revisit times")
|
| 43 |
-
parser.add_argument("--freeze", action="store_true")
|
| 44 |
-
parser.add_argument("--use_proj", action="store_true")
|
| 45 |
-
parser.add_argument(
|
| 46 |
-
"--merging", type=int, default=0, help="VGGT aggregator merging steps"
|
| 47 |
-
)
|
| 48 |
-
parser.add_argument("--kf", type=int, default=2, help="key frame")
|
| 49 |
-
return parser
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def main(args):
|
| 53 |
-
from data import SevenScenes, NRGBD
|
| 54 |
-
from utils import accuracy, completion
|
| 55 |
-
|
| 56 |
-
if args.size == 512:
|
| 57 |
-
resolution = (512, 384)
|
| 58 |
-
elif args.size == 224:
|
| 59 |
-
resolution = 224
|
| 60 |
-
elif args.size == 518:
|
| 61 |
-
resolution = (518, 392)
|
| 62 |
-
else:
|
| 63 |
-
raise NotImplementedError
|
| 64 |
-
datasets_all = {
|
| 65 |
-
"7scenes": SevenScenes(
|
| 66 |
-
split="test",
|
| 67 |
-
ROOT="/data/sy/7scenes",
|
| 68 |
-
resolution=resolution,
|
| 69 |
-
num_seq=1,
|
| 70 |
-
full_video=True,
|
| 71 |
-
kf_every=args.kf,
|
| 72 |
-
), # 20),
|
| 73 |
-
"NRGBD": NRGBD(
|
| 74 |
-
split="test",
|
| 75 |
-
ROOT="/data/sy/neural_rgbd_data",
|
| 76 |
-
resolution=resolution,
|
| 77 |
-
num_seq=1,
|
| 78 |
-
full_video=True,
|
| 79 |
-
kf_every=args.kf,
|
| 80 |
-
),
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
device = args.device
|
| 84 |
-
model_name = args.model_name
|
| 85 |
-
|
| 86 |
-
from vggt.models.vggt import VGGT
|
| 87 |
-
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 88 |
-
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
| 89 |
-
from criterion import Regr3D_t_ScaleShiftInv, L21
|
| 90 |
-
|
| 91 |
-
# Force use of bf16 data type
|
| 92 |
-
dtype = torch.bfloat16
|
| 93 |
-
# Load VGGT model
|
| 94 |
-
model = VGGT(merging=args.merging, enable_point=True)
|
| 95 |
-
ckpt = torch.load(args.ckpt_path, map_location="cpu")
|
| 96 |
-
|
| 97 |
-
# ✅ Fix: load pre-trained weights
|
| 98 |
-
model.load_state_dict(
|
| 99 |
-
ckpt, strict=False
|
| 100 |
-
) # Use strict=False due to enable_point=True difference
|
| 101 |
-
|
| 102 |
-
model = model.cuda().eval()
|
| 103 |
-
model = model.to(torch.bfloat16)
|
| 104 |
-
|
| 105 |
-
del ckpt
|
| 106 |
-
os.makedirs(osp.join(args.output_dir, f"{args.kf}"), exist_ok=True)
|
| 107 |
-
|
| 108 |
-
criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True)
|
| 109 |
-
|
| 110 |
-
with torch.no_grad():
|
| 111 |
-
for name_data, dataset in datasets_all.items():
|
| 112 |
-
save_path = osp.join(osp.join(args.output_dir, f"{args.kf}"), name_data)
|
| 113 |
-
os.makedirs(save_path, exist_ok=True)
|
| 114 |
-
log_file = osp.join(save_path, "logs.txt")
|
| 115 |
-
|
| 116 |
-
acc_all = 0
|
| 117 |
-
acc_all_med = 0
|
| 118 |
-
comp_all = 0
|
| 119 |
-
comp_all_med = 0
|
| 120 |
-
nc1_all = 0
|
| 121 |
-
nc1_all_med = 0
|
| 122 |
-
nc2_all = 0
|
| 123 |
-
nc2_all_med = 0
|
| 124 |
-
scene_infer_times = defaultdict(list)
|
| 125 |
-
|
| 126 |
-
for data_idx in tqdm(range(len(dataset))):
|
| 127 |
-
batch = default_collate([dataset[data_idx]])
|
| 128 |
-
ignore_keys = set(
|
| 129 |
-
[
|
| 130 |
-
"depthmap",
|
| 131 |
-
"dataset",
|
| 132 |
-
"label",
|
| 133 |
-
"instance",
|
| 134 |
-
"idx",
|
| 135 |
-
"true_shape",
|
| 136 |
-
"rng",
|
| 137 |
-
]
|
| 138 |
-
)
|
| 139 |
-
for view in batch:
|
| 140 |
-
for name in view.keys(): # pseudo_focal
|
| 141 |
-
if name in ignore_keys:
|
| 142 |
-
continue
|
| 143 |
-
if isinstance(view[name], tuple) or isinstance(
|
| 144 |
-
view[name], list
|
| 145 |
-
):
|
| 146 |
-
view[name] = [
|
| 147 |
-
x.to(device, non_blocking=True) for x in view[name]
|
| 148 |
-
]
|
| 149 |
-
else:
|
| 150 |
-
view[name] = view[name].to(device, non_blocking=True)
|
| 151 |
-
|
| 152 |
-
pts_all = []
|
| 153 |
-
pts_gt_all = []
|
| 154 |
-
images_all = []
|
| 155 |
-
masks_all = []
|
| 156 |
-
conf_all = []
|
| 157 |
-
in_camera1 = None
|
| 158 |
-
|
| 159 |
-
dtype = (
|
| 160 |
-
torch.bfloat16
|
| 161 |
-
if torch.cuda.get_device_capability()[0] >= 8
|
| 162 |
-
else torch.float16
|
| 163 |
-
)
|
| 164 |
-
with torch.cuda.amp.autocast(dtype=dtype):
|
| 165 |
-
if isinstance(batch, dict) and "img" in batch:
|
| 166 |
-
batch["img"] = (batch["img"] + 1.0) / 2.0
|
| 167 |
-
elif isinstance(batch, list) and all(
|
| 168 |
-
isinstance(v, dict) and "img" in v for v in batch
|
| 169 |
-
):
|
| 170 |
-
for view in batch:
|
| 171 |
-
view["img"] = (view["img"] + 1.0) / 2.0
|
| 172 |
-
# Gather all `img` tensors into a single tensor of shape [N, C, H, W]
|
| 173 |
-
imgs_tensor = torch.cat([v["img"] for v in batch], dim=0)
|
| 174 |
-
|
| 175 |
-
with torch.cuda.amp.autocast(dtype=dtype):
|
| 176 |
-
with torch.no_grad():
|
| 177 |
-
torch.cuda.synchronize()
|
| 178 |
-
start = time.time()
|
| 179 |
-
preds = model(imgs_tensor)
|
| 180 |
-
torch.cuda.synchronize()
|
| 181 |
-
end = time.time()
|
| 182 |
-
inference_time_ms = (end - start) * 1000
|
| 183 |
-
print(f"Inference time: {inference_time_ms:.2f}ms")
|
| 184 |
-
|
| 185 |
-
# Wrap model outputs per-view to align with batch later
|
| 186 |
-
predictions = preds
|
| 187 |
-
views = batch # list[dict]
|
| 188 |
-
if "pose_enc" in predictions:
|
| 189 |
-
B, S = predictions["pose_enc"].shape[:2]
|
| 190 |
-
elif "world_points" in predictions:
|
| 191 |
-
B, S = predictions["world_points"].shape[:2]
|
| 192 |
-
else:
|
| 193 |
-
raise KeyError(
|
| 194 |
-
"predictions is missing a key to infer sequence length"
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
ress = []
|
| 198 |
-
for s in range(S):
|
| 199 |
-
res = {
|
| 200 |
-
"pts3d_in_other_view": predictions["world_points"][:, s],
|
| 201 |
-
"conf": predictions["world_points_conf"][:, s],
|
| 202 |
-
"depth": predictions["depth"][:, s],
|
| 203 |
-
"depth_conf": predictions["depth_conf"][:, s],
|
| 204 |
-
"camera_pose": predictions["pose_enc"][:, s, :],
|
| 205 |
-
}
|
| 206 |
-
if (
|
| 207 |
-
isinstance(views, list)
|
| 208 |
-
and s < len(views)
|
| 209 |
-
and "valid_mask" in views[s]
|
| 210 |
-
):
|
| 211 |
-
res["valid_mask"] = views[s]["valid_mask"]
|
| 212 |
-
if "track" in predictions:
|
| 213 |
-
res.update(
|
| 214 |
-
{
|
| 215 |
-
"track": predictions["track"][:, s],
|
| 216 |
-
"vis": (
|
| 217 |
-
predictions.get("vis", None)[:, s]
|
| 218 |
-
if "vis" in predictions
|
| 219 |
-
else None
|
| 220 |
-
),
|
| 221 |
-
"track_conf": (
|
| 222 |
-
predictions.get("conf", None)[:, s]
|
| 223 |
-
if "conf" in predictions
|
| 224 |
-
else None
|
| 225 |
-
),
|
| 226 |
-
}
|
| 227 |
-
)
|
| 228 |
-
ress.append(res)
|
| 229 |
-
|
| 230 |
-
preds = ress
|
| 231 |
-
|
| 232 |
-
valid_length = len(preds) // args.revisit
|
| 233 |
-
if args.revisit > 1:
|
| 234 |
-
preds = preds[-valid_length:]
|
| 235 |
-
batch = batch[-valid_length:]
|
| 236 |
-
|
| 237 |
-
# Evaluation
|
| 238 |
-
print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}")
|
| 239 |
-
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
|
| 240 |
-
criterion.get_all_pts3d_t(batch, preds)
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
in_camera1 = None
|
| 244 |
-
pts_all = []
|
| 245 |
-
pts_gt_all = []
|
| 246 |
-
images_all = []
|
| 247 |
-
masks_all = []
|
| 248 |
-
conf_all = []
|
| 249 |
-
|
| 250 |
-
for j, view in enumerate(batch):
|
| 251 |
-
if in_camera1 is None:
|
| 252 |
-
in_camera1 = view["camera_pose"][0].cpu()
|
| 253 |
-
|
| 254 |
-
image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0]
|
| 255 |
-
mask = view["valid_mask"].cpu().numpy()[0]
|
| 256 |
-
|
| 257 |
-
pts = pred_pts[j].cpu().numpy()[0]
|
| 258 |
-
conf = preds[j]["conf"].cpu().data.numpy()[0]
|
| 259 |
-
|
| 260 |
-
# mask = mask & (conf > 1.8)
|
| 261 |
-
|
| 262 |
-
pts_gt = gt_pts[j].detach().cpu().numpy()[0]
|
| 263 |
-
|
| 264 |
-
H, W = image.shape[:2]
|
| 265 |
-
cx = W // 2
|
| 266 |
-
cy = H // 2
|
| 267 |
-
l, t = cx - 112, cy - 112
|
| 268 |
-
r, b = cx + 112, cy + 112
|
| 269 |
-
image = image[t:b, l:r]
|
| 270 |
-
mask = mask[t:b, l:r]
|
| 271 |
-
pts = pts[t:b, l:r]
|
| 272 |
-
pts_gt = pts_gt[t:b, l:r]
|
| 273 |
-
|
| 274 |
-
images_all.append(image[None, ...])
|
| 275 |
-
pts_all.append(pts[None, ...])
|
| 276 |
-
pts_gt_all.append(pts_gt[None, ...])
|
| 277 |
-
masks_all.append(mask[None, ...])
|
| 278 |
-
conf_all.append(conf[None, ...])
|
| 279 |
-
|
| 280 |
-
images_all = np.concatenate(images_all, axis=0)
|
| 281 |
-
pts_all = np.concatenate(pts_all, axis=0)
|
| 282 |
-
pts_gt_all = np.concatenate(pts_gt_all, axis=0)
|
| 283 |
-
masks_all = np.concatenate(masks_all, axis=0)
|
| 284 |
-
|
| 285 |
-
scene_id = view["label"][0].rsplit("/", 1)[0]
|
| 286 |
-
# Record average inference time per scene
|
| 287 |
-
try:
|
| 288 |
-
scene_infer_times[scene_id].append(float(inference_time_ms))
|
| 289 |
-
except Exception:
|
| 290 |
-
pass
|
| 291 |
-
|
| 292 |
-
save_params = {}
|
| 293 |
-
|
| 294 |
-
save_params["images_all"] = images_all
|
| 295 |
-
save_params["pts_all"] = pts_all
|
| 296 |
-
save_params["pts_gt_all"] = pts_gt_all
|
| 297 |
-
save_params["masks_all"] = masks_all
|
| 298 |
-
|
| 299 |
-
pts_all_masked = pts_all[masks_all > 0]
|
| 300 |
-
pts_gt_all_masked = pts_gt_all[masks_all > 0]
|
| 301 |
-
images_all_masked = images_all[masks_all > 0]
|
| 302 |
-
|
| 303 |
-
mask = np.isfinite(pts_all_masked)
|
| 304 |
-
pts_all_masked = pts_all_masked[mask]
|
| 305 |
-
|
| 306 |
-
mask_gt = np.isfinite(pts_gt_all_masked)
|
| 307 |
-
pts_gt_all_masked = pts_gt_all_masked[mask_gt]
|
| 308 |
-
images_all_masked = images_all_masked[mask]
|
| 309 |
-
|
| 310 |
-
# Reshape to point cloud (N, 3) before sampling
|
| 311 |
-
pts_all_masked = pts_all_masked.reshape(-1, 3)
|
| 312 |
-
pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
|
| 313 |
-
images_all_masked = images_all_masked.reshape(-1, 3)
|
| 314 |
-
|
| 315 |
-
# If number of points exceeds threshold, sample by points
|
| 316 |
-
if pts_all_masked.shape[0] > 999999:
|
| 317 |
-
sample_indices = np.random.choice(
|
| 318 |
-
pts_all_masked.shape[0], 999999, replace=False
|
| 319 |
-
)
|
| 320 |
-
pts_all_masked = pts_all_masked[sample_indices]
|
| 321 |
-
images_all_masked = images_all_masked[sample_indices]
|
| 322 |
-
|
| 323 |
-
# Apply the same sampling to GT point cloud
|
| 324 |
-
if pts_gt_all_masked.shape[0] > 999999:
|
| 325 |
-
sample_indices_gt = np.random.choice(
|
| 326 |
-
pts_gt_all_masked.shape[0], 999999, replace=False
|
| 327 |
-
)
|
| 328 |
-
pts_gt_all_masked = pts_gt_all_masked[sample_indices_gt]
|
| 329 |
-
|
| 330 |
-
if args.use_proj:
|
| 331 |
-
|
| 332 |
-
def umeyama_alignment(
|
| 333 |
-
src: np.ndarray, dst: np.ndarray, with_scale: bool = True
|
| 334 |
-
):
|
| 335 |
-
assert src.shape == dst.shape
|
| 336 |
-
N, dim = src.shape
|
| 337 |
-
|
| 338 |
-
mu_src = src.mean(axis=0)
|
| 339 |
-
mu_dst = dst.mean(axis=0)
|
| 340 |
-
src_c = src - mu_src
|
| 341 |
-
dst_c = dst - mu_dst
|
| 342 |
-
|
| 343 |
-
Sigma = dst_c.T @ src_c / N # (3,3)
|
| 344 |
-
|
| 345 |
-
U, D, Vt = np.linalg.svd(Sigma)
|
| 346 |
-
|
| 347 |
-
S = np.eye(dim)
|
| 348 |
-
if np.linalg.det(U) * np.linalg.det(Vt) < 0:
|
| 349 |
-
S[-1, -1] = -1
|
| 350 |
-
|
| 351 |
-
R = U @ S @ Vt
|
| 352 |
-
|
| 353 |
-
if with_scale:
|
| 354 |
-
var_src = (src_c**2).sum() / N
|
| 355 |
-
s = (D * S.diagonal()).sum() / var_src
|
| 356 |
-
else:
|
| 357 |
-
s = 1.0
|
| 358 |
-
|
| 359 |
-
t = mu_dst - s * R @ mu_src
|
| 360 |
-
|
| 361 |
-
return s, R, t
|
| 362 |
-
|
| 363 |
-
pts_all_masked = pts_all_masked.reshape(-1, 3)
|
| 364 |
-
pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
|
| 365 |
-
s, R, t = umeyama_alignment(
|
| 366 |
-
pts_all_masked, pts_gt_all_masked, with_scale=True
|
| 367 |
-
)
|
| 368 |
-
pts_all_aligned = (s * (R @ pts_all_masked.T)).T + t # (N,3)
|
| 369 |
-
pts_all_masked = pts_all_aligned
|
| 370 |
-
|
| 371 |
-
pcd = o3d.geometry.PointCloud()
|
| 372 |
-
pcd.points = o3d.utility.Vector3dVector(pts_all_masked)
|
| 373 |
-
pcd.colors = o3d.utility.Vector3dVector(images_all_masked)
|
| 374 |
-
|
| 375 |
-
pcd_gt = o3d.geometry.PointCloud()
|
| 376 |
-
pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked)
|
| 377 |
-
pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked)
|
| 378 |
-
|
| 379 |
-
trans_init = np.eye(4)
|
| 380 |
-
|
| 381 |
-
threshold = 0.1
|
| 382 |
-
reg_p2p = o3d.pipelines.registration.registration_icp(
|
| 383 |
-
pcd,
|
| 384 |
-
pcd_gt,
|
| 385 |
-
threshold,
|
| 386 |
-
trans_init,
|
| 387 |
-
o3d.pipelines.registration.TransformationEstimationPointToPoint(),
|
| 388 |
-
)
|
| 389 |
-
|
| 390 |
-
transformation = reg_p2p.transformation
|
| 391 |
-
|
| 392 |
-
pcd = pcd.transform(transformation)
|
| 393 |
-
pcd.estimate_normals()
|
| 394 |
-
pcd_gt.estimate_normals()
|
| 395 |
-
|
| 396 |
-
gt_normal = np.asarray(pcd_gt.normals)
|
| 397 |
-
pred_normal = np.asarray(pcd.normals)
|
| 398 |
-
|
| 399 |
-
acc, acc_med, nc1, nc1_med = accuracy(
|
| 400 |
-
pcd_gt.points, pcd.points, gt_normal, pred_normal
|
| 401 |
-
)
|
| 402 |
-
comp, comp_med, nc2, nc2_med = completion(
|
| 403 |
-
pcd_gt.points, pcd.points, gt_normal, pred_normal
|
| 404 |
-
)
|
| 405 |
-
print(
|
| 406 |
-
f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}"
|
| 407 |
-
)
|
| 408 |
-
print(
|
| 409 |
-
f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}",
|
| 410 |
-
file=open(log_file, "a"),
|
| 411 |
-
)
|
| 412 |
-
|
| 413 |
-
acc_all += acc
|
| 414 |
-
comp_all += comp
|
| 415 |
-
nc1_all += nc1
|
| 416 |
-
nc2_all += nc2
|
| 417 |
-
|
| 418 |
-
acc_all_med += acc_med
|
| 419 |
-
comp_all_med += comp_med
|
| 420 |
-
nc1_all_med += nc1_med
|
| 421 |
-
nc2_all_med += nc2_med
|
| 422 |
-
|
| 423 |
-
# release cuda memory
|
| 424 |
-
torch.cuda.empty_cache()
|
| 425 |
-
|
| 426 |
-
# Get depth from pcd and run TSDFusion
|
| 427 |
-
to_write = ""
|
| 428 |
-
# Read the log file
|
| 429 |
-
if os.path.exists(osp.join(save_path, "logs.txt")):
|
| 430 |
-
with open(osp.join(save_path, "logs.txt"), "r") as f_sub:
|
| 431 |
-
to_write += f_sub.read()
|
| 432 |
-
|
| 433 |
-
with open(osp.join(save_path, f"logs_all.txt"), "w") as f:
|
| 434 |
-
log_data = to_write
|
| 435 |
-
metrics = defaultdict(list)
|
| 436 |
-
for line in log_data.strip().split("\n"):
|
| 437 |
-
match = regex.match(line)
|
| 438 |
-
if match:
|
| 439 |
-
data = match.groupdict()
|
| 440 |
-
# Exclude 'scene_id' from metrics as it's an identifier
|
| 441 |
-
for key, value in data.items():
|
| 442 |
-
if key != "scene_id":
|
| 443 |
-
metrics[key].append(float(value))
|
| 444 |
-
metrics["nc"].append(
|
| 445 |
-
(float(data["nc1"]) + float(data["nc2"])) / 2
|
| 446 |
-
)
|
| 447 |
-
metrics["nc_med"].append(
|
| 448 |
-
(float(data["nc1_med"]) + float(data["nc2_med"])) / 2
|
| 449 |
-
)
|
| 450 |
-
mean_metrics = {
|
| 451 |
-
metric: sum(values) / len(values)
|
| 452 |
-
for metric, values in metrics.items()
|
| 453 |
-
}
|
| 454 |
-
|
| 455 |
-
c_name = "mean"
|
| 456 |
-
print_str = f"{c_name.ljust(20)}: "
|
| 457 |
-
for m_name in mean_metrics:
|
| 458 |
-
print_num = np.mean(mean_metrics[m_name])
|
| 459 |
-
print_str = print_str + f"{m_name}: {print_num:.3f} | "
|
| 460 |
-
print_str = print_str + "\n"
|
| 461 |
-
# Summarize per-scene average inference time
|
| 462 |
-
time_lines = []
|
| 463 |
-
for sid, times in scene_infer_times.items():
|
| 464 |
-
if len(times) > 0:
|
| 465 |
-
time_lines.append(
|
| 466 |
-
f"Idx: {sid}, Time_avg_ms: {np.mean(times):.2f}"
|
| 467 |
-
)
|
| 468 |
-
time_block = "\n".join(time_lines) + (
|
| 469 |
-
"\n" if len(time_lines) > 0 else ""
|
| 470 |
-
)
|
| 471 |
-
|
| 472 |
-
f.write(to_write + time_block + print_str)
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
from collections import defaultdict
|
| 476 |
-
import re
|
| 477 |
-
|
| 478 |
-
pattern = r"""
|
| 479 |
-
Idx:\s*(?P<scene_id>[^,]+),\s*
|
| 480 |
-
Acc:\s*(?P<acc>[^,]+),\s*
|
| 481 |
-
Comp:\s*(?P<comp>[^,]+),\s*
|
| 482 |
-
NC1:\s*(?P<nc1>[^,]+),\s*
|
| 483 |
-
NC2:\s*(?P<nc2>[^,]+)\s*-\s*
|
| 484 |
-
Acc_med:\s*(?P<acc_med>[^,]+),\s*
|
| 485 |
-
Compc_med:\s*(?P<comp_med>[^,]+),\s*
|
| 486 |
-
NC1c_med:\s*(?P<nc1_med>[^,]+),\s*
|
| 487 |
-
NC2c_med:\s*(?P<nc2_med>[^,]+)
|
| 488 |
-
"""
|
| 489 |
-
|
| 490 |
-
regex = re.compile(pattern, re.VERBOSE)
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
if __name__ == "__main__":
|
| 494 |
-
parser = get_args_parser()
|
| 495 |
-
args = parser.parse_args()
|
| 496 |
-
|
| 497 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/eval/eval_custom.py
DELETED
|
@@ -1,467 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import os
|
| 6 |
-
import sys
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
from scipy.spatial.transform import Rotation
|
| 9 |
-
|
| 10 |
-
# Ensure project root is in sys.path for absolute imports like `vggt.*`
|
| 11 |
-
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
| 12 |
-
if ROOT_DIR not in sys.path:
|
| 13 |
-
sys.path.insert(0, ROOT_DIR)
|
| 14 |
-
|
| 15 |
-
from vggt.models.vggt import VGGT
|
| 16 |
-
from vggt.utils.eval_utils import (
|
| 17 |
-
load_poses,
|
| 18 |
-
get_vgg_input_imgs,
|
| 19 |
-
get_sorted_image_paths,
|
| 20 |
-
build_frame_selection,
|
| 21 |
-
load_images_rgb,
|
| 22 |
-
infer_vggt_and_reconstruct,
|
| 23 |
-
evaluate_scene_and_save,
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
# Import pose visualization libraries (optional EVO support)
|
| 27 |
-
try:
|
| 28 |
-
from evo.core.trajectory import PoseTrajectory3D
|
| 29 |
-
import evo.tools.plot as plot
|
| 30 |
-
|
| 31 |
-
EVO_AVAILABLE = True
|
| 32 |
-
except ImportError:
|
| 33 |
-
# EVO is optional; we have a matplotlib-based fallback
|
| 34 |
-
EVO_AVAILABLE = False
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def visualize_predicted_poses(
|
| 38 |
-
all_cam_to_world_mat, frame_ids, output_scene_dir, scene_name="custom_dataset"
|
| 39 |
-
):
|
| 40 |
-
"""
|
| 41 |
-
Visualize the predicted camera pose trajectory (no GT comparison required).
|
| 42 |
-
|
| 43 |
-
Args:
|
| 44 |
-
all_cam_to_world_mat: List of camera-to-world transform matrices
|
| 45 |
-
frame_ids: List of frame IDs
|
| 46 |
-
output_scene_dir: Output directory
|
| 47 |
-
scene_name: Scene name
|
| 48 |
-
"""
|
| 49 |
-
# Provide basic pose visualization even without EVO
|
| 50 |
-
if not EVO_AVAILABLE:
|
| 51 |
-
print("⚠️ EVO not installed; using basic matplotlib visualization")
|
| 52 |
-
|
| 53 |
-
try:
|
| 54 |
-
# Convert to numpy array
|
| 55 |
-
poses_est = np.array(all_cam_to_world_mat)
|
| 56 |
-
|
| 57 |
-
if len(poses_est) < 2:
|
| 58 |
-
print("⚠️ Not enough poses to generate trajectory plot")
|
| 59 |
-
return
|
| 60 |
-
|
| 61 |
-
print(f"🎨 Generating pose trajectory visualization...")
|
| 62 |
-
|
| 63 |
-
# Extract translation part
|
| 64 |
-
positions = poses_est[:, :3, 3] # shape: (N, 3)
|
| 65 |
-
|
| 66 |
-
# Create figure - show XZ-plane projection only
|
| 67 |
-
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
|
| 68 |
-
|
| 69 |
-
# XZ-plane projection
|
| 70 |
-
ax.plot(
|
| 71 |
-
positions[:, 0],
|
| 72 |
-
positions[:, 2],
|
| 73 |
-
"b-",
|
| 74 |
-
linewidth=2,
|
| 75 |
-
label="Predicted Trajectory",
|
| 76 |
-
)
|
| 77 |
-
ax.scatter(
|
| 78 |
-
positions[0, 0], positions[0, 2], color="green", s=100, label="Start"
|
| 79 |
-
)
|
| 80 |
-
ax.scatter(positions[-1, 0], positions[-1, 2], color="red", s=100, label="End")
|
| 81 |
-
ax.set_xlabel("X (m)")
|
| 82 |
-
ax.set_ylabel("Z (m)")
|
| 83 |
-
ax.set_title(f"{scene_name} - XZ-plane projection")
|
| 84 |
-
ax.legend()
|
| 85 |
-
ax.grid(True, alpha=0.3)
|
| 86 |
-
|
| 87 |
-
# Save image
|
| 88 |
-
pose_plot_path = output_scene_dir / "predicted_trajectory.png"
|
| 89 |
-
plt.savefig(pose_plot_path, dpi=300, bbox_inches="tight")
|
| 90 |
-
plt.close()
|
| 91 |
-
|
| 92 |
-
print(f"📊 Trajectory visualization saved: {pose_plot_path}")
|
| 93 |
-
|
| 94 |
-
except Exception as e:
|
| 95 |
-
print(f"⚠️ Failed to generate pose visualization: {e}")
|
| 96 |
-
import traceback
|
| 97 |
-
|
| 98 |
-
traceback.print_exc()
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def main():
|
| 102 |
-
"""
|
| 103 |
-
Evaluation script for a Custom Dataset.
|
| 104 |
-
Supports optional evaluation and custom dataset structure.
|
| 105 |
-
"""
|
| 106 |
-
parser = argparse.ArgumentParser(
|
| 107 |
-
description="Run FastVGGT evaluation on a Custom Dataset"
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
# Required: dataset path
|
| 111 |
-
parser.add_argument(
|
| 112 |
-
"--data_path",
|
| 113 |
-
type=Path,
|
| 114 |
-
required=True,
|
| 115 |
-
help="Dataset path containing subfolders: color, depth, gt_ply, pose",
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
# Optional: enable evaluation
|
| 119 |
-
parser.add_argument(
|
| 120 |
-
"--enable_evaluation",
|
| 121 |
-
action="store_true",
|
| 122 |
-
help="Enable evaluation (requires pose and ply data)",
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
# Output path
|
| 126 |
-
parser.add_argument(
|
| 127 |
-
"--output_path",
|
| 128 |
-
type=Path,
|
| 129 |
-
default="./eval_results_custom",
|
| 130 |
-
help="Output path for evaluation results",
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
# Model parameters
|
| 134 |
-
parser.add_argument(
|
| 135 |
-
"--ckpt_path",
|
| 136 |
-
type=str,
|
| 137 |
-
default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
|
| 138 |
-
help="Model checkpoint file path",
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
parser.add_argument("--merging", type=int, default=0, help="Merging parameter")
|
| 142 |
-
|
| 143 |
-
# Processing parameters
|
| 144 |
-
parser.add_argument(
|
| 145 |
-
"--input_frame",
|
| 146 |
-
type=int,
|
| 147 |
-
default=200,
|
| 148 |
-
help="Maximum number of frames to process per scene",
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
parser.add_argument(
|
| 152 |
-
"--depth_conf_thresh",
|
| 153 |
-
type=float,
|
| 154 |
-
default=3.0,
|
| 155 |
-
help="Depth confidence threshold to filter low-confidence depth values",
|
| 156 |
-
)
|
| 157 |
-
|
| 158 |
-
# Evaluation parameters (only used when evaluation is enabled)
|
| 159 |
-
parser.add_argument(
|
| 160 |
-
"--chamfer_max_dist",
|
| 161 |
-
type=float,
|
| 162 |
-
default=0.5,
|
| 163 |
-
help="Maximum distance threshold used in Chamfer Distance computation",
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
parser.add_argument("--plot", action="store_true", help="Whether to generate plots")
|
| 167 |
-
|
| 168 |
-
parser.add_argument(
|
| 169 |
-
"--vis_attn_map",
|
| 170 |
-
action="store_true",
|
| 171 |
-
help="Visualize attention maps during inference",
|
| 172 |
-
)
|
| 173 |
-
|
| 174 |
-
args = parser.parse_args()
|
| 175 |
-
torch.manual_seed(33)
|
| 176 |
-
|
| 177 |
-
# Check data path exists
|
| 178 |
-
if not args.data_path.exists():
|
| 179 |
-
print(f"❌ Error: Data path does not exist: {args.data_path}")
|
| 180 |
-
return
|
| 181 |
-
|
| 182 |
-
# Check required subdirectories
|
| 183 |
-
color_dir = args.data_path / "images"
|
| 184 |
-
pose_dir = args.data_path / "pose"
|
| 185 |
-
|
| 186 |
-
if not color_dir.exists():
|
| 187 |
-
print(f"❌ Error: color directory does not exist: {color_dir}")
|
| 188 |
-
return
|
| 189 |
-
|
| 190 |
-
print(f"📁 Dataset path: {args.data_path}")
|
| 191 |
-
# print(f"🔧 Enable evaluation: {'Yes' if args.enable_evaluation else 'No'}")
|
| 192 |
-
|
| 193 |
-
# If evaluation is enabled, check pose and gt_ply directories
|
| 194 |
-
if args.enable_evaluation:
|
| 195 |
-
if not pose_dir.exists():
|
| 196 |
-
print(f"❌ Error: Evaluation requires pose directory: {pose_dir}")
|
| 197 |
-
return
|
| 198 |
-
|
| 199 |
-
gt_ply_dir = args.data_path / "gt_ply"
|
| 200 |
-
if not gt_ply_dir.exists():
|
| 201 |
-
print(f"❌ Error: Evaluation requires gt_ply directory: {gt_ply_dir}")
|
| 202 |
-
return
|
| 203 |
-
print(f"📊 Evaluation will use Ground Truth")
|
| 204 |
-
else:
|
| 205 |
-
print(f"🏃 Inference only, no evaluation")
|
| 206 |
-
|
| 207 |
-
# Create output directory
|
| 208 |
-
args.output_path.mkdir(parents=True, exist_ok=True)
|
| 209 |
-
output_scene_dir = args.output_path / "custom_dataset"
|
| 210 |
-
|
| 211 |
-
# Check if already processed
|
| 212 |
-
if (output_scene_dir / "metrics.json").exists() and args.enable_evaluation:
|
| 213 |
-
print(
|
| 214 |
-
f"⚠️ Results already exist, skipping: {output_scene_dir / 'metrics.json'}"
|
| 215 |
-
)
|
| 216 |
-
return
|
| 217 |
-
|
| 218 |
-
# Force use of bf16 dtype
|
| 219 |
-
dtype = torch.bfloat16
|
| 220 |
-
|
| 221 |
-
# Load VGGT model
|
| 222 |
-
print(f"🔄 Loading model: {args.ckpt_path}")
|
| 223 |
-
model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
|
| 224 |
-
ckpt = torch.load(args.ckpt_path, map_location="cpu")
|
| 225 |
-
incompat = model.load_state_dict(ckpt, strict=False)
|
| 226 |
-
# if incompat.missing_keys or incompat.unexpected_keys:
|
| 227 |
-
# print(f"⚠️ Partially incompatible keys when loading model: {incompat}")
|
| 228 |
-
model = model.cuda().eval()
|
| 229 |
-
model = model.to(torch.bfloat16)
|
| 230 |
-
print(f"✅ Model loaded")
|
| 231 |
-
|
| 232 |
-
# Load scene data
|
| 233 |
-
image_paths = get_sorted_image_paths(color_dir)
|
| 234 |
-
if len(image_paths) == 0:
|
| 235 |
-
print(f"❌ Error: No images found in {color_dir}")
|
| 236 |
-
return
|
| 237 |
-
|
| 238 |
-
print(f"🖼️ Found {len(image_paths)} images")
|
| 239 |
-
|
| 240 |
-
# Process pose data (if evaluation is enabled)
|
| 241 |
-
poses_gt = None
|
| 242 |
-
first_gt_pose = None
|
| 243 |
-
available_pose_frame_ids = None
|
| 244 |
-
c2ws = None
|
| 245 |
-
|
| 246 |
-
if args.enable_evaluation:
|
| 247 |
-
poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_dir)
|
| 248 |
-
if (
|
| 249 |
-
poses_gt is None
|
| 250 |
-
or first_gt_pose is None
|
| 251 |
-
or available_pose_frame_ids is None
|
| 252 |
-
):
|
| 253 |
-
print(f"❌ Error: Failed to load pose data")
|
| 254 |
-
return
|
| 255 |
-
print(f"📐 Loaded {len(poses_gt)} poses")
|
| 256 |
-
|
| 257 |
-
# Frame selection
|
| 258 |
-
if args.enable_evaluation and available_pose_frame_ids is not None:
|
| 259 |
-
# Use pose data for frame selection
|
| 260 |
-
selected_frame_ids, selected_image_paths, selected_pose_indices = (
|
| 261 |
-
build_frame_selection(
|
| 262 |
-
image_paths, available_pose_frame_ids, args.input_frame
|
| 263 |
-
)
|
| 264 |
-
)
|
| 265 |
-
c2ws = poses_gt[selected_pose_indices]
|
| 266 |
-
image_paths = selected_image_paths
|
| 267 |
-
else:
|
| 268 |
-
# Simply take the first N frames
|
| 269 |
-
num_frames = min(len(image_paths), args.input_frame)
|
| 270 |
-
selected_frame_ids = list(range(num_frames))
|
| 271 |
-
image_paths = image_paths[:num_frames]
|
| 272 |
-
|
| 273 |
-
print(f"📋 Selected {len(image_paths)} frames for processing")
|
| 274 |
-
|
| 275 |
-
try:
|
| 276 |
-
# Load images
|
| 277 |
-
print(f"🔄 Loading images...")
|
| 278 |
-
images = load_images_rgb(image_paths)
|
| 279 |
-
|
| 280 |
-
if not images or len(images) < 3:
|
| 281 |
-
print(f"❌ Error: Not enough valid images (need at least 3)")
|
| 282 |
-
return
|
| 283 |
-
|
| 284 |
-
frame_ids = selected_frame_ids
|
| 285 |
-
images_array = np.stack(images)
|
| 286 |
-
vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
|
| 287 |
-
print(f"📐 Image patch dimensions: {patch_width}x{patch_height}")
|
| 288 |
-
|
| 289 |
-
# Update attention layer patch dimensions in the model
|
| 290 |
-
model.update_patch_dimensions(patch_width, patch_height)
|
| 291 |
-
|
| 292 |
-
# Inference + Reconstruction
|
| 293 |
-
print(f"🚀 Start inference and reconstruction...")
|
| 294 |
-
(
|
| 295 |
-
extrinsic_np,
|
| 296 |
-
intrinsic_np,
|
| 297 |
-
all_world_points,
|
| 298 |
-
all_point_colors,
|
| 299 |
-
all_cam_to_world_mat,
|
| 300 |
-
inference_time_ms,
|
| 301 |
-
) = infer_vggt_and_reconstruct(
|
| 302 |
-
model, vgg_input, dtype, args.depth_conf_thresh, image_paths
|
| 303 |
-
)
|
| 304 |
-
print(f"⏱️ Inference time: {inference_time_ms:.2f}ms")
|
| 305 |
-
|
| 306 |
-
# Check results
|
| 307 |
-
if not all_cam_to_world_mat or not all_world_points:
|
| 308 |
-
print(f"❌ Error: Failed to obtain valid camera poses or point clouds")
|
| 309 |
-
return
|
| 310 |
-
|
| 311 |
-
# print(f"✅ Inference done, obtained {len(all_world_points)} point sets")
|
| 312 |
-
|
| 313 |
-
# Evaluation and saving
|
| 314 |
-
if args.enable_evaluation:
|
| 315 |
-
print(f"📊 Start evaluation...")
|
| 316 |
-
gt_ply_dir = args.data_path / "gt_ply"
|
| 317 |
-
metrics = evaluate_scene_and_save(
|
| 318 |
-
"custom_dataset",
|
| 319 |
-
c2ws,
|
| 320 |
-
first_gt_pose,
|
| 321 |
-
frame_ids,
|
| 322 |
-
all_cam_to_world_mat,
|
| 323 |
-
all_world_points,
|
| 324 |
-
output_scene_dir,
|
| 325 |
-
gt_ply_dir,
|
| 326 |
-
args.chamfer_max_dist,
|
| 327 |
-
inference_time_ms,
|
| 328 |
-
args.plot,
|
| 329 |
-
)
|
| 330 |
-
if metrics is not None:
|
| 331 |
-
print("📈 Evaluation results:")
|
| 332 |
-
for key, value in metrics.items():
|
| 333 |
-
if key in [
|
| 334 |
-
"chamfer_distance",
|
| 335 |
-
"ate",
|
| 336 |
-
"are",
|
| 337 |
-
"rpe_rot",
|
| 338 |
-
"rpe_trans",
|
| 339 |
-
"inference_time_ms",
|
| 340 |
-
]:
|
| 341 |
-
print(f" {key}: {float(value):.4f}")
|
| 342 |
-
|
| 343 |
-
# Also visualize predicted poses in evaluation branch
|
| 344 |
-
if args.plot:
|
| 345 |
-
visualize_predicted_poses(
|
| 346 |
-
all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
|
| 347 |
-
)
|
| 348 |
-
else:
|
| 349 |
-
# Save reconstruction only, no evaluation
|
| 350 |
-
print(f"💾 Saving reconstruction...")
|
| 351 |
-
output_scene_dir.mkdir(parents=True, exist_ok=True)
|
| 352 |
-
|
| 353 |
-
# Save camera poses
|
| 354 |
-
poses_output_path = output_scene_dir / "estimated_poses.txt"
|
| 355 |
-
with open(poses_output_path, "w") as f:
|
| 356 |
-
for i, pose in enumerate(all_cam_to_world_mat):
|
| 357 |
-
f.write(f"# Frame {frame_ids[i]}\n")
|
| 358 |
-
for row in pose:
|
| 359 |
-
f.write(" ".join(map(str, row)) + "\n")
|
| 360 |
-
f.write("\n")
|
| 361 |
-
|
| 362 |
-
# Save point cloud
|
| 363 |
-
if all_world_points:
|
| 364 |
-
points_output_path = output_scene_dir / "reconstructed_points.ply"
|
| 365 |
-
|
| 366 |
-
# Merge all frames' point clouds and colors
|
| 367 |
-
try:
|
| 368 |
-
merged_point_cloud = np.vstack(all_world_points)
|
| 369 |
-
merged_colors = (
|
| 370 |
-
np.vstack(all_point_colors).astype(np.uint8)
|
| 371 |
-
if all_point_colors is not None and len(all_point_colors) > 0
|
| 372 |
-
else None
|
| 373 |
-
)
|
| 374 |
-
print(
|
| 375 |
-
f"📊 Merged point clouds: {len(all_world_points)} frames, total {len(merged_point_cloud)} points"
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
# If too many points, randomly sample 100000 points
|
| 379 |
-
max_points = 100000
|
| 380 |
-
if len(merged_point_cloud) > max_points:
|
| 381 |
-
print(
|
| 382 |
-
f"🔽 Too many points, randomly sampling {max_points} points..."
|
| 383 |
-
)
|
| 384 |
-
# Randomly choose indices
|
| 385 |
-
indices = np.random.choice(
|
| 386 |
-
len(merged_point_cloud), size=max_points, replace=False
|
| 387 |
-
)
|
| 388 |
-
merged_point_cloud = merged_point_cloud[indices]
|
| 389 |
-
if merged_colors is not None:
|
| 390 |
-
merged_colors = merged_colors[indices]
|
| 391 |
-
print(
|
| 392 |
-
f"✅ Sampling done, kept {len(merged_point_cloud)} points"
|
| 393 |
-
)
|
| 394 |
-
|
| 395 |
-
# Save as PLY (with color)
|
| 396 |
-
with open(points_output_path, "w") as f:
|
| 397 |
-
f.write("ply\n")
|
| 398 |
-
f.write("format ascii 1.0\n")
|
| 399 |
-
f.write(f"element vertex {len(merged_point_cloud)}\n")
|
| 400 |
-
f.write("property float x\n")
|
| 401 |
-
f.write("property float y\n")
|
| 402 |
-
f.write("property float z\n")
|
| 403 |
-
if merged_colors is not None:
|
| 404 |
-
f.write("property uchar red\n")
|
| 405 |
-
f.write("property uchar green\n")
|
| 406 |
-
f.write("property uchar blue\n")
|
| 407 |
-
f.write("end_header\n")
|
| 408 |
-
if merged_colors is None:
|
| 409 |
-
for point in merged_point_cloud:
|
| 410 |
-
if not (np.isnan(point).any() or np.isinf(point).any()):
|
| 411 |
-
f.write(
|
| 412 |
-
f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n"
|
| 413 |
-
)
|
| 414 |
-
else:
|
| 415 |
-
for point, color in zip(merged_point_cloud, merged_colors):
|
| 416 |
-
# Check point validity
|
| 417 |
-
if not (np.isnan(point).any() or np.isinf(point).any()):
|
| 418 |
-
r = int(np.clip(color[0], 0, 255))
|
| 419 |
-
g = int(np.clip(color[1], 0, 255))
|
| 420 |
-
b = int(np.clip(color[2], 0, 255))
|
| 421 |
-
f.write(
|
| 422 |
-
f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r} {g} {b}\n"
|
| 423 |
-
)
|
| 424 |
-
|
| 425 |
-
print(f"💾 Point cloud saved to: {points_output_path}")
|
| 426 |
-
|
| 427 |
-
except Exception as e:
|
| 428 |
-
print(f"⚠️ Error saving point cloud: {e}")
|
| 429 |
-
# If merge fails, try to log per-frame info
|
| 430 |
-
print(f"🔍 Point cloud debug info:")
|
| 431 |
-
for i, frame_points in enumerate(all_world_points):
|
| 432 |
-
print(
|
| 433 |
-
f" Frame {i}: {frame_points.shape if hasattr(frame_points, 'shape') else type(frame_points)}"
|
| 434 |
-
)
|
| 435 |
-
if (
|
| 436 |
-
hasattr(frame_points, "shape")
|
| 437 |
-
and len(frame_points.shape) >= 2
|
| 438 |
-
):
|
| 439 |
-
print(
|
| 440 |
-
f" Shape: {frame_points.shape}, Dtype: {frame_points.dtype}"
|
| 441 |
-
)
|
| 442 |
-
if frame_points.shape[0] > 0:
|
| 443 |
-
print(
|
| 444 |
-
f" Range: x[{np.min(frame_points[:, 0]):.3f}, {np.max(frame_points[:, 0]):.3f}] "
|
| 445 |
-
f"y[{np.min(frame_points[:, 1]):.3f}, {np.max(frame_points[:, 1]):.3f}] "
|
| 446 |
-
f"z[{np.min(frame_points[:, 2]):.3f}, {np.max(frame_points[:, 2]):.3f}]"
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
-
print(f"📁 Results saved to: {output_scene_dir}")
|
| 450 |
-
|
| 451 |
-
# Visualize predicted pose trajectory
|
| 452 |
-
if args.plot:
|
| 453 |
-
visualize_predicted_poses(
|
| 454 |
-
all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
|
| 455 |
-
)
|
| 456 |
-
|
| 457 |
-
print(f"🎉 Done!")
|
| 458 |
-
|
| 459 |
-
except Exception as e:
|
| 460 |
-
print(f"❌ Error occurred during processing: {e}")
|
| 461 |
-
import traceback
|
| 462 |
-
|
| 463 |
-
traceback.print_exc()
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
if __name__ == "__main__":
|
| 467 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/eval/eval_scannet.py
DELETED
|
@@ -1,208 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import os
|
| 6 |
-
import sys
|
| 7 |
-
|
| 8 |
-
# Ensure project root is in sys.path for absolute imports like `vggt.*`
|
| 9 |
-
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
| 10 |
-
if ROOT_DIR not in sys.path:
|
| 11 |
-
sys.path.insert(0, ROOT_DIR)
|
| 12 |
-
|
| 13 |
-
from vggt.models.vggt import VGGT
|
| 14 |
-
from vggt.utils.eval_utils import (
|
| 15 |
-
load_poses,
|
| 16 |
-
get_vgg_input_imgs,
|
| 17 |
-
get_sorted_image_paths,
|
| 18 |
-
get_all_scenes,
|
| 19 |
-
build_frame_selection,
|
| 20 |
-
load_images_rgb,
|
| 21 |
-
infer_vggt_and_reconstruct,
|
| 22 |
-
evaluate_scene_and_save,
|
| 23 |
-
compute_average_metrics_and_save,
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
if __name__ == "__main__":
|
| 28 |
-
parser = argparse.ArgumentParser()
|
| 29 |
-
parser.add_argument(
|
| 30 |
-
"--data_dir", type=Path, default="/data/scannetv2/process_scannet"
|
| 31 |
-
)
|
| 32 |
-
parser.add_argument(
|
| 33 |
-
"--gt_ply_dir",
|
| 34 |
-
type=Path,
|
| 35 |
-
default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
|
| 36 |
-
)
|
| 37 |
-
parser.add_argument("--output_path", type=Path, default="./eval_results")
|
| 38 |
-
parser.add_argument("--merging", type=int, default=None)
|
| 39 |
-
parser.add_argument("--plot", type=bool, default=True)
|
| 40 |
-
parser.add_argument(
|
| 41 |
-
"--depth_conf_thresh",
|
| 42 |
-
type=float,
|
| 43 |
-
default=3.0,
|
| 44 |
-
help="Depth confidence threshold for filtering low confidence depth values",
|
| 45 |
-
)
|
| 46 |
-
parser.add_argument(
|
| 47 |
-
"--chamfer_max_dist",
|
| 48 |
-
type=float,
|
| 49 |
-
default=0.5,
|
| 50 |
-
help="Maximum distance threshold in Chamfer Distance computation, distances exceeding this value will be clipped",
|
| 51 |
-
)
|
| 52 |
-
parser.add_argument(
|
| 53 |
-
"--input_frame",
|
| 54 |
-
type=int,
|
| 55 |
-
default=200,
|
| 56 |
-
help="Maximum number of frames selected for processing per scene",
|
| 57 |
-
)
|
| 58 |
-
parser.add_argument(
|
| 59 |
-
"--num_scenes",
|
| 60 |
-
type=int,
|
| 61 |
-
default=50,
|
| 62 |
-
help="Maximum number of scenes to evaluate",
|
| 63 |
-
)
|
| 64 |
-
parser.add_argument(
|
| 65 |
-
"--ckpt_path",
|
| 66 |
-
type=str,
|
| 67 |
-
default="./ckpt/model_tracker_fixed_e20.pt",
|
| 68 |
-
help="Path to the model checkpoint file",
|
| 69 |
-
)
|
| 70 |
-
parser.add_argument(
|
| 71 |
-
"--vis_attn_map",
|
| 72 |
-
action="store_true",
|
| 73 |
-
help="Whether to visualize attention maps during inference",
|
| 74 |
-
)
|
| 75 |
-
args = parser.parse_args()
|
| 76 |
-
torch.manual_seed(33)
|
| 77 |
-
|
| 78 |
-
# Scene sampling
|
| 79 |
-
scannet_scenes = get_all_scenes(args.data_dir, args.num_scenes)
|
| 80 |
-
print(f"Evaluate {len(scannet_scenes)} scenes")
|
| 81 |
-
|
| 82 |
-
all_scenes_metrics = {"scenes": {}, "average": {}}
|
| 83 |
-
# Force use of bf16 data type
|
| 84 |
-
dtype = torch.bfloat16
|
| 85 |
-
# Load VGGT model
|
| 86 |
-
model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
|
| 87 |
-
ckpt = torch.load(args.ckpt_path, map_location="cpu")
|
| 88 |
-
incompat = model.load_state_dict(ckpt, strict=False)
|
| 89 |
-
model = model.cuda().eval()
|
| 90 |
-
model = model.to(torch.bfloat16)
|
| 91 |
-
|
| 92 |
-
# Process each scene
|
| 93 |
-
for scene in scannet_scenes:
|
| 94 |
-
scene_dir = args.data_dir / f"{scene}"
|
| 95 |
-
output_scene_dir = args.output_path / f"input_frame_{args.input_frame}" / scene
|
| 96 |
-
if (output_scene_dir / "metrics.json").exists():
|
| 97 |
-
continue
|
| 98 |
-
|
| 99 |
-
# Load scene data
|
| 100 |
-
images_dir = scene_dir / "color"
|
| 101 |
-
pose_path = scene_dir / "pose"
|
| 102 |
-
image_paths = get_sorted_image_paths(images_dir)
|
| 103 |
-
poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_path)
|
| 104 |
-
if (
|
| 105 |
-
poses_gt is None
|
| 106 |
-
or first_gt_pose is None
|
| 107 |
-
or available_pose_frame_ids is None
|
| 108 |
-
):
|
| 109 |
-
print(f"Skipping scene {scene}: no pose data")
|
| 110 |
-
continue
|
| 111 |
-
|
| 112 |
-
# Frame filtering
|
| 113 |
-
selected_frame_ids, selected_image_paths, selected_pose_indices = (
|
| 114 |
-
build_frame_selection(
|
| 115 |
-
image_paths, available_pose_frame_ids, args.input_frame
|
| 116 |
-
)
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
# Get corresponding poses
|
| 120 |
-
c2ws = poses_gt[selected_pose_indices]
|
| 121 |
-
image_paths = selected_image_paths
|
| 122 |
-
|
| 123 |
-
if len(image_paths) == 0:
|
| 124 |
-
print(f"No images found in {images_dir}")
|
| 125 |
-
continue
|
| 126 |
-
|
| 127 |
-
print("🚩Processing", scene, f"Found {len(image_paths)} images")
|
| 128 |
-
all_cam_to_world_mat = []
|
| 129 |
-
all_world_points = []
|
| 130 |
-
|
| 131 |
-
try:
|
| 132 |
-
# Load images
|
| 133 |
-
images = load_images_rgb(image_paths)
|
| 134 |
-
|
| 135 |
-
if not images or len(images) < 3:
|
| 136 |
-
print(f"Skipping {scene}: insufficient valid images")
|
| 137 |
-
continue
|
| 138 |
-
|
| 139 |
-
frame_ids = selected_frame_ids
|
| 140 |
-
images_array = np.stack(images)
|
| 141 |
-
vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
|
| 142 |
-
print(f"Patch dimensions: {patch_width}x{patch_height}")
|
| 143 |
-
|
| 144 |
-
# Update model attention layers with dynamic patch dimensions
|
| 145 |
-
model.update_patch_dimensions(patch_width, patch_height)
|
| 146 |
-
|
| 147 |
-
# Inference + Reconstruction
|
| 148 |
-
(
|
| 149 |
-
extrinsic_np,
|
| 150 |
-
intrinsic_np,
|
| 151 |
-
all_world_points,
|
| 152 |
-
all_point_colors,
|
| 153 |
-
all_cam_to_world_mat,
|
| 154 |
-
inference_time_ms,
|
| 155 |
-
) = infer_vggt_and_reconstruct(
|
| 156 |
-
model, vgg_input, dtype, args.depth_conf_thresh, image_paths
|
| 157 |
-
)
|
| 158 |
-
print(f"Inference time: {inference_time_ms:.2f}ms")
|
| 159 |
-
|
| 160 |
-
# Process results
|
| 161 |
-
if not all_cam_to_world_mat or not all_world_points:
|
| 162 |
-
print(
|
| 163 |
-
f"Skipping {scene}: failed to obtain valid camera poses or point clouds"
|
| 164 |
-
)
|
| 165 |
-
continue
|
| 166 |
-
|
| 167 |
-
# Evaluate and save
|
| 168 |
-
metrics = evaluate_scene_and_save(
|
| 169 |
-
scene,
|
| 170 |
-
c2ws,
|
| 171 |
-
first_gt_pose,
|
| 172 |
-
frame_ids,
|
| 173 |
-
all_cam_to_world_mat,
|
| 174 |
-
all_world_points,
|
| 175 |
-
output_scene_dir,
|
| 176 |
-
args.gt_ply_dir,
|
| 177 |
-
args.chamfer_max_dist,
|
| 178 |
-
inference_time_ms,
|
| 179 |
-
args.plot,
|
| 180 |
-
)
|
| 181 |
-
if metrics is not None:
|
| 182 |
-
all_scenes_metrics["scenes"][scene] = {
|
| 183 |
-
key: float(value)
|
| 184 |
-
for key, value in metrics.items()
|
| 185 |
-
if key
|
| 186 |
-
in [
|
| 187 |
-
"chamfer_distance",
|
| 188 |
-
"ate",
|
| 189 |
-
"are",
|
| 190 |
-
"rpe_rot",
|
| 191 |
-
"rpe_trans",
|
| 192 |
-
"inference_time_ms",
|
| 193 |
-
]
|
| 194 |
-
}
|
| 195 |
-
print("Complete metrics", all_scenes_metrics["scenes"][scene])
|
| 196 |
-
|
| 197 |
-
except Exception as e:
|
| 198 |
-
print(f"Error processing scene {scene}: {e}")
|
| 199 |
-
import traceback
|
| 200 |
-
|
| 201 |
-
traceback.print_exc()
|
| 202 |
-
|
| 203 |
-
# Summarize average metrics and save
|
| 204 |
-
compute_average_metrics_and_save(
|
| 205 |
-
all_scenes_metrics,
|
| 206 |
-
args.output_path,
|
| 207 |
-
args.input_frame,
|
| 208 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/eval/utils.py
DELETED
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
from scipy.spatial import cKDTree as KDTree
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
|
| 6 |
-
"""
|
| 7 |
-
Args:
|
| 8 |
-
- depthmap (HxW array):
|
| 9 |
-
- camera_intrinsics: a 3x3 matrix
|
| 10 |
-
Returns:
|
| 11 |
-
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 12 |
-
"""
|
| 13 |
-
camera_intrinsics = np.float32(camera_intrinsics)
|
| 14 |
-
H, W = depthmap.shape
|
| 15 |
-
|
| 16 |
-
assert camera_intrinsics[0, 1] == 0.0
|
| 17 |
-
assert camera_intrinsics[1, 0] == 0.0
|
| 18 |
-
if pseudo_focal is None:
|
| 19 |
-
fu = camera_intrinsics[0, 0]
|
| 20 |
-
fv = camera_intrinsics[1, 1]
|
| 21 |
-
else:
|
| 22 |
-
assert pseudo_focal.shape == (H, W)
|
| 23 |
-
fu = fv = pseudo_focal
|
| 24 |
-
cu = camera_intrinsics[0, 2]
|
| 25 |
-
cv = camera_intrinsics[1, 2]
|
| 26 |
-
|
| 27 |
-
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 28 |
-
z_cam = depthmap
|
| 29 |
-
x_cam = (u - cu) * z_cam / fu
|
| 30 |
-
y_cam = (v - cv) * z_cam / fv
|
| 31 |
-
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
| 32 |
-
|
| 33 |
-
valid_mask = depthmap > 0.0
|
| 34 |
-
return X_cam, valid_mask
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def depthmap_to_absolute_camera_coordinates(
|
| 38 |
-
depthmap, camera_intrinsics, camera_pose, **kw
|
| 39 |
-
):
|
| 40 |
-
"""
|
| 41 |
-
Args:
|
| 42 |
-
- depthmap (HxW array):
|
| 43 |
-
- camera_intrinsics: a 3x3 matrix
|
| 44 |
-
- camera_pose: a 4x3 or 4x4 cam2world matrix
|
| 45 |
-
Returns:
|
| 46 |
-
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 47 |
-
"""
|
| 48 |
-
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
|
| 49 |
-
|
| 50 |
-
X_world = X_cam # default
|
| 51 |
-
if camera_pose is not None:
|
| 52 |
-
|
| 53 |
-
R_cam2world = camera_pose[:3, :3]
|
| 54 |
-
t_cam2world = camera_pose[:3, 3]
|
| 55 |
-
|
| 56 |
-
X_world = (
|
| 57 |
-
np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
return X_world, valid_mask
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def completion_ratio(gt_points, rec_points, dist_th=0.05):
|
| 64 |
-
gen_points_kd_tree = KDTree(rec_points)
|
| 65 |
-
distances, _ = gen_points_kd_tree.query(gt_points)
|
| 66 |
-
comp_ratio = np.mean((distances < dist_th).astype(np.float32))
|
| 67 |
-
return comp_ratio
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None):
|
| 71 |
-
gt_points_kd_tree = KDTree(gt_points)
|
| 72 |
-
distances, idx = gt_points_kd_tree.query(rec_points, workers=-1)
|
| 73 |
-
acc = np.mean(distances)
|
| 74 |
-
|
| 75 |
-
acc_median = np.median(distances)
|
| 76 |
-
|
| 77 |
-
if gt_normals is not None and rec_normals is not None:
|
| 78 |
-
normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1)
|
| 79 |
-
normal_dot = np.abs(normal_dot)
|
| 80 |
-
|
| 81 |
-
return acc, acc_median, np.mean(normal_dot), np.median(normal_dot)
|
| 82 |
-
|
| 83 |
-
return acc, acc_median
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def completion(gt_points, rec_points, gt_normals=None, rec_normals=None):
|
| 87 |
-
gt_points_kd_tree = KDTree(rec_points)
|
| 88 |
-
distances, idx = gt_points_kd_tree.query(gt_points, workers=-1)
|
| 89 |
-
comp = np.mean(distances)
|
| 90 |
-
comp_median = np.median(distances)
|
| 91 |
-
|
| 92 |
-
if gt_normals is not None and rec_normals is not None:
|
| 93 |
-
normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1)
|
| 94 |
-
normal_dot = np.abs(normal_dot)
|
| 95 |
-
|
| 96 |
-
return comp, comp_median, np.mean(normal_dot), np.median(normal_dot)
|
| 97 |
-
|
| 98 |
-
return comp, comp_median
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def compute_iou(pred_vox, target_vox):
|
| 102 |
-
# Get voxel indices
|
| 103 |
-
v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()]
|
| 104 |
-
v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()]
|
| 105 |
-
|
| 106 |
-
# Convert to sets for set operations
|
| 107 |
-
v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices)
|
| 108 |
-
v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices)
|
| 109 |
-
|
| 110 |
-
# Compute intersection and union
|
| 111 |
-
intersection = v_pred_filled & v_target_filled
|
| 112 |
-
union = v_pred_filled | v_target_filled
|
| 113 |
-
|
| 114 |
-
# Compute IoU
|
| 115 |
-
iou = len(intersection) / len(union)
|
| 116 |
-
return iou
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def colmap_to_opencv_intrinsics(K):
|
| 120 |
-
"""
|
| 121 |
-
Modify camera intrinsics to follow a different convention.
|
| 122 |
-
Coordinates of the center of the top-left pixels are by default:
|
| 123 |
-
- (0.5, 0.5) in Colmap
|
| 124 |
-
- (0,0) in OpenCV
|
| 125 |
-
"""
|
| 126 |
-
K = K.copy()
|
| 127 |
-
K[0, 2] -= 0.5
|
| 128 |
-
K[1, 2] -= 0.5
|
| 129 |
-
return K
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def opencv_to_colmap_intrinsics(K):
|
| 133 |
-
"""
|
| 134 |
-
Modify camera intrinsics to follow a different convention.
|
| 135 |
-
Coordinates of the center of the top-left pixels are by default:
|
| 136 |
-
- (0.5, 0.5) in Colmap
|
| 137 |
-
- (0,0) in OpenCV
|
| 138 |
-
"""
|
| 139 |
-
K = K.copy()
|
| 140 |
-
K[0, 2] += 0.5
|
| 141 |
-
K[1, 2] += 0.5
|
| 142 |
-
return K
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/merging/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from . import merge
|
| 2 |
-
|
| 3 |
-
__all__ = ["merge"]
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/merging/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (187 Bytes)
|
|
|
FastVGGT/merging/__pycache__/merge.cpython-310.pyc
DELETED
|
Binary file (7.54 kB)
|
|
|
FastVGGT/merging/merge.py
DELETED
|
@@ -1,370 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from typing import Tuple, Callable, Optional, Union
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
@torch.jit.script
|
| 6 |
-
def fast_similarity_chunks(
|
| 7 |
-
a: torch.Tensor, b_transposed: torch.Tensor, chunk_size: int
|
| 8 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 9 |
-
|
| 10 |
-
B, num_src, C = a.shape
|
| 11 |
-
original_dtype = a.dtype
|
| 12 |
-
|
| 13 |
-
# Convert to bf16 for computation to improve performance and reduce memory usage
|
| 14 |
-
a_bf16 = a.to(torch.bfloat16)
|
| 15 |
-
b_transposed_bf16 = b_transposed.to(torch.bfloat16)
|
| 16 |
-
node_max = torch.empty(B, num_src, device=a.device, dtype=original_dtype)
|
| 17 |
-
node_idx = torch.empty(B, num_src, device=a.device, dtype=torch.long)
|
| 18 |
-
|
| 19 |
-
# Process in chunks
|
| 20 |
-
for i in range(0, num_src, chunk_size):
|
| 21 |
-
end_i = min(i + chunk_size, num_src)
|
| 22 |
-
a_chunk = a_bf16[:, i:end_i, :] # [B, chunk_size, C]
|
| 23 |
-
scores_chunk = torch.bmm(a_chunk, b_transposed_bf16)
|
| 24 |
-
chunk_max_bf16, chunk_idx = torch.max(scores_chunk, dim=2)
|
| 25 |
-
chunk_max = chunk_max_bf16.to(original_dtype)
|
| 26 |
-
node_max[:, i:end_i] = chunk_max
|
| 27 |
-
node_idx[:, i:end_i] = chunk_idx
|
| 28 |
-
return node_max, node_idx
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def do_nothing(
|
| 32 |
-
x: torch.Tensor,
|
| 33 |
-
extra_tensors=None,
|
| 34 |
-
extra_tensors_2=None,
|
| 35 |
-
) -> Union[
|
| 36 |
-
torch.Tensor,
|
| 37 |
-
Tuple[torch.Tensor, torch.Tensor],
|
| 38 |
-
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
| 39 |
-
]:
|
| 40 |
-
if extra_tensors is not None and extra_tensors_2 is not None:
|
| 41 |
-
return x, extra_tensors, extra_tensors_2
|
| 42 |
-
elif extra_tensors is not None:
|
| 43 |
-
return x, extra_tensors
|
| 44 |
-
else:
|
| 45 |
-
return x
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def token_merge_bipartite2d(
|
| 49 |
-
metric: torch.Tensor,
|
| 50 |
-
w: int,
|
| 51 |
-
h: int,
|
| 52 |
-
sx: int,
|
| 53 |
-
sy: int,
|
| 54 |
-
r: int,
|
| 55 |
-
no_rand: bool = False,
|
| 56 |
-
generator: Optional[torch.Generator] = None,
|
| 57 |
-
enable_protection: bool = False,
|
| 58 |
-
) -> Tuple[Callable, Callable]:
|
| 59 |
-
"""
|
| 60 |
-
Divide tokens into source (src) and destination (dst) groups, and merge r tokens from src to dst.
|
| 61 |
-
dst tokens are selected by randomly choosing one token from each (sx, sy) region.
|
| 62 |
-
Optionally protect the top 10% of tokens from merging based on importance scores.
|
| 63 |
-
|
| 64 |
-
Args:
|
| 65 |
-
- metric [B, N, C]: Tensor for similarity computation, B=batch size, N=token count, C=feature dimension
|
| 66 |
-
- w: Image width in tokens
|
| 67 |
-
- h: Image height in tokens
|
| 68 |
-
- sx: dst stride in x dimension, must divide w evenly
|
| 69 |
-
- sy: dst stride in y dimension, must divide h evenly
|
| 70 |
-
- r: Number of tokens to remove through merging
|
| 71 |
-
- no_rand: If True, disable randomness (use only top-left token)
|
| 72 |
-
- generator: Random number generator if no_rand is False and not None
|
| 73 |
-
- enable_protection: If True, enable importance protection feature
|
| 74 |
-
|
| 75 |
-
Returns:
|
| 76 |
-
- (merge, unmerge): Two functions for merging tokens and restoring pre-merge state
|
| 77 |
-
"""
|
| 78 |
-
B, N, _ = metric.shape # Batch size B, total tokens N
|
| 79 |
-
if r <= 0:
|
| 80 |
-
return do_nothing, do_nothing
|
| 81 |
-
|
| 82 |
-
gather = torch.gather
|
| 83 |
-
|
| 84 |
-
tokens_per_img = w * h + 5
|
| 85 |
-
num_imgs = N // tokens_per_img
|
| 86 |
-
assert tokens_per_img * num_imgs == N, "Token count doesn't match (w*h+5)*num_imgs"
|
| 87 |
-
|
| 88 |
-
with torch.no_grad():
|
| 89 |
-
# Determine whether to compute importance scores based on enable_protection
|
| 90 |
-
if enable_protection:
|
| 91 |
-
num_protected = int(N * 0.1)
|
| 92 |
-
step = max(1, N // num_protected)
|
| 93 |
-
protected_indices = torch.arange(0, N, step, device=metric.device)[
|
| 94 |
-
:num_protected
|
| 95 |
-
]
|
| 96 |
-
else:
|
| 97 |
-
protected_indices = None
|
| 98 |
-
num_protected = 0
|
| 99 |
-
|
| 100 |
-
# Global idx_buffer_seq of length N; -1 indicates dst, 0 indicates src (maintain original logic)
|
| 101 |
-
idx_buffer_seq = torch.zeros(N, device=metric.device, dtype=torch.int64)
|
| 102 |
-
hsy, wsx = h // sy, w // sx # Number of blocks within each image
|
| 103 |
-
|
| 104 |
-
# Mark first image entirely as dst
|
| 105 |
-
if num_imgs > 0:
|
| 106 |
-
idx_buffer_seq[:tokens_per_img] = -1
|
| 107 |
-
|
| 108 |
-
# Process other images - fully vectorized batch operations
|
| 109 |
-
if num_imgs > 1:
|
| 110 |
-
cls_indices = (
|
| 111 |
-
torch.arange(1, num_imgs, device=metric.device) * tokens_per_img
|
| 112 |
-
)
|
| 113 |
-
cls_indices = cls_indices[:, None] + torch.arange(5, device=metric.device)
|
| 114 |
-
idx_buffer_seq[cls_indices.flatten()] = -1
|
| 115 |
-
effective_h = min(hsy * sy, h)
|
| 116 |
-
effective_w = min(wsx * sx, w)
|
| 117 |
-
effective_grid_size = effective_h * effective_w
|
| 118 |
-
|
| 119 |
-
if no_rand:
|
| 120 |
-
base_pattern = torch.zeros(
|
| 121 |
-
effective_grid_size, device=metric.device, dtype=torch.int64
|
| 122 |
-
)
|
| 123 |
-
grid_starts = (
|
| 124 |
-
torch.arange(1, num_imgs, device=metric.device) * tokens_per_img + 5
|
| 125 |
-
)
|
| 126 |
-
grid_indices = grid_starts[:, None] + torch.arange(
|
| 127 |
-
effective_grid_size, device=metric.device
|
| 128 |
-
)
|
| 129 |
-
idx_buffer_seq[grid_indices.flatten()] = base_pattern.repeat(
|
| 130 |
-
num_imgs - 1
|
| 131 |
-
)
|
| 132 |
-
else:
|
| 133 |
-
total_other_imgs = num_imgs - 1
|
| 134 |
-
all_rand_idx = torch.randint(
|
| 135 |
-
sy * sx,
|
| 136 |
-
size=(total_other_imgs, hsy, wsx),
|
| 137 |
-
device=metric.device,
|
| 138 |
-
generator=generator,
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
scatter_src = -torch.ones(
|
| 142 |
-
total_other_imgs, hsy, wsx, device=metric.device, dtype=torch.int64
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
idx_buffer_batch = torch.zeros(
|
| 146 |
-
total_other_imgs,
|
| 147 |
-
hsy,
|
| 148 |
-
wsx,
|
| 149 |
-
sy * sx,
|
| 150 |
-
device=metric.device,
|
| 151 |
-
dtype=torch.int64,
|
| 152 |
-
)
|
| 153 |
-
idx_buffer_batch.scatter_(
|
| 154 |
-
dim=3,
|
| 155 |
-
index=all_rand_idx.unsqueeze(-1),
|
| 156 |
-
src=scatter_src.unsqueeze(-1),
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
idx_buffer_batch = (
|
| 160 |
-
idx_buffer_batch.view(total_other_imgs, hsy, wsx, sy, sx)
|
| 161 |
-
.transpose(2, 3)
|
| 162 |
-
.reshape(total_other_imgs, hsy * sy, wsx * sx)
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
# Batch fill to target positions - still needs a small loop here, but operations are greatly reduced
|
| 166 |
-
for i in range(total_other_imgs):
|
| 167 |
-
img_idx = i + 1
|
| 168 |
-
grid_start = img_idx * tokens_per_img + 5
|
| 169 |
-
flat_view = idx_buffer_batch[
|
| 170 |
-
i, :effective_h, :effective_w
|
| 171 |
-
].flatten()
|
| 172 |
-
idx_buffer_seq[grid_start : grid_start + effective_grid_size] = (
|
| 173 |
-
flat_view
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
rand_idx = idx_buffer_seq.reshape(1, -1, 1).argsort(dim=1)
|
| 177 |
-
num_dst_orig = int((idx_buffer_seq == -1).sum())
|
| 178 |
-
|
| 179 |
-
# Original src and dst indices
|
| 180 |
-
a_idx_orig = rand_idx[:, num_dst_orig:, :]
|
| 181 |
-
b_idx_orig = rand_idx[:, :num_dst_orig, :]
|
| 182 |
-
a_idx = a_idx_orig
|
| 183 |
-
b_idx = b_idx_orig
|
| 184 |
-
|
| 185 |
-
if enable_protection:
|
| 186 |
-
protected_idx = protected_indices.unsqueeze(0).unsqueeze(-1)
|
| 187 |
-
num_protected_actual = protected_idx.shape[1]
|
| 188 |
-
else:
|
| 189 |
-
protected_idx = None
|
| 190 |
-
num_protected_actual = 0
|
| 191 |
-
|
| 192 |
-
num_src = a_idx.shape[1]
|
| 193 |
-
num_dst = b_idx.shape[1]
|
| 194 |
-
|
| 195 |
-
# Define an internal function to separate src, dst, and protected tokens
|
| 196 |
-
def split(x):
|
| 197 |
-
C = x.shape[-1]
|
| 198 |
-
|
| 199 |
-
if enable_protection:
|
| 200 |
-
src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
|
| 201 |
-
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
| 202 |
-
protected = gather(
|
| 203 |
-
x, dim=1, index=protected_idx.expand(B, num_protected_actual, C)
|
| 204 |
-
)
|
| 205 |
-
return src, dst, protected
|
| 206 |
-
else:
|
| 207 |
-
src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
|
| 208 |
-
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
| 209 |
-
return src, dst
|
| 210 |
-
|
| 211 |
-
# Compute cosine similarity (normalize first then dot product)
|
| 212 |
-
metric = metric / metric.norm(dim=-1, keepdim=True)
|
| 213 |
-
if enable_protection:
|
| 214 |
-
a, b, protected = split(metric)
|
| 215 |
-
else:
|
| 216 |
-
a, b = split(metric)
|
| 217 |
-
|
| 218 |
-
r = min(a.shape[1], r)
|
| 219 |
-
num_src_actual = a.shape[1]
|
| 220 |
-
chunk_size = min(5000, num_src_actual)
|
| 221 |
-
|
| 222 |
-
node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype)
|
| 223 |
-
node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long)
|
| 224 |
-
|
| 225 |
-
b_transposed = b.transpose(-1, -2)
|
| 226 |
-
node_max, node_idx = fast_similarity_chunks(a, b_transposed, chunk_size)
|
| 227 |
-
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
| 228 |
-
|
| 229 |
-
# If protection is enabled, filter out protected tokens to ensure they are not merged
|
| 230 |
-
if enable_protection:
|
| 231 |
-
src_indices = a_idx[0, :, 0]
|
| 232 |
-
protected_mask_src = torch.isin(src_indices, protected_indices)
|
| 233 |
-
edge_flat = edge_idx[0, :, 0]
|
| 234 |
-
valid_mask = ~protected_mask_src[edge_flat]
|
| 235 |
-
valid_edges = edge_flat[valid_mask]
|
| 236 |
-
|
| 237 |
-
valid_count = valid_edges.shape[0]
|
| 238 |
-
r_actual = min(r, valid_count)
|
| 239 |
-
|
| 240 |
-
unm_idx = valid_edges[r_actual:].unsqueeze(0).unsqueeze(-1)
|
| 241 |
-
src_idx = valid_edges[:r_actual].unsqueeze(0).unsqueeze(-1)
|
| 242 |
-
else:
|
| 243 |
-
unm_idx = edge_idx[..., r:, :]
|
| 244 |
-
src_idx = edge_idx[..., :r, :]
|
| 245 |
-
r_actual = r
|
| 246 |
-
|
| 247 |
-
# Get dst token indices corresponding to each src token to be merged
|
| 248 |
-
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
| 249 |
-
r = r_actual
|
| 250 |
-
|
| 251 |
-
# Define merge function to merge selected src tokens to corresponding dst tokens
|
| 252 |
-
def merge(
|
| 253 |
-
x: torch.Tensor,
|
| 254 |
-
mode: str = "mean",
|
| 255 |
-
extra_tensors=None,
|
| 256 |
-
extra_tensors_2=None,
|
| 257 |
-
) -> Union[
|
| 258 |
-
torch.Tensor,
|
| 259 |
-
Tuple[torch.Tensor, torch.Tensor],
|
| 260 |
-
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
| 261 |
-
]:
|
| 262 |
-
if enable_protection:
|
| 263 |
-
src, dst, protected = split(x)
|
| 264 |
-
else:
|
| 265 |
-
src, dst = split(x)
|
| 266 |
-
|
| 267 |
-
n, t1, c = src.shape
|
| 268 |
-
|
| 269 |
-
# Extract unmerged src tokens - using actual unm_idx size
|
| 270 |
-
unm_len = unm_idx.shape[1]
|
| 271 |
-
unm = gather(src, dim=-2, index=unm_idx.expand(n, unm_len, c))
|
| 272 |
-
src_len = src_idx.shape[1]
|
| 273 |
-
src = gather(src, dim=-2, index=src_idx.expand(n, src_len, c))
|
| 274 |
-
dst = dst.scatter_reduce(-2, dst_idx.expand(n, src_len, c), src, reduce=mode)
|
| 275 |
-
|
| 276 |
-
# ---------------- Extra tensor processing ----------------
|
| 277 |
-
merged_extra_1 = None
|
| 278 |
-
merged_extra_2 = None
|
| 279 |
-
if extra_tensors is not None:
|
| 280 |
-
E_dim = extra_tensors.shape[-1]
|
| 281 |
-
if enable_protection:
|
| 282 |
-
src_e, dst_e, protected_e = split(extra_tensors)
|
| 283 |
-
else:
|
| 284 |
-
src_e, dst_e = split(extra_tensors)
|
| 285 |
-
|
| 286 |
-
# Consistent with main tensor, only select r src tokens to be merged
|
| 287 |
-
src_e_r = gather(src_e, dim=-2, index=src_idx.expand(n, src_len, E_dim))
|
| 288 |
-
unm_e = gather(src_e, dim=-2, index=unm_idx.expand(n, unm_len, E_dim))
|
| 289 |
-
|
| 290 |
-
dst_e = dst_e.scatter_reduce(
|
| 291 |
-
-2, dst_idx.expand(n, src_len, E_dim), src_e_r, reduce=mode
|
| 292 |
-
)
|
| 293 |
-
if enable_protection:
|
| 294 |
-
merged_extra_1 = torch.cat([unm_e, dst_e, protected_e], dim=1)
|
| 295 |
-
else:
|
| 296 |
-
merged_extra_1 = torch.cat([unm_e, dst_e], dim=1)
|
| 297 |
-
|
| 298 |
-
if extra_tensors_2 is not None:
|
| 299 |
-
E_dim_2 = extra_tensors_2.shape[-1]
|
| 300 |
-
if enable_protection:
|
| 301 |
-
src_e2, dst_e2, protected_e2 = split(extra_tensors_2)
|
| 302 |
-
else:
|
| 303 |
-
src_e2, dst_e2 = split(extra_tensors_2)
|
| 304 |
-
|
| 305 |
-
src_e2_r = gather(src_e2, dim=-2, index=src_idx.expand(n, src_len, E_dim_2))
|
| 306 |
-
unm_e2 = gather(src_e2, dim=-2, index=unm_idx.expand(n, unm_len, E_dim_2))
|
| 307 |
-
|
| 308 |
-
dst_e2 = dst_e2.scatter_reduce(
|
| 309 |
-
-2, dst_idx.expand(n, src_len, E_dim_2), src_e2_r, reduce=mode
|
| 310 |
-
)
|
| 311 |
-
if enable_protection:
|
| 312 |
-
merged_extra_2 = torch.cat([unm_e2, dst_e2, protected_e2], dim=1)
|
| 313 |
-
else:
|
| 314 |
-
merged_extra_2 = torch.cat([unm_e2, dst_e2], dim=1)
|
| 315 |
-
|
| 316 |
-
if enable_protection:
|
| 317 |
-
main_result = torch.cat([unm, dst, protected], dim=1)
|
| 318 |
-
else:
|
| 319 |
-
main_result = torch.cat([unm, dst], dim=1)
|
| 320 |
-
|
| 321 |
-
if merged_extra_1 is not None and merged_extra_2 is not None:
|
| 322 |
-
return main_result, merged_extra_1, merged_extra_2
|
| 323 |
-
elif merged_extra_1 is not None:
|
| 324 |
-
return main_result, merged_extra_1
|
| 325 |
-
else:
|
| 326 |
-
return main_result
|
| 327 |
-
|
| 328 |
-
# Define unmerge function to restore pre-merge state (for decoder)
|
| 329 |
-
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
| 330 |
-
unm_len = unm_idx.shape[1]
|
| 331 |
-
dst_len = num_dst
|
| 332 |
-
src_len = src_idx.shape[1]
|
| 333 |
-
unm = x[..., :unm_len, :]
|
| 334 |
-
dst = x[..., unm_len : unm_len + dst_len, :]
|
| 335 |
-
|
| 336 |
-
if enable_protection:
|
| 337 |
-
protected = x[
|
| 338 |
-
..., unm_len + dst_len : unm_len + dst_len + num_protected_actual, :
|
| 339 |
-
]
|
| 340 |
-
|
| 341 |
-
_, _, c = unm.shape
|
| 342 |
-
src = gather(dst, dim=-2, index=dst_idx.expand(B, src_len, c))
|
| 343 |
-
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
| 344 |
-
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
| 345 |
-
out.scatter_(
|
| 346 |
-
dim=-2,
|
| 347 |
-
index=gather(
|
| 348 |
-
a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx
|
| 349 |
-
).expand(B, unm_len, c),
|
| 350 |
-
src=unm,
|
| 351 |
-
)
|
| 352 |
-
|
| 353 |
-
out.scatter_(
|
| 354 |
-
dim=-2,
|
| 355 |
-
index=gather(
|
| 356 |
-
a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx
|
| 357 |
-
).expand(B, src_len, c),
|
| 358 |
-
src=src,
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
if enable_protection:
|
| 362 |
-
out.scatter_(
|
| 363 |
-
dim=-2,
|
| 364 |
-
index=protected_idx.expand(B, num_protected_actual, c),
|
| 365 |
-
src=protected,
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
return out
|
| 369 |
-
|
| 370 |
-
return merge, unmerge
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/requirements.txt
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
torch==2.3.1
|
| 2 |
-
torchvision==0.18.1
|
| 3 |
-
numpy==1.26.1
|
| 4 |
-
Pillow
|
| 5 |
-
huggingface_hub
|
| 6 |
-
einops
|
| 7 |
-
safetensors
|
| 8 |
-
evo
|
| 9 |
-
open3d
|
| 10 |
-
matplotlib
|
| 11 |
-
scipy
|
| 12 |
-
opencv-python
|
| 13 |
-
scikit-image
|
| 14 |
-
tqdm
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/vggt/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (132 Bytes)
|
|
|
FastVGGT/vggt/dependency/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (143 Bytes)
|
|
|
FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc
DELETED
|
Binary file (1.39 kB)
|
|
|
FastVGGT/vggt/dependency/distortion.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def apply_distortion(points, distortion_params):
|
| 12 |
-
"""
|
| 13 |
-
Apply distortion to normalized camera coordinates.
|
| 14 |
-
|
| 15 |
-
Args:
|
| 16 |
-
points: Array of normalized camera coordinates
|
| 17 |
-
distortion_params: Distortion parameters
|
| 18 |
-
|
| 19 |
-
Returns:
|
| 20 |
-
Distorted coordinates
|
| 21 |
-
"""
|
| 22 |
-
# Simple passthrough for now - implement actual distortion if needed
|
| 23 |
-
return points
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def iterative_undistortion(points, distortion_params, max_iter=10):
|
| 27 |
-
"""
|
| 28 |
-
Remove distortion from normalized camera coordinates using iterative method.
|
| 29 |
-
|
| 30 |
-
Args:
|
| 31 |
-
points: Array of distorted normalized camera coordinates
|
| 32 |
-
distortion_params: Distortion parameters
|
| 33 |
-
max_iter: Maximum number of iterations
|
| 34 |
-
|
| 35 |
-
Returns:
|
| 36 |
-
Undistorted coordinates
|
| 37 |
-
"""
|
| 38 |
-
# Simple passthrough for now - implement actual undistortion if needed
|
| 39 |
-
return points
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def single_undistortion(points, distortion_params):
|
| 43 |
-
"""
|
| 44 |
-
Remove distortion from normalized camera coordinates using single step.
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
points: Array of distorted normalized camera coordinates
|
| 48 |
-
distortion_params: Distortion parameters
|
| 49 |
-
|
| 50 |
-
Returns:
|
| 51 |
-
Undistorted coordinates
|
| 52 |
-
"""
|
| 53 |
-
# Simple passthrough for now - implement actual undistortion if needed
|
| 54 |
-
return points
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc
DELETED
|
Binary file (4.24 kB)
|
|
|
FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc
DELETED
|
Binary file (12.8 kB)
|
|
|
FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc
DELETED
|
Binary file (3.1 kB)
|
|
|
FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc
DELETED
|
Binary file (3.41 kB)
|
|
|
FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc
DELETED
|
Binary file (3.18 kB)
|
|
|
FastVGGT/vggt/heads/camera_head.py
DELETED
|
@@ -1,149 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import math
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
import torch.nn as nn
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
|
| 14 |
-
from vggt.layers import Mlp
|
| 15 |
-
from vggt.layers.block import Block
|
| 16 |
-
from vggt.heads.head_act import activate_pose
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class CameraHead(nn.Module):
|
| 20 |
-
"""
|
| 21 |
-
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 22 |
-
|
| 23 |
-
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
def __init__(
|
| 27 |
-
self,
|
| 28 |
-
dim_in: int = 2048,
|
| 29 |
-
trunk_depth: int = 4,
|
| 30 |
-
pose_encoding_type: str = "absT_quaR_FoV",
|
| 31 |
-
num_heads: int = 16,
|
| 32 |
-
mlp_ratio: int = 4,
|
| 33 |
-
init_values: float = 0.01,
|
| 34 |
-
trans_act: str = "linear",
|
| 35 |
-
quat_act: str = "linear",
|
| 36 |
-
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
| 37 |
-
):
|
| 38 |
-
super().__init__()
|
| 39 |
-
|
| 40 |
-
if pose_encoding_type == "absT_quaR_FoV":
|
| 41 |
-
self.target_dim = 9
|
| 42 |
-
else:
|
| 43 |
-
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
| 44 |
-
|
| 45 |
-
self.trans_act = trans_act
|
| 46 |
-
self.quat_act = quat_act
|
| 47 |
-
self.fl_act = fl_act
|
| 48 |
-
self.trunk_depth = trunk_depth
|
| 49 |
-
|
| 50 |
-
# Build the trunk using a sequence of transformer blocks.
|
| 51 |
-
self.trunk = nn.Sequential(
|
| 52 |
-
*[
|
| 53 |
-
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
|
| 54 |
-
for _ in range(trunk_depth)
|
| 55 |
-
]
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
# Normalizations for camera token and trunk output.
|
| 59 |
-
self.token_norm = nn.LayerNorm(dim_in)
|
| 60 |
-
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 61 |
-
|
| 62 |
-
# Learnable empty camera pose token.
|
| 63 |
-
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 64 |
-
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 65 |
-
|
| 66 |
-
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 67 |
-
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
| 68 |
-
|
| 69 |
-
# Adaptive layer normalization without affine parameters.
|
| 70 |
-
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 71 |
-
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
| 72 |
-
|
| 73 |
-
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
| 74 |
-
"""
|
| 75 |
-
Forward pass to predict camera parameters.
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
aggregated_tokens_list (list): List of token tensors from the network;
|
| 79 |
-
the last tensor is used for prediction.
|
| 80 |
-
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
| 81 |
-
|
| 82 |
-
Returns:
|
| 83 |
-
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 84 |
-
"""
|
| 85 |
-
# Use tokens from the last block for camera prediction.
|
| 86 |
-
tokens = aggregated_tokens_list[-1]
|
| 87 |
-
|
| 88 |
-
# Extract the camera tokens
|
| 89 |
-
pose_tokens = tokens[:, :, 0]
|
| 90 |
-
pose_tokens = self.token_norm(pose_tokens)
|
| 91 |
-
|
| 92 |
-
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
| 93 |
-
return pred_pose_enc_list
|
| 94 |
-
|
| 95 |
-
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
| 96 |
-
"""
|
| 97 |
-
Iteratively refine camera pose predictions.
|
| 98 |
-
|
| 99 |
-
Args:
|
| 100 |
-
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
| 101 |
-
num_iterations (int): Number of refinement iterations.
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
list: List of activated camera encodings from each iteration.
|
| 105 |
-
"""
|
| 106 |
-
B, S, C = pose_tokens.shape # S is expected to be 1.
|
| 107 |
-
pred_pose_enc = None
|
| 108 |
-
pred_pose_enc_list = []
|
| 109 |
-
|
| 110 |
-
for _ in range(num_iterations):
|
| 111 |
-
# Use a learned empty pose for the first iteration.
|
| 112 |
-
if pred_pose_enc is None:
|
| 113 |
-
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 114 |
-
else:
|
| 115 |
-
# Detach the previous prediction to avoid backprop through time.
|
| 116 |
-
pred_pose_enc = pred_pose_enc.detach()
|
| 117 |
-
module_input = self.embed_pose(pred_pose_enc)
|
| 118 |
-
|
| 119 |
-
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 120 |
-
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
| 121 |
-
|
| 122 |
-
# Adaptive layer normalization and modulation.
|
| 123 |
-
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
| 124 |
-
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 125 |
-
|
| 126 |
-
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
| 127 |
-
# Compute the delta update for the pose encoding.
|
| 128 |
-
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
| 129 |
-
|
| 130 |
-
if pred_pose_enc is None:
|
| 131 |
-
pred_pose_enc = pred_pose_enc_delta
|
| 132 |
-
else:
|
| 133 |
-
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 134 |
-
|
| 135 |
-
# Apply final activation functions for translation, quaternion, and field-of-view.
|
| 136 |
-
activated_pose = activate_pose(
|
| 137 |
-
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
| 138 |
-
)
|
| 139 |
-
pred_pose_enc_list.append(activated_pose)
|
| 140 |
-
|
| 141 |
-
return pred_pose_enc_list
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 145 |
-
"""
|
| 146 |
-
Modulate the input tensor using scaling and shifting parameters.
|
| 147 |
-
"""
|
| 148 |
-
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 149 |
-
return x * (1 + scale) + shift
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/vggt/heads/dpt_head.py
DELETED
|
@@ -1,598 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
import os
|
| 12 |
-
from typing import List, Dict, Tuple, Union, Optional
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
import torch.nn as nn
|
| 16 |
-
import torch.nn.functional as F
|
| 17 |
-
from .head_act import activate_head
|
| 18 |
-
from .utils import create_uv_grid, position_grid_to_embed
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class DPTHead(nn.Module):
|
| 22 |
-
"""
|
| 23 |
-
DPT Head for dense prediction tasks.
|
| 24 |
-
|
| 25 |
-
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 26 |
-
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 27 |
-
backbone and produces dense predictions by fusing multi-scale features.
|
| 28 |
-
|
| 29 |
-
Args:
|
| 30 |
-
dim_in (int): Input dimension (channels).
|
| 31 |
-
patch_size (int, optional): Patch size. Default is 14.
|
| 32 |
-
output_dim (int, optional): Number of output channels. Default is 4.
|
| 33 |
-
activation (str, optional): Activation type. Default is "inv_log".
|
| 34 |
-
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 35 |
-
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 36 |
-
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 37 |
-
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 38 |
-
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 39 |
-
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 40 |
-
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 41 |
-
"""
|
| 42 |
-
|
| 43 |
-
def __init__(
|
| 44 |
-
self,
|
| 45 |
-
dim_in: int,
|
| 46 |
-
patch_size: int = 14,
|
| 47 |
-
output_dim: int = 4,
|
| 48 |
-
activation: str = "inv_log",
|
| 49 |
-
conf_activation: str = "expp1",
|
| 50 |
-
features: int = 256,
|
| 51 |
-
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 52 |
-
intermediate_layer_idx: List[int] = [0, 1, 2, 3],
|
| 53 |
-
pos_embed: bool = True,
|
| 54 |
-
feature_only: bool = False,
|
| 55 |
-
down_ratio: int = 1,
|
| 56 |
-
) -> None:
|
| 57 |
-
super(DPTHead, self).__init__()
|
| 58 |
-
self.patch_size = patch_size
|
| 59 |
-
self.activation = activation
|
| 60 |
-
self.conf_activation = conf_activation
|
| 61 |
-
self.pos_embed = pos_embed
|
| 62 |
-
self.feature_only = feature_only
|
| 63 |
-
self.down_ratio = down_ratio
|
| 64 |
-
self.intermediate_layer_idx = intermediate_layer_idx
|
| 65 |
-
|
| 66 |
-
self.norm = nn.LayerNorm(dim_in)
|
| 67 |
-
|
| 68 |
-
# Projection layers for each output channel from tokens.
|
| 69 |
-
self.projects = nn.ModuleList(
|
| 70 |
-
[
|
| 71 |
-
nn.Conv2d(
|
| 72 |
-
in_channels=dim_in,
|
| 73 |
-
out_channels=oc,
|
| 74 |
-
kernel_size=1,
|
| 75 |
-
stride=1,
|
| 76 |
-
padding=0,
|
| 77 |
-
)
|
| 78 |
-
for oc in out_channels
|
| 79 |
-
]
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
# Resize layers for upsampling feature maps.
|
| 83 |
-
self.resize_layers = nn.ModuleList(
|
| 84 |
-
[
|
| 85 |
-
nn.ConvTranspose2d(
|
| 86 |
-
in_channels=out_channels[0],
|
| 87 |
-
out_channels=out_channels[0],
|
| 88 |
-
kernel_size=4,
|
| 89 |
-
stride=4,
|
| 90 |
-
padding=0,
|
| 91 |
-
),
|
| 92 |
-
nn.ConvTranspose2d(
|
| 93 |
-
in_channels=out_channels[1],
|
| 94 |
-
out_channels=out_channels[1],
|
| 95 |
-
kernel_size=2,
|
| 96 |
-
stride=2,
|
| 97 |
-
padding=0,
|
| 98 |
-
),
|
| 99 |
-
nn.Identity(),
|
| 100 |
-
nn.Conv2d(
|
| 101 |
-
in_channels=out_channels[3],
|
| 102 |
-
out_channels=out_channels[3],
|
| 103 |
-
kernel_size=3,
|
| 104 |
-
stride=2,
|
| 105 |
-
padding=1,
|
| 106 |
-
),
|
| 107 |
-
]
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
self.scratch = _make_scratch(out_channels, features, expand=False)
|
| 111 |
-
|
| 112 |
-
# Attach additional modules to scratch.
|
| 113 |
-
self.scratch.stem_transpose = nn.Identity() # Use Identity instead of None
|
| 114 |
-
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 115 |
-
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 116 |
-
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 117 |
-
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 118 |
-
|
| 119 |
-
head_features_1 = features
|
| 120 |
-
head_features_2 = 32
|
| 121 |
-
|
| 122 |
-
if feature_only:
|
| 123 |
-
self.scratch.output_conv1 = nn.Conv2d(
|
| 124 |
-
head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
|
| 125 |
-
)
|
| 126 |
-
else:
|
| 127 |
-
self.scratch.output_conv1 = nn.Conv2d(
|
| 128 |
-
head_features_1,
|
| 129 |
-
head_features_1 // 2,
|
| 130 |
-
kernel_size=3,
|
| 131 |
-
stride=1,
|
| 132 |
-
padding=1,
|
| 133 |
-
)
|
| 134 |
-
conv2_in_channels = head_features_1 // 2
|
| 135 |
-
|
| 136 |
-
self.scratch.output_conv2 = nn.Sequential(
|
| 137 |
-
nn.Conv2d(
|
| 138 |
-
conv2_in_channels,
|
| 139 |
-
head_features_2,
|
| 140 |
-
kernel_size=3,
|
| 141 |
-
stride=1,
|
| 142 |
-
padding=1,
|
| 143 |
-
),
|
| 144 |
-
nn.ReLU(inplace=True),
|
| 145 |
-
nn.Conv2d(
|
| 146 |
-
head_features_2, output_dim, kernel_size=1, stride=1, padding=0
|
| 147 |
-
),
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
def forward(
|
| 151 |
-
self,
|
| 152 |
-
aggregated_tokens_list: List[torch.Tensor],
|
| 153 |
-
images: torch.Tensor,
|
| 154 |
-
patch_start_idx: int,
|
| 155 |
-
frames_chunk_size: int = 8,
|
| 156 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 157 |
-
"""
|
| 158 |
-
Forward pass through the DPT head, supports processing by chunking frames.
|
| 159 |
-
Args:
|
| 160 |
-
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 161 |
-
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 162 |
-
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 163 |
-
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 164 |
-
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 165 |
-
If None or larger than S, all frames are processed at once. Default: 8.
|
| 166 |
-
|
| 167 |
-
Returns:
|
| 168 |
-
Tensor or Tuple[Tensor, Tensor]:
|
| 169 |
-
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 170 |
-
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 171 |
-
"""
|
| 172 |
-
B, S, _, H, W = images.shape
|
| 173 |
-
|
| 174 |
-
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
| 175 |
-
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 176 |
-
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
| 177 |
-
|
| 178 |
-
# Otherwise, process frames in chunks to manage memory usage
|
| 179 |
-
assert frames_chunk_size > 0
|
| 180 |
-
|
| 181 |
-
# Process frames in batches
|
| 182 |
-
all_preds = []
|
| 183 |
-
all_conf = []
|
| 184 |
-
|
| 185 |
-
for frames_start_idx in range(0, S, frames_chunk_size):
|
| 186 |
-
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
| 187 |
-
|
| 188 |
-
# Process batch of frames
|
| 189 |
-
if self.feature_only:
|
| 190 |
-
chunk_output = self._forward_impl(
|
| 191 |
-
aggregated_tokens_list,
|
| 192 |
-
images,
|
| 193 |
-
patch_start_idx,
|
| 194 |
-
frames_start_idx,
|
| 195 |
-
frames_end_idx,
|
| 196 |
-
)
|
| 197 |
-
all_preds.append(chunk_output)
|
| 198 |
-
else:
|
| 199 |
-
chunk_preds, chunk_conf = self._forward_impl(
|
| 200 |
-
aggregated_tokens_list,
|
| 201 |
-
images,
|
| 202 |
-
patch_start_idx,
|
| 203 |
-
frames_start_idx,
|
| 204 |
-
frames_end_idx,
|
| 205 |
-
)
|
| 206 |
-
all_preds.append(chunk_preds)
|
| 207 |
-
all_conf.append(chunk_conf)
|
| 208 |
-
|
| 209 |
-
# Concatenate results along the sequence dimension
|
| 210 |
-
if self.feature_only:
|
| 211 |
-
return torch.cat(all_preds, dim=1)
|
| 212 |
-
else:
|
| 213 |
-
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
| 214 |
-
|
| 215 |
-
def _forward_impl(
|
| 216 |
-
self,
|
| 217 |
-
aggregated_tokens_list: List[torch.Tensor],
|
| 218 |
-
images: torch.Tensor,
|
| 219 |
-
patch_start_idx: int,
|
| 220 |
-
frames_start_idx: Optional[int] = None,
|
| 221 |
-
frames_end_idx: Optional[int] = None,
|
| 222 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 223 |
-
"""
|
| 224 |
-
Implementation of the forward pass through the DPT head.
|
| 225 |
-
|
| 226 |
-
This method processes a specific chunk of frames from the sequence.
|
| 227 |
-
|
| 228 |
-
Args:
|
| 229 |
-
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 230 |
-
images (Tensor): Input images with shape [B, S, 3, H, W].
|
| 231 |
-
patch_start_idx (int): Starting index for patch tokens.
|
| 232 |
-
frames_start_idx (int, optional): Starting index for frames to process.
|
| 233 |
-
frames_end_idx (int, optional): Ending index for frames to process.
|
| 234 |
-
|
| 235 |
-
Returns:
|
| 236 |
-
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
| 237 |
-
"""
|
| 238 |
-
if frames_start_idx is not None and frames_end_idx is not None:
|
| 239 |
-
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
| 240 |
-
|
| 241 |
-
B, S, _, H, W = images.shape
|
| 242 |
-
|
| 243 |
-
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 244 |
-
|
| 245 |
-
out = []
|
| 246 |
-
dpt_idx = 0
|
| 247 |
-
|
| 248 |
-
for layer_idx in self.intermediate_layer_idx:
|
| 249 |
-
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
| 250 |
-
|
| 251 |
-
# Select frames if processing a chunk
|
| 252 |
-
if frames_start_idx is not None and frames_end_idx is not None:
|
| 253 |
-
x = x[:, frames_start_idx:frames_end_idx]
|
| 254 |
-
|
| 255 |
-
x = x.reshape(B * S, -1, x.shape[-1])
|
| 256 |
-
|
| 257 |
-
x = self.norm(x)
|
| 258 |
-
|
| 259 |
-
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 260 |
-
|
| 261 |
-
x = self.projects[dpt_idx](x)
|
| 262 |
-
if self.pos_embed:
|
| 263 |
-
x = self._apply_pos_embed(x, W, H)
|
| 264 |
-
x = self.resize_layers[dpt_idx](x)
|
| 265 |
-
|
| 266 |
-
out.append(x)
|
| 267 |
-
dpt_idx += 1
|
| 268 |
-
|
| 269 |
-
# Fuse features from multiple layers.
|
| 270 |
-
out = self.scratch_forward(out)
|
| 271 |
-
# Interpolate fused output to match target image resolution.
|
| 272 |
-
out = custom_interpolate(
|
| 273 |
-
out,
|
| 274 |
-
(
|
| 275 |
-
int(patch_h * self.patch_size / self.down_ratio),
|
| 276 |
-
int(patch_w * self.patch_size / self.down_ratio),
|
| 277 |
-
),
|
| 278 |
-
mode="bilinear",
|
| 279 |
-
align_corners=True,
|
| 280 |
-
)
|
| 281 |
-
|
| 282 |
-
if self.pos_embed:
|
| 283 |
-
out = self._apply_pos_embed(out, W, H)
|
| 284 |
-
|
| 285 |
-
if self.feature_only:
|
| 286 |
-
return out.view(B, S, *out.shape[1:])
|
| 287 |
-
|
| 288 |
-
out = self.scratch.output_conv2(out)
|
| 289 |
-
preds, conf = activate_head(
|
| 290 |
-
out, activation=self.activation, conf_activation=self.conf_activation
|
| 291 |
-
)
|
| 292 |
-
|
| 293 |
-
preds = preds.view(B, S, *preds.shape[1:])
|
| 294 |
-
conf = conf.view(B, S, *conf.shape[1:])
|
| 295 |
-
return preds, conf
|
| 296 |
-
|
| 297 |
-
def _apply_pos_embed(
|
| 298 |
-
self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
|
| 299 |
-
) -> torch.Tensor:
|
| 300 |
-
"""
|
| 301 |
-
Apply positional embedding to tensor x.
|
| 302 |
-
"""
|
| 303 |
-
patch_w = x.shape[-1]
|
| 304 |
-
patch_h = x.shape[-2]
|
| 305 |
-
pos_embed = create_uv_grid(
|
| 306 |
-
patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
|
| 307 |
-
)
|
| 308 |
-
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 309 |
-
pos_embed = pos_embed * ratio
|
| 310 |
-
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 311 |
-
return x + pos_embed
|
| 312 |
-
|
| 313 |
-
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 314 |
-
"""
|
| 315 |
-
Forward pass through the fusion blocks.
|
| 316 |
-
|
| 317 |
-
Args:
|
| 318 |
-
features (List[Tensor]): List of feature maps from different layers.
|
| 319 |
-
|
| 320 |
-
Returns:
|
| 321 |
-
Tensor: Fused feature map.
|
| 322 |
-
"""
|
| 323 |
-
layer_1, layer_2, layer_3, layer_4 = features
|
| 324 |
-
|
| 325 |
-
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 326 |
-
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 327 |
-
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 328 |
-
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 329 |
-
|
| 330 |
-
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 331 |
-
del layer_4_rn, layer_4
|
| 332 |
-
|
| 333 |
-
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 334 |
-
del layer_3_rn, layer_3
|
| 335 |
-
|
| 336 |
-
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 337 |
-
del layer_2_rn, layer_2
|
| 338 |
-
|
| 339 |
-
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 340 |
-
del layer_1_rn, layer_1
|
| 341 |
-
|
| 342 |
-
out = self.scratch.output_conv1(out)
|
| 343 |
-
return out
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
################################################################################
|
| 347 |
-
# Modules
|
| 348 |
-
################################################################################
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
def _make_fusion_block(
|
| 352 |
-
features: int,
|
| 353 |
-
size: Optional[int] = None,
|
| 354 |
-
has_residual: bool = True,
|
| 355 |
-
groups: int = 1,
|
| 356 |
-
) -> nn.Module:
|
| 357 |
-
return FeatureFusionBlock(
|
| 358 |
-
features,
|
| 359 |
-
nn.ReLU(inplace=True),
|
| 360 |
-
deconv=False,
|
| 361 |
-
bn=False,
|
| 362 |
-
expand=False,
|
| 363 |
-
align_corners=True,
|
| 364 |
-
size=size,
|
| 365 |
-
has_residual=has_residual,
|
| 366 |
-
groups=groups,
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
def _make_scratch(
|
| 371 |
-
in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
|
| 372 |
-
) -> nn.Module:
|
| 373 |
-
scratch = nn.Module()
|
| 374 |
-
out_shape1 = out_shape
|
| 375 |
-
out_shape2 = out_shape
|
| 376 |
-
out_shape3 = out_shape
|
| 377 |
-
if len(in_shape) >= 4:
|
| 378 |
-
out_shape4 = out_shape
|
| 379 |
-
|
| 380 |
-
if expand:
|
| 381 |
-
out_shape1 = out_shape
|
| 382 |
-
out_shape2 = out_shape * 2
|
| 383 |
-
out_shape3 = out_shape * 4
|
| 384 |
-
if len(in_shape) >= 4:
|
| 385 |
-
out_shape4 = out_shape * 8
|
| 386 |
-
|
| 387 |
-
scratch.layer1_rn = nn.Conv2d(
|
| 388 |
-
in_shape[0],
|
| 389 |
-
out_shape1,
|
| 390 |
-
kernel_size=3,
|
| 391 |
-
stride=1,
|
| 392 |
-
padding=1,
|
| 393 |
-
bias=False,
|
| 394 |
-
groups=groups,
|
| 395 |
-
)
|
| 396 |
-
scratch.layer2_rn = nn.Conv2d(
|
| 397 |
-
in_shape[1],
|
| 398 |
-
out_shape2,
|
| 399 |
-
kernel_size=3,
|
| 400 |
-
stride=1,
|
| 401 |
-
padding=1,
|
| 402 |
-
bias=False,
|
| 403 |
-
groups=groups,
|
| 404 |
-
)
|
| 405 |
-
scratch.layer3_rn = nn.Conv2d(
|
| 406 |
-
in_shape[2],
|
| 407 |
-
out_shape3,
|
| 408 |
-
kernel_size=3,
|
| 409 |
-
stride=1,
|
| 410 |
-
padding=1,
|
| 411 |
-
bias=False,
|
| 412 |
-
groups=groups,
|
| 413 |
-
)
|
| 414 |
-
if len(in_shape) >= 4:
|
| 415 |
-
scratch.layer4_rn = nn.Conv2d(
|
| 416 |
-
in_shape[3],
|
| 417 |
-
out_shape4,
|
| 418 |
-
kernel_size=3,
|
| 419 |
-
stride=1,
|
| 420 |
-
padding=1,
|
| 421 |
-
bias=False,
|
| 422 |
-
groups=groups,
|
| 423 |
-
)
|
| 424 |
-
return scratch
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
class ResidualConvUnit(nn.Module):
|
| 428 |
-
"""Residual convolution module."""
|
| 429 |
-
|
| 430 |
-
def __init__(self, features, activation, bn, groups=1):
|
| 431 |
-
"""Init.
|
| 432 |
-
|
| 433 |
-
Args:
|
| 434 |
-
features (int): number of features
|
| 435 |
-
"""
|
| 436 |
-
super().__init__()
|
| 437 |
-
|
| 438 |
-
self.bn = bn
|
| 439 |
-
self.groups = groups
|
| 440 |
-
self.conv1 = nn.Conv2d(
|
| 441 |
-
features,
|
| 442 |
-
features,
|
| 443 |
-
kernel_size=3,
|
| 444 |
-
stride=1,
|
| 445 |
-
padding=1,
|
| 446 |
-
bias=True,
|
| 447 |
-
groups=self.groups,
|
| 448 |
-
)
|
| 449 |
-
self.conv2 = nn.Conv2d(
|
| 450 |
-
features,
|
| 451 |
-
features,
|
| 452 |
-
kernel_size=3,
|
| 453 |
-
stride=1,
|
| 454 |
-
padding=1,
|
| 455 |
-
bias=True,
|
| 456 |
-
groups=self.groups,
|
| 457 |
-
)
|
| 458 |
-
|
| 459 |
-
self.norm1 = None
|
| 460 |
-
self.norm2 = None
|
| 461 |
-
|
| 462 |
-
self.activation = activation
|
| 463 |
-
|
| 464 |
-
def forward(self, x):
|
| 465 |
-
"""Forward pass.
|
| 466 |
-
|
| 467 |
-
Args:
|
| 468 |
-
x (tensor): input
|
| 469 |
-
|
| 470 |
-
Returns:
|
| 471 |
-
tensor: output
|
| 472 |
-
"""
|
| 473 |
-
|
| 474 |
-
out = self.activation(x)
|
| 475 |
-
out = self.conv1(out)
|
| 476 |
-
if self.norm1 is not None:
|
| 477 |
-
out = self.norm1(out)
|
| 478 |
-
|
| 479 |
-
out = self.activation(out)
|
| 480 |
-
out = self.conv2(out)
|
| 481 |
-
if self.norm2 is not None:
|
| 482 |
-
out = self.norm2(out)
|
| 483 |
-
|
| 484 |
-
return out + x
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
class FeatureFusionBlock(nn.Module):
|
| 488 |
-
"""Feature fusion block."""
|
| 489 |
-
|
| 490 |
-
def __init__(
|
| 491 |
-
self,
|
| 492 |
-
features,
|
| 493 |
-
activation,
|
| 494 |
-
deconv=False,
|
| 495 |
-
bn=False,
|
| 496 |
-
expand=False,
|
| 497 |
-
align_corners=True,
|
| 498 |
-
size=None,
|
| 499 |
-
has_residual=True,
|
| 500 |
-
groups=1,
|
| 501 |
-
):
|
| 502 |
-
"""Init.
|
| 503 |
-
|
| 504 |
-
Args:
|
| 505 |
-
features (int): number of features
|
| 506 |
-
"""
|
| 507 |
-
super(FeatureFusionBlock, self).__init__()
|
| 508 |
-
|
| 509 |
-
self.deconv = deconv
|
| 510 |
-
self.align_corners = align_corners
|
| 511 |
-
self.groups = groups
|
| 512 |
-
self.expand = expand
|
| 513 |
-
out_features = features
|
| 514 |
-
if self.expand == True:
|
| 515 |
-
out_features = features // 2
|
| 516 |
-
|
| 517 |
-
self.out_conv = nn.Conv2d(
|
| 518 |
-
features,
|
| 519 |
-
out_features,
|
| 520 |
-
kernel_size=1,
|
| 521 |
-
stride=1,
|
| 522 |
-
padding=0,
|
| 523 |
-
bias=True,
|
| 524 |
-
groups=self.groups,
|
| 525 |
-
)
|
| 526 |
-
|
| 527 |
-
if has_residual:
|
| 528 |
-
self.resConfUnit1 = ResidualConvUnit(
|
| 529 |
-
features, activation, bn, groups=self.groups
|
| 530 |
-
)
|
| 531 |
-
|
| 532 |
-
self.has_residual = has_residual
|
| 533 |
-
self.resConfUnit2 = ResidualConvUnit(
|
| 534 |
-
features, activation, bn, groups=self.groups
|
| 535 |
-
)
|
| 536 |
-
|
| 537 |
-
self.size = size
|
| 538 |
-
|
| 539 |
-
def forward(self, *xs, size=None):
|
| 540 |
-
"""Forward pass.
|
| 541 |
-
|
| 542 |
-
Returns:
|
| 543 |
-
tensor: output
|
| 544 |
-
"""
|
| 545 |
-
output = xs[0]
|
| 546 |
-
|
| 547 |
-
if self.has_residual:
|
| 548 |
-
res = self.resConfUnit1(xs[1])
|
| 549 |
-
output = output + res
|
| 550 |
-
|
| 551 |
-
output = self.resConfUnit2(output)
|
| 552 |
-
|
| 553 |
-
if (size is None) and (self.size is None):
|
| 554 |
-
modifier = {"scale_factor": 2}
|
| 555 |
-
elif size is None:
|
| 556 |
-
modifier = {"size": self.size}
|
| 557 |
-
else:
|
| 558 |
-
modifier = {"size": size}
|
| 559 |
-
|
| 560 |
-
output = custom_interpolate(
|
| 561 |
-
output, **modifier, mode="bilinear", align_corners=self.align_corners
|
| 562 |
-
)
|
| 563 |
-
output = self.out_conv(output)
|
| 564 |
-
|
| 565 |
-
return output
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
def custom_interpolate(
|
| 569 |
-
x: torch.Tensor,
|
| 570 |
-
size: Optional[Tuple[int, int]] = None,
|
| 571 |
-
scale_factor: Optional[float] = None,
|
| 572 |
-
mode: str = "bilinear",
|
| 573 |
-
align_corners: bool = True,
|
| 574 |
-
) -> torch.Tensor:
|
| 575 |
-
"""
|
| 576 |
-
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 577 |
-
"""
|
| 578 |
-
if size is None:
|
| 579 |
-
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 580 |
-
|
| 581 |
-
INT_MAX = 1610612736
|
| 582 |
-
|
| 583 |
-
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 584 |
-
|
| 585 |
-
if input_elements > INT_MAX:
|
| 586 |
-
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 587 |
-
interpolated_chunks = [
|
| 588 |
-
nn.functional.interpolate(
|
| 589 |
-
chunk, size=size, mode=mode, align_corners=align_corners
|
| 590 |
-
)
|
| 591 |
-
for chunk in chunks
|
| 592 |
-
]
|
| 593 |
-
x = torch.cat(interpolated_chunks, dim=0)
|
| 594 |
-
return x.contiguous()
|
| 595 |
-
else:
|
| 596 |
-
return nn.functional.interpolate(
|
| 597 |
-
x, size=size, mode=mode, align_corners=align_corners
|
| 598 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FastVGGT/vggt/heads/head_act.py
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
| 13 |
-
"""
|
| 14 |
-
Activate pose parameters with specified activation functions.
|
| 15 |
-
|
| 16 |
-
Args:
|
| 17 |
-
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
| 18 |
-
trans_act: Activation type for translation component
|
| 19 |
-
quat_act: Activation type for quaternion component
|
| 20 |
-
fl_act: Activation type for focal length component
|
| 21 |
-
|
| 22 |
-
Returns:
|
| 23 |
-
Activated pose parameters tensor
|
| 24 |
-
"""
|
| 25 |
-
T = pred_pose_enc[..., :3]
|
| 26 |
-
quat = pred_pose_enc[..., 3:7]
|
| 27 |
-
fl = pred_pose_enc[..., 7:] # or fov
|
| 28 |
-
|
| 29 |
-
T = base_pose_act(T, trans_act)
|
| 30 |
-
quat = base_pose_act(quat, quat_act)
|
| 31 |
-
fl = base_pose_act(fl, fl_act) # or fov
|
| 32 |
-
|
| 33 |
-
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
| 34 |
-
|
| 35 |
-
return pred_pose_enc
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def base_pose_act(pose_enc, act_type="linear"):
|
| 39 |
-
"""
|
| 40 |
-
Apply basic activation function to pose parameters.
|
| 41 |
-
|
| 42 |
-
Args:
|
| 43 |
-
pose_enc: Tensor containing encoded pose parameters
|
| 44 |
-
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
| 45 |
-
|
| 46 |
-
Returns:
|
| 47 |
-
Activated pose parameters
|
| 48 |
-
"""
|
| 49 |
-
if act_type == "linear":
|
| 50 |
-
return pose_enc
|
| 51 |
-
elif act_type == "inv_log":
|
| 52 |
-
return inverse_log_transform(pose_enc)
|
| 53 |
-
elif act_type == "exp":
|
| 54 |
-
return torch.exp(pose_enc)
|
| 55 |
-
elif act_type == "relu":
|
| 56 |
-
return F.relu(pose_enc)
|
| 57 |
-
else:
|
| 58 |
-
raise ValueError(f"Unknown act_type: {act_type}")
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
| 62 |
-
"""
|
| 63 |
-
Process network output to extract 3D points and confidence values.
|
| 64 |
-
|
| 65 |
-
Args:
|
| 66 |
-
out: Network output tensor (B, C, H, W)
|
| 67 |
-
activation: Activation type for 3D points
|
| 68 |
-
conf_activation: Activation type for confidence values
|
| 69 |
-
|
| 70 |
-
Returns:
|
| 71 |
-
Tuple of (3D points tensor, confidence tensor)
|
| 72 |
-
"""
|
| 73 |
-
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
| 74 |
-
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
| 75 |
-
|
| 76 |
-
# Split into xyz (first C-1 channels) and confidence (last channel)
|
| 77 |
-
xyz = fmap[:, :, :, :-1]
|
| 78 |
-
conf = fmap[:, :, :, -1]
|
| 79 |
-
|
| 80 |
-
if activation == "norm_exp":
|
| 81 |
-
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 82 |
-
xyz_normed = xyz / d
|
| 83 |
-
pts3d = xyz_normed * torch.expm1(d)
|
| 84 |
-
elif activation == "norm":
|
| 85 |
-
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
| 86 |
-
elif activation == "exp":
|
| 87 |
-
pts3d = torch.exp(xyz)
|
| 88 |
-
elif activation == "relu":
|
| 89 |
-
pts3d = F.relu(xyz)
|
| 90 |
-
elif activation == "inv_log":
|
| 91 |
-
pts3d = inverse_log_transform(xyz)
|
| 92 |
-
elif activation == "xy_inv_log":
|
| 93 |
-
xy, z = xyz.split([2, 1], dim=-1)
|
| 94 |
-
z = inverse_log_transform(z)
|
| 95 |
-
pts3d = torch.cat([xy * z, z], dim=-1)
|
| 96 |
-
elif activation == "sigmoid":
|
| 97 |
-
pts3d = torch.sigmoid(xyz)
|
| 98 |
-
elif activation == "linear":
|
| 99 |
-
pts3d = xyz
|
| 100 |
-
else:
|
| 101 |
-
raise ValueError(f"Unknown activation: {activation}")
|
| 102 |
-
|
| 103 |
-
if conf_activation == "expp1":
|
| 104 |
-
conf_out = 1 + conf.exp()
|
| 105 |
-
elif conf_activation == "expp0":
|
| 106 |
-
conf_out = conf.exp()
|
| 107 |
-
elif conf_activation == "sigmoid":
|
| 108 |
-
conf_out = torch.sigmoid(conf)
|
| 109 |
-
else:
|
| 110 |
-
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
| 111 |
-
|
| 112 |
-
return pts3d, conf_out
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def inverse_log_transform(y):
|
| 116 |
-
"""
|
| 117 |
-
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
| 118 |
-
|
| 119 |
-
Args:
|
| 120 |
-
y: Input tensor
|
| 121 |
-
|
| 122 |
-
Returns:
|
| 123 |
-
Transformed tensor
|
| 124 |
-
"""
|
| 125 |
-
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|