聂如 commited on
Commit
91126af
·
1 Parent(s): 7829591

Add design file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LEGAL.md +7 -0
  2. LICENSE.txt +7 -0
  3. README.md +100 -13
  4. app.py +291 -0
  5. dust3r/croco/datasets/__init__.py +0 -0
  6. dust3r/croco/datasets/crops/README.MD +104 -0
  7. dust3r/croco/datasets/crops/extract_crops_from_images.py +159 -0
  8. dust3r/croco/datasets/habitat_sim/README.MD +76 -0
  9. dust3r/croco/datasets/habitat_sim/__init__.py +0 -0
  10. dust3r/croco/datasets/habitat_sim/generate_from_metadata.py +92 -0
  11. dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py +27 -0
  12. dust3r/croco/datasets/habitat_sim/generate_multiview_images.py +177 -0
  13. dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +390 -0
  14. dust3r/croco/datasets/habitat_sim/pack_metadata_files.py +69 -0
  15. dust3r/croco/datasets/habitat_sim/paths.py +129 -0
  16. dust3r/croco/datasets/pairs_dataset.py +109 -0
  17. dust3r/croco/datasets/transforms.py +95 -0
  18. dust3r/croco/models/__pycache__/blocks.cpython-312.pyc +0 -0
  19. dust3r/croco/models/__pycache__/croco.cpython-312.pyc +0 -0
  20. dust3r/croco/models/__pycache__/dpt_block.cpython-312.pyc +0 -0
  21. dust3r/croco/models/__pycache__/masking.cpython-312.pyc +0 -0
  22. dust3r/croco/models/__pycache__/pos_embed.cpython-312.pyc +0 -0
  23. dust3r/croco/models/blocks.py +307 -0
  24. dust3r/croco/models/criterion.py +37 -0
  25. dust3r/croco/models/croco.py +288 -0
  26. dust3r/croco/models/dpt_block.py +450 -0
  27. dust3r/croco/models/head_downstream.py +58 -0
  28. dust3r/croco/models/masking.py +25 -0
  29. dust3r/croco/models/pos_embed.py +159 -0
  30. dust3r/croco/models/transformer_utils.py +1021 -0
  31. dust3r/croco/models/x_transformer.py +558 -0
  32. dust3r/croco/utils/misc.py +583 -0
  33. dust3r/dust3r/__init__.py +2 -0
  34. dust3r/dust3r/__pycache__/__init__.cpython-312.pyc +0 -0
  35. dust3r/dust3r/__pycache__/model.cpython-312.pyc +0 -0
  36. dust3r/dust3r/__pycache__/patch_embed.cpython-312.pyc +0 -0
  37. dust3r/dust3r/__pycache__/viz.cpython-312.pyc +0 -0
  38. dust3r/dust3r/datasets/CustomDataset.py +145 -0
  39. dust3r/dust3r/datasets/__init__.py +39 -0
  40. dust3r/dust3r/datasets/__pycache__/CustomDataset.cpython-312.pyc +0 -0
  41. dust3r/dust3r/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
  42. dust3r/dust3r/datasets/base/__init__.py +2 -0
  43. dust3r/dust3r/datasets/base/__pycache__/__init__.cpython-312.pyc +0 -0
  44. dust3r/dust3r/datasets/base/__pycache__/base_stereo_view_dataset.cpython-312.pyc +0 -0
  45. dust3r/dust3r/datasets/base/__pycache__/batched_sampler.cpython-312.pyc +0 -0
  46. dust3r/dust3r/datasets/base/__pycache__/easy_dataset.cpython-312.pyc +0 -0
  47. dust3r/dust3r/datasets/base/base_stereo_view_dataset.py +774 -0
  48. dust3r/dust3r/datasets/base/batched_sampler.py +74 -0
  49. dust3r/dust3r/datasets/base/easy_dataset.py +157 -0
  50. dust3r/dust3r/datasets/utils/__init__.py +2 -0
LEGAL.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Legal Disclaimer
2
+
3
+ Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
4
+
5
+ 法律免责声明
6
+
7
+ 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
LICENSE.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ FLARE, Copyright (c) 2025-present Ant Group, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
README.md CHANGED
@@ -1,13 +1,100 @@
1
- ---
2
- title: FLARE
3
- emoji: 🦀
4
- colorFrom: blue
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.19.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLARE: Feed-forward Geometry, Appearance and Camera Estimation from Uncalibrated Sparse Views
2
+ [![Website](https://img.shields.io/website-up-down-green-red/http/shields.io.svg)](https://zhanghe3z.github.io/FLARE/)
3
+ [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97-Hugging%20Face-yellow)](https://huggingface.co/AntResearch/FLARE)
4
+ [![Video](https://img.shields.io/badge/Video-Demo-red)](https://zhanghe3z.github.io/FLARE/videos/teaser_video.mp4)
5
+
6
+ Official implementation of **FLARE** (CVPR 2025) - a feed-forward model for joint camera pose estimation, 3D reconstruction and novel view synthesis from sparse uncalibrated views.
7
+
8
+ ![Teaser Video](./assets/teaser.jpg)
9
+
10
+
11
+ <!-- TOC start (generated with https://github.com/derlin/bitdowntoc) -->
12
+
13
+ - [📖 Overview](#-overview)
14
+ - [🛠️ TODO List](#-todo-list)
15
+ - [🌍 Installation](#-installation)
16
+ - [💿 Checkpoints](#-checkpoints)
17
+ - [🎯 Run a Demo (Point Cloud and Camera Pose Estimation) ](#-run-a-demo-point-cloud-and-camera-pose-estimation)
18
+ - [👀 Visualization ](#-visualization)
19
+ - [📜 Citation ](#-citation)
20
+
21
+ <!-- TOC end -->
22
+
23
+ ## 📖 Overview
24
+ We present FLARE, a feed-forward model that simultaneously estimates high-quality camera poses, 3D geometry, and appearance from as few as 2-8 uncalibrated images. Our cascaded learning paradigm:
25
+
26
+ 1. **Camera Pose Estimation**: Directly regress camera poses without bundle adjustment
27
+ 2. **Geometry Reconstruction**: Decompose geometry reconstruction into two simpler sub-problems
28
+ 3. **Appearance Modeling**: Enable photorealistic novel view synthesis via 3D Gaussians
29
+
30
+ Achieves SOTA performance with inference times <0.5 seconds!
31
+
32
+ ## 🛠️ TODO List
33
+ - [x] Release point cloud and camera pose estimation code.
34
+ - [x] Updated Gradio demo (app.py).
35
+ - [ ] Release novel view synthesis code. (~2 weeks)
36
+ - [ ] Release evaluation code. (~2 weeks)
37
+ - [ ] Release training code.
38
+ - [ ] Release data processing code.
39
+
40
+ ## 🌍 Installation
41
+
42
+ ```
43
+ conda create -n flare python=3.8
44
+ conda activate flare
45
+ conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system
46
+ pip install -r requirements.txt
47
+ conda uninstall ffmpeg
48
+ conda install -c conda-forge ffmpeg
49
+ ```
50
+
51
+
52
+ ## 💿 Checkpoints
53
+ Download the checkpoint from [huggingface](https://huggingface.co/AntResearch/FLARE/blob/main/geometry_pose.pth) and place it in the /checkpoints/geometry_pose.pth directory.
54
+
55
+ ## 🎯 Run a Demo (Point Cloud and Camera Pose Estimation)
56
+
57
+
58
+ ```bash
59
+ sh scripts/run_pose_pointcloud.sh
60
+ ```
61
+
62
+
63
+ ```bash
64
+ torchrun --nproc_per_node=1 run_pose_pointcloud.py \
65
+ --test_dataset "1 @ CustomDataset(split='train', ROOT='Your/Data/Path', resolution=(512,384), seed=1, num_views=7, gt_num_image=0, aug_portrait_or_landscape=False, sequential_input=False)" \
66
+ --model "AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True, desc_conf_mode=('exp', 0, inf))" \
67
+ --pretrained "Your/Checkpoint/Path" \
68
+ --test_criterion "MeshOutput(sam=False)" --output_dir "log/" --amp 1 --seed 1 --num_workers 0
69
+ ```
70
+
71
+ **To run the demo using ground truth camera poses:**
72
+ Enable the wpose=True flag in both the CustomDataset and AsymmetricMASt3R. An example script demonstrating this setup is provided in run_pose_pointcloud_wpose.sh.
73
+
74
+ ```bash
75
+ sh scripts/run_pose_pointcloud_wpose.sh
76
+ ```
77
+
78
+ ## 👀 Visualization
79
+
80
+ ```
81
+ sh ./visualizer/vis.sh
82
+ ```
83
+
84
+
85
+ ```
86
+ CUDA_VISIBLE_DEVICES=0 python visualizer/run_vis.py --result_npz data/mesh/IMG_1511.HEIC.JPG.JPG/pred.npz --results_folder data/mesh/IMG_1511.HEIC.JPG.JPG/
87
+ ```
88
+
89
+
90
+ ## 📜 Citation
91
+ ```bibtex
92
+ @misc{zhang2025flarefeedforwardgeometryappearance,
93
+ title={FLARE: Feed-forward Geometry, Appearance and Camera Estimation from Uncalibrated Sparse Views},
94
+ author={Shangzhan Zhang and Jianyuan Wang and Yinghao Xu and Nan Xue and Christian Rupprecht and Xiaowei Zhou and Yujun Shen and Gordon Wetzstein},
95
+ year={2025},
96
+ eprint={2502.12138},
97
+ archivePrefix={arXiv},
98
+ primaryClass={cs.CV},
99
+ url={https://arxiv.org/abs/2502.12138},
100
+ }
app.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import mast3r.utils.path_to_dust3r # noqa
3
+ import dust3r.utils.path_to_croco # noqa: F401
4
+ import mast3r.utils.path_to_dust3r # noqa
5
+ import os
6
+ import sys
7
+ import os.path as path
8
+ import torch
9
+ import tempfile
10
+ import gradio
11
+ import shutil
12
+ import math
13
+ from mast3r.model import AsymmetricMASt3R
14
+ import matplotlib.pyplot as pl
15
+ from dust3r.utils.image import load_images
16
+ import torch.nn.functional as F
17
+ from pytorch3d.ops import knn_points
18
+ from dust3r.utils.geometry import xy_grid
19
+ import numpy as np
20
+ import cv2
21
+ from dust3r.utils.device import to_numpy
22
+ import trimesh
23
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
24
+ from scipy.spatial.transform import Rotation
25
+
26
+ pl.ion()
27
+ # for gpu >= Ampere and pytorch >= 1.12
28
+ torch.backends.cuda.matmul.allow_tf32 = True
29
+ batch_size = 1
30
+ inf = float('inf')
31
+ weights_path = "checkpoints/geometry_pose.pth"
32
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
33
+ ckpt = torch.load(weights_path, map_location=device)
34
+ model = AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True, desc_conf_mode=('exp', 0, inf))
35
+ model = AsymmetricMASt3R.from_pretrained("zhang3z/FLARE").to(device)
36
+ # model.from_pretrained(ckpt['model'], strict=False)
37
+ model = model.to(device).eval()
38
+
39
+ tmpdirname = tempfile.mkdtemp(suffix='_FLARE_gradio_demo')
40
+ image_size = 512
41
+ silent = True
42
+ gradio_delete_cache = 7200
43
+ backbone = torch.hub.load(
44
+ "facebookresearch/dinov2", "dinov2_vitb14_reg"
45
+ )
46
+ backbone = backbone.eval().cuda()
47
+
48
+ class FileState:
49
+ def __init__(self, outfile_name=None):
50
+ self.outfile_name = outfile_name
51
+
52
+ def __del__(self):
53
+ if self.outfile_name is not None and os.path.isfile(self.outfile_name):
54
+ os.remove(self.outfile_name)
55
+ self.outfile_name = None
56
+
57
+ def pad_to_square(reshaped_image):
58
+ B, C, H, W = reshaped_image.shape
59
+ max_dim = max(H, W)
60
+ pad_height = max_dim - H
61
+ pad_width = max_dim - W
62
+ padding = (pad_width // 2, pad_width - pad_width // 2,
63
+ pad_height // 2, pad_height - pad_height // 2)
64
+ padded_image = F.pad(reshaped_image, padding, mode='constant', value=0)
65
+ return padded_image
66
+
67
+ def generate_rank_by_dino(
68
+ reshaped_image, backbone, query_frame_num, image_size=336
69
+ ):
70
+ # Downsample image to image_size x image_size
71
+ # because we found it is unnecessary to use high resolution
72
+ rgbs = pad_to_square(reshaped_image)
73
+ rgbs = F.interpolate(
74
+ reshaped_image,
75
+ (image_size, image_size),
76
+ mode="bilinear",
77
+ align_corners=True,
78
+ )
79
+ rgbs = _resnet_normalize_image(rgbs.cuda())
80
+
81
+ # Get the image features (patch level)
82
+ frame_feat = backbone(rgbs, is_training=True)
83
+ frame_feat = frame_feat["x_norm_patchtokens"]
84
+ frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
85
+
86
+ # Compute the similiarty matrix
87
+ frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
88
+ similarity_matrix = torch.bmm(
89
+ frame_feat_norm, frame_feat_norm.transpose(-1, -2)
90
+ )
91
+ similarity_matrix = similarity_matrix.mean(dim=0)
92
+ distance_matrix = 100 - similarity_matrix.clone()
93
+
94
+ # Ignore self-pairing
95
+ similarity_matrix.fill_diagonal_(-100)
96
+
97
+ similarity_sum = similarity_matrix.sum(dim=1)
98
+
99
+ # Find the most common frame
100
+ most_common_frame_index = torch.argmax(similarity_sum).item()
101
+ return most_common_frame_index
102
+
103
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
104
+ _RESNET_STD = [0.229, 0.224, 0.225]
105
+ _resnet_mean = torch.tensor(_RESNET_MEAN).view(1, 3, 1, 1).cuda()
106
+ _resnet_std = torch.tensor(_RESNET_STD).view(1, 3, 1, 1).cuda()
107
+ def _resnet_normalize_image(img: torch.Tensor) -> torch.Tensor:
108
+ return (img - _resnet_mean) / _resnet_std
109
+
110
+ def calculate_index_mappings(query_index, S, device=None):
111
+ """
112
+ Construct an order that we can switch [query_index] and [0]
113
+ so that the content of query_index would be placed at [0]
114
+ """
115
+ new_order = torch.arange(S)
116
+ new_order[0] = query_index
117
+ new_order[query_index] = 0
118
+ if device is not None:
119
+ new_order = new_order.to(device)
120
+ return new_order
121
+
122
+ def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
123
+ cam_color=None, as_pointcloud=False,
124
+ transparent_cams=False, silent=False):
125
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
126
+ pts3d = to_numpy(pts3d)
127
+ imgs = to_numpy(imgs)
128
+ focals = to_numpy(focals)
129
+ mask = to_numpy(mask)
130
+ cams2world = to_numpy(cams2world)
131
+
132
+ scene = trimesh.Scene()
133
+ # full pointcloud
134
+ if as_pointcloud:
135
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
136
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
137
+ valid_msk = np.isfinite(pts.sum(axis=1))
138
+ pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
139
+ scene.add_geometry(pct)
140
+ else:
141
+ meshes = []
142
+ for i in range(len(imgs)):
143
+ pts3d_i = pts3d[i].reshape(imgs[i].shape)
144
+ msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
145
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
146
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
147
+ scene.add_geometry(mesh)
148
+
149
+ # add each camera
150
+ for i, pose_c2w in enumerate(cams2world):
151
+ if isinstance(cam_color, list):
152
+ camera_edge_color = cam_color[i]
153
+ else:
154
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
155
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
156
+ None if transparent_cams else imgs[i], focals[i],
157
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
158
+
159
+ rot = np.eye(4)
160
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
161
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
162
+ if not silent:
163
+ print('(exporting 3D scene to', outfile, ')')
164
+
165
+ scene.export(file_obj=outfile)
166
+ return outfile
167
+
168
+ @spaces.GPU(duration=180)
169
+ def local_get_reconstructed_scene(inputfiles, min_conf_thr, cam_size):
170
+
171
+ batch = load_images(inputfiles, size=image_size, verbose=not silent)
172
+ images = [gt['img'] for gt in batch]
173
+ images = torch.cat(images, dim=0)
174
+ images = images / 2 + 0.5
175
+ index = generate_rank_by_dino(images, backbone, query_frame_num=1)
176
+ sorted_order = calculate_index_mappings(index, len(images), device=device)
177
+ sorted_batch = []
178
+ for i in range(len(batch)):
179
+ sorted_batch.append(batch[sorted_order[i]])
180
+ batch = sorted_batch
181
+ ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'rng', 'vid'])
182
+ ignore_dtype_keys = set(['true_shape', 'camera_pose', 'pts3d', 'fxfycxcy', 'img_org', 'camera_intrinsics', 'depthmap', 'depth_anything', 'fxfycxcy_unorm'])
183
+ dtype = torch.bfloat16
184
+ for view in batch:
185
+ for name in view.keys(): # pseudo_focal
186
+ if name in ignore_keys:
187
+ continue
188
+ if isinstance(view[name], torch.Tensor):
189
+ view[name] = view[name].to(device, non_blocking=True)
190
+ else:
191
+ view[name] = torch.tensor(view[name]).to(device, non_blocking=True)
192
+ if view[name].dtype == torch.float32 and name not in ignore_dtype_keys:
193
+ view[name] = view[name].to(dtype)
194
+ view1 = batch[:1]
195
+ view2 = batch[1:]
196
+ with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
197
+ pred1, pred2, pred_cameras = model(view1, view2, True, dtype)
198
+ pts3d = pred2['pts3d']
199
+ conf = pred2['conf']
200
+ pts3d = pts3d.detach().cpu()
201
+ B, N, H, W, _ = pts3d.shape
202
+ thres = torch.quantile(conf.flatten(2,3), min_conf_thr, dim=-1)[0]
203
+ masks_conf = conf > thres[None, :, None, None]
204
+ masks_conf = masks_conf.cpu()
205
+
206
+ images = [view['img'] for view in view1+view2]
207
+ shape = torch.stack([view['true_shape'] for view in view1+view2], dim=1).detach().cpu().numpy()
208
+ images = torch.stack(images,1).float().permute(0,1,3,4,2).detach().cpu().numpy()
209
+ images = images / 2 + 0.5
210
+ images = images.reshape(B, N, H, W, 3)
211
+ # estimate focal length
212
+ images = images[0]
213
+ pts3d = pts3d[0]
214
+ masks_conf = masks_conf[0]
215
+ xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1)
216
+ pp = torch.tensor((W/2, H/2)).to(xy_over_z)
217
+ pixels = xy_grid(W, H, device=xy_over_z.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2
218
+ u, v = pixels[:1].unbind(dim=-1)
219
+ x, y, z = pts3d[:1].reshape(-1,3).unbind(dim=-1)
220
+ fx_votes = (u * z) / x
221
+ fy_votes = (v * z) / y
222
+ # assume square pixels, hence same focal for X and Y
223
+ f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
224
+ focal = torch.nanmedian(f_votes, dim=-1).values
225
+ focal = focal.item()
226
+ pts3d = pts3d.numpy()
227
+ # use PNP to estimate camera poses
228
+ pred_poses = []
229
+ for i in range(pts3d.shape[0]):
230
+ shape_input_each = shape[:, i]
231
+ mesh_grid = xy_grid(shape_input_each[0,1], shape_input_each[0,0])
232
+ cur_inlier = conf[0,i] > torch.quantile(conf[0,i], 0.6)
233
+ cur_inlier = cur_inlier.detach().cpu().numpy()
234
+ ransac_thres = 0.5
235
+ confidence = 0.9999
236
+ iterationsCount = 10_000
237
+ cur_pts3d = pts3d[i]
238
+ K = np.float32([(focal, 0, W/2), (0, focal, H/2), (0, 0, 1)])
239
+ success, r_pose, t_pose, _ = cv2.solvePnPRansac(cur_pts3d[cur_inlier].astype(np.float64), mesh_grid[cur_inlier].astype(np.float64), K, None,
240
+ flags=cv2.SOLVEPNP_SQPNP,
241
+ iterationsCount=iterationsCount,
242
+ reprojectionError=1,
243
+ confidence=confidence)
244
+ r_pose = cv2.Rodrigues(r_pose)[0]
245
+ RT = np.r_[np.c_[r_pose, t_pose], [(0,0,0,1)]]
246
+ cam2world = np.linalg.inv(RT)
247
+ pred_poses.append(cam2world)
248
+ pred_poses = np.stack(pred_poses, axis=0)
249
+ pred_poses = torch.tensor(pred_poses)
250
+ # use knn to clean the point cloud
251
+ K = 10
252
+ points = torch.tensor(pts3d.reshape(1,-1,3)).cuda()
253
+ knn = knn_points(points, points, K=K)
254
+ dists = knn.dists
255
+ mean_dists = dists.mean(dim=-1)
256
+ masks_dist = mean_dists < torch.quantile(mean_dists.reshape(-1), 0.95)
257
+ masks_dist = masks_dist.detach().cpu().numpy()
258
+ masks_conf = (masks_conf > 0) & masks_dist.reshape(-1,H,W)
259
+ masks_conf = masks_conf > 0
260
+ outdir = tempfile.mkdtemp(suffix='_FLARE_gradio_demo')
261
+ os.makedirs(outdir, exist_ok=True)
262
+ focals = [focal] * len(images)
263
+ outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
264
+
265
+ _convert_scene_output_to_glb(outfile_name, images, pts3d, masks_conf, focals, pred_poses, as_pointcloud=True,
266
+ transparent_cams=False, cam_size=cam_size, silent=silent)
267
+ return filestate, outfile_name
268
+
269
+ css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
270
+ title = "FLARE Demo"
271
+ with gradio.Blocks(css=css, title=title, delete_cache=(gradio_delete_cache, gradio_delete_cache)) as demo:
272
+ filestate = gradio.State(None)
273
+ gradio.HTML('<h2 style="text-align: center;">3D Reconstruction with FLARE</h2>')
274
+ with gradio.Column():
275
+ inputfiles = gradio.File(file_count="multiple")
276
+ snapshot = gradio.Image(None, visible=False)
277
+ with gradio.Row():
278
+ # adjust the confidence threshold
279
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=0.1, minimum=0.0, maximum=1, step=0.05)
280
+ # adjust the camera size in the output pointcloud
281
+ cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
282
+ run_btn = gradio.Button("Run")
283
+ outmodel = gradio.Model3D()
284
+
285
+ # events
286
+ run_btn.click(fn=local_get_reconstructed_scene,
287
+ inputs=[inputfiles, min_conf_thr, cam_size],
288
+ outputs=[filestate, outmodel])
289
+
290
+ demo.launch(show_error=True, share=None, server_name=None, server_port=None)
291
+ shutil.rmtree(tmpdirname)
dust3r/croco/datasets/__init__.py ADDED
File without changes
dust3r/croco/datasets/crops/README.MD ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of crops from the real datasets
2
+
3
+ The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
4
+
5
+ ### Download the metadata of the crops to generate
6
+
7
+ First, download the metadata and put them in `./data/`:
8
+ ```
9
+ mkdir -p data
10
+ cd data/
11
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
12
+ unzip crop_metadata.zip
13
+ rm crop_metadata.zip
14
+ cd ..
15
+ ```
16
+
17
+ ### Prepare the original datasets
18
+
19
+ Second, download the original datasets in `./data/original_datasets/`.
20
+ ```
21
+ mkdir -p data/original_datasets
22
+ ```
23
+
24
+ ##### ARKitScenes
25
+
26
+ Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
27
+ The resulting file structure should be like:
28
+ ```
29
+ ./data/original_datasets/ARKitScenes/
30
+ └───Training
31
+ └───40753679
32
+ │ │ ultrawide
33
+ │ │ ...
34
+ └───40753686
35
+
36
+ ...
37
+ ```
38
+
39
+ ##### MegaDepth
40
+
41
+ Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
42
+ The resulting file structure should be like:
43
+
44
+ ```
45
+ ./data/original_datasets/MegaDepth/
46
+ └───0000
47
+ │ └───images
48
+ │ │ │ 1000557903_87fa96b8a4_o.jpg
49
+ │ │ └ ...
50
+ │ └─── ...
51
+ └───0001
52
+ │ │
53
+ │ └ ...
54
+ └─── ...
55
+ ```
56
+
57
+ ##### 3DStreetView
58
+
59
+ Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
60
+ The resulting file structure should be like:
61
+
62
+ ```
63
+ ./data/original_datasets/3DStreetView/
64
+ └───dataset_aligned
65
+ │ └───0002
66
+ │ │ │ 0000002_0000001_0000002_0000001.jpg
67
+ │ │ └ ...
68
+ │ └─── ...
69
+ └───dataset_unaligned
70
+ │ └───0003
71
+ │ │ │ 0000003_0000001_0000002_0000001.jpg
72
+ │ │ └ ...
73
+ │ └─── ...
74
+ ```
75
+
76
+ ##### IndoorVL
77
+
78
+ Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
79
+
80
+ ```
81
+ pip install kapture
82
+ mkdir -p ./data/original_datasets/IndoorVL
83
+ cd ./data/original_datasets/IndoorVL
84
+ kapture_download_dataset.py update
85
+ kapture_download_dataset.py install "HyundaiDepartmentStore_*"
86
+ kapture_download_dataset.py install "GangnamStation_*"
87
+ cd -
88
+ ```
89
+
90
+ ### Extract the crops
91
+
92
+ Now, extract the crops for each of the dataset:
93
+ ```
94
+ for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
95
+ do
96
+ python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
97
+ done
98
+ ```
99
+
100
+ ##### Note for IndoorVL
101
+
102
+ Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
103
+ To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
104
+ The impact on the performance is negligible.
dust3r/croco/datasets/crops/extract_crops_from_images.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Extracting crops for pre-training
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import argparse
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ import functools
13
+ from multiprocessing import Pool
14
+ import math
15
+
16
+
17
+ def arg_parser():
18
+ parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
19
+
20
+ parser.add_argument('--crops', type=str, required=True, help='crop file')
21
+ parser.add_argument('--root-dir', type=str, required=True, help='root directory')
22
+ parser.add_argument('--output-dir', type=str, required=True, help='output directory')
23
+ parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
24
+ parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
25
+ parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
26
+ parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
27
+ return parser
28
+
29
+
30
+ def main(args):
31
+ listing_path = os.path.join(args.output_dir, 'listing.txt')
32
+
33
+ print(f'Loading list of crops ... ({args.nthread} threads)')
34
+ crops, num_crops_to_generate = load_crop_file(args.crops)
35
+
36
+ print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
37
+ num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
38
+ num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
39
+
40
+ jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
41
+ del crops
42
+
43
+ os.makedirs(args.output_dir, exist_ok=True)
44
+ mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
45
+ call = functools.partial(save_image_crops, args)
46
+
47
+ print(f"Generating cropped images to {args.output_dir} ...")
48
+ with open(listing_path, 'w') as listing:
49
+ listing.write('# pair_path\n')
50
+ for results in tqdm(mmap(call, jobs), total=len(jobs)):
51
+ for path in results:
52
+ listing.write(f'{path}\n')
53
+ print('Finished writing listing to', listing_path)
54
+
55
+
56
+ def load_crop_file(path):
57
+ data = open(path).read().splitlines()
58
+ pairs = []
59
+ num_crops_to_generate = 0
60
+ for line in tqdm(data):
61
+ if line.startswith('#'):
62
+ continue
63
+ line = line.split(', ')
64
+ if len(line) < 8:
65
+ img1, img2, rotation = line
66
+ pairs.append((img1, img2, int(rotation), []))
67
+ else:
68
+ l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
69
+ rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
70
+ pairs[-1][-1].append((rect1, rect2))
71
+ num_crops_to_generate += 1
72
+ return pairs, num_crops_to_generate
73
+
74
+
75
+ def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
76
+ jobs = []
77
+ powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
78
+
79
+ def get_path(idx):
80
+ idx_array = []
81
+ d = idx
82
+ for level in range(num_levels - 1):
83
+ idx_array.append(idx // powers[level])
84
+ idx = idx % powers[level]
85
+ idx_array.append(d)
86
+ return '/'.join(map(lambda x: hex(x)[2:], idx_array))
87
+
88
+ idx = 0
89
+ for pair_data in tqdm(pairs):
90
+ img1, img2, rotation, crops = pair_data
91
+ if -60 <= rotation and rotation <= 60:
92
+ rotation = 0 # most likely not a true rotation
93
+ paths = [get_path(idx + k) for k in range(len(crops))]
94
+ idx += len(crops)
95
+ jobs.append(((img1, img2), rotation, crops, paths))
96
+ return jobs
97
+
98
+
99
+ def load_image(path):
100
+ try:
101
+ return Image.open(path).convert('RGB')
102
+ except Exception as e:
103
+ print('skipping', path, e)
104
+ raise OSError()
105
+
106
+
107
+ def save_image_crops(args, data):
108
+ # load images
109
+ img_pair, rot, crops, paths = data
110
+ try:
111
+ img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
112
+ except OSError as e:
113
+ return []
114
+
115
+ def area(sz):
116
+ return sz[0] * sz[1]
117
+
118
+ tgt_size = (args.imsize, args.imsize)
119
+
120
+ def prepare_crop(img, rect, rot=0):
121
+ # actual crop
122
+ img = img.crop(rect)
123
+
124
+ # resize to desired size
125
+ interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
126
+ img = img.resize(tgt_size, resample=interp)
127
+
128
+ # rotate the image
129
+ rot90 = (round(rot/90) % 4) * 90
130
+ if rot90 == 90:
131
+ img = img.transpose(Image.Transpose.ROTATE_90)
132
+ elif rot90 == 180:
133
+ img = img.transpose(Image.Transpose.ROTATE_180)
134
+ elif rot90 == 270:
135
+ img = img.transpose(Image.Transpose.ROTATE_270)
136
+ return img
137
+
138
+ results = []
139
+ for (rect1, rect2), path in zip(crops, paths):
140
+ crop1 = prepare_crop(img1, rect1)
141
+ crop2 = prepare_crop(img2, rect2, rot)
142
+
143
+ fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
144
+ fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
145
+ os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
146
+
147
+ assert not os.path.isfile(fullpath1), fullpath1
148
+ assert not os.path.isfile(fullpath2), fullpath2
149
+ crop1.save(fullpath1)
150
+ crop2.save(fullpath2)
151
+ results.append(path)
152
+
153
+ return results
154
+
155
+
156
+ if __name__ == '__main__':
157
+ args = arg_parser().parse_args()
158
+ main(args)
159
+
dust3r/croco/datasets/habitat_sim/README.MD ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of synthetic image pairs using Habitat-Sim
2
+
3
+ These instructions allow to generate pre-training pairs from the Habitat simulator.
4
+ As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
5
+
6
+ ### Download Habitat-Sim scenes
7
+ Download Habitat-Sim scenes:
8
+ - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
9
+ - We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
10
+ - Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
11
+ ```
12
+ ./data/
13
+ └──habitat-sim-data/
14
+ └──scene_datasets/
15
+ ├──hm3d/
16
+ ├──gibson/
17
+ ├──habitat-test-scenes/
18
+ ├──replica_cad_baked_lighting/
19
+ ├──replica_cad/
20
+ ├──ReplicaDataset/
21
+ └──scannet/
22
+ ```
23
+
24
+ ### Image pairs generation
25
+ We provide metadata to generate reproducible images pairs for pretraining and validation.
26
+ Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
27
+
28
+ Specifications:
29
+ - 256x256 resolution images, with 60 degrees field of view .
30
+ - Up to 1000 image pairs per scene.
31
+ - Number of scenes considered/number of images pairs per dataset:
32
+ - Scannet: 1097 scenes / 985 209 pairs
33
+ - HM3D:
34
+ - hm3d/train: 800 / 800k pairs
35
+ - hm3d/val: 100 scenes / 100k pairs
36
+ - hm3d/minival: 10 scenes / 10k pairs
37
+ - habitat-test-scenes: 3 scenes / 3k pairs
38
+ - replica_cad_baked_lighting: 13 scenes / 13k pairs
39
+
40
+ - Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
41
+
42
+ Download metadata and extract it:
43
+ ```bash
44
+ mkdir -p data/habitat_release_metadata/
45
+ cd data/habitat_release_metadata/
46
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
47
+ tar -xvf multiview_habitat_metadata.tar.gz
48
+ cd ../..
49
+ # Location of the metadata
50
+ METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
51
+ ```
52
+
53
+ Generate image pairs from metadata:
54
+ - The following command will print a list of commandlines to generate image pairs for each scene:
55
+ ```bash
56
+ # Target output directory
57
+ PAIRS_DATASET_DIR="./data/habitat_release/"
58
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
59
+ ```
60
+ - One can launch multiple of such commands in parallel e.g. using GNU Parallel:
61
+ ```bash
62
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
63
+ ```
64
+
65
+ ## Metadata generation
66
+
67
+ Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
68
+ ```bash
69
+ # Print commandlines to generate image pairs from the different scenes available.
70
+ PAIRS_DATASET_DIR=MY_CUSTOM_PATH
71
+ python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
72
+
73
+ # Once a dataset is generated, pack metadata files for reproducibility.
74
+ METADATA_DIR=MY_CUSTON_PATH
75
+ python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
76
+ ```
dust3r/croco/datasets/habitat_sim/__init__.py ADDED
File without changes
dust3r/croco/datasets/habitat_sim/generate_from_metadata.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
6
+ """
7
+ import os
8
+ from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator
9
+ from datasets.habitat_sim.paths import SCENES_DATASET
10
+ import argparse
11
+ import quaternion
12
+ import PIL.Image
13
+ import cv2
14
+ import json
15
+ from tqdm import tqdm
16
+
17
+ def generate_multiview_images_from_metadata(metadata_filename,
18
+ output_dir,
19
+ overload_params = dict(),
20
+ scene_datasets_paths=None,
21
+ exist_ok=False):
22
+ """
23
+ Generate images from a metadata file for reproducibility purposes.
24
+ """
25
+ # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
26
+ if scene_datasets_paths is not None:
27
+ scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True))
28
+
29
+ with open(metadata_filename, 'r') as f:
30
+ input_metadata = json.load(f)
31
+ metadata = dict()
32
+ for key, value in input_metadata.items():
33
+ # Optionally replace some paths
34
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
35
+ if scene_datasets_paths is not None:
36
+ for dataset_label, dataset_path in scene_datasets_paths.items():
37
+ if value.startswith(dataset_label):
38
+ value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label)))
39
+ break
40
+ metadata[key] = value
41
+
42
+ # Overload some parameters
43
+ for key, value in overload_params.items():
44
+ metadata[key] = value
45
+
46
+ generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))])
47
+ generate_depth = metadata["generate_depth"]
48
+
49
+ os.makedirs(output_dir, exist_ok=exist_ok)
50
+
51
+ generator = MultiviewHabitatSimGenerator(**generation_entries)
52
+
53
+ # Generate views
54
+ for idx_label, data in tqdm(metadata['multiviews'].items()):
55
+ positions = data["positions"]
56
+ orientations = data["orientations"]
57
+ n = len(positions)
58
+ for oidx in range(n):
59
+ observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx]))
60
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
61
+ # Color image saved using PIL
62
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
63
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
64
+ img.save(filename)
65
+ if generate_depth:
66
+ # Depth image as EXR file
67
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
68
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
69
+ # Camera parameters
70
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
71
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
72
+ with open(filename, "w") as f:
73
+ json.dump(camera_params, f)
74
+ # Save metadata
75
+ with open(os.path.join(output_dir, "metadata.json"), "w") as f:
76
+ json.dump(metadata, f)
77
+
78
+ generator.close()
79
+
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument("--metadata_filename", required=True)
83
+ parser.add_argument("--output_dir", required=True)
84
+ args = parser.parse_args()
85
+
86
+ generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename,
87
+ output_dir=args.output_dir,
88
+ scene_datasets_paths=SCENES_DATASET,
89
+ overload_params=dict(),
90
+ exist_ok=True)
91
+
92
+
dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script generating commandlines to generate image pairs from metadata files.
6
+ """
7
+ import os
8
+ import glob
9
+ from tqdm import tqdm
10
+ import argparse
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--input_dir", required=True)
15
+ parser.add_argument("--output_dir", required=True)
16
+ parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.")
17
+ args = parser.parse_args()
18
+
19
+ input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True)
20
+
21
+ for metadata_filename in tqdm(input_metadata_filenames):
22
+ output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir))
23
+ # Do not process the scene if the metadata file already exists
24
+ if os.path.exists(os.path.join(output_dir, "metadata.json")):
25
+ continue
26
+ commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
27
+ print(commandline)
dust3r/croco/datasets/habitat_sim/generate_multiview_images.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from tqdm import tqdm
6
+ import argparse
7
+ import PIL.Image
8
+ import numpy as np
9
+ import json
10
+ from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError
11
+ from datasets.habitat_sim.paths import list_scenes_available
12
+ import cv2
13
+ import quaternion
14
+ import shutil
15
+
16
+ def generate_multiview_images_for_scene(scene_dataset_config_file,
17
+ scene,
18
+ navmesh,
19
+ output_dir,
20
+ views_count,
21
+ size,
22
+ exist_ok=False,
23
+ generate_depth=False,
24
+ **kwargs):
25
+ """
26
+ Generate tuples of overlapping views for a given scene.
27
+ generate_depth: generate depth images and camera parameters.
28
+ """
29
+ if os.path.exists(output_dir) and not exist_ok:
30
+ print(f"Scene {scene}: data already generated. Ignoring generation.")
31
+ return
32
+ try:
33
+ print(f"Scene {scene}: {size} multiview acquisitions to generate...")
34
+ os.makedirs(output_dir, exist_ok=exist_ok)
35
+
36
+ metadata_filename = os.path.join(output_dir, "metadata.json")
37
+
38
+ metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file,
39
+ scene=scene,
40
+ navmesh=navmesh,
41
+ views_count=views_count,
42
+ size=size,
43
+ generate_depth=generate_depth,
44
+ **kwargs)
45
+ metadata_template["multiviews"] = dict()
46
+
47
+ if os.path.exists(metadata_filename):
48
+ print("Metadata file already exists:", metadata_filename)
49
+ print("Loading already generated metadata file...")
50
+ with open(metadata_filename, "r") as f:
51
+ metadata = json.load(f)
52
+
53
+ for key in metadata_template.keys():
54
+ if key != "multiviews":
55
+ assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}."
56
+ else:
57
+ print("No temporary file found. Starting generation from scratch...")
58
+ metadata = metadata_template
59
+
60
+ starting_id = len(metadata["multiviews"])
61
+ print(f"Starting generation from index {starting_id}/{size}...")
62
+ if starting_id >= size:
63
+ print("Generation already done.")
64
+ return
65
+
66
+ generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file,
67
+ scene=scene,
68
+ navmesh=navmesh,
69
+ views_count = views_count,
70
+ size = size,
71
+ **kwargs)
72
+
73
+ for idx in tqdm(range(starting_id, size)):
74
+ # Generate / re-generate the observations
75
+ try:
76
+ data = generator[idx]
77
+ observations = data["observations"]
78
+ positions = data["positions"]
79
+ orientations = data["orientations"]
80
+
81
+ idx_label = f"{idx:08}"
82
+ for oidx, observation in enumerate(observations):
83
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
84
+ # Color image saved using PIL
85
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
86
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
87
+ img.save(filename)
88
+ if generate_depth:
89
+ # Depth image as EXR file
90
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
91
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
92
+ # Camera parameters
93
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
94
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
95
+ with open(filename, "w") as f:
96
+ json.dump(camera_params, f)
97
+ metadata["multiviews"][idx_label] = {"positions": positions.tolist(),
98
+ "orientations": orientations.tolist(),
99
+ "covisibility_ratios": data["covisibility_ratios"].tolist(),
100
+ "valid_fractions": data["valid_fractions"].tolist(),
101
+ "pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()}
102
+ except RecursionError:
103
+ print("Recursion error: unable to sample observations for this scene. We will stop there.")
104
+ break
105
+
106
+ # Regularly save a temporary metadata file, in case we need to restart the generation
107
+ if idx % 10 == 0:
108
+ with open(metadata_filename, "w") as f:
109
+ json.dump(metadata, f)
110
+
111
+ # Save metadata
112
+ with open(metadata_filename, "w") as f:
113
+ json.dump(metadata, f)
114
+
115
+ generator.close()
116
+ except NoNaviguableSpaceError:
117
+ pass
118
+
119
+ def create_commandline(scene_data, generate_depth, exist_ok=False):
120
+ """
121
+ Create a commandline string to generate a scene.
122
+ """
123
+ def my_formatting(val):
124
+ if val is None or val == "":
125
+ return '""'
126
+ else:
127
+ return val
128
+ commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)}
129
+ --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)}
130
+ --navmesh {my_formatting(scene_data.navmesh)}
131
+ --output_dir {my_formatting(scene_data.output_dir)}
132
+ --generate_depth {int(generate_depth)}
133
+ --exist_ok {int(exist_ok)}
134
+ """
135
+ commandline = " ".join(commandline.split())
136
+ return commandline
137
+
138
+ if __name__ == "__main__":
139
+ os.umask(2)
140
+
141
+ parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available:
142
+ > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands
143
+ """)
144
+
145
+ parser.add_argument("--output_dir", type=str, required=True)
146
+ parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true")
147
+ parser.add_argument("--scene", type=str, default="")
148
+ parser.add_argument("--scene_dataset_config_file", type=str, default="")
149
+ parser.add_argument("--navmesh", type=str, default="")
150
+
151
+ parser.add_argument("--generate_depth", type=int, default=1)
152
+ parser.add_argument("--exist_ok", type=int, default=0)
153
+
154
+ kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000)
155
+
156
+ args = parser.parse_args()
157
+ generate_depth=bool(args.generate_depth)
158
+ exist_ok = bool(args.exist_ok)
159
+
160
+ if args.list_commands:
161
+ # Listing scenes available...
162
+ scenes_data = list_scenes_available(base_output_dir=args.output_dir)
163
+
164
+ for scene_data in scenes_data:
165
+ print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok))
166
+ else:
167
+ if args.scene == "" or args.output_dir == "":
168
+ print("Missing scene or output dir argument!")
169
+ print(parser.format_help())
170
+ else:
171
+ generate_multiview_images_for_scene(scene=args.scene,
172
+ scene_dataset_config_file = args.scene_dataset_config_file,
173
+ navmesh = args.navmesh,
174
+ output_dir = args.output_dir,
175
+ exist_ok=exist_ok,
176
+ generate_depth=generate_depth,
177
+ **kwargs)
dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ import numpy as np
6
+ import quaternion
7
+ import habitat_sim
8
+ import json
9
+ from sklearn.neighbors import NearestNeighbors
10
+ import cv2
11
+
12
+ # OpenCV to habitat camera convention transformation
13
+ R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)
14
+ R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
15
+ DEG2RAD = np.pi / 180
16
+
17
+ def compute_camera_intrinsics(height, width, hfov):
18
+ f = width/2 / np.tan(hfov/2 * np.pi/180)
19
+ cu, cv = width/2, height/2
20
+ return f, cu, cv
21
+
22
+ def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
23
+ R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
24
+ t_cam2world = np.asarray(camera_position)
25
+ return R_cam2world, t_cam2world
26
+
27
+ def compute_pointmap(depthmap, hfov):
28
+ """ Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
29
+ height, width = depthmap.shape
30
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
31
+ # Cast depth map to point
32
+ z_cam = depthmap
33
+ u, v = np.meshgrid(range(width), range(height))
34
+ x_cam = (u - cu) / f * z_cam
35
+ y_cam = (v - cv) / f * z_cam
36
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
37
+ return X_cam
38
+
39
+ def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
40
+ """Return a 3D point cloud corresponding to valid pixels of the depth map"""
41
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation)
42
+
43
+ X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
44
+ valid_mask = (X_cam[:,:,2] != 0.0)
45
+
46
+ X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
47
+ X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
48
+ return X_world
49
+
50
+ def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False):
51
+ """
52
+ Compute 'overlapping' metrics based on a distance threshold between two point clouds.
53
+ """
54
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2)
55
+ distances, indices = nbrs.kneighbors(pointcloud1)
56
+ intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
57
+
58
+ data = {"intersection1": intersection1,
59
+ "size1": len(pointcloud1)}
60
+ if compute_symmetric:
61
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1)
62
+ distances, indices = nbrs.kneighbors(pointcloud2)
63
+ intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
64
+ data["intersection2"] = intersection2
65
+ data["size2"] = len(pointcloud2)
66
+
67
+ return data
68
+
69
+ def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
70
+ """
71
+ Add camera parameters to the observation dictionnary produced by Habitat-Sim
72
+ In-place modifications.
73
+ """
74
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation)
75
+ height, width = observation['depth'].shape
76
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
77
+ K = np.asarray([[f, 0, cu],
78
+ [0, f, cv],
79
+ [0, 0, 1.0]])
80
+ observation["camera_intrinsics"] = K
81
+ observation["t_cam2world"] = t_cam2world
82
+ observation["R_cam2world"] = R_cam2world
83
+
84
+ def look_at(eye, center, up, return_cam2world=True):
85
+ """
86
+ Return camera pose looking at a given center point.
87
+ Analogous of gluLookAt function, using OpenCV camera convention.
88
+ """
89
+ z = center - eye
90
+ z /= np.linalg.norm(z, axis=-1, keepdims=True)
91
+ y = -up
92
+ y = y - np.sum(y * z, axis=-1, keepdims=True) * z
93
+ y /= np.linalg.norm(y, axis=-1, keepdims=True)
94
+ x = np.cross(y, z, axis=-1)
95
+
96
+ if return_cam2world:
97
+ R = np.stack((x, y, z), axis=-1)
98
+ t = eye
99
+ else:
100
+ # World to camera transformation
101
+ # Transposed matrix
102
+ R = np.stack((x, y, z), axis=-2)
103
+ t = - np.einsum('...ij, ...j', R, eye)
104
+ return R, t
105
+
106
+ def look_at_for_habitat(eye, center, up, return_cam2world=True):
107
+ R, t = look_at(eye, center, up)
108
+ orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
109
+ return orientation, t
110
+
111
+ def generate_orientation_noise(pan_range, tilt_range, roll_range):
112
+ return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP)
113
+ * quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT)
114
+ * quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT))
115
+
116
+
117
+ class NoNaviguableSpaceError(RuntimeError):
118
+ def __init__(self, *args):
119
+ super().__init__(*args)
120
+
121
+ class MultiviewHabitatSimGenerator:
122
+ def __init__(self,
123
+ scene,
124
+ navmesh,
125
+ scene_dataset_config_file,
126
+ resolution = (240, 320),
127
+ views_count=2,
128
+ hfov = 60,
129
+ gpu_id = 0,
130
+ size = 10000,
131
+ minimum_covisibility = 0.5,
132
+ transform = None):
133
+ self.scene = scene
134
+ self.navmesh = navmesh
135
+ self.scene_dataset_config_file = scene_dataset_config_file
136
+ self.resolution = resolution
137
+ self.views_count = views_count
138
+ assert(self.views_count >= 1)
139
+ self.hfov = hfov
140
+ self.gpu_id = gpu_id
141
+ self.size = size
142
+ self.transform = transform
143
+
144
+ # Noise added to camera orientation
145
+ self.pan_range = (-3, 3)
146
+ self.tilt_range = (-10, 10)
147
+ self.roll_range = (-5, 5)
148
+
149
+ # Height range to sample cameras
150
+ self.height_range = (1.2, 1.8)
151
+
152
+ # Random steps between the camera views
153
+ self.random_steps_count = 5
154
+ self.random_step_variance = 2.0
155
+
156
+ # Minimum fraction of the scene which should be valid (well defined depth)
157
+ self.minimum_valid_fraction = 0.7
158
+
159
+ # Distance threshold to see to select pairs
160
+ self.distance_threshold = 0.05
161
+ # Minimum IoU of a view point cloud with respect to the reference view to be kept.
162
+ self.minimum_covisibility = minimum_covisibility
163
+
164
+ # Maximum number of retries.
165
+ self.max_attempts_count = 100
166
+
167
+ self.seed = None
168
+ self._lazy_initialization()
169
+
170
+ def _lazy_initialization(self):
171
+ # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
172
+ if self.seed == None:
173
+ # Re-seed numpy generator
174
+ np.random.seed()
175
+ self.seed = np.random.randint(2**32-1)
176
+ sim_cfg = habitat_sim.SimulatorConfiguration()
177
+ sim_cfg.scene_id = self.scene
178
+ if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "":
179
+ sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
180
+ sim_cfg.random_seed = self.seed
181
+ sim_cfg.load_semantic_mesh = False
182
+ sim_cfg.gpu_device_id = self.gpu_id
183
+
184
+ depth_sensor_spec = habitat_sim.CameraSensorSpec()
185
+ depth_sensor_spec.uuid = "depth"
186
+ depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
187
+ depth_sensor_spec.resolution = self.resolution
188
+ depth_sensor_spec.hfov = self.hfov
189
+ depth_sensor_spec.position = [0.0, 0.0, 0]
190
+ depth_sensor_spec.orientation
191
+
192
+ rgb_sensor_spec = habitat_sim.CameraSensorSpec()
193
+ rgb_sensor_spec.uuid = "color"
194
+ rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
195
+ rgb_sensor_spec.resolution = self.resolution
196
+ rgb_sensor_spec.hfov = self.hfov
197
+ rgb_sensor_spec.position = [0.0, 0.0, 0]
198
+ agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec])
199
+
200
+ cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
201
+ self.sim = habitat_sim.Simulator(cfg)
202
+ if self.navmesh is not None and self.navmesh != "":
203
+ # Use pre-computed navmesh when available (usually better than those generated automatically)
204
+ self.sim.pathfinder.load_nav_mesh(self.navmesh)
205
+
206
+ if not self.sim.pathfinder.is_loaded:
207
+ # Try to compute a navmesh
208
+ navmesh_settings = habitat_sim.NavMeshSettings()
209
+ navmesh_settings.set_defaults()
210
+ self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
211
+
212
+ # Ensure that the navmesh is not empty
213
+ if not self.sim.pathfinder.is_loaded:
214
+ raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})")
215
+
216
+ self.agent = self.sim.initialize_agent(agent_id=0)
217
+
218
+ def close(self):
219
+ self.sim.close()
220
+
221
+ def __del__(self):
222
+ self.sim.close()
223
+
224
+ def __len__(self):
225
+ return self.size
226
+
227
+ def sample_random_viewpoint(self):
228
+ """ Sample a random viewpoint using the navmesh """
229
+ nav_point = self.sim.pathfinder.get_random_navigable_point()
230
+
231
+ # Sample a random viewpoint height
232
+ viewpoint_height = np.random.uniform(*self.height_range)
233
+ viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
234
+ viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
235
+ return viewpoint_position, viewpoint_orientation, nav_point
236
+
237
+ def sample_other_random_viewpoint(self, observed_point, nav_point):
238
+ """ Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
239
+ other_nav_point = nav_point
240
+
241
+ walk_directions = self.random_step_variance * np.asarray([1,0,1])
242
+ for i in range(self.random_steps_count):
243
+ temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3))
244
+ # Snapping may return nan when it fails
245
+ if not np.isnan(temp[0]):
246
+ other_nav_point = temp
247
+
248
+ other_viewpoint_height = np.random.uniform(*self.height_range)
249
+ other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
250
+
251
+ # Set viewing direction towards the central point
252
+ rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True)
253
+ rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
254
+ return position, rotation, other_nav_point
255
+
256
+ def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
257
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
258
+ # Observation
259
+ pixels_count = self.resolution[0] * self.resolution[1]
260
+ valid_fraction = len(other_pointcloud) / pixels_count
261
+ assert valid_fraction <= 1.0 and valid_fraction >= 0.0
262
+ overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True)
263
+ covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count)
264
+ is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility)
265
+ return is_valid, valid_fraction, covisibility
266
+
267
+ def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation):
268
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
269
+ # Observation
270
+ other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation)
271
+ return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
272
+
273
+ def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
274
+ agent_state = habitat_sim.AgentState()
275
+ agent_state.position = viewpoint_position
276
+ agent_state.rotation = viewpoint_orientation
277
+ self.agent.set_state(agent_state)
278
+ viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
279
+ _append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation)
280
+ return viewpoint_observations
281
+
282
+ def __getitem__(self, useless_idx):
283
+ ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
284
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
285
+ # Extract point cloud
286
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
287
+ camera_position=ref_position, camera_rotation=ref_orientation)
288
+
289
+ pixels_count = self.resolution[0] * self.resolution[1]
290
+ ref_valid_fraction = len(ref_pointcloud) / pixels_count
291
+ assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
292
+ if ref_valid_fraction < self.minimum_valid_fraction:
293
+ # This should produce a recursion error at some point when something is very wrong.
294
+ return self[0]
295
+ # Pick an reference observed point in the point cloud
296
+ observed_point = np.mean(ref_pointcloud, axis=0)
297
+
298
+ # Add the first image as reference
299
+ viewpoints_observations = [ref_observations]
300
+ viewpoints_covisibility = [ref_valid_fraction]
301
+ viewpoints_positions = [ref_position]
302
+ viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
303
+ viewpoints_clouds = [ref_pointcloud]
304
+ viewpoints_valid_fractions = [ref_valid_fraction]
305
+
306
+ for _ in range(self.views_count - 1):
307
+ # Generate an other viewpoint using some dummy random walk
308
+ successful_sampling = False
309
+ for sampling_attempt in range(self.max_attempts_count):
310
+ position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point)
311
+ # Observation
312
+ other_viewpoint_observations = self.render_viewpoint(position, rotation)
313
+ other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation)
314
+
315
+ is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
316
+ if is_valid:
317
+ successful_sampling = True
318
+ break
319
+ if not successful_sampling:
320
+ print("WARNING: Maximum number of attempts reached.")
321
+ # Dirty hack, try using a novel original viewpoint
322
+ return self[0]
323
+ viewpoints_observations.append(other_viewpoint_observations)
324
+ viewpoints_covisibility.append(covisibility)
325
+ viewpoints_positions.append(position)
326
+ viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding.
327
+ viewpoints_clouds.append(other_pointcloud)
328
+ viewpoints_valid_fractions.append(valid_fraction)
329
+
330
+ # Estimate relations between all pairs of images
331
+ pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations)))
332
+ for i in range(len(viewpoints_observations)):
333
+ pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i]
334
+ for j in range(i+1, len(viewpoints_observations)):
335
+ overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True)
336
+ pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count
337
+ pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count
338
+
339
+ # IoU is relative to the image 0
340
+ data = {"observations": viewpoints_observations,
341
+ "positions": np.asarray(viewpoints_positions),
342
+ "orientations": np.asarray(viewpoints_orientations),
343
+ "covisibility_ratios": np.asarray(viewpoints_covisibility),
344
+ "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
345
+ "pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float),
346
+ }
347
+
348
+ if self.transform is not None:
349
+ data = self.transform(data)
350
+ return data
351
+
352
+ def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False):
353
+ """
354
+ Return a list of images corresponding to a spiral trajectory from a random starting point.
355
+ Useful to generate nice visualisations.
356
+ Use an even number of half turns to get a nice "C1-continuous" loop effect
357
+ """
358
+ ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
359
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
360
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
361
+ camera_position=ref_position, camera_rotation=ref_orientation)
362
+ pixels_count = self.resolution[0] * self.resolution[1]
363
+ if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
364
+ # Dirty hack: ensure that the valid part of the image is significant
365
+ return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation)
366
+
367
+ # Pick an observed point in the point cloud
368
+ observed_point = np.mean(ref_pointcloud, axis=0)
369
+ ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation)
370
+
371
+ images = []
372
+ is_valid = []
373
+ # Spiral trajectory, use_constant orientation
374
+ for i, alpha in enumerate(np.linspace(0, 1, images_count)):
375
+ r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius
376
+ theta = alpha * half_turns * np.pi
377
+ x = r * np.cos(theta)
378
+ y = r * np.sin(theta)
379
+ z = 0.0
380
+ position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten()
381
+ if use_constant_orientation:
382
+ orientation = ref_orientation
383
+ else:
384
+ # trajectory looking at a mean point in front of the ref observation
385
+ orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP)
386
+ observations = self.render_viewpoint(position, orientation)
387
+ images.append(observations['color'][...,:3])
388
+ _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation)
389
+ is_valid.append(_is_valid)
390
+ return images, np.all(is_valid)
dust3r/croco/datasets/habitat_sim/pack_metadata_files.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ """
4
+ Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
5
+ """
6
+ import os
7
+ import glob
8
+ from tqdm import tqdm
9
+ import shutil
10
+ import json
11
+ from datasets.habitat_sim.paths import *
12
+ import argparse
13
+ import collections
14
+
15
+ if __name__ == "__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("input_dir")
18
+ parser.add_argument("output_dir")
19
+ args = parser.parse_args()
20
+
21
+ input_dirname = args.input_dir
22
+ output_dirname = args.output_dir
23
+
24
+ input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True)
25
+
26
+ images_count = collections.defaultdict(lambda : 0)
27
+
28
+ os.makedirs(output_dirname)
29
+ for input_filename in tqdm(input_metadata_filenames):
30
+ # Ignore empty files
31
+ with open(input_filename, "r") as f:
32
+ original_metadata = json.load(f)
33
+ if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0:
34
+ print("No views in", input_filename)
35
+ continue
36
+
37
+ relpath = os.path.relpath(input_filename, input_dirname)
38
+ print(relpath)
39
+
40
+ # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
41
+ # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
42
+ scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True))
43
+ metadata = dict()
44
+ for key, value in original_metadata.items():
45
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
46
+ known_path = False
47
+ for dataset, dataset_path in scenes_dataset_paths.items():
48
+ if value.startswith(dataset_path):
49
+ value = os.path.join(dataset, os.path.relpath(value, dataset_path))
50
+ known_path = True
51
+ break
52
+ if not known_path:
53
+ raise KeyError("Unknown path:" + value)
54
+ metadata[key] = value
55
+
56
+ # Compile some general statistics while packing data
57
+ scene_split = metadata["scene"].split("/")
58
+ upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
59
+ images_count[upper_level] += len(metadata["multiviews"])
60
+
61
+ output_filename = os.path.join(output_dirname, relpath)
62
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
63
+ with open(output_filename, "w") as f:
64
+ json.dump(metadata, f)
65
+
66
+ # Print statistics
67
+ print("Images count:")
68
+ for upper_level, count in images_count.items():
69
+ print(f"- {upper_level}: {count}")
dust3r/croco/datasets/habitat_sim/paths.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Paths to Habitat-Sim scenes
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import collections
11
+ from tqdm import tqdm
12
+
13
+
14
+ # Hardcoded path to the different scene datasets
15
+ SCENES_DATASET = {
16
+ "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
17
+ "gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
18
+ "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
19
+ "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
20
+ "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
21
+ "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
22
+ "scannet": "./data/habitat-sim/scene_datasets/scannet/"
23
+ }
24
+
25
+ SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"])
26
+
27
+ def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
28
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json")
29
+ scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
30
+ navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
31
+ scenes_data = []
32
+ for idx in range(len(scenes)):
33
+ output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
34
+ # Add scene
35
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
36
+ scene = scenes[idx] + ".scene_instance.json",
37
+ navmesh = os.path.join(base_path, navmeshes[idx]),
38
+ output_dir = output_dir)
39
+ scenes_data.append(data)
40
+ return scenes_data
41
+
42
+ def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]):
43
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json")
44
+ scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [])
45
+ navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
46
+ scenes_data = []
47
+ for idx in range(len(scenes)):
48
+ output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx])
49
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
50
+ scene = scenes[idx],
51
+ navmesh = "",
52
+ output_dir = output_dir)
53
+ scenes_data.append(data)
54
+ return scenes_data
55
+
56
+ def list_replica_scenes(base_output_dir, base_path):
57
+ scenes_data = []
58
+ for scene_id in os.listdir(base_path):
59
+ scene = os.path.join(base_path, scene_id, "mesh.ply")
60
+ navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it
61
+ scene_dataset_config_file = ""
62
+ output_dir = os.path.join(base_output_dir, scene_id)
63
+ # Add scene only if it does not exist already, or if exist_ok
64
+ data = SceneData(scene_dataset_config_file = scene_dataset_config_file,
65
+ scene = scene,
66
+ navmesh = navmesh,
67
+ output_dir = output_dir)
68
+ scenes_data.append(data)
69
+ return scenes_data
70
+
71
+
72
+ def list_scenes(base_output_dir, base_path):
73
+ """
74
+ Generic method iterating through a base_path folder to find scenes.
75
+ """
76
+ scenes_data = []
77
+ for root, dirs, files in os.walk(base_path, followlinks=True):
78
+ folder_scenes_data = []
79
+ for file in files:
80
+ name, ext = os.path.splitext(file)
81
+ if ext == ".glb":
82
+ scene = os.path.join(root, name + ".glb")
83
+ navmesh = os.path.join(root, name + ".navmesh")
84
+ if not os.path.exists(navmesh):
85
+ navmesh = ""
86
+ relpath = os.path.relpath(root, base_path)
87
+ output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name))
88
+ data = SceneData(scene_dataset_config_file="",
89
+ scene = scene,
90
+ navmesh = navmesh,
91
+ output_dir = output_dir)
92
+ folder_scenes_data.append(data)
93
+
94
+ # Specific check for HM3D:
95
+ # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
96
+ basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")]
97
+ if len(basis_scenes) != 0:
98
+ folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)]
99
+
100
+ scenes_data.extend(folder_scenes_data)
101
+ return scenes_data
102
+
103
+ def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
104
+ scenes_data = []
105
+
106
+ # HM3D
107
+ for split in ("minival", "train", "val", "examples"):
108
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
109
+ base_path=f"{scenes_dataset_paths['hm3d']}/{split}")
110
+
111
+ # Gibson
112
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"),
113
+ base_path=scenes_dataset_paths["gibson"])
114
+
115
+ # Habitat test scenes (just a few)
116
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
117
+ base_path=scenes_dataset_paths["habitat-test-scenes"])
118
+
119
+ # ReplicaCAD (baked lightning)
120
+ scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir)
121
+
122
+ # ScanNet
123
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"),
124
+ base_path=scenes_dataset_paths["scannet"])
125
+
126
+ # Replica
127
+ list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"),
128
+ base_path=scenes_dataset_paths["replica"])
129
+ return scenes_data
dust3r/croco/datasets/pairs_dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+
8
+ from datasets.transforms import get_pair_transforms
9
+
10
+ def load_image(impath):
11
+ return Image.open(impath)
12
+
13
+ def load_pairs_from_cache_file(fname, root=''):
14
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
15
+ with open(fname, 'r') as fid:
16
+ lines = fid.read().strip().splitlines()
17
+ pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines]
18
+ return pairs
19
+
20
+ def load_pairs_from_list_file(fname, root=''):
21
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
22
+ with open(fname, 'r') as fid:
23
+ lines = fid.read().strip().splitlines()
24
+ pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')]
25
+ return pairs
26
+
27
+
28
+ def write_cache_file(fname, pairs, root=''):
29
+ if len(root)>0:
30
+ if not root.endswith('/'): root+='/'
31
+ assert os.path.isdir(root)
32
+ s = ''
33
+ for im1, im2 in pairs:
34
+ if len(root)>0:
35
+ assert im1.startswith(root), im1
36
+ assert im2.startswith(root), im2
37
+ s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):])
38
+ with open(fname, 'w') as fid:
39
+ fid.write(s[:-1])
40
+
41
+ def parse_and_cache_all_pairs(dname, data_dir='./data/'):
42
+ if dname=='habitat_release':
43
+ dirname = os.path.join(data_dir, 'habitat_release')
44
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
45
+ cache_file = os.path.join(dirname, 'pairs.txt')
46
+ assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file
47
+
48
+ print('Parsing pairs for dataset: '+dname)
49
+ pairs = []
50
+ for root, dirs, files in os.walk(dirname):
51
+ if 'val' in root: continue
52
+ dirs.sort()
53
+ pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')]
54
+ print('Found {:,} pairs'.format(len(pairs)))
55
+ print('Writing cache to: '+cache_file)
56
+ write_cache_file(cache_file, pairs, root=dirname)
57
+
58
+ else:
59
+ raise NotImplementedError('Unknown dataset: '+dname)
60
+
61
+ def dnames_to_image_pairs(dnames, data_dir='./data/'):
62
+ """
63
+ dnames: list of datasets with image pairs, separated by +
64
+ """
65
+ all_pairs = []
66
+ for dname in dnames.split('+'):
67
+ if dname=='habitat_release':
68
+ dirname = os.path.join(data_dir, 'habitat_release')
69
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
70
+ cache_file = os.path.join(dirname, 'pairs.txt')
71
+ assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file
72
+ pairs = load_pairs_from_cache_file(cache_file, root=dirname)
73
+ elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']:
74
+ dirname = os.path.join(data_dir, dname+'_crops')
75
+ assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
76
+ list_file = os.path.join(dirname, 'listing.txt')
77
+ assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file)
78
+ pairs = load_pairs_from_list_file(list_file, root=dirname)
79
+ print(' {:s}: {:,} pairs'.format(dname, len(pairs)))
80
+ all_pairs += pairs
81
+ if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs)))
82
+ return all_pairs
83
+
84
+
85
+ class PairsDataset(Dataset):
86
+
87
+ def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'):
88
+ super().__init__()
89
+ self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
90
+ self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize)
91
+
92
+ def __len__(self):
93
+ return len(self.image_pairs)
94
+
95
+ def __getitem__(self, index):
96
+ im1path, im2path = self.image_pairs[index]
97
+ im1 = load_image(im1path)
98
+ im2 = load_image(im2path)
99
+ if self.transforms is not None: im1, im2 = self.transforms(im1, im2)
100
+ return im1, im2
101
+
102
+
103
+ if __name__=="__main__":
104
+ import argparse
105
+ parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset")
106
+ parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
107
+ parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset")
108
+ args = parser.parse_args()
109
+ parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
dust3r/croco/datasets/transforms.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ import torchvision.transforms
6
+ import torchvision.transforms.functional as F
7
+
8
+ # "Pair": apply a transform on a pair
9
+ # "Both": apply the exact same transform to both images
10
+
11
+ class ComposePair(torchvision.transforms.Compose):
12
+ def __call__(self, img1, img2):
13
+ for t in self.transforms:
14
+ img1, img2 = t(img1, img2)
15
+ return img1, img2
16
+
17
+ class NormalizeBoth(torchvision.transforms.Normalize):
18
+ def forward(self, img1, img2):
19
+ img1 = super().forward(img1)
20
+ img2 = super().forward(img2)
21
+ return img1, img2
22
+
23
+ class ToTensorBoth(torchvision.transforms.ToTensor):
24
+ def __call__(self, img1, img2):
25
+ img1 = super().__call__(img1)
26
+ img2 = super().__call__(img2)
27
+ return img1, img2
28
+
29
+ class RandomCropPair(torchvision.transforms.RandomCrop):
30
+ # the crop will be intentionally different for the two images with this class
31
+ def forward(self, img1, img2):
32
+ img1 = super().forward(img1)
33
+ img2 = super().forward(img2)
34
+ return img1, img2
35
+
36
+ class ColorJitterPair(torchvision.transforms.ColorJitter):
37
+ # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
38
+ def __init__(self, assymetric_prob, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.assymetric_prob = assymetric_prob
41
+ def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor):
42
+ for fn_id in fn_idx:
43
+ if fn_id == 0 and brightness_factor is not None:
44
+ img = F.adjust_brightness(img, brightness_factor)
45
+ elif fn_id == 1 and contrast_factor is not None:
46
+ img = F.adjust_contrast(img, contrast_factor)
47
+ elif fn_id == 2 and saturation_factor is not None:
48
+ img = F.adjust_saturation(img, saturation_factor)
49
+ elif fn_id == 3 and hue_factor is not None:
50
+ img = F.adjust_hue(img, hue_factor)
51
+ return img
52
+
53
+ def forward(self, img1, img2):
54
+
55
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
56
+ self.brightness, self.contrast, self.saturation, self.hue
57
+ )
58
+ img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
59
+ if torch.rand(1) < self.assymetric_prob: # assymetric:
60
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
61
+ self.brightness, self.contrast, self.saturation, self.hue
62
+ )
63
+ img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
64
+ return img1, img2
65
+
66
+ def get_pair_transforms(transform_str, totensor=True, normalize=True):
67
+ # transform_str is eg crop224+color
68
+ trfs = []
69
+ for s in transform_str.split('+'):
70
+ if s.startswith('crop'):
71
+ size = int(s[len('crop'):])
72
+ trfs.append(RandomCropPair(size))
73
+ elif s=='acolor':
74
+ trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0))
75
+ elif s=='': # if transform_str was ""
76
+ pass
77
+ else:
78
+ raise NotImplementedError('Unknown augmentation: '+s)
79
+
80
+ if totensor:
81
+ trfs.append( ToTensorBoth() )
82
+ if normalize:
83
+ trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )
84
+
85
+ if len(trfs)==0:
86
+ return None
87
+ elif len(trfs)==1:
88
+ return trfs
89
+ else:
90
+ return ComposePair(trfs)
91
+
92
+
93
+
94
+
95
+
dust3r/croco/models/__pycache__/blocks.cpython-312.pyc ADDED
Binary file (19.6 kB). View file
 
dust3r/croco/models/__pycache__/croco.cpython-312.pyc ADDED
Binary file (15.2 kB). View file
 
dust3r/croco/models/__pycache__/dpt_block.cpython-312.pyc ADDED
Binary file (16.9 kB). View file
 
dust3r/croco/models/__pycache__/masking.cpython-312.pyc ADDED
Binary file (1.28 kB). View file
 
dust3r/croco/models/__pycache__/pos_embed.cpython-312.pyc ADDED
Binary file (8.31 kB). View file
 
dust3r/croco/models/blocks.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Main encoder/decoder blocks
7
+ # --------------------------------------------------------
8
+ # References:
9
+ # timm
10
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
11
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
12
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
13
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
14
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from itertools import repeat
21
+ import collections.abc
22
+
23
+
24
+ def _ntuple(n):
25
+ def parse(x):
26
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
27
+ return x
28
+ return tuple(repeat(x, n))
29
+ return parse
30
+ to_2tuple = _ntuple(2)
31
+
32
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
33
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
34
+ """
35
+ if drop_prob == 0. or not training:
36
+ return x
37
+ keep_prob = 1 - drop_prob
38
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
39
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
40
+ if keep_prob > 0.0 and scale_by_keep:
41
+ random_tensor.div_(keep_prob)
42
+ return x * random_tensor
43
+
44
+ class DropPath(nn.Module):
45
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46
+ """
47
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
48
+ super(DropPath, self).__init__()
49
+ self.drop_prob = drop_prob
50
+ self.scale_by_keep = scale_by_keep
51
+
52
+ def forward(self, x):
53
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
54
+
55
+ def extra_repr(self):
56
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
57
+
58
+ class Mlp(nn.Module):
59
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
60
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
61
+ super().__init__()
62
+ out_features = out_features or in_features
63
+ hidden_features = hidden_features or in_features
64
+ bias = to_2tuple(bias)
65
+ drop_probs = to_2tuple(drop)
66
+
67
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
68
+ self.act = act_layer()
69
+ self.drop1 = nn.Dropout(drop_probs[0])
70
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
71
+ self.drop2 = nn.Dropout(drop_probs[1])
72
+
73
+ def forward(self, x):
74
+ x = self.fc1(x)
75
+ x = self.act(x)
76
+ x = self.drop1(x)
77
+ x = self.fc2(x)
78
+ x = self.drop2(x)
79
+ return x
80
+
81
+ class Attention(nn.Module):
82
+
83
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
84
+ super().__init__()
85
+ self.num_heads = num_heads
86
+ head_dim = dim // num_heads
87
+ self.scale = head_dim ** -0.5
88
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
89
+ self.attn_drop = nn.Dropout(attn_drop)
90
+ self.proj = nn.Linear(dim, dim)
91
+ self.proj_drop = nn.Dropout(proj_drop)
92
+ self.rope = rope
93
+
94
+ def forward(self, x, xpos):
95
+ B, N, C = x.shape
96
+
97
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
98
+ q, k, v = [qkv[:,:,i] for i in range(3)]
99
+ # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
100
+
101
+ if self.rope is not None:
102
+ q = self.rope(q, xpos)
103
+ k = self.rope(k, xpos)
104
+
105
+ # attn = (q @ k.transpose(-2, -1)) * self.scale
106
+ # attn = attn.softmax(dim=-1)
107
+ # attn = self.attn_drop(attn)
108
+ # x_old = (attn @ v).transpose(1, 2).reshape(B, N, C)
109
+ # with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
110
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale = self.scale, dropout_p=0.).transpose(1, 2).reshape(B, N, C)
111
+ # import ipdb;ipdb.set_trace()
112
+ # (x - x_old).abs().mean()
113
+ x = self.proj(x)
114
+ x = self.proj_drop(x)
115
+ return x
116
+
117
+ class LayerNorm(nn.LayerNorm):
118
+ def forward(self, x):
119
+ t = x.dtype
120
+ x = super().forward(x.type(torch.float32))
121
+ return x.type(t)
122
+
123
+ class Block(nn.Module):
124
+
125
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
126
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
127
+ super().__init__()
128
+ self.norm1 = norm_layer(dim)
129
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
130
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
131
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
132
+ self.norm2 = norm_layer(dim)
133
+ mlp_hidden_dim = int(dim * mlp_ratio)
134
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
135
+
136
+ def forward(self, x, xpos):
137
+ dtype = x.dtype
138
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
139
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
140
+ return x
141
+
142
+ class CrossAttention(nn.Module):
143
+
144
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
145
+ super().__init__()
146
+ self.num_heads = num_heads
147
+ head_dim = dim // num_heads
148
+ self.scale = head_dim ** -0.5
149
+
150
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
151
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
152
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
153
+ self.attn_drop = nn.Dropout(attn_drop)
154
+ self.proj = nn.Linear(dim, dim)
155
+ self.proj_drop = nn.Dropout(proj_drop)
156
+
157
+ self.rope = rope
158
+
159
+ def forward(self, query, key, value, qpos, kpos):
160
+ B, Nq, C = query.shape
161
+ Nk = key.shape[1]
162
+ Nv = value.shape[1]
163
+
164
+ q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
165
+ k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
166
+ v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
167
+
168
+ if self.rope is not None:
169
+ q = self.rope(q, qpos)
170
+ k = self.rope(k, kpos)
171
+
172
+ # attn = (q @ k.transpose(-2, -1)) * self.scale
173
+ # attn = attn.softmax(dim=-1)
174
+ # attn = self.attn_drop(attn)
175
+
176
+ # x_old = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
177
+ # with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
178
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale = self.scale, dropout_p=0.).transpose(1, 2).reshape(B, Nq, C)
179
+ # import ipdb;ipdb.set_trace()
180
+ # (x - x_old).abs().mean()
181
+ x = self.proj(x)
182
+ x = self.proj_drop(x)
183
+ return x
184
+
185
+
186
+ class DecoderBlock_onlyself(nn.Module):
187
+
188
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
189
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
190
+ super().__init__()
191
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
192
+ # self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
193
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
194
+ self.norm1 = norm_layer(dim)
195
+ self.norm3 = norm_layer(dim)
196
+ mlp_hidden_dim = int(dim * mlp_ratio)
197
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
198
+ # self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
199
+
200
+ def forward(self, x, xpos, split=False):
201
+ # if split==False:
202
+ # else:
203
+ # x = x.reshape(-1, y.shape[1], y.shape[2])
204
+ # x = x + self.drop_path(self.attn(self.norm1(x), xpos.reshape(-1, y.shape[1], 2)))
205
+ # x = x.reshape(y.shape[0], -1, y.shape[2])
206
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
207
+ # y_ = self.norm_y(y)
208
+ # x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
209
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
210
+ return x
211
+
212
+
213
+
214
+ class DecoderBlock_onlycross(nn.Module):
215
+
216
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
217
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
218
+ super().__init__()
219
+ self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
220
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
221
+ self.norm2 = norm_layer(dim)
222
+ self.norm3 = norm_layer(dim)
223
+ mlp_hidden_dim = int(dim * mlp_ratio)
224
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
225
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
226
+ def forward(self, x, y, xpos, ypos, split=False):
227
+ y_ = self.norm_y(y)
228
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
229
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
230
+ return x, y
231
+
232
+ class DecoderBlock(nn.Module):
233
+
234
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
235
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
236
+ super().__init__()
237
+ self.norm1 = norm_layer(dim)
238
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
239
+ self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
240
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
241
+ self.norm2 = norm_layer(dim)
242
+ self.norm3 = norm_layer(dim)
243
+ mlp_hidden_dim = int(dim * mlp_ratio)
244
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
245
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
246
+
247
+ def forward(self, x, y, xpos, ypos, split=False):
248
+ # if split==False:
249
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
250
+ # else:
251
+ # x = x.reshape(-1, y.shape[1], y.shape[2])
252
+ # x = x + self.drop_path(self.attn(self.norm1(x), xpos.reshape(-1, y.shape[1], 2)))
253
+ # x = x.reshape(y.shape[0], -1, y.shape[2])
254
+ y_ = self.norm_y(y)
255
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
256
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
257
+ return x, y
258
+
259
+
260
+ # patch embedding
261
+ class PositionGetter(object):
262
+ """ return positions of patches """
263
+
264
+ def __init__(self):
265
+ self.cache_positions = {}
266
+
267
+ def __call__(self, b, h, w, device):
268
+ if not (h,w) in self.cache_positions:
269
+ x = torch.arange(w, device=device)
270
+ y = torch.arange(h, device=device)
271
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
272
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
273
+ return pos
274
+
275
+ class PatchEmbed(nn.Module):
276
+ """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
277
+
278
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
279
+ super().__init__()
280
+ img_size = to_2tuple(img_size)
281
+ patch_size = to_2tuple(patch_size)
282
+ self.img_size = img_size
283
+ self.patch_size = patch_size
284
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
285
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
286
+ self.flatten = flatten
287
+
288
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
289
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
290
+
291
+ self.position_getter = PositionGetter()
292
+
293
+ def forward(self, x):
294
+ B, C, H, W = x.shape
295
+ torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
296
+ torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
297
+ x = self.proj(x)
298
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
299
+ if self.flatten:
300
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
301
+ x = self.norm(x)
302
+ return x, pos
303
+
304
+ def _init_weights(self):
305
+ w = self.proj.weight.data
306
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
307
+
dust3r/croco/models/criterion.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Criterion to train CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+ class MaskedMSE(torch.nn.Module):
14
+
15
+ def __init__(self, norm_pix_loss=False, masked=True):
16
+ """
17
+ norm_pix_loss: normalize each patch by their pixel mean and variance
18
+ masked: compute loss over the masked patches only
19
+ """
20
+ super().__init__()
21
+ self.norm_pix_loss = norm_pix_loss
22
+ self.masked = masked
23
+
24
+ def forward(self, pred, mask, target):
25
+
26
+ if self.norm_pix_loss:
27
+ mean = target.mean(dim=-1, keepdim=True)
28
+ var = target.var(dim=-1, keepdim=True)
29
+ target = (target - mean) / (var + 1.e-6)**.5
30
+
31
+ loss = (pred - target) ** 2
32
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
33
+ if self.masked:
34
+ loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
35
+ else:
36
+ loss = loss.mean() # mean loss
37
+ return loss
dust3r/croco/models/croco.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # CroCo model during pretraining
7
+ # --------------------------------------------------------
8
+
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
14
+ from functools import partial
15
+ from models.blocks import Block, DecoderBlock, PatchEmbed, DecoderBlock_onlyself, DecoderBlock_onlycross
16
+ from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
17
+ from models.masking import RandomMask
18
+ from mast3r.modules import AttnBlock
19
+
20
+ class CroCoNet(nn.Module):
21
+
22
+ def __init__(self,
23
+ img_size=224, # input image size
24
+ patch_size=16, # patch_size
25
+ mask_ratio=0.9, # ratios of masked tokens
26
+ enc_embed_dim=768, # encoder feature dimension
27
+ enc_depth=12, # encoder depth
28
+ enc_num_heads=12, # encoder number of heads in the transformer block
29
+ dec_embed_dim=512, # decoder feature dimension
30
+ dec_depth=8, # decoder depth
31
+ dec_num_heads=16, # decoder number of heads in the transformer block
32
+ mlp_ratio=4,
33
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
34
+ norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
35
+ pos_embed='cosine', # positional embedding (either cosine or RoPE100)
36
+ ):
37
+
38
+ super(CroCoNet, self).__init__()
39
+
40
+ # patch embeddings (with initialization done as in MAE)
41
+ self._set_patch_embed(img_size, patch_size, enc_embed_dim)
42
+
43
+ # mask generations
44
+ self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
45
+
46
+ self.pos_embed = pos_embed
47
+ if pos_embed=='cosine':
48
+ # positional embedding of the encoder
49
+ enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
50
+ self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
51
+ # positional embedding of the decoder
52
+ dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
53
+ self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
54
+ # pos embedding in each block
55
+ self.rope = None # nothing for cosine
56
+ elif pos_embed.startswith('RoPE'): # eg RoPE100
57
+ self.enc_pos_embed = None # nothing to add in the encoder with RoPE
58
+ self.dec_pos_embed = None # nothing to add in the decoder with RoPE
59
+ if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
60
+ freq = float(pos_embed[len('RoPE'):])
61
+ self.rope = RoPE2D(freq=freq)
62
+ else:
63
+ raise NotImplementedError('Unknown pos_embed '+pos_embed)
64
+
65
+ # transformer for the encoder
66
+ self.enc_depth = enc_depth
67
+ self.enc_embed_dim = enc_embed_dim
68
+ self.enc_blocks = nn.ModuleList([
69
+ Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
70
+ for i in range(enc_depth)])
71
+ self.enc_norm = norm_layer(enc_embed_dim)
72
+ self.enc_blocks_stage2 = nn.ModuleList([
73
+ Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
74
+ for i in range(enc_depth//6)])
75
+ self.enc_norm_stage2 = norm_layer(enc_embed_dim)
76
+ # masked tokens
77
+ self._set_mask_token(dec_embed_dim)
78
+
79
+ # decoder
80
+ self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
81
+
82
+ # prediction head
83
+ self._set_prediction_head(dec_embed_dim, patch_size)
84
+
85
+ # initializer weights
86
+ self.initialize_weights()
87
+
88
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
89
+ self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
90
+
91
+
92
+ def _set_mask_generator(self, num_patches, mask_ratio):
93
+ self.mask_generator = RandomMask(num_patches, mask_ratio)
94
+
95
+ def _set_mask_token(self, dec_embed_dim):
96
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
97
+
98
+ def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
99
+ self.dec_depth = dec_depth
100
+ self.dec_embed_dim = dec_embed_dim
101
+ # transfer from encoder to decoder
102
+ self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
103
+ # transformer for the decoder
104
+ self.dec_blocks = nn.ModuleList([
105
+ DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
106
+ for i in range(dec_depth)])
107
+ self.dec_blocks_fine = nn.ModuleList([
108
+ DecoderBlock_onlyself(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
109
+ for i in range(dec_depth)])
110
+ self.dec_blocks_point_cross = nn.ModuleList([
111
+ DecoderBlock_onlycross(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
112
+ for i in range(dec_depth)])
113
+ # final norm layer
114
+ self.cam_cond_encoder = nn.ModuleList([AttnBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
115
+ for _ in range(dec_depth)])
116
+ self.pose_token_ref = nn.Parameter(torch.randn(1, 1, dec_embed_dim))
117
+ self.pose_token_source = nn.Parameter(torch.randn(1, 1, dec_embed_dim))
118
+
119
+ self.cam_cond_embed = nn.ModuleList([nn.Linear(dec_embed_dim, dec_embed_dim, bias=False) for i in range(dec_depth)])
120
+ self.dec_norm = norm_layer(dec_embed_dim)
121
+ self.dec_cam_norm = norm_layer(dec_embed_dim)
122
+
123
+ def _set_prediction_head(self, dec_embed_dim, patch_size):
124
+ self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
125
+
126
+
127
+ def initialize_weights(self):
128
+ # patch embed
129
+ self.patch_embed._init_weights()
130
+ # mask tokens
131
+ if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
132
+ # linears and layer norms
133
+ self.apply(self._init_weights)
134
+
135
+ def _init_weights(self, m):
136
+ if isinstance(m, nn.Linear):
137
+ # we use xavier_uniform following official JAX ViT:
138
+ torch.nn.init.xavier_uniform_(m.weight)
139
+ if isinstance(m, nn.Linear) and m.bias is not None:
140
+ nn.init.constant_(m.bias, 0)
141
+ elif isinstance(m, nn.LayerNorm):
142
+ if m.elementwise_affine == True:
143
+ nn.init.constant_(m.bias, 0)
144
+ nn.init.constant_(m.weight, 1.0)
145
+
146
+ def _encode_image_fine(self, image, shapes, dtype=torch.float32):
147
+ """
148
+ image has B x 3 x img_size x img_size
149
+ do_mask: whether to perform masking or not
150
+ return_all_blocks: if True, return the features at the end of every block
151
+ instead of just the features from the last block (eg for some prediction heads)
152
+ """
153
+ # embed the image into patches (x has size B x Npatches x C)
154
+ # and get position if each return patch (pos has size B x Npatches x 2)
155
+ x, pos = self.patch_embed_fine(image, shapes)
156
+ x = x.to(dtype)
157
+ # add positional embedding without cls token
158
+ B,N,C = x.size()
159
+ posvis = pos
160
+ # now apply the transformer encoder and normalization
161
+ for blk in self.enc_fine_blocks:
162
+ x = blk(x, posvis)
163
+ x = self.enc_fine_norm(x)
164
+ x, pos = self.patch_embed_fine2(x)
165
+ x = self.enc_fine_norm2(x)
166
+ return x, pos, None
167
+
168
+ def _encode_image(self, image, do_mask=False, return_all_blocks=False):
169
+ """
170
+ image has B x 3 x img_size x img_size
171
+ do_mask: whether to perform masking or not
172
+ return_all_blocks: if True, return the features at the end of every block
173
+ instead of just the features from the last block (eg for some prediction heads)
174
+ """
175
+ # embed the image into patches (x has size B x Npatches x C)
176
+ # and get position if each return patch (pos has size B x Npatches x 2)
177
+ x, pos = self.patch_embed(image)
178
+ # add positional embedding without cls token
179
+ if self.enc_pos_embed is not None:
180
+ x = x + self.enc_pos_embed[None,...]
181
+ # apply masking
182
+ B,N,C = x.size()
183
+ if do_mask:
184
+ masks = self.mask_generator(x)
185
+ x = x[~masks].view(B, -1, C)
186
+ posvis = pos[~masks].view(B, -1, 2)
187
+ else:
188
+ B,N,C = x.size()
189
+ masks = torch.zeros((B,N), dtype=bool)
190
+ posvis = pos
191
+ # now apply the transformer encoder and normalization
192
+ if return_all_blocks:
193
+ out = []
194
+ for blk in self.enc_blocks:
195
+ x = blk(x, posvis)
196
+ out.append(x)
197
+ out[-1] = self.enc_norm(out[-1])
198
+ return out, pos, masks
199
+ else:
200
+ for blk in self.enc_blocks:
201
+ x = blk(x, posvis)
202
+ x = self.enc_norm(x)
203
+ return x, pos, masks
204
+
205
+ def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
206
+ """
207
+ return_all_blocks: if True, return the features at the end of every block
208
+ instead of just the features from the last block (eg for some prediction heads)
209
+
210
+ masks1 can be None => assume image1 fully visible
211
+ """
212
+ # encoder to decoder layer
213
+ visf1 = self.decoder_embed(feat1)
214
+ f2 = self.decoder_embed(feat2)
215
+ # append masked tokens to the sequence
216
+ B,Nenc,C = visf1.size()
217
+ if masks1 is None: # downstreams
218
+ f1_ = visf1
219
+ else: # pretraining
220
+ Ntotal = masks1.size(1)
221
+ f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
222
+ f1_[~masks1] = visf1.view(B * Nenc, C)
223
+ # add positional embedding
224
+ if self.dec_pos_embed is not None:
225
+ f1_ = f1_ + self.dec_pos_embed
226
+ f2 = f2 + self.dec_pos_embed
227
+ # apply Transformer blocks
228
+ out = f1_
229
+ out2 = f2
230
+ if return_all_blocks:
231
+ _out, out = out, []
232
+ for blk in self.dec_blocks:
233
+ _out, out2 = blk(_out, out2, pos1, pos2)
234
+ out.append(_out)
235
+ out[-1] = self.dec_norm(out[-1])
236
+ else:
237
+ for blk in self.dec_blocks:
238
+ out, out2 = blk(out, out2, pos1, pos2)
239
+ out = self.dec_norm(out)
240
+ return out
241
+
242
+ def patchify(self, imgs):
243
+ """
244
+ imgs: (B, 3, H, W)
245
+ x: (B, L, patch_size**2 *3)
246
+ """
247
+ p = self.patch_embed.patch_size[0]
248
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
249
+
250
+ h = w = imgs.shape[2] // p
251
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
252
+ x = torch.einsum('nchpwq->nhwpqc', x)
253
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
254
+
255
+ return x
256
+
257
+ def unpatchify(self, x, channels=3):
258
+ """
259
+ x: (N, L, patch_size**2 *channels)
260
+ imgs: (N, 3, H, W)
261
+ """
262
+ patch_size = self.patch_embed.patch_size[0]
263
+ h = w = int(x.shape[1]**.5)
264
+ assert h * w == x.shape[1]
265
+ x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
266
+ x = torch.einsum('nhwpqc->nchpwq', x)
267
+ imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
268
+ return imgs
269
+
270
+ def forward(self, img1, img2):
271
+ """
272
+ img1: tensor of size B x 3 x img_size x img_size
273
+ img2: tensor of size B x 3 x img_size x img_size
274
+
275
+ out will be B x N x (3*patch_size*patch_size)
276
+ masks are also returned as B x N just in case
277
+ """
278
+ # encoder of the masked first image
279
+ feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
280
+ # encoder of the second image
281
+ feat2, pos2, _ = self._encode_image(img2, do_mask=False)
282
+ # decoder
283
+ decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
284
+ # prediction head
285
+ out = self.prediction_head(decfeat)
286
+ # get target
287
+ target = self.patchify(img1)
288
+ return out, mask1, target
dust3r/croco/models/dpt_block.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # DPT head for ViTs
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # https://github.com/isl-org/DPT
9
+ # https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+ from typing import Union, Tuple, Iterable, List, Optional, Dict
16
+
17
+ def pair(t):
18
+ return t if isinstance(t, tuple) else (t, t)
19
+
20
+ def make_scratch(in_shape, out_shape, groups=1, expand=False):
21
+ scratch = nn.Module()
22
+
23
+ out_shape1 = out_shape
24
+ out_shape2 = out_shape
25
+ out_shape3 = out_shape
26
+ out_shape4 = out_shape
27
+ if expand == True:
28
+ out_shape1 = out_shape
29
+ out_shape2 = out_shape * 2
30
+ out_shape3 = out_shape * 4
31
+ out_shape4 = out_shape * 8
32
+
33
+ scratch.layer1_rn = nn.Conv2d(
34
+ in_shape[0],
35
+ out_shape1,
36
+ kernel_size=3,
37
+ stride=1,
38
+ padding=1,
39
+ bias=False,
40
+ groups=groups,
41
+ )
42
+ scratch.layer2_rn = nn.Conv2d(
43
+ in_shape[1],
44
+ out_shape2,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1,
48
+ bias=False,
49
+ groups=groups,
50
+ )
51
+ scratch.layer3_rn = nn.Conv2d(
52
+ in_shape[2],
53
+ out_shape3,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1,
57
+ bias=False,
58
+ groups=groups,
59
+ )
60
+ scratch.layer4_rn = nn.Conv2d(
61
+ in_shape[3],
62
+ out_shape4,
63
+ kernel_size=3,
64
+ stride=1,
65
+ padding=1,
66
+ bias=False,
67
+ groups=groups,
68
+ )
69
+
70
+ scratch.layer_rn = nn.ModuleList([
71
+ scratch.layer1_rn,
72
+ scratch.layer2_rn,
73
+ scratch.layer3_rn,
74
+ scratch.layer4_rn,
75
+ ])
76
+
77
+ return scratch
78
+
79
+ class ResidualConvUnit_custom(nn.Module):
80
+ """Residual convolution module."""
81
+
82
+ def __init__(self, features, activation, bn):
83
+ """Init.
84
+ Args:
85
+ features (int): number of features
86
+ """
87
+ super().__init__()
88
+
89
+ self.bn = bn
90
+
91
+ self.groups = 1
92
+
93
+ self.conv1 = nn.Conv2d(
94
+ features,
95
+ features,
96
+ kernel_size=3,
97
+ stride=1,
98
+ padding=1,
99
+ bias=not self.bn,
100
+ groups=self.groups,
101
+ )
102
+
103
+ self.conv2 = nn.Conv2d(
104
+ features,
105
+ features,
106
+ kernel_size=3,
107
+ stride=1,
108
+ padding=1,
109
+ bias=not self.bn,
110
+ groups=self.groups,
111
+ )
112
+
113
+ if self.bn == True:
114
+ self.bn1 = nn.BatchNorm2d(features)
115
+ self.bn2 = nn.BatchNorm2d(features)
116
+
117
+ self.activation = activation
118
+
119
+ self.skip_add = nn.quantized.FloatFunctional()
120
+
121
+ def forward(self, x):
122
+ """Forward pass.
123
+ Args:
124
+ x (tensor): input
125
+ Returns:
126
+ tensor: output
127
+ """
128
+
129
+ out = self.activation(x)
130
+ out = self.conv1(out)
131
+ if self.bn == True:
132
+ out = self.bn1(out)
133
+
134
+ out = self.activation(out)
135
+ out = self.conv2(out)
136
+ if self.bn == True:
137
+ out = self.bn2(out)
138
+
139
+ if self.groups > 1:
140
+ out = self.conv_merge(out)
141
+
142
+ return self.skip_add.add(out, x)
143
+
144
+ class FeatureFusionBlock_custom(nn.Module):
145
+ """Feature fusion block."""
146
+
147
+ def __init__(
148
+ self,
149
+ features,
150
+ activation,
151
+ deconv=False,
152
+ bn=False,
153
+ expand=False,
154
+ align_corners=True,
155
+ width_ratio=1,
156
+ ):
157
+ """Init.
158
+ Args:
159
+ features (int): number of features
160
+ """
161
+ super(FeatureFusionBlock_custom, self).__init__()
162
+ self.width_ratio = width_ratio
163
+
164
+ self.deconv = deconv
165
+ self.align_corners = align_corners
166
+
167
+ self.groups = 1
168
+
169
+ self.expand = expand
170
+ out_features = features
171
+ if self.expand == True:
172
+ out_features = features // 2
173
+
174
+ self.out_conv = nn.Conv2d(
175
+ features,
176
+ out_features,
177
+ kernel_size=1,
178
+ stride=1,
179
+ padding=0,
180
+ bias=True,
181
+ groups=1,
182
+ )
183
+
184
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
185
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
186
+
187
+ self.skip_add = nn.quantized.FloatFunctional()
188
+
189
+ def forward(self, *xs):
190
+ """Forward pass.
191
+ Returns:
192
+ tensor: output
193
+ """
194
+ output = xs[0]
195
+
196
+ if len(xs) == 2:
197
+ res = self.resConfUnit1(xs[1])
198
+ if self.width_ratio != 1:
199
+ res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
200
+
201
+ output = self.skip_add.add(output, res)
202
+ # output += res
203
+
204
+ output = self.resConfUnit2(output)
205
+
206
+ if self.width_ratio != 1:
207
+ # and output.shape[3] < self.width_ratio * output.shape[2]
208
+ #size=(image.shape[])
209
+ if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
210
+ shape = 3 * output.shape[3]
211
+ else:
212
+ shape = int(self.width_ratio * 2 * output.shape[2])
213
+ output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
214
+ else:
215
+ output = nn.functional.interpolate(output, scale_factor=2,
216
+ mode="bilinear", align_corners=self.align_corners)
217
+ output = self.out_conv(output)
218
+ return output
219
+
220
+ def make_fusion_block(features, use_bn, width_ratio=1):
221
+ return FeatureFusionBlock_custom(
222
+ features,
223
+ nn.ReLU(False),
224
+ deconv=False,
225
+ bn=use_bn,
226
+ expand=False,
227
+ align_corners=True,
228
+ width_ratio=width_ratio,
229
+ )
230
+
231
+ class Interpolate(nn.Module):
232
+ """Interpolation module."""
233
+
234
+ def __init__(self, scale_factor, mode, align_corners=False):
235
+ """Init.
236
+ Args:
237
+ scale_factor (float): scaling
238
+ mode (str): interpolation mode
239
+ """
240
+ super(Interpolate, self).__init__()
241
+
242
+ self.interp = nn.functional.interpolate
243
+ self.scale_factor = scale_factor
244
+ self.mode = mode
245
+ self.align_corners = align_corners
246
+
247
+ def forward(self, x):
248
+ """Forward pass.
249
+ Args:
250
+ x (tensor): input
251
+ Returns:
252
+ tensor: interpolated data
253
+ """
254
+ dtype = x.dtype
255
+ x = self.interp(
256
+ x.float(),
257
+ scale_factor=self.scale_factor,
258
+ mode=self.mode,
259
+ align_corners=self.align_corners,
260
+ )
261
+ x = x.to(dtype)
262
+ return x
263
+
264
+ class DPTOutputAdapter(nn.Module):
265
+ """DPT output adapter.
266
+
267
+ :param num_cahnnels: Number of output channels
268
+ :param stride_level: tride level compared to the full-sized image.
269
+ E.g. 4 for 1/4th the size of the image.
270
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
271
+ Patch size for smaller inputs will be computed accordingly.
272
+ :param hooks: Index of intermediate layers
273
+ :param layer_dims: Dimension of intermediate layers
274
+ :param feature_dim: Feature dimension
275
+ :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
276
+ :param use_bn: If set to True, activates batch norm
277
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
278
+ """
279
+
280
+ def __init__(self,
281
+ num_channels: int = 1,
282
+ stride_level: int = 1,
283
+ patch_size: Union[int, Tuple[int, int]] = 16,
284
+ main_tasks: Iterable[str] = ('rgb',),
285
+ hooks: List[int] = [2, 5, 8, 11],
286
+ layer_dims: List[int] = [96, 192, 384, 768],
287
+ feature_dim: int = 256,
288
+ last_dim: int = 32,
289
+ use_bn: bool = False,
290
+ dim_tokens_enc: Optional[int] = None,
291
+ head_type: str = 'regression',
292
+ output_width_ratio=1,
293
+ **kwargs):
294
+ super().__init__()
295
+ self.num_channels = num_channels
296
+ self.stride_level = stride_level
297
+ self.patch_size = pair(patch_size)
298
+ self.main_tasks = main_tasks
299
+ self.hooks = hooks
300
+ self.layer_dims = layer_dims
301
+ self.feature_dim = feature_dim
302
+ self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
303
+ self.head_type = head_type
304
+
305
+ # Actual patch height and width, taking into account stride of input
306
+ self.P_H = max(1, self.patch_size[0] // stride_level)
307
+ self.P_W = max(1, self.patch_size[1] // stride_level)
308
+
309
+ self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
310
+
311
+ self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
312
+ self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
313
+ self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
314
+ self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
315
+
316
+ if self.head_type == 'regression':
317
+ # The "DPTDepthModel" head
318
+ self.head = nn.Sequential(
319
+ nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
320
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
321
+ nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
322
+ nn.ReLU(True),
323
+ nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
324
+ )
325
+ elif self.head_type == 'semseg':
326
+ # The "DPTSegmentationModel" head
327
+ self.head = nn.Sequential(
328
+ nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
329
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
330
+ nn.ReLU(True),
331
+ nn.Dropout(0.1, False),
332
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
333
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
334
+ )
335
+ else:
336
+ raise ValueError('DPT head_type must be "regression" or "semseg".')
337
+
338
+ if self.dim_tokens_enc is not None:
339
+ self.init(dim_tokens_enc=dim_tokens_enc)
340
+
341
+ def init(self, dim_tokens_enc=768):
342
+ """
343
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
344
+ Should be called when setting up MultiMAE.
345
+
346
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
347
+ """
348
+ #print(dim_tokens_enc)
349
+
350
+ # Set up activation postprocessing layers
351
+ if isinstance(dim_tokens_enc, int):
352
+ dim_tokens_enc = 4 * [dim_tokens_enc]
353
+
354
+ self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
355
+
356
+ self.act_1_postprocess = nn.Sequential(
357
+ nn.Conv2d(
358
+ in_channels=self.dim_tokens_enc[0],
359
+ out_channels=self.layer_dims[0],
360
+ kernel_size=1, stride=1, padding=0,
361
+ ),
362
+ nn.ConvTranspose2d(
363
+ in_channels=self.layer_dims[0],
364
+ out_channels=self.layer_dims[0],
365
+ kernel_size=4, stride=4, padding=0,
366
+ bias=True, dilation=1, groups=1,
367
+ )
368
+ )
369
+
370
+ self.act_2_postprocess = nn.Sequential(
371
+ nn.Conv2d(
372
+ in_channels=self.dim_tokens_enc[1],
373
+ out_channels=self.layer_dims[1],
374
+ kernel_size=1, stride=1, padding=0,
375
+ ),
376
+ nn.ConvTranspose2d(
377
+ in_channels=self.layer_dims[1],
378
+ out_channels=self.layer_dims[1],
379
+ kernel_size=2, stride=2, padding=0,
380
+ bias=True, dilation=1, groups=1,
381
+ )
382
+ )
383
+
384
+ self.act_3_postprocess = nn.Sequential(
385
+ nn.Conv2d(
386
+ in_channels=self.dim_tokens_enc[2],
387
+ out_channels=self.layer_dims[2],
388
+ kernel_size=1, stride=1, padding=0,
389
+ )
390
+ )
391
+
392
+ self.act_4_postprocess = nn.Sequential(
393
+ nn.Conv2d(
394
+ in_channels=self.dim_tokens_enc[3],
395
+ out_channels=self.layer_dims[3],
396
+ kernel_size=1, stride=1, padding=0,
397
+ ),
398
+ nn.Conv2d(
399
+ in_channels=self.layer_dims[3],
400
+ out_channels=self.layer_dims[3],
401
+ kernel_size=3, stride=2, padding=1,
402
+ )
403
+ )
404
+
405
+ self.act_postprocess = nn.ModuleList([
406
+ self.act_1_postprocess,
407
+ self.act_2_postprocess,
408
+ self.act_3_postprocess,
409
+ self.act_4_postprocess
410
+ ])
411
+
412
+ def adapt_tokens(self, encoder_tokens):
413
+ # Adapt tokens
414
+ x = []
415
+ x.append(encoder_tokens[:, :])
416
+ x = torch.cat(x, dim=-1)
417
+ return x
418
+
419
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size):
420
+ #input_info: Dict):
421
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
422
+ H, W = image_size
423
+
424
+ # Number of patches in height and width
425
+ N_H = H // (self.stride_level * self.P_H)
426
+ N_W = W // (self.stride_level * self.P_W)
427
+
428
+ # Hook decoder onto 4 layers from specified ViT layers
429
+ layers = [encoder_tokens[hook] for hook in self.hooks]
430
+
431
+ # Extract only task-relevant tokens and ignore global tokens.
432
+ layers = [self.adapt_tokens(l) for l in layers]
433
+
434
+ # Reshape tokens to spatial representation
435
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
436
+
437
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
438
+ # Project layers to chosen feature dim
439
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
440
+
441
+ # Fuse layers using refinement stages
442
+ path_4 = self.scratch.refinenet4(layers[3])
443
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
444
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
445
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
446
+
447
+ # Output head
448
+ out = self.head(path_1)
449
+
450
+ return out
dust3r/croco/models/head_downstream.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Heads for downstream tasks
6
+ # --------------------------------------------------------
7
+
8
+ """
9
+ A head is a module where the __init__ defines only the head hyperparameters.
10
+ A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes.
11
+ The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height'
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from .dpt_block import DPTOutputAdapter
17
+
18
+
19
+ class PixelwiseTaskWithDPT(nn.Module):
20
+ """ DPT module for CroCo.
21
+ by default, hooks_idx will be equal to:
22
+ * for encoder-only: 4 equally spread layers
23
+ * for encoder+decoder: last encoder + 3 equally spread layers of the decoder
24
+ """
25
+
26
+ def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768],
27
+ output_width_ratio=1, num_channels=1, postprocess=None, **kwargs):
28
+ super(PixelwiseTaskWithDPT, self).__init__()
29
+ self.return_all_blocks = True # backbone needs to return all layers
30
+ self.postprocess = postprocess
31
+ self.output_width_ratio = output_width_ratio
32
+ self.num_channels = num_channels
33
+ self.hooks_idx = hooks_idx
34
+ self.layer_dims = layer_dims
35
+
36
+ def setup(self, croconet):
37
+ dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels}
38
+ if self.hooks_idx is None:
39
+ if hasattr(croconet, 'dec_blocks'): # encoder + decoder
40
+ step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth]
41
+ hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
42
+ else: # encoder only
43
+ step = croconet.enc_depth//4
44
+ hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
45
+ self.hooks_idx = hooks_idx
46
+ print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}')
47
+ dpt_args['hooks'] = self.hooks_idx
48
+ dpt_args['layer_dims'] = self.layer_dims
49
+ self.dpt = DPTOutputAdapter(**dpt_args)
50
+ dim_tokens = [croconet.enc_embed_dim if hook<croconet.enc_depth else croconet.dec_embed_dim for hook in self.hooks_idx]
51
+ dpt_init_args = {'dim_tokens_enc': dim_tokens}
52
+ self.dpt.init(**dpt_init_args)
53
+
54
+
55
+ def forward(self, x, img_info):
56
+ out = self.dpt(x, image_size=(img_info['height'],img_info['width']))
57
+ if self.postprocess: out = self.postprocess(out)
58
+ return out
dust3r/croco/models/masking.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Masking utils
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ class RandomMask(nn.Module):
13
+ """
14
+ random masking
15
+ """
16
+
17
+ def __init__(self, num_patches, mask_ratio):
18
+ super().__init__()
19
+ self.num_patches = num_patches
20
+ self.num_mask = int(mask_ratio * self.num_patches)
21
+
22
+ def __call__(self, x):
23
+ noise = torch.rand(x.size(0), self.num_patches, device=x.device)
24
+ argsort = torch.argsort(noise, dim=1)
25
+ return argsort < self.num_mask
dust3r/croco/models/pos_embed.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Position embedding utils
7
+ # --------------------------------------------------------
8
+
9
+
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ # --------------------------------------------------------
16
+ # 2D sine-cosine position embedding
17
+ # References:
18
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
21
+ # --------------------------------------------------------
22
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
23
+ """
24
+ grid_size: int of the grid height and width
25
+ return:
26
+ pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27
+ """
28
+ grid_h = np.arange(grid_size, dtype=np.float32)
29
+ grid_w = np.arange(grid_size, dtype=np.float32)
30
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
31
+ grid = np.stack(grid, axis=0)
32
+
33
+ grid = grid.reshape([2, 1, grid_size, grid_size])
34
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
35
+ if n_cls_token>0:
36
+ pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
37
+ return pos_embed
38
+
39
+
40
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
41
+ assert embed_dim % 2 == 0
42
+
43
+ # use half of dimensions to encode grid_h
44
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
45
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
46
+
47
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
48
+ return emb
49
+
50
+
51
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
52
+ """
53
+ embed_dim: output dimension for each position
54
+ pos: a list of positions to be encoded: size (M,)
55
+ out: (M, D)
56
+ """
57
+ assert embed_dim % 2 == 0
58
+ omega = np.arange(embed_dim // 2, dtype=float)
59
+ omega /= embed_dim / 2.
60
+ omega = 1. / 10000**omega # (D/2,)
61
+
62
+ pos = pos.reshape(-1) # (M,)
63
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
64
+
65
+ emb_sin = np.sin(out) # (M, D/2)
66
+ emb_cos = np.cos(out) # (M, D/2)
67
+
68
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
69
+ return emb
70
+
71
+
72
+ # --------------------------------------------------------
73
+ # Interpolate position embeddings for high-resolution
74
+ # References:
75
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
76
+ # DeiT: https://github.com/facebookresearch/deit
77
+ # --------------------------------------------------------
78
+ def interpolate_pos_embed(model, checkpoint_model):
79
+ if 'pos_embed' in checkpoint_model:
80
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
81
+ embedding_size = pos_embed_checkpoint.shape[-1]
82
+ num_patches = model.patch_embed.num_patches
83
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
84
+ # height (== width) for the checkpoint position embedding
85
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
86
+ # height (== width) for the new position embedding
87
+ new_size = int(num_patches ** 0.5)
88
+ # class_token and dist_token are kept unchanged
89
+ if orig_size != new_size:
90
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
91
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
92
+ # only the position tokens are interpolated
93
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
94
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
95
+ pos_tokens = torch.nn.functional.interpolate(
96
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
97
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
98
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
99
+ checkpoint_model['pos_embed'] = new_pos_embed
100
+
101
+
102
+ #----------------------------------------------------------
103
+ # RoPE2D: RoPE implementation in 2D
104
+ #----------------------------------------------------------
105
+
106
+ try:
107
+ from models.curope import cuRoPE2D
108
+ RoPE2D = cuRoPE2D
109
+ except ImportError:
110
+ print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
111
+
112
+ class RoPE2D(torch.nn.Module):
113
+
114
+ def __init__(self, freq=100.0, F0=1.0):
115
+ super().__init__()
116
+ self.base = freq
117
+ self.F0 = F0
118
+ self.cache = {}
119
+
120
+ def get_cos_sin(self, D, seq_len, device, dtype):
121
+ if (D,seq_len,device,dtype) not in self.cache:
122
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
123
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
124
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
125
+ freqs = torch.cat((freqs, freqs), dim=-1)
126
+ cos = freqs.cos() # (Seq, Dim)
127
+ sin = freqs.sin()
128
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
129
+ return self.cache[D,seq_len,device,dtype]
130
+
131
+ @staticmethod
132
+ def rotate_half(x):
133
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
134
+ return torch.cat((-x2, x1), dim=-1)
135
+
136
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
137
+ assert pos1d.ndim==2
138
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
139
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
140
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
141
+
142
+ def forward(self, tokens, positions):
143
+ """
144
+ input:
145
+ * tokens: batch_size x nheads x ntokens x dim
146
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
147
+ output:
148
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
149
+ """
150
+ assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
151
+ D = tokens.size(3) // 2
152
+ assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
153
+ cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
154
+ # split features into two along the feature dimension, and apply rope1d on each half
155
+ y, x = tokens.chunk(2, dim=-1)
156
+ y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
157
+ x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
158
+ tokens = torch.cat((y, x), dim=-1)
159
+ return tokens
dust3r/croco/models/transformer_utils.py ADDED
@@ -0,0 +1,1021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ from typing import Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from torch import nn, einsum
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from inspect import isfunction
16
+ try:
17
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func, flash_attn_varlen_qkvpacked_func
18
+ from flash_attn.bert_padding import unpad_input, pad_input
19
+ except:
20
+ flash_attn_qkvpacked_func, flash_attn_func, flash_attn_varlen_qkvpacked_func = None, None, None
21
+ unpad_input, pad_input = None, None
22
+ from .x_transformer import AttentionLayers, BasicEncoder
23
+ import math
24
+
25
+
26
+
27
+ def _init_weights(module):
28
+ if isinstance(module, nn.Linear):
29
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
30
+ if module.bias is not None:
31
+ torch.nn.init.zeros_(module.bias)
32
+
33
+ def exists(val):
34
+ return val is not None
35
+
36
+ def default(val, d):
37
+ if exists(val):
38
+ return val
39
+ return d() if isfunction(d) else d
40
+
41
+ def zero_module(module):
42
+ """
43
+ Zero out the parameters of a module and return it.
44
+ """
45
+ for p in module.parameters():
46
+ p.detach().zero_()
47
+ return module
48
+
49
+ # Copy from CLIP GitHub
50
+ class LayerNorm(nn.LayerNorm):
51
+ """Subclass torch's LayerNorm to handle fp16."""
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ orig_type = x.dtype
55
+ ret = super().forward(x.type(torch.float32))
56
+ return ret.type(orig_type)
57
+
58
+ class QuickGELU(nn.Module):
59
+ def forward(self, x: torch.Tensor):
60
+ return x * torch.sigmoid(1.702 * x)
61
+
62
+ def modulate(x, shift, scale):
63
+ # from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
64
+ return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
65
+
66
+
67
+ class MultiheadAttentionFlashV2(nn.Module):
68
+ def __init__(self, embed_dim, n_head, bias=False, shift_group=None, qkv_packed=False, window_size=None):
69
+ super().__init__()
70
+
71
+ self.head_dim = embed_dim// n_head
72
+ self.embed_dim = embed_dim
73
+ self.n_head = n_head
74
+ self.to_q = nn.Linear(embed_dim, embed_dim, bias=bias)
75
+ self.to_k = nn.Linear(embed_dim, embed_dim, bias=bias)
76
+ self.to_v = nn.Linear(embed_dim, embed_dim, bias=bias)
77
+ self.shift_group = shift_group
78
+ self.qkv_packed = qkv_packed
79
+ self.window_size = window_size
80
+
81
+
82
+ def forward(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, need_weights=False, attn_mask=None):
83
+ q = q.permute(1, 0, 2)
84
+ k = k.permute(1, 0, 2)
85
+ v = v.permute(1, 0, 2)
86
+
87
+ h = self.n_head
88
+ q = self.to_q(q)
89
+ k = self.to_k(k)
90
+ v = self.to_v(v)
91
+
92
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v))
93
+ # print(q.dtype, k.dtype, v.dtype)
94
+ if self.qkv_packed:
95
+ bsz, q_len, heads, head_dim = q.shape
96
+ group_size = self.shift_group
97
+ nheads = self.n_head
98
+ qkv = torch.stack([q,k,v], dim=2)
99
+ qkv = qkv.reshape(bsz, q_len, 3, 2, nheads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,
100
+ q_len, 3,
101
+ nheads // 2,
102
+ self.head_dim)
103
+
104
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
105
+ key_padding_mask = torch.ones(x.shape[0], x.shape[1], device=x.device, dtype=x.dtype)
106
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
107
+ cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
108
+ cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2
109
+ cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).min
110
+ cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
111
+ cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
112
+ cu_q_lens = cu_q_lens[cu_q_lens >= 0]
113
+ x_unpad = rearrange(
114
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
115
+ )
116
+ output_unpad = flash_attn_varlen_qkvpacked_func(
117
+ x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=False,
118
+ )
119
+ output = rearrange(
120
+ pad_input(
121
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
122
+ ),
123
+ "b s (h d) -> b s h d",
124
+ h=nheads // 2,
125
+ )
126
+ r_out = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,
127
+ self.head_dim)
128
+ else:
129
+ if self.shift_group is not None:
130
+ bsz, q_len, heads, head_dim = q.shape
131
+ assert q_len % self.shift_group == 0
132
+
133
+ def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
134
+ qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
135
+ qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
136
+ return qkv
137
+
138
+ q = shift(q, bsz, q_len, self.shift_group, h, self.head_dim)
139
+ k = shift(k, bsz, q_len, self.shift_group, h, self.head_dim)
140
+ v = shift(v, bsz, q_len, self.shift_group, h, self.head_dim)
141
+ if self.window_size:
142
+ out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal, window_size=(self.window_size // 2, self.window_size // 2))
143
+ else:
144
+ out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal)
145
+
146
+ if self.shift_group is not None:
147
+ out = out.transpose(1, 2).contiguous()
148
+ out = rearrange(out, '(b l) g h d -> b (l g) h d', l=q_len // self.shift_group)
149
+ r_out = out.clone()
150
+ r_out[:, :, h//2:] = r_out[:, :, h//2:].roll(h//2, dims=1)
151
+ else:
152
+ r_out = out
153
+
154
+ r_out = rearrange(r_out, 'b n h d -> b n (h d)')
155
+ r_out = r_out.permute(1, 0, 2)
156
+ return (r_out,)
157
+
158
+ class PSUpsamplerBlock(nn.Module):
159
+ def __init__(self, d_model: int, d_model_out: int, scale_factor: int):
160
+ super().__init__()
161
+
162
+ # self.mlp = nn.Sequential(OrderedDict([
163
+ # ("c_fc", nn.Linear(d_model, d_model_out * scale_factor**2)),
164
+ # ("gelu", QuickGELU()),
165
+ # ("c_proj", nn.Linear(d_model_out * scale_factor**2, d_model_out * scale_factor**2))
166
+ # ]))
167
+ # self.ln_2 = LayerNorm(d_model)
168
+ self.scale_factor = scale_factor
169
+ self.d_model_out = d_model_out
170
+ self.residual_fc = nn.Linear(d_model, d_model_out * (scale_factor**2))
171
+ self.pixelshuffle = nn.PixelShuffle(scale_factor)
172
+
173
+ def forward(self, x: torch.Tensor):
174
+ # mlp block
175
+ # x.shape b, l, d
176
+ # y = self.ln_2(x)
177
+ # y = self.mlp(y)
178
+ # For here we have two cases:
179
+ # 1. If we have a modulation function for the MLP, we use it to modulate the output of the MLP
180
+ # 2. If we don't have a modulation function for the MLP, we use the modulation function for the attention
181
+ x = self.residual_fc(x)# .repeat(1, 1, self.scale_factor**2)
182
+ # x = x + y
183
+ bs, l, c = x.shape
184
+ resolution = int(np.sqrt(l))
185
+ x = x.permute(0, 2, 1).reshape(bs, c, resolution, resolution)
186
+ x = self.pixelshuffle(x)
187
+ x = x.reshape(bs, self.d_model_out, resolution*self.scale_factor*resolution*self.scale_factor)
188
+ x = x.permute(0, 2, 1)
189
+ # x = rearrange(x, 'b l (s c) -> b (l s) c', s=self.scale_factor**2)
190
+ return x
191
+
192
+ class ResidualAttentionBlock(nn.Module):
193
+ def __init__(self, d_model: int,
194
+ n_head: int,
195
+ attn_mask: torch.Tensor = None,
196
+ modulate_feature_size: int = None,
197
+ modulate_act_type: str = 'gelu',
198
+ cross_att: bool = None,
199
+ flash_v2: bool = None,
200
+ qkv_packed: bool = None,
201
+ shift_group: int = None,
202
+ window_size: int = None,):
203
+ super().__init__()
204
+
205
+ print('vit flashv2', flash_v2)
206
+
207
+ self.flash_v2 = flash_v2
208
+ self.window_size = window_size
209
+ if self.flash_v2:
210
+ self.attn = MultiheadAttentionFlashV2(d_model, n_head, shift_group=shift_group, qkv_packed=qkv_packed, window_size=window_size)
211
+ else:
212
+ self.attn = nn.MultiheadAttention(d_model, n_head)
213
+
214
+ self.ln_1 = LayerNorm(d_model)
215
+ self.mlp = nn.Sequential(OrderedDict([
216
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
217
+ ("gelu", QuickGELU()),
218
+ ("c_proj", nn.Linear(d_model * 4, d_model))
219
+ ]))
220
+ self.ln_2 = LayerNorm(d_model)
221
+ self.attn_mask = attn_mask
222
+ self.window_size = window_size
223
+
224
+ if modulate_feature_size is not None:
225
+ act_dict = {'gelu': QuickGELU,
226
+ 'silu': nn.SiLU}
227
+ self.modulation_fn = nn.Sequential(
228
+ LayerNorm(modulate_feature_size),
229
+ act_dict[modulate_act_type](),
230
+ nn.Linear(modulate_feature_size, 3 * d_model, bias=True)
231
+ )
232
+ self.mlp_modulation_fn = nn.Sequential(
233
+ LayerNorm(modulate_feature_size),
234
+ act_dict[modulate_act_type](),
235
+ nn.Linear(modulate_feature_size, 3 * d_model, bias=True)
236
+ )
237
+ else:
238
+ self.modulation_fn = None
239
+ self.mlp_modulation_fn = None
240
+
241
+ self.cross_att = cross_att
242
+ if self.cross_att:
243
+ self.cross_att = CrossAttention(query_dim=d_model, context_dim=d_model,
244
+ heads=n_head, dim_head=int(d_model//n_head), dropout=0)
245
+ self.ln_1_5 = LayerNorm(d_model)
246
+
247
+ def attention(self, x: torch.Tensor, index):
248
+ if self.attn_mask is not None:
249
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device)
250
+ length = x.shape[0]
251
+ attn_mask = self.attn_mask[:length, :length]
252
+ else:
253
+ attn_mask = None
254
+ if self.window_size is not None:
255
+ x = x.permute(1, 0, 2)
256
+ b, l, c = x.shape
257
+ # print(x.shape)
258
+ assert l % self.window_size == 0
259
+ if index % 2 == 0:
260
+ x = rearrange(x, 'b (p w) c -> (b p) w c', w=self.window_size)
261
+ x = x.permute(1, 0, 2) # w, bp, c
262
+ x = self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
263
+ x = x.permute(1, 0, 2) # bp, w, c
264
+ x = rearrange(x, '(b l) w c -> b (l w) c', l=l//self.window_size, w=self.window_size)
265
+ x = x.permute(1, 0, 2) # w, bp, c
266
+ else:
267
+ x = torch.roll(x, shifts=self.window_size//2, dims=1)
268
+ x = rearrange(x, 'b (p w) c -> (b p) w c', w=self.window_size)
269
+ x = x.permute(1, 0, 2) # w, bp, c
270
+ x = self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
271
+ x = x.permute(1, 0, 2) # w, bp, c
272
+ x = rearrange(x, '(b l) w c -> b (l w) c', l=l//self.window_size, w=self.window_size)
273
+ x = torch.roll(x, shifts=-self.window_size//2, dims=1)
274
+ x = x.permute(1, 0, 2)
275
+ else:
276
+ x = self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
277
+
278
+ return x
279
+
280
+ def forward(self, x: torch.Tensor, modulation: torch.Tensor = None, context: torch.Tensor = None, index=None):
281
+ # self attention block
282
+ y = self.ln_1(x)
283
+ if self.modulation_fn is not None:
284
+ shift, scale, gate = self.modulation_fn(modulation).chunk(3, dim=1)
285
+ y = modulate(y, shift, scale)
286
+ y = self.attention(y, index)
287
+ # If we have modulation func for mlp as well, we will just use the gate for the attention
288
+ if self.modulation_fn is not None and self.mlp_modulation_fn is not None:
289
+ y = y * gate.unsqueeze(0)
290
+ x = x + y
291
+
292
+ # cross attention block
293
+ if self.cross_att:
294
+ y = self.cross_att(self.ln_1_5(x), context=context)
295
+ # print(y.mean().item())
296
+ x = x + y
297
+
298
+ # mlp block
299
+ y = self.ln_2(x)
300
+ if self.mlp_modulation_fn is not None:
301
+ shift, scale, gate = self.mlp_modulation_fn(modulation).chunk(3, dim=1)
302
+ y = modulate(y, shift, scale)
303
+ y = self.mlp(y)
304
+ # For here we have two cases:
305
+ # 1. If we have a modulation function for the MLP, we use it to modulate the output of the MLP
306
+ # 2. If we don't have a modulation function for the MLP, we use the modulation function for the attention
307
+ if self.modulation_fn is not None:
308
+ y = y * gate.unsqueeze(0)
309
+ x = x + y
310
+
311
+ return x
312
+
313
+
314
+ class Transformer(nn.Module):
315
+ def __init__(self,
316
+ width: int,
317
+ layers: int,
318
+ heads: int,
319
+ attn_mask: torch.Tensor = None,
320
+ modulate_feature_size: int = None,
321
+ modulate_act_type: str = 'gelu',
322
+ cross_att_layers: int = 0,
323
+ return_all_layers=False,
324
+ flash_v2=True,
325
+ qkv_packed=False,
326
+ shift_group=None,
327
+ window_size=None):
328
+
329
+ super().__init__()
330
+ self.width = width
331
+ self.layers = layers
332
+
333
+ blocks = []
334
+ for _ in range(layers):
335
+ layer = ResidualAttentionBlock(width,
336
+ heads,
337
+ attn_mask,
338
+ modulate_feature_size=modulate_feature_size,
339
+ modulate_act_type=modulate_act_type,
340
+ cross_att = (_ + cross_att_layers)>=layers,
341
+ flash_v2=flash_v2,
342
+ qkv_packed=qkv_packed,
343
+ shift_group=shift_group,
344
+ window_size=window_size)
345
+ blocks.append(layer)
346
+
347
+ self.resblocks = nn.Sequential(*blocks)
348
+
349
+ self.grad_checkpointing = False
350
+ self.return_all_layers = return_all_layers
351
+ self.flash_v2 = flash_v2
352
+
353
+ def set_grad_checkpointing(self, flag=True):
354
+ self.grad_checkpointing = flag
355
+
356
+ def forward(self,
357
+ x: torch.Tensor,
358
+ modulation: torch.Tensor = None,
359
+ context: torch.Tensor = None,
360
+ additional_residuals = None):
361
+
362
+ all_x = []
363
+ if additional_residuals is not None:
364
+ assert len(additional_residuals) == self.layers
365
+ for res_i, module in enumerate(self.resblocks):
366
+ if self.grad_checkpointing:
367
+ # print("Grad checkpointing")
368
+ x = checkpoint(module, x, modulation, context, res_i)
369
+ else:
370
+ x = module(x, modulation, context, res_i)
371
+ if additional_residuals is not None:
372
+ add_res = additional_residuals[res_i]
373
+ x[:, :add_res.shape[1]] = x[:, :add_res.shape[1]] + add_res
374
+ all_x.append(x)
375
+ if self.return_all_layers:
376
+ return all_x
377
+ else:
378
+ return x
379
+
380
+ class GaussianUpsampler(nn.Module):
381
+ def __init__(self, width,
382
+ up_ratio,
383
+ ch_decay=1,
384
+ low_channels=64,
385
+ window_size=False,
386
+ with_additional_inputs=False):
387
+
388
+ super().__init__()
389
+ self.up_ratio = up_ratio
390
+ self.low_channels = low_channels
391
+ self.window_size = window_size
392
+ self.base_width = width
393
+ self.with_additional_inputs = with_additional_inputs
394
+ for res_log2 in range(int(np.log2(up_ratio))):
395
+ _width = width
396
+ width = max(width // ch_decay, 64)
397
+ heads = int(width / 64)
398
+ width = heads * 64
399
+ if self.with_additional_inputs:
400
+ self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width+self.base_width, width, 2))
401
+ else:
402
+ self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width, width, 2))
403
+ encoder = Transformer(width, 2, heads,
404
+ modulate_feature_size=None,
405
+ modulate_act_type=None,
406
+ cross_att_layers=0,
407
+ return_all_layers=False,
408
+ flash_v2=False,
409
+ qkv_packed=False,
410
+ shift_group=False,
411
+ window_size=window_size)
412
+ self.add_module(f'attention_{res_log2}', encoder)
413
+ self.out_channels = width
414
+ self.ln_post = LayerNorm(width)
415
+
416
+ def forward(self, x, additional_inputs=None):
417
+ if self.with_additional_inputs:
418
+ assert len(additional_inputs) == int(np.log2(self.up_ratio))
419
+ for res_log2 in range(int(np.log2(self.up_ratio))):
420
+ if self.with_additional_inputs:
421
+ add_input = additional_inputs[res_log2]
422
+ scale = x.shape[1] // add_input.shape[1]
423
+ add_input = add_input.repeat_interleave(scale, 1)
424
+ x = torch.cat([x, add_input], dim=2)
425
+ x = getattr(self, f'upsampler_{res_log2}')(x)
426
+ x = x.permute(1, 0, 2)
427
+ x = getattr(self, f'attention_{res_log2}')(x)
428
+ x = x.permute(1, 0, 2)
429
+ x = self.ln_post(x)
430
+ return x
431
+
432
+
433
+
434
+ class HyperGaussianUpsampler(nn.Module):
435
+ def __init__(self, width,
436
+ resolution,
437
+ up_ratio,
438
+ ch_decay=1,
439
+ window_size=False,
440
+ with_additional_inputs=False,
441
+ upsampler_kwargs={}):
442
+
443
+ super().__init__()
444
+ self.up_ratio = up_ratio
445
+ self.window_size = window_size
446
+ self.base_width = width
447
+ self.with_additional_inputs = with_additional_inputs
448
+ self.resolution = resolution
449
+ for res_log2 in range(int(np.log2(up_ratio))):
450
+ if res_log2 == 0:
451
+ _width = width
452
+ width = width
453
+ heads = int(width / 64)
454
+ width = heads * 64
455
+ if self.with_additional_inputs:
456
+ self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width+self.base_width, width, 2))
457
+ else:
458
+ self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width, width, 2))
459
+ encoder = Transformer(width, 2, heads,
460
+ modulate_feature_size=None,
461
+ modulate_act_type=None,
462
+ cross_att_layers=0,
463
+ return_all_layers=False,
464
+ flash_v2=False,
465
+ qkv_packed=False,
466
+ shift_group=False,
467
+ window_size=window_size)
468
+ self.add_module(f'attention_{res_log2}', encoder)
469
+ self.resolution = self.resolution * 2
470
+ else:
471
+ self.resolution = self.resolution * 2
472
+ self.add_module(f'upsample_{res_log2}',
473
+ UpsamplerLayers_conv(in_channels=width,
474
+ out_channels=width,
475
+ resolution=self.resolution,
476
+ conv_block_type = 'convnext',
477
+ **upsampler_kwargs))
478
+ self.out_channels = width
479
+ # self.ln_post = LayerNorm(width)
480
+ self.ln_post = LayerNorm([self.resolution, self.resolution, width])
481
+
482
+ def forward(self, x, additional_inputs=None):
483
+ if self.with_additional_inputs:
484
+ assert len(additional_inputs) == int(np.log2(self.up_ratio))
485
+ for res_log2 in range(int(np.log2(self.up_ratio))):
486
+ if res_log2 == 0:
487
+ if self.with_additional_inputs:
488
+ add_input = additional_inputs[res_log2]
489
+ scale = x.shape[1] // add_input.shape[1]
490
+ add_input = add_input.repeat_interleave(scale, 1)
491
+ x = torch.cat([x, add_input], dim=2)
492
+ x = getattr(self, f'upsampler_{res_log2}')(x)
493
+ x = x.permute(1, 0, 2)
494
+ x = getattr(self, f'attention_{res_log2}')(x)
495
+ x = x.permute(1, 0, 2)
496
+ x = x.reshape(x.shape[0], int(math.sqrt(x.shape[1])), int(math.sqrt(x.shape[1])), -1).permute(0, 3, 1, 2)
497
+ else:
498
+ x = getattr(self, f'upsample_{res_log2}')(x)
499
+ x = self.ln_post(x.permute(0, 2, 3, 1))
500
+ return x
501
+
502
+ class VisionTransformer(nn.Module):
503
+ def __init__(self,
504
+ # transformer params
505
+ in_channels: int,
506
+ patch_size: int,
507
+ width: int,
508
+ layers: int,
509
+ heads: int,
510
+ weight: str = None,
511
+ encode_layers: int = 0,
512
+ shift_group = False,
513
+ flash_v2 = False,
514
+ qkv_packed = False,
515
+ window_size = False,
516
+ use_pe = False,
517
+ # modualtion params
518
+ modulate_feature_size: int = None,
519
+ modulate_act_type: str = 'gelu',
520
+ # camera condition
521
+ camera_condition: str = 'plucker',
522
+ # init params
523
+ disable_dino=False,
524
+ error_weight_init_mode='mean',
525
+ # other params
526
+ add_zero_conv=False,
527
+ return_all_layers=False,
528
+ disable_post_ln=False,
529
+ rope=None):
530
+ super().__init__()
531
+ self.patch_size = patch_size
532
+ self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
533
+ self.use_pe = use_pe
534
+ self.rope = rope
535
+ self.disable_dino = disable_dino
536
+ # if not self.disable_dino:
537
+ # scale = width ** -0.5
538
+ # self.class_embedding = nn.Parameter(scale * torch.randn(width))
539
+ # self.positional_embedding = nn.Parameter(scale * torch.randn((input_res// patch_size) ** 2 + 1, width))
540
+ # else:
541
+ # if self.use_pe:
542
+ # self.positional_embedding = nn.Parameter(torch.zeros(1, (input_res// patch_size) ** 2, width))
543
+ # nn.init.trunc_normal_(self.positional_embedding, std=0.02)
544
+ self.ln_pre = LayerNorm(width)
545
+ self.add_zero_conv = add_zero_conv
546
+ self.return_all_layers = return_all_layers
547
+ self.disable_post_ln = disable_post_ln
548
+ self.flash_v2 = flash_v2
549
+ self.qkv_packed = qkv_packed
550
+
551
+ self.camera_condition = camera_condition
552
+ if self.camera_condition == 'plucker': assert modulate_feature_size is None
553
+
554
+ if self.add_zero_conv:
555
+ assert self.return_all_layers
556
+ self.zero_convs = nn.ModuleList([zero_module(nn.Conv1d(in_channels=width, out_channels=width, kernel_size=1, stride=1, bias=True)) for _ in range(layers)])
557
+
558
+ self.encode_layers = encode_layers
559
+ if self.encode_layers > 0:
560
+ self.encoder = Transformer(width, encode_layers, heads,
561
+ modulate_feature_size=modulate_feature_size,
562
+ modulate_act_type=modulate_act_type,
563
+ cross_att_layers=0,
564
+ return_all_layers=return_all_layers,
565
+ flash_v2=flash_v2,
566
+ qkv_packed=qkv_packed,
567
+ shift_group=shift_group,
568
+ window_size=window_size)
569
+ self.transformer = Transformer(width, layers-encode_layers, heads,
570
+ modulate_feature_size=modulate_feature_size,
571
+ modulate_act_type=modulate_act_type,
572
+ cross_att_layers=0,
573
+ return_all_layers=return_all_layers,
574
+ flash_v2=flash_v2,
575
+ qkv_packed=qkv_packed,
576
+ shift_group=shift_group,
577
+ window_size=window_size)
578
+
579
+ if not self.disable_post_ln:
580
+ self.ln_post = LayerNorm(width)
581
+
582
+ if weight is not None:
583
+ if not self.disable_dino:
584
+ if "clip" in weight:
585
+ raise NotImplementedError()
586
+ elif weight.startswith("vit_b_16"):
587
+ load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
588
+ elif weight.startswith("vit_b_8"):
589
+ load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
590
+ else:
591
+ raise NotImplementedError()
592
+ else:
593
+ self.apply(_init_weights)
594
+
595
+ # Init the weight and bias of modulation_fn to zero
596
+ if modulate_feature_size != 0:
597
+ for block in self.transformer.resblocks:
598
+ if block.modulation_fn is not None:
599
+ block.modulation_fn[2].weight.data.zero_()
600
+ block.modulation_fn[2].bias.data.zero_()
601
+ if block.mlp_modulation_fn is not None:
602
+ block.mlp_modulation_fn[2].weight.data.zero_()
603
+ block.mlp_modulation_fn[2].bias.data.zero_()
604
+ for block in self.transformer.resblocks:
605
+ if block.cross_att:
606
+ zero_module(block.cross_att.to_out)
607
+
608
+ def set_grad_checkpointing(self, flag=True):
609
+ self.transformer.set_grad_checkpointing(flag)
610
+
611
+ def forward(self,
612
+ x: torch.Tensor,
613
+ modulation: torch.Tensor = None,
614
+ context: torch.Tensor = None,
615
+ additional_residuals=None,
616
+ abla_crossview=False,
617
+ pos=None):
618
+
619
+ # image tokenization
620
+ bs, vs = x.shape[:2]
621
+ x = rearrange(x, 'b v c h w -> (b v) c h w')
622
+ pos = rearrange(pos, 'b v c h -> (b v) c h')
623
+ if self.camera_condition == 'plucker' and modulation is not None:
624
+ modulation = rearrange(modulation, 'b v c h w -> (b v) c h w')
625
+ x = torch.cat([x, modulation], dim=1)
626
+ modulation = None
627
+
628
+ x = self.conv1(x) # shape = [*, width, grid, grid]
629
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
630
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
631
+
632
+ # pre-normalization
633
+ x = self.ln_pre(x)
634
+ B, N, C = x.shape
635
+ x = x.reshape(B, N, -1, 64)
636
+ x = x.permute(0, 2, 1, 3)
637
+ # print('pre x mean: ', x.mean().item())
638
+ # print('pre x var: ', x.var().item())
639
+ x = x + self.rope(torch.ones_like(x).to(x), pos)
640
+ # print('x mean: ', x.mean().item())
641
+ # print('x var: ', x.var().item())
642
+
643
+ x = x.permute(0, 2, 1, 3)
644
+ x = x.reshape(B, N, -1)
645
+ # use encode to extract features
646
+ if self.encode_layers > 0:
647
+ x = x.permute(1, 0, 2) # NLD -> LND
648
+ x = self.encoder(x, modulation, context, additional_residuals=additional_residuals)
649
+ x = x.permute(1, 0, 2) # LND -> NLD
650
+ if not self.disable_dino:
651
+ x = x.permute(1, 0, 2) # NLD -> LND
652
+ else:
653
+ if not abla_crossview:
654
+ # flatten x along the video dimension
655
+ x = rearrange(x, '(b v) n d -> b (v n) d', v=vs)
656
+ # print(x.shape)
657
+ x = x.permute(1, 0, 2) # NLD -> LND
658
+ else:
659
+ x = x.permute(1, 0, 2)
660
+ x = self.transformer(x, modulation, context, additional_residuals=additional_residuals)
661
+
662
+
663
+ if self.add_zero_conv:
664
+ assert isinstance(x, (list, tuple))
665
+ assert len(x) == len(self.zero_convs)
666
+ new_x = []
667
+ for sub_x, sub_zero_conv in zip(x, self.zero_convs):
668
+ sub_x_out = sub_zero_conv(sub_x.permute(1, 2, 0))
669
+ new_x.append(sub_x_out.permute(2, 0, 1))
670
+ x = new_x
671
+
672
+ if self.return_all_layers:
673
+ assert isinstance(x, (list, tuple))
674
+ if not self.disable_post_ln:
675
+ x_final = x[-1].permute(1, 0, 2) # LND -> NLD
676
+ x_final = self.ln_post(x_final)
677
+ x_final = rearrange(x_final, 'b (v n) d -> b v n d', v=vs)
678
+ x = [s.permute(1, 0, 2) for s in x]
679
+ x.append(x_final)
680
+ return x
681
+
682
+ if not self.disable_post_ln:
683
+ x = x.permute(1, 0, 2) # LND -> NLD
684
+ x = self.ln_post(x)
685
+ if not self.disable_dino:
686
+ x = rearrange(x, '(b v) n d -> b v n d', b=bs, v=vs)
687
+ else:
688
+ if not abla_crossview:
689
+ # reshape x back to video dimension
690
+ x = rearrange(x, 'b (v n) d -> b v n d', v=vs)
691
+ else:
692
+ x = rearrange(x, '(b v) n d -> b v n d', v=vs)
693
+ return x
694
+
695
+ def extra_repr(self) -> str:
696
+ pass
697
+
698
+
699
+ class VisionTransformer_fusion(nn.Module):
700
+ def __init__(self,
701
+ # transformer params
702
+ in_channels: int,
703
+ patch_size: int,
704
+ width: int,
705
+ layers: int,
706
+ heads: int,
707
+ weight: str = None,
708
+ encode_layers: int = 0,
709
+ shift_group = False,
710
+ flash_v2 = False,
711
+ qkv_packed = False,
712
+ window_size = False,
713
+ use_pe = False,
714
+ # modualtion params
715
+ modulate_feature_size: int = None,
716
+ modulate_act_type: str = 'gelu',
717
+ # camera condition
718
+ camera_condition: str = 'plucker',
719
+ # init params
720
+ disable_dino=False,
721
+ error_weight_init_mode='mean',
722
+ # other params
723
+ add_zero_conv=False,
724
+ return_all_layers=False,
725
+ disable_post_ln=False,
726
+ rope=None):
727
+ super().__init__()
728
+ self.patch_size = patch_size
729
+ self.use_pe = use_pe
730
+ self.rope = rope
731
+ self.disable_dino = disable_dino
732
+ # if not self.disable_dino:
733
+ # scale = width ** -0.5
734
+ # self.class_embedding = nn.Parameter(scale * torch.randn(width))
735
+ # self.positional_embedding = nn.Parameter(scale * torch.randn((input_res// patch_size) ** 2 + 1, width))
736
+ # else:
737
+ # if self.use_pe:
738
+ # self.positional_embedding = nn.Parameter(torch.zeros(1, (input_res// patch_size) ** 2, width))
739
+ # nn.init.trunc_normal_(self.positional_embedding, std=0.02)
740
+ self.ln_pre = LayerNorm(width)
741
+ self.add_zero_conv = add_zero_conv
742
+ self.return_all_layers = return_all_layers
743
+ self.disable_post_ln = disable_post_ln
744
+ self.flash_v2 = flash_v2
745
+ self.qkv_packed = qkv_packed
746
+
747
+ self.camera_condition = camera_condition
748
+ if self.camera_condition == 'plucker': assert modulate_feature_size is None
749
+
750
+ if self.add_zero_conv:
751
+ assert self.return_all_layers
752
+ self.zero_convs = nn.ModuleList([zero_module(nn.Conv1d(in_channels=width, out_channels=width, kernel_size=1, stride=1, bias=True)) for _ in range(layers)])
753
+
754
+ self.encode_layers = encode_layers
755
+ if self.encode_layers > 0:
756
+ self.encoder = Transformer(width, encode_layers, heads,
757
+ modulate_feature_size=modulate_feature_size,
758
+ modulate_act_type=modulate_act_type,
759
+ cross_att_layers=0,
760
+ return_all_layers=return_all_layers,
761
+ flash_v2=flash_v2,
762
+ qkv_packed=qkv_packed,
763
+ shift_group=shift_group,
764
+ window_size=window_size)
765
+ self.transformer = Transformer(width, layers-encode_layers, heads,
766
+ modulate_feature_size=modulate_feature_size,
767
+ modulate_act_type=modulate_act_type,
768
+ cross_att_layers=0,
769
+ return_all_layers=return_all_layers,
770
+ flash_v2=flash_v2,
771
+ qkv_packed=qkv_packed,
772
+ shift_group=shift_group,
773
+ window_size=window_size)
774
+
775
+ if not self.disable_post_ln:
776
+ self.ln_post = LayerNorm(width)
777
+
778
+ if weight is not None:
779
+ if not self.disable_dino:
780
+ if "clip" in weight:
781
+ raise NotImplementedError()
782
+ elif weight.startswith("vit_b_16"):
783
+ load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
784
+ elif weight.startswith("vit_b_8"):
785
+ load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
786
+ else:
787
+ raise NotImplementedError()
788
+ else:
789
+ self.apply(_init_weights)
790
+
791
+ # Init the weight and bias of modulation_fn to zero
792
+ if modulate_feature_size != 0:
793
+ for block in self.transformer.resblocks:
794
+ if block.modulation_fn is not None:
795
+ block.modulation_fn[2].weight.data.zero_()
796
+ block.modulation_fn[2].bias.data.zero_()
797
+ if block.mlp_modulation_fn is not None:
798
+ block.mlp_modulation_fn[2].weight.data.zero_()
799
+ block.mlp_modulation_fn[2].bias.data.zero_()
800
+ for block in self.transformer.resblocks:
801
+ if block.cross_att:
802
+ zero_module(block.cross_att.to_out)
803
+
804
+ def set_grad_checkpointing(self, flag=True):
805
+ self.transformer.set_grad_checkpointing(flag)
806
+
807
+ def forward(self,
808
+ x: torch.Tensor,
809
+ modulation: torch.Tensor = None,
810
+ context: torch.Tensor = None,
811
+ additional_residuals=None,
812
+ abla_crossview=False,
813
+ pos=None):
814
+
815
+ # image tokenization
816
+ bs, vs = x.shape[:2]
817
+ x = rearrange(x, 'b v h g -> (b v) h g') # shape = [*, grid ** 2, width]
818
+ pos = rearrange(pos, 'b v c h -> (b v) c h')
819
+
820
+ # pre-normalization
821
+ B, N, C = x.shape
822
+ x = x.reshape(B, N, -1, 64)
823
+ x = x.permute(0, 2, 1, 3)
824
+ # print('pre x mean: ', x.mean().item())
825
+ # print('pre x var: ', x.var().item())
826
+ x = x + self.rope(torch.ones_like(x).to(x), pos)
827
+ # print('x mean: ', x.mean().item())
828
+ # print('x var: ', x.var().item())
829
+ x = x.permute(0, 2, 1, 3)
830
+ x = x.reshape(B, N, -1)
831
+ # use encode to extract features
832
+ if self.encode_layers > 0:
833
+ x = x.permute(1, 0, 2) # NLD -> LND
834
+ x = self.encoder(x, modulation, context, additional_residuals=additional_residuals)
835
+ x = x.permute(1, 0, 2) # LND -> NLD
836
+ if not self.disable_dino:
837
+ x = x.permute(1, 0, 2) # NLD -> LND
838
+ else:
839
+ if not abla_crossview:
840
+ # flatten x along the video dimension
841
+ x = rearrange(x, '(b v) n d -> b (v n) d', v=vs)
842
+ # print(x.shape)
843
+ x = x.permute(1, 0, 2) # NLD -> LND
844
+ else:
845
+ x = x.permute(1, 0, 2)
846
+ x = self.transformer(x, modulation, context, additional_residuals=additional_residuals)
847
+
848
+ if self.add_zero_conv:
849
+ assert isinstance(x, (list, tuple))
850
+ assert len(x) == len(self.zero_convs)
851
+ new_x = []
852
+ for sub_x, sub_zero_conv in zip(x, self.zero_convs):
853
+ sub_x_out = sub_zero_conv(sub_x.permute(1, 2, 0))
854
+ new_x.append(sub_x_out.permute(2, 0, 1))
855
+ x = new_x
856
+
857
+ if self.return_all_layers:
858
+ assert isinstance(x, (list, tuple))
859
+ if not self.disable_post_ln:
860
+ x_final = x[-1].permute(1, 0, 2) # LND -> NLD
861
+ x_final = self.ln_post(x_final)
862
+ x_final = rearrange(x_final, 'b (v n) d -> b v n d', v=vs)
863
+ x = [s.permute(1, 0, 2) for s in x]
864
+ x.append(x_final)
865
+ return x
866
+
867
+ if not self.disable_post_ln:
868
+ x = x.permute(1, 0, 2) # LND -> NLD
869
+ x = self.ln_post(x)
870
+ if not self.disable_dino:
871
+ x = rearrange(x, '(b v) n d -> b v n d', b=bs, v=vs)
872
+ else:
873
+ if not abla_crossview:
874
+ # reshape x back to video dimension
875
+ x = rearrange(x, 'b (v n) d -> b v n d', v=vs)
876
+ else:
877
+ x = rearrange(x, '(b v) n d -> b v n d', v=vs)
878
+ return x
879
+
880
+ def extra_repr(self) -> str:
881
+ pass
882
+
883
+
884
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic'):
885
+ """
886
+ Resize positional embeddings, implementation from google/simclr and open_clip.
887
+ """
888
+ # Rescale the grid of position embeddings when loading from state_dict
889
+ old_pos_embed = state_dict.get('positional_embedding', None)
890
+ if old_pos_embed is None:
891
+ return
892
+
893
+ # Compute the grid size and extra tokens
894
+ old_pos_len = state_dict["positional_embedding"].shape[0]
895
+ old_grid_size = round((state_dict["positional_embedding"].shape[0]) ** 0.5)
896
+ grid_size = round((model.positional_embedding.shape[0]) ** 0.5)
897
+ if old_grid_size == grid_size:
898
+ return
899
+ extra_tokens = old_pos_len - (old_grid_size ** 2)
900
+
901
+ if extra_tokens:
902
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
903
+ else:
904
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
905
+
906
+ # Only interpolate the positional emb part, not the extra token part.
907
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
908
+ pos_emb_img = F.interpolate(
909
+ pos_emb_img,
910
+ size=grid_size,
911
+ mode=interpolation,
912
+ align_corners=True,
913
+ )
914
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size * grid_size, -1)[0]
915
+
916
+ # Concatenate back the
917
+ if pos_emb_tok is not None:
918
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
919
+ else:
920
+ new_pos_embed = pos_emb_img
921
+ state_dict['positional_embedding'] = new_pos_embed
922
+
923
+
924
+ myname2timmname = {
925
+ "vit_b_16_mae": None,
926
+ "vit_b_16_in": "vit_base_patch16_224",
927
+ "vit_b_16_in21k": 'vit_base_patch16_224_in21k',
928
+ "vit_b_16_sam": 'vit_base_patch16_224_sam',
929
+ "vit_b_16_dino": 'vit_base_patch16_224_dino',
930
+ "vit_b_16_mill_in21k": 'vit_base_patch16_224_miil_in21k',
931
+ "vit_b_16_mill": 'vit_base_patch16_224_miil',
932
+ "vit_b_8_dino": 'vit_base_patch16_224_dino',
933
+ }
934
+
935
+ def load_timm_to_clip(module, config_name="vit_b_16_mae", init_mode='zero'):
936
+ from torch import nn
937
+ from clip.model import LayerNorm as CLIPLayerNorm
938
+ from clip.model import QuickGELU
939
+
940
+ from torch.nn import GELU
941
+ from torch.nn import LayerNorm
942
+
943
+ import json
944
+ now_dir = os.path.abspath(os.path.dirname(__file__))
945
+ timm2clip = json.load(open(f"{now_dir}/timm2clip_vit_b_16.json"))
946
+
947
+ assert config_name in myname2timmname, f"The name {config_name} is not one of {list(myname2timmname.keys())}"
948
+ try:
949
+ timm_weight = torch.load(f"/sensei-fs/users/hatan/model/{config_name}.pth")["model"]
950
+ except Exception as e:
951
+ try:
952
+ print(f"/input/yhxu/models/dino_weights/{config_name}.pth")
953
+ timm_weight = torch.load(f"/input/yhxu/models/dino_weights/{config_name}.pth")["model"]
954
+ except Exception as e:
955
+ try:
956
+ print(f"/home/yhxu/models/dino_weights/{config_name}.pth")
957
+ timm_weight = torch.load(f"/home/yhxu/models/dino_weights/{config_name}.pth")["model"]
958
+ except:
959
+ try:
960
+ timm_weight = torch.load(f"/nas2/zifan/checkpoint/dino_weights/{config_name}.pth")["model"]
961
+ except Exception as e:
962
+ print("Please download weight with support/dump_timm_weights.py. \n"
963
+ "If using mae weight, please check https://github.com/facebookresearch/mae,"
964
+ "and download the weight as vit_b_16_mae.pth")
965
+ assert False
966
+
967
+ # Build model's state dict
968
+ clipname2timmweight = {}
969
+ for timm_key, clip_key in timm2clip.items():
970
+ timm_value = timm_weight[timm_key]
971
+ clipname2timmweight[clip_key[len("visual."):]] = timm_value.squeeze()
972
+
973
+ # Resize positional embedding
974
+ resize_pos_embed(clipname2timmweight, module)
975
+
976
+ # Load weight to model.
977
+ model_visual_keys = set(module.state_dict().keys())
978
+ load_keys = set(clipname2timmweight.keys())
979
+ # print(f"Load not in model: {load_keys - model_visual_keys}")
980
+ # print(f"Model not in load: {model_visual_keys - load_keys}")
981
+ # status = module.load_state_dict(clipname2timmweight, strict=False)
982
+ try:
983
+ status = module.load_state_dict(clipname2timmweight, strict=False)
984
+ except:
985
+ print('conv.weight has error!')
986
+ if init_mode == 'zero':
987
+ new_weight = torch.zeros_like(clipname2timmweight['conv1.weight'])
988
+ new_weight = new_weight.repeat(1, 2, 1, 1)
989
+ new_weight[:,:3] = clipname2timmweight['conv1.weight']
990
+ elif init_mode == 'mean':
991
+ new_weight = torch.zeros_like(clipname2timmweight['conv1.weight'])
992
+ new_weight = new_weight.repeat(1, 3, 1, 1)
993
+ new_weight = ((clipname2timmweight['conv1.weight']).repeat(1, 3, 1, 1))/3
994
+
995
+ clipname2timmweight['conv1.weight'] = new_weight
996
+ status = module.load_state_dict(clipname2timmweight, strict=False)
997
+
998
+ # Since timm model has bias, we add it back here.
999
+ module.conv1.bias = nn.Parameter(clipname2timmweight['conv1.bias'])
1000
+
1001
+ # Reinit the visual weights that not covered by timm
1002
+ module.ln_pre.reset_parameters()
1003
+
1004
+ def convert_clip_to_timm(module):
1005
+ """Copy from detectron2, frozen BN"""
1006
+ res = module
1007
+ if isinstance(module, CLIPLayerNorm):
1008
+ # Timm uses eps=1e-6 while CLIP uses eps=1e-5
1009
+ res = LayerNorm(module.normalized_shape, eps=1e-6, elementwise_affine=module.elementwise_affine)
1010
+ if module.elementwise_affine:
1011
+ res.weight.data = module.weight.data.clone().detach()
1012
+ res.bias.data = module.bias.data.clone().detach()
1013
+ elif isinstance(module, QuickGELU):
1014
+ # Timm uses GELU while CLIP uses QuickGELU
1015
+ res = GELU()
1016
+ else:
1017
+ for name, child in module.named_children():
1018
+ new_child = convert_clip_to_timm(child)
1019
+ if new_child is not child:
1020
+ res.add_module(name, new_child)
1021
+ return res
dust3r/croco/models/x_transformer.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """from https://github.com/lucidrains/x-transformers"""
2
+ import math
3
+ from random import random
4
+
5
+ import torch
6
+ from torch import nn, einsum
7
+ import torch.nn.functional as F
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from functools import partial, wraps
11
+ from inspect import isfunction
12
+
13
+ from einops import rearrange, repeat, reduce
14
+
15
+
16
+ # constants
17
+
18
+ DEFAULT_DIM_HEAD = 64
19
+
20
+
21
+ # helpers
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def default(val, d):
28
+ if exists(val):
29
+ return val
30
+ return d() if isfunction(d) else d
31
+
32
+
33
+ def cast_tuple(val, depth):
34
+ return val if isinstance(val, tuple) else (val,) * depth
35
+
36
+
37
+ # init helpers
38
+
39
+ def init_zero_(layer):
40
+ nn.init.constant_(layer.weight, 0.)
41
+ if exists(layer.bias):
42
+ nn.init.constant_(layer.bias, 0.)
43
+
44
+
45
+ # keyword argument helpers
46
+
47
+ def pick_and_pop(keys, d):
48
+ values = list(map(lambda key: d.pop(key), keys))
49
+ return dict(zip(keys, values))
50
+
51
+
52
+ def group_dict_by_key(cond, d):
53
+ return_val = [dict(), dict()]
54
+ for key in d.keys():
55
+ match = bool(cond(key))
56
+ ind = int(not match)
57
+ return_val[ind][key] = d[key]
58
+ return (*return_val,)
59
+
60
+
61
+ def string_begins_with(prefix, str):
62
+ return str.startswith(prefix)
63
+
64
+
65
+ def group_by_key_prefix(prefix, d):
66
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
67
+
68
+
69
+ def groupby_prefix_and_trim(prefix, d):
70
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
71
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
72
+ return kwargs_without_prefix, kwargs
73
+
74
+
75
+ # initializations
76
+
77
+ def deepnorm_init(
78
+ transformer,
79
+ beta,
80
+ module_name_match_list=['.ff.', '.to_v', '.to_out']
81
+ ):
82
+ for name, module in transformer.named_modules():
83
+ if type(module) != nn.Linear:
84
+ continue
85
+
86
+ needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list))
87
+ gain = beta if needs_beta_gain else 1
88
+ nn.init.xavier_normal_(module.weight.data, gain=gain)
89
+
90
+ if exists(module.bias):
91
+ nn.init.constant_(module.bias.data, 0)
92
+
93
+
94
+ # activations
95
+
96
+ class ReluSquared(nn.Module):
97
+ def forward(self, x):
98
+ return F.relu(x) ** 2
99
+
100
+
101
+ # norms
102
+
103
+ class Scale(nn.Module):
104
+ def __init__(self, value, fn):
105
+ super().__init__()
106
+ self.value = value
107
+ self.fn = fn
108
+
109
+ def forward(self, x, **kwargs):
110
+ out = self.fn(x, **kwargs)
111
+ scale_fn = lambda t: t * self.value
112
+
113
+ if not isinstance(out, tuple):
114
+ return scale_fn(out)
115
+
116
+ return (scale_fn(out[0]), *out[1:])
117
+
118
+
119
+ class ScaleNorm(nn.Module):
120
+ def __init__(self, dim, eps=1e-5):
121
+ super().__init__()
122
+ self.eps = eps
123
+ self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
124
+
125
+ def forward(self, x):
126
+ norm = torch.norm(x, dim=-1, keepdim=True)
127
+ return x / norm.clamp(min=self.eps) * self.g
128
+
129
+
130
+ class RMSNorm(nn.Module):
131
+ def __init__(self, dim, eps=1e-8):
132
+ super().__init__()
133
+ self.scale = dim ** -0.5
134
+ self.eps = eps
135
+ self.g = nn.Parameter(torch.ones(dim))
136
+
137
+ def forward(self, x):
138
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
139
+ return x / norm.clamp(min=self.eps) * self.g
140
+
141
+
142
+ # residual and residual gates
143
+
144
+ class Residual(nn.Module):
145
+ def __init__(self, dim, scale_residual=False, scale_residual_constant=1.):
146
+ super().__init__()
147
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
148
+ self.scale_residual_constant = scale_residual_constant
149
+
150
+ def forward(self, x, residual):
151
+ if exists(self.residual_scale):
152
+ residual = residual * self.residual_scale
153
+
154
+ if self.scale_residual_constant != 1:
155
+ residual = residual * self.scale_residual_constant
156
+
157
+ return x + residual
158
+
159
+
160
+ class GRUGating(nn.Module):
161
+ def __init__(self, dim, scale_residual=False, **kwargs):
162
+ super().__init__()
163
+ self.gru = nn.GRUCell(dim, dim)
164
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
165
+
166
+ def forward(self, x, residual):
167
+ if exists(self.residual_scale):
168
+ residual = residual * self.residual_scale
169
+
170
+ gated_output = self.gru(
171
+ rearrange(x, 'b n d -> (b n) d'),
172
+ rearrange(residual, 'b n d -> (b n) d')
173
+ )
174
+
175
+ return gated_output.reshape_as(x)
176
+
177
+
178
+ # feedforward
179
+ class GLU(nn.Module):
180
+ def __init__(self, dim_in, dim_out, activation):
181
+ super().__init__()
182
+ self.act = activation
183
+ self.proj = nn.Linear(dim_in, dim_out * 2)
184
+
185
+ def forward(self, x):
186
+ x, gate = self.proj(x).chunk(2, dim=-1)
187
+ return x * self.act(gate)
188
+
189
+
190
+ class FeedForward(nn.Module):
191
+ def __init__(
192
+ self,
193
+ dim,
194
+ dim_out=None,
195
+ mult=4,
196
+ glu=False,
197
+ swish=False,
198
+ relu_squared=False,
199
+ post_act_ln=False,
200
+ dropout=0.,
201
+ no_bias=False,
202
+ zero_init_output=False
203
+ ):
204
+ super().__init__()
205
+ inner_dim = int(dim * mult)
206
+ dim_out = default(dim_out, dim)
207
+
208
+ if relu_squared:
209
+ activation = ReluSquared()
210
+ elif swish:
211
+ activation = nn.SiLU()
212
+ else:
213
+ activation = nn.GELU()
214
+
215
+ project_in = nn.Sequential(
216
+ nn.Linear(dim, inner_dim, bias=not no_bias),
217
+ activation
218
+ ) if not glu else GLU(dim, inner_dim, activation)
219
+
220
+ self.ff = nn.Sequential(
221
+ project_in,
222
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
223
+ nn.Dropout(dropout),
224
+ nn.Linear(inner_dim, dim_out, bias=not no_bias)
225
+ )
226
+
227
+ # init last linear layer to 0
228
+ if zero_init_output:
229
+ init_zero_(self.ff[-1])
230
+
231
+ def forward(self, x):
232
+ return self.ff(x)
233
+
234
+
235
+ # attention.
236
+
237
+ class Attention(nn.Module):
238
+ def __init__(
239
+ self,
240
+ dim,
241
+ kv_dim=None,
242
+ dim_head=DEFAULT_DIM_HEAD,
243
+ heads=8,
244
+ causal=False,
245
+ dropout=0.,
246
+ zero_init_output=False,
247
+ shared_kv=False,
248
+ value_dim_head=None,
249
+ flash_attention=True,
250
+ ):
251
+ super().__init__()
252
+ self.scale = dim_head ** -0.5
253
+ if kv_dim is None:
254
+ kv_dim = dim
255
+
256
+ self.heads = heads
257
+ self.causal = causal
258
+
259
+ value_dim_head = default(value_dim_head, dim_head)
260
+ q_dim = k_dim = dim_head * heads
261
+ v_dim = out_dim = value_dim_head * heads
262
+
263
+ self.to_q = nn.Linear(dim, q_dim, bias=False)
264
+ self.to_k = nn.Linear(kv_dim, k_dim, bias=False)
265
+
266
+ # shared key / values, for further memory savings during inference
267
+ assert not (
268
+ shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
269
+ self.to_v = nn.Linear(kv_dim, v_dim, bias=False) if not shared_kv else None
270
+
271
+ # Convert to output
272
+ self.to_out = nn.Linear(out_dim, dim, bias=False)
273
+
274
+ # dropout
275
+ self.dropout_p = dropout
276
+ self.dropout = nn.Dropout(dropout)
277
+
278
+ # Flash Attention, needs PyTorch >= 1.13
279
+ self.flash = flash_attention
280
+ assert self.flash
281
+
282
+ # Use torch.nn.functional.scaled_dot_product_attention if available
283
+ # otherwise, we use the xformer library.
284
+ # self.use_xformer = True
285
+ self.use_xformer = not hasattr(torch.nn.functional, 'scaled_dot_product_attention')
286
+
287
+ # init output projection 0
288
+ if zero_init_output:
289
+ init_zero_(self.to_out)
290
+
291
+ def forward(
292
+ self,
293
+ x,
294
+ context=None,
295
+ mask=None,
296
+ context_mask=None,
297
+ ):
298
+ # print("x", x.dtype)
299
+ h = self.heads
300
+ kv_input = default(context, x)
301
+
302
+ q_input = x
303
+ k_input = kv_input
304
+ v_input = kv_input
305
+
306
+ q = self.to_q(q_input)
307
+ k = self.to_k(k_input)
308
+ v = self.to_v(v_input) if exists(self.to_v) else k
309
+
310
+ # print("q", q.dtype)
311
+ # print("k", k.dtype)
312
+ # print("v", v.dtype)
313
+
314
+ if self.use_xformer:
315
+ # Since xformers only accepts bf16/fp16, we need to convert qkv to bf16/fp16
316
+ dtype = q.dtype
317
+ q, k, v = map(lambda t: t.bfloat16() if t.dtype == torch.float32 else t, (q, k, v))
318
+
319
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v))
320
+ try:
321
+ import xformers.ops as xops
322
+ except ImportError as e:
323
+ print("Please install xformers to use flash attention for PyTorch < 2.0.0.")
324
+ raise e
325
+
326
+ # Use the flash attention support from the xformers library
327
+ if self.causal:
328
+ attention_bias = xops.LowerTriangularMask()
329
+ else:
330
+ attention_bias = None
331
+
332
+ # The memory_efficient_attention takes the input as (batch, seq_len, heads, dim)
333
+ out = xops.memory_efficient_attention(
334
+ q, k, v, attn_bias=attention_bias,
335
+ # op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
336
+ )
337
+
338
+ out = out.to(dtype)
339
+
340
+ out = rearrange(out, 'b n h d -> b n (h d)')
341
+ else:
342
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
343
+ # efficient attention using Flash Attention CUDA kernels
344
+ out = torch.nn.functional.scaled_dot_product_attention(
345
+ q, k, v, attn_mask=None, dropout_p=self.dropout_p, is_causal=self.causal,
346
+ )
347
+ out = rearrange(out, 'b h n d -> b n (h d)')
348
+
349
+ out = self.to_out(out)
350
+
351
+ if exists(mask):
352
+ mask = rearrange(mask, 'b n -> b n 1')
353
+ out = out.masked_fill(~mask, 0.)
354
+
355
+ return out
356
+
357
+ def extra_repr(self) -> str:
358
+ return f"causal: {self.causal}, flash attention: {self.flash}, " \
359
+ f"use_xformers (if False, use torch.nn.functional.scaled_dot_product_attention): {self.use_xformer}"
360
+
361
+
362
+ def modulate(x, shift, scale):
363
+ # from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
364
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
365
+
366
+
367
+ class AttentionLayers(nn.Module):
368
+ def __init__(
369
+ self,
370
+ dim,
371
+ depth,
372
+ heads=8,
373
+ ctx_dim=None,
374
+ causal=False,
375
+ cross_attend=False,
376
+ only_cross=False,
377
+ use_scalenorm=False,
378
+ use_rmsnorm=False,
379
+ residual_attn=False,
380
+ cross_residual_attn=False,
381
+ macaron=False,
382
+ pre_norm=True,
383
+ gate_residual=False,
384
+ scale_residual=False,
385
+ scale_residual_constant=1.,
386
+ deepnorm=False,
387
+ sandwich_norm=False,
388
+ zero_init_branch_output=False,
389
+ layer_dropout=0.,
390
+ # Below are the arguments used for this img2nerf projects
391
+ modulate_feature_size=-1,
392
+ checkpointing=False,
393
+ checkpoint_every=1,
394
+ **kwargs
395
+ ):
396
+ super().__init__()
397
+
398
+ # Add checkpointing
399
+ self.checkpointing = checkpointing
400
+ self.checkpoint_every = checkpoint_every
401
+
402
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
403
+ attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
404
+
405
+ self.dim = dim
406
+ self.depth = depth
407
+ self.layers = nn.ModuleList([])
408
+
409
+ # determine deepnorm and residual scale
410
+ if deepnorm:
411
+ assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings'
412
+ pre_norm = sandwich_norm = False
413
+ scale_residual = True
414
+ scale_residual_constant = (2 * depth) ** 0.25
415
+
416
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
417
+ self.pre_norm = pre_norm
418
+ self.sandwich_norm = sandwich_norm
419
+
420
+ self.residual_attn = residual_attn
421
+ self.cross_residual_attn = cross_residual_attn
422
+ self.cross_attend = cross_attend
423
+
424
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
425
+ norm_class = RMSNorm if use_rmsnorm else norm_class
426
+ norm_fn = partial(norm_class, dim)
427
+
428
+ if cross_attend and not only_cross:
429
+ default_block = ('a', 'c', 'f')
430
+ elif cross_attend and only_cross:
431
+ default_block = ('c', 'f')
432
+ else:
433
+ default_block = ('a', 'f')
434
+
435
+ if macaron:
436
+ default_block = ('f',) + default_block
437
+
438
+ # zero init
439
+
440
+ if zero_init_branch_output:
441
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
442
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
443
+
444
+ # calculate layer block order
445
+ layer_types = default_block * depth
446
+
447
+ self.layer_types = layer_types
448
+
449
+ # stochastic depth
450
+ self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
451
+
452
+ # iterate and construct layers
453
+ for ind, layer_type in enumerate(self.layer_types):
454
+ is_last_layer = ind == (len(self.layer_types) - 1)
455
+
456
+ if layer_type == 'a':
457
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
458
+ elif layer_type == 'c':
459
+ layer = Attention(dim, kv_dim=ctx_dim, heads=heads, **attn_kwargs)
460
+ elif layer_type == 'f':
461
+ layer = FeedForward(dim, **ff_kwargs)
462
+ layer = layer if not macaron else Scale(0.5, layer)
463
+ else:
464
+ raise Exception(f'invalid layer type {layer_type}')
465
+
466
+ residual_fn = GRUGating if gate_residual else Residual
467
+ residual = residual_fn(dim, scale_residual=scale_residual, scale_residual_constant=scale_residual_constant)
468
+
469
+ pre_branch_norm = norm_fn() if pre_norm else None
470
+ post_branch_norm = norm_fn() if sandwich_norm else None
471
+ post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
472
+
473
+ # The whole modulation part is copied from DiT
474
+ # https://github.com/facebookresearch/DiT
475
+ modulation = None
476
+ if modulate_feature_size is not None:
477
+ modulation = nn.Sequential(
478
+ nn.LayerNorm(modulate_feature_size),
479
+ nn.GELU(),
480
+ nn.Linear(modulate_feature_size, 3 * dim, bias=True)
481
+ )
482
+
483
+ norms = nn.ModuleList([
484
+ pre_branch_norm,
485
+ post_branch_norm,
486
+ post_main_norm,
487
+ ])
488
+
489
+ self.layers.append(nn.ModuleList([
490
+ norms,
491
+ layer,
492
+ residual,
493
+ modulation,
494
+ ]))
495
+
496
+ if deepnorm:
497
+ init_gain = (8 * depth) ** -0.25
498
+ deepnorm_init(self, init_gain)
499
+
500
+ def forward(
501
+ self,
502
+ x,
503
+ context=None,
504
+ modulation=None,
505
+ mask=None,
506
+ context_mask=None,
507
+ ):
508
+ assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
509
+
510
+ num_layers = len(self.layer_types)
511
+ assert num_layers % self.checkpoint_every == 0
512
+
513
+ for start_layer_idx in range(0, num_layers, self.checkpoint_every):
514
+ end_layer_idx = min(start_layer_idx + self.checkpoint_every, num_layers)
515
+
516
+ def run_layers(x, context, modulation, start, end):
517
+ for ind, (layer_type, (norm, block, residual_fn, modulation_fn), layer_dropout) in enumerate(
518
+ zip(self.layer_types[start: end], self.layers[start: end], self.layer_dropouts[start: end])):
519
+ residual = x
520
+
521
+ pre_branch_norm, post_branch_norm, post_main_norm = norm
522
+
523
+ if exists(pre_branch_norm):
524
+ x = pre_branch_norm(x)
525
+
526
+ if modulation_fn is not None:
527
+ shift, scale, gate = modulation_fn(modulation).chunk(3, dim=1)
528
+ x = modulate(x, shift, scale)
529
+
530
+ if layer_type == 'a':
531
+ out = block(x, mask=mask)
532
+ elif layer_type == 'c':
533
+ out = block(x, context=context, mask=mask, context_mask=context_mask)
534
+ elif layer_type == 'f':
535
+ out = block(x)
536
+
537
+ if exists(post_branch_norm):
538
+ out = post_branch_norm(out)
539
+
540
+ if modulation_fn is not None:
541
+ # TODO: add a option to use gate or not.
542
+ out = out * gate.unsqueeze(1)
543
+
544
+ x = residual_fn(out, residual)
545
+
546
+ if exists(post_main_norm):
547
+ x = post_main_norm(x)
548
+
549
+ return x
550
+
551
+ if self.checkpointing:
552
+ # print("X checkpointing")
553
+ x = checkpoint(run_layers, x, context, modulation, start_layer_idx, end_layer_idx)
554
+ else:
555
+ x = run_layers(x, context, modulation, start_layer_idx, end_layer_idx)
556
+
557
+ return x
558
+
dust3r/croco/utils/misc.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
11
+ # --------------------------------------------------------
12
+
13
+ import builtins
14
+ import datetime
15
+ import os
16
+ import time
17
+ import math
18
+ import json
19
+ from collections import defaultdict, deque
20
+ from pathlib import Path
21
+ import numpy as np
22
+ from datetime import timedelta
23
+ import torch
24
+ import torch.distributed as dist
25
+ from torch import inf
26
+ import functools
27
+ from typing import cast, Dict, Iterable, List, Optional, Tuple, Union
28
+ from typing_extensions import deprecated
29
+
30
+ class SmoothedValue(object):
31
+ """Track a series of values and provide access to smoothed values over a
32
+ window or the global series average.
33
+ """
34
+
35
+ def __init__(self, window_size=20, fmt=None):
36
+ if fmt is None:
37
+ fmt = "{median:.4f} ({global_avg:.4f})"
38
+ self.deque = deque(maxlen=window_size)
39
+ self.total = 0.0
40
+ self.count = 0
41
+ self.fmt = fmt
42
+
43
+ def update(self, value, n=1):
44
+ self.deque.append(value)
45
+ self.count += n
46
+ self.total += value * n
47
+
48
+ def synchronize_between_processes(self):
49
+ """
50
+ Warning: does not synchronize the deque!
51
+ """
52
+ if not is_dist_avail_and_initialized():
53
+ return
54
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55
+ dist.barrier()
56
+ dist.all_reduce(t)
57
+ t = t.tolist()
58
+ self.count = int(t[0])
59
+ self.total = t[1]
60
+
61
+ @property
62
+ def median(self):
63
+ d = torch.tensor(list(self.deque))
64
+ return d.median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
69
+ return d.mean().item()
70
+
71
+ @property
72
+ def global_avg(self):
73
+ return self.total / self.count
74
+
75
+ @property
76
+ def max(self):
77
+ return max(self.deque)
78
+
79
+ @property
80
+ def value(self):
81
+ return self.deque[-1]
82
+
83
+ def __str__(self):
84
+ return self.fmt.format(
85
+ median=self.median,
86
+ avg=self.avg,
87
+ global_avg=self.global_avg,
88
+ max=self.max,
89
+ value=self.value)
90
+
91
+
92
+ class MetricLogger(object):
93
+ def __init__(self, delimiter="\t"):
94
+ self.meters = defaultdict(SmoothedValue)
95
+ self.delimiter = delimiter
96
+
97
+ def update(self, **kwargs):
98
+ for k, v in kwargs.items():
99
+ if v is None:
100
+ continue
101
+ if isinstance(v, torch.Tensor):
102
+ v = v.item()
103
+ assert isinstance(v, (float, int))
104
+ self.meters[k].update(v)
105
+
106
+ def __getattr__(self, attr):
107
+ if attr in self.meters:
108
+ return self.meters[attr]
109
+ if attr in self.__dict__:
110
+ return self.__dict__[attr]
111
+ raise AttributeError("'{}' object has no attribute '{}'".format(
112
+ type(self).__name__, attr))
113
+
114
+ def __str__(self):
115
+ loss_str = []
116
+ for name, meter in self.meters.items():
117
+ loss_str.append(
118
+ "{}: {}".format(name, str(meter))
119
+ )
120
+ return self.delimiter.join(loss_str)
121
+
122
+ def synchronize_between_processes(self):
123
+ for meter in self.meters.values():
124
+ meter.synchronize_between_processes()
125
+
126
+ def add_meter(self, name, meter):
127
+ self.meters[name] = meter
128
+
129
+ def log_every(self, iterable, print_freq, header=None, max_iter=None):
130
+ i = 0
131
+ if not header:
132
+ header = ''
133
+ start_time = time.time()
134
+ end = time.time()
135
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
136
+ data_time = SmoothedValue(fmt='{avg:.4f}')
137
+ len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable)
138
+ space_fmt = ':' + str(len(str(len_iterable))) + 'd'
139
+ log_msg = [
140
+ header,
141
+ '[{0' + space_fmt + '}/{1}]',
142
+ 'eta: {eta}',
143
+ '{meters}',
144
+ 'time: {time}',
145
+ 'data: {data}'
146
+ ]
147
+ if torch.cuda.is_available():
148
+ log_msg.append('max mem: {memory:.0f}')
149
+ log_msg = self.delimiter.join(log_msg)
150
+ MB = 1024.0 * 1024.0
151
+ for it,obj in enumerate(iterable):
152
+ data_time.update(time.time() - end)
153
+ yield obj
154
+ iter_time.update(time.time() - end)
155
+ if i % print_freq == 0 or i == len_iterable - 1:
156
+ eta_seconds = iter_time.global_avg * (len_iterable - i)
157
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
158
+ if torch.cuda.is_available():
159
+ print(log_msg.format(
160
+ i, len_iterable, eta=eta_string,
161
+ meters=str(self),
162
+ time=str(iter_time), data=str(data_time),
163
+ memory=torch.cuda.max_memory_allocated() / MB))
164
+ else:
165
+ print(log_msg.format(
166
+ i, len_iterable, eta=eta_string,
167
+ meters=str(self),
168
+ time=str(iter_time), data=str(data_time)))
169
+ i += 1
170
+ end = time.time()
171
+ if max_iter and it >= max_iter:
172
+ break
173
+ total_time = time.time() - start_time
174
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
175
+ print('{} Total time: {} ({:.4f} s / it)'.format(
176
+ header, total_time_str, total_time / len_iterable))
177
+
178
+
179
+ def setup_for_distributed(is_master):
180
+ """
181
+ This function disables printing when not in master process
182
+ """
183
+ builtin_print = builtins.print
184
+
185
+ def print(*args, **kwargs):
186
+ force = kwargs.pop('force', False)
187
+ force = force #or (get_world_size() > 8)
188
+ if is_master or force:
189
+ now = datetime.datetime.now().time()
190
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
191
+ builtin_print(*args, **kwargs)
192
+
193
+ builtins.print = print
194
+
195
+
196
+ def is_dist_avail_and_initialized():
197
+ if not dist.is_available():
198
+ return False
199
+ if not dist.is_initialized():
200
+ return False
201
+ return True
202
+
203
+
204
+ def get_world_size():
205
+ if not is_dist_avail_and_initialized():
206
+ return 1
207
+ return dist.get_world_size()
208
+
209
+
210
+ def get_rank():
211
+ if not is_dist_avail_and_initialized():
212
+ return 0
213
+ return dist.get_rank()
214
+
215
+
216
+ def is_main_process():
217
+ return get_rank() == 0
218
+
219
+
220
+ def save_on_master(*args, **kwargs):
221
+ if is_main_process():
222
+ torch.save(*args, **kwargs)
223
+
224
+
225
+ def init_distributed_mode(args):
226
+ nodist = args.nodist if hasattr(args,'nodist') else False
227
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ and not nodist:
228
+ args.rank = int(os.environ["RANK"])
229
+ args.world_size = int(os.environ['WORLD_SIZE'])
230
+ args.gpu = int(os.environ['LOCAL_RANK'])
231
+ else:
232
+ print('Not using distributed mode')
233
+ setup_for_distributed(is_master=True) # hack
234
+ args.distributed = False
235
+ return
236
+
237
+ # args.distributed = True
238
+
239
+ # torch.cuda.set_device(args.gpu)
240
+ # args.dist_backend = 'nccl'
241
+ # print('| distributed init (rank {}): {}, gpu {}'.format(
242
+ # args.rank, args.dist_url, args.gpu), flush=True)
243
+ # # os.environ['TORCH_NCCL_BLOCKING_WAIT'] = '0' # not to enforce timeout
244
+ # torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
245
+ # timeout=timedelta(seconds=72000000),
246
+ # world_size=args.world_size, rank=args.rank)
247
+ # # print('| distributed is master {} {} {}'.format(is_main_process(), is_dist_avail_and_initialized(), dist.get_rank()), flush=True)
248
+ # torch.distributed.barrier()
249
+ # setup_for_distributed(args.gpu == 0)
250
+ args.distributed = True
251
+ torch.cuda.set_device(args.gpu)
252
+ args.dist_backend = 'nccl'
253
+ print('| distributed init (rank {}): {}, gpu {}'.format(
254
+ args.rank, args.dist_url, args.gpu), flush=True)
255
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timedelta(seconds=72000000),
256
+ world_size=args.world_size, rank=args.rank)
257
+ torch.distributed.barrier()
258
+ setup_for_distributed(args.gpu == 0)
259
+
260
+ def _no_grad(func):
261
+ """
262
+ This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
263
+ clip_grad_norm_ and clip_grad_value_ themselves.
264
+ """
265
+
266
+ def _no_grad_wrapper(*args, **kwargs):
267
+ with torch.no_grad():
268
+ return func(*args, **kwargs)
269
+
270
+ functools.update_wrapper(_no_grad_wrapper, func)
271
+ return _no_grad_wrapper
272
+
273
+
274
+ @_no_grad
275
+ def clip_grad_norm_(
276
+ parameters,
277
+ max_norm,
278
+ norm_type= 2.0,
279
+ error_if_nonfinite = False,
280
+ foreach = None,
281
+ ):
282
+ r"""Clip the gradient norm of an iterable of parameters.
283
+
284
+ The norm is computed over the norms of the individual gradients of all parameters,
285
+ as if the norms of the individual gradients were concatenated into a single vector.
286
+ Gradients are modified in-place.
287
+
288
+ Args:
289
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
290
+ single Tensor that will have gradients normalized
291
+ max_norm (float): max norm of the gradients
292
+ norm_type (float): type of the used p-norm. Can be ``'inf'`` for
293
+ infinity norm.
294
+ error_if_nonfinite (bool): if True, an error is thrown if the total
295
+ norm of the gradients from :attr:`parameters` is ``nan``,
296
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
297
+ foreach (bool): use the faster foreach-based implementation.
298
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
299
+ fall back to the slow implementation for other device types.
300
+ Default: ``None``
301
+
302
+ Returns:
303
+ Total norm of the parameter gradients (viewed as a single vector).
304
+ """
305
+ if isinstance(parameters, torch.Tensor):
306
+ parameters = [parameters]
307
+ grads = [p.grad for p in parameters if p.grad is not None]
308
+ max_norm = float(max_norm)
309
+ norm_type = float(norm_type)
310
+ if len(grads) == 0:
311
+ return torch.tensor(0.0)
312
+ first_device = grads[0].device
313
+ grouped_grads: Dict[
314
+ Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
315
+ ] = _group_tensors_by_device_and_dtype(
316
+ [grads]
317
+ ) # type: ignore[assignment]
318
+
319
+ norms: List[Tensor] = []
320
+ for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
321
+ if (foreach is None and _has_foreach_support(device_grads, device)) or (
322
+ foreach and _device_has_foreach_support(device)
323
+ ):
324
+ norms.extend(torch._foreach_norm(device_grads, norm_type))
325
+ elif foreach:
326
+ raise RuntimeError(
327
+ f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
328
+ )
329
+ else:
330
+ norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
331
+
332
+ total_norm = torch.linalg.vector_norm(
333
+ torch.stack([norm.to(first_device) for norm in norms]), norm_type
334
+ )
335
+ return total_norm
336
+
337
+ class NativeScalerWithGradNormCount:
338
+ state_dict_key = "amp_scaler"
339
+
340
+ def __init__(self, enabled=True):
341
+ self._scaler = torch.cuda.amp.GradScaler(enabled=False)
342
+
343
+ def __call__(self, loss, optimizer, clip_grad=10, parameters=None, create_graph=False, update_grad=True):
344
+ self._scaler.scale(loss).backward(create_graph=create_graph)
345
+ if update_grad:
346
+ # if clip_grad is not None:
347
+ # assert parameters is not None
348
+
349
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
350
+ with torch.no_grad():
351
+ if isinstance(parameters, torch.Tensor):
352
+ parameters = [parameters]
353
+ parameters = [p for p in parameters if p.grad is not None]
354
+ for p in parameters:
355
+ if p.grad is None:
356
+ print(f"WARNING: found a None grad, name is {n}, step is {step}", force=True)
357
+ else:
358
+ p.grad.nan_to_num_(nan=0., posinf=1e-3, neginf=-1e-3)
359
+ norm = torch.nn.utils.clip_grad_norm_(parameters, 10.)
360
+ # else:
361
+ # self._scaler.unscale_(optimizer)
362
+ # norm = get_grad_norm_(parameters)
363
+ self._scaler.step(optimizer)
364
+ self._scaler.update()
365
+ else:
366
+ norm = None
367
+ return norm
368
+
369
+ def state_dict(self):
370
+ return self._scaler.state_dict()
371
+
372
+ def load_state_dict(self, state_dict):
373
+ self._scaler.load_state_dict(state_dict)
374
+
375
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
376
+ if isinstance(parameters, torch.Tensor):
377
+ parameters = [parameters]
378
+ parameters = [p for p in parameters if p.grad is not None]
379
+ norm_type = float(norm_type)
380
+ if len(parameters) == 0:
381
+ return torch.tensor(0.)
382
+ device = parameters[0].grad.device
383
+ if norm_type == inf:
384
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
385
+ else:
386
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
387
+
388
+ return total_norm
389
+
390
+
391
+
392
+
393
+ def save_model(args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None):
394
+ output_dir = Path(args.output_dir)
395
+ if fname is None: fname = str(epoch)
396
+ checkpoint_path = output_dir / ('checkpoint-%s.pth' % fname)
397
+ to_save = {
398
+ 'model': model_without_ddp.state_dict(),
399
+ 'optimizer': optimizer.state_dict(),
400
+ 'scaler': loss_scaler.state_dict(),
401
+ 'args': args,
402
+ 'epoch': epoch,
403
+ }
404
+ if best_so_far is not None: to_save['best_so_far'] = best_so_far
405
+ print(f'>> Saving model to {checkpoint_path} ...')
406
+ save_on_master(to_save, checkpoint_path)
407
+ if is_main_process():
408
+ os.system('ossutil64 cp -f %s oss://antsys-vilab/zsz/checkpoints/%s -j 200' % (checkpoint_path, checkpoint_path))
409
+
410
+
411
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
412
+ args.start_epoch = 0
413
+ best_so_far = None
414
+ if args.resume is not None:
415
+ if args.resume.startswith('https'):
416
+ checkpoint = torch.hub.load_state_dict_from_url(
417
+ args.resume, map_location='cpu', check_hash=True)
418
+ else:
419
+ checkpoint = torch.load(args.resume, map_location='cpu')
420
+ print("Resume checkpoint %s" % args.resume)
421
+ model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
422
+ args.start_epoch = checkpoint['epoch'] + 1
423
+ optimizer.load_state_dict(checkpoint['optimizer'])
424
+ if 'scaler' in checkpoint:
425
+ loss_scaler.load_state_dict(checkpoint['scaler'])
426
+ if 'best_so_far' in checkpoint:
427
+ best_so_far = checkpoint['best_so_far']
428
+ print(" & best_so_far={:g}".format(best_so_far))
429
+ else:
430
+ print("")
431
+ print("With optim & sched! start_epoch={:d}".format(args.start_epoch), end='')
432
+ return best_so_far
433
+
434
+ def all_reduce_mean(x):
435
+ world_size = get_world_size()
436
+ if world_size > 1:
437
+ x_reduce = torch.tensor(x).cuda()
438
+ dist.all_reduce(x_reduce)
439
+ x_reduce /= world_size
440
+ return x_reduce.item()
441
+ else:
442
+ return x
443
+
444
+ def _replace(text, src, tgt, rm=''):
445
+ """ Advanced string replacement.
446
+ Given a text:
447
+ - replace all elements in src by the corresponding element in tgt
448
+ - remove all elements in rm
449
+ """
450
+ if len(tgt) == 1:
451
+ tgt = tgt * len(src)
452
+ assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len"
453
+ for s,t in zip(src, tgt):
454
+ text = text.replace(s,t)
455
+ for c in rm:
456
+ text = text.replace(c,'')
457
+ return text
458
+
459
+ def filename( obj ):
460
+ """ transform a python obj or cmd into a proper filename.
461
+ - \1 gets replaced by slash '/'
462
+ - \2 gets replaced by comma ','
463
+ """
464
+ if not isinstance(obj, str):
465
+ obj = repr(obj)
466
+ obj = str(obj).replace('()','')
467
+ obj = _replace(obj, '_,(*/\1\2','-__x%/,', rm=' )\'"')
468
+ assert all(len(s) < 256 for s in obj.split(os.sep)), 'filename too long (>256 characters):\n'+obj
469
+ return obj
470
+
471
+ def _get_num_layer_for_vit(var_name, enc_depth, dec_depth):
472
+ if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"):
473
+ return 0
474
+ elif var_name.startswith("patch_embed"):
475
+ return 0
476
+ elif var_name.startswith("enc_blocks"):
477
+ layer_id = int(var_name.split('.')[1])
478
+ return layer_id + 1
479
+ elif var_name.startswith('decoder_embed') or var_name.startswith('enc_norm'): # part of the last black
480
+ return enc_depth
481
+ elif var_name.startswith('dec_blocks'):
482
+ layer_id = int(var_name.split('.')[1])
483
+ return enc_depth + layer_id + 1
484
+ elif var_name.startswith('dec_norm'): # part of the last block
485
+ return enc_depth + dec_depth
486
+ elif any(var_name.startswith(k) for k in ['head','prediction_head']):
487
+ return enc_depth + dec_depth + 1
488
+ else:
489
+ raise NotImplementedError(var_name)
490
+
491
+ def get_parameter_groups(model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]):
492
+ parameter_group_names = {}
493
+ parameter_group_vars = {}
494
+ enc_depth, dec_depth = None, None
495
+ # prepare layer decay values
496
+ assert layer_decay==1.0 or 0.<layer_decay<1.
497
+ if layer_decay<1.:
498
+ enc_depth = model.enc_depth
499
+ dec_depth = model.dec_depth if hasattr(model, 'dec_blocks') else 0
500
+ num_layers = enc_depth+dec_depth
501
+ layer_decay_values = list(layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))
502
+
503
+ for name, param in model.named_parameters():
504
+ if not param.requires_grad:
505
+ continue # frozen weights
506
+
507
+ # Assign weight decay values
508
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
509
+ group_name = "no_decay"
510
+ this_weight_decay = 0.
511
+ elif 'mask_token' in name or 'pathch_embed' in name or 'enc_blocks' in name:
512
+ group_name = "encoder"
513
+ this_weight_decay = weight_decay
514
+ else:
515
+ group_name = "decay"
516
+ this_weight_decay = weight_decay
517
+
518
+ # Assign layer ID for LR scaling
519
+ if layer_decay<1.:
520
+ skip_scale = False
521
+ layer_id = _get_num_layer_for_vit(name, enc_depth, dec_depth)
522
+ group_name = "layer_%d_%s" % (layer_id, group_name)
523
+ if name in no_lr_scale_list:
524
+ skip_scale = True
525
+ group_name = f'{group_name}_no_lr_scale'
526
+ else:
527
+ layer_id = 0
528
+ skip_scale = True
529
+
530
+ if group_name not in parameter_group_names:
531
+ if not skip_scale:
532
+ scale = layer_decay_values[layer_id]
533
+ elif group_name == "encoder":
534
+ scale = 0.5
535
+ else:
536
+ scale = 1.
537
+
538
+ parameter_group_names[group_name] = {
539
+ "weight_decay": this_weight_decay,
540
+ "params": [],
541
+ "lr_scale": scale
542
+ }
543
+ parameter_group_vars[group_name] = {
544
+ "weight_decay": this_weight_decay,
545
+ "params": [],
546
+ "lr_scale": scale
547
+ }
548
+ parameter_group_vars[group_name]["params"].append(param)
549
+ parameter_group_names[group_name]["params"].append(name)
550
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
551
+ return list(parameter_group_vars.values())
552
+
553
+
554
+
555
+ def adjust_learning_rate(optimizer, epoch, args):
556
+ """Decay the learning rate with half-cycle cosine after warmup"""
557
+ lr_peak = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
558
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
559
+ T_mult = 1
560
+ warmup_iters = 2
561
+ T_0 = args.cycle_epoch
562
+ end = args.epochs
563
+ epoch_int = int(epoch) % T_0
564
+ decimal_part = epoch - math.floor(epoch)
565
+ epoch = decimal_part + epoch_int
566
+ T_cur = epoch
567
+
568
+
569
+ if T_cur < warmup_iters:
570
+ warmup_ratio = T_cur / warmup_iters
571
+ lr = args.min_lr + (lr_peak - args.min_lr) * warmup_ratio
572
+ else:
573
+ T_cur_adjusted = T_cur - warmup_iters
574
+ T_i = T_0 - warmup_iters
575
+ lr = args.min_lr + (lr_peak - args.min_lr) * (1 + math.cos(math.pi * T_cur_adjusted / T_i)) / 2
576
+ # 1e-5 + (1e-4-1e-5) * (1+math.cos(math.pi * (10-2) / 98)) / 2
577
+ for param_group in optimizer.param_groups:
578
+ if "lr_scale" in param_group:
579
+ param_group["lr"] = lr * param_group["lr_scale"]
580
+ else:
581
+ param_group["lr"] = lr
582
+ return lr
583
+
dust3r/dust3r/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
dust3r/dust3r/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (150 Bytes). View file
 
dust3r/dust3r/__pycache__/model.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
dust3r/dust3r/__pycache__/patch_embed.cpython-312.pyc ADDED
Binary file (4.86 kB). View file
 
dust3r/dust3r/__pycache__/viz.cpython-312.pyc ADDED
Binary file (22.3 kB). View file
 
dust3r/dust3r/datasets/CustomDataset.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Dataloader for preprocessed arkitscenes
6
+ # dataset at https://github.com/apple/ARKitScenes - Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license
7
+ # See datasets_preprocess/preprocess_arkitscenes.py
8
+ # --------------------------------------------------------
9
+ import os.path as osp
10
+ import cv2
11
+ import numpy as np
12
+ import random
13
+ import mast3r.utils.path_to_dust3r # noqa
14
+ from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset_test
15
+ from collections import deque
16
+ import os
17
+ import json
18
+ import time
19
+ import glob
20
+ from pathlib import Path
21
+
22
+ class CustomDataset(BaseStereoViewDataset_test):
23
+ def __init__(self, *args, split, ROOT, wpose=False, sequential_input=False, index_list=None, **kwargs):
24
+ self.ROOT = ROOT
25
+ self.wpose = wpose
26
+ self.sequential_input = sequential_input
27
+ super().__init__(*args, **kwargs)
28
+
29
+ def __len__(self):
30
+ return 684000
31
+
32
+ @staticmethod
33
+ def image_read(image_file):
34
+ img = cv2.imread(image_file)
35
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36
+
37
+ def read_cam_file(self, filename):
38
+ with open(filename) as f:
39
+ lines = [line.rstrip() for line in f.readlines()]
40
+ # extrinsics: line [1,5), 4x4 matrix
41
+ extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
42
+ extrinsics = extrinsics.reshape((4, 4))
43
+ # intrinsics: line [7-10), 3x3 matrix
44
+ intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
45
+ intrinsics = intrinsics.reshape((3, 3))
46
+ # depth_min & depth_interval: line 11
47
+ depth_min = float(lines[11].split()[0])
48
+ depth_interval = float(lines[11].split()[1])
49
+ return intrinsics, extrinsics, depth_min, depth_interval
50
+
51
+
52
+ def _get_views(self, idx, resolution, rng):
53
+ images_list = glob.glob(osp.join(self.ROOT, '*.png')) + glob.glob(osp.join(self.ROOT, '*.jpg')) + glob.glob(osp.join(self.ROOT, '*.JPG'))
54
+ images_list = sorted(images_list)
55
+ if self.num_image != len(images_list):
56
+ images_list = random.sample(images_list, self.num_image)
57
+ self.gt_num_image = 0
58
+ views = []
59
+ for image in images_list:
60
+ rgb_image = self.image_read(image)
61
+ H, W = rgb_image.shape[:2]
62
+ if self.wpose == False:
63
+ intrinsics = np.array([[W, 0, W/2], [0, H, H/2], [0, 0, 1]])
64
+ camera_pose = np.eye(4)
65
+ else:
66
+ image_index = image.split('/')[-1].split('.')[0]
67
+ proj_mat_filename = os.path.join(self.ROOT, image_index+'.txt')
68
+ intrinsics, camera_pose, depth_min, depth_interval = self.read_cam_file(proj_mat_filename)
69
+ camera_pose = np.linalg.inv(camera_pose)
70
+
71
+ depthmap = np.zeros((H, W))
72
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
73
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=None)
74
+ rgb_image_orig = rgb_image.copy()
75
+ H, W = depthmap.shape[:2]
76
+ fxfycxcy = np.array([intrinsics[0, 0]/W, intrinsics[1, 1]/H, intrinsics[0,2]/W, intrinsics[1,2]/H]).astype(np.float32)
77
+ views.append(dict(
78
+ img_org=rgb_image_orig,
79
+ img=rgb_image,
80
+ depthmap=depthmap.astype(np.float32),
81
+ camera_pose=camera_pose.astype(np.float32),
82
+ camera_intrinsics=intrinsics.astype(np.float32),
83
+ fxfycxcy=fxfycxcy,
84
+ dataset='custom',
85
+ label=image,
86
+ instance=image,
87
+ ))
88
+ return views
89
+
90
+
91
+ if __name__ == "__main__":
92
+ from dust3r.datasets.base.base_stereo_view_dataset import view_name
93
+ from dust3r.viz import SceneViz, auto_cam_size
94
+ from dust3r.utils.image import rgb
95
+ import nerfvis.scene as scene_vis
96
+
97
+ for idx in np.random.permutation(len(dataset)):
98
+ views = dataset[idx]
99
+ # assert len(views) == 2
100
+ # print(view_name(views[0]), view_name(views[1]))
101
+ view_idxs = list(range(len(views)))
102
+ poses = [views[view_idx]['camera_pose'] for view_idx in view_idxs]
103
+ cam_size = max(auto_cam_size(poses), 0.001)
104
+ pts3ds = []
105
+ colors = []
106
+ valid_masks = []
107
+ c2ws = []
108
+ intrinsics = []
109
+ for view_idx in view_idxs:
110
+ pts3d = views[view_idx]['pts3d']
111
+ pts3ds.append(pts3d)
112
+ valid_mask = views[view_idx]['valid_mask']
113
+ valid_masks.append(valid_mask)
114
+ color = rgb(views[view_idx]['img'])
115
+ colors.append(color)
116
+ # viz.add_pointcloud(pts3d, colors, valid_mask)
117
+ c2ws.append(views[view_idx]['camera_pose'])
118
+
119
+
120
+ pts3ds = np.stack(pts3ds, axis=0)
121
+ colors = np.stack(colors, axis=0)
122
+ valid_masks = np.stack(valid_masks, axis=0)
123
+ c2ws = np.stack(c2ws)
124
+ scene_vis.set_title("My Scene")
125
+ scene_vis.set_opencv()
126
+ # colors = torch.zeros_like(structure).to(structure)
127
+ # scene_vis.add_points("points", pts3ds.reshape(-1,3)[valid_masks.reshape(-1)], vert_color=colors.reshape(-1,3)[valid_masks.reshape(-1)], point_size=1)
128
+ # for i in range(len(c2ws)):
129
+ f = 1111.0 / 2.5
130
+ z = 10.
131
+ scene_vis.add_camera_frustum("cameras", r=c2ws[:, :3, :3], t=c2ws[:, :3, 3], focal_length=f,
132
+ image_width=colors.shape[2], image_height=colors.shape[1],
133
+ z=z, connect=False, color=[1.0, 0.0, 0.0])
134
+ for i in range(len(c2ws)):
135
+ scene_vis.add_image(
136
+ f"images/{i}",
137
+ colors[i], # Can be a list of paths too (requires joblib for that)
138
+ r=c2ws[i, :3, :3],
139
+ t=c2ws[i, :3, 3],
140
+ # Alternatively: from nerfvis.utils import split_mat4; **split_mat4(c2ws)
141
+ focal_length=f,
142
+ z=z,
143
+ )
144
+ scene_vis.display(port=8081)
145
+
dust3r/dust3r/datasets/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ from .utils.transforms import *
4
+ from .base.batched_sampler import BatchedRandomSampler # noqa
5
+ from .CustomDataset import CustomDataset # noqa
6
+
7
+ def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
8
+ import torch
9
+ from croco.utils.misc import get_world_size, get_rank
10
+ # pytorch dataset
11
+ if isinstance(dataset, str):
12
+ dataset = eval(dataset)
13
+
14
+ world_size = get_world_size()
15
+ rank = get_rank()
16
+ try:
17
+ sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
18
+ rank=rank, drop_last=drop_last)
19
+ except (AttributeError, NotImplementedError):
20
+ # not avail for this dataset
21
+ if torch.distributed.is_initialized():
22
+ sampler = torch.utils.data.DistributedSampler(
23
+ dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
24
+ )
25
+ elif shuffle:
26
+ sampler = torch.utils.data.RandomSampler(dataset)
27
+ else:
28
+ sampler = torch.utils.data.SequentialSampler(dataset)
29
+
30
+ data_loader = torch.utils.data.DataLoader(
31
+ dataset,
32
+ sampler=sampler,
33
+ batch_size=batch_size,
34
+ num_workers=num_workers,
35
+ pin_memory=pin_mem,
36
+ drop_last=drop_last,
37
+ )
38
+
39
+ return data_loader
dust3r/dust3r/datasets/__pycache__/CustomDataset.cpython-312.pyc ADDED
Binary file (8.14 kB). View file
 
dust3r/dust3r/datasets/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.76 kB). View file
 
dust3r/dust3r/datasets/base/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
dust3r/dust3r/datasets/base/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (164 Bytes). View file
 
dust3r/dust3r/datasets/base/__pycache__/base_stereo_view_dataset.cpython-312.pyc ADDED
Binary file (32.6 kB). View file
 
dust3r/dust3r/datasets/base/__pycache__/batched_sampler.cpython-312.pyc ADDED
Binary file (4.09 kB). View file
 
dust3r/dust3r/datasets/base/__pycache__/easy_dataset.cpython-312.pyc ADDED
Binary file (8.87 kB). View file
 
dust3r/dust3r/datasets/base/base_stereo_view_dataset.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # base class for implementing datasets
6
+ # --------------------------------------------------------
7
+ import PIL
8
+ import numpy as np
9
+ import torch
10
+
11
+ from dust3r.datasets.base.easy_dataset import EasyDataset
12
+ from dust3r.datasets.utils.transforms import ImgNorm
13
+ from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf
14
+ import dust3r.datasets.utils.cropping as cropping
15
+ import random
16
+ import copy
17
+ from scipy.spatial.transform import Rotation
18
+ import torchvision.transforms as transforms
19
+ from dust3r.utils.geometry import inv, geotrf
20
+ import cv2
21
+
22
+
23
+
24
+ class BaseStereoViewDataset_test (EasyDataset):
25
+ """ Define all basic options.
26
+
27
+ Usage:
28
+ class MyDataset (BaseStereoViewDataset):
29
+ def _get_views(self, idx, rng):
30
+ # overload here
31
+ views = []
32
+ views.append(dict(img=, ...))
33
+ return views
34
+ """
35
+
36
+ def __init__(self, *, # only keyword arguments
37
+ split=None,
38
+ resolution=None, # square_size or (width, height) or list of [(width,height), ...]
39
+ transform=ImgNorm,
40
+ aug_crop=False,
41
+ seed=None,
42
+ num_views=8,
43
+ gt_num_image=0,
44
+ aug_monocular=False,
45
+ aug_portrait_or_landscape=False,
46
+ aug_rot90=False,
47
+ aug_swap=False,
48
+ only_pose=False,
49
+ sequential_input=False,
50
+ overfit=False,
51
+ caculate_mask=False):
52
+ self.sequential_input = sequential_input
53
+ self.split = split
54
+ self.num_image = num_views
55
+ self._set_resolutions(resolution)
56
+ self.gt_num_image=gt_num_image
57
+ self.aug_monocular=aug_monocular
58
+ self.aug_portrait_or_landscape = aug_portrait_or_landscape
59
+ self.transform = transform
60
+ self.transform_org = transforms.Compose([transform for transform in transform.transforms if type(transform).__name__ != 'ColorJitter'])
61
+ self.aug_rot90 = aug_rot90
62
+ self.aug_swap = aug_swap
63
+ self.only_pose = only_pose
64
+ self.overfit = overfit
65
+ self.rendering = False
66
+ self.caculate_mask = caculate_mask
67
+ if isinstance(transform, str):
68
+ transform = eval(transform)
69
+
70
+ self.aug_crop = aug_crop
71
+ self.seed = seed
72
+ self.kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(9, 9))
73
+
74
+ def __len__(self):
75
+ return len(self.scenes)
76
+
77
+ # def sequential_sample(self, im_start, last, interal):
78
+ # im_list = [im_start + i * interal + random.choice(list(range(interal))) for i in range(self.num_image)]
79
+ # im_list += [random.choice(im_list) + random.choice(list(range(interal))) for _ in range(self.gt_num_image)]
80
+ # return im_list
81
+ def sequential_sample(self, im_start, last, interal):
82
+ im_list = [
83
+ im_start + i * interal + random.choice(list(range(-interal//2, interal//2)))
84
+ for i in range(self.num_image)
85
+ ]
86
+ im_list += [
87
+ random.choice(im_list) + random.choice(list(range(-interal//2, interal//2)))
88
+ for _ in range(self.gt_num_image)
89
+ ]
90
+ return im_list
91
+
92
+ def get_stats(self):
93
+ return f"{len(self)} pairs"
94
+
95
+ def __repr__(self):
96
+ resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
97
+ return f"""{type(self).__name__}({self.get_stats()},
98
+ {self.split=},
99
+ {self.seed=},
100
+ resolutions={resolutions_str},
101
+ {self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '')
102
+
103
+ def _get_views(self, idx, resolution, rng):
104
+ raise NotImplementedError()
105
+
106
+ def _swap_view_aug(self, views):
107
+ # if self._rng.random() < 0.5:
108
+ # views.reverse()
109
+ return random.shuffle(views)
110
+
111
+ def __getitem__(self, idx):
112
+ if isinstance(idx, tuple):
113
+ # the idx is specifying the aspect-ratio
114
+ idx, ar_idx = idx
115
+ else:
116
+ assert len(self._resolutions) == 1
117
+ ar_idx = 0
118
+
119
+ # set-up the rng
120
+ if self.seed: # reseed for each __getitem__
121
+ self._rng = np.random.default_rng(seed=self.seed + idx)
122
+ elif not hasattr(self, '_rng'):
123
+ seed = torch.initial_seed() # this is different for each dataloader process
124
+ self._rng = np.random.default_rng(seed=seed)
125
+
126
+ # over-loaded code
127
+ resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
128
+ flag = False
129
+ i = 0
130
+ # while flag == False and i < 100:
131
+ # try:
132
+ # views = self._get_views(idx, resolution, self._rng)
133
+ # flag = True
134
+ # except:
135
+ # flag = False
136
+ # i += 1
137
+
138
+ views = self._get_views(idx, resolution, self._rng)
139
+
140
+ # assert len(views) == self.num_image + self.gt_num_image
141
+ if self.only_pose == True:
142
+ # check data-types
143
+ for view in views:
144
+ # transpose to make sure all views are the same size
145
+ # this allows to check whether the RNG is is the same state each time
146
+ view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
147
+ return views
148
+ else:
149
+ for v, view in enumerate(views):
150
+ assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
151
+ view['idx'] = (idx, ar_idx, v)
152
+ # encode the image
153
+ width, height = view['img'].size
154
+ view['true_shape'] = np.int32((height, width))
155
+ view['img'] = self.transform_org(view['img'])
156
+ view['img_org' ] = self.transform_org(view['img_org'])
157
+ if 'depth_anything' not in view:
158
+ view['depth_anything'] = np.zeros_like(view['depthmap'])
159
+ # if view['img_org'].shape[1] != 224:
160
+ # print(view['img_org' ].shape)
161
+ # print(view['img'].shape)
162
+ assert 'camera_intrinsics' in view
163
+ if 'camera_pose' not in view:
164
+ view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
165
+ else:
166
+ assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
167
+ assert 'pts3d' not in view
168
+ assert 'valid_mask' not in view
169
+ assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
170
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
171
+
172
+ view['pts3d'] = pts3d
173
+ view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
174
+ # print(view['pts3d'].shape)
175
+ # print(view['valid_mask'].shape)
176
+
177
+ # check all datatypes
178
+ for key, val in view.items():
179
+ res, err_msg = is_good_type(key, val)
180
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
181
+ K = view['camera_intrinsics']
182
+
183
+
184
+ for view in views:
185
+ fxfycxcy = view['fxfycxcy'].copy()
186
+ H, W = view['img'].shape[1:]
187
+ fxfycxcy[0] = fxfycxcy[0] * W
188
+ fxfycxcy[1] = fxfycxcy[1] * H
189
+ fxfycxcy[2] = fxfycxcy[2] * W
190
+ fxfycxcy[3] = fxfycxcy[3] * H
191
+ view['fxfycxcy_unorm'] = fxfycxcy
192
+
193
+ # last thing done!
194
+ for view in views:
195
+ view['render_mask'] = np.ones((view['img'].shape[1], view['img'].shape[2]), dtype=np.uint8) > 0.1
196
+
197
+ for view in views:
198
+ # transpose to make sure all views are the same size
199
+ transpose_to_landscape(view)
200
+ if 'sky_mask' in view:
201
+ view.pop('sky_mask')
202
+ # this allows to check whether the RNG is is the same state each time
203
+ view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
204
+ return views
205
+
206
+ def _set_resolutions(self, resolutions):
207
+ assert resolutions is not None, 'undefined resolution'
208
+
209
+ if not isinstance(resolutions, list):
210
+ resolutions = [resolutions]
211
+
212
+ self._resolutions = []
213
+ for resolution in resolutions:
214
+ if isinstance(resolution, int):
215
+ width = height = resolution
216
+ else:
217
+ width, height = resolution
218
+ assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
219
+ assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
220
+ assert width >= height
221
+ self._resolutions.append((width, height))
222
+
223
+ def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None, depth_anything=None):
224
+ """ This function:
225
+ - first downsizes the image with LANCZOS inteprolation,
226
+ which is better than bilinear interpolation in
227
+ """
228
+ if not isinstance(image, PIL.Image.Image):
229
+ image = PIL.Image.fromarray(image)
230
+
231
+ # transpose the resolution if necessary
232
+ W, H = image.size # new size
233
+ assert resolution[0] >= resolution[1]
234
+ if H > 1.1 * W:
235
+ # image is portrait mode
236
+ resolution = resolution[::-1]
237
+ elif 0.7 < H / W < 1.3 and resolution[0] != resolution[1] and self.aug_portrait_or_landscape:
238
+ # image is square, so we chose (portrait, landscape) randomly
239
+ if rng.integers(2):
240
+ resolution = resolution[::-1]
241
+ # resolution = resolution[::-1]
242
+ # high-quality Lanczos down-scaling
243
+ target_resolution = np.array(resolution)
244
+ if depth_anything is not None:
245
+ image, depthmap, intrinsics, depth_anything = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution, depth_anything=depth_anything)
246
+ else:
247
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
248
+
249
+ # actual cropping (if necessary) with bilinear interpolation
250
+ offset_factor = 0.5
251
+ intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
252
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
253
+ if depth_anything is not None:
254
+ image, depthmap, intrinsics2, depth_anything = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox, depth_anything=depth_anything)
255
+ return image, depthmap, intrinsics2, depth_anything
256
+ else:
257
+ image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
258
+ return image, depthmap, intrinsics2
259
+
260
+ def _crop_resize_if_necessary_test(self, image, depthmap, intrinsics, resolution, rng=None, info=None, depth_anything=None):
261
+ """ This function:
262
+ - first downsizes the image with LANCZOS inteprolation,
263
+ which is better than bilinear interpolation in
264
+ """
265
+ if not isinstance(image, PIL.Image.Image):
266
+ image = PIL.Image.fromarray(image)
267
+
268
+ # transpose the resolution if necessary
269
+ W, H = image.size # new size
270
+ assert resolution[0] >= resolution[1]
271
+ if H > 1.1 * W:
272
+ # image is portrait mode
273
+ resolution = resolution[::-1]
274
+
275
+ # resolution = resolution[::-1]
276
+ # high-quality Lanczos down-scaling
277
+ target_resolution = np.array(resolution)
278
+ if depth_anything is not None:
279
+ image, depthmap, intrinsics, depth_anything = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution, depth_anything=depth_anything)
280
+ else:
281
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
282
+
283
+ # actual cropping (if necessary) with bilinear interpolation
284
+ offset_factor = 0.5
285
+ intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
286
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
287
+ if depth_anything is not None:
288
+ image, depthmap, intrinsics2, depth_anything = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox, depth_anything=depth_anything)
289
+ else:
290
+ image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
291
+
292
+ return image, depthmap, intrinsics2
293
+
294
+ def rotate_90(views, k=1):
295
+ # print('rotation =', k)
296
+ RT = np.eye(4, dtype=np.float32)
297
+ RT[:3, :3] = Rotation.from_euler('z', 90 * k, degrees=True).as_matrix()
298
+
299
+ for view in views:
300
+ view['img'] = torch.rot90(view['img'], k=k, dims=(-2, -1)) # WARNING!! dims=(-1,-2) != dims=(-2,-1)
301
+ view['depthmap'] = np.rot90(view['depthmap'], k=k).copy()
302
+ view['camera_pose'] = view['camera_pose'] @ RT
303
+
304
+ RT2 = np.eye(3, dtype=np.float32)
305
+ RT2[:2, :2] = RT[:2, :2] * ((1, -1), (-1, 1))
306
+ H, W = view['depthmap'].shape
307
+ if k % 4 == 0:
308
+ pass
309
+ elif k % 4 == 1:
310
+ # top-left (0,0) pixel becomes (0,H-1)
311
+ RT2[:2, 2] = (0, H - 1)
312
+ elif k % 4 == 2:
313
+ # top-left (0,0) pixel becomes (W-1,H-1)
314
+ RT2[:2, 2] = (W - 1, H - 1)
315
+ elif k % 4 == 3:
316
+ # top-left (0,0) pixel becomes (W-1,0)
317
+ RT2[:2, 2] = (W - 1, 0)
318
+ else:
319
+ raise ValueError(f'Bad value for {k=}')
320
+
321
+ view['camera_intrinsics'][:2, 2] = geotrf(RT2, view['camera_intrinsics'][:2, 2])
322
+ if k % 2 == 1:
323
+ K = view['camera_intrinsics']
324
+ np.fill_diagonal(K, K.diagonal()[[1, 0, 2]])
325
+
326
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
327
+ view['pts3d'] = pts3d
328
+ view['valid_mask'] = np.rot90(view['valid_mask'], k=k).copy()
329
+ view['true_shape'] = np.int32((H, W))
330
+ intrinsics = view['camera_intrinsics']
331
+ fxfycxcy = np.array([intrinsics[0, 0]/W, intrinsics[1, 1]/H, intrinsics[0,2]/W, intrinsics[1,2]/H]).astype(np.float32)
332
+ view['fxfycxcy'] = fxfycxcy
333
+
334
+ def reciprocal_1d(corres_1_to_2, corres_2_to_1, shape1, shape2, ret_recip=False):
335
+ is_reciprocal1 = np.abs(unravel_xy(corres_2_to_1[corres_1_to_2], shape1) - unravel_xy(np.arange(len(corres_1_to_2)), shape1)).sum(-1) < 4
336
+ pos1 = is_reciprocal1.nonzero()[0]
337
+ pos2 = corres_1_to_2[pos1]
338
+ if ret_recip:
339
+ return is_reciprocal1, pos1, pos2
340
+ return pos1, pos2
341
+
342
+
343
+ def reproject_view(pts3d, view2):
344
+ shape = view2['pts3d'].shape[:2]
345
+ return reproject(pts3d, view2['camera_intrinsics'], inv(view2['camera_pose']), shape)
346
+
347
+
348
+ def reproject(pts3d, K, world2cam, shape):
349
+ H, W, THREE = pts3d.shape
350
+ assert THREE == 3
351
+
352
+ # reproject in camera2 space
353
+ with np.errstate(divide='ignore', invalid='ignore'):
354
+ pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
355
+
356
+ # quantize to pixel positions
357
+ return (H, W), ravel_xy(pos, shape)
358
+
359
+
360
+ def ravel_xy(pos, shape):
361
+ H, W = shape
362
+ with np.errstate(invalid='ignore'):
363
+ qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
364
+ quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
365
+ return quantized_pos
366
+
367
+
368
+ def unravel_xy(pos, shape):
369
+ # convert (x+W*y) back to 2d (x,y) coordinates
370
+ return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
371
+
372
+
373
+ class BaseStereoViewDataset (EasyDataset):
374
+ """ Define all basic options.
375
+
376
+ Usage:
377
+ class MyDataset (BaseStereoViewDataset):
378
+ def _get_views(self, idx, rng):
379
+ # overload here
380
+ views = []
381
+ views.append(dict(img=, ...))
382
+ return views
383
+ """
384
+
385
+ def __init__(self, *, # only keyword arguments
386
+ split=None,
387
+ resolution=None, # square_size or (width, height) or list of [(width,height), ...]
388
+ transform=ImgNorm,
389
+ aug_crop=False,
390
+ seed=None,
391
+ num_views=8,
392
+ gt_num_image=0,
393
+ aug_monocular=False,
394
+ aug_portrait_or_landscape=True,
395
+ aug_rot90=False,
396
+ aug_swap=False,
397
+ only_pose=False,
398
+ sequential_input=False,
399
+ overfit=False,
400
+ caculate_mask=False):
401
+ self.sequential_input = sequential_input
402
+ self.split = split
403
+ self.num_image = num_views
404
+ self._set_resolutions(resolution)
405
+ self.gt_num_image=gt_num_image
406
+ self.aug_monocular=aug_monocular
407
+ self.aug_portrait_or_landscape = aug_portrait_or_landscape
408
+ self.transform = transform
409
+ self.transform_org = transforms.Compose([transform for transform in transform.transforms if type(transform).__name__ != 'ColorJitter'])
410
+ self.aug_rot90 = aug_rot90
411
+ self.aug_swap = aug_swap
412
+ self.only_pose = only_pose
413
+ self.overfit = overfit
414
+ self.rendering = False
415
+ self.caculate_mask = caculate_mask
416
+ if isinstance(transform, str):
417
+ transform = eval(transform)
418
+
419
+ self.aug_crop = aug_crop
420
+ self.seed = seed
421
+ self.kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(9, 9))
422
+
423
+ def __len__(self):
424
+ return len(self.scenes)
425
+
426
+ # def sequential_sample(self, im_start, last, interal):
427
+ # im_list = [im_start + i * interal + random.choice(list(range(interal))) for i in range(self.num_image)]
428
+ # im_list += [random.choice(im_list) + random.choice(list(range(interal))) for _ in range(self.gt_num_image)]
429
+ # return im_list
430
+ def sequential_sample(self, im_start, last, interal):
431
+ im_list = [
432
+ im_start + i * interal + random.choice(list(range(-interal//2, interal//2)))
433
+ for i in range(self.num_image)
434
+ ]
435
+ im_list += [
436
+ random.choice(im_list) + random.choice(list(range(-interal//2, interal//2)))
437
+ for _ in range(self.gt_num_image)
438
+ ]
439
+ return im_list
440
+
441
+ def get_stats(self):
442
+ return f"{len(self)} pairs"
443
+
444
+ def __repr__(self):
445
+ resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
446
+ return f"""{type(self).__name__}({self.get_stats()},
447
+ {self.split=},
448
+ {self.seed=},
449
+ resolutions={resolutions_str},
450
+ {self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '')
451
+
452
+ def _get_views(self, idx, resolution, rng):
453
+ raise NotImplementedError()
454
+
455
+ def _swap_view_aug(self, views):
456
+ # if self._rng.random() < 0.5:
457
+ # views.reverse()
458
+ return random.shuffle(views)
459
+
460
+ def __getitem__(self, idx):
461
+ if isinstance(idx, tuple):
462
+ # the idx is specifying the aspect-ratio
463
+ idx, ar_idx = idx
464
+ else:
465
+ assert len(self._resolutions) == 1
466
+ ar_idx = 0
467
+
468
+ # set-up the rng
469
+ if self.seed: # reseed for each __getitem__
470
+ self._rng = np.random.default_rng(seed=self.seed + idx)
471
+ elif not hasattr(self, '_rng'):
472
+ seed = torch.initial_seed() # this is different for each dataloader process
473
+ self._rng = np.random.default_rng(seed=seed)
474
+
475
+ # over-loaded code
476
+ resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
477
+ flag = False
478
+ i = 0
479
+ # views = self._get_views(idx, resolution, self._rng)
480
+ while flag == False and i < 1000:
481
+ try:
482
+ views = self._get_views(idx, resolution, self._rng)
483
+ flag = True
484
+ except:
485
+ flag = False
486
+ i += 1
487
+
488
+ # assert len(views) == self.num_image + self.gt_num_image
489
+ if self.only_pose == True:
490
+ # check data-types
491
+ for view in views:
492
+ # transpose to make sure all views are the same size
493
+ # this allows to check whether the RNG is is the same state each time
494
+ view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
495
+ return views
496
+ else:
497
+ for v, view in enumerate(views):
498
+ assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
499
+ view['idx'] = (idx, ar_idx, v)
500
+ # encode the image
501
+ width, height = view['img'].size
502
+ view['true_shape'] = np.int32((height, width))
503
+ view['img'] = self.transform(view['img'])
504
+ view['img_org' ] = self.transform_org(view['img_org'])
505
+ if 'depth_anything' not in view:
506
+ view['depth_anything'] = np.zeros_like(view['depthmap'])
507
+ # if view['img_org'].shape[1] != 224:
508
+ # print(view['img_org' ].shape)
509
+ # print(view['img'].shape)
510
+ assert 'camera_intrinsics' in view
511
+ if 'camera_pose' not in view:
512
+ view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
513
+ else:
514
+ assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
515
+ assert 'pts3d' not in view
516
+ assert 'valid_mask' not in view
517
+ assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
518
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
519
+
520
+ view['pts3d'] = pts3d
521
+ view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
522
+ # print(view['pts3d'].shape)
523
+ # print(view['valid_mask'].shape)
524
+
525
+ # check all datatypes
526
+ for key, val in view.items():
527
+ res, err_msg = is_good_type(key, val)
528
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
529
+ K = view['camera_intrinsics']
530
+
531
+ # if self.aug_swap:
532
+ # self._swap_view_aug(views)
533
+
534
+ if self.aug_monocular:
535
+ if self._rng.random() < self.aug_monocular:
536
+ random_idxs = random.choices(list(range(len(views)-1)), k = self.num_image + self.gt_num_image-1)
537
+ views_copy = [views[-1]] + [copy.deepcopy(views[random_idxs[i]]) for i in range(len(views)-1)]
538
+ views = views_copy
539
+
540
+ # if self.aug_rot90 is False:
541
+ # pass
542
+ # elif self.aug_rot90 == 'same':
543
+ # rotate_90(views, k=self._rng.choice(4))
544
+ # elif self.aug_rot90 == 'diff':
545
+ # views_list = []
546
+ # for view in views:
547
+ # views_list += rotate_90([view], k=self._rng.choice(4))
548
+ # views = views_list
549
+ # else:
550
+ # raise ValueError(f'Bad value for {self.aug_rot90=}')
551
+ if self.rendering==False:
552
+ self._rng.shuffle(views)
553
+
554
+ if self.caculate_mask:
555
+ for view1 in views[self.num_image:]:
556
+ render_mask = []
557
+ start = True
558
+ # images = []
559
+ height, width = view1['true_shape']
560
+ for view2 in views[:self.num_image]:
561
+ shape1, corres1_to_2 = reproject_view(view1['pts3d'], view2)
562
+ shape2, corres2_to_1 = reproject_view(view2['pts3d'], view1)
563
+ # compute reciprocal correspondences:
564
+ # pos1 == valid pixels (correspondences) in image1
565
+ # corres1_to_2 = unravel_xy(corres1_to_2, shape2)
566
+ # corres2_to_1 = unravel_xy(corres2_to_1, shape1)
567
+ is_reciprocal1, pos1, pos2 = reciprocal_1d(corres1_to_2, corres2_to_1, shape1, shape2, ret_recip=True)
568
+ render_mask.append(is_reciprocal1.reshape(shape1))
569
+ # is_reciprocal1 = is_reciprocal1.reshape(shape1)
570
+ # plt.subplot(1, 3, 1)
571
+ # plt.imshow(is_reciprocal1)
572
+ # plt.subplot(1, 3, 2)
573
+ # plt.imshow(view1['img'].permute(1, 2, 0) / 2 + 0.5)
574
+ # plt.subplot(1, 3, 3)
575
+ # plt.imshow(view2['img'].permute(1, 2, 0) / 2 + 0.5)
576
+ # plt.savefig('/data0/zsz/mast3recon/test/est.png')
577
+ # import ipdb; ipdb.set_trace()
578
+ # images.append(view2['img'])
579
+ if start:
580
+ view2['render_mask'] = np.ones((view2['img'].shape[1], view2['img'].shape[2]), dtype=np.uint8) > 0.1
581
+ start = False
582
+ render_mask = np.stack(render_mask, axis=0).sum(0)
583
+ render_mask = cv2.dilate(render_mask/16, self.kernel)
584
+ view1['render_mask'] = render_mask > 1e-5
585
+ # images = torch.concatenate(images, dim=2)
586
+ # import matplotlib.pyplot as plt
587
+ # plt.subplot(3, 4, 1)
588
+ # plt.imshow(render_mask)
589
+ # plt.subplot(3, 4, 2)
590
+ # plt.imshow(view1['img'].permute(1, 2, 0) / 2 + 0.5)
591
+ # for i, image in enumerate(images):
592
+ # plt.subplot(3, 4, 3+i)
593
+ # plt.imshow(image.permute(1, 2, 0) / 2 + 0.5)
594
+ # plt.savefig('/data0/zsz/mast3recon/test/est.png')
595
+ # import ipdb; ipdb.set_trace()
596
+ # if view1['render_mask'].shape != (height, width):
597
+ # import ipdb; ipdb.set_trace()
598
+ assert view1['render_mask'].shape == (height, width)
599
+ else:
600
+ for view in views:
601
+ view['render_mask'] = np.ones((view['img'].shape[1], view['img'].shape[2]), dtype=np.uint8) > 0.1
602
+
603
+ for view in views:
604
+ fxfycxcy = view['fxfycxcy'].copy()
605
+ H, W = view['img'].shape[1:]
606
+ fxfycxcy[0] = fxfycxcy[0] * W
607
+ fxfycxcy[1] = fxfycxcy[1] * H
608
+ fxfycxcy[2] = fxfycxcy[2] * W
609
+ fxfycxcy[3] = fxfycxcy[3] * H
610
+ view['fxfycxcy_unorm'] = fxfycxcy
611
+
612
+ # last thing done!
613
+ for view in views:
614
+ # transpose to make sure all views are the same size
615
+ transpose_to_landscape(view)
616
+ if 'sky_mask' in view:
617
+ view.pop('sky_mask')
618
+ # this allows to check whether the RNG is is the same state each time
619
+ view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
620
+ return views
621
+
622
+ def _set_resolutions(self, resolutions):
623
+ assert resolutions is not None, 'undefined resolution'
624
+
625
+ if not isinstance(resolutions, list):
626
+ resolutions = [resolutions]
627
+
628
+ self._resolutions = []
629
+ for resolution in resolutions:
630
+ if isinstance(resolution, int):
631
+ width = height = resolution
632
+ else:
633
+ width, height = resolution
634
+ assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
635
+ assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
636
+ assert width >= height
637
+ self._resolutions.append((width, height))
638
+
639
+ def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None, depth_anything=None):
640
+ """ This function:
641
+ - first downsizes the image with LANCZOS inteprolation,
642
+ which is better than bilinear interpolation in
643
+ """
644
+ if not isinstance(image, PIL.Image.Image):
645
+ image = PIL.Image.fromarray(image)
646
+
647
+ # transpose the resolution if necessary
648
+ W, H = image.size # new size
649
+ assert resolution[0] >= resolution[1]
650
+ if H > 1.1 * W:
651
+ # image is portrait mode
652
+ resolution = resolution[::-1]
653
+ elif 0.7 < H / W < 1.3 and resolution[0] != resolution[1] and self.aug_portrait_or_landscape:
654
+ # image is square, so we chose (portrait, landscape) randomly
655
+ if rng.integers(2):
656
+ resolution = resolution[::-1]
657
+ # resolution = resolution[::-1]
658
+ # high-quality Lanczos down-scaling
659
+ target_resolution = np.array(resolution)
660
+ if depth_anything is not None:
661
+ image, depthmap, intrinsics, depth_anything = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution, depth_anything=depth_anything)
662
+ else:
663
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
664
+
665
+ # actual cropping (if necessary) with bilinear interpolation
666
+ offset_factor = 0.5
667
+ intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
668
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
669
+ if depth_anything is not None:
670
+ image, depthmap, intrinsics2, depth_anything = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox, depth_anything=depth_anything)
671
+ return image, depthmap, intrinsics2, depth_anything
672
+ else:
673
+ image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
674
+ return image, depthmap, intrinsics2
675
+
676
+ def _crop_resize_if_necessary_test(self, image, depthmap, intrinsics, resolution, rng=None, info=None, depth_anything=None):
677
+ """ This function:
678
+ - first downsizes the image with LANCZOS inteprolation,
679
+ which is better than bilinear interpolation in
680
+ """
681
+ if not isinstance(image, PIL.Image.Image):
682
+ image = PIL.Image.fromarray(image)
683
+
684
+ # transpose the resolution if necessary
685
+ W, H = image.size # new size
686
+ assert resolution[0] >= resolution[1]
687
+ if H > 1.1 * W:
688
+ # image is portrait mode
689
+ resolution = resolution[::-1]
690
+
691
+ # resolution = resolution[::-1]
692
+ # high-quality Lanczos down-scaling
693
+ target_resolution = np.array(resolution)
694
+ if depth_anything is not None:
695
+ image, depthmap, intrinsics, depth_anything = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution, depth_anything=depth_anything)
696
+ else:
697
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
698
+
699
+ # actual cropping (if necessary) with bilinear interpolation
700
+ offset_factor = 0.5
701
+ intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
702
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
703
+ if depth_anything is not None:
704
+ image, depthmap, intrinsics2, depth_anything = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox, depth_anything=depth_anything)
705
+ else:
706
+ image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
707
+
708
+ return image, depthmap, intrinsics2
709
+
710
+ def is_good_type(key, v):
711
+ """ returns (is_good, err_msg)
712
+ """
713
+ if isinstance(v, (str, int, tuple)):
714
+ return True, None
715
+ if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
716
+ return False, f"bad {v.dtype=}"
717
+ return True, None
718
+
719
+
720
+ def view_name(view, batch_index=None):
721
+ def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x
722
+ db = sel(view['dataset'])
723
+ label = sel(view['label'])
724
+ instance = sel(view['instance'])
725
+ return f"{db}/{label}/{instance}"
726
+
727
+
728
+ def transpose_to_landscape(view):
729
+ height, width = view['true_shape']
730
+
731
+ if width < height:
732
+ # rectify portrait to landscape
733
+ assert view['img'].shape == (3, height, width)
734
+ view['img'] = view['img'].swapaxes(1, 2)
735
+ # try:
736
+ if 'render_mask' in view:
737
+ assert view['render_mask'].shape == (height, width)
738
+ # except:
739
+ # import ipdb; ipdb.set_trace()
740
+ view['render_mask'] = view['render_mask'].swapaxes(0, 1)
741
+
742
+ assert view['img_org'].shape == (3, height, width)
743
+ view['img_org'] = view['img_org'].swapaxes(1, 2)
744
+
745
+ assert view['valid_mask'].shape == (height, width)
746
+ view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
747
+
748
+ assert view['depthmap'].shape == (height, width)
749
+ view['depthmap'] = view['depthmap'].swapaxes(0, 1)
750
+
751
+ assert view['pts3d'].shape == (height, width, 3)
752
+ view['pts3d'] = view['pts3d'].swapaxes(0, 1)
753
+
754
+ assert view['depth_anything'].shape == (height, width)
755
+ view['depth_anything'] = view['depth_anything'].swapaxes(0, 1)
756
+
757
+ # transpose x and y pixels
758
+ view['camera_intrinsics'] = view['camera_intrinsics']#[[1, 0, 2]]
759
+ # view['fxfycxcy'] = view['fxfycxcy']
760
+ # print(view['img'].shape)
761
+ # print(view['img_org'].shape)
762
+ # print(view['valid_mask'].shape)
763
+ # print(view['depthmap'].shape)
764
+ # print(view['pts3d'].shape)
765
+ # print(view['camera_intrinsics'].shape)
766
+ # print(view['fxfycxcy'].shape)
767
+ # assert view['img'].shape == (3, height, width)
768
+ # assert view['img_org'].shape == (3, height, width)
769
+ # assert view['valid_mask'].shape == (height, width)
770
+ # assert view['depthmap'].shape == (height, width)
771
+ # assert view['pts3d'].shape == (height, width, 3)
772
+ # assert view['camera_intrinsics'].shape == (3, 3)
773
+ # assert view['fxfycxcy'].shape == (4,)
774
+
dust3r/dust3r/datasets/base/batched_sampler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Random sampling under a constraint
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ class BatchedRandomSampler:
12
+ """ Random sampling under a constraint: each sample in the batch has the same feature,
13
+ which is chosen randomly from a known pool of 'features' for each batch.
14
+
15
+ For instance, the 'feature' could be the image aspect-ratio.
16
+
17
+ The index returned is a tuple (sample_idx, feat_idx).
18
+ This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
19
+ """
20
+
21
+ def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):
22
+ self.batch_size = batch_size
23
+ self.pool_size = pool_size
24
+
25
+ self.len_dataset = N = len(dataset)
26
+ self.total_size = round_by(N, batch_size*world_size) if drop_last else N
27
+ assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'
28
+
29
+ # distributed sampler
30
+ self.world_size = world_size
31
+ self.rank = rank
32
+ self.epoch = None
33
+
34
+ def __len__(self):
35
+ return self.total_size // self.world_size
36
+
37
+ def set_epoch(self, epoch):
38
+ self.epoch = epoch
39
+
40
+ def __iter__(self):
41
+ # prepare RNG
42
+ if self.epoch is None:
43
+ assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'
44
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
45
+ else:
46
+ seed = self.epoch + 777
47
+ rng = np.random.default_rng(seed=seed)
48
+
49
+ # random indices (will restart from 0 if not drop_last)
50
+ sample_idxs = np.arange(self.total_size)
51
+ rng.shuffle(sample_idxs)
52
+
53
+ # random feat_idxs (same across each batch)
54
+ n_batches = (self.total_size+self.batch_size-1) // self.batch_size
55
+ feat_idxs = rng.integers(self.pool_size, size=n_batches)
56
+ feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
57
+ feat_idxs = feat_idxs.ravel()[:self.total_size]
58
+
59
+ # put them together
60
+ idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
61
+
62
+ # Distributed sampler: we select a subset of batches
63
+ # make sure the slice for each node is aligned with batch_size
64
+ size_per_proc = self.batch_size * ((self.total_size + self.world_size *
65
+ self.batch_size-1) // (self.world_size * self.batch_size))
66
+ idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]
67
+
68
+ yield from (tuple(idx) for idx in idxs)
69
+
70
+
71
+ def round_by(total, multiple, up=False):
72
+ if up:
73
+ total = total + multiple-1
74
+ return (total//multiple) * multiple
dust3r/dust3r/datasets/base/easy_dataset.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # A dataset base class that you can easily resize and combine.
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ from dust3r.datasets.base.batched_sampler import BatchedRandomSampler
9
+
10
+
11
+ class EasyDataset:
12
+ """ a dataset that you can easily resize and combine.
13
+ Examples:
14
+ ---------
15
+ 2 * dataset ==> duplicate each element 2x
16
+
17
+ 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
18
+
19
+ dataset1 + dataset2 ==> concatenate datasets
20
+ """
21
+
22
+ def __add__(self, other):
23
+ return CatDataset([self, other])
24
+
25
+ def __rmul__(self, factor):
26
+ return MulDataset(factor, self)
27
+
28
+ def __rmatmul__(self, factor):
29
+ return ResizedDataset(factor, self)
30
+
31
+ def set_epoch(self, epoch):
32
+ pass # nothing to do by default
33
+
34
+ def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):
35
+ if not (shuffle):
36
+ raise NotImplementedError() # cannot deal yet
37
+ num_of_aspect_ratios = len(self._resolutions)
38
+ return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)
39
+
40
+
41
+ class MulDataset (EasyDataset):
42
+ """ Artifically augmenting the size of a dataset.
43
+ """
44
+ multiplicator: int
45
+
46
+ def __init__(self, multiplicator, dataset):
47
+ assert isinstance(multiplicator, int) and multiplicator > 0
48
+ self.multiplicator = multiplicator
49
+ self.dataset = dataset
50
+
51
+ def __len__(self):
52
+ return self.multiplicator * len(self.dataset)
53
+
54
+ def __repr__(self):
55
+ return f'{self.multiplicator}*{repr(self.dataset)}'
56
+
57
+ def __getitem__(self, idx):
58
+ if isinstance(idx, tuple):
59
+ idx, other = idx
60
+ return self.dataset[idx // self.multiplicator, other]
61
+ else:
62
+ return self.dataset[idx // self.multiplicator]
63
+
64
+ @property
65
+ def _resolutions(self):
66
+ return self.dataset._resolutions
67
+
68
+
69
+ class ResizedDataset (EasyDataset):
70
+ """ Artifically changing the size of a dataset.
71
+ """
72
+ new_size: int
73
+
74
+ def __init__(self, new_size, dataset):
75
+ assert isinstance(new_size, int) and new_size > 0
76
+ self.new_size = new_size
77
+ self.dataset = dataset
78
+
79
+ def __len__(self):
80
+ return self.new_size
81
+
82
+ def __repr__(self):
83
+ size_str = str(self.new_size)
84
+ for i in range((len(size_str)-1) // 3):
85
+ sep = -4*i-3
86
+ size_str = size_str[:sep] + '_' + size_str[sep:]
87
+ return f'{size_str} @ {repr(self.dataset)}'
88
+
89
+ def set_epoch(self, epoch):
90
+ # this random shuffle only depends on the epoch
91
+ rng = np.random.default_rng(seed=epoch+777)
92
+
93
+ # shuffle all indices
94
+ perm = rng.permutation(len(self.dataset))
95
+
96
+ # rotary extension until target size is met
97
+ shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))
98
+ self._idxs_mapping = shuffled_idxs[:self.new_size]
99
+
100
+ assert len(self._idxs_mapping) == self.new_size
101
+
102
+ def __getitem__(self, idx):
103
+ assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'
104
+ if isinstance(idx, tuple):
105
+ idx, other = idx
106
+ return self.dataset[self._idxs_mapping[idx], other]
107
+ else:
108
+ return self.dataset[self._idxs_mapping[idx]]
109
+
110
+ @property
111
+ def _resolutions(self):
112
+ return self.dataset._resolutions
113
+
114
+
115
+ class CatDataset (EasyDataset):
116
+ """ Concatenation of several datasets
117
+ """
118
+
119
+ def __init__(self, datasets):
120
+ for dataset in datasets:
121
+ assert isinstance(dataset, EasyDataset)
122
+ self.datasets = datasets
123
+ self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
124
+
125
+ def __len__(self):
126
+ return self._cum_sizes[-1]
127
+
128
+ def __repr__(self):
129
+ # remove uselessly long transform
130
+ return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)
131
+
132
+ def set_epoch(self, epoch):
133
+ for dataset in self.datasets:
134
+ dataset.set_epoch(epoch)
135
+
136
+ def __getitem__(self, idx):
137
+ other = None
138
+ if isinstance(idx, tuple):
139
+ idx, other = idx
140
+
141
+ if not (0 <= idx < len(self)):
142
+ raise IndexError()
143
+
144
+ db_idx = np.searchsorted(self._cum_sizes, idx, 'right')
145
+ dataset = self.datasets[db_idx]
146
+ new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
147
+
148
+ if other is not None:
149
+ new_idx = (new_idx, other)
150
+ return dataset[new_idx]
151
+
152
+ @property
153
+ def _resolutions(self):
154
+ resolutions = self.datasets[0]._resolutions
155
+ for dataset in self.datasets[1:]:
156
+ assert tuple(dataset._resolutions) == tuple(resolutions)
157
+ return resolutions
dust3r/dust3r/datasets/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).