Clara211111 commited on
Commit
1efbda0
·
1 Parent(s): 5c8071a

upload demo

Browse files
Files changed (37) hide show
  1. app.py +725 -4
  2. flow3r/models/dinov2/__init__.py +6 -0
  3. flow3r/models/dinov2/hub/__init__.py +4 -0
  4. flow3r/models/dinov2/hub/backbones.py +156 -0
  5. flow3r/models/dinov2/hub/utils.py +39 -0
  6. flow3r/models/dinov2/layers/__init__.py +11 -0
  7. flow3r/models/dinov2/layers/attention.py +89 -0
  8. flow3r/models/dinov2/layers/block.py +259 -0
  9. flow3r/models/dinov2/layers/dino_head.py +58 -0
  10. flow3r/models/dinov2/layers/drop_path.py +34 -0
  11. flow3r/models/dinov2/layers/layer_scale.py +27 -0
  12. flow3r/models/dinov2/layers/mlp.py +40 -0
  13. flow3r/models/dinov2/layers/patch_embed.py +88 -0
  14. flow3r/models/dinov2/layers/swiglu_ffn.py +72 -0
  15. flow3r/models/dinov2/models/__init__.py +43 -0
  16. flow3r/models/dinov2/models/vision_transformer.py +404 -0
  17. flow3r/models/dinov2/utils/__init__.py +4 -0
  18. flow3r/models/dinov2/utils/cluster.py +95 -0
  19. flow3r/models/dinov2/utils/config.py +72 -0
  20. flow3r/models/dinov2/utils/dtype.py +37 -0
  21. flow3r/models/dinov2/utils/param_groups.py +103 -0
  22. flow3r/models/dinov2/utils/utils.py +95 -0
  23. flow3r/models/flow3r.py +233 -0
  24. flow3r/models/flow_head/dpt_head.py +498 -0
  25. flow3r/models/flow_head/utils.py +108 -0
  26. flow3r/models/layers/attention.py +403 -0
  27. flow3r/models/layers/block.py +406 -0
  28. flow3r/models/layers/camera_head.py +93 -0
  29. flow3r/models/layers/pos_embed.py +174 -0
  30. flow3r/models/layers/transformer_head.py +389 -0
  31. flow3r/utils/alignment.py +499 -0
  32. flow3r/utils/basic.py +223 -0
  33. flow3r/utils/cropping.py +197 -0
  34. flow3r/utils/debug.py +63 -0
  35. flow3r/utils/flow_utils.py +472 -0
  36. flow3r/utils/geometry.py +367 -0
  37. requirements.txt +15 -0
app.py CHANGED
@@ -1,8 +1,729 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import cv2
9
+ import torch
10
+ import numpy as np
11
  import gradio as gr
12
+ import sys
13
+ import shutil
14
+ from datetime import datetime
15
+ import glob
16
+ import gc
17
+ import time
18
+ import trimesh
19
+ import matplotlib
20
+
21
+ from flow3r.models.flow3r import Flow3r
22
+ from flow3r.utils.basic import load_images_as_tensor
23
+ from flow3r.utils.geometry import depth_edge
24
+
25
+ from scipy.spatial.transform import Rotation
26
+ from huggingface_hub import hf_hub_download
27
+
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ print("Initializing and loading Flow3r model...")
31
+
32
+ model = Flow3r()
33
+ ckpt_path = hf_hub_download(repo_id="Clara211111/flow3r", filename="flow3r.bin")
34
+ checkpoint = torch.load(ckpt_path, weights_only=False, map_location='cpu')
35
+ model.load_state_dict(checkpoint, strict=True)
36
+
37
+ model.eval()
38
+ model = model.to(device)
39
+
40
+ # -------------------------------------------------------------------------
41
+ # Utils
42
+ # -------------------------------------------------------------------------
43
+ def predictions_to_glb(
44
+ predictions,
45
+ conf_thres=50.0,
46
+ filter_by_frames="all",
47
+ show_cam=True,
48
+ ) -> trimesh.Scene:
49
+ """
50
+ Converts predictions to a 3D scene represented as a GLB file.
51
+
52
+ Args:
53
+ predictions (dict): Dictionary containing model predictions with keys:
54
+ - world_points: 3D point coordinates (S, H, W, 3)
55
+ - world_points_conf: Confidence scores (S, H, W)
56
+ - images: Input images (S, H, W, 3)
57
+ - extrinsic: Camera extrinsic matrices (S, 3, 4)
58
+ conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0)
59
+ filter_by_frames (str): Frame filter specification (default: "all")
60
+ show_cam (bool): Include camera visualization (default: True)
61
+
62
+ Returns:
63
+ trimesh.Scene: Processed 3D scene containing point cloud and cameras
64
+
65
+ Raises:
66
+ ValueError: If input predictions structure is invalid
67
+ """
68
+ if not isinstance(predictions, dict):
69
+ raise ValueError("predictions must be a dictionary")
70
+
71
+ if conf_thres is None:
72
+ conf_thres = 10
73
+
74
+ print("Building GLB scene")
75
+ selected_frame_idx = None
76
+ if filter_by_frames != "all" and filter_by_frames != "All":
77
+ try:
78
+ # Extract the index part before the colon
79
+ selected_frame_idx = int(filter_by_frames.split(":")[0])
80
+ except (ValueError, IndexError):
81
+ pass
82
+
83
+ pred_world_points = predictions["points"]
84
+ pred_world_points_conf = predictions.get("conf", np.ones_like(pred_world_points[..., 0]))
85
+
86
+ # Get images from predictions
87
+ images = predictions["images"]
88
+ # Use extrinsic matrices instead of pred_extrinsic_list
89
+ camera_poses = predictions["camera_poses"]
90
+
91
+ if selected_frame_idx is not None:
92
+ pred_world_points = pred_world_points[selected_frame_idx][None]
93
+ pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
94
+ images = images[selected_frame_idx][None]
95
+ camera_poses = camera_poses[selected_frame_idx][None]
96
+
97
+ vertices_3d = pred_world_points.reshape(-1, 3)
98
+ # Handle different image formats - check if images need transposing
99
+ if images.ndim == 4 and images.shape[1] == 3: # NCHW format
100
+ colors_rgb = np.transpose(images, (0, 2, 3, 1))
101
+ else: # Assume already in NHWC format
102
+ colors_rgb = images
103
+ colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
104
+
105
+ conf = pred_world_points_conf.reshape(-1)
106
+ # Convert percentage threshold to actual confidence value
107
+ if conf_thres == 0.0:
108
+ conf_threshold = 0.0
109
+ else:
110
+ # conf_threshold = np.percentile(conf, conf_thres)
111
+ conf_threshold = conf_thres / 100
112
+
113
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
114
+
115
+ vertices_3d = vertices_3d[conf_mask]
116
+ colors_rgb = colors_rgb[conf_mask]
117
+
118
+ if vertices_3d is None or np.asarray(vertices_3d).size == 0:
119
+ vertices_3d = np.array([[1, 0, 0]])
120
+ colors_rgb = np.array([[255, 255, 255]])
121
+ scene_scale = 1
122
+ else:
123
+ # Calculate the 5th and 95th percentiles along each axis
124
+ lower_percentile = np.percentile(vertices_3d, 5, axis=0)
125
+ upper_percentile = np.percentile(vertices_3d, 95, axis=0)
126
+
127
+ # Calculate the diagonal length of the percentile bounding box
128
+ scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
129
+
130
+ colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
131
+
132
+ # Initialize a 3D scene
133
+ scene_3d = trimesh.Scene()
134
+
135
+ # Add point cloud data to the scene
136
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
137
+
138
+ scene_3d.add_geometry(point_cloud_data)
139
+
140
+ # Prepare 4x4 matrices for camera extrinsics
141
+ num_cameras = len(camera_poses)
142
+
143
+ if show_cam:
144
+ # Add camera models to the scene
145
+ for i in range(num_cameras):
146
+ camera_to_world = camera_poses[i]
147
+ rgba_color = colormap(i / num_cameras)
148
+ current_color = tuple(int(255 * x) for x in rgba_color[:3])
149
+
150
+ # integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
151
+ integrate_camera_into_scene(scene_3d, camera_to_world, current_color, 1.) # fixed camera size
152
+
153
+ # Rotate scene for better visualize
154
+ align_rotation = np.eye(4)
155
+ align_rotation[:3, :3] = Rotation.from_euler("y", 100, degrees=True).as_matrix() # plane rotate
156
+ align_rotation[:3, :3] = align_rotation[:3, :3] @ Rotation.from_euler("x", 155, degrees=True).as_matrix() # roll
157
+ scene_3d.apply_transform(align_rotation)
158
+
159
+ print("GLB Scene built")
160
+ return scene_3d
161
+
162
+ def get_opengl_conversion_matrix() -> np.ndarray:
163
+ """
164
+ Constructs and returns the OpenGL conversion matrix.
165
+
166
+ Returns:
167
+ numpy.ndarray: A 4x4 OpenGL conversion matrix.
168
+ """
169
+ # Create an identity matrix
170
+ matrix = np.identity(4)
171
+
172
+ # Flip the y and z axes
173
+ matrix[1, 1] = -1
174
+ matrix[2, 2] = -1
175
+
176
+ return matrix
177
+
178
+ def integrate_camera_into_scene(scene: trimesh.Scene, transform: np.ndarray, face_colors: tuple, scene_scale: float):
179
+ """
180
+ Integrates a fake camera mesh into the 3D scene.
181
+
182
+ Args:
183
+ scene (trimesh.Scene): The 3D scene to add the camera model.
184
+ transform (np.ndarray): Transformation matrix for camera positioning.
185
+ face_colors (tuple): Color of the camera face.
186
+ scene_scale (float): Scale of the scene.
187
+ """
188
+
189
+ cam_width = scene_scale * 0.05
190
+ cam_height = scene_scale * 0.1
191
+
192
+ # Create cone shape for camera
193
+ rot_45_degree = np.eye(4)
194
+ rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
195
+ rot_45_degree[2, 3] = -cam_height
196
+
197
+ opengl_transform = get_opengl_conversion_matrix()
198
+ # Combine transformations
199
+ complete_transform = transform @ opengl_transform @ rot_45_degree
200
+ camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
201
+
202
+ # Generate mesh for the camera
203
+ slight_rotation = np.eye(4)
204
+ slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
205
+
206
+ vertices_combined = np.concatenate(
207
+ [
208
+ camera_cone_shape.vertices,
209
+ 0.95 * camera_cone_shape.vertices,
210
+ transform_points(slight_rotation, camera_cone_shape.vertices),
211
+ ]
212
+ )
213
+ vertices_transformed = transform_points(complete_transform, vertices_combined)
214
+
215
+ mesh_faces = compute_camera_faces(camera_cone_shape)
216
+
217
+ # Add the camera mesh to the scene
218
+ camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
219
+ camera_mesh.visual.face_colors[:, :3] = face_colors
220
+ scene.add_geometry(camera_mesh)
221
+
222
+ def transform_points(transformation: np.ndarray, points: np.ndarray, dim: int = None) -> np.ndarray:
223
+ """
224
+ Applies a 4x4 transformation to a set of points.
225
+
226
+ Args:
227
+ transformation (np.ndarray): Transformation matrix.
228
+ points (np.ndarray): Points to be transformed.
229
+ dim (int, optional): Dimension for reshaping the result.
230
+
231
+ Returns:
232
+ np.ndarray: Transformed points.
233
+ """
234
+ points = np.asarray(points)
235
+ initial_shape = points.shape[:-1]
236
+ dim = dim or points.shape[-1]
237
+
238
+ # Apply transformation
239
+ transformation = transformation.swapaxes(-1, -2) # Transpose the transformation matrix
240
+ points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
241
+
242
+ # Reshape the result
243
+ result = points[..., :dim].reshape(*initial_shape, dim)
244
+ return result
245
+
246
+ def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
247
+ """
248
+ Computes the faces for the camera mesh.
249
+
250
+ Args:
251
+ cone_shape (trimesh.Trimesh): The shape of the camera cone.
252
+
253
+ Returns:
254
+ np.ndarray: Array of faces for the camera mesh.
255
+ """
256
+ # Create pseudo cameras
257
+ faces_list = []
258
+ num_vertices_cone = len(cone_shape.vertices)
259
+
260
+ for face in cone_shape.faces:
261
+ if 0 in face:
262
+ continue
263
+ v1, v2, v3 = face
264
+ v1_offset, v2_offset, v3_offset = face + num_vertices_cone
265
+ v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
266
+
267
+ faces_list.extend(
268
+ [
269
+ (v1, v2, v2_offset),
270
+ (v1, v1_offset, v3),
271
+ (v3_offset, v2, v3),
272
+ (v1, v2, v2_offset_2),
273
+ (v1, v1_offset_2, v3),
274
+ (v3_offset_2, v2, v3),
275
+ ]
276
+ )
277
+
278
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
279
+ return np.array(faces_list)
280
+
281
+ # -------------------------------------------------------------------------
282
+ # 1) Core model inference
283
+ # -------------------------------------------------------------------------
284
+ def run_model(target_dir, model) -> dict:
285
+ print(f"Processing images from {target_dir}")
286
+
287
+ # Device check
288
+ device = "cuda" if torch.cuda.is_available() else "cpu"
289
+ if not torch.cuda.is_available():
290
+ raise ValueError("CUDA is not available. Check your environment.")
291
+
292
+ # Move model to device
293
+ model = model.to(device)
294
+ model.eval()
295
+
296
+ # Load and preprocess images
297
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
298
+ image_names = sorted(image_names)
299
+ print(f"Found {len(image_names)} images")
300
+ if len(image_names) == 0:
301
+ raise ValueError("No images found. Check your upload.")
302
+
303
+ # interval = 10 if target_dir.endswith('.mp4') else 1
304
+ interval = 1
305
+ imgs = load_images_as_tensor(os.path.join(target_dir, "images"), interval=interval).to(device) # (N, 3, H, W)
306
+
307
+ # Run inference
308
+ print("Running inference...")
309
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
310
+
311
+ with torch.no_grad():
312
+ with torch.cuda.amp.autocast(dtype=dtype):
313
+ predictions = model(imgs[None]) # Add batch dimension
314
+ predictions['images'] = imgs[None].permute(0, 1, 3, 4, 2)
315
+ predictions['conf'] = torch.sigmoid(predictions['conf'])
316
+ edge = depth_edge(predictions['local_points'][..., 2], rtol=0.03)
317
+ predictions['conf'][edge] = 0.0
318
+ del predictions['local_points']
319
+
320
+ # Convert tensors to numpy
321
+ for key in predictions.keys():
322
+ if isinstance(predictions[key], torch.Tensor):
323
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
324
+
325
+ # Clean up
326
+ torch.cuda.empty_cache()
327
+ return predictions
328
+
329
+
330
+ # -------------------------------------------------------------------------
331
+ # 2) Handle uploaded video/images --> produce target_dir + images
332
+ # -------------------------------------------------------------------------
333
+ def handle_uploads(input_video, input_images):
334
+ """
335
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
336
+ images or extracted frames from video into it. Return (target_dir, image_paths).
337
+ """
338
+ start_time = time.time()
339
+ gc.collect()
340
+ torch.cuda.empty_cache()
341
+
342
+ # Create a unique folder name
343
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
344
+ target_dir = f"input_images_{timestamp}"
345
+ target_dir_images = os.path.join(target_dir, "images")
346
+
347
+ # Clean up if somehow that folder already exists
348
+ if os.path.exists(target_dir):
349
+ shutil.rmtree(target_dir)
350
+ os.makedirs(target_dir)
351
+ os.makedirs(target_dir_images)
352
+
353
+ image_paths = []
354
+
355
+ # --- Handle images ---
356
+ if input_images is not None:
357
+ for file_data in input_images:
358
+ if isinstance(file_data, dict) and "name" in file_data:
359
+ file_path = file_data["name"]
360
+ else:
361
+ file_path = file_data
362
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
363
+ shutil.copy(file_path, dst_path)
364
+ image_paths.append(dst_path)
365
+
366
+ # --- Handle video ---
367
+ if input_video is not None:
368
+ if isinstance(input_video, dict) and "name" in input_video:
369
+ video_path = input_video["name"]
370
+ else:
371
+ video_path = input_video
372
+
373
+ vs = cv2.VideoCapture(video_path)
374
+ fps = vs.get(cv2.CAP_PROP_FPS)
375
+ frame_interval = int(fps * 1) # 1 frame/sec
376
+
377
+ count = 0
378
+ video_frame_num = 0
379
+ while True:
380
+ gotit, frame = vs.read()
381
+ if not gotit:
382
+ break
383
+ count += 1
384
+ if count % frame_interval == 0:
385
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
386
+ cv2.imwrite(image_path, frame)
387
+ image_paths.append(image_path)
388
+ video_frame_num += 1
389
+
390
+ # Sort final images for gallery
391
+ image_paths = sorted(image_paths)
392
+
393
+ end_time = time.time()
394
+ print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
395
+ return target_dir, image_paths
396
+
397
+
398
+ # -------------------------------------------------------------------------
399
+ # 3) Update gallery on upload
400
+ # -------------------------------------------------------------------------
401
+ def update_gallery_on_upload(input_video, input_images):
402
+ """
403
+ Whenever user uploads or changes files, immediately handle them
404
+ and show in the gallery. Return (target_dir, image_paths).
405
+ If nothing is uploaded, returns "None" and empty list.
406
+ """
407
+ if not input_video and not input_images:
408
+ return None, None, None, None
409
+ target_dir, image_paths = handle_uploads(input_video, input_images)
410
+ return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
411
+
412
+
413
+ # -------------------------------------------------------------------------
414
+ # 4) Reconstruction: uses the target_dir plus any viz parameters
415
+ # -------------------------------------------------------------------------
416
+ def gradio_demo(
417
+ target_dir,
418
+ conf_thres=3.0,
419
+ frame_filter="All",
420
+ show_cam=True,
421
+ ):
422
+ """
423
+ Perform reconstruction using the already-created target_dir/images.
424
+ """
425
+ if not os.path.isdir(target_dir) or target_dir == "None":
426
+ return None, "No valid target directory found. Please upload first.", None, None
427
+
428
+ start_time = time.time()
429
+ gc.collect()
430
+ torch.cuda.empty_cache()
431
+
432
+ # Prepare frame_filter dropdown
433
+ target_dir_images = os.path.join(target_dir, "images")
434
+ all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
435
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
436
+ frame_filter_choices = ["All"] + all_files
437
+
438
+ print("Running run_model...")
439
+ with torch.no_grad():
440
+ predictions = run_model(target_dir, model)
441
+
442
+ # Save predictions
443
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
444
+ np.savez(prediction_save_path, **predictions)
445
+
446
+ # Handle None frame_filter
447
+ if frame_filter is None:
448
+ frame_filter = "All"
449
+
450
+ # Build a GLB file name
451
+ glbfile = os.path.join(
452
+ target_dir,
453
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}.glb",
454
+ )
455
+
456
+ # Convert predictions to GLB
457
+ glbscene = predictions_to_glb(
458
+ predictions,
459
+ conf_thres=conf_thres,
460
+ filter_by_frames=frame_filter,
461
+ show_cam=show_cam,
462
+ # mask_sky=mask_sky,
463
+ )
464
+ glbscene.export(file_obj=glbfile)
465
+
466
+ # Cleanup
467
+ del predictions
468
+ gc.collect()
469
+ torch.cuda.empty_cache()
470
+
471
+ end_time = time.time()
472
+ print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
473
+ log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
474
+
475
+ return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
476
+
477
+
478
+ # -------------------------------------------------------------------------
479
+ # 5) Helper functions for UI resets + re-visualization
480
+ # -------------------------------------------------------------------------
481
+ def clear_fields():
482
+ """
483
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
484
+ """
485
+ return None
486
+
487
+
488
+ def update_log():
489
+ """
490
+ Display a quick log message while waiting.
491
+ """
492
+ return "Loading and Reconstructing..."
493
+
494
+
495
+ def update_visualization(
496
+ target_dir, conf_thres, frame_filter, show_cam, is_example
497
+ ):
498
+ """
499
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
500
+ and return it for the 3D viewer. If is_example == "True", skip.
501
+ """
502
+
503
+ # If it's an example click, skip as requested
504
+ if is_example == "True":
505
+ return None, "No reconstruction available. Please click the Reconstruct button first."
506
+
507
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
508
+ return None, "No reconstruction available. Please click the Reconstruct button first."
509
+
510
+ predictions_path = os.path.join(target_dir, "predictions.npz")
511
+ if not os.path.exists(predictions_path):
512
+ return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
513
+
514
+ key_list = [
515
+ "images",
516
+ "points",
517
+ "conf",
518
+ "camera_poses",
519
+ ]
520
+
521
+ loaded = np.load(predictions_path)
522
+ predictions = {key: np.array(loaded[key]) for key in key_list}
523
+
524
+ glbfile = os.path.join(
525
+ target_dir,
526
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}.glb",
527
+ )
528
+
529
+ if not os.path.exists(glbfile):
530
+ glbscene = predictions_to_glb(
531
+ predictions,
532
+ conf_thres=conf_thres,
533
+ filter_by_frames=frame_filter,
534
+ show_cam=show_cam,
535
+ # mask_sky=mask_sky,
536
+ )
537
+ glbscene.export(file_obj=glbfile)
538
+
539
+ return glbfile, "Updating Visualization"
540
+
541
+
542
+ # -------------------------------------------------------------------------
543
+ # Example images
544
+ # -------------------------------------------------------------------------
545
+
546
+ great_wall_video = "examples/videos/great_wall.mp4"
547
+ colosseum_video = "examples/videos/Colosseum.mp4"
548
+ room_video = "examples/videos/room.mp4"
549
+ kitchen_video = "examples/videos/kitchen.mp4"
550
+ fern_video = "examples/videos/fern.mp4"
551
+ single_cartoon_video = "examples/videos/single_cartoon.mp4"
552
+ single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
553
+ pyramid_video = "examples/videos/pyramid.mp4"
554
+
555
+
556
+ # -------------------------------------------------------------------------
557
+ # 6) Build Gradio UI
558
+ # -------------------------------------------------------------------------
559
+ theme = gr.themes.Ocean()
560
+ theme.set(
561
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
562
+ checkbox_label_text_color_selected="*button_primary_text_color",
563
+ )
564
+
565
+ with gr.Blocks(
566
+ theme=theme,
567
+ css="""
568
+ .custom-log * {
569
+ font-style: italic;
570
+ font-size: 22px !important;
571
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
572
+ -webkit-background-clip: text;
573
+ background-clip: text;
574
+ font-weight: bold !important;
575
+ color: transparent !important;
576
+ text-align: center !important;
577
+ }
578
+
579
+ .example-log * {
580
+ font-style: italic;
581
+ font-size: 16px !important;
582
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
583
+ -webkit-background-clip: text;
584
+ background-clip: text;
585
+ color: transparent !important;
586
+ }
587
+
588
+ #my_radio .wrap {
589
+ display: flex;
590
+ flex-wrap: nowrap;
591
+ justify-content: center;
592
+ align-items: center;
593
+ }
594
+
595
+ #my_radio .wrap label {
596
+ display: flex;
597
+ width: 50%;
598
+ justify-content: center;
599
+ align-items: center;
600
+ margin: 0;
601
+ padding: 10px 0;
602
+ box-sizing: border-box;
603
+ }
604
+ """,
605
+ ) as demo:
606
+ # Instead of gr.State, we use a hidden Textbox:
607
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
608
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
609
+
610
+ gr.HTML(
611
+ """
612
+ <h1>Flow3r: Factored Flow Prediction for Visual Geometry Learning</h1>
613
+ <p>
614
+ <a href="https://github.com/Kidrauh/flow3r">GitHub Repository</a> |
615
+ <a href="https://flow3r-project.github.io/">Project Page</a>
616
+ </p>
617
+
618
+ <div style="font-size: 16px; line-height: 1.5;">
619
+ <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. Flow3r takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
620
+
621
+ </div>
622
+ """
623
+ )
624
+
625
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
626
+
627
+ with gr.Row():
628
+ with gr.Column(scale=2):
629
+ input_video = gr.Video(label="Upload Video", interactive=True)
630
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
631
+
632
+ image_gallery = gr.Gallery(
633
+ label="Preview",
634
+ columns=4,
635
+ height="300px",
636
+ # show_download_button=True,
637
+ object_fit="contain",
638
+ preview=True,
639
+ )
640
+
641
+ with gr.Column(scale=4):
642
+ with gr.Column():
643
+ gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
644
+ log_output = gr.Markdown(
645
+ "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
646
+ )
647
+ reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
648
+
649
+ with gr.Row():
650
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
651
+ clear_btn = gr.ClearButton(
652
+ [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
653
+ scale=1,
654
+ )
655
+
656
+ with gr.Row():
657
+ conf_thres = gr.Slider(minimum=0, maximum=100, value=0, step=0.1, label="Confidence Threshold (%)")
658
+ frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
659
+ with gr.Column():
660
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
661
+
662
+
663
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
664
+ fn=update_log, inputs=[], outputs=[log_output]
665
+ ).then(
666
+ fn=gradio_demo,
667
+ inputs=[
668
+ target_dir_output,
669
+ conf_thres,
670
+ frame_filter,
671
+ show_cam,
672
+ ],
673
+ outputs=[reconstruction_output, log_output, frame_filter],
674
+ ).then(
675
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
676
+ )
677
 
678
+ # -------------------------------------------------------------------------
679
+ # Real-time Visualization Updates
680
+ # -------------------------------------------------------------------------
681
+ conf_thres.change(
682
+ update_visualization,
683
+ [
684
+ target_dir_output,
685
+ conf_thres,
686
+ frame_filter,
687
+ show_cam,
688
+ is_example,
689
+ ],
690
+ [reconstruction_output, log_output],
691
+ )
692
+ frame_filter.change(
693
+ update_visualization,
694
+ [
695
+ target_dir_output,
696
+ conf_thres,
697
+ frame_filter,
698
+ show_cam,
699
+ is_example,
700
+ ],
701
+ [reconstruction_output, log_output],
702
+ )
703
 
704
+ show_cam.change(
705
+ update_visualization,
706
+ [
707
+ target_dir_output,
708
+ conf_thres,
709
+ frame_filter,
710
+ show_cam,
711
+ is_example,
712
+ ],
713
+ [reconstruction_output, log_output],
714
+ )
715
+ # -------------------------------------------------------------------------
716
+ # Auto-update gallery whenever user uploads or changes their files
717
+ # -------------------------------------------------------------------------
718
+ input_video.change(
719
+ fn=update_gallery_on_upload,
720
+ inputs=[input_video, input_images],
721
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
722
+ )
723
+ input_images.change(
724
+ fn=update_gallery_on_upload,
725
+ inputs=[input_video, input_images],
726
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
727
+ )
728
 
729
+ demo.queue(max_size=20).launch(show_error=True, share=True)
flow3r/models/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
flow3r/models/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
flow3r/models/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+
85
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
86
+ """
87
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
88
+ """
89
+ return _make_dinov2_model(
90
+ arch_name="vit_giant2",
91
+ ffn_layer="swiglufused",
92
+ weights=weights,
93
+ pretrained=pretrained,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
99
+ """
100
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
101
+ """
102
+ return _make_dinov2_model(
103
+ arch_name="vit_small",
104
+ pretrained=pretrained,
105
+ weights=weights,
106
+ num_register_tokens=4,
107
+ interpolate_antialias=True,
108
+ interpolate_offset=0.0,
109
+ **kwargs,
110
+ )
111
+
112
+
113
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
114
+ """
115
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
116
+ """
117
+ return _make_dinov2_model(
118
+ arch_name="vit_base",
119
+ pretrained=pretrained,
120
+ weights=weights,
121
+ num_register_tokens=4,
122
+ interpolate_antialias=True,
123
+ interpolate_offset=0.0,
124
+ **kwargs,
125
+ )
126
+
127
+
128
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
129
+ """
130
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
131
+ """
132
+ return _make_dinov2_model(
133
+ arch_name="vit_large",
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ num_register_tokens=4,
137
+ interpolate_antialias=True,
138
+ interpolate_offset=0.0,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
144
+ """
145
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
146
+ """
147
+ return _make_dinov2_model(
148
+ arch_name="vit_giant2",
149
+ ffn_layer="swiglufused",
150
+ weights=weights,
151
+ pretrained=pretrained,
152
+ num_register_tokens=4,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ **kwargs,
156
+ )
flow3r/models/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
flow3r/models/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
flow3r/models/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ # warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ # warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ # warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
flow3r/models/dinov2/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ # warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ # warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+ # warnings.warn("xFormers is not available (Block)")
40
+
41
+
42
+ class Block(nn.Module):
43
+ def __init__(
44
+ self,
45
+ dim: int,
46
+ num_heads: int,
47
+ mlp_ratio: float = 4.0,
48
+ qkv_bias: bool = False,
49
+ proj_bias: bool = True,
50
+ ffn_bias: bool = True,
51
+ drop: float = 0.0,
52
+ attn_drop: float = 0.0,
53
+ init_values=None,
54
+ drop_path: float = 0.0,
55
+ act_layer: Callable[..., nn.Module] = nn.GELU,
56
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
57
+ attn_class: Callable[..., nn.Module] = Attention,
58
+ ffn_layer: Callable[..., nn.Module] = Mlp,
59
+ ) -> None:
60
+ super().__init__()
61
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
62
+ self.norm1 = norm_layer(dim)
63
+ self.attn = attn_class(
64
+ dim,
65
+ num_heads=num_heads,
66
+ qkv_bias=qkv_bias,
67
+ proj_bias=proj_bias,
68
+ attn_drop=attn_drop,
69
+ proj_drop=drop,
70
+ )
71
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
72
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
73
+
74
+ self.norm2 = norm_layer(dim)
75
+ mlp_hidden_dim = int(dim * mlp_ratio)
76
+ self.mlp = ffn_layer(
77
+ in_features=dim,
78
+ hidden_features=mlp_hidden_dim,
79
+ act_layer=act_layer,
80
+ drop=drop,
81
+ bias=ffn_bias,
82
+ )
83
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
84
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
85
+
86
+ self.sample_drop_ratio = drop_path
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ def attn_residual_func(x: Tensor) -> Tensor:
90
+ return self.ls1(self.attn(self.norm1(x)))
91
+
92
+ def ffn_residual_func(x: Tensor) -> Tensor:
93
+ return self.ls2(self.mlp(self.norm2(x)))
94
+
95
+ if self.training and self.sample_drop_ratio > 0.1:
96
+ # the overhead is compensated only for a drop path rate larger than 0.1
97
+ x = drop_add_residual_stochastic_depth(
98
+ x,
99
+ residual_func=attn_residual_func,
100
+ sample_drop_ratio=self.sample_drop_ratio,
101
+ )
102
+ x = drop_add_residual_stochastic_depth(
103
+ x,
104
+ residual_func=ffn_residual_func,
105
+ sample_drop_ratio=self.sample_drop_ratio,
106
+ )
107
+ elif self.training and self.sample_drop_ratio > 0.0:
108
+ x = x + self.drop_path1(attn_residual_func(x))
109
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
110
+ else:
111
+ x = x + attn_residual_func(x)
112
+ x = x + ffn_residual_func(x)
113
+ return x
114
+
115
+
116
+ def drop_add_residual_stochastic_depth(
117
+ x: Tensor,
118
+ residual_func: Callable[[Tensor], Tensor],
119
+ sample_drop_ratio: float = 0.0,
120
+ ) -> Tensor:
121
+ # 1) extract subset using permutation
122
+ b, n, d = x.shape
123
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
124
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
125
+ x_subset = x[brange]
126
+
127
+ # 2) apply residual_func to get residual
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
flow3r/models/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
flow3r/models/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
flow3r/models/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
flow3r/models/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
flow3r/models/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
flow3r/models/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
flow3r/models/dinov2/models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from . import vision_transformer as vits
9
+
10
+
11
+ logger = logging.getLogger("dinov2")
12
+
13
+
14
+ def build_model(args, only_teacher=False, img_size=224):
15
+ args.arch = args.arch.removesuffix("_memeff")
16
+ if "vit" in args.arch:
17
+ vit_kwargs = dict(
18
+ img_size=img_size,
19
+ patch_size=args.patch_size,
20
+ init_values=args.layerscale,
21
+ ffn_layer=args.ffn_layer,
22
+ block_chunks=args.block_chunks,
23
+ qkv_bias=args.qkv_bias,
24
+ proj_bias=args.proj_bias,
25
+ ffn_bias=args.ffn_bias,
26
+ num_register_tokens=args.num_register_tokens,
27
+ interpolate_offset=args.interpolate_offset,
28
+ interpolate_antialias=args.interpolate_antialias,
29
+ )
30
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
31
+ if only_teacher:
32
+ return teacher, teacher.embed_dim
33
+ student = vits.__dict__[args.arch](
34
+ **vit_kwargs,
35
+ drop_path_rate=args.drop_path_rate,
36
+ drop_path_uniform=args.drop_path_uniform,
37
+ )
38
+ embed_dim = student.embed_dim
39
+ return student, teacher, embed_dim
40
+
41
+
42
+ def build_model_from_cfg(cfg, only_teacher=False):
43
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
flow3r/models/dinov2/models/vision_transformer.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.checkpoint import checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+ from ...layers.attention import FlashAttention
22
+
23
+
24
+ # logger = logging.getLogger("dinov2")
25
+
26
+
27
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
28
+ if not depth_first and include_root:
29
+ fn(module=module, name=name)
30
+ for child_name, child_module in module.named_children():
31
+ child_name = ".".join((name, child_name)) if name else child_name
32
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
33
+ if depth_first and include_root:
34
+ fn(module=module, name=name)
35
+ return module
36
+
37
+
38
+ class BlockChunk(nn.ModuleList):
39
+ def forward(self, x):
40
+ for b in self:
41
+ x = b(x)
42
+ return x
43
+
44
+
45
+ class DinoVisionTransformer(nn.Module):
46
+ def __init__(
47
+ self,
48
+ img_size=224,
49
+ patch_size=16,
50
+ in_chans=3,
51
+ embed_dim=768,
52
+ depth=12,
53
+ num_heads=12,
54
+ mlp_ratio=4.0,
55
+ qkv_bias=True,
56
+ ffn_bias=True,
57
+ proj_bias=True,
58
+ drop_path_rate=0.0,
59
+ drop_path_uniform=False,
60
+ init_values=None, # for layerscale: None or 0 => no layerscale
61
+ embed_layer=PatchEmbed,
62
+ act_layer=nn.GELU,
63
+ block_fn=Block,
64
+ ffn_layer="mlp",
65
+ block_chunks=1,
66
+ num_register_tokens=0,
67
+ interpolate_antialias=False,
68
+ interpolate_offset=0.1,
69
+ ):
70
+ """
71
+ Args:
72
+ img_size (int, tuple): input image size
73
+ patch_size (int, tuple): patch size
74
+ in_chans (int): number of input channels
75
+ embed_dim (int): embedding dimension
76
+ depth (int): depth of transformer
77
+ num_heads (int): number of attention heads
78
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
79
+ qkv_bias (bool): enable bias for qkv if True
80
+ proj_bias (bool): enable bias for proj in attn if True
81
+ ffn_bias (bool): enable bias for ffn if True
82
+ drop_path_rate (float): stochastic depth rate
83
+ drop_path_uniform (bool): apply uniform drop rate across blocks
84
+ weight_init (str): weight init scheme
85
+ init_values (float): layer-scale init values
86
+ embed_layer (nn.Module): patch embedding layer
87
+ act_layer (nn.Module): MLP activation layer
88
+ block_fn (nn.Module): transformer block class
89
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
90
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
91
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
92
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
93
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
94
+ """
95
+ super().__init__()
96
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
97
+
98
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
99
+ self.num_tokens = 1
100
+ self.n_blocks = depth
101
+ self.num_heads = num_heads
102
+ self.patch_size = patch_size
103
+ self.num_register_tokens = num_register_tokens
104
+ self.interpolate_antialias = interpolate_antialias
105
+ self.interpolate_offset = interpolate_offset
106
+
107
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
108
+ num_patches = self.patch_embed.num_patches
109
+
110
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
111
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
112
+ assert num_register_tokens >= 0
113
+ self.register_tokens = (
114
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
115
+ )
116
+
117
+ if drop_path_uniform is True:
118
+ dpr = [drop_path_rate] * depth
119
+ else:
120
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
121
+
122
+ if ffn_layer == "mlp":
123
+ # logger.info("using MLP layer as FFN")
124
+ ffn_layer = Mlp
125
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
126
+ # logger.info("using SwiGLU layer as FFN")
127
+ ffn_layer = SwiGLUFFNFused
128
+ elif ffn_layer == "identity":
129
+ # logger.info("using Identity layer as FFN")
130
+
131
+ def f(*args, **kwargs):
132
+ return nn.Identity()
133
+
134
+ ffn_layer = f
135
+ else:
136
+ raise NotImplementedError
137
+
138
+ blocks_list = [
139
+ block_fn(
140
+ dim=embed_dim,
141
+ num_heads=num_heads,
142
+ mlp_ratio=mlp_ratio,
143
+ qkv_bias=qkv_bias,
144
+ proj_bias=proj_bias,
145
+ ffn_bias=ffn_bias,
146
+ drop_path=dpr[i],
147
+ norm_layer=norm_layer,
148
+ act_layer=act_layer,
149
+ ffn_layer=ffn_layer,
150
+ init_values=init_values,
151
+ attn_class=FlashAttention
152
+ )
153
+ for i in range(depth)
154
+ ]
155
+ if block_chunks > 0:
156
+ self.chunked_blocks = True
157
+ chunked_blocks = []
158
+ chunksize = depth // block_chunks
159
+ for i in range(0, depth, chunksize):
160
+ # this is to keep the block index consistent if we chunk the block list
161
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
162
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
163
+ else:
164
+ self.chunked_blocks = False
165
+ self.blocks = nn.ModuleList(blocks_list)
166
+
167
+ self.norm = norm_layer(embed_dim)
168
+ self.head = nn.Identity()
169
+
170
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
171
+
172
+ self.init_weights()
173
+
174
+ def init_weights(self):
175
+ trunc_normal_(self.pos_embed, std=0.02)
176
+ nn.init.normal_(self.cls_token, std=1e-6)
177
+ if self.register_tokens is not None:
178
+ nn.init.normal_(self.register_tokens, std=1e-6)
179
+ named_apply(init_weights_vit_timm, self)
180
+
181
+ def interpolate_pos_encoding(self, x, w, h):
182
+ previous_dtype = x.dtype
183
+ npatch = x.shape[1] - 1
184
+ N = self.pos_embed.shape[1] - 1
185
+ if npatch == N and w == h:
186
+ return self.pos_embed
187
+ pos_embed = self.pos_embed.float()
188
+ class_pos_embed = pos_embed[:, 0]
189
+ patch_pos_embed = pos_embed[:, 1:]
190
+ dim = x.shape[-1]
191
+ w0 = w // self.patch_size
192
+ h0 = h // self.patch_size
193
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
194
+ assert N == M * M
195
+ kwargs = {}
196
+ if self.interpolate_offset:
197
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
198
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
199
+ sx = float(w0 + self.interpolate_offset) / M
200
+ sy = float(h0 + self.interpolate_offset) / M
201
+ kwargs["scale_factor"] = (sx, sy)
202
+ else:
203
+ # Simply specify an output size instead of a scale factor
204
+ kwargs["size"] = (w0, h0)
205
+ patch_pos_embed = nn.functional.interpolate(
206
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
207
+ mode="bicubic",
208
+ antialias=self.interpolate_antialias,
209
+ **kwargs,
210
+ )
211
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
212
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
213
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
214
+
215
+ def prepare_tokens_with_masks(self, x, masks=None):
216
+ B, nc, w, h = x.shape
217
+ x = self.patch_embed(x)
218
+ if masks is not None:
219
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
220
+
221
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
222
+ x = x + self.interpolate_pos_encoding(x, w, h)
223
+
224
+ if self.register_tokens is not None:
225
+ x = torch.cat(
226
+ (
227
+ x[:, :1],
228
+ self.register_tokens.expand(x.shape[0], -1, -1),
229
+ x[:, 1:],
230
+ ),
231
+ dim=1,
232
+ )
233
+
234
+ return x
235
+
236
+ def forward_features_list(self, x_list, masks_list):
237
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
238
+ for blk in self.blocks:
239
+ if self.training:
240
+ x = checkpoint(blk, x, use_reentrant=False)
241
+ else:
242
+ x = blk(x)
243
+
244
+ all_x = x
245
+ output = []
246
+ for x, masks in zip(all_x, masks_list):
247
+ x_norm = self.norm(x)
248
+ output.append(
249
+ {
250
+ "x_norm_clstoken": x_norm[:, 0],
251
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
252
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
253
+ "x_prenorm": x,
254
+ "masks": masks,
255
+ }
256
+ )
257
+ return output
258
+
259
+ def forward_features(self, x, masks=None):
260
+ if isinstance(x, list):
261
+ return self.forward_features_list(x, masks)
262
+
263
+ x = self.prepare_tokens_with_masks(x, masks)
264
+
265
+ for blk in self.blocks:
266
+ if self.training:
267
+ x = checkpoint(blk, x, use_reentrant=False)
268
+ else:
269
+ x = blk(x)
270
+
271
+ x_norm = self.norm(x)
272
+ return {
273
+ "x_norm_clstoken": x_norm[:, 0],
274
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
275
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
276
+ "x_prenorm": x,
277
+ "masks": masks,
278
+ }
279
+
280
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
281
+ x = self.prepare_tokens_with_masks(x)
282
+ # If n is an int, take the n last blocks. If it's a list, take them
283
+ output, total_block_len = [], len(self.blocks)
284
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
285
+ for i, blk in enumerate(self.blocks):
286
+ x = blk(x)
287
+ if i in blocks_to_take:
288
+ output.append(x)
289
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
290
+ return output
291
+
292
+ def _get_intermediate_layers_chunked(self, x, n=1):
293
+ x = self.prepare_tokens_with_masks(x)
294
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
295
+ # If n is an int, take the n last blocks. If it's a list, take them
296
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
297
+ for block_chunk in self.blocks:
298
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
299
+ x = blk(x)
300
+ if i in blocks_to_take:
301
+ output.append(x)
302
+ i += 1
303
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
304
+ return output
305
+
306
+ def get_intermediate_layers(
307
+ self,
308
+ x: torch.Tensor,
309
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
310
+ reshape: bool = False,
311
+ return_class_token: bool = False,
312
+ norm=True,
313
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
314
+ if self.chunked_blocks:
315
+ outputs = self._get_intermediate_layers_chunked(x, n)
316
+ else:
317
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
318
+ if norm:
319
+ outputs = [self.norm(out) for out in outputs]
320
+ class_tokens = [out[:, 0] for out in outputs]
321
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
322
+ if reshape:
323
+ B, _, w, h = x.shape
324
+ outputs = [
325
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
326
+ for out in outputs
327
+ ]
328
+ if return_class_token:
329
+ return tuple(zip(outputs, class_tokens))
330
+ return tuple(outputs)
331
+
332
+ def forward(self, *args, is_training=False, **kwargs):
333
+ ret = self.forward_features(*args, **kwargs)
334
+ if is_training:
335
+ return ret
336
+ else:
337
+ return self.head(ret["x_norm_clstoken"])
338
+
339
+
340
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
341
+ """ViT weight initialization, original timm impl (for reproducibility)"""
342
+ if isinstance(module, nn.Linear):
343
+ trunc_normal_(module.weight, std=0.02)
344
+ if module.bias is not None:
345
+ nn.init.zeros_(module.bias)
346
+
347
+
348
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
349
+ model = DinoVisionTransformer(
350
+ patch_size=patch_size,
351
+ embed_dim=384,
352
+ depth=12,
353
+ num_heads=6,
354
+ mlp_ratio=4,
355
+ block_fn=partial(Block, attn_class=MemEffAttention),
356
+ num_register_tokens=num_register_tokens,
357
+ **kwargs,
358
+ )
359
+ return model
360
+
361
+
362
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
363
+ model = DinoVisionTransformer(
364
+ patch_size=patch_size,
365
+ embed_dim=768,
366
+ depth=12,
367
+ num_heads=12,
368
+ mlp_ratio=4,
369
+ block_fn=partial(Block, attn_class=MemEffAttention),
370
+ num_register_tokens=num_register_tokens,
371
+ **kwargs,
372
+ )
373
+ return model
374
+
375
+
376
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
377
+ model = DinoVisionTransformer(
378
+ patch_size=patch_size,
379
+ embed_dim=1024,
380
+ depth=24,
381
+ num_heads=16,
382
+ mlp_ratio=4,
383
+ block_fn=partial(Block, attn_class=MemEffAttention),
384
+ num_register_tokens=num_register_tokens,
385
+ **kwargs,
386
+ )
387
+ return model
388
+
389
+
390
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
391
+ """
392
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
393
+ """
394
+ model = DinoVisionTransformer(
395
+ patch_size=patch_size,
396
+ embed_dim=1536,
397
+ depth=40,
398
+ num_heads=24,
399
+ mlp_ratio=4,
400
+ block_fn=partial(Block, attn_class=MemEffAttention),
401
+ num_register_tokens=num_register_tokens,
402
+ **kwargs,
403
+ )
404
+ return model
flow3r/models/dinov2/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
flow3r/models/dinov2/utils/cluster.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+
12
+ class ClusterType(Enum):
13
+ AWS = "aws"
14
+ FAIR = "fair"
15
+ RSC = "rsc"
16
+
17
+
18
+ def _guess_cluster_type() -> ClusterType:
19
+ uname = os.uname()
20
+ if uname.sysname == "Linux":
21
+ if uname.release.endswith("-aws"):
22
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
23
+ return ClusterType.AWS
24
+ elif uname.nodename.startswith("rsc"):
25
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
26
+ return ClusterType.RSC
27
+
28
+ return ClusterType.FAIR
29
+
30
+
31
+ def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
32
+ if cluster_type is None:
33
+ return _guess_cluster_type()
34
+
35
+ return cluster_type
36
+
37
+
38
+ def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
39
+ cluster_type = get_cluster_type(cluster_type)
40
+ if cluster_type is None:
41
+ return None
42
+
43
+ CHECKPOINT_DIRNAMES = {
44
+ ClusterType.AWS: "checkpoints",
45
+ ClusterType.FAIR: "checkpoint",
46
+ ClusterType.RSC: "checkpoint/dino",
47
+ }
48
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
49
+
50
+
51
+ def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
52
+ checkpoint_path = get_checkpoint_path(cluster_type)
53
+ if checkpoint_path is None:
54
+ return None
55
+
56
+ username = os.environ.get("USER")
57
+ assert username is not None
58
+ return checkpoint_path / username
59
+
60
+
61
+ def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
62
+ cluster_type = get_cluster_type(cluster_type)
63
+ if cluster_type is None:
64
+ return None
65
+
66
+ SLURM_PARTITIONS = {
67
+ ClusterType.AWS: "learnlab",
68
+ ClusterType.FAIR: "learnlab",
69
+ ClusterType.RSC: "learn",
70
+ }
71
+ return SLURM_PARTITIONS[cluster_type]
72
+
73
+
74
+ def get_slurm_executor_parameters(
75
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
76
+ ) -> Dict[str, Any]:
77
+ # create default parameters
78
+ params = {
79
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
80
+ "gpus_per_node": num_gpus_per_node,
81
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
82
+ "cpus_per_task": 10,
83
+ "nodes": nodes,
84
+ "slurm_partition": get_slurm_partition(cluster_type),
85
+ }
86
+ # apply cluster-specific adjustments
87
+ cluster_type = get_cluster_type(cluster_type)
88
+ if cluster_type == ClusterType.AWS:
89
+ params["cpus_per_task"] = 12
90
+ del params["mem_gb"]
91
+ elif cluster_type == ClusterType.RSC:
92
+ params["cpus_per_task"] = 12
93
+ # set additional parameters / apply overrides
94
+ params.update(kwargs)
95
+ return params
flow3r/models/dinov2/utils/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import logging
8
+ import os
9
+
10
+ from omegaconf import OmegaConf
11
+
12
+ import dinov2.distributed as distributed
13
+ from dinov2.logging import setup_logging
14
+ from dinov2.utils import utils
15
+ from dinov2.configs import dinov2_default_config
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ def apply_scaling_rules_to_cfg(cfg): # to fix
22
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
23
+ base_lr = cfg.optim.base_lr
24
+ cfg.optim.lr = base_lr
25
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
26
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
27
+ else:
28
+ raise NotImplementedError
29
+ return cfg
30
+
31
+
32
+ def write_config(cfg, output_dir, name="config.yaml"):
33
+ logger.info(OmegaConf.to_yaml(cfg))
34
+ saved_cfg_path = os.path.join(output_dir, name)
35
+ with open(saved_cfg_path, "w") as f:
36
+ OmegaConf.save(config=cfg, f=f)
37
+ return saved_cfg_path
38
+
39
+
40
+ def get_cfg_from_args(args):
41
+ args.output_dir = os.path.abspath(args.output_dir)
42
+ args.opts += [f"train.output_dir={args.output_dir}"]
43
+ default_cfg = OmegaConf.create(dinov2_default_config)
44
+ cfg = OmegaConf.load(args.config_file)
45
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
46
+ return cfg
47
+
48
+
49
+ def default_setup(args):
50
+ distributed.enable(overwrite=True)
51
+ seed = getattr(args, "seed", 0)
52
+ rank = distributed.get_global_rank()
53
+
54
+ global logger
55
+ setup_logging(output=args.output_dir, level=logging.INFO)
56
+ logger = logging.getLogger("dinov2")
57
+
58
+ utils.fix_random_seeds(seed + rank)
59
+ logger.info("git:\n {}\n".format(utils.get_sha()))
60
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
61
+
62
+
63
+ def setup(args):
64
+ """
65
+ Create configs and perform basic setups.
66
+ """
67
+ cfg = get_cfg_from_args(args)
68
+ os.makedirs(args.output_dir, exist_ok=True)
69
+ default_setup(args)
70
+ apply_scaling_rules_to_cfg(cfg)
71
+ write_config(cfg, args.output_dir)
72
+ return cfg
flow3r/models/dinov2/utils/dtype.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import Dict, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ TypeSpec = Union[str, np.dtype, torch.dtype]
14
+
15
+
16
+ _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
17
+ np.dtype("bool"): torch.bool,
18
+ np.dtype("uint8"): torch.uint8,
19
+ np.dtype("int8"): torch.int8,
20
+ np.dtype("int16"): torch.int16,
21
+ np.dtype("int32"): torch.int32,
22
+ np.dtype("int64"): torch.int64,
23
+ np.dtype("float16"): torch.float16,
24
+ np.dtype("float32"): torch.float32,
25
+ np.dtype("float64"): torch.float64,
26
+ np.dtype("complex64"): torch.complex64,
27
+ np.dtype("complex128"): torch.complex128,
28
+ }
29
+
30
+
31
+ def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
32
+ if isinstance(dtype, torch.dtype):
33
+ return dtype
34
+ if isinstance(dtype, str):
35
+ dtype = np.dtype(dtype)
36
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
37
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
flow3r/models/dinov2/utils/param_groups.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ import logging
8
+
9
+
10
+ logger = logging.getLogger("dinov2")
11
+
12
+
13
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
14
+ """
15
+ Calculate lr decay rate for different ViT blocks.
16
+ Args:
17
+ name (string): parameter name.
18
+ lr_decay_rate (float): base lr decay rate.
19
+ num_layers (int): number of ViT blocks.
20
+ Returns:
21
+ lr decay rate for the given parameter.
22
+ """
23
+ layer_id = num_layers + 1
24
+ if name.startswith("backbone") or force_is_backbone:
25
+ if (
26
+ ".pos_embed" in name
27
+ or ".patch_embed" in name
28
+ or ".mask_token" in name
29
+ or ".cls_token" in name
30
+ or ".register_tokens" in name
31
+ ):
32
+ layer_id = 0
33
+ elif force_is_backbone and (
34
+ "pos_embed" in name
35
+ or "patch_embed" in name
36
+ or "mask_token" in name
37
+ or "cls_token" in name
38
+ or "register_tokens" in name
39
+ ):
40
+ layer_id = 0
41
+ elif ".blocks." in name and ".residual." not in name:
42
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
43
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
44
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
45
+ elif "blocks." in name and "residual." not in name:
46
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
47
+
48
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
49
+
50
+
51
+ def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
52
+ chunked_blocks = False
53
+ if hasattr(model, "n_blocks"):
54
+ logger.info("chunked fsdp")
55
+ n_blocks = model.n_blocks
56
+ chunked_blocks = model.chunked_blocks
57
+ elif hasattr(model, "blocks"):
58
+ logger.info("first code branch")
59
+ n_blocks = len(model.blocks)
60
+ elif hasattr(model, "backbone"):
61
+ logger.info("second code branch")
62
+ n_blocks = len(model.backbone.blocks)
63
+ else:
64
+ logger.info("else code branch")
65
+ n_blocks = 0
66
+ all_param_groups = []
67
+
68
+ for name, param in model.named_parameters():
69
+ name = name.replace("_fsdp_wrapped_module.", "")
70
+ if not param.requires_grad:
71
+ continue
72
+ decay_rate = get_vit_lr_decay_rate(
73
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
74
+ )
75
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
76
+
77
+ if "last_layer" in name:
78
+ d.update({"is_last_layer": True})
79
+
80
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
81
+ d.update({"wd_multiplier": 0.0})
82
+
83
+ if "patch_embed" in name:
84
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
85
+
86
+ all_param_groups.append(d)
87
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
88
+
89
+ return all_param_groups
90
+
91
+
92
+ def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
93
+ fused_params_groups = defaultdict(lambda: {"params": []})
94
+ for d in all_params_groups:
95
+ identifier = ""
96
+ for k in keys:
97
+ identifier += k + str(d[k]) + "_"
98
+
99
+ for k in keys:
100
+ fused_params_groups[identifier][k] = d[k]
101
+ fused_params_groups[identifier]["params"].append(d["params"])
102
+
103
+ return fused_params_groups.values()
flow3r/models/dinov2/utils/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import random
9
+ import subprocess
10
+ from urllib.parse import urlparse
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+
16
+
17
+ # logger = logging.getLogger("dinov2")
18
+
19
+
20
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
21
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
22
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
23
+ else:
24
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
25
+ if checkpoint_key is not None and checkpoint_key in state_dict:
26
+ # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
27
+ state_dict = state_dict[checkpoint_key]
28
+ # remove `module.` prefix
29
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
30
+ # remove `backbone.` prefix induced by multicrop wrapper
31
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
32
+ msg = model.load_state_dict(state_dict, strict=False)
33
+ # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
34
+
35
+
36
+ def fix_random_seeds(seed=31):
37
+ """
38
+ Fix random seeds.
39
+ """
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed)
42
+ np.random.seed(seed)
43
+ random.seed(seed)
44
+
45
+
46
+ def get_sha():
47
+ cwd = os.path.dirname(os.path.abspath(__file__))
48
+
49
+ def _run(command):
50
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
51
+
52
+ sha = "N/A"
53
+ diff = "clean"
54
+ branch = "N/A"
55
+ try:
56
+ sha = _run(["git", "rev-parse", "HEAD"])
57
+ subprocess.check_output(["git", "diff"], cwd=cwd)
58
+ diff = _run(["git", "diff-index", "HEAD"])
59
+ diff = "has uncommitted changes" if diff else "clean"
60
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
61
+ except Exception:
62
+ pass
63
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
64
+ return message
65
+
66
+
67
+ class CosineScheduler(object):
68
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
69
+ super().__init__()
70
+ self.final_value = final_value
71
+ self.total_iters = total_iters
72
+
73
+ freeze_schedule = np.zeros((freeze_iters))
74
+
75
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
76
+
77
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
78
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
79
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
80
+
81
+ assert len(self.schedule) == self.total_iters
82
+
83
+ def __getitem__(self, it):
84
+ if it >= self.total_iters:
85
+ return self.final_value
86
+ else:
87
+ return self.schedule[it]
88
+
89
+
90
+ def has_batchnorms(model):
91
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
92
+ for name, module in model.named_modules():
93
+ if isinstance(module, bn_types):
94
+ return True
95
+ return False
flow3r/models/flow3r.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+ from copy import deepcopy
5
+
6
+ from .dinov2.layers import Mlp
7
+ from ..utils.geometry import homogenize_points
8
+ from .layers.pos_embed import RoPE2D, PositionGetter
9
+ from .layers.block import BlockRope
10
+ from .layers.attention import FlashAttentionRope
11
+ from .layers.transformer_head import TransformerDecoder, LinearPts3d, ContextTransformerDecoder
12
+ from .layers.camera_head import CameraHead
13
+ from .flow_head.dpt_head import DPTHead
14
+ from .dinov2.hub.backbones import dinov2_vitl14_reg
15
+
16
+
17
+ class Flow3r(nn.Module):
18
+ def __init__(
19
+ self,
20
+ pos_type='rope100',
21
+ decoder_size='large',
22
+ ):
23
+ super().__init__()
24
+
25
+
26
+ # ----------------------
27
+ # Encoder
28
+ # ----------------------
29
+ self.encoder = dinov2_vitl14_reg(pretrained=False)
30
+ self.patch_size = 14
31
+ del self.encoder.mask_token
32
+
33
+ # ----------------------
34
+ # Positonal Encoding
35
+ # ----------------------
36
+ self.pos_type = pos_type if pos_type is not None else 'none'
37
+ self.rope=None
38
+ if self.pos_type.startswith('rope'): # eg rope100
39
+ if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
40
+ freq = float(self.pos_type[len('rope'):])
41
+ self.rope = RoPE2D(freq=freq)
42
+ self.position_getter = PositionGetter()
43
+ else:
44
+ raise NotImplementedError
45
+
46
+
47
+ # ----------------------
48
+ # Decoder
49
+ # ----------------------
50
+ if decoder_size == 'small':
51
+ dec_embed_dim = 384
52
+ dec_num_heads = 6
53
+ mlp_ratio = 4
54
+ dec_depth = 24
55
+ elif decoder_size == 'base':
56
+ dec_embed_dim = 768
57
+ dec_num_heads = 12
58
+ mlp_ratio = 4
59
+ dec_depth = 24
60
+ elif decoder_size == 'large':
61
+ dec_embed_dim = 1024
62
+ dec_num_heads = 16
63
+ mlp_ratio = 4
64
+ dec_depth = 36
65
+ else:
66
+ raise NotImplementedError
67
+ self.decoder = nn.ModuleList([
68
+ BlockRope(
69
+ dim=dec_embed_dim,
70
+ num_heads=dec_num_heads,
71
+ mlp_ratio=mlp_ratio,
72
+ qkv_bias=True,
73
+ proj_bias=True,
74
+ ffn_bias=True,
75
+ drop_path=0.0,
76
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
77
+ act_layer=nn.GELU,
78
+ ffn_layer=Mlp,
79
+ init_values=0.01,
80
+ qk_norm=True,
81
+ attn_class=FlashAttentionRope,
82
+ rope=self.rope
83
+ ) for _ in range(dec_depth)])
84
+ self.dec_embed_dim = dec_embed_dim
85
+
86
+ # ----------------------
87
+ # Register_token
88
+ # ----------------------
89
+ num_register_tokens = 5
90
+ self.patch_start_idx = num_register_tokens
91
+ self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim))
92
+ nn.init.normal_(self.register_token, std=1e-6)
93
+
94
+ # ----------------------
95
+ # Local Points Decoder
96
+ # ----------------------
97
+ self.point_decoder = TransformerDecoder(
98
+ in_dim=2*self.dec_embed_dim,
99
+ dec_embed_dim=1024,
100
+ dec_num_heads=16,
101
+ out_dim=1024,
102
+ rope=self.rope
103
+ )
104
+ self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3)
105
+
106
+ # ----------------------
107
+ # Camera Pose Decoder
108
+ # ----------------------
109
+ self.camera_decoder = TransformerDecoder(
110
+ in_dim=2*self.dec_embed_dim,
111
+ dec_embed_dim=1024,
112
+ dec_num_heads=16, # 8
113
+ out_dim=512,
114
+ rope=self.rope,
115
+ use_checkpoint=False
116
+ )
117
+ self.camera_head = CameraHead(dim=512)
118
+
119
+ # ----------------------
120
+ # Motion Flow Decoder
121
+ # ----------------------
122
+ self.flow_head = DPTHead(
123
+ patch_size=14,
124
+ output_dim=2,
125
+ )
126
+
127
+ # ----------------------
128
+ # Conf Decoder
129
+ # ----------------------
130
+ self.conf_decoder = deepcopy(self.point_decoder)
131
+ self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1)
132
+
133
+ # For ImageNet Normalize
134
+ image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
135
+ image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
136
+
137
+ self.register_buffer("image_mean", image_mean)
138
+ self.register_buffer("image_std", image_std)
139
+
140
+
141
+ def decode(self, hidden, N, H, W):
142
+ BN, hw, _ = hidden.shape
143
+ B = BN // N
144
+
145
+ final_output = []
146
+
147
+ hidden = hidden.reshape(B*N, hw, -1)
148
+
149
+ register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:])
150
+
151
+ # Concatenate special tokens with patch tokens
152
+ hidden = torch.cat([register_token, hidden], dim=1)
153
+ hw = hidden.shape[1]
154
+
155
+ if self.pos_type.startswith('rope'):
156
+ pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device)
157
+
158
+ if self.patch_start_idx > 0:
159
+ # do not use position embedding for special tokens (camera and register tokens)
160
+ # so set pos to 0 for the special tokens
161
+ pos = pos + 1
162
+ pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype)
163
+ pos = torch.cat([pos_special, pos], dim=1)
164
+
165
+ for i in range(len(self.decoder)):
166
+ blk = self.decoder[i]
167
+
168
+ if i % 2 == 0:
169
+ pos = pos.reshape(B*N, hw, -1)
170
+ hidden = hidden.reshape(B*N, hw, -1)
171
+ else:
172
+ pos = pos.reshape(B, N*hw, -1)
173
+ hidden = hidden.reshape(B, N*hw, -1)
174
+
175
+ hidden = blk(hidden, xpos=pos)
176
+
177
+ if i+1 in [len(self.decoder)-1, len(self.decoder)]:
178
+ final_output.append(hidden.reshape(B*N, hw, -1))
179
+
180
+ return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1)
181
+
182
+ def forward(self, imgs, pair_indices=None):
183
+ imgs = (imgs - self.image_mean) / self.image_std
184
+ # print("the shape of imgs is", imgs.shape)
185
+
186
+ B, N, _, H, W = imgs.shape
187
+ patch_h, patch_w = H // 14, W // 14
188
+
189
+ # encode by dinov2
190
+ imgs = imgs.reshape(B*N, _, H, W)
191
+ hidden = self.encoder(imgs, is_training=True)
192
+
193
+ if isinstance(hidden, dict):
194
+ hidden = hidden["x_norm_patchtokens"]
195
+
196
+ hidden, pos = self.decode(hidden, N, H, W)
197
+
198
+ point_hidden, point_intermediate = self.point_decoder(hidden, xpos=pos, return_intermediate=True)
199
+ conf_hidden = self.conf_decoder(hidden, xpos=pos)
200
+ camera_hidden, camera_intermediate = self.camera_decoder(hidden, xpos=pos, return_intermediate=True)
201
+
202
+ with torch.amp.autocast(device_type='cuda', enabled=False):
203
+ # local points
204
+ point_hidden = point_hidden.float()
205
+ ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
206
+ xy, z = ret.split([2, 1], dim=-1)
207
+ z = torch.exp(z)
208
+ local_points = torch.cat([xy * z, z], dim=-1)
209
+
210
+ # confidence
211
+ conf_hidden = conf_hidden.float()
212
+ conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
213
+
214
+ # camera
215
+ camera_hidden = camera_hidden.float()
216
+ camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4)
217
+
218
+ # Flow
219
+ if pair_indices is not None:
220
+ flow = self.flow_head([t.float() for t in point_intermediate], [t.float() for t in camera_intermediate], pair_indices, self.patch_start_idx,(H, W), B, N)
221
+ else:
222
+ flow = None
223
+
224
+ # unproject local points using camera poses
225
+ points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3]
226
+
227
+ return dict(
228
+ points=points,
229
+ local_points=local_points,
230
+ conf=conf,
231
+ camera_poses=camera_poses,
232
+ flow=flow,
233
+ )
flow3r/models/flow_head/dpt_head.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from .utils import create_uv_grid, position_grid_to_embed
17
+
18
+
19
+ class DPTHead(nn.Module):
20
+ """
21
+ DPT Head for dense prediction tasks.
22
+
23
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
24
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
25
+ backbone and produces dense predictions by fusing multi-scale features.
26
+
27
+ Args:
28
+ dim_in (int): Input dimension (channels).
29
+ patch_size (int, optional): Patch size. Default is 14.
30
+ output_dim (int, optional): Number of output channels. Default is 4.
31
+ activation (str, optional): Activation type. Default is "inv_log".
32
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
33
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
34
+ out_channels (List[int], optional): Output channels for each intermediate layer.
35
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
36
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
37
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
38
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ dim_in: int = 1024,
44
+ patch_size: int = 14,
45
+ output_dim: int = 4,
46
+ activation: str = "inv_log",
47
+ conf_activation: str = "expp1",
48
+ features: int = 256,
49
+ out_channels: List[int] = [256, 512, 1024, 1024],
50
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
51
+ pos_embed: bool = True,
52
+ feature_only: bool = False,
53
+ down_ratio: int = 1,
54
+ ) -> None:
55
+ super(DPTHead, self).__init__()
56
+ self.patch_size = patch_size
57
+ self.activation = activation
58
+ self.conf_activation = conf_activation
59
+ self.pos_embed = pos_embed
60
+ self.feature_only = feature_only
61
+ self.down_ratio = down_ratio
62
+ self.intermediate_layer_idx = intermediate_layer_idx
63
+ self.dim_in = dim_in
64
+ self.output_dim = output_dim
65
+
66
+ self.mlp = nn.Sequential(
67
+ nn.Linear(3 * dim_in, 2*dim_in),
68
+ nn.ReLU(),
69
+ nn.Linear(2*dim_in, 2*dim_in),
70
+ nn.ReLU(),
71
+ nn.Linear(2*dim_in, dim_in),
72
+ )
73
+
74
+ self.norm = nn.LayerNorm(dim_in)
75
+
76
+ # Projection layers for each output channel from tokens.
77
+ self.projects = nn.ModuleList(
78
+ [
79
+ nn.Conv2d(
80
+ in_channels=dim_in,
81
+ out_channels=oc,
82
+ kernel_size=1,
83
+ stride=1,
84
+ padding=0,
85
+ )
86
+ for oc in out_channels
87
+ ]
88
+ )
89
+
90
+ # Resize layers for upsampling feature maps.
91
+ self.resize_layers = nn.ModuleList(
92
+ [
93
+ nn.ConvTranspose2d(
94
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
95
+ ),
96
+ nn.ConvTranspose2d(
97
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
98
+ ),
99
+ nn.Identity(),
100
+ nn.Conv2d(
101
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
102
+ ),
103
+ ]
104
+ )
105
+
106
+ self.scratch = _make_scratch(
107
+ out_channels,
108
+ features,
109
+ expand=False,
110
+ )
111
+
112
+ # Attach additional modules to scratch.
113
+ self.scratch.stem_transpose = None
114
+ self.scratch.refinenet1 = _make_fusion_block(features)
115
+ self.scratch.refinenet2 = _make_fusion_block(features)
116
+ self.scratch.refinenet3 = _make_fusion_block(features)
117
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
118
+
119
+ head_features_1 = features
120
+ head_features_2 = 32
121
+
122
+ if feature_only:
123
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
124
+ else:
125
+ self.scratch.output_conv1 = nn.Conv2d(
126
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
127
+ )
128
+ conv2_in_channels = head_features_1 // 2
129
+
130
+ self.scratch.output_conv2 = nn.Sequential(
131
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
132
+ nn.ReLU(inplace=True),
133
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
134
+ )
135
+
136
+ def forward(
137
+ self,
138
+ patch_intermediate_4: List[torch.Tensor], # len=4, each (B*N, hw, dec_embed_dim)
139
+ camera_intermediate_4: List[torch.Tensor], # len=4, each (B*N, hw, camera_dim)
140
+ pair_indices: torch.Tensor, # (B, S, 2)
141
+ patch_start_idx: int,
142
+ img_shape: Tuple[int, int],
143
+ B: int,
144
+ N: int,
145
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
146
+ """
147
+ Forward pass through the DPT head, supports processing by chunking frames.
148
+ Args:
149
+ patch_intermediate_4 (List[Tensor]): List of token tensors from different transformer layers.
150
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
151
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
152
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
153
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
154
+ If None or larger than S, all frames are processed at once. Default: 8.
155
+
156
+ Returns:
157
+ Tensor or Tuple[Tensor, Tensor]:
158
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
159
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
160
+ """
161
+
162
+ feats_4 = []
163
+ for l in range(4):
164
+ feat_l = self._fuse_one_layer(
165
+ patch_intermediate_4[l],
166
+ camera_intermediate_4[l],
167
+ patch_start_idx,
168
+ pair_indices,
169
+ img_shape,
170
+ B,
171
+ N,
172
+ )
173
+ feats_4.append(feat_l)
174
+
175
+ flow = self._dpt_fuse_and_predict(feats_4, img_shape) # (T,2,H,W)
176
+
177
+ H, W = img_shape
178
+ S = pair_indices.shape[1]
179
+ return flow.permute(0, 2, 3, 1).reshape(B, S, H, W, self.output_dim)
180
+
181
+
182
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
183
+ """
184
+ Apply positional embedding to tensor x.
185
+ """
186
+ patch_w = x.shape[-1]
187
+ patch_h = x.shape[-2]
188
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
189
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
190
+ pos_embed = pos_embed * ratio
191
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
192
+ return x + pos_embed
193
+
194
+ def _fuse_one_layer(
195
+ self,
196
+ patch_hidden_l: torch.Tensor, # (B*N, hw, dec_embed_dim)
197
+ camera_hidden_l: torch.Tensor, # (B*N, hw, camera_dim)
198
+ patch_start_idx: int,
199
+ pair_indices: torch.Tensor, # (B, S, 2)
200
+ img_shape: Tuple[int, int],
201
+ B: int,
202
+ N: int,
203
+ ) -> Tuple[torch.Tensor, int, int, int]:
204
+ """
205
+ Returns:
206
+ feat_map: (T, dec_embed_dim, patch_h, patch_w) where T = B*S
207
+ T, patch_h, patch_w
208
+ """
209
+ H, W = img_shape
210
+ hw = patch_hidden_l[:, patch_start_idx:].shape[1]
211
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
212
+ assert hw == patch_h * patch_w, (hw, patch_h, patch_w)
213
+
214
+ # reshape to (B, N, hw, C)
215
+ patch_hidden_l = patch_hidden_l[:, patch_start_idx:].reshape(B, N, hw, self.dim_in)
216
+ camera_hidden_l = camera_hidden_l[:, patch_start_idx:].reshape(B, N, hw, self.dim_in)
217
+
218
+ S = pair_indices.shape[1]
219
+ batch_idx = torch.arange(B, device=pair_indices.device).unsqueeze(1).expand(B, S)
220
+ idx_i = pair_indices[:, :, 0]
221
+ idx_j = pair_indices[:, :, 1]
222
+
223
+ patch_i = patch_hidden_l[batch_idx, idx_i] # (B,S,hw,dec)
224
+ cam_i = camera_hidden_l[batch_idx, idx_i] # (B,S,hw,cam)
225
+ cam_j = camera_hidden_l[batch_idx, idx_j] # (B,S,hw,cam)
226
+
227
+ # Average cam_j to get a single camera token per pair
228
+ cam_j = cam_j.mean(dim=2, keepdim=True).expand(-1, -1, hw, -1)
229
+
230
+ # concat + flatten
231
+ concat = torch.cat([cam_i, cam_j, patch_i], dim=-1) # (B,S,hw, 2cam+dec)
232
+ T = B * S
233
+ x = concat.reshape(B * S, hw, 3 * self.dim_in)
234
+
235
+ # MLP fuse
236
+ x = self.mlp(x) # (B * S, hw, dim_in)
237
+ x = self.norm(x)
238
+
239
+ # token -> grid
240
+ feat = x.transpose(1, 2).reshape(B * S, self.dim_in, patch_h, patch_w) # (B * S, dim_in, ph, pw)
241
+ return feat
242
+
243
+ def _dpt_fuse_and_predict(
244
+ self,
245
+ feats_4: List[torch.Tensor],
246
+ img_shape: Tuple[int, int],
247
+ ) -> torch.Tensor:
248
+ """
249
+ Runs standard DPT fusion and outputs flow:
250
+ returns (T, 2, H, W)
251
+ """
252
+ H, W = img_shape
253
+ out = []
254
+ for i in range(4):
255
+ x = feats_4[i] # (T, dec, ph, pw)
256
+ x = self.projects[i](x)
257
+ if self.pos_embed:
258
+ x = self._apply_pos_embed(x, W, H)
259
+ x = self.resize_layers[i](x) # multi-scale path
260
+ out.append(x)
261
+
262
+ x = self.scratch_forward(out) # (T, features, ...)
263
+ x = custom_interpolate(
264
+ x, (H, W),
265
+ mode="bilinear",
266
+ align_corners=True,
267
+ )
268
+ if self.pos_embed:
269
+ out = self._apply_pos_embed(x, W, H)
270
+
271
+ flow = self.scratch.output_conv2(out) # (T, 2,H, W)
272
+ return flow
273
+
274
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
275
+ """
276
+ Forward pass through the fusion blocks.
277
+
278
+ Args:
279
+ features (List[Tensor]): List of feature maps from different layers.
280
+
281
+ Returns:
282
+ Tensor: Fused feature map.
283
+ """
284
+ layer_1, layer_2, layer_3, layer_4 = features
285
+
286
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
287
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
288
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
289
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
290
+
291
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
292
+ del layer_4_rn, layer_4
293
+
294
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
295
+ del layer_3_rn, layer_3
296
+
297
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
298
+ del layer_2_rn, layer_2
299
+
300
+ out = self.scratch.refinenet1(out, layer_1_rn)
301
+ del layer_1_rn, layer_1
302
+
303
+ out = self.scratch.output_conv1(out)
304
+ return out
305
+
306
+
307
+
308
+ ################################################################################
309
+ # Modules
310
+ ################################################################################
311
+
312
+
313
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
314
+ return FeatureFusionBlock(
315
+ features,
316
+ nn.ReLU(inplace=True),
317
+ deconv=False,
318
+ bn=False,
319
+ expand=False,
320
+ align_corners=True,
321
+ size=size,
322
+ has_residual=has_residual,
323
+ groups=groups,
324
+ )
325
+
326
+
327
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
328
+ scratch = nn.Module()
329
+ out_shape1 = out_shape
330
+ out_shape2 = out_shape
331
+ out_shape3 = out_shape
332
+ if len(in_shape) >= 4:
333
+ out_shape4 = out_shape
334
+
335
+ if expand:
336
+ out_shape1 = out_shape
337
+ out_shape2 = out_shape * 2
338
+ out_shape3 = out_shape * 4
339
+ if len(in_shape) >= 4:
340
+ out_shape4 = out_shape * 8
341
+
342
+ scratch.layer1_rn = nn.Conv2d(
343
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
344
+ )
345
+ scratch.layer2_rn = nn.Conv2d(
346
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
347
+ )
348
+ scratch.layer3_rn = nn.Conv2d(
349
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
350
+ )
351
+ if len(in_shape) >= 4:
352
+ scratch.layer4_rn = nn.Conv2d(
353
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
354
+ )
355
+ return scratch
356
+
357
+
358
+ class ResidualConvUnit(nn.Module):
359
+ """Residual convolution module."""
360
+
361
+ def __init__(self, features, activation, bn, groups=1):
362
+ """Init.
363
+
364
+ Args:
365
+ features (int): number of features
366
+ """
367
+ super().__init__()
368
+
369
+ self.bn = bn
370
+ self.groups = groups
371
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
372
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
373
+
374
+ self.norm1 = None
375
+ self.norm2 = None
376
+
377
+ self.activation = activation
378
+ self.skip_add = nn.quantized.FloatFunctional()
379
+
380
+ def forward(self, x):
381
+ """Forward pass.
382
+
383
+ Args:
384
+ x (tensor): input
385
+
386
+ Returns:
387
+ tensor: output
388
+ """
389
+
390
+ out = self.activation(x)
391
+ out = self.conv1(out)
392
+ if self.norm1 is not None:
393
+ out = self.norm1(out)
394
+
395
+ out = self.activation(out)
396
+ out = self.conv2(out)
397
+ if self.norm2 is not None:
398
+ out = self.norm2(out)
399
+
400
+ return self.skip_add.add(out, x)
401
+
402
+
403
+ class FeatureFusionBlock(nn.Module):
404
+ """Feature fusion block."""
405
+
406
+ def __init__(
407
+ self,
408
+ features,
409
+ activation,
410
+ deconv=False,
411
+ bn=False,
412
+ expand=False,
413
+ align_corners=True,
414
+ size=None,
415
+ has_residual=True,
416
+ groups=1,
417
+ ):
418
+ """Init.
419
+
420
+ Args:
421
+ features (int): number of features
422
+ """
423
+ super(FeatureFusionBlock, self).__init__()
424
+
425
+ self.deconv = deconv
426
+ self.align_corners = align_corners
427
+ self.groups = groups
428
+ self.expand = expand
429
+ out_features = features
430
+ if self.expand == True:
431
+ out_features = features // 2
432
+
433
+ self.out_conv = nn.Conv2d(
434
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
435
+ )
436
+
437
+ if has_residual:
438
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
439
+
440
+ self.has_residual = has_residual
441
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
442
+
443
+ self.skip_add = nn.quantized.FloatFunctional()
444
+ self.size = size
445
+
446
+ def forward(self, *xs, size=None):
447
+ """Forward pass.
448
+
449
+ Returns:
450
+ tensor: output
451
+ """
452
+ output = xs[0]
453
+
454
+ if self.has_residual:
455
+ res = self.resConfUnit1(xs[1])
456
+ output = self.skip_add.add(output, res)
457
+
458
+ output = self.resConfUnit2(output)
459
+
460
+ if (size is None) and (self.size is None):
461
+ modifier = {"scale_factor": 2}
462
+ elif size is None:
463
+ modifier = {"size": self.size}
464
+ else:
465
+ modifier = {"size": size}
466
+
467
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
468
+ output = self.out_conv(output)
469
+
470
+ return output
471
+
472
+
473
+ def custom_interpolate(
474
+ x: torch.Tensor,
475
+ size: Tuple[int, int] = None,
476
+ scale_factor: float = None,
477
+ mode: str = "bilinear",
478
+ align_corners: bool = True,
479
+ ) -> torch.Tensor:
480
+ """
481
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
482
+ """
483
+ if size is None:
484
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
485
+
486
+ INT_MAX = 1610612736
487
+
488
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
489
+
490
+ if input_elements > INT_MAX:
491
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
492
+ interpolated_chunks = [
493
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
494
+ ]
495
+ x = torch.cat(interpolated_chunks, dim=0)
496
+ return x.contiguous()
497
+ else:
498
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
flow3r/models/flow_head/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
12
+ """
13
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
14
+
15
+ Args:
16
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
17
+ embed_dim: Output channel dimension for embeddings
18
+
19
+ Returns:
20
+ Tensor of shape (H, W, embed_dim) with positional embeddings
21
+ """
22
+ H, W, grid_dim = pos_grid.shape
23
+ assert grid_dim == 2
24
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
25
+
26
+ # Process x and y coordinates separately
27
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
28
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
29
+
30
+ # Combine and reshape
31
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
32
+
33
+ return emb.view(H, W, embed_dim) # [H, W, D]
34
+
35
+
36
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
37
+ """
38
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - pos: The position to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 1D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
49
+ omega /= embed_dim / 2.0
50
+ omega = 1.0 / omega_0**omega # (D/2,)
51
+
52
+ pos = pos.reshape(-1) # (M,)
53
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
54
+
55
+ emb_sin = torch.sin(out) # (M, D/2)
56
+ emb_cos = torch.cos(out) # (M, D/2)
57
+
58
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
59
+ return emb.float()
60
+
61
+
62
+ # Inspired by https://github.com/microsoft/moge
63
+
64
+
65
+ def create_uv_grid(
66
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
67
+ ) -> torch.Tensor:
68
+ """
69
+ Create a normalized UV grid of shape (width, height, 2).
70
+
71
+ The grid spans horizontally and vertically according to an aspect ratio,
72
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
73
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
74
+
75
+ Args:
76
+ width (int): Number of points horizontally.
77
+ height (int): Number of points vertically.
78
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
79
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
80
+ device (torch.device, optional): Device on which the tensor is created.
81
+
82
+ Returns:
83
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
84
+ """
85
+ # Derive aspect ratio if not explicitly provided
86
+ if aspect_ratio is None:
87
+ aspect_ratio = float(width) / float(height)
88
+
89
+ # Compute normalized spans for X and Y
90
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
91
+ span_x = aspect_ratio / diag_factor
92
+ span_y = 1.0 / diag_factor
93
+
94
+ # Establish the linspace boundaries
95
+ left_x = -span_x * (width - 1) / width
96
+ right_x = span_x * (width - 1) / width
97
+ top_y = -span_y * (height - 1) / height
98
+ bottom_y = span_y * (height - 1) / height
99
+
100
+ # Generate 1D coordinates
101
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
102
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
103
+
104
+ # Create 2D meshgrid (width x height) and stack into UV
105
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
106
+ uv_grid = torch.stack((uu, vv), dim=-1)
107
+
108
+ return uv_grid
flow3r/models/layers/attention.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+ import torch
17
+
18
+ from torch.nn.functional import scaled_dot_product_attention
19
+ from torch.nn.attention import SDPBackend
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ # warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ # warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ # warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ # q, k, v = unbind(qkv, 2)
83
+ q, k, v = [qkv[:,:,i] for i in range(3)]
84
+
85
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
86
+ x = x.reshape([B, N, C])
87
+
88
+ x = self.proj(x)
89
+ x = self.proj_drop(x)
90
+ return x
91
+
92
+
93
+
94
+ class FlashAttention(Attention):
95
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
96
+ B, N, C = x.shape
97
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3)
98
+
99
+ # q, k, v = unbind(qkv, 2)
100
+ q, k, v = [qkv[:,:,i] for i in range(3)]
101
+
102
+ if q.dtype == torch.bfloat16:
103
+ with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
104
+ x = scaled_dot_product_attention(q, k, v)
105
+ else:
106
+ with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
107
+ x = scaled_dot_product_attention(q, k, v)
108
+
109
+ x = x.transpose(1, 2).reshape([B, N, C])
110
+
111
+ x = self.proj(x)
112
+ x = self.proj_drop(x)
113
+ return x
114
+
115
+
116
+ """
117
+ Following is written by GPT-4o
118
+ """
119
+ class CrossAttentionRope(nn.Module):
120
+ def __init__(
121
+ self,
122
+ dim: int,
123
+ num_heads: int = 8,
124
+ qkv_bias: bool = False,
125
+ proj_bias: bool = True,
126
+ attn_drop: float = 0.0,
127
+ proj_drop: float = 0.0,
128
+ qk_norm: bool = False,
129
+ norm_layer: nn.Module = nn.LayerNorm,
130
+ rope=None,
131
+ ) -> None:
132
+ super().__init__()
133
+ self.num_heads = num_heads
134
+ head_dim = dim // num_heads
135
+ self.scale = head_dim**-0.5
136
+
137
+ # Separate projection layers for query, key, and value
138
+ self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
139
+ self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
140
+ self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
141
+
142
+ self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
143
+ self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
144
+
145
+ self.attn_drop = nn.Dropout(attn_drop)
146
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
147
+ self.proj_drop = nn.Dropout(proj_drop)
148
+
149
+ self.rope = rope
150
+
151
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
152
+ """
153
+ Args:
154
+ query: Tensor of shape (B, N, C), input query
155
+ key: Tensor of shape (B, M, C), input key
156
+ value: Tensor of shape (B, M, C), input value
157
+ attn_bias: Optional tensor for attention bias
158
+ Returns:
159
+ Tensor of shape (B, N, C), output of cross-attention
160
+ """
161
+ B, N, C = query.shape
162
+ _, M, _ = key.shape
163
+
164
+ # Project query, key, and value
165
+ q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
166
+ k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
167
+ v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
168
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
169
+
170
+ if self.rope is not None:
171
+ q = self.rope(q, qpos)
172
+ k = self.rope(k, kpos)
173
+
174
+ # Scale query
175
+ q = q * self.scale
176
+
177
+ # Compute attention scores
178
+ attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M)
179
+ if attn_bias is not None:
180
+ attn = attn + attn_bias
181
+
182
+ attn = attn.softmax(dim=-1)
183
+ attn = self.attn_drop(attn)
184
+
185
+ # Compute attention output
186
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C)
187
+
188
+ # Final projection
189
+ x = self.proj(x)
190
+ x = self.proj_drop(x)
191
+ return x
192
+
193
+
194
+ class MemEffCrossAttentionRope(CrossAttentionRope):
195
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
196
+ """
197
+ Args:
198
+ query: Tensor of shape (B, N, C), input query
199
+ key: Tensor of shape (B, M, C), input key
200
+ value: Tensor of shape (B, M, C), input value
201
+ attn_bias: Optional tensor for attention bias
202
+ Returns:
203
+ Tensor of shape (B, N, C), output of cross-attention
204
+ """
205
+ if not XFORMERS_AVAILABLE:
206
+ if attn_bias is not None:
207
+ raise AssertionError("xFormers is required for using nested tensors")
208
+ return super().forward(query, key, value, attn_bias)
209
+
210
+ B, N, C = query.shape
211
+ _, M, _ = key.shape
212
+
213
+ # Project query, key, and value
214
+ q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads)
215
+ k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads)
216
+ v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads)
217
+
218
+ q = q.transpose(1, 2)
219
+ k = k.transpose(1, 2)
220
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
221
+
222
+ if self.rope is not None:
223
+ q = self.rope(q, qpos)
224
+ k = self.rope(k, kpos)
225
+
226
+ q = q.transpose(1, 2)
227
+ k = k.transpose(1, 2)
228
+
229
+ # Compute memory-efficient attention
230
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
231
+ x = x.reshape(B, N, C)
232
+
233
+ # Final projection
234
+ x = self.proj(x)
235
+ x = self.proj_drop(x)
236
+ return x
237
+
238
+ class FlashCrossAttentionRope(CrossAttentionRope):
239
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
240
+ B, N, C = query.shape
241
+ _, M, _ = key.shape
242
+
243
+ # 1. 投射 query, key, value 并调整维度为 (B, num_heads, Seq_Len, head_dim)
244
+ q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
245
+ k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
246
+ v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
247
+
248
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
249
+ if self.rope is not None:
250
+ q = self.rope(q, qpos)
251
+ k = self.rope(k, kpos)
252
+
253
+ dropout_p = self.attn_drop.p if self.training else 0.0
254
+
255
+ if q.dtype == torch.bfloat16:
256
+ with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
257
+ x = scaled_dot_product_attention(
258
+ q, k, v, attn_mask=attn_bias, dropout_p=dropout_p
259
+ )
260
+ else:
261
+ with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
262
+ x = scaled_dot_product_attention(
263
+ q, k, v, attn_mask=attn_bias, dropout_p=dropout_p
264
+ )
265
+
266
+ x = x.transpose(1, 2).reshape(B, N, C)
267
+
268
+ x = self.proj(x)
269
+ x = self.proj_drop(x)
270
+ return x
271
+
272
+ class AttentionRope(nn.Module):
273
+ def __init__(
274
+ self,
275
+ dim: int,
276
+ num_heads: int = 8,
277
+ qkv_bias: bool = False,
278
+ proj_bias: bool = True,
279
+ attn_drop: float = 0.0,
280
+ proj_drop: float = 0.0,
281
+ qk_norm: bool = False,
282
+ norm_layer: nn.Module = nn.LayerNorm,
283
+ rope=None
284
+ ) -> None:
285
+ super().__init__()
286
+ self.num_heads = num_heads
287
+ head_dim = dim // num_heads
288
+ self.scale = head_dim**-0.5
289
+
290
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
291
+ self.attn_drop = nn.Dropout(attn_drop)
292
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
293
+ self.proj_drop = nn.Dropout(proj_drop)
294
+
295
+ self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
296
+ self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
297
+
298
+ self.rope = rope
299
+
300
+ def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
301
+ B, N, C = x.shape
302
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
303
+ q, k, v = qkv[0], qkv[1], qkv[2]
304
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
305
+
306
+ if self.rope is not None:
307
+ q = self.rope(q, xpos)
308
+ k = self.rope(k, xpos)
309
+
310
+ q = q * self.scale
311
+ attn = q @ k.transpose(-2, -1)
312
+
313
+ attn = attn.softmax(dim=-1)
314
+ attn = self.attn_drop(attn)
315
+
316
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
317
+ x = self.proj(x)
318
+ x = self.proj_drop(x)
319
+ return x
320
+
321
+
322
+ class MemEffAttentionRope(AttentionRope):
323
+ def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
324
+ if not XFORMERS_AVAILABLE:
325
+ if attn_bias is not None:
326
+ raise AssertionError("xFormers is required for using nested tensors")
327
+ return super().forward(x)
328
+
329
+ B, N, C = x.shape
330
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
331
+
332
+ qkv = qkv.transpose(1, 3)
333
+ # q, k, v = unbind(qkv, 2)
334
+ q, k, v = [qkv[:,:,i] for i in range(3)]
335
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
336
+
337
+ if self.rope is not None:
338
+ q = self.rope(q, xpos)
339
+ k = self.rope(k, xpos)
340
+
341
+ q = q.transpose(1, 2)
342
+ k = k.transpose(1, 2)
343
+ v = v.transpose(1, 2)
344
+
345
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
346
+ x = x.reshape([B, N, C])
347
+
348
+ # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix
349
+ # global_valid_id = torch.where(score_matrix > 0)
350
+ # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1)
351
+
352
+ x = self.proj(x)
353
+ x = self.proj_drop(x)
354
+ return x
355
+
356
+
357
+ class FlashAttentionRope(AttentionRope):
358
+ def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
359
+ B, N, C = x.shape
360
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3)
361
+
362
+ # q, k, v = unbind(qkv, 2)
363
+ q, k, v = [qkv[:,:,i] for i in range(3)]
364
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
365
+
366
+ if self.rope is not None:
367
+ q = self.rope(q, xpos)
368
+ k = self.rope(k, xpos)
369
+
370
+ if q.dtype == torch.bfloat16:
371
+ with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
372
+ x = scaled_dot_product_attention(q, k, v)
373
+ else:
374
+ with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
375
+ x = scaled_dot_product_attention(q, k, v)
376
+
377
+ x = x.transpose(1, 2).reshape([B, N, C])
378
+
379
+ x = self.proj(x)
380
+ x = self.proj_drop(x)
381
+ return x
382
+
383
+ def get_attn_score(blk_class, x, frame_num, token_length, xpos=None):
384
+ x = blk_class.norm1(x)
385
+
386
+ B, N, C = x.shape
387
+ qkv = blk_class.attn.qkv(x).reshape(B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads)
388
+
389
+ qkv = qkv.transpose(1, 3)
390
+ # q, k, v = unbind(qkv, 2)
391
+ q, k, v = [qkv[:,:,i] for i in range(3)]
392
+ q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype)
393
+
394
+ if blk_class.attn.rope is not None:
395
+ q = blk_class.attn.rope(q, xpos)
396
+ k = blk_class.attn.rope(k, xpos)
397
+
398
+ q = q.transpose(1, 2)
399
+ k = k.transpose(1, 2)
400
+
401
+ score = (q.permute(0, 2, 1, 3) * blk_class.attn.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(B, frame_num, token_length, frame_num, token_length).mean(dim=[2, 4]).sum(-1)
402
+
403
+ return score
flow3r/models/layers/block.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention, CrossAttentionRope, MemEffCrossAttentionRope, FlashAttentionRope
19
+ from ..dinov2.layers.drop_path import DropPath
20
+ from ..dinov2.layers.layer_scale import LayerScale
21
+ from ..dinov2.layers.mlp import Mlp
22
+
23
+
24
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
25
+ try:
26
+ if XFORMERS_ENABLED:
27
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
28
+
29
+ XFORMERS_AVAILABLE = True
30
+ # warnings.warn("xFormers is available (Block)")
31
+ else:
32
+ # warnings.warn("xFormers is disabled (Block)")
33
+ raise ImportError
34
+ except ImportError:
35
+ XFORMERS_AVAILABLE = False
36
+ # warnings.warn("xFormers is not available (Block)")
37
+
38
+
39
+ class Block(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ num_heads: int,
44
+ mlp_ratio: float = 4.0,
45
+ qkv_bias: bool = False,
46
+ proj_bias: bool = True,
47
+ ffn_bias: bool = True,
48
+ drop: float = 0.0,
49
+ attn_drop: float = 0.0,
50
+ init_values=None,
51
+ drop_path: float = 0.0,
52
+ act_layer: Callable[..., nn.Module] = nn.GELU,
53
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
54
+ attn_class: Callable[..., nn.Module] = Attention,
55
+ ffn_layer: Callable[..., nn.Module] = Mlp,
56
+ ) -> None:
57
+ super().__init__()
58
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
59
+ self.norm1 = norm_layer(dim)
60
+ self.attn = attn_class(
61
+ dim,
62
+ num_heads=num_heads,
63
+ qkv_bias=qkv_bias,
64
+ proj_bias=proj_bias,
65
+ attn_drop=attn_drop,
66
+ proj_drop=drop,
67
+ )
68
+
69
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
70
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
71
+
72
+ self.norm2 = norm_layer(dim)
73
+ mlp_hidden_dim = int(dim * mlp_ratio)
74
+ self.mlp = ffn_layer(
75
+ in_features=dim,
76
+ hidden_features=mlp_hidden_dim,
77
+ act_layer=act_layer,
78
+ drop=drop,
79
+ bias=ffn_bias,
80
+ )
81
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
82
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
83
+
84
+ self.sample_drop_ratio = drop_path
85
+
86
+ def forward(self, x: Tensor) -> Tensor:
87
+ def attn_residual_func(x: Tensor) -> Tensor:
88
+ return self.ls1(self.attn(self.norm1(x)))
89
+
90
+ def ffn_residual_func(x: Tensor) -> Tensor:
91
+ return self.ls2(self.mlp(self.norm2(x)))
92
+
93
+ if self.training and self.sample_drop_ratio > 0.1:
94
+ # the overhead is compensated only for a drop path rate larger than 0.1
95
+ x = drop_add_residual_stochastic_depth(
96
+ x,
97
+ residual_func=attn_residual_func,
98
+ sample_drop_ratio=self.sample_drop_ratio,
99
+ )
100
+ x = drop_add_residual_stochastic_depth(
101
+ x,
102
+ residual_func=ffn_residual_func,
103
+ sample_drop_ratio=self.sample_drop_ratio,
104
+ )
105
+ elif self.training and self.sample_drop_ratio > 0.0:
106
+ x = x + self.drop_path1(attn_residual_func(x))
107
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
108
+ else:
109
+ x = x + attn_residual_func(x)
110
+ x = x + ffn_residual_func(x)
111
+ return x
112
+
113
+
114
+ def drop_add_residual_stochastic_depth(
115
+ x: Tensor,
116
+ residual_func: Callable[[Tensor], Tensor],
117
+ sample_drop_ratio: float = 0.0,
118
+ ) -> Tensor:
119
+ # 1) extract subset using permutation
120
+ b, n, d = x.shape
121
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
122
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
123
+ x_subset = x[brange]
124
+
125
+ # 2) apply residual_func to get residual
126
+ residual = residual_func(x_subset)
127
+
128
+ x_flat = x.flatten(1)
129
+ residual = residual.flatten(1)
130
+
131
+ residual_scale_factor = b / sample_subset_size
132
+
133
+ # 3) add the residual
134
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
135
+ return x_plus_residual.view_as(x)
136
+
137
+
138
+ def get_branges_scales(x, sample_drop_ratio=0.0):
139
+ b, n, d = x.shape
140
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
141
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
142
+ residual_scale_factor = b / sample_subset_size
143
+ return brange, residual_scale_factor
144
+
145
+
146
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
147
+ if scaling_vector is None:
148
+ x_flat = x.flatten(1)
149
+ residual = residual.flatten(1)
150
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
151
+ else:
152
+ x_plus_residual = scaled_index_add(
153
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
154
+ )
155
+ return x_plus_residual
156
+
157
+
158
+ attn_bias_cache: Dict[Tuple, Any] = {}
159
+
160
+
161
+ def get_attn_bias_and_cat(x_list, branges=None):
162
+ """
163
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
164
+ """
165
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
166
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
167
+ if all_shapes not in attn_bias_cache.keys():
168
+ seqlens = []
169
+ for b, x in zip(batch_sizes, x_list):
170
+ for _ in range(b):
171
+ seqlens.append(x.shape[1])
172
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
173
+ attn_bias._batch_sizes = batch_sizes
174
+ attn_bias_cache[all_shapes] = attn_bias
175
+
176
+ if branges is not None:
177
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
178
+ else:
179
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
180
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
181
+
182
+ return attn_bias_cache[all_shapes], cat_tensors
183
+
184
+
185
+ def drop_add_residual_stochastic_depth_list(
186
+ x_list: List[Tensor],
187
+ residual_func: Callable[[Tensor, Any], Tensor],
188
+ sample_drop_ratio: float = 0.0,
189
+ scaling_vector=None,
190
+ ) -> Tensor:
191
+ # 1) generate random set of indices for dropping samples in the batch
192
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
193
+ branges = [s[0] for s in branges_scales]
194
+ residual_scale_factors = [s[1] for s in branges_scales]
195
+
196
+ # 2) get attention bias and index+concat the tensors
197
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
198
+
199
+ # 3) apply residual_func to get residual, and split the result
200
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
201
+
202
+ outputs = []
203
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
204
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
205
+ return outputs
206
+
207
+
208
+ class NestedTensorBlock(Block):
209
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
210
+ """
211
+ x_list contains a list of tensors to nest together and run
212
+ """
213
+ assert isinstance(self.attn, MemEffAttention)
214
+
215
+ if self.training and self.sample_drop_ratio > 0.0:
216
+
217
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
218
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
219
+
220
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
221
+ return self.mlp(self.norm2(x))
222
+
223
+ x_list = drop_add_residual_stochastic_depth_list(
224
+ x_list,
225
+ residual_func=attn_residual_func,
226
+ sample_drop_ratio=self.sample_drop_ratio,
227
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
228
+ )
229
+ x_list = drop_add_residual_stochastic_depth_list(
230
+ x_list,
231
+ residual_func=ffn_residual_func,
232
+ sample_drop_ratio=self.sample_drop_ratio,
233
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
234
+ )
235
+ return x_list
236
+ else:
237
+
238
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
239
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
240
+
241
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
242
+ return self.ls2(self.mlp(self.norm2(x)))
243
+
244
+ attn_bias, x = get_attn_bias_and_cat(x_list)
245
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
246
+ x = x + ffn_residual_func(x)
247
+ return attn_bias.split(x)
248
+
249
+ def forward(self, x_or_x_list):
250
+ if isinstance(x_or_x_list, Tensor):
251
+ return super().forward(x_or_x_list)
252
+ elif isinstance(x_or_x_list, list):
253
+ if not XFORMERS_AVAILABLE:
254
+ raise AssertionError("xFormers is required for using nested tensors")
255
+ return self.forward_nested(x_or_x_list)
256
+ else:
257
+ raise AssertionError
258
+
259
+ class BlockRope(nn.Module):
260
+ def __init__(
261
+ self,
262
+ dim: int,
263
+ num_heads: int,
264
+ mlp_ratio: float = 4.0,
265
+ qkv_bias: bool = False,
266
+ proj_bias: bool = True,
267
+ ffn_bias: bool = True,
268
+ drop: float = 0.0,
269
+ attn_drop: float = 0.0,
270
+ init_values=None,
271
+ drop_path: float = 0.0,
272
+ act_layer: Callable[..., nn.Module] = nn.GELU,
273
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
274
+ attn_class: Callable[..., nn.Module] = Attention,
275
+ ffn_layer: Callable[..., nn.Module] = Mlp,
276
+ qk_norm: bool=False,
277
+ rope=None
278
+ ) -> None:
279
+ super().__init__()
280
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
281
+ self.norm1 = norm_layer(dim)
282
+ self.attn = attn_class(
283
+ dim,
284
+ num_heads=num_heads,
285
+ qkv_bias=qkv_bias,
286
+ proj_bias=proj_bias,
287
+ attn_drop=attn_drop,
288
+ proj_drop=drop,
289
+ qk_norm=qk_norm,
290
+ rope=rope
291
+ )
292
+
293
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
294
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
295
+
296
+ self.norm2 = norm_layer(dim)
297
+ mlp_hidden_dim = int(dim * mlp_ratio)
298
+ self.mlp = ffn_layer(
299
+ in_features=dim,
300
+ hidden_features=mlp_hidden_dim,
301
+ act_layer=act_layer,
302
+ drop=drop,
303
+ bias=ffn_bias,
304
+ )
305
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
306
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
307
+
308
+ self.sample_drop_ratio = drop_path
309
+
310
+ def forward(self, x: Tensor, xpos=None) -> Tensor:
311
+ def attn_residual_func(x: Tensor) -> Tensor:
312
+ return self.ls1(self.attn(self.norm1(x), xpos=xpos))
313
+
314
+ def ffn_residual_func(x: Tensor) -> Tensor:
315
+ return self.ls2(self.mlp(self.norm2(x)))
316
+
317
+ if self.training and self.sample_drop_ratio > 0.1:
318
+ # the overhead is compensated only for a drop path rate larger than 0.1
319
+ x = drop_add_residual_stochastic_depth(
320
+ x,
321
+ residual_func=attn_residual_func,
322
+ sample_drop_ratio=self.sample_drop_ratio,
323
+ )
324
+ x = drop_add_residual_stochastic_depth(
325
+ x,
326
+ residual_func=ffn_residual_func,
327
+ sample_drop_ratio=self.sample_drop_ratio,
328
+ )
329
+ elif self.training and self.sample_drop_ratio > 0.0:
330
+ x = x + self.drop_path1(attn_residual_func(x))
331
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
332
+ else:
333
+ x = x + attn_residual_func(x)
334
+ x = x + ffn_residual_func(x)
335
+ return x
336
+
337
+
338
+ class CrossBlockRope(nn.Module):
339
+ def __init__(
340
+ self,
341
+ dim: int,
342
+ num_heads: int,
343
+ mlp_ratio: float = 4.0,
344
+ qkv_bias: bool = False,
345
+ proj_bias: bool = True,
346
+ ffn_bias: bool = True,
347
+ act_layer: Callable[..., nn.Module] = nn.GELU,
348
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
349
+ attn_class: Callable[..., nn.Module] = Attention,
350
+ cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope,
351
+ ffn_layer: Callable[..., nn.Module] = Mlp,
352
+ init_values=None,
353
+ qk_norm: bool=False,
354
+ rope=None
355
+ ) -> None:
356
+ super().__init__()
357
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
358
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
359
+ self.norm1 = norm_layer(dim)
360
+ self.attn = attn_class(
361
+ dim,
362
+ num_heads=num_heads,
363
+ qkv_bias=qkv_bias,
364
+ proj_bias=proj_bias,
365
+ rope=rope,
366
+ qk_norm=qk_norm
367
+ )
368
+
369
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
370
+ self.ls_y = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
371
+ self.norm2 = norm_layer(dim)
372
+ self.norm_y = norm_layer(dim)
373
+ self.cross_attn = cross_attn_class(
374
+ dim,
375
+ num_heads=num_heads,
376
+ qkv_bias=qkv_bias,
377
+ proj_bias=proj_bias,
378
+ rope=rope,
379
+ qk_norm=qk_norm
380
+ )
381
+
382
+ self.norm3 = norm_layer(dim)
383
+ mlp_hidden_dim = int(dim * mlp_ratio)
384
+ self.mlp = ffn_layer(
385
+ in_features=dim,
386
+ hidden_features=mlp_hidden_dim,
387
+ act_layer=act_layer,
388
+ bias=ffn_bias,
389
+ )
390
+
391
+ def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor:
392
+ def attn_residual_func(x: Tensor) -> Tensor:
393
+ return self.ls1(self.attn(self.norm1(x), xpos=xpos))
394
+
395
+ def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor:
396
+ return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos))
397
+
398
+ def ffn_residual_func(x: Tensor) -> Tensor:
399
+ return self.ls2(self.mlp(self.norm3(x)))
400
+
401
+ x = x + attn_residual_func(x)
402
+ y_ = self.norm_y(y)
403
+ x = x + cross_attn_residual_func(x, y_)
404
+ x = x + ffn_residual_func(x)
405
+
406
+ return x
flow3r/models/layers/camera_head.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from copy import deepcopy
4
+ import torch.nn.functional as F
5
+
6
+ # code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172'
7
+ class ResConvBlock(nn.Module):
8
+ """
9
+ 1x1 convolution residual block
10
+ """
11
+ def __init__(self, in_channels, out_channels):
12
+ super().__init__()
13
+ self.in_channels = in_channels
14
+ self.out_channels = out_channels
15
+ self.head_skip = nn.Identity() if self.in_channels == self.out_channels else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
16
+ # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
17
+ # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
18
+ # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
19
+
20
+ # change 1x1 convolution to linear
21
+ self.res_conv1 = nn.Linear(self.in_channels, self.out_channels)
22
+ self.res_conv2 = nn.Linear(self.out_channels, self.out_channels)
23
+ self.res_conv3 = nn.Linear(self.out_channels, self.out_channels)
24
+
25
+ def forward(self, res):
26
+ x = F.relu(self.res_conv1(res))
27
+ x = F.relu(self.res_conv2(x))
28
+ x = F.relu(self.res_conv3(x))
29
+ res = self.head_skip(res) + x
30
+ return res
31
+
32
+ class CameraHead(nn.Module):
33
+ def __init__(self, dim=512):
34
+ super().__init__()
35
+ output_dim = dim
36
+ self.res_conv = nn.ModuleList([deepcopy(ResConvBlock(output_dim, output_dim))
37
+ for _ in range(2)])
38
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
39
+ self.more_mlps = nn.Sequential(
40
+ nn.Linear(output_dim,output_dim),
41
+ nn.ReLU(),
42
+ nn.Linear(output_dim,output_dim),
43
+ nn.ReLU()
44
+ )
45
+ self.fc_t = nn.Linear(output_dim, 3)
46
+ self.fc_rot = nn.Linear(output_dim, 9)
47
+
48
+ def forward(self, feat, patch_h, patch_w):
49
+ BN, hw, c = feat.shape
50
+
51
+ for i in range(2):
52
+ feat = self.res_conv[i](feat)
53
+
54
+ # feat = self.avgpool(feat)
55
+ feat = self.avgpool(feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()) ##########
56
+ feat = feat.view(feat.size(0), -1)
57
+
58
+ feat = self.more_mlps(feat) # [B, D_]
59
+ with torch.amp.autocast(device_type='cuda', enabled=False):
60
+ out_t = self.fc_t(feat.float()) # [B,3]
61
+ out_r = self.fc_rot(feat.float()) # [B,9]
62
+ pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device)
63
+
64
+ return pose
65
+
66
+ def convert_pose_to_4x4(self, B, out_r, out_t, device):
67
+ out_r = self.svd_orthogonalize(out_r) # [N,3,3]
68
+ pose = torch.zeros((B, 4, 4), device=device)
69
+ pose[:, :3, :3] = out_r
70
+ pose[:, :3, 3] = out_t
71
+ pose[:, 3, 3] = 1.
72
+ return pose
73
+
74
+ def svd_orthogonalize(self, m):
75
+ """Convert 9D representation to SO(3) using SVD orthogonalization.
76
+
77
+ Args:
78
+ m: [BATCH, 3, 3] 3x3 matrices.
79
+
80
+ Returns:
81
+ [BATCH, 3, 3] SO(3) rotation matrices.
82
+ """
83
+ if m.dim() < 3:
84
+ m = m.reshape((-1, 3, 3))
85
+ m_transpose = torch.transpose(torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2)
86
+ u, s, v = torch.svd(m_transpose)
87
+ det = torch.det(torch.matmul(v, u.transpose(-2, -1)))
88
+ # Check orientation reflection.
89
+ r = torch.matmul(
90
+ torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2),
91
+ u.transpose(-2, -1)
92
+ )
93
+ return r
flow3r/models/layers/pos_embed.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Position embedding utils
7
+ # --------------------------------------------------------
8
+
9
+
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ # --------------------------------------------------------
16
+ # 2D sine-cosine position embedding
17
+ # References:
18
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
21
+ # --------------------------------------------------------
22
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
23
+ """
24
+ grid_size: int of the grid height and width
25
+ return:
26
+ pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27
+ """
28
+ grid_h = np.arange(grid_size, dtype=np.float32)
29
+ grid_w = np.arange(grid_size, dtype=np.float32)
30
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
31
+ grid = np.stack(grid, axis=0)
32
+
33
+ grid = grid.reshape([2, 1, grid_size, grid_size])
34
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
35
+ if n_cls_token>0:
36
+ pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
37
+ return pos_embed
38
+
39
+
40
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
41
+ assert embed_dim % 2 == 0
42
+
43
+ # use half of dimensions to encode grid_h
44
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
45
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
46
+
47
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
48
+ return emb
49
+
50
+
51
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
52
+ """
53
+ embed_dim: output dimension for each position
54
+ pos: a list of positions to be encoded: size (M,)
55
+ out: (M, D)
56
+ """
57
+ assert embed_dim % 2 == 0
58
+ omega = np.arange(embed_dim // 2, dtype=float)
59
+ omega /= embed_dim / 2.
60
+ omega = 1. / 10000**omega # (D/2,)
61
+
62
+ pos = pos.reshape(-1) # (M,)
63
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
64
+
65
+ emb_sin = np.sin(out) # (M, D/2)
66
+ emb_cos = np.cos(out) # (M, D/2)
67
+
68
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
69
+ return emb
70
+
71
+
72
+ # --------------------------------------------------------
73
+ # Interpolate position embeddings for high-resolution
74
+ # References:
75
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
76
+ # DeiT: https://github.com/facebookresearch/deit
77
+ # --------------------------------------------------------
78
+ def interpolate_pos_embed(model, checkpoint_model):
79
+ if 'pos_embed' in checkpoint_model:
80
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
81
+ embedding_size = pos_embed_checkpoint.shape[-1]
82
+ num_patches = model.patch_embed.num_patches
83
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
84
+ # height (== width) for the checkpoint position embedding
85
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
86
+ # height (== width) for the new position embedding
87
+ new_size = int(num_patches ** 0.5)
88
+ # class_token and dist_token are kept unchanged
89
+ if orig_size != new_size:
90
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
91
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
92
+ # only the position tokens are interpolated
93
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
94
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
95
+ pos_tokens = torch.nn.functional.interpolate(
96
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
97
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
98
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
99
+ checkpoint_model['pos_embed'] = new_pos_embed
100
+
101
+
102
+ #----------------------------------------------------------
103
+ # RoPE2D: RoPE implementation in 2D
104
+ #----------------------------------------------------------
105
+
106
+ try:
107
+ from models.curope import cuRoPE2D
108
+ RoPE2D = cuRoPE2D
109
+ except ImportError:
110
+ print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
111
+
112
+ class RoPE2D(torch.nn.Module):
113
+
114
+ def __init__(self, freq=100.0, F0=1.0):
115
+ super().__init__()
116
+ self.base = freq
117
+ self.F0 = F0
118
+ self.cache = {}
119
+
120
+ def get_cos_sin(self, D, seq_len, device, dtype):
121
+ if (D,seq_len,device,dtype) not in self.cache:
122
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
123
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
124
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
125
+ freqs = torch.cat((freqs, freqs), dim=-1)
126
+ cos = freqs.cos() # (Seq, Dim)
127
+ sin = freqs.sin()
128
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
129
+ return self.cache[D,seq_len,device,dtype]
130
+
131
+ @staticmethod
132
+ def rotate_half(x):
133
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
134
+ return torch.cat((-x2, x1), dim=-1)
135
+
136
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
137
+ assert pos1d.ndim==2
138
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
139
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
140
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
141
+
142
+ def forward(self, tokens, positions):
143
+ """
144
+ input:
145
+ * tokens: batch_size x nheads x ntokens x dim
146
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
147
+ output:
148
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
149
+ """
150
+ assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
151
+ D = tokens.size(3) // 2
152
+ assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
153
+ cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
154
+ # split features into two along the feature dimension, and apply rope1d on each half
155
+ y, x = tokens.chunk(2, dim=-1)
156
+ y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
157
+ x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
158
+ tokens = torch.cat((y, x), dim=-1)
159
+ return tokens
160
+
161
+ # patch embedding
162
+ class PositionGetter(object):
163
+ """ return positions of patches """
164
+
165
+ def __init__(self):
166
+ self.cache_positions = {}
167
+
168
+ def __call__(self, b, h, w, device):
169
+ if not (h,w) in self.cache_positions:
170
+ x = torch.arange(w, device=device)
171
+ y = torch.arange(h, device=device)
172
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
173
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
174
+ return pos
flow3r/models/layers/transformer_head.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import FlashAttentionRope, FlashCrossAttentionRope
2
+ from .block import BlockRope, CrossBlockRope
3
+ from ..dinov2.layers import Mlp
4
+ import torch
5
+ import torch.nn as nn
6
+ from functools import partial
7
+ from torch.utils.checkpoint import checkpoint
8
+ import torch.nn.functional as F
9
+ from flow3r.models.flow_head.utils import create_uv_grid, position_grid_to_embed
10
+
11
+ class TransformerDecoder(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_dim,
15
+ out_dim,
16
+ dec_embed_dim=512,
17
+ depth=5,
18
+ dec_num_heads=8,
19
+ mlp_ratio=4,
20
+ rope=None,
21
+ need_project=True,
22
+ use_checkpoint=False,
23
+ ):
24
+ super().__init__()
25
+
26
+ self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
27
+ self.use_checkpoint = use_checkpoint
28
+
29
+ self.blocks = nn.ModuleList([
30
+ BlockRope(
31
+ dim=dec_embed_dim,
32
+ num_heads=dec_num_heads,
33
+ mlp_ratio=mlp_ratio,
34
+ qkv_bias=True,
35
+ proj_bias=True,
36
+ ffn_bias=True,
37
+ drop_path=0.0,
38
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
39
+ act_layer=nn.GELU,
40
+ ffn_layer=Mlp,
41
+ init_values=None,
42
+ qk_norm=False,
43
+ # attn_class=MemEffAttentionRope,
44
+ attn_class=FlashAttentionRope,
45
+ rope=rope
46
+ ) for _ in range(depth)])
47
+
48
+ self.linear_out = nn.Linear(dec_embed_dim, out_dim)
49
+
50
+ def forward(self, hidden, xpos=None, return_intermediate=False):
51
+ hidden = self.projects(hidden)
52
+ intermediate = []
53
+ for i, blk in enumerate(self.blocks):
54
+ if self.use_checkpoint and self.training:
55
+ hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False)
56
+ else:
57
+ hidden = blk(hidden, xpos=xpos)
58
+
59
+ if return_intermediate:
60
+ intermediate.append(hidden)
61
+
62
+ out = self.linear_out(hidden)
63
+
64
+ if return_intermediate:
65
+ return out, intermediate[-4:]
66
+
67
+ return out
68
+
69
+ class LinearPts3d (nn.Module):
70
+ """
71
+ Linear head for dust3r
72
+ Each token outputs: - 16x16 3D points (+ confidence)
73
+ """
74
+
75
+ def __init__(self, patch_size, dec_embed_dim, output_dim=3,):
76
+ super().__init__()
77
+ self.patch_size = patch_size
78
+
79
+ self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
80
+
81
+ def forward(self, decout, img_shape):
82
+ H, W = img_shape
83
+ tokens = decout[-1]
84
+ B, S, D = tokens.shape
85
+ # print("--------------------------------")
86
+ # print("pointhead")
87
+ # print("H, W is", H, W)
88
+ # print("hw is", S)
89
+ # print("patch_h is", H//self.patch_size)
90
+ # print("patch_w is", W//self.patch_size)
91
+ # print("--------------------------------")
92
+
93
+ # extract 3D points
94
+ feat = self.proj(tokens) # B,S,D
95
+ feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
96
+ feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
97
+
98
+ # permute + norm depth
99
+ return feat.permute(0, 2, 3, 1)
100
+
101
+ class LinearFlow2d (nn.Module):
102
+ """
103
+ Linear head for flow 2D with MLP fusion of camera and patch features
104
+ Each token outputs: - 16x16 2D flow
105
+ """
106
+
107
+ def __init__(self, patch_size, dec_embed_dim, output_dim=2, camera_dim=512, num_heads=8, rope=None):
108
+ super().__init__()
109
+ self.patch_size = patch_size
110
+ self.dec_embed_dim = dec_embed_dim
111
+ self.camera_dim = camera_dim
112
+
113
+ # Position embedding for camera features (to distinguish first and second camera)
114
+ self.camera_pos_embed = nn.Parameter(torch.randn(2, 1, camera_dim))
115
+ nn.init.normal_(self.camera_pos_embed, std=0.02)
116
+
117
+ # Projection to match camera feature dimension to patch feature dimension
118
+ # self.camera_proj = nn.Linear(camera_dim, dec_embed_dim)
119
+
120
+ # MLP to fuse camera features and patch features
121
+ self.mlp = nn.Sequential(
122
+ nn.Linear(2*camera_dim + dec_embed_dim, 2*dec_embed_dim),
123
+ nn.ReLU(),
124
+ nn.Linear(2*dec_embed_dim, 2*dec_embed_dim),
125
+ nn.ReLU(),
126
+ nn.Linear(2*dec_embed_dim, dec_embed_dim),
127
+ )
128
+
129
+ # Final projection to output dimension
130
+ self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
131
+
132
+ def forward(self, patch_hidden, camera_hidden, pair_indices, img_shape, B, N):
133
+ """
134
+ Args:
135
+ patch_hidden: (B*N, hw, dec_embed_dim) - motion decoder output
136
+ camera_hidden: (B*N, hw, camera_dim) - camera decoder output
137
+ pair_indices: Tensor of shape (B, S, 2) or list of tuples
138
+ If Tensor (B, S, 2): indices are (i, j) relative to each batch
139
+ If list: [(b1, i1, j1), ...] or [(i1, j1), ...]
140
+ img_shape: (H, W)
141
+ B: batch size
142
+ N: sequence length (number of images)
143
+ Returns:
144
+ flow: (total_pairs, H, W, 2)
145
+ """
146
+ H, W = img_shape
147
+ hw = patch_hidden.shape[1]
148
+
149
+ # Reshape from (B*N, hw, dim) to (B, N, hw, dim)、
150
+ # print("!!!!!now inside the LinearFlow2d forward function")
151
+ patch_hidden = patch_hidden.reshape(B, N, hw, self.dec_embed_dim)
152
+ camera_hidden = camera_hidden.reshape(B, N, hw, self.camera_dim)
153
+ # print("the shape of patch_hidden is", patch_hidden.shape)
154
+ # print("the shape of camera_hidden is", camera_hidden.shape)
155
+
156
+ # Handle Tensor input (B, S, 2)
157
+ if isinstance(pair_indices, torch.Tensor) and pair_indices.dim() == 3:
158
+ # pair_indices shape: (B, S, 2)
159
+ # We can use advanced indexing for efficiency
160
+
161
+ # Create batch indices: (B, S)
162
+ S = pair_indices.shape[1]
163
+ batch_idx = torch.arange(B, device=pair_indices.device).unsqueeze(1).expand(B, S)
164
+
165
+ # Extract indices for i and j images: (B, S)
166
+ idx_i = pair_indices[:, :, 0]
167
+ idx_j = pair_indices[:, :, 1]
168
+
169
+ # Extract patch features: (B, S, hw, dim)
170
+ patch_feat = patch_hidden[batch_idx, idx_i]
171
+ # print("the shape of patch_feat is", patch_feat.shape)
172
+
173
+ # Extract camera features: (B, S, hw, dim)
174
+ camera_i = camera_hidden[batch_idx, idx_i]
175
+ camera_j = camera_hidden[batch_idx, idx_j]
176
+ # print("the shape of camera_i is", camera_i.shape)
177
+ # print("the shape of camera_j is", camera_j.shape)
178
+ # Add position encoding
179
+ camera_i = camera_i + self.camera_pos_embed[0]
180
+ camera_j = camera_j + self.camera_pos_embed[1]
181
+ # print("the shape of camera_i after position encoding is", camera_i.shape)
182
+ # print("the shape of camera_j after position encoding is", camera_j.shape)
183
+ # Project camera features
184
+ # camera_i = self.camera_proj(camera_i)
185
+ # camera_j = self.camera_proj(camera_j)
186
+ # print("the shape of camera_i after projection is", camera_i.shape)
187
+ # print("the shape of camera_j after projection is", camera_j.shape)
188
+ # Concatenate camera features and patch features: (B, S, hw, 3*dim)
189
+ concat_features = torch.cat([camera_i, camera_j, patch_feat], dim=-1)
190
+
191
+ # Flatten B and S dimensions
192
+ total_pairs = B * S
193
+ input_features = concat_features.reshape(total_pairs, hw, 2*self.camera_dim + self.dec_embed_dim)
194
+
195
+ else:
196
+ raise ValueError("Invalid pair_indices type")
197
+
198
+ # Apply MLP
199
+ fused_features = self.mlp(input_features)
200
+ # print("the shape of fused_features after reshape is", fused_features.shape)
201
+ # Project to output dimension
202
+ feat = self.proj(fused_features) # (total_pairs, patch_hw, output_dim * patch_size^2)
203
+ # print("the shape of feat is", feat.shape)
204
+ # Reshape and apply pixel shuffle
205
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
206
+ # print("--------------------------------")
207
+ # print("H, W is", H, W)
208
+ # print("hw is", hw)
209
+ # print("patch_h is", patch_h)
210
+ # print("patch_w is", patch_w)
211
+ # print("--------------------------------")
212
+ feat = feat.transpose(-1, -2).reshape(total_pairs, -1, patch_h, patch_w)
213
+ feat = F.pixel_shuffle(feat, self.patch_size) # (total_pairs, output_dim, H, W)
214
+ # print("the shape of feat after pixel shuffle is", feat.shape)
215
+
216
+ # Permute to (total_pairs, H, W, output_dim)
217
+ return feat.permute(0, 2, 3, 1).reshape(B, S, H, W, -1)
218
+
219
+ class DPTFlow2d (nn.Module):
220
+ """
221
+ Simplified DPT head for flow 2D with only one layer input
222
+ Each token outputs: - 16x16 2D flow
223
+ """
224
+
225
+ def __init__(self, patch_size, dec_embed_dim, output_dim=2, camera_dim=512, rope=None, features=256):
226
+ super().__init__()
227
+ self.patch_size = patch_size
228
+ self.dec_embed_dim = dec_embed_dim
229
+ self.camera_dim = camera_dim
230
+
231
+ # Projection to match camera feature dimension to patch feature dimension
232
+ # self.camera_proj = nn.Linear(camera_dim, dec_embed_dim)
233
+
234
+ # MLP to fuse camera features and patch features
235
+ self.mlp = nn.Sequential(
236
+ nn.Linear(2*camera_dim + dec_embed_dim, 2*dec_embed_dim),
237
+ nn.ReLU(),
238
+ nn.Linear(2*dec_embed_dim, 2*dec_embed_dim),
239
+ nn.ReLU(),
240
+ nn.Linear(2*dec_embed_dim, dec_embed_dim),
241
+ )
242
+
243
+ self.norm = nn.LayerNorm(dec_embed_dim)
244
+
245
+ self.project = nn.Conv2d(dec_embed_dim, features, kernel_size=1, stride=1, padding=0)
246
+ self.refine_low = nn.Sequential(
247
+ nn.Conv2d(features, features, 3, padding=1),
248
+ nn.GELU(),
249
+ nn.Conv2d(features, features, 3, padding=1),
250
+ nn.GELU(),
251
+ )
252
+ self.refine_high = nn.Sequential(
253
+ nn.Conv2d(features, features, 3, padding=1),
254
+ nn.GELU(),
255
+ nn.Conv2d(features, features, 3, padding=1),
256
+ nn.GELU(),
257
+ )
258
+ self.out_head = nn.Sequential(
259
+ nn.Conv2d(features, 64, 3, padding=1),
260
+ nn.GELU(),
261
+ nn.Conv2d(64, output_dim, 1),
262
+ )
263
+
264
+ # Final projection to output dimension
265
+ # self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
266
+
267
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
268
+ """
269
+ Apply positional embedding to tensor x.
270
+ """
271
+ patch_w = x.shape[-1]
272
+ patch_h = x.shape[-2]
273
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
274
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
275
+ pos_embed = pos_embed * ratio
276
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
277
+ return x + pos_embed
278
+
279
+
280
+ def forward(self, patch_hidden, camera_hidden, pair_indices, img_shape, B, N):
281
+ """
282
+ Args:
283
+ patch_hidden: (B*N, hw, dec_embed_dim) - motion decoder output
284
+ camera_hidden: (B*N, hw, camera_dim) - camera decoder output
285
+ pair_indices: Tensor of shape (B, S, 2) or list of tuples
286
+ If Tensor (B, S, 2): indices are (i, j) relative to each batch
287
+ If list: [(b1, i1, j1), ...] or [(i1, j1), ...]
288
+ img_shape: (H, W)
289
+ B: batch size
290
+ N: sequence length (number of images)
291
+ Returns:
292
+ flow: (total_pairs, H, W, 2)
293
+ """
294
+ H, W = img_shape
295
+ hw = patch_hidden.shape[1]
296
+
297
+ # Reshape from (B*N, hw, dim) to (B, N, hw, dim)、
298
+ # print("!!!!!now inside the LinearFlow2d forward function")
299
+ patch_hidden = patch_hidden.reshape(B, N, hw, self.dec_embed_dim)
300
+ camera_hidden = camera_hidden.reshape(B, N, hw, self.camera_dim)
301
+
302
+ # Handle Tensor input (B, S, 2)
303
+ S = pair_indices.shape[1]
304
+ batch_idx = torch.arange(B, device=pair_indices.device).unsqueeze(1).expand(B, S)
305
+
306
+ # Extract indices for i and j images: (B, S)
307
+ idx_i = pair_indices[:, :, 0]
308
+ idx_j = pair_indices[:, :, 1]
309
+
310
+ # Extract patch features: (B, S, hw, dim)
311
+ patch_feat = patch_hidden[batch_idx, idx_i]
312
+ # print("the shape of patch_feat is", patch_feat.shape)
313
+
314
+ # Extract camera features: (B, S, hw, dim)
315
+ camera_i = camera_hidden[batch_idx, idx_i]
316
+ camera_j = camera_hidden[batch_idx, idx_j]
317
+ # Concatenate camera features and patch features: (B, S, hw, 3*dim)
318
+ concat_features = torch.cat([camera_i, camera_j, patch_feat], dim=-1)
319
+
320
+ # Flatten B and S dimensions
321
+ total_pairs = B * S
322
+ input_features = concat_features.reshape(total_pairs, hw, 2*self.camera_dim + self.dec_embed_dim)
323
+
324
+ # Apply MLP
325
+ fused = self.mlp(input_features) # (T, hw, dec_embed_dim)
326
+
327
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
328
+ assert hw == patch_h * patch_w, (hw, patch_h, patch_w)
329
+ fused = self.norm(fused)
330
+ feat = fused.transpose(1, 2).reshape(total_pairs, self.dec_embed_dim, patch_h, patch_w) # (T,D,h,w)
331
+
332
+ feat = self.project(feat) # (T,features,h,w)
333
+ feat = self._apply_pos_embed(feat, W, H)
334
+ feat = self.refine_low(feat)
335
+
336
+ feat = F.interpolate(feat, size=(H, W), mode="bilinear", align_corners=True)
337
+ feat = self._apply_pos_embed(feat, W, H)
338
+ feat = self.refine_high(feat)
339
+
340
+ flow = self.out_head(feat) # (T,2,H,W)
341
+ return flow.permute(0, 2, 3, 1).reshape(B, S, H, W, -1)
342
+
343
+ class ContextTransformerDecoder(nn.Module):
344
+ def __init__(
345
+ self,
346
+ in_dim,
347
+ out_dim,
348
+ dec_embed_dim=512,
349
+ depth=5,
350
+ dec_num_heads=8,
351
+ mlp_ratio=4,
352
+ rope=None,
353
+ ):
354
+ super().__init__()
355
+
356
+ self.projects_x = nn.Linear(in_dim, dec_embed_dim)
357
+ self.projects_y = nn.Linear(in_dim, dec_embed_dim)
358
+
359
+ self.blocks = nn.ModuleList([
360
+ CrossBlockRope(
361
+ dim=dec_embed_dim,
362
+ num_heads=dec_num_heads,
363
+ mlp_ratio=mlp_ratio,
364
+ qkv_bias=True,
365
+ proj_bias=True,
366
+ ffn_bias=True,
367
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
368
+ act_layer=nn.GELU,
369
+ ffn_layer=Mlp,
370
+ init_values=None,
371
+ qk_norm=False,
372
+ # attn_class=MemEffAttentionRope,
373
+ # cross_attn_class=MemEffCrossAttentionRope,
374
+ attn_class=FlashAttentionRope,
375
+ cross_attn_class=FlashCrossAttentionRope,
376
+ rope=rope
377
+ ) for _ in range(depth)])
378
+
379
+ self.linear_out = nn.Linear(dec_embed_dim, out_dim)
380
+
381
+ def forward(self, hidden, context, xpos=None, ypos=None):
382
+ hidden = self.projects_x(hidden)
383
+ context = self.projects_y(context)
384
+
385
+ for i, blk in enumerate(self.blocks):
386
+ hidden = blk(hidden, context, xpos=xpos, ypos=ypos)
387
+
388
+ out = self.linear_out(hidden)
389
+
flow3r/utils/alignment.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+ from collections import namedtuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.types
10
+ # import utils3d
11
+
12
+
13
+ def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
14
+ "Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
15
+ shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
16
+ minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
17
+ minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
18
+ indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
19
+ indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
20
+ return torch.return_types.min((minimum, indices))
21
+
22
+
23
+ def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
24
+ batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
25
+ n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
26
+ splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
27
+ splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
28
+ results = []
29
+ for i in range(n_chunks):
30
+ chunk_args = tuple(arg[i] for arg in splited_args)
31
+ chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
32
+ results.append(fn(*chunk_args, **chunk_kwargs))
33
+
34
+ if isinstance(results[0], tuple):
35
+ return tuple(torch.cat(r, dim=0) for r in zip(*results))
36
+ else:
37
+ return torch.cat(results, dim=0)
38
+
39
+
40
+ def _pad_inf(x_: torch.Tensor):
41
+ return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
42
+
43
+
44
+ def _pad_cumsum(cumsum: torch.Tensor):
45
+ return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
46
+
47
+
48
+ def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
49
+ return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
50
+
51
+
52
+ def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
53
+ """
54
+ If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
55
+
56
+ w_i must be >= 0.
57
+
58
+ ### Parameters:
59
+ - `x`: tensor of shape (..., n)
60
+ - `y`: tensor of shape (..., n)
61
+ - `w`: tensor of shape (..., n)
62
+ - `trunc`: optional, float or tensor of shape (..., n) or None
63
+
64
+ ### Returns:
65
+ - `a`: tensor of shape (...), differentiable
66
+ - `loss`: tensor of shape (...), value of loss function at `a`, detached
67
+ - `index`: tensor of shape (...), where a = y[idx] / x[idx]
68
+ """
69
+ if trunc is None:
70
+ x, y, w = torch.broadcast_tensors(x, y, w)
71
+ sign = torch.sign(x)
72
+ x, y = x * sign, y * sign
73
+ y_div_x = y / x.clamp_min(eps)
74
+ y_div_x, argsort = y_div_x.sort(dim=-1)
75
+
76
+ wx = torch.gather(x * w, dim=-1, index=argsort)
77
+ derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
78
+ search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
79
+
80
+ a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
81
+ index = argsort.gather(dim=-1, index=search).squeeze(-1)
82
+ loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
83
+
84
+ else:
85
+ # Reshape to (batch_size, n) for simplicity
86
+ x, y, w = torch.broadcast_tensors(x, y, w)
87
+ batch_shape = x.shape[:-1]
88
+ batch_size = math.prod(batch_shape)
89
+ x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
90
+
91
+ sign = torch.sign(x)
92
+ x, y = x * sign, y * sign
93
+ wx, wy = w * x, w * y
94
+ xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
95
+
96
+ y_div_x = A = y / x.clamp_min(eps)
97
+ B = (wy - trunc) / wx.clamp_min(eps)
98
+ C = (wy + trunc) / wx.clamp_min(eps)
99
+ with torch.no_grad():
100
+ # Caculate prefix sum by orders of A, B, C
101
+ A, A_argsort = A.sort(dim=-1)
102
+ Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
103
+ A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
104
+
105
+ B, B_argsort = B.sort(dim=-1)
106
+ Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
107
+ B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
108
+
109
+ C, C_argsort = C.sort(dim=-1)
110
+ Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
111
+ C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
112
+
113
+ # Caculate left and right derivative of A
114
+ j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
115
+ j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
116
+ j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
117
+ left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
118
+ j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
119
+ j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
120
+ j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
121
+ right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
122
+
123
+ # Find extrema
124
+ is_extrema = (left_derivative < 0) & (right_derivative >= 0)
125
+ is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
126
+ where_extrema_batch, where_extrema_index = torch.where(is_extrema)
127
+
128
+ # Calculate objective value at extrema
129
+ extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
130
+ MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
131
+ SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
132
+ extrema_value = torch.cat([
133
+ _compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
134
+ for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
135
+ ]) # (num_extrema,)
136
+
137
+ # Find minima among corresponding extrema
138
+ minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
139
+ index = where_extrema_index[indices]
140
+
141
+ a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
142
+ a = a.reshape(batch_shape)
143
+ loss = minima.reshape(batch_shape)
144
+ index = index.reshape(batch_shape)
145
+
146
+ return a, loss, index
147
+
148
+
149
+ def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
150
+ """
151
+ Align `depth_src` to `depth_tgt` with given constant weights.
152
+
153
+ ### Parameters:
154
+ - `depth_src: torch.Tensor` of shape (..., N)
155
+ - `depth_tgt: torch.Tensor` of shape (..., N)
156
+
157
+ """
158
+ scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
159
+
160
+ return scale
161
+
162
+
163
+ def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
164
+ """
165
+ Align `depth_src` to `depth_tgt` with given constant weights.
166
+
167
+ ### Parameters:
168
+ - `depth_src: torch.Tensor` of shape (..., N)
169
+ - `depth_tgt: torch.Tensor` of shape (..., N)
170
+ - `weight: torch.Tensor` of shape (..., N)
171
+ - `trunc: float` or tensor of shape (..., N) or None
172
+
173
+ ### Returns:
174
+ - `scale: torch.Tensor` of shape (...).
175
+ - `shift: torch.Tensor` of shape (...).
176
+ """
177
+ dtype, device = depth_src.dtype, depth_src.device
178
+
179
+ # Flatten batch dimensions for simplicity
180
+ batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
181
+ batch_size = math.prod(batch_shape)
182
+ depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
183
+
184
+ # Here, we take anchors only for non-zero weights.
185
+ # Although the results will be still correct even anchor points have zero weight,
186
+ # it is wasting computation and may cause instability in some cases, e.g. too many extrema.
187
+ anchors_where_batch, anchors_where_n = torch.where(weight > 0)
188
+
189
+ # Stop gradient when solving optimal anchors
190
+ with torch.no_grad():
191
+ depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
192
+ depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
193
+
194
+ depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
195
+ depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
196
+ weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
197
+
198
+ scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
199
+
200
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
201
+
202
+ # Reproduce by indexing for shorter compute graph
203
+ index_1 = anchors_where_n[index_anchor] # (batch_size,)
204
+ index_2 = index[index_anchor] # (batch_size,)
205
+
206
+ tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
207
+ tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
208
+
209
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
210
+ shift = tgt_1 - scale * src_1
211
+
212
+ scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
213
+
214
+ return scale, shift
215
+
216
+ def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
217
+ """
218
+ Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
219
+ """
220
+ dtype, device = depth_src.dtype, depth_src.device
221
+
222
+ w = weight
223
+ x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
224
+ y = depth_tgt
225
+
226
+ for i in range(max_iter):
227
+ beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
228
+ w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
229
+
230
+ return beta[..., 0], beta[..., 1]
231
+
232
+
233
+ def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
234
+ """
235
+ ### Parameters:
236
+ - `points_src: torch.Tensor` of shape (..., N, 3)
237
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
238
+ - `weight: torch.Tensor` of shape (..., N)
239
+
240
+ ### Returns:
241
+ - `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
242
+ - `b: torch.Tensor` of shape (...)
243
+ """
244
+ dtype, device = points_src.dtype, points_src.device
245
+
246
+ scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
247
+
248
+ return scale
249
+
250
+
251
+ def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
252
+ """
253
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
254
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
255
+
256
+ ### Parameters:
257
+ - `points_src: torch.Tensor` of shape (..., N, 3)
258
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
259
+ - `weights: torch.Tensor` of shape (..., N)
260
+
261
+ ### Returns:
262
+ - `scale: torch.Tensor` of shape (...).
263
+ - `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
264
+ """
265
+ dtype, device = points_src.dtype, points_src.device
266
+
267
+ # Flatten batch dimensions for simplicity
268
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
269
+ batch_size = math.prod(batch_shape)
270
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
271
+
272
+ # Take anchors
273
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
274
+ with torch.no_grad():
275
+ zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
276
+ points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
277
+ points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
278
+
279
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
280
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
281
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
282
+
283
+ # Solve optimal scale and shift for each anchor
284
+ MAX_ELEMENTS = 2 ** 20
285
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
286
+
287
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
288
+
289
+ # Reproduce by indexing for shorter compute graph
290
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
291
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
292
+
293
+ zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
294
+ points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
295
+ tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
296
+ tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
297
+
298
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
299
+ shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
300
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
301
+
302
+ return scale, shift
303
+
304
+
305
+ def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
306
+ """
307
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
308
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
309
+
310
+ ### Parameters:
311
+ - `points_src: torch.Tensor` of shape (..., N, 3)
312
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
313
+ - `weights: torch.Tensor` of shape (..., N)
314
+
315
+ ### Returns:
316
+ - `scale: torch.Tensor` of shape (...).
317
+ - `shift: torch.Tensor` of shape (..., 3)
318
+ """
319
+ dtype, device = points_src.dtype, points_src.device
320
+
321
+ # Flatten batch dimensions for simplicity
322
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
323
+ batch_size = math.prod(batch_shape)
324
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
325
+
326
+ # Take anchors
327
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
328
+
329
+ with torch.no_grad():
330
+ points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
331
+ points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
332
+
333
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
334
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
335
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
336
+
337
+ # Solve optimal scale and shift for each anchor
338
+ MAX_ELEMENTS = 2 ** 20
339
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
340
+
341
+ # Get optimal scale and shift for each batch element
342
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
343
+
344
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
345
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
346
+
347
+ src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
348
+ src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
349
+
350
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
351
+ shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
352
+
353
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
354
+
355
+ return scale, shift
356
+
357
+
358
+ def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
359
+ """
360
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
361
+
362
+ ### Parameters:
363
+ - `points_src: torch.Tensor` of shape (..., N, 3)
364
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
365
+ - `weights: torch.Tensor` of shape (..., N)
366
+
367
+ ### Returns:
368
+ - `scale: torch.Tensor` of shape (...).
369
+ - `shift: torch.Tensor` of shape (..., 3)
370
+ """
371
+ dtype, device = points_src.dtype, points_src.device
372
+
373
+ shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
374
+ shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
375
+
376
+ return shift
377
+
378
+
379
+ def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
380
+ """
381
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
382
+
383
+ ### Parameters:
384
+ - `points_src: torch.Tensor` of shape (..., N, 3)
385
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
386
+ - `weights: torch.Tensor` of shape (..., N)
387
+
388
+ ### Returns:
389
+ - `scale: torch.Tensor` of shape (...).
390
+ - `shift: torch.Tensor` of shape (..., 3)
391
+ """
392
+ dtype, device = points_src.dtype, points_src.device
393
+
394
+ shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
395
+
396
+ return shift
397
+
398
+
399
+ def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
400
+ """
401
+ Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
402
+
403
+ ### Parameters:
404
+ - `x: torch.Tensor` of shape (..., N)
405
+ - `y: torch.Tensor` of shape (..., N)
406
+ - `w: torch.Tensor` of shape (..., N)
407
+
408
+ ### Returns:
409
+ - `a: torch.Tensor` of shape (...,)
410
+ - `b: torch.Tensor` of shape (...,)
411
+ """
412
+ w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
413
+ A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
414
+ B = (w_sqrt * y)[..., None]
415
+ a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
416
+ return a, b
417
+
418
+
419
+ def align_affine_lstsq_z_shift(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
420
+ """
421
+ Solve `min sum_i w_i * ||a * x_i + b - y_i||^2`, where x_i and y_i are 3D points,
422
+ `a` is a scalar (isotropic scaling), and `b` is a translation vector of the form `[0, 0, shift_z]`.
423
+ The minimization is with respect to `a` (scalar_scale) and `shift_z`.
424
+
425
+ The input point clouds x and y are expected to have a shape like (..., N, 3),
426
+ where N is the number of points and the last dimension has size 3 (X, Y, Z).
427
+ The weights w, if provided, should have shape (..., N) corresponding to the points.
428
+
429
+ This function adapts the structure of a 1D affine least squares solver to this specific
430
+ 3D problem by reformulating the design matrix A and observation vector B for torch.linalg.lstsq.
431
+
432
+ Parameters:
433
+ - `x: torch.Tensor` of shape (..., N, 3), representing the source point cloud.
434
+ - `y: torch.Tensor` of shape (..., N, 3), representing the target point cloud.
435
+ - `w: torch.Tensor` (optional) of shape (..., N), representing weights for each point.
436
+ If None, all points are weighted equally.
437
+
438
+ Returns:
439
+ - `a: torch.Tensor` of shape (...,), the scalar scaling factor.
440
+ - `b: torch.Tensor` of shape (..., 3), the translation vector `[0, 0, shift_z]`.
441
+ """
442
+ if x.shape[-1] != 3 or y.shape[-1] != 3:
443
+ raise ValueError("Input tensors x and y must have 3 features in the last dimension (X, Y, Z). "
444
+ f"Got x shape: {x.shape}, y shape: {y.shape}")
445
+ # Check all dimensions except the last one (feature dimension)
446
+ if x.shape[:-1] != y.shape[:-1]:
447
+ raise ValueError("Input tensors x and y must have matching shapes up to the last dimension. "
448
+ f"Got x shape: {x.shape}, y shape: {y.shape}")
449
+ if w is not None and w.shape != x.shape[:-1]:
450
+ raise ValueError("Weights w, if provided, must have shape (..., N) matching x and y's point dimensions. "
451
+ f"Got w shape: {w.shape}, x shape: {x.shape}")
452
+
453
+ # Determine batch shape and number of points
454
+ # Example: x shape (B1, B2, N, 3) -> batch_shape (B1, B2), num_points N
455
+ batch_shape = x.shape[:-2]
456
+ num_points = x.shape[-2]
457
+
458
+ # Prepare w_sqrt. If w is None, use unit weights.
459
+ # w_sqrt_points will have shape (..., N)
460
+ if w is None:
461
+ w_sqrt_points = torch.ones(*batch_shape, num_points, device=x.device, dtype=x.dtype)
462
+ else:
463
+ w_sqrt_points = w.sqrt()
464
+
465
+ # Dimension along which to concatenate point data from different coordinates (X, Y, Z)
466
+ dim_to_cat = len(batch_shape)
467
+
468
+ # Coefficients for 'a_val' (the scalar scale)
469
+ s_terms_x = w_sqrt_points * x[..., :, 0] # Shape (..., N)
470
+ s_terms_y = w_sqrt_points * x[..., :, 1] # Shape (..., N)
471
+ s_terms_z = w_sqrt_points * x[..., :, 2] # Shape (..., N)
472
+ a_val_coeff_column = torch.cat([s_terms_x, s_terms_y, s_terms_z], dim=dim_to_cat) # Shape (..., 3*N)
473
+
474
+ # Coefficients for 'shift_z_val'
475
+ zeros_for_shift_coeffs = torch.zeros_like(s_terms_x) # Shape (..., N)
476
+ shift_z_val_coeff_column = torch.cat([zeros_for_shift_coeffs, zeros_for_shift_coeffs, w_sqrt_points], dim=dim_to_cat) # Shape (..., 3*N)
477
+
478
+ # Construct the design matrix A_ls (shape (..., 3*N, 2))
479
+ A_ls = torch.stack([a_val_coeff_column, shift_z_val_coeff_column], dim=-1)
480
+
481
+ # Construct the observation vector B_ls (shape (..., 3*N, 1))
482
+ B_terms_x = w_sqrt_points * y[..., :, 0] # Shape (..., N)
483
+ B_terms_y = w_sqrt_points * y[..., :, 1] # Shape (..., N)
484
+ B_terms_z = w_sqrt_points * y[..., :, 2] # Shape (..., N)
485
+ B_ls_flat = torch.cat([B_terms_x, B_terms_y, B_terms_z], dim=dim_to_cat) # Shape (..., 3*N)
486
+ B_ls = B_ls_flat.unsqueeze(-1)
487
+
488
+ # Solve the least squares problem
489
+ solution = torch.linalg.lstsq(A_ls, B_ls)[0] # solution shape (..., 2, 1)
490
+
491
+ # Extract the scalar scale 'a_val' and 'shift_z_val'
492
+ a_val = solution[..., 0, 0] # Shape (...,)
493
+ shift_z_val = solution[..., 1, 0] # Shape (...,)
494
+
495
+ # Construct the output translation vector b = [0, 0, shift_z_val]
496
+ zeros_for_b = torch.zeros_like(a_val)
497
+ b_vector = torch.stack([zeros_for_b, zeros_for_b, shift_z_val], dim=-1) # Shape (..., 3)
498
+
499
+ return a_val, b_vector
flow3r/utils/basic.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import math
4
+ import cv2
5
+ from PIL import Image
6
+ import torch
7
+ from torchvision import transforms
8
+ from plyfile import PlyData, PlyElement
9
+ import numpy as np
10
+
11
+ def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000):
12
+ """
13
+ Loads images from a directory or video, resizes them to a uniform size,
14
+ then converts and stacks them into a single [N, 3, H, W] PyTorch tensor.
15
+ """
16
+ sources = []
17
+
18
+ # --- 1. Load image paths or video frames ---
19
+ if osp.isdir(path):
20
+ print(f"Loading images from directory: {path}")
21
+ filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))])
22
+ for i in range(0, len(filenames), interval):
23
+ img_path = osp.join(path, filenames[i])
24
+ try:
25
+ sources.append(Image.open(img_path).convert('RGB'))
26
+ except Exception as e:
27
+ print(f"Could not load image {filenames[i]}: {e}")
28
+ elif path.lower().endswith('.mp4'):
29
+ print(f"Loading frames from video: {path}")
30
+ cap = cv2.VideoCapture(path)
31
+ if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}")
32
+ frame_idx = 0
33
+ while True:
34
+ ret, frame = cap.read()
35
+ if not ret: break
36
+ if frame_idx % interval == 0:
37
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
38
+ sources.append(Image.fromarray(rgb_frame))
39
+ frame_idx += 1
40
+ cap.release()
41
+ else:
42
+ raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}")
43
+
44
+ if not sources:
45
+ print("No images found or loaded.")
46
+ return torch.empty(0)
47
+
48
+ print(f"Found {len(sources)} images/frames. Processing...")
49
+
50
+ # --- 2. Determine a uniform target size for all images based on the first image ---
51
+ # This is necessary to ensure all tensors have the same dimensions for stacking.
52
+ first_img = sources[0]
53
+ W_orig, H_orig = first_img.size
54
+ scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1
55
+ W_target, H_target = W_orig * scale, H_orig * scale
56
+ k, m = round(W_target / 14), round(H_target / 14)
57
+ while (k * 14) * (m * 14) > PIXEL_LIMIT:
58
+ if k / m > W_target / H_target: k -= 1
59
+ else: m -= 1
60
+ TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14
61
+ print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})")
62
+
63
+ # --- 3. Resize images and convert them to tensors in the [0, 1] range ---
64
+ tensor_list = []
65
+ # Define a transform to convert a PIL Image to a CxHxW tensor and normalize to [0,1]
66
+ to_tensor_transform = transforms.ToTensor()
67
+
68
+ for img_pil in sources:
69
+ try:
70
+ # Resize to the uniform target size
71
+ resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS)
72
+ # Convert to tensor
73
+ img_tensor = to_tensor_transform(resized_img)
74
+ tensor_list.append(img_tensor)
75
+ except Exception as e:
76
+ print(f"Error processing an image: {e}")
77
+
78
+ if not tensor_list:
79
+ print("No images were successfully processed.")
80
+ return torch.empty(0)
81
+
82
+ # --- 4. Stack the list of tensors into a single [N, C, H, W] batch tensor ---
83
+ return torch.stack(tensor_list, dim=0)
84
+
85
+
86
+ def tensor_to_pil(tensor):
87
+ """
88
+ Converts a PyTorch tensor to a PIL image. Automatically moves the channel dimension
89
+ (if it has size 3) to the last axis before converting.
90
+
91
+ Args:
92
+ tensor (torch.Tensor): Input tensor. Expected shape can be [C, H, W], [H, W, C], or [H, W].
93
+
94
+ Returns:
95
+ PIL.Image: The converted PIL image.
96
+ """
97
+ if torch.is_tensor(tensor):
98
+ array = tensor.detach().cpu().numpy()
99
+ else:
100
+ array = tensor
101
+
102
+ return array_to_pil(array)
103
+
104
+
105
+ def array_to_pil(array):
106
+ """
107
+ Converts a NumPy array to a PIL image. Automatically:
108
+ - Squeezes dimensions of size 1.
109
+ - Moves the channel dimension (if it has size 3) to the last axis.
110
+
111
+ Args:
112
+ array (np.ndarray): Input array. Expected shape can be [C, H, W], [H, W, C], or [H, W].
113
+
114
+ Returns:
115
+ PIL.Image: The converted PIL image.
116
+ """
117
+ # Remove singleton dimensions
118
+ array = np.squeeze(array)
119
+
120
+ # Ensure the array has the channel dimension as the last axis
121
+ if array.ndim == 3 and array.shape[0] == 3: # If the channel is the first axis
122
+ array = np.transpose(array, (1, 2, 0)) # Move channel to the last axis
123
+
124
+ # Handle single-channel grayscale images
125
+ if array.ndim == 2: # [H, W]
126
+ return Image.fromarray((array * 255).astype(np.uint8), mode="L")
127
+ elif array.ndim == 3 and array.shape[2] == 3: # [H, W, C] with 3 channels
128
+ return Image.fromarray((array * 255).astype(np.uint8), mode="RGB")
129
+ else:
130
+ raise ValueError(f"Unsupported array shape for PIL conversion: {array.shape}")
131
+
132
+
133
+ def rotate_target_dim_to_last_axis(x, target_dim=3):
134
+ shape = x.shape
135
+ axis_to_move = -1
136
+ # Iterate backwards to find the first occurrence from the end
137
+ # (which corresponds to the last dimension of size 3 in the original order).
138
+ for i in range(len(shape) - 1, -1, -1):
139
+ if shape[i] == target_dim:
140
+ axis_to_move = i
141
+ break
142
+
143
+ # 2. If the axis is found and it's not already in the last position, move it.
144
+ if axis_to_move != -1 and axis_to_move != len(shape) - 1:
145
+ # Create the new dimension order.
146
+ dims_order = list(range(len(shape)))
147
+ dims_order.pop(axis_to_move)
148
+ dims_order.append(axis_to_move)
149
+
150
+ # Use permute to reorder the dimensions.
151
+ ret = x.transpose(*dims_order)
152
+ else:
153
+ ret = x
154
+
155
+ return ret
156
+
157
+
158
+ def write_ply(
159
+ xyz,
160
+ rgb=None,
161
+ path='output.ply',
162
+ ) -> None:
163
+ if torch.is_tensor(xyz):
164
+ xyz = xyz.detach().cpu().numpy()
165
+
166
+ if torch.is_tensor(rgb):
167
+ rgb = rgb.detach().cpu().numpy()
168
+
169
+ if rgb is not None and rgb.max() > 1:
170
+ rgb = rgb / 255.
171
+
172
+ xyz = rotate_target_dim_to_last_axis(xyz, 3)
173
+ xyz = xyz.reshape(-1, 3)
174
+
175
+ if rgb is not None:
176
+ rgb = rotate_target_dim_to_last_axis(rgb, 3)
177
+ rgb = rgb.reshape(-1, 3)
178
+
179
+ if rgb is None:
180
+ min_coord = np.min(xyz, axis=0)
181
+ max_coord = np.max(xyz, axis=0)
182
+ normalized_coord = (xyz - min_coord) / (max_coord - min_coord + 1e-8)
183
+
184
+ hue = 0.7 * normalized_coord[:,0] + 0.2 * normalized_coord[:,1] + 0.1 * normalized_coord[:,2]
185
+ hsv = np.stack([hue, 0.9*np.ones_like(hue), 0.8*np.ones_like(hue)], axis=1)
186
+
187
+ c = hsv[:,2:] * hsv[:,1:2]
188
+ x = c * (1 - np.abs( (hsv[:,0:1]*6) % 2 - 1 ))
189
+ m = hsv[:,2:] - c
190
+
191
+ rgb = np.zeros_like(hsv)
192
+ cond = (0 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 1)
193
+ rgb[cond] = np.hstack([c[cond], x[cond], np.zeros_like(x[cond])])
194
+ cond = (1 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 2)
195
+ rgb[cond] = np.hstack([x[cond], c[cond], np.zeros_like(x[cond])])
196
+ cond = (2 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 3)
197
+ rgb[cond] = np.hstack([np.zeros_like(x[cond]), c[cond], x[cond]])
198
+ cond = (3 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 4)
199
+ rgb[cond] = np.hstack([np.zeros_like(x[cond]), x[cond], c[cond]])
200
+ cond = (4 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 5)
201
+ rgb[cond] = np.hstack([x[cond], np.zeros_like(x[cond]), c[cond]])
202
+ cond = (5 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 6)
203
+ rgb[cond] = np.hstack([c[cond], np.zeros_like(x[cond]), x[cond]])
204
+ rgb = (rgb + m)
205
+
206
+ dtype = [
207
+ ("x", "f4"),
208
+ ("y", "f4"),
209
+ ("z", "f4"),
210
+ ("nx", "f4"),
211
+ ("ny", "f4"),
212
+ ("nz", "f4"),
213
+ ("red", "u1"),
214
+ ("green", "u1"),
215
+ ("blue", "u1"),
216
+ ]
217
+ normals = np.zeros_like(xyz)
218
+ elements = np.empty(xyz.shape[0], dtype=dtype)
219
+ attributes = np.concatenate((xyz, normals, rgb * 255), axis=1)
220
+ elements[:] = list(map(tuple, attributes))
221
+ vertex_element = PlyElement.describe(elements, "vertex")
222
+ ply_data = PlyData([vertex_element])
223
+ ply_data.write(path)
flow3r/utils/cropping.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # croppping utilities
6
+ # --------------------------------------------------------
7
+ import PIL.Image
8
+ import os
9
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
10
+ import cv2 # noqa
11
+ import numpy as np # noqa
12
+ try:
13
+ lanczos = PIL.Image.Resampling.LANCZOS
14
+ bicubic = PIL.Image.Resampling.BICUBIC
15
+ except AttributeError:
16
+ lanczos = PIL.Image.LANCZOS
17
+ bicubic = PIL.Image.BICUBIC
18
+
19
+ from utils.basic import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
20
+
21
+ class ImageList:
22
+ """ Convenience class to aply the same operation to a whole set of images.
23
+ """
24
+
25
+ def __init__(self, images):
26
+ if not isinstance(images, (tuple, list, set)):
27
+ images = [images]
28
+ self.images = []
29
+ for image in images:
30
+ if not isinstance(image, PIL.Image.Image):
31
+ image = PIL.Image.fromarray(image)
32
+ self.images.append(image)
33
+
34
+ def __len__(self):
35
+ return len(self.images)
36
+
37
+ def to_pil(self):
38
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
39
+
40
+ @property
41
+ def size(self):
42
+ sizes = [im.size for im in self.images]
43
+ assert all(sizes[0] == s for s in sizes)
44
+ return sizes[0]
45
+
46
+ def resize(self, *args, **kwargs):
47
+ return ImageList(self._dispatch('resize', *args, **kwargs))
48
+
49
+ def crop(self, *args, **kwargs):
50
+ return ImageList(self._dispatch('crop', *args, **kwargs))
51
+
52
+ def _dispatch(self, func, *args, **kwargs):
53
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
54
+
55
+
56
+ def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True, normal=None, far_mask=None):
57
+ """ Jointly rescale a (image, depthmap)
58
+ so that (out_width, out_height) >= output_res
59
+ """
60
+ image = ImageList(image)
61
+ input_resolution = np.array(image.size) # (W,H)
62
+ output_resolution = np.array(output_resolution)
63
+ if depthmap is not None:
64
+ # can also use this with masks instead of depthmaps
65
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
66
+
67
+ # define output resolution
68
+ assert output_resolution.shape == (2,)
69
+ scale_final = max(output_resolution / image.size) + 1e-8
70
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
71
+ return (image.to_pil(), depthmap, camera_intrinsics)
72
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
73
+
74
+ # first rescale the image so that it contains the crop
75
+ image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic)
76
+ if depthmap is not None:
77
+ depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
78
+ fy=scale_final, interpolation=cv2.INTER_NEAREST)
79
+
80
+ if normal is not None:
81
+ normal = cv2.resize(normal, output_resolution, fx=scale_final,
82
+ fy=scale_final, interpolation=cv2.INTER_NEAREST)
83
+ if far_mask is not None:
84
+ far_mask = cv2.resize(far_mask, output_resolution, fx=scale_final,
85
+ fy=scale_final, interpolation=cv2.INTER_NEAREST)
86
+
87
+ # no offset here; simple rescaling
88
+ camera_intrinsics = camera_matrix_of_crop(
89
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
90
+
91
+ return image.to_pil(), depthmap, camera_intrinsics, normal, far_mask
92
+
93
+ def center_crop_image_depthmap(image, depthmap, camera_intrinsics, crop_scale, normal=None, far_mask=None):
94
+ """
95
+ Jointly center-crop an image and its depthmap, and adjust the camera intrinsics accordingly.
96
+
97
+ Parameters:
98
+ - image: PIL.Image or similar, the input image.
99
+ - depthmap: np.ndarray, the corresponding depth map.
100
+ - camera_intrinsics: np.ndarray, the 3x3 camera intrinsics matrix.
101
+ - crop_scale: float between 0 and 1, the fraction of the image to keep.
102
+
103
+ Returns:
104
+ - cropped_image: PIL.Image, the center-cropped image.
105
+ - cropped_depthmap: np.ndarray, the center-cropped depth map.
106
+ - adjusted_intrinsics: np.ndarray, the adjusted camera intrinsics matrix.
107
+ """
108
+ # Ensure crop_scale is valid
109
+ assert 0 < crop_scale <= 1, "crop_scale must be between 0 and 1"
110
+
111
+ # Convert image to ImageList for consistent processing
112
+ image = ImageList(image)
113
+ input_resolution = np.array(image.size) # (width, height)
114
+ if depthmap is not None:
115
+ # Ensure depthmap matches the image size
116
+ assert depthmap.shape[:2] == tuple(image.size[::-1]), "Depthmap size must match image size"
117
+
118
+ # Compute output resolution after cropping
119
+ output_resolution = np.floor(input_resolution * crop_scale).astype(int)
120
+ # get the correct crop_scale
121
+ crop_scale = output_resolution / input_resolution
122
+
123
+ # Compute margins (amount to crop from each side)
124
+ margins = input_resolution - output_resolution
125
+ offset = margins / 2 # Since we are center cropping
126
+
127
+ # Calculate the crop bounding box
128
+ l, t = offset.astype(int)
129
+ r = l + output_resolution[0]
130
+ b = t + output_resolution[1]
131
+ crop_bbox = (l, t, r, b)
132
+
133
+ # Crop the image and depthmap
134
+ image = image.crop(crop_bbox)
135
+ if depthmap is not None:
136
+ depthmap = depthmap[t:b, l:r]
137
+ if normal is not None:
138
+ normal = normal[t:b, l:r]
139
+ if far_mask is not None:
140
+ far_mask = far_mask[t:b, l:r]
141
+
142
+ # Adjust the camera intrinsics
143
+ adjusted_intrinsics = camera_intrinsics.copy()
144
+
145
+ # Adjust focal lengths (fx, fy) # no need to adjust focal lengths for cropping
146
+ # adjusted_intrinsics[0, 0] /= crop_scale[0] # fx
147
+ # adjusted_intrinsics[1, 1] /= crop_scale[1] # fy
148
+
149
+ # Adjust principal point (cx, cy)
150
+ adjusted_intrinsics[0, 2] -= l # cx
151
+ adjusted_intrinsics[1, 2] -= t # cy
152
+
153
+ return image.to_pil(), depthmap, adjusted_intrinsics, normal, far_mask
154
+
155
+
156
+ def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
157
+ # Margins to offset the origin
158
+ margins = np.asarray(input_resolution) * scaling - output_resolution
159
+ assert np.all(margins >= 0.0)
160
+ if offset is None:
161
+ offset = offset_factor * margins
162
+
163
+ # Generate new camera parameters
164
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
165
+ output_camera_matrix_colmap[:2, :] *= scaling
166
+ output_camera_matrix_colmap[:2, 2] -= offset
167
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
168
+
169
+ return output_camera_matrix
170
+
171
+
172
+ def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox, normal=None, far_mask=None):
173
+ """
174
+ Return a crop of the input view.
175
+ """
176
+ image = ImageList(image)
177
+ l, t, r, b = crop_bbox
178
+
179
+ image = image.crop((l, t, r, b))
180
+ depthmap = depthmap[t:b, l:r]
181
+ if normal is not None:
182
+ normal = normal[t:b, l:r]
183
+ if far_mask is not None:
184
+ far_mask = far_mask[t:b, l:r]
185
+
186
+ camera_intrinsics = camera_intrinsics.copy()
187
+ camera_intrinsics[0, 2] -= l
188
+ camera_intrinsics[1, 2] -= t
189
+
190
+ return image.to_pil(), depthmap, camera_intrinsics, normal, far_mask
191
+
192
+
193
+ def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
194
+ out_width, out_height = output_resolution
195
+ l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
196
+ crop_bbox = (l, t, l + out_width, t + out_height)
197
+ return crop_bbox
flow3r/utils/debug.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import debugpy
4
+ import socket
5
+ import random
6
+
7
+ def update_vscode_launch_file(host: str, port: int):
8
+ """Update the .vscode/launch.json file with the new host and port."""
9
+ launch_file_path = ".vscode/launch.json"
10
+ # Desired configuration
11
+ new_config = {
12
+ "version": "0.2.0",
13
+ "configurations": [
14
+ {
15
+ "name": "bash_debug",
16
+ "type": "debugpy",
17
+ "request": "attach",
18
+ "connect": {
19
+ "host": host,
20
+ "port": port
21
+ },
22
+ "justMyCode": False
23
+ },
24
+ ]
25
+ }
26
+
27
+ # Ensure the .vscode directory exists
28
+ if not os.path.exists(".vscode"):
29
+ os.makedirs(".vscode")
30
+
31
+ # Write the updated configuration to launch.json
32
+ with open(launch_file_path, "w") as f:
33
+ json.dump(new_config, f, indent=4)
34
+ print(f"Updated {launch_file_path} with host: {host} and port: {port}")
35
+
36
+ def is_port_in_use(host, port):
37
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
38
+ return s.connect_ex((host, port)) == 0
39
+
40
+ def setup_debug(is_main_process=True, max_retries=10, port_range=(10000, 20000)):
41
+ if is_main_process:
42
+ host = os.environ['SLURM_NODELIST'].split(',')[0]
43
+
44
+ for _ in range(max_retries):
45
+ port = random.randint(*port_range)
46
+ try:
47
+ if is_port_in_use(host, port):
48
+ print(f"Port {port} is already in use, trying another...")
49
+ continue
50
+
51
+ # 更新 launch.json
52
+ update_vscode_launch_file(host, port)
53
+
54
+ print("master_addr = ", host)
55
+ debugpy.listen((host, port))
56
+ print(f"Waiting for debugger attach at port {port}...", flush=True)
57
+ debugpy.wait_for_client()
58
+ print("Debugger attached", flush=True)
59
+ return
60
+ except Exception as e:
61
+ print(f"Failed to bind to port {port}: {e}")
62
+
63
+ raise RuntimeError("Could not find a free port for debugpy after several attempts.")
flow3r/utils/flow_utils.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ import flow_vis
6
+ from .geometry import se3_inverse, homogenize_points
7
+ import torch.nn.functional as F
8
+ import matplotlib.pyplot as plt
9
+ import wandb
10
+
11
+ def warp_image_with_flow(source_image, source_mask, target_image, flow) -> np.ndarray:
12
+ """
13
+ Warp the target to source image using the given flow vectors.
14
+ Flow vectors indicate the displacement from source to target.
15
+
16
+ Args:
17
+ source_image: np.ndarray of shape (H, W, 3), normalized to [0, 1]
18
+ target_image: np.ndarray of shape (H, W, 3), normalized to [0, 1]
19
+ flow: np.ndarray of shape (H, W, 2)
20
+ source_mask: non_occluded mask represented in source image.
21
+
22
+ Returns:
23
+ warped_image: target_image warped according to flow into frame of source image
24
+ np.ndarray of shape (H, W, 3), normalized to [0, 1]
25
+
26
+ """
27
+ # assert source_image.shape[-1] == 3
28
+ # assert target_image.shape[-1] == 3
29
+
30
+ assert flow.shape[-1] == 2
31
+
32
+ # Get the shape of the source image
33
+ height, width = source_image.shape[:2]
34
+ target_height, target_width = target_image.shape[:2]
35
+
36
+ # Create mesh grid
37
+ x, y = np.meshgrid(np.arange(width), np.arange(height))
38
+
39
+ # Apply flow displacements
40
+ flow_x, flow_y = flow[..., 0], flow[..., 1]
41
+ x_new = np.clip(x + flow_x, 0, target_width - 1) + 0.5
42
+ y_new = np.clip(y + flow_y, 0, target_height - 1) + 0.5
43
+
44
+ x_new = (x_new / target_image.shape[1]) * 2 - 1
45
+ y_new = (y_new / target_image.shape[0]) * 2 - 1
46
+
47
+ warped_image = F.grid_sample(
48
+ torch.from_numpy(target_image).permute(2, 0, 1)[None, ...].float(),
49
+ torch.from_numpy(np.stack([x_new, y_new], axis=-1)).float()[None, ...],
50
+ mode="bilinear",
51
+ align_corners=False,
52
+ )
53
+
54
+ warped_image = warped_image[0].permute(1, 2, 0).numpy()
55
+
56
+ if source_mask is not None:
57
+ warped_image = warped_image * (source_mask > 0.5)[..., None]
58
+
59
+ return warped_image
60
+
61
+ def ndc_to_pixel_coords(coords_ndc: torch.Tensor, H: int, W: int) -> torch.Tensor:
62
+ """
63
+ Convert coordinates from NDC space back to pixel space.
64
+
65
+ Args:
66
+ coords_ndc: [..., H, W, 2], coordinates in NDC space (x_ndc, y_ndc)
67
+ H, W: image dimensions
68
+
69
+ Returns:
70
+ coords_px: [..., H, W, 2], coordinates in pixel space (x_pix, y_pix)
71
+ """
72
+ coords_px = coords_ndc.clone()
73
+
74
+ # Convert x: NDC [1, -1] -> pixel [0, W-1]
75
+ coords_px[..., 0] = (1.0 - coords_ndc[..., 0]) * max(W - 1, 1) / 2.0
76
+
77
+ # Convert y: NDC [1, -1] -> pixel [0, H-1]
78
+ coords_px[..., 1] = (1.0 - coords_ndc[..., 1]) * max(H - 1, 1) / 2.0
79
+
80
+ return coords_px
81
+
82
+ def coords_to_flow(coords: torch.Tensor, H: int, W: int) -> torch.Tensor:
83
+ """
84
+ Convert coordinates to flow by subtracting source pixel coordinates.
85
+
86
+ Args:
87
+ coords: [..., H, W, 2], target coordinates (where pixels from source appear)
88
+ H, W: image dimensions
89
+
90
+ Returns:
91
+ flow: [..., H, W, 2], optical flow (displacement vectors)
92
+ """
93
+ device = coords.device
94
+
95
+ # Create source coordinate grid
96
+ grid_y, grid_x = torch.meshgrid(
97
+ torch.arange(H, device=device),
98
+ torch.arange(W, device=device),
99
+ indexing="ij"
100
+ )
101
+ source_coords = torch.stack([grid_x, grid_y], dim=-1).float() # (H, W, 2)
102
+
103
+ # Compute flow as target - source
104
+ flow = coords - source_coords
105
+
106
+ return flow
107
+
108
+ def flow_to_coords(flow: torch.Tensor, H: int, W: int) -> torch.Tensor:
109
+ """
110
+ Convert optical flow to absolute target coordinates.
111
+
112
+ Args:
113
+ flow: [..., H, W, 2], optical flow (displacement vectors)
114
+ H, W: image dimensions
115
+
116
+ Returns:
117
+ coords: [..., H, W, 2], absolute target coordinates (pixel positions in target image)
118
+ """
119
+ device = flow.device
120
+
121
+ # Create source coordinate grid
122
+ grid_y, grid_x = torch.meshgrid(
123
+ torch.arange(H, device=device),
124
+ torch.arange(W, device=device),
125
+ indexing="ij"
126
+ )
127
+ source_coords = torch.stack([grid_x, grid_y], dim=-1).float() # (H, W, 2)
128
+
129
+ # Compute absolute target coordinates
130
+ coords = flow + source_coords
131
+
132
+ return coords
133
+
134
+ def ndc_pixels_to_flow(flow_ndc: torch.Tensor, H: int, W: int) -> torch.Tensor:
135
+ """
136
+ Convert optical flow from NDC space back to pixel space.
137
+
138
+ Args:
139
+ flow_ndc: [..., H, W, 2], optical flow in NDC (dx_ndc, dy_ndc),
140
+ PyTorch3D NDC convention: +x left, +y up, origin at image center.
141
+ H, W: image height and width.
142
+
143
+ Returns:
144
+ flow_px: [..., H, W, 2], optical flow in pixel space (dx_pix, dy_pix),
145
+ screen convention: +x right, +y down, origin at top-left.
146
+ """
147
+ # Inverse of: dx_ndc = -2/(W-1)*dx_pix, dy_ndc = -2/(H-1)*dy_pix
148
+ sx = 2.0 / max(W - 1, 1)
149
+ sy = 2.0 / max(H - 1, 1)
150
+
151
+ flow_px = flow_ndc.clone()
152
+ flow_px[..., 0] = - flow_ndc[..., 0] / sx # dx_pix
153
+ flow_px[..., 1] = - flow_ndc[..., 1] / sy # dy_pix
154
+ return flow_px
155
+
156
+ def coords_pixels_to_ndc(coords_px: torch.Tensor, H: int, W: int) -> torch.Tensor:
157
+ """
158
+ PyTorch3D convention:
159
+ - NDC space: x ∈ [1, -1] (+x left), y ∈ [-1, 1] (+y up), origin at center
160
+ - Pixel space: x ∈ [0, W-1] (+x right), y ∈ [0, H-1] (+y down), origin at top-left
161
+ """
162
+ coords_ndc = coords_px.clone()
163
+
164
+ # Convert x: pixel [0, W-1] (left→right) -> NDC [1, -1] (left→right in NDC means 1→-1)
165
+ coords_ndc[..., 0] = 1.0 - (coords_px[..., 0] / max(W - 1, 1)) * 2.0
166
+
167
+ # Convert y: pixel [0, H-1] (top→bottom) -> NDC [1, -1] (top→bottom in NDC means 1→-1)
168
+ coords_ndc[..., 1] = 1.0 - (coords_px[..., 1] / max(H - 1, 1)) * 2.0
169
+
170
+ return coords_ndc
171
+
172
+
173
+ def batched_pi3_motion_flow(world_points, camera_poses, camera_intrinsics, sampled_pairs, image_size):
174
+ """
175
+ Compute batched motion flow from img1 to img2 using world points and camera pose encodings.
176
+
177
+ Args:
178
+ world_points: (B, N, H, W, 3) predicted world points per image.
179
+ camera_poses: (B, N, 4, 4) extrinsics for each frame, camera-to-world.
180
+ camera_intrinsics: (B, N, 3, 3) camera intrinsics for each frame.
181
+ sampled_pairs: (B, P, 2) image pairs to compute flow between.
182
+ image_size: int, image height/width.
183
+
184
+ Returns:
185
+ flow: (B, P, H, W, 2) motion flows, (x, y) in pixel coordinates
186
+ """
187
+ B, N, H, W, _ = world_points.shape
188
+ P = sampled_pairs.shape[1]
189
+ device = world_points.device
190
+
191
+ # Gather source points
192
+ # (B, P)
193
+ src_idx = sampled_pairs[..., 0]
194
+ # (B, P, 1, 1, 1) -> (B, P, H, W, 3)
195
+ # Expand indices to gather along N dimension
196
+ src_idx_exp = src_idx.view(B, P, 1, 1, 1).expand(B, P, H, W, 3)
197
+ src_points = torch.gather(world_points, 1, src_idx_exp)
198
+
199
+ # Gather target poses and intrinsics
200
+ # (B, P)
201
+ tgt_idx = sampled_pairs[..., 1]
202
+
203
+ tgt_poses = torch.gather(camera_poses, 1, tgt_idx.view(B, P, 1, 1).expand(B, P, 4, 4))
204
+ tgt_intrinsics = torch.gather(camera_intrinsics, 1, tgt_idx.view(B, P, 1, 1).expand(B, P, 3, 3))
205
+
206
+ # Transform points to target camera frame
207
+ w2c_tgt = se3_inverse(tgt_poses)
208
+ src_points_homo = homogenize_points(src_points)
209
+
210
+ # P_cam = T_w2c @ P_world
211
+ # (B, P, 4, 4) @ (B, P, H, W, 4) -> (B, P, H, W, 4)
212
+ pts_cam = torch.einsum('bpij,bphwj->bphwi', w2c_tgt, src_points_homo)[..., :3]
213
+
214
+ # Project to image plane
215
+ # P_img = K @ P_cam
216
+ # (B, P, 3, 3) @ (B, P, H, W, 3) -> (B, P, H, W, 3)
217
+ pts_img = torch.einsum('bpij,bphwj->bphwi', tgt_intrinsics, pts_cam)
218
+
219
+ # Normalize to pixels
220
+ uv_tgt = pts_img[..., :2] / (pts_img[..., 2:3] + 1e-6)
221
+
222
+ # Generate source pixel coordinates
223
+ # print("image_size is: ", image_size)
224
+ H_img, W_img = image_size[0]
225
+
226
+ scale_h = H_img / H
227
+ scale_w = W_img / W
228
+
229
+ y, x = torch.meshgrid(
230
+ torch.arange(H, device=device, dtype=torch.float32),
231
+ torch.arange(W, device=device, dtype=torch.float32),
232
+ indexing='ij'
233
+ )
234
+
235
+ # Map grid to image coordinates (assuming center of pixels/patches)
236
+ uv_src = torch.stack([
237
+ (x + 0.5) * scale_w - 0.5,
238
+ (y + 0.5) * scale_h - 0.5
239
+ ], dim=-1) # (H, W, 2)
240
+
241
+ uv_src = uv_src.view(1, 1, H, W, 2).expand(B, P, -1, -1, -1)
242
+
243
+ return uv_tgt - uv_src
244
+
245
+
246
+ def visualize_flow(pred_motion_coords, motion_coords, covis_masks, sampled_pairs, images, pred_pi3_flow, iteration, accelerator, dataset_names):
247
+ # visualize gt images, gt flow, pred flow, flow computed from predicted cameras and points
248
+ path = f"/ocean/projects/cis250013p/zcong/pi3/outputs/flow_vis/{iteration}"
249
+ if not os.path.exists(path):
250
+ os.makedirs(path)
251
+
252
+ with torch.no_grad():
253
+ # Get dimensions
254
+ B, num_pairs = sampled_pairs.shape[0], sampled_pairs.shape[1]
255
+ H, W = motion_coords[0, 0].shape[0], motion_coords[0, 0].shape[1]
256
+
257
+ # Process all pairs for all batches
258
+ for batch_idx in range(B):
259
+ dataset_name = dataset_names[batch_idx]
260
+ for pair_idx in range(num_pairs):
261
+ if pair_idx > 1: break
262
+ # Get pair indices
263
+ pairs = sampled_pairs[batch_idx, pair_idx].cpu().numpy() # (2,)
264
+ img1 = images[batch_idx, pairs[0]].cpu().numpy()
265
+ img2 = images[batch_idx, pairs[1]].cpu().numpy()
266
+
267
+ # Convert ground truth coordinates to flow
268
+ gt_coords_ndc = motion_coords[batch_idx, pair_idx] # NDC coordinates
269
+ gt_coords_pixel = ndc_to_pixel_coords(gt_coords_ndc, H, W) # Convert to pixel coordinates
270
+ flow_tensor = coords_to_flow(gt_coords_pixel, H, W).float().cpu() # (H, W, 2)
271
+ flow = flow_tensor.numpy() # (H, W, 2)
272
+
273
+ covis_mask = covis_masks[batch_idx, pair_idx].float().cpu().numpy() # (H, W)
274
+ masked_flow = flow * covis_mask[..., None]
275
+
276
+ # Convert predicted coordinates to flow
277
+ pred_coords_ndc = pred_motion_coords[batch_idx, pair_idx] # NDC coordinates
278
+ pred_coords_pixel = ndc_to_pixel_coords(pred_coords_ndc, H, W) # Convert to pixel coordinates
279
+ pred_flow = coords_to_flow(pred_coords_pixel, H, W).float().cpu().numpy() # (H, W, 2)
280
+ masked_pred_flow = pred_flow * covis_mask[..., None]
281
+
282
+ pi3_flow = pred_pi3_flow[batch_idx, pair_idx].float().cpu().numpy() # (H, W, 2)
283
+ masked_pi3_flow = pi3_flow * covis_mask[..., None]
284
+
285
+ # warp img1 to img2
286
+ # first compute gt warpping
287
+ img1_np = np.transpose(img1, (1, 2, 0)) # [H, W, 3]
288
+ img2_np = np.transpose(img2, (1, 2, 0)) # [H, W, 3]
289
+ warped_img_gt = warp_image_with_flow(img1_np, covis_mask, img2_np, flow)
290
+ warped_img_gt = warped_img_gt.clip(0, 1)
291
+ warped_img_gt = Image.fromarray((warped_img_gt * 255).astype(np.uint8))
292
+ # compute prediction warping
293
+ warped_img_pred = warp_image_with_flow(img1_np, covis_mask, img2_np, pred_flow)
294
+ warped_img_pred = warped_img_pred.clip(0, 1)
295
+ warped_img_pred = Image.fromarray((warped_img_pred * 255).astype(np.uint8))
296
+ # compute pi3 warping
297
+ warped_img_pi3 = warp_image_with_flow(img1_np, covis_mask, img2_np, pi3_flow)
298
+ warped_img_pi3 = warped_img_pi3.clip(0, 1)
299
+ warped_img_pi3 = Image.fromarray((warped_img_pi3 * 255).astype(np.uint8))
300
+
301
+ # visualize images
302
+ img_array1 = np.transpose(img1, (1, 2, 0))
303
+ img1_pil = Image.fromarray((img_array1 * 255).astype(np.uint8))
304
+ img_array2 = np.transpose(img2, (1, 2, 0))
305
+ img2_pil = Image.fromarray((img_array2 * 255).astype(np.uint8))
306
+
307
+ # Calculate AEPE metrics
308
+ # Only calculate on valid covisible pixels
309
+ valid_mask = covis_mask > 0
310
+ if np.sum(valid_mask) > 0:
311
+ # AEPE for predicted flow vs GT flow
312
+ flow_diff_pred = np.sqrt(np.sum((masked_pred_flow - masked_flow) ** 2, axis=-1))
313
+ aepe_pred = np.mean(flow_diff_pred[valid_mask])
314
+ aepe_5px_pred = np.mean(flow_diff_pred[valid_mask] < 5.0) * 100 # percentage
315
+
316
+ # AEPE for pi3 flow vs GT flow
317
+ flow_diff_pi3 = np.sqrt(np.sum((masked_pi3_flow - masked_flow) ** 2, axis=-1))
318
+ aepe_pi3 = np.mean(flow_diff_pi3[valid_mask])
319
+ aepe_5px_pi3 = np.mean(flow_diff_pi3[valid_mask] < 5.0) * 100 # percentage
320
+ else:
321
+ aepe_pred = float('inf')
322
+ aepe_5px_pred = 0.0
323
+ aepe_pi3 = float('inf')
324
+ aepe_5px_pi3 = 0.0
325
+
326
+ # visualize flow
327
+ flow_vis_image_gt = flow_vis.flow_to_color(masked_flow)
328
+ flow_pil = Image.fromarray(flow_vis_image_gt.astype(np.uint8))
329
+ flow_vis_image_pred = flow_vis.flow_to_color(masked_pred_flow)
330
+ flow_pred_pil = Image.fromarray(flow_vis_image_pred.astype(np.uint8))
331
+ flow_vis_image_pi3 = flow_vis.flow_to_color(masked_pi3_flow)
332
+ flow_pi3_pil = Image.fromarray(flow_vis_image_pi3.astype(np.uint8))
333
+
334
+ # Create metrics text
335
+ metrics_text = {
336
+ 'pred_aepe': aepe_pred,
337
+ 'pred_5px_pct': aepe_5px_pred,
338
+ 'pi3_aepe': aepe_pi3,
339
+ 'pi3_5px_pct': aepe_5px_pi3,
340
+ 'covis_ratio': float(np.mean(covis_mask)) * 100,
341
+ 'pairs': pairs,
342
+ 'dataset': dataset_name,
343
+ }
344
+
345
+ # Save individual visualization and log to wandb
346
+ save_path = os.path.join(path, f"motion_flow_grid_batch_{batch_idx}_pair_{pair_idx}_imgs_{pairs[0]}_{pairs[1]}_iter_{iteration:08d}.png")
347
+ visualize_motion_grid_nodepth_with_metrics(
348
+ img1_pil, img2_pil, flow_pil, flow_pred_pil, flow_pi3_pil,
349
+ warped_img_gt, warped_img_pred, warped_img_pi3,
350
+ metrics_text,
351
+ save_path=save_path,
352
+ pair_idx = pair_idx,
353
+ step=iteration,
354
+ log_to_wandb=True, # We'll handle wandb logging separately
355
+ accelerator=accelerator,
356
+ dataset_name=dataset_name
357
+ )
358
+
359
+ def visualize_motion_grid_nodepth_with_metrics(img1, img2, flow_pil, flow_pred_pil, flow_pi3_pil, warped_img_gt, warped_img_pred, warped_img_pi3, metrics_text, pair_idx, save_path="motion_flow_grid.png", step=None, log_to_wandb=True, accelerator=None, dataset_name=None):
360
+ fig, axes = plt.subplots(3, 3, figsize=(20, 16))
361
+
362
+ # images
363
+ axes[0, 0].imshow(img1)
364
+ axes[0, 0].set_title(f"Image {metrics_text['pairs'][0]}")
365
+ axes[0, 0].axis("off")
366
+
367
+ axes[0, 1].imshow(img2)
368
+ axes[0, 1].set_title(f"Image {metrics_text['pairs'][1]}")
369
+ axes[0, 1].axis("off")
370
+
371
+ # Add overall metrics in the third subplot
372
+ axes[0, 2].text(0.1, 0.9, f"{metrics_text['dataset']} Pair: {metrics_text['pairs'][0]} → {metrics_text['pairs'][1]}",
373
+ fontsize=14, fontweight='bold', transform=axes[0, 2].transAxes)
374
+ axes[0, 2].text(0.1, 0.8, f"Covis Ratio: {metrics_text['covis_ratio']:.1f}%",
375
+ fontsize=12, transform=axes[0, 2].transAxes)
376
+ axes[0, 2].text(0.1, 0.7, "Pred Flow Metrics:",
377
+ fontsize=12, fontweight='bold', color='blue', transform=axes[0, 2].transAxes)
378
+ axes[0, 2].text(0.1, 0.6, f"AEPE: {metrics_text['pred_aepe']:.3f}",
379
+ fontsize=11, color='blue', transform=axes[0, 2].transAxes)
380
+ axes[0, 2].text(0.1, 0.5, f"<5px: {metrics_text['pred_5px_pct']:.1f}%",
381
+ fontsize=11, color='blue', transform=axes[0, 2].transAxes)
382
+ axes[0, 2].text(0.1, 0.4, "Pi3 Flow Metrics:",
383
+ fontsize=12, fontweight='bold', color='red', transform=axes[0, 2].transAxes)
384
+ axes[0, 2].text(0.1, 0.3, f"AEPE: {metrics_text['pi3_aepe']:.3f}",
385
+ fontsize=11, color='red', transform=axes[0, 2].transAxes)
386
+ axes[0, 2].text(0.1, 0.2, f"<5px: {metrics_text['pi3_5px_pct']:.1f}%",
387
+ fontsize=11, color='red', transform=axes[0, 2].transAxes)
388
+ axes[0, 2].set_xlim(0, 1)
389
+ axes[0, 2].set_ylim(0, 1)
390
+ axes[0, 2].axis("off")
391
+
392
+ # GT flow and Pred flow
393
+ axes[1, 0].imshow(flow_pil)
394
+ axes[1, 0].set_title("GT Motion Flow")
395
+ axes[1, 0].axis("off")
396
+
397
+ axes[1, 1].imshow(flow_pred_pil)
398
+ axes[1, 1].set_title(f"Predicted Flow\nAEPE: {metrics_text['pred_aepe']:.3f}, <5px: {metrics_text['pred_5px_pct']:.1f}%")
399
+ axes[1, 1].axis("off")
400
+
401
+ axes[1, 2].imshow(flow_pi3_pil)
402
+ axes[1, 2].set_title(f"Pi3 Flow\nAEPE: {metrics_text['pi3_aepe']:.3f}, <5px: {metrics_text['pi3_5px_pct']:.1f}%")
403
+ axes[1, 2].axis("off")
404
+
405
+ # GT warp and Pred warp
406
+ axes[2, 0].imshow(warped_img_gt)
407
+ axes[2, 0].set_title("GT Warped Image")
408
+ axes[2, 0].axis("off")
409
+
410
+ axes[2, 1].imshow(warped_img_pred)
411
+ axes[2, 1].set_title("Pred Warped Image")
412
+ axes[2, 1].axis("off")
413
+
414
+ axes[2, 2].imshow(warped_img_pi3)
415
+ axes[2, 2].set_title("PI3 Warped Image")
416
+ axes[2, 2].axis("off")
417
+
418
+ plt.tight_layout()
419
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
420
+ if log_to_wandb:
421
+ accelerator.log({f"Visualization_{pair_idx}": wandb.Image(save_path)}, step=step)
422
+ plt.close()
423
+
424
+ def calculate_flow_metrics(pred_motion_coords, motion_coords, covis_masks, sampled_pairs, pred_pi3_flow):
425
+ with torch.no_grad():
426
+ # Get dimensions
427
+ B, num_pairs = sampled_pairs.shape[0], sampled_pairs.shape[1]
428
+ H, W = motion_coords[0, 0].shape[0], motion_coords[0, 0].shape[1]
429
+ aepe_pred, aepe_5px_pred, aepe_pi3, aepe_5px_pi3 = [], [], [], []
430
+
431
+ # Process all pairs for all batches
432
+ for batch_idx in range(B):
433
+ for pair_idx in range(num_pairs):
434
+ # Convert ground truth coordinates to flow
435
+ gt_coords_ndc = motion_coords[batch_idx, pair_idx] # NDC coordinates
436
+ gt_coords_pixel = ndc_to_pixel_coords(gt_coords_ndc, H, W) # Convert to pixel coordinates
437
+ flow_tensor = coords_to_flow(gt_coords_pixel, H, W).float().cpu() # (H, W, 2)
438
+ flow = flow_tensor.numpy() # (H, W, 2)
439
+
440
+ covis_mask = covis_masks[batch_idx, pair_idx].float().cpu().numpy() # (H, W)
441
+ masked_flow = flow * covis_mask[..., None]
442
+
443
+ # Convert predicted coordinates to flow
444
+ pred_coords_ndc = pred_motion_coords[batch_idx, pair_idx] # NDC coordinates
445
+ pred_coords_pixel = ndc_to_pixel_coords(pred_coords_ndc, H, W) # Convert to pixel coordinates
446
+ pred_flow = coords_to_flow(pred_coords_pixel, H, W).float().cpu().numpy() # (H, W, 2)
447
+ masked_pred_flow = pred_flow * covis_mask[..., None]
448
+
449
+ pi3_flow = pred_pi3_flow[batch_idx, pair_idx].float().cpu().numpy() # (H, W, 2)
450
+ masked_pi3_flow = pi3_flow * covis_mask[..., None]
451
+
452
+ # Calculate AEPE metrics
453
+ # Only calculate on valid covisible pixels
454
+ valid_mask = covis_mask > 0
455
+ if np.sum(valid_mask) > 0:
456
+ # AEPE for predicted flow vs GT flow
457
+ flow_diff_pred = np.sqrt(np.sum((masked_pred_flow - masked_flow) ** 2, axis=-1))
458
+ aepe_pred.append(np.mean(flow_diff_pred[valid_mask]))
459
+ aepe_5px_pred.append(np.mean(flow_diff_pred[valid_mask] < 5.0) * 100) # percentage
460
+
461
+ # AEPE for pi3 flow vs GT flow
462
+ flow_diff_pi3 = np.sqrt(np.sum((masked_pi3_flow - masked_flow) ** 2, axis=-1))
463
+ aepe_pi3.append(np.mean(flow_diff_pi3[valid_mask]))
464
+ aepe_5px_pi3.append(np.mean(flow_diff_pi3[valid_mask] < 5.0) * 100) # percentage
465
+ else:
466
+ aepe_pred.append(float('inf'))
467
+ aepe_5px_pred.append(0.0)
468
+ aepe_pi3.append(float('inf'))
469
+ aepe_5px_pi3.append(0.0)
470
+
471
+ # print("aepe 5px pi3 is",aepe_5px_pi3)
472
+ return np.mean(aepe_pred), np.mean(aepe_5px_pred), np.mean(aepe_pi3), np.mean(aepe_5px_pi3)
flow3r/utils/geometry.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ def se3_inverse(T):
6
+ """
7
+ Computes the inverse of a batch of SE(3) matrices.
8
+ """
9
+
10
+ if torch.is_tensor(T):
11
+ R = T[..., :3, :3]
12
+ t = T[..., :3, 3].unsqueeze(-1)
13
+ R_inv = R.transpose(-2, -1)
14
+ t_inv = -torch.matmul(R_inv, t)
15
+ T_inv = torch.cat([
16
+ torch.cat([R_inv, t_inv], dim=-1),
17
+ torch.tensor([0, 0, 0, 1], device=T.device, dtype=T.dtype).repeat(*T.shape[:-2], 1, 1)
18
+ ], dim=-2)
19
+ else:
20
+ R = T[..., :3, :3]
21
+ t = T[..., :3, 3, np.newaxis]
22
+
23
+ R_inv = np.swapaxes(R, -2, -1)
24
+ t_inv = -R_inv @ t
25
+
26
+ bottom_row = np.zeros((*T.shape[:-2], 1, 4), dtype=T.dtype)
27
+ bottom_row[..., :, 3] = 1
28
+
29
+ top_part = np.concatenate([R_inv, t_inv], axis=-1)
30
+ T_inv = np.concatenate([top_part, bottom_row], axis=-2)
31
+
32
+ return T_inv
33
+
34
+ def get_pixel(H, W):
35
+ # get 2D pixels (u, v) for image_a in cam_a pixel space
36
+ u_a, v_a = np.meshgrid(np.arange(W), np.arange(H))
37
+ # u_a = np.flip(u_a, axis=1)
38
+ # v_a = np.flip(v_a, axis=0)
39
+ pixels_a = np.stack([
40
+ u_a.flatten() + 0.5,
41
+ v_a.flatten() + 0.5,
42
+ np.ones_like(u_a.flatten())
43
+ ], axis=0)
44
+
45
+ return pixels_a
46
+
47
+ def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw):
48
+ """
49
+ Args:
50
+ - depthmap (HxW array):
51
+ - camera_intrinsics: a 3x3 matrix
52
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
53
+ Returns:
54
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
55
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
56
+ if z_far > 0:
57
+ valid_mask = valid_mask & (depthmap < z_far)
58
+
59
+ X_world = X_cam # default
60
+ if camera_pose is not None:
61
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
62
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
63
+ R_cam2world = camera_pose[:3, :3]
64
+ t_cam2world = camera_pose[:3, 3]
65
+
66
+ # Express in absolute coordinates (invalid depth values)
67
+ X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
68
+
69
+ return X_world, valid_mask
70
+
71
+
72
+ def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
73
+ """
74
+ Args:
75
+ - depthmap (HxW array):
76
+ - camera_intrinsics: a 3x3 matrix
77
+ Returns:
78
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
79
+ """
80
+ camera_intrinsics = np.float32(camera_intrinsics)
81
+ H, W = depthmap.shape
82
+
83
+ # Compute 3D ray associated with each pixel
84
+ # Strong assumption: there are no skew terms
85
+ # assert camera_intrinsics[0, 1] == 0.0
86
+ # assert camera_intrinsics[1, 0] == 0.0
87
+ if pseudo_focal is None:
88
+ fu = camera_intrinsics[0, 0]
89
+ fv = camera_intrinsics[1, 1]
90
+ else:
91
+ assert pseudo_focal.shape == (H, W)
92
+ fu = fv = pseudo_focal
93
+ cu = camera_intrinsics[0, 2]
94
+ cv = camera_intrinsics[1, 2]
95
+
96
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
97
+ z_cam = depthmap
98
+ x_cam = (u - cu) * z_cam / fu
99
+ y_cam = (v - cv) * z_cam / fv
100
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
101
+
102
+ # Mask for valid coordinates
103
+ valid_mask = (depthmap > 0.0)
104
+ # Invalid any depth > 80m
105
+ valid_mask = valid_mask
106
+ return X_cam, valid_mask
107
+
108
+ def homogenize_points(
109
+ points,
110
+ ):
111
+ """Convert batched points (xyz) to (xyz1)."""
112
+ return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
113
+
114
+
115
+ def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
116
+
117
+ if H is None:
118
+ B,H,W = depth1.shape
119
+ else:
120
+ B = depth1.shape[0]
121
+ with torch.no_grad():
122
+ x1_n = torch.meshgrid(
123
+ *[
124
+ torch.linspace(
125
+ -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
126
+ )
127
+ for n in (B, H, W)
128
+ ],
129
+ indexing = 'ij'
130
+ )
131
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
132
+ mask, x2 = warp_kpts(
133
+ x1_n.double(),
134
+ depth1.double(),
135
+ depth2.double(),
136
+ T_1to2.double(),
137
+ K1.double(),
138
+ K2.double(),
139
+ depth_interpolation_mode = depth_interpolation_mode,
140
+ relative_depth_error_threshold = relative_depth_error_threshold,
141
+ )
142
+ prob = mask.float().reshape(B, H, W)
143
+ x2 = x2.reshape(B, H, W, 2)
144
+ return x2, prob
145
+
146
+ @torch.no_grad()
147
+ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
148
+ """Warp kpts0 from I0 to I1 with depth, K and Rt
149
+ Also check covisibility and depth consistency.
150
+ Depth is consistent if relative error < 0.2 (hard-coded).
151
+ # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
152
+ Args:
153
+ kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
154
+ depth0 (torch.Tensor): [N, H, W],
155
+ depth1 (torch.Tensor): [N, H, W],
156
+ T_0to1 (torch.Tensor): [N, 3, 4],
157
+ K0 (torch.Tensor): [N, 3, 3],
158
+ K1 (torch.Tensor): [N, 3, 3],
159
+ Returns:
160
+ calculable_mask (torch.Tensor): [N, L]
161
+ warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
162
+ """
163
+ (
164
+ n,
165
+ h,
166
+ w,
167
+ ) = depth0.shape
168
+ if depth_interpolation_mode == "combined":
169
+ # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
170
+ if smooth_mask:
171
+ raise NotImplementedError("Combined bilinear and NN warp not implemented")
172
+ valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
173
+ smooth_mask = smooth_mask,
174
+ return_relative_depth_error = return_relative_depth_error,
175
+ depth_interpolation_mode = "bilinear",
176
+ relative_depth_error_threshold = relative_depth_error_threshold)
177
+ valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
178
+ smooth_mask = smooth_mask,
179
+ return_relative_depth_error = return_relative_depth_error,
180
+ depth_interpolation_mode = "nearest-exact",
181
+ relative_depth_error_threshold = relative_depth_error_threshold)
182
+ nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
183
+ warp = warp_bilinear.clone()
184
+ warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
185
+ valid = valid_bilinear | valid_nearest
186
+ return valid, warp
187
+
188
+
189
+ kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
190
+ :, 0, :, 0
191
+ ]
192
+ kpts0 = torch.stack(
193
+ (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
194
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
195
+ # Sample depth, get calculable_mask on depth != 0
196
+ # nonzero_mask = kpts0_depth != 0
197
+ # Sample depth, get calculable_mask on depth > 0
198
+ nonzero_mask = kpts0_depth > 0
199
+
200
+ # Unproject
201
+ kpts0_h = (
202
+ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
203
+ * kpts0_depth[..., None]
204
+ ) # (N, L, 3)
205
+ kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
206
+ kpts0_cam = kpts0_n
207
+
208
+ # Rigid Transform
209
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
210
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
211
+
212
+ # Project
213
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
214
+ w_kpts0 = w_kpts0_h[:, :, :2] / (
215
+ w_kpts0_h[:, :, [2]] + 1e-4
216
+ ) # (N, L, 2), +1e-4 to avoid zero depth
217
+
218
+ # Covisible Check
219
+ h, w = depth1.shape[1:3]
220
+ covisible_mask = (
221
+ (w_kpts0[:, :, 0] > 0)
222
+ * (w_kpts0[:, :, 0] < w - 1)
223
+ * (w_kpts0[:, :, 1] > 0)
224
+ * (w_kpts0[:, :, 1] < h - 1)
225
+ )
226
+ w_kpts0 = torch.stack(
227
+ (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
228
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
229
+ # w_kpts0[~covisible_mask, :] = -5 # xd
230
+
231
+ w_kpts0_depth = F.grid_sample(
232
+ depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
233
+ )[:, 0, :, 0]
234
+
235
+ relative_depth_error = (
236
+ (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
237
+ ).abs()
238
+ if not smooth_mask:
239
+ consistent_mask = relative_depth_error < relative_depth_error_threshold
240
+ else:
241
+ consistent_mask = (-relative_depth_error/smooth_mask).exp()
242
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
243
+ if return_relative_depth_error:
244
+ return relative_depth_error, w_kpts0
245
+ else:
246
+ return valid_mask, w_kpts0
247
+
248
+
249
+ def geotrf(Trf, pts, ncol=None, norm=False):
250
+ """ Apply a geometric transformation to a list of 3-D points.
251
+
252
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
253
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
254
+
255
+ ncol: int. number of columns of the result (2 or 3)
256
+ norm: float. if != 0, the resut is projected on the z=norm plane.
257
+
258
+ Returns an array of projected 2d points.
259
+ """
260
+ assert Trf.ndim >= 2
261
+ if isinstance(Trf, np.ndarray):
262
+ pts = np.asarray(pts)
263
+ elif isinstance(Trf, torch.Tensor):
264
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
265
+
266
+ # adapt shape if necessary
267
+ output_reshape = pts.shape[:-1]
268
+ ncol = ncol or pts.shape[-1]
269
+
270
+ # optimized code
271
+ if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
272
+ Trf.ndim == 3 and pts.ndim == 4):
273
+ d = pts.shape[3]
274
+ if Trf.shape[-1] == d:
275
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
276
+ elif Trf.shape[-1] == d + 1:
277
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
278
+ else:
279
+ raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
280
+ else:
281
+ if Trf.ndim >= 3:
282
+ n = Trf.ndim - 2
283
+ assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
284
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
285
+
286
+ if pts.ndim > Trf.ndim:
287
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
288
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
289
+ elif pts.ndim == 2:
290
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
291
+ pts = pts[:, None, :]
292
+
293
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
294
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
295
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
296
+ elif pts.shape[-1] == Trf.shape[-1]:
297
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
298
+ pts = pts @ Trf
299
+ else:
300
+ pts = Trf @ pts.T
301
+ if pts.ndim >= 2:
302
+ pts = pts.swapaxes(-1, -2)
303
+
304
+ if norm:
305
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
306
+ if norm != 1:
307
+ pts *= norm
308
+
309
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
310
+ return res
311
+
312
+
313
+ def inv(mat):
314
+ """ Invert a torch or numpy matrix
315
+ """
316
+ if isinstance(mat, torch.Tensor):
317
+ return torch.linalg.inv(mat)
318
+ if isinstance(mat, np.ndarray):
319
+ return np.linalg.inv(mat)
320
+ raise ValueError(f'bad matrix type = {type(mat)}')
321
+
322
+ def opencv_camera_to_plucker(poses, K, H, W):
323
+ device = poses.device
324
+ B = poses.shape[0]
325
+
326
+ pixel = torch.from_numpy(get_pixel(H, W).astype(np.float32)).to(device).T.reshape(H, W, 3)[None].repeat(B, 1, 1, 1) # (3, H, W)
327
+ pixel = torch.einsum('bij, bhwj -> bhwi', torch.inverse(K), pixel)
328
+ ray_directions = torch.einsum('bij, bhwj -> bhwi', poses[..., :3, :3], pixel)
329
+
330
+ ray_origins = poses[..., :3, 3][:, None, None].repeat(1, H, W, 1)
331
+
332
+ ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
333
+ plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
334
+ plucker_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
335
+
336
+ return plucker_ray
337
+
338
+
339
+ def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor:
340
+ """
341
+ Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
342
+
343
+ Args:
344
+ depth (torch.Tensor): shape (..., height, width), linear depth map
345
+ atol (float): absolute tolerance
346
+ rtol (float): relative tolerance
347
+
348
+ Returns:
349
+ edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
350
+ """
351
+ shape = depth.shape
352
+ depth = depth.reshape(-1, 1, *shape[-2:])
353
+ if mask is not None:
354
+ mask = mask.reshape(-1, 1, *shape[-2:])
355
+
356
+ if mask is None:
357
+ diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2))
358
+ else:
359
+ diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2))
360
+
361
+ edge = torch.zeros_like(depth, dtype=torch.bool)
362
+ if atol is not None:
363
+ edge |= diff > atol
364
+ if rtol is not None:
365
+ edge |= (diff / depth).nan_to_num_() > rtol
366
+ edge = edge.reshape(*shape)
367
+ return edge
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ numpy==1.26.4
4
+ pillow
5
+ opencv-python
6
+ plyfile
7
+ huggingface_hub
8
+ safetensors
9
+
10
+ # below for gradio
11
+ gradio
12
+ trimesh
13
+ matplotlib
14
+ scipy
15
+ spaces