diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..d7b99820e699849afed50b932ee9930fd4890ddf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..69ad179855e254f05cd15d45418c5f736fbd7d34 --- /dev/null +++ b/LICENSE @@ -0,0 +1,13 @@ +S-Lab License 1.0 +Copyright 2025 S-Lab + +Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. \ No newline at end of file diff --git a/assets/pipeline.png b/assets/pipeline.png new file mode 100644 index 0000000000000000000000000000000000000000..430c91c99f49b571661c3d598bc2f72df4789089 --- /dev/null +++ b/assets/pipeline.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:099a2c82e37b04878112826abcc85b02cf86e0bd059824dfe98e4a99782d6aac +size 654500 diff --git a/assets/teaser_dynamic.gif b/assets/teaser_dynamic.gif new file mode 100644 index 0000000000000000000000000000000000000000..1fa0427c48feb862fddd327bc37f43790bba76ff --- /dev/null +++ b/assets/teaser_dynamic.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb25ab7cf2e3dcff862a3e8e82657dbba7fc0cbc36856f315d2b6e25f9bb9d72 +size 2330571 diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc12b57bb1525eeec58b6c9cba8c684ed7e314ac --- /dev/null +++ b/configs/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# this file is needed here to include configs when building project as a package diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..521d6e411e1c83d2fc696ac6007f163ec6bfa8a3 --- /dev/null +++ b/configs/callbacks/default.yaml @@ -0,0 +1,22 @@ +defaults: + - model_checkpoint + - early_stopping + - model_summary + - rich_progress_bar + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}" + monitor: "val/loss" + mode: "min" + save_last: True + auto_insert_metric_name: False + +early_stopping: + monitor: "val/loss" + patience: 100 + mode: "min" + +model_summary: + max_depth: -1 diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c826c8d58651a5e2c7cca0e99948a9b6ccabccf3 --- /dev/null +++ b/configs/callbacks/early_stopping.yaml @@ -0,0 +1,15 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html + +early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: ??? # quantity to be monitored, must be specified !!! + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 3 # number of checks with no improvement after which training will be stopped + verbose: False # verbosity mode + mode: "min" # "max" means higher metric value is better, can be also "min" + strict: True # whether to crash the training if monitor is not found in the validation metrics + check_finite: True # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch + # log_rank_zero_only: False # this keyword argument isn't available in stable version diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9c035acb5350a73a77b5038095eccc43b0aec9e --- /dev/null +++ b/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,17 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: null # directory to save the model file + filename: null # checkpoint filename + monitor: 'val/loss' # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 1 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: False # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: 20 # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b75981d8cd5d73f61088d80495dc540274bca3d1 --- /dev/null +++ b/configs/callbacks/model_summary.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html + +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/configs/callbacks/none.yaml b/configs/callbacks/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f761eb783958a9c32b64a7c753027358f29751b2 --- /dev/null +++ b/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,18 @@ +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html + +rich_progress_bar: + _target_: lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar + refresh_rate: 1 + leave: false + theme: + _target_: lightning.pytorch.callbacks.progress.rich_progress.RichProgressBarTheme + description: green_yellow + progress_bar: green1 + progress_bar_finished: green1 + progress_bar_pulse: "#6206E0" + batch_progress: green_yellow + time: blue + processing_speed: cyan + metrics: grey82 + metrics_text_delimiter: " " + metrics_format: .4g diff --git a/configs/data/multiview_dust3r.yaml b/configs/data/multiview_dust3r.yaml new file mode 100644 index 0000000000000000000000000000000000000000..18d93923e8dfd575b6d68116d0e881af6800d718 --- /dev/null +++ b/configs/data/multiview_dust3r.yaml @@ -0,0 +1,25 @@ +# Define the common data root and number of views +data_root: /path/to/dust3r_data +num_views: 4 +num_views_val: 10 + +data_module: + _target_: stream3r.data.multiview_dust3r_datamodule.MultiViewDUSt3RDataModule + train_datasets: + - 80_000 @ Co3d_Multiview(split='train', num_views=${data.num_views}, window_degree_range=360, num_samples_per_window=100, ROOT='${data.data_root}/co3d_50_seqs_per_category_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + - 80_000 @ MegaDepth_Multiview(split='train', num_views=${data.num_views}, window_size=${python_eval:"${data.num_views} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/megadepth_processed', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + - 80_000 @ ScanNetpp_Multiview(split='train', num_views=${data.num_views}, window_size=${python_eval:"${data.num_views} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/scannetpp_processed', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + - 80_000 @ ARKitScenes_Multiview(split='train', num_views=${data.num_views}, num_samples_per_window=10, ROOT='${data.data_root}/arkitscenes_processed', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + - 80_000 @ Habitat_Multiview(1_000_000, split='train', num_views=${data.num_views}, ROOT='${data.data_root}/habitat_processed', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + validation_datasets: + - 100 @ Co3d_Multiview(split='test', num_views=${data.num_views_val}, window_degree_range=360, num_samples_per_window=100, ROOT='${data.data_root}/co3d_50_seqs_per_category_subset_processed', resolution=(512, 384), seed=777) + - 100 @ MegaDepth_Multiview(split='val', num_views=${data.num_views_val}, window_size=${python_eval:"${data.num_views_val} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/megadepth_processed', resolution=(512, 336), seed=777) + - 100 @ ScanNetpp_Multiview(split='train', num_views=${data.num_views_val}, window_size=${python_eval:"${data.num_views_val} * 2"}, num_samples_per_window=100, ROOT='${data.data_root}/scannetpp_processed', resolution=(512, 384), seed=777) + - 100 @ ARKitScenes_Multiview(split='train', num_views=${data.num_views_val}, num_samples_per_window=10, ROOT='${data.data_root}/arkitscenes_processed', resolution=(512, 384), seed=777) + - 100 @ Habitat_Multiview(100_000, split='val', num_views=${data.num_views_val}, ROOT='${data.data_root}/habitat_processed', resolution=(512,384), seed=777) + batch_size_per_device: 6 + batch_size_per_device_val: 4 + num_workers: 6 + pin_memory: True + + diff --git a/configs/debug/ddp_debug.yaml b/configs/debug/ddp_debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c475993fead54368a715068560b754ca361ecc13 --- /dev/null +++ b/configs/debug/ddp_debug.yaml @@ -0,0 +1,48 @@ +# @package _global_ + +#use a smaller dataset for faster initializations +defaults: + - override /data: multiview_dust3r_tiny + - override /logger: + - csv + - wandb + +# overwrite task name so debugging logs are stored in separate folder +task_name: "debug" + +logger: + wandb: + name: ${paths.run_folder_name} + +# ckpt_path: /some/random/path + +extras: + ignore_warnings: False + enforce_tags: False + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + +model: + net: + random_image_idx_embedding: true + +data: + num_views: 4 + data_module: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: false # disable gpu memory pin + batch_size_per_device: 6 + +trainer: + log_every_n_steps: 1 + devices: auto + # fast_dev_run: 1 + limit_train_batches: 1 + limit_val_batches: 10000 + precision: 32 + diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1886902b39f1be560e314bce7b3778f95b44754c --- /dev/null +++ b/configs/debug/default.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +# overwrite task name so debugging logs are stored in separate folder +task_name: "debug" + +# disable callbacks and loggers during debugging +callbacks: null +logger: null + +extras: + ignore_warnings: False + enforce_tags: False + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + + # use this to also set hydra loggers to 'DEBUG' + # verbose: True + +trainer: + max_epochs: 1 + accelerator: cpu # debuggers don't like gpus + devices: 1 # debuggers don't like multiprocessing + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + +data: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: False # disable gpu memory pin diff --git a/configs/debug/fdr.yaml b/configs/debug/fdr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f2d34fa37c31017e749d5a4fc5ae6763e688b46 --- /dev/null +++ b/configs/debug/fdr.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - default + +trainer: + fast_dev_run: true diff --git a/configs/debug/limit.yaml b/configs/debug/limit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..514d77fbd1475b03fff0372e3da3c2fa7ea7d190 --- /dev/null +++ b/configs/debug/limit.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# uses only 1% of the training data and 5% of validation/test data + +defaults: + - default + +trainer: + max_epochs: 3 + limit_train_batches: 0.01 + limit_val_batches: 0.05 + limit_test_batches: 0.05 diff --git a/configs/debug/overfit.yaml b/configs/debug/overfit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9906586a67a12aa81ff69138f589a366dbe2222f --- /dev/null +++ b/configs/debug/overfit.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# overfits to 3 batches + +defaults: + - default + +trainer: + max_epochs: 20 + overfit_batches: 3 + +# model ckpt and early stopping need to be disabled during overfitting +callbacks: null diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2bd7da87ae23ed425ace99b09250a76a5634a3fb --- /dev/null +++ b/configs/debug/profiler.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# runs with execution time profiling + +defaults: + - default + +trainer: + max_epochs: 1 + profiler: "simple" + # profiler: "advanced" + # profiler: "pytorch" diff --git a/configs/eval.yaml b/configs/eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51743dd563b674b6f65e1f469ba21d231509935a --- /dev/null +++ b/configs/eval.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +defaults: + - _self_ + - data: multiview_dust3r + - model: stream3r + - logger: many_loggers + - trainer: ddp_eval + - paths: default + - extras: default + - hydra: default + - eval: default + +task_name: "eval" + +tags: ["eval"] + +# passing checkpoint path is necessary for evaluation +ckpt_path: ??? diff --git a/configs/experiment/stream3r/stream3r.yaml b/configs/experiment/stream3r/stream3r.yaml new file mode 100644 index 0000000000000000000000000000000000000000..097facd7572148c9ee989157368e41aa8a75f7ba --- /dev/null +++ b/configs/experiment/stream3r/stream3r.yaml @@ -0,0 +1,125 @@ +# @package _global_ + +defaults: + - override /model: stream3r + +# seed for random number generators in pytorch, numpy and python.random +seed: 42 + + +tags: ["train", "stream3r"] + +task_name: stream3r +slurm_job_id: 99999 # must set in the command line + +# ckpt_path: /path/to/resume.ckpt # uncomment to resume training from a checkpoint + +paths: + run_folder_name: ${task_name}_${slurm_job_id} + +logger: + wandb: + name: ${task_name}_${slurm_job_id} + project: stream3r + +data: + data_scaling: 1.0 + data_root: /data + num_views: 24 + resolution: + - [518, 392] + - [518, 378] + - [518, 336] + - [518, 294] + - [518, 252] + - [518, 210] + - [518, 140] + - [378, 518] + - [336, 518] + - [294, 518] + - [252, 518] + - [224, 224] + allow_repeat: true + n_corres_train: 0 + data_module: + _target_: stream3r.data.multiview_dust3r_datamodule.MultiViewDUSt3RDataModule + pin_memory: true + num_workers: 16 + num_workers_val: 1 # have to be a low number when using DeepSpeed ZeRO-2 + batch_size_per_device: 1 + batch_size_per_device_val: 1 + train_datasets: + - 44800 @ Co3d_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT='${data.data_root}/processed_co3d', aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 56000 @ WildRGBD_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_wildrgbd_mp", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 22400 @ ARKitScenesHighRes_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_arkitscene_highres", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 38400 @ ScanNet_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_scannet/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 16800 @ ScanNetpp_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_scannetpp/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 84000 @ MapFree_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_mapfree/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 20000 @ Waymo_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_waymo/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 56000 @ TartanAir_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_tartanair/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 9400 @ Spring(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_spring/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 36000 @ BEDLAM_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_bedlam/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 28800 @ MP3D_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_mp3d/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 14400 @ UASOL_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_uasol/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 1400 @ MVS_Synth_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_mvs_synth", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 7200 @ PointOdyssey_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_pointodyssey", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 11200 @ HyperSim_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_hypersim_new", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 22400 @ BlendedMVS_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_blendedmvs/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 22400 @ MegaDepth_Multi(allow_repeat=${data.allow_repeat}, split="train", ROOT="${data.data_root}/processed_megadepth", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 5600 @ VirtualKITTI2_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_vkitti", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 168 @ UnReal4K_Multi(allow_repeat=${data.allow_repeat}, split=None, ROOT="${data.data_root}/processed_unreal4k/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 74000 @ DL3DV_Multi(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_dl3dv_ours_parts/processed_dl3dv_ours", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + + - 36000 @ DynamicReplica(allow_repeat=${data.allow_repeat}, split='train', ROOT="${data.data_root}/processed_dynamic_replica/", aug_crop=16, resolution=${data.resolution}, transform=ColorJitter, num_views=${data.num_views}, n_corres=${data.n_corres_train}) + +model: + pretrained: weights/vggt/model.pt + net: + freeze: encoder + + scheduler: + warmup_start_lr: 1e-6 + warmup_epochs: 1 + + train_criterion: + _target_: stream3r.loss.losses.CausalLoss + gradient_loss: grad + is_metric: false + + validation_criterion: + _target_: stream3r.loss.losses.CausalLoss + gradient_loss: grad + is_metric: false + + optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 1e-5 + betas: + - 0.9 + - 0.95 + weight_decay: 0.05 + +trainer: + devices: auto + max_epochs: 500 + accumulate_grad_batches: 4 + strategy: + _target_: lightning.pytorch.strategies.DeepSpeedStrategy + timeout: + _target_: datetime.timedelta + minutes: 80 + plugins: null + limit_val_batches: 0 + precision: bf16-mixed + log_every_n_steps: 20 + +callbacks: + model_checkpoint: + every_n_train_steps: 2000 + every_n_epochs: null + save_top_k: -1 + filename: "{epoch:03d}-{step:08d}" + save_last: false + monitor: "train/loss" + early_stopping: + monitor: "train/loss" \ No newline at end of file diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9c6b622283a647fbc513166fc14f016cc3ed8a0 --- /dev/null +++ b/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/configs/hparams_search/mnist_optuna.yaml b/configs/hparams_search/mnist_optuna.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1391183ebcdec3d8f5eb61374e0719d13c7545da --- /dev/null +++ b/configs/hparams_search/mnist_optuna.yaml @@ -0,0 +1,52 @@ +# @package _global_ + +# example hyperparameter optimization of some experiment with Optuna: +# python train.py -m hparams_search=mnist_optuna experiment=example + +defaults: + - override /hydra/sweeper: optuna + +# choose metric which will be optimized by Optuna +# make sure this is the correct name of some metric logged in lightning module! +optimized_metric: "val/acc_best" + +# here we define Optuna hyperparameter search +# it optimizes for value returned from function with @hydra.main decorator +# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper +hydra: + mode: "MULTIRUN" # set hydra to multirun by default if this config is attached + + sweeper: + _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper + + # storage URL to persist optimization results + # for example, you can use SQLite if you set 'sqlite:///example.db' + storage: null + + # name of the study to persist optimization results + study_name: null + + # number of parallel workers + n_jobs: 1 + + # 'minimize' or 'maximize' the objective + direction: maximize + + # total number of runs that will be executed + n_trials: 20 + + # choose Optuna hyperparameter sampler + # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others + # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html + sampler: + _target_: optuna.samplers.TPESampler + seed: 1234 + n_startup_trials: 10 # number of random sampling runs before optimization starts + + # define hyperparameter search space + params: + model.optimizer.lr: interval(0.0001, 0.1) + data.batch_size: choice(32, 64, 128, 256) + model.net.lin1_size: choice(64, 128, 256) + model.net.lin2_size: choice(64, 128, 256) + model.net.lin3_size: choice(32, 64, 128, 256) diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..955d3bab7e1cc2648e693bcf0820a19122d158b4 --- /dev/null +++ b/configs/hydra/default.yaml @@ -0,0 +1,19 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/runs/${paths.run_folder_name} +sweep: + dir: ${paths.log_dir}/${task_name}/multiruns/${paths.run_folder_name} + subdir: ${hydra.job.num} + +job_logging: + handlers: + file: + # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 + filename: ${hydra.runtime.output_dir}/${task_name}.log diff --git a/configs/hydra/launcher/fair_a100.yaml b/configs/hydra/launcher/fair_a100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28a19b84dd7c78cc43587dbe873ae8f65e9a6d6c --- /dev/null +++ b/configs/hydra/launcher/fair_a100.yaml @@ -0,0 +1,43 @@ +defaults: + - submitit_slurm + +# see: https://github.com/facebookresearch/hydra/blob/main/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/config.py +_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher +submitit_folder: ${hydra.sweep.dir}/.submitit/%j +name: ${hydra.job.name} +timeout_min: 20160 # 14 days : 60 * 24 * 14 +account: cortex +qos: cortex_high +comment: "multiview_dust3r experiment" +nodes: 1 +gres: "gpu:8" +tasks_per_node: 8 +cpus_per_task: 12 +signal_delay_s: 120 # USR1 signal delay (seconds) before timeout +max_num_timeout: 0 # number of times the job can be restarted after timeout +array_parallelism: 256 # Maximum number of jobs running in parallel + +# Useful to add parameters which are not currently available in the plugin. +# Eg: {"mail-user": "blublu@fb.com", "mail-type": "BEGIN"} +additional_parameters: + mail-user: "jianingy@meta.com" + mail-type: "BEGIN,END" + output: "/path/to/slurm_out/%x-%j.out" + +setup: # A list of commands to run in sbatch befure running srun + - echo "Begin setting up env on head node ($HOSTNAME)..." + - echo $(env | grep SLURM) + - export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) + - export MASTER_PORT=9929 + - export RDZV_ID=$SLURM_JOBID + - export OMP_NUM_THREADS=12 + - . /path/to/miniforge3/etc/profile.d/conda.sh # activate conda + - conda activate dust3r + - cd /path/to/project # cd to the project directory + - export NCCL_DEBUG=INFO + - export PYTHONFAULTHANDLER=1 + - export TORCH_DISTRIBUTED_DEBUG=INFO + - echo "env setup on head node ($HOSTNAME) finished, starting srun..." + +srun_args: + - "--cpu-bind=none" # This is critical to ensure dataloaders uses all CPUs! diff --git a/configs/local/.gitkeep b/configs/local/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/logger/aim.yaml b/configs/logger/aim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f9f6adad7feb2780c2efd5ddb0ed053621e05f8 --- /dev/null +++ b/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: null # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0789274e2137ee6c97ca37a5d56c2b8abaf0aaa --- /dev/null +++ b/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "lightning-hydra-template" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa028e9c146430c319101ffdfce466514338591c --- /dev/null +++ b/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd586800bdccb4e8f4b0236a181b7ddd756ba9ab --- /dev/null +++ b/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet + - csv + # - mlflow + # - neptune + - tensorboard + - wandb diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8fb7e685fa27fc8141387a421b90a0b9b492d9e --- /dev/null +++ b/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: null + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8233c140018ecce6ab62971beed269991d31c89b --- /dev/null +++ b/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/lightning-hydra-template + # name: "" + log_model_checkpoints: True + prefix: "" diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2bd31f6d8ba68d1f5c36a804885d5b9f9c1a9302 --- /dev/null +++ b/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: null + log_graph: False + default_hp_metric: True + prefix: "" + # version: "" diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a10dcb7eb9ccaddac5b8fa74eab3a342816f3103 --- /dev/null +++ b/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + name: null # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: False + id: null # pass correct id to resume experiment! + anonymous: null # enable anonymous logging + project: "stream3r" + log_model: False # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + # entity: "" # set to name of your wandb team + group: "" + tags: [] + job_type: "" diff --git a/configs/model/stream3r.yaml b/configs/model/stream3r.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33faf87b05ba8cb4479998b7ae86a58a69543216 --- /dev/null +++ b/configs/model/stream3r.yaml @@ -0,0 +1,42 @@ +_target_: stream3r.models.multiview_dust3r_module.MultiViewDUSt3RLitModule + +pretrained: null +resume_from_checkpoint: ${ckpt_path} + +eval_use_pts3d_from_local_head: true + +train_criterion: + _target_: stream3r.loss.losses.CausalLoss + +validation_criterion: + _target_: stream3r.loss.losses.CausalLoss + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 1e-4 + betas: + - 0.9 + - 0.95 + weight_decay: 0.05 + +# scheduler: +# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau +# _partial_: true +# mode: min +# factor: 0.1 +# patience: 10 + +scheduler: + _target_: pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR + _partial_: true + warmup_epochs: 10 + max_epochs: ${trainer.max_epochs} + eta_min: 1e-06 + +net: + _target_: stream3r.models.stream3r.STream3R + + +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a5761adfdb395a54545010abcefde57c6e585041 --- /dev/null +++ b/configs/paths/default.yaml @@ -0,0 +1,21 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# you can replace it with "." if you want the root to be the current working directory +# root_dir: R${oc.env:PROJECT_ROOT} +root_dir: . + +# path to data directory +data_dir: ${paths.root_dir}/data/ + +# path to logging directory +log_dir: ${paths.root_dir}/logs/ + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} + +run_folder_name: ${now:%Y-%m-%d}_${now:%H-%M-%S} diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..caf7e447fd4415c63a6dd800daef4726ed08516c --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,49 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - data: multiview_dust3r + - model: stream3r + - callbacks: default + - logger: many_loggers # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - trainer: ddp + - paths: default + - extras: default + - hydra: default + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path +task_name: "train" + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +tags: ["dev"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: True + +# simply provide checkpoint path to resume training +ckpt_path: null + +# seed for random number generators in pytorch, numpy and python.random +seed: 42 diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7d6767e60c956567555980654f15e7bb673a41f --- /dev/null +++ b/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: cpu +devices: 1 diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49d123313d6851e7cfef7b4035bdcbc196c7c6ad --- /dev/null +++ b/configs/trainer/ddp.yaml @@ -0,0 +1,12 @@ +defaults: + - default + +# strategy: ddp +strategy: ddp_find_unused_parameters_true + +accelerator: gpu +devices: auto +num_nodes: 1 +sync_batchnorm: true + +use_distributed_sampler: false diff --git a/configs/trainer/ddp_eval.yaml b/configs/trainer/ddp_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..861c524bd3723ab7302f987a8f74065dfdaf8d77 --- /dev/null +++ b/configs/trainer/ddp_eval.yaml @@ -0,0 +1,16 @@ +defaults: + - default + +# strategy: ddp +strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + timeout: + _target_: datetime.timedelta + minutes: 30 + +accelerator: gpu +devices: auto +num_nodes: 1 +sync_batchnorm: true + +use_distributed_sampler: false diff --git a/configs/trainer/ddp_sim.yaml b/configs/trainer/ddp_sim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8404419e5c295654967d0dfb73a7366e75be2f1f --- /dev/null +++ b/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/configs/trainer/deepspeed_stage_2.yaml b/configs/trainer/deepspeed_stage_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bd83c02a33cc1356b701090e0d0099577803d19c --- /dev/null +++ b/configs/trainer/deepspeed_stage_2.yaml @@ -0,0 +1,9 @@ +defaults: + - default + +# strategy: deepspeed_stage_2 +strategy: deepspeed_stage_2 + +accelerator: gpu +devices: auto +num_nodes: 1 diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9be013e36aac3c25c72fda8c86f18f16087729a2 --- /dev/null +++ b/configs/trainer/default.yaml @@ -0,0 +1,30 @@ +_target_: lightning.pytorch.trainer.Trainer +_convert_: partial + +default_root_dir: ${paths.output_dir} + +min_epochs: 1 # prevents early stopping +max_epochs: 100 + +accelerator: cpu +devices: 1 + +# mixed precision for extra speed-up +# precision: 16 + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False + +plugins: + - _target_: lightning.pytorch.plugins.environments.SLURMEnvironment + auto_requeue: true # auto-resubmit the job when it is preempted by slurm + requeue_signal: ${python_eval:"signal.SIGUSR1"} # singal code is platform dependent, so it has to be decided at runtime + # requeue_signal: + # _target_: signal.Signals + # _args_: + # - 10 # SIGUSR1, see: https://chromium.googlesource.com/chromiumos/docs/+/master/constants/signals.md + diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2389510a90f5f0161cff6ccfcb4a96097ddf9a1 --- /dev/null +++ b/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: gpu +devices: 1 diff --git a/configs/trainer/mps.yaml b/configs/trainer/mps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ecf6d5cc3a34ca127c5510f4a18e989561e38e4 --- /dev/null +++ b/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: mps +devices: 1 diff --git a/eval/monodepth/eval_metrics.py b/eval/monodepth/eval_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..81a59325d4551ed86fc3c3f42e74d06fa06b328a --- /dev/null +++ b/eval/monodepth/eval_metrics.py @@ -0,0 +1,211 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +from eval.monodepth.tools import depth_evaluation +import numpy as np +import json +from tqdm import tqdm +import glob +import cv2 +from eval.monodepth.metadata import dataset_metadata +import argparse +from PIL import Image + +TAG_FLOAT = 202021.25 + + +def depth_read_sintel(filename): + """Read depth data from file, return as numpy array.""" + f = open(filename, "rb") + check = np.fromfile(f, dtype=np.float32, count=1)[0] + assert ( + check == TAG_FLOAT + ), " depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format( + TAG_FLOAT, check + ) + width = np.fromfile(f, dtype=np.int32, count=1)[0] + height = np.fromfile(f, dtype=np.int32, count=1)[0] + size = width * height + assert ( + width > 0 and height > 0 and size > 1 and size < 100000000 + ), " depth_read:: Wrong input size (width = {0}, height = {1}).".format( + width, height + ) + depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width)) + return depth + + +def depth_read_bonn(filename): + # loads depth map D from png file + # and returns it as a numpy array + depth_png = np.asarray(Image.open(filename)) + # make sure we have a proper 16bit depth map here.. not 8bit! + assert np.max(depth_png) > 255 + depth = depth_png.astype(np.float64) / 5000.0 + depth[depth_png == 0] = -1.0 + return depth + + +def depth_read_kitti(filename): + # loads depth map D from png file + # and returns it as a numpy array, + # for details see readme.txt + img_pil = Image.open(filename) + depth_png = np.array(img_pil, dtype=int) + # make sure we have a proper 16bit depth map here.. not 8bit! + assert np.max(depth_png) > 255 + + depth = depth_png.astype(float) / 256.0 + depth[depth_png == 0] = -1.0 + return depth + + +def get_gt_depth(filename, dataset): + if dataset == "sintel": + return depth_read_sintel(filename) + elif dataset == "bonn": + return depth_read_bonn(filename) + elif dataset == "kitti": + return depth_read_kitti(filename) + elif dataset == "nyu": + return np.load(filename) + else: + raise NotImplementedError + + +def get_args_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--output_dir", + type=str, + default="", + help="value for outdir", + ) + parser.add_argument( + "--eval_dataset", type=str, default="nyu", choices=list(dataset_metadata.keys()) + ) + return parser + + +def main(args): + if args.eval_dataset == "nyu": + depth_pathes = glob.glob("data/nyu-v2/val/nyu_depths/*.npy") + depth_pathes = sorted(depth_pathes) + pred_pathes = glob.glob( + f"{args.output_dir}/*.npy" + ) # TODO: update the path to your prediction + pred_pathes = sorted(pred_pathes) + elif args.eval_dataset == "sintel": + pred_pathes = glob.glob( + f"{args.output_dir}/*/*.npy" + ) # TODO: update the path to your prediction + pred_pathes = sorted(pred_pathes) + full = len(pred_pathes) > 643 + if full: + depth_pathes = glob.glob(f"data/sintel/training/depth/*/*.dpt") + depth_pathes = sorted(depth_pathes) + else: + seq_list = [ + "alley_2", + "ambush_4", + "ambush_5", + "ambush_6", + "cave_2", + "cave_4", + "market_2", + "market_5", + "market_6", + "shaman_3", + "sleeping_1", + "sleeping_2", + "temple_2", + "temple_3", + ] + depth_pathes_folder = [ + f"data/sintel/training/depth/{seq}" for seq in seq_list + ] + depth_pathes = [] + for depth_pathes_folder_i in depth_pathes_folder: + depth_pathes += glob.glob(depth_pathes_folder_i + "/*.dpt") + depth_pathes = sorted(depth_pathes) + elif args.eval_dataset == "bonn": + seq_list = ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"] + img_pathes_folder = [ + f"data/bonn/rgbd_bonn_dataset/rgbd_bonn_{seq}/rgb_110/*.png" + for seq in seq_list + ] + img_pathes = [] + for img_pathes_folder_i in img_pathes_folder: + img_pathes += glob.glob(img_pathes_folder_i) + img_pathes = sorted(img_pathes) + depth_pathes_folder = [ + f"data/bonn/rgbd_bonn_dataset/rgbd_bonn_{seq}/depth_110/*.png" + for seq in seq_list + ] + depth_pathes = [] + for depth_pathes_folder_i in depth_pathes_folder: + depth_pathes += glob.glob(depth_pathes_folder_i) + depth_pathes = sorted(depth_pathes) + pred_pathes = glob.glob( + f"{args.output_dir}/*/*.npy" + ) # TODO: update the path to your prediction + pred_pathes = sorted(pred_pathes) + elif args.eval_dataset == "kitti": + depth_pathes = glob.glob( + "data/kitti/depth_selection/val_selection_cropped/groundtruth_depth_gathered/*/*.png" + ) + depth_pathes = sorted(depth_pathes) + pred_pathes = glob.glob( + f"{args.output_dir}/*/*depth.npy" + ) # TODO: update the path to your prediction + pred_pathes = sorted(pred_pathes) + else: + raise NotImplementedError + + gathered_depth_metrics = [] + for idx in tqdm(range(len(depth_pathes))): + pred_depth = np.load(pred_pathes[idx]) + gt_depth = get_gt_depth(depth_pathes[idx], args.eval_dataset) + pred_depth = cv2.resize( + pred_depth, + (gt_depth.shape[1], gt_depth.shape[0]), + interpolation=cv2.INTER_CUBIC, + ) + if args.eval_dataset == "nyu": + depth_results, error_map, depth_predict, depth_gt = depth_evaluation( + pred_depth, gt_depth, max_depth=None, lr=1e-3 + ) + elif args.eval_dataset == "sintel": + depth_results, error_map, depth_predict, depth_gt = depth_evaluation( + pred_depth, gt_depth, max_depth=70, use_gpu=True, post_clip_max=70 + ) + elif args.eval_dataset == "bonn": + depth_results, error_map, depth_predict, depth_gt = depth_evaluation( + pred_depth, gt_depth, max_depth=70, use_gpu=True + ) + elif args.eval_dataset == "kitti": + depth_results, error_map, depth_predict, depth_gt = depth_evaluation( + pred_depth, gt_depth, max_depth=None, use_gpu=True + ) + gathered_depth_metrics.append(depth_results) + + depth_log_path = os.path.join(args.output_dir, "metric.json") + average_metrics = { + key: np.average( + [metrics[key] for metrics in gathered_depth_metrics], + weights=[metrics["valid_pixels"] for metrics in gathered_depth_metrics], + ) + for key in gathered_depth_metrics[0].keys() + if key != "valid_pixels" + } + print(f"{args.eval_dataset} - Average depth evaluation metrics:", average_metrics) + with open(depth_log_path, "w") as f: + f.write(json.dumps(average_metrics)) + + +if __name__ == "__main__": + args = get_args_parser() + args = args.parse_args() + main(args) diff --git a/eval/monodepth/launch.py b/eval/monodepth/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..8a072096424ed6871fdfbafa937b536b7d8c246c --- /dev/null +++ b/eval/monodepth/launch.py @@ -0,0 +1,146 @@ +import torch +import numpy as np +import matplotlib +import numpy as np +import cv2 +import argparse +from pathlib import Path +from tqdm import tqdm +import os +import sys + +from stream3r.models.stream3r import STream3R +from stream3r.dust3r.utils.device import collate_with_cat +from stream3r.dust3r.utils.image import load_images_for_eval as load_images +from stream3r.utils.utils import ImgDust3r2Stream3r + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +from eval.monodepth.metadata import dataset_metadata + + +torch.backends.cuda.matmul.allow_tf32 = True + +# avoid high cpu usage +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +torch.set_num_threads(1) +# =========================================== + + +def colorize_depth(depth: np.ndarray, + mask: np.ndarray = None, + normalize: bool = True, + cmap: str = 'Spectral') -> np.ndarray: + if mask is None: + depth = np.where(depth > 0, depth, np.nan) + else: + depth = np.where((depth > 0) & mask, depth, np.nan) + disp = 1 / depth + if normalize: + min_disp, max_disp = np.nanquantile(disp, + 0.001), np.nanquantile(disp, 0.99) + disp = (disp - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def get_args_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument("--device", + type=str, + default="cuda", + help="pytorch device") + parser.add_argument("--output_dir", + type=str, + default="", + help="value for outdir") + parser.add_argument("--no_crop", + type=bool, + default=True, + help="whether to crop input data") + parser.add_argument("--full_seq", + type=bool, + default=False, + help="whether to use all seqs") + parser.add_argument("--seq_list", default=None) + + parser.add_argument("--eval_dataset", + type=str, + default="nyu", + choices=list(dataset_metadata.keys())) + return parser + + +def eval_mono_depth_estimation(args, model, device): + metadata = dataset_metadata.get(args.eval_dataset) + if metadata is None: + raise ValueError(f"Unknown dataset: {args.eval_dataset}") + + img_path = metadata.get("img_path") + if "img_path_func" in metadata: + img_path = metadata["img_path_func"](args) + + process_func = metadata.get("process_func") + if process_func is None: + raise ValueError( + f"No processing function defined for dataset: {args.eval_dataset}") + + for filelist, save_dir in process_func(args, img_path): + Path(save_dir).mkdir(parents=True, exist_ok=True) + eval_mono_depth(args, model, device, filelist, save_dir=save_dir) + + +def eval_mono_depth(args, model, device, filelist, save_dir=None): + for file in tqdm(filelist): + file = [file] + images = load_images( + file, + size=518, + verbose=True, + crop=False, + patch_size=14, + ) + + images = collate_with_cat([tuple(images)]) + images = torch.stack([view["img"] for view in images], dim=1) + images = ImgDust3r2Stream3r(images).to(device) + + with torch.no_grad(): + predictions = model(images) + + depth_map = predictions['depth'][0,0].squeeze(-1).cpu() + + if save_dir is not None: + # save the depth map to the save_dir as npy + np.save( + f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.npy')}", + depth_map.cpu().numpy(), + ) + depth_map = colorize_depth(depth_map) + cv2.imwrite( + f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.jpg')}", + depth_map, + ) + + +def main(): + args = get_args_parser() + args = args.parse_args() + + if args.eval_dataset == "sintel": + args.full_seq = True + else: + args.full_seq = False + + model = STream3R.from_pretrained("yslan/STream3R").to(args.device) + model.eval() + + eval_mono_depth_estimation(args, model, args.device) + + +if __name__ == "__main__": + main() diff --git a/eval/monodepth/metadata.py b/eval/monodepth/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..459511277cff7cc319780dc665bb2fdba7acbe9d --- /dev/null +++ b/eval/monodepth/metadata.py @@ -0,0 +1,187 @@ +import os +import glob +from tqdm import tqdm + +# Define the merged dataset metadata dictionary +dataset_metadata = { + "sun_rgbd": { + "img_path": "data/sun_rgbd/image/test", + "mask_path": None, + }, + "davis": { + "img_path": "data/davis/DAVIS/JPEGImages/480p", + "mask_path": "data/davis/DAVIS/masked_images/480p", + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: None, + "traj_format": None, + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq), + "skip_condition": None, + "process_func": None, # Not used in mono depth estimation + }, + "kitti": { + "img_path": "data/kitti/depth_selection/val_selection_cropped/image_gathered", # Default path + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: None, + "traj_format": None, + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": lambda args, img_path: process_kitti(args, img_path), + }, + "bonn": { + "img_path": "data/bonn/rgbd_bonn_dataset", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join( + img_path, f"rgbd_bonn_{seq}", "rgb_110" + ), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt" + ), + "traj_format": "tum", + "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"], + "full_seq": False, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": lambda args, img_path: process_bonn(args, img_path), + }, + "nyu": { + "img_path": "data/nyu-v2/val/nyu_images", + "mask_path": None, + "process_func": lambda args, img_path: process_nyu(args, img_path), + }, + "scannet": { + "img_path": "data/scannetv2", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "pose_90.txt" + ), + "traj_format": "replica", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)), + "process_func": lambda args, img_path: process_scannet(args, img_path), + }, + "tum": { + "img_path": "data/tum", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "groundtruth_90.txt" + ), + "traj_format": "tum", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": None, + }, + "sintel": { + "img_path": "data/sintel/training/final", + "anno_path": "data/sintel/training/camdata_left", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq), + "traj_format": None, + "seq_list": [ + "alley_2", + "ambush_4", + "ambush_5", + "ambush_6", + "cave_2", + "cave_4", + "market_2", + "market_5", + "market_6", + "shaman_3", + "sleeping_1", + "sleeping_2", + "temple_2", + "temple_3", + ], + "full_seq": False, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": lambda args, img_path: process_sintel(args, img_path), + }, +} + + +# Define processing functions for each dataset +def process_kitti(args, img_path): + for dir in tqdm(sorted(glob.glob(f"{img_path}/*"))): + filelist = sorted(glob.glob(f"{dir}/*.png")) + save_dir = f"{args.output_dir}/{os.path.basename(dir)}" + yield filelist, save_dir + + +def process_bonn(args, img_path): + if args.full_seq: + for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))): + filelist = sorted(glob.glob(f"{dir}/rgb/*.png")) + save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}" + yield filelist, save_dir + else: + seq_list = ( + ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"] + if args.seq_list is None + else args.seq_list + ) + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f"{img_path}/rgbd_bonn_{seq}/rgb_110/*.png")) + save_dir = f"{args.output_dir}/{seq}" + yield filelist, save_dir + + +def process_sunrgbd(args, img_path): + filelist = sorted(glob.glob(f"{img_path}/*.jpg")) + save_dir = f"{args.output_dir}" + yield filelist, save_dir + + +def process_nyu(args, img_path): + filelist = sorted(glob.glob(f"{img_path}/*.png")) + save_dir = f"{args.output_dir}" + yield filelist, save_dir + + +def process_scannet(args, img_path): + seq_list = sorted(glob.glob(f"{img_path}/*")) + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f"{seq}/color_90/*.jpg")) + save_dir = f"{args.output_dir}/{os.path.basename(seq)}" + yield filelist, save_dir + + +def process_sintel(args, img_path): + if args.full_seq: + for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))): + filelist = sorted(glob.glob(f"{dir}/*.png")) + save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}" + yield filelist, save_dir + else: + seq_list = [ + "alley_2", + "ambush_4", + "ambush_5", + "ambush_6", + "cave_2", + "cave_4", + "market_2", + "market_5", + "market_6", + "shaman_3", + "sleeping_1", + "sleeping_2", + "temple_2", + "temple_3", + ] + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f"{img_path}/{seq}/*.png")) + save_dir = f"{args.output_dir}/{seq}" + yield filelist, save_dir diff --git a/eval/monodepth/run.sh b/eval/monodepth/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..9a0412e83aeefaf7fe13aa1650803b1d23729273 --- /dev/null +++ b/eval/monodepth/run.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e + +workdir='.' + +datasets=('sintel' 'bonn' 'kitti' 'nyu') +model_name='stream3r' + +for data in "${datasets[@]}"; do + output_dir="${workdir}/eval_results/monodepth/${model_name}/${data}" + echo "$output_dir" + + python eval/monodepth/launch.py \ + --output_dir="$output_dir" \ + --eval_dataset="$data" \ + + python eval/monodepth/eval_metrics.py \ + --output_dir "$output_dir" \ + --eval_dataset "$data" +done \ No newline at end of file diff --git a/eval/monodepth/tools.py b/eval/monodepth/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d6786fa6f25def0110ce22dbb7d44a7a08c952c8 --- /dev/null +++ b/eval/monodepth/tools.py @@ -0,0 +1,399 @@ +import torch +import numpy as np +import cv2 +import glob +import argparse +from pathlib import Path +from tqdm import tqdm +from copy import deepcopy +from scipy.optimize import minimize +import os +from collections import defaultdict + + +def group_by_directory(pathes, idx=-1): + """ + Groups the file paths based on the second-to-last directory in their paths. + + Parameters: + - pathes (list): List of file paths. + + Returns: + - dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths. + """ + grouped_pathes = defaultdict(list) + + for path in pathes: + # Extract the second-to-last directory + dir_name = os.path.dirname(path).split("/")[idx] + grouped_pathes[dir_name].append(path) + + return grouped_pathes + + +def depth2disparity(depth, return_mask=False): + if isinstance(depth, torch.Tensor): + disparity = torch.zeros_like(depth) + elif isinstance(depth, np.ndarray): + disparity = np.zeros_like(depth) + non_negtive_mask = depth > 0 + disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] + if return_mask: + return disparity, non_negtive_mask + else: + return disparity + + +def absolute_error_loss(params, predicted_depth, ground_truth_depth): + s, t = params + + predicted_aligned = s * predicted_depth + t + + abs_error = np.abs(predicted_aligned - ground_truth_depth) + return np.sum(abs_error) + + +def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0): + predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1) + ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1) + + initial_params = [s, t] # s = 1, t = 0 + + result = minimize( + absolute_error_loss, + initial_params, + args=(predicted_depth_np, ground_truth_depth_np), + ) + + s, t = result.x + return s, t + + +def absolute_value_scaling2( + predicted_depth, + ground_truth_depth, + s_init=1.0, + t_init=0.0, + lr=1e-4, + max_iters=1000, + tol=1e-6, +): + # Initialize s and t as torch tensors with requires_grad=True + s = torch.tensor( + [s_init], + requires_grad=True, + device=predicted_depth.device, + dtype=predicted_depth.dtype, + ) + t = torch.tensor( + [t_init], + requires_grad=True, + device=predicted_depth.device, + dtype=predicted_depth.dtype, + ) + + optimizer = torch.optim.Adam([s, t], lr=lr) + + prev_loss = None + + for i in range(max_iters): + optimizer.zero_grad() + + # Compute predicted aligned depth + predicted_aligned = s * predicted_depth + t + + # Compute absolute error + abs_error = torch.abs(predicted_aligned - ground_truth_depth) + + # Compute loss + loss = torch.sum(abs_error) + + # Backpropagate + loss.backward() + + # Update parameters + optimizer.step() + + # Check convergence + if prev_loss is not None and torch.abs(prev_loss - loss) < tol: + break + + prev_loss = loss.item() + + return s.detach().item(), t.detach().item() + + +def depth_evaluation( + predicted_depth_original, + ground_truth_depth_original, + max_depth=80, + custom_mask=None, + post_clip_min=None, + post_clip_max=None, + pre_clip_min=None, + pre_clip_max=None, + align_with_lstsq=False, + align_with_lad=False, + align_with_lad2=False, + metric_scale=False, + lr=1e-4, + max_iters=1000, + use_gpu=False, + align_with_scale=False, + disp_input=False, +): + """ + Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment. + + Args: + predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map. + ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map. + max_depth (float): The maximum depth value to consider. Default is 80 meters. + align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth. + + Returns: + dict: A dictionary containing the evaluation metrics. + torch.Tensor: The depth error parity map. + """ + if isinstance(predicted_depth_original, np.ndarray): + predicted_depth_original = torch.from_numpy(predicted_depth_original) + if isinstance(ground_truth_depth_original, np.ndarray): + ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original) + if custom_mask is not None and isinstance(custom_mask, np.ndarray): + custom_mask = torch.from_numpy(custom_mask) + + # if the dimension is 3, flatten to 2d along the batch dimension + if predicted_depth_original.dim() == 3: + _, h, w = predicted_depth_original.shape + predicted_depth_original = predicted_depth_original.view(-1, w) + ground_truth_depth_original = ground_truth_depth_original.view(-1, w) + if custom_mask is not None: + custom_mask = custom_mask.view(-1, w) + + # put to device + if use_gpu: + predicted_depth_original = predicted_depth_original.cuda() + ground_truth_depth_original = ground_truth_depth_original.cuda() + + # Filter out depths greater than max_depth + if max_depth is not None: + mask = (ground_truth_depth_original > 0) & ( + ground_truth_depth_original < max_depth + ) + else: + mask = ground_truth_depth_original > 0 + predicted_depth = predicted_depth_original[mask] + ground_truth_depth = ground_truth_depth_original[mask] + + # Clip the depth values + if pre_clip_min is not None: + predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min) + if pre_clip_max is not None: + predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max) + + if disp_input: # align the pred to gt in the disparity space + real_gt = ground_truth_depth.clone() + ground_truth_depth = 1 / (ground_truth_depth + 1e-8) + + # various alignment methods + if metric_scale: + predicted_depth = predicted_depth + elif align_with_lstsq: + # Convert to numpy for lstsq + predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1) + ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1) + + # Add a column of ones for the shift term + A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)]) + + # Solve for scale (s) and shift (t) using least squares + result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None) + s, t = result[0][0], result[0][1] + + # convert to torch tensor + s = torch.tensor(s, device=predicted_depth_original.device) + t = torch.tensor(t, device=predicted_depth_original.device) + + # Apply scale and shift + predicted_depth = s * predicted_depth + t + elif align_with_lad: + s, t = absolute_value_scaling( + predicted_depth, + ground_truth_depth, + s=torch.median(ground_truth_depth) / torch.median(predicted_depth), + ) + predicted_depth = s * predicted_depth + t + elif align_with_lad2: + s_init = ( + torch.median(ground_truth_depth) / torch.median(predicted_depth) + ).item() + s, t = absolute_value_scaling2( + predicted_depth, + ground_truth_depth, + s_init=s_init, + lr=lr, + max_iters=max_iters, + ) + predicted_depth = s * predicted_depth + t + elif align_with_scale: + # Compute initial scale factor 's' using the closed-form solution (L2 norm) + dot_pred_gt = torch.nanmean(ground_truth_depth) + dot_pred_pred = torch.nanmean(predicted_depth) + s = dot_pred_gt / dot_pred_pred + + # Iterative reweighted least squares using the Weiszfeld method + for _ in range(10): + # Compute residuals between scaled predictions and ground truth + residuals = s * predicted_depth - ground_truth_depth + abs_residuals = ( + residuals.abs() + 1e-8 + ) # Add small constant to avoid division by zero + + # Compute weights inversely proportional to the residuals + weights = 1.0 / abs_residuals + + # Update 's' using weighted sums + weighted_dot_pred_gt = torch.sum( + weights * predicted_depth * ground_truth_depth + ) + weighted_dot_pred_pred = torch.sum(weights * predicted_depth**2) + s = weighted_dot_pred_gt / weighted_dot_pred_pred + + # Optionally clip 's' to prevent extreme scaling + s = s.clamp(min=1e-3) + + # Detach 's' if you want to stop gradients from flowing through it + s = s.detach() + + # Apply the scale factor to the predicted depth + predicted_depth = s * predicted_depth + + else: + # Align the predicted depth with the ground truth using median scaling + scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth) + predicted_depth *= scale_factor + + if disp_input: + # convert back to depth + ground_truth_depth = real_gt + predicted_depth = depth2disparity(predicted_depth) + + # Clip the predicted depth values + if post_clip_min is not None: + predicted_depth = torch.clamp(predicted_depth, min=post_clip_min) + if post_clip_max is not None: + predicted_depth = torch.clamp(predicted_depth, max=post_clip_max) + + if custom_mask is not None: + assert custom_mask.shape == ground_truth_depth_original.shape + mask_within_mask = custom_mask.cpu()[mask] + predicted_depth = predicted_depth[mask_within_mask] + ground_truth_depth = ground_truth_depth[mask_within_mask] + + # Calculate the metrics + abs_rel = torch.mean( + torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth + ).item() + sq_rel = torch.mean( + ((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth + ).item() + + # Correct RMSE calculation + rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item() + + # Clip the depth values to avoid log(0) + predicted_depth = torch.clamp(predicted_depth, min=1e-5) + log_rmse = torch.sqrt( + torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2) + ).item() + + # Calculate the accuracy thresholds + max_ratio = torch.maximum( + predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth + ) + threshold_0 = torch.mean((max_ratio < 1.0).float()).item() + threshold_1 = torch.mean((max_ratio < 1.25).float()).item() + threshold_2 = torch.mean((max_ratio < 1.25**2).float()).item() + threshold_3 = torch.mean((max_ratio < 1.25**3).float()).item() + + # Compute the depth error parity map + if metric_scale: + predicted_depth_original = predicted_depth_original + if disp_input: + predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = ( + torch.abs(predicted_depth_original - ground_truth_depth_original) + / ground_truth_depth_original + ) + elif align_with_lstsq or align_with_lad or align_with_lad2: + predicted_depth_original = predicted_depth_original * s + t + if disp_input: + predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = ( + torch.abs(predicted_depth_original - ground_truth_depth_original) + / ground_truth_depth_original + ) + elif align_with_scale: + predicted_depth_original = predicted_depth_original * s + if disp_input: + predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = ( + torch.abs(predicted_depth_original - ground_truth_depth_original) + / ground_truth_depth_original + ) + else: + predicted_depth_original = predicted_depth_original * scale_factor + if disp_input: + predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = ( + torch.abs(predicted_depth_original - ground_truth_depth_original) + / ground_truth_depth_original + ) + + # Reshape the depth_error_parity_map back to the original image size + depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original) + depth_error_parity_map_full = torch.where( + mask, depth_error_parity_map, depth_error_parity_map_full + ) + + predict_depth_map_full = predicted_depth_original + gt_depth_map_full = torch.zeros_like(ground_truth_depth_original) + gt_depth_map_full = torch.where( + mask, ground_truth_depth_original, gt_depth_map_full + ) + + num_valid_pixels = ( + torch.sum(mask).item() + if custom_mask is None + else torch.sum(mask_within_mask).item() + ) + if num_valid_pixels == 0: + ( + abs_rel, + sq_rel, + rmse, + log_rmse, + threshold_0, + threshold_1, + threshold_2, + threshold_3, + ) = (0, 0, 0, 0, 0, 0, 0, 0) + + results = { + "Abs Rel": abs_rel, + "Sq Rel": sq_rel, + "RMSE": rmse, + "Log RMSE": log_rmse, + "δ < 1.": threshold_0, + "δ < 1.25": threshold_1, + "δ < 1.25^2": threshold_2, + "δ < 1.25^3": threshold_3, + "valid_pixels": num_valid_pixels, + } + + return ( + results, + depth_error_parity_map_full, + predict_depth_map_full, + gt_depth_map_full, + ) diff --git a/eval/mv_recon/base.py b/eval/mv_recon/base.py new file mode 100644 index 0000000000000000000000000000000000000000..af28e3419d5f569751f3b57c825dd9c6080d8537 --- /dev/null +++ b/eval/mv_recon/base.py @@ -0,0 +1,274 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# base class for implementing datasets +# -------------------------------------------------------- +import PIL +import numpy as np +import torch + +from stream3r.dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates + +from eval.mv_recon.dataset_utils.transforms import ImgNorm +import eval.mv_recon.dataset_utils.cropping as cropping + + +class BaseStereoViewDataset: + """Define all basic options. + + Usage: + class MyDataset (BaseStereoViewDataset): + def _get_views(self, idx, rng): + # overload here + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__( + self, + *, # only keyword arguments + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + seed=None, + ): + self.num_views = 2 + self.split = split + self._set_resolutions(resolution) + + self.transform = transform + if isinstance(transform, str): + transform = eval(transform) + + self.aug_crop = aug_crop + self.seed = seed + + def __len__(self): + return len(self.scenes) + + def get_stats(self): + return f"{len(self)} pairs" + + def __repr__(self): + resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" + return ( + f"""{type(self).__name__}({self.get_stats()}, + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace( + "self.", "" + ) + .replace("\n", "") + .replace(" ", "") + ) + + def _get_views(self, idx, resolution, rng): + raise NotImplementedError() + + def __getitem__(self, idx): + if isinstance(idx, tuple): + # the idx is specifying the aspect-ratio + idx, ar_idx = idx + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, "_rng"): + seed = torch.initial_seed() # this is different for each dataloader process + self._rng = np.random.default_rng(seed=seed) + + # over-loaded code + resolution = self._resolutions[ + ar_idx + ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + views = self._get_views(idx, resolution, self._rng) + + # check data-types + for v, view in enumerate(views): + assert ( + "pts3d" not in view + ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view["idx"] = v + + # encode the image + width, height = view["img"].size + view["true_shape"] = np.int32((height, width)) + view["img"] = self.transform(view["img"]) + + assert "camera_intrinsics" in view + if "camera_pose" not in view: + view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite( + view["camera_pose"] + ).all(), f"NaN in camera pose for view {view_name(view)}" + assert "pts3d" not in view + assert "valid_mask" not in view + assert np.isfinite( + view["depthmap"] + ).all(), f"NaN in depthmap for view {view_name(view)}" + pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + view["pts3d"] = pts3d + view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view["camera_intrinsics"] + view["img_mask"] = True + view["ray_mask"] = False + view["ray_map"] = torch.full( + (6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan + ) + view["update"] = True + view["reset"] = False + + # last thing done! + for view in views: + # transpose to make sure all views are the same size + transpose_to_landscape(view) + # this allows to check whether the RNG is is the same state each time + view["rng"] = int.from_bytes(self._rng.bytes(4), "big") + return views + + def _set_resolutions(self, resolutions): + """Set the resolution(s) of the dataset. + Params: + - resolutions: int or tuple or list of tuples + """ + assert resolutions is not None, "undefined resolution" + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance( + width, int + ), f"Bad type for {width=} {type(width)=}, should be int" + assert isinstance( + height, int + ), f"Bad type for {height=} {type(height)=}, should be int" + assert width >= height + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary( + self, image, depthmap, intrinsics, resolution, rng=None, info=None + ): + """This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # downscale with lanczos interpolation so that image.size == resolution + # cropping centered on the principal point + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + + # calculate min distance to margin + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + assert min_margin_x > W / 5, f"Bad principal point in view={info}" + assert min_margin_y > H / 5, f"Bad principal point in view={info}" + + ## Center crop + # Crop on the principal point, make it always centered + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + + image, depthmap, intrinsics = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + + # # transpose the resolution if necessary + W, H = image.size # new size + assert resolution[0] >= resolution[1] + if H > 1.1 * W: + # image is portrait mode + resolution = resolution[::-1] + elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]: + # image is square, so we chose (portrait, landscape) randomly + if rng.integers(2): + resolution = resolution[::-1] + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + # # if self.aug_crop > 1: + # # target_resolution += rng.integers(0, self.aug_crop) + # if resolution != (224, 224): + # halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8 + # ## Recale with max factor, so one of width or height might be larger than target_resolution + # image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh)) + # else: + image, depthmap, intrinsics = cropping.rescale_image_depthmap( + image, depthmap, intrinsics, target_resolution + ) + # actual cropping (if necessary) with bilinear interpolation + # if resolution == (224, 224): + intrinsics2 = cropping.camera_matrix_of_crop( + intrinsics, image.size, resolution, offset_factor=0.5 + ) + crop_bbox = cropping.bbox_from_intrinsics_in_out( + intrinsics, intrinsics2, resolution + ) + image, depthmap, intrinsics = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + return image, depthmap, intrinsics + + +def is_good_type(key, v): + """returns (is_good, err_msg)""" + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + def sel(x): + return x[batch_index] if batch_index not in (None, slice(None)) else x + + db = sel(view["dataset"]) + label = sel(view["label"]) + instance = sel(view["instance"]) + return f"{db}/{label}/{instance}" + + +def transpose_to_landscape(view): + height, width = view["true_shape"] + + if width < height: + # rectify portrait to landscape + assert view["img"].shape == (3, height, width) + view["img"] = view["img"].swapaxes(1, 2) + + assert view["valid_mask"].shape == (height, width) + view["valid_mask"] = view["valid_mask"].swapaxes(0, 1) + + assert view["depthmap"].shape == (height, width) + view["depthmap"] = view["depthmap"].swapaxes(0, 1) + + assert view["pts3d"].shape == (height, width, 3) + view["pts3d"] = view["pts3d"].swapaxes(0, 1) + + # transpose x and y pixels + view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]] diff --git a/eval/mv_recon/criterion.py b/eval/mv_recon/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..86f3ab5099c67d8752f9d5d4b9d929d75419024f --- /dev/null +++ b/eval/mv_recon/criterion.py @@ -0,0 +1,550 @@ +import torch +import torch.nn as nn +from copy import copy, deepcopy +from stream3r.dust3r.utils.misc import invalid_to_zeros, invalid_to_nans +from stream3r.dust3r.utils.geometry import inv, geotrf, depthmap_to_pts3d +from stream3r.dust3r.utils.camera import pose_encoding_to_camera + + +class BaseCriterion(nn.Module): + + def __init__(self, reduction="mean"): + super().__init__() + self.reduction = reduction + + +class Criterion(nn.Module): + + def __init__(self, criterion=None): + super().__init__() + assert isinstance( + criterion, + BaseCriterion), f"{criterion} is not a proper criterion!" + self.criterion = copy(criterion) + + def get_name(self): + return f"{type(self).__name__}({self.criterion})" + + def with_reduction(self, mode="none"): + res = loss = deepcopy(self) + while loss is not None: + assert isinstance(loss, Criterion) + loss.criterion.reduction = mode # make it return the loss for each sample + loss = loss._loss2 # we assume loss is a Multiloss + return res + + +class MultiLoss(nn.Module): + """Easily combinable losses (also keep track of individual loss values): + loss = MyLoss1() + 0.1*MyLoss2() + Usage: + Inherit from this class and override get_name() and compute_loss() + """ + + def __init__(self): + super().__init__() + self._alpha = 1 + self._loss2 = None + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + def __mul__(self, alpha): + assert isinstance(alpha, (int, float)) + res = copy(self) + res._alpha = alpha + return res + + __rmul__ = __mul__ # same + + def __add__(self, loss2): + assert isinstance(loss2, MultiLoss) + res = cur = copy(self) + + while cur._loss2 is not None: + cur = cur._loss2 + cur._loss2 = loss2 + return res + + def __repr__(self): + name = self.get_name() + if self._alpha != 1: + name = f"{self._alpha:g}*{name}" + if self._loss2: + name = f"{name} + {self._loss2}" + return name + + def forward(self, *args, **kwargs): + loss = self.compute_loss(*args, **kwargs) + if isinstance(loss, tuple): + loss, details = loss + elif loss.ndim == 0: + details = {self.get_name(): float(loss)} + else: + details = {} + loss = loss * self._alpha + + if self._loss2: + loss2, details2 = self._loss2(*args, **kwargs) + loss = loss + loss2 + details |= details2 + + return loss, details + + +class LLoss(BaseCriterion): + """L-norm loss""" + + def forward(self, a, b): + assert (a.shape == b.shape and a.ndim >= 2 + and 1 <= a.shape[-1] <= 3), f"Bad shape = {a.shape}" + dist = self.distance(a, b) + + if self.reduction == "none": + return dist + if self.reduction == "sum": + return dist.sum() + if self.reduction == "mean": + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + raise ValueError(f"bad {self.reduction=} mode") + + def distance(self, a, b): + raise NotImplementedError() + + +class L21Loss(LLoss): + """Euclidean distance between 3d points""" + + def distance(self, a, b): + return torch.norm(a - b, dim=-1) # normalized L2 distance + + +L21 = L21Loss() + + +def get_pred_pts3d(gt, pred, use_pose=False): + if "depth" in pred and "pseudo_focal" in pred: + try: + pp = gt["camera_intrinsics"][..., :2, 2] + except KeyError: + pp = None + pts3d = depthmap_to_pts3d(**pred, pp=pp) + + elif "pts3d" in pred: + # pts3d from my camera + pts3d = pred["pts3d"] + + elif "pts3d_in_other_view" in pred: + # pts3d from the other camera, already transformed + assert use_pose is False + return pred["pts3d_in_other_view"] # return! + + if use_pose: + camera_pose = pred.get("camera_pose") + pts3d = pred.get("pts3d_in_self_view") + assert camera_pose is not None + assert pts3d is not None + pts3d = geotrf(pose_encoding_to_camera(camera_pose), pts3d) + + return pts3d + + +def Sum(losses, masks, conf=None): + loss, mask = losses[0], masks[0] + if loss.ndim > 0: + # we are actually returning the loss for every pixels + if conf is not None: + return losses, masks, conf + return losses, masks + else: + # we are returning the global loss + for loss2 in losses[1:]: + loss = loss + loss2 + return loss + + +def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True): + assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3 + assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split("_") + + nan_pts = [] + nnzs = [] + + if norm_mode == "avg": + # gather all points together (joint normalization) + + for i, pt in enumerate(pts): + nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3) + nan_pts.append(nan_pt) + nnzs.append(nnz) + + if fix_first: + break + all_pts = torch.cat(nan_pts, dim=1) + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass # do nothing + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + else: + raise ValueError(f"bad {dis_mode=}") + + norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8) + else: + raise ValueError(f"Not implemented {norm_mode=}") + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts[0].ndim: + norm_factor.unsqueeze_(-1) + + return norm_factor + + +def normalize_pointcloud_t(pts, + norm_mode="avg_dis", + valids=None, + fix_first=True, + gt=False): + if gt: + norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first) + res = [] + + for i, pt in enumerate(pts): + res.append(pt / norm_factor) + + else: + # pts_l, pts_r = pts + # use pts_l and pts_r[-1] as pts to normalize + norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first) + + res = [] + + for i in range(len(pts)): + res.append(pts[i] / norm_factor) + # res_r.append(pts_r[i] / norm_factor) + + # res = [res_l, res_r] + + return res, norm_factor + + +@torch.no_grad() +def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5): + # set invalid points to NaN + _zs = [] + for i in range(len(zs)): + valid_mask = valid_masks[i] if valid_masks is not None else None + _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1) + _zs.append(_z) + + _zs = torch.cat(_zs, dim=-1) + + # compute median depth overall (ignoring nans) + if quantile == 0.5: + shift_z = torch.nanmedian(_zs, dim=-1).values + else: + shift_z = torch.nanquantile(_zs, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale(pts, + valid_masks=None, + z_only=False, + center=True): + # set invalid points to NaN + + _pts = [] + for i in range(len(pts)): + valid_mask = valid_masks[i] if valid_masks is not None else None + _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3) + _pts.append(_pt) + + _pts = torch.cat(_pts, dim=1) + + # compute median center + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + # compute median norm + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +class Regr3D_t(Criterion, MultiLoss): + + def __init__(self, + criterion, + norm_mode="avg_dis", + gt_scale=False, + fix_first=True): + super().__init__(criterion) + self.norm_mode = norm_mode + self.gt_scale = gt_scale + self.fix_first = fix_first + + def get_all_pts3d_t(self, gts, preds, dist_clip=None): + # everything is normalized w.r.t. camera of view1 + in_camera1 = inv(gts[0]["camera_pose"]) + + gt_pts = [] + valids = [] + pr_pts = [] + + for i, gt in enumerate(gts): + # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3 + gt_pts.append(geotrf(in_camera1, gt["pts3d"])) + + valid = gt["valid_mask"].clone() + + if dist_clip is not None: + # points that are too far-away == invalid + dis = gt["pts3d"].norm(dim=-1) + valid = valid & (dis <= dist_clip) + + valids.append(valid) + # pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True)) + pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=False)) + # if i != len(gts)-1: + # pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0))) + + # if i != 0: + # pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0))) + + # pr_pts = (pr_pts_l, pr_pts_r) + + if self.norm_mode: + pr_pts, pr_factor = normalize_pointcloud_t( + pr_pts, + self.norm_mode, + valids, + fix_first=self.fix_first, + gt=False) + else: + pr_factor = None + + if self.norm_mode and not self.gt_scale: + gt_pts, gt_factor = normalize_pointcloud_t( + gt_pts, + self.norm_mode, + valids, + fix_first=self.fix_first, + gt=True) + else: + gt_factor = None + + return gt_pts, pr_pts, gt_factor, pr_factor, valids, {} + + def compute_frame_loss(self, gts, preds, **kw): + gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( + self.get_all_pts3d_t(gts, preds, **kw)) + + pred_pts_l, pred_pts_r = pred_pts + + loss_all = [] + mask_all = [] + conf_all = [] + + loss_left = 0 + loss_right = 0 + pred_conf_l = 0 + pred_conf_r = 0 + + for i in range(len(gt_pts)): + + # Left (Reference) + if i != len(gt_pts) - 1: + frame_loss = self.criterion(pred_pts_l[i][masks[i]], + gt_pts[i][masks[i]]) + + loss_all.append(frame_loss) + mask_all.append(masks[i]) + conf_all.append(preds[i][0]["conf"]) + + # To compare target/reference loss + if i != 0: + loss_left += frame_loss.cpu().detach().numpy().mean() + pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy( + ).mean() + + # Right (Target) + if i != 0: + frame_loss = self.criterion(pred_pts_r[i - 1][masks[i]], + gt_pts[i][masks[i]]) + + loss_all.append(frame_loss) + mask_all.append(masks[i]) + conf_all.append(preds[i - 1][1]["conf"]) + + # To compare target/reference loss + if i != len(gt_pts) - 1: + loss_right += frame_loss.cpu().detach().numpy().mean() + pred_conf_r += preds[ + i - 1][1]["conf"].cpu().detach().numpy().mean() + + if pr_factor is not None and gt_factor is not None: + filter_factor = pr_factor[pr_factor > gt_factor] + else: + filter_factor = [] + + if len(filter_factor) > 0: + factor_loss = (filter_factor - gt_factor).abs().mean() + else: + factor_loss = 0.0 + + self_name = type(self).__name__ + details = { + self_name + "_pts3d_1": float(loss_all[0].mean()), + self_name + "_pts3d_2": float(loss_all[1].mean()), + self_name + "loss_left": float(loss_left), + self_name + "loss_right": float(loss_right), + self_name + "conf_left": float(pred_conf_l), + self_name + "conf_right": float(pred_conf_r), + } + + return Sum(loss_all, mask_all, + conf_all), (details | monitoring), factor_loss + + +class ConfLoss_t(MultiLoss): + """Weighted regression by learned confidence. + Assuming the input pixel_loss is a pixel-level regression loss. + + Principle: + high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) + low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) + + alpha: hyperparameter + """ + + def __init__(self, pixel_loss, alpha=1): + super().__init__() + assert alpha > 0 + self.alpha = alpha + self.pixel_loss = pixel_loss.with_reduction("none") + + def get_name(self): + return f"ConfLoss({self.pixel_loss})" + + def get_conf_log(self, x): + return x, torch.log(x) + + def compute_frame_loss(self, gts, preds, **kw): + # compute per-pixel loss + (losses, masks, + confs), details, loss_factor = (self.pixel_loss.compute_frame_loss( + gts, preds, **kw)) + + # weight by confidence + conf_losses = [] + conf_sum = 0 + for i in range(len(losses)): + conf, log_conf = self.get_conf_log(confs[i][masks[i]]) + conf_sum += conf.mean() + conf_loss = losses[i] * conf - self.alpha * log_conf + conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0 + conf_losses.append(conf_loss) + + conf_losses = torch.stack(conf_losses) * 2.0 + conf_loss_mean = conf_losses.mean() + + return ( + conf_loss_mean, + dict( + conf_loss_1=float(conf_losses[0]), + conf_loss2=float(conf_losses[1]), + conf_mean=conf_sum / len(losses), + **details, + ), + loss_factor, + ) + + +class Regr3D_t_ShiftInv(Regr3D_t): + """Same than Regr3D but invariant to depth shift.""" + + def get_all_pts3d_t(self, gts, preds): + # compute unnormalized points + gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( + super().get_all_pts3d_t(gts, preds)) + + # pred_pts_l, pred_pts_r = pred_pts + gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts] + + pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts] + # pred_zs.append(pred_pts_r[-1][..., 2]) + + # compute median depth + gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None] + pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, + None] + + # subtract the median depth + for i in range(len(gt_pts)): + gt_pts[i][..., 2] -= gt_shift_z + + for i in range(len(pred_pts)): + # for j in range(len(pred_pts[i])): + # pred_pts[i][..., 2] -= pred_shift_z + pred_pts[i] = pred_pts[i].clone() + pred_pts[i][..., 2] -= pred_shift_z # avoid in-place modification + + monitoring = dict( + monitoring, + gt_shift_z=gt_shift_z.mean().detach(), + pred_shift_z=pred_shift_z.mean().detach(), + ) + return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring + + +class Regr3D_t_ScaleInv(Regr3D_t): + """Same than Regr3D but invariant to depth shift. + if gt_scale == True: enforce the prediction to take the same scale than GT + """ + + def get_all_pts3d_t(self, gts, preds): + # compute depth-normalized points + gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( + super().get_all_pts3d_t(gts, preds)) + + pred_pts_all = [x.clone() for x in pred_pts + ] # [pred_pt for pred_pt in pred_pts_l] + + _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks) + _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks) + + # prevent predictions to be in a ridiculous range + pred_scale = pred_scale.clip(min=1e-3, max=1e3) + + # subtract the median depth + if self.gt_scale: + for i in range(len(pred_pts)): + # for j in range(len(pred_pts[i])): + pred_pts[i] *= gt_scale / pred_scale + + else: + for i in range(len(pred_pts)): + # for j in range(len(pred_pts[i])): + pred_pts[i] *= pred_scale / gt_scale + + for i in range(len(gt_pts)): + gt_pts[i] *= gt_scale / pred_scale + + monitoring = dict(monitoring, + gt_scale=gt_scale.mean(), + pred_scale=pred_scale.mean().detach()) + + return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring + + +class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv): + # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv + pass diff --git a/eval/mv_recon/data.py b/eval/mv_recon/data.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0c67f2ae9429dd6ee06d6fcabccd4697612ca1 --- /dev/null +++ b/eval/mv_recon/data.py @@ -0,0 +1,532 @@ +import os +import cv2 +import numpy as np +import os.path as osp +from collections import deque +import random + +from eval.mv_recon.base import BaseStereoViewDataset +import eval.mv_recon.dataset_utils.cropping as cropping + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """Open an image or a depthmap with opencv-python.""" + if path.endswith((".exr", "EXR")): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f"Could not load image={path} with {options=}") + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def shuffle_deque(dq, seed=None): + # Set the random seed for reproducibility + if seed is not None: + random.seed(seed) + + # Convert deque to list, shuffle, and convert back + shuffled_list = list(dq) + random.shuffle(shuffled_list) + return deque(shuffled_list) + + +class SevenScenes(BaseStereoViewDataset): + def __init__( + self, + num_seq=1, + num_frames=5, + min_thresh=10, + max_thresh=100, + test_id=None, + full_video=False, + tuple_list=None, + seq_id=None, + rebuttal=False, + shuffle_seed=-1, + kf_every=1, + *args, + ROOT, + **kwargs, + ): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self.num_seq = num_seq + self.num_frames = num_frames + self.max_thresh = max_thresh + self.min_thresh = min_thresh + self.test_id = test_id + self.full_video = full_video + self.kf_every = kf_every + self.seq_id = seq_id + self.rebuttal = rebuttal + self.shuffle_seed = shuffle_seed + + # load all scenes + self.load_all_tuples(tuple_list) + self.load_all_scenes(ROOT) + + def __len__(self): + if self.tuple_list is not None: + return len(self.tuple_list) + return len(self.scene_list) * self.num_seq + + def load_all_tuples(self, tuple_list): + if tuple_list is not None: + self.tuple_list = tuple_list + # with open(tuple_path) as f: + # self.tuple_list = f.read().splitlines() + + else: + self.tuple_list = None + + def load_all_scenes(self, base_dir): + + if self.tuple_list is not None: + # Use pre-defined simplerecon scene_ids + self.scene_list = [ + "stairs/seq-06", + "stairs/seq-02", + "pumpkin/seq-06", + "chess/seq-01", + "heads/seq-02", + "fire/seq-02", + "office/seq-03", + "pumpkin/seq-03", + "redkitchen/seq-07", + "chess/seq-02", + "office/seq-01", + "redkitchen/seq-01", + "fire/seq-01", + ] + print(f"Found {len(self.scene_list)} sequences in split {self.split}") + return + + scenes = os.listdir(base_dir) + + file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split] + + self.scene_list = [] + for scene in scenes: + if self.test_id is not None and scene != self.test_id: + continue + # read file split + with open(osp.join(base_dir, scene, file_split)) as f: + seq_ids = f.read().splitlines() + + for seq_id in seq_ids: + # seq is string, take the int part and make it 01, 02, 03 + # seq_id = 'seq-{:2d}'.format(int(seq_id)) + num_part = "".join(filter(str.isdigit, seq_id)) + seq_id = f"seq-{num_part.zfill(2)}" + if self.seq_id is not None and seq_id != self.seq_id: + continue + self.scene_list.append(f"{scene}/{seq_id}") + + print(f"Found {len(self.scene_list)} sequences in split {self.split}") + + def _get_views(self, idx, resolution, rng): + + if self.tuple_list is not None: + line = self.tuple_list[idx].split(" ") + scene_id = line[0] + img_idxs = line[1:] + + else: + scene_id = self.scene_list[idx // self.num_seq] + seq_id = idx % self.num_seq + + data_path = osp.join(self.ROOT, scene_id) + num_files = len([name for name in os.listdir(data_path) if "color" in name]) + img_idxs = [f"{i:06d}" for i in range(num_files)] + img_idxs = img_idxs[:: self.kf_every] + + # Intrinsics used in SimpleRecon + fx, fy, cx, cy = 525, 525, 320, 240 + intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + views = [] + imgs_idxs = deque(img_idxs) + if self.shuffle_seed >= 0: + imgs_idxs = shuffle_deque(imgs_idxs) + + while len(imgs_idxs) > 0: + im_idx = imgs_idxs.popleft() + impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png") + depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png") + posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt") + + rgb_image = imread_cv2(impath) + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0])) + + depthmap[depthmap == 65535] = 0 + depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0 + depthmap[depthmap > 10] = 0 + depthmap[depthmap < 1e-3] = 0 + + camera_pose = np.loadtxt(posepath).astype(np.float32) + + if resolution != (224, 224) or self.rebuttal: + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath + ) + else: + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath + ) + W, H = rgb_image.size + cx = W // 2 + cy = H // 2 + l, t = cx - 112, cy - 112 + r, b = cx + 112, cy + 112 + crop_bbox = (l, t, r, b) + rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap( + rgb_image, depthmap, intrinsics, crop_bbox + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset="7scenes", + label=osp.join(scene_id, im_idx), + instance=impath, + ) + ) + return views + + +class DTU(BaseStereoViewDataset): + def __init__( + self, + num_seq=49, + num_frames=5, + min_thresh=10, + max_thresh=30, + test_id=None, + full_video=False, + sample_pairs=False, + kf_every=1, + *args, + ROOT, + **kwargs, + ): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + + self.num_seq = num_seq + self.num_frames = num_frames + self.max_thresh = max_thresh + self.min_thresh = min_thresh + self.test_id = test_id + self.full_video = full_video + self.kf_every = kf_every + self.sample_pairs = sample_pairs + + # load all scenes + self.load_all_scenes(ROOT) + + def __len__(self): + return len(self.scene_list) * self.num_seq + + def load_all_scenes(self, base_dir): + + if self.test_id is None: + self.scene_list = os.listdir(osp.join(base_dir)) + print(f"Found {len(self.scene_list)} scenes in split {self.split}") + + else: + if isinstance(self.test_id, list): + self.scene_list = self.test_id + else: + self.scene_list = [self.test_id] + + print(f"Test_id: {self.test_id}") + + def load_cam_mvsnet(self, file, interval_scale=1): + """read camera txt file""" + cam = np.zeros((2, 4, 4)) + words = file.read().split() + # read extrinsic + for i in range(0, 4): + for j in range(0, 4): + extrinsic_index = 4 * i + j + 1 + cam[0][i][j] = words[extrinsic_index] + + # read intrinsic + for i in range(0, 3): + for j in range(0, 3): + intrinsic_index = 3 * i + j + 18 + cam[1][i][j] = words[intrinsic_index] + + if len(words) == 29: + cam[1][3][0] = words[27] + cam[1][3][1] = float(words[28]) * interval_scale + cam[1][3][2] = 192 + cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2] + elif len(words) == 30: + cam[1][3][0] = words[27] + cam[1][3][1] = float(words[28]) * interval_scale + cam[1][3][2] = words[29] + cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * cam[1][3][2] + elif len(words) == 31: + cam[1][3][0] = words[27] + cam[1][3][1] = float(words[28]) * interval_scale + cam[1][3][2] = words[29] + cam[1][3][3] = words[30] + else: + cam[1][3][0] = 0 + cam[1][3][1] = 0 + cam[1][3][2] = 0 + cam[1][3][3] = 0 + + extrinsic = cam[0].astype(np.float32) + intrinsic = cam[1].astype(np.float32) + + return intrinsic, extrinsic + + def _get_views(self, idx, resolution, rng): + scene_id = self.scene_list[idx // self.num_seq] + seq_id = idx % self.num_seq + + print("Scene ID:", scene_id) + + image_path = osp.join(self.ROOT, scene_id, "images") + depth_path = osp.join(self.ROOT, scene_id, "depths") + mask_path = osp.join(self.ROOT, scene_id, "binary_masks") + cam_path = osp.join(self.ROOT, scene_id, "cams") + pairs_path = osp.join(self.ROOT, scene_id, "pair.txt") + + if not self.full_video: + img_idxs = self.sample_pairs(pairs_path, seq_id) + else: + img_idxs = sorted(os.listdir(image_path)) + img_idxs = img_idxs[:: self.kf_every] + + views = [] + imgs_idxs = deque(img_idxs) + + while len(imgs_idxs) > 0: + im_idx = imgs_idxs.pop() + impath = osp.join(image_path, im_idx) + depthpath = osp.join(depth_path, im_idx.replace(".jpg", ".npy")) + campath = osp.join(cam_path, im_idx.replace(".jpg", "_cam.txt")) + maskpath = osp.join(mask_path, im_idx.replace(".jpg", ".png")) + + rgb_image = imread_cv2(impath) + depthmap = np.load(depthpath) + depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) + + mask = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED) / 255.0 + mask = mask.astype(np.float32) + + mask[mask > 0.5] = 1.0 + mask[mask < 0.5] = 0.0 + + mask = cv2.resize( + mask, + (depthmap.shape[1], depthmap.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + kernel = np.ones((10, 10), np.uint8) # Define the erosion kernel + mask = cv2.erode(mask, kernel, iterations=1) + depthmap = depthmap * mask + + cur_intrinsics, camera_pose = self.load_cam_mvsnet(open(campath, "r")) + intrinsics = cur_intrinsics[:3, :3] + camera_pose = np.linalg.inv(camera_pose) + + if resolution != (224, 224): + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath + ) + else: + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, (512, 384), rng=rng, info=impath + ) + W, H = rgb_image.size + cx = W // 2 + cy = H // 2 + l, t = cx - 112, cy - 112 + r, b = cx + 112, cy + 112 + crop_bbox = (l, t, r, b) + rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap( + rgb_image, depthmap, intrinsics, crop_bbox + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset="dtu", + label=osp.join(scene_id, im_idx), + instance=impath, + ) + ) + + return views + + +class NRGBD(BaseStereoViewDataset): + def __init__( + self, + num_seq=1, + num_frames=5, + min_thresh=10, + max_thresh=100, + test_id=None, + full_video=False, + tuple_list=None, + seq_id=None, + rebuttal=False, + shuffle_seed=-1, + kf_every=1, + *args, + ROOT, + **kwargs, + ): + + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self.num_seq = num_seq + self.num_frames = num_frames + self.max_thresh = max_thresh + self.min_thresh = min_thresh + self.test_id = test_id + self.full_video = full_video + self.kf_every = kf_every + self.seq_id = seq_id + self.rebuttal = rebuttal + self.shuffle_seed = shuffle_seed + + # load all scenes + self.load_all_tuples(tuple_list) + self.load_all_scenes(ROOT) + + def __len__(self): + if self.tuple_list is not None: + return len(self.tuple_list) + return len(self.scene_list) * self.num_seq + + def load_all_tuples(self, tuple_list): + if tuple_list is not None: + self.tuple_list = tuple_list + # with open(tuple_path) as f: + # self.tuple_list = f.read().splitlines() + + else: + self.tuple_list = None + + def load_all_scenes(self, base_dir): + + scenes = [ + d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) + ] + + if self.test_id is not None: + self.scene_list = [self.test_id] + + else: + self.scene_list = scenes + + print(f"Found {len(self.scene_list)} sequences in split {self.split}") + + def load_poses(self, path): + file = open(path, "r") + lines = file.readlines() + file.close() + poses = [] + valid = [] + lines_per_matrix = 4 + for i in range(0, len(lines), lines_per_matrix): + if "nan" in lines[i]: + valid.append(False) + poses.append(np.eye(4, 4, dtype=np.float32).tolist()) + else: + valid.append(True) + pose_floats = [ + [float(x) for x in line.split()] + for line in lines[i : i + lines_per_matrix] + ] + poses.append(pose_floats) + + return np.array(poses, dtype=np.float32), valid + + def _get_views(self, idx, resolution, rng): + + if self.tuple_list is not None: + line = self.tuple_list[idx].split(" ") + scene_id = line[0] + img_idxs = line[1:] + + else: + scene_id = self.scene_list[idx // self.num_seq] + + num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images"))) + img_idxs = [f"{i}" for i in range(num_files)] + img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)] + + fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240 + intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + posepath = osp.join(self.ROOT, scene_id, f"poses.txt") + camera_poses, valids = self.load_poses(posepath) + + imgs_idxs = deque(img_idxs) + if self.shuffle_seed >= 0: + imgs_idxs = shuffle_deque(imgs_idxs) + views = [] + + while len(imgs_idxs) > 0: + im_idx = imgs_idxs.popleft() + + impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png") + depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png") + + rgb_image = imread_cv2(impath) + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0 + depthmap[depthmap > 10] = 0 + depthmap[depthmap < 1e-3] = 0 + + rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0])) + + camera_pose = camera_poses[int(im_idx)] + # gl to cv + camera_pose[:, 1:3] *= -1.0 + if resolution != (224, 224) or self.rebuttal: + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath + ) + else: + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath + ) + W, H = rgb_image.size + cx = W // 2 + cy = H // 2 + l, t = cx - 112, cy - 112 + r, b = cx + 112, cy + 112 + crop_bbox = (l, t, r, b) + rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap( + rgb_image, depthmap, intrinsics, crop_bbox + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset="nrgbd", + label=osp.join(scene_id, im_idx), + instance=impath, + ) + ) + + return views diff --git a/eval/mv_recon/dataset_utils/__init__.py b/eval/mv_recon/dataset_utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/eval/mv_recon/dataset_utils/__init__.py @@ -0,0 +1 @@ + diff --git a/eval/mv_recon/dataset_utils/corr.py b/eval/mv_recon/dataset_utils/corr.py new file mode 100755 index 0000000000000000000000000000000000000000..d39d8fad844c65f0f839de6b728f2ab72b19f6a2 --- /dev/null +++ b/eval/mv_recon/dataset_utils/corr.py @@ -0,0 +1,122 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- + +import numpy as np +from dust3r.utils.device import to_numpy +from dust3r.utils.geometry import inv, geotrf + + +def reproject_view(pts3d, view2): + shape = view2["pts3d"].shape[:2] + return reproject( + pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape + ) + + +def reproject(pts3d, K, world2cam, shape): + H, W, THREE = pts3d.shape + assert THREE == 3 + + with np.errstate(divide="ignore", invalid="ignore"): + pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2) + + return (H, W), ravel_xy(pos, shape) + + +def ravel_xy(pos, shape): + H, W = shape + with np.errstate(invalid="ignore"): + qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T + quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip( + min=0, max=H - 1, out=qy + ) + return quantized_pos + + +def unravel_xy(pos, shape): + + return np.unravel_index(pos, shape)[0].base[:, ::-1].copy() + + +def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False): + is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)) + pos1 = is_reciprocal1.nonzero()[0] + pos2 = corres_1_to_2[pos1] + if ret_recip: + return is_reciprocal1, pos1, pos2 + return pos1, pos2 + + +def extract_correspondences_from_pts3d( + view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0 +): + view1, view2 = to_numpy((view1, view2)) + + shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2) + shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1) + + is_reciprocal1, pos1, pos2 = reciprocal_1d( + corres1_to_2, corres2_to_1, ret_recip=True + ) + is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)) + + if target_n_corres is None: + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2 + + available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum()) + target_n_positives = int(target_n_corres * (1 - nneg)) + n_positives = min(len(pos1), target_n_positives) + n_negatives = min(target_n_corres - n_positives, available_negatives) + + if n_negatives + n_positives != target_n_corres: + + n_positives = target_n_corres - n_negatives + assert n_positives <= len(pos1) + + assert n_positives <= len(pos1) + assert n_positives <= len(pos2) + assert n_negatives <= (~is_reciprocal1).sum() + assert n_negatives <= (~is_reciprocal2).sum() + assert n_positives + n_negatives == target_n_corres + + valid = np.ones(n_positives, dtype=bool) + if n_positives < len(pos1): + + perm = rng.permutation(len(pos1))[:n_positives] + pos1 = pos1[perm] + pos2 = pos2[perm] + + if n_negatives > 0: + + def norm(p): + return p / p.sum() + + pos1 = np.r_[ + pos1, + rng.choice( + shape1[0] * shape1[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal1), + ), + ] + pos2 = np.r_[ + pos2, + rng.choice( + shape2[0] * shape2[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal2), + ), + ] + valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)] + + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2, valid diff --git a/eval/mv_recon/dataset_utils/cropping.py b/eval/mv_recon/dataset_utils/cropping.py new file mode 100755 index 0000000000000000000000000000000000000000..ec9182beac4fc5906bd844673a0f166f6dcddcbc --- /dev/null +++ b/eval/mv_recon/dataset_utils/cropping.py @@ -0,0 +1,142 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- + +import PIL.Image +import os + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import numpy as np # noqa +from stream3r.dust3r.utils.geometry import ( + colmap_to_opencv_intrinsics, + opencv_to_colmap_intrinsics, +) # noqa + +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + + +class ImageList: + """Convenience class to aply the same operation to a whole set of images.""" + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch("resize", *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch("crop", *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap( + image, depthmap, camera_intrinsics, output_resolution, force=True +): + """Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + + assert tuple(depthmap.shape[:2]) == image.size[::-1] + + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + return (image.to_pil(), depthmap, camera_intrinsics) + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + image = image.resize( + output_resolution, resample=lanczos if scale_final < 1 else bicubic + ) + if depthmap is not None: + depthmap = cv2.resize( + depthmap, + output_resolution, + fx=scale_final, + fy=scale_final, + interpolation=cv2.INTER_NEAREST, + ) + + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final + ) + + return image.to_pil(), depthmap, camera_intrinsics + + +def camera_matrix_of_crop( + input_camera_matrix, + input_resolution, + output_resolution, + scaling=1, + offset_factor=0.5, + offset=None, +): + + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + depthmap = depthmap[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out( + input_camera_matrix, output_camera_matrix, output_resolution +): + out_width, out_height = output_resolution + l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l + out_width, t + out_height) + return crop_bbox diff --git a/eval/mv_recon/dataset_utils/transforms.py b/eval/mv_recon/dataset_utils/transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..98df66dd7cdf3b223041673ac430fd79af358da5 --- /dev/null +++ b/eval/mv_recon/dataset_utils/transforms.py @@ -0,0 +1,77 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- + +import torchvision.transforms as tvf +from stream3r.dust3r.utils.image import ImgNorm + + +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) + + +def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True): + if isinstance(value, (int, float)): + if value < 0: + raise ValueError(f"If is a single number, it must be non negative.") + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + value = [float(value[0]), float(value[1])] + else: + raise TypeError(f"should be a single number or a list/tuple with length 2.") + + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"values should be between {bound}, but got {value}.") + + if value[0] == value[1] == center: + return None + else: + return tuple(value) + + +import torch +import torchvision.transforms.functional as F + + +def SeqColorJitter(): + """ + Return a color jitter transform with same random parameters + """ + brightness = _check_input(0.5) + contrast = _check_input(0.5) + saturation = _check_input(0.5) + hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + + fn_idx = torch.randperm(4) + brightness_factor = ( + None + if brightness is None + else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + ) + contrast_factor = ( + None + if contrast is None + else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + ) + saturation_factor = ( + None + if saturation is None + else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + ) + hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + + def _color_jitter(img): + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return ImgNorm(img) + + return _color_jitter diff --git a/eval/mv_recon/launch.py b/eval/mv_recon/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..16f263b59204118971ded959b5f926417b334dac --- /dev/null +++ b/eval/mv_recon/launch.py @@ -0,0 +1,426 @@ +import os +import sys +from copy import deepcopy + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +import time +import torch +import argparse +import numpy as np +import open3d as o3d +import os.path as osp + +from torch.utils.data._utils.collate import default_collate +from tqdm import tqdm + +from stream3r.models.stream3r import STream3R +from stream3r.dust3r.utils.geometry import geotrf +from stream3r.models.components.utils.geometry import unproject_depth_map_to_point_map +from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri +from stream3r.utils.utils import ImgDust3r2Stream3r + +from eval.mv_recon.criterion import Regr3D_t_ScaleShiftInv, L21 +from eval.mv_recon.utils import accuracy, completion +from eval.mv_recon.data import SevenScenes, NRGBD + + +torch.backends.cuda.matmul.allow_tf32 = True + +# avoid high cpu usage +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +torch.set_num_threads(1) +# =========================================== + + +def get_args_parser(): + parser = argparse.ArgumentParser("3D Reconstruction evaluation", + add_help=False) + parser.add_argument( + "--weights", + type=str, + default="", + help="ckpt name", + ) + parser.add_argument("--device", type=str, default="cuda:0", help="device") + parser.add_argument("--model_name", type=str, default="") + parser.add_argument("--conf_thresh", + type=float, + default=0.0, + help="confidence threshold") + parser.add_argument( + "--output_dir", + type=str, + default="", + help="value for outdir", + ) + parser.add_argument("--size", type=int, default=512) + parser.add_argument("--revisit", type=int, default=1, help="revisit times") + parser.add_argument("--freeze", action="store_true") + return parser + + +def main(args): + if args.size == 518: + resolution = (518, 392) + elif args.size == 512: + resolution = (512, 384) + elif args.size == 224: + resolution = 224 + else: + raise NotImplementedError + + datasets_all = { + "7scenes": + SevenScenes( + split="test", + ROOT="./data/7scenes", + resolution=resolution, + num_seq=1, + full_video=True, + kf_every=200, + ), # 20), + "NRGBD": + NRGBD( + split="test", + ROOT="./data/neural_rgbd", + resolution=resolution, + num_seq=1, + full_video=True, + kf_every=500, + ), + } + + device = 'cuda' + model_name = args.model_name + + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = STream3R.from_pretrained("yslan/STream3R").to(device) + model.eval() + + os.makedirs(args.output_dir, exist_ok=True) + + criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True) + + with torch.no_grad(): + for name_data, dataset in datasets_all.items(): + save_path = osp.join(args.output_dir, name_data) + os.makedirs(save_path, exist_ok=True) + log_file = osp.join(save_path, + f"logs_0.txt") + + acc_all = 0 + acc_all_med = 0 + comp_all = 0 + comp_all_med = 0 + nc1_all = 0 + nc1_all_med = 0 + nc2_all = 0 + nc2_all_med = 0 + + idxs = list(range(len(dataset))) + for data_idx in tqdm(idxs): + batch = default_collate([dataset[data_idx]]) + ignore_keys = set([ + "depthmap", + "dataset", + "label", + "instance", + "idx", + "true_shape", + "rng", + ]) + for view in batch: + for name in view.keys(): # pseudo_focal + if name in ignore_keys: + continue + if isinstance(view[name], tuple) or isinstance( + view[name], list): + view[name] = [ + x.to(device, non_blocking=True) + for x in view[name] + ] + else: + view[name] = view[name].to(device, + non_blocking=True) + + if model_name == "ours" or model_name == "stream3r": + revisit = args.revisit + update = not args.freeze + if revisit > 1: + # repeat input for 'revisit' times + new_views = [] + for r in range(revisit): + for i in range(len(batch)): + new_view = deepcopy(batch[i]) + new_view["idx"] = [ + (r * len(batch) + i) + for _ in range(len(batch[i]["idx"])) + ] + new_view["instance"] = [ + str(r * len(batch) + i) for _ in range( + len(batch[i]["instance"])) + ] + if r > 0: + if not update: + new_view[ + "update"] = torch.zeros_like( + batch[i]["update"]).bool() + new_views.append(new_view) + batch = new_views + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + batch_cpu = [ + { + k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in sample.items() + } for sample in batch + ] + # move all stuffs in batch to cuda + with torch.autocast('cuda', enabled=False): + images = torch.cat([item['img'] for item in batch]) + images = ImgDust3r2Stream3r(images).to(device) + + with torch.no_grad(): + predictions = model(images) + + extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], predictions["images"].shape[-2:]) + world_points_from_depth = unproject_depth_map_to_point_map( + predictions["depth"].cpu().numpy().squeeze(0), + extrinsic.cpu().numpy().squeeze(0), + intrinsic.cpu().numpy().squeeze(0) + ) + world_points_from_depth = torch.from_numpy(world_points_from_depth).unsqueeze(0).to(device=device) + + preds = world_points_from_depth + confs = predictions["depth_conf"] + + all_preds = [] + for idx in range(preds.shape[1]): + all_preds.append( + {'pts3d': preds[0][idx:idx+1].cpu(), 'conf': confs[0][idx:idx+1]} + ) + # convert preds into list + views = batch_cpu + preds = all_preds + + valid_length = len(preds) // revisit + preds = preds[-valid_length:] + batch = batch[-valid_length:] + + # Evaluation + print( + f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}" + ) + gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( + criterion.get_all_pts3d_t(views, preds)) + pred_scale, gt_scale, pred_shift_z, gt_shift_z = ( + monitoring["pred_scale"], + monitoring["gt_scale"], + monitoring["pred_shift_z"], + monitoring["gt_shift_z"], + ) + + in_camera1 = None + pts_all = [] + pts_gt_all = [] + images_all = [] + masks_all = [] + conf_all = [] + + for j, view in enumerate(batch): + if in_camera1 is None: + in_camera1 = view["camera_pose"][0].cpu() + + image = view["img"].permute(0, 2, 3, + 1).cpu().numpy()[0] + mask = view["valid_mask"].cpu().numpy()[0] + + pts = pred_pts[j].cpu().numpy()[0] + conf = preds[j]["conf"].cpu().data.numpy()[0] + # mask = mask & (conf > 1.8) + + pts_gt = gt_pts[j].detach().cpu().numpy()[0] + + H, W = image.shape[:2] + cx = W // 2 + cy = H // 2 + l, t = cx - 112, cy - 112 + r, b = cx + 112, cy + 112 + image = image[t:b, l:r] + mask = mask[t:b, l:r] + pts = pts[t:b, l:r] + pts_gt = pts_gt[t:b, l:r] + + #### Align predicted 3D points to the ground truth + pts[..., -1] += gt_shift_z.cpu().numpy().item() + pts = geotrf(in_camera1, pts) + + pts_gt[..., -1] += gt_shift_z.cpu().numpy().item() + pts_gt = geotrf(in_camera1, pts_gt) + + images_all.append((image[None, ...] + 1.0) / 2.0) + pts_all.append(pts[None, ...]) + pts_gt_all.append(pts_gt[None, ...]) + masks_all.append(mask[None, ...]) + conf_all.append(conf[None, ...]) + + images_all = np.concatenate(images_all, axis=0) + pts_all = np.concatenate(pts_all, axis=0) + pts_gt_all = np.concatenate(pts_gt_all, axis=0) + masks_all = np.concatenate(masks_all, axis=0) + + scene_id = view["label"][0].rsplit("/", 1)[0] + + save_params = {} + + save_params["images_all"] = images_all + save_params["pts_all"] = pts_all + save_params["pts_gt_all"] = pts_gt_all + save_params["masks_all"] = masks_all + + np.save( + os.path.join(save_path, + f"{scene_id.replace('/', '_')}.npy"), + save_params, + ) + + if "DTU" in name_data: + threshold = 100 + else: + threshold = 0.1 + + pts_all_masked = pts_all[masks_all > 0] + pts_gt_all_masked = pts_gt_all[masks_all > 0] + images_all_masked = images_all[masks_all > 0] + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector( + pts_all_masked.reshape(-1, 3)) + pcd.colors = o3d.utility.Vector3dVector( + images_all_masked.reshape(-1, 3)) + o3d.io.write_point_cloud( + os.path.join(save_path, + f"{scene_id.replace('/', '_')}-mask.ply"), + pcd, + ) + + pcd_gt = o3d.geometry.PointCloud() + pcd_gt.points = o3d.utility.Vector3dVector( + pts_gt_all_masked.reshape(-1, 3)) + pcd_gt.colors = o3d.utility.Vector3dVector( + images_all_masked.reshape(-1, 3)) + o3d.io.write_point_cloud( + os.path.join(save_path, + f"{scene_id.replace('/', '_')}-gt.ply"), + pcd_gt, + ) + + trans_init = np.eye(4) + + reg_p2p = o3d.pipelines.registration.registration_icp( + pcd, + pcd_gt, + threshold, + trans_init, + o3d.pipelines.registration. + TransformationEstimationPointToPoint(), + ) + + transformation = reg_p2p.transformation + + pcd = pcd.transform(transformation) + pcd.estimate_normals() + pcd_gt.estimate_normals() + + gt_normal = np.asarray(pcd_gt.normals) + pred_normal = np.asarray(pcd.normals) + + acc, acc_med, nc1, nc1_med = accuracy( + pcd_gt.points, pcd.points, gt_normal, pred_normal) + comp, comp_med, nc2, nc2_med = completion( + pcd_gt.points, pcd.points, gt_normal, pred_normal) + print( + f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}" + ) + print( + f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}", + file=open(log_file, "a"), + ) + + acc_all += acc + comp_all += comp + nc1_all += nc1 + nc2_all += nc2 + + acc_all_med += acc_med + comp_all_med += comp_med + nc1_all_med += nc1_med + nc2_all_med += nc2_med + + # release cuda memory + torch.cuda.empty_cache() + + to_write = "" + # Copy the error log from each process to the main error log + for i in range(8): + if not os.path.exists(osp.join(save_path, + f"logs_{i}.txt")): + break + with open(osp.join(save_path, f"logs_{i}.txt"), + "r") as f_sub: + to_write += f_sub.read() + + with open(osp.join(save_path, f"logs_all.txt"), "w") as f: + log_data = to_write + metrics = defaultdict(list) + for line in log_data.strip().split("\n"): + match = regex.match(line) + if match: + data = match.groupdict() + # Exclude 'scene_id' from metrics as it's an identifier + for key, value in data.items(): + if key != "scene_id": + metrics[key].append(float(value)) + metrics["nc"].append( + (float(data["nc1"]) + float(data["nc2"])) / 2) + metrics["nc_med"].append( + (float(data["nc1_med"]) + + float(data["nc2_med"])) / 2) + mean_metrics = { + metric: sum(values) / len(values) + for metric, values in metrics.items() + } + + c_name = "mean" + print_str = f"{c_name.ljust(20)}: " + for m_name in mean_metrics: + print_num = np.mean(mean_metrics[m_name]) + print_str = print_str + f"{m_name}: {print_num:.3f} | " + print_str = print_str + "\n" + f.write(to_write + print_str) + + +from collections import defaultdict +import re + +pattern = r""" + Idx:\s*(?P[^,]+),\s* + Acc:\s*(?P[^,]+),\s* + Comp:\s*(?P[^,]+),\s* + NC1:\s*(?P[^,]+),\s* + NC2:\s*(?P[^,]+)\s*-\s* + Acc_med:\s*(?P[^,]+),\s* + Compc_med:\s*(?P[^,]+),\s* + NC1c_med:\s*(?P[^,]+),\s* + NC2c_med:\s*(?P[^,]+) +""" + +regex = re.compile(pattern, re.VERBOSE) + +if __name__ == "__main__": + parser = get_args_parser() + args = parser.parse_args() + + main(args) diff --git a/eval/mv_recon/run.sh b/eval/mv_recon/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..51f1b319965876a7a8b36d71378d89fed940b2ed --- /dev/null +++ b/eval/mv_recon/run.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +workdir='.' + +model_name='stream3r' + +output_dir="${workdir}/eval_results/mv_recon/${model_name}/" +echo "$output_dir" + +python eval/mv_recon/launch.py \ + --output_dir="$output_dir" \ + --size=518 \ + --model_name="stream3r" \ \ No newline at end of file diff --git a/eval/mv_recon/utils.py b/eval/mv_recon/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..546f8535e9d2f9ea25efec4d12b7b655bf940568 --- /dev/null +++ b/eval/mv_recon/utils.py @@ -0,0 +1,59 @@ +import numpy as np +from scipy.spatial import cKDTree as KDTree +import torch + + +def completion_ratio(gt_points, rec_points, dist_th=0.05): + gen_points_kd_tree = KDTree(rec_points) + distances, _ = gen_points_kd_tree.query(gt_points) + comp_ratio = np.mean((distances < dist_th).astype(np.float32)) + return comp_ratio + + +def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None): + gt_points_kd_tree = KDTree(gt_points) + distances, idx = gt_points_kd_tree.query(rec_points, workers=-1) + acc = np.mean(distances) + + acc_median = np.median(distances) + + if gt_normals is not None and rec_normals is not None: + normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1) + normal_dot = np.abs(normal_dot) + + return acc, acc_median, np.mean(normal_dot), np.median(normal_dot) + + return acc, acc_median + + +def completion(gt_points, rec_points, gt_normals=None, rec_normals=None): + gt_points_kd_tree = KDTree(rec_points) + distances, idx = gt_points_kd_tree.query(gt_points, workers=-1) + comp = np.mean(distances) + comp_median = np.median(distances) + + if gt_normals is not None and rec_normals is not None: + normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1) + normal_dot = np.abs(normal_dot) + + return comp, comp_median, np.mean(normal_dot), np.median(normal_dot) + + return comp, comp_median + + +def compute_iou(pred_vox, target_vox): + # Get voxel indices + v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()] + v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()] + + # Convert to sets for set operations + v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices) + v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices) + + # Compute intersection and union + intersection = v_pred_filled & v_target_filled + union = v_pred_filled | v_target_filled + + # Compute IoU + iou = len(intersection) / len(union) + return iou diff --git a/eval/relpose/evo_utils.py b/eval/relpose/evo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..513c6d05b6ed6fe47289297aa3fb668097d4a193 --- /dev/null +++ b/eval/relpose/evo_utils.py @@ -0,0 +1,427 @@ +import os +import re +from copy import deepcopy +from pathlib import Path + +import evo.main_ape as main_ape +import evo.main_rpe as main_rpe +import matplotlib.pyplot as plt +import numpy as np +from evo.core import sync +from evo.core.metrics import PoseRelation, Unit +from evo.core.trajectory import PosePath3D, PoseTrajectory3D +from evo.tools import file_interface, plot +from scipy.spatial.transform import Rotation +from evo.core import metrics + + +def sintel_cam_read(filename): + """Read camera data, return (M,N) tuple. + + M is the intrinsic matrix, N is the extrinsic matrix, so that + + x = M*N*X, + with x being a point in homogeneous image pixel coordinates, X being a + point in homogeneous world coordinates. + """ + TAG_FLOAT = 202021.25 + + f = open(filename, "rb") + check = np.fromfile(f, dtype=np.float32, count=1)[0] + assert ( + check == TAG_FLOAT + ), " cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format( + TAG_FLOAT, check + ) + M = np.fromfile(f, dtype="float64", count=9).reshape((3, 3)) + N = np.fromfile(f, dtype="float64", count=12).reshape((3, 4)) + return M, N + + +def load_replica_traj(gt_file): + traj_w_c = np.loadtxt(gt_file) + assert traj_w_c.shape[1] == 12 or traj_w_c.shape[1] == 16 + poses = [ + np.array( + [ + [r[0], r[1], r[2], r[3]], + [r[4], r[5], r[6], r[7]], + [r[8], r[9], r[10], r[11]], + [0, 0, 0, 1], + ] + ) + for r in traj_w_c + ] + + pose_path = PosePath3D(poses_se3=poses) + timestamps_mat = np.arange(traj_w_c.shape[0]).astype(float) + + traj = PoseTrajectory3D(poses_se3=pose_path.poses_se3, timestamps=timestamps_mat) + xyz = traj.positions_xyz + # shift -1 column -> w in back column + # quat = np.roll(traj.orientations_quat_wxyz, -1, axis=1) + # uncomment this line if the quaternion is in scalar-first format + quat = traj.orientations_quat_wxyz + + traj_tum = np.column_stack((xyz, quat)) + return (traj_tum, timestamps_mat) + + +def load_sintel_traj(gt_file): # './data/sintel/training/camdata_left/alley_2' + # Refer to ParticleSfM + gt_pose_lists = sorted(os.listdir(gt_file)) + gt_pose_lists = [ + os.path.join(gt_file, x) for x in gt_pose_lists if x.endswith(".cam") + ] + tstamps = [float(x.split("/")[-1][:-4].split("_")[-1]) for x in gt_pose_lists] + gt_poses = [ + sintel_cam_read(f)[1] for f in gt_pose_lists + ] # [1] means get the extrinsic + xyzs, wxyzs = [], [] + tum_gt_poses = [] + for gt_pose in gt_poses: + gt_pose = np.concatenate([gt_pose, np.array([[0, 0, 0, 1]])], 0) + gt_pose_inv = np.linalg.inv(gt_pose) # world2cam -> cam2world + xyz = gt_pose_inv[:3, -1] + xyzs.append(xyz) + R = Rotation.from_matrix(gt_pose_inv[:3, :3]) + xyzw = R.as_quat() # scalar-last for scipy + wxyz = np.array([xyzw[-1], xyzw[0], xyzw[1], xyzw[2]]) + wxyzs.append(wxyz) + tum_gt_pose = np.concatenate([xyz, wxyz], 0) # TODO: check if this is correct + tum_gt_poses.append(tum_gt_pose) + + tum_gt_poses = np.stack(tum_gt_poses, 0) + tum_gt_poses[:, :3] = tum_gt_poses[:, :3] - np.mean( + tum_gt_poses[:, :3], 0, keepdims=True + ) + tt = np.expand_dims(np.stack(tstamps, 0), -1) + return tum_gt_poses, tt + + +def load_traj(gt_traj_file, traj_format="sintel", skip=0, stride=1, num_frames=None): + """Read trajectory format. Return in TUM-RGBD format. + Returns: + traj_tum (N, 7): camera to world poses in (x,y,z,qx,qy,qz,qw) + timestamps_mat (N, 1): timestamps + """ + if traj_format == "replica": + traj_tum, timestamps_mat = load_replica_traj(gt_traj_file) + elif traj_format == "sintel": + traj_tum, timestamps_mat = load_sintel_traj(gt_traj_file) + elif traj_format in ["tum", "tartanair"]: + traj = file_interface.read_tum_trajectory_file(gt_traj_file) + xyz = traj.positions_xyz + quat = traj.orientations_quat_wxyz + timestamps_mat = traj.timestamps + traj_tum = np.column_stack((xyz, quat)) + else: + raise NotImplementedError + + traj_tum = traj_tum[skip::stride] + timestamps_mat = timestamps_mat[skip::stride] + if num_frames is not None: + traj_tum = traj_tum[:num_frames] + timestamps_mat = timestamps_mat[:num_frames] + return traj_tum, timestamps_mat + + +def update_timestamps(gt_file, traj_format, skip=0, stride=1): + """Update timestamps given a""" + if traj_format == "tum": + traj_t_map_file = gt_file.replace("groundtruth.txt", "rgb.txt") + timestamps = load_timestamps(traj_t_map_file, traj_format) + return timestamps[skip::stride] + elif traj_format == "tartanair": + traj_t_map_file = gt_file.replace("gt_pose.txt", "times.txt") + timestamps = load_timestamps(traj_t_map_file, traj_format) + return timestamps[skip::stride] + + +def load_timestamps(time_file, traj_format="replica"): + if traj_format in ["tum", "tartanair"]: + with open(time_file, "r+") as f: + lines = f.readlines() + timestamps_mat = [ + float(x.split(" ")[0]) for x in lines if not x.startswith("#") + ] + return timestamps_mat + + +def make_traj(args) -> PoseTrajectory3D: + if isinstance(args, tuple) or isinstance(args, list): + traj, tstamps = args + return PoseTrajectory3D( + positions_xyz=traj[:, :3], + orientations_quat_wxyz=traj[:, 3:], + timestamps=tstamps, + ) + assert isinstance(args, PoseTrajectory3D), type(args) + return deepcopy(args) + + +def eval_metrics(pred_traj, gt_traj=None, seq="", filename="", sample_stride=1): + + if sample_stride > 1: + pred_traj[0] = pred_traj[0][::sample_stride] + pred_traj[1] = pred_traj[1][::sample_stride] + if gt_traj is not None: + updated_gt_traj = [] + updated_gt_traj.append(gt_traj[0][::sample_stride]) + updated_gt_traj.append(gt_traj[1][::sample_stride]) + gt_traj = updated_gt_traj + + pred_traj = make_traj(pred_traj) + + if gt_traj is not None: + gt_traj = make_traj(gt_traj) + + if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]: + pred_traj.timestamps = gt_traj.timestamps + else: + print(pred_traj.timestamps.shape[0], gt_traj.timestamps.shape[0]) + + gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj) + + # ATE + traj_ref = gt_traj + traj_est = pred_traj + + ate_result = main_ape.ape( + traj_ref, + traj_est, + est_name="traj", + pose_relation=PoseRelation.translation_part, + align=True, + correct_scale=True, + ) + + ate = ate_result.stats["rmse"] + # print(ate_result.np_arrays['error_array']) + # exit() + + # RPE rotation and translation + delta_list = [1] + rpe_rots, rpe_transs = [], [] + for delta in delta_list: + rpe_rots_result = main_rpe.rpe( + traj_ref, + traj_est, + est_name="traj", + pose_relation=PoseRelation.rotation_angle_deg, + align=True, + correct_scale=True, + delta=delta, + delta_unit=Unit.frames, + rel_delta_tol=0.01, + all_pairs=True, + ) + + rot = rpe_rots_result.stats["rmse"] + rpe_rots.append(rot) + + for delta in delta_list: + rpe_transs_result = main_rpe.rpe( + traj_ref, + traj_est, + est_name="traj", + pose_relation=PoseRelation.translation_part, + align=True, + correct_scale=True, + delta=delta, + delta_unit=Unit.frames, + rel_delta_tol=0.01, + all_pairs=True, + ) + + trans = rpe_transs_result.stats["rmse"] + rpe_transs.append(trans) + + rpe_trans, rpe_rot = np.mean(rpe_transs), np.mean(rpe_rots) + with open(filename, "w+") as f: + f.write(f"Seq: {seq} \n\n") + f.write(f"{ate_result}") + f.write(f"{rpe_rots_result}") + f.write(f"{rpe_transs_result}") + + print(f"Save results to {filename}") + return ate, rpe_trans, rpe_rot + + +def eval_metrics_first_pose_align_last_pose( + pred_traj, gt_traj=None, seq="", filename="", figpath="", sample_stride=1 +): + if sample_stride > 1: + pred_traj[0] = pred_traj[0][::sample_stride] + pred_traj[1] = pred_traj[1][::sample_stride] + if gt_traj is not None: + gt_traj = [gt_traj[0][::sample_stride], gt_traj[1][::sample_stride]] + pred_traj = make_traj(pred_traj) + if gt_traj is not None: + gt_traj = make_traj(gt_traj) + + if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]: + pred_traj.timestamps = gt_traj.timestamps + else: + print( + "Different number of poses:", + pred_traj.timestamps.shape[0], + gt_traj.timestamps.shape[0], + ) + + gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj) + + if gt_traj is not None and pred_traj is not None: + if len(gt_traj.poses_se3) > 0 and len(pred_traj.poses_se3) > 0: + first_gt_pose = gt_traj.poses_se3[0] + first_pred_pose = pred_traj.poses_se3[0] + # T = (first_gt_pose) * inv(first_pred_pose) + T = first_gt_pose @ np.linalg.inv(first_pred_pose) + + # Apply T to every predicted pose + aligned_pred_poses = [] + for pose in pred_traj.poses_se3: + aligned_pred_poses.append(T @ pose) + aligned_pred_traj = PoseTrajectory3D( + poses_se3=aligned_pred_poses, + timestamps=np.array(pred_traj.timestamps), + # optionally copy other fields if your make_traj object has them + ) + pred_traj = aligned_pred_traj # .poses_se3 = aligned_pred_poses + plot_trajectory( + pred_traj, + gt_traj, + title=seq, + filename=figpath, + align=False, + correct_scale=False, + ) + + if gt_traj is not None and len(gt_traj.poses_se3) > 0: + gt_traj = PoseTrajectory3D( + poses_se3=[gt_traj.poses_se3[-1]], timestamps=[gt_traj.timestamps[-1]] + ) + if pred_traj is not None and len(pred_traj.poses_se3) > 0: + pred_traj = PoseTrajectory3D( + poses_se3=[pred_traj.poses_se3[-1]], timestamps=[pred_traj.timestamps[-1]] + ) + + ate_result = main_ape.ape( + gt_traj, + pred_traj, + est_name="traj", + pose_relation=PoseRelation.translation_part, + align=False, # <-- important + correct_scale=False, # <-- important + ) + ate = ate_result.stats["rmse"] + with open(filename, "w+") as f: + f.write(f"Seq: {seq}\n\n") + f.write(f"{ate_result}") + + print(f"Save results to {filename}") + + return ate + + +def best_plotmode(traj): + _, i1, i2 = np.argsort(np.var(traj.positions_xyz, axis=0)) + plot_axes = "xyz"[i2] + "xyz"[i1] + return getattr(plot.PlotMode, plot_axes) + + +def plot_trajectory( + pred_traj, gt_traj=None, title="", filename="", align=True, correct_scale=True +): + pred_traj = make_traj(pred_traj) + + if gt_traj is not None: + gt_traj = make_traj(gt_traj) + if pred_traj.timestamps.shape[0] == gt_traj.timestamps.shape[0]: + pred_traj.timestamps = gt_traj.timestamps + else: + print("WARNING", pred_traj.timestamps.shape[0], gt_traj.timestamps.shape[0]) + + gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj) + + if align: + pred_traj.align(gt_traj, correct_scale=correct_scale) + + plot_collection = plot.PlotCollection("PlotCol") + fig = plt.figure(figsize=(8, 8)) + plot_mode = best_plotmode(gt_traj if (gt_traj is not None) else pred_traj) + ax = plot.prepare_axis(fig, plot_mode) + ax.set_title(title) + if gt_traj is not None: + plot.traj(ax, plot_mode, gt_traj, "--", "gray", "Ground Truth") + plot.traj(ax, plot_mode, pred_traj, "-", "blue", "Predicted") + plot_collection.add_figure("traj_error", fig) + plot_collection.export(filename, confirm_overwrite=False) + plt.close(fig=fig) + print(f"Saved trajectory to {filename.replace('.png','')}_traj_error.png") + + +def save_trajectory_tum_format(traj, filename): + traj = make_traj(traj) + tostr = lambda a: " ".join(map(str, a)) + with Path(filename).open("w") as f: + for i in range(traj.num_poses): + f.write( + f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[0,1,2,3]])}\n" + ) + print(f"Saved trajectory to {filename}") + + +def extract_metrics(file_path): + with open(file_path, "r") as file: + content = file.read() + + # Extract metrics using regex + ate_match = re.search( + r"APE w.r.t. translation part \(m\).*?rmse\s+([0-9.]+)", content, re.DOTALL + ) + rpe_trans_match = re.search( + r"RPE w.r.t. translation part \(m\).*?rmse\s+([0-9.]+)", content, re.DOTALL + ) + rpe_rot_match = re.search( + r"RPE w.r.t. rotation angle in degrees \(deg\).*?rmse\s+([0-9.]+)", + content, + re.DOTALL, + ) + + ate = float(ate_match.group(1)) if ate_match else 0.0 + rpe_trans = float(rpe_trans_match.group(1)) if rpe_trans_match else 0.0 + rpe_rot = float(rpe_rot_match.group(1)) if rpe_rot_match else 0.0 + + return ate, rpe_trans, rpe_rot + + +def process_directory(directory): + results = [] + for root, _, files in os.walk(directory): + if files is not None: + files = sorted(files) + for file in files: + if file.endswith("_metric.txt"): + file_path = os.path.join(root, file) + seq_name = file.replace("_eval_metric.txt", "") + ate, rpe_trans, rpe_rot = extract_metrics(file_path) + results.append((seq_name, ate, rpe_trans, rpe_rot)) + + return results + + +def calculate_averages(results): + total_ate = sum(r[1] for r in results) + total_rpe_trans = sum(r[2] for r in results) + total_rpe_rot = sum(r[3] for r in results) + count = len(results) + + if count == 0: + return 0.0, 0.0, 0.0 + + avg_ate = total_ate / count + avg_rpe_trans = total_rpe_trans / count + avg_rpe_rot = total_rpe_rot / count + + return avg_ate, avg_rpe_trans, avg_rpe_rot diff --git a/eval/relpose/launch.py b/eval/relpose/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..505ea29e3817a6e937728d3dd7c11094f35373f3 --- /dev/null +++ b/eval/relpose/launch.py @@ -0,0 +1,272 @@ +import os +import sys +import torch +import argparse +from tqdm import tqdm +from accelerate import PartialState + +from stream3r.models.stream3r import STream3R +from stream3r.stream_session import StreamSession +from stream3r.dust3r.utils.image import load_images_for_eval as load_images +from stream3r.dust3r.utils.device import collate_with_cat +from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri +from stream3r.dust3r.utils.geometry import inv +from stream3r.utils.utils import ImgDust3r2Stream3r + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +from eval.relpose.metadata import dataset_metadata +from eval.relpose.utils import * + + +torch.backends.cuda.matmul.allow_tf32 = True + +# avoid high cpu usage +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +torch.set_num_threads(1) +# =========================================== + + +def get_args_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument("--device", + type=str, + default="cuda", + help="pytorch device") + parser.add_argument( + "--output_dir", + type=str, + default="", + help="value for outdir", + ) + parser.add_argument("--no_crop", + type=bool, + default=True, + help="whether to crop input data") + + parser.add_argument( + "--eval_dataset", + type=str, + default="sintel", + choices=list(dataset_metadata.keys()), + ) + parser.add_argument("--size", type=int, default="224") + + parser.add_argument("--pose_eval_stride", + default=1, + type=int, + help="stride for pose evaluation") + parser.add_argument("--shuffle", action="store_true", default=False) + parser.add_argument( + "--full_seq", + action="store_true", + default=False, + help="use full sequence for pose evaluation", + ) + parser.add_argument( + "--seq_list", + nargs="+", + default=None, + help="list of sequences for pose evaluation", + ) + + parser.add_argument("--freeze_state", action="store_true", default=False) + return parser + + +def eval_pose_estimation_dist(args, + model, + img_path, + save_dir=None, + mask_path=None): + + metadata = dataset_metadata.get(args.eval_dataset) + anno_path = metadata.get("anno_path", None) + + seq_list = args.seq_list + if seq_list is None: + if metadata.get("full_seq", False): + args.full_seq = True + else: + seq_list = metadata.get("seq_list", []) + if args.full_seq: + seq_list = os.listdir(img_path) + seq_list = [ + seq for seq in seq_list + if os.path.isdir(os.path.join(img_path, seq)) + ] + seq_list = sorted(seq_list) + + if save_dir is None: + save_dir = args.output_dir + + distributed_state = PartialState() + model.to(distributed_state.device) + device = distributed_state.device + + with distributed_state.split_between_processes(seq_list) as seqs: + ate_list = [] + rpe_trans_list = [] + rpe_rot_list = [] + error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process + for seq in tqdm(seqs): + try: + dir_path = metadata["dir_path_func"](img_path, seq) + + # Handle skip_condition + skip_condition = metadata.get("skip_condition", None) + if skip_condition is not None and skip_condition( + save_dir, seq): + continue + + mask_path_seq_func = metadata.get("mask_path_seq_func", + lambda mask_path, seq: None) + mask_path_seq = mask_path_seq_func(mask_path, seq) + + filelist = [ + os.path.join(dir_path, name) + for name in os.listdir(dir_path) + ] + filelist.sort() + filelist = filelist[::args.pose_eval_stride] + + images = load_images( + filelist, + size=518, + verbose=True, + crop=False, + patch_size=14, + ) + + images = collate_with_cat([tuple(images)]) + images = torch.stack([view["img"] for view in images], dim=1) + images = ImgDust3r2Stream3r(images).to(device) + + with torch.no_grad(): + session = StreamSession(model, mode="causal") + for i in range(images.shape[1]): + image = images[:, i:i+1] + predictions = session.forward_stream(image) + + extrinsic, _ = pose_encoding_to_extri_intri(predictions["pose_enc"], predictions["images"].shape[-2:]) + + pr_poses = [] + for i in range(extrinsic.shape[1]): + pr_poses.append(inv(torch.cat([extrinsic[0, i], torch.tensor([[0, 0, 0, 1]], device=device)], dim=0))) + + pred_traj = get_tum_poses(pr_poses) + os.makedirs(f"{save_dir}/{seq}", exist_ok=True) + save_tum_poses(pr_poses, f"{save_dir}/{seq}/pred_traj.txt") + + gt_traj_file = metadata["gt_traj_func"](img_path, anno_path, + seq) + traj_format = metadata.get("traj_format", None) + + if args.eval_dataset == "sintel": + gt_traj = load_traj(gt_traj_file=gt_traj_file, + stride=args.pose_eval_stride) + elif traj_format is not None: + gt_traj = load_traj( + gt_traj_file=gt_traj_file, + traj_format=traj_format, + stride=args.pose_eval_stride, + ) + else: + gt_traj = None + + if gt_traj is not None: + ate, rpe_trans, rpe_rot = eval_metrics( + pred_traj, + gt_traj, + seq=seq, + filename=f"{save_dir}/{seq}_eval_metric.txt", + ) + plot_trajectory(pred_traj, + gt_traj, + title=seq, + filename=f"{save_dir}/{seq}.png") + else: + ate, rpe_trans, rpe_rot = 0, 0, 0 + bug = True + + ate_list.append(ate) + rpe_trans_list.append(rpe_trans) + rpe_rot_list.append(rpe_rot) + + # Write to error log after each sequence + with open(error_log_path, "a") as f: + f.write( + f"{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n" + ) + f.write(f"{ate:.5f}\n") + f.write(f"{rpe_trans:.5f}\n") + f.write(f"{rpe_rot:.5f}\n") + + except Exception as e: + if "out of memory" in str(e): + # Handle OOM + torch.cuda.empty_cache() # Clear the CUDA memory + with open(error_log_path, "a") as f: + f.write( + f"OOM error in sequence {seq}, skipping this sequence.\n" + ) + print(f"OOM error in sequence {seq}, skipping...") + elif "Degenerate covariance rank" in str( + e) or "Eigenvalues did not converge" in str(e): + # Handle Degenerate covariance rank exception and Eigenvalues did not converge exception + with open(error_log_path, "a") as f: + f.write(f"Exception in sequence {seq}: {str(e)}\n") + print( + f"Traj evaluation error in sequence {seq}, skipping.") + else: + raise e # Rethrow if it's not an expected exception + + distributed_state.wait_for_everyone() + + results = process_directory(save_dir) + avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results) + + # Write the averages to the error log (only on the main process) + if distributed_state.is_main_process: + with open(f"{save_dir}/_error_log.txt", "a") as f: + # Copy the error log from each process to the main error log + for i in range(distributed_state.num_processes): + if not os.path.exists(f"{save_dir}/_error_log_{i}.txt"): + break + with open(f"{save_dir}/_error_log_{i}.txt", "r") as f_sub: + f.write(f_sub.read()) + f.write( + f"Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n" + ) + + return avg_ate, avg_rpe_trans, avg_rpe_rot + + +def eval_pose_estimation(args, model, save_dir=None): + metadata = dataset_metadata.get(args.eval_dataset) + img_path = metadata["img_path"] + mask_path = metadata["mask_path"] + + ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist( + args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path) + return ate_mean, rpe_trans_mean, rpe_rot_mean + + +def main(): + args = get_args_parser() + args = args.parse_args() + + args.full_seq = False + args.no_crop = False + + model = STream3R.from_pretrained("yslan/STream3R").to(args.device) + model.eval() + + eval_pose_estimation(args, model, save_dir=args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/eval/relpose/metadata.py b/eval/relpose/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d97c5e21c4671a705e311ecf9bb2f2a8c1e516 --- /dev/null +++ b/eval/relpose/metadata.py @@ -0,0 +1,233 @@ +import os +import glob +from tqdm import tqdm + +# Define the merged dataset metadata dictionary +dataset_metadata = { + "davis": { + "img_path": "data/davis/DAVIS/JPEGImages/480p", + "mask_path": "data/davis/DAVIS/masked_images/480p", + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: None, + "traj_format": None, + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq), + "skip_condition": None, + "process_func": None, # Not used in mono depth estimation + }, + "kitti": { + "img_path": "data/kitti/depth_selection/val_selection_cropped/image_gathered", # Default path + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: None, + "traj_format": None, + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": lambda args, img_path: process_kitti(args, img_path), + }, + "bonn": { + "img_path": "data/bonn/rgbd_bonn_dataset", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join( + img_path, f"rgbd_bonn_{seq}", "rgb_110" + ), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt" + ), + "traj_format": "tum", + "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"], + "full_seq": False, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": lambda args, img_path: process_bonn(args, img_path), + }, + "nyu": { + "img_path": "data/nyu-v2/val/nyu_images", + "mask_path": None, + "process_func": lambda args, img_path: process_nyu(args, img_path), + }, + "scannet": { + "img_path": "data/scannetv2", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "pose_90.txt" + ), + "traj_format": "replica", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)), + "process_func": lambda args, img_path: process_scannet(args, img_path), + }, + "scannet-257": { + "img_path": "data/scannetv2_3_257", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "pose_90.txt" + ), + "traj_format": "replica", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)), + "process_func": lambda args, img_path: process_scannet(args, img_path), + }, + "scannet-129": { + "img_path": "data/scannetv2_3_129", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "pose_90.txt" + ), + "traj_format": "replica", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)), + "process_func": lambda args, img_path: process_scannet(args, img_path), + }, + "scannet-65": { + "img_path": "data/scannetv2_3_65", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "pose_90.txt" + ), + "traj_format": "replica", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)), + "process_func": lambda args, img_path: process_scannet(args, img_path), + }, + "scannet-33": { + "img_path": "data/scannetv2_3_33", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "pose_90.txt" + ), + "traj_format": "replica", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)), + "process_func": lambda args, img_path: process_scannet(args, img_path), + }, + "tum": { + "img_path": "data/tum", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "groundtruth_90.txt" + ), + "traj_format": "tum", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": None, + }, + "sintel": { + "img_path": "data/sintel/training/final", + "anno_path": "data/sintel/training/camdata_left", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq), + "traj_format": None, + "seq_list": [ + "alley_2", + "ambush_4", + "ambush_5", + "ambush_6", + "cave_2", + "cave_4", + "market_2", + "market_5", + "market_6", + "shaman_3", + "sleeping_1", + "sleeping_2", + "temple_2", + "temple_3", + ], + "full_seq": False, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": lambda args, img_path: process_sintel(args, img_path), + }, +} + + +# Define processing functions for each dataset +def process_kitti(args, img_path): + for dir in tqdm(sorted(glob.glob(f"{img_path}/*"))): + filelist = sorted(glob.glob(f"{dir}/*.png")) + save_dir = f"{args.output_dir}/{os.path.basename(dir)}" + yield filelist, save_dir + + +def process_bonn(args, img_path): + if args.full_seq: + for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))): + filelist = sorted(glob.glob(f"{dir}/rgb/*.png")) + save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}" + yield filelist, save_dir + else: + seq_list = ( + ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"] + if args.seq_list is None + else args.seq_list + ) + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f"{img_path}/rgbd_bonn_{seq}/rgb_110/*.png")) + save_dir = f"{args.output_dir}/{seq}" + yield filelist, save_dir + + +def process_nyu(args, img_path): + filelist = sorted(glob.glob(f"{img_path}/*.png")) + save_dir = f"{args.output_dir}" + yield filelist, save_dir + + +def process_scannet(args, img_path): + seq_list = sorted(glob.glob(f"{img_path}/*")) + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f"{seq}/color_90/*.jpg")) + save_dir = f"{args.output_dir}/{os.path.basename(seq)}" + yield filelist, save_dir + + +def process_sintel(args, img_path): + if args.full_seq: + for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))): + filelist = sorted(glob.glob(f"{dir}/*.png")) + save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}" + yield filelist, save_dir + else: + seq_list = [ + "alley_2", + "ambush_4", + "ambush_5", + "ambush_6", + "cave_2", + "cave_4", + "market_2", + "market_5", + "market_6", + "shaman_3", + "sleeping_1", + "sleeping_2", + "temple_2", + "temple_3", + ] + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f"{img_path}/{seq}/*.png")) + save_dir = f"{args.output_dir}/{seq}" + yield filelist, save_dir diff --git a/eval/relpose/run.sh b/eval/relpose/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..ad38bb91f0bfc6b7f231c3acf25f9b4edabb4fea --- /dev/null +++ b/eval/relpose/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e + +workdir='.' + +datasets=('tum' 'sintel' 'scannet') +model_name='stream3r' + +for data in "${datasets[@]}"; do + output_dir="${workdir}/eval_results/relpose/${model_name}/${data}" + echo "$output_dir" + accelerate launch --num_processes 1 --main_process_port 29558 eval/relpose/launch.py \ + --output_dir "$output_dir/" \ + --eval_dataset "$data" +done + + diff --git a/eval/relpose/utils.py b/eval/relpose/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e138cf66ea1270086cf9bab1c45f7217c038eb97 --- /dev/null +++ b/eval/relpose/utils.py @@ -0,0 +1,303 @@ +import cv2 +import numpy as np +import torch +import matplotlib as mpl +import matplotlib.cm as cm +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg +from scipy.spatial.transform import Rotation +from eval.relpose.evo_utils import * +from PIL import Image +import imageio.v2 as iio +from matplotlib.figure import Figure + + +def todevice(batch, device, callback=None, non_blocking=False): + """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + """ + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == "numpy": + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice # alias + + +def to_numpy(x): + return todevice(x, "numpy") + + +def c2w_to_tumpose(c2w): + """ + Convert a camera-to-world matrix to a tuple of translation and rotation + + input: c2w: 4x4 matrix + output: tuple of translation and rotation (x y z qw qx qy qz) + """ + # convert input to numpy + c2w = to_numpy(c2w) + xyz = c2w[:3, -1] + rot = Rotation.from_matrix(c2w[:3, :3]) + qx, qy, qz, qw = rot.as_quat() + tum_pose = np.concatenate([xyz, [qw, qx, qy, qz]]) + return tum_pose + + +def get_tum_poses(poses): + """ + poses: list of 4x4 arrays + """ + tt = np.arange(len(poses)).astype(float) + tum_poses = [c2w_to_tumpose(p) for p in poses] + tum_poses = np.stack(tum_poses, 0) + return [tum_poses, tt] + + +def save_tum_poses(poses, path): + traj = get_tum_poses(poses) + save_trajectory_tum_format(traj, path) + return traj[0] # return the poses + + +def save_focals(cam_dict, path): + # convert focal to txt + focals = cam_dict["focal"] + np.savetxt(path, focals, fmt="%.6f") + return focals + + +def save_intrinsics(cam_dict, path): + K_raw = np.eye(3)[None].repeat(len(cam_dict["focal"]), axis=0) + K_raw[:, 0, 0] = cam_dict["focal"] + K_raw[:, 1, 1] = cam_dict["focal"] + K_raw[:, :2, 2] = cam_dict["pp"] + K = K_raw.reshape(-1, 9) + np.savetxt(path, K, fmt="%.6f") + return K_raw + + +def save_conf_maps(conf, path): + for i, c in enumerate(conf): + np.save(f"{path}/conf_{i}.npy", c.detach().cpu().numpy()) + return conf + + +def save_rgb_imgs(colors, path): + imgs = colors + for i, img in enumerate(imgs): + # convert from rgb to bgr + iio.imwrite( + f"{path}/frame_{i:04d}.jpg", (img.cpu().numpy() * 255).astype(np.uint8) + ) + return imgs + + +def save_depth_maps(pts3ds_self, path, conf_self=None): + depth_maps = torch.stack([pts3d_self[..., -1] for pts3d_self in pts3ds_self], 0) + min_depth = depth_maps.min() # float(torch.quantile(out, 0.01)) + max_depth = depth_maps.max() # float(torch.quantile(out, 0.99)) + colored_depth = colorize( + depth_maps, + cmap_name="Spectral_r", + range=(min_depth, max_depth), + append_cbar=True, + ) + images = [] + + if conf_self is not None: + conf_selfs = torch.concat(conf_self, 0) + min_conf = torch.log(conf_selfs.min()) # float(torch.quantile(out, 0.01)) + max_conf = torch.log(conf_selfs.max()) # float(torch.quantile(out, 0.99)) + colored_conf = colorize( + torch.log(conf_selfs), + cmap_name="jet", + range=(min_conf, max_conf), + append_cbar=True, + ) + + for i, depth_map in enumerate(colored_depth): + # Apply color map to depth map + img_path = f"{path}/frame_{(i):04d}.png" + if conf_self is None: + to_save = (depth_map * 255).detach().cpu().numpy().astype(np.uint8) + else: + to_save = torch.cat([depth_map, colored_conf[i]], dim=1) + to_save = (to_save * 255).detach().cpu().numpy().astype(np.uint8) + iio.imwrite(img_path, to_save) + images.append(Image.open(img_path)) + np.save(f"{path}/frame_{(i):04d}.npy", depth_maps[i].detach().cpu().numpy()) + + images[0].save( + f"{path}/_depth_maps.gif", + save_all=True, + append_images=images[1:], + duration=100, + loop=0, + ) + + return depth_maps + + +def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2): + """ + :param w: pixels + :param h: pixels + :param vmin: min value + :param vmax: max value + :param cmap_name: + :param label + :return: + """ + fig = Figure(figsize=(2, 8), dpi=100) + fig.subplots_adjust(right=1.5) + canvas = FigureCanvasAgg(fig) + + # Do some plotting. + ax = fig.add_subplot(111) + cmap = cm.get_cmap(cmap_name) + norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + tick_cnt = 6 + tick_loc = np.linspace(vmin, vmax, tick_cnt) + cb1 = mpl.colorbar.ColorbarBase( + ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical" + ) + + tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc] + if cbar_precision == 0: + tick_label = [x[:-2] for x in tick_label] + + cb1.set_ticklabels(tick_label) + + cb1.ax.tick_params(labelsize=18, rotation=0) + if label is not None: + cb1.set_label(label) + + # fig.tight_layout() + + canvas.draw() + s, (width, height) = canvas.print_to_buffer() + + im = np.frombuffer(s, np.uint8).reshape((height, width, 4)) + + im = im[:, :, :3].astype(np.float32) / 255.0 + if h != im.shape[0]: + w = int(im.shape[1] / im.shape[0] * h) + im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA) + + return im + + +def colorize_np( + x, + cmap_name="jet", + mask=None, + range=None, + append_cbar=False, + cbar_in_image=False, + cbar_precision=2, +): + """ + turn a grayscale image into a color image + :param x: input grayscale, [H, W] + :param cmap_name: the colorization method + :param mask: the mask image, [H, W] + :param range: the range for scaling, automatic if None, [min, max] + :param append_cbar: if append the color bar + :param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image + :return: colorized image, [H, W] + """ + if range is not None: + vmin, vmax = range + elif mask is not None: + # vmin, vmax = np.percentile(x[mask], (2, 100)) + vmin = np.min(x[mask][np.nonzero(x[mask])]) + vmax = np.max(x[mask]) + # vmin = vmin - np.abs(vmin) * 0.01 + x[np.logical_not(mask)] = vmin + # print(vmin, vmax) + else: + vmin, vmax = np.percentile(x, (1, 100)) + vmax += 1e-6 + + x = np.clip(x, vmin, vmax) + x = (x - vmin) / (vmax - vmin) + # x = np.clip(x, 0., 1.) + + cmap = cm.get_cmap(cmap_name) + x_new = cmap(x)[:, :, :3] + + if mask is not None: + mask = np.float32(mask[:, :, np.newaxis]) + x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask) + + cbar = get_vertical_colorbar( + h=x.shape[0], + vmin=vmin, + vmax=vmax, + cmap_name=cmap_name, + cbar_precision=cbar_precision, + ) + + if append_cbar: + if cbar_in_image: + x_new[:, -cbar.shape[1] :, :] = cbar + else: + x_new = np.concatenate( + (x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1 + ) + return x_new + else: + return x_new + + +# tensor +def colorize( + x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False +): + """ + turn a grayscale image into a color image + :param x: torch.Tensor, grayscale image, [H, W] or [B, H, W] + :param mask: torch.Tensor or None, mask image, [H, W] or [B, H, W] or None + """ + + device = x.device + x = x.cpu().numpy() + if mask is not None: + mask = mask.cpu().numpy() > 0.99 + kernel = np.ones((3, 3), np.uint8) + + if x.ndim == 2: + x = x[None] + if mask is not None: + mask = mask[None] + + out = [] + for x_ in x: + if mask is not None: + mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool) + + x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image) + out.append(torch.from_numpy(x_).to(device).float()) + out = torch.stack(out).squeeze(0) + return out diff --git a/eval/video_depth/eval_depth.py b/eval/video_depth/eval_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..bb37e5795459491e5c3ffd404a665cbdf4313293 --- /dev/null +++ b/eval/video_depth/eval_depth.py @@ -0,0 +1,385 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +from eval.video_depth.tools import depth_evaluation, group_by_directory +import numpy as np +import cv2 +from tqdm import tqdm +import glob +from PIL import Image +import argparse +import json +from eval.video_depth.metadata import dataset_metadata + + +def get_args_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--output_dir", + type=str, + default="", + help="value for outdir", + ) + parser.add_argument( + "--eval_dataset", type=str, default="nyu", choices=list(dataset_metadata.keys()) + ) + parser.add_argument( + "--align", + type=str, + default="scale&shift", + choices=["scale&shift", "scale", "metric"], + ) + return parser + + +def main(args): + if args.eval_dataset == "sintel": + TAG_FLOAT = 202021.25 + + def depth_read(filename): + """Read depth data from file, return as numpy array.""" + f = open(filename, "rb") + check = np.fromfile(f, dtype=np.float32, count=1)[0] + assert ( + check == TAG_FLOAT + ), " depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format( + TAG_FLOAT, check + ) + width = np.fromfile(f, dtype=np.int32, count=1)[0] + height = np.fromfile(f, dtype=np.int32, count=1)[0] + size = width * height + assert ( + width > 0 and height > 0 and size > 1 and size < 100000000 + ), " depth_read:: Wrong input size (width = {0}, height = {1}).".format( + width, height + ) + depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width)) + return depth + + pred_pathes = glob.glob( + f"{args.output_dir}/*/frame_*.npy" + ) # TODO: update the path to your prediction + pred_pathes = sorted(pred_pathes) + + if len(pred_pathes) > 643: + full = True + else: + full = False + + if full: + depth_pathes = glob.glob(f"data/sintel/training/depth/*/*.dpt") + depth_pathes = sorted(depth_pathes) + else: + seq_list = [ + "alley_2", + "ambush_4", + "ambush_5", + "ambush_6", + "cave_2", + "cave_4", + "market_2", + "market_5", + "market_6", + "shaman_3", + "sleeping_1", + "sleeping_2", + "temple_2", + "temple_3", + ] + depth_pathes_folder = [ + f"data/sintel/training/depth/{seq}" for seq in seq_list + ] + depth_pathes = [] + for depth_pathes_folder_i in depth_pathes_folder: + depth_pathes += glob.glob(depth_pathes_folder_i + "/*.dpt") + depth_pathes = sorted(depth_pathes) + + def get_video_results(): + grouped_pred_depth = group_by_directory(pred_pathes) + + grouped_gt_depth = group_by_directory(depth_pathes) + gathered_depth_metrics = [] + + for key in tqdm(grouped_pred_depth.keys()): + pd_pathes = grouped_pred_depth[key] + gt_pathes = grouped_gt_depth[key.replace("_pred_depth", "")] + + gt_depth = np.stack( + [depth_read(gt_path) for gt_path in gt_pathes], axis=0 + ) + pr_depth = np.stack( + [ + cv2.resize( + np.load(pd_path), + (gt_depth.shape[2], gt_depth.shape[1]), + interpolation=cv2.INTER_CUBIC, + ) + for pd_path in pd_pathes + ], + axis=0, + ) + # for depth eval, set align_with_lad2=False to use median alignment; set align_with_lad2=True to use scale&shift alignment + if args.align == "scale&shift": + depth_results, error_map, depth_predict, depth_gt = ( + depth_evaluation( + pr_depth, + gt_depth, + max_depth=70, + align_with_lad2=True, + use_gpu=True, + post_clip_max=70, + ) + ) + elif args.align == "scale": + depth_results, error_map, depth_predict, depth_gt = ( + depth_evaluation( + pr_depth, + gt_depth, + max_depth=70, + align_with_scale=True, + use_gpu=True, + post_clip_max=70, + ) + ) + elif args.align == "metric": + depth_results, error_map, depth_predict, depth_gt = ( + depth_evaluation( + pr_depth, + gt_depth, + max_depth=70, + metric_scale=True, + use_gpu=True, + post_clip_max=70, + ) + ) + gathered_depth_metrics.append(depth_results) + + depth_log_path = f"{args.output_dir}/result_{args.align}.json" + average_metrics = { + key: np.average( + [metrics[key] for metrics in gathered_depth_metrics], + weights=[ + metrics["valid_pixels"] for metrics in gathered_depth_metrics + ], + ) + for key in gathered_depth_metrics[0].keys() + if key != "valid_pixels" + } + print("Average depth evaluation metrics:", average_metrics) + with open(depth_log_path, "w") as f: + f.write(json.dumps(average_metrics)) + + get_video_results() + elif args.eval_dataset == "bonn": + + def depth_read(filename): + # loads depth map D from png file + # and returns it as a numpy array + depth_png = np.asarray(Image.open(filename)) + # make sure we have a proper 16bit depth map here.. not 8bit! + assert np.max(depth_png) > 255 + depth = depth_png.astype(np.float64) / 5000.0 + depth[depth_png == 0] = -1.0 + return depth + + seq_list = ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"] + + img_pathes_folder = [ + f"data/bonn/rgbd_bonn_dataset/rgbd_bonn_{seq}/rgb_110/*.png" + for seq in seq_list + ] + img_pathes = [] + for img_pathes_folder_i in img_pathes_folder: + img_pathes += glob.glob(img_pathes_folder_i) + img_pathes = sorted(img_pathes) + depth_pathes_folder = [ + f"data/bonn/rgbd_bonn_dataset/rgbd_bonn_{seq}/depth_110/*.png" + for seq in seq_list + ] + depth_pathes = [] + for depth_pathes_folder_i in depth_pathes_folder: + depth_pathes += glob.glob(depth_pathes_folder_i) + depth_pathes = sorted(depth_pathes) + pred_pathes = glob.glob( + f"{args.output_dir}/*/frame*.npy" + ) # TODO: update the path to your prediction + pred_pathes = sorted(pred_pathes) + + def get_video_results(): + grouped_pred_depth = group_by_directory(pred_pathes) + grouped_gt_depth = group_by_directory(depth_pathes, idx=-2) + gathered_depth_metrics = [] + for key in tqdm(grouped_gt_depth.keys()): + pd_pathes = grouped_pred_depth[key[10:]] + gt_pathes = grouped_gt_depth[key] + gt_depth = np.stack( + [depth_read(gt_path) for gt_path in gt_pathes], axis=0 + ) + pr_depth = np.stack( + [ + cv2.resize( + np.load(pd_path), + (gt_depth.shape[2], gt_depth.shape[1]), + interpolation=cv2.INTER_CUBIC, + ) + for pd_path in pd_pathes + ], + axis=0, + ) + # for depth eval, set align_with_lad2=False to use median alignment; set align_with_lad2=True to use scale&shift alignment + if args.align == "scale&shift": + depth_results, error_map, depth_predict, depth_gt = ( + depth_evaluation( + pr_depth, + gt_depth, + max_depth=70, + align_with_lad2=True, + use_gpu=True, + ) + ) + elif args.align == "scale": + depth_results, error_map, depth_predict, depth_gt = ( + depth_evaluation( + pr_depth, + gt_depth, + max_depth=70, + align_with_scale=True, + use_gpu=True, + ) + ) + elif args.align == "metric": + depth_results, error_map, depth_predict, depth_gt = ( + depth_evaluation( + pr_depth, + gt_depth, + max_depth=70, + metric_scale=True, + use_gpu=True, + ) + ) + gathered_depth_metrics.append(depth_results) + + # seq_len = gt_depth.shape[0] + # error_map = error_map.reshape(seq_len, -1, error_map.shape[-1]).cpu() + # error_map_colored = colorize(error_map, range=(error_map.min(), error_map.max()), append_cbar=True) + # ImageSequenceClip([x for x in (error_map_colored.numpy()*255).astype(np.uint8)], fps=10).write_videofile(f'{args.output_dir}/errormap_{key}_{args.align}.mp4', fps=10) + + depth_log_path = f"{args.output_dir}/result_{args.align}.json" + average_metrics = { + key: np.average( + [metrics[key] for metrics in gathered_depth_metrics], + weights=[ + metrics["valid_pixels"] for metrics in gathered_depth_metrics + ], + ) + for key in gathered_depth_metrics[0].keys() + if key != "valid_pixels" + } + print("Average depth evaluation metrics:", average_metrics) + with open(depth_log_path, "w") as f: + f.write(json.dumps(average_metrics)) + + get_video_results() + elif args.eval_dataset == "kitti": + + def depth_read(filename): + # loads depth map D from png file + # and returns it as a numpy array, + # for details see readme.txt + img_pil = Image.open(filename) + depth_png = np.array(img_pil, dtype=int) + # make sure we have a proper 16bit depth map here.. not 8bit! + assert np.max(depth_png) > 255 + + depth = depth_png.astype(float) / 256.0 + depth[depth_png == 0] = -1.0 + return depth + + depth_pathes = glob.glob( + "data/kitti/depth_selection/val_selection_cropped/groundtruth_depth_gathered/*/*.png" + ) + depth_pathes = sorted(depth_pathes) + pred_pathes = glob.glob( + f"{args.output_dir}/*/frame_*.npy" + ) # TODO: update the path to your prediction + pred_pathes = sorted(pred_pathes) + + def get_video_results(): + grouped_pred_depth = group_by_directory(pred_pathes) + grouped_gt_depth = group_by_directory(depth_pathes) + gathered_depth_metrics = [] + for key in tqdm(grouped_pred_depth.keys()): + pd_pathes = grouped_pred_depth[key] + gt_pathes = grouped_gt_depth[key] + gt_depth = np.stack( + [depth_read(gt_path) for gt_path in gt_pathes], axis=0 + ) + pr_depth = np.stack( + [ + cv2.resize( + np.load(pd_path), + (gt_depth.shape[2], gt_depth.shape[1]), + interpolation=cv2.INTER_CUBIC, + ) + for pd_path in pd_pathes + ], + axis=0, + ) + + # for depth eval, set align_with_lad2=False to use median alignment; set align_with_lad2=True to use scale&shift alignment + if args.align == "scale&shift": + depth_results, error_map, depth_predict, depth_gt = ( + depth_evaluation( + pr_depth, + gt_depth, + max_depth=None, + align_with_lad2=True, + use_gpu=True, + ) + ) + elif args.align == "scale": + depth_results, error_map, depth_predict, depth_gt = ( + depth_evaluation( + pr_depth, + gt_depth, + max_depth=None, + align_with_scale=True, + use_gpu=True, + ) + ) + elif args.align == "metric": + depth_results, error_map, depth_predict, depth_gt = ( + depth_evaluation( + pr_depth, + gt_depth, + max_depth=None, + metric_scale=True, + use_gpu=True, + ) + ) + gathered_depth_metrics.append(depth_results) + + depth_log_path = f"{args.output_dir}/result_{args.align}.json" + average_metrics = { + key: np.average( + [metrics[key] for metrics in gathered_depth_metrics], + weights=[ + metrics["valid_pixels"] for metrics in gathered_depth_metrics + ], + ) + for key in gathered_depth_metrics[0].keys() + if key != "valid_pixels" + } + print("Average depth evaluation metrics:", average_metrics) + with open(depth_log_path, "w") as f: + f.write(json.dumps(average_metrics)) + + get_video_results() + + +if __name__ == "__main__": + args = get_args_parser() + args = args.parse_args() + main(args) diff --git a/eval/video_depth/launch.py b/eval/video_depth/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..99ade77b28ab65d54db321341d6ff145c148de35 --- /dev/null +++ b/eval/video_depth/launch.py @@ -0,0 +1,249 @@ +import os +import sys +import numpy as np +import torch +import argparse +from accelerate import PartialState +from tqdm import tqdm +from PIL import Image +import imageio.v2 as iio + +from stream3r.models.stream3r import STream3R +from stream3r.stream_session import StreamSession +from stream3r.dust3r.utils.image import load_images_for_eval as load_images +from stream3r.dust3r.utils.device import collate_with_cat +from stream3r.utils.utils import ImgDust3r2Stream3r + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +from eval.video_depth.metadata import dataset_metadata +from eval.video_depth.utils import colorize + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +torch.backends.cuda.matmul.allow_tf32 = True + +# avoid high cpu usage +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["OPENBLAS_NUM_THREADS"] = "1" +torch.set_num_threads(1) +# =========================================== + + +def save_depth_maps(pts3ds_self, path, conf_self=None, depth_maps=None): + if depth_maps is None: + depth_maps = torch.stack([pts3d_self[..., -1] for pts3d_self in pts3ds_self], 0) + min_depth = depth_maps.min() # float(torch.quantile(out, 0.01)) + max_depth = depth_maps.max() # float(torch.quantile(out, 0.99)) + colored_depth = colorize( + depth_maps, + cmap_name="Spectral_r", + range=(min_depth, max_depth), + append_cbar=True, + ) + images = [] + + if conf_self is not None: + conf_selfs = torch.concat(conf_self, 0) + min_conf = torch.log(conf_selfs.min()) # float(torch.quantile(out, 0.01)) + max_conf = torch.log(conf_selfs.max()) # float(torch.quantile(out, 0.99)) + colored_conf = colorize( + torch.log(conf_selfs), + cmap_name="jet", + range=(min_conf, max_conf), + append_cbar=True, + ) + + for i, depth_map in enumerate(colored_depth): + # Apply color map to depth map + img_path = f"{path}/frame_{(i):04d}.png" + if conf_self is None: + to_save = (depth_map * 255).detach().cpu().numpy().astype(np.uint8) + else: + to_save = torch.cat([depth_map, colored_conf[i]], dim=1) + to_save = (to_save * 255).detach().cpu().numpy().astype(np.uint8) + iio.imwrite(img_path, to_save) + images.append(Image.open(img_path)) + np.save(f"{path}/frame_{(i):04d}.npy", depth_maps[i].detach().cpu().numpy()) + + return depth_maps + + +def get_args_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument("--device", type=str, default="cuda", help="pytorch device") + parser.add_argument( + "--output_dir", + type=str, + default="", + help="value for outdir", + ) + parser.add_argument( + "--no_crop", type=bool, default=True, help="whether to crop input data" + ) + + parser.add_argument( + "--eval_dataset", + type=str, + default="sintel", + choices=list(dataset_metadata.keys()), + ) + parser.add_argument("--size", type=int, default="512") + + parser.add_argument( + "--pose_eval_stride", default=1, type=int, help="stride for pose evaluation" + ) + parser.add_argument( + "--full_seq", + action="store_true", + default=False, + help="use full sequence for pose evaluation", + ) + parser.add_argument( + "--seq_list", + nargs="+", + default=None, + help="list of sequences for pose evaluation", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default="", + help="path to the checkpoint directory", + ) + return parser + + +def eval_pose_estimation(args, model, save_dir=None): + metadata = dataset_metadata.get(args.eval_dataset) + img_path = metadata["img_path"] + mask_path = metadata["mask_path"] + + ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist( + args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path) + return ate_mean, rpe_trans_mean, rpe_rot_mean + + +def eval_pose_estimation_dist(args, + model, + img_path, + save_dir=None, + mask_path=None): + metadata = dataset_metadata.get(args.eval_dataset) + model.eval() + + seq_list = args.seq_list + + if seq_list is None: + if metadata.get("full_seq", False): + args.full_seq = True + else: + seq_list = metadata.get("seq_list", []) + if args.full_seq: + seq_list = os.listdir(img_path) + seq_list = [ + seq for seq in seq_list + if os.path.isdir(os.path.join(img_path, seq)) + ] + seq_list = sorted(seq_list) + + if save_dir is None: + save_dir = args.output_dir + + distributed_state = PartialState() + model.to(distributed_state.device) + device = distributed_state.device + + with distributed_state.split_between_processes(seq_list) as seqs: + error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process + for seq in tqdm(seqs): + try: + dir_path = metadata["dir_path_func"](img_path, seq) + + # Handle skip_condition + skip_condition = metadata.get("skip_condition", None) + if skip_condition is not None and skip_condition( + save_dir, seq): + continue + + mask_path_seq_func = metadata.get("mask_path_seq_func", + lambda mask_path, seq: None) + mask_path_seq = mask_path_seq_func(mask_path, seq) + + filelist = [ + os.path.join(dir_path, name) + for name in os.listdir(dir_path) + ] + filelist.sort() + filelist = filelist[::args.pose_eval_stride] + + images = load_images( + filelist, + size=518, + verbose=True, + crop=False, + patch_size=14, + ) + + images = collate_with_cat([tuple(images)]) + images = torch.stack([view["img"] for view in images], dim=1) + images = ImgDust3r2Stream3r(images).to(device) + + with torch.no_grad(): + session = StreamSession(model, mode="causal") + for i in range(images.shape[1]): + image = images[:, i:i+1] + predictions = session.forward_stream(image) + + print( + f"Finished depth estmation of {len(filelist)} images" + ) + + os.makedirs(f"{save_dir}/{seq}", exist_ok=True) + save_depth_maps(None, + f"{save_dir}/{seq}", + conf_self=None, + depth_maps=predictions['depth'].squeeze().cpu()) + + except Exception as e: + if "out of memory" in str(e): + # Handle OOM + torch.cuda.empty_cache() # Clear the CUDA memory + with open(error_log_path, "a") as f: + f.write( + f"OOM error in sequence {seq}, skipping this sequence.\n" + ) + print(f"OOM error in sequence {seq}, skipping...") + elif "Degenerate covariance rank" in str( + e) or "Eigenvalues did not converge" in str(e): + # Handle Degenerate covariance rank exception and Eigenvalues did not converge exception + with open(error_log_path, "a") as f: + f.write(f"Exception in sequence {seq}: {str(e)}\n") + print( + f"Traj evaluation error in sequence {seq}, skipping.") + else: + raise e # Rethrow if it's not an expected exception + return None, None, None + + +def main(): + args = get_args_parser() + args = args.parse_args() + + if args.eval_dataset == "sintel": + args.full_seq = True + else: + args.full_seq = False + args.no_crop = True + + model = STream3R.from_pretrained("yslan/STream3R").to(args.device) + model.eval() + + eval_pose_estimation(args, model, save_dir=args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/eval/video_depth/metadata.py b/eval/video_depth/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..48e30d9751af2aafd9c84b2dc059a9ee160a4475 --- /dev/null +++ b/eval/video_depth/metadata.py @@ -0,0 +1,177 @@ +import os +import glob +from tqdm import tqdm + +# Define the merged dataset metadata dictionary +dataset_metadata = { + "davis": { + "img_path": "data/davis/DAVIS/JPEGImages/480p", + "mask_path": "data/davis/DAVIS/masked_images/480p", + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: None, + "traj_format": None, + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq), + "skip_condition": None, + "process_func": None, # Not used in mono depth estimation + }, + "kitti": { + "img_path": "data/kitti/depth_selection/val_selection_cropped/image_gathered", # Default path + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: None, + "traj_format": None, + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": lambda args, img_path: process_kitti(args, img_path), + }, + "bonn": { + "img_path": "data/bonn/rgbd_bonn_dataset", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join( + img_path, f"rgbd_bonn_{seq}", "rgb_110" + ), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt" + ), + "traj_format": "tum", + "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"], + "full_seq": False, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": lambda args, img_path: process_bonn(args, img_path), + }, + "nyu": { + "img_path": "data/nyu-v2/val/nyu_images", + "mask_path": None, + "process_func": lambda args, img_path: process_nyu(args, img_path), + }, + "scannet": { + "img_path": "data/scannetv2", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "pose_90.txt" + ), + "traj_format": "replica", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, # lambda save_dir, seq: os.path.exists(os.path.join(save_dir, seq)), + "process_func": lambda args, img_path: process_scannet(args, img_path), + }, + "tum": { + "img_path": "data/tum", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "groundtruth_90.txt" + ), + "traj_format": "tum", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": None, + }, + "sintel": { + "img_path": "data/sintel/training/final", + "anno_path": "data/sintel/training/camdata_left", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq), + "traj_format": None, + "seq_list": [ + "alley_2", + "ambush_4", + "ambush_5", + "ambush_6", + "cave_2", + "cave_4", + "market_2", + "market_5", + "market_6", + "shaman_3", + "sleeping_1", + "sleeping_2", + "temple_2", + "temple_3", + ], + "full_seq": False, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": lambda args, img_path: process_sintel(args, img_path), + }, +} + + +# Define processing functions for each dataset +def process_kitti(args, img_path): + for dir in tqdm(sorted(glob.glob(f"{img_path}/*"))): + filelist = sorted(glob.glob(f"{dir}/*.png")) + save_dir = f"{args.output_dir}/{os.path.basename(dir)}" + yield filelist, save_dir + + +def process_bonn(args, img_path): + if args.full_seq: + for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))): + filelist = sorted(glob.glob(f"{dir}/rgb/*.png")) + save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}" + yield filelist, save_dir + else: + seq_list = ( + ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"] + if args.seq_list is None + else args.seq_list + ) + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f"{img_path}/rgbd_bonn_{seq}/rgb_110/*.png")) + save_dir = f"{args.output_dir}/{seq}" + yield filelist, save_dir + + +def process_nyu(args, img_path): + filelist = sorted(glob.glob(f"{img_path}/*.png")) + save_dir = f"{args.output_dir}" + yield filelist, save_dir + + +def process_scannet(args, img_path): + seq_list = sorted(glob.glob(f"{img_path}/*")) + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f"{seq}/color_90/*.jpg")) + save_dir = f"{args.output_dir}/{os.path.basename(seq)}" + yield filelist, save_dir + + +def process_sintel(args, img_path): + if args.full_seq: + for dir in tqdm(sorted(glob.glob(f"{img_path}/*/"))): + filelist = sorted(glob.glob(f"{dir}/*.png")) + save_dir = f"{args.output_dir}/{os.path.basename(os.path.dirname(dir))}" + yield filelist, save_dir + else: + seq_list = [ + "alley_2", + "ambush_4", + "ambush_5", + "ambush_6", + "cave_2", + "cave_4", + "market_2", + "market_5", + "market_6", + "shaman_3", + "sleeping_1", + "sleeping_2", + "temple_2", + "temple_3", + ] + for seq in tqdm(seq_list): + filelist = sorted(glob.glob(f"{img_path}/{seq}/*.png")) + save_dir = f"{args.output_dir}/{seq}" + yield filelist, save_dir diff --git a/eval/video_depth/run.sh b/eval/video_depth/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..288c2984ae44ea645b389c854a432efdbd0cc307 --- /dev/null +++ b/eval/video_depth/run.sh @@ -0,0 +1,22 @@ +#!/bin/bash +set -e + +workdir='.' + +datasets=('sintel' 'bonn' 'kitti') +model_name='stream3r' + +for data in "${datasets[@]}"; do + output_dir="${workdir}/eval_results/video_depth/${model_name}/${data}" + echo "$output_dir" + + python eval/video_depth/launch.py \ + --output_dir="$output_dir" \ + --eval_dataset="$data" \ + --checkpoint_dir="$ckpt_dir/${model_name}.ckpt" \ + + python eval/video_depth/eval_depth.py \ + --output_dir "$output_dir" \ + --eval_dataset "$data" \ + --align "scale" +done \ No newline at end of file diff --git a/eval/video_depth/tools.py b/eval/video_depth/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d6786fa6f25def0110ce22dbb7d44a7a08c952c8 --- /dev/null +++ b/eval/video_depth/tools.py @@ -0,0 +1,399 @@ +import torch +import numpy as np +import cv2 +import glob +import argparse +from pathlib import Path +from tqdm import tqdm +from copy import deepcopy +from scipy.optimize import minimize +import os +from collections import defaultdict + + +def group_by_directory(pathes, idx=-1): + """ + Groups the file paths based on the second-to-last directory in their paths. + + Parameters: + - pathes (list): List of file paths. + + Returns: + - dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths. + """ + grouped_pathes = defaultdict(list) + + for path in pathes: + # Extract the second-to-last directory + dir_name = os.path.dirname(path).split("/")[idx] + grouped_pathes[dir_name].append(path) + + return grouped_pathes + + +def depth2disparity(depth, return_mask=False): + if isinstance(depth, torch.Tensor): + disparity = torch.zeros_like(depth) + elif isinstance(depth, np.ndarray): + disparity = np.zeros_like(depth) + non_negtive_mask = depth > 0 + disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] + if return_mask: + return disparity, non_negtive_mask + else: + return disparity + + +def absolute_error_loss(params, predicted_depth, ground_truth_depth): + s, t = params + + predicted_aligned = s * predicted_depth + t + + abs_error = np.abs(predicted_aligned - ground_truth_depth) + return np.sum(abs_error) + + +def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0): + predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1) + ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1) + + initial_params = [s, t] # s = 1, t = 0 + + result = minimize( + absolute_error_loss, + initial_params, + args=(predicted_depth_np, ground_truth_depth_np), + ) + + s, t = result.x + return s, t + + +def absolute_value_scaling2( + predicted_depth, + ground_truth_depth, + s_init=1.0, + t_init=0.0, + lr=1e-4, + max_iters=1000, + tol=1e-6, +): + # Initialize s and t as torch tensors with requires_grad=True + s = torch.tensor( + [s_init], + requires_grad=True, + device=predicted_depth.device, + dtype=predicted_depth.dtype, + ) + t = torch.tensor( + [t_init], + requires_grad=True, + device=predicted_depth.device, + dtype=predicted_depth.dtype, + ) + + optimizer = torch.optim.Adam([s, t], lr=lr) + + prev_loss = None + + for i in range(max_iters): + optimizer.zero_grad() + + # Compute predicted aligned depth + predicted_aligned = s * predicted_depth + t + + # Compute absolute error + abs_error = torch.abs(predicted_aligned - ground_truth_depth) + + # Compute loss + loss = torch.sum(abs_error) + + # Backpropagate + loss.backward() + + # Update parameters + optimizer.step() + + # Check convergence + if prev_loss is not None and torch.abs(prev_loss - loss) < tol: + break + + prev_loss = loss.item() + + return s.detach().item(), t.detach().item() + + +def depth_evaluation( + predicted_depth_original, + ground_truth_depth_original, + max_depth=80, + custom_mask=None, + post_clip_min=None, + post_clip_max=None, + pre_clip_min=None, + pre_clip_max=None, + align_with_lstsq=False, + align_with_lad=False, + align_with_lad2=False, + metric_scale=False, + lr=1e-4, + max_iters=1000, + use_gpu=False, + align_with_scale=False, + disp_input=False, +): + """ + Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment. + + Args: + predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map. + ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map. + max_depth (float): The maximum depth value to consider. Default is 80 meters. + align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth. + + Returns: + dict: A dictionary containing the evaluation metrics. + torch.Tensor: The depth error parity map. + """ + if isinstance(predicted_depth_original, np.ndarray): + predicted_depth_original = torch.from_numpy(predicted_depth_original) + if isinstance(ground_truth_depth_original, np.ndarray): + ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original) + if custom_mask is not None and isinstance(custom_mask, np.ndarray): + custom_mask = torch.from_numpy(custom_mask) + + # if the dimension is 3, flatten to 2d along the batch dimension + if predicted_depth_original.dim() == 3: + _, h, w = predicted_depth_original.shape + predicted_depth_original = predicted_depth_original.view(-1, w) + ground_truth_depth_original = ground_truth_depth_original.view(-1, w) + if custom_mask is not None: + custom_mask = custom_mask.view(-1, w) + + # put to device + if use_gpu: + predicted_depth_original = predicted_depth_original.cuda() + ground_truth_depth_original = ground_truth_depth_original.cuda() + + # Filter out depths greater than max_depth + if max_depth is not None: + mask = (ground_truth_depth_original > 0) & ( + ground_truth_depth_original < max_depth + ) + else: + mask = ground_truth_depth_original > 0 + predicted_depth = predicted_depth_original[mask] + ground_truth_depth = ground_truth_depth_original[mask] + + # Clip the depth values + if pre_clip_min is not None: + predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min) + if pre_clip_max is not None: + predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max) + + if disp_input: # align the pred to gt in the disparity space + real_gt = ground_truth_depth.clone() + ground_truth_depth = 1 / (ground_truth_depth + 1e-8) + + # various alignment methods + if metric_scale: + predicted_depth = predicted_depth + elif align_with_lstsq: + # Convert to numpy for lstsq + predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1) + ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1) + + # Add a column of ones for the shift term + A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)]) + + # Solve for scale (s) and shift (t) using least squares + result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None) + s, t = result[0][0], result[0][1] + + # convert to torch tensor + s = torch.tensor(s, device=predicted_depth_original.device) + t = torch.tensor(t, device=predicted_depth_original.device) + + # Apply scale and shift + predicted_depth = s * predicted_depth + t + elif align_with_lad: + s, t = absolute_value_scaling( + predicted_depth, + ground_truth_depth, + s=torch.median(ground_truth_depth) / torch.median(predicted_depth), + ) + predicted_depth = s * predicted_depth + t + elif align_with_lad2: + s_init = ( + torch.median(ground_truth_depth) / torch.median(predicted_depth) + ).item() + s, t = absolute_value_scaling2( + predicted_depth, + ground_truth_depth, + s_init=s_init, + lr=lr, + max_iters=max_iters, + ) + predicted_depth = s * predicted_depth + t + elif align_with_scale: + # Compute initial scale factor 's' using the closed-form solution (L2 norm) + dot_pred_gt = torch.nanmean(ground_truth_depth) + dot_pred_pred = torch.nanmean(predicted_depth) + s = dot_pred_gt / dot_pred_pred + + # Iterative reweighted least squares using the Weiszfeld method + for _ in range(10): + # Compute residuals between scaled predictions and ground truth + residuals = s * predicted_depth - ground_truth_depth + abs_residuals = ( + residuals.abs() + 1e-8 + ) # Add small constant to avoid division by zero + + # Compute weights inversely proportional to the residuals + weights = 1.0 / abs_residuals + + # Update 's' using weighted sums + weighted_dot_pred_gt = torch.sum( + weights * predicted_depth * ground_truth_depth + ) + weighted_dot_pred_pred = torch.sum(weights * predicted_depth**2) + s = weighted_dot_pred_gt / weighted_dot_pred_pred + + # Optionally clip 's' to prevent extreme scaling + s = s.clamp(min=1e-3) + + # Detach 's' if you want to stop gradients from flowing through it + s = s.detach() + + # Apply the scale factor to the predicted depth + predicted_depth = s * predicted_depth + + else: + # Align the predicted depth with the ground truth using median scaling + scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth) + predicted_depth *= scale_factor + + if disp_input: + # convert back to depth + ground_truth_depth = real_gt + predicted_depth = depth2disparity(predicted_depth) + + # Clip the predicted depth values + if post_clip_min is not None: + predicted_depth = torch.clamp(predicted_depth, min=post_clip_min) + if post_clip_max is not None: + predicted_depth = torch.clamp(predicted_depth, max=post_clip_max) + + if custom_mask is not None: + assert custom_mask.shape == ground_truth_depth_original.shape + mask_within_mask = custom_mask.cpu()[mask] + predicted_depth = predicted_depth[mask_within_mask] + ground_truth_depth = ground_truth_depth[mask_within_mask] + + # Calculate the metrics + abs_rel = torch.mean( + torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth + ).item() + sq_rel = torch.mean( + ((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth + ).item() + + # Correct RMSE calculation + rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item() + + # Clip the depth values to avoid log(0) + predicted_depth = torch.clamp(predicted_depth, min=1e-5) + log_rmse = torch.sqrt( + torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2) + ).item() + + # Calculate the accuracy thresholds + max_ratio = torch.maximum( + predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth + ) + threshold_0 = torch.mean((max_ratio < 1.0).float()).item() + threshold_1 = torch.mean((max_ratio < 1.25).float()).item() + threshold_2 = torch.mean((max_ratio < 1.25**2).float()).item() + threshold_3 = torch.mean((max_ratio < 1.25**3).float()).item() + + # Compute the depth error parity map + if metric_scale: + predicted_depth_original = predicted_depth_original + if disp_input: + predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = ( + torch.abs(predicted_depth_original - ground_truth_depth_original) + / ground_truth_depth_original + ) + elif align_with_lstsq or align_with_lad or align_with_lad2: + predicted_depth_original = predicted_depth_original * s + t + if disp_input: + predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = ( + torch.abs(predicted_depth_original - ground_truth_depth_original) + / ground_truth_depth_original + ) + elif align_with_scale: + predicted_depth_original = predicted_depth_original * s + if disp_input: + predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = ( + torch.abs(predicted_depth_original - ground_truth_depth_original) + / ground_truth_depth_original + ) + else: + predicted_depth_original = predicted_depth_original * scale_factor + if disp_input: + predicted_depth_original = depth2disparity(predicted_depth_original) + depth_error_parity_map = ( + torch.abs(predicted_depth_original - ground_truth_depth_original) + / ground_truth_depth_original + ) + + # Reshape the depth_error_parity_map back to the original image size + depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original) + depth_error_parity_map_full = torch.where( + mask, depth_error_parity_map, depth_error_parity_map_full + ) + + predict_depth_map_full = predicted_depth_original + gt_depth_map_full = torch.zeros_like(ground_truth_depth_original) + gt_depth_map_full = torch.where( + mask, ground_truth_depth_original, gt_depth_map_full + ) + + num_valid_pixels = ( + torch.sum(mask).item() + if custom_mask is None + else torch.sum(mask_within_mask).item() + ) + if num_valid_pixels == 0: + ( + abs_rel, + sq_rel, + rmse, + log_rmse, + threshold_0, + threshold_1, + threshold_2, + threshold_3, + ) = (0, 0, 0, 0, 0, 0, 0, 0) + + results = { + "Abs Rel": abs_rel, + "Sq Rel": sq_rel, + "RMSE": rmse, + "Log RMSE": log_rmse, + "δ < 1.": threshold_0, + "δ < 1.25": threshold_1, + "δ < 1.25^2": threshold_2, + "δ < 1.25^3": threshold_3, + "valid_pixels": num_valid_pixels, + } + + return ( + results, + depth_error_parity_map_full, + predict_depth_map_full, + gt_depth_map_full, + ) diff --git a/eval/video_depth/utils.py b/eval/video_depth/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c639c0b266b3197f61a2e3df8f5ddc9b213a8f30 --- /dev/null +++ b/eval/video_depth/utils.py @@ -0,0 +1,237 @@ +from copy import deepcopy +import cv2 + +import numpy as np +import torch +import torch.nn as nn +import roma +from copy import deepcopy +import tqdm +import matplotlib as mpl +import matplotlib.cm as cm +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg +from scipy.spatial.transform import Rotation +from PIL import Image +import imageio.v2 as iio +from matplotlib.figure import Figure + + +def save_focals(cam_dict, path): + # convert focal to txt + focals = cam_dict["focal"] + np.savetxt(path, focals, fmt="%.6f") + return focals + + +def save_intrinsics(cam_dict, path): + K_raw = np.eye(3)[None].repeat(len(cam_dict["focal"]), axis=0) + K_raw[:, 0, 0] = cam_dict["focal"] + K_raw[:, 1, 1] = cam_dict["focal"] + K_raw[:, :2, 2] = cam_dict["pp"] + K = K_raw.reshape(-1, 9) + np.savetxt(path, K, fmt="%.6f") + return K_raw + + +def save_conf_maps(conf, path): + for i, c in enumerate(conf): + np.save(f"{path}/conf_{i}.npy", c.detach().cpu().numpy()) + return conf + + +def save_rgb_imgs(colors, path): + imgs = colors + for i, img in enumerate(imgs): + # convert from rgb to bgr + iio.imwrite( + f"{path}/frame_{i:04d}.jpg", (img.cpu().numpy() * 255).astype(np.uint8) + ) + return imgs + + +def save_depth_maps(pts3ds_self, path, conf_self=None, depth_maps=None): + if depth_maps is None: + depth_maps = torch.stack([pts3d_self[..., -1] for pts3d_self in pts3ds_self], 0) + min_depth = depth_maps.min() # float(torch.quantile(out, 0.01)) + max_depth = depth_maps.max() # float(torch.quantile(out, 0.99)) + colored_depth = colorize( + depth_maps, + cmap_name="Spectral_r", + range=(min_depth, max_depth), + append_cbar=True, + ) + images = [] + + if conf_self is not None: + conf_selfs = torch.concat(conf_self, 0) + min_conf = torch.log(conf_selfs.min()) # float(torch.quantile(out, 0.01)) + max_conf = torch.log(conf_selfs.max()) # float(torch.quantile(out, 0.99)) + colored_conf = colorize( + torch.log(conf_selfs), + cmap_name="jet", + range=(min_conf, max_conf), + append_cbar=True, + ) + + for i, depth_map in enumerate(colored_depth): + # Apply color map to depth map + img_path = f"{path}/frame_{(i):04d}.png" + if conf_self is None: + to_save = (depth_map * 255).detach().cpu().numpy().astype(np.uint8) + else: + to_save = torch.cat([depth_map, colored_conf[i]], dim=1) + to_save = (to_save * 255).detach().cpu().numpy().astype(np.uint8) + iio.imwrite(img_path, to_save) + images.append(Image.open(img_path)) + np.save(f"{path}/frame_{(i):04d}.npy", depth_maps[i].detach().cpu().numpy()) + + # comment this as it may fail sometimes + # images[0].save(f'{path}/_depth_maps.gif', save_all=True, append_images=images[1:], duration=100, loop=0) + + return depth_maps + + +def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2): + """ + :param w: pixels + :param h: pixels + :param vmin: min value + :param vmax: max value + :param cmap_name: + :param label + :return: + """ + fig = Figure(figsize=(2, 8), dpi=100) + fig.subplots_adjust(right=1.5) + canvas = FigureCanvasAgg(fig) + + # Do some plotting. + ax = fig.add_subplot(111) + cmap = cm.get_cmap(cmap_name) + norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + tick_cnt = 6 + tick_loc = np.linspace(vmin, vmax, tick_cnt) + cb1 = mpl.colorbar.ColorbarBase( + ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical" + ) + + tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc] + if cbar_precision == 0: + tick_label = [x[:-2] for x in tick_label] + + cb1.set_ticklabels(tick_label) + + cb1.ax.tick_params(labelsize=18, rotation=0) + if label is not None: + cb1.set_label(label) + + # fig.tight_layout() + + canvas.draw() + s, (width, height) = canvas.print_to_buffer() + + im = np.frombuffer(s, np.uint8).reshape((height, width, 4)) + + im = im[:, :, :3].astype(np.float32) / 255.0 + if h != im.shape[0]: + w = int(im.shape[1] / im.shape[0] * h) + im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA) + + return im + + +def colorize_np( + x, + cmap_name="jet", + mask=None, + range=None, + append_cbar=False, + cbar_in_image=False, + cbar_precision=2, +): + """ + turn a grayscale image into a color image + :param x: input grayscale, [H, W] + :param cmap_name: the colorization method + :param mask: the mask image, [H, W] + :param range: the range for scaling, automatic if None, [min, max] + :param append_cbar: if append the color bar + :param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image + :return: colorized image, [H, W] + """ + if range is not None: + vmin, vmax = range + elif mask is not None: + # vmin, vmax = np.percentile(x[mask], (2, 100)) + vmin = np.min(x[mask][np.nonzero(x[mask])]) + vmax = np.max(x[mask]) + # vmin = vmin - np.abs(vmin) * 0.01 + x[np.logical_not(mask)] = vmin + # print(vmin, vmax) + else: + vmin, vmax = np.percentile(x, (1, 100)) + vmax += 1e-6 + + x = np.clip(x, vmin, vmax) + x = (x - vmin) / (vmax - vmin) + # x = np.clip(x, 0., 1.) + + cmap = cm.get_cmap(cmap_name) + x_new = cmap(x)[:, :, :3] + + if mask is not None: + mask = np.float32(mask[:, :, np.newaxis]) + x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask) + + cbar = get_vertical_colorbar( + h=x.shape[0], + vmin=vmin, + vmax=vmax, + cmap_name=cmap_name, + cbar_precision=cbar_precision, + ) + + if append_cbar: + if cbar_in_image: + x_new[:, -cbar.shape[1] :, :] = cbar + else: + x_new = np.concatenate( + (x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1 + ) + return x_new + else: + return x_new + + +# tensor +def colorize( + x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False +): + """ + turn a grayscale image into a color image + :param x: torch.Tensor, grayscale image, [H, W] or [B, H, W] + :param mask: torch.Tensor or None, mask image, [H, W] or [B, H, W] or None + """ + + device = x.device + x = x.cpu().numpy() + if mask is not None: + mask = mask.cpu().numpy() > 0.99 + kernel = np.ones((3, 3), np.uint8) + + if x.ndim == 2: + x = x[None] + if mask is not None: + mask = mask[None] + + out = [] + for x_ in x: + if mask is not None: + mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool) + + x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image) + out.append(torch.from_numpy(x_).to(device).float()) + out = torch.stack(out).squeeze(0) + return out diff --git a/examples/dynamic_car/00.jpg b/examples/dynamic_car/00.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c7d006ff27f18fec58a2235f53db1d8134ce1d14 --- /dev/null +++ b/examples/dynamic_car/00.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8889a71ccb90a785ebdeb2272a0e9cd214e6753068de8943b49a904c4ee646c1 +size 142463 diff --git a/examples/dynamic_car/01.jpg b/examples/dynamic_car/01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..af3a28636271e73d1b9f39f3102065e4d42fd103 --- /dev/null +++ b/examples/dynamic_car/01.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3fdf9c525e8d5e02aa4b3bb48aed8aaadd104f41dcbd44ff537287cf6ce7548a +size 131152 diff --git a/examples/dynamic_car/02.jpg b/examples/dynamic_car/02.jpg new file mode 100644 index 0000000000000000000000000000000000000000..05653f623c24f69f5e308678e13d3e9342ee04da --- /dev/null +++ b/examples/dynamic_car/02.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbe380d18d08d458f6ec8bd0ddec427312237942876e86107a9977687bb2cede +size 132169 diff --git a/examples/dynamic_car/03.jpg b/examples/dynamic_car/03.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a70ce2157d41c1bdf98b648590838b347d440f70 --- /dev/null +++ b/examples/dynamic_car/03.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9536ed20f287a23ee5d7411bd32244315bcf1979c3197e0b870deb5a08a65aa +size 134919 diff --git a/examples/dynamic_car/04.jpg b/examples/dynamic_car/04.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c471d0f6e26882660eaa2d7eac422aa9cf43ff22 --- /dev/null +++ b/examples/dynamic_car/04.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d1f7ec6b6c018bd34bd627313308f5dd52f7bf73a704c8e0b9b159bb0bcb0f2 +size 134127 diff --git a/examples/dynamic_car/05.jpg b/examples/dynamic_car/05.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f0b3370257c4d02ab8e8843972779d91f8d6f8d0 --- /dev/null +++ b/examples/dynamic_car/05.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b5145f6dd62243fa1368104e91409fc6bcc04b729682f4be56066e4c35f0397 +size 134778 diff --git a/examples/static_room/00.png b/examples/static_room/00.png new file mode 100644 index 0000000000000000000000000000000000000000..653af99154c795ffd6d3fbaba35736eb9ece780e --- /dev/null +++ b/examples/static_room/00.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c52d7c29412217255fa3c5323e882297b460a7840a8ea2867c2d7a483d198b9b +size 191389 diff --git a/examples/static_room/01.png b/examples/static_room/01.png new file mode 100644 index 0000000000000000000000000000000000000000..d96c3900c8fc6156603a92ed9f3d4553b71c9c58 --- /dev/null +++ b/examples/static_room/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8f3348e4a779bcec006653a5c6920192344a17893bfab2f662d7e659b555709 +size 171643 diff --git a/examples/static_room/02.png b/examples/static_room/02.png new file mode 100644 index 0000000000000000000000000000000000000000..2ed25b1d21be318f20afce18225b0ed7271dd94e --- /dev/null +++ b/examples/static_room/02.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ba96e5378cc9cc71c6971be9902130297ce3d53cfa82f3812809adf8fae7ae2 +size 161186 diff --git a/examples/static_room/03.png b/examples/static_room/03.png new file mode 100644 index 0000000000000000000000000000000000000000..42c21aa78b34fdbfce8864f0ee5899f17fe157d3 --- /dev/null +++ b/examples/static_room/03.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:062e91e26356f9a16252ffdae377b580b145bd3623e2c6a2f51bbcb3de50426f +size 178446 diff --git a/examples/static_room/04.png b/examples/static_room/04.png new file mode 100644 index 0000000000000000000000000000000000000000..e87f3ef275bb3c48b991beea69db6202def61989 --- /dev/null +++ b/examples/static_room/04.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02c0f01192033c6207aaeb8b42dea0be09af32ea015272500769d365130d2eee +size 187766 diff --git a/examples/static_room/05.png b/examples/static_room/05.png new file mode 100644 index 0000000000000000000000000000000000000000..9a28d850d1f12a8b6816250d3a05ee166b0fb800 --- /dev/null +++ b/examples/static_room/05.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb1ace8553fd327f22288a92eb94d1347b6e2e4ec79d3b0939e5fe9ea6f3d3d1 +size 202827 diff --git a/examples/static_room/06.png b/examples/static_room/06.png new file mode 100644 index 0000000000000000000000000000000000000000..b2d0bbf440b3c0f52f70d633d81e428a9e3d4126 --- /dev/null +++ b/examples/static_room/06.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7840250c6ba0a656401715a435cc38ee5c83c35ffa6a932c0329f4cefa03fb25 +size 209898 diff --git a/examples/static_room/07.png b/examples/static_room/07.png new file mode 100644 index 0000000000000000000000000000000000000000..8fe40f3c422b489d3013cbaaa3c2d7d10e06b510 --- /dev/null +++ b/examples/static_room/07.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c2c5e765c2594b26240714179df55972a6f52ea56fccb9b7da3179af05f8d2f +size 226887 diff --git a/examples/static_room/08.png b/examples/static_room/08.png new file mode 100644 index 0000000000000000000000000000000000000000..e3cc1a55d2d123ceb261a11f0c87ce07d75318f0 --- /dev/null +++ b/examples/static_room/08.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79565ebdacc4222e961c077e37ed98c9ce9ae589fa5a9148b1475a85c828c33b +size 168540 diff --git a/examples/static_room/09.png b/examples/static_room/09.png new file mode 100644 index 0000000000000000000000000000000000000000..426af197b47b9936f54b8a83dd7ab8b80caf5ead --- /dev/null +++ b/examples/static_room/09.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fe465ac4ec5aa33f41619fbc1361257eaf8f34e8ddce70c57dc21cd90b3bd9f +size 175460 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..434684e0c50d6649a5411175c0165454bdb19090 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,51 @@ +# --------- lightning --------- # +lightning>=2.5.0 +lightning-bolts +lightning-utilities +torchmetrics +torchinfo + +# --------- hydra --------- # +hydra-core +hydra-colorlog +hydra-optuna-sweeper + +# --------- loggers --------- # +wandb + +# --------- others --------- # +rootutils # standardizing the project root setup +pre-commit # hooks for applying linters on commit +rich # beautiful text formatting in terminal +pytest # tests +nvitop +h5py + +# --------- project core --------- # +trimesh +roma +open3d +einops +deepspeed +gradio +matplotlib +tqdm +jupyterlab +opencv-python +viser +pillow_heif +plotly +scikit-image +scikit-learn +scipy +seaborn +pyglet<2 +huggingface-hub[torch]>=0.22 + +# --------- eval --------- # +accelerate +evo + +# --------- demo --------- # +gradio==5.17.1 +onnxruntime \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..cc8c3a9194c40f62b06907d9cd1a11654bc8044d --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import find_packages, setup + +setup( + name="stream3r", + version="1.0", + packages=find_packages(include=["stream3r"]), +) diff --git a/stream3r/__init__.py b/stream3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a49a3a9427e67836de554fed0bc7f6466adbe06 --- /dev/null +++ b/stream3r/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + diff --git a/stream3r/__pycache__/__init__.cpython-311.pyc b/stream3r/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ac122fe40efb693b54bd41bd29ef617cf38118e Binary files /dev/null and b/stream3r/__pycache__/__init__.cpython-311.pyc differ diff --git a/stream3r/__pycache__/stream_session.cpython-311.pyc b/stream3r/__pycache__/stream_session.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0eae57e7f52e973ab1d1dfc779938d5a0806d76 Binary files /dev/null and b/stream3r/__pycache__/stream_session.cpython-311.pyc differ diff --git a/stream3r/croco/LICENSE b/stream3r/croco/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c9342d78f441dccebe0fe4461ad6a791196ef484 --- /dev/null +++ b/stream3r/croco/LICENSE @@ -0,0 +1,52 @@ +CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license. + +A summary of the CC BY-NC-SA 4.0 license is located here: + https://creativecommons.org/licenses/by-nc-sa/4.0/ + +The CC BY-NC-SA 4.0 license is located here: + https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode + + +SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py + +*************************** + +NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py + +This software is being redistributed in a modifiled form. The original form is available here: + +https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +This software in this file incorporates parts of the following software available here: + +Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py +available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE + +MoCo v3: https://github.com/facebookresearch/moco-v3 +available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE + +DeiT: https://github.com/facebookresearch/deit +available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE + + +ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW: + +https://github.com/facebookresearch/mae/blob/main/LICENSE + +Attribution-NonCommercial 4.0 International + +*************************** + +NOTICE WITH RESPECT TO THE FILE: models/blocks.py + +This software is being redistributed in a modifiled form. The original form is available here: + +https://github.com/rwightman/pytorch-image-models + +ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW: + +https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ diff --git a/stream3r/croco/NOTICE b/stream3r/croco/NOTICE new file mode 100644 index 0000000000000000000000000000000000000000..2a44a6a89e9ace025db2b713442d91e47bfc4656 --- /dev/null +++ b/stream3r/croco/NOTICE @@ -0,0 +1,21 @@ +CroCo +Copyright 2022-present NAVER Corp. + +This project contains subcomponents with separate copyright notices and license terms. +Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. + +==== + +facebookresearch/mae +https://github.com/facebookresearch/mae + +Attribution-NonCommercial 4.0 International + +==== + +rwightman/pytorch-image-models +https://github.com/rwightman/pytorch-image-models + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ diff --git a/stream3r/croco/README.MD b/stream3r/croco/README.MD new file mode 100644 index 0000000000000000000000000000000000000000..ecc8f8263b52a0ec3b826de0418e946a9783ed36 --- /dev/null +++ b/stream3r/croco/README.MD @@ -0,0 +1,124 @@ +# CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow + +[[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)] + +This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2: + +![image](assets/arch.jpg) + +```bibtex +@inproceedings{croco, + title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}}, + author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}}, + booktitle={{NeurIPS}}, + year={2022} +} + +@inproceedings{croco_v2, + title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}}, + author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me}, + booktitle={ICCV}, + year={2023} +} +``` + +## License + +The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information. +Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License. +Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license. + +## Preparation + +1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version. + +```bash +conda create -n croco python=3.7 cmake=3.14.0 +conda activate croco +conda install habitat-sim headless -c conda-forge -c aihabitat +conda install pytorch torchvision -c pytorch +conda install notebook ipykernel matplotlib +conda install ipywidgets widgetsnbextension +conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation + +``` + +2. Compile cuda kernels for RoPE + +CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels. +```bash +cd models/curope/ +python setup.py build_ext --inplace +cd ../../ +``` + +This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only. +You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation. + +In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded. + +3. Download pre-trained model + +We provide several pre-trained models: + +| modelname | pre-training data | pos. embed. | Encoder | Decoder | +|------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------| +| [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small | +| [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small | +| [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base | +| [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base | + +To download a specific model, i.e., the first one (`CroCo.pth`) +```bash +mkdir -p pretrained_models/ +wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/ +``` + +## Reconstruction example + +Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`) +```bash +python demo.py +``` + +## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator + +First download the test scene from Habitat: +```bash +python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/ +``` + +Then, run the Notebook demo `interactive_demo.ipynb`. + +In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo. +![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg) + +## Pre-training + +### CroCo + +To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command: +``` +torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/ +``` + +Our CroCo pre-training was launched on a single server with 4 GPUs. +It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training. +Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case. +The first run can take a few minutes to start, to parse all available pre-training pairs. + +### CroCo v2 + +For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD). +Then, run the following command for the largest model (ViT-L encoder, Base decoder): +``` +torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/ +``` + +Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases. +The largest model should take around 12 days on A100. +Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case. + +## Stereo matching and Optical flow downstream tasks + +For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD). diff --git a/stream3r/croco/assets/Chateau1.png b/stream3r/croco/assets/Chateau1.png new file mode 100644 index 0000000000000000000000000000000000000000..295b00e46972ffcacaca60c2c7c7ec7a04c762fa --- /dev/null +++ b/stream3r/croco/assets/Chateau1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71ffb8c7d77e5ced0bb3dcd2cb0db84d0e98e6ff5ffd2d02696a7156e5284857 +size 112106 diff --git a/stream3r/croco/assets/Chateau2.png b/stream3r/croco/assets/Chateau2.png new file mode 100644 index 0000000000000000000000000000000000000000..97b3c058ff180a6d0c0853ab533b0823a06f8425 --- /dev/null +++ b/stream3r/croco/assets/Chateau2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3a0be9e19f6b89491d692c71e3f2317c2288a898a990561d48b7667218b47c8 +size 109905 diff --git a/stream3r/croco/assets/arch.jpg b/stream3r/croco/assets/arch.jpg new file mode 100644 index 0000000000000000000000000000000000000000..894c58e25c2d9ee0b579c6f5a6ce78d12217d106 --- /dev/null +++ b/stream3r/croco/assets/arch.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05fbf12896a79819a3864a800b174896bd3b6fa29b4f4f580d06725ff7c30dc7 +size 74842 diff --git a/stream3r/croco/croco-stereo-flow-demo.ipynb b/stream3r/croco/croco-stereo-flow-demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2b00a7607ab5f82d1857041969bfec977e56b3e0 --- /dev/null +++ b/stream3r/croco/croco-stereo-flow-demo.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9bca0f41", + "metadata": {}, + "source": [ + "# Simple inference example with CroCo-Stereo or CroCo-Flow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80653ef7", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n", + "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)." + ] + }, + { + "cell_type": "markdown", + "id": "4f033862", + "metadata": {}, + "source": [ + "First download the model(s) of your choice by running\n", + "```\n", + "bash stereoflow/download_model.sh crocostereo.pth\n", + "bash stereoflow/download_model.sh crocoflow.pth\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fb2e392", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n", + "device = torch.device('cuda:0' if use_gpu else 'cpu')\n", + "import matplotlib.pylab as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0e25d77", + "metadata": {}, + "outputs": [], + "source": [ + "from stereoflow.test import _load_model_and_criterion\n", + "from stereoflow.engine import tiled_pred\n", + "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n", + "from stereoflow.datasets_flow import flowToColor\n", + "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower" + ] + }, + { + "cell_type": "markdown", + "id": "86a921f5", + "metadata": {}, + "source": [ + "### CroCo-Stereo example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64e483cb", + "metadata": {}, + "outputs": [], + "source": [ + "image1 = np.asarray(Image.open(''))\n", + "image2 = np.asarray(Image.open(''))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0d04303", + "metadata": {}, + "outputs": [], + "source": [ + "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47dc14b5", + "metadata": {}, + "outputs": [], + "source": [ + "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", + "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", + "with torch.inference_mode():\n", + " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", + "pred = pred.squeeze(0).squeeze(0).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "583b9f16", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(vis_disparity(pred))\n", + "plt.axis('off')" + ] + }, + { + "cell_type": "markdown", + "id": "d2df5d70", + "metadata": {}, + "source": [ + "### CroCo-Flow example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ee257a7", + "metadata": {}, + "outputs": [], + "source": [ + "image1 = np.asarray(Image.open(''))\n", + "image2 = np.asarray(Image.open(''))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5edccf0", + "metadata": {}, + "outputs": [], + "source": [ + "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b19692c3", + "metadata": {}, + "outputs": [], + "source": [ + "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", + "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", + "with torch.inference_mode():\n", + " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", + "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26f79db3", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(flowToColor(pred))\n", + "plt.axis('off')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/stream3r/croco/datasets/__init__.py b/stream3r/croco/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a49a3a9427e67836de554fed0bc7f6466adbe06 --- /dev/null +++ b/stream3r/croco/datasets/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + diff --git a/stream3r/croco/datasets/crops/README.MD b/stream3r/croco/datasets/crops/README.MD new file mode 100644 index 0000000000000000000000000000000000000000..71b97a61084536305bfca2cebabed89a16340e0a --- /dev/null +++ b/stream3r/croco/datasets/crops/README.MD @@ -0,0 +1,104 @@ +## Generation of crops from the real datasets + +The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL. + +### Download the metadata of the crops to generate + +First, download the metadata and put them in `./data/`: +``` +mkdir -p data +cd data/ +wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip +unzip crop_metadata.zip +rm crop_metadata.zip +cd .. +``` + +### Prepare the original datasets + +Second, download the original datasets in `./data/original_datasets/`. +``` +mkdir -p data/original_datasets +``` + +##### ARKitScenes + +Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`. +The resulting file structure should be like: +``` +./data/original_datasets/ARKitScenes/ +└───Training + └───40753679 + │ │ ultrawide + │ │ ... + └───40753686 + │ + ... +``` + +##### MegaDepth + +Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`. +The resulting file structure should be like: + +``` +./data/original_datasets/MegaDepth/ +└───0000 +│ └───images +│ │ │ 1000557903_87fa96b8a4_o.jpg +│ │ └ ... +│ └─── ... +└───0001 +│ │ +│ └ ... +└─── ... +``` + +##### 3DStreetView + +Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`. +The resulting file structure should be like: + +``` +./data/original_datasets/3DStreetView/ +└───dataset_aligned +│ └───0002 +│ │ │ 0000002_0000001_0000002_0000001.jpg +│ │ └ ... +│ └─── ... +└───dataset_unaligned +│ └───0003 +│ │ │ 0000003_0000001_0000002_0000001.jpg +│ │ └ ... +│ └─── ... +``` + +##### IndoorVL + +Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture). + +``` +pip install kapture +mkdir -p ./data/original_datasets/IndoorVL +cd ./data/original_datasets/IndoorVL +kapture_download_dataset.py update +kapture_download_dataset.py install "HyundaiDepartmentStore_*" +kapture_download_dataset.py install "GangnamStation_*" +cd - +``` + +### Extract the crops + +Now, extract the crops for each of the dataset: +``` +for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL; +do + python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500; +done +``` + +##### Note for IndoorVL + +Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper. +To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively. +The impact on the performance is negligible. diff --git a/stream3r/croco/datasets/crops/extract_crops_from_images.py b/stream3r/croco/datasets/crops/extract_crops_from_images.py new file mode 100644 index 0000000000000000000000000000000000000000..ab766697a78771b38577987ed52868e59ec759bb --- /dev/null +++ b/stream3r/croco/datasets/crops/extract_crops_from_images.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Extracting crops for pre-training +# -------------------------------------------------------- + +import argparse +import functools +import math +import os +from multiprocessing import Pool + +from PIL import Image +from tqdm import tqdm + + +def arg_parser(): + parser = argparse.ArgumentParser( + "Generate cropped image pairs from image crop list" + ) + + parser.add_argument("--crops", type=str, required=True, help="crop file") + parser.add_argument("--root-dir", type=str, required=True, help="root directory") + parser.add_argument( + "--output-dir", type=str, required=True, help="output directory" + ) + parser.add_argument("--imsize", type=int, default=256, help="size of the crops") + parser.add_argument( + "--nthread", type=int, required=True, help="number of simultaneous threads" + ) + parser.add_argument( + "--max-subdir-levels", + type=int, + default=5, + help="maximum number of subdirectories", + ) + parser.add_argument( + "--ideal-number-pairs-in-dir", + type=int, + default=500, + help="number of pairs stored in a dir", + ) + return parser + + +def main(args): + listing_path = os.path.join(args.output_dir, "listing.txt") + + print(f"Loading list of crops ... ({args.nthread} threads)") + crops, num_crops_to_generate = load_crop_file(args.crops) + + print(f"Preparing jobs ({len(crops)} candidate image pairs)...") + num_levels = min( + math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), + args.max_subdir_levels, + ) + num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1 / num_levels)) + + jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir) + del crops + + os.makedirs(args.output_dir, exist_ok=True) + mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map + call = functools.partial(save_image_crops, args) + + print(f"Generating cropped images to {args.output_dir} ...") + with open(listing_path, "w") as listing: + listing.write("# pair_path\n") + for results in tqdm(mmap(call, jobs), total=len(jobs)): + for path in results: + listing.write(f"{path}\n") + print("Finished writing listing to", listing_path) + + +def load_crop_file(path): + data = open(path).read().splitlines() + pairs = [] + num_crops_to_generate = 0 + for line in tqdm(data): + if line.startswith("#"): + continue + line = line.split(", ") + if len(line) < 8: + img1, img2, rotation = line + pairs.append((img1, img2, int(rotation), [])) + else: + l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line) + rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2) + pairs[-1][-1].append((rect1, rect2)) + num_crops_to_generate += 1 + return pairs, num_crops_to_generate + + +def prepare_jobs(pairs, num_levels, num_pairs_in_dir): + jobs = [] + powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))] + + def get_path(idx): + idx_array = [] + d = idx + for level in range(num_levels - 1): + idx_array.append(idx // powers[level]) + idx = idx % powers[level] + idx_array.append(d) + return "/".join(map(lambda x: hex(x)[2:], idx_array)) + + idx = 0 + for pair_data in tqdm(pairs): + img1, img2, rotation, crops = pair_data + if -60 <= rotation and rotation <= 60: + rotation = 0 # most likely not a true rotation + paths = [get_path(idx + k) for k in range(len(crops))] + idx += len(crops) + jobs.append(((img1, img2), rotation, crops, paths)) + return jobs + + +def load_image(path): + try: + return Image.open(path).convert("RGB") + except Exception as e: + print("skipping", path, e) + raise OSError() + + +def save_image_crops(args, data): + # load images + img_pair, rot, crops, paths = data + try: + img1, img2 = [ + load_image(os.path.join(args.root_dir, impath)) for impath in img_pair + ] + except OSError as e: + return [] + + def area(sz): + return sz[0] * sz[1] + + tgt_size = (args.imsize, args.imsize) + + def prepare_crop(img, rect, rot=0): + # actual crop + img = img.crop(rect) + + # resize to desired size + interp = ( + Image.Resampling.LANCZOS + if area(img.size) > 4 * area(tgt_size) + else Image.Resampling.BICUBIC + ) + img = img.resize(tgt_size, resample=interp) + + # rotate the image + rot90 = (round(rot / 90) % 4) * 90 + if rot90 == 90: + img = img.transpose(Image.Transpose.ROTATE_90) + elif rot90 == 180: + img = img.transpose(Image.Transpose.ROTATE_180) + elif rot90 == 270: + img = img.transpose(Image.Transpose.ROTATE_270) + return img + + results = [] + for (rect1, rect2), path in zip(crops, paths): + crop1 = prepare_crop(img1, rect1) + crop2 = prepare_crop(img2, rect2, rot) + + fullpath1 = os.path.join(args.output_dir, path + "_1.jpg") + fullpath2 = os.path.join(args.output_dir, path + "_2.jpg") + os.makedirs(os.path.dirname(fullpath1), exist_ok=True) + + assert not os.path.isfile(fullpath1), fullpath1 + assert not os.path.isfile(fullpath2), fullpath2 + crop1.save(fullpath1) + crop2.save(fullpath2) + results.append(path) + + return results + + +if __name__ == "__main__": + args = arg_parser().parse_args() + main(args) diff --git a/stream3r/croco/datasets/habitat_sim/README.MD b/stream3r/croco/datasets/habitat_sim/README.MD new file mode 100644 index 0000000000000000000000000000000000000000..a505781ff9eb91bce7f1d189e848f8ba1c560940 --- /dev/null +++ b/stream3r/croco/datasets/habitat_sim/README.MD @@ -0,0 +1,76 @@ +## Generation of synthetic image pairs using Habitat-Sim + +These instructions allow to generate pre-training pairs from the Habitat simulator. +As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent. + +### Download Habitat-Sim scenes +Download Habitat-Sim scenes: +- Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md +- We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets. +- Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`. +``` +./data/ +└──habitat-sim-data/ + └──scene_datasets/ + ├──hm3d/ + ├──gibson/ + ├──habitat-test-scenes/ + ├──replica_cad_baked_lighting/ + ├──replica_cad/ + ├──ReplicaDataset/ + └──scannet/ +``` + +### Image pairs generation +We provide metadata to generate reproducible images pairs for pretraining and validation. +Experiments described in the paper used similar data, but whose generation was not reproducible at the time. + +Specifications: +- 256x256 resolution images, with 60 degrees field of view . +- Up to 1000 image pairs per scene. +- Number of scenes considered/number of images pairs per dataset: + - Scannet: 1097 scenes / 985 209 pairs + - HM3D: + - hm3d/train: 800 / 800k pairs + - hm3d/val: 100 scenes / 100k pairs + - hm3d/minival: 10 scenes / 10k pairs + - habitat-test-scenes: 3 scenes / 3k pairs + - replica_cad_baked_lighting: 13 scenes / 13k pairs + +- Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes. + +Download metadata and extract it: +```bash +mkdir -p data/habitat_release_metadata/ +cd data/habitat_release_metadata/ +wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz +tar -xvf multiview_habitat_metadata.tar.gz +cd ../.. +# Location of the metadata +METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata" +``` + +Generate image pairs from metadata: +- The following command will print a list of commandlines to generate image pairs for each scene: +```bash +# Target output directory +PAIRS_DATASET_DIR="./data/habitat_release/" +python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR +``` +- One can launch multiple of such commands in parallel e.g. using GNU Parallel: +```bash +python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16 +``` + +## Metadata generation + +Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible: +```bash +# Print commandlines to generate image pairs from the different scenes available. +PAIRS_DATASET_DIR=MY_CUSTOM_PATH +python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR + +# Once a dataset is generated, pack metadata files for reproducibility. +METADATA_DIR=MY_CUSTON_PATH +python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR +``` diff --git a/stream3r/croco/datasets/habitat_sim/__init__.py b/stream3r/croco/datasets/habitat_sim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a49a3a9427e67836de554fed0bc7f6466adbe06 --- /dev/null +++ b/stream3r/croco/datasets/habitat_sim/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + diff --git a/stream3r/croco/datasets/habitat_sim/generate_from_metadata.py b/stream3r/croco/datasets/habitat_sim/generate_from_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..3e133fdd3fd7489a9f28e491c175e7ffcc764633 --- /dev/null +++ b/stream3r/croco/datasets/habitat_sim/generate_from_metadata.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Script to generate image pairs for a given scene reproducing poses provided in a metadata file. +""" +import argparse +import json +import os + +import cv2 +import PIL.Image +import quaternion +from datasets.habitat_sim.multiview_habitat_sim_generator import ( + MultiviewHabitatSimGenerator, +) +from datasets.habitat_sim.paths import SCENES_DATASET +from tqdm import tqdm + + +def generate_multiview_images_from_metadata( + metadata_filename, + output_dir, + overload_params=dict(), + scene_datasets_paths=None, + exist_ok=False, +): + """ + Generate images from a metadata file for reproducibility purposes. + """ + # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label + if scene_datasets_paths is not None: + scene_datasets_paths = dict( + sorted(scene_datasets_paths.items(), key=lambda x: len(x[0]), reverse=True) + ) + + with open(metadata_filename, "r") as f: + input_metadata = json.load(f) + metadata = dict() + for key, value in input_metadata.items(): + # Optionally replace some paths + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + if scene_datasets_paths is not None: + for dataset_label, dataset_path in scene_datasets_paths.items(): + if value.startswith(dataset_label): + value = os.path.normpath( + os.path.join( + dataset_path, os.path.relpath(value, dataset_label) + ) + ) + break + metadata[key] = value + + # Overload some parameters + for key, value in overload_params.items(): + metadata[key] = value + + generation_entries = dict( + [ + (key, value) + for key, value in metadata.items() + if not (key in ("multiviews", "output_dir", "generate_depth")) + ] + ) + generate_depth = metadata["generate_depth"] + + os.makedirs(output_dir, exist_ok=exist_ok) + + generator = MultiviewHabitatSimGenerator(**generation_entries) + + # Generate views + for idx_label, data in tqdm(metadata["multiviews"].items()): + positions = data["positions"] + orientations = data["orientations"] + n = len(positions) + for oidx in range(n): + observation = generator.render_viewpoint( + positions[oidx], quaternion.from_float_array(orientations[oidx]) + ) + observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1 + # Color image saved using PIL + img = PIL.Image.fromarray(observation["color"][:, :, :3]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg") + img.save(filename) + if generate_depth: + # Depth image as EXR file + filename = os.path.join( + output_dir, f"{idx_label}_{observation_label}_depth.exr" + ) + cv2.imwrite( + filename, + observation["depth"], + [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF], + ) + # Camera parameters + camera_params = dict( + [ + (key, observation[key].tolist()) + for key in ("camera_intrinsics", "R_cam2world", "t_cam2world") + ] + ) + filename = os.path.join( + output_dir, f"{idx_label}_{observation_label}_camera_params.json" + ) + with open(filename, "w") as f: + json.dump(camera_params, f) + # Save metadata + with open(os.path.join(output_dir, "metadata.json"), "w") as f: + json.dump(metadata, f) + + generator.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_filename", required=True) + parser.add_argument("--output_dir", required=True) + args = parser.parse_args() + + generate_multiview_images_from_metadata( + metadata_filename=args.metadata_filename, + output_dir=args.output_dir, + scene_datasets_paths=SCENES_DATASET, + overload_params=dict(), + exist_ok=True, + ) diff --git a/stream3r/croco/datasets/habitat_sim/generate_from_metadata_files.py b/stream3r/croco/datasets/habitat_sim/generate_from_metadata_files.py new file mode 100644 index 0000000000000000000000000000000000000000..afcc62116624beb162e4402df4ca3527d04eb620 --- /dev/null +++ b/stream3r/croco/datasets/habitat_sim/generate_from_metadata_files.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Script generating commandlines to generate image pairs from metadata files. +""" +import argparse +import glob +import os + +from tqdm import tqdm + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", required=True) + parser.add_argument("--output_dir", required=True) + parser.add_argument( + "--prefix", + default="", + help="Commanline prefix, useful e.g. to setup environment.", + ) + args = parser.parse_args() + + input_metadata_filenames = glob.iglob( + f"{args.input_dir}/**/metadata.json", recursive=True + ) + + for metadata_filename in tqdm(input_metadata_filenames): + output_dir = os.path.join( + args.output_dir, + os.path.relpath(os.path.dirname(metadata_filename), args.input_dir), + ) + # Do not process the scene if the metadata file already exists + if os.path.exists(os.path.join(output_dir, "metadata.json")): + continue + commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}" + print(commandline) diff --git a/stream3r/croco/datasets/habitat_sim/generate_multiview_images.py b/stream3r/croco/datasets/habitat_sim/generate_multiview_images.py new file mode 100644 index 0000000000000000000000000000000000000000..8b51827af7e8376fc68df4a93157acc5e80afac9 --- /dev/null +++ b/stream3r/croco/datasets/habitat_sim/generate_multiview_images.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import argparse +import json +import os + +import cv2 +import numpy as np +import PIL.Image +import quaternion +from datasets.habitat_sim.multiview_habitat_sim_generator import ( + MultiviewHabitatSimGenerator, + NoNaviguableSpaceError, +) +from datasets.habitat_sim.paths import list_scenes_available +from tqdm import tqdm + + +def generate_multiview_images_for_scene( + scene_dataset_config_file, + scene, + navmesh, + output_dir, + views_count, + size, + exist_ok=False, + generate_depth=False, + **kwargs, +): + """ + Generate tuples of overlapping views for a given scene. + generate_depth: generate depth images and camera parameters. + """ + if os.path.exists(output_dir) and not exist_ok: + print(f"Scene {scene}: data already generated. Ignoring generation.") + return + try: + print(f"Scene {scene}: {size} multiview acquisitions to generate...") + os.makedirs(output_dir, exist_ok=exist_ok) + + metadata_filename = os.path.join(output_dir, "metadata.json") + + metadata_template = dict( + scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count=views_count, + size=size, + generate_depth=generate_depth, + **kwargs, + ) + metadata_template["multiviews"] = dict() + + if os.path.exists(metadata_filename): + print("Metadata file already exists:", metadata_filename) + print("Loading already generated metadata file...") + with open(metadata_filename, "r") as f: + metadata = json.load(f) + + for key in metadata_template.keys(): + if key != "multiviews": + assert ( + metadata_template[key] == metadata[key] + ), f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}." + else: + print("No temporary file found. Starting generation from scratch...") + metadata = metadata_template + + starting_id = len(metadata["multiviews"]) + print(f"Starting generation from index {starting_id}/{size}...") + if starting_id >= size: + print("Generation already done.") + return + + generator = MultiviewHabitatSimGenerator( + scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count=views_count, + size=size, + **kwargs, + ) + + for idx in tqdm(range(starting_id, size)): + # Generate / re-generate the observations + try: + data = generator[idx] + observations = data["observations"] + positions = data["positions"] + orientations = data["orientations"] + + idx_label = f"{idx:08}" + for oidx, observation in enumerate(observations): + observation_label = ( + f"{oidx + 1}" # Leonid is indexing starting from 1 + ) + # Color image saved using PIL + img = PIL.Image.fromarray(observation["color"][:, :, :3]) + filename = os.path.join( + output_dir, f"{idx_label}_{observation_label}.jpeg" + ) + img.save(filename) + if generate_depth: + # Depth image as EXR file + filename = os.path.join( + output_dir, f"{idx_label}_{observation_label}_depth.exr" + ) + cv2.imwrite( + filename, + observation["depth"], + [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF], + ) + # Camera parameters + camera_params = dict( + [ + (key, observation[key].tolist()) + for key in ( + "camera_intrinsics", + "R_cam2world", + "t_cam2world", + ) + ] + ) + filename = os.path.join( + output_dir, + f"{idx_label}_{observation_label}_camera_params.json", + ) + with open(filename, "w") as f: + json.dump(camera_params, f) + metadata["multiviews"][idx_label] = { + "positions": positions.tolist(), + "orientations": orientations.tolist(), + "covisibility_ratios": data["covisibility_ratios"].tolist(), + "valid_fractions": data["valid_fractions"].tolist(), + "pairwise_visibility_ratios": data[ + "pairwise_visibility_ratios" + ].tolist(), + } + except RecursionError: + print( + "Recursion error: unable to sample observations for this scene. We will stop there." + ) + break + + # Regularly save a temporary metadata file, in case we need to restart the generation + if idx % 10 == 0: + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + # Save metadata + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + generator.close() + except NoNaviguableSpaceError: + pass + + +def create_commandline(scene_data, generate_depth, exist_ok=False): + """ + Create a commandline string to generate a scene. + """ + + def my_formatting(val): + if val is None or val == "": + return '""' + else: + return val + + commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)} + --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)} + --navmesh {my_formatting(scene_data.navmesh)} + --output_dir {my_formatting(scene_data.output_dir)} + --generate_depth {int(generate_depth)} + --exist_ok {int(exist_ok)} + """ + commandline = " ".join(commandline.split()) + return commandline + + +if __name__ == "__main__": + os.umask(2) + + parser = argparse.ArgumentParser( + description="""Example of use -- listing commands to generate data for scenes available: + > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands + """ + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--list_commands", action="store_true", help="list commandlines to run if true" + ) + parser.add_argument("--scene", type=str, default="") + parser.add_argument("--scene_dataset_config_file", type=str, default="") + parser.add_argument("--navmesh", type=str, default="") + + parser.add_argument("--generate_depth", type=int, default=1) + parser.add_argument("--exist_ok", type=int, default=0) + + kwargs = dict(resolution=(256, 256), hfov=60, views_count=2, size=1000) + + args = parser.parse_args() + generate_depth = bool(args.generate_depth) + exist_ok = bool(args.exist_ok) + + if args.list_commands: + # Listing scenes available... + scenes_data = list_scenes_available(base_output_dir=args.output_dir) + + for scene_data in scenes_data: + print( + create_commandline( + scene_data, generate_depth=generate_depth, exist_ok=exist_ok + ) + ) + else: + if args.scene == "" or args.output_dir == "": + print("Missing scene or output dir argument!") + print(parser.format_help()) + else: + generate_multiview_images_for_scene( + scene=args.scene, + scene_dataset_config_file=args.scene_dataset_config_file, + navmesh=args.navmesh, + output_dir=args.output_dir, + exist_ok=exist_ok, + generate_depth=generate_depth, + **kwargs, + ) diff --git a/stream3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py b/stream3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ea0fac439986a654dfee3fbfc2541a4baf6a44 --- /dev/null +++ b/stream3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py @@ -0,0 +1,508 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +import cv2 +import habitat_sim +import numpy as np +import quaternion +from sklearn.neighbors import NearestNeighbors + +# OpenCV to habitat camera convention transformation +R_OPENCV2HABITAT = np.stack( + (habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0 +) +R_HABITAT2OPENCV = R_OPENCV2HABITAT.T +DEG2RAD = np.pi / 180 + + +def compute_camera_intrinsics(height, width, hfov): + f = width / 2 / np.tan(hfov / 2 * np.pi / 180) + cu, cv = width / 2, height / 2 + return f, cu, cv + + +def compute_camera_pose_opencv_convention(camera_position, camera_orientation): + R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT + t_cam2world = np.asarray(camera_position) + return R_cam2world, t_cam2world + + +def compute_pointmap(depthmap, hfov): + """Compute a HxWx3 pointmap in camera frame from a HxW depth map.""" + height, width = depthmap.shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + # Cast depth map to point + z_cam = depthmap + u, v = np.meshgrid(range(width), range(height)) + x_cam = (u - cu) / f * z_cam + y_cam = (v - cv) / f * z_cam + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1) + return X_cam + + +def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation): + """Return a 3D point cloud corresponding to valid pixels of the depth map""" + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention( + camera_position, camera_rotation + ) + + X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov) + valid_mask = X_cam[:, :, 2] != 0.0 + + X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()] + X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3) + return X_world + + +def compute_pointcloud_overlaps_scikit( + pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False +): + """ + Compute 'overlapping' metrics based on a distance threshold between two point clouds. + """ + nbrs = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(pointcloud2) + distances, indices = nbrs.kneighbors(pointcloud1) + intersection1 = np.count_nonzero(distances.flatten() < distance_threshold) + + data = {"intersection1": intersection1, "size1": len(pointcloud1)} + if compute_symmetric: + nbrs = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(pointcloud1) + distances, indices = nbrs.kneighbors(pointcloud2) + intersection2 = np.count_nonzero(distances.flatten() < distance_threshold) + data["intersection2"] = intersection2 + data["size2"] = len(pointcloud2) + + return data + + +def _append_camera_parameters(observation, hfov, camera_location, camera_rotation): + """ + Add camera parameters to the observation dictionnary produced by Habitat-Sim + In-place modifications. + """ + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention( + camera_location, camera_rotation + ) + height, width = observation["depth"].shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + K = np.asarray([[f, 0, cu], [0, f, cv], [0, 0, 1.0]]) + observation["camera_intrinsics"] = K + observation["t_cam2world"] = t_cam2world + observation["R_cam2world"] = R_cam2world + + +def look_at(eye, center, up, return_cam2world=True): + """ + Return camera pose looking at a given center point. + Analogous of gluLookAt function, using OpenCV camera convention. + """ + z = center - eye + z /= np.linalg.norm(z, axis=-1, keepdims=True) + y = -up + y = y - np.sum(y * z, axis=-1, keepdims=True) * z + y /= np.linalg.norm(y, axis=-1, keepdims=True) + x = np.cross(y, z, axis=-1) + + if return_cam2world: + R = np.stack((x, y, z), axis=-1) + t = eye + else: + # World to camera transformation + # Transposed matrix + R = np.stack((x, y, z), axis=-2) + t = -np.einsum("...ij, ...j", R, eye) + return R, t + + +def look_at_for_habitat(eye, center, up, return_cam2world=True): + R, t = look_at(eye, center, up) + orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T) + return orientation, t + + +def generate_orientation_noise(pan_range, tilt_range, roll_range): + return ( + quaternion.from_rotation_vector( + np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP + ) + * quaternion.from_rotation_vector( + np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT + ) + * quaternion.from_rotation_vector( + np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT + ) + ) + + +class NoNaviguableSpaceError(RuntimeError): + def __init__(self, *args): + super().__init__(*args) + + +class MultiviewHabitatSimGenerator: + def __init__( + self, + scene, + navmesh, + scene_dataset_config_file, + resolution=(240, 320), + views_count=2, + hfov=60, + gpu_id=0, + size=10000, + minimum_covisibility=0.5, + transform=None, + ): + self.scene = scene + self.navmesh = navmesh + self.scene_dataset_config_file = scene_dataset_config_file + self.resolution = resolution + self.views_count = views_count + assert self.views_count >= 1 + self.hfov = hfov + self.gpu_id = gpu_id + self.size = size + self.transform = transform + + # Noise added to camera orientation + self.pan_range = (-3, 3) + self.tilt_range = (-10, 10) + self.roll_range = (-5, 5) + + # Height range to sample cameras + self.height_range = (1.2, 1.8) + + # Random steps between the camera views + self.random_steps_count = 5 + self.random_step_variance = 2.0 + + # Minimum fraction of the scene which should be valid (well defined depth) + self.minimum_valid_fraction = 0.7 + + # Distance threshold to see to select pairs + self.distance_threshold = 0.05 + # Minimum IoU of a view point cloud with respect to the reference view to be kept. + self.minimum_covisibility = minimum_covisibility + + # Maximum number of retries. + self.max_attempts_count = 100 + + self.seed = None + self._lazy_initialization() + + def _lazy_initialization(self): + # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly + if self.seed == None: + # Re-seed numpy generator + np.random.seed() + self.seed = np.random.randint(2**32 - 1) + sim_cfg = habitat_sim.SimulatorConfiguration() + sim_cfg.scene_id = self.scene + if ( + self.scene_dataset_config_file is not None + and self.scene_dataset_config_file != "" + ): + sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file + sim_cfg.random_seed = self.seed + sim_cfg.load_semantic_mesh = False + sim_cfg.gpu_device_id = self.gpu_id + + depth_sensor_spec = habitat_sim.CameraSensorSpec() + depth_sensor_spec.uuid = "depth" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.resolution + depth_sensor_spec.hfov = self.hfov + depth_sensor_spec.position = [0.0, 0.0, 0] + depth_sensor_spec.orientation + + rgb_sensor_spec = habitat_sim.CameraSensorSpec() + rgb_sensor_spec.uuid = "color" + rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR + rgb_sensor_spec.resolution = self.resolution + rgb_sensor_spec.hfov = self.hfov + rgb_sensor_spec.position = [0.0, 0.0, 0] + agent_cfg = habitat_sim.agent.AgentConfiguration( + sensor_specifications=[rgb_sensor_spec, depth_sensor_spec] + ) + + cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg]) + self.sim = habitat_sim.Simulator(cfg) + if self.navmesh is not None and self.navmesh != "": + # Use pre-computed navmesh when available (usually better than those generated automatically) + self.sim.pathfinder.load_nav_mesh(self.navmesh) + + if not self.sim.pathfinder.is_loaded: + # Try to compute a navmesh + navmesh_settings = habitat_sim.NavMeshSettings() + navmesh_settings.set_defaults() + self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True) + + # Ensure that the navmesh is not empty + if not self.sim.pathfinder.is_loaded: + raise NoNaviguableSpaceError( + f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})" + ) + + self.agent = self.sim.initialize_agent(agent_id=0) + + def close(self): + self.sim.close() + + def __del__(self): + self.sim.close() + + def __len__(self): + return self.size + + def sample_random_viewpoint(self): + """Sample a random viewpoint using the navmesh""" + nav_point = self.sim.pathfinder.get_random_navigable_point() + + # Sample a random viewpoint height + viewpoint_height = np.random.uniform(*self.height_range) + viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP + viewpoint_orientation = quaternion.from_rotation_vector( + np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP + ) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range) + return viewpoint_position, viewpoint_orientation, nav_point + + def sample_other_random_viewpoint(self, observed_point, nav_point): + """Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point.""" + other_nav_point = nav_point + + walk_directions = self.random_step_variance * np.asarray([1, 0, 1]) + for i in range(self.random_steps_count): + temp = self.sim.pathfinder.snap_point( + other_nav_point + walk_directions * np.random.normal(size=3) + ) + # Snapping may return nan when it fails + if not np.isnan(temp[0]): + other_nav_point = temp + + other_viewpoint_height = np.random.uniform(*self.height_range) + other_viewpoint_position = ( + other_nav_point + other_viewpoint_height * habitat_sim.geo.UP + ) + + # Set viewing direction towards the central point + rotation, position = look_at_for_habitat( + eye=other_viewpoint_position, + center=observed_point, + up=habitat_sim.geo.UP, + return_cam2world=True, + ) + rotation = rotation * generate_orientation_noise( + self.pan_range, self.tilt_range, self.roll_range + ) + return position, rotation, other_nav_point + + def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud): + """Check if a viewpoint is valid and overlaps significantly with a reference one.""" + # Observation + pixels_count = self.resolution[0] * self.resolution[1] + valid_fraction = len(other_pointcloud) / pixels_count + assert valid_fraction <= 1.0 and valid_fraction >= 0.0 + overlap = compute_pointcloud_overlaps_scikit( + ref_pointcloud, + other_pointcloud, + self.distance_threshold, + compute_symmetric=True, + ) + covisibility = min( + overlap["intersection1"] / pixels_count, + overlap["intersection2"] / pixels_count, + ) + is_valid = (valid_fraction >= self.minimum_valid_fraction) and ( + covisibility >= self.minimum_covisibility + ) + return is_valid, valid_fraction, covisibility + + def is_other_viewpoint_overlapping( + self, ref_pointcloud, observation, position, rotation + ): + """Check if a viewpoint is valid and overlaps significantly with a reference one.""" + # Observation + other_pointcloud = compute_pointcloud( + observation["depth"], self.hfov, position, rotation + ) + return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud) + + def render_viewpoint(self, viewpoint_position, viewpoint_orientation): + agent_state = habitat_sim.AgentState() + agent_state.position = viewpoint_position + agent_state.rotation = viewpoint_orientation + self.agent.set_state(agent_state) + viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0) + _append_camera_parameters( + viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation + ) + return viewpoint_observations + + def __getitem__(self, useless_idx): + ref_position, ref_orientation, nav_point = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + # Extract point cloud + ref_pointcloud = compute_pointcloud( + depthmap=ref_observations["depth"], + hfov=self.hfov, + camera_position=ref_position, + camera_rotation=ref_orientation, + ) + + pixels_count = self.resolution[0] * self.resolution[1] + ref_valid_fraction = len(ref_pointcloud) / pixels_count + assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0 + if ref_valid_fraction < self.minimum_valid_fraction: + # This should produce a recursion error at some point when something is very wrong. + return self[0] + # Pick an reference observed point in the point cloud + observed_point = np.mean(ref_pointcloud, axis=0) + + # Add the first image as reference + viewpoints_observations = [ref_observations] + viewpoints_covisibility = [ref_valid_fraction] + viewpoints_positions = [ref_position] + viewpoints_orientations = [quaternion.as_float_array(ref_orientation)] + viewpoints_clouds = [ref_pointcloud] + viewpoints_valid_fractions = [ref_valid_fraction] + + for _ in range(self.views_count - 1): + # Generate an other viewpoint using some dummy random walk + successful_sampling = False + for sampling_attempt in range(self.max_attempts_count): + position, rotation, _ = self.sample_other_random_viewpoint( + observed_point, nav_point + ) + # Observation + other_viewpoint_observations = self.render_viewpoint(position, rotation) + other_pointcloud = compute_pointcloud( + other_viewpoint_observations["depth"], self.hfov, position, rotation + ) + + ( + is_valid, + valid_fraction, + covisibility, + ) = self.is_other_pointcloud_overlapping( + ref_pointcloud, other_pointcloud + ) + if is_valid: + successful_sampling = True + break + if not successful_sampling: + print("WARNING: Maximum number of attempts reached.") + # Dirty hack, try using a novel original viewpoint + return self[0] + viewpoints_observations.append(other_viewpoint_observations) + viewpoints_covisibility.append(covisibility) + viewpoints_positions.append(position) + viewpoints_orientations.append( + quaternion.as_float_array(rotation) + ) # WXYZ convention for the quaternion encoding. + viewpoints_clouds.append(other_pointcloud) + viewpoints_valid_fractions.append(valid_fraction) + + # Estimate relations between all pairs of images + pairwise_visibility_ratios = np.ones( + (len(viewpoints_observations), len(viewpoints_observations)) + ) + for i in range(len(viewpoints_observations)): + pairwise_visibility_ratios[i, i] = viewpoints_valid_fractions[i] + for j in range(i + 1, len(viewpoints_observations)): + overlap = compute_pointcloud_overlaps_scikit( + viewpoints_clouds[i], + viewpoints_clouds[j], + self.distance_threshold, + compute_symmetric=True, + ) + pairwise_visibility_ratios[i, j] = ( + overlap["intersection1"] / pixels_count + ) + pairwise_visibility_ratios[j, i] = ( + overlap["intersection2"] / pixels_count + ) + + # IoU is relative to the image 0 + data = { + "observations": viewpoints_observations, + "positions": np.asarray(viewpoints_positions), + "orientations": np.asarray(viewpoints_orientations), + "covisibility_ratios": np.asarray(viewpoints_covisibility), + "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float), + "pairwise_visibility_ratios": np.asarray( + pairwise_visibility_ratios, dtype=float + ), + } + + if self.transform is not None: + data = self.transform(data) + return data + + def generate_random_spiral_trajectory( + self, + images_count=100, + max_radius=0.5, + half_turns=5, + use_constant_orientation=False, + ): + """ + Return a list of images corresponding to a spiral trajectory from a random starting point. + Useful to generate nice visualisations. + Use an even number of half turns to get a nice "C1-continuous" loop effect + """ + ref_position, ref_orientation, navpoint = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + ref_pointcloud = compute_pointcloud( + depthmap=ref_observations["depth"], + hfov=self.hfov, + camera_position=ref_position, + camera_rotation=ref_orientation, + ) + pixels_count = self.resolution[0] * self.resolution[1] + if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction: + # Dirty hack: ensure that the valid part of the image is significant + return self.generate_random_spiral_trajectory( + images_count, max_radius, half_turns, use_constant_orientation + ) + + # Pick an observed point in the point cloud + observed_point = np.mean(ref_pointcloud, axis=0) + ref_R, ref_t = compute_camera_pose_opencv_convention( + ref_position, ref_orientation + ) + + images = [] + is_valid = [] + # Spiral trajectory, use_constant orientation + for i, alpha in enumerate(np.linspace(0, 1, images_count)): + r = max_radius * np.abs( + np.sin(alpha * np.pi) + ) # Increase then decrease the radius + theta = alpha * half_turns * np.pi + x = r * np.cos(theta) + y = r * np.sin(theta) + z = 0.0 + position = ( + ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3, 1)).flatten() + ) + if use_constant_orientation: + orientation = ref_orientation + else: + # trajectory looking at a mean point in front of the ref observation + orientation, position = look_at_for_habitat( + eye=position, center=observed_point, up=habitat_sim.geo.UP + ) + observations = self.render_viewpoint(position, orientation) + images.append(observations["color"][..., :3]) + _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping( + ref_pointcloud, observations, position, orientation + ) + is_valid.append(_is_valid) + return images, np.all(is_valid) diff --git a/stream3r/croco/datasets/habitat_sim/pack_metadata_files.py b/stream3r/croco/datasets/habitat_sim/pack_metadata_files.py new file mode 100644 index 0000000000000000000000000000000000000000..4de14e152499d292698090c8daa7b2e42158a630 --- /dev/null +++ b/stream3r/croco/datasets/habitat_sim/pack_metadata_files.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +""" +Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere. +""" +import argparse +import collections +import glob +import json +import os + +from datasets.habitat_sim.paths import SCENES_DATASET +from tqdm import tqdm + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input_dir") + parser.add_argument("output_dir") + args = parser.parse_args() + + input_dirname = args.input_dir + output_dirname = args.output_dir + + input_metadata_filenames = glob.iglob( + f"{input_dirname}/**/metadata.json", recursive=True + ) + + images_count = collections.defaultdict(lambda: 0) + + os.makedirs(output_dirname) + for input_filename in tqdm(input_metadata_filenames): + # Ignore empty files + with open(input_filename, "r") as f: + original_metadata = json.load(f) + if ( + "multiviews" not in original_metadata + or len(original_metadata["multiviews"]) == 0 + ): + print("No views in", input_filename) + continue + + relpath = os.path.relpath(input_filename, input_dirname) + print(relpath) + + # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability. + # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern. + scenes_dataset_paths = dict( + sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True) + ) + metadata = dict() + for key, value in original_metadata.items(): + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + known_path = False + for dataset, dataset_path in scenes_dataset_paths.items(): + if value.startswith(dataset_path): + value = os.path.join( + dataset, os.path.relpath(value, dataset_path) + ) + known_path = True + break + if not known_path: + raise KeyError("Unknown path:" + value) + metadata[key] = value + + # Compile some general statistics while packing data + scene_split = metadata["scene"].split("/") + upper_level = ( + "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0] + ) + images_count[upper_level] += len(metadata["multiviews"]) + + output_filename = os.path.join(output_dirname, relpath) + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + with open(output_filename, "w") as f: + json.dump(metadata, f) + + # Print statistics + print("Images count:") + for upper_level, count in images_count.items(): + print(f"- {upper_level}: {count}") diff --git a/stream3r/croco/datasets/habitat_sim/paths.py b/stream3r/croco/datasets/habitat_sim/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..8d673d7a1b244571501d1bc5a0d7550693801d10 --- /dev/null +++ b/stream3r/croco/datasets/habitat_sim/paths.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +""" +Paths to Habitat-Sim scenes +""" + +import collections +import os + +from tqdm import tqdm + +# Hardcoded path to the different scene datasets +SCENES_DATASET = { + "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/", + "gibson": "./data/habitat-sim-data/scene_datasets/gibson/", + "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/", + "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/", + "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/", + "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/", + "scannet": "./data/habitat-sim/scene_datasets/scannet/", +} + +SceneData = collections.namedtuple( + "SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"] +) + + +def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]): + scene_dataset_config_file = os.path.join( + base_path, "replicaCAD.scene_dataset_config.json" + ) + scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"] + navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + [ + "empty_stage.navmesh" + ] + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx]) + # Add scene + data = SceneData( + scene_dataset_config_file=scene_dataset_config_file, + scene=scenes[idx] + ".scene_instance.json", + navmesh=os.path.join(base_path, navmeshes[idx]), + output_dir=output_dir, + ) + scenes_data.append(data) + return scenes_data + + +def list_replica_cad_baked_lighting_scenes( + base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"] +): + scene_dataset_config_file = os.path.join( + base_path, "replicaCAD_baked.scene_dataset_config.json" + ) + scenes = sum( + [[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [] + ) + navmeshes = "" # [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"] + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join( + base_output_dir, "replica_cad_baked_lighting", scenes[idx] + ) + data = SceneData( + scene_dataset_config_file=scene_dataset_config_file, + scene=scenes[idx], + navmesh="", + output_dir=output_dir, + ) + scenes_data.append(data) + return scenes_data + + +def list_replica_scenes(base_output_dir, base_path): + scenes_data = [] + for scene_id in os.listdir(base_path): + scene = os.path.join(base_path, scene_id, "mesh.ply") + navmesh = os.path.join( + base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh" + ) # Not sure if I should use it + scene_dataset_config_file = "" + output_dir = os.path.join(base_output_dir, scene_id) + # Add scene only if it does not exist already, or if exist_ok + data = SceneData( + scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + output_dir=output_dir, + ) + scenes_data.append(data) + return scenes_data + + +def list_scenes(base_output_dir, base_path): + """ + Generic method iterating through a base_path folder to find scenes. + """ + scenes_data = [] + for root, dirs, files in os.walk(base_path, followlinks=True): + folder_scenes_data = [] + for file in files: + name, ext = os.path.splitext(file) + if ext == ".glb": + scene = os.path.join(root, name + ".glb") + navmesh = os.path.join(root, name + ".navmesh") + if not os.path.exists(navmesh): + navmesh = "" + relpath = os.path.relpath(root, base_path) + output_dir = os.path.abspath( + os.path.join(base_output_dir, relpath, name) + ) + data = SceneData( + scene_dataset_config_file="", + scene=scene, + navmesh=navmesh, + output_dir=output_dir, + ) + folder_scenes_data.append(data) + + # Specific check for HM3D: + # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version. + basis_scenes = [ + data.scene[: -len(".basis.glb")] + for data in folder_scenes_data + if data.scene.endswith(".basis.glb") + ] + if len(basis_scenes) != 0: + folder_scenes_data = [ + data + for data in folder_scenes_data + if not (data.scene[: -len(".glb")] in basis_scenes) + ] + + scenes_data.extend(folder_scenes_data) + return scenes_data + + +def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET): + scenes_data = [] + + # HM3D + for split in ("minival", "train", "val", "examples"): + scenes_data += list_scenes( + base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"), + base_path=f"{scenes_dataset_paths['hm3d']}/{split}", + ) + + # Gibson + scenes_data += list_scenes( + base_output_dir=os.path.join(base_output_dir, "gibson"), + base_path=scenes_dataset_paths["gibson"], + ) + + # Habitat test scenes (just a few) + scenes_data += list_scenes( + base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"), + base_path=scenes_dataset_paths["habitat-test-scenes"], + ) + + # ReplicaCAD (baked lightning) + scenes_data += list_replica_cad_baked_lighting_scenes( + base_output_dir=base_output_dir + ) + + # ScanNet + scenes_data += list_scenes( + base_output_dir=os.path.join(base_output_dir, "scannet"), + base_path=scenes_dataset_paths["scannet"], + ) + + # Replica + list_replica_scenes( + base_output_dir=os.path.join(base_output_dir, "replica"), + base_path=scenes_dataset_paths["replica"], + ) + return scenes_data diff --git a/stream3r/croco/datasets/pairs_dataset.py b/stream3r/croco/datasets/pairs_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..989ba30ce54630326c0e476adadbb00f23f15629 --- /dev/null +++ b/stream3r/croco/datasets/pairs_dataset.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import os + +from datasets.transforms import get_pair_transforms +from PIL import Image +from torch.utils.data import Dataset + + +def load_image(impath): + return Image.open(impath) + + +def load_pairs_from_cache_file(fname, root=""): + assert os.path.isfile( + fname + ), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, "r") as fid: + lines = fid.read().strip().splitlines() + pairs = [ + (os.path.join(root, l.split()[0]), os.path.join(root, l.split()[1])) + for l in lines + ] + return pairs + + +def load_pairs_from_list_file(fname, root=""): + assert os.path.isfile( + fname + ), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, "r") as fid: + lines = fid.read().strip().splitlines() + pairs = [ + (os.path.join(root, l + "_1.jpg"), os.path.join(root, l + "_2.jpg")) + for l in lines + if not l.startswith("#") + ] + return pairs + + +def write_cache_file(fname, pairs, root=""): + if len(root) > 0: + if not root.endswith("/"): + root += "/" + assert os.path.isdir(root) + s = "" + for im1, im2 in pairs: + if len(root) > 0: + assert im1.startswith(root), im1 + assert im2.startswith(root), im2 + s += "{:s} {:s}\n".format(im1[len(root) :], im2[len(root) :]) + with open(fname, "w") as fid: + fid.write(s[:-1]) + + +def parse_and_cache_all_pairs(dname, data_dir="./data/"): + if dname == "habitat_release": + dirname = os.path.join(data_dir, "habitat_release") + assert os.path.isdir(dirname), ( + "cannot find folder for habitat_release pairs: " + dirname + ) + cache_file = os.path.join(dirname, "pairs.txt") + assert not os.path.isfile(cache_file), ( + "cache file already exists: " + cache_file + ) + + print("Parsing pairs for dataset: " + dname) + pairs = [] + for root, dirs, files in os.walk(dirname): + if "val" in root: + continue + dirs.sort() + pairs += [ + ( + os.path.join(root, f), + os.path.join(root, f[: -len("_1.jpeg")] + "_2.jpeg"), + ) + for f in sorted(files) + if f.endswith("_1.jpeg") + ] + print("Found {:,} pairs".format(len(pairs))) + print("Writing cache to: " + cache_file) + write_cache_file(cache_file, pairs, root=dirname) + + else: + raise NotImplementedError("Unknown dataset: " + dname) + + +def dnames_to_image_pairs(dnames, data_dir="./data/"): + """ + dnames: list of datasets with image pairs, separated by + + """ + all_pairs = [] + for dname in dnames.split("+"): + if dname == "habitat_release": + dirname = os.path.join(data_dir, "habitat_release") + assert os.path.isdir(dirname), ( + "cannot find folder for habitat_release pairs: " + dirname + ) + cache_file = os.path.join(dirname, "pairs.txt") + assert os.path.isfile(cache_file), ( + "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. " + + cache_file + ) + pairs = load_pairs_from_cache_file(cache_file, root=dirname) + elif dname in ["ARKitScenes", "MegaDepth", "3DStreetView", "IndoorVL"]: + dirname = os.path.join(data_dir, dname + "_crops") + assert os.path.isdir( + dirname + ), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) + list_file = os.path.join(dirname, "listing.txt") + assert os.path.isfile( + list_file + ), "cannot find list file for {:s} pairs, see instructions. {:s}".format( + dname, list_file + ) + pairs = load_pairs_from_list_file(list_file, root=dirname) + print(" {:s}: {:,} pairs".format(dname, len(pairs))) + all_pairs += pairs + if "+" in dnames: + print(" Total: {:,} pairs".format(len(all_pairs))) + return all_pairs + + +class PairsDataset(Dataset): + def __init__( + self, dnames, trfs="", totensor=True, normalize=True, data_dir="./data/" + ): + super().__init__() + self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) + self.transforms = get_pair_transforms( + transform_str=trfs, totensor=totensor, normalize=normalize + ) + + def __len__(self): + return len(self.image_pairs) + + def __getitem__(self, index): + im1path, im2path = self.image_pairs[index] + im1 = load_image(im1path) + im2 = load_image(im2path) + if self.transforms is not None: + im1, im2 = self.transforms(im1, im2) + return im1, im2 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + prog="Computing and caching list of pairs for a given dataset" + ) + parser.add_argument( + "--data_dir", default="./data/", type=str, help="path where data are stored" + ) + parser.add_argument( + "--dataset", default="habitat_release", type=str, help="name of the dataset" + ) + args = parser.parse_args() + parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir) diff --git a/stream3r/croco/datasets/transforms.py b/stream3r/croco/datasets/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6c1dba67b092466c53af39372f07fdb4a8a040b8 --- /dev/null +++ b/stream3r/croco/datasets/transforms.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch +import torchvision.transforms +import torchvision.transforms.functional as F + +# "Pair": apply a transform on a pair +# "Both": apply the exact same transform to both images + + +class ComposePair(torchvision.transforms.Compose): + def __call__(self, img1, img2): + for t in self.transforms: + img1, img2 = t(img1, img2) + return img1, img2 + + +class NormalizeBoth(torchvision.transforms.Normalize): + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + + +class ToTensorBoth(torchvision.transforms.ToTensor): + def __call__(self, img1, img2): + img1 = super().__call__(img1) + img2 = super().__call__(img2) + return img1, img2 + + +class RandomCropPair(torchvision.transforms.RandomCrop): + # the crop will be intentionally different for the two images with this class + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + + +class ColorJitterPair(torchvision.transforms.ColorJitter): + # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob + def __init__(self, assymetric_prob, **kwargs): + super().__init__(**kwargs) + self.assymetric_prob = assymetric_prob + + def jitter_one( + self, + img, + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ): + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return img + + def forward(self, img1, img2): + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + img1 = self.jitter_one( + img1, + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) + if torch.rand(1) < self.assymetric_prob: # assymetric: + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + img2 = self.jitter_one( + img2, + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) + return img1, img2 + + +def get_pair_transforms(transform_str, totensor=True, normalize=True): + # transform_str is eg crop224+color + trfs = [] + for s in transform_str.split("+"): + if s.startswith("crop"): + size = int(s[len("crop") :]) + trfs.append(RandomCropPair(size)) + elif s == "acolor": + trfs.append( + ColorJitterPair( + assymetric_prob=1.0, + brightness=(0.6, 1.4), + contrast=(0.6, 1.4), + saturation=(0.6, 1.4), + hue=0.0, + ) + ) + elif s == "": # if transform_str was "" + pass + else: + raise NotImplementedError("Unknown augmentation: " + s) + + if totensor: + trfs.append(ToTensorBoth()) + if normalize: + trfs.append( + NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) + + if len(trfs) == 0: + return None + elif len(trfs) == 1: + return trfs + else: + return ComposePair(trfs) diff --git a/stream3r/croco/demo.py b/stream3r/croco/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..dcb8f1de2d9bbba7ac72b146b518df04ca8de3cc --- /dev/null +++ b/stream3r/croco/demo.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch +import torchvision.transforms +from models.croco import CroCoNet +from PIL import Image +from torchvision.transforms import Compose, Normalize, ToTensor + + +def main(): + device = torch.device( + "cuda:0" + if torch.cuda.is_available() and torch.cuda.device_count() > 0 + else "cpu" + ) + + # load 224x224 images and transform them to tensor + imagenet_mean = [0.485, 0.456, 0.406] + imagenet_mean_tensor = ( + torch.tensor(imagenet_mean).view(1, 3, 1, 1).to(device, non_blocking=True) + ) + imagenet_std = [0.229, 0.224, 0.225] + imagenet_std_tensor = ( + torch.tensor(imagenet_std).view(1, 3, 1, 1).to(device, non_blocking=True) + ) + trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)]) + image1 = ( + trfs(Image.open("assets/Chateau1.png").convert("RGB")) + .to(device, non_blocking=True) + .unsqueeze(0) + ) + image2 = ( + trfs(Image.open("assets/Chateau2.png").convert("RGB")) + .to(device, non_blocking=True) + .unsqueeze(0) + ) + + # load model + ckpt = torch.load("pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth", "cpu") + model = CroCoNet(**ckpt.get("croco_kwargs", {})).to(device) + model.eval() + msg = model.load_state_dict(ckpt["model"], strict=True) + + # forward + with torch.inference_mode(): + out, mask, target = model(image1, image2) + + # the output is normalized, thus use the mean/std of the actual image to go back to RGB space + patchified = model.patchify(image1) + mean = patchified.mean(dim=-1, keepdim=True) + var = patchified.var(dim=-1, keepdim=True) + decoded_image = model.unpatchify(out * (var + 1.0e-6) ** 0.5 + mean) + # undo imagenet normalization, prepare masked image + decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor + input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor + ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor + image_masks = model.unpatchify( + model.patchify(torch.ones_like(ref_image)) * mask[:, :, None] + ) + masked_input_image = (1 - image_masks) * input_image + + # make visualization + visualization = torch.cat( + (ref_image, masked_input_image, decoded_image, input_image), dim=3 + ) # 4*(B, 3, H, W) -> B, 3, H, W*4 + B, C, H, W = visualization.shape + visualization = visualization.permute(1, 0, 2, 3).reshape(C, B * H, W) + visualization = torchvision.transforms.functional.to_pil_image( + torch.clamp(visualization, 0, 1) + ) + fname = "demo_output.png" + visualization.save(fname) + print("Visualization save in " + fname) + + +if __name__ == "__main__": + main() diff --git a/stream3r/croco/interactive_demo.ipynb b/stream3r/croco/interactive_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6cfc960af5baac9a69029c29a16eea4e24123a71 --- /dev/null +++ b/stream3r/croco/interactive_demo.ipynb @@ -0,0 +1,271 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Interactive demo of Cross-view Completion." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n", + "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "from models.croco import CroCoNet\n", + "from ipywidgets import interact, interactive, fixed, interact_manual\n", + "import ipywidgets as widgets\n", + "import matplotlib.pyplot as plt\n", + "import quaternion\n", + "import models.masking" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load CroCo model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n", + "model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n", + "msg = model.load_state_dict(ckpt['model'], strict=True)\n", + "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n", + "device = torch.device('cuda:0' if use_gpu else 'cpu')\n", + "model = model.eval()\n", + "model = model.to(device=device)\n", + "print(msg)\n", + "\n", + "def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n", + " \"\"\"\n", + " Perform Cross-View completion using two input images, specified using Numpy arrays.\n", + " \"\"\"\n", + " # Replace the mask generator\n", + " model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n", + "\n", + " # ImageNet-1k color normalization\n", + " imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n", + " imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n", + "\n", + " normalize_input_colors = True\n", + " is_output_normalized = True\n", + " with torch.no_grad():\n", + " # Cast data to torch\n", + " target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n", + " ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n", + "\n", + " if normalize_input_colors:\n", + " ref_image = (ref_image - imagenet_mean) / imagenet_std\n", + " target_image = (target_image - imagenet_mean) / imagenet_std\n", + "\n", + " out, mask, _ = model(target_image, ref_image)\n", + " # # get target\n", + " if not is_output_normalized:\n", + " predicted_image = model.unpatchify(out)\n", + " else:\n", + " # The output only contains higher order information,\n", + " # we retrieve mean and standard deviation from the actual target image\n", + " patchified = model.patchify(target_image)\n", + " mean = patchified.mean(dim=-1, keepdim=True)\n", + " var = patchified.var(dim=-1, keepdim=True)\n", + " pred_renorm = out * (var + 1.e-6)**.5 + mean\n", + " predicted_image = model.unpatchify(pred_renorm)\n", + "\n", + " image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n", + " masked_target_image = (1 - image_masks) * target_image\n", + " \n", + " if not reconstruct_unmasked_patches:\n", + " # Replace unmasked patches by their actual values\n", + " predicted_image = predicted_image * image_masks + masked_target_image\n", + "\n", + " # Unapply color normalization\n", + " if normalize_input_colors:\n", + " predicted_image = predicted_image * imagenet_std + imagenet_mean\n", + " masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n", + " \n", + " # Cast to Numpy\n", + " masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n", + " predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n", + " return masked_target_image, predicted_image" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n", + "os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n", + "import habitat_sim\n", + "\n", + "scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n", + "navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n", + "\n", + "sim_cfg = habitat_sim.SimulatorConfiguration()\n", + "if use_gpu: sim_cfg.gpu_device_id = 0\n", + "sim_cfg.scene_id = scene\n", + "sim_cfg.load_semantic_mesh = False\n", + "rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n", + "rgb_sensor_spec.uuid = \"color\"\n", + "rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n", + "rgb_sensor_spec.resolution = (224,224)\n", + "rgb_sensor_spec.hfov = 56.56\n", + "rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n", + "rgb_sensor_spec.orientation = [0, 0, 0]\n", + "agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n", + "\n", + "\n", + "cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n", + "sim = habitat_sim.Simulator(cfg)\n", + "if navmesh is not None:\n", + " sim.pathfinder.load_nav_mesh(navmesh)\n", + "agent = sim.initialize_agent(agent_id=0)\n", + "\n", + "def sample_random_viewpoint():\n", + " \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n", + " nav_point = sim.pathfinder.get_random_navigable_point()\n", + " # Sample a random viewpoint height\n", + " viewpoint_height = np.random.uniform(1.0, 1.6)\n", + " viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n", + " viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n", + " return viewpoint_position, viewpoint_orientation\n", + "\n", + "def render_viewpoint(position, orientation):\n", + " agent_state = habitat_sim.AgentState()\n", + " agent_state.position = position\n", + " agent_state.rotation = orientation\n", + " agent.set_state(agent_state)\n", + " viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n", + " image = viewpoint_observations['color'][:,:,:3]\n", + " image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n", + " return image" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sample a random reference view" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ref_position, ref_orientation = sample_random_viewpoint()\n", + "ref_image = render_viewpoint(ref_position, ref_orientation)\n", + "plt.clf()\n", + "fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n", + "axes[0,0].imshow(ref_image)\n", + "for ax in axes.flatten():\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Interactive cross-view completion using CroCo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstruct_unmasked_patches = False\n", + "\n", + "def show_demo(masking_ratio, x, y, z, panorama, elevation):\n", + " R = quaternion.as_rotation_matrix(ref_orientation)\n", + " target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n", + " target_orientation = (ref_orientation\n", + " * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n", + " * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n", + " \n", + " ref_image = render_viewpoint(ref_position, ref_orientation)\n", + " target_image = render_viewpoint(target_position, target_orientation)\n", + "\n", + " masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n", + "\n", + " fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n", + " axes[0].imshow(ref_image)\n", + " axes[0].set_xlabel(\"Reference\")\n", + " axes[1].imshow(masked_target_image)\n", + " axes[1].set_xlabel(\"Masked target\")\n", + " axes[2].imshow(predicted_image)\n", + " axes[2].set_xlabel(\"Reconstruction\") \n", + " axes[3].imshow(target_image)\n", + " axes[3].set_xlabel(\"Target\")\n", + " for ax in axes.flatten():\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + "\n", + "interact(show_demo,\n", + " masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n", + " x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", + " y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", + " z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", + " panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n", + " elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "vscode": { + "interpreter": { + "hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/stream3r/croco/models/blocks.py b/stream3r/croco/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..1f652bdb4da109eae3cb926505a3626b4895817f --- /dev/null +++ b/stream3r/croco/models/blocks.py @@ -0,0 +1,483 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Main encoder/decoder blocks +# -------------------------------------------------------- +# References: +# timm +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py + + +import collections.abc +from itertools import repeat +import math + +import torch +import torch.nn as nn +from torch.nn.functional import scaled_dot_product_attention +from torch.nn.attention import SDPBackend + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, + attn_mask=None, is_causal=False, attn_implementation="pytorch_naive", + attn_bias_for_inference_enabled=False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + # use attention biasing to accommodate for longer sequences than during training + self.attn_bias_for_inference_enabled = attn_bias_for_inference_enabled + gamma = 1.0 + train_seqlen = 20 + inference_seqlen = 137 + self.attn_bias_scale = head_dim**-0.5 * (gamma * math.log(inference_seqlen) / math.log(train_seqlen))**0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.dropout_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + self.attn_mask = attn_mask + self.is_causal = is_causal + self.attn_implementation = attn_implementation + + def forward(self, x, xpos): + B, N, C = x.shape + + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .transpose(1, 3) + ) + q, k, v = [qkv[:, :, i] for i in range(3)] + # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple) + + if self.rope is not None: + with torch.autocast(device_type=next(self.parameters()).device.type, dtype=torch.float32): # FIXME: for some reason Lightning didn't pick up torch.cuda.amp.custom_fwd when using bf16-true + q = self.rope(q, xpos) if xpos is not None else q + k = self.rope(k, xpos) if xpos is not None else k + + if not self.training and self.attn_bias_for_inference_enabled: + scale = self.attn_bias_scale + else: + scale = self.scale + + # Important: For the fusion Transformer, we forward through the attention with bfloat16 precision + # If you are not using this block for the fusion Transformer, you should double check the precision of the input and output + if self.attn_implementation == "pytorch_naive": + assert self.attn_mask is None, "attn_mask not supported for pytorch_naive implementation of scaled dot product attention" + assert self.is_causal is False, "is_causal not supported for pytorch_naive implementation of scaled dot product attention" + dtype = k.dtype + with torch.autocast("cuda", dtype=torch.bfloat16): + x = (q @ k.transpose(-2, -1)) * scale + x = x.softmax(dim=-1) + x = self.attn_drop(x) + if dtype == torch.float32: # if input was FP32, cast back to FP32 + x = x.to(torch.float32) + x = (x @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + elif self.attn_implementation == "flash_attention": + with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + dtype = k.dtype + with torch.autocast("cuda", dtype=torch.bfloat16): + x = scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout_p, is_causal=self.is_causal, scale=scale) + if dtype == torch.float32: # if input was FP32, cast back to FP32 + x = x.to(torch.float32) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + elif self.attn_implementation == "pytorch_auto": + with torch.nn.attention.sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION,]): + dtype = k.dtype + with torch.autocast("cuda", dtype=torch.bfloat16): + x = scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout_p, is_causal=self.is_causal, scale=scale) + if dtype == torch.float32: # if input was FP32, cast back to FP32 + x = x.to(torch.float32) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + else: + raise ValueError(f"Unknown attn_implementation: {self.attn_implementation}") + + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + rope=None, + attn_implementation="pytorch_naive", + attn_bias_for_inference_enabled=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + rope=rope, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + attn_implementation=attn_implementation, + attn_bias_for_inference_enabled=attn_bias_for_inference_enabled, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, xpos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, attn_mask=None, is_causal=False, attn_implementation="pytorch_naive" + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.projq = nn.Linear(dim, dim, bias=qkv_bias) + self.projk = nn.Linear(dim, dim, bias=qkv_bias) + self.projv = nn.Linear(dim, dim, bias=qkv_bias) + self.dropout_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + self.attn_mask = attn_mask + self.is_causal = is_causal + self.attn_implementation = attn_implementation + + def forward(self, query, key, value, qpos, kpos): + B, Nq, C = query.shape + Nk = key.shape[1] + Nv = value.shape[1] + + q = ( + self.projq(query) + .reshape(B, Nq, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = ( + self.projk(key) + .reshape(B, Nk, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + v = ( + self.projv(value) + .reshape(B, Nv, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.rope is not None: + with torch.autocast(device_type=next(self.parameters()).device.type, dtype=torch.float32): # FIXME: for some reason Lightning didn't pick up torch.cuda.amp.custom_fwd when using bf16-true + q = self.rope(q, qpos) if qpos is not None else q + k = self.rope(k, kpos) if kpos is not None else k + + if self.attn_implementation == "pytorch_naive": + assert self.attn_mask is None, "attn_mask not supported for pytorch_naive implementation of scaled dot product attention" + assert self.is_causal is False, "is_causal not supported for pytorch_naive implementation of scaled dot product attention" + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + elif self.attn_implementation == "flash_attention": + with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + # cast to BF16 to use flash_attention + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + x = scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout_p, is_causal=self.is_causal, scale=self.scale) + # cast back to FP32 + x = x.to(torch.float32) + x = x.transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + else: + raise ValueError(f"Unknown attn_implementation: {self.attn_implementation}") + + return x + + +class TrackingAttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_mem=True, + rope=None, + attn_implementation="pytorch_naive", + ): + super().__init__() + self.cross_attn = CrossAttention( + dim, + rope=rope, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + attn_implementation=attn_implementation, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() + + def forward(self, x, y, xpos, ypos): + y_ = self.norm_y(y) + x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) + x = x + self.drop_path(self.mlp(self.norm3(x))) + return x + + +class DecoderBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_mem=True, + rope=None, + attn_implementation="pytorch_naive", + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + rope=rope, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + attn_implementation=attn_implementation, + ) + self.cross_attn = CrossAttention( + dim, + rope=rope, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + attn_implementation=attn_implementation, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() + + def forward(self, x, y, xpos, ypos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + y_ = self.norm_y(y) + x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) + x = x + self.drop_path(self.mlp(self.norm3(x))) + return x, y + + +# patch embedding +class PositionGetter(object): + """return positions of patches""" + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h, w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone() + return pos + + +class PatchEmbed(nn.Module): + """just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + self.position_getter = PositionGetter() + + def forward(self, x): + B, C, H, W = x.shape + torch._assert( + H == self.img_size[0], + f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", + ) + torch._assert( + W == self.img_size[1], + f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", + ) + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2).contiguous() # BCHW -> BNC + x = self.norm(x) + return x, pos + + def _init_weights(self): + w = self.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) diff --git a/stream3r/croco/models/criterion.py b/stream3r/croco/models/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..f9233ddc1233372e90d637a3977b70dc8f476bdd --- /dev/null +++ b/stream3r/croco/models/criterion.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Criterion to train CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# -------------------------------------------------------- + +import torch + + +class MaskedMSE(torch.nn.Module): + def __init__(self, norm_pix_loss=False, masked=True): + """ + norm_pix_loss: normalize each patch by their pixel mean and variance + masked: compute loss over the masked patches only + """ + super().__init__() + self.norm_pix_loss = norm_pix_loss + self.masked = masked + + def forward(self, pred, mask, target): + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + if self.masked: + loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches + else: + loss = loss.mean() # mean loss + return loss diff --git a/stream3r/croco/models/croco.py b/stream3r/croco/models/croco.py new file mode 100644 index 0000000000000000000000000000000000000000..989362928930aa5b04b53510eb1d4aedce352812 --- /dev/null +++ b/stream3r/croco/models/croco.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# CroCo model during pretraining +# -------------------------------------------------------- + + +import torch +import torch.nn as nn + +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 +from functools import partial + +from stream3r.croco.models.blocks import Block, DecoderBlock, PatchEmbed +from stream3r.croco.models.masking import RandomMask +from stream3r.croco.models.pos_embed import RoPE2D, get_2d_sincos_pos_embed + + +class CroCoNet(nn.Module): + def __init__( + self, + img_size=224, # input image size + patch_size=16, # patch_size + mask_ratio=0.9, # ratios of masked tokens + enc_embed_dim=768, # encoder feature dimension + enc_depth=12, # encoder depth + enc_num_heads=12, # encoder number of heads in the transformer block + dec_embed_dim=512, # decoder feature dimension + dec_depth=8, # decoder depth + dec_num_heads=16, # decoder number of heads in the transformer block + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder + pos_embed="cosine", # positional embedding (either cosine or RoPE100) + attn_implementation="pytorch_naive", # implementation of the scaled_dot_product_attention (either "pytorch_naive" or "flash_attention") + ): + super(CroCoNet, self).__init__() + + # patch embeddings (with initialization done as in MAE) + self._set_patch_embed(img_size, patch_size, enc_embed_dim) + + # mask generations + self._set_mask_generator(self.patch_embed.num_patches, mask_ratio) + + self.pos_embed = pos_embed + if pos_embed == "cosine": + # positional embedding of the encoder + enc_pos_embed = get_2d_sincos_pos_embed( + enc_embed_dim, int(self.patch_embed.num_patches**0.5), n_cls_token=0 + ) + self.register_buffer( + "enc_pos_embed", torch.from_numpy(enc_pos_embed).float() + ) + # positional embedding of the decoder + dec_pos_embed = get_2d_sincos_pos_embed( + dec_embed_dim, int(self.patch_embed.num_patches**0.5), n_cls_token=0 + ) + self.register_buffer( + "dec_pos_embed", torch.from_numpy(dec_pos_embed).float() + ) + # pos embedding in each block + self.rope = None # nothing for cosine + elif pos_embed.startswith("RoPE"): # eg RoPE100 + self.enc_pos_embed = None # nothing to add in the encoder with RoPE + self.dec_pos_embed = None # nothing to add in the decoder with RoPE + if RoPE2D is None: + raise ImportError( + "Cannot find cuRoPE2D, please install it following the README instructions" + ) + freq = float(pos_embed[len("RoPE") :]) + self.rope = RoPE2D(freq=freq) + else: + raise NotImplementedError("Unknown pos_embed " + pos_embed) + + self.attn_implementation = attn_implementation + # transformer for the encoder + self.enc_depth = enc_depth + self.enc_embed_dim = enc_embed_dim + self.enc_blocks = nn.ModuleList( + [ + Block( + enc_embed_dim, + enc_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + rope=self.rope, + attn_implementation=attn_implementation + ) + for i in range(enc_depth) + ] + ) + self.enc_norm = norm_layer(enc_embed_dim) + + # masked tokens + self._set_mask_token(dec_embed_dim) + + # decoder + self._set_decoder( + enc_embed_dim, + dec_embed_dim, + dec_num_heads, + dec_depth, + mlp_ratio, + norm_layer, + norm_im2_in_dec, + ) + + # prediction head + self._set_prediction_head(dec_embed_dim, patch_size) + + # initializer weights + self.initialize_weights() + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim) + + def _set_mask_generator(self, num_patches, mask_ratio): + self.mask_generator = RandomMask(num_patches, mask_ratio) + + def _set_mask_token(self, dec_embed_dim): + self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim)) + + def _set_decoder( + self, + enc_embed_dim, + dec_embed_dim, + dec_num_heads, + dec_depth, + mlp_ratio, + norm_layer, + norm_im2_in_dec, + ): + self.dec_depth = dec_depth + self.dec_embed_dim = dec_embed_dim + # transfer from encoder to decoder + self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) + # transformer for the decoder + self.dec_blocks = nn.ModuleList( + [ + DecoderBlock( + dec_embed_dim, + dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + norm_mem=norm_im2_in_dec, + rope=self.rope, + attn_implementation=self.attn_implementation + ) + for i in range(dec_depth) + ] + ) + # final norm layer + self.dec_norm = norm_layer(dec_embed_dim) + + def _set_prediction_head(self, dec_embed_dim, patch_size): + self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True) + + def initialize_weights(self): + # patch embed + self.patch_embed._init_weights() + # mask tokens + if self.mask_token is not None: + torch.nn.init.normal_(self.mask_token, std=0.02) + # linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _encode_image(self, image, do_mask=False, return_all_blocks=False): + """ + image has B x 3 x img_size x img_size + do_mask: whether to perform masking or not + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + """ + # embed the image into patches (x has size B x Npatches x C) + # and get position if each return patch (pos has size B x Npatches x 2) + x, pos = self.patch_embed(image) + # add positional embedding without cls token + if self.enc_pos_embed is not None: + x = x + self.enc_pos_embed[None, ...] + # apply masking + B, N, C = x.size() + if do_mask: + masks = self.mask_generator(x) + x = x[~masks].view(B, -1, C) + posvis = pos[~masks].view(B, -1, 2) + else: + B, N, C = x.size() + masks = torch.zeros((B, N), dtype=bool) + posvis = pos + # now apply the transformer encoder and normalization + if return_all_blocks: + out = [] + for blk in self.enc_blocks: + x = blk(x, posvis) + out.append(x) + out[-1] = self.enc_norm(out[-1]) + return out, pos, masks + else: + for blk in self.enc_blocks: + x = blk(x, posvis) + x = self.enc_norm(x) + return x, pos, masks + + def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False): + """ + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + + masks1 can be None => assume image1 fully visible + """ + # encoder to decoder layer + visf1 = self.decoder_embed(feat1) + f2 = self.decoder_embed(feat2) + # append masked tokens to the sequence + B, Nenc, C = visf1.size() + if masks1 is None: # downstreams + f1_ = visf1 + else: # pretraining + Ntotal = masks1.size(1) + f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype) + f1_[~masks1] = visf1.view(B * Nenc, C) + # add positional embedding + if self.dec_pos_embed is not None: + f1_ = f1_ + self.dec_pos_embed + f2 = f2 + self.dec_pos_embed + # apply Transformer blocks + out = f1_ + out2 = f2 + if return_all_blocks: + _out, out = out, [] + for blk in self.dec_blocks: + _out, out2 = blk(_out, out2, pos1, pos2) + out.append(_out) + out[-1] = self.dec_norm(out[-1]) + else: + for blk in self.dec_blocks: + out, out2 = blk(out, out2, pos1, pos2) + out = self.dec_norm(out) + return out + + def patchify(self, imgs): + """ + imgs: (B, 3, H, W) + x: (B, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + + return x + + def unpatchify(self, x, channels=3): + """ + x: (N, L, patch_size**2 *channels) + imgs: (N, 3, H, W) + """ + patch_size = self.patch_embed.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size)) + return imgs + + def forward(self, img1, img2): + """ + img1: tensor of size B x 3 x img_size x img_size + img2: tensor of size B x 3 x img_size x img_size + + out will be B x N x (3*patch_size*patch_size) + masks are also returned as B x N just in case + """ + # encoder of the masked first image + feat1, pos1, mask1 = self._encode_image(img1, do_mask=True) + # encoder of the second image + feat2, pos2, _ = self._encode_image(img2, do_mask=False) + # decoder + decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2) + # prediction head + out = self.prediction_head(decfeat) + # get target + target = self.patchify(img1) + return out, mask1, target diff --git a/stream3r/croco/models/croco_downstream.py b/stream3r/croco/models/croco_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..0220607b0fa378be18fcb20c1a46c0315365befc --- /dev/null +++ b/stream3r/croco/models/croco_downstream.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# CroCo model for downstream tasks +# -------------------------------------------------------- + +import torch + +from .croco import CroCoNet + + +def croco_args_from_ckpt(ckpt): + if "croco_kwargs" in ckpt: # CroCo v2 released models + return ckpt["croco_kwargs"] + elif "args" in ckpt and hasattr( + ckpt["args"], "model" + ): # pretrained using the official code release + s = ckpt[ + "args" + ].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)" + assert s.startswith("CroCoNet(") + return eval( + "dict" + s[len("CroCoNet") :] + ) # transform it into the string of a dictionary and evaluate it + else: # CroCo v1 released models + return dict() + + +class CroCoDownstreamMonocularEncoder(CroCoNet): + def __init__(self, head, **kwargs): + """Build network for monocular downstream task, only using the encoder. + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + NOTE: It works by *calling super().__init__() but with redefined setters + + """ + super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """No mask generator""" + return + + def _set_mask_token(self, *args, **kwargs): + """No mask token""" + self.mask_token = None + return + + def _set_decoder(self, *args, **kwargs): + """No decoder""" + return + + def _set_prediction_head(self, *args, **kwargs): + """No 'prediction head' for downstream tasks.""" + return + + def forward(self, img): + """ + img if of size batch_size x 3 x h x w + """ + B, C, H, W = img.size() + img_info = {"height": H, "width": W} + need_all_layers = ( + hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks + ) + out, _, _ = self._encode_image( + img, do_mask=False, return_all_blocks=need_all_layers + ) + return self.head(out, img_info) + + +class CroCoDownstreamBinocular(CroCoNet): + def __init__(self, head, **kwargs): + """Build network for binocular downstream task + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + """ + super(CroCoDownstreamBinocular, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """No mask generator""" + return + + def _set_mask_token(self, *args, **kwargs): + """No mask token""" + self.mask_token = None + return + + def _set_prediction_head(self, *args, **kwargs): + """No prediction head for downstream tasks, define your own head""" + return + + def encode_image_pairs(self, img1, img2, return_all_blocks=False): + """run encoder for a pair of images + it is actually ~5% faster to concatenate the images along the batch dimension + than to encode them separately + """ + ## the two commented lines below is the naive version with separate encoding + # out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks) + # out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False) + ## and now the faster version + out, pos, _ = self._encode_image( + torch.cat((img1, img2), dim=0), + do_mask=False, + return_all_blocks=return_all_blocks, + ) + if return_all_blocks: + out, out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) + out2 = out2[-1] + else: + out, out2 = out.chunk(2, dim=0) + pos, pos2 = pos.chunk(2, dim=0) + return out, out2, pos, pos2 + + def forward(self, img1, img2): + B, C, H, W = img1.size() + img_info = {"height": H, "width": W} + return_all_blocks = ( + hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks + ) + out, out2, pos, pos2 = self.encode_image_pairs( + img1, img2, return_all_blocks=return_all_blocks + ) + if return_all_blocks: + decout = self._decoder( + out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks + ) + decout = out + decout + else: + decout = self._decoder( + out, pos, None, out2, pos2, return_all_blocks=return_all_blocks + ) + return self.head(decout, img_info) diff --git a/stream3r/croco/models/curope/__init__.py b/stream3r/croco/models/curope/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..932cb72f80ad0686f65224dd11e058c763dc6dc1 --- /dev/null +++ b/stream3r/croco/models/curope/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from .curope2d import cuRoPE2D diff --git a/stream3r/croco/models/curope/curope.cpp b/stream3r/croco/models/curope/curope.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8fc67ca3e01b666c56f96280a12089fa4ec2e2a7 --- /dev/null +++ b/stream3r/croco/models/curope/curope.cpp @@ -0,0 +1,69 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include + +// forward declaration +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); + +void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) +{ + const int B = tokens.size(0); + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3) / 4; + + auto tok = tokens.accessor(); + auto pos = positions.accessor(); + + for (int b = 0; b < B; b++) { + for (int x = 0; x < 2; x++) { // y and then x (2d) + for (int n = 0; n < N; n++) { + + // grab the token position + const int p = pos[b][n][x]; + + for (int h = 0; h < H; h++) { + for (int d = 0; d < D; d++) { + // grab the two values + float u = tok[b][n][h][d+0+x*2*D]; + float v = tok[b][n][h][d+D+x*2*D]; + + // grab the cos,sin + const float inv_freq = fwd * p / powf(base, d/float(D)); + float c = cosf(inv_freq); + float s = sinf(inv_freq); + + // write the result + tok[b][n][h][d+0+x*2*D] = u*c - v*s; + tok[b][n][h][d+D+x*2*D] = v*c + u*s; + } + } + } + } + } +} + +void rope_2d( torch::Tensor tokens, // B,N,H,D + const torch::Tensor positions, // B,N,2 + const float base, + const float fwd ) +{ + TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); + TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); + TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); + TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); + TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); + TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); + + if (tokens.is_cuda()) + rope_2d_cuda( tokens, positions, base, fwd ); + else + rope_2d_cpu( tokens, positions, base, fwd ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); +} diff --git a/stream3r/croco/models/curope/curope2d.py b/stream3r/croco/models/curope/curope2d.py new file mode 100644 index 0000000000000000000000000000000000000000..5544c26cb360ef22c8042c7e5dcf71e853d21435 --- /dev/null +++ b/stream3r/croco/models/curope/curope2d.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch + +try: + import curope as _kernels # run `python setup.py install` +except ModuleNotFoundError: + from . import curope as _kernels # run `python setup.py build_ext --inplace` + + +class cuRoPE2D_func(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type='cuda') + def forward(ctx, tokens, positions, base, F0=1): + ctx.save_for_backward(positions) + ctx.saved_base = base + ctx.saved_F0 = F0 + # tokens = tokens.clone() # uncomment this if inplace doesn't work + _kernels.rope_2d(tokens, positions, base, F0) + ctx.mark_dirty(tokens) + return tokens + + @staticmethod + @torch.amp.custom_bwd(device_type='cuda') + def backward(ctx, grad_res): + positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 + _kernels.rope_2d(grad_res, positions, base, -F0) + ctx.mark_dirty(grad_res) + return grad_res, None, None, None + + +class cuRoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + + def forward(self, tokens, positions): + cuRoPE2D_func.apply(tokens.transpose(1, 2), positions, self.base, self.F0) + return tokens diff --git a/stream3r/croco/models/curope/kernels.cu b/stream3r/croco/models/curope/kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..01604c01a0e832d63bd77ef8c684145fb5d3d11b --- /dev/null +++ b/stream3r/croco/models/curope/kernels.cu @@ -0,0 +1,108 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include +#include +#include +#include + +#define CHECK_CUDA(tensor) {\ + TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ + TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } +void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} + + +template < typename scalar_t > +__global__ void rope_2d_cuda_kernel( + //scalar_t* __restrict__ tokens, + torch::PackedTensorAccessor32 tokens, + const int64_t* __restrict__ pos, + const float base, + const float fwd ) + // const int N, const int H, const int D ) +{ + // tokens shape = (B, N, H, D) + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3); + + // each block update a single token, for all heads + // each thread takes care of a single output + extern __shared__ float shared[]; + float* shared_inv_freq = shared + D; + + const int b = blockIdx.x / N; + const int n = blockIdx.x % N; + + const int Q = D / 4; + // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] + // u_Y v_Y u_X v_X + + // shared memory: first, compute inv_freq + if (threadIdx.x < Q) + shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); + __syncthreads(); + + // start of X or Y part + const int X = threadIdx.x < D/2 ? 0 : 1; + const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X + + // grab the cos,sin appropriate for me + const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; + const float cos = cosf(freq); + const float sin = sinf(freq); + /* + float* shared_cos_sin = shared + D + D/4; + if ((threadIdx.x % (D/2)) < Q) + shared_cos_sin[m+0] = cosf(freq); + else + shared_cos_sin[m+Q] = sinf(freq); + __syncthreads(); + const float cos = shared_cos_sin[m+0]; + const float sin = shared_cos_sin[m+Q]; + */ + + for (int h = 0; h < H; h++) + { + // then, load all the token for this head in shared memory + shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; + __syncthreads(); + + const float u = shared[m]; + const float v = shared[m+Q]; + + // write output + if ((threadIdx.x % (D/2)) < Q) + tokens[b][n][h][threadIdx.x] = u*cos - v*sin; + else + tokens[b][n][h][threadIdx.x] = v*cos + u*sin; + } +} + +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) +{ + const int B = tokens.size(0); // batch size + const int N = tokens.size(1); // sequence length + const int H = tokens.size(2); // number of heads + const int D = tokens.size(3); // dimension per head + + TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); + TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); + TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); + TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); + + // one block for each layer, one thread per local-max + const int THREADS_PER_BLOCK = D; + const int N_BLOCKS = B * N; // each block takes care of H*D values + const int SHARED_MEM = sizeof(float) * (D + D/4); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { + rope_2d_cuda_kernel <<>> ( + //tokens.data_ptr(), + tokens.packed_accessor32(), + pos.data_ptr(), + base, fwd); //, N, H, D ); + })); +} diff --git a/stream3r/croco/models/curope/setup.py b/stream3r/croco/models/curope/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f6be2c84ec4c57dccbf12af07a55c64094b7428a --- /dev/null +++ b/stream3r/croco/models/curope/setup.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from setuptools import setup +from torch import cuda +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# compile for all possible CUDA architectures +all_cuda_archs = cuda.get_gencode_flags().replace("compute=", "arch=").split() +# alternatively, you can list cuda archs that you want, eg: +# all_cuda_archs = [ +# '-gencode', 'arch=compute_70,code=sm_70', +# '-gencode', 'arch=compute_75,code=sm_75', +# '-gencode', 'arch=compute_80,code=sm_80', +# '-gencode', 'arch=compute_86,code=sm_86' +# ] + +setup( + name="curope", + ext_modules=[ + CUDAExtension( + name="curope", + sources=[ + "curope.cpp", + "kernels.cu", + ], + extra_compile_args=dict( + nvcc=["-O3", "--ptxas-options=-v", "--use_fast_math"] + all_cuda_archs, + cxx=["-O3"], + ), + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/stream3r/croco/models/dpt_block.py b/stream3r/croco/models/dpt_block.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0b79e8dac0cbaeb33f056db84f82354567150c --- /dev/null +++ b/stream3r/croco/models/dpt_block.py @@ -0,0 +1,534 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# DPT head for ViTs +# -------------------------------------------------------- +# References: +# https://github.com/isl-org/DPT +# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py + +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + scratch.layer_rn = nn.ModuleList( + [ + scratch.layer1_rn, + scratch.layer2_rn, + scratch.layer3_rn, + scratch.layer4_rn, + ] + ) + + return scratch + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + width_ratio=1, + ): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + self.width_ratio = width_ratio + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs, max_chunk_size=100): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + if self.width_ratio != 1: + res = F.interpolate( + res, size=(output.shape[2], output.shape[3]), mode="bilinear" + ) + + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if self.width_ratio != 1: + # and output.shape[3] < self.width_ratio * output.shape[2] + # size=(image.shape[]) + if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio: + shape = 3 * output.shape[3] + else: + shape = int(self.width_ratio * 2 * output.shape[2]) + output = F.interpolate( + output, size=(2 * output.shape[2], shape), mode="bilinear" + ) + else: + # Split input into chunks to avoid memory issues with large batches + + chunks = torch.split(output, max_chunk_size, dim=0) + outputs = [] + + for chunk in chunks: + out_chunk = nn.functional.interpolate( + chunk, + scale_factor=2, + mode="bilinear", + align_corners=self.align_corners, + ) + outputs.append(out_chunk) + + # Concatenate outputs along the batch dimension + output = torch.cat(outputs, dim=0) + + output = self.out_conv(output) + return output + + +def make_fusion_block(features, use_bn, width_ratio=1): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + width_ratio=width_ratio, + ) + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + + +class DPTOutputAdapter(nn.Module): + """DPT output adapter. + + :param num_cahnnels: Number of output channels + :param stride_level: tride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param hooks: Index of intermediate layers + :param layer_dims: Dimension of intermediate layers + :param feature_dim: Feature dimension + :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression + :param use_bn: If set to True, activates batch norm + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + + def __init__( + self, + num_channels: int = 1, + stride_level: int = 1, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ("rgb",), + hooks: List[int] = [2, 5, 8, 11], + layer_dims: List[int] = [96, 192, 384, 768], + feature_dim: int = 256, + last_dim: int = 32, + use_bn: bool = False, + dim_tokens_enc: Optional[int] = None, + head_type: str = "regression", + output_width_ratio=1, + **kwargs + ): + super().__init__() + self.num_channels = num_channels + self.stride_level = stride_level + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.dim_tokens_enc = ( + dim_tokens_enc * len(self.main_tasks) + if dim_tokens_enc is not None + else None + ) + self.head_type = head_type + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size[0] // stride_level) + self.P_W = max(1, self.patch_size[1] // stride_level) + + self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet1 = make_fusion_block( + feature_dim, use_bn, output_width_ratio + ) + self.scratch.refinenet2 = make_fusion_block( + feature_dim, use_bn, output_width_ratio + ) + self.scratch.refinenet3 = make_fusion_block( + feature_dim, use_bn, output_width_ratio + ) + self.scratch.refinenet4 = make_fusion_block( + feature_dim, use_bn, output_width_ratio + ) + + if self.head_type == "regression": + # The "DPTDepthModel" head + self.head = nn.Sequential( + nn.Conv2d( + feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1 + ), + # the act_postprocess layers upsample each patch by 8 in total, + # so self.patch_size / 8 calculates how much more we need to upsample + # to get to the full image size (remember that num_patches = image_size / patch_size) + Interpolate(scale_factor=self.patch_size[0] / 8, mode="bilinear", align_corners=True), + nn.Conv2d( + feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(True), + nn.Conv2d( + last_dim, self.num_channels, kernel_size=1, stride=1, padding=0 + ), + ) + elif self.head_type == "semseg": + # The "DPTSegmentationModel" head + self.head = nn.Sequential( + nn.Conv2d( + feature_dim, feature_dim, kernel_size=3, padding=1, bias=False + ), + nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(feature_dim, self.num_channels, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + else: + raise ValueError('DPT head_type must be "regression" or "semseg".') + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc=768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + # print(dim_tokens_enc) + + # Set up activation postprocessing layers + if isinstance(dim_tokens_enc, int): + dim_tokens_enc = 4 * [dim_tokens_enc] + + self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc] + + self.act_1_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[0], + out_channels=self.layer_dims[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[0], + out_channels=self.layer_dims[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + self.act_2_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[1], + out_channels=self.layer_dims[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + self.act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[2], + out_channels=self.layer_dims[2], + kernel_size=1, + stride=1, + padding=0, + ) + ) + + self.act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[3], + out_channels=self.layer_dims[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[3], + out_channels=self.layer_dims[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + self.act_postprocess = nn.ModuleList( + [ + self.act_1_postprocess, + self.act_2_postprocess, + self.act_3_postprocess, + self.act_4_postprocess, + ] + ) + + def adapt_tokens(self, encoder_tokens): + # Adapt tokens + x = [] + x.append(encoder_tokens[:, :]) + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: List[torch.Tensor], image_size): + # input_info: Dict): + assert ( + self.dim_tokens_enc is not None + ), "Need to call init(dim_tokens_enc) function first" + H, W = image_size + + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [ + rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers + ] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3]) + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out diff --git a/stream3r/croco/models/head_downstream.py b/stream3r/croco/models/head_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..e584f30fded3bd6af495e627735caf0d56e7deea --- /dev/null +++ b/stream3r/croco/models/head_downstream.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Heads for downstream tasks +# -------------------------------------------------------- + +""" +A head is a module where the __init__ defines only the head hyperparameters. +A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. +The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' +""" + +import torch +import torch.nn as nn + +from .dpt_block import DPTOutputAdapter + + +class PixelwiseTaskWithDPT(nn.Module): + """DPT module for CroCo. + by default, hooks_idx will be equal to: + * for encoder-only: 4 equally spread layers + * for encoder+decoder: last encoder + 3 equally spread layers of the decoder + """ + + def __init__( + self, + *, + hooks_idx=None, + layer_dims=[96, 192, 384, 768], + output_width_ratio=1, + num_channels=1, + postprocess=None, + **kwargs, + ): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_blocks = True # backbone needs to return all layers + self.postprocess = postprocess + self.output_width_ratio = output_width_ratio + self.num_channels = num_channels + self.hooks_idx = hooks_idx + self.layer_dims = layer_dims + + def setup(self, croconet): + dpt_args = { + "output_width_ratio": self.output_width_ratio, + "num_channels": self.num_channels, + } + if self.hooks_idx is None: + if hasattr(croconet, "dec_blocks"): # encoder + decoder + step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] + hooks_idx = [ + croconet.dec_depth + croconet.enc_depth - 1 - i * step + for i in range(3, -1, -1) + ] + else: # encoder only + step = croconet.enc_depth // 4 + hooks_idx = [ + croconet.enc_depth - 1 - i * step for i in range(3, -1, -1) + ] + self.hooks_idx = hooks_idx + print( + f" PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}" + ) + dpt_args["hooks"] = self.hooks_idx + dpt_args["layer_dims"] = self.layer_dims + self.dpt = DPTOutputAdapter(**dpt_args) + dim_tokens = [ + croconet.enc_embed_dim + if hook < croconet.enc_depth + else croconet.dec_embed_dim + for hook in self.hooks_idx + ] + dpt_init_args = {"dim_tokens_enc": dim_tokens} + self.dpt.init(**dpt_init_args) + + def forward(self, x, img_info): + out = self.dpt(x, image_size=(img_info["height"], img_info["width"])) + if self.postprocess: + out = self.postprocess(out) + return out diff --git a/stream3r/croco/models/masking.py b/stream3r/croco/models/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..f1cb14a223f9c6db6fbb163ce9db6acc4fd7961c --- /dev/null +++ b/stream3r/croco/models/masking.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Masking utils +# -------------------------------------------------------- + +import torch +import torch.nn as nn + + +class RandomMask(nn.Module): + """ + random masking + """ + + def __init__(self, num_patches, mask_ratio): + super().__init__() + self.num_patches = num_patches + self.num_mask = int(mask_ratio * self.num_patches) + + def __call__(self, x): + noise = torch.rand(x.size(0), self.num_patches, device=x.device) + argsort = torch.argsort(noise, dim=1) + return argsort < self.num_mask diff --git a/stream3r/croco/models/perceiver_block.py b/stream3r/croco/models/perceiver_block.py new file mode 100644 index 0000000000000000000000000000000000000000..d5116b8da37cebf843f9bebddb947d00a45ebf8b --- /dev/null +++ b/stream3r/croco/models/perceiver_block.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn + +from blocks import CrossAttention, Block +from pos_embed import get_1d_sincos_pos_embed_from_grid + +class PerceiverCompressor(nn.Module): + def __init__(self, token_dim, latent_dim, num_latents, num_cross_layers, num_latent_transformer_layers, num_heads=8, dropout=0.0, norm_layer=nn.LayerNorm): + super(PerceiverCompressor, self).__init__() + self.token_dim = token_dim + self.latent_dim = latent_dim + self.num_latents = num_latents + self.num_heads = num_heads + self.num_cross_layers = num_cross_layers + self.num_latent_transformer_layers = num_latent_transformer_layers + + # Learnable latents + self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) + + # Cross-attention and latent transformer layers + self.cross_attention_layers = nn.ModuleList() + for _ in range(num_cross_layers): + cross_attn_layer = nn.ModuleDict({ + 'cross_attn': CrossAttention( + dim=latent_dim, num_heads=num_heads, qkv_bias=True, attn_drop=dropout, proj_drop=dropout + ), + 'latent_transformer': nn.ModuleList([ + Block( + dim=latent_dim, num_heads=num_heads, mlp_ratio=4.0, qkv_bias=True, drop=dropout, attn_drop=dropout, norm_layer=norm_layer + ) for _ in range(num_latent_transformer_layers) + ]), + 'norm1': norm_layer(latent_dim), + 'norm2': norm_layer(latent_dim), + 'norm_x': norm_layer(latent_dim) + }) + self.cross_attention_layers.append(cross_attn_layer) + + def forward(self, x, pos, image_ids): + """ + Args: + x (torch.Tensor): Input tensor of shape [B, P, C] where + B - batch size + P - total number of patches from all images + C - dimension of each visual token + pos (torch.Tensor): Positional tensor of shape [B, P, 2] indicating positions + image_ids (torch.Tensor): Tensor of shape [B, P] specifying which image each patch belongs to + Returns: + torch.Tensor: Compressed latent representation of shape [B, L, D] where + L - number of latents + D - dimension of each latent representation + """ + B, P, C = x.shape + + # Repeat the latents for each batch + latents = self.latents.unsqueeze(0).expand(B, -1, -1) + + # Compute image positional encoding dynamically + num_images = (torch.max(image_ids) + 1).cpu().item() + image_pos_emb = torch.from_numpy( + get_1d_sincos_pos_embed_from_grid(self.token_dim, np.arange(num_images)) + ).float().to(x.device) + + # Add image positional encoding to distinguish image sources + image_pos = image_pos_emb[image_ids] + x += image_pos + + # Alternate between cross-attention and latent transformer layers + for layer in self.cross_attention_layers: + # first, compress the image tokens into latents + latents = layer['cross_attn'](query=layer['norm1'](latents), key=layer['norm_x'](x), value=layer['norm_x'](x), qpos=None, kpos=pos) + # then, self-attend the latents to refine them + for latent_transformer_layer in layer['latent_transformer']: + latents = latent_transformer_layer(x=layer['norm2'](latents), xpos=None) + + return latents + + + + + +# Example usage +B, P, C = 2, 100*256, 768 # Example dimensions (batch size, total patches, token dimension) +L, D = 1000, 768 # Latent dimensions +num_cross_layers = 4 +num_latent_transformer_layers = 2 +num_heads = 8 +dropout = 0.1 + +compressor = PerceiverCompressor(token_dim=C, latent_dim=D, num_latents=L, num_cross_layers=num_cross_layers, num_latent_transformer_layers=num_latent_transformer_layers, num_heads=num_heads, dropout=dropout).cuda() +input_tensor = torch.randn(B, P, C).cuda() +pos_tensor = torch.randn(B, P, 2).cuda() # Example positional tensor +image_ids = torch.tensor([[i] * 256 for i in range(100)] * B).cuda().reshape(B, -1) # Example image IDs for patches +output_tensor = compressor(input_tensor, pos_tensor, image_ids) +print(output_tensor.shape) # Should print torch.Size([1, 1000, 768]) diff --git a/stream3r/croco/models/pos_embed.py b/stream3r/croco/models/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..31b9c33859dd272efc98d805292e7619b26ba067 --- /dev/null +++ b/stream3r/croco/models/pos_embed.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + + +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + + +import numpy as np +import torch + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if n_cls_token > 0: + pos_embed = np.concatenate( + [np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +# ---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +# ---------------------------------------------------------- + +try: + from stream3r.croco.models.curope import cuRoPE2D + + RoPE2D = cuRoPE2D +except ImportError: + print( + "Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead" + ) + + class RoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, D, 2).float().to(device) / D) + ) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert ( + tokens.size(3) % 2 == 0 + ), "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin( + D, int(positions.max()) + 1, tokens.device, tokens.dtype + ) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) + x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens diff --git a/stream3r/croco/pretrain.py b/stream3r/croco/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..53ddeb03c1cf7bad13e10d28454fb70ff4600884 --- /dev/null +++ b/stream3r/croco/pretrain.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Pre-training CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import argparse +import datetime +import json +import math +import os +import sys +import time +from pathlib import Path +from typing import Iterable + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torchvision.datasets as datasets +import torchvision.transforms as transforms +import utils.misc as misc +from datasets.pairs_dataset import PairsDataset +from models.criterion import MaskedMSE +from models.croco import CroCoNet +from torch.utils.tensorboard import SummaryWriter +from utils.misc import NativeScalerWithGradNormCount as NativeScaler + + +def get_args_parser(): + parser = argparse.ArgumentParser("CroCo pre-training", add_help=False) + # model and criterion + parser.add_argument( + "--model", + default="CroCoNet()", + type=str, + help="string containing the model to build", + ) + parser.add_argument( + "--norm_pix_loss", + default=1, + choices=[0, 1], + help="apply per-patch mean/std normalization before applying the loss", + ) + # dataset + parser.add_argument( + "--dataset", default="habitat_release", type=str, help="training set" + ) + parser.add_argument( + "--transforms", default="crop224+acolor", type=str, help="transforms to apply" + ) # in the paper, we also use some homography and rotation, but find later that they were not useful or even harmful + # training + parser.add_argument("--seed", default=0, type=int, help="Random seed") + parser.add_argument( + "--batch_size", + default=64, + type=int, + help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus", + ) + parser.add_argument( + "--epochs", + default=800, + type=int, + help="Maximum number of epochs for the scheduler", + ) + parser.add_argument( + "--max_epoch", default=400, type=int, help="Stop training at this epoch" + ) + parser.add_argument( + "--accum_iter", + default=1, + type=int, + help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)", + ) + parser.add_argument( + "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)" + ) + parser.add_argument( + "--lr", + type=float, + default=None, + metavar="LR", + help="learning rate (absolute lr)", + ) + parser.add_argument( + "--blr", + type=float, + default=1.5e-4, + metavar="LR", + help="base learning rate: absolute_lr = base_lr * total_batch_size / 256", + ) + parser.add_argument( + "--min_lr", + type=float, + default=0.0, + metavar="LR", + help="lower lr bound for cyclic schedulers that hit 0", + ) + parser.add_argument( + "--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR" + ) + parser.add_argument( + "--amp", + type=int, + default=1, + choices=[0, 1], + help="Use Automatic Mixed Precision for pretraining", + ) + # others + parser.add_argument("--num_workers", default=8, type=int) + parser.add_argument( + "--world_size", default=1, type=int, help="number of distributed processes" + ) + parser.add_argument("--local_rank", default=-1, type=int) + parser.add_argument( + "--dist_url", default="env://", help="url used to set up distributed training" + ) + parser.add_argument( + "--save_freq", + default=1, + type=int, + help="frequence (number of epochs) to save checkpoint in checkpoint-last.pth", + ) + parser.add_argument( + "--keep_freq", + default=20, + type=int, + help="frequence (number of epochs) to save checkpoint in checkpoint-%d.pth", + ) + parser.add_argument( + "--print_freq", + default=20, + type=int, + help="frequence (number of iterations) to print infos while training", + ) + # paths + parser.add_argument( + "--output_dir", + default="./output/", + type=str, + help="path where to save the output", + ) + parser.add_argument( + "--data_dir", default="./data/", type=str, help="path where data are stored" + ) + return parser + + +def main(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + world_size = misc.get_world_size() + + print("output_dir: " + args.output_dir) + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + # auto resume + last_ckpt_fname = os.path.join(args.output_dir, f"checkpoint-last.pth") + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(", ", ",\n")) + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # fix the seed + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + ## training dataset and loader + print( + "Building dataset for {:s} with transforms {:s}".format( + args.dataset, args.transforms + ) + ) + dataset = PairsDataset(args.dataset, trfs=args.transforms, data_dir=args.data_dir) + if world_size > 1: + sampler_train = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + else: + sampler_train = torch.utils.data.RandomSampler(dataset) + data_loader_train = torch.utils.data.DataLoader( + dataset, + sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + + ## model + print("Loading model: {:s}".format(args.model)) + model = eval(args.model) + print( + "Loading criterion: MaskedMSE(norm_pix_loss={:s})".format( + str(bool(args.norm_pix_loss)) + ) + ) + criterion = MaskedMSE(norm_pix_loss=bool(args.norm_pix_loss)) + + model.to(device) + model_without_ddp = model + print("Model = %s" % str(model_without_ddp)) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True + ) + model_without_ddp = model.module + + param_groups = misc.get_parameter_groups( + model_without_ddp, args.weight_decay + ) # following timm: set wd as 0 for bias and norm layers + optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) + print(optimizer) + loss_scaler = NativeScaler() + + misc.load_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + ) + + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter(log_dir=args.output_dir) + else: + log_writer = None + + print(f"Start training until {args.max_epoch} epochs") + start_time = time.time() + for epoch in range(args.start_epoch, args.max_epoch): + if world_size > 1: + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, + criterion, + data_loader_train, + optimizer, + device, + epoch, + loss_scaler, + log_writer=log_writer, + args=args, + ) + + if args.output_dir and epoch % args.save_freq == 0: + misc.save_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + fname="last", + ) + + if ( + args.output_dir + and (epoch % args.keep_freq == 0 or epoch + 1 == args.max_epoch) + and (epoch > 0 or args.max_epoch == 1) + ): + misc.save_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + ) + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + "epoch": epoch, + } + + if args.output_dir and misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + +def train_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + data_loader: Iterable, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, + loss_scaler, + log_writer=None, + args=None, +): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) + header = "Epoch: [{}]".format(epoch) + accum_iter = args.accum_iter + + optimizer.zero_grad() + + if log_writer is not None: + print("log_dir: {}".format(log_writer.log_dir)) + + for data_iter_step, (image1, image2) in enumerate( + metric_logger.log_every(data_loader, args.print_freq, header) + ): + # we use a per iteration lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate( + optimizer, data_iter_step / len(data_loader) + epoch, args + ) + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + with torch.cuda.amp.autocast(enabled=bool(args.amp)): + out, mask, target = model(image1, image2) + loss = criterion(out, mask, target) + + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler( + loss, + optimizer, + parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0, + ) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + loss_value_reduce = misc.all_reduce_mean(loss_value) + if ( + log_writer is not None + and ((data_iter_step + 1) % (accum_iter * args.print_freq)) == 0 + ): + # x-axis is based on epoch_1000x in the tensorboard, calibrating differences curves when batch size changes + epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) + log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x) + log_writer.add_scalar("lr", lr, epoch_1000x) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +if __name__ == "__main__": + args = get_args_parser() + args = args.parse_args() + main(args) diff --git a/stream3r/croco/stereoflow/README.MD b/stream3r/croco/stereoflow/README.MD new file mode 100644 index 0000000000000000000000000000000000000000..ca802b2fe8d1c84c04d717234d3db9674691ed67 --- /dev/null +++ b/stream3r/croco/stereoflow/README.MD @@ -0,0 +1,318 @@ +## CroCo-Stereo and CroCo-Flow + +This README explains how to use CroCo-Stereo and CroCo-Flow as well as how they were trained. +All commands should be launched from the root directory. + +### Simple inference example + +We provide a simple inference exemple for CroCo-Stereo and CroCo-Flow in the Totebook `croco-stereo-flow-demo.ipynb`. +Before running it, please download the trained models with: +``` +bash stereoflow/download_model.sh crocostereo.pth +bash stereoflow/download_model.sh crocoflow.pth +``` + +### Prepare data for training or evaluation + +Put the datasets used for training/evaluation in `./data/stereoflow` (or update the paths at the top of `stereoflow/datasets_stereo.py` and `stereoflow/datasets_flow.py`). +Please find below on the file structure should look for each dataset: +
+FlyingChairs + +``` +./data/stereoflow/FlyingChairs/ +└───chairs_split.txt +└───data/ + └─── ... +``` +
+ +
+MPI-Sintel + +``` +./data/stereoflow/MPI-Sintel/ +└───training/ +│ └───clean/ +│ └───final/ +│ └───flow/ +└───test/ + └───clean/ + └───final/ +``` +
+ +
+SceneFlow (including FlyingThings) + +``` +./data/stereoflow/SceneFlow/ +└───Driving/ +│ └───disparity/ +│ └───frames_cleanpass/ +│ └───frames_finalpass/ +└───FlyingThings/ +│ └───disparity/ +│ └───frames_cleanpass/ +│ └───frames_finalpass/ +│ └───optical_flow/ +└───Monkaa/ + └───disparity/ + └───frames_cleanpass/ + └───frames_finalpass/ +``` +
+ +
+TartanAir + +``` +./data/stereoflow/TartanAir/ +└───abandonedfactory/ +│ └───.../ +└───abandonedfactory_night/ +│ └───.../ +└───.../ +``` +
+ +
+Booster + +``` +./data/stereoflow/booster_gt/ +└───train/ + └───balanced/ + └───Bathroom/ + └───Bedroom/ + └───... +``` +
+ +
+CREStereo + +``` +./data/stereoflow/crenet_stereo_trainset/ +└───stereo_trainset/ + └───crestereo/ + └───hole/ + └───reflective/ + └───shapenet/ + └───tree/ +``` +
+ +
+ETH3D Two-view Low-res + +``` +./data/stereoflow/eth3d_lowres/ +└───test/ +│ └───lakeside_1l/ +│ └───... +└───train/ +│ └───delivery_area_1l/ +│ └───... +└───train_gt/ + └───delivery_area_1l/ + └───... +``` +
+ +
+KITTI 2012 + +``` +./data/stereoflow/kitti-stereo-2012/ +└───testing/ +│ └───colored_0/ +│ └───colored_1/ +└───training/ + └───colored_0/ + └───colored_1/ + └───disp_occ/ + └───flow_occ/ +``` +
+ +
+KITTI 2015 + +``` +./data/stereoflow/kitti-stereo-2015/ +└───testing/ +│ └───image_2/ +│ └───image_3/ +└───training/ + └───image_2/ + └───image_3/ + └───disp_occ_0/ + └───flow_occ/ +``` +
+ +
+Middlebury + +``` +./data/stereoflow/middlebury +└───2005/ +│ └───train/ +│ └───Art/ +│ └───... +└───2006/ +│ └───Aloe/ +│ └───Baby1/ +│ └───... +└───2014/ +│ └───Adirondack-imperfect/ +│ └───Adirondack-perfect/ +│ └───... +└───2021/ +│ └───data/ +│ └───artroom1/ +│ └───artroom2/ +│ └───... +└───MiddEval3_F/ + └───test/ + │ └───Australia/ + │ └───... + └───train/ + └───Adirondack/ + └───... +``` +
+ +
+Spring + +``` +./data/stereoflow/spring/ +└───test/ +│ └───0003/ +│ └───... +└───train/ + └───0001/ + └───... +``` +
+ + +### CroCo-Stereo + +##### Main model + +The main training of CroCo-Stereo was performed on a series of datasets, and it was used as it for Middlebury v3 benchmark. + +``` +# Download the model +bash stereoflow/download_model.sh crocostereo.pth +# Middlebury v3 submission +python stereoflow/test.py --model stereoflow_models/crocostereo.pth --dataset "MdEval3('all_full')" --save submission --tile_overlap 0.9 +# Training command that was used, using checkpoint-last.pth +python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main/ +# or it can be launched on multiple gpus (while maintaining the effective batch size), e.g. on 3 gpus: +torchrun --nproc_per_node 3 stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 2 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main/ +``` + +For evaluation of validation set, we also provide the model trained on the `subtrain` subset of the training sets. + +``` +# Download the model +bash stereoflow/download_model.sh crocostereo_subtrain.pth +# Evaluation on validation sets +python stereoflow/test.py --model stereoflow_models/crocostereo_subtrain.pth --dataset "MdEval3('subval_full')+ETH3DLowRes('subval')+SceneFlow('test_finalpass')+SceneFlow('test_cleanpass')" --save metrics --tile_overlap 0.9 +# Training command that was used (same as above but on subtrain, using checkpoint-best.pth), can also be launched on multiple gpus +python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('subtrain')+50*Md05('subtrain')+50*Md06('subtrain')+50*Md14('subtrain')+50*Md21('subtrain')+50*MdEval3('subtrain_full')+Booster('subtrain_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main_subtrain/ +``` + +##### Other models + +
+ Model for ETH3D + The model used for the submission on ETH3D is trained with the same command but using an unbounded Laplacian loss. + + # Download the model + bash stereoflow/download_model.sh crocostereo_eth3d.pth + # ETH3D submission + python stereoflow/test.py --model stereoflow_models/crocostereo_eth3d.pth --dataset "ETH3DLowRes('all')" --save submission --tile_overlap 0.9 + # Training command that was used + python -u stereoflow/train.py stereo --criterion "LaplacianLoss()" --tile_conf_mode conf_expbeta3 --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main_eth3d/ + +
+ +
+ Main model finetuned on Kitti + + # Download the model + bash stereoflow/download_model.sh crocostereo_finetune_kitti.pth + # Kitti submission + python stereoflow/test.py --model stereoflow_models/crocostereo_finetune_kitti.pth --dataset "Kitti15('test')" --save submission --tile_overlap 0.9 + # Training that was used + python -u stereoflow/train.py stereo --crop 352 1216 --criterion "LaplacianLossBounded2()" --dataset "Kitti12('train')+Kitti15('train')" --lr 3e-5 --batch_size 1 --accum_iter 6 --epochs 20 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocostereo.pth --output_dir xps/crocostereo/finetune_kitti/ --save_every 5 +
+ +
+ Main model finetuned on Spring + + # Download the model + bash stereoflow/download_model.sh crocostereo_finetune_spring.pth + # Spring submission + python stereoflow/test.py --model stereoflow_models/crocostereo_finetune_spring.pth --dataset "Spring('test')" --save submission --tile_overlap 0.9 + # Training command that was used + python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "Spring('train')" --lr 3e-5 --batch_size 6 --epochs 8 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocostereo.pth --output_dir xps/crocostereo/finetune_spring/ +
+ +
+ Smaller models + To train CroCo-Stereo with smaller CroCo pretrained models, simply replace the --pretrained argument. To download the smaller CroCo-Stereo models based on CroCo v2 pretraining with ViT-Base encoder and Small encoder, use bash stereoflow/download_model.sh crocostereo_subtrain_vitb_smalldecoder.pth, and for the model with a ViT-Base encoder and a Base decoder, use bash stereoflow/download_model.sh crocostereo_subtrain_vitb_basedecoder.pth. +
+ + +### CroCo-Flow + +##### Main model + +The main training of CroCo-Flow was performed on the FlyingThings, FlyingChairs, MPI-Sintel and TartanAir datasets. +It was used for our submission to the MPI-Sintel benchmark. + +``` +# Download the model +bash stereoflow/download_model.sh crocoflow.pth +# Evaluation +python stereoflow/test.py --model stereoflow_models/crocoflow.pth --dataset "MPISintel('subval_cleanpass')+MPISintel('subval_finalpass')" --save metrics --tile_overlap 0.9 +# Sintel submission +python stereoflow/test.py --model stereoflow_models/crocoflow.pth --dataset "MPISintel('test_allpass')" --save submission --tile_overlap 0.9 +# Training command that was used, with checkpoint-best.pth +python -u stereoflow/train.py flow --criterion "LaplacianLossBounded()" --dataset "40*MPISintel('subtrain_cleanpass')+40*MPISintel('subtrain_finalpass')+4*FlyingThings('train_allpass')+4*FlyingChairs('train')+TartanAir('train')" --val_dataset "MPISintel('subval_cleanpass')+MPISintel('subval_finalpass')" --lr 2e-5 --batch_size 8 --epochs 240 --img_per_epoch 30000 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocoflow/main/ +``` + +##### Other models + +
+ Main model finetuned on Kitti + + # Download the model + bash stereoflow/download_model.sh crocoflow_finetune_kitti.pth + # Kitti submission + python stereoflow/test.py --model stereoflow_models/crocoflow_finetune_kitti.pth --dataset "Kitti15('test')" --save submission --tile_overlap 0.99 + # Training that was used, with checkpoint-last.pth + python -u stereoflow/train.py flow --crop 352 1216 --criterion "LaplacianLossBounded()" --dataset "Kitti15('train')+Kitti12('train')" --lr 2e-5 --batch_size 1 --accum_iter 8 --epochs 150 --save_every 5 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocoflow.pth --output_dir xps/crocoflow/finetune_kitti/ +
+ +
+ Main model finetuned on Spring + + # Download the model + bash stereoflow/download_model.sh crocoflow_finetune_spring.pth + # Spring submission + python stereoflow/test.py --model stereoflow_models/crocoflow_finetune_spring.pth --dataset "Spring('test')" --save submission --tile_overlap 0.9 + # Training command that was used, with checkpoint-last.pth + python -u stereoflow/train.py flow --criterion "LaplacianLossBounded()" --dataset "Spring('train')" --lr 2e-5 --batch_size 8 --epochs 12 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocoflow.pth --output_dir xps/crocoflow/finetune_spring/ +
+ +
+ Smaller models + To train CroCo-Flow with smaller CroCo pretrained models, simply replace the --pretrained argument. To download the smaller CroCo-Flow models based on CroCo v2 pretraining with ViT-Base encoder and Small encoder, use bash stereoflow/download_model.sh crocoflow_vitb_smalldecoder.pth, and for the model with a ViT-Base encoder and a Base decoder, use bash stereoflow/download_model.sh crocoflow_vitb_basedecoder.pth. +
diff --git a/stream3r/croco/stereoflow/augmentor.py b/stream3r/croco/stereoflow/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..761b117d67ff6c07674344530c352e41af52fd7f --- /dev/null +++ b/stream3r/croco/stereoflow/augmentor.py @@ -0,0 +1,394 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Data augmentation for training stereo and flow +# -------------------------------------------------------- + +# References +# https://github.com/autonomousvision/unimatch/blob/master/dataloader/stereo/transforms.py +# https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/transforms.py + + +import random + +import cv2 +import numpy as np +from PIL import Image + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +import torchvision.transforms.functional as FF +from torchvision.transforms import ColorJitter + + +class StereoAugmentor(object): + def __init__( + self, + crop_size, + scale_prob=0.5, + scale_xonly=True, + lhth=800.0, + lminscale=0.0, + lmaxscale=1.0, + hminscale=-0.2, + hmaxscale=0.4, + scale_interp_nearest=True, + rightjitterprob=0.5, + v_flip_prob=0.5, + color_aug_asym=True, + color_choice_prob=0.5, + ): + self.crop_size = crop_size + self.scale_prob = scale_prob + self.scale_xonly = scale_xonly + self.lhth = lhth + self.lminscale = lminscale + self.lmaxscale = lmaxscale + self.hminscale = hminscale + self.hmaxscale = hmaxscale + self.scale_interp_nearest = scale_interp_nearest + self.rightjitterprob = rightjitterprob + self.v_flip_prob = v_flip_prob + self.color_aug_asym = color_aug_asym + self.color_choice_prob = color_choice_prob + + def _random_scale(self, img1, img2, disp): + ch, cw = self.crop_size + h, w = img1.shape[:2] + if self.scale_prob > 0.0 and np.random.rand() < self.scale_prob: + min_scale, max_scale = ( + (self.lminscale, self.lmaxscale) + if min(h, w) < self.lhth + else (self.hminscale, self.hmaxscale) + ) + scale_x = 2.0 ** np.random.uniform(min_scale, max_scale) + scale_x = np.clip(scale_x, (cw + 8) / float(w), None) + scale_y = 1.0 + if not self.scale_xonly: + scale_y = scale_x + scale_y = np.clip(scale_y, (ch + 8) / float(h), None) + img1 = cv2.resize( + img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + img2 = cv2.resize( + img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + disp = ( + cv2.resize( + disp, + None, + fx=scale_x, + fy=scale_y, + interpolation=cv2.INTER_LINEAR + if not self.scale_interp_nearest + else cv2.INTER_NEAREST, + ) + * scale_x + ) + else: # check if we need to resize to be able to crop + h, w = img1.shape[:2] + clip_scale = (cw + 8) / float(w) + if clip_scale > 1.0: + scale_x = clip_scale + scale_y = scale_x if not self.scale_xonly else 1.0 + img1 = cv2.resize( + img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + img2 = cv2.resize( + img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + disp = ( + cv2.resize( + disp, + None, + fx=scale_x, + fy=scale_y, + interpolation=cv2.INTER_LINEAR + if not self.scale_interp_nearest + else cv2.INTER_NEAREST, + ) + * scale_x + ) + return img1, img2, disp + + def _random_crop(self, img1, img2, disp): + h, w = img1.shape[:2] + ch, cw = self.crop_size + assert ch <= h and cw <= w, (img1.shape, h, w, ch, cw) + offset_x = np.random.randint(w - cw + 1) + offset_y = np.random.randint(h - ch + 1) + img1 = img1[offset_y : offset_y + ch, offset_x : offset_x + cw] + img2 = img2[offset_y : offset_y + ch, offset_x : offset_x + cw] + disp = disp[offset_y : offset_y + ch, offset_x : offset_x + cw] + return img1, img2, disp + + def _random_vflip(self, img1, img2, disp): + # vertical flip + if self.v_flip_prob > 0 and np.random.rand() < self.v_flip_prob: + img1 = np.copy(np.flipud(img1)) + img2 = np.copy(np.flipud(img2)) + disp = np.copy(np.flipud(disp)) + return img1, img2, disp + + def _random_rotate_shift_right(self, img2): + if self.rightjitterprob > 0.0 and np.random.rand() < self.rightjitterprob: + angle, pixel = 0.1, 2 + px = np.random.uniform(-pixel, pixel) + ag = np.random.uniform(-angle, angle) + image_center = ( + np.random.uniform(0, img2.shape[0]), + np.random.uniform(0, img2.shape[1]), + ) + rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0) + img2 = cv2.warpAffine( + img2, rot_mat, img2.shape[1::-1], flags=cv2.INTER_LINEAR + ) + trans_mat = np.float32([[1, 0, 0], [0, 1, px]]) + img2 = cv2.warpAffine( + img2, trans_mat, img2.shape[1::-1], flags=cv2.INTER_LINEAR + ) + return img2 + + def _random_color_contrast(self, img1, img2): + if np.random.random() < 0.5: + contrast_factor = np.random.uniform(0.8, 1.2) + img1 = FF.adjust_contrast(img1, contrast_factor) + if self.color_aug_asym and np.random.random() < 0.5: + contrast_factor = np.random.uniform(0.8, 1.2) + img2 = FF.adjust_contrast(img2, contrast_factor) + return img1, img2 + + def _random_color_gamma(self, img1, img2): + if np.random.random() < 0.5: + gamma = np.random.uniform(0.7, 1.5) + img1 = FF.adjust_gamma(img1, gamma) + if self.color_aug_asym and np.random.random() < 0.5: + gamma = np.random.uniform(0.7, 1.5) + img2 = FF.adjust_gamma(img2, gamma) + return img1, img2 + + def _random_color_brightness(self, img1, img2): + if np.random.random() < 0.5: + brightness = np.random.uniform(0.5, 2.0) + img1 = FF.adjust_brightness(img1, brightness) + if self.color_aug_asym and np.random.random() < 0.5: + brightness = np.random.uniform(0.5, 2.0) + img2 = FF.adjust_brightness(img2, brightness) + return img1, img2 + + def _random_color_hue(self, img1, img2): + if np.random.random() < 0.5: + hue = np.random.uniform(-0.1, 0.1) + img1 = FF.adjust_hue(img1, hue) + if self.color_aug_asym and np.random.random() < 0.5: + hue = np.random.uniform(-0.1, 0.1) + img2 = FF.adjust_hue(img2, hue) + return img1, img2 + + def _random_color_saturation(self, img1, img2): + if np.random.random() < 0.5: + saturation = np.random.uniform(0.8, 1.2) + img1 = FF.adjust_saturation(img1, saturation) + if self.color_aug_asym and np.random.random() < 0.5: + saturation = np.random.uniform(-0.8, 1.2) + img2 = FF.adjust_saturation(img2, saturation) + return img1, img2 + + def _random_color(self, img1, img2): + trfs = [ + self._random_color_contrast, + self._random_color_gamma, + self._random_color_brightness, + self._random_color_hue, + self._random_color_saturation, + ] + img1 = Image.fromarray(img1.astype("uint8")) + img2 = Image.fromarray(img2.astype("uint8")) + if np.random.random() < self.color_choice_prob: + # A single transform + t = random.choice(trfs) + img1, img2 = t(img1, img2) + else: + # Combination of trfs + # Random order + random.shuffle(trfs) + for t in trfs: + img1, img2 = t(img1, img2) + img1 = np.array(img1).astype(np.float32) + img2 = np.array(img2).astype(np.float32) + return img1, img2 + + def __call__(self, img1, img2, disp, dataset_name): + img1, img2, disp = self._random_scale(img1, img2, disp) + img1, img2, disp = self._random_crop(img1, img2, disp) + img1, img2, disp = self._random_vflip(img1, img2, disp) + img2 = self._random_rotate_shift_right(img2) + img1, img2 = self._random_color(img1, img2) + return img1, img2, disp + + +class FlowAugmentor: + def __init__( + self, + crop_size, + min_scale=-0.2, + max_scale=0.5, + spatial_aug_prob=0.8, + stretch_prob=0.8, + max_stretch=0.2, + h_flip_prob=0.5, + v_flip_prob=0.1, + asymmetric_color_aug_prob=0.2, + ): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = spatial_aug_prob + self.stretch_prob = stretch_prob + self.max_stretch = max_stretch + + # flip augmentation params + self.h_flip_prob = h_flip_prob + self.v_flip_prob = v_flip_prob + + # photometric augmentation params + self.photo_aug = ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14 + ) + + self.asymmetric_color_aug_prob = asymmetric_color_aug_prob + + def color_transform(self, img1, img2): + """Photometric augmentation""" + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array( + self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 + ) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def _resize_flow(self, flow, scale_x, scale_y, factor=1.0): + if np.all(np.isfinite(flow)): + flow = cv2.resize( + flow, + None, + fx=scale_x / factor, + fy=scale_y / factor, + interpolation=cv2.INTER_LINEAR, + ) + flow = flow * [scale_x, scale_y] + else: # sparse version + fx, fy = scale_x, scale_y + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = np.isfinite(flow[:, 0]) + + coords0 = coords[valid] + flow0 = flow[valid] + + ht1 = int(round(ht * fy / factor)) + wd1 = int(round(wd * fx / factor)) + + rescale = np.expand_dims(np.array([fx, fy]), axis=0) + coords1 = coords0 * rescale / factor + flow1 = flow0 * rescale + + xx = np.round(coords1[:, 0]).astype(np.int32) + yy = np.round(coords1[:, 1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow = np.inf * np.ones( + [ht1, wd1, 2], dtype=np.float32 + ) # invalid value every where, before we fill it with the correct ones + flow[yy, xx] = flow1 + return flow + + def spatial_transform(self, img1, img2, flow, dname): + if np.random.rand() < self.spatial_aug_prob: + # randomly sample scale + ht, wd = img1.shape[:2] + clip_min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd) + ) + min_scale, max_scale = self.min_scale, self.max_scale + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_x = np.clip(scale_x, clip_min_scale, None) + scale_y = np.clip(scale_y, clip_min_scale, None) + # rescale the images + img1 = cv2.resize( + img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + img2 = cv2.resize( + img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + flow = self._resize_flow( + flow, scale_x, scale_y, factor=2.0 if dname == "Spring" else 1.0 + ) + elif dname == "Spring": + flow = self._resize_flow(flow, 1.0, 1.0, factor=2.0) + + if self.h_flip_prob > 0.0 and np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if self.v_flip_prob > 0.0 and np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + # In case no cropping + if img1.shape[0] - self.crop_size[0] > 0: + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + else: + y0 = 0 + if img1.shape[1] - self.crop_size[1] > 0: + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + else: + x0 = 0 + + img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow, dname): + img1, img2, flow = self.spatial_transform(img1, img2, flow, dname) + img1, img2 = self.color_transform(img1, img2) + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + return img1, img2, flow diff --git a/stream3r/croco/stereoflow/criterion.py b/stream3r/croco/stereoflow/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..708be3eaf47baa5b875f9ea9f7af7485118cf053 --- /dev/null +++ b/stream3r/croco/stereoflow/criterion.py @@ -0,0 +1,352 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Losses, metrics per batch, metrics per dataset +# -------------------------------------------------------- + +import torch +import torch.nn.functional as F +from torch import nn + + +def _get_gtnorm(gt): + if gt.size(1) == 1: # stereo + return gt + # flow + return torch.sqrt(torch.sum(gt**2, dim=1, keepdims=True)) # Bx1xHxW + + +############ losses without confidence + + +class L1Loss(nn.Module): + def __init__(self, max_gtnorm=None): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = False + + def _error(self, gt, predictions): + return torch.abs(gt - predictions) + + def forward(self, predictions, gt, inspect=False): + mask = torch.isfinite(gt) + if self.max_gtnorm is not None: + mask *= _get_gtnorm(gt).expand(-1, gt.size(1), -1, -1) < self.max_gtnorm + if inspect: + return self._error(gt, predictions) + return self._error(gt[mask], predictions[mask]).mean() + + +############## losses with confience +## there are several parametrizations + + +class LaplacianLoss(nn.Module): # used for CroCo-Stereo on ETH3D, d'=exp(d) + def __init__(self, max_gtnorm=None): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = True + + def forward(self, predictions, gt, conf): + mask = torch.isfinite(gt) + mask = mask[:, 0, :, :] + if self.max_gtnorm is not None: + mask *= _get_gtnorm(gt)[:, 0, :, :] < self.max_gtnorm + conf = conf.squeeze(1) + return ( + torch.abs(gt - predictions).sum(dim=1)[mask] / torch.exp(conf[mask]) + + conf[mask] + ).mean() # + torch.log(2) => which is a constant + + +class LaplacianLossBounded( + nn.Module +): # used for CroCo-Flow ; in the equation of the paper, we have a=1/b + def __init__(self, max_gtnorm=10000.0, a=0.25, b=4.0): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = True + self.a, self.b = a, b + + def forward(self, predictions, gt, conf): + mask = torch.isfinite(gt) + mask = mask[:, 0, :, :] + if self.max_gtnorm is not None: + mask *= _get_gtnorm(gt)[:, 0, :, :] < self.max_gtnorm + conf = conf.squeeze(1) + conf = (self.b - self.a) * torch.sigmoid(conf) + self.a + return ( + torch.abs(gt - predictions).sum(dim=1)[mask] / conf[mask] + + torch.log(conf)[mask] + ).mean() # + torch.log(2) => which is a constant + + +class LaplacianLossBounded2( + nn.Module +): # used for CroCo-Stereo (except for ETH3D) ; in the equation of the paper, we have a=b + def __init__(self, max_gtnorm=None, a=3.0, b=3.0): + super().__init__() + self.max_gtnorm = max_gtnorm + self.with_conf = True + self.a, self.b = a, b + + def forward(self, predictions, gt, conf): + mask = torch.isfinite(gt) + mask = mask[:, 0, :, :] + if self.max_gtnorm is not None: + mask *= _get_gtnorm(gt)[:, 0, :, :] < self.max_gtnorm + conf = conf.squeeze(1) + conf = 2 * self.a * (torch.sigmoid(conf / self.b) - 0.5) + return ( + torch.abs(gt - predictions).sum(dim=1)[mask] / torch.exp(conf[mask]) + + conf[mask] + ).mean() # + torch.log(2) => which is a constant + + +############## metrics per batch + + +class StereoMetrics(nn.Module): + def __init__(self, do_quantile=False): + super().__init__() + self.bad_ths = [0.5, 1, 2, 3] + self.do_quantile = do_quantile + + def forward(self, predictions, gt): + B = predictions.size(0) + metrics = {} + gtcopy = gt.clone() + mask = torch.isfinite(gtcopy) + gtcopy[ + ~mask + ] = 999999.0 # we make a copy and put a non-infinite value, such that it does not become nan once multiplied by the mask value 0 + Npx = mask.view(B, -1).sum(dim=1) + L1error = (torch.abs(gtcopy - predictions) * mask).view(B, -1) + L2error = (torch.square(gtcopy - predictions) * mask).view(B, -1) + # avgerr + metrics["avgerr"] = torch.mean(L1error.sum(dim=1) / Npx) + # rmse + metrics["rmse"] = torch.sqrt(L2error.sum(dim=1) / Npx).mean(dim=0) + # err > t for t in [0.5,1,2,3] + for ths in self.bad_ths: + metrics["bad@{:.1f}".format(ths)] = ( + ((L1error > ths) * mask.view(B, -1)).sum(dim=1) / Npx + ).mean(dim=0) * 100 + return metrics + + +class FlowMetrics(nn.Module): + def __init__(self): + super().__init__() + self.bad_ths = [1, 3, 5] + + def forward(self, predictions, gt): + B = predictions.size(0) + metrics = {} + mask = torch.isfinite(gt[:, 0, :, :]) # both x and y would be infinite + Npx = mask.view(B, -1).sum(dim=1) + gtcopy = ( + gt.clone() + ) # to compute L1/L2 error, we need to have non-infinite value, the error computed at this locations will be ignored + gtcopy[:, 0, :, :][~mask] = 999999.0 + gtcopy[:, 1, :, :][~mask] = 999999.0 + L1error = (torch.abs(gtcopy - predictions).sum(dim=1) * mask).view(B, -1) + L2error = ( + torch.sqrt(torch.sum(torch.square(gtcopy - predictions), dim=1)) * mask + ).view(B, -1) + metrics["L1err"] = torch.mean(L1error.sum(dim=1) / Npx) + metrics["EPE"] = torch.mean(L2error.sum(dim=1) / Npx) + for ths in self.bad_ths: + metrics["bad@{:.1f}".format(ths)] = ( + ((L2error > ths) * mask.view(B, -1)).sum(dim=1) / Npx + ).mean(dim=0) * 100 + return metrics + + +############## metrics per dataset +## we update the average and maintain the number of pixels while adding data batch per batch +## at the beggining, call reset() +## after each batch, call add_batch(...) +## at the end: call get_results() + + +class StereoDatasetMetrics(nn.Module): + def __init__(self): + super().__init__() + self.bad_ths = [0.5, 1, 2, 3] + + def reset(self): + self.agg_N = 0 # number of pixels so far + self.agg_L1err = torch.tensor(0.0) # L1 error so far + self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels + self._metrics = None + + def add_batch(self, predictions, gt): + assert predictions.size(1) == 1, predictions.size() + assert gt.size(1) == 1, gt.size() + if ( + gt.size(2) == predictions.size(2) * 2 + and gt.size(3) == predictions.size(3) * 2 + ): # special case for Spring ... + L1err = torch.minimum( + torch.minimum( + torch.minimum( + torch.sum(torch.abs(gt[:, :, 0::2, 0::2] - predictions), dim=1), + torch.sum(torch.abs(gt[:, :, 1::2, 0::2] - predictions), dim=1), + ), + torch.sum(torch.abs(gt[:, :, 0::2, 1::2] - predictions), dim=1), + ), + torch.sum(torch.abs(gt[:, :, 1::2, 1::2] - predictions), dim=1), + ) + valid = torch.isfinite(L1err) + else: + valid = torch.isfinite(gt[:, 0, :, :]) # both x and y would be infinite + L1err = torch.sum(torch.abs(gt - predictions), dim=1) + N = valid.sum() + Nnew = self.agg_N + N + self.agg_L1err = ( + float(self.agg_N) / Nnew * self.agg_L1err + + L1err[valid].mean().cpu() * float(N) / Nnew + ) + self.agg_N = Nnew + for i, th in enumerate(self.bad_ths): + self.agg_Nbad[i] += (L1err[valid] > th).sum().cpu() + + def _compute_metrics(self): + if self._metrics is not None: + return + out = {} + out["L1err"] = self.agg_L1err.item() + for i, th in enumerate(self.bad_ths): + out["bad@{:.1f}".format(th)] = ( + float(self.agg_Nbad[i]) / self.agg_N + ).item() * 100.0 + self._metrics = out + + def get_results(self): + self._compute_metrics() # to avoid recompute them multiple times + return self._metrics + + +class FlowDatasetMetrics(nn.Module): + def __init__(self): + super().__init__() + self.bad_ths = [0.5, 1, 3, 5] + self.speed_ths = [(0, 10), (10, 40), (40, torch.inf)] + + def reset(self): + self.agg_N = 0 # number of pixels so far + self.agg_L1err = torch.tensor(0.0) # L1 error so far + self.agg_L2err = torch.tensor(0.0) # L2 (=EPE) error so far + self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels + self.agg_EPEspeed = [ + torch.tensor(0.0) for _ in self.speed_ths + ] # EPE per speed bin so far + self.agg_Nspeed = [0 for _ in self.speed_ths] # N pixels per speed bin so far + self._metrics = None + self.pairname_results = {} + + def add_batch(self, predictions, gt): + assert predictions.size(1) == 2, predictions.size() + assert gt.size(1) == 2, gt.size() + if ( + gt.size(2) == predictions.size(2) * 2 + and gt.size(3) == predictions.size(3) * 2 + ): # special case for Spring ... + L1err = torch.minimum( + torch.minimum( + torch.minimum( + torch.sum(torch.abs(gt[:, :, 0::2, 0::2] - predictions), dim=1), + torch.sum(torch.abs(gt[:, :, 1::2, 0::2] - predictions), dim=1), + ), + torch.sum(torch.abs(gt[:, :, 0::2, 1::2] - predictions), dim=1), + ), + torch.sum(torch.abs(gt[:, :, 1::2, 1::2] - predictions), dim=1), + ) + L2err = torch.minimum( + torch.minimum( + torch.minimum( + torch.sqrt( + torch.sum( + torch.square(gt[:, :, 0::2, 0::2] - predictions), dim=1 + ) + ), + torch.sqrt( + torch.sum( + torch.square(gt[:, :, 1::2, 0::2] - predictions), dim=1 + ) + ), + ), + torch.sqrt( + torch.sum( + torch.square(gt[:, :, 0::2, 1::2] - predictions), dim=1 + ) + ), + ), + torch.sqrt( + torch.sum(torch.square(gt[:, :, 1::2, 1::2] - predictions), dim=1) + ), + ) + valid = torch.isfinite(L1err) + gtspeed = ( + torch.sqrt(torch.sum(torch.square(gt[:, :, 0::2, 0::2]), dim=1)) + + torch.sqrt(torch.sum(torch.square(gt[:, :, 0::2, 1::2]), dim=1)) + + torch.sqrt(torch.sum(torch.square(gt[:, :, 1::2, 0::2]), dim=1)) + + torch.sqrt(torch.sum(torch.square(gt[:, :, 1::2, 1::2]), dim=1)) + ) / 4.0 # let's just average them + else: + valid = torch.isfinite(gt[:, 0, :, :]) # both x and y would be infinite + L1err = torch.sum(torch.abs(gt - predictions), dim=1) + L2err = torch.sqrt(torch.sum(torch.square(gt - predictions), dim=1)) + gtspeed = torch.sqrt(torch.sum(torch.square(gt), dim=1)) + N = valid.sum() + Nnew = self.agg_N + N + self.agg_L1err = ( + float(self.agg_N) / Nnew * self.agg_L1err + + L1err[valid].mean().cpu() * float(N) / Nnew + ) + self.agg_L2err = ( + float(self.agg_N) / Nnew * self.agg_L2err + + L2err[valid].mean().cpu() * float(N) / Nnew + ) + self.agg_N = Nnew + for i, th in enumerate(self.bad_ths): + self.agg_Nbad[i] += (L2err[valid] > th).sum().cpu() + for i, (th1, th2) in enumerate(self.speed_ths): + vv = (gtspeed[valid] >= th1) * (gtspeed[valid] < th2) + iNspeed = vv.sum() + if iNspeed == 0: + continue + iNnew = self.agg_Nspeed[i] + iNspeed + self.agg_EPEspeed[i] = ( + float(self.agg_Nspeed[i]) / iNnew * self.agg_EPEspeed[i] + + float(iNspeed) / iNnew * L2err[valid][vv].mean().cpu() + ) + self.agg_Nspeed[i] = iNnew + + def _compute_metrics(self): + if self._metrics is not None: + return + out = {} + out["L1err"] = self.agg_L1err.item() + out["EPE"] = self.agg_L2err.item() + for i, th in enumerate(self.bad_ths): + out["bad@{:.1f}".format(th)] = ( + float(self.agg_Nbad[i]) / self.agg_N + ).item() * 100.0 + for i, (th1, th2) in enumerate(self.speed_ths): + out[ + "s{:d}{:s}".format(th1, "-" + str(th2) if th2 < torch.inf else "+") + ] = self.agg_EPEspeed[i].item() + self._metrics = out + + def get_results(self): + self._compute_metrics() # to avoid recompute them multiple times + return self._metrics diff --git a/stream3r/croco/stereoflow/datasets_flow.py b/stream3r/croco/stereoflow/datasets_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..2a79d14faed21f5eb7d1f0e8b9968ad873ed8ced --- /dev/null +++ b/stream3r/croco/stereoflow/datasets_flow.py @@ -0,0 +1,934 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Dataset structure for flow +# -------------------------------------------------------- + +import os +import os.path as osp +import pickle +import struct +from copy import deepcopy + +import h5py +import numpy as np +import torch +from PIL import Image +from torch.utils import data + +from .augmentor import FlowAugmentor +from .datasets_stereo import _read_img, _read_pfm, dataset_to_root, img_to_tensor + +dataset_to_root = deepcopy(dataset_to_root) + +dataset_to_root.update( + **{ + "TartanAir": "./data/stereoflow/TartanAir", + "FlyingChairs": "./data/stereoflow/FlyingChairs/", + "FlyingThings": osp.join(dataset_to_root["SceneFlow"], "FlyingThings") + "/", + "MPISintel": "./data/stereoflow//MPI-Sintel/" + "/", + } +) +cache_dir = "./data/stereoflow/datasets_flow_cache/" + + +def flow_to_tensor(disp): + return torch.from_numpy(disp).float().permute(2, 0, 1) + + +class FlowDataset(data.Dataset): + def __init__(self, split, augmentor=False, crop_size=None, totensor=True): + self.split = split + if not augmentor: + assert crop_size is None + if crop_size is not None: + assert augmentor + self.crop_size = crop_size + self.augmentor_str = augmentor + self.augmentor = FlowAugmentor(crop_size) if augmentor else None + self.totensor = totensor + self.rmul = 1 # keep track of rmul + self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time) + self._prepare_data() + self._load_or_build_cache() + + def prepare_data(self): + """ + to be defined for each dataset + """ + raise NotImplementedError + + def __len__(self): + return len( + self.pairnames + ) # each pairname is typically of the form (str, int1, int2) + + def __getitem__(self, index): + pairname = self.pairnames[index] + + # get filenames + img1name = self.pairname_to_img1name(pairname) + img2name = self.pairname_to_img2name(pairname) + flowname = ( + self.pairname_to_flowname(pairname) + if self.pairname_to_flowname is not None + else None + ) + + # load images and disparities + img1 = _read_img(img1name) + img2 = _read_img(img2name) + flow = self.load_flow(flowname) if flowname is not None else None + + # apply augmentations + if self.augmentor is not None: + img1, img2, flow = self.augmentor(img1, img2, flow, self.name) + + if self.totensor: + img1 = img_to_tensor(img1) + img2 = img_to_tensor(img2) + if flow is not None: + flow = flow_to_tensor(flow) + else: + flow = torch.tensor( + [] + ) # to allow dataloader batching with default collate_gn + pairname = str( + pairname + ) # transform potential tuple to str to be able to batch it + + return img1, img2, flow, pairname + + def __rmul__(self, v): + self.rmul *= v + self.pairnames = v * self.pairnames + return self + + def __str__(self): + return f"{self.__class__.__name__}_{self.split}" + + def __repr__(self): + s = f"{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})" + if self.rmul == 1: + s += f"\n\tnum pairs: {len(self.pairnames)}" + else: + s += f"\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})" + return s + + def _set_root(self): + self.root = dataset_to_root[self.name] + assert os.path.isdir( + self.root + ), f"could not find root directory for dataset {self.name}: {self.root}" + + def _load_or_build_cache(self): + cache_file = osp.join(cache_dir, self.name + ".pkl") + if osp.isfile(cache_file): + with open(cache_file, "rb") as fid: + self.pairnames = pickle.load(fid)[self.split] + else: + tosave = self._build_cache() + os.makedirs(cache_dir, exist_ok=True) + with open(cache_file, "wb") as fid: + pickle.dump(tosave, fid) + self.pairnames = tosave[self.split] + + +class TartanAirDataset(FlowDataset): + def _prepare_data(self): + self.name = "TartanAir" + self._set_root() + assert self.split in ["train"] + self.pairname_to_img1name = lambda pairname: osp.join( + self.root, pairname[0], "image_left/{:06d}_left.png".format(pairname[1]) + ) + self.pairname_to_img2name = lambda pairname: osp.join( + self.root, pairname[0], "image_left/{:06d}_left.png".format(pairname[2]) + ) + self.pairname_to_flowname = lambda pairname: osp.join( + self.root, + pairname[0], + "flow/{:06d}_{:06d}_flow.npy".format(pairname[1], pairname[2]), + ) + self.pairname_to_str = lambda pairname: os.path.join( + pairname[0][pairname[0].find("/") + 1 :], + "{:06d}_{:06d}".format(pairname[1], pairname[2]), + ) + self.load_flow = _read_numpy_flow + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + pairs = [ + (osp.join(s, s, difficulty, Pxxx), int(a[:6]), int(a[:6]) + 1) + for s in seqs + for difficulty in ["Easy", "Hard"] + for Pxxx in sorted(os.listdir(osp.join(self.root, s, s, difficulty))) + for a in sorted( + os.listdir(osp.join(self.root, s, s, difficulty, Pxxx, "image_left/")) + )[:-1] + ] + assert len(pairs) == 306268, "incorrect parsing of pairs in TartanAir" + tosave = {"train": pairs} + return tosave + + +class FlyingChairsDataset(FlowDataset): + def _prepare_data(self): + self.name = "FlyingChairs" + self._set_root() + assert self.split in ["train", "val"] + self.pairname_to_img1name = lambda pairname: osp.join( + self.root, "data", pairname + "_img1.ppm" + ) + self.pairname_to_img2name = lambda pairname: osp.join( + self.root, "data", pairname + "_img2.ppm" + ) + self.pairname_to_flowname = lambda pairname: osp.join( + self.root, "data", pairname + "_flow.flo" + ) + self.pairname_to_str = lambda pairname: pairname + self.load_flow = _read_flo_file + + def _build_cache(self): + split_file = osp.join(self.root, "chairs_split.txt") + split_list = np.loadtxt(split_file, dtype=np.int32) + trainpairs = ["{:05d}".format(i) for i in np.where(split_list == 1)[0] + 1] + valpairs = ["{:05d}".format(i) for i in np.where(split_list == 2)[0] + 1] + assert ( + len(trainpairs) == 22232 and len(valpairs) == 640 + ), "incorrect parsing of pairs in MPI-Sintel" + tosave = {"train": trainpairs, "val": valpairs} + return tosave + + +class FlyingThingsDataset(FlowDataset): + def _prepare_data(self): + self.name = "FlyingThings" + self._set_root() + assert self.split in [ + f"{set_}_{pass_}pass{camstr}" + for set_ in ["train", "test", "test1024"] + for camstr in ["", "_rightcam"] + for pass_ in ["clean", "final", "all"] + ] + self.pairname_to_img1name = lambda pairname: osp.join( + self.root, + f"frames_{pairname[3]}pass", + pairname[0].replace("into_future", "").replace("into_past", ""), + "{:04d}.png".format(pairname[1]), + ) + self.pairname_to_img2name = lambda pairname: osp.join( + self.root, + f"frames_{pairname[3]}pass", + pairname[0].replace("into_future", "").replace("into_past", ""), + "{:04d}.png".format(pairname[2]), + ) + self.pairname_to_flowname = lambda pairname: osp.join( + self.root, + "optical_flow", + pairname[0], + "OpticalFlowInto{f:s}_{i:04d}_{c:s}.pfm".format( + f="Future" if "future" in pairname[0] else "Past", + i=pairname[1], + c="L" if "left" in pairname[0] else "R", + ), + ) + self.pairname_to_str = lambda pairname: os.path.join( + pairname[3] + "pass", + pairname[0], + "Into{f:s}_{i:04d}_{c:s}".format( + f="Future" if "future" in pairname[0] else "Past", + i=pairname[1], + c="L" if "left" in pairname[0] else "R", + ), + ) + self.load_flow = _read_pfm_flow + + def _build_cache(self): + tosave = {} + # train and test splits for the different passes + for set_ in ["train", "test"]: + sroot = osp.join(self.root, "optical_flow", set_.upper()) + fname_to_i = lambda f: int( + f[len("OpticalFlowIntoFuture_") : -len("_L.pfm")] + ) + pp = [ + (osp.join(set_.upper(), d, s, "into_future/left"), fname_to_i(fname)) + for d in sorted(os.listdir(sroot)) + for s in sorted(os.listdir(osp.join(sroot, d))) + for fname in sorted( + os.listdir(osp.join(sroot, d, s, "into_future/left")) + )[:-1] + ] + pairs = [(a, i, i + 1) for a, i in pp] + pairs += [(a.replace("into_future", "into_past"), i + 1, i) for a, i in pp] + assert ( + len(pairs) == {"train": 40302, "test": 7866}[set_] + ), "incorrect parsing of pairs Flying Things" + for cam in ["left", "right"]: + camstr = "" if cam == "left" else f"_{cam}cam" + for pass_ in ["final", "clean"]: + tosave[f"{set_}_{pass_}pass{camstr}"] = [ + (a.replace("left", cam), i, j, pass_) for a, i, j in pairs + ] + tosave[f"{set_}_allpass{camstr}"] = ( + tosave[f"{set_}_cleanpass{camstr}"] + + tosave[f"{set_}_finalpass{camstr}"] + ) + # test1024: this is the same split as unimatch 'validation' split + # see https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/datasets.py#L229 + test1024_nsamples = 1024 + alltest_nsamples = len(tosave["test_cleanpass"]) # 7866 + stride = alltest_nsamples // test1024_nsamples + remove = alltest_nsamples % test1024_nsamples + for cam in ["left", "right"]: + camstr = "" if cam == "left" else f"_{cam}cam" + for pass_ in ["final", "clean"]: + tosave[f"test1024_{pass_}pass{camstr}"] = sorted( + tosave[f"test_{pass_}pass{camstr}"] + )[:-remove][ + ::stride + ] # warning, it was not sorted before + assert ( + len(tosave["test1024_cleanpass"]) == 1024 + ), "incorrect parsing of pairs in Flying Things" + tosave[f"test1024_allpass{camstr}"] = ( + tosave[f"test1024_cleanpass{camstr}"] + + tosave[f"test1024_finalpass{camstr}"] + ) + return tosave + + +class MPISintelDataset(FlowDataset): + def _prepare_data(self): + self.name = "MPISintel" + self._set_root() + assert self.split in [ + s + "_" + p + for s in ["train", "test", "subval", "subtrain"] + for p in ["cleanpass", "finalpass", "allpass"] + ] + self.pairname_to_img1name = lambda pairname: osp.join( + self.root, pairname[0], "frame_{:04d}.png".format(pairname[1]) + ) + self.pairname_to_img2name = lambda pairname: osp.join( + self.root, pairname[0], "frame_{:04d}.png".format(pairname[1] + 1) + ) + self.pairname_to_flowname = ( + lambda pairname: None + if pairname[0].startswith("test/") + else osp.join( + self.root, + pairname[0].replace("/clean/", "/flow/").replace("/final/", "/flow/"), + "frame_{:04d}.flo".format(pairname[1]), + ) + ) + self.pairname_to_str = lambda pairname: osp.join( + pairname[0], "frame_{:04d}".format(pairname[1]) + ) + self.load_flow = _read_flo_file + + def _build_cache(self): + trainseqs = sorted(os.listdir(self.root + "training/clean")) + trainpairs = [ + (osp.join("training/clean", s), i) + for s in trainseqs + for i in range(1, len(os.listdir(self.root + "training/clean/" + s))) + ] + subvalseqs = ["temple_2", "temple_3"] + subtrainseqs = [s for s in trainseqs if s not in subvalseqs] + subvalpairs = [(p, i) for p, i in trainpairs if any(s in p for s in subvalseqs)] + subtrainpairs = [ + (p, i) for p, i in trainpairs if any(s in p for s in subtrainseqs) + ] + testseqs = sorted(os.listdir(self.root + "test/clean")) + testpairs = [ + (osp.join("test/clean", s), i) + for s in testseqs + for i in range(1, len(os.listdir(self.root + "test/clean/" + s))) + ] + assert ( + len(trainpairs) == 1041 + and len(testpairs) == 552 + and len(subvalpairs) == 98 + and len(subtrainpairs) == 943 + ), "incorrect parsing of pairs in MPI-Sintel" + tosave = {} + tosave["train_cleanpass"] = trainpairs + tosave["test_cleanpass"] = testpairs + tosave["subval_cleanpass"] = subvalpairs + tosave["subtrain_cleanpass"] = subtrainpairs + for t in ["train", "test", "subval", "subtrain"]: + tosave[t + "_finalpass"] = [ + (p.replace("/clean/", "/final/"), i) + for p, i in tosave[t + "_cleanpass"] + ] + tosave[t + "_allpass"] = tosave[t + "_cleanpass"] + tosave[t + "_finalpass"] + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, _time): + assert prediction.shape[2] == 2 + outfile = os.path.join( + outdir, "submission", self.pairname_to_str(pairname) + ".flo" + ) + os.makedirs(os.path.dirname(outfile), exist_ok=True) + writeFlowFile(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split == "test_allpass" + bundle_exe = "/nfs/data/ffs-3d/datasets/StereoFlow/MPI-Sintel/bundler/linux-x64/bundler" # eg + if os.path.isfile(bundle_exe): + cmd = f'{bundle_exe} "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"' + print(cmd) + os.system(cmd) + print(f'Done. Submission file at: "{outdir}/submission/bundled.lzma"') + else: + print("Could not find bundler executable for submission.") + print("Please download it and run:") + print( + f' "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"' + ) + + +class SpringDataset(FlowDataset): + def _prepare_data(self): + self.name = "Spring" + self._set_root() + assert self.split in ["train", "test", "subtrain", "subval"] + self.pairname_to_img1name = lambda pairname: osp.join( + self.root, + pairname[0], + pairname[1], + "frame_" + pairname[3], + "frame_{:s}_{:04d}.png".format(pairname[3], pairname[4]), + ) + self.pairname_to_img2name = lambda pairname: osp.join( + self.root, + pairname[0], + pairname[1], + "frame_" + pairname[3], + "frame_{:s}_{:04d}.png".format( + pairname[3], pairname[4] + (1 if pairname[2] == "FW" else -1) + ), + ) + self.pairname_to_flowname = ( + lambda pairname: None + if pairname[0] == "test" + else osp.join( + self.root, + pairname[0], + pairname[1], + f"flow_{pairname[2]}_{pairname[3]}", + f"flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5", + ) + ) + self.pairname_to_str = lambda pairname: osp.join( + pairname[0], + pairname[1], + f"flow_{pairname[2]}_{pairname[3]}", + f"flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}", + ) + self.load_flow = _read_hdf5_flow + + def _build_cache(self): + # train + trainseqs = sorted(os.listdir(osp.join(self.root, "train"))) + trainpairs = [] + for leftright in ["left", "right"]: + for fwbw in ["FW", "BW"]: + trainpairs += [ + ( + "train", + s, + fwbw, + leftright, + int(f[len(f"flow_{fwbw}_{leftright}_") : -len(".flo5")]), + ) + for s in trainseqs + for f in sorted( + os.listdir( + osp.join(self.root, "train", s, f"flow_{fwbw}_{leftright}") + ) + ) + ] + # test + testseqs = sorted(os.listdir(osp.join(self.root, "test"))) + testpairs = [] + for leftright in ["left", "right"]: + testpairs += [ + ( + "test", + s, + "FW", + leftright, + int(f[len(f"frame_{leftright}_") : -len(".png")]), + ) + for s in testseqs + for f in sorted( + os.listdir(osp.join(self.root, "test", s, f"frame_{leftright}")) + )[:-1] + ] + testpairs += [ + ( + "test", + s, + "BW", + leftright, + int(f[len(f"frame_{leftright}_") : -len(".png")]) + 1, + ) + for s in testseqs + for f in sorted( + os.listdir(osp.join(self.root, "test", s, f"frame_{leftright}")) + )[:-1] + ] + # subtrain / subval + subtrainpairs = [p for p in trainpairs if p[1] != "0041"] + subvalpairs = [p for p in trainpairs if p[1] == "0041"] + assert ( + len(trainpairs) == 19852 + and len(testpairs) == 3960 + and len(subtrainpairs) == 19472 + and len(subvalpairs) == 380 + ), "incorrect parsing of pairs in Spring" + tosave = { + "train": trainpairs, + "test": testpairs, + "subtrain": subtrainpairs, + "subval": subvalpairs, + } + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim == 3 + assert prediction.shape[2] == 2 + assert prediction.dtype == np.float32 + outfile = osp.join( + outdir, + pairname[0], + pairname[1], + f"flow_{pairname[2]}_{pairname[3]}", + f"flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5", + ) + os.makedirs(os.path.dirname(outfile), exist_ok=True) + writeFlo5File(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split == "test" + exe = "{self.root}/flow_subsampling" + if os.path.isfile(exe): + cmd = f'cd "{outdir}/test"; {exe} .' + print(cmd) + os.system(cmd) + print(f"Done. Submission file at {outdir}/test/flow_submission.hdf5") + else: + print("Could not find flow_subsampling executable for submission.") + print("Please download it and run:") + print(f'cd "{outdir}/test"; .') + + +class Kitti12Dataset(FlowDataset): + def _prepare_data(self): + self.name = "Kitti12" + self._set_root() + assert self.split in ["train", "test"] + self.pairname_to_img1name = lambda pairname: osp.join( + self.root, pairname + "_10.png" + ) + self.pairname_to_img2name = lambda pairname: osp.join( + self.root, pairname + "_11.png" + ) + self.pairname_to_flowname = ( + None + if self.split == "test" + else lambda pairname: osp.join( + self.root, pairname.replace("/colored_0/", "/flow_occ/") + "_10.png" + ) + ) + self.pairname_to_str = lambda pairname: pairname.replace("/colored_0/", "/") + self.load_flow = _read_kitti_flow + + def _build_cache(self): + trainseqs = ["training/colored_0/%06d" % (i) for i in range(194)] + testseqs = ["testing/colored_0/%06d" % (i) for i in range(195)] + assert ( + len(trainseqs) == 194 and len(testseqs) == 195 + ), "incorrect parsing of pairs in Kitti12" + tosave = {"train": trainseqs, "test": testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim == 3 + assert prediction.shape[2] == 2 + outfile = os.path.join(outdir, pairname.split("/")[-1] + "_10.png") + os.makedirs(os.path.dirname(outfile), exist_ok=True) + writeFlowKitti(outfile, prediction) + + def finalize_submission(self, outdir): + assert self.split == "test" + cmd = f'cd {outdir}/; zip -r "kitti12_flow_results.zip" .' + print(cmd) + os.system(cmd) + print(f"Done. Submission file at {outdir}/kitti12_flow_results.zip") + + +class Kitti15Dataset(FlowDataset): + def _prepare_data(self): + self.name = "Kitti15" + self._set_root() + assert self.split in ["train", "subtrain", "subval", "test"] + self.pairname_to_img1name = lambda pairname: osp.join( + self.root, pairname + "_10.png" + ) + self.pairname_to_img2name = lambda pairname: osp.join( + self.root, pairname + "_11.png" + ) + self.pairname_to_flowname = ( + None + if self.split == "test" + else lambda pairname: osp.join( + self.root, pairname.replace("/image_2/", "/flow_occ/") + "_10.png" + ) + ) + self.pairname_to_str = lambda pairname: pairname.replace("/image_2/", "/") + self.load_flow = _read_kitti_flow + + def _build_cache(self): + trainseqs = ["training/image_2/%06d" % (i) for i in range(200)] + subtrainseqs = trainseqs[:-10] + subvalseqs = trainseqs[-10:] + testseqs = ["testing/image_2/%06d" % (i) for i in range(200)] + assert ( + len(trainseqs) == 200 + and len(subtrainseqs) == 190 + and len(subvalseqs) == 10 + and len(testseqs) == 200 + ), "incorrect parsing of pairs in Kitti15" + tosave = { + "train": trainseqs, + "subtrain": subtrainseqs, + "subval": subvalseqs, + "test": testseqs, + } + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim == 3 + assert prediction.shape[2] == 2 + outfile = os.path.join(outdir, "flow", pairname.split("/")[-1] + "_10.png") + os.makedirs(os.path.dirname(outfile), exist_ok=True) + writeFlowKitti(outfile, prediction) + + def finalize_submission(self, outdir): + assert self.split == "test" + cmd = f'cd {outdir}/; zip -r "kitti15_flow_results.zip" flow' + print(cmd) + os.system(cmd) + print(f"Done. Submission file at {outdir}/kitti15_flow_results.zip") + + +import cv2 + + +def _read_numpy_flow(filename): + return np.load(filename) + + +def _read_pfm_flow(filename): + f, _ = _read_pfm(filename) + assert np.all(f[:, :, 2] == 0.0) + return np.ascontiguousarray(f[:, :, :2]) + + +TAG_FLOAT = 202021.25 # tag to check the sanity of the file +TAG_STRING = "PIEH" # string containing the tag +MIN_WIDTH = 1 +MAX_WIDTH = 99999 +MIN_HEIGHT = 1 +MAX_HEIGHT = 99999 + + +def readFlowFile(filename): + """ + readFlowFile() reads a flow file into a 2-band np.array. + if does not exist, an IOError is raised. + if does not finish by '.flo' or the tag, the width, the height or the file's size is illegal, an Expcetion is raised. + ---- PARAMETERS ---- + filename: string containg the name of the file to read a flow + ---- OUTPUTS ---- + a np.array of dimension (height x width x 2) containing the flow of type 'float32' + """ + + # check filename + if not filename.endswith(".flo"): + raise Exception( + "readFlowFile({:s}): filename must finish with '.flo'".format(filename) + ) + + # open the file and read it + with open(filename, "rb") as f: + # check tag + tag = struct.unpack("f", f.read(4))[0] + if tag != TAG_FLOAT: + raise Exception("flow_utils.readFlowFile({:s}): wrong tag".format(filename)) + # read dimension + w, h = struct.unpack("ii", f.read(8)) + if w < MIN_WIDTH or w > MAX_WIDTH: + raise Exception( + "flow_utils.readFlowFile({:s}: illegal width {:d}".format(filename, w) + ) + if h < MIN_HEIGHT or h > MAX_HEIGHT: + raise Exception( + "flow_utils.readFlowFile({:s}: illegal height {:d}".format(filename, h) + ) + flow = np.fromfile(f, "float32") + if not flow.shape == (h * w * 2,): + raise Exception( + "flow_utils.readFlowFile({:s}: illegal size of the file".format( + filename + ) + ) + flow.shape = (h, w, 2) + return flow + + +def writeFlowFile(flow, filename): + """ + writeFlowFile(flow,) write flow to the file . + if does not exist, an IOError is raised. + if does not finish with '.flo' or the flow has not 2 bands, an Exception is raised. + ---- PARAMETERS ---- + flow: np.array of dimension (height x width x 2) containing the flow to write + filename: string containg the name of the file to write a flow + """ + + # check filename + if not filename.endswith(".flo"): + raise Exception( + "flow_utils.writeFlowFile(,{:s}): filename must finish with '.flo'".format( + filename + ) + ) + + if not flow.shape[2:] == (2,): + raise Exception( + "flow_utils.writeFlowFile(,{:s}): must have 2 bands".format( + filename + ) + ) + + # open the file and write it + with open(filename, "wb") as f: + # write TAG + f.write(TAG_STRING.encode("utf-8")) + # write dimension + f.write(struct.pack("ii", flow.shape[1], flow.shape[0])) + # write the flow + + flow.astype(np.float32).tofile(f) + + +_read_flo_file = readFlowFile + + +def _read_kitti_flow(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + valid = flow[:, :, 2] > 0 + flow = flow[:, :, :2] + flow = (flow - 2**15) / 64.0 + flow[~valid, 0] = np.inf + flow[~valid, 1] = np.inf + return flow + + +_read_hd1k_flow = _read_kitti_flow + + +def writeFlowKitti(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def writeFlo5File(flow, filename): + with h5py.File(filename, "w") as f: + f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5) + + +def _read_hdf5_flow(filename): + flow = np.asarray(h5py.File(filename)["flow"]) + flow[np.isnan(flow)] = np.inf # make invalid values as +inf + return flow.astype(np.float32) + + +# flow visualization +RY = 15 +YG = 6 +GC = 4 +CB = 11 +BM = 13 +MR = 6 +UNKNOWN_THRESH = 1e9 + + +def colorTest(): + """ + flow_utils.colorTest(): display an example of image showing the color encoding scheme + """ + import matplotlib.pylab as plt + + truerange = 1 + h, w = 151, 151 + trange = truerange * 1.04 + s2 = round(h / 2) + x, y = np.meshgrid(range(w), range(h)) + u = x * trange / s2 - trange + v = y * trange / s2 - trange + img = _computeColor( + np.concatenate((u[:, :, np.newaxis], v[:, :, np.newaxis]), 2) + / trange + / np.sqrt(2) + ) + plt.imshow(img) + plt.axis("off") + plt.axhline(round(h / 2), color="k") + plt.axvline(round(w / 2), color="k") + + +def flowToColor(flow, maxflow=None, maxmaxflow=None, saturate=False): + """ + flow_utils.flowToColor(flow): return a color code flow field, normalized based on the maximum l2-norm of the flow + flow_utils.flowToColor(flow,maxflow): return a color code flow field, normalized by maxflow + ---- PARAMETERS ---- + flow: flow to display of shape (height x width x 2) + maxflow (default:None): if given, normalize the flow by its value, otherwise by the flow norm + maxmaxflow (default:None): if given, normalize the flow by the max of its value and the flow norm + ---- OUTPUT ---- + an np.array of shape (height x width x 3) of type uint8 containing a color code of the flow + """ + h, w, n = flow.shape + # check size of flow + assert n == 2, "flow_utils.flowToColor(flow): flow must have 2 bands" + # fix unknown flow + unknown_idx = np.max(np.abs(flow), 2) > UNKNOWN_THRESH + flow[unknown_idx] = 0.0 + # compute max flow if needed + if maxflow is None: + maxflow = flowMaxNorm(flow) + if maxmaxflow is not None: + maxflow = min(maxmaxflow, maxflow) + # normalize flow + eps = np.spacing(1) # minimum positive float value to avoid division by 0 + # compute the flow + img = _computeColor(flow / (maxflow + eps), saturate=saturate) + # put black pixels in unknown location + img[np.tile(unknown_idx[:, :, np.newaxis], [1, 1, 3])] = 0.0 + return img + + +def flowMaxNorm(flow): + """ + flow_utils.flowMaxNorm(flow): return the maximum of the l2-norm of the given flow + ---- PARAMETERS ---- + flow: the flow + + ---- OUTPUT ---- + a float containing the maximum of the l2-norm of the flow + """ + return np.max(np.sqrt(np.sum(np.square(flow), 2))) + + +def _computeColor(flow, saturate=True): + """ + flow_utils._computeColor(flow): compute color codes for the flow field flow + + ---- PARAMETERS ---- + flow: np.array of dimension (height x width x 2) containing the flow to display + ---- OUTPUTS ---- + an np.array of dimension (height x width x 3) containing the color conversion of the flow + """ + # set nan to 0 + nanidx = np.isnan(flow[:, :, 0]) + flow[nanidx] = 0.0 + + # colorwheel + ncols = RY + YG + GC + CB + BM + MR + nchans = 3 + colorwheel = np.zeros((ncols, nchans), "uint8") + col = 0 + # RY + colorwheel[:RY, 0] = 255 + colorwheel[:RY, 1] = [(255 * i) // RY for i in range(RY)] + col += RY + # YG + colorwheel[col : col + YG, 0] = [255 - (255 * i) // YG for i in range(YG)] + colorwheel[col : col + YG, 1] = 255 + col += YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = [(255 * i) // GC for i in range(GC)] + col += GC + # CB + colorwheel[col : col + CB, 1] = [255 - (255 * i) // CB for i in range(CB)] + colorwheel[col : col + CB, 2] = 255 + col += CB + # BM + colorwheel[col : col + BM, 0] = [(255 * i) // BM for i in range(BM)] + colorwheel[col : col + BM, 2] = 255 + col += BM + # MR + colorwheel[col : col + MR, 0] = 255 + colorwheel[col : col + MR, 2] = [255 - (255 * i) // MR for i in range(MR)] + + # compute utility variables + rad = np.sqrt(np.sum(np.square(flow), 2)) # magnitude + a = np.arctan2(-flow[:, :, 1], -flow[:, :, 0]) / np.pi # angle + fk = (a + 1) / 2 * (ncols - 1) # map [-1,1] to [0,ncols-1] + k0 = np.floor(fk).astype("int") + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + + if not saturate: + rad = np.minimum(rad, 1) + + # compute the image + img = np.zeros((flow.shape[0], flow.shape[1], nchans), "uint8") + for i in range(nchans): + tmp = colorwheel[:, i].astype("float") + col0 = tmp[k0] / 255 + col1 = tmp[k1] / 255 + col = (1 - f) * col0 + f * col1 + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) # increase saturation with radius + col[~idx] *= 0.75 # out of range + img[:, :, i] = (255 * col * (1 - nanidx.astype("float"))).astype("uint8") + + return img + + +# flow dataset getter + + +def get_train_dataset_flow(dataset_str, augmentor=True, crop_size=None): + dataset_str = dataset_str.replace("(", "Dataset(") + if augmentor: + dataset_str = dataset_str.replace(")", ", augmentor=True)") + if crop_size is not None: + dataset_str = dataset_str.replace( + ")", ", crop_size={:s})".format(str(crop_size)) + ) + return eval(dataset_str) + + +def get_test_datasets_flow(dataset_str): + dataset_str = dataset_str.replace("(", "Dataset(") + return [eval(s) for s in dataset_str.split("+")] diff --git a/stream3r/croco/stereoflow/datasets_stereo.py b/stream3r/croco/stereoflow/datasets_stereo.py new file mode 100644 index 0000000000000000000000000000000000000000..29d4b1cec8978aaa87a82105e7bad73bf2be3a76 --- /dev/null +++ b/stream3r/croco/stereoflow/datasets_stereo.py @@ -0,0 +1,983 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Dataset structure for stereo +# -------------------------------------------------------- + +import os +import os.path as osp +import pickle +import sys +from glob import glob + +import cv2 +import h5py +import numpy as np +import torch +from PIL import Image +from torch.utils import data + +from .augmentor import StereoAugmentor + +dataset_to_root = { + "CREStereo": "./data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/", + "SceneFlow": "./data/stereoflow//SceneFlow/", + "ETH3DLowRes": "./data/stereoflow/eth3d_lowres/", + "Booster": "./data/stereoflow/booster_gt/", + "Middlebury2021": "./data/stereoflow/middlebury/2021/data/", + "Middlebury2014": "./data/stereoflow/middlebury/2014/", + "Middlebury2006": "./data/stereoflow/middlebury/2006/", + "Middlebury2005": "./data/stereoflow/middlebury/2005/train/", + "MiddleburyEval3": "./data/stereoflow/middlebury/MiddEval3/", + "Spring": "./data/stereoflow/spring/", + "Kitti15": "./data/stereoflow/kitti-stereo-2015/", + "Kitti12": "./data/stereoflow/kitti-stereo-2012/", +} +cache_dir = "./data/stereoflow/datasets_stereo_cache/" + + +in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) +in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + + +def img_to_tensor(img): + img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 + img = (img - in1k_mean) / in1k_std + return img + + +def disp_to_tensor(disp): + return torch.from_numpy(disp)[None, :, :] + + +class StereoDataset(data.Dataset): + def __init__(self, split, augmentor=False, crop_size=None, totensor=True): + self.split = split + if not augmentor: + assert crop_size is None + if crop_size: + assert augmentor + self.crop_size = crop_size + self.augmentor_str = augmentor + self.augmentor = StereoAugmentor(crop_size) if augmentor else None + self.totensor = totensor + self.rmul = 1 # keep track of rmul + self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time) + self._prepare_data() + self._load_or_build_cache() + + def prepare_data(self): + """ + to be defined for each dataset + """ + raise NotImplementedError + + def __len__(self): + return len(self.pairnames) + + def __getitem__(self, index): + pairname = self.pairnames[index] + + # get filenames + Limgname = self.pairname_to_Limgname(pairname) + Rimgname = self.pairname_to_Rimgname(pairname) + Ldispname = ( + self.pairname_to_Ldispname(pairname) + if self.pairname_to_Ldispname is not None + else None + ) + + # load images and disparities + Limg = _read_img(Limgname) + Rimg = _read_img(Rimgname) + disp = self.load_disparity(Ldispname) if Ldispname is not None else None + + # sanity check + if disp is not None: + assert np.all(disp > 0) or self.name == "Spring", ( + self.name, + pairname, + Ldispname, + ) + + # apply augmentations + if self.augmentor is not None: + Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name) + + if self.totensor: + Limg = img_to_tensor(Limg) + Rimg = img_to_tensor(Rimg) + if disp is None: + disp = torch.tensor( + [] + ) # to allow dataloader batching with default collate_gn + else: + disp = disp_to_tensor(disp) + + return Limg, Rimg, disp, str(pairname) + + def __rmul__(self, v): + self.rmul *= v + self.pairnames = v * self.pairnames + return self + + def __str__(self): + return f"{self.__class__.__name__}_{self.split}" + + def __repr__(self): + s = f"{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})" + if self.rmul == 1: + s += f"\n\tnum pairs: {len(self.pairnames)}" + else: + s += f"\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})" + return s + + def _set_root(self): + self.root = dataset_to_root[self.name] + assert os.path.isdir( + self.root + ), f"could not find root directory for dataset {self.name}: {self.root}" + + def _load_or_build_cache(self): + cache_file = osp.join(cache_dir, self.name + ".pkl") + if osp.isfile(cache_file): + with open(cache_file, "rb") as fid: + self.pairnames = pickle.load(fid)[self.split] + else: + tosave = self._build_cache() + os.makedirs(cache_dir, exist_ok=True) + with open(cache_file, "wb") as fid: + pickle.dump(tosave, fid) + self.pairnames = tosave[self.split] + + +class CREStereoDataset(StereoDataset): + def _prepare_data(self): + self.name = "CREStereo" + self._set_root() + assert self.split in ["train"] + self.pairname_to_Limgname = lambda pairname: osp.join( + self.root, pairname + "_left.jpg" + ) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, pairname + "_right.jpg" + ) + self.pairname_to_Ldispname = lambda pairname: osp.join( + self.root, pairname + "_left.disp.png" + ) + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_crestereo_disp + + def _build_cache(self): + allpairs = [ + s + "/" + f[: -len("_left.jpg")] + for s in sorted(os.listdir(self.root)) + for f in sorted(os.listdir(self.root + "/" + s)) + if f.endswith("_left.jpg") + ] + assert len(allpairs) == 200000, "incorrect parsing of pairs in CreStereo" + tosave = {"train": allpairs} + return tosave + + +class SceneFlowDataset(StereoDataset): + def _prepare_data(self): + self.name = "SceneFlow" + self._set_root() + assert self.split in [ + "train_finalpass", + "train_cleanpass", + "train_allpass", + "test_finalpass", + "test_cleanpass", + "test_allpass", + "test1of100_cleanpass", + "test1of100_finalpass", + ] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, pairname + ).replace("/left/", "/right/") + self.pairname_to_Ldispname = ( + lambda pairname: osp.join(self.root, pairname) + .replace("/frames_finalpass/", "/disparity/") + .replace("/frames_cleanpass/", "/disparity/")[:-4] + + ".pfm" + ) + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_sceneflow_disp + + def _build_cache(self): + trainpairs = [] + # driving + pairs = sorted(glob(self.root + "Driving/frames_finalpass/*/*/*/left/*.png")) + pairs = list(map(lambda x: x[len(self.root) :], pairs)) + assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + # monkaa + pairs = sorted(glob(self.root + "Monkaa/frames_finalpass/*/left/*.png")) + pairs = list(map(lambda x: x[len(self.root) :], pairs)) + assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + # flyingthings + pairs = sorted( + glob(self.root + "FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png") + ) + pairs = list(map(lambda x: x[len(self.root) :], pairs)) + assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow" + trainpairs += pairs + assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow" + testpairs = sorted( + glob(self.root + "FlyingThings/frames_finalpass/TEST/*/*/left/*.png") + ) + testpairs = list(map(lambda x: x[len(self.root) :], testpairs)) + assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow" + test1of100pairs = testpairs[::100] + assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow" + # all + tosave = { + "train_finalpass": trainpairs, + "train_cleanpass": list( + map( + lambda x: x.replace("frames_finalpass", "frames_cleanpass"), + trainpairs, + ) + ), + "test_finalpass": testpairs, + "test_cleanpass": list( + map( + lambda x: x.replace("frames_finalpass", "frames_cleanpass"), + testpairs, + ) + ), + "test1of100_finalpass": test1of100pairs, + "test1of100_cleanpass": list( + map( + lambda x: x.replace("frames_finalpass", "frames_cleanpass"), + test1of100pairs, + ) + ), + } + tosave["train_allpass"] = tosave["train_finalpass"] + tosave["train_cleanpass"] + tosave["test_allpass"] = tosave["test_finalpass"] + tosave["test_cleanpass"] + return tosave + + +class Md21Dataset(StereoDataset): + def _prepare_data(self): + self.name = "Middlebury2021" + self._set_root() + assert self.split in ["train", "subtrain", "subval"] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, pairname.replace("/im0", "/im1") + ) + self.pairname_to_Ldispname = lambda pairname: osp.join( + self.root, pairname.split("/")[0], "disp0.pfm" + ) + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury_disp + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + # trainpairs += [s+'/im0.png'] # we should remove it, it is included as such in other lightings + trainpairs += [ + s + "/ambient/" + b + "/" + a + for b in sorted(os.listdir(osp.join(self.root, s, "ambient"))) + for a in sorted(os.listdir(osp.join(self.root, s, "ambient", b))) + if a.startswith("im0") + ] + assert len(trainpairs) == 355 + subtrainpairs = [ + p for p in trainpairs if any(p.startswith(s + "/") for s in seqs[:-2]) + ] + subvalpairs = [ + p for p in trainpairs if any(p.startswith(s + "/") for s in seqs[-2:]) + ] + assert ( + len(subtrainpairs) == 335 and len(subvalpairs) == 20 + ), "incorrect parsing of pairs in Middlebury 2021" + tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs} + return tosave + + +class Md14Dataset(StereoDataset): + def _prepare_data(self): + self.name = "Middlebury2014" + self._set_root() + assert self.split in ["train", "subtrain", "subval"] + self.pairname_to_Limgname = lambda pairname: osp.join( + self.root, osp.dirname(pairname), "im0.png" + ) + self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Ldispname = lambda pairname: osp.join( + self.root, osp.dirname(pairname), "disp0.pfm" + ) + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury_disp + self.has_constant_resolution = False + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + trainpairs += [s + "/im1.png", s + "/im1E.png", s + "/im1L.png"] + assert len(trainpairs) == 138 + valseqs = ["Umbrella-imperfect", "Vintage-perfect"] + assert all(s in seqs for s in valseqs) + subtrainpairs = [ + p for p in trainpairs if not any(p.startswith(s + "/") for s in valseqs) + ] + subvalpairs = [ + p for p in trainpairs if any(p.startswith(s + "/") for s in valseqs) + ] + assert ( + len(subtrainpairs) == 132 and len(subvalpairs) == 6 + ), "incorrect parsing of pairs in Middlebury 2014" + tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs} + return tosave + + +class Md06Dataset(StereoDataset): + def _prepare_data(self): + self.name = "Middlebury2006" + self._set_root() + assert self.split in ["train", "subtrain", "subval"] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, osp.dirname(pairname), "view5.png" + ) + self.pairname_to_Ldispname = lambda pairname: osp.join( + self.root, pairname.split("/")[0], "disp1.png" + ) + self.load_disparity = _read_middlebury20052006_disp + self.has_constant_resolution = False + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + for i in ["Illum1", "Illum2", "Illum3"]: + for e in ["Exp0", "Exp1", "Exp2"]: + trainpairs.append(osp.join(s, i, e, "view1.png")) + assert len(trainpairs) == 189 + valseqs = ["Rocks1", "Wood2"] + assert all(s in seqs for s in valseqs) + subtrainpairs = [ + p for p in trainpairs if not any(p.startswith(s + "/") for s in valseqs) + ] + subvalpairs = [ + p for p in trainpairs if any(p.startswith(s + "/") for s in valseqs) + ] + assert ( + len(subtrainpairs) == 171 and len(subvalpairs) == 18 + ), "incorrect parsing of pairs in Middlebury 2006" + tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs} + return tosave + + +class Md05Dataset(StereoDataset): + def _prepare_data(self): + self.name = "Middlebury2005" + self._set_root() + assert self.split in ["train", "subtrain", "subval"] + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, osp.dirname(pairname), "view5.png" + ) + self.pairname_to_Ldispname = lambda pairname: osp.join( + self.root, pairname.split("/")[0], "disp1.png" + ) + self.pairname_to_str = lambda pairname: pairname[:-4] + self.load_disparity = _read_middlebury20052006_disp + + def _build_cache(self): + seqs = sorted(os.listdir(self.root)) + trainpairs = [] + for s in seqs: + for i in ["Illum1", "Illum2", "Illum3"]: + for e in ["Exp0", "Exp1", "Exp2"]: + trainpairs.append(osp.join(s, i, e, "view1.png")) + assert len(trainpairs) == 54, "incorrect parsing of pairs in Middlebury 2005" + valseqs = ["Reindeer"] + assert all(s in seqs for s in valseqs) + subtrainpairs = [ + p for p in trainpairs if not any(p.startswith(s + "/") for s in valseqs) + ] + subvalpairs = [ + p for p in trainpairs if any(p.startswith(s + "/") for s in valseqs) + ] + assert ( + len(subtrainpairs) == 45 and len(subvalpairs) == 9 + ), "incorrect parsing of pairs in Middlebury 2005" + tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs} + return tosave + + +class MdEval3Dataset(StereoDataset): + def _prepare_data(self): + self.name = "MiddleburyEval3" + self._set_root() + assert self.split in [ + s + "_" + r + for s in ["train", "subtrain", "subval", "test", "all"] + for r in ["full", "half", "quarter"] + ] + if self.split.endswith("_full"): + self.root = self.root.replace("/MiddEval3", "/MiddEval3_F") + elif self.split.endswith("_half"): + self.root = self.root.replace("/MiddEval3", "/MiddEval3_H") + else: + assert self.split.endswith("_quarter") + self.pairname_to_Limgname = lambda pairname: osp.join( + self.root, pairname, "im0.png" + ) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, pairname, "im1.png" + ) + self.pairname_to_Ldispname = ( + lambda pairname: None + if pairname.startswith("test") + else osp.join(self.root, pairname, "disp0GT.pfm") + ) + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_middlebury_disp + # for submission only + self.submission_methodname = "CroCo-Stereo" + self.submission_sresolution = ( + "F" + if self.split.endswith("_full") + else ("H" if self.split.endswith("_half") else "Q") + ) + + def _build_cache(self): + trainpairs = ["train/" + s for s in sorted(os.listdir(self.root + "train/"))] + testpairs = ["test/" + s for s in sorted(os.listdir(self.root + "test/"))] + subvalpairs = trainpairs[-1:] + subtrainpairs = trainpairs[:-1] + allpairs = trainpairs + testpairs + assert ( + len(trainpairs) == 15 + and len(testpairs) == 15 + and len(subvalpairs) == 1 + and len(subtrainpairs) == 14 + and len(allpairs) == 30 + ), "incorrect parsing of pairs in Middlebury Eval v3" + tosave = {} + for r in ["full", "half", "quarter"]: + tosave.update( + **{ + "train_" + r: trainpairs, + "subtrain_" + r: subtrainpairs, + "subval_" + r: subvalpairs, + "test_" + r: testpairs, + "all_" + r: allpairs, + } + ) + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim == 2 + assert prediction.dtype == np.float32 + outfile = os.path.join( + outdir, + pairname.split("/")[0].replace("train", "training") + + self.submission_sresolution, + pairname.split("/")[1], + "disp0" + self.submission_methodname + ".pfm", + ) + os.makedirs(os.path.dirname(outfile), exist_ok=True) + writePFM(outfile, prediction) + timefile = os.path.join( + os.path.dirname(outfile), "time" + self.submission_methodname + ".txt" + ) + with open(timefile, "w") as fid: + fid.write(str(time)) + + def finalize_submission(self, outdir): + cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .' + print(cmd) + os.system(cmd) + print(f"Done. Submission file at {outdir}/{self.submission_methodname}.zip") + + +class ETH3DLowResDataset(StereoDataset): + def _prepare_data(self): + self.name = "ETH3DLowRes" + self._set_root() + assert self.split in ["train", "test", "subtrain", "subval", "all"] + self.pairname_to_Limgname = lambda pairname: osp.join( + self.root, pairname, "im0.png" + ) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, pairname, "im1.png" + ) + self.pairname_to_Ldispname = ( + None + if self.split == "test" + else lambda pairname: None + if pairname.startswith("test/") + else osp.join( + self.root, pairname.replace("train/", "train_gt/"), "disp0GT.pfm" + ) + ) + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_eth3d_disp + self.has_constant_resolution = False + + def _build_cache(self): + trainpairs = ["train/" + s for s in sorted(os.listdir(self.root + "train/"))] + testpairs = ["test/" + s for s in sorted(os.listdir(self.root + "test/"))] + assert ( + len(trainpairs) == 27 and len(testpairs) == 20 + ), "incorrect parsing of pairs in ETH3D Low Res" + subvalpairs = [ + "train/delivery_area_3s", + "train/electro_3l", + "train/playground_3l", + ] + assert all(p in trainpairs for p in subvalpairs) + subtrainpairs = [p for p in trainpairs if not p in subvalpairs] + assert ( + len(subvalpairs) == 3 and len(subtrainpairs) == 24 + ), "incorrect parsing of pairs in ETH3D Low Res" + tosave = { + "train": trainpairs, + "test": testpairs, + "subtrain": subtrainpairs, + "subval": subvalpairs, + "all": trainpairs + testpairs, + } + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim == 2 + assert prediction.dtype == np.float32 + outfile = os.path.join( + outdir, "low_res_two_view", pairname.split("/")[1] + ".pfm" + ) + os.makedirs(os.path.dirname(outfile), exist_ok=True) + writePFM(outfile, prediction) + timefile = outfile[:-4] + ".txt" + with open(timefile, "w") as fid: + fid.write("runtime " + str(time)) + + def finalize_submission(self, outdir): + cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view' + print(cmd) + os.system(cmd) + print(f"Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip") + + +class BoosterDataset(StereoDataset): + def _prepare_data(self): + self.name = "Booster" + self._set_root() + assert self.split in [ + "train_balanced", + "test_balanced", + "subtrain_balanced", + "subval_balanced", + ] # we use only the balanced version + self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, pairname + ).replace("/camera_00/", "/camera_02/") + self.pairname_to_Ldispname = lambda pairname: osp.join( + self.root, osp.dirname(pairname), "../disp_00.npy" + ) # same images with different colors, same gt per sequence + self.pairname_to_str = lambda pairname: pairname[:-4].replace( + "/camera_00/", "/" + ) + self.load_disparity = _read_booster_disp + + def _build_cache(self): + trainseqs = sorted(os.listdir(self.root + "train/balanced")) + trainpairs = [ + "train/balanced/" + s + "/camera_00/" + imname + for s in trainseqs + for imname in sorted( + os.listdir(self.root + "train/balanced/" + s + "/camera_00/") + ) + ] + testpairs = [ + "test/balanced/" + s + "/camera_00/" + imname + for s in sorted(os.listdir(self.root + "test/balanced")) + for imname in sorted( + os.listdir(self.root + "test/balanced/" + s + "/camera_00/") + ) + ] + assert len(trainpairs) == 228 and len(testpairs) == 191 + subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])] + subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])] + # warning: if we do validation split, we should split scenes!!! + tosave = { + "train_balanced": trainpairs, + "test_balanced": testpairs, + "subtrain_balanced": subtrainpairs, + "subval_balanced": subvalpairs, + } + return tosave + + +class SpringDataset(StereoDataset): + def _prepare_data(self): + self.name = "Spring" + self._set_root() + assert self.split in ["train", "test", "subtrain", "subval"] + self.pairname_to_Limgname = lambda pairname: osp.join( + self.root, pairname + ".png" + ) + self.pairname_to_Rimgname = ( + lambda pairname: osp.join(self.root, pairname + ".png") + .replace("frame_right", "") + .replace("frame_left", "frame_right") + .replace("", "frame_left") + ) + self.pairname_to_Ldispname = ( + lambda pairname: None + if pairname.startswith("test") + else osp.join(self.root, pairname + ".dsp5") + .replace("frame_left", "disp1_left") + .replace("frame_right", "disp1_right") + ) + self.pairname_to_str = lambda pairname: pairname + self.load_disparity = _read_hdf5_disp + + def _build_cache(self): + trainseqs = sorted(os.listdir(osp.join(self.root, "train"))) + trainpairs = [ + osp.join("train", s, "frame_left", f[:-4]) + for s in trainseqs + for f in sorted(os.listdir(osp.join(self.root, "train", s, "frame_left"))) + ] + testseqs = sorted(os.listdir(osp.join(self.root, "test"))) + testpairs = [ + osp.join("test", s, "frame_left", f[:-4]) + for s in testseqs + for f in sorted(os.listdir(osp.join(self.root, "test", s, "frame_left"))) + ] + testpairs += [p.replace("frame_left", "frame_right") for p in testpairs] + """maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041""" + subtrainpairs = [p for p in trainpairs if p.split("/")[1] != "0041"] + subvalpairs = [p for p in trainpairs if p.split("/")[1] == "0041"] + assert ( + len(trainpairs) == 5000 + and len(testpairs) == 2000 + and len(subtrainpairs) == 4904 + and len(subvalpairs) == 96 + ), "incorrect parsing of pairs in Spring" + tosave = { + "train": trainpairs, + "test": testpairs, + "subtrain": subtrainpairs, + "subval": subvalpairs, + } + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim == 2 + assert prediction.dtype == np.float32 + outfile = ( + os.path.join(outdir, pairname + ".dsp5") + .replace("frame_left", "disp1_left") + .replace("frame_right", "disp1_right") + ) + os.makedirs(os.path.dirname(outfile), exist_ok=True) + writeDsp5File(prediction, outfile) + + def finalize_submission(self, outdir): + assert self.split == "test" + exe = "{self.root}/disp1_subsampling" + if os.path.isfile(exe): + cmd = f'cd "{outdir}/test"; {exe} .' + print(cmd) + os.system(cmd) + else: + print("Could not find disp1_subsampling executable for submission.") + print("Please download it and run:") + print(f'cd "{outdir}/test"; .') + + +class Kitti12Dataset(StereoDataset): + def _prepare_data(self): + self.name = "Kitti12" + self._set_root() + assert self.split in ["train", "test"] + self.pairname_to_Limgname = lambda pairname: osp.join( + self.root, pairname + "_10.png" + ) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, pairname.replace("/colored_0/", "/colored_1/") + "_10.png" + ) + self.pairname_to_Ldispname = ( + None + if self.split == "test" + else lambda pairname: osp.join( + self.root, pairname.replace("/colored_0/", "/disp_occ/") + "_10.png" + ) + ) + self.pairname_to_str = lambda pairname: pairname.replace("/colored_0/", "/") + self.load_disparity = _read_kitti_disp + + def _build_cache(self): + trainseqs = ["training/colored_0/%06d" % (i) for i in range(194)] + testseqs = ["testing/colored_0/%06d" % (i) for i in range(195)] + assert ( + len(trainseqs) == 194 and len(testseqs) == 195 + ), "incorrect parsing of pairs in Kitti12" + tosave = {"train": trainseqs, "test": testseqs} + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim == 2 + assert prediction.dtype == np.float32 + outfile = os.path.join(outdir, pairname.split("/")[-1] + "_10.png") + os.makedirs(os.path.dirname(outfile), exist_ok=True) + img = (prediction * 256).astype("uint16") + Image.fromarray(img).save(outfile) + + def finalize_submission(self, outdir): + assert self.split == "test" + cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .' + print(cmd) + os.system(cmd) + print(f"Done. Submission file at {outdir}/kitti12_results.zip") + + +class Kitti15Dataset(StereoDataset): + def _prepare_data(self): + self.name = "Kitti15" + self._set_root() + assert self.split in ["train", "subtrain", "subval", "test"] + self.pairname_to_Limgname = lambda pairname: osp.join( + self.root, pairname + "_10.png" + ) + self.pairname_to_Rimgname = lambda pairname: osp.join( + self.root, pairname.replace("/image_2/", "/image_3/") + "_10.png" + ) + self.pairname_to_Ldispname = ( + None + if self.split == "test" + else lambda pairname: osp.join( + self.root, pairname.replace("/image_2/", "/disp_occ_0/") + "_10.png" + ) + ) + self.pairname_to_str = lambda pairname: pairname.replace("/image_2/", "/") + self.load_disparity = _read_kitti_disp + + def _build_cache(self): + trainseqs = ["training/image_2/%06d" % (i) for i in range(200)] + subtrainseqs = trainseqs[:-5] + subvalseqs = trainseqs[-5:] + testseqs = ["testing/image_2/%06d" % (i) for i in range(200)] + assert ( + len(trainseqs) == 200 + and len(subtrainseqs) == 195 + and len(subvalseqs) == 5 + and len(testseqs) == 200 + ), "incorrect parsing of pairs in Kitti15" + tosave = { + "train": trainseqs, + "subtrain": subtrainseqs, + "subval": subvalseqs, + "test": testseqs, + } + return tosave + + def submission_save_pairname(self, pairname, prediction, outdir, time): + assert prediction.ndim == 2 + assert prediction.dtype == np.float32 + outfile = os.path.join(outdir, "disp_0", pairname.split("/")[-1] + "_10.png") + os.makedirs(os.path.dirname(outfile), exist_ok=True) + img = (prediction * 256).astype("uint16") + Image.fromarray(img).save(outfile) + + def finalize_submission(self, outdir): + assert self.split == "test" + cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0' + print(cmd) + os.system(cmd) + print(f"Done. Submission file at {outdir}/kitti15_results.zip") + + +### auxiliary functions + + +def _read_img(filename): + # convert to RGB for scene flow finalpass data + img = np.asarray(Image.open(filename).convert("RGB")) + return img + + +def _read_booster_disp(filename): + disp = np.load(filename) + disp[disp == 0.0] = np.inf + return disp + + +def _read_png_disp(filename, coef=1.0): + disp = np.asarray(Image.open(filename)) + disp = disp.astype(np.float32) / coef + disp[disp == 0.0] = np.inf + return disp + + +def _read_pfm_disp(filename): + disp = np.ascontiguousarray(_read_pfm(filename)[0]) + disp[ + disp <= 0 + ] = ( + np.inf + ) # eg /nfs/data/ffs-3d/datasets/middlebury/2014/Shopvac-imperfect/disp0.pfm + return disp + + +def _read_npy_disp(filename): + return np.load(filename) + + +def _read_crestereo_disp(filename): + return _read_png_disp(filename, coef=32.0) + + +def _read_middlebury20052006_disp(filename): + return _read_png_disp(filename, coef=1.0) + + +def _read_kitti_disp(filename): + return _read_png_disp(filename, coef=256.0) + + +_read_sceneflow_disp = _read_pfm_disp +_read_eth3d_disp = _read_pfm_disp +_read_middlebury_disp = _read_pfm_disp +_read_carla_disp = _read_pfm_disp +_read_tartanair_disp = _read_npy_disp + + +def _read_hdf5_disp(filename): + disp = np.asarray(h5py.File(filename)["disparity"]) + disp[np.isnan(disp)] = np.inf # make invalid values as +inf + # disp[disp==0.0] = np.inf # make invalid values as +inf + return disp.astype(np.float32) + + +import re + + +def _read_pfm(file): + file = open(file, "rb") + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file.") + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + + +def writePFM(file, image, scale=1): + file = open(file, "wb") + + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def writeDsp5File(disp, filename): + with h5py.File(filename, "w") as f: + f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5) + + +# disp visualization + + +def vis_disparity(disp, m=None, M=None): + if m is None: + m = disp.min() + if M is None: + M = disp.max() + disp_vis = (disp - m) / (M - m) * 255.0 + disp_vis = disp_vis.astype("uint8") + disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) + return disp_vis + + +# dataset getter + + +def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None): + dataset_str = dataset_str.replace("(", "Dataset(") + if augmentor: + dataset_str = dataset_str.replace(")", ", augmentor=True)") + if crop_size is not None: + dataset_str = dataset_str.replace( + ")", ", crop_size={:s})".format(str(crop_size)) + ) + return eval(dataset_str) + + +def get_test_datasets_stereo(dataset_str): + dataset_str = dataset_str.replace("(", "Dataset(") + return [eval(s) for s in dataset_str.split("+")] diff --git a/stream3r/croco/stereoflow/download_model.sh b/stream3r/croco/stereoflow/download_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..f9f59121038b7c4b5d2782ce1d8ce35278d9db1c --- /dev/null +++ b/stream3r/croco/stereoflow/download_model.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +model=$1 +outfile="stereoflow_models/${model}" +if [[ ! -f $outfile ]] +then + mkdir -p stereoflow_models/; + wget https://download.europe.naverlabs.com/ComputerVision/CroCo/StereoFlow_models/"$1" -P stereoflow_models/; +else + echo "Model ${model} already downloaded in ${outfile}." +fi diff --git a/stream3r/croco/stereoflow/engine.py b/stream3r/croco/stereoflow/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..635d5e2768b8fd5dae5944a5cdb992dea041cbc6 --- /dev/null +++ b/stream3r/croco/stereoflow/engine.py @@ -0,0 +1,370 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main function for training one epoch or testing +# -------------------------------------------------------- + +import math +import sys +from typing import Iterable + +import numpy as np +import torch +import torchvision +from utils import misc as misc + + +def split_prediction_conf(predictions, with_conf=False): + if not with_conf: + return predictions, None + conf = predictions[:, -1:, :, :] + predictions = predictions[:, :-1, :, :] + return predictions, conf + + +def train_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + metrics: torch.nn.Module, + data_loader: Iterable, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, + loss_scaler, + log_writer=None, + print_freq=20, + args=None, +): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) + header = "Epoch: [{}]".format(epoch) + + accum_iter = args.accum_iter + + optimizer.zero_grad() + + details = {} + + if log_writer is not None: + print("log_dir: {}".format(log_writer.log_dir)) + + if args.img_per_epoch: + iter_per_epoch = args.img_per_epoch // args.batch_size + int( + args.img_per_epoch % args.batch_size > 0 + ) + assert ( + len(data_loader) >= iter_per_epoch + ), "Dataset is too small for so many iterations" + len_data_loader = iter_per_epoch + else: + len_data_loader, iter_per_epoch = len(data_loader), None + + for data_iter_step, (image1, image2, gt, pairname) in enumerate( + metric_logger.log_every( + data_loader, print_freq, header, max_iter=iter_per_epoch + ) + ): + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) + + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + misc.adjust_learning_rate( + optimizer, data_iter_step / len_data_loader + epoch, args + ) + + with torch.cuda.amp.autocast(enabled=bool(args.amp)): + prediction = model(image1, image2) + prediction, conf = split_prediction_conf(prediction, criterion.with_conf) + batch_metrics = metrics(prediction.detach(), gt) + loss = ( + criterion(prediction, gt) + if conf is None + else criterion(prediction, gt, conf) + ) + + loss_value = loss.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler( + loss, + optimizer, + parameters=model.parameters(), + update_grad=(data_iter_step + 1) % accum_iter == 0, + ) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + for k, v in batch_metrics.items(): + metric_logger.update(**{k: v.item()}) + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + # if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value) + time_to_log = (data_iter_step + 1) % ( + args.tboard_log_step * accum_iter + ) == 0 or data_iter_step == len_data_loader - 1 + loss_value_reduce = misc.all_reduce_mean(loss_value) + if log_writer is not None and time_to_log: + epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000) + # We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. + log_writer.add_scalar("train/loss", loss_value_reduce, epoch_1000x) + log_writer.add_scalar("lr", lr, epoch_1000x) + for k, v in batch_metrics.items(): + log_writer.add_scalar("train/" + k, v.item(), epoch_1000x) + + # gather the stats from all processes + # if args.distributed: metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def validate_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + metrics: torch.nn.Module, + data_loaders: list[Iterable], + device: torch.device, + epoch: int, + log_writer=None, + args=None, +): + model.eval() + metric_loggers = [] + header = "Epoch: [{}]".format(epoch) + print_freq = 20 + + conf_mode = args.tile_conf_mode + crop = args.crop + + if log_writer is not None: + print("log_dir: {}".format(log_writer.log_dir)) + + results = {} + dnames = [] + image1, image2, gt, prediction = None, None, None, None + for didx, data_loader in enumerate(data_loaders): + dname = str(data_loader.dataset) + dnames.append(dname) + metric_loggers.append(misc.MetricLogger(delimiter=" ")) + for data_iter_step, (image1, image2, gt, pairname) in enumerate( + metric_loggers[didx].log_every(data_loader, print_freq, header) + ): + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = gt.to(device, non_blocking=True) + if dname.startswith("Spring"): + assert ( + gt.size(2) == image1.size(2) * 2 + and gt.size(3) == image1.size(3) * 2 + ) + gt = ( + gt[:, :, 0::2, 0::2] + + gt[:, :, 0::2, 1::2] + + gt[:, :, 1::2, 0::2] + + gt[:, :, 1::2, 1::2] + ) / 4.0 # we approximate the gt based on the 2x upsampled ones + + with torch.inference_mode(): + prediction, tiled_loss, c = tiled_pred( + model, + criterion, + image1, + image2, + gt, + conf_mode=conf_mode, + overlap=args.val_overlap, + crop=crop, + with_conf=criterion.with_conf, + ) + batch_metrics = metrics(prediction.detach(), gt) + loss = ( + criterion(prediction.detach(), gt) + if not criterion.with_conf + else criterion(prediction.detach(), gt, c) + ) + loss_value = loss.item() + metric_loggers[didx].update(loss_tiled=tiled_loss.item()) + metric_loggers[didx].update(**{f"loss": loss_value}) + for k, v in batch_metrics.items(): + metric_loggers[didx].update(**{dname + "_" + k: v.item()}) + + results = { + k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items() + } + if len(dnames) > 1: + for k in batch_metrics.keys(): + results["AVG_" + k] = sum( + results[dname + "_" + k] for dname in dnames + ) / len(dnames) + + if log_writer is not None: + epoch_1000x = int((1 + epoch) * 1000) + for k, v in results.items(): + log_writer.add_scalar("val/" + k, v, epoch_1000x) + + print("Averaged stats:", results) + return results + + +import torch.nn.functional as F + + +def _resize_img(img, new_size): + return F.interpolate(img, size=new_size, mode="bicubic", align_corners=False) + + +def _resize_stereo_or_flow(data, new_size): + assert data.ndim == 4 + assert data.size(1) in [1, 2] + scale_x = new_size[1] / float(data.size(3)) + out = F.interpolate(data, size=new_size, mode="bicubic", align_corners=False) + out[:, 0, :, :] *= scale_x + if out.size(1) == 2: + scale_y = new_size[0] / float(data.size(2)) + out[:, 1, :, :] *= scale_y + print(scale_x, new_size, data.shape) + return out + + +@torch.no_grad() +def tiled_pred( + model, + criterion, + img1, + img2, + gt, + overlap=0.5, + bad_crop_thr=0.05, + downscale=False, + crop=512, + ret="loss", + conf_mode="conf_expsigmoid_10_5", + with_conf=False, + return_time=False, +): + # for each image, we are going to run inference on many overlapping patches + # then, all predictions will be weighted-averaged + if gt is not None: + B, C, H, W = gt.shape + else: + B, _, H, W = img1.shape + C = model.head.num_channels - int(with_conf) + win_height, win_width = crop[0], crop[1] + + # upscale to be larger than the crop + do_change_scale = H < win_height or W < win_width + if do_change_scale: + upscale_factor = max(win_width / W, win_height / W) + original_size = (H, W) + new_size = (round(H * upscale_factor), round(W * upscale_factor)) + img1 = _resize_img(img1, new_size) + img2 = _resize_img(img2, new_size) + # resize gt just for the computation of tiled losses + if gt is not None: + gt = _resize_stereo_or_flow(gt, new_size) + H, W = img1.shape[2:4] + + if conf_mode.startswith("conf_expsigmoid_"): # conf_expsigmoid_30_10 + beta, betasigmoid = map(float, conf_mode[len("conf_expsigmoid_") :].split("_")) + elif conf_mode.startswith("conf_expbeta"): # conf_expbeta3 + beta = float(conf_mode[len("conf_expbeta") :]) + else: + raise NotImplementedError(f"conf_mode {conf_mode} is not implemented") + + def crop_generator(): + for sy in _overlapping(H, win_height, overlap): + for sx in _overlapping(W, win_width, overlap): + yield sy, sx, sy, sx, True + + # keep track of weighted sum of prediction*weights and weights + accu_pred = img1.new_zeros( + (B, C, H, W) + ) # accumulate the weighted sum of predictions + accu_conf = img1.new_zeros((B, H, W)) + 1e-16 # accumulate the weights + accu_c = img1.new_zeros( + (B, H, W) + ) # accumulate the weighted sum of confidences ; not so useful except for computing some losses + + tiled_losses = [] + + if return_time: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + + for sy1, sx1, sy2, sx2, aligned in crop_generator(): + # compute optical flow there + pred = model(_crop(img1, sy1, sx1), _crop(img2, sy2, sx2)) + pred, predconf = split_prediction_conf(pred, with_conf=with_conf) + + if gt is not None: + gtcrop = _crop(gt, sy1, sx1) + if criterion is not None and gt is not None: + tiled_losses.append( + criterion(pred, gtcrop).item() + if predconf is None + else criterion(pred, gtcrop, predconf).item() + ) + + if conf_mode.startswith("conf_expsigmoid_"): + conf = torch.exp( + -beta * 2 * (torch.sigmoid(predconf / betasigmoid) - 0.5) + ).view(B, win_height, win_width) + elif conf_mode.startswith("conf_expbeta"): + conf = torch.exp(-beta * predconf).view(B, win_height, win_width) + else: + raise NotImplementedError + + accu_pred[..., sy1, sx1] += pred * conf[:, None, :, :] + accu_conf[..., sy1, sx1] += conf + accu_c[..., sy1, sx1] += predconf.view(B, win_height, win_width) * conf + + pred = accu_pred / accu_conf[:, None, :, :] + c = accu_c / accu_conf + assert not torch.any(torch.isnan(pred)) + + if return_time: + end.record() + torch.cuda.synchronize() + time = start.elapsed_time(end) / 1000.0 # this was in milliseconds + + if do_change_scale: + pred = _resize_stereo_or_flow(pred, original_size) + + if return_time: + return pred, torch.mean(torch.tensor(tiled_losses)), c, time + return pred, torch.mean(torch.tensor(tiled_losses)), c + + +def _overlapping(total, window, overlap=0.5): + assert total >= window and 0 <= overlap < 1, (total, window, overlap) + num_windows = 1 + int(np.ceil((total - window) / ((1 - overlap) * window))) + offsets = np.linspace(0, total - window, num_windows).round().astype(int) + yield from (slice(x, x + window) for x in offsets) + + +def _crop(img, sy, sx): + B, THREE, H, W = img.shape + if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W: + return img[:, :, sy, sx] + l, r = max(0, -sx.start), max(0, sx.stop - W) + t, b = max(0, -sy.start), max(0, sy.stop - H) + img = torch.nn.functional.pad(img, (l, r, t, b), mode="constant") + return img[:, :, slice(sy.start + t, sy.stop + t), slice(sx.start + l, sx.stop + l)] diff --git a/stream3r/croco/stereoflow/test.py b/stream3r/croco/stereoflow/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbc37487af1a2006c711d415135102e92795254 --- /dev/null +++ b/stream3r/croco/stereoflow/test.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main test function +# -------------------------------------------------------- + +import argparse +import os +import pickle + +import numpy as np +import torch +import utils.misc as misc +from models.croco_downstream import CroCoDownstreamBinocular +from models.head_downstream import PixelwiseTaskWithDPT +from PIL import Image +from stereoflow.criterion import FlowDatasetMetrics, StereoDatasetMetrics +from stereoflow.datasets_flow import flowToColor, get_test_datasets_flow +from stereoflow.datasets_stereo import get_test_datasets_stereo, vis_disparity +from stereoflow.engine import tiled_pred +from torch.utils.data import DataLoader +from tqdm import tqdm + + +def get_args_parser(): + parser = argparse.ArgumentParser("Test CroCo models on stereo/flow", add_help=False) + # important argument + parser.add_argument( + "--model", required=True, type=str, help="Path to the model to evaluate" + ) + parser.add_argument( + "--dataset", + required=True, + type=str, + help="test dataset (there can be multiple dataset separated by a +)", + ) + # tiling + parser.add_argument( + "--tile_conf_mode", + type=str, + default="", + help="Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint", + ) + parser.add_argument( + "--tile_overlap", type=float, default=0.7, help="overlap between tiles" + ) + # save (it will automatically go to _/_) + parser.add_argument( + "--save", + type=str, + nargs="+", + default=[], + help="what to save: \ + metrics (pickle file), \ + pred (raw prediction save as torch tensor), \ + visu (visualization in png of each prediction), \ + err10 (visualization in png of the error clamp at 10 for each prediction), \ + submission (submission file)", + ) + # other (no impact) + parser.add_argument("--num_workers", default=4, type=int) + return parser + + +def _load_model_and_criterion(model_path, do_load_metrics, device): + print("loading model from", model_path) + assert os.path.isfile(model_path) + ckpt = torch.load(model_path, "cpu") + + ckpt_args = ckpt["args"] + task = ckpt_args.task + tile_conf_mode = ckpt_args.tile_conf_mode + num_channels = {"stereo": 1, "flow": 2}[task] + with_conf = eval(ckpt_args.criterion).with_conf + if with_conf: + num_channels += 1 + print("head: PixelwiseTaskWithDPT()") + head = PixelwiseTaskWithDPT() + head.num_channels = num_channels + print("croco_args:", ckpt_args.croco_args) + model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args) + msg = model.load_state_dict(ckpt["model"], strict=True) + model.eval() + model = model.to(device) + + if do_load_metrics: + if task == "stereo": + metrics = StereoDatasetMetrics().to(device) + else: + metrics = FlowDatasetMetrics().to(device) + else: + metrics = None + + return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode + + +def _save_batch( + pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None +): + for i in range(len(pairnames)): + pairname = ( + eval(pairnames[i]) if pairnames[i].startswith("(") else pairnames[i] + ) # unbatch pairname + fname = os.path.join(outdir, dataset.pairname_to_str(pairname)) + os.makedirs(os.path.dirname(fname), exist_ok=True) + + predi = pred[i, ...] + if gt is not None: + gti = gt[i, ...] + + if "pred" in save: + torch.save(predi.squeeze(0).cpu(), fname + "_pred.pth") + + if "visu" in save: + if task == "stereo": + disparity = predi.permute((1, 2, 0)).squeeze(2).cpu().numpy() + m, M = None + if gt is not None: + mask = torch.isfinite(gti) + m = gt[mask].min() + M = gt[mask].max() + img_disparity = vis_disparity(disparity, m=m, M=M) + Image.fromarray(img_disparity).save(fname + "_pred.png") + else: + # normalize flowToColor according to the maxnorm of gt (or prediction if not available) + flowNorm = ( + torch.sqrt( + torch.sum((gti if gt is not None else predi) ** 2, dim=0) + ) + .max() + .item() + ) + imgflow = flowToColor( + predi.permute((1, 2, 0)).cpu().numpy(), maxflow=flowNorm + ) + Image.fromarray(imgflow).save(fname + "_pred.png") + + if "err10" in save: + assert gt is not None + L2err = torch.sqrt(torch.sum((gti - predi) ** 2, dim=0)) + valid = torch.isfinite(gti[0, :, :]) + L2err[~valid] = 0.0 + L2err = torch.clamp(L2err, max=10.0) + red = (L2err * 255.0 / 10.0).to(dtype=torch.uint8)[:, :, None] + zer = torch.zeros_like(red) + imgerr = torch.cat((red, zer, zer), dim=2).cpu().numpy() + Image.fromarray(imgerr).save(fname + "_err10.png") + + if "submission" in save: + assert submission_dir is not None + predi_np = ( + predi.permute(1, 2, 0).squeeze(2).cpu().numpy() + ) # transform into HxWx2 for flow or HxW for stereo + dataset.submission_save_pairname(pairname, predi_np, submission_dir, time) + + +def main(args): + # load the pretrained model and metrics + device = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + ) + ( + model, + metrics, + cropsize, + with_conf, + task, + tile_conf_mode, + ) = _load_model_and_criterion(args.model, "metrics" in args.save, device) + if args.tile_conf_mode == "": + args.tile_conf_mode = tile_conf_mode + + # load the datasets + datasets = ( + get_test_datasets_stereo if task == "stereo" else get_test_datasets_flow + )(args.dataset) + dataloaders = [ + DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + ) + for dataset in datasets + ] + + # run + for i, dataloader in enumerate(dataloaders): + dataset = datasets[i] + dstr = args.dataset.split("+")[i] + + outdir = args.model + "_" + misc.filename(dstr) + if "metrics" in args.save and len(args.save) == 1: + fname = os.path.join( + outdir, f"conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl" + ) + if os.path.isfile(fname) and len(args.save) == 1: + print(" metrics already compute in " + fname) + with open(fname, "rb") as fid: + results = pickle.load(fid) + for k, v in results.items(): + print("{:s}: {:.3f}".format(k, v)) + continue + + if "submission" in args.save: + dirname = ( + f"submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}" + ) + submission_dir = os.path.join(outdir, dirname) + else: + submission_dir = None + + print("") + print("saving {:s} in {:s}".format("+".join(args.save), outdir)) + print(repr(dataset)) + + if metrics is not None: + metrics.reset() + + for data_iter_step, (image1, image2, gt, pairnames) in enumerate( + tqdm(dataloader) + ): + do_flip = ( + task == "stereo" + and dstr.startswith("Spring") + and any("right" in p for p in pairnames) + ) # we flip the images and will flip the prediction after as we assume img1 is on the left + + image1 = image1.to(device, non_blocking=True) + image2 = image2.to(device, non_blocking=True) + gt = ( + gt.to(device, non_blocking=True) if gt.numel() > 0 else None + ) # special case for test time + if do_flip: + assert all("right" in p for p in pairnames) + image1 = image1.flip( + dims=[3] + ) # this is already the right frame, let's flip it + image2 = image2.flip(dims=[3]) + gt = gt # that is ok + + with torch.inference_mode(): + pred, _, _, time = tiled_pred( + model, + None, + image1, + image2, + None if dataset.name == "Spring" else gt, + conf_mode=args.tile_conf_mode, + overlap=args.tile_overlap, + crop=cropsize, + with_conf=with_conf, + return_time=True, + ) + + if do_flip: + pred = pred.flip(dims=[3]) + + if metrics is not None: + metrics.add_batch(pred, gt) + + if any(k in args.save for k in ["pred", "visu", "err10", "submission"]): + _save_batch( + pred, + gt, + pairnames, + dataset, + task, + args.save, + outdir, + time, + submission_dir=submission_dir, + ) + + # print + if metrics is not None: + results = metrics.get_results() + for k, v in results.items(): + print("{:s}: {:.3f}".format(k, v)) + + # save if needed + if "metrics" in args.save: + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, "wb") as fid: + pickle.dump(results, fid) + print("metrics saved in", fname) + + # finalize submission if needed + if "submission" in args.save: + dataset.finalize_submission(submission_dir) + + +if __name__ == "__main__": + args = get_args_parser() + args = args.parse_args() + main(args) diff --git a/stream3r/croco/stereoflow/train.py b/stream3r/croco/stereoflow/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b32b0df4274952cc643a47ddc0baf72f7bae5d57 --- /dev/null +++ b/stream3r/croco/stereoflow/train.py @@ -0,0 +1,457 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# -------------------------------------------------------- +# Main training function +# -------------------------------------------------------- + +import argparse +import datetime +import json +import os +import time + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torchvision.datasets as datasets +import torchvision.transforms as transforms +import utils +import utils.misc as misc +from models.croco_downstream import CroCoDownstreamBinocular, croco_args_from_ckpt +from models.head_downstream import PixelwiseTaskWithDPT +from models.pos_embed import interpolate_pos_embed +from stereoflow.criterion import FlowMetrics, StereoMetrics +from stereoflow.datasets_flow import get_test_datasets_flow, get_train_dataset_flow +from stereoflow.datasets_stereo import ( + get_test_datasets_stereo, + get_train_dataset_stereo, +) +from stereoflow.engine import train_one_epoch, validate_one_epoch +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from utils.misc import NativeScalerWithGradNormCount as NativeScaler + + +def get_args_parser(): + # prepare subparsers + parser = argparse.ArgumentParser( + "Finetuning CroCo models on stereo or flow", add_help=False + ) + subparsers = parser.add_subparsers( + title="Task (stereo or flow)", dest="task", required=True + ) + parser_stereo = subparsers.add_parser("stereo", help="Training stereo model") + parser_flow = subparsers.add_parser("flow", help="Training flow model") + + def add_arg( + name_or_flags, default=None, default_stereo=None, default_flow=None, **kwargs + ): + if default is not None: + assert ( + default_stereo is None and default_flow is None + ), "setting default makes default_stereo and default_flow disabled" + parser_stereo.add_argument( + name_or_flags, + default=default if default is not None else default_stereo, + **kwargs, + ) + parser_flow.add_argument( + name_or_flags, + default=default if default is not None else default_flow, + **kwargs, + ) + + # output dir + add_arg( + "--output_dir", + required=True, + type=str, + help="path where to save, if empty, automatically created", + ) + # model + add_arg( + "--crop", + type=int, + nargs="+", + default_stereo=[352, 704], + default_flow=[320, 384], + help="size of the random image crops used during training.", + ) + add_arg( + "--pretrained", + required=True, + type=str, + help="Load pretrained model (required as croco arguments come from there)", + ) + # criterion + add_arg( + "--criterion", + default_stereo="LaplacianLossBounded2()", + default_flow="LaplacianLossBounded()", + type=str, + help="string to evaluate to get criterion", + ) + add_arg("--bestmetric", default_stereo="avgerr", default_flow="EPE", type=str) + # dataset + add_arg("--dataset", type=str, required=True, help="training set") + # training + add_arg("--seed", default=0, type=int, help="seed") + add_arg( + "--batch_size", + default_stereo=6, + default_flow=8, + type=int, + help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus", + ) + add_arg("--epochs", default=32, type=int, help="number of training epochs") + add_arg( + "--img_per_epoch", + type=int, + default=None, + help="Fix the number of images seen in an epoch (None means use all training pairs)", + ) + add_arg( + "--accum_iter", + default=1, + type=int, + help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)", + ) + add_arg( + "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)" + ) + add_arg( + "--lr", + type=float, + default_stereo=3e-5, + default_flow=2e-5, + metavar="LR", + help="learning rate (absolute lr)", + ) + add_arg( + "--min_lr", + type=float, + default=0.0, + metavar="LR", + help="lower lr bound for cyclic schedulers that hit 0", + ) + add_arg( + "--warmup_epochs", type=int, default=1, metavar="N", help="epochs to warmup LR" + ) + add_arg( + "--optimizer", + default="AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))", + type=str, + help="Optimizer from torch.optim [ default: AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) ]", + ) + add_arg( + "--amp", + default=0, + type=int, + choices=[0, 1], + help="enable automatic mixed precision training", + ) + # validation + add_arg( + "--val_dataset", + type=str, + default="", + help="Validation sets, multiple separated by + (empty string means that no validation is performed)", + ) + add_arg( + "--tile_conf_mode", + type=str, + default_stereo="conf_expsigmoid_15_3", + default_flow="conf_expsigmoid_10_5", + help="Weights for tile aggregation", + ) + add_arg( + "--val_overlap", default=0.7, type=float, help="Overlap value for the tiling" + ) + # others + add_arg("--num_workers", default=8, type=int) + add_arg("--eval_every", type=int, default=1, help="Val loss evaluation frequency") + add_arg("--save_every", type=int, default=1, help="Save checkpoint frequency") + add_arg( + "--start_from", + type=str, + default=None, + help="Start training using weights from an other model (eg for finetuning)", + ) + add_arg( + "--tboard_log_step", + type=int, + default=100, + help="Log to tboard every so many steps", + ) + add_arg( + "--dist_url", default="env://", help="url used to set up distributed training" + ) + + return parser + + +def main(args): + misc.init_distributed_mode(args) + global_rank = misc.get_rank() + num_tasks = misc.get_world_size() + + assert os.path.isfile(args.pretrained) + print("output_dir: " + args.output_dir) + os.makedirs(args.output_dir, exist_ok=True) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + cudnn.benchmark = True + + # Metrics / criterion + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + metrics = (StereoMetrics if args.task == "stereo" else FlowMetrics)().to(device) + criterion = eval(args.criterion).to(device) + print("Criterion: ", args.criterion) + + # Prepare model + assert os.path.isfile(args.pretrained) + ckpt = torch.load(args.pretrained, "cpu") + croco_args = croco_args_from_ckpt(ckpt) + croco_args["img_size"] = (args.crop[0], args.crop[1]) + print("Croco args: " + str(croco_args)) + args.croco_args = croco_args # saved for test time + # prepare head + num_channels = {"stereo": 1, "flow": 2}[args.task] + if criterion.with_conf: + num_channels += 1 + print(f"Building head PixelwiseTaskWithDPT() with {num_channels} channel(s)") + head = PixelwiseTaskWithDPT() + head.num_channels = num_channels + # build model and load pretrained weights + model = CroCoDownstreamBinocular(head, **croco_args) + interpolate_pos_embed(model, ckpt["model"]) + msg = model.load_state_dict(ckpt["model"], strict=False) + print(msg) + + total_params = sum(p.numel() for p in model.parameters()) + total_params_trainable = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + print(f"Total params: {total_params}") + print(f"Total params trainable: {total_params_trainable}") + model_without_ddp = model.to(device) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + print("lr: %.2e" % args.lr) + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], static_graph=True + ) + model_without_ddp = model.module + + # following timm: set wd as 0 for bias and norm layers + param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) + optimizer = eval(f"torch.optim.{args.optimizer}") + print(optimizer) + loss_scaler = NativeScaler() + + # automatic restart + last_ckpt_fname = os.path.join(args.output_dir, f"checkpoint-last.pth") + args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None + + if not args.resume and args.start_from: + print(f"Starting from an other model's weights: {args.start_from}") + best_so_far = None + args.start_epoch = 0 + ckpt = torch.load(args.start_from, "cpu") + msg = model_without_ddp.load_state_dict(ckpt["model"], strict=False) + print(msg) + else: + best_so_far = misc.load_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + ) + + if best_so_far is None: + best_so_far = np.inf + + # tensorboard + log_writer = None + if global_rank == 0 and args.output_dir is not None: + log_writer = SummaryWriter( + log_dir=args.output_dir, purge_step=args.start_epoch * 1000 + ) + + # dataset and loader + print("Building Train Data loader for dataset: ", args.dataset) + train_dataset = ( + get_train_dataset_stereo if args.task == "stereo" else get_train_dataset_flow + )(args.dataset, crop_size=args.crop) + + def _print_repr_dataset(d): + if isinstance(d, torch.utils.data.dataset.ConcatDataset): + for dd in d.datasets: + _print_repr_dataset(dd) + else: + print(repr(d)) + + _print_repr_dataset(train_dataset) + print(" total length:", len(train_dataset)) + if args.distributed: + sampler_train = torch.utils.data.DistributedSampler( + train_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_train = torch.utils.data.RandomSampler(train_dataset) + data_loader_train = torch.utils.data.DataLoader( + train_dataset, + sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + if args.val_dataset == "": + data_loaders_val = None + else: + print("Building Val Data loader for datasets: ", args.val_dataset) + val_datasets = ( + get_test_datasets_stereo + if args.task == "stereo" + else get_test_datasets_flow + )(args.val_dataset) + for val_dataset in val_datasets: + print(repr(val_dataset)) + data_loaders_val = [ + DataLoader( + val_dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + ) + for val_dataset in val_datasets + ] + bestmetric = ( + "AVG_" + if len(data_loaders_val) > 1 + else str(data_loaders_val[0].dataset) + "_" + ) + args.bestmetric + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + # Training Loop + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + + # Train + epoch_start = time.time() + train_stats = train_one_epoch( + model, + criterion, + metrics, + data_loader_train, + optimizer, + device, + epoch, + loss_scaler, + log_writer=log_writer, + args=args, + ) + epoch_time = time.time() - epoch_start + + if args.distributed: + dist.barrier() + + # Validation (current naive implementation runs the validation on every gpu ... not smart ...) + if ( + data_loaders_val is not None + and args.eval_every > 0 + and (epoch + 1) % args.eval_every == 0 + ): + val_epoch_start = time.time() + val_stats = validate_one_epoch( + model, + criterion, + metrics, + data_loaders_val, + device, + epoch, + log_writer=log_writer, + args=args, + ) + val_epoch_time = time.time() - val_epoch_start + + val_best = val_stats[bestmetric] + + # Save best of all + if val_best <= best_so_far: + best_so_far = val_best + misc.save_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + best_so_far=best_so_far, + fname="best", + ) + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + "epoch": epoch, + **{f"val_{k}": v for k, v in val_stats.items()}, + } + else: + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + "epoch": epoch, + } + + if args.distributed: + dist.barrier() + + # Save stuff + if args.output_dir and ( + (epoch + 1) % args.save_every == 0 or epoch + 1 == args.epochs + ): + misc.save_model( + args=args, + model_without_ddp=model_without_ddp, + optimizer=optimizer, + loss_scaler=loss_scaler, + epoch=epoch, + best_so_far=best_so_far, + fname="last", + ) + + if args.output_dir: + if log_writer is not None: + log_writer.flush() + with open( + os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + +if __name__ == "__main__": + args = get_args_parser() + args = args.parse_args() + main(args) diff --git a/stream3r/croco/utils/misc.py b/stream3r/croco/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8c34192786d91aeb2d67e8b5a8cbf7412a1c682d --- /dev/null +++ b/stream3r/croco/utils/misc.py @@ -0,0 +1,531 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for CroCo +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import json +import math +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist +from torch import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, max_iter=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable) + space_fmt = ":" + str(len(str(len_iterable))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for it, obj in enumerate(iterable): + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len_iterable - 1: + eta_seconds = iter_time.global_avg * (len_iterable - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len_iterable, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len_iterable, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + if max_iter and it >= max_iter: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len_iterable + ) + ) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + nodist = args.nodist if hasattr(args, "nodist") else False + if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not nodist: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + else: + print("Not using distributed mode") + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}, gpu {}".format( + args.rank, args.dist_url, args.gpu + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self, enabled=True): + self._scaler = torch.cuda.amp.GradScaler(enabled=enabled) + + def __call__( + self, + loss, + optimizer, + clip_grad=None, + parameters=None, + create_graph=False, + update_grad=True, + ): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_( + optimizer + ) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + # TODO: FIXME: get_grad_norm_ is a poorly-implemented func that is very slow, and its return not even used + # norm = get_grad_norm_(parameters) + norm = None + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.0) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def save_model( + args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None +): + output_dir = Path(args.output_dir) + if fname is None: + fname = str(epoch) + checkpoint_path = output_dir / ("checkpoint-%s.pth" % fname) + to_save = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "scaler": loss_scaler.state_dict(), + "args": args, + "epoch": epoch, + } + if best_so_far is not None: + to_save["best_so_far"] = best_so_far + print(f">> Saving model to {checkpoint_path} ...") + save_on_master(to_save, checkpoint_path) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + args.start_epoch = 0 + best_so_far = None + if args.resume is not None: + if args.resume.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load(args.resume, map_location="cpu") + print("Resume checkpoint %s" % args.resume) + model_without_ddp.load_state_dict(checkpoint["model"], strict=False) + args.start_epoch = checkpoint["epoch"] + 1 + optimizer.load_state_dict(checkpoint["optimizer"]) + if "scaler" in checkpoint: + loss_scaler.load_state_dict(checkpoint["scaler"]) + if "best_so_far" in checkpoint: + best_so_far = checkpoint["best_so_far"] + print(" & best_so_far={:g}".format(best_so_far)) + else: + print("") + print("With optim & sched! start_epoch={:d}".format(args.start_epoch), end="") + return best_so_far + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + + +def _replace(text, src, tgt, rm=""): + """Advanced string replacement. + Given a text: + - replace all elements in src by the corresponding element in tgt + - remove all elements in rm + """ + if len(tgt) == 1: + tgt = tgt * len(src) + assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len" + for s, t in zip(src, tgt): + text = text.replace(s, t) + for c in rm: + text = text.replace(c, "") + return text + + +def filename(obj): + """transform a python obj or cmd into a proper filename. + - \1 gets replaced by slash '/' + - \2 gets replaced by comma ',' + """ + if not isinstance(obj, str): + obj = repr(obj) + obj = str(obj).replace("()", "") + obj = _replace(obj, "_,(*/\1\2", "-__x%/,", rm=" )'\"") + assert all(len(s) < 256 for s in obj.split(os.sep)), ( + "filename too long (>256 characters):\n" + obj + ) + return obj + + +def _get_num_layer_for_vit(var_name, enc_depth, dec_depth): + if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("enc_blocks"): + layer_id = int(var_name.split(".")[1]) + return layer_id + 1 + elif var_name.startswith("decoder_embed") or var_name.startswith( + "enc_norm" + ): # part of the last black + return enc_depth + elif var_name.startswith("dec_blocks"): + layer_id = int(var_name.split(".")[1]) + return enc_depth + layer_id + 1 + elif var_name.startswith("dec_norm"): # part of the last block + return enc_depth + dec_depth + elif any(var_name.startswith(k) for k in ["head", "prediction_head"]): + return enc_depth + dec_depth + 1 + else: + raise NotImplementedError(var_name) + + +def get_parameter_groups( + model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[] +): + parameter_group_names = {} + parameter_group_vars = {} + enc_depth, dec_depth = None, None + # prepare layer decay values + assert layer_decay == 1.0 or 0.0 < layer_decay < 1.0 + if layer_decay < 1.0: + enc_depth = model.enc_depth + dec_depth = model.dec_depth if hasattr(model, "dec_blocks") else 0 + num_layers = enc_depth + dec_depth + layer_decay_values = list( + layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2) + ) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + + # Assign weight decay values + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + group_name = "no_decay" + this_weight_decay = 0.0 + else: + group_name = "decay" + this_weight_decay = weight_decay + + # Assign layer ID for LR scaling + if layer_decay < 1.0: + skip_scale = False + layer_id = _get_num_layer_for_vit(name, enc_depth, dec_depth) + group_name = "layer_%d_%s" % (layer_id, group_name) + if name in no_lr_scale_list: + skip_scale = True + group_name = f"{group_name}_no_lr_scale" + else: + layer_id = 0 + skip_scale = True + + if group_name not in parameter_group_names: + if not skip_scale: + scale = layer_decay_values[layer_id] + else: + scale = 1.0 + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + } + + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + return list(parameter_group_vars.values()) + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( + 1.0 + + math.cos( + math.pi + * (epoch - args.warmup_epochs) + / (args.epochs - args.warmup_epochs) + ) + ) + + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + + return lr diff --git a/stream3r/dust3r/__init__.py b/stream3r/dust3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea9d87a320e848d1c4851a1e2408313c9255365 --- /dev/null +++ b/stream3r/dust3r/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/stream3r/dust3r/__pycache__/__init__.cpython-311.pyc b/stream3r/dust3r/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30233195d61b7ccc31b7c4e4c99fae410172e7d3 Binary files /dev/null and b/stream3r/dust3r/__pycache__/__init__.cpython-311.pyc differ diff --git a/stream3r/dust3r/cloud_opt/__init__.py b/stream3r/dust3r/cloud_opt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed185c899d09c004e533803014ca02a45f20bae --- /dev/null +++ b/stream3r/dust3r/cloud_opt/__init__.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# global alignment optimization wrapper function +# -------------------------------------------------------- +from enum import Enum + +from .modular_optimizer import ModularPointCloudOptimizer +from .optimizer import PointCloudOptimizer +from .pair_viewer import PairViewer + + +class GlobalAlignerMode(Enum): + PointCloudOptimizer = "PointCloudOptimizer" + ModularPointCloudOptimizer = "ModularPointCloudOptimizer" + PairViewer = "PairViewer" + + +def global_aligner( + dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw +): + # extract all inputs + view1, view2, pred1, pred2 = [ + dust3r_output[k] for k in "view1 view2 pred1 pred2".split() + ] + # build the optimizer + if mode == GlobalAlignerMode.PointCloudOptimizer: + net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) + elif mode == GlobalAlignerMode.ModularPointCloudOptimizer: + net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to( + device + ) + elif mode == GlobalAlignerMode.PairViewer: + net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) + else: + raise NotImplementedError(f"Unknown mode {mode}") + + return net diff --git a/stream3r/dust3r/cloud_opt/base_opt.py b/stream3r/dust3r/cloud_opt/base_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..437f19266e8d0226da1de4b9524c0edee1477e20 --- /dev/null +++ b/stream3r/dust3r/cloud_opt/base_opt.py @@ -0,0 +1,453 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Base class for the global alignement procedure +# -------------------------------------------------------- +from copy import deepcopy + +import numpy as np +import roma +import torch +import torch.nn as nn +import tqdm + +import stream3r.dust3r.cloud_opt.init_im_poses as init_fun +from stream3r.dust3r.cloud_opt.commons import ( + ALL_DISTS, + NoGradParamDict, + cosine_schedule, + edge_str, + get_conf_trf, + get_imshapes, + linear_schedule, + signed_expm1, + signed_log1p, +) +from stream3r.dust3r.optim_factory import adjust_learning_rate_by_lr +from stream3r.dust3r.utils.device import to_numpy +from stream3r.dust3r.utils.geometry import geotrf, inv +from stream3r.dust3r.utils.image import rgb +from stream3r.dust3r.viz import SceneViz, auto_cam_size, segment_sky + + +class BasePCOptimizer(nn.Module): + """Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + other = deepcopy(args[0]) + attrs = """edges is_symmetrized dist n_imgs pred_i pred_j imshapes + min_conf_thr conf_thr conf_i conf_j im_conf + base_scale norm_pw_scale POSE_DIM pw_poses + pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose""".split() + self.__dict__.update({k: other[k] for k in attrs}) + else: + self._init_from_views(*args, **kwargs) + + def _init_from_views( + self, + view1, + view2, + pred1, + pred2, + dist="l1", + conf="log", + min_conf_thr=3, + base_scale=0.5, + allow_pw_adaptors=False, + pw_break=20, + rand_pose=torch.randn, + iterationsCount=None, + verbose=True, + ): + super().__init__() + if not isinstance(view1["idx"], list): + view1["idx"] = view1["idx"].tolist() + if not isinstance(view2["idx"], list): + view2["idx"] = view2["idx"].tolist() + self.edges = [(int(i), int(j)) for i, j in zip(view1["idx"], view2["idx"])] + self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} + self.dist = ALL_DISTS[dist] + self.verbose = verbose + + self.n_imgs = self._check_edges() + + # input data + pred1_pts = pred1["pts3d"] + pred2_pts = pred2["pts3d_in_other_view"] + self.pred_i = NoGradParamDict( + {ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)} + ) + self.pred_j = NoGradParamDict( + {ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)} + ) + self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) + + # work in log-scale with conf + pred1_conf = pred1["conf"] + pred2_conf = pred2["conf"] + self.min_conf_thr = min_conf_thr + self.conf_trf = get_conf_trf(conf) + + self.conf_i = NoGradParamDict( + {ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)} + ) + self.conf_j = NoGradParamDict( + {ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)} + ) + self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) + for i in range(len(self.im_conf)): + self.im_conf[i].requires_grad = False + + # pairwise pose parameters + self.base_scale = base_scale + self.norm_pw_scale = True + self.pw_break = pw_break + self.POSE_DIM = 7 + self.pw_poses = nn.Parameter( + rand_pose((self.n_edges, 1 + self.POSE_DIM)) + ) # pairwise poses + self.pw_adaptors = nn.Parameter( + torch.zeros((self.n_edges, 2)) + ) # slight xy/z adaptation + self.pw_adaptors.requires_grad_(allow_pw_adaptors) + self.has_im_poses = False + self.rand_pose = rand_pose + + # possibly store images for show_pointcloud + self.imgs = None + if "img" in view1 and "img" in view2: + imgs = [torch.zeros((3,) + hw) for hw in self.imshapes] + for v in range(len(self.edges)): + idx = view1["idx"][v] + imgs[idx] = view1["img"][v] + idx = view2["idx"][v] + imgs[idx] = view2["img"][v] + self.imgs = rgb(imgs) + + @property + def n_edges(self): + return len(self.edges) + + @property + def str_edges(self): + return [edge_str(i, j) for i, j in self.edges] + + @property + def imsizes(self): + return [(w, h) for h, w in self.imshapes] + + @property + def device(self): + return next(iter(self.parameters())).device + + def state_dict(self, trainable=True): + all_params = super().state_dict() + return { + k: v + for k, v in all_params.items() + if k.startswith(("_", "pred_i.", "pred_j.", "conf_i.", "conf_j.")) + != trainable + } + + def load_state_dict(self, data): + return super().load_state_dict(self.state_dict(trainable=False) | data) + + def _check_edges(self): + indices = sorted({i for edge in self.edges for i in edge}) + assert indices == list(range(len(indices))), "bad pair indices: missing values " + return len(indices) + + @torch.no_grad() + def _compute_img_conf(self, pred1_conf, pred2_conf): + im_conf = nn.ParameterList( + [torch.zeros(hw, device=self.device) for hw in self.imshapes] + ) + for e, (i, j) in enumerate(self.edges): + im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) + im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) + return im_conf + + def get_adaptors(self): + adapt = self.pw_adaptors + adapt = torch.cat( + (adapt[:, 0:1], adapt), dim=-1 + ) # (scale_xy, scale_xy, scale_z) + if self.norm_pw_scale: # normalize so that the product == 1 + adapt = adapt - adapt.mean(dim=1, keepdim=True) + return (adapt / self.pw_break).exp() + + def _get_poses(self, poses): + # normalize rotation + Q = poses[:, :4] + T = signed_expm1(poses[:, 4:7]) + RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() + return RT + + def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): + # all poses == cam-to-world + pose = poses[idx] + if not (pose.requires_grad or force): + return pose + + if R.shape == (4, 4): + assert T is None + T = R[:3, 3] + R = R[:3, :3] + + if R is not None: + pose.data[0:4] = roma.rotmat_to_unitquat(R) + if T is not None: + pose.data[4:7] = signed_log1p( + T / (scale or 1) + ) # translation is function of scale + + if scale is not None: + assert poses.shape[-1] in (8, 13) + pose.data[-1] = np.log(float(scale)) + return pose + + def get_pw_norm_scale_factor(self): + if self.norm_pw_scale: + # normalize scales so that things cannot go south + # we want that exp(scale) ~= self.base_scale + return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() + else: + return 1 # don't norm scale for known poses + + def get_pw_scale(self): + scale = self.pw_poses[:, -1].exp() # (n_edges,) + scale = scale * self.get_pw_norm_scale_factor() + return scale + + def get_pw_poses(self): # cam to world + RT = self._get_poses(self.pw_poses) + scaled_RT = RT.clone() + scaled_RT[:, :3] *= self.get_pw_scale().view( + -1, 1, 1 + ) # scale the rotation AND translation + return scaled_RT + + def get_masks(self): + return [(conf > self.min_conf_thr) for conf in self.im_conf] + + def depth_to_pts3d(self): + raise NotImplementedError() + + def get_pts3d(self, raw=False): + res = self.depth_to_pts3d() + if not raw: + res = [dm[: h * w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def _set_focal(self, idx, focal, force=False): + raise NotImplementedError() + + def get_focals(self): + raise NotImplementedError() + + def get_known_focal_mask(self): + raise NotImplementedError() + + def get_principal_points(self): + raise NotImplementedError() + + def get_conf(self, mode=None): + trf = self.conf_trf if mode is None else get_conf_trf(mode) + return [trf(c) for c in self.im_conf] + + def get_im_poses(self): + raise NotImplementedError() + + def _set_depthmap(self, idx, depth, force=False): + raise NotImplementedError() + + def get_depthmaps(self, raw=False): + raise NotImplementedError() + + @torch.no_grad() + def clean_pointcloud(self, tol=0.001, max_bad_conf=0): + """Method: + 1) express all 3d points in each camera coordinate frame + 2) if they're in front of a depthmap --> then lower their confidence + """ + assert 0 <= tol < 1 + cams = inv(self.get_im_poses()) + K = self.get_intrinsics() + depthmaps = self.get_depthmaps() + res = deepcopy(self) + + for i, pts3d in enumerate(self.depth_to_pts3d()): + for j in range(self.n_imgs): + if i == j: + continue + + # project 3dpts in other view + Hi, Wi = self.imshapes[i] + Hj, Wj = self.imshapes[j] + proj = geotrf(cams[j], pts3d[: Hi * Wi]).reshape(Hi, Wi, 3) + proj_depth = proj[:, :, 2] + u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) + + # check which points are actually in the visible cone + msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj) + msk_j = v[msk_i], u[msk_i] + + # find bad points = those in front but less confident + bad_points = (proj_depth[msk_i] < (1 - tol) * depthmaps[j][msk_j]) & ( + res.im_conf[i][msk_i] < res.im_conf[j][msk_j] + ) + + bad_msk_i = msk_i.clone() + bad_msk_i[msk_i] = bad_points + res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_( + max=max_bad_conf + ) + + return res + + def forward(self, ret_details=False): + pw_poses = self.get_pw_poses() # cam-to-world + pw_adapt = self.get_adaptors() + proj_pts3d = self.get_pts3d() + # pre-compute pixel weights + weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} + weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} + + loss = 0 + if ret_details: + details = -torch.ones((self.n_imgs, self.n_imgs)) + + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # distance in image i and j + aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) + aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) + li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() + lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() + loss = loss + li + lj + + if ret_details: + details[i, j] = li + lj + loss /= self.n_edges # average over all pairs + + if ret_details: + return loss, details + return loss + + @torch.amp.autocast("cuda", enabled=False) + def compute_global_alignment(self, init=None, niter_PnP=10, **kw): + if init is None: + pass + elif init == "msp" or init == "mst": + init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) + elif init == "known_poses": + init_fun.init_from_known_poses( + self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP + ) + else: + raise ValueError(f"bad value for {init=}") + + return global_alignment_loop(self, **kw) + + @torch.no_grad() + def mask_sky(self): + res = deepcopy(self) + for i in range(self.n_imgs): + sky = segment_sky(self.imgs[i]) + res.im_conf[i][sky] = 0 + return res + + def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): + viz = SceneViz() + if self.imgs is None: + colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) + colors = list(map(tuple, colors.tolist())) + for n in range(self.n_imgs): + viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) + else: + viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) + colors = np.random.randint(256, size=(self.n_imgs, 3)) + + # camera poses + im_poses = to_numpy(self.get_im_poses()) + if cam_size is None: + cam_size = auto_cam_size(im_poses) + viz.add_cameras( + im_poses, + self.get_focals(), + colors=colors, + images=self.imgs, + imsizes=self.imsizes, + cam_size=cam_size, + ) + if show_pw_cams: + pw_poses = self.get_pw_poses() + viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) + + if show_pw_pts3d: + pts = [ + geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) + for e, (i, j) in enumerate(self.edges) + ] + viz.add_pointcloud(pts, (128, 0, 128)) + + viz.show(**kw) + return viz + + +def global_alignment_loop(net, lr=0.01, niter=300, schedule="cosine", lr_min=1e-6): + params = [p for p in net.parameters() if p.requires_grad] + if not params: + return net + + verbose = net.verbose + if verbose: + print("Global alignement - optimizing for:") + print([name for name, value in net.named_parameters() if value.requires_grad]) + + lr_base = lr + optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) + + loss = float("inf") + if verbose: + with tqdm.tqdm(total=niter) as bar: + while bar.n < bar.total: + loss, lr = global_alignment_iter( + net, bar.n, niter, lr_base, lr_min, optimizer, schedule + ) + bar.set_postfix_str(f"{lr=:g} loss={loss:g}") + bar.update() + else: + for n in range(niter): + loss, _ = global_alignment_iter( + net, n, niter, lr_base, lr_min, optimizer, schedule + ) + return loss + + +def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule): + t = cur_iter / niter + if schedule == "cosine": + lr = cosine_schedule(t, lr_base, lr_min) + elif schedule == "linear": + lr = linear_schedule(t, lr_base, lr_min) + else: + raise ValueError(f"bad lr {schedule=}") + adjust_learning_rate_by_lr(optimizer, lr) + optimizer.zero_grad() + loss = net() + loss.backward() + optimizer.step() + + return float(loss), lr diff --git a/stream3r/dust3r/cloud_opt/commons.py b/stream3r/dust3r/cloud_opt/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f596f5aab122a82e4c7d9f2565454f3d2757f3 --- /dev/null +++ b/stream3r/dust3r/cloud_opt/commons.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utility functions for global alignment +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn + + +def edge_str(i, j): + return f"{i}_{j}" + + +def i_j_ij(ij): + return edge_str(*ij), ij + + +def edge_conf(conf_i, conf_j, edge): + return float(conf_i[edge].mean() * conf_j[edge].mean()) + + +def compute_edge_scores(edges, conf_i, conf_j): + return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges} + + +def NoGradParamDict(x): + assert isinstance(x, dict) + return nn.ParameterDict(x).requires_grad_(False) + + +def get_imshapes(edges, pred_i, pred_j): + n_imgs = max(max(e) for e in edges) + 1 + imshapes = [None] * n_imgs + for e, (i, j) in enumerate(edges): + shape_i = tuple(pred_i[e].shape[0:2]) + shape_j = tuple(pred_j[e].shape[0:2]) + if imshapes[i]: + assert imshapes[i] == shape_i, f"incorrect shape for image {i}" + if imshapes[j]: + assert imshapes[j] == shape_j, f"incorrect shape for image {j}" + imshapes[i] = shape_i + imshapes[j] = shape_j + return imshapes + + +def get_conf_trf(mode): + if mode == "log": + + def conf_trf(x): + return x.log() + + elif mode == "sqrt": + + def conf_trf(x): + return x.sqrt() + + elif mode == "m1": + + def conf_trf(x): + return x - 1 + + elif mode in ("id", "none"): + + def conf_trf(x): + return x + + else: + raise ValueError(f"bad mode for {mode=}") + return conf_trf + + +def l2_dist(a, b, weight): + return (a - b).square().sum(dim=-1) * weight + + +def l1_dist(a, b, weight): + return (a - b).norm(dim=-1) * weight + + +ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) + + +def signed_log1p(x): + sign = torch.sign(x) + return sign * torch.log1p(torch.abs(x)) + + +def signed_expm1(x): + sign = torch.sign(x) + return sign * torch.expm1(torch.abs(x)) + + +def cosine_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2 + + +def linear_schedule(t, lr_start, lr_end): + assert 0 <= t <= 1 + return lr_start + (lr_end - lr_start) * t diff --git a/stream3r/dust3r/cloud_opt/init_im_poses.py b/stream3r/dust3r/cloud_opt/init_im_poses.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f8649767ab77a1df0f5fe88444f5a728d739a0 --- /dev/null +++ b/stream3r/dust3r/cloud_opt/init_im_poses.py @@ -0,0 +1,382 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Initialization functions for global alignment +# -------------------------------------------------------- +from functools import cache + +import cv2 +import numpy as np +import roma +import scipy.sparse as sp +import torch +from tqdm import tqdm + +from stream3r.dust3r.cloud_opt.commons import compute_edge_scores, edge_str, i_j_ij +from stream3r.dust3r.post_process import estimate_focal_knowing_depth +from stream3r.dust3r.utils.geometry import geotrf, get_med_dist_between_poses, inv +from stream3r.dust3r.viz import to_numpy + + +@torch.no_grad() +def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3): + device = self.device + + # indices of known poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + assert nkp == self.n_imgs, "not all poses are known" + + # get all focals + nkf, _, im_focals = get_known_focals(self) + assert nkf == self.n_imgs + im_pp = self.get_principal_points() + + best_depthmaps = {} + # init all pairwise poses + for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)): + i_j = edge_str(i, j) + + # find relative pose for this pair + P1 = torch.eye(4, device=device) + msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1) + _, P2 = fast_pnp( + self.pred_j[i_j], + float(im_focals[i].mean()), + pp=im_pp[i], + msk=msk, + device=device, + niter_PnP=niter_PnP, + ) + + # align the two predicted camera with the two gt cameras + s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]]) + # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1 + # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3]) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # remember if this is a good depthmap + score = float(self.conf_i[i_j].mean()) + if score > best_depthmaps.get(i, (0,))[0]: + best_depthmaps[i] = score, i_j, s + + # init all image poses + for n in range(self.n_imgs): + assert known_poses_msk[n] + _, i_j, scale = best_depthmaps[n] + depth = self.pred_i[i_j][:, :, 2] + self._set_depthmap(n, depth * scale) + + +@torch.no_grad() +def init_minimum_spanning_tree(self, **kw): + """Init all camera poses (image-wise and pairwise poses) given + an initial set of pairwise estimations. + """ + device = self.device + pts3d, _, im_focals, im_poses = minimum_spanning_tree( + self.imshapes, + self.edges, + self.pred_i, + self.pred_j, + self.conf_i, + self.conf_j, + self.im_conf, + self.min_conf_thr, + device, + has_im_poses=self.has_im_poses, + verbose=self.verbose, + **kw, + ) + + return init_from_pts3d(self, pts3d, im_focals, im_poses) + + +def init_from_pts3d(self, pts3d, im_focals, im_poses): + # init poses + nkp, known_poses_msk, known_poses = get_known_poses(self) + if nkp == 1: + raise NotImplementedError( + "Would be simpler to just align everything afterwards on the single known pose" + ) + elif nkp > 1: + # global rigid SE3 alignment + s, R, T = align_multiple_poses( + im_poses[known_poses_msk], known_poses[known_poses_msk] + ) + trf = sRT_to_4x4(s, R, T, device=known_poses.device) + + # rotate everything + im_poses = trf @ im_poses + im_poses[:, :3, :3] /= s # undo scaling on the rotation part + for img_pts3d in pts3d: + img_pts3d[:] = geotrf(trf, img_pts3d) + + # set all pairwise poses + for e, (i, j) in enumerate(self.edges): + i_j = edge_str(i, j) + # compute transform that goes from cam to world + s, R, T = rigid_points_registration( + self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j] + ) + self._set_pose(self.pw_poses, e, R, T, scale=s) + + # take into account the scale normalization + s_factor = self.get_pw_norm_scale_factor() + im_poses[:, :3, 3] *= s_factor # apply downscaling factor + for img_pts3d in pts3d: + img_pts3d *= s_factor + + # init all image poses + if self.has_im_poses: + for i in range(self.n_imgs): + cam2world = im_poses[i] + depth = geotrf(inv(cam2world), pts3d[i])[..., 2] + self._set_depthmap(i, depth) + self._set_pose(self.im_poses, i, cam2world) + if im_focals[i] is not None: + self._set_focal(i, im_focals[i]) + + if self.verbose: + print(" init loss =", float(self())) + + +def minimum_spanning_tree( + imshapes, + edges, + pred_i, + pred_j, + conf_i, + conf_j, + im_conf, + min_conf_thr, + device, + has_im_poses=True, + niter_PnP=10, + verbose=True, +): + n_imgs = len(imshapes) + sparse_graph = -dict_to_sparse_graph( + compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j) + ) + msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() + + # temp variable to store 3d points + pts3d = [None] * len(imshapes) + + todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges + im_poses = [None] * n_imgs + im_focals = [None] * n_imgs + + # init with strongest edge + score, i, j = todo.pop() + if verbose: + print(f" init edge ({i}*,{j}*) {score=}") + i_j = edge_str(i, j) + pts3d[i] = pred_i[i_j].clone() + pts3d[j] = pred_j[i_j].clone() + done = {i, j} + if has_im_poses: + im_poses[i] = torch.eye(4, device=device) + im_focals[i] = estimate_focal(pred_i[i_j]) + + # set initial pointcloud based on pairwise graph + msp_edges = [(i, j)] + while todo: + # each time, predict the next one + score, i, j = todo.pop() + + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[i_j]) + + if i in done: + if verbose: + print(f" init edge ({i},{j}*) {score=}") + assert j not in done + # align pred[i] with pts3d[i], and then set j accordingly + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[j] = geotrf(trf, pred_j[i_j]) + done.add(j) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + + elif j in done: + if verbose: + print(f" init edge ({i}*,{j}) {score=}") + assert i not in done + i_j = edge_str(i, j) + s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) + trf = sRT_to_4x4(s, R, T, device) + pts3d[i] = geotrf(trf, pred_i[i_j]) + done.add(i) + msp_edges.append((i, j)) + + if has_im_poses and im_poses[i] is None: + im_poses[i] = sRT_to_4x4(1, R, T, device) + else: + # let's try again later + todo.insert(0, (score, i, j)) + + if has_im_poses: + # complete all missing informations + pair_scores = list( + sparse_graph.values() + ) # already negative scores: less is best + edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[ + np.argsort(pair_scores) + ] + for i, j in edges_from_best_to_worse.tolist(): + if im_focals[i] is None: + im_focals[i] = estimate_focal(pred_i[edge_str(i, j)]) + + for i in range(n_imgs): + if im_poses[i] is None: + msk = im_conf[i] > min_conf_thr + res = fast_pnp( + pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP + ) + if res: + im_focals[i], im_poses[i] = res + if im_poses[i] is None: + im_poses[i] = torch.eye(4, device=device) + im_poses = torch.stack(im_poses) + else: + im_poses = im_focals = None + + return pts3d, msp_edges, im_focals, im_poses + + +def dict_to_sparse_graph(dic): + n_imgs = max(max(e) for e in dic) + 1 + res = sp.dok_array((n_imgs, n_imgs)) + for edge, value in dic.items(): + res[edge] = value + return res + + +def rigid_points_registration(pts1, pts2, conf): + R, T, s = roma.rigid_points_registration( + pts1.reshape(-1, 3), + pts2.reshape(-1, 3), + weights=conf.ravel(), + compute_scaling=True, + ) + return s, R, T # return un-scaled (R, T) + + +def sRT_to_4x4(scale, R, T, device): + trf = torch.eye(4, device=device) + trf[:3, :3] = R * scale + trf[:3, 3] = T.ravel() # doesn't need scaling + return trf + + +def estimate_focal(pts3d_i, pp=None): + if pp is None: + H, W, THREE = pts3d_i.shape + assert THREE == 3 + pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device) + focal = estimate_focal_knowing_depth( + pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld" + ).ravel() + return float(focal) + + +@cache +def pixel_grid(H, W): + return np.mgrid[:W, :H].T.astype(np.float32) + + +def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10, num_guessed_focals=100): + # extract camera poses and focals with RANSAC-PnP + if msk.sum() < 4: + return None, None # we need at least 4 points for PnP + pts3d, msk = map(to_numpy, (pts3d, msk)) + + H, W, THREE = pts3d.shape + assert THREE == 3 + pixels = pixel_grid(H, W) + + if focal is None: + S = max(W, H) + tentative_focals = np.geomspace(S / 2, S * 3, num=num_guessed_focals) + else: + tentative_focals = [focal] + + if pp is None: + pp = (W / 2, H / 2) + else: + pp = to_numpy(pp) + + best = (0,) + for focal in tentative_focals: + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + try: # solvePnPRansac is not always solvable, especially when the predicted points are not very good + success, R, T, inliers = cv2.solvePnPRansac( + pts3d[msk], + pixels[msk], + K, + None, + iterationsCount=niter_PnP, + reprojectionError=5, + flags=cv2.SOLVEPNP_SQPNP, + ) + if not success: + continue + except cv2.error: + continue + + score = len(inliers) + if success and score > best[0]: + best = score, R, T, focal + + if not best[0]: + return None, None + + _, R, T, best_focal = best + R = cv2.Rodrigues(R)[0] # world to cam + R, T = map(torch.from_numpy, (R, T)) + return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world + + +def get_known_poses(self): + if self.has_im_poses: + known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses]) + known_poses = self.get_im_poses() + return known_poses_msk.sum(), known_poses_msk, known_poses + else: + return 0, None, None + + +def get_known_focals(self): + if self.has_im_poses: + known_focal_msk = self.get_known_focal_mask() + known_focals = self.get_focals() + return known_focal_msk.sum(), known_focal_msk, known_focals + else: + return 0, None, None + + +def align_multiple_poses(src_poses, target_poses): + N = len(src_poses) + assert src_poses.shape == target_poses.shape == (N, 4, 4) + + def center_and_z(poses): + eps = get_med_dist_between_poses(poses) / 100 + return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2])) + + R, T, s = roma.rigid_points_registration( + center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True + ) + return s, R, T diff --git a/stream3r/dust3r/cloud_opt/modular_optimizer.py b/stream3r/dust3r/cloud_opt/modular_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..30a14fe9b132c05bd1c8cf516f446e56fea16d8b --- /dev/null +++ b/stream3r/dust3r/cloud_opt/modular_optimizer.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Slower implementation of the global alignment that allows to freeze partial poses/intrinsics +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn + +from stream3r.dust3r.cloud_opt.base_opt import BasePCOptimizer +from stream3r.dust3r.utils.device import to_cpu, to_numpy +from stream3r.dust3r.utils.geometry import depthmap_to_pts3d, geotrf + + +class ModularPointCloudOptimizer(BasePCOptimizer): + """Optimize a global scene, given a list of pairwise observations. + Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics) + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__( + self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs + ): + super().__init__(*args, **kwargs) + self.has_im_poses = True # by definition of this class + self.focal_brake = focal_brake + + # adding thing to optimize + self.im_depthmaps = nn.ParameterList( + torch.randn(H, W) / 10 - 3 for H, W in self.imshapes + ) # log(depth) + self.im_poses = nn.ParameterList( + self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs) + ) # camera poses + default_focals = [ + self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes + ] + self.im_focals = nn.ParameterList( + torch.FloatTensor([f, f] if fx_and_fy else [f]) for f in default_focals + ) # camera intrinsics + self.im_pp = nn.ParameterList( + torch.zeros((2,)) for _ in range(self.n_imgs) + ) # camera intrinsics + self.im_pp.requires_grad_(optimize_pp) + + def preset_pose(self, known_poses, pose_msk=None): # cam-to-world + if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: + known_poses = [known_poses] + for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): + if self.verbose: + print(f" (setting pose #{idx} = {pose[:3,3]})") + self._no_grad( + self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True) + ) + + # normalize scale if there's less than 1 known pose + n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) + self.norm_pw_scale = n_known_poses <= 1 + + def preset_intrinsics(self, known_intrinsics, msk=None): + if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2: + known_intrinsics = [known_intrinsics] + for K in known_intrinsics: + assert K.shape == (3, 3) + self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk) + self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk) + + def preset_focal(self, known_focals, msk=None): + for idx, focal in zip(self._get_msk_indices(msk), known_focals): + if self.verbose: + print(f" (setting focal #{idx} = {focal})") + self._no_grad(self._set_focal(idx, focal, force=True)) + + def preset_principal_point(self, known_pp, msk=None): + for idx, pp in zip(self._get_msk_indices(msk), known_pp): + if self.verbose: + print(f" (setting principal point #{idx} = {pp})") + self._no_grad(self._set_principal_point(idx, pp, force=True)) + + def _no_grad(self, tensor): + return tensor.requires_grad_(False) + + def _get_msk_indices(self, msk): + if msk is None: + return range(self.n_imgs) + elif isinstance(msk, int): + return [msk] + elif isinstance(msk, (tuple, list)): + return self._get_msk_indices(np.array(msk)) + elif msk.dtype in (bool, torch.bool, np.bool_): + assert len(msk) == self.n_imgs + return np.where(msk)[0] + elif np.issubdtype(msk.dtype, np.integer): + return msk + else: + raise ValueError(f"bad {msk=}") + + def _set_focal(self, idx, focal, force=False): + param = self.im_focals[idx] + if ( + param.requires_grad or force + ): # can only init a parameter not already initialized + param.data[:] = self.focal_brake * np.log(focal) + return param + + def get_focals(self): + log_focals = torch.stack(list(self.im_focals), dim=0) + return (log_focals / self.focal_brake).exp() + + def _set_principal_point(self, idx, pp, force=False): + param = self.im_pp[idx] + H, W = self.imshapes[idx] + if ( + param.requires_grad or force + ): # can only init a parameter not already initialized + param.data[:] = to_cpu(to_numpy(pp) - (W / 2, H / 2)) / 10 + return param + + def get_principal_points(self): + return torch.stack( + [ + pp.new((W / 2, H / 2)) + 10 * pp + for pp, (H, W) in zip(self.im_pp, self.imshapes) + ] + ) + + def get_intrinsics(self): + K = torch.zeros((self.n_imgs, 3, 3), device=self.device) + focals = self.get_focals().view(self.n_imgs, -1) + K[:, 0, 0] = focals[:, 0] + K[:, 1, 1] = focals[:, -1] + K[:, :2, 2] = self.get_principal_points() + K[:, 2, 2] = 1 + return K + + def get_im_poses(self): # cam to world + cam2world = self._get_poses(torch.stack(list(self.im_poses))) + return cam2world + + def _set_depthmap(self, idx, depth, force=False): + param = self.im_depthmaps[idx] + if ( + param.requires_grad or force + ): # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def get_depthmaps(self): + return [d.exp() for d in self.im_depthmaps] + + def depth_to_pts3d(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps() + + # convert focal to (1,2,H,W) constant field + def focal_ex(i): + return focals[i][..., None, None].expand( + 1, *focals[i].shape, *self.imshapes[i] + ) + + # get pointmaps in camera frame + rel_ptmaps = [ + depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i : i + 1])[0] + for i in range(im_poses.shape[0]) + ] + # project to world frame + return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)] + + def get_pts3d(self): + return self.depth_to_pts3d() diff --git a/stream3r/dust3r/cloud_opt/optimizer.py b/stream3r/dust3r/cloud_opt/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ac3d84916951c190fa780bcf093f55960a34ee --- /dev/null +++ b/stream3r/dust3r/cloud_opt/optimizer.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Main class for the implementation of the global alignment +# -------------------------------------------------------- +import numpy as np +import torch +import torch.nn as nn + +from stream3r.dust3r.cloud_opt.base_opt import BasePCOptimizer +from stream3r.dust3r.utils.device import to_cpu, to_numpy +from stream3r.dust3r.utils.geometry import geotrf, xy_grid + + +class PointCloudOptimizer(BasePCOptimizer): + """Optimize a global scene, given a list of pairwise observations. + Graph node: images + Graph edges: observations = (pred1, pred2) + """ + + def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs): + super().__init__(*args, **kwargs) + + self.has_im_poses = True # by definition of this class + self.focal_break = focal_break + + # adding thing to optimize + self.im_depthmaps = nn.ParameterList( + torch.randn(H, W) / 10 - 3 for H, W in self.imshapes + ) # log(depth) + self.im_poses = nn.ParameterList( + self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs) + ) # camera poses + self.im_focals = nn.ParameterList( + torch.FloatTensor([self.focal_break * np.log(max(H, W))]) + for H, W in self.imshapes + ) # camera intrinsics + self.im_pp = nn.ParameterList( + torch.zeros((2,)) for _ in range(self.n_imgs) + ) # camera intrinsics + self.im_pp.requires_grad_(optimize_pp) + + self.imshape = self.imshapes[0] + im_areas = [h * w for h, w in self.imshapes] + self.max_area = max(im_areas) + + # adding thing to optimize + self.im_depthmaps = ParameterStack( + self.im_depthmaps, is_param=True, fill=self.max_area + ) + self.im_poses = ParameterStack(self.im_poses, is_param=True) + self.im_focals = ParameterStack(self.im_focals, is_param=True) + self.im_pp = ParameterStack(self.im_pp, is_param=True) + self.register_buffer( + "_pp", torch.tensor([(w / 2, h / 2) for h, w in self.imshapes]) + ) + self.register_buffer( + "_grid", + ParameterStack( + [xy_grid(W, H, device=self.device) for H, W in self.imshapes], + fill=self.max_area, + ), + ) + + # pre-compute pixel weights + self.register_buffer( + "_weight_i", + ParameterStack( + [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], + fill=self.max_area, + ), + ) + self.register_buffer( + "_weight_j", + ParameterStack( + [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], + fill=self.max_area, + ), + ) + + # precompute aa + self.register_buffer( + "_stacked_pred_i", + ParameterStack(self.pred_i, self.str_edges, fill=self.max_area), + ) + self.register_buffer( + "_stacked_pred_j", + ParameterStack(self.pred_j, self.str_edges, fill=self.max_area), + ) + self.register_buffer("_ei", torch.tensor([i for i, j in self.edges])) + self.register_buffer("_ej", torch.tensor([j for i, j in self.edges])) + self.total_area_i = sum([im_areas[i] for i, j in self.edges]) + self.total_area_j = sum([im_areas[j] for i, j in self.edges]) + + def _check_all_imgs_are_selected(self, msk): + assert np.all( + self._get_msk_indices(msk) == np.arange(self.n_imgs) + ), "incomplete mask!" + + def preset_pose(self, known_poses, pose_msk=None): # cam-to-world + self._check_all_imgs_are_selected(pose_msk) + + if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: + known_poses = [known_poses] + for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): + if self.verbose: + print(f" (setting pose #{idx} = {pose[:3,3]})") + self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose))) + + # normalize scale if there's less than 1 known pose + n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) + self.norm_pw_scale = n_known_poses <= 1 + + self.im_poses.requires_grad_(False) + self.norm_pw_scale = False + + def preset_focal(self, known_focals, msk=None): + self._check_all_imgs_are_selected(msk) + + for idx, focal in zip(self._get_msk_indices(msk), known_focals): + if self.verbose: + print(f" (setting focal #{idx} = {focal})") + self._no_grad(self._set_focal(idx, focal)) + + self.im_focals.requires_grad_(False) + + def preset_principal_point(self, known_pp, msk=None): + self._check_all_imgs_are_selected(msk) + + for idx, pp in zip(self._get_msk_indices(msk), known_pp): + if self.verbose: + print(f" (setting principal point #{idx} = {pp})") + self._no_grad(self._set_principal_point(idx, pp)) + + self.im_pp.requires_grad_(False) + + def _get_msk_indices(self, msk): + if msk is None: + return range(self.n_imgs) + elif isinstance(msk, int): + return [msk] + elif isinstance(msk, (tuple, list)): + return self._get_msk_indices(np.array(msk)) + elif msk.dtype in (bool, torch.bool, np.bool_): + assert len(msk) == self.n_imgs + return np.where(msk)[0] + elif np.issubdtype(msk.dtype, np.integer): + return msk + else: + raise ValueError(f"bad {msk=}") + + def _no_grad(self, tensor): + assert ( + tensor.requires_grad + ), "it must be True at this point, otherwise no modification occurs" + + def _set_focal(self, idx, focal, force=False): + param = self.im_focals[idx] + if ( + param.requires_grad or force + ): # can only init a parameter not already initialized + param.data[:] = self.focal_break * np.log(focal) + return param + + def get_focals(self): + log_focals = torch.stack(list(self.im_focals), dim=0) + return (log_focals / self.focal_break).exp() + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.im_focals]) + + def _set_principal_point(self, idx, pp, force=False): + param = self.im_pp[idx] + H, W = self.imshapes[idx] + if ( + param.requires_grad or force + ): # can only init a parameter not already initialized + param.data[:] = to_cpu(to_numpy(pp) - (W / 2, H / 2)) / 10 + return param + + def get_principal_points(self): + return self._pp + 10 * self.im_pp + + def get_intrinsics(self): + K = torch.zeros((self.n_imgs, 3, 3), device=self.device) + focals = self.get_focals().flatten() + K[:, 0, 0] = K[:, 1, 1] = focals + K[:, :2, 2] = self.get_principal_points() + K[:, 2, 2] = 1 + return K + + def get_im_poses(self): # cam to world + cam2world = self._get_poses(self.im_poses) + return cam2world + + def _set_depthmap(self, idx, depth, force=False): + depth = _ravel_hw(depth, self.max_area) + + param = self.im_depthmaps[idx] + if ( + param.requires_grad or force + ): # can only init a parameter not already initialized + param.data[:] = depth.log().nan_to_num(neginf=0) + return param + + def get_depthmaps(self, raw=False): + res = self.im_depthmaps.exp() + if not raw: + res = [dm[: h * w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def depth_to_pts3d(self): + # Get depths and projection params if not provided + focals = self.get_focals() + pp = self.get_principal_points() + im_poses = self.get_im_poses() + depth = self.get_depthmaps(raw=True) + + # get pointmaps in camera frame + rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) + # project to world frame + return geotrf(im_poses, rel_ptmaps) + + def get_pts3d(self, raw=False): + res = self.depth_to_pts3d() + if not raw: + res = [dm[: h * w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] + return res + + def forward(self): + pw_poses = self.get_pw_poses() # cam-to-world + pw_adapt = self.get_adaptors().unsqueeze(1) + proj_pts3d = self.get_pts3d(raw=True) + + # rotate pairwise prediction according to pw_poses + aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) + aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j) + + # compute the less + li = ( + self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() + / self.total_area_i + ) + lj = ( + self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() + / self.total_area_j + ) + + return li + lj + + +def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp): + pp = pp.unsqueeze(1) + focal = focal.unsqueeze(1) + assert focal.shape == (len(depth), 1, 1) + assert pp.shape == (len(depth), 1, 2) + assert pixel_grid.shape == depth.shape + (2,) + depth = depth.unsqueeze(-1) + return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) + + +def ParameterStack(params, keys=None, is_param=None, fill=0): + if keys is not None: + params = [params[k] for k in keys] + + if fill > 0: + params = [_ravel_hw(p, fill) for p in params] + + requires_grad = params[0].requires_grad + assert all(p.requires_grad == requires_grad for p in params) + + params = torch.stack(list(params)).float().detach() + if is_param or requires_grad: + params = nn.Parameter(params) + params.requires_grad_(requires_grad) + return params + + +def _ravel_hw(tensor, fill=0): + # ravel H,W + tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) + + if len(tensor) < fill: + tensor = torch.cat( + (tensor, tensor.new_zeros((fill - len(tensor),) + tensor.shape[1:])) + ) + return tensor + + +def acceptable_focal_range(H, W, minf=0.5, maxf=3.5): + focal_base = max(H, W) / ( + 2 * np.tan(np.deg2rad(60) / 2) + ) # size / 1.1547005383792515 + return minf * focal_base, maxf * focal_base + + +def apply_mask(img, msk): + img = img.copy() + img[msk] = 0 + return img diff --git a/stream3r/dust3r/cloud_opt/pair_viewer.py b/stream3r/dust3r/cloud_opt/pair_viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..43d62074be850d622d94fe3e5307f911e9bb31d9 --- /dev/null +++ b/stream3r/dust3r/cloud_opt/pair_viewer.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dummy optimizer for visualizing pairs +# -------------------------------------------------------- +import cv2 +import numpy as np +import torch +import torch.nn as nn + +from stream3r.dust3r.cloud_opt.base_opt import BasePCOptimizer +from stream3r.dust3r.cloud_opt.commons import edge_str +from stream3r.dust3r.post_process import estimate_focal_knowing_depth +from stream3r.dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf, inv + + +class PairViewer(BasePCOptimizer): + """ + This a Dummy Optimizer. + To use only when the goal is to visualize the results for a pair of images (with is_symmetrized) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.is_symmetrized and self.n_edges == 2 + self.has_im_poses = True + + # compute all parameters directly from raw input + self.focals = [] + self.pp = [] + rel_poses = [] + confs = [] + for i in range(self.n_imgs): + conf = float( + self.conf_i[edge_str(i, 1 - i)].mean() + * self.conf_j[edge_str(i, 1 - i)].mean() + ) + if self.verbose: + print(f" - {conf=:.3} for edge {i}-{1-i}") + confs.append(conf) + + H, W = self.imshapes[i] + pts3d = self.pred_i[edge_str(i, 1 - i)] + pp = torch.tensor((W / 2, H / 2)) + focal = float( + estimate_focal_knowing_depth(pts3d[None], pp, focal_mode="weiszfeld") + ) + self.focals.append(focal) + self.pp.append(pp) + + # estimate the pose of pts1 in image 2 + pixels = np.mgrid[:W, :H].T.astype(np.float32) + pts3d = self.pred_j[edge_str(1 - i, i)].numpy() + assert pts3d.shape[:2] == (H, W) + msk = self.get_masks()[i].numpy() + K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) + + try: + res = cv2.solvePnPRansac( + pts3d[msk], + pixels[msk], + K, + None, + iterationsCount=100, + reprojectionError=5, + flags=cv2.SOLVEPNP_SQPNP, + ) + success, R, T, inliers = res + assert success + + R = cv2.Rodrigues(R)[0] # world to cam + pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world + except: + pose = np.eye(4) + rel_poses.append(torch.from_numpy(pose.astype(np.float32))) + + # let's use the pair with the most confidence + if confs[0] > confs[1]: + # ptcloud is expressed in camera1 + self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1 + self.depth = [ + self.pred_i["0_1"][..., 2], + geotrf(inv(rel_poses[1]), self.pred_j["0_1"])[..., 2], + ] + else: + # ptcloud is expressed in camera2 + self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2 + self.depth = [ + geotrf(inv(rel_poses[0]), self.pred_j["1_0"])[..., 2], + self.pred_i["1_0"][..., 2], + ] + + self.im_poses = nn.Parameter( + torch.stack(self.im_poses, dim=0), requires_grad=False + ) + self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False) + self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False) + self.depth = nn.ParameterList(self.depth) + for p in self.parameters(): + p.requires_grad = False + + def _set_depthmap(self, idx, depth, force=False): + if self.verbose: + print("_set_depthmap is ignored in PairViewer") + return + + def get_depthmaps(self, raw=False): + depth = [d.to(self.device) for d in self.depth] + return depth + + def _set_focal(self, idx, focal, force=False): + self.focals[idx] = focal + + def get_focals(self): + return self.focals + + def get_known_focal_mask(self): + return torch.tensor([not (p.requires_grad) for p in self.focals]) + + def get_principal_points(self): + return self.pp + + def get_intrinsics(self): + focals = self.get_focals() + pps = self.get_principal_points() + K = torch.zeros((len(focals), 3, 3), device=self.device) + for i in range(len(focals)): + K[i, 0, 0] = K[i, 1, 1] = focals[i] + K[i, :2, 2] = pps[i] + K[i, 2, 2] = 1 + return K + + def get_im_poses(self): + return self.im_poses + + def depth_to_pts3d(self): + pts3d = [] + for d, intrinsics, im_pose in zip( + self.depth, self.get_intrinsics(), self.get_im_poses() + ): + pts, _ = depthmap_to_absolute_camera_coordinates( + d.cpu().numpy(), intrinsics.cpu().numpy(), im_pose.cpu().numpy() + ) + pts3d.append(torch.from_numpy(pts).to(device=self.device)) + return pts3d + + def forward(self): + return float("nan") diff --git a/stream3r/dust3r/datasets/__init__.py b/stream3r/dust3r/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0bd1d37f10d95ea5b32154601eead55620c545 --- /dev/null +++ b/stream3r/dust3r/datasets/__init__.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +from .utils.transforms import * +from .base.batched_sampler import BatchedRandomSampler # noqa + + +def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True, persistent_workers=False, multiprocessing_context=None): + import torch + from croco.utils.misc import get_world_size, get_rank + + # pytorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + try: + sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, + rank=rank, drop_last=drop_last) + except (AttributeError, NotImplementedError): + # not avail for this dataset + if torch.distributed.is_initialized(): + sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last + ) + elif shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + persistent_workers=persistent_workers, + multiprocessing_context=multiprocessing_context, + ) + + return data_loader \ No newline at end of file diff --git a/stream3r/dust3r/datasets/aria/__init__.py b/stream3r/dust3r/datasets/aria/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a49a3a9427e67836de554fed0bc7f6466adbe06 --- /dev/null +++ b/stream3r/dust3r/datasets/aria/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + diff --git a/stream3r/dust3r/datasets/aria/camera_utils.py b/stream3r/dust3r/datasets/aria/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d06211965ad65ad6dff131a5a8bdbb0a95af664f --- /dev/null +++ b/stream3r/dust3r/datasets/aria/camera_utils.py @@ -0,0 +1,471 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import numpy as np +import os +import torch + + +def undistort_fisheye_to_pinhole_rgbd( + fisheye_img, fisheye_depth, fisheye_params, pinhole_params +): + """ + Undistort fisheye images and depth into pinhole camera images and depth. + Inputs: + fisheye_img: HxWx3 numpy array of the fisheye image + fisheye_depth: HxW numpy array of the fisheye depth image + fisheye_params: Bx(T)x16 tensor of Fisheye624 parameters + pinhole_params: Bx(T)x4 tensor of Pinhole parameters + Outputs: + pinhole_image: HxWx3 numpy array of pinhole image + pinhole_depth: HxW numpy array of pinhole depth image + """ + # Create a grid of (u, v) coordinates + h, w, _ = fisheye_img.shape + u, v = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy") + u, v = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy") + uv_grid = torch.hstack([u.reshape(-1, 1), v.reshape(-1, 1)]).reshape(-1, 2).float() + uv_grid = uv_grid[None, None] # Reshape to Bx(T)xNx2 + + # Unproject pinhole image points to 3D rays + rays = pinhole_unproject(uv_grid, pinhole_params) + + # Project 3D rays onto fisheye image plane + fisheye_uv = fisheye624_project(rays, fisheye_params) + + # Reshape the coordinates to the original image size + fisheye_uv = fisheye_uv.reshape(h, w, 2) + + # Convert the coordinates to a NumPy array + fisheye_uv_np = fisheye_uv.numpy() + + # Assuming `rays` is a Bx(T)xNx3 tensor of 3D ray vectors and `depth` is a Bx(T)xN tensor of ray distances + fisheye_rays = fisheye624_unproject(uv_grid, fisheye_params) + rays_normalized = torch.nn.functional.normalize( + fisheye_rays, dim=-1 + ) # Normalize the rays to unit length + + # The Z-axis depth is the length of the projection of the ray onto the Z-axis + # This is equivalent to the dot product of the ray with the Z-axis, since the rays are normalized + z_axis = torch.tensor([0, 0, 1]).to(rays.device) # The Z-axis vector + # Reshape depth to match the last dimension of rays_normalized + + z_depth = torch.sum(rays_normalized * z_axis, dim=-1) * fisheye_depth.reshape(-1) + z_depth = z_depth.reshape(fisheye_depth.shape).unsqueeze(-1) + z_depth = z_depth.numpy() + # Now `z_depth` is a Bx(T)xN tensor of Z depth values + + # Map the color values from the fisheye image to the pinhole image + pinhole_image = cv2.remap( + fisheye_img, + fisheye_uv_np[..., 0], + fisheye_uv_np[..., 1], + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0, + ) + + # Map the depth values from the fisheye depth image to the pinhole depth image + pinhole_depth = cv2.remap( + z_depth, + fisheye_uv_np[..., 0], + fisheye_uv_np[..., 1], + interpolation=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0, + ) + + return pinhole_image, pinhole_depth + + +class VignetteCorrector: + """ + A class to apply vignette correction to an RGB image. + """ + + def __init__(self, vignette_file=None): + """ + Initialize the VignetteCorrector with a vignette file. + Args: + vignette_file (str): The path to the vignette file. + """ + vignette_file = vignette_file or os.path.join(os.path.dirname(__file__), "vignette_imx577.png") + self.vignette = cv2.imread(vignette_file) + self.vignette = self.vignette / 255.0 + self.vignette = torch.from_numpy(self.vignette).float() + + def correct(self, rgb_image): + """ + Apply vignette correction to an RGB image. + Args: + rgb_image : The input RGB image. + Returns: + numpy.array : The corrected RGB image tensor. + """ + result_image = torch.from_numpy( + rgb_image + ).float() # Convert rgb_image to a PyTorch tensor + result_image = result_image / torch.clamp(self.vignette, min=1e-3) + result_image = result_image.clamp(0.0, 255.0) + # set resulting image to 0 at the pixels where vigenette is 0 + result_image = result_image * (self.vignette != 0.0) * 1.0 + return result_image.numpy().astype(np.float32) + + +# Source of the next methods is: +# https://github.com/nerfstudio-project/nerfstudio/blob/d1fc2ee33863071aa03c6679595d554d67246258/nerfstudio/cameras/camera_utils.py + + +def sign_plus(x): + """ + return +1 for positive and for 0.0 in x. This is important for our handling + of z values that should never be 0.0 + """ + sgn = torch.ones_like(x) + sgn[sgn < 0.0] = -1.0 + return sgn + + +@torch.jit.script +def fisheye624_project(xyz, params): + """ + Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera + model project() function. + + Inputs: + xyz: Bx(T)xNx3 tensor of 3D points to be projected + params: Bx(T)x16 tensor of Fisheye624 parameters formatted like this: + [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] + or Bx(T)x15 tensor of Fisheye624 parameters formatted like this: + [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] + Outputs: + uv: Bx(T)xNx2 tensor of 2D projections of xyz in image plane + + Model for fisheye cameras with radial, tangential, and thin-prism distortion. + This model allows fu != fv. + Specifically, the model is: + uvDistorted = [x_r] + tangentialDistortion + thinPrismDistortion + [y_r] + proj = diag(fu,fv) * uvDistorted + [cu;cv]; + where: + a = x/z, b = y/z, r = (a^2+b^2)^(1/2) + th = atan(r) + cosPhi = a/r, sinPhi = b/r + [x_r] = (th+ k0 * th^3 + k1* th^5 + ...) [cosPhi] + [y_r] [sinPhi] + the number of terms in the series is determined by the template parameter numK. + tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1] + [(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0] + where rd^2 = x_r^2 + y_r^2 + thinPrismDistortion = [s0 * rd^2 + s1 rd^4] + [s2 * rd^2 + s3 rd^4] + + Author: Daniel DeTone (ddetone) + """ + + assert (xyz.ndim == 3 and params.ndim == 2) or ( + xyz.ndim == 4 and params.ndim == 3 + ), f"point dim {xyz.shape} does not match cam parameter dim {params}" + assert xyz.shape[-1] == 3 + assert ( + params.shape[-1] == 16 or params.shape[-1] == 15 + ), "This model allows fx != fy" + assert xyz.dtype == params.dtype, "data type must match" + + eps = 1e-9 + T = -1 + if xyz.ndim == 4: + # has T dim + T, N = xyz.shape[1], xyz.shape[2] + xyz = xyz.reshape(-1, N, 3) # (BxT)xNx3 + params = params.reshape(-1, params.shape[-1]) # (BxT)x16 + + B, N = xyz.shape[0], xyz.shape[1] + + # Radial correction. + z = xyz[:, :, 2].reshape(B, N, 1) + # Do not use torch.sign(z) it leads to 0.0 zs if z == 0.0 which leads to a + # nan when we compute xy/z + z = torch.where(torch.abs(z) < eps, eps * sign_plus(z), z) + ab = xyz[:, :, :2] / z + # make sure abs are not too small or 0 otherwise gradients are nan + ab = torch.where(torch.abs(ab) < eps, eps * sign_plus(ab), ab) + r = torch.norm(ab, dim=-1, p=2, keepdim=True) + th = torch.atan(r) + th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r) + th_k = th.reshape(B, N, 1).clone() + for i in range(6): + th_k = th_k + params[:, -12 + i].reshape(B, 1, 1) * torch.pow(th, 3 + i * 2) + xr_yr = th_k * th_divr + uv_dist = xr_yr + + # Tangential correction. + p0 = params[:, -6].reshape(B, 1) + p1 = params[:, -5].reshape(B, 1) + xr = xr_yr[:, :, 0].reshape(B, N) + yr = xr_yr[:, :, 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) + yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + uv_dist_tu = uv_dist[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1) + uv_dist_tv = uv_dist[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0) + uv_dist = torch.stack( + [uv_dist_tu, uv_dist_tv], dim=-1 + ) # Avoids in-place complaint. + + # Thin Prism correction. + s0 = params[:, -4].reshape(B, 1) + s1 = params[:, -3].reshape(B, 1) + s2 = params[:, -2].reshape(B, 1) + s3 = params[:, -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + + # Finally, apply standard terms: focal length and camera centers. + if params.shape[-1] == 15: + fx_fy = params[:, 0].reshape(B, 1, 1) + cx_cy = params[:, 1:3].reshape(B, 1, 2) + else: + fx_fy = params[:, 0:2].reshape(B, 1, 2) + cx_cy = params[:, 2:4].reshape(B, 1, 2) + result = uv_dist * fx_fy + cx_cy + + if T > 0: + result = result.reshape(B // T, T, N, 2) + + assert result.ndim == 4 or result.ndim == 3 + assert result.shape[-1] == 2 + + return result + + +@torch.jit.script +def fisheye624_unproject(uv, params, max_iters: int = 5): + """ + Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera + model. There is no analytical solution for the inverse of the project() + function so this solves an optimization problem using Newton's method to get + the inverse. + + Inputs: + uv: Bx(T)xNx2 tensor of 2D pixels to be projected + params: Bx(T)x16 tensor of Fisheye624 parameters formatted like this: + [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] + or Bx(T)x15 tensor of Fisheye624 parameters formatted like this: + [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] + Outputs: + xyz: Bx(T)xNx3 tensor of 3D rays of uv points with z = 1. + + Model for fisheye cameras with radial, tangential, and thin-prism distortion. + This model assumes fu=fv. This unproject function holds that: + + X = unproject(project(X)) [for X=(x,y,z) in R^3, z>0] + + and + + x = project(unproject(s*x)) [for s!=0 and x=(u,v) in R^2] + + Author: Daniel DeTone (ddetone) + """ + # Note(nyn): The unprojection sometimes results in NaNs when using Float32. + # A temporary workaround in Perveiver is passing in Float64 (double) parameters. + + assert uv.ndim == 3 or uv.ndim == 4, "Expected batched input shaped Bx(T)xNx2" + assert uv.shape[-1] == 2 + assert ( + params.ndim == 2 or params.ndim == 3 + ), "Expected batched input shaped Bx(T)x16 or Bx(T)x15" + assert ( + params.shape[-1] == 16 or params.shape[-1] == 15 + ), "This model allows fx != fy" + assert str(uv.dtype) == str( + params.dtype + ), f"data type must match {uv.dtype} <> {params.dtype}" + eps = 1e-6 + + T = -1 + if uv.ndim == 4: + # has T dim + T, N = uv.shape[1], uv.shape[2] + uv = uv.reshape(-1, N, 2) # (BxT)xNx2 + params = params.reshape(-1, params.shape[-1]) # (BxT)x16 + params = params.reshape(-1, params.shape[-1]) # (BxT)x16 + + B, N = uv.shape[0], uv.shape[1] + + if params.shape[-1] == 15: + fx_fy = params[:, 0].reshape(B, 1, 1) + cx_cy = params[:, 1:3].reshape(B, 1, 2) + else: + fx_fy = params[:, 0:2].reshape(B, 1, 2) + cx_cy = params[:, 2:4].reshape(B, 1, 2) + + uv_dist = (uv - cx_cy) / fx_fy + + # Compute xr_yr using Newton's method. + xr_yr = uv_dist.clone() # Initial guess. + for _ in range(max_iters): + uv_dist_est = xr_yr.clone() + # Tangential terms. + p0 = params[:, -6].reshape(B, 1) + p1 = params[:, -5].reshape(B, 1) + xr = xr_yr[:, :, 0].reshape(B, N) + yr = xr_yr[:, :, 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) + yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + # Thin Prism terms. + s0 = params[:, -4].reshape(B, 1) + s1 = params[:, -3].reshape(B, 1) + s2 = params[:, -2].reshape(B, 1) + s3 = params[:, -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + # Compute the derivative of uv_dist w.r.t. xr_yr. + duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) + duv_dist_dxr_yr[:, :, 0, 0] = ( + 1.0 + 6.0 * xr_yr[:, :, 0] * p0 + 2.0 * xr_yr[:, :, 1] * p1 + ) + offdiag = 2.0 * (xr_yr[:, :, 0] * p1 + xr_yr[:, :, 1] * p0) + duv_dist_dxr_yr[:, :, 0, 1] = offdiag + duv_dist_dxr_yr[:, :, 1, 0] = offdiag + duv_dist_dxr_yr[:, :, 1, 1] = ( + 1.0 + 6.0 * xr_yr[:, :, 1] * p1 + 2.0 * xr_yr[:, :, 0] * p0 + ) + xr_yr_sq_norm = xr_yr_sq[:, :, 0] + xr_yr_sq[:, :, 1] + temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) + duv_dist_dxr_yr[:, :, 0, 0] = duv_dist_dxr_yr[:, :, 0, 0] + ( + xr_yr[:, :, 0] * temp1 + ) + duv_dist_dxr_yr[:, :, 0, 1] = duv_dist_dxr_yr[:, :, 0, 1] + ( + xr_yr[:, :, 1] * temp1 + ) + temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) + duv_dist_dxr_yr[:, :, 1, 0] = duv_dist_dxr_yr[:, :, 1, 0] + ( + xr_yr[:, :, 0] * temp2 + ) + duv_dist_dxr_yr[:, :, 1, 1] = duv_dist_dxr_yr[:, :, 1, 1] + ( + xr_yr[:, :, 1] * temp2 + ) + # Compute 2x2 inverse manually here since torch.inverse() is very slow. + # Because this is slow: inv = duv_dist_dxr_yr.inverse() + # About a 10x reduction in speed with above line. + mat = duv_dist_dxr_yr.reshape(-1, 2, 2) + a = mat[:, 0, 0].reshape(-1, 1, 1) + b = mat[:, 0, 1].reshape(-1, 1, 1) + c = mat[:, 1, 0].reshape(-1, 1, 1) + d = mat[:, 1, 1].reshape(-1, 1, 1) + det = 1.0 / ((a * d) - (b * c)) + top = torch.cat([d, -b], dim=2) + bot = torch.cat([-c, a], dim=2) + inv = det * torch.cat([top, bot], dim=1) + inv = inv.reshape(B, N, 2, 2) + # Manually compute 2x2 @ 2x1 matrix multiply. + # Because this is slow: step = (inv @ (uv_dist - uv_dist_est)[..., None])[..., 0] + diff = uv_dist - uv_dist_est + a = inv[:, :, 0, 0] + b = inv[:, :, 0, 1] + c = inv[:, :, 1, 0] + d = inv[:, :, 1, 1] + e = diff[:, :, 0] + f = diff[:, :, 1] + step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) + # Newton step. + xr_yr = xr_yr + step + + # Compute theta using Newton's method. + xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) + th = xr_yr_norm.clone() + for _ in range(max_iters): + th_radial = uv.new_ones(B, N, 1) + dthd_th = uv.new_ones(B, N, 1) + for k in range(6): + r_k = params[:, -12 + k].reshape(B, 1, 1) + th_radial = th_radial + (r_k * torch.pow(th, 2 + k * 2)) + dthd_th = dthd_th + ((3.0 + 2.0 * k) * r_k * torch.pow(th, 2 + k * 2)) + th_radial = th_radial * th + step = (xr_yr_norm - th_radial) / dthd_th + # handle dthd_th close to 0. + step = torch.where(dthd_th.abs() > eps, step, sign_plus(step) * eps * 10.0) + th = th + step + # Compute the ray direction using theta and xr_yr. + close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) + ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr) + ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) + assert ray.shape[-1] == 3 + + if T > 0: + ray = ray.reshape(B // T, T, N, 3) + + return ray + + +def pinhole_project(xyz, params): + """ + Batched implementation of the Pinhole (aka Linear) camera + model project() function. + + Inputs: + xyz: Bx(T)xNx3 tensor of 3D points to be projected + params: Bx(T)x4 tensor of Pinhole parameters formatted like this: + [f_u f_v c_u c_v] + Outputs: + uv: Bx(T)xNx2 tensor of 2D projections of xyz in image plane + """ + + assert (xyz.ndim == 3 and params.ndim == 2) or (xyz.ndim == 4 and params.ndim == 3) + assert params.shape[-1] == 4 + eps = 1e-9 + + # Focal length and principal point + fx_fy = params[..., 0:2].reshape(*xyz.shape[:-2], 1, 2) + cx_cy = params[..., 2:4].reshape(*xyz.shape[:-2], 1, 2) + # Make sure depth is not too close to zero. + z = xyz[..., 2:] + # Do not use torch.sign(z) it leads to 0.0 zs if z == 0.0 which leads to a + # nan when we compute xy/z + z = torch.where(torch.abs(z) < eps, eps * sign_plus(z), z) + uv = (xyz[..., :2] / z) * fx_fy + cx_cy + return uv + + +def pinhole_unproject(uv, params, max_iters: int = 5): + """ + Batched implementation of the Pinhole (aka Linear) camera + model. + + Inputs: + uv: Bx(T)xNx2 tensor of 2D pixels to be projected + params: Bx(T)x4 tensor of Pinhole parameters formatted like this: + [f_u f_v c_u c_v] + Outputs: + xyz: Bx(T)xNx3 tensor of 3D rays of uv points with z = 1. + + """ + assert uv.ndim == 3 or uv.ndim == 4, "Expected batched input shaped Bx(T)xNx3" + assert params.ndim == 2 or params.ndim == 3 + assert params.shape[-1] == 4 + assert uv.shape[-1] == 2 + + # Focal length and principal point + fx_fy = params[..., 0:2].reshape(*uv.shape[:-2], 1, 2) + cx_cy = params[..., 2:4].reshape(*uv.shape[:-2], 1, 2) + + uv_dist = (uv - cx_cy) / fx_fy + + ray = torch.cat([uv_dist, uv.new_ones(*uv.shape[:-1], 1)], dim=-1) + return ray diff --git a/stream3r/dust3r/datasets/aria/vignette_imx577.png b/stream3r/dust3r/datasets/aria/vignette_imx577.png new file mode 100644 index 0000000000000000000000000000000000000000..ef1639a4000ef1395c0ec1dfbf630fa5cf4dad3c --- /dev/null +++ b/stream3r/dust3r/datasets/aria/vignette_imx577.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ef18ef9ac6a7d8316974c398394351357f10f56b609381335b7f87833b4be48 +size 52678 diff --git a/stream3r/dust3r/datasets/base/__init__.py b/stream3r/dust3r/datasets/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea9d87a320e848d1c4851a1e2408313c9255365 --- /dev/null +++ b/stream3r/dust3r/datasets/base/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/stream3r/dust3r/datasets/base/batched_sampler.py b/stream3r/dust3r/datasets/base/batched_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..d46b1a73ceaadb235f556970eee091c41c5fe1a9 --- /dev/null +++ b/stream3r/dust3r/datasets/base/batched_sampler.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Random sampling under a constraint +# -------------------------------------------------------- +import numpy as np +import torch + + +class BatchedRandomSampler: + """Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__( + self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True + ): + self.batch_size = batch_size + self.pool_size = pool_size + + self.len_dataset = N = len(dataset) + self.total_size = round_by(N, batch_size * world_size) if drop_last else N + assert ( + world_size == 1 or drop_last + ), "must drop the last batch in distributed mode" + + # distributed sampler + self.world_size = world_size + self.rank = rank + self.epoch = None + + def __len__(self): + return self.total_size // self.world_size + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + # prepare RNG + if self.epoch is None: + assert ( + self.world_size == 1 and self.rank == 0 + ), "use set_epoch() if distributed mode is used" + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.epoch + 777 + rng = np.random.default_rng(seed=seed) + + # random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) + rng.shuffle(sample_idxs) + + # random feat_idxs (same across each batch) + n_batches = (self.total_size + self.batch_size - 1) // self.batch_size + feat_idxs = rng.integers(self.pool_size, size=n_batches) + feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) + feat_idxs = feat_idxs.ravel()[: self.total_size] + + # put them together + idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) + + # Distributed sampler: we select a subset of batches + # make sure the slice for each node is aligned with batch_size + size_per_proc = self.batch_size * ( + (self.total_size + self.world_size * self.batch_size - 1) + // (self.world_size * self.batch_size) + ) + idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc] + + yield from (tuple(idx) for idx in idxs) + + +def round_by(total, multiple, up=False): + if up: + total = total + multiple - 1 + return (total // multiple) * multiple diff --git a/stream3r/dust3r/datasets/base/easy_dataset.py b/stream3r/dust3r/datasets/base/easy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c6362f3b4d3e9371a20632a122b3a62950cf8648 --- /dev/null +++ b/stream3r/dust3r/datasets/base/easy_dataset.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# A dataset base class that you can easily resize and combine. +# -------------------------------------------------------- +import numpy as np + +from stream3r.dust3r.datasets.base.batched_sampler import BatchedRandomSampler + + +class EasyDataset: + """a dataset that you can easily resize and combine. + Examples: + --------- + 2 * dataset ==> duplicate each element 2x + + 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) + + dataset1 + dataset2 ==> concatenate datasets + """ + + def __add__(self, other): + return CatDataset([self, other]) + + def __rmul__(self, factor): + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + pass # nothing to do by default + + def set_ratio(self, train_ratio): + self.train_ratio = train_ratio + + def make_sampler( + self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True + ): + if not (shuffle): + raise NotImplementedError() # cannot deal yet + num_of_aspect_ratios = len(self._resolutions) + return BatchedRandomSampler( + self, + batch_size, + num_of_aspect_ratios, + world_size=world_size, + rank=rank, + drop_last=drop_last, + ) + + +class MulDataset(EasyDataset): + """Artifically augmenting the size of a dataset.""" + + multiplicator: int + + def __init__(self, multiplicator, dataset): + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + return self.multiplicator * len(self.dataset) + + def __repr__(self): + return f"{self.multiplicator}*{repr(self.dataset)}" + + def __getitem__(self, idx): + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[idx // self.multiplicator, other] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class ResizedDataset(EasyDataset): + """Artifically changing the size of a dataset.""" + + new_size: int + + def __init__(self, new_size, dataset): + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + def __len__(self): + return self.new_size + + def __repr__(self): + size_str = str(self.new_size) + for i in range((len(size_str) - 1) // 3): + sep = -4 * i - 3 + size_str = size_str[:sep] + "_" + size_str[sep:] + return f"{size_str} @ {repr(self.dataset)}" + + def set_epoch(self, epoch): + # this random shuffle only depends on the epoch + # we modify to depend on date too + from datetime import datetime + current_date = datetime.now() + date_int = int(current_date.strftime("%Y%m%d")) + rng = np.random.default_rng(seed=epoch + date_int + 777) + + # shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # rotary extension until target size is met + shuffled_idxs = np.concatenate( + [perm] * (1 + (len(self) - 1) // len(self.dataset)) + ) + self._idxs_mapping = shuffled_idxs[: self.new_size] + + assert len(self._idxs_mapping) == self.new_size + + def set_ratio(self, train_ratio): + self.dataset.train_ratio = train_ratio + + def __getitem__(self, idx): + assert hasattr( + self, "_idxs_mapping" + ), "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()" + if isinstance(idx, tuple): + idx, other = idx + return self.dataset[self._idxs_mapping[idx], other] + else: + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + return self.dataset._resolutions + + +class CatDataset(EasyDataset): + """Concatenation of several datasets""" + + def __init__(self, datasets): + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + return self._cum_sizes[-1] + + def __repr__(self): + # remove uselessly long transform + return " + ".join( + repr(dataset).replace( + ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))", + "", + ) + for dataset in self.datasets + ) + + def set_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def set_ratio(self, train_ratio): + for dataset in self.datasets: + dataset.set_ratio(train_ratio) + + def __getitem__(self, idx): + other = None + if isinstance(idx, tuple): + idx, other = idx + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, "right") + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None: + new_idx = (new_idx, other) + return dataset[new_idx] + + @property + def _resolutions(self): + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions) + return resolutions diff --git a/stream3r/dust3r/datasets/utils/__init__.py b/stream3r/dust3r/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea9d87a320e848d1c4851a1e2408313c9255365 --- /dev/null +++ b/stream3r/dust3r/datasets/utils/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/stream3r/dust3r/datasets/utils/cropping.py b/stream3r/dust3r/datasets/utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5ab593001355f862ba4fa76eedd32f4ad41310 --- /dev/null +++ b/stream3r/dust3r/datasets/utils/cropping.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# croppping utilities +# -------------------------------------------------------- +import PIL.Image +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import numpy as np # noqa +from stream3r.dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + + +class ImageList: + """ Convenience class to aply the same operation to a whole set of images. + """ + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch('resize', *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch('crop', *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True, track=None, track_valid_mask=None): + """ Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + # can also use this with masks instead of depthmaps + assert tuple(depthmap.shape[:2]) == image.size[::-1] + + if track is not None: + # can also use this with masks instead of depthmaps + assert tuple(track.shape[:2]) == image.size[::-1] + assert tuple(track_valid_mask.shape[:2]) == image.size[::-1] + + # define output resolution + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + return (image.to_pil(), depthmap, camera_intrinsics) + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # first rescale the image so that it contains the crop + image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic) + if depthmap is not None: + depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final, + fy=scale_final, interpolation=cv2.INTER_NEAREST) + + if track is not None: + track = cv2.resize(track, output_resolution, fx=scale_final, + fy=scale_final, interpolation=cv2.INTER_NEAREST) + track_valid_mask = cv2.resize(track_valid_mask.astype(np.uint8), output_resolution, fx=scale_final, + fy=scale_final, interpolation=cv2.INTER_NEAREST).astype(bool) + + # no offset here; simple rescaling + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) + + if track is not None: + return image.to_pil(), depthmap, camera_intrinsics, track, track_valid_mask + else: + return image.to_pil(), depthmap, camera_intrinsics + + +def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None): + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox, track=None, track_valid_mask=None): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + if depthmap is not None: + depthmap = depthmap[t:b, l:r] + + if track is not None: + track = track[t:b, l:r] + track_valid_mask = track_valid_mask[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + if track is not None: + return image.to_pil(), depthmap, camera_intrinsics, track, track_valid_mask + else: + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution): + out_width, out_height = output_resolution + l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l + out_width, t + out_height) + return crop_bbox diff --git a/stream3r/dust3r/datasets/utils/transforms.py b/stream3r/dust3r/datasets/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..9231072ac5b2823adc96b336dc7969f0194dbb77 --- /dev/null +++ b/stream3r/dust3r/datasets/utils/transforms.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUST3R default transforms +# -------------------------------------------------------- +import torchvision.transforms as tvf + +from stream3r.dust3r.utils.image import ImgNorm + +# define the standard image transforms +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) + +# from CoTracker and DELTA +TrackAug = tvf.Compose([ + tvf.RandomApply([tvf.ColorJitter(0.2, 0.2, 0.2, 0.25/3.14)], p=0.25), + tvf.RandomApply([tvf.GaussianBlur(11, sigma=(0.1, 2.0))], p=0.05), + ImgNorm +]) \ No newline at end of file diff --git a/stream3r/dust3r/datasets_cut3r/__init__.py b/stream3r/dust3r/datasets_cut3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34bfaef0e742df925479cb6fa4ca800b5b7304ad --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/__init__.py @@ -0,0 +1,81 @@ +import torch +from torch.utils.data.distributed import DistributedSampler + +from .utils.transforms import * +from .base.batched_sampler import BatchedRandomSampler # noqa +from .arkitscenes import ARKitScenes_Multi # noqa +from .arkitscenes_highres import ARKitScenesHighRes_Multi +from .bedlam import BEDLAM_Multi +from .blendedmvs import BlendedMVS_Multi # noqa +from .co3d import Co3d_Multi # noqa +from .cop3d import Cop3D_Multi +from .dl3dv import DL3DV_Multi +from .dynamic_replica import DynamicReplica +from .eden import EDEN_Multi +from .hypersim import HyperSim_Multi +from .irs import IRS +from .hoi4d import HOI4D_Multi +from .mapfree import MapFree_Multi +from .megadepth import MegaDepth_Multi # noqa +from .mp3d import MP3D_Multi +from .mvimgnet import MVImgNet_Multi +from .mvs_synth import MVS_Synth_Multi +from .omniobject3d import OmniObject3D_Multi +from .pointodyssey import PointOdyssey_Multi +from .realestate10k import RE10K_Multi +from .scannet import ScanNet_Multi +from .scannetpp import ScanNetpp_Multi # noqa +from .smartportraits import SmartPortraits_Multi +from .spring import Spring +from .synscapes import SynScapes +from .tartanair import TartanAir_Multi +from .threedkb import ThreeDKenBurns +from .uasol import UASOL_Multi +from .urbansyn import UrbanSyn +from .unreal4k import UnReal4K_Multi +from .vkitti2 import VirtualKITTI2_Multi # noqa +from .waymo import Waymo_Multi # noqa +from .wildrgbd import WildRGBD_Multi # noqa + +# from spann3r, slam3r +from .habitat import Habitat +from .project_aria_seq import Aria_Seq + + +def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True, persistent_workers=False, multiprocessing_context=None): + import torch + from croco.utils.misc import get_world_size, get_rank + + # pytorch dataset + if isinstance(dataset, str): + dataset = eval(dataset) + + world_size = get_world_size() + rank = get_rank() + + try: + sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, + rank=rank, drop_last=drop_last) + except (AttributeError, NotImplementedError): + # not avail for this dataset + if torch.distributed.is_initialized(): + sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last + ) + elif shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_mem, + drop_last=drop_last, + persistent_workers=persistent_workers, + multiprocessing_context=multiprocessing_context, + ) + + return data_loader \ No newline at end of file diff --git a/stream3r/dust3r/datasets_cut3r/arkitscenes.py b/stream3r/dust3r/datasets_cut3r/arkitscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..c2ed6c7d58744aef3223a5ae08bce74487c9b6c1 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/arkitscenes.py @@ -0,0 +1,273 @@ +import os.path as osp +import pickle +import os +import sys +import itertools + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +def stratified_sampling(indices, num_samples, rng=None): + if num_samples > len(indices): + raise ValueError("num_samples cannot exceed the number of available indices.") + elif num_samples == len(indices): + return indices + + sorted_indices = sorted(indices) + stride = len(sorted_indices) / num_samples + sampled_indices = [] + if rng is None: + rng = np.random.default_rng() + + for i in range(num_samples): + start = int(i * stride) + end = int((i + 1) * stride) + # Ensure end does not exceed the list + end = min(end, len(sorted_indices)) + if start < end: + # Randomly select within the current stratum + rand_idx = rng.integers(start, end) + sampled_indices.append(sorted_indices[rand_idx]) + else: + # In case of any rounding issues, select the last index + sampled_indices.append(sorted_indices[-1]) + + return rng.permutation(sampled_indices) + + +class ARKitScenes_Multi(BaseMultiViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 8 + super().__init__(*args, **kwargs) + if split == "train": + self.split = "Training" + elif split == "test": + self.split = "Test" + # self.split = "Validation" + else: + raise ValueError("") + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + + if os.path.exists(osp.join(self.ROOT, f'{split}-pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'{split}-pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.id_ranges = pre_calculated_data['id_ranges'] + self.intrinsics = pre_calculated_data['intrinsics'] + self.trajectories = pre_calculated_data['trajectories'] + self.groups = pre_calculated_data['groups'] + + else: + + + with np.load(osp.join(self.ROOT, split, "all_metadata.npz")) as data: + self.scenes: np.ndarray = data["scenes"] + high_res_list = np.array( + [ + d + for d in os.listdir( + os.path.join( + self.ROOT.rstrip("/") + "_highres", + split if split == "Training" else "Validation", + ) + ) + if os.path.join(self.ROOT + "_highres", split, d) + ] + ) + self.scenes = np.setdiff1d(self.scenes, high_res_list) + offset = 0 + counts = [] + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + groups = [] + id_ranges = [] + j = 0 + for scene_idx, scene in enumerate(self.scenes): + scene_dir = osp.join(self.ROOT, self.split, scene) + with np.load( + osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True + ) as data: + imgs = data["images"] + intrins = data["intrinsics"] + traj = data["trajectories"] + min_seq_len = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + if len(imgs) < min_seq_len: + print(f"Skipping {scene}") + continue + + collections = {} + assert "image_collection" in data, "Image collection not found" + collections["image"] = data["image_collection"] + + num_imgs = imgs.shape[0] + img_groups = [] + min_group_len = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + for ref_id, group in collections["image"].item().items(): + if len(group) + 1 < min_group_len: + continue + + # groups are (idx, score)s + group.insert(0, (ref_id, 1.0)) + group = [int(x[0] + offset) for x in group] + img_groups.append(sorted(group)) + + if len(img_groups) == 0: + print(f"Skipping {scene}") + continue + + scenes.append(scene) + sceneids.extend([j] * num_imgs) + id_ranges.extend([(offset, offset + num_imgs) for _ in range(num_imgs)]) + images.extend(imgs) + K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0) + + K[:, 0, 0] = [fx for _, _, fx, _, _, _ in intrins] + K[:, 1, 1] = [fy for _, _, _, fy, _, _ in intrins] + K[:, 0, 2] = [cx for _, _, _, _, cx, _ in intrins] + K[:, 1, 2] = [cy for _, _, _, _, _, cy in intrins] + intrinsics.extend(list(K)) + trajectories.extend(list(traj)) + + # offset groups + groups.extend(img_groups) + counts.append(offset) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.id_ranges = id_ranges + self.images = images + self.intrinsics = intrinsics + self.trajectories = trajectories + self.groups = groups + + with open(osp.join(self.ROOT, f'{split}-pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + id_ranges=id_ranges, + intrinsics=intrinsics, + trajectories=trajectories, + groups=groups), + f, + ) + + + def __len__(self): + return len(self.groups) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + + if rng.choice([True, False]): + image_idxs = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1]) + cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) + start_image_idxs = image_idxs[: len(image_idxs) - cut_off + 1] + start_id = rng.choice(start_image_idxs) + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + image_idxs.tolist(), + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.5, + block_shuffle=16, + ) + image_idxs = np.array(image_idxs)[pos] + else: + ordered_video = False + image_idxs = self.groups[idx] + image_idxs = rng.permutation(image_idxs) + if len(image_idxs) > num_views: + image_idxs = image_idxs[:num_views] + else: + if rng.random() < 0.8: + image_idxs = rng.choice(image_idxs, size=num_views, replace=True) + else: + repeat_num = num_views // len(image_idxs) + 1 + image_idxs = np.tile(image_idxs, repeat_num)[:num_views] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + assert ( + basename[:8] == self.scenes[scene_id] + ), f"{basename}, {self.scenes[scene_id]}" + # print(scene_dir, basename) + # Load RGB image + rgb_image = imread_cv2( + osp.join(scene_dir, "vga_wide", basename.replace(".png", ".jpg")) + ) + # Load depthmap + depthmap = imread_cv2( + osp.join(scene_dir, "lowres_depth", basename), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000.0 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="arkitscenes", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/arkitscenes_highres.py b/stream3r/dust3r/datasets_cut3r/arkitscenes_highres.py new file mode 100755 index 0000000000000000000000000000000000000000..ec6783f18eb823d2d8a14f7209566a1d8e2a5ebf --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/arkitscenes_highres.py @@ -0,0 +1,207 @@ +import os.path as osp +import pickle +import os +import sys +import itertools + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np +import h5py +import math +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 +from safetensors.numpy import save_file, load_file + + +class ARKitScenesHighRes_Multi(BaseMultiViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.max_interval = 8 + self.is_metric = True + super().__init__(*args, **kwargs) + if split == "train": + self.split = "Training" + elif split == "test": + self.split = "Validation" + else: + raise ValueError("") + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + + cache_file = f'{split}-pre-calculated-loaddata-{self.num_views}.pkl' + if os.path.exists(osp.join(self.ROOT, cache_file)): + with open(osp.join(self.ROOT, cache_file), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + self.intrinsics = pre_calculated_data['intrinsics'] + self.trajectories = pre_calculated_data['trajectories'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + + return + + all_scenes = sorted( + [ + d + for d in os.listdir(osp.join(self.ROOT, split)) + if osp.isdir(osp.join(self.ROOT, split, d)) + ] + ) + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + timestamps = [] + intrinsics = [] + trajectories = [] + scene_id = 0 + for scene in all_scenes: + scene_dir = osp.join(self.ROOT, self.split, scene) + with np.load(osp.join(scene_dir, "scene_metadata.npz")) as data: + # with np.load(osp.join(scene_dir, "scene_metadata.safetensor.npz")) as data: + imgs_with_indices = sorted( + enumerate(data["images"]), key=lambda x: x[1] + ) + imgs = [x[1] for x in imgs_with_indices] + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + if len(imgs) < cut_off: + print(f"Skipping {scene}") + continue + indices = [x[0] for x in imgs_with_indices] + tsps = np.array( + [float(img_name.split("_")[1][:-4]) for img_name in imgs] + ) + assert [img[:8] == scene for img in imgs], f"{scene}, {imgs}" + num_imgs = data["images"].shape[0] + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + scenes.append(scene) + scene_img_list.append(img_ids) + sceneids.extend([scene_id] * num_imgs) + images.extend(imgs) + start_img_ids.extend(start_img_ids_) + timestamps.extend(tsps) + + K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0) + intrins = data["intrinsics"][indices] + K[:, 0, 0] = [fx for _, _, fx, _, _, _ in intrins] + K[:, 1, 1] = [fy for _, _, _, fy, _, _ in intrins] + K[:, 0, 2] = [cx for _, _, _, _, cx, _ in intrins] + K[:, 1, 2] = [cy for _, _, _, _, _, cy in intrins] + intrinsics.extend(list(K)) + trajectories.extend(list(data["trajectories"][indices])) + + # offset groups + offset += num_imgs + scene_id += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.scene_img_list = scene_img_list + self.intrinsics = intrinsics + self.trajectories = trajectories + self.start_img_ids = start_img_ids + assert len(self.images) == len(self.intrinsics) == len(self.trajectories) + + with open(osp.join(self.ROOT, cache_file), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + scene_img_list=scene_img_list, + intrinsics=intrinsics, + trajectories=trajectories, + start_img_ids=start_img_ids), + f, + ) + + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + assert ( + basename[:8] == self.scenes[scene_id] + ), f"{basename}, {self.scenes[scene_id]}" + # print(scene_dir, basename) + # Load RGB image + rgb_image = imread_cv2( + osp.join(scene_dir, "vga_wide", basename.replace(".png", ".jpg")) + ) + # Load depthmap + depthmap = imread_cv2( + osp.join(scene_dir, "highres_depth", basename), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000.0 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.7, 0.25, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="arkitscenes_highres", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.99, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/base/__init__.py b/stream3r/dust3r/datasets_cut3r/base/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stream3r/dust3r/datasets_cut3r/base/base_multiview_dataset.py b/stream3r/dust3r/datasets_cut3r/base/base_multiview_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..87c1379f460ccf7c859fc6f211ae02c5eac5f62b --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/base/base_multiview_dataset.py @@ -0,0 +1,570 @@ +import PIL +import types +import torchvision +import numpy as np +import torch +import random +import itertools +from stream3r.dust3r.datasets.base.easy_dataset import EasyDataset # here we use fast's EasyDataset +from stream3r.dust3r.datasets_cut3r.utils.transforms import ImgNorm, SeqColorJitter +from stream3r.dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates +import stream3r.dust3r.datasets_cut3r.utils.cropping as cropping +from stream3r.dust3r.datasets_cut3r.utils.corr import extract_correspondences_from_pts3d + +from pdb import set_trace as st + +def get_ray_map(c2w1, c2w2, intrinsics, h, w): + c2w = np.linalg.inv(c2w1) @ c2w2 + i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy") + grid = np.stack([i, j, np.ones_like(i)], axis=-1) + ro = c2w[:3, 3] + rd = np.linalg.inv(intrinsics) @ grid.reshape(-1, 3).T + rd = (c2w @ np.vstack([rd, np.ones_like(rd[0])])).T[:, :3].reshape(h, w, 3) + rd = rd / np.linalg.norm(rd, axis=-1, keepdims=True) + ro = np.broadcast_to(ro, (h, w, 3)) + ray_map = np.concatenate([ro, rd], axis=-1) + return ray_map + + +class BaseMultiViewDataset(EasyDataset): + """Define all basic options. + + Usage: + class MyDataset (BaseMultiViewDataset): + def _get_views(self, idx, rng): + # overload here + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__( + self, + *, # only keyword arguments + num_views=None, + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + n_corres=0, + nneg=0, + seed=None, + allow_repeat=False, + seq_aug_crop=False, + ): + assert num_views is not None, "undefined num_views" + self.num_views = num_views + self.split = split + self._set_resolutions(resolution) + + self.n_corres = n_corres + self.nneg = nneg + assert ( + self.n_corres == "all" + or isinstance(self.n_corres, int) + or ( + isinstance(self.n_corres, list) and len(self.n_corres) == self.num_views + ) + ), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}" + assert ( + self.nneg == 0 or self.n_corres != "all" + ), "nneg should be 0 if n_corres is all" + + self.is_seq_color_jitter = False + # st() + if isinstance(transform, str): + transform = eval(transform) + + # more robust than [[ transform == SeqColorJitter() ]] + if isinstance(transform, types.FunctionType) and transform.__name__ == 'SeqColorJitter': + transform = SeqColorJitter() + self.is_seq_color_jitter = True + self.transform = transform + + self.aug_crop = aug_crop + self.seed = seed + self.allow_repeat = allow_repeat + self.seq_aug_crop = seq_aug_crop + + def __len__(self): + return len(self.scenes) + + @staticmethod + def efficient_random_intervals( + start, + num_elements, + interval_range, + fixed_interval_prob=0.8, + weights=None, + seed=42, + ): + if random.random() < fixed_interval_prob: + intervals = random.choices(interval_range, weights=weights) * ( + num_elements - 1 + ) + else: + intervals = [ + random.choices(interval_range, weights=weights)[0] + for _ in range(num_elements - 1) + ] + return list(itertools.accumulate([start] + intervals)) + + def sample_based_on_timestamps(self, i, timestamps, num_views, interval=1): + time_diffs = np.abs(timestamps - timestamps[i]) + ids_candidate = np.where(time_diffs < interval)[0] + ids_candidate = np.sort(ids_candidate) + if (self.allow_repeat and len(ids_candidate) < num_views // 3) or ( + len(ids_candidate) < num_views + ): + return [] + ids_sel_list = [] + ids_candidate_left = ids_candidate.copy() + while len(ids_candidate_left) >= num_views: + ids_sel = np.random.choice(ids_candidate_left, num_views, replace=False) + ids_sel_list.append(sorted(ids_sel)) + ids_candidate_left = np.setdiff1d(ids_candidate_left, ids_sel) + + if len(ids_candidate_left) > 0 and len(ids_candidate) >= num_views: + ids_sel = np.concatenate( + [ + ids_candidate_left, + np.random.choice( + np.setdiff1d(ids_candidate, ids_candidate_left), + num_views - len(ids_candidate_left), + replace=False, + ), + ] + ) + ids_sel_list.append(sorted(ids_sel)) + + if self.allow_repeat: + ids_sel_list.append( + sorted(np.random.choice(ids_candidate, num_views, replace=True)) + ) + + # add sequences with fixed intervals (all possible intervals) + pos_i = np.where(ids_candidate == i)[0][0] + curr_interval = 1 + stop = len(ids_candidate) < num_views + while not stop: + pos_sel = [pos_i] + count = 0 + while len(pos_sel) < num_views: + if count % 2 == 0: + curr_pos_i = pos_sel[-1] + curr_interval + if curr_pos_i >= len(ids_candidate): + stop = True + break + pos_sel.append(curr_pos_i) + else: + curr_pos_i = pos_sel[0] - curr_interval + if curr_pos_i < 0: + stop = True + break + pos_sel.insert(0, curr_pos_i) + count += 1 + if not stop and len(pos_sel) == num_views: + ids_sel = sorted([ids_candidate[pos] for pos in pos_sel]) + if ids_sel not in ids_sel_list: + ids_sel_list.append(ids_sel) + curr_interval += 1 + return ids_sel_list + + @staticmethod + def blockwise_shuffle(x, rng, block_shuffle): + if block_shuffle is None: + return rng.permutation(x).tolist() + else: + assert block_shuffle > 0 + blocks = [x[i : i + block_shuffle] for i in range(0, len(x), block_shuffle)] + shuffled_blocks = [rng.permutation(block).tolist() for block in blocks] + shuffled_list = [item for block in shuffled_blocks for item in block] + return shuffled_list + + def get_seq_from_start_id( + self, + num_views, + id_ref, + ids_all, + rng, + min_interval=1, + max_interval=25, + video_prob=0.5, + fix_interval_prob=0.5, + block_shuffle=None, + ): + """ + args: + num_views: number of views to return + id_ref: the reference id (first id) + ids_all: all the ids + rng: random number generator + max_interval: maximum interval between two views + returns: + pos: list of positions of the views in ids_all, i.e., index for ids_all + is_video: True if the views are consecutive + """ + assert min_interval > 0, f"min_interval should be > 0, got {min_interval}" + assert ( + min_interval <= max_interval + ), f"min_interval should be <= max_interval, got {min_interval} and {max_interval}" + assert id_ref in ids_all + pos_ref = ids_all.index(id_ref) + all_possible_pos = np.arange(pos_ref, len(ids_all)) + + remaining_sum = len(ids_all) - 1 - pos_ref + + if remaining_sum >= num_views - 1: + if remaining_sum == num_views - 1: + assert ids_all[-num_views] == id_ref + return [pos_ref + i for i in range(num_views)], True + max_interval = min(max_interval, 2 * remaining_sum // (num_views - 1)) + intervals = [ + rng.choice(range(min_interval, max_interval + 1)) + for _ in range(num_views - 1) + ] + + # if video or collection + if rng.random() < video_prob: + # if fixed interval or random + if rng.random() < fix_interval_prob: + # regular interval + fixed_interval = rng.choice( + range( + 1, + min(remaining_sum // (num_views - 1) + 1, max_interval + 1), + ) + ) + intervals = [fixed_interval for _ in range(num_views - 1)] + is_video = True + else: + is_video = False + + pos = list(itertools.accumulate([pos_ref] + intervals)) + pos = [p for p in pos if p < len(ids_all)] + pos_candidates = [p for p in all_possible_pos if p not in pos] + pos = ( + pos + + rng.choice( + pos_candidates, num_views - len(pos), replace=False + ).tolist() + ) + + pos = ( + sorted(pos) + if is_video + else self.blockwise_shuffle(pos, rng, block_shuffle) + ) + else: + # assert self.allow_repeat + uniq_num = remaining_sum + new_pos_ref = rng.choice(np.arange(pos_ref + 1)) + new_remaining_sum = len(ids_all) - 1 - new_pos_ref + new_max_interval = min(max_interval, new_remaining_sum // (uniq_num - 1)) + new_intervals = [ + rng.choice(range(1, new_max_interval + 1)) for _ in range(uniq_num - 1) + ] + + revisit_random = rng.random() + video_random = rng.random() + + if rng.random() < fix_interval_prob and video_random < video_prob: + # regular interval + fixed_interval = rng.choice(range(1, new_max_interval + 1)) + new_intervals = [fixed_interval for _ in range(uniq_num - 1)] + pos = list(itertools.accumulate([new_pos_ref] + new_intervals)) + + is_video = False + if revisit_random < 0.5 or video_prob == 1.0: # revisit, video / collection + is_video = video_random < video_prob + pos = ( + self.blockwise_shuffle(pos, rng, block_shuffle) + if not is_video + else pos + ) + num_full_repeat = num_views // uniq_num + pos = ( + pos * num_full_repeat + + pos[: num_views - len(pos) * num_full_repeat] + ) + elif revisit_random < 0.9: # random + pos = rng.choice(pos, num_views, replace=True) + else: # ordered + pos = sorted(rng.choice(pos, num_views, replace=True)) + assert len(pos) == num_views + return pos, is_video + + def get_img_and_ray_masks(self, is_metric, v, rng, p=[0.8, 0.15, 0.05]): + # generate img mask and raymap mask + if v == 0 or (not is_metric): + img_mask = True + raymap_mask = False + else: + # rand_val = rng.random() + # if rand_val < p[0]: + img_mask = True + raymap_mask = False + # elif rand_val < p[0] + p[1]: + # img_mask = False + # raymap_mask = True + # else: + # img_mask = True + # raymap_mask = True + return img_mask, raymap_mask + + def get_stats(self): + return f"{len(self)} groups of views" + + def __repr__(self): + resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" + return ( + f"""{type(self).__name__}({self.get_stats()}, + {self.num_views=}, + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace( + "self.", "" + ) + .replace("\n", "") + .replace(" ", "") + ) + + def _get_views(self, idx, resolution, rng, num_views): + raise NotImplementedError() + + def __getitem__(self, idx): + # print("Receiving:" , idx) + # print(idx) + if isinstance(idx, (tuple, list, np.ndarray)): + # the idx is specifying the aspect-ratio + idx, ar_idx = idx + nview = self.num_views + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + nview = self.num_views + + assert nview >= 1 and nview <= self.num_views + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, "_rng"): + seed = torch.randint(0, 2**32, (1,)).item() + self._rng = np.random.default_rng(seed=seed) + + if self.aug_crop > 1 and self.seq_aug_crop: + self.delta_target_resolution = self._rng.integers(0, self.aug_crop) + + # over-loaded code + resolution = self._resolutions[ + ar_idx + ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + try: + views = self._get_views(idx, resolution, self._rng, nview) + except Exception as e: + print(f"Error in _get_views for view {idx}: {e}") + return self.__getitem__(((idx + 999)%len(self), ar_idx)) + + assert len(views) == nview + + if "camera_pose" not in views[0]: + views[0]["camera_pose"] = np.ones((4, 4), dtype=np.float32) + first_view_camera_pose = views[0]["camera_pose"] + transform = SeqColorJitter() if self.is_seq_color_jitter else self.transform + + for v, view in enumerate(views): + assert ( + "pts3d" not in view + ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view["idx"] = (idx, ar_idx, v) + + # encode the image + width, height = view["img"].size + + view["true_shape"] = np.int32((height, width)) + # if self.is_seq_color_jitter: + # st() + # pass + view["img"] = transform(view["img"]) + blk_img_flag = view["img"].max()==-1 # some training imgs are purely black. + view["sky_mask"] = view["depthmap"] < 0 + + assert "camera_intrinsics" in view + if "camera_pose" not in view: + view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite( + view["camera_pose"] + ).all(), f"NaN in camera pose for view {view_name(view)}" + + ray_map = get_ray_map( + first_view_camera_pose, + view["camera_pose"], + view["camera_intrinsics"], + height, + width, + ) + view["ray_map"] = ray_map.astype(np.float32) + + assert "pts3d" not in view + assert "valid_mask" not in view + assert np.isfinite( + view["depthmap"] + ).all(), f"NaN in depthmap for view {view_name(view)}" + pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + view["pts3d"] = pts3d + view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) & (not blk_img_flag) + + # TODO, add black check. + if blk_img_flag: + print(f"{view['dataset']}-{view['instance']}-{view['label']} is purely black!") + + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view["camera_intrinsics"] + view['is_metric_scale'] = view['is_metric'] # for loss compat issue + + if self.n_corres > 0: # for mast3r training. + ref_view = views[0] + for view in views: + corres1, corres2, valid = extract_correspondences_from_pts3d( + ref_view, view, self.n_corres, self._rng, nneg=self.nneg + ) + view["corres"] = (corres1, corres2) + view["valid_corres"] = valid + + # last thing done! + for view in views: + view["rng"] = int.from_bytes(self._rng.bytes(4), "big") + return views + + def _set_resolutions(self, resolutions): + assert resolutions is not None, "undefined resolution" + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance( + width, int + ), f"Bad type for {width=} {type(width)=}, should be int" + assert isinstance( + height, int + ), f"Bad type for {height=} {type(height)=}, should be int" + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary( + self, image, depthmap, intrinsics, resolution, rng=None, info=None + ): + """This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # downscale with lanczos interpolation so that image.size == resolution + # cropping centered on the principal point + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + # st() + assert min_margin_x > W / 5, f"Bad principal point in view={info}" + assert min_margin_y > H / 5, f"Bad principal point in view={info}" + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + image, depthmap, intrinsics = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + + # transpose the resolution if necessary + W, H = image.size # new size + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + if self.aug_crop > 1: + target_resolution += ( + rng.integers(0, self.aug_crop) + if not self.seq_aug_crop + else self.delta_target_resolution + ) + image, depthmap, intrinsics = cropping.rescale_image_depthmap( + image, depthmap, intrinsics, target_resolution + ) + + # actual cropping (if necessary) with bilinear interpolation + intrinsics2 = cropping.camera_matrix_of_crop( + intrinsics, image.size, resolution, offset_factor=0.5 + ) + crop_bbox = cropping.bbox_from_intrinsics_in_out( + intrinsics, intrinsics2, resolution + ) + image, depthmap, intrinsics2 = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + + return image, depthmap, intrinsics2 + + +def is_good_type(key, v): + """returns (is_good, err_msg)""" + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + def sel(x): + return x[batch_index] if batch_index not in (None, slice(None)) else x + + db = sel(view["dataset"]) + label = sel(view["label"]) + instance = sel(view["instance"]) + return f"{db}/{label}/{instance}" + + +def transpose_to_landscape(view): + height, width = view["true_shape"] + + if width < height: + # rectify portrait to landscape + assert view["img"].shape == (3, height, width) + view["img"] = view["img"].swapaxes(1, 2) + + assert view["valid_mask"].shape == (height, width) + view["valid_mask"] = view["valid_mask"].swapaxes(0, 1) + + assert view["depthmap"].shape == (height, width) + view["depthmap"] = view["depthmap"].swapaxes(0, 1) + + assert view["pts3d"].shape == (height, width, 3) + view["pts3d"] = view["pts3d"].swapaxes(0, 1) + + # transpose x and y pixels + view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]] + + assert view["ray_map"].shape == (height, width, 6) + view["ray_map"] = view["ray_map"].swapaxes(0, 1) + + assert view["sky_mask"].shape == (height, width) + view["sky_mask"] = view["sky_mask"].swapaxes(0, 1) + + if "corres" in view: + # transpose correspondences x and y + view["corres"][0] = view["corres"][0][:, [1, 0]] + view["corres"][1] = view["corres"][1][:, [1, 0]] diff --git a/stream3r/dust3r/datasets_cut3r/base/batched_sampler.py b/stream3r/dust3r/datasets_cut3r/base/batched_sampler.py new file mode 100755 index 0000000000000000000000000000000000000000..9da0f066e67351cffe33ac3affb8b3b4d7334249 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/base/batched_sampler.py @@ -0,0 +1,120 @@ +import numpy as np +from pdb import set_trace as st +from torch.utils.data.distributed import DistributedSampler +import torch +from accelerate import Accelerator +import torch.utils +from torch.utils.data import BatchSampler, Sampler +import torch.utils.data + + +class CustomRandomSampler(Sampler): + """Random sampling under a constraint: each sample in the batch has the same feature, + which is chosen randomly from a known pool of 'features' for each batch. + + For instance, the 'feature' could be the image aspect-ratio. + + The index returned is a tuple (sample_idx, feat_idx). + This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. + """ + + def __init__( + self, + dataset, + batch_size, + pool_size, + min_view_size, # 2 if not fixed_length else num_of_views. support paired pre-training first. + max_view_size, + world_size, + # rank, + warmup=1, + drop_last=True, + ): + # st() + self.batch_size = batch_size + self.world_size = world_size + self.pool_size = pool_size # 10, num_of_aspect_ratios == len(self._resolutions) + self.min_view_size = min_view_size + self.max_view_size = max_view_size + self.drop_last = drop_last + self.len_dataset = N = len(dataset) + self.total_size = N + + self.epoch = None + self.epochf = 0.0 + + def __len__(self): + return self.total_size + + def set_epoch(self, epoch): + self.epoch = epoch + + def set_indices(self, indices): + """Set indices from DistributedSampler to ensure each GPU gets a unique subset""" + self.indices = indices + + def __iter__(self): + if self.epoch is None: + raise ValueError( + "Epoch number not set. Please call 'set_epoch(epoch)' before iterating." + ) + + seed = self.epoch + 788 + rng = np.random.default_rng(seed=seed) + # random indices (will restart from 0 if not drop_last) + sample_idxs = np.arange(self.total_size) # dataset lengths + rng.shuffle(sample_idxs) + # random feat_idxs (same across each batch) + n_batches = (self.total_size + self.batch_size - 1) // self.batch_size + if self.pool_size > 1: + p = np.ones(self.pool_size) + p[:self.pool_size // + 2] *= 2 # the first half ratios have double probability to be trained on. + p = p / p.sum() + _feat_idxs = rng.choice( + self.pool_size, size=n_batches, + p=p) # each batch being designated into a ratio + else: + _feat_idxs = rng.integers(self.pool_size, size=n_batches) + # st() + _feat_idxs = np.broadcast_to(_feat_idxs[:, None], + (n_batches, self.batch_size)) # broadcast to (n_batches, self.batch_size) + _feat_idxs = _feat_idxs.ravel()[:self.total_size] # flatten + # st() + _view_idxs = rng.integers(self.min_view_size, # get random integeres from low to high + self.max_view_size + 1, + size=n_batches) # ValueError: low >= high + _view_idxs = np.broadcast_to(_view_idxs[:, None], + (n_batches, self.batch_size)) + _view_idxs = _view_idxs.ravel()[:self.total_size] + # st() + # print('sample_idxs ', sample_idxs) + # print('_feat_idxs ', _feat_idxs) + # print('_view_idxs ', _view_idxs) + # st() + + idxs = np.c_[sample_idxs, _feat_idxs, _view_idxs] + # print(idxs) + yield from (tuple(idx) for idx in idxs) + + +class BatchedRandomSampler(BatchSampler): + """Batch sampler that groups indices from RandomSampler into batches.""" + + def __init__(self, + sampler: CustomRandomSampler, + batch_size, + drop_last=True): + self.sampler = sampler # An instance of RandomSampler + self.batch_size = batch_size + self.drop_last = drop_last + # self.set_epoch(0) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + +def round_by(total, multiple, up=False): + if up: + total = total + multiple - 1 + return (total // multiple) * multiple diff --git a/stream3r/dust3r/datasets_cut3r/base/easy_dataset.py b/stream3r/dust3r/datasets_cut3r/base/easy_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..f30145fadc787b7b8dae57e7ec0cbed7e3572ea9 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/base/easy_dataset.py @@ -0,0 +1,227 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +from torch.utils.data.distributed import DistributedSampler +from pdb import set_trace as st +import numpy as np +from stream3r.dust3r.datasets_cut3r.base.batched_sampler import ( + BatchedRandomSampler, + CustomRandomSampler, +) +import torch + + +class EasyDataset: + """a dataset that you can easily resize and combine. + Examples: + --------- + 2 * dataset ==> duplicate each element 2x + + 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) + + dataset1 + dataset2 ==> concatenate datasets + """ + + def __add__(self, other): + return CatDataset([self, other]) + + def __rmul__(self, factor): + return MulDataset(factor, self) + + def __rmatmul__(self, factor): + return ResizedDataset(factor, self) + + def set_epoch(self, epoch): + pass # nothing to do by default + + def make_sampler(self, + batch_size, + shuffle=True, + drop_last=True, + world_size=1, + rank=0, + fixed_length=False): + if not (shuffle): + raise NotImplementedError() # cannot deal yet + num_of_aspect_ratios = len(self._resolutions) + num_of_views = self.num_views + # st() + + sampler = CustomRandomSampler( + self, + batch_size, + num_of_aspect_ratios, + 4 if not fixed_length else num_of_views, + # 2 if not fixed_length else num_of_views, + # 5 if not fixed_length else num_of_views, + # 4 if not fixed_length else num_of_views, + # 3 if not fixed_length else num_of_views, + num_of_views, + world_size, + warmup=1, + drop_last=drop_last, + ) + + if world_size > 1: # modified according to gpt + # print('world size: ', world_size, 'global rank: ', rank) + + # 1️⃣ Create DistributedSampler + distributed_sampler = DistributedSampler( + self, + num_replicas=world_size, + rank=rank, + shuffle=True, # This ensures dataset is shuffled across GPUs + drop_last=drop_last) + + # 2️⃣ Get indices for this GPU + distributed_indices = list(iter(distributed_sampler)) + sampler.set_indices( + distributed_indices) # Assign correct indices to the sampler + + return BatchedRandomSampler(sampler, batch_size, drop_last) + + +class MulDataset(EasyDataset): + """Artifically augmenting the size of a dataset.""" + + multiplicator: int + + def __init__(self, multiplicator, dataset): + assert isinstance(multiplicator, int) and multiplicator > 0 + self.multiplicator = multiplicator + self.dataset = dataset + + def __len__(self): + return self.multiplicator * len(self.dataset) + + def __repr__(self): + return f"{self.multiplicator}*{repr(self.dataset)}" + + def __getitem__(self, idx): + if isinstance(idx, tuple): + idx, other, another = idx + return self.dataset[idx // self.multiplicator, other, another] + else: + return self.dataset[idx // self.multiplicator] + + @property + def _resolutions(self): + return self.dataset._resolutions + + @property + def num_views(self): + return self.dataset.num_views + + +class ResizedDataset(EasyDataset): + """Artifically changing the size of a dataset.""" + + new_size: int + + def __init__(self, new_size, dataset): + assert isinstance(new_size, int) and new_size > 0 + self.new_size = new_size + self.dataset = dataset + + # self.set_epoch(0) # hard coded + + def __len__(self): + return self.new_size + + def __repr__(self): + size_str = str(self.new_size) + for i in range((len(size_str) - 1) // 3): + sep = -4 * i - 3 + size_str = size_str[:sep] + "_" + size_str[sep:] + return f"{size_str} @ {repr(self.dataset)}" + + def set_epoch(self, epoch): + # this random shuffle only depends on the epoch + rng = np.random.default_rng(seed=epoch + 777) + + # shuffle all indices + perm = rng.permutation(len(self.dataset)) + + # rotary extension until target size is met + # st() + shuffled_idxs = np.concatenate( + [perm] * (1 + (len(self) - 1) // len(self.dataset))) + self._idxs_mapping = shuffled_idxs[:self.new_size] + + assert len(self._idxs_mapping) == self.new_size + + def __getitem__(self, idx): + assert hasattr( + self, "_idxs_mapping" + ), "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()" + if isinstance(idx, tuple): + idx, other, another = idx + return self.dataset[self._idxs_mapping[idx], other, another] + else: + return self.dataset[self._idxs_mapping[idx]] + + @property + def _resolutions(self): + return self.dataset._resolutions + + @property + def num_views(self): + return self.dataset.num_views + + +class CatDataset(EasyDataset): + """Concatenation of several datasets""" + + def __init__(self, datasets): + for dataset in datasets: + assert isinstance(dataset, EasyDataset) + self.datasets = datasets + self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) + + def __len__(self): + return self._cum_sizes[-1] + + def __repr__(self): + # remove uselessly long transform + return " + ".join( + repr(dataset).replace( + ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))", + "", + ) for dataset in self.datasets) + + def set_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_epoch(epoch) + + def __getitem__(self, idx): + other = None + if isinstance(idx, tuple): + idx, other, another = idx + + if not (0 <= idx < len(self)): + raise IndexError() + + db_idx = np.searchsorted(self._cum_sizes, idx, "right") + dataset = self.datasets[db_idx] + new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) + + if other is not None and another is not None: + new_idx = (new_idx, other, another) + return dataset[new_idx] + + @property + def _resolutions(self): + resolutions = self.datasets[0]._resolutions + for dataset in self.datasets[1:]: + assert tuple(dataset._resolutions) == tuple(resolutions) + return resolutions + + @property + def num_views(self): + num_views = self.datasets[0].num_views + for dataset in self.datasets[1:]: + assert dataset.num_views == num_views + return num_views diff --git a/stream3r/dust3r/datasets_cut3r/bedlam.py b/stream3r/dust3r/datasets_cut3r/bedlam.py new file mode 100644 index 0000000000000000000000000000000000000000..8a363c90480e98fb7e116e755e64d53d364b8e15 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/bedlam.py @@ -0,0 +1,330 @@ +import os.path as osp +import pickle +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 +from safetensors.numpy import save_file, load_file + +invalid_seqs = [ + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000042", + "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000059", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000079", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000978", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000081", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000268", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000089", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000189", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000034", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000889", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000293", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000067", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000904", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000434", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000044", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000013", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000396", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000012", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000082", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000120", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000324", + "20221013_3_250_batch01hand_static_bigOffice_seq_000038", + "20221012_3-10_500_batch01hand_zoom_highSchoolGym_seq_000486", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000421", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000226", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000012", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000149", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000311", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000080", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000122", + "20221012_3-10_500_batch01hand_zoom_highSchoolGym_seq_000079", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000077", + "20221014_3_250_batch01hand_orbit_archVizUI3_time15_seq_000095", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000062", + "20221013_3_250_batch01hand_static_bigOffice_seq_000015", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000095", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000119", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000297", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000011", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000196", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000316", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000283", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000085", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000287", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000163", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000804", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000842", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000027", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000182", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000982", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000029", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000031", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000025", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000250", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000785", + "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000069", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000122", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000246", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000352", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000425", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000192", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000900", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000043", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000063", + "20221014_3_250_batch01hand_orbit_archVizUI3_time15_seq_000096", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000091", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000013", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000309", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000114", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000969", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000361", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000267", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000083", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000383", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000890", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000003", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000045", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000317", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000076", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000082", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000907", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000279", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000076", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000004", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000061", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000811", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000800", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000841", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000794", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000308", + "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000064", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000284", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000752", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000269", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000036", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000419", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000290", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000322", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000818", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000327", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000326", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000002", + "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000060", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000348", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000059", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000016", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000817", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000332", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000094", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000193", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000779", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000177", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000368", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000023", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000024", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000310", + "20221014_3_250_batch01hand_orbit_archVizUI3_time15_seq_000086", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000038", + "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000071", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000768", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000017", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000053", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000097", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000856", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000827", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000161", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000084", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000106", + "20221013_3_250_batch01hand_orbit_bigOffice_seq_000207", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000007", + "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000013", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000251", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000796", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000105", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000251", + "20221019_3-8_250_highbmihand_orbit_stadium_seq_000046", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000334", + "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000453", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000373", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000283", + "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000249", +] +hdri_scenes = [ + "20221010_3_1000_batch01hand", + "20221017_3_1000_batch01hand", + "20221018_3-8_250_batch01hand", + "20221019_3_250_highbmihand", +] + + +class BEDLAM_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + # self.pose_root = os.path.join( + # os.path.dirname(ROOT), f"{os.path.basename(ROOT)}_pose" + # ) + self.pose_root = ROOT + # os.path.dirname(ROOT), f"{os.path.basename(ROOT)}_pose" + # ) + assert os.path.exists(self.pose_root) + self.video = True + self.is_metric = True + self.max_interval = 4 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + + cache_file = osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl') + + if os.path.exists(cache_file): + with open(cache_file, + 'rb') as f: + pre_calculated_data = pickle.load(f) + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + # st() + return + + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + if scene in invalid_seqs: + continue + if any([scene.startswith(x) for x in hdri_scenes]): + continue + if "closeup" in scene: + continue + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + assert len(set(self.scenes) - set(os.listdir(self.pose_root))) == 0 + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open(cache_file, 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list), + f, + ) + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(osp.join(self.pose_root, self.scenes[scene_id]), "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + depthmap[depthmap > 200.0] = 0.0 + + cam = load_file(osp.join(cam_dir, basename + ".safetensor")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.10, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="BEDLAM", + label=self.scenes[scene_id] + "_" + basename, + instance=osp.join(rgb_dir, basename + ".png"), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/blendedmvs.py b/stream3r/dust3r/datasets_cut3r/blendedmvs.py new file mode 100755 index 0000000000000000000000000000000000000000..d4247baa7967a573dee529e9ac227100e54cb447 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/blendedmvs.py @@ -0,0 +1,310 @@ +import os.path as osp +from pdb import set_trace as st +import numpy as np +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 +from safetensors.numpy import save_file, load_file +import h5py +from tqdm import tqdm + + +class BlendedMVS_Multi(BaseMultiViewDataset): + """Dataset of outdoor street scenes, 5 images each time""" + + def __init__(self, *args, ROOT, split=None, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = False + super().__init__(*args, **kwargs) + # assert split is None + self._load_data() + + def _load_data(self): + self.data_dict = self.read_h5_file(os.path.join(self.ROOT, "new_overlap.h5")) + self.num_imgs = sum( + [len(self.data_dict[s]["basenames"]) for s in self.data_dict.keys()] + ) + self.num_scenes = len(self.data_dict.keys()) + self.invalid_scenes = [] + self.is_reachable_cache = {scene: {} for scene in self.data_dict.keys()} + + def read_h5_file(self, h5_file_path): + data_dict = {} + self.all_ref_imgs = [] + with h5py.File(h5_file_path, "r") as f: # ! 9 seconds required + # st() + for scene_dir in tqdm(f.keys()): + group = f[scene_dir] + basenames = group["basenames"][:] + indices = group["indices"][:] + values = group["values"][:] + shape = group.attrs["shape"] + # Reconstruct the sparse matrix + score_matrix = np.zeros(shape, dtype=np.float32) + score_matrix[indices[0], indices[1]] = values + data_dict[scene_dir] = { + "basenames": basenames, + "score_matrix": self.build_adjacency_list(score_matrix), + } + self.all_ref_imgs.extend( + [(scene_dir, b) for b in range(len(basenames))] + ) + return data_dict + + @staticmethod + def build_adjacency_list(S, thresh=0.2): + adjacency_list = [[] for _ in range(len(S))] + S = S - thresh + S[S < 0] = 0 + rows, cols = np.nonzero(S) + for i, j in zip(rows, cols): + adjacency_list[i].append((j, S[i][j])) + return adjacency_list + + @staticmethod + def is_reachable(adjacency_list, start_index, k): + visited = set() + stack = [start_index] + while stack and len(visited) < k: + node = stack.pop() + if node not in visited: + visited.add(node) + for neighbor in adjacency_list[node]: + if neighbor[0] not in visited: + stack.append(neighbor[0]) + return len(visited) >= k + + @staticmethod + def random_sequence_no_revisit_with_backtracking( + adjacency_list, k, start_index, rng: np.random.Generator + ): + path = [start_index] + visited = set([start_index]) + + neighbor_iterators = [] + # Initialize the iterator for the start index + neighbors = adjacency_list[start_index] + neighbor_idxs = [n[0] for n in neighbors] + neighbor_weights = [n[1] for n in neighbors] + neighbor_idxs = rng.choice( + neighbor_idxs, + size=len(neighbor_idxs), + replace=False, + p=np.array(neighbor_weights) / np.sum(neighbor_weights), + ).tolist() + neighbor_iterators.append(iter(neighbor_idxs)) + + while len(path) < k: + if not neighbor_iterators: + # No possible sequence + return None + current_iterator = neighbor_iterators[-1] + try: + next_index = next(current_iterator) + if next_index not in visited: + path.append(next_index) + visited.add(next_index) + + # Prepare iterator for the next node + neighbors = adjacency_list[next_index] + neighbor_idxs = [n[0] for n in neighbors] + neighbor_weights = [n[1] for n in neighbors] + neighbor_idxs = rng.choice( + neighbor_idxs, + size=len(neighbor_idxs), + replace=False, + p=np.array(neighbor_weights) / np.sum(neighbor_weights), + ).tolist() + neighbor_iterators.append(iter(neighbor_idxs)) + except StopIteration: + # No more neighbors to try at this node, backtrack + neighbor_iterators.pop() + visited.remove(path.pop()) + return path + + @staticmethod + def random_sequence_with_optional_repeats( + adjacency_list, + k, + start_index, + rng: np.random.Generator, + max_k=None, + max_attempts=100, + ): + if max_k is None: + max_k = k + path = [start_index] + visited = set([start_index]) + current_index = start_index + attempts = 0 + + while len(path) < max_k and attempts < max_attempts: + attempts += 1 + neighbors = adjacency_list[current_index] + neighbor_idxs = [n[0] for n in neighbors] + neighbor_weights = [n[1] for n in neighbors] + + if not neighbor_idxs: + # No neighbors, cannot proceed further + break + + # Try to find unvisited neighbors + unvisited_neighbors = [ + (idx, wgt) + for idx, wgt in zip(neighbor_idxs, neighbor_weights) + if idx not in visited + ] + if unvisited_neighbors: + # Select among unvisited neighbors + unvisited_idxs = [idx for idx, _ in unvisited_neighbors] + unvisited_weights = [wgt for _, wgt in unvisited_neighbors] + probabilities = np.array(unvisited_weights) / np.sum(unvisited_weights) + next_index = rng.choice(unvisited_idxs, p=probabilities) + visited.add(next_index) + else: + # All neighbors visited, but we need to reach length max_k + # So we can revisit nodes + probabilities = np.array(neighbor_weights) / np.sum(neighbor_weights) + next_index = rng.choice(neighbor_idxs, p=probabilities) + + path.append(next_index) + current_index = next_index + + if len(set(path)) >= k: + # If path is shorter than max_k, extend it by repeating existing elements + while len(path) < max_k: + # Randomly select nodes from the existing path to repeat + next_index = rng.choice(path) + path.append(next_index) + return path + else: + # Could not reach k unique nodes + return None + + def __len__(self): + return len(self.all_ref_imgs) + + def get_image_num(self): + return self.num_imgs + + def get_stats(self): + return f"{len(self)} imgs from {self.num_scenes} scenes" + + def generate_sequence( + self, scene, adj_list, num_views, start_index, rng, allow_repeat=False + ): + # st() + cutoff = num_views if not allow_repeat else max(num_views // 5, 3) + if start_index in self.is_reachable_cache[scene]: + if not self.is_reachable_cache[scene][start_index]: + print( + f"Cannot reach {num_views} unique elements from index {start_index}." + ) + return None + else: + self.is_reachable_cache[scene][start_index] = self.is_reachable( + adj_list, start_index, cutoff + ) + if not self.is_reachable_cache[scene][start_index]: + print( + f"Cannot reach {num_views} unique elements from index {start_index}." + ) + return None + if not allow_repeat: + sequence = self.random_sequence_no_revisit_with_backtracking( + adj_list, cutoff, start_index, rng + ) + else: + sequence = self.random_sequence_with_optional_repeats( + adj_list, cutoff, start_index, rng, max_k=num_views + ) + if not sequence: + # st() + self.is_reachable_cache[scene][start_index] = False + # print("Failed to generate a sequence without revisiting.") + return sequence + + def _get_views(self, idx, resolution, rng: np.random.Generator, num_views): + scene_info, ref_img_idx = self.all_ref_imgs[idx] + invalid_seq = True + ordered_video = False + + while invalid_seq: + basenames = self.data_dict[scene_info]["basenames"] + if ( + sum( + [ + (1 - int(x)) + for x in list(self.is_reachable_cache[scene_info].values()) + ] + ) + > len(basenames) - self.num_views + ): + self.invalid_scenes.append(scene_info) + while scene_info in self.invalid_scenes: + idx = rng.integers(low=0, high=len(self.all_ref_imgs)) + scene_info, ref_img_idx = self.all_ref_imgs[idx] + basenames = self.data_dict[scene_info]["basenames"] + + score_matrix = self.data_dict[scene_info]["score_matrix"] + imgs_idxs = self.generate_sequence( + scene_info, score_matrix, num_views, ref_img_idx, rng, self.allow_repeat + ) + + if imgs_idxs is None: + random_direction = 2 * rng.choice(2) - 1 + for offset in range(1, len(basenames)): + tentative_im_idx = ( + ref_img_idx + (random_direction * offset) + ) % len(basenames) + if ( + tentative_im_idx not in self.is_reachable_cache[scene_info] + or self.is_reachable_cache[scene_info][tentative_im_idx] + ): + ref_img_idx = tentative_im_idx + break + else: + invalid_seq = False + views = [] + for view_idx in imgs_idxs: + scene_dir = osp.join(self.ROOT, scene_info) + impath = basenames[view_idx].decode("utf-8") + image = imread_cv2(osp.join(scene_dir, impath + ".jpg")) + depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr")) + camera_params = load_file(osp.join(scene_dir, impath + ".safetensor")) + + intrinsics = np.float32(camera_params["intrinsics"]) + camera_pose = np.eye(4, dtype=np.float32) + camera_pose[:3, :3] = camera_params["R_cam2world"] + camera_pose[:3, 3] = camera_params["t_cam2world"] + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath) + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="BlendedMVS", + label=osp.relpath(scene_dir, self.ROOT), + is_metric=self.is_metric, + is_video=ordered_video, + instance=osp.join(scene_dir, impath + ".jpg"), + quantile=np.array(0.97, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/co3d.py b/stream3r/dust3r/datasets_cut3r/co3d.py new file mode 100755 index 0000000000000000000000000000000000000000..a363c713a34781b9c16b1d010fcf3ae6d05debd7 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/co3d.py @@ -0,0 +1,205 @@ +import os.path as osp +from pdb import set_trace as st +import json +import itertools +from collections import deque +import sys + +import cv2 +import numpy as np +import time + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +from safetensors.numpy import save_file, load_file + + +class Co3d_Multi(BaseMultiViewDataset): + + def __init__(self, mask_bg="rand", *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + assert mask_bg in (True, False, "rand") + self.mask_bg = mask_bg + self.is_metric = False + self.dataset_label = "Co3d_v2" + + overfitting_id = ("apple", "608_95658_192033") + # overfitting_id = ("apple", "272_29064_56962") # ! align with the chunk dataset? Not in here. + overfitting = False # whether working on a single object + # overfitting = True # whether working on a single object + + # load all scenes + with open(osp.join(self.ROOT, f"selected_seqs_{self.split}.json"), + "r") as f: + self.scenes = json.load(f) + self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0} + self.scenes = { + (k, k2): v2 + for k, v in self.scenes.items() + for k2, v2 in v.items() + } + + if overfitting: + self.scenes = {overfitting_id: self.scenes[overfitting_id]} + + self.scene_list = list(self.scenes.keys()) + cut_off = (self.num_views if not self.allow_repeat else max( + self.num_views // 3, 3)) + self.cut_off = cut_off + self.all_ref_imgs = [(key, value) + for key, values in self.scenes.items() + for value in values[:len(values) - cut_off + 1]] + self.invalidate = {scene: {} for scene in self.scene_list} + self.invalid_scenes = {scene: False for scene in self.scene_list} + + def __len__(self): + return len(self.all_ref_imgs) + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "images", + f"frame{view_idx:06n}.safetensor") + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "images", + f"frame{view_idx:06n}.jpg") + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "depths", + f"frame{view_idx:06n}.jpg.geometric.png") + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "masks", + f"frame{view_idx:06n}.png") + + def _read_depthmap(self, depthpath, input_metadata): + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num( + input_metadata["maximum_depth"]) + return depthmap + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene_info, ref_img_idx = self.all_ref_imgs[idx] + + while invalid_seq: + while self.invalid_scenes[scene_info]: + idx = rng.integers(low=0, high=len(self.all_ref_imgs)) + scene_info, ref_img_idx = self.all_ref_imgs[idx] + + obj, instance = scene_info + + image_pool = self.scenes[obj, instance] + if len(image_pool) < self.cut_off: + print("Invalid scene!") + self.invalid_scenes[scene_info] = True + continue + + imgs_idxs, ordered_video = self.get_seq_from_start_id( + num_views, ref_img_idx, image_pool, rng) + + if resolution not in self.invalidate[ + obj, instance]: # flag invalid images + self.invalidate[obj, instance][resolution] = [ + False for _ in range(len(image_pool)) + ] + # decide now if we mask the bg + mask_bg = (self.mask_bg + == True) or (self.mask_bg == "rand" + and rng.choice(2, p=[0.9, 0.1])) + views = [] + + imgs_idxs = deque(imgs_idxs) + + while len(imgs_idxs) > 0: # some images (few) have zero depth + if (len(image_pool) - + sum(self.invalidate[obj, instance][resolution]) + < self.cut_off): + print("Invalid scene!") + invalid_seq = True + self.invalid_scenes[scene_info] = True + break + + im_idx = imgs_idxs.pop() + if self.invalidate[obj, instance][resolution][im_idx]: + # search for a valid image + ordered_video = False + random_direction = 2 * rng.choice(2) - 1 + for offset in range(1, len(image_pool)): + tentative_im_idx = ( + im_idx + + (random_direction * offset)) % len(image_pool) + if not self.invalidate[ + obj, instance][resolution][tentative_im_idx]: + im_idx = tentative_im_idx + break + view_idx = image_pool[im_idx] + impath = self._get_impath(obj, instance, view_idx) + depthpath = self._get_depthpath(obj, instance, view_idx) + + # load camera params + metadata_path = self._get_metadatapath(obj, instance, view_idx) + # input_metadata = np.load(metadata_path) + # st() + input_metadata = load_file(metadata_path) + camera_pose = input_metadata["camera_pose"].astype(np.float32) + intrinsics = input_metadata["camera_intrinsics"].astype( + np.float32) + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = self._read_depthmap(depthpath, input_metadata) + + if mask_bg: + # load object mask + maskpath = self._get_maskpath(obj, instance, view_idx) + maskmap = imread_cv2( + maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) + maskmap = (maskmap / 255.0) > 0.1 + + # update the depthmap with mask + depthmap *= maskmap + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, + depthmap, + intrinsics, + resolution, + rng=rng, + info=impath) + num_valid = (depthmap > 0.0).sum() + if num_valid == 0: + # problem, invalidate image and retry + self.invalidate[obj, instance][resolution][im_idx] = True + imgs_idxs.append(im_idx) + continue + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, len(views), rng) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=osp.join(obj, instance), + instance=osp.split(impath)[1], + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.9, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + )) + + if len(views) == num_views and not all( + [view["instance"] == views[0]["instance"] for view in views]): + invalid_seq = False + return views \ No newline at end of file diff --git a/stream3r/dust3r/datasets_cut3r/cop3d.py b/stream3r/dust3r/datasets_cut3r/cop3d.py new file mode 100755 index 0000000000000000000000000000000000000000..5894d25da5b9c191ecbbf170904bc963449140ec --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/cop3d.py @@ -0,0 +1,112 @@ +import os.path as osp +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from stream3r.dust3r.datasets_cut3r.co3d import Co3d_Multi +from stream3r.dust3r.utils.image import imread_cv2 + +from safetensors.numpy import save_file, load_file + + +class Cop3D_Multi(Co3d_Multi): + def __init__(self, mask_bg="rand", *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.dataset_label = "Cop3D" + self.is_metric = False + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.safetensor") + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.jpg") + + def _get_depthpath(self, obj, instance, view_idx): + # no depth, pseduo path just for getting the right resolution + return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.jpg") + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "masks", f"frame{view_idx:06n}.png") + + def _read_depthmap(self, impath, input_metadata): + # no depth, set to all ones + img = imread_cv2(impath, cv2.IMREAD_UNCHANGED) + depthmap = np.ones_like(img[..., 0], dtype=np.float32) + return depthmap + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene_info, ref_img_idx = self.all_ref_imgs[idx] + + while invalid_seq: + while self.invalid_scenes[scene_info]: + idx = rng.integers(low=0, high=len(self.all_ref_imgs)) + scene_info, ref_img_idx = self.all_ref_imgs[idx] + + obj, instance = scene_info + + image_pool = self.scenes[obj, instance] + if len(image_pool) < self.num_views: + print("Invalid scene!") + self.invalid_scenes[scene_info] = True + continue + + imgs_idxs, ordered_video = self.get_seq_from_start_id( + num_views, + ref_img_idx, + image_pool, + rng, + max_interval=5, + video_prob=1.0, + fix_interval_prob=0.9, + ) + + views = [] + + for im_idx in imgs_idxs: + view_idx = image_pool[im_idx] + impath = self._get_impath(obj, instance, view_idx) + depthpath = self._get_depthpath(obj, instance, view_idx) + + # load camera params + metadata_path = self._get_metadatapath(obj, instance, view_idx) + input_metadata = load_file(metadata_path) + camera_pose = input_metadata["camera_pose"].astype(np.float32) + intrinsics = input_metadata["camera_intrinsics"].astype(np.float32) + + # load image and depth + rgb_image = imread_cv2(impath) + depthmap = self._read_depthmap(depthpath, input_metadata) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset=self.dataset_label, + label=osp.join(obj, instance), + instance=osp.split(impath)[1], + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.96, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=True, + depth_only=False, + single_view=False, + reset=False, + ) + ) + + if len(views) == num_views and not all( + [view["instance"] == views[0]["instance"] for view in views] + ): + invalid_seq = False + return views diff --git a/stream3r/dust3r/datasets_cut3r/dl3dv.py b/stream3r/dust3r/datasets_cut3r/dl3dv.py new file mode 100644 index 0000000000000000000000000000000000000000..c259166d73d31a34845154120d1fca568a40040a --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/dl3dv.py @@ -0,0 +1,202 @@ +import os.path as osp +from pdb import set_trace as st +import pickle +import os +import sys +import itertools + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class DL3DV_Multi(BaseMultiViewDataset): + + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.max_interval = 20 + self.is_metric = False + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + + def _load_data(self): + + if os.path.exists( + osp.join(self.ROOT, + f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open( + osp.join(self.ROOT, + f'pre-calculated-loaddata-{self.num_views}.pkl'), + 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + return + + self.all_scenes = sorted([ + f for f in os.listdir(self.ROOT) + if os.path.isdir(osp.join(self.ROOT, f)) + ]) + subscenes = [] + for scene in self.all_scenes: + # not empty + subscenes.extend([ + osp.join(scene, f) + for f in os.listdir(osp.join(self.ROOT, scene)) + if os.path.isdir(osp.join(self.ROOT, scene, f)) + and len(os.listdir(osp.join(self.ROOT, scene, f))) > 0 + ]) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + scene_img_list = [] + start_img_ids = [] + j = 0 + + for scene_idx, scene in enumerate(subscenes): + scene_dir = osp.join(self.ROOT, scene, "dense") + rgb_paths = sorted([ + f for f in os.listdir(os.path.join(scene_dir, "rgb")) + if f.endswith(".png") + ]) + skip_flag = False + for sub_dir in ['cam', 'depth', 'outlier_mask', 'sky_mask']: + if len(os.listdir(os.path.join(scene_dir, + sub_dir))) != len(rgb_paths): + print('dl3dv ignore ', scene_dir, sub_dir) + skip_flag = True + break + if skip_flag: + # st() + continue + + assert len(rgb_paths) > 0, f"{scene_dir} is empty." + num_imgs = len(rgb_paths) + cut_off = (self.num_views if not self.allow_repeat else max( + self.num_views // 3, 3)) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[:num_imgs - cut_off + 1] + + scenes.append(scene) + scene_img_list.append(img_ids) + sceneids.extend([j] * num_imgs) + images.extend(rgb_paths) + start_img_ids.extend(start_img_ids_) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open( + osp.join(self.ROOT, + f'pre-calculated-loaddata-{self.num_views}.pkl'), + 'wb') as f: + pickle.dump( + dict( + scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list, + ), + f, + ) + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + scene_id = self.sceneids[start_id] + all_image_ids = self.scene_img_list[scene_id] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + block_shuffle=25, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id], "dense") + + rgb_path = self.images[view_idx] + basename = rgb_path[:-4] + + rgb_image = imread_cv2(osp.join(scene_dir, "rgb", rgb_path), + cv2.IMREAD_COLOR) + depthmap = np.load(osp.join(scene_dir, "depth", + basename + ".npy")).astype(np.float32) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + cam_file = np.load(osp.join(scene_dir, "cam", basename + ".npz")) + sky_mask = (cv2.imread(osp.join(scene_dir, "sky_mask", rgb_path), + cv2.IMREAD_UNCHANGED) >= 127) + outlier_mask = cv2.imread( + osp.join(scene_dir, "outlier_mask", rgb_path), + cv2.IMREAD_UNCHANGED) + depthmap[sky_mask] = -1.0 + depthmap[outlier_mask >= 127] = 0.0 + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + threshold = (np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 else 0) + depthmap[depthmap > threshold] = 0.0 + + intrinsics = cam_file["intrinsic"].astype(np.float32) + camera_pose = cam_file["pose"].astype(np.float32) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, + depthmap, + intrinsics, + resolution, + rng=rng, + info=view_idx) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="dl3dv", + label=self.scenes[scene_id] + "_" + rgb_path, + instance=osp.join(scene_dir, "rgb", rgb_path), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.9, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + )) + return views diff --git a/stream3r/dust3r/datasets_cut3r/dynamic_replica.py b/stream3r/dust3r/datasets_cut3r/dynamic_replica.py new file mode 100755 index 0000000000000000000000000000000000000000..58a7df7329eaa5c16d830e144cc6c653fbe60302 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/dynamic_replica.py @@ -0,0 +1,231 @@ +import os.path as osp +import pickle +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 +from stream3r.dust3r.utils.geometry import inv + +from safetensors.numpy import save_file, load_file + + +class DynamicReplica(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 16 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + return + + self.scenes = os.listdir(os.path.join(self.ROOT, split)) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, self.split, scene, "left") + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=lambda x: float(x), + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list,), + f, + ) + + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id], "left") + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + cam = load_file(osp.join(cam_dir, basename + ".safetensor")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + + # ============fix + pytorch3d_w2c = inv(camera_pose) + pytorch3d_w2c[:2, 3] *= -1 + + opencv_w2c = np.eye(4) + opencv_w2c[:3, :3] = pytorch3d_w2c[:3, :3].T + opencv_w2c[:3, 3] = pytorch3d_w2c[:3, 3] + opencv_w2c[:2] *= -1 + + camera_pose = inv(opencv_w2c) + # ============fix done + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.10, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="dynamic_replica", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views + + +if __name__ == "__main__": + import torch + import pause + from torchvision.transforms import ToPILImage + from stream3r.dust3r.datasets.base.base_stereo_view_dataset import view_name + from stream3r.dust3r.utils.image import rgb + from stream3r.dust3r.viz import SceneViz, auto_cam_size + from IPython.display import display + from stream3r.dust3r.datasets.utils.transforms import ImgNorm, convert_input_to_pred_format, vis_track + from stream3r.dust3r.utils.geometry import ( + geotrf, + inv, + ) + from stream3r.viz.viser_visualizer_track import start_visualization + + def main(): + dataset = DynamicReplica( + split="train", allow_repeat=False, ROOT="/mnt/storage/yslan-data/cut3r_processed/processed_dynamic_replica/", + aug_crop=0, resolution=(512, 384), num_views=20, transform=ImgNorm + ) + + # import random + # for i in random.sample(range(len(dataset)), 100): + # views = dataset[i] + # print(i) + + select_idx = 1 + views = dataset[select_idx] + output = convert_input_to_pred_format(views) + + # save_path = os.path.join("develop/2d_compare/test_data", views[0]['dataset'] + str(select_idx)) + # os.makedirs(save_path, exist_ok=True) + # for i in range(len(views)): + # print(view_name(views[i])) + # ToPILImage()(rgb(views[i]["img"])).save(f"{save_path}/{i}.png") + + server = start_visualization( + output=output, + min_conf_thr_percentile=0, + global_conf_thr_value_to_drop_view=1, + point_size=0.0016, + ) + + # share_url = servers.request_share_url() + # print(share_url) + + pause.days(1) + + main() \ No newline at end of file diff --git a/stream3r/dust3r/datasets_cut3r/eden.py b/stream3r/dust3r/datasets_cut3r/eden.py new file mode 100644 index 0000000000000000000000000000000000000000..b2558c524accb886d2ee5fb38192ebfed339598b --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/eden.py @@ -0,0 +1,169 @@ +import os.path as osp +from pdb import set_trace as st +import pickle +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 +from safetensors.numpy import save_file, load_file + + +class EDEN_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.img_names = pre_calculated_data['img_names'] + + return + + scenes = os.listdir(self.ROOT) + self.scenes = scenes + img_names = [] + for scene in scenes: + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + img_names.extend([(scene, basename) for basename in basenames]) + + self.img_names = img_names + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + img_names=self.img_names), + f, + ) + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + img_names = new_rng.permutation(self.img_names) + + views = [] + i = 0 + while len(views) < num_views: + # Load RGB image + scene, img_name = img_names[i] + try: + rgb_image = imread_cv2( + osp.join(self.ROOT, scene, "rgb", f"{img_name}.png") + ) + depthmap = np.load( + osp.join(self.ROOT, scene, "depth", f"{img_name}.npy") + ) + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + + intrinsics = load_file( + osp.join(self.ROOT, scene, "cam", f"{img_name}.safetensor") + )["intrinsics"] + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + except: + i += 1 + continue + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="EDEN", + label=img_name, + instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"), + is_metric=self.is_metric, + is_video=False, + quantile=np.array(1.0, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + ) + ) + i += 1 + return views + + +if __name__ == "__main__": + import torch + import pause + from torchvision.transforms import ToPILImage + from stream3r.dust3r.datasets.base.base_stereo_view_dataset import view_name + from stream3r.dust3r.utils.image import rgb + from stream3r.dust3r.viz import SceneViz, auto_cam_size + from IPython.display import display + from stream3r.dust3r.datasets.utils.transforms import ImgNorm, convert_input_to_pred_format, vis_track + from stream3r.dust3r.utils.geometry import ( + geotrf, + inv, + ) + from stream3r.viz.viser_visualizer_track import start_visualization + + def main(): + dataset = EDEN_Multi( + split="train", allow_repeat=False, ROOT="/mnt/storage/yslan-data/cut3r_processed/processed_eden/", + aug_crop=0, resolution=(512, 384), num_views=20, transform=ImgNorm + ) + + # import random + # for i in random.sample(range(len(dataset)), 100): + # views = dataset[i] + # print(i) + + select_idx = 1 + views = dataset[select_idx] + output = convert_input_to_pred_format(views) + + # save_path = os.path.join("develop/2d_compare/test_data", views[0]['dataset'] + str(select_idx)) + # os.makedirs(save_path, exist_ok=True) + # for i in range(len(views)): + # print(view_name(views[i])) + # ToPILImage()(rgb(views[i]["img"])).save(f"{save_path}/{i}.png") + + server = start_visualization( + output=output, + min_conf_thr_percentile=0, + global_conf_thr_value_to_drop_view=1, + point_size=0.0016, + ) + + # share_url = servers.request_share_url() + # print(share_url) + + pause.days(1) + + main() \ No newline at end of file diff --git a/stream3r/dust3r/datasets_cut3r/habitat.py b/stream3r/dust3r/datasets_cut3r/habitat.py new file mode 100644 index 0000000000000000000000000000000000000000..086eb866d247b41386e41ed1a29afdda5cc4961a --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/habitat.py @@ -0,0 +1,114 @@ +import os +import cv2 +import json +import numpy as np +import os.path as osp +from collections import deque +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +# from dust3r.datasets.base.base_stereo_view_dataset import BaseMultiViewDataset +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset + +from stream3r.dust3r.utils.image import imread_cv2 +# from .base_many_view_dataset import BaseManyViewDataset +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset + + +class Habitat(BaseMultiViewDataset): + def __init__(self, num_seq=200, *args, ROOT, **kwargs): # num_views=5, + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self.num_seq = num_seq + # self.num_frames = num_views + self.is_metric = True + + # load all scenes + self.load_all_scenes(ROOT, num_seq) + + def __len__(self): + return len(self.scene_list) * self.num_seq + + def load_all_scenes(self, base_dir, num_seq=200): + + self.scenes = {} + + data_all = os.listdir(base_dir) + print('All datasets in Habitat:', data_all) + + for data in data_all: + scenes = os.listdir(osp.join(base_dir, data)) + self.scenes[data] = scenes + + self.scenes = {(k, v2): list(range(num_seq)) for k, v in self.scenes.items() + for v2 in v} + self.scene_list = list(self.scenes.keys()) + + def _get_views(self, idx, resolution, rng, attempts=0): + data, scene = self.scene_list[idx // self.num_seq] + seq_id = idx % self.num_seq + + + imgs_idxs_ = list(range(1, self.num_views+1)) + rng.shuffle(imgs_idxs_) + imgs_idxs = deque(imgs_idxs_) + + views = [] + + while len(imgs_idxs) > 0: + im_idx = imgs_idxs.popleft() + + impath = osp.join(self.ROOT, data, scene, f"{seq_id:08}_{im_idx}.jpeg") + depthpath = osp.join(self.ROOT, data, scene, f"{seq_id:08}_{im_idx}_depth.exr") + cam_params_path = osp.join(self.ROOT, data, scene, f"{seq_id:08}_{im_idx}_camera_params.json") + + if not osp.exists(impath): + new_idx = rng.integers(0, self.__len__()-1) + return self._get_views(new_idx, resolution, rng) + + rgb_image = imread_cv2(impath) + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + + cam_params = json.load(open(cam_params_path, 'r')) + intrinsics_ = np.array(cam_params['camera_intrinsics'], dtype=np.float32) + + # cam_r: [3, 3], cam_t: [3, ] + cam_r = np.array(cam_params['R_cam2world'], dtype=np.float32) + cam_t = np.array(cam_params['t_cam2world'], dtype=np.float32) + + # camera_pose: [4, 4] + camera_pose = np.eye(4).astype(np.float32) + camera_pose[:3, :3] = cam_r + camera_pose[:3, 3] = cam_t + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath) + + num_valid = (depthmap > 0.0).sum() + if num_valid == 0 or (not np.isfinite(camera_pose).all()): + if attempts >= 5: + new_idx = rng.integers(0, self.__len__()-1) + return self._get_views(new_idx, resolution, rng) + return self._get_views(idx, resolution, rng, attempts+1) + + views.append(dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset='habitat', + label=osp.join(data, scene), + instance=osp.split(impath)[1], + is_metric=self.is_metric, + is_video=False, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + )) + return views + diff --git a/stream3r/dust3r/datasets_cut3r/hoi4d.py b/stream3r/dust3r/datasets_cut3r/hoi4d.py new file mode 100644 index 0000000000000000000000000000000000000000..8442a75208d2c8c98f7fe3ca8a6b68e506dcacee --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/hoi4d.py @@ -0,0 +1,108 @@ +import os.path as osp +from pdb import set_trace as st +import pickle +import cv2 +import numpy as np +import itertools +import os +import sys +sys.path.append(osp.join(osp.dirname(__file__), '..','..')) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 +from safetensors.numpy import save_file, load_file + + +class HOI4D_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + + cache_file = osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl') + + if os.path.exists(cache_file): + with open(cache_file, + 'rb') as f: + pre_calculated_data = pickle.load(f) + self.img_names = pre_calculated_data['img_names'] + # st() + return + + else: + scenes = os.listdir(self.ROOT) + img_names = [] + for scene in scenes: + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, 'rgb') + basenames = sorted([f[:-4] for f in os.listdir(rgb_dir) if f.endswith('.png')]) + img_names.extend([(scene, basename) for basename in basenames]) + + self.img_names = img_names + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), + 'wb') as f: + pickle.dump(dict(img_names=self.img_names), f) + + # st() + # pass + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + invalid_seq = True + while invalid_seq: + img_names = new_rng.choice(self.img_names, num_views, replace=False) + + views = [] + for v, img_name in enumerate(img_names): + # Load RGB image + scene, img_name = img_name + try: + rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png")) + depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy")) + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + + intrinsics = load_file(osp.join(self.ROOT, scene, "cam", f"{img_name}.safetensor"))["intrinsics"] + except: + print(f"Error loading {scene} {img_name}, skipping") + break + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics= self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='HOI4D', + label=img_name, + instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"), + is_metric=self.is_metric, + is_video=False, + quantile=np.array(0.99, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + )) + if len(views) == num_views: + invalid_seq = False + return views diff --git a/stream3r/dust3r/datasets_cut3r/hypersim.py b/stream3r/dust3r/datasets_cut3r/hypersim.py new file mode 100755 index 0000000000000000000000000000000000000000..55176fe9f8feeeb7637152288d85584ca41bdbc1 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/hypersim.py @@ -0,0 +1,224 @@ +import os.path as osp +import pickle +import json +from pdb import set_trace as st +import os +import sys +import itertools + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +from safetensors.numpy import save_file, load_file + +broken_scenes = [ + 'ai_003_001', + 'ai_004_009', + 'ai_015_006', + 'ai_038_007', + 'ai_046_001', + 'ai_046_009', + 'ai_048_004', + 'ai_053_005', + 'ai_012_007', + 'ai_013_001', + 'ai_023_008', + 'ai_026_020', + 'ai_023_009', + 'ai_038_007', + 'ai_023_004', + 'ai_023_006', + 'ai_026_013', + 'ai_026_018', +] + + +class HyperSim_Multi(BaseMultiViewDataset): + + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 4 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + # https://github.com/apple/ml-hypersim/issues/22 + + def _load_data(self): + + if os.path.exists( + osp.join(self.ROOT, + f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open( + osp.join(self.ROOT, + f'pre-calculated-loaddata-{self.num_views}.pkl'), + 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + # st() + + else: + + self.all_scenes = sorted([ + f for f in os.listdir(self.ROOT) + if os.path.isdir(osp.join(self.ROOT, f)) + ]) + subscenes = [] + for scene in self.all_scenes: + if scene in broken_scenes: + print('hypersim ignore: ', scene) + continue + # not empty + subscenes.extend([ + osp.join(scene, f) + for f in os.listdir(osp.join(self.ROOT, scene)) + if os.path.isdir(osp.join(self.ROOT, scene, f)) + and len(os.listdir(osp.join(self.ROOT, scene, f))) > 0 + ]) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + j = 0 + for scene_idx, scene in enumerate(subscenes): + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + rgb_paths = sorted( + [f for f in os.listdir(scene_dir) if f.endswith(".png")]) + assert len(rgb_paths) > 0, f"{scene_dir} is empty." + num_imgs = len(rgb_paths) + cut_off = (self.num_views if not self.allow_repeat else max( + self.num_views // 3, 3)) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[:num_imgs - cut_off + 1] + + scenes.append(scene) + scene_img_list.append(img_ids) + sceneids.extend([j] * num_imgs) + images.extend(rgb_paths) + start_img_ids.extend(start_img_ids_) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.scene_img_list = scene_img_list + self.start_img_ids = start_img_ids + # st() + + # save_file( {k: np.array(v) for k, v in dict(scenes=self.scenes, + # sceneids=self.sceneids, + # images=images, + # start_img_ids=start_img_ids, + # scene_img_list=scene_img_list,)}, + # osp.join(self.ROOT, 'pre-calculated-loaddata.safetensor') + # ) + + with open( + osp.join(self.ROOT, + f'pre-calculated-loaddata-{self.num_views}.pkl'), + 'wb') as f: + pickle.dump( + dict( + scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list, + ), + f, + ) + + pass + + def __len__(self): + return len(self.start_img_ids) * 10 + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + idx = idx // 10 + start_id = self.start_img_ids[idx] + scene_id = self.sceneids[start_id] + all_image_ids = self.scene_img_list[scene_id] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + rgb_path = self.images[view_idx] + depth_path = rgb_path.replace("rgb.png", "depth.npy") + # cam_path = rgb_path.replace("rgb.png", "cam.safetensor") + cam_path = rgb_path.replace("rgb.png", "cam.npz") + + rgb_image = imread_cv2(osp.join(scene_dir, rgb_path), + cv2.IMREAD_COLOR) + depthmap = np.load(osp.join(scene_dir, + depth_path)).astype(np.float32) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + # cam_file = load_file(osp.join(scene_dir, cam_path)) + cam_file = np.load(osp.join(scene_dir, cam_path)) + intrinsics = cam_file["intrinsics"].astype(np.float32) + camera_pose = cam_file["pose"].astype(np.float32) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, + depthmap, + intrinsics, + resolution, + rng=rng, + info=view_idx) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05]) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="hypersim", + label=self.scenes[scene_id] + "_" + rgb_path, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + )) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/irs.py b/stream3r/dust3r/datasets_cut3r/irs.py new file mode 100644 index 0000000000000000000000000000000000000000..8a531e903aee75bd21cc84286985c37b578bc3b0 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/irs.py @@ -0,0 +1,137 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class IRS(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + scenes = os.listdir(self.ROOT) + img_names = [] + for scene in scenes: + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + img_names.extend([(scene, basename) for basename in basenames]) + + self.img_names = img_names + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + img_names = new_rng.choice(self.img_names, num_views, replace=False) + + views = [] + for v, img_name in enumerate(img_names): + # Load RGB image + scene, img_name = img_name + rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png")) + depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy")) + depthmap[depthmap > 200] = 0.0 + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + + intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))[ + "intrinsics" + ] + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="irs", + label=img_name, + instance=f"{str(idx)}_{img_name}", + is_metric=self.is_metric, + is_video=False, + quantile=np.array(1.0, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + ) + ) + assert len(views) == num_views + return views + + +if __name__ == "__main__": + import torch + import pause + from torchvision.transforms import ToPILImage + from stream3r.dust3r.datasets.base.base_stereo_view_dataset import view_name + from stream3r.dust3r.utils.image import rgb + from stream3r.dust3r.viz import SceneViz, auto_cam_size + from IPython.display import display + from stream3r.dust3r.datasets.utils.transforms import ImgNorm, convert_input_to_pred_format, vis_track + from stream3r.dust3r.utils.geometry import ( + geotrf, + inv, + ) + from stream3r.viz.viser_visualizer_track import start_visualization + + def main(): + dataset = IRS( + split="train", allow_repeat=False, ROOT="/mnt/storage/yslan-data/cut3r_processed/processed_irs/Store/", + aug_crop=0, resolution=(512, 384), num_views=20, transform=ImgNorm + ) + + # import random + # for i in random.sample(range(len(dataset)), 100): + # views = dataset[i] + # print(i) + + select_idx = 1 + views = dataset[select_idx] + output = convert_input_to_pred_format(views) + + # save_path = os.path.join("develop/2d_compare/test_data", views[0]['dataset'] + str(select_idx)) + # os.makedirs(save_path, exist_ok=True) + # for i in range(len(views)): + # print(view_name(views[i])) + # ToPILImage()(rgb(views[i]["img"])).save(f"{save_path}/{i}.png") + + server = start_visualization( + output=output, + min_conf_thr_percentile=0, + global_conf_thr_value_to_drop_view=1, + point_size=0.0016, + ) + + # share_url = servers.request_share_url() + # print(share_url) + + pause.days(1) + + main() \ No newline at end of file diff --git a/stream3r/dust3r/datasets_cut3r/mapfree.py b/stream3r/dust3r/datasets_cut3r/mapfree.py new file mode 100644 index 0000000000000000000000000000000000000000..754a120726eff18e671e1c0a0c132a173c3a0725 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/mapfree.py @@ -0,0 +1,323 @@ +import os.path as osp +import pickle +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys +import pickle +import h5py +from tqdm import tqdm + +from pdb import set_trace as st +from safetensors.numpy import save_file, load_file + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +from safetensors.numpy import save_file, load_file + + +class MapFree_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 30 + super().__init__(*args, **kwargs) + + self._load_data() + + def imgid2path(self, img_id, scene): + first_seq_id, first_frame_id = img_id + return os.path.join( + self.ROOT, + scene, + f"dense_{first_seq_id}", + "rgb", + f"frame_{first_frame_id:05d}.jpg", + ) + + def path2imgid(self, subscene, filename): + # st() + first_seq_id = int(subscene[6:]) # dense_ + first_frame_id = int(filename[6:-4]) + return [first_seq_id, first_frame_id] + + def _load_data(self): + # cache_file = f"{self.ROOT}/cached_metadata_50_col_only_fix.h5" + # cache_file = f"cached_metadata_50_col_only_fix.h5" + cache_file = f"{self.ROOT}/cached_metadata_50_col_only_fix.pkl" + # cache_file = "cached_metadata_50_col_only.h5" + # st() + if os.path.exists(cache_file): + # if False: + # print(f"Loading cached metadata from {cache_file}") + # with h5py.File(cache_file, "r") as hf: + # self.scenes = list(map(lambda x: x.decode("utf-8"), hf["scenes"][:])) + # self.sceneids = hf["sceneids"][:] + # self.scope = hf["scope"][:] + # self.video_flags = hf["video_flags"][:] + # self.groups = hf["groups"][:] + # self.id_ranges = hf["id_ranges"][:] + # self.images = hf["images"][:] + # ! save to pkl for /oss loading + # with open(f"{self.ROOT}/cached_metadata_50_col_only_fix.pkl", 'wb') as f: + # pickle.dump(dict( + # scenes=self.scenes, + # sceneids=self.sceneids, + # scope=self.scope, + # video_flags=self.video_flags, + # groups=self.groups, + # id_ranges=self.id_ranges, + # images=self.images + # ),f,) + + # with open(f"{self.ROOT}/cached_metadata_50_col_only_fix.pkl", "rb") as f: + with open(cache_file, "rb") as f: + hf = pickle.load(f) + self.scenes = hf['scenes'] + self.sceneids = hf["sceneids"][:] + self.scope = hf["scope"][:] + self.video_flags = hf["video_flags"][:] + self.groups = hf["groups"][:] + self.id_ranges = hf["id_ranges"][:] + self.images = hf["images"][:] + # st() + + else: + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + scenes = [] + sceneids = [] + groups = [] + scope = [] + images = [] + id_ranges = [] + is_video = [] + start = 0 + j = 0 + offset = 0 + + for scene in tqdm(scene_dirs): + scenes.append(scene) + # video sequences + subscenes = sorted( + [ + d + for d in os.listdir(os.path.join(self.ROOT, scene)) + if d.startswith("dense_") # dense_0, dense_1 + ] + ) + id_range_subscenes = [] + for subscene in subscenes: + rgb_paths = sorted( + [ + d + for d in os.listdir( + os.path.join(self.ROOT, scene, subscene, "rgb") + ) + if (d.endswith(".jpg") and os.path.exists(os.path.join(self.ROOT, scene, subscene, "depth", d.replace('jpg', 'npy')))) + ] + ) + # st() + assert ( + len(rgb_paths) > 0 + ), f"{os.path.join(self.ROOT, scene, subscene)} is empty." + num_imgs = len(rgb_paths) + images.extend( + [self.path2imgid(subscene, rgb_path) for rgb_path in rgb_paths] + ) + id_range_subscenes.append((offset, offset + num_imgs)) + offset += num_imgs + + # image collections + metadata = pickle.load( + open(os.path.join(self.ROOT, scene, "metadata.pkl"), "rb") + ) + ref_imgs = list(metadata.keys()) + img_groups = [] + for ref_img in ref_imgs: + other_imgs = metadata[ref_img] + if len(other_imgs) + 1 < self.num_views: + continue + group = [(*other_img[0], other_img[1]) for other_img in other_imgs] + group.insert(0, (*ref_img, 1)) + img_groups.append(np.array(group)) + id_ranges.append(id_range_subscenes[ref_img[0]]) + scope.append(start) + start = start + len(group) + + num_groups = len(img_groups) + sceneids.extend([j] * num_groups) + groups.extend(img_groups) + is_video.extend([False] * num_groups) + j += 1 + + self.scenes = np.array(scenes) + self.sceneids = np.array(sceneids) + self.scope = np.array(scope) + self.video_flags = np.array(is_video) + self.groups = np.concatenate(groups, 0) + self.id_ranges = np.array(id_ranges) + self.images = np.array(images) + + data = dict( + scenes=self.scenes, + sceneids=self.sceneids, + scope=self.scope, + video_flags=self.video_flags, + groups=self.groups, + id_ranges=self.id_ranges, + images=self.images, + ) + + # st() + with h5py.File(cache_file, "w") as h5f: + h5f.create_dataset( + "scenes", + data=data["scenes"].astype(object), + dtype=h5py.string_dtype(encoding="utf-8"), + compression="lzf", + chunks=True, + ) + h5f.create_dataset( + "sceneids", data=data["sceneids"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "scope", data=data["scope"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "video_flags", + data=data["video_flags"], + compression="lzf", + chunks=True, + ) + h5f.create_dataset( + "groups", data=data["groups"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "id_ranges", data=data["id_ranges"], compression="lzf", chunks=True + ) + h5f.create_dataset( + "images", data=data["images"], compression="lzf", chunks=True + ) + # st() + + def __len__(self): + return len(self.scope) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + scene = self.scenes[self.sceneids[idx]] + if rng.random() < 0.6: + ids = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1]) + cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) + start_ids = ids[: len(ids) - cut_off + 1] + start_id = rng.choice(start_ids) + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + ids.tolist(), + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.5, + block_shuffle=16, + ) + ids = np.array(ids)[pos] + image_idxs = self.images[ids] + else: + ordered_video = False + seq_start_index = self.scope[idx] + seq_end_index = self.scope[idx + 1] if idx < len(self.scope) - 1 else None + image_idxs = ( + self.groups[seq_start_index:seq_end_index] + if seq_end_index is not None + else self.groups[seq_start_index:] + ) + image_idxs, overlap_scores = image_idxs[:, :2], image_idxs[:, 2] + replace = ( + True + if self.allow_repeat + or len(overlap_scores[overlap_scores > 0]) < num_views + else False + ) + image_idxs = rng.choice( + image_idxs, + num_views, + replace=replace, + p=overlap_scores / np.sum(overlap_scores), + ) + image_idxs = image_idxs.astype(np.int64) + + views = [] + for v, view_idx in enumerate(image_idxs): + img_path = self.imgid2path(view_idx, scene) + depth_path = img_path.replace("rgb", "depth").replace(".jpg", ".npy") + cam_path = img_path.replace("rgb", "cam").replace(".jpg", ".safetensor") + # cam_path = img_path.replace("rgb", "cam").replace(".jpg", ".npz") + sky_mask_path = img_path.replace("rgb", "sky_mask") + image = imread_cv2(img_path) + depthmap = np.load(depth_path) + # camera_params = np.load(cam_path) + camera_params = load_file(cam_path) + sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_UNCHANGED) >= 127 + + intrinsics = camera_params["intrinsic"].astype(np.float32) + camera_pose = camera_params["pose"].astype(np.float32) + + depthmap[sky_mask] = -1.0 + depthmap[depthmap > 400.0] = 0.0 + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + threshold = ( + np.percentile(depthmap[depthmap > 0], 90) # less noisy + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(img_path) + ) + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="MapFree", + label=img_path, + is_metric=self.is_metric, + instance=img_path, + is_video=ordered_video, + quantile=np.array(0.96, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/megadepth.py b/stream3r/dust3r/datasets_cut3r/megadepth.py new file mode 100755 index 0000000000000000000000000000000000000000..8a0db352bbd54633279d0bc6581904437f936567 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/megadepth.py @@ -0,0 +1,100 @@ +import os.path as osp +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +from safetensors.numpy import save_file, load_file + + +class MegaDepth_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self._load_data(self.split) + self.is_metric = False + if self.split is None: + pass + elif self.split == "train": + self.select_scene(("0015", "0022"), opposite=True) + elif self.split == "val": + self.select_scene(("0015", "0022")) + else: + raise ValueError(f"bad {self.split=}") + + def _load_data(self, split): + with np.load( + osp.join(self.ROOT, "megadepth_sets_64.npz"), allow_pickle=True + ) as data: + self.all_scenes = data["scenes"] + self.all_images = data["images"] + self.sets = data["sets"] + + def __len__(self): + return len(self.sets) + + def get_image_num(self): + return len(self.all_images) + + def get_stats(self): + return f"{len(self)} groups from {len(self.all_scenes)} scenes" + + def select_scene(self, scene, *instances, opposite=False): + scenes = (scene,) if isinstance(scene, str) else tuple(scene) + scene_id = [s.startswith(scenes) for s in self.all_scenes] + assert any(scene_id), "no scene found" + valid = np.in1d(self.sets[:, 0], np.nonzero(scene_id)[0]) + if instances: + raise NotImplementedError("selecting instances not implemented") + if opposite: + valid = ~valid + assert valid.any() + self.sets = self.sets[valid] + + def _get_views(self, idx, resolution, rng, num_views): + scene_id = self.sets[idx][0] + image_idxs = self.sets[idx][1:65] + replace = False if not self.allow_repeat else True + image_idxs = rng.choice(image_idxs, num_views, replace=replace) + scene, subscene = self.all_scenes[scene_id].split() + seq_path = osp.join(self.ROOT, scene, subscene) + views = [] + for im_id in image_idxs: + img = self.all_images[im_id] + try: + image = imread_cv2(osp.join(seq_path, img + ".jpg")) + depthmap = imread_cv2(osp.join(seq_path, img + ".exr")) + camera_params = load_file(osp.join(seq_path, img + ".safetensor")) + except Exception as e: + raise OSError(f"cannot load {img}, got exception {e}") + intrinsics = np.float32(camera_params["intrinsics"]) + camera_pose = np.float32(camera_params["cam2world"]) + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(seq_path, img) + ) + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="MegaDepth", + label=osp.relpath(seq_path, self.ROOT), + is_metric=self.is_metric, + instance=img, + is_video=False, + quantile=np.array(0.96, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/mp3d.py b/stream3r/dust3r/datasets_cut3r/mp3d.py new file mode 100755 index 0000000000000000000000000000000000000000..75facdfe9ff4cb7baed964ec06a1c97791d55962 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/mp3d.py @@ -0,0 +1,160 @@ +import os.path as osp +import pickle +import os +import sys +import itertools + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np +from safetensors.numpy import save_file, load_file + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class MP3D_Multi(BaseMultiViewDataset): + def __init__(self, *args, split, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = True + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + + def _load_data(self): + + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + self.images = pre_calculated_data['images'] + self.overlaps = pre_calculated_data['overlaps'] + + return + + scenes = os.listdir(self.ROOT) + offset = 0 + overlaps = {scene: [] for scene in scenes} + scene_img_list = {scene: [] for scene in scenes} + images = [] + + j = 0 + for scene in scenes: + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + overlap = np.load(osp.join(scene_dir, "overlap.npy")) + overlaps[scene] = overlap + num_imgs = len(basenames) + + images.extend( + [(scene, i, basename) for i, basename in enumerate(basenames)] + ) + scene_img_list[scene] = np.arange(num_imgs) + offset + offset += num_imgs + j += 1 + + self.scenes = scenes + self.scene_img_list = scene_img_list + self.images = images + self.overlaps = overlaps + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + scene_img_list=self.scene_img_list, + images=images, + overlaps=overlaps), + f, + ) + + def __len__(self): + return len(self.images) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + num_views_posible = 0 + num_unique = num_views if not self.allow_repeat else max(num_views // 3, 3) + while num_views_posible < num_unique - 1: + scene, img_idx, _ = self.images[idx] + overlap = self.overlaps[scene] + sel_img_idx = np.where(overlap[:, 0] == img_idx)[0] + overlap_sel = overlap[sel_img_idx] + overlap_sel = overlap_sel[ + (overlap_sel[:, 2] > 0.01) * (overlap_sel[:, 2] < 1) + ] + num_views_posible = len(overlap_sel) + if num_views_posible >= num_unique - 1: + break + idx = rng.choice(len(self.images)) + + ref_id = self.scene_img_list[scene][img_idx] + ids = self.scene_img_list[scene][overlap_sel[:, 1].astype(np.int64)] + replace = False if not self.allow_repeat else True + image_idxs = rng.choice( + ids, + num_views - 1, + replace=replace, + p=overlap_sel[:, 2] / np.sum(overlap_sel[:, 2]), + ) + image_idxs = np.concatenate([[ref_id], image_idxs]) + + ordered_video = False + views = [] + for v, view_idx in enumerate(image_idxs): + scene, _, basename = self.images[view_idx] + scene_dir = osp.join(self.ROOT, scene) + rgb_path = osp.join(scene_dir, "rgb", basename + ".png") + depth_path = osp.join(scene_dir, "depth", basename + ".npy") + cam_path = osp.join(scene_dir, "cam", basename + ".safetensor") + + rgb_image = imread_cv2(rgb_path, cv2.IMREAD_COLOR) + depthmap = np.load(depth_path).astype(np.float32) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + # cam_file = np.load(cam_path) + cam_file = load_file(cam_path) + intrinsics = cam_file["intrinsics"] + camera_pose = cam_file["pose"] + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.1, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="mp3d", + label=scene + "_" + rgb_path, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.99, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/mvimgnet.py b/stream3r/dust3r/datasets_cut3r/mvimgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c30c0fc5e87ef3dd7105b01adcaeb8c9adbeb1a5 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/mvimgnet.py @@ -0,0 +1,190 @@ +import os.path as osp +import pickle +import json +from pathlib import Path +import cv2 +from pdb import set_trace as st +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +from safetensors.numpy import save_file, load_file + + +class MVImgNet_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = False + self.max_interval = 32 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + + def _load_data(self): + # self.scenes = os.listdir(self.ROOT) + # st() + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + # if False: + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + self.invalid_scenes = {scene: False for scene in self.scenes} + else: + + ls_file = Path(self.ROOT) / 'mvimgnet_ls.txt' + if ls_file.exists(): + with open(ls_file) as f: + self.scenes = [scene.strip() for scene in f.readlines()] + else: + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")] + ) + + num_imgs = len(basenames) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend([(scene, id) for id in start_img_ids_]) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + self.invalid_scenes = {scene: False for scene in self.scenes} + + # st() # save all required stuffs to the json, avoid re-calculating all the stuffs during loading + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list, + invalid_scenes=self.invalid_scenes), + f, + ) + # st() + # pass + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene, start_id = self.start_img_ids[idx] + + while invalid_seq: + while self.invalid_scenes[scene]: + idx = rng.integers(low=0, high=len(self.start_img_ids)) + scene, start_id = self.start_img_ids[idx] + + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=self.max_interval + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + try: + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap, no depth, set to all ones + depthmap = np.ones_like(rgb_image[..., 0], dtype=np.float32) + cam = load_file(osp.join(cam_dir, basename + ".safetensor")) + camera_pose = cam["pose"] + intrinsics = np.eye(3) + intrinsics[0, 0] = cam["intrinsics"][0, 0] + intrinsics[1, 1] = cam["intrinsics"][0, 0] + intrinsics[0, 2] = cam["intrinsics"][1, 1] + intrinsics[1, 2] = cam["intrinsics"][0, 2] + except: + print(f"Error loading {scene} {basename}, skipping") + self.invalid_scenes[scene] = True + break + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="MVImgnet", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=True, + depth_only=False, + single_view=False, + reset=False, + ) + ) + if len(views) == num_views: + invalid_seq = False + return views diff --git a/stream3r/dust3r/datasets_cut3r/mvs_synth.py b/stream3r/dust3r/datasets_cut3r/mvs_synth.py new file mode 100644 index 0000000000000000000000000000000000000000..37e31a6c22c45c5b9cff76541dedb7ca86f883fd --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/mvs_synth.py @@ -0,0 +1,174 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 +from safetensors.numpy import save_file, load_file + +import pickle +from pdb import set_trace as st + + +class MVS_Synth_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = False + self.max_interval = 4 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + return + + + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")] + ) + num_imgs = len(basenames) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list,), + f, + ) + + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + depthmap[depthmap > 1000] = 0.0 + + cam = load_file(osp.join(cam_dir, basename + ".safetensor")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.8, 0.15, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="MVS_Synth", + label=self.scenes[scene_id] + "_" + basename, + instance=osp.join(rgb_dir, basename + ".jpg"), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/omniobject3d.py b/stream3r/dust3r/datasets_cut3r/omniobject3d.py new file mode 100755 index 0000000000000000000000000000000000000000..dee8c9c7426bf2eb7eb77743569aad63871c8b27 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/omniobject3d.py @@ -0,0 +1,146 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys +import json + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 +import re + + +def extract_number(filename): + match = re.search(r"\d+", filename) + if match: + return int(match.group()) + return 0 + + +class OmniObject3D_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = False # True + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + with open(os.path.join(self.ROOT, "scale.json"), "r") as f: + self.scales = json.load(f) + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=extract_number, + ) + + num_imgs = len(basenames) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend([(scene, id) for id in start_img_ids_]) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + scene, start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=100, video_prob=0.0 + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + scale = self.scales[self.scenes[scene_id]] + depthmap = depthmap / scale / 1000.0 + camera_pose[:3, 3] = camera_pose[:3, 3] / scale / 1000.0 + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.8, 0.15, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="OmniObject3D", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/pointodyssey.py b/stream3r/dust3r/datasets_cut3r/pointodyssey.py new file mode 100644 index 0000000000000000000000000000000000000000..d49dc754de9749597cb2052b801029dcd2ada5bd --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/pointodyssey.py @@ -0,0 +1,206 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +import pickle +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +from safetensors.numpy import save_file, load_file + +class PointOdyssey_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 4 + super().__init__(*args, **kwargs) + assert self.split in ["train", "test", "val"] + self.scenes_to_use = [ + # 'cab_h_bench_3rd', 'cab_h_bench_ego1', 'cab_h_bench_ego2', + "cnb_dlab_0215_3rd", + "cnb_dlab_0215_ego1", + "cnb_dlab_0225_3rd", + "cnb_dlab_0225_ego1", + "dancing", + "dancingroom0_3rd", + "footlab_3rd", + "footlab_ego1", + "footlab_ego2", + "girl", + "girl_egocentric", + "human_egocentric", + "human_in_scene", + "human_in_scene1", + "kg", + "kg_ego1", + "kg_ego2", + "kitchen_gfloor", + "kitchen_gfloor_ego1", + "kitchen_gfloor_ego2", + "scene_carb_h_tables", + "scene_carb_h_tables_ego1", + "scene_carb_h_tables_ego2", + "scene_j716_3rd", + "scene_j716_ego1", + "scene_j716_ego2", + "scene_recording_20210910_S05_S06_0_3rd", + "scene_recording_20210910_S05_S06_0_ego2", + "scene1_0129", + "scene1_0129_ego", + "seminar_h52_3rd", + "seminar_h52_ego1", + "seminar_h52_ego2", + ] + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + return + + root = os.path.join(self.ROOT, split) + self.scenes = [] + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(os.listdir(root)): + if scene not in self.scenes_to_use: + continue + scene_dir = osp.join(root, scene) + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")] + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + # start_img_ids_ = img_ids[:-self.num_views+1] + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list,), + f, + ) + + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + depthmap[depthmap > 1000] = 0.0 + + cam = load_file(osp.join(cam_dir, basename + ".safetensor")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.9, 0.05, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="PointOdyssey", + label=self.scenes[scene_id] + "_" + basename, + instance=osp.join(rgb_dir, basename + ".jpg"), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/project_aria_seq.py b/stream3r/dust3r/datasets_cut3r/project_aria_seq.py new file mode 100644 index 0000000000000000000000000000000000000000..1f27d11eb671c3aaf4a78a64de383b8bce4d2613 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/project_aria_seq.py @@ -0,0 +1,213 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Dataloader for preprocessed project-aria dataset +# -------------------------------------------------------- +import os.path as osp +import os +import cv2 +import numpy as np +import math +import sys # noqa: E402 + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +# SLAM3R_DIR = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) +# from slam3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class Aria_Seq(BaseMultiViewDataset): + def __init__(self, + ROOT='data/projectaria/ase_processed', + # num_views=2, + scene_name=None, # specify scene name(s) to load + sample_freq=1, # stride of the frmaes inside the sliding window + start_freq=1, # start frequency for the sliding window + filter=False, # filter out the windows with abnormally large stride + rand_sel=False, # randomly select views from a window + winsize=0, # window size to randomly select views + sel_num=0, # number of combinations to randomly select from a window + *args,**kwargs): + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.sample_freq = sample_freq + self.start_freq = start_freq + # self.num_views = num_views + self.is_metric = True + + self.rand_sel = rand_sel + if rand_sel: + assert winsize > 0 and sel_num > 0 + comb_num = math.comb(winsize-1, self.num_views-2) + assert comb_num >= sel_num + self.winsize = winsize + self.sel_num = sel_num + else: + self.winsize = sample_freq*(self.num_views-1) + + self.scene_names = os.listdir(self.ROOT) + self.scene_names = [int(scene_name) for scene_name in self.scene_names if scene_name.isdigit()] + self.scene_names = sorted(self.scene_names) + self.scene_names = [str(scene_name) for scene_name in self.scene_names] + total_scene_num = len(self.scene_names) + + if self.split == 'train': + # choose 90% of the data as training set + self.scene_names = self.scene_names[:int(total_scene_num*0.9)] + elif self.split=='test': + self.scene_names = self.scene_names[int(total_scene_num*0.9):] + if scene_name is not None: + assert self.split is None + if isinstance(scene_name, list): + self.scene_names = scene_name + else: + if isinstance(scene_name, int): + scene_name = str(scene_name) + assert isinstance(scene_name, str) + self.scene_names = [scene_name] + + self._load_data(filter=filter) + print(self) + + def filter_windows(self, sid, eid, image_names): + return False + + def _load_data(self, filter=False): + self.sceneids = [] + self.images = [] + self.intrinsics = [] #scene_num*(3,3) + self.win_bid = [] + + num_count = 0 + for id, scene_name in enumerate(self.scene_names): + scene_dir = os.path.join(self.ROOT, scene_name) + # print(id, scene_name) + image_names = os.listdir(os.path.join(scene_dir, 'color')) + image_names = sorted(image_names) + intrinsic = np.loadtxt(os.path.join(scene_dir, 'intrinsic', 'intrinsic_color.txt'))[:3,:3] + image_num = len(image_names) + # precompute the window indices + for i in range(0, image_num, self.start_freq): + last_id = i+self.winsize + if last_id >= image_num: + break + if filter and self.filter_windows(i, last_id, image_names): + continue + self.win_bid.append((num_count+i, num_count+last_id)) + + self.intrinsics.append(intrinsic) + self.images += image_names + self.sceneids += [id,] * image_num + num_count += image_num + # print(self.sceneids, self.scene_names) + self.intrinsics = np.stack(self.intrinsics, axis=0) + print(self.intrinsics.shape) + assert len(self.sceneids)==len(self.images), f"{len(self.sceneids)}, {len(self.images)}" + + def __len__(self): + if self.rand_sel: + return self.sel_num*len(self.win_bid) + return len(self.win_bid) + + def get_img_idxes(self, idx, rng, num_views): + if self.rand_sel: + sid, eid = self.win_bid[idx//self.sel_num] + if idx % self.sel_num == 0: + return np.linspace(sid, eid, num_views, endpoint=True, dtype=int) + + if self.num_views == 2: + return [sid, eid] + sel_ids = rng.choice(range(sid+1, eid), num_views-2, replace=False) + sel_ids.sort() + return [sid] + list(sel_ids) + [eid] + else: + sid, eid = self.win_bid[idx] + return [sid + i*self.sample_freq for i in range(num_views)] + + + def _get_views(self, idx, resolution, rng, num_views): + + image_idxes = self.get_img_idxes(idx, rng, num_views) + # print(image_idxes) + views = [] + for view_idx in image_idxes: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scene_names[scene_id]) + + intrinsics = self.intrinsics[scene_id] + basename = self.images[view_idx] + camera_pose = np.loadtxt(osp.join(scene_dir, 'pose', basename.replace('.jpg', '.txt'))) + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, 'color', basename)) + # Load depthmap + depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename.replace('.jpg', '.png')), cv2.IMREAD_UNCHANGED) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[depthmap > 20] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) + + views.append(dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset='Aria', + label=self.scene_names[scene_id] + '_' + basename, + instance=f'{str(idx)}_{str(view_idx)}', + # other stuffs in cut3r + is_metric=self.is_metric, + is_video=False, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + )) + # print([view['label'] for view in views]) + return views + +if __name__ == "__main__": + import trimesh + + num_views = 4 + # dataset = Aria_Seq(resolution=(224,224), + # num_views=num_views, + # start_freq=1, sample_freq=2) + dataset = Aria_Seq(split='train', resolution=(224,224), + num_views=num_views, + start_freq=1, rand_sel=True, winsize=6, sel_num=3) + save_dir = "visualization/aria_seq_views" + os.makedirs(save_dir, exist_ok=True) + + for idx in np.random.permutation(len(dataset))[:10]: + os.makedirs(osp.join(save_dir, str(idx)), exist_ok=True) + views = dataset[(idx,0)] + assert len(views) == num_views + all_pts = [] + all_color=[] + for i, view in enumerate(views): + img = np.array(view['img']).transpose(1, 2, 0) + # save_path = osp.join(save_dir, str(idx), f"{'_'.join(view_name(view).split('/')[1:])}.jpg") + save_path = osp.join(save_dir, str(idx), f"{i}_{view['label']}") + # img=cv2.COLOR_RGB2BGR(img) + img=img[...,::-1] + img = (img+1)/2 + cv2.imwrite(save_path, img*255) + print(f"save to {save_path}") + img = img[...,::-1] + pts3d = np.array(view['pts3d']).reshape(-1,3) + pct = trimesh.PointCloud(pts3d, colors=img.reshape(-1, 3)) + pct.export(save_path.replace('.jpg','.ply')) + all_pts.append(pts3d) + all_color.append(img.reshape(-1, 3)) + all_pts = np.concatenate(all_pts, axis=0) + all_color = np.concatenate(all_color, axis=0) + pct = trimesh.PointCloud(all_pts, all_color) + pct.export(osp.join(save_dir, str(idx), f"all.ply")) \ No newline at end of file diff --git a/stream3r/dust3r/datasets_cut3r/realestate10k.py b/stream3r/dust3r/datasets_cut3r/realestate10k.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba1774a82b8d9063bf37bc14368f4ea51cbc9a9 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/realestate10k.py @@ -0,0 +1,167 @@ +import os.path as osp +import pickle +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class RE10K_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = False + self.max_interval = 128 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + self.invalid_scenes = {scene: False for scene in self.scenes} + + return + + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + if not osp.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=lambda x: int(x), + ) + + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + start_img_ids.extend([(scene, id) for id in start_img_ids_]) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list,), + f, + ) + + self.invalid_scenes = {scene: False for scene in self.scenes} + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + invalid_seq = True + scene, start_id = self.start_img_ids[idx] + + while invalid_seq: + while self.invalid_scenes[scene]: + idx = rng.integers(low=0, high=len(self.start_img_ids)) + scene, start_id = self.start_img_ids[idx] + + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=self.max_interval + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + try: + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap, no depth, set to all ones + depthmap = np.ones_like(rgb_image[..., 0], dtype=np.float32) + cam = np.load(osp.join(cam_dir, basename + ".npz")) + intrinsics = cam["intrinsics"] + camera_pose = cam["pose"] + except: + print(f"Error loading {scene} {basename}, skipping") + self.invalid_scenes[scene] = True + break + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="realestate10k", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=True, + depth_only=False, + single_view=False, + reset=False, + ) + ) + if len(views) == num_views: + invalid_seq = False + return views diff --git a/stream3r/dust3r/datasets_cut3r/scannet.py b/stream3r/dust3r/datasets_cut3r/scannet.py new file mode 100755 index 0000000000000000000000000000000000000000..56808ce9e5aeee54da792370053c5605ef89e55a --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/scannet.py @@ -0,0 +1,181 @@ +import os.path as osp +from safetensors.numpy import save_file, load_file +import cv2 +from pdb import set_trace as st +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +import pickle + +class ScanNet_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 30 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data(self.split) + + def _load_data(self, split): + self.scene_root = osp.join( + self.ROOT, "scans" if split == "train" else "scans_test" + ) + self.scenes = [ + scene for scene in os.listdir(self.scene_root) if scene.startswith("scene") + ] + + if os.path.exists(osp.join(self.ROOT, f'{split}-pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'{split}-pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + else: + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.scene_root, scene) + if not os.path.exists(osp.join(scene_dir, "new_scene_metadata.pkl")): + continue + with open(osp.join(scene_dir, "new_scene_metadata.pkl"), 'rb') as f: + data = pickle.load(f) + # with np.load( + # osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True + # ) as data: + basenames = data["images"] + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open(osp.join(self.ROOT, f'{split}-pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list,), + f, + ) + + # st() + # pass + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=0.6, + fix_interval_prob=0.6, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.scene_root, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "color") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".jpg")) + # Load depthmap + depthmap = imread_cv2( + osp.join(depth_dir, basename + ".png"), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + # cam = np.load(osp.join(cam_dir, basename + ".npz")) + cam = load_file(osp.join(cam_dir, basename + ".safetensor")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="ScanNet", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/scannetpp.py b/stream3r/dust3r/datasets_cut3r/scannetpp.py new file mode 100755 index 0000000000000000000000000000000000000000..98054c0b3d49c116e4608b961ce5cf5399426ff4 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/scannetpp.py @@ -0,0 +1,227 @@ +import os.path as osp +from pdb import set_trace as st +import pickle +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class ScanNetpp_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 3 + super().__init__(*args, **kwargs) + assert self.split == "train" + self.loaded_data = self._load_data() + + def _load_data(self): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + # if False: + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.intrinsics = pre_calculated_data['intrinsics'] + self.trajectories = pre_calculated_data['trajectories'] + self.groups = pre_calculated_data['groups'] + self.id_ranges = pre_calculated_data['id_ranges'] + + return + + with np.load(osp.join(self.ROOT, "all_metadata.npz")) as data: + self.scenes = data["scenes"] + + offset = 0 + scenes = [] + sceneids = [] + images = [] + intrinsics = [] + trajectories = [] + groups = [] + id_ranges = [] + j = 0 + self.image_num = 0 + for scene in self.scenes: + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + with np.load( + osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True + ) as data: + imgs = data["images"] + self.image_num += len(imgs) + img_ids = np.arange(len(imgs)).tolist() + intrins = data["intrinsics"] + traj = data["trajectories"] + imgs_on_disk = sorted(os.listdir(osp.join(scene_dir, "images"))) + imgs_on_disk = list(map(lambda x: x[:-4], imgs_on_disk)) + + dslr_ids = [ + i + offset + for i in img_ids + if imgs[i].startswith("DSC") and imgs[i] in imgs_on_disk + ] + iphone_ids = [ + i + offset + for i in img_ids + if imgs[i].startswith("frame") and imgs[i] in imgs_on_disk + ] + + num_imgs = len(imgs) + assert max(dslr_ids) < min(iphone_ids) + assert "image_collection" in data + + img_groups = [] + img_id_ranges = [] + + for ref_id, group in data["image_collection"].item().items(): + if len(group) + 1 < self.num_views: + continue + group.insert(0, (ref_id, 1.0)) + sorted_group = sorted(group, key=lambda x: x[1], reverse=True) + group = [int(x[0] + offset) for x in sorted_group] + img_groups.append(sorted(group)) + + if imgs[ref_id].startswith("frame"): + img_id_ranges.append(dslr_ids) + else: + img_id_ranges.append(iphone_ids) + + if len(img_groups) == 0: + print(f"Skipping {scene}") + continue + scenes.append(scene) + sceneids.extend([j] * num_imgs) + images.extend(imgs) + intrinsics.append(intrins) + trajectories.append(traj) + + # offset groups + groups.extend(img_groups) + id_ranges.extend(img_id_ranges) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.intrinsics = np.concatenate(intrinsics, axis=0) + self.trajectories = np.concatenate(trajectories, axis=0) + self.id_ranges = id_ranges + self.groups = groups + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=self.images, + intrinsics=self.intrinsics, + trajectories=self.trajectories, + groups=self.groups, + id_ranges=self.id_ranges), + f, + ) + + def __len__(self): + return len(self.groups) * 10 + + def get_image_num(self): + return self.image_num + + def _get_views(self, idx, resolution, rng, num_views): + idx = idx // 10 + image_idxs = self.groups[idx] + rand_val = rng.random() + + image_idxs_video = self.id_ranges[idx] + cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) + start_image_idxs = image_idxs_video[: len(image_idxs_video) - cut_off + 1] + + if rand_val < 0.7 and len(start_image_idxs) > 0: + start_id = rng.choice(start_image_idxs) + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + image_idxs_video, + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.5, + block_shuffle=16, + ) + image_idxs = np.array(image_idxs_video)[pos] + + else: + ordered_video = True + # ordered video with varying intervals + num_candidates = len(image_idxs) + max_id = min(num_candidates, int(num_views * (2 + 2 * rng.random()))) + image_idxs = sorted(rng.permutation(image_idxs[:max_id])[:num_views]) + if rand_val > 0.75: + ordered_video = False + image_idxs = rng.permutation(image_idxs) + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + + # st() + intrinsics = self.intrinsics[view_idx] + camera_pose = self.trajectories[view_idx] + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(scene_dir, "images", basename + ".jpg")) + # Load depthmap + depthmap = imread_cv2( + osp.join(scene_dir, "depth", basename + ".png"), cv2.IMREAD_UNCHANGED + ) + depthmap = depthmap.astype(np.float32) / 1000 + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="ScanNet++", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.99, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + if not len(views) == num_views: # ! fails on world_size>1, views!=2 + st() + return views diff --git a/stream3r/dust3r/datasets_cut3r/smartportraits.py b/stream3r/dust3r/datasets_cut3r/smartportraits.py new file mode 100644 index 0000000000000000000000000000000000000000..ef4330a3314cd91aac712fb4b5d23d80742ed7bb --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/smartportraits.py @@ -0,0 +1,87 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class SmartPortraits_Multi(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + scenes = os.listdir(self.ROOT) + img_names = [] + for scene in scenes: + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + img_names.extend([(scene, basename) for basename in basenames]) + + self.img_names = img_names + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + img_names = new_rng.choice(self.img_names, num_views, replace=False) + + views = [] + for v, img_name in enumerate(img_names): + # Load RGB image + scene, img_name = img_name + rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png")) + depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy")) + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + + intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))[ + "intrinsics" + ] + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="SmartPortraits", + label=img_name, + instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"), + is_metric=self.is_metric, + is_video=False, + quantile=np.array(0.98, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/spring.py b/stream3r/dust3r/datasets_cut3r/spring.py new file mode 100755 index 0000000000000000000000000000000000000000..44a5fdb668c50114007b3acd34e751a2d3d55e80 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/spring.py @@ -0,0 +1,165 @@ +import os.path as osp +import pickle +from safetensors.numpy import save_file, load_file +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class Spring(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 16 + super().__init__(*args, **kwargs) + + self.loaded_data = self._load_data() + + def _load_data(self): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + return + + + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + # start_img_ids_ = img_ids[:-self.num_views+1] + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list,), + f, + ) + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=1.0, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + + cam = load_file(osp.join(cam_dir, basename + ".safetensor")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.10, 0.05] + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="spring", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/synscapes.py b/stream3r/dust3r/datasets_cut3r/synscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..524ea275c7ebc2695864f8eba672af2ad46451e8 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/synscapes.py @@ -0,0 +1,154 @@ +import os.path as osp +import pickle +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from safetensors.numpy import save_file, load_file +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class SynScapes(BaseMultiViewDataset): + + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), + 'rb') as f: + pre_calculated_data = pickle.load(f) + self.img_names = pre_calculated_data['img_names'] + return + + rgb_dir = osp.join(self.ROOT, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=lambda x: int(x), + ) + self.img_names = basenames # 25K imgs + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), + 'wb') as f: + pickle.dump(dict(img_names=self.img_names), f) + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + img_names = new_rng.choice(self.img_names, num_views, replace=False) + + views = [] + for v, img_name in enumerate(img_names): + # Load RGB image + rgb_image = imread_cv2( + osp.join(self.ROOT, "rgb", f"{img_name}.png")) + depthmap = np.load(osp.join(self.ROOT, "depth", f"{img_name}.npy")) + sky_mask = (imread_cv2( + osp.join(self.ROOT, "sky_mask", f"{img_name}.png"))[..., 0] + >= 127) + depthmap[sky_mask] = -1.0 + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + depthmap[depthmap > 200] = 0.0 + + intrinsics = load_file( + osp.join(self.ROOT, "cam", + f"{img_name}.safetensor"))["intrinsics"] + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, + depthmap, + intrinsics, + resolution, + rng=rng, + info=img_name) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="synscapes", + label=img_name, + instance=f"{str(idx)}_{img_name}", + is_metric=self.is_metric, + is_video=False, + quantile=np.array(1.0, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + )) + assert len(views) == num_views + return views + + +if __name__ == "__main__": + import torch + import pause + from torchvision.transforms import ToPILImage + from stream3r.dust3r.datasets.base.base_stereo_view_dataset import view_name + from stream3r.dust3r.utils.image import rgb + from stream3r.dust3r.viz import SceneViz, auto_cam_size + from IPython.display import display + from stream3r.dust3r.datasets.utils.transforms import ImgNorm, convert_input_to_pred_format, vis_track + from stream3r.dust3r.utils.geometry import ( + geotrf, + inv, + ) + from stream3r.viz.viser_visualizer_track import start_visualization + + def main(): + dataset = SynScapes( + split="train", allow_repeat=False, ROOT="/mnt/storage/yslan-data/cut3r_processed/processed_synscapes/", + aug_crop=0, resolution=(512, 384), num_views=20, transform=ImgNorm + ) + + # import random + # for i in random.sample(range(len(dataset)), 100): + # views = dataset[i] + # print(i) + + select_idx = 1 + views = dataset[select_idx] + output = convert_input_to_pred_format(views) + + # save_path = os.path.join("develop/2d_compare/test_data", views[0]['dataset'] + str(select_idx)) + # os.makedirs(save_path, exist_ok=True) + # for i in range(len(views)): + # print(view_name(views[i])) + # ToPILImage()(rgb(views[i]["img"])).save(f"{save_path}/{i}.png") + + server = start_visualization( + output=output, + min_conf_thr_percentile=0, + global_conf_thr_value_to_drop_view=1, + point_size=0.0016, + ) + + # share_url = servers.request_share_url() + # print(share_url) + + pause.days(1) + + main() \ No newline at end of file diff --git a/stream3r/dust3r/datasets_cut3r/tartanair.py b/stream3r/dust3r/datasets_cut3r/tartanair.py new file mode 100644 index 0000000000000000000000000000000000000000..530b0aa73e2a4a9b42dd40b6af326dc1bb154e6b --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/tartanair.py @@ -0,0 +1,191 @@ +import os.path as osp +import pickle +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys +from pdb import set_trace as st + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +from safetensors.numpy import save_file, load_file + + +class TartanAir_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 20 + super().__init__(*args, **kwargs) + # loading all + assert self.split is None + self._load_data() + + def _load_data(self): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + else: + + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + scene_img_list = [] + start_img_ids = [] + j = 0 + + for scene in scene_dirs: + for mode in ["Easy", "Hard"]: + seq_dirs = sorted( + [ + os.path.join(self.ROOT, scene, mode, d) + for d in os.listdir(os.path.join(self.ROOT, scene, mode)) + if os.path.isdir(os.path.join(self.ROOT, scene, mode, d)) + ] + ) + for seq_dir in seq_dirs: + basenames = sorted( + [f[:-8] for f in os.listdir(seq_dir) if f.endswith(".png")] + ) + num_imgs = len(basenames) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + scenes.append(seq_dir) + scene_img_list.append(img_ids) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + start_img_ids.extend(start_img_ids_) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list,), + f, + ) + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + scene_id = self.sceneids[start_id] + all_image_ids = self.scene_img_list[scene_id] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=0.8, + fix_interval_prob=0.8, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = self.scenes[scene_id] + basename = self.images[view_idx] + + img = basename + "_rgb.png" + image = imread_cv2(osp.join(scene_dir, img)) + depthmap = np.load(osp.join(scene_dir, basename + "_depth.npy")) + camera_params = load_file(osp.join(scene_dir, basename + "_cam.safetensor")) + + intrinsics = camera_params["camera_intrinsics"] + camera_pose = camera_params["camera_pose"] + + sky_mask = depthmap >= 1000 + depthmap[sky_mask] = -1.0 # sky + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, img) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="TartanAir", + label=scene_dir, + is_metric=self.is_metric, + instance=scene_dir + "_" + img, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/threedkb.py b/stream3r/dust3r/datasets_cut3r/threedkb.py new file mode 100644 index 0000000000000000000000000000000000000000..eca3dbba04cb4fc9bdadf71ce438186e7acdfb76 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/threedkb.py @@ -0,0 +1,111 @@ +import os.path as osp +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class ThreeDKenBurns(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = False + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + rgb_dir = osp.join(scene_dir, "rgb") + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")] + ) + + num_imgs = len(basenames) + img_ids_ = list(np.arange(num_imgs) + offset) + + img_ids.extend(img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.img_ids = img_ids + + def __len__(self): + return len(self.img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + image_idxs = new_rng.choice(self.img_ids, num_views, replace=False) + + views = [] + for view_idx in image_idxs: + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + depthmap = imread_cv2(osp.join(depth_dir, basename + ".exr")) + depthmap[depthmap > 20000] = 0.0 + depthmap = depthmap / 1000.0 + cam = np.load(osp.join(cam_dir, basename + ".npz")) + intrinsics = cam["intrinsics"] + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="3DKenBurns", + label=self.scenes[scene_id] + "_" + basename, + instance=f"{str(idx)}_{str(view_idx)}", + is_metric=self.is_metric, + is_video=False, + quantile=np.array(1.0, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/uasol.py b/stream3r/dust3r/datasets_cut3r/uasol.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c9609639e418377c60a5fbb31ad58ff59bc46f --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/uasol.py @@ -0,0 +1,188 @@ +import os.path as osp +import pickle +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +import re + + +def extract_number(filename): + match = re.search(r"\d+", filename) + if match: + return int(match.group()) + return 0 + + +class UASOL_Multi(BaseMultiViewDataset): + + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 40 + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + + if os.path.exists( + osp.join(self.ROOT, + f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open( + osp.join(self.ROOT, + f'pre-calculated-loaddata-{self.num_views}.pkl'), + 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + return + + self.scenes = os.listdir(self.ROOT) + + offset = 0 + scenes = [] + sceneids = [] + scene_img_list = [] + images = [] + start_img_ids = [] + + j = 0 + for scene in tqdm(self.scenes): + scene_dir = osp.join(self.ROOT, scene) + if not osp.isdir(scene_dir): + continue + rgb_dir = osp.join(scene_dir, "rgb") + if not osp.isdir(rgb_dir): + continue + basenames = sorted( + [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")], + key=extract_number, + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + # start_img_ids_ = img_ids[:-self.num_views+1] + cut_off = (self.num_views if not self.allow_repeat else max( + self.num_views // 3, 3)) + start_img_ids_ = img_ids[:num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(scene) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open( + osp.join(self.ROOT, + f'pre-calculated-loaddata-{self.num_views}.pkl'), + 'wb') as f: + pickle.dump( + dict( + scenes=self.scenes, + sceneids=self.sceneids, + images=self.images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list, + ), + f, + ) + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=0.75, + fix_interval_prob=0.75, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + rgb_dir = osp.join(scene_dir, "rgb") + depth_dir = osp.join(scene_dir, "depth") + cam_dir = osp.join(scene_dir, "cam") + + basename = self.images[view_idx] + + # Load RGB image + rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png")) + # Load depthmap + depthmap = np.load(osp.join(depth_dir, basename + ".npy")) + depthmap[~np.isfinite(depthmap)] = 0 # invalid + depthmap[depthmap >= 20] = 0 # invalid + + cam = np.load(osp.join(cam_dir, basename + ".npz")) + camera_pose = cam["pose"] + intrinsics = cam["intrinsics"] + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, + depthmap, + intrinsics, + resolution, + rng=rng, + info=view_idx) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05]) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="UASOL", + label=self.scenes[scene_id] + "_" + basename, + instance=osp.join(rgb_dir, basename + ".png"), + is_metric=self.is_metric, + is_video=ordered_video, + quantile=np.array(0.9, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + )) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/unreal4k.py b/stream3r/dust3r/datasets_cut3r/unreal4k.py new file mode 100644 index 0000000000000000000000000000000000000000..94898150bde5349a71cc00ad3efc22c80643564c --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/unreal4k.py @@ -0,0 +1,200 @@ +import os.path as osp +from tqdm import tqdm +import pickle +import json +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys + +from safetensors.numpy import save_file, load_file +from pdb import set_trace as st +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + + +R_conv = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).astype( + np.float32 +) + + +class UnReal4K_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.max_interval = 2 + self.is_metric = True + super().__init__(*args, **kwargs) + # loading all + assert self.split is None + self._load_data() + + def _load_data(self): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + + else: + + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + j = 0 + + seq_dirs = sorted( + [ + os.path.join(self.ROOT, scene, mode) + for scene in scene_dirs + for mode in ["0", "1"] + ] + ) + for seq_dir in tqdm(seq_dirs): + basenames = sorted( + [f[:-8] for f in os.listdir(seq_dir) if f.endswith(".png")] + ) + num_imgs = len(basenames) + img_ids = list(np.arange(num_imgs) + offset) + # start_img_ids_ = img_ids[:-self.num_views+1] + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {seq_dir}") + continue + + start_img_ids.extend(start_img_ids_) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + scenes.append(seq_dir) + scene_img_list.append(img_ids) + + # offset groups + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + # st() + + # save_file( + # dict(scenes=self.scenes, + # sceneids=self.sceneids, + # images=images, + # start_img_ids=start_img_ids, + # scene_img_list=scene_img_list,), + # osp.join(self.ROOT, 'pre-calculated-loaddata.safetensor') + # ) + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list,), + f, + ) + # st() + + def __len__(self): + return len(self.start_img_ids) * 10 + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)//10} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + idx = idx // 10 + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + pos, ordered_video = self.get_seq_from_start_id( + num_views, start_id, all_image_ids, rng, max_interval=self.max_interval + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = self.scenes[scene_id] + basename = self.images[view_idx] + + img = basename + "_rgb.png" + image = imread_cv2(osp.join(scene_dir, img)) + depthmap = np.load(osp.join(scene_dir, basename + "_depth.npy")) + camera_params = np.load(osp.join(scene_dir, basename + ".npz")) + + intrinsics = camera_params["intrinsics"].astype(np.float32) + camera_pose = camera_params["cam2world"].astype(np.float32) + + camera_pose = R_conv @ camera_pose + + sky_mask = depthmap >= 1000 + depthmap[sky_mask] = -1.0 # sky + threshold = ( + np.percentile(depthmap[depthmap > 0], 98) + if depthmap[depthmap > 0].size > 0 + else 0 + ) + depthmap[depthmap > threshold] = 0.0 + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, img) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.75, 0.2, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="UnReal4K", + label=scene_dir, + is_metric=self.is_metric, + instance=scene_dir + "_" + img, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/urbansyn.py b/stream3r/dust3r/datasets_cut3r/urbansyn.py new file mode 100644 index 0000000000000000000000000000000000000000..2853725ce144607d47e29c070370dcd24ae8578c --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/urbansyn.py @@ -0,0 +1,98 @@ +import os.path as osp +from safetensors.numpy import save_file, load_file +import pickle +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +from tqdm import tqdm +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class UrbanSyn(BaseMultiViewDataset): + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.video = False + self.is_metric = True + super().__init__(*args, **kwargs) + self.loaded_data = self._load_data() + + def _load_data(self): + + cache_file = osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl') + if os.path.exists(cache_file): + with open(cache_file, + 'rb') as f: + pre_calculated_data = pickle.load(f) + self.img_names = pre_calculated_data['img_names'] + return + else: + + rgb_dir = osp.join(self.ROOT, "rgb") + basenames = sorted([f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")]) + self.img_names = basenames + + with open(cache_file, + 'wb') as f: + pickle.dump(dict(img_names=self.img_names), f) + + def __len__(self): + return len(self.img_names) + + def get_image_num(self): + return len(self.img_names) + + def _get_views(self, idx, resolution, rng, num_views): + new_seed = rng.integers(0, 2**32) + idx + new_rng = np.random.default_rng(new_seed) + img_names = new_rng.choice(self.img_names, num_views, replace=False) + + views = [] + for img_name in img_names: + # Load RGB image + rgb_image = imread_cv2(osp.join(self.ROOT, "rgb", f"{img_name}.png")) + depthmap = np.load(osp.join(self.ROOT, "depth", f"{img_name}.npy")) + sky_mask = ( + imread_cv2(osp.join(self.ROOT, "sky_mask", f"{img_name}.png"))[..., 0] + >= 127 + ) + depthmap[sky_mask] = -1.0 + depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) + depthmap[depthmap > 200] = 0.0 + + intrinsics = load_file(osp.join(self.ROOT, "cam", f"{img_name}.safetensor"))[ + "intrinsics" + ] + # camera pose is not provided, placeholder + camera_pose = np.eye(4) + + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap.astype(np.float32), + camera_pose=camera_pose.astype(np.float32), + camera_intrinsics=intrinsics.astype(np.float32), + dataset="urbansyn", + label=img_name, + instance=f"{str(idx)}_{img_name}", + is_metric=self.is_metric, + is_video=False, + quantile=np.array(1.0, dtype=np.float32), + img_mask=True, + ray_mask=False, + camera_only=False, + depth_only=False, + single_view=True, + reset=True, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/utils/__init__.py b/stream3r/dust3r/datasets_cut3r/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/stream3r/dust3r/datasets_cut3r/utils/corr.py b/stream3r/dust3r/datasets_cut3r/utils/corr.py new file mode 100755 index 0000000000000000000000000000000000000000..1ad682643f385e1cc96a2fc6ef7e73e74ea03f19 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/utils/corr.py @@ -0,0 +1,129 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# modified from DUSt3R + +import numpy as np +from stream3r.dust3r.utils.device import to_numpy +from stream3r.dust3r.utils.geometry import inv, geotrf + + +def reproject_view(pts3d, view2): + shape = view2["pts3d"].shape[:2] + return reproject( + pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape + ) + + +def reproject(pts3d, K, world2cam, shape): + H, W, THREE = pts3d.shape + assert THREE == 3 + + # reproject in camera2 space + with np.errstate(divide="ignore", invalid="ignore"): + pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2) + + # quantize to pixel positions + return (H, W), ravel_xy(pos, shape) + + +def ravel_xy(pos, shape): + H, W = shape + with np.errstate(invalid="ignore"): + qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T + quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip( + min=0, max=H - 1, out=qy + ) + return quantized_pos + + +def unravel_xy(pos, shape): + # convert (x+W*y) back to 2d (x,y) coordinates + return np.unravel_index(pos, shape)[0].base[:, ::-1].copy() + + +def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False): + is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)) + pos1 = is_reciprocal1.nonzero()[0] + pos2 = corres_1_to_2[pos1] + if ret_recip: + return is_reciprocal1, pos1, pos2 + return pos1, pos2 + + +def extract_correspondences_from_pts3d( + view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0 +): + view1, view2 = to_numpy((view1, view2)) + # project pixels from image1 --> 3d points --> image2 pixels + shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2) + shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1) + + # compute reciprocal correspondences: + # pos1 == valid pixels (correspondences) in image1 + is_reciprocal1, pos1, pos2 = reciprocal_1d( + corres1_to_2, corres2_to_1, ret_recip=True + ) + is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)) + + if target_n_corres is None: + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2 + + available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum()) + target_n_positives = int(target_n_corres * (1 - nneg)) + n_positives = min(len(pos1), target_n_positives) + n_negatives = min(target_n_corres - n_positives, available_negatives) + + if n_negatives + n_positives != target_n_corres: + # should be really rare => when there are not enough negatives + # in that case, break nneg and add a few more positives ? + n_positives = target_n_corres - n_negatives + assert n_positives <= len(pos1) + + assert n_positives <= len(pos1) + assert n_positives <= len(pos2) + assert n_negatives <= (~is_reciprocal1).sum() + assert n_negatives <= (~is_reciprocal2).sum() + assert n_positives + n_negatives == target_n_corres + + valid = np.ones(n_positives, dtype=bool) + if n_positives < len(pos1): + # random sub-sampling of valid correspondences + perm = rng.permutation(len(pos1))[:n_positives] + pos1 = pos1[perm] + pos2 = pos2[perm] + + if n_negatives > 0: + # add false correspondences if not enough + def norm(p): + return p / p.sum() + + pos1 = np.r_[ + pos1, + rng.choice( + shape1[0] * shape1[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal1), + ), + ] + pos2 = np.r_[ + pos2, + rng.choice( + shape2[0] * shape2[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal2), + ), + ] + valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)] + + # convert (x+W*y) back to 2d (x,y) coordinates + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2, valid diff --git a/stream3r/dust3r/datasets_cut3r/utils/cropping.py b/stream3r/dust3r/datasets_cut3r/utils/cropping.py new file mode 100755 index 0000000000000000000000000000000000000000..dc091d576c024d207e3009ccb8f2573d84b48a71 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/utils/cropping.py @@ -0,0 +1,153 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# croppping utilities +# -------------------------------------------------------- +import PIL.Image +import os + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +from pdb import set_trace as st +import cv2 # noqa +import numpy as np # noqa +from stream3r.dust3r.utils.geometry import ( + colmap_to_opencv_intrinsics, + opencv_to_colmap_intrinsics, +) # noqa + +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + + +class ImageList: + """Convenience class to aply the same operation to a whole set of images.""" + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch("resize", *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch("crop", *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap(image, + depthmap, + camera_intrinsics, + output_resolution, + force=True): + """Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + # st() + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + # can also use this with masks instead of depthmaps + assert tuple(depthmap.shape[:2]) == image.size[::-1] + + # define output resolution + assert output_resolution.shape == (2, ) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + return (image.to_pil(), depthmap, camera_intrinsics) + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + # first rescale the image so that it contains the crop + image = image.resize(tuple(output_resolution), + resample=lanczos if scale_final < 1 else bicubic) + if depthmap is not None: + depthmap = cv2.resize( + depthmap, + output_resolution, + fx=scale_final, + fy=scale_final, + interpolation=cv2.INTER_NEAREST, + ) + + # no offset here; simple rescaling + camera_intrinsics = camera_matrix_of_crop(camera_intrinsics, + input_resolution, + output_resolution, + scaling=scale_final) + + return image.to_pil(), depthmap, camera_intrinsics + + +def camera_matrix_of_crop( + input_camera_matrix, + input_resolution, + output_resolution, + scaling=1, + offset_factor=0.5, + offset=None, +): + # Margins to offset the origin + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + # Generate new camera parameters + output_camera_matrix_colmap = opencv_to_colmap_intrinsics( + input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics( + output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + depthmap = depthmap[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, + output_resolution): + out_width, out_height = output_resolution + l, t = np.int32( + np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l + out_width, t + out_height) + return crop_bbox diff --git a/stream3r/dust3r/datasets_cut3r/utils/transforms.py b/stream3r/dust3r/datasets_cut3r/utils/transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..32b491bf86a9b6ea033fd0e77f36ed16ca85d91f --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/utils/transforms.py @@ -0,0 +1,82 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUST3R default transforms +# -------------------------------------------------------- +import torchvision.transforms as tvf +from stream3r.dust3r.utils.image import ImgNorm +from pdb import set_trace as st + +# define the standard image transforms +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) + + +def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True): + if isinstance(value, (int, float)): + if value < 0: + raise ValueError(f"If is a single number, it must be non negative.") + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + value = [float(value[0]), float(value[1])] + else: + raise TypeError(f"should be a single number or a list/tuple with length 2.") + + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"values should be between {bound}, but got {value}.") + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + return None + else: + return tuple(value) + + +import torch +import torchvision.transforms.functional as F + + +def SeqColorJitter(): + """ + Return a color jitter transform with same random parameters + """ + brightness = _check_input(0.5) + contrast = _check_input(0.5) + saturation = _check_input(0.5) + hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + + fn_idx = torch.randperm(4) + brightness_factor = ( + None + if brightness is None + else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + ) + contrast_factor = ( + None + if contrast is None + else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + ) + saturation_factor = ( + None + if saturation is None + else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + ) + hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + + def _color_jitter(img): + # st() + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return ImgNorm(img) + + return _color_jitter diff --git a/stream3r/dust3r/datasets_cut3r/vkitti2.py b/stream3r/dust3r/datasets_cut3r/vkitti2.py new file mode 100755 index 0000000000000000000000000000000000000000..a9a5fb6f7f2fb53bf11d159d7ca4ca6c493529b7 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/vkitti2.py @@ -0,0 +1,195 @@ +import os.path as osp +from pdb import set_trace as st +import pickle +import numpy as np +import cv2 +import numpy as np +import itertools +import os +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) + +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + +from safetensors.numpy import save_file, load_file + + +class VirtualKITTI2_Multi(BaseMultiViewDataset): + + def __init__(self, ROOT, *args, **kwargs): + self.ROOT = ROOT + self.video = True + self.is_metric = True + self.max_interval = 5 + super().__init__(*args, **kwargs) + # loading all + self._load_data(self.split) + + def _load_data(self, split=None): + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + if split == "train": + scene_dirs = scene_dirs[:-1] + elif split == "test": + scene_dirs = scene_dirs[-1:] + + if os.path.exists(osp.join(self.ROOT, f'{split}-pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'{split}-pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + else: + + seq_dirs = [] + for scene in scene_dirs: + seq_dirs += sorted( + [ + os.path.join(scene, d) + for d in os.listdir(os.path.join(self.ROOT, scene)) + if os.path.isdir(os.path.join(self.ROOT, scene, d)) + ] + ) + offset = 0 + scenes = [] + sceneids = [] + images = [] + scene_img_list = [] + start_img_ids = [] + j = 0 + + for seq_idx, seq in enumerate(seq_dirs): + seq_path = osp.join(self.ROOT, seq) + for cam in ["Camera_0", "Camera_1"]: + basenames = sorted( + [ + f[:5] + for f in os.listdir(seq_path + "/" + cam) + if f.endswith(".jpg") + ] + ) + num_imgs = len(basenames) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + if num_imgs < cut_off: + print(f"Skipping {scene}") + continue + img_ids = list(np.arange(num_imgs) + offset) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + scenes.append(seq + "/" + cam) + scene_img_list.append(img_ids) + sceneids.extend([j] * num_imgs) + images.extend(basenames) + start_img_ids.extend(start_img_ids_) + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + + with open(osp.join(self.ROOT, f'{split}-pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list,), + f, + ) + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + scene_id = self.sceneids[start_id] + all_image_ids = self.scene_img_list[scene_id] + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=self.max_interval, + video_prob=1.0, + fix_interval_prob=0.9, + ) + image_idxs = np.array(all_image_ids)[pos] + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) + basename = self.images[view_idx] + + img = basename + "_rgb.jpg" + image = imread_cv2(osp.join(scene_dir, img)) + depthmap = ( + cv2.imread( + osp.join(scene_dir, basename + "_depth.png"), + cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, + ).astype(np.float32) + / 100.0 + ) + camera_params = np.load(osp.join(scene_dir, basename + "_cam.npz")) + + intrinsics = camera_params["camera_intrinsics"] + camera_pose = camera_params["camera_pose"] + + sky_mask = depthmap >= 655 + depthmap[sky_mask] = -1.0 # sky + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, img) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.1, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="VirtualKITTI2", + label=scene_dir, + is_metric=self.is_metric, + instance=scene_dir + "_" + img, + is_video=ordered_video, + quantile=np.array(1.0, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + assert len(views) == num_views + return views diff --git a/stream3r/dust3r/datasets_cut3r/waymo.py b/stream3r/dust3r/datasets_cut3r/waymo.py new file mode 100755 index 0000000000000000000000000000000000000000..4259af77d56ba9dccc729f903f4b7a554f70d2d6 --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/waymo.py @@ -0,0 +1,210 @@ +import os.path as osp +import pickle +import os +import numpy as np +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import h5py + +from safetensors.numpy import save_file, load_file +from stream3r.dust3r.datasets_cut3r.base.base_multiview_dataset import BaseMultiViewDataset +from stream3r.dust3r.utils.image import imread_cv2 + + +class Waymo_Multi(BaseMultiViewDataset): + """Dataset of outdoor street scenes, 5 images each time""" + + def __init__(self, *args, ROOT, **kwargs): + self.ROOT = ROOT + self.max_interval = 8 + self.video = True + self.is_metric = True + super().__init__(*args, **kwargs) + assert self.split is None + self._load_data() + + def load_invalid_dict(self, h5_file_path): + invalid_dict = {} + with h5py.File(h5_file_path, "r") as h5f: + for scene in h5f: + data = h5f[scene]["invalid_pairs"][:] + invalid_pairs = set( + tuple(pair.decode("utf-8").split("_")) for pair in data + ) + invalid_dict[scene] = invalid_pairs + return invalid_dict + + def _load_data(self): + + if os.path.exists(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl')): + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'rb') as f: + pre_calculated_data = pickle.load(f) + + self.scenes = pre_calculated_data['scenes'] + self.sceneids = pre_calculated_data['sceneids'] + self.images = pre_calculated_data['images'] + self.start_img_ids = pre_calculated_data['start_img_ids'] + self.scene_img_list = pre_calculated_data['scene_img_list'] + self.is_video = pre_calculated_data['is_video'] + + return + + + invalid_dict = self.load_invalid_dict( + os.path.join(self.ROOT, "invalid_files.h5") + ) + scene_dirs = sorted( + [ + d + for d in os.listdir(self.ROOT) + if os.path.isdir(os.path.join(self.ROOT, d)) + ] + ) + offset = 0 + scenes = [] + sceneids = [] + images = [] + start_img_ids = [] + scene_img_list = [] + is_video = [] + j = 0 + + for scene in scene_dirs: + scene_dir = osp.join(self.ROOT, scene) + if not os.path.isdir(scene_dir): + continue + invalid_pairs = invalid_dict.get(scene, set()) + seq2frames = {} + for f in os.listdir(scene_dir): + if not f.endswith(".jpg"): + continue + basename = f[:-4] + frame_id = basename.split("_")[0] + seq_id = basename.split("_")[1] + if seq_id == "5": + continue + if (seq_id, frame_id) in invalid_pairs: + continue # Skip invalid files + if seq_id not in seq2frames: + seq2frames[seq_id] = [] + seq2frames[seq_id].append(frame_id) + + for seq_id, frame_ids in seq2frames.items(): + frame_ids = sorted(frame_ids) + num_imgs = len(frame_ids) + img_ids = list(np.arange(num_imgs) + offset) + cut_off = ( + self.num_views + if not self.allow_repeat + else max(self.num_views // 3, 3) + ) + start_img_ids_ = img_ids[: num_imgs - cut_off + 1] + + if num_imgs < cut_off: + print(f"Skipping {scene}_{seq_id}") + continue + + scenes.append((scene, seq_id)) + sceneids.extend([j] * num_imgs) + images.extend(frame_ids) + start_img_ids.extend(start_img_ids_) + scene_img_list.append(img_ids) + + offset += num_imgs + j += 1 + + self.scenes = scenes + self.sceneids = sceneids + self.images = images + self.start_img_ids = start_img_ids + self.scene_img_list = scene_img_list + self.is_video = is_video + + with open(osp.join(self.ROOT, f'pre-calculated-loaddata-{self.num_views}.pkl'), 'wb') as f: + pickle.dump( + dict(scenes=self.scenes, + sceneids=self.sceneids, + images=images, + start_img_ids=start_img_ids, + scene_img_list=scene_img_list, + is_video=is_video), + f, + ) + + + def __len__(self): + return len(self.start_img_ids) + + def get_image_num(self): + return len(self.images) + + def get_stats(self): + return f"{len(self)} groups of views" + + def _get_views(self, idx, resolution, rng, num_views): + start_id = self.start_img_ids[idx] + all_image_ids = self.scene_img_list[self.sceneids[start_id]] + _, seq_id = self.scenes[self.sceneids[start_id]] + max_interval = self.max_interval // 2 if seq_id == "4" else self.max_interval + pos, ordered_video = self.get_seq_from_start_id( + num_views, + start_id, + all_image_ids, + rng, + max_interval=max_interval, + video_prob=0.9, + fix_interval_prob=0.9, + block_shuffle=16, + ) + image_idxs = np.array(all_image_ids)[pos] + views = [] + ordered_video = True + + views = [] + + for v, view_idx in enumerate(image_idxs): + scene_id = self.sceneids[view_idx] + scene_dir, seq_id = self.scenes[scene_id] + scene_dir = osp.join(self.ROOT, scene_dir) + frame_id = self.images[view_idx] + + impath = f"{frame_id}_{seq_id}" + image = imread_cv2(osp.join(scene_dir, impath + ".jpg")) + depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr")) + camera_params = load_file(osp.join(scene_dir, impath + ".safetensor")) + + intrinsics = np.float32(camera_params["intrinsics"]) + camera_pose = np.float32(camera_params["cam2world"]) + + image, depthmap, intrinsics = self._crop_resize_if_necessary( + image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath) + ) + + # generate img mask and raymap mask + img_mask, ray_mask = self.get_img_and_ray_masks( + self.is_metric, v, rng, p=[0.85, 0.10, 0.05] + ) + + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=camera_pose, # cam2world + camera_intrinsics=intrinsics, + dataset="Waymo", + label=osp.relpath(scene_dir, self.ROOT), + is_metric=self.is_metric, + instance=osp.join(scene_dir, impath + ".jpg"), + is_video=ordered_video, + quantile=np.array(0.98, dtype=np.float32), + img_mask=img_mask, + ray_mask=ray_mask, + camera_only=False, + depth_only=False, + single_view=False, + reset=False, + ) + ) + + return views diff --git a/stream3r/dust3r/datasets_cut3r/wildrgbd.py b/stream3r/dust3r/datasets_cut3r/wildrgbd.py new file mode 100755 index 0000000000000000000000000000000000000000..82738e8a71aca6a44e8ca780a4f79fe93d530c2b --- /dev/null +++ b/stream3r/dust3r/datasets_cut3r/wildrgbd.py @@ -0,0 +1,57 @@ +import os.path as osp +import sys + +sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) +import cv2 +import numpy as np + +from stream3r.dust3r.datasets_cut3r.co3d import Co3d_Multi +from safetensors.numpy import save_file, load_file +from stream3r.dust3r.utils.image import imread_cv2 + + +class WildRGBD_Multi(Co3d_Multi): + def __init__(self, mask_bg="rand", *args, ROOT, **kwargs): + super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) + self.dataset_label = "WildRGBD" + self.is_metric = True + # load all scenes + self.scenes.pop(("box", "scenes/scene_257"), None) + self.scene_list = list(self.scenes.keys()) + cut_off = ( + self.num_views if not self.allow_repeat else max(self.num_views // 3, 3) + ) + self.cut_off = cut_off + self.all_ref_imgs = [ + (key, value) + for key, values in self.scenes.items() + for value in values[: len(values) - cut_off + 1] + ] + self.invalidate = {scene: {} for scene in self.scene_list} + self.invalid_scenes = {scene: False for scene in self.scene_list} + + def _get_metadatapath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "metadata", f"{view_idx:0>5d}.safetensor") + + def _get_impath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "rgb", f"{view_idx:0>5d}.jpg") + + def _get_depthpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "depth", f"{view_idx:0>5d}.png") + + def _get_maskpath(self, obj, instance, view_idx): + return osp.join(self.ROOT, obj, instance, "masks", f"{view_idx:0>5d}.png") + + def _read_depthmap(self, depthpath, input_metadata): + # We store depths in the depth scale of 1000. + # That is, when we load depth image and divide by 1000, we could get depth in meters. + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = depthmap.astype(np.float32) / 1000.0 + return depthmap + + def _get_views(self, idx, resolution, rng, num_views): + views = super()._get_views(idx, resolution, rng, num_views) + for view in views: + assert view["is_metric"] + view["quantile"] = np.array(0.96, dtype=np.float32) + return views diff --git a/stream3r/dust3r/heads/__init__.py b/stream3r/dust3r/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f27fdc0d1577d588c3488e5bdd2a1111b8136a8 --- /dev/null +++ b/stream3r/dust3r/heads/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# head factory +# -------------------------------------------------------- +from .dpt_head import create_dpt_head +from .linear_head import LinearPts3d + + +def head_factory(head_type, output_mode, net, has_conf=False): + """ " build a prediction head for the decoder""" + if head_type == "linear" and output_mode == "pts3d": + return LinearPts3d(net, has_conf) + elif head_type == "dpt" and output_mode == "pts3d": + return create_dpt_head(net, has_conf=has_conf) + else: + raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") diff --git a/stream3r/dust3r/heads/dpt_head.py b/stream3r/dust3r/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb2bcffacb23adfe07d52082f09988e9fcd1026 --- /dev/null +++ b/stream3r/dust3r/heads/dpt_head.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# dpt head implementation for DUST3R +# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; +# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True +# the forward function also takes as input a dictionnary img_info with key "height" and "width" +# for PixelwiseTask, the output will be of dimension B x num_channels x H x W +# -------------------------------------------------------- +from typing import List + +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F + +from stream3r.croco.models.dpt_block import DPTOutputAdapter +from stream3r.croco.models.blocks import Mlp + +import stream3r.dust3r.utils.path_to_croco # noqa: F401 +from stream3r.dust3r.heads.postprocess import postprocess + + +class DPTOutputAdapter_fix(DPTOutputAdapter): + """ + Adapt croco's DPTOutputAdapter implementation for dust3r: + remove duplicated weigths, and fix forward for dust3r + """ + + def init(self, dim_tokens_enc=768): + super().init(dim_tokens_enc) + # these are duplicated weights + del self.act_1_postprocess + del self.act_2_postprocess + del self.act_3_postprocess + del self.act_4_postprocess + + def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): + assert ( + self.dim_tokens_enc is not None + ), "Need to call init(dim_tokens_enc) function first" + # H, W = input_info['image_size'] + image_size = self.image_size if image_size is None else image_size + H, W = image_size + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [ + rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers + ] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3])[ + :, :, : layers[2].shape[2], : layers[2].shape[3] + ] + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Split input into chunks to avoid memory issues with large batches + if self.training: + max_chunk_size = 1 + else: + max_chunk_size = 50 + chunks = torch.split(path_1, max_chunk_size, dim=0) + outputs = [] + + for chunk in chunks: + out_chunk = self.head(chunk) + outputs.append(out_chunk) + + # Concatenate outputs along the batch dimension + out = torch.cat(outputs, dim=0) + return out + + +class PixelwiseTaskWithDPT(nn.Module): + """DPT module for dust3r, can return 3D points + confidence for all pixels""" + + def __init__( + self, + *, + n_cls_token=0, + hooks_idx=None, + dim_tokens=None, + output_width_ratio=1, + num_channels=1, + postprocess=None, + depth_mode=None, + conf_mode=None, + vis_mode=None, + **kwargs + ): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_layers = True # backbone needs to return all layers + self.postprocess = postprocess + self.depth_mode = depth_mode + self.conf_mode = conf_mode + self.vis_mode = vis_mode + + assert n_cls_token == 0, "Not implemented" + dpt_args = dict( + output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs + ) + if hooks_idx is not None: + dpt_args.update(hooks=hooks_idx) + self.dpt = DPTOutputAdapter_fix(**dpt_args) + dpt_init_args = {} if dim_tokens is None else {"dim_tokens_enc": dim_tokens} + self.dpt.init(**dpt_init_args) + + def forward(self, x, img_info): + out = self.dpt(x, image_size=(img_info[0], img_info[1])) + if self.postprocess: + out = self.postprocess(out, self.depth_mode, self.conf_mode, self.vis_mode) + return out + + +class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT): + """ Mixture between MLP and DPT head that outputs 3d points and local features (with MLP). + The input for both heads is a concatenation of Encoder and Decoder outputs + """ + + def __init__( + self, + desc_head_args, + local_feat_dim=16, + hidden_dim_factor=4., + hooks_idx=None, + dim_tokens=None, + num_channels=1, + postprocess=None, + feature_dim=256, + last_dim=32, + depth_mode=None, + conf_mode=None, + head_type="regression", + **kwargs + ): + super().__init__( + num_channels=num_channels, + feature_dim=feature_dim, + last_dim=last_dim, + hooks_idx=hooks_idx, + dim_tokens=dim_tokens, + depth_mode=depth_mode, + postprocess=postprocess, + conf_mode=conf_mode, + head_type=head_type + ) + self.local_feat_dim = local_feat_dim + + patch_size = desc_head_args['patch_size'] + if isinstance(patch_size, tuple): + assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance( + patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints." + assert patch_size[0] == patch_size[1], "Error, non square patches not managed" + patch_size = patch_size[0] + self.patch_size = patch_size + + self.desc_mode = desc_head_args['desc_mode'] + self.two_confs = desc_head_args['two_confs'] # independent confs for 3D regr and descs + self.desc_conf_mode = desc_head_args['desc_conf_mode'] + idim = desc_head_args['enc_embed_dim'] + desc_head_args['dec_embed_dim'] + + self.features_head = Mlp( + in_features = idim, + hidden_features = int(hidden_dim_factor * idim), + out_features = (self.local_feat_dim + self.two_confs) * self.patch_size**2 + ) + + def forward(self, decout, img_shape): + # pass through the heads + pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1])) + + # recover encoder and decoder outputs + enc_output, dec_output = decout[0], decout[-1] + cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate + H, W = img_shape + B, S, D = cat_output.shape + + # extract local_features + local_features = self.features_head(cat_output) # B,S,D + local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size) + local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W + + # post process 3D pts, descriptors and confidences + out = torch.cat([pts3d, local_features], dim=1) + if self.postprocess: + out = self.postprocess( + out, + depth_mode=self.depth_mode, + conf_mode=self.conf_mode, + desc_dim=self.local_feat_dim, + desc_mode=self.desc_mode, + two_confs=self.two_confs, + desc_conf_mode=self.desc_conf_mode + ) + + return out + + +def create_dpt_head(net, has_conf=False): + """ + return PixelwiseTaskWithDPT for given net params + """ + assert net.dec_depth > 9 + l2 = net.dec_depth + feature_dim = 256 + last_dim = feature_dim // 2 + out_nchan = 3 + ed = net.enc_embed_dim + dd = net.dec_embed_dim + return PixelwiseTaskWithDPT( + num_channels=out_nchan + has_conf, + feature_dim=feature_dim, + last_dim=last_dim, + hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], + dim_tokens=[ed, dd, dd, dd], + postprocess=postprocess, + depth_mode=net.depth_mode, + conf_mode=net.conf_mode, + head_type="regression", + ) diff --git a/stream3r/dust3r/heads/linear_head.py b/stream3r/dust3r/heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..be86d749cfd7ad9a7aed1fde72576df92f8767cf --- /dev/null +++ b/stream3r/dust3r/heads/linear_head.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# linear head implementation for DUST3R +# -------------------------------------------------------- +import torch.nn as nn +import torch.nn.functional as F + +from stream3r.dust3r.heads.postprocess import postprocess + + +class LinearPts3d(nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__(self, net, has_conf=False): + super().__init__() + self.patch_size = net.patch_embed.patch_size[0] + self.depth_mode = net.depth_mode + self.conf_mode = net.conf_mode + self.has_conf = has_conf + + self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf) * self.patch_size**2) + + def setup(self, croconet): + pass + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + # extract 3D points + feat = self.proj(tokens) # B,S,D + feat = feat.transpose(-1, -2).view( + B, -1, H // self.patch_size, W // self.patch_size + ) + feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W + + # permute + norm depth + return postprocess(feat, self.depth_mode, self.conf_mode) diff --git a/stream3r/dust3r/heads/postprocess.py b/stream3r/dust3r/heads/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5a088a73b23a807607b69d87064269e42f43af13 --- /dev/null +++ b/stream3r/dust3r/heads/postprocess.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# post process function for all heads: extract 3D points/confidence from output +# -------------------------------------------------------- +import torch + + +def reg_desc(desc, mode): + if 'norm' in mode: + desc = desc / desc.norm(dim=-1, keepdim=True) + else: + raise ValueError(f"Unknown desc mode {mode}") + return desc + + +def postprocess_with_feature(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None): + if desc_conf_mode is None: + desc_conf_mode = conf_mode + fmap = out.permute(0, 2, 3, 1) # B,H,W,D + res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode)) + if conf_mode is not None: + res['conf'] = reg_dense_conf(fmap[..., 3], mode=conf_mode) + if desc_dim is not None: + start = 3 + int(conf_mode is not None) + res['desc'] = reg_desc(fmap[..., start:start + desc_dim], mode=desc_mode) + if two_confs: + res['desc_conf'] = reg_dense_conf(fmap[..., start + desc_dim], mode=desc_conf_mode) + else: + res['desc_conf'] = res['conf'].clone() + return res + + +def postprocess(out, depth_mode, conf_mode, vis_mode): + """ + extract 3D points/confidence from prediction head output + """ + fmap = out.permute(0, 2, 3, 1) # B,H,W,3 + res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) + + if conf_mode is not None: + res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) + + if vis_mode is not None: + res["vis"] = reg_dense_conf(fmap[:, :, :, 3:4], mode=vis_mode) + + return res + + +def reg_dense_depth(xyz, mode): + """ + extract 3D points from prediction head output + """ + mode, vmin, vmax = mode + + no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) + assert no_bounds + + if mode == "linear": + if no_bounds: + return xyz # [-inf, +inf] + return xyz.clip(min=vmin, max=vmax) + + # distance to origin + d = xyz.norm(dim=-1, keepdim=True) + xyz = xyz / d.clip(min=1e-8) + + if mode == "square": + return xyz * d.square() + + if mode == "exp": + return xyz * torch.expm1(d) + + raise ValueError(f"bad {mode=}") + + +def reg_dense_conf(x, mode): + """ + extract confidence from prediction head output + """ + mode, vmin, vmax = mode + if mode == "exp": + return vmin + x.exp().clip(max=vmax - vmin) + if mode == "sigmoid": + return (vmax - vmin) * torch.sigmoid(x) + vmin + if mode == "none": + return x + raise ValueError(f"bad {mode=}") diff --git a/stream3r/dust3r/image_pairs.py b/stream3r/dust3r/image_pairs.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1344e6e068ca083e5142b8e7a6cbc0f9e52b6b --- /dev/null +++ b/stream3r/dust3r/image_pairs.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities needed to load image pairs +# -------------------------------------------------------- +import numpy as np +import torch + + +def make_pairs(imgs, scene_graph="complete", prefilter=None, symmetrize=True): + pairs = [] + if scene_graph == "complete": # complete graph + for i in range(len(imgs)): + for j in range(i): + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith("swin"): + winsize = int(scene_graph.split("-")[1]) if "-" in scene_graph else 3 + pairsid = set() + for i in range(len(imgs)): + for j in range(1, winsize + 1): + idx = (i + j) % len(imgs) # explicit loop closure + pairsid.add((i, idx) if i < idx else (idx, i)) + for i, j in pairsid: + pairs.append((imgs[i], imgs[j])) + elif scene_graph.startswith("oneref"): + refid = int(scene_graph.split("-")[1]) if "-" in scene_graph else 0 + for j in range(len(imgs)): + if j != refid: + pairs.append((imgs[refid], imgs[j])) + if symmetrize: + pairs += [(img2, img1) for img1, img2 in pairs] + + # now, remove edges + if isinstance(prefilter, str) and prefilter.startswith("seq"): + pairs = filter_pairs_seq(pairs, int(prefilter[3:])) + + if isinstance(prefilter, str) and prefilter.startswith("cyc"): + pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) + + return pairs + + +def sel(x, kept): + if isinstance(x, dict): + return {k: sel(v, kept) for k, v in x.items()} + if isinstance(x, (torch.Tensor, np.ndarray)): + return x[kept] + if isinstance(x, (tuple, list)): + return type(x)([x[k] for k in kept]) + + +def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): + # number of images + n = max(max(e) for e in edges) + 1 + + kept = [] + for e, (i, j) in enumerate(edges): + dis = abs(i - j) + if cyclic: + dis = min(dis, abs(i + n - j), abs(i - n - j)) + if dis <= seq_dis_thr: + kept.append(e) + return kept + + +def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): + edges = [(img1["idx"], img2["idx"]) for img1, img2 in pairs] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + return [pairs[i] for i in kept] + + +def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): + edges = [(int(i), int(j)) for i, j in zip(view1["idx"], view2["idx"])] + kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) + print( + f">> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges" + ) + return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) diff --git a/stream3r/dust3r/inference.py b/stream3r/dust3r/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f7437a92eaff517e984824529b6df18c89f83d66 --- /dev/null +++ b/stream3r/dust3r/inference.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities needed for the inference +# -------------------------------------------------------- +import torch +import tqdm + +from stream3r.dust3r.utils.device import collate_with_cat, to_cpu +from stream3r.dust3r.utils.geometry import depthmap_to_pts3d, geotrf +from stream3r.dust3r.utils.misc import invalid_to_nans + + +def _interleave_imgs(img1, img2): + res = {} + for key, value1 in img1.items(): + value2 = img2[key] + if isinstance(value1, torch.Tensor): + value = torch.stack((value1, value2), dim=1).flatten(0, 1) + else: + value = [x for pair in zip(value1, value2) for x in pair] + res[key] = value + return res + + +def make_batch_symmetric(batch): + view1, view2 = batch + view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) + return view1, view2 + + +def loss_of_one_batch( + batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None +): + view1, view2 = batch + for view in batch: + for ( + name + ) in ( + "img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres".split() + ): # pseudo_focal + if name not in view: + continue + view[name] = view[name].to(device, non_blocking=True) + + if symmetrize_batch: + view1, view2 = make_batch_symmetric(batch) + + with torch.cuda.amp.autocast(enabled=bool(use_amp)): + pred1, pred2 = model(view1, view2) + + # loss is supposed to be symmetric + with torch.cuda.amp.autocast(enabled=False): + loss = ( + criterion(view1, view2, pred1, pred2) if criterion is not None else None + ) + + result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) + return result[ret] if ret else result + + +@torch.no_grad() +def inference(pairs, model, device, batch_size=8, verbose=True): + if verbose: + print(f">> Inference with model on {len(pairs)} image pairs") + result = [] + + # first, check if all images have the same size + multiple_shapes = not (check_if_same_size(pairs)) + if multiple_shapes: # force bs=1 + batch_size = 1 + + for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose): + res = loss_of_one_batch( + collate_with_cat(pairs[i : i + batch_size]), model, None, device + ) + result.append(to_cpu(res)) + + result = collate_with_cat(result, lists=multiple_shapes) + + return result + + +def check_if_same_size(pairs): + shapes1 = [img1["img"].shape[-2:] for img1, img2 in pairs] + shapes2 = [img2["img"].shape[-2:] for img1, img2 in pairs] + return all(shapes1[0] == s for s in shapes1) and all( + shapes2[0] == s for s in shapes2 + ) + + +def get_pred_pts3d(gt, pred, use_pose=False): + if "depth" in pred and "pseudo_focal" in pred: + try: + pp = gt["camera_intrinsics"][..., :2, 2] + except KeyError: + pp = None + pts3d = depthmap_to_pts3d(**pred, pp=pp) + + elif "pts3d" in pred: + # pts3d from my camera + pts3d = pred["pts3d"] + + elif "pts3d_in_other_view" in pred: + # pts3d from the other camera, already transformed + assert use_pose is True + return pred["pts3d_in_other_view"] # return! + + if use_pose: + camera_pose = pred.get("camera_pose") + assert camera_pose is not None + pts3d = geotrf(camera_pose, pts3d) + + return pts3d + + +def find_opt_scaling( + gt_pts1, + gt_pts2, + pr_pts1, + pr_pts2=None, + fit_mode="weiszfeld_stop_grad", + valid1=None, + valid2=None, +): + assert gt_pts1.ndim == pr_pts1.ndim == 4 + assert gt_pts1.shape == pr_pts1.shape + if gt_pts2 is not None: + assert gt_pts2.ndim == pr_pts2.ndim == 4 + assert gt_pts2.shape == pr_pts2.shape + + # concat the pointcloud + nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) + nan_gt_pts2 = ( + invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None + ) + + pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) + pr_pts2 = ( + invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None + ) + + all_gt = ( + torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) + if gt_pts2 is not None + else nan_gt_pts1 + ) + all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 + + dot_gt_pr = (all_pr * all_gt).sum(dim=-1) + dot_gt_gt = all_gt.square().sum(dim=-1) + + if fit_mode.startswith("avg"): + # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1) + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + elif fit_mode.startswith("median"): + scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values + elif fit_mode.startswith("weiszfeld"): + # init scaling with l2 closed form + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip_(min=1e-8).reciprocal() + # update the scaling with the new weights + scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) + else: + raise ValueError(f"bad {fit_mode=}") + + if fit_mode.endswith("stop_grad"): + scaling = scaling.detach() + + scaling = scaling.clip(min=1e-3) + # assert scaling.isfinite().all(), bb() + return scaling diff --git a/stream3r/dust3r/inference_multiview.py b/stream3r/dust3r/inference_multiview.py new file mode 100644 index 0000000000000000000000000000000000000000..58fea88e27db70ffa3d6ac937df00da25486beaf --- /dev/null +++ b/stream3r/dust3r/inference_multiview.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities needed for the inference +# -------------------------------------------------------- +import time +import torch +import tqdm + +from stream3r.dust3r.utils.device import collate_with_cat, to_cpu +from stream3r.dust3r.utils.geometry import depthmap_to_pts3d, geotrf +from stream3r.dust3r.utils.misc import invalid_to_nans + + +def loss_of_one_batch( + batch, model, criterion, device, precision, symmetrize_batch=False, use_amp=False, ret=None, profiling=False, +): + """ + Args: + batch (list[dict]): a list of views, each view is a dict of tensors, the tensors are batched + """ + for view in batch: + for ( + name + ) in ( + "img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres".split() + ): # pseudo_focal + if name not in view: + continue + view[name] = view[name].to(device, non_blocking=True) + + views = batch + + autocast_dict = dict(device_type=device.type) + if precision == "32": + autocast_dict["enabled"] = False + elif precision == "16-mixed": + autocast_dict["dtype"] = torch.float16 + elif precision in ["bf16-mixed", "bf16-mixed-no-grad-scaling"]: + autocast_dict["dtype"] = torch.bfloat16 + elif precision == torch.bfloat16: + autocast_dict["dtype"] = torch.bfloat16 + + + with torch.autocast(**autocast_dict): + if profiling: + preds, profiling_info = model(views, profiling=profiling) + else: + preds = model(views, profiling=profiling) + + # loss is supposed to be symmetric + loss = ( + criterion(views, preds) if criterion is not None else None + ) + + result = dict(views=views, preds=preds, loss=loss) + if profiling: + result["profiling_info"] = profiling_info + + return result[ret] if ret else result + + +@torch.no_grad() +def inference(multiple_views_in_one_sample, model, device, dtype, verbose=True, profiling=False, use_center_as_anchor=False): + if verbose: + print(f">> Inference with model on {len(multiple_views_in_one_sample)} images") + result = [] + + if use_center_as_anchor: + center_image_id = len(multiple_views_in_one_sample) // 2 + center_image = multiple_views_in_one_sample[center_image_id] + multiple_views_in_one_sample.pop(center_image_id) + multiple_views_in_one_sample.insert(0, center_image) + + # first, check if all images have the same size + multiple_shapes = not (check_if_same_size(multiple_views_in_one_sample)) + if multiple_shapes: # force bs=1 + batch_size = 1 + + # Get the result from loss_of_one_batch + res = loss_of_one_batch( + collate_with_cat([tuple(multiple_views_in_one_sample)]), model, None, device, dtype, profiling=profiling + ) + + # Extract profiling_info before to_cpu if it exists + profiling_info = None + if profiling and "profiling_info" in res: + profiling_info = res.pop("profiling_info") + + # Process the result without profiling_info + result.append(to_cpu(res)) + result = collate_with_cat(result, lists=multiple_shapes) + + # this would cause a bug as tracking rgb is still using the first view, but we just ignore it for now + if use_center_as_anchor: + # k is "preds", "views", "losses" (could be None) + for k, v in result.items(): + if v is not None: + center_result = v[0] + result[k] = v[1:] + result[k].insert(center_image_id, center_result) + + # Return the result with profiling_info if requested + if profiling and profiling_info is not None: + return result, profiling_info + + return result + + +def check_if_same_size(imgs): + shapes = [img["img"].shape[-2:] for img in imgs] + return all(shape == shapes[0] for shape in shapes) + + +def get_pred_pts3d(gt, pred, use_pose=False): + if "depth" in pred and "pseudo_focal" in pred: + try: + pp = gt["camera_intrinsics"][..., :2, 2] + except KeyError: + pp = None + pts3d = depthmap_to_pts3d(**pred, pp=pp) + + elif "pts3d" in pred: + # pts3d from my camera + pts3d = pred["pts3d"] + + elif "pts3d_in_other_view" in pred: + # pts3d from the other camera, already transformed + assert use_pose is True + return pred["pts3d_in_other_view"] # return! + + if use_pose: + camera_pose = pred.get("camera_pose") + assert camera_pose is not None + pts3d = geotrf(camera_pose, pts3d) + + return pts3d + + +def find_opt_scaling( + gt_pts1, + gt_pts2, + pr_pts1, + pr_pts2=None, + fit_mode="weiszfeld_stop_grad", + valid1=None, + valid2=None, +): + assert gt_pts1.ndim == pr_pts1.ndim == 4 + assert gt_pts1.shape == pr_pts1.shape + if gt_pts2 is not None: + assert gt_pts2.ndim == pr_pts2.ndim == 4 + assert gt_pts2.shape == pr_pts2.shape + + # concat the pointcloud + nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) + nan_gt_pts2 = ( + invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None + ) + + pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) + pr_pts2 = ( + invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None + ) + + all_gt = ( + torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) + if gt_pts2 is not None + else nan_gt_pts1 + ) + all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 + + dot_gt_pr = (all_pr * all_gt).sum(dim=-1) + dot_gt_gt = all_gt.square().sum(dim=-1) + + if fit_mode.startswith("avg"): + # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1) + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + elif fit_mode.startswith("median"): + scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values + elif fit_mode.startswith("weiszfeld"): + # init scaling with l2 closed form + scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip_(min=1e-8).reciprocal() + # update the scaling with the new weights + scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) + else: + raise ValueError(f"bad {fit_mode=}") + + if fit_mode.endswith("stop_grad"): + scaling = scaling.detach() + + scaling = scaling.clip(min=1e-3) + # assert scaling.isfinite().all(), bb() + return scaling diff --git a/stream3r/dust3r/model.py b/stream3r/dust3r/model.py new file mode 100644 index 0000000000000000000000000000000000000000..287335064fa33e18efaf1282e88e7ae1fc0a4523 --- /dev/null +++ b/stream3r/dust3r/model.py @@ -0,0 +1,615 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# DUSt3R model class +# -------------------------------------------------------- +import os +from copy import deepcopy + +import huggingface_hub +import torch +import torch.distributed +import torch.nn as nn +import numpy as np +from stream3r.croco.models.croco import CroCoNet +from stream3r.croco.models.blocks import Block +from stream3r.croco.models.pos_embed import get_1d_sincos_pos_embed_from_grid +from packaging import version + +from stream3r.dust3r.patch_embed import get_patch_embed + +from .heads import head_factory +from .utils.misc import ( + fill_default_args, + freeze_all_params, + interleave, + is_symmetrized, + transpose_to_landscape, +) +import torch.autograd.profiler as profiler + +inf = float("inf") + +hf_version_number = huggingface_hub.__version__ +assert version.parse(hf_version_number) >= version.parse( + "0.22.0" +), "Outdated huggingface_hub version, please reinstall requirements.txt" + + +def load_model(model_path, device, verbose=True): + if verbose: + print("... loading model from", model_path) + ckpt = torch.load(model_path, map_location="cpu") + args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") + if "landscape_only" not in args: + args = args[:-1] + ", landscape_only=False)" + else: + args = args.replace(" ", "").replace( + "landscape_only=True", "landscape_only=False" + ) + assert "landscape_only=False" in args + if verbose: + print(f"instantiating : {args}") + net = eval(args) + s = net.load_state_dict(ckpt["model"], strict=False) + if verbose: + print(s) + return net.to(device) + + +class AsymmetricCroCo3DStereo( + CroCoNet, + huggingface_hub.PyTorchModelHubMixin, + library_name="dust3r", + repo_url="https://github.com/naver/dust3r", + tags=["image-to-3d"], +): + """Two siamese encoders, followed by two decoders. + The goal is to output 3d points directly, both images in view1's frame + (hence the asymmetry). + """ + + def __init__( + self, + output_mode="pts3d", + head_type="linear", + depth_mode=("exp", -inf, inf), + conf_mode=("exp", 1, inf), + freeze="none", + landscape_only=True, + patch_embed_cls="PatchEmbedDust3R", # PatchEmbedDust3R or ManyAR_PatchEmbed + **croco_kwargs, + ): + self.patch_embed_cls = patch_embed_cls + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + super().__init__(**croco_kwargs) + + # dust3r specific initialization + self.dec_blocks2 = deepcopy(self.dec_blocks) + self.set_downstream_head( + output_mode, + head_type, + landscape_only, + depth_mode, + conf_mode, + **croco_kwargs, + ) + self.set_freeze(freeze) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kw): + if os.path.isfile(pretrained_model_name_or_path): + return load_model(pretrained_model_name_or_path, device="cpu") + else: + return super(AsymmetricCroCo3DStereo, cls).from_pretrained( + pretrained_model_name_or_path, **kw + ) + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = get_patch_embed( + self.patch_embed_cls, img_size, patch_size, enc_embed_dim + ) + + def load_state_dict(self, ckpt, **kw): + # duplicate all weights for the second decoder if not present + new_ckpt = dict(ckpt) + if not any(k.startswith("dec_blocks2") for k in ckpt): + for key, value in ckpt.items(): + if key.startswith("dec_blocks"): + new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value + return super().load_state_dict(new_ckpt, **kw) + + def set_freeze(self, freeze): # this is for use by downstream models + self.freeze = freeze + to_be_frozen = { + "none": [], + "mask": [self.mask_token], + "encoder": [self.mask_token, self.patch_embed, self.enc_blocks], + } + freeze_all_params(to_be_frozen[freeze]) + + def _set_prediction_head(self, *args, **kwargs): + """No prediction head""" + return + + def set_downstream_head( + self, + output_mode, + head_type, + landscape_only, + depth_mode, + conf_mode, + patch_size, + img_size, + **kw, + ): + assert ( + img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0 + ), f"{img_size=} must be multiple of {patch_size=}" + self.output_mode = output_mode + self.head_type = head_type + self.depth_mode = depth_mode + self.conf_mode = conf_mode + # allocate heads + self.downstream_head1 = head_factory( + head_type, output_mode, self, has_conf=bool(conf_mode) + ) + self.downstream_head2 = head_factory( + head_type, output_mode, self, has_conf=bool(conf_mode) + ) + # magic wrapper + self.head1 = transpose_to_landscape( + self.downstream_head1, activate=landscape_only + ) + self.head2 = transpose_to_landscape( + self.downstream_head2, activate=landscape_only + ) + + def _encode_image(self, image, true_shape): + # embed the image into patches (x has size B x Npatches x C) + x, pos = self.patch_embed(image, true_shape=true_shape) + + # add positional embedding without cls token + assert self.enc_pos_embed is None + + # now apply the transformer encoder and normalization + for blk in self.enc_blocks: + x = blk(x, pos) + + x = self.enc_norm(x) + return x, pos, None + + def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): + if img1.shape[-2:] == img2.shape[-2:]: + out, pos, _ = self._encode_image( + torch.cat((img1, img2), dim=0), + torch.cat((true_shape1, true_shape2), dim=0), + ) + out, out2 = out.chunk(2, dim=0) + pos, pos2 = pos.chunk(2, dim=0) + else: + out, pos, _ = self._encode_image(img1, true_shape1) + out2, pos2, _ = self._encode_image(img2, true_shape2) + return out, out2, pos, pos2 + + def _encode_symmetrized(self, view1, view2): + img1 = view1["img"] + img2 = view2["img"] + B = img1.shape[0] + # Recover true_shape when available, otherwise assume that the img shape is the true one + shape1 = view1.get( + "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1) + ) + shape2 = view2.get( + "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1) + ) + # warning! maybe the images have different portrait/landscape orientations + + if is_symmetrized(view1, view2): + # computing half of forward pass!' + feat1, feat2, pos1, pos2 = self._encode_image_pairs( + img1[::2], img2[::2], shape1[::2], shape2[::2] + ) + feat1, feat2 = interleave(feat1, feat2) + pos1, pos2 = interleave(pos1, pos2) + else: + feat1, feat2, pos1, pos2 = self._encode_image_pairs( + img1, img2, shape1, shape2 + ) + + return (shape1, shape2), (feat1, feat2), (pos1, pos2) + + def _decoder(self, f1, pos1, f2, pos2): + final_output = [(f1, f2)] # before projection + + # project to decoder dim + f1 = self.decoder_embed(f1) + f2 = self.decoder_embed(f2) + + final_output.append((f1, f2)) + for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): + # img1 side + f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) + # img2 side + f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) + # store the result + final_output.append((f1, f2)) + + # normalize last output + del final_output[1] # duplicate with final_output[0] + final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) + return zip(*final_output) + + def _downstream_head(self, head_num, decout, img_shape): + B, S, D = decout[-1].shape + # img_shape = tuple(map(int, img_shape)) + head = getattr(self, f"head{head_num}") + return head(decout, img_shape) + + def forward(self, view1, view2): + # encode the two images --> B,S,D + (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized( + view1, view2 + ) + + # combine all ref images into object-centric representation + dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) + + with torch.cuda.amp.autocast(enabled=False): + res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) + res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) + + res2["pts3d_in_other_view"] = res2.pop( + "pts3d" + ) # predict view2's pts3d in view1's frame + return res1, res2 + + +class FlashDUSt3R( + CroCoNet, + huggingface_hub.PyTorchModelHubMixin, + library_name="dust3r", + repo_url="https://github.com/naver/dust3r", + tags=["image-to-3d"], +): + """Two siamese encoders, followed by a single large decoder. + The goal is to output 3d points directly, processing multiple views. + """ + + def __init__( + self, + output_mode="pts3d", + head_type="linear", + depth_mode=("exp", -inf, inf), + conf_mode=("exp", 1, inf), + freeze="none", + landscape_only=True, + patch_embed_cls="PatchEmbedDust3R", # PatchEmbedDust3R or ManyAR_PatchEmbed + decoder_pos_embed_type="sinusoidal", + attn_implementation="pytorch_naive", + random_image_idx_embedding=False, + **croco_kwargs, + ): + self.patch_embed_cls = patch_embed_cls + self.random_image_idx_embedding = random_image_idx_embedding + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + croco_kwargs["attn_implementation"] = attn_implementation + super().__init__(**croco_kwargs) + + # Pre-initialize image position embeddings for IDs 0 to 9999 + self.register_buffer( + "image_idx_emb", + torch.from_numpy( + get_1d_sincos_pos_embed_from_grid(self.dec_embed_dim, np.arange(1000)) + ).float(), + persistent=False, + ) + + del self.dec_blocks # remove the decoder blocks + torch.cuda.empty_cache() + # dust3r specific initialization + self.decoder_pos_embed_type = decoder_pos_embed_type + self.multiview_dec_blocks = nn.ModuleList([ + Block( + dim=self.dec_embed_dim, num_heads=8, mlp_ratio=4.0, qkv_bias=True, drop=0.0, attn_drop=0.0, norm_layer=nn.LayerNorm, attn_implementation=attn_implementation, + ) for _ in range(12) + ]) + self.set_downstream_head( + output_mode, + head_type, + landscape_only, + depth_mode, + conf_mode, + **croco_kwargs, + ) + self.set_freeze(freeze) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kw): + if os.path.isfile(pretrained_model_name_or_path): + return load_model(pretrained_model_name_or_path, device="cpu") + else: + return super(FlashDUSt3R, cls).from_pretrained( + pretrained_model_name_or_path, **kw + ) + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = get_patch_embed( + self.patch_embed_cls, img_size, patch_size, enc_embed_dim + ) + + def load_state_dict(self, ckpt, **kw): + return super().load_state_dict(ckpt, **kw) + + def set_freeze(self, freeze): # this is for use by downstream models + self.freeze = freeze + to_be_frozen = { + "none": [], + "mask": [self.mask_token], + "encoder": [self.mask_token, self.patch_embed, self.enc_blocks], + "sandwich": [self.mask_token, self.patch_embed, self.enc_blocks, self.downstream_head], + } + freeze_all_params(to_be_frozen[freeze]) + + def _set_prediction_head(self, *args, **kwargs): + """No prediction head""" + return + + def set_downstream_head( + self, + output_mode, + head_type, + landscape_only, + depth_mode, + conf_mode, + patch_size, + img_size, + **kw, + ): + assert ( + img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0 + ), f"{img_size=} must be multiple of {patch_size=}" + self.output_mode = output_mode + self.head_type = head_type + self.depth_mode = depth_mode + self.conf_mode = conf_mode + # allocate head + self.downstream_head = head_factory( + head_type, output_mode, self, has_conf=bool(conf_mode) + ) + # magic wrapper + self.head = transpose_to_landscape( + self.downstream_head, activate=landscape_only + ) + + def _encode_image(self, image, true_shape): + # embed the image into patches (x has size B x Npatches x C) + x, pos = self.patch_embed(image, true_shape=true_shape) + + # add positional embedding without cls token + assert self.enc_pos_embed is None + + # now apply the transformer encoder and normalization + for blk in self.enc_blocks: + x = blk(x, pos) + + x = self.enc_norm(x) + return x, pos + + def _encode_images(self, views): + B = views[0]["img"].shape[0] + encoded_feats, positions, shapes = [], [], [] + + # TODO: Batchify this + for view in views: + img = view["img"] + true_shape = view.get( + "true_shape", torch.tensor(img.shape[-2:])[None].repeat(B, 1) + ) + feat, pos = self._encode_image(img, true_shape) + encoded_feats.append(feat) + positions.append(pos) + shapes.append(true_shape) + + return encoded_feats, positions, shapes + + def _generate_per_rank_generator(self): + # this way, the randperm will be different for each rank, but deterministic given a fixed number of forward passes (tracked by self.random_generator) + # and to ensure determinism when resuming from a checkpoint, we only need to save self.random_generator to state_dict + # generate a per-rank random seed + per_forward_pass_seed = torch.randint(0, 2 ** 32, (1,)).item() + world_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + per_rank_seed = per_forward_pass_seed + world_rank + + # Set the seed for the random generator + per_rank_generator = torch.Generator() + per_rank_generator.manual_seed(per_rank_seed) + return per_rank_generator + + def _get_random_image_pos(self, encoded_feats, batch_size, num_views, max_image_idx, device): + """ + Generates non-repeating random image indices for each sample, retrieves corresponding + positional embeddings for each view, and concatenates them. + + Args: + encoded_feats (list of tensors): Encoded features for each view. + batch_size (int): Number of samples in the batch. + num_views (int): Number of views per sample. + max_image_idx (int): Maximum image index for embedding. + device (torch.device): Device to move data to. + + Returns: + Tensor: Concatenated positional embeddings for the entire batch. + """ + # Generate random non-repeating image IDs (on CPU) + image_ids = torch.zeros(batch_size, num_views, dtype=torch.long) + + # First view is always 0 for all samples + image_ids[:, 0] = 0 + + # Get a generator that is unique to each rank, while also being deterministic based on the global across numbers of forward passes + per_rank_generator = self._generate_per_rank_generator() + + # Generate random non-repeating IDs for the remaining views using the generator + for b in range(batch_size): + # Use the torch.Generator for randomness to ensure randomness between forward passes + random_ids = torch.randperm(max_image_idx, generator=per_rank_generator)[:num_views - 1] + 1 + image_ids[b, 1:] = random_ids + + # Move the image IDs to the correct device + image_ids = image_ids.to(device) + + # Initialize list to store positional embeddings for all views + image_pos_list = [] + + for i in range(num_views): + # Retrieve the number of patches for this view + num_patches = encoded_feats[i].shape[1] + + # Gather the positional embeddings for the entire batch based on the random image IDs + image_pos_for_view = self.image_idx_emb[image_ids[:, i]] # (B, D) + + # Expand the positional embeddings to match the number of patches + image_pos_for_view = image_pos_for_view.unsqueeze(1).repeat(1, num_patches, 1) + + image_pos_list.append(image_pos_for_view) + + # Concatenate positional embeddings for all views along the patch dimension + image_pos = torch.cat(image_pos_list, dim=1) # (B, Npatches_total, D) + + return image_pos + + def _decoder(self, encoded_feats, positions, image_ids): + x = torch.cat(encoded_feats, dim=1) # concate along the patch dimension + pos = torch.cat(positions, dim=1) + + final_output = [x] # before projection + + # project to decoder dim + x = self.decoder_embed(x) + + # Add positional embedding based on image IDs + if self.random_image_idx_embedding: + # Generate random positional embeddings for all views and samples + image_pos = self._get_random_image_pos(encoded_feats=encoded_feats, + batch_size=encoded_feats[0].shape[0], + num_views=len(encoded_feats), + max_image_idx=self.image_idx_emb.shape[0] - 1, + device=x.device) + else: + # Use default image IDs from input + num_images = (torch.max(image_ids) + 1).cpu().item() + image_idx_emb = self.image_idx_emb[:num_images] + image_pos = image_idx_emb[image_ids] + + # Apply positional embedding based on image IDs and positions + x += image_pos # x has size B x Npatches x D, image_pos has size Npatches x D, so this is broadcasting + + for blk in self.multiview_dec_blocks: + x = blk(x, pos) + final_output.append(x) + + x = self.dec_norm(x) + final_output[-1] = x + return final_output + + def forward(self, views): + """ + Args: + views (list[dict]): a list of views, each view is a dict of tensors, the tensors are batched + + Returns: + list[dict]: a list of results for each view + """ + # encode the images --> B,S,D + encoded_feats, positions, shapes = self._encode_images(views) + + # Create image IDs for each patch + num_images = len(views) + B, _, _ = encoded_feats[0].shape + + different_resolution_across_views = not all(encoded_feats[0].shape[1] == encoded_feat.shape[1] for encoded_feat in encoded_feats) + + # Initialize an empty list to collect image IDs for each patch. + # Note that at inference time, different views may have different number of patches. + image_ids = [] + + # Loop through each encoded feature to get the actual number of patches + for i, encoded_feat in enumerate(encoded_feats): + num_patches = encoded_feat.shape[1] # Get the number of patches for this image + # Extend the image_ids list with the current image ID repeated num_patches times + image_ids.extend([i] * num_patches) + + # Repeat the image_ids list B times and reshape it to match the expected shape + image_ids = torch.tensor(image_ids * B).reshape(B, -1).to(encoded_feats[0].device) + + # combine all ref images into object-centric representation + dec_output = self._decoder(encoded_feats, positions, image_ids) + + ################## Forward pass through the head ################## + # TODO: optimize this + + # Initialize the final results list + final_results = [{} for _ in range(num_images)] + + with profiler.record_function("head: gathered outputs"): + # Prepare the gathered outputs for each layer + gathered_outputs_list = [] + if different_resolution_across_views: # If the views have different resolutions, gathered_outputs_list is a list of lists, the outer list is for different views, and the inner list is for different layers + for img_id in range(num_images): + gathered_outputs_per_view = [] + for layer_output in dec_output: + B, P, D = layer_output.shape + mask = (image_ids == img_id) + gathered_output = layer_output[mask].view(B, -1, D) + gathered_outputs_per_view.append(gathered_output) + gathered_outputs_list.append(gathered_outputs_per_view) + else: # If the views have the same resolution, gathered_outputs_list is a list of tensors, each tensor is for a different layer + for layer_output in dec_output: + B, P, D = layer_output.shape + gathered_outputs_per_view = [] + for img_id in range(num_images): + mask = (image_ids == img_id) + gathered_output = layer_output[mask].view(B, -1, D) + gathered_outputs_per_view.append(gathered_output) + gathered_outputs_list.append(torch.cat(gathered_outputs_per_view, dim=0)) # fold the view dimension into batch dimension + + with profiler.record_function("head: forward pass"): + if different_resolution_across_views: + # Forward pass for each view separately + final_results = [{} for _ in range(num_images)] + for img_id in range(num_images): + img_result = self.head(gathered_outputs_list[img_id], shapes[img_id]) + # Re-map the results back to the original batch and image order + for key in img_result.keys(): + if key == 'pts3d': + final_results[img_id]['pts3d_in_other_view'] = img_result[key] + else: + final_results[img_id][key] = img_result[key] + else: + # Concatenate shapes + concatenated_shapes = torch.cat(shapes, dim=0) + + # Forward pass through self.head() + result = self.head(gathered_outputs_list, concatenated_shapes) + + # Initialize the final results list + final_results = [{} for _ in range(num_images)] + + # Re-map the results back to the original batch and image order + for key in result.keys(): + for img_id in range(num_images): + img_result = result[key][img_id * B:(img_id + 1) * B] + if key == 'pts3d': + final_results[img_id]['pts3d_in_other_view'] = img_result + else: + final_results[img_id][key] = img_result + + return final_results diff --git a/stream3r/dust3r/optim_factory.py b/stream3r/dust3r/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..773ea27b90e4d92b0cf7d41e5dd1e35745ed2a9b --- /dev/null +++ b/stream3r/dust3r/optim_factory.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# optimization functions +# -------------------------------------------------------- + + +def adjust_learning_rate_by_lr(optimizer, lr): + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr diff --git a/stream3r/dust3r/patch_embed.py b/stream3r/dust3r/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..17fa5b2b7948b554414ba3a22db7d6122b1233e6 --- /dev/null +++ b/stream3r/dust3r/patch_embed.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# PatchEmbed implementation for DUST3R, +# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio +# -------------------------------------------------------- +import torch +from stream3r.croco.models.blocks import PatchEmbed + + +def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): + assert patch_embed_cls in ["PatchEmbedDust3R", "ManyAR_PatchEmbed"] + patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) + return patch_embed + + +class PatchEmbedDust3R(PatchEmbed): + def forward(self, x, **kw): + B, C, H, W = x.shape + assert ( + H % self.patch_size[0] == 0 + ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + +class ManyAR_PatchEmbed(PatchEmbed): + """Handle images with non-square aspect ratio. + All images in the same batch have the same aspect ratio. + true_shape = [(height, width) ...] indicates the actual shape of each image. + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + ): + self.embed_dim = embed_dim + super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) + + def forward(self, img, true_shape): + if not self.training: + x = img + B, C, H, W = x.shape + assert ( + H % self.patch_size[0] == 0 + ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + B, C, H, W = img.shape + assert W >= H, f"img should be in landscape mode, but got {W=} {H=}" + assert ( + H % self.patch_size[0] == 0 + ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + assert true_shape.shape == ( + B, + 2, + ), f"true_shape has the wrong shape={true_shape.shape}" + + # size expressed in tokens + W //= self.patch_size[0] + H //= self.patch_size[1] + n_tokens = H * W + + height, width = true_shape.T + is_landscape = width >= height + is_portrait = ~is_landscape + + # linear projection, transposed if necessary + if is_landscape.any(): + new_landscape_content = self.proj(img[is_landscape]) + new_landscape_content = new_landscape_content.permute(0, 2, 3, 1).flatten(1, 2) + if is_portrait.any(): + new_protrait_content = self.proj(img[is_portrait].swapaxes(-1, -2)) + new_protrait_content = new_protrait_content.permute(0, 2, 3, 1).flatten(1, 2) + + # allocate space for result and set the content + x = img.new_empty((B, n_tokens, self.embed_dim), dtype=next(self.named_parameters())[1].dtype) # dynamically set dtype based on the current precision + if is_landscape.any(): + x[is_landscape] = new_landscape_content.to(x.dtype) + if is_portrait.any(): + x[is_portrait] = new_protrait_content.to(x.dtype) + + # allocate space for result and set the content + pos = img.new_empty((B, n_tokens, 2), dtype=torch.int64) + if is_landscape.any(): + pos[is_landscape] = self.position_getter(1, H, W, pos.device).expand(is_landscape.sum(), -1, -1) + if is_portrait.any(): + pos[is_portrait] = self.position_getter(1, W, H, pos.device).expand(is_portrait.sum(), -1, -1) + + x = self.norm(x) + return x, pos diff --git a/stream3r/dust3r/post_process.py b/stream3r/dust3r/post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..72271b46c8f9b176735cdb33f511cd424abc39d6 --- /dev/null +++ b/stream3r/dust3r/post_process.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilities for interpreting the DUST3R output +# -------------------------------------------------------- +import numpy as np +import torch + +from stream3r.dust3r.utils.geometry import xy_grid + + +def estimate_focal_knowing_depth( + pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf +): + """Reprojection method, for when the absolute depth is known: + 1) estimate the camera focal using a robust estimator + 2) reproject points onto true rays, minimizing a certain error + """ + B, H, W, THREE = pts3d.shape + assert THREE == 3 + + # centered pixel grid + pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view( + -1, 1, 2 + ) # B,HW,2 + pts3d = pts3d.flatten(1, 2) # (B, HW, 3) + + if focal_mode == "median": + with torch.no_grad(): + # direct estimation of focal + u, v = pixels.unbind(dim=-1) + x, y, z = pts3d.unbind(dim=-1) + fx_votes = (u * z) / x + fy_votes = (v * z) / y + + # assume square pixels, hence same focal for X and Y + f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) + focal = torch.nanmedian(f_votes, dim=-1).values + + elif focal_mode == "weiszfeld": + # init focal with l2 closed form + # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| + xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num( + posinf=0, neginf=0 + ) # homogeneous (x,y,1) + + dot_xy_px = (xy_over_z * pixels).sum(dim=-1) + dot_xy_xy = xy_over_z.square().sum(dim=-1) + + focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) + + # iterative re-weighted least-squares + for iter in range(10): + # re-weighting by inverse of distance + dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) + # print(dis.nanmean(-1)) + w = dis.clip(min=1e-8).reciprocal() + # update the scaling with the new weights + focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) + else: + raise ValueError(f"bad {focal_mode=}") + + focal_base = max(H, W) / ( + 2 * np.tan(np.deg2rad(60) / 2) + ) # size / 1.1547005383792515 + focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base) + # print(focal) + return focal + +def estimate_focal_knowing_depth_and_confidence_mask( + pts3d, pp, conf_mask, focal_mode="median", min_focal=0.0, max_focal=np.inf +): + """Reprojection method for when the absolute depth is known: + 1) estimate the camera focal using a robust estimator + 2) reproject points onto true rays, minimizing a certain error + This function considers only points where conf_mask is True. + """ + B, H, W, THREE = pts3d.shape + assert THREE == 3 + + # centered pixel grid + pixels = xy_grid(W, H, device=pts3d.device).view(1, H, W, 2) - pp.view( + -1, 1, 1, 2 + ) # B,H,W,2 + + # Apply the confidence mask + conf_mask = conf_mask.view(B, H, W) # Ensure conf_mask is of shape (B, H, W) + valid_indices = conf_mask # Boolean mask + + # Flatten the valid points + pts3d_valid = pts3d[valid_indices] # Shape: (N, 3) + pixels_valid = pixels[valid_indices] # Shape: (N, 2) + + if pts3d_valid.numel() == 0: + # No valid points, return a default focal length + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) + return torch.tensor([focal_base]) + + if focal_mode == "median": + with torch.no_grad(): + # Direct estimation of focal + u, v = pixels_valid.unbind(dim=-1) + x, y, z = pts3d_valid.unbind(dim=-1) + fx_votes = (u * z) / x + fy_votes = (v * z) / y + + # Assume square pixels, hence same focal for X and Y + f_votes = torch.cat((fx_votes.view(-1), fy_votes.view(-1)), dim=-1) + focal = torch.nanmedian(f_votes).unsqueeze(0) # Shape: (1,) + + elif focal_mode == "weiszfeld": + # Initialize focal with L2 closed-form solution + xy_over_z = (pts3d_valid[..., :2] / pts3d_valid[..., 2:3]).nan_to_num( + posinf=0, neginf=0 + ) # Shape: (N, 2) + + dot_xy_px = (xy_over_z * pixels_valid).sum(dim=-1) # Shape: (N,) + dot_xy_xy = xy_over_z.square().sum(dim=-1) # Shape: (N,) + + focal = dot_xy_px.mean() / dot_xy_xy.mean() # Shape: scalar + + # Iterative re-weighted least-squares + for _ in range(100): + # Re-weighting by inverse of distance + dis = (pixels_valid - focal * xy_over_z).norm(dim=-1) # Shape: (N,) + w = dis.clip(min=1e-8).reciprocal() # Shape: (N,) + # Update the scaling with the new weights + focal = (w * dot_xy_px).sum() / (w * dot_xy_xy).sum() + focal = focal.unsqueeze(0) # Shape: (1,) + else: + raise ValueError(f"bad focal_mode={focal_mode}") + + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) + focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base) + return focal diff --git a/stream3r/dust3r/utils/__init__.py b/stream3r/dust3r/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea9d87a320e848d1c4851a1e2408313c9255365 --- /dev/null +++ b/stream3r/dust3r/utils/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/stream3r/dust3r/utils/__pycache__/__init__.cpython-311.pyc b/stream3r/dust3r/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b0f71c7cc120fa545873be68975ac4c43dfd361 Binary files /dev/null and b/stream3r/dust3r/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/stream3r/dust3r/utils/__pycache__/misc.cpython-311.pyc b/stream3r/dust3r/utils/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54e7f3974e39fd5ef625424aeeeada5b18ed4f64 Binary files /dev/null and b/stream3r/dust3r/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/stream3r/dust3r/utils/camera.py b/stream3r/dust3r/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..8474d869eacd064b1d561e9916a5aacdf61fe666 --- /dev/null +++ b/stream3r/dust3r/utils/camera.py @@ -0,0 +1,204 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +inf = float("inf") + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + return standardize_quaternion(out) + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + quaternions = F.normalize(quaternions, p=2, dim=-1) + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def camera_to_pose_encoding( + camera, + pose_encoding_type="absT_quaR", +): + """ + Inverse to pose_encoding_to_camera + camera: opencv, cam2world + """ + if pose_encoding_type == "absT_quaR": + + quaternion_R = matrix_to_quaternion(camera[:, :3, :3]) + + pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + return pose_encoding + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def pose_encoding_to_camera( + pose_encoding, + pose_encoding_type="absT_quaR", +): + """ + Args: + pose_encoding: A tensor of shape `BxC`, containing a batch of + `B` `C`-dimensional pose encodings. + pose_encoding_type: The type of pose encoding, + """ + + if pose_encoding_type == "absT_quaR": + + abs_T = pose_encoding[:, :3] + quaternion_R = pose_encoding[:, 3:7] + R = quaternion_to_matrix(quaternion_R) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device) + c2w_mats = c2w_mats[None].repeat(len(R), 1, 1) + c2w_mats[:, :3, :3] = R + c2w_mats[:, :3, 3] = abs_T + + return c2w_mats + + +def quaternion_conjugate(q): + """Compute the conjugate of quaternion q (w, x, y, z).""" + + q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) + return q_conj + + +def quaternion_multiply(q1, q2): + """Multiply two quaternions q1 and q2.""" + w1, x1, y1, z1 = q1.unbind(dim=-1) + w2, x2, y2, z2 = q2.unbind(dim=-1) + + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + + return torch.stack((w, x, y, z), dim=-1) + + +def rotate_vector(q, v): + """Rotate vector v by quaternion q.""" + q_vec = q[..., 1:] + q_w = q[..., :1] + + t = 2.0 * torch.cross(q_vec, v, dim=-1) + v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1) + return v_rot + + +def relative_pose_absT_quatR(t1, q1, t2, q2): + """Compute the relative translation and quaternion between two poses.""" + + q1_inv = quaternion_conjugate(q1) + + q_rel = quaternion_multiply(q1_inv, q2) + + delta_t = t2 - t1 + t_rel = rotate_vector(q1_inv, delta_t) + return t_rel, q_rel diff --git a/stream3r/dust3r/utils/device.py b/stream3r/dust3r/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..bae8a503be3d75b60258135a5134dfdd7154ad59 --- /dev/null +++ b/stream3r/dust3r/utils/device.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import numpy as np +import torch + + +def todevice(batch, device, callback=None, non_blocking=False): + """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + """ + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == "numpy": + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice # alias + + +def to_numpy(x): + return todevice(x, "numpy") + + +def to_cpu(x): + return todevice(x, "cpu") + + +def to_cuda(x): + return todevice(x, "cuda") + + +def collate_with_cat(whatever, lists=False): + if isinstance(whatever, dict): + return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} + + elif isinstance(whatever, (tuple, list)): + if len(whatever) == 0: + return whatever + elem = whatever[0] + T = type(whatever) + + if elem is None: + return None + if isinstance(elem, (bool, float, int, str)): + return whatever + if isinstance(elem, tuple): + return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) + if isinstance(elem, dict): + return { + k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem + } + + if isinstance(elem, torch.Tensor): + return listify(whatever) if lists else torch.cat(whatever) + if isinstance(elem, np.ndarray): + return ( + listify(whatever) + if lists + else torch.cat([torch.from_numpy(x) for x in whatever]) + ) + + # otherwise, we just chain lists + return sum(whatever, T()) + + +def listify(elems): + return [x for e in elems for x in e] diff --git a/stream3r/dust3r/utils/geometry.py b/stream3r/dust3r/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..d95a6a0f9f2de4d7d5f83745a5bfeb1a933fde3f --- /dev/null +++ b/stream3r/dust3r/utils/geometry.py @@ -0,0 +1,411 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# geometry utilitary functions +# -------------------------------------------------------- +import numpy as np +import torch +from scipy.spatial import cKDTree as KDTree + +from stream3r.dust3r.utils.device import to_numpy +from stream3r.dust3r.utils.misc import invalid_to_nans, invalid_to_zeros + + +def xy_grid( + W, + H, + device=None, + origin=(0, 0), + unsqueeze=None, + cat_dim=-1, + homogeneous=False, + **arange_kw, +): + """Output a (H,W,2) array of int32 + with output[j,i,0] = i + origin[0] + output[j,i,1] = j + origin[1] + """ + if device is None: + # numpy + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + # torch + arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) + meshgrid, stack = torch.meshgrid, torch.stack + ones = lambda *a: torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing="xy") + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if ( + isinstance(Trf, torch.Tensor) + and isinstance(pts, torch.Tensor) + and Trf.ndim == 3 + and pts.ndim == 4 + ): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = ( + torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + + Trf[:, None, None, :d, d] + ) + else: + raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """Invert a torch or numpy matrix""" + if isinstance(mat, torch.Tensor): + # return torch.linalg.inv(mat) + # for mixed precision training + if mat.dtype == torch.bfloat16: + mat = mat.to(torch.float32) + mat = torch.linalg.inv(mat) + return mat + else: + mat = torch.linalg.inv(mat) + return mat + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f"bad matrix type = {type(mat)}") + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + # set principal point + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + # Mask for valid coordinates + valid_mask = depthmap > 0.0 + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates( + depthmap, camera_intrinsics, camera_pose, **kw +): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + X_world = ( + np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + ) + return X_world, valid_mask + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K + + +def normalize_pointcloud(pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None): + """renorm pointmaps pts1, pts2 with norm_mode""" + assert pts1.ndim >= 3 and pts1.shape[-1] == 3 + assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split("_") + + if norm_mode == "avg": + # gather all points together (joint normalization) + nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) + nan_pts2, nnz2 = ( + invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) + ) + all_pts = ( + torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + ) + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass # do nothing + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + elif dis_mode == "warp-log1p": + # actually warp input points before normalizing them + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H1, W1 = pts1.shape[1:-1] + pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1) + if pts2 is not None: + H2, W2 = pts2.shape[1:-1] + pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1) + all_dis = log_dis # this is their true distance afterwards + else: + raise ValueError(f"bad {dis_mode=}") + + norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) + else: + # gather all points together (joint normalization) + nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) + nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None + all_pts = ( + torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + ) + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + + if norm_mode == "avg": + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == "median": + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == "sqrt": + norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2 + else: + raise ValueError(f"bad {norm_mode=}") + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts1.ndim: + norm_factor.unsqueeze_(-1) + + res = pts1 / norm_factor + if pts2 is not None: + res = (res, pts2 / norm_factor) + return res + + +@torch.no_grad() +def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): + # set invalid points to NaN + _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + _z2 = ( + invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) + if z2 is not None + else None + ) + _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 + + # compute median depth overall (ignoring nans) + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale( + pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True +): + # set invalid points to NaN + _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + _pts2 = ( + invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) + if pts2 is not None + else None + ) + _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 + + # compute median center + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + # compute median norm + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +def find_reciprocal_matches(P1, P2): + """ + returns 3 values: + 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match + 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 + 3 - reciprocal_in_P2.sum(): the number of matches + """ + tree1 = KDTree(P1) + tree2 = KDTree(P2) + + _, nn1_in_P2 = tree2.query(P1, workers=8) + _, nn2_in_P1 = tree1.query(P2, workers=8) + + reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)) + reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)) + assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() + return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() + + +def get_med_dist_between_poses(poses): + from scipy.spatial.distance import pdist + + return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) diff --git a/stream3r/dust3r/utils/image.py b/stream3r/dust3r/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..23f3b8ae862fc29652b2e1726fc4b5db6fa7901f --- /dev/null +++ b/stream3r/dust3r/utils/image.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions about images (loading/converting...) +# -------------------------------------------------------- +import os + +import numpy as np +import PIL.Image +import torch +import torchvision.transforms as tvf +from PIL.ImageOps import exif_transpose + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +try: + from pillow_heif import register_heif_opener + + register_heif_opener() + heif_support_enabled = True +except ImportError: + heif_support_enabled = False + +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """Open an image or a depthmap with opencv-python.""" + if path.endswith((".exr", "EXR")): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f"Could not load image={path} with {options=}") + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def rgb(ftensor, true_shape=None): + if isinstance(ftensor, list): + return [rgb(x, true_shape=true_shape) for x in ftensor] + if isinstance(ftensor, torch.Tensor): + ftensor = ftensor.detach().cpu().numpy() # H,W,3 + if ftensor.ndim == 3 and ftensor.shape[0] == 3: + ftensor = ftensor.transpose(1, 2, 0) + elif ftensor.ndim == 4 and ftensor.shape[1] == 3: + ftensor = ftensor.transpose(0, 2, 3, 1) + if true_shape is not None: + H, W = true_shape + ftensor = ftensor[:H, :W] + if ftensor.dtype == np.uint8: + img = np.float32(ftensor) / 255 + else: + img = (ftensor * 0.5) + 0.5 + return img.clip(min=0, max=1) + + +def _resize_pil_image(img, long_edge_size): + S = max(img.size) + if S > long_edge_size: + interp = PIL.Image.LANCZOS + elif S <= long_edge_size: + interp = PIL.Image.BICUBIC + new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size) + return img.resize(new_size, interp) + + +def load_images(folder_or_list, size, square_ok=False, verbose=True, rotate_clockwise_90=False, crop_to_landscape=False, patch_size=16): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + + if rotate_clockwise_90: + img = img.rotate(-90, expand=True) + + # width, height = img.size + # if height > width: + # img = img.transpose(PIL.Image.TRANSPOSE) + + if crop_to_landscape: + # Crop to a landscape aspect ratio (e.g., 16:9) + desired_aspect_ratio = 4 / 3 + width, height = img.size + current_aspect_ratio = width / height + + if current_aspect_ratio > desired_aspect_ratio: + # Wider than landscape: crop width + new_width = int(height * desired_aspect_ratio) + left = (width - new_width) // 2 + right = left + new_width + top = 0 + bottom = height + else: + # Taller than landscape: crop height + new_height = int(width / desired_aspect_ratio) + top = (height - new_height) // 2 + bottom = top + new_height + left = 0 + right = width + + img = img.crop((left, top, right, bottom)) + + W1, H1 = img.size + if size == 224: + # resize short side to 224 (then crop) + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + # resize long side to 512 + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: + # 16 is the patch size and 8 is the 16//2 + halfw, halfh = ((2 * cx) // patch_size) * patch_size//2, ((2 * cy) // patch_size) * patch_size//2 + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + # true_shape = [img.size] if height > width else [img.size[::-1]] # if is protrait, the true shape should inverse + true_shape = [img.size[::-1]] # true shape requires H, W + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32(true_shape), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs + + +def load_images_for_eval( + folder_or_list, size, square_ok=False, verbose=True, crop=True, patch_size=16 +): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for i, path in enumerate(folder_content): + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + W1, H1 = img.size + if size == 224: + # resize short side to 224 (then crop) + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + # resize long side to 512 + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + if crop: + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: # resize + img = img.resize((2 * half, 2 * half), PIL.Image.LANCZOS) + else: + halfw, halfh = ((2 * cx) // patch_size) * (patch_size//2), ((2 * cy) // patch_size) * (patch_size//2) + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + if crop: + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + else: # resize + img = img.resize((2 * halfw, 2 * halfh), PIL.Image.LANCZOS) + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs diff --git a/stream3r/dust3r/utils/misc.py b/stream3r/dust3r/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f110ef7ac0c45137865ccfd16e3a408c9c1203d4 --- /dev/null +++ b/stream3r/dust3r/utils/misc.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for DUSt3R +# -------------------------------------------------------- +import torch + + +def fill_default_args(kwargs, func): + import inspect # a bit hacky but it works reliably + + signature = inspect.signature(func) + + for k, v in signature.parameters.items(): + if v.default is inspect.Parameter.empty: + continue + kwargs.setdefault(k, v.default) + + return kwargs + + +def freeze_all_params(modules): + for module in modules: + try: + for n, param in module.named_parameters(): + param.requires_grad = False + except AttributeError: + # module is directly a parameter + module.requires_grad = False + + +def is_symmetrized(gt1, gt2): + x = gt1["instance"] + y = gt2["instance"] + if len(x) == len(y) and len(x) == 1: + return False # special case of batchsize 1 + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) + return ok + + +def flip(tensor): + """flip so that tensor[0::2] <=> tensor[1::2]""" + return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) + + +def interleave(tensor1, tensor2): + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +def transpose_to_landscape(head, activate=True): + """Predict in the correct aspect-ratio, + then transpose the result in landscape + and stack everything back together. + """ + + def wrapper_no(decout, true_shape): + B = len(true_shape) + assert true_shape[0:1].allclose(true_shape), "true_shape must be all identical" + H, W = true_shape[0].cpu().tolist() + res = head(decout, (H, W)) + return res + + def wrapper_yes(decout, true_shape): + if not head.training: + return wrapper_no(decout, true_shape) + + B = len(true_shape) + # by definition, the batch is in landscape mode so W >= H + H, W = int(true_shape.min()), int(true_shape.max()) + + height, width = true_shape.T + is_landscape = width >= height + is_portrait = ~is_landscape + + # true_shape = true_shape.cpu() + if is_landscape.all(): + return head(decout, (H, W)) + if is_portrait.all(): + return transposed(head(decout, (W, H))) + + # batch is a mix of both portraint & landscape + def selout(ar): + return [d[ar] for d in decout] + + l_result = head(selout(is_landscape), (H, W)) + p_result = transposed(head(selout(is_portrait), (W, H))) + + # allocate full result + result = {} + for k in l_result | p_result: + x = l_result[k].new(B, *l_result[k].shape[1:]) + x[is_landscape] = l_result[k] + x[is_portrait] = p_result[k] + result[k] = x + + return result + + return wrapper_yes if activate else wrapper_no + + +def transposed(dic): + return {k: v.swapaxes(1, 2) for k, v in dic.items()} + + +def invalid_to_nans(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float("nan") + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz diff --git a/stream3r/dust3r/utils/parallel.py b/stream3r/dust3r/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..6c9c4e3ecb9968b390741154d9fdb191b12800c4 --- /dev/null +++ b/stream3r/dust3r/utils/parallel.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# utilitary functions for multiprocessing +# -------------------------------------------------------- +from tqdm import tqdm +from multiprocessing.dummy import Pool as ThreadPool +from multiprocessing import cpu_count + + +def parallel_threads(function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=ThreadPool, **tqdm_kw): + """ tqdm but with parallel execution. + + Will essentially return + res = [ function(arg) # default + function(*arg) # if star_args is True + function(**arg) # if kw_args is True + for arg in args] + + Note: + the first elements of args will not be parallelized. + This can be useful for debugging. + """ + while workers <= 0: + workers += cpu_count() + if workers == 1: + front_num = float('inf') + + # convert into an iterable + try: + n_args_parallel = len(args) - front_num + except TypeError: + n_args_parallel = None + args = iter(args) + + # sequential execution first + front = [] + while len(front) < front_num: + try: + a = next(args) + except StopIteration: + return front # end of the iterable + front.append(function(*a) if star_args else function(**a) if kw_args else function(a)) + + # then parallel execution + out = [] + with Pool(workers) as pool: + # Pass the elements of args into function + if star_args: + futures = pool.imap(starcall, [(function, a) for a in args]) + elif kw_args: + futures = pool.imap(starstarcall, [(function, a) for a in args]) + else: + futures = pool.imap(function, args) + # Print out the progress as tasks complete + for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): + out.append(f) + return front + out + + +def parallel_processes(*args, **kwargs): + """ Same as parallel_threads, with processes + """ + import multiprocessing as mp + kwargs['Pool'] = mp.Pool + return parallel_threads(*args, **kwargs) + + +def starcall(args): + """ convenient wrapper for Process.Pool """ + function, args = args + return function(*args) + + +def starstarcall(args): + """ convenient wrapper for Process.Pool """ + function, args = args + return function(**args) diff --git a/stream3r/dust3r/utils/path_to_croco.py b/stream3r/dust3r/utils/path_to_croco.py new file mode 100644 index 0000000000000000000000000000000000000000..977ea499a88b198ae89b850b2e13dfe3eadf3930 --- /dev/null +++ b/stream3r/dust3r/utils/path_to_croco.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# CroCo submodule import +# -------------------------------------------------------- + +import os.path as path +import sys + +HERE_PATH = path.normpath(path.dirname(__file__)) +CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, "../../croco")) +CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, "models") +# check the presence of models directory in repo to be sure its cloned +if path.isdir(CROCO_MODELS_PATH): + # workaround for sibling import + sys.path.insert(0, CROCO_REPO_PATH) +else: + raise ImportError( + f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " + "Did you forget to run 'git submodule update --init --recursive' ?" + ) diff --git a/stream3r/dust3r/viz.py b/stream3r/dust3r/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..2a362911018be5021b28687ea231b5c91f5bff5b --- /dev/null +++ b/stream3r/dust3r/viz.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Visualization utilities using trimesh +# -------------------------------------------------------- +import numpy as np +import PIL.Image +import torch +from scipy.spatial.transform import Rotation + +from stream3r.dust3r.utils.device import to_numpy +from stream3r.dust3r.utils.geometry import geotrf, get_med_dist_between_poses +from stream3r.dust3r.utils.image import rgb + +try: + import trimesh +except ImportError: + print("/!\\ module trimesh is not installed, cannot visualize results /!\\") + + +def cat_3d(vecs): + if isinstance(vecs, (np.ndarray, torch.Tensor)): + vecs = [vecs] + return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)]) + + +def show_raw_pointcloud(pts3d, colors, point_size=2): + scene = trimesh.Scene() + + pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors)) + scene.add_geometry(pct) + + scene.show(line_settings={"point_size": point_size}) + + +def pts3d_to_trimesh(img, pts3d, valid=None): + H, W, THREE = img.shape + assert THREE == 3 + assert img.shape == pts3d.shape + + vertices = pts3d.reshape(-1, 3) + + # make squares: each pixel == 2 triangles + idx = np.arange(len(vertices)).reshape(H, W) + idx1 = idx[:-1, :-1].ravel() # top-left corner + idx2 = idx[:-1, +1:].ravel() # right-left corner + idx3 = idx[+1:, :-1].ravel() # bottom-left corner + idx4 = idx[+1:, +1:].ravel() # bottom-right corner + faces = np.concatenate( + ( + np.c_[idx1, idx2, idx3], + np.c_[ + idx3, idx2, idx1 + ], # same triangle, but backward (cheap solution to cancel face culling) + np.c_[idx2, idx3, idx4], + np.c_[ + idx4, idx3, idx2 + ], # same triangle, but backward (cheap solution to cancel face culling) + ), + axis=0, + ) + + # prepare triangle colors + face_colors = np.concatenate( + ( + img[:-1, :-1].reshape(-1, 3), + img[:-1, :-1].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3), + img[+1:, +1:].reshape(-1, 3), + ), + axis=0, + ) + + # remove invalid faces + if valid is not None: + assert valid.shape == (H, W) + valid_idxs = valid.ravel() + valid_faces = valid_idxs[faces].all(axis=-1) + faces = faces[valid_faces] + face_colors = face_colors[valid_faces] + + assert len(faces) == len(face_colors) + return dict(vertices=vertices, face_colors=face_colors, faces=faces) + + +def cat_meshes(meshes): + vertices, faces, colors = zip( + *[(m["vertices"], m["faces"], m["face_colors"]) for m in meshes] + ) + n_vertices = np.cumsum([0] + [len(v) for v in vertices]) + for i in range(len(faces)): + faces[i][:] += n_vertices[i] + + vertices = np.concatenate(vertices) + colors = np.concatenate(colors) + faces = np.concatenate(faces) + return dict(vertices=vertices, face_colors=colors, faces=faces) + + +def show_duster_pairs(view1, view2, pred1, pred2): + import matplotlib.pyplot as pl + + pl.ion() + + for e in range(len(view1["instance"])): + i = view1["idx"][e] + j = view2["idx"][e] + img1 = rgb(view1["img"][e]) + img2 = rgb(view2["img"][e]) + conf1 = pred1["conf"][e].squeeze() + conf2 = pred2["conf"][e].squeeze() + score = conf1.mean() * conf2.mean() + print(f">> Showing pair #{e} {i}-{j} {score=:g}") + pl.clf() + pl.subplot(221).imshow(img1) + pl.subplot(223).imshow(img2) + pl.subplot(222).imshow(conf1, vmin=1, vmax=30) + pl.subplot(224).imshow(conf2, vmin=1, vmax=30) + pts1 = pred1["pts3d"][e] + pts2 = pred2["pts3d_in_other_view"][e] + pl.subplots_adjust(0, 0, 1, 1, 0, 0) + if input("show pointcloud? (y/n) ") == "y": + show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5) + + +def auto_cam_size(im_poses): + return 0.1 * get_med_dist_between_poses(im_poses) + + +class SceneViz: + def __init__(self): + self.scene = trimesh.Scene() + + def add_pointcloud(self, pts3d, color, mask=None): + pts3d = to_numpy(pts3d) + mask = to_numpy(mask) + if mask is None: + mask = [slice(None)] * len(pts3d) + pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) + pct = trimesh.PointCloud(pts.reshape(-1, 3)) + + if isinstance(color, (list, np.ndarray, torch.Tensor)): + color = to_numpy(color) + col = np.concatenate([p[m] for p, m in zip(color, mask)]) + assert col.shape == pts.shape + pct.visual.vertex_colors = uint8(col.reshape(-1, 3)) + else: + assert len(color) == 3 + pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape) + + self.scene.add_geometry(pct) + return self + + def add_camera( + self, + pose_c2w, + focal=None, + color=(0, 0, 0), + image=None, + imsize=None, + cam_size=0.03, + ): + pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image)) + add_scene_cam(self.scene, pose_c2w, color, image, focal, screen_width=cam_size) + return self + + def add_cameras( + self, poses, focals=None, images=None, imsizes=None, colors=None, **kw + ): + def get(arr, idx): + return None if arr is None else arr[idx] + + for i, pose_c2w in enumerate(poses): + self.add_camera( + pose_c2w, + get(focals, i), + image=get(images, i), + color=get(colors, i), + imsize=get(imsizes, i), + **kw, + ) + return self + + def show(self, point_size=2, viewer=None): + return self.scene.show(viewer=viewer, line_settings={"point_size": point_size}) + + +def show_raw_pointcloud_with_cams( + imgs, pts3d, mask, focals, cams2world, point_size=2, cam_size=0.05, cam_color=None +): + """Visualization of a pointcloud with cameras + imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...] + pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...] + focals = (N,) or N-size list of [focal, ...] + cams2world = (N,4,4) or N-size list of [(4,4), ...] + """ + assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) + pts3d = to_numpy(pts3d) + imgs = to_numpy(imgs) + focals = to_numpy(focals) + cams2world = to_numpy(cams2world) + + scene = trimesh.Scene() + + # full pointcloud + pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) + col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) + pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) + scene.add_geometry(pct) + + # add each camera + for i, pose_c2w in enumerate(cams2world): + if isinstance(cam_color, list): + camera_edge_color = cam_color[i] + else: + camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] + add_scene_cam( + scene, + pose_c2w, + camera_edge_color, + imgs[i] if i < len(imgs) else None, + focals[i], + screen_width=cam_size, + ) + + scene.show(line_settings={"point_size": point_size}) + + +def add_scene_cam( + scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03 +): + if image is not None: + H, W, THREE = image.shape + assert THREE == 3 + if image.dtype != np.uint8: + image = np.uint8(255 * image) + elif imsize is not None: + W, H = imsize + elif focal is not None: + H = W = focal / 1.1 + else: + H = W = 1 + + if focal is None: + focal = min(H, W) * 1.1 # default value + elif isinstance(focal, np.ndarray): + focal = focal[0] + + # create fake camera + height = focal * screen_width / H + width = screen_width * 0.5**0.5 + rot45 = np.eye(4) + rot45[:3, :3] = Rotation.from_euler("z", np.deg2rad(45)).as_matrix() + rot45[2, 3] = -height # set the tip of the cone = optical center + aspect_ratio = np.eye(4) + aspect_ratio[0, 0] = W / H + transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45 + cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform) + + # this is the image + if image is not None: + vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]]) + faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]]) + img = trimesh.Trimesh(vertices=vertices, faces=faces) + uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]]) + img.visual = trimesh.visual.TextureVisuals( + uv_coords, image=PIL.Image.fromarray(image) + ) + scene.add_geometry(img) + + # this is the camera mesh + rot2 = np.eye(4) + rot2[:3, :3] = Rotation.from_euler("z", np.deg2rad(2)).as_matrix() + vertices = np.r_[cam.vertices, 0.95 * cam.vertices, geotrf(rot2, cam.vertices)] + vertices = geotrf(transform, vertices) + faces = [] + for face in cam.faces: + if 0 in face: + continue + a, b, c = face + a2, b2, c2 = face + len(cam.vertices) + a3, b3, c3 = face + 2 * len(cam.vertices) + + # add 3 pseudo-edges + faces.append((a, b, b2)) + faces.append((a, a2, c)) + faces.append((c2, b, c)) + + faces.append((a, b, b3)) + faces.append((a, a3, c)) + faces.append((c3, b, c)) + + # no culling + faces += [(c, b, a) for a, b, c in faces] + + cam = trimesh.Trimesh(vertices=vertices, faces=faces) + cam.visual.face_colors[:, :3] = edge_color + scene.add_geometry(cam) + + +def cat(a, b): + return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3))) + + +OPENGL = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + + +CAM_COLORS = [ + (255, 0, 0), + (0, 0, 255), + (0, 255, 0), + (255, 0, 255), + (255, 204, 0), + (0, 204, 204), + (128, 255, 255), + (255, 128, 255), + (255, 255, 128), + (0, 0, 0), + (128, 128, 128), +] + + +def uint8(colors): + if not isinstance(colors, np.ndarray): + colors = np.array(colors) + if np.issubdtype(colors.dtype, np.floating): + colors *= 255 + assert 0 <= colors.min() and colors.max() < 256 + return np.uint8(colors) + + +def segment_sky(image): + import cv2 + from scipy import ndimage + + # Convert to HSV + image = to_numpy(image) + if np.issubdtype(image.dtype, np.floating): + image = np.uint8(255 * image.clip(min=0, max=1)) + hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + # Define range for blue color and create mask + lower_blue = np.array([0, 0, 100]) + upper_blue = np.array([30, 255, 255]) + mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool) + + # add luminous gray + mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150) + mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180) + mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220) + + # Morphological operations + kernel = np.ones((5, 5), np.uint8) + mask2 = ndimage.binary_opening(mask, structure=kernel) + + # keep only largest CC + _, labels, stats, _ = cv2.connectedComponentsWithStats( + mask2.view(np.uint8), connectivity=8 + ) + cc_sizes = stats[1:, cv2.CC_STAT_AREA] + order = cc_sizes.argsort()[::-1] # bigger first + i = 0 + selection = [] + while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2: + selection.append(1 + order[i]) + i += 1 + mask3 = np.in1d(labels, selection).reshape(labels.shape) + + # Apply mask + return torch.from_numpy(mask3) diff --git a/stream3r/dust3r/viz_plotly.py b/stream3r/dust3r/viz_plotly.py new file mode 100644 index 0000000000000000000000000000000000000000..b0eee7f73981d5b86383c299094994f21fecfe43 --- /dev/null +++ b/stream3r/dust3r/viz_plotly.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import plotly.graph_objects as go +from PIL import Image +from sklearn.cluster import KMeans +from sklearn.utils import shuffle +from scipy.spatial.transform import Rotation + + +class SceneViz: + def __init__(self): + self.fig = go.Figure() + + def export_html(self, filename="scene_visualization.html"): + """Exports the current figure to a self-contained HTML file.""" + # Use Plotly's write_html to save the figure as a standalone HTML file + self.fig.write_html(filename, include_plotlyjs='cdn', full_html=True) + + print(f"Visualization exported to {filename}") + + def add_pointcloud(self, pts3d, color, mask=None, point_size=1, view_idx=0): + """Adds a point cloud to the Plotly figure using original colors.""" + pts3d = np.array(pts3d) + if mask is None: + mask = np.ones(pts3d.shape[:2], dtype=bool) + + masked_pts = pts3d[mask] + masked_color = color[mask] + + # Add point cloud with original colors and adjustable point size + self.fig.add_trace(go.Scatter3d( + x=masked_pts[:, 0], + y=masked_pts[:, 1], + z=masked_pts[:, 2], + mode='markers', + marker=dict(size=point_size, color=masked_color.reshape(-1, 3), opacity=0.8), + name=f"Point Cloud {view_idx}" + )) + return self + + def clamp_color(self, rgb_color): + """Ensure that RGB values are clamped between 0 and 255.""" + return tuple(max(0, min(255, int(c))) for c in rgb_color) + + def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.02, view_idx=0, enable_color_image=True): + """Adds a camera frustum to the plot, with an optional image texture.""" + focal = focal if focal is not None else 500 + # Clamp the color values to ensure valid RGB range + clamped_color = self.clamp_color(color) + color_str = f'rgb{clamped_color}' + + # Create frustum and image surface or mesh3d + frustum_traces = create_camera_frustum_with_image( + pose_c2w, focal=focal, H=imsize[0] if imsize else 1080, W=imsize[1] if imsize else 1920, + image=image, screen_width=cam_size, color=color_str, view_idx=view_idx, enable_color_image=enable_color_image + ) + + # Add the traces for both frustum and image + for trace in frustum_traces: + if trace: + trace.update(legendgroup=f"frustum_{id(pose_c2w)}", showlegend=True) + self.fig.add_trace(trace) + + return self + + def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, cam_size=0.02, enable_color_image=True): + """Add multiple cameras with adjustable frustum size.""" + def get(arr, idx): + return None if arr is None else arr[idx] + + for i, pose_c2w in enumerate(poses): + self.add_camera( + pose_c2w, + focal=get(focals, i), + image=get(images, i), + color=get(colors, i), + imsize=get(imsizes, i), + cam_size=cam_size, # Frustum size control + view_idx=i, # Use view index for naming + enable_color_image=enable_color_image # Flag to enable or disable colored images + ) + return self + + def show(self, point_size=1, viewer=None): + self.fig.update_layout( + title="Camera Poses and Point Clouds", + scene=dict( + xaxis_title='X', + yaxis_title='Y', + zaxis_title='Z' + ), + margin=dict(l=0, r=0, b=0, t=40), + height=800, + ) + self.fig.show() + return self + + +def image2zvals(img, n_colors=64, n_training_pixels=1000): + """Quantize the image using KMeans for color mapping.""" + rows, cols, _ = img.shape + img = np.clip(img / 255.0, 0, 1) # Normalize the image + + # Flatten and shuffle for KMeans clustering + observations = img[:, :, :3].reshape(rows * cols, 3) + training_pixels = shuffle(observations, random_state=42)[:n_training_pixels] + + kmeans = KMeans(n_clusters=n_colors, random_state=42).fit(training_pixels) + codebook = kmeans.cluster_centers_ + indices = kmeans.predict(observations) + + z_vals = indices.astype(float) / (n_colors - 1) # Normalize indices to [0, 1] + z_vals = z_vals.reshape(rows, cols) + + # Create Plotly color scale + scale = np.linspace(0, 1, n_colors) + colors = (codebook * 255).astype(np.uint8) + plotly_colorscale = [[s, f'rgb{tuple(c)}'] for s, c in zip(scale, colors)] + + return z_vals, plotly_colorscale + + +def regular_triangles(rows, cols): + """Generate regular triangles for a mesh.""" + triangles = [] + for i in range(rows - 1): + for j in range(cols - 1): + k = j + i * cols + triangles.extend([[k, k + cols, k + 1 + cols], [k, k + 1 + cols, k + 1]]) + return np.array(triangles) + + +def mesh_data(img, resolution=128, n_colors=64, n_training_pixels=1000): + """Generate mesh data with quantized color intensities for the image.""" + img_downsampled = np.array(Image.fromarray(img).resize((resolution, resolution))) + + # Quantize the downsampled image + z_vals, pl_colorscale = image2zvals(img_downsampled, n_colors=n_colors, n_training_pixels=n_training_pixels) + + # Generate triangles + rows, cols, _ = img_downsampled.shape + triangles = regular_triangles(rows, cols) + I, J, K = triangles.T + + # Assign intensity to each triangle + zc = z_vals.flatten()[triangles] + tri_color_intensity = [zc[k][2] if k % 2 else zc[k][1] for k in range(len(zc))] + + return I, J, K, tri_color_intensity, pl_colorscale + + +def generate_meshgrid(frustum_points_world, resolution): + """Generate a meshgrid and calculate X, Y, Z for the surface or mesh3d.""" + img_x, img_y, img_z = frustum_points_world[1:5, 0], frustum_points_world[1:5, 1], frustum_points_world[1:5, 2] + u = np.linspace(0, 1, resolution) + v = np.linspace(0, 1, resolution) + uu, vv = np.meshgrid(u, v) + + X = img_x[0] * (1 - uu) * (1 - vv) + img_x[1] * uu * (1 - vv) + img_x[3] * (1 - uu) * vv + img_x[2] * uu * vv + Y = img_y[0] * (1 - uu) * (1 - vv) + img_y[1] * uu * (1 - vv) + img_y[3] * (1 - uu) * vv + img_y[2] * uu * vv + Z = img_z[0] * (1 - uu) * (1 - vv) + img_z[1] * uu * (1 - vv) + img_z[3] * (1 - uu) * vv + img_z[2] * uu * vv + + return X, Y, Z + + +def create_mesh3d(pose_c2w, img, frustum_points_world, resolution=64, n_colors=64, view_idx=0): + """Creates a Mesh3d object for the image texture mapping.""" + X, Y, Z = generate_meshgrid(frustum_points_world, resolution) + + # Get the mesh data + I, J, K, tri_color_intensity, pl_colorscale = mesh_data(img, resolution=resolution, n_colors=n_colors) + + # Create the Mesh3d trace + mesh3d_trace = go.Mesh3d( + x=X.flatten(), y=Y.flatten(), z=Z.flatten(), # Use frustum base interpolation + i=I, j=J, k=K, + intensity=tri_color_intensity, + intensitymode="cell", + colorscale=pl_colorscale, + showscale=False, + name=f"Image {view_idx}" + ) + + return mesh3d_trace + + +def create_surface(frustum_points_world, img_downsampled, z_vals, resolution=64, view_idx=0): + """Creates a Surface object with grayscale for faster rendering.""" + X, Y, Z = generate_meshgrid(frustum_points_world, resolution) + + # Create a grayscale surface + surface_trace = go.Surface( + x=X, y=Y, z=Z, surfacecolor=z_vals, colorscale='gray', showscale=False, name=f"Image {view_idx}" + ) + + return surface_trace + + +def create_camera_frustum_with_image(pose_c2w, focal, H, W, image=None, screen_width=0.02, color='blue', resolution=128, view_idx=0, enable_color_image=True): + """Creates a frustum for a camera and optionally adds an image to the frustum.""" + if image is not None: + H, W, THREE = image.shape + assert THREE == 3 + if image.dtype != np.uint8: + image = np.uint8(255 * image) + + # Calculate frustum size based on screen width and focal length + depth = focal * screen_width / H + hw_ratio = W / H + + # Define frustum points in camera space + frustum_points = np.array([ + [0, 0, 0], # Camera origin + [-hw_ratio * depth, -depth, depth], # Bottom left + [hw_ratio * depth, -depth, depth], # Bottom right + [hw_ratio * depth, depth, depth], # Top right + [-hw_ratio * depth, depth, depth], # Top left + ]) + + # Transform points to world coordinates + frustum_points_homogeneous = np.hstack([frustum_points, np.ones((frustum_points.shape[0], 1))]) + frustum_points_world = (pose_c2w @ frustum_points_homogeneous.T).T[:, :3] + + # Define edges of the frustum + edges = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4), (4, 1)] + x_vals, y_vals, z_vals = [], [], [] + for edge in edges: + x_vals += [frustum_points_world[edge[0], 0], frustum_points_world[edge[1], 0], None] + y_vals += [frustum_points_world[edge[0], 1], frustum_points_world[edge[1], 1], None] + z_vals += [frustum_points_world[edge[0], 2], frustum_points_world[edge[1], 2], None] + + # Add frustum lines + frustum_trace = go.Scatter3d( + x=x_vals, y=y_vals, z=z_vals, mode='lines', + line=dict(color=color, width=2), name=f"Camera {view_idx}" + ) + + # Optionally add image to the base of the frustum + image_surface_trace = None + if image is not None: + if enable_color_image: + # plotly doesn't natively support texture mapping, so we use Mesh3d for colored image + # see: https://github.com/empet/Texture-mapping-with-Plotly/blob/main/Texture-mapping-surface.ipynb + # Create the Mesh3d for colored image + image_surface_trace = create_mesh3d(pose_c2w, img=image, frustum_points_world=frustum_points_world, resolution=resolution, n_colors=64, view_idx=view_idx) + else: + # Downsample and use grayscale for faster performance + img_downsampled = np.array(Image.fromarray(image).resize((resolution, resolution))) + z_vals = np.mean(img_downsampled / 255.0, axis=-1) + image_surface_trace = create_surface(frustum_points_world, img_downsampled, z_vals, resolution=resolution, view_idx=view_idx) + + return [frustum_trace, image_surface_trace] if image_surface_trace else [frustum_trace] diff --git a/stream3r/loss/losses.py b/stream3r/loss/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..c7fc7b607963a249b39e8f9a0ba89429b237f662 --- /dev/null +++ b/stream3r/loss/losses.py @@ -0,0 +1,266 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# Implementation of DUSt3R training losses +# -------------------------------------------------------- +from copy import copy, deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from stream3r.dust3r.utils.geometry import ( + geotrf, + inv, +) +from stream3r.loss.utils import camera_loss, point_loss, depth_loss + + + +class LLoss(nn.Module): + """L-norm loss""" + + def __init__(self, reduction="mean"): + super().__init__() + self.reduction = reduction + + def forward(self, a, b): + assert ( + a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3 + ), f"Bad shape = {a.shape}" + dist = self.distance(a, b) + assert dist.ndim == a.ndim - 1 # one dimension less + if self.reduction == "none": + return dist + if self.reduction == "sum": + return dist.sum() + if self.reduction == "mean": + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + raise ValueError(f"bad {self.reduction=} mode") + + def distance(self, a, b): + raise NotImplementedError() + + +class L21Loss(LLoss): + """Euclidean distance between 3d points""" + + def distance(self, a, b): + return torch.norm(a - b, dim=-1) # normalized L2 distance + + +L21 = L21Loss() + + +class Criterion(nn.Module): + def __init__(self, criterion=None): + super().__init__() + assert isinstance(criterion, LLoss), ( + f"{criterion} is not a proper criterion!" + bb() + ) + self.criterion = copy(criterion) + + def get_name(self): + return f"{type(self).__name__}({self.criterion})" + + def with_reduction(self, mode): + res = loss = deepcopy(self) + while loss is not None: + assert isinstance(loss, Criterion) + loss.criterion.reduction = "none" # make it return the loss for each sample + loss = loss._loss2 # we assume loss is a Multiloss + return res + + +class MultiLoss(nn.Module): + """Easily combinable losses (also keep track of individual loss values): + loss = MyLoss1() + 0.1*MyLoss2() + Usage: + Inherit from this class and override get_name() and compute_loss() + """ + + def __init__(self): + super().__init__() + self._alpha = 1 + self._loss2 = None + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + def __mul__(self, alpha): + assert isinstance(alpha, (int, float)) + res = copy(self) + res._alpha = alpha + return res + + __rmul__ = __mul__ # same + + def __add__(self, loss2): + assert isinstance(loss2, MultiLoss) + res = cur = copy(self) + # find the end of the chain + while cur._loss2 is not None: + cur = cur._loss2 + cur._loss2 = loss2 + return res + + def __repr__(self): + name = self.get_name() + if self._alpha != 1: + name = f"{self._alpha:g}*{name}" + if self._loss2: + name = f"{name} + {self._loss2}" + return name + + def forward(self, *args, **kwargs): + loss = self.compute_loss(*args, **kwargs) + if isinstance(loss, tuple): + loss, details = loss + elif loss.ndim == 0: + details = {self.get_name(): float(loss)} + else: + details = {} + loss = loss * self._alpha + + if self._loss2: + loss2, details2 = self._loss2(*args, **kwargs) + loss = loss + loss2 + details |= details2 + + return loss, details + + +class CausalLoss(MultiLoss): + def __init__(self, gradient_loss="grad", is_metric=True): + super().__init__() + + self.gradient_loss = gradient_loss + self.is_metric = is_metric + + def get_name(self): + return f"CausalLoss" + + def get_pts3d_from_views(self, gt_views, dist_clip=None, local=False, key_word="pts3d"): + """Get point clouds and valid masks for multiple views.""" + gt_pts_list = [] + valid_mask_list = [] + + if key_word == "pts3d": + mask_key_word = "valid_mask" + elif key_word == "track": + mask_key_word = "track_valid_mask" + else: + raise ValueError(f"Invalid key_word: {key_word}") + + if not local: # compute the inverse transformation for the anchor view (first view) + inv_matrix_anchor = inv(gt_views[0]["camera_pose"].float()) + + for gt_view in gt_views: + if local: + # Rotate GT points to align with the local camera origin for supervision + inv_matrix_local = inv(gt_view["camera_pose"].float()) + gt_pts = geotrf(inv_matrix_local, gt_view[key_word]) # Transform GT points to local view's origin + else: + # Use the anchor view (first view) transformation for global loss + gt_pts = geotrf(inv_matrix_anchor, gt_view[key_word]) # Transform GT points to anchor view + + valid_gt = gt_view[mask_key_word].clone() + + if dist_clip is not None: + dis = gt_pts.norm(dim=-1) + valid_gt &= dis <= dist_clip + + gt_pts_list.append(gt_pts) + valid_mask_list.append(valid_gt) + + gt_pts = torch.stack(gt_pts_list, dim=1) + valid_masks = torch.stack(valid_mask_list, dim=1) + + return gt_pts, valid_masks + + def get_depth_from_views(self, gt_views, dist_clip=None): + gt_pts_list = [] + valid_mask_list = [] + + mask_key_word = "valid_mask" + + for gt_view in gt_views: + gt_pts = gt_view["depthmap"] + valid_gt = gt_view[mask_key_word].clone() + + if dist_clip is not None: + dis = gt_pts.norm(dim=-1) + valid_gt &= dis <= dist_clip + + gt_pts_list.append(gt_pts) + valid_mask_list.append(valid_gt) + + gt_pts = torch.stack(gt_pts_list, dim=1) + valid_masks = torch.stack(valid_mask_list, dim=1) + + return gt_pts, valid_masks + + def get_camera_from_views(self, gt_views): + gt_extrinsic_list = [] + gt_intrinsic_list = [] + + image_size_hw = gt_views[0]["img"].shape[-2:] + for gt_view in gt_views: + gt_extrinsic_list.append(gt_view["camera_pose"]) + gt_intrinsic_list.append(gt_view["camera_intrinsics"]) + + gt_extrinsics = torch.stack(gt_extrinsic_list, dim=1) + gt_intrinsics = torch.stack(gt_intrinsic_list, dim=1) + + return gt_extrinsics, gt_intrinsics, image_size_hw + + def compute_loss(self, gts, preds, **kw): + details = {} + self_name = type(self).__name__ + + gt_pts3d_global, valid_mask_global = self.get_pts3d_from_views(gts, key_word="pts3d", **kw) # B, N, H, W, C + gt_depth, valid_mask_depth = self.get_depth_from_views(gts, **kw) # B, N, H, W, C + gt_extrinsics, gt_intrinsics, image_size_hw = self.get_camera_from_views(gts) + + pred_pts3d_global, pred_conf_global = preds["world_points"], preds["world_points_conf"] + pred_depth, pred_depth_conf = preds["depth"], preds["depth_conf"] + + # loss for pts3d global + loss_pts3d_global = point_loss(pred_pts3d_global, pred_conf_global, gt_pts3d_global, valid_mask_global, gradient_loss=self.gradient_loss, temporal_matching_loss=False, all_mean=True, valid_range=0.98, ormalize_pred=True, normalize_gt=True, normalize_using_first_view=False) + + # loss for depth + loss_depth = depth_loss(pred_depth, pred_depth_conf, gt_depth, valid_mask_depth, gradient_loss=self.gradient_loss, temporal_matching_loss=False, all_mean=True, valid_range=0.98, normalize_pred=True, normalize_gt=True, normalize_using_first_view=False) + gt_pts3d_scale = loss_depth[f"gt_pts3d_scale"] + pred_pts3d_scale = loss_depth[f"pred_pts3d_scale"] + + # loss for camera + pred_pose_enc_list = preds["pose_enc_list"] + loss_camera = camera_loss(pred_pose_enc_list, gt_extrinsics, gt_intrinsics, image_size_hw, loss_type="l1", gt_pts3d_scale=gt_pts3d_scale, pred_pts3d_scale=pred_pts3d_scale, pose_encoding_type="relT_quaR_FoV") + + # total loss + pts3d_loss = loss_pts3d_global["loss_conf"] + loss_pts3d_global["loss_grad"] + loss_depth["loss_conf"] + loss_depth["loss_grad"] + total_loss = pts3d_loss + loss_camera["loss_camera"] + + # logs + details[self_name + "_pts3d_loss" + "/00"] = float(pts3d_loss.detach()) + details[self_name + "_pts3d_loss_global" + "_conf" + "/00"] = float(loss_pts3d_global["loss_conf"].detach()) + details[self_name + "_pts3d_loss_global" + "_grad" + "/00"] = float(loss_pts3d_global["loss_grad"].detach()) + details[self_name + "_depth_loss" + "_conf" + "/00"] = float(loss_depth["loss_conf"].detach()) + details[self_name + "_depth_loss" + "_grad" + "/00"] = float(loss_depth["loss_grad"].detach()) + + details[self_name + "_camera_loss" + "_loss_camera" + "/00"] = float(loss_camera["loss_camera"].detach()) + details[self_name + "_camera_loss" + "_loss_T" + "/00"] = float(loss_camera["loss_T"].detach()) + details[self_name + "_camera_loss" + "_loss_R" + "/00"] = float(loss_camera["loss_R"].detach()) + details[self_name + "_camera_loss" + "_loss_fl" + "/00"] = float(loss_camera["loss_fl"].detach()) + + return total_loss, details \ No newline at end of file diff --git a/stream3r/loss/utils.py b/stream3r/loss/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15272340cef6ba9f01f212cff0bb98c03bb26886 --- /dev/null +++ b/stream3r/loss/utils.py @@ -0,0 +1,600 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F +from math import ceil, floor + +from stream3r.dust3r.utils.geometry import inv +from stream3r.models.components.utils.pose_enc import extri_intri_to_pose_encoding + + +def check_and_fix_inf_nan(loss_tensor, loss_name, hard_max = 100): + """ + Checks if 'loss_tensor' contains inf or nan. If it does, replace those + values with zero and print the name of the loss tensor. + + Args: + loss_tensor (torch.Tensor): The loss tensor to check. + loss_name (str): Name of the loss (for diagnostic prints). + + Returns: + torch.Tensor: The checked and fixed loss tensor, with inf/nan replaced by 0. + """ + + if torch.isnan(loss_tensor).any() or torch.isinf(loss_tensor).any(): + for _ in range(10): + print(f"{loss_name} has inf or nan. Setting those values to 0.") + loss_tensor = torch.where( + torch.isnan(loss_tensor) | torch.isinf(loss_tensor), + torch.tensor(0.0, device=loss_tensor.device), + loss_tensor + ) + + loss_tensor = torch.clamp(loss_tensor, min=-hard_max, max=hard_max) + + return loss_tensor + + +def camera_loss(pred_pose_enc_list, gt_extrinsic, gt_intrinsic, image_size_hw, gt_pts3d_scale, pred_pts3d_scale, loss_type="l1", gamma=0.6, pose_encoding_type="relT_quaR_FoV", weight_T = 1.0, weight_R = 1.0, weight_fl = 0.5, frame_num = -100): + num_predictions = len(pred_pose_enc_list) + + anchor_camera_inv = inv(gt_extrinsic[:, 0:1, :, :]) + # in dust3r frammework, dataset gt extrinsic is cam2world, but vggt predicts world2cam + gt_extrinsic_aligned = inv(anchor_camera_inv @ gt_extrinsic) + gt_pose_encoding = extri_intri_to_pose_encoding(gt_extrinsic_aligned, gt_intrinsic, image_size_hw, pose_encoding_type=pose_encoding_type, gt_pts3d_scale=gt_pts3d_scale) + + loss_T = loss_R = loss_fl = 0 + + for i in range(num_predictions): + i_weight = gamma ** (num_predictions - i - 1) + + cur_pred_pose_enc = pred_pose_enc_list[i] # B, S, 9 + cur_pred_pose_enc = torch.cat([ + cur_pred_pose_enc[:, :, :3] / pred_pts3d_scale.view(-1, 1, 1), + cur_pred_pose_enc[:, :, 3:] + ], dim=2) + + if frame_num>0: + loss_T_i, loss_R_i, loss_fl_i = camera_loss_single(cur_pred_pose_enc[:, :frame_num].clone(), gt_pose_encoding[:, :frame_num].clone(), loss_type=loss_type) + else: + loss_T_i, loss_R_i, loss_fl_i = camera_loss_single(cur_pred_pose_enc.clone(), gt_pose_encoding.clone(), loss_type=loss_type) + + loss_T += loss_T_i * i_weight + loss_R += loss_R_i * i_weight + loss_fl += loss_fl_i * i_weight + + loss_T = loss_T / num_predictions + loss_R = loss_R / num_predictions + loss_fl = loss_fl / num_predictions + loss_camera = loss_T * weight_T + loss_R * weight_R + loss_fl * weight_fl + + loss_dict = { + "loss_camera": loss_camera, + "loss_T": loss_T, + "loss_R": loss_R, + "loss_fl": loss_fl + } + + return loss_dict + + +def camera_loss_single(cur_pred_pose_enc, gt_pose_encoding, loss_type="l1"): + if loss_type == "l1": + loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).abs() + loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).abs() + loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).abs() + elif loss_type == "l2": + loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).norm(dim=-1, keepdim=True) + loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).norm(dim=-1) + loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).norm(dim=-1) + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + loss_T = check_and_fix_inf_nan(loss_T, "loss_T") + loss_R = check_and_fix_inf_nan(loss_R, "loss_R") + loss_fl = check_and_fix_inf_nan(loss_fl, "loss_fl") + + loss_T = loss_T.clamp(max=100) # TODO: remove this + loss_T = loss_T.mean() + loss_R = loss_R.mean() + loss_fl = loss_fl.mean() + + return loss_T, loss_R, loss_fl + + +def normalize_pointcloud(pts3d, valid_mask, normalize_using_first_view, eps=1e-3): + """ + pts3d: B, S, H, W, 3 + valid_mask: B, S, H, W + """ + if normalize_using_first_view: + dist = pts3d[:, 0:1, ...].norm(dim=-1) + valid_mask = valid_mask[:, 0:1, ...] + else: + dist = pts3d.norm(dim=-1) + + dist_sum = (dist * valid_mask).sum(dim=[1,2,3]) + valid_count = valid_mask.sum(dim=[1,2,3]) + + avg_scale = (dist_sum / (valid_count + eps)).clamp(min=eps, max=1e3) + + pts3d = pts3d / avg_scale.view(-1, 1, 1, 1, 1) + return pts3d, avg_scale + + +def depth_loss(depth, depth_conf, gt_depth, valid_mask, gamma=1.0, alpha=0.2, loss_type="conf", predict_disparity=False, affine_inv=False, gradient_loss= None, valid_range=-1, disable_conf=False, all_mean=False, normalize_gt=True, normalize_pred=False, normalize_using_first_view=False, normalize_with_metric_mask=False, is_metric_mask=None, gt_pts3d_scale=None, pred_pts3d_scale=None, **kwargs): + gt_depth = gt_depth[..., None] + + if loss_type == "conf": + conf_loss_dict = conf_loss(depth, depth_conf, gt_depth, valid_mask, + batch=None, normalize_pred=normalize_pred, normalize_gt=normalize_gt, + gamma=gamma, alpha=alpha, affine_inv=affine_inv, gradient_loss=gradient_loss, valid_range=valid_range, postfix="", disable_conf=disable_conf, all_mean=all_mean, normalize_using_first_view=normalize_using_first_view, normalize_with_metric_mask=normalize_with_metric_mask, is_metric_mask=is_metric_mask, gt_pts3d_scale=gt_pts3d_scale, pred_pts3d_scale=pred_pts3d_scale) + else: + raise ValueError(f"Invalid loss type: {loss_type}") + + return conf_loss_dict + + +def point_loss(pts3d, pts3d_conf, gt_pts3d, valid_mask, normalize_pred=True, normalize_gt=True, gamma=1.0, alpha=0.2, affine_inv=False, gradient_loss=None, valid_range=-1, camera_centric_reg=-1, disable_conf=False, all_mean=False, conf_loss_type="v1", gt_pts3d_scale=None, temporal_matching_loss=False, normalize_using_first_view=False, normalize_with_metric_mask=False, is_metric_mask=None, **kwargs): + """ + pts3d: B, S, H, W, 3 + pts3d_conf: B, S, H, W + gt_pts3d: B, S, H, W, 3 + valid_mask: B, S, H, W + """ + if conf_loss_type == "v1": + conf_loss_fn = conf_loss + else: + raise ValueError(f"Invalid conf loss type: {conf_loss_type}") + + conf_loss_dict = conf_loss_fn(pts3d, pts3d_conf, gt_pts3d, valid_mask, batch=None, + normalize_pred=normalize_pred, normalize_gt=normalize_gt, gamma=gamma, alpha=alpha, affine_inv=affine_inv, + gradient_loss=gradient_loss, valid_range=valid_range, camera_centric_reg=camera_centric_reg, disable_conf=disable_conf, all_mean=all_mean, gt_pts3d_scale=gt_pts3d_scale, temporal_matching_loss=temporal_matching_loss, normalize_using_first_view=normalize_using_first_view, normalize_with_metric_mask=normalize_with_metric_mask, is_metric_mask=is_metric_mask,) + + return conf_loss_dict + + +def filter_by_quantile(loss_tensor, valid_range, min_elements=1000, hard_max=100): + """ + Filters a loss tensor by keeping only values below a certain quantile threshold. + Also clamps individual values to hard_max. + + Args: + loss_tensor: Tensor containing loss values + valid_range: Float between 0 and 1 indicating the quantile threshold + min_elements: Minimum number of elements required to apply filtering + hard_max: Maximum allowed value for any individual loss + + Returns: + Filtered and clamped loss tensor + """ + if loss_tensor.numel() <= 1000: + # too small, just return + return loss_tensor + + # Randomly sample if tensor is too large + if loss_tensor.numel() > 100000000: + # Flatten and randomly select 1M elements + indices = torch.randperm(loss_tensor.numel(), device=loss_tensor.device)[:1_000_000] + loss_tensor = loss_tensor.view(-1)[indices] + + # First clamp individual values + loss_tensor = loss_tensor.clamp(max=hard_max) + + quantile_thresh = torch_quantile(loss_tensor.detach(), valid_range) + quantile_thresh = min(quantile_thresh, hard_max) + + # Apply quantile filtering if enough elements remain + quantile_mask = loss_tensor < quantile_thresh + if quantile_mask.sum() > min_elements: + return loss_tensor[quantile_mask] + return loss_tensor + + +def conf_loss(pts3d, pts3d_conf, gt_pts3d, valid_mask, batch, normalize_gt=True, normalize_pred=True, gamma=1.0, alpha=0.2, affine_inv=False, gradient_loss=None, valid_range=-1, camera_centric_reg=-1, disable_conf=False, all_mean=False, postfix="", gt_pts3d_scale=None, temporal_matching_loss=False, normalize_using_first_view=False, normalize_with_metric_mask=False, is_metric_mask=None, pred_pts3d_scale=None): + + # normalize + if gt_pts3d_scale is not None and pred_pts3d_scale is not None: + gt_pts3d = gt_pts3d / gt_pts3d_scale.view(-1, 1, 1, 1, 1) + pts3d = pts3d / pred_pts3d_scale.view(-1, 1, 1, 1, 1) + elif normalize_with_metric_mask: + assert is_metric_mask is not None + non_metric_mask = ~is_metric_mask + gt_pts3d_non_metric = gt_pts3d[non_metric_mask] + valid_mask_non_metric = valid_mask[non_metric_mask] + + # Normalize non-metric points + _, gt_pts3d_scale_non_metric = normalize_pointcloud(gt_pts3d_non_metric, valid_mask_non_metric, normalize_using_first_view) + + # for pred backpropagation, we have to normalize in place with divide operation, so we just get pred_scale here + pred_pts3d_non_metric = pts3d[non_metric_mask].clone().detach() + _, pred_pts3d_scale_non_metric = normalize_pointcloud(pred_pts3d_non_metric, valid_mask_non_metric, normalize_using_first_view) + + # Put normalized points back + gt_pts3d_scale = torch.ones_like(is_metric_mask, dtype=gt_pts3d.dtype) + gt_pts3d_scale[non_metric_mask] = gt_pts3d_scale_non_metric.to(gt_pts3d.dtype) + gt_pts3d = gt_pts3d / gt_pts3d_scale.view(-1, 1, 1, 1, 1) + + pred_pts3d_scale = torch.ones_like(is_metric_mask, dtype=pts3d.dtype) + pred_pts3d_scale[non_metric_mask] = pred_pts3d_scale_non_metric.to(pts3d.dtype) + pts3d = pts3d / pred_pts3d_scale.view(-1, 1, 1, 1, 1) + else: + if normalize_gt: + if gt_pts3d_scale is None: + gt_pts3d, gt_pts3d_scale = normalize_pointcloud(gt_pts3d, valid_mask, normalize_using_first_view) + else: + gt_pts3d = gt_pts3d / gt_pts3d_scale.view(-1, 1, 1, 1, 1) + + if normalize_pred: + pts3d, pred_pts3d_scale = normalize_pointcloud(pts3d, valid_mask, normalize_using_first_view) + + if affine_inv: + raise NotImplementedError() + # scale, shift = closed_form_scale_and_shift(pts3d, gt_pts3d, valid_mask) + # pts3d = pts3d * scale + shift + + loss_reg_first_frame, loss_reg_other_frames, loss_grad_first_frame, loss_grad_other_frames, loss_temporal_matching = reg_loss(pts3d, gt_pts3d, valid_mask, gradient_loss=gradient_loss, temporal_matching_loss=temporal_matching_loss) + + if disable_conf: + conf_loss_first_frame = gamma * loss_reg_first_frame + conf_loss_other_frames = gamma * loss_reg_other_frames + else: + first_frame_conf = pts3d_conf[:, 0:1, ...] + other_frames_conf = pts3d_conf[:, 1:, ...] + first_frame_mask = valid_mask[:, 0:1, ...] + other_frames_mask = valid_mask[:, 1:, ...] + + conf_loss_first_frame = gamma * loss_reg_first_frame * first_frame_conf[first_frame_mask] - alpha * torch.log(first_frame_conf[first_frame_mask]) + conf_loss_other_frames = gamma * loss_reg_other_frames * other_frames_conf[other_frames_mask] - alpha * torch.log(other_frames_conf[other_frames_mask]) + + if valid_range>0: + conf_loss_first_frame = filter_by_quantile(conf_loss_first_frame, valid_range) + conf_loss_other_frames = filter_by_quantile(conf_loss_other_frames, valid_range) + + conf_loss_first_frame = check_and_fix_inf_nan(conf_loss_first_frame, f"conf_loss_first_frame{postfix}") + conf_loss_other_frames = check_and_fix_inf_nan(conf_loss_other_frames, f"conf_loss_other_frames{postfix}") + + if all_mean and conf_loss_first_frame.numel() > 0 and conf_loss_other_frames.numel() > 0: + all_conf_loss = torch.cat([conf_loss_first_frame, conf_loss_other_frames]) + conf_loss = all_conf_loss.mean() if all_conf_loss.numel() > 0 else 0 + + # for logging only + conf_loss_first_frame = conf_loss_first_frame.mean() if conf_loss_first_frame.numel() > 0 else 0 + conf_loss_other_frames = conf_loss_other_frames.mean() if conf_loss_other_frames.numel() > 0 else 0 + else: + conf_loss_first_frame = conf_loss_first_frame.mean() if conf_loss_first_frame.numel() > 0 else 0 + conf_loss_other_frames = conf_loss_other_frames.mean() if conf_loss_other_frames.numel() > 0 else 0 + + conf_loss = conf_loss_first_frame + conf_loss_other_frames + + # Verified that the loss is the same + + loss_dict = { + f"loss_conf{postfix}": conf_loss, + f"loss_reg1{postfix}": loss_reg_first_frame.detach().mean() if loss_reg_first_frame.numel() > 0 else 0, + f"loss_reg2{postfix}": loss_reg_other_frames.detach().mean() if loss_reg_other_frames.numel() > 0 else 0, + f"loss_conf1{postfix}": conf_loss_first_frame, + f"loss_conf2{postfix}": conf_loss_other_frames, + } + + # loss_grad_first_frame and loss_grad_other_frames are already meaned + loss_grad = loss_grad_first_frame + loss_grad_other_frames + loss_dict[f"loss_grad1{postfix}"] = loss_grad_first_frame + loss_dict[f"loss_grad2{postfix}"] = loss_grad_other_frames + loss_dict[f"loss_grad{postfix}"] = loss_grad + + if temporal_matching_loss: + loss_dict[f"loss_temporal_matching{postfix}"] = loss_temporal_matching + else: + loss_dict[f"loss_temporal_matching{postfix}"] = 0 + + loss_dict[f"gt_pts3d_scale{postfix}"] = gt_pts3d_scale + loss_dict[f"pred_pts3d_scale{postfix}"] = pred_pts3d_scale + + return loss_dict + + +def reg_loss(pts3d, gt_pts3d, valid_mask, gradient_loss=None, temporal_matching_loss=False): + first_frame_pts3d = pts3d[:, 0:1, ...] + first_frame_gt_pts3d = gt_pts3d[:, 0:1, ...] + first_frame_mask = valid_mask[:, 0:1, ...] + + other_frames_pts3d = pts3d[:, 1:, ...] + other_frames_gt_pts3d = gt_pts3d[:, 1:, ...] + other_frames_mask = valid_mask[:, 1:, ...] + + loss_reg_first_frame = torch.norm(first_frame_gt_pts3d[first_frame_mask] - first_frame_pts3d[first_frame_mask], dim=-1) + loss_reg_other_frames = torch.norm(other_frames_gt_pts3d[other_frames_mask] - other_frames_pts3d[other_frames_mask], dim=-1) + + if gradient_loss == "grad": + bb, ss_f, hh, ww, nc = first_frame_pts3d.shape + loss_grad_first_frame = gradient_loss_multi_scale(first_frame_pts3d.reshape(bb*ss_f, hh, ww, nc), first_frame_gt_pts3d.reshape(bb*ss_f, hh, ww, nc), first_frame_mask.reshape(bb*ss_f, hh, ww)) + bb, ss_o, hh, ww, nc = other_frames_pts3d.shape + loss_grad_other_frames = gradient_loss_multi_scale(other_frames_pts3d.reshape(bb*ss_o, hh, ww, nc), other_frames_gt_pts3d.reshape(bb*ss_o, hh, ww, nc), other_frames_mask.reshape(bb*ss_o, hh, ww)) + + # we all mean gradient loss + loss_grad_other_frames *= (ss_o // ss_f) + + elif gradient_loss == "normal": + bb, ss, hh, ww, nc = first_frame_pts3d.shape + loss_grad_first_frame = gradient_loss_multi_scale(first_frame_pts3d.reshape(bb*ss, hh, ww, nc), first_frame_gt_pts3d.reshape(bb*ss, hh, ww, nc), first_frame_mask.reshape(bb*ss, hh, ww), gradient_loss_fn=normal_loss, scales=3) + bb, ss, hh, ww, nc = other_frames_pts3d.shape + loss_grad_other_frames = gradient_loss_multi_scale(other_frames_pts3d.reshape(bb*ss, hh, ww, nc), other_frames_gt_pts3d.reshape(bb*ss, hh, ww, nc), other_frames_mask.reshape(bb*ss, hh, ww), gradient_loss_fn=normal_loss, scales=3) + else: + loss_grad_first_frame = 0 + loss_grad_other_frames = 0 + + loss_reg_first_frame = check_and_fix_inf_nan(loss_reg_first_frame, "loss_reg_first_frame") + loss_reg_other_frames = check_and_fix_inf_nan(loss_reg_other_frames, "loss_reg_other_frames") + + if temporal_matching_loss: + # B, S, H, W, 3 + pred_diff = pts3d[:, 1:] - pts3d[:, :-1] + gt_diff = gt_pts3d[:, 1:] - gt_pts3d[:, :-1] + valid_mask = valid_mask[:, 1:] & valid_mask[:, :-1] + + loss_temporal_matching = F.l1_loss(pred_diff[valid_mask], gt_diff[valid_mask], reduction='none') + loss_temporal_matching = check_and_fix_inf_nan(loss_temporal_matching, "loss_temporal_matching") + valid_count = valid_mask.sum() + loss_temporal_matching = (loss_temporal_matching.sum() / valid_count) if valid_count > 0 else 0 + else: + loss_temporal_matching = 0 + + return loss_reg_first_frame, loss_reg_other_frames, loss_grad_first_frame, loss_grad_other_frames, loss_temporal_matching + + +def normal_loss(prediction, target, mask, cos_eps=1e-8, conf=None): + """ + Computes the normal-based loss by comparing the angle between + predicted normals and ground-truth normals. + + prediction: (B, H, W, 3) - Predicted 3D coordinates/points + target: (B, H, W, 3) - Ground-truth 3D coordinates/points + mask: (B, H, W) - Valid pixel mask (1 = valid, 0 = invalid) + + Returns: scalar (averaged over valid regions) + """ + pred_normals, pred_valids = point_map_to_normal(prediction, mask, eps=cos_eps) + gt_normals, gt_valids = point_map_to_normal(target, mask, eps=cos_eps) + + all_valid = pred_valids & gt_valids # shape: (4, B, H, W) + + # Early return if not enough valid points + divisor = torch.sum(all_valid) + if divisor < 10: + return 0 + + pred_normals = pred_normals[all_valid].clone() + gt_normals = gt_normals[all_valid].clone() + + # Compute cosine similarity between corresponding normals + # pred_normals and gt_normals are (4, B, H, W, 3) + # We want to compare corresponding normals where all_valid is True + dot = torch.sum(pred_normals * gt_normals, dim=-1) # shape: (4, B, H, W) + + # Clamp dot product to [-1, 1] for numerical stability + dot = torch.clamp(dot, -1 + cos_eps, 1 - cos_eps) + + # Compute loss as 1 - cos(theta), instead of arccos(dot) for numerical stability + loss = 1 - dot # shape: (4, B, H, W) + + + # Return mean loss if we have enough valid points + if loss.numel() < 10: + return 0 + else: + loss = check_and_fix_inf_nan(loss, "normal_loss") + + if conf is not None: + conf = conf[None, ...].expand(4, -1, -1, -1) + conf = conf[all_valid].clone() + + gamma = 1.0 # hard coded + alpha = 0.2 # hard coded + + loss = gamma * loss * conf - alpha * torch.log(conf) + return loss.mean() + else: + return loss.mean() + + +def point_map_to_normal(point_map, mask, eps=1e-6): + """ + point_map: (B, H, W, 3) - 3D points laid out in a 2D grid + mask: (B, H, W) - valid pixels (bool) + + Returns: + normals: (4, B, H, W, 3) - normal vectors for each of the 4 cross-product directions + valids: (4, B, H, W) - corresponding valid masks + """ + + with torch.cuda.amp.autocast(enabled=False): + # Pad inputs to avoid boundary issues + padded_mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) + pts = F.pad(point_map.permute(0, 3, 1, 2), (1,1,1,1), mode='constant', value=0).permute(0, 2, 3, 1) + + # Each pixel's neighbors + center = pts[:, 1:-1, 1:-1, :] # B,H,W,3 + up = pts[:, :-2, 1:-1, :] + left = pts[:, 1:-1, :-2 , :] + down = pts[:, 2:, 1:-1, :] + right = pts[:, 1:-1, 2:, :] + + # Direction vectors + up_dir = up - center + left_dir = left - center + down_dir = down - center + right_dir = right - center + + # Four cross products (shape: B,H,W,3 each) + n1 = torch.cross(up_dir, left_dir, dim=-1) # up x left + n2 = torch.cross(left_dir, down_dir, dim=-1) # left x down + n3 = torch.cross(down_dir, right_dir, dim=-1) # down x right + n4 = torch.cross(right_dir,up_dir, dim=-1) # right x up + + # Validity for each cross-product direction + # We require that both directions' pixels are valid + v1 = padded_mask[:, :-2, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, :-2] + v2 = padded_mask[:, 1:-1, :-2 ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 2:, 1:-1] + v3 = padded_mask[:, 2:, 1:-1] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, 1:-1, 2:] + v4 = padded_mask[:, 1:-1, 2: ] & padded_mask[:, 1:-1, 1:-1] & padded_mask[:, :-2, 1:-1] + + # Stack them to shape (4,B,H,W,3), (4,B,H,W) + normals = torch.stack([n1, n2, n3, n4], dim=0) # shape [4, B, H, W, 3] + valids = torch.stack([v1, v2, v3, v4], dim=0) # shape [4, B, H, W] + + # Normalize each direction's normal + # shape is (4, B, H, W, 3), so dim=-1 is the vector dimension + # clamp_min(eps) to avoid division by zero + # lengths = torch.norm(normals, dim=-1, keepdim=True).clamp_min(eps) + # normals = normals / lengths + normals = F.normalize(normals, p=2, dim=-1, eps=eps) + + + # Zero out invalid entries so they don't pollute subsequent computations + # normals = normals * valids.unsqueeze(-1) + + return normals, valids + + +def gradient_loss(prediction, target, mask, conf=None, gamma=1.0, alpha=0.2): + # prediction: B, H, W, C + # target: B, H, W, C + # mask: B, H, W + + mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1]) + M = torch.sum(mask, (1, 2, 3)) + + diff = prediction - target + diff = torch.mul(mask, diff) + + grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) + mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) + grad_x = torch.mul(mask_x, grad_x) + + grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) + mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) + grad_y = torch.mul(mask_y, grad_y) + + grad_x = grad_x.clamp(max=100) + grad_y = grad_y.clamp(max=100) + + + if conf is not None: + conf = conf[..., None].expand(-1, -1, -1, prediction.shape[-1]) + conf_x = conf[:, :, 1:] + conf_y = conf[:, 1:, :] + gamma = 1.0 + alpha = 0.2 + + grad_x = gamma * grad_x * conf_x - alpha * torch.log(conf_x) + grad_y = gamma * grad_y * conf_y - alpha * torch.log(conf_y) + + + image_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3)) + image_loss = check_and_fix_inf_nan(image_loss, "gradient_loss") + + divisor = torch.sum(M) + + if divisor == 0: + return 0 + else: + image_loss = torch.sum(image_loss) / divisor + + return image_loss + + +def gradient_loss_multi_scale(prediction, target, mask, scales=4, gradient_loss_fn = gradient_loss, conf=None): + """ + Compute gradient loss across multiple scales + """ + + total = 0 + for scale in range(scales): + step = pow(2, scale) + + total += gradient_loss_fn( + prediction[:, ::step, ::step], + target[:, ::step, ::step], + mask[:, ::step, ::step], + conf=conf[:, ::step, ::step] if conf is not None else None + ) + + total = total / scales + return total + + +def torch_quantile( + input: torch.Tensor, + q: float | torch.Tensor, + dim: int | None = None, + keepdim: bool = False, + *, + interpolation: str = "nearest", + out: torch.Tensor | None = None, +) -> torch.Tensor: + """Better torch.quantile for one SCALAR quantile. + + Using torch.kthvalue. Better than torch.quantile because: + - No 2**24 input size limit (pytorch/issues/67592), + - Much faster, at least on big input sizes. + + Arguments: + input (torch.Tensor): See torch.quantile. + q (float): See torch.quantile. Supports only scalar input + currently. + dim (int | None): See torch.quantile. + keepdim (bool): See torch.quantile. Supports only False + currently. + interpolation: {"nearest", "lower", "higher"} + See torch.quantile. + out (torch.Tensor | None): See torch.quantile. Supports only + None currently. + """ + # https://github.com/pytorch/pytorch/issues/64947 + # Sanitization: q + try: + q = float(q) + assert 0 <= q <= 1 + except Exception: + raise ValueError(f"Only scalar input 0<=q<=1 is currently supported (got {q})!") + + # Sanitization: dim + # Because one cannot pass `dim=None` to `squeeze()` or `kthvalue()` + if dim_was_none := dim is None: + dim = 0 + input = input.reshape((-1,) + (1,) * (input.ndim - 1)) + + # Sanitization: inteporlation + if interpolation == "nearest": + inter = round + elif interpolation == "lower": + inter = floor + elif interpolation == "higher": + inter = ceil + else: + raise ValueError( + "Supported interpolations currently are {'nearest', 'lower', 'higher'} " + f"(got '{interpolation}')!" + ) + + # Sanitization: out + if out is not None: + raise ValueError(f"Only None value is currently supported for out (got {out})!") + + # Logic + k = inter(q * (input.shape[dim] - 1)) + 1 + out = torch.kthvalue(input, k, dim, keepdim=True, out=out)[0] + + # Rectification: keepdim + if keepdim: + return out + if dim_was_none: + return out.squeeze() + else: + return out.squeeze(dim) \ No newline at end of file diff --git a/stream3r/models/__init__.py b/stream3r/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a49a3a9427e67836de554fed0bc7f6466adbe06 --- /dev/null +++ b/stream3r/models/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + diff --git a/stream3r/models/__pycache__/__init__.cpython-311.pyc b/stream3r/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b733c86efdf73b61e80845be613b9411fce00436 Binary files /dev/null and b/stream3r/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/stream3r/models/__pycache__/stream3r.cpython-311.pyc b/stream3r/models/__pycache__/stream3r.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffb02d3e818e7195a07cd1f9b08275cbc229e0c5 Binary files /dev/null and b/stream3r/models/__pycache__/stream3r.cpython-311.pyc differ diff --git a/stream3r/models/components/aggregator/__pycache__/streamaggregator.cpython-311.pyc b/stream3r/models/components/aggregator/__pycache__/streamaggregator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c028a4798305177d7db913f7a315b516f296d04 Binary files /dev/null and b/stream3r/models/components/aggregator/__pycache__/streamaggregator.cpython-311.pyc differ diff --git a/stream3r/models/components/aggregator/streamaggregator.py b/stream3r/models/components/aggregator/streamaggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..51608a08aced6d0dc84794741235afc15fbdd172 --- /dev/null +++ b/stream3r/models/components/aggregator/streamaggregator.py @@ -0,0 +1,379 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import torch +import torch.nn as nn +from typing import Tuple, List +from torch.utils.checkpoint import checkpoint + +from stream3r.models.components.layers import PatchEmbed +from stream3r.models.components.layers.block import Block +from stream3r.models.components.layers.rope import RotaryPositionEmbedding2D, PositionGetter +from stream3r.models.components.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class STreamAggregator(nn.Module): + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + patch_embed="dinov2_vitl14_reg", + aa_order=["frame", "global"], + aa_block_size=1, + qk_norm=True, + rope_freq=100, + init_values=0.01, + ): + super().__init__() + + self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) + + # Initialize rotary position embedding if frequency > 0 + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + self.position_getter = PositionGetter() if self.rope is not None else None + + self.frame_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.global_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.depth = depth + self.aa_order = aa_order + self.patch_size = patch_size + self.aa_block_size = aa_block_size + + # Validate that depth is divisible by aa_block_size + if self.depth % self.aa_block_size != 0: + raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") + + self.aa_block_num = self.depth // self.aa_block_size + + # Note: We have two camera tokens, one for the first frame and one for the rest + # The same applies for register tokens + self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) + self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) + + # The patch tokens start after the camera and register tokens + self.patch_start_idx = 1 + num_register_tokens + + # Initialize parameters with small values + nn.init.normal_(self.camera_token, std=1e-6) + nn.init.normal_(self.register_token, std=1e-6) + + # Register normalization constants as buffers + for name, value in ( + ("_resnet_mean", _RESNET_MEAN), + ("_resnet_std", _RESNET_STD), + ): + self.register_buffer( + name, + torch.FloatTensor(value).view(1, 1, 3, 1, 1), + persistent=False, + ) + + def __build_patch_embed__( + self, + patch_embed, + img_size, + patch_size, + num_register_tokens, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + embed_dim=1024, + ): + """ + Build the patch embed layer. If 'conv', we use a + simple PatchEmbed conv layer. Otherwise, we use a vision transformer. + """ + + if "conv" in patch_embed: + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) + else: + vit_models = { + "dinov2_vitl14_reg": vit_large, + "dinov2_vitb14_reg": vit_base, + "dinov2_vits14_reg": vit_small, + "dinov2_vitg2_reg": vit_giant2, + } + + self.patch_embed = vit_models[patch_embed]( + img_size=img_size, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + block_chunks=block_chunks, + init_values=init_values, + ) + + # Disable gradient updates for mask token + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + def _create_attn_mask(self, S: int, P: int, mode: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + N = S * P + mask = torch.zeros((N, N), dtype=dtype, device=device) + + if mode == "causal": + for i in range(S): + curr_view_start = i * P + curr_view_end = (i + 1) * P + mask[curr_view_start:curr_view_end, curr_view_end:] = float('-inf') + elif mode == "window": + window_size = 5 + for i in range(S): + curr_view_start = i * P + curr_view_end = (i + 1) * P + mask[curr_view_start:curr_view_end, P:] = float('-inf') + start_view = max(1, i - window_size + 1) + mask[curr_view_start:curr_view_end, start_view*P:(i+1)*P] = 0 + elif mode == "full": + mask = None + else: + raise NotImplementedError(f"Unknown attention mode: {mode}") + + return mask + + def forward( + self, + images: torch.Tensor, + mode: str = "causal", + kv_cache_list: List[List[torch.Tensor]] = None + ) -> Tuple[List[torch.Tensor], int]: + """ + Args: + images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + mode (str): Global attention mode, could be either "causal", "window" or "full" + kv_cache_list (List[List[torch.Tensor]]): List of cached key-value pairs for + each global attention layer of the aggregator + + Returns: + (list[torch.Tensor], int): + The list of outputs from the attention blocks, + and the patch_start_idx indicating where patch tokens begin. + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + # Normalize images and reshape for patch embed + images = (images - self._resnet_mean) / self._resnet_std + + # Reshape to [B*S, C, H, W] for patch embedding + images = images.view(B * S, C_in, H, W) + + patch_tokens = self.patch_embed(images) + + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + _, P, C = patch_tokens.shape + + # Expand camera and register tokens to match batch size and sequence length + is_anchor_exist = kv_cache_list is None or kv_cache_list[0][0] is None + camera_token = slice_expand_and_flatten(self.camera_token, B, S, is_anchor_exist=is_anchor_exist) + register_token = slice_expand_and_flatten(self.register_token, B, S, is_anchor_exist=is_anchor_exist) + + # Concatenate special tokens with patch tokens + tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) + + pos = None + if self.rope is not None: + pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + 1 + pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) + pos = torch.cat([pos_special, pos], dim=1) + + # update P because we added special tokens + _, P, C = tokens.shape + + attn_mask = None + if kv_cache_list is None: + attn_mask = self._create_attn_mask(S, P, mode, tokens.dtype, tokens.device) + + frame_idx = 0 + global_idx = 0 + output_list = [] + + for _ in range(self.aa_block_num): + for attn_type in self.aa_order: + if attn_type == "frame": + tokens, frame_idx, frame_intermediates = self._process_frame_attention( + tokens, B, S, P, C, frame_idx, pos=pos + ) + elif attn_type == "global": + if kv_cache_list is not None: + kv_cache = kv_cache_list[global_idx] + tokens, global_idx, global_intermediates, kv_cache = self._process_global_attention( + tokens, B, S, P, C, global_idx, pos=pos, attn_mask=attn_mask, kv_cache=kv_cache + ) + kv_cache_list[global_idx-1] = kv_cache + else: + tokens, global_idx, global_intermediates = self._process_global_attention( + tokens, B, S, P, C, global_idx, pos=pos, attn_mask=attn_mask + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + for i in range(len(frame_intermediates)): + # concat frame and global intermediates, [B x S x P x 2C] + concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) + output_list.append(concat_inter) + + del concat_inter + del frame_intermediates + del global_intermediates + + if kv_cache_list is not None: + return output_list, self.patch_start_idx, kv_cache_list + else: + return output_list, self.patch_start_idx + + def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): + """ + Process frame attention blocks. We keep tokens in shape (B*S, P, C). + """ + # If needed, reshape tokens or positions: + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B, S, P, C).view(B * S, P, C) + + if pos is not None and pos.shape != (B * S, P, 2): + pos = pos.view(B, S, P, 2).view(B * S, P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + tokens = checkpoint( + self.frame_blocks[frame_idx], + tokens, + pos, + use_reentrant=False + ) + frame_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None, attn_mask=None, kv_cache=None): + """ + Process global attention blocks. We keep tokens in shape (B, S*P, C). + """ + if tokens.shape != (B, S * P, C): + tokens = tokens.view(B, S, P, C).view(B, S * P, C) + + if pos is not None and pos.shape != (B, S * P, 2): + pos = pos.view(B, S, P, 2).view(B, S * P, 2) + + intermediates = [] + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + if kv_cache is not None: + tokens, kv_cache = checkpoint( + self.global_blocks[global_idx], + tokens, + pos, + attn_mask, + kv_cache, + use_reentrant=False + ) + else: + tokens = checkpoint( + self.global_blocks[global_idx], + tokens, + pos, + attn_mask, + use_reentrant=False + ) + global_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + if kv_cache is not None: + return tokens, global_idx, intermediates, kv_cache + + return tokens, global_idx, intermediates + + +def slice_expand_and_flatten(token_tensor, B, S, is_anchor_exist=False): + """ + Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: + 1) Uses the first position (index=0) for the first frame only + 2) Uses the second position (index=1) for all remaining frames (S-1 frames) + 3) Expands both to match batch size B + 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token + followed by (S-1) second-position tokens + 5) Flattens to (B*S, X, C) for processing + + Returns: + torch.Tensor: Processed tokens with shape (B*S, X, C) + """ + + # Slice out the "query" tokens => shape (1, 1, ...) + if is_anchor_exist: + query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) + else: + query = token_tensor[:, 1:, ...].expand(B, 1, *token_tensor.shape[2:]) + # Slice out the "other" tokens => shape (1, S-1, ...) + others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) + # Concatenate => shape (B, S, ...) + combined = torch.cat([query, others], dim=1) + + # Finally flatten => shape (B*S, ...) + combined = combined.view(B * S, *combined.shape[2:]) + return combined \ No newline at end of file diff --git a/stream3r/models/components/heads/__pycache__/camera_head.cpython-311.pyc b/stream3r/models/components/heads/__pycache__/camera_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a68a003dcca1fd5f4a80094cb5ed9ab0f789c792 Binary files /dev/null and b/stream3r/models/components/heads/__pycache__/camera_head.cpython-311.pyc differ diff --git a/stream3r/models/components/heads/__pycache__/dpt_head.cpython-311.pyc b/stream3r/models/components/heads/__pycache__/dpt_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e9c71113e61e58bd7a3cb4f36b53e740d9f6290 Binary files /dev/null and b/stream3r/models/components/heads/__pycache__/dpt_head.cpython-311.pyc differ diff --git a/stream3r/models/components/heads/__pycache__/head_act.cpython-311.pyc b/stream3r/models/components/heads/__pycache__/head_act.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca8980123948aa3f9063e356947561d299432d53 Binary files /dev/null and b/stream3r/models/components/heads/__pycache__/head_act.cpython-311.pyc differ diff --git a/stream3r/models/components/heads/__pycache__/utils.cpython-311.pyc b/stream3r/models/components/heads/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5435b15c03a49b7a5cdb14cac0b1933affc752e0 Binary files /dev/null and b/stream3r/models/components/heads/__pycache__/utils.cpython-311.pyc differ diff --git a/stream3r/models/components/heads/camera_head.py b/stream3r/models/components/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..499b3e56102d0d9ae9dc4b2916ce05508adbfffd --- /dev/null +++ b/stream3r/models/components/heads/camera_head.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +import torch +import torch.nn as nn + +from stream3r.models.components.layers import Mlp +from stream3r.models.components.layers.block import Block +from stream3r.models.components.heads.head_act import activate_pose + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_in, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_in // 2, + out_features=self.target_dim, + drop=0, + ) + + def _create_attn_mask(self, S: int, mode: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + N = S + mask = torch.zeros((N, N), dtype=dtype, device=device) + + if mode == "causal": + for i in range(S): + curr_view_start = i + curr_view_end = (i + 1) + mask[curr_view_start:curr_view_end, curr_view_end:] = float('-inf') + elif mode == "window": + window_size = 5 + for i in range(S): + curr_view_start = i + curr_view_end = (i + 1) + mask[curr_view_start:curr_view_end, 1:] = float('-inf') + start_view = max(1, i - window_size + 1) + mask[curr_view_start:curr_view_end, start_view:(i+1)] = 0 + elif mode == "full": + mask = None + else: + raise NotImplementedError(f"Unknown attention mode: {mode}") + + return mask + + def forward( + self, + aggregated_tokens_list: list, + num_iterations: int = 4, + mode: str = "causal", + kv_cache_list: List[List[List[torch.Tensor]]] = None + ) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + mode (str): Global attention mode, could be either "causal", "window" or "full" + kv_cache_list (List[List[List[torch.Tensor]]]): List of cached key-value pairs for + each iterations and each attention layer of the camera head + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + B, S, C = pose_tokens.shape + attn_mask = None + if kv_cache_list is None: + attn_mask = self._create_attn_mask(S, mode, pose_tokens.dtype, pose_tokens.device) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations, attn_mask, kv_cache_list) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int, attn_mask: torch.Tensor, kv_cache_list: List[Tuple[torch.Tensor, torch.Tensor]] = None) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape + pred_pose_enc = None + pred_pose_enc_list = [] + + for iter in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + for i in range(self.trunk_depth): + if kv_cache_list is not None: + pose_tokens_modulated, kv_cache_list[iter][i] = self.trunk[i](pose_tokens_modulated, attn_mask=attn_mask, kv_cache=kv_cache_list[iter][i]) + else: + pose_tokens_modulated = self.trunk[i](pose_tokens_modulated, attn_mask=attn_mask) + + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, + trans_act=self.trans_act, + quat_act=self.quat_act, + fl_act=self.fl_act, + ) + pred_pose_enc_list.append(activated_pose) + + if kv_cache_list is not None: + return pred_pose_enc_list, kv_cache_list + else: + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift diff --git a/stream3r/models/components/heads/dpt_head.py b/stream3r/models/components/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..555848340278724a702c0a2c9855a179d7eeaaa3 --- /dev/null +++ b/stream3r/models/components/heads/dpt_head.py @@ -0,0 +1,497 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch( + out_channels, + features, + expand=False, + ) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.view(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/stream3r/models/components/heads/head_act.py b/stream3r/models/components/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489 --- /dev/null +++ b/stream3r/models/components/heads/head_act.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/stream3r/models/components/heads/utils.py b/stream3r/models/components/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..533fc8ae67a75cd0a94d5ca96dc5a0513446c64f --- /dev/null +++ b/stream3r/models/components/heads/utils.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + device = pos.device + omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/stream3r/models/components/layers/__init__.py b/stream3r/models/components/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/stream3r/models/components/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/stream3r/models/components/layers/__pycache__/__init__.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c65dd6934cdcf910f7e5c638f8381e5dc31d6a0a Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/__pycache__/attention.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf2ab498332f862f4181eeb81f8530431bbf43cd Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/attention.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/__pycache__/block.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/block.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1020e1e081ae3e424fe0cc05c26eaee5640bb7c7 Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/block.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/__pycache__/drop_path.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/drop_path.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08a59a79aa097d9d0a8de67d6f755bcd5603cab6 Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/drop_path.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/__pycache__/layer_scale.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/layer_scale.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61907f54e9cc6c53ea615c43b8ca16287b5429b6 Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/layer_scale.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/__pycache__/mlp.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63362e9319c321a3fd8c82beaeafd614155bf7d4 Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/mlp.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/__pycache__/patch_embed.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/patch_embed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21e175ddbf41669b3132646b291ad603af159df7 Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/patch_embed.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/__pycache__/rope.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/rope.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ab5e08310ccd396fa3200c5bd3dce02e26e8a92 Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/rope.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/__pycache__/swiglu_ffn.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/swiglu_ffn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c43aded34f79981ac55018b9c3c63066dd7c0916 Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/swiglu_ffn.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/__pycache__/vision_transformer.cpython-311.pyc b/stream3r/models/components/layers/__pycache__/vision_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e32aa75b04fcc488793d9952e538264c9d4c2080 Binary files /dev/null and b/stream3r/models/components/layers/__pycache__/vision_transformer.cpython-311.pyc differ diff --git a/stream3r/models/components/layers/attention.py b/stream3r/models/components/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..46d4d32a5f940ea44750bc884d45d269e44bbc34 --- /dev/null +++ b/stream3r/models/components/layers/attention.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import torch +from torch import Tensor +from torch import nn +import torch.nn.functional as F + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None, attn_mask=None, kv_cache=None) -> Tensor: + B, N, C = x.shape + + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if kv_cache is not None: + k_cache, v_cache = kv_cache + if k_cache is not None and v_cache is not None: + k = torch.cat([k_cache, k], dim=2) + v = torch.cat([v_cache, v], dim=2) + kv_cache = [k, v] + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=attn_mask, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if attn_mask is not None: + attn = attn + attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + + x = self.proj_drop(x) + + if kv_cache is not None: + return x, kv_cache + + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: + assert pos is None + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/stream3r/models/components/layers/block.py b/stream3r/models/components/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..cef5b78b6c90a5b9be5f5a8fb909bf1c9a97c3a0 --- /dev/null +++ b/stream3r/models/components/layers/block.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None, attn_mask=None, kv_cache=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None, attn_mask=None, kv_cache=None) -> Tensor: + if kv_cache is not None: + x, kv_cache = self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask, kv_cache=kv_cache) + return self.ls1(x), kv_cache + elif attn_mask is not None: + return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) + else: + return self.ls1(self.attn(self.norm1(x), pos=pos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + assert attn_mask is None and kv_cache is None, "attn_mask and kv_cache are not supported for drop_add_residual_stochastic_depth yet" + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + pos=pos, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + if kv_cache is not None: + delta_x, kv_cache = attn_residual_func(x, pos=pos, attn_mask=attn_mask, kv_cache=kv_cache) + else: + delta_x = attn_residual_func(x, pos=pos, attn_mask=attn_mask) + x = x + self.drop_path1(delta_x) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + if kv_cache is not None: + delta_x, kv_cache = attn_residual_func(x, pos=pos, attn_mask=attn_mask, kv_cache=kv_cache) + else: + delta_x = attn_residual_func(x, pos=pos, attn_mask=attn_mask) + x = x + delta_x + x = x + ffn_residual_func(x) + + if kv_cache is not None: + return x, kv_cache + else: + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, + pos=None, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/stream3r/models/components/layers/drop_path.py b/stream3r/models/components/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/stream3r/models/components/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/stream3r/models/components/layers/layer_scale.py b/stream3r/models/components/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/stream3r/models/components/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/stream3r/models/components/layers/mlp.py b/stream3r/models/components/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/stream3r/models/components/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/stream3r/models/components/layers/patch_embed.py b/stream3r/models/components/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/stream3r/models/components/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/stream3r/models/components/layers/rope.py b/stream3r/models/components/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5d33304e55dbd05687bd86752a47a80e5f82df --- /dev/null +++ b/stream3r/models/components/layers/rope.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) + horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/stream3r/models/components/layers/swiglu_ffn.py b/stream3r/models/components/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..54fe8e90b7bedf6fbdbf09c6215844e3cc63f857 --- /dev/null +++ b/stream3r/models/components/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +# try: +# if XFORMERS_ENABLED: +# from xformers.ops import SwiGLU + +# XFORMERS_AVAILABLE = True +# warnings.warn("xFormers is available (SwiGLU)") +# else: +# warnings.warn("xFormers is disabled (SwiGLU)") +# raise ImportError +# except ImportError: +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + +# warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/stream3r/models/components/layers/vision_transformer.py b/stream3r/models/components/layers/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..120cbe6c26650d212e50aefc497669abdc937467 --- /dev/null +++ b/stream3r/models/components/layers/vision_transformer.py @@ -0,0 +1,407 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ +from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + # tricky but makes it work + self.use_checkpoint = False + # + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc b/stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cd8f52fdaf392844af85c63da1d3ba2ba5204d1 Binary files /dev/null and b/stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc differ diff --git a/stream3r/models/components/utils/__pycache__/load_fn.cpython-311.pyc b/stream3r/models/components/utils/__pycache__/load_fn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62ea113d189ce979408ca2998c2d5c5fc078e792 Binary files /dev/null and b/stream3r/models/components/utils/__pycache__/load_fn.cpython-311.pyc differ diff --git a/stream3r/models/components/utils/__pycache__/pose_enc.cpython-311.pyc b/stream3r/models/components/utils/__pycache__/pose_enc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa6bf7eebff58b46c3b82f1f86d7854e8e26ae54 Binary files /dev/null and b/stream3r/models/components/utils/__pycache__/pose_enc.cpython-311.pyc differ diff --git a/stream3r/models/components/utils/__pycache__/rotation.cpython-311.pyc b/stream3r/models/components/utils/__pycache__/rotation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95dda6921ef7a47a5b61312c2c0775afc2b63ed3 Binary files /dev/null and b/stream3r/models/components/utils/__pycache__/rotation.cpython-311.pyc differ diff --git a/stream3r/models/components/utils/geometry.py b/stream3r/models/components/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..8ebd25dbc6cac6b0095956524c4f0628410dd5cb --- /dev/null +++ b/stream3r/models/components/utils/geometry.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import numpy as np + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix diff --git a/stream3r/models/components/utils/load_fn.py b/stream3r/models/components/utils/load_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..d786e98a950f880342da9a13664be4fa32eb0bfa --- /dev/null +++ b/stream3r/models/components/utils/load_fn.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from PIL import Image +from torchvision import transforms as TF + + +def load_and_preprocess_images(image_path_list, mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. + + Args: + image_path_list (list): List of paths to image files + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + # First process all images and collect their shapes + for image_path in image_path_list: + + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background: + if img.mode == "RGBA": + # Create white background + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + # Alpha composite onto the white background + img = Image.alpha_composite(background, img) + + # Now convert to "RGB" (this step assigns white for transparent areas) + img = img.convert("RGB") + + width, height = img.size + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 + else: + new_height = target_size + new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(image_path_list) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + + return images diff --git a/stream3r/models/components/utils/pose_enc.py b/stream3r/models/components/utils/pose_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc8d565cacc9506f4cb2fb9f297f65ceb525175 --- /dev/null +++ b/stream3r/models/components/utils/pose_enc.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from .rotation import quat_to_mat, mat_to_quat + + +def extri_intri_to_pose_encoding( + extrinsics, + intrinsics, + image_size_hw=None, # e.g., (256, 512) + pose_encoding_type="absT_quaR_FoV", + gt_pts3d_scale=None, +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + elif pose_encoding_type == "relT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + T = T / gt_pts3d_scale.view(-1, 1, 1) + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, + image_size_hw=None, # e.g., (256, 512) + pose_encoding_type="absT_quaR_FoV", + build_intrinsics=True, +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + else: + raise NotImplementedError + + return extrinsics, intrinsics diff --git a/stream3r/models/components/utils/rotation.py b/stream3r/models/components/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..657583e6915437c824c192d51939990b589a14fa --- /dev/null +++ b/stream3r/models/components/utils/rotation.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import numpy as np +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/stream3r/models/components/utils/visual_track.py b/stream3r/models/components/utils/visual_track.py new file mode 100644 index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154 --- /dev/null +++ b/stream3r/models/components/utils/visual_track.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import numpy as np +import os + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + # Simple combination: + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + # cmap(c) -> (r,g,b,a) in [0,1] + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) # in [0,1], RGB order + + +def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + # treat all as visible + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + # Find first visible frame for track i + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + # track is never visible; just assign black or something + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + # use that frame's (x,y) + x, y = tracks_b[first_s, i].tolist() + + # map (x,y) -> (R,G,B) in [0,1] + r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) + # scale to [0,255] + r, g, b = int(r * 255), int(g * 255), int(b * 255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", # "CHW" or "HWC" + normalize_mode="[0,1]", + cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" + frames_per_row=4, # New parameter for grid layout + save_grid=True, # Flag to control whether to save the grid image +): + """ + Visualizes frames in a grid layout with specified frames per row. + Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + Also saves each individual frame as a separate PNG file. + + Args: + images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. + tracks: torch.Tensor (S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + frames_per_row: number of frames to display in each row of the grid. + save_grid: whether to save all frames in one grid image. + + Returns: + None (saves images in out_dir). + """ + + if len(tracks.shape) == 4: + tracks = tracks.squeeze(0) + images = images.squeeze(0) + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.squeeze(0) + + import matplotlib + + matplotlib.use("Agg") # for non-interactive (optional) + + os.makedirs(out_dir, exist_ok=True) + + S = images.shape[0] + _, N, _ = tracks.shape # (S, N, 2) + + # Move to CPU + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + # Infer H, W from images shape + if image_format == "CHW": + # e.g. images[s].shape = (3, H, W) + H, W = images.shape[2], images.shape[3] + else: + # e.g. images[s].shape = (H, W, 3) + H, W = images.shape[1], images.shape[2] + + # Pre-compute the color for each track i based on first visible position + track_colors_rgb = get_track_colors_by_position( + tracks, # shape (S, N, 2) + vis_mask_b=track_vis_mask if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name, + ) + + # We'll accumulate each frame's drawn image in a list + frame_images = [] + + for s in range(S): + # shape => either (3, H, W) or (H, W, 3) + img = images[s] + + # Convert to (H, W, 3) + if image_format == "CHW": + img = img.permute(1, 2, 0) # (H, W, 3) + # else "HWC", do nothing + + img = img.numpy().astype(np.float32) + + # Scale to [0,255] if needed + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + # else no normalization + + # Convert to uint8 + img = img.astype(np.uint8) + + # For drawing in OpenCV, convert to BGR + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + # Draw each visible track + cur_tracks = tracks[s] # shape (N, 2) + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + # Convert back to RGB for consistent final saving: + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + # Save individual frame + frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") + # Convert to BGR for OpenCV imwrite + frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(frame_path, frame_bgr) + + frame_images.append(img_rgb) + + # Only create and save the grid image if save_grid is True + if save_grid: + # Calculate grid dimensions + num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division + + # Create a grid of images + grid_img = None + for row in range(num_rows): + start_idx = row * frames_per_row + end_idx = min(start_idx + frames_per_row, S) + + # Concatenate this row horizontally + row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) + + # If this row has fewer than frames_per_row images, pad with black + if end_idx - start_idx < frames_per_row: + padding_width = (frames_per_row - (end_idx - start_idx)) * W + padding = np.zeros((H, padding_width, 3), dtype=np.uint8) + row_img = np.concatenate([row_img, padding], axis=1) + + # Add this row to the grid + if grid_img is None: + grid_img = row_img + else: + grid_img = np.concatenate([grid_img, row_img], axis=0) + + out_path = os.path.join(out_dir, "tracks_grid.png") + # Convert back to BGR for OpenCV imwrite + grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, grid_img_bgr) + print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") + + print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png") diff --git a/stream3r/models/multiview_dust3r_module.py b/stream3r/models/multiview_dust3r_module.py new file mode 100644 index 0000000000000000000000000000000000000000..5f291ee15ef7c09d31f92c8a80506294a0e5a412 --- /dev/null +++ b/stream3r/models/multiview_dust3r_module.py @@ -0,0 +1,343 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Tuple +import numpy as np +import re +import roma +import torch +from torch.distributed import all_gather_object, barrier +from lightning import LightningModule +from lightning.pytorch.loggers.wandb import WandbLogger +from torchmetrics import MaxMetric, MeanMetric, MinMetric, SumMetric, Metric +from torchmetrics.aggregation import BaseAggregator +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR + +from stream3r.dust3r.model import FlashDUSt3R +from stream3r.models.stream3r import STream3R +from stream3r.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + +class AccumulatedSum(BaseAggregator): + def __init__( + self, + **kwargs: Any, + ) -> None: + super().__init__( + fn="sum", + default_value=torch.tensor(0.0, dtype=torch.long), + nan_strategy='warn', + state_name="sum_value", + **kwargs, + ) + + def update(self, value: int) -> None: + self.sum_value += value + + def compute(self) -> torch.LongTensor: + return self.sum_value + +class MultiViewDUSt3RLitModule(LightningModule): + def __init__( + self, + net: torch.nn.Module, + train_criterion: torch.nn.Module, + validation_criterion: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + compile: bool, + pretrained: Optional[str] = None, + resume_from_checkpoint: Optional[str] = None, + eval_use_pts3d_from_local_head: bool = True, + ) -> None: + super().__init__() + + self.save_hyperparameters(logger=False, ignore=['net', 'train_criterion', 'validation_criterion']) + + self.net = net + self.train_criterion = train_criterion + self.validation_criterion = validation_criterion + self.pretrained = pretrained + self.resume_from_checkpoint = resume_from_checkpoint + self.eval_use_pts3d_from_local_head = eval_use_pts3d_from_local_head + + # use register_buffer to save these with checkpoints + # so that when we resume training, these bookkeeping variables are preserved + self.register_buffer("epoch_fraction", torch.tensor(0.0, dtype=torch.float32, device=self.device)) + self.register_buffer("train_total_samples", torch.tensor(0, dtype=torch.long, device=self.device)) + self.register_buffer("train_total_images", torch.tensor(0, dtype=torch.long, device=self.device)) + + self.train_total_samples_per_step = AccumulatedSum() # these need to be reduced across GPUs, so use Metric + self.train_total_images_per_step = AccumulatedSum() # these need to be reduced across GPUs, so use Metric + + self.val_loss = MeanMetric() + + @classmethod + def load_for_inference(cls, net: STream3R): + lit_module = cls(net=net, train_criterion=None, validation_criterion=None, optimizer=None, scheduler=None, compile=False) + lit_module.eval() + return lit_module + + def forward(self, views: List[Dict[str, torch.Tensor]]) -> Any: + return self.net(views) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # Legacy: if the checkpoint does not contain the epoch_fraction, train_total_samples, and train_total_images + # we manually add them to the checkpoint + # if self.trainer.strategy.strategy_name != "deepseed": + # if checkpoint["state_dict"].get("epoch_fraction") is None: + # checkpoint["state_dict"]["epoch_fraction"] = self.epoch_fraction + # if checkpoint["state_dict"].get("train_total_samples") is None: + # checkpoint["state_dict"]["train_total_samples"] = self.train_total_samples + # if checkpoint["state_dict"].get("train_total_images") is None: + # checkpoint["state_dict"]["train_total_images"] = self.train_total_images + pass + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + # by default lightning executes validation step sanity checks before training starts, + # so it's worth to make sure validation metrics don't store results from these checks + self.val_loss.reset() + + # the wandb logger lives in self.loggers + # find the wandb logger and watch the model and gradients + for logger in self.loggers: + if isinstance(logger, WandbLogger): + self.wandb_logger = logger + # log gradients, parameter histogram and model topology + self.wandb_logger.watch(self.net, log="all", log_freq=500, log_graph=False) + + def on_train_epoch_start(self) -> None: + # save initial checkpoint to check pretrained model + # if self.trainer.global_step == 0: + # checkpoint_path = os.path.join(self.trainer.checkpoint_callback.dirpath, "step_0.ckpt") + # self.trainer.save_checkpoint(checkpoint_path) + + # our custom dataset and sampler has to have epoch set by calling set_epoch + if hasattr(self.trainer.train_dataloader, "dataset") and hasattr(self.trainer.train_dataloader.dataset, "set_epoch"): + self.trainer.train_dataloader.dataset.set_epoch(self.current_epoch) + if hasattr(self.trainer.train_dataloader, "sampler") and hasattr(self.trainer.train_dataloader.sampler, "set_epoch"): + self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) + + def on_validation_epoch_start(self) -> None: + # our custom dataset and sampler has to have epoch set by calling set_epoch + for loader in self.trainer.val_dataloaders: + if hasattr(loader, "dataset") and hasattr(loader.dataset, "set_epoch"): + loader.dataset.set_epoch(0) + if hasattr(loader, "sampler") and hasattr(loader.sampler, "set_epoch"): + loader.sampler.set_epoch(0) + + def model_step( + self, batch: List[Dict[str, torch.Tensor]], criterion: torch.nn.Module, + ) -> Tuple[torch.Tensor, Dict]: + device = self.device + + # Move data to device + for view in batch: + for name in "img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres".split(): + if name in view: + view[name] = view[name].to(device, non_blocking=True) + + views = batch + + preds = self.forward(views) + + # Compute the loss in higher precision + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + loss, loss_details = criterion(views, preds) if criterion is not None else None + + return views, preds, loss, loss_details + + def training_step( + self, batch: List[Dict[str, torch.Tensor]], batch_idx: int + ) -> torch.Tensor: + views, preds, loss, loss_details = self.model_step(batch, self.train_criterion) + + if not isinstance(loss, (torch.Tensor, dict, type(None))): # this will cause a lightning.fabric.utilities.exceptions.MisconfigurationException + # log loss and the batch information to help debugging + # use print instead of log because the logger only logs on rank 0, but this could happen on any rank + print(f"Loss is not a tensor or dict but {type(loss)}, value: {loss}") + print(f"Loss details: {loss_details}") + print(f"Batch: {batch}") + print(f"Batch index: {batch_idx}") + print(f"Views: {views}") + print(f"Preds: {preds}") + loss = None # set loss to None will still break the training loop in DDP, this is intended - we should fix the data to avoid nan loss in the first place + return loss + + self.epoch_fraction = torch.tensor(self.trainer.current_epoch + batch_idx / self.trainer.num_training_batches, device=self.device) + + self.log("trainer/epoch", self.epoch_fraction, on_step=True, on_epoch=False, prog_bar=True) + self.log("trainer/lr", self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True) + self.log("train/loss", loss, on_step=True, on_epoch=False, prog_bar=True) + + # log the details of the loss + if loss_details is not None: + for key, value in loss_details.items(): + self.log(f"train_detail_{key}", value, on_step=True, on_epoch=False, prog_bar=False) + match = re.search(r'/(\d{1,2})$', key) + if match: + stripped_key = key[:match.start()] + self.log(f"train/{stripped_key}", value, on_step=True, on_epoch=False, prog_bar=False) + + # Log the total number of samples seen so far + batch_size = views[0]["img"].shape[0] + self.train_total_samples_per_step(batch_size) # aggregate across all GPUs + self.train_total_samples += self.train_total_samples_per_step.compute() # accumulate across all steps + self.train_total_samples_per_step.reset() + self.log("trainer/total_samples", self.train_total_samples, on_step=True, on_epoch=False, prog_bar=False) + + # Log the total number of images seen so far + num_views = len(views) + n_image_cur_step = batch_size * num_views + self.train_total_images_per_step(n_image_cur_step) # aggregate across all GPUs + self.train_total_images += self.train_total_images_per_step.compute() # accumulate across all steps + self.train_total_images_per_step.reset() + self.log("trainer/total_images", self.train_total_images, on_step=True, on_epoch=False, prog_bar=False) + + return loss + + def validation_step( + self, batch: List[Dict[str, torch.Tensor]], batch_idx: int, dataloader_idx: int = 0, + ) -> torch.Tensor: + views, preds, loss, loss_details = self.model_step(batch, self.validation_criterion) + + # Extract the dataset name and batch size + dataset_name = views[0]['dataset'][0] # all views should have the same dataset name because we use "sequential" mode of CombinedLoader + batch_size = views[0]["img"].shape[0] + + self.val_loss(loss) + + for key, value in loss_details.items(): + self.log( + f"val_detail_{dataset_name}_{key}", + value, + on_step=False, + on_epoch=True, + prog_bar=False, + reduce_fx="mean", + sync_dist=True, + add_dataloader_idx=False, + batch_size=batch_size, + ) + match = re.search(r'/(\d{1,2})$', key) + if match: + stripped_key = key[:match.start()] + self.log(f"val/{dataset_name}_{stripped_key}", value, on_step=False, on_epoch=True, prog_bar=False, reduce_fx="mean", sync_dist=True, add_dataloader_idx=False, batch_size=batch_size) + + loss_value = loss.detach().cpu().item() + del loss, loss_details + torch.cuda.empty_cache() + + del views, preds + torch.cuda.empty_cache() + + return loss_value + + def on_validation_epoch_end(self) -> None: + self.log("val/loss", self.val_loss, prog_bar=True) + + # if we dont do these, wandb for some reason cannot display the validation loss with them as the x-axis + self.log("trainer/epoch", self.epoch_fraction, sync_dist=True) + self.log("trainer/total_samples", self.train_total_samples.cpu().item(), sync_dist=True) + self.log("trainer/total_images", self.train_total_images.cpu().item(), sync_dist=True) + + # def test_step( + # self, batch: List[Dict[str, torch.Tensor]], batch_idx: int + # ) -> None: + # pass + + def configure_optimizers(self) -> Dict[str, Any]: + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + + if self.hparams.scheduler is not None: + scheduler_config = self.hparams.scheduler + + # HACK: if the class is pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR, + # both warmup_epochs and max_epochs should be scaled. + # more specifically, max_epochs should be scaled to total number of steps that we will have during training, + # and warmup_epochs should be scaled up proportionally. + if scheduler_config.func is LinearWarmupCosineAnnealingLR: + # Extract the keyword arguments from the partial object + scheduler_kwargs = {k: v for k, v in scheduler_config.keywords.items()} + original_warmup_epochs = scheduler_kwargs['warmup_epochs'] + original_max_epochs = scheduler_kwargs['max_epochs'] + + total_steps = self.trainer.estimated_stepping_batches # total number of total steps in all training epochs + + # Scale warmup_epochs and max_epochs + scaled_warmup_epochs = int(original_warmup_epochs * total_steps / original_max_epochs) + scaled_max_epochs = total_steps + + # Update the kwargs with scaled values + scheduler_kwargs.update({ + 'warmup_epochs': scaled_warmup_epochs, + 'max_epochs': scaled_max_epochs + }) + + # Re-initialize the scheduler with updated parameters + scheduler = LinearWarmupCosineAnnealingLR( + optimizer=optimizer, + **scheduler_kwargs + ) + else: + scheduler = scheduler_config(optimizer=optimizer) + + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'name': 'train/lr', # put lr inside train group in loggers + 'scheduler': scheduler, + 'interval': 'step' if scheduler_config.func is LinearWarmupCosineAnnealingLR else 'epoch', + 'frequency': 1, + } + } + + return {"optimizer": optimizer} + + def setup(self, stage: str) -> None: + if self.hparams.compile and stage == "fit": + self.net = torch.compile(self.net) + + # Load pretrained weights if available and not resuming + # note that if resume_from_checkpoint is set, the Trainer is responsible for actually loading the checkpoint + # so we are only using resume_from_checkpoint as a check of whether we should load the pretrained weights + if self.pretrained and not self.resume_from_checkpoint: + self._load_pretrained_weights() + + def _load_pretrained_weights(self) -> None: + log.info(f"Loading pretrained: {self.pretrained}") + if isinstance(self.net, FlashDUSt3R): # if the model is FlashDUSt3R, use the weights of the first head only + ckpt = torch.load(self.pretrained) + ckpt = self._update_ckpt_keys(ckpt, new_head_name='downstream_head', head_to_keep='downstream_head1', head_to_discard='downstream_head2') + self.net.load_state_dict(ckpt["model"], strict=False) + del ckpt # in case it occupies memory + elif isinstance(self.net, STream3R): + # if the checkpoint is also STream3R, load all weights + log.info(f"Loading pretrained weights from {self.pretrained}") + checkpoint = torch.load(self.pretrained) + missing_keys, unexpected_keys = self.net.load_state_dict(checkpoint, strict=False) + log.info(f"Missing keys: {missing_keys}") + log.info(f"Unexpected keys: {unexpected_keys}") + + @staticmethod + def _update_ckpt_keys(ckpt, new_head_name='downstream_head', head_to_keep='downstream_head1', head_to_discard='downstream_head2'): + """Helper function to use the weights of a model with multiple heads in a model with a single head. + specifically, keep only the weights of the first head and delete the weights of the second head. + """ + new_ckpt = {'model': {}} + + for key, value in ckpt['model'].items(): + if key.startswith(head_to_keep): + new_key = key.replace(head_to_keep, new_head_name) + new_ckpt['model'][new_key] = value + elif key.startswith(head_to_discard): + continue + else: + new_ckpt['model'][key] = value + + return new_ckpt \ No newline at end of file diff --git a/stream3r/models/stream3r.py b/stream3r/models/stream3r.py new file mode 100644 index 0000000000000000000000000000000000000000..3f72c3184bfa40500eea32d480dd41da63d0268c --- /dev/null +++ b/stream3r/models/stream3r.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple, List +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin # used for model hub + +from stream3r.dust3r.utils.misc import freeze_all_params +from stream3r.models.components.aggregator.streamaggregator import STreamAggregator +from stream3r.models.components.heads.camera_head import CameraHead +from stream3r.models.components.heads.dpt_head import DPTHead + + +class STream3R(nn.Module, PyTorchModelHubMixin): + def __init__(self, img_size=518, patch_size=14, embed_dim=1024, freeze="none"): + super().__init__() + + self.aggregator = STreamAggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) + self.camera_head = CameraHead(dim_in=2 * embed_dim) + self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1") + self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1") + + self.set_freeze(freeze) + + def set_freeze(self, freeze): + self.freeze = freeze + + to_be_frozen = { + "none": [], + "encoder": [self.aggregator.patch_embed], + } + freeze_all_params(to_be_frozen[freeze]) + + def forward( + self, + images: torch.Tensor, + mode: str = "causal", + aggregator_kv_cache_list: List[List[torch.Tensor]] = None, + camera_head_kv_cache_list: List[List[List[torch.Tensor]]] = None, + ): + """ + Forward pass of the STream3R model. + + Args: + images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + mode (str): Global attention mode, could be either "causal", "window", "full" + aggregator_kv_cache_list (List[List[torch.Tensor]]): List of cached key-value pairs for + each global attention layer of the aggregator + camera_head_kv_cache_list (List[List[List[torch.Tensor]]]): List of cached key-value pairs for + each iterations and each attention layer of the camera head + + Returns: + dict: A dictionary containing the following predictions: + - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) + - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] + - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] + - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] + - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] + - images (torch.Tensor): Original input images, preserved for visualization + """ + if self.training: + images = torch.stack([view["img"] for view in images], dim=1) + images = (images + 1.) / 2. + + # If without batch dimension, add it + if len(images.shape) == 4: + images = images.unsqueeze(0) + + if aggregator_kv_cache_list is not None: + aggregated_tokens_list, patch_start_idx, aggregator_kv_cache_list = self.aggregator(images, mode=mode, kv_cache_list=aggregator_kv_cache_list) + else: + aggregated_tokens_list, patch_start_idx = self.aggregator(images, mode=mode) + + predictions = {} + + with torch.autocast(device_type=next(self.parameters()).device.type, dtype=torch.float32): + if self.camera_head is not None: + if camera_head_kv_cache_list is not None: + pose_enc_list, camera_head_kv_cache_list = self.camera_head(aggregated_tokens_list, mode=mode, kv_cache_list=camera_head_kv_cache_list) + else: + pose_enc_list = self.camera_head(aggregated_tokens_list, mode=mode) + predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration + if self.training: + predictions["pose_enc_list"] = pose_enc_list + + if self.point_head is not None: + pts3d, pts3d_conf = self.point_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + predictions["world_points"] = pts3d + predictions["world_points_conf"] = pts3d_conf + + if self.depth_head is not None: + depth, depth_conf = self.depth_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + predictions["depth"] = depth + predictions["depth_conf"] = depth_conf + + if aggregator_kv_cache_list is not None: + predictions["aggregator_kv_cache_list"] = aggregator_kv_cache_list + + if camera_head_kv_cache_list is not None: + predictions["camera_head_kv_cache_list"] = camera_head_kv_cache_list + + if not self.training: + predictions["images"] = images + + return predictions \ No newline at end of file diff --git a/stream3r/stream_session.py b/stream3r/stream_session.py new file mode 100644 index 0000000000000000000000000000000000000000..bf46fbbd012285bdf9239ab8b82d170b67180525 --- /dev/null +++ b/stream3r/stream_session.py @@ -0,0 +1,99 @@ +import torch +from stream3r.models.stream3r import STream3R + + +class StreamSession: + """ + A causal streaming inference session with KV cache management for STream3R. + """ + def __init__(self, model: STream3R, mode: str): + self.model = model + self.mode = mode + self.aggregator_kv_cache_depth = model.aggregator.depth + self.camera_head_kv_cache_depth = model.camera_head.trunk_depth + self.camera_head_iterations = 4 + + if self.mode not in ["causal", "window"]: + raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}") + + self.clear() + + def _clear_predictions(self): + self.predictions = dict() + + def _update_predictions(self, predictions): + for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]: + if k in predictions: + self.predictions[k] = torch.cat( + [self.predictions.get(k, torch.empty(0, device=predictions[k].device)), predictions[k]], + dim=1 + ) + + def _clear_cache(self): + self.aggregator_kv_cache_list = [[None, None] for _ in range(self.aggregator_kv_cache_depth)] + self.camera_head_kv_cache_list = [[[None, None] for _ in range(self.camera_head_kv_cache_depth)] for _ in range(self.camera_head_iterations)] + + def _update_cache(self, aggregator_kv_cache_list, camera_head_kv_cache_list): + if self.mode == "causal": + self.aggregator_kv_cache_list = aggregator_kv_cache_list + self.camera_head_kv_cache_list = camera_head_kv_cache_list + elif self.mode == "window": + window_size = 5 + for k in range(2): + for i in range(self.aggregator_kv_cache_depth): + h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3] + P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx + anchor_token = aggregator_kv_cache_list[i][k][:, :, :P] + window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-window_size*P):] + self.aggregator_kv_cache_list[i][k] = torch.cat( + [ + anchor_token, + window_tokens + ], + dim=2 + ) + for i in range(self.camera_head_iterations): + for j in range(self.camera_head_kv_cache_depth): + anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1] + window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-window_size):] + self.camera_head_kv_cache_list[i][j][k] = torch.cat( + [ + anchor_token, + window_tokens + ], + dim=2 + ) + else: + raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}") + + def _get_cache(self): + return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list + + def get_all_predictions(self): + return self.predictions + + def get_last_prediction(self): + last_predictions = dict() + for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]: + if k in self.predictions: + last_predictions[k] = self.predictions[k][:, -1:] + return last_predictions + + def clear(self): + self._clear_predictions() + self._clear_cache() + + def forward_stream(self, images): + aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache() + + outputs = self.model( + images=images, + mode=self.mode, + aggregator_kv_cache_list=aggregator_kv_cache_list, + camera_head_kv_cache_list=camera_head_kv_cache_list, + ) + + self._update_predictions(outputs) + self._update_cache(outputs["aggregator_kv_cache_list"], outputs["camera_head_kv_cache_list"]) + + return self.get_all_predictions() diff --git a/stream3r/train.py b/stream3r/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0aa8791af5a05abc4200008f717b081946f30ed6 --- /dev/null +++ b/stream3r/train.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import lightning as L +import rootutils +import torch +import signal # noqa: F401 +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig, OmegaConf + + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from stream3r.utils import ( + RankedLogger, + extras, + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +def python_eval_resolver(code: str): + return eval(code) + + +# Register the resolver with OmegaConf +# usage: ${python_code:1 + 1} in yaml +OmegaConf.register_new_resolver("python_eval", python_eval_resolver) + + +@task_wrapper +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data.data_module._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data.data_module) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() diff --git a/stream3r/utils/__init__.py b/stream3r/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a011fd57c80cd40a1f2d5011ea2c4213a6e736cf --- /dev/null +++ b/stream3r/utils/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from stream3r.utils.instantiators import instantiate_callbacks, instantiate_loggers +from stream3r.utils.logging_utils import log_hyperparameters +from stream3r.utils.pylogger import RankedLogger +from stream3r.utils.rich_utils import enforce_tags, print_config_tree +from stream3r.utils.utils import extras, get_metric_value, task_wrapper diff --git a/stream3r/utils/__pycache__/__init__.cpython-311.pyc b/stream3r/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab0140018b716247a09cb87bee653864a8cb5189 Binary files /dev/null and b/stream3r/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/stream3r/utils/__pycache__/instantiators.cpython-311.pyc b/stream3r/utils/__pycache__/instantiators.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c5d6857bb8764013e1c19cbf5bab9927613c555 Binary files /dev/null and b/stream3r/utils/__pycache__/instantiators.cpython-311.pyc differ diff --git a/stream3r/utils/__pycache__/logging_utils.cpython-311.pyc b/stream3r/utils/__pycache__/logging_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9ac54b89358b32ccbc6f2c5d3ca71d588ef36f8 Binary files /dev/null and b/stream3r/utils/__pycache__/logging_utils.cpython-311.pyc differ diff --git a/stream3r/utils/__pycache__/pylogger.cpython-311.pyc b/stream3r/utils/__pycache__/pylogger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99088307043dac6b6570e01950053a3ccc3bbc6b Binary files /dev/null and b/stream3r/utils/__pycache__/pylogger.cpython-311.pyc differ diff --git a/stream3r/utils/__pycache__/rich_utils.cpython-311.pyc b/stream3r/utils/__pycache__/rich_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49e17be0e79d06d36ea6d5d856b8860af25cad93 Binary files /dev/null and b/stream3r/utils/__pycache__/rich_utils.cpython-311.pyc differ diff --git a/stream3r/utils/__pycache__/utils.cpython-311.pyc b/stream3r/utils/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b7a6668a400e3f9b85a65702cfeb515eb65bade Binary files /dev/null and b/stream3r/utils/__pycache__/utils.cpython-311.pyc differ diff --git a/stream3r/utils/__pycache__/visual_utils.cpython-311.pyc b/stream3r/utils/__pycache__/visual_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d72a2f481c99889706538bf01fb3e2ebf97ef05 Binary files /dev/null and b/stream3r/utils/__pycache__/visual_utils.cpython-311.pyc differ diff --git a/stream3r/utils/instantiators.py b/stream3r/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..c03a06c88d4274e6a7270e83dbefa7d36c8c37e4 --- /dev/null +++ b/stream3r/utils/instantiators.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from stream3r.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/stream3r/utils/logging_utils.py b/stream3r/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0af4576c6b3f130c666363e1590bc81601510d28 --- /dev/null +++ b/stream3r/utils/logging_utils.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict + +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import OmegaConf + +from stream3r.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/stream3r/utils/pylogger.py b/stream3r/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..533f56324d0127208bfbb86172f418e4c0b35007 --- /dev/null +++ b/stream3r/utils/pylogger.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/stream3r/utils/rich_utils.py b/stream3r/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b7bea9d8ee08b281c2f945dd8ed74b801db008 --- /dev/null +++ b/stream3r/utils/rich_utils.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from stream3r.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/stream3r/utils/utils.py b/stream3r/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1fbe844baa401865fcc28dee0db5e1d17944f055 --- /dev/null +++ b/stream3r/utils/utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from importlib.util import find_spec +from typing import Any, Callable, Dict, Optional, Tuple +import torchvision.transforms as tvf +from omegaconf import DictConfig + +from stream3r.utils import pylogger, rich_utils + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + +ImgDust3r2Stream3r = tvf.Normalize(mean=[-1, -1, -1], std=[2, 2, 2]) + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: If provided, the name of the metric to retrieve. + :return: If a metric name was provided, the value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value diff --git a/stream3r/utils/visual_utils.py b/stream3r/utils/visual_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52599589d3ae8052e7ca630db5f73dee31a90aa0 --- /dev/null +++ b/stream3r/utils/visual_utils.py @@ -0,0 +1,456 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import trimesh +import numpy as np +import matplotlib +from scipy.spatial.transform import Rotation +import copy +import cv2 +import os +import requests + + +def predictions_to_glb( + predictions, + conf_thres=50.0, + filter_by_frames="all", + mask_black_bg=False, + mask_white_bg=False, + show_cam=True, + mask_sky=False, + target_dir=None, + prediction_mode="Predicted Pointmap", +) -> trimesh.Scene: + """ + Converts predictions to a 3D scene represented as a GLB file. + + Args: + predictions (dict): Dictionary containing model predictions with keys: + - world_points: 3D point coordinates (S, H, W, 3) + - world_points_conf: Confidence scores (S, H, W) + - images: Input images (S, H, W, 3) + - extrinsic: Camera extrinsic matrices (S, 3, 4) + conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0) + filter_by_frames (str): Frame filter specification (default: "all") + mask_black_bg (bool): Mask out black background pixels (default: False) + mask_white_bg (bool): Mask out white background pixels (default: False) + show_cam (bool): Include camera visualization (default: True) + mask_sky (bool): Apply sky segmentation mask (default: False) + target_dir (str): Output directory for intermediate files (default: None) + prediction_mode (str): Prediction mode selector (default: "Predicted Pointmap") + + Returns: + trimesh.Scene: Processed 3D scene containing point cloud and cameras + + Raises: + ValueError: If input predictions structure is invalid + """ + if not isinstance(predictions, dict): + raise ValueError("predictions must be a dictionary") + + if conf_thres is None: + conf_thres = 10.0 + + print("Building GLB scene") + selected_frame_idx = None + if filter_by_frames != "all" and filter_by_frames != "All": + try: + # Extract the index part before the colon + selected_frame_idx = int(filter_by_frames.split(":")[0]) + except (ValueError, IndexError): + pass + + if "Pointmap" in prediction_mode: + print("Using Pointmap Branch") + if "world_points" in predictions: + pred_world_points = predictions["world_points"] # No batch dimension to remove + pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0])) + else: + print("Warning: world_points not found in predictions, falling back to depth-based points") + pred_world_points = predictions["world_points_from_depth"] + pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) + else: + print("Using Depthmap and Camera Branch") + pred_world_points = predictions["world_points_from_depth"] + pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) + + # Get images from predictions + images = predictions["images"] + # Use extrinsic matrices instead of pred_extrinsic_list + camera_matrices = predictions["extrinsic"] + + if mask_sky: + if target_dir is not None: + import onnxruntime + + skyseg_session = None + target_dir_images = target_dir + "/images" + image_list = sorted(os.listdir(target_dir_images)) + sky_mask_list = [] + + # Get the shape of pred_world_points_conf to match + S, H, W = ( + pred_world_points_conf.shape + if hasattr(pred_world_points_conf, "shape") + else (len(images), images.shape[1], images.shape[2]) + ) + + # Download skyseg.onnx if it doesn't exist + if not os.path.exists("skyseg.onnx"): + print("Downloading skyseg.onnx...") + download_file_from_url( + "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx" + ) + + for i, image_name in enumerate(image_list): + image_filepath = os.path.join(target_dir_images, image_name) + mask_filepath = os.path.join(target_dir, "sky_masks", image_name) + + # Check if mask already exists + if os.path.exists(mask_filepath): + # Load existing mask + sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) + else: + # Generate new mask + if skyseg_session is None: + skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") + sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath) + + # Resize mask to match H×W if needed + if sky_mask.shape[0] != H or sky_mask.shape[1] != W: + sky_mask = cv2.resize(sky_mask, (W, H)) + + sky_mask_list.append(sky_mask) + + # Convert list to numpy array with shape S×H×W + sky_mask_array = np.array(sky_mask_list) + + # Apply sky mask to confidence scores + sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32) + pred_world_points_conf = pred_world_points_conf * sky_mask_binary + + if selected_frame_idx is not None: + pred_world_points = pred_world_points[selected_frame_idx][None] + pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None] + images = images[selected_frame_idx][None] + camera_matrices = camera_matrices[selected_frame_idx][None] + + vertices_3d = pred_world_points.reshape(-1, 3) + # Handle different image formats - check if images need transposing + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + colors_rgb = np.transpose(images, (0, 2, 3, 1)) + else: # Assume already in NHWC format + colors_rgb = images + colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) + + conf = pred_world_points_conf.reshape(-1) + # Convert percentage threshold to actual confidence value + if conf_thres == 0.0: + conf_threshold = 0.0 + else: + conf_threshold = np.percentile(conf, conf_thres) + + conf_mask = (conf >= conf_threshold) & (conf > 1e-5) + + if mask_black_bg: + black_bg_mask = colors_rgb.sum(axis=1) >= 16 + conf_mask = conf_mask & black_bg_mask + + if mask_white_bg: + # Filter out white background pixels (RGB values close to white) + # Consider pixels white if all RGB values are above 240 + white_bg_mask = ~((colors_rgb[:, 0] > 240) & (colors_rgb[:, 1] > 240) & (colors_rgb[:, 2] > 240)) + conf_mask = conf_mask & white_bg_mask + + vertices_3d = vertices_3d[conf_mask] + colors_rgb = colors_rgb[conf_mask] + + if vertices_3d is None or np.asarray(vertices_3d).size == 0: + vertices_3d = np.array([[1, 0, 0]]) + colors_rgb = np.array([[255, 255, 255]]) + scene_scale = 1 + else: + # Calculate the 5th and 95th percentiles along each axis + lower_percentile = np.percentile(vertices_3d, 5, axis=0) + upper_percentile = np.percentile(vertices_3d, 95, axis=0) + + # Calculate the diagonal length of the percentile bounding box + scene_scale = np.linalg.norm(upper_percentile - lower_percentile) + + colormap = matplotlib.colormaps.get_cmap("gist_rainbow") + + # Initialize a 3D scene + scene_3d = trimesh.Scene() + + # Add point cloud data to the scene + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + + scene_3d.add_geometry(point_cloud_data) + + # Prepare 4x4 matrices for camera extrinsics + num_cameras = len(camera_matrices) + extrinsics_matrices = np.zeros((num_cameras, 4, 4)) + extrinsics_matrices[:, :3, :4] = camera_matrices + extrinsics_matrices[:, 3, 3] = 1 + + if show_cam: + # Add camera models to the scene + for i in range(num_cameras): + world_to_camera = extrinsics_matrices[i] + camera_to_world = np.linalg.inv(world_to_camera) + rgba_color = colormap(i / num_cameras) + current_color = tuple(int(255 * x) for x in rgba_color[:3]) + + integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) + + # Align scene to the observation of the first camera + scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices) + + print("GLB Scene built") + return scene_3d + + +def integrate_camera_into_scene(scene: trimesh.Scene, transform: np.ndarray, face_colors: tuple, scene_scale: float): + """ + Integrates a fake camera mesh into the 3D scene. + + Args: + scene (trimesh.Scene): The 3D scene to add the camera model. + transform (np.ndarray): Transformation matrix for camera positioning. + face_colors (tuple): Color of the camera face. + scene_scale (float): Scale of the scene. + """ + + cam_width = scene_scale * 0.05 + cam_height = scene_scale * 0.1 + + # Create cone shape for camera + rot_45_degree = np.eye(4) + rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix() + rot_45_degree[2, 3] = -cam_height + + opengl_transform = get_opengl_conversion_matrix() + # Combine transformations + complete_transform = transform @ opengl_transform @ rot_45_degree + camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4) + + # Generate mesh for the camera + slight_rotation = np.eye(4) + slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix() + + vertices_combined = np.concatenate( + [ + camera_cone_shape.vertices, + 0.95 * camera_cone_shape.vertices, + transform_points(slight_rotation, camera_cone_shape.vertices), + ] + ) + vertices_transformed = transform_points(complete_transform, vertices_combined) + + mesh_faces = compute_camera_faces(camera_cone_shape) + + # Add the camera mesh to the scene + camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces) + camera_mesh.visual.face_colors[:, :3] = face_colors + scene.add_geometry(camera_mesh) + + +def apply_scene_alignment(scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray) -> trimesh.Scene: + """ + Aligns the 3D scene based on the extrinsics of the first camera. + + Args: + scene_3d (trimesh.Scene): The 3D scene to be aligned. + extrinsics_matrices (np.ndarray): Camera extrinsic matrices. + + Returns: + trimesh.Scene: Aligned 3D scene. + """ + # Set transformations for scene alignment + opengl_conversion_matrix = get_opengl_conversion_matrix() + + # Rotation matrix for alignment (180 degrees around the y-axis) + align_rotation = np.eye(4) + align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix() + + # Apply transformation + initial_transformation = np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation + scene_3d.apply_transform(initial_transformation) + return scene_3d + + +def get_opengl_conversion_matrix() -> np.ndarray: + """ + Constructs and returns the OpenGL conversion matrix. + + Returns: + numpy.ndarray: A 4x4 OpenGL conversion matrix. + """ + # Create an identity matrix + matrix = np.identity(4) + + # Flip the y and z axes + matrix[1, 1] = -1 + matrix[2, 2] = -1 + + return matrix + + +def transform_points(transformation: np.ndarray, points: np.ndarray, dim: int = None) -> np.ndarray: + """ + Applies a 4x4 transformation to a set of points. + + Args: + transformation (np.ndarray): Transformation matrix. + points (np.ndarray): Points to be transformed. + dim (int, optional): Dimension for reshaping the result. + + Returns: + np.ndarray: Transformed points. + """ + points = np.asarray(points) + initial_shape = points.shape[:-1] + dim = dim or points.shape[-1] + + # Apply transformation + transformation = transformation.swapaxes(-1, -2) # Transpose the transformation matrix + points = points @ transformation[..., :-1, :] + transformation[..., -1:, :] + + # Reshape the result + result = points[..., :dim].reshape(*initial_shape, dim) + return result + + +def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray: + """ + Computes the faces for the camera mesh. + + Args: + cone_shape (trimesh.Trimesh): The shape of the camera cone. + + Returns: + np.ndarray: Array of faces for the camera mesh. + """ + # Create pseudo cameras + faces_list = [] + num_vertices_cone = len(cone_shape.vertices) + + for face in cone_shape.faces: + if 0 in face: + continue + v1, v2, v3 = face + v1_offset, v2_offset, v3_offset = face + num_vertices_cone + v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone + + faces_list.extend( + [ + (v1, v2, v2_offset), + (v1, v1_offset, v3), + (v3_offset, v2, v3), + (v1, v2, v2_offset_2), + (v1, v1_offset_2, v3), + (v3_offset_2, v2, v3), + ] + ) + + faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] + return np.array(faces_list) + + +def segment_sky(image_path, onnx_session, mask_filename=None): + """ + Segments sky from an image using an ONNX model. + Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing + + Args: + image_path: Path to input image + onnx_session: ONNX runtime session with loaded model + mask_filename: Path to save the output mask + + Returns: + np.ndarray: Binary mask where 255 indicates non-sky regions + """ + + assert mask_filename is not None + image = cv2.imread(image_path) + + result_map = run_skyseg(onnx_session, [320, 320], image) + # resize the result_map to the original image size + result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) + + # Fix: Invert the mask so that 255 = non-sky, 0 = sky + # The model outputs low values for sky, high values for non-sky + output_mask = np.zeros_like(result_map_original) + output_mask[result_map_original < 32] = 255 # Use threshold of 32 + + os.makedirs(os.path.dirname(mask_filename), exist_ok=True) + cv2.imwrite(mask_filename, output_mask) + return output_mask + + +def run_skyseg(onnx_session, input_size, image): + """ + Runs sky segmentation inference using ONNX model. + + Args: + onnx_session: ONNX runtime session + input_size: Target size for model input (width, height) + image: Input image in BGR format + + Returns: + np.ndarray: Segmentation mask + """ + + # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast + temp_image = copy.deepcopy(image) + resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) + x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) + x = np.array(x, dtype=np.float32) + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + x = (x / 255 - mean) / std + x = x.transpose(2, 0, 1) + x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") + + # Inference + input_name = onnx_session.get_inputs()[0].name + output_name = onnx_session.get_outputs()[0].name + onnx_result = onnx_session.run([output_name], {input_name: x}) + + # Post process + onnx_result = np.array(onnx_result).squeeze() + min_value = np.min(onnx_result) + max_value = np.max(onnx_result) + onnx_result = (onnx_result - min_value) / (max_value - min_value) + onnx_result *= 255 + onnx_result = onnx_result.astype("uint8") + + return onnx_result + + +def download_file_from_url(url, filename): + """Downloads a file from a Hugging Face model repo, handling redirects.""" + try: + # Get the redirect URL + response = requests.get(url, allow_redirects=False) + response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx) + + if response.status_code == 302: # Expecting a redirect + redirect_url = response.headers["Location"] + response = requests.get(redirect_url, stream=True) + response.raise_for_status() + else: + print(f"Unexpected status code: {response.status_code}") + return + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"Downloaded {filename} successfully.") + + except requests.exceptions.RequestException as e: + print(f"Error downloading file: {e}")