Spaces:
Runtime error
Runtime error
Commit ·
71466b0
1
Parent(s): 407c655
Add dust3r source for HF (no binaries)
Browse files- .gitignore +6 -1
- app.py +5 -0
- dust3r +0 -1
- dust3r/.gitignore +132 -0
- dust3r/.gitmodules +3 -0
- dust3r/LICENSE +7 -0
- dust3r/NOTICE +13 -0
- dust3r/README.md +299 -0
- dust3r/datasets_preprocess/path_to_root.py +13 -0
- dust3r/datasets_preprocess/preprocess_co3d.py +295 -0
- dust3r/demo.py +283 -0
- dust3r/dust3r/__init__.py +2 -0
- dust3r/dust3r/cloud_opt/__init__.py +29 -0
- dust3r/dust3r/cloud_opt/base_opt.py +375 -0
- dust3r/dust3r/cloud_opt/commons.py +90 -0
- dust3r/dust3r/cloud_opt/init_im_poses.py +312 -0
- dust3r/dust3r/cloud_opt/optimizer.py +230 -0
- dust3r/dust3r/cloud_opt/pair_viewer.py +125 -0
- dust3r/dust3r/datasets/__init__.py +42 -0
- dust3r/dust3r/datasets/base/__init__.py +2 -0
- dust3r/dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
- dust3r/dust3r/datasets/base/batched_sampler.py +74 -0
- dust3r/dust3r/datasets/base/easy_dataset.py +157 -0
- dust3r/dust3r/datasets/co3d.py +146 -0
- dust3r/dust3r/datasets/utils/__init__.py +2 -0
- dust3r/dust3r/datasets/utils/cropping.py +119 -0
- dust3r/dust3r/datasets/utils/transforms.py +11 -0
- dust3r/dust3r/heads/__init__.py +19 -0
- dust3r/dust3r/heads/dpt_head.py +115 -0
- dust3r/dust3r/heads/linear_head.py +41 -0
- dust3r/dust3r/heads/postprocess.py +58 -0
- dust3r/dust3r/image_pairs.py +83 -0
- dust3r/dust3r/inference.py +165 -0
- dust3r/dust3r/losses.py +297 -0
- dust3r/dust3r/model.py +166 -0
- dust3r/dust3r/optim_factory.py +14 -0
- dust3r/dust3r/patch_embed.py +70 -0
- dust3r/dust3r/post_process.py +60 -0
- dust3r/dust3r/utils/__init__.py +2 -0
- dust3r/dust3r/utils/device.py +76 -0
- dust3r/dust3r/utils/geometry.py +361 -0
- dust3r/dust3r/utils/image.py +104 -0
- dust3r/dust3r/utils/misc.py +121 -0
- dust3r/dust3r/utils/path_to_croco.py +19 -0
- dust3r/dust3r/viz.py +320 -0
- dust3r/requirements.txt +12 -0
- dust3r/train.py +383 -0
- setup.sh +6 -2
.gitignore
CHANGED
|
@@ -16,5 +16,10 @@ results/
|
|
| 16 |
# Gradio
|
| 17 |
.gradio/
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
.DS_Store
|
|
|
|
| 16 |
# Gradio
|
| 17 |
.gradio/
|
| 18 |
|
| 19 |
+
dust3r/
|
| 20 |
+
results/
|
| 21 |
+
dust3r/assets/*.jpg
|
| 22 |
+
dust3r/croco/assets/*.png
|
| 23 |
+
results/object.glb
|
| 24 |
+
|
| 25 |
.DS_Store
|
app.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
sys.path.append(os.path.join(os.path.dirname(__file__), 'dust3r'))
|
| 6 |
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import torch
|
| 4 |
+
import subprocess
|
| 5 |
+
|
| 6 |
+
if not os.path.exists('dust3r/.git'):
|
| 7 |
+
print("Initializing dust3r submodule...")
|
| 8 |
+
subprocess.run(['git', 'submodule', 'update', '--init', '--recursive'], check=False)
|
| 9 |
|
| 10 |
sys.path.append(os.path.join(os.path.dirname(__file__), 'dust3r'))
|
| 11 |
|
dust3r
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
Subproject commit 78e55fd11ef6d838fc3d8c5c5a52b32eac426e09
|
|
|
|
|
|
dust3r/.gitignore
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/
|
| 2 |
+
checkpoints/
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
pip-wheel-metadata/
|
| 27 |
+
share/python-wheels/
|
| 28 |
+
*.egg-info/
|
| 29 |
+
.installed.cfg
|
| 30 |
+
*.egg
|
| 31 |
+
MANIFEST
|
| 32 |
+
|
| 33 |
+
# PyInstaller
|
| 34 |
+
# Usually these files are written by a python script from a template
|
| 35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
*.py,cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
.pytest_cache/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
target/
|
| 79 |
+
|
| 80 |
+
# Jupyter Notebook
|
| 81 |
+
.ipynb_checkpoints
|
| 82 |
+
|
| 83 |
+
# IPython
|
| 84 |
+
profile_default/
|
| 85 |
+
ipython_config.py
|
| 86 |
+
|
| 87 |
+
# pyenv
|
| 88 |
+
.python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 98 |
+
__pypackages__/
|
| 99 |
+
|
| 100 |
+
# Celery stuff
|
| 101 |
+
celerybeat-schedule
|
| 102 |
+
celerybeat.pid
|
| 103 |
+
|
| 104 |
+
# SageMath parsed files
|
| 105 |
+
*.sage.py
|
| 106 |
+
|
| 107 |
+
# Environments
|
| 108 |
+
.env
|
| 109 |
+
.venv
|
| 110 |
+
env/
|
| 111 |
+
venv/
|
| 112 |
+
ENV/
|
| 113 |
+
env.bak/
|
| 114 |
+
venv.bak/
|
| 115 |
+
|
| 116 |
+
# Spyder project settings
|
| 117 |
+
.spyderproject
|
| 118 |
+
.spyproject
|
| 119 |
+
|
| 120 |
+
# Rope project settings
|
| 121 |
+
.ropeproject
|
| 122 |
+
|
| 123 |
+
# mkdocs documentation
|
| 124 |
+
/site
|
| 125 |
+
|
| 126 |
+
# mypy
|
| 127 |
+
.mypy_cache/
|
| 128 |
+
.dmypy.json
|
| 129 |
+
dmypy.json
|
| 130 |
+
|
| 131 |
+
# Pyre type checker
|
| 132 |
+
.pyre/
|
dust3r/.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "croco"]
|
| 2 |
+
path = croco
|
| 3 |
+
url = https://github.com/naver/croco
|
dust3r/LICENSE
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
|
| 2 |
+
|
| 3 |
+
A summary of the CC BY-NC-SA 4.0 license is located here:
|
| 4 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 5 |
+
|
| 6 |
+
The CC BY-NC-SA 4.0 license is located here:
|
| 7 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
dust3r/NOTICE
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DUSt3R
|
| 2 |
+
Copyright 2024-present NAVER Corp.
|
| 3 |
+
|
| 4 |
+
This project contains subcomponents with separate copyright notices and license terms.
|
| 5 |
+
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
|
| 6 |
+
|
| 7 |
+
====
|
| 8 |
+
|
| 9 |
+
naver/croco
|
| 10 |
+
https://github.com/naver/croco/
|
| 11 |
+
|
| 12 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0
|
| 13 |
+
|
dust3r/README.md
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DUSt3R
|
| 2 |
+
|
| 3 |
+
Official implementation of `DUSt3R: Geometric 3D Vision Made Easy`
|
| 4 |
+
[[Project page](https://dust3r.europe.naverlabs.com/)], [[DUSt3R arxiv](https://arxiv.org/abs/2312.14132)]
|
| 5 |
+
|
| 6 |
+

|
| 7 |
+
|
| 8 |
+

|
| 9 |
+
|
| 10 |
+
```bibtex
|
| 11 |
+
@misc{wang2023dust3r,
|
| 12 |
+
title={DUSt3R: Geometric 3D Vision Made Easy},
|
| 13 |
+
author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
|
| 14 |
+
year={2023},
|
| 15 |
+
eprint={2312.14132},
|
| 16 |
+
archivePrefix={arXiv},
|
| 17 |
+
primaryClass={cs.CV}
|
| 18 |
+
}
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Table of Contents
|
| 22 |
+
- [DUSt3R](#dust3r)
|
| 23 |
+
- [License](#license)
|
| 24 |
+
- [Get Started](#get-started)
|
| 25 |
+
- [Installation](#installation)
|
| 26 |
+
- [Checkpoints](#checkpoints)
|
| 27 |
+
- [Interactive demo](#interactive-demo)
|
| 28 |
+
- [Usage](#usage)
|
| 29 |
+
- [Training](#training)
|
| 30 |
+
- [Demo](#demo)
|
| 31 |
+
- [Our Hyperparameters](#our-hyperparameters)
|
| 32 |
+
|
| 33 |
+
## License
|
| 34 |
+
The code is distributed under the CC BY-NC-SA 4.0 License. See See [LICENSE](LICENSE) for more information.
|
| 35 |
+
```python
|
| 36 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 37 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Get Started
|
| 41 |
+
|
| 42 |
+
### Installation
|
| 43 |
+
|
| 44 |
+
1. Clone DUSt3R
|
| 45 |
+
```bash
|
| 46 |
+
git clone --recursive https://github.com/naver/dust3r
|
| 47 |
+
cd dust3r
|
| 48 |
+
# if you have already cloned dust3r:
|
| 49 |
+
# git submodule update --init --recursive
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
2. Create the environment, here we show an example using conda.
|
| 53 |
+
```bash
|
| 54 |
+
conda create -n dust3r python=3.11 cmake=3.14.0
|
| 55 |
+
conda activate dust3r
|
| 56 |
+
conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system
|
| 57 |
+
pip install -r requirements.txt
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
3. Optional, compile the cuda kernels for RoPE (as in CroCo v2)
|
| 62 |
+
```bash
|
| 63 |
+
# DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
|
| 64 |
+
cd croco/models/curope/
|
| 65 |
+
python setup.py build_ext --inplace
|
| 66 |
+
cd ../../../
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
4. Download pre-trained model
|
| 70 |
+
```bash
|
| 71 |
+
mkdir -p checkpoints/
|
| 72 |
+
wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### Checkpoints
|
| 76 |
+
|
| 77 |
+
We provide several pre-trained models:
|
| 78 |
+
|
| 79 |
+
| Modelname | Training resolutions | Head | Encoder | Decoder |
|
| 80 |
+
|-------------|----------------------|------|---------|---------|
|
| 81 |
+
| [`DUSt3R_ViTLarge_BaseDecoder_224_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth) | 224x224 | Linear | ViT-L | ViT-B |
|
| 82 |
+
| [`DUSt3R_ViTLarge_BaseDecoder_512_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | Linear | ViT-L | ViT-B |
|
| 83 |
+
| [`DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | DPT | ViT-L | ViT-B |
|
| 84 |
+
|
| 85 |
+
You can check the hyperparameters we used to train these models in the [section: Our Hyperparameters](#our-hyperparameters)
|
| 86 |
+
|
| 87 |
+
### Interactive demo
|
| 88 |
+
In this demo, you should be able run DUSt3R on your machine to reconstruct a scene.
|
| 89 |
+
First select images that depicts the same scene.
|
| 90 |
+
|
| 91 |
+
You can ajust the global alignment schedule and its number of iterations.
|
| 92 |
+
Note: if you selected one or two images, the global alignment procedure will be skipped (mode=GlobalAlignerMode.PairViewer)
|
| 93 |
+
Hit "Run" and wait.
|
| 94 |
+
When the global alignment ends, the reconstruction appears.
|
| 95 |
+
Use the slider "min_conf_thr" to show or remove low confidence areas.
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
python3 demo.py --weights checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
|
| 99 |
+
|
| 100 |
+
# Use --image_size to select the correct resolution for your checkpoint. 512 (default) or 224
|
| 101 |
+
# Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
|
| 102 |
+
# Use --server_port to change the port, by default it will search for an available port starting at 7860
|
| 103 |
+
# Use --device to use a different device, by default it's "cuda"
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+

|
| 107 |
+
|
| 108 |
+
## Usage
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
from dust3r.inference import inference, load_model
|
| 112 |
+
from dust3r.utils.image import load_images
|
| 113 |
+
from dust3r.image_pairs import make_pairs
|
| 114 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
| 115 |
+
|
| 116 |
+
if __name__ == '__main__':
|
| 117 |
+
model_path = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
|
| 118 |
+
device = 'cuda'
|
| 119 |
+
batch_size = 1
|
| 120 |
+
schedule = 'cosine'
|
| 121 |
+
lr = 0.01
|
| 122 |
+
niter = 300
|
| 123 |
+
|
| 124 |
+
model = load_model(model_path, device)
|
| 125 |
+
# load_images can take a list of images or a directory
|
| 126 |
+
images = load_images(['croco/assets/Chateau1.png', 'croco/assets/Chateau2.png'], size=512)
|
| 127 |
+
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
|
| 128 |
+
output = inference(pairs, model, device, batch_size=batch_size)
|
| 129 |
+
|
| 130 |
+
# at this stage, you have the raw dust3r predictions
|
| 131 |
+
view1, pred1 = output['view1'], output['pred1']
|
| 132 |
+
view2, pred2 = output['view2'], output['pred2']
|
| 133 |
+
# here, view1, pred1, view2, pred2 are dicts of lists of len(2)
|
| 134 |
+
# -> because we symmetrize we have (im1, im2) and (im2, im1) pairs
|
| 135 |
+
# in each view you have:
|
| 136 |
+
# an integer image identifier: view1['idx'] and view2['idx']
|
| 137 |
+
# the img: view1['img'] and view2['img']
|
| 138 |
+
# the image shape: view1['true_shape'] and view2['true_shape']
|
| 139 |
+
# an instance string output by the dataloader: view1['instance'] and view2['instance']
|
| 140 |
+
# pred1 and pred2 contains the confidence values: pred1['conf'] and pred2['conf']
|
| 141 |
+
# pred1 contains 3D points for view1['img'] in view1['img'] space: pred1['pts3d']
|
| 142 |
+
# pred2 contains 3D points for view2['img'] in view1['img'] space: pred2['pts3d_in_other_view']
|
| 143 |
+
|
| 144 |
+
# next we'll use the global_aligner to align the predictions
|
| 145 |
+
# depending on your task, you may be fine with the raw output and not need it
|
| 146 |
+
# with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output
|
| 147 |
+
# if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment
|
| 148 |
+
scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
|
| 149 |
+
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
|
| 150 |
+
|
| 151 |
+
# retrieve useful values from scene:
|
| 152 |
+
imgs = scene.imgs
|
| 153 |
+
focals = scene.get_focals()
|
| 154 |
+
poses = scene.get_im_poses()
|
| 155 |
+
pts3d = scene.get_pts3d()
|
| 156 |
+
confidence_masks = scene.get_masks()
|
| 157 |
+
|
| 158 |
+
# visualize reconstruction
|
| 159 |
+
scene.show()
|
| 160 |
+
|
| 161 |
+
# find 2D-2D matches between the two images
|
| 162 |
+
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
|
| 163 |
+
pts2d_list, pts3d_list = [], []
|
| 164 |
+
for i in range(2):
|
| 165 |
+
conf_i = confidence_masks[i].cpu().numpy()
|
| 166 |
+
pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i]) # imgs[i].shape[:2] = (H, W)
|
| 167 |
+
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
|
| 168 |
+
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
|
| 169 |
+
print(f'found {num_matches} matches')
|
| 170 |
+
matches_im1 = pts2d_list[1][reciprocal_in_P2]
|
| 171 |
+
matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
|
| 172 |
+
|
| 173 |
+
# visualize a few matches
|
| 174 |
+
import numpy as np
|
| 175 |
+
from matplotlib import pyplot as pl
|
| 176 |
+
n_viz = 10
|
| 177 |
+
match_idx_to_viz = np.round(np.linspace(0, num_matches-1, n_viz)).astype(int)
|
| 178 |
+
viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
|
| 179 |
+
|
| 180 |
+
H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
|
| 181 |
+
img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
|
| 182 |
+
img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
|
| 183 |
+
img = np.concatenate((img0, img1), axis=1)
|
| 184 |
+
pl.figure()
|
| 185 |
+
pl.imshow(img)
|
| 186 |
+
cmap = pl.get_cmap('jet')
|
| 187 |
+
for i in range(n_viz):
|
| 188 |
+
(x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
|
| 189 |
+
pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
|
| 190 |
+
pl.show(block=True)
|
| 191 |
+
|
| 192 |
+
```
|
| 193 |
+

|
| 194 |
+
|
| 195 |
+
## Training
|
| 196 |
+
In this section, we present propose a short demonstration to get started with training DUSt3R. At the moment, we didn't release the training datasets, so we're going to download and prepare a subset of [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE) and launch the training code on it.
|
| 197 |
+
The demo model will be trained for a few epochs on a very small dataset. It will not be very good.
|
| 198 |
+
|
| 199 |
+
### Demo
|
| 200 |
+
|
| 201 |
+
```bash
|
| 202 |
+
|
| 203 |
+
# download and prepare the co3d subset
|
| 204 |
+
mkdir -p data/co3d_subset
|
| 205 |
+
cd data/co3d_subset
|
| 206 |
+
git clone https://github.com/facebookresearch/co3d
|
| 207 |
+
cd co3d
|
| 208 |
+
python3 ./co3d/download_dataset.py --download_folder ../ --single_sequence_subset
|
| 209 |
+
rm ../*.zip
|
| 210 |
+
cd ../../..
|
| 211 |
+
|
| 212 |
+
python3 datasets_preprocess/preprocess_co3d.py --co3d_dir data/co3d_subset --output_dir data/co3d_subset_processed --single_sequence_subset
|
| 213 |
+
|
| 214 |
+
# download the pretrained croco v2 checkpoint
|
| 215 |
+
mkdir -p checkpoints/
|
| 216 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth -P checkpoints/
|
| 217 |
+
|
| 218 |
+
# the training of dust3r is done in 3 steps.
|
| 219 |
+
# for this example we'll do fewer epochs, for the actual hyperparameters we used in the paper, see the next section: "Our Hyperparameters"
|
| 220 |
+
# step 1 - train dust3r for 224 resolution
|
| 221 |
+
torchrun --nproc_per_node=4 train.py \
|
| 222 |
+
--train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)" \
|
| 223 |
+
--test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=224, seed=777)" \
|
| 224 |
+
--model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 225 |
+
--train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 226 |
+
--test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
|
| 227 |
+
--pretrained checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth \
|
| 228 |
+
--lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 16 --accum_iter 1 \
|
| 229 |
+
--save_freq 1 --keep_freq 5 --eval_freq 1 \
|
| 230 |
+
--output_dir checkpoints/dust3r_demo_224
|
| 231 |
+
|
| 232 |
+
# step 2 - train dust3r for 512 resolution
|
| 233 |
+
torchrun --nproc_per_node=4 train.py \
|
| 234 |
+
--train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
|
| 235 |
+
--test_dataset="100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
|
| 236 |
+
--model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 237 |
+
--train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 238 |
+
--test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
|
| 239 |
+
--pretrained='checkpoints/dust3r_demo_224/checkpoint-best.pth' \
|
| 240 |
+
--lr=0.0001 --min_lr=1e-06 --warmup_epochs 1 --epochs 10 --batch_size 4 --accum_iter 4 \
|
| 241 |
+
--save_freq 1 --keep_freq 5 --eval_freq 1 \
|
| 242 |
+
--output_dir checkpoints/dust3r_demo_512
|
| 243 |
+
|
| 244 |
+
# step 3 - train dust3r for 512 resolution with dpt
|
| 245 |
+
torchrun --nproc_per_node=4 train.py \
|
| 246 |
+
--train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
|
| 247 |
+
--test_dataset="100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
|
| 248 |
+
--model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 249 |
+
--train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 250 |
+
--test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
|
| 251 |
+
--pretrained='checkpoints/dust3r_demo_512/checkpoint-best.pth' \
|
| 252 |
+
--lr=0.0001 --min_lr=1e-06 --warmup_epochs 1 --epochs 10 --batch_size 2 --accum_iter 8 \
|
| 253 |
+
--save_freq 1 --keep_freq 5 --eval_freq 1 \
|
| 254 |
+
--output_dir checkpoints/dust3r_demo_512dpt
|
| 255 |
+
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
### Our Hyperparameters
|
| 259 |
+
We didn't release the training datasets, but here are the commands we used for training our models:
|
| 260 |
+
|
| 261 |
+
```bash
|
| 262 |
+
# NOTE: ROOT path omitted for datasets
|
| 263 |
+
# 224 linear
|
| 264 |
+
torchrun --nproc_per_node 4 train.py \
|
| 265 |
+
--train_dataset=" + 100_000 @ Habitat512(1_000_000, split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ BlendedMVS(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ MegaDepthDense(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ ARKitScenes(aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Co3d_v3(split='train', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ ScanNetpp(split='train', aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Waymo(aug_crop=128, resolution=224, transform=ColorJitter) " \
|
| 266 |
+
--test_dataset=" Habitat512(1_000, split='val', resolution=224, seed=777) + 1_000 @ BlendedMVS(split='val', resolution=224, seed=777) + 1_000 @ MegaDepthDense(split='val', resolution=224, seed=777) + 1_000 @ Co3d_v3(split='test', mask_bg='rand', resolution=224, seed=777) " \
|
| 267 |
+
--train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 268 |
+
--test_criterion='Regr3D_ScaleShiftInv(L21, gt_scale=True)' \
|
| 269 |
+
--model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 270 |
+
--pretrained="checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
|
| 271 |
+
--lr=0.0001 --min_lr=1e-06 --warmup_epochs=10 --epochs=100 --batch_size=16 --accum_iter=1 \
|
| 272 |
+
--save_freq=5 --keep_freq=10 --eval_freq=1 \
|
| 273 |
+
--output_dir='checkpoints/dust3r_224'
|
| 274 |
+
|
| 275 |
+
# 512 linear
|
| 276 |
+
torchrun --nproc_per_node 8 train.py \
|
| 277 |
+
--train_dataset=" + 10_000 @ Habitat512(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepthDense(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d_v3(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Waymo(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
|
| 278 |
+
--test_dataset=" Habitat512(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepthDense(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d_v3(split='test', resolution=(512,384), seed=777) " \
|
| 279 |
+
--train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 280 |
+
--test_criterion='Regr3D_ScaleShiftInv(L21, gt_scale=True)' \
|
| 281 |
+
--model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 282 |
+
--pretrained='checkpoints/dust3r_224/checkpoint-best.pth' \
|
| 283 |
+
--lr=0.0001 --min_lr=1e-06 --warmup_epochs=20 --epochs=200 --batch_size=4 --accum_iter=2 \
|
| 284 |
+
--save_freq=10 --keep_freq=10 --eval_freq=1 --print_freq=10 \
|
| 285 |
+
--output_dir='checkpoints/dust3r_512'
|
| 286 |
+
|
| 287 |
+
# 512 dpt
|
| 288 |
+
torchrun --nproc_per_node 8 train.py \
|
| 289 |
+
--train_dataset=" + 10_000 @ Habitat512(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepthDense(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d_v3(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Waymo(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
|
| 290 |
+
--test_dataset=" Habitat512(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepthDense(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d_v3(split='test', resolution=(512,384), seed=777) " \
|
| 291 |
+
--train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
|
| 292 |
+
--test_criterion='Regr3D_ScaleShiftInv(L21, gt_scale=True)' \
|
| 293 |
+
--model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
|
| 294 |
+
--pretrained='checkpoints/dust3r_512/checkpoint-best.pth' \
|
| 295 |
+
--lr=0.0001 --min_lr=1e-06 --warmup_epochs=15 --epochs=90 --batch_size=2 --accum_iter=4 \
|
| 296 |
+
--save_freq=5 --keep_freq=10 --eval_freq=1 --print_freq=10 \
|
| 297 |
+
--output_dir='checkpoints/dust3r_512dpt'
|
| 298 |
+
|
| 299 |
+
```
|
dust3r/datasets_preprocess/path_to_root.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# DUSt3R repo root import
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import os.path as path
|
| 10 |
+
HERE_PATH = path.normpath(path.dirname(__file__))
|
| 11 |
+
DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../'))
|
| 12 |
+
# workaround for sibling import
|
| 13 |
+
sys.path.insert(0, DUST3R_REPO_PATH)
|
dust3r/datasets_preprocess/preprocess_co3d.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 4 |
+
#
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# Script to pre-process the CO3D dataset.
|
| 7 |
+
# Usage:
|
| 8 |
+
# python3 datasets_preprocess/preprocess_co3d.py --co3d_dir /path/to/co3d
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import random
|
| 13 |
+
import gzip
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import os.path as osp
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import PIL.Image
|
| 20 |
+
import numpy as np
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
from tqdm.auto import tqdm
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
|
| 26 |
+
import path_to_root # noqa
|
| 27 |
+
import dust3r.datasets.utils.cropping as cropping # noqa
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
CATEGORIES = [
|
| 31 |
+
"apple", "backpack", "ball", "banana", "baseballbat", "baseballglove",
|
| 32 |
+
"bench", "bicycle", "book", "bottle", "bowl", "broccoli", "cake", "car", "carrot",
|
| 33 |
+
"cellphone", "chair", "couch", "cup", "donut", "frisbee", "hairdryer", "handbag",
|
| 34 |
+
"hotdog", "hydrant", "keyboard", "kite", "laptop", "microwave",
|
| 35 |
+
"motorcycle",
|
| 36 |
+
"mouse", "orange", "parkingmeter", "pizza", "plant", "remote", "sandwich",
|
| 37 |
+
"skateboard", "stopsign",
|
| 38 |
+
"suitcase", "teddybear", "toaster", "toilet", "toybus",
|
| 39 |
+
"toyplane", "toytrain", "toytruck", "tv",
|
| 40 |
+
"umbrella", "vase", "wineglass",
|
| 41 |
+
]
|
| 42 |
+
CATEGORIES_IDX = {cat: i for i, cat in enumerate(CATEGORIES)} # for seeding
|
| 43 |
+
|
| 44 |
+
SINGLE_SEQUENCE_CATEGORIES = sorted(set(CATEGORIES) - set(["microwave", "stopsign", "tv"]))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_parser():
|
| 48 |
+
parser = argparse.ArgumentParser()
|
| 49 |
+
parser.add_argument("--category", type=str, default=None)
|
| 50 |
+
parser.add_argument('--single_sequence_subset', default=False, action='store_true',
|
| 51 |
+
help="prepare the single_sequence_subset instead.")
|
| 52 |
+
parser.add_argument("--output_dir", type=str, default="data/co3d_processed")
|
| 53 |
+
parser.add_argument("--co3d_dir", type=str, required=True)
|
| 54 |
+
parser.add_argument("--num_sequences_per_object", type=int, default=50)
|
| 55 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 56 |
+
parser.add_argument("--min_quality", type=float, default=0.5, help="Minimum viewpoint quality score.")
|
| 57 |
+
|
| 58 |
+
parser.add_argument("--img_size", type=int, default=512,
|
| 59 |
+
help=("lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size"))
|
| 60 |
+
return parser
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def convert_ndc_to_pinhole(focal_length, principal_point, image_size):
|
| 64 |
+
focal_length = np.array(focal_length)
|
| 65 |
+
principal_point = np.array(principal_point)
|
| 66 |
+
image_size_wh = np.array([image_size[1], image_size[0]])
|
| 67 |
+
half_image_size = image_size_wh / 2
|
| 68 |
+
rescale = half_image_size.min()
|
| 69 |
+
principal_point_px = half_image_size - principal_point * rescale
|
| 70 |
+
focal_length_px = focal_length * rescale
|
| 71 |
+
fx, fy = focal_length_px[0], focal_length_px[1]
|
| 72 |
+
cx, cy = principal_point_px[0], principal_point_px[1]
|
| 73 |
+
K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)
|
| 74 |
+
return K
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def opencv_from_cameras_projection(R, T, focal, p0, image_size):
|
| 78 |
+
R = torch.from_numpy(R)[None, :, :]
|
| 79 |
+
T = torch.from_numpy(T)[None, :]
|
| 80 |
+
focal = torch.from_numpy(focal)[None, :]
|
| 81 |
+
p0 = torch.from_numpy(p0)[None, :]
|
| 82 |
+
image_size = torch.from_numpy(image_size)[None, :]
|
| 83 |
+
|
| 84 |
+
R_pytorch3d = R.clone()
|
| 85 |
+
T_pytorch3d = T.clone()
|
| 86 |
+
focal_pytorch3d = focal
|
| 87 |
+
p0_pytorch3d = p0
|
| 88 |
+
T_pytorch3d[:, :2] *= -1
|
| 89 |
+
R_pytorch3d[:, :, :2] *= -1
|
| 90 |
+
tvec = T_pytorch3d
|
| 91 |
+
R = R_pytorch3d.permute(0, 2, 1)
|
| 92 |
+
|
| 93 |
+
# Retype the image_size correctly and flip to width, height.
|
| 94 |
+
image_size_wh = image_size.to(R).flip(dims=(1,))
|
| 95 |
+
|
| 96 |
+
# NDC to screen conversion.
|
| 97 |
+
scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
|
| 98 |
+
scale = scale.expand(-1, 2)
|
| 99 |
+
c0 = image_size_wh / 2.0
|
| 100 |
+
|
| 101 |
+
principal_point = -p0_pytorch3d * scale + c0
|
| 102 |
+
focal_length = focal_pytorch3d * scale
|
| 103 |
+
|
| 104 |
+
camera_matrix = torch.zeros_like(R)
|
| 105 |
+
camera_matrix[:, :2, 2] = principal_point
|
| 106 |
+
camera_matrix[:, 2, 2] = 1.0
|
| 107 |
+
camera_matrix[:, 0, 0] = focal_length[:, 0]
|
| 108 |
+
camera_matrix[:, 1, 1] = focal_length[:, 1]
|
| 109 |
+
return R[0], tvec[0], camera_matrix[0]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_set_list(category_dir, split, is_single_sequence_subset=False):
|
| 113 |
+
listfiles = os.listdir(osp.join(category_dir, "set_lists"))
|
| 114 |
+
if is_single_sequence_subset:
|
| 115 |
+
# not all objects have manyview_dev
|
| 116 |
+
subset_list_files = [f for f in listfiles if "manyview_dev" in f]
|
| 117 |
+
else:
|
| 118 |
+
subset_list_files = [f for f in listfiles if f"fewview_train" in f]
|
| 119 |
+
|
| 120 |
+
sequences_all = []
|
| 121 |
+
for subset_list_file in subset_list_files:
|
| 122 |
+
with open(osp.join(category_dir, "set_lists", subset_list_file)) as f:
|
| 123 |
+
subset_lists_data = json.load(f)
|
| 124 |
+
sequences_all.extend(subset_lists_data[split])
|
| 125 |
+
|
| 126 |
+
return sequences_all
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def prepare_sequences(category, co3d_dir, output_dir, img_size, split, min_quality, max_num_sequences_per_object,
|
| 130 |
+
seed, is_single_sequence_subset=False):
|
| 131 |
+
random.seed(seed)
|
| 132 |
+
category_dir = osp.join(co3d_dir, category)
|
| 133 |
+
category_output_dir = osp.join(output_dir, category)
|
| 134 |
+
sequences_all = get_set_list(category_dir, split, is_single_sequence_subset)
|
| 135 |
+
sequences_numbers = sorted(set(seq_name for seq_name, _, _ in sequences_all))
|
| 136 |
+
|
| 137 |
+
frame_file = osp.join(category_dir, "frame_annotations.jgz")
|
| 138 |
+
sequence_file = osp.join(category_dir, "sequence_annotations.jgz")
|
| 139 |
+
|
| 140 |
+
with gzip.open(frame_file, "r") as fin:
|
| 141 |
+
frame_data = json.loads(fin.read())
|
| 142 |
+
with gzip.open(sequence_file, "r") as fin:
|
| 143 |
+
sequence_data = json.loads(fin.read())
|
| 144 |
+
|
| 145 |
+
frame_data_processed = {}
|
| 146 |
+
for f_data in frame_data:
|
| 147 |
+
sequence_name = f_data["sequence_name"]
|
| 148 |
+
frame_data_processed.setdefault(sequence_name, {})[f_data["frame_number"]] = f_data
|
| 149 |
+
|
| 150 |
+
good_quality_sequences = set()
|
| 151 |
+
for seq_data in sequence_data:
|
| 152 |
+
if seq_data["viewpoint_quality_score"] > min_quality:
|
| 153 |
+
good_quality_sequences.add(seq_data["sequence_name"])
|
| 154 |
+
|
| 155 |
+
sequences_numbers = [seq_name for seq_name in sequences_numbers if seq_name in good_quality_sequences]
|
| 156 |
+
if len(sequences_numbers) < max_num_sequences_per_object:
|
| 157 |
+
selected_sequences_numbers = sequences_numbers
|
| 158 |
+
else:
|
| 159 |
+
selected_sequences_numbers = random.sample(sequences_numbers, max_num_sequences_per_object)
|
| 160 |
+
|
| 161 |
+
selected_sequences_numbers_dict = {seq_name: [] for seq_name in selected_sequences_numbers}
|
| 162 |
+
sequences_all = [(seq_name, frame_number, filepath)
|
| 163 |
+
for seq_name, frame_number, filepath in sequences_all
|
| 164 |
+
if seq_name in selected_sequences_numbers_dict]
|
| 165 |
+
|
| 166 |
+
for seq_name, frame_number, filepath in tqdm(sequences_all):
|
| 167 |
+
frame_idx = int(filepath.split('/')[-1][5:-4])
|
| 168 |
+
selected_sequences_numbers_dict[seq_name].append(frame_idx)
|
| 169 |
+
mask_path = filepath.replace("images", "masks").replace(".jpg", ".png")
|
| 170 |
+
frame_data = frame_data_processed[seq_name][frame_number]
|
| 171 |
+
focal_length = frame_data["viewpoint"]["focal_length"]
|
| 172 |
+
principal_point = frame_data["viewpoint"]["principal_point"]
|
| 173 |
+
image_size = frame_data["image"]["size"]
|
| 174 |
+
K = convert_ndc_to_pinhole(focal_length, principal_point, image_size)
|
| 175 |
+
R, tvec, camera_intrinsics = opencv_from_cameras_projection(np.array(frame_data["viewpoint"]["R"]),
|
| 176 |
+
np.array(frame_data["viewpoint"]["T"]),
|
| 177 |
+
np.array(focal_length),
|
| 178 |
+
np.array(principal_point),
|
| 179 |
+
np.array(image_size))
|
| 180 |
+
|
| 181 |
+
frame_data = frame_data_processed[seq_name][frame_number]
|
| 182 |
+
depth_path = os.path.join(co3d_dir, frame_data["depth"]["path"])
|
| 183 |
+
assert frame_data["depth"]["scale_adjustment"] == 1.0
|
| 184 |
+
image_path = os.path.join(co3d_dir, filepath)
|
| 185 |
+
mask_path_full = os.path.join(co3d_dir, mask_path)
|
| 186 |
+
|
| 187 |
+
input_rgb_image = PIL.Image.open(image_path).convert('RGB')
|
| 188 |
+
input_mask = plt.imread(mask_path_full)
|
| 189 |
+
|
| 190 |
+
with PIL.Image.open(depth_path) as depth_pil:
|
| 191 |
+
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
|
| 192 |
+
# we cast it to uint16, then reinterpret as float16, then cast to float32
|
| 193 |
+
input_depthmap = (
|
| 194 |
+
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
|
| 195 |
+
.astype(np.float32)
|
| 196 |
+
.reshape((depth_pil.size[1], depth_pil.size[0])))
|
| 197 |
+
depth_mask = np.stack((input_depthmap, input_mask), axis=-1)
|
| 198 |
+
H, W = input_depthmap.shape
|
| 199 |
+
|
| 200 |
+
camera_intrinsics = camera_intrinsics.numpy()
|
| 201 |
+
cx, cy = camera_intrinsics[:2, 2].round().astype(int)
|
| 202 |
+
min_margin_x = min(cx, W-cx)
|
| 203 |
+
min_margin_y = min(cy, H-cy)
|
| 204 |
+
|
| 205 |
+
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
|
| 206 |
+
l, t = cx - min_margin_x, cy - min_margin_y
|
| 207 |
+
r, b = cx + min_margin_x, cy + min_margin_y
|
| 208 |
+
crop_bbox = (l, t, r, b)
|
| 209 |
+
input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap(
|
| 210 |
+
input_rgb_image, depth_mask, camera_intrinsics, crop_bbox)
|
| 211 |
+
|
| 212 |
+
# try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384
|
| 213 |
+
scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8
|
| 214 |
+
output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)
|
| 215 |
+
if max(output_resolution) < img_size:
|
| 216 |
+
# let's put the max dimension to img_size
|
| 217 |
+
scale_final = (img_size / max(H, W)) + 1e-8
|
| 218 |
+
output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)
|
| 219 |
+
|
| 220 |
+
input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap(
|
| 221 |
+
input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution)
|
| 222 |
+
input_depthmap = depth_mask[:, :, 0]
|
| 223 |
+
input_mask = depth_mask[:, :, 1]
|
| 224 |
+
|
| 225 |
+
# generate and adjust camera pose
|
| 226 |
+
camera_pose = np.eye(4, dtype=np.float32)
|
| 227 |
+
camera_pose[:3, :3] = R
|
| 228 |
+
camera_pose[:3, 3] = tvec
|
| 229 |
+
camera_pose = np.linalg.inv(camera_pose)
|
| 230 |
+
|
| 231 |
+
# save crop images and depth, metadata
|
| 232 |
+
save_img_path = os.path.join(output_dir, filepath)
|
| 233 |
+
save_depth_path = os.path.join(output_dir, frame_data["depth"]["path"])
|
| 234 |
+
save_mask_path = os.path.join(output_dir, mask_path)
|
| 235 |
+
os.makedirs(os.path.split(save_img_path)[0], exist_ok=True)
|
| 236 |
+
os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True)
|
| 237 |
+
os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True)
|
| 238 |
+
|
| 239 |
+
input_rgb_image.save(save_img_path)
|
| 240 |
+
scaled_depth_map = (input_depthmap / np.max(input_depthmap) * 65535).astype(np.uint16)
|
| 241 |
+
cv2.imwrite(save_depth_path, scaled_depth_map)
|
| 242 |
+
cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8))
|
| 243 |
+
|
| 244 |
+
save_meta_path = save_img_path.replace('jpg', 'npz')
|
| 245 |
+
np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics,
|
| 246 |
+
camera_pose=camera_pose, maximum_depth=np.max(input_depthmap))
|
| 247 |
+
|
| 248 |
+
return selected_sequences_numbers_dict
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
parser = get_parser()
|
| 253 |
+
args = parser.parse_args()
|
| 254 |
+
assert args.co3d_dir != args.output_dir
|
| 255 |
+
if args.category is None:
|
| 256 |
+
if args.single_sequence_subset:
|
| 257 |
+
categories = SINGLE_SEQUENCE_CATEGORIES
|
| 258 |
+
else:
|
| 259 |
+
categories = CATEGORIES
|
| 260 |
+
else:
|
| 261 |
+
categories = [args.category]
|
| 262 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 263 |
+
|
| 264 |
+
for split in ['train', 'test']:
|
| 265 |
+
selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json')
|
| 266 |
+
if os.path.isfile(selected_sequences_path):
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
all_selected_sequences = {}
|
| 270 |
+
for category in categories:
|
| 271 |
+
category_output_dir = osp.join(args.output_dir, category)
|
| 272 |
+
os.makedirs(category_output_dir, exist_ok=True)
|
| 273 |
+
category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json')
|
| 274 |
+
if os.path.isfile(category_selected_sequences_path):
|
| 275 |
+
with open(category_selected_sequences_path, 'r') as fid:
|
| 276 |
+
category_selected_sequences = json.load(fid)
|
| 277 |
+
else:
|
| 278 |
+
print(f"Processing {split} - category = {category}")
|
| 279 |
+
category_selected_sequences = prepare_sequences(
|
| 280 |
+
category=category,
|
| 281 |
+
co3d_dir=args.co3d_dir,
|
| 282 |
+
output_dir=args.output_dir,
|
| 283 |
+
img_size=args.img_size,
|
| 284 |
+
split=split,
|
| 285 |
+
min_quality=args.min_quality,
|
| 286 |
+
max_num_sequences_per_object=args.num_sequences_per_object,
|
| 287 |
+
seed=args.seed + CATEGORIES_IDX[category],
|
| 288 |
+
is_single_sequence_subset=args.single_sequence_subset
|
| 289 |
+
)
|
| 290 |
+
with open(category_selected_sequences_path, 'w') as file:
|
| 291 |
+
json.dump(category_selected_sequences, file)
|
| 292 |
+
|
| 293 |
+
all_selected_sequences[category] = category_selected_sequences
|
| 294 |
+
with open(selected_sequences_path, 'w') as file:
|
| 295 |
+
json.dump(all_selected_sequences, file)
|
dust3r/demo.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 4 |
+
#
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# gradio demo
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
import argparse
|
| 9 |
+
import gradio
|
| 10 |
+
import os
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
import tempfile
|
| 14 |
+
import functools
|
| 15 |
+
import trimesh
|
| 16 |
+
import copy
|
| 17 |
+
from scipy.spatial.transform import Rotation
|
| 18 |
+
|
| 19 |
+
from dust3r.inference import inference, load_model
|
| 20 |
+
from dust3r.image_pairs import make_pairs
|
| 21 |
+
from dust3r.utils.image import load_images, rgb
|
| 22 |
+
from dust3r.utils.device import to_numpy
|
| 23 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
| 24 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
| 25 |
+
|
| 26 |
+
import matplotlib.pyplot as pl
|
| 27 |
+
pl.ion()
|
| 28 |
+
|
| 29 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
| 30 |
+
batch_size = 1
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_args_parser():
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
parser_url = parser.add_mutually_exclusive_group()
|
| 36 |
+
parser_url.add_argument("--local_network", action='store_true', default=False,
|
| 37 |
+
help="make app accessible on local network: address will be set to 0.0.0.0")
|
| 38 |
+
parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
|
| 39 |
+
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
|
| 40 |
+
parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
|
| 41 |
+
"If None, will search for an available port starting at 7860."),
|
| 42 |
+
default=None)
|
| 43 |
+
parser.add_argument("--weights", type=str, required=True, help="path to the model weights")
|
| 44 |
+
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
|
| 45 |
+
parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
|
| 46 |
+
return parser
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
| 50 |
+
cam_color=None, as_pointcloud=False, transparent_cams=False):
|
| 51 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
| 52 |
+
pts3d = to_numpy(pts3d)
|
| 53 |
+
imgs = to_numpy(imgs)
|
| 54 |
+
focals = to_numpy(focals)
|
| 55 |
+
cams2world = to_numpy(cams2world)
|
| 56 |
+
|
| 57 |
+
scene = trimesh.Scene()
|
| 58 |
+
|
| 59 |
+
# full pointcloud
|
| 60 |
+
if as_pointcloud:
|
| 61 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
| 62 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
|
| 63 |
+
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
|
| 64 |
+
scene.add_geometry(pct)
|
| 65 |
+
else:
|
| 66 |
+
meshes = []
|
| 67 |
+
for i in range(len(imgs)):
|
| 68 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
|
| 69 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
| 70 |
+
scene.add_geometry(mesh)
|
| 71 |
+
|
| 72 |
+
# add each camera
|
| 73 |
+
for i, pose_c2w in enumerate(cams2world):
|
| 74 |
+
if isinstance(cam_color, list):
|
| 75 |
+
camera_edge_color = cam_color[i]
|
| 76 |
+
else:
|
| 77 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
| 78 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
| 79 |
+
None if transparent_cams else imgs[i], focals[i],
|
| 80 |
+
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
| 81 |
+
|
| 82 |
+
rot = np.eye(4)
|
| 83 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
| 84 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
| 85 |
+
outfile = os.path.join(outdir, 'scene.glb')
|
| 86 |
+
print('(exporting 3D scene to', outfile, ')')
|
| 87 |
+
scene.export(file_obj=outfile)
|
| 88 |
+
return outfile
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_3D_model_from_scene(outdir, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
|
| 92 |
+
clean_depth=False, transparent_cams=False, cam_size=0.05):
|
| 93 |
+
"""
|
| 94 |
+
extract 3D_model (glb file) from a reconstructed scene
|
| 95 |
+
"""
|
| 96 |
+
if scene is None:
|
| 97 |
+
return None
|
| 98 |
+
# post processes
|
| 99 |
+
if clean_depth:
|
| 100 |
+
scene = scene.clean_pointcloud()
|
| 101 |
+
if mask_sky:
|
| 102 |
+
scene = scene.mask_sky()
|
| 103 |
+
|
| 104 |
+
# get optimized values from scene
|
| 105 |
+
rgbimg = scene.imgs
|
| 106 |
+
focals = scene.get_focals().cpu()
|
| 107 |
+
cams2world = scene.get_im_poses().cpu()
|
| 108 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
| 109 |
+
pts3d = to_numpy(scene.get_pts3d())
|
| 110 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
|
| 111 |
+
msk = to_numpy(scene.get_masks())
|
| 112 |
+
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
| 113 |
+
transparent_cams=transparent_cams, cam_size=cam_size)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_reconstructed_scene(outdir, model, device, image_size, filelist, schedule, niter, min_conf_thr,
|
| 117 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
| 118 |
+
scenegraph_type, winsize, refid):
|
| 119 |
+
"""
|
| 120 |
+
from a list of images, run dust3r inference, global aligner.
|
| 121 |
+
then run get_3D_model_from_scene
|
| 122 |
+
"""
|
| 123 |
+
imgs = load_images(filelist, size=image_size)
|
| 124 |
+
if len(imgs) == 1:
|
| 125 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
| 126 |
+
imgs[1]['idx'] = 1
|
| 127 |
+
if scenegraph_type == "swin":
|
| 128 |
+
scenegraph_type = scenegraph_type + "-" + str(winsize)
|
| 129 |
+
elif scenegraph_type == "oneref":
|
| 130 |
+
scenegraph_type = scenegraph_type + "-" + str(refid)
|
| 131 |
+
|
| 132 |
+
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
|
| 133 |
+
output = inference(pairs, model, device, batch_size=batch_size)
|
| 134 |
+
|
| 135 |
+
mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
|
| 136 |
+
scene = global_aligner(output, device=device, mode=mode)
|
| 137 |
+
lr = 0.01
|
| 138 |
+
|
| 139 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
| 140 |
+
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
|
| 141 |
+
|
| 142 |
+
outfile = get_3D_model_from_scene(outdir, scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 143 |
+
clean_depth, transparent_cams, cam_size)
|
| 144 |
+
|
| 145 |
+
# also return rgb, depth and confidence imgs
|
| 146 |
+
# depth is normalized with the max value for all images
|
| 147 |
+
# we apply the jet colormap on the confidence maps
|
| 148 |
+
rgbimg = scene.imgs
|
| 149 |
+
depths = to_numpy(scene.get_depthmaps())
|
| 150 |
+
confs = to_numpy([c for c in scene.im_conf])
|
| 151 |
+
cmap = pl.get_cmap('jet')
|
| 152 |
+
depths_max = max([d.max() for d in depths])
|
| 153 |
+
depths = [d/depths_max for d in depths]
|
| 154 |
+
confs_max = max([d.max() for d in confs])
|
| 155 |
+
confs = [cmap(d/confs_max) for d in confs]
|
| 156 |
+
|
| 157 |
+
imgs = []
|
| 158 |
+
for i in range(len(rgbimg)):
|
| 159 |
+
imgs.append(rgbimg[i])
|
| 160 |
+
imgs.append(rgb(depths[i]))
|
| 161 |
+
imgs.append(rgb(confs[i]))
|
| 162 |
+
|
| 163 |
+
return scene, outfile, imgs
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
|
| 167 |
+
num_files = len(inputfiles) if inputfiles is not None else 1
|
| 168 |
+
max_winsize = max(1, (num_files - 1)//2)
|
| 169 |
+
if scenegraph_type == "swin":
|
| 170 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
| 171 |
+
minimum=1, maximum=max_winsize, step=1, visible=True)
|
| 172 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
| 173 |
+
maximum=num_files-1, step=1, visible=False)
|
| 174 |
+
elif scenegraph_type == "oneref":
|
| 175 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
| 176 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
| 177 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
| 178 |
+
maximum=num_files-1, step=1, visible=True)
|
| 179 |
+
else:
|
| 180 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
| 181 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
| 182 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
| 183 |
+
maximum=num_files-1, step=1, visible=False)
|
| 184 |
+
return winsize, refid
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def main_demo(tmpdirname, model, device, image_size, server_name, server_port):
|
| 188 |
+
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, image_size)
|
| 189 |
+
model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
|
| 190 |
+
with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo:
|
| 191 |
+
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
|
| 192 |
+
scene = gradio.State(None)
|
| 193 |
+
gradio.HTML('<h2 style="text-align: center;">DUSt3R Demo</h2>')
|
| 194 |
+
with gradio.Column():
|
| 195 |
+
inputfiles = gradio.File(file_count="multiple")
|
| 196 |
+
with gradio.Row():
|
| 197 |
+
schedule = gradio.Dropdown(["linear", "cosine"],
|
| 198 |
+
value='linear', label="schedule", info="For global alignment!")
|
| 199 |
+
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
|
| 200 |
+
label="num_iterations", info="For global alignment!")
|
| 201 |
+
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
|
| 202 |
+
value='complete', label="Scenegraph",
|
| 203 |
+
info="Define how to make pairs",
|
| 204 |
+
interactive=True)
|
| 205 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
| 206 |
+
minimum=1, maximum=1, step=1, visible=False)
|
| 207 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
|
| 208 |
+
|
| 209 |
+
run_btn = gradio.Button("Run")
|
| 210 |
+
|
| 211 |
+
with gradio.Row():
|
| 212 |
+
# adjust the confidence threshold
|
| 213 |
+
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
|
| 214 |
+
# adjust the camera size in the output pointcloud
|
| 215 |
+
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
|
| 216 |
+
with gradio.Row():
|
| 217 |
+
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
|
| 218 |
+
# two post process implemented
|
| 219 |
+
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
|
| 220 |
+
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
|
| 221 |
+
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
| 222 |
+
|
| 223 |
+
outmodel = gradio.Model3D()
|
| 224 |
+
outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%")
|
| 225 |
+
|
| 226 |
+
# events
|
| 227 |
+
scenegraph_type.change(set_scenegraph_options,
|
| 228 |
+
inputs=[inputfiles, winsize, refid, scenegraph_type],
|
| 229 |
+
outputs=[winsize, refid])
|
| 230 |
+
inputfiles.change(set_scenegraph_options,
|
| 231 |
+
inputs=[inputfiles, winsize, refid, scenegraph_type],
|
| 232 |
+
outputs=[winsize, refid])
|
| 233 |
+
run_btn.click(fn=recon_fun,
|
| 234 |
+
inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
|
| 235 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
| 236 |
+
scenegraph_type, winsize, refid],
|
| 237 |
+
outputs=[scene, outmodel, outgallery])
|
| 238 |
+
min_conf_thr.release(fn=model_from_scene_fun,
|
| 239 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 240 |
+
clean_depth, transparent_cams, cam_size],
|
| 241 |
+
outputs=outmodel)
|
| 242 |
+
cam_size.change(fn=model_from_scene_fun,
|
| 243 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 244 |
+
clean_depth, transparent_cams, cam_size],
|
| 245 |
+
outputs=outmodel)
|
| 246 |
+
as_pointcloud.change(fn=model_from_scene_fun,
|
| 247 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 248 |
+
clean_depth, transparent_cams, cam_size],
|
| 249 |
+
outputs=outmodel)
|
| 250 |
+
mask_sky.change(fn=model_from_scene_fun,
|
| 251 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 252 |
+
clean_depth, transparent_cams, cam_size],
|
| 253 |
+
outputs=outmodel)
|
| 254 |
+
clean_depth.change(fn=model_from_scene_fun,
|
| 255 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 256 |
+
clean_depth, transparent_cams, cam_size],
|
| 257 |
+
outputs=outmodel)
|
| 258 |
+
transparent_cams.change(model_from_scene_fun,
|
| 259 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
| 260 |
+
clean_depth, transparent_cams, cam_size],
|
| 261 |
+
outputs=outmodel)
|
| 262 |
+
demo.launch(share=False, server_name=server_name, server_port=server_port)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == '__main__':
|
| 266 |
+
parser = get_args_parser()
|
| 267 |
+
args = parser.parse_args()
|
| 268 |
+
|
| 269 |
+
if args.tmp_dir is not None:
|
| 270 |
+
tmp_path = args.tmp_dir
|
| 271 |
+
os.makedirs(tmp_path, exist_ok=True)
|
| 272 |
+
tempfile.tempdir = tmp_path
|
| 273 |
+
|
| 274 |
+
if args.server_name is not None:
|
| 275 |
+
server_name = args.server_name
|
| 276 |
+
else:
|
| 277 |
+
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
| 278 |
+
|
| 279 |
+
model = load_model(args.weights, args.device)
|
| 280 |
+
# dust3r will write the 3D model inside tmpdirname
|
| 281 |
+
with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
|
| 282 |
+
print('Outputing stuff in', tmpdirname)
|
| 283 |
+
main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port)
|
dust3r/dust3r/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
dust3r/dust3r/cloud_opt/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# global alignment optimization wrapper function
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from enum import Enum
|
| 8 |
+
|
| 9 |
+
from .optimizer import PointCloudOptimizer
|
| 10 |
+
from .pair_viewer import PairViewer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GlobalAlignerMode(Enum):
|
| 14 |
+
PointCloudOptimizer = "PointCloudOptimizer"
|
| 15 |
+
PairViewer = "PairViewer"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw):
|
| 19 |
+
# extract all inputs
|
| 20 |
+
view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]
|
| 21 |
+
# build the optimizer
|
| 22 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
| 23 |
+
net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
|
| 24 |
+
elif mode == GlobalAlignerMode.PairViewer:
|
| 25 |
+
net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
|
| 26 |
+
else:
|
| 27 |
+
raise NotImplementedError(f'Unknown mode {mode}')
|
| 28 |
+
|
| 29 |
+
return net
|
dust3r/dust3r/cloud_opt/base_opt.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Base class for the global alignement procedure
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import roma
|
| 13 |
+
from copy import deepcopy
|
| 14 |
+
import tqdm
|
| 15 |
+
|
| 16 |
+
from dust3r.utils.geometry import inv, geotrf
|
| 17 |
+
from dust3r.utils.device import to_numpy
|
| 18 |
+
from dust3r.utils.image import rgb
|
| 19 |
+
from dust3r.viz import SceneViz, segment_sky, auto_cam_size
|
| 20 |
+
from dust3r.optim_factory import adjust_learning_rate_by_lr
|
| 21 |
+
|
| 22 |
+
from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
|
| 23 |
+
cosine_schedule, linear_schedule, get_conf_trf)
|
| 24 |
+
import dust3r.cloud_opt.init_im_poses as init_fun
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BasePCOptimizer (nn.Module):
|
| 28 |
+
""" Optimize a global scene, given a list of pairwise observations.
|
| 29 |
+
Graph node: images
|
| 30 |
+
Graph edges: observations = (pred1, pred2)
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, *args, **kwargs):
|
| 34 |
+
if len(args) == 1 and len(kwargs) == 0:
|
| 35 |
+
other = deepcopy(args[0])
|
| 36 |
+
attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes
|
| 37 |
+
min_conf_thr conf_thr conf_i conf_j im_conf
|
| 38 |
+
base_scale norm_pw_scale POSE_DIM pw_poses
|
| 39 |
+
pw_adaptors pw_adaptors has_im_poses rand_pose imgs'''.split()
|
| 40 |
+
self.__dict__.update({k: other[k] for k in attrs})
|
| 41 |
+
else:
|
| 42 |
+
self._init_from_views(*args, **kwargs)
|
| 43 |
+
|
| 44 |
+
def _init_from_views(self, view1, view2, pred1, pred2,
|
| 45 |
+
dist='l1',
|
| 46 |
+
conf='log',
|
| 47 |
+
min_conf_thr=3,
|
| 48 |
+
base_scale=0.5,
|
| 49 |
+
allow_pw_adaptors=False,
|
| 50 |
+
pw_break=20,
|
| 51 |
+
rand_pose=torch.randn,
|
| 52 |
+
iterationsCount=None):
|
| 53 |
+
super().__init__()
|
| 54 |
+
if not isinstance(view1['idx'], list):
|
| 55 |
+
view1['idx'] = view1['idx'].tolist()
|
| 56 |
+
if not isinstance(view2['idx'], list):
|
| 57 |
+
view2['idx'] = view2['idx'].tolist()
|
| 58 |
+
self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
|
| 59 |
+
self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
|
| 60 |
+
self.dist = ALL_DISTS[dist]
|
| 61 |
+
|
| 62 |
+
self.n_imgs = self._check_edges()
|
| 63 |
+
|
| 64 |
+
# input data
|
| 65 |
+
pred1_pts = pred1['pts3d']
|
| 66 |
+
pred2_pts = pred2['pts3d_in_other_view']
|
| 67 |
+
self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
|
| 68 |
+
self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
|
| 69 |
+
self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
|
| 70 |
+
|
| 71 |
+
# work in log-scale with conf
|
| 72 |
+
pred1_conf = pred1['conf']
|
| 73 |
+
pred2_conf = pred2['conf']
|
| 74 |
+
self.min_conf_thr = min_conf_thr
|
| 75 |
+
self.conf_trf = get_conf_trf(conf)
|
| 76 |
+
|
| 77 |
+
self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
|
| 78 |
+
self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
|
| 79 |
+
self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
|
| 80 |
+
|
| 81 |
+
# pairwise pose parameters
|
| 82 |
+
self.base_scale = base_scale
|
| 83 |
+
self.norm_pw_scale = True
|
| 84 |
+
self.pw_break = pw_break
|
| 85 |
+
self.POSE_DIM = 7
|
| 86 |
+
self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses
|
| 87 |
+
self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation
|
| 88 |
+
self.pw_adaptors.requires_grad_(allow_pw_adaptors)
|
| 89 |
+
self.has_im_poses = False
|
| 90 |
+
self.rand_pose = rand_pose
|
| 91 |
+
|
| 92 |
+
# possibly store images for show_pointcloud
|
| 93 |
+
self.imgs = None
|
| 94 |
+
if 'img' in view1 and 'img' in view2:
|
| 95 |
+
imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
|
| 96 |
+
for v in range(len(self.edges)):
|
| 97 |
+
idx = view1['idx'][v]
|
| 98 |
+
imgs[idx] = view1['img'][v]
|
| 99 |
+
idx = view2['idx'][v]
|
| 100 |
+
imgs[idx] = view2['img'][v]
|
| 101 |
+
self.imgs = rgb(imgs)
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def n_edges(self):
|
| 105 |
+
return len(self.edges)
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def str_edges(self):
|
| 109 |
+
return [edge_str(i, j) for i, j in self.edges]
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def imsizes(self):
|
| 113 |
+
return [(w, h) for h, w in self.imshapes]
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def device(self):
|
| 117 |
+
return next(iter(self.parameters())).device
|
| 118 |
+
|
| 119 |
+
def state_dict(self, trainable=True):
|
| 120 |
+
all_params = super().state_dict()
|
| 121 |
+
return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
|
| 122 |
+
|
| 123 |
+
def load_state_dict(self, data):
|
| 124 |
+
return super().load_state_dict(self.state_dict(trainable=False) | data)
|
| 125 |
+
|
| 126 |
+
def _check_edges(self):
|
| 127 |
+
indices = sorted({i for edge in self.edges for i in edge})
|
| 128 |
+
assert indices == list(range(len(indices))), 'bad pair indices: missing values '
|
| 129 |
+
return len(indices)
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def _compute_img_conf(self, pred1_conf, pred2_conf):
|
| 133 |
+
im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
|
| 134 |
+
for e, (i, j) in enumerate(self.edges):
|
| 135 |
+
im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
|
| 136 |
+
im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
|
| 137 |
+
return im_conf
|
| 138 |
+
|
| 139 |
+
def get_adaptors(self):
|
| 140 |
+
adapt = self.pw_adaptors
|
| 141 |
+
adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z)
|
| 142 |
+
if self.norm_pw_scale: # normalize so that the product == 1
|
| 143 |
+
adapt = adapt - adapt.mean(dim=1, keepdim=True)
|
| 144 |
+
return (adapt / self.pw_break).exp()
|
| 145 |
+
|
| 146 |
+
def _get_poses(self, poses):
|
| 147 |
+
# normalize rotation
|
| 148 |
+
Q = poses[:, :4]
|
| 149 |
+
T = signed_expm1(poses[:, 4:7])
|
| 150 |
+
RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
|
| 151 |
+
return RT
|
| 152 |
+
|
| 153 |
+
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
|
| 154 |
+
# all poses == cam-to-world
|
| 155 |
+
pose = poses[idx]
|
| 156 |
+
if not (pose.requires_grad or force):
|
| 157 |
+
return pose
|
| 158 |
+
|
| 159 |
+
if R.shape == (4, 4):
|
| 160 |
+
assert T is None
|
| 161 |
+
T = R[:3, 3]
|
| 162 |
+
R = R[:3, :3]
|
| 163 |
+
|
| 164 |
+
if R is not None:
|
| 165 |
+
pose.data[0:4] = roma.rotmat_to_unitquat(R)
|
| 166 |
+
if T is not None:
|
| 167 |
+
pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale
|
| 168 |
+
|
| 169 |
+
if scale is not None:
|
| 170 |
+
assert poses.shape[-1] in (8, 13)
|
| 171 |
+
pose.data[-1] = np.log(float(scale))
|
| 172 |
+
return pose
|
| 173 |
+
|
| 174 |
+
def get_pw_norm_scale_factor(self):
|
| 175 |
+
if self.norm_pw_scale:
|
| 176 |
+
# normalize scales so that things cannot go south
|
| 177 |
+
# we want that exp(scale) ~= self.base_scale
|
| 178 |
+
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
|
| 179 |
+
else:
|
| 180 |
+
return 1 # don't norm scale for known poses
|
| 181 |
+
|
| 182 |
+
def get_pw_scale(self):
|
| 183 |
+
scale = self.pw_poses[:, -1].exp() # (n_edges,)
|
| 184 |
+
scale = scale * self.get_pw_norm_scale_factor()
|
| 185 |
+
return scale
|
| 186 |
+
|
| 187 |
+
def get_pw_poses(self): # cam to world
|
| 188 |
+
RT = self._get_poses(self.pw_poses)
|
| 189 |
+
scaled_RT = RT.clone()
|
| 190 |
+
scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation
|
| 191 |
+
return scaled_RT
|
| 192 |
+
|
| 193 |
+
def get_masks(self):
|
| 194 |
+
return [(conf > self.min_conf_thr) for conf in self.im_conf]
|
| 195 |
+
|
| 196 |
+
def depth_to_pts3d(self):
|
| 197 |
+
raise NotImplementedError()
|
| 198 |
+
|
| 199 |
+
def get_pts3d(self, raw=False):
|
| 200 |
+
res = self.depth_to_pts3d()
|
| 201 |
+
if not raw:
|
| 202 |
+
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
|
| 203 |
+
return res
|
| 204 |
+
|
| 205 |
+
def _set_focal(self, idx, focal, force=False):
|
| 206 |
+
raise NotImplementedError()
|
| 207 |
+
|
| 208 |
+
def get_focals(self):
|
| 209 |
+
raise NotImplementedError()
|
| 210 |
+
|
| 211 |
+
def get_known_focal_mask(self):
|
| 212 |
+
raise NotImplementedError()
|
| 213 |
+
|
| 214 |
+
def get_principal_points(self):
|
| 215 |
+
raise NotImplementedError()
|
| 216 |
+
|
| 217 |
+
def get_conf(self, mode=None):
|
| 218 |
+
trf = self.conf_trf if mode is None else get_conf_trf(mode)
|
| 219 |
+
return [trf(c) for c in self.im_conf]
|
| 220 |
+
|
| 221 |
+
def get_im_poses(self):
|
| 222 |
+
raise NotImplementedError()
|
| 223 |
+
|
| 224 |
+
def _set_depthmap(self, idx, depth, force=False):
|
| 225 |
+
raise NotImplementedError()
|
| 226 |
+
|
| 227 |
+
def get_depthmaps(self, raw=False):
|
| 228 |
+
raise NotImplementedError()
|
| 229 |
+
|
| 230 |
+
@torch.no_grad()
|
| 231 |
+
def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
|
| 232 |
+
""" Method:
|
| 233 |
+
1) express all 3d points in each camera coordinate frame
|
| 234 |
+
2) if they're in front of a depthmap --> then lower their confidence
|
| 235 |
+
"""
|
| 236 |
+
assert 0 <= tol < 1
|
| 237 |
+
cams = inv(self.get_im_poses())
|
| 238 |
+
K = self.get_intrinsics()
|
| 239 |
+
depthmaps = self.get_depthmaps()
|
| 240 |
+
res = deepcopy(self)
|
| 241 |
+
|
| 242 |
+
for i, pts3d in enumerate(self.depth_to_pts3d()):
|
| 243 |
+
for j in range(self.n_imgs):
|
| 244 |
+
if i == j:
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
# project 3dpts in other view
|
| 248 |
+
Hi, Wi = self.imshapes[i]
|
| 249 |
+
Hj, Wj = self.imshapes[j]
|
| 250 |
+
proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
|
| 251 |
+
proj_depth = proj[:, :, 2]
|
| 252 |
+
u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
|
| 253 |
+
|
| 254 |
+
# check which points are actually in the visible cone
|
| 255 |
+
msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
|
| 256 |
+
msk_j = v[msk_i], u[msk_i]
|
| 257 |
+
|
| 258 |
+
# find bad points = those in front but less confident
|
| 259 |
+
bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]
|
| 260 |
+
) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j])
|
| 261 |
+
|
| 262 |
+
bad_msk_i = msk_i.clone()
|
| 263 |
+
bad_msk_i[msk_i] = bad_points
|
| 264 |
+
res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf)
|
| 265 |
+
|
| 266 |
+
return res
|
| 267 |
+
|
| 268 |
+
def forward(self, ret_details=False):
|
| 269 |
+
pw_poses = self.get_pw_poses() # cam-to-world
|
| 270 |
+
pw_adapt = self.get_adaptors()
|
| 271 |
+
proj_pts3d = self.get_pts3d()
|
| 272 |
+
# pre-compute pixel weights
|
| 273 |
+
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
|
| 274 |
+
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
|
| 275 |
+
|
| 276 |
+
loss = 0
|
| 277 |
+
if ret_details:
|
| 278 |
+
details = -torch.ones((self.n_imgs, self.n_imgs))
|
| 279 |
+
|
| 280 |
+
for e, (i, j) in enumerate(self.edges):
|
| 281 |
+
i_j = edge_str(i, j)
|
| 282 |
+
# distance in image i and j
|
| 283 |
+
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
|
| 284 |
+
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
|
| 285 |
+
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
|
| 286 |
+
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
|
| 287 |
+
loss = loss + li + lj
|
| 288 |
+
|
| 289 |
+
if ret_details:
|
| 290 |
+
details[i, j] = li + lj
|
| 291 |
+
loss /= self.n_edges # average over all pairs
|
| 292 |
+
|
| 293 |
+
if ret_details:
|
| 294 |
+
return loss, details
|
| 295 |
+
return loss
|
| 296 |
+
|
| 297 |
+
def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
|
| 298 |
+
if init is None:
|
| 299 |
+
pass
|
| 300 |
+
elif init == 'msp' or init == 'mst':
|
| 301 |
+
init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
|
| 302 |
+
elif init == 'known_poses':
|
| 303 |
+
init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP)
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError(f'bad value for {init=}')
|
| 306 |
+
|
| 307 |
+
global_alignment_loop(self, **kw)
|
| 308 |
+
|
| 309 |
+
@torch.no_grad()
|
| 310 |
+
def mask_sky(self):
|
| 311 |
+
res = deepcopy(self)
|
| 312 |
+
for i in range(self.n_imgs):
|
| 313 |
+
sky = segment_sky(self.imgs[i])
|
| 314 |
+
res.im_conf[i][sky] = 0
|
| 315 |
+
return res
|
| 316 |
+
|
| 317 |
+
def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
|
| 318 |
+
viz = SceneViz()
|
| 319 |
+
if self.imgs is None:
|
| 320 |
+
colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
|
| 321 |
+
colors = list(map(tuple, colors.tolist()))
|
| 322 |
+
for n in range(self.n_imgs):
|
| 323 |
+
viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
|
| 324 |
+
else:
|
| 325 |
+
viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
|
| 326 |
+
colors = np.random.randint(256, size=(self.n_imgs, 3))
|
| 327 |
+
|
| 328 |
+
# camera poses
|
| 329 |
+
im_poses = to_numpy(self.get_im_poses())
|
| 330 |
+
if cam_size is None:
|
| 331 |
+
cam_size = auto_cam_size(im_poses)
|
| 332 |
+
viz.add_cameras(im_poses, self.get_focals(), colors=colors,
|
| 333 |
+
images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
|
| 334 |
+
if show_pw_cams:
|
| 335 |
+
pw_poses = self.get_pw_poses()
|
| 336 |
+
viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
|
| 337 |
+
|
| 338 |
+
if show_pw_pts3d:
|
| 339 |
+
pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
|
| 340 |
+
viz.add_pointcloud(pts, (128, 0, 128))
|
| 341 |
+
|
| 342 |
+
viz.show(**kw)
|
| 343 |
+
return viz
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6, verbose=False):
|
| 347 |
+
params = [p for p in net.parameters() if p.requires_grad]
|
| 348 |
+
if not params:
|
| 349 |
+
return net
|
| 350 |
+
|
| 351 |
+
if verbose:
|
| 352 |
+
print([name for name, value in net.named_parameters() if value.requires_grad])
|
| 353 |
+
|
| 354 |
+
lr_base = lr
|
| 355 |
+
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
|
| 356 |
+
|
| 357 |
+
with tqdm.tqdm(total=niter) as bar:
|
| 358 |
+
while bar.n < bar.total:
|
| 359 |
+
t = bar.n / bar.total
|
| 360 |
+
|
| 361 |
+
if schedule == 'cosine':
|
| 362 |
+
lr = cosine_schedule(t, lr_base, lr_min)
|
| 363 |
+
elif schedule == 'linear':
|
| 364 |
+
lr = linear_schedule(t, lr_base, lr_min)
|
| 365 |
+
else:
|
| 366 |
+
raise ValueError(f'bad lr {schedule=}')
|
| 367 |
+
adjust_learning_rate_by_lr(optimizer, lr)
|
| 368 |
+
|
| 369 |
+
optimizer.zero_grad()
|
| 370 |
+
loss = net()
|
| 371 |
+
loss.backward()
|
| 372 |
+
optimizer.step()
|
| 373 |
+
loss = float(loss)
|
| 374 |
+
bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
|
| 375 |
+
bar.update()
|
dust3r/dust3r/cloud_opt/commons.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utility functions for global alignment
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def edge_str(i, j):
|
| 13 |
+
return f'{i}_{j}'
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def i_j_ij(ij):
|
| 17 |
+
return edge_str(*ij), ij
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def edge_conf(conf_i, conf_j, edge):
|
| 21 |
+
return float(conf_i[edge].mean() * conf_j[edge].mean())
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compute_edge_scores(edges, conf_i, conf_j):
|
| 25 |
+
return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def NoGradParamDict(x):
|
| 29 |
+
assert isinstance(x, dict)
|
| 30 |
+
return nn.ParameterDict(x).requires_grad_(False)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_imshapes(edges, pred_i, pred_j):
|
| 34 |
+
n_imgs = max(max(e) for e in edges) + 1
|
| 35 |
+
imshapes = [None] * n_imgs
|
| 36 |
+
for e, (i, j) in enumerate(edges):
|
| 37 |
+
shape_i = tuple(pred_i[e].shape[0:2])
|
| 38 |
+
shape_j = tuple(pred_j[e].shape[0:2])
|
| 39 |
+
if imshapes[i]:
|
| 40 |
+
assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
|
| 41 |
+
if imshapes[j]:
|
| 42 |
+
assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
|
| 43 |
+
imshapes[i] = shape_i
|
| 44 |
+
imshapes[j] = shape_j
|
| 45 |
+
return imshapes
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_conf_trf(mode):
|
| 49 |
+
if mode == 'log':
|
| 50 |
+
def conf_trf(x): return x.log()
|
| 51 |
+
elif mode == 'sqrt':
|
| 52 |
+
def conf_trf(x): return x.sqrt()
|
| 53 |
+
elif mode == 'm1':
|
| 54 |
+
def conf_trf(x): return x-1
|
| 55 |
+
elif mode in ('id', 'none'):
|
| 56 |
+
def conf_trf(x): return x
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f'bad mode for {mode=}')
|
| 59 |
+
return conf_trf
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def l2_dist(a, b, weight):
|
| 63 |
+
return ((a - b).square().sum(dim=-1) * weight)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def l1_dist(a, b, weight):
|
| 67 |
+
return ((a - b).norm(dim=-1) * weight)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def signed_log1p(x):
|
| 74 |
+
sign = torch.sign(x)
|
| 75 |
+
return sign * torch.log1p(torch.abs(x))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def signed_expm1(x):
|
| 79 |
+
sign = torch.sign(x)
|
| 80 |
+
return sign * torch.expm1(torch.abs(x))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def cosine_schedule(t, lr_start, lr_end):
|
| 84 |
+
assert 0 <= t <= 1
|
| 85 |
+
return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def linear_schedule(t, lr_start, lr_end):
|
| 89 |
+
assert 0 <= t <= 1
|
| 90 |
+
return lr_start + (lr_end - lr_start) * t
|
dust3r/dust3r/cloud_opt/init_im_poses.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Initialization functions for global alignment
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from functools import cache
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import scipy.sparse as sp
|
| 11 |
+
import torch
|
| 12 |
+
import cv2
|
| 13 |
+
import roma
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
|
| 17 |
+
from dust3r.post_process import estimate_focal_knowing_depth
|
| 18 |
+
from dust3r.viz import to_numpy
|
| 19 |
+
|
| 20 |
+
from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
|
| 25 |
+
device = self.device
|
| 26 |
+
|
| 27 |
+
# indices of known poses
|
| 28 |
+
nkp, known_poses_msk, known_poses = get_known_poses(self)
|
| 29 |
+
assert nkp == self.n_imgs, 'not all poses are known'
|
| 30 |
+
|
| 31 |
+
# get all focals
|
| 32 |
+
nkf, _, im_focals = get_known_focals(self)
|
| 33 |
+
assert nkf == self.n_imgs
|
| 34 |
+
im_pp = self.get_principal_points()
|
| 35 |
+
|
| 36 |
+
best_depthmaps = {}
|
| 37 |
+
# init all pairwise poses
|
| 38 |
+
for e, (i, j) in enumerate(tqdm(self.edges)):
|
| 39 |
+
i_j = edge_str(i, j)
|
| 40 |
+
|
| 41 |
+
# find relative pose for this pair
|
| 42 |
+
P1 = torch.eye(4, device=device)
|
| 43 |
+
msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
|
| 44 |
+
_, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),
|
| 45 |
+
pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)
|
| 46 |
+
|
| 47 |
+
# align the two predicted camera with the two gt cameras
|
| 48 |
+
s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
|
| 49 |
+
# normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
|
| 50 |
+
# and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
|
| 51 |
+
self._set_pose(self.pw_poses, e, R, T, scale=s)
|
| 52 |
+
|
| 53 |
+
# remember if this is a good depthmap
|
| 54 |
+
score = float(self.conf_i[i_j].mean())
|
| 55 |
+
if score > best_depthmaps.get(i, (0,))[0]:
|
| 56 |
+
best_depthmaps[i] = score, i_j, s
|
| 57 |
+
|
| 58 |
+
# init all image poses
|
| 59 |
+
for n in range(self.n_imgs):
|
| 60 |
+
assert known_poses_msk[n]
|
| 61 |
+
_, i_j, scale = best_depthmaps[n]
|
| 62 |
+
depth = self.pred_i[i_j][:, :, 2]
|
| 63 |
+
self._set_depthmap(n, depth * scale)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def init_minimum_spanning_tree(self, **kw):
|
| 68 |
+
""" Init all camera poses (image-wise and pairwise poses) given
|
| 69 |
+
an initial set of pairwise estimations.
|
| 70 |
+
"""
|
| 71 |
+
device = self.device
|
| 72 |
+
pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,
|
| 73 |
+
self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,
|
| 74 |
+
device, has_im_poses=self.has_im_poses, **kw)
|
| 75 |
+
|
| 76 |
+
return init_from_pts3d(self, pts3d, im_focals, im_poses)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def init_from_pts3d(self, pts3d, im_focals, im_poses):
|
| 80 |
+
# init poses
|
| 81 |
+
nkp, known_poses_msk, known_poses = get_known_poses(self)
|
| 82 |
+
if nkp == 1:
|
| 83 |
+
raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
|
| 84 |
+
elif nkp > 1:
|
| 85 |
+
# global rigid SE3 alignment
|
| 86 |
+
s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
|
| 87 |
+
trf = sRT_to_4x4(s, R, T, device=known_poses.device)
|
| 88 |
+
|
| 89 |
+
# rotate everything
|
| 90 |
+
im_poses = trf @ im_poses
|
| 91 |
+
im_poses[:, :3, :3] /= s # undo scaling on the rotation part
|
| 92 |
+
for img_pts3d in pts3d:
|
| 93 |
+
img_pts3d[:] = geotrf(trf, img_pts3d)
|
| 94 |
+
|
| 95 |
+
# set all pairwise poses
|
| 96 |
+
for e, (i, j) in enumerate(self.edges):
|
| 97 |
+
i_j = edge_str(i, j)
|
| 98 |
+
# compute transform that goes from cam to world
|
| 99 |
+
s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j])
|
| 100 |
+
self._set_pose(self.pw_poses, e, R, T, scale=s)
|
| 101 |
+
|
| 102 |
+
# take into account the scale normalization
|
| 103 |
+
s_factor = self.get_pw_norm_scale_factor()
|
| 104 |
+
im_poses[:, :3, 3] *= s_factor # apply downscaling factor
|
| 105 |
+
for img_pts3d in pts3d:
|
| 106 |
+
img_pts3d *= s_factor
|
| 107 |
+
|
| 108 |
+
# init all image poses
|
| 109 |
+
if self.has_im_poses:
|
| 110 |
+
for i in range(self.n_imgs):
|
| 111 |
+
cam2world = im_poses[i]
|
| 112 |
+
depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
|
| 113 |
+
self._set_depthmap(i, depth)
|
| 114 |
+
self._set_pose(self.im_poses, i, cam2world)
|
| 115 |
+
if im_focals[i] is not None:
|
| 116 |
+
self._set_focal(i, im_focals[i])
|
| 117 |
+
|
| 118 |
+
print(' init loss =', float(self()))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,
|
| 122 |
+
device, has_im_poses=True, niter_PnP=10):
|
| 123 |
+
n_imgs = len(imshapes)
|
| 124 |
+
sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j))
|
| 125 |
+
msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
|
| 126 |
+
|
| 127 |
+
# temp variable to store 3d points
|
| 128 |
+
pts3d = [None] * len(imshapes)
|
| 129 |
+
|
| 130 |
+
todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
|
| 131 |
+
im_poses = [None] * n_imgs
|
| 132 |
+
im_focals = [None] * n_imgs
|
| 133 |
+
|
| 134 |
+
# init with strongest edge
|
| 135 |
+
score, i, j = todo.pop()
|
| 136 |
+
print(f' init edge ({i}*,{j}*) {score=}')
|
| 137 |
+
i_j = edge_str(i, j)
|
| 138 |
+
pts3d[i] = pred_i[i_j].clone()
|
| 139 |
+
pts3d[j] = pred_j[i_j].clone()
|
| 140 |
+
done = {i, j}
|
| 141 |
+
if has_im_poses:
|
| 142 |
+
im_poses[i] = torch.eye(4, device=device)
|
| 143 |
+
im_focals[i] = estimate_focal(pred_i[i_j])
|
| 144 |
+
|
| 145 |
+
# set intial pointcloud based on pairwise graph
|
| 146 |
+
msp_edges = [(i, j)]
|
| 147 |
+
while todo:
|
| 148 |
+
# each time, predict the next one
|
| 149 |
+
score, i, j = todo.pop()
|
| 150 |
+
|
| 151 |
+
if im_focals[i] is None:
|
| 152 |
+
im_focals[i] = estimate_focal(pred_i[i_j])
|
| 153 |
+
|
| 154 |
+
if i in done:
|
| 155 |
+
print(f' init edge ({i},{j}*) {score=}')
|
| 156 |
+
assert j not in done
|
| 157 |
+
# align pred[i] with pts3d[i], and then set j accordingly
|
| 158 |
+
i_j = edge_str(i, j)
|
| 159 |
+
s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j])
|
| 160 |
+
trf = sRT_to_4x4(s, R, T, device)
|
| 161 |
+
pts3d[j] = geotrf(trf, pred_j[i_j])
|
| 162 |
+
done.add(j)
|
| 163 |
+
msp_edges.append((i, j))
|
| 164 |
+
|
| 165 |
+
if has_im_poses and im_poses[i] is None:
|
| 166 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
| 167 |
+
|
| 168 |
+
elif j in done:
|
| 169 |
+
print(f' init edge ({i}*,{j}) {score=}')
|
| 170 |
+
assert i not in done
|
| 171 |
+
i_j = edge_str(i, j)
|
| 172 |
+
s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j])
|
| 173 |
+
trf = sRT_to_4x4(s, R, T, device)
|
| 174 |
+
pts3d[i] = geotrf(trf, pred_i[i_j])
|
| 175 |
+
done.add(i)
|
| 176 |
+
msp_edges.append((i, j))
|
| 177 |
+
|
| 178 |
+
if has_im_poses and im_poses[i] is None:
|
| 179 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
| 180 |
+
else:
|
| 181 |
+
# let's try again later
|
| 182 |
+
todo.insert(0, (score, i, j))
|
| 183 |
+
|
| 184 |
+
if has_im_poses:
|
| 185 |
+
# complete all missing informations
|
| 186 |
+
pair_scores = list(sparse_graph.values()) # already negative scores: less is best
|
| 187 |
+
edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]
|
| 188 |
+
for i, j in edges_from_best_to_worse.tolist():
|
| 189 |
+
if im_focals[i] is None:
|
| 190 |
+
im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
|
| 191 |
+
|
| 192 |
+
for i in range(n_imgs):
|
| 193 |
+
if im_poses[i] is None:
|
| 194 |
+
msk = im_conf[i] > min_conf_thr
|
| 195 |
+
res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)
|
| 196 |
+
if res:
|
| 197 |
+
im_focals[i], im_poses[i] = res
|
| 198 |
+
if im_poses[i] is None:
|
| 199 |
+
im_poses[i] = torch.eye(4, device=device)
|
| 200 |
+
im_poses = torch.stack(im_poses)
|
| 201 |
+
else:
|
| 202 |
+
im_poses = im_focals = None
|
| 203 |
+
|
| 204 |
+
return pts3d, msp_edges, im_focals, im_poses
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def dict_to_sparse_graph(dic):
|
| 208 |
+
n_imgs = max(max(e) for e in dic) + 1
|
| 209 |
+
res = sp.dok_array((n_imgs, n_imgs))
|
| 210 |
+
for edge, value in dic.items():
|
| 211 |
+
res[edge] = value
|
| 212 |
+
return res
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def rigid_points_registration(pts1, pts2, conf):
|
| 216 |
+
R, T, s = roma.rigid_points_registration(
|
| 217 |
+
pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)
|
| 218 |
+
return s, R, T # return un-scaled (R, T)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def sRT_to_4x4(scale, R, T, device):
|
| 222 |
+
trf = torch.eye(4, device=device)
|
| 223 |
+
trf[:3, :3] = R * scale
|
| 224 |
+
trf[:3, 3] = T.ravel() # doesn't need scaling
|
| 225 |
+
return trf
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def estimate_focal(pts3d_i, pp=None):
|
| 229 |
+
if pp is None:
|
| 230 |
+
H, W, THREE = pts3d_i.shape
|
| 231 |
+
assert THREE == 3
|
| 232 |
+
pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
|
| 233 |
+
focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(
|
| 234 |
+
0), focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5).ravel()
|
| 235 |
+
return float(focal)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@cache
|
| 239 |
+
def pixel_grid(H, W):
|
| 240 |
+
return np.mgrid[:W, :H].T.astype(np.float32)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
|
| 244 |
+
# extract camera poses and focals with RANSAC-PnP
|
| 245 |
+
if msk.sum() < 4:
|
| 246 |
+
return None # we need at least 4 points for PnP
|
| 247 |
+
pts3d, msk = map(to_numpy, (pts3d, msk))
|
| 248 |
+
|
| 249 |
+
H, W, THREE = pts3d.shape
|
| 250 |
+
assert THREE == 3
|
| 251 |
+
pixels = pixel_grid(H, W)
|
| 252 |
+
|
| 253 |
+
if focal is None:
|
| 254 |
+
S = max(W, H)
|
| 255 |
+
tentative_focals = np.geomspace(S/2, S*3, 21)
|
| 256 |
+
else:
|
| 257 |
+
tentative_focals = [focal]
|
| 258 |
+
|
| 259 |
+
if pp is None:
|
| 260 |
+
pp = (W/2, H/2)
|
| 261 |
+
else:
|
| 262 |
+
pp = to_numpy(pp)
|
| 263 |
+
|
| 264 |
+
best = 0,
|
| 265 |
+
for focal in tentative_focals:
|
| 266 |
+
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
|
| 267 |
+
|
| 268 |
+
success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
|
| 269 |
+
iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
|
| 270 |
+
if not success:
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
score = len(inliers)
|
| 274 |
+
if success and score > best[0]:
|
| 275 |
+
best = score, R, T, focal
|
| 276 |
+
|
| 277 |
+
if not best[0]:
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
_, R, T, best_focal = best
|
| 281 |
+
R = cv2.Rodrigues(R)[0] # world to cam
|
| 282 |
+
R, T = map(torch.from_numpy, (R, T))
|
| 283 |
+
return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def get_known_poses(self):
|
| 287 |
+
if self.has_im_poses:
|
| 288 |
+
known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
|
| 289 |
+
known_poses = self.get_im_poses()
|
| 290 |
+
return known_poses_msk.sum(), known_poses_msk, known_poses
|
| 291 |
+
else:
|
| 292 |
+
return 0, None, None
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def get_known_focals(self):
|
| 296 |
+
if self.has_im_poses:
|
| 297 |
+
known_focal_msk = self.get_known_focal_mask()
|
| 298 |
+
known_focals = self.get_focals()
|
| 299 |
+
return known_focal_msk.sum(), known_focal_msk, known_focals
|
| 300 |
+
else:
|
| 301 |
+
return 0, None, None
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def align_multiple_poses(src_poses, target_poses):
|
| 305 |
+
N = len(src_poses)
|
| 306 |
+
assert src_poses.shape == target_poses.shape == (N, 4, 4)
|
| 307 |
+
|
| 308 |
+
def center_and_z(poses):
|
| 309 |
+
eps = get_med_dist_between_poses(poses) / 100
|
| 310 |
+
return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))
|
| 311 |
+
R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)
|
| 312 |
+
return s, R, T
|
dust3r/dust3r/cloud_opt/optimizer.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Main class for the implementation of the global alignment
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from dust3r.cloud_opt.base_opt import BasePCOptimizer
|
| 12 |
+
from dust3r.utils.geometry import xy_grid, geotrf
|
| 13 |
+
from dust3r.utils.device import to_cpu, to_numpy
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PointCloudOptimizer(BasePCOptimizer):
|
| 17 |
+
""" Optimize a global scene, given a list of pairwise observations.
|
| 18 |
+
Graph node: images
|
| 19 |
+
Graph edges: observations = (pred1, pred2)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
|
| 25 |
+
self.has_im_poses = True # by definition of this class
|
| 26 |
+
self.focal_break = focal_break
|
| 27 |
+
|
| 28 |
+
# adding thing to optimize
|
| 29 |
+
self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
|
| 30 |
+
self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
|
| 31 |
+
self.im_focals = nn.ParameterList(torch.FloatTensor(
|
| 32 |
+
[self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
|
| 33 |
+
self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
|
| 34 |
+
self.im_pp.requires_grad_(optimize_pp)
|
| 35 |
+
|
| 36 |
+
self.imshape = self.imshapes[0]
|
| 37 |
+
im_areas = [h*w for h, w in self.imshapes]
|
| 38 |
+
self.max_area = max(im_areas)
|
| 39 |
+
|
| 40 |
+
# adding thing to optimize
|
| 41 |
+
self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
|
| 42 |
+
self.im_poses = ParameterStack(self.im_poses, is_param=True)
|
| 43 |
+
self.im_focals = ParameterStack(self.im_focals, is_param=True)
|
| 44 |
+
self.im_pp = ParameterStack(self.im_pp, is_param=True)
|
| 45 |
+
self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
|
| 46 |
+
self.register_buffer('_grid', ParameterStack(
|
| 47 |
+
[xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
|
| 48 |
+
|
| 49 |
+
# pre-compute pixel weights
|
| 50 |
+
self.register_buffer('_weight_i', ParameterStack(
|
| 51 |
+
[self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
|
| 52 |
+
self.register_buffer('_weight_j', ParameterStack(
|
| 53 |
+
[self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
|
| 54 |
+
|
| 55 |
+
# precompute aa
|
| 56 |
+
self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
|
| 57 |
+
self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
|
| 58 |
+
self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
|
| 59 |
+
self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
|
| 60 |
+
self.total_area_i = sum([im_areas[i] for i, j in self.edges])
|
| 61 |
+
self.total_area_j = sum([im_areas[j] for i, j in self.edges])
|
| 62 |
+
|
| 63 |
+
def _check_all_imgs_are_selected(self, msk):
|
| 64 |
+
assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
|
| 65 |
+
|
| 66 |
+
def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
|
| 67 |
+
self._check_all_imgs_are_selected(pose_msk)
|
| 68 |
+
|
| 69 |
+
if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
|
| 70 |
+
known_poses = [known_poses]
|
| 71 |
+
for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
|
| 72 |
+
print(f' (setting pose #{idx} = {pose[:3,3]})')
|
| 73 |
+
self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
|
| 74 |
+
|
| 75 |
+
# normalize scale if there's less than 1 known pose
|
| 76 |
+
n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
|
| 77 |
+
self.norm_pw_scale = (n_known_poses <= 1)
|
| 78 |
+
|
| 79 |
+
self.im_poses.requires_grad_(False)
|
| 80 |
+
self.norm_pw_scale = False
|
| 81 |
+
|
| 82 |
+
def preset_focal(self, known_focals, msk=None):
|
| 83 |
+
self._check_all_imgs_are_selected(msk)
|
| 84 |
+
|
| 85 |
+
for idx, focal in zip(self._get_msk_indices(msk), known_focals):
|
| 86 |
+
print(f' (setting focal #{idx} = {focal})')
|
| 87 |
+
self._no_grad(self._set_focal(idx, focal))
|
| 88 |
+
|
| 89 |
+
self.im_focals.requires_grad_(False)
|
| 90 |
+
|
| 91 |
+
def preset_principal_point(self, known_pp, msk=None):
|
| 92 |
+
self._check_all_imgs_are_selected(msk)
|
| 93 |
+
|
| 94 |
+
for idx, pp in zip(self._get_msk_indices(msk), known_pp):
|
| 95 |
+
print(f' (setting principal point #{idx} = {pp})')
|
| 96 |
+
self._no_grad(self._set_principal_point(idx, pp))
|
| 97 |
+
|
| 98 |
+
self.im_pp.requires_grad_(False)
|
| 99 |
+
|
| 100 |
+
def _no_grad(self, tensor):
|
| 101 |
+
assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
|
| 102 |
+
|
| 103 |
+
def _set_focal(self, idx, focal, force=False):
|
| 104 |
+
param = self.im_focals[idx]
|
| 105 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
| 106 |
+
param.data[:] = self.focal_break * np.log(focal)
|
| 107 |
+
return param
|
| 108 |
+
|
| 109 |
+
def get_focals(self):
|
| 110 |
+
log_focals = torch.stack(list(self.im_focals), dim=0)
|
| 111 |
+
return (log_focals / self.focal_break).exp()
|
| 112 |
+
|
| 113 |
+
def get_known_focal_mask(self):
|
| 114 |
+
return torch.tensor([not (p.requires_grad) for p in self.im_focals])
|
| 115 |
+
|
| 116 |
+
def _set_principal_point(self, idx, pp, force=False):
|
| 117 |
+
param = self.im_pp[idx]
|
| 118 |
+
H, W = self.imshapes[idx]
|
| 119 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
| 120 |
+
param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
|
| 121 |
+
return param
|
| 122 |
+
|
| 123 |
+
def get_principal_points(self):
|
| 124 |
+
return self._pp + 10 * self.im_pp
|
| 125 |
+
|
| 126 |
+
def get_intrinsics(self):
|
| 127 |
+
K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
|
| 128 |
+
focals = self.get_focals().flatten()
|
| 129 |
+
K[:, 0, 0] = K[:, 1, 1] = focals
|
| 130 |
+
K[:, :2, 2] = self.get_principal_points()
|
| 131 |
+
K[:, 2, 2] = 1
|
| 132 |
+
return K
|
| 133 |
+
|
| 134 |
+
def get_im_poses(self): # cam to world
|
| 135 |
+
cam2world = self._get_poses(self.im_poses)
|
| 136 |
+
return cam2world
|
| 137 |
+
|
| 138 |
+
def _set_depthmap(self, idx, depth, force=False):
|
| 139 |
+
depth = _ravel_hw(depth, self.max_area)
|
| 140 |
+
|
| 141 |
+
param = self.im_depthmaps[idx]
|
| 142 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
| 143 |
+
param.data[:] = depth.log().nan_to_num(neginf=0)
|
| 144 |
+
return param
|
| 145 |
+
|
| 146 |
+
def get_depthmaps(self, raw=False):
|
| 147 |
+
res = self.im_depthmaps.exp()
|
| 148 |
+
if not raw:
|
| 149 |
+
res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
|
| 150 |
+
return res
|
| 151 |
+
|
| 152 |
+
def depth_to_pts3d(self):
|
| 153 |
+
# Get depths and projection params if not provided
|
| 154 |
+
focals = self.get_focals()
|
| 155 |
+
pp = self.get_principal_points()
|
| 156 |
+
im_poses = self.get_im_poses()
|
| 157 |
+
depth = self.get_depthmaps(raw=True)
|
| 158 |
+
|
| 159 |
+
# get pointmaps in camera frame
|
| 160 |
+
rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
|
| 161 |
+
# project to world frame
|
| 162 |
+
return geotrf(im_poses, rel_ptmaps)
|
| 163 |
+
|
| 164 |
+
def get_pts3d(self, raw=False):
|
| 165 |
+
res = self.depth_to_pts3d()
|
| 166 |
+
if not raw:
|
| 167 |
+
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
|
| 168 |
+
return res
|
| 169 |
+
|
| 170 |
+
def forward(self):
|
| 171 |
+
pw_poses = self.get_pw_poses() # cam-to-world
|
| 172 |
+
pw_adapt = self.get_adaptors().unsqueeze(1)
|
| 173 |
+
proj_pts3d = self.get_pts3d(raw=True)
|
| 174 |
+
|
| 175 |
+
# rotate pairwise prediction according to pw_poses
|
| 176 |
+
aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
|
| 177 |
+
aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
|
| 178 |
+
|
| 179 |
+
# compute the less
|
| 180 |
+
li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
|
| 181 |
+
lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
|
| 182 |
+
|
| 183 |
+
return li + lj
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
|
| 187 |
+
pp = pp.unsqueeze(1)
|
| 188 |
+
focal = focal.unsqueeze(1)
|
| 189 |
+
assert focal.shape == (len(depth), 1, 1)
|
| 190 |
+
assert pp.shape == (len(depth), 1, 2)
|
| 191 |
+
assert pixel_grid.shape == depth.shape + (2,)
|
| 192 |
+
depth = depth.unsqueeze(-1)
|
| 193 |
+
return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def ParameterStack(params, keys=None, is_param=None, fill=0):
|
| 197 |
+
if keys is not None:
|
| 198 |
+
params = [params[k] for k in keys]
|
| 199 |
+
|
| 200 |
+
if fill > 0:
|
| 201 |
+
params = [_ravel_hw(p, fill) for p in params]
|
| 202 |
+
|
| 203 |
+
requires_grad = params[0].requires_grad
|
| 204 |
+
assert all(p.requires_grad == requires_grad for p in params)
|
| 205 |
+
|
| 206 |
+
params = torch.stack(list(params)).float().detach()
|
| 207 |
+
if is_param or requires_grad:
|
| 208 |
+
params = nn.Parameter(params)
|
| 209 |
+
params.requires_grad_(requires_grad)
|
| 210 |
+
return params
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _ravel_hw(tensor, fill=0):
|
| 214 |
+
# ravel H,W
|
| 215 |
+
tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
| 216 |
+
|
| 217 |
+
if len(tensor) < fill:
|
| 218 |
+
tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
|
| 219 |
+
return tensor
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
|
| 223 |
+
focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
|
| 224 |
+
return minf*focal_base, maxf*focal_base
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def apply_mask(img, msk):
|
| 228 |
+
img = img.copy()
|
| 229 |
+
img[msk] = 0
|
| 230 |
+
return img
|
dust3r/dust3r/cloud_opt/pair_viewer.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Dummy optimizer for visualizing pairs
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
from dust3r.cloud_opt.base_opt import BasePCOptimizer
|
| 13 |
+
from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
|
| 14 |
+
from dust3r.cloud_opt.commons import edge_str
|
| 15 |
+
from dust3r.post_process import estimate_focal_knowing_depth
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PairViewer (BasePCOptimizer):
|
| 19 |
+
"""
|
| 20 |
+
This a Dummy Optimizer.
|
| 21 |
+
To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, *args, **kwargs):
|
| 25 |
+
super().__init__(*args, **kwargs)
|
| 26 |
+
assert self.is_symmetrized and self.n_edges == 2
|
| 27 |
+
self.has_im_poses = True
|
| 28 |
+
|
| 29 |
+
# compute all parameters directly from raw input
|
| 30 |
+
self.focals = []
|
| 31 |
+
self.pp = []
|
| 32 |
+
rel_poses = []
|
| 33 |
+
confs = []
|
| 34 |
+
for i in range(self.n_imgs):
|
| 35 |
+
conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
|
| 36 |
+
print(f' - {conf=:.3} for edge {i}-{1-i}')
|
| 37 |
+
confs.append(conf)
|
| 38 |
+
|
| 39 |
+
H, W = self.imshapes[i]
|
| 40 |
+
pts3d = self.pred_i[edge_str(i, 1-i)]
|
| 41 |
+
pp = torch.tensor((W/2, H/2))
|
| 42 |
+
focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
|
| 43 |
+
self.focals.append(focal)
|
| 44 |
+
self.pp.append(pp)
|
| 45 |
+
|
| 46 |
+
# estimate the pose of pts1 in image 2
|
| 47 |
+
pixels = np.mgrid[:W, :H].T.astype(np.float32)
|
| 48 |
+
pts3d = self.pred_j[edge_str(1-i, i)].numpy()
|
| 49 |
+
assert pts3d.shape[:2] == (H, W)
|
| 50 |
+
msk = self.get_masks()[i].numpy()
|
| 51 |
+
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
|
| 55 |
+
iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
|
| 56 |
+
success, R, T, inliers = res
|
| 57 |
+
assert success
|
| 58 |
+
|
| 59 |
+
R = cv2.Rodrigues(R)[0] # world to cam
|
| 60 |
+
pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
|
| 61 |
+
except:
|
| 62 |
+
pose = np.eye(4)
|
| 63 |
+
rel_poses.append(torch.from_numpy(pose.astype(np.float32)))
|
| 64 |
+
|
| 65 |
+
# let's use the pair with the most confidence
|
| 66 |
+
if confs[0] > confs[1]:
|
| 67 |
+
# ptcloud is expressed in camera1
|
| 68 |
+
self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1
|
| 69 |
+
self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
|
| 70 |
+
else:
|
| 71 |
+
# ptcloud is expressed in camera2
|
| 72 |
+
self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2
|
| 73 |
+
self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
|
| 74 |
+
|
| 75 |
+
self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
|
| 76 |
+
self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
|
| 77 |
+
self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
|
| 78 |
+
self.depth = nn.ParameterList(self.depth)
|
| 79 |
+
for p in self.parameters():
|
| 80 |
+
p.requires_grad = False
|
| 81 |
+
|
| 82 |
+
def _set_depthmap(self, idx, depth, force=False):
|
| 83 |
+
print('_set_depthmap is ignored in PairViewer')
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
def get_depthmaps(self, raw=False):
|
| 87 |
+
depth = [d.to(self.device) for d in self.depth]
|
| 88 |
+
return depth
|
| 89 |
+
|
| 90 |
+
def _set_focal(self, idx, focal, force=False):
|
| 91 |
+
self.focals[idx] = focal
|
| 92 |
+
|
| 93 |
+
def get_focals(self):
|
| 94 |
+
return self.focals
|
| 95 |
+
|
| 96 |
+
def get_known_focal_mask(self):
|
| 97 |
+
return torch.tensor([not (p.requires_grad) for p in self.focals])
|
| 98 |
+
|
| 99 |
+
def get_principal_points(self):
|
| 100 |
+
return self.pp
|
| 101 |
+
|
| 102 |
+
def get_intrinsics(self):
|
| 103 |
+
focals = self.get_focals()
|
| 104 |
+
pps = self.get_principal_points()
|
| 105 |
+
K = torch.zeros((len(focals), 3, 3), device=self.device)
|
| 106 |
+
for i in range(len(focals)):
|
| 107 |
+
K[i, 0, 0] = K[i, 1, 1] = focals[i]
|
| 108 |
+
K[i, :2, 2] = pps[i]
|
| 109 |
+
K[i, 2, 2] = 1
|
| 110 |
+
return K
|
| 111 |
+
|
| 112 |
+
def get_im_poses(self):
|
| 113 |
+
return self.im_poses
|
| 114 |
+
|
| 115 |
+
def depth_to_pts3d(self):
|
| 116 |
+
pts3d = []
|
| 117 |
+
for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
|
| 118 |
+
pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
|
| 119 |
+
intrinsics.cpu().numpy(),
|
| 120 |
+
im_pose.cpu().numpy())
|
| 121 |
+
pts3d.append(torch.from_numpy(pts).to(device=self.device))
|
| 122 |
+
return pts3d
|
| 123 |
+
|
| 124 |
+
def forward(self):
|
| 125 |
+
return float('nan')
|
dust3r/dust3r/datasets/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
from .utils.transforms import *
|
| 4 |
+
from .base.batched_sampler import BatchedRandomSampler # noqa: F401
|
| 5 |
+
from .co3d import Co3d # noqa: F401
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
|
| 9 |
+
import torch
|
| 10 |
+
from croco.utils.misc import get_world_size, get_rank
|
| 11 |
+
|
| 12 |
+
# pytorch dataset
|
| 13 |
+
if isinstance(dataset, str):
|
| 14 |
+
dataset = eval(dataset)
|
| 15 |
+
|
| 16 |
+
world_size = get_world_size()
|
| 17 |
+
rank = get_rank()
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
|
| 21 |
+
rank=rank, drop_last=drop_last)
|
| 22 |
+
except (AttributeError, NotImplementedError):
|
| 23 |
+
# not avail for this dataset
|
| 24 |
+
if torch.distributed.is_initialized():
|
| 25 |
+
sampler = torch.utils.data.DistributedSampler(
|
| 26 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
|
| 27 |
+
)
|
| 28 |
+
elif shuffle:
|
| 29 |
+
sampler = torch.utils.data.RandomSampler(dataset)
|
| 30 |
+
else:
|
| 31 |
+
sampler = torch.utils.data.SequentialSampler(dataset)
|
| 32 |
+
|
| 33 |
+
data_loader = torch.utils.data.DataLoader(
|
| 34 |
+
dataset,
|
| 35 |
+
sampler=sampler,
|
| 36 |
+
batch_size=batch_size,
|
| 37 |
+
num_workers=num_workers,
|
| 38 |
+
pin_memory=pin_mem,
|
| 39 |
+
drop_last=drop_last,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return data_loader
|
dust3r/dust3r/datasets/base/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
dust3r/dust3r/datasets/base/base_stereo_view_dataset.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# base class for implementing datasets
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import PIL
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from dust3r.datasets.base.easy_dataset import EasyDataset
|
| 12 |
+
from dust3r.datasets.utils.transforms import ImgNorm
|
| 13 |
+
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
|
| 14 |
+
import dust3r.datasets.utils.cropping as cropping
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseStereoViewDataset (EasyDataset):
|
| 18 |
+
""" Define all basic options.
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
class MyDataset (BaseStereoViewDataset):
|
| 22 |
+
def _get_views(self, idx, rng):
|
| 23 |
+
# overload here
|
| 24 |
+
views = []
|
| 25 |
+
views.append(dict(img=, ...))
|
| 26 |
+
return views
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, *, # only keyword arguments
|
| 30 |
+
split=None,
|
| 31 |
+
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
| 32 |
+
transform=ImgNorm,
|
| 33 |
+
aug_crop=False,
|
| 34 |
+
seed=None):
|
| 35 |
+
self.num_views = 2
|
| 36 |
+
self.split = split
|
| 37 |
+
self._set_resolutions(resolution)
|
| 38 |
+
|
| 39 |
+
self.transform = transform
|
| 40 |
+
if isinstance(transform, str):
|
| 41 |
+
transform = eval(transform)
|
| 42 |
+
|
| 43 |
+
self.aug_crop = aug_crop
|
| 44 |
+
self.seed = seed
|
| 45 |
+
|
| 46 |
+
def __len__(self):
|
| 47 |
+
return len(self.scenes)
|
| 48 |
+
|
| 49 |
+
def get_stats(self):
|
| 50 |
+
return f"{len(self)} pairs"
|
| 51 |
+
|
| 52 |
+
def __repr__(self):
|
| 53 |
+
resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
|
| 54 |
+
return f"""{type(self).__name__}({self.get_stats()},
|
| 55 |
+
{self.split=},
|
| 56 |
+
{self.seed=},
|
| 57 |
+
resolutions={resolutions_str},
|
| 58 |
+
{self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '')
|
| 59 |
+
|
| 60 |
+
def _get_views(self, idx, resolution, rng):
|
| 61 |
+
raise NotImplementedError()
|
| 62 |
+
|
| 63 |
+
def __getitem__(self, idx):
|
| 64 |
+
if isinstance(idx, tuple):
|
| 65 |
+
# the idx is specifying the aspect-ratio
|
| 66 |
+
idx, ar_idx = idx
|
| 67 |
+
else:
|
| 68 |
+
assert len(self._resolutions) == 1
|
| 69 |
+
ar_idx = 0
|
| 70 |
+
|
| 71 |
+
# set-up the rng
|
| 72 |
+
if self.seed: # reseed for each __getitem__
|
| 73 |
+
self._rng = np.random.default_rng(seed=self.seed + idx)
|
| 74 |
+
elif not hasattr(self, '_rng'):
|
| 75 |
+
seed = torch.initial_seed() # this is different for each dataloader process
|
| 76 |
+
self._rng = np.random.default_rng(seed=seed)
|
| 77 |
+
|
| 78 |
+
# over-loaded code
|
| 79 |
+
resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
| 80 |
+
views = self._get_views(idx, resolution, self._rng)
|
| 81 |
+
assert len(views) == self.num_views
|
| 82 |
+
|
| 83 |
+
# check data-types
|
| 84 |
+
for v, view in enumerate(views):
|
| 85 |
+
assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
| 86 |
+
view['idx'] = (idx, ar_idx, v)
|
| 87 |
+
|
| 88 |
+
# encode the image
|
| 89 |
+
width, height = view['img'].size
|
| 90 |
+
view['true_shape'] = np.int32((height, width))
|
| 91 |
+
view['img'] = self.transform(view['img'])
|
| 92 |
+
|
| 93 |
+
assert 'camera_intrinsics' in view
|
| 94 |
+
if 'camera_pose' not in view:
|
| 95 |
+
view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
|
| 96 |
+
else:
|
| 97 |
+
assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
|
| 98 |
+
assert 'pts3d' not in view
|
| 99 |
+
assert 'valid_mask' not in view
|
| 100 |
+
assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
|
| 101 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 102 |
+
|
| 103 |
+
view['pts3d'] = pts3d
|
| 104 |
+
view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
| 105 |
+
|
| 106 |
+
# check all datatypes
|
| 107 |
+
for key, val in view.items():
|
| 108 |
+
res, err_msg = is_good_type(key, val)
|
| 109 |
+
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
| 110 |
+
K = view['camera_intrinsics']
|
| 111 |
+
|
| 112 |
+
# last thing done!
|
| 113 |
+
for view in views:
|
| 114 |
+
# transpose to make sure all views are the same size
|
| 115 |
+
transpose_to_landscape(view)
|
| 116 |
+
# this allows to check whether the RNG is is the same state each time
|
| 117 |
+
view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
|
| 118 |
+
return views
|
| 119 |
+
|
| 120 |
+
def _set_resolutions(self, resolutions):
|
| 121 |
+
assert resolutions is not None, 'undefined resolution'
|
| 122 |
+
|
| 123 |
+
if not isinstance(resolutions, list):
|
| 124 |
+
resolutions = [resolutions]
|
| 125 |
+
|
| 126 |
+
self._resolutions = []
|
| 127 |
+
for resolution in resolutions:
|
| 128 |
+
if isinstance(resolution, int):
|
| 129 |
+
width = height = resolution
|
| 130 |
+
else:
|
| 131 |
+
width, height = resolution
|
| 132 |
+
assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
|
| 133 |
+
assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
|
| 134 |
+
assert width >= height
|
| 135 |
+
self._resolutions.append((width, height))
|
| 136 |
+
|
| 137 |
+
def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):
|
| 138 |
+
""" This function:
|
| 139 |
+
- first downsizes the image with LANCZOS inteprolation,
|
| 140 |
+
which is better than bilinear interpolation in
|
| 141 |
+
"""
|
| 142 |
+
if not isinstance(image, PIL.Image.Image):
|
| 143 |
+
image = PIL.Image.fromarray(image)
|
| 144 |
+
|
| 145 |
+
# downscale with lanczos interpolation so that image.size == resolution
|
| 146 |
+
# cropping centered on the principal point
|
| 147 |
+
W, H = image.size
|
| 148 |
+
cx, cy = intrinsics[:2, 2].round().astype(int)
|
| 149 |
+
min_margin_x = min(cx, W-cx)
|
| 150 |
+
min_margin_y = min(cy, H-cy)
|
| 151 |
+
assert min_margin_x > W/5, f'Bad principal point in view={info}'
|
| 152 |
+
assert min_margin_y > H/5, f'Bad principal point in view={info}'
|
| 153 |
+
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
|
| 154 |
+
l, t = cx - min_margin_x, cy - min_margin_y
|
| 155 |
+
r, b = cx + min_margin_x, cy + min_margin_y
|
| 156 |
+
crop_bbox = (l, t, r, b)
|
| 157 |
+
image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
| 158 |
+
|
| 159 |
+
# transpose the resolution if necessary
|
| 160 |
+
W, H = image.size # new size
|
| 161 |
+
assert resolution[0] >= resolution[1]
|
| 162 |
+
if H > 1.1*W:
|
| 163 |
+
# image is portrait mode
|
| 164 |
+
resolution = resolution[::-1]
|
| 165 |
+
elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]:
|
| 166 |
+
# image is square, so we chose (portrait, landscape) randomly
|
| 167 |
+
if rng.integers(2):
|
| 168 |
+
resolution = resolution[::-1]
|
| 169 |
+
|
| 170 |
+
# high-quality Lanczos down-scaling
|
| 171 |
+
target_resolution = np.array(resolution)
|
| 172 |
+
if self.aug_crop > 1:
|
| 173 |
+
target_resolution += rng.integers(0, self.aug_crop)
|
| 174 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
|
| 175 |
+
|
| 176 |
+
# actual cropping (if necessary) with bilinear interpolation
|
| 177 |
+
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)
|
| 178 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
|
| 179 |
+
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
| 180 |
+
|
| 181 |
+
return image, depthmap, intrinsics2
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def is_good_type(key, v):
|
| 185 |
+
""" returns (is_good, err_msg)
|
| 186 |
+
"""
|
| 187 |
+
if isinstance(v, (str, int, tuple)):
|
| 188 |
+
return True, None
|
| 189 |
+
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
|
| 190 |
+
return False, f"bad {v.dtype=}"
|
| 191 |
+
return True, None
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def view_name(view, batch_index=None):
|
| 195 |
+
def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x
|
| 196 |
+
db = sel(view['dataset'])
|
| 197 |
+
label = sel(view['label'])
|
| 198 |
+
instance = sel(view['instance'])
|
| 199 |
+
return f"{db}/{label}/{instance}"
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def transpose_to_landscape(view):
|
| 203 |
+
height, width = view['true_shape']
|
| 204 |
+
|
| 205 |
+
if width < height:
|
| 206 |
+
# rectify portrait to landscape
|
| 207 |
+
assert view['img'].shape == (3, height, width)
|
| 208 |
+
view['img'] = view['img'].swapaxes(1, 2)
|
| 209 |
+
|
| 210 |
+
assert view['valid_mask'].shape == (height, width)
|
| 211 |
+
view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
|
| 212 |
+
|
| 213 |
+
assert view['depthmap'].shape == (height, width)
|
| 214 |
+
view['depthmap'] = view['depthmap'].swapaxes(0, 1)
|
| 215 |
+
|
| 216 |
+
assert view['pts3d'].shape == (height, width, 3)
|
| 217 |
+
view['pts3d'] = view['pts3d'].swapaxes(0, 1)
|
| 218 |
+
|
| 219 |
+
# transpose x and y pixels
|
| 220 |
+
view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]
|
dust3r/dust3r/datasets/base/batched_sampler.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Random sampling under a constraint
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BatchedRandomSampler:
|
| 12 |
+
""" Random sampling under a constraint: each sample in the batch has the same feature,
|
| 13 |
+
which is chosen randomly from a known pool of 'features' for each batch.
|
| 14 |
+
|
| 15 |
+
For instance, the 'feature' could be the image aspect-ratio.
|
| 16 |
+
|
| 17 |
+
The index returned is a tuple (sample_idx, feat_idx).
|
| 18 |
+
This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):
|
| 22 |
+
self.batch_size = batch_size
|
| 23 |
+
self.pool_size = pool_size
|
| 24 |
+
|
| 25 |
+
self.len_dataset = N = len(dataset)
|
| 26 |
+
self.total_size = round_by(N, batch_size*world_size) if drop_last else N
|
| 27 |
+
assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'
|
| 28 |
+
|
| 29 |
+
# distributed sampler
|
| 30 |
+
self.world_size = world_size
|
| 31 |
+
self.rank = rank
|
| 32 |
+
self.epoch = None
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return self.total_size // self.world_size
|
| 36 |
+
|
| 37 |
+
def set_epoch(self, epoch):
|
| 38 |
+
self.epoch = epoch
|
| 39 |
+
|
| 40 |
+
def __iter__(self):
|
| 41 |
+
# prepare RNG
|
| 42 |
+
if self.epoch is None:
|
| 43 |
+
assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'
|
| 44 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 45 |
+
else:
|
| 46 |
+
seed = self.epoch + 777
|
| 47 |
+
rng = np.random.default_rng(seed=seed)
|
| 48 |
+
|
| 49 |
+
# random indices (will restart from 0 if not drop_last)
|
| 50 |
+
sample_idxs = np.arange(self.total_size)
|
| 51 |
+
rng.shuffle(sample_idxs)
|
| 52 |
+
|
| 53 |
+
# random feat_idxs (same across each batch)
|
| 54 |
+
n_batches = (self.total_size+self.batch_size-1) // self.batch_size
|
| 55 |
+
feat_idxs = rng.integers(self.pool_size, size=n_batches)
|
| 56 |
+
feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
|
| 57 |
+
feat_idxs = feat_idxs.ravel()[:self.total_size]
|
| 58 |
+
|
| 59 |
+
# put them together
|
| 60 |
+
idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
|
| 61 |
+
|
| 62 |
+
# Distributed sampler: we select a subset of batches
|
| 63 |
+
# make sure the slice for each node is aligned with batch_size
|
| 64 |
+
size_per_proc = self.batch_size * ((self.total_size + self.world_size *
|
| 65 |
+
self.batch_size-1) // (self.world_size * self.batch_size))
|
| 66 |
+
idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]
|
| 67 |
+
|
| 68 |
+
yield from (tuple(idx) for idx in idxs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def round_by(total, multiple, up=False):
|
| 72 |
+
if up:
|
| 73 |
+
total = total + multiple-1
|
| 74 |
+
return (total//multiple) * multiple
|
dust3r/dust3r/datasets/base/easy_dataset.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# A dataset base class that you can easily resize and combine.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
from dust3r.datasets.base.batched_sampler import BatchedRandomSampler
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EasyDataset:
|
| 12 |
+
""" a dataset that you can easily resize and combine.
|
| 13 |
+
Examples:
|
| 14 |
+
---------
|
| 15 |
+
2 * dataset ==> duplicate each element 2x
|
| 16 |
+
|
| 17 |
+
10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
|
| 18 |
+
|
| 19 |
+
dataset1 + dataset2 ==> concatenate datasets
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __add__(self, other):
|
| 23 |
+
return CatDataset([self, other])
|
| 24 |
+
|
| 25 |
+
def __rmul__(self, factor):
|
| 26 |
+
return MulDataset(factor, self)
|
| 27 |
+
|
| 28 |
+
def __rmatmul__(self, factor):
|
| 29 |
+
return ResizedDataset(factor, self)
|
| 30 |
+
|
| 31 |
+
def set_epoch(self, epoch):
|
| 32 |
+
pass # nothing to do by default
|
| 33 |
+
|
| 34 |
+
def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):
|
| 35 |
+
if not (shuffle):
|
| 36 |
+
raise NotImplementedError() # cannot deal yet
|
| 37 |
+
num_of_aspect_ratios = len(self._resolutions)
|
| 38 |
+
return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MulDataset (EasyDataset):
|
| 42 |
+
""" Artifically augmenting the size of a dataset.
|
| 43 |
+
"""
|
| 44 |
+
multiplicator: int
|
| 45 |
+
|
| 46 |
+
def __init__(self, multiplicator, dataset):
|
| 47 |
+
assert isinstance(multiplicator, int) and multiplicator > 0
|
| 48 |
+
self.multiplicator = multiplicator
|
| 49 |
+
self.dataset = dataset
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return self.multiplicator * len(self.dataset)
|
| 53 |
+
|
| 54 |
+
def __repr__(self):
|
| 55 |
+
return f'{self.multiplicator}*{repr(self.dataset)}'
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
if isinstance(idx, tuple):
|
| 59 |
+
idx, other = idx
|
| 60 |
+
return self.dataset[idx // self.multiplicator, other]
|
| 61 |
+
else:
|
| 62 |
+
return self.dataset[idx // self.multiplicator]
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def _resolutions(self):
|
| 66 |
+
return self.dataset._resolutions
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ResizedDataset (EasyDataset):
|
| 70 |
+
""" Artifically changing the size of a dataset.
|
| 71 |
+
"""
|
| 72 |
+
new_size: int
|
| 73 |
+
|
| 74 |
+
def __init__(self, new_size, dataset):
|
| 75 |
+
assert isinstance(new_size, int) and new_size > 0
|
| 76 |
+
self.new_size = new_size
|
| 77 |
+
self.dataset = dataset
|
| 78 |
+
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return self.new_size
|
| 81 |
+
|
| 82 |
+
def __repr__(self):
|
| 83 |
+
size_str = str(self.new_size)
|
| 84 |
+
for i in range((len(size_str)-1) // 3):
|
| 85 |
+
sep = -4*i-3
|
| 86 |
+
size_str = size_str[:sep] + '_' + size_str[sep:]
|
| 87 |
+
return f'{size_str} @ {repr(self.dataset)}'
|
| 88 |
+
|
| 89 |
+
def set_epoch(self, epoch):
|
| 90 |
+
# this random shuffle only depends on the epoch
|
| 91 |
+
rng = np.random.default_rng(seed=epoch+777)
|
| 92 |
+
|
| 93 |
+
# shuffle all indices
|
| 94 |
+
perm = rng.permutation(len(self.dataset))
|
| 95 |
+
|
| 96 |
+
# rotary extension until target size is met
|
| 97 |
+
shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))
|
| 98 |
+
self._idxs_mapping = shuffled_idxs[:self.new_size]
|
| 99 |
+
|
| 100 |
+
assert len(self._idxs_mapping) == self.new_size
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, idx):
|
| 103 |
+
assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'
|
| 104 |
+
if isinstance(idx, tuple):
|
| 105 |
+
idx, other = idx
|
| 106 |
+
return self.dataset[self._idxs_mapping[idx], other]
|
| 107 |
+
else:
|
| 108 |
+
return self.dataset[self._idxs_mapping[idx]]
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def _resolutions(self):
|
| 112 |
+
return self.dataset._resolutions
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class CatDataset (EasyDataset):
|
| 116 |
+
""" Concatenation of several datasets
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, datasets):
|
| 120 |
+
for dataset in datasets:
|
| 121 |
+
assert isinstance(dataset, EasyDataset)
|
| 122 |
+
self.datasets = datasets
|
| 123 |
+
self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
|
| 124 |
+
|
| 125 |
+
def __len__(self):
|
| 126 |
+
return self._cum_sizes[-1]
|
| 127 |
+
|
| 128 |
+
def __repr__(self):
|
| 129 |
+
# remove uselessly long transform
|
| 130 |
+
return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)
|
| 131 |
+
|
| 132 |
+
def set_epoch(self, epoch):
|
| 133 |
+
for dataset in self.datasets:
|
| 134 |
+
dataset.set_epoch(epoch)
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
other = None
|
| 138 |
+
if isinstance(idx, tuple):
|
| 139 |
+
idx, other = idx
|
| 140 |
+
|
| 141 |
+
if not (0 <= idx < len(self)):
|
| 142 |
+
raise IndexError()
|
| 143 |
+
|
| 144 |
+
db_idx = np.searchsorted(self._cum_sizes, idx, 'right')
|
| 145 |
+
dataset = self.datasets[db_idx]
|
| 146 |
+
new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
|
| 147 |
+
|
| 148 |
+
if other is not None:
|
| 149 |
+
new_idx = (new_idx, other)
|
| 150 |
+
return dataset[new_idx]
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def _resolutions(self):
|
| 154 |
+
resolutions = self.datasets[0]._resolutions
|
| 155 |
+
for dataset in self.datasets[1:]:
|
| 156 |
+
assert tuple(dataset._resolutions) == tuple(resolutions)
|
| 157 |
+
return resolutions
|
dust3r/dust3r/datasets/co3d.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Dataloader for preprocessed Co3d_v2
|
| 6 |
+
# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International
|
| 7 |
+
# See datasets_preprocess/preprocess_co3d.py
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
import os.path as osp
|
| 10 |
+
import json
|
| 11 |
+
import itertools
|
| 12 |
+
from collections import deque
|
| 13 |
+
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
|
| 18 |
+
from dust3r.utils.image import imread_cv2
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Co3d(BaseStereoViewDataset):
|
| 22 |
+
def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
|
| 23 |
+
self.ROOT = ROOT
|
| 24 |
+
super().__init__(*args, **kwargs)
|
| 25 |
+
assert mask_bg in (True, False, 'rand')
|
| 26 |
+
self.mask_bg = mask_bg
|
| 27 |
+
|
| 28 |
+
# load all scenes
|
| 29 |
+
with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
|
| 30 |
+
self.scenes = json.load(f)
|
| 31 |
+
self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
|
| 32 |
+
self.scenes = {(k, k2): v2 for k, v in self.scenes.items()
|
| 33 |
+
for k2, v2 in v.items()}
|
| 34 |
+
self.scene_list = list(self.scenes.keys())
|
| 35 |
+
|
| 36 |
+
# for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)
|
| 37 |
+
# we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees
|
| 38 |
+
self.combinations = [(i, j)
|
| 39 |
+
for i, j in itertools.combinations(range(100), 2)
|
| 40 |
+
if 0 < abs(i-j) <= 30 and abs(i-j) % 5 == 0]
|
| 41 |
+
|
| 42 |
+
self.invalidate = {scene: {} for scene in self.scene_list}
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return len(self.scene_list) * len(self.combinations)
|
| 46 |
+
|
| 47 |
+
def _get_views(self, idx, resolution, rng):
|
| 48 |
+
# choose a scene
|
| 49 |
+
obj, instance = self.scene_list[idx // len(self.combinations)]
|
| 50 |
+
image_pool = self.scenes[obj, instance]
|
| 51 |
+
im1_idx, im2_idx = self.combinations[idx % len(self.combinations)]
|
| 52 |
+
|
| 53 |
+
# add a bit of randomness
|
| 54 |
+
last = len(image_pool)-1
|
| 55 |
+
|
| 56 |
+
if resolution not in self.invalidate[obj, instance]: # flag invalid images
|
| 57 |
+
self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
|
| 58 |
+
|
| 59 |
+
# decide now if we mask the bg
|
| 60 |
+
mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
|
| 61 |
+
|
| 62 |
+
views = []
|
| 63 |
+
imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]]
|
| 64 |
+
imgs_idxs = deque(imgs_idxs)
|
| 65 |
+
while len(imgs_idxs) > 0: # some images (few) have zero depth
|
| 66 |
+
im_idx = imgs_idxs.pop()
|
| 67 |
+
|
| 68 |
+
if self.invalidate[obj, instance][resolution][im_idx]:
|
| 69 |
+
# search for a valid image
|
| 70 |
+
random_direction = 2 * rng.choice(2) - 1
|
| 71 |
+
for offset in range(1, len(image_pool)):
|
| 72 |
+
tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
|
| 73 |
+
if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
|
| 74 |
+
im_idx = tentative_im_idx
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
view_idx = image_pool[im_idx]
|
| 78 |
+
|
| 79 |
+
impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')
|
| 80 |
+
|
| 81 |
+
# load camera params
|
| 82 |
+
input_metadata = np.load(impath.replace('jpg', 'npz'))
|
| 83 |
+
camera_pose = input_metadata['camera_pose'].astype(np.float32)
|
| 84 |
+
intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
|
| 85 |
+
|
| 86 |
+
# load image and depth
|
| 87 |
+
rgb_image = imread_cv2(impath)
|
| 88 |
+
depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
|
| 89 |
+
depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
|
| 90 |
+
|
| 91 |
+
if mask_bg:
|
| 92 |
+
# load object mask
|
| 93 |
+
maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')
|
| 94 |
+
maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
| 95 |
+
maskmap = (maskmap / 255.0) > 0.1
|
| 96 |
+
|
| 97 |
+
# update the depthmap with mask
|
| 98 |
+
depthmap *= maskmap
|
| 99 |
+
|
| 100 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 101 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
|
| 102 |
+
|
| 103 |
+
num_valid = (depthmap > 0.0).sum()
|
| 104 |
+
if num_valid == 0:
|
| 105 |
+
# problem, invalidate image and retry
|
| 106 |
+
self.invalidate[obj, instance][resolution][im_idx] = True
|
| 107 |
+
imgs_idxs.append(im_idx)
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
views.append(dict(
|
| 111 |
+
img=rgb_image,
|
| 112 |
+
depthmap=depthmap,
|
| 113 |
+
camera_pose=camera_pose,
|
| 114 |
+
camera_intrinsics=intrinsics,
|
| 115 |
+
dataset='Co3d_v2',
|
| 116 |
+
label=osp.join(obj, instance),
|
| 117 |
+
instance=osp.split(impath)[1],
|
| 118 |
+
))
|
| 119 |
+
return views
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
from dust3r.datasets.base.base_stereo_view_dataset import view_name
|
| 124 |
+
from dust3r.viz import SceneViz, auto_cam_size
|
| 125 |
+
from dust3r.utils.image import rgb
|
| 126 |
+
|
| 127 |
+
dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16)
|
| 128 |
+
|
| 129 |
+
for idx in np.random.permutation(len(dataset)):
|
| 130 |
+
views = dataset[idx]
|
| 131 |
+
assert len(views) == 2
|
| 132 |
+
print(view_name(views[0]), view_name(views[1]))
|
| 133 |
+
viz = SceneViz()
|
| 134 |
+
poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]
|
| 135 |
+
cam_size = max(auto_cam_size(poses), 0.001)
|
| 136 |
+
for view_idx in [0, 1]:
|
| 137 |
+
pts3d = views[view_idx]['pts3d']
|
| 138 |
+
valid_mask = views[view_idx]['valid_mask']
|
| 139 |
+
colors = rgb(views[view_idx]['img'])
|
| 140 |
+
viz.add_pointcloud(pts3d, colors, valid_mask)
|
| 141 |
+
viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
|
| 142 |
+
focal=views[view_idx]['camera_intrinsics'][0, 0],
|
| 143 |
+
color=(idx*255, (1 - idx)*255, 0),
|
| 144 |
+
image=colors,
|
| 145 |
+
cam_size=cam_size)
|
| 146 |
+
viz.show()
|
dust3r/dust3r/datasets/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
dust3r/dust3r/datasets/utils/cropping.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# croppping utilities
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import PIL.Image
|
| 8 |
+
import os
|
| 9 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 10 |
+
import cv2 # noqa
|
| 11 |
+
import numpy as np # noqa
|
| 12 |
+
from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa
|
| 13 |
+
try:
|
| 14 |
+
lanczos = PIL.Image.Resampling.LANCZOS
|
| 15 |
+
except AttributeError:
|
| 16 |
+
lanczos = PIL.Image.LANCZOS
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ImageList:
|
| 20 |
+
""" Convenience class to aply the same operation to a whole set of images.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, images):
|
| 24 |
+
if not isinstance(images, (tuple, list, set)):
|
| 25 |
+
images = [images]
|
| 26 |
+
self.images = []
|
| 27 |
+
for image in images:
|
| 28 |
+
if not isinstance(image, PIL.Image.Image):
|
| 29 |
+
image = PIL.Image.fromarray(image)
|
| 30 |
+
self.images.append(image)
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.images)
|
| 34 |
+
|
| 35 |
+
def to_pil(self):
|
| 36 |
+
return tuple(self.images) if len(self.images) > 1 else self.images[0]
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def size(self):
|
| 40 |
+
sizes = [im.size for im in self.images]
|
| 41 |
+
assert all(sizes[0] == s for s in sizes)
|
| 42 |
+
return sizes[0]
|
| 43 |
+
|
| 44 |
+
def resize(self, *args, **kwargs):
|
| 45 |
+
return ImageList(self._dispatch('resize', *args, **kwargs))
|
| 46 |
+
|
| 47 |
+
def crop(self, *args, **kwargs):
|
| 48 |
+
return ImageList(self._dispatch('crop', *args, **kwargs))
|
| 49 |
+
|
| 50 |
+
def _dispatch(self, func, *args, **kwargs):
|
| 51 |
+
return [getattr(im, func)(*args, **kwargs) for im in self.images]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution):
|
| 55 |
+
""" Jointly rescale a (image, depthmap)
|
| 56 |
+
so that (out_width, out_height) >= output_res
|
| 57 |
+
"""
|
| 58 |
+
image = ImageList(image)
|
| 59 |
+
input_resolution = np.array(image.size) # (W,H)
|
| 60 |
+
output_resolution = np.array(output_resolution)
|
| 61 |
+
if depthmap is not None:
|
| 62 |
+
# can also use this with masks instead of depthmaps
|
| 63 |
+
assert tuple(depthmap.shape[:2]) == image.size[::-1]
|
| 64 |
+
assert output_resolution.shape == (2,)
|
| 65 |
+
# define output resolution
|
| 66 |
+
scale_final = max(output_resolution / image.size) + 1e-8
|
| 67 |
+
output_resolution = np.floor(input_resolution * scale_final).astype(int)
|
| 68 |
+
|
| 69 |
+
# first rescale the image so that it contains the crop
|
| 70 |
+
image = image.resize(output_resolution, resample=lanczos)
|
| 71 |
+
if depthmap is not None:
|
| 72 |
+
depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
|
| 73 |
+
fy=scale_final, interpolation=cv2.INTER_NEAREST)
|
| 74 |
+
|
| 75 |
+
# no offset here; simple rescaling
|
| 76 |
+
camera_intrinsics = camera_matrix_of_crop(
|
| 77 |
+
camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
|
| 78 |
+
|
| 79 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
|
| 83 |
+
# Margins to offset the origin
|
| 84 |
+
margins = np.asarray(input_resolution) * scaling - output_resolution
|
| 85 |
+
assert np.all(margins >= 0.0)
|
| 86 |
+
if offset is None:
|
| 87 |
+
offset = offset_factor * margins
|
| 88 |
+
|
| 89 |
+
# Generate new camera parameters
|
| 90 |
+
output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
|
| 91 |
+
output_camera_matrix_colmap[:2, :] *= scaling
|
| 92 |
+
output_camera_matrix_colmap[:2, 2] -= offset
|
| 93 |
+
output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
|
| 94 |
+
|
| 95 |
+
return output_camera_matrix
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
|
| 99 |
+
"""
|
| 100 |
+
Return a crop of the input view.
|
| 101 |
+
"""
|
| 102 |
+
image = ImageList(image)
|
| 103 |
+
l, t, r, b = crop_bbox
|
| 104 |
+
|
| 105 |
+
image = image.crop((l, t, r, b))
|
| 106 |
+
depthmap = depthmap[t:b, l:r]
|
| 107 |
+
|
| 108 |
+
camera_intrinsics = camera_intrinsics.copy()
|
| 109 |
+
camera_intrinsics[0, 2] -= l
|
| 110 |
+
camera_intrinsics[1, 2] -= t
|
| 111 |
+
|
| 112 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
|
| 116 |
+
out_width, out_height = output_resolution
|
| 117 |
+
l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
|
| 118 |
+
crop_bbox = (l, t, l+out_width, t+out_height)
|
| 119 |
+
return crop_bbox
|
dust3r/dust3r/datasets/utils/transforms.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# DUST3R default transforms
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torchvision.transforms as tvf
|
| 8 |
+
from dust3r.utils.image import ImgNorm
|
| 9 |
+
|
| 10 |
+
# define the standard image transforms
|
| 11 |
+
ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
|
dust3r/dust3r/heads/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# head factory
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from .linear_head import LinearPts3d
|
| 8 |
+
from .dpt_head import create_dpt_head
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def head_factory(head_type, output_mode, net, has_conf=False):
|
| 12 |
+
"""" build a prediction head for the decoder
|
| 13 |
+
"""
|
| 14 |
+
if head_type == 'linear' and output_mode == 'pts3d':
|
| 15 |
+
return LinearPts3d(net, has_conf)
|
| 16 |
+
elif head_type == 'dpt' and output_mode == 'pts3d':
|
| 17 |
+
return create_dpt_head(net, has_conf=has_conf)
|
| 18 |
+
else:
|
| 19 |
+
raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
|
dust3r/dust3r/heads/dpt_head.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# dpt head implementation for DUST3R
|
| 6 |
+
# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
|
| 7 |
+
# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
|
| 8 |
+
# the forward function also takes as input a dictionnary img_info with key "height" and "width"
|
| 9 |
+
# for PixelwiseTask, the output will be of dimension B x num_channels x H x W
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from typing import List
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from dust3r.heads.postprocess import postprocess
|
| 16 |
+
import dust3r.utils.path_to_croco # noqa: F401
|
| 17 |
+
from models.dpt_block import DPTOutputAdapter # noqa
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DPTOutputAdapter_fix(DPTOutputAdapter):
|
| 21 |
+
"""
|
| 22 |
+
Adapt croco's DPTOutputAdapter implementation for dust3r:
|
| 23 |
+
remove duplicated weigths, and fix forward for dust3r
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def init(self, dim_tokens_enc=768):
|
| 27 |
+
super().init(dim_tokens_enc)
|
| 28 |
+
# these are duplicated weights
|
| 29 |
+
del self.act_1_postprocess
|
| 30 |
+
del self.act_2_postprocess
|
| 31 |
+
del self.act_3_postprocess
|
| 32 |
+
del self.act_4_postprocess
|
| 33 |
+
|
| 34 |
+
def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
|
| 35 |
+
assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
|
| 36 |
+
# H, W = input_info['image_size']
|
| 37 |
+
image_size = self.image_size if image_size is None else image_size
|
| 38 |
+
H, W = image_size
|
| 39 |
+
# Number of patches in height and width
|
| 40 |
+
N_H = H // (self.stride_level * self.P_H)
|
| 41 |
+
N_W = W // (self.stride_level * self.P_W)
|
| 42 |
+
|
| 43 |
+
# Hook decoder onto 4 layers from specified ViT layers
|
| 44 |
+
layers = [encoder_tokens[hook] for hook in self.hooks]
|
| 45 |
+
|
| 46 |
+
# Extract only task-relevant tokens and ignore global tokens.
|
| 47 |
+
layers = [self.adapt_tokens(l) for l in layers]
|
| 48 |
+
|
| 49 |
+
# Reshape tokens to spatial representation
|
| 50 |
+
layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
|
| 51 |
+
|
| 52 |
+
layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
|
| 53 |
+
# Project layers to chosen feature dim
|
| 54 |
+
layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
|
| 55 |
+
|
| 56 |
+
# Fuse layers using refinement stages
|
| 57 |
+
path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
|
| 58 |
+
path_3 = self.scratch.refinenet3(path_4, layers[2])
|
| 59 |
+
path_2 = self.scratch.refinenet2(path_3, layers[1])
|
| 60 |
+
path_1 = self.scratch.refinenet1(path_2, layers[0])
|
| 61 |
+
|
| 62 |
+
# Output head
|
| 63 |
+
out = self.head(path_1)
|
| 64 |
+
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class PixelwiseTaskWithDPT(nn.Module):
|
| 69 |
+
""" DPT module for dust3r, can return 3D points + confidence for all pixels"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
|
| 72 |
+
output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
|
| 73 |
+
super(PixelwiseTaskWithDPT, self).__init__()
|
| 74 |
+
self.return_all_layers = True # backbone needs to return all layers
|
| 75 |
+
self.postprocess = postprocess
|
| 76 |
+
self.depth_mode = depth_mode
|
| 77 |
+
self.conf_mode = conf_mode
|
| 78 |
+
|
| 79 |
+
assert n_cls_token == 0, "Not implemented"
|
| 80 |
+
dpt_args = dict(output_width_ratio=output_width_ratio,
|
| 81 |
+
num_channels=num_channels,
|
| 82 |
+
**kwargs)
|
| 83 |
+
if hooks_idx is not None:
|
| 84 |
+
dpt_args.update(hooks=hooks_idx)
|
| 85 |
+
self.dpt = DPTOutputAdapter_fix(**dpt_args)
|
| 86 |
+
dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
|
| 87 |
+
self.dpt.init(**dpt_init_args)
|
| 88 |
+
|
| 89 |
+
def forward(self, x, img_info):
|
| 90 |
+
out = self.dpt(x, image_size=(img_info[0], img_info[1]))
|
| 91 |
+
if self.postprocess:
|
| 92 |
+
out = self.postprocess(out, self.depth_mode, self.conf_mode)
|
| 93 |
+
return out
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def create_dpt_head(net, has_conf=False):
|
| 97 |
+
"""
|
| 98 |
+
return PixelwiseTaskWithDPT for given net params
|
| 99 |
+
"""
|
| 100 |
+
assert net.dec_depth > 9
|
| 101 |
+
l2 = net.dec_depth
|
| 102 |
+
feature_dim = 256
|
| 103 |
+
last_dim = feature_dim//2
|
| 104 |
+
out_nchan = 3
|
| 105 |
+
ed = net.enc_embed_dim
|
| 106 |
+
dd = net.dec_embed_dim
|
| 107 |
+
return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
|
| 108 |
+
feature_dim=feature_dim,
|
| 109 |
+
last_dim=last_dim,
|
| 110 |
+
hooks_idx=[0, l2*2//4, l2*3//4, l2],
|
| 111 |
+
dim_tokens=[ed, dd, dd, dd],
|
| 112 |
+
postprocess=postprocess,
|
| 113 |
+
depth_mode=net.depth_mode,
|
| 114 |
+
conf_mode=net.conf_mode,
|
| 115 |
+
head_type='regression')
|
dust3r/dust3r/heads/linear_head.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# linear head implementation for DUST3R
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from dust3r.heads.postprocess import postprocess
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LinearPts3d (nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Linear head for dust3r
|
| 15 |
+
Each token outputs: - 16x16 3D points (+ confidence)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, net, has_conf=False):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.patch_size = net.patch_embed.patch_size[0]
|
| 21 |
+
self.depth_mode = net.depth_mode
|
| 22 |
+
self.conf_mode = net.conf_mode
|
| 23 |
+
self.has_conf = has_conf
|
| 24 |
+
|
| 25 |
+
self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
|
| 26 |
+
|
| 27 |
+
def setup(self, croconet):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def forward(self, decout, img_shape):
|
| 31 |
+
H, W = img_shape
|
| 32 |
+
tokens = decout[-1]
|
| 33 |
+
B, S, D = tokens.shape
|
| 34 |
+
|
| 35 |
+
# extract 3D points
|
| 36 |
+
feat = self.proj(tokens) # B,S,D
|
| 37 |
+
feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
|
| 38 |
+
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
|
| 39 |
+
|
| 40 |
+
# permute + norm depth
|
| 41 |
+
return postprocess(feat, self.depth_mode, self.conf_mode)
|
dust3r/dust3r/heads/postprocess.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# post process function for all heads: extract 3D points/confidence from output
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def postprocess(out, depth_mode, conf_mode):
|
| 11 |
+
"""
|
| 12 |
+
extract 3D points/confidence from prediction head output
|
| 13 |
+
"""
|
| 14 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
|
| 15 |
+
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
|
| 16 |
+
|
| 17 |
+
if conf_mode is not None:
|
| 18 |
+
res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
|
| 19 |
+
return res
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def reg_dense_depth(xyz, mode):
|
| 23 |
+
"""
|
| 24 |
+
extract 3D points from prediction head output
|
| 25 |
+
"""
|
| 26 |
+
mode, vmin, vmax = mode
|
| 27 |
+
|
| 28 |
+
no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
|
| 29 |
+
assert no_bounds
|
| 30 |
+
|
| 31 |
+
if mode == 'linear':
|
| 32 |
+
if no_bounds:
|
| 33 |
+
return xyz # [-inf, +inf]
|
| 34 |
+
return xyz.clip(min=vmin, max=vmax)
|
| 35 |
+
|
| 36 |
+
# distance to origin
|
| 37 |
+
d = xyz.norm(dim=-1, keepdim=True)
|
| 38 |
+
xyz = xyz / d.clip(min=1e-8)
|
| 39 |
+
|
| 40 |
+
if mode == 'square':
|
| 41 |
+
return xyz * d.square()
|
| 42 |
+
|
| 43 |
+
if mode == 'exp':
|
| 44 |
+
return xyz * torch.expm1(d)
|
| 45 |
+
|
| 46 |
+
raise ValueError(f'bad {mode=}')
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def reg_dense_conf(x, mode):
|
| 50 |
+
"""
|
| 51 |
+
extract confidence from prediction head output
|
| 52 |
+
"""
|
| 53 |
+
mode, vmin, vmax = mode
|
| 54 |
+
if mode == 'exp':
|
| 55 |
+
return vmin + x.exp().clip(max=vmax-vmin)
|
| 56 |
+
if mode == 'sigmoid':
|
| 57 |
+
return (vmax - vmin) * torch.sigmoid(x) + vmin
|
| 58 |
+
raise ValueError(f'bad {mode=}')
|
dust3r/dust3r/image_pairs.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utilities needed to load image pairs
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True):
|
| 12 |
+
pairs = []
|
| 13 |
+
|
| 14 |
+
if scene_graph == 'complete': # complete graph
|
| 15 |
+
for i in range(len(imgs)):
|
| 16 |
+
for j in range(i):
|
| 17 |
+
pairs.append((imgs[i], imgs[j]))
|
| 18 |
+
|
| 19 |
+
elif scene_graph.startswith('swin'):
|
| 20 |
+
winsize = int(scene_graph.split('-')[1]) if '-' in scene_graph else 3
|
| 21 |
+
for i in range(len(imgs)):
|
| 22 |
+
for j in range(winsize):
|
| 23 |
+
idx = (i + j) % len(imgs) # explicit loop closure
|
| 24 |
+
pairs.append((imgs[i], imgs[idx]))
|
| 25 |
+
|
| 26 |
+
elif scene_graph.startswith('oneref'):
|
| 27 |
+
refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0
|
| 28 |
+
for j in range(len(imgs)):
|
| 29 |
+
if j != refid:
|
| 30 |
+
pairs.append((imgs[refid], imgs[j]))
|
| 31 |
+
|
| 32 |
+
elif scene_graph == 'pairs':
|
| 33 |
+
assert len(imgs) % 2 == 0
|
| 34 |
+
for i in range(0, len(imgs), 2):
|
| 35 |
+
pairs.append((imgs[i], imgs[i+1]))
|
| 36 |
+
|
| 37 |
+
if symmetrize:
|
| 38 |
+
pairs += [(img2, img1) for img1, img2 in pairs]
|
| 39 |
+
|
| 40 |
+
# now, remove edges
|
| 41 |
+
if isinstance(prefilter, str) and prefilter.startswith('seq'):
|
| 42 |
+
pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
|
| 43 |
+
|
| 44 |
+
if isinstance(prefilter, str) and prefilter.startswith('cyc'):
|
| 45 |
+
pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
|
| 46 |
+
|
| 47 |
+
return pairs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def sel(x, kept):
|
| 51 |
+
if isinstance(x, dict):
|
| 52 |
+
return {k: sel(v, kept) for k, v in x.items()}
|
| 53 |
+
if isinstance(x, (torch.Tensor, np.ndarray)):
|
| 54 |
+
return x[kept]
|
| 55 |
+
if isinstance(x, (tuple, list)):
|
| 56 |
+
return type(x)([x[k] for k in kept])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
|
| 60 |
+
# number of images
|
| 61 |
+
n = max(max(e) for e in edges)+1
|
| 62 |
+
|
| 63 |
+
kept = []
|
| 64 |
+
for e, (i, j) in enumerate(edges):
|
| 65 |
+
dis = abs(i-j)
|
| 66 |
+
if cyclic:
|
| 67 |
+
dis = min(dis, abs(i+n-j), abs(i-n-j))
|
| 68 |
+
if dis <= seq_dis_thr:
|
| 69 |
+
kept.append(e)
|
| 70 |
+
return kept
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
|
| 74 |
+
edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs]
|
| 75 |
+
kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
|
| 76 |
+
return [pairs[i] for i in kept]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
|
| 80 |
+
edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
|
| 81 |
+
kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
|
| 82 |
+
print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges')
|
| 83 |
+
return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
|
dust3r/dust3r/inference.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utilities needed for the inference
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import tqdm
|
| 8 |
+
import torch
|
| 9 |
+
from dust3r.utils.device import to_cpu, collate_with_cat
|
| 10 |
+
from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model
|
| 11 |
+
from dust3r.utils.misc import invalid_to_nans
|
| 12 |
+
from dust3r.utils.geometry import depthmap_to_pts3d, geotrf
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_model(model_path, device):
|
| 16 |
+
print('... loading model from', model_path)
|
| 17 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
| 18 |
+
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
|
| 19 |
+
if 'landscape_only' not in args:
|
| 20 |
+
args = args[:-1] + ', landscape_only=False)'
|
| 21 |
+
else:
|
| 22 |
+
args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
|
| 23 |
+
assert "landscape_only=False" in args
|
| 24 |
+
print(f"instantiating : {args}")
|
| 25 |
+
net = eval(args)
|
| 26 |
+
print(net.load_state_dict(ckpt['model'], strict=False))
|
| 27 |
+
return net.to(device)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _interleave_imgs(img1, img2):
|
| 31 |
+
res = {}
|
| 32 |
+
for key, value1 in img1.items():
|
| 33 |
+
value2 = img2[key]
|
| 34 |
+
if isinstance(value1, torch.Tensor):
|
| 35 |
+
value = torch.stack((value1, value2), dim=1).flatten(0, 1)
|
| 36 |
+
else:
|
| 37 |
+
value = [x for pair in zip(value1, value2) for x in pair]
|
| 38 |
+
res[key] = value
|
| 39 |
+
return res
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def make_batch_symmetric(batch):
|
| 43 |
+
view1, view2 = batch
|
| 44 |
+
view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
|
| 45 |
+
return view1, view2
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
|
| 49 |
+
view1, view2 = batch
|
| 50 |
+
for view in batch:
|
| 51 |
+
for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split(): # pseudo_focal
|
| 52 |
+
if name not in view:
|
| 53 |
+
continue
|
| 54 |
+
view[name] = view[name].to(device, non_blocking=True)
|
| 55 |
+
|
| 56 |
+
if symmetrize_batch:
|
| 57 |
+
view1, view2 = make_batch_symmetric(batch)
|
| 58 |
+
|
| 59 |
+
with torch.cuda.amp.autocast(enabled=bool(use_amp)):
|
| 60 |
+
pred1, pred2 = model(view1, view2)
|
| 61 |
+
|
| 62 |
+
# loss is supposed to be symmetric
|
| 63 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 64 |
+
loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
|
| 65 |
+
|
| 66 |
+
result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
|
| 67 |
+
return result[ret] if ret else result
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@torch.no_grad()
|
| 71 |
+
def inference(pairs, model, device, batch_size=8):
|
| 72 |
+
print(f'>> Inference with model on {len(pairs)} image pairs')
|
| 73 |
+
result = []
|
| 74 |
+
|
| 75 |
+
# first, check if all images have the same size
|
| 76 |
+
multiple_shapes = not (check_if_same_size(pairs))
|
| 77 |
+
if multiple_shapes: # force bs=1
|
| 78 |
+
batch_size = 1
|
| 79 |
+
|
| 80 |
+
for i in tqdm.trange(0, len(pairs), batch_size):
|
| 81 |
+
res = loss_of_one_batch(collate_with_cat(pairs[i:i+batch_size]), model, None, device)
|
| 82 |
+
result.append(to_cpu(res))
|
| 83 |
+
|
| 84 |
+
result = collate_with_cat(result, lists=multiple_shapes)
|
| 85 |
+
|
| 86 |
+
torch.cuda.empty_cache()
|
| 87 |
+
return result
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def check_if_same_size(pairs):
|
| 91 |
+
shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs]
|
| 92 |
+
shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs]
|
| 93 |
+
return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_pred_pts3d(gt, pred, use_pose=False):
|
| 97 |
+
if 'depth' in pred and 'pseudo_focal' in pred:
|
| 98 |
+
try:
|
| 99 |
+
pp = gt['camera_intrinsics'][..., :2, 2]
|
| 100 |
+
except KeyError:
|
| 101 |
+
pp = None
|
| 102 |
+
pts3d = depthmap_to_pts3d(**pred, pp=pp)
|
| 103 |
+
|
| 104 |
+
elif 'pts3d' in pred:
|
| 105 |
+
# pts3d from my camera
|
| 106 |
+
pts3d = pred['pts3d']
|
| 107 |
+
|
| 108 |
+
elif 'pts3d_in_other_view' in pred:
|
| 109 |
+
# pts3d from the other camera, already transformed
|
| 110 |
+
assert use_pose is True
|
| 111 |
+
return pred['pts3d_in_other_view'] # return!
|
| 112 |
+
|
| 113 |
+
if use_pose:
|
| 114 |
+
camera_pose = pred.get('camera_pose')
|
| 115 |
+
assert camera_pose is not None
|
| 116 |
+
pts3d = geotrf(camera_pose, pts3d)
|
| 117 |
+
|
| 118 |
+
return pts3d
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None):
|
| 122 |
+
assert gt_pts1.ndim == pr_pts1.ndim == 4
|
| 123 |
+
assert gt_pts1.shape == pr_pts1.shape
|
| 124 |
+
if gt_pts2 is not None:
|
| 125 |
+
assert gt_pts2.ndim == pr_pts2.ndim == 4
|
| 126 |
+
assert gt_pts2.shape == pr_pts2.shape
|
| 127 |
+
|
| 128 |
+
# concat the pointcloud
|
| 129 |
+
nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
|
| 130 |
+
nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
|
| 131 |
+
|
| 132 |
+
pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
|
| 133 |
+
pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
|
| 134 |
+
|
| 135 |
+
all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1
|
| 136 |
+
all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
|
| 137 |
+
|
| 138 |
+
dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
|
| 139 |
+
dot_gt_gt = all_gt.square().sum(dim=-1)
|
| 140 |
+
|
| 141 |
+
if fit_mode.startswith('avg'):
|
| 142 |
+
# scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
|
| 143 |
+
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
|
| 144 |
+
elif fit_mode.startswith('median'):
|
| 145 |
+
scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
|
| 146 |
+
elif fit_mode.startswith('weiszfeld'):
|
| 147 |
+
# init scaling with l2 closed form
|
| 148 |
+
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
|
| 149 |
+
# iterative re-weighted least-squares
|
| 150 |
+
for iter in range(10):
|
| 151 |
+
# re-weighting by inverse of distance
|
| 152 |
+
dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
|
| 153 |
+
# print(dis.nanmean(-1))
|
| 154 |
+
w = dis.clip_(min=1e-8).reciprocal()
|
| 155 |
+
# update the scaling with the new weights
|
| 156 |
+
scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
|
| 157 |
+
else:
|
| 158 |
+
raise ValueError(f'bad {fit_mode=}')
|
| 159 |
+
|
| 160 |
+
if fit_mode.endswith('stop_grad'):
|
| 161 |
+
scaling = scaling.detach()
|
| 162 |
+
|
| 163 |
+
scaling = scaling.clip(min=1e-3)
|
| 164 |
+
# assert scaling.isfinite().all(), bb()
|
| 165 |
+
return scaling
|
dust3r/dust3r/losses.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Implementation of DUSt3R training losses
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from copy import copy, deepcopy
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from dust3r.inference import get_pred_pts3d, find_opt_scaling
|
| 12 |
+
from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud
|
| 13 |
+
from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def Sum(*losses_and_masks):
|
| 17 |
+
loss, mask = losses_and_masks[0]
|
| 18 |
+
if loss.ndim > 0:
|
| 19 |
+
# we are actually returning the loss for every pixels
|
| 20 |
+
return losses_and_masks
|
| 21 |
+
else:
|
| 22 |
+
# we are returning the global loss
|
| 23 |
+
for loss2, mask2 in losses_and_masks[1:]:
|
| 24 |
+
loss = loss + loss2
|
| 25 |
+
return loss
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LLoss (nn.Module):
|
| 29 |
+
""" L-norm loss
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, reduction='mean'):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.reduction = reduction
|
| 35 |
+
|
| 36 |
+
def forward(self, a, b):
|
| 37 |
+
assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}'
|
| 38 |
+
dist = self.distance(a, b)
|
| 39 |
+
assert dist.ndim == a.ndim-1 # one dimension less
|
| 40 |
+
if self.reduction == 'none':
|
| 41 |
+
return dist
|
| 42 |
+
if self.reduction == 'sum':
|
| 43 |
+
return dist.sum()
|
| 44 |
+
if self.reduction == 'mean':
|
| 45 |
+
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
|
| 46 |
+
raise ValueError(f'bad {self.reduction=} mode')
|
| 47 |
+
|
| 48 |
+
def distance(self, a, b):
|
| 49 |
+
raise NotImplementedError()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class L21Loss (LLoss):
|
| 53 |
+
""" Euclidean distance between 3d points """
|
| 54 |
+
|
| 55 |
+
def distance(self, a, b):
|
| 56 |
+
return torch.norm(a - b, dim=-1) # normalized L2 distance
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
L21 = L21Loss()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Criterion (nn.Module):
|
| 63 |
+
def __init__(self, criterion=None):
|
| 64 |
+
super().__init__()
|
| 65 |
+
assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!'+bb()
|
| 66 |
+
self.criterion = copy(criterion)
|
| 67 |
+
|
| 68 |
+
def get_name(self):
|
| 69 |
+
return f'{type(self).__name__}({self.criterion})'
|
| 70 |
+
|
| 71 |
+
def with_reduction(self, mode):
|
| 72 |
+
res = loss = deepcopy(self)
|
| 73 |
+
while loss is not None:
|
| 74 |
+
assert isinstance(loss, Criterion)
|
| 75 |
+
loss.criterion.reduction = 'none' # make it return the loss for each sample
|
| 76 |
+
loss = loss._loss2 # we assume loss is a Multiloss
|
| 77 |
+
return res
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MultiLoss (nn.Module):
|
| 81 |
+
""" Easily combinable losses (also keep track of individual loss values):
|
| 82 |
+
loss = MyLoss1() + 0.1*MyLoss2()
|
| 83 |
+
Usage:
|
| 84 |
+
Inherit from this class and override get_name() and compute_loss()
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self._alpha = 1
|
| 90 |
+
self._loss2 = None
|
| 91 |
+
|
| 92 |
+
def compute_loss(self, *args, **kwargs):
|
| 93 |
+
raise NotImplementedError()
|
| 94 |
+
|
| 95 |
+
def get_name(self):
|
| 96 |
+
raise NotImplementedError()
|
| 97 |
+
|
| 98 |
+
def __mul__(self, alpha):
|
| 99 |
+
assert isinstance(alpha, (int, float))
|
| 100 |
+
res = copy(self)
|
| 101 |
+
res._alpha = alpha
|
| 102 |
+
return res
|
| 103 |
+
__rmul__ = __mul__ # same
|
| 104 |
+
|
| 105 |
+
def __add__(self, loss2):
|
| 106 |
+
assert isinstance(loss2, MultiLoss)
|
| 107 |
+
res = cur = copy(self)
|
| 108 |
+
# find the end of the chain
|
| 109 |
+
while cur._loss2 is not None:
|
| 110 |
+
cur = cur._loss2
|
| 111 |
+
cur._loss2 = loss2
|
| 112 |
+
return res
|
| 113 |
+
|
| 114 |
+
def __repr__(self):
|
| 115 |
+
name = self.get_name()
|
| 116 |
+
if self._alpha != 1:
|
| 117 |
+
name = f'{self._alpha:g}*{name}'
|
| 118 |
+
if self._loss2:
|
| 119 |
+
name = f'{name} + {self._loss2}'
|
| 120 |
+
return name
|
| 121 |
+
|
| 122 |
+
def forward(self, *args, **kwargs):
|
| 123 |
+
loss = self.compute_loss(*args, **kwargs)
|
| 124 |
+
if isinstance(loss, tuple):
|
| 125 |
+
loss, details = loss
|
| 126 |
+
elif loss.ndim == 0:
|
| 127 |
+
details = {self.get_name(): float(loss)}
|
| 128 |
+
else:
|
| 129 |
+
details = {}
|
| 130 |
+
loss = loss * self._alpha
|
| 131 |
+
|
| 132 |
+
if self._loss2:
|
| 133 |
+
loss2, details2 = self._loss2(*args, **kwargs)
|
| 134 |
+
loss = loss + loss2
|
| 135 |
+
details |= details2
|
| 136 |
+
|
| 137 |
+
return loss, details
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class Regr3D (Criterion, MultiLoss):
|
| 141 |
+
""" Ensure that all 3D points are correct.
|
| 142 |
+
Asymmetric loss: view1 is supposed to be the anchor.
|
| 143 |
+
|
| 144 |
+
P1 = RT1 @ D1
|
| 145 |
+
P2 = RT2 @ D2
|
| 146 |
+
loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1)
|
| 147 |
+
loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2)
|
| 148 |
+
= (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2)
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False):
|
| 152 |
+
super().__init__(criterion)
|
| 153 |
+
self.norm_mode = norm_mode
|
| 154 |
+
self.gt_scale = gt_scale
|
| 155 |
+
|
| 156 |
+
def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):
|
| 157 |
+
# everything is normalized w.r.t. camera of view1
|
| 158 |
+
in_camera1 = inv(gt1['camera_pose'])
|
| 159 |
+
gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3
|
| 160 |
+
gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3
|
| 161 |
+
|
| 162 |
+
valid1 = gt1['valid_mask'].clone()
|
| 163 |
+
valid2 = gt2['valid_mask'].clone()
|
| 164 |
+
|
| 165 |
+
if dist_clip is not None:
|
| 166 |
+
# points that are too far-away == invalid
|
| 167 |
+
dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
|
| 168 |
+
dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
|
| 169 |
+
valid1 = valid1 & (dis1 <= dist_clip)
|
| 170 |
+
valid2 = valid2 & (dis2 <= dist_clip)
|
| 171 |
+
|
| 172 |
+
pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False)
|
| 173 |
+
pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True)
|
| 174 |
+
|
| 175 |
+
# normalize 3d points
|
| 176 |
+
if self.norm_mode:
|
| 177 |
+
pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2)
|
| 178 |
+
if self.norm_mode and not self.gt_scale:
|
| 179 |
+
gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2)
|
| 180 |
+
|
| 181 |
+
return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {}
|
| 182 |
+
|
| 183 |
+
def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
|
| 184 |
+
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \
|
| 185 |
+
self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw)
|
| 186 |
+
# loss on img1 side
|
| 187 |
+
l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1])
|
| 188 |
+
# loss on gt2 side
|
| 189 |
+
l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2])
|
| 190 |
+
self_name = type(self).__name__
|
| 191 |
+
details = {self_name+'_pts3d_1': float(l1.mean()), self_name+'_pts3d_2': float(l2.mean())}
|
| 192 |
+
return Sum((l1, mask1), (l2, mask2)), (details | monitoring)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ConfLoss (MultiLoss):
|
| 196 |
+
""" Weighted regression by learned confidence.
|
| 197 |
+
Assuming the input pixel_loss is a pixel-level regression loss.
|
| 198 |
+
|
| 199 |
+
Principle:
|
| 200 |
+
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
|
| 201 |
+
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
|
| 202 |
+
|
| 203 |
+
alpha: hyperparameter
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(self, pixel_loss, alpha=1):
|
| 207 |
+
super().__init__()
|
| 208 |
+
assert alpha > 0
|
| 209 |
+
self.alpha = alpha
|
| 210 |
+
self.pixel_loss = pixel_loss.with_reduction('none')
|
| 211 |
+
|
| 212 |
+
def get_name(self):
|
| 213 |
+
return f'ConfLoss({self.pixel_loss})'
|
| 214 |
+
|
| 215 |
+
def get_conf_log(self, x):
|
| 216 |
+
return x, torch.log(x)
|
| 217 |
+
|
| 218 |
+
def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
|
| 219 |
+
# compute per-pixel loss
|
| 220 |
+
((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)
|
| 221 |
+
if loss1.numel() == 0:
|
| 222 |
+
print('NO VALID POINTS in img1', force=True)
|
| 223 |
+
if loss2.numel() == 0:
|
| 224 |
+
print('NO VALID POINTS in img2', force=True)
|
| 225 |
+
|
| 226 |
+
# weight by confidence
|
| 227 |
+
conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1])
|
| 228 |
+
conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2])
|
| 229 |
+
conf_loss1 = loss1 * conf1 - self.alpha * log_conf1
|
| 230 |
+
conf_loss2 = loss2 * conf2 - self.alpha * log_conf2
|
| 231 |
+
|
| 232 |
+
# average + nan protection (in case of no valid pixels at all)
|
| 233 |
+
conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0
|
| 234 |
+
conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0
|
| 235 |
+
|
| 236 |
+
return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class Regr3D_ShiftInv (Regr3D):
|
| 240 |
+
""" Same than Regr3D but invariant to depth shift.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def get_all_pts3d(self, gt1, gt2, pred1, pred2):
|
| 244 |
+
# compute unnormalized points
|
| 245 |
+
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \
|
| 246 |
+
super().get_all_pts3d(gt1, gt2, pred1, pred2)
|
| 247 |
+
|
| 248 |
+
# compute median depth
|
| 249 |
+
gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2]
|
| 250 |
+
pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2]
|
| 251 |
+
gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None]
|
| 252 |
+
pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None]
|
| 253 |
+
|
| 254 |
+
# subtract the median depth
|
| 255 |
+
gt_z1 -= gt_shift_z
|
| 256 |
+
gt_z2 -= gt_shift_z
|
| 257 |
+
pred_z1 -= pred_shift_z
|
| 258 |
+
pred_z2 -= pred_shift_z
|
| 259 |
+
|
| 260 |
+
# monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach())
|
| 261 |
+
return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class Regr3D_ScaleInv (Regr3D):
|
| 265 |
+
""" Same than Regr3D but invariant to depth shift.
|
| 266 |
+
if gt_scale == True: enforce the prediction to take the same scale than GT
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
def get_all_pts3d(self, gt1, gt2, pred1, pred2):
|
| 270 |
+
# compute depth-normalized points
|
| 271 |
+
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2)
|
| 272 |
+
|
| 273 |
+
# measure scene scale
|
| 274 |
+
_, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2)
|
| 275 |
+
_, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2)
|
| 276 |
+
|
| 277 |
+
# prevent predictions to be in a ridiculous range
|
| 278 |
+
pred_scale = pred_scale.clip(min=1e-3, max=1e3)
|
| 279 |
+
|
| 280 |
+
# subtract the median depth
|
| 281 |
+
if self.gt_scale:
|
| 282 |
+
pred_pts1 *= gt_scale / pred_scale
|
| 283 |
+
pred_pts2 *= gt_scale / pred_scale
|
| 284 |
+
# monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean())
|
| 285 |
+
else:
|
| 286 |
+
gt_pts1 /= gt_scale
|
| 287 |
+
gt_pts2 /= gt_scale
|
| 288 |
+
pred_pts1 /= pred_scale
|
| 289 |
+
pred_pts2 /= pred_scale
|
| 290 |
+
# monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach())
|
| 291 |
+
|
| 292 |
+
return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv):
|
| 296 |
+
# calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
|
| 297 |
+
pass
|
dust3r/dust3r/model.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# DUSt3R model class
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape
|
| 11 |
+
from .heads import head_factory
|
| 12 |
+
from dust3r.patch_embed import get_patch_embed
|
| 13 |
+
|
| 14 |
+
import dust3r.utils.path_to_croco # noqa: F401
|
| 15 |
+
from models.croco import CroCoNet # noqa
|
| 16 |
+
inf = float('inf')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AsymmetricCroCo3DStereo (CroCoNet):
|
| 20 |
+
""" Two siamese encoders, followed by two decoders.
|
| 21 |
+
The goal is to output 3d points directly, both images in view1's frame
|
| 22 |
+
(hence the asymmetry).
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self,
|
| 26 |
+
output_mode='pts3d',
|
| 27 |
+
head_type='linear',
|
| 28 |
+
depth_mode=('exp', -inf, inf),
|
| 29 |
+
conf_mode=('exp', 1, inf),
|
| 30 |
+
freeze='none',
|
| 31 |
+
landscape_only=True,
|
| 32 |
+
patch_embed_cls='PatchEmbedDust3R', # PatchEmbedDust3R or ManyAR_PatchEmbed
|
| 33 |
+
**croco_kwargs):
|
| 34 |
+
self.patch_embed_cls = patch_embed_cls
|
| 35 |
+
self.croco_args = fill_default_args(croco_kwargs, super().__init__)
|
| 36 |
+
super().__init__(**croco_kwargs)
|
| 37 |
+
|
| 38 |
+
# dust3r specific initialization
|
| 39 |
+
self.dec_blocks2 = deepcopy(self.dec_blocks)
|
| 40 |
+
self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs)
|
| 41 |
+
self.set_freeze(freeze)
|
| 42 |
+
|
| 43 |
+
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
|
| 44 |
+
self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
|
| 45 |
+
|
| 46 |
+
def load_state_dict(self, ckpt, **kw):
|
| 47 |
+
# duplicate all weights for the second decoder if not present
|
| 48 |
+
new_ckpt = dict(ckpt)
|
| 49 |
+
if not any(k.startswith('dec_blocks2') for k in ckpt):
|
| 50 |
+
for key, value in ckpt.items():
|
| 51 |
+
if key.startswith('dec_blocks'):
|
| 52 |
+
new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value
|
| 53 |
+
return super().load_state_dict(new_ckpt, **kw)
|
| 54 |
+
|
| 55 |
+
def set_freeze(self, freeze): # this is for use by downstream models
|
| 56 |
+
self.freeze = freeze
|
| 57 |
+
to_be_frozen = {
|
| 58 |
+
'none': [],
|
| 59 |
+
'mask': [self.mask_token],
|
| 60 |
+
'encoder': [self.mask_token, self.patch_embed, self.enc_blocks],
|
| 61 |
+
}
|
| 62 |
+
freeze_all_params(to_be_frozen[freeze])
|
| 63 |
+
|
| 64 |
+
def _set_prediction_head(self, *args, **kwargs):
|
| 65 |
+
""" No prediction head """
|
| 66 |
+
return
|
| 67 |
+
|
| 68 |
+
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size,
|
| 69 |
+
**kw):
|
| 70 |
+
assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \
|
| 71 |
+
f'{img_size=} must be multiple of {patch_size=}'
|
| 72 |
+
self.output_mode = output_mode
|
| 73 |
+
self.head_type = head_type
|
| 74 |
+
self.depth_mode = depth_mode
|
| 75 |
+
self.conf_mode = conf_mode
|
| 76 |
+
# allocate heads
|
| 77 |
+
self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
|
| 78 |
+
self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
|
| 79 |
+
# magic wrapper
|
| 80 |
+
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
|
| 81 |
+
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
|
| 82 |
+
|
| 83 |
+
def _encode_image(self, image, true_shape):
|
| 84 |
+
# embed the image into patches (x has size B x Npatches x C)
|
| 85 |
+
x, pos = self.patch_embed(image, true_shape=true_shape)
|
| 86 |
+
|
| 87 |
+
# add positional embedding without cls token
|
| 88 |
+
assert self.enc_pos_embed is None
|
| 89 |
+
|
| 90 |
+
# now apply the transformer encoder and normalization
|
| 91 |
+
for blk in self.enc_blocks:
|
| 92 |
+
x = blk(x, pos)
|
| 93 |
+
|
| 94 |
+
x = self.enc_norm(x)
|
| 95 |
+
return x, pos, None
|
| 96 |
+
|
| 97 |
+
def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2):
|
| 98 |
+
if img1.shape[-2:] == img2.shape[-2:]:
|
| 99 |
+
out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0),
|
| 100 |
+
torch.cat((true_shape1, true_shape2), dim=0))
|
| 101 |
+
out, out2 = out.chunk(2, dim=0)
|
| 102 |
+
pos, pos2 = pos.chunk(2, dim=0)
|
| 103 |
+
else:
|
| 104 |
+
out, pos, _ = self._encode_image(img1, true_shape1)
|
| 105 |
+
out2, pos2, _ = self._encode_image(img2, true_shape2)
|
| 106 |
+
return out, out2, pos, pos2
|
| 107 |
+
|
| 108 |
+
def _encode_symmetrized(self, view1, view2):
|
| 109 |
+
img1 = view1['img']
|
| 110 |
+
img2 = view2['img']
|
| 111 |
+
B = img1.shape[0]
|
| 112 |
+
# Recover true_shape when available, otherwise assume that the img shape is the true one
|
| 113 |
+
shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1))
|
| 114 |
+
shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1))
|
| 115 |
+
# warning! maybe the images have different portrait/landscape orientations
|
| 116 |
+
|
| 117 |
+
if is_symmetrized(view1, view2):
|
| 118 |
+
# computing half of forward pass!'
|
| 119 |
+
feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2])
|
| 120 |
+
feat1, feat2 = interleave(feat1, feat2)
|
| 121 |
+
pos1, pos2 = interleave(pos1, pos2)
|
| 122 |
+
else:
|
| 123 |
+
feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2)
|
| 124 |
+
|
| 125 |
+
return (shape1, shape2), (feat1, feat2), (pos1, pos2)
|
| 126 |
+
|
| 127 |
+
def _decoder(self, f1, pos1, f2, pos2):
|
| 128 |
+
final_output = [(f1, f2)] # before projection
|
| 129 |
+
|
| 130 |
+
# project to decoder dim
|
| 131 |
+
f1 = self.decoder_embed(f1)
|
| 132 |
+
f2 = self.decoder_embed(f2)
|
| 133 |
+
|
| 134 |
+
final_output.append((f1, f2))
|
| 135 |
+
for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
|
| 136 |
+
# img1 side
|
| 137 |
+
f1, _ = blk1(*final_output[-1][::+1], pos1, pos2)
|
| 138 |
+
# img2 side
|
| 139 |
+
f2, _ = blk2(*final_output[-1][::-1], pos2, pos1)
|
| 140 |
+
# store the result
|
| 141 |
+
final_output.append((f1, f2))
|
| 142 |
+
|
| 143 |
+
# normalize last output
|
| 144 |
+
del final_output[1] # duplicate with final_output[0]
|
| 145 |
+
final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
|
| 146 |
+
return zip(*final_output)
|
| 147 |
+
|
| 148 |
+
def _downstream_head(self, head_num, decout, img_shape):
|
| 149 |
+
B, S, D = decout[-1].shape
|
| 150 |
+
# img_shape = tuple(map(int, img_shape))
|
| 151 |
+
head = getattr(self, f'head{head_num}')
|
| 152 |
+
return head(decout, img_shape)
|
| 153 |
+
|
| 154 |
+
def forward(self, view1, view2):
|
| 155 |
+
# encode the two images --> B,S,D
|
| 156 |
+
(shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2)
|
| 157 |
+
|
| 158 |
+
# combine all ref images into object-centric representation
|
| 159 |
+
dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)
|
| 160 |
+
|
| 161 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 162 |
+
res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
|
| 163 |
+
res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
|
| 164 |
+
|
| 165 |
+
res2['pts3d_in_other_view'] = res2.pop('pts3d') # predict view2's pts3d in view1's frame
|
| 166 |
+
return res1, res2
|
dust3r/dust3r/optim_factory.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# optimization functions
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def adjust_learning_rate_by_lr(optimizer, lr):
|
| 10 |
+
for param_group in optimizer.param_groups:
|
| 11 |
+
if "lr_scale" in param_group:
|
| 12 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
| 13 |
+
else:
|
| 14 |
+
param_group["lr"] = lr
|
dust3r/dust3r/patch_embed.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# PatchEmbed implementation for DUST3R,
|
| 6 |
+
# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
import torch
|
| 9 |
+
import dust3r.utils.path_to_croco # noqa: F401
|
| 10 |
+
from models.blocks import PatchEmbed # noqa
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
|
| 14 |
+
assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
|
| 15 |
+
patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
|
| 16 |
+
return patch_embed
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PatchEmbedDust3R(PatchEmbed):
|
| 20 |
+
def forward(self, x, **kw):
|
| 21 |
+
B, C, H, W = x.shape
|
| 22 |
+
assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
|
| 23 |
+
assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
|
| 24 |
+
x = self.proj(x)
|
| 25 |
+
pos = self.position_getter(B, x.size(2), x.size(3), x.device)
|
| 26 |
+
if self.flatten:
|
| 27 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 28 |
+
x = self.norm(x)
|
| 29 |
+
return x, pos
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ManyAR_PatchEmbed (PatchEmbed):
|
| 33 |
+
""" Handle images with non-square aspect ratio.
|
| 34 |
+
All images in the same batch have the same aspect ratio.
|
| 35 |
+
true_shape = [(height, width) ...] indicates the actual shape of each image.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
| 39 |
+
self.embed_dim = embed_dim
|
| 40 |
+
super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
|
| 41 |
+
|
| 42 |
+
def forward(self, img, true_shape):
|
| 43 |
+
B, C, H, W = img.shape
|
| 44 |
+
assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'
|
| 45 |
+
assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
|
| 46 |
+
assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
|
| 47 |
+
assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}"
|
| 48 |
+
|
| 49 |
+
# size expressed in tokens
|
| 50 |
+
W //= self.patch_size[0]
|
| 51 |
+
H //= self.patch_size[1]
|
| 52 |
+
n_tokens = H * W
|
| 53 |
+
|
| 54 |
+
height, width = true_shape.T
|
| 55 |
+
is_landscape = (width >= height)
|
| 56 |
+
is_portrait = ~is_landscape
|
| 57 |
+
|
| 58 |
+
# allocate result
|
| 59 |
+
x = img.new_zeros((B, n_tokens, self.embed_dim))
|
| 60 |
+
pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)
|
| 61 |
+
|
| 62 |
+
# linear projection, transposed if necessary
|
| 63 |
+
x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()
|
| 64 |
+
x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()
|
| 65 |
+
|
| 66 |
+
pos[is_landscape] = self.position_getter(1, H, W, pos.device)
|
| 67 |
+
pos[is_portrait] = self.position_getter(1, W, H, pos.device)
|
| 68 |
+
|
| 69 |
+
x = self.norm(x)
|
| 70 |
+
return x, pos
|
dust3r/dust3r/post_process.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utilities for interpreting the DUST3R output
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from dust3r.utils.geometry import xy_grid
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0.5, max_focal=3.5):
|
| 13 |
+
""" Reprojection method, for when the absolute depth is known:
|
| 14 |
+
1) estimate the camera focal using a robust estimator
|
| 15 |
+
2) reproject points onto true rays, minimizing a certain error
|
| 16 |
+
"""
|
| 17 |
+
B, H, W, THREE = pts3d.shape
|
| 18 |
+
assert THREE == 3
|
| 19 |
+
|
| 20 |
+
# centered pixel grid
|
| 21 |
+
pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2
|
| 22 |
+
pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
|
| 23 |
+
|
| 24 |
+
if focal_mode == 'median':
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
# direct estimation of focal
|
| 27 |
+
u, v = pixels.unbind(dim=-1)
|
| 28 |
+
x, y, z = pts3d.unbind(dim=-1)
|
| 29 |
+
fx_votes = (u * z) / x
|
| 30 |
+
fy_votes = (v * z) / y
|
| 31 |
+
|
| 32 |
+
# assume square pixels, hence same focal for X and Y
|
| 33 |
+
f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
|
| 34 |
+
focal = torch.nanmedian(f_votes, dim=-1).values
|
| 35 |
+
|
| 36 |
+
elif focal_mode == 'weiszfeld':
|
| 37 |
+
# init focal with l2 closed form
|
| 38 |
+
# we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
|
| 39 |
+
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1)
|
| 40 |
+
|
| 41 |
+
dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
|
| 42 |
+
dot_xy_xy = xy_over_z.square().sum(dim=-1)
|
| 43 |
+
|
| 44 |
+
focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
|
| 45 |
+
|
| 46 |
+
# iterative re-weighted least-squares
|
| 47 |
+
for iter in range(10):
|
| 48 |
+
# re-weighting by inverse of distance
|
| 49 |
+
dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
|
| 50 |
+
# print(dis.nanmean(-1))
|
| 51 |
+
w = dis.clip(min=1e-8).reciprocal()
|
| 52 |
+
# update the scaling with the new weights
|
| 53 |
+
focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError(f'bad {focal_mode=}')
|
| 56 |
+
|
| 57 |
+
focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
|
| 58 |
+
focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base)
|
| 59 |
+
# print(focal)
|
| 60 |
+
return focal
|
dust3r/dust3r/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
dust3r/dust3r/utils/device.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utilitary functions for DUSt3R
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def todevice(batch, device, callback=None, non_blocking=False):
|
| 12 |
+
''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
|
| 13 |
+
|
| 14 |
+
batch: list, tuple, dict of tensors or other things
|
| 15 |
+
device: pytorch device or 'numpy'
|
| 16 |
+
callback: function that would be called on every sub-elements.
|
| 17 |
+
'''
|
| 18 |
+
if callback:
|
| 19 |
+
batch = callback(batch)
|
| 20 |
+
|
| 21 |
+
if isinstance(batch, dict):
|
| 22 |
+
return {k: todevice(v, device) for k, v in batch.items()}
|
| 23 |
+
|
| 24 |
+
if isinstance(batch, (tuple, list)):
|
| 25 |
+
return type(batch)(todevice(x, device) for x in batch)
|
| 26 |
+
|
| 27 |
+
x = batch
|
| 28 |
+
if device == 'numpy':
|
| 29 |
+
if isinstance(x, torch.Tensor):
|
| 30 |
+
x = x.detach().cpu().numpy()
|
| 31 |
+
elif x is not None:
|
| 32 |
+
if isinstance(x, np.ndarray):
|
| 33 |
+
x = torch.from_numpy(x)
|
| 34 |
+
if torch.is_tensor(x):
|
| 35 |
+
x = x.to(device, non_blocking=non_blocking)
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
to_device = todevice # alias
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def to_numpy(x): return todevice(x, 'numpy')
|
| 43 |
+
def to_cpu(x): return todevice(x, 'cpu')
|
| 44 |
+
def to_cuda(x): return todevice(x, 'cuda')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def collate_with_cat(whatever, lists=False):
|
| 48 |
+
if isinstance(whatever, dict):
|
| 49 |
+
return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
|
| 50 |
+
|
| 51 |
+
elif isinstance(whatever, (tuple, list)):
|
| 52 |
+
if len(whatever) == 0:
|
| 53 |
+
return whatever
|
| 54 |
+
elem = whatever[0]
|
| 55 |
+
T = type(whatever)
|
| 56 |
+
|
| 57 |
+
if elem is None:
|
| 58 |
+
return None
|
| 59 |
+
if isinstance(elem, (bool, float, int, str)):
|
| 60 |
+
return whatever
|
| 61 |
+
if isinstance(elem, tuple):
|
| 62 |
+
return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
|
| 63 |
+
if isinstance(elem, dict):
|
| 64 |
+
return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}
|
| 65 |
+
|
| 66 |
+
if isinstance(elem, torch.Tensor):
|
| 67 |
+
return listify(whatever) if lists else torch.cat(whatever)
|
| 68 |
+
if isinstance(elem, np.ndarray):
|
| 69 |
+
return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])
|
| 70 |
+
|
| 71 |
+
# otherwise, we just chain lists
|
| 72 |
+
return sum(whatever, T())
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def listify(elems):
|
| 76 |
+
return [x for e in elems for x in e]
|
dust3r/dust3r/utils/geometry.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# geometry utilitary functions
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.spatial import cKDTree as KDTree
|
| 10 |
+
|
| 11 |
+
from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
|
| 12 |
+
from dust3r.utils.device import to_numpy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):
|
| 16 |
+
""" Output a (H,W,2) array of int32
|
| 17 |
+
with output[j,i,0] = i + origin[0]
|
| 18 |
+
output[j,i,1] = j + origin[1]
|
| 19 |
+
"""
|
| 20 |
+
if device is None:
|
| 21 |
+
# numpy
|
| 22 |
+
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
|
| 23 |
+
else:
|
| 24 |
+
# torch
|
| 25 |
+
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
|
| 26 |
+
meshgrid, stack = torch.meshgrid, torch.stack
|
| 27 |
+
ones = lambda *a: torch.ones(*a, device=device)
|
| 28 |
+
|
| 29 |
+
tw, th = [arange(o, o+s, **arange_kw) for s, o in zip((W, H), origin)]
|
| 30 |
+
grid = meshgrid(tw, th, indexing='xy')
|
| 31 |
+
if homogeneous:
|
| 32 |
+
grid = grid + (ones((H, W)),)
|
| 33 |
+
if unsqueeze is not None:
|
| 34 |
+
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
|
| 35 |
+
if cat_dim is not None:
|
| 36 |
+
grid = stack(grid, cat_dim)
|
| 37 |
+
return grid
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def geotrf(Trf, pts, ncol=None, norm=False):
|
| 41 |
+
""" Apply a geometric transformation to a list of 3-D points.
|
| 42 |
+
|
| 43 |
+
H: 3x3 or 4x4 projection matrix (typically a Homography)
|
| 44 |
+
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
| 45 |
+
|
| 46 |
+
ncol: int. number of columns of the result (2 or 3)
|
| 47 |
+
norm: float. if != 0, the resut is projected on the z=norm plane.
|
| 48 |
+
|
| 49 |
+
Returns an array of projected 2d points.
|
| 50 |
+
"""
|
| 51 |
+
assert Trf.ndim >= 2
|
| 52 |
+
if isinstance(Trf, np.ndarray):
|
| 53 |
+
pts = np.asarray(pts)
|
| 54 |
+
elif isinstance(Trf, torch.Tensor):
|
| 55 |
+
pts = torch.as_tensor(pts, dtype=Trf.dtype)
|
| 56 |
+
|
| 57 |
+
# adapt shape if necessary
|
| 58 |
+
output_reshape = pts.shape[:-1]
|
| 59 |
+
ncol = ncol or pts.shape[-1]
|
| 60 |
+
|
| 61 |
+
# optimized code
|
| 62 |
+
if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
|
| 63 |
+
Trf.ndim == 3 and pts.ndim == 4):
|
| 64 |
+
d = pts.shape[3]
|
| 65 |
+
if Trf.shape[-1] == d:
|
| 66 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
|
| 67 |
+
elif Trf.shape[-1] == d+1:
|
| 68 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
|
| 71 |
+
else:
|
| 72 |
+
if Trf.ndim >= 3:
|
| 73 |
+
n = Trf.ndim-2
|
| 74 |
+
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
|
| 75 |
+
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
| 76 |
+
|
| 77 |
+
if pts.ndim > Trf.ndim:
|
| 78 |
+
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
|
| 79 |
+
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
| 80 |
+
elif pts.ndim == 2:
|
| 81 |
+
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
|
| 82 |
+
pts = pts[:, None, :]
|
| 83 |
+
|
| 84 |
+
if pts.shape[-1]+1 == Trf.shape[-1]:
|
| 85 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 86 |
+
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
| 87 |
+
elif pts.shape[-1] == Trf.shape[-1]:
|
| 88 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 89 |
+
pts = pts @ Trf
|
| 90 |
+
else:
|
| 91 |
+
pts = Trf @ pts.T
|
| 92 |
+
if pts.ndim >= 2:
|
| 93 |
+
pts = pts.swapaxes(-1, -2)
|
| 94 |
+
|
| 95 |
+
if norm:
|
| 96 |
+
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
| 97 |
+
if norm != 1:
|
| 98 |
+
pts *= norm
|
| 99 |
+
|
| 100 |
+
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
| 101 |
+
return res
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def inv(mat):
|
| 105 |
+
""" Invert a torch or numpy matrix
|
| 106 |
+
"""
|
| 107 |
+
if isinstance(mat, torch.Tensor):
|
| 108 |
+
return torch.linalg.inv(mat)
|
| 109 |
+
if isinstance(mat, np.ndarray):
|
| 110 |
+
return np.linalg.inv(mat)
|
| 111 |
+
raise ValueError(f'bad matrix type = {type(mat)}')
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
|
| 115 |
+
"""
|
| 116 |
+
Args:
|
| 117 |
+
- depthmap (BxHxW array):
|
| 118 |
+
- pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
|
| 119 |
+
Returns:
|
| 120 |
+
pointmap of absolute coordinates (BxHxWx3 array)
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
if len(depth.shape) == 4:
|
| 124 |
+
B, H, W, n = depth.shape
|
| 125 |
+
else:
|
| 126 |
+
B, H, W = depth.shape
|
| 127 |
+
n = None
|
| 128 |
+
|
| 129 |
+
if len(pseudo_focal.shape) == 3: # [B,H,W]
|
| 130 |
+
pseudo_focalx = pseudo_focaly = pseudo_focal
|
| 131 |
+
elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
|
| 132 |
+
pseudo_focalx = pseudo_focal[:, 0]
|
| 133 |
+
if pseudo_focal.shape[1] == 2:
|
| 134 |
+
pseudo_focaly = pseudo_focal[:, 1]
|
| 135 |
+
else:
|
| 136 |
+
pseudo_focaly = pseudo_focalx
|
| 137 |
+
else:
|
| 138 |
+
raise NotImplementedError("Error, unknown input focal shape format.")
|
| 139 |
+
|
| 140 |
+
assert pseudo_focalx.shape == depth.shape[:3]
|
| 141 |
+
assert pseudo_focaly.shape == depth.shape[:3]
|
| 142 |
+
grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
|
| 143 |
+
|
| 144 |
+
# set principal point
|
| 145 |
+
if pp is None:
|
| 146 |
+
grid_x = grid_x - (W-1)/2
|
| 147 |
+
grid_y = grid_y - (H-1)/2
|
| 148 |
+
else:
|
| 149 |
+
grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
|
| 150 |
+
grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
|
| 151 |
+
|
| 152 |
+
if n is None:
|
| 153 |
+
pts3d = torch.empty((B, H, W, 3), device=depth.device)
|
| 154 |
+
pts3d[..., 0] = depth * grid_x / pseudo_focalx
|
| 155 |
+
pts3d[..., 1] = depth * grid_y / pseudo_focaly
|
| 156 |
+
pts3d[..., 2] = depth
|
| 157 |
+
else:
|
| 158 |
+
pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
|
| 159 |
+
pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
|
| 160 |
+
pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
|
| 161 |
+
pts3d[..., 2, :] = depth
|
| 162 |
+
return pts3d
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
|
| 166 |
+
"""
|
| 167 |
+
Args:
|
| 168 |
+
- depthmap (HxW array):
|
| 169 |
+
- camera_intrinsics: a 3x3 matrix
|
| 170 |
+
Returns:
|
| 171 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 172 |
+
"""
|
| 173 |
+
camera_intrinsics = np.float32(camera_intrinsics)
|
| 174 |
+
H, W = depthmap.shape
|
| 175 |
+
|
| 176 |
+
# Compute 3D ray associated with each pixel
|
| 177 |
+
# Strong assumption: there are no skew terms
|
| 178 |
+
assert camera_intrinsics[0, 1] == 0.0
|
| 179 |
+
assert camera_intrinsics[1, 0] == 0.0
|
| 180 |
+
if pseudo_focal is None:
|
| 181 |
+
fu = camera_intrinsics[0, 0]
|
| 182 |
+
fv = camera_intrinsics[1, 1]
|
| 183 |
+
else:
|
| 184 |
+
assert pseudo_focal.shape == (H, W)
|
| 185 |
+
fu = fv = pseudo_focal
|
| 186 |
+
cu = camera_intrinsics[0, 2]
|
| 187 |
+
cv = camera_intrinsics[1, 2]
|
| 188 |
+
|
| 189 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 190 |
+
z_cam = depthmap
|
| 191 |
+
x_cam = (u - cu) * z_cam / fu
|
| 192 |
+
y_cam = (v - cv) * z_cam / fv
|
| 193 |
+
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
| 194 |
+
|
| 195 |
+
# Mask for valid coordinates
|
| 196 |
+
valid_mask = (depthmap > 0.0)
|
| 197 |
+
return X_cam, valid_mask
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw):
|
| 201 |
+
"""
|
| 202 |
+
Args:
|
| 203 |
+
- depthmap (HxW array):
|
| 204 |
+
- camera_intrinsics: a 3x3 matrix
|
| 205 |
+
- camera_pose: a 4x3 or 4x4 cam2world matrix
|
| 206 |
+
Returns:
|
| 207 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
|
| 208 |
+
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
|
| 209 |
+
|
| 210 |
+
# R_cam2world = np.float32(camera_params["R_cam2world"])
|
| 211 |
+
# t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
|
| 212 |
+
R_cam2world = camera_pose[:3, :3]
|
| 213 |
+
t_cam2world = camera_pose[:3, 3]
|
| 214 |
+
|
| 215 |
+
# Express in absolute coordinates (invalid depth values)
|
| 216 |
+
X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
|
| 217 |
+
return X_world, valid_mask
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def colmap_to_opencv_intrinsics(K):
|
| 221 |
+
"""
|
| 222 |
+
Modify camera intrinsics to follow a different convention.
|
| 223 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 224 |
+
- (0.5, 0.5) in Colmap
|
| 225 |
+
- (0,0) in OpenCV
|
| 226 |
+
"""
|
| 227 |
+
K = K.copy()
|
| 228 |
+
K[0, 2] -= 0.5
|
| 229 |
+
K[1, 2] -= 0.5
|
| 230 |
+
return K
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def opencv_to_colmap_intrinsics(K):
|
| 234 |
+
"""
|
| 235 |
+
Modify camera intrinsics to follow a different convention.
|
| 236 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 237 |
+
- (0.5, 0.5) in Colmap
|
| 238 |
+
- (0,0) in OpenCV
|
| 239 |
+
"""
|
| 240 |
+
K = K.copy()
|
| 241 |
+
K[0, 2] += 0.5
|
| 242 |
+
K[1, 2] += 0.5
|
| 243 |
+
return K
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None):
|
| 247 |
+
""" renorm pointmaps pts1, pts2 with norm_mode
|
| 248 |
+
"""
|
| 249 |
+
assert pts1.ndim >= 3 and pts1.shape[-1] == 3
|
| 250 |
+
assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
|
| 251 |
+
norm_mode, dis_mode = norm_mode.split('_')
|
| 252 |
+
|
| 253 |
+
if norm_mode == 'avg':
|
| 254 |
+
# gather all points together (joint normalization)
|
| 255 |
+
nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
|
| 256 |
+
nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
|
| 257 |
+
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
| 258 |
+
|
| 259 |
+
# compute distance to origin
|
| 260 |
+
all_dis = all_pts.norm(dim=-1)
|
| 261 |
+
if dis_mode == 'dis':
|
| 262 |
+
pass # do nothing
|
| 263 |
+
elif dis_mode == 'log1p':
|
| 264 |
+
all_dis = torch.log1p(all_dis)
|
| 265 |
+
elif dis_mode == 'warp-log1p':
|
| 266 |
+
# actually warp input points before normalizing them
|
| 267 |
+
log_dis = torch.log1p(all_dis)
|
| 268 |
+
warp_factor = log_dis / all_dis.clip(min=1e-8)
|
| 269 |
+
H1, W1 = pts1.shape[1:-1]
|
| 270 |
+
pts1 = pts1 * warp_factor[:, :W1*H1].view(-1, H1, W1, 1)
|
| 271 |
+
if pts2 is not None:
|
| 272 |
+
H2, W2 = pts2.shape[1:-1]
|
| 273 |
+
pts2 = pts2 * warp_factor[:, W1*H1:].view(-1, H2, W2, 1)
|
| 274 |
+
all_dis = log_dis # this is their true distance afterwards
|
| 275 |
+
else:
|
| 276 |
+
raise ValueError(f'bad {dis_mode=}')
|
| 277 |
+
|
| 278 |
+
norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
|
| 279 |
+
else:
|
| 280 |
+
# gather all points together (joint normalization)
|
| 281 |
+
nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
|
| 282 |
+
nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
|
| 283 |
+
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
| 284 |
+
|
| 285 |
+
# compute distance to origin
|
| 286 |
+
all_dis = all_pts.norm(dim=-1)
|
| 287 |
+
|
| 288 |
+
if norm_mode == 'avg':
|
| 289 |
+
norm_factor = all_dis.nanmean(dim=1)
|
| 290 |
+
elif norm_mode == 'median':
|
| 291 |
+
norm_factor = all_dis.nanmedian(dim=1).values.detach()
|
| 292 |
+
elif norm_mode == 'sqrt':
|
| 293 |
+
norm_factor = all_dis.sqrt().nanmean(dim=1)**2
|
| 294 |
+
else:
|
| 295 |
+
raise ValueError(f'bad {norm_mode=}')
|
| 296 |
+
|
| 297 |
+
norm_factor = norm_factor.clip(min=1e-8)
|
| 298 |
+
while norm_factor.ndim < pts1.ndim:
|
| 299 |
+
norm_factor.unsqueeze_(-1)
|
| 300 |
+
|
| 301 |
+
res = pts1 / norm_factor
|
| 302 |
+
if pts2 is not None:
|
| 303 |
+
res = (res, pts2 / norm_factor)
|
| 304 |
+
return res
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@torch.no_grad()
|
| 308 |
+
def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
|
| 309 |
+
# set invalid points to NaN
|
| 310 |
+
_z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
|
| 311 |
+
_z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None
|
| 312 |
+
_z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
|
| 313 |
+
|
| 314 |
+
# compute median depth overall (ignoring nans)
|
| 315 |
+
if quantile == 0.5:
|
| 316 |
+
shift_z = torch.nanmedian(_z, dim=-1).values
|
| 317 |
+
else:
|
| 318 |
+
shift_z = torch.nanquantile(_z, quantile, dim=-1)
|
| 319 |
+
return shift_z # (B,)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@torch.no_grad()
|
| 323 |
+
def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):
|
| 324 |
+
# set invalid points to NaN
|
| 325 |
+
_pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
|
| 326 |
+
_pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None
|
| 327 |
+
_pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
|
| 328 |
+
|
| 329 |
+
# compute median center
|
| 330 |
+
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
|
| 331 |
+
if z_only:
|
| 332 |
+
_center[..., :2] = 0 # do not center X and Y
|
| 333 |
+
|
| 334 |
+
# compute median norm
|
| 335 |
+
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
|
| 336 |
+
scale = torch.nanmedian(_norm, dim=1).values
|
| 337 |
+
return _center[:, None, :, :], scale[:, None, None, None]
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def find_reciprocal_matches(P1, P2):
|
| 341 |
+
"""
|
| 342 |
+
returns 3 values:
|
| 343 |
+
1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
|
| 344 |
+
2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
|
| 345 |
+
3 - reciprocal_in_P2.sum(): the number of matches
|
| 346 |
+
"""
|
| 347 |
+
tree1 = KDTree(P1)
|
| 348 |
+
tree2 = KDTree(P2)
|
| 349 |
+
|
| 350 |
+
_, nn1_in_P2 = tree2.query(P1, workers=8)
|
| 351 |
+
_, nn2_in_P1 = tree1.query(P2, workers=8)
|
| 352 |
+
|
| 353 |
+
reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))
|
| 354 |
+
reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))
|
| 355 |
+
assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
|
| 356 |
+
return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def get_med_dist_between_poses(poses):
|
| 360 |
+
from scipy.spatial.distance import pdist
|
| 361 |
+
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
|
dust3r/dust3r/utils/image.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utilitary functions about images (loading/converting...)
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import PIL.Image
|
| 11 |
+
import torchvision.transforms as tvf
|
| 12 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 13 |
+
import cv2 # noqa
|
| 14 |
+
|
| 15 |
+
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def imread_cv2(path, options=cv2.IMREAD_COLOR):
|
| 19 |
+
""" Open an image or a depthmap with opencv-python.
|
| 20 |
+
"""
|
| 21 |
+
if path.endswith(('.exr', 'EXR')):
|
| 22 |
+
options = cv2.IMREAD_ANYDEPTH
|
| 23 |
+
img = cv2.imread(path, options)
|
| 24 |
+
if img is None:
|
| 25 |
+
raise IOError(f'Could not load image={path} with {options=}')
|
| 26 |
+
if img.ndim == 3:
|
| 27 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 28 |
+
return img
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def rgb(ftensor, true_shape=None):
|
| 32 |
+
if isinstance(ftensor, list):
|
| 33 |
+
return [rgb(x, true_shape=true_shape) for x in ftensor]
|
| 34 |
+
if isinstance(ftensor, torch.Tensor):
|
| 35 |
+
ftensor = ftensor.detach().cpu().numpy() # H,W,3
|
| 36 |
+
if ftensor.ndim == 3 and ftensor.shape[0] == 3:
|
| 37 |
+
ftensor = ftensor.transpose(1, 2, 0)
|
| 38 |
+
elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
|
| 39 |
+
ftensor = ftensor.transpose(0, 2, 3, 1)
|
| 40 |
+
if true_shape is not None:
|
| 41 |
+
H, W = true_shape
|
| 42 |
+
ftensor = ftensor[:H, :W]
|
| 43 |
+
if ftensor.dtype == np.uint8:
|
| 44 |
+
img = np.float32(ftensor) / 255
|
| 45 |
+
else:
|
| 46 |
+
img = (ftensor * 0.5) + 0.5
|
| 47 |
+
return img.clip(min=0, max=1)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _resize_pil_image(img, long_edge_size):
|
| 51 |
+
S = max(img.size)
|
| 52 |
+
if S > long_edge_size:
|
| 53 |
+
interp = PIL.Image.LANCZOS
|
| 54 |
+
elif S <= long_edge_size:
|
| 55 |
+
interp = PIL.Image.BICUBIC
|
| 56 |
+
new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
|
| 57 |
+
return img.resize(new_size, interp)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_images(folder_or_list, size, square_ok=False):
|
| 61 |
+
""" open and convert all images in a list or folder to proper input format for DUSt3R
|
| 62 |
+
"""
|
| 63 |
+
if isinstance(folder_or_list, str):
|
| 64 |
+
print(f'>> Loading images from {folder_or_list}')
|
| 65 |
+
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
|
| 66 |
+
|
| 67 |
+
elif isinstance(folder_or_list, list):
|
| 68 |
+
print(f'>> Loading a list of {len(folder_or_list)} images')
|
| 69 |
+
root, folder_content = '', folder_or_list
|
| 70 |
+
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
|
| 73 |
+
|
| 74 |
+
imgs = []
|
| 75 |
+
for path in folder_content:
|
| 76 |
+
if not path.endswith(('.jpg', '.jpeg', '.png', '.JPG')):
|
| 77 |
+
continue
|
| 78 |
+
img = PIL.Image.open(os.path.join(root, path)).convert('RGB')
|
| 79 |
+
W1, H1 = img.size
|
| 80 |
+
if size == 224:
|
| 81 |
+
# resize short side to 224 (then crop)
|
| 82 |
+
img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
|
| 83 |
+
else:
|
| 84 |
+
# resize long side to 512
|
| 85 |
+
img = _resize_pil_image(img, size)
|
| 86 |
+
W, H = img.size
|
| 87 |
+
cx, cy = W//2, H//2
|
| 88 |
+
if size == 224:
|
| 89 |
+
half = min(cx, cy)
|
| 90 |
+
img = img.crop((cx-half, cy-half, cx+half, cy+half))
|
| 91 |
+
else:
|
| 92 |
+
halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
|
| 93 |
+
if not (square_ok) and W == H:
|
| 94 |
+
halfh = 3*halfw/4
|
| 95 |
+
img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
|
| 96 |
+
|
| 97 |
+
W2, H2 = img.size
|
| 98 |
+
print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
|
| 99 |
+
imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
|
| 100 |
+
[img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
|
| 101 |
+
|
| 102 |
+
assert imgs, 'no images foud at '+root
|
| 103 |
+
print(f' (Found {len(imgs)} images)')
|
| 104 |
+
return imgs
|
dust3r/dust3r/utils/misc.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utilitary functions for DUSt3R
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def fill_default_args(kwargs, func):
|
| 11 |
+
import inspect # a bit hacky but it works reliably
|
| 12 |
+
signature = inspect.signature(func)
|
| 13 |
+
|
| 14 |
+
for k, v in signature.parameters.items():
|
| 15 |
+
if v.default is inspect.Parameter.empty:
|
| 16 |
+
continue
|
| 17 |
+
kwargs.setdefault(k, v.default)
|
| 18 |
+
|
| 19 |
+
return kwargs
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def freeze_all_params(modules):
|
| 23 |
+
for module in modules:
|
| 24 |
+
try:
|
| 25 |
+
for n, param in module.named_parameters():
|
| 26 |
+
param.requires_grad = False
|
| 27 |
+
except AttributeError:
|
| 28 |
+
# module is directly a parameter
|
| 29 |
+
module.requires_grad = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def is_symmetrized(gt1, gt2):
|
| 33 |
+
x = gt1['instance']
|
| 34 |
+
y = gt2['instance']
|
| 35 |
+
if len(x) == len(y) and len(x) == 1:
|
| 36 |
+
return False # special case of batchsize 1
|
| 37 |
+
ok = True
|
| 38 |
+
for i in range(0, len(x), 2):
|
| 39 |
+
ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
|
| 40 |
+
return ok
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def flip(tensor):
|
| 44 |
+
""" flip so that tensor[0::2] <=> tensor[1::2] """
|
| 45 |
+
return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def interleave(tensor1, tensor2):
|
| 49 |
+
res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
|
| 50 |
+
res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
|
| 51 |
+
return res1, res2
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def transpose_to_landscape(head, activate=True):
|
| 55 |
+
""" Predict in the correct aspect-ratio,
|
| 56 |
+
then transpose the result in landscape
|
| 57 |
+
and stack everything back together.
|
| 58 |
+
"""
|
| 59 |
+
def wrapper_no(decout, true_shape):
|
| 60 |
+
B = len(true_shape)
|
| 61 |
+
assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
|
| 62 |
+
H, W = true_shape[0].cpu().tolist()
|
| 63 |
+
res = head(decout, (H, W))
|
| 64 |
+
return res
|
| 65 |
+
|
| 66 |
+
def wrapper_yes(decout, true_shape):
|
| 67 |
+
B = len(true_shape)
|
| 68 |
+
# by definition, the batch is in landscape mode so W >= H
|
| 69 |
+
H, W = int(true_shape.min()), int(true_shape.max())
|
| 70 |
+
|
| 71 |
+
height, width = true_shape.T
|
| 72 |
+
is_landscape = (width >= height)
|
| 73 |
+
is_portrait = ~is_landscape
|
| 74 |
+
|
| 75 |
+
# true_shape = true_shape.cpu()
|
| 76 |
+
if is_landscape.all():
|
| 77 |
+
return head(decout, (H, W))
|
| 78 |
+
if is_portrait.all():
|
| 79 |
+
return transposed(head(decout, (W, H)))
|
| 80 |
+
|
| 81 |
+
# batch is a mix of both portraint & landscape
|
| 82 |
+
def selout(ar): return [d[ar] for d in decout]
|
| 83 |
+
l_result = head(selout(is_landscape), (H, W))
|
| 84 |
+
p_result = transposed(head(selout(is_portrait), (W, H)))
|
| 85 |
+
|
| 86 |
+
# allocate full result
|
| 87 |
+
result = {}
|
| 88 |
+
for k in l_result | p_result:
|
| 89 |
+
x = l_result[k].new(B, *l_result[k].shape[1:])
|
| 90 |
+
x[is_landscape] = l_result[k]
|
| 91 |
+
x[is_portrait] = p_result[k]
|
| 92 |
+
result[k] = x
|
| 93 |
+
|
| 94 |
+
return result
|
| 95 |
+
|
| 96 |
+
return wrapper_yes if activate else wrapper_no
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def transposed(dic):
|
| 100 |
+
return {k: v.swapaxes(1, 2) for k, v in dic.items()}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def invalid_to_nans(arr, valid_mask, ndim=999):
|
| 104 |
+
if valid_mask is not None:
|
| 105 |
+
arr = arr.clone()
|
| 106 |
+
arr[~valid_mask] = float('nan')
|
| 107 |
+
if arr.ndim > ndim:
|
| 108 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 109 |
+
return arr
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def invalid_to_zeros(arr, valid_mask, ndim=999):
|
| 113 |
+
if valid_mask is not None:
|
| 114 |
+
arr = arr.clone()
|
| 115 |
+
arr[~valid_mask] = 0
|
| 116 |
+
nnz = valid_mask.view(len(valid_mask), -1).sum(1)
|
| 117 |
+
else:
|
| 118 |
+
nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
|
| 119 |
+
if arr.ndim > ndim:
|
| 120 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 121 |
+
return arr, nnz
|
dust3r/dust3r/utils/path_to_croco.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# CroCo submodule import
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import os.path as path
|
| 10 |
+
HERE_PATH = path.normpath(path.dirname(__file__))
|
| 11 |
+
CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco'))
|
| 12 |
+
CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models')
|
| 13 |
+
# check the presence of models directory in repo to be sure its cloned
|
| 14 |
+
if path.isdir(CROCO_MODELS_PATH):
|
| 15 |
+
# workaround for sibling import
|
| 16 |
+
sys.path.insert(0, CROCO_REPO_PATH)
|
| 17 |
+
else:
|
| 18 |
+
raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n "
|
| 19 |
+
"Did you forget to run 'git submodule update --init --recursive' ?")
|
dust3r/dust3r/viz.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Visualization utilities using trimesh
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import PIL.Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.spatial.transform import Rotation
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from dust3r.utils.geometry import geotrf, get_med_dist_between_poses
|
| 13 |
+
from dust3r.utils.device import to_numpy
|
| 14 |
+
from dust3r.utils.image import rgb
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import trimesh
|
| 18 |
+
except ImportError:
|
| 19 |
+
print('/!\\ module trimesh is not installed, cannot visualize results /!\\')
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def cat_3d(vecs):
|
| 23 |
+
if isinstance(vecs, (np.ndarray, torch.Tensor)):
|
| 24 |
+
vecs = [vecs]
|
| 25 |
+
return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def show_raw_pointcloud(pts3d, colors, point_size=2):
|
| 29 |
+
scene = trimesh.Scene()
|
| 30 |
+
|
| 31 |
+
pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))
|
| 32 |
+
scene.add_geometry(pct)
|
| 33 |
+
|
| 34 |
+
scene.show(line_settings={'point_size': point_size})
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def pts3d_to_trimesh(img, pts3d, valid=None):
|
| 38 |
+
H, W, THREE = img.shape
|
| 39 |
+
assert THREE == 3
|
| 40 |
+
assert img.shape == pts3d.shape
|
| 41 |
+
|
| 42 |
+
vertices = pts3d.reshape(-1, 3)
|
| 43 |
+
|
| 44 |
+
# make squares: each pixel == 2 triangles
|
| 45 |
+
idx = np.arange(len(vertices)).reshape(H, W)
|
| 46 |
+
idx1 = idx[:-1, :-1].ravel() # top-left corner
|
| 47 |
+
idx2 = idx[:-1, +1:].ravel() # right-left corner
|
| 48 |
+
idx3 = idx[+1:, :-1].ravel() # bottom-left corner
|
| 49 |
+
idx4 = idx[+1:, +1:].ravel() # bottom-right corner
|
| 50 |
+
faces = np.concatenate((
|
| 51 |
+
np.c_[idx1, idx2, idx3],
|
| 52 |
+
np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
|
| 53 |
+
np.c_[idx2, idx3, idx4],
|
| 54 |
+
np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
|
| 55 |
+
), axis=0)
|
| 56 |
+
|
| 57 |
+
# prepare triangle colors
|
| 58 |
+
face_colors = np.concatenate((
|
| 59 |
+
img[:-1, :-1].reshape(-1, 3),
|
| 60 |
+
img[:-1, :-1].reshape(-1, 3),
|
| 61 |
+
img[+1:, +1:].reshape(-1, 3),
|
| 62 |
+
img[+1:, +1:].reshape(-1, 3)
|
| 63 |
+
), axis=0)
|
| 64 |
+
|
| 65 |
+
# remove invalid faces
|
| 66 |
+
if valid is not None:
|
| 67 |
+
assert valid.shape == (H, W)
|
| 68 |
+
valid_idxs = valid.ravel()
|
| 69 |
+
valid_faces = valid_idxs[faces].all(axis=-1)
|
| 70 |
+
faces = faces[valid_faces]
|
| 71 |
+
face_colors = face_colors[valid_faces]
|
| 72 |
+
|
| 73 |
+
assert len(faces) == len(face_colors)
|
| 74 |
+
return dict(vertices=vertices, face_colors=face_colors, faces=faces)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def cat_meshes(meshes):
|
| 78 |
+
vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
|
| 79 |
+
n_vertices = np.cumsum([0]+[len(v) for v in vertices])
|
| 80 |
+
for i in range(len(faces)):
|
| 81 |
+
faces[i][:] += n_vertices[i]
|
| 82 |
+
|
| 83 |
+
vertices = np.concatenate(vertices)
|
| 84 |
+
colors = np.concatenate(colors)
|
| 85 |
+
faces = np.concatenate(faces)
|
| 86 |
+
return dict(vertices=vertices, face_colors=colors, faces=faces)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def show_duster_pairs(view1, view2, pred1, pred2):
|
| 90 |
+
import matplotlib.pyplot as pl
|
| 91 |
+
pl.ion()
|
| 92 |
+
|
| 93 |
+
for e in range(len(view1['instance'])):
|
| 94 |
+
i = view1['idx'][e]
|
| 95 |
+
j = view2['idx'][e]
|
| 96 |
+
img1 = rgb(view1['img'][e])
|
| 97 |
+
img2 = rgb(view2['img'][e])
|
| 98 |
+
conf1 = pred1['conf'][e].squeeze()
|
| 99 |
+
conf2 = pred2['conf'][e].squeeze()
|
| 100 |
+
score = conf1.mean()*conf2.mean()
|
| 101 |
+
print(f">> Showing pair #{e} {i}-{j} {score=:g}")
|
| 102 |
+
pl.clf()
|
| 103 |
+
pl.subplot(221).imshow(img1)
|
| 104 |
+
pl.subplot(223).imshow(img2)
|
| 105 |
+
pl.subplot(222).imshow(conf1, vmin=1, vmax=30)
|
| 106 |
+
pl.subplot(224).imshow(conf2, vmin=1, vmax=30)
|
| 107 |
+
pts1 = pred1['pts3d'][e]
|
| 108 |
+
pts2 = pred2['pts3d_in_other_view'][e]
|
| 109 |
+
pl.subplots_adjust(0, 0, 1, 1, 0, 0)
|
| 110 |
+
if input('show pointcloud? (y/n) ') == 'y':
|
| 111 |
+
show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def auto_cam_size(im_poses):
|
| 115 |
+
return 0.1 * get_med_dist_between_poses(im_poses)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class SceneViz:
|
| 119 |
+
def __init__(self):
|
| 120 |
+
self.scene = trimesh.Scene()
|
| 121 |
+
|
| 122 |
+
def add_pointcloud(self, pts3d, color, mask=None):
|
| 123 |
+
pts3d = to_numpy(pts3d)
|
| 124 |
+
mask = to_numpy(mask)
|
| 125 |
+
if mask is None:
|
| 126 |
+
mask = [slice(None)] * len(pts3d)
|
| 127 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
| 128 |
+
pct = trimesh.PointCloud(pts.reshape(-1, 3))
|
| 129 |
+
|
| 130 |
+
if isinstance(color, (list, np.ndarray, torch.Tensor)):
|
| 131 |
+
color = to_numpy(color)
|
| 132 |
+
col = np.concatenate([p[m] for p, m in zip(color, mask)])
|
| 133 |
+
assert col.shape == pts.shape
|
| 134 |
+
pct.visual.vertex_colors = uint8(col.reshape(-1, 3))
|
| 135 |
+
else:
|
| 136 |
+
assert len(color) == 3
|
| 137 |
+
pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)
|
| 138 |
+
|
| 139 |
+
self.scene.add_geometry(pct)
|
| 140 |
+
return self
|
| 141 |
+
|
| 142 |
+
def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.03):
|
| 143 |
+
pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))
|
| 144 |
+
add_scene_cam(self.scene, pose_c2w, color, image, focal, screen_width=cam_size)
|
| 145 |
+
return self
|
| 146 |
+
|
| 147 |
+
def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, **kw):
|
| 148 |
+
def get(arr, idx): return None if arr is None else arr[idx]
|
| 149 |
+
for i, pose_c2w in enumerate(poses):
|
| 150 |
+
self.add_camera(pose_c2w, get(focals, i), image=get(images, i),
|
| 151 |
+
color=get(colors, i), imsize=get(imsizes, i), **kw)
|
| 152 |
+
return self
|
| 153 |
+
|
| 154 |
+
def show(self, point_size=2):
|
| 155 |
+
self.scene.show(line_settings={'point_size': point_size})
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world,
|
| 159 |
+
point_size=2, cam_size=0.05, cam_color=None):
|
| 160 |
+
""" Visualization of a pointcloud with cameras
|
| 161 |
+
imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]
|
| 162 |
+
pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]
|
| 163 |
+
focals = (N,) or N-size list of [focal, ...]
|
| 164 |
+
cams2world = (N,4,4) or N-size list of [(4,4), ...]
|
| 165 |
+
"""
|
| 166 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
| 167 |
+
pts3d = to_numpy(pts3d)
|
| 168 |
+
imgs = to_numpy(imgs)
|
| 169 |
+
focals = to_numpy(focals)
|
| 170 |
+
cams2world = to_numpy(cams2world)
|
| 171 |
+
|
| 172 |
+
scene = trimesh.Scene()
|
| 173 |
+
|
| 174 |
+
# full pointcloud
|
| 175 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
| 176 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
|
| 177 |
+
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
|
| 178 |
+
scene.add_geometry(pct)
|
| 179 |
+
|
| 180 |
+
# add each camera
|
| 181 |
+
for i, pose_c2w in enumerate(cams2world):
|
| 182 |
+
if isinstance(cam_color, list):
|
| 183 |
+
camera_edge_color = cam_color[i]
|
| 184 |
+
else:
|
| 185 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
| 186 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
| 187 |
+
imgs[i] if i < len(imgs) else None, focals[i], screen_width=cam_size)
|
| 188 |
+
|
| 189 |
+
scene.show(line_settings={'point_size': point_size})
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03):
|
| 193 |
+
|
| 194 |
+
if image is not None:
|
| 195 |
+
H, W, THREE = image.shape
|
| 196 |
+
assert THREE == 3
|
| 197 |
+
if image.dtype != np.uint8:
|
| 198 |
+
image = np.uint8(255*image)
|
| 199 |
+
elif imsize is not None:
|
| 200 |
+
W, H = imsize
|
| 201 |
+
elif focal is not None:
|
| 202 |
+
H = W = focal / 1.1
|
| 203 |
+
else:
|
| 204 |
+
H = W = 1
|
| 205 |
+
|
| 206 |
+
if focal is None:
|
| 207 |
+
focal = min(H, W) * 1.1 # default value
|
| 208 |
+
elif isinstance(focal, np.ndarray):
|
| 209 |
+
focal = focal[0]
|
| 210 |
+
|
| 211 |
+
# create fake camera
|
| 212 |
+
height = focal * screen_width / H
|
| 213 |
+
width = screen_width * 0.5**0.5
|
| 214 |
+
rot45 = np.eye(4)
|
| 215 |
+
rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
|
| 216 |
+
rot45[2, 3] = -height # set the tip of the cone = optical center
|
| 217 |
+
aspect_ratio = np.eye(4)
|
| 218 |
+
aspect_ratio[0, 0] = W/H
|
| 219 |
+
transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
|
| 220 |
+
cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
|
| 221 |
+
|
| 222 |
+
# this is the image
|
| 223 |
+
if image is not None:
|
| 224 |
+
vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
|
| 225 |
+
faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
|
| 226 |
+
img = trimesh.Trimesh(vertices=vertices, faces=faces)
|
| 227 |
+
uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
|
| 228 |
+
img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))
|
| 229 |
+
scene.add_geometry(img)
|
| 230 |
+
|
| 231 |
+
# this is the camera mesh
|
| 232 |
+
rot2 = np.eye(4)
|
| 233 |
+
rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
|
| 234 |
+
vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
|
| 235 |
+
vertices = geotrf(transform, vertices)
|
| 236 |
+
faces = []
|
| 237 |
+
for face in cam.faces:
|
| 238 |
+
if 0 in face:
|
| 239 |
+
continue
|
| 240 |
+
a, b, c = face
|
| 241 |
+
a2, b2, c2 = face + len(cam.vertices)
|
| 242 |
+
a3, b3, c3 = face + 2*len(cam.vertices)
|
| 243 |
+
|
| 244 |
+
# add 3 pseudo-edges
|
| 245 |
+
faces.append((a, b, b2))
|
| 246 |
+
faces.append((a, a2, c))
|
| 247 |
+
faces.append((c2, b, c))
|
| 248 |
+
|
| 249 |
+
faces.append((a, b, b3))
|
| 250 |
+
faces.append((a, a3, c))
|
| 251 |
+
faces.append((c3, b, c))
|
| 252 |
+
|
| 253 |
+
# no culling
|
| 254 |
+
faces += [(c, b, a) for a, b, c in faces]
|
| 255 |
+
|
| 256 |
+
cam = trimesh.Trimesh(vertices=vertices, faces=faces)
|
| 257 |
+
cam.visual.face_colors[:, :3] = edge_color
|
| 258 |
+
scene.add_geometry(cam)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def cat(a, b):
|
| 262 |
+
return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
OPENGL = np.array([[1, 0, 0, 0],
|
| 266 |
+
[0, -1, 0, 0],
|
| 267 |
+
[0, 0, -1, 0],
|
| 268 |
+
[0, 0, 0, 1]])
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
CAM_COLORS = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),
|
| 272 |
+
(128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def uint8(colors):
|
| 276 |
+
if not isinstance(colors, np.ndarray):
|
| 277 |
+
colors = np.array(colors)
|
| 278 |
+
if np.issubdtype(colors.dtype, np.floating):
|
| 279 |
+
colors *= 255
|
| 280 |
+
assert 0 <= colors.min() and colors.max() < 256
|
| 281 |
+
return np.uint8(colors)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def segment_sky(image):
|
| 285 |
+
import cv2
|
| 286 |
+
from scipy import ndimage
|
| 287 |
+
|
| 288 |
+
# Convert to HSV
|
| 289 |
+
image = to_numpy(image)
|
| 290 |
+
if np.issubdtype(image.dtype, np.floating):
|
| 291 |
+
image = np.uint8(255*image.clip(min=0, max=1))
|
| 292 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
| 293 |
+
|
| 294 |
+
# Define range for blue color and create mask
|
| 295 |
+
lower_blue = np.array([0, 0, 100])
|
| 296 |
+
upper_blue = np.array([30, 255, 255])
|
| 297 |
+
mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
|
| 298 |
+
|
| 299 |
+
# add luminous gray
|
| 300 |
+
mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
|
| 301 |
+
mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
|
| 302 |
+
mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
|
| 303 |
+
|
| 304 |
+
# Morphological operations
|
| 305 |
+
kernel = np.ones((5, 5), np.uint8)
|
| 306 |
+
mask2 = ndimage.binary_opening(mask, structure=kernel)
|
| 307 |
+
|
| 308 |
+
# keep only largest CC
|
| 309 |
+
_, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8)
|
| 310 |
+
cc_sizes = stats[1:, cv2.CC_STAT_AREA]
|
| 311 |
+
order = cc_sizes.argsort()[::-1] # bigger first
|
| 312 |
+
i = 0
|
| 313 |
+
selection = []
|
| 314 |
+
while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
|
| 315 |
+
selection.append(1 + order[i])
|
| 316 |
+
i += 1
|
| 317 |
+
mask3 = np.in1d(labels, selection).reshape(labels.shape)
|
| 318 |
+
|
| 319 |
+
# Apply mask
|
| 320 |
+
return torch.from_numpy(mask3)
|
dust3r/requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
roma
|
| 4 |
+
gradio
|
| 5 |
+
matplotlib
|
| 6 |
+
tqdm
|
| 7 |
+
opencv-python
|
| 8 |
+
scipy
|
| 9 |
+
einops
|
| 10 |
+
trimesh
|
| 11 |
+
tensorboard
|
| 12 |
+
pyglet<2
|
dust3r/train.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 4 |
+
#
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# training code for DUSt3R
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
# References:
|
| 9 |
+
# MAE: https://github.com/facebookresearch/mae
|
| 10 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 11 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 12 |
+
# --------------------------------------------------------
|
| 13 |
+
import argparse
|
| 14 |
+
import datetime
|
| 15 |
+
import json
|
| 16 |
+
import numpy as np
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import time
|
| 20 |
+
import math
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Sized
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.backends.cudnn as cudnn
|
| 27 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 28 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
| 29 |
+
|
| 30 |
+
from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model
|
| 31 |
+
from dust3r.datasets import get_data_loader # noqa
|
| 32 |
+
from dust3r.losses import * # noqa: F401, needed when loading the model
|
| 33 |
+
from dust3r.inference import loss_of_one_batch # noqa
|
| 34 |
+
|
| 35 |
+
import dust3r.utils.path_to_croco # noqa: F401
|
| 36 |
+
import croco.utils.misc as misc # noqa
|
| 37 |
+
from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler # noqa
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_args_parser():
|
| 41 |
+
parser = argparse.ArgumentParser('DUST3R training', add_help=False)
|
| 42 |
+
# model and criterion
|
| 43 |
+
parser.add_argument('--model', default="AsymmetricCroCo3DStereo(patch_embed_cls='ManyAR_PatchEmbed')",
|
| 44 |
+
type=str, help="string containing the model to build")
|
| 45 |
+
parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint')
|
| 46 |
+
parser.add_argument('--train_criterion', default="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)",
|
| 47 |
+
type=str, help="train criterion")
|
| 48 |
+
parser.add_argument('--test_criterion', default=None, type=str, help="test criterion")
|
| 49 |
+
|
| 50 |
+
# dataset
|
| 51 |
+
parser.add_argument('--train_dataset', required=True, type=str, help="training set")
|
| 52 |
+
parser.add_argument('--test_dataset', default='[None]', type=str, help="testing set")
|
| 53 |
+
|
| 54 |
+
# training
|
| 55 |
+
parser.add_argument('--seed', default=0, type=int, help="Random seed")
|
| 56 |
+
parser.add_argument('--batch_size', default=64, type=int,
|
| 57 |
+
help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus")
|
| 58 |
+
parser.add_argument('--accum_iter', default=1, type=int,
|
| 59 |
+
help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)")
|
| 60 |
+
parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler")
|
| 61 |
+
|
| 62 |
+
parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)")
|
| 63 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
|
| 64 |
+
parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
|
| 65 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
| 66 |
+
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
| 67 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
| 68 |
+
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR')
|
| 69 |
+
|
| 70 |
+
parser.add_argument('--amp', type=int, default=0,
|
| 71 |
+
choices=[0, 1], help="Use Automatic Mixed Precision for pretraining")
|
| 72 |
+
|
| 73 |
+
# others
|
| 74 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
| 75 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
| 76 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
| 77 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 78 |
+
|
| 79 |
+
parser.add_argument('--eval_freq', type=int, default=1, help='Test loss evaluation frequency')
|
| 80 |
+
parser.add_argument('--save_freq', default=1, type=int,
|
| 81 |
+
help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth')
|
| 82 |
+
parser.add_argument('--keep_freq', default=20, type=int,
|
| 83 |
+
help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth')
|
| 84 |
+
parser.add_argument('--print_freq', default=20, type=int,
|
| 85 |
+
help='frequence (number of iterations) to print infos while training')
|
| 86 |
+
|
| 87 |
+
# output dir
|
| 88 |
+
parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output")
|
| 89 |
+
return parser
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def main(args):
|
| 93 |
+
misc.init_distributed_mode(args)
|
| 94 |
+
global_rank = misc.get_rank()
|
| 95 |
+
world_size = misc.get_world_size()
|
| 96 |
+
|
| 97 |
+
print("output_dir: "+args.output_dir)
|
| 98 |
+
if args.output_dir:
|
| 99 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
# auto resume
|
| 102 |
+
last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth')
|
| 103 |
+
args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None
|
| 104 |
+
|
| 105 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
| 106 |
+
print("{}".format(args).replace(', ', ',\n'))
|
| 107 |
+
|
| 108 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 109 |
+
device = torch.device(device)
|
| 110 |
+
|
| 111 |
+
# fix the seed
|
| 112 |
+
seed = args.seed + misc.get_rank()
|
| 113 |
+
torch.manual_seed(seed)
|
| 114 |
+
np.random.seed(seed)
|
| 115 |
+
|
| 116 |
+
cudnn.benchmark = True
|
| 117 |
+
|
| 118 |
+
# training dataset and loader
|
| 119 |
+
print('Building train dataset {:s}'.format(args.train_dataset))
|
| 120 |
+
# dataset and loader
|
| 121 |
+
data_loader_train = build_dataset(args.train_dataset, args.batch_size, args.num_workers, test=False)
|
| 122 |
+
print('Building test dataset {:s}'.format(args.train_dataset))
|
| 123 |
+
data_loader_test = {dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True)
|
| 124 |
+
for dataset in args.test_dataset.split('+')}
|
| 125 |
+
|
| 126 |
+
# model
|
| 127 |
+
print('Loading model: {:s}'.format(args.model))
|
| 128 |
+
model = eval(args.model)
|
| 129 |
+
print(f'>> Creating train criterion = {args.train_criterion}')
|
| 130 |
+
train_criterion = eval(args.train_criterion).to(device)
|
| 131 |
+
print(f'>> Creating test criterion = {args.test_criterion or args.train_criterion}')
|
| 132 |
+
test_criterion = eval(args.test_criterion or args.criterion).to(device)
|
| 133 |
+
|
| 134 |
+
model.to(device)
|
| 135 |
+
model_without_ddp = model
|
| 136 |
+
print("Model = %s" % str(model_without_ddp))
|
| 137 |
+
|
| 138 |
+
if args.pretrained and not args.resume:
|
| 139 |
+
print('Loading pretrained: ', args.pretrained)
|
| 140 |
+
ckpt = torch.load(args.pretrained, map_location=device)
|
| 141 |
+
print(model.load_state_dict(ckpt['model'], strict=False))
|
| 142 |
+
del ckpt # in case it occupies memory
|
| 143 |
+
|
| 144 |
+
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
| 145 |
+
if args.lr is None: # only base_lr is specified
|
| 146 |
+
args.lr = args.blr * eff_batch_size / 256
|
| 147 |
+
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
| 148 |
+
print("actual lr: %.2e" % args.lr)
|
| 149 |
+
print("accumulate grad iterations: %d" % args.accum_iter)
|
| 150 |
+
print("effective batch size: %d" % eff_batch_size)
|
| 151 |
+
|
| 152 |
+
if args.distributed:
|
| 153 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 154 |
+
model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True)
|
| 155 |
+
model_without_ddp = model.module
|
| 156 |
+
|
| 157 |
+
# following timm: set wd as 0 for bias and norm layers
|
| 158 |
+
param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay)
|
| 159 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
| 160 |
+
print(optimizer)
|
| 161 |
+
loss_scaler = NativeScaler()
|
| 162 |
+
|
| 163 |
+
def write_log_stats(epoch, train_stats, test_stats):
|
| 164 |
+
if misc.is_main_process():
|
| 165 |
+
if log_writer is not None:
|
| 166 |
+
log_writer.flush()
|
| 167 |
+
|
| 168 |
+
log_stats = dict(epoch=epoch, **{f'train_{k}': v for k, v in train_stats.items()})
|
| 169 |
+
for test_name in data_loader_test:
|
| 170 |
+
if test_name not in test_stats:
|
| 171 |
+
continue
|
| 172 |
+
log_stats.update({test_name+'_'+k: v for k, v in test_stats[test_name].items()})
|
| 173 |
+
|
| 174 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
| 175 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 176 |
+
|
| 177 |
+
def save_model(epoch, fname, best_so_far):
|
| 178 |
+
misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
| 179 |
+
loss_scaler=loss_scaler, epoch=epoch, fname=fname, best_so_far=best_so_far)
|
| 180 |
+
|
| 181 |
+
best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp,
|
| 182 |
+
optimizer=optimizer, loss_scaler=loss_scaler)
|
| 183 |
+
if best_so_far is None:
|
| 184 |
+
best_so_far = float('inf')
|
| 185 |
+
if global_rank == 0 and args.output_dir is not None:
|
| 186 |
+
log_writer = SummaryWriter(log_dir=args.output_dir)
|
| 187 |
+
else:
|
| 188 |
+
log_writer = None
|
| 189 |
+
|
| 190 |
+
print(f"Start training for {args.epochs} epochs")
|
| 191 |
+
start_time = time.time()
|
| 192 |
+
train_stats = test_stats = {}
|
| 193 |
+
for epoch in range(args.start_epoch, args.epochs+1):
|
| 194 |
+
|
| 195 |
+
# Save immediately the last checkpoint
|
| 196 |
+
if epoch > args.start_epoch:
|
| 197 |
+
if args.save_freq and epoch % args.save_freq == 0 or epoch == args.epochs:
|
| 198 |
+
save_model(epoch-1, 'last', best_so_far)
|
| 199 |
+
|
| 200 |
+
# Test on multiple datasets
|
| 201 |
+
new_best = False
|
| 202 |
+
if (epoch > 0 and args.eval_freq > 0 and epoch % args.eval_freq == 0):
|
| 203 |
+
test_stats = {}
|
| 204 |
+
for test_name, testset in data_loader_test.items():
|
| 205 |
+
stats = test_one_epoch(model, test_criterion, testset,
|
| 206 |
+
device, epoch, log_writer=log_writer, args=args, prefix=test_name)
|
| 207 |
+
test_stats[test_name] = stats
|
| 208 |
+
|
| 209 |
+
# Save best of all
|
| 210 |
+
if stats['loss_med'] < best_so_far:
|
| 211 |
+
best_so_far = stats['loss_med']
|
| 212 |
+
new_best = True
|
| 213 |
+
|
| 214 |
+
# Save more stuff
|
| 215 |
+
write_log_stats(epoch, train_stats, test_stats)
|
| 216 |
+
|
| 217 |
+
if epoch > args.start_epoch:
|
| 218 |
+
if args.keep_freq and epoch % args.keep_freq == 0:
|
| 219 |
+
save_model(epoch-1, str(epoch), best_so_far)
|
| 220 |
+
if new_best:
|
| 221 |
+
save_model(epoch-1, 'best', best_so_far)
|
| 222 |
+
if epoch >= args.epochs:
|
| 223 |
+
break # exit after writing last test to disk
|
| 224 |
+
|
| 225 |
+
# Train
|
| 226 |
+
train_stats = train_one_epoch(
|
| 227 |
+
model, train_criterion, data_loader_train,
|
| 228 |
+
optimizer, device, epoch, loss_scaler,
|
| 229 |
+
log_writer=log_writer,
|
| 230 |
+
args=args)
|
| 231 |
+
|
| 232 |
+
total_time = time.time() - start_time
|
| 233 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 234 |
+
print('Training time {}'.format(total_time_str))
|
| 235 |
+
|
| 236 |
+
save_final_model(args, args.epochs, model_without_ddp, best_so_far=best_so_far)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def save_final_model(args, epoch, model_without_ddp, best_so_far=None):
|
| 240 |
+
output_dir = Path(args.output_dir)
|
| 241 |
+
checkpoint_path = output_dir / 'checkpoint-final.pth'
|
| 242 |
+
to_save = {
|
| 243 |
+
'args': args,
|
| 244 |
+
'model': model_without_ddp if isinstance(model_without_ddp, dict) else model_without_ddp.cpu().state_dict(),
|
| 245 |
+
'epoch': epoch
|
| 246 |
+
}
|
| 247 |
+
if best_so_far is not None:
|
| 248 |
+
to_save['best_so_far'] = best_so_far
|
| 249 |
+
print(f'>> Saving model to {checkpoint_path} ...')
|
| 250 |
+
misc.save_on_master(to_save, checkpoint_path)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def build_dataset(dataset, batch_size, num_workers, test=False):
|
| 254 |
+
split = ['Train', 'Test'][test]
|
| 255 |
+
print(f'Building {split} Data loader for dataset: ', dataset)
|
| 256 |
+
loader = get_data_loader(dataset,
|
| 257 |
+
batch_size=batch_size,
|
| 258 |
+
num_workers=num_workers,
|
| 259 |
+
pin_mem=True,
|
| 260 |
+
shuffle=not (test),
|
| 261 |
+
drop_last=not (test))
|
| 262 |
+
|
| 263 |
+
print(f"{split} dataset length: ", len(loader))
|
| 264 |
+
return loader
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
| 268 |
+
data_loader: Sized, optimizer: torch.optim.Optimizer,
|
| 269 |
+
device: torch.device, epoch: int, loss_scaler,
|
| 270 |
+
args,
|
| 271 |
+
log_writer=None):
|
| 272 |
+
assert torch.backends.cuda.matmul.allow_tf32 == True
|
| 273 |
+
|
| 274 |
+
model.train(True)
|
| 275 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 276 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 277 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 278 |
+
accum_iter = args.accum_iter
|
| 279 |
+
|
| 280 |
+
if log_writer is not None:
|
| 281 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
| 282 |
+
|
| 283 |
+
if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'):
|
| 284 |
+
data_loader.dataset.set_epoch(epoch)
|
| 285 |
+
if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'):
|
| 286 |
+
data_loader.sampler.set_epoch(epoch)
|
| 287 |
+
|
| 288 |
+
optimizer.zero_grad()
|
| 289 |
+
|
| 290 |
+
for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
|
| 291 |
+
epoch_f = epoch + data_iter_step / len(data_loader)
|
| 292 |
+
|
| 293 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 294 |
+
if data_iter_step % accum_iter == 0:
|
| 295 |
+
misc.adjust_learning_rate(optimizer, epoch_f, args)
|
| 296 |
+
|
| 297 |
+
loss_tuple = loss_of_one_batch(batch, model, criterion, device,
|
| 298 |
+
symmetrize_batch=True,
|
| 299 |
+
use_amp=bool(args.amp), ret='loss')
|
| 300 |
+
loss, loss_details = loss_tuple # criterion returns two values
|
| 301 |
+
loss_value = float(loss)
|
| 302 |
+
|
| 303 |
+
if not math.isfinite(loss_value):
|
| 304 |
+
print("Loss is {}, stopping training".format(loss_value), force=True)
|
| 305 |
+
sys.exit(1)
|
| 306 |
+
|
| 307 |
+
loss /= accum_iter
|
| 308 |
+
loss_scaler(loss, optimizer, parameters=model.parameters(),
|
| 309 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
| 310 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
| 311 |
+
optimizer.zero_grad()
|
| 312 |
+
|
| 313 |
+
del loss
|
| 314 |
+
del batch
|
| 315 |
+
|
| 316 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 317 |
+
metric_logger.update(epoch=epoch_f)
|
| 318 |
+
metric_logger.update(lr=lr)
|
| 319 |
+
metric_logger.update(loss=loss_value, **loss_details)
|
| 320 |
+
|
| 321 |
+
if (data_iter_step + 1) % accum_iter == 0 and ((data_iter_step + 1) % (accum_iter * args.print_freq)) == 0:
|
| 322 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value) # MUST BE EXECUTED BY ALL NODES
|
| 323 |
+
if log_writer is None:
|
| 324 |
+
continue
|
| 325 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
| 326 |
+
This calibrates different curves when batch size changes.
|
| 327 |
+
"""
|
| 328 |
+
epoch_1000x = int(epoch_f * 1000)
|
| 329 |
+
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
|
| 330 |
+
log_writer.add_scalar('train_lr', lr, epoch_1000x)
|
| 331 |
+
log_writer.add_scalar('train_iter', epoch_1000x, epoch_1000x)
|
| 332 |
+
for name, val in loss_details.items():
|
| 333 |
+
log_writer.add_scalar('train_'+name, val, epoch_1000x)
|
| 334 |
+
|
| 335 |
+
# gather the stats from all processes
|
| 336 |
+
metric_logger.synchronize_between_processes()
|
| 337 |
+
print("Averaged stats:", metric_logger)
|
| 338 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@torch.no_grad()
|
| 342 |
+
def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
| 343 |
+
data_loader: Sized, device: torch.device, epoch: int,
|
| 344 |
+
args, log_writer=None, prefix='test'):
|
| 345 |
+
|
| 346 |
+
model.eval()
|
| 347 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 348 |
+
metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9))
|
| 349 |
+
header = 'Test Epoch: [{}]'.format(epoch)
|
| 350 |
+
|
| 351 |
+
if log_writer is not None:
|
| 352 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
| 353 |
+
|
| 354 |
+
if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'):
|
| 355 |
+
data_loader.dataset.set_epoch(epoch)
|
| 356 |
+
if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'):
|
| 357 |
+
data_loader.sampler.set_epoch(epoch)
|
| 358 |
+
|
| 359 |
+
for _, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
|
| 360 |
+
loss_tuple = loss_of_one_batch(batch, model, criterion, device,
|
| 361 |
+
symmetrize_batch=True,
|
| 362 |
+
use_amp=bool(args.amp), ret='loss')
|
| 363 |
+
loss_value, loss_details = loss_tuple # criterion returns two values
|
| 364 |
+
metric_logger.update(loss=float(loss_value), **loss_details)
|
| 365 |
+
|
| 366 |
+
# gather the stats from all processes
|
| 367 |
+
metric_logger.synchronize_between_processes()
|
| 368 |
+
print("Averaged stats:", metric_logger)
|
| 369 |
+
|
| 370 |
+
aggs = [('avg', 'global_avg'), ('med', 'median')]
|
| 371 |
+
results = {f'{k}_{tag}': getattr(meter, attr) for k, meter in metric_logger.meters.items() for tag, attr in aggs}
|
| 372 |
+
|
| 373 |
+
if log_writer is not None:
|
| 374 |
+
for name, val in results.items():
|
| 375 |
+
log_writer.add_scalar(prefix+'_'+name, val, 1000*epoch)
|
| 376 |
+
|
| 377 |
+
return results
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
if __name__ == '__main__':
|
| 381 |
+
args = get_args_parser()
|
| 382 |
+
args = args.parse_args()
|
| 383 |
+
main(args)
|
setup.sh
CHANGED
|
@@ -4,10 +4,13 @@ set -e
|
|
| 4 |
|
| 5 |
cd "$(dirname "$0")"
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
fi
|
| 10 |
|
|
|
|
| 11 |
python -m pip install -r requirements.txt
|
| 12 |
|
| 13 |
if [[ "$(uname)" == "Linux" ]]; then
|
|
@@ -18,6 +21,7 @@ mkdir -p dust3r/checkpoints
|
|
| 18 |
|
| 19 |
WEIGHTS=dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
|
| 20 |
if [ ! -f "$WEIGHTS" ]; then
|
|
|
|
| 21 |
wget https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth \
|
| 22 |
-P dust3r/checkpoints
|
| 23 |
fi
|
|
|
|
| 4 |
|
| 5 |
cd "$(dirname "$0")"
|
| 6 |
|
| 7 |
+
echo "Initializing dust3r submodule..."
|
| 8 |
+
if [ ! -d "dust3r/.git" ]; then
|
| 9 |
+
git submodule add -f https://github.com/camenduru/dust3r.git dust3r 2>/dev/null || true
|
| 10 |
+
git submodule update --init --recursive
|
| 11 |
fi
|
| 12 |
|
| 13 |
+
echo "Installing Python dependencies..."
|
| 14 |
python -m pip install -r requirements.txt
|
| 15 |
|
| 16 |
if [[ "$(uname)" == "Linux" ]]; then
|
|
|
|
| 21 |
|
| 22 |
WEIGHTS=dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
|
| 23 |
if [ ! -f "$WEIGHTS" ]; then
|
| 24 |
+
echo "Downloading model weights..."
|
| 25 |
wget https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth \
|
| 26 |
-P dust3r/checkpoints
|
| 27 |
fi
|