sharifIslam commited on
Commit
71466b0
·
1 Parent(s): 407c655

Add dust3r source for HF (no binaries)

Browse files
Files changed (48) hide show
  1. .gitignore +6 -1
  2. app.py +5 -0
  3. dust3r +0 -1
  4. dust3r/.gitignore +132 -0
  5. dust3r/.gitmodules +3 -0
  6. dust3r/LICENSE +7 -0
  7. dust3r/NOTICE +13 -0
  8. dust3r/README.md +299 -0
  9. dust3r/datasets_preprocess/path_to_root.py +13 -0
  10. dust3r/datasets_preprocess/preprocess_co3d.py +295 -0
  11. dust3r/demo.py +283 -0
  12. dust3r/dust3r/__init__.py +2 -0
  13. dust3r/dust3r/cloud_opt/__init__.py +29 -0
  14. dust3r/dust3r/cloud_opt/base_opt.py +375 -0
  15. dust3r/dust3r/cloud_opt/commons.py +90 -0
  16. dust3r/dust3r/cloud_opt/init_im_poses.py +312 -0
  17. dust3r/dust3r/cloud_opt/optimizer.py +230 -0
  18. dust3r/dust3r/cloud_opt/pair_viewer.py +125 -0
  19. dust3r/dust3r/datasets/__init__.py +42 -0
  20. dust3r/dust3r/datasets/base/__init__.py +2 -0
  21. dust3r/dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
  22. dust3r/dust3r/datasets/base/batched_sampler.py +74 -0
  23. dust3r/dust3r/datasets/base/easy_dataset.py +157 -0
  24. dust3r/dust3r/datasets/co3d.py +146 -0
  25. dust3r/dust3r/datasets/utils/__init__.py +2 -0
  26. dust3r/dust3r/datasets/utils/cropping.py +119 -0
  27. dust3r/dust3r/datasets/utils/transforms.py +11 -0
  28. dust3r/dust3r/heads/__init__.py +19 -0
  29. dust3r/dust3r/heads/dpt_head.py +115 -0
  30. dust3r/dust3r/heads/linear_head.py +41 -0
  31. dust3r/dust3r/heads/postprocess.py +58 -0
  32. dust3r/dust3r/image_pairs.py +83 -0
  33. dust3r/dust3r/inference.py +165 -0
  34. dust3r/dust3r/losses.py +297 -0
  35. dust3r/dust3r/model.py +166 -0
  36. dust3r/dust3r/optim_factory.py +14 -0
  37. dust3r/dust3r/patch_embed.py +70 -0
  38. dust3r/dust3r/post_process.py +60 -0
  39. dust3r/dust3r/utils/__init__.py +2 -0
  40. dust3r/dust3r/utils/device.py +76 -0
  41. dust3r/dust3r/utils/geometry.py +361 -0
  42. dust3r/dust3r/utils/image.py +104 -0
  43. dust3r/dust3r/utils/misc.py +121 -0
  44. dust3r/dust3r/utils/path_to_croco.py +19 -0
  45. dust3r/dust3r/viz.py +320 -0
  46. dust3r/requirements.txt +12 -0
  47. dust3r/train.py +383 -0
  48. setup.sh +6 -2
.gitignore CHANGED
@@ -16,5 +16,10 @@ results/
16
  # Gradio
17
  .gradio/
18
 
19
- # OS
 
 
 
 
 
20
  .DS_Store
 
16
  # Gradio
17
  .gradio/
18
 
19
+ dust3r/
20
+ results/
21
+ dust3r/assets/*.jpg
22
+ dust3r/croco/assets/*.png
23
+ results/object.glb
24
+
25
  .DS_Store
app.py CHANGED
@@ -1,6 +1,11 @@
1
  import os
2
  import sys
3
  import torch
 
 
 
 
 
4
 
5
  sys.path.append(os.path.join(os.path.dirname(__file__), 'dust3r'))
6
 
 
1
  import os
2
  import sys
3
  import torch
4
+ import subprocess
5
+
6
+ if not os.path.exists('dust3r/.git'):
7
+ print("Initializing dust3r submodule...")
8
+ subprocess.run(['git', 'submodule', 'update', '--init', '--recursive'], check=False)
9
 
10
  sys.path.append(os.path.join(os.path.dirname(__file__), 'dust3r'))
11
 
dust3r DELETED
@@ -1 +0,0 @@
1
- Subproject commit 78e55fd11ef6d838fc3d8c5c5a52b32eac426e09
 
 
dust3r/.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ checkpoints/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ pip-wheel-metadata/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
dust3r/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "croco"]
2
+ path = croco
3
+ url = https://github.com/naver/croco
dust3r/LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
dust3r/NOTICE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DUSt3R
2
+ Copyright 2024-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ naver/croco
10
+ https://github.com/naver/croco/
11
+
12
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0
13
+
dust3r/README.md ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DUSt3R
2
+
3
+ Official implementation of `DUSt3R: Geometric 3D Vision Made Easy`
4
+ [[Project page](https://dust3r.europe.naverlabs.com/)], [[DUSt3R arxiv](https://arxiv.org/abs/2312.14132)]
5
+
6
+ ![Example of reconstruction from two images](assets/pipeline1.jpg)
7
+
8
+ ![High level overview of DUSt3R capabilities](assets/dust3r_archi.jpg)
9
+
10
+ ```bibtex
11
+ @misc{wang2023dust3r,
12
+ title={DUSt3R: Geometric 3D Vision Made Easy},
13
+ author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
14
+ year={2023},
15
+ eprint={2312.14132},
16
+ archivePrefix={arXiv},
17
+ primaryClass={cs.CV}
18
+ }
19
+ ```
20
+
21
+ ## Table of Contents
22
+ - [DUSt3R](#dust3r)
23
+ - [License](#license)
24
+ - [Get Started](#get-started)
25
+ - [Installation](#installation)
26
+ - [Checkpoints](#checkpoints)
27
+ - [Interactive demo](#interactive-demo)
28
+ - [Usage](#usage)
29
+ - [Training](#training)
30
+ - [Demo](#demo)
31
+ - [Our Hyperparameters](#our-hyperparameters)
32
+
33
+ ## License
34
+ The code is distributed under the CC BY-NC-SA 4.0 License. See See [LICENSE](LICENSE) for more information.
35
+ ```python
36
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
37
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
38
+ ```
39
+
40
+ ## Get Started
41
+
42
+ ### Installation
43
+
44
+ 1. Clone DUSt3R
45
+ ```bash
46
+ git clone --recursive https://github.com/naver/dust3r
47
+ cd dust3r
48
+ # if you have already cloned dust3r:
49
+ # git submodule update --init --recursive
50
+ ```
51
+
52
+ 2. Create the environment, here we show an example using conda.
53
+ ```bash
54
+ conda create -n dust3r python=3.11 cmake=3.14.0
55
+ conda activate dust3r
56
+ conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system
57
+ pip install -r requirements.txt
58
+ ```
59
+
60
+
61
+ 3. Optional, compile the cuda kernels for RoPE (as in CroCo v2)
62
+ ```bash
63
+ # DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
64
+ cd croco/models/curope/
65
+ python setup.py build_ext --inplace
66
+ cd ../../../
67
+ ```
68
+
69
+ 4. Download pre-trained model
70
+ ```bash
71
+ mkdir -p checkpoints/
72
+ wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/
73
+ ```
74
+
75
+ ### Checkpoints
76
+
77
+ We provide several pre-trained models:
78
+
79
+ | Modelname | Training resolutions | Head | Encoder | Decoder |
80
+ |-------------|----------------------|------|---------|---------|
81
+ | [`DUSt3R_ViTLarge_BaseDecoder_224_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth) | 224x224 | Linear | ViT-L | ViT-B |
82
+ | [`DUSt3R_ViTLarge_BaseDecoder_512_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | Linear | ViT-L | ViT-B |
83
+ | [`DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | DPT | ViT-L | ViT-B |
84
+
85
+ You can check the hyperparameters we used to train these models in the [section: Our Hyperparameters](#our-hyperparameters)
86
+
87
+ ### Interactive demo
88
+ In this demo, you should be able run DUSt3R on your machine to reconstruct a scene.
89
+ First select images that depicts the same scene.
90
+
91
+ You can ajust the global alignment schedule and its number of iterations.
92
+ Note: if you selected one or two images, the global alignment procedure will be skipped (mode=GlobalAlignerMode.PairViewer)
93
+ Hit "Run" and wait.
94
+ When the global alignment ends, the reconstruction appears.
95
+ Use the slider "min_conf_thr" to show or remove low confidence areas.
96
+
97
+ ```bash
98
+ python3 demo.py --weights checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
99
+
100
+ # Use --image_size to select the correct resolution for your checkpoint. 512 (default) or 224
101
+ # Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
102
+ # Use --server_port to change the port, by default it will search for an available port starting at 7860
103
+ # Use --device to use a different device, by default it's "cuda"
104
+ ```
105
+
106
+ ![demo](assets/demo.jpg)
107
+
108
+ ## Usage
109
+
110
+ ```python
111
+ from dust3r.inference import inference, load_model
112
+ from dust3r.utils.image import load_images
113
+ from dust3r.image_pairs import make_pairs
114
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
115
+
116
+ if __name__ == '__main__':
117
+ model_path = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
118
+ device = 'cuda'
119
+ batch_size = 1
120
+ schedule = 'cosine'
121
+ lr = 0.01
122
+ niter = 300
123
+
124
+ model = load_model(model_path, device)
125
+ # load_images can take a list of images or a directory
126
+ images = load_images(['croco/assets/Chateau1.png', 'croco/assets/Chateau2.png'], size=512)
127
+ pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
128
+ output = inference(pairs, model, device, batch_size=batch_size)
129
+
130
+ # at this stage, you have the raw dust3r predictions
131
+ view1, pred1 = output['view1'], output['pred1']
132
+ view2, pred2 = output['view2'], output['pred2']
133
+ # here, view1, pred1, view2, pred2 are dicts of lists of len(2)
134
+ # -> because we symmetrize we have (im1, im2) and (im2, im1) pairs
135
+ # in each view you have:
136
+ # an integer image identifier: view1['idx'] and view2['idx']
137
+ # the img: view1['img'] and view2['img']
138
+ # the image shape: view1['true_shape'] and view2['true_shape']
139
+ # an instance string output by the dataloader: view1['instance'] and view2['instance']
140
+ # pred1 and pred2 contains the confidence values: pred1['conf'] and pred2['conf']
141
+ # pred1 contains 3D points for view1['img'] in view1['img'] space: pred1['pts3d']
142
+ # pred2 contains 3D points for view2['img'] in view1['img'] space: pred2['pts3d_in_other_view']
143
+
144
+ # next we'll use the global_aligner to align the predictions
145
+ # depending on your task, you may be fine with the raw output and not need it
146
+ # with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output
147
+ # if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment
148
+ scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
149
+ loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
150
+
151
+ # retrieve useful values from scene:
152
+ imgs = scene.imgs
153
+ focals = scene.get_focals()
154
+ poses = scene.get_im_poses()
155
+ pts3d = scene.get_pts3d()
156
+ confidence_masks = scene.get_masks()
157
+
158
+ # visualize reconstruction
159
+ scene.show()
160
+
161
+ # find 2D-2D matches between the two images
162
+ from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
163
+ pts2d_list, pts3d_list = [], []
164
+ for i in range(2):
165
+ conf_i = confidence_masks[i].cpu().numpy()
166
+ pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i]) # imgs[i].shape[:2] = (H, W)
167
+ pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
168
+ reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
169
+ print(f'found {num_matches} matches')
170
+ matches_im1 = pts2d_list[1][reciprocal_in_P2]
171
+ matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
172
+
173
+ # visualize a few matches
174
+ import numpy as np
175
+ from matplotlib import pyplot as pl
176
+ n_viz = 10
177
+ match_idx_to_viz = np.round(np.linspace(0, num_matches-1, n_viz)).astype(int)
178
+ viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
179
+
180
+ H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
181
+ img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
182
+ img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
183
+ img = np.concatenate((img0, img1), axis=1)
184
+ pl.figure()
185
+ pl.imshow(img)
186
+ cmap = pl.get_cmap('jet')
187
+ for i in range(n_viz):
188
+ (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
189
+ pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
190
+ pl.show(block=True)
191
+
192
+ ```
193
+ ![matching example on croco pair](assets/matching.jpg)
194
+
195
+ ## Training
196
+ In this section, we present propose a short demonstration to get started with training DUSt3R. At the moment, we didn't release the training datasets, so we're going to download and prepare a subset of [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE) and launch the training code on it.
197
+ The demo model will be trained for a few epochs on a very small dataset. It will not be very good.
198
+
199
+ ### Demo
200
+
201
+ ```bash
202
+
203
+ # download and prepare the co3d subset
204
+ mkdir -p data/co3d_subset
205
+ cd data/co3d_subset
206
+ git clone https://github.com/facebookresearch/co3d
207
+ cd co3d
208
+ python3 ./co3d/download_dataset.py --download_folder ../ --single_sequence_subset
209
+ rm ../*.zip
210
+ cd ../../..
211
+
212
+ python3 datasets_preprocess/preprocess_co3d.py --co3d_dir data/co3d_subset --output_dir data/co3d_subset_processed --single_sequence_subset
213
+
214
+ # download the pretrained croco v2 checkpoint
215
+ mkdir -p checkpoints/
216
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth -P checkpoints/
217
+
218
+ # the training of dust3r is done in 3 steps.
219
+ # for this example we'll do fewer epochs, for the actual hyperparameters we used in the paper, see the next section: "Our Hyperparameters"
220
+ # step 1 - train dust3r for 224 resolution
221
+ torchrun --nproc_per_node=4 train.py \
222
+ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)" \
223
+ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=224, seed=777)" \
224
+ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
225
+ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
226
+ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
227
+ --pretrained checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth \
228
+ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 16 --accum_iter 1 \
229
+ --save_freq 1 --keep_freq 5 --eval_freq 1 \
230
+ --output_dir checkpoints/dust3r_demo_224
231
+
232
+ # step 2 - train dust3r for 512 resolution
233
+ torchrun --nproc_per_node=4 train.py \
234
+ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
235
+ --test_dataset="100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
236
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
237
+ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
238
+ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
239
+ --pretrained='checkpoints/dust3r_demo_224/checkpoint-best.pth' \
240
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs 1 --epochs 10 --batch_size 4 --accum_iter 4 \
241
+ --save_freq 1 --keep_freq 5 --eval_freq 1 \
242
+ --output_dir checkpoints/dust3r_demo_512
243
+
244
+ # step 3 - train dust3r for 512 resolution with dpt
245
+ torchrun --nproc_per_node=4 train.py \
246
+ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
247
+ --test_dataset="100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
248
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
249
+ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
250
+ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
251
+ --pretrained='checkpoints/dust3r_demo_512/checkpoint-best.pth' \
252
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs 1 --epochs 10 --batch_size 2 --accum_iter 8 \
253
+ --save_freq 1 --keep_freq 5 --eval_freq 1 \
254
+ --output_dir checkpoints/dust3r_demo_512dpt
255
+
256
+ ```
257
+
258
+ ### Our Hyperparameters
259
+ We didn't release the training datasets, but here are the commands we used for training our models:
260
+
261
+ ```bash
262
+ # NOTE: ROOT path omitted for datasets
263
+ # 224 linear
264
+ torchrun --nproc_per_node 4 train.py \
265
+ --train_dataset=" + 100_000 @ Habitat512(1_000_000, split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ BlendedMVS(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ MegaDepthDense(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ ARKitScenes(aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Co3d_v3(split='train', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ ScanNetpp(split='train', aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Waymo(aug_crop=128, resolution=224, transform=ColorJitter) " \
266
+ --test_dataset=" Habitat512(1_000, split='val', resolution=224, seed=777) + 1_000 @ BlendedMVS(split='val', resolution=224, seed=777) + 1_000 @ MegaDepthDense(split='val', resolution=224, seed=777) + 1_000 @ Co3d_v3(split='test', mask_bg='rand', resolution=224, seed=777) " \
267
+ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
268
+ --test_criterion='Regr3D_ScaleShiftInv(L21, gt_scale=True)' \
269
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
270
+ --pretrained="checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
271
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=10 --epochs=100 --batch_size=16 --accum_iter=1 \
272
+ --save_freq=5 --keep_freq=10 --eval_freq=1 \
273
+ --output_dir='checkpoints/dust3r_224'
274
+
275
+ # 512 linear
276
+ torchrun --nproc_per_node 8 train.py \
277
+ --train_dataset=" + 10_000 @ Habitat512(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepthDense(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d_v3(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Waymo(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
278
+ --test_dataset=" Habitat512(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepthDense(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d_v3(split='test', resolution=(512,384), seed=777) " \
279
+ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
280
+ --test_criterion='Regr3D_ScaleShiftInv(L21, gt_scale=True)' \
281
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
282
+ --pretrained='checkpoints/dust3r_224/checkpoint-best.pth' \
283
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=20 --epochs=200 --batch_size=4 --accum_iter=2 \
284
+ --save_freq=10 --keep_freq=10 --eval_freq=1 --print_freq=10 \
285
+ --output_dir='checkpoints/dust3r_512'
286
+
287
+ # 512 dpt
288
+ torchrun --nproc_per_node 8 train.py \
289
+ --train_dataset=" + 10_000 @ Habitat512(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepthDense(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d_v3(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Waymo(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
290
+ --test_dataset=" Habitat512(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepthDense(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d_v3(split='test', resolution=(512,384), seed=777) " \
291
+ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
292
+ --test_criterion='Regr3D_ScaleShiftInv(L21, gt_scale=True)' \
293
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
294
+ --pretrained='checkpoints/dust3r_512/checkpoint-best.pth' \
295
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=15 --epochs=90 --batch_size=2 --accum_iter=4 \
296
+ --save_freq=5 --keep_freq=10 --eval_freq=1 --print_freq=10 \
297
+ --output_dir='checkpoints/dust3r_512dpt'
298
+
299
+ ```
dust3r/datasets_preprocess/path_to_root.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # DUSt3R repo root import
6
+ # --------------------------------------------------------
7
+
8
+ import sys
9
+ import os.path as path
10
+ HERE_PATH = path.normpath(path.dirname(__file__))
11
+ DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../'))
12
+ # workaround for sibling import
13
+ sys.path.insert(0, DUST3R_REPO_PATH)
dust3r/datasets_preprocess/preprocess_co3d.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # Script to pre-process the CO3D dataset.
7
+ # Usage:
8
+ # python3 datasets_preprocess/preprocess_co3d.py --co3d_dir /path/to/co3d
9
+ # --------------------------------------------------------
10
+
11
+ import argparse
12
+ import random
13
+ import gzip
14
+ import json
15
+ import os
16
+ import os.path as osp
17
+
18
+ import torch
19
+ import PIL.Image
20
+ import numpy as np
21
+ import cv2
22
+
23
+ from tqdm.auto import tqdm
24
+ import matplotlib.pyplot as plt
25
+
26
+ import path_to_root # noqa
27
+ import dust3r.datasets.utils.cropping as cropping # noqa
28
+
29
+
30
+ CATEGORIES = [
31
+ "apple", "backpack", "ball", "banana", "baseballbat", "baseballglove",
32
+ "bench", "bicycle", "book", "bottle", "bowl", "broccoli", "cake", "car", "carrot",
33
+ "cellphone", "chair", "couch", "cup", "donut", "frisbee", "hairdryer", "handbag",
34
+ "hotdog", "hydrant", "keyboard", "kite", "laptop", "microwave",
35
+ "motorcycle",
36
+ "mouse", "orange", "parkingmeter", "pizza", "plant", "remote", "sandwich",
37
+ "skateboard", "stopsign",
38
+ "suitcase", "teddybear", "toaster", "toilet", "toybus",
39
+ "toyplane", "toytrain", "toytruck", "tv",
40
+ "umbrella", "vase", "wineglass",
41
+ ]
42
+ CATEGORIES_IDX = {cat: i for i, cat in enumerate(CATEGORIES)} # for seeding
43
+
44
+ SINGLE_SEQUENCE_CATEGORIES = sorted(set(CATEGORIES) - set(["microwave", "stopsign", "tv"]))
45
+
46
+
47
+ def get_parser():
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("--category", type=str, default=None)
50
+ parser.add_argument('--single_sequence_subset', default=False, action='store_true',
51
+ help="prepare the single_sequence_subset instead.")
52
+ parser.add_argument("--output_dir", type=str, default="data/co3d_processed")
53
+ parser.add_argument("--co3d_dir", type=str, required=True)
54
+ parser.add_argument("--num_sequences_per_object", type=int, default=50)
55
+ parser.add_argument("--seed", type=int, default=42)
56
+ parser.add_argument("--min_quality", type=float, default=0.5, help="Minimum viewpoint quality score.")
57
+
58
+ parser.add_argument("--img_size", type=int, default=512,
59
+ help=("lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size"))
60
+ return parser
61
+
62
+
63
+ def convert_ndc_to_pinhole(focal_length, principal_point, image_size):
64
+ focal_length = np.array(focal_length)
65
+ principal_point = np.array(principal_point)
66
+ image_size_wh = np.array([image_size[1], image_size[0]])
67
+ half_image_size = image_size_wh / 2
68
+ rescale = half_image_size.min()
69
+ principal_point_px = half_image_size - principal_point * rescale
70
+ focal_length_px = focal_length * rescale
71
+ fx, fy = focal_length_px[0], focal_length_px[1]
72
+ cx, cy = principal_point_px[0], principal_point_px[1]
73
+ K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)
74
+ return K
75
+
76
+
77
+ def opencv_from_cameras_projection(R, T, focal, p0, image_size):
78
+ R = torch.from_numpy(R)[None, :, :]
79
+ T = torch.from_numpy(T)[None, :]
80
+ focal = torch.from_numpy(focal)[None, :]
81
+ p0 = torch.from_numpy(p0)[None, :]
82
+ image_size = torch.from_numpy(image_size)[None, :]
83
+
84
+ R_pytorch3d = R.clone()
85
+ T_pytorch3d = T.clone()
86
+ focal_pytorch3d = focal
87
+ p0_pytorch3d = p0
88
+ T_pytorch3d[:, :2] *= -1
89
+ R_pytorch3d[:, :, :2] *= -1
90
+ tvec = T_pytorch3d
91
+ R = R_pytorch3d.permute(0, 2, 1)
92
+
93
+ # Retype the image_size correctly and flip to width, height.
94
+ image_size_wh = image_size.to(R).flip(dims=(1,))
95
+
96
+ # NDC to screen conversion.
97
+ scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
98
+ scale = scale.expand(-1, 2)
99
+ c0 = image_size_wh / 2.0
100
+
101
+ principal_point = -p0_pytorch3d * scale + c0
102
+ focal_length = focal_pytorch3d * scale
103
+
104
+ camera_matrix = torch.zeros_like(R)
105
+ camera_matrix[:, :2, 2] = principal_point
106
+ camera_matrix[:, 2, 2] = 1.0
107
+ camera_matrix[:, 0, 0] = focal_length[:, 0]
108
+ camera_matrix[:, 1, 1] = focal_length[:, 1]
109
+ return R[0], tvec[0], camera_matrix[0]
110
+
111
+
112
+ def get_set_list(category_dir, split, is_single_sequence_subset=False):
113
+ listfiles = os.listdir(osp.join(category_dir, "set_lists"))
114
+ if is_single_sequence_subset:
115
+ # not all objects have manyview_dev
116
+ subset_list_files = [f for f in listfiles if "manyview_dev" in f]
117
+ else:
118
+ subset_list_files = [f for f in listfiles if f"fewview_train" in f]
119
+
120
+ sequences_all = []
121
+ for subset_list_file in subset_list_files:
122
+ with open(osp.join(category_dir, "set_lists", subset_list_file)) as f:
123
+ subset_lists_data = json.load(f)
124
+ sequences_all.extend(subset_lists_data[split])
125
+
126
+ return sequences_all
127
+
128
+
129
+ def prepare_sequences(category, co3d_dir, output_dir, img_size, split, min_quality, max_num_sequences_per_object,
130
+ seed, is_single_sequence_subset=False):
131
+ random.seed(seed)
132
+ category_dir = osp.join(co3d_dir, category)
133
+ category_output_dir = osp.join(output_dir, category)
134
+ sequences_all = get_set_list(category_dir, split, is_single_sequence_subset)
135
+ sequences_numbers = sorted(set(seq_name for seq_name, _, _ in sequences_all))
136
+
137
+ frame_file = osp.join(category_dir, "frame_annotations.jgz")
138
+ sequence_file = osp.join(category_dir, "sequence_annotations.jgz")
139
+
140
+ with gzip.open(frame_file, "r") as fin:
141
+ frame_data = json.loads(fin.read())
142
+ with gzip.open(sequence_file, "r") as fin:
143
+ sequence_data = json.loads(fin.read())
144
+
145
+ frame_data_processed = {}
146
+ for f_data in frame_data:
147
+ sequence_name = f_data["sequence_name"]
148
+ frame_data_processed.setdefault(sequence_name, {})[f_data["frame_number"]] = f_data
149
+
150
+ good_quality_sequences = set()
151
+ for seq_data in sequence_data:
152
+ if seq_data["viewpoint_quality_score"] > min_quality:
153
+ good_quality_sequences.add(seq_data["sequence_name"])
154
+
155
+ sequences_numbers = [seq_name for seq_name in sequences_numbers if seq_name in good_quality_sequences]
156
+ if len(sequences_numbers) < max_num_sequences_per_object:
157
+ selected_sequences_numbers = sequences_numbers
158
+ else:
159
+ selected_sequences_numbers = random.sample(sequences_numbers, max_num_sequences_per_object)
160
+
161
+ selected_sequences_numbers_dict = {seq_name: [] for seq_name in selected_sequences_numbers}
162
+ sequences_all = [(seq_name, frame_number, filepath)
163
+ for seq_name, frame_number, filepath in sequences_all
164
+ if seq_name in selected_sequences_numbers_dict]
165
+
166
+ for seq_name, frame_number, filepath in tqdm(sequences_all):
167
+ frame_idx = int(filepath.split('/')[-1][5:-4])
168
+ selected_sequences_numbers_dict[seq_name].append(frame_idx)
169
+ mask_path = filepath.replace("images", "masks").replace(".jpg", ".png")
170
+ frame_data = frame_data_processed[seq_name][frame_number]
171
+ focal_length = frame_data["viewpoint"]["focal_length"]
172
+ principal_point = frame_data["viewpoint"]["principal_point"]
173
+ image_size = frame_data["image"]["size"]
174
+ K = convert_ndc_to_pinhole(focal_length, principal_point, image_size)
175
+ R, tvec, camera_intrinsics = opencv_from_cameras_projection(np.array(frame_data["viewpoint"]["R"]),
176
+ np.array(frame_data["viewpoint"]["T"]),
177
+ np.array(focal_length),
178
+ np.array(principal_point),
179
+ np.array(image_size))
180
+
181
+ frame_data = frame_data_processed[seq_name][frame_number]
182
+ depth_path = os.path.join(co3d_dir, frame_data["depth"]["path"])
183
+ assert frame_data["depth"]["scale_adjustment"] == 1.0
184
+ image_path = os.path.join(co3d_dir, filepath)
185
+ mask_path_full = os.path.join(co3d_dir, mask_path)
186
+
187
+ input_rgb_image = PIL.Image.open(image_path).convert('RGB')
188
+ input_mask = plt.imread(mask_path_full)
189
+
190
+ with PIL.Image.open(depth_path) as depth_pil:
191
+ # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
192
+ # we cast it to uint16, then reinterpret as float16, then cast to float32
193
+ input_depthmap = (
194
+ np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
195
+ .astype(np.float32)
196
+ .reshape((depth_pil.size[1], depth_pil.size[0])))
197
+ depth_mask = np.stack((input_depthmap, input_mask), axis=-1)
198
+ H, W = input_depthmap.shape
199
+
200
+ camera_intrinsics = camera_intrinsics.numpy()
201
+ cx, cy = camera_intrinsics[:2, 2].round().astype(int)
202
+ min_margin_x = min(cx, W-cx)
203
+ min_margin_y = min(cy, H-cy)
204
+
205
+ # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
206
+ l, t = cx - min_margin_x, cy - min_margin_y
207
+ r, b = cx + min_margin_x, cy + min_margin_y
208
+ crop_bbox = (l, t, r, b)
209
+ input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap(
210
+ input_rgb_image, depth_mask, camera_intrinsics, crop_bbox)
211
+
212
+ # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384
213
+ scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8
214
+ output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)
215
+ if max(output_resolution) < img_size:
216
+ # let's put the max dimension to img_size
217
+ scale_final = (img_size / max(H, W)) + 1e-8
218
+ output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)
219
+
220
+ input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap(
221
+ input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution)
222
+ input_depthmap = depth_mask[:, :, 0]
223
+ input_mask = depth_mask[:, :, 1]
224
+
225
+ # generate and adjust camera pose
226
+ camera_pose = np.eye(4, dtype=np.float32)
227
+ camera_pose[:3, :3] = R
228
+ camera_pose[:3, 3] = tvec
229
+ camera_pose = np.linalg.inv(camera_pose)
230
+
231
+ # save crop images and depth, metadata
232
+ save_img_path = os.path.join(output_dir, filepath)
233
+ save_depth_path = os.path.join(output_dir, frame_data["depth"]["path"])
234
+ save_mask_path = os.path.join(output_dir, mask_path)
235
+ os.makedirs(os.path.split(save_img_path)[0], exist_ok=True)
236
+ os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True)
237
+ os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True)
238
+
239
+ input_rgb_image.save(save_img_path)
240
+ scaled_depth_map = (input_depthmap / np.max(input_depthmap) * 65535).astype(np.uint16)
241
+ cv2.imwrite(save_depth_path, scaled_depth_map)
242
+ cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8))
243
+
244
+ save_meta_path = save_img_path.replace('jpg', 'npz')
245
+ np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics,
246
+ camera_pose=camera_pose, maximum_depth=np.max(input_depthmap))
247
+
248
+ return selected_sequences_numbers_dict
249
+
250
+
251
+ if __name__ == "__main__":
252
+ parser = get_parser()
253
+ args = parser.parse_args()
254
+ assert args.co3d_dir != args.output_dir
255
+ if args.category is None:
256
+ if args.single_sequence_subset:
257
+ categories = SINGLE_SEQUENCE_CATEGORIES
258
+ else:
259
+ categories = CATEGORIES
260
+ else:
261
+ categories = [args.category]
262
+ os.makedirs(args.output_dir, exist_ok=True)
263
+
264
+ for split in ['train', 'test']:
265
+ selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json')
266
+ if os.path.isfile(selected_sequences_path):
267
+ continue
268
+
269
+ all_selected_sequences = {}
270
+ for category in categories:
271
+ category_output_dir = osp.join(args.output_dir, category)
272
+ os.makedirs(category_output_dir, exist_ok=True)
273
+ category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json')
274
+ if os.path.isfile(category_selected_sequences_path):
275
+ with open(category_selected_sequences_path, 'r') as fid:
276
+ category_selected_sequences = json.load(fid)
277
+ else:
278
+ print(f"Processing {split} - category = {category}")
279
+ category_selected_sequences = prepare_sequences(
280
+ category=category,
281
+ co3d_dir=args.co3d_dir,
282
+ output_dir=args.output_dir,
283
+ img_size=args.img_size,
284
+ split=split,
285
+ min_quality=args.min_quality,
286
+ max_num_sequences_per_object=args.num_sequences_per_object,
287
+ seed=args.seed + CATEGORIES_IDX[category],
288
+ is_single_sequence_subset=args.single_sequence_subset
289
+ )
290
+ with open(category_selected_sequences_path, 'w') as file:
291
+ json.dump(category_selected_sequences, file)
292
+
293
+ all_selected_sequences[category] = category_selected_sequences
294
+ with open(selected_sequences_path, 'w') as file:
295
+ json.dump(all_selected_sequences, file)
dust3r/demo.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # gradio demo
7
+ # --------------------------------------------------------
8
+ import argparse
9
+ import gradio
10
+ import os
11
+ import torch
12
+ import numpy as np
13
+ import tempfile
14
+ import functools
15
+ import trimesh
16
+ import copy
17
+ from scipy.spatial.transform import Rotation
18
+
19
+ from dust3r.inference import inference, load_model
20
+ from dust3r.image_pairs import make_pairs
21
+ from dust3r.utils.image import load_images, rgb
22
+ from dust3r.utils.device import to_numpy
23
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
24
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
25
+
26
+ import matplotlib.pyplot as pl
27
+ pl.ion()
28
+
29
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
30
+ batch_size = 1
31
+
32
+
33
+ def get_args_parser():
34
+ parser = argparse.ArgumentParser()
35
+ parser_url = parser.add_mutually_exclusive_group()
36
+ parser_url.add_argument("--local_network", action='store_true', default=False,
37
+ help="make app accessible on local network: address will be set to 0.0.0.0")
38
+ parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1")
39
+ parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
40
+ parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). "
41
+ "If None, will search for an available port starting at 7860."),
42
+ default=None)
43
+ parser.add_argument("--weights", type=str, required=True, help="path to the model weights")
44
+ parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
45
+ parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir")
46
+ return parser
47
+
48
+
49
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
50
+ cam_color=None, as_pointcloud=False, transparent_cams=False):
51
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
52
+ pts3d = to_numpy(pts3d)
53
+ imgs = to_numpy(imgs)
54
+ focals = to_numpy(focals)
55
+ cams2world = to_numpy(cams2world)
56
+
57
+ scene = trimesh.Scene()
58
+
59
+ # full pointcloud
60
+ if as_pointcloud:
61
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
62
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
63
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
64
+ scene.add_geometry(pct)
65
+ else:
66
+ meshes = []
67
+ for i in range(len(imgs)):
68
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
69
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
70
+ scene.add_geometry(mesh)
71
+
72
+ # add each camera
73
+ for i, pose_c2w in enumerate(cams2world):
74
+ if isinstance(cam_color, list):
75
+ camera_edge_color = cam_color[i]
76
+ else:
77
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
78
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
79
+ None if transparent_cams else imgs[i], focals[i],
80
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
81
+
82
+ rot = np.eye(4)
83
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
84
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
85
+ outfile = os.path.join(outdir, 'scene.glb')
86
+ print('(exporting 3D scene to', outfile, ')')
87
+ scene.export(file_obj=outfile)
88
+ return outfile
89
+
90
+
91
+ def get_3D_model_from_scene(outdir, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
92
+ clean_depth=False, transparent_cams=False, cam_size=0.05):
93
+ """
94
+ extract 3D_model (glb file) from a reconstructed scene
95
+ """
96
+ if scene is None:
97
+ return None
98
+ # post processes
99
+ if clean_depth:
100
+ scene = scene.clean_pointcloud()
101
+ if mask_sky:
102
+ scene = scene.mask_sky()
103
+
104
+ # get optimized values from scene
105
+ rgbimg = scene.imgs
106
+ focals = scene.get_focals().cpu()
107
+ cams2world = scene.get_im_poses().cpu()
108
+ # 3D pointcloud from depthmap, poses and intrinsics
109
+ pts3d = to_numpy(scene.get_pts3d())
110
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
111
+ msk = to_numpy(scene.get_masks())
112
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
113
+ transparent_cams=transparent_cams, cam_size=cam_size)
114
+
115
+
116
+ def get_reconstructed_scene(outdir, model, device, image_size, filelist, schedule, niter, min_conf_thr,
117
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
118
+ scenegraph_type, winsize, refid):
119
+ """
120
+ from a list of images, run dust3r inference, global aligner.
121
+ then run get_3D_model_from_scene
122
+ """
123
+ imgs = load_images(filelist, size=image_size)
124
+ if len(imgs) == 1:
125
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
126
+ imgs[1]['idx'] = 1
127
+ if scenegraph_type == "swin":
128
+ scenegraph_type = scenegraph_type + "-" + str(winsize)
129
+ elif scenegraph_type == "oneref":
130
+ scenegraph_type = scenegraph_type + "-" + str(refid)
131
+
132
+ pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
133
+ output = inference(pairs, model, device, batch_size=batch_size)
134
+
135
+ mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
136
+ scene = global_aligner(output, device=device, mode=mode)
137
+ lr = 0.01
138
+
139
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
140
+ loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
141
+
142
+ outfile = get_3D_model_from_scene(outdir, scene, min_conf_thr, as_pointcloud, mask_sky,
143
+ clean_depth, transparent_cams, cam_size)
144
+
145
+ # also return rgb, depth and confidence imgs
146
+ # depth is normalized with the max value for all images
147
+ # we apply the jet colormap on the confidence maps
148
+ rgbimg = scene.imgs
149
+ depths = to_numpy(scene.get_depthmaps())
150
+ confs = to_numpy([c for c in scene.im_conf])
151
+ cmap = pl.get_cmap('jet')
152
+ depths_max = max([d.max() for d in depths])
153
+ depths = [d/depths_max for d in depths]
154
+ confs_max = max([d.max() for d in confs])
155
+ confs = [cmap(d/confs_max) for d in confs]
156
+
157
+ imgs = []
158
+ for i in range(len(rgbimg)):
159
+ imgs.append(rgbimg[i])
160
+ imgs.append(rgb(depths[i]))
161
+ imgs.append(rgb(confs[i]))
162
+
163
+ return scene, outfile, imgs
164
+
165
+
166
+ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
167
+ num_files = len(inputfiles) if inputfiles is not None else 1
168
+ max_winsize = max(1, (num_files - 1)//2)
169
+ if scenegraph_type == "swin":
170
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
171
+ minimum=1, maximum=max_winsize, step=1, visible=True)
172
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
173
+ maximum=num_files-1, step=1, visible=False)
174
+ elif scenegraph_type == "oneref":
175
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
176
+ minimum=1, maximum=max_winsize, step=1, visible=False)
177
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
178
+ maximum=num_files-1, step=1, visible=True)
179
+ else:
180
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
181
+ minimum=1, maximum=max_winsize, step=1, visible=False)
182
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
183
+ maximum=num_files-1, step=1, visible=False)
184
+ return winsize, refid
185
+
186
+
187
+ def main_demo(tmpdirname, model, device, image_size, server_name, server_port):
188
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, image_size)
189
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
190
+ with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo:
191
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
192
+ scene = gradio.State(None)
193
+ gradio.HTML('<h2 style="text-align: center;">DUSt3R Demo</h2>')
194
+ with gradio.Column():
195
+ inputfiles = gradio.File(file_count="multiple")
196
+ with gradio.Row():
197
+ schedule = gradio.Dropdown(["linear", "cosine"],
198
+ value='linear', label="schedule", info="For global alignment!")
199
+ niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
200
+ label="num_iterations", info="For global alignment!")
201
+ scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
202
+ value='complete', label="Scenegraph",
203
+ info="Define how to make pairs",
204
+ interactive=True)
205
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
206
+ minimum=1, maximum=1, step=1, visible=False)
207
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
208
+
209
+ run_btn = gradio.Button("Run")
210
+
211
+ with gradio.Row():
212
+ # adjust the confidence threshold
213
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
214
+ # adjust the camera size in the output pointcloud
215
+ cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
216
+ with gradio.Row():
217
+ as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
218
+ # two post process implemented
219
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
220
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
221
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
222
+
223
+ outmodel = gradio.Model3D()
224
+ outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%")
225
+
226
+ # events
227
+ scenegraph_type.change(set_scenegraph_options,
228
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
229
+ outputs=[winsize, refid])
230
+ inputfiles.change(set_scenegraph_options,
231
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
232
+ outputs=[winsize, refid])
233
+ run_btn.click(fn=recon_fun,
234
+ inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
235
+ mask_sky, clean_depth, transparent_cams, cam_size,
236
+ scenegraph_type, winsize, refid],
237
+ outputs=[scene, outmodel, outgallery])
238
+ min_conf_thr.release(fn=model_from_scene_fun,
239
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
240
+ clean_depth, transparent_cams, cam_size],
241
+ outputs=outmodel)
242
+ cam_size.change(fn=model_from_scene_fun,
243
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
244
+ clean_depth, transparent_cams, cam_size],
245
+ outputs=outmodel)
246
+ as_pointcloud.change(fn=model_from_scene_fun,
247
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
248
+ clean_depth, transparent_cams, cam_size],
249
+ outputs=outmodel)
250
+ mask_sky.change(fn=model_from_scene_fun,
251
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
252
+ clean_depth, transparent_cams, cam_size],
253
+ outputs=outmodel)
254
+ clean_depth.change(fn=model_from_scene_fun,
255
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
256
+ clean_depth, transparent_cams, cam_size],
257
+ outputs=outmodel)
258
+ transparent_cams.change(model_from_scene_fun,
259
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
260
+ clean_depth, transparent_cams, cam_size],
261
+ outputs=outmodel)
262
+ demo.launch(share=False, server_name=server_name, server_port=server_port)
263
+
264
+
265
+ if __name__ == '__main__':
266
+ parser = get_args_parser()
267
+ args = parser.parse_args()
268
+
269
+ if args.tmp_dir is not None:
270
+ tmp_path = args.tmp_dir
271
+ os.makedirs(tmp_path, exist_ok=True)
272
+ tempfile.tempdir = tmp_path
273
+
274
+ if args.server_name is not None:
275
+ server_name = args.server_name
276
+ else:
277
+ server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
278
+
279
+ model = load_model(args.weights, args.device)
280
+ # dust3r will write the 3D model inside tmpdirname
281
+ with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
282
+ print('Outputing stuff in', tmpdirname)
283
+ main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port)
dust3r/dust3r/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
dust3r/dust3r/cloud_opt/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # global alignment optimization wrapper function
6
+ # --------------------------------------------------------
7
+ from enum import Enum
8
+
9
+ from .optimizer import PointCloudOptimizer
10
+ from .pair_viewer import PairViewer
11
+
12
+
13
+ class GlobalAlignerMode(Enum):
14
+ PointCloudOptimizer = "PointCloudOptimizer"
15
+ PairViewer = "PairViewer"
16
+
17
+
18
+ def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw):
19
+ # extract all inputs
20
+ view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]
21
+ # build the optimizer
22
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
23
+ net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
24
+ elif mode == GlobalAlignerMode.PairViewer:
25
+ net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
26
+ else:
27
+ raise NotImplementedError(f'Unknown mode {mode}')
28
+
29
+ return net
dust3r/dust3r/cloud_opt/base_opt.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Base class for the global alignement procedure
6
+ # --------------------------------------------------------
7
+ from copy import deepcopy
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import roma
13
+ from copy import deepcopy
14
+ import tqdm
15
+
16
+ from dust3r.utils.geometry import inv, geotrf
17
+ from dust3r.utils.device import to_numpy
18
+ from dust3r.utils.image import rgb
19
+ from dust3r.viz import SceneViz, segment_sky, auto_cam_size
20
+ from dust3r.optim_factory import adjust_learning_rate_by_lr
21
+
22
+ from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
23
+ cosine_schedule, linear_schedule, get_conf_trf)
24
+ import dust3r.cloud_opt.init_im_poses as init_fun
25
+
26
+
27
+ class BasePCOptimizer (nn.Module):
28
+ """ Optimize a global scene, given a list of pairwise observations.
29
+ Graph node: images
30
+ Graph edges: observations = (pred1, pred2)
31
+ """
32
+
33
+ def __init__(self, *args, **kwargs):
34
+ if len(args) == 1 and len(kwargs) == 0:
35
+ other = deepcopy(args[0])
36
+ attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes
37
+ min_conf_thr conf_thr conf_i conf_j im_conf
38
+ base_scale norm_pw_scale POSE_DIM pw_poses
39
+ pw_adaptors pw_adaptors has_im_poses rand_pose imgs'''.split()
40
+ self.__dict__.update({k: other[k] for k in attrs})
41
+ else:
42
+ self._init_from_views(*args, **kwargs)
43
+
44
+ def _init_from_views(self, view1, view2, pred1, pred2,
45
+ dist='l1',
46
+ conf='log',
47
+ min_conf_thr=3,
48
+ base_scale=0.5,
49
+ allow_pw_adaptors=False,
50
+ pw_break=20,
51
+ rand_pose=torch.randn,
52
+ iterationsCount=None):
53
+ super().__init__()
54
+ if not isinstance(view1['idx'], list):
55
+ view1['idx'] = view1['idx'].tolist()
56
+ if not isinstance(view2['idx'], list):
57
+ view2['idx'] = view2['idx'].tolist()
58
+ self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
59
+ self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
60
+ self.dist = ALL_DISTS[dist]
61
+
62
+ self.n_imgs = self._check_edges()
63
+
64
+ # input data
65
+ pred1_pts = pred1['pts3d']
66
+ pred2_pts = pred2['pts3d_in_other_view']
67
+ self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
68
+ self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
69
+ self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
70
+
71
+ # work in log-scale with conf
72
+ pred1_conf = pred1['conf']
73
+ pred2_conf = pred2['conf']
74
+ self.min_conf_thr = min_conf_thr
75
+ self.conf_trf = get_conf_trf(conf)
76
+
77
+ self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
78
+ self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
79
+ self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
80
+
81
+ # pairwise pose parameters
82
+ self.base_scale = base_scale
83
+ self.norm_pw_scale = True
84
+ self.pw_break = pw_break
85
+ self.POSE_DIM = 7
86
+ self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses
87
+ self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation
88
+ self.pw_adaptors.requires_grad_(allow_pw_adaptors)
89
+ self.has_im_poses = False
90
+ self.rand_pose = rand_pose
91
+
92
+ # possibly store images for show_pointcloud
93
+ self.imgs = None
94
+ if 'img' in view1 and 'img' in view2:
95
+ imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
96
+ for v in range(len(self.edges)):
97
+ idx = view1['idx'][v]
98
+ imgs[idx] = view1['img'][v]
99
+ idx = view2['idx'][v]
100
+ imgs[idx] = view2['img'][v]
101
+ self.imgs = rgb(imgs)
102
+
103
+ @property
104
+ def n_edges(self):
105
+ return len(self.edges)
106
+
107
+ @property
108
+ def str_edges(self):
109
+ return [edge_str(i, j) for i, j in self.edges]
110
+
111
+ @property
112
+ def imsizes(self):
113
+ return [(w, h) for h, w in self.imshapes]
114
+
115
+ @property
116
+ def device(self):
117
+ return next(iter(self.parameters())).device
118
+
119
+ def state_dict(self, trainable=True):
120
+ all_params = super().state_dict()
121
+ return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
122
+
123
+ def load_state_dict(self, data):
124
+ return super().load_state_dict(self.state_dict(trainable=False) | data)
125
+
126
+ def _check_edges(self):
127
+ indices = sorted({i for edge in self.edges for i in edge})
128
+ assert indices == list(range(len(indices))), 'bad pair indices: missing values '
129
+ return len(indices)
130
+
131
+ @torch.no_grad()
132
+ def _compute_img_conf(self, pred1_conf, pred2_conf):
133
+ im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
134
+ for e, (i, j) in enumerate(self.edges):
135
+ im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
136
+ im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
137
+ return im_conf
138
+
139
+ def get_adaptors(self):
140
+ adapt = self.pw_adaptors
141
+ adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z)
142
+ if self.norm_pw_scale: # normalize so that the product == 1
143
+ adapt = adapt - adapt.mean(dim=1, keepdim=True)
144
+ return (adapt / self.pw_break).exp()
145
+
146
+ def _get_poses(self, poses):
147
+ # normalize rotation
148
+ Q = poses[:, :4]
149
+ T = signed_expm1(poses[:, 4:7])
150
+ RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
151
+ return RT
152
+
153
+ def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
154
+ # all poses == cam-to-world
155
+ pose = poses[idx]
156
+ if not (pose.requires_grad or force):
157
+ return pose
158
+
159
+ if R.shape == (4, 4):
160
+ assert T is None
161
+ T = R[:3, 3]
162
+ R = R[:3, :3]
163
+
164
+ if R is not None:
165
+ pose.data[0:4] = roma.rotmat_to_unitquat(R)
166
+ if T is not None:
167
+ pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale
168
+
169
+ if scale is not None:
170
+ assert poses.shape[-1] in (8, 13)
171
+ pose.data[-1] = np.log(float(scale))
172
+ return pose
173
+
174
+ def get_pw_norm_scale_factor(self):
175
+ if self.norm_pw_scale:
176
+ # normalize scales so that things cannot go south
177
+ # we want that exp(scale) ~= self.base_scale
178
+ return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
179
+ else:
180
+ return 1 # don't norm scale for known poses
181
+
182
+ def get_pw_scale(self):
183
+ scale = self.pw_poses[:, -1].exp() # (n_edges,)
184
+ scale = scale * self.get_pw_norm_scale_factor()
185
+ return scale
186
+
187
+ def get_pw_poses(self): # cam to world
188
+ RT = self._get_poses(self.pw_poses)
189
+ scaled_RT = RT.clone()
190
+ scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation
191
+ return scaled_RT
192
+
193
+ def get_masks(self):
194
+ return [(conf > self.min_conf_thr) for conf in self.im_conf]
195
+
196
+ def depth_to_pts3d(self):
197
+ raise NotImplementedError()
198
+
199
+ def get_pts3d(self, raw=False):
200
+ res = self.depth_to_pts3d()
201
+ if not raw:
202
+ res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
203
+ return res
204
+
205
+ def _set_focal(self, idx, focal, force=False):
206
+ raise NotImplementedError()
207
+
208
+ def get_focals(self):
209
+ raise NotImplementedError()
210
+
211
+ def get_known_focal_mask(self):
212
+ raise NotImplementedError()
213
+
214
+ def get_principal_points(self):
215
+ raise NotImplementedError()
216
+
217
+ def get_conf(self, mode=None):
218
+ trf = self.conf_trf if mode is None else get_conf_trf(mode)
219
+ return [trf(c) for c in self.im_conf]
220
+
221
+ def get_im_poses(self):
222
+ raise NotImplementedError()
223
+
224
+ def _set_depthmap(self, idx, depth, force=False):
225
+ raise NotImplementedError()
226
+
227
+ def get_depthmaps(self, raw=False):
228
+ raise NotImplementedError()
229
+
230
+ @torch.no_grad()
231
+ def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
232
+ """ Method:
233
+ 1) express all 3d points in each camera coordinate frame
234
+ 2) if they're in front of a depthmap --> then lower their confidence
235
+ """
236
+ assert 0 <= tol < 1
237
+ cams = inv(self.get_im_poses())
238
+ K = self.get_intrinsics()
239
+ depthmaps = self.get_depthmaps()
240
+ res = deepcopy(self)
241
+
242
+ for i, pts3d in enumerate(self.depth_to_pts3d()):
243
+ for j in range(self.n_imgs):
244
+ if i == j:
245
+ continue
246
+
247
+ # project 3dpts in other view
248
+ Hi, Wi = self.imshapes[i]
249
+ Hj, Wj = self.imshapes[j]
250
+ proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
251
+ proj_depth = proj[:, :, 2]
252
+ u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
253
+
254
+ # check which points are actually in the visible cone
255
+ msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
256
+ msk_j = v[msk_i], u[msk_i]
257
+
258
+ # find bad points = those in front but less confident
259
+ bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]
260
+ ) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j])
261
+
262
+ bad_msk_i = msk_i.clone()
263
+ bad_msk_i[msk_i] = bad_points
264
+ res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf)
265
+
266
+ return res
267
+
268
+ def forward(self, ret_details=False):
269
+ pw_poses = self.get_pw_poses() # cam-to-world
270
+ pw_adapt = self.get_adaptors()
271
+ proj_pts3d = self.get_pts3d()
272
+ # pre-compute pixel weights
273
+ weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
274
+ weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
275
+
276
+ loss = 0
277
+ if ret_details:
278
+ details = -torch.ones((self.n_imgs, self.n_imgs))
279
+
280
+ for e, (i, j) in enumerate(self.edges):
281
+ i_j = edge_str(i, j)
282
+ # distance in image i and j
283
+ aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
284
+ aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
285
+ li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
286
+ lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
287
+ loss = loss + li + lj
288
+
289
+ if ret_details:
290
+ details[i, j] = li + lj
291
+ loss /= self.n_edges # average over all pairs
292
+
293
+ if ret_details:
294
+ return loss, details
295
+ return loss
296
+
297
+ def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
298
+ if init is None:
299
+ pass
300
+ elif init == 'msp' or init == 'mst':
301
+ init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
302
+ elif init == 'known_poses':
303
+ init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP)
304
+ else:
305
+ raise ValueError(f'bad value for {init=}')
306
+
307
+ global_alignment_loop(self, **kw)
308
+
309
+ @torch.no_grad()
310
+ def mask_sky(self):
311
+ res = deepcopy(self)
312
+ for i in range(self.n_imgs):
313
+ sky = segment_sky(self.imgs[i])
314
+ res.im_conf[i][sky] = 0
315
+ return res
316
+
317
+ def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
318
+ viz = SceneViz()
319
+ if self.imgs is None:
320
+ colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
321
+ colors = list(map(tuple, colors.tolist()))
322
+ for n in range(self.n_imgs):
323
+ viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
324
+ else:
325
+ viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
326
+ colors = np.random.randint(256, size=(self.n_imgs, 3))
327
+
328
+ # camera poses
329
+ im_poses = to_numpy(self.get_im_poses())
330
+ if cam_size is None:
331
+ cam_size = auto_cam_size(im_poses)
332
+ viz.add_cameras(im_poses, self.get_focals(), colors=colors,
333
+ images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
334
+ if show_pw_cams:
335
+ pw_poses = self.get_pw_poses()
336
+ viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
337
+
338
+ if show_pw_pts3d:
339
+ pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
340
+ viz.add_pointcloud(pts, (128, 0, 128))
341
+
342
+ viz.show(**kw)
343
+ return viz
344
+
345
+
346
+ def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6, verbose=False):
347
+ params = [p for p in net.parameters() if p.requires_grad]
348
+ if not params:
349
+ return net
350
+
351
+ if verbose:
352
+ print([name for name, value in net.named_parameters() if value.requires_grad])
353
+
354
+ lr_base = lr
355
+ optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
356
+
357
+ with tqdm.tqdm(total=niter) as bar:
358
+ while bar.n < bar.total:
359
+ t = bar.n / bar.total
360
+
361
+ if schedule == 'cosine':
362
+ lr = cosine_schedule(t, lr_base, lr_min)
363
+ elif schedule == 'linear':
364
+ lr = linear_schedule(t, lr_base, lr_min)
365
+ else:
366
+ raise ValueError(f'bad lr {schedule=}')
367
+ adjust_learning_rate_by_lr(optimizer, lr)
368
+
369
+ optimizer.zero_grad()
370
+ loss = net()
371
+ loss.backward()
372
+ optimizer.step()
373
+ loss = float(loss)
374
+ bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
375
+ bar.update()
dust3r/dust3r/cloud_opt/commons.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utility functions for global alignment
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+
11
+
12
+ def edge_str(i, j):
13
+ return f'{i}_{j}'
14
+
15
+
16
+ def i_j_ij(ij):
17
+ return edge_str(*ij), ij
18
+
19
+
20
+ def edge_conf(conf_i, conf_j, edge):
21
+ return float(conf_i[edge].mean() * conf_j[edge].mean())
22
+
23
+
24
+ def compute_edge_scores(edges, conf_i, conf_j):
25
+ return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
26
+
27
+
28
+ def NoGradParamDict(x):
29
+ assert isinstance(x, dict)
30
+ return nn.ParameterDict(x).requires_grad_(False)
31
+
32
+
33
+ def get_imshapes(edges, pred_i, pred_j):
34
+ n_imgs = max(max(e) for e in edges) + 1
35
+ imshapes = [None] * n_imgs
36
+ for e, (i, j) in enumerate(edges):
37
+ shape_i = tuple(pred_i[e].shape[0:2])
38
+ shape_j = tuple(pred_j[e].shape[0:2])
39
+ if imshapes[i]:
40
+ assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
41
+ if imshapes[j]:
42
+ assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
43
+ imshapes[i] = shape_i
44
+ imshapes[j] = shape_j
45
+ return imshapes
46
+
47
+
48
+ def get_conf_trf(mode):
49
+ if mode == 'log':
50
+ def conf_trf(x): return x.log()
51
+ elif mode == 'sqrt':
52
+ def conf_trf(x): return x.sqrt()
53
+ elif mode == 'm1':
54
+ def conf_trf(x): return x-1
55
+ elif mode in ('id', 'none'):
56
+ def conf_trf(x): return x
57
+ else:
58
+ raise ValueError(f'bad mode for {mode=}')
59
+ return conf_trf
60
+
61
+
62
+ def l2_dist(a, b, weight):
63
+ return ((a - b).square().sum(dim=-1) * weight)
64
+
65
+
66
+ def l1_dist(a, b, weight):
67
+ return ((a - b).norm(dim=-1) * weight)
68
+
69
+
70
+ ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
71
+
72
+
73
+ def signed_log1p(x):
74
+ sign = torch.sign(x)
75
+ return sign * torch.log1p(torch.abs(x))
76
+
77
+
78
+ def signed_expm1(x):
79
+ sign = torch.sign(x)
80
+ return sign * torch.expm1(torch.abs(x))
81
+
82
+
83
+ def cosine_schedule(t, lr_start, lr_end):
84
+ assert 0 <= t <= 1
85
+ return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2
86
+
87
+
88
+ def linear_schedule(t, lr_start, lr_end):
89
+ assert 0 <= t <= 1
90
+ return lr_start + (lr_end - lr_start) * t
dust3r/dust3r/cloud_opt/init_im_poses.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Initialization functions for global alignment
6
+ # --------------------------------------------------------
7
+ from functools import cache
8
+
9
+ import numpy as np
10
+ import scipy.sparse as sp
11
+ import torch
12
+ import cv2
13
+ import roma
14
+ from tqdm import tqdm
15
+
16
+ from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
17
+ from dust3r.post_process import estimate_focal_knowing_depth
18
+ from dust3r.viz import to_numpy
19
+
20
+ from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
21
+
22
+
23
+ @torch.no_grad()
24
+ def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
25
+ device = self.device
26
+
27
+ # indices of known poses
28
+ nkp, known_poses_msk, known_poses = get_known_poses(self)
29
+ assert nkp == self.n_imgs, 'not all poses are known'
30
+
31
+ # get all focals
32
+ nkf, _, im_focals = get_known_focals(self)
33
+ assert nkf == self.n_imgs
34
+ im_pp = self.get_principal_points()
35
+
36
+ best_depthmaps = {}
37
+ # init all pairwise poses
38
+ for e, (i, j) in enumerate(tqdm(self.edges)):
39
+ i_j = edge_str(i, j)
40
+
41
+ # find relative pose for this pair
42
+ P1 = torch.eye(4, device=device)
43
+ msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
44
+ _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),
45
+ pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)
46
+
47
+ # align the two predicted camera with the two gt cameras
48
+ s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
49
+ # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
50
+ # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
51
+ self._set_pose(self.pw_poses, e, R, T, scale=s)
52
+
53
+ # remember if this is a good depthmap
54
+ score = float(self.conf_i[i_j].mean())
55
+ if score > best_depthmaps.get(i, (0,))[0]:
56
+ best_depthmaps[i] = score, i_j, s
57
+
58
+ # init all image poses
59
+ for n in range(self.n_imgs):
60
+ assert known_poses_msk[n]
61
+ _, i_j, scale = best_depthmaps[n]
62
+ depth = self.pred_i[i_j][:, :, 2]
63
+ self._set_depthmap(n, depth * scale)
64
+
65
+
66
+ @torch.no_grad()
67
+ def init_minimum_spanning_tree(self, **kw):
68
+ """ Init all camera poses (image-wise and pairwise poses) given
69
+ an initial set of pairwise estimations.
70
+ """
71
+ device = self.device
72
+ pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,
73
+ self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,
74
+ device, has_im_poses=self.has_im_poses, **kw)
75
+
76
+ return init_from_pts3d(self, pts3d, im_focals, im_poses)
77
+
78
+
79
+ def init_from_pts3d(self, pts3d, im_focals, im_poses):
80
+ # init poses
81
+ nkp, known_poses_msk, known_poses = get_known_poses(self)
82
+ if nkp == 1:
83
+ raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
84
+ elif nkp > 1:
85
+ # global rigid SE3 alignment
86
+ s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
87
+ trf = sRT_to_4x4(s, R, T, device=known_poses.device)
88
+
89
+ # rotate everything
90
+ im_poses = trf @ im_poses
91
+ im_poses[:, :3, :3] /= s # undo scaling on the rotation part
92
+ for img_pts3d in pts3d:
93
+ img_pts3d[:] = geotrf(trf, img_pts3d)
94
+
95
+ # set all pairwise poses
96
+ for e, (i, j) in enumerate(self.edges):
97
+ i_j = edge_str(i, j)
98
+ # compute transform that goes from cam to world
99
+ s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j])
100
+ self._set_pose(self.pw_poses, e, R, T, scale=s)
101
+
102
+ # take into account the scale normalization
103
+ s_factor = self.get_pw_norm_scale_factor()
104
+ im_poses[:, :3, 3] *= s_factor # apply downscaling factor
105
+ for img_pts3d in pts3d:
106
+ img_pts3d *= s_factor
107
+
108
+ # init all image poses
109
+ if self.has_im_poses:
110
+ for i in range(self.n_imgs):
111
+ cam2world = im_poses[i]
112
+ depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
113
+ self._set_depthmap(i, depth)
114
+ self._set_pose(self.im_poses, i, cam2world)
115
+ if im_focals[i] is not None:
116
+ self._set_focal(i, im_focals[i])
117
+
118
+ print(' init loss =', float(self()))
119
+
120
+
121
+ def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,
122
+ device, has_im_poses=True, niter_PnP=10):
123
+ n_imgs = len(imshapes)
124
+ sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j))
125
+ msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
126
+
127
+ # temp variable to store 3d points
128
+ pts3d = [None] * len(imshapes)
129
+
130
+ todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
131
+ im_poses = [None] * n_imgs
132
+ im_focals = [None] * n_imgs
133
+
134
+ # init with strongest edge
135
+ score, i, j = todo.pop()
136
+ print(f' init edge ({i}*,{j}*) {score=}')
137
+ i_j = edge_str(i, j)
138
+ pts3d[i] = pred_i[i_j].clone()
139
+ pts3d[j] = pred_j[i_j].clone()
140
+ done = {i, j}
141
+ if has_im_poses:
142
+ im_poses[i] = torch.eye(4, device=device)
143
+ im_focals[i] = estimate_focal(pred_i[i_j])
144
+
145
+ # set intial pointcloud based on pairwise graph
146
+ msp_edges = [(i, j)]
147
+ while todo:
148
+ # each time, predict the next one
149
+ score, i, j = todo.pop()
150
+
151
+ if im_focals[i] is None:
152
+ im_focals[i] = estimate_focal(pred_i[i_j])
153
+
154
+ if i in done:
155
+ print(f' init edge ({i},{j}*) {score=}')
156
+ assert j not in done
157
+ # align pred[i] with pts3d[i], and then set j accordingly
158
+ i_j = edge_str(i, j)
159
+ s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j])
160
+ trf = sRT_to_4x4(s, R, T, device)
161
+ pts3d[j] = geotrf(trf, pred_j[i_j])
162
+ done.add(j)
163
+ msp_edges.append((i, j))
164
+
165
+ if has_im_poses and im_poses[i] is None:
166
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
167
+
168
+ elif j in done:
169
+ print(f' init edge ({i}*,{j}) {score=}')
170
+ assert i not in done
171
+ i_j = edge_str(i, j)
172
+ s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j])
173
+ trf = sRT_to_4x4(s, R, T, device)
174
+ pts3d[i] = geotrf(trf, pred_i[i_j])
175
+ done.add(i)
176
+ msp_edges.append((i, j))
177
+
178
+ if has_im_poses and im_poses[i] is None:
179
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
180
+ else:
181
+ # let's try again later
182
+ todo.insert(0, (score, i, j))
183
+
184
+ if has_im_poses:
185
+ # complete all missing informations
186
+ pair_scores = list(sparse_graph.values()) # already negative scores: less is best
187
+ edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]
188
+ for i, j in edges_from_best_to_worse.tolist():
189
+ if im_focals[i] is None:
190
+ im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
191
+
192
+ for i in range(n_imgs):
193
+ if im_poses[i] is None:
194
+ msk = im_conf[i] > min_conf_thr
195
+ res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)
196
+ if res:
197
+ im_focals[i], im_poses[i] = res
198
+ if im_poses[i] is None:
199
+ im_poses[i] = torch.eye(4, device=device)
200
+ im_poses = torch.stack(im_poses)
201
+ else:
202
+ im_poses = im_focals = None
203
+
204
+ return pts3d, msp_edges, im_focals, im_poses
205
+
206
+
207
+ def dict_to_sparse_graph(dic):
208
+ n_imgs = max(max(e) for e in dic) + 1
209
+ res = sp.dok_array((n_imgs, n_imgs))
210
+ for edge, value in dic.items():
211
+ res[edge] = value
212
+ return res
213
+
214
+
215
+ def rigid_points_registration(pts1, pts2, conf):
216
+ R, T, s = roma.rigid_points_registration(
217
+ pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)
218
+ return s, R, T # return un-scaled (R, T)
219
+
220
+
221
+ def sRT_to_4x4(scale, R, T, device):
222
+ trf = torch.eye(4, device=device)
223
+ trf[:3, :3] = R * scale
224
+ trf[:3, 3] = T.ravel() # doesn't need scaling
225
+ return trf
226
+
227
+
228
+ def estimate_focal(pts3d_i, pp=None):
229
+ if pp is None:
230
+ H, W, THREE = pts3d_i.shape
231
+ assert THREE == 3
232
+ pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
233
+ focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(
234
+ 0), focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5).ravel()
235
+ return float(focal)
236
+
237
+
238
+ @cache
239
+ def pixel_grid(H, W):
240
+ return np.mgrid[:W, :H].T.astype(np.float32)
241
+
242
+
243
+ def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
244
+ # extract camera poses and focals with RANSAC-PnP
245
+ if msk.sum() < 4:
246
+ return None # we need at least 4 points for PnP
247
+ pts3d, msk = map(to_numpy, (pts3d, msk))
248
+
249
+ H, W, THREE = pts3d.shape
250
+ assert THREE == 3
251
+ pixels = pixel_grid(H, W)
252
+
253
+ if focal is None:
254
+ S = max(W, H)
255
+ tentative_focals = np.geomspace(S/2, S*3, 21)
256
+ else:
257
+ tentative_focals = [focal]
258
+
259
+ if pp is None:
260
+ pp = (W/2, H/2)
261
+ else:
262
+ pp = to_numpy(pp)
263
+
264
+ best = 0,
265
+ for focal in tentative_focals:
266
+ K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
267
+
268
+ success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
269
+ iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
270
+ if not success:
271
+ continue
272
+
273
+ score = len(inliers)
274
+ if success and score > best[0]:
275
+ best = score, R, T, focal
276
+
277
+ if not best[0]:
278
+ return None
279
+
280
+ _, R, T, best_focal = best
281
+ R = cv2.Rodrigues(R)[0] # world to cam
282
+ R, T = map(torch.from_numpy, (R, T))
283
+ return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
284
+
285
+
286
+ def get_known_poses(self):
287
+ if self.has_im_poses:
288
+ known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
289
+ known_poses = self.get_im_poses()
290
+ return known_poses_msk.sum(), known_poses_msk, known_poses
291
+ else:
292
+ return 0, None, None
293
+
294
+
295
+ def get_known_focals(self):
296
+ if self.has_im_poses:
297
+ known_focal_msk = self.get_known_focal_mask()
298
+ known_focals = self.get_focals()
299
+ return known_focal_msk.sum(), known_focal_msk, known_focals
300
+ else:
301
+ return 0, None, None
302
+
303
+
304
+ def align_multiple_poses(src_poses, target_poses):
305
+ N = len(src_poses)
306
+ assert src_poses.shape == target_poses.shape == (N, 4, 4)
307
+
308
+ def center_and_z(poses):
309
+ eps = get_med_dist_between_poses(poses) / 100
310
+ return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))
311
+ R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)
312
+ return s, R, T
dust3r/dust3r/cloud_opt/optimizer.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Main class for the implementation of the global alignment
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from dust3r.cloud_opt.base_opt import BasePCOptimizer
12
+ from dust3r.utils.geometry import xy_grid, geotrf
13
+ from dust3r.utils.device import to_cpu, to_numpy
14
+
15
+
16
+ class PointCloudOptimizer(BasePCOptimizer):
17
+ """ Optimize a global scene, given a list of pairwise observations.
18
+ Graph node: images
19
+ Graph edges: observations = (pred1, pred2)
20
+ """
21
+
22
+ def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+
25
+ self.has_im_poses = True # by definition of this class
26
+ self.focal_break = focal_break
27
+
28
+ # adding thing to optimize
29
+ self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
30
+ self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
31
+ self.im_focals = nn.ParameterList(torch.FloatTensor(
32
+ [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
33
+ self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
34
+ self.im_pp.requires_grad_(optimize_pp)
35
+
36
+ self.imshape = self.imshapes[0]
37
+ im_areas = [h*w for h, w in self.imshapes]
38
+ self.max_area = max(im_areas)
39
+
40
+ # adding thing to optimize
41
+ self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
42
+ self.im_poses = ParameterStack(self.im_poses, is_param=True)
43
+ self.im_focals = ParameterStack(self.im_focals, is_param=True)
44
+ self.im_pp = ParameterStack(self.im_pp, is_param=True)
45
+ self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
46
+ self.register_buffer('_grid', ParameterStack(
47
+ [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
48
+
49
+ # pre-compute pixel weights
50
+ self.register_buffer('_weight_i', ParameterStack(
51
+ [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
52
+ self.register_buffer('_weight_j', ParameterStack(
53
+ [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
54
+
55
+ # precompute aa
56
+ self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
57
+ self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
58
+ self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
59
+ self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
60
+ self.total_area_i = sum([im_areas[i] for i, j in self.edges])
61
+ self.total_area_j = sum([im_areas[j] for i, j in self.edges])
62
+
63
+ def _check_all_imgs_are_selected(self, msk):
64
+ assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
65
+
66
+ def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
67
+ self._check_all_imgs_are_selected(pose_msk)
68
+
69
+ if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
70
+ known_poses = [known_poses]
71
+ for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
72
+ print(f' (setting pose #{idx} = {pose[:3,3]})')
73
+ self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
74
+
75
+ # normalize scale if there's less than 1 known pose
76
+ n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
77
+ self.norm_pw_scale = (n_known_poses <= 1)
78
+
79
+ self.im_poses.requires_grad_(False)
80
+ self.norm_pw_scale = False
81
+
82
+ def preset_focal(self, known_focals, msk=None):
83
+ self._check_all_imgs_are_selected(msk)
84
+
85
+ for idx, focal in zip(self._get_msk_indices(msk), known_focals):
86
+ print(f' (setting focal #{idx} = {focal})')
87
+ self._no_grad(self._set_focal(idx, focal))
88
+
89
+ self.im_focals.requires_grad_(False)
90
+
91
+ def preset_principal_point(self, known_pp, msk=None):
92
+ self._check_all_imgs_are_selected(msk)
93
+
94
+ for idx, pp in zip(self._get_msk_indices(msk), known_pp):
95
+ print(f' (setting principal point #{idx} = {pp})')
96
+ self._no_grad(self._set_principal_point(idx, pp))
97
+
98
+ self.im_pp.requires_grad_(False)
99
+
100
+ def _no_grad(self, tensor):
101
+ assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
102
+
103
+ def _set_focal(self, idx, focal, force=False):
104
+ param = self.im_focals[idx]
105
+ if param.requires_grad or force: # can only init a parameter not already initialized
106
+ param.data[:] = self.focal_break * np.log(focal)
107
+ return param
108
+
109
+ def get_focals(self):
110
+ log_focals = torch.stack(list(self.im_focals), dim=0)
111
+ return (log_focals / self.focal_break).exp()
112
+
113
+ def get_known_focal_mask(self):
114
+ return torch.tensor([not (p.requires_grad) for p in self.im_focals])
115
+
116
+ def _set_principal_point(self, idx, pp, force=False):
117
+ param = self.im_pp[idx]
118
+ H, W = self.imshapes[idx]
119
+ if param.requires_grad or force: # can only init a parameter not already initialized
120
+ param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
121
+ return param
122
+
123
+ def get_principal_points(self):
124
+ return self._pp + 10 * self.im_pp
125
+
126
+ def get_intrinsics(self):
127
+ K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
128
+ focals = self.get_focals().flatten()
129
+ K[:, 0, 0] = K[:, 1, 1] = focals
130
+ K[:, :2, 2] = self.get_principal_points()
131
+ K[:, 2, 2] = 1
132
+ return K
133
+
134
+ def get_im_poses(self): # cam to world
135
+ cam2world = self._get_poses(self.im_poses)
136
+ return cam2world
137
+
138
+ def _set_depthmap(self, idx, depth, force=False):
139
+ depth = _ravel_hw(depth, self.max_area)
140
+
141
+ param = self.im_depthmaps[idx]
142
+ if param.requires_grad or force: # can only init a parameter not already initialized
143
+ param.data[:] = depth.log().nan_to_num(neginf=0)
144
+ return param
145
+
146
+ def get_depthmaps(self, raw=False):
147
+ res = self.im_depthmaps.exp()
148
+ if not raw:
149
+ res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
150
+ return res
151
+
152
+ def depth_to_pts3d(self):
153
+ # Get depths and projection params if not provided
154
+ focals = self.get_focals()
155
+ pp = self.get_principal_points()
156
+ im_poses = self.get_im_poses()
157
+ depth = self.get_depthmaps(raw=True)
158
+
159
+ # get pointmaps in camera frame
160
+ rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
161
+ # project to world frame
162
+ return geotrf(im_poses, rel_ptmaps)
163
+
164
+ def get_pts3d(self, raw=False):
165
+ res = self.depth_to_pts3d()
166
+ if not raw:
167
+ res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
168
+ return res
169
+
170
+ def forward(self):
171
+ pw_poses = self.get_pw_poses() # cam-to-world
172
+ pw_adapt = self.get_adaptors().unsqueeze(1)
173
+ proj_pts3d = self.get_pts3d(raw=True)
174
+
175
+ # rotate pairwise prediction according to pw_poses
176
+ aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
177
+ aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
178
+
179
+ # compute the less
180
+ li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
181
+ lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
182
+
183
+ return li + lj
184
+
185
+
186
+ def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
187
+ pp = pp.unsqueeze(1)
188
+ focal = focal.unsqueeze(1)
189
+ assert focal.shape == (len(depth), 1, 1)
190
+ assert pp.shape == (len(depth), 1, 2)
191
+ assert pixel_grid.shape == depth.shape + (2,)
192
+ depth = depth.unsqueeze(-1)
193
+ return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)
194
+
195
+
196
+ def ParameterStack(params, keys=None, is_param=None, fill=0):
197
+ if keys is not None:
198
+ params = [params[k] for k in keys]
199
+
200
+ if fill > 0:
201
+ params = [_ravel_hw(p, fill) for p in params]
202
+
203
+ requires_grad = params[0].requires_grad
204
+ assert all(p.requires_grad == requires_grad for p in params)
205
+
206
+ params = torch.stack(list(params)).float().detach()
207
+ if is_param or requires_grad:
208
+ params = nn.Parameter(params)
209
+ params.requires_grad_(requires_grad)
210
+ return params
211
+
212
+
213
+ def _ravel_hw(tensor, fill=0):
214
+ # ravel H,W
215
+ tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
216
+
217
+ if len(tensor) < fill:
218
+ tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
219
+ return tensor
220
+
221
+
222
+ def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
223
+ focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
224
+ return minf*focal_base, maxf*focal_base
225
+
226
+
227
+ def apply_mask(img, msk):
228
+ img = img.copy()
229
+ img[msk] = 0
230
+ return img
dust3r/dust3r/cloud_opt/pair_viewer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Dummy optimizer for visualizing pairs
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import cv2
11
+
12
+ from dust3r.cloud_opt.base_opt import BasePCOptimizer
13
+ from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
14
+ from dust3r.cloud_opt.commons import edge_str
15
+ from dust3r.post_process import estimate_focal_knowing_depth
16
+
17
+
18
+ class PairViewer (BasePCOptimizer):
19
+ """
20
+ This a Dummy Optimizer.
21
+ To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
22
+ """
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ assert self.is_symmetrized and self.n_edges == 2
27
+ self.has_im_poses = True
28
+
29
+ # compute all parameters directly from raw input
30
+ self.focals = []
31
+ self.pp = []
32
+ rel_poses = []
33
+ confs = []
34
+ for i in range(self.n_imgs):
35
+ conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
36
+ print(f' - {conf=:.3} for edge {i}-{1-i}')
37
+ confs.append(conf)
38
+
39
+ H, W = self.imshapes[i]
40
+ pts3d = self.pred_i[edge_str(i, 1-i)]
41
+ pp = torch.tensor((W/2, H/2))
42
+ focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
43
+ self.focals.append(focal)
44
+ self.pp.append(pp)
45
+
46
+ # estimate the pose of pts1 in image 2
47
+ pixels = np.mgrid[:W, :H].T.astype(np.float32)
48
+ pts3d = self.pred_j[edge_str(1-i, i)].numpy()
49
+ assert pts3d.shape[:2] == (H, W)
50
+ msk = self.get_masks()[i].numpy()
51
+ K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
52
+
53
+ try:
54
+ res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
55
+ iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
56
+ success, R, T, inliers = res
57
+ assert success
58
+
59
+ R = cv2.Rodrigues(R)[0] # world to cam
60
+ pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
61
+ except:
62
+ pose = np.eye(4)
63
+ rel_poses.append(torch.from_numpy(pose.astype(np.float32)))
64
+
65
+ # let's use the pair with the most confidence
66
+ if confs[0] > confs[1]:
67
+ # ptcloud is expressed in camera1
68
+ self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1
69
+ self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
70
+ else:
71
+ # ptcloud is expressed in camera2
72
+ self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2
73
+ self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
74
+
75
+ self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
76
+ self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
77
+ self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
78
+ self.depth = nn.ParameterList(self.depth)
79
+ for p in self.parameters():
80
+ p.requires_grad = False
81
+
82
+ def _set_depthmap(self, idx, depth, force=False):
83
+ print('_set_depthmap is ignored in PairViewer')
84
+ return
85
+
86
+ def get_depthmaps(self, raw=False):
87
+ depth = [d.to(self.device) for d in self.depth]
88
+ return depth
89
+
90
+ def _set_focal(self, idx, focal, force=False):
91
+ self.focals[idx] = focal
92
+
93
+ def get_focals(self):
94
+ return self.focals
95
+
96
+ def get_known_focal_mask(self):
97
+ return torch.tensor([not (p.requires_grad) for p in self.focals])
98
+
99
+ def get_principal_points(self):
100
+ return self.pp
101
+
102
+ def get_intrinsics(self):
103
+ focals = self.get_focals()
104
+ pps = self.get_principal_points()
105
+ K = torch.zeros((len(focals), 3, 3), device=self.device)
106
+ for i in range(len(focals)):
107
+ K[i, 0, 0] = K[i, 1, 1] = focals[i]
108
+ K[i, :2, 2] = pps[i]
109
+ K[i, 2, 2] = 1
110
+ return K
111
+
112
+ def get_im_poses(self):
113
+ return self.im_poses
114
+
115
+ def depth_to_pts3d(self):
116
+ pts3d = []
117
+ for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
118
+ pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
119
+ intrinsics.cpu().numpy(),
120
+ im_pose.cpu().numpy())
121
+ pts3d.append(torch.from_numpy(pts).to(device=self.device))
122
+ return pts3d
123
+
124
+ def forward(self):
125
+ return float('nan')
dust3r/dust3r/datasets/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ from .utils.transforms import *
4
+ from .base.batched_sampler import BatchedRandomSampler # noqa: F401
5
+ from .co3d import Co3d # noqa: F401
6
+
7
+
8
+ def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
9
+ import torch
10
+ from croco.utils.misc import get_world_size, get_rank
11
+
12
+ # pytorch dataset
13
+ if isinstance(dataset, str):
14
+ dataset = eval(dataset)
15
+
16
+ world_size = get_world_size()
17
+ rank = get_rank()
18
+
19
+ try:
20
+ sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
21
+ rank=rank, drop_last=drop_last)
22
+ except (AttributeError, NotImplementedError):
23
+ # not avail for this dataset
24
+ if torch.distributed.is_initialized():
25
+ sampler = torch.utils.data.DistributedSampler(
26
+ dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
27
+ )
28
+ elif shuffle:
29
+ sampler = torch.utils.data.RandomSampler(dataset)
30
+ else:
31
+ sampler = torch.utils.data.SequentialSampler(dataset)
32
+
33
+ data_loader = torch.utils.data.DataLoader(
34
+ dataset,
35
+ sampler=sampler,
36
+ batch_size=batch_size,
37
+ num_workers=num_workers,
38
+ pin_memory=pin_mem,
39
+ drop_last=drop_last,
40
+ )
41
+
42
+ return data_loader
dust3r/dust3r/datasets/base/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
dust3r/dust3r/datasets/base/base_stereo_view_dataset.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # base class for implementing datasets
6
+ # --------------------------------------------------------
7
+ import PIL
8
+ import numpy as np
9
+ import torch
10
+
11
+ from dust3r.datasets.base.easy_dataset import EasyDataset
12
+ from dust3r.datasets.utils.transforms import ImgNorm
13
+ from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
14
+ import dust3r.datasets.utils.cropping as cropping
15
+
16
+
17
+ class BaseStereoViewDataset (EasyDataset):
18
+ """ Define all basic options.
19
+
20
+ Usage:
21
+ class MyDataset (BaseStereoViewDataset):
22
+ def _get_views(self, idx, rng):
23
+ # overload here
24
+ views = []
25
+ views.append(dict(img=, ...))
26
+ return views
27
+ """
28
+
29
+ def __init__(self, *, # only keyword arguments
30
+ split=None,
31
+ resolution=None, # square_size or (width, height) or list of [(width,height), ...]
32
+ transform=ImgNorm,
33
+ aug_crop=False,
34
+ seed=None):
35
+ self.num_views = 2
36
+ self.split = split
37
+ self._set_resolutions(resolution)
38
+
39
+ self.transform = transform
40
+ if isinstance(transform, str):
41
+ transform = eval(transform)
42
+
43
+ self.aug_crop = aug_crop
44
+ self.seed = seed
45
+
46
+ def __len__(self):
47
+ return len(self.scenes)
48
+
49
+ def get_stats(self):
50
+ return f"{len(self)} pairs"
51
+
52
+ def __repr__(self):
53
+ resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
54
+ return f"""{type(self).__name__}({self.get_stats()},
55
+ {self.split=},
56
+ {self.seed=},
57
+ resolutions={resolutions_str},
58
+ {self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '')
59
+
60
+ def _get_views(self, idx, resolution, rng):
61
+ raise NotImplementedError()
62
+
63
+ def __getitem__(self, idx):
64
+ if isinstance(idx, tuple):
65
+ # the idx is specifying the aspect-ratio
66
+ idx, ar_idx = idx
67
+ else:
68
+ assert len(self._resolutions) == 1
69
+ ar_idx = 0
70
+
71
+ # set-up the rng
72
+ if self.seed: # reseed for each __getitem__
73
+ self._rng = np.random.default_rng(seed=self.seed + idx)
74
+ elif not hasattr(self, '_rng'):
75
+ seed = torch.initial_seed() # this is different for each dataloader process
76
+ self._rng = np.random.default_rng(seed=seed)
77
+
78
+ # over-loaded code
79
+ resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
80
+ views = self._get_views(idx, resolution, self._rng)
81
+ assert len(views) == self.num_views
82
+
83
+ # check data-types
84
+ for v, view in enumerate(views):
85
+ assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
86
+ view['idx'] = (idx, ar_idx, v)
87
+
88
+ # encode the image
89
+ width, height = view['img'].size
90
+ view['true_shape'] = np.int32((height, width))
91
+ view['img'] = self.transform(view['img'])
92
+
93
+ assert 'camera_intrinsics' in view
94
+ if 'camera_pose' not in view:
95
+ view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
96
+ else:
97
+ assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
98
+ assert 'pts3d' not in view
99
+ assert 'valid_mask' not in view
100
+ assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
101
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
102
+
103
+ view['pts3d'] = pts3d
104
+ view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
105
+
106
+ # check all datatypes
107
+ for key, val in view.items():
108
+ res, err_msg = is_good_type(key, val)
109
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
110
+ K = view['camera_intrinsics']
111
+
112
+ # last thing done!
113
+ for view in views:
114
+ # transpose to make sure all views are the same size
115
+ transpose_to_landscape(view)
116
+ # this allows to check whether the RNG is is the same state each time
117
+ view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
118
+ return views
119
+
120
+ def _set_resolutions(self, resolutions):
121
+ assert resolutions is not None, 'undefined resolution'
122
+
123
+ if not isinstance(resolutions, list):
124
+ resolutions = [resolutions]
125
+
126
+ self._resolutions = []
127
+ for resolution in resolutions:
128
+ if isinstance(resolution, int):
129
+ width = height = resolution
130
+ else:
131
+ width, height = resolution
132
+ assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
133
+ assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
134
+ assert width >= height
135
+ self._resolutions.append((width, height))
136
+
137
+ def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):
138
+ """ This function:
139
+ - first downsizes the image with LANCZOS inteprolation,
140
+ which is better than bilinear interpolation in
141
+ """
142
+ if not isinstance(image, PIL.Image.Image):
143
+ image = PIL.Image.fromarray(image)
144
+
145
+ # downscale with lanczos interpolation so that image.size == resolution
146
+ # cropping centered on the principal point
147
+ W, H = image.size
148
+ cx, cy = intrinsics[:2, 2].round().astype(int)
149
+ min_margin_x = min(cx, W-cx)
150
+ min_margin_y = min(cy, H-cy)
151
+ assert min_margin_x > W/5, f'Bad principal point in view={info}'
152
+ assert min_margin_y > H/5, f'Bad principal point in view={info}'
153
+ # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
154
+ l, t = cx - min_margin_x, cy - min_margin_y
155
+ r, b = cx + min_margin_x, cy + min_margin_y
156
+ crop_bbox = (l, t, r, b)
157
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
158
+
159
+ # transpose the resolution if necessary
160
+ W, H = image.size # new size
161
+ assert resolution[0] >= resolution[1]
162
+ if H > 1.1*W:
163
+ # image is portrait mode
164
+ resolution = resolution[::-1]
165
+ elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]:
166
+ # image is square, so we chose (portrait, landscape) randomly
167
+ if rng.integers(2):
168
+ resolution = resolution[::-1]
169
+
170
+ # high-quality Lanczos down-scaling
171
+ target_resolution = np.array(resolution)
172
+ if self.aug_crop > 1:
173
+ target_resolution += rng.integers(0, self.aug_crop)
174
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
175
+
176
+ # actual cropping (if necessary) with bilinear interpolation
177
+ intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)
178
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
179
+ image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
180
+
181
+ return image, depthmap, intrinsics2
182
+
183
+
184
+ def is_good_type(key, v):
185
+ """ returns (is_good, err_msg)
186
+ """
187
+ if isinstance(v, (str, int, tuple)):
188
+ return True, None
189
+ if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
190
+ return False, f"bad {v.dtype=}"
191
+ return True, None
192
+
193
+
194
+ def view_name(view, batch_index=None):
195
+ def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x
196
+ db = sel(view['dataset'])
197
+ label = sel(view['label'])
198
+ instance = sel(view['instance'])
199
+ return f"{db}/{label}/{instance}"
200
+
201
+
202
+ def transpose_to_landscape(view):
203
+ height, width = view['true_shape']
204
+
205
+ if width < height:
206
+ # rectify portrait to landscape
207
+ assert view['img'].shape == (3, height, width)
208
+ view['img'] = view['img'].swapaxes(1, 2)
209
+
210
+ assert view['valid_mask'].shape == (height, width)
211
+ view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
212
+
213
+ assert view['depthmap'].shape == (height, width)
214
+ view['depthmap'] = view['depthmap'].swapaxes(0, 1)
215
+
216
+ assert view['pts3d'].shape == (height, width, 3)
217
+ view['pts3d'] = view['pts3d'].swapaxes(0, 1)
218
+
219
+ # transpose x and y pixels
220
+ view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]
dust3r/dust3r/datasets/base/batched_sampler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Random sampling under a constraint
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ class BatchedRandomSampler:
12
+ """ Random sampling under a constraint: each sample in the batch has the same feature,
13
+ which is chosen randomly from a known pool of 'features' for each batch.
14
+
15
+ For instance, the 'feature' could be the image aspect-ratio.
16
+
17
+ The index returned is a tuple (sample_idx, feat_idx).
18
+ This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
19
+ """
20
+
21
+ def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):
22
+ self.batch_size = batch_size
23
+ self.pool_size = pool_size
24
+
25
+ self.len_dataset = N = len(dataset)
26
+ self.total_size = round_by(N, batch_size*world_size) if drop_last else N
27
+ assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'
28
+
29
+ # distributed sampler
30
+ self.world_size = world_size
31
+ self.rank = rank
32
+ self.epoch = None
33
+
34
+ def __len__(self):
35
+ return self.total_size // self.world_size
36
+
37
+ def set_epoch(self, epoch):
38
+ self.epoch = epoch
39
+
40
+ def __iter__(self):
41
+ # prepare RNG
42
+ if self.epoch is None:
43
+ assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'
44
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
45
+ else:
46
+ seed = self.epoch + 777
47
+ rng = np.random.default_rng(seed=seed)
48
+
49
+ # random indices (will restart from 0 if not drop_last)
50
+ sample_idxs = np.arange(self.total_size)
51
+ rng.shuffle(sample_idxs)
52
+
53
+ # random feat_idxs (same across each batch)
54
+ n_batches = (self.total_size+self.batch_size-1) // self.batch_size
55
+ feat_idxs = rng.integers(self.pool_size, size=n_batches)
56
+ feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
57
+ feat_idxs = feat_idxs.ravel()[:self.total_size]
58
+
59
+ # put them together
60
+ idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
61
+
62
+ # Distributed sampler: we select a subset of batches
63
+ # make sure the slice for each node is aligned with batch_size
64
+ size_per_proc = self.batch_size * ((self.total_size + self.world_size *
65
+ self.batch_size-1) // (self.world_size * self.batch_size))
66
+ idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]
67
+
68
+ yield from (tuple(idx) for idx in idxs)
69
+
70
+
71
+ def round_by(total, multiple, up=False):
72
+ if up:
73
+ total = total + multiple-1
74
+ return (total//multiple) * multiple
dust3r/dust3r/datasets/base/easy_dataset.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # A dataset base class that you can easily resize and combine.
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ from dust3r.datasets.base.batched_sampler import BatchedRandomSampler
9
+
10
+
11
+ class EasyDataset:
12
+ """ a dataset that you can easily resize and combine.
13
+ Examples:
14
+ ---------
15
+ 2 * dataset ==> duplicate each element 2x
16
+
17
+ 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
18
+
19
+ dataset1 + dataset2 ==> concatenate datasets
20
+ """
21
+
22
+ def __add__(self, other):
23
+ return CatDataset([self, other])
24
+
25
+ def __rmul__(self, factor):
26
+ return MulDataset(factor, self)
27
+
28
+ def __rmatmul__(self, factor):
29
+ return ResizedDataset(factor, self)
30
+
31
+ def set_epoch(self, epoch):
32
+ pass # nothing to do by default
33
+
34
+ def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):
35
+ if not (shuffle):
36
+ raise NotImplementedError() # cannot deal yet
37
+ num_of_aspect_ratios = len(self._resolutions)
38
+ return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)
39
+
40
+
41
+ class MulDataset (EasyDataset):
42
+ """ Artifically augmenting the size of a dataset.
43
+ """
44
+ multiplicator: int
45
+
46
+ def __init__(self, multiplicator, dataset):
47
+ assert isinstance(multiplicator, int) and multiplicator > 0
48
+ self.multiplicator = multiplicator
49
+ self.dataset = dataset
50
+
51
+ def __len__(self):
52
+ return self.multiplicator * len(self.dataset)
53
+
54
+ def __repr__(self):
55
+ return f'{self.multiplicator}*{repr(self.dataset)}'
56
+
57
+ def __getitem__(self, idx):
58
+ if isinstance(idx, tuple):
59
+ idx, other = idx
60
+ return self.dataset[idx // self.multiplicator, other]
61
+ else:
62
+ return self.dataset[idx // self.multiplicator]
63
+
64
+ @property
65
+ def _resolutions(self):
66
+ return self.dataset._resolutions
67
+
68
+
69
+ class ResizedDataset (EasyDataset):
70
+ """ Artifically changing the size of a dataset.
71
+ """
72
+ new_size: int
73
+
74
+ def __init__(self, new_size, dataset):
75
+ assert isinstance(new_size, int) and new_size > 0
76
+ self.new_size = new_size
77
+ self.dataset = dataset
78
+
79
+ def __len__(self):
80
+ return self.new_size
81
+
82
+ def __repr__(self):
83
+ size_str = str(self.new_size)
84
+ for i in range((len(size_str)-1) // 3):
85
+ sep = -4*i-3
86
+ size_str = size_str[:sep] + '_' + size_str[sep:]
87
+ return f'{size_str} @ {repr(self.dataset)}'
88
+
89
+ def set_epoch(self, epoch):
90
+ # this random shuffle only depends on the epoch
91
+ rng = np.random.default_rng(seed=epoch+777)
92
+
93
+ # shuffle all indices
94
+ perm = rng.permutation(len(self.dataset))
95
+
96
+ # rotary extension until target size is met
97
+ shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))
98
+ self._idxs_mapping = shuffled_idxs[:self.new_size]
99
+
100
+ assert len(self._idxs_mapping) == self.new_size
101
+
102
+ def __getitem__(self, idx):
103
+ assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'
104
+ if isinstance(idx, tuple):
105
+ idx, other = idx
106
+ return self.dataset[self._idxs_mapping[idx], other]
107
+ else:
108
+ return self.dataset[self._idxs_mapping[idx]]
109
+
110
+ @property
111
+ def _resolutions(self):
112
+ return self.dataset._resolutions
113
+
114
+
115
+ class CatDataset (EasyDataset):
116
+ """ Concatenation of several datasets
117
+ """
118
+
119
+ def __init__(self, datasets):
120
+ for dataset in datasets:
121
+ assert isinstance(dataset, EasyDataset)
122
+ self.datasets = datasets
123
+ self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
124
+
125
+ def __len__(self):
126
+ return self._cum_sizes[-1]
127
+
128
+ def __repr__(self):
129
+ # remove uselessly long transform
130
+ return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)
131
+
132
+ def set_epoch(self, epoch):
133
+ for dataset in self.datasets:
134
+ dataset.set_epoch(epoch)
135
+
136
+ def __getitem__(self, idx):
137
+ other = None
138
+ if isinstance(idx, tuple):
139
+ idx, other = idx
140
+
141
+ if not (0 <= idx < len(self)):
142
+ raise IndexError()
143
+
144
+ db_idx = np.searchsorted(self._cum_sizes, idx, 'right')
145
+ dataset = self.datasets[db_idx]
146
+ new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
147
+
148
+ if other is not None:
149
+ new_idx = (new_idx, other)
150
+ return dataset[new_idx]
151
+
152
+ @property
153
+ def _resolutions(self):
154
+ resolutions = self.datasets[0]._resolutions
155
+ for dataset in self.datasets[1:]:
156
+ assert tuple(dataset._resolutions) == tuple(resolutions)
157
+ return resolutions
dust3r/dust3r/datasets/co3d.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Dataloader for preprocessed Co3d_v2
6
+ # dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International
7
+ # See datasets_preprocess/preprocess_co3d.py
8
+ # --------------------------------------------------------
9
+ import os.path as osp
10
+ import json
11
+ import itertools
12
+ from collections import deque
13
+
14
+ import cv2
15
+ import numpy as np
16
+
17
+ from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
18
+ from dust3r.utils.image import imread_cv2
19
+
20
+
21
+ class Co3d(BaseStereoViewDataset):
22
+ def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
23
+ self.ROOT = ROOT
24
+ super().__init__(*args, **kwargs)
25
+ assert mask_bg in (True, False, 'rand')
26
+ self.mask_bg = mask_bg
27
+
28
+ # load all scenes
29
+ with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
30
+ self.scenes = json.load(f)
31
+ self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
32
+ self.scenes = {(k, k2): v2 for k, v in self.scenes.items()
33
+ for k2, v2 in v.items()}
34
+ self.scene_list = list(self.scenes.keys())
35
+
36
+ # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)
37
+ # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees
38
+ self.combinations = [(i, j)
39
+ for i, j in itertools.combinations(range(100), 2)
40
+ if 0 < abs(i-j) <= 30 and abs(i-j) % 5 == 0]
41
+
42
+ self.invalidate = {scene: {} for scene in self.scene_list}
43
+
44
+ def __len__(self):
45
+ return len(self.scene_list) * len(self.combinations)
46
+
47
+ def _get_views(self, idx, resolution, rng):
48
+ # choose a scene
49
+ obj, instance = self.scene_list[idx // len(self.combinations)]
50
+ image_pool = self.scenes[obj, instance]
51
+ im1_idx, im2_idx = self.combinations[idx % len(self.combinations)]
52
+
53
+ # add a bit of randomness
54
+ last = len(image_pool)-1
55
+
56
+ if resolution not in self.invalidate[obj, instance]: # flag invalid images
57
+ self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
58
+
59
+ # decide now if we mask the bg
60
+ mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
61
+
62
+ views = []
63
+ imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]]
64
+ imgs_idxs = deque(imgs_idxs)
65
+ while len(imgs_idxs) > 0: # some images (few) have zero depth
66
+ im_idx = imgs_idxs.pop()
67
+
68
+ if self.invalidate[obj, instance][resolution][im_idx]:
69
+ # search for a valid image
70
+ random_direction = 2 * rng.choice(2) - 1
71
+ for offset in range(1, len(image_pool)):
72
+ tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
73
+ if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
74
+ im_idx = tentative_im_idx
75
+ break
76
+
77
+ view_idx = image_pool[im_idx]
78
+
79
+ impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')
80
+
81
+ # load camera params
82
+ input_metadata = np.load(impath.replace('jpg', 'npz'))
83
+ camera_pose = input_metadata['camera_pose'].astype(np.float32)
84
+ intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
85
+
86
+ # load image and depth
87
+ rgb_image = imread_cv2(impath)
88
+ depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
89
+ depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
90
+
91
+ if mask_bg:
92
+ # load object mask
93
+ maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')
94
+ maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
95
+ maskmap = (maskmap / 255.0) > 0.1
96
+
97
+ # update the depthmap with mask
98
+ depthmap *= maskmap
99
+
100
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
101
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
102
+
103
+ num_valid = (depthmap > 0.0).sum()
104
+ if num_valid == 0:
105
+ # problem, invalidate image and retry
106
+ self.invalidate[obj, instance][resolution][im_idx] = True
107
+ imgs_idxs.append(im_idx)
108
+ continue
109
+
110
+ views.append(dict(
111
+ img=rgb_image,
112
+ depthmap=depthmap,
113
+ camera_pose=camera_pose,
114
+ camera_intrinsics=intrinsics,
115
+ dataset='Co3d_v2',
116
+ label=osp.join(obj, instance),
117
+ instance=osp.split(impath)[1],
118
+ ))
119
+ return views
120
+
121
+
122
+ if __name__ == "__main__":
123
+ from dust3r.datasets.base.base_stereo_view_dataset import view_name
124
+ from dust3r.viz import SceneViz, auto_cam_size
125
+ from dust3r.utils.image import rgb
126
+
127
+ dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16)
128
+
129
+ for idx in np.random.permutation(len(dataset)):
130
+ views = dataset[idx]
131
+ assert len(views) == 2
132
+ print(view_name(views[0]), view_name(views[1]))
133
+ viz = SceneViz()
134
+ poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]
135
+ cam_size = max(auto_cam_size(poses), 0.001)
136
+ for view_idx in [0, 1]:
137
+ pts3d = views[view_idx]['pts3d']
138
+ valid_mask = views[view_idx]['valid_mask']
139
+ colors = rgb(views[view_idx]['img'])
140
+ viz.add_pointcloud(pts3d, colors, valid_mask)
141
+ viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
142
+ focal=views[view_idx]['camera_intrinsics'][0, 0],
143
+ color=(idx*255, (1 - idx)*255, 0),
144
+ image=colors,
145
+ cam_size=cam_size)
146
+ viz.show()
dust3r/dust3r/datasets/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
dust3r/dust3r/datasets/utils/cropping.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # croppping utilities
6
+ # --------------------------------------------------------
7
+ import PIL.Image
8
+ import os
9
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
10
+ import cv2 # noqa
11
+ import numpy as np # noqa
12
+ from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa
13
+ try:
14
+ lanczos = PIL.Image.Resampling.LANCZOS
15
+ except AttributeError:
16
+ lanczos = PIL.Image.LANCZOS
17
+
18
+
19
+ class ImageList:
20
+ """ Convenience class to aply the same operation to a whole set of images.
21
+ """
22
+
23
+ def __init__(self, images):
24
+ if not isinstance(images, (tuple, list, set)):
25
+ images = [images]
26
+ self.images = []
27
+ for image in images:
28
+ if not isinstance(image, PIL.Image.Image):
29
+ image = PIL.Image.fromarray(image)
30
+ self.images.append(image)
31
+
32
+ def __len__(self):
33
+ return len(self.images)
34
+
35
+ def to_pil(self):
36
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
37
+
38
+ @property
39
+ def size(self):
40
+ sizes = [im.size for im in self.images]
41
+ assert all(sizes[0] == s for s in sizes)
42
+ return sizes[0]
43
+
44
+ def resize(self, *args, **kwargs):
45
+ return ImageList(self._dispatch('resize', *args, **kwargs))
46
+
47
+ def crop(self, *args, **kwargs):
48
+ return ImageList(self._dispatch('crop', *args, **kwargs))
49
+
50
+ def _dispatch(self, func, *args, **kwargs):
51
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
52
+
53
+
54
+ def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution):
55
+ """ Jointly rescale a (image, depthmap)
56
+ so that (out_width, out_height) >= output_res
57
+ """
58
+ image = ImageList(image)
59
+ input_resolution = np.array(image.size) # (W,H)
60
+ output_resolution = np.array(output_resolution)
61
+ if depthmap is not None:
62
+ # can also use this with masks instead of depthmaps
63
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
64
+ assert output_resolution.shape == (2,)
65
+ # define output resolution
66
+ scale_final = max(output_resolution / image.size) + 1e-8
67
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
68
+
69
+ # first rescale the image so that it contains the crop
70
+ image = image.resize(output_resolution, resample=lanczos)
71
+ if depthmap is not None:
72
+ depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
73
+ fy=scale_final, interpolation=cv2.INTER_NEAREST)
74
+
75
+ # no offset here; simple rescaling
76
+ camera_intrinsics = camera_matrix_of_crop(
77
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
78
+
79
+ return image.to_pil(), depthmap, camera_intrinsics
80
+
81
+
82
+ def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
83
+ # Margins to offset the origin
84
+ margins = np.asarray(input_resolution) * scaling - output_resolution
85
+ assert np.all(margins >= 0.0)
86
+ if offset is None:
87
+ offset = offset_factor * margins
88
+
89
+ # Generate new camera parameters
90
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
91
+ output_camera_matrix_colmap[:2, :] *= scaling
92
+ output_camera_matrix_colmap[:2, 2] -= offset
93
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
94
+
95
+ return output_camera_matrix
96
+
97
+
98
+ def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
99
+ """
100
+ Return a crop of the input view.
101
+ """
102
+ image = ImageList(image)
103
+ l, t, r, b = crop_bbox
104
+
105
+ image = image.crop((l, t, r, b))
106
+ depthmap = depthmap[t:b, l:r]
107
+
108
+ camera_intrinsics = camera_intrinsics.copy()
109
+ camera_intrinsics[0, 2] -= l
110
+ camera_intrinsics[1, 2] -= t
111
+
112
+ return image.to_pil(), depthmap, camera_intrinsics
113
+
114
+
115
+ def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
116
+ out_width, out_height = output_resolution
117
+ l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
118
+ crop_bbox = (l, t, l+out_width, t+out_height)
119
+ return crop_bbox
dust3r/dust3r/datasets/utils/transforms.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # DUST3R default transforms
6
+ # --------------------------------------------------------
7
+ import torchvision.transforms as tvf
8
+ from dust3r.utils.image import ImgNorm
9
+
10
+ # define the standard image transforms
11
+ ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
dust3r/dust3r/heads/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # head factory
6
+ # --------------------------------------------------------
7
+ from .linear_head import LinearPts3d
8
+ from .dpt_head import create_dpt_head
9
+
10
+
11
+ def head_factory(head_type, output_mode, net, has_conf=False):
12
+ """" build a prediction head for the decoder
13
+ """
14
+ if head_type == 'linear' and output_mode == 'pts3d':
15
+ return LinearPts3d(net, has_conf)
16
+ elif head_type == 'dpt' and output_mode == 'pts3d':
17
+ return create_dpt_head(net, has_conf=has_conf)
18
+ else:
19
+ raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
dust3r/dust3r/heads/dpt_head.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # dpt head implementation for DUST3R
6
+ # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
7
+ # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
8
+ # the forward function also takes as input a dictionnary img_info with key "height" and "width"
9
+ # for PixelwiseTask, the output will be of dimension B x num_channels x H x W
10
+ # --------------------------------------------------------
11
+ from einops import rearrange
12
+ from typing import List
13
+ import torch
14
+ import torch.nn as nn
15
+ from dust3r.heads.postprocess import postprocess
16
+ import dust3r.utils.path_to_croco # noqa: F401
17
+ from models.dpt_block import DPTOutputAdapter # noqa
18
+
19
+
20
+ class DPTOutputAdapter_fix(DPTOutputAdapter):
21
+ """
22
+ Adapt croco's DPTOutputAdapter implementation for dust3r:
23
+ remove duplicated weigths, and fix forward for dust3r
24
+ """
25
+
26
+ def init(self, dim_tokens_enc=768):
27
+ super().init(dim_tokens_enc)
28
+ # these are duplicated weights
29
+ del self.act_1_postprocess
30
+ del self.act_2_postprocess
31
+ del self.act_3_postprocess
32
+ del self.act_4_postprocess
33
+
34
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
35
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
36
+ # H, W = input_info['image_size']
37
+ image_size = self.image_size if image_size is None else image_size
38
+ H, W = image_size
39
+ # Number of patches in height and width
40
+ N_H = H // (self.stride_level * self.P_H)
41
+ N_W = W // (self.stride_level * self.P_W)
42
+
43
+ # Hook decoder onto 4 layers from specified ViT layers
44
+ layers = [encoder_tokens[hook] for hook in self.hooks]
45
+
46
+ # Extract only task-relevant tokens and ignore global tokens.
47
+ layers = [self.adapt_tokens(l) for l in layers]
48
+
49
+ # Reshape tokens to spatial representation
50
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
51
+
52
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
53
+ # Project layers to chosen feature dim
54
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
55
+
56
+ # Fuse layers using refinement stages
57
+ path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
58
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
59
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
60
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
61
+
62
+ # Output head
63
+ out = self.head(path_1)
64
+
65
+ return out
66
+
67
+
68
+ class PixelwiseTaskWithDPT(nn.Module):
69
+ """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
70
+
71
+ def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
72
+ output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
73
+ super(PixelwiseTaskWithDPT, self).__init__()
74
+ self.return_all_layers = True # backbone needs to return all layers
75
+ self.postprocess = postprocess
76
+ self.depth_mode = depth_mode
77
+ self.conf_mode = conf_mode
78
+
79
+ assert n_cls_token == 0, "Not implemented"
80
+ dpt_args = dict(output_width_ratio=output_width_ratio,
81
+ num_channels=num_channels,
82
+ **kwargs)
83
+ if hooks_idx is not None:
84
+ dpt_args.update(hooks=hooks_idx)
85
+ self.dpt = DPTOutputAdapter_fix(**dpt_args)
86
+ dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
87
+ self.dpt.init(**dpt_init_args)
88
+
89
+ def forward(self, x, img_info):
90
+ out = self.dpt(x, image_size=(img_info[0], img_info[1]))
91
+ if self.postprocess:
92
+ out = self.postprocess(out, self.depth_mode, self.conf_mode)
93
+ return out
94
+
95
+
96
+ def create_dpt_head(net, has_conf=False):
97
+ """
98
+ return PixelwiseTaskWithDPT for given net params
99
+ """
100
+ assert net.dec_depth > 9
101
+ l2 = net.dec_depth
102
+ feature_dim = 256
103
+ last_dim = feature_dim//2
104
+ out_nchan = 3
105
+ ed = net.enc_embed_dim
106
+ dd = net.dec_embed_dim
107
+ return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
108
+ feature_dim=feature_dim,
109
+ last_dim=last_dim,
110
+ hooks_idx=[0, l2*2//4, l2*3//4, l2],
111
+ dim_tokens=[ed, dd, dd, dd],
112
+ postprocess=postprocess,
113
+ depth_mode=net.depth_mode,
114
+ conf_mode=net.conf_mode,
115
+ head_type='regression')
dust3r/dust3r/heads/linear_head.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # linear head implementation for DUST3R
6
+ # --------------------------------------------------------
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from dust3r.heads.postprocess import postprocess
10
+
11
+
12
+ class LinearPts3d (nn.Module):
13
+ """
14
+ Linear head for dust3r
15
+ Each token outputs: - 16x16 3D points (+ confidence)
16
+ """
17
+
18
+ def __init__(self, net, has_conf=False):
19
+ super().__init__()
20
+ self.patch_size = net.patch_embed.patch_size[0]
21
+ self.depth_mode = net.depth_mode
22
+ self.conf_mode = net.conf_mode
23
+ self.has_conf = has_conf
24
+
25
+ self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
26
+
27
+ def setup(self, croconet):
28
+ pass
29
+
30
+ def forward(self, decout, img_shape):
31
+ H, W = img_shape
32
+ tokens = decout[-1]
33
+ B, S, D = tokens.shape
34
+
35
+ # extract 3D points
36
+ feat = self.proj(tokens) # B,S,D
37
+ feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
38
+ feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
39
+
40
+ # permute + norm depth
41
+ return postprocess(feat, self.depth_mode, self.conf_mode)
dust3r/dust3r/heads/postprocess.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # post process function for all heads: extract 3D points/confidence from output
6
+ # --------------------------------------------------------
7
+ import torch
8
+
9
+
10
+ def postprocess(out, depth_mode, conf_mode):
11
+ """
12
+ extract 3D points/confidence from prediction head output
13
+ """
14
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,3
15
+ res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
16
+
17
+ if conf_mode is not None:
18
+ res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
19
+ return res
20
+
21
+
22
+ def reg_dense_depth(xyz, mode):
23
+ """
24
+ extract 3D points from prediction head output
25
+ """
26
+ mode, vmin, vmax = mode
27
+
28
+ no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
29
+ assert no_bounds
30
+
31
+ if mode == 'linear':
32
+ if no_bounds:
33
+ return xyz # [-inf, +inf]
34
+ return xyz.clip(min=vmin, max=vmax)
35
+
36
+ # distance to origin
37
+ d = xyz.norm(dim=-1, keepdim=True)
38
+ xyz = xyz / d.clip(min=1e-8)
39
+
40
+ if mode == 'square':
41
+ return xyz * d.square()
42
+
43
+ if mode == 'exp':
44
+ return xyz * torch.expm1(d)
45
+
46
+ raise ValueError(f'bad {mode=}')
47
+
48
+
49
+ def reg_dense_conf(x, mode):
50
+ """
51
+ extract confidence from prediction head output
52
+ """
53
+ mode, vmin, vmax = mode
54
+ if mode == 'exp':
55
+ return vmin + x.exp().clip(max=vmax-vmin)
56
+ if mode == 'sigmoid':
57
+ return (vmax - vmin) * torch.sigmoid(x) + vmin
58
+ raise ValueError(f'bad {mode=}')
dust3r/dust3r/image_pairs.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilities needed to load image pairs
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True):
12
+ pairs = []
13
+
14
+ if scene_graph == 'complete': # complete graph
15
+ for i in range(len(imgs)):
16
+ for j in range(i):
17
+ pairs.append((imgs[i], imgs[j]))
18
+
19
+ elif scene_graph.startswith('swin'):
20
+ winsize = int(scene_graph.split('-')[1]) if '-' in scene_graph else 3
21
+ for i in range(len(imgs)):
22
+ for j in range(winsize):
23
+ idx = (i + j) % len(imgs) # explicit loop closure
24
+ pairs.append((imgs[i], imgs[idx]))
25
+
26
+ elif scene_graph.startswith('oneref'):
27
+ refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0
28
+ for j in range(len(imgs)):
29
+ if j != refid:
30
+ pairs.append((imgs[refid], imgs[j]))
31
+
32
+ elif scene_graph == 'pairs':
33
+ assert len(imgs) % 2 == 0
34
+ for i in range(0, len(imgs), 2):
35
+ pairs.append((imgs[i], imgs[i+1]))
36
+
37
+ if symmetrize:
38
+ pairs += [(img2, img1) for img1, img2 in pairs]
39
+
40
+ # now, remove edges
41
+ if isinstance(prefilter, str) and prefilter.startswith('seq'):
42
+ pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
43
+
44
+ if isinstance(prefilter, str) and prefilter.startswith('cyc'):
45
+ pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
46
+
47
+ return pairs
48
+
49
+
50
+ def sel(x, kept):
51
+ if isinstance(x, dict):
52
+ return {k: sel(v, kept) for k, v in x.items()}
53
+ if isinstance(x, (torch.Tensor, np.ndarray)):
54
+ return x[kept]
55
+ if isinstance(x, (tuple, list)):
56
+ return type(x)([x[k] for k in kept])
57
+
58
+
59
+ def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
60
+ # number of images
61
+ n = max(max(e) for e in edges)+1
62
+
63
+ kept = []
64
+ for e, (i, j) in enumerate(edges):
65
+ dis = abs(i-j)
66
+ if cyclic:
67
+ dis = min(dis, abs(i+n-j), abs(i-n-j))
68
+ if dis <= seq_dis_thr:
69
+ kept.append(e)
70
+ return kept
71
+
72
+
73
+ def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
74
+ edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs]
75
+ kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
76
+ return [pairs[i] for i in kept]
77
+
78
+
79
+ def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
80
+ edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
81
+ kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
82
+ print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges')
83
+ return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
dust3r/dust3r/inference.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilities needed for the inference
6
+ # --------------------------------------------------------
7
+ import tqdm
8
+ import torch
9
+ from dust3r.utils.device import to_cpu, collate_with_cat
10
+ from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model
11
+ from dust3r.utils.misc import invalid_to_nans
12
+ from dust3r.utils.geometry import depthmap_to_pts3d, geotrf
13
+
14
+
15
+ def load_model(model_path, device):
16
+ print('... loading model from', model_path)
17
+ ckpt = torch.load(model_path, map_location='cpu')
18
+ args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
19
+ if 'landscape_only' not in args:
20
+ args = args[:-1] + ', landscape_only=False)'
21
+ else:
22
+ args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
23
+ assert "landscape_only=False" in args
24
+ print(f"instantiating : {args}")
25
+ net = eval(args)
26
+ print(net.load_state_dict(ckpt['model'], strict=False))
27
+ return net.to(device)
28
+
29
+
30
+ def _interleave_imgs(img1, img2):
31
+ res = {}
32
+ for key, value1 in img1.items():
33
+ value2 = img2[key]
34
+ if isinstance(value1, torch.Tensor):
35
+ value = torch.stack((value1, value2), dim=1).flatten(0, 1)
36
+ else:
37
+ value = [x for pair in zip(value1, value2) for x in pair]
38
+ res[key] = value
39
+ return res
40
+
41
+
42
+ def make_batch_symmetric(batch):
43
+ view1, view2 = batch
44
+ view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
45
+ return view1, view2
46
+
47
+
48
+ def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
49
+ view1, view2 = batch
50
+ for view in batch:
51
+ for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split(): # pseudo_focal
52
+ if name not in view:
53
+ continue
54
+ view[name] = view[name].to(device, non_blocking=True)
55
+
56
+ if symmetrize_batch:
57
+ view1, view2 = make_batch_symmetric(batch)
58
+
59
+ with torch.cuda.amp.autocast(enabled=bool(use_amp)):
60
+ pred1, pred2 = model(view1, view2)
61
+
62
+ # loss is supposed to be symmetric
63
+ with torch.cuda.amp.autocast(enabled=False):
64
+ loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
65
+
66
+ result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
67
+ return result[ret] if ret else result
68
+
69
+
70
+ @torch.no_grad()
71
+ def inference(pairs, model, device, batch_size=8):
72
+ print(f'>> Inference with model on {len(pairs)} image pairs')
73
+ result = []
74
+
75
+ # first, check if all images have the same size
76
+ multiple_shapes = not (check_if_same_size(pairs))
77
+ if multiple_shapes: # force bs=1
78
+ batch_size = 1
79
+
80
+ for i in tqdm.trange(0, len(pairs), batch_size):
81
+ res = loss_of_one_batch(collate_with_cat(pairs[i:i+batch_size]), model, None, device)
82
+ result.append(to_cpu(res))
83
+
84
+ result = collate_with_cat(result, lists=multiple_shapes)
85
+
86
+ torch.cuda.empty_cache()
87
+ return result
88
+
89
+
90
+ def check_if_same_size(pairs):
91
+ shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs]
92
+ shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs]
93
+ return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2)
94
+
95
+
96
+ def get_pred_pts3d(gt, pred, use_pose=False):
97
+ if 'depth' in pred and 'pseudo_focal' in pred:
98
+ try:
99
+ pp = gt['camera_intrinsics'][..., :2, 2]
100
+ except KeyError:
101
+ pp = None
102
+ pts3d = depthmap_to_pts3d(**pred, pp=pp)
103
+
104
+ elif 'pts3d' in pred:
105
+ # pts3d from my camera
106
+ pts3d = pred['pts3d']
107
+
108
+ elif 'pts3d_in_other_view' in pred:
109
+ # pts3d from the other camera, already transformed
110
+ assert use_pose is True
111
+ return pred['pts3d_in_other_view'] # return!
112
+
113
+ if use_pose:
114
+ camera_pose = pred.get('camera_pose')
115
+ assert camera_pose is not None
116
+ pts3d = geotrf(camera_pose, pts3d)
117
+
118
+ return pts3d
119
+
120
+
121
+ def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None):
122
+ assert gt_pts1.ndim == pr_pts1.ndim == 4
123
+ assert gt_pts1.shape == pr_pts1.shape
124
+ if gt_pts2 is not None:
125
+ assert gt_pts2.ndim == pr_pts2.ndim == 4
126
+ assert gt_pts2.shape == pr_pts2.shape
127
+
128
+ # concat the pointcloud
129
+ nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
130
+ nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
131
+
132
+ pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
133
+ pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
134
+
135
+ all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1
136
+ all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
137
+
138
+ dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
139
+ dot_gt_gt = all_gt.square().sum(dim=-1)
140
+
141
+ if fit_mode.startswith('avg'):
142
+ # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
143
+ scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
144
+ elif fit_mode.startswith('median'):
145
+ scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
146
+ elif fit_mode.startswith('weiszfeld'):
147
+ # init scaling with l2 closed form
148
+ scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
149
+ # iterative re-weighted least-squares
150
+ for iter in range(10):
151
+ # re-weighting by inverse of distance
152
+ dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
153
+ # print(dis.nanmean(-1))
154
+ w = dis.clip_(min=1e-8).reciprocal()
155
+ # update the scaling with the new weights
156
+ scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
157
+ else:
158
+ raise ValueError(f'bad {fit_mode=}')
159
+
160
+ if fit_mode.endswith('stop_grad'):
161
+ scaling = scaling.detach()
162
+
163
+ scaling = scaling.clip(min=1e-3)
164
+ # assert scaling.isfinite().all(), bb()
165
+ return scaling
dust3r/dust3r/losses.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Implementation of DUSt3R training losses
6
+ # --------------------------------------------------------
7
+ from copy import copy, deepcopy
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from dust3r.inference import get_pred_pts3d, find_opt_scaling
12
+ from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud
13
+ from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale
14
+
15
+
16
+ def Sum(*losses_and_masks):
17
+ loss, mask = losses_and_masks[0]
18
+ if loss.ndim > 0:
19
+ # we are actually returning the loss for every pixels
20
+ return losses_and_masks
21
+ else:
22
+ # we are returning the global loss
23
+ for loss2, mask2 in losses_and_masks[1:]:
24
+ loss = loss + loss2
25
+ return loss
26
+
27
+
28
+ class LLoss (nn.Module):
29
+ """ L-norm loss
30
+ """
31
+
32
+ def __init__(self, reduction='mean'):
33
+ super().__init__()
34
+ self.reduction = reduction
35
+
36
+ def forward(self, a, b):
37
+ assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}'
38
+ dist = self.distance(a, b)
39
+ assert dist.ndim == a.ndim-1 # one dimension less
40
+ if self.reduction == 'none':
41
+ return dist
42
+ if self.reduction == 'sum':
43
+ return dist.sum()
44
+ if self.reduction == 'mean':
45
+ return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
46
+ raise ValueError(f'bad {self.reduction=} mode')
47
+
48
+ def distance(self, a, b):
49
+ raise NotImplementedError()
50
+
51
+
52
+ class L21Loss (LLoss):
53
+ """ Euclidean distance between 3d points """
54
+
55
+ def distance(self, a, b):
56
+ return torch.norm(a - b, dim=-1) # normalized L2 distance
57
+
58
+
59
+ L21 = L21Loss()
60
+
61
+
62
+ class Criterion (nn.Module):
63
+ def __init__(self, criterion=None):
64
+ super().__init__()
65
+ assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!'+bb()
66
+ self.criterion = copy(criterion)
67
+
68
+ def get_name(self):
69
+ return f'{type(self).__name__}({self.criterion})'
70
+
71
+ def with_reduction(self, mode):
72
+ res = loss = deepcopy(self)
73
+ while loss is not None:
74
+ assert isinstance(loss, Criterion)
75
+ loss.criterion.reduction = 'none' # make it return the loss for each sample
76
+ loss = loss._loss2 # we assume loss is a Multiloss
77
+ return res
78
+
79
+
80
+ class MultiLoss (nn.Module):
81
+ """ Easily combinable losses (also keep track of individual loss values):
82
+ loss = MyLoss1() + 0.1*MyLoss2()
83
+ Usage:
84
+ Inherit from this class and override get_name() and compute_loss()
85
+ """
86
+
87
+ def __init__(self):
88
+ super().__init__()
89
+ self._alpha = 1
90
+ self._loss2 = None
91
+
92
+ def compute_loss(self, *args, **kwargs):
93
+ raise NotImplementedError()
94
+
95
+ def get_name(self):
96
+ raise NotImplementedError()
97
+
98
+ def __mul__(self, alpha):
99
+ assert isinstance(alpha, (int, float))
100
+ res = copy(self)
101
+ res._alpha = alpha
102
+ return res
103
+ __rmul__ = __mul__ # same
104
+
105
+ def __add__(self, loss2):
106
+ assert isinstance(loss2, MultiLoss)
107
+ res = cur = copy(self)
108
+ # find the end of the chain
109
+ while cur._loss2 is not None:
110
+ cur = cur._loss2
111
+ cur._loss2 = loss2
112
+ return res
113
+
114
+ def __repr__(self):
115
+ name = self.get_name()
116
+ if self._alpha != 1:
117
+ name = f'{self._alpha:g}*{name}'
118
+ if self._loss2:
119
+ name = f'{name} + {self._loss2}'
120
+ return name
121
+
122
+ def forward(self, *args, **kwargs):
123
+ loss = self.compute_loss(*args, **kwargs)
124
+ if isinstance(loss, tuple):
125
+ loss, details = loss
126
+ elif loss.ndim == 0:
127
+ details = {self.get_name(): float(loss)}
128
+ else:
129
+ details = {}
130
+ loss = loss * self._alpha
131
+
132
+ if self._loss2:
133
+ loss2, details2 = self._loss2(*args, **kwargs)
134
+ loss = loss + loss2
135
+ details |= details2
136
+
137
+ return loss, details
138
+
139
+
140
+ class Regr3D (Criterion, MultiLoss):
141
+ """ Ensure that all 3D points are correct.
142
+ Asymmetric loss: view1 is supposed to be the anchor.
143
+
144
+ P1 = RT1 @ D1
145
+ P2 = RT2 @ D2
146
+ loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1)
147
+ loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2)
148
+ = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2)
149
+ """
150
+
151
+ def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False):
152
+ super().__init__(criterion)
153
+ self.norm_mode = norm_mode
154
+ self.gt_scale = gt_scale
155
+
156
+ def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):
157
+ # everything is normalized w.r.t. camera of view1
158
+ in_camera1 = inv(gt1['camera_pose'])
159
+ gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3
160
+ gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3
161
+
162
+ valid1 = gt1['valid_mask'].clone()
163
+ valid2 = gt2['valid_mask'].clone()
164
+
165
+ if dist_clip is not None:
166
+ # points that are too far-away == invalid
167
+ dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
168
+ dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
169
+ valid1 = valid1 & (dis1 <= dist_clip)
170
+ valid2 = valid2 & (dis2 <= dist_clip)
171
+
172
+ pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False)
173
+ pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True)
174
+
175
+ # normalize 3d points
176
+ if self.norm_mode:
177
+ pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2)
178
+ if self.norm_mode and not self.gt_scale:
179
+ gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2)
180
+
181
+ return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {}
182
+
183
+ def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
184
+ gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \
185
+ self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw)
186
+ # loss on img1 side
187
+ l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1])
188
+ # loss on gt2 side
189
+ l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2])
190
+ self_name = type(self).__name__
191
+ details = {self_name+'_pts3d_1': float(l1.mean()), self_name+'_pts3d_2': float(l2.mean())}
192
+ return Sum((l1, mask1), (l2, mask2)), (details | monitoring)
193
+
194
+
195
+ class ConfLoss (MultiLoss):
196
+ """ Weighted regression by learned confidence.
197
+ Assuming the input pixel_loss is a pixel-level regression loss.
198
+
199
+ Principle:
200
+ high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
201
+ low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
202
+
203
+ alpha: hyperparameter
204
+ """
205
+
206
+ def __init__(self, pixel_loss, alpha=1):
207
+ super().__init__()
208
+ assert alpha > 0
209
+ self.alpha = alpha
210
+ self.pixel_loss = pixel_loss.with_reduction('none')
211
+
212
+ def get_name(self):
213
+ return f'ConfLoss({self.pixel_loss})'
214
+
215
+ def get_conf_log(self, x):
216
+ return x, torch.log(x)
217
+
218
+ def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
219
+ # compute per-pixel loss
220
+ ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)
221
+ if loss1.numel() == 0:
222
+ print('NO VALID POINTS in img1', force=True)
223
+ if loss2.numel() == 0:
224
+ print('NO VALID POINTS in img2', force=True)
225
+
226
+ # weight by confidence
227
+ conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1])
228
+ conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2])
229
+ conf_loss1 = loss1 * conf1 - self.alpha * log_conf1
230
+ conf_loss2 = loss2 * conf2 - self.alpha * log_conf2
231
+
232
+ # average + nan protection (in case of no valid pixels at all)
233
+ conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0
234
+ conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0
235
+
236
+ return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details)
237
+
238
+
239
+ class Regr3D_ShiftInv (Regr3D):
240
+ """ Same than Regr3D but invariant to depth shift.
241
+ """
242
+
243
+ def get_all_pts3d(self, gt1, gt2, pred1, pred2):
244
+ # compute unnormalized points
245
+ gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \
246
+ super().get_all_pts3d(gt1, gt2, pred1, pred2)
247
+
248
+ # compute median depth
249
+ gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2]
250
+ pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2]
251
+ gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None]
252
+ pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None]
253
+
254
+ # subtract the median depth
255
+ gt_z1 -= gt_shift_z
256
+ gt_z2 -= gt_shift_z
257
+ pred_z1 -= pred_shift_z
258
+ pred_z2 -= pred_shift_z
259
+
260
+ # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach())
261
+ return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring
262
+
263
+
264
+ class Regr3D_ScaleInv (Regr3D):
265
+ """ Same than Regr3D but invariant to depth shift.
266
+ if gt_scale == True: enforce the prediction to take the same scale than GT
267
+ """
268
+
269
+ def get_all_pts3d(self, gt1, gt2, pred1, pred2):
270
+ # compute depth-normalized points
271
+ gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2)
272
+
273
+ # measure scene scale
274
+ _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2)
275
+ _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2)
276
+
277
+ # prevent predictions to be in a ridiculous range
278
+ pred_scale = pred_scale.clip(min=1e-3, max=1e3)
279
+
280
+ # subtract the median depth
281
+ if self.gt_scale:
282
+ pred_pts1 *= gt_scale / pred_scale
283
+ pred_pts2 *= gt_scale / pred_scale
284
+ # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean())
285
+ else:
286
+ gt_pts1 /= gt_scale
287
+ gt_pts2 /= gt_scale
288
+ pred_pts1 /= pred_scale
289
+ pred_pts2 /= pred_scale
290
+ # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach())
291
+
292
+ return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring
293
+
294
+
295
+ class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv):
296
+ # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
297
+ pass
dust3r/dust3r/model.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # DUSt3R model class
6
+ # --------------------------------------------------------
7
+ from copy import deepcopy
8
+ import torch
9
+
10
+ from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape
11
+ from .heads import head_factory
12
+ from dust3r.patch_embed import get_patch_embed
13
+
14
+ import dust3r.utils.path_to_croco # noqa: F401
15
+ from models.croco import CroCoNet # noqa
16
+ inf = float('inf')
17
+
18
+
19
+ class AsymmetricCroCo3DStereo (CroCoNet):
20
+ """ Two siamese encoders, followed by two decoders.
21
+ The goal is to output 3d points directly, both images in view1's frame
22
+ (hence the asymmetry).
23
+ """
24
+
25
+ def __init__(self,
26
+ output_mode='pts3d',
27
+ head_type='linear',
28
+ depth_mode=('exp', -inf, inf),
29
+ conf_mode=('exp', 1, inf),
30
+ freeze='none',
31
+ landscape_only=True,
32
+ patch_embed_cls='PatchEmbedDust3R', # PatchEmbedDust3R or ManyAR_PatchEmbed
33
+ **croco_kwargs):
34
+ self.patch_embed_cls = patch_embed_cls
35
+ self.croco_args = fill_default_args(croco_kwargs, super().__init__)
36
+ super().__init__(**croco_kwargs)
37
+
38
+ # dust3r specific initialization
39
+ self.dec_blocks2 = deepcopy(self.dec_blocks)
40
+ self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs)
41
+ self.set_freeze(freeze)
42
+
43
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
44
+ self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
45
+
46
+ def load_state_dict(self, ckpt, **kw):
47
+ # duplicate all weights for the second decoder if not present
48
+ new_ckpt = dict(ckpt)
49
+ if not any(k.startswith('dec_blocks2') for k in ckpt):
50
+ for key, value in ckpt.items():
51
+ if key.startswith('dec_blocks'):
52
+ new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value
53
+ return super().load_state_dict(new_ckpt, **kw)
54
+
55
+ def set_freeze(self, freeze): # this is for use by downstream models
56
+ self.freeze = freeze
57
+ to_be_frozen = {
58
+ 'none': [],
59
+ 'mask': [self.mask_token],
60
+ 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks],
61
+ }
62
+ freeze_all_params(to_be_frozen[freeze])
63
+
64
+ def _set_prediction_head(self, *args, **kwargs):
65
+ """ No prediction head """
66
+ return
67
+
68
+ def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size,
69
+ **kw):
70
+ assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \
71
+ f'{img_size=} must be multiple of {patch_size=}'
72
+ self.output_mode = output_mode
73
+ self.head_type = head_type
74
+ self.depth_mode = depth_mode
75
+ self.conf_mode = conf_mode
76
+ # allocate heads
77
+ self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
78
+ self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
79
+ # magic wrapper
80
+ self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
81
+ self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
82
+
83
+ def _encode_image(self, image, true_shape):
84
+ # embed the image into patches (x has size B x Npatches x C)
85
+ x, pos = self.patch_embed(image, true_shape=true_shape)
86
+
87
+ # add positional embedding without cls token
88
+ assert self.enc_pos_embed is None
89
+
90
+ # now apply the transformer encoder and normalization
91
+ for blk in self.enc_blocks:
92
+ x = blk(x, pos)
93
+
94
+ x = self.enc_norm(x)
95
+ return x, pos, None
96
+
97
+ def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2):
98
+ if img1.shape[-2:] == img2.shape[-2:]:
99
+ out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0),
100
+ torch.cat((true_shape1, true_shape2), dim=0))
101
+ out, out2 = out.chunk(2, dim=0)
102
+ pos, pos2 = pos.chunk(2, dim=0)
103
+ else:
104
+ out, pos, _ = self._encode_image(img1, true_shape1)
105
+ out2, pos2, _ = self._encode_image(img2, true_shape2)
106
+ return out, out2, pos, pos2
107
+
108
+ def _encode_symmetrized(self, view1, view2):
109
+ img1 = view1['img']
110
+ img2 = view2['img']
111
+ B = img1.shape[0]
112
+ # Recover true_shape when available, otherwise assume that the img shape is the true one
113
+ shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1))
114
+ shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1))
115
+ # warning! maybe the images have different portrait/landscape orientations
116
+
117
+ if is_symmetrized(view1, view2):
118
+ # computing half of forward pass!'
119
+ feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2])
120
+ feat1, feat2 = interleave(feat1, feat2)
121
+ pos1, pos2 = interleave(pos1, pos2)
122
+ else:
123
+ feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2)
124
+
125
+ return (shape1, shape2), (feat1, feat2), (pos1, pos2)
126
+
127
+ def _decoder(self, f1, pos1, f2, pos2):
128
+ final_output = [(f1, f2)] # before projection
129
+
130
+ # project to decoder dim
131
+ f1 = self.decoder_embed(f1)
132
+ f2 = self.decoder_embed(f2)
133
+
134
+ final_output.append((f1, f2))
135
+ for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
136
+ # img1 side
137
+ f1, _ = blk1(*final_output[-1][::+1], pos1, pos2)
138
+ # img2 side
139
+ f2, _ = blk2(*final_output[-1][::-1], pos2, pos1)
140
+ # store the result
141
+ final_output.append((f1, f2))
142
+
143
+ # normalize last output
144
+ del final_output[1] # duplicate with final_output[0]
145
+ final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
146
+ return zip(*final_output)
147
+
148
+ def _downstream_head(self, head_num, decout, img_shape):
149
+ B, S, D = decout[-1].shape
150
+ # img_shape = tuple(map(int, img_shape))
151
+ head = getattr(self, f'head{head_num}')
152
+ return head(decout, img_shape)
153
+
154
+ def forward(self, view1, view2):
155
+ # encode the two images --> B,S,D
156
+ (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2)
157
+
158
+ # combine all ref images into object-centric representation
159
+ dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)
160
+
161
+ with torch.cuda.amp.autocast(enabled=False):
162
+ res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
163
+ res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
164
+
165
+ res2['pts3d_in_other_view'] = res2.pop('pts3d') # predict view2's pts3d in view1's frame
166
+ return res1, res2
dust3r/dust3r/optim_factory.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # optimization functions
6
+ # --------------------------------------------------------
7
+
8
+
9
+ def adjust_learning_rate_by_lr(optimizer, lr):
10
+ for param_group in optimizer.param_groups:
11
+ if "lr_scale" in param_group:
12
+ param_group["lr"] = lr * param_group["lr_scale"]
13
+ else:
14
+ param_group["lr"] = lr
dust3r/dust3r/patch_embed.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # PatchEmbed implementation for DUST3R,
6
+ # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
7
+ # --------------------------------------------------------
8
+ import torch
9
+ import dust3r.utils.path_to_croco # noqa: F401
10
+ from models.blocks import PatchEmbed # noqa
11
+
12
+
13
+ def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
14
+ assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
15
+ patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
16
+ return patch_embed
17
+
18
+
19
+ class PatchEmbedDust3R(PatchEmbed):
20
+ def forward(self, x, **kw):
21
+ B, C, H, W = x.shape
22
+ assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
23
+ assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
24
+ x = self.proj(x)
25
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
26
+ if self.flatten:
27
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
28
+ x = self.norm(x)
29
+ return x, pos
30
+
31
+
32
+ class ManyAR_PatchEmbed (PatchEmbed):
33
+ """ Handle images with non-square aspect ratio.
34
+ All images in the same batch have the same aspect ratio.
35
+ true_shape = [(height, width) ...] indicates the actual shape of each image.
36
+ """
37
+
38
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
39
+ self.embed_dim = embed_dim
40
+ super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
41
+
42
+ def forward(self, img, true_shape):
43
+ B, C, H, W = img.shape
44
+ assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'
45
+ assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
46
+ assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
47
+ assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}"
48
+
49
+ # size expressed in tokens
50
+ W //= self.patch_size[0]
51
+ H //= self.patch_size[1]
52
+ n_tokens = H * W
53
+
54
+ height, width = true_shape.T
55
+ is_landscape = (width >= height)
56
+ is_portrait = ~is_landscape
57
+
58
+ # allocate result
59
+ x = img.new_zeros((B, n_tokens, self.embed_dim))
60
+ pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)
61
+
62
+ # linear projection, transposed if necessary
63
+ x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()
64
+ x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()
65
+
66
+ pos[is_landscape] = self.position_getter(1, H, W, pos.device)
67
+ pos[is_portrait] = self.position_getter(1, W, H, pos.device)
68
+
69
+ x = self.norm(x)
70
+ return x, pos
dust3r/dust3r/post_process.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilities for interpreting the DUST3R output
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ from dust3r.utils.geometry import xy_grid
10
+
11
+
12
+ def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0.5, max_focal=3.5):
13
+ """ Reprojection method, for when the absolute depth is known:
14
+ 1) estimate the camera focal using a robust estimator
15
+ 2) reproject points onto true rays, minimizing a certain error
16
+ """
17
+ B, H, W, THREE = pts3d.shape
18
+ assert THREE == 3
19
+
20
+ # centered pixel grid
21
+ pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2
22
+ pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
23
+
24
+ if focal_mode == 'median':
25
+ with torch.no_grad():
26
+ # direct estimation of focal
27
+ u, v = pixels.unbind(dim=-1)
28
+ x, y, z = pts3d.unbind(dim=-1)
29
+ fx_votes = (u * z) / x
30
+ fy_votes = (v * z) / y
31
+
32
+ # assume square pixels, hence same focal for X and Y
33
+ f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
34
+ focal = torch.nanmedian(f_votes, dim=-1).values
35
+
36
+ elif focal_mode == 'weiszfeld':
37
+ # init focal with l2 closed form
38
+ # we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
39
+ xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1)
40
+
41
+ dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
42
+ dot_xy_xy = xy_over_z.square().sum(dim=-1)
43
+
44
+ focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
45
+
46
+ # iterative re-weighted least-squares
47
+ for iter in range(10):
48
+ # re-weighting by inverse of distance
49
+ dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
50
+ # print(dis.nanmean(-1))
51
+ w = dis.clip(min=1e-8).reciprocal()
52
+ # update the scaling with the new weights
53
+ focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
54
+ else:
55
+ raise ValueError(f'bad {focal_mode=}')
56
+
57
+ focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
58
+ focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base)
59
+ # print(focal)
60
+ return focal
dust3r/dust3r/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
dust3r/dust3r/utils/device.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for DUSt3R
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def todevice(batch, device, callback=None, non_blocking=False):
12
+ ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
13
+
14
+ batch: list, tuple, dict of tensors or other things
15
+ device: pytorch device or 'numpy'
16
+ callback: function that would be called on every sub-elements.
17
+ '''
18
+ if callback:
19
+ batch = callback(batch)
20
+
21
+ if isinstance(batch, dict):
22
+ return {k: todevice(v, device) for k, v in batch.items()}
23
+
24
+ if isinstance(batch, (tuple, list)):
25
+ return type(batch)(todevice(x, device) for x in batch)
26
+
27
+ x = batch
28
+ if device == 'numpy':
29
+ if isinstance(x, torch.Tensor):
30
+ x = x.detach().cpu().numpy()
31
+ elif x is not None:
32
+ if isinstance(x, np.ndarray):
33
+ x = torch.from_numpy(x)
34
+ if torch.is_tensor(x):
35
+ x = x.to(device, non_blocking=non_blocking)
36
+ return x
37
+
38
+
39
+ to_device = todevice # alias
40
+
41
+
42
+ def to_numpy(x): return todevice(x, 'numpy')
43
+ def to_cpu(x): return todevice(x, 'cpu')
44
+ def to_cuda(x): return todevice(x, 'cuda')
45
+
46
+
47
+ def collate_with_cat(whatever, lists=False):
48
+ if isinstance(whatever, dict):
49
+ return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
50
+
51
+ elif isinstance(whatever, (tuple, list)):
52
+ if len(whatever) == 0:
53
+ return whatever
54
+ elem = whatever[0]
55
+ T = type(whatever)
56
+
57
+ if elem is None:
58
+ return None
59
+ if isinstance(elem, (bool, float, int, str)):
60
+ return whatever
61
+ if isinstance(elem, tuple):
62
+ return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
63
+ if isinstance(elem, dict):
64
+ return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}
65
+
66
+ if isinstance(elem, torch.Tensor):
67
+ return listify(whatever) if lists else torch.cat(whatever)
68
+ if isinstance(elem, np.ndarray):
69
+ return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])
70
+
71
+ # otherwise, we just chain lists
72
+ return sum(whatever, T())
73
+
74
+
75
+ def listify(elems):
76
+ return [x for e in elems for x in e]
dust3r/dust3r/utils/geometry.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # geometry utilitary functions
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import numpy as np
9
+ from scipy.spatial import cKDTree as KDTree
10
+
11
+ from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
12
+ from dust3r.utils.device import to_numpy
13
+
14
+
15
+ def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):
16
+ """ Output a (H,W,2) array of int32
17
+ with output[j,i,0] = i + origin[0]
18
+ output[j,i,1] = j + origin[1]
19
+ """
20
+ if device is None:
21
+ # numpy
22
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
23
+ else:
24
+ # torch
25
+ arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
26
+ meshgrid, stack = torch.meshgrid, torch.stack
27
+ ones = lambda *a: torch.ones(*a, device=device)
28
+
29
+ tw, th = [arange(o, o+s, **arange_kw) for s, o in zip((W, H), origin)]
30
+ grid = meshgrid(tw, th, indexing='xy')
31
+ if homogeneous:
32
+ grid = grid + (ones((H, W)),)
33
+ if unsqueeze is not None:
34
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
35
+ if cat_dim is not None:
36
+ grid = stack(grid, cat_dim)
37
+ return grid
38
+
39
+
40
+ def geotrf(Trf, pts, ncol=None, norm=False):
41
+ """ Apply a geometric transformation to a list of 3-D points.
42
+
43
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
44
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
45
+
46
+ ncol: int. number of columns of the result (2 or 3)
47
+ norm: float. if != 0, the resut is projected on the z=norm plane.
48
+
49
+ Returns an array of projected 2d points.
50
+ """
51
+ assert Trf.ndim >= 2
52
+ if isinstance(Trf, np.ndarray):
53
+ pts = np.asarray(pts)
54
+ elif isinstance(Trf, torch.Tensor):
55
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
56
+
57
+ # adapt shape if necessary
58
+ output_reshape = pts.shape[:-1]
59
+ ncol = ncol or pts.shape[-1]
60
+
61
+ # optimized code
62
+ if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
63
+ Trf.ndim == 3 and pts.ndim == 4):
64
+ d = pts.shape[3]
65
+ if Trf.shape[-1] == d:
66
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
67
+ elif Trf.shape[-1] == d+1:
68
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
69
+ else:
70
+ raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
71
+ else:
72
+ if Trf.ndim >= 3:
73
+ n = Trf.ndim-2
74
+ assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
75
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
76
+
77
+ if pts.ndim > Trf.ndim:
78
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
79
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
80
+ elif pts.ndim == 2:
81
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
82
+ pts = pts[:, None, :]
83
+
84
+ if pts.shape[-1]+1 == Trf.shape[-1]:
85
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
86
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
87
+ elif pts.shape[-1] == Trf.shape[-1]:
88
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
89
+ pts = pts @ Trf
90
+ else:
91
+ pts = Trf @ pts.T
92
+ if pts.ndim >= 2:
93
+ pts = pts.swapaxes(-1, -2)
94
+
95
+ if norm:
96
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
97
+ if norm != 1:
98
+ pts *= norm
99
+
100
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
101
+ return res
102
+
103
+
104
+ def inv(mat):
105
+ """ Invert a torch or numpy matrix
106
+ """
107
+ if isinstance(mat, torch.Tensor):
108
+ return torch.linalg.inv(mat)
109
+ if isinstance(mat, np.ndarray):
110
+ return np.linalg.inv(mat)
111
+ raise ValueError(f'bad matrix type = {type(mat)}')
112
+
113
+
114
+ def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
115
+ """
116
+ Args:
117
+ - depthmap (BxHxW array):
118
+ - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
119
+ Returns:
120
+ pointmap of absolute coordinates (BxHxWx3 array)
121
+ """
122
+
123
+ if len(depth.shape) == 4:
124
+ B, H, W, n = depth.shape
125
+ else:
126
+ B, H, W = depth.shape
127
+ n = None
128
+
129
+ if len(pseudo_focal.shape) == 3: # [B,H,W]
130
+ pseudo_focalx = pseudo_focaly = pseudo_focal
131
+ elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
132
+ pseudo_focalx = pseudo_focal[:, 0]
133
+ if pseudo_focal.shape[1] == 2:
134
+ pseudo_focaly = pseudo_focal[:, 1]
135
+ else:
136
+ pseudo_focaly = pseudo_focalx
137
+ else:
138
+ raise NotImplementedError("Error, unknown input focal shape format.")
139
+
140
+ assert pseudo_focalx.shape == depth.shape[:3]
141
+ assert pseudo_focaly.shape == depth.shape[:3]
142
+ grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
143
+
144
+ # set principal point
145
+ if pp is None:
146
+ grid_x = grid_x - (W-1)/2
147
+ grid_y = grid_y - (H-1)/2
148
+ else:
149
+ grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
150
+ grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
151
+
152
+ if n is None:
153
+ pts3d = torch.empty((B, H, W, 3), device=depth.device)
154
+ pts3d[..., 0] = depth * grid_x / pseudo_focalx
155
+ pts3d[..., 1] = depth * grid_y / pseudo_focaly
156
+ pts3d[..., 2] = depth
157
+ else:
158
+ pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
159
+ pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
160
+ pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
161
+ pts3d[..., 2, :] = depth
162
+ return pts3d
163
+
164
+
165
+ def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
166
+ """
167
+ Args:
168
+ - depthmap (HxW array):
169
+ - camera_intrinsics: a 3x3 matrix
170
+ Returns:
171
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
172
+ """
173
+ camera_intrinsics = np.float32(camera_intrinsics)
174
+ H, W = depthmap.shape
175
+
176
+ # Compute 3D ray associated with each pixel
177
+ # Strong assumption: there are no skew terms
178
+ assert camera_intrinsics[0, 1] == 0.0
179
+ assert camera_intrinsics[1, 0] == 0.0
180
+ if pseudo_focal is None:
181
+ fu = camera_intrinsics[0, 0]
182
+ fv = camera_intrinsics[1, 1]
183
+ else:
184
+ assert pseudo_focal.shape == (H, W)
185
+ fu = fv = pseudo_focal
186
+ cu = camera_intrinsics[0, 2]
187
+ cv = camera_intrinsics[1, 2]
188
+
189
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
190
+ z_cam = depthmap
191
+ x_cam = (u - cu) * z_cam / fu
192
+ y_cam = (v - cv) * z_cam / fv
193
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
194
+
195
+ # Mask for valid coordinates
196
+ valid_mask = (depthmap > 0.0)
197
+ return X_cam, valid_mask
198
+
199
+
200
+ def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw):
201
+ """
202
+ Args:
203
+ - depthmap (HxW array):
204
+ - camera_intrinsics: a 3x3 matrix
205
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
206
+ Returns:
207
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
208
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
209
+
210
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
211
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
212
+ R_cam2world = camera_pose[:3, :3]
213
+ t_cam2world = camera_pose[:3, 3]
214
+
215
+ # Express in absolute coordinates (invalid depth values)
216
+ X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
217
+ return X_world, valid_mask
218
+
219
+
220
+ def colmap_to_opencv_intrinsics(K):
221
+ """
222
+ Modify camera intrinsics to follow a different convention.
223
+ Coordinates of the center of the top-left pixels are by default:
224
+ - (0.5, 0.5) in Colmap
225
+ - (0,0) in OpenCV
226
+ """
227
+ K = K.copy()
228
+ K[0, 2] -= 0.5
229
+ K[1, 2] -= 0.5
230
+ return K
231
+
232
+
233
+ def opencv_to_colmap_intrinsics(K):
234
+ """
235
+ Modify camera intrinsics to follow a different convention.
236
+ Coordinates of the center of the top-left pixels are by default:
237
+ - (0.5, 0.5) in Colmap
238
+ - (0,0) in OpenCV
239
+ """
240
+ K = K.copy()
241
+ K[0, 2] += 0.5
242
+ K[1, 2] += 0.5
243
+ return K
244
+
245
+
246
+ def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None):
247
+ """ renorm pointmaps pts1, pts2 with norm_mode
248
+ """
249
+ assert pts1.ndim >= 3 and pts1.shape[-1] == 3
250
+ assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
251
+ norm_mode, dis_mode = norm_mode.split('_')
252
+
253
+ if norm_mode == 'avg':
254
+ # gather all points together (joint normalization)
255
+ nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
256
+ nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
257
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
258
+
259
+ # compute distance to origin
260
+ all_dis = all_pts.norm(dim=-1)
261
+ if dis_mode == 'dis':
262
+ pass # do nothing
263
+ elif dis_mode == 'log1p':
264
+ all_dis = torch.log1p(all_dis)
265
+ elif dis_mode == 'warp-log1p':
266
+ # actually warp input points before normalizing them
267
+ log_dis = torch.log1p(all_dis)
268
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
269
+ H1, W1 = pts1.shape[1:-1]
270
+ pts1 = pts1 * warp_factor[:, :W1*H1].view(-1, H1, W1, 1)
271
+ if pts2 is not None:
272
+ H2, W2 = pts2.shape[1:-1]
273
+ pts2 = pts2 * warp_factor[:, W1*H1:].view(-1, H2, W2, 1)
274
+ all_dis = log_dis # this is their true distance afterwards
275
+ else:
276
+ raise ValueError(f'bad {dis_mode=}')
277
+
278
+ norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
279
+ else:
280
+ # gather all points together (joint normalization)
281
+ nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
282
+ nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
283
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
284
+
285
+ # compute distance to origin
286
+ all_dis = all_pts.norm(dim=-1)
287
+
288
+ if norm_mode == 'avg':
289
+ norm_factor = all_dis.nanmean(dim=1)
290
+ elif norm_mode == 'median':
291
+ norm_factor = all_dis.nanmedian(dim=1).values.detach()
292
+ elif norm_mode == 'sqrt':
293
+ norm_factor = all_dis.sqrt().nanmean(dim=1)**2
294
+ else:
295
+ raise ValueError(f'bad {norm_mode=}')
296
+
297
+ norm_factor = norm_factor.clip(min=1e-8)
298
+ while norm_factor.ndim < pts1.ndim:
299
+ norm_factor.unsqueeze_(-1)
300
+
301
+ res = pts1 / norm_factor
302
+ if pts2 is not None:
303
+ res = (res, pts2 / norm_factor)
304
+ return res
305
+
306
+
307
+ @torch.no_grad()
308
+ def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
309
+ # set invalid points to NaN
310
+ _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
311
+ _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None
312
+ _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
313
+
314
+ # compute median depth overall (ignoring nans)
315
+ if quantile == 0.5:
316
+ shift_z = torch.nanmedian(_z, dim=-1).values
317
+ else:
318
+ shift_z = torch.nanquantile(_z, quantile, dim=-1)
319
+ return shift_z # (B,)
320
+
321
+
322
+ @torch.no_grad()
323
+ def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):
324
+ # set invalid points to NaN
325
+ _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
326
+ _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None
327
+ _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
328
+
329
+ # compute median center
330
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
331
+ if z_only:
332
+ _center[..., :2] = 0 # do not center X and Y
333
+
334
+ # compute median norm
335
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
336
+ scale = torch.nanmedian(_norm, dim=1).values
337
+ return _center[:, None, :, :], scale[:, None, None, None]
338
+
339
+
340
+ def find_reciprocal_matches(P1, P2):
341
+ """
342
+ returns 3 values:
343
+ 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
344
+ 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
345
+ 3 - reciprocal_in_P2.sum(): the number of matches
346
+ """
347
+ tree1 = KDTree(P1)
348
+ tree2 = KDTree(P2)
349
+
350
+ _, nn1_in_P2 = tree2.query(P1, workers=8)
351
+ _, nn2_in_P1 = tree1.query(P2, workers=8)
352
+
353
+ reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))
354
+ reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))
355
+ assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
356
+ return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
357
+
358
+
359
+ def get_med_dist_between_poses(poses):
360
+ from scipy.spatial.distance import pdist
361
+ return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
dust3r/dust3r/utils/image.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions about images (loading/converting...)
6
+ # --------------------------------------------------------
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+ import PIL.Image
11
+ import torchvision.transforms as tvf
12
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
13
+ import cv2 # noqa
14
+
15
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
16
+
17
+
18
+ def imread_cv2(path, options=cv2.IMREAD_COLOR):
19
+ """ Open an image or a depthmap with opencv-python.
20
+ """
21
+ if path.endswith(('.exr', 'EXR')):
22
+ options = cv2.IMREAD_ANYDEPTH
23
+ img = cv2.imread(path, options)
24
+ if img is None:
25
+ raise IOError(f'Could not load image={path} with {options=}')
26
+ if img.ndim == 3:
27
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28
+ return img
29
+
30
+
31
+ def rgb(ftensor, true_shape=None):
32
+ if isinstance(ftensor, list):
33
+ return [rgb(x, true_shape=true_shape) for x in ftensor]
34
+ if isinstance(ftensor, torch.Tensor):
35
+ ftensor = ftensor.detach().cpu().numpy() # H,W,3
36
+ if ftensor.ndim == 3 and ftensor.shape[0] == 3:
37
+ ftensor = ftensor.transpose(1, 2, 0)
38
+ elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
39
+ ftensor = ftensor.transpose(0, 2, 3, 1)
40
+ if true_shape is not None:
41
+ H, W = true_shape
42
+ ftensor = ftensor[:H, :W]
43
+ if ftensor.dtype == np.uint8:
44
+ img = np.float32(ftensor) / 255
45
+ else:
46
+ img = (ftensor * 0.5) + 0.5
47
+ return img.clip(min=0, max=1)
48
+
49
+
50
+ def _resize_pil_image(img, long_edge_size):
51
+ S = max(img.size)
52
+ if S > long_edge_size:
53
+ interp = PIL.Image.LANCZOS
54
+ elif S <= long_edge_size:
55
+ interp = PIL.Image.BICUBIC
56
+ new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
57
+ return img.resize(new_size, interp)
58
+
59
+
60
+ def load_images(folder_or_list, size, square_ok=False):
61
+ """ open and convert all images in a list or folder to proper input format for DUSt3R
62
+ """
63
+ if isinstance(folder_or_list, str):
64
+ print(f'>> Loading images from {folder_or_list}')
65
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
66
+
67
+ elif isinstance(folder_or_list, list):
68
+ print(f'>> Loading a list of {len(folder_or_list)} images')
69
+ root, folder_content = '', folder_or_list
70
+
71
+ else:
72
+ raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
73
+
74
+ imgs = []
75
+ for path in folder_content:
76
+ if not path.endswith(('.jpg', '.jpeg', '.png', '.JPG')):
77
+ continue
78
+ img = PIL.Image.open(os.path.join(root, path)).convert('RGB')
79
+ W1, H1 = img.size
80
+ if size == 224:
81
+ # resize short side to 224 (then crop)
82
+ img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
83
+ else:
84
+ # resize long side to 512
85
+ img = _resize_pil_image(img, size)
86
+ W, H = img.size
87
+ cx, cy = W//2, H//2
88
+ if size == 224:
89
+ half = min(cx, cy)
90
+ img = img.crop((cx-half, cy-half, cx+half, cy+half))
91
+ else:
92
+ halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
93
+ if not (square_ok) and W == H:
94
+ halfh = 3*halfw/4
95
+ img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
96
+
97
+ W2, H2 = img.size
98
+ print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
99
+ imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
100
+ [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
101
+
102
+ assert imgs, 'no images foud at '+root
103
+ print(f' (Found {len(imgs)} images)')
104
+ return imgs
dust3r/dust3r/utils/misc.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for DUSt3R
6
+ # --------------------------------------------------------
7
+ import torch
8
+
9
+
10
+ def fill_default_args(kwargs, func):
11
+ import inspect # a bit hacky but it works reliably
12
+ signature = inspect.signature(func)
13
+
14
+ for k, v in signature.parameters.items():
15
+ if v.default is inspect.Parameter.empty:
16
+ continue
17
+ kwargs.setdefault(k, v.default)
18
+
19
+ return kwargs
20
+
21
+
22
+ def freeze_all_params(modules):
23
+ for module in modules:
24
+ try:
25
+ for n, param in module.named_parameters():
26
+ param.requires_grad = False
27
+ except AttributeError:
28
+ # module is directly a parameter
29
+ module.requires_grad = False
30
+
31
+
32
+ def is_symmetrized(gt1, gt2):
33
+ x = gt1['instance']
34
+ y = gt2['instance']
35
+ if len(x) == len(y) and len(x) == 1:
36
+ return False # special case of batchsize 1
37
+ ok = True
38
+ for i in range(0, len(x), 2):
39
+ ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
40
+ return ok
41
+
42
+
43
+ def flip(tensor):
44
+ """ flip so that tensor[0::2] <=> tensor[1::2] """
45
+ return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
46
+
47
+
48
+ def interleave(tensor1, tensor2):
49
+ res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
50
+ res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
51
+ return res1, res2
52
+
53
+
54
+ def transpose_to_landscape(head, activate=True):
55
+ """ Predict in the correct aspect-ratio,
56
+ then transpose the result in landscape
57
+ and stack everything back together.
58
+ """
59
+ def wrapper_no(decout, true_shape):
60
+ B = len(true_shape)
61
+ assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
62
+ H, W = true_shape[0].cpu().tolist()
63
+ res = head(decout, (H, W))
64
+ return res
65
+
66
+ def wrapper_yes(decout, true_shape):
67
+ B = len(true_shape)
68
+ # by definition, the batch is in landscape mode so W >= H
69
+ H, W = int(true_shape.min()), int(true_shape.max())
70
+
71
+ height, width = true_shape.T
72
+ is_landscape = (width >= height)
73
+ is_portrait = ~is_landscape
74
+
75
+ # true_shape = true_shape.cpu()
76
+ if is_landscape.all():
77
+ return head(decout, (H, W))
78
+ if is_portrait.all():
79
+ return transposed(head(decout, (W, H)))
80
+
81
+ # batch is a mix of both portraint & landscape
82
+ def selout(ar): return [d[ar] for d in decout]
83
+ l_result = head(selout(is_landscape), (H, W))
84
+ p_result = transposed(head(selout(is_portrait), (W, H)))
85
+
86
+ # allocate full result
87
+ result = {}
88
+ for k in l_result | p_result:
89
+ x = l_result[k].new(B, *l_result[k].shape[1:])
90
+ x[is_landscape] = l_result[k]
91
+ x[is_portrait] = p_result[k]
92
+ result[k] = x
93
+
94
+ return result
95
+
96
+ return wrapper_yes if activate else wrapper_no
97
+
98
+
99
+ def transposed(dic):
100
+ return {k: v.swapaxes(1, 2) for k, v in dic.items()}
101
+
102
+
103
+ def invalid_to_nans(arr, valid_mask, ndim=999):
104
+ if valid_mask is not None:
105
+ arr = arr.clone()
106
+ arr[~valid_mask] = float('nan')
107
+ if arr.ndim > ndim:
108
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
109
+ return arr
110
+
111
+
112
+ def invalid_to_zeros(arr, valid_mask, ndim=999):
113
+ if valid_mask is not None:
114
+ arr = arr.clone()
115
+ arr[~valid_mask] = 0
116
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
117
+ else:
118
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
119
+ if arr.ndim > ndim:
120
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
121
+ return arr, nnz
dust3r/dust3r/utils/path_to_croco.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # CroCo submodule import
6
+ # --------------------------------------------------------
7
+
8
+ import sys
9
+ import os.path as path
10
+ HERE_PATH = path.normpath(path.dirname(__file__))
11
+ CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco'))
12
+ CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models')
13
+ # check the presence of models directory in repo to be sure its cloned
14
+ if path.isdir(CROCO_MODELS_PATH):
15
+ # workaround for sibling import
16
+ sys.path.insert(0, CROCO_REPO_PATH)
17
+ else:
18
+ raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n "
19
+ "Did you forget to run 'git submodule update --init --recursive' ?")
dust3r/dust3r/viz.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Visualization utilities using trimesh
6
+ # --------------------------------------------------------
7
+ import PIL.Image
8
+ import numpy as np
9
+ from scipy.spatial.transform import Rotation
10
+ import torch
11
+
12
+ from dust3r.utils.geometry import geotrf, get_med_dist_between_poses
13
+ from dust3r.utils.device import to_numpy
14
+ from dust3r.utils.image import rgb
15
+
16
+ try:
17
+ import trimesh
18
+ except ImportError:
19
+ print('/!\\ module trimesh is not installed, cannot visualize results /!\\')
20
+
21
+
22
+ def cat_3d(vecs):
23
+ if isinstance(vecs, (np.ndarray, torch.Tensor)):
24
+ vecs = [vecs]
25
+ return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])
26
+
27
+
28
+ def show_raw_pointcloud(pts3d, colors, point_size=2):
29
+ scene = trimesh.Scene()
30
+
31
+ pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))
32
+ scene.add_geometry(pct)
33
+
34
+ scene.show(line_settings={'point_size': point_size})
35
+
36
+
37
+ def pts3d_to_trimesh(img, pts3d, valid=None):
38
+ H, W, THREE = img.shape
39
+ assert THREE == 3
40
+ assert img.shape == pts3d.shape
41
+
42
+ vertices = pts3d.reshape(-1, 3)
43
+
44
+ # make squares: each pixel == 2 triangles
45
+ idx = np.arange(len(vertices)).reshape(H, W)
46
+ idx1 = idx[:-1, :-1].ravel() # top-left corner
47
+ idx2 = idx[:-1, +1:].ravel() # right-left corner
48
+ idx3 = idx[+1:, :-1].ravel() # bottom-left corner
49
+ idx4 = idx[+1:, +1:].ravel() # bottom-right corner
50
+ faces = np.concatenate((
51
+ np.c_[idx1, idx2, idx3],
52
+ np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
53
+ np.c_[idx2, idx3, idx4],
54
+ np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
55
+ ), axis=0)
56
+
57
+ # prepare triangle colors
58
+ face_colors = np.concatenate((
59
+ img[:-1, :-1].reshape(-1, 3),
60
+ img[:-1, :-1].reshape(-1, 3),
61
+ img[+1:, +1:].reshape(-1, 3),
62
+ img[+1:, +1:].reshape(-1, 3)
63
+ ), axis=0)
64
+
65
+ # remove invalid faces
66
+ if valid is not None:
67
+ assert valid.shape == (H, W)
68
+ valid_idxs = valid.ravel()
69
+ valid_faces = valid_idxs[faces].all(axis=-1)
70
+ faces = faces[valid_faces]
71
+ face_colors = face_colors[valid_faces]
72
+
73
+ assert len(faces) == len(face_colors)
74
+ return dict(vertices=vertices, face_colors=face_colors, faces=faces)
75
+
76
+
77
+ def cat_meshes(meshes):
78
+ vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
79
+ n_vertices = np.cumsum([0]+[len(v) for v in vertices])
80
+ for i in range(len(faces)):
81
+ faces[i][:] += n_vertices[i]
82
+
83
+ vertices = np.concatenate(vertices)
84
+ colors = np.concatenate(colors)
85
+ faces = np.concatenate(faces)
86
+ return dict(vertices=vertices, face_colors=colors, faces=faces)
87
+
88
+
89
+ def show_duster_pairs(view1, view2, pred1, pred2):
90
+ import matplotlib.pyplot as pl
91
+ pl.ion()
92
+
93
+ for e in range(len(view1['instance'])):
94
+ i = view1['idx'][e]
95
+ j = view2['idx'][e]
96
+ img1 = rgb(view1['img'][e])
97
+ img2 = rgb(view2['img'][e])
98
+ conf1 = pred1['conf'][e].squeeze()
99
+ conf2 = pred2['conf'][e].squeeze()
100
+ score = conf1.mean()*conf2.mean()
101
+ print(f">> Showing pair #{e} {i}-{j} {score=:g}")
102
+ pl.clf()
103
+ pl.subplot(221).imshow(img1)
104
+ pl.subplot(223).imshow(img2)
105
+ pl.subplot(222).imshow(conf1, vmin=1, vmax=30)
106
+ pl.subplot(224).imshow(conf2, vmin=1, vmax=30)
107
+ pts1 = pred1['pts3d'][e]
108
+ pts2 = pred2['pts3d_in_other_view'][e]
109
+ pl.subplots_adjust(0, 0, 1, 1, 0, 0)
110
+ if input('show pointcloud? (y/n) ') == 'y':
111
+ show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)
112
+
113
+
114
+ def auto_cam_size(im_poses):
115
+ return 0.1 * get_med_dist_between_poses(im_poses)
116
+
117
+
118
+ class SceneViz:
119
+ def __init__(self):
120
+ self.scene = trimesh.Scene()
121
+
122
+ def add_pointcloud(self, pts3d, color, mask=None):
123
+ pts3d = to_numpy(pts3d)
124
+ mask = to_numpy(mask)
125
+ if mask is None:
126
+ mask = [slice(None)] * len(pts3d)
127
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
128
+ pct = trimesh.PointCloud(pts.reshape(-1, 3))
129
+
130
+ if isinstance(color, (list, np.ndarray, torch.Tensor)):
131
+ color = to_numpy(color)
132
+ col = np.concatenate([p[m] for p, m in zip(color, mask)])
133
+ assert col.shape == pts.shape
134
+ pct.visual.vertex_colors = uint8(col.reshape(-1, 3))
135
+ else:
136
+ assert len(color) == 3
137
+ pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)
138
+
139
+ self.scene.add_geometry(pct)
140
+ return self
141
+
142
+ def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.03):
143
+ pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))
144
+ add_scene_cam(self.scene, pose_c2w, color, image, focal, screen_width=cam_size)
145
+ return self
146
+
147
+ def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, **kw):
148
+ def get(arr, idx): return None if arr is None else arr[idx]
149
+ for i, pose_c2w in enumerate(poses):
150
+ self.add_camera(pose_c2w, get(focals, i), image=get(images, i),
151
+ color=get(colors, i), imsize=get(imsizes, i), **kw)
152
+ return self
153
+
154
+ def show(self, point_size=2):
155
+ self.scene.show(line_settings={'point_size': point_size})
156
+
157
+
158
+ def show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world,
159
+ point_size=2, cam_size=0.05, cam_color=None):
160
+ """ Visualization of a pointcloud with cameras
161
+ imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]
162
+ pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]
163
+ focals = (N,) or N-size list of [focal, ...]
164
+ cams2world = (N,4,4) or N-size list of [(4,4), ...]
165
+ """
166
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
167
+ pts3d = to_numpy(pts3d)
168
+ imgs = to_numpy(imgs)
169
+ focals = to_numpy(focals)
170
+ cams2world = to_numpy(cams2world)
171
+
172
+ scene = trimesh.Scene()
173
+
174
+ # full pointcloud
175
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
176
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
177
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
178
+ scene.add_geometry(pct)
179
+
180
+ # add each camera
181
+ for i, pose_c2w in enumerate(cams2world):
182
+ if isinstance(cam_color, list):
183
+ camera_edge_color = cam_color[i]
184
+ else:
185
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
186
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
187
+ imgs[i] if i < len(imgs) else None, focals[i], screen_width=cam_size)
188
+
189
+ scene.show(line_settings={'point_size': point_size})
190
+
191
+
192
+ def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03):
193
+
194
+ if image is not None:
195
+ H, W, THREE = image.shape
196
+ assert THREE == 3
197
+ if image.dtype != np.uint8:
198
+ image = np.uint8(255*image)
199
+ elif imsize is not None:
200
+ W, H = imsize
201
+ elif focal is not None:
202
+ H = W = focal / 1.1
203
+ else:
204
+ H = W = 1
205
+
206
+ if focal is None:
207
+ focal = min(H, W) * 1.1 # default value
208
+ elif isinstance(focal, np.ndarray):
209
+ focal = focal[0]
210
+
211
+ # create fake camera
212
+ height = focal * screen_width / H
213
+ width = screen_width * 0.5**0.5
214
+ rot45 = np.eye(4)
215
+ rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
216
+ rot45[2, 3] = -height # set the tip of the cone = optical center
217
+ aspect_ratio = np.eye(4)
218
+ aspect_ratio[0, 0] = W/H
219
+ transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
220
+ cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
221
+
222
+ # this is the image
223
+ if image is not None:
224
+ vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
225
+ faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
226
+ img = trimesh.Trimesh(vertices=vertices, faces=faces)
227
+ uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
228
+ img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))
229
+ scene.add_geometry(img)
230
+
231
+ # this is the camera mesh
232
+ rot2 = np.eye(4)
233
+ rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
234
+ vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
235
+ vertices = geotrf(transform, vertices)
236
+ faces = []
237
+ for face in cam.faces:
238
+ if 0 in face:
239
+ continue
240
+ a, b, c = face
241
+ a2, b2, c2 = face + len(cam.vertices)
242
+ a3, b3, c3 = face + 2*len(cam.vertices)
243
+
244
+ # add 3 pseudo-edges
245
+ faces.append((a, b, b2))
246
+ faces.append((a, a2, c))
247
+ faces.append((c2, b, c))
248
+
249
+ faces.append((a, b, b3))
250
+ faces.append((a, a3, c))
251
+ faces.append((c3, b, c))
252
+
253
+ # no culling
254
+ faces += [(c, b, a) for a, b, c in faces]
255
+
256
+ cam = trimesh.Trimesh(vertices=vertices, faces=faces)
257
+ cam.visual.face_colors[:, :3] = edge_color
258
+ scene.add_geometry(cam)
259
+
260
+
261
+ def cat(a, b):
262
+ return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))
263
+
264
+
265
+ OPENGL = np.array([[1, 0, 0, 0],
266
+ [0, -1, 0, 0],
267
+ [0, 0, -1, 0],
268
+ [0, 0, 0, 1]])
269
+
270
+
271
+ CAM_COLORS = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),
272
+ (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]
273
+
274
+
275
+ def uint8(colors):
276
+ if not isinstance(colors, np.ndarray):
277
+ colors = np.array(colors)
278
+ if np.issubdtype(colors.dtype, np.floating):
279
+ colors *= 255
280
+ assert 0 <= colors.min() and colors.max() < 256
281
+ return np.uint8(colors)
282
+
283
+
284
+ def segment_sky(image):
285
+ import cv2
286
+ from scipy import ndimage
287
+
288
+ # Convert to HSV
289
+ image = to_numpy(image)
290
+ if np.issubdtype(image.dtype, np.floating):
291
+ image = np.uint8(255*image.clip(min=0, max=1))
292
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
293
+
294
+ # Define range for blue color and create mask
295
+ lower_blue = np.array([0, 0, 100])
296
+ upper_blue = np.array([30, 255, 255])
297
+ mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
298
+
299
+ # add luminous gray
300
+ mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
301
+ mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
302
+ mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
303
+
304
+ # Morphological operations
305
+ kernel = np.ones((5, 5), np.uint8)
306
+ mask2 = ndimage.binary_opening(mask, structure=kernel)
307
+
308
+ # keep only largest CC
309
+ _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8)
310
+ cc_sizes = stats[1:, cv2.CC_STAT_AREA]
311
+ order = cc_sizes.argsort()[::-1] # bigger first
312
+ i = 0
313
+ selection = []
314
+ while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
315
+ selection.append(1 + order[i])
316
+ i += 1
317
+ mask3 = np.in1d(labels, selection).reshape(labels.shape)
318
+
319
+ # Apply mask
320
+ return torch.from_numpy(mask3)
dust3r/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ roma
4
+ gradio
5
+ matplotlib
6
+ tqdm
7
+ opencv-python
8
+ scipy
9
+ einops
10
+ trimesh
11
+ tensorboard
12
+ pyglet<2
dust3r/train.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # training code for DUSt3R
7
+ # --------------------------------------------------------
8
+ # References:
9
+ # MAE: https://github.com/facebookresearch/mae
10
+ # DeiT: https://github.com/facebookresearch/deit
11
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
12
+ # --------------------------------------------------------
13
+ import argparse
14
+ import datetime
15
+ import json
16
+ import numpy as np
17
+ import os
18
+ import sys
19
+ import time
20
+ import math
21
+ from collections import defaultdict
22
+ from pathlib import Path
23
+ from typing import Sized
24
+
25
+ import torch
26
+ import torch.backends.cudnn as cudnn
27
+ from torch.utils.tensorboard import SummaryWriter
28
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
29
+
30
+ from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model
31
+ from dust3r.datasets import get_data_loader # noqa
32
+ from dust3r.losses import * # noqa: F401, needed when loading the model
33
+ from dust3r.inference import loss_of_one_batch # noqa
34
+
35
+ import dust3r.utils.path_to_croco # noqa: F401
36
+ import croco.utils.misc as misc # noqa
37
+ from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler # noqa
38
+
39
+
40
+ def get_args_parser():
41
+ parser = argparse.ArgumentParser('DUST3R training', add_help=False)
42
+ # model and criterion
43
+ parser.add_argument('--model', default="AsymmetricCroCo3DStereo(patch_embed_cls='ManyAR_PatchEmbed')",
44
+ type=str, help="string containing the model to build")
45
+ parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint')
46
+ parser.add_argument('--train_criterion', default="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)",
47
+ type=str, help="train criterion")
48
+ parser.add_argument('--test_criterion', default=None, type=str, help="test criterion")
49
+
50
+ # dataset
51
+ parser.add_argument('--train_dataset', required=True, type=str, help="training set")
52
+ parser.add_argument('--test_dataset', default='[None]', type=str, help="testing set")
53
+
54
+ # training
55
+ parser.add_argument('--seed', default=0, type=int, help="Random seed")
56
+ parser.add_argument('--batch_size', default=64, type=int,
57
+ help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus")
58
+ parser.add_argument('--accum_iter', default=1, type=int,
59
+ help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)")
60
+ parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler")
61
+
62
+ parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)")
63
+ parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
64
+ parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
65
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
66
+ parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
67
+ help='lower lr bound for cyclic schedulers that hit 0')
68
+ parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR')
69
+
70
+ parser.add_argument('--amp', type=int, default=0,
71
+ choices=[0, 1], help="Use Automatic Mixed Precision for pretraining")
72
+
73
+ # others
74
+ parser.add_argument('--num_workers', default=8, type=int)
75
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
76
+ parser.add_argument('--local_rank', default=-1, type=int)
77
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
78
+
79
+ parser.add_argument('--eval_freq', type=int, default=1, help='Test loss evaluation frequency')
80
+ parser.add_argument('--save_freq', default=1, type=int,
81
+ help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth')
82
+ parser.add_argument('--keep_freq', default=20, type=int,
83
+ help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth')
84
+ parser.add_argument('--print_freq', default=20, type=int,
85
+ help='frequence (number of iterations) to print infos while training')
86
+
87
+ # output dir
88
+ parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output")
89
+ return parser
90
+
91
+
92
+ def main(args):
93
+ misc.init_distributed_mode(args)
94
+ global_rank = misc.get_rank()
95
+ world_size = misc.get_world_size()
96
+
97
+ print("output_dir: "+args.output_dir)
98
+ if args.output_dir:
99
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
100
+
101
+ # auto resume
102
+ last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth')
103
+ args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None
104
+
105
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
106
+ print("{}".format(args).replace(', ', ',\n'))
107
+
108
+ device = "cuda" if torch.cuda.is_available() else "cpu"
109
+ device = torch.device(device)
110
+
111
+ # fix the seed
112
+ seed = args.seed + misc.get_rank()
113
+ torch.manual_seed(seed)
114
+ np.random.seed(seed)
115
+
116
+ cudnn.benchmark = True
117
+
118
+ # training dataset and loader
119
+ print('Building train dataset {:s}'.format(args.train_dataset))
120
+ # dataset and loader
121
+ data_loader_train = build_dataset(args.train_dataset, args.batch_size, args.num_workers, test=False)
122
+ print('Building test dataset {:s}'.format(args.train_dataset))
123
+ data_loader_test = {dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True)
124
+ for dataset in args.test_dataset.split('+')}
125
+
126
+ # model
127
+ print('Loading model: {:s}'.format(args.model))
128
+ model = eval(args.model)
129
+ print(f'>> Creating train criterion = {args.train_criterion}')
130
+ train_criterion = eval(args.train_criterion).to(device)
131
+ print(f'>> Creating test criterion = {args.test_criterion or args.train_criterion}')
132
+ test_criterion = eval(args.test_criterion or args.criterion).to(device)
133
+
134
+ model.to(device)
135
+ model_without_ddp = model
136
+ print("Model = %s" % str(model_without_ddp))
137
+
138
+ if args.pretrained and not args.resume:
139
+ print('Loading pretrained: ', args.pretrained)
140
+ ckpt = torch.load(args.pretrained, map_location=device)
141
+ print(model.load_state_dict(ckpt['model'], strict=False))
142
+ del ckpt # in case it occupies memory
143
+
144
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
145
+ if args.lr is None: # only base_lr is specified
146
+ args.lr = args.blr * eff_batch_size / 256
147
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
148
+ print("actual lr: %.2e" % args.lr)
149
+ print("accumulate grad iterations: %d" % args.accum_iter)
150
+ print("effective batch size: %d" % eff_batch_size)
151
+
152
+ if args.distributed:
153
+ model = torch.nn.parallel.DistributedDataParallel(
154
+ model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True)
155
+ model_without_ddp = model.module
156
+
157
+ # following timm: set wd as 0 for bias and norm layers
158
+ param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay)
159
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
160
+ print(optimizer)
161
+ loss_scaler = NativeScaler()
162
+
163
+ def write_log_stats(epoch, train_stats, test_stats):
164
+ if misc.is_main_process():
165
+ if log_writer is not None:
166
+ log_writer.flush()
167
+
168
+ log_stats = dict(epoch=epoch, **{f'train_{k}': v for k, v in train_stats.items()})
169
+ for test_name in data_loader_test:
170
+ if test_name not in test_stats:
171
+ continue
172
+ log_stats.update({test_name+'_'+k: v for k, v in test_stats[test_name].items()})
173
+
174
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
175
+ f.write(json.dumps(log_stats) + "\n")
176
+
177
+ def save_model(epoch, fname, best_so_far):
178
+ misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
179
+ loss_scaler=loss_scaler, epoch=epoch, fname=fname, best_so_far=best_so_far)
180
+
181
+ best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp,
182
+ optimizer=optimizer, loss_scaler=loss_scaler)
183
+ if best_so_far is None:
184
+ best_so_far = float('inf')
185
+ if global_rank == 0 and args.output_dir is not None:
186
+ log_writer = SummaryWriter(log_dir=args.output_dir)
187
+ else:
188
+ log_writer = None
189
+
190
+ print(f"Start training for {args.epochs} epochs")
191
+ start_time = time.time()
192
+ train_stats = test_stats = {}
193
+ for epoch in range(args.start_epoch, args.epochs+1):
194
+
195
+ # Save immediately the last checkpoint
196
+ if epoch > args.start_epoch:
197
+ if args.save_freq and epoch % args.save_freq == 0 or epoch == args.epochs:
198
+ save_model(epoch-1, 'last', best_so_far)
199
+
200
+ # Test on multiple datasets
201
+ new_best = False
202
+ if (epoch > 0 and args.eval_freq > 0 and epoch % args.eval_freq == 0):
203
+ test_stats = {}
204
+ for test_name, testset in data_loader_test.items():
205
+ stats = test_one_epoch(model, test_criterion, testset,
206
+ device, epoch, log_writer=log_writer, args=args, prefix=test_name)
207
+ test_stats[test_name] = stats
208
+
209
+ # Save best of all
210
+ if stats['loss_med'] < best_so_far:
211
+ best_so_far = stats['loss_med']
212
+ new_best = True
213
+
214
+ # Save more stuff
215
+ write_log_stats(epoch, train_stats, test_stats)
216
+
217
+ if epoch > args.start_epoch:
218
+ if args.keep_freq and epoch % args.keep_freq == 0:
219
+ save_model(epoch-1, str(epoch), best_so_far)
220
+ if new_best:
221
+ save_model(epoch-1, 'best', best_so_far)
222
+ if epoch >= args.epochs:
223
+ break # exit after writing last test to disk
224
+
225
+ # Train
226
+ train_stats = train_one_epoch(
227
+ model, train_criterion, data_loader_train,
228
+ optimizer, device, epoch, loss_scaler,
229
+ log_writer=log_writer,
230
+ args=args)
231
+
232
+ total_time = time.time() - start_time
233
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
234
+ print('Training time {}'.format(total_time_str))
235
+
236
+ save_final_model(args, args.epochs, model_without_ddp, best_so_far=best_so_far)
237
+
238
+
239
+ def save_final_model(args, epoch, model_without_ddp, best_so_far=None):
240
+ output_dir = Path(args.output_dir)
241
+ checkpoint_path = output_dir / 'checkpoint-final.pth'
242
+ to_save = {
243
+ 'args': args,
244
+ 'model': model_without_ddp if isinstance(model_without_ddp, dict) else model_without_ddp.cpu().state_dict(),
245
+ 'epoch': epoch
246
+ }
247
+ if best_so_far is not None:
248
+ to_save['best_so_far'] = best_so_far
249
+ print(f'>> Saving model to {checkpoint_path} ...')
250
+ misc.save_on_master(to_save, checkpoint_path)
251
+
252
+
253
+ def build_dataset(dataset, batch_size, num_workers, test=False):
254
+ split = ['Train', 'Test'][test]
255
+ print(f'Building {split} Data loader for dataset: ', dataset)
256
+ loader = get_data_loader(dataset,
257
+ batch_size=batch_size,
258
+ num_workers=num_workers,
259
+ pin_mem=True,
260
+ shuffle=not (test),
261
+ drop_last=not (test))
262
+
263
+ print(f"{split} dataset length: ", len(loader))
264
+ return loader
265
+
266
+
267
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
268
+ data_loader: Sized, optimizer: torch.optim.Optimizer,
269
+ device: torch.device, epoch: int, loss_scaler,
270
+ args,
271
+ log_writer=None):
272
+ assert torch.backends.cuda.matmul.allow_tf32 == True
273
+
274
+ model.train(True)
275
+ metric_logger = misc.MetricLogger(delimiter=" ")
276
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
277
+ header = 'Epoch: [{}]'.format(epoch)
278
+ accum_iter = args.accum_iter
279
+
280
+ if log_writer is not None:
281
+ print('log_dir: {}'.format(log_writer.log_dir))
282
+
283
+ if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'):
284
+ data_loader.dataset.set_epoch(epoch)
285
+ if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'):
286
+ data_loader.sampler.set_epoch(epoch)
287
+
288
+ optimizer.zero_grad()
289
+
290
+ for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
291
+ epoch_f = epoch + data_iter_step / len(data_loader)
292
+
293
+ # we use a per iteration (instead of per epoch) lr scheduler
294
+ if data_iter_step % accum_iter == 0:
295
+ misc.adjust_learning_rate(optimizer, epoch_f, args)
296
+
297
+ loss_tuple = loss_of_one_batch(batch, model, criterion, device,
298
+ symmetrize_batch=True,
299
+ use_amp=bool(args.amp), ret='loss')
300
+ loss, loss_details = loss_tuple # criterion returns two values
301
+ loss_value = float(loss)
302
+
303
+ if not math.isfinite(loss_value):
304
+ print("Loss is {}, stopping training".format(loss_value), force=True)
305
+ sys.exit(1)
306
+
307
+ loss /= accum_iter
308
+ loss_scaler(loss, optimizer, parameters=model.parameters(),
309
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
310
+ if (data_iter_step + 1) % accum_iter == 0:
311
+ optimizer.zero_grad()
312
+
313
+ del loss
314
+ del batch
315
+
316
+ lr = optimizer.param_groups[0]["lr"]
317
+ metric_logger.update(epoch=epoch_f)
318
+ metric_logger.update(lr=lr)
319
+ metric_logger.update(loss=loss_value, **loss_details)
320
+
321
+ if (data_iter_step + 1) % accum_iter == 0 and ((data_iter_step + 1) % (accum_iter * args.print_freq)) == 0:
322
+ loss_value_reduce = misc.all_reduce_mean(loss_value) # MUST BE EXECUTED BY ALL NODES
323
+ if log_writer is None:
324
+ continue
325
+ """ We use epoch_1000x as the x-axis in tensorboard.
326
+ This calibrates different curves when batch size changes.
327
+ """
328
+ epoch_1000x = int(epoch_f * 1000)
329
+ log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
330
+ log_writer.add_scalar('train_lr', lr, epoch_1000x)
331
+ log_writer.add_scalar('train_iter', epoch_1000x, epoch_1000x)
332
+ for name, val in loss_details.items():
333
+ log_writer.add_scalar('train_'+name, val, epoch_1000x)
334
+
335
+ # gather the stats from all processes
336
+ metric_logger.synchronize_between_processes()
337
+ print("Averaged stats:", metric_logger)
338
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
339
+
340
+
341
+ @torch.no_grad()
342
+ def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
343
+ data_loader: Sized, device: torch.device, epoch: int,
344
+ args, log_writer=None, prefix='test'):
345
+
346
+ model.eval()
347
+ metric_logger = misc.MetricLogger(delimiter=" ")
348
+ metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9))
349
+ header = 'Test Epoch: [{}]'.format(epoch)
350
+
351
+ if log_writer is not None:
352
+ print('log_dir: {}'.format(log_writer.log_dir))
353
+
354
+ if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'):
355
+ data_loader.dataset.set_epoch(epoch)
356
+ if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'):
357
+ data_loader.sampler.set_epoch(epoch)
358
+
359
+ for _, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
360
+ loss_tuple = loss_of_one_batch(batch, model, criterion, device,
361
+ symmetrize_batch=True,
362
+ use_amp=bool(args.amp), ret='loss')
363
+ loss_value, loss_details = loss_tuple # criterion returns two values
364
+ metric_logger.update(loss=float(loss_value), **loss_details)
365
+
366
+ # gather the stats from all processes
367
+ metric_logger.synchronize_between_processes()
368
+ print("Averaged stats:", metric_logger)
369
+
370
+ aggs = [('avg', 'global_avg'), ('med', 'median')]
371
+ results = {f'{k}_{tag}': getattr(meter, attr) for k, meter in metric_logger.meters.items() for tag, attr in aggs}
372
+
373
+ if log_writer is not None:
374
+ for name, val in results.items():
375
+ log_writer.add_scalar(prefix+'_'+name, val, 1000*epoch)
376
+
377
+ return results
378
+
379
+
380
+ if __name__ == '__main__':
381
+ args = get_args_parser()
382
+ args = args.parse_args()
383
+ main(args)
setup.sh CHANGED
@@ -4,10 +4,13 @@ set -e
4
 
5
  cd "$(dirname "$0")"
6
 
7
- if [ ! -d "dust3r" ]; then
8
- git clone -b dev --recursive https://github.com/camenduru/dust3r
 
 
9
  fi
10
 
 
11
  python -m pip install -r requirements.txt
12
 
13
  if [[ "$(uname)" == "Linux" ]]; then
@@ -18,6 +21,7 @@ mkdir -p dust3r/checkpoints
18
 
19
  WEIGHTS=dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
20
  if [ ! -f "$WEIGHTS" ]; then
 
21
  wget https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth \
22
  -P dust3r/checkpoints
23
  fi
 
4
 
5
  cd "$(dirname "$0")"
6
 
7
+ echo "Initializing dust3r submodule..."
8
+ if [ ! -d "dust3r/.git" ]; then
9
+ git submodule add -f https://github.com/camenduru/dust3r.git dust3r 2>/dev/null || true
10
+ git submodule update --init --recursive
11
  fi
12
 
13
+ echo "Installing Python dependencies..."
14
  python -m pip install -r requirements.txt
15
 
16
  if [[ "$(uname)" == "Linux" ]]; then
 
21
 
22
  WEIGHTS=dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
23
  if [ ! -f "$WEIGHTS" ]; then
24
+ echo "Downloading model weights..."
25
  wget https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth \
26
  -P dust3r/checkpoints
27
  fi