qizhangslam commited on
Commit
ae2def3
·
verified ·
1 Parent(s): 69db92d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. outdoor_v48_16gpu_v2/.hydra/config.yaml +68 -0
  2. outdoor_v48_16gpu_v2/.hydra/hydra.yaml +156 -0
  3. outdoor_v48_16gpu_v2/.hydra/overrides.yaml +2 -0
  4. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/__init__.py +0 -0
  5. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/base_multiview_dataset.py +576 -0
  6. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/batched_sampler.py +93 -0
  7. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/easy_dataset.py +212 -0
  8. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/dynamic_replica.py +137 -0
  9. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/habitat_hm3d.py +174 -0
  10. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/hoi4d.py +84 -0
  11. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mapfree.py +282 -0
  12. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mvs_synth.py +144 -0
  13. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/omniobject3d.py +146 -0
  14. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/pointodyssey.py +178 -0
  15. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/realestate10k.py +139 -0
  16. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannet.py +149 -0
  17. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannetpp.py +211 -0
  18. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/smartportraits.py +85 -0
  19. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/threedkb.py +111 -0
  20. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/unreal4k.py +159 -0
  21. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/__init__.py +2 -0
  22. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/corr.py +129 -0
  23. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/cropping.py +147 -0
  24. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/transforms.py +80 -0
  25. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/waymo.py +178 -0
  26. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/wildrgbd.py +56 -0
  27. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/__init__.py +1 -0
  28. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/camera.py +463 -0
  29. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/device.py +88 -0
  30. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/geometry.py +554 -0
  31. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/image.py +271 -0
  32. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/misc.py +127 -0
  33. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/parallel.py +87 -0
  34. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/path_to_croco.py +47 -0
  35. outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/render.py +75 -0
  36. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/__init__.py +6 -0
  37. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/__init__.py +4 -0
  38. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/backbones.py +156 -0
  39. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/utils.py +39 -0
  40. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/__init__.py +11 -0
  41. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/attention.py +89 -0
  42. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/block.py +259 -0
  43. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/dino_head.py +58 -0
  44. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/drop_path.py +34 -0
  45. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/layer_scale.py +27 -0
  46. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/mlp.py +40 -0
  47. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/patch_embed.py +88 -0
  48. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/swiglu_ffn.py +72 -0
  49. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/__init__.py +43 -0
  50. outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/vision_transformer.py +404 -0
outdoor_v48_16gpu_v2/.hydra/config.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ teacher: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model
2
+ pretrained: /gpfs/work2/0/prjs0824/qi_proj/ckpt/checkpoint-10.pth.model
3
+ load_only_encoder: false
4
+ long_context: false
5
+ fixed_length: true
6
+ resume: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth
7
+ benchmark: false
8
+ num_views: 64
9
+ num_test_views: 4
10
+ n_corres_train: 0
11
+ n_corres_test: 0
12
+ train_criterion: DistillLoss()
13
+ test_criterion: DistillLoss()
14
+ allow_repeat: false
15
+ root_vkitti2: /scratch-shared/wwei2/training/preprocessed_vkitti/mast3r_data/processed_vkitti
16
+ root_kitti: /scratch-shared/wwei2/eval/kitti_odometry/dataset
17
+ root_kitti_velo: /gpfs/work2/0/prjs0824/semantickitti/dataset
18
+ root_kitti360: /scratch-shared/wwei2/downloads/kitti360/KITTI-360
19
+ root_kitti360_velo: /scratch-shared/wwei2/downloads/kitti360/KITTI-360
20
+ root_waymo: /scratch-shared/wwei2/waymo_v2
21
+ root_waymo_lidar: /scratch-shared/wwei2/waymo_v2
22
+ dataset_vkitti2: VirtualKITTI2_Multi(allow_repeat=${allow_repeat}, split='train',
23
+ ROOT="${root_vkitti2}", aug_crop=16, resolution=[(518, 392), (518, 336), (518, 294),
24
+ (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views},
25
+ n_corres=${n_corres_train})
26
+ dataset_kitti360: KITTI360_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_kitti360}",
27
+ velodyne_root="${root_kitti360_velo}", aug_crop=16, resolution=[(518, 392), (518,
28
+ 336), (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter,
29
+ num_views=${num_views}, n_corres=${n_corres_train})
30
+ dataset_waymo: Waymo_v2_Multi(allow_repeat=${allow_repeat}, split='train', ROOT="${root_waymo}",
31
+ lidar_root="${root_waymo_lidar}", aug_crop=16, resolution=[(518, 392), (518, 336),
32
+ (518, 294), (518, 266), (518, 210), (518, 154)], transform=SeqColorJitter, num_views=${num_views},
33
+ n_corres=${n_corres_train})
34
+ train_dataset: 6000 @ ${dataset_vkitti2} + 6000 @ ${dataset_kitti360} + 5400 @ ${dataset_waymo}
35
+ test_dataset: 200 @ VirtualKITTI2_Multi(split='train', ROOT="${root_vkitti2}", resolution=(518,
36
+ 154), num_views=${num_test_views}, seed=42, n_corres=${n_corres_test})
37
+ seed: 0
38
+ batch_size: 1
39
+ accum_iter: 1
40
+ gradient_checkpointing: false
41
+ epochs: 10
42
+ start_epoch: 0
43
+ start_step: 0
44
+ weight_decay: 0.05
45
+ lr: 1.0e-05
46
+ min_lr: 1.0e-08
47
+ warmup_epochs: 0.5
48
+ amp: 1
49
+ num_workers: 4
50
+ world_size: 1
51
+ local-rank: -1
52
+ dist_url: env://
53
+ rank: 0
54
+ gpu: 0
55
+ distributed: false
56
+ dist_backend: nccl
57
+ eval_freq: 1
58
+ save_freq: 0.1
59
+ max_checkpoints: 10
60
+ keep_freq: 1
61
+ print_freq: 10
62
+ print_img_freq: 50000000
63
+ num_imgs_vis: 4
64
+ save_dir: /scratch-shared/wwei2/training_upstream/checkpoints
65
+ exp_name: outdoor_v48_16gpu_v2
66
+ task: StreamVGGT
67
+ logdir: ${save_dir}/${exp_name}/logs
68
+ output_dir: ${save_dir}/${exp_name}/
outdoor_v48_16gpu_v2/.hydra/hydra.yaml ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ${save_dir}/${exp_name}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - exp_name=outdoor_v48_16gpu_v2
116
+ - resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth
117
+ job:
118
+ name: mytrain
119
+ chdir: null
120
+ override_dirname: exp_name=outdoor_v48_16gpu_v2,resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth
121
+ id: ???
122
+ num: ???
123
+ config_name: outdoor_v48
124
+ env_set: {}
125
+ env_copy: []
126
+ config:
127
+ override_dirname:
128
+ kv_sep: '='
129
+ item_sep: ','
130
+ exclude_keys: []
131
+ runtime:
132
+ version: 1.3.2
133
+ version_base: '1.3'
134
+ cwd: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/src
135
+ config_sources:
136
+ - path: hydra.conf
137
+ schema: pkg
138
+ provider: hydra
139
+ - path: /gpfs/work2/0/prjs0824/qi_proj/slamformer_upstream/config
140
+ schema: file
141
+ provider: main
142
+ - path: ''
143
+ schema: structured
144
+ provider: schema
145
+ output_dir: /scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_16gpu_v2
146
+ choices:
147
+ hydra/env: default
148
+ hydra/callbacks: null
149
+ hydra/job_logging: default
150
+ hydra/hydra_logging: default
151
+ hydra/hydra_help: default
152
+ hydra/help: default
153
+ hydra/sweeper: basic
154
+ hydra/launcher: basic
155
+ hydra/output: default
156
+ verbose: true
outdoor_v48_16gpu_v2/.hydra/overrides.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ - exp_name=outdoor_v48_16gpu_v2
2
+ - resume=/scratch-shared/wwei2/training_upstream/checkpoints/outdoor_v48_4gpu_v2/checkpoint-last.pth
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/__init__.py ADDED
File without changes
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/base_multiview_dataset.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+ import torch
4
+ import random
5
+ import itertools
6
+ from dust3r.datasets.base.easy_dataset import EasyDataset
7
+ from dust3r.datasets.utils.transforms import ImgNorm, SeqColorJitter
8
+ from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
9
+ import dust3r.datasets.utils.cropping as cropping
10
+ from dust3r.datasets.utils.corr import extract_correspondences_from_pts3d
11
+
12
+ from vggt.train_utils.augmentation import get_image_augmentation
13
+
14
+
15
+
16
+ def get_ray_map(c2w1, c2w2, intrinsics, h, w):
17
+ c2w = np.linalg.inv(c2w1) @ c2w2
18
+ i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy")
19
+ grid = np.stack([i, j, np.ones_like(i)], axis=-1)
20
+ ro = c2w[:3, 3]
21
+ rd = np.linalg.inv(intrinsics) @ grid.reshape(-1, 3).T
22
+ rd = (c2w @ np.vstack([rd, np.ones_like(rd[0])])).T[:, :3].reshape(h, w, 3)
23
+ rd = rd / np.linalg.norm(rd, axis=-1, keepdims=True)
24
+ ro = np.broadcast_to(ro, (h, w, 3))
25
+ ray_map = np.concatenate([ro, rd], axis=-1)
26
+ return ray_map
27
+
28
+
29
+ class BaseMultiViewDataset(EasyDataset):
30
+ """Define all basic options.
31
+
32
+ Usage:
33
+ class MyDataset (BaseMultiViewDataset):
34
+ def _get_views(self, idx, rng):
35
+ # overload here
36
+ views = []
37
+ views.append(dict(img=, ...))
38
+ return views
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ *, # only keyword arguments
44
+ num_views=None,
45
+ split=None,
46
+ resolution=None, # square_size or (width, height) or list of [(width,height), ...]
47
+ transform=ImgNorm,
48
+ aug_crop=False,
49
+ n_corres=0,
50
+ nneg=0,
51
+ seed=None,
52
+ allow_repeat=False,
53
+ seq_aug_crop=False,
54
+ ):
55
+ assert num_views is not None, "undefined num_views"
56
+ self.num_views = num_views
57
+ self.split = split
58
+ self._set_resolutions(resolution)
59
+
60
+ self.n_corres = n_corres
61
+ self.nneg = nneg
62
+ assert (
63
+ self.n_corres == "all"
64
+ or isinstance(self.n_corres, int)
65
+ or (
66
+ isinstance(self.n_corres, list) and len(self.n_corres) == self.num_views
67
+ )
68
+ ), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}"
69
+ assert (
70
+ self.nneg == 0 or self.n_corres != "all"
71
+ ), "nneg should be 0 if n_corres is all"
72
+
73
+ self.is_seq_color_jitter = False
74
+ if isinstance(transform, str):
75
+ transform = eval(transform)
76
+ if transform == SeqColorJitter:
77
+ transform = SeqColorJitter()
78
+ self.is_seq_color_jitter = True
79
+ self.transform = transform
80
+
81
+ self.image_aug = get_image_augmentation(
82
+ color_jitter={ 'brightness': 0.5,
83
+ 'contrast': 0.5,
84
+ 'saturation': 0.5,
85
+ 'hue': 0.1,
86
+ 'p': 0.9},
87
+ #common_config.augs.color_jitter,
88
+ gray_scale=True,#common_config.augs.gray_scale,
89
+ gau_blur=False, #common_config.augs.gau_blur,
90
+ )
91
+
92
+
93
+ self.aug_crop = aug_crop
94
+ self.seed = seed
95
+ self.allow_repeat = allow_repeat
96
+ self.seq_aug_crop = seq_aug_crop
97
+
98
+ def __len__(self):
99
+ return len(self.scenes)
100
+
101
+ @staticmethod
102
+ def efficient_random_intervals(
103
+ start,
104
+ num_elements,
105
+ interval_range,
106
+ fixed_interval_prob=0.8,
107
+ weights=None,
108
+ seed=42,
109
+ ):
110
+ if random.random() < fixed_interval_prob:
111
+ intervals = random.choices(interval_range, weights=weights) * (
112
+ num_elements - 1
113
+ )
114
+ else:
115
+ intervals = [
116
+ random.choices(interval_range, weights=weights)[0]
117
+ for _ in range(num_elements - 1)
118
+ ]
119
+ return list(itertools.accumulate([start] + intervals))
120
+
121
+ def sample_based_on_timestamps(self, i, timestamps, num_views, interval=1):
122
+ time_diffs = np.abs(timestamps - timestamps[i])
123
+ ids_candidate = np.where(time_diffs < interval)[0]
124
+ ids_candidate = np.sort(ids_candidate)
125
+ if (self.allow_repeat and len(ids_candidate) < num_views // 3) or (
126
+ len(ids_candidate) < num_views
127
+ ):
128
+ return []
129
+ ids_sel_list = []
130
+ ids_candidate_left = ids_candidate.copy()
131
+ while len(ids_candidate_left) >= num_views:
132
+ ids_sel = np.random.choice(ids_candidate_left, num_views, replace=False)
133
+ ids_sel_list.append(sorted(ids_sel))
134
+ ids_candidate_left = np.setdiff1d(ids_candidate_left, ids_sel)
135
+
136
+ if len(ids_candidate_left) > 0 and len(ids_candidate) >= num_views:
137
+ ids_sel = np.concatenate(
138
+ [
139
+ ids_candidate_left,
140
+ np.random.choice(
141
+ np.setdiff1d(ids_candidate, ids_candidate_left),
142
+ num_views - len(ids_candidate_left),
143
+ replace=False,
144
+ ),
145
+ ]
146
+ )
147
+ ids_sel_list.append(sorted(ids_sel))
148
+
149
+ if self.allow_repeat:
150
+ ids_sel_list.append(
151
+ sorted(np.random.choice(ids_candidate, num_views, replace=True))
152
+ )
153
+
154
+ # add sequences with fixed intervals (all possible intervals)
155
+ pos_i = np.where(ids_candidate == i)[0][0]
156
+ curr_interval = 1
157
+ stop = len(ids_candidate) < num_views
158
+ while not stop:
159
+ pos_sel = [pos_i]
160
+ count = 0
161
+ while len(pos_sel) < num_views:
162
+ if count % 2 == 0:
163
+ curr_pos_i = pos_sel[-1] + curr_interval
164
+ if curr_pos_i >= len(ids_candidate):
165
+ stop = True
166
+ break
167
+ pos_sel.append(curr_pos_i)
168
+ else:
169
+ curr_pos_i = pos_sel[0] - curr_interval
170
+ if curr_pos_i < 0:
171
+ stop = True
172
+ break
173
+ pos_sel.insert(0, curr_pos_i)
174
+ count += 1
175
+ if not stop and len(pos_sel) == num_views:
176
+ ids_sel = sorted([ids_candidate[pos] for pos in pos_sel])
177
+ if ids_sel not in ids_sel_list:
178
+ ids_sel_list.append(ids_sel)
179
+ curr_interval += 1
180
+ return ids_sel_list
181
+
182
+ @staticmethod
183
+ def blockwise_shuffle(x, rng, block_shuffle):
184
+ if block_shuffle is None:
185
+ return rng.permutation(x).tolist()
186
+ else:
187
+ assert block_shuffle > 0
188
+ blocks = [x[i : i + block_shuffle] for i in range(0, len(x), block_shuffle)]
189
+ shuffled_blocks = [rng.permutation(block).tolist() for block in blocks]
190
+ shuffled_list = [item for block in shuffled_blocks for item in block]
191
+ return shuffled_list
192
+
193
+ def get_seq_from_start_id(
194
+ self,
195
+ num_views,
196
+ id_ref,
197
+ ids_all,
198
+ rng,
199
+ min_interval=1,
200
+ max_interval=25,
201
+ video_prob=0.5,
202
+ fix_interval_prob=0.5,
203
+ block_shuffle=None,
204
+ ):
205
+ """
206
+ args:
207
+ num_views: number of views to return
208
+ id_ref: the reference id (first id)
209
+ ids_all: all the ids
210
+ rng: random number generator
211
+ max_interval: maximum interval between two views
212
+ returns:
213
+ pos: list of positions of the views in ids_all, i.e., index for ids_all
214
+ is_video: True if the views are consecutive
215
+ """
216
+ assert min_interval > 0, f"min_interval should be > 0, got {min_interval}"
217
+ assert (
218
+ min_interval <= max_interval
219
+ ), f"min_interval should be <= max_interval, got {min_interval} and {max_interval}"
220
+ assert id_ref in ids_all
221
+ pos_ref = ids_all.index(id_ref)
222
+ all_possible_pos = np.arange(pos_ref, len(ids_all))
223
+
224
+ remaining_sum = len(ids_all) - 1 - pos_ref
225
+
226
+ if remaining_sum >= num_views - 1:
227
+ if remaining_sum == num_views - 1:
228
+ assert ids_all[-num_views] == id_ref
229
+ return [pos_ref + i for i in range(num_views)], True
230
+ max_interval = min(max_interval, 2 * remaining_sum // (num_views - 1))
231
+ intervals = [
232
+ rng.choice(range(min_interval, max_interval + 1))
233
+ for _ in range(num_views - 1)
234
+ ]
235
+
236
+ # if video or collection
237
+ if rng.random() < video_prob:
238
+ # if fixed interval or random
239
+ if rng.random() < fix_interval_prob:
240
+ # regular interval
241
+ fixed_interval = rng.choice(
242
+ range(
243
+ 1,
244
+ min(remaining_sum // (num_views - 1) + 1, max_interval + 1),
245
+ )
246
+ )
247
+ intervals = [fixed_interval for _ in range(num_views - 1)]
248
+ is_video = True
249
+ else:
250
+ is_video = False
251
+
252
+ pos = list(itertools.accumulate([pos_ref] + intervals))
253
+ pos = [p for p in pos if p < len(ids_all)]
254
+ pos_candidates = [p for p in all_possible_pos if p not in pos]
255
+ pos = (
256
+ pos
257
+ + rng.choice(
258
+ pos_candidates, num_views - len(pos), replace=False
259
+ ).tolist()
260
+ )
261
+
262
+ pos = (
263
+ sorted(pos)
264
+ if is_video
265
+ else self.blockwise_shuffle(pos, rng, block_shuffle)
266
+ )
267
+ #elif remaining_sum>1:
268
+ else:
269
+ # assert self.allow_repeat
270
+ uniq_num = remaining_sum
271
+ new_pos_ref = rng.choice(np.arange(pos_ref + 1))
272
+ new_remaining_sum = len(ids_all) - 1 - new_pos_ref
273
+ new_max_interval = min(max_interval, new_remaining_sum // (uniq_num - 1))
274
+ new_intervals = [
275
+ rng.choice(range(1, new_max_interval + 1)) for _ in range(uniq_num - 1)
276
+ ]
277
+
278
+ revisit_random = rng.random()
279
+ video_random = rng.random()
280
+
281
+ if rng.random() < fix_interval_prob and video_random < video_prob:
282
+ # regular interval
283
+ fixed_interval = rng.choice(range(1, new_max_interval + 1))
284
+ new_intervals = [fixed_interval for _ in range(uniq_num - 1)]
285
+ pos = list(itertools.accumulate([new_pos_ref] + new_intervals))
286
+
287
+ is_video = False
288
+ if revisit_random < 0.5 or video_prob == 1.0: # revisit, video / collection
289
+ is_video = video_random < video_prob
290
+ pos = (
291
+ self.blockwise_shuffle(pos, rng, block_shuffle)
292
+ if not is_video
293
+ else pos
294
+ )
295
+ num_full_repeat = num_views // uniq_num
296
+ pos = (
297
+ pos * num_full_repeat
298
+ + pos[: num_views - len(pos) * num_full_repeat]
299
+ )
300
+ elif revisit_random < 0.9: # random
301
+ pos = rng.choice(pos, num_views, replace=True)
302
+ else: # ordered
303
+ pos = sorted(rng.choice(pos, num_views, replace=True))
304
+ assert len(pos) == num_views
305
+ return pos, is_video
306
+
307
+ def get_img_and_ray_masks(self, is_metric, v, rng, p=[0.8, 0.15, 0.05]):
308
+ # generate img mask and raymap mask
309
+ if v == 0 or (not is_metric):
310
+ img_mask = True
311
+ raymap_mask = False
312
+ else:
313
+ rand_val = rng.random()
314
+ if rand_val < p[0]:
315
+ img_mask = True
316
+ raymap_mask = False
317
+ elif rand_val < p[0] + p[1]:
318
+ img_mask = False
319
+ raymap_mask = True
320
+ else:
321
+ img_mask = True
322
+ raymap_mask = True
323
+ return img_mask, raymap_mask
324
+
325
+ def get_stats(self):
326
+ return f"{len(self)} groups of views"
327
+
328
+ def __repr__(self):
329
+ resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
330
+ return (
331
+ f"""{type(self).__name__}({self.get_stats()},
332
+ {self.num_views=},
333
+ {self.split=},
334
+ {self.seed=},
335
+ resolutions={resolutions_str},
336
+ {self.transform=})""".replace(
337
+ "self.", ""
338
+ )
339
+ .replace("\n", "")
340
+ .replace(" ", "")
341
+ )
342
+
343
+ def _get_views(self, idx, resolution, rng, num_views):
344
+ raise NotImplementedError()
345
+
346
+ def __getitem__(self, idx):
347
+ # print("Receiving:" , idx)
348
+ if isinstance(idx, (tuple, list, np.ndarray)):
349
+ # the idx is specifying the aspect-ratio
350
+ idx, ar_idx, nview = idx
351
+ else:
352
+ assert len(self._resolutions) == 1
353
+ ar_idx = 0
354
+ nview = self.num_views
355
+
356
+ assert nview >= 1 and nview <= self.num_views
357
+ # set-up the rng
358
+ if self.seed: # reseed for each __getitem__
359
+ self._rng = np.random.default_rng(seed=self.seed + idx)
360
+ elif not hasattr(self, "_rng"):
361
+ seed = torch.randint(0, 2**32, (1,)).item()
362
+ self._rng = np.random.default_rng(seed=seed)
363
+
364
+ if self.aug_crop > 1 and self.seq_aug_crop:
365
+ self.delta_target_resolution = self._rng.integers(0, self.aug_crop)
366
+
367
+ # over-loaded code
368
+ resolution = self._resolutions[
369
+ ar_idx
370
+ ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
371
+ views = self._get_views(idx, resolution, self._rng, nview)
372
+ assert len(views) == nview
373
+
374
+ if "camera_pose" not in views[0]:
375
+ views[0]["camera_pose"] = np.ones((4, 4), dtype=np.float32)
376
+ first_view_camera_pose = views[0]["camera_pose"]
377
+ transform = SeqColorJitter() if self.is_seq_color_jitter else self.transform
378
+
379
+ for v, view in enumerate(views):
380
+ assert (
381
+ "pts3d" not in view
382
+ ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
383
+ view["idx"] = (idx, ar_idx, v)
384
+
385
+ # encode the image
386
+ width, height = view["img"].size
387
+
388
+ view["true_shape"] = np.int32((height, width))
389
+ view["img"] = transform(view["img"])
390
+ view["sky_mask"] = view["depthmap"] < 0
391
+
392
+ assert "camera_intrinsics" in view
393
+ if "camera_pose" not in view:
394
+ view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
395
+ else:
396
+ assert np.isfinite(
397
+ view["camera_pose"]
398
+ ).all(), f"NaN in camera pose for view {view_name(view)}"
399
+
400
+ ray_map = get_ray_map(
401
+ first_view_camera_pose,
402
+ view["camera_pose"],
403
+ view["camera_intrinsics"],
404
+ height,
405
+ width,
406
+ )
407
+ view["ray_map"] = ray_map.astype(np.float32)
408
+
409
+ assert "pts3d" not in view
410
+ assert "valid_mask" not in view
411
+ assert np.isfinite(
412
+ view["depthmap"]
413
+ ).all(), f"NaN in depthmap for view {view_name(view)}"
414
+ pts3d, pts3d_local, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
415
+
416
+
417
+
418
+
419
+ view["pts3d"] = pts3d
420
+ view["pts3d_local"] = pts3d_local
421
+ view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
422
+
423
+ # check all datatypes
424
+ for key, val in view.items():
425
+ res, err_msg = is_good_type(key, val)
426
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
427
+ K = view["camera_intrinsics"]
428
+ if False:
429
+ if random.random() > 0.3:#self.cojitter_ratio:
430
+ images = torch.stack([view['img'] for view in views],axis=0)
431
+ images = self.image_aug(images)
432
+ for v, view in enumerate(views):
433
+ view['img'] = images[v]
434
+
435
+ else:
436
+ for view in views:
437
+ view['img'] = self.image_aug(view['img'][None])[0]
438
+
439
+ if self.n_corres > 0:
440
+ ref_view = views[0]
441
+ for view in views:
442
+ corres1, corres2, valid = extract_correspondences_from_pts3d(
443
+ ref_view, view, self.n_corres, self._rng, nneg=self.nneg
444
+ )
445
+ view["corres"] = (corres1, corres2)
446
+ view["valid_corres"] = valid
447
+
448
+ # last thing done!
449
+ for view in views:
450
+ view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
451
+ return views
452
+
453
+ def _set_resolutions(self, resolutions):
454
+ assert resolutions is not None, "undefined resolution"
455
+
456
+ if not isinstance(resolutions, list):
457
+ resolutions = [resolutions]
458
+
459
+ self._resolutions = []
460
+ for resolution in resolutions:
461
+ if isinstance(resolution, int):
462
+ width = height = resolution
463
+ else:
464
+ width, height = resolution
465
+ assert isinstance(
466
+ width, int
467
+ ), f"Bad type for {width=} {type(width)=}, should be int"
468
+ assert isinstance(
469
+ height, int
470
+ ), f"Bad type for {height=} {type(height)=}, should be int"
471
+ self._resolutions.append((width, height))
472
+
473
+ def _crop_resize_if_necessary(
474
+ self, image, depthmap, intrinsics, resolution, rng=None, info=None
475
+ ):
476
+ """This function:
477
+ - first downsizes the image with LANCZOS inteprolation,
478
+ which is better than bilinear interpolation in
479
+ """
480
+ if not isinstance(image, PIL.Image.Image):
481
+ image = PIL.Image.fromarray(image)
482
+
483
+ # downscale with lanczos interpolation so that image.size == resolution
484
+ # cropping centered on the principal point
485
+ W, H = image.size
486
+ cx, cy = intrinsics[:2, 2].round().astype(int)
487
+ min_margin_x = min(cx, W - cx)
488
+ min_margin_y = min(cy, H - cy)
489
+ assert min_margin_x > W / 5, f"Bad principal point in view={info}"
490
+ assert min_margin_y > H / 5, f"Bad principal point in view={info}"
491
+ # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
492
+ l, t = cx - min_margin_x, cy - min_margin_y
493
+ r, b = cx + min_margin_x, cy + min_margin_y
494
+ crop_bbox = (l, t, r, b)
495
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(
496
+ image, depthmap, intrinsics, crop_bbox
497
+ )
498
+
499
+ # transpose the resolution if necessary
500
+ W, H = image.size # new size
501
+
502
+ # high-quality Lanczos down-scaling
503
+ target_resolution = np.array(resolution)
504
+ if self.aug_crop > 1:
505
+ target_resolution += (
506
+ rng.integers(0, self.aug_crop)
507
+ if not self.seq_aug_crop
508
+ else self.delta_target_resolution
509
+ )
510
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(
511
+ image, depthmap, intrinsics, target_resolution
512
+ )
513
+
514
+ # actual cropping (if necessary) with bilinear interpolation
515
+ intrinsics2 = cropping.camera_matrix_of_crop(
516
+ intrinsics, image.size, resolution, offset_factor=0.5
517
+ )
518
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(
519
+ intrinsics, intrinsics2, resolution
520
+ )
521
+ image, depthmap, intrinsics2 = cropping.crop_image_depthmap(
522
+ image, depthmap, intrinsics, crop_bbox
523
+ )
524
+
525
+ return image, depthmap, intrinsics2
526
+
527
+
528
+ def is_good_type(key, v):
529
+ """returns (is_good, err_msg)"""
530
+ if isinstance(v, (str, int, tuple)):
531
+ return True, None
532
+ if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
533
+ return False, f"bad {v.dtype=}"
534
+ return True, None
535
+
536
+
537
+ def view_name(view, batch_index=None):
538
+ def sel(x):
539
+ return x[batch_index] if batch_index not in (None, slice(None)) else x
540
+
541
+ db = sel(view["dataset"])
542
+ label = sel(view["label"])
543
+ instance = sel(view["instance"])
544
+ return f"{db}/{label}/{instance}"
545
+
546
+
547
+ def transpose_to_landscape(view):
548
+ height, width = view["true_shape"]
549
+
550
+ if width < height:
551
+ # rectify portrait to landscape
552
+ assert view["img"].shape == (3, height, width)
553
+ view["img"] = view["img"].swapaxes(1, 2)
554
+
555
+ assert view["valid_mask"].shape == (height, width)
556
+ view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
557
+
558
+ assert view["depthmap"].shape == (height, width)
559
+ view["depthmap"] = view["depthmap"].swapaxes(0, 1)
560
+
561
+ assert view["pts3d"].shape == (height, width, 3)
562
+ view["pts3d"] = view["pts3d"].swapaxes(0, 1)
563
+
564
+ # transpose x and y pixels
565
+ view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
566
+
567
+ assert view["ray_map"].shape == (height, width, 6)
568
+ view["ray_map"] = view["ray_map"].swapaxes(0, 1)
569
+
570
+ assert view["sky_mask"].shape == (height, width)
571
+ view["sky_mask"] = view["sky_mask"].swapaxes(0, 1)
572
+
573
+ if "corres" in view:
574
+ # transpose correspondences x and y
575
+ view["corres"][0] = view["corres"][0][:, [1, 0]]
576
+ view["corres"][1] = view["corres"][1][:, [1, 0]]
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/batched_sampler.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from accelerate import Accelerator
4
+ import torch.utils
5
+ from torch.utils.data import BatchSampler, Sampler
6
+ import torch.utils.data
7
+
8
+
9
+ class CustomRandomSampler(Sampler):
10
+ """Random sampling under a constraint: each sample in the batch has the same feature,
11
+ which is chosen randomly from a known pool of 'features' for each batch.
12
+
13
+ For instance, the 'feature' could be the image aspect-ratio.
14
+
15
+ The index returned is a tuple (sample_idx, feat_idx).
16
+ This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ dataset,
22
+ batch_size,
23
+ pool_size,
24
+ min_view_size,
25
+ max_view_size,
26
+ world_size,
27
+ warmup=1,
28
+ drop_last=True,
29
+ ):
30
+ self.batch_size = batch_size
31
+ self.pool_size = pool_size
32
+ self.min_view_size = min_view_size
33
+ self.max_view_size = max_view_size
34
+ self.drop_last = drop_last
35
+ self.len_dataset = N = len(dataset)
36
+ self.total_size = N
37
+ self.epoch = None
38
+ self.epochf = 0.0
39
+
40
+ def __len__(self):
41
+ return self.total_size
42
+
43
+ def set_epoch(self, epoch):
44
+ self.epoch = epoch
45
+
46
+ def __iter__(self):
47
+ if self.epoch is None:
48
+ raise ValueError(
49
+ "Epoch number not set. Please call 'set_epoch(epoch)' before iterating."
50
+ )
51
+
52
+ seed = self.epoch + 788
53
+ rng = np.random.default_rng(seed=seed)
54
+ # random indices (will restart from 0 if not drop_last)
55
+ sample_idxs = np.arange(self.total_size)
56
+ rng.shuffle(sample_idxs)
57
+ # random feat_idxs (same across each batch)
58
+ n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
59
+ if self.pool_size > 1:
60
+ p = np.ones(self.pool_size)
61
+ p[: self.pool_size // 2] *= 2
62
+ p = p / p.sum()
63
+ _feat_idxs = rng.choice(self.pool_size, size=n_batches, p=p)
64
+ else:
65
+ _feat_idxs = rng.integers(self.pool_size, size=n_batches)
66
+ _feat_idxs = np.broadcast_to(_feat_idxs[:, None], (n_batches, self.batch_size))
67
+ _feat_idxs = _feat_idxs.ravel()[: self.total_size]
68
+ _view_idxs = rng.integers(
69
+ self.min_view_size, self.max_view_size + 1, size=n_batches
70
+ )
71
+ _view_idxs = np.broadcast_to(_view_idxs[:, None], (n_batches, self.batch_size))
72
+ _view_idxs = _view_idxs.ravel()[: self.total_size]
73
+
74
+ idxs = np.c_[sample_idxs, _feat_idxs, _view_idxs]
75
+ yield from (tuple(idx) for idx in idxs)
76
+
77
+
78
+ class BatchedRandomSampler(BatchSampler):
79
+ """Batch sampler that groups indices from RandomSampler into batches."""
80
+
81
+ def __init__(self, sampler: CustomRandomSampler, batch_size, drop_last=True):
82
+ self.sampler = sampler # An instance of RandomSampler
83
+ self.batch_size = batch_size
84
+ self.drop_last = drop_last
85
+
86
+ def set_epoch(self, epoch):
87
+ self.sampler.set_epoch(epoch)
88
+
89
+
90
+ def round_by(total, multiple, up=False):
91
+ if up:
92
+ total = total + multiple - 1
93
+ return (total // multiple) * multiple
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/base/easy_dataset.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # modified from DUSt3R
6
+
7
+ import numpy as np
8
+ from dust3r.datasets.base.batched_sampler import (
9
+ BatchedRandomSampler,
10
+ CustomRandomSampler,
11
+ )
12
+ import torch
13
+
14
+
15
+ class EasyDataset:
16
+ """a dataset that you can easily resize and combine.
17
+ Examples:
18
+ ---------
19
+ 2 * dataset ==> duplicate each element 2x
20
+
21
+ 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
22
+
23
+ dataset1 + dataset2 ==> concatenate datasets
24
+ """
25
+
26
+ def __add__(self, other):
27
+ return CatDataset([self, other])
28
+
29
+ def __rmul__(self, factor):
30
+ return MulDataset(factor, self)
31
+
32
+ def __rmatmul__(self, factor):
33
+ return ResizedDataset(factor, self)
34
+
35
+ def set_epoch(self, epoch):
36
+ pass # nothing to do by default
37
+
38
+ def make_sampler(
39
+ self, batch_size, shuffle=True, drop_last=True, world_size=1, rank=0, fixed_length=False
40
+ ):
41
+ if not (shuffle):
42
+ raise NotImplementedError() # cannot deal yet
43
+ num_of_aspect_ratios = len(self._resolutions)
44
+ num_of_views = self.num_views
45
+ sampler = CustomRandomSampler(
46
+ self,
47
+ batch_size,
48
+ num_of_aspect_ratios,
49
+ 4 if not fixed_length else num_of_views,
50
+ num_of_views,
51
+ world_size,
52
+ warmup=1,
53
+ drop_last=drop_last,
54
+ )
55
+ return BatchedRandomSampler(sampler, batch_size, drop_last)
56
+
57
+
58
+ class MulDataset(EasyDataset):
59
+ """Artifically augmenting the size of a dataset."""
60
+
61
+ multiplicator: int
62
+
63
+ def __init__(self, multiplicator, dataset):
64
+ assert isinstance(multiplicator, int) and multiplicator > 0
65
+ self.multiplicator = multiplicator
66
+ self.dataset = dataset
67
+
68
+ def __len__(self):
69
+ return self.multiplicator * len(self.dataset)
70
+
71
+ def __repr__(self):
72
+ return f"{self.multiplicator}*{repr(self.dataset)}"
73
+
74
+ def __getitem__(self, idx):
75
+ if isinstance(idx, tuple):
76
+ idx, other, another = idx
77
+ return self.dataset[idx // self.multiplicator, other, another]
78
+ else:
79
+ return self.dataset[idx // self.multiplicator]
80
+
81
+ @property
82
+ def _resolutions(self):
83
+ return self.dataset._resolutions
84
+
85
+ @property
86
+ def num_views(self):
87
+ return self.dataset.num_views
88
+
89
+
90
+ class ResizedDataset(EasyDataset):
91
+ """Artifically changing the size of a dataset."""
92
+
93
+ new_size: int
94
+
95
+ def __init__(self, new_size, dataset):
96
+ assert isinstance(new_size, int) and new_size > 0
97
+ self.new_size = new_size
98
+ self.dataset = dataset
99
+
100
+ def __len__(self):
101
+ return self.new_size
102
+
103
+ def __repr__(self):
104
+ size_str = str(self.new_size)
105
+ for i in range((len(size_str) - 1) // 3):
106
+ sep = -4 * i - 3
107
+ size_str = size_str[:sep] + "_" + size_str[sep:]
108
+ return f"{size_str} @ {repr(self.dataset)}"
109
+
110
+ def set_epoch(self, epoch):
111
+ # this random shuffle only depends on the epoch
112
+ rng = np.random.default_rng(seed=epoch + 777)
113
+
114
+ # shuffle all indices
115
+ perm = rng.permutation(len(self.dataset))
116
+
117
+ # rotary extension until target size is met
118
+ shuffled_idxs = np.concatenate(
119
+ [perm] * (1 + (len(self) - 1) // len(self.dataset))
120
+ )
121
+ self._idxs_mapping = shuffled_idxs[: self.new_size]
122
+
123
+ assert len(self._idxs_mapping) == self.new_size
124
+
125
+ def __getitem__(self, idx):
126
+ assert hasattr(
127
+ self, "_idxs_mapping"
128
+ ), "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()"
129
+ if isinstance(idx, tuple):
130
+ idx, other, another = idx
131
+ return self.dataset[self._idxs_mapping[idx], other, another]
132
+ else:
133
+ return self.dataset[self._idxs_mapping[idx]]
134
+
135
+ @property
136
+ def _resolutions(self):
137
+ return self.dataset._resolutions
138
+
139
+ @property
140
+ def num_views(self):
141
+ return self.dataset.num_views
142
+
143
+
144
+ class CatDataset(EasyDataset):
145
+ """Concatenation of several datasets"""
146
+
147
+ def __init__(self, datasets):
148
+ for dataset in datasets:
149
+ assert isinstance(dataset, EasyDataset)
150
+ self.datasets = datasets
151
+ self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
152
+
153
+ def __len__(self):
154
+ return self._cum_sizes[-1]
155
+
156
+ def __repr__(self):
157
+ # remove uselessly long transform
158
+ return " + ".join(
159
+ repr(dataset).replace(
160
+ ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))",
161
+ "",
162
+ )
163
+ for dataset in self.datasets
164
+ )
165
+
166
+ def set_epoch(self, epoch):
167
+ for dataset in self.datasets:
168
+ dataset.set_epoch(epoch)
169
+
170
+ def __getitem__(self, idx):
171
+ other = None
172
+ if isinstance(idx, tuple):
173
+ idx, other, another = idx
174
+
175
+ cause_error = False
176
+ while True:
177
+
178
+ if not (0 <= idx < len(self)):
179
+ raise IndexError()
180
+
181
+ db_idx = np.searchsorted(self._cum_sizes, idx, "right")
182
+ dataset = self.datasets[db_idx]
183
+ new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
184
+
185
+ if other is not None and another is not None:
186
+ new_idx = (new_idx, other, another)
187
+
188
+ try:
189
+ res_data = dataset[new_idx]
190
+ except Exception as e:
191
+ print(e)
192
+ print("DATA ERROR", new_idx)
193
+ idx += 1
194
+ idx = idx % len(self)
195
+ continue
196
+
197
+ break
198
+ return res_data
199
+
200
+ @property
201
+ def _resolutions(self):
202
+ resolutions = self.datasets[0]._resolutions
203
+ for dataset in self.datasets[1:]:
204
+ assert tuple(dataset._resolutions) == tuple(resolutions)
205
+ return resolutions
206
+
207
+ @property
208
+ def num_views(self):
209
+ num_views = self.datasets[0].num_views
210
+ for dataset in self.datasets[1:]:
211
+ assert dataset.num_views == num_views
212
+ return num_views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/dynamic_replica.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+
14
+ class DynamicReplica(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = True
18
+ self.is_metric = True
19
+ self.max_interval = 16
20
+ super().__init__(*args, **kwargs)
21
+
22
+ self.loaded_data = self._load_data(self.split)
23
+
24
+ def _load_data(self, split):
25
+ self.scenes = os.listdir(os.path.join(self.ROOT, split))
26
+
27
+ offset = 0
28
+ scenes = []
29
+ sceneids = []
30
+ scene_img_list = []
31
+ images = []
32
+ start_img_ids = []
33
+
34
+ j = 0
35
+ for scene in tqdm(self.scenes):
36
+ scene_dir = osp.join(self.ROOT, self.split, scene, "left")
37
+ rgb_dir = osp.join(scene_dir, "rgb")
38
+ basenames = sorted(
39
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")],
40
+ key=lambda x: float(x),
41
+ )
42
+ num_imgs = len(basenames)
43
+ img_ids = list(np.arange(num_imgs) + offset)
44
+ cut_off = (
45
+ self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
46
+ )
47
+ if num_imgs < cut_off:
48
+ print(f"Skipping {scene}")
49
+ continue
50
+
51
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
52
+ start_img_ids.extend(start_img_ids_)
53
+ sceneids.extend([j] * num_imgs)
54
+ images.extend(basenames)
55
+ scenes.append(scene)
56
+ scene_img_list.append(img_ids)
57
+
58
+ # offset groups
59
+ offset += num_imgs
60
+ j += 1
61
+
62
+ self.scenes = scenes
63
+ self.sceneids = sceneids
64
+ self.images = images
65
+ self.start_img_ids = start_img_ids
66
+ self.scene_img_list = scene_img_list
67
+
68
+ def __len__(self):
69
+ return len(self.start_img_ids)
70
+
71
+ def get_image_num(self):
72
+ return len(self.images)
73
+
74
+ def _get_views(self, idx, resolution, rng, num_views):
75
+ start_id = self.start_img_ids[idx]
76
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]]
77
+ pos, ordered_video = self.get_seq_from_start_id(
78
+ num_views,
79
+ start_id,
80
+ all_image_ids,
81
+ rng,
82
+ max_interval=self.max_interval,
83
+ video_prob=1.0,
84
+ fix_interval_prob=1.0,
85
+ )
86
+ image_idxs = np.array(all_image_ids)[pos]
87
+
88
+ views = []
89
+ for v, view_idx in enumerate(image_idxs):
90
+ scene_id = self.sceneids[view_idx]
91
+ scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id], "left")
92
+ rgb_dir = osp.join(scene_dir, "rgb")
93
+ depth_dir = osp.join(scene_dir, "depth")
94
+ cam_dir = osp.join(scene_dir, "cam")
95
+
96
+ basename = self.images[view_idx]
97
+
98
+ # Load RGB image
99
+ rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png"))
100
+ # Load depthmap
101
+ depthmap = np.load(osp.join(depth_dir, basename + ".npy"))
102
+ depthmap[~np.isfinite(depthmap)] = 0 # invalid
103
+
104
+ cam = np.load(osp.join(cam_dir, basename + ".npz"))
105
+ camera_pose = cam["pose"]
106
+ intrinsics = cam["intrinsics"]
107
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
108
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
109
+ )
110
+
111
+ # generate img mask and raymap mask
112
+ img_mask, ray_mask = self.get_img_and_ray_masks(
113
+ self.is_metric, v, rng, p=[0.85, 0.10, 0.05]
114
+ )
115
+
116
+ views.append(
117
+ dict(
118
+ img=rgb_image,
119
+ depthmap=depthmap.astype(np.float32),
120
+ camera_pose=camera_pose.astype(np.float32),
121
+ camera_intrinsics=intrinsics.astype(np.float32),
122
+ dataset="dynamic_replica",
123
+ label=self.scenes[scene_id] + "_" + basename,
124
+ instance=f"{str(idx)}_{str(view_idx)}",
125
+ is_metric=self.is_metric,
126
+ is_video=ordered_video,
127
+ quantile=np.array(1.0, dtype=np.float32),
128
+ img_mask=img_mask,
129
+ ray_mask=ray_mask,
130
+ camera_only=False,
131
+ depth_only=False,
132
+ single_view=False,
133
+ reset=False,
134
+ )
135
+ )
136
+ assert len(views) == num_views
137
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/habitat_hm3d.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+
14
+ class HabitatHM3D_Multi(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = True
18
+ self.is_metric = False
19
+ self.max_interval = 8
20
+ super().__init__(*args, **kwargs)
21
+ self.loaded_data = self._load_data()
22
+
23
+ def _load_data(self):
24
+ self.scenes = os.listdir(self.ROOT)
25
+
26
+ offset = 0
27
+ scenes = []
28
+ sceneids = []
29
+ scene_img_list = []
30
+ images = []
31
+ start_img_ids = []
32
+
33
+ j = 0
34
+ for scene in tqdm(self.scenes):
35
+ scene_dir = osp.join(self.ROOT, scene)
36
+ basenames = sorted(
37
+ [f[:-4] for f in os.listdir(scene_dir) if f.endswith(".npz")],
38
+ key=lambda x: int(x),
39
+ )
40
+
41
+ num_imgs = len(basenames)
42
+ # TODO: because current minghui's training data is backward moving, now use seq from -1 to 0
43
+ img_ids = list(np.arange(num_imgs) + offset)
44
+ cut_off = (
45
+ self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
46
+ )
47
+ if num_imgs < cut_off:
48
+ print(f"Skipping {scene}")
49
+ continue
50
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
51
+
52
+ start_img_ids.extend([(scene, id) for id in start_img_ids_])
53
+ sceneids.extend([j] * num_imgs)
54
+ images.extend(basenames)
55
+ scenes.append(scene)
56
+ scene_img_list.append(img_ids)
57
+
58
+ # offset groups
59
+ offset += num_imgs
60
+ j += 1
61
+
62
+ self.scenes = scenes
63
+ self.sceneids = sceneids
64
+ self.images = images
65
+ self.start_img_ids = start_img_ids
66
+ self.scene_img_list = scene_img_list
67
+
68
+ self.invalid_scenes = {scene: False for scene in self.scenes}
69
+
70
+ def __len__(self):
71
+ return len(self.start_img_ids)
72
+
73
+ def get_image_num(self):
74
+ return len(self.images)
75
+
76
+ def _get_views(self, idx, resolution, rng, num_views):
77
+ invalid_seq = True
78
+ scene, start_id = self.start_img_ids[idx] # 获取指定索引idx对应的场景名scene和起始图像id
79
+
80
+ # 添加最大重试次数,防止无限循环导致分布式训练卡住
81
+ max_retries = 100
82
+ retry_count = 0
83
+
84
+ while invalid_seq:
85
+ retry_count += 1
86
+
87
+ # 超过重试次数限制,抛出异常
88
+ if retry_count > max_retries:
89
+ raise RuntimeError(
90
+ f"[HabitatHM3D] Failed to get valid views after {max_retries} retries. "
91
+ f"idx={idx}, scene={scene}, num_views={num_views}. "
92
+ f"This may indicate insufficient valid frames in the dataset."
93
+ )
94
+
95
+ # 超过50次时打印警告
96
+ if retry_count == 50:
97
+ print(f"[HabitatHM3D WARNING] Already retried {retry_count} times for idx={idx}, scene={scene}")
98
+
99
+ # 如果当前场景被标记为invalid则随机选择一个新的场景和起始图像id
100
+ scene_retry = 0
101
+ while self.invalid_scenes[scene]:
102
+ scene_retry += 1
103
+ if scene_retry > len(self.start_img_ids):
104
+ raise RuntimeError(
105
+ f"[HabitatHM3D] All scenes are invalid! Cannot find valid scene after {scene_retry} attempts."
106
+ )
107
+ idx = rng.integers(low=0, high=len(self.start_img_ids))
108
+ scene, start_id = self.start_img_ids[idx]
109
+
110
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]] # 获取当前场景的所有图像id列表
111
+ pos, ordered_video = self.get_seq_from_start_id(
112
+ num_views, start_id, all_image_ids, rng, max_interval=self.max_interval
113
+ ) # 根据起始图像id和其他参数生成图像序列的索引pos 并返回有序视频
114
+ image_idxs = np.array(all_image_ids)[pos] # 从all_image_ids提取图像序列
115
+
116
+ views = []
117
+ load_failed = False
118
+ for view_idx in image_idxs:
119
+ scene_id = self.sceneids[view_idx]
120
+ scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
121
+
122
+ basename = self.images[view_idx]
123
+
124
+ try:
125
+ # Load RGB image
126
+ rgb_image = imread_cv2(osp.join(scene_dir, "image_" + basename + ".png"))
127
+ # Load depthmap
128
+ depthmap = imread_cv2(
129
+ osp.join(scene_dir, "depth_" + basename + ".png"), cv2.IMREAD_UNCHANGED
130
+ )
131
+ depthmap = depthmap.astype(np.float32) / 1000
132
+ depthmap[~np.isfinite(depthmap)] = 0 # invalid
133
+
134
+ camera_params = np.load(osp.join(scene_dir, basename + ".npz"))
135
+ intrinsics = np.float32(camera_params["intrinsics"])
136
+ camera_pose = np.eye(4, dtype=np.float32)
137
+ camera_pose[:3, :3] = camera_params["R_cam2world"]
138
+ camera_pose[:3, 3] = camera_params["t_cam2world"]
139
+ except Exception as e:
140
+ print(f"[HabitatHM3D] Error loading {scene} {basename}: {e}, skipping scene")
141
+ self.invalid_scenes[scene] = True
142
+ load_failed = True
143
+ break
144
+
145
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
146
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
147
+ )
148
+
149
+ views.append(
150
+ dict(
151
+ img=rgb_image,
152
+ depthmap=depthmap.astype(np.float32),
153
+ camera_pose=camera_pose.astype(np.float32),
154
+ camera_intrinsics=intrinsics.astype(np.float32),
155
+ dataset="habitatHM3D",
156
+ label=self.scenes[scene_id] + "_" + basename,
157
+ instance=f"{str(idx)}_{str(view_idx)}",
158
+ is_metric=self.is_metric,
159
+ is_video=ordered_video,
160
+ quantile=np.array(0.98, dtype=np.float32),
161
+ img_mask=True,
162
+ ray_mask=False,
163
+ camera_only=True,
164
+ depth_only=False,
165
+ single_view=False,
166
+ reset=False,
167
+ )
168
+ )
169
+
170
+ # 只有成功加载所有视图才退出循环
171
+ if not load_failed and len(views) == num_views:
172
+ invalid_seq = False
173
+
174
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/hoi4d.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+ sys.path.append(osp.join(osp.dirname(__file__), '..','..'))
8
+ from tqdm import tqdm
9
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
10
+ from dust3r.utils.image import imread_cv2
11
+
12
+
13
+ class HOI4D_Multi(BaseMultiViewDataset):
14
+ def __init__(self, *args, ROOT, **kwargs):
15
+ self.ROOT = ROOT
16
+ self.video = True
17
+ self.is_metric = True
18
+ super().__init__(*args, **kwargs)
19
+ self.loaded_data = self._load_data()
20
+
21
+ def _load_data(self):
22
+ scenes = os.listdir(self.ROOT)
23
+ img_names = []
24
+ for scene in scenes:
25
+ scene_dir = osp.join(self.ROOT, scene)
26
+ rgb_dir = osp.join(scene_dir, 'rgb')
27
+ basenames = sorted([f[:-4] for f in os.listdir(rgb_dir) if f.endswith('.png')])
28
+ img_names.extend([(scene, basename) for basename in basenames])
29
+
30
+ self.img_names = img_names
31
+
32
+ def __len__(self):
33
+ return len(self.img_names)
34
+
35
+ def get_image_num(self):
36
+ return len(self.img_names)
37
+
38
+ def _get_views(self, idx, resolution, rng, num_views):
39
+ new_seed = rng.integers(0, 2**32) + idx
40
+ new_rng = np.random.default_rng(new_seed)
41
+ invalid_seq = True
42
+ while invalid_seq:
43
+ img_names = new_rng.choice(self.img_names, num_views, replace=False)
44
+
45
+ views = []
46
+ for v, img_name in enumerate(img_names):
47
+ # Load RGB image
48
+ scene, img_name = img_name
49
+ try:
50
+ rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"))
51
+ depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy"))
52
+ depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
53
+
54
+ intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))["intrinsics"]
55
+ except:
56
+ print(f"Error loading {scene} {img_name}, skipping")
57
+ break
58
+ # camera pose is not provided, placeholder
59
+ camera_pose = np.eye(4)
60
+
61
+ rgb_image, depthmap, intrinsics= self._crop_resize_if_necessary(
62
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name)
63
+
64
+ views.append(dict(
65
+ img=rgb_image,
66
+ depthmap=depthmap.astype(np.float32),
67
+ camera_pose=camera_pose.astype(np.float32),
68
+ camera_intrinsics=intrinsics.astype(np.float32),
69
+ dataset='HOI4D',
70
+ label=img_name,
71
+ instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"),
72
+ is_metric=self.is_metric,
73
+ is_video=False,
74
+ quantile=np.array(0.99, dtype=np.float32),
75
+ img_mask=True,
76
+ ray_mask=False,
77
+ camera_only=False,
78
+ depth_only=False,
79
+ single_view=True,
80
+ reset=True,
81
+ ))
82
+ if len(views) == num_views:
83
+ invalid_seq = False
84
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mapfree.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import cv2
4
+ import numpy as np
5
+ import itertools
6
+ import os
7
+ import sys
8
+ import pickle
9
+ import h5py
10
+ from tqdm import tqdm
11
+
12
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
13
+
14
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
15
+ from dust3r.utils.image import imread_cv2
16
+
17
+
18
+ class MapFree_Multi(BaseMultiViewDataset):
19
+
20
+ def __init__(self, ROOT, *args, **kwargs):
21
+ self.ROOT = ROOT
22
+ self.video = True
23
+ self.is_metric = True
24
+ self.max_interval = 30
25
+ super().__init__(*args, **kwargs)
26
+
27
+ self._load_data()
28
+
29
+ def imgid2path(self, img_id, scene):
30
+ first_seq_id, first_frame_id = img_id
31
+ return os.path.join(
32
+ self.ROOT,
33
+ scene,
34
+ f"dense{first_seq_id}",
35
+ "rgb",
36
+ f"frame_{first_frame_id:05d}.jpg",
37
+ )
38
+
39
+ def path2imgid(self, subscene, filename):
40
+ first_seq_id = int(subscene[5:])
41
+ first_frame_id = int(filename[6:-4])
42
+ return [first_seq_id, first_frame_id]
43
+
44
+ def _load_data(self):
45
+ cache_file = f"{self.ROOT}/cached_metadata_50_col_only.h5"
46
+ if os.path.exists(cache_file):
47
+ print(f"Loading cached metadata from {cache_file}")
48
+ with h5py.File(cache_file, "r") as hf:
49
+ self.scenes = list(map(lambda x: x.decode("utf-8"), hf["scenes"][:]))
50
+ self.sceneids = hf["sceneids"][:]
51
+ self.scope = hf["scope"][:]
52
+ self.video_flags = hf["video_flags"][:]
53
+ self.groups = hf["groups"][:]
54
+ self.id_ranges = hf["id_ranges"][:]
55
+ self.images = hf["images"][:]
56
+ else:
57
+ scene_dirs = sorted(
58
+ [
59
+ d
60
+ for d in os.listdir(self.ROOT)
61
+ if os.path.isdir(os.path.join(self.ROOT, d))
62
+ ]
63
+ )
64
+ scenes = []
65
+ sceneids = []
66
+ groups = []
67
+ scope = []
68
+ images = []
69
+ id_ranges = []
70
+ is_video = []
71
+ start = 0
72
+ j = 0
73
+ offset = 0
74
+
75
+ for scene in tqdm(scene_dirs):
76
+ scenes.append(scene)
77
+ # video sequences
78
+ subscenes = sorted(
79
+ [
80
+ d
81
+ for d in os.listdir(os.path.join(self.ROOT, scene))
82
+ if d.startswith("dense")
83
+ ]
84
+ )
85
+ id_range_subscenes = []
86
+ for subscene in subscenes:
87
+ rgb_paths = sorted(
88
+ [
89
+ d
90
+ for d in os.listdir(
91
+ os.path.join(self.ROOT, scene, subscene, "rgb")
92
+ )
93
+ if d.endswith(".jpg")
94
+ ]
95
+ )
96
+ assert (
97
+ len(rgb_paths) > 0
98
+ ), f"{os.path.join(self.ROOT, scene, subscene)} is empty."
99
+ num_imgs = len(rgb_paths)
100
+ images.extend(
101
+ [self.path2imgid(subscene, rgb_path) for rgb_path in rgb_paths]
102
+ )
103
+ id_range_subscenes.append((offset, offset + num_imgs))
104
+ offset += num_imgs
105
+
106
+ # image collections
107
+ metadata = pickle.load(
108
+ open(os.path.join(self.ROOT, scene, "metadata.pkl"), "rb")
109
+ )
110
+ ref_imgs = list(metadata.keys())
111
+ img_groups = []
112
+ for ref_img in ref_imgs:
113
+ other_imgs = metadata[ref_img]
114
+ if len(other_imgs) + 1 < self.num_views:
115
+ continue
116
+ group = [(*other_img[0], other_img[1]) for other_img in other_imgs]
117
+ group.insert(0, (*ref_img, 1))
118
+ img_groups.append(np.array(group))
119
+ id_ranges.append(id_range_subscenes[ref_img[0]])
120
+ scope.append(start)
121
+ start = start + len(group)
122
+
123
+ num_groups = len(img_groups)
124
+ sceneids.extend([j] * num_groups)
125
+ groups.extend(img_groups)
126
+ is_video.extend([False] * num_groups)
127
+ j += 1
128
+
129
+ self.scenes = np.array(scenes)
130
+ self.sceneids = np.array(sceneids)
131
+ self.scope = np.array(scope)
132
+ self.video_flags = np.array(is_video)
133
+ self.groups = np.concatenate(groups, 0)
134
+ self.id_ranges = np.array(id_ranges)
135
+ self.images = np.array(images)
136
+
137
+ data = dict(
138
+ scenes=self.scenes,
139
+ sceneids=self.sceneids,
140
+ scope=self.scope,
141
+ video_flags=self.video_flags,
142
+ groups=self.groups,
143
+ id_ranges=self.id_ranges,
144
+ images=self.images,
145
+ )
146
+
147
+ with h5py.File(cache_file, "w") as h5f:
148
+ h5f.create_dataset(
149
+ "scenes",
150
+ data=data["scenes"].astype(object),
151
+ dtype=h5py.string_dtype(encoding="utf-8"),
152
+ compression="lzf",
153
+ chunks=True,
154
+ )
155
+ h5f.create_dataset(
156
+ "sceneids", data=data["sceneids"], compression="lzf", chunks=True
157
+ )
158
+ h5f.create_dataset(
159
+ "scope", data=data["scope"], compression="lzf", chunks=True
160
+ )
161
+ h5f.create_dataset(
162
+ "video_flags",
163
+ data=data["video_flags"],
164
+ compression="lzf",
165
+ chunks=True,
166
+ )
167
+ h5f.create_dataset(
168
+ "groups", data=data["groups"], compression="lzf", chunks=True
169
+ )
170
+ h5f.create_dataset(
171
+ "id_ranges", data=data["id_ranges"], compression="lzf", chunks=True
172
+ )
173
+ h5f.create_dataset(
174
+ "images", data=data["images"], compression="lzf", chunks=True
175
+ )
176
+
177
+ def __len__(self):
178
+ return len(self.scope)
179
+
180
+ def get_image_num(self):
181
+ return len(self.images)
182
+
183
+ def get_stats(self):
184
+ return f"{len(self)} groups of views"
185
+
186
+ def _get_views(self, idx, resolution, rng, num_views):
187
+ scene = self.scenes[self.sceneids[idx]]
188
+ if rng.random() < 0.6:
189
+ ids = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1])
190
+ cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3)
191
+ start_ids = ids[: len(ids) - cut_off + 1]
192
+ start_id = rng.choice(start_ids)
193
+ pos, ordered_video = self.get_seq_from_start_id(
194
+ num_views,
195
+ start_id,
196
+ ids.tolist(),
197
+ rng,
198
+ max_interval=self.max_interval,
199
+ video_prob=0.8,
200
+ fix_interval_prob=0.5,
201
+ block_shuffle=16,
202
+ )
203
+ ids = np.array(ids)[pos]
204
+ image_idxs = self.images[ids]
205
+ else:
206
+ ordered_video = False
207
+ seq_start_index = self.scope[idx]
208
+ seq_end_index = self.scope[idx + 1] if idx < len(self.scope) - 1 else None
209
+ image_idxs = (
210
+ self.groups[seq_start_index:seq_end_index]
211
+ if seq_end_index is not None
212
+ else self.groups[seq_start_index:]
213
+ )
214
+ image_idxs, overlap_scores = image_idxs[:, :2], image_idxs[:, 2]
215
+ replace = (
216
+ True
217
+ if self.allow_repeat
218
+ or len(overlap_scores[overlap_scores > 0]) < num_views
219
+ else False
220
+ )
221
+ image_idxs = rng.choice(
222
+ image_idxs,
223
+ num_views,
224
+ replace=replace,
225
+ p=overlap_scores / np.sum(overlap_scores),
226
+ )
227
+ image_idxs = image_idxs.astype(np.int64)
228
+
229
+ views = []
230
+ for v, view_idx in enumerate(image_idxs):
231
+ img_path = self.imgid2path(view_idx, scene)
232
+ depth_path = img_path.replace("rgb", "depth").replace(".jpg", ".npy")
233
+ cam_path = img_path.replace("rgb", "cam").replace(".jpg", ".npz")
234
+ sky_mask_path = img_path.replace("rgb", "sky_mask")
235
+ image = imread_cv2(img_path)
236
+ depthmap = np.load(depth_path)
237
+ camera_params = np.load(cam_path)
238
+ sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_UNCHANGED) >= 127
239
+
240
+ intrinsics = camera_params["intrinsic"].astype(np.float32)
241
+ camera_pose = camera_params["pose"].astype(np.float32)
242
+
243
+ depthmap[sky_mask] = -1.0
244
+ depthmap[depthmap > 400.0] = 0.0
245
+ depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
246
+ threshold = (
247
+ np.percentile(depthmap[depthmap > 0], 98)
248
+ if depthmap[depthmap > 0].size > 0
249
+ else 0
250
+ )
251
+ depthmap[depthmap > threshold] = 0.0
252
+
253
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
254
+ image, depthmap, intrinsics, resolution, rng, info=(img_path)
255
+ )
256
+ # generate img mask and raymap mask
257
+ img_mask, ray_mask = self.get_img_and_ray_masks(
258
+ self.is_metric, v, rng, p=[0.75, 0.2, 0.05]
259
+ )
260
+
261
+ views.append(
262
+ dict(
263
+ img=image,
264
+ depthmap=depthmap,
265
+ camera_pose=camera_pose, # cam2world
266
+ camera_intrinsics=intrinsics,
267
+ dataset="MapFree",
268
+ label=img_path,
269
+ is_metric=self.is_metric,
270
+ instance=img_path,
271
+ is_video=ordered_video,
272
+ quantile=np.array(0.96, dtype=np.float32),
273
+ img_mask=img_mask,
274
+ ray_mask=ray_mask,
275
+ camera_only=False,
276
+ depth_only=False,
277
+ single_view=False,
278
+ reset=False,
279
+ )
280
+ )
281
+ assert len(views) == num_views
282
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/mvs_synth.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_pil
12
+
13
+
14
+ class MVS_Synth_Multi(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = True
18
+ self.is_metric = False
19
+ self.max_interval = 4
20
+ super().__init__(*args, **kwargs)
21
+ self.loaded_data = self._load_data()
22
+ print('DATA: mvs_synth', len(self))
23
+
24
+ def _load_data(self):
25
+ self.scenes = os.listdir(self.ROOT)
26
+
27
+ offset = 0
28
+ scenes = []
29
+ sceneids = []
30
+ scene_img_list = []
31
+ images = []
32
+ start_img_ids = []
33
+
34
+ j = 0
35
+ for scene in tqdm(self.scenes):
36
+ scene_dir = osp.join(self.ROOT, scene)
37
+ rgb_dir = osp.join(scene_dir, "rgb")
38
+ basenames = sorted(
39
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")]
40
+ )
41
+ num_imgs = len(basenames)
42
+ cut_off = (
43
+ self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
44
+ )
45
+
46
+ if num_imgs < cut_off:
47
+ print(f"Skipping {scene}")
48
+ continue
49
+ img_ids = list(np.arange(num_imgs) + offset)
50
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
51
+
52
+ start_img_ids.extend(start_img_ids_)
53
+ sceneids.extend([j] * num_imgs)
54
+ images.extend(basenames)
55
+ scenes.append(scene)
56
+ scene_img_list.append(img_ids)
57
+
58
+ # offset groups
59
+ offset += num_imgs
60
+ j += 1
61
+
62
+ self.scenes = scenes
63
+ self.sceneids = sceneids
64
+ self.images = images
65
+ self.start_img_ids = start_img_ids
66
+ self.scene_img_list = scene_img_list
67
+
68
+ def __len__(self):
69
+ return len(self.start_img_ids)
70
+
71
+ def get_image_num(self):
72
+ return len(self.images)
73
+
74
+ def _get_views(self, idx, resolution, rng, num_views):
75
+ start_id = self.start_img_ids[idx]
76
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]]
77
+ pos, ordered_video = self.get_seq_from_start_id(
78
+ num_views,
79
+ start_id,
80
+ all_image_ids,
81
+ rng,
82
+ max_interval=self.max_interval,
83
+ video_prob=1.0,
84
+ fix_interval_prob=1.0,
85
+ )
86
+ image_idxs = np.array(all_image_ids)[pos]
87
+
88
+ views = []
89
+ for v, view_idx in enumerate(image_idxs):
90
+ scene_id = self.sceneids[view_idx]
91
+ scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
92
+ rgb_dir = osp.join(scene_dir, "rgb")
93
+ depth_dir = osp.join(scene_dir, "depth")
94
+ cam_dir = osp.join(scene_dir, "cam")
95
+
96
+ basename = self.images[view_idx]
97
+
98
+ # Load RGB image
99
+ rgb_image = imread_pil(osp.join(rgb_dir, basename + ".jpg"))
100
+ # Load depthmap
101
+ depthmap = np.load(osp.join(depth_dir, basename + ".npy"))
102
+ depthmap[~np.isfinite(depthmap)] = 0 # invalid
103
+ threshold = (
104
+ np.percentile(depthmap[depthmap > 0], 98)
105
+ if depthmap[depthmap > 0].size > 0
106
+ else 0
107
+ )
108
+ depthmap[depthmap > threshold] = 0.0
109
+ depthmap[depthmap > 1000] = 0.0
110
+
111
+ cam = np.load(osp.join(cam_dir, basename + ".npz"))
112
+ camera_pose = cam["pose"]
113
+ intrinsics = cam["intrinsics"]
114
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
115
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
116
+ )
117
+
118
+ # generate img mask and raymap mask
119
+ img_mask, ray_mask = self.get_img_and_ray_masks(
120
+ self.is_metric, v, rng, p=[0.8, 0.15, 0.05]
121
+ )
122
+
123
+ views.append(
124
+ dict(
125
+ img=rgb_image,
126
+ depthmap=depthmap.astype(np.float32),
127
+ camera_pose=camera_pose.astype(np.float32),
128
+ camera_intrinsics=intrinsics.astype(np.float32),
129
+ dataset="MVS_Synth",
130
+ label=self.scenes[scene_id] + "_" + basename,
131
+ instance=osp.join(rgb_dir, basename + ".jpg"),
132
+ is_metric=self.is_metric,
133
+ is_video=ordered_video,
134
+ quantile=np.array(1.0, dtype=np.float32),
135
+ img_mask=img_mask,
136
+ ray_mask=ray_mask,
137
+ camera_only=False,
138
+ depth_only=False,
139
+ single_view=False,
140
+ reset=False,
141
+ )
142
+ )
143
+ assert len(views) == num_views
144
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/omniobject3d.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+ import json
8
+
9
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
10
+ from tqdm import tqdm
11
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
12
+ from dust3r.utils.image import imread_cv2
13
+ import re
14
+
15
+
16
+ def extract_number(filename):
17
+ match = re.search(r"\d+", filename)
18
+ if match:
19
+ return int(match.group())
20
+ return 0
21
+
22
+
23
+ class OmniObject3D_Multi(BaseMultiViewDataset):
24
+ def __init__(self, *args, ROOT, **kwargs):
25
+ self.ROOT = ROOT
26
+ self.video = False
27
+ self.is_metric = False # True
28
+ super().__init__(*args, **kwargs)
29
+
30
+ self.loaded_data = self._load_data()
31
+
32
+ def _load_data(self):
33
+ self.scenes = [
34
+ d
35
+ for d in os.listdir(self.ROOT)
36
+ if os.path.isdir(os.path.join(self.ROOT, d)) and not d.startswith('.')
37
+ ]
38
+ with open(os.path.join(self.ROOT, "scale.json"), "r") as f:
39
+ self.scales = json.load(f)
40
+ offset = 0
41
+ scenes = []
42
+ sceneids = []
43
+ scene_img_list = []
44
+ images = []
45
+ start_img_ids = []
46
+
47
+ j = 0
48
+ for scene in tqdm(self.scenes):
49
+ scene_dir = osp.join(self.ROOT, scene)
50
+ rgb_dir = osp.join(scene_dir, "rgb")
51
+ basenames = sorted(
52
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")],
53
+ key=extract_number,
54
+ )
55
+
56
+ num_imgs = len(basenames)
57
+ cut_off = (
58
+ self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
59
+ )
60
+
61
+ if num_imgs < cut_off:
62
+ print(f"Skipping {scene}")
63
+ continue
64
+ img_ids = list(np.arange(num_imgs) + offset)
65
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
66
+
67
+ start_img_ids.extend([(scene, id) for id in start_img_ids_])
68
+ sceneids.extend([j] * num_imgs)
69
+ images.extend(basenames)
70
+ scenes.append(scene)
71
+ scene_img_list.append(img_ids)
72
+
73
+ # offset groups
74
+ offset += num_imgs
75
+ j += 1
76
+
77
+ self.scenes = scenes
78
+ self.sceneids = sceneids
79
+ self.images = images
80
+ self.start_img_ids = start_img_ids
81
+ self.scene_img_list = scene_img_list
82
+
83
+ def __len__(self):
84
+ return len(self.start_img_ids)
85
+
86
+ def get_image_num(self):
87
+ return len(self.images)
88
+
89
+ def _get_views(self, idx, resolution, rng, num_views):
90
+ scene, start_id = self.start_img_ids[idx]
91
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]]
92
+ pos, ordered_video = self.get_seq_from_start_id(
93
+ num_views, start_id, all_image_ids, rng, max_interval=100, video_prob=0.0
94
+ )
95
+ image_idxs = np.array(all_image_ids)[pos]
96
+
97
+ views = []
98
+ for v, view_idx in enumerate(image_idxs):
99
+ scene_id = self.sceneids[view_idx]
100
+ scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
101
+ rgb_dir = osp.join(scene_dir, "rgb")
102
+ depth_dir = osp.join(scene_dir, "depth")
103
+ cam_dir = osp.join(scene_dir, "cam")
104
+
105
+ basename = self.images[view_idx]
106
+
107
+ # Load RGB image
108
+ rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png"))
109
+ depthmap = np.load(osp.join(depth_dir, basename + ".npy"))
110
+ cam = np.load(osp.join(cam_dir, basename + ".npz"))
111
+ camera_pose = cam["pose"]
112
+ intrinsics = cam["intrinsics"]
113
+ scale = self.scales[self.scenes[scene_id]]
114
+ depthmap = depthmap / scale / 1000.0
115
+ camera_pose[:3, 3] = camera_pose[:3, 3] / scale / 1000.0
116
+
117
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
118
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
119
+ )
120
+
121
+ img_mask, ray_mask = self.get_img_and_ray_masks(
122
+ self.is_metric, v, rng, p=[0.8, 0.15, 0.05]
123
+ )
124
+
125
+ views.append(
126
+ dict(
127
+ img=rgb_image,
128
+ depthmap=depthmap.astype(np.float32),
129
+ camera_pose=camera_pose.astype(np.float32),
130
+ camera_intrinsics=intrinsics.astype(np.float32),
131
+ dataset="OmniObject3D",
132
+ label=self.scenes[scene_id] + "_" + basename,
133
+ instance=f"{str(idx)}_{str(view_idx)}",
134
+ is_metric=self.is_metric,
135
+ is_video=ordered_video,
136
+ quantile=np.array(1.0, dtype=np.float32),
137
+ img_mask=img_mask,
138
+ ray_mask=ray_mask,
139
+ camera_only=False,
140
+ depth_only=False,
141
+ single_view=False,
142
+ reset=False,
143
+ )
144
+ )
145
+ assert len(views) == num_views
146
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/pointodyssey.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+
14
+ class PointOdyssey_Multi(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = True
18
+ self.is_metric = True
19
+ self.max_interval = 4
20
+ super().__init__(*args, **kwargs)
21
+ assert self.split in ["train", "test", "val"]
22
+ self.scenes_to_use = [
23
+ # 'cab_h_bench_3rd', 'cab_h_bench_ego1', 'cab_h_bench_ego2',
24
+ "cnb_dlab_0215_3rd",
25
+ "cnb_dlab_0215_ego1",
26
+ "cnb_dlab_0225_3rd",
27
+ "cnb_dlab_0225_ego1",
28
+ "dancing",
29
+ "dancingroom0_3rd",
30
+ "footlab_3rd",
31
+ "footlab_ego1",
32
+ "footlab_ego2",
33
+ "girl",
34
+ "girl_egocentric",
35
+ "human_egocentric",
36
+ "human_in_scene",
37
+ "human_in_scene1",
38
+ "kg",
39
+ "kg_ego1",
40
+ "kg_ego2",
41
+ "kitchen_gfloor",
42
+ "kitchen_gfloor_ego1",
43
+ "kitchen_gfloor_ego2",
44
+ "scene_carb_h_tables",
45
+ "scene_carb_h_tables_ego1",
46
+ "scene_carb_h_tables_ego2",
47
+ "scene_j716_3rd",
48
+ "scene_j716_ego1",
49
+ "scene_j716_ego2",
50
+ "scene_recording_20210910_S05_S06_0_3rd",
51
+ "scene_recording_20210910_S05_S06_0_ego2",
52
+ "scene1_0129",
53
+ "scene1_0129_ego",
54
+ "seminar_h52_3rd",
55
+ "seminar_h52_ego1",
56
+ "seminar_h52_ego2",
57
+ ]
58
+ self.loaded_data = self._load_data(self.split)
59
+
60
+ def _load_data(self, split):
61
+ root = os.path.join(self.ROOT, split)
62
+ self.scenes = []
63
+
64
+ offset = 0
65
+ scenes = []
66
+ sceneids = []
67
+ scene_img_list = []
68
+ images = []
69
+ start_img_ids = []
70
+
71
+ j = 0
72
+ for scene in tqdm(os.listdir(root)):
73
+ if scene not in self.scenes_to_use:
74
+ continue
75
+ scene_dir = osp.join(root, scene)
76
+ rgb_dir = osp.join(scene_dir, "rgb")
77
+ basenames = sorted(
78
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".jpg")]
79
+ )
80
+ num_imgs = len(basenames)
81
+ img_ids = list(np.arange(num_imgs) + offset)
82
+ cut_off = (
83
+ self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
84
+ )
85
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
86
+ # start_img_ids_ = img_ids[:-self.num_views+1]
87
+
88
+ if num_imgs < cut_off:
89
+ print(f"Skipping {scene}")
90
+ continue
91
+
92
+ start_img_ids.extend(start_img_ids_)
93
+ sceneids.extend([j] * num_imgs)
94
+ images.extend(basenames)
95
+ scenes.append(scene)
96
+ scene_img_list.append(img_ids)
97
+
98
+ # offset groups
99
+ offset += num_imgs
100
+ j += 1
101
+
102
+ self.scenes = scenes
103
+ self.sceneids = sceneids
104
+ self.images = images
105
+ self.start_img_ids = start_img_ids
106
+ self.scene_img_list = scene_img_list
107
+
108
+ def __len__(self):
109
+ return len(self.start_img_ids)
110
+
111
+ def get_image_num(self):
112
+ return len(self.images)
113
+
114
+ def _get_views(self, idx, resolution, rng, num_views):
115
+ start_id = self.start_img_ids[idx]
116
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]]
117
+ pos, ordered_video = self.get_seq_from_start_id(
118
+ num_views,
119
+ start_id,
120
+ all_image_ids,
121
+ rng,
122
+ max_interval=self.max_interval,
123
+ video_prob=1.0,
124
+ fix_interval_prob=1.0,
125
+ )
126
+ image_idxs = np.array(all_image_ids)[pos]
127
+
128
+ views = []
129
+ for v, view_idx in enumerate(image_idxs):
130
+ scene_id = self.sceneids[view_idx]
131
+ scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id])
132
+ rgb_dir = osp.join(scene_dir, "rgb")
133
+ depth_dir = osp.join(scene_dir, "depth")
134
+ cam_dir = osp.join(scene_dir, "cam")
135
+
136
+ basename = self.images[view_idx]
137
+
138
+ # Load RGB image
139
+ rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".jpg"))
140
+ # Load depthmap
141
+ depthmap = np.load(osp.join(depth_dir, basename + ".npy"))
142
+ depthmap[~np.isfinite(depthmap)] = 0 # invalid
143
+ depthmap[depthmap > 1000] = 0.0
144
+
145
+ cam = np.load(osp.join(cam_dir, basename + ".npz"))
146
+ camera_pose = cam["pose"]
147
+ intrinsics = cam["intrinsics"]
148
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
149
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
150
+ )
151
+
152
+ # generate img mask and raymap mask
153
+ img_mask, ray_mask = self.get_img_and_ray_masks(
154
+ self.is_metric, v, rng, p=[0.9, 0.05, 0.05]
155
+ )
156
+
157
+ views.append(
158
+ dict(
159
+ img=rgb_image,
160
+ depthmap=depthmap.astype(np.float32),
161
+ camera_pose=camera_pose.astype(np.float32),
162
+ camera_intrinsics=intrinsics.astype(np.float32),
163
+ dataset="PointOdyssey",
164
+ label=self.scenes[scene_id] + "_" + basename,
165
+ instance=osp.join(rgb_dir, basename + ".jpg"),
166
+ is_metric=self.is_metric,
167
+ is_video=ordered_video,
168
+ quantile=np.array(1.0, dtype=np.float32),
169
+ img_mask=img_mask,
170
+ ray_mask=ray_mask,
171
+ camera_only=False,
172
+ depth_only=False,
173
+ single_view=False,
174
+ reset=False,
175
+ )
176
+ )
177
+ assert len(views) == num_views
178
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/realestate10k.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+
14
+ class RE10K_Multi(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = True
18
+ self.is_metric = False
19
+ self.max_interval = 128
20
+ super().__init__(*args, **kwargs)
21
+ self.loaded_data = self._load_data()
22
+
23
+ def _load_data(self):
24
+ self.scenes = os.listdir(self.ROOT)
25
+
26
+ offset = 0
27
+ scenes = []
28
+ sceneids = []
29
+ scene_img_list = []
30
+ images = []
31
+ start_img_ids = []
32
+
33
+ j = 0
34
+ for scene in tqdm(self.scenes):
35
+ scene_dir = osp.join(self.ROOT, scene)
36
+ rgb_dir = osp.join(scene_dir, "rgb")
37
+ basenames = sorted(
38
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")],
39
+ key=lambda x: int(x),
40
+ )
41
+
42
+ num_imgs = len(basenames)
43
+ img_ids = list(np.arange(num_imgs) + offset)
44
+ cut_off = (
45
+ self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
46
+ )
47
+ if num_imgs < cut_off:
48
+ print(f"Skipping {scene}")
49
+ continue
50
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
51
+
52
+ start_img_ids.extend([(scene, id) for id in start_img_ids_])
53
+ sceneids.extend([j] * num_imgs)
54
+ images.extend(basenames)
55
+ scenes.append(scene)
56
+ scene_img_list.append(img_ids)
57
+
58
+ # offset groups
59
+ offset += num_imgs
60
+ j += 1
61
+
62
+ self.scenes = scenes
63
+ self.sceneids = sceneids
64
+ self.images = images
65
+ self.start_img_ids = start_img_ids
66
+ self.scene_img_list = scene_img_list
67
+
68
+ self.invalid_scenes = {scene: False for scene in self.scenes}
69
+
70
+ def __len__(self):
71
+ return len(self.start_img_ids)
72
+
73
+ def get_image_num(self):
74
+ return len(self.images)
75
+
76
+ def _get_views(self, idx, resolution, rng, num_views):
77
+ invalid_seq = True
78
+ scene, start_id = self.start_img_ids[idx]
79
+
80
+ while invalid_seq:
81
+ while self.invalid_scenes[scene]:
82
+ idx = rng.integers(low=0, high=len(self.start_img_ids))
83
+ scene, start_id = self.start_img_ids[idx]
84
+
85
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]]
86
+ pos, ordered_video = self.get_seq_from_start_id(
87
+ num_views, start_id, all_image_ids, rng, max_interval=self.max_interval
88
+ )
89
+ image_idxs = np.array(all_image_ids)[pos]
90
+
91
+ views = []
92
+ for view_idx in image_idxs:
93
+ scene_id = self.sceneids[view_idx]
94
+ scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
95
+ rgb_dir = osp.join(scene_dir, "rgb")
96
+ cam_dir = osp.join(scene_dir, "cam")
97
+
98
+ basename = self.images[view_idx]
99
+
100
+ try:
101
+ # Load RGB image
102
+ rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png"))
103
+ # Load depthmap, no depth, set to all ones
104
+ depthmap = np.ones_like(rgb_image[..., 0], dtype=np.float32)
105
+ cam = np.load(osp.join(cam_dir, basename + ".npz"))
106
+ intrinsics = cam["intrinsics"]
107
+ camera_pose = cam["pose"]
108
+ except:
109
+ print(f"Error loading {scene} {basename}, skipping")
110
+ self.invalid_scenes[scene] = True
111
+ break
112
+
113
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
114
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
115
+ )
116
+
117
+ views.append(
118
+ dict(
119
+ img=rgb_image,
120
+ depthmap=depthmap.astype(np.float32),
121
+ camera_pose=camera_pose.astype(np.float32),
122
+ camera_intrinsics=intrinsics.astype(np.float32),
123
+ dataset="realestate10k",
124
+ label=self.scenes[scene_id] + "_" + basename,
125
+ instance=f"{str(idx)}_{str(view_idx)}",
126
+ is_metric=self.is_metric,
127
+ is_video=ordered_video,
128
+ quantile=np.array(0.98, dtype=np.float32),
129
+ img_mask=True,
130
+ ray_mask=False,
131
+ camera_only=True,
132
+ depth_only=False,
133
+ single_view=False,
134
+ reset=False,
135
+ )
136
+ )
137
+ if len(views) == num_views:
138
+ invalid_seq = False
139
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannet.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2, imread_pil
12
+
13
+
14
+ class ScanNet_Multi(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = True
18
+ self.is_metric = True
19
+ self.max_interval = 30
20
+ super().__init__(*args, **kwargs)
21
+
22
+ self.loaded_data = self._load_data(self.split)
23
+ print('DATA: scannet', len(self))
24
+
25
+ def _load_data(self, split):
26
+ self.scene_root = osp.join(
27
+ self.ROOT, "scans_train" if split == "train" else "scans_test"
28
+ )
29
+ self.scenes = [
30
+ scene for scene in os.listdir(self.scene_root) if scene.startswith("scene")
31
+ ]
32
+
33
+ offset = 0
34
+ scenes = []
35
+ sceneids = []
36
+ scene_img_list = []
37
+ images = []
38
+ start_img_ids = []
39
+
40
+ j = 0
41
+ for scene in tqdm(self.scenes):
42
+ scene_dir = osp.join(self.scene_root, scene)
43
+ with np.load(
44
+ osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True
45
+ ) as data:
46
+ basenames = data["images"]
47
+ num_imgs = len(basenames)
48
+ img_ids = list(np.arange(num_imgs) + offset)
49
+ cut_off = (
50
+ self.num_views
51
+ if not self.allow_repeat
52
+ else max(self.num_views // 3, 3)
53
+ )
54
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
55
+
56
+ if num_imgs < cut_off:
57
+ print(f"Skipping {scene}")
58
+ continue
59
+
60
+ start_img_ids.extend(start_img_ids_)
61
+ sceneids.extend([j] * num_imgs)
62
+ images.extend(basenames)
63
+ scenes.append(scene)
64
+ scene_img_list.append(img_ids)
65
+
66
+ # offset groups
67
+ offset += num_imgs
68
+ j += 1
69
+
70
+ self.scenes = scenes
71
+ self.sceneids = sceneids
72
+ self.images = images
73
+ self.start_img_ids = start_img_ids
74
+ self.scene_img_list = scene_img_list
75
+
76
+ def __len__(self):
77
+ return len(self.start_img_ids)
78
+
79
+ def get_image_num(self):
80
+ return len(self.images)
81
+
82
+ def _get_views(self, idx, resolution, rng, num_views):
83
+ start_id = self.start_img_ids[idx]
84
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]]
85
+ pos, ordered_video = self.get_seq_from_start_id(
86
+ num_views,
87
+ start_id,
88
+ all_image_ids,
89
+ rng,
90
+ max_interval=self.max_interval,
91
+ video_prob=0.6,
92
+ fix_interval_prob=0.6,
93
+ block_shuffle=16,
94
+ )
95
+ image_idxs = np.array(all_image_ids)[pos]
96
+
97
+ views = []
98
+ for v, view_idx in enumerate(image_idxs):
99
+ scene_id = self.sceneids[view_idx]
100
+ scene_dir = osp.join(self.scene_root, self.scenes[scene_id])
101
+ rgb_dir = osp.join(scene_dir, "color")
102
+ depth_dir = osp.join(scene_dir, "depth")
103
+ cam_dir = osp.join(scene_dir, "cam")
104
+
105
+ basename = self.images[view_idx]
106
+
107
+ # Load RGB image
108
+ rgb_image = imread_pil(osp.join(rgb_dir, basename + ".jpg"))
109
+ # Load depthmap
110
+ depthmap = imread_cv2(
111
+ osp.join(depth_dir, basename + ".png"), cv2.IMREAD_UNCHANGED
112
+ )
113
+ depthmap = depthmap.astype(np.float32) / 1000
114
+ depthmap[~np.isfinite(depthmap)] = 0 # invalid
115
+
116
+ cam = np.load(osp.join(cam_dir, basename + ".npz"))
117
+ camera_pose = cam["pose"]
118
+ intrinsics = cam["intrinsics"]
119
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
120
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
121
+ )
122
+
123
+ # generate img mask and raymap mask
124
+ img_mask, ray_mask = self.get_img_and_ray_masks(
125
+ self.is_metric, v, rng, p=[0.75, 0.2, 0.05]
126
+ )
127
+
128
+ views.append(
129
+ dict(
130
+ img=rgb_image,
131
+ depthmap=depthmap.astype(np.float32),
132
+ camera_pose=camera_pose.astype(np.float32),
133
+ camera_intrinsics=intrinsics.astype(np.float32),
134
+ dataset="ScanNet",
135
+ label=self.scenes[scene_id] + "_" + basename,
136
+ instance=f"{str(idx)}_{str(view_idx)}",
137
+ is_metric=self.is_metric,
138
+ is_video=ordered_video,
139
+ quantile=np.array(0.98, dtype=np.float32),
140
+ img_mask=img_mask,
141
+ ray_mask=ray_mask,
142
+ camera_only=False,
143
+ depth_only=False,
144
+ single_view=False,
145
+ reset=False,
146
+ )
147
+ )
148
+ assert len(views) == num_views
149
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/scannetpp.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2, imread_pil
12
+
13
+
14
+ class ScanNetpp_Multi(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = True
18
+ self.is_metric = True
19
+ self.max_interval = 3
20
+ super().__init__(*args, **kwargs)
21
+ assert self.split == "train"
22
+ self.loaded_data = self._load_data()
23
+
24
+ def _load_data(self):
25
+ with np.load(osp.join(self.ROOT, "all_metadata.npz")) as data:
26
+ self.scenes = data["scenes"]
27
+ offset = 0
28
+ scenes = []
29
+ sceneids = []
30
+ images = []
31
+ intrinsics = []
32
+ trajectories = []
33
+ groups = []
34
+ id_ranges = []
35
+ j = 0
36
+ self.image_num = 0
37
+ for scene in self.scenes:
38
+ scene_dir = osp.join(self.ROOT, scene)
39
+ with np.load(
40
+ osp.join(scene_dir, "new_scene_metadata.npz"), allow_pickle=True
41
+ ) as data:
42
+ imgs = data["images"]
43
+ self.image_num += len(imgs)
44
+ img_ids = np.arange(len(imgs)).tolist()
45
+ intrins = data["intrinsics"]
46
+ traj = data["trajectories"]
47
+ imgs_on_disk = sorted(os.listdir(osp.join(scene_dir, "images")))
48
+ imgs_on_disk = list(map(lambda x: x[:-4], imgs_on_disk))
49
+
50
+ dslr_ids = [
51
+ i + offset
52
+ for i in img_ids
53
+ if imgs[i].startswith("DSC") and imgs[i] in imgs_on_disk
54
+ ]
55
+ iphone_ids = [
56
+ i + offset
57
+ for i in img_ids
58
+ if imgs[i].startswith("frame") and imgs[i] in imgs_on_disk
59
+ ]
60
+
61
+ num_imgs = len(imgs)
62
+ assert max(dslr_ids) < min(iphone_ids)
63
+ assert "image_collection" in data
64
+
65
+ img_groups = []
66
+ img_id_ranges = []
67
+
68
+ # 使用与其他数据集一致的 cut_off 逻辑
69
+ min_group_len = (
70
+ self.num_views
71
+ if not self.allow_repeat
72
+ else max(self.num_views // 3, 3)
73
+ )
74
+
75
+ for ref_id, group in data["image_collection"].item().items():
76
+ if len(group) + 1 < min_group_len:
77
+ continue
78
+ group.insert(0, (ref_id, 1.0))
79
+ sorted_group = sorted(group, key=lambda x: x[1], reverse=True)
80
+ group = [int(x[0] + offset) for x in sorted_group]
81
+
82
+ # 确定对应的视频帧列表
83
+ if imgs[ref_id].startswith("frame"):
84
+ video_ids = dslr_ids
85
+ else:
86
+ video_ids = iphone_ids
87
+
88
+ # 只有当视频帧列表足够长时才添加
89
+ if len(video_ids) >= min_group_len:
90
+ img_groups.append(sorted(group))
91
+ img_id_ranges.append(video_ids)
92
+
93
+ if len(img_groups) == 0:
94
+ print(f"Skipping {scene}")
95
+ continue
96
+ scenes.append(scene)
97
+ sceneids.extend([j] * num_imgs)
98
+ images.extend(imgs)
99
+ intrinsics.append(intrins)
100
+ trajectories.append(traj)
101
+
102
+ # offset groups
103
+ groups.extend(img_groups)
104
+ id_ranges.extend(img_id_ranges)
105
+ offset += num_imgs
106
+ j += 1
107
+
108
+ self.scenes = scenes
109
+ self.sceneids = sceneids
110
+ self.images = images
111
+ self.intrinsics = np.concatenate(intrinsics, axis=0)
112
+ self.trajectories = np.concatenate(trajectories, axis=0)
113
+ self.id_ranges = id_ranges
114
+ self.groups = groups
115
+
116
+ def __len__(self):
117
+ return len(self.groups) * 10
118
+
119
+ def get_image_num(self):
120
+ return self.image_num
121
+
122
+ def _get_views(self, idx, resolution, rng, num_views):
123
+ idx = idx // 10
124
+ image_idxs = self.groups[idx]
125
+ rand_val = rng.random()
126
+
127
+ image_idxs_video = self.id_ranges[idx]
128
+ cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3)
129
+ start_image_idxs = image_idxs_video[: len(image_idxs_video) - cut_off + 1]
130
+
131
+ if rand_val < 0.7 and len(start_image_idxs) > 0:
132
+ start_id = rng.choice(start_image_idxs)
133
+ pos, ordered_video = self.get_seq_from_start_id(
134
+ num_views,
135
+ start_id,
136
+ image_idxs_video,
137
+ rng,
138
+ max_interval=self.max_interval,
139
+ video_prob=0.8,
140
+ fix_interval_prob=0.5,
141
+ block_shuffle=16,
142
+ )
143
+ image_idxs = np.array(image_idxs_video)[pos]
144
+
145
+ else:
146
+ ordered_video = True
147
+ # ordered video with varying intervals
148
+ num_candidates = len(image_idxs)
149
+ max_id = min(num_candidates, int(num_views * (2 + 2 * rng.random())))
150
+
151
+ # 确保有足够的候选帧
152
+ if num_candidates < num_views:
153
+ # 如果候选帧不足,使用重复采样
154
+ image_idxs = sorted(rng.choice(image_idxs, size=num_views, replace=True))
155
+ else:
156
+ image_idxs = sorted(rng.permutation(image_idxs[:max_id])[:num_views])
157
+
158
+ if rand_val > 0.75:
159
+ ordered_video = False
160
+ image_idxs = rng.permutation(image_idxs)
161
+
162
+
163
+ views = []
164
+ for v, view_idx in enumerate(image_idxs):
165
+ scene_id = self.sceneids[view_idx]
166
+ scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
167
+
168
+ intrinsics = self.intrinsics[view_idx]
169
+ camera_pose = self.trajectories[view_idx]
170
+ basename = self.images[view_idx]
171
+
172
+ # Load RGB image
173
+ rgb_image = imread_pil(osp.join(scene_dir, "images", basename + ".jpg"))
174
+ # Load depthmap
175
+ depthmap = imread_cv2(
176
+ osp.join(scene_dir, "depth", basename + ".png"), cv2.IMREAD_UNCHANGED
177
+ )
178
+ depthmap = depthmap.astype(np.float32) / 1000
179
+ depthmap[~np.isfinite(depthmap)] = 0 # invalid
180
+
181
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
182
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
183
+ )
184
+
185
+ # generate img mask and raymap mask
186
+ img_mask, ray_mask = self.get_img_and_ray_masks(
187
+ self.is_metric, v, rng, p=[0.75, 0.2, 0.05]
188
+ )
189
+
190
+ views.append(
191
+ dict(
192
+ img=rgb_image,
193
+ depthmap=depthmap.astype(np.float32),
194
+ camera_pose=camera_pose.astype(np.float32),
195
+ camera_intrinsics=intrinsics.astype(np.float32),
196
+ dataset="ScanNet++",
197
+ label=self.scenes[scene_id] + "_" + basename,
198
+ instance=f"{str(idx)}_{str(view_idx)}",
199
+ is_metric=self.is_metric,
200
+ is_video=ordered_video,
201
+ quantile=np.array(0.99, dtype=np.float32),
202
+ img_mask=img_mask,
203
+ ray_mask=ray_mask,
204
+ camera_only=False,
205
+ depth_only=False,
206
+ single_view=False,
207
+ reset=False,
208
+ )
209
+ )
210
+ assert len(views) == num_views
211
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/smartportraits.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+
14
+ class SmartPortraits_Multi(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = True
18
+ self.is_metric = True
19
+ super().__init__(*args, **kwargs)
20
+ self.loaded_data = self._load_data()
21
+
22
+ def _load_data(self):
23
+ scenes = os.listdir(self.ROOT)
24
+ img_names = []
25
+ for scene in scenes:
26
+ scene_dir = osp.join(self.ROOT, scene)
27
+ rgb_dir = osp.join(scene_dir, "rgb")
28
+ basenames = sorted(
29
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")]
30
+ )
31
+ img_names.extend([(scene, basename) for basename in basenames])
32
+
33
+ self.img_names = img_names
34
+
35
+ def __len__(self):
36
+ return len(self.img_names)
37
+
38
+ def get_image_num(self):
39
+ return len(self.img_names)
40
+
41
+ def _get_views(self, idx, resolution, rng, num_views):
42
+ new_seed = rng.integers(0, 2**32) + idx
43
+ new_rng = np.random.default_rng(new_seed)
44
+ img_names = new_rng.choice(self.img_names, num_views, replace=False)
45
+
46
+ views = []
47
+ for v, img_name in enumerate(img_names):
48
+ # Load RGB image
49
+ scene, img_name = img_name
50
+ rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"))
51
+ depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy"))
52
+ depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
53
+
54
+ intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))[
55
+ "intrinsics"
56
+ ]
57
+ # camera pose is not provided, placeholder
58
+ camera_pose = np.eye(4)
59
+
60
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
61
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name
62
+ )
63
+
64
+ views.append(
65
+ dict(
66
+ img=rgb_image,
67
+ depthmap=depthmap.astype(np.float32),
68
+ camera_pose=camera_pose.astype(np.float32),
69
+ camera_intrinsics=intrinsics.astype(np.float32),
70
+ dataset="SmartPortraits",
71
+ label=img_name,
72
+ instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"),
73
+ is_metric=self.is_metric,
74
+ is_video=False,
75
+ quantile=np.array(0.98, dtype=np.float32),
76
+ img_mask=True,
77
+ ray_mask=False,
78
+ camera_only=False,
79
+ depth_only=False,
80
+ single_view=True,
81
+ reset=True,
82
+ )
83
+ )
84
+ assert len(views) == num_views
85
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/threedkb.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+
14
+ class ThreeDKenBurns(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = False
18
+ self.is_metric = False
19
+ super().__init__(*args, **kwargs)
20
+ self.loaded_data = self._load_data()
21
+
22
+ def _load_data(self):
23
+ self.scenes = os.listdir(self.ROOT)
24
+
25
+ offset = 0
26
+ scenes = []
27
+ sceneids = []
28
+ images = []
29
+ img_ids = []
30
+
31
+ j = 0
32
+ for scene in tqdm(self.scenes):
33
+ scene_dir = osp.join(self.ROOT, scene)
34
+ rgb_dir = osp.join(scene_dir, "rgb")
35
+ basenames = sorted(
36
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")]
37
+ )
38
+
39
+ num_imgs = len(basenames)
40
+ img_ids_ = list(np.arange(num_imgs) + offset)
41
+
42
+ img_ids.extend(img_ids_)
43
+ sceneids.extend([j] * num_imgs)
44
+ images.extend(basenames)
45
+ scenes.append(scene)
46
+
47
+ # offset groups
48
+ offset += num_imgs
49
+ j += 1
50
+
51
+ self.scenes = scenes
52
+ self.sceneids = sceneids
53
+ self.images = images
54
+ self.img_ids = img_ids
55
+
56
+ def __len__(self):
57
+ return len(self.img_ids)
58
+
59
+ def get_image_num(self):
60
+ return len(self.images)
61
+
62
+ def _get_views(self, idx, resolution, rng, num_views):
63
+ new_seed = rng.integers(0, 2**32) + idx
64
+ new_rng = np.random.default_rng(new_seed)
65
+ image_idxs = new_rng.choice(self.img_ids, num_views, replace=False)
66
+
67
+ views = []
68
+ for view_idx in image_idxs:
69
+ scene_id = self.sceneids[view_idx]
70
+ scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
71
+ rgb_dir = osp.join(scene_dir, "rgb")
72
+ depth_dir = osp.join(scene_dir, "depth")
73
+ cam_dir = osp.join(scene_dir, "cam")
74
+
75
+ basename = self.images[view_idx]
76
+
77
+ # Load RGB image
78
+ rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png"))
79
+ depthmap = imread_cv2(osp.join(depth_dir, basename + ".exr"))
80
+ depthmap[depthmap > 20000] = 0.0
81
+ depthmap = depthmap / 1000.0
82
+ cam = np.load(osp.join(cam_dir, basename + ".npz"))
83
+ intrinsics = cam["intrinsics"]
84
+ camera_pose = np.eye(4)
85
+
86
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
87
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
88
+ )
89
+
90
+ views.append(
91
+ dict(
92
+ img=rgb_image,
93
+ depthmap=depthmap.astype(np.float32),
94
+ camera_pose=camera_pose.astype(np.float32),
95
+ camera_intrinsics=intrinsics.astype(np.float32),
96
+ dataset="3DKenBurns",
97
+ label=self.scenes[scene_id] + "_" + basename,
98
+ instance=f"{str(idx)}_{str(view_idx)}",
99
+ is_metric=self.is_metric,
100
+ is_video=False,
101
+ quantile=np.array(1.0, dtype=np.float32),
102
+ img_mask=True,
103
+ ray_mask=False,
104
+ camera_only=False,
105
+ depth_only=False,
106
+ single_view=True,
107
+ reset=True,
108
+ )
109
+ )
110
+ assert len(views) == num_views
111
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/unreal4k.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import cv2
4
+ import numpy as np
5
+ import itertools
6
+ import os
7
+ import sys
8
+
9
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
10
+
11
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
12
+ from dust3r.utils.image import imread_cv2
13
+
14
+ R_conv = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).astype(
15
+ np.float32
16
+ )
17
+
18
+
19
+ class UnReal4K_Multi(BaseMultiViewDataset):
20
+
21
+ def __init__(self, ROOT, *args, **kwargs):
22
+ self.ROOT = ROOT
23
+ self.max_interval = 2
24
+ self.is_metric = True
25
+ super().__init__(*args, **kwargs)
26
+ # loading all
27
+ assert self.split is None
28
+ self._load_data()
29
+
30
+ def _load_data(self):
31
+ scene_dirs = sorted(
32
+ [
33
+ d
34
+ for d in os.listdir(self.ROOT)
35
+ if os.path.isdir(os.path.join(self.ROOT, d))
36
+ ]
37
+ )
38
+
39
+ offset = 0
40
+ scenes = []
41
+ sceneids = []
42
+ images = []
43
+ start_img_ids = []
44
+ scene_img_list = []
45
+ j = 0
46
+
47
+ seq_dirs = sorted(
48
+ [
49
+ os.path.join(self.ROOT, scene, mode)
50
+ for scene in scene_dirs
51
+ for mode in ["0", "1"]
52
+ ]
53
+ )
54
+ for seq_dir in seq_dirs:
55
+ basenames = sorted(
56
+ [f[:-8] for f in os.listdir(seq_dir) if f.endswith(".png")]
57
+ )
58
+ num_imgs = len(basenames)
59
+ img_ids = list(np.arange(num_imgs) + offset)
60
+ # start_img_ids_ = img_ids[:-self.num_views+1]
61
+ cut_off = (
62
+ self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
63
+ )
64
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
65
+
66
+ if num_imgs < cut_off:
67
+ print(f"Skipping {seq_dir}")
68
+ continue
69
+
70
+ start_img_ids.extend(start_img_ids_)
71
+ sceneids.extend([j] * num_imgs)
72
+ images.extend(basenames)
73
+ scenes.append(seq_dir)
74
+ scene_img_list.append(img_ids)
75
+
76
+ # offset groups
77
+ offset += num_imgs
78
+ j += 1
79
+
80
+ self.scenes = scenes
81
+ self.sceneids = sceneids
82
+ self.images = images
83
+ self.start_img_ids = start_img_ids
84
+ self.scene_img_list = scene_img_list
85
+
86
+ def __len__(self):
87
+ return len(self.start_img_ids) * 10
88
+
89
+ def get_image_num(self):
90
+ return len(self.images)
91
+
92
+ def get_stats(self):
93
+ return f"{len(self)//10} groups of views"
94
+
95
+ def _get_views(self, idx, resolution, rng, num_views):
96
+ idx = idx // 10
97
+ start_id = self.start_img_ids[idx]
98
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]]
99
+ pos, ordered_video = self.get_seq_from_start_id(
100
+ num_views, start_id, all_image_ids, rng, max_interval=self.max_interval
101
+ )
102
+ image_idxs = np.array(all_image_ids)[pos]
103
+
104
+ views = []
105
+
106
+ for v, view_idx in enumerate(image_idxs):
107
+ scene_id = self.sceneids[view_idx]
108
+ scene_dir = self.scenes[scene_id]
109
+ basename = self.images[view_idx]
110
+
111
+ img = basename + "_rgb.png"
112
+ image = imread_cv2(osp.join(scene_dir, img))
113
+ depthmap = np.load(osp.join(scene_dir, basename + "_depth.npy"))
114
+ camera_params = np.load(osp.join(scene_dir, basename + ".npz"))
115
+
116
+ intrinsics = camera_params["intrinsics"].astype(np.float32)
117
+ camera_pose = camera_params["cam2world"].astype(np.float32)
118
+
119
+ camera_pose = R_conv @ camera_pose
120
+
121
+ sky_mask = depthmap >= 1000
122
+ depthmap[sky_mask] = -1.0 # sky
123
+ threshold = (
124
+ np.percentile(depthmap[depthmap > 0], 98)
125
+ if depthmap[depthmap > 0].size > 0
126
+ else 0
127
+ )
128
+ depthmap[depthmap > threshold] = 0.0
129
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
130
+ image, depthmap, intrinsics, resolution, rng, info=(scene_dir, img)
131
+ )
132
+
133
+ # generate img mask and raymap mask
134
+ img_mask, ray_mask = self.get_img_and_ray_masks(
135
+ self.is_metric, v, rng, p=[0.75, 0.2, 0.05]
136
+ )
137
+
138
+ views.append(
139
+ dict(
140
+ img=image,
141
+ depthmap=depthmap,
142
+ camera_pose=camera_pose, # cam2world
143
+ camera_intrinsics=intrinsics,
144
+ dataset="UnReal4K",
145
+ label=scene_dir,
146
+ is_metric=self.is_metric,
147
+ instance=scene_dir + "_" + img,
148
+ is_video=ordered_video,
149
+ quantile=np.array(1.0, dtype=np.float32),
150
+ img_mask=img_mask,
151
+ ray_mask=ray_mask,
152
+ camera_only=False,
153
+ depth_only=False,
154
+ single_view=False,
155
+ reset=False,
156
+ )
157
+ )
158
+ assert len(views) == num_views
159
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/corr.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # modified from DUSt3R
6
+
7
+ import numpy as np
8
+ from dust3r.utils.device import to_numpy
9
+ from dust3r.utils.geometry import inv, geotrf
10
+
11
+
12
+ def reproject_view(pts3d, view2):
13
+ shape = view2["pts3d"].shape[:2]
14
+ return reproject(
15
+ pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
16
+ )
17
+
18
+
19
+ def reproject(pts3d, K, world2cam, shape):
20
+ H, W, THREE = pts3d.shape
21
+ assert THREE == 3
22
+
23
+ # reproject in camera2 space
24
+ with np.errstate(divide="ignore", invalid="ignore"):
25
+ pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
26
+
27
+ # quantize to pixel positions
28
+ return (H, W), ravel_xy(pos, shape)
29
+
30
+
31
+ def ravel_xy(pos, shape):
32
+ H, W = shape
33
+ with np.errstate(invalid="ignore"):
34
+ qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
35
+ quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
36
+ min=0, max=H - 1, out=qy
37
+ )
38
+ return quantized_pos
39
+
40
+
41
+ def unravel_xy(pos, shape):
42
+ # convert (x+W*y) back to 2d (x,y) coordinates
43
+ return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
44
+
45
+
46
+ def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
47
+ is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
48
+ pos1 = is_reciprocal1.nonzero()[0]
49
+ pos2 = corres_1_to_2[pos1]
50
+ if ret_recip:
51
+ return is_reciprocal1, pos1, pos2
52
+ return pos1, pos2
53
+
54
+
55
+ def extract_correspondences_from_pts3d(
56
+ view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
57
+ ):
58
+ view1, view2 = to_numpy((view1, view2))
59
+ # project pixels from image1 --> 3d points --> image2 pixels
60
+ shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
61
+ shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
62
+
63
+ # compute reciprocal correspondences:
64
+ # pos1 == valid pixels (correspondences) in image1
65
+ is_reciprocal1, pos1, pos2 = reciprocal_1d(
66
+ corres1_to_2, corres2_to_1, ret_recip=True
67
+ )
68
+ is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
69
+
70
+ if target_n_corres is None:
71
+ if ret_xy:
72
+ pos1 = unravel_xy(pos1, shape1)
73
+ pos2 = unravel_xy(pos2, shape2)
74
+ return pos1, pos2
75
+
76
+ available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
77
+ target_n_positives = int(target_n_corres * (1 - nneg))
78
+ n_positives = min(len(pos1), target_n_positives)
79
+ n_negatives = min(target_n_corres - n_positives, available_negatives)
80
+
81
+ if n_negatives + n_positives != target_n_corres:
82
+ # should be really rare => when there are not enough negatives
83
+ # in that case, break nneg and add a few more positives ?
84
+ n_positives = target_n_corres - n_negatives
85
+ assert n_positives <= len(pos1)
86
+
87
+ assert n_positives <= len(pos1)
88
+ assert n_positives <= len(pos2)
89
+ assert n_negatives <= (~is_reciprocal1).sum()
90
+ assert n_negatives <= (~is_reciprocal2).sum()
91
+ assert n_positives + n_negatives == target_n_corres
92
+
93
+ valid = np.ones(n_positives, dtype=bool)
94
+ if n_positives < len(pos1):
95
+ # random sub-sampling of valid correspondences
96
+ perm = rng.permutation(len(pos1))[:n_positives]
97
+ pos1 = pos1[perm]
98
+ pos2 = pos2[perm]
99
+
100
+ if n_negatives > 0:
101
+ # add false correspondences if not enough
102
+ def norm(p):
103
+ return p / p.sum()
104
+
105
+ pos1 = np.r_[
106
+ pos1,
107
+ rng.choice(
108
+ shape1[0] * shape1[1],
109
+ size=n_negatives,
110
+ replace=False,
111
+ p=norm(~is_reciprocal1),
112
+ ),
113
+ ]
114
+ pos2 = np.r_[
115
+ pos2,
116
+ rng.choice(
117
+ shape2[0] * shape2[1],
118
+ size=n_negatives,
119
+ replace=False,
120
+ p=norm(~is_reciprocal2),
121
+ ),
122
+ ]
123
+ valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
124
+
125
+ # convert (x+W*y) back to 2d (x,y) coordinates
126
+ if ret_xy:
127
+ pos1 = unravel_xy(pos1, shape1)
128
+ pos2 = unravel_xy(pos2, shape2)
129
+ return pos1, pos2, valid
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/cropping.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # croppping utilities
6
+ # --------------------------------------------------------
7
+ import PIL.Image
8
+ import os
9
+
10
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
11
+ import cv2 # noqa
12
+ import numpy as np # noqa
13
+ from dust3r.utils.geometry import (
14
+ colmap_to_opencv_intrinsics,
15
+ opencv_to_colmap_intrinsics,
16
+ ) # noqa
17
+
18
+ try:
19
+ lanczos = PIL.Image.Resampling.LANCZOS
20
+ bicubic = PIL.Image.Resampling.BICUBIC
21
+ except AttributeError:
22
+ lanczos = PIL.Image.LANCZOS
23
+ bicubic = PIL.Image.BICUBIC
24
+
25
+
26
+ class ImageList:
27
+ """Convenience class to aply the same operation to a whole set of images."""
28
+
29
+ def __init__(self, images):
30
+ if not isinstance(images, (tuple, list, set)):
31
+ images = [images]
32
+ self.images = []
33
+ for image in images:
34
+ if not isinstance(image, PIL.Image.Image):
35
+ image = PIL.Image.fromarray(image)
36
+ self.images.append(image)
37
+
38
+ def __len__(self):
39
+ return len(self.images)
40
+
41
+ def to_pil(self):
42
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
43
+
44
+ @property
45
+ def size(self):
46
+ sizes = [im.size for im in self.images]
47
+ assert all(sizes[0] == s for s in sizes)
48
+ return sizes[0]
49
+
50
+ def resize(self, *args, **kwargs):
51
+ return ImageList(self._dispatch("resize", *args, **kwargs))
52
+
53
+ def crop(self, *args, **kwargs):
54
+ return ImageList(self._dispatch("crop", *args, **kwargs))
55
+
56
+ def _dispatch(self, func, *args, **kwargs):
57
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
58
+
59
+
60
+ def rescale_image_depthmap(
61
+ image, depthmap, camera_intrinsics, output_resolution, force=True
62
+ ):
63
+ """Jointly rescale a (image, depthmap)
64
+ so that (out_width, out_height) >= output_res
65
+ """
66
+ image = ImageList(image)
67
+ input_resolution = np.array(image.size) # (W,H)
68
+ output_resolution = np.array(output_resolution)
69
+ if depthmap is not None:
70
+ # can also use this with masks instead of depthmaps
71
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
72
+
73
+ # define output resolution
74
+ assert output_resolution.shape == (2,)
75
+ scale_final = max(output_resolution / image.size) + 1e-8
76
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
77
+ return (image.to_pil(), depthmap, camera_intrinsics)
78
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
79
+
80
+ # first rescale the image so that it contains the crop
81
+ image = image.resize(
82
+ output_resolution, resample=lanczos if scale_final < 1 else bicubic
83
+ )
84
+ if depthmap is not None:
85
+ depthmap = cv2.resize(
86
+ depthmap,
87
+ output_resolution,
88
+ fx=scale_final,
89
+ fy=scale_final,
90
+ interpolation=cv2.INTER_NEAREST,
91
+ )
92
+
93
+ # no offset here; simple rescaling
94
+ camera_intrinsics = camera_matrix_of_crop(
95
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
96
+ )
97
+
98
+ return image.to_pil(), depthmap, camera_intrinsics
99
+
100
+
101
+ def camera_matrix_of_crop(
102
+ input_camera_matrix,
103
+ input_resolution,
104
+ output_resolution,
105
+ scaling=1,
106
+ offset_factor=0.5,
107
+ offset=None,
108
+ ):
109
+ # Margins to offset the origin
110
+ margins = np.asarray(input_resolution) * scaling - output_resolution
111
+ assert np.all(margins >= 0.0)
112
+ if offset is None:
113
+ offset = offset_factor * margins
114
+
115
+ # Generate new camera parameters
116
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
117
+ output_camera_matrix_colmap[:2, :] *= scaling
118
+ output_camera_matrix_colmap[:2, 2] -= offset
119
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
120
+
121
+ return output_camera_matrix
122
+
123
+
124
+ def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
125
+ """
126
+ Return a crop of the input view.
127
+ """
128
+ image = ImageList(image)
129
+ l, t, r, b = crop_bbox
130
+
131
+ image = image.crop((l, t, r, b))
132
+ depthmap = depthmap[t:b, l:r]
133
+
134
+ camera_intrinsics = camera_intrinsics.copy()
135
+ camera_intrinsics[0, 2] -= l
136
+ camera_intrinsics[1, 2] -= t
137
+
138
+ return image.to_pil(), depthmap, camera_intrinsics
139
+
140
+
141
+ def bbox_from_intrinsics_in_out(
142
+ input_camera_matrix, output_camera_matrix, output_resolution
143
+ ):
144
+ out_width, out_height = output_resolution
145
+ l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
146
+ crop_bbox = (l, t, l + out_width, t + out_height)
147
+ return crop_bbox
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/utils/transforms.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # DUST3R default transforms
6
+ # --------------------------------------------------------
7
+ import torchvision.transforms as tvf
8
+ from dust3r.utils.image import ImgNorm
9
+
10
+ # define the standard image transforms
11
+ ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
12
+
13
+
14
+ def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
15
+ if isinstance(value, (int, float)):
16
+ if value < 0:
17
+ raise ValueError(f"If is a single number, it must be non negative.")
18
+ value = [center - float(value), center + float(value)]
19
+ if clip_first_on_zero:
20
+ value[0] = max(value[0], 0.0)
21
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
22
+ value = [float(value[0]), float(value[1])]
23
+ else:
24
+ raise TypeError(f"should be a single number or a list/tuple with length 2.")
25
+
26
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
27
+ raise ValueError(f"values should be between {bound}, but got {value}.")
28
+
29
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
30
+ # or (0., 0.) for hue, do nothing
31
+ if value[0] == value[1] == center:
32
+ return None
33
+ else:
34
+ return tuple(value)
35
+
36
+
37
+ import torch
38
+ import torchvision.transforms.functional as F
39
+
40
+
41
+ def SeqColorJitter():
42
+ """
43
+ Return a color jitter transform with same random parameters
44
+ """
45
+ brightness = _check_input(0.5)
46
+ contrast = _check_input(0.5)
47
+ saturation = _check_input(0.5)
48
+ hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
49
+
50
+ fn_idx = torch.randperm(4)
51
+ brightness_factor = (
52
+ None
53
+ if brightness is None
54
+ else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
55
+ )
56
+ contrast_factor = (
57
+ None
58
+ if contrast is None
59
+ else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
60
+ )
61
+ saturation_factor = (
62
+ None
63
+ if saturation is None
64
+ else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
65
+ )
66
+ hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
67
+
68
+ def _color_jitter(img):
69
+ for fn_id in fn_idx:
70
+ if fn_id == 0 and brightness_factor is not None:
71
+ img = F.adjust_brightness(img, brightness_factor)
72
+ elif fn_id == 1 and contrast_factor is not None:
73
+ img = F.adjust_contrast(img, contrast_factor)
74
+ elif fn_id == 2 and saturation_factor is not None:
75
+ img = F.adjust_saturation(img, saturation_factor)
76
+ elif fn_id == 3 and hue_factor is not None:
77
+ img = F.adjust_hue(img, hue_factor)
78
+ return ImgNorm(img)
79
+
80
+ return _color_jitter
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/waymo.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import os
3
+ import numpy as np
4
+ import sys
5
+
6
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
7
+ import h5py
8
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
9
+ from dust3r.utils.image import imread_cv2
10
+
11
+
12
+ class Waymo_Multi(BaseMultiViewDataset):
13
+ """Dataset of outdoor street scenes, 5 images each time"""
14
+
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.max_interval = 8
18
+ self.video = True
19
+ self.is_metric = True
20
+ super().__init__(*args, **kwargs)
21
+ assert self.split is None
22
+ self._load_data()
23
+
24
+ def load_invalid_dict(self, h5_file_path):
25
+ invalid_dict = {}
26
+ with h5py.File(h5_file_path, "r") as h5f:
27
+ for scene in h5f:
28
+ data = h5f[scene]["invalid_pairs"][:]
29
+ invalid_pairs = set(
30
+ tuple(pair.decode("utf-8").split("_")) for pair in data
31
+ )
32
+ invalid_dict[scene] = invalid_pairs
33
+ return invalid_dict
34
+
35
+ def _load_data(self):
36
+ invalid_dict = self.load_invalid_dict(
37
+ os.path.join(self.ROOT, "invalid_files.h5")
38
+ )
39
+ scene_dirs = sorted(
40
+ [
41
+ d
42
+ for d in os.listdir(self.ROOT)
43
+ if os.path.isdir(os.path.join(self.ROOT, d))
44
+ ]
45
+ )
46
+ offset = 0
47
+ scenes = []
48
+ sceneids = []
49
+ images = []
50
+ start_img_ids = []
51
+ scene_img_list = []
52
+ is_video = []
53
+ j = 0
54
+
55
+ for scene in scene_dirs:
56
+ scene_dir = osp.join(self.ROOT, scene)
57
+ invalid_pairs = invalid_dict.get(scene, set())
58
+ seq2frames = {}
59
+ for f in os.listdir(scene_dir):
60
+ if not f.endswith(".jpg"):
61
+ continue
62
+ basename = f[:-4]
63
+ frame_id = basename.split("_")[0]
64
+ seq_id = basename.split("_")[1]
65
+ if seq_id == "5":
66
+ continue
67
+ if (seq_id, frame_id) in invalid_pairs:
68
+ continue # Skip invalid files
69
+ if seq_id not in seq2frames:
70
+ seq2frames[seq_id] = []
71
+ seq2frames[seq_id].append(frame_id)
72
+
73
+ for seq_id, frame_ids in seq2frames.items():
74
+ frame_ids = sorted(frame_ids)
75
+ num_imgs = len(frame_ids)
76
+ img_ids = list(np.arange(num_imgs) + offset)
77
+ cut_off = (
78
+ self.num_views
79
+ if not self.allow_repeat
80
+ else max(self.num_views // 3, 3)
81
+ )
82
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
83
+
84
+ if num_imgs < cut_off:
85
+ print(f"Skipping {scene}_{seq_id}")
86
+ continue
87
+
88
+ scenes.append((scene, seq_id))
89
+ sceneids.extend([j] * num_imgs)
90
+ images.extend(frame_ids)
91
+ start_img_ids.extend(start_img_ids_)
92
+ scene_img_list.append(img_ids)
93
+
94
+ offset += num_imgs
95
+ j += 1
96
+
97
+ self.scenes = scenes
98
+ self.sceneids = sceneids
99
+ self.images = images
100
+ self.start_img_ids = start_img_ids
101
+ self.scene_img_list = scene_img_list
102
+ self.is_video = is_video
103
+
104
+ def __len__(self):
105
+ return len(self.start_img_ids)
106
+
107
+ def get_image_num(self):
108
+ return len(self.images)
109
+
110
+ def get_stats(self):
111
+ return f"{len(self)} groups of views"
112
+
113
+ def _get_views(self, idx, resolution, rng, num_views):
114
+ start_id = self.start_img_ids[idx]
115
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]]
116
+ _, seq_id = self.scenes[self.sceneids[start_id]]
117
+ max_interval = self.max_interval // 2 if seq_id == "4" else self.max_interval
118
+ pos, ordered_video = self.get_seq_from_start_id(
119
+ num_views,
120
+ start_id,
121
+ all_image_ids,
122
+ rng,
123
+ max_interval=max_interval,
124
+ video_prob=0.9,
125
+ fix_interval_prob=0.9,
126
+ block_shuffle=16,
127
+ )
128
+ image_idxs = np.array(all_image_ids)[pos]
129
+ views = []
130
+ ordered_video = True
131
+
132
+ views = []
133
+
134
+ for v, view_idx in enumerate(image_idxs):
135
+ scene_id = self.sceneids[view_idx]
136
+ scene_dir, seq_id = self.scenes[scene_id]
137
+ scene_dir = osp.join(self.ROOT, scene_dir)
138
+ frame_id = self.images[view_idx]
139
+
140
+ impath = f"{frame_id}_{seq_id}"
141
+ image = imread_cv2(osp.join(scene_dir, impath + ".jpg"))
142
+ depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr"))
143
+ camera_params = np.load(osp.join(scene_dir, impath + ".npz"))
144
+
145
+ intrinsics = np.float32(camera_params["intrinsics"])
146
+ camera_pose = np.float32(camera_params["cam2world"])
147
+
148
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
149
+ image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath)
150
+ )
151
+
152
+ # generate img mask and raymap mask
153
+ img_mask, ray_mask = self.get_img_and_ray_masks(
154
+ self.is_metric, v, rng, p=[0.85, 0.10, 0.05]
155
+ )
156
+
157
+ views.append(
158
+ dict(
159
+ img=image,
160
+ depthmap=depthmap,
161
+ camera_pose=camera_pose, # cam2world
162
+ camera_intrinsics=intrinsics,
163
+ dataset="Waymo",
164
+ label=osp.relpath(scene_dir, self.ROOT),
165
+ is_metric=self.is_metric,
166
+ instance=osp.join(scene_dir, impath + ".jpg"),
167
+ is_video=ordered_video,
168
+ quantile=np.array(0.98, dtype=np.float32),
169
+ img_mask=img_mask,
170
+ ray_mask=ray_mask,
171
+ camera_only=False,
172
+ depth_only=False,
173
+ single_view=False,
174
+ reset=False,
175
+ )
176
+ )
177
+
178
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/datasets/wildrgbd.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import sys
3
+
4
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from dust3r.datasets.co3d import Co3d_Multi
9
+ from dust3r.utils.image import imread_cv2
10
+
11
+
12
+ class WildRGBD_Multi(Co3d_Multi):
13
+ def __init__(self, mask_bg="rand", *args, ROOT, **kwargs):
14
+ super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
15
+ self.dataset_label = "WildRGBD"
16
+ self.is_metric = True
17
+ # load all scenes
18
+ self.scenes.pop(("box", "scenes/scene_257"), None)
19
+ self.scene_list = list(self.scenes.keys())
20
+ cut_off = (
21
+ self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
22
+ )
23
+ self.cut_off = cut_off
24
+ self.all_ref_imgs = [
25
+ (key, value)
26
+ for key, values in self.scenes.items()
27
+ for value in values[: len(values) - cut_off + 1]
28
+ ]
29
+ self.invalidate = {scene: {} for scene in self.scene_list}
30
+ self.invalid_scenes = {scene: False for scene in self.scene_list}
31
+
32
+ def _get_metadatapath(self, obj, instance, view_idx):
33
+ return osp.join(self.ROOT, obj, instance, "metadata", f"{view_idx:0>5d}.npz")
34
+
35
+ def _get_impath(self, obj, instance, view_idx):
36
+ return osp.join(self.ROOT, obj, instance, "rgb", f"{view_idx:0>5d}.jpg")
37
+
38
+ def _get_depthpath(self, obj, instance, view_idx):
39
+ return osp.join(self.ROOT, obj, instance, "depth", f"{view_idx:0>5d}.png")
40
+
41
+ def _get_maskpath(self, obj, instance, view_idx):
42
+ return osp.join(self.ROOT, obj, instance, "masks", f"{view_idx:0>5d}.png")
43
+
44
+ def _read_depthmap(self, depthpath, input_metadata):
45
+ # We store depths in the depth scale of 1000.
46
+ # That is, when we load depth image and divide by 1000, we could get depth in meters.
47
+ depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
48
+ depthmap = depthmap.astype(np.float32) / 1000.0
49
+ return depthmap
50
+
51
+ def _get_views(self, idx, resolution, rng, num_views):
52
+ views = super()._get_views(idx, resolution, rng, num_views)
53
+ for view in views:
54
+ assert view["is_metric"]
55
+ view["quantile"] = np.array(0.96, dtype=np.float32)
56
+ return views
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/camera.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from croco.models.blocks import Mlp
8
+
9
+
10
+ inf = float("inf")
11
+
12
+
13
+ class PoseDecoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ hidden_size=768,
17
+ mlp_ratio=4,
18
+ pose_encoding_type="absT_quaR",
19
+ ):
20
+ super().__init__()
21
+
22
+ self.pose_encoding_type = pose_encoding_type
23
+ if self.pose_encoding_type == "absT_quaR":
24
+ self.target_dim = 7
25
+
26
+ self.mlp = Mlp(
27
+ in_features=hidden_size,
28
+ hidden_features=int(hidden_size * mlp_ratio),
29
+ out_features=self.target_dim,
30
+ drop=0,
31
+ )
32
+
33
+ def forward(
34
+ self,
35
+ pose_feat,
36
+ ):
37
+ """
38
+ pose_feat: BxC
39
+ preliminary_cameras: cameras in opencv coordinate.
40
+ """
41
+
42
+ pred_cameras = self.mlp(pose_feat) # Bx7, 3 for absT, 4 for quaR
43
+ return pred_cameras
44
+
45
+
46
+ class PoseEncoder(nn.Module):
47
+ def __init__(
48
+ self,
49
+ hidden_size=768,
50
+ mlp_ratio=4,
51
+ pose_mode=("exp", -inf, inf),
52
+ pose_encoding_type="absT_quaR",
53
+ ):
54
+ super().__init__()
55
+ self.pose_encoding_type = pose_encoding_type
56
+ self.pose_mode = pose_mode
57
+
58
+ if self.pose_encoding_type == "absT_quaR":
59
+ self.target_dim = 7
60
+
61
+ self.embed_pose = PoseEmbedding(
62
+ target_dim=self.target_dim,
63
+ out_dim=hidden_size,
64
+ n_harmonic_functions=10,
65
+ append_input=True,
66
+ )
67
+ self.pose_encoder = Mlp(
68
+ in_features=self.embed_pose.out_dim,
69
+ hidden_features=int(hidden_size * mlp_ratio),
70
+ out_features=hidden_size,
71
+ drop=0,
72
+ )
73
+
74
+ def forward(self, camera):
75
+ from dust3r.heads.postprocess import postprocess_pose
76
+ pose_enc = camera_to_pose_encoding(
77
+ camera,
78
+ pose_encoding_type=self.pose_encoding_type,
79
+ ).to(camera.dtype)
80
+ pose_enc = postprocess_pose(pose_enc, self.pose_mode, inverse=True)
81
+ pose_feat = self.embed_pose(pose_enc)
82
+ pose_feat = self.pose_encoder(pose_feat)
83
+ return pose_feat
84
+
85
+
86
+ class HarmonicEmbedding(torch.nn.Module):
87
+ def __init__(
88
+ self,
89
+ n_harmonic_functions: int = 6,
90
+ omega_0: float = 1.0,
91
+ logspace: bool = True,
92
+ append_input: bool = True,
93
+ ) -> None:
94
+ """
95
+ The harmonic embedding layer supports the classical
96
+ Nerf positional encoding described in
97
+ `NeRF <https://arxiv.org/abs/2003.08934>`_
98
+ and the integrated position encoding in
99
+ `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
100
+
101
+ During the inference you can provide the extra argument `diag_cov`.
102
+
103
+ If `diag_cov is None`, it converts
104
+ rays parametrized with a `ray_bundle` to 3D points by
105
+ extending each ray according to the corresponding length.
106
+ Then it converts each feature
107
+ (i.e. vector along the last dimension) in `x`
108
+ into a series of harmonic features `embedding`,
109
+ where for each i in range(dim) the following are present
110
+ in embedding[...]::
111
+
112
+ [
113
+ sin(f_1*x[..., i]),
114
+ sin(f_2*x[..., i]),
115
+ ...
116
+ sin(f_N * x[..., i]),
117
+ cos(f_1*x[..., i]),
118
+ cos(f_2*x[..., i]),
119
+ ...
120
+ cos(f_N * x[..., i]),
121
+ x[..., i], # only present if append_input is True.
122
+ ]
123
+
124
+ where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
125
+ denoting the i-th frequency of the harmonic embedding.
126
+
127
+
128
+ If `diag_cov is not None`, it approximates
129
+ conical frustums following a ray bundle as gaussians,
130
+ defined by x, the means of the gaussians and diag_cov,
131
+ the diagonal covariances.
132
+ Then it converts each gaussian
133
+ into a series of harmonic features `embedding`,
134
+ where for each i in range(dim) the following are present
135
+ in embedding[...]::
136
+
137
+ [
138
+ sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
139
+ sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
140
+ ...
141
+ sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
142
+ cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
143
+ cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
144
+ ...
145
+ cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
146
+ x[..., i], # only present if append_input is True.
147
+ ]
148
+
149
+ where N equals `n_harmonic_functions-1`, and f_i is a scalar
150
+ denoting the i-th frequency of the harmonic embedding.
151
+
152
+ If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
153
+ powers of 2:
154
+ `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
155
+
156
+ If `logspace==False`, frequencies are linearly spaced between
157
+ `1.0` and `2**(n_harmonic_functions-1)`:
158
+ `f_1, ..., f_N = torch.linspace(
159
+ 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
160
+ )`
161
+
162
+ Note that `x` is also premultiplied by the base frequency `omega_0`
163
+ before evaluating the harmonic functions.
164
+
165
+ Args:
166
+ n_harmonic_functions: int, number of harmonic
167
+ features
168
+ omega_0: float, base frequency
169
+ logspace: bool, Whether to space the frequencies in
170
+ logspace or linear space
171
+ append_input: bool, whether to concat the original
172
+ input to the harmonic embedding. If true the
173
+ output is of the form (embed.sin(), embed.cos(), x)
174
+ """
175
+ super().__init__()
176
+
177
+ if logspace:
178
+ frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32)
179
+ else:
180
+ frequencies = torch.linspace(
181
+ 1.0,
182
+ 2.0 ** (n_harmonic_functions - 1),
183
+ n_harmonic_functions,
184
+ dtype=torch.float32,
185
+ )
186
+
187
+ self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
188
+ self.register_buffer(
189
+ "_zero_half_pi",
190
+ torch.tensor([0.0, 0.5 * torch.pi]),
191
+ persistent=False,
192
+ )
193
+ self.append_input = append_input
194
+
195
+ def forward(
196
+ self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
197
+ ) -> torch.Tensor:
198
+ """
199
+ Args:
200
+ x: tensor of shape [..., dim]
201
+ diag_cov: An optional tensor of shape `(..., dim)`
202
+ representing the diagonal covariance matrices of our Gaussians, joined with x
203
+ as means of the Gaussians.
204
+
205
+ Returns:
206
+ embedding: a harmonic embedding of `x` of shape
207
+ [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
208
+ """
209
+
210
+ embed = x[..., None] * self._frequencies
211
+
212
+ embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
213
+
214
+ embed = embed.sin()
215
+ if diag_cov is not None:
216
+ x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
217
+ exp_var = torch.exp(-0.5 * x_var)
218
+
219
+ embed = embed * exp_var[..., None, :, :]
220
+
221
+ embed = embed.reshape(*x.shape[:-1], -1)
222
+
223
+ if self.append_input:
224
+ return torch.cat([embed, x], dim=-1)
225
+ return embed
226
+
227
+ @staticmethod
228
+ def get_output_dim_static(
229
+ input_dims: int, n_harmonic_functions: int, append_input: bool
230
+ ) -> int:
231
+ """
232
+ Utility to help predict the shape of the output of `forward`.
233
+
234
+ Args:
235
+ input_dims: length of the last dimension of the input tensor
236
+ n_harmonic_functions: number of embedding frequencies
237
+ append_input: whether or not to concat the original
238
+ input to the harmonic embedding
239
+ Returns:
240
+ int: the length of the last dimension of the output tensor
241
+ """
242
+ return input_dims * (2 * n_harmonic_functions + int(append_input))
243
+
244
+ def get_output_dim(self, input_dims: int = 3) -> int:
245
+ """
246
+ Same as above. The default for input_dims is 3 for 3D applications
247
+ which use harmonic embedding for positional encoding,
248
+ so the input might be xyz.
249
+ """
250
+ return self.get_output_dim_static(
251
+ input_dims, len(self._frequencies), self.append_input
252
+ )
253
+
254
+
255
+ class PoseEmbedding(nn.Module):
256
+ def __init__(self, target_dim, out_dim, n_harmonic_functions=10, append_input=True):
257
+ super().__init__()
258
+
259
+ self._emb_pose = HarmonicEmbedding(
260
+ n_harmonic_functions=n_harmonic_functions, append_input=append_input
261
+ )
262
+
263
+ self.out_dim = self._emb_pose.get_output_dim(target_dim)
264
+
265
+ def forward(self, pose_encoding):
266
+ e_pose_encoding = self._emb_pose(pose_encoding)
267
+ return e_pose_encoding
268
+
269
+
270
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
271
+ """
272
+ Returns torch.sqrt(torch.max(0, x))
273
+ but with a zero subgradient where x is 0.
274
+ """
275
+ ret = torch.zeros_like(x)
276
+ positive_mask = x > 0
277
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
278
+ return ret
279
+
280
+
281
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
282
+ """
283
+ Convert rotations given as rotation matrices to quaternions.
284
+
285
+ Args:
286
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
287
+
288
+ Returns:
289
+ quaternions with real part first, as tensor of shape (..., 4).
290
+ """
291
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
292
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
293
+
294
+ batch_dim = matrix.shape[:-2]
295
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
296
+ matrix.reshape(batch_dim + (9,)), dim=-1
297
+ )
298
+
299
+ q_abs = _sqrt_positive_part(
300
+ torch.stack(
301
+ [
302
+ 1.0 + m00 + m11 + m22,
303
+ 1.0 + m00 - m11 - m22,
304
+ 1.0 - m00 + m11 - m22,
305
+ 1.0 - m00 - m11 + m22,
306
+ ],
307
+ dim=-1,
308
+ )
309
+ )
310
+
311
+ quat_by_rijk = torch.stack(
312
+ [
313
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
314
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
315
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
316
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
317
+ ],
318
+ dim=-2,
319
+ )
320
+
321
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
322
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
323
+
324
+ out = quat_candidates[
325
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
326
+ ].reshape(batch_dim + (4,))
327
+ return standardize_quaternion(out)
328
+
329
+
330
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
331
+ """
332
+ Convert a unit quaternion to a standard form: one in which the real
333
+ part is non negative.
334
+
335
+ Args:
336
+ quaternions: Quaternions with real part first,
337
+ as tensor of shape (..., 4).
338
+
339
+ Returns:
340
+ Standardized quaternions as tensor of shape (..., 4).
341
+ """
342
+ quaternions = F.normalize(quaternions, p=2, dim=-1)
343
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
344
+
345
+
346
+ def camera_to_pose_encoding(
347
+ camera,
348
+ pose_encoding_type="absT_quaR",
349
+ ):
350
+ """
351
+ Inverse to pose_encoding_to_camera
352
+ camera: opencv, cam2world
353
+ """
354
+ if pose_encoding_type == "absT_quaR":
355
+
356
+ quaternion_R = matrix_to_quaternion(camera[:, :3, :3])
357
+
358
+ pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1)
359
+ else:
360
+ raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
361
+
362
+ return pose_encoding
363
+
364
+
365
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
366
+ """
367
+ Convert rotations given as quaternions to rotation matrices.
368
+
369
+ Args:
370
+ quaternions: quaternions with real part first,
371
+ as tensor of shape (..., 4).
372
+
373
+ Returns:
374
+ Rotation matrices as tensor of shape (..., 3, 3).
375
+ """
376
+ r, i, j, k = torch.unbind(quaternions, -1)
377
+
378
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
379
+
380
+ o = torch.stack(
381
+ (
382
+ 1 - two_s * (j * j + k * k),
383
+ two_s * (i * j - k * r),
384
+ two_s * (i * k + j * r),
385
+ two_s * (i * j + k * r),
386
+ 1 - two_s * (i * i + k * k),
387
+ two_s * (j * k - i * r),
388
+ two_s * (i * k - j * r),
389
+ two_s * (j * k + i * r),
390
+ 1 - two_s * (i * i + j * j),
391
+ ),
392
+ -1,
393
+ )
394
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
395
+
396
+
397
+ def pose_encoding_to_camera(
398
+ pose_encoding,
399
+ pose_encoding_type="absT_quaR",
400
+ ):
401
+ """
402
+ Args:
403
+ pose_encoding: A tensor of shape `BxC`, containing a batch of
404
+ `B` `C`-dimensional pose encodings.
405
+ pose_encoding_type: The type of pose encoding,
406
+ """
407
+
408
+ if pose_encoding_type == "absT_quaR":
409
+
410
+ abs_T = pose_encoding[:, :3]
411
+ quaternion_R = pose_encoding[:, 3:7]
412
+ R = quaternion_to_matrix(quaternion_R)
413
+ else:
414
+ raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
415
+
416
+ c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device)
417
+ c2w_mats = c2w_mats[None].repeat(len(R), 1, 1)
418
+ c2w_mats[:, :3, :3] = R
419
+ c2w_mats[:, :3, 3] = abs_T
420
+
421
+ return c2w_mats
422
+
423
+
424
+ def quaternion_conjugate(q):
425
+ """Compute the conjugate of quaternion q (w, x, y, z)."""
426
+
427
+ q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
428
+ return q_conj
429
+
430
+
431
+ def quaternion_multiply(q1, q2):
432
+ """Multiply two quaternions q1 and q2."""
433
+ w1, x1, y1, z1 = q1.unbind(dim=-1)
434
+ w2, x2, y2, z2 = q2.unbind(dim=-1)
435
+
436
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
437
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
438
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
439
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
440
+
441
+ return torch.stack((w, x, y, z), dim=-1)
442
+
443
+
444
+ def rotate_vector(q, v):
445
+ """Rotate vector v by quaternion q."""
446
+ q_vec = q[..., 1:]
447
+ q_w = q[..., :1]
448
+
449
+ t = 2.0 * torch.cross(q_vec, v, dim=-1)
450
+ v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1)
451
+ return v_rot
452
+
453
+
454
+ def relative_pose_absT_quatR(t1, q1, t2, q2):
455
+ """Compute the relative translation and quaternion between two poses."""
456
+
457
+ q1_inv = quaternion_conjugate(q1)
458
+
459
+ q_rel = quaternion_multiply(q1_inv, q2)
460
+
461
+ delta_t = t2 - t1
462
+ t_rel = rotate_vector(q1_inv, delta_t)
463
+ return t_rel, q_rel
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/device.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # modified from DUSt3R
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def todevice(batch, device, callback=None, non_blocking=False):
12
+ """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
13
+
14
+ batch: list, tuple, dict of tensors or other things
15
+ device: pytorch device or 'numpy'
16
+ callback: function that would be called on every sub-elements.
17
+ """
18
+ if callback:
19
+ batch = callback(batch)
20
+
21
+ if isinstance(batch, dict):
22
+ return {k: todevice(v, device) for k, v in batch.items()}
23
+
24
+ if isinstance(batch, (tuple, list)):
25
+ return type(batch)(todevice(x, device) for x in batch)
26
+
27
+ x = batch
28
+ if device == "numpy":
29
+ if isinstance(x, torch.Tensor):
30
+ x = x.detach().cpu().numpy()
31
+ elif x is not None:
32
+ if isinstance(x, np.ndarray):
33
+ x = torch.from_numpy(x)
34
+ if torch.is_tensor(x):
35
+ x = x.to(device, non_blocking=non_blocking)
36
+ return x
37
+
38
+
39
+ to_device = todevice # alias
40
+
41
+
42
+ def to_numpy(x):
43
+ return todevice(x, "numpy")
44
+
45
+
46
+ def to_cpu(x):
47
+ return todevice(x, "cpu")
48
+
49
+
50
+ def to_cuda(x):
51
+ return todevice(x, "cuda")
52
+
53
+
54
+ def collate_with_cat(whatever, lists=False):
55
+ if isinstance(whatever, dict):
56
+ return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
57
+
58
+ elif isinstance(whatever, (tuple, list)):
59
+ if len(whatever) == 0:
60
+ return whatever
61
+ elem = whatever[0]
62
+ T = type(whatever)
63
+
64
+ if elem is None:
65
+ return None
66
+ if isinstance(elem, (bool, float, int, str)):
67
+ return whatever
68
+ if isinstance(elem, tuple):
69
+ return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
70
+ if isinstance(elem, dict):
71
+ return {
72
+ k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem
73
+ }
74
+
75
+ if isinstance(elem, torch.Tensor):
76
+ return listify(whatever) if lists else torch.cat(whatever)
77
+ if isinstance(elem, np.ndarray):
78
+ return (
79
+ listify(whatever)
80
+ if lists
81
+ else torch.cat([torch.from_numpy(x) for x in whatever])
82
+ )
83
+
84
+ return sum(whatever, T())
85
+
86
+
87
+ def listify(elems):
88
+ return [x for e in elems for x in e]
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/geometry.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # modified from DUSt3R
6
+
7
+ import torch
8
+ import numpy as np
9
+ from scipy.spatial import cKDTree as KDTree
10
+
11
+ from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
12
+ from dust3r.utils.device import to_numpy
13
+
14
+
15
+ def xy_grid(
16
+ W,
17
+ H,
18
+ device=None,
19
+ origin=(0, 0),
20
+ unsqueeze=None,
21
+ cat_dim=-1,
22
+ homogeneous=False,
23
+ **arange_kw,
24
+ ):
25
+ """Output a (H,W,2) array of int32
26
+ with output[j,i,0] = i + origin[0]
27
+ output[j,i,1] = j + origin[1]
28
+ """
29
+ if device is None:
30
+
31
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
32
+ else:
33
+
34
+ arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
35
+ meshgrid, stack = torch.meshgrid, torch.stack
36
+ ones = lambda *a: torch.ones(*a, device=device)
37
+
38
+ tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
39
+ grid = meshgrid(tw, th, indexing="xy")
40
+ if homogeneous:
41
+ grid = grid + (ones((H, W)),)
42
+ if unsqueeze is not None:
43
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
44
+ if cat_dim is not None:
45
+ grid = stack(grid, cat_dim)
46
+ return grid
47
+
48
+
49
+ def geotrf(Trf, pts, ncol=None, norm=False):
50
+ """Apply a geometric transformation to a list of 3-D points.
51
+
52
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
53
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
54
+
55
+ ncol: int. number of columns of the result (2 or 3)
56
+ norm: float. if != 0, the resut is projected on the z=norm plane.
57
+
58
+ Returns an array of projected 2d points.
59
+ """
60
+ assert Trf.ndim >= 2
61
+ if isinstance(Trf, np.ndarray):
62
+ pts = np.asarray(pts)
63
+ elif isinstance(Trf, torch.Tensor):
64
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
65
+
66
+ output_reshape = pts.shape[:-1]
67
+ ncol = ncol or pts.shape[-1]
68
+
69
+ if (
70
+ isinstance(Trf, torch.Tensor)
71
+ and isinstance(pts, torch.Tensor)
72
+ and Trf.ndim == 3
73
+ and pts.ndim == 4
74
+ ):
75
+ d = pts.shape[3]
76
+ if Trf.shape[-1] == d:
77
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
78
+ elif Trf.shape[-1] == d + 1:
79
+ pts = (
80
+ torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
81
+ + Trf[:, None, None, :d, d]
82
+ )
83
+ else:
84
+ raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
85
+ else:
86
+ if Trf.ndim >= 3:
87
+ n = Trf.ndim - 2
88
+ assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
89
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
90
+
91
+ if pts.ndim > Trf.ndim:
92
+
93
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
94
+ elif pts.ndim == 2:
95
+
96
+ pts = pts[:, None, :]
97
+
98
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
99
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
100
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
101
+ elif pts.shape[-1] == Trf.shape[-1]:
102
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
103
+ pts = pts @ Trf
104
+ else:
105
+ pts = Trf @ pts.T
106
+ if pts.ndim >= 2:
107
+ pts = pts.swapaxes(-1, -2)
108
+
109
+ if norm:
110
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
111
+ if norm != 1:
112
+ pts *= norm
113
+
114
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
115
+ return res
116
+
117
+
118
+ def inv(mat):
119
+ """Invert a torch or numpy matrix"""
120
+ if isinstance(mat, torch.Tensor):
121
+ return torch.linalg.inv(mat)
122
+ if isinstance(mat, np.ndarray):
123
+ return np.linalg.inv(mat)
124
+ raise ValueError(f"bad matrix type = {type(mat)}")
125
+
126
+
127
+ def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
128
+ """
129
+ Args:
130
+ - depthmap (BxHxW array):
131
+ - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
132
+ Returns:
133
+ pointmap of absolute coordinates (BxHxWx3 array)
134
+ """
135
+
136
+ if len(depth.shape) == 4:
137
+ B, H, W, n = depth.shape
138
+ else:
139
+ B, H, W = depth.shape
140
+ n = None
141
+
142
+ if len(pseudo_focal.shape) == 3: # [B,H,W]
143
+ pseudo_focalx = pseudo_focaly = pseudo_focal
144
+ elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
145
+ pseudo_focalx = pseudo_focal[:, 0]
146
+ if pseudo_focal.shape[1] == 2:
147
+ pseudo_focaly = pseudo_focal[:, 1]
148
+ else:
149
+ pseudo_focaly = pseudo_focalx
150
+ else:
151
+ raise NotImplementedError("Error, unknown input focal shape format.")
152
+
153
+ assert pseudo_focalx.shape == depth.shape[:3]
154
+ assert pseudo_focaly.shape == depth.shape[:3]
155
+ grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
156
+
157
+ if pp is None:
158
+ grid_x = grid_x - (W - 1) / 2
159
+ grid_y = grid_y - (H - 1) / 2
160
+ else:
161
+ grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
162
+ grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
163
+
164
+ if n is None:
165
+ pts3d = torch.empty((B, H, W, 3), device=depth.device)
166
+ pts3d[..., 0] = depth * grid_x / pseudo_focalx
167
+ pts3d[..., 1] = depth * grid_y / pseudo_focaly
168
+ pts3d[..., 2] = depth
169
+ else:
170
+ pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
171
+ pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
172
+ pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
173
+ pts3d[..., 2, :] = depth
174
+ return pts3d
175
+
176
+
177
+ def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
178
+ """
179
+ Args:
180
+ - depthmap (HxW array):
181
+ - camera_intrinsics: a 3x3 matrix
182
+ Returns:
183
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
184
+ """
185
+ camera_intrinsics = np.float32(camera_intrinsics)
186
+ H, W = depthmap.shape
187
+
188
+ assert camera_intrinsics[0, 1] == 0.0
189
+ assert camera_intrinsics[1, 0] == 0.0
190
+ if pseudo_focal is None:
191
+ fu = camera_intrinsics[0, 0]
192
+ fv = camera_intrinsics[1, 1]
193
+ else:
194
+ assert pseudo_focal.shape == (H, W)
195
+ fu = fv = pseudo_focal
196
+ cu = camera_intrinsics[0, 2]
197
+ cv = camera_intrinsics[1, 2]
198
+
199
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
200
+ z_cam = depthmap
201
+ x_cam = (u - cu) * z_cam / fu
202
+ y_cam = (v - cv) * z_cam / fv
203
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
204
+ valid_mask = depthmap > 0.0
205
+ return X_cam, valid_mask
206
+
207
+
208
+ def depthmap_to_absolute_camera_coordinates(
209
+ depthmap, camera_intrinsics, camera_pose, **kw
210
+ ):
211
+ """
212
+ Args:
213
+ - depthmap (HxW array):
214
+ - camera_intrinsics: a 3x3 matrix
215
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
216
+ Returns:
217
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
218
+ """
219
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
220
+
221
+ X_world = X_cam # default
222
+ if camera_pose is not None:
223
+
224
+ R_cam2world = camera_pose[:3, :3]
225
+ t_cam2world = camera_pose[:3, 3]
226
+
227
+ X_world = (
228
+ np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
229
+ )
230
+
231
+ return X_world, X_cam, valid_mask
232
+
233
+
234
+ def colmap_to_opencv_intrinsics(K):
235
+ """
236
+ Modify camera intrinsics to follow a different convention.
237
+ Coordinates of the center of the top-left pixels are by default:
238
+ - (0.5, 0.5) in Colmap
239
+ - (0,0) in OpenCV
240
+ """
241
+ K = K.copy()
242
+ K[0, 2] -= 0.5
243
+ K[1, 2] -= 0.5
244
+ return K
245
+
246
+
247
+ def opencv_to_colmap_intrinsics(K):
248
+ """
249
+ Modify camera intrinsics to follow a different convention.
250
+ Coordinates of the center of the top-left pixels are by default:
251
+ - (0.5, 0.5) in Colmap
252
+ - (0,0) in OpenCV
253
+ """
254
+ K = K.copy()
255
+ K[0, 2] += 0.5
256
+ K[1, 2] += 0.5
257
+ return K
258
+
259
+
260
+ def normalize_pointcloud(
261
+ pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None, ret_factor=False
262
+ ):
263
+ """renorm pointmaps pts1, pts2 with norm_mode"""
264
+ assert pts1.ndim >= 3 and pts1.shape[-1] == 3
265
+ assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
266
+ norm_mode, dis_mode = norm_mode.split("_")
267
+
268
+ if norm_mode == "avg":
269
+
270
+ nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
271
+ nan_pts2, nnz2 = (
272
+ invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
273
+ )
274
+ all_pts = (
275
+ torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
276
+ )
277
+
278
+ all_dis = all_pts.norm(dim=-1)
279
+ if dis_mode == "dis":
280
+ pass # do nothing
281
+ elif dis_mode == "log1p":
282
+ all_dis = torch.log1p(all_dis)
283
+ elif dis_mode == "warp-log1p":
284
+
285
+ log_dis = torch.log1p(all_dis)
286
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
287
+ H1, W1 = pts1.shape[1:-1]
288
+ pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1)
289
+ if pts2 is not None:
290
+ H2, W2 = pts2.shape[1:-1]
291
+ pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1)
292
+ all_dis = log_dis # this is their true distance afterwards
293
+ else:
294
+ raise ValueError(f"bad {dis_mode=}")
295
+
296
+ norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
297
+ else:
298
+
299
+ nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
300
+ nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
301
+ all_pts = (
302
+ torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
303
+ )
304
+
305
+ all_dis = all_pts.norm(dim=-1)
306
+
307
+ if norm_mode == "avg":
308
+ norm_factor = all_dis.nanmean(dim=1)
309
+ elif norm_mode == "median":
310
+ norm_factor = all_dis.nanmedian(dim=1).values.detach()
311
+ elif norm_mode == "sqrt":
312
+ norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2
313
+ else:
314
+ raise ValueError(f"bad {norm_mode=}")
315
+
316
+ norm_factor = norm_factor.clip(min=1e-8)
317
+ while norm_factor.ndim < pts1.ndim:
318
+ norm_factor.unsqueeze_(-1)
319
+
320
+ res = pts1 / norm_factor
321
+ if pts2 is not None:
322
+ res = (res, pts2 / norm_factor)
323
+ if ret_factor:
324
+ res = res + (norm_factor,)
325
+ return res
326
+
327
+
328
+ def normalize_pointcloud_group(
329
+ pts_list,
330
+ norm_mode="avg_dis",
331
+ valid_list=None,
332
+ conf_list=None,
333
+ ret_factor=False,
334
+ ret_factor_only=False,
335
+ ):
336
+ """renorm pointmaps pts1, pts2 with norm_mode"""
337
+ for pts in pts_list:
338
+ assert pts.ndim >= 3 and pts.shape[-1] == 3
339
+
340
+ norm_mode, dis_mode = norm_mode.split("_")
341
+
342
+ if norm_mode == "avg":
343
+
344
+ nan_pts_list, nnz_list = zip(
345
+ *[
346
+ invalid_to_zeros(pts1, valid1, ndim=3)
347
+ for pts1, valid1 in zip(pts_list, valid_list)
348
+ ]
349
+ )
350
+ all_pts = torch.cat(nan_pts_list, dim=1)
351
+ if conf_list is not None:
352
+ nan_conf_list = [
353
+ invalid_to_zeros(conf1[..., None], valid1, ndim=3)[0]
354
+ for conf1, valid1 in zip(conf_list, valid_list)
355
+ ]
356
+ all_conf = torch.cat(nan_conf_list, dim=1)[..., 0]
357
+ else:
358
+ all_conf = torch.ones_like(all_pts[..., 0])
359
+
360
+ all_dis = all_pts.norm(dim=-1)
361
+ if dis_mode == "dis":
362
+ pass # do nothing
363
+ elif dis_mode == "log1p":
364
+ all_dis = torch.log1p(all_dis)
365
+ elif dis_mode == "warp-log1p":
366
+
367
+ log_dis = torch.log1p(all_dis)
368
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
369
+ H_W_list = [pts.shape[1:-1] for pts in pts_list]
370
+ pts_list = [
371
+ pts
372
+ * warp_factor[:, sum(H_W_list[:i]) : sum(H_W_list[: i + 1])].view(
373
+ -1, H, W, 1
374
+ )
375
+ for i, (pts, (H, W)) in enumerate(zip(pts_list, H_W_list))
376
+ ]
377
+ all_dis = log_dis # this is their true distance afterwards
378
+ else:
379
+ raise ValueError(f"bad {dis_mode=}")
380
+
381
+ norm_factor = (all_conf * all_dis).sum(dim=1) / (all_conf.sum(dim=1) + 1e-8)
382
+ else:
383
+
384
+ nan_pts_list = [
385
+ invalid_to_nans(pts1, valid1, ndim=3)
386
+ for pts1, valid1 in zip(pts_list, valid_list)
387
+ ]
388
+
389
+ all_pts = torch.cat(nan_pts_list, dim=1)
390
+
391
+ all_dis = all_pts.norm(dim=-1)
392
+
393
+ if norm_mode == "avg":
394
+ norm_factor = all_dis.nanmean(dim=1)
395
+ elif norm_mode == "median":
396
+ norm_factor = all_dis.nanmedian(dim=1).values.detach()
397
+ elif norm_mode == "sqrt":
398
+ norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2
399
+ else:
400
+ raise ValueError(f"bad {norm_mode=}")
401
+
402
+ norm_factor = norm_factor.clip(min=1e-8)
403
+ while norm_factor.ndim < pts_list[0].ndim:
404
+ norm_factor.unsqueeze_(-1)
405
+
406
+ if ret_factor_only:
407
+
408
+ return norm_factor
409
+
410
+ res = [pts / norm_factor for pts in pts_list]
411
+ if ret_factor:
412
+ return res, norm_factor
413
+ return res
414
+
415
+
416
+ @torch.no_grad()
417
+ def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
418
+
419
+ _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
420
+ _z2 = (
421
+ invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1)
422
+ if z2 is not None
423
+ else None
424
+ )
425
+ _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
426
+
427
+ if quantile == 0.5:
428
+ shift_z = torch.nanmedian(_z, dim=-1).values
429
+ else:
430
+ shift_z = torch.nanquantile(_z, quantile, dim=-1)
431
+ return shift_z # (B,)
432
+
433
+
434
+ @torch.no_grad()
435
+ def get_group_pointcloud_depth(zs, valid_masks, quantile=0.5):
436
+
437
+ _zs = [
438
+ invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
439
+ for z1, valid_mask1 in zip(zs, valid_masks)
440
+ ]
441
+ _z = torch.cat(_zs, dim=-1)
442
+
443
+ if quantile == 0.5:
444
+ shift_z = torch.nanmedian(_z, dim=-1).values
445
+ else:
446
+ shift_z = torch.nanquantile(_z, quantile, dim=-1)
447
+ return shift_z # (B,)
448
+
449
+
450
+ @torch.no_grad()
451
+ def get_joint_pointcloud_center_scale(
452
+ pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True
453
+ ):
454
+
455
+ _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
456
+ _pts2 = (
457
+ invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3)
458
+ if pts2 is not None
459
+ else None
460
+ )
461
+ _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
462
+
463
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
464
+ if z_only:
465
+ _center[..., :2] = 0 # do not center X and Y
466
+
467
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
468
+ scale = torch.nanmedian(_norm, dim=1).values
469
+ return _center[:, None, :, :], scale[:, None, None, None]
470
+
471
+
472
+ @torch.no_grad()
473
+ def get_group_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
474
+
475
+ _pts = [
476
+ invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
477
+ for pts1, valid_mask1 in zip(pts, valid_masks)
478
+ ]
479
+ _pts = torch.cat(_pts, dim=1)
480
+
481
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
482
+ if z_only:
483
+ _center[..., :2] = 0 # do not center X and Y
484
+
485
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
486
+ scale = torch.nanmedian(_norm, dim=1).values
487
+ return _center[:, None, :, :], scale[:, None, None, None]
488
+
489
+
490
+ def find_reciprocal_matches(P1, P2):
491
+ """
492
+ returns 3 values:
493
+ 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
494
+ 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
495
+ 3 - reciprocal_in_P2.sum(): the number of matches
496
+ """
497
+ tree1 = KDTree(P1)
498
+ tree2 = KDTree(P2)
499
+
500
+ _, nn1_in_P2 = tree2.query(P1, workers=8)
501
+ _, nn2_in_P1 = tree1.query(P2, workers=8)
502
+
503
+ reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))
504
+ reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))
505
+ assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
506
+ return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
507
+
508
+
509
+ def get_med_dist_between_poses(poses):
510
+ from scipy.spatial.distance import pdist
511
+
512
+ return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
513
+
514
+
515
+ def weighted_procrustes(A, B, w, use_weights=True, eps=1e-16, return_T=False):
516
+ """
517
+ X: torch tensor B x N x 3
518
+ Y: torch tensor B x N x 3
519
+ w: torch tensor B x N
520
+ """
521
+ assert len(A) == len(B)
522
+ if use_weights:
523
+ W1 = torch.abs(w).sum(1, keepdim=True)
524
+ w_norm = (w / (W1 + eps)).unsqueeze(-1)
525
+ a_mean = (w_norm * A).sum(dim=1, keepdim=True)
526
+ b_mean = (w_norm * B).sum(dim=1, keepdim=True)
527
+
528
+ A_c = A - a_mean
529
+ B_c = B - b_mean
530
+
531
+ H = torch.einsum("bni,bnj->bij", A_c, w_norm * B_c)
532
+
533
+ else:
534
+ a_mean = A.mean(axis=1, keepdim=True)
535
+ b_mean = B.mean(axis=1, keepdim=True)
536
+
537
+ A_c = A - a_mean
538
+ B_c = B - b_mean
539
+
540
+ H = torch.einsum("bij,bik->bjk", A_c, B_c)
541
+
542
+ U, S, V = torch.svd(H) # U: B x 3 x 3, S: B x 3, V: B x 3 x 3
543
+ Z = torch.eye(3).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device)
544
+ Z[:, -1, -1] = torch.sign(torch.linalg.det(U @ V.transpose(1, 2))) # B x 3 x 3
545
+ R = V @ Z @ U.transpose(1, 2) # B x 3 x 3
546
+ t = b_mean - torch.einsum("bij,bjk->bik", R, a_mean.transpose(-2, -1)).transpose(
547
+ -2, -1
548
+ )
549
+ if return_T:
550
+ T = torch.eye(4).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device)
551
+ T[:, :3, :3] = R
552
+ T[:, :3, 3] = t.squeeze()
553
+ return T
554
+ return R, t.squeeze()
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/image.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # modified from DUSt3R
6
+
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+ import PIL.Image
11
+ from PIL.ImageOps import exif_transpose
12
+ import torchvision.transforms as tvf
13
+
14
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
15
+ import cv2 # noqa
16
+
17
+ try:
18
+ from pillow_heif import register_heif_opener # noqa
19
+
20
+ register_heif_opener()
21
+ heif_support_enabled = True
22
+ except ImportError:
23
+ heif_support_enabled = False
24
+
25
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
26
+
27
+
28
+ def img_to_arr(img):
29
+ if isinstance(img, str):
30
+ img = imread_cv2(img)
31
+ return img
32
+
33
+
34
+ def imread_cv2(path, options=cv2.IMREAD_COLOR):
35
+ """Open an image or a depthmap with opencv-python."""
36
+ if path.endswith((".exr", "EXR")):
37
+ options = cv2.IMREAD_ANYDEPTH
38
+ img = cv2.imread(path, options)
39
+ if img is None:
40
+ raise IOError(f"Could not load image={path} with {options=}")
41
+ if img.ndim == 3:
42
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
43
+ return img
44
+
45
+
46
+ def imread_pil(path):
47
+ """Open an RGB image using PIL and return as numpy array."""
48
+ img = PIL.Image.open(path)
49
+ img = exif_transpose(img)
50
+ img = img.convert("RGB")
51
+ return np.array(img)
52
+
53
+
54
+ def rgb(ftensor, true_shape=None):
55
+ if isinstance(ftensor, list):
56
+ return [rgb(x, true_shape=true_shape) for x in ftensor]
57
+ if isinstance(ftensor, torch.Tensor):
58
+ ftensor = ftensor.detach().cpu().numpy() # H,W,3
59
+ if ftensor.ndim == 3 and ftensor.shape[0] == 3:
60
+ ftensor = ftensor.transpose(1, 2, 0)
61
+ elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
62
+ ftensor = ftensor.transpose(0, 2, 3, 1)
63
+ if true_shape is not None:
64
+ H, W = true_shape
65
+ ftensor = ftensor[:H, :W]
66
+ if ftensor.dtype == np.uint8:
67
+ img = np.float32(ftensor) / 255
68
+ else:
69
+ img = (ftensor * 0.5) + 0.5
70
+ return img.clip(min=0, max=1)
71
+
72
+
73
+ def _resize_pil_image(img, long_edge_size):
74
+ S = max(img.size)
75
+ if S > long_edge_size:
76
+ interp = PIL.Image.LANCZOS
77
+ elif S <= long_edge_size:
78
+ interp = PIL.Image.BICUBIC
79
+ new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
80
+ return img.resize(new_size, interp)
81
+
82
+
83
+ def load_images(folder_or_list, size, square_ok=False, verbose=True):
84
+ """open and convert all images in a list or folder to proper input format for DUSt3R"""
85
+ if isinstance(folder_or_list, str):
86
+ if verbose:
87
+ print(f">> Loading images from {folder_or_list}")
88
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
89
+
90
+ elif isinstance(folder_or_list, list):
91
+ if verbose:
92
+ print(f">> Loading a list of {len(folder_or_list)} images")
93
+ root, folder_content = "", folder_or_list
94
+
95
+ else:
96
+ raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
97
+
98
+ supported_images_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
99
+ if heif_support_enabled:
100
+ supported_images_extensions += [".heic", ".heif"]
101
+ supported_images_extensions = tuple(supported_images_extensions)
102
+
103
+ imgs = []
104
+ for path in folder_content:
105
+ if not path.lower().endswith(supported_images_extensions):
106
+ continue
107
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
108
+ W1, H1 = img.size
109
+ if size == 224:
110
+
111
+ img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
112
+ else:
113
+
114
+ img = _resize_pil_image(img, size)
115
+ W, H = img.size
116
+ cx, cy = W // 2, H // 2
117
+ if size == 224:
118
+ half = min(cx, cy)
119
+ img = img.crop((cx - half, cy - half, cx + half, cy + half))
120
+ else:
121
+ halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
122
+ if not (square_ok) and W == H:
123
+ halfh = 3 * halfw / 4
124
+ img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
125
+
126
+ W2, H2 = img.size
127
+ if verbose:
128
+ print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
129
+ imgs.append(
130
+ dict(
131
+ img=ImgNorm(img)[None],
132
+ true_shape=np.int32([img.size[::-1]]),
133
+ idx=len(imgs),
134
+ instance=str(len(imgs)),
135
+ )
136
+ )
137
+
138
+ assert imgs, "no images foud at " + root
139
+ if verbose:
140
+ print(f" (Found {len(imgs)} images)")
141
+ return imgs
142
+
143
+
144
+ def load_images_for_eval(
145
+ folder_or_list, size, square_ok=False, verbose=True, crop=True
146
+ ):
147
+ """open and convert all images in a list or folder to proper input format for DUSt3R"""
148
+ if isinstance(folder_or_list, str):
149
+ if verbose:
150
+ print(f">> Loading images from {folder_or_list}")
151
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
152
+
153
+ elif isinstance(folder_or_list, list):
154
+ if verbose:
155
+ print(f">> Loading a list of {len(folder_or_list)} images")
156
+ root, folder_content = "", folder_or_list
157
+
158
+ else:
159
+ raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
160
+
161
+ supported_images_extensions = [".jpg", ".jpeg", ".png"]
162
+ if heif_support_enabled:
163
+ supported_images_extensions += [".heic", ".heif"]
164
+ supported_images_extensions = tuple(supported_images_extensions)
165
+
166
+ imgs = []
167
+ for path in folder_content:
168
+ if not path.lower().endswith(supported_images_extensions):
169
+ continue
170
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
171
+ W1, H1 = img.size
172
+ if size == 224:
173
+ # resize short side to 224 (then crop)
174
+ img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
175
+ else:
176
+ # resize long side to 512
177
+ img = _resize_pil_image(img, size)
178
+ W, H = img.size
179
+ cx, cy = W // 2, H // 2
180
+ if size == 224:
181
+ half = min(cx, cy)
182
+ if crop:
183
+ img = img.crop((cx - half, cy - half, cx + half, cy + half))
184
+ else: # resize
185
+ img = img.resize((2 * half, 2 * half), PIL.Image.LANCZOS)
186
+ else:
187
+ halfw, halfh = ((2 * cx) // 14) * 7, ((2 * cy) // 14) * 7
188
+ if not (square_ok) and W == H:
189
+ halfh = 3 * halfw / 4
190
+ if crop:
191
+ img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
192
+ else: # resize
193
+ img = img.resize((2 * halfw, 2 * halfh), PIL.Image.LANCZOS)
194
+ W2, H2 = img.size
195
+ if verbose:
196
+ print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
197
+ imgs.append(
198
+ dict(
199
+ img=ImgNorm(img)[None],
200
+ true_shape=np.int32([img.size[::-1]]),
201
+ idx=len(imgs),
202
+ instance=str(len(imgs)),
203
+ )
204
+ )
205
+
206
+ assert imgs, "no images foud at " + root
207
+ if verbose:
208
+ print(f" (Found {len(imgs)} images)")
209
+ return imgs
210
+
211
+
212
+ def load_images_512(folder_or_list, size, square_ok=False, verbose=True):
213
+ """open and convert all images in a list or folder to proper input format for DUSt3R"""
214
+ if isinstance(folder_or_list, str):
215
+ if verbose:
216
+ print(f">> Loading images from {folder_or_list}")
217
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
218
+
219
+ elif isinstance(folder_or_list, list):
220
+ if verbose:
221
+ print(f">> Loading a list of {len(folder_or_list)} images")
222
+ root, folder_content = "", folder_or_list
223
+
224
+ else:
225
+ raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
226
+
227
+ supported_images_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
228
+ if heif_support_enabled:
229
+ supported_images_extensions += [".heic", ".heif"]
230
+ supported_images_extensions = tuple(supported_images_extensions)
231
+
232
+ imgs = []
233
+ for path in folder_content:
234
+ if not path.lower().endswith(supported_images_extensions):
235
+ continue
236
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
237
+ img = img.resize((512, 384))
238
+ W1, H1 = img.size
239
+ if size == 224:
240
+
241
+ img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
242
+ else:
243
+
244
+ img = _resize_pil_image(img, size)
245
+ W, H = img.size
246
+ cx, cy = W // 2, H // 2
247
+ if size == 224:
248
+ half = min(cx, cy)
249
+ img = img.crop((cx - half, cy - half, cx + half, cy + half))
250
+ else:
251
+ halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
252
+ if not (square_ok) and W == H:
253
+ halfh = 3 * halfw / 4
254
+ img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
255
+
256
+ W2, H2 = img.size
257
+ if verbose:
258
+ print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
259
+ imgs.append(
260
+ dict(
261
+ img=ImgNorm(img)[None],
262
+ true_shape=np.int32([img.size[::-1]]),
263
+ idx=len(imgs),
264
+ instance=str(len(imgs)),
265
+ )
266
+ )
267
+
268
+ assert imgs, "no images foud at " + root
269
+ if verbose:
270
+ print(f" (Found {len(imgs)} images)")
271
+ return imgs
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/misc.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # modified from DUSt3R
6
+
7
+ import torch
8
+
9
+
10
+ def fill_default_args(kwargs, func):
11
+ import inspect # a bit hacky but it works reliably
12
+
13
+ signature = inspect.signature(func)
14
+
15
+ for k, v in signature.parameters.items():
16
+ if v.default is inspect.Parameter.empty:
17
+ continue
18
+ kwargs.setdefault(k, v.default)
19
+
20
+ return kwargs
21
+
22
+
23
+ def freeze_all_params(modules):
24
+ for module in modules:
25
+ try:
26
+ for n, param in module.named_parameters():
27
+ param.requires_grad = False
28
+ except AttributeError:
29
+
30
+ module.requires_grad = False
31
+
32
+
33
+ def is_symmetrized(gt1, gt2):
34
+ x = gt1["instance"]
35
+ y = gt2["instance"]
36
+ if len(x) == len(y) and len(x) == 1:
37
+ return False # special case of batchsize 1
38
+ ok = True
39
+ for i in range(0, len(x), 2):
40
+ ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])
41
+ return ok
42
+
43
+
44
+ def flip(tensor):
45
+ """flip so that tensor[0::2] <=> tensor[1::2]"""
46
+ return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
47
+
48
+
49
+ def interleave(tensor1, tensor2):
50
+ res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
51
+ res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
52
+ return res1, res2
53
+
54
+
55
+ def transpose_to_landscape(head, activate=True):
56
+ """Predict in the correct aspect-ratio,
57
+ then transpose the result in landscape
58
+ and stack everything back together.
59
+ """
60
+
61
+ def wrapper_no(decout, true_shape, **kwargs):
62
+ B = len(true_shape)
63
+ assert true_shape[0:1].allclose(true_shape), "true_shape must be all identical"
64
+ H, W = true_shape[0].cpu().tolist()
65
+ res = head(decout, (H, W), **kwargs)
66
+ return res
67
+
68
+ def wrapper_yes(decout, true_shape, **kwargs):
69
+ B = len(true_shape)
70
+
71
+ H, W = int(true_shape.min()), int(true_shape.max())
72
+
73
+ height, width = true_shape.T
74
+ is_landscape = width >= height
75
+ is_portrait = ~is_landscape
76
+
77
+ if is_landscape.all():
78
+ return head(decout, (H, W), **kwargs)
79
+ if is_portrait.all():
80
+ return transposed(head(decout, (W, H), **kwargs))
81
+
82
+ def selout(ar):
83
+ return [d[ar] for d in decout]
84
+
85
+ if "pos" in kwargs:
86
+ kwargs_landscape = kwargs.copy()
87
+ kwargs_landscape["pos"] = kwargs["pos"][is_landscape]
88
+ kwargs_portrait = kwargs.copy()
89
+ kwargs_portrait["pos"] = kwargs["pos"][is_portrait]
90
+ l_result = head(selout(is_landscape), (H, W), **kwargs_landscape)
91
+ p_result = transposed(head(selout(is_portrait), (W, H), **kwargs_portrait))
92
+
93
+ result = {}
94
+ for k in l_result | p_result:
95
+ x = l_result[k].new(B, *l_result[k].shape[1:])
96
+ x[is_landscape] = l_result[k]
97
+ x[is_portrait] = p_result[k]
98
+ result[k] = x
99
+
100
+ return result
101
+
102
+ return wrapper_yes if activate else wrapper_no
103
+
104
+
105
+ def transposed(dic):
106
+ return {k: v.swapaxes(1, 2) if v.ndim > 2 else v for k, v in dic.items()}
107
+
108
+
109
+ def invalid_to_nans(arr, valid_mask, ndim=999):
110
+ if valid_mask is not None:
111
+ arr = arr.clone()
112
+ arr[~valid_mask] = float("nan")
113
+ if arr.ndim > ndim:
114
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
115
+ return arr
116
+
117
+
118
+ def invalid_to_zeros(arr, valid_mask, ndim=999):
119
+ if valid_mask is not None:
120
+ arr = arr.clone()
121
+ arr[~valid_mask] = 0
122
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
123
+ else:
124
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
125
+ if arr.ndim > ndim:
126
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
127
+ return arr, nnz
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/parallel.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # modified from DUSt3R
6
+
7
+ from tqdm import tqdm
8
+ from multiprocessing.dummy import Pool as ThreadPool
9
+ from multiprocessing import cpu_count
10
+
11
+
12
+ def parallel_threads(
13
+ function,
14
+ args,
15
+ workers=0,
16
+ star_args=False,
17
+ kw_args=False,
18
+ front_num=1,
19
+ Pool=ThreadPool,
20
+ **tqdm_kw
21
+ ):
22
+ """tqdm but with parallel execution.
23
+
24
+ Will essentially return
25
+ res = [ function(arg) # default
26
+ function(*arg) # if star_args is True
27
+ function(**arg) # if kw_args is True
28
+ for arg in args]
29
+
30
+ Note:
31
+ the <front_num> first elements of args will not be parallelized.
32
+ This can be useful for debugging.
33
+ """
34
+ while workers <= 0:
35
+ workers += cpu_count()
36
+ if workers == 1:
37
+ front_num = float("inf")
38
+
39
+ try:
40
+ n_args_parallel = len(args) - front_num
41
+ except TypeError:
42
+ n_args_parallel = None
43
+ args = iter(args)
44
+
45
+ front = []
46
+ while len(front) < front_num:
47
+ try:
48
+ a = next(args)
49
+ except StopIteration:
50
+ return front # end of the iterable
51
+ front.append(
52
+ function(*a) if star_args else function(**a) if kw_args else function(a)
53
+ )
54
+
55
+ out = []
56
+ with Pool(workers) as pool:
57
+
58
+ if star_args:
59
+ futures = pool.imap(starcall, [(function, a) for a in args])
60
+ elif kw_args:
61
+ futures = pool.imap(starstarcall, [(function, a) for a in args])
62
+ else:
63
+ futures = pool.imap(function, args)
64
+
65
+ for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
66
+ out.append(f)
67
+ return front + out
68
+
69
+
70
+ def parallel_processes(*args, **kwargs):
71
+ """Same as parallel_threads, with processes"""
72
+ import multiprocessing as mp
73
+
74
+ kwargs["Pool"] = mp.Pool
75
+ return parallel_threads(*args, **kwargs)
76
+
77
+
78
+ def starcall(args):
79
+ """convenient wrapper for Process.Pool"""
80
+ function, args = args
81
+ return function(*args)
82
+
83
+
84
+ def starstarcall(args):
85
+ """convenient wrapper for Process.Pool"""
86
+ function, args = args
87
+ return function(**args)
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/path_to_croco.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # modified from DUSt3R
6
+
7
+ import sys
8
+ import os.path as path
9
+ import importlib
10
+
11
+ HERE_PATH = path.normpath(path.dirname(__file__))
12
+ CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, "../../croco"))
13
+ CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, "models")
14
+ # IMPORTANT:
15
+ # Do NOT add `.../src/croco` directly to sys.path, otherwise subfolders like
16
+ # `croco/datasets` become a top-level module named `datasets`, which will shadow
17
+ # HuggingFace `datasets` and break `accelerate` (and others).
18
+ # Instead, add `.../src` so we import as `croco.*`.
19
+ SRC_PATH = path.normpath(path.join(HERE_PATH, "../../.."))
20
+
21
+ if path.isdir(CROCO_MODELS_PATH):
22
+
23
+ # Prefer adding the `src` directory; this enables `import croco...` without
24
+ # polluting top-level module names.
25
+ if SRC_PATH not in sys.path:
26
+ sys.path.insert(0, SRC_PATH)
27
+
28
+ # In case an old run already inserted CROCO_REPO_PATH, remove it to avoid
29
+ # shadowing top-level modules (e.g., `datasets`).
30
+ while CROCO_REPO_PATH in sys.path:
31
+ sys.path.remove(CROCO_REPO_PATH)
32
+
33
+ # Backward-compat: DUSt3R code expects `models.*` to exist as a top-level package
34
+ # (historically achieved by adding CROCO_REPO_PATH to sys.path). We keep that
35
+ # import path working by aliasing `croco.models` to `models` without exposing
36
+ # other top-level names like `datasets`.
37
+ try:
38
+ _croco_models = importlib.import_module("croco.models")
39
+ sys.modules.setdefault("models", _croco_models)
40
+ except Exception:
41
+ # If croco isn't importable yet, downstream import will raise a clearer error.
42
+ pass
43
+ else:
44
+ raise ImportError(
45
+ f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n "
46
+ "Did you forget to run 'git submodule update --init --recursive' ?"
47
+ )
outdoor_v48_16gpu_v2/code/05_02-22:24:00/dust3r/utils/render.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from gsplat import rasterization
3
+ from dust3r.utils.geometry import inv, geotrf
4
+
5
+
6
+ def render(
7
+ intrinsics: torch.Tensor,
8
+ pts3d: torch.Tensor,
9
+ rgbs: torch.Tensor | None = None,
10
+ scale: float = 0.002,
11
+ opacity: float = 0.95,
12
+ ):
13
+
14
+ device = pts3d.device
15
+ batch_size = len(intrinsics)
16
+ img_size = pts3d.shape[1:3]
17
+ pts3d = pts3d.reshape(batch_size, -1, 3)
18
+ num_pts = pts3d.shape[1]
19
+ quats = torch.randn((num_pts, 4), device=device)
20
+ quats = quats / quats.norm(dim=-1, keepdim=True)
21
+ scales = scale * torch.ones((num_pts, 3), device=device)
22
+ opacities = opacity * torch.ones((num_pts), device=device)
23
+ if rgbs is not None:
24
+ assert rgbs.shape[1] == 3
25
+ rgbs = rgbs.reshape(batch_size, 3, -1).transpose(1, 2)
26
+ else:
27
+ rgbs = torch.ones_like(pts3d[:, :, :3])
28
+
29
+ rendered_rgbs = []
30
+ rendered_depths = []
31
+ accs = []
32
+ for i in range(batch_size):
33
+ rgbd, acc, _ = rasterization(
34
+ pts3d[i],
35
+ quats,
36
+ scales,
37
+ opacities,
38
+ rgbs[i],
39
+ torch.eye(4, device=device)[None],
40
+ intrinsics[[i]],
41
+ width=img_size[1],
42
+ height=img_size[0],
43
+ packed=False,
44
+ render_mode="RGB+D",
45
+ )
46
+
47
+ rendered_depths.append(rgbd[..., 3])
48
+
49
+ rendered_depths = torch.cat(rendered_depths, dim=0)
50
+
51
+ return rendered_rgbs, rendered_depths, accs
52
+
53
+
54
+ def get_render_results(gts, preds, self_view=False):
55
+ device = preds[0]["pts3d_in_other_view"].device
56
+ with torch.no_grad():
57
+ depths = []
58
+ gt_depths = []
59
+ for i, (gt, pred) in enumerate(zip(gts, preds)):
60
+ if self_view:
61
+ camera = inv(gt["camera_pose"]).to(device)
62
+ intrinsics = gt["camera_intrinsics"].to(device)
63
+ pred = pred["pts3d_in_other_view"]
64
+ else:
65
+ camera = inv(gts[0]["camera_pose"]).to(device)
66
+ intrinsics = gts[0]["camera_intrinsics"].to(device)
67
+ pred = pred["pts3d_in_other_view"]
68
+ gt_img = gt["img"].to(device)
69
+ gt_pts3d = gt["pts3d"].to(device)
70
+
71
+ _, depth, _ = render(intrinsics, pred, gt_img)
72
+ _, gt_depth, _ = render(intrinsics, geotrf(camera, gt_pts3d), gt_img)
73
+ depths.append(depth)
74
+ gt_depths.append(gt_depth)
75
+ return depths, gt_depths
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+
85
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
86
+ """
87
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
88
+ """
89
+ return _make_dinov2_model(
90
+ arch_name="vit_giant2",
91
+ ffn_layer="swiglufused",
92
+ weights=weights,
93
+ pretrained=pretrained,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
99
+ """
100
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
101
+ """
102
+ return _make_dinov2_model(
103
+ arch_name="vit_small",
104
+ pretrained=pretrained,
105
+ weights=weights,
106
+ num_register_tokens=4,
107
+ interpolate_antialias=True,
108
+ interpolate_offset=0.0,
109
+ **kwargs,
110
+ )
111
+
112
+
113
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
114
+ """
115
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
116
+ """
117
+ return _make_dinov2_model(
118
+ arch_name="vit_base",
119
+ pretrained=pretrained,
120
+ weights=weights,
121
+ num_register_tokens=4,
122
+ interpolate_antialias=True,
123
+ interpolate_offset=0.0,
124
+ **kwargs,
125
+ )
126
+
127
+
128
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
129
+ """
130
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
131
+ """
132
+ return _make_dinov2_model(
133
+ arch_name="vit_large",
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ num_register_tokens=4,
137
+ interpolate_antialias=True,
138
+ interpolate_offset=0.0,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
144
+ """
145
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
146
+ """
147
+ return _make_dinov2_model(
148
+ arch_name="vit_giant2",
149
+ ffn_layer="swiglufused",
150
+ weights=weights,
151
+ pretrained=pretrained,
152
+ num_register_tokens=4,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ **kwargs,
156
+ )
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ # warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ # warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ # warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ # warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ # warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+ # warnings.warn("xFormers is not available (Block)")
40
+
41
+
42
+ class Block(nn.Module):
43
+ def __init__(
44
+ self,
45
+ dim: int,
46
+ num_heads: int,
47
+ mlp_ratio: float = 4.0,
48
+ qkv_bias: bool = False,
49
+ proj_bias: bool = True,
50
+ ffn_bias: bool = True,
51
+ drop: float = 0.0,
52
+ attn_drop: float = 0.0,
53
+ init_values=None,
54
+ drop_path: float = 0.0,
55
+ act_layer: Callable[..., nn.Module] = nn.GELU,
56
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
57
+ attn_class: Callable[..., nn.Module] = Attention,
58
+ ffn_layer: Callable[..., nn.Module] = Mlp,
59
+ ) -> None:
60
+ super().__init__()
61
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
62
+ self.norm1 = norm_layer(dim)
63
+ self.attn = attn_class(
64
+ dim,
65
+ num_heads=num_heads,
66
+ qkv_bias=qkv_bias,
67
+ proj_bias=proj_bias,
68
+ attn_drop=attn_drop,
69
+ proj_drop=drop,
70
+ )
71
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
72
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
73
+
74
+ self.norm2 = norm_layer(dim)
75
+ mlp_hidden_dim = int(dim * mlp_ratio)
76
+ self.mlp = ffn_layer(
77
+ in_features=dim,
78
+ hidden_features=mlp_hidden_dim,
79
+ act_layer=act_layer,
80
+ drop=drop,
81
+ bias=ffn_bias,
82
+ )
83
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
84
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
85
+
86
+ self.sample_drop_ratio = drop_path
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ def attn_residual_func(x: Tensor) -> Tensor:
90
+ return self.ls1(self.attn(self.norm1(x)))
91
+
92
+ def ffn_residual_func(x: Tensor) -> Tensor:
93
+ return self.ls2(self.mlp(self.norm2(x)))
94
+
95
+ if self.training and self.sample_drop_ratio > 0.1:
96
+ # the overhead is compensated only for a drop path rate larger than 0.1
97
+ x = drop_add_residual_stochastic_depth(
98
+ x,
99
+ residual_func=attn_residual_func,
100
+ sample_drop_ratio=self.sample_drop_ratio,
101
+ )
102
+ x = drop_add_residual_stochastic_depth(
103
+ x,
104
+ residual_func=ffn_residual_func,
105
+ sample_drop_ratio=self.sample_drop_ratio,
106
+ )
107
+ elif self.training and self.sample_drop_ratio > 0.0:
108
+ x = x + self.drop_path1(attn_residual_func(x))
109
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
110
+ else:
111
+ x = x + attn_residual_func(x)
112
+ x = x + ffn_residual_func(x)
113
+ return x
114
+
115
+
116
+ def drop_add_residual_stochastic_depth(
117
+ x: Tensor,
118
+ residual_func: Callable[[Tensor], Tensor],
119
+ sample_drop_ratio: float = 0.0,
120
+ ) -> Tensor:
121
+ # 1) extract subset using permutation
122
+ b, n, d = x.shape
123
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
124
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
125
+ x_subset = x[brange]
126
+
127
+ # 2) apply residual_func to get residual
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from . import vision_transformer as vits
9
+
10
+
11
+ logger = logging.getLogger("dinov2")
12
+
13
+
14
+ def build_model(args, only_teacher=False, img_size=224):
15
+ args.arch = args.arch.removesuffix("_memeff")
16
+ if "vit" in args.arch:
17
+ vit_kwargs = dict(
18
+ img_size=img_size,
19
+ patch_size=args.patch_size,
20
+ init_values=args.layerscale,
21
+ ffn_layer=args.ffn_layer,
22
+ block_chunks=args.block_chunks,
23
+ qkv_bias=args.qkv_bias,
24
+ proj_bias=args.proj_bias,
25
+ ffn_bias=args.ffn_bias,
26
+ num_register_tokens=args.num_register_tokens,
27
+ interpolate_offset=args.interpolate_offset,
28
+ interpolate_antialias=args.interpolate_antialias,
29
+ )
30
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
31
+ if only_teacher:
32
+ return teacher, teacher.embed_dim
33
+ student = vits.__dict__[args.arch](
34
+ **vit_kwargs,
35
+ drop_path_rate=args.drop_path_rate,
36
+ drop_path_uniform=args.drop_path_uniform,
37
+ )
38
+ embed_dim = student.embed_dim
39
+ return student, teacher, embed_dim
40
+
41
+
42
+ def build_model_from_cfg(cfg, only_teacher=False):
43
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
outdoor_v48_16gpu_v2/code/05_02-22:24:00/slamformer/models/dinov2/models/vision_transformer.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.checkpoint import checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+ from ...layers.attention import FlashAttention
22
+
23
+
24
+ # logger = logging.getLogger("dinov2")
25
+
26
+
27
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
28
+ if not depth_first and include_root:
29
+ fn(module=module, name=name)
30
+ for child_name, child_module in module.named_children():
31
+ child_name = ".".join((name, child_name)) if name else child_name
32
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
33
+ if depth_first and include_root:
34
+ fn(module=module, name=name)
35
+ return module
36
+
37
+
38
+ class BlockChunk(nn.ModuleList):
39
+ def forward(self, x):
40
+ for b in self:
41
+ x = b(x)
42
+ return x
43
+
44
+
45
+ class DinoVisionTransformer(nn.Module):
46
+ def __init__(
47
+ self,
48
+ img_size=224,
49
+ patch_size=16,
50
+ in_chans=3,
51
+ embed_dim=768,
52
+ depth=12,
53
+ num_heads=12,
54
+ mlp_ratio=4.0,
55
+ qkv_bias=True,
56
+ ffn_bias=True,
57
+ proj_bias=True,
58
+ drop_path_rate=0.0,
59
+ drop_path_uniform=False,
60
+ init_values=None, # for layerscale: None or 0 => no layerscale
61
+ embed_layer=PatchEmbed,
62
+ act_layer=nn.GELU,
63
+ block_fn=Block,
64
+ ffn_layer="mlp",
65
+ block_chunks=1,
66
+ num_register_tokens=0,
67
+ interpolate_antialias=False,
68
+ interpolate_offset=0.1,
69
+ ):
70
+ """
71
+ Args:
72
+ img_size (int, tuple): input image size
73
+ patch_size (int, tuple): patch size
74
+ in_chans (int): number of input channels
75
+ embed_dim (int): embedding dimension
76
+ depth (int): depth of transformer
77
+ num_heads (int): number of attention heads
78
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
79
+ qkv_bias (bool): enable bias for qkv if True
80
+ proj_bias (bool): enable bias for proj in attn if True
81
+ ffn_bias (bool): enable bias for ffn if True
82
+ drop_path_rate (float): stochastic depth rate
83
+ drop_path_uniform (bool): apply uniform drop rate across blocks
84
+ weight_init (str): weight init scheme
85
+ init_values (float): layer-scale init values
86
+ embed_layer (nn.Module): patch embedding layer
87
+ act_layer (nn.Module): MLP activation layer
88
+ block_fn (nn.Module): transformer block class
89
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
90
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
91
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
92
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
93
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
94
+ """
95
+ super().__init__()
96
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
97
+
98
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
99
+ self.num_tokens = 1
100
+ self.n_blocks = depth
101
+ self.num_heads = num_heads
102
+ self.patch_size = patch_size
103
+ self.num_register_tokens = num_register_tokens
104
+ self.interpolate_antialias = interpolate_antialias
105
+ self.interpolate_offset = interpolate_offset
106
+
107
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
108
+ num_patches = self.patch_embed.num_patches
109
+
110
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
111
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
112
+ assert num_register_tokens >= 0
113
+ self.register_tokens = (
114
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
115
+ )
116
+
117
+ if drop_path_uniform is True:
118
+ dpr = [drop_path_rate] * depth
119
+ else:
120
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
121
+
122
+ if ffn_layer == "mlp":
123
+ # logger.info("using MLP layer as FFN")
124
+ ffn_layer = Mlp
125
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
126
+ # logger.info("using SwiGLU layer as FFN")
127
+ ffn_layer = SwiGLUFFNFused
128
+ elif ffn_layer == "identity":
129
+ # logger.info("using Identity layer as FFN")
130
+
131
+ def f(*args, **kwargs):
132
+ return nn.Identity()
133
+
134
+ ffn_layer = f
135
+ else:
136
+ raise NotImplementedError
137
+
138
+ blocks_list = [
139
+ block_fn(
140
+ dim=embed_dim,
141
+ num_heads=num_heads,
142
+ mlp_ratio=mlp_ratio,
143
+ qkv_bias=qkv_bias,
144
+ proj_bias=proj_bias,
145
+ ffn_bias=ffn_bias,
146
+ drop_path=dpr[i],
147
+ norm_layer=norm_layer,
148
+ act_layer=act_layer,
149
+ ffn_layer=ffn_layer,
150
+ init_values=init_values,
151
+ attn_class=FlashAttention
152
+ )
153
+ for i in range(depth)
154
+ ]
155
+ if block_chunks > 0:
156
+ self.chunked_blocks = True
157
+ chunked_blocks = []
158
+ chunksize = depth // block_chunks
159
+ for i in range(0, depth, chunksize):
160
+ # this is to keep the block index consistent if we chunk the block list
161
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
162
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
163
+ else:
164
+ self.chunked_blocks = False
165
+ self.blocks = nn.ModuleList(blocks_list)
166
+
167
+ self.norm = norm_layer(embed_dim)
168
+ self.head = nn.Identity()
169
+
170
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
171
+
172
+ self.init_weights()
173
+
174
+ def init_weights(self):
175
+ trunc_normal_(self.pos_embed, std=0.02)
176
+ nn.init.normal_(self.cls_token, std=1e-6)
177
+ if self.register_tokens is not None:
178
+ nn.init.normal_(self.register_tokens, std=1e-6)
179
+ named_apply(init_weights_vit_timm, self)
180
+
181
+ def interpolate_pos_encoding(self, x, w, h):
182
+ previous_dtype = x.dtype
183
+ npatch = x.shape[1] - 1
184
+ N = self.pos_embed.shape[1] - 1
185
+ if npatch == N and w == h:
186
+ return self.pos_embed
187
+ pos_embed = self.pos_embed.float()
188
+ class_pos_embed = pos_embed[:, 0]
189
+ patch_pos_embed = pos_embed[:, 1:]
190
+ dim = x.shape[-1]
191
+ w0 = w // self.patch_size
192
+ h0 = h // self.patch_size
193
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
194
+ assert N == M * M
195
+ kwargs = {}
196
+ if self.interpolate_offset:
197
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
198
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
199
+ sx = float(w0 + self.interpolate_offset) / M
200
+ sy = float(h0 + self.interpolate_offset) / M
201
+ kwargs["scale_factor"] = (sx, sy)
202
+ else:
203
+ # Simply specify an output size instead of a scale factor
204
+ kwargs["size"] = (w0, h0)
205
+ patch_pos_embed = nn.functional.interpolate(
206
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
207
+ mode="bicubic",
208
+ antialias=self.interpolate_antialias,
209
+ **kwargs,
210
+ )
211
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
212
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
213
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
214
+
215
+ def prepare_tokens_with_masks(self, x, masks=None):
216
+ B, nc, w, h = x.shape
217
+ x = self.patch_embed(x)
218
+ if masks is not None:
219
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
220
+
221
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
222
+ x = x + self.interpolate_pos_encoding(x, w, h)
223
+
224
+ if self.register_tokens is not None:
225
+ x = torch.cat(
226
+ (
227
+ x[:, :1],
228
+ self.register_tokens.expand(x.shape[0], -1, -1),
229
+ x[:, 1:],
230
+ ),
231
+ dim=1,
232
+ )
233
+
234
+ return x
235
+
236
+ def forward_features_list(self, x_list, masks_list):
237
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
238
+ for blk in self.blocks:
239
+ if self.training:
240
+ x = checkpoint(blk, x, use_reentrant=False)
241
+ else:
242
+ x = blk(x)
243
+
244
+ all_x = x
245
+ output = []
246
+ for x, masks in zip(all_x, masks_list):
247
+ x_norm = self.norm(x)
248
+ output.append(
249
+ {
250
+ "x_norm_clstoken": x_norm[:, 0],
251
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
252
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
253
+ "x_prenorm": x,
254
+ "masks": masks,
255
+ }
256
+ )
257
+ return output
258
+
259
+ def forward_features(self, x, masks=None):
260
+ if isinstance(x, list):
261
+ return self.forward_features_list(x, masks)
262
+
263
+ x = self.prepare_tokens_with_masks(x, masks)
264
+
265
+ for blk in self.blocks:
266
+ if self.training:
267
+ x = checkpoint(blk, x, use_reentrant=False)
268
+ else:
269
+ x = blk(x)
270
+
271
+ x_norm = self.norm(x)
272
+ return {
273
+ "x_norm_clstoken": x_norm[:, 0],
274
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
275
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
276
+ "x_prenorm": x,
277
+ "masks": masks,
278
+ }
279
+
280
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
281
+ x = self.prepare_tokens_with_masks(x)
282
+ # If n is an int, take the n last blocks. If it's a list, take them
283
+ output, total_block_len = [], len(self.blocks)
284
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
285
+ for i, blk in enumerate(self.blocks):
286
+ x = blk(x)
287
+ if i in blocks_to_take:
288
+ output.append(x)
289
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
290
+ return output
291
+
292
+ def _get_intermediate_layers_chunked(self, x, n=1):
293
+ x = self.prepare_tokens_with_masks(x)
294
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
295
+ # If n is an int, take the n last blocks. If it's a list, take them
296
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
297
+ for block_chunk in self.blocks:
298
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
299
+ x = blk(x)
300
+ if i in blocks_to_take:
301
+ output.append(x)
302
+ i += 1
303
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
304
+ return output
305
+
306
+ def get_intermediate_layers(
307
+ self,
308
+ x: torch.Tensor,
309
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
310
+ reshape: bool = False,
311
+ return_class_token: bool = False,
312
+ norm=True,
313
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
314
+ if self.chunked_blocks:
315
+ outputs = self._get_intermediate_layers_chunked(x, n)
316
+ else:
317
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
318
+ if norm:
319
+ outputs = [self.norm(out) for out in outputs]
320
+ class_tokens = [out[:, 0] for out in outputs]
321
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
322
+ if reshape:
323
+ B, _, w, h = x.shape
324
+ outputs = [
325
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
326
+ for out in outputs
327
+ ]
328
+ if return_class_token:
329
+ return tuple(zip(outputs, class_tokens))
330
+ return tuple(outputs)
331
+
332
+ def forward(self, *args, is_training=False, **kwargs):
333
+ ret = self.forward_features(*args, **kwargs)
334
+ if is_training:
335
+ return ret
336
+ else:
337
+ return self.head(ret["x_norm_clstoken"])
338
+
339
+
340
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
341
+ """ViT weight initialization, original timm impl (for reproducibility)"""
342
+ if isinstance(module, nn.Linear):
343
+ trunc_normal_(module.weight, std=0.02)
344
+ if module.bias is not None:
345
+ nn.init.zeros_(module.bias)
346
+
347
+
348
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
349
+ model = DinoVisionTransformer(
350
+ patch_size=patch_size,
351
+ embed_dim=384,
352
+ depth=12,
353
+ num_heads=6,
354
+ mlp_ratio=4,
355
+ block_fn=partial(Block, attn_class=MemEffAttention),
356
+ num_register_tokens=num_register_tokens,
357
+ **kwargs,
358
+ )
359
+ return model
360
+
361
+
362
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
363
+ model = DinoVisionTransformer(
364
+ patch_size=patch_size,
365
+ embed_dim=768,
366
+ depth=12,
367
+ num_heads=12,
368
+ mlp_ratio=4,
369
+ block_fn=partial(Block, attn_class=MemEffAttention),
370
+ num_register_tokens=num_register_tokens,
371
+ **kwargs,
372
+ )
373
+ return model
374
+
375
+
376
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
377
+ model = DinoVisionTransformer(
378
+ patch_size=patch_size,
379
+ embed_dim=1024,
380
+ depth=24,
381
+ num_heads=16,
382
+ mlp_ratio=4,
383
+ block_fn=partial(Block, attn_class=MemEffAttention),
384
+ num_register_tokens=num_register_tokens,
385
+ **kwargs,
386
+ )
387
+ return model
388
+
389
+
390
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
391
+ """
392
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
393
+ """
394
+ model = DinoVisionTransformer(
395
+ patch_size=patch_size,
396
+ embed_dim=1536,
397
+ depth=40,
398
+ num_heads=24,
399
+ mlp_ratio=4,
400
+ block_fn=partial(Block, attn_class=MemEffAttention),
401
+ num_register_tokens=num_register_tokens,
402
+ **kwargs,
403
+ )
404
+ return model