brian4dwell commited on
Commit
9d31508
·
1 Parent(s): 594b88c

add stream3r

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. LICENSE +13 -0
  3. assets/pipeline.png +3 -0
  4. assets/teaser_dynamic.gif +3 -0
  5. configs/__init__.py +7 -0
  6. configs/callbacks/default.yaml +22 -0
  7. configs/callbacks/early_stopping.yaml +15 -0
  8. configs/callbacks/model_checkpoint.yaml +17 -0
  9. configs/callbacks/model_summary.yaml +5 -0
  10. configs/callbacks/none.yaml +0 -0
  11. configs/callbacks/rich_progress_bar.yaml +18 -0
  12. configs/data/multiview_dust3r.yaml +25 -0
  13. configs/debug/ddp_debug.yaml +48 -0
  14. configs/debug/default.yaml +35 -0
  15. configs/debug/fdr.yaml +9 -0
  16. configs/debug/limit.yaml +12 -0
  17. configs/debug/overfit.yaml +13 -0
  18. configs/debug/profiler.yaml +12 -0
  19. configs/eval.yaml +19 -0
  20. configs/experiment/stream3r/stream3r.yaml +125 -0
  21. configs/extras/default.yaml +8 -0
  22. configs/hparams_search/mnist_optuna.yaml +52 -0
  23. configs/hydra/default.yaml +19 -0
  24. configs/hydra/launcher/fair_a100.yaml +43 -0
  25. configs/local/.gitkeep +0 -0
  26. configs/logger/aim.yaml +28 -0
  27. configs/logger/comet.yaml +12 -0
  28. configs/logger/csv.yaml +7 -0
  29. configs/logger/many_loggers.yaml +9 -0
  30. configs/logger/mlflow.yaml +12 -0
  31. configs/logger/neptune.yaml +9 -0
  32. configs/logger/tensorboard.yaml +10 -0
  33. configs/logger/wandb.yaml +16 -0
  34. configs/model/stream3r.yaml +42 -0
  35. configs/paths/default.yaml +21 -0
  36. configs/train.yaml +49 -0
  37. configs/trainer/cpu.yaml +5 -0
  38. configs/trainer/ddp.yaml +12 -0
  39. configs/trainer/ddp_eval.yaml +16 -0
  40. configs/trainer/ddp_sim.yaml +7 -0
  41. configs/trainer/deepspeed_stage_2.yaml +9 -0
  42. configs/trainer/default.yaml +30 -0
  43. configs/trainer/gpu.yaml +5 -0
  44. configs/trainer/mps.yaml +5 -0
  45. eval/monodepth/eval_metrics.py +211 -0
  46. eval/monodepth/launch.py +146 -0
  47. eval/monodepth/metadata.py +187 -0
  48. eval/monodepth/run.sh +20 -0
  49. eval/monodepth/tools.py +399 -0
  50. eval/mv_recon/base.py +274 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.gif filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0
2
+ Copyright 2025 S-Lab
3
+
4
+ Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
7
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
8
+
9
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
10
+ IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
11
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
12
+
13
+ In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
assets/pipeline.png ADDED

Git LFS Details

  • SHA256: 099a2c82e37b04878112826abcc85b02cf86e0bd059824dfe98e4a99782d6aac
  • Pointer size: 131 Bytes
  • Size of remote file: 655 kB
assets/teaser_dynamic.gif ADDED

Git LFS Details

  • SHA256: eb25ab7cf2e3dcff862a3e8e82657dbba7fc0cbc36856f315d2b6e25f9bb9d72
  • Pointer size: 132 Bytes
  • Size of remote file: 2.33 MB
configs/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # this file is needed here to include configs when building project as a package
configs/callbacks/default.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model_checkpoint
3
+ - early_stopping
4
+ - model_summary
5
+ - rich_progress_bar
6
+ - _self_
7
+
8
+ model_checkpoint:
9
+ dirpath: ${paths.output_dir}/checkpoints
10
+ filename: "epoch_{epoch:03d}"
11
+ monitor: "val/loss"
12
+ mode: "min"
13
+ save_last: True
14
+ auto_insert_metric_name: False
15
+
16
+ early_stopping:
17
+ monitor: "val/loss"
18
+ patience: 100
19
+ mode: "min"
20
+
21
+ model_summary:
22
+ max_depth: -1
configs/callbacks/early_stopping.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
+
3
+ early_stopping:
4
+ _target_: lightning.pytorch.callbacks.EarlyStopping
5
+ monitor: ??? # quantity to be monitored, must be specified !!!
6
+ min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7
+ patience: 3 # number of checks with no improvement after which training will be stopped
8
+ verbose: False # verbosity mode
9
+ mode: "min" # "max" means higher metric value is better, can be also "min"
10
+ strict: True # whether to crash the training if monitor is not found in the validation metrics
11
+ check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
12
+ stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
13
+ divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
14
+ check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
15
+ # log_rank_zero_only: False # this keyword argument isn't available in stable version
configs/callbacks/model_checkpoint.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
+
3
+ model_checkpoint:
4
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
5
+ dirpath: null # directory to save the model file
6
+ filename: null # checkpoint filename
7
+ monitor: 'val/loss' # name of the logged metric which determines when model is improving
8
+ verbose: False # verbosity mode
9
+ save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt
10
+ save_top_k: 1 # save k best models (determined by above metric)
11
+ mode: "min" # "max" means higher metric value is better, can be also "min"
12
+ auto_insert_metric_name: False # when True, the checkpoints filenames will contain the metric name
13
+ save_weights_only: False # if True, then only the model’s weights will be saved
14
+ every_n_train_steps: null # number of training steps between checkpoints
15
+ train_time_interval: null # checkpoints are monitored at the specified time interval
16
+ every_n_epochs: 20 # number of epochs between checkpoints
17
+ save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
configs/callbacks/model_summary.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2
+
3
+ model_summary:
4
+ _target_: lightning.pytorch.callbacks.RichModelSummary
5
+ max_depth: 1 # the maximum depth of layer nesting that the summary will include
configs/callbacks/none.yaml ADDED
File without changes
configs/callbacks/rich_progress_bar.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
2
+
3
+ rich_progress_bar:
4
+ _target_: lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar
5
+ refresh_rate: 1
6
+ leave: false
7
+ theme:
8
+ _target_: lightning.pytorch.callbacks.progress.rich_progress.RichProgressBarTheme
9
+ description: green_yellow
10
+ progress_bar: green1
11
+ progress_bar_finished: green1
12
+ progress_bar_pulse: "#6206E0"
13
+ batch_progress: green_yellow
14
+ time: blue
15
+ processing_speed: cyan
16
+ metrics: grey82
17
+ metrics_text_delimiter: " "
18
+ metrics_format: .4g
configs/data/multiview_dust3r.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define the common data root and number of views
2
+ data_root: /path/to/dust3r_data
3
+ num_views: 4
4
+ num_views_val: 10
5
+
6
+ data_module:
7
+ _target_: stream3r.data.multiview_dust3r_datamodule.MultiViewDUSt3RDataModule
8
+ train_datasets:
9
+ - 80_000 @ Co3d_Multiview(split='train', num_views=${data.num_views}, window_degree_range=360, num_samples_per_window=100, ROOT='${data.data_root}/co3d_50_seqs_per_category_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
10
+ - 80_000 @ MegaDepth_Multiview(split='train', num_views=${data.num_views}, window_size=${python_eval:"${data.num_views} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/megadepth_processed', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
11
+ - 80_000 @ ScanNetpp_Multiview(split='train', num_views=${data.num_views}, window_size=${python_eval:"${data.num_views} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/scannetpp_processed', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
12
+ - 80_000 @ ARKitScenes_Multiview(split='train', num_views=${data.num_views}, num_samples_per_window=10, ROOT='${data.data_root}/arkitscenes_processed', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
13
+ - 80_000 @ Habitat_Multiview(1_000_000, split='train', num_views=${data.num_views}, ROOT='${data.data_root}/habitat_processed', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
14
+ validation_datasets:
15
+ - 100 @ Co3d_Multiview(split='test', num_views=${data.num_views_val}, window_degree_range=360, num_samples_per_window=100, ROOT='${data.data_root}/co3d_50_seqs_per_category_subset_processed', resolution=(512, 384), seed=777)
16
+ - 100 @ MegaDepth_Multiview(split='val', num_views=${data.num_views_val}, window_size=${python_eval:"${data.num_views_val} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/megadepth_processed', resolution=(512, 336), seed=777)
17
+ - 100 @ ScanNetpp_Multiview(split='train', num_views=${data.num_views_val}, window_size=${python_eval:"${data.num_views_val} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/scannetpp_processed', resolution=(512, 384), seed=777)
18
+ - 100 @ ARKitScenes_Multiview(split='train', num_views=${data.num_views_val}, num_samples_per_window=10, ROOT='${data.data_root}/arkitscenes_processed', resolution=(512, 384), seed=777)
19
+ - 100 @ Habitat_Multiview(100_000, split='val', num_views=${data.num_views_val}, ROOT='${data.data_root}/habitat_processed', resolution=(512,384), seed=777)
20
+ batch_size_per_device: 6
21
+ batch_size_per_device_val: 4
22
+ num_workers: 6
23
+ pin_memory: True
24
+
25
+
configs/debug/ddp_debug.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ #use a smaller dataset for faster initializations
4
+ defaults:
5
+ - override /data: multiview_dust3r_tiny
6
+ - override /logger:
7
+ - csv
8
+ - wandb
9
+
10
+ # overwrite task name so debugging logs are stored in separate folder
11
+ task_name: "debug"
12
+
13
+ logger:
14
+ wandb:
15
+ name: ${paths.run_folder_name}
16
+
17
+ # ckpt_path: /some/random/path
18
+
19
+ extras:
20
+ ignore_warnings: False
21
+ enforce_tags: False
22
+
23
+ # sets level of all command line loggers to 'DEBUG'
24
+ # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
25
+ hydra:
26
+ job_logging:
27
+ root:
28
+ level: DEBUG
29
+
30
+ model:
31
+ net:
32
+ random_image_idx_embedding: true
33
+
34
+ data:
35
+ num_views: 4
36
+ data_module:
37
+ num_workers: 0 # debuggers don't like multiprocessing
38
+ pin_memory: false # disable gpu memory pin
39
+ batch_size_per_device: 6
40
+
41
+ trainer:
42
+ log_every_n_steps: 1
43
+ devices: auto
44
+ # fast_dev_run: 1
45
+ limit_train_batches: 1
46
+ limit_val_batches: 10000
47
+ precision: 32
48
+
configs/debug/default.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # default debugging setup, runs 1 full epoch
4
+ # other debugging configs can inherit from this one
5
+
6
+ # overwrite task name so debugging logs are stored in separate folder
7
+ task_name: "debug"
8
+
9
+ # disable callbacks and loggers during debugging
10
+ callbacks: null
11
+ logger: null
12
+
13
+ extras:
14
+ ignore_warnings: False
15
+ enforce_tags: False
16
+
17
+ # sets level of all command line loggers to 'DEBUG'
18
+ # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
19
+ hydra:
20
+ job_logging:
21
+ root:
22
+ level: DEBUG
23
+
24
+ # use this to also set hydra loggers to 'DEBUG'
25
+ # verbose: True
26
+
27
+ trainer:
28
+ max_epochs: 1
29
+ accelerator: cpu # debuggers don't like gpus
30
+ devices: 1 # debuggers don't like multiprocessing
31
+ detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
32
+
33
+ data:
34
+ num_workers: 0 # debuggers don't like multiprocessing
35
+ pin_memory: False # disable gpu memory pin
configs/debug/fdr.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs 1 train, 1 validation and 1 test step
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ fast_dev_run: true
configs/debug/limit.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # uses only 1% of the training data and 5% of validation/test data
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 3
10
+ limit_train_batches: 0.01
11
+ limit_val_batches: 0.05
12
+ limit_test_batches: 0.05
configs/debug/overfit.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # overfits to 3 batches
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 20
10
+ overfit_batches: 3
11
+
12
+ # model ckpt and early stopping need to be disabled during overfitting
13
+ callbacks: null
configs/debug/profiler.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs with execution time profiling
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 1
10
+ profiler: "simple"
11
+ # profiler: "advanced"
12
+ # profiler: "pytorch"
configs/eval.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - _self_
5
+ - data: multiview_dust3r
6
+ - model: stream3r
7
+ - logger: many_loggers
8
+ - trainer: ddp_eval
9
+ - paths: default
10
+ - extras: default
11
+ - hydra: default
12
+ - eval: default
13
+
14
+ task_name: "eval"
15
+
16
+ tags: ["eval"]
17
+
18
+ # passing checkpoint path is necessary for evaluation
19
+ ckpt_path: ???
configs/experiment/stream3r/stream3r.yaml ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: stream3r
5
+
6
+ # seed for random number generators in pytorch, numpy and python.random
7
+ seed: 42
8
+
9
+
10
+ tags: ["train", "stream3r"]
11
+
12
+ task_name: stream3r
13
+ slurm_job_id: 99999 # must set in the command line
14
+
15
+ # ckpt_path: /path/to/resume.ckpt # uncomment to resume training from a checkpoint
16
+
17
+ paths:
18
+ run_folder_name: ${task_name}_${slurm_job_id}
19
+
20
+ logger:
21
+ wandb:
22
+ name: ${task_name}_${slurm_job_id}
23
+ project: stream3r
24
+
25
+ data:
26
+ data_scaling: 1.0
27
+ data_root: /data
28
+ num_views: 24
29
+ resolution:
30
+ - [518, 392]
31
+ - [518, 378]
32
+ - [518, 336]
33
+ - [518, 294]
34
+ - [518, 252]
35
+ - [518, 210]
36
+ - [518, 140]
37
+ - [378, 518]
38
+ - [336, 518]
39
+ - [294, 518]
40
+ - [252, 518]
41
+ - [224, 224]
42
+ allow_repeat: true
43
+ n_corres_train: 0
44
+ data_module:
45
+ _target_: stream3r.data.multiview_dust3r_datamodule.MultiViewDUSt3RDataModule
46
+ pin_memory: true
47
+ num_workers: 16
48
+ num_workers_val: 1 # have to be a low number when using DeepSpeed ZeRO-2
49
+ batch_size_per_device: 1
50
+ batch_size_per_device_val: 1
51
+ train_datasets:
52
+ - 44800 @ Co3d_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT='${data.data_root}/processed_co3d', aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
53
+ - 56000 @ WildRGBD_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_wildrgbd_mp", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
54
+ - 22400 @ ARKitScenesHighRes_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_arkitscene_highres", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
55
+ - 38400 @ ScanNet_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_scannet/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
56
+ - 16800 @ ScanNetpp_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_scannetpp/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
57
+ - 84000 @ MapFree_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_mapfree/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
58
+ - 20000 @ Waymo_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_waymo/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
59
+ - 56000 @ TartanAir_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_tartanair/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
60
+ - 9400 @ Spring(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_spring/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
61
+ - 36000 @ BEDLAM_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_bedlam/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
62
+ - 28800 @ MP3D_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_mp3d/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
63
+ - 14400 @ UASOL_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_uasol/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
64
+ - 1400 @ MVS_Synth_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_mvs_synth", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
65
+ - 7200 @ PointOdyssey_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_pointodyssey", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
66
+ - 11200 @ HyperSim_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_hypersim_new", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
67
+ - 22400 @ BlendedMVS_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_blendedmvs/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
68
+ - 22400 @ MegaDepth_Multi(allow_repeat=${data.allow_repeat}, split="train", ROOT="${data.data_root}/processed_megadepth", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
69
+ - 5600 @ VirtualKITTI2_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_vkitti", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
70
+ - 168 @ UnReal4K_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_unreal4k/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
71
+ - 74000 @ DL3DV_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_dl3dv_ours_parts/processed_dl3dv_ours", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
72
+ - 36000 @ DynamicReplica(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_dynamic_replica/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train})
73
+
74
+ model:
75
+ pretrained: weights/vggt/model.pt
76
+ net:
77
+ freeze: encoder
78
+
79
+ scheduler:
80
+ warmup_start_lr: 1e-6
81
+ warmup_epochs: 1
82
+
83
+ train_criterion:
84
+ _target_: stream3r.loss.losses.CausalLoss
85
+ gradient_loss: grad
86
+ is_metric: false
87
+
88
+ validation_criterion:
89
+ _target_: stream3r.loss.losses.CausalLoss
90
+ gradient_loss: grad
91
+ is_metric: false
92
+
93
+ optimizer:
94
+ _target_: torch.optim.AdamW
95
+ _partial_: true
96
+ lr: 1e-5
97
+ betas:
98
+ - 0.9
99
+ - 0.95
100
+ weight_decay: 0.05
101
+
102
+ trainer:
103
+ devices: auto
104
+ max_epochs: 500
105
+ accumulate_grad_batches: 4
106
+ strategy:
107
+ _target_: lightning.pytorch.strategies.DeepSpeedStrategy
108
+ timeout:
109
+ _target_: datetime.timedelta
110
+ minutes: 80
111
+ plugins: null
112
+ limit_val_batches: 0
113
+ precision: bf16-mixed
114
+ log_every_n_steps: 20
115
+
116
+ callbacks:
117
+ model_checkpoint:
118
+ every_n_train_steps: 2000
119
+ every_n_epochs: null
120
+ save_top_k: -1
121
+ filename: "{epoch:03d}-{step:08d}"
122
+ save_last: false
123
+ monitor: "train/loss"
124
+ early_stopping:
125
+ monitor: "train/loss"
configs/extras/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # disable python warnings if they annoy you
2
+ ignore_warnings: False
3
+
4
+ # ask user for tags if none are provided in the config
5
+ enforce_tags: True
6
+
7
+ # pretty print config tree at the start of the run using Rich library
8
+ print_config: True
configs/hparams_search/mnist_optuna.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # example hyperparameter optimization of some experiment with Optuna:
4
+ # python train.py -m hparams_search=mnist_optuna experiment=example
5
+
6
+ defaults:
7
+ - override /hydra/sweeper: optuna
8
+
9
+ # choose metric which will be optimized by Optuna
10
+ # make sure this is the correct name of some metric logged in lightning module!
11
+ optimized_metric: "val/acc_best"
12
+
13
+ # here we define Optuna hyperparameter search
14
+ # it optimizes for value returned from function with @hydra.main decorator
15
+ # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16
+ hydra:
17
+ mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
18
+
19
+ sweeper:
20
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
21
+
22
+ # storage URL to persist optimization results
23
+ # for example, you can use SQLite if you set 'sqlite:///example.db'
24
+ storage: null
25
+
26
+ # name of the study to persist optimization results
27
+ study_name: null
28
+
29
+ # number of parallel workers
30
+ n_jobs: 1
31
+
32
+ # 'minimize' or 'maximize' the objective
33
+ direction: maximize
34
+
35
+ # total number of runs that will be executed
36
+ n_trials: 20
37
+
38
+ # choose Optuna hyperparameter sampler
39
+ # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
40
+ # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
41
+ sampler:
42
+ _target_: optuna.samplers.TPESampler
43
+ seed: 1234
44
+ n_startup_trials: 10 # number of random sampling runs before optimization starts
45
+
46
+ # define hyperparameter search space
47
+ params:
48
+ model.optimizer.lr: interval(0.0001, 0.1)
49
+ data.batch_size: choice(32, 64, 128, 256)
50
+ model.net.lin1_size: choice(64, 128, 256)
51
+ model.net.lin2_size: choice(64, 128, 256)
52
+ model.net.lin3_size: choice(32, 64, 128, 256)
configs/hydra/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://hydra.cc/docs/configure_hydra/intro/
2
+
3
+ # enable color logging
4
+ defaults:
5
+ - override hydra_logging: colorlog
6
+ - override job_logging: colorlog
7
+
8
+ # output directory, generated dynamically on each run
9
+ run:
10
+ dir: ${paths.log_dir}/${task_name}/runs/${paths.run_folder_name}
11
+ sweep:
12
+ dir: ${paths.log_dir}/${task_name}/multiruns/${paths.run_folder_name}
13
+ subdir: ${hydra.job.num}
14
+
15
+ job_logging:
16
+ handlers:
17
+ file:
18
+ # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
19
+ filename: ${hydra.runtime.output_dir}/${task_name}.log
configs/hydra/launcher/fair_a100.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - submitit_slurm
3
+
4
+ # see: https://github.com/facebookresearch/hydra/blob/main/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py
5
+ _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher
6
+ submitit_folder: ${hydra.sweep.dir}/.submitit/%j
7
+ name: ${hydra.job.name}
8
+ timeout_min: 20160 # 14 days : 60 * 24 * 14
9
+ account: cortex
10
+ qos: cortex_high
11
+ comment: "multiview_dust3r experiment"
12
+ nodes: 1
13
+ gres: "gpu:8"
14
+ tasks_per_node: 8
15
+ cpus_per_task: 12
16
+ signal_delay_s: 120 # USR1 signal delay (seconds) before timeout
17
+ max_num_timeout: 0 # number of times the job can be restarted after timeout
18
+ array_parallelism: 256 # Maximum number of jobs running in parallel
19
+
20
+ # Useful to add parameters which are not currently available in the plugin.
21
+ # Eg: {"mail-user": "blublu@fb.com", "mail-type": "BEGIN"}
22
+ additional_parameters:
23
+ mail-user: "jianingy@meta.com"
24
+ mail-type: "BEGIN,END"
25
+ output: "/path/to/slurm_out/%x-%j.out"
26
+
27
+ setup: # A list of commands to run in sbatch befure running srun
28
+ - echo "Begin setting up env on head node ($HOSTNAME)..."
29
+ - echo $(env | grep SLURM)
30
+ - export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
31
+ - export MASTER_PORT=9929
32
+ - export RDZV_ID=$SLURM_JOBID
33
+ - export OMP_NUM_THREADS=12
34
+ - . /path/to/miniforge3/etc/profile.d/conda.sh # activate conda
35
+ - conda activate dust3r
36
+ - cd /path/to/project # cd to the project directory
37
+ - export NCCL_DEBUG=INFO
38
+ - export PYTHONFAULTHANDLER=1
39
+ - export TORCH_DISTRIBUTED_DEBUG=INFO
40
+ - echo "env setup on head node ($HOSTNAME) finished, starting srun..."
41
+
42
+ srun_args:
43
+ - "--cpu-bind=none" # This is critical to ensure dataloaders uses all CPUs!
configs/local/.gitkeep ADDED
File without changes
configs/logger/aim.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://aimstack.io/
2
+
3
+ # example usage in lightning module:
4
+ # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
5
+
6
+ # open the Aim UI with the following command (run in the folder containing the `.aim` folder):
7
+ # `aim up`
8
+
9
+ aim:
10
+ _target_: aim.pytorch_lightning.AimLogger
11
+ repo: ${paths.root_dir} # .aim folder will be created here
12
+ # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
13
+
14
+ # aim allows to group runs under experiment name
15
+ experiment: null # any string, set to "default" if not specified
16
+
17
+ train_metric_prefix: "train/"
18
+ val_metric_prefix: "val/"
19
+ test_metric_prefix: "test/"
20
+
21
+ # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
22
+ system_tracking_interval: 10 # set to null to disable system metrics tracking
23
+
24
+ # enable/disable logging of system params such as installed packages, git info, env vars, etc.
25
+ log_system_params: true
26
+
27
+ # enable/disable tracking console logs (default value is true)
28
+ capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
configs/logger/comet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.comet.ml
2
+
3
+ comet:
4
+ _target_: lightning.pytorch.loggers.comet.CometLogger
5
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6
+ save_dir: "${paths.output_dir}"
7
+ project_name: "lightning-hydra-template"
8
+ rest_api_key: null
9
+ # experiment_name: ""
10
+ experiment_key: null # set to resume experiment
11
+ offline: False
12
+ prefix: ""
configs/logger/csv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # csv logger built in lightning
2
+
3
+ csv:
4
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
5
+ save_dir: "${paths.output_dir}"
6
+ name: "csv/"
7
+ prefix: ""
configs/logger/many_loggers.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # train with many loggers at once
2
+
3
+ defaults:
4
+ # - comet
5
+ - csv
6
+ # - mlflow
7
+ # - neptune
8
+ - tensorboard
9
+ - wandb
configs/logger/mlflow.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://mlflow.org
2
+
3
+ mlflow:
4
+ _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
5
+ # experiment_name: ""
6
+ # run_name: ""
7
+ tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
8
+ tags: null
9
+ # save_dir: "./mlruns"
10
+ prefix: ""
11
+ artifact_location: null
12
+ # run_id: ""
configs/logger/neptune.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # https://neptune.ai
2
+
3
+ neptune:
4
+ _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
5
+ api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
6
+ project: username/lightning-hydra-template
7
+ # name: ""
8
+ log_model_checkpoints: True
9
+ prefix: ""
configs/logger/tensorboard.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.tensorflow.org/tensorboard/
2
+
3
+ tensorboard:
4
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
5
+ save_dir: "${paths.output_dir}/tensorboard/"
6
+ name: null
7
+ log_graph: False
8
+ default_hp_metric: True
9
+ prefix: ""
10
+ # version: ""
configs/logger/wandb.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://wandb.ai
2
+
3
+ wandb:
4
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
5
+ name: null # name of the run (normally generated by wandb)
6
+ save_dir: "${paths.output_dir}"
7
+ offline: False
8
+ id: null # pass correct id to resume experiment!
9
+ anonymous: null # enable anonymous logging
10
+ project: "stream3r"
11
+ log_model: False # upload lightning ckpts
12
+ prefix: "" # a string to put at the beginning of metric keys
13
+ # entity: "" # set to name of your wandb team
14
+ group: ""
15
+ tags: []
16
+ job_type: ""
configs/model/stream3r.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: stream3r.models.multiview_dust3r_module.MultiViewDUSt3RLitModule
2
+
3
+ pretrained: null
4
+ resume_from_checkpoint: ${ckpt_path}
5
+
6
+ eval_use_pts3d_from_local_head: true
7
+
8
+ train_criterion:
9
+ _target_: stream3r.loss.losses.CausalLoss
10
+
11
+ validation_criterion:
12
+ _target_: stream3r.loss.losses.CausalLoss
13
+
14
+ optimizer:
15
+ _target_: torch.optim.AdamW
16
+ _partial_: true
17
+ lr: 1e-4
18
+ betas:
19
+ - 0.9
20
+ - 0.95
21
+ weight_decay: 0.05
22
+
23
+ # scheduler:
24
+ # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
25
+ # _partial_: true
26
+ # mode: min
27
+ # factor: 0.1
28
+ # patience: 10
29
+
30
+ scheduler:
31
+ _target_: pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR
32
+ _partial_: true
33
+ warmup_epochs: 10
34
+ max_epochs: ${trainer.max_epochs}
35
+ eta_min: 1e-06
36
+
37
+ net:
38
+ _target_: stream3r.models.stream3r.STream3R
39
+
40
+
41
+ # compile model for faster training with pytorch 2.0
42
+ compile: false
configs/paths/default.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ # this requires PROJECT_ROOT environment variable to exist
3
+ # you can replace it with "." if you want the root to be the current working directory
4
+ # root_dir: R${oc.env:PROJECT_ROOT}
5
+ root_dir: .
6
+
7
+ # path to data directory
8
+ data_dir: ${paths.root_dir}/data/
9
+
10
+ # path to logging directory
11
+ log_dir: ${paths.root_dir}/logs/
12
+
13
+ # path to output directory, created dynamically by hydra
14
+ # path generation pattern is specified in `configs/hydra/default.yaml`
15
+ # use it to store all files generated during the run, like ckpts and metrics
16
+ output_dir: ${hydra:runtime.output_dir}
17
+
18
+ # path to working directory
19
+ work_dir: ${hydra:runtime.cwd}
20
+
21
+ run_folder_name: ${now:%Y-%m-%d}_${now:%H-%M-%S}
configs/train.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default configuration
4
+ # order of defaults determines the order in which configs override each other
5
+ defaults:
6
+ - _self_
7
+ - data: multiview_dust3r
8
+ - model: stream3r
9
+ - callbacks: default
10
+ - logger: many_loggers # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
+ - trainer: ddp
12
+ - paths: default
13
+ - extras: default
14
+ - hydra: default
15
+
16
+ # experiment configs allow for version control of specific hyperparameters
17
+ # e.g. best hyperparameters for given model and datamodule
18
+ - experiment: null
19
+
20
+ # config for hyperparameter optimization
21
+ - hparams_search: null
22
+
23
+ # optional local config for machine/user specific settings
24
+ # it's optional since it doesn't need to exist and is excluded from version control
25
+ - optional local: default
26
+
27
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
28
+ - debug: null
29
+
30
+ # task name, determines output directory path
31
+ task_name: "train"
32
+
33
+ # tags to help you identify your experiments
34
+ # you can overwrite this in experiment configs
35
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
36
+ tags: ["dev"]
37
+
38
+ # set False to skip model training
39
+ train: True
40
+
41
+ # evaluate on test set, using best model weights achieved during training
42
+ # lightning chooses best weights based on the metric specified in checkpoint callback
43
+ test: True
44
+
45
+ # simply provide checkpoint path to resume training
46
+ ckpt_path: null
47
+
48
+ # seed for random number generators in pytorch, numpy and python.random
49
+ seed: 42
configs/trainer/cpu.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: cpu
5
+ devices: 1
configs/trainer/ddp.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ # strategy: ddp
5
+ strategy: ddp_find_unused_parameters_true
6
+
7
+ accelerator: gpu
8
+ devices: auto
9
+ num_nodes: 1
10
+ sync_batchnorm: true
11
+
12
+ use_distributed_sampler: false
configs/trainer/ddp_eval.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ # strategy: ddp
5
+ strategy:
6
+ _target_: lightning.pytorch.strategies.DDPStrategy
7
+ timeout:
8
+ _target_: datetime.timedelta
9
+ minutes: 30
10
+
11
+ accelerator: gpu
12
+ devices: auto
13
+ num_nodes: 1
14
+ sync_batchnorm: true
15
+
16
+ use_distributed_sampler: false
configs/trainer/ddp_sim.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ # simulate DDP on CPU, useful for debugging
5
+ accelerator: cpu
6
+ devices: 2
7
+ strategy: ddp_spawn
configs/trainer/deepspeed_stage_2.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ # strategy: deepspeed_stage_2
5
+ strategy: deepspeed_stage_2
6
+
7
+ accelerator: gpu
8
+ devices: auto
9
+ num_nodes: 1
configs/trainer/default.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: lightning.pytorch.trainer.Trainer
2
+ _convert_: partial
3
+
4
+ default_root_dir: ${paths.output_dir}
5
+
6
+ min_epochs: 1 # prevents early stopping
7
+ max_epochs: 100
8
+
9
+ accelerator: cpu
10
+ devices: 1
11
+
12
+ # mixed precision for extra speed-up
13
+ # precision: 16
14
+
15
+ # perform a validation loop every N training epochs
16
+ check_val_every_n_epoch: 1
17
+
18
+ # set True to to ensure deterministic results
19
+ # makes training slower but gives more reproducibility than just setting seeds
20
+ deterministic: False
21
+
22
+ plugins:
23
+ - _target_: lightning.pytorch.plugins.environments.SLURMEnvironment
24
+ auto_requeue: true # auto-resubmit the job when it is preempted by slurm
25
+ requeue_signal: ${python_eval:"signal.SIGUSR1"} # singal code is platform dependent, so it has to be decided at runtime
26
+ # requeue_signal:
27
+ # _target_: signal.Signals
28
+ # _args_:
29
+ # - 10 # SIGUSR1, see: https://chromium.googlesource.com/chromiumos/docs/+/master/constants/signals.md
30
+
configs/trainer/gpu.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: gpu
5
+ devices: 1
configs/trainer/mps.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: mps
5
+ devices: 1
eval/monodepth/eval_metrics.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
5
+ from eval.monodepth.tools import depth_evaluation
6
+ import numpy as np
7
+ import json
8
+ from tqdm import tqdm
9
+ import glob
10
+ import cv2
11
+ from eval.monodepth.metadata import dataset_metadata
12
+ import argparse
13
+ from PIL import Image
14
+
15
+ TAG_FLOAT = 202021.25
16
+
17
+
18
+ def depth_read_sintel(filename):
19
+ """Read depth data from file, return as numpy array."""
20
+ f = open(filename, "rb")
21
+ check = np.fromfile(f, dtype=np.float32, count=1)[0]
22
+ assert (
23
+ check == TAG_FLOAT
24
+ ), " depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format(
25
+ TAG_FLOAT, check
26
+ )
27
+ width = np.fromfile(f, dtype=np.int32, count=1)[0]
28
+ height = np.fromfile(f, dtype=np.int32, count=1)[0]
29
+ size = width * height
30
+ assert (
31
+ width > 0 and height > 0 and size > 1 and size < 100000000
32
+ ), " depth_read:: Wrong input size (width = {0}, height = {1}).".format(
33
+ width, height
34
+ )
35
+ depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width))
36
+ return depth
37
+
38
+
39
+ def depth_read_bonn(filename):
40
+ # loads depth map D from png file
41
+ # and returns it as a numpy array
42
+ depth_png = np.asarray(Image.open(filename))
43
+ # make sure we have a proper 16bit depth map here.. not 8bit!
44
+ assert np.max(depth_png) > 255
45
+ depth = depth_png.astype(np.float64) / 5000.0
46
+ depth[depth_png == 0] = -1.0
47
+ return depth
48
+
49
+
50
+ def depth_read_kitti(filename):
51
+ # loads depth map D from png file
52
+ # and returns it as a numpy array,
53
+ # for details see readme.txt
54
+ img_pil = Image.open(filename)
55
+ depth_png = np.array(img_pil, dtype=int)
56
+ # make sure we have a proper 16bit depth map here.. not 8bit!
57
+ assert np.max(depth_png) > 255
58
+
59
+ depth = depth_png.astype(float) / 256.0
60
+ depth[depth_png == 0] = -1.0
61
+ return depth
62
+
63
+
64
+ def get_gt_depth(filename, dataset):
65
+ if dataset == "sintel":
66
+ return depth_read_sintel(filename)
67
+ elif dataset == "bonn":
68
+ return depth_read_bonn(filename)
69
+ elif dataset == "kitti":
70
+ return depth_read_kitti(filename)
71
+ elif dataset == "nyu":
72
+ return np.load(filename)
73
+ else:
74
+ raise NotImplementedError
75
+
76
+
77
+ def get_args_parser():
78
+ parser = argparse.ArgumentParser()
79
+
80
+ parser.add_argument(
81
+ "--output_dir",
82
+ type=str,
83
+ default="",
84
+ help="value for outdir",
85
+ )
86
+ parser.add_argument(
87
+ "--eval_dataset", type=str, default="nyu", choices=list(dataset_metadata.keys())
88
+ )
89
+ return parser
90
+
91
+
92
+ def main(args):
93
+ if args.eval_dataset == "nyu":
94
+ depth_pathes = glob.glob("data/nyu-v2/val/nyu_depths/*.npy")
95
+ depth_pathes = sorted(depth_pathes)
96
+ pred_pathes = glob.glob(
97
+ f"{args.output_dir}/*.npy"
98
+ ) # TODO: update the path to your prediction
99
+ pred_pathes = sorted(pred_pathes)
100
+ elif args.eval_dataset == "sintel":
101
+ pred_pathes = glob.glob(
102
+ f"{args.output_dir}/*/*.npy"
103
+ ) # TODO: update the path to your prediction
104
+ pred_pathes = sorted(pred_pathes)
105
+ full = len(pred_pathes) > 643
106
+ if full:
107
+ depth_pathes = glob.glob(f"data/sintel/training/depth/*/*.dpt")
108
+ depth_pathes = sorted(depth_pathes)
109
+ else:
110
+ seq_list = [
111
+ "alley_2",
112
+ "ambush_4",
113
+ "ambush_5",
114
+ "ambush_6",
115
+ "cave_2",
116
+ "cave_4",
117
+ "market_2",
118
+ "market_5",
119
+ "market_6",
120
+ "shaman_3",
121
+ "sleeping_1",
122
+ "sleeping_2",
123
+ "temple_2",
124
+ "temple_3",
125
+ ]
126
+ depth_pathes_folder = [
127
+ f"data/sintel/training/depth/{seq}" for seq in seq_list
128
+ ]
129
+ depth_pathes = []
130
+ for depth_pathes_folder_i in depth_pathes_folder:
131
+ depth_pathes += glob.glob(depth_pathes_folder_i + "/*.dpt")
132
+ depth_pathes = sorted(depth_pathes)
133
+ elif args.eval_dataset == "bonn":
134
+ seq_list = ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]
135
+ img_pathes_folder = [
136
+ f"data/bonn/rgbd_bonn_dataset/rgbd_bonn_{seq}/rgb_110/*.png"
137
+ for seq in seq_list
138
+ ]
139
+ img_pathes = []
140
+ for img_pathes_folder_i in img_pathes_folder:
141
+ img_pathes += glob.glob(img_pathes_folder_i)
142
+ img_pathes = sorted(img_pathes)
143
+ depth_pathes_folder = [
144
+ f"data/bonn/rgbd_bonn_dataset/rgbd_bonn_{seq}/depth_110/*.png"
145
+ for seq in seq_list
146
+ ]
147
+ depth_pathes = []
148
+ for depth_pathes_folder_i in depth_pathes_folder:
149
+ depth_pathes += glob.glob(depth_pathes_folder_i)
150
+ depth_pathes = sorted(depth_pathes)
151
+ pred_pathes = glob.glob(
152
+ f"{args.output_dir}/*/*.npy"
153
+ ) # TODO: update the path to your prediction
154
+ pred_pathes = sorted(pred_pathes)
155
+ elif args.eval_dataset == "kitti":
156
+ depth_pathes = glob.glob(
157
+ "data/kitti/depth_selection/val_selection_cropped/groundtruth_depth_gathered/*/*.png"
158
+ )
159
+ depth_pathes = sorted(depth_pathes)
160
+ pred_pathes = glob.glob(
161
+ f"{args.output_dir}/*/*depth.npy"
162
+ ) # TODO: update the path to your prediction
163
+ pred_pathes = sorted(pred_pathes)
164
+ else:
165
+ raise NotImplementedError
166
+
167
+ gathered_depth_metrics = []
168
+ for idx in tqdm(range(len(depth_pathes))):
169
+ pred_depth = np.load(pred_pathes[idx])
170
+ gt_depth = get_gt_depth(depth_pathes[idx], args.eval_dataset)
171
+ pred_depth = cv2.resize(
172
+ pred_depth,
173
+ (gt_depth.shape[1], gt_depth.shape[0]),
174
+ interpolation=cv2.INTER_CUBIC,
175
+ )
176
+ if args.eval_dataset == "nyu":
177
+ depth_results, error_map, depth_predict, depth_gt = depth_evaluation(
178
+ pred_depth, gt_depth, max_depth=None, lr=1e-3
179
+ )
180
+ elif args.eval_dataset == "sintel":
181
+ depth_results, error_map, depth_predict, depth_gt = depth_evaluation(
182
+ pred_depth, gt_depth, max_depth=70, use_gpu=True, post_clip_max=70
183
+ )
184
+ elif args.eval_dataset == "bonn":
185
+ depth_results, error_map, depth_predict, depth_gt = depth_evaluation(
186
+ pred_depth, gt_depth, max_depth=70, use_gpu=True
187
+ )
188
+ elif args.eval_dataset == "kitti":
189
+ depth_results, error_map, depth_predict, depth_gt = depth_evaluation(
190
+ pred_depth, gt_depth, max_depth=None, use_gpu=True
191
+ )
192
+ gathered_depth_metrics.append(depth_results)
193
+
194
+ depth_log_path = os.path.join(args.output_dir, "metric.json")
195
+ average_metrics = {
196
+ key: np.average(
197
+ [metrics[key] for metrics in gathered_depth_metrics],
198
+ weights=[metrics["valid_pixels"] for metrics in gathered_depth_metrics],
199
+ )
200
+ for key in gathered_depth_metrics[0].keys()
201
+ if key != "valid_pixels"
202
+ }
203
+ print(f"{args.eval_dataset} - Average depth evaluation metrics:", average_metrics)
204
+ with open(depth_log_path, "w") as f:
205
+ f.write(json.dumps(average_metrics))
206
+
207
+
208
+ if __name__ == "__main__":
209
+ args = get_args_parser()
210
+ args = args.parse_args()
211
+ main(args)
eval/monodepth/launch.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib
4
+ import numpy as np
5
+ import cv2
6
+ import argparse
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+ import os
10
+ import sys
11
+
12
+ from stream3r.models.stream3r import STream3R
13
+ from stream3r.dust3r.utils.device import collate_with_cat
14
+ from stream3r.dust3r.utils.image import load_images_for_eval as load_images
15
+ from stream3r.utils.utils import ImgDust3r2Stream3r
16
+
17
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
18
+ from eval.monodepth.metadata import dataset_metadata
19
+
20
+
21
+ torch.backends.cuda.matmul.allow_tf32 = True
22
+
23
+ # avoid high cpu usage
24
+ os.environ["OMP_NUM_THREADS"] = "1"
25
+ os.environ["MKL_NUM_THREADS"] = "1"
26
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
27
+ os.environ["OPENBLAS_NUM_THREADS"] = "1"
28
+ torch.set_num_threads(1)
29
+ # ===========================================
30
+
31
+
32
+ def colorize_depth(depth: np.ndarray,
33
+ mask: np.ndarray = None,
34
+ normalize: bool = True,
35
+ cmap: str = 'Spectral') -> np.ndarray:
36
+ if mask is None:
37
+ depth = np.where(depth > 0, depth, np.nan)
38
+ else:
39
+ depth = np.where((depth > 0) & mask, depth, np.nan)
40
+ disp = 1 / depth
41
+ if normalize:
42
+ min_disp, max_disp = np.nanquantile(disp,
43
+ 0.001), np.nanquantile(disp, 0.99)
44
+ disp = (disp - min_disp) / (max_disp - min_disp)
45
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
46
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
47
+ return colored
48
+
49
+
50
+ def get_args_parser():
51
+ parser = argparse.ArgumentParser()
52
+
53
+ parser.add_argument("--device",
54
+ type=str,
55
+ default="cuda",
56
+ help="pytorch device")
57
+ parser.add_argument("--output_dir",
58
+ type=str,
59
+ default="",
60
+ help="value for outdir")
61
+ parser.add_argument("--no_crop",
62
+ type=bool,
63
+ default=True,
64
+ help="whether to crop input data")
65
+ parser.add_argument("--full_seq",
66
+ type=bool,
67
+ default=False,
68
+ help="whether to use all seqs")
69
+ parser.add_argument("--seq_list", default=None)
70
+
71
+ parser.add_argument("--eval_dataset",
72
+ type=str,
73
+ default="nyu",
74
+ choices=list(dataset_metadata.keys()))
75
+ return parser
76
+
77
+
78
+ def eval_mono_depth_estimation(args, model, device):
79
+ metadata = dataset_metadata.get(args.eval_dataset)
80
+ if metadata is None:
81
+ raise ValueError(f"Unknown dataset: {args.eval_dataset}")
82
+
83
+ img_path = metadata.get("img_path")
84
+ if "img_path_func" in metadata:
85
+ img_path = metadata["img_path_func"](args)
86
+
87
+ process_func = metadata.get("process_func")
88
+ if process_func is None:
89
+ raise ValueError(
90
+ f"No processing function defined for dataset: {args.eval_dataset}")
91
+
92
+ for filelist, save_dir in process_func(args, img_path):
93
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
94
+ eval_mono_depth(args, model, device, filelist, save_dir=save_dir)
95
+
96
+
97
+ def eval_mono_depth(args, model, device, filelist, save_dir=None):
98
+ for file in tqdm(filelist):
99
+ file = [file]
100
+ images = load_images(
101
+ file,
102
+ size=518,
103
+ verbose=True,
104
+ crop=False,
105
+ patch_size=14,
106
+ )
107
+
108
+ images = collate_with_cat([tuple(images)])
109
+ images = torch.stack([view["img"] for view in images], dim=1)
110
+ images = ImgDust3r2Stream3r(images).to(device)
111
+
112
+ with torch.no_grad():
113
+ predictions = model(images)
114
+
115
+ depth_map = predictions['depth'][0,0].squeeze(-1).cpu()
116
+
117
+ if save_dir is not None:
118
+ # save the depth map to the save_dir as npy
119
+ np.save(
120
+ f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.npy')}",
121
+ depth_map.cpu().numpy(),
122
+ )
123
+ depth_map = colorize_depth(depth_map)
124
+ cv2.imwrite(
125
+ f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.jpg')}",
126
+ depth_map,
127
+ )
128
+
129
+
130
+ def main():
131
+ args = get_args_parser()
132
+ args = args.parse_args()
133
+
134
+ if args.eval_dataset == "sintel":
135
+ args.full_seq = True
136
+ else:
137
+ args.full_seq = False
138
+
139
+ model = STream3R.from_pretrained("yslan/STream3R").to(args.device)
140
+ model.eval()
141
+
142
+ eval_mono_depth_estimation(args, model, args.device)
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
eval/monodepth/metadata.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from tqdm import tqdm
4
+
5
+ # Define the merged dataset metadata dictionary
6
+ dataset_metadata = {
7
+ "sun_rgbd": {
8
+ "img_path": "data/sun_rgbd/image/test",
9
+ "mask_path": None,
10
+ },
11
+ "davis": {
12
+ "img_path": "data/davis/DAVIS/JPEGImages/480p",
13
+ "mask_path": "data/davis/DAVIS/masked_images/480p",
14
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
15
+ "gt_traj_func": lambda img_path, anno_path, seq: None,
16
+ "traj_format": None,
17
+ "seq_list": None,
18
+ "full_seq": True,
19
+ "mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq),
20
+ "skip_condition": None,
21
+ "process_func": None, # Not used in mono depth estimation
22
+ },
23
+ "kitti": {
24
+ "img_path": "data/kitti/depth_selection/val_selection_cropped/image_gathered", # Default path
25
+ "mask_path": None,
26
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
27
+ "gt_traj_func": lambda img_path, anno_path, seq: None,
28
+ "traj_format": None,
29
+ "seq_list": None,
30
+ "full_seq": True,
31
+ "mask_path_seq_func": lambda mask_path, seq: None,
32
+ "skip_condition": None,
33
+ "process_func": lambda args, img_path: process_kitti(args, img_path),
34
+ },
35
+ "bonn": {
36
+ "img_path": "data/bonn/rgbd_bonn_dataset",
37
+ "mask_path": None,
38
+ "dir_path_func": lambda img_path, seq: os.path.join(
39
+ img_path, f"rgbd_bonn_{seq}", "rgb_110"
40
+ ),
41
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
42
+ img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt"
43
+ ),
44
+ "traj_format": "tum",
45
+ "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
46
+ "full_seq": False,
47
+ "mask_path_seq_func": lambda mask_path, seq: None,
48
+ "skip_condition": None,
49
+ "process_func": lambda args, img_path: process_bonn(args, img_path),
50
+ },
51
+ "nyu": {
52
+ "img_path": "data/nyu-v2/val/nyu_images",
53
+ "mask_path": None,
54
+ "process_func": lambda args, img_path: process_nyu(args, img_path),
55
+ },
56
+ "scannet": {
57
+ "img_path": "data/scannetv2",
58
+ "mask_path": None,
59
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
60
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
61
+ img_path, seq, "pose_90.txt"
62
+ ),
63
+ "traj_format": "replica",
64
+ "seq_list": None,
65
+ "full_seq": True,
66
+ "mask_path_seq_func": lambda mask_path, seq: None,
67
+ "skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
68
+ "process_func": lambda args, img_path: process_scannet(args, img_path),
69
+ },
70
+ "tum": {
71
+ "img_path": "data/tum",
72
+ "mask_path": None,
73
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"),
74
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
75
+ img_path, seq, "groundtruth_90.txt"
76
+ ),
77
+ "traj_format": "tum",
78
+ "seq_list": None,
79
+ "full_seq": True,
80
+ "mask_path_seq_func": lambda mask_path, seq: None,
81
+ "skip_condition": None,
82
+ "process_func": None,
83
+ },
84
+ "sintel": {
85
+ "img_path": "data/sintel/training/final",
86
+ "anno_path": "data/sintel/training/camdata_left",
87
+ "mask_path": None,
88
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
89
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq),
90
+ "traj_format": None,
91
+ "seq_list": [
92
+ "alley_2",
93
+ "ambush_4",
94
+ "ambush_5",
95
+ "ambush_6",
96
+ "cave_2",
97
+ "cave_4",
98
+ "market_2",
99
+ "market_5",
100
+ "market_6",
101
+ "shaman_3",
102
+ "sleeping_1",
103
+ "sleeping_2",
104
+ "temple_2",
105
+ "temple_3",
106
+ ],
107
+ "full_seq": False,
108
+ "mask_path_seq_func": lambda mask_path, seq: None,
109
+ "skip_condition": None,
110
+ "process_func": lambda args, img_path: process_sintel(args, img_path),
111
+ },
112
+ }
113
+
114
+
115
+ # Define processing functions for each dataset
116
+ def process_kitti(args, img_path):
117
+ for dir in tqdm(sorted(glob.glob(f"{img_path}/*"))):
118
+ filelist = sorted(glob.glob(f"{dir}/*.png"))
119
+ save_dir = f"{args.output_dir}/{os.path.basename(dir)}"
120
+ yield filelist, save_dir
121
+
122
+
123
+ def process_bonn(args, img_path):
124
+ if args.full_seq:
125
+ for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
126
+ filelist = sorted(glob.glob(f"{dir}/rgb/*.png"))
127
+ save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
128
+ yield filelist, save_dir
129
+ else:
130
+ seq_list = (
131
+ ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]
132
+ if args.seq_list is None
133
+ else args.seq_list
134
+ )
135
+ for seq in tqdm(seq_list):
136
+ filelist = sorted(glob.glob(f"{img_path}/rgbd_bonn_{seq}/rgb_110/*.png"))
137
+ save_dir = f"{args.output_dir}/{seq}"
138
+ yield filelist, save_dir
139
+
140
+
141
+ def process_sunrgbd(args, img_path):
142
+ filelist = sorted(glob.glob(f"{img_path}/*.jpg"))
143
+ save_dir = f"{args.output_dir}"
144
+ yield filelist, save_dir
145
+
146
+
147
+ def process_nyu(args, img_path):
148
+ filelist = sorted(glob.glob(f"{img_path}/*.png"))
149
+ save_dir = f"{args.output_dir}"
150
+ yield filelist, save_dir
151
+
152
+
153
+ def process_scannet(args, img_path):
154
+ seq_list = sorted(glob.glob(f"{img_path}/*"))
155
+ for seq in tqdm(seq_list):
156
+ filelist = sorted(glob.glob(f"{seq}/color_90/*.jpg"))
157
+ save_dir = f"{args.output_dir}/{os.path.basename(seq)}"
158
+ yield filelist, save_dir
159
+
160
+
161
+ def process_sintel(args, img_path):
162
+ if args.full_seq:
163
+ for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
164
+ filelist = sorted(glob.glob(f"{dir}/*.png"))
165
+ save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
166
+ yield filelist, save_dir
167
+ else:
168
+ seq_list = [
169
+ "alley_2",
170
+ "ambush_4",
171
+ "ambush_5",
172
+ "ambush_6",
173
+ "cave_2",
174
+ "cave_4",
175
+ "market_2",
176
+ "market_5",
177
+ "market_6",
178
+ "shaman_3",
179
+ "sleeping_1",
180
+ "sleeping_2",
181
+ "temple_2",
182
+ "temple_3",
183
+ ]
184
+ for seq in tqdm(seq_list):
185
+ filelist = sorted(glob.glob(f"{img_path}/{seq}/*.png"))
186
+ save_dir = f"{args.output_dir}/{seq}"
187
+ yield filelist, save_dir
eval/monodepth/run.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ workdir='.'
5
+
6
+ datasets=('sintel' 'bonn' 'kitti' 'nyu')
7
+ model_name='stream3r'
8
+
9
+ for data in "${datasets[@]}"; do
10
+ output_dir="${workdir}/eval_results/monodepth/${model_name}/${data}"
11
+ echo "$output_dir"
12
+
13
+ python eval/monodepth/launch.py \
14
+ --output_dir="$output_dir" \
15
+ --eval_dataset="$data" \
16
+
17
+ python eval/monodepth/eval_metrics.py \
18
+ --output_dir "$output_dir" \
19
+ --eval_dataset "$data"
20
+ done
eval/monodepth/tools.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import glob
5
+ import argparse
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ from copy import deepcopy
9
+ from scipy.optimize import minimize
10
+ import os
11
+ from collections import defaultdict
12
+
13
+
14
+ def group_by_directory(pathes, idx=-1):
15
+ """
16
+ Groups the file paths based on the second-to-last directory in their paths.
17
+
18
+ Parameters:
19
+ - pathes (list): List of file paths.
20
+
21
+ Returns:
22
+ - dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths.
23
+ """
24
+ grouped_pathes = defaultdict(list)
25
+
26
+ for path in pathes:
27
+ # Extract the second-to-last directory
28
+ dir_name = os.path.dirname(path).split("/")[idx]
29
+ grouped_pathes[dir_name].append(path)
30
+
31
+ return grouped_pathes
32
+
33
+
34
+ def depth2disparity(depth, return_mask=False):
35
+ if isinstance(depth, torch.Tensor):
36
+ disparity = torch.zeros_like(depth)
37
+ elif isinstance(depth, np.ndarray):
38
+ disparity = np.zeros_like(depth)
39
+ non_negtive_mask = depth > 0
40
+ disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
41
+ if return_mask:
42
+ return disparity, non_negtive_mask
43
+ else:
44
+ return disparity
45
+
46
+
47
+ def absolute_error_loss(params, predicted_depth, ground_truth_depth):
48
+ s, t = params
49
+
50
+ predicted_aligned = s * predicted_depth + t
51
+
52
+ abs_error = np.abs(predicted_aligned - ground_truth_depth)
53
+ return np.sum(abs_error)
54
+
55
+
56
+ def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0):
57
+ predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1)
58
+ ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1)
59
+
60
+ initial_params = [s, t] # s = 1, t = 0
61
+
62
+ result = minimize(
63
+ absolute_error_loss,
64
+ initial_params,
65
+ args=(predicted_depth_np, ground_truth_depth_np),
66
+ )
67
+
68
+ s, t = result.x
69
+ return s, t
70
+
71
+
72
+ def absolute_value_scaling2(
73
+ predicted_depth,
74
+ ground_truth_depth,
75
+ s_init=1.0,
76
+ t_init=0.0,
77
+ lr=1e-4,
78
+ max_iters=1000,
79
+ tol=1e-6,
80
+ ):
81
+ # Initialize s and t as torch tensors with requires_grad=True
82
+ s = torch.tensor(
83
+ [s_init],
84
+ requires_grad=True,
85
+ device=predicted_depth.device,
86
+ dtype=predicted_depth.dtype,
87
+ )
88
+ t = torch.tensor(
89
+ [t_init],
90
+ requires_grad=True,
91
+ device=predicted_depth.device,
92
+ dtype=predicted_depth.dtype,
93
+ )
94
+
95
+ optimizer = torch.optim.Adam([s, t], lr=lr)
96
+
97
+ prev_loss = None
98
+
99
+ for i in range(max_iters):
100
+ optimizer.zero_grad()
101
+
102
+ # Compute predicted aligned depth
103
+ predicted_aligned = s * predicted_depth + t
104
+
105
+ # Compute absolute error
106
+ abs_error = torch.abs(predicted_aligned - ground_truth_depth)
107
+
108
+ # Compute loss
109
+ loss = torch.sum(abs_error)
110
+
111
+ # Backpropagate
112
+ loss.backward()
113
+
114
+ # Update parameters
115
+ optimizer.step()
116
+
117
+ # Check convergence
118
+ if prev_loss is not None and torch.abs(prev_loss - loss) < tol:
119
+ break
120
+
121
+ prev_loss = loss.item()
122
+
123
+ return s.detach().item(), t.detach().item()
124
+
125
+
126
+ def depth_evaluation(
127
+ predicted_depth_original,
128
+ ground_truth_depth_original,
129
+ max_depth=80,
130
+ custom_mask=None,
131
+ post_clip_min=None,
132
+ post_clip_max=None,
133
+ pre_clip_min=None,
134
+ pre_clip_max=None,
135
+ align_with_lstsq=False,
136
+ align_with_lad=False,
137
+ align_with_lad2=False,
138
+ metric_scale=False,
139
+ lr=1e-4,
140
+ max_iters=1000,
141
+ use_gpu=False,
142
+ align_with_scale=False,
143
+ disp_input=False,
144
+ ):
145
+ """
146
+ Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment.
147
+
148
+ Args:
149
+ predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map.
150
+ ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map.
151
+ max_depth (float): The maximum depth value to consider. Default is 80 meters.
152
+ align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth.
153
+
154
+ Returns:
155
+ dict: A dictionary containing the evaluation metrics.
156
+ torch.Tensor: The depth error parity map.
157
+ """
158
+ if isinstance(predicted_depth_original, np.ndarray):
159
+ predicted_depth_original = torch.from_numpy(predicted_depth_original)
160
+ if isinstance(ground_truth_depth_original, np.ndarray):
161
+ ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original)
162
+ if custom_mask is not None and isinstance(custom_mask, np.ndarray):
163
+ custom_mask = torch.from_numpy(custom_mask)
164
+
165
+ # if the dimension is 3, flatten to 2d along the batch dimension
166
+ if predicted_depth_original.dim() == 3:
167
+ _, h, w = predicted_depth_original.shape
168
+ predicted_depth_original = predicted_depth_original.view(-1, w)
169
+ ground_truth_depth_original = ground_truth_depth_original.view(-1, w)
170
+ if custom_mask is not None:
171
+ custom_mask = custom_mask.view(-1, w)
172
+
173
+ # put to device
174
+ if use_gpu:
175
+ predicted_depth_original = predicted_depth_original.cuda()
176
+ ground_truth_depth_original = ground_truth_depth_original.cuda()
177
+
178
+ # Filter out depths greater than max_depth
179
+ if max_depth is not None:
180
+ mask = (ground_truth_depth_original > 0) & (
181
+ ground_truth_depth_original < max_depth
182
+ )
183
+ else:
184
+ mask = ground_truth_depth_original > 0
185
+ predicted_depth = predicted_depth_original[mask]
186
+ ground_truth_depth = ground_truth_depth_original[mask]
187
+
188
+ # Clip the depth values
189
+ if pre_clip_min is not None:
190
+ predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min)
191
+ if pre_clip_max is not None:
192
+ predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max)
193
+
194
+ if disp_input: # align the pred to gt in the disparity space
195
+ real_gt = ground_truth_depth.clone()
196
+ ground_truth_depth = 1 / (ground_truth_depth + 1e-8)
197
+
198
+ # various alignment methods
199
+ if metric_scale:
200
+ predicted_depth = predicted_depth
201
+ elif align_with_lstsq:
202
+ # Convert to numpy for lstsq
203
+ predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1)
204
+ ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1)
205
+
206
+ # Add a column of ones for the shift term
207
+ A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)])
208
+
209
+ # Solve for scale (s) and shift (t) using least squares
210
+ result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None)
211
+ s, t = result[0][0], result[0][1]
212
+
213
+ # convert to torch tensor
214
+ s = torch.tensor(s, device=predicted_depth_original.device)
215
+ t = torch.tensor(t, device=predicted_depth_original.device)
216
+
217
+ # Apply scale and shift
218
+ predicted_depth = s * predicted_depth + t
219
+ elif align_with_lad:
220
+ s, t = absolute_value_scaling(
221
+ predicted_depth,
222
+ ground_truth_depth,
223
+ s=torch.median(ground_truth_depth) / torch.median(predicted_depth),
224
+ )
225
+ predicted_depth = s * predicted_depth + t
226
+ elif align_with_lad2:
227
+ s_init = (
228
+ torch.median(ground_truth_depth) / torch.median(predicted_depth)
229
+ ).item()
230
+ s, t = absolute_value_scaling2(
231
+ predicted_depth,
232
+ ground_truth_depth,
233
+ s_init=s_init,
234
+ lr=lr,
235
+ max_iters=max_iters,
236
+ )
237
+ predicted_depth = s * predicted_depth + t
238
+ elif align_with_scale:
239
+ # Compute initial scale factor 's' using the closed-form solution (L2 norm)
240
+ dot_pred_gt = torch.nanmean(ground_truth_depth)
241
+ dot_pred_pred = torch.nanmean(predicted_depth)
242
+ s = dot_pred_gt / dot_pred_pred
243
+
244
+ # Iterative reweighted least squares using the Weiszfeld method
245
+ for _ in range(10):
246
+ # Compute residuals between scaled predictions and ground truth
247
+ residuals = s * predicted_depth - ground_truth_depth
248
+ abs_residuals = (
249
+ residuals.abs() + 1e-8
250
+ ) # Add small constant to avoid division by zero
251
+
252
+ # Compute weights inversely proportional to the residuals
253
+ weights = 1.0 / abs_residuals
254
+
255
+ # Update 's' using weighted sums
256
+ weighted_dot_pred_gt = torch.sum(
257
+ weights * predicted_depth * ground_truth_depth
258
+ )
259
+ weighted_dot_pred_pred = torch.sum(weights * predicted_depth**2)
260
+ s = weighted_dot_pred_gt / weighted_dot_pred_pred
261
+
262
+ # Optionally clip 's' to prevent extreme scaling
263
+ s = s.clamp(min=1e-3)
264
+
265
+ # Detach 's' if you want to stop gradients from flowing through it
266
+ s = s.detach()
267
+
268
+ # Apply the scale factor to the predicted depth
269
+ predicted_depth = s * predicted_depth
270
+
271
+ else:
272
+ # Align the predicted depth with the ground truth using median scaling
273
+ scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth)
274
+ predicted_depth *= scale_factor
275
+
276
+ if disp_input:
277
+ # convert back to depth
278
+ ground_truth_depth = real_gt
279
+ predicted_depth = depth2disparity(predicted_depth)
280
+
281
+ # Clip the predicted depth values
282
+ if post_clip_min is not None:
283
+ predicted_depth = torch.clamp(predicted_depth, min=post_clip_min)
284
+ if post_clip_max is not None:
285
+ predicted_depth = torch.clamp(predicted_depth, max=post_clip_max)
286
+
287
+ if custom_mask is not None:
288
+ assert custom_mask.shape == ground_truth_depth_original.shape
289
+ mask_within_mask = custom_mask.cpu()[mask]
290
+ predicted_depth = predicted_depth[mask_within_mask]
291
+ ground_truth_depth = ground_truth_depth[mask_within_mask]
292
+
293
+ # Calculate the metrics
294
+ abs_rel = torch.mean(
295
+ torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth
296
+ ).item()
297
+ sq_rel = torch.mean(
298
+ ((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth
299
+ ).item()
300
+
301
+ # Correct RMSE calculation
302
+ rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item()
303
+
304
+ # Clip the depth values to avoid log(0)
305
+ predicted_depth = torch.clamp(predicted_depth, min=1e-5)
306
+ log_rmse = torch.sqrt(
307
+ torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2)
308
+ ).item()
309
+
310
+ # Calculate the accuracy thresholds
311
+ max_ratio = torch.maximum(
312
+ predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth
313
+ )
314
+ threshold_0 = torch.mean((max_ratio < 1.0).float()).item()
315
+ threshold_1 = torch.mean((max_ratio < 1.25).float()).item()
316
+ threshold_2 = torch.mean((max_ratio < 1.25**2).float()).item()
317
+ threshold_3 = torch.mean((max_ratio < 1.25**3).float()).item()
318
+
319
+ # Compute the depth error parity map
320
+ if metric_scale:
321
+ predicted_depth_original = predicted_depth_original
322
+ if disp_input:
323
+ predicted_depth_original = depth2disparity(predicted_depth_original)
324
+ depth_error_parity_map = (
325
+ torch.abs(predicted_depth_original - ground_truth_depth_original)
326
+ / ground_truth_depth_original
327
+ )
328
+ elif align_with_lstsq or align_with_lad or align_with_lad2:
329
+ predicted_depth_original = predicted_depth_original * s + t
330
+ if disp_input:
331
+ predicted_depth_original = depth2disparity(predicted_depth_original)
332
+ depth_error_parity_map = (
333
+ torch.abs(predicted_depth_original - ground_truth_depth_original)
334
+ / ground_truth_depth_original
335
+ )
336
+ elif align_with_scale:
337
+ predicted_depth_original = predicted_depth_original * s
338
+ if disp_input:
339
+ predicted_depth_original = depth2disparity(predicted_depth_original)
340
+ depth_error_parity_map = (
341
+ torch.abs(predicted_depth_original - ground_truth_depth_original)
342
+ / ground_truth_depth_original
343
+ )
344
+ else:
345
+ predicted_depth_original = predicted_depth_original * scale_factor
346
+ if disp_input:
347
+ predicted_depth_original = depth2disparity(predicted_depth_original)
348
+ depth_error_parity_map = (
349
+ torch.abs(predicted_depth_original - ground_truth_depth_original)
350
+ / ground_truth_depth_original
351
+ )
352
+
353
+ # Reshape the depth_error_parity_map back to the original image size
354
+ depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original)
355
+ depth_error_parity_map_full = torch.where(
356
+ mask, depth_error_parity_map, depth_error_parity_map_full
357
+ )
358
+
359
+ predict_depth_map_full = predicted_depth_original
360
+ gt_depth_map_full = torch.zeros_like(ground_truth_depth_original)
361
+ gt_depth_map_full = torch.where(
362
+ mask, ground_truth_depth_original, gt_depth_map_full
363
+ )
364
+
365
+ num_valid_pixels = (
366
+ torch.sum(mask).item()
367
+ if custom_mask is None
368
+ else torch.sum(mask_within_mask).item()
369
+ )
370
+ if num_valid_pixels == 0:
371
+ (
372
+ abs_rel,
373
+ sq_rel,
374
+ rmse,
375
+ log_rmse,
376
+ threshold_0,
377
+ threshold_1,
378
+ threshold_2,
379
+ threshold_3,
380
+ ) = (0, 0, 0, 0, 0, 0, 0, 0)
381
+
382
+ results = {
383
+ "Abs Rel": abs_rel,
384
+ "Sq Rel": sq_rel,
385
+ "RMSE": rmse,
386
+ "Log RMSE": log_rmse,
387
+ "δ < 1.": threshold_0,
388
+ "δ < 1.25": threshold_1,
389
+ "δ < 1.25^2": threshold_2,
390
+ "δ < 1.25^3": threshold_3,
391
+ "valid_pixels": num_valid_pixels,
392
+ }
393
+
394
+ return (
395
+ results,
396
+ depth_error_parity_map_full,
397
+ predict_depth_map_full,
398
+ gt_depth_map_full,
399
+ )
eval/mv_recon/base.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # base class for implementing datasets
6
+ # --------------------------------------------------------
7
+ import PIL
8
+ import numpy as np
9
+ import torch
10
+
11
+ from stream3r.dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
12
+
13
+ from eval.mv_recon.dataset_utils.transforms import ImgNorm
14
+ import eval.mv_recon.dataset_utils.cropping as cropping
15
+
16
+
17
+ class BaseStereoViewDataset:
18
+ """Define all basic options.
19
+
20
+ Usage:
21
+ class MyDataset (BaseStereoViewDataset):
22
+ def _get_views(self, idx, rng):
23
+ # overload here
24
+ views = []
25
+ views.append(dict(img=, ...))
26
+ return views
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ *, # only keyword arguments
32
+ split=None,
33
+ resolution=None, # square_size or (width, height) or list of [(width,height), ...]
34
+ transform=ImgNorm,
35
+ aug_crop=False,
36
+ seed=None,
37
+ ):
38
+ self.num_views = 2
39
+ self.split = split
40
+ self._set_resolutions(resolution)
41
+
42
+ self.transform = transform
43
+ if isinstance(transform, str):
44
+ transform = eval(transform)
45
+
46
+ self.aug_crop = aug_crop
47
+ self.seed = seed
48
+
49
+ def __len__(self):
50
+ return len(self.scenes)
51
+
52
+ def get_stats(self):
53
+ return f"{len(self)} pairs"
54
+
55
+ def __repr__(self):
56
+ resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
57
+ return (
58
+ f"""{type(self).__name__}({self.get_stats()},
59
+ {self.split=},
60
+ {self.seed=},
61
+ resolutions={resolutions_str},
62
+ {self.transform=})""".replace(
63
+ "self.", ""
64
+ )
65
+ .replace("\n", "")
66
+ .replace(" ", "")
67
+ )
68
+
69
+ def _get_views(self, idx, resolution, rng):
70
+ raise NotImplementedError()
71
+
72
+ def __getitem__(self, idx):
73
+ if isinstance(idx, tuple):
74
+ # the idx is specifying the aspect-ratio
75
+ idx, ar_idx = idx
76
+ else:
77
+ assert len(self._resolutions) == 1
78
+ ar_idx = 0
79
+
80
+ # set-up the rng
81
+ if self.seed: # reseed for each __getitem__
82
+ self._rng = np.random.default_rng(seed=self.seed + idx)
83
+ elif not hasattr(self, "_rng"):
84
+ seed = torch.initial_seed() # this is different for each dataloader process
85
+ self._rng = np.random.default_rng(seed=seed)
86
+
87
+ # over-loaded code
88
+ resolution = self._resolutions[
89
+ ar_idx
90
+ ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
91
+ views = self._get_views(idx, resolution, self._rng)
92
+
93
+ # check data-types
94
+ for v, view in enumerate(views):
95
+ assert (
96
+ "pts3d" not in view
97
+ ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
98
+ view["idx"] = v
99
+
100
+ # encode the image
101
+ width, height = view["img"].size
102
+ view["true_shape"] = np.int32((height, width))
103
+ view["img"] = self.transform(view["img"])
104
+
105
+ assert "camera_intrinsics" in view
106
+ if "camera_pose" not in view:
107
+ view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
108
+ else:
109
+ assert np.isfinite(
110
+ view["camera_pose"]
111
+ ).all(), f"NaN in camera pose for view {view_name(view)}"
112
+ assert "pts3d" not in view
113
+ assert "valid_mask" not in view
114
+ assert np.isfinite(
115
+ view["depthmap"]
116
+ ).all(), f"NaN in depthmap for view {view_name(view)}"
117
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
118
+
119
+ view["pts3d"] = pts3d
120
+ view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
121
+
122
+ # check all datatypes
123
+ for key, val in view.items():
124
+ res, err_msg = is_good_type(key, val)
125
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
126
+ K = view["camera_intrinsics"]
127
+ view["img_mask"] = True
128
+ view["ray_mask"] = False
129
+ view["ray_map"] = torch.full(
130
+ (6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
131
+ )
132
+ view["update"] = True
133
+ view["reset"] = False
134
+
135
+ # last thing done!
136
+ for view in views:
137
+ # transpose to make sure all views are the same size
138
+ transpose_to_landscape(view)
139
+ # this allows to check whether the RNG is is the same state each time
140
+ view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
141
+ return views
142
+
143
+ def _set_resolutions(self, resolutions):
144
+ """Set the resolution(s) of the dataset.
145
+ Params:
146
+ - resolutions: int or tuple or list of tuples
147
+ """
148
+ assert resolutions is not None, "undefined resolution"
149
+
150
+ if not isinstance(resolutions, list):
151
+ resolutions = [resolutions]
152
+
153
+ self._resolutions = []
154
+ for resolution in resolutions:
155
+ if isinstance(resolution, int):
156
+ width = height = resolution
157
+ else:
158
+ width, height = resolution
159
+ assert isinstance(
160
+ width, int
161
+ ), f"Bad type for {width=} {type(width)=}, should be int"
162
+ assert isinstance(
163
+ height, int
164
+ ), f"Bad type for {height=} {type(height)=}, should be int"
165
+ assert width >= height
166
+ self._resolutions.append((width, height))
167
+
168
+ def _crop_resize_if_necessary(
169
+ self, image, depthmap, intrinsics, resolution, rng=None, info=None
170
+ ):
171
+ """This function:
172
+ - first downsizes the image with LANCZOS inteprolation,
173
+ which is better than bilinear interpolation in
174
+ """
175
+ if not isinstance(image, PIL.Image.Image):
176
+ image = PIL.Image.fromarray(image)
177
+
178
+ # downscale with lanczos interpolation so that image.size == resolution
179
+ # cropping centered on the principal point
180
+ W, H = image.size
181
+ cx, cy = intrinsics[:2, 2].round().astype(int)
182
+
183
+ # calculate min distance to margin
184
+ min_margin_x = min(cx, W - cx)
185
+ min_margin_y = min(cy, H - cy)
186
+ assert min_margin_x > W / 5, f"Bad principal point in view={info}"
187
+ assert min_margin_y > H / 5, f"Bad principal point in view={info}"
188
+
189
+ ## Center crop
190
+ # Crop on the principal point, make it always centered
191
+ # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
192
+ l, t = cx - min_margin_x, cy - min_margin_y
193
+ r, b = cx + min_margin_x, cy + min_margin_y
194
+ crop_bbox = (l, t, r, b)
195
+
196
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(
197
+ image, depthmap, intrinsics, crop_bbox
198
+ )
199
+
200
+ # # transpose the resolution if necessary
201
+ W, H = image.size # new size
202
+ assert resolution[0] >= resolution[1]
203
+ if H > 1.1 * W:
204
+ # image is portrait mode
205
+ resolution = resolution[::-1]
206
+ elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
207
+ # image is square, so we chose (portrait, landscape) randomly
208
+ if rng.integers(2):
209
+ resolution = resolution[::-1]
210
+
211
+ # high-quality Lanczos down-scaling
212
+ target_resolution = np.array(resolution)
213
+ # # if self.aug_crop > 1:
214
+ # # target_resolution += rng.integers(0, self.aug_crop)
215
+ # if resolution != (224, 224):
216
+ # halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
217
+ # ## Recale with max factor, so one of width or height might be larger than target_resolution
218
+ # image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
219
+ # else:
220
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(
221
+ image, depthmap, intrinsics, target_resolution
222
+ )
223
+ # actual cropping (if necessary) with bilinear interpolation
224
+ # if resolution == (224, 224):
225
+ intrinsics2 = cropping.camera_matrix_of_crop(
226
+ intrinsics, image.size, resolution, offset_factor=0.5
227
+ )
228
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(
229
+ intrinsics, intrinsics2, resolution
230
+ )
231
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(
232
+ image, depthmap, intrinsics, crop_bbox
233
+ )
234
+ return image, depthmap, intrinsics
235
+
236
+
237
+ def is_good_type(key, v):
238
+ """returns (is_good, err_msg)"""
239
+ if isinstance(v, (str, int, tuple)):
240
+ return True, None
241
+ if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
242
+ return False, f"bad {v.dtype=}"
243
+ return True, None
244
+
245
+
246
+ def view_name(view, batch_index=None):
247
+ def sel(x):
248
+ return x[batch_index] if batch_index not in (None, slice(None)) else x
249
+
250
+ db = sel(view["dataset"])
251
+ label = sel(view["label"])
252
+ instance = sel(view["instance"])
253
+ return f"{db}/{label}/{instance}"
254
+
255
+
256
+ def transpose_to_landscape(view):
257
+ height, width = view["true_shape"]
258
+
259
+ if width < height:
260
+ # rectify portrait to landscape
261
+ assert view["img"].shape == (3, height, width)
262
+ view["img"] = view["img"].swapaxes(1, 2)
263
+
264
+ assert view["valid_mask"].shape == (height, width)
265
+ view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
266
+
267
+ assert view["depthmap"].shape == (height, width)
268
+ view["depthmap"] = view["depthmap"].swapaxes(0, 1)
269
+
270
+ assert view["pts3d"].shape == (height, width, 3)
271
+ view["pts3d"] = view["pts3d"].swapaxes(0, 1)
272
+
273
+ # transpose x and y pixels
274
+ view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]