Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- outdoor_v48_16gpu_v2/.hydra/config.yaml +68 -0
- outdoor_v48_16gpu_v2/.hydra/hydra.yaml +156 -0
- outdoor_v48_16gpu_v2/.hydra/overrides.yaml +2 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/__init__.py +0 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/base_multiview_dataset.py +576 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/batched_sampler.py +93 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/easy_dataset.py +212 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/dynamic_replica.py +137 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/habitat_hm3d.py +174 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/hoi4d.py +84 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mapfree.py +282 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mvs_synth.py +144 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/omniobject3d.py +146 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/pointodyssey.py +178 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/realestate10k.py +139 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannet.py +149 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannetpp.py +211 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/smartportraits.py +85 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/threedkb.py +111 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/unreal4k.py +159 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/__init__.py +2 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/corr.py +129 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/cropping.py +147 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/transforms.py +80 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/waymo.py +178 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/wildrgbd.py +56 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/__init__.py +1 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/camera.py +463 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/device.py +88 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/geometry.py +554 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/image.py +271 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/misc.py +127 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/parallel.py +87 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/path_to_croco.py +47 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/render.py +75 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/__init__.py +6 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/__init__.py +4 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/backbones.py +156 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/utils.py +39 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/__init__.py +11 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/attention.py +89 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/block.py +259 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/dino_head.py +58 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/drop_path.py +34 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/layer_scale.py +27 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/mlp.py +40 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/patch_embed.py +88 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/swiglu_ffn.py +72 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/__init__.py +43 -0
- outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/vision_transformer.py +404 -0
outdoor_v48_16gpu_v2/.hydra/config.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
teacher: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model
|
| 2 |
+
pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model
|
| 3 |
+
load_only_encoder: false
|
| 4 |
+
long_context: false
|
| 5 |
+
fixed_length: true
|
| 6 |
+
resume: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth
|
| 7 |
+
benchmark: false
|
| 8 |
+
num_views: 64
|
| 9 |
+
num_test_views: 4
|
| 10 |
+
n_corres_train: 0
|
| 11 |
+
n_corres_test: 0
|
| 12 |
+
train_criterion: DistillLoss()
|
| 13 |
+
test_criterion: DistillLoss()
|
| 14 |
+
allow_repeat: false
|
| 15 |
+
root_vkitti2: /scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti
|
| 16 |
+
root_kitti: /scratch-shared/wwei2/eval/kitti_odometry/dataset
|
| 17 |
+
root_kitti_velo: /gpfs/work2/0/prjs0824/semantickitti/dataset
|
| 18 |
+
root_kitti360: /scratch-shared/wwei2/downloads/kitti360/KITTI-360
|
| 19 |
+
root_kitti360_velo: /scratch-shared/wwei2/downloads/kitti360/KITTI-360
|
| 20 |
+
root_waymo: /scratch-shared/wwei2/waymo_v2
|
| 21 |
+
root_waymo_lidar: /scratch-shared/wwei2/waymo_v2
|
| 22 |
+
dataset_vkitti2: VirtualKITTI2_Multi(allow_repeat=${allow_repeat}, split='train',
|
| 23 |
+
ROOT="${root_vkitti2}", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294),
|
| 24 |
+
(518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views},
|
| 25 |
+
n_corres=${n_corres_train})
|
| 26 |
+
dataset_kitti360: KITTI360_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_kitti360}",
|
| 27 |
+
velodyne_root="${root_kitti360_velo}", aug_crop=16, resolution=[(518, 392), (518,
|
| 28 |
+
336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter,
|
| 29 |
+
num_views=${num_views}, n_corres=${n_corres_train})
|
| 30 |
+
dataset_waymo: Waymo_v2_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_waymo}",
|
| 31 |
+
lidar_root="${root_waymo_lidar}", aug_crop=16, resolution=[(518, 392), (518, 336),
|
| 32 |
+
(518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views},
|
| 33 |
+
n_corres=${n_corres_train})
|
| 34 |
+
train_dataset: 6000 @ ${dataset_vkitti2} + 6000 @ ${dataset_kitti360} + 5400 @ ${dataset_waymo}
|
| 35 |
+
test_dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="${root_vkitti2}", resolution=(518,
|
| 36 |
+
154), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
|
| 37 |
+
seed: 0
|
| 38 |
+
batch_size: 1
|
| 39 |
+
accum_iter: 1
|
| 40 |
+
gradient_checkpointing: false
|
| 41 |
+
epochs: 10
|
| 42 |
+
start_epoch: 0
|
| 43 |
+
start_step: 0
|
| 44 |
+
weight_decay: 0.05
|
| 45 |
+
lr: 1.0e-05
|
| 46 |
+
min_lr: 1.0e-08
|
| 47 |
+
warmup_epochs: 0.5
|
| 48 |
+
amp: 1
|
| 49 |
+
num_workers: 4
|
| 50 |
+
world_size: 1
|
| 51 |
+
local-rank: -1
|
| 52 |
+
dist_url: env://
|
| 53 |
+
rank: 0
|
| 54 |
+
gpu: 0
|
| 55 |
+
distributed: false
|
| 56 |
+
dist_backend: nccl
|
| 57 |
+
eval_freq: 1
|
| 58 |
+
save_freq: 0.1
|
| 59 |
+
max_checkpoints: 10
|
| 60 |
+
keep_freq: 1
|
| 61 |
+
print_freq: 10
|
| 62 |
+
print_img_freq: 50000000
|
| 63 |
+
num_imgs_vis: 4
|
| 64 |
+
save_dir: /scratch-shared/wwei2/training_upstream/checkpoints
|
| 65 |
+
exp_name: outdoor_v48_16gpu_v2
|
| 66 |
+
task: StreamVGGT
|
| 67 |
+
logdir: ${save_dir}/${exp_name}/logs
|
| 68 |
+
output_dir: ${save_dir}/${exp_name}/
|
outdoor_v48_16gpu_v2/.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ${save_dir}/${exp_name}
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task:
|
| 115 |
+
- exp_name=outdoor_v48_16gpu_v2
|
| 116 |
+
- resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth
|
| 117 |
+
job:
|
| 118 |
+
name: mytrain
|
| 119 |
+
chdir: null
|
| 120 |
+
override_dirname: exp_name=outdoor_v48_16gpu_v2,resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth
|
| 121 |
+
id: ???
|
| 122 |
+
num: ???
|
| 123 |
+
config_name: outdoor_v48
|
| 124 |
+
env_set: {}
|
| 125 |
+
env_copy: []
|
| 126 |
+
config:
|
| 127 |
+
override_dirname:
|
| 128 |
+
kv_sep: '='
|
| 129 |
+
item_sep: ','
|
| 130 |
+
exclude_keys: []
|
| 131 |
+
runtime:
|
| 132 |
+
version: 1.3.2
|
| 133 |
+
version_base: '1.3'
|
| 134 |
+
cwd: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src
|
| 135 |
+
config_sources:
|
| 136 |
+
- path: hydra.conf
|
| 137 |
+
schema: pkg
|
| 138 |
+
provider: hydra
|
| 139 |
+
- path: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/config
|
| 140 |
+
schema: file
|
| 141 |
+
provider: main
|
| 142 |
+
- path: ''
|
| 143 |
+
schema: structured
|
| 144 |
+
provider: schema
|
| 145 |
+
output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu_v2
|
| 146 |
+
choices:
|
| 147 |
+
hydra/env: default
|
| 148 |
+
hydra/callbacks: null
|
| 149 |
+
hydra/job_logging: default
|
| 150 |
+
hydra/hydra_logging: default
|
| 151 |
+
hydra/hydra_help: default
|
| 152 |
+
hydra/help: default
|
| 153 |
+
hydra/sweeper: basic
|
| 154 |
+
hydra/launcher: basic
|
| 155 |
+
hydra/output: default
|
| 156 |
+
verbose: true
|
outdoor_v48_16gpu_v2/.hydra/overrides.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- exp_name=outdoor_v48_16gpu_v2
|
| 2 |
+
- resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/__init__.py
ADDED
|
File without changes
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/base_multiview_dataset.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import PIL
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
import itertools
|
| 6 |
+
from dust3r.datasets.base.easy_dataset import EasyDataset
|
| 7 |
+
from dust3r.datasets.utils.transforms import ImgNorm, SeqColorJitter
|
| 8 |
+
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
|
| 9 |
+
import dust3r.datasets.utils.cropping as cropping
|
| 10 |
+
from dust3r.datasets.utils.corr import extract_correspondences_from_pts3d
|
| 11 |
+
|
| 12 |
+
from vggt.train_utils.augmentation import get_image_augmentation
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_ray_map(c2w1, c2w2, intrinsics, h, w):
|
| 17 |
+
c2w = np.linalg.inv(c2w1) @ c2w2
|
| 18 |
+
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy")
|
| 19 |
+
grid = np.stack([i, j, np.ones_like(i)], axis=-1)
|
| 20 |
+
ro = c2w[:3, 3]
|
| 21 |
+
rd = np.linalg.inv(intrinsics) @ grid.reshape(-1, 3).T
|
| 22 |
+
rd = (c2w @ np.vstack([rd, np.ones_like(rd[0])])).T[:, :3].reshape(h, w, 3)
|
| 23 |
+
rd = rd / np.linalg.norm(rd, axis=-1, keepdims=True)
|
| 24 |
+
ro = np.broadcast_to(ro, (h, w, 3))
|
| 25 |
+
ray_map = np.concatenate([ro, rd], axis=-1)
|
| 26 |
+
return ray_map
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BaseMultiViewDataset(EasyDataset):
|
| 30 |
+
"""Define all basic options.
|
| 31 |
+
|
| 32 |
+
Usage:
|
| 33 |
+
class MyDataset (BaseMultiViewDataset):
|
| 34 |
+
def _get_views(self, idx, rng):
|
| 35 |
+
# overload here
|
| 36 |
+
views = []
|
| 37 |
+
views.append(dict(img=, ...))
|
| 38 |
+
return views
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
*, # only keyword arguments
|
| 44 |
+
num_views=None,
|
| 45 |
+
split=None,
|
| 46 |
+
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
| 47 |
+
transform=ImgNorm,
|
| 48 |
+
aug_crop=False,
|
| 49 |
+
n_corres=0,
|
| 50 |
+
nneg=0,
|
| 51 |
+
seed=None,
|
| 52 |
+
allow_repeat=False,
|
| 53 |
+
seq_aug_crop=False,
|
| 54 |
+
):
|
| 55 |
+
assert num_views is not None, "undefined num_views"
|
| 56 |
+
self.num_views = num_views
|
| 57 |
+
self.split = split
|
| 58 |
+
self._set_resolutions(resolution)
|
| 59 |
+
|
| 60 |
+
self.n_corres = n_corres
|
| 61 |
+
self.nneg = nneg
|
| 62 |
+
assert (
|
| 63 |
+
self.n_corres == "all"
|
| 64 |
+
or isinstance(self.n_corres, int)
|
| 65 |
+
or (
|
| 66 |
+
isinstance(self.n_corres, list) and len(self.n_corres) == self.num_views
|
| 67 |
+
)
|
| 68 |
+
), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}"
|
| 69 |
+
assert (
|
| 70 |
+
self.nneg == 0 or self.n_corres != "all"
|
| 71 |
+
), "nneg should be 0 if n_corres is all"
|
| 72 |
+
|
| 73 |
+
self.is_seq_color_jitter = False
|
| 74 |
+
if isinstance(transform, str):
|
| 75 |
+
transform = eval(transform)
|
| 76 |
+
if transform == SeqColorJitter:
|
| 77 |
+
transform = SeqColorJitter()
|
| 78 |
+
self.is_seq_color_jitter = True
|
| 79 |
+
self.transform = transform
|
| 80 |
+
|
| 81 |
+
self.image_aug = get_image_augmentation(
|
| 82 |
+
color_jitter={ 'brightness': 0.5,
|
| 83 |
+
'contrast': 0.5,
|
| 84 |
+
'saturation': 0.5,
|
| 85 |
+
'hue': 0.1,
|
| 86 |
+
'p': 0.9},
|
| 87 |
+
#common_config.augs.color_jitter,
|
| 88 |
+
gray_scale=True,#common_config.augs.gray_scale,
|
| 89 |
+
gau_blur=False, #common_config.augs.gau_blur,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
self.aug_crop = aug_crop
|
| 94 |
+
self.seed = seed
|
| 95 |
+
self.allow_repeat = allow_repeat
|
| 96 |
+
self.seq_aug_crop = seq_aug_crop
|
| 97 |
+
|
| 98 |
+
def __len__(self):
|
| 99 |
+
return len(self.scenes)
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def efficient_random_intervals(
|
| 103 |
+
start,
|
| 104 |
+
num_elements,
|
| 105 |
+
interval_range,
|
| 106 |
+
fixed_interval_prob=0.8,
|
| 107 |
+
weights=None,
|
| 108 |
+
seed=42,
|
| 109 |
+
):
|
| 110 |
+
if random.random() < fixed_interval_prob:
|
| 111 |
+
intervals = random.choices(interval_range, weights=weights) * (
|
| 112 |
+
num_elements - 1
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
intervals = [
|
| 116 |
+
random.choices(interval_range, weights=weights)[0]
|
| 117 |
+
for _ in range(num_elements - 1)
|
| 118 |
+
]
|
| 119 |
+
return list(itertools.accumulate([start] + intervals))
|
| 120 |
+
|
| 121 |
+
def sample_based_on_timestamps(self, i, timestamps, num_views, interval=1):
|
| 122 |
+
time_diffs = np.abs(timestamps - timestamps[i])
|
| 123 |
+
ids_candidate = np.where(time_diffs < interval)[0]
|
| 124 |
+
ids_candidate = np.sort(ids_candidate)
|
| 125 |
+
if (self.allow_repeat and len(ids_candidate) < num_views // 3) or (
|
| 126 |
+
len(ids_candidate) < num_views
|
| 127 |
+
):
|
| 128 |
+
return []
|
| 129 |
+
ids_sel_list = []
|
| 130 |
+
ids_candidate_left = ids_candidate.copy()
|
| 131 |
+
while len(ids_candidate_left) >= num_views:
|
| 132 |
+
ids_sel = np.random.choice(ids_candidate_left, num_views, replace=False)
|
| 133 |
+
ids_sel_list.append(sorted(ids_sel))
|
| 134 |
+
ids_candidate_left = np.setdiff1d(ids_candidate_left, ids_sel)
|
| 135 |
+
|
| 136 |
+
if len(ids_candidate_left) > 0 and len(ids_candidate) >= num_views:
|
| 137 |
+
ids_sel = np.concatenate(
|
| 138 |
+
[
|
| 139 |
+
ids_candidate_left,
|
| 140 |
+
np.random.choice(
|
| 141 |
+
np.setdiff1d(ids_candidate, ids_candidate_left),
|
| 142 |
+
num_views - len(ids_candidate_left),
|
| 143 |
+
replace=False,
|
| 144 |
+
),
|
| 145 |
+
]
|
| 146 |
+
)
|
| 147 |
+
ids_sel_list.append(sorted(ids_sel))
|
| 148 |
+
|
| 149 |
+
if self.allow_repeat:
|
| 150 |
+
ids_sel_list.append(
|
| 151 |
+
sorted(np.random.choice(ids_candidate, num_views, replace=True))
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# add sequences with fixed intervals (all possible intervals)
|
| 155 |
+
pos_i = np.where(ids_candidate == i)[0][0]
|
| 156 |
+
curr_interval = 1
|
| 157 |
+
stop = len(ids_candidate) < num_views
|
| 158 |
+
while not stop:
|
| 159 |
+
pos_sel = [pos_i]
|
| 160 |
+
count = 0
|
| 161 |
+
while len(pos_sel) < num_views:
|
| 162 |
+
if count % 2 == 0:
|
| 163 |
+
curr_pos_i = pos_sel[-1] + curr_interval
|
| 164 |
+
if curr_pos_i >= len(ids_candidate):
|
| 165 |
+
stop = True
|
| 166 |
+
break
|
| 167 |
+
pos_sel.append(curr_pos_i)
|
| 168 |
+
else:
|
| 169 |
+
curr_pos_i = pos_sel[0] - curr_interval
|
| 170 |
+
if curr_pos_i < 0:
|
| 171 |
+
stop = True
|
| 172 |
+
break
|
| 173 |
+
pos_sel.insert(0, curr_pos_i)
|
| 174 |
+
count += 1
|
| 175 |
+
if not stop and len(pos_sel) == num_views:
|
| 176 |
+
ids_sel = sorted([ids_candidate[pos] for pos in pos_sel])
|
| 177 |
+
if ids_sel not in ids_sel_list:
|
| 178 |
+
ids_sel_list.append(ids_sel)
|
| 179 |
+
curr_interval += 1
|
| 180 |
+
return ids_sel_list
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def blockwise_shuffle(x, rng, block_shuffle):
|
| 184 |
+
if block_shuffle is None:
|
| 185 |
+
return rng.permutation(x).tolist()
|
| 186 |
+
else:
|
| 187 |
+
assert block_shuffle > 0
|
| 188 |
+
blocks = [x[i : i + block_shuffle] for i in range(0, len(x), block_shuffle)]
|
| 189 |
+
shuffled_blocks = [rng.permutation(block).tolist() for block in blocks]
|
| 190 |
+
shuffled_list = [item for block in shuffled_blocks for item in block]
|
| 191 |
+
return shuffled_list
|
| 192 |
+
|
| 193 |
+
def get_seq_from_start_id(
|
| 194 |
+
self,
|
| 195 |
+
num_views,
|
| 196 |
+
id_ref,
|
| 197 |
+
ids_all,
|
| 198 |
+
rng,
|
| 199 |
+
min_interval=1,
|
| 200 |
+
max_interval=25,
|
| 201 |
+
video_prob=0.5,
|
| 202 |
+
fix_interval_prob=0.5,
|
| 203 |
+
block_shuffle=None,
|
| 204 |
+
):
|
| 205 |
+
"""
|
| 206 |
+
args:
|
| 207 |
+
num_views: number of views to return
|
| 208 |
+
id_ref: the reference id (first id)
|
| 209 |
+
ids_all: all the ids
|
| 210 |
+
rng: random number generator
|
| 211 |
+
max_interval: maximum interval between two views
|
| 212 |
+
returns:
|
| 213 |
+
pos: list of positions of the views in ids_all, i.e., index for ids_all
|
| 214 |
+
is_video: True if the views are consecutive
|
| 215 |
+
"""
|
| 216 |
+
assert min_interval > 0, f"min_interval should be > 0, got {min_interval}"
|
| 217 |
+
assert (
|
| 218 |
+
min_interval <= max_interval
|
| 219 |
+
), f"min_interval should be <= max_interval, got {min_interval} and {max_interval}"
|
| 220 |
+
assert id_ref in ids_all
|
| 221 |
+
pos_ref = ids_all.index(id_ref)
|
| 222 |
+
all_possible_pos = np.arange(pos_ref, len(ids_all))
|
| 223 |
+
|
| 224 |
+
remaining_sum = len(ids_all) - 1 - pos_ref
|
| 225 |
+
|
| 226 |
+
if remaining_sum >= num_views - 1:
|
| 227 |
+
if remaining_sum == num_views - 1:
|
| 228 |
+
assert ids_all[-num_views] == id_ref
|
| 229 |
+
return [pos_ref + i for i in range(num_views)], True
|
| 230 |
+
max_interval = min(max_interval, 2 * remaining_sum // (num_views - 1))
|
| 231 |
+
intervals = [
|
| 232 |
+
rng.choice(range(min_interval, max_interval + 1))
|
| 233 |
+
for _ in range(num_views - 1)
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
# if video or collection
|
| 237 |
+
if rng.random() < video_prob:
|
| 238 |
+
# if fixed interval or random
|
| 239 |
+
if rng.random() < fix_interval_prob:
|
| 240 |
+
# regular interval
|
| 241 |
+
fixed_interval = rng.choice(
|
| 242 |
+
range(
|
| 243 |
+
1,
|
| 244 |
+
min(remaining_sum // (num_views - 1) + 1, max_interval + 1),
|
| 245 |
+
)
|
| 246 |
+
)
|
| 247 |
+
intervals = [fixed_interval for _ in range(num_views - 1)]
|
| 248 |
+
is_video = True
|
| 249 |
+
else:
|
| 250 |
+
is_video = False
|
| 251 |
+
|
| 252 |
+
pos = list(itertools.accumulate([pos_ref] + intervals))
|
| 253 |
+
pos = [p for p in pos if p < len(ids_all)]
|
| 254 |
+
pos_candidates = [p for p in all_possible_pos if p not in pos]
|
| 255 |
+
pos = (
|
| 256 |
+
pos
|
| 257 |
+
+ rng.choice(
|
| 258 |
+
pos_candidates, num_views - len(pos), replace=False
|
| 259 |
+
).tolist()
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
pos = (
|
| 263 |
+
sorted(pos)
|
| 264 |
+
if is_video
|
| 265 |
+
else self.blockwise_shuffle(pos, rng, block_shuffle)
|
| 266 |
+
)
|
| 267 |
+
#elif remaining_sum>1:
|
| 268 |
+
else:
|
| 269 |
+
# assert self.allow_repeat
|
| 270 |
+
uniq_num = remaining_sum
|
| 271 |
+
new_pos_ref = rng.choice(np.arange(pos_ref + 1))
|
| 272 |
+
new_remaining_sum = len(ids_all) - 1 - new_pos_ref
|
| 273 |
+
new_max_interval = min(max_interval, new_remaining_sum // (uniq_num - 1))
|
| 274 |
+
new_intervals = [
|
| 275 |
+
rng.choice(range(1, new_max_interval + 1)) for _ in range(uniq_num - 1)
|
| 276 |
+
]
|
| 277 |
+
|
| 278 |
+
revisit_random = rng.random()
|
| 279 |
+
video_random = rng.random()
|
| 280 |
+
|
| 281 |
+
if rng.random() < fix_interval_prob and video_random < video_prob:
|
| 282 |
+
# regular interval
|
| 283 |
+
fixed_interval = rng.choice(range(1, new_max_interval + 1))
|
| 284 |
+
new_intervals = [fixed_interval for _ in range(uniq_num - 1)]
|
| 285 |
+
pos = list(itertools.accumulate([new_pos_ref] + new_intervals))
|
| 286 |
+
|
| 287 |
+
is_video = False
|
| 288 |
+
if revisit_random < 0.5 or video_prob == 1.0: # revisit, video / collection
|
| 289 |
+
is_video = video_random < video_prob
|
| 290 |
+
pos = (
|
| 291 |
+
self.blockwise_shuffle(pos, rng, block_shuffle)
|
| 292 |
+
if not is_video
|
| 293 |
+
else pos
|
| 294 |
+
)
|
| 295 |
+
num_full_repeat = num_views // uniq_num
|
| 296 |
+
pos = (
|
| 297 |
+
pos * num_full_repeat
|
| 298 |
+
+ pos[: num_views - len(pos) * num_full_repeat]
|
| 299 |
+
)
|
| 300 |
+
elif revisit_random < 0.9: # random
|
| 301 |
+
pos = rng.choice(pos, num_views, replace=True)
|
| 302 |
+
else: # ordered
|
| 303 |
+
pos = sorted(rng.choice(pos, num_views, replace=True))
|
| 304 |
+
assert len(pos) == num_views
|
| 305 |
+
return pos, is_video
|
| 306 |
+
|
| 307 |
+
def get_img_and_ray_masks(self, is_metric, v, rng, p=[0.8, 0.15, 0.05]):
|
| 308 |
+
# generate img mask and raymap mask
|
| 309 |
+
if v == 0 or (not is_metric):
|
| 310 |
+
img_mask = True
|
| 311 |
+
raymap_mask = False
|
| 312 |
+
else:
|
| 313 |
+
rand_val = rng.random()
|
| 314 |
+
if rand_val < p[0]:
|
| 315 |
+
img_mask = True
|
| 316 |
+
raymap_mask = False
|
| 317 |
+
elif rand_val < p[0] + p[1]:
|
| 318 |
+
img_mask = False
|
| 319 |
+
raymap_mask = True
|
| 320 |
+
else:
|
| 321 |
+
img_mask = True
|
| 322 |
+
raymap_mask = True
|
| 323 |
+
return img_mask, raymap_mask
|
| 324 |
+
|
| 325 |
+
def get_stats(self):
|
| 326 |
+
return f"{len(self)} groups of views"
|
| 327 |
+
|
| 328 |
+
def __repr__(self):
|
| 329 |
+
resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
|
| 330 |
+
return (
|
| 331 |
+
f"""{type(self).__name__}({self.get_stats()},
|
| 332 |
+
{self.num_views=},
|
| 333 |
+
{self.split=},
|
| 334 |
+
{self.seed=},
|
| 335 |
+
resolutions={resolutions_str},
|
| 336 |
+
{self.transform=})""".replace(
|
| 337 |
+
"self.", ""
|
| 338 |
+
)
|
| 339 |
+
.replace("\n", "")
|
| 340 |
+
.replace(" ", "")
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 344 |
+
raise NotImplementedError()
|
| 345 |
+
|
| 346 |
+
def __getitem__(self, idx):
|
| 347 |
+
# print("Receiving:" , idx)
|
| 348 |
+
if isinstance(idx, (tuple, list, np.ndarray)):
|
| 349 |
+
# the idx is specifying the aspect-ratio
|
| 350 |
+
idx, ar_idx, nview = idx
|
| 351 |
+
else:
|
| 352 |
+
assert len(self._resolutions) == 1
|
| 353 |
+
ar_idx = 0
|
| 354 |
+
nview = self.num_views
|
| 355 |
+
|
| 356 |
+
assert nview >= 1 and nview <= self.num_views
|
| 357 |
+
# set-up the rng
|
| 358 |
+
if self.seed: # reseed for each __getitem__
|
| 359 |
+
self._rng = np.random.default_rng(seed=self.seed + idx)
|
| 360 |
+
elif not hasattr(self, "_rng"):
|
| 361 |
+
seed = torch.randint(0, 2**32, (1,)).item()
|
| 362 |
+
self._rng = np.random.default_rng(seed=seed)
|
| 363 |
+
|
| 364 |
+
if self.aug_crop > 1 and self.seq_aug_crop:
|
| 365 |
+
self.delta_target_resolution = self._rng.integers(0, self.aug_crop)
|
| 366 |
+
|
| 367 |
+
# over-loaded code
|
| 368 |
+
resolution = self._resolutions[
|
| 369 |
+
ar_idx
|
| 370 |
+
] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
| 371 |
+
views = self._get_views(idx, resolution, self._rng, nview)
|
| 372 |
+
assert len(views) == nview
|
| 373 |
+
|
| 374 |
+
if "camera_pose" not in views[0]:
|
| 375 |
+
views[0]["camera_pose"] = np.ones((4, 4), dtype=np.float32)
|
| 376 |
+
first_view_camera_pose = views[0]["camera_pose"]
|
| 377 |
+
transform = SeqColorJitter() if self.is_seq_color_jitter else self.transform
|
| 378 |
+
|
| 379 |
+
for v, view in enumerate(views):
|
| 380 |
+
assert (
|
| 381 |
+
"pts3d" not in view
|
| 382 |
+
), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
| 383 |
+
view["idx"] = (idx, ar_idx, v)
|
| 384 |
+
|
| 385 |
+
# encode the image
|
| 386 |
+
width, height = view["img"].size
|
| 387 |
+
|
| 388 |
+
view["true_shape"] = np.int32((height, width))
|
| 389 |
+
view["img"] = transform(view["img"])
|
| 390 |
+
view["sky_mask"] = view["depthmap"] < 0
|
| 391 |
+
|
| 392 |
+
assert "camera_intrinsics" in view
|
| 393 |
+
if "camera_pose" not in view:
|
| 394 |
+
view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
|
| 395 |
+
else:
|
| 396 |
+
assert np.isfinite(
|
| 397 |
+
view["camera_pose"]
|
| 398 |
+
).all(), f"NaN in camera pose for view {view_name(view)}"
|
| 399 |
+
|
| 400 |
+
ray_map = get_ray_map(
|
| 401 |
+
first_view_camera_pose,
|
| 402 |
+
view["camera_pose"],
|
| 403 |
+
view["camera_intrinsics"],
|
| 404 |
+
height,
|
| 405 |
+
width,
|
| 406 |
+
)
|
| 407 |
+
view["ray_map"] = ray_map.astype(np.float32)
|
| 408 |
+
|
| 409 |
+
assert "pts3d" not in view
|
| 410 |
+
assert "valid_mask" not in view
|
| 411 |
+
assert np.isfinite(
|
| 412 |
+
view["depthmap"]
|
| 413 |
+
).all(), f"NaN in depthmap for view {view_name(view)}"
|
| 414 |
+
pts3d, pts3d_local, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
view["pts3d"] = pts3d
|
| 420 |
+
view["pts3d_local"] = pts3d_local
|
| 421 |
+
view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
| 422 |
+
|
| 423 |
+
# check all datatypes
|
| 424 |
+
for key, val in view.items():
|
| 425 |
+
res, err_msg = is_good_type(key, val)
|
| 426 |
+
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
| 427 |
+
K = view["camera_intrinsics"]
|
| 428 |
+
if False:
|
| 429 |
+
if random.random() > 0.3:#self.cojitter_ratio:
|
| 430 |
+
images = torch.stack([view['img'] for view in views],axis=0)
|
| 431 |
+
images = self.image_aug(images)
|
| 432 |
+
for v, view in enumerate(views):
|
| 433 |
+
view['img'] = images[v]
|
| 434 |
+
|
| 435 |
+
else:
|
| 436 |
+
for view in views:
|
| 437 |
+
view['img'] = self.image_aug(view['img'][None])[0]
|
| 438 |
+
|
| 439 |
+
if self.n_corres > 0:
|
| 440 |
+
ref_view = views[0]
|
| 441 |
+
for view in views:
|
| 442 |
+
corres1, corres2, valid = extract_correspondences_from_pts3d(
|
| 443 |
+
ref_view, view, self.n_corres, self._rng, nneg=self.nneg
|
| 444 |
+
)
|
| 445 |
+
view["corres"] = (corres1, corres2)
|
| 446 |
+
view["valid_corres"] = valid
|
| 447 |
+
|
| 448 |
+
# last thing done!
|
| 449 |
+
for view in views:
|
| 450 |
+
view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
|
| 451 |
+
return views
|
| 452 |
+
|
| 453 |
+
def _set_resolutions(self, resolutions):
|
| 454 |
+
assert resolutions is not None, "undefined resolution"
|
| 455 |
+
|
| 456 |
+
if not isinstance(resolutions, list):
|
| 457 |
+
resolutions = [resolutions]
|
| 458 |
+
|
| 459 |
+
self._resolutions = []
|
| 460 |
+
for resolution in resolutions:
|
| 461 |
+
if isinstance(resolution, int):
|
| 462 |
+
width = height = resolution
|
| 463 |
+
else:
|
| 464 |
+
width, height = resolution
|
| 465 |
+
assert isinstance(
|
| 466 |
+
width, int
|
| 467 |
+
), f"Bad type for {width=} {type(width)=}, should be int"
|
| 468 |
+
assert isinstance(
|
| 469 |
+
height, int
|
| 470 |
+
), f"Bad type for {height=} {type(height)=}, should be int"
|
| 471 |
+
self._resolutions.append((width, height))
|
| 472 |
+
|
| 473 |
+
def _crop_resize_if_necessary(
|
| 474 |
+
self, image, depthmap, intrinsics, resolution, rng=None, info=None
|
| 475 |
+
):
|
| 476 |
+
"""This function:
|
| 477 |
+
- first downsizes the image with LANCZOS inteprolation,
|
| 478 |
+
which is better than bilinear interpolation in
|
| 479 |
+
"""
|
| 480 |
+
if not isinstance(image, PIL.Image.Image):
|
| 481 |
+
image = PIL.Image.fromarray(image)
|
| 482 |
+
|
| 483 |
+
# downscale with lanczos interpolation so that image.size == resolution
|
| 484 |
+
# cropping centered on the principal point
|
| 485 |
+
W, H = image.size
|
| 486 |
+
cx, cy = intrinsics[:2, 2].round().astype(int)
|
| 487 |
+
min_margin_x = min(cx, W - cx)
|
| 488 |
+
min_margin_y = min(cy, H - cy)
|
| 489 |
+
assert min_margin_x > W / 5, f"Bad principal point in view={info}"
|
| 490 |
+
assert min_margin_y > H / 5, f"Bad principal point in view={info}"
|
| 491 |
+
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
|
| 492 |
+
l, t = cx - min_margin_x, cy - min_margin_y
|
| 493 |
+
r, b = cx + min_margin_x, cy + min_margin_y
|
| 494 |
+
crop_bbox = (l, t, r, b)
|
| 495 |
+
image, depthmap, intrinsics = cropping.crop_image_depthmap(
|
| 496 |
+
image, depthmap, intrinsics, crop_bbox
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# transpose the resolution if necessary
|
| 500 |
+
W, H = image.size # new size
|
| 501 |
+
|
| 502 |
+
# high-quality Lanczos down-scaling
|
| 503 |
+
target_resolution = np.array(resolution)
|
| 504 |
+
if self.aug_crop > 1:
|
| 505 |
+
target_resolution += (
|
| 506 |
+
rng.integers(0, self.aug_crop)
|
| 507 |
+
if not self.seq_aug_crop
|
| 508 |
+
else self.delta_target_resolution
|
| 509 |
+
)
|
| 510 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(
|
| 511 |
+
image, depthmap, intrinsics, target_resolution
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# actual cropping (if necessary) with bilinear interpolation
|
| 515 |
+
intrinsics2 = cropping.camera_matrix_of_crop(
|
| 516 |
+
intrinsics, image.size, resolution, offset_factor=0.5
|
| 517 |
+
)
|
| 518 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(
|
| 519 |
+
intrinsics, intrinsics2, resolution
|
| 520 |
+
)
|
| 521 |
+
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(
|
| 522 |
+
image, depthmap, intrinsics, crop_bbox
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
return image, depthmap, intrinsics2
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def is_good_type(key, v):
|
| 529 |
+
"""returns (is_good, err_msg)"""
|
| 530 |
+
if isinstance(v, (str, int, tuple)):
|
| 531 |
+
return True, None
|
| 532 |
+
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
|
| 533 |
+
return False, f"bad {v.dtype=}"
|
| 534 |
+
return True, None
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def view_name(view, batch_index=None):
|
| 538 |
+
def sel(x):
|
| 539 |
+
return x[batch_index] if batch_index not in (None, slice(None)) else x
|
| 540 |
+
|
| 541 |
+
db = sel(view["dataset"])
|
| 542 |
+
label = sel(view["label"])
|
| 543 |
+
instance = sel(view["instance"])
|
| 544 |
+
return f"{db}/{label}/{instance}"
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def transpose_to_landscape(view):
|
| 548 |
+
height, width = view["true_shape"]
|
| 549 |
+
|
| 550 |
+
if width < height:
|
| 551 |
+
# rectify portrait to landscape
|
| 552 |
+
assert view["img"].shape == (3, height, width)
|
| 553 |
+
view["img"] = view["img"].swapaxes(1, 2)
|
| 554 |
+
|
| 555 |
+
assert view["valid_mask"].shape == (height, width)
|
| 556 |
+
view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
|
| 557 |
+
|
| 558 |
+
assert view["depthmap"].shape == (height, width)
|
| 559 |
+
view["depthmap"] = view["depthmap"].swapaxes(0, 1)
|
| 560 |
+
|
| 561 |
+
assert view["pts3d"].shape == (height, width, 3)
|
| 562 |
+
view["pts3d"] = view["pts3d"].swapaxes(0, 1)
|
| 563 |
+
|
| 564 |
+
# transpose x and y pixels
|
| 565 |
+
view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
|
| 566 |
+
|
| 567 |
+
assert view["ray_map"].shape == (height, width, 6)
|
| 568 |
+
view["ray_map"] = view["ray_map"].swapaxes(0, 1)
|
| 569 |
+
|
| 570 |
+
assert view["sky_mask"].shape == (height, width)
|
| 571 |
+
view["sky_mask"] = view["sky_mask"].swapaxes(0, 1)
|
| 572 |
+
|
| 573 |
+
if "corres" in view:
|
| 574 |
+
# transpose correspondences x and y
|
| 575 |
+
view["corres"][0] = view["corres"][0][:, [1, 0]]
|
| 576 |
+
view["corres"][1] = view["corres"][1][:, [1, 0]]
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/batched_sampler.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from accelerate import Accelerator
|
| 4 |
+
import torch.utils
|
| 5 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 6 |
+
import torch.utils.data
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CustomRandomSampler(Sampler):
|
| 10 |
+
"""Random sampling under a constraint: each sample in the batch has the same feature,
|
| 11 |
+
which is chosen randomly from a known pool of 'features' for each batch.
|
| 12 |
+
|
| 13 |
+
For instance, the 'feature' could be the image aspect-ratio.
|
| 14 |
+
|
| 15 |
+
The index returned is a tuple (sample_idx, feat_idx).
|
| 16 |
+
This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
dataset,
|
| 22 |
+
batch_size,
|
| 23 |
+
pool_size,
|
| 24 |
+
min_view_size,
|
| 25 |
+
max_view_size,
|
| 26 |
+
world_size,
|
| 27 |
+
warmup=1,
|
| 28 |
+
drop_last=True,
|
| 29 |
+
):
|
| 30 |
+
self.batch_size = batch_size
|
| 31 |
+
self.pool_size = pool_size
|
| 32 |
+
self.min_view_size = min_view_size
|
| 33 |
+
self.max_view_size = max_view_size
|
| 34 |
+
self.drop_last = drop_last
|
| 35 |
+
self.len_dataset = N = len(dataset)
|
| 36 |
+
self.total_size = N
|
| 37 |
+
self.epoch = None
|
| 38 |
+
self.epochf = 0.0
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return self.total_size
|
| 42 |
+
|
| 43 |
+
def set_epoch(self, epoch):
|
| 44 |
+
self.epoch = epoch
|
| 45 |
+
|
| 46 |
+
def __iter__(self):
|
| 47 |
+
if self.epoch is None:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"Epoch number not set. Please call 'set_epoch(epoch)' before iterating."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
seed = self.epoch + 788
|
| 53 |
+
rng = np.random.default_rng(seed=seed)
|
| 54 |
+
# random indices (will restart from 0 if not drop_last)
|
| 55 |
+
sample_idxs = np.arange(self.total_size)
|
| 56 |
+
rng.shuffle(sample_idxs)
|
| 57 |
+
# random feat_idxs (same across each batch)
|
| 58 |
+
n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
|
| 59 |
+
if self.pool_size > 1:
|
| 60 |
+
p = np.ones(self.pool_size)
|
| 61 |
+
p[: self.pool_size // 2] *= 2
|
| 62 |
+
p = p / p.sum()
|
| 63 |
+
_feat_idxs = rng.choice(self.pool_size, size=n_batches, p=p)
|
| 64 |
+
else:
|
| 65 |
+
_feat_idxs = rng.integers(self.pool_size, size=n_batches)
|
| 66 |
+
_feat_idxs = np.broadcast_to(_feat_idxs[:, None], (n_batches, self.batch_size))
|
| 67 |
+
_feat_idxs = _feat_idxs.ravel()[: self.total_size]
|
| 68 |
+
_view_idxs = rng.integers(
|
| 69 |
+
self.min_view_size, self.max_view_size + 1, size=n_batches
|
| 70 |
+
)
|
| 71 |
+
_view_idxs = np.broadcast_to(_view_idxs[:, None], (n_batches, self.batch_size))
|
| 72 |
+
_view_idxs = _view_idxs.ravel()[: self.total_size]
|
| 73 |
+
|
| 74 |
+
idxs = np.c_[sample_idxs, _feat_idxs, _view_idxs]
|
| 75 |
+
yield from (tuple(idx) for idx in idxs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class BatchedRandomSampler(BatchSampler):
|
| 79 |
+
"""Batch sampler that groups indices from RandomSampler into batches."""
|
| 80 |
+
|
| 81 |
+
def __init__(self, sampler: CustomRandomSampler, batch_size, drop_last=True):
|
| 82 |
+
self.sampler = sampler # An instance of RandomSampler
|
| 83 |
+
self.batch_size = batch_size
|
| 84 |
+
self.drop_last = drop_last
|
| 85 |
+
|
| 86 |
+
def set_epoch(self, epoch):
|
| 87 |
+
self.sampler.set_epoch(epoch)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def round_by(total, multiple, up=False):
|
| 91 |
+
if up:
|
| 92 |
+
total = total + multiple - 1
|
| 93 |
+
return (total // multiple) * multiple
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/easy_dataset.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# modified from DUSt3R
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from dust3r.datasets.base.batched_sampler import (
|
| 9 |
+
BatchedRandomSampler,
|
| 10 |
+
CustomRandomSampler,
|
| 11 |
+
)
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class EasyDataset:
|
| 16 |
+
"""a dataset that you can easily resize and combine.
|
| 17 |
+
Examples:
|
| 18 |
+
---------
|
| 19 |
+
2 * dataset ==> duplicate each element 2x
|
| 20 |
+
|
| 21 |
+
10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
|
| 22 |
+
|
| 23 |
+
dataset1 + dataset2 ==> concatenate datasets
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __add__(self, other):
|
| 27 |
+
return CatDataset([self, other])
|
| 28 |
+
|
| 29 |
+
def __rmul__(self, factor):
|
| 30 |
+
return MulDataset(factor, self)
|
| 31 |
+
|
| 32 |
+
def __rmatmul__(self, factor):
|
| 33 |
+
return ResizedDataset(factor, self)
|
| 34 |
+
|
| 35 |
+
def set_epoch(self, epoch):
|
| 36 |
+
pass # nothing to do by default
|
| 37 |
+
|
| 38 |
+
def make_sampler(
|
| 39 |
+
self, batch_size, shuffle=True, drop_last=True, world_size=1, rank=0, fixed_length=False
|
| 40 |
+
):
|
| 41 |
+
if not (shuffle):
|
| 42 |
+
raise NotImplementedError() # cannot deal yet
|
| 43 |
+
num_of_aspect_ratios = len(self._resolutions)
|
| 44 |
+
num_of_views = self.num_views
|
| 45 |
+
sampler = CustomRandomSampler(
|
| 46 |
+
self,
|
| 47 |
+
batch_size,
|
| 48 |
+
num_of_aspect_ratios,
|
| 49 |
+
4 if not fixed_length else num_of_views,
|
| 50 |
+
num_of_views,
|
| 51 |
+
world_size,
|
| 52 |
+
warmup=1,
|
| 53 |
+
drop_last=drop_last,
|
| 54 |
+
)
|
| 55 |
+
return BatchedRandomSampler(sampler, batch_size, drop_last)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MulDataset(EasyDataset):
|
| 59 |
+
"""Artifically augmenting the size of a dataset."""
|
| 60 |
+
|
| 61 |
+
multiplicator: int
|
| 62 |
+
|
| 63 |
+
def __init__(self, multiplicator, dataset):
|
| 64 |
+
assert isinstance(multiplicator, int) and multiplicator > 0
|
| 65 |
+
self.multiplicator = multiplicator
|
| 66 |
+
self.dataset = dataset
|
| 67 |
+
|
| 68 |
+
def __len__(self):
|
| 69 |
+
return self.multiplicator * len(self.dataset)
|
| 70 |
+
|
| 71 |
+
def __repr__(self):
|
| 72 |
+
return f"{self.multiplicator}*{repr(self.dataset)}"
|
| 73 |
+
|
| 74 |
+
def __getitem__(self, idx):
|
| 75 |
+
if isinstance(idx, tuple):
|
| 76 |
+
idx, other, another = idx
|
| 77 |
+
return self.dataset[idx // self.multiplicator, other, another]
|
| 78 |
+
else:
|
| 79 |
+
return self.dataset[idx // self.multiplicator]
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def _resolutions(self):
|
| 83 |
+
return self.dataset._resolutions
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def num_views(self):
|
| 87 |
+
return self.dataset.num_views
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ResizedDataset(EasyDataset):
|
| 91 |
+
"""Artifically changing the size of a dataset."""
|
| 92 |
+
|
| 93 |
+
new_size: int
|
| 94 |
+
|
| 95 |
+
def __init__(self, new_size, dataset):
|
| 96 |
+
assert isinstance(new_size, int) and new_size > 0
|
| 97 |
+
self.new_size = new_size
|
| 98 |
+
self.dataset = dataset
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return self.new_size
|
| 102 |
+
|
| 103 |
+
def __repr__(self):
|
| 104 |
+
size_str = str(self.new_size)
|
| 105 |
+
for i in range((len(size_str) - 1) // 3):
|
| 106 |
+
sep = -4 * i - 3
|
| 107 |
+
size_str = size_str[:sep] + "_" + size_str[sep:]
|
| 108 |
+
return f"{size_str} @ {repr(self.dataset)}"
|
| 109 |
+
|
| 110 |
+
def set_epoch(self, epoch):
|
| 111 |
+
# this random shuffle only depends on the epoch
|
| 112 |
+
rng = np.random.default_rng(seed=epoch + 777)
|
| 113 |
+
|
| 114 |
+
# shuffle all indices
|
| 115 |
+
perm = rng.permutation(len(self.dataset))
|
| 116 |
+
|
| 117 |
+
# rotary extension until target size is met
|
| 118 |
+
shuffled_idxs = np.concatenate(
|
| 119 |
+
[perm] * (1 + (len(self) - 1) // len(self.dataset))
|
| 120 |
+
)
|
| 121 |
+
self._idxs_mapping = shuffled_idxs[: self.new_size]
|
| 122 |
+
|
| 123 |
+
assert len(self._idxs_mapping) == self.new_size
|
| 124 |
+
|
| 125 |
+
def __getitem__(self, idx):
|
| 126 |
+
assert hasattr(
|
| 127 |
+
self, "_idxs_mapping"
|
| 128 |
+
), "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()"
|
| 129 |
+
if isinstance(idx, tuple):
|
| 130 |
+
idx, other, another = idx
|
| 131 |
+
return self.dataset[self._idxs_mapping[idx], other, another]
|
| 132 |
+
else:
|
| 133 |
+
return self.dataset[self._idxs_mapping[idx]]
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def _resolutions(self):
|
| 137 |
+
return self.dataset._resolutions
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def num_views(self):
|
| 141 |
+
return self.dataset.num_views
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class CatDataset(EasyDataset):
|
| 145 |
+
"""Concatenation of several datasets"""
|
| 146 |
+
|
| 147 |
+
def __init__(self, datasets):
|
| 148 |
+
for dataset in datasets:
|
| 149 |
+
assert isinstance(dataset, EasyDataset)
|
| 150 |
+
self.datasets = datasets
|
| 151 |
+
self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
|
| 152 |
+
|
| 153 |
+
def __len__(self):
|
| 154 |
+
return self._cum_sizes[-1]
|
| 155 |
+
|
| 156 |
+
def __repr__(self):
|
| 157 |
+
# remove uselessly long transform
|
| 158 |
+
return " + ".join(
|
| 159 |
+
repr(dataset).replace(
|
| 160 |
+
",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))",
|
| 161 |
+
"",
|
| 162 |
+
)
|
| 163 |
+
for dataset in self.datasets
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def set_epoch(self, epoch):
|
| 167 |
+
for dataset in self.datasets:
|
| 168 |
+
dataset.set_epoch(epoch)
|
| 169 |
+
|
| 170 |
+
def __getitem__(self, idx):
|
| 171 |
+
other = None
|
| 172 |
+
if isinstance(idx, tuple):
|
| 173 |
+
idx, other, another = idx
|
| 174 |
+
|
| 175 |
+
cause_error = False
|
| 176 |
+
while True:
|
| 177 |
+
|
| 178 |
+
if not (0 <= idx < len(self)):
|
| 179 |
+
raise IndexError()
|
| 180 |
+
|
| 181 |
+
db_idx = np.searchsorted(self._cum_sizes, idx, "right")
|
| 182 |
+
dataset = self.datasets[db_idx]
|
| 183 |
+
new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
|
| 184 |
+
|
| 185 |
+
if other is not None and another is not None:
|
| 186 |
+
new_idx = (new_idx, other, another)
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
res_data = dataset[new_idx]
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(e)
|
| 192 |
+
print("DATA ERROR", new_idx)
|
| 193 |
+
idx += 1
|
| 194 |
+
idx = idx % len(self)
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
break
|
| 198 |
+
return res_data
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def _resolutions(self):
|
| 202 |
+
resolutions = self.datasets[0]._resolutions
|
| 203 |
+
for dataset in self.datasets[1:]:
|
| 204 |
+
assert tuple(dataset._resolutions) == tuple(resolutions)
|
| 205 |
+
return resolutions
|
| 206 |
+
|
| 207 |
+
@property
|
| 208 |
+
def num_views(self):
|
| 209 |
+
num_views = self.datasets[0].num_views
|
| 210 |
+
for dataset in self.datasets[1:]:
|
| 211 |
+
assert dataset.num_views == num_views
|
| 212 |
+
return num_views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/dynamic_replica.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 11 |
+
from dust3r.utils.image import imread_cv2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DynamicReplica(BaseMultiViewDataset):
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.video = True
|
| 18 |
+
self.is_metric = True
|
| 19 |
+
self.max_interval = 16
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
|
| 22 |
+
self.loaded_data = self._load_data(self.split)
|
| 23 |
+
|
| 24 |
+
def _load_data(self, split):
|
| 25 |
+
self.scenes = os.listdir(os.path.join(self.ROOT, split))
|
| 26 |
+
|
| 27 |
+
offset = 0
|
| 28 |
+
scenes = []
|
| 29 |
+
sceneids = []
|
| 30 |
+
scene_img_list = []
|
| 31 |
+
images = []
|
| 32 |
+
start_img_ids = []
|
| 33 |
+
|
| 34 |
+
j = 0
|
| 35 |
+
for scene in tqdm(self.scenes):
|
| 36 |
+
scene_dir = osp.join(self.ROOT, self.split, scene, "left")
|
| 37 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 38 |
+
basenames = sorted(
|
| 39 |
+
[f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")],
|
| 40 |
+
key=lambda x: float(x),
|
| 41 |
+
)
|
| 42 |
+
num_imgs = len(basenames)
|
| 43 |
+
img_ids = list(np.arange(num_imgs) + offset)
|
| 44 |
+
cut_off = (
|
| 45 |
+
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
|
| 46 |
+
)
|
| 47 |
+
if num_imgs < cut_off:
|
| 48 |
+
print(f"Skipping {scene}")
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
|
| 52 |
+
start_img_ids.extend(start_img_ids_)
|
| 53 |
+
sceneids.extend([j] * num_imgs)
|
| 54 |
+
images.extend(basenames)
|
| 55 |
+
scenes.append(scene)
|
| 56 |
+
scene_img_list.append(img_ids)
|
| 57 |
+
|
| 58 |
+
# offset groups
|
| 59 |
+
offset += num_imgs
|
| 60 |
+
j += 1
|
| 61 |
+
|
| 62 |
+
self.scenes = scenes
|
| 63 |
+
self.sceneids = sceneids
|
| 64 |
+
self.images = images
|
| 65 |
+
self.start_img_ids = start_img_ids
|
| 66 |
+
self.scene_img_list = scene_img_list
|
| 67 |
+
|
| 68 |
+
def __len__(self):
|
| 69 |
+
return len(self.start_img_ids)
|
| 70 |
+
|
| 71 |
+
def get_image_num(self):
|
| 72 |
+
return len(self.images)
|
| 73 |
+
|
| 74 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 75 |
+
start_id = self.start_img_ids[idx]
|
| 76 |
+
all_image_ids = self.scene_img_list[self.sceneids[start_id]]
|
| 77 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 78 |
+
num_views,
|
| 79 |
+
start_id,
|
| 80 |
+
all_image_ids,
|
| 81 |
+
rng,
|
| 82 |
+
max_interval=self.max_interval,
|
| 83 |
+
video_prob=1.0,
|
| 84 |
+
fix_interval_prob=1.0,
|
| 85 |
+
)
|
| 86 |
+
image_idxs = np.array(all_image_ids)[pos]
|
| 87 |
+
|
| 88 |
+
views = []
|
| 89 |
+
for v, view_idx in enumerate(image_idxs):
|
| 90 |
+
scene_id = self.sceneids[view_idx]
|
| 91 |
+
scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id], "left")
|
| 92 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 93 |
+
depth_dir = osp.join(scene_dir, "depth")
|
| 94 |
+
cam_dir = osp.join(scene_dir, "cam")
|
| 95 |
+
|
| 96 |
+
basename = self.images[view_idx]
|
| 97 |
+
|
| 98 |
+
# Load RGB image
|
| 99 |
+
rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png"))
|
| 100 |
+
# Load depthmap
|
| 101 |
+
depthmap = np.load(osp.join(depth_dir, basename + ".npy"))
|
| 102 |
+
depthmap[~np.isfinite(depthmap)] = 0 # invalid
|
| 103 |
+
|
| 104 |
+
cam = np.load(osp.join(cam_dir, basename + ".npz"))
|
| 105 |
+
camera_pose = cam["pose"]
|
| 106 |
+
intrinsics = cam["intrinsics"]
|
| 107 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 108 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# generate img mask and raymap mask
|
| 112 |
+
img_mask, ray_mask = self.get_img_and_ray_masks(
|
| 113 |
+
self.is_metric, v, rng, p=[0.85, 0.10, 0.05]
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
views.append(
|
| 117 |
+
dict(
|
| 118 |
+
img=rgb_image,
|
| 119 |
+
depthmap=depthmap.astype(np.float32),
|
| 120 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 121 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 122 |
+
dataset="dynamic_replica",
|
| 123 |
+
label=self.scenes[scene_id] + "_" + basename,
|
| 124 |
+
instance=f"{str(idx)}_{str(view_idx)}",
|
| 125 |
+
is_metric=self.is_metric,
|
| 126 |
+
is_video=ordered_video,
|
| 127 |
+
quantile=np.array(1.0, dtype=np.float32),
|
| 128 |
+
img_mask=img_mask,
|
| 129 |
+
ray_mask=ray_mask,
|
| 130 |
+
camera_only=False,
|
| 131 |
+
depth_only=False,
|
| 132 |
+
single_view=False,
|
| 133 |
+
reset=False,
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
assert len(views) == num_views
|
| 137 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/habitat_hm3d.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 11 |
+
from dust3r.utils.image import imread_cv2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class HabitatHM3D_Multi(BaseMultiViewDataset):
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.video = True
|
| 18 |
+
self.is_metric = False
|
| 19 |
+
self.max_interval = 8
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
self.loaded_data = self._load_data()
|
| 22 |
+
|
| 23 |
+
def _load_data(self):
|
| 24 |
+
self.scenes = os.listdir(self.ROOT)
|
| 25 |
+
|
| 26 |
+
offset = 0
|
| 27 |
+
scenes = []
|
| 28 |
+
sceneids = []
|
| 29 |
+
scene_img_list = []
|
| 30 |
+
images = []
|
| 31 |
+
start_img_ids = []
|
| 32 |
+
|
| 33 |
+
j = 0
|
| 34 |
+
for scene in tqdm(self.scenes):
|
| 35 |
+
scene_dir = osp.join(self.ROOT, scene)
|
| 36 |
+
basenames = sorted(
|
| 37 |
+
[f[:-4] for f in os.listdir(scene_dir) if f.endswith(".npz")],
|
| 38 |
+
key=lambda x: int(x),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
num_imgs = len(basenames)
|
| 42 |
+
# TODO: because current minghui's training data is backward moving, now use seq from -1 to 0
|
| 43 |
+
img_ids = list(np.arange(num_imgs) + offset)
|
| 44 |
+
cut_off = (
|
| 45 |
+
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
|
| 46 |
+
)
|
| 47 |
+
if num_imgs < cut_off:
|
| 48 |
+
print(f"Skipping {scene}")
|
| 49 |
+
continue
|
| 50 |
+
start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
|
| 51 |
+
|
| 52 |
+
start_img_ids.extend([(scene, id) for id in start_img_ids_])
|
| 53 |
+
sceneids.extend([j] * num_imgs)
|
| 54 |
+
images.extend(basenames)
|
| 55 |
+
scenes.append(scene)
|
| 56 |
+
scene_img_list.append(img_ids)
|
| 57 |
+
|
| 58 |
+
# offset groups
|
| 59 |
+
offset += num_imgs
|
| 60 |
+
j += 1
|
| 61 |
+
|
| 62 |
+
self.scenes = scenes
|
| 63 |
+
self.sceneids = sceneids
|
| 64 |
+
self.images = images
|
| 65 |
+
self.start_img_ids = start_img_ids
|
| 66 |
+
self.scene_img_list = scene_img_list
|
| 67 |
+
|
| 68 |
+
self.invalid_scenes = {scene: False for scene in self.scenes}
|
| 69 |
+
|
| 70 |
+
def __len__(self):
|
| 71 |
+
return len(self.start_img_ids)
|
| 72 |
+
|
| 73 |
+
def get_image_num(self):
|
| 74 |
+
return len(self.images)
|
| 75 |
+
|
| 76 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 77 |
+
invalid_seq = True
|
| 78 |
+
scene, start_id = self.start_img_ids[idx] # 获取指定索引idx对应的场景名scene和起始图像id
|
| 79 |
+
|
| 80 |
+
# 添加最大重试次数,防止无限循环导致分布式训练卡住
|
| 81 |
+
max_retries = 100
|
| 82 |
+
retry_count = 0
|
| 83 |
+
|
| 84 |
+
while invalid_seq:
|
| 85 |
+
retry_count += 1
|
| 86 |
+
|
| 87 |
+
# 超过重试次数限制,抛出异常
|
| 88 |
+
if retry_count > max_retries:
|
| 89 |
+
raise RuntimeError(
|
| 90 |
+
f"[HabitatHM3D] Failed to get valid views after {max_retries} retries. "
|
| 91 |
+
f"idx={idx}, scene={scene}, num_views={num_views}. "
|
| 92 |
+
f"This may indicate insufficient valid frames in the dataset."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# 超过50次时打印警告
|
| 96 |
+
if retry_count == 50:
|
| 97 |
+
print(f"[HabitatHM3D WARNING] Already retried {retry_count} times for idx={idx}, scene={scene}")
|
| 98 |
+
|
| 99 |
+
# 如果当前场景被标记为invalid则随机选择一个新的场景和起始图像id
|
| 100 |
+
scene_retry = 0
|
| 101 |
+
while self.invalid_scenes[scene]:
|
| 102 |
+
scene_retry += 1
|
| 103 |
+
if scene_retry > len(self.start_img_ids):
|
| 104 |
+
raise RuntimeError(
|
| 105 |
+
f"[HabitatHM3D] All scenes are invalid! Cannot find valid scene after {scene_retry} attempts."
|
| 106 |
+
)
|
| 107 |
+
idx = rng.integers(low=0, high=len(self.start_img_ids))
|
| 108 |
+
scene, start_id = self.start_img_ids[idx]
|
| 109 |
+
|
| 110 |
+
all_image_ids = self.scene_img_list[self.sceneids[start_id]] # 获取当前场景的所有图像id列表
|
| 111 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 112 |
+
num_views, start_id, all_image_ids, rng, max_interval=self.max_interval
|
| 113 |
+
) # 根据起始图像id和其他参数生成图像序列的索引pos 并返回有序视频
|
| 114 |
+
image_idxs = np.array(all_image_ids)[pos] # 从all_image_ids提取图像序列
|
| 115 |
+
|
| 116 |
+
views = []
|
| 117 |
+
load_failed = False
|
| 118 |
+
for view_idx in image_idxs:
|
| 119 |
+
scene_id = self.sceneids[view_idx]
|
| 120 |
+
scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
|
| 121 |
+
|
| 122 |
+
basename = self.images[view_idx]
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
# Load RGB image
|
| 126 |
+
rgb_image = imread_cv2(osp.join(scene_dir, "image_" + basename + ".png"))
|
| 127 |
+
# Load depthmap
|
| 128 |
+
depthmap = imread_cv2(
|
| 129 |
+
osp.join(scene_dir, "depth_" + basename + ".png"), cv2.IMREAD_UNCHANGED
|
| 130 |
+
)
|
| 131 |
+
depthmap = depthmap.astype(np.float32) / 1000
|
| 132 |
+
depthmap[~np.isfinite(depthmap)] = 0 # invalid
|
| 133 |
+
|
| 134 |
+
camera_params = np.load(osp.join(scene_dir, basename + ".npz"))
|
| 135 |
+
intrinsics = np.float32(camera_params["intrinsics"])
|
| 136 |
+
camera_pose = np.eye(4, dtype=np.float32)
|
| 137 |
+
camera_pose[:3, :3] = camera_params["R_cam2world"]
|
| 138 |
+
camera_pose[:3, 3] = camera_params["t_cam2world"]
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"[HabitatHM3D] Error loading {scene} {basename}: {e}, skipping scene")
|
| 141 |
+
self.invalid_scenes[scene] = True
|
| 142 |
+
load_failed = True
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 146 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
views.append(
|
| 150 |
+
dict(
|
| 151 |
+
img=rgb_image,
|
| 152 |
+
depthmap=depthmap.astype(np.float32),
|
| 153 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 154 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 155 |
+
dataset="habitatHM3D",
|
| 156 |
+
label=self.scenes[scene_id] + "_" + basename,
|
| 157 |
+
instance=f"{str(idx)}_{str(view_idx)}",
|
| 158 |
+
is_metric=self.is_metric,
|
| 159 |
+
is_video=ordered_video,
|
| 160 |
+
quantile=np.array(0.98, dtype=np.float32),
|
| 161 |
+
img_mask=True,
|
| 162 |
+
ray_mask=False,
|
| 163 |
+
camera_only=True,
|
| 164 |
+
depth_only=False,
|
| 165 |
+
single_view=False,
|
| 166 |
+
reset=False,
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# 只有成功加载所有视图才退出循环
|
| 171 |
+
if not load_failed and len(views) == num_views:
|
| 172 |
+
invalid_seq = False
|
| 173 |
+
|
| 174 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/hoi4d.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
sys.path.append(osp.join(osp.dirname(__file__), '..','..'))
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 10 |
+
from dust3r.utils.image import imread_cv2
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class HOI4D_Multi(BaseMultiViewDataset):
|
| 14 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 15 |
+
self.ROOT = ROOT
|
| 16 |
+
self.video = True
|
| 17 |
+
self.is_metric = True
|
| 18 |
+
super().__init__(*args, **kwargs)
|
| 19 |
+
self.loaded_data = self._load_data()
|
| 20 |
+
|
| 21 |
+
def _load_data(self):
|
| 22 |
+
scenes = os.listdir(self.ROOT)
|
| 23 |
+
img_names = []
|
| 24 |
+
for scene in scenes:
|
| 25 |
+
scene_dir = osp.join(self.ROOT, scene)
|
| 26 |
+
rgb_dir = osp.join(scene_dir, 'rgb')
|
| 27 |
+
basenames = sorted([f[:-4] for f in os.listdir(rgb_dir) if f.endswith('.png')])
|
| 28 |
+
img_names.extend([(scene, basename) for basename in basenames])
|
| 29 |
+
|
| 30 |
+
self.img_names = img_names
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.img_names)
|
| 34 |
+
|
| 35 |
+
def get_image_num(self):
|
| 36 |
+
return len(self.img_names)
|
| 37 |
+
|
| 38 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 39 |
+
new_seed = rng.integers(0, 2**32) + idx
|
| 40 |
+
new_rng = np.random.default_rng(new_seed)
|
| 41 |
+
invalid_seq = True
|
| 42 |
+
while invalid_seq:
|
| 43 |
+
img_names = new_rng.choice(self.img_names, num_views, replace=False)
|
| 44 |
+
|
| 45 |
+
views = []
|
| 46 |
+
for v, img_name in enumerate(img_names):
|
| 47 |
+
# Load RGB image
|
| 48 |
+
scene, img_name = img_name
|
| 49 |
+
try:
|
| 50 |
+
rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"))
|
| 51 |
+
depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy"))
|
| 52 |
+
depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
|
| 53 |
+
|
| 54 |
+
intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))["intrinsics"]
|
| 55 |
+
except:
|
| 56 |
+
print(f"Error loading {scene} {img_name}, skipping")
|
| 57 |
+
break
|
| 58 |
+
# camera pose is not provided, placeholder
|
| 59 |
+
camera_pose = np.eye(4)
|
| 60 |
+
|
| 61 |
+
rgb_image, depthmap, intrinsics= self._crop_resize_if_necessary(
|
| 62 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name)
|
| 63 |
+
|
| 64 |
+
views.append(dict(
|
| 65 |
+
img=rgb_image,
|
| 66 |
+
depthmap=depthmap.astype(np.float32),
|
| 67 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 68 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 69 |
+
dataset='HOI4D',
|
| 70 |
+
label=img_name,
|
| 71 |
+
instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"),
|
| 72 |
+
is_metric=self.is_metric,
|
| 73 |
+
is_video=False,
|
| 74 |
+
quantile=np.array(0.99, dtype=np.float32),
|
| 75 |
+
img_mask=True,
|
| 76 |
+
ray_mask=False,
|
| 77 |
+
camera_only=False,
|
| 78 |
+
depth_only=False,
|
| 79 |
+
single_view=True,
|
| 80 |
+
reset=True,
|
| 81 |
+
))
|
| 82 |
+
if len(views) == num_views:
|
| 83 |
+
invalid_seq = False
|
| 84 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mapfree.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import itertools
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import pickle
|
| 9 |
+
import h5py
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 13 |
+
|
| 14 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 15 |
+
from dust3r.utils.image import imread_cv2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MapFree_Multi(BaseMultiViewDataset):
|
| 19 |
+
|
| 20 |
+
def __init__(self, ROOT, *args, **kwargs):
|
| 21 |
+
self.ROOT = ROOT
|
| 22 |
+
self.video = True
|
| 23 |
+
self.is_metric = True
|
| 24 |
+
self.max_interval = 30
|
| 25 |
+
super().__init__(*args, **kwargs)
|
| 26 |
+
|
| 27 |
+
self._load_data()
|
| 28 |
+
|
| 29 |
+
def imgid2path(self, img_id, scene):
|
| 30 |
+
first_seq_id, first_frame_id = img_id
|
| 31 |
+
return os.path.join(
|
| 32 |
+
self.ROOT,
|
| 33 |
+
scene,
|
| 34 |
+
f"dense{first_seq_id}",
|
| 35 |
+
"rgb",
|
| 36 |
+
f"frame_{first_frame_id:05d}.jpg",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def path2imgid(self, subscene, filename):
|
| 40 |
+
first_seq_id = int(subscene[5:])
|
| 41 |
+
first_frame_id = int(filename[6:-4])
|
| 42 |
+
return [first_seq_id, first_frame_id]
|
| 43 |
+
|
| 44 |
+
def _load_data(self):
|
| 45 |
+
cache_file = f"{self.ROOT}/cached_metadata_50_col_only.h5"
|
| 46 |
+
if os.path.exists(cache_file):
|
| 47 |
+
print(f"Loading cached metadata from {cache_file}")
|
| 48 |
+
with h5py.File(cache_file, "r") as hf:
|
| 49 |
+
self.scenes = list(map(lambda x: x.decode("utf-8"), hf["scenes"][:]))
|
| 50 |
+
self.sceneids = hf["sceneids"][:]
|
| 51 |
+
self.scope = hf["scope"][:]
|
| 52 |
+
self.video_flags = hf["video_flags"][:]
|
| 53 |
+
self.groups = hf["groups"][:]
|
| 54 |
+
self.id_ranges = hf["id_ranges"][:]
|
| 55 |
+
self.images = hf["images"][:]
|
| 56 |
+
else:
|
| 57 |
+
scene_dirs = sorted(
|
| 58 |
+
[
|
| 59 |
+
d
|
| 60 |
+
for d in os.listdir(self.ROOT)
|
| 61 |
+
if os.path.isdir(os.path.join(self.ROOT, d))
|
| 62 |
+
]
|
| 63 |
+
)
|
| 64 |
+
scenes = []
|
| 65 |
+
sceneids = []
|
| 66 |
+
groups = []
|
| 67 |
+
scope = []
|
| 68 |
+
images = []
|
| 69 |
+
id_ranges = []
|
| 70 |
+
is_video = []
|
| 71 |
+
start = 0
|
| 72 |
+
j = 0
|
| 73 |
+
offset = 0
|
| 74 |
+
|
| 75 |
+
for scene in tqdm(scene_dirs):
|
| 76 |
+
scenes.append(scene)
|
| 77 |
+
# video sequences
|
| 78 |
+
subscenes = sorted(
|
| 79 |
+
[
|
| 80 |
+
d
|
| 81 |
+
for d in os.listdir(os.path.join(self.ROOT, scene))
|
| 82 |
+
if d.startswith("dense")
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
id_range_subscenes = []
|
| 86 |
+
for subscene in subscenes:
|
| 87 |
+
rgb_paths = sorted(
|
| 88 |
+
[
|
| 89 |
+
d
|
| 90 |
+
for d in os.listdir(
|
| 91 |
+
os.path.join(self.ROOT, scene, subscene, "rgb")
|
| 92 |
+
)
|
| 93 |
+
if d.endswith(".jpg")
|
| 94 |
+
]
|
| 95 |
+
)
|
| 96 |
+
assert (
|
| 97 |
+
len(rgb_paths) > 0
|
| 98 |
+
), f"{os.path.join(self.ROOT, scene, subscene)} is empty."
|
| 99 |
+
num_imgs = len(rgb_paths)
|
| 100 |
+
images.extend(
|
| 101 |
+
[self.path2imgid(subscene, rgb_path) for rgb_path in rgb_paths]
|
| 102 |
+
)
|
| 103 |
+
id_range_subscenes.append((offset, offset + num_imgs))
|
| 104 |
+
offset += num_imgs
|
| 105 |
+
|
| 106 |
+
# image collections
|
| 107 |
+
metadata = pickle.load(
|
| 108 |
+
open(os.path.join(self.ROOT, scene, "metadata.pkl"), "rb")
|
| 109 |
+
)
|
| 110 |
+
ref_imgs = list(metadata.keys())
|
| 111 |
+
img_groups = []
|
| 112 |
+
for ref_img in ref_imgs:
|
| 113 |
+
other_imgs = metadata[ref_img]
|
| 114 |
+
if len(other_imgs) + 1 < self.num_views:
|
| 115 |
+
continue
|
| 116 |
+
group = [(*other_img[0], other_img[1]) for other_img in other_imgs]
|
| 117 |
+
group.insert(0, (*ref_img, 1))
|
| 118 |
+
img_groups.append(np.array(group))
|
| 119 |
+
id_ranges.append(id_range_subscenes[ref_img[0]])
|
| 120 |
+
scope.append(start)
|
| 121 |
+
start = start + len(group)
|
| 122 |
+
|
| 123 |
+
num_groups = len(img_groups)
|
| 124 |
+
sceneids.extend([j] * num_groups)
|
| 125 |
+
groups.extend(img_groups)
|
| 126 |
+
is_video.extend([False] * num_groups)
|
| 127 |
+
j += 1
|
| 128 |
+
|
| 129 |
+
self.scenes = np.array(scenes)
|
| 130 |
+
self.sceneids = np.array(sceneids)
|
| 131 |
+
self.scope = np.array(scope)
|
| 132 |
+
self.video_flags = np.array(is_video)
|
| 133 |
+
self.groups = np.concatenate(groups, 0)
|
| 134 |
+
self.id_ranges = np.array(id_ranges)
|
| 135 |
+
self.images = np.array(images)
|
| 136 |
+
|
| 137 |
+
data = dict(
|
| 138 |
+
scenes=self.scenes,
|
| 139 |
+
sceneids=self.sceneids,
|
| 140 |
+
scope=self.scope,
|
| 141 |
+
video_flags=self.video_flags,
|
| 142 |
+
groups=self.groups,
|
| 143 |
+
id_ranges=self.id_ranges,
|
| 144 |
+
images=self.images,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
with h5py.File(cache_file, "w") as h5f:
|
| 148 |
+
h5f.create_dataset(
|
| 149 |
+
"scenes",
|
| 150 |
+
data=data["scenes"].astype(object),
|
| 151 |
+
dtype=h5py.string_dtype(encoding="utf-8"),
|
| 152 |
+
compression="lzf",
|
| 153 |
+
chunks=True,
|
| 154 |
+
)
|
| 155 |
+
h5f.create_dataset(
|
| 156 |
+
"sceneids", data=data["sceneids"], compression="lzf", chunks=True
|
| 157 |
+
)
|
| 158 |
+
h5f.create_dataset(
|
| 159 |
+
"scope", data=data["scope"], compression="lzf", chunks=True
|
| 160 |
+
)
|
| 161 |
+
h5f.create_dataset(
|
| 162 |
+
"video_flags",
|
| 163 |
+
data=data["video_flags"],
|
| 164 |
+
compression="lzf",
|
| 165 |
+
chunks=True,
|
| 166 |
+
)
|
| 167 |
+
h5f.create_dataset(
|
| 168 |
+
"groups", data=data["groups"], compression="lzf", chunks=True
|
| 169 |
+
)
|
| 170 |
+
h5f.create_dataset(
|
| 171 |
+
"id_ranges", data=data["id_ranges"], compression="lzf", chunks=True
|
| 172 |
+
)
|
| 173 |
+
h5f.create_dataset(
|
| 174 |
+
"images", data=data["images"], compression="lzf", chunks=True
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def __len__(self):
|
| 178 |
+
return len(self.scope)
|
| 179 |
+
|
| 180 |
+
def get_image_num(self):
|
| 181 |
+
return len(self.images)
|
| 182 |
+
|
| 183 |
+
def get_stats(self):
|
| 184 |
+
return f"{len(self)} groups of views"
|
| 185 |
+
|
| 186 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 187 |
+
scene = self.scenes[self.sceneids[idx]]
|
| 188 |
+
if rng.random() < 0.6:
|
| 189 |
+
ids = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1])
|
| 190 |
+
cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3)
|
| 191 |
+
start_ids = ids[: len(ids) - cut_off + 1]
|
| 192 |
+
start_id = rng.choice(start_ids)
|
| 193 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 194 |
+
num_views,
|
| 195 |
+
start_id,
|
| 196 |
+
ids.tolist(),
|
| 197 |
+
rng,
|
| 198 |
+
max_interval=self.max_interval,
|
| 199 |
+
video_prob=0.8,
|
| 200 |
+
fix_interval_prob=0.5,
|
| 201 |
+
block_shuffle=16,
|
| 202 |
+
)
|
| 203 |
+
ids = np.array(ids)[pos]
|
| 204 |
+
image_idxs = self.images[ids]
|
| 205 |
+
else:
|
| 206 |
+
ordered_video = False
|
| 207 |
+
seq_start_index = self.scope[idx]
|
| 208 |
+
seq_end_index = self.scope[idx + 1] if idx < len(self.scope) - 1 else None
|
| 209 |
+
image_idxs = (
|
| 210 |
+
self.groups[seq_start_index:seq_end_index]
|
| 211 |
+
if seq_end_index is not None
|
| 212 |
+
else self.groups[seq_start_index:]
|
| 213 |
+
)
|
| 214 |
+
image_idxs, overlap_scores = image_idxs[:, :2], image_idxs[:, 2]
|
| 215 |
+
replace = (
|
| 216 |
+
True
|
| 217 |
+
if self.allow_repeat
|
| 218 |
+
or len(overlap_scores[overlap_scores > 0]) < num_views
|
| 219 |
+
else False
|
| 220 |
+
)
|
| 221 |
+
image_idxs = rng.choice(
|
| 222 |
+
image_idxs,
|
| 223 |
+
num_views,
|
| 224 |
+
replace=replace,
|
| 225 |
+
p=overlap_scores / np.sum(overlap_scores),
|
| 226 |
+
)
|
| 227 |
+
image_idxs = image_idxs.astype(np.int64)
|
| 228 |
+
|
| 229 |
+
views = []
|
| 230 |
+
for v, view_idx in enumerate(image_idxs):
|
| 231 |
+
img_path = self.imgid2path(view_idx, scene)
|
| 232 |
+
depth_path = img_path.replace("rgb", "depth").replace(".jpg", ".npy")
|
| 233 |
+
cam_path = img_path.replace("rgb", "cam").replace(".jpg", ".npz")
|
| 234 |
+
sky_mask_path = img_path.replace("rgb", "sky_mask")
|
| 235 |
+
image = imread_cv2(img_path)
|
| 236 |
+
depthmap = np.load(depth_path)
|
| 237 |
+
camera_params = np.load(cam_path)
|
| 238 |
+
sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_UNCHANGED) >= 127
|
| 239 |
+
|
| 240 |
+
intrinsics = camera_params["intrinsic"].astype(np.float32)
|
| 241 |
+
camera_pose = camera_params["pose"].astype(np.float32)
|
| 242 |
+
|
| 243 |
+
depthmap[sky_mask] = -1.0
|
| 244 |
+
depthmap[depthmap > 400.0] = 0.0
|
| 245 |
+
depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
|
| 246 |
+
threshold = (
|
| 247 |
+
np.percentile(depthmap[depthmap > 0], 98)
|
| 248 |
+
if depthmap[depthmap > 0].size > 0
|
| 249 |
+
else 0
|
| 250 |
+
)
|
| 251 |
+
depthmap[depthmap > threshold] = 0.0
|
| 252 |
+
|
| 253 |
+
image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 254 |
+
image, depthmap, intrinsics, resolution, rng, info=(img_path)
|
| 255 |
+
)
|
| 256 |
+
# generate img mask and raymap mask
|
| 257 |
+
img_mask, ray_mask = self.get_img_and_ray_masks(
|
| 258 |
+
self.is_metric, v, rng, p=[0.75, 0.2, 0.05]
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
views.append(
|
| 262 |
+
dict(
|
| 263 |
+
img=image,
|
| 264 |
+
depthmap=depthmap,
|
| 265 |
+
camera_pose=camera_pose, # cam2world
|
| 266 |
+
camera_intrinsics=intrinsics,
|
| 267 |
+
dataset="MapFree",
|
| 268 |
+
label=img_path,
|
| 269 |
+
is_metric=self.is_metric,
|
| 270 |
+
instance=img_path,
|
| 271 |
+
is_video=ordered_video,
|
| 272 |
+
quantile=np.array(0.96, dtype=np.float32),
|
| 273 |
+
img_mask=img_mask,
|
| 274 |
+
ray_mask=ray_mask,
|
| 275 |
+
camera_only=False,
|
| 276 |
+
depth_only=False,
|
| 277 |
+
single_view=False,
|
| 278 |
+
reset=False,
|
| 279 |
+
)
|
| 280 |
+
)
|
| 281 |
+
assert len(views) == num_views
|
| 282 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mvs_synth.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 11 |
+
from dust3r.utils.image import imread_pil
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MVS_Synth_Multi(BaseMultiViewDataset):
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.video = True
|
| 18 |
+
self.is_metric = False
|
| 19 |
+
self.max_interval = 4
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
self.loaded_data = self._load_data()
|
| 22 |
+
print('DATA: mvs_synth', len(self))
|
| 23 |
+
|
| 24 |
+
def _load_data(self):
|
| 25 |
+
self.scenes = os.listdir(self.ROOT)
|
| 26 |
+
|
| 27 |
+
offset = 0
|
| 28 |
+
scenes = []
|
| 29 |
+
sceneids = []
|
| 30 |
+
scene_img_list = []
|
| 31 |
+
images = []
|
| 32 |
+
start_img_ids = []
|
| 33 |
+
|
| 34 |
+
j = 0
|
| 35 |
+
for scene in tqdm(self.scenes):
|
| 36 |
+
scene_dir = osp.join(self.ROOT, scene)
|
| 37 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 38 |
+
basenames = sorted(
|
| 39 |
+
[f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")]
|
| 40 |
+
)
|
| 41 |
+
num_imgs = len(basenames)
|
| 42 |
+
cut_off = (
|
| 43 |
+
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if num_imgs < cut_off:
|
| 47 |
+
print(f"Skipping {scene}")
|
| 48 |
+
continue
|
| 49 |
+
img_ids = list(np.arange(num_imgs) + offset)
|
| 50 |
+
start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
|
| 51 |
+
|
| 52 |
+
start_img_ids.extend(start_img_ids_)
|
| 53 |
+
sceneids.extend([j] * num_imgs)
|
| 54 |
+
images.extend(basenames)
|
| 55 |
+
scenes.append(scene)
|
| 56 |
+
scene_img_list.append(img_ids)
|
| 57 |
+
|
| 58 |
+
# offset groups
|
| 59 |
+
offset += num_imgs
|
| 60 |
+
j += 1
|
| 61 |
+
|
| 62 |
+
self.scenes = scenes
|
| 63 |
+
self.sceneids = sceneids
|
| 64 |
+
self.images = images
|
| 65 |
+
self.start_img_ids = start_img_ids
|
| 66 |
+
self.scene_img_list = scene_img_list
|
| 67 |
+
|
| 68 |
+
def __len__(self):
|
| 69 |
+
return len(self.start_img_ids)
|
| 70 |
+
|
| 71 |
+
def get_image_num(self):
|
| 72 |
+
return len(self.images)
|
| 73 |
+
|
| 74 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 75 |
+
start_id = self.start_img_ids[idx]
|
| 76 |
+
all_image_ids = self.scene_img_list[self.sceneids[start_id]]
|
| 77 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 78 |
+
num_views,
|
| 79 |
+
start_id,
|
| 80 |
+
all_image_ids,
|
| 81 |
+
rng,
|
| 82 |
+
max_interval=self.max_interval,
|
| 83 |
+
video_prob=1.0,
|
| 84 |
+
fix_interval_prob=1.0,
|
| 85 |
+
)
|
| 86 |
+
image_idxs = np.array(all_image_ids)[pos]
|
| 87 |
+
|
| 88 |
+
views = []
|
| 89 |
+
for v, view_idx in enumerate(image_idxs):
|
| 90 |
+
scene_id = self.sceneids[view_idx]
|
| 91 |
+
scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
|
| 92 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 93 |
+
depth_dir = osp.join(scene_dir, "depth")
|
| 94 |
+
cam_dir = osp.join(scene_dir, "cam")
|
| 95 |
+
|
| 96 |
+
basename = self.images[view_idx]
|
| 97 |
+
|
| 98 |
+
# Load RGB image
|
| 99 |
+
rgb_image = imread_pil(osp.join(rgb_dir, basename + ".jpg"))
|
| 100 |
+
# Load depthmap
|
| 101 |
+
depthmap = np.load(osp.join(depth_dir, basename + ".npy"))
|
| 102 |
+
depthmap[~np.isfinite(depthmap)] = 0 # invalid
|
| 103 |
+
threshold = (
|
| 104 |
+
np.percentile(depthmap[depthmap > 0], 98)
|
| 105 |
+
if depthmap[depthmap > 0].size > 0
|
| 106 |
+
else 0
|
| 107 |
+
)
|
| 108 |
+
depthmap[depthmap > threshold] = 0.0
|
| 109 |
+
depthmap[depthmap > 1000] = 0.0
|
| 110 |
+
|
| 111 |
+
cam = np.load(osp.join(cam_dir, basename + ".npz"))
|
| 112 |
+
camera_pose = cam["pose"]
|
| 113 |
+
intrinsics = cam["intrinsics"]
|
| 114 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 115 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# generate img mask and raymap mask
|
| 119 |
+
img_mask, ray_mask = self.get_img_and_ray_masks(
|
| 120 |
+
self.is_metric, v, rng, p=[0.8, 0.15, 0.05]
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
views.append(
|
| 124 |
+
dict(
|
| 125 |
+
img=rgb_image,
|
| 126 |
+
depthmap=depthmap.astype(np.float32),
|
| 127 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 128 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 129 |
+
dataset="MVS_Synth",
|
| 130 |
+
label=self.scenes[scene_id] + "_" + basename,
|
| 131 |
+
instance=osp.join(rgb_dir, basename + ".jpg"),
|
| 132 |
+
is_metric=self.is_metric,
|
| 133 |
+
is_video=ordered_video,
|
| 134 |
+
quantile=np.array(1.0, dtype=np.float32),
|
| 135 |
+
img_mask=img_mask,
|
| 136 |
+
ray_mask=ray_mask,
|
| 137 |
+
camera_only=False,
|
| 138 |
+
depth_only=False,
|
| 139 |
+
single_view=False,
|
| 140 |
+
reset=False,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
assert len(views) == num_views
|
| 144 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/omniobject3d.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 12 |
+
from dust3r.utils.image import imread_cv2
|
| 13 |
+
import re
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def extract_number(filename):
|
| 17 |
+
match = re.search(r"\d+", filename)
|
| 18 |
+
if match:
|
| 19 |
+
return int(match.group())
|
| 20 |
+
return 0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OmniObject3D_Multi(BaseMultiViewDataset):
|
| 24 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 25 |
+
self.ROOT = ROOT
|
| 26 |
+
self.video = False
|
| 27 |
+
self.is_metric = False # True
|
| 28 |
+
super().__init__(*args, **kwargs)
|
| 29 |
+
|
| 30 |
+
self.loaded_data = self._load_data()
|
| 31 |
+
|
| 32 |
+
def _load_data(self):
|
| 33 |
+
self.scenes = [
|
| 34 |
+
d
|
| 35 |
+
for d in os.listdir(self.ROOT)
|
| 36 |
+
if os.path.isdir(os.path.join(self.ROOT, d)) and not d.startswith('.')
|
| 37 |
+
]
|
| 38 |
+
with open(os.path.join(self.ROOT, "scale.json"), "r") as f:
|
| 39 |
+
self.scales = json.load(f)
|
| 40 |
+
offset = 0
|
| 41 |
+
scenes = []
|
| 42 |
+
sceneids = []
|
| 43 |
+
scene_img_list = []
|
| 44 |
+
images = []
|
| 45 |
+
start_img_ids = []
|
| 46 |
+
|
| 47 |
+
j = 0
|
| 48 |
+
for scene in tqdm(self.scenes):
|
| 49 |
+
scene_dir = osp.join(self.ROOT, scene)
|
| 50 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 51 |
+
basenames = sorted(
|
| 52 |
+
[f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")],
|
| 53 |
+
key=extract_number,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
num_imgs = len(basenames)
|
| 57 |
+
cut_off = (
|
| 58 |
+
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if num_imgs < cut_off:
|
| 62 |
+
print(f"Skipping {scene}")
|
| 63 |
+
continue
|
| 64 |
+
img_ids = list(np.arange(num_imgs) + offset)
|
| 65 |
+
start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
|
| 66 |
+
|
| 67 |
+
start_img_ids.extend([(scene, id) for id in start_img_ids_])
|
| 68 |
+
sceneids.extend([j] * num_imgs)
|
| 69 |
+
images.extend(basenames)
|
| 70 |
+
scenes.append(scene)
|
| 71 |
+
scene_img_list.append(img_ids)
|
| 72 |
+
|
| 73 |
+
# offset groups
|
| 74 |
+
offset += num_imgs
|
| 75 |
+
j += 1
|
| 76 |
+
|
| 77 |
+
self.scenes = scenes
|
| 78 |
+
self.sceneids = sceneids
|
| 79 |
+
self.images = images
|
| 80 |
+
self.start_img_ids = start_img_ids
|
| 81 |
+
self.scene_img_list = scene_img_list
|
| 82 |
+
|
| 83 |
+
def __len__(self):
|
| 84 |
+
return len(self.start_img_ids)
|
| 85 |
+
|
| 86 |
+
def get_image_num(self):
|
| 87 |
+
return len(self.images)
|
| 88 |
+
|
| 89 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 90 |
+
scene, start_id = self.start_img_ids[idx]
|
| 91 |
+
all_image_ids = self.scene_img_list[self.sceneids[start_id]]
|
| 92 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 93 |
+
num_views, start_id, all_image_ids, rng, max_interval=100, video_prob=0.0
|
| 94 |
+
)
|
| 95 |
+
image_idxs = np.array(all_image_ids)[pos]
|
| 96 |
+
|
| 97 |
+
views = []
|
| 98 |
+
for v, view_idx in enumerate(image_idxs):
|
| 99 |
+
scene_id = self.sceneids[view_idx]
|
| 100 |
+
scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
|
| 101 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 102 |
+
depth_dir = osp.join(scene_dir, "depth")
|
| 103 |
+
cam_dir = osp.join(scene_dir, "cam")
|
| 104 |
+
|
| 105 |
+
basename = self.images[view_idx]
|
| 106 |
+
|
| 107 |
+
# Load RGB image
|
| 108 |
+
rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png"))
|
| 109 |
+
depthmap = np.load(osp.join(depth_dir, basename + ".npy"))
|
| 110 |
+
cam = np.load(osp.join(cam_dir, basename + ".npz"))
|
| 111 |
+
camera_pose = cam["pose"]
|
| 112 |
+
intrinsics = cam["intrinsics"]
|
| 113 |
+
scale = self.scales[self.scenes[scene_id]]
|
| 114 |
+
depthmap = depthmap / scale / 1000.0
|
| 115 |
+
camera_pose[:3, 3] = camera_pose[:3, 3] / scale / 1000.0
|
| 116 |
+
|
| 117 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 118 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
img_mask, ray_mask = self.get_img_and_ray_masks(
|
| 122 |
+
self.is_metric, v, rng, p=[0.8, 0.15, 0.05]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
views.append(
|
| 126 |
+
dict(
|
| 127 |
+
img=rgb_image,
|
| 128 |
+
depthmap=depthmap.astype(np.float32),
|
| 129 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 130 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 131 |
+
dataset="OmniObject3D",
|
| 132 |
+
label=self.scenes[scene_id] + "_" + basename,
|
| 133 |
+
instance=f"{str(idx)}_{str(view_idx)}",
|
| 134 |
+
is_metric=self.is_metric,
|
| 135 |
+
is_video=ordered_video,
|
| 136 |
+
quantile=np.array(1.0, dtype=np.float32),
|
| 137 |
+
img_mask=img_mask,
|
| 138 |
+
ray_mask=ray_mask,
|
| 139 |
+
camera_only=False,
|
| 140 |
+
depth_only=False,
|
| 141 |
+
single_view=False,
|
| 142 |
+
reset=False,
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
assert len(views) == num_views
|
| 146 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/pointodyssey.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 11 |
+
from dust3r.utils.image import imread_cv2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PointOdyssey_Multi(BaseMultiViewDataset):
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.video = True
|
| 18 |
+
self.is_metric = True
|
| 19 |
+
self.max_interval = 4
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
assert self.split in ["train", "test", "val"]
|
| 22 |
+
self.scenes_to_use = [
|
| 23 |
+
# 'cab_h_bench_3rd', 'cab_h_bench_ego1', 'cab_h_bench_ego2',
|
| 24 |
+
"cnb_dlab_0215_3rd",
|
| 25 |
+
"cnb_dlab_0215_ego1",
|
| 26 |
+
"cnb_dlab_0225_3rd",
|
| 27 |
+
"cnb_dlab_0225_ego1",
|
| 28 |
+
"dancing",
|
| 29 |
+
"dancingroom0_3rd",
|
| 30 |
+
"footlab_3rd",
|
| 31 |
+
"footlab_ego1",
|
| 32 |
+
"footlab_ego2",
|
| 33 |
+
"girl",
|
| 34 |
+
"girl_egocentric",
|
| 35 |
+
"human_egocentric",
|
| 36 |
+
"human_in_scene",
|
| 37 |
+
"human_in_scene1",
|
| 38 |
+
"kg",
|
| 39 |
+
"kg_ego1",
|
| 40 |
+
"kg_ego2",
|
| 41 |
+
"kitchen_gfloor",
|
| 42 |
+
"kitchen_gfloor_ego1",
|
| 43 |
+
"kitchen_gfloor_ego2",
|
| 44 |
+
"scene_carb_h_tables",
|
| 45 |
+
"scene_carb_h_tables_ego1",
|
| 46 |
+
"scene_carb_h_tables_ego2",
|
| 47 |
+
"scene_j716_3rd",
|
| 48 |
+
"scene_j716_ego1",
|
| 49 |
+
"scene_j716_ego2",
|
| 50 |
+
"scene_recording_20210910_S05_S06_0_3rd",
|
| 51 |
+
"scene_recording_20210910_S05_S06_0_ego2",
|
| 52 |
+
"scene1_0129",
|
| 53 |
+
"scene1_0129_ego",
|
| 54 |
+
"seminar_h52_3rd",
|
| 55 |
+
"seminar_h52_ego1",
|
| 56 |
+
"seminar_h52_ego2",
|
| 57 |
+
]
|
| 58 |
+
self.loaded_data = self._load_data(self.split)
|
| 59 |
+
|
| 60 |
+
def _load_data(self, split):
|
| 61 |
+
root = os.path.join(self.ROOT, split)
|
| 62 |
+
self.scenes = []
|
| 63 |
+
|
| 64 |
+
offset = 0
|
| 65 |
+
scenes = []
|
| 66 |
+
sceneids = []
|
| 67 |
+
scene_img_list = []
|
| 68 |
+
images = []
|
| 69 |
+
start_img_ids = []
|
| 70 |
+
|
| 71 |
+
j = 0
|
| 72 |
+
for scene in tqdm(os.listdir(root)):
|
| 73 |
+
if scene not in self.scenes_to_use:
|
| 74 |
+
continue
|
| 75 |
+
scene_dir = osp.join(root, scene)
|
| 76 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 77 |
+
basenames = sorted(
|
| 78 |
+
[f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")]
|
| 79 |
+
)
|
| 80 |
+
num_imgs = len(basenames)
|
| 81 |
+
img_ids = list(np.arange(num_imgs) + offset)
|
| 82 |
+
cut_off = (
|
| 83 |
+
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
|
| 84 |
+
)
|
| 85 |
+
start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
|
| 86 |
+
# start_img_ids_ = img_ids[:-self.num_views+1]
|
| 87 |
+
|
| 88 |
+
if num_imgs < cut_off:
|
| 89 |
+
print(f"Skipping {scene}")
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
start_img_ids.extend(start_img_ids_)
|
| 93 |
+
sceneids.extend([j] * num_imgs)
|
| 94 |
+
images.extend(basenames)
|
| 95 |
+
scenes.append(scene)
|
| 96 |
+
scene_img_list.append(img_ids)
|
| 97 |
+
|
| 98 |
+
# offset groups
|
| 99 |
+
offset += num_imgs
|
| 100 |
+
j += 1
|
| 101 |
+
|
| 102 |
+
self.scenes = scenes
|
| 103 |
+
self.sceneids = sceneids
|
| 104 |
+
self.images = images
|
| 105 |
+
self.start_img_ids = start_img_ids
|
| 106 |
+
self.scene_img_list = scene_img_list
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.start_img_ids)
|
| 110 |
+
|
| 111 |
+
def get_image_num(self):
|
| 112 |
+
return len(self.images)
|
| 113 |
+
|
| 114 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 115 |
+
start_id = self.start_img_ids[idx]
|
| 116 |
+
all_image_ids = self.scene_img_list[self.sceneids[start_id]]
|
| 117 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 118 |
+
num_views,
|
| 119 |
+
start_id,
|
| 120 |
+
all_image_ids,
|
| 121 |
+
rng,
|
| 122 |
+
max_interval=self.max_interval,
|
| 123 |
+
video_prob=1.0,
|
| 124 |
+
fix_interval_prob=1.0,
|
| 125 |
+
)
|
| 126 |
+
image_idxs = np.array(all_image_ids)[pos]
|
| 127 |
+
|
| 128 |
+
views = []
|
| 129 |
+
for v, view_idx in enumerate(image_idxs):
|
| 130 |
+
scene_id = self.sceneids[view_idx]
|
| 131 |
+
scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id])
|
| 132 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 133 |
+
depth_dir = osp.join(scene_dir, "depth")
|
| 134 |
+
cam_dir = osp.join(scene_dir, "cam")
|
| 135 |
+
|
| 136 |
+
basename = self.images[view_idx]
|
| 137 |
+
|
| 138 |
+
# Load RGB image
|
| 139 |
+
rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".jpg"))
|
| 140 |
+
# Load depthmap
|
| 141 |
+
depthmap = np.load(osp.join(depth_dir, basename + ".npy"))
|
| 142 |
+
depthmap[~np.isfinite(depthmap)] = 0 # invalid
|
| 143 |
+
depthmap[depthmap > 1000] = 0.0
|
| 144 |
+
|
| 145 |
+
cam = np.load(osp.join(cam_dir, basename + ".npz"))
|
| 146 |
+
camera_pose = cam["pose"]
|
| 147 |
+
intrinsics = cam["intrinsics"]
|
| 148 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 149 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# generate img mask and raymap mask
|
| 153 |
+
img_mask, ray_mask = self.get_img_and_ray_masks(
|
| 154 |
+
self.is_metric, v, rng, p=[0.9, 0.05, 0.05]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
views.append(
|
| 158 |
+
dict(
|
| 159 |
+
img=rgb_image,
|
| 160 |
+
depthmap=depthmap.astype(np.float32),
|
| 161 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 162 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 163 |
+
dataset="PointOdyssey",
|
| 164 |
+
label=self.scenes[scene_id] + "_" + basename,
|
| 165 |
+
instance=osp.join(rgb_dir, basename + ".jpg"),
|
| 166 |
+
is_metric=self.is_metric,
|
| 167 |
+
is_video=ordered_video,
|
| 168 |
+
quantile=np.array(1.0, dtype=np.float32),
|
| 169 |
+
img_mask=img_mask,
|
| 170 |
+
ray_mask=ray_mask,
|
| 171 |
+
camera_only=False,
|
| 172 |
+
depth_only=False,
|
| 173 |
+
single_view=False,
|
| 174 |
+
reset=False,
|
| 175 |
+
)
|
| 176 |
+
)
|
| 177 |
+
assert len(views) == num_views
|
| 178 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/realestate10k.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 11 |
+
from dust3r.utils.image import imread_cv2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RE10K_Multi(BaseMultiViewDataset):
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.video = True
|
| 18 |
+
self.is_metric = False
|
| 19 |
+
self.max_interval = 128
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
self.loaded_data = self._load_data()
|
| 22 |
+
|
| 23 |
+
def _load_data(self):
|
| 24 |
+
self.scenes = os.listdir(self.ROOT)
|
| 25 |
+
|
| 26 |
+
offset = 0
|
| 27 |
+
scenes = []
|
| 28 |
+
sceneids = []
|
| 29 |
+
scene_img_list = []
|
| 30 |
+
images = []
|
| 31 |
+
start_img_ids = []
|
| 32 |
+
|
| 33 |
+
j = 0
|
| 34 |
+
for scene in tqdm(self.scenes):
|
| 35 |
+
scene_dir = osp.join(self.ROOT, scene)
|
| 36 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 37 |
+
basenames = sorted(
|
| 38 |
+
[f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")],
|
| 39 |
+
key=lambda x: int(x),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
num_imgs = len(basenames)
|
| 43 |
+
img_ids = list(np.arange(num_imgs) + offset)
|
| 44 |
+
cut_off = (
|
| 45 |
+
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
|
| 46 |
+
)
|
| 47 |
+
if num_imgs < cut_off:
|
| 48 |
+
print(f"Skipping {scene}")
|
| 49 |
+
continue
|
| 50 |
+
start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
|
| 51 |
+
|
| 52 |
+
start_img_ids.extend([(scene, id) for id in start_img_ids_])
|
| 53 |
+
sceneids.extend([j] * num_imgs)
|
| 54 |
+
images.extend(basenames)
|
| 55 |
+
scenes.append(scene)
|
| 56 |
+
scene_img_list.append(img_ids)
|
| 57 |
+
|
| 58 |
+
# offset groups
|
| 59 |
+
offset += num_imgs
|
| 60 |
+
j += 1
|
| 61 |
+
|
| 62 |
+
self.scenes = scenes
|
| 63 |
+
self.sceneids = sceneids
|
| 64 |
+
self.images = images
|
| 65 |
+
self.start_img_ids = start_img_ids
|
| 66 |
+
self.scene_img_list = scene_img_list
|
| 67 |
+
|
| 68 |
+
self.invalid_scenes = {scene: False for scene in self.scenes}
|
| 69 |
+
|
| 70 |
+
def __len__(self):
|
| 71 |
+
return len(self.start_img_ids)
|
| 72 |
+
|
| 73 |
+
def get_image_num(self):
|
| 74 |
+
return len(self.images)
|
| 75 |
+
|
| 76 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 77 |
+
invalid_seq = True
|
| 78 |
+
scene, start_id = self.start_img_ids[idx]
|
| 79 |
+
|
| 80 |
+
while invalid_seq:
|
| 81 |
+
while self.invalid_scenes[scene]:
|
| 82 |
+
idx = rng.integers(low=0, high=len(self.start_img_ids))
|
| 83 |
+
scene, start_id = self.start_img_ids[idx]
|
| 84 |
+
|
| 85 |
+
all_image_ids = self.scene_img_list[self.sceneids[start_id]]
|
| 86 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 87 |
+
num_views, start_id, all_image_ids, rng, max_interval=self.max_interval
|
| 88 |
+
)
|
| 89 |
+
image_idxs = np.array(all_image_ids)[pos]
|
| 90 |
+
|
| 91 |
+
views = []
|
| 92 |
+
for view_idx in image_idxs:
|
| 93 |
+
scene_id = self.sceneids[view_idx]
|
| 94 |
+
scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
|
| 95 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 96 |
+
cam_dir = osp.join(scene_dir, "cam")
|
| 97 |
+
|
| 98 |
+
basename = self.images[view_idx]
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
# Load RGB image
|
| 102 |
+
rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png"))
|
| 103 |
+
# Load depthmap, no depth, set to all ones
|
| 104 |
+
depthmap = np.ones_like(rgb_image[..., 0], dtype=np.float32)
|
| 105 |
+
cam = np.load(osp.join(cam_dir, basename + ".npz"))
|
| 106 |
+
intrinsics = cam["intrinsics"]
|
| 107 |
+
camera_pose = cam["pose"]
|
| 108 |
+
except:
|
| 109 |
+
print(f"Error loading {scene} {basename}, skipping")
|
| 110 |
+
self.invalid_scenes[scene] = True
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 114 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
views.append(
|
| 118 |
+
dict(
|
| 119 |
+
img=rgb_image,
|
| 120 |
+
depthmap=depthmap.astype(np.float32),
|
| 121 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 122 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 123 |
+
dataset="realestate10k",
|
| 124 |
+
label=self.scenes[scene_id] + "_" + basename,
|
| 125 |
+
instance=f"{str(idx)}_{str(view_idx)}",
|
| 126 |
+
is_metric=self.is_metric,
|
| 127 |
+
is_video=ordered_video,
|
| 128 |
+
quantile=np.array(0.98, dtype=np.float32),
|
| 129 |
+
img_mask=True,
|
| 130 |
+
ray_mask=False,
|
| 131 |
+
camera_only=True,
|
| 132 |
+
depth_only=False,
|
| 133 |
+
single_view=False,
|
| 134 |
+
reset=False,
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
if len(views) == num_views:
|
| 138 |
+
invalid_seq = False
|
| 139 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannet.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 11 |
+
from dust3r.utils.image import imread_cv2, imread_pil
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ScanNet_Multi(BaseMultiViewDataset):
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.video = True
|
| 18 |
+
self.is_metric = True
|
| 19 |
+
self.max_interval = 30
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
|
| 22 |
+
self.loaded_data = self._load_data(self.split)
|
| 23 |
+
print('DATA: scannet', len(self))
|
| 24 |
+
|
| 25 |
+
def _load_data(self, split):
|
| 26 |
+
self.scene_root = osp.join(
|
| 27 |
+
self.ROOT, "scans_train" if split == "train" else "scans_test"
|
| 28 |
+
)
|
| 29 |
+
self.scenes = [
|
| 30 |
+
scene for scene in os.listdir(self.scene_root) if scene.startswith("scene")
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
offset = 0
|
| 34 |
+
scenes = []
|
| 35 |
+
sceneids = []
|
| 36 |
+
scene_img_list = []
|
| 37 |
+
images = []
|
| 38 |
+
start_img_ids = []
|
| 39 |
+
|
| 40 |
+
j = 0
|
| 41 |
+
for scene in tqdm(self.scenes):
|
| 42 |
+
scene_dir = osp.join(self.scene_root, scene)
|
| 43 |
+
with np.load(
|
| 44 |
+
osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True
|
| 45 |
+
) as data:
|
| 46 |
+
basenames = data["images"]
|
| 47 |
+
num_imgs = len(basenames)
|
| 48 |
+
img_ids = list(np.arange(num_imgs) + offset)
|
| 49 |
+
cut_off = (
|
| 50 |
+
self.num_views
|
| 51 |
+
if not self.allow_repeat
|
| 52 |
+
else max(self.num_views // 3, 3)
|
| 53 |
+
)
|
| 54 |
+
start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
|
| 55 |
+
|
| 56 |
+
if num_imgs < cut_off:
|
| 57 |
+
print(f"Skipping {scene}")
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
start_img_ids.extend(start_img_ids_)
|
| 61 |
+
sceneids.extend([j] * num_imgs)
|
| 62 |
+
images.extend(basenames)
|
| 63 |
+
scenes.append(scene)
|
| 64 |
+
scene_img_list.append(img_ids)
|
| 65 |
+
|
| 66 |
+
# offset groups
|
| 67 |
+
offset += num_imgs
|
| 68 |
+
j += 1
|
| 69 |
+
|
| 70 |
+
self.scenes = scenes
|
| 71 |
+
self.sceneids = sceneids
|
| 72 |
+
self.images = images
|
| 73 |
+
self.start_img_ids = start_img_ids
|
| 74 |
+
self.scene_img_list = scene_img_list
|
| 75 |
+
|
| 76 |
+
def __len__(self):
|
| 77 |
+
return len(self.start_img_ids)
|
| 78 |
+
|
| 79 |
+
def get_image_num(self):
|
| 80 |
+
return len(self.images)
|
| 81 |
+
|
| 82 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 83 |
+
start_id = self.start_img_ids[idx]
|
| 84 |
+
all_image_ids = self.scene_img_list[self.sceneids[start_id]]
|
| 85 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 86 |
+
num_views,
|
| 87 |
+
start_id,
|
| 88 |
+
all_image_ids,
|
| 89 |
+
rng,
|
| 90 |
+
max_interval=self.max_interval,
|
| 91 |
+
video_prob=0.6,
|
| 92 |
+
fix_interval_prob=0.6,
|
| 93 |
+
block_shuffle=16,
|
| 94 |
+
)
|
| 95 |
+
image_idxs = np.array(all_image_ids)[pos]
|
| 96 |
+
|
| 97 |
+
views = []
|
| 98 |
+
for v, view_idx in enumerate(image_idxs):
|
| 99 |
+
scene_id = self.sceneids[view_idx]
|
| 100 |
+
scene_dir = osp.join(self.scene_root, self.scenes[scene_id])
|
| 101 |
+
rgb_dir = osp.join(scene_dir, "color")
|
| 102 |
+
depth_dir = osp.join(scene_dir, "depth")
|
| 103 |
+
cam_dir = osp.join(scene_dir, "cam")
|
| 104 |
+
|
| 105 |
+
basename = self.images[view_idx]
|
| 106 |
+
|
| 107 |
+
# Load RGB image
|
| 108 |
+
rgb_image = imread_pil(osp.join(rgb_dir, basename + ".jpg"))
|
| 109 |
+
# Load depthmap
|
| 110 |
+
depthmap = imread_cv2(
|
| 111 |
+
osp.join(depth_dir, basename + ".png"), cv2.IMREAD_UNCHANGED
|
| 112 |
+
)
|
| 113 |
+
depthmap = depthmap.astype(np.float32) / 1000
|
| 114 |
+
depthmap[~np.isfinite(depthmap)] = 0 # invalid
|
| 115 |
+
|
| 116 |
+
cam = np.load(osp.join(cam_dir, basename + ".npz"))
|
| 117 |
+
camera_pose = cam["pose"]
|
| 118 |
+
intrinsics = cam["intrinsics"]
|
| 119 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 120 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# generate img mask and raymap mask
|
| 124 |
+
img_mask, ray_mask = self.get_img_and_ray_masks(
|
| 125 |
+
self.is_metric, v, rng, p=[0.75, 0.2, 0.05]
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
views.append(
|
| 129 |
+
dict(
|
| 130 |
+
img=rgb_image,
|
| 131 |
+
depthmap=depthmap.astype(np.float32),
|
| 132 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 133 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 134 |
+
dataset="ScanNet",
|
| 135 |
+
label=self.scenes[scene_id] + "_" + basename,
|
| 136 |
+
instance=f"{str(idx)}_{str(view_idx)}",
|
| 137 |
+
is_metric=self.is_metric,
|
| 138 |
+
is_video=ordered_video,
|
| 139 |
+
quantile=np.array(0.98, dtype=np.float32),
|
| 140 |
+
img_mask=img_mask,
|
| 141 |
+
ray_mask=ray_mask,
|
| 142 |
+
camera_only=False,
|
| 143 |
+
depth_only=False,
|
| 144 |
+
single_view=False,
|
| 145 |
+
reset=False,
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
assert len(views) == num_views
|
| 149 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannetpp.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 9 |
+
|
| 10 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 11 |
+
from dust3r.utils.image import imread_cv2, imread_pil
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ScanNetpp_Multi(BaseMultiViewDataset):
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.video = True
|
| 18 |
+
self.is_metric = True
|
| 19 |
+
self.max_interval = 3
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
assert self.split == "train"
|
| 22 |
+
self.loaded_data = self._load_data()
|
| 23 |
+
|
| 24 |
+
def _load_data(self):
|
| 25 |
+
with np.load(osp.join(self.ROOT, "all_metadata.npz")) as data:
|
| 26 |
+
self.scenes = data["scenes"]
|
| 27 |
+
offset = 0
|
| 28 |
+
scenes = []
|
| 29 |
+
sceneids = []
|
| 30 |
+
images = []
|
| 31 |
+
intrinsics = []
|
| 32 |
+
trajectories = []
|
| 33 |
+
groups = []
|
| 34 |
+
id_ranges = []
|
| 35 |
+
j = 0
|
| 36 |
+
self.image_num = 0
|
| 37 |
+
for scene in self.scenes:
|
| 38 |
+
scene_dir = osp.join(self.ROOT, scene)
|
| 39 |
+
with np.load(
|
| 40 |
+
osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True
|
| 41 |
+
) as data:
|
| 42 |
+
imgs = data["images"]
|
| 43 |
+
self.image_num += len(imgs)
|
| 44 |
+
img_ids = np.arange(len(imgs)).tolist()
|
| 45 |
+
intrins = data["intrinsics"]
|
| 46 |
+
traj = data["trajectories"]
|
| 47 |
+
imgs_on_disk = sorted(os.listdir(osp.join(scene_dir, "images")))
|
| 48 |
+
imgs_on_disk = list(map(lambda x: x[:-4], imgs_on_disk))
|
| 49 |
+
|
| 50 |
+
dslr_ids = [
|
| 51 |
+
i + offset
|
| 52 |
+
for i in img_ids
|
| 53 |
+
if imgs[i].startswith("DSC") and imgs[i] in imgs_on_disk
|
| 54 |
+
]
|
| 55 |
+
iphone_ids = [
|
| 56 |
+
i + offset
|
| 57 |
+
for i in img_ids
|
| 58 |
+
if imgs[i].startswith("frame") and imgs[i] in imgs_on_disk
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
num_imgs = len(imgs)
|
| 62 |
+
assert max(dslr_ids) < min(iphone_ids)
|
| 63 |
+
assert "image_collection" in data
|
| 64 |
+
|
| 65 |
+
img_groups = []
|
| 66 |
+
img_id_ranges = []
|
| 67 |
+
|
| 68 |
+
# 使用与其他数据集一致的 cut_off 逻辑
|
| 69 |
+
min_group_len = (
|
| 70 |
+
self.num_views
|
| 71 |
+
if not self.allow_repeat
|
| 72 |
+
else max(self.num_views // 3, 3)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
for ref_id, group in data["image_collection"].item().items():
|
| 76 |
+
if len(group) + 1 < min_group_len:
|
| 77 |
+
continue
|
| 78 |
+
group.insert(0, (ref_id, 1.0))
|
| 79 |
+
sorted_group = sorted(group, key=lambda x: x[1], reverse=True)
|
| 80 |
+
group = [int(x[0] + offset) for x in sorted_group]
|
| 81 |
+
|
| 82 |
+
# 确定对应的视频帧列表
|
| 83 |
+
if imgs[ref_id].startswith("frame"):
|
| 84 |
+
video_ids = dslr_ids
|
| 85 |
+
else:
|
| 86 |
+
video_ids = iphone_ids
|
| 87 |
+
|
| 88 |
+
# 只有当视频帧列表足够长时才添加
|
| 89 |
+
if len(video_ids) >= min_group_len:
|
| 90 |
+
img_groups.append(sorted(group))
|
| 91 |
+
img_id_ranges.append(video_ids)
|
| 92 |
+
|
| 93 |
+
if len(img_groups) == 0:
|
| 94 |
+
print(f"Skipping {scene}")
|
| 95 |
+
continue
|
| 96 |
+
scenes.append(scene)
|
| 97 |
+
sceneids.extend([j] * num_imgs)
|
| 98 |
+
images.extend(imgs)
|
| 99 |
+
intrinsics.append(intrins)
|
| 100 |
+
trajectories.append(traj)
|
| 101 |
+
|
| 102 |
+
# offset groups
|
| 103 |
+
groups.extend(img_groups)
|
| 104 |
+
id_ranges.extend(img_id_ranges)
|
| 105 |
+
offset += num_imgs
|
| 106 |
+
j += 1
|
| 107 |
+
|
| 108 |
+
self.scenes = scenes
|
| 109 |
+
self.sceneids = sceneids
|
| 110 |
+
self.images = images
|
| 111 |
+
self.intrinsics = np.concatenate(intrinsics, axis=0)
|
| 112 |
+
self.trajectories = np.concatenate(trajectories, axis=0)
|
| 113 |
+
self.id_ranges = id_ranges
|
| 114 |
+
self.groups = groups
|
| 115 |
+
|
| 116 |
+
def __len__(self):
|
| 117 |
+
return len(self.groups) * 10
|
| 118 |
+
|
| 119 |
+
def get_image_num(self):
|
| 120 |
+
return self.image_num
|
| 121 |
+
|
| 122 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 123 |
+
idx = idx // 10
|
| 124 |
+
image_idxs = self.groups[idx]
|
| 125 |
+
rand_val = rng.random()
|
| 126 |
+
|
| 127 |
+
image_idxs_video = self.id_ranges[idx]
|
| 128 |
+
cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3)
|
| 129 |
+
start_image_idxs = image_idxs_video[: len(image_idxs_video) - cut_off + 1]
|
| 130 |
+
|
| 131 |
+
if rand_val < 0.7 and len(start_image_idxs) > 0:
|
| 132 |
+
start_id = rng.choice(start_image_idxs)
|
| 133 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 134 |
+
num_views,
|
| 135 |
+
start_id,
|
| 136 |
+
image_idxs_video,
|
| 137 |
+
rng,
|
| 138 |
+
max_interval=self.max_interval,
|
| 139 |
+
video_prob=0.8,
|
| 140 |
+
fix_interval_prob=0.5,
|
| 141 |
+
block_shuffle=16,
|
| 142 |
+
)
|
| 143 |
+
image_idxs = np.array(image_idxs_video)[pos]
|
| 144 |
+
|
| 145 |
+
else:
|
| 146 |
+
ordered_video = True
|
| 147 |
+
# ordered video with varying intervals
|
| 148 |
+
num_candidates = len(image_idxs)
|
| 149 |
+
max_id = min(num_candidates, int(num_views * (2 + 2 * rng.random())))
|
| 150 |
+
|
| 151 |
+
# 确保有足够的候选帧
|
| 152 |
+
if num_candidates < num_views:
|
| 153 |
+
# 如果候选帧不足,使用重复采样
|
| 154 |
+
image_idxs = sorted(rng.choice(image_idxs, size=num_views, replace=True))
|
| 155 |
+
else:
|
| 156 |
+
image_idxs = sorted(rng.permutation(image_idxs[:max_id])[:num_views])
|
| 157 |
+
|
| 158 |
+
if rand_val > 0.75:
|
| 159 |
+
ordered_video = False
|
| 160 |
+
image_idxs = rng.permutation(image_idxs)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
views = []
|
| 164 |
+
for v, view_idx in enumerate(image_idxs):
|
| 165 |
+
scene_id = self.sceneids[view_idx]
|
| 166 |
+
scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
|
| 167 |
+
|
| 168 |
+
intrinsics = self.intrinsics[view_idx]
|
| 169 |
+
camera_pose = self.trajectories[view_idx]
|
| 170 |
+
basename = self.images[view_idx]
|
| 171 |
+
|
| 172 |
+
# Load RGB image
|
| 173 |
+
rgb_image = imread_pil(osp.join(scene_dir, "images", basename + ".jpg"))
|
| 174 |
+
# Load depthmap
|
| 175 |
+
depthmap = imread_cv2(
|
| 176 |
+
osp.join(scene_dir, "depth", basename + ".png"), cv2.IMREAD_UNCHANGED
|
| 177 |
+
)
|
| 178 |
+
depthmap = depthmap.astype(np.float32) / 1000
|
| 179 |
+
depthmap[~np.isfinite(depthmap)] = 0 # invalid
|
| 180 |
+
|
| 181 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 182 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# generate img mask and raymap mask
|
| 186 |
+
img_mask, ray_mask = self.get_img_and_ray_masks(
|
| 187 |
+
self.is_metric, v, rng, p=[0.75, 0.2, 0.05]
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
views.append(
|
| 191 |
+
dict(
|
| 192 |
+
img=rgb_image,
|
| 193 |
+
depthmap=depthmap.astype(np.float32),
|
| 194 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 195 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 196 |
+
dataset="ScanNet++",
|
| 197 |
+
label=self.scenes[scene_id] + "_" + basename,
|
| 198 |
+
instance=f"{str(idx)}_{str(view_idx)}",
|
| 199 |
+
is_metric=self.is_metric,
|
| 200 |
+
is_video=ordered_video,
|
| 201 |
+
quantile=np.array(0.99, dtype=np.float32),
|
| 202 |
+
img_mask=img_mask,
|
| 203 |
+
ray_mask=ray_mask,
|
| 204 |
+
camera_only=False,
|
| 205 |
+
depth_only=False,
|
| 206 |
+
single_view=False,
|
| 207 |
+
reset=False,
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
assert len(views) == num_views
|
| 211 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/smartportraits.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 11 |
+
from dust3r.utils.image import imread_cv2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SmartPortraits_Multi(BaseMultiViewDataset):
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.video = True
|
| 18 |
+
self.is_metric = True
|
| 19 |
+
super().__init__(*args, **kwargs)
|
| 20 |
+
self.loaded_data = self._load_data()
|
| 21 |
+
|
| 22 |
+
def _load_data(self):
|
| 23 |
+
scenes = os.listdir(self.ROOT)
|
| 24 |
+
img_names = []
|
| 25 |
+
for scene in scenes:
|
| 26 |
+
scene_dir = osp.join(self.ROOT, scene)
|
| 27 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 28 |
+
basenames = sorted(
|
| 29 |
+
[f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")]
|
| 30 |
+
)
|
| 31 |
+
img_names.extend([(scene, basename) for basename in basenames])
|
| 32 |
+
|
| 33 |
+
self.img_names = img_names
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return len(self.img_names)
|
| 37 |
+
|
| 38 |
+
def get_image_num(self):
|
| 39 |
+
return len(self.img_names)
|
| 40 |
+
|
| 41 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 42 |
+
new_seed = rng.integers(0, 2**32) + idx
|
| 43 |
+
new_rng = np.random.default_rng(new_seed)
|
| 44 |
+
img_names = new_rng.choice(self.img_names, num_views, replace=False)
|
| 45 |
+
|
| 46 |
+
views = []
|
| 47 |
+
for v, img_name in enumerate(img_names):
|
| 48 |
+
# Load RGB image
|
| 49 |
+
scene, img_name = img_name
|
| 50 |
+
rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"))
|
| 51 |
+
depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy"))
|
| 52 |
+
depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
|
| 53 |
+
|
| 54 |
+
intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))[
|
| 55 |
+
"intrinsics"
|
| 56 |
+
]
|
| 57 |
+
# camera pose is not provided, placeholder
|
| 58 |
+
camera_pose = np.eye(4)
|
| 59 |
+
|
| 60 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 61 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
views.append(
|
| 65 |
+
dict(
|
| 66 |
+
img=rgb_image,
|
| 67 |
+
depthmap=depthmap.astype(np.float32),
|
| 68 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 69 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 70 |
+
dataset="SmartPortraits",
|
| 71 |
+
label=img_name,
|
| 72 |
+
instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"),
|
| 73 |
+
is_metric=self.is_metric,
|
| 74 |
+
is_video=False,
|
| 75 |
+
quantile=np.array(0.98, dtype=np.float32),
|
| 76 |
+
img_mask=True,
|
| 77 |
+
ray_mask=False,
|
| 78 |
+
camera_only=False,
|
| 79 |
+
depth_only=False,
|
| 80 |
+
single_view=True,
|
| 81 |
+
reset=True,
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
assert len(views) == num_views
|
| 85 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/threedkb.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 11 |
+
from dust3r.utils.image import imread_cv2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ThreeDKenBurns(BaseMultiViewDataset):
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.video = False
|
| 18 |
+
self.is_metric = False
|
| 19 |
+
super().__init__(*args, **kwargs)
|
| 20 |
+
self.loaded_data = self._load_data()
|
| 21 |
+
|
| 22 |
+
def _load_data(self):
|
| 23 |
+
self.scenes = os.listdir(self.ROOT)
|
| 24 |
+
|
| 25 |
+
offset = 0
|
| 26 |
+
scenes = []
|
| 27 |
+
sceneids = []
|
| 28 |
+
images = []
|
| 29 |
+
img_ids = []
|
| 30 |
+
|
| 31 |
+
j = 0
|
| 32 |
+
for scene in tqdm(self.scenes):
|
| 33 |
+
scene_dir = osp.join(self.ROOT, scene)
|
| 34 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 35 |
+
basenames = sorted(
|
| 36 |
+
[f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
num_imgs = len(basenames)
|
| 40 |
+
img_ids_ = list(np.arange(num_imgs) + offset)
|
| 41 |
+
|
| 42 |
+
img_ids.extend(img_ids_)
|
| 43 |
+
sceneids.extend([j] * num_imgs)
|
| 44 |
+
images.extend(basenames)
|
| 45 |
+
scenes.append(scene)
|
| 46 |
+
|
| 47 |
+
# offset groups
|
| 48 |
+
offset += num_imgs
|
| 49 |
+
j += 1
|
| 50 |
+
|
| 51 |
+
self.scenes = scenes
|
| 52 |
+
self.sceneids = sceneids
|
| 53 |
+
self.images = images
|
| 54 |
+
self.img_ids = img_ids
|
| 55 |
+
|
| 56 |
+
def __len__(self):
|
| 57 |
+
return len(self.img_ids)
|
| 58 |
+
|
| 59 |
+
def get_image_num(self):
|
| 60 |
+
return len(self.images)
|
| 61 |
+
|
| 62 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 63 |
+
new_seed = rng.integers(0, 2**32) + idx
|
| 64 |
+
new_rng = np.random.default_rng(new_seed)
|
| 65 |
+
image_idxs = new_rng.choice(self.img_ids, num_views, replace=False)
|
| 66 |
+
|
| 67 |
+
views = []
|
| 68 |
+
for view_idx in image_idxs:
|
| 69 |
+
scene_id = self.sceneids[view_idx]
|
| 70 |
+
scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
|
| 71 |
+
rgb_dir = osp.join(scene_dir, "rgb")
|
| 72 |
+
depth_dir = osp.join(scene_dir, "depth")
|
| 73 |
+
cam_dir = osp.join(scene_dir, "cam")
|
| 74 |
+
|
| 75 |
+
basename = self.images[view_idx]
|
| 76 |
+
|
| 77 |
+
# Load RGB image
|
| 78 |
+
rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png"))
|
| 79 |
+
depthmap = imread_cv2(osp.join(depth_dir, basename + ".exr"))
|
| 80 |
+
depthmap[depthmap > 20000] = 0.0
|
| 81 |
+
depthmap = depthmap / 1000.0
|
| 82 |
+
cam = np.load(osp.join(cam_dir, basename + ".npz"))
|
| 83 |
+
intrinsics = cam["intrinsics"]
|
| 84 |
+
camera_pose = np.eye(4)
|
| 85 |
+
|
| 86 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 87 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
views.append(
|
| 91 |
+
dict(
|
| 92 |
+
img=rgb_image,
|
| 93 |
+
depthmap=depthmap.astype(np.float32),
|
| 94 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 95 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 96 |
+
dataset="3DKenBurns",
|
| 97 |
+
label=self.scenes[scene_id] + "_" + basename,
|
| 98 |
+
instance=f"{str(idx)}_{str(view_idx)}",
|
| 99 |
+
is_metric=self.is_metric,
|
| 100 |
+
is_video=False,
|
| 101 |
+
quantile=np.array(1.0, dtype=np.float32),
|
| 102 |
+
img_mask=True,
|
| 103 |
+
ray_mask=False,
|
| 104 |
+
camera_only=False,
|
| 105 |
+
depth_only=False,
|
| 106 |
+
single_view=True,
|
| 107 |
+
reset=True,
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
assert len(views) == num_views
|
| 111 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/unreal4k.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import itertools
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 10 |
+
|
| 11 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 12 |
+
from dust3r.utils.image import imread_cv2
|
| 13 |
+
|
| 14 |
+
R_conv = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).astype(
|
| 15 |
+
np.float32
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class UnReal4K_Multi(BaseMultiViewDataset):
|
| 20 |
+
|
| 21 |
+
def __init__(self, ROOT, *args, **kwargs):
|
| 22 |
+
self.ROOT = ROOT
|
| 23 |
+
self.max_interval = 2
|
| 24 |
+
self.is_metric = True
|
| 25 |
+
super().__init__(*args, **kwargs)
|
| 26 |
+
# loading all
|
| 27 |
+
assert self.split is None
|
| 28 |
+
self._load_data()
|
| 29 |
+
|
| 30 |
+
def _load_data(self):
|
| 31 |
+
scene_dirs = sorted(
|
| 32 |
+
[
|
| 33 |
+
d
|
| 34 |
+
for d in os.listdir(self.ROOT)
|
| 35 |
+
if os.path.isdir(os.path.join(self.ROOT, d))
|
| 36 |
+
]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
offset = 0
|
| 40 |
+
scenes = []
|
| 41 |
+
sceneids = []
|
| 42 |
+
images = []
|
| 43 |
+
start_img_ids = []
|
| 44 |
+
scene_img_list = []
|
| 45 |
+
j = 0
|
| 46 |
+
|
| 47 |
+
seq_dirs = sorted(
|
| 48 |
+
[
|
| 49 |
+
os.path.join(self.ROOT, scene, mode)
|
| 50 |
+
for scene in scene_dirs
|
| 51 |
+
for mode in ["0", "1"]
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
for seq_dir in seq_dirs:
|
| 55 |
+
basenames = sorted(
|
| 56 |
+
[f[:-8] for f in os.listdir(seq_dir) if f.endswith(".png")]
|
| 57 |
+
)
|
| 58 |
+
num_imgs = len(basenames)
|
| 59 |
+
img_ids = list(np.arange(num_imgs) + offset)
|
| 60 |
+
# start_img_ids_ = img_ids[:-self.num_views+1]
|
| 61 |
+
cut_off = (
|
| 62 |
+
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
|
| 63 |
+
)
|
| 64 |
+
start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
|
| 65 |
+
|
| 66 |
+
if num_imgs < cut_off:
|
| 67 |
+
print(f"Skipping {seq_dir}")
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
start_img_ids.extend(start_img_ids_)
|
| 71 |
+
sceneids.extend([j] * num_imgs)
|
| 72 |
+
images.extend(basenames)
|
| 73 |
+
scenes.append(seq_dir)
|
| 74 |
+
scene_img_list.append(img_ids)
|
| 75 |
+
|
| 76 |
+
# offset groups
|
| 77 |
+
offset += num_imgs
|
| 78 |
+
j += 1
|
| 79 |
+
|
| 80 |
+
self.scenes = scenes
|
| 81 |
+
self.sceneids = sceneids
|
| 82 |
+
self.images = images
|
| 83 |
+
self.start_img_ids = start_img_ids
|
| 84 |
+
self.scene_img_list = scene_img_list
|
| 85 |
+
|
| 86 |
+
def __len__(self):
|
| 87 |
+
return len(self.start_img_ids) * 10
|
| 88 |
+
|
| 89 |
+
def get_image_num(self):
|
| 90 |
+
return len(self.images)
|
| 91 |
+
|
| 92 |
+
def get_stats(self):
|
| 93 |
+
return f"{len(self)//10} groups of views"
|
| 94 |
+
|
| 95 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 96 |
+
idx = idx // 10
|
| 97 |
+
start_id = self.start_img_ids[idx]
|
| 98 |
+
all_image_ids = self.scene_img_list[self.sceneids[start_id]]
|
| 99 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 100 |
+
num_views, start_id, all_image_ids, rng, max_interval=self.max_interval
|
| 101 |
+
)
|
| 102 |
+
image_idxs = np.array(all_image_ids)[pos]
|
| 103 |
+
|
| 104 |
+
views = []
|
| 105 |
+
|
| 106 |
+
for v, view_idx in enumerate(image_idxs):
|
| 107 |
+
scene_id = self.sceneids[view_idx]
|
| 108 |
+
scene_dir = self.scenes[scene_id]
|
| 109 |
+
basename = self.images[view_idx]
|
| 110 |
+
|
| 111 |
+
img = basename + "_rgb.png"
|
| 112 |
+
image = imread_cv2(osp.join(scene_dir, img))
|
| 113 |
+
depthmap = np.load(osp.join(scene_dir, basename + "_depth.npy"))
|
| 114 |
+
camera_params = np.load(osp.join(scene_dir, basename + ".npz"))
|
| 115 |
+
|
| 116 |
+
intrinsics = camera_params["intrinsics"].astype(np.float32)
|
| 117 |
+
camera_pose = camera_params["cam2world"].astype(np.float32)
|
| 118 |
+
|
| 119 |
+
camera_pose = R_conv @ camera_pose
|
| 120 |
+
|
| 121 |
+
sky_mask = depthmap >= 1000
|
| 122 |
+
depthmap[sky_mask] = -1.0 # sky
|
| 123 |
+
threshold = (
|
| 124 |
+
np.percentile(depthmap[depthmap > 0], 98)
|
| 125 |
+
if depthmap[depthmap > 0].size > 0
|
| 126 |
+
else 0
|
| 127 |
+
)
|
| 128 |
+
depthmap[depthmap > threshold] = 0.0
|
| 129 |
+
image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 130 |
+
image, depthmap, intrinsics, resolution, rng, info=(scene_dir, img)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# generate img mask and raymap mask
|
| 134 |
+
img_mask, ray_mask = self.get_img_and_ray_masks(
|
| 135 |
+
self.is_metric, v, rng, p=[0.75, 0.2, 0.05]
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
views.append(
|
| 139 |
+
dict(
|
| 140 |
+
img=image,
|
| 141 |
+
depthmap=depthmap,
|
| 142 |
+
camera_pose=camera_pose, # cam2world
|
| 143 |
+
camera_intrinsics=intrinsics,
|
| 144 |
+
dataset="UnReal4K",
|
| 145 |
+
label=scene_dir,
|
| 146 |
+
is_metric=self.is_metric,
|
| 147 |
+
instance=scene_dir + "_" + img,
|
| 148 |
+
is_video=ordered_video,
|
| 149 |
+
quantile=np.array(1.0, dtype=np.float32),
|
| 150 |
+
img_mask=img_mask,
|
| 151 |
+
ray_mask=ray_mask,
|
| 152 |
+
camera_only=False,
|
| 153 |
+
depth_only=False,
|
| 154 |
+
single_view=False,
|
| 155 |
+
reset=False,
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
assert len(views) == num_views
|
| 159 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/corr.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# modified from DUSt3R
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from dust3r.utils.device import to_numpy
|
| 9 |
+
from dust3r.utils.geometry import inv, geotrf
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def reproject_view(pts3d, view2):
|
| 13 |
+
shape = view2["pts3d"].shape[:2]
|
| 14 |
+
return reproject(
|
| 15 |
+
pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def reproject(pts3d, K, world2cam, shape):
|
| 20 |
+
H, W, THREE = pts3d.shape
|
| 21 |
+
assert THREE == 3
|
| 22 |
+
|
| 23 |
+
# reproject in camera2 space
|
| 24 |
+
with np.errstate(divide="ignore", invalid="ignore"):
|
| 25 |
+
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
|
| 26 |
+
|
| 27 |
+
# quantize to pixel positions
|
| 28 |
+
return (H, W), ravel_xy(pos, shape)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def ravel_xy(pos, shape):
|
| 32 |
+
H, W = shape
|
| 33 |
+
with np.errstate(invalid="ignore"):
|
| 34 |
+
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
|
| 35 |
+
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
|
| 36 |
+
min=0, max=H - 1, out=qy
|
| 37 |
+
)
|
| 38 |
+
return quantized_pos
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def unravel_xy(pos, shape):
|
| 42 |
+
# convert (x+W*y) back to 2d (x,y) coordinates
|
| 43 |
+
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
|
| 47 |
+
is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
|
| 48 |
+
pos1 = is_reciprocal1.nonzero()[0]
|
| 49 |
+
pos2 = corres_1_to_2[pos1]
|
| 50 |
+
if ret_recip:
|
| 51 |
+
return is_reciprocal1, pos1, pos2
|
| 52 |
+
return pos1, pos2
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def extract_correspondences_from_pts3d(
|
| 56 |
+
view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
|
| 57 |
+
):
|
| 58 |
+
view1, view2 = to_numpy((view1, view2))
|
| 59 |
+
# project pixels from image1 --> 3d points --> image2 pixels
|
| 60 |
+
shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
|
| 61 |
+
shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
|
| 62 |
+
|
| 63 |
+
# compute reciprocal correspondences:
|
| 64 |
+
# pos1 == valid pixels (correspondences) in image1
|
| 65 |
+
is_reciprocal1, pos1, pos2 = reciprocal_1d(
|
| 66 |
+
corres1_to_2, corres2_to_1, ret_recip=True
|
| 67 |
+
)
|
| 68 |
+
is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
|
| 69 |
+
|
| 70 |
+
if target_n_corres is None:
|
| 71 |
+
if ret_xy:
|
| 72 |
+
pos1 = unravel_xy(pos1, shape1)
|
| 73 |
+
pos2 = unravel_xy(pos2, shape2)
|
| 74 |
+
return pos1, pos2
|
| 75 |
+
|
| 76 |
+
available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
|
| 77 |
+
target_n_positives = int(target_n_corres * (1 - nneg))
|
| 78 |
+
n_positives = min(len(pos1), target_n_positives)
|
| 79 |
+
n_negatives = min(target_n_corres - n_positives, available_negatives)
|
| 80 |
+
|
| 81 |
+
if n_negatives + n_positives != target_n_corres:
|
| 82 |
+
# should be really rare => when there are not enough negatives
|
| 83 |
+
# in that case, break nneg and add a few more positives ?
|
| 84 |
+
n_positives = target_n_corres - n_negatives
|
| 85 |
+
assert n_positives <= len(pos1)
|
| 86 |
+
|
| 87 |
+
assert n_positives <= len(pos1)
|
| 88 |
+
assert n_positives <= len(pos2)
|
| 89 |
+
assert n_negatives <= (~is_reciprocal1).sum()
|
| 90 |
+
assert n_negatives <= (~is_reciprocal2).sum()
|
| 91 |
+
assert n_positives + n_negatives == target_n_corres
|
| 92 |
+
|
| 93 |
+
valid = np.ones(n_positives, dtype=bool)
|
| 94 |
+
if n_positives < len(pos1):
|
| 95 |
+
# random sub-sampling of valid correspondences
|
| 96 |
+
perm = rng.permutation(len(pos1))[:n_positives]
|
| 97 |
+
pos1 = pos1[perm]
|
| 98 |
+
pos2 = pos2[perm]
|
| 99 |
+
|
| 100 |
+
if n_negatives > 0:
|
| 101 |
+
# add false correspondences if not enough
|
| 102 |
+
def norm(p):
|
| 103 |
+
return p / p.sum()
|
| 104 |
+
|
| 105 |
+
pos1 = np.r_[
|
| 106 |
+
pos1,
|
| 107 |
+
rng.choice(
|
| 108 |
+
shape1[0] * shape1[1],
|
| 109 |
+
size=n_negatives,
|
| 110 |
+
replace=False,
|
| 111 |
+
p=norm(~is_reciprocal1),
|
| 112 |
+
),
|
| 113 |
+
]
|
| 114 |
+
pos2 = np.r_[
|
| 115 |
+
pos2,
|
| 116 |
+
rng.choice(
|
| 117 |
+
shape2[0] * shape2[1],
|
| 118 |
+
size=n_negatives,
|
| 119 |
+
replace=False,
|
| 120 |
+
p=norm(~is_reciprocal2),
|
| 121 |
+
),
|
| 122 |
+
]
|
| 123 |
+
valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
|
| 124 |
+
|
| 125 |
+
# convert (x+W*y) back to 2d (x,y) coordinates
|
| 126 |
+
if ret_xy:
|
| 127 |
+
pos1 = unravel_xy(pos1, shape1)
|
| 128 |
+
pos2 = unravel_xy(pos2, shape2)
|
| 129 |
+
return pos1, pos2, valid
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/cropping.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# croppping utilities
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import PIL.Image
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 11 |
+
import cv2 # noqa
|
| 12 |
+
import numpy as np # noqa
|
| 13 |
+
from dust3r.utils.geometry import (
|
| 14 |
+
colmap_to_opencv_intrinsics,
|
| 15 |
+
opencv_to_colmap_intrinsics,
|
| 16 |
+
) # noqa
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
lanczos = PIL.Image.Resampling.LANCZOS
|
| 20 |
+
bicubic = PIL.Image.Resampling.BICUBIC
|
| 21 |
+
except AttributeError:
|
| 22 |
+
lanczos = PIL.Image.LANCZOS
|
| 23 |
+
bicubic = PIL.Image.BICUBIC
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ImageList:
|
| 27 |
+
"""Convenience class to aply the same operation to a whole set of images."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, images):
|
| 30 |
+
if not isinstance(images, (tuple, list, set)):
|
| 31 |
+
images = [images]
|
| 32 |
+
self.images = []
|
| 33 |
+
for image in images:
|
| 34 |
+
if not isinstance(image, PIL.Image.Image):
|
| 35 |
+
image = PIL.Image.fromarray(image)
|
| 36 |
+
self.images.append(image)
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self.images)
|
| 40 |
+
|
| 41 |
+
def to_pil(self):
|
| 42 |
+
return tuple(self.images) if len(self.images) > 1 else self.images[0]
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def size(self):
|
| 46 |
+
sizes = [im.size for im in self.images]
|
| 47 |
+
assert all(sizes[0] == s for s in sizes)
|
| 48 |
+
return sizes[0]
|
| 49 |
+
|
| 50 |
+
def resize(self, *args, **kwargs):
|
| 51 |
+
return ImageList(self._dispatch("resize", *args, **kwargs))
|
| 52 |
+
|
| 53 |
+
def crop(self, *args, **kwargs):
|
| 54 |
+
return ImageList(self._dispatch("crop", *args, **kwargs))
|
| 55 |
+
|
| 56 |
+
def _dispatch(self, func, *args, **kwargs):
|
| 57 |
+
return [getattr(im, func)(*args, **kwargs) for im in self.images]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def rescale_image_depthmap(
|
| 61 |
+
image, depthmap, camera_intrinsics, output_resolution, force=True
|
| 62 |
+
):
|
| 63 |
+
"""Jointly rescale a (image, depthmap)
|
| 64 |
+
so that (out_width, out_height) >= output_res
|
| 65 |
+
"""
|
| 66 |
+
image = ImageList(image)
|
| 67 |
+
input_resolution = np.array(image.size) # (W,H)
|
| 68 |
+
output_resolution = np.array(output_resolution)
|
| 69 |
+
if depthmap is not None:
|
| 70 |
+
# can also use this with masks instead of depthmaps
|
| 71 |
+
assert tuple(depthmap.shape[:2]) == image.size[::-1]
|
| 72 |
+
|
| 73 |
+
# define output resolution
|
| 74 |
+
assert output_resolution.shape == (2,)
|
| 75 |
+
scale_final = max(output_resolution / image.size) + 1e-8
|
| 76 |
+
if scale_final >= 1 and not force: # image is already smaller than what is asked
|
| 77 |
+
return (image.to_pil(), depthmap, camera_intrinsics)
|
| 78 |
+
output_resolution = np.floor(input_resolution * scale_final).astype(int)
|
| 79 |
+
|
| 80 |
+
# first rescale the image so that it contains the crop
|
| 81 |
+
image = image.resize(
|
| 82 |
+
output_resolution, resample=lanczos if scale_final < 1 else bicubic
|
| 83 |
+
)
|
| 84 |
+
if depthmap is not None:
|
| 85 |
+
depthmap = cv2.resize(
|
| 86 |
+
depthmap,
|
| 87 |
+
output_resolution,
|
| 88 |
+
fx=scale_final,
|
| 89 |
+
fy=scale_final,
|
| 90 |
+
interpolation=cv2.INTER_NEAREST,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# no offset here; simple rescaling
|
| 94 |
+
camera_intrinsics = camera_matrix_of_crop(
|
| 95 |
+
camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def camera_matrix_of_crop(
|
| 102 |
+
input_camera_matrix,
|
| 103 |
+
input_resolution,
|
| 104 |
+
output_resolution,
|
| 105 |
+
scaling=1,
|
| 106 |
+
offset_factor=0.5,
|
| 107 |
+
offset=None,
|
| 108 |
+
):
|
| 109 |
+
# Margins to offset the origin
|
| 110 |
+
margins = np.asarray(input_resolution) * scaling - output_resolution
|
| 111 |
+
assert np.all(margins >= 0.0)
|
| 112 |
+
if offset is None:
|
| 113 |
+
offset = offset_factor * margins
|
| 114 |
+
|
| 115 |
+
# Generate new camera parameters
|
| 116 |
+
output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
|
| 117 |
+
output_camera_matrix_colmap[:2, :] *= scaling
|
| 118 |
+
output_camera_matrix_colmap[:2, 2] -= offset
|
| 119 |
+
output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
|
| 120 |
+
|
| 121 |
+
return output_camera_matrix
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
|
| 125 |
+
"""
|
| 126 |
+
Return a crop of the input view.
|
| 127 |
+
"""
|
| 128 |
+
image = ImageList(image)
|
| 129 |
+
l, t, r, b = crop_bbox
|
| 130 |
+
|
| 131 |
+
image = image.crop((l, t, r, b))
|
| 132 |
+
depthmap = depthmap[t:b, l:r]
|
| 133 |
+
|
| 134 |
+
camera_intrinsics = camera_intrinsics.copy()
|
| 135 |
+
camera_intrinsics[0, 2] -= l
|
| 136 |
+
camera_intrinsics[1, 2] -= t
|
| 137 |
+
|
| 138 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def bbox_from_intrinsics_in_out(
|
| 142 |
+
input_camera_matrix, output_camera_matrix, output_resolution
|
| 143 |
+
):
|
| 144 |
+
out_width, out_height = output_resolution
|
| 145 |
+
l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
|
| 146 |
+
crop_bbox = (l, t, l + out_width, t + out_height)
|
| 147 |
+
return crop_bbox
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/transforms.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# DUST3R default transforms
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import torchvision.transforms as tvf
|
| 8 |
+
from dust3r.utils.image import ImgNorm
|
| 9 |
+
|
| 10 |
+
# define the standard image transforms
|
| 11 |
+
ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
|
| 15 |
+
if isinstance(value, (int, float)):
|
| 16 |
+
if value < 0:
|
| 17 |
+
raise ValueError(f"If is a single number, it must be non negative.")
|
| 18 |
+
value = [center - float(value), center + float(value)]
|
| 19 |
+
if clip_first_on_zero:
|
| 20 |
+
value[0] = max(value[0], 0.0)
|
| 21 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
| 22 |
+
value = [float(value[0]), float(value[1])]
|
| 23 |
+
else:
|
| 24 |
+
raise TypeError(f"should be a single number or a list/tuple with length 2.")
|
| 25 |
+
|
| 26 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
| 27 |
+
raise ValueError(f"values should be between {bound}, but got {value}.")
|
| 28 |
+
|
| 29 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
| 30 |
+
# or (0., 0.) for hue, do nothing
|
| 31 |
+
if value[0] == value[1] == center:
|
| 32 |
+
return None
|
| 33 |
+
else:
|
| 34 |
+
return tuple(value)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
import torchvision.transforms.functional as F
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def SeqColorJitter():
|
| 42 |
+
"""
|
| 43 |
+
Return a color jitter transform with same random parameters
|
| 44 |
+
"""
|
| 45 |
+
brightness = _check_input(0.5)
|
| 46 |
+
contrast = _check_input(0.5)
|
| 47 |
+
saturation = _check_input(0.5)
|
| 48 |
+
hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
| 49 |
+
|
| 50 |
+
fn_idx = torch.randperm(4)
|
| 51 |
+
brightness_factor = (
|
| 52 |
+
None
|
| 53 |
+
if brightness is None
|
| 54 |
+
else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
|
| 55 |
+
)
|
| 56 |
+
contrast_factor = (
|
| 57 |
+
None
|
| 58 |
+
if contrast is None
|
| 59 |
+
else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
|
| 60 |
+
)
|
| 61 |
+
saturation_factor = (
|
| 62 |
+
None
|
| 63 |
+
if saturation is None
|
| 64 |
+
else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
|
| 65 |
+
)
|
| 66 |
+
hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
|
| 67 |
+
|
| 68 |
+
def _color_jitter(img):
|
| 69 |
+
for fn_id in fn_idx:
|
| 70 |
+
if fn_id == 0 and brightness_factor is not None:
|
| 71 |
+
img = F.adjust_brightness(img, brightness_factor)
|
| 72 |
+
elif fn_id == 1 and contrast_factor is not None:
|
| 73 |
+
img = F.adjust_contrast(img, contrast_factor)
|
| 74 |
+
elif fn_id == 2 and saturation_factor is not None:
|
| 75 |
+
img = F.adjust_saturation(img, saturation_factor)
|
| 76 |
+
elif fn_id == 3 and hue_factor is not None:
|
| 77 |
+
img = F.adjust_hue(img, hue_factor)
|
| 78 |
+
return ImgNorm(img)
|
| 79 |
+
|
| 80 |
+
return _color_jitter
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/waymo.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 7 |
+
import h5py
|
| 8 |
+
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
|
| 9 |
+
from dust3r.utils.image import imread_cv2
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Waymo_Multi(BaseMultiViewDataset):
|
| 13 |
+
"""Dataset of outdoor street scenes, 5 images each time"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, *args, ROOT, **kwargs):
|
| 16 |
+
self.ROOT = ROOT
|
| 17 |
+
self.max_interval = 8
|
| 18 |
+
self.video = True
|
| 19 |
+
self.is_metric = True
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
assert self.split is None
|
| 22 |
+
self._load_data()
|
| 23 |
+
|
| 24 |
+
def load_invalid_dict(self, h5_file_path):
|
| 25 |
+
invalid_dict = {}
|
| 26 |
+
with h5py.File(h5_file_path, "r") as h5f:
|
| 27 |
+
for scene in h5f:
|
| 28 |
+
data = h5f[scene]["invalid_pairs"][:]
|
| 29 |
+
invalid_pairs = set(
|
| 30 |
+
tuple(pair.decode("utf-8").split("_")) for pair in data
|
| 31 |
+
)
|
| 32 |
+
invalid_dict[scene] = invalid_pairs
|
| 33 |
+
return invalid_dict
|
| 34 |
+
|
| 35 |
+
def _load_data(self):
|
| 36 |
+
invalid_dict = self.load_invalid_dict(
|
| 37 |
+
os.path.join(self.ROOT, "invalid_files.h5")
|
| 38 |
+
)
|
| 39 |
+
scene_dirs = sorted(
|
| 40 |
+
[
|
| 41 |
+
d
|
| 42 |
+
for d in os.listdir(self.ROOT)
|
| 43 |
+
if os.path.isdir(os.path.join(self.ROOT, d))
|
| 44 |
+
]
|
| 45 |
+
)
|
| 46 |
+
offset = 0
|
| 47 |
+
scenes = []
|
| 48 |
+
sceneids = []
|
| 49 |
+
images = []
|
| 50 |
+
start_img_ids = []
|
| 51 |
+
scene_img_list = []
|
| 52 |
+
is_video = []
|
| 53 |
+
j = 0
|
| 54 |
+
|
| 55 |
+
for scene in scene_dirs:
|
| 56 |
+
scene_dir = osp.join(self.ROOT, scene)
|
| 57 |
+
invalid_pairs = invalid_dict.get(scene, set())
|
| 58 |
+
seq2frames = {}
|
| 59 |
+
for f in os.listdir(scene_dir):
|
| 60 |
+
if not f.endswith(".jpg"):
|
| 61 |
+
continue
|
| 62 |
+
basename = f[:-4]
|
| 63 |
+
frame_id = basename.split("_")[0]
|
| 64 |
+
seq_id = basename.split("_")[1]
|
| 65 |
+
if seq_id == "5":
|
| 66 |
+
continue
|
| 67 |
+
if (seq_id, frame_id) in invalid_pairs:
|
| 68 |
+
continue # Skip invalid files
|
| 69 |
+
if seq_id not in seq2frames:
|
| 70 |
+
seq2frames[seq_id] = []
|
| 71 |
+
seq2frames[seq_id].append(frame_id)
|
| 72 |
+
|
| 73 |
+
for seq_id, frame_ids in seq2frames.items():
|
| 74 |
+
frame_ids = sorted(frame_ids)
|
| 75 |
+
num_imgs = len(frame_ids)
|
| 76 |
+
img_ids = list(np.arange(num_imgs) + offset)
|
| 77 |
+
cut_off = (
|
| 78 |
+
self.num_views
|
| 79 |
+
if not self.allow_repeat
|
| 80 |
+
else max(self.num_views // 3, 3)
|
| 81 |
+
)
|
| 82 |
+
start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
|
| 83 |
+
|
| 84 |
+
if num_imgs < cut_off:
|
| 85 |
+
print(f"Skipping {scene}_{seq_id}")
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
scenes.append((scene, seq_id))
|
| 89 |
+
sceneids.extend([j] * num_imgs)
|
| 90 |
+
images.extend(frame_ids)
|
| 91 |
+
start_img_ids.extend(start_img_ids_)
|
| 92 |
+
scene_img_list.append(img_ids)
|
| 93 |
+
|
| 94 |
+
offset += num_imgs
|
| 95 |
+
j += 1
|
| 96 |
+
|
| 97 |
+
self.scenes = scenes
|
| 98 |
+
self.sceneids = sceneids
|
| 99 |
+
self.images = images
|
| 100 |
+
self.start_img_ids = start_img_ids
|
| 101 |
+
self.scene_img_list = scene_img_list
|
| 102 |
+
self.is_video = is_video
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return len(self.start_img_ids)
|
| 106 |
+
|
| 107 |
+
def get_image_num(self):
|
| 108 |
+
return len(self.images)
|
| 109 |
+
|
| 110 |
+
def get_stats(self):
|
| 111 |
+
return f"{len(self)} groups of views"
|
| 112 |
+
|
| 113 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 114 |
+
start_id = self.start_img_ids[idx]
|
| 115 |
+
all_image_ids = self.scene_img_list[self.sceneids[start_id]]
|
| 116 |
+
_, seq_id = self.scenes[self.sceneids[start_id]]
|
| 117 |
+
max_interval = self.max_interval // 2 if seq_id == "4" else self.max_interval
|
| 118 |
+
pos, ordered_video = self.get_seq_from_start_id(
|
| 119 |
+
num_views,
|
| 120 |
+
start_id,
|
| 121 |
+
all_image_ids,
|
| 122 |
+
rng,
|
| 123 |
+
max_interval=max_interval,
|
| 124 |
+
video_prob=0.9,
|
| 125 |
+
fix_interval_prob=0.9,
|
| 126 |
+
block_shuffle=16,
|
| 127 |
+
)
|
| 128 |
+
image_idxs = np.array(all_image_ids)[pos]
|
| 129 |
+
views = []
|
| 130 |
+
ordered_video = True
|
| 131 |
+
|
| 132 |
+
views = []
|
| 133 |
+
|
| 134 |
+
for v, view_idx in enumerate(image_idxs):
|
| 135 |
+
scene_id = self.sceneids[view_idx]
|
| 136 |
+
scene_dir, seq_id = self.scenes[scene_id]
|
| 137 |
+
scene_dir = osp.join(self.ROOT, scene_dir)
|
| 138 |
+
frame_id = self.images[view_idx]
|
| 139 |
+
|
| 140 |
+
impath = f"{frame_id}_{seq_id}"
|
| 141 |
+
image = imread_cv2(osp.join(scene_dir, impath + ".jpg"))
|
| 142 |
+
depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr"))
|
| 143 |
+
camera_params = np.load(osp.join(scene_dir, impath + ".npz"))
|
| 144 |
+
|
| 145 |
+
intrinsics = np.float32(camera_params["intrinsics"])
|
| 146 |
+
camera_pose = np.float32(camera_params["cam2world"])
|
| 147 |
+
|
| 148 |
+
image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 149 |
+
image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# generate img mask and raymap mask
|
| 153 |
+
img_mask, ray_mask = self.get_img_and_ray_masks(
|
| 154 |
+
self.is_metric, v, rng, p=[0.85, 0.10, 0.05]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
views.append(
|
| 158 |
+
dict(
|
| 159 |
+
img=image,
|
| 160 |
+
depthmap=depthmap,
|
| 161 |
+
camera_pose=camera_pose, # cam2world
|
| 162 |
+
camera_intrinsics=intrinsics,
|
| 163 |
+
dataset="Waymo",
|
| 164 |
+
label=osp.relpath(scene_dir, self.ROOT),
|
| 165 |
+
is_metric=self.is_metric,
|
| 166 |
+
instance=osp.join(scene_dir, impath + ".jpg"),
|
| 167 |
+
is_video=ordered_video,
|
| 168 |
+
quantile=np.array(0.98, dtype=np.float32),
|
| 169 |
+
img_mask=img_mask,
|
| 170 |
+
ray_mask=ray_mask,
|
| 171 |
+
camera_only=False,
|
| 172 |
+
depth_only=False,
|
| 173 |
+
single_view=False,
|
| 174 |
+
reset=False,
|
| 175 |
+
)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/wildrgbd.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from dust3r.datasets.co3d import Co3d_Multi
|
| 9 |
+
from dust3r.utils.image import imread_cv2
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class WildRGBD_Multi(Co3d_Multi):
|
| 13 |
+
def __init__(self, mask_bg="rand", *args, ROOT, **kwargs):
|
| 14 |
+
super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
|
| 15 |
+
self.dataset_label = "WildRGBD"
|
| 16 |
+
self.is_metric = True
|
| 17 |
+
# load all scenes
|
| 18 |
+
self.scenes.pop(("box", "scenes/scene_257"), None)
|
| 19 |
+
self.scene_list = list(self.scenes.keys())
|
| 20 |
+
cut_off = (
|
| 21 |
+
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
|
| 22 |
+
)
|
| 23 |
+
self.cut_off = cut_off
|
| 24 |
+
self.all_ref_imgs = [
|
| 25 |
+
(key, value)
|
| 26 |
+
for key, values in self.scenes.items()
|
| 27 |
+
for value in values[: len(values) - cut_off + 1]
|
| 28 |
+
]
|
| 29 |
+
self.invalidate = {scene: {} for scene in self.scene_list}
|
| 30 |
+
self.invalid_scenes = {scene: False for scene in self.scene_list}
|
| 31 |
+
|
| 32 |
+
def _get_metadatapath(self, obj, instance, view_idx):
|
| 33 |
+
return osp.join(self.ROOT, obj, instance, "metadata", f"{view_idx:0>5d}.npz")
|
| 34 |
+
|
| 35 |
+
def _get_impath(self, obj, instance, view_idx):
|
| 36 |
+
return osp.join(self.ROOT, obj, instance, "rgb", f"{view_idx:0>5d}.jpg")
|
| 37 |
+
|
| 38 |
+
def _get_depthpath(self, obj, instance, view_idx):
|
| 39 |
+
return osp.join(self.ROOT, obj, instance, "depth", f"{view_idx:0>5d}.png")
|
| 40 |
+
|
| 41 |
+
def _get_maskpath(self, obj, instance, view_idx):
|
| 42 |
+
return osp.join(self.ROOT, obj, instance, "masks", f"{view_idx:0>5d}.png")
|
| 43 |
+
|
| 44 |
+
def _read_depthmap(self, depthpath, input_metadata):
|
| 45 |
+
# We store depths in the depth scale of 1000.
|
| 46 |
+
# That is, when we load depth image and divide by 1000, we could get depth in meters.
|
| 47 |
+
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
|
| 48 |
+
depthmap = depthmap.astype(np.float32) / 1000.0
|
| 49 |
+
return depthmap
|
| 50 |
+
|
| 51 |
+
def _get_views(self, idx, resolution, rng, num_views):
|
| 52 |
+
views = super()._get_views(idx, resolution, rng, num_views)
|
| 53 |
+
for view in views:
|
| 54 |
+
assert view["is_metric"]
|
| 55 |
+
view["quantile"] = np.array(0.96, dtype=np.float32)
|
| 56 |
+
return views
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/camera.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from croco.models.blocks import Mlp
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
inf = float("inf")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PoseDecoder(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
hidden_size=768,
|
| 17 |
+
mlp_ratio=4,
|
| 18 |
+
pose_encoding_type="absT_quaR",
|
| 19 |
+
):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
self.pose_encoding_type = pose_encoding_type
|
| 23 |
+
if self.pose_encoding_type == "absT_quaR":
|
| 24 |
+
self.target_dim = 7
|
| 25 |
+
|
| 26 |
+
self.mlp = Mlp(
|
| 27 |
+
in_features=hidden_size,
|
| 28 |
+
hidden_features=int(hidden_size * mlp_ratio),
|
| 29 |
+
out_features=self.target_dim,
|
| 30 |
+
drop=0,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def forward(
|
| 34 |
+
self,
|
| 35 |
+
pose_feat,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
pose_feat: BxC
|
| 39 |
+
preliminary_cameras: cameras in opencv coordinate.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
pred_cameras = self.mlp(pose_feat) # Bx7, 3 for absT, 4 for quaR
|
| 43 |
+
return pred_cameras
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class PoseEncoder(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
hidden_size=768,
|
| 50 |
+
mlp_ratio=4,
|
| 51 |
+
pose_mode=("exp", -inf, inf),
|
| 52 |
+
pose_encoding_type="absT_quaR",
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.pose_encoding_type = pose_encoding_type
|
| 56 |
+
self.pose_mode = pose_mode
|
| 57 |
+
|
| 58 |
+
if self.pose_encoding_type == "absT_quaR":
|
| 59 |
+
self.target_dim = 7
|
| 60 |
+
|
| 61 |
+
self.embed_pose = PoseEmbedding(
|
| 62 |
+
target_dim=self.target_dim,
|
| 63 |
+
out_dim=hidden_size,
|
| 64 |
+
n_harmonic_functions=10,
|
| 65 |
+
append_input=True,
|
| 66 |
+
)
|
| 67 |
+
self.pose_encoder = Mlp(
|
| 68 |
+
in_features=self.embed_pose.out_dim,
|
| 69 |
+
hidden_features=int(hidden_size * mlp_ratio),
|
| 70 |
+
out_features=hidden_size,
|
| 71 |
+
drop=0,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, camera):
|
| 75 |
+
from dust3r.heads.postprocess import postprocess_pose
|
| 76 |
+
pose_enc = camera_to_pose_encoding(
|
| 77 |
+
camera,
|
| 78 |
+
pose_encoding_type=self.pose_encoding_type,
|
| 79 |
+
).to(camera.dtype)
|
| 80 |
+
pose_enc = postprocess_pose(pose_enc, self.pose_mode, inverse=True)
|
| 81 |
+
pose_feat = self.embed_pose(pose_enc)
|
| 82 |
+
pose_feat = self.pose_encoder(pose_feat)
|
| 83 |
+
return pose_feat
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class HarmonicEmbedding(torch.nn.Module):
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
n_harmonic_functions: int = 6,
|
| 90 |
+
omega_0: float = 1.0,
|
| 91 |
+
logspace: bool = True,
|
| 92 |
+
append_input: bool = True,
|
| 93 |
+
) -> None:
|
| 94 |
+
"""
|
| 95 |
+
The harmonic embedding layer supports the classical
|
| 96 |
+
Nerf positional encoding described in
|
| 97 |
+
`NeRF <https://arxiv.org/abs/2003.08934>`_
|
| 98 |
+
and the integrated position encoding in
|
| 99 |
+
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
| 100 |
+
|
| 101 |
+
During the inference you can provide the extra argument `diag_cov`.
|
| 102 |
+
|
| 103 |
+
If `diag_cov is None`, it converts
|
| 104 |
+
rays parametrized with a `ray_bundle` to 3D points by
|
| 105 |
+
extending each ray according to the corresponding length.
|
| 106 |
+
Then it converts each feature
|
| 107 |
+
(i.e. vector along the last dimension) in `x`
|
| 108 |
+
into a series of harmonic features `embedding`,
|
| 109 |
+
where for each i in range(dim) the following are present
|
| 110 |
+
in embedding[...]::
|
| 111 |
+
|
| 112 |
+
[
|
| 113 |
+
sin(f_1*x[..., i]),
|
| 114 |
+
sin(f_2*x[..., i]),
|
| 115 |
+
...
|
| 116 |
+
sin(f_N * x[..., i]),
|
| 117 |
+
cos(f_1*x[..., i]),
|
| 118 |
+
cos(f_2*x[..., i]),
|
| 119 |
+
...
|
| 120 |
+
cos(f_N * x[..., i]),
|
| 121 |
+
x[..., i], # only present if append_input is True.
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
|
| 125 |
+
denoting the i-th frequency of the harmonic embedding.
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
If `diag_cov is not None`, it approximates
|
| 129 |
+
conical frustums following a ray bundle as gaussians,
|
| 130 |
+
defined by x, the means of the gaussians and diag_cov,
|
| 131 |
+
the diagonal covariances.
|
| 132 |
+
Then it converts each gaussian
|
| 133 |
+
into a series of harmonic features `embedding`,
|
| 134 |
+
where for each i in range(dim) the following are present
|
| 135 |
+
in embedding[...]::
|
| 136 |
+
|
| 137 |
+
[
|
| 138 |
+
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
| 139 |
+
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
|
| 140 |
+
...
|
| 141 |
+
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
| 142 |
+
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
| 143 |
+
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
|
| 144 |
+
...
|
| 145 |
+
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
| 146 |
+
x[..., i], # only present if append_input is True.
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
where N equals `n_harmonic_functions-1`, and f_i is a scalar
|
| 150 |
+
denoting the i-th frequency of the harmonic embedding.
|
| 151 |
+
|
| 152 |
+
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
| 153 |
+
powers of 2:
|
| 154 |
+
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
| 155 |
+
|
| 156 |
+
If `logspace==False`, frequencies are linearly spaced between
|
| 157 |
+
`1.0` and `2**(n_harmonic_functions-1)`:
|
| 158 |
+
`f_1, ..., f_N = torch.linspace(
|
| 159 |
+
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
|
| 160 |
+
)`
|
| 161 |
+
|
| 162 |
+
Note that `x` is also premultiplied by the base frequency `omega_0`
|
| 163 |
+
before evaluating the harmonic functions.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
n_harmonic_functions: int, number of harmonic
|
| 167 |
+
features
|
| 168 |
+
omega_0: float, base frequency
|
| 169 |
+
logspace: bool, Whether to space the frequencies in
|
| 170 |
+
logspace or linear space
|
| 171 |
+
append_input: bool, whether to concat the original
|
| 172 |
+
input to the harmonic embedding. If true the
|
| 173 |
+
output is of the form (embed.sin(), embed.cos(), x)
|
| 174 |
+
"""
|
| 175 |
+
super().__init__()
|
| 176 |
+
|
| 177 |
+
if logspace:
|
| 178 |
+
frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32)
|
| 179 |
+
else:
|
| 180 |
+
frequencies = torch.linspace(
|
| 181 |
+
1.0,
|
| 182 |
+
2.0 ** (n_harmonic_functions - 1),
|
| 183 |
+
n_harmonic_functions,
|
| 184 |
+
dtype=torch.float32,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
|
| 188 |
+
self.register_buffer(
|
| 189 |
+
"_zero_half_pi",
|
| 190 |
+
torch.tensor([0.0, 0.5 * torch.pi]),
|
| 191 |
+
persistent=False,
|
| 192 |
+
)
|
| 193 |
+
self.append_input = append_input
|
| 194 |
+
|
| 195 |
+
def forward(
|
| 196 |
+
self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
|
| 197 |
+
) -> torch.Tensor:
|
| 198 |
+
"""
|
| 199 |
+
Args:
|
| 200 |
+
x: tensor of shape [..., dim]
|
| 201 |
+
diag_cov: An optional tensor of shape `(..., dim)`
|
| 202 |
+
representing the diagonal covariance matrices of our Gaussians, joined with x
|
| 203 |
+
as means of the Gaussians.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
embedding: a harmonic embedding of `x` of shape
|
| 207 |
+
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
embed = x[..., None] * self._frequencies
|
| 211 |
+
|
| 212 |
+
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
|
| 213 |
+
|
| 214 |
+
embed = embed.sin()
|
| 215 |
+
if diag_cov is not None:
|
| 216 |
+
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
|
| 217 |
+
exp_var = torch.exp(-0.5 * x_var)
|
| 218 |
+
|
| 219 |
+
embed = embed * exp_var[..., None, :, :]
|
| 220 |
+
|
| 221 |
+
embed = embed.reshape(*x.shape[:-1], -1)
|
| 222 |
+
|
| 223 |
+
if self.append_input:
|
| 224 |
+
return torch.cat([embed, x], dim=-1)
|
| 225 |
+
return embed
|
| 226 |
+
|
| 227 |
+
@staticmethod
|
| 228 |
+
def get_output_dim_static(
|
| 229 |
+
input_dims: int, n_harmonic_functions: int, append_input: bool
|
| 230 |
+
) -> int:
|
| 231 |
+
"""
|
| 232 |
+
Utility to help predict the shape of the output of `forward`.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
input_dims: length of the last dimension of the input tensor
|
| 236 |
+
n_harmonic_functions: number of embedding frequencies
|
| 237 |
+
append_input: whether or not to concat the original
|
| 238 |
+
input to the harmonic embedding
|
| 239 |
+
Returns:
|
| 240 |
+
int: the length of the last dimension of the output tensor
|
| 241 |
+
"""
|
| 242 |
+
return input_dims * (2 * n_harmonic_functions + int(append_input))
|
| 243 |
+
|
| 244 |
+
def get_output_dim(self, input_dims: int = 3) -> int:
|
| 245 |
+
"""
|
| 246 |
+
Same as above. The default for input_dims is 3 for 3D applications
|
| 247 |
+
which use harmonic embedding for positional encoding,
|
| 248 |
+
so the input might be xyz.
|
| 249 |
+
"""
|
| 250 |
+
return self.get_output_dim_static(
|
| 251 |
+
input_dims, len(self._frequencies), self.append_input
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class PoseEmbedding(nn.Module):
|
| 256 |
+
def __init__(self, target_dim, out_dim, n_harmonic_functions=10, append_input=True):
|
| 257 |
+
super().__init__()
|
| 258 |
+
|
| 259 |
+
self._emb_pose = HarmonicEmbedding(
|
| 260 |
+
n_harmonic_functions=n_harmonic_functions, append_input=append_input
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
self.out_dim = self._emb_pose.get_output_dim(target_dim)
|
| 264 |
+
|
| 265 |
+
def forward(self, pose_encoding):
|
| 266 |
+
e_pose_encoding = self._emb_pose(pose_encoding)
|
| 267 |
+
return e_pose_encoding
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 271 |
+
"""
|
| 272 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 273 |
+
but with a zero subgradient where x is 0.
|
| 274 |
+
"""
|
| 275 |
+
ret = torch.zeros_like(x)
|
| 276 |
+
positive_mask = x > 0
|
| 277 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 278 |
+
return ret
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
| 282 |
+
"""
|
| 283 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 290 |
+
"""
|
| 291 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 292 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 293 |
+
|
| 294 |
+
batch_dim = matrix.shape[:-2]
|
| 295 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
| 296 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
q_abs = _sqrt_positive_part(
|
| 300 |
+
torch.stack(
|
| 301 |
+
[
|
| 302 |
+
1.0 + m00 + m11 + m22,
|
| 303 |
+
1.0 + m00 - m11 - m22,
|
| 304 |
+
1.0 - m00 + m11 - m22,
|
| 305 |
+
1.0 - m00 - m11 + m22,
|
| 306 |
+
],
|
| 307 |
+
dim=-1,
|
| 308 |
+
)
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
quat_by_rijk = torch.stack(
|
| 312 |
+
[
|
| 313 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 314 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 315 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 316 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 317 |
+
],
|
| 318 |
+
dim=-2,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 322 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 323 |
+
|
| 324 |
+
out = quat_candidates[
|
| 325 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
| 326 |
+
].reshape(batch_dim + (4,))
|
| 327 |
+
return standardize_quaternion(out)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 331 |
+
"""
|
| 332 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 333 |
+
part is non negative.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
quaternions: Quaternions with real part first,
|
| 337 |
+
as tensor of shape (..., 4).
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 341 |
+
"""
|
| 342 |
+
quaternions = F.normalize(quaternions, p=2, dim=-1)
|
| 343 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def camera_to_pose_encoding(
|
| 347 |
+
camera,
|
| 348 |
+
pose_encoding_type="absT_quaR",
|
| 349 |
+
):
|
| 350 |
+
"""
|
| 351 |
+
Inverse to pose_encoding_to_camera
|
| 352 |
+
camera: opencv, cam2world
|
| 353 |
+
"""
|
| 354 |
+
if pose_encoding_type == "absT_quaR":
|
| 355 |
+
|
| 356 |
+
quaternion_R = matrix_to_quaternion(camera[:, :3, :3])
|
| 357 |
+
|
| 358 |
+
pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1)
|
| 359 |
+
else:
|
| 360 |
+
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
| 361 |
+
|
| 362 |
+
return pose_encoding
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
| 366 |
+
"""
|
| 367 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
quaternions: quaternions with real part first,
|
| 371 |
+
as tensor of shape (..., 4).
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 375 |
+
"""
|
| 376 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 377 |
+
|
| 378 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 379 |
+
|
| 380 |
+
o = torch.stack(
|
| 381 |
+
(
|
| 382 |
+
1 - two_s * (j * j + k * k),
|
| 383 |
+
two_s * (i * j - k * r),
|
| 384 |
+
two_s * (i * k + j * r),
|
| 385 |
+
two_s * (i * j + k * r),
|
| 386 |
+
1 - two_s * (i * i + k * k),
|
| 387 |
+
two_s * (j * k - i * r),
|
| 388 |
+
two_s * (i * k - j * r),
|
| 389 |
+
two_s * (j * k + i * r),
|
| 390 |
+
1 - two_s * (i * i + j * j),
|
| 391 |
+
),
|
| 392 |
+
-1,
|
| 393 |
+
)
|
| 394 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def pose_encoding_to_camera(
|
| 398 |
+
pose_encoding,
|
| 399 |
+
pose_encoding_type="absT_quaR",
|
| 400 |
+
):
|
| 401 |
+
"""
|
| 402 |
+
Args:
|
| 403 |
+
pose_encoding: A tensor of shape `BxC`, containing a batch of
|
| 404 |
+
`B` `C`-dimensional pose encodings.
|
| 405 |
+
pose_encoding_type: The type of pose encoding,
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
if pose_encoding_type == "absT_quaR":
|
| 409 |
+
|
| 410 |
+
abs_T = pose_encoding[:, :3]
|
| 411 |
+
quaternion_R = pose_encoding[:, 3:7]
|
| 412 |
+
R = quaternion_to_matrix(quaternion_R)
|
| 413 |
+
else:
|
| 414 |
+
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
|
| 415 |
+
|
| 416 |
+
c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device)
|
| 417 |
+
c2w_mats = c2w_mats[None].repeat(len(R), 1, 1)
|
| 418 |
+
c2w_mats[:, :3, :3] = R
|
| 419 |
+
c2w_mats[:, :3, 3] = abs_T
|
| 420 |
+
|
| 421 |
+
return c2w_mats
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def quaternion_conjugate(q):
|
| 425 |
+
"""Compute the conjugate of quaternion q (w, x, y, z)."""
|
| 426 |
+
|
| 427 |
+
q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
|
| 428 |
+
return q_conj
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def quaternion_multiply(q1, q2):
|
| 432 |
+
"""Multiply two quaternions q1 and q2."""
|
| 433 |
+
w1, x1, y1, z1 = q1.unbind(dim=-1)
|
| 434 |
+
w2, x2, y2, z2 = q2.unbind(dim=-1)
|
| 435 |
+
|
| 436 |
+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
| 437 |
+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
| 438 |
+
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
|
| 439 |
+
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 440 |
+
|
| 441 |
+
return torch.stack((w, x, y, z), dim=-1)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def rotate_vector(q, v):
|
| 445 |
+
"""Rotate vector v by quaternion q."""
|
| 446 |
+
q_vec = q[..., 1:]
|
| 447 |
+
q_w = q[..., :1]
|
| 448 |
+
|
| 449 |
+
t = 2.0 * torch.cross(q_vec, v, dim=-1)
|
| 450 |
+
v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1)
|
| 451 |
+
return v_rot
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def relative_pose_absT_quatR(t1, q1, t2, q2):
|
| 455 |
+
"""Compute the relative translation and quaternion between two poses."""
|
| 456 |
+
|
| 457 |
+
q1_inv = quaternion_conjugate(q1)
|
| 458 |
+
|
| 459 |
+
q_rel = quaternion_multiply(q1_inv, q2)
|
| 460 |
+
|
| 461 |
+
delta_t = t2 - t1
|
| 462 |
+
t_rel = rotate_vector(q1_inv, delta_t)
|
| 463 |
+
return t_rel, q_rel
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/device.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# modified from DUSt3R
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def todevice(batch, device, callback=None, non_blocking=False):
|
| 12 |
+
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
|
| 13 |
+
|
| 14 |
+
batch: list, tuple, dict of tensors or other things
|
| 15 |
+
device: pytorch device or 'numpy'
|
| 16 |
+
callback: function that would be called on every sub-elements.
|
| 17 |
+
"""
|
| 18 |
+
if callback:
|
| 19 |
+
batch = callback(batch)
|
| 20 |
+
|
| 21 |
+
if isinstance(batch, dict):
|
| 22 |
+
return {k: todevice(v, device) for k, v in batch.items()}
|
| 23 |
+
|
| 24 |
+
if isinstance(batch, (tuple, list)):
|
| 25 |
+
return type(batch)(todevice(x, device) for x in batch)
|
| 26 |
+
|
| 27 |
+
x = batch
|
| 28 |
+
if device == "numpy":
|
| 29 |
+
if isinstance(x, torch.Tensor):
|
| 30 |
+
x = x.detach().cpu().numpy()
|
| 31 |
+
elif x is not None:
|
| 32 |
+
if isinstance(x, np.ndarray):
|
| 33 |
+
x = torch.from_numpy(x)
|
| 34 |
+
if torch.is_tensor(x):
|
| 35 |
+
x = x.to(device, non_blocking=non_blocking)
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
to_device = todevice # alias
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def to_numpy(x):
|
| 43 |
+
return todevice(x, "numpy")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def to_cpu(x):
|
| 47 |
+
return todevice(x, "cpu")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def to_cuda(x):
|
| 51 |
+
return todevice(x, "cuda")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def collate_with_cat(whatever, lists=False):
|
| 55 |
+
if isinstance(whatever, dict):
|
| 56 |
+
return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
|
| 57 |
+
|
| 58 |
+
elif isinstance(whatever, (tuple, list)):
|
| 59 |
+
if len(whatever) == 0:
|
| 60 |
+
return whatever
|
| 61 |
+
elem = whatever[0]
|
| 62 |
+
T = type(whatever)
|
| 63 |
+
|
| 64 |
+
if elem is None:
|
| 65 |
+
return None
|
| 66 |
+
if isinstance(elem, (bool, float, int, str)):
|
| 67 |
+
return whatever
|
| 68 |
+
if isinstance(elem, tuple):
|
| 69 |
+
return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
|
| 70 |
+
if isinstance(elem, dict):
|
| 71 |
+
return {
|
| 72 |
+
k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
if isinstance(elem, torch.Tensor):
|
| 76 |
+
return listify(whatever) if lists else torch.cat(whatever)
|
| 77 |
+
if isinstance(elem, np.ndarray):
|
| 78 |
+
return (
|
| 79 |
+
listify(whatever)
|
| 80 |
+
if lists
|
| 81 |
+
else torch.cat([torch.from_numpy(x) for x in whatever])
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return sum(whatever, T())
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def listify(elems):
|
| 88 |
+
return [x for e in elems for x in e]
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/geometry.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# modified from DUSt3R
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.spatial import cKDTree as KDTree
|
| 10 |
+
|
| 11 |
+
from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
|
| 12 |
+
from dust3r.utils.device import to_numpy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def xy_grid(
|
| 16 |
+
W,
|
| 17 |
+
H,
|
| 18 |
+
device=None,
|
| 19 |
+
origin=(0, 0),
|
| 20 |
+
unsqueeze=None,
|
| 21 |
+
cat_dim=-1,
|
| 22 |
+
homogeneous=False,
|
| 23 |
+
**arange_kw,
|
| 24 |
+
):
|
| 25 |
+
"""Output a (H,W,2) array of int32
|
| 26 |
+
with output[j,i,0] = i + origin[0]
|
| 27 |
+
output[j,i,1] = j + origin[1]
|
| 28 |
+
"""
|
| 29 |
+
if device is None:
|
| 30 |
+
|
| 31 |
+
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
|
| 32 |
+
else:
|
| 33 |
+
|
| 34 |
+
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
|
| 35 |
+
meshgrid, stack = torch.meshgrid, torch.stack
|
| 36 |
+
ones = lambda *a: torch.ones(*a, device=device)
|
| 37 |
+
|
| 38 |
+
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
|
| 39 |
+
grid = meshgrid(tw, th, indexing="xy")
|
| 40 |
+
if homogeneous:
|
| 41 |
+
grid = grid + (ones((H, W)),)
|
| 42 |
+
if unsqueeze is not None:
|
| 43 |
+
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
|
| 44 |
+
if cat_dim is not None:
|
| 45 |
+
grid = stack(grid, cat_dim)
|
| 46 |
+
return grid
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def geotrf(Trf, pts, ncol=None, norm=False):
|
| 50 |
+
"""Apply a geometric transformation to a list of 3-D points.
|
| 51 |
+
|
| 52 |
+
H: 3x3 or 4x4 projection matrix (typically a Homography)
|
| 53 |
+
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
| 54 |
+
|
| 55 |
+
ncol: int. number of columns of the result (2 or 3)
|
| 56 |
+
norm: float. if != 0, the resut is projected on the z=norm plane.
|
| 57 |
+
|
| 58 |
+
Returns an array of projected 2d points.
|
| 59 |
+
"""
|
| 60 |
+
assert Trf.ndim >= 2
|
| 61 |
+
if isinstance(Trf, np.ndarray):
|
| 62 |
+
pts = np.asarray(pts)
|
| 63 |
+
elif isinstance(Trf, torch.Tensor):
|
| 64 |
+
pts = torch.as_tensor(pts, dtype=Trf.dtype)
|
| 65 |
+
|
| 66 |
+
output_reshape = pts.shape[:-1]
|
| 67 |
+
ncol = ncol or pts.shape[-1]
|
| 68 |
+
|
| 69 |
+
if (
|
| 70 |
+
isinstance(Trf, torch.Tensor)
|
| 71 |
+
and isinstance(pts, torch.Tensor)
|
| 72 |
+
and Trf.ndim == 3
|
| 73 |
+
and pts.ndim == 4
|
| 74 |
+
):
|
| 75 |
+
d = pts.shape[3]
|
| 76 |
+
if Trf.shape[-1] == d:
|
| 77 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
|
| 78 |
+
elif Trf.shape[-1] == d + 1:
|
| 79 |
+
pts = (
|
| 80 |
+
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
|
| 81 |
+
+ Trf[:, None, None, :d, d]
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
|
| 85 |
+
else:
|
| 86 |
+
if Trf.ndim >= 3:
|
| 87 |
+
n = Trf.ndim - 2
|
| 88 |
+
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
|
| 89 |
+
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
| 90 |
+
|
| 91 |
+
if pts.ndim > Trf.ndim:
|
| 92 |
+
|
| 93 |
+
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
| 94 |
+
elif pts.ndim == 2:
|
| 95 |
+
|
| 96 |
+
pts = pts[:, None, :]
|
| 97 |
+
|
| 98 |
+
if pts.shape[-1] + 1 == Trf.shape[-1]:
|
| 99 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 100 |
+
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
| 101 |
+
elif pts.shape[-1] == Trf.shape[-1]:
|
| 102 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 103 |
+
pts = pts @ Trf
|
| 104 |
+
else:
|
| 105 |
+
pts = Trf @ pts.T
|
| 106 |
+
if pts.ndim >= 2:
|
| 107 |
+
pts = pts.swapaxes(-1, -2)
|
| 108 |
+
|
| 109 |
+
if norm:
|
| 110 |
+
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
| 111 |
+
if norm != 1:
|
| 112 |
+
pts *= norm
|
| 113 |
+
|
| 114 |
+
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
| 115 |
+
return res
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def inv(mat):
|
| 119 |
+
"""Invert a torch or numpy matrix"""
|
| 120 |
+
if isinstance(mat, torch.Tensor):
|
| 121 |
+
return torch.linalg.inv(mat)
|
| 122 |
+
if isinstance(mat, np.ndarray):
|
| 123 |
+
return np.linalg.inv(mat)
|
| 124 |
+
raise ValueError(f"bad matrix type = {type(mat)}")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
|
| 128 |
+
"""
|
| 129 |
+
Args:
|
| 130 |
+
- depthmap (BxHxW array):
|
| 131 |
+
- pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
|
| 132 |
+
Returns:
|
| 133 |
+
pointmap of absolute coordinates (BxHxWx3 array)
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
if len(depth.shape) == 4:
|
| 137 |
+
B, H, W, n = depth.shape
|
| 138 |
+
else:
|
| 139 |
+
B, H, W = depth.shape
|
| 140 |
+
n = None
|
| 141 |
+
|
| 142 |
+
if len(pseudo_focal.shape) == 3: # [B,H,W]
|
| 143 |
+
pseudo_focalx = pseudo_focaly = pseudo_focal
|
| 144 |
+
elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
|
| 145 |
+
pseudo_focalx = pseudo_focal[:, 0]
|
| 146 |
+
if pseudo_focal.shape[1] == 2:
|
| 147 |
+
pseudo_focaly = pseudo_focal[:, 1]
|
| 148 |
+
else:
|
| 149 |
+
pseudo_focaly = pseudo_focalx
|
| 150 |
+
else:
|
| 151 |
+
raise NotImplementedError("Error, unknown input focal shape format.")
|
| 152 |
+
|
| 153 |
+
assert pseudo_focalx.shape == depth.shape[:3]
|
| 154 |
+
assert pseudo_focaly.shape == depth.shape[:3]
|
| 155 |
+
grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
|
| 156 |
+
|
| 157 |
+
if pp is None:
|
| 158 |
+
grid_x = grid_x - (W - 1) / 2
|
| 159 |
+
grid_y = grid_y - (H - 1) / 2
|
| 160 |
+
else:
|
| 161 |
+
grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
|
| 162 |
+
grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
|
| 163 |
+
|
| 164 |
+
if n is None:
|
| 165 |
+
pts3d = torch.empty((B, H, W, 3), device=depth.device)
|
| 166 |
+
pts3d[..., 0] = depth * grid_x / pseudo_focalx
|
| 167 |
+
pts3d[..., 1] = depth * grid_y / pseudo_focaly
|
| 168 |
+
pts3d[..., 2] = depth
|
| 169 |
+
else:
|
| 170 |
+
pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
|
| 171 |
+
pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
|
| 172 |
+
pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
|
| 173 |
+
pts3d[..., 2, :] = depth
|
| 174 |
+
return pts3d
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
- depthmap (HxW array):
|
| 181 |
+
- camera_intrinsics: a 3x3 matrix
|
| 182 |
+
Returns:
|
| 183 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 184 |
+
"""
|
| 185 |
+
camera_intrinsics = np.float32(camera_intrinsics)
|
| 186 |
+
H, W = depthmap.shape
|
| 187 |
+
|
| 188 |
+
assert camera_intrinsics[0, 1] == 0.0
|
| 189 |
+
assert camera_intrinsics[1, 0] == 0.0
|
| 190 |
+
if pseudo_focal is None:
|
| 191 |
+
fu = camera_intrinsics[0, 0]
|
| 192 |
+
fv = camera_intrinsics[1, 1]
|
| 193 |
+
else:
|
| 194 |
+
assert pseudo_focal.shape == (H, W)
|
| 195 |
+
fu = fv = pseudo_focal
|
| 196 |
+
cu = camera_intrinsics[0, 2]
|
| 197 |
+
cv = camera_intrinsics[1, 2]
|
| 198 |
+
|
| 199 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 200 |
+
z_cam = depthmap
|
| 201 |
+
x_cam = (u - cu) * z_cam / fu
|
| 202 |
+
y_cam = (v - cv) * z_cam / fv
|
| 203 |
+
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
| 204 |
+
valid_mask = depthmap > 0.0
|
| 205 |
+
return X_cam, valid_mask
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def depthmap_to_absolute_camera_coordinates(
|
| 209 |
+
depthmap, camera_intrinsics, camera_pose, **kw
|
| 210 |
+
):
|
| 211 |
+
"""
|
| 212 |
+
Args:
|
| 213 |
+
- depthmap (HxW array):
|
| 214 |
+
- camera_intrinsics: a 3x3 matrix
|
| 215 |
+
- camera_pose: a 4x3 or 4x4 cam2world matrix
|
| 216 |
+
Returns:
|
| 217 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 218 |
+
"""
|
| 219 |
+
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
|
| 220 |
+
|
| 221 |
+
X_world = X_cam # default
|
| 222 |
+
if camera_pose is not None:
|
| 223 |
+
|
| 224 |
+
R_cam2world = camera_pose[:3, :3]
|
| 225 |
+
t_cam2world = camera_pose[:3, 3]
|
| 226 |
+
|
| 227 |
+
X_world = (
|
| 228 |
+
np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return X_world, X_cam, valid_mask
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def colmap_to_opencv_intrinsics(K):
|
| 235 |
+
"""
|
| 236 |
+
Modify camera intrinsics to follow a different convention.
|
| 237 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 238 |
+
- (0.5, 0.5) in Colmap
|
| 239 |
+
- (0,0) in OpenCV
|
| 240 |
+
"""
|
| 241 |
+
K = K.copy()
|
| 242 |
+
K[0, 2] -= 0.5
|
| 243 |
+
K[1, 2] -= 0.5
|
| 244 |
+
return K
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def opencv_to_colmap_intrinsics(K):
|
| 248 |
+
"""
|
| 249 |
+
Modify camera intrinsics to follow a different convention.
|
| 250 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 251 |
+
- (0.5, 0.5) in Colmap
|
| 252 |
+
- (0,0) in OpenCV
|
| 253 |
+
"""
|
| 254 |
+
K = K.copy()
|
| 255 |
+
K[0, 2] += 0.5
|
| 256 |
+
K[1, 2] += 0.5
|
| 257 |
+
return K
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def normalize_pointcloud(
|
| 261 |
+
pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None, ret_factor=False
|
| 262 |
+
):
|
| 263 |
+
"""renorm pointmaps pts1, pts2 with norm_mode"""
|
| 264 |
+
assert pts1.ndim >= 3 and pts1.shape[-1] == 3
|
| 265 |
+
assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
|
| 266 |
+
norm_mode, dis_mode = norm_mode.split("_")
|
| 267 |
+
|
| 268 |
+
if norm_mode == "avg":
|
| 269 |
+
|
| 270 |
+
nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
|
| 271 |
+
nan_pts2, nnz2 = (
|
| 272 |
+
invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
|
| 273 |
+
)
|
| 274 |
+
all_pts = (
|
| 275 |
+
torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
all_dis = all_pts.norm(dim=-1)
|
| 279 |
+
if dis_mode == "dis":
|
| 280 |
+
pass # do nothing
|
| 281 |
+
elif dis_mode == "log1p":
|
| 282 |
+
all_dis = torch.log1p(all_dis)
|
| 283 |
+
elif dis_mode == "warp-log1p":
|
| 284 |
+
|
| 285 |
+
log_dis = torch.log1p(all_dis)
|
| 286 |
+
warp_factor = log_dis / all_dis.clip(min=1e-8)
|
| 287 |
+
H1, W1 = pts1.shape[1:-1]
|
| 288 |
+
pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1)
|
| 289 |
+
if pts2 is not None:
|
| 290 |
+
H2, W2 = pts2.shape[1:-1]
|
| 291 |
+
pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1)
|
| 292 |
+
all_dis = log_dis # this is their true distance afterwards
|
| 293 |
+
else:
|
| 294 |
+
raise ValueError(f"bad {dis_mode=}")
|
| 295 |
+
|
| 296 |
+
norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
|
| 297 |
+
else:
|
| 298 |
+
|
| 299 |
+
nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
|
| 300 |
+
nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
|
| 301 |
+
all_pts = (
|
| 302 |
+
torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
all_dis = all_pts.norm(dim=-1)
|
| 306 |
+
|
| 307 |
+
if norm_mode == "avg":
|
| 308 |
+
norm_factor = all_dis.nanmean(dim=1)
|
| 309 |
+
elif norm_mode == "median":
|
| 310 |
+
norm_factor = all_dis.nanmedian(dim=1).values.detach()
|
| 311 |
+
elif norm_mode == "sqrt":
|
| 312 |
+
norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2
|
| 313 |
+
else:
|
| 314 |
+
raise ValueError(f"bad {norm_mode=}")
|
| 315 |
+
|
| 316 |
+
norm_factor = norm_factor.clip(min=1e-8)
|
| 317 |
+
while norm_factor.ndim < pts1.ndim:
|
| 318 |
+
norm_factor.unsqueeze_(-1)
|
| 319 |
+
|
| 320 |
+
res = pts1 / norm_factor
|
| 321 |
+
if pts2 is not None:
|
| 322 |
+
res = (res, pts2 / norm_factor)
|
| 323 |
+
if ret_factor:
|
| 324 |
+
res = res + (norm_factor,)
|
| 325 |
+
return res
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def normalize_pointcloud_group(
|
| 329 |
+
pts_list,
|
| 330 |
+
norm_mode="avg_dis",
|
| 331 |
+
valid_list=None,
|
| 332 |
+
conf_list=None,
|
| 333 |
+
ret_factor=False,
|
| 334 |
+
ret_factor_only=False,
|
| 335 |
+
):
|
| 336 |
+
"""renorm pointmaps pts1, pts2 with norm_mode"""
|
| 337 |
+
for pts in pts_list:
|
| 338 |
+
assert pts.ndim >= 3 and pts.shape[-1] == 3
|
| 339 |
+
|
| 340 |
+
norm_mode, dis_mode = norm_mode.split("_")
|
| 341 |
+
|
| 342 |
+
if norm_mode == "avg":
|
| 343 |
+
|
| 344 |
+
nan_pts_list, nnz_list = zip(
|
| 345 |
+
*[
|
| 346 |
+
invalid_to_zeros(pts1, valid1, ndim=3)
|
| 347 |
+
for pts1, valid1 in zip(pts_list, valid_list)
|
| 348 |
+
]
|
| 349 |
+
)
|
| 350 |
+
all_pts = torch.cat(nan_pts_list, dim=1)
|
| 351 |
+
if conf_list is not None:
|
| 352 |
+
nan_conf_list = [
|
| 353 |
+
invalid_to_zeros(conf1[..., None], valid1, ndim=3)[0]
|
| 354 |
+
for conf1, valid1 in zip(conf_list, valid_list)
|
| 355 |
+
]
|
| 356 |
+
all_conf = torch.cat(nan_conf_list, dim=1)[..., 0]
|
| 357 |
+
else:
|
| 358 |
+
all_conf = torch.ones_like(all_pts[..., 0])
|
| 359 |
+
|
| 360 |
+
all_dis = all_pts.norm(dim=-1)
|
| 361 |
+
if dis_mode == "dis":
|
| 362 |
+
pass # do nothing
|
| 363 |
+
elif dis_mode == "log1p":
|
| 364 |
+
all_dis = torch.log1p(all_dis)
|
| 365 |
+
elif dis_mode == "warp-log1p":
|
| 366 |
+
|
| 367 |
+
log_dis = torch.log1p(all_dis)
|
| 368 |
+
warp_factor = log_dis / all_dis.clip(min=1e-8)
|
| 369 |
+
H_W_list = [pts.shape[1:-1] for pts in pts_list]
|
| 370 |
+
pts_list = [
|
| 371 |
+
pts
|
| 372 |
+
* warp_factor[:, sum(H_W_list[:i]) : sum(H_W_list[: i + 1])].view(
|
| 373 |
+
-1, H, W, 1
|
| 374 |
+
)
|
| 375 |
+
for i, (pts, (H, W)) in enumerate(zip(pts_list, H_W_list))
|
| 376 |
+
]
|
| 377 |
+
all_dis = log_dis # this is their true distance afterwards
|
| 378 |
+
else:
|
| 379 |
+
raise ValueError(f"bad {dis_mode=}")
|
| 380 |
+
|
| 381 |
+
norm_factor = (all_conf * all_dis).sum(dim=1) / (all_conf.sum(dim=1) + 1e-8)
|
| 382 |
+
else:
|
| 383 |
+
|
| 384 |
+
nan_pts_list = [
|
| 385 |
+
invalid_to_nans(pts1, valid1, ndim=3)
|
| 386 |
+
for pts1, valid1 in zip(pts_list, valid_list)
|
| 387 |
+
]
|
| 388 |
+
|
| 389 |
+
all_pts = torch.cat(nan_pts_list, dim=1)
|
| 390 |
+
|
| 391 |
+
all_dis = all_pts.norm(dim=-1)
|
| 392 |
+
|
| 393 |
+
if norm_mode == "avg":
|
| 394 |
+
norm_factor = all_dis.nanmean(dim=1)
|
| 395 |
+
elif norm_mode == "median":
|
| 396 |
+
norm_factor = all_dis.nanmedian(dim=1).values.detach()
|
| 397 |
+
elif norm_mode == "sqrt":
|
| 398 |
+
norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(f"bad {norm_mode=}")
|
| 401 |
+
|
| 402 |
+
norm_factor = norm_factor.clip(min=1e-8)
|
| 403 |
+
while norm_factor.ndim < pts_list[0].ndim:
|
| 404 |
+
norm_factor.unsqueeze_(-1)
|
| 405 |
+
|
| 406 |
+
if ret_factor_only:
|
| 407 |
+
|
| 408 |
+
return norm_factor
|
| 409 |
+
|
| 410 |
+
res = [pts / norm_factor for pts in pts_list]
|
| 411 |
+
if ret_factor:
|
| 412 |
+
return res, norm_factor
|
| 413 |
+
return res
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
@torch.no_grad()
|
| 417 |
+
def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
|
| 418 |
+
|
| 419 |
+
_z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
|
| 420 |
+
_z2 = (
|
| 421 |
+
invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1)
|
| 422 |
+
if z2 is not None
|
| 423 |
+
else None
|
| 424 |
+
)
|
| 425 |
+
_z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
|
| 426 |
+
|
| 427 |
+
if quantile == 0.5:
|
| 428 |
+
shift_z = torch.nanmedian(_z, dim=-1).values
|
| 429 |
+
else:
|
| 430 |
+
shift_z = torch.nanquantile(_z, quantile, dim=-1)
|
| 431 |
+
return shift_z # (B,)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
@torch.no_grad()
|
| 435 |
+
def get_group_pointcloud_depth(zs, valid_masks, quantile=0.5):
|
| 436 |
+
|
| 437 |
+
_zs = [
|
| 438 |
+
invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
|
| 439 |
+
for z1, valid_mask1 in zip(zs, valid_masks)
|
| 440 |
+
]
|
| 441 |
+
_z = torch.cat(_zs, dim=-1)
|
| 442 |
+
|
| 443 |
+
if quantile == 0.5:
|
| 444 |
+
shift_z = torch.nanmedian(_z, dim=-1).values
|
| 445 |
+
else:
|
| 446 |
+
shift_z = torch.nanquantile(_z, quantile, dim=-1)
|
| 447 |
+
return shift_z # (B,)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
@torch.no_grad()
|
| 451 |
+
def get_joint_pointcloud_center_scale(
|
| 452 |
+
pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True
|
| 453 |
+
):
|
| 454 |
+
|
| 455 |
+
_pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
|
| 456 |
+
_pts2 = (
|
| 457 |
+
invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3)
|
| 458 |
+
if pts2 is not None
|
| 459 |
+
else None
|
| 460 |
+
)
|
| 461 |
+
_pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
|
| 462 |
+
|
| 463 |
+
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
|
| 464 |
+
if z_only:
|
| 465 |
+
_center[..., :2] = 0 # do not center X and Y
|
| 466 |
+
|
| 467 |
+
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
|
| 468 |
+
scale = torch.nanmedian(_norm, dim=1).values
|
| 469 |
+
return _center[:, None, :, :], scale[:, None, None, None]
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
@torch.no_grad()
|
| 473 |
+
def get_group_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
|
| 474 |
+
|
| 475 |
+
_pts = [
|
| 476 |
+
invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
|
| 477 |
+
for pts1, valid_mask1 in zip(pts, valid_masks)
|
| 478 |
+
]
|
| 479 |
+
_pts = torch.cat(_pts, dim=1)
|
| 480 |
+
|
| 481 |
+
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
|
| 482 |
+
if z_only:
|
| 483 |
+
_center[..., :2] = 0 # do not center X and Y
|
| 484 |
+
|
| 485 |
+
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
|
| 486 |
+
scale = torch.nanmedian(_norm, dim=1).values
|
| 487 |
+
return _center[:, None, :, :], scale[:, None, None, None]
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def find_reciprocal_matches(P1, P2):
|
| 491 |
+
"""
|
| 492 |
+
returns 3 values:
|
| 493 |
+
1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
|
| 494 |
+
2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
|
| 495 |
+
3 - reciprocal_in_P2.sum(): the number of matches
|
| 496 |
+
"""
|
| 497 |
+
tree1 = KDTree(P1)
|
| 498 |
+
tree2 = KDTree(P2)
|
| 499 |
+
|
| 500 |
+
_, nn1_in_P2 = tree2.query(P1, workers=8)
|
| 501 |
+
_, nn2_in_P1 = tree1.query(P2, workers=8)
|
| 502 |
+
|
| 503 |
+
reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))
|
| 504 |
+
reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))
|
| 505 |
+
assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
|
| 506 |
+
return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def get_med_dist_between_poses(poses):
|
| 510 |
+
from scipy.spatial.distance import pdist
|
| 511 |
+
|
| 512 |
+
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def weighted_procrustes(A, B, w, use_weights=True, eps=1e-16, return_T=False):
|
| 516 |
+
"""
|
| 517 |
+
X: torch tensor B x N x 3
|
| 518 |
+
Y: torch tensor B x N x 3
|
| 519 |
+
w: torch tensor B x N
|
| 520 |
+
"""
|
| 521 |
+
assert len(A) == len(B)
|
| 522 |
+
if use_weights:
|
| 523 |
+
W1 = torch.abs(w).sum(1, keepdim=True)
|
| 524 |
+
w_norm = (w / (W1 + eps)).unsqueeze(-1)
|
| 525 |
+
a_mean = (w_norm * A).sum(dim=1, keepdim=True)
|
| 526 |
+
b_mean = (w_norm * B).sum(dim=1, keepdim=True)
|
| 527 |
+
|
| 528 |
+
A_c = A - a_mean
|
| 529 |
+
B_c = B - b_mean
|
| 530 |
+
|
| 531 |
+
H = torch.einsum("bni,bnj->bij", A_c, w_norm * B_c)
|
| 532 |
+
|
| 533 |
+
else:
|
| 534 |
+
a_mean = A.mean(axis=1, keepdim=True)
|
| 535 |
+
b_mean = B.mean(axis=1, keepdim=True)
|
| 536 |
+
|
| 537 |
+
A_c = A - a_mean
|
| 538 |
+
B_c = B - b_mean
|
| 539 |
+
|
| 540 |
+
H = torch.einsum("bij,bik->bjk", A_c, B_c)
|
| 541 |
+
|
| 542 |
+
U, S, V = torch.svd(H) # U: B x 3 x 3, S: B x 3, V: B x 3 x 3
|
| 543 |
+
Z = torch.eye(3).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device)
|
| 544 |
+
Z[:, -1, -1] = torch.sign(torch.linalg.det(U @ V.transpose(1, 2))) # B x 3 x 3
|
| 545 |
+
R = V @ Z @ U.transpose(1, 2) # B x 3 x 3
|
| 546 |
+
t = b_mean - torch.einsum("bij,bjk->bik", R, a_mean.transpose(-2, -1)).transpose(
|
| 547 |
+
-2, -1
|
| 548 |
+
)
|
| 549 |
+
if return_T:
|
| 550 |
+
T = torch.eye(4).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device)
|
| 551 |
+
T[:, :3, :3] = R
|
| 552 |
+
T[:, :3, 3] = t.squeeze()
|
| 553 |
+
return T
|
| 554 |
+
return R, t.squeeze()
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/image.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# modified from DUSt3R
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import PIL.Image
|
| 11 |
+
from PIL.ImageOps import exif_transpose
|
| 12 |
+
import torchvision.transforms as tvf
|
| 13 |
+
|
| 14 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 15 |
+
import cv2 # noqa
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from pillow_heif import register_heif_opener # noqa
|
| 19 |
+
|
| 20 |
+
register_heif_opener()
|
| 21 |
+
heif_support_enabled = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
heif_support_enabled = False
|
| 24 |
+
|
| 25 |
+
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def img_to_arr(img):
|
| 29 |
+
if isinstance(img, str):
|
| 30 |
+
img = imread_cv2(img)
|
| 31 |
+
return img
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def imread_cv2(path, options=cv2.IMREAD_COLOR):
|
| 35 |
+
"""Open an image or a depthmap with opencv-python."""
|
| 36 |
+
if path.endswith((".exr", "EXR")):
|
| 37 |
+
options = cv2.IMREAD_ANYDEPTH
|
| 38 |
+
img = cv2.imread(path, options)
|
| 39 |
+
if img is None:
|
| 40 |
+
raise IOError(f"Could not load image={path} with {options=}")
|
| 41 |
+
if img.ndim == 3:
|
| 42 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 43 |
+
return img
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def imread_pil(path):
|
| 47 |
+
"""Open an RGB image using PIL and return as numpy array."""
|
| 48 |
+
img = PIL.Image.open(path)
|
| 49 |
+
img = exif_transpose(img)
|
| 50 |
+
img = img.convert("RGB")
|
| 51 |
+
return np.array(img)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def rgb(ftensor, true_shape=None):
|
| 55 |
+
if isinstance(ftensor, list):
|
| 56 |
+
return [rgb(x, true_shape=true_shape) for x in ftensor]
|
| 57 |
+
if isinstance(ftensor, torch.Tensor):
|
| 58 |
+
ftensor = ftensor.detach().cpu().numpy() # H,W,3
|
| 59 |
+
if ftensor.ndim == 3 and ftensor.shape[0] == 3:
|
| 60 |
+
ftensor = ftensor.transpose(1, 2, 0)
|
| 61 |
+
elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
|
| 62 |
+
ftensor = ftensor.transpose(0, 2, 3, 1)
|
| 63 |
+
if true_shape is not None:
|
| 64 |
+
H, W = true_shape
|
| 65 |
+
ftensor = ftensor[:H, :W]
|
| 66 |
+
if ftensor.dtype == np.uint8:
|
| 67 |
+
img = np.float32(ftensor) / 255
|
| 68 |
+
else:
|
| 69 |
+
img = (ftensor * 0.5) + 0.5
|
| 70 |
+
return img.clip(min=0, max=1)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _resize_pil_image(img, long_edge_size):
|
| 74 |
+
S = max(img.size)
|
| 75 |
+
if S > long_edge_size:
|
| 76 |
+
interp = PIL.Image.LANCZOS
|
| 77 |
+
elif S <= long_edge_size:
|
| 78 |
+
interp = PIL.Image.BICUBIC
|
| 79 |
+
new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
|
| 80 |
+
return img.resize(new_size, interp)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_images(folder_or_list, size, square_ok=False, verbose=True):
|
| 84 |
+
"""open and convert all images in a list or folder to proper input format for DUSt3R"""
|
| 85 |
+
if isinstance(folder_or_list, str):
|
| 86 |
+
if verbose:
|
| 87 |
+
print(f">> Loading images from {folder_or_list}")
|
| 88 |
+
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
|
| 89 |
+
|
| 90 |
+
elif isinstance(folder_or_list, list):
|
| 91 |
+
if verbose:
|
| 92 |
+
print(f">> Loading a list of {len(folder_or_list)} images")
|
| 93 |
+
root, folder_content = "", folder_or_list
|
| 94 |
+
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
|
| 97 |
+
|
| 98 |
+
supported_images_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
|
| 99 |
+
if heif_support_enabled:
|
| 100 |
+
supported_images_extensions += [".heic", ".heif"]
|
| 101 |
+
supported_images_extensions = tuple(supported_images_extensions)
|
| 102 |
+
|
| 103 |
+
imgs = []
|
| 104 |
+
for path in folder_content:
|
| 105 |
+
if not path.lower().endswith(supported_images_extensions):
|
| 106 |
+
continue
|
| 107 |
+
img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
|
| 108 |
+
W1, H1 = img.size
|
| 109 |
+
if size == 224:
|
| 110 |
+
|
| 111 |
+
img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
|
| 112 |
+
else:
|
| 113 |
+
|
| 114 |
+
img = _resize_pil_image(img, size)
|
| 115 |
+
W, H = img.size
|
| 116 |
+
cx, cy = W // 2, H // 2
|
| 117 |
+
if size == 224:
|
| 118 |
+
half = min(cx, cy)
|
| 119 |
+
img = img.crop((cx - half, cy - half, cx + half, cy + half))
|
| 120 |
+
else:
|
| 121 |
+
halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
|
| 122 |
+
if not (square_ok) and W == H:
|
| 123 |
+
halfh = 3 * halfw / 4
|
| 124 |
+
img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
|
| 125 |
+
|
| 126 |
+
W2, H2 = img.size
|
| 127 |
+
if verbose:
|
| 128 |
+
print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
|
| 129 |
+
imgs.append(
|
| 130 |
+
dict(
|
| 131 |
+
img=ImgNorm(img)[None],
|
| 132 |
+
true_shape=np.int32([img.size[::-1]]),
|
| 133 |
+
idx=len(imgs),
|
| 134 |
+
instance=str(len(imgs)),
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
assert imgs, "no images foud at " + root
|
| 139 |
+
if verbose:
|
| 140 |
+
print(f" (Found {len(imgs)} images)")
|
| 141 |
+
return imgs
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def load_images_for_eval(
|
| 145 |
+
folder_or_list, size, square_ok=False, verbose=True, crop=True
|
| 146 |
+
):
|
| 147 |
+
"""open and convert all images in a list or folder to proper input format for DUSt3R"""
|
| 148 |
+
if isinstance(folder_or_list, str):
|
| 149 |
+
if verbose:
|
| 150 |
+
print(f">> Loading images from {folder_or_list}")
|
| 151 |
+
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
|
| 152 |
+
|
| 153 |
+
elif isinstance(folder_or_list, list):
|
| 154 |
+
if verbose:
|
| 155 |
+
print(f">> Loading a list of {len(folder_or_list)} images")
|
| 156 |
+
root, folder_content = "", folder_or_list
|
| 157 |
+
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
|
| 160 |
+
|
| 161 |
+
supported_images_extensions = [".jpg", ".jpeg", ".png"]
|
| 162 |
+
if heif_support_enabled:
|
| 163 |
+
supported_images_extensions += [".heic", ".heif"]
|
| 164 |
+
supported_images_extensions = tuple(supported_images_extensions)
|
| 165 |
+
|
| 166 |
+
imgs = []
|
| 167 |
+
for path in folder_content:
|
| 168 |
+
if not path.lower().endswith(supported_images_extensions):
|
| 169 |
+
continue
|
| 170 |
+
img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
|
| 171 |
+
W1, H1 = img.size
|
| 172 |
+
if size == 224:
|
| 173 |
+
# resize short side to 224 (then crop)
|
| 174 |
+
img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
|
| 175 |
+
else:
|
| 176 |
+
# resize long side to 512
|
| 177 |
+
img = _resize_pil_image(img, size)
|
| 178 |
+
W, H = img.size
|
| 179 |
+
cx, cy = W // 2, H // 2
|
| 180 |
+
if size == 224:
|
| 181 |
+
half = min(cx, cy)
|
| 182 |
+
if crop:
|
| 183 |
+
img = img.crop((cx - half, cy - half, cx + half, cy + half))
|
| 184 |
+
else: # resize
|
| 185 |
+
img = img.resize((2 * half, 2 * half), PIL.Image.LANCZOS)
|
| 186 |
+
else:
|
| 187 |
+
halfw, halfh = ((2 * cx) // 14) * 7, ((2 * cy) // 14) * 7
|
| 188 |
+
if not (square_ok) and W == H:
|
| 189 |
+
halfh = 3 * halfw / 4
|
| 190 |
+
if crop:
|
| 191 |
+
img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
|
| 192 |
+
else: # resize
|
| 193 |
+
img = img.resize((2 * halfw, 2 * halfh), PIL.Image.LANCZOS)
|
| 194 |
+
W2, H2 = img.size
|
| 195 |
+
if verbose:
|
| 196 |
+
print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
|
| 197 |
+
imgs.append(
|
| 198 |
+
dict(
|
| 199 |
+
img=ImgNorm(img)[None],
|
| 200 |
+
true_shape=np.int32([img.size[::-1]]),
|
| 201 |
+
idx=len(imgs),
|
| 202 |
+
instance=str(len(imgs)),
|
| 203 |
+
)
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
assert imgs, "no images foud at " + root
|
| 207 |
+
if verbose:
|
| 208 |
+
print(f" (Found {len(imgs)} images)")
|
| 209 |
+
return imgs
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def load_images_512(folder_or_list, size, square_ok=False, verbose=True):
|
| 213 |
+
"""open and convert all images in a list or folder to proper input format for DUSt3R"""
|
| 214 |
+
if isinstance(folder_or_list, str):
|
| 215 |
+
if verbose:
|
| 216 |
+
print(f">> Loading images from {folder_or_list}")
|
| 217 |
+
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
|
| 218 |
+
|
| 219 |
+
elif isinstance(folder_or_list, list):
|
| 220 |
+
if verbose:
|
| 221 |
+
print(f">> Loading a list of {len(folder_or_list)} images")
|
| 222 |
+
root, folder_content = "", folder_or_list
|
| 223 |
+
|
| 224 |
+
else:
|
| 225 |
+
raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
|
| 226 |
+
|
| 227 |
+
supported_images_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
|
| 228 |
+
if heif_support_enabled:
|
| 229 |
+
supported_images_extensions += [".heic", ".heif"]
|
| 230 |
+
supported_images_extensions = tuple(supported_images_extensions)
|
| 231 |
+
|
| 232 |
+
imgs = []
|
| 233 |
+
for path in folder_content:
|
| 234 |
+
if not path.lower().endswith(supported_images_extensions):
|
| 235 |
+
continue
|
| 236 |
+
img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
|
| 237 |
+
img = img.resize((512, 384))
|
| 238 |
+
W1, H1 = img.size
|
| 239 |
+
if size == 224:
|
| 240 |
+
|
| 241 |
+
img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
|
| 242 |
+
else:
|
| 243 |
+
|
| 244 |
+
img = _resize_pil_image(img, size)
|
| 245 |
+
W, H = img.size
|
| 246 |
+
cx, cy = W // 2, H // 2
|
| 247 |
+
if size == 224:
|
| 248 |
+
half = min(cx, cy)
|
| 249 |
+
img = img.crop((cx - half, cy - half, cx + half, cy + half))
|
| 250 |
+
else:
|
| 251 |
+
halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
|
| 252 |
+
if not (square_ok) and W == H:
|
| 253 |
+
halfh = 3 * halfw / 4
|
| 254 |
+
img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
|
| 255 |
+
|
| 256 |
+
W2, H2 = img.size
|
| 257 |
+
if verbose:
|
| 258 |
+
print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
|
| 259 |
+
imgs.append(
|
| 260 |
+
dict(
|
| 261 |
+
img=ImgNorm(img)[None],
|
| 262 |
+
true_shape=np.int32([img.size[::-1]]),
|
| 263 |
+
idx=len(imgs),
|
| 264 |
+
instance=str(len(imgs)),
|
| 265 |
+
)
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
assert imgs, "no images foud at " + root
|
| 269 |
+
if verbose:
|
| 270 |
+
print(f" (Found {len(imgs)} images)")
|
| 271 |
+
return imgs
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/misc.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# modified from DUSt3R
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def fill_default_args(kwargs, func):
|
| 11 |
+
import inspect # a bit hacky but it works reliably
|
| 12 |
+
|
| 13 |
+
signature = inspect.signature(func)
|
| 14 |
+
|
| 15 |
+
for k, v in signature.parameters.items():
|
| 16 |
+
if v.default is inspect.Parameter.empty:
|
| 17 |
+
continue
|
| 18 |
+
kwargs.setdefault(k, v.default)
|
| 19 |
+
|
| 20 |
+
return kwargs
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def freeze_all_params(modules):
|
| 24 |
+
for module in modules:
|
| 25 |
+
try:
|
| 26 |
+
for n, param in module.named_parameters():
|
| 27 |
+
param.requires_grad = False
|
| 28 |
+
except AttributeError:
|
| 29 |
+
|
| 30 |
+
module.requires_grad = False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_symmetrized(gt1, gt2):
|
| 34 |
+
x = gt1["instance"]
|
| 35 |
+
y = gt2["instance"]
|
| 36 |
+
if len(x) == len(y) and len(x) == 1:
|
| 37 |
+
return False # special case of batchsize 1
|
| 38 |
+
ok = True
|
| 39 |
+
for i in range(0, len(x), 2):
|
| 40 |
+
ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])
|
| 41 |
+
return ok
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def flip(tensor):
|
| 45 |
+
"""flip so that tensor[0::2] <=> tensor[1::2]"""
|
| 46 |
+
return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def interleave(tensor1, tensor2):
|
| 50 |
+
res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
|
| 51 |
+
res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
|
| 52 |
+
return res1, res2
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def transpose_to_landscape(head, activate=True):
|
| 56 |
+
"""Predict in the correct aspect-ratio,
|
| 57 |
+
then transpose the result in landscape
|
| 58 |
+
and stack everything back together.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def wrapper_no(decout, true_shape, **kwargs):
|
| 62 |
+
B = len(true_shape)
|
| 63 |
+
assert true_shape[0:1].allclose(true_shape), "true_shape must be all identical"
|
| 64 |
+
H, W = true_shape[0].cpu().tolist()
|
| 65 |
+
res = head(decout, (H, W), **kwargs)
|
| 66 |
+
return res
|
| 67 |
+
|
| 68 |
+
def wrapper_yes(decout, true_shape, **kwargs):
|
| 69 |
+
B = len(true_shape)
|
| 70 |
+
|
| 71 |
+
H, W = int(true_shape.min()), int(true_shape.max())
|
| 72 |
+
|
| 73 |
+
height, width = true_shape.T
|
| 74 |
+
is_landscape = width >= height
|
| 75 |
+
is_portrait = ~is_landscape
|
| 76 |
+
|
| 77 |
+
if is_landscape.all():
|
| 78 |
+
return head(decout, (H, W), **kwargs)
|
| 79 |
+
if is_portrait.all():
|
| 80 |
+
return transposed(head(decout, (W, H), **kwargs))
|
| 81 |
+
|
| 82 |
+
def selout(ar):
|
| 83 |
+
return [d[ar] for d in decout]
|
| 84 |
+
|
| 85 |
+
if "pos" in kwargs:
|
| 86 |
+
kwargs_landscape = kwargs.copy()
|
| 87 |
+
kwargs_landscape["pos"] = kwargs["pos"][is_landscape]
|
| 88 |
+
kwargs_portrait = kwargs.copy()
|
| 89 |
+
kwargs_portrait["pos"] = kwargs["pos"][is_portrait]
|
| 90 |
+
l_result = head(selout(is_landscape), (H, W), **kwargs_landscape)
|
| 91 |
+
p_result = transposed(head(selout(is_portrait), (W, H), **kwargs_portrait))
|
| 92 |
+
|
| 93 |
+
result = {}
|
| 94 |
+
for k in l_result | p_result:
|
| 95 |
+
x = l_result[k].new(B, *l_result[k].shape[1:])
|
| 96 |
+
x[is_landscape] = l_result[k]
|
| 97 |
+
x[is_portrait] = p_result[k]
|
| 98 |
+
result[k] = x
|
| 99 |
+
|
| 100 |
+
return result
|
| 101 |
+
|
| 102 |
+
return wrapper_yes if activate else wrapper_no
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def transposed(dic):
|
| 106 |
+
return {k: v.swapaxes(1, 2) if v.ndim > 2 else v for k, v in dic.items()}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def invalid_to_nans(arr, valid_mask, ndim=999):
|
| 110 |
+
if valid_mask is not None:
|
| 111 |
+
arr = arr.clone()
|
| 112 |
+
arr[~valid_mask] = float("nan")
|
| 113 |
+
if arr.ndim > ndim:
|
| 114 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 115 |
+
return arr
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def invalid_to_zeros(arr, valid_mask, ndim=999):
|
| 119 |
+
if valid_mask is not None:
|
| 120 |
+
arr = arr.clone()
|
| 121 |
+
arr[~valid_mask] = 0
|
| 122 |
+
nnz = valid_mask.view(len(valid_mask), -1).sum(1)
|
| 123 |
+
else:
|
| 124 |
+
nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
|
| 125 |
+
if arr.ndim > ndim:
|
| 126 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 127 |
+
return arr, nnz
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/parallel.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# modified from DUSt3R
|
| 6 |
+
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from multiprocessing.dummy import Pool as ThreadPool
|
| 9 |
+
from multiprocessing import cpu_count
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parallel_threads(
|
| 13 |
+
function,
|
| 14 |
+
args,
|
| 15 |
+
workers=0,
|
| 16 |
+
star_args=False,
|
| 17 |
+
kw_args=False,
|
| 18 |
+
front_num=1,
|
| 19 |
+
Pool=ThreadPool,
|
| 20 |
+
**tqdm_kw
|
| 21 |
+
):
|
| 22 |
+
"""tqdm but with parallel execution.
|
| 23 |
+
|
| 24 |
+
Will essentially return
|
| 25 |
+
res = [ function(arg) # default
|
| 26 |
+
function(*arg) # if star_args is True
|
| 27 |
+
function(**arg) # if kw_args is True
|
| 28 |
+
for arg in args]
|
| 29 |
+
|
| 30 |
+
Note:
|
| 31 |
+
the <front_num> first elements of args will not be parallelized.
|
| 32 |
+
This can be useful for debugging.
|
| 33 |
+
"""
|
| 34 |
+
while workers <= 0:
|
| 35 |
+
workers += cpu_count()
|
| 36 |
+
if workers == 1:
|
| 37 |
+
front_num = float("inf")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
n_args_parallel = len(args) - front_num
|
| 41 |
+
except TypeError:
|
| 42 |
+
n_args_parallel = None
|
| 43 |
+
args = iter(args)
|
| 44 |
+
|
| 45 |
+
front = []
|
| 46 |
+
while len(front) < front_num:
|
| 47 |
+
try:
|
| 48 |
+
a = next(args)
|
| 49 |
+
except StopIteration:
|
| 50 |
+
return front # end of the iterable
|
| 51 |
+
front.append(
|
| 52 |
+
function(*a) if star_args else function(**a) if kw_args else function(a)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
out = []
|
| 56 |
+
with Pool(workers) as pool:
|
| 57 |
+
|
| 58 |
+
if star_args:
|
| 59 |
+
futures = pool.imap(starcall, [(function, a) for a in args])
|
| 60 |
+
elif kw_args:
|
| 61 |
+
futures = pool.imap(starstarcall, [(function, a) for a in args])
|
| 62 |
+
else:
|
| 63 |
+
futures = pool.imap(function, args)
|
| 64 |
+
|
| 65 |
+
for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
|
| 66 |
+
out.append(f)
|
| 67 |
+
return front + out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def parallel_processes(*args, **kwargs):
|
| 71 |
+
"""Same as parallel_threads, with processes"""
|
| 72 |
+
import multiprocessing as mp
|
| 73 |
+
|
| 74 |
+
kwargs["Pool"] = mp.Pool
|
| 75 |
+
return parallel_threads(*args, **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def starcall(args):
|
| 79 |
+
"""convenient wrapper for Process.Pool"""
|
| 80 |
+
function, args = args
|
| 81 |
+
return function(*args)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def starstarcall(args):
|
| 85 |
+
"""convenient wrapper for Process.Pool"""
|
| 86 |
+
function, args = args
|
| 87 |
+
return function(**args)
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/path_to_croco.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# modified from DUSt3R
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import os.path as path
|
| 9 |
+
import importlib
|
| 10 |
+
|
| 11 |
+
HERE_PATH = path.normpath(path.dirname(__file__))
|
| 12 |
+
CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, "../../croco"))
|
| 13 |
+
CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, "models")
|
| 14 |
+
# IMPORTANT:
|
| 15 |
+
# Do NOT add `.../src/croco` directly to sys.path, otherwise subfolders like
|
| 16 |
+
# `croco/datasets` become a top-level module named `datasets`, which will shadow
|
| 17 |
+
# HuggingFace `datasets` and break `accelerate` (and others).
|
| 18 |
+
# Instead, add `.../src` so we import as `croco.*`.
|
| 19 |
+
SRC_PATH = path.normpath(path.join(HERE_PATH, "../../.."))
|
| 20 |
+
|
| 21 |
+
if path.isdir(CROCO_MODELS_PATH):
|
| 22 |
+
|
| 23 |
+
# Prefer adding the `src` directory; this enables `import croco...` without
|
| 24 |
+
# polluting top-level module names.
|
| 25 |
+
if SRC_PATH not in sys.path:
|
| 26 |
+
sys.path.insert(0, SRC_PATH)
|
| 27 |
+
|
| 28 |
+
# In case an old run already inserted CROCO_REPO_PATH, remove it to avoid
|
| 29 |
+
# shadowing top-level modules (e.g., `datasets`).
|
| 30 |
+
while CROCO_REPO_PATH in sys.path:
|
| 31 |
+
sys.path.remove(CROCO_REPO_PATH)
|
| 32 |
+
|
| 33 |
+
# Backward-compat: DUSt3R code expects `models.*` to exist as a top-level package
|
| 34 |
+
# (historically achieved by adding CROCO_REPO_PATH to sys.path). We keep that
|
| 35 |
+
# import path working by aliasing `croco.models` to `models` without exposing
|
| 36 |
+
# other top-level names like `datasets`.
|
| 37 |
+
try:
|
| 38 |
+
_croco_models = importlib.import_module("croco.models")
|
| 39 |
+
sys.modules.setdefault("models", _croco_models)
|
| 40 |
+
except Exception:
|
| 41 |
+
# If croco isn't importable yet, downstream import will raise a clearer error.
|
| 42 |
+
pass
|
| 43 |
+
else:
|
| 44 |
+
raise ImportError(
|
| 45 |
+
f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n "
|
| 46 |
+
"Did you forget to run 'git submodule update --init --recursive' ?"
|
| 47 |
+
)
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/render.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from gsplat import rasterization
|
| 3 |
+
from dust3r.utils.geometry import inv, geotrf
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def render(
|
| 7 |
+
intrinsics: torch.Tensor,
|
| 8 |
+
pts3d: torch.Tensor,
|
| 9 |
+
rgbs: torch.Tensor | None = None,
|
| 10 |
+
scale: float = 0.002,
|
| 11 |
+
opacity: float = 0.95,
|
| 12 |
+
):
|
| 13 |
+
|
| 14 |
+
device = pts3d.device
|
| 15 |
+
batch_size = len(intrinsics)
|
| 16 |
+
img_size = pts3d.shape[1:3]
|
| 17 |
+
pts3d = pts3d.reshape(batch_size, -1, 3)
|
| 18 |
+
num_pts = pts3d.shape[1]
|
| 19 |
+
quats = torch.randn((num_pts, 4), device=device)
|
| 20 |
+
quats = quats / quats.norm(dim=-1, keepdim=True)
|
| 21 |
+
scales = scale * torch.ones((num_pts, 3), device=device)
|
| 22 |
+
opacities = opacity * torch.ones((num_pts), device=device)
|
| 23 |
+
if rgbs is not None:
|
| 24 |
+
assert rgbs.shape[1] == 3
|
| 25 |
+
rgbs = rgbs.reshape(batch_size, 3, -1).transpose(1, 2)
|
| 26 |
+
else:
|
| 27 |
+
rgbs = torch.ones_like(pts3d[:, :, :3])
|
| 28 |
+
|
| 29 |
+
rendered_rgbs = []
|
| 30 |
+
rendered_depths = []
|
| 31 |
+
accs = []
|
| 32 |
+
for i in range(batch_size):
|
| 33 |
+
rgbd, acc, _ = rasterization(
|
| 34 |
+
pts3d[i],
|
| 35 |
+
quats,
|
| 36 |
+
scales,
|
| 37 |
+
opacities,
|
| 38 |
+
rgbs[i],
|
| 39 |
+
torch.eye(4, device=device)[None],
|
| 40 |
+
intrinsics[[i]],
|
| 41 |
+
width=img_size[1],
|
| 42 |
+
height=img_size[0],
|
| 43 |
+
packed=False,
|
| 44 |
+
render_mode="RGB+D",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
rendered_depths.append(rgbd[..., 3])
|
| 48 |
+
|
| 49 |
+
rendered_depths = torch.cat(rendered_depths, dim=0)
|
| 50 |
+
|
| 51 |
+
return rendered_rgbs, rendered_depths, accs
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_render_results(gts, preds, self_view=False):
|
| 55 |
+
device = preds[0]["pts3d_in_other_view"].device
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
depths = []
|
| 58 |
+
gt_depths = []
|
| 59 |
+
for i, (gt, pred) in enumerate(zip(gts, preds)):
|
| 60 |
+
if self_view:
|
| 61 |
+
camera = inv(gt["camera_pose"]).to(device)
|
| 62 |
+
intrinsics = gt["camera_intrinsics"].to(device)
|
| 63 |
+
pred = pred["pts3d_in_other_view"]
|
| 64 |
+
else:
|
| 65 |
+
camera = inv(gts[0]["camera_pose"]).to(device)
|
| 66 |
+
intrinsics = gts[0]["camera_intrinsics"].to(device)
|
| 67 |
+
pred = pred["pts3d_in_other_view"]
|
| 68 |
+
gt_img = gt["img"].to(device)
|
| 69 |
+
gt_pts3d = gt["pts3d"].to(device)
|
| 70 |
+
|
| 71 |
+
_, depth, _ = render(intrinsics, pred, gt_img)
|
| 72 |
+
_, gt_depth, _ = render(intrinsics, geotrf(camera, gt_pts3d), gt_img)
|
| 73 |
+
depths.append(depth)
|
| 74 |
+
gt_depths.append(gt_depth)
|
| 75 |
+
return depths, gt_depths
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/backbones.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Weights(Enum):
|
| 15 |
+
LVD142M = "LVD142M"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _make_dinov2_model(
|
| 19 |
+
*,
|
| 20 |
+
arch_name: str = "vit_large",
|
| 21 |
+
img_size: int = 518,
|
| 22 |
+
patch_size: int = 14,
|
| 23 |
+
init_values: float = 1.0,
|
| 24 |
+
ffn_layer: str = "mlp",
|
| 25 |
+
block_chunks: int = 0,
|
| 26 |
+
num_register_tokens: int = 0,
|
| 27 |
+
interpolate_antialias: bool = False,
|
| 28 |
+
interpolate_offset: float = 0.1,
|
| 29 |
+
pretrained: bool = True,
|
| 30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
from ..models import vision_transformer as vits
|
| 34 |
+
|
| 35 |
+
if isinstance(weights, str):
|
| 36 |
+
try:
|
| 37 |
+
weights = Weights[weights]
|
| 38 |
+
except KeyError:
|
| 39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 40 |
+
|
| 41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 42 |
+
vit_kwargs = dict(
|
| 43 |
+
img_size=img_size,
|
| 44 |
+
patch_size=patch_size,
|
| 45 |
+
init_values=init_values,
|
| 46 |
+
ffn_layer=ffn_layer,
|
| 47 |
+
block_chunks=block_chunks,
|
| 48 |
+
num_register_tokens=num_register_tokens,
|
| 49 |
+
interpolate_antialias=interpolate_antialias,
|
| 50 |
+
interpolate_offset=interpolate_offset,
|
| 51 |
+
)
|
| 52 |
+
vit_kwargs.update(**kwargs)
|
| 53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 54 |
+
|
| 55 |
+
if pretrained:
|
| 56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
| 57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 59 |
+
model.load_state_dict(state_dict, strict=True)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 67 |
+
"""
|
| 68 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 72 |
+
"""
|
| 73 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 74 |
+
"""
|
| 75 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 79 |
+
"""
|
| 80 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 81 |
+
"""
|
| 82 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 88 |
+
"""
|
| 89 |
+
return _make_dinov2_model(
|
| 90 |
+
arch_name="vit_giant2",
|
| 91 |
+
ffn_layer="swiglufused",
|
| 92 |
+
weights=weights,
|
| 93 |
+
pretrained=pretrained,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 99 |
+
"""
|
| 100 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 101 |
+
"""
|
| 102 |
+
return _make_dinov2_model(
|
| 103 |
+
arch_name="vit_small",
|
| 104 |
+
pretrained=pretrained,
|
| 105 |
+
weights=weights,
|
| 106 |
+
num_register_tokens=4,
|
| 107 |
+
interpolate_antialias=True,
|
| 108 |
+
interpolate_offset=0.0,
|
| 109 |
+
**kwargs,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 114 |
+
"""
|
| 115 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 116 |
+
"""
|
| 117 |
+
return _make_dinov2_model(
|
| 118 |
+
arch_name="vit_base",
|
| 119 |
+
pretrained=pretrained,
|
| 120 |
+
weights=weights,
|
| 121 |
+
num_register_tokens=4,
|
| 122 |
+
interpolate_antialias=True,
|
| 123 |
+
interpolate_offset=0.0,
|
| 124 |
+
**kwargs,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 129 |
+
"""
|
| 130 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 131 |
+
"""
|
| 132 |
+
return _make_dinov2_model(
|
| 133 |
+
arch_name="vit_large",
|
| 134 |
+
pretrained=pretrained,
|
| 135 |
+
weights=weights,
|
| 136 |
+
num_register_tokens=4,
|
| 137 |
+
interpolate_antialias=True,
|
| 138 |
+
interpolate_offset=0.0,
|
| 139 |
+
**kwargs,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 144 |
+
"""
|
| 145 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 146 |
+
"""
|
| 147 |
+
return _make_dinov2_model(
|
| 148 |
+
arch_name="vit_giant2",
|
| 149 |
+
ffn_layer="swiglufused",
|
| 150 |
+
weights=weights,
|
| 151 |
+
pretrained=pretrained,
|
| 152 |
+
num_register_tokens=4,
|
| 153 |
+
interpolate_antialias=True,
|
| 154 |
+
interpolate_offset=0.0,
|
| 155 |
+
**kwargs,
|
| 156 |
+
)
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
| 18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CenterPadding(nn.Module):
|
| 24 |
+
def __init__(self, multiple):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.multiple = multiple
|
| 27 |
+
|
| 28 |
+
def _get_pad(self, size):
|
| 29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 30 |
+
pad_size = new_size - size
|
| 31 |
+
pad_size_left = pad_size // 2
|
| 32 |
+
pad_size_right = pad_size - pad_size_left
|
| 33 |
+
return pad_size_left, pad_size_right
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
| 38 |
+
output = F.pad(x, pads)
|
| 39 |
+
return output
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_head import DINOHead
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 22 |
+
try:
|
| 23 |
+
if XFORMERS_ENABLED:
|
| 24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 25 |
+
|
| 26 |
+
XFORMERS_AVAILABLE = True
|
| 27 |
+
# warnings.warn("xFormers is available (Attention)")
|
| 28 |
+
else:
|
| 29 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
| 30 |
+
raise ImportError
|
| 31 |
+
except ImportError:
|
| 32 |
+
XFORMERS_AVAILABLE = False
|
| 33 |
+
# warnings.warn("xFormers is not available (Attention)")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Attention(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int = 8,
|
| 41 |
+
qkv_bias: bool = False,
|
| 42 |
+
proj_bias: bool = True,
|
| 43 |
+
attn_drop: float = 0.0,
|
| 44 |
+
proj_drop: float = 0.0,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
self.scale = head_dim**-0.5
|
| 50 |
+
|
| 51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 57 |
+
B, N, C = x.shape
|
| 58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 59 |
+
|
| 60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 61 |
+
attn = q @ k.transpose(-2, -1)
|
| 62 |
+
|
| 63 |
+
attn = attn.softmax(dim=-1)
|
| 64 |
+
attn = self.attn_drop(attn)
|
| 65 |
+
|
| 66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 67 |
+
x = self.proj(x)
|
| 68 |
+
x = self.proj_drop(x)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MemEffAttention(Attention):
|
| 73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 74 |
+
if not XFORMERS_AVAILABLE:
|
| 75 |
+
if attn_bias is not None:
|
| 76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 77 |
+
return super().forward(x)
|
| 78 |
+
|
| 79 |
+
B, N, C = x.shape
|
| 80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 81 |
+
|
| 82 |
+
q, k, v = unbind(qkv, 2)
|
| 83 |
+
|
| 84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 85 |
+
x = x.reshape([B, N, C])
|
| 86 |
+
|
| 87 |
+
x = self.proj(x)
|
| 88 |
+
x = self.proj_drop(x)
|
| 89 |
+
return x
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/block.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention, MemEffAttention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 28 |
+
try:
|
| 29 |
+
if XFORMERS_ENABLED:
|
| 30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
| 31 |
+
|
| 32 |
+
XFORMERS_AVAILABLE = True
|
| 33 |
+
# warnings.warn("xFormers is available (Block)")
|
| 34 |
+
else:
|
| 35 |
+
# warnings.warn("xFormers is disabled (Block)")
|
| 36 |
+
raise ImportError
|
| 37 |
+
except ImportError:
|
| 38 |
+
XFORMERS_AVAILABLE = False
|
| 39 |
+
# warnings.warn("xFormers is not available (Block)")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Block(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim: int,
|
| 46 |
+
num_heads: int,
|
| 47 |
+
mlp_ratio: float = 4.0,
|
| 48 |
+
qkv_bias: bool = False,
|
| 49 |
+
proj_bias: bool = True,
|
| 50 |
+
ffn_bias: bool = True,
|
| 51 |
+
drop: float = 0.0,
|
| 52 |
+
attn_drop: float = 0.0,
|
| 53 |
+
init_values=None,
|
| 54 |
+
drop_path: float = 0.0,
|
| 55 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 56 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 57 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 58 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 59 |
+
) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 62 |
+
self.norm1 = norm_layer(dim)
|
| 63 |
+
self.attn = attn_class(
|
| 64 |
+
dim,
|
| 65 |
+
num_heads=num_heads,
|
| 66 |
+
qkv_bias=qkv_bias,
|
| 67 |
+
proj_bias=proj_bias,
|
| 68 |
+
attn_drop=attn_drop,
|
| 69 |
+
proj_drop=drop,
|
| 70 |
+
)
|
| 71 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 72 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 73 |
+
|
| 74 |
+
self.norm2 = norm_layer(dim)
|
| 75 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 76 |
+
self.mlp = ffn_layer(
|
| 77 |
+
in_features=dim,
|
| 78 |
+
hidden_features=mlp_hidden_dim,
|
| 79 |
+
act_layer=act_layer,
|
| 80 |
+
drop=drop,
|
| 81 |
+
bias=ffn_bias,
|
| 82 |
+
)
|
| 83 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 84 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 85 |
+
|
| 86 |
+
self.sample_drop_ratio = drop_path
|
| 87 |
+
|
| 88 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 89 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 90 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 91 |
+
|
| 92 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 93 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 94 |
+
|
| 95 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 96 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 97 |
+
x = drop_add_residual_stochastic_depth(
|
| 98 |
+
x,
|
| 99 |
+
residual_func=attn_residual_func,
|
| 100 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 101 |
+
)
|
| 102 |
+
x = drop_add_residual_stochastic_depth(
|
| 103 |
+
x,
|
| 104 |
+
residual_func=ffn_residual_func,
|
| 105 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 106 |
+
)
|
| 107 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 108 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 109 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 110 |
+
else:
|
| 111 |
+
x = x + attn_residual_func(x)
|
| 112 |
+
x = x + ffn_residual_func(x)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def drop_add_residual_stochastic_depth(
|
| 117 |
+
x: Tensor,
|
| 118 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 119 |
+
sample_drop_ratio: float = 0.0,
|
| 120 |
+
) -> Tensor:
|
| 121 |
+
# 1) extract subset using permutation
|
| 122 |
+
b, n, d = x.shape
|
| 123 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 124 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 125 |
+
x_subset = x[brange]
|
| 126 |
+
|
| 127 |
+
# 2) apply residual_func to get residual
|
| 128 |
+
residual = residual_func(x_subset)
|
| 129 |
+
|
| 130 |
+
x_flat = x.flatten(1)
|
| 131 |
+
residual = residual.flatten(1)
|
| 132 |
+
|
| 133 |
+
residual_scale_factor = b / sample_subset_size
|
| 134 |
+
|
| 135 |
+
# 3) add the residual
|
| 136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 137 |
+
return x_plus_residual.view_as(x)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 141 |
+
b, n, d = x.shape
|
| 142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 144 |
+
residual_scale_factor = b / sample_subset_size
|
| 145 |
+
return brange, residual_scale_factor
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 149 |
+
if scaling_vector is None:
|
| 150 |
+
x_flat = x.flatten(1)
|
| 151 |
+
residual = residual.flatten(1)
|
| 152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 153 |
+
else:
|
| 154 |
+
x_plus_residual = scaled_index_add(
|
| 155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 156 |
+
)
|
| 157 |
+
return x_plus_residual
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 164 |
+
"""
|
| 165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 166 |
+
"""
|
| 167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 169 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 170 |
+
seqlens = []
|
| 171 |
+
for b, x in zip(batch_sizes, x_list):
|
| 172 |
+
for _ in range(b):
|
| 173 |
+
seqlens.append(x.shape[1])
|
| 174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 175 |
+
attn_bias._batch_sizes = batch_sizes
|
| 176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 177 |
+
|
| 178 |
+
if branges is not None:
|
| 179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 180 |
+
else:
|
| 181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 183 |
+
|
| 184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def drop_add_residual_stochastic_depth_list(
|
| 188 |
+
x_list: List[Tensor],
|
| 189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 190 |
+
sample_drop_ratio: float = 0.0,
|
| 191 |
+
scaling_vector=None,
|
| 192 |
+
) -> Tensor:
|
| 193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 195 |
+
branges = [s[0] for s in branges_scales]
|
| 196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 197 |
+
|
| 198 |
+
# 2) get attention bias and index+concat the tensors
|
| 199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 200 |
+
|
| 201 |
+
# 3) apply residual_func to get residual, and split the result
|
| 202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 203 |
+
|
| 204 |
+
outputs = []
|
| 205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 207 |
+
return outputs
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class NestedTensorBlock(Block):
|
| 211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 212 |
+
"""
|
| 213 |
+
x_list contains a list of tensors to nest together and run
|
| 214 |
+
"""
|
| 215 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 216 |
+
|
| 217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 218 |
+
|
| 219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 221 |
+
|
| 222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 223 |
+
return self.mlp(self.norm2(x))
|
| 224 |
+
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=attn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 232 |
+
x_list,
|
| 233 |
+
residual_func=ffn_residual_func,
|
| 234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 236 |
+
)
|
| 237 |
+
return x_list
|
| 238 |
+
else:
|
| 239 |
+
|
| 240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 242 |
+
|
| 243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 245 |
+
|
| 246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 248 |
+
x = x + ffn_residual_func(x)
|
| 249 |
+
return attn_bias.split(x)
|
| 250 |
+
|
| 251 |
+
def forward(self, x_or_x_list):
|
| 252 |
+
if isinstance(x_or_x_list, Tensor):
|
| 253 |
+
return super().forward(x_or_x_list)
|
| 254 |
+
elif isinstance(x_or_x_list, list):
|
| 255 |
+
if not XFORMERS_AVAILABLE:
|
| 256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 257 |
+
return self.forward_nested(x_or_x_list)
|
| 258 |
+
else:
|
| 259 |
+
raise AssertionError
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/dino_head.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn.init import trunc_normal_
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOHead(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_dim,
|
| 16 |
+
out_dim,
|
| 17 |
+
use_bn=False,
|
| 18 |
+
nlayers=3,
|
| 19 |
+
hidden_dim=2048,
|
| 20 |
+
bottleneck_dim=256,
|
| 21 |
+
mlp_bias=True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
nlayers = max(nlayers, 1)
|
| 25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
| 26 |
+
self.apply(self._init_weights)
|
| 27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 28 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, m):
|
| 31 |
+
if isinstance(m, nn.Linear):
|
| 32 |
+
trunc_normal_(m.weight, std=0.02)
|
| 33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.mlp(x)
|
| 38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 40 |
+
x = self.last_layer(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
| 45 |
+
if nlayers == 1:
|
| 46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 47 |
+
else:
|
| 48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 49 |
+
if use_bn:
|
| 50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 51 |
+
layers.append(nn.GELU())
|
| 52 |
+
for _ in range(nlayers - 2):
|
| 53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 54 |
+
if use_bn:
|
| 55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 56 |
+
layers.append(nn.GELU())
|
| 57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 58 |
+
return nn.Sequential(*layers)
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 20 |
+
inplace: bool = False,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.inplace = inplace
|
| 24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
_, _, H, W = x.shape
|
| 70 |
+
patch_H, patch_W = self.patch_size
|
| 71 |
+
|
| 72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 74 |
+
|
| 75 |
+
x = self.proj(x) # B C H W
|
| 76 |
+
H, W = x.size(2), x.size(3)
|
| 77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 78 |
+
x = self.norm(x)
|
| 79 |
+
if not self.flatten_embedding:
|
| 80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def flops(self) -> float:
|
| 84 |
+
Ho, Wo = self.patches_resolution
|
| 85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 86 |
+
if self.norm is not None:
|
| 87 |
+
flops += Ho * Wo * self.embed_dim
|
| 88 |
+
return flops
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
hidden_features: Optional[int] = None,
|
| 19 |
+
out_features: Optional[int] = None,
|
| 20 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 21 |
+
drop: float = 0.0,
|
| 22 |
+
bias: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
x12 = self.w12(x)
|
| 32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 33 |
+
hidden = F.silu(x1) * x2
|
| 34 |
+
return self.w3(hidden)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 38 |
+
try:
|
| 39 |
+
if XFORMERS_ENABLED:
|
| 40 |
+
from xformers.ops import SwiGLU
|
| 41 |
+
|
| 42 |
+
XFORMERS_AVAILABLE = True
|
| 43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 44 |
+
else:
|
| 45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 46 |
+
raise ImportError
|
| 47 |
+
except ImportError:
|
| 48 |
+
SwiGLU = SwiGLUFFN
|
| 49 |
+
XFORMERS_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_features: int,
|
| 58 |
+
hidden_features: Optional[int] = None,
|
| 59 |
+
out_features: Optional[int] = None,
|
| 60 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 61 |
+
drop: float = 0.0,
|
| 62 |
+
bias: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
out_features = out_features or in_features
|
| 65 |
+
hidden_features = hidden_features or in_features
|
| 66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 67 |
+
super().__init__(
|
| 68 |
+
in_features=in_features,
|
| 69 |
+
hidden_features=hidden_features,
|
| 70 |
+
out_features=out_features,
|
| 71 |
+
bias=bias,
|
| 72 |
+
)
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from . import vision_transformer as vits
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("dinov2")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_model(args, only_teacher=False, img_size=224):
|
| 15 |
+
args.arch = args.arch.removesuffix("_memeff")
|
| 16 |
+
if "vit" in args.arch:
|
| 17 |
+
vit_kwargs = dict(
|
| 18 |
+
img_size=img_size,
|
| 19 |
+
patch_size=args.patch_size,
|
| 20 |
+
init_values=args.layerscale,
|
| 21 |
+
ffn_layer=args.ffn_layer,
|
| 22 |
+
block_chunks=args.block_chunks,
|
| 23 |
+
qkv_bias=args.qkv_bias,
|
| 24 |
+
proj_bias=args.proj_bias,
|
| 25 |
+
ffn_bias=args.ffn_bias,
|
| 26 |
+
num_register_tokens=args.num_register_tokens,
|
| 27 |
+
interpolate_offset=args.interpolate_offset,
|
| 28 |
+
interpolate_antialias=args.interpolate_antialias,
|
| 29 |
+
)
|
| 30 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
| 31 |
+
if only_teacher:
|
| 32 |
+
return teacher, teacher.embed_dim
|
| 33 |
+
student = vits.__dict__[args.arch](
|
| 34 |
+
**vit_kwargs,
|
| 35 |
+
drop_path_rate=args.drop_path_rate,
|
| 36 |
+
drop_path_uniform=args.drop_path_uniform,
|
| 37 |
+
)
|
| 38 |
+
embed_dim = student.embed_dim
|
| 39 |
+
return student, teacher, embed_dim
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
| 43 |
+
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.utils.checkpoint import checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
|
| 20 |
+
from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 21 |
+
from ...layers.attention import FlashAttention
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 28 |
+
if not depth_first and include_root:
|
| 29 |
+
fn(module=module, name=name)
|
| 30 |
+
for child_name, child_module in module.named_children():
|
| 31 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 32 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 33 |
+
if depth_first and include_root:
|
| 34 |
+
fn(module=module, name=name)
|
| 35 |
+
return module
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BlockChunk(nn.ModuleList):
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
for b in self:
|
| 41 |
+
x = b(x)
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DinoVisionTransformer(nn.Module):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
img_size=224,
|
| 49 |
+
patch_size=16,
|
| 50 |
+
in_chans=3,
|
| 51 |
+
embed_dim=768,
|
| 52 |
+
depth=12,
|
| 53 |
+
num_heads=12,
|
| 54 |
+
mlp_ratio=4.0,
|
| 55 |
+
qkv_bias=True,
|
| 56 |
+
ffn_bias=True,
|
| 57 |
+
proj_bias=True,
|
| 58 |
+
drop_path_rate=0.0,
|
| 59 |
+
drop_path_uniform=False,
|
| 60 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 61 |
+
embed_layer=PatchEmbed,
|
| 62 |
+
act_layer=nn.GELU,
|
| 63 |
+
block_fn=Block,
|
| 64 |
+
ffn_layer="mlp",
|
| 65 |
+
block_chunks=1,
|
| 66 |
+
num_register_tokens=0,
|
| 67 |
+
interpolate_antialias=False,
|
| 68 |
+
interpolate_offset=0.1,
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Args:
|
| 72 |
+
img_size (int, tuple): input image size
|
| 73 |
+
patch_size (int, tuple): patch size
|
| 74 |
+
in_chans (int): number of input channels
|
| 75 |
+
embed_dim (int): embedding dimension
|
| 76 |
+
depth (int): depth of transformer
|
| 77 |
+
num_heads (int): number of attention heads
|
| 78 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 79 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 80 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 81 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 82 |
+
drop_path_rate (float): stochastic depth rate
|
| 83 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 84 |
+
weight_init (str): weight init scheme
|
| 85 |
+
init_values (float): layer-scale init values
|
| 86 |
+
embed_layer (nn.Module): patch embedding layer
|
| 87 |
+
act_layer (nn.Module): MLP activation layer
|
| 88 |
+
block_fn (nn.Module): transformer block class
|
| 89 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 90 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 91 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 92 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 93 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 94 |
+
"""
|
| 95 |
+
super().__init__()
|
| 96 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 97 |
+
|
| 98 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 99 |
+
self.num_tokens = 1
|
| 100 |
+
self.n_blocks = depth
|
| 101 |
+
self.num_heads = num_heads
|
| 102 |
+
self.patch_size = patch_size
|
| 103 |
+
self.num_register_tokens = num_register_tokens
|
| 104 |
+
self.interpolate_antialias = interpolate_antialias
|
| 105 |
+
self.interpolate_offset = interpolate_offset
|
| 106 |
+
|
| 107 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 108 |
+
num_patches = self.patch_embed.num_patches
|
| 109 |
+
|
| 110 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 111 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 112 |
+
assert num_register_tokens >= 0
|
| 113 |
+
self.register_tokens = (
|
| 114 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if drop_path_uniform is True:
|
| 118 |
+
dpr = [drop_path_rate] * depth
|
| 119 |
+
else:
|
| 120 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 121 |
+
|
| 122 |
+
if ffn_layer == "mlp":
|
| 123 |
+
# logger.info("using MLP layer as FFN")
|
| 124 |
+
ffn_layer = Mlp
|
| 125 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 126 |
+
# logger.info("using SwiGLU layer as FFN")
|
| 127 |
+
ffn_layer = SwiGLUFFNFused
|
| 128 |
+
elif ffn_layer == "identity":
|
| 129 |
+
# logger.info("using Identity layer as FFN")
|
| 130 |
+
|
| 131 |
+
def f(*args, **kwargs):
|
| 132 |
+
return nn.Identity()
|
| 133 |
+
|
| 134 |
+
ffn_layer = f
|
| 135 |
+
else:
|
| 136 |
+
raise NotImplementedError
|
| 137 |
+
|
| 138 |
+
blocks_list = [
|
| 139 |
+
block_fn(
|
| 140 |
+
dim=embed_dim,
|
| 141 |
+
num_heads=num_heads,
|
| 142 |
+
mlp_ratio=mlp_ratio,
|
| 143 |
+
qkv_bias=qkv_bias,
|
| 144 |
+
proj_bias=proj_bias,
|
| 145 |
+
ffn_bias=ffn_bias,
|
| 146 |
+
drop_path=dpr[i],
|
| 147 |
+
norm_layer=norm_layer,
|
| 148 |
+
act_layer=act_layer,
|
| 149 |
+
ffn_layer=ffn_layer,
|
| 150 |
+
init_values=init_values,
|
| 151 |
+
attn_class=FlashAttention
|
| 152 |
+
)
|
| 153 |
+
for i in range(depth)
|
| 154 |
+
]
|
| 155 |
+
if block_chunks > 0:
|
| 156 |
+
self.chunked_blocks = True
|
| 157 |
+
chunked_blocks = []
|
| 158 |
+
chunksize = depth // block_chunks
|
| 159 |
+
for i in range(0, depth, chunksize):
|
| 160 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 161 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 162 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 163 |
+
else:
|
| 164 |
+
self.chunked_blocks = False
|
| 165 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 166 |
+
|
| 167 |
+
self.norm = norm_layer(embed_dim)
|
| 168 |
+
self.head = nn.Identity()
|
| 169 |
+
|
| 170 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 171 |
+
|
| 172 |
+
self.init_weights()
|
| 173 |
+
|
| 174 |
+
def init_weights(self):
|
| 175 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 176 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 177 |
+
if self.register_tokens is not None:
|
| 178 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 179 |
+
named_apply(init_weights_vit_timm, self)
|
| 180 |
+
|
| 181 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 182 |
+
previous_dtype = x.dtype
|
| 183 |
+
npatch = x.shape[1] - 1
|
| 184 |
+
N = self.pos_embed.shape[1] - 1
|
| 185 |
+
if npatch == N and w == h:
|
| 186 |
+
return self.pos_embed
|
| 187 |
+
pos_embed = self.pos_embed.float()
|
| 188 |
+
class_pos_embed = pos_embed[:, 0]
|
| 189 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 190 |
+
dim = x.shape[-1]
|
| 191 |
+
w0 = w // self.patch_size
|
| 192 |
+
h0 = h // self.patch_size
|
| 193 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 194 |
+
assert N == M * M
|
| 195 |
+
kwargs = {}
|
| 196 |
+
if self.interpolate_offset:
|
| 197 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 198 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 199 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 200 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 201 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 202 |
+
else:
|
| 203 |
+
# Simply specify an output size instead of a scale factor
|
| 204 |
+
kwargs["size"] = (w0, h0)
|
| 205 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 206 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 207 |
+
mode="bicubic",
|
| 208 |
+
antialias=self.interpolate_antialias,
|
| 209 |
+
**kwargs,
|
| 210 |
+
)
|
| 211 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 212 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 213 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 214 |
+
|
| 215 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 216 |
+
B, nc, w, h = x.shape
|
| 217 |
+
x = self.patch_embed(x)
|
| 218 |
+
if masks is not None:
|
| 219 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 220 |
+
|
| 221 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 222 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 223 |
+
|
| 224 |
+
if self.register_tokens is not None:
|
| 225 |
+
x = torch.cat(
|
| 226 |
+
(
|
| 227 |
+
x[:, :1],
|
| 228 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 229 |
+
x[:, 1:],
|
| 230 |
+
),
|
| 231 |
+
dim=1,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
def forward_features_list(self, x_list, masks_list):
|
| 237 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 238 |
+
for blk in self.blocks:
|
| 239 |
+
if self.training:
|
| 240 |
+
x = checkpoint(blk, x, use_reentrant=False)
|
| 241 |
+
else:
|
| 242 |
+
x = blk(x)
|
| 243 |
+
|
| 244 |
+
all_x = x
|
| 245 |
+
output = []
|
| 246 |
+
for x, masks in zip(all_x, masks_list):
|
| 247 |
+
x_norm = self.norm(x)
|
| 248 |
+
output.append(
|
| 249 |
+
{
|
| 250 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 251 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 252 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 253 |
+
"x_prenorm": x,
|
| 254 |
+
"masks": masks,
|
| 255 |
+
}
|
| 256 |
+
)
|
| 257 |
+
return output
|
| 258 |
+
|
| 259 |
+
def forward_features(self, x, masks=None):
|
| 260 |
+
if isinstance(x, list):
|
| 261 |
+
return self.forward_features_list(x, masks)
|
| 262 |
+
|
| 263 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 264 |
+
|
| 265 |
+
for blk in self.blocks:
|
| 266 |
+
if self.training:
|
| 267 |
+
x = checkpoint(blk, x, use_reentrant=False)
|
| 268 |
+
else:
|
| 269 |
+
x = blk(x)
|
| 270 |
+
|
| 271 |
+
x_norm = self.norm(x)
|
| 272 |
+
return {
|
| 273 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 274 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 275 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 276 |
+
"x_prenorm": x,
|
| 277 |
+
"masks": masks,
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 281 |
+
x = self.prepare_tokens_with_masks(x)
|
| 282 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 283 |
+
output, total_block_len = [], len(self.blocks)
|
| 284 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 285 |
+
for i, blk in enumerate(self.blocks):
|
| 286 |
+
x = blk(x)
|
| 287 |
+
if i in blocks_to_take:
|
| 288 |
+
output.append(x)
|
| 289 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 290 |
+
return output
|
| 291 |
+
|
| 292 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 293 |
+
x = self.prepare_tokens_with_masks(x)
|
| 294 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 295 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 296 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 297 |
+
for block_chunk in self.blocks:
|
| 298 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 299 |
+
x = blk(x)
|
| 300 |
+
if i in blocks_to_take:
|
| 301 |
+
output.append(x)
|
| 302 |
+
i += 1
|
| 303 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 304 |
+
return output
|
| 305 |
+
|
| 306 |
+
def get_intermediate_layers(
|
| 307 |
+
self,
|
| 308 |
+
x: torch.Tensor,
|
| 309 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 310 |
+
reshape: bool = False,
|
| 311 |
+
return_class_token: bool = False,
|
| 312 |
+
norm=True,
|
| 313 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 314 |
+
if self.chunked_blocks:
|
| 315 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 316 |
+
else:
|
| 317 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 318 |
+
if norm:
|
| 319 |
+
outputs = [self.norm(out) for out in outputs]
|
| 320 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 321 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 322 |
+
if reshape:
|
| 323 |
+
B, _, w, h = x.shape
|
| 324 |
+
outputs = [
|
| 325 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 326 |
+
for out in outputs
|
| 327 |
+
]
|
| 328 |
+
if return_class_token:
|
| 329 |
+
return tuple(zip(outputs, class_tokens))
|
| 330 |
+
return tuple(outputs)
|
| 331 |
+
|
| 332 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 333 |
+
ret = self.forward_features(*args, **kwargs)
|
| 334 |
+
if is_training:
|
| 335 |
+
return ret
|
| 336 |
+
else:
|
| 337 |
+
return self.head(ret["x_norm_clstoken"])
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 341 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 342 |
+
if isinstance(module, nn.Linear):
|
| 343 |
+
trunc_normal_(module.weight, std=0.02)
|
| 344 |
+
if module.bias is not None:
|
| 345 |
+
nn.init.zeros_(module.bias)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 349 |
+
model = DinoVisionTransformer(
|
| 350 |
+
patch_size=patch_size,
|
| 351 |
+
embed_dim=384,
|
| 352 |
+
depth=12,
|
| 353 |
+
num_heads=6,
|
| 354 |
+
mlp_ratio=4,
|
| 355 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 356 |
+
num_register_tokens=num_register_tokens,
|
| 357 |
+
**kwargs,
|
| 358 |
+
)
|
| 359 |
+
return model
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 363 |
+
model = DinoVisionTransformer(
|
| 364 |
+
patch_size=patch_size,
|
| 365 |
+
embed_dim=768,
|
| 366 |
+
depth=12,
|
| 367 |
+
num_heads=12,
|
| 368 |
+
mlp_ratio=4,
|
| 369 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 370 |
+
num_register_tokens=num_register_tokens,
|
| 371 |
+
**kwargs,
|
| 372 |
+
)
|
| 373 |
+
return model
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 377 |
+
model = DinoVisionTransformer(
|
| 378 |
+
patch_size=patch_size,
|
| 379 |
+
embed_dim=1024,
|
| 380 |
+
depth=24,
|
| 381 |
+
num_heads=16,
|
| 382 |
+
mlp_ratio=4,
|
| 383 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 384 |
+
num_register_tokens=num_register_tokens,
|
| 385 |
+
**kwargs,
|
| 386 |
+
)
|
| 387 |
+
return model
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 391 |
+
"""
|
| 392 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 393 |
+
"""
|
| 394 |
+
model = DinoVisionTransformer(
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
embed_dim=1536,
|
| 397 |
+
depth=40,
|
| 398 |
+
num_heads=24,
|
| 399 |
+
mlp_ratio=4,
|
| 400 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 401 |
+
num_register_tokens=num_register_tokens,
|
| 402 |
+
**kwargs,
|
| 403 |
+
)
|
| 404 |
+
return model
|