Spaces:
Configuration error
Configuration error
Commit
·
9d31508
1
Parent(s):
594b88c
add stream3r
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- LICENSE +13 -0
- assets/pipeline.png +3 -0
- assets/teaser_dynamic.gif +3 -0
- configs/__init__.py +7 -0
- configs/callbacks/default.yaml +22 -0
- configs/callbacks/early_stopping.yaml +15 -0
- configs/callbacks/model_checkpoint.yaml +17 -0
- configs/callbacks/model_summary.yaml +5 -0
- configs/callbacks/none.yaml +0 -0
- configs/callbacks/rich_progress_bar.yaml +18 -0
- configs/data/multiview_dust3r.yaml +25 -0
- configs/debug/ddp_debug.yaml +48 -0
- configs/debug/default.yaml +35 -0
- configs/debug/fdr.yaml +9 -0
- configs/debug/limit.yaml +12 -0
- configs/debug/overfit.yaml +13 -0
- configs/debug/profiler.yaml +12 -0
- configs/eval.yaml +19 -0
- configs/experiment/stream3r/stream3r.yaml +125 -0
- configs/extras/default.yaml +8 -0
- configs/hparams_search/mnist_optuna.yaml +52 -0
- configs/hydra/default.yaml +19 -0
- configs/hydra/launcher/fair_a100.yaml +43 -0
- configs/local/.gitkeep +0 -0
- configs/logger/aim.yaml +28 -0
- configs/logger/comet.yaml +12 -0
- configs/logger/csv.yaml +7 -0
- configs/logger/many_loggers.yaml +9 -0
- configs/logger/mlflow.yaml +12 -0
- configs/logger/neptune.yaml +9 -0
- configs/logger/tensorboard.yaml +10 -0
- configs/logger/wandb.yaml +16 -0
- configs/model/stream3r.yaml +42 -0
- configs/paths/default.yaml +21 -0
- configs/train.yaml +49 -0
- configs/trainer/cpu.yaml +5 -0
- configs/trainer/ddp.yaml +12 -0
- configs/trainer/ddp_eval.yaml +16 -0
- configs/trainer/ddp_sim.yaml +7 -0
- configs/trainer/deepspeed_stage_2.yaml +9 -0
- configs/trainer/default.yaml +30 -0
- configs/trainer/gpu.yaml +5 -0
- configs/trainer/mps.yaml +5 -0
- eval/monodepth/eval_metrics.py +211 -0
- eval/monodepth/launch.py +146 -0
- eval/monodepth/metadata.py +187 -0
- eval/monodepth/run.sh +20 -0
- eval/monodepth/tools.py +399 -0
- eval/mv_recon/base.py +274 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
S-Lab License 1.0
|
| 2 |
+
Copyright 2025 S-Lab
|
| 3 |
+
|
| 4 |
+
Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 5 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 6 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 7 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
| 8 |
+
|
| 9 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
|
| 10 |
+
IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 11 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 12 |
+
|
| 13 |
+
In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
|
assets/pipeline.png
ADDED
|
Git LFS Details
|
assets/teaser_dynamic.gif
ADDED
|
Git LFS Details
|
configs/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# this file is needed here to include configs when building project as a package
|
configs/callbacks/default.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model_checkpoint
|
| 3 |
+
- early_stopping
|
| 4 |
+
- model_summary
|
| 5 |
+
- rich_progress_bar
|
| 6 |
+
- _self_
|
| 7 |
+
|
| 8 |
+
model_checkpoint:
|
| 9 |
+
dirpath: ${paths.output_dir}/checkpoints
|
| 10 |
+
filename: "epoch_{epoch:03d}"
|
| 11 |
+
monitor: "val/loss"
|
| 12 |
+
mode: "min"
|
| 13 |
+
save_last: True
|
| 14 |
+
auto_insert_metric_name: False
|
| 15 |
+
|
| 16 |
+
early_stopping:
|
| 17 |
+
monitor: "val/loss"
|
| 18 |
+
patience: 100
|
| 19 |
+
mode: "min"
|
| 20 |
+
|
| 21 |
+
model_summary:
|
| 22 |
+
max_depth: -1
|
configs/callbacks/early_stopping.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
| 2 |
+
|
| 3 |
+
early_stopping:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.EarlyStopping
|
| 5 |
+
monitor: ??? # quantity to be monitored, must be specified !!!
|
| 6 |
+
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
| 7 |
+
patience: 3 # number of checks with no improvement after which training will be stopped
|
| 8 |
+
verbose: False # verbosity mode
|
| 9 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 10 |
+
strict: True # whether to crash the training if monitor is not found in the validation metrics
|
| 11 |
+
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
|
| 12 |
+
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
|
| 13 |
+
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
|
| 14 |
+
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
|
| 15 |
+
# log_rank_zero_only: False # this keyword argument isn't available in stable version
|
configs/callbacks/model_checkpoint.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
| 2 |
+
|
| 3 |
+
model_checkpoint:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
| 5 |
+
dirpath: null # directory to save the model file
|
| 6 |
+
filename: null # checkpoint filename
|
| 7 |
+
monitor: 'val/loss' # name of the logged metric which determines when model is improving
|
| 8 |
+
verbose: False # verbosity mode
|
| 9 |
+
save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
| 10 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 11 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 12 |
+
auto_insert_metric_name: False # when True, the checkpoints filenames will contain the metric name
|
| 13 |
+
save_weights_only: False # if True, then only the model’s weights will be saved
|
| 14 |
+
every_n_train_steps: null # number of training steps between checkpoints
|
| 15 |
+
train_time_interval: null # checkpoints are monitored at the specified time interval
|
| 16 |
+
every_n_epochs: 20 # number of epochs between checkpoints
|
| 17 |
+
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
configs/callbacks/model_summary.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
| 2 |
+
|
| 3 |
+
model_summary:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.RichModelSummary
|
| 5 |
+
max_depth: 1 # the maximum depth of layer nesting that the summary will include
|
configs/callbacks/none.yaml
ADDED
|
File without changes
|
configs/callbacks/rich_progress_bar.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
|
| 2 |
+
|
| 3 |
+
rich_progress_bar:
|
| 4 |
+
_target_: lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar
|
| 5 |
+
refresh_rate: 1
|
| 6 |
+
leave: false
|
| 7 |
+
theme:
|
| 8 |
+
_target_: lightning.pytorch.callbacks.progress.rich_progress.RichProgressBarTheme
|
| 9 |
+
description: green_yellow
|
| 10 |
+
progress_bar: green1
|
| 11 |
+
progress_bar_finished: green1
|
| 12 |
+
progress_bar_pulse: "#6206E0"
|
| 13 |
+
batch_progress: green_yellow
|
| 14 |
+
time: blue
|
| 15 |
+
processing_speed: cyan
|
| 16 |
+
metrics: grey82
|
| 17 |
+
metrics_text_delimiter: " "
|
| 18 |
+
metrics_format: .4g
|
configs/data/multiview_dust3r.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Define the common data root and number of views
|
| 2 |
+
data_root: /path/to/dust3r_data
|
| 3 |
+
num_views: 4
|
| 4 |
+
num_views_val: 10
|
| 5 |
+
|
| 6 |
+
data_module:
|
| 7 |
+
_target_: stream3r.data.multiview_dust3r_datamodule.MultiViewDUSt3RDataModule
|
| 8 |
+
train_datasets:
|
| 9 |
+
- 80_000 @ Co3d_Multiview(split='train', num_views=${data.num_views}, window_degree_range=360, num_samples_per_window=100, ROOT='${data.data_root}/co3d_50_seqs_per_category_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
|
| 10 |
+
- 80_000 @ MegaDepth_Multiview(split='train', num_views=${data.num_views}, window_size=${python_eval:"${data.num_views} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/megadepth_processed', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
|
| 11 |
+
- 80_000 @ ScanNetpp_Multiview(split='train', num_views=${data.num_views}, window_size=${python_eval:"${data.num_views} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/scannetpp_processed', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
|
| 12 |
+
- 80_000 @ ARKitScenes_Multiview(split='train', num_views=${data.num_views}, num_samples_per_window=10, ROOT='${data.data_root}/arkitscenes_processed', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
|
| 13 |
+
- 80_000 @ Habitat_Multiview(1_000_000, split='train', num_views=${data.num_views}, ROOT='${data.data_root}/habitat_processed', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)
|
| 14 |
+
validation_datasets:
|
| 15 |
+
- 100 @ Co3d_Multiview(split='test', num_views=${data.num_views_val}, window_degree_range=360, num_samples_per_window=100, ROOT='${data.data_root}/co3d_50_seqs_per_category_subset_processed', resolution=(512, 384), seed=777)
|
| 16 |
+
- 100 @ MegaDepth_Multiview(split='val', num_views=${data.num_views_val}, window_size=${python_eval:"${data.num_views_val} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/megadepth_processed', resolution=(512, 336), seed=777)
|
| 17 |
+
- 100 @ ScanNetpp_Multiview(split='train', num_views=${data.num_views_val}, window_size=${python_eval:"${data.num_views_val} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/scannetpp_processed', resolution=(512, 384), seed=777)
|
| 18 |
+
- 100 @ ARKitScenes_Multiview(split='train', num_views=${data.num_views_val}, num_samples_per_window=10, ROOT='${data.data_root}/arkitscenes_processed', resolution=(512, 384), seed=777)
|
| 19 |
+
- 100 @ Habitat_Multiview(100_000, split='val', num_views=${data.num_views_val}, ROOT='${data.data_root}/habitat_processed', resolution=(512,384), seed=777)
|
| 20 |
+
batch_size_per_device: 6
|
| 21 |
+
batch_size_per_device_val: 4
|
| 22 |
+
num_workers: 6
|
| 23 |
+
pin_memory: True
|
| 24 |
+
|
| 25 |
+
|
configs/debug/ddp_debug.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
#use a smaller dataset for faster initializations
|
| 4 |
+
defaults:
|
| 5 |
+
- override /data: multiview_dust3r_tiny
|
| 6 |
+
- override /logger:
|
| 7 |
+
- csv
|
| 8 |
+
- wandb
|
| 9 |
+
|
| 10 |
+
# overwrite task name so debugging logs are stored in separate folder
|
| 11 |
+
task_name: "debug"
|
| 12 |
+
|
| 13 |
+
logger:
|
| 14 |
+
wandb:
|
| 15 |
+
name: ${paths.run_folder_name}
|
| 16 |
+
|
| 17 |
+
# ckpt_path: /some/random/path
|
| 18 |
+
|
| 19 |
+
extras:
|
| 20 |
+
ignore_warnings: False
|
| 21 |
+
enforce_tags: False
|
| 22 |
+
|
| 23 |
+
# sets level of all command line loggers to 'DEBUG'
|
| 24 |
+
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
| 25 |
+
hydra:
|
| 26 |
+
job_logging:
|
| 27 |
+
root:
|
| 28 |
+
level: DEBUG
|
| 29 |
+
|
| 30 |
+
model:
|
| 31 |
+
net:
|
| 32 |
+
random_image_idx_embedding: true
|
| 33 |
+
|
| 34 |
+
data:
|
| 35 |
+
num_views: 4
|
| 36 |
+
data_module:
|
| 37 |
+
num_workers: 0 # debuggers don't like multiprocessing
|
| 38 |
+
pin_memory: false # disable gpu memory pin
|
| 39 |
+
batch_size_per_device: 6
|
| 40 |
+
|
| 41 |
+
trainer:
|
| 42 |
+
log_every_n_steps: 1
|
| 43 |
+
devices: auto
|
| 44 |
+
# fast_dev_run: 1
|
| 45 |
+
limit_train_batches: 1
|
| 46 |
+
limit_val_batches: 10000
|
| 47 |
+
precision: 32
|
| 48 |
+
|
configs/debug/default.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# default debugging setup, runs 1 full epoch
|
| 4 |
+
# other debugging configs can inherit from this one
|
| 5 |
+
|
| 6 |
+
# overwrite task name so debugging logs are stored in separate folder
|
| 7 |
+
task_name: "debug"
|
| 8 |
+
|
| 9 |
+
# disable callbacks and loggers during debugging
|
| 10 |
+
callbacks: null
|
| 11 |
+
logger: null
|
| 12 |
+
|
| 13 |
+
extras:
|
| 14 |
+
ignore_warnings: False
|
| 15 |
+
enforce_tags: False
|
| 16 |
+
|
| 17 |
+
# sets level of all command line loggers to 'DEBUG'
|
| 18 |
+
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
| 19 |
+
hydra:
|
| 20 |
+
job_logging:
|
| 21 |
+
root:
|
| 22 |
+
level: DEBUG
|
| 23 |
+
|
| 24 |
+
# use this to also set hydra loggers to 'DEBUG'
|
| 25 |
+
# verbose: True
|
| 26 |
+
|
| 27 |
+
trainer:
|
| 28 |
+
max_epochs: 1
|
| 29 |
+
accelerator: cpu # debuggers don't like gpus
|
| 30 |
+
devices: 1 # debuggers don't like multiprocessing
|
| 31 |
+
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
|
| 32 |
+
|
| 33 |
+
data:
|
| 34 |
+
num_workers: 0 # debuggers don't like multiprocessing
|
| 35 |
+
pin_memory: False # disable gpu memory pin
|
configs/debug/fdr.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# runs 1 train, 1 validation and 1 test step
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
fast_dev_run: true
|
configs/debug/limit.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# uses only 1% of the training data and 5% of validation/test data
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
max_epochs: 3
|
| 10 |
+
limit_train_batches: 0.01
|
| 11 |
+
limit_val_batches: 0.05
|
| 12 |
+
limit_test_batches: 0.05
|
configs/debug/overfit.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# overfits to 3 batches
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
max_epochs: 20
|
| 10 |
+
overfit_batches: 3
|
| 11 |
+
|
| 12 |
+
# model ckpt and early stopping need to be disabled during overfitting
|
| 13 |
+
callbacks: null
|
configs/debug/profiler.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# runs with execution time profiling
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
max_epochs: 1
|
| 10 |
+
profiler: "simple"
|
| 11 |
+
# profiler: "advanced"
|
| 12 |
+
# profiler: "pytorch"
|
configs/eval.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- _self_
|
| 5 |
+
- data: multiview_dust3r
|
| 6 |
+
- model: stream3r
|
| 7 |
+
- logger: many_loggers
|
| 8 |
+
- trainer: ddp_eval
|
| 9 |
+
- paths: default
|
| 10 |
+
- extras: default
|
| 11 |
+
- hydra: default
|
| 12 |
+
- eval: default
|
| 13 |
+
|
| 14 |
+
task_name: "eval"
|
| 15 |
+
|
| 16 |
+
tags: ["eval"]
|
| 17 |
+
|
| 18 |
+
# passing checkpoint path is necessary for evaluation
|
| 19 |
+
ckpt_path: ???
|
configs/experiment/stream3r/stream3r.yaml
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: stream3r
|
| 5 |
+
|
| 6 |
+
# seed for random number generators in pytorch, numpy and python.random
|
| 7 |
+
seed: 42
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
tags: ["train", "stream3r"]
|
| 11 |
+
|
| 12 |
+
task_name: stream3r
|
| 13 |
+
slurm_job_id: 99999 # must set in the command line
|
| 14 |
+
|
| 15 |
+
# ckpt_path: /path/to/resume.ckpt # uncomment to resume training from a checkpoint
|
| 16 |
+
|
| 17 |
+
paths:
|
| 18 |
+
run_folder_name: ${task_name}_${slurm_job_id}
|
| 19 |
+
|
| 20 |
+
logger:
|
| 21 |
+
wandb:
|
| 22 |
+
name: ${task_name}_${slurm_job_id}
|
| 23 |
+
project: stream3r
|
| 24 |
+
|
| 25 |
+
data:
|
| 26 |
+
data_scaling: 1.0
|
| 27 |
+
data_root: /data
|
| 28 |
+
num_views: 24
|
| 29 |
+
resolution:
|
| 30 |
+
- [518, 392]
|
| 31 |
+
- [518, 378]
|
| 32 |
+
- [518, 336]
|
| 33 |
+
- [518, 294]
|
| 34 |
+
- [518, 252]
|
| 35 |
+
- [518, 210]
|
| 36 |
+
- [518, 140]
|
| 37 |
+
- [378, 518]
|
| 38 |
+
- [336, 518]
|
| 39 |
+
- [294, 518]
|
| 40 |
+
- [252, 518]
|
| 41 |
+
- [224, 224]
|
| 42 |
+
allow_repeat: true
|
| 43 |
+
n_corres_train: 0
|
| 44 |
+
data_module:
|
| 45 |
+
_target_: stream3r.data.multiview_dust3r_datamodule.MultiViewDUSt3RDataModule
|
| 46 |
+
pin_memory: true
|
| 47 |
+
num_workers: 16
|
| 48 |
+
num_workers_val: 1 # have to be a low number when using DeepSpeed ZeRO-2
|
| 49 |
+
batch_size_per_device: 1
|
| 50 |
+
batch_size_per_device_val: 1
|
| 51 |
+
train_datasets:
|
| 52 |
+
- 44800 @ Co3d_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT='${data.data_root}/processed_co3d', aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 53 |
+
- 56000 @ WildRGBD_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_wildrgbd_mp", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 54 |
+
- 22400 @ ARKitScenesHighRes_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_arkitscene_highres", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 55 |
+
- 38400 @ ScanNet_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_scannet/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 56 |
+
- 16800 @ ScanNetpp_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_scannetpp/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 57 |
+
- 84000 @ MapFree_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_mapfree/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 58 |
+
- 20000 @ Waymo_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_waymo/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 59 |
+
- 56000 @ TartanAir_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_tartanair/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 60 |
+
- 9400 @ Spring(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_spring/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 61 |
+
- 36000 @ BEDLAM_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_bedlam/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 62 |
+
- 28800 @ MP3D_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_mp3d/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 63 |
+
- 14400 @ UASOL_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_uasol/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 64 |
+
- 1400 @ MVS_Synth_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_mvs_synth", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 65 |
+
- 7200 @ PointOdyssey_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_pointodyssey", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 66 |
+
- 11200 @ HyperSim_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_hypersim_new", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 67 |
+
- 22400 @ BlendedMVS_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_blendedmvs/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 68 |
+
- 22400 @ MegaDepth_Multi(allow_repeat=${data.allow_repeat}, split="train", ROOT="${data.data_root}/processed_megadepth", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 69 |
+
- 5600 @ VirtualKITTI2_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_vkitti", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 70 |
+
- 168 @ UnReal4K_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_unreal4k/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 71 |
+
- 74000 @ DL3DV_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_dl3dv_ours_parts/processed_dl3dv_ours", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) +
|
| 72 |
+
- 36000 @ DynamicReplica(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_dynamic_replica/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train})
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
pretrained: weights/vggt/model.pt
|
| 76 |
+
net:
|
| 77 |
+
freeze: encoder
|
| 78 |
+
|
| 79 |
+
scheduler:
|
| 80 |
+
warmup_start_lr: 1e-6
|
| 81 |
+
warmup_epochs: 1
|
| 82 |
+
|
| 83 |
+
train_criterion:
|
| 84 |
+
_target_: stream3r.loss.losses.CausalLoss
|
| 85 |
+
gradient_loss: grad
|
| 86 |
+
is_metric: false
|
| 87 |
+
|
| 88 |
+
validation_criterion:
|
| 89 |
+
_target_: stream3r.loss.losses.CausalLoss
|
| 90 |
+
gradient_loss: grad
|
| 91 |
+
is_metric: false
|
| 92 |
+
|
| 93 |
+
optimizer:
|
| 94 |
+
_target_: torch.optim.AdamW
|
| 95 |
+
_partial_: true
|
| 96 |
+
lr: 1e-5
|
| 97 |
+
betas:
|
| 98 |
+
- 0.9
|
| 99 |
+
- 0.95
|
| 100 |
+
weight_decay: 0.05
|
| 101 |
+
|
| 102 |
+
trainer:
|
| 103 |
+
devices: auto
|
| 104 |
+
max_epochs: 500
|
| 105 |
+
accumulate_grad_batches: 4
|
| 106 |
+
strategy:
|
| 107 |
+
_target_: lightning.pytorch.strategies.DeepSpeedStrategy
|
| 108 |
+
timeout:
|
| 109 |
+
_target_: datetime.timedelta
|
| 110 |
+
minutes: 80
|
| 111 |
+
plugins: null
|
| 112 |
+
limit_val_batches: 0
|
| 113 |
+
precision: bf16-mixed
|
| 114 |
+
log_every_n_steps: 20
|
| 115 |
+
|
| 116 |
+
callbacks:
|
| 117 |
+
model_checkpoint:
|
| 118 |
+
every_n_train_steps: 2000
|
| 119 |
+
every_n_epochs: null
|
| 120 |
+
save_top_k: -1
|
| 121 |
+
filename: "{epoch:03d}-{step:08d}"
|
| 122 |
+
save_last: false
|
| 123 |
+
monitor: "train/loss"
|
| 124 |
+
early_stopping:
|
| 125 |
+
monitor: "train/loss"
|
configs/extras/default.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# disable python warnings if they annoy you
|
| 2 |
+
ignore_warnings: False
|
| 3 |
+
|
| 4 |
+
# ask user for tags if none are provided in the config
|
| 5 |
+
enforce_tags: True
|
| 6 |
+
|
| 7 |
+
# pretty print config tree at the start of the run using Rich library
|
| 8 |
+
print_config: True
|
configs/hparams_search/mnist_optuna.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# example hyperparameter optimization of some experiment with Optuna:
|
| 4 |
+
# python train.py -m hparams_search=mnist_optuna experiment=example
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /hydra/sweeper: optuna
|
| 8 |
+
|
| 9 |
+
# choose metric which will be optimized by Optuna
|
| 10 |
+
# make sure this is the correct name of some metric logged in lightning module!
|
| 11 |
+
optimized_metric: "val/acc_best"
|
| 12 |
+
|
| 13 |
+
# here we define Optuna hyperparameter search
|
| 14 |
+
# it optimizes for value returned from function with @hydra.main decorator
|
| 15 |
+
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
| 16 |
+
hydra:
|
| 17 |
+
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
|
| 18 |
+
|
| 19 |
+
sweeper:
|
| 20 |
+
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
| 21 |
+
|
| 22 |
+
# storage URL to persist optimization results
|
| 23 |
+
# for example, you can use SQLite if you set 'sqlite:///example.db'
|
| 24 |
+
storage: null
|
| 25 |
+
|
| 26 |
+
# name of the study to persist optimization results
|
| 27 |
+
study_name: null
|
| 28 |
+
|
| 29 |
+
# number of parallel workers
|
| 30 |
+
n_jobs: 1
|
| 31 |
+
|
| 32 |
+
# 'minimize' or 'maximize' the objective
|
| 33 |
+
direction: maximize
|
| 34 |
+
|
| 35 |
+
# total number of runs that will be executed
|
| 36 |
+
n_trials: 20
|
| 37 |
+
|
| 38 |
+
# choose Optuna hyperparameter sampler
|
| 39 |
+
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
|
| 40 |
+
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
| 41 |
+
sampler:
|
| 42 |
+
_target_: optuna.samplers.TPESampler
|
| 43 |
+
seed: 1234
|
| 44 |
+
n_startup_trials: 10 # number of random sampling runs before optimization starts
|
| 45 |
+
|
| 46 |
+
# define hyperparameter search space
|
| 47 |
+
params:
|
| 48 |
+
model.optimizer.lr: interval(0.0001, 0.1)
|
| 49 |
+
data.batch_size: choice(32, 64, 128, 256)
|
| 50 |
+
model.net.lin1_size: choice(64, 128, 256)
|
| 51 |
+
model.net.lin2_size: choice(64, 128, 256)
|
| 52 |
+
model.net.lin3_size: choice(32, 64, 128, 256)
|
configs/hydra/default.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://hydra.cc/docs/configure_hydra/intro/
|
| 2 |
+
|
| 3 |
+
# enable color logging
|
| 4 |
+
defaults:
|
| 5 |
+
- override hydra_logging: colorlog
|
| 6 |
+
- override job_logging: colorlog
|
| 7 |
+
|
| 8 |
+
# output directory, generated dynamically on each run
|
| 9 |
+
run:
|
| 10 |
+
dir: ${paths.log_dir}/${task_name}/runs/${paths.run_folder_name}
|
| 11 |
+
sweep:
|
| 12 |
+
dir: ${paths.log_dir}/${task_name}/multiruns/${paths.run_folder_name}
|
| 13 |
+
subdir: ${hydra.job.num}
|
| 14 |
+
|
| 15 |
+
job_logging:
|
| 16 |
+
handlers:
|
| 17 |
+
file:
|
| 18 |
+
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
| 19 |
+
filename: ${hydra.runtime.output_dir}/${task_name}.log
|
configs/hydra/launcher/fair_a100.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- submitit_slurm
|
| 3 |
+
|
| 4 |
+
# see: https://github.com/facebookresearch/hydra/blob/main/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py
|
| 5 |
+
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher
|
| 6 |
+
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
|
| 7 |
+
name: ${hydra.job.name}
|
| 8 |
+
timeout_min: 20160 # 14 days : 60 * 24 * 14
|
| 9 |
+
account: cortex
|
| 10 |
+
qos: cortex_high
|
| 11 |
+
comment: "multiview_dust3r experiment"
|
| 12 |
+
nodes: 1
|
| 13 |
+
gres: "gpu:8"
|
| 14 |
+
tasks_per_node: 8
|
| 15 |
+
cpus_per_task: 12
|
| 16 |
+
signal_delay_s: 120 # USR1 signal delay (seconds) before timeout
|
| 17 |
+
max_num_timeout: 0 # number of times the job can be restarted after timeout
|
| 18 |
+
array_parallelism: 256 # Maximum number of jobs running in parallel
|
| 19 |
+
|
| 20 |
+
# Useful to add parameters which are not currently available in the plugin.
|
| 21 |
+
# Eg: {"mail-user": "blublu@fb.com", "mail-type": "BEGIN"}
|
| 22 |
+
additional_parameters:
|
| 23 |
+
mail-user: "jianingy@meta.com"
|
| 24 |
+
mail-type: "BEGIN,END"
|
| 25 |
+
output: "/path/to/slurm_out/%x-%j.out"
|
| 26 |
+
|
| 27 |
+
setup: # A list of commands to run in sbatch befure running srun
|
| 28 |
+
- echo "Begin setting up env on head node ($HOSTNAME)..."
|
| 29 |
+
- echo $(env | grep SLURM)
|
| 30 |
+
- export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
| 31 |
+
- export MASTER_PORT=9929
|
| 32 |
+
- export RDZV_ID=$SLURM_JOBID
|
| 33 |
+
- export OMP_NUM_THREADS=12
|
| 34 |
+
- . /path/to/miniforge3/etc/profile.d/conda.sh # activate conda
|
| 35 |
+
- conda activate dust3r
|
| 36 |
+
- cd /path/to/project # cd to the project directory
|
| 37 |
+
- export NCCL_DEBUG=INFO
|
| 38 |
+
- export PYTHONFAULTHANDLER=1
|
| 39 |
+
- export TORCH_DISTRIBUTED_DEBUG=INFO
|
| 40 |
+
- echo "env setup on head node ($HOSTNAME) finished, starting srun..."
|
| 41 |
+
|
| 42 |
+
srun_args:
|
| 43 |
+
- "--cpu-bind=none" # This is critical to ensure dataloaders uses all CPUs!
|
configs/local/.gitkeep
ADDED
|
File without changes
|
configs/logger/aim.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://aimstack.io/
|
| 2 |
+
|
| 3 |
+
# example usage in lightning module:
|
| 4 |
+
# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
|
| 5 |
+
|
| 6 |
+
# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
|
| 7 |
+
# `aim up`
|
| 8 |
+
|
| 9 |
+
aim:
|
| 10 |
+
_target_: aim.pytorch_lightning.AimLogger
|
| 11 |
+
repo: ${paths.root_dir} # .aim folder will be created here
|
| 12 |
+
# repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
|
| 13 |
+
|
| 14 |
+
# aim allows to group runs under experiment name
|
| 15 |
+
experiment: null # any string, set to "default" if not specified
|
| 16 |
+
|
| 17 |
+
train_metric_prefix: "train/"
|
| 18 |
+
val_metric_prefix: "val/"
|
| 19 |
+
test_metric_prefix: "test/"
|
| 20 |
+
|
| 21 |
+
# sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
|
| 22 |
+
system_tracking_interval: 10 # set to null to disable system metrics tracking
|
| 23 |
+
|
| 24 |
+
# enable/disable logging of system params such as installed packages, git info, env vars, etc.
|
| 25 |
+
log_system_params: true
|
| 26 |
+
|
| 27 |
+
# enable/disable tracking console logs (default value is true)
|
| 28 |
+
capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
|
configs/logger/comet.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://www.comet.ml
|
| 2 |
+
|
| 3 |
+
comet:
|
| 4 |
+
_target_: lightning.pytorch.loggers.comet.CometLogger
|
| 5 |
+
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
| 6 |
+
save_dir: "${paths.output_dir}"
|
| 7 |
+
project_name: "lightning-hydra-template"
|
| 8 |
+
rest_api_key: null
|
| 9 |
+
# experiment_name: ""
|
| 10 |
+
experiment_key: null # set to resume experiment
|
| 11 |
+
offline: False
|
| 12 |
+
prefix: ""
|
configs/logger/csv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# csv logger built in lightning
|
| 2 |
+
|
| 3 |
+
csv:
|
| 4 |
+
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
|
| 5 |
+
save_dir: "${paths.output_dir}"
|
| 6 |
+
name: "csv/"
|
| 7 |
+
prefix: ""
|
configs/logger/many_loggers.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train with many loggers at once
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
# - comet
|
| 5 |
+
- csv
|
| 6 |
+
# - mlflow
|
| 7 |
+
# - neptune
|
| 8 |
+
- tensorboard
|
| 9 |
+
- wandb
|
configs/logger/mlflow.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://mlflow.org
|
| 2 |
+
|
| 3 |
+
mlflow:
|
| 4 |
+
_target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
|
| 5 |
+
# experiment_name: ""
|
| 6 |
+
# run_name: ""
|
| 7 |
+
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
|
| 8 |
+
tags: null
|
| 9 |
+
# save_dir: "./mlruns"
|
| 10 |
+
prefix: ""
|
| 11 |
+
artifact_location: null
|
| 12 |
+
# run_id: ""
|
configs/logger/neptune.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://neptune.ai
|
| 2 |
+
|
| 3 |
+
neptune:
|
| 4 |
+
_target_: lightning.pytorch.loggers.neptune.NeptuneLogger
|
| 5 |
+
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
|
| 6 |
+
project: username/lightning-hydra-template
|
| 7 |
+
# name: ""
|
| 8 |
+
log_model_checkpoints: True
|
| 9 |
+
prefix: ""
|
configs/logger/tensorboard.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://www.tensorflow.org/tensorboard/
|
| 2 |
+
|
| 3 |
+
tensorboard:
|
| 4 |
+
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
| 5 |
+
save_dir: "${paths.output_dir}/tensorboard/"
|
| 6 |
+
name: null
|
| 7 |
+
log_graph: False
|
| 8 |
+
default_hp_metric: True
|
| 9 |
+
prefix: ""
|
| 10 |
+
# version: ""
|
configs/logger/wandb.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://wandb.ai
|
| 2 |
+
|
| 3 |
+
wandb:
|
| 4 |
+
_target_: lightning.pytorch.loggers.wandb.WandbLogger
|
| 5 |
+
name: null # name of the run (normally generated by wandb)
|
| 6 |
+
save_dir: "${paths.output_dir}"
|
| 7 |
+
offline: False
|
| 8 |
+
id: null # pass correct id to resume experiment!
|
| 9 |
+
anonymous: null # enable anonymous logging
|
| 10 |
+
project: "stream3r"
|
| 11 |
+
log_model: False # upload lightning ckpts
|
| 12 |
+
prefix: "" # a string to put at the beginning of metric keys
|
| 13 |
+
# entity: "" # set to name of your wandb team
|
| 14 |
+
group: ""
|
| 15 |
+
tags: []
|
| 16 |
+
job_type: ""
|
configs/model/stream3r.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: stream3r.models.multiview_dust3r_module.MultiViewDUSt3RLitModule
|
| 2 |
+
|
| 3 |
+
pretrained: null
|
| 4 |
+
resume_from_checkpoint: ${ckpt_path}
|
| 5 |
+
|
| 6 |
+
eval_use_pts3d_from_local_head: true
|
| 7 |
+
|
| 8 |
+
train_criterion:
|
| 9 |
+
_target_: stream3r.loss.losses.CausalLoss
|
| 10 |
+
|
| 11 |
+
validation_criterion:
|
| 12 |
+
_target_: stream3r.loss.losses.CausalLoss
|
| 13 |
+
|
| 14 |
+
optimizer:
|
| 15 |
+
_target_: torch.optim.AdamW
|
| 16 |
+
_partial_: true
|
| 17 |
+
lr: 1e-4
|
| 18 |
+
betas:
|
| 19 |
+
- 0.9
|
| 20 |
+
- 0.95
|
| 21 |
+
weight_decay: 0.05
|
| 22 |
+
|
| 23 |
+
# scheduler:
|
| 24 |
+
# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
| 25 |
+
# _partial_: true
|
| 26 |
+
# mode: min
|
| 27 |
+
# factor: 0.1
|
| 28 |
+
# patience: 10
|
| 29 |
+
|
| 30 |
+
scheduler:
|
| 31 |
+
_target_: pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR
|
| 32 |
+
_partial_: true
|
| 33 |
+
warmup_epochs: 10
|
| 34 |
+
max_epochs: ${trainer.max_epochs}
|
| 35 |
+
eta_min: 1e-06
|
| 36 |
+
|
| 37 |
+
net:
|
| 38 |
+
_target_: stream3r.models.stream3r.STream3R
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# compile model for faster training with pytorch 2.0
|
| 42 |
+
compile: false
|
configs/paths/default.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# path to root directory
|
| 2 |
+
# this requires PROJECT_ROOT environment variable to exist
|
| 3 |
+
# you can replace it with "." if you want the root to be the current working directory
|
| 4 |
+
# root_dir: R${oc.env:PROJECT_ROOT}
|
| 5 |
+
root_dir: .
|
| 6 |
+
|
| 7 |
+
# path to data directory
|
| 8 |
+
data_dir: ${paths.root_dir}/data/
|
| 9 |
+
|
| 10 |
+
# path to logging directory
|
| 11 |
+
log_dir: ${paths.root_dir}/logs/
|
| 12 |
+
|
| 13 |
+
# path to output directory, created dynamically by hydra
|
| 14 |
+
# path generation pattern is specified in `configs/hydra/default.yaml`
|
| 15 |
+
# use it to store all files generated during the run, like ckpts and metrics
|
| 16 |
+
output_dir: ${hydra:runtime.output_dir}
|
| 17 |
+
|
| 18 |
+
# path to working directory
|
| 19 |
+
work_dir: ${hydra:runtime.cwd}
|
| 20 |
+
|
| 21 |
+
run_folder_name: ${now:%Y-%m-%d}_${now:%H-%M-%S}
|
configs/train.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# specify here default configuration
|
| 4 |
+
# order of defaults determines the order in which configs override each other
|
| 5 |
+
defaults:
|
| 6 |
+
- _self_
|
| 7 |
+
- data: multiview_dust3r
|
| 8 |
+
- model: stream3r
|
| 9 |
+
- callbacks: default
|
| 10 |
+
- logger: many_loggers # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
|
| 11 |
+
- trainer: ddp
|
| 12 |
+
- paths: default
|
| 13 |
+
- extras: default
|
| 14 |
+
- hydra: default
|
| 15 |
+
|
| 16 |
+
# experiment configs allow for version control of specific hyperparameters
|
| 17 |
+
# e.g. best hyperparameters for given model and datamodule
|
| 18 |
+
- experiment: null
|
| 19 |
+
|
| 20 |
+
# config for hyperparameter optimization
|
| 21 |
+
- hparams_search: null
|
| 22 |
+
|
| 23 |
+
# optional local config for machine/user specific settings
|
| 24 |
+
# it's optional since it doesn't need to exist and is excluded from version control
|
| 25 |
+
- optional local: default
|
| 26 |
+
|
| 27 |
+
# debugging config (enable through command line, e.g. `python train.py debug=default)
|
| 28 |
+
- debug: null
|
| 29 |
+
|
| 30 |
+
# task name, determines output directory path
|
| 31 |
+
task_name: "train"
|
| 32 |
+
|
| 33 |
+
# tags to help you identify your experiments
|
| 34 |
+
# you can overwrite this in experiment configs
|
| 35 |
+
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
|
| 36 |
+
tags: ["dev"]
|
| 37 |
+
|
| 38 |
+
# set False to skip model training
|
| 39 |
+
train: True
|
| 40 |
+
|
| 41 |
+
# evaluate on test set, using best model weights achieved during training
|
| 42 |
+
# lightning chooses best weights based on the metric specified in checkpoint callback
|
| 43 |
+
test: True
|
| 44 |
+
|
| 45 |
+
# simply provide checkpoint path to resume training
|
| 46 |
+
ckpt_path: null
|
| 47 |
+
|
| 48 |
+
# seed for random number generators in pytorch, numpy and python.random
|
| 49 |
+
seed: 42
|
configs/trainer/cpu.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
accelerator: cpu
|
| 5 |
+
devices: 1
|
configs/trainer/ddp.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
# strategy: ddp
|
| 5 |
+
strategy: ddp_find_unused_parameters_true
|
| 6 |
+
|
| 7 |
+
accelerator: gpu
|
| 8 |
+
devices: auto
|
| 9 |
+
num_nodes: 1
|
| 10 |
+
sync_batchnorm: true
|
| 11 |
+
|
| 12 |
+
use_distributed_sampler: false
|
configs/trainer/ddp_eval.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
# strategy: ddp
|
| 5 |
+
strategy:
|
| 6 |
+
_target_: lightning.pytorch.strategies.DDPStrategy
|
| 7 |
+
timeout:
|
| 8 |
+
_target_: datetime.timedelta
|
| 9 |
+
minutes: 30
|
| 10 |
+
|
| 11 |
+
accelerator: gpu
|
| 12 |
+
devices: auto
|
| 13 |
+
num_nodes: 1
|
| 14 |
+
sync_batchnorm: true
|
| 15 |
+
|
| 16 |
+
use_distributed_sampler: false
|
configs/trainer/ddp_sim.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
# simulate DDP on CPU, useful for debugging
|
| 5 |
+
accelerator: cpu
|
| 6 |
+
devices: 2
|
| 7 |
+
strategy: ddp_spawn
|
configs/trainer/deepspeed_stage_2.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
# strategy: deepspeed_stage_2
|
| 5 |
+
strategy: deepspeed_stage_2
|
| 6 |
+
|
| 7 |
+
accelerator: gpu
|
| 8 |
+
devices: auto
|
| 9 |
+
num_nodes: 1
|
configs/trainer/default.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: lightning.pytorch.trainer.Trainer
|
| 2 |
+
_convert_: partial
|
| 3 |
+
|
| 4 |
+
default_root_dir: ${paths.output_dir}
|
| 5 |
+
|
| 6 |
+
min_epochs: 1 # prevents early stopping
|
| 7 |
+
max_epochs: 100
|
| 8 |
+
|
| 9 |
+
accelerator: cpu
|
| 10 |
+
devices: 1
|
| 11 |
+
|
| 12 |
+
# mixed precision for extra speed-up
|
| 13 |
+
# precision: 16
|
| 14 |
+
|
| 15 |
+
# perform a validation loop every N training epochs
|
| 16 |
+
check_val_every_n_epoch: 1
|
| 17 |
+
|
| 18 |
+
# set True to to ensure deterministic results
|
| 19 |
+
# makes training slower but gives more reproducibility than just setting seeds
|
| 20 |
+
deterministic: False
|
| 21 |
+
|
| 22 |
+
plugins:
|
| 23 |
+
- _target_: lightning.pytorch.plugins.environments.SLURMEnvironment
|
| 24 |
+
auto_requeue: true # auto-resubmit the job when it is preempted by slurm
|
| 25 |
+
requeue_signal: ${python_eval:"signal.SIGUSR1"} # singal code is platform dependent, so it has to be decided at runtime
|
| 26 |
+
# requeue_signal:
|
| 27 |
+
# _target_: signal.Signals
|
| 28 |
+
# _args_:
|
| 29 |
+
# - 10 # SIGUSR1, see: https://chromium.googlesource.com/chromiumos/docs/+/master/constants/signals.md
|
| 30 |
+
|
configs/trainer/gpu.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
accelerator: gpu
|
| 5 |
+
devices: 1
|
configs/trainer/mps.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default
|
| 3 |
+
|
| 4 |
+
accelerator: mps
|
| 5 |
+
devices: 1
|
eval/monodepth/eval_metrics.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| 5 |
+
from eval.monodepth.tools import depth_evaluation
|
| 6 |
+
import numpy as np
|
| 7 |
+
import json
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import glob
|
| 10 |
+
import cv2
|
| 11 |
+
from eval.monodepth.metadata import dataset_metadata
|
| 12 |
+
import argparse
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
TAG_FLOAT = 202021.25
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def depth_read_sintel(filename):
|
| 19 |
+
"""Read depth data from file, return as numpy array."""
|
| 20 |
+
f = open(filename, "rb")
|
| 21 |
+
check = np.fromfile(f, dtype=np.float32, count=1)[0]
|
| 22 |
+
assert (
|
| 23 |
+
check == TAG_FLOAT
|
| 24 |
+
), " depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format(
|
| 25 |
+
TAG_FLOAT, check
|
| 26 |
+
)
|
| 27 |
+
width = np.fromfile(f, dtype=np.int32, count=1)[0]
|
| 28 |
+
height = np.fromfile(f, dtype=np.int32, count=1)[0]
|
| 29 |
+
size = width * height
|
| 30 |
+
assert (
|
| 31 |
+
width > 0 and height > 0 and size > 1 and size < 100000000
|
| 32 |
+
), " depth_read:: Wrong input size (width = {0}, height = {1}).".format(
|
| 33 |
+
width, height
|
| 34 |
+
)
|
| 35 |
+
depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width))
|
| 36 |
+
return depth
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def depth_read_bonn(filename):
|
| 40 |
+
# loads depth map D from png file
|
| 41 |
+
# and returns it as a numpy array
|
| 42 |
+
depth_png = np.asarray(Image.open(filename))
|
| 43 |
+
# make sure we have a proper 16bit depth map here.. not 8bit!
|
| 44 |
+
assert np.max(depth_png) > 255
|
| 45 |
+
depth = depth_png.astype(np.float64) / 5000.0
|
| 46 |
+
depth[depth_png == 0] = -1.0
|
| 47 |
+
return depth
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def depth_read_kitti(filename):
|
| 51 |
+
# loads depth map D from png file
|
| 52 |
+
# and returns it as a numpy array,
|
| 53 |
+
# for details see readme.txt
|
| 54 |
+
img_pil = Image.open(filename)
|
| 55 |
+
depth_png = np.array(img_pil, dtype=int)
|
| 56 |
+
# make sure we have a proper 16bit depth map here.. not 8bit!
|
| 57 |
+
assert np.max(depth_png) > 255
|
| 58 |
+
|
| 59 |
+
depth = depth_png.astype(float) / 256.0
|
| 60 |
+
depth[depth_png == 0] = -1.0
|
| 61 |
+
return depth
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_gt_depth(filename, dataset):
|
| 65 |
+
if dataset == "sintel":
|
| 66 |
+
return depth_read_sintel(filename)
|
| 67 |
+
elif dataset == "bonn":
|
| 68 |
+
return depth_read_bonn(filename)
|
| 69 |
+
elif dataset == "kitti":
|
| 70 |
+
return depth_read_kitti(filename)
|
| 71 |
+
elif dataset == "nyu":
|
| 72 |
+
return np.load(filename)
|
| 73 |
+
else:
|
| 74 |
+
raise NotImplementedError
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_args_parser():
|
| 78 |
+
parser = argparse.ArgumentParser()
|
| 79 |
+
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--output_dir",
|
| 82 |
+
type=str,
|
| 83 |
+
default="",
|
| 84 |
+
help="value for outdir",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--eval_dataset", type=str, default="nyu", choices=list(dataset_metadata.keys())
|
| 88 |
+
)
|
| 89 |
+
return parser
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def main(args):
|
| 93 |
+
if args.eval_dataset == "nyu":
|
| 94 |
+
depth_pathes = glob.glob("data/nyu-v2/val/nyu_depths/*.npy")
|
| 95 |
+
depth_pathes = sorted(depth_pathes)
|
| 96 |
+
pred_pathes = glob.glob(
|
| 97 |
+
f"{args.output_dir}/*.npy"
|
| 98 |
+
) # TODO: update the path to your prediction
|
| 99 |
+
pred_pathes = sorted(pred_pathes)
|
| 100 |
+
elif args.eval_dataset == "sintel":
|
| 101 |
+
pred_pathes = glob.glob(
|
| 102 |
+
f"{args.output_dir}/*/*.npy"
|
| 103 |
+
) # TODO: update the path to your prediction
|
| 104 |
+
pred_pathes = sorted(pred_pathes)
|
| 105 |
+
full = len(pred_pathes) > 643
|
| 106 |
+
if full:
|
| 107 |
+
depth_pathes = glob.glob(f"data/sintel/training/depth/*/*.dpt")
|
| 108 |
+
depth_pathes = sorted(depth_pathes)
|
| 109 |
+
else:
|
| 110 |
+
seq_list = [
|
| 111 |
+
"alley_2",
|
| 112 |
+
"ambush_4",
|
| 113 |
+
"ambush_5",
|
| 114 |
+
"ambush_6",
|
| 115 |
+
"cave_2",
|
| 116 |
+
"cave_4",
|
| 117 |
+
"market_2",
|
| 118 |
+
"market_5",
|
| 119 |
+
"market_6",
|
| 120 |
+
"shaman_3",
|
| 121 |
+
"sleeping_1",
|
| 122 |
+
"sleeping_2",
|
| 123 |
+
"temple_2",
|
| 124 |
+
"temple_3",
|
| 125 |
+
]
|
| 126 |
+
depth_pathes_folder = [
|
| 127 |
+
f"data/sintel/training/depth/{seq}" for seq in seq_list
|
| 128 |
+
]
|
| 129 |
+
depth_pathes = []
|
| 130 |
+
for depth_pathes_folder_i in depth_pathes_folder:
|
| 131 |
+
depth_pathes += glob.glob(depth_pathes_folder_i + "/*.dpt")
|
| 132 |
+
depth_pathes = sorted(depth_pathes)
|
| 133 |
+
elif args.eval_dataset == "bonn":
|
| 134 |
+
seq_list = ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]
|
| 135 |
+
img_pathes_folder = [
|
| 136 |
+
f"data/bonn/rgbd_bonn_dataset/rgbd_bonn_{seq}/rgb_110/*.png"
|
| 137 |
+
for seq in seq_list
|
| 138 |
+
]
|
| 139 |
+
img_pathes = []
|
| 140 |
+
for img_pathes_folder_i in img_pathes_folder:
|
| 141 |
+
img_pathes += glob.glob(img_pathes_folder_i)
|
| 142 |
+
img_pathes = sorted(img_pathes)
|
| 143 |
+
depth_pathes_folder = [
|
| 144 |
+
f"data/bonn/rgbd_bonn_dataset/rgbd_bonn_{seq}/depth_110/*.png"
|
| 145 |
+
for seq in seq_list
|
| 146 |
+
]
|
| 147 |
+
depth_pathes = []
|
| 148 |
+
for depth_pathes_folder_i in depth_pathes_folder:
|
| 149 |
+
depth_pathes += glob.glob(depth_pathes_folder_i)
|
| 150 |
+
depth_pathes = sorted(depth_pathes)
|
| 151 |
+
pred_pathes = glob.glob(
|
| 152 |
+
f"{args.output_dir}/*/*.npy"
|
| 153 |
+
) # TODO: update the path to your prediction
|
| 154 |
+
pred_pathes = sorted(pred_pathes)
|
| 155 |
+
elif args.eval_dataset == "kitti":
|
| 156 |
+
depth_pathes = glob.glob(
|
| 157 |
+
"data/kitti/depth_selection/val_selection_cropped/groundtruth_depth_gathered/*/*.png"
|
| 158 |
+
)
|
| 159 |
+
depth_pathes = sorted(depth_pathes)
|
| 160 |
+
pred_pathes = glob.glob(
|
| 161 |
+
f"{args.output_dir}/*/*depth.npy"
|
| 162 |
+
) # TODO: update the path to your prediction
|
| 163 |
+
pred_pathes = sorted(pred_pathes)
|
| 164 |
+
else:
|
| 165 |
+
raise NotImplementedError
|
| 166 |
+
|
| 167 |
+
gathered_depth_metrics = []
|
| 168 |
+
for idx in tqdm(range(len(depth_pathes))):
|
| 169 |
+
pred_depth = np.load(pred_pathes[idx])
|
| 170 |
+
gt_depth = get_gt_depth(depth_pathes[idx], args.eval_dataset)
|
| 171 |
+
pred_depth = cv2.resize(
|
| 172 |
+
pred_depth,
|
| 173 |
+
(gt_depth.shape[1], gt_depth.shape[0]),
|
| 174 |
+
interpolation=cv2.INTER_CUBIC,
|
| 175 |
+
)
|
| 176 |
+
if args.eval_dataset == "nyu":
|
| 177 |
+
depth_results, error_map, depth_predict, depth_gt = depth_evaluation(
|
| 178 |
+
pred_depth, gt_depth, max_depth=None, lr=1e-3
|
| 179 |
+
)
|
| 180 |
+
elif args.eval_dataset == "sintel":
|
| 181 |
+
depth_results, error_map, depth_predict, depth_gt = depth_evaluation(
|
| 182 |
+
pred_depth, gt_depth, max_depth=70, use_gpu=True, post_clip_max=70
|
| 183 |
+
)
|
| 184 |
+
elif args.eval_dataset == "bonn":
|
| 185 |
+
depth_results, error_map, depth_predict, depth_gt = depth_evaluation(
|
| 186 |
+
pred_depth, gt_depth, max_depth=70, use_gpu=True
|
| 187 |
+
)
|
| 188 |
+
elif args.eval_dataset == "kitti":
|
| 189 |
+
depth_results, error_map, depth_predict, depth_gt = depth_evaluation(
|
| 190 |
+
pred_depth, gt_depth, max_depth=None, use_gpu=True
|
| 191 |
+
)
|
| 192 |
+
gathered_depth_metrics.append(depth_results)
|
| 193 |
+
|
| 194 |
+
depth_log_path = os.path.join(args.output_dir, "metric.json")
|
| 195 |
+
average_metrics = {
|
| 196 |
+
key: np.average(
|
| 197 |
+
[metrics[key] for metrics in gathered_depth_metrics],
|
| 198 |
+
weights=[metrics["valid_pixels"] for metrics in gathered_depth_metrics],
|
| 199 |
+
)
|
| 200 |
+
for key in gathered_depth_metrics[0].keys()
|
| 201 |
+
if key != "valid_pixels"
|
| 202 |
+
}
|
| 203 |
+
print(f"{args.eval_dataset} - Average depth evaluation metrics:", average_metrics)
|
| 204 |
+
with open(depth_log_path, "w") as f:
|
| 205 |
+
f.write(json.dumps(average_metrics))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
args = get_args_parser()
|
| 210 |
+
args = args.parse_args()
|
| 211 |
+
main(args)
|
eval/monodepth/launch.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
from stream3r.models.stream3r import STream3R
|
| 13 |
+
from stream3r.dust3r.utils.device import collate_with_cat
|
| 14 |
+
from stream3r.dust3r.utils.image import load_images_for_eval as load_images
|
| 15 |
+
from stream3r.utils.utils import ImgDust3r2Stream3r
|
| 16 |
+
|
| 17 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| 18 |
+
from eval.monodepth.metadata import dataset_metadata
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 22 |
+
|
| 23 |
+
# avoid high cpu usage
|
| 24 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 25 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
| 26 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
| 27 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
| 28 |
+
torch.set_num_threads(1)
|
| 29 |
+
# ===========================================
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def colorize_depth(depth: np.ndarray,
|
| 33 |
+
mask: np.ndarray = None,
|
| 34 |
+
normalize: bool = True,
|
| 35 |
+
cmap: str = 'Spectral') -> np.ndarray:
|
| 36 |
+
if mask is None:
|
| 37 |
+
depth = np.where(depth > 0, depth, np.nan)
|
| 38 |
+
else:
|
| 39 |
+
depth = np.where((depth > 0) & mask, depth, np.nan)
|
| 40 |
+
disp = 1 / depth
|
| 41 |
+
if normalize:
|
| 42 |
+
min_disp, max_disp = np.nanquantile(disp,
|
| 43 |
+
0.001), np.nanquantile(disp, 0.99)
|
| 44 |
+
disp = (disp - min_disp) / (max_disp - min_disp)
|
| 45 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
|
| 46 |
+
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
|
| 47 |
+
return colored
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_args_parser():
|
| 51 |
+
parser = argparse.ArgumentParser()
|
| 52 |
+
|
| 53 |
+
parser.add_argument("--device",
|
| 54 |
+
type=str,
|
| 55 |
+
default="cuda",
|
| 56 |
+
help="pytorch device")
|
| 57 |
+
parser.add_argument("--output_dir",
|
| 58 |
+
type=str,
|
| 59 |
+
default="",
|
| 60 |
+
help="value for outdir")
|
| 61 |
+
parser.add_argument("--no_crop",
|
| 62 |
+
type=bool,
|
| 63 |
+
default=True,
|
| 64 |
+
help="whether to crop input data")
|
| 65 |
+
parser.add_argument("--full_seq",
|
| 66 |
+
type=bool,
|
| 67 |
+
default=False,
|
| 68 |
+
help="whether to use all seqs")
|
| 69 |
+
parser.add_argument("--seq_list", default=None)
|
| 70 |
+
|
| 71 |
+
parser.add_argument("--eval_dataset",
|
| 72 |
+
type=str,
|
| 73 |
+
default="nyu",
|
| 74 |
+
choices=list(dataset_metadata.keys()))
|
| 75 |
+
return parser
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def eval_mono_depth_estimation(args, model, device):
|
| 79 |
+
metadata = dataset_metadata.get(args.eval_dataset)
|
| 80 |
+
if metadata is None:
|
| 81 |
+
raise ValueError(f"Unknown dataset: {args.eval_dataset}")
|
| 82 |
+
|
| 83 |
+
img_path = metadata.get("img_path")
|
| 84 |
+
if "img_path_func" in metadata:
|
| 85 |
+
img_path = metadata["img_path_func"](args)
|
| 86 |
+
|
| 87 |
+
process_func = metadata.get("process_func")
|
| 88 |
+
if process_func is None:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"No processing function defined for dataset: {args.eval_dataset}")
|
| 91 |
+
|
| 92 |
+
for filelist, save_dir in process_func(args, img_path):
|
| 93 |
+
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
| 94 |
+
eval_mono_depth(args, model, device, filelist, save_dir=save_dir)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def eval_mono_depth(args, model, device, filelist, save_dir=None):
|
| 98 |
+
for file in tqdm(filelist):
|
| 99 |
+
file = [file]
|
| 100 |
+
images = load_images(
|
| 101 |
+
file,
|
| 102 |
+
size=518,
|
| 103 |
+
verbose=True,
|
| 104 |
+
crop=False,
|
| 105 |
+
patch_size=14,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
images = collate_with_cat([tuple(images)])
|
| 109 |
+
images = torch.stack([view["img"] for view in images], dim=1)
|
| 110 |
+
images = ImgDust3r2Stream3r(images).to(device)
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
predictions = model(images)
|
| 114 |
+
|
| 115 |
+
depth_map = predictions['depth'][0,0].squeeze(-1).cpu()
|
| 116 |
+
|
| 117 |
+
if save_dir is not None:
|
| 118 |
+
# save the depth map to the save_dir as npy
|
| 119 |
+
np.save(
|
| 120 |
+
f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.npy')}",
|
| 121 |
+
depth_map.cpu().numpy(),
|
| 122 |
+
)
|
| 123 |
+
depth_map = colorize_depth(depth_map)
|
| 124 |
+
cv2.imwrite(
|
| 125 |
+
f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.jpg')}",
|
| 126 |
+
depth_map,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def main():
|
| 131 |
+
args = get_args_parser()
|
| 132 |
+
args = args.parse_args()
|
| 133 |
+
|
| 134 |
+
if args.eval_dataset == "sintel":
|
| 135 |
+
args.full_seq = True
|
| 136 |
+
else:
|
| 137 |
+
args.full_seq = False
|
| 138 |
+
|
| 139 |
+
model = STream3R.from_pretrained("yslan/STream3R").to(args.device)
|
| 140 |
+
model.eval()
|
| 141 |
+
|
| 142 |
+
eval_mono_depth_estimation(args, model, args.device)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
eval/monodepth/metadata.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
# Define the merged dataset metadata dictionary
|
| 6 |
+
dataset_metadata = {
|
| 7 |
+
"sun_rgbd": {
|
| 8 |
+
"img_path": "data/sun_rgbd/image/test",
|
| 9 |
+
"mask_path": None,
|
| 10 |
+
},
|
| 11 |
+
"davis": {
|
| 12 |
+
"img_path": "data/davis/DAVIS/JPEGImages/480p",
|
| 13 |
+
"mask_path": "data/davis/DAVIS/masked_images/480p",
|
| 14 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
|
| 15 |
+
"gt_traj_func": lambda img_path, anno_path, seq: None,
|
| 16 |
+
"traj_format": None,
|
| 17 |
+
"seq_list": None,
|
| 18 |
+
"full_seq": True,
|
| 19 |
+
"mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq),
|
| 20 |
+
"skip_condition": None,
|
| 21 |
+
"process_func": None, # Not used in mono depth estimation
|
| 22 |
+
},
|
| 23 |
+
"kitti": {
|
| 24 |
+
"img_path": "data/kitti/depth_selection/val_selection_cropped/image_gathered", # Default path
|
| 25 |
+
"mask_path": None,
|
| 26 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
|
| 27 |
+
"gt_traj_func": lambda img_path, anno_path, seq: None,
|
| 28 |
+
"traj_format": None,
|
| 29 |
+
"seq_list": None,
|
| 30 |
+
"full_seq": True,
|
| 31 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 32 |
+
"skip_condition": None,
|
| 33 |
+
"process_func": lambda args, img_path: process_kitti(args, img_path),
|
| 34 |
+
},
|
| 35 |
+
"bonn": {
|
| 36 |
+
"img_path": "data/bonn/rgbd_bonn_dataset",
|
| 37 |
+
"mask_path": None,
|
| 38 |
+
"dir_path_func": lambda img_path, seq: os.path.join(
|
| 39 |
+
img_path, f"rgbd_bonn_{seq}", "rgb_110"
|
| 40 |
+
),
|
| 41 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
|
| 42 |
+
img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt"
|
| 43 |
+
),
|
| 44 |
+
"traj_format": "tum",
|
| 45 |
+
"seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
|
| 46 |
+
"full_seq": False,
|
| 47 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 48 |
+
"skip_condition": None,
|
| 49 |
+
"process_func": lambda args, img_path: process_bonn(args, img_path),
|
| 50 |
+
},
|
| 51 |
+
"nyu": {
|
| 52 |
+
"img_path": "data/nyu-v2/val/nyu_images",
|
| 53 |
+
"mask_path": None,
|
| 54 |
+
"process_func": lambda args, img_path: process_nyu(args, img_path),
|
| 55 |
+
},
|
| 56 |
+
"scannet": {
|
| 57 |
+
"img_path": "data/scannetv2",
|
| 58 |
+
"mask_path": None,
|
| 59 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
|
| 60 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
|
| 61 |
+
img_path, seq, "pose_90.txt"
|
| 62 |
+
),
|
| 63 |
+
"traj_format": "replica",
|
| 64 |
+
"seq_list": None,
|
| 65 |
+
"full_seq": True,
|
| 66 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 67 |
+
"skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)),
|
| 68 |
+
"process_func": lambda args, img_path: process_scannet(args, img_path),
|
| 69 |
+
},
|
| 70 |
+
"tum": {
|
| 71 |
+
"img_path": "data/tum",
|
| 72 |
+
"mask_path": None,
|
| 73 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"),
|
| 74 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
|
| 75 |
+
img_path, seq, "groundtruth_90.txt"
|
| 76 |
+
),
|
| 77 |
+
"traj_format": "tum",
|
| 78 |
+
"seq_list": None,
|
| 79 |
+
"full_seq": True,
|
| 80 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 81 |
+
"skip_condition": None,
|
| 82 |
+
"process_func": None,
|
| 83 |
+
},
|
| 84 |
+
"sintel": {
|
| 85 |
+
"img_path": "data/sintel/training/final",
|
| 86 |
+
"anno_path": "data/sintel/training/camdata_left",
|
| 87 |
+
"mask_path": None,
|
| 88 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
|
| 89 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq),
|
| 90 |
+
"traj_format": None,
|
| 91 |
+
"seq_list": [
|
| 92 |
+
"alley_2",
|
| 93 |
+
"ambush_4",
|
| 94 |
+
"ambush_5",
|
| 95 |
+
"ambush_6",
|
| 96 |
+
"cave_2",
|
| 97 |
+
"cave_4",
|
| 98 |
+
"market_2",
|
| 99 |
+
"market_5",
|
| 100 |
+
"market_6",
|
| 101 |
+
"shaman_3",
|
| 102 |
+
"sleeping_1",
|
| 103 |
+
"sleeping_2",
|
| 104 |
+
"temple_2",
|
| 105 |
+
"temple_3",
|
| 106 |
+
],
|
| 107 |
+
"full_seq": False,
|
| 108 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 109 |
+
"skip_condition": None,
|
| 110 |
+
"process_func": lambda args, img_path: process_sintel(args, img_path),
|
| 111 |
+
},
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Define processing functions for each dataset
|
| 116 |
+
def process_kitti(args, img_path):
|
| 117 |
+
for dir in tqdm(sorted(glob.glob(f"{img_path}/*"))):
|
| 118 |
+
filelist = sorted(glob.glob(f"{dir}/*.png"))
|
| 119 |
+
save_dir = f"{args.output_dir}/{os.path.basename(dir)}"
|
| 120 |
+
yield filelist, save_dir
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def process_bonn(args, img_path):
|
| 124 |
+
if args.full_seq:
|
| 125 |
+
for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
|
| 126 |
+
filelist = sorted(glob.glob(f"{dir}/rgb/*.png"))
|
| 127 |
+
save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
|
| 128 |
+
yield filelist, save_dir
|
| 129 |
+
else:
|
| 130 |
+
seq_list = (
|
| 131 |
+
["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"]
|
| 132 |
+
if args.seq_list is None
|
| 133 |
+
else args.seq_list
|
| 134 |
+
)
|
| 135 |
+
for seq in tqdm(seq_list):
|
| 136 |
+
filelist = sorted(glob.glob(f"{img_path}/rgbd_bonn_{seq}/rgb_110/*.png"))
|
| 137 |
+
save_dir = f"{args.output_dir}/{seq}"
|
| 138 |
+
yield filelist, save_dir
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def process_sunrgbd(args, img_path):
|
| 142 |
+
filelist = sorted(glob.glob(f"{img_path}/*.jpg"))
|
| 143 |
+
save_dir = f"{args.output_dir}"
|
| 144 |
+
yield filelist, save_dir
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def process_nyu(args, img_path):
|
| 148 |
+
filelist = sorted(glob.glob(f"{img_path}/*.png"))
|
| 149 |
+
save_dir = f"{args.output_dir}"
|
| 150 |
+
yield filelist, save_dir
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def process_scannet(args, img_path):
|
| 154 |
+
seq_list = sorted(glob.glob(f"{img_path}/*"))
|
| 155 |
+
for seq in tqdm(seq_list):
|
| 156 |
+
filelist = sorted(glob.glob(f"{seq}/color_90/*.jpg"))
|
| 157 |
+
save_dir = f"{args.output_dir}/{os.path.basename(seq)}"
|
| 158 |
+
yield filelist, save_dir
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def process_sintel(args, img_path):
|
| 162 |
+
if args.full_seq:
|
| 163 |
+
for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))):
|
| 164 |
+
filelist = sorted(glob.glob(f"{dir}/*.png"))
|
| 165 |
+
save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}"
|
| 166 |
+
yield filelist, save_dir
|
| 167 |
+
else:
|
| 168 |
+
seq_list = [
|
| 169 |
+
"alley_2",
|
| 170 |
+
"ambush_4",
|
| 171 |
+
"ambush_5",
|
| 172 |
+
"ambush_6",
|
| 173 |
+
"cave_2",
|
| 174 |
+
"cave_4",
|
| 175 |
+
"market_2",
|
| 176 |
+
"market_5",
|
| 177 |
+
"market_6",
|
| 178 |
+
"shaman_3",
|
| 179 |
+
"sleeping_1",
|
| 180 |
+
"sleeping_2",
|
| 181 |
+
"temple_2",
|
| 182 |
+
"temple_3",
|
| 183 |
+
]
|
| 184 |
+
for seq in tqdm(seq_list):
|
| 185 |
+
filelist = sorted(glob.glob(f"{img_path}/{seq}/*.png"))
|
| 186 |
+
save_dir = f"{args.output_dir}/{seq}"
|
| 187 |
+
yield filelist, save_dir
|
eval/monodepth/run.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
workdir='.'
|
| 5 |
+
|
| 6 |
+
datasets=('sintel' 'bonn' 'kitti' 'nyu')
|
| 7 |
+
model_name='stream3r'
|
| 8 |
+
|
| 9 |
+
for data in "${datasets[@]}"; do
|
| 10 |
+
output_dir="${workdir}/eval_results/monodepth/${model_name}/${data}"
|
| 11 |
+
echo "$output_dir"
|
| 12 |
+
|
| 13 |
+
python eval/monodepth/launch.py \
|
| 14 |
+
--output_dir="$output_dir" \
|
| 15 |
+
--eval_dataset="$data" \
|
| 16 |
+
|
| 17 |
+
python eval/monodepth/eval_metrics.py \
|
| 18 |
+
--output_dir "$output_dir" \
|
| 19 |
+
--eval_dataset "$data"
|
| 20 |
+
done
|
eval/monodepth/tools.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import glob
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from scipy.optimize import minimize
|
| 10 |
+
import os
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def group_by_directory(pathes, idx=-1):
|
| 15 |
+
"""
|
| 16 |
+
Groups the file paths based on the second-to-last directory in their paths.
|
| 17 |
+
|
| 18 |
+
Parameters:
|
| 19 |
+
- pathes (list): List of file paths.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
- dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths.
|
| 23 |
+
"""
|
| 24 |
+
grouped_pathes = defaultdict(list)
|
| 25 |
+
|
| 26 |
+
for path in pathes:
|
| 27 |
+
# Extract the second-to-last directory
|
| 28 |
+
dir_name = os.path.dirname(path).split("/")[idx]
|
| 29 |
+
grouped_pathes[dir_name].append(path)
|
| 30 |
+
|
| 31 |
+
return grouped_pathes
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def depth2disparity(depth, return_mask=False):
|
| 35 |
+
if isinstance(depth, torch.Tensor):
|
| 36 |
+
disparity = torch.zeros_like(depth)
|
| 37 |
+
elif isinstance(depth, np.ndarray):
|
| 38 |
+
disparity = np.zeros_like(depth)
|
| 39 |
+
non_negtive_mask = depth > 0
|
| 40 |
+
disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
|
| 41 |
+
if return_mask:
|
| 42 |
+
return disparity, non_negtive_mask
|
| 43 |
+
else:
|
| 44 |
+
return disparity
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def absolute_error_loss(params, predicted_depth, ground_truth_depth):
|
| 48 |
+
s, t = params
|
| 49 |
+
|
| 50 |
+
predicted_aligned = s * predicted_depth + t
|
| 51 |
+
|
| 52 |
+
abs_error = np.abs(predicted_aligned - ground_truth_depth)
|
| 53 |
+
return np.sum(abs_error)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0):
|
| 57 |
+
predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1)
|
| 58 |
+
ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1)
|
| 59 |
+
|
| 60 |
+
initial_params = [s, t] # s = 1, t = 0
|
| 61 |
+
|
| 62 |
+
result = minimize(
|
| 63 |
+
absolute_error_loss,
|
| 64 |
+
initial_params,
|
| 65 |
+
args=(predicted_depth_np, ground_truth_depth_np),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
s, t = result.x
|
| 69 |
+
return s, t
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def absolute_value_scaling2(
|
| 73 |
+
predicted_depth,
|
| 74 |
+
ground_truth_depth,
|
| 75 |
+
s_init=1.0,
|
| 76 |
+
t_init=0.0,
|
| 77 |
+
lr=1e-4,
|
| 78 |
+
max_iters=1000,
|
| 79 |
+
tol=1e-6,
|
| 80 |
+
):
|
| 81 |
+
# Initialize s and t as torch tensors with requires_grad=True
|
| 82 |
+
s = torch.tensor(
|
| 83 |
+
[s_init],
|
| 84 |
+
requires_grad=True,
|
| 85 |
+
device=predicted_depth.device,
|
| 86 |
+
dtype=predicted_depth.dtype,
|
| 87 |
+
)
|
| 88 |
+
t = torch.tensor(
|
| 89 |
+
[t_init],
|
| 90 |
+
requires_grad=True,
|
| 91 |
+
device=predicted_depth.device,
|
| 92 |
+
dtype=predicted_depth.dtype,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
optimizer = torch.optim.Adam([s, t], lr=lr)
|
| 96 |
+
|
| 97 |
+
prev_loss = None
|
| 98 |
+
|
| 99 |
+
for i in range(max_iters):
|
| 100 |
+
optimizer.zero_grad()
|
| 101 |
+
|
| 102 |
+
# Compute predicted aligned depth
|
| 103 |
+
predicted_aligned = s * predicted_depth + t
|
| 104 |
+
|
| 105 |
+
# Compute absolute error
|
| 106 |
+
abs_error = torch.abs(predicted_aligned - ground_truth_depth)
|
| 107 |
+
|
| 108 |
+
# Compute loss
|
| 109 |
+
loss = torch.sum(abs_error)
|
| 110 |
+
|
| 111 |
+
# Backpropagate
|
| 112 |
+
loss.backward()
|
| 113 |
+
|
| 114 |
+
# Update parameters
|
| 115 |
+
optimizer.step()
|
| 116 |
+
|
| 117 |
+
# Check convergence
|
| 118 |
+
if prev_loss is not None and torch.abs(prev_loss - loss) < tol:
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
prev_loss = loss.item()
|
| 122 |
+
|
| 123 |
+
return s.detach().item(), t.detach().item()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def depth_evaluation(
|
| 127 |
+
predicted_depth_original,
|
| 128 |
+
ground_truth_depth_original,
|
| 129 |
+
max_depth=80,
|
| 130 |
+
custom_mask=None,
|
| 131 |
+
post_clip_min=None,
|
| 132 |
+
post_clip_max=None,
|
| 133 |
+
pre_clip_min=None,
|
| 134 |
+
pre_clip_max=None,
|
| 135 |
+
align_with_lstsq=False,
|
| 136 |
+
align_with_lad=False,
|
| 137 |
+
align_with_lad2=False,
|
| 138 |
+
metric_scale=False,
|
| 139 |
+
lr=1e-4,
|
| 140 |
+
max_iters=1000,
|
| 141 |
+
use_gpu=False,
|
| 142 |
+
align_with_scale=False,
|
| 143 |
+
disp_input=False,
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map.
|
| 150 |
+
ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map.
|
| 151 |
+
max_depth (float): The maximum depth value to consider. Default is 80 meters.
|
| 152 |
+
align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
dict: A dictionary containing the evaluation metrics.
|
| 156 |
+
torch.Tensor: The depth error parity map.
|
| 157 |
+
"""
|
| 158 |
+
if isinstance(predicted_depth_original, np.ndarray):
|
| 159 |
+
predicted_depth_original = torch.from_numpy(predicted_depth_original)
|
| 160 |
+
if isinstance(ground_truth_depth_original, np.ndarray):
|
| 161 |
+
ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original)
|
| 162 |
+
if custom_mask is not None and isinstance(custom_mask, np.ndarray):
|
| 163 |
+
custom_mask = torch.from_numpy(custom_mask)
|
| 164 |
+
|
| 165 |
+
# if the dimension is 3, flatten to 2d along the batch dimension
|
| 166 |
+
if predicted_depth_original.dim() == 3:
|
| 167 |
+
_, h, w = predicted_depth_original.shape
|
| 168 |
+
predicted_depth_original = predicted_depth_original.view(-1, w)
|
| 169 |
+
ground_truth_depth_original = ground_truth_depth_original.view(-1, w)
|
| 170 |
+
if custom_mask is not None:
|
| 171 |
+
custom_mask = custom_mask.view(-1, w)
|
| 172 |
+
|
| 173 |
+
# put to device
|
| 174 |
+
if use_gpu:
|
| 175 |
+
predicted_depth_original = predicted_depth_original.cuda()
|
| 176 |
+
ground_truth_depth_original = ground_truth_depth_original.cuda()
|
| 177 |
+
|
| 178 |
+
# Filter out depths greater than max_depth
|
| 179 |
+
if max_depth is not None:
|
| 180 |
+
mask = (ground_truth_depth_original > 0) & (
|
| 181 |
+
ground_truth_depth_original < max_depth
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
mask = ground_truth_depth_original > 0
|
| 185 |
+
predicted_depth = predicted_depth_original[mask]
|
| 186 |
+
ground_truth_depth = ground_truth_depth_original[mask]
|
| 187 |
+
|
| 188 |
+
# Clip the depth values
|
| 189 |
+
if pre_clip_min is not None:
|
| 190 |
+
predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min)
|
| 191 |
+
if pre_clip_max is not None:
|
| 192 |
+
predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max)
|
| 193 |
+
|
| 194 |
+
if disp_input: # align the pred to gt in the disparity space
|
| 195 |
+
real_gt = ground_truth_depth.clone()
|
| 196 |
+
ground_truth_depth = 1 / (ground_truth_depth + 1e-8)
|
| 197 |
+
|
| 198 |
+
# various alignment methods
|
| 199 |
+
if metric_scale:
|
| 200 |
+
predicted_depth = predicted_depth
|
| 201 |
+
elif align_with_lstsq:
|
| 202 |
+
# Convert to numpy for lstsq
|
| 203 |
+
predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1)
|
| 204 |
+
ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1)
|
| 205 |
+
|
| 206 |
+
# Add a column of ones for the shift term
|
| 207 |
+
A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)])
|
| 208 |
+
|
| 209 |
+
# Solve for scale (s) and shift (t) using least squares
|
| 210 |
+
result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None)
|
| 211 |
+
s, t = result[0][0], result[0][1]
|
| 212 |
+
|
| 213 |
+
# convert to torch tensor
|
| 214 |
+
s = torch.tensor(s, device=predicted_depth_original.device)
|
| 215 |
+
t = torch.tensor(t, device=predicted_depth_original.device)
|
| 216 |
+
|
| 217 |
+
# Apply scale and shift
|
| 218 |
+
predicted_depth = s * predicted_depth + t
|
| 219 |
+
elif align_with_lad:
|
| 220 |
+
s, t = absolute_value_scaling(
|
| 221 |
+
predicted_depth,
|
| 222 |
+
ground_truth_depth,
|
| 223 |
+
s=torch.median(ground_truth_depth) / torch.median(predicted_depth),
|
| 224 |
+
)
|
| 225 |
+
predicted_depth = s * predicted_depth + t
|
| 226 |
+
elif align_with_lad2:
|
| 227 |
+
s_init = (
|
| 228 |
+
torch.median(ground_truth_depth) / torch.median(predicted_depth)
|
| 229 |
+
).item()
|
| 230 |
+
s, t = absolute_value_scaling2(
|
| 231 |
+
predicted_depth,
|
| 232 |
+
ground_truth_depth,
|
| 233 |
+
s_init=s_init,
|
| 234 |
+
lr=lr,
|
| 235 |
+
max_iters=max_iters,
|
| 236 |
+
)
|
| 237 |
+
predicted_depth = s * predicted_depth + t
|
| 238 |
+
elif align_with_scale:
|
| 239 |
+
# Compute initial scale factor 's' using the closed-form solution (L2 norm)
|
| 240 |
+
dot_pred_gt = torch.nanmean(ground_truth_depth)
|
| 241 |
+
dot_pred_pred = torch.nanmean(predicted_depth)
|
| 242 |
+
s = dot_pred_gt / dot_pred_pred
|
| 243 |
+
|
| 244 |
+
# Iterative reweighted least squares using the Weiszfeld method
|
| 245 |
+
for _ in range(10):
|
| 246 |
+
# Compute residuals between scaled predictions and ground truth
|
| 247 |
+
residuals = s * predicted_depth - ground_truth_depth
|
| 248 |
+
abs_residuals = (
|
| 249 |
+
residuals.abs() + 1e-8
|
| 250 |
+
) # Add small constant to avoid division by zero
|
| 251 |
+
|
| 252 |
+
# Compute weights inversely proportional to the residuals
|
| 253 |
+
weights = 1.0 / abs_residuals
|
| 254 |
+
|
| 255 |
+
# Update 's' using weighted sums
|
| 256 |
+
weighted_dot_pred_gt = torch.sum(
|
| 257 |
+
weights * predicted_depth * ground_truth_depth
|
| 258 |
+
)
|
| 259 |
+
weighted_dot_pred_pred = torch.sum(weights * predicted_depth**2)
|
| 260 |
+
s = weighted_dot_pred_gt / weighted_dot_pred_pred
|
| 261 |
+
|
| 262 |
+
# Optionally clip 's' to prevent extreme scaling
|
| 263 |
+
s = s.clamp(min=1e-3)
|
| 264 |
+
|
| 265 |
+
# Detach 's' if you want to stop gradients from flowing through it
|
| 266 |
+
s = s.detach()
|
| 267 |
+
|
| 268 |
+
# Apply the scale factor to the predicted depth
|
| 269 |
+
predicted_depth = s * predicted_depth
|
| 270 |
+
|
| 271 |
+
else:
|
| 272 |
+
# Align the predicted depth with the ground truth using median scaling
|
| 273 |
+
scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth)
|
| 274 |
+
predicted_depth *= scale_factor
|
| 275 |
+
|
| 276 |
+
if disp_input:
|
| 277 |
+
# convert back to depth
|
| 278 |
+
ground_truth_depth = real_gt
|
| 279 |
+
predicted_depth = depth2disparity(predicted_depth)
|
| 280 |
+
|
| 281 |
+
# Clip the predicted depth values
|
| 282 |
+
if post_clip_min is not None:
|
| 283 |
+
predicted_depth = torch.clamp(predicted_depth, min=post_clip_min)
|
| 284 |
+
if post_clip_max is not None:
|
| 285 |
+
predicted_depth = torch.clamp(predicted_depth, max=post_clip_max)
|
| 286 |
+
|
| 287 |
+
if custom_mask is not None:
|
| 288 |
+
assert custom_mask.shape == ground_truth_depth_original.shape
|
| 289 |
+
mask_within_mask = custom_mask.cpu()[mask]
|
| 290 |
+
predicted_depth = predicted_depth[mask_within_mask]
|
| 291 |
+
ground_truth_depth = ground_truth_depth[mask_within_mask]
|
| 292 |
+
|
| 293 |
+
# Calculate the metrics
|
| 294 |
+
abs_rel = torch.mean(
|
| 295 |
+
torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth
|
| 296 |
+
).item()
|
| 297 |
+
sq_rel = torch.mean(
|
| 298 |
+
((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth
|
| 299 |
+
).item()
|
| 300 |
+
|
| 301 |
+
# Correct RMSE calculation
|
| 302 |
+
rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item()
|
| 303 |
+
|
| 304 |
+
# Clip the depth values to avoid log(0)
|
| 305 |
+
predicted_depth = torch.clamp(predicted_depth, min=1e-5)
|
| 306 |
+
log_rmse = torch.sqrt(
|
| 307 |
+
torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2)
|
| 308 |
+
).item()
|
| 309 |
+
|
| 310 |
+
# Calculate the accuracy thresholds
|
| 311 |
+
max_ratio = torch.maximum(
|
| 312 |
+
predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth
|
| 313 |
+
)
|
| 314 |
+
threshold_0 = torch.mean((max_ratio < 1.0).float()).item()
|
| 315 |
+
threshold_1 = torch.mean((max_ratio < 1.25).float()).item()
|
| 316 |
+
threshold_2 = torch.mean((max_ratio < 1.25**2).float()).item()
|
| 317 |
+
threshold_3 = torch.mean((max_ratio < 1.25**3).float()).item()
|
| 318 |
+
|
| 319 |
+
# Compute the depth error parity map
|
| 320 |
+
if metric_scale:
|
| 321 |
+
predicted_depth_original = predicted_depth_original
|
| 322 |
+
if disp_input:
|
| 323 |
+
predicted_depth_original = depth2disparity(predicted_depth_original)
|
| 324 |
+
depth_error_parity_map = (
|
| 325 |
+
torch.abs(predicted_depth_original - ground_truth_depth_original)
|
| 326 |
+
/ ground_truth_depth_original
|
| 327 |
+
)
|
| 328 |
+
elif align_with_lstsq or align_with_lad or align_with_lad2:
|
| 329 |
+
predicted_depth_original = predicted_depth_original * s + t
|
| 330 |
+
if disp_input:
|
| 331 |
+
predicted_depth_original = depth2disparity(predicted_depth_original)
|
| 332 |
+
depth_error_parity_map = (
|
| 333 |
+
torch.abs(predicted_depth_original - ground_truth_depth_original)
|
| 334 |
+
/ ground_truth_depth_original
|
| 335 |
+
)
|
| 336 |
+
elif align_with_scale:
|
| 337 |
+
predicted_depth_original = predicted_depth_original * s
|
| 338 |
+
if disp_input:
|
| 339 |
+
predicted_depth_original = depth2disparity(predicted_depth_original)
|
| 340 |
+
depth_error_parity_map = (
|
| 341 |
+
torch.abs(predicted_depth_original - ground_truth_depth_original)
|
| 342 |
+
/ ground_truth_depth_original
|
| 343 |
+
)
|
| 344 |
+
else:
|
| 345 |
+
predicted_depth_original = predicted_depth_original * scale_factor
|
| 346 |
+
if disp_input:
|
| 347 |
+
predicted_depth_original = depth2disparity(predicted_depth_original)
|
| 348 |
+
depth_error_parity_map = (
|
| 349 |
+
torch.abs(predicted_depth_original - ground_truth_depth_original)
|
| 350 |
+
/ ground_truth_depth_original
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Reshape the depth_error_parity_map back to the original image size
|
| 354 |
+
depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original)
|
| 355 |
+
depth_error_parity_map_full = torch.where(
|
| 356 |
+
mask, depth_error_parity_map, depth_error_parity_map_full
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
predict_depth_map_full = predicted_depth_original
|
| 360 |
+
gt_depth_map_full = torch.zeros_like(ground_truth_depth_original)
|
| 361 |
+
gt_depth_map_full = torch.where(
|
| 362 |
+
mask, ground_truth_depth_original, gt_depth_map_full
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
num_valid_pixels = (
|
| 366 |
+
torch.sum(mask).item()
|
| 367 |
+
if custom_mask is None
|
| 368 |
+
else torch.sum(mask_within_mask).item()
|
| 369 |
+
)
|
| 370 |
+
if num_valid_pixels == 0:
|
| 371 |
+
(
|
| 372 |
+
abs_rel,
|
| 373 |
+
sq_rel,
|
| 374 |
+
rmse,
|
| 375 |
+
log_rmse,
|
| 376 |
+
threshold_0,
|
| 377 |
+
threshold_1,
|
| 378 |
+
threshold_2,
|
| 379 |
+
threshold_3,
|
| 380 |
+
) = (0, 0, 0, 0, 0, 0, 0, 0)
|
| 381 |
+
|
| 382 |
+
results = {
|
| 383 |
+
"Abs Rel": abs_rel,
|
| 384 |
+
"Sq Rel": sq_rel,
|
| 385 |
+
"RMSE": rmse,
|
| 386 |
+
"Log RMSE": log_rmse,
|
| 387 |
+
"δ < 1.": threshold_0,
|
| 388 |
+
"δ < 1.25": threshold_1,
|
| 389 |
+
"δ < 1.25^2": threshold_2,
|
| 390 |
+
"δ < 1.25^3": threshold_3,
|
| 391 |
+
"valid_pixels": num_valid_pixels,
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
return (
|
| 395 |
+
results,
|
| 396 |
+
depth_error_parity_map_full,
|
| 397 |
+
predict_depth_map_full,
|
| 398 |
+
gt_depth_map_full,
|
| 399 |
+
)
|
eval/mv_recon/base.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# base class for implementing datasets
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import PIL
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from stream3r.dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
|
| 12 |
+
|
| 13 |
+
from eval.mv_recon.dataset_utils.transforms import ImgNorm
|
| 14 |
+
import eval.mv_recon.dataset_utils.cropping as cropping
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseStereoViewDataset:
|
| 18 |
+
"""Define all basic options.
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
class MyDataset (BaseStereoViewDataset):
|
| 22 |
+
def _get_views(self, idx, rng):
|
| 23 |
+
# overload here
|
| 24 |
+
views = []
|
| 25 |
+
views.append(dict(img=, ...))
|
| 26 |
+
return views
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
*, # only keyword arguments
|
| 32 |
+
split=None,
|
| 33 |
+
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
| 34 |
+
transform=ImgNorm,
|
| 35 |
+
aug_crop=False,
|
| 36 |
+
seed=None,
|
| 37 |
+
):
|
| 38 |
+
self.num_views = 2
|
| 39 |
+
self.split = split
|
| 40 |
+
self._set_resolutions(resolution)
|
| 41 |
+
|
| 42 |
+
self.transform = transform
|
| 43 |
+
if isinstance(transform, str):
|
| 44 |
+
transform = eval(transform)
|
| 45 |
+
|
| 46 |
+
self.aug_crop = aug_crop
|
| 47 |
+
self.seed = seed
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return len(self.scenes)
|
| 51 |
+
|
| 52 |
+
def get_stats(self):
|
| 53 |
+
return f"{len(self)} pairs"
|
| 54 |
+
|
| 55 |
+
def __repr__(self):
|
| 56 |
+
resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
|
| 57 |
+
return (
|
| 58 |
+
f"""{type(self).__name__}({self.get_stats()},
|
| 59 |
+
{self.split=},
|
| 60 |
+
{self.seed=},
|
| 61 |
+
resolutions={resolutions_str},
|
| 62 |
+
{self.transform=})""".replace(
|
| 63 |
+
"self.", ""
|
| 64 |
+
)
|
| 65 |
+
.replace("\n", "")
|
| 66 |
+
.replace(" ", "")
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def _get_views(self, idx, resolution, rng):
|
| 70 |
+
raise NotImplementedError()
|
| 71 |
+
|
| 72 |
+
def __getitem__(self, idx):
|
| 73 |
+
if isinstance(idx, tuple):
|
| 74 |
+
# the idx is specifying the aspect-ratio
|
| 75 |
+
idx, ar_idx = idx
|
| 76 |
+
else:
|
| 77 |
+
assert len(self._resolutions) == 1
|
| 78 |
+
ar_idx = 0
|
| 79 |
+
|
| 80 |
+
# set-up the rng
|
| 81 |
+
if self.seed: # reseed for each __getitem__
|
| 82 |
+
self._rng = np.random.default_rng(seed=self.seed + idx)
|
| 83 |
+
elif not hasattr(self, "_rng"):
|
| 84 |
+
seed = torch.initial_seed() # this is different for each dataloader process
|
| 85 |
+
self._rng = np.random.default_rng(seed=seed)
|
| 86 |
+
|
| 87 |
+
# over-loaded code
|
| 88 |
+
resolution = self._resolutions[
|
| 89 |
+
ar_idx
|
| 90 |
+
] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
| 91 |
+
views = self._get_views(idx, resolution, self._rng)
|
| 92 |
+
|
| 93 |
+
# check data-types
|
| 94 |
+
for v, view in enumerate(views):
|
| 95 |
+
assert (
|
| 96 |
+
"pts3d" not in view
|
| 97 |
+
), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
| 98 |
+
view["idx"] = v
|
| 99 |
+
|
| 100 |
+
# encode the image
|
| 101 |
+
width, height = view["img"].size
|
| 102 |
+
view["true_shape"] = np.int32((height, width))
|
| 103 |
+
view["img"] = self.transform(view["img"])
|
| 104 |
+
|
| 105 |
+
assert "camera_intrinsics" in view
|
| 106 |
+
if "camera_pose" not in view:
|
| 107 |
+
view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
|
| 108 |
+
else:
|
| 109 |
+
assert np.isfinite(
|
| 110 |
+
view["camera_pose"]
|
| 111 |
+
).all(), f"NaN in camera pose for view {view_name(view)}"
|
| 112 |
+
assert "pts3d" not in view
|
| 113 |
+
assert "valid_mask" not in view
|
| 114 |
+
assert np.isfinite(
|
| 115 |
+
view["depthmap"]
|
| 116 |
+
).all(), f"NaN in depthmap for view {view_name(view)}"
|
| 117 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 118 |
+
|
| 119 |
+
view["pts3d"] = pts3d
|
| 120 |
+
view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
| 121 |
+
|
| 122 |
+
# check all datatypes
|
| 123 |
+
for key, val in view.items():
|
| 124 |
+
res, err_msg = is_good_type(key, val)
|
| 125 |
+
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
| 126 |
+
K = view["camera_intrinsics"]
|
| 127 |
+
view["img_mask"] = True
|
| 128 |
+
view["ray_mask"] = False
|
| 129 |
+
view["ray_map"] = torch.full(
|
| 130 |
+
(6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
|
| 131 |
+
)
|
| 132 |
+
view["update"] = True
|
| 133 |
+
view["reset"] = False
|
| 134 |
+
|
| 135 |
+
# last thing done!
|
| 136 |
+
for view in views:
|
| 137 |
+
# transpose to make sure all views are the same size
|
| 138 |
+
transpose_to_landscape(view)
|
| 139 |
+
# this allows to check whether the RNG is is the same state each time
|
| 140 |
+
view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
|
| 141 |
+
return views
|
| 142 |
+
|
| 143 |
+
def _set_resolutions(self, resolutions):
|
| 144 |
+
"""Set the resolution(s) of the dataset.
|
| 145 |
+
Params:
|
| 146 |
+
- resolutions: int or tuple or list of tuples
|
| 147 |
+
"""
|
| 148 |
+
assert resolutions is not None, "undefined resolution"
|
| 149 |
+
|
| 150 |
+
if not isinstance(resolutions, list):
|
| 151 |
+
resolutions = [resolutions]
|
| 152 |
+
|
| 153 |
+
self._resolutions = []
|
| 154 |
+
for resolution in resolutions:
|
| 155 |
+
if isinstance(resolution, int):
|
| 156 |
+
width = height = resolution
|
| 157 |
+
else:
|
| 158 |
+
width, height = resolution
|
| 159 |
+
assert isinstance(
|
| 160 |
+
width, int
|
| 161 |
+
), f"Bad type for {width=} {type(width)=}, should be int"
|
| 162 |
+
assert isinstance(
|
| 163 |
+
height, int
|
| 164 |
+
), f"Bad type for {height=} {type(height)=}, should be int"
|
| 165 |
+
assert width >= height
|
| 166 |
+
self._resolutions.append((width, height))
|
| 167 |
+
|
| 168 |
+
def _crop_resize_if_necessary(
|
| 169 |
+
self, image, depthmap, intrinsics, resolution, rng=None, info=None
|
| 170 |
+
):
|
| 171 |
+
"""This function:
|
| 172 |
+
- first downsizes the image with LANCZOS inteprolation,
|
| 173 |
+
which is better than bilinear interpolation in
|
| 174 |
+
"""
|
| 175 |
+
if not isinstance(image, PIL.Image.Image):
|
| 176 |
+
image = PIL.Image.fromarray(image)
|
| 177 |
+
|
| 178 |
+
# downscale with lanczos interpolation so that image.size == resolution
|
| 179 |
+
# cropping centered on the principal point
|
| 180 |
+
W, H = image.size
|
| 181 |
+
cx, cy = intrinsics[:2, 2].round().astype(int)
|
| 182 |
+
|
| 183 |
+
# calculate min distance to margin
|
| 184 |
+
min_margin_x = min(cx, W - cx)
|
| 185 |
+
min_margin_y = min(cy, H - cy)
|
| 186 |
+
assert min_margin_x > W / 5, f"Bad principal point in view={info}"
|
| 187 |
+
assert min_margin_y > H / 5, f"Bad principal point in view={info}"
|
| 188 |
+
|
| 189 |
+
## Center crop
|
| 190 |
+
# Crop on the principal point, make it always centered
|
| 191 |
+
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
|
| 192 |
+
l, t = cx - min_margin_x, cy - min_margin_y
|
| 193 |
+
r, b = cx + min_margin_x, cy + min_margin_y
|
| 194 |
+
crop_bbox = (l, t, r, b)
|
| 195 |
+
|
| 196 |
+
image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 197 |
+
image, depthmap, intrinsics, crop_bbox
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# # transpose the resolution if necessary
|
| 201 |
+
W, H = image.size # new size
|
| 202 |
+
assert resolution[0] >= resolution[1]
|
| 203 |
+
if H > 1.1 * W:
|
| 204 |
+
# image is portrait mode
|
| 205 |
+
resolution = resolution[::-1]
|
| 206 |
+
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
|
| 207 |
+
# image is square, so we chose (portrait, landscape) randomly
|
| 208 |
+
if rng.integers(2):
|
| 209 |
+
resolution = resolution[::-1]
|
| 210 |
+
|
| 211 |
+
# high-quality Lanczos down-scaling
|
| 212 |
+
target_resolution = np.array(resolution)
|
| 213 |
+
# # if self.aug_crop > 1:
|
| 214 |
+
# # target_resolution += rng.integers(0, self.aug_crop)
|
| 215 |
+
# if resolution != (224, 224):
|
| 216 |
+
# halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
|
| 217 |
+
# ## Recale with max factor, so one of width or height might be larger than target_resolution
|
| 218 |
+
# image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
|
| 219 |
+
# else:
|
| 220 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(
|
| 221 |
+
image, depthmap, intrinsics, target_resolution
|
| 222 |
+
)
|
| 223 |
+
# actual cropping (if necessary) with bilinear interpolation
|
| 224 |
+
# if resolution == (224, 224):
|
| 225 |
+
intrinsics2 = cropping.camera_matrix_of_crop(
|
| 226 |
+
intrinsics, image.size, resolution, offset_factor=0.5
|
| 227 |
+
)
|
| 228 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(
|
| 229 |
+
intrinsics, intrinsics2, resolution
|
| 230 |
+
)
|
| 231 |
+
image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 232 |
+
image, depthmap, intrinsics, crop_bbox
|
| 233 |
+
)
|
| 234 |
+
return image, depthmap, intrinsics
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def is_good_type(key, v):
|
| 238 |
+
"""returns (is_good, err_msg)"""
|
| 239 |
+
if isinstance(v, (str, int, tuple)):
|
| 240 |
+
return True, None
|
| 241 |
+
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
|
| 242 |
+
return False, f"bad {v.dtype=}"
|
| 243 |
+
return True, None
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def view_name(view, batch_index=None):
|
| 247 |
+
def sel(x):
|
| 248 |
+
return x[batch_index] if batch_index not in (None, slice(None)) else x
|
| 249 |
+
|
| 250 |
+
db = sel(view["dataset"])
|
| 251 |
+
label = sel(view["label"])
|
| 252 |
+
instance = sel(view["instance"])
|
| 253 |
+
return f"{db}/{label}/{instance}"
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def transpose_to_landscape(view):
|
| 257 |
+
height, width = view["true_shape"]
|
| 258 |
+
|
| 259 |
+
if width < height:
|
| 260 |
+
# rectify portrait to landscape
|
| 261 |
+
assert view["img"].shape == (3, height, width)
|
| 262 |
+
view["img"] = view["img"].swapaxes(1, 2)
|
| 263 |
+
|
| 264 |
+
assert view["valid_mask"].shape == (height, width)
|
| 265 |
+
view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
|
| 266 |
+
|
| 267 |
+
assert view["depthmap"].shape == (height, width)
|
| 268 |
+
view["depthmap"] = view["depthmap"].swapaxes(0, 1)
|
| 269 |
+
|
| 270 |
+
assert view["pts3d"].shape == (height, width, 3)
|
| 271 |
+
view["pts3d"] = view["pts3d"].swapaxes(0, 1)
|
| 272 |
+
|
| 273 |
+
# transpose x and y pixels
|
| 274 |
+
view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
|