Spaces:
Running on Zero
Running on Zero
Commit ·
1efbda0
1
Parent(s): 5c8071a
upload demo
Browse files- app.py +725 -4
- flow3r/models/dinov2/__init__.py +6 -0
- flow3r/models/dinov2/hub/__init__.py +4 -0
- flow3r/models/dinov2/hub/backbones.py +156 -0
- flow3r/models/dinov2/hub/utils.py +39 -0
- flow3r/models/dinov2/layers/__init__.py +11 -0
- flow3r/models/dinov2/layers/attention.py +89 -0
- flow3r/models/dinov2/layers/block.py +259 -0
- flow3r/models/dinov2/layers/dino_head.py +58 -0
- flow3r/models/dinov2/layers/drop_path.py +34 -0
- flow3r/models/dinov2/layers/layer_scale.py +27 -0
- flow3r/models/dinov2/layers/mlp.py +40 -0
- flow3r/models/dinov2/layers/patch_embed.py +88 -0
- flow3r/models/dinov2/layers/swiglu_ffn.py +72 -0
- flow3r/models/dinov2/models/__init__.py +43 -0
- flow3r/models/dinov2/models/vision_transformer.py +404 -0
- flow3r/models/dinov2/utils/__init__.py +4 -0
- flow3r/models/dinov2/utils/cluster.py +95 -0
- flow3r/models/dinov2/utils/config.py +72 -0
- flow3r/models/dinov2/utils/dtype.py +37 -0
- flow3r/models/dinov2/utils/param_groups.py +103 -0
- flow3r/models/dinov2/utils/utils.py +95 -0
- flow3r/models/flow3r.py +233 -0
- flow3r/models/flow_head/dpt_head.py +498 -0
- flow3r/models/flow_head/utils.py +108 -0
- flow3r/models/layers/attention.py +403 -0
- flow3r/models/layers/block.py +406 -0
- flow3r/models/layers/camera_head.py +93 -0
- flow3r/models/layers/pos_embed.py +174 -0
- flow3r/models/layers/transformer_head.py +389 -0
- flow3r/utils/alignment.py +499 -0
- flow3r/utils/basic.py +223 -0
- flow3r/utils/cropping.py +197 -0
- flow3r/utils/debug.py +63 -0
- flow3r/utils/flow_utils.py +472 -0
- flow3r/utils/geometry.py +367 -0
- requirements.txt +15 -0
app.py
CHANGED
|
@@ -1,8 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
demo.
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import cv2
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
import gradio as gr
|
| 12 |
+
import sys
|
| 13 |
+
import shutil
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
import glob
|
| 16 |
+
import gc
|
| 17 |
+
import time
|
| 18 |
+
import trimesh
|
| 19 |
+
import matplotlib
|
| 20 |
+
|
| 21 |
+
from flow3r.models.flow3r import Flow3r
|
| 22 |
+
from flow3r.utils.basic import load_images_as_tensor
|
| 23 |
+
from flow3r.utils.geometry import depth_edge
|
| 24 |
+
|
| 25 |
+
from scipy.spatial.transform import Rotation
|
| 26 |
+
from huggingface_hub import hf_hub_download
|
| 27 |
+
|
| 28 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
|
| 30 |
+
print("Initializing and loading Flow3r model...")
|
| 31 |
+
|
| 32 |
+
model = Flow3r()
|
| 33 |
+
ckpt_path = hf_hub_download(repo_id="Clara211111/flow3r", filename="flow3r.bin")
|
| 34 |
+
checkpoint = torch.load(ckpt_path, weights_only=False, map_location='cpu')
|
| 35 |
+
model.load_state_dict(checkpoint, strict=True)
|
| 36 |
+
|
| 37 |
+
model.eval()
|
| 38 |
+
model = model.to(device)
|
| 39 |
+
|
| 40 |
+
# -------------------------------------------------------------------------
|
| 41 |
+
# Utils
|
| 42 |
+
# -------------------------------------------------------------------------
|
| 43 |
+
def predictions_to_glb(
|
| 44 |
+
predictions,
|
| 45 |
+
conf_thres=50.0,
|
| 46 |
+
filter_by_frames="all",
|
| 47 |
+
show_cam=True,
|
| 48 |
+
) -> trimesh.Scene:
|
| 49 |
+
"""
|
| 50 |
+
Converts predictions to a 3D scene represented as a GLB file.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
predictions (dict): Dictionary containing model predictions with keys:
|
| 54 |
+
- world_points: 3D point coordinates (S, H, W, 3)
|
| 55 |
+
- world_points_conf: Confidence scores (S, H, W)
|
| 56 |
+
- images: Input images (S, H, W, 3)
|
| 57 |
+
- extrinsic: Camera extrinsic matrices (S, 3, 4)
|
| 58 |
+
conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0)
|
| 59 |
+
filter_by_frames (str): Frame filter specification (default: "all")
|
| 60 |
+
show_cam (bool): Include camera visualization (default: True)
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
trimesh.Scene: Processed 3D scene containing point cloud and cameras
|
| 64 |
+
|
| 65 |
+
Raises:
|
| 66 |
+
ValueError: If input predictions structure is invalid
|
| 67 |
+
"""
|
| 68 |
+
if not isinstance(predictions, dict):
|
| 69 |
+
raise ValueError("predictions must be a dictionary")
|
| 70 |
+
|
| 71 |
+
if conf_thres is None:
|
| 72 |
+
conf_thres = 10
|
| 73 |
+
|
| 74 |
+
print("Building GLB scene")
|
| 75 |
+
selected_frame_idx = None
|
| 76 |
+
if filter_by_frames != "all" and filter_by_frames != "All":
|
| 77 |
+
try:
|
| 78 |
+
# Extract the index part before the colon
|
| 79 |
+
selected_frame_idx = int(filter_by_frames.split(":")[0])
|
| 80 |
+
except (ValueError, IndexError):
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
pred_world_points = predictions["points"]
|
| 84 |
+
pred_world_points_conf = predictions.get("conf", np.ones_like(pred_world_points[..., 0]))
|
| 85 |
+
|
| 86 |
+
# Get images from predictions
|
| 87 |
+
images = predictions["images"]
|
| 88 |
+
# Use extrinsic matrices instead of pred_extrinsic_list
|
| 89 |
+
camera_poses = predictions["camera_poses"]
|
| 90 |
+
|
| 91 |
+
if selected_frame_idx is not None:
|
| 92 |
+
pred_world_points = pred_world_points[selected_frame_idx][None]
|
| 93 |
+
pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
|
| 94 |
+
images = images[selected_frame_idx][None]
|
| 95 |
+
camera_poses = camera_poses[selected_frame_idx][None]
|
| 96 |
+
|
| 97 |
+
vertices_3d = pred_world_points.reshape(-1, 3)
|
| 98 |
+
# Handle different image formats - check if images need transposing
|
| 99 |
+
if images.ndim == 4 and images.shape[1] == 3: # NCHW format
|
| 100 |
+
colors_rgb = np.transpose(images, (0, 2, 3, 1))
|
| 101 |
+
else: # Assume already in NHWC format
|
| 102 |
+
colors_rgb = images
|
| 103 |
+
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
|
| 104 |
+
|
| 105 |
+
conf = pred_world_points_conf.reshape(-1)
|
| 106 |
+
# Convert percentage threshold to actual confidence value
|
| 107 |
+
if conf_thres == 0.0:
|
| 108 |
+
conf_threshold = 0.0
|
| 109 |
+
else:
|
| 110 |
+
# conf_threshold = np.percentile(conf, conf_thres)
|
| 111 |
+
conf_threshold = conf_thres / 100
|
| 112 |
+
|
| 113 |
+
conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
|
| 114 |
+
|
| 115 |
+
vertices_3d = vertices_3d[conf_mask]
|
| 116 |
+
colors_rgb = colors_rgb[conf_mask]
|
| 117 |
+
|
| 118 |
+
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
|
| 119 |
+
vertices_3d = np.array([[1, 0, 0]])
|
| 120 |
+
colors_rgb = np.array([[255, 255, 255]])
|
| 121 |
+
scene_scale = 1
|
| 122 |
+
else:
|
| 123 |
+
# Calculate the 5th and 95th percentiles along each axis
|
| 124 |
+
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
|
| 125 |
+
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
|
| 126 |
+
|
| 127 |
+
# Calculate the diagonal length of the percentile bounding box
|
| 128 |
+
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
|
| 129 |
+
|
| 130 |
+
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
| 131 |
+
|
| 132 |
+
# Initialize a 3D scene
|
| 133 |
+
scene_3d = trimesh.Scene()
|
| 134 |
+
|
| 135 |
+
# Add point cloud data to the scene
|
| 136 |
+
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
|
| 137 |
+
|
| 138 |
+
scene_3d.add_geometry(point_cloud_data)
|
| 139 |
+
|
| 140 |
+
# Prepare 4x4 matrices for camera extrinsics
|
| 141 |
+
num_cameras = len(camera_poses)
|
| 142 |
+
|
| 143 |
+
if show_cam:
|
| 144 |
+
# Add camera models to the scene
|
| 145 |
+
for i in range(num_cameras):
|
| 146 |
+
camera_to_world = camera_poses[i]
|
| 147 |
+
rgba_color = colormap(i / num_cameras)
|
| 148 |
+
current_color = tuple(int(255 * x) for x in rgba_color[:3])
|
| 149 |
+
|
| 150 |
+
# integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
|
| 151 |
+
integrate_camera_into_scene(scene_3d, camera_to_world, current_color, 1.) # fixed camera size
|
| 152 |
+
|
| 153 |
+
# Rotate scene for better visualize
|
| 154 |
+
align_rotation = np.eye(4)
|
| 155 |
+
align_rotation[:3, :3] = Rotation.from_euler("y", 100, degrees=True).as_matrix() # plane rotate
|
| 156 |
+
align_rotation[:3, :3] = align_rotation[:3, :3] @ Rotation.from_euler("x", 155, degrees=True).as_matrix() # roll
|
| 157 |
+
scene_3d.apply_transform(align_rotation)
|
| 158 |
+
|
| 159 |
+
print("GLB Scene built")
|
| 160 |
+
return scene_3d
|
| 161 |
+
|
| 162 |
+
def get_opengl_conversion_matrix() -> np.ndarray:
|
| 163 |
+
"""
|
| 164 |
+
Constructs and returns the OpenGL conversion matrix.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
numpy.ndarray: A 4x4 OpenGL conversion matrix.
|
| 168 |
+
"""
|
| 169 |
+
# Create an identity matrix
|
| 170 |
+
matrix = np.identity(4)
|
| 171 |
+
|
| 172 |
+
# Flip the y and z axes
|
| 173 |
+
matrix[1, 1] = -1
|
| 174 |
+
matrix[2, 2] = -1
|
| 175 |
+
|
| 176 |
+
return matrix
|
| 177 |
+
|
| 178 |
+
def integrate_camera_into_scene(scene: trimesh.Scene, transform: np.ndarray, face_colors: tuple, scene_scale: float):
|
| 179 |
+
"""
|
| 180 |
+
Integrates a fake camera mesh into the 3D scene.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
scene (trimesh.Scene): The 3D scene to add the camera model.
|
| 184 |
+
transform (np.ndarray): Transformation matrix for camera positioning.
|
| 185 |
+
face_colors (tuple): Color of the camera face.
|
| 186 |
+
scene_scale (float): Scale of the scene.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
cam_width = scene_scale * 0.05
|
| 190 |
+
cam_height = scene_scale * 0.1
|
| 191 |
+
|
| 192 |
+
# Create cone shape for camera
|
| 193 |
+
rot_45_degree = np.eye(4)
|
| 194 |
+
rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
|
| 195 |
+
rot_45_degree[2, 3] = -cam_height
|
| 196 |
+
|
| 197 |
+
opengl_transform = get_opengl_conversion_matrix()
|
| 198 |
+
# Combine transformations
|
| 199 |
+
complete_transform = transform @ opengl_transform @ rot_45_degree
|
| 200 |
+
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
|
| 201 |
+
|
| 202 |
+
# Generate mesh for the camera
|
| 203 |
+
slight_rotation = np.eye(4)
|
| 204 |
+
slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
|
| 205 |
+
|
| 206 |
+
vertices_combined = np.concatenate(
|
| 207 |
+
[
|
| 208 |
+
camera_cone_shape.vertices,
|
| 209 |
+
0.95 * camera_cone_shape.vertices,
|
| 210 |
+
transform_points(slight_rotation, camera_cone_shape.vertices),
|
| 211 |
+
]
|
| 212 |
+
)
|
| 213 |
+
vertices_transformed = transform_points(complete_transform, vertices_combined)
|
| 214 |
+
|
| 215 |
+
mesh_faces = compute_camera_faces(camera_cone_shape)
|
| 216 |
+
|
| 217 |
+
# Add the camera mesh to the scene
|
| 218 |
+
camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
|
| 219 |
+
camera_mesh.visual.face_colors[:, :3] = face_colors
|
| 220 |
+
scene.add_geometry(camera_mesh)
|
| 221 |
+
|
| 222 |
+
def transform_points(transformation: np.ndarray, points: np.ndarray, dim: int = None) -> np.ndarray:
|
| 223 |
+
"""
|
| 224 |
+
Applies a 4x4 transformation to a set of points.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
transformation (np.ndarray): Transformation matrix.
|
| 228 |
+
points (np.ndarray): Points to be transformed.
|
| 229 |
+
dim (int, optional): Dimension for reshaping the result.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
np.ndarray: Transformed points.
|
| 233 |
+
"""
|
| 234 |
+
points = np.asarray(points)
|
| 235 |
+
initial_shape = points.shape[:-1]
|
| 236 |
+
dim = dim or points.shape[-1]
|
| 237 |
+
|
| 238 |
+
# Apply transformation
|
| 239 |
+
transformation = transformation.swapaxes(-1, -2) # Transpose the transformation matrix
|
| 240 |
+
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
|
| 241 |
+
|
| 242 |
+
# Reshape the result
|
| 243 |
+
result = points[..., :dim].reshape(*initial_shape, dim)
|
| 244 |
+
return result
|
| 245 |
+
|
| 246 |
+
def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
|
| 247 |
+
"""
|
| 248 |
+
Computes the faces for the camera mesh.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
cone_shape (trimesh.Trimesh): The shape of the camera cone.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
np.ndarray: Array of faces for the camera mesh.
|
| 255 |
+
"""
|
| 256 |
+
# Create pseudo cameras
|
| 257 |
+
faces_list = []
|
| 258 |
+
num_vertices_cone = len(cone_shape.vertices)
|
| 259 |
+
|
| 260 |
+
for face in cone_shape.faces:
|
| 261 |
+
if 0 in face:
|
| 262 |
+
continue
|
| 263 |
+
v1, v2, v3 = face
|
| 264 |
+
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
|
| 265 |
+
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
|
| 266 |
+
|
| 267 |
+
faces_list.extend(
|
| 268 |
+
[
|
| 269 |
+
(v1, v2, v2_offset),
|
| 270 |
+
(v1, v1_offset, v3),
|
| 271 |
+
(v3_offset, v2, v3),
|
| 272 |
+
(v1, v2, v2_offset_2),
|
| 273 |
+
(v1, v1_offset_2, v3),
|
| 274 |
+
(v3_offset_2, v2, v3),
|
| 275 |
+
]
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
| 279 |
+
return np.array(faces_list)
|
| 280 |
+
|
| 281 |
+
# -------------------------------------------------------------------------
|
| 282 |
+
# 1) Core model inference
|
| 283 |
+
# -------------------------------------------------------------------------
|
| 284 |
+
def run_model(target_dir, model) -> dict:
|
| 285 |
+
print(f"Processing images from {target_dir}")
|
| 286 |
+
|
| 287 |
+
# Device check
|
| 288 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 289 |
+
if not torch.cuda.is_available():
|
| 290 |
+
raise ValueError("CUDA is not available. Check your environment.")
|
| 291 |
+
|
| 292 |
+
# Move model to device
|
| 293 |
+
model = model.to(device)
|
| 294 |
+
model.eval()
|
| 295 |
+
|
| 296 |
+
# Load and preprocess images
|
| 297 |
+
image_names = glob.glob(os.path.join(target_dir, "images", "*"))
|
| 298 |
+
image_names = sorted(image_names)
|
| 299 |
+
print(f"Found {len(image_names)} images")
|
| 300 |
+
if len(image_names) == 0:
|
| 301 |
+
raise ValueError("No images found. Check your upload.")
|
| 302 |
+
|
| 303 |
+
# interval = 10 if target_dir.endswith('.mp4') else 1
|
| 304 |
+
interval = 1
|
| 305 |
+
imgs = load_images_as_tensor(os.path.join(target_dir, "images"), interval=interval).to(device) # (N, 3, H, W)
|
| 306 |
+
|
| 307 |
+
# Run inference
|
| 308 |
+
print("Running inference...")
|
| 309 |
+
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 310 |
+
|
| 311 |
+
with torch.no_grad():
|
| 312 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
| 313 |
+
predictions = model(imgs[None]) # Add batch dimension
|
| 314 |
+
predictions['images'] = imgs[None].permute(0, 1, 3, 4, 2)
|
| 315 |
+
predictions['conf'] = torch.sigmoid(predictions['conf'])
|
| 316 |
+
edge = depth_edge(predictions['local_points'][..., 2], rtol=0.03)
|
| 317 |
+
predictions['conf'][edge] = 0.0
|
| 318 |
+
del predictions['local_points']
|
| 319 |
+
|
| 320 |
+
# Convert tensors to numpy
|
| 321 |
+
for key in predictions.keys():
|
| 322 |
+
if isinstance(predictions[key], torch.Tensor):
|
| 323 |
+
predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
|
| 324 |
+
|
| 325 |
+
# Clean up
|
| 326 |
+
torch.cuda.empty_cache()
|
| 327 |
+
return predictions
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# -------------------------------------------------------------------------
|
| 331 |
+
# 2) Handle uploaded video/images --> produce target_dir + images
|
| 332 |
+
# -------------------------------------------------------------------------
|
| 333 |
+
def handle_uploads(input_video, input_images):
|
| 334 |
+
"""
|
| 335 |
+
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
|
| 336 |
+
images or extracted frames from video into it. Return (target_dir, image_paths).
|
| 337 |
+
"""
|
| 338 |
+
start_time = time.time()
|
| 339 |
+
gc.collect()
|
| 340 |
+
torch.cuda.empty_cache()
|
| 341 |
+
|
| 342 |
+
# Create a unique folder name
|
| 343 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 344 |
+
target_dir = f"input_images_{timestamp}"
|
| 345 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 346 |
+
|
| 347 |
+
# Clean up if somehow that folder already exists
|
| 348 |
+
if os.path.exists(target_dir):
|
| 349 |
+
shutil.rmtree(target_dir)
|
| 350 |
+
os.makedirs(target_dir)
|
| 351 |
+
os.makedirs(target_dir_images)
|
| 352 |
+
|
| 353 |
+
image_paths = []
|
| 354 |
+
|
| 355 |
+
# --- Handle images ---
|
| 356 |
+
if input_images is not None:
|
| 357 |
+
for file_data in input_images:
|
| 358 |
+
if isinstance(file_data, dict) and "name" in file_data:
|
| 359 |
+
file_path = file_data["name"]
|
| 360 |
+
else:
|
| 361 |
+
file_path = file_data
|
| 362 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
| 363 |
+
shutil.copy(file_path, dst_path)
|
| 364 |
+
image_paths.append(dst_path)
|
| 365 |
+
|
| 366 |
+
# --- Handle video ---
|
| 367 |
+
if input_video is not None:
|
| 368 |
+
if isinstance(input_video, dict) and "name" in input_video:
|
| 369 |
+
video_path = input_video["name"]
|
| 370 |
+
else:
|
| 371 |
+
video_path = input_video
|
| 372 |
+
|
| 373 |
+
vs = cv2.VideoCapture(video_path)
|
| 374 |
+
fps = vs.get(cv2.CAP_PROP_FPS)
|
| 375 |
+
frame_interval = int(fps * 1) # 1 frame/sec
|
| 376 |
+
|
| 377 |
+
count = 0
|
| 378 |
+
video_frame_num = 0
|
| 379 |
+
while True:
|
| 380 |
+
gotit, frame = vs.read()
|
| 381 |
+
if not gotit:
|
| 382 |
+
break
|
| 383 |
+
count += 1
|
| 384 |
+
if count % frame_interval == 0:
|
| 385 |
+
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
|
| 386 |
+
cv2.imwrite(image_path, frame)
|
| 387 |
+
image_paths.append(image_path)
|
| 388 |
+
video_frame_num += 1
|
| 389 |
+
|
| 390 |
+
# Sort final images for gallery
|
| 391 |
+
image_paths = sorted(image_paths)
|
| 392 |
+
|
| 393 |
+
end_time = time.time()
|
| 394 |
+
print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
|
| 395 |
+
return target_dir, image_paths
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# -------------------------------------------------------------------------
|
| 399 |
+
# 3) Update gallery on upload
|
| 400 |
+
# -------------------------------------------------------------------------
|
| 401 |
+
def update_gallery_on_upload(input_video, input_images):
|
| 402 |
+
"""
|
| 403 |
+
Whenever user uploads or changes files, immediately handle them
|
| 404 |
+
and show in the gallery. Return (target_dir, image_paths).
|
| 405 |
+
If nothing is uploaded, returns "None" and empty list.
|
| 406 |
+
"""
|
| 407 |
+
if not input_video and not input_images:
|
| 408 |
+
return None, None, None, None
|
| 409 |
+
target_dir, image_paths = handle_uploads(input_video, input_images)
|
| 410 |
+
return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
# -------------------------------------------------------------------------
|
| 414 |
+
# 4) Reconstruction: uses the target_dir plus any viz parameters
|
| 415 |
+
# -------------------------------------------------------------------------
|
| 416 |
+
def gradio_demo(
|
| 417 |
+
target_dir,
|
| 418 |
+
conf_thres=3.0,
|
| 419 |
+
frame_filter="All",
|
| 420 |
+
show_cam=True,
|
| 421 |
+
):
|
| 422 |
+
"""
|
| 423 |
+
Perform reconstruction using the already-created target_dir/images.
|
| 424 |
+
"""
|
| 425 |
+
if not os.path.isdir(target_dir) or target_dir == "None":
|
| 426 |
+
return None, "No valid target directory found. Please upload first.", None, None
|
| 427 |
+
|
| 428 |
+
start_time = time.time()
|
| 429 |
+
gc.collect()
|
| 430 |
+
torch.cuda.empty_cache()
|
| 431 |
+
|
| 432 |
+
# Prepare frame_filter dropdown
|
| 433 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 434 |
+
all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
|
| 435 |
+
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
|
| 436 |
+
frame_filter_choices = ["All"] + all_files
|
| 437 |
+
|
| 438 |
+
print("Running run_model...")
|
| 439 |
+
with torch.no_grad():
|
| 440 |
+
predictions = run_model(target_dir, model)
|
| 441 |
+
|
| 442 |
+
# Save predictions
|
| 443 |
+
prediction_save_path = os.path.join(target_dir, "predictions.npz")
|
| 444 |
+
np.savez(prediction_save_path, **predictions)
|
| 445 |
+
|
| 446 |
+
# Handle None frame_filter
|
| 447 |
+
if frame_filter is None:
|
| 448 |
+
frame_filter = "All"
|
| 449 |
+
|
| 450 |
+
# Build a GLB file name
|
| 451 |
+
glbfile = os.path.join(
|
| 452 |
+
target_dir,
|
| 453 |
+
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}.glb",
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
# Convert predictions to GLB
|
| 457 |
+
glbscene = predictions_to_glb(
|
| 458 |
+
predictions,
|
| 459 |
+
conf_thres=conf_thres,
|
| 460 |
+
filter_by_frames=frame_filter,
|
| 461 |
+
show_cam=show_cam,
|
| 462 |
+
# mask_sky=mask_sky,
|
| 463 |
+
)
|
| 464 |
+
glbscene.export(file_obj=glbfile)
|
| 465 |
+
|
| 466 |
+
# Cleanup
|
| 467 |
+
del predictions
|
| 468 |
+
gc.collect()
|
| 469 |
+
torch.cuda.empty_cache()
|
| 470 |
+
|
| 471 |
+
end_time = time.time()
|
| 472 |
+
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
|
| 473 |
+
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
|
| 474 |
+
|
| 475 |
+
return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# -------------------------------------------------------------------------
|
| 479 |
+
# 5) Helper functions for UI resets + re-visualization
|
| 480 |
+
# -------------------------------------------------------------------------
|
| 481 |
+
def clear_fields():
|
| 482 |
+
"""
|
| 483 |
+
Clears the 3D viewer, the stored target_dir, and empties the gallery.
|
| 484 |
+
"""
|
| 485 |
+
return None
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def update_log():
|
| 489 |
+
"""
|
| 490 |
+
Display a quick log message while waiting.
|
| 491 |
+
"""
|
| 492 |
+
return "Loading and Reconstructing..."
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def update_visualization(
|
| 496 |
+
target_dir, conf_thres, frame_filter, show_cam, is_example
|
| 497 |
+
):
|
| 498 |
+
"""
|
| 499 |
+
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
| 500 |
+
and return it for the 3D viewer. If is_example == "True", skip.
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
# If it's an example click, skip as requested
|
| 504 |
+
if is_example == "True":
|
| 505 |
+
return None, "No reconstruction available. Please click the Reconstruct button first."
|
| 506 |
+
|
| 507 |
+
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
|
| 508 |
+
return None, "No reconstruction available. Please click the Reconstruct button first."
|
| 509 |
+
|
| 510 |
+
predictions_path = os.path.join(target_dir, "predictions.npz")
|
| 511 |
+
if not os.path.exists(predictions_path):
|
| 512 |
+
return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
|
| 513 |
+
|
| 514 |
+
key_list = [
|
| 515 |
+
"images",
|
| 516 |
+
"points",
|
| 517 |
+
"conf",
|
| 518 |
+
"camera_poses",
|
| 519 |
+
]
|
| 520 |
+
|
| 521 |
+
loaded = np.load(predictions_path)
|
| 522 |
+
predictions = {key: np.array(loaded[key]) for key in key_list}
|
| 523 |
+
|
| 524 |
+
glbfile = os.path.join(
|
| 525 |
+
target_dir,
|
| 526 |
+
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}.glb",
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
if not os.path.exists(glbfile):
|
| 530 |
+
glbscene = predictions_to_glb(
|
| 531 |
+
predictions,
|
| 532 |
+
conf_thres=conf_thres,
|
| 533 |
+
filter_by_frames=frame_filter,
|
| 534 |
+
show_cam=show_cam,
|
| 535 |
+
# mask_sky=mask_sky,
|
| 536 |
+
)
|
| 537 |
+
glbscene.export(file_obj=glbfile)
|
| 538 |
+
|
| 539 |
+
return glbfile, "Updating Visualization"
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
# -------------------------------------------------------------------------
|
| 543 |
+
# Example images
|
| 544 |
+
# -------------------------------------------------------------------------
|
| 545 |
+
|
| 546 |
+
great_wall_video = "examples/videos/great_wall.mp4"
|
| 547 |
+
colosseum_video = "examples/videos/Colosseum.mp4"
|
| 548 |
+
room_video = "examples/videos/room.mp4"
|
| 549 |
+
kitchen_video = "examples/videos/kitchen.mp4"
|
| 550 |
+
fern_video = "examples/videos/fern.mp4"
|
| 551 |
+
single_cartoon_video = "examples/videos/single_cartoon.mp4"
|
| 552 |
+
single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
|
| 553 |
+
pyramid_video = "examples/videos/pyramid.mp4"
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# -------------------------------------------------------------------------
|
| 557 |
+
# 6) Build Gradio UI
|
| 558 |
+
# -------------------------------------------------------------------------
|
| 559 |
+
theme = gr.themes.Ocean()
|
| 560 |
+
theme.set(
|
| 561 |
+
checkbox_label_background_fill_selected="*button_primary_background_fill",
|
| 562 |
+
checkbox_label_text_color_selected="*button_primary_text_color",
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
with gr.Blocks(
|
| 566 |
+
theme=theme,
|
| 567 |
+
css="""
|
| 568 |
+
.custom-log * {
|
| 569 |
+
font-style: italic;
|
| 570 |
+
font-size: 22px !important;
|
| 571 |
+
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
|
| 572 |
+
-webkit-background-clip: text;
|
| 573 |
+
background-clip: text;
|
| 574 |
+
font-weight: bold !important;
|
| 575 |
+
color: transparent !important;
|
| 576 |
+
text-align: center !important;
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
.example-log * {
|
| 580 |
+
font-style: italic;
|
| 581 |
+
font-size: 16px !important;
|
| 582 |
+
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
|
| 583 |
+
-webkit-background-clip: text;
|
| 584 |
+
background-clip: text;
|
| 585 |
+
color: transparent !important;
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
#my_radio .wrap {
|
| 589 |
+
display: flex;
|
| 590 |
+
flex-wrap: nowrap;
|
| 591 |
+
justify-content: center;
|
| 592 |
+
align-items: center;
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
#my_radio .wrap label {
|
| 596 |
+
display: flex;
|
| 597 |
+
width: 50%;
|
| 598 |
+
justify-content: center;
|
| 599 |
+
align-items: center;
|
| 600 |
+
margin: 0;
|
| 601 |
+
padding: 10px 0;
|
| 602 |
+
box-sizing: border-box;
|
| 603 |
+
}
|
| 604 |
+
""",
|
| 605 |
+
) as demo:
|
| 606 |
+
# Instead of gr.State, we use a hidden Textbox:
|
| 607 |
+
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
| 608 |
+
num_images = gr.Textbox(label="num_images", visible=False, value="None")
|
| 609 |
+
|
| 610 |
+
gr.HTML(
|
| 611 |
+
"""
|
| 612 |
+
<h1>Flow3r: Factored Flow Prediction for Visual Geometry Learning</h1>
|
| 613 |
+
<p>
|
| 614 |
+
<a href="https://github.com/Kidrauh/flow3r">GitHub Repository</a> |
|
| 615 |
+
<a href="https://flow3r-project.github.io/">Project Page</a>
|
| 616 |
+
</p>
|
| 617 |
+
|
| 618 |
+
<div style="font-size: 16px; line-height: 1.5;">
|
| 619 |
+
<p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. Flow3r takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
|
| 620 |
+
|
| 621 |
+
</div>
|
| 622 |
+
"""
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
| 626 |
+
|
| 627 |
+
with gr.Row():
|
| 628 |
+
with gr.Column(scale=2):
|
| 629 |
+
input_video = gr.Video(label="Upload Video", interactive=True)
|
| 630 |
+
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
| 631 |
+
|
| 632 |
+
image_gallery = gr.Gallery(
|
| 633 |
+
label="Preview",
|
| 634 |
+
columns=4,
|
| 635 |
+
height="300px",
|
| 636 |
+
# show_download_button=True,
|
| 637 |
+
object_fit="contain",
|
| 638 |
+
preview=True,
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
with gr.Column(scale=4):
|
| 642 |
+
with gr.Column():
|
| 643 |
+
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
|
| 644 |
+
log_output = gr.Markdown(
|
| 645 |
+
"Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
|
| 646 |
+
)
|
| 647 |
+
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
|
| 648 |
+
|
| 649 |
+
with gr.Row():
|
| 650 |
+
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 651 |
+
clear_btn = gr.ClearButton(
|
| 652 |
+
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
|
| 653 |
+
scale=1,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
with gr.Row():
|
| 657 |
+
conf_thres = gr.Slider(minimum=0, maximum=100, value=0, step=0.1, label="Confidence Threshold (%)")
|
| 658 |
+
frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
|
| 659 |
+
with gr.Column():
|
| 660 |
+
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
|
| 664 |
+
fn=update_log, inputs=[], outputs=[log_output]
|
| 665 |
+
).then(
|
| 666 |
+
fn=gradio_demo,
|
| 667 |
+
inputs=[
|
| 668 |
+
target_dir_output,
|
| 669 |
+
conf_thres,
|
| 670 |
+
frame_filter,
|
| 671 |
+
show_cam,
|
| 672 |
+
],
|
| 673 |
+
outputs=[reconstruction_output, log_output, frame_filter],
|
| 674 |
+
).then(
|
| 675 |
+
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
|
| 676 |
+
)
|
| 677 |
|
| 678 |
+
# -------------------------------------------------------------------------
|
| 679 |
+
# Real-time Visualization Updates
|
| 680 |
+
# -------------------------------------------------------------------------
|
| 681 |
+
conf_thres.change(
|
| 682 |
+
update_visualization,
|
| 683 |
+
[
|
| 684 |
+
target_dir_output,
|
| 685 |
+
conf_thres,
|
| 686 |
+
frame_filter,
|
| 687 |
+
show_cam,
|
| 688 |
+
is_example,
|
| 689 |
+
],
|
| 690 |
+
[reconstruction_output, log_output],
|
| 691 |
+
)
|
| 692 |
+
frame_filter.change(
|
| 693 |
+
update_visualization,
|
| 694 |
+
[
|
| 695 |
+
target_dir_output,
|
| 696 |
+
conf_thres,
|
| 697 |
+
frame_filter,
|
| 698 |
+
show_cam,
|
| 699 |
+
is_example,
|
| 700 |
+
],
|
| 701 |
+
[reconstruction_output, log_output],
|
| 702 |
+
)
|
| 703 |
|
| 704 |
+
show_cam.change(
|
| 705 |
+
update_visualization,
|
| 706 |
+
[
|
| 707 |
+
target_dir_output,
|
| 708 |
+
conf_thres,
|
| 709 |
+
frame_filter,
|
| 710 |
+
show_cam,
|
| 711 |
+
is_example,
|
| 712 |
+
],
|
| 713 |
+
[reconstruction_output, log_output],
|
| 714 |
+
)
|
| 715 |
+
# -------------------------------------------------------------------------
|
| 716 |
+
# Auto-update gallery whenever user uploads or changes their files
|
| 717 |
+
# -------------------------------------------------------------------------
|
| 718 |
+
input_video.change(
|
| 719 |
+
fn=update_gallery_on_upload,
|
| 720 |
+
inputs=[input_video, input_images],
|
| 721 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 722 |
+
)
|
| 723 |
+
input_images.change(
|
| 724 |
+
fn=update_gallery_on_upload,
|
| 725 |
+
inputs=[input_video, input_images],
|
| 726 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 727 |
+
)
|
| 728 |
|
| 729 |
+
demo.queue(max_size=20).launch(show_error=True, share=True)
|
flow3r/models/dinov2/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
flow3r/models/dinov2/hub/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
flow3r/models/dinov2/hub/backbones.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Weights(Enum):
|
| 15 |
+
LVD142M = "LVD142M"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _make_dinov2_model(
|
| 19 |
+
*,
|
| 20 |
+
arch_name: str = "vit_large",
|
| 21 |
+
img_size: int = 518,
|
| 22 |
+
patch_size: int = 14,
|
| 23 |
+
init_values: float = 1.0,
|
| 24 |
+
ffn_layer: str = "mlp",
|
| 25 |
+
block_chunks: int = 0,
|
| 26 |
+
num_register_tokens: int = 0,
|
| 27 |
+
interpolate_antialias: bool = False,
|
| 28 |
+
interpolate_offset: float = 0.1,
|
| 29 |
+
pretrained: bool = True,
|
| 30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
from ..models import vision_transformer as vits
|
| 34 |
+
|
| 35 |
+
if isinstance(weights, str):
|
| 36 |
+
try:
|
| 37 |
+
weights = Weights[weights]
|
| 38 |
+
except KeyError:
|
| 39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 40 |
+
|
| 41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 42 |
+
vit_kwargs = dict(
|
| 43 |
+
img_size=img_size,
|
| 44 |
+
patch_size=patch_size,
|
| 45 |
+
init_values=init_values,
|
| 46 |
+
ffn_layer=ffn_layer,
|
| 47 |
+
block_chunks=block_chunks,
|
| 48 |
+
num_register_tokens=num_register_tokens,
|
| 49 |
+
interpolate_antialias=interpolate_antialias,
|
| 50 |
+
interpolate_offset=interpolate_offset,
|
| 51 |
+
)
|
| 52 |
+
vit_kwargs.update(**kwargs)
|
| 53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 54 |
+
|
| 55 |
+
if pretrained:
|
| 56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
| 57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 59 |
+
model.load_state_dict(state_dict, strict=True)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 67 |
+
"""
|
| 68 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 72 |
+
"""
|
| 73 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 74 |
+
"""
|
| 75 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 79 |
+
"""
|
| 80 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 81 |
+
"""
|
| 82 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 88 |
+
"""
|
| 89 |
+
return _make_dinov2_model(
|
| 90 |
+
arch_name="vit_giant2",
|
| 91 |
+
ffn_layer="swiglufused",
|
| 92 |
+
weights=weights,
|
| 93 |
+
pretrained=pretrained,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 99 |
+
"""
|
| 100 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 101 |
+
"""
|
| 102 |
+
return _make_dinov2_model(
|
| 103 |
+
arch_name="vit_small",
|
| 104 |
+
pretrained=pretrained,
|
| 105 |
+
weights=weights,
|
| 106 |
+
num_register_tokens=4,
|
| 107 |
+
interpolate_antialias=True,
|
| 108 |
+
interpolate_offset=0.0,
|
| 109 |
+
**kwargs,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 114 |
+
"""
|
| 115 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 116 |
+
"""
|
| 117 |
+
return _make_dinov2_model(
|
| 118 |
+
arch_name="vit_base",
|
| 119 |
+
pretrained=pretrained,
|
| 120 |
+
weights=weights,
|
| 121 |
+
num_register_tokens=4,
|
| 122 |
+
interpolate_antialias=True,
|
| 123 |
+
interpolate_offset=0.0,
|
| 124 |
+
**kwargs,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 129 |
+
"""
|
| 130 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 131 |
+
"""
|
| 132 |
+
return _make_dinov2_model(
|
| 133 |
+
arch_name="vit_large",
|
| 134 |
+
pretrained=pretrained,
|
| 135 |
+
weights=weights,
|
| 136 |
+
num_register_tokens=4,
|
| 137 |
+
interpolate_antialias=True,
|
| 138 |
+
interpolate_offset=0.0,
|
| 139 |
+
**kwargs,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 144 |
+
"""
|
| 145 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 146 |
+
"""
|
| 147 |
+
return _make_dinov2_model(
|
| 148 |
+
arch_name="vit_giant2",
|
| 149 |
+
ffn_layer="swiglufused",
|
| 150 |
+
weights=weights,
|
| 151 |
+
pretrained=pretrained,
|
| 152 |
+
num_register_tokens=4,
|
| 153 |
+
interpolate_antialias=True,
|
| 154 |
+
interpolate_offset=0.0,
|
| 155 |
+
**kwargs,
|
| 156 |
+
)
|
flow3r/models/dinov2/hub/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
| 18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CenterPadding(nn.Module):
|
| 24 |
+
def __init__(self, multiple):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.multiple = multiple
|
| 27 |
+
|
| 28 |
+
def _get_pad(self, size):
|
| 29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 30 |
+
pad_size = new_size - size
|
| 31 |
+
pad_size_left = pad_size // 2
|
| 32 |
+
pad_size_right = pad_size - pad_size_left
|
| 33 |
+
return pad_size_left, pad_size_right
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
| 38 |
+
output = F.pad(x, pads)
|
| 39 |
+
return output
|
flow3r/models/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_head import DINOHead
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
flow3r/models/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 22 |
+
try:
|
| 23 |
+
if XFORMERS_ENABLED:
|
| 24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 25 |
+
|
| 26 |
+
XFORMERS_AVAILABLE = True
|
| 27 |
+
# warnings.warn("xFormers is available (Attention)")
|
| 28 |
+
else:
|
| 29 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
| 30 |
+
raise ImportError
|
| 31 |
+
except ImportError:
|
| 32 |
+
XFORMERS_AVAILABLE = False
|
| 33 |
+
# warnings.warn("xFormers is not available (Attention)")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Attention(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int = 8,
|
| 41 |
+
qkv_bias: bool = False,
|
| 42 |
+
proj_bias: bool = True,
|
| 43 |
+
attn_drop: float = 0.0,
|
| 44 |
+
proj_drop: float = 0.0,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
self.scale = head_dim**-0.5
|
| 50 |
+
|
| 51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 57 |
+
B, N, C = x.shape
|
| 58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 59 |
+
|
| 60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 61 |
+
attn = q @ k.transpose(-2, -1)
|
| 62 |
+
|
| 63 |
+
attn = attn.softmax(dim=-1)
|
| 64 |
+
attn = self.attn_drop(attn)
|
| 65 |
+
|
| 66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 67 |
+
x = self.proj(x)
|
| 68 |
+
x = self.proj_drop(x)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MemEffAttention(Attention):
|
| 73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 74 |
+
if not XFORMERS_AVAILABLE:
|
| 75 |
+
if attn_bias is not None:
|
| 76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 77 |
+
return super().forward(x)
|
| 78 |
+
|
| 79 |
+
B, N, C = x.shape
|
| 80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 81 |
+
|
| 82 |
+
q, k, v = unbind(qkv, 2)
|
| 83 |
+
|
| 84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 85 |
+
x = x.reshape([B, N, C])
|
| 86 |
+
|
| 87 |
+
x = self.proj(x)
|
| 88 |
+
x = self.proj_drop(x)
|
| 89 |
+
return x
|
flow3r/models/dinov2/layers/block.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention, MemEffAttention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 28 |
+
try:
|
| 29 |
+
if XFORMERS_ENABLED:
|
| 30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
| 31 |
+
|
| 32 |
+
XFORMERS_AVAILABLE = True
|
| 33 |
+
# warnings.warn("xFormers is available (Block)")
|
| 34 |
+
else:
|
| 35 |
+
# warnings.warn("xFormers is disabled (Block)")
|
| 36 |
+
raise ImportError
|
| 37 |
+
except ImportError:
|
| 38 |
+
XFORMERS_AVAILABLE = False
|
| 39 |
+
# warnings.warn("xFormers is not available (Block)")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Block(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim: int,
|
| 46 |
+
num_heads: int,
|
| 47 |
+
mlp_ratio: float = 4.0,
|
| 48 |
+
qkv_bias: bool = False,
|
| 49 |
+
proj_bias: bool = True,
|
| 50 |
+
ffn_bias: bool = True,
|
| 51 |
+
drop: float = 0.0,
|
| 52 |
+
attn_drop: float = 0.0,
|
| 53 |
+
init_values=None,
|
| 54 |
+
drop_path: float = 0.0,
|
| 55 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 56 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 57 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 58 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 59 |
+
) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 62 |
+
self.norm1 = norm_layer(dim)
|
| 63 |
+
self.attn = attn_class(
|
| 64 |
+
dim,
|
| 65 |
+
num_heads=num_heads,
|
| 66 |
+
qkv_bias=qkv_bias,
|
| 67 |
+
proj_bias=proj_bias,
|
| 68 |
+
attn_drop=attn_drop,
|
| 69 |
+
proj_drop=drop,
|
| 70 |
+
)
|
| 71 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 72 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 73 |
+
|
| 74 |
+
self.norm2 = norm_layer(dim)
|
| 75 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 76 |
+
self.mlp = ffn_layer(
|
| 77 |
+
in_features=dim,
|
| 78 |
+
hidden_features=mlp_hidden_dim,
|
| 79 |
+
act_layer=act_layer,
|
| 80 |
+
drop=drop,
|
| 81 |
+
bias=ffn_bias,
|
| 82 |
+
)
|
| 83 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 84 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 85 |
+
|
| 86 |
+
self.sample_drop_ratio = drop_path
|
| 87 |
+
|
| 88 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 89 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 90 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 91 |
+
|
| 92 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 93 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 94 |
+
|
| 95 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 96 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 97 |
+
x = drop_add_residual_stochastic_depth(
|
| 98 |
+
x,
|
| 99 |
+
residual_func=attn_residual_func,
|
| 100 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 101 |
+
)
|
| 102 |
+
x = drop_add_residual_stochastic_depth(
|
| 103 |
+
x,
|
| 104 |
+
residual_func=ffn_residual_func,
|
| 105 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 106 |
+
)
|
| 107 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 108 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 109 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 110 |
+
else:
|
| 111 |
+
x = x + attn_residual_func(x)
|
| 112 |
+
x = x + ffn_residual_func(x)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def drop_add_residual_stochastic_depth(
|
| 117 |
+
x: Tensor,
|
| 118 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 119 |
+
sample_drop_ratio: float = 0.0,
|
| 120 |
+
) -> Tensor:
|
| 121 |
+
# 1) extract subset using permutation
|
| 122 |
+
b, n, d = x.shape
|
| 123 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 124 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 125 |
+
x_subset = x[brange]
|
| 126 |
+
|
| 127 |
+
# 2) apply residual_func to get residual
|
| 128 |
+
residual = residual_func(x_subset)
|
| 129 |
+
|
| 130 |
+
x_flat = x.flatten(1)
|
| 131 |
+
residual = residual.flatten(1)
|
| 132 |
+
|
| 133 |
+
residual_scale_factor = b / sample_subset_size
|
| 134 |
+
|
| 135 |
+
# 3) add the residual
|
| 136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 137 |
+
return x_plus_residual.view_as(x)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 141 |
+
b, n, d = x.shape
|
| 142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 144 |
+
residual_scale_factor = b / sample_subset_size
|
| 145 |
+
return brange, residual_scale_factor
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 149 |
+
if scaling_vector is None:
|
| 150 |
+
x_flat = x.flatten(1)
|
| 151 |
+
residual = residual.flatten(1)
|
| 152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 153 |
+
else:
|
| 154 |
+
x_plus_residual = scaled_index_add(
|
| 155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 156 |
+
)
|
| 157 |
+
return x_plus_residual
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 164 |
+
"""
|
| 165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 166 |
+
"""
|
| 167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 169 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 170 |
+
seqlens = []
|
| 171 |
+
for b, x in zip(batch_sizes, x_list):
|
| 172 |
+
for _ in range(b):
|
| 173 |
+
seqlens.append(x.shape[1])
|
| 174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 175 |
+
attn_bias._batch_sizes = batch_sizes
|
| 176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 177 |
+
|
| 178 |
+
if branges is not None:
|
| 179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 180 |
+
else:
|
| 181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 183 |
+
|
| 184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def drop_add_residual_stochastic_depth_list(
|
| 188 |
+
x_list: List[Tensor],
|
| 189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 190 |
+
sample_drop_ratio: float = 0.0,
|
| 191 |
+
scaling_vector=None,
|
| 192 |
+
) -> Tensor:
|
| 193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 195 |
+
branges = [s[0] for s in branges_scales]
|
| 196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 197 |
+
|
| 198 |
+
# 2) get attention bias and index+concat the tensors
|
| 199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 200 |
+
|
| 201 |
+
# 3) apply residual_func to get residual, and split the result
|
| 202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 203 |
+
|
| 204 |
+
outputs = []
|
| 205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 207 |
+
return outputs
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class NestedTensorBlock(Block):
|
| 211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 212 |
+
"""
|
| 213 |
+
x_list contains a list of tensors to nest together and run
|
| 214 |
+
"""
|
| 215 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 216 |
+
|
| 217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 218 |
+
|
| 219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 221 |
+
|
| 222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 223 |
+
return self.mlp(self.norm2(x))
|
| 224 |
+
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=attn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 232 |
+
x_list,
|
| 233 |
+
residual_func=ffn_residual_func,
|
| 234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 236 |
+
)
|
| 237 |
+
return x_list
|
| 238 |
+
else:
|
| 239 |
+
|
| 240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 242 |
+
|
| 243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 245 |
+
|
| 246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 248 |
+
x = x + ffn_residual_func(x)
|
| 249 |
+
return attn_bias.split(x)
|
| 250 |
+
|
| 251 |
+
def forward(self, x_or_x_list):
|
| 252 |
+
if isinstance(x_or_x_list, Tensor):
|
| 253 |
+
return super().forward(x_or_x_list)
|
| 254 |
+
elif isinstance(x_or_x_list, list):
|
| 255 |
+
if not XFORMERS_AVAILABLE:
|
| 256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 257 |
+
return self.forward_nested(x_or_x_list)
|
| 258 |
+
else:
|
| 259 |
+
raise AssertionError
|
flow3r/models/dinov2/layers/dino_head.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn.init import trunc_normal_
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOHead(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_dim,
|
| 16 |
+
out_dim,
|
| 17 |
+
use_bn=False,
|
| 18 |
+
nlayers=3,
|
| 19 |
+
hidden_dim=2048,
|
| 20 |
+
bottleneck_dim=256,
|
| 21 |
+
mlp_bias=True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
nlayers = max(nlayers, 1)
|
| 25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
| 26 |
+
self.apply(self._init_weights)
|
| 27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 28 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, m):
|
| 31 |
+
if isinstance(m, nn.Linear):
|
| 32 |
+
trunc_normal_(m.weight, std=0.02)
|
| 33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.mlp(x)
|
| 38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 40 |
+
x = self.last_layer(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
| 45 |
+
if nlayers == 1:
|
| 46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 47 |
+
else:
|
| 48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 49 |
+
if use_bn:
|
| 50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 51 |
+
layers.append(nn.GELU())
|
| 52 |
+
for _ in range(nlayers - 2):
|
| 53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 54 |
+
if use_bn:
|
| 55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 56 |
+
layers.append(nn.GELU())
|
| 57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 58 |
+
return nn.Sequential(*layers)
|
flow3r/models/dinov2/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
flow3r/models/dinov2/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 20 |
+
inplace: bool = False,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.inplace = inplace
|
| 24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
flow3r/models/dinov2/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
flow3r/models/dinov2/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
_, _, H, W = x.shape
|
| 70 |
+
patch_H, patch_W = self.patch_size
|
| 71 |
+
|
| 72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 74 |
+
|
| 75 |
+
x = self.proj(x) # B C H W
|
| 76 |
+
H, W = x.size(2), x.size(3)
|
| 77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 78 |
+
x = self.norm(x)
|
| 79 |
+
if not self.flatten_embedding:
|
| 80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def flops(self) -> float:
|
| 84 |
+
Ho, Wo = self.patches_resolution
|
| 85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 86 |
+
if self.norm is not None:
|
| 87 |
+
flops += Ho * Wo * self.embed_dim
|
| 88 |
+
return flops
|
flow3r/models/dinov2/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
hidden_features: Optional[int] = None,
|
| 19 |
+
out_features: Optional[int] = None,
|
| 20 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 21 |
+
drop: float = 0.0,
|
| 22 |
+
bias: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
x12 = self.w12(x)
|
| 32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 33 |
+
hidden = F.silu(x1) * x2
|
| 34 |
+
return self.w3(hidden)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 38 |
+
try:
|
| 39 |
+
if XFORMERS_ENABLED:
|
| 40 |
+
from xformers.ops import SwiGLU
|
| 41 |
+
|
| 42 |
+
XFORMERS_AVAILABLE = True
|
| 43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 44 |
+
else:
|
| 45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 46 |
+
raise ImportError
|
| 47 |
+
except ImportError:
|
| 48 |
+
SwiGLU = SwiGLUFFN
|
| 49 |
+
XFORMERS_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_features: int,
|
| 58 |
+
hidden_features: Optional[int] = None,
|
| 59 |
+
out_features: Optional[int] = None,
|
| 60 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 61 |
+
drop: float = 0.0,
|
| 62 |
+
bias: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
out_features = out_features or in_features
|
| 65 |
+
hidden_features = hidden_features or in_features
|
| 66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 67 |
+
super().__init__(
|
| 68 |
+
in_features=in_features,
|
| 69 |
+
hidden_features=hidden_features,
|
| 70 |
+
out_features=out_features,
|
| 71 |
+
bias=bias,
|
| 72 |
+
)
|
flow3r/models/dinov2/models/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from . import vision_transformer as vits
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("dinov2")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_model(args, only_teacher=False, img_size=224):
|
| 15 |
+
args.arch = args.arch.removesuffix("_memeff")
|
| 16 |
+
if "vit" in args.arch:
|
| 17 |
+
vit_kwargs = dict(
|
| 18 |
+
img_size=img_size,
|
| 19 |
+
patch_size=args.patch_size,
|
| 20 |
+
init_values=args.layerscale,
|
| 21 |
+
ffn_layer=args.ffn_layer,
|
| 22 |
+
block_chunks=args.block_chunks,
|
| 23 |
+
qkv_bias=args.qkv_bias,
|
| 24 |
+
proj_bias=args.proj_bias,
|
| 25 |
+
ffn_bias=args.ffn_bias,
|
| 26 |
+
num_register_tokens=args.num_register_tokens,
|
| 27 |
+
interpolate_offset=args.interpolate_offset,
|
| 28 |
+
interpolate_antialias=args.interpolate_antialias,
|
| 29 |
+
)
|
| 30 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
| 31 |
+
if only_teacher:
|
| 32 |
+
return teacher, teacher.embed_dim
|
| 33 |
+
student = vits.__dict__[args.arch](
|
| 34 |
+
**vit_kwargs,
|
| 35 |
+
drop_path_rate=args.drop_path_rate,
|
| 36 |
+
drop_path_uniform=args.drop_path_uniform,
|
| 37 |
+
)
|
| 38 |
+
embed_dim = student.embed_dim
|
| 39 |
+
return student, teacher, embed_dim
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
| 43 |
+
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
flow3r/models/dinov2/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.utils.checkpoint import checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
|
| 20 |
+
from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 21 |
+
from ...layers.attention import FlashAttention
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 28 |
+
if not depth_first and include_root:
|
| 29 |
+
fn(module=module, name=name)
|
| 30 |
+
for child_name, child_module in module.named_children():
|
| 31 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 32 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 33 |
+
if depth_first and include_root:
|
| 34 |
+
fn(module=module, name=name)
|
| 35 |
+
return module
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BlockChunk(nn.ModuleList):
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
for b in self:
|
| 41 |
+
x = b(x)
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DinoVisionTransformer(nn.Module):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
img_size=224,
|
| 49 |
+
patch_size=16,
|
| 50 |
+
in_chans=3,
|
| 51 |
+
embed_dim=768,
|
| 52 |
+
depth=12,
|
| 53 |
+
num_heads=12,
|
| 54 |
+
mlp_ratio=4.0,
|
| 55 |
+
qkv_bias=True,
|
| 56 |
+
ffn_bias=True,
|
| 57 |
+
proj_bias=True,
|
| 58 |
+
drop_path_rate=0.0,
|
| 59 |
+
drop_path_uniform=False,
|
| 60 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 61 |
+
embed_layer=PatchEmbed,
|
| 62 |
+
act_layer=nn.GELU,
|
| 63 |
+
block_fn=Block,
|
| 64 |
+
ffn_layer="mlp",
|
| 65 |
+
block_chunks=1,
|
| 66 |
+
num_register_tokens=0,
|
| 67 |
+
interpolate_antialias=False,
|
| 68 |
+
interpolate_offset=0.1,
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Args:
|
| 72 |
+
img_size (int, tuple): input image size
|
| 73 |
+
patch_size (int, tuple): patch size
|
| 74 |
+
in_chans (int): number of input channels
|
| 75 |
+
embed_dim (int): embedding dimension
|
| 76 |
+
depth (int): depth of transformer
|
| 77 |
+
num_heads (int): number of attention heads
|
| 78 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 79 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 80 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 81 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 82 |
+
drop_path_rate (float): stochastic depth rate
|
| 83 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 84 |
+
weight_init (str): weight init scheme
|
| 85 |
+
init_values (float): layer-scale init values
|
| 86 |
+
embed_layer (nn.Module): patch embedding layer
|
| 87 |
+
act_layer (nn.Module): MLP activation layer
|
| 88 |
+
block_fn (nn.Module): transformer block class
|
| 89 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 90 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 91 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 92 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 93 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 94 |
+
"""
|
| 95 |
+
super().__init__()
|
| 96 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 97 |
+
|
| 98 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 99 |
+
self.num_tokens = 1
|
| 100 |
+
self.n_blocks = depth
|
| 101 |
+
self.num_heads = num_heads
|
| 102 |
+
self.patch_size = patch_size
|
| 103 |
+
self.num_register_tokens = num_register_tokens
|
| 104 |
+
self.interpolate_antialias = interpolate_antialias
|
| 105 |
+
self.interpolate_offset = interpolate_offset
|
| 106 |
+
|
| 107 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 108 |
+
num_patches = self.patch_embed.num_patches
|
| 109 |
+
|
| 110 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 111 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 112 |
+
assert num_register_tokens >= 0
|
| 113 |
+
self.register_tokens = (
|
| 114 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if drop_path_uniform is True:
|
| 118 |
+
dpr = [drop_path_rate] * depth
|
| 119 |
+
else:
|
| 120 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 121 |
+
|
| 122 |
+
if ffn_layer == "mlp":
|
| 123 |
+
# logger.info("using MLP layer as FFN")
|
| 124 |
+
ffn_layer = Mlp
|
| 125 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 126 |
+
# logger.info("using SwiGLU layer as FFN")
|
| 127 |
+
ffn_layer = SwiGLUFFNFused
|
| 128 |
+
elif ffn_layer == "identity":
|
| 129 |
+
# logger.info("using Identity layer as FFN")
|
| 130 |
+
|
| 131 |
+
def f(*args, **kwargs):
|
| 132 |
+
return nn.Identity()
|
| 133 |
+
|
| 134 |
+
ffn_layer = f
|
| 135 |
+
else:
|
| 136 |
+
raise NotImplementedError
|
| 137 |
+
|
| 138 |
+
blocks_list = [
|
| 139 |
+
block_fn(
|
| 140 |
+
dim=embed_dim,
|
| 141 |
+
num_heads=num_heads,
|
| 142 |
+
mlp_ratio=mlp_ratio,
|
| 143 |
+
qkv_bias=qkv_bias,
|
| 144 |
+
proj_bias=proj_bias,
|
| 145 |
+
ffn_bias=ffn_bias,
|
| 146 |
+
drop_path=dpr[i],
|
| 147 |
+
norm_layer=norm_layer,
|
| 148 |
+
act_layer=act_layer,
|
| 149 |
+
ffn_layer=ffn_layer,
|
| 150 |
+
init_values=init_values,
|
| 151 |
+
attn_class=FlashAttention
|
| 152 |
+
)
|
| 153 |
+
for i in range(depth)
|
| 154 |
+
]
|
| 155 |
+
if block_chunks > 0:
|
| 156 |
+
self.chunked_blocks = True
|
| 157 |
+
chunked_blocks = []
|
| 158 |
+
chunksize = depth // block_chunks
|
| 159 |
+
for i in range(0, depth, chunksize):
|
| 160 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 161 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 162 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 163 |
+
else:
|
| 164 |
+
self.chunked_blocks = False
|
| 165 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 166 |
+
|
| 167 |
+
self.norm = norm_layer(embed_dim)
|
| 168 |
+
self.head = nn.Identity()
|
| 169 |
+
|
| 170 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 171 |
+
|
| 172 |
+
self.init_weights()
|
| 173 |
+
|
| 174 |
+
def init_weights(self):
|
| 175 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 176 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 177 |
+
if self.register_tokens is not None:
|
| 178 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 179 |
+
named_apply(init_weights_vit_timm, self)
|
| 180 |
+
|
| 181 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 182 |
+
previous_dtype = x.dtype
|
| 183 |
+
npatch = x.shape[1] - 1
|
| 184 |
+
N = self.pos_embed.shape[1] - 1
|
| 185 |
+
if npatch == N and w == h:
|
| 186 |
+
return self.pos_embed
|
| 187 |
+
pos_embed = self.pos_embed.float()
|
| 188 |
+
class_pos_embed = pos_embed[:, 0]
|
| 189 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 190 |
+
dim = x.shape[-1]
|
| 191 |
+
w0 = w // self.patch_size
|
| 192 |
+
h0 = h // self.patch_size
|
| 193 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 194 |
+
assert N == M * M
|
| 195 |
+
kwargs = {}
|
| 196 |
+
if self.interpolate_offset:
|
| 197 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 198 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 199 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 200 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 201 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 202 |
+
else:
|
| 203 |
+
# Simply specify an output size instead of a scale factor
|
| 204 |
+
kwargs["size"] = (w0, h0)
|
| 205 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 206 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 207 |
+
mode="bicubic",
|
| 208 |
+
antialias=self.interpolate_antialias,
|
| 209 |
+
**kwargs,
|
| 210 |
+
)
|
| 211 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 212 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 213 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 214 |
+
|
| 215 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 216 |
+
B, nc, w, h = x.shape
|
| 217 |
+
x = self.patch_embed(x)
|
| 218 |
+
if masks is not None:
|
| 219 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 220 |
+
|
| 221 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 222 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 223 |
+
|
| 224 |
+
if self.register_tokens is not None:
|
| 225 |
+
x = torch.cat(
|
| 226 |
+
(
|
| 227 |
+
x[:, :1],
|
| 228 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 229 |
+
x[:, 1:],
|
| 230 |
+
),
|
| 231 |
+
dim=1,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
def forward_features_list(self, x_list, masks_list):
|
| 237 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 238 |
+
for blk in self.blocks:
|
| 239 |
+
if self.training:
|
| 240 |
+
x = checkpoint(blk, x, use_reentrant=False)
|
| 241 |
+
else:
|
| 242 |
+
x = blk(x)
|
| 243 |
+
|
| 244 |
+
all_x = x
|
| 245 |
+
output = []
|
| 246 |
+
for x, masks in zip(all_x, masks_list):
|
| 247 |
+
x_norm = self.norm(x)
|
| 248 |
+
output.append(
|
| 249 |
+
{
|
| 250 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 251 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 252 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 253 |
+
"x_prenorm": x,
|
| 254 |
+
"masks": masks,
|
| 255 |
+
}
|
| 256 |
+
)
|
| 257 |
+
return output
|
| 258 |
+
|
| 259 |
+
def forward_features(self, x, masks=None):
|
| 260 |
+
if isinstance(x, list):
|
| 261 |
+
return self.forward_features_list(x, masks)
|
| 262 |
+
|
| 263 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 264 |
+
|
| 265 |
+
for blk in self.blocks:
|
| 266 |
+
if self.training:
|
| 267 |
+
x = checkpoint(blk, x, use_reentrant=False)
|
| 268 |
+
else:
|
| 269 |
+
x = blk(x)
|
| 270 |
+
|
| 271 |
+
x_norm = self.norm(x)
|
| 272 |
+
return {
|
| 273 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 274 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 275 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 276 |
+
"x_prenorm": x,
|
| 277 |
+
"masks": masks,
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 281 |
+
x = self.prepare_tokens_with_masks(x)
|
| 282 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 283 |
+
output, total_block_len = [], len(self.blocks)
|
| 284 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 285 |
+
for i, blk in enumerate(self.blocks):
|
| 286 |
+
x = blk(x)
|
| 287 |
+
if i in blocks_to_take:
|
| 288 |
+
output.append(x)
|
| 289 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 290 |
+
return output
|
| 291 |
+
|
| 292 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 293 |
+
x = self.prepare_tokens_with_masks(x)
|
| 294 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 295 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 296 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 297 |
+
for block_chunk in self.blocks:
|
| 298 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 299 |
+
x = blk(x)
|
| 300 |
+
if i in blocks_to_take:
|
| 301 |
+
output.append(x)
|
| 302 |
+
i += 1
|
| 303 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 304 |
+
return output
|
| 305 |
+
|
| 306 |
+
def get_intermediate_layers(
|
| 307 |
+
self,
|
| 308 |
+
x: torch.Tensor,
|
| 309 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 310 |
+
reshape: bool = False,
|
| 311 |
+
return_class_token: bool = False,
|
| 312 |
+
norm=True,
|
| 313 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 314 |
+
if self.chunked_blocks:
|
| 315 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 316 |
+
else:
|
| 317 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 318 |
+
if norm:
|
| 319 |
+
outputs = [self.norm(out) for out in outputs]
|
| 320 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 321 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 322 |
+
if reshape:
|
| 323 |
+
B, _, w, h = x.shape
|
| 324 |
+
outputs = [
|
| 325 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 326 |
+
for out in outputs
|
| 327 |
+
]
|
| 328 |
+
if return_class_token:
|
| 329 |
+
return tuple(zip(outputs, class_tokens))
|
| 330 |
+
return tuple(outputs)
|
| 331 |
+
|
| 332 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 333 |
+
ret = self.forward_features(*args, **kwargs)
|
| 334 |
+
if is_training:
|
| 335 |
+
return ret
|
| 336 |
+
else:
|
| 337 |
+
return self.head(ret["x_norm_clstoken"])
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 341 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 342 |
+
if isinstance(module, nn.Linear):
|
| 343 |
+
trunc_normal_(module.weight, std=0.02)
|
| 344 |
+
if module.bias is not None:
|
| 345 |
+
nn.init.zeros_(module.bias)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 349 |
+
model = DinoVisionTransformer(
|
| 350 |
+
patch_size=patch_size,
|
| 351 |
+
embed_dim=384,
|
| 352 |
+
depth=12,
|
| 353 |
+
num_heads=6,
|
| 354 |
+
mlp_ratio=4,
|
| 355 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 356 |
+
num_register_tokens=num_register_tokens,
|
| 357 |
+
**kwargs,
|
| 358 |
+
)
|
| 359 |
+
return model
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 363 |
+
model = DinoVisionTransformer(
|
| 364 |
+
patch_size=patch_size,
|
| 365 |
+
embed_dim=768,
|
| 366 |
+
depth=12,
|
| 367 |
+
num_heads=12,
|
| 368 |
+
mlp_ratio=4,
|
| 369 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 370 |
+
num_register_tokens=num_register_tokens,
|
| 371 |
+
**kwargs,
|
| 372 |
+
)
|
| 373 |
+
return model
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 377 |
+
model = DinoVisionTransformer(
|
| 378 |
+
patch_size=patch_size,
|
| 379 |
+
embed_dim=1024,
|
| 380 |
+
depth=24,
|
| 381 |
+
num_heads=16,
|
| 382 |
+
mlp_ratio=4,
|
| 383 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 384 |
+
num_register_tokens=num_register_tokens,
|
| 385 |
+
**kwargs,
|
| 386 |
+
)
|
| 387 |
+
return model
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 391 |
+
"""
|
| 392 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 393 |
+
"""
|
| 394 |
+
model = DinoVisionTransformer(
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
embed_dim=1536,
|
| 397 |
+
depth=40,
|
| 398 |
+
num_heads=24,
|
| 399 |
+
mlp_ratio=4,
|
| 400 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 401 |
+
num_register_tokens=num_register_tokens,
|
| 402 |
+
**kwargs,
|
| 403 |
+
)
|
| 404 |
+
return model
|
flow3r/models/dinov2/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
flow3r/models/dinov2/utils/cluster.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ClusterType(Enum):
|
| 13 |
+
AWS = "aws"
|
| 14 |
+
FAIR = "fair"
|
| 15 |
+
RSC = "rsc"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _guess_cluster_type() -> ClusterType:
|
| 19 |
+
uname = os.uname()
|
| 20 |
+
if uname.sysname == "Linux":
|
| 21 |
+
if uname.release.endswith("-aws"):
|
| 22 |
+
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
| 23 |
+
return ClusterType.AWS
|
| 24 |
+
elif uname.nodename.startswith("rsc"):
|
| 25 |
+
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
| 26 |
+
return ClusterType.RSC
|
| 27 |
+
|
| 28 |
+
return ClusterType.FAIR
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
| 32 |
+
if cluster_type is None:
|
| 33 |
+
return _guess_cluster_type()
|
| 34 |
+
|
| 35 |
+
return cluster_type
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 39 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 40 |
+
if cluster_type is None:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
CHECKPOINT_DIRNAMES = {
|
| 44 |
+
ClusterType.AWS: "checkpoints",
|
| 45 |
+
ClusterType.FAIR: "checkpoint",
|
| 46 |
+
ClusterType.RSC: "checkpoint/dino",
|
| 47 |
+
}
|
| 48 |
+
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 52 |
+
checkpoint_path = get_checkpoint_path(cluster_type)
|
| 53 |
+
if checkpoint_path is None:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
username = os.environ.get("USER")
|
| 57 |
+
assert username is not None
|
| 58 |
+
return checkpoint_path / username
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| 62 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 63 |
+
if cluster_type is None:
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
SLURM_PARTITIONS = {
|
| 67 |
+
ClusterType.AWS: "learnlab",
|
| 68 |
+
ClusterType.FAIR: "learnlab",
|
| 69 |
+
ClusterType.RSC: "learn",
|
| 70 |
+
}
|
| 71 |
+
return SLURM_PARTITIONS[cluster_type]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_slurm_executor_parameters(
|
| 75 |
+
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
|
| 76 |
+
) -> Dict[str, Any]:
|
| 77 |
+
# create default parameters
|
| 78 |
+
params = {
|
| 79 |
+
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
| 80 |
+
"gpus_per_node": num_gpus_per_node,
|
| 81 |
+
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
| 82 |
+
"cpus_per_task": 10,
|
| 83 |
+
"nodes": nodes,
|
| 84 |
+
"slurm_partition": get_slurm_partition(cluster_type),
|
| 85 |
+
}
|
| 86 |
+
# apply cluster-specific adjustments
|
| 87 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 88 |
+
if cluster_type == ClusterType.AWS:
|
| 89 |
+
params["cpus_per_task"] = 12
|
| 90 |
+
del params["mem_gb"]
|
| 91 |
+
elif cluster_type == ClusterType.RSC:
|
| 92 |
+
params["cpus_per_task"] = 12
|
| 93 |
+
# set additional parameters / apply overrides
|
| 94 |
+
params.update(kwargs)
|
| 95 |
+
return params
|
flow3r/models/dinov2/utils/config.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
|
| 12 |
+
import dinov2.distributed as distributed
|
| 13 |
+
from dinov2.logging import setup_logging
|
| 14 |
+
from dinov2.utils import utils
|
| 15 |
+
from dinov2.configs import dinov2_default_config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_scaling_rules_to_cfg(cfg): # to fix
|
| 22 |
+
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
| 23 |
+
base_lr = cfg.optim.base_lr
|
| 24 |
+
cfg.optim.lr = base_lr
|
| 25 |
+
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
| 26 |
+
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
return cfg
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
| 33 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
| 34 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
| 35 |
+
with open(saved_cfg_path, "w") as f:
|
| 36 |
+
OmegaConf.save(config=cfg, f=f)
|
| 37 |
+
return saved_cfg_path
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_cfg_from_args(args):
|
| 41 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
| 42 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 43 |
+
default_cfg = OmegaConf.create(dinov2_default_config)
|
| 44 |
+
cfg = OmegaConf.load(args.config_file)
|
| 45 |
+
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
| 46 |
+
return cfg
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def default_setup(args):
|
| 50 |
+
distributed.enable(overwrite=True)
|
| 51 |
+
seed = getattr(args, "seed", 0)
|
| 52 |
+
rank = distributed.get_global_rank()
|
| 53 |
+
|
| 54 |
+
global logger
|
| 55 |
+
setup_logging(output=args.output_dir, level=logging.INFO)
|
| 56 |
+
logger = logging.getLogger("dinov2")
|
| 57 |
+
|
| 58 |
+
utils.fix_random_seeds(seed + rank)
|
| 59 |
+
logger.info("git:\n {}\n".format(utils.get_sha()))
|
| 60 |
+
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def setup(args):
|
| 64 |
+
"""
|
| 65 |
+
Create configs and perform basic setups.
|
| 66 |
+
"""
|
| 67 |
+
cfg = get_cfg_from_args(args)
|
| 68 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 69 |
+
default_setup(args)
|
| 70 |
+
apply_scaling_rules_to_cfg(cfg)
|
| 71 |
+
write_config(cfg, args.output_dir)
|
| 72 |
+
return cfg
|
flow3r/models/dinov2/utils/dtype.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
TypeSpec = Union[str, np.dtype, torch.dtype]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
| 17 |
+
np.dtype("bool"): torch.bool,
|
| 18 |
+
np.dtype("uint8"): torch.uint8,
|
| 19 |
+
np.dtype("int8"): torch.int8,
|
| 20 |
+
np.dtype("int16"): torch.int16,
|
| 21 |
+
np.dtype("int32"): torch.int32,
|
| 22 |
+
np.dtype("int64"): torch.int64,
|
| 23 |
+
np.dtype("float16"): torch.float16,
|
| 24 |
+
np.dtype("float32"): torch.float32,
|
| 25 |
+
np.dtype("float64"): torch.float64,
|
| 26 |
+
np.dtype("complex64"): torch.complex64,
|
| 27 |
+
np.dtype("complex128"): torch.complex128,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
| 32 |
+
if isinstance(dtype, torch.dtype):
|
| 33 |
+
return dtype
|
| 34 |
+
if isinstance(dtype, str):
|
| 35 |
+
dtype = np.dtype(dtype)
|
| 36 |
+
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
| 37 |
+
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
flow3r/models/dinov2/utils/param_groups.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("dinov2")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
|
| 14 |
+
"""
|
| 15 |
+
Calculate lr decay rate for different ViT blocks.
|
| 16 |
+
Args:
|
| 17 |
+
name (string): parameter name.
|
| 18 |
+
lr_decay_rate (float): base lr decay rate.
|
| 19 |
+
num_layers (int): number of ViT blocks.
|
| 20 |
+
Returns:
|
| 21 |
+
lr decay rate for the given parameter.
|
| 22 |
+
"""
|
| 23 |
+
layer_id = num_layers + 1
|
| 24 |
+
if name.startswith("backbone") or force_is_backbone:
|
| 25 |
+
if (
|
| 26 |
+
".pos_embed" in name
|
| 27 |
+
or ".patch_embed" in name
|
| 28 |
+
or ".mask_token" in name
|
| 29 |
+
or ".cls_token" in name
|
| 30 |
+
or ".register_tokens" in name
|
| 31 |
+
):
|
| 32 |
+
layer_id = 0
|
| 33 |
+
elif force_is_backbone and (
|
| 34 |
+
"pos_embed" in name
|
| 35 |
+
or "patch_embed" in name
|
| 36 |
+
or "mask_token" in name
|
| 37 |
+
or "cls_token" in name
|
| 38 |
+
or "register_tokens" in name
|
| 39 |
+
):
|
| 40 |
+
layer_id = 0
|
| 41 |
+
elif ".blocks." in name and ".residual." not in name:
|
| 42 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
| 43 |
+
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
| 44 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
| 45 |
+
elif "blocks." in name and "residual." not in name:
|
| 46 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
| 47 |
+
|
| 48 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
| 52 |
+
chunked_blocks = False
|
| 53 |
+
if hasattr(model, "n_blocks"):
|
| 54 |
+
logger.info("chunked fsdp")
|
| 55 |
+
n_blocks = model.n_blocks
|
| 56 |
+
chunked_blocks = model.chunked_blocks
|
| 57 |
+
elif hasattr(model, "blocks"):
|
| 58 |
+
logger.info("first code branch")
|
| 59 |
+
n_blocks = len(model.blocks)
|
| 60 |
+
elif hasattr(model, "backbone"):
|
| 61 |
+
logger.info("second code branch")
|
| 62 |
+
n_blocks = len(model.backbone.blocks)
|
| 63 |
+
else:
|
| 64 |
+
logger.info("else code branch")
|
| 65 |
+
n_blocks = 0
|
| 66 |
+
all_param_groups = []
|
| 67 |
+
|
| 68 |
+
for name, param in model.named_parameters():
|
| 69 |
+
name = name.replace("_fsdp_wrapped_module.", "")
|
| 70 |
+
if not param.requires_grad:
|
| 71 |
+
continue
|
| 72 |
+
decay_rate = get_vit_lr_decay_rate(
|
| 73 |
+
name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
|
| 74 |
+
)
|
| 75 |
+
d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
|
| 76 |
+
|
| 77 |
+
if "last_layer" in name:
|
| 78 |
+
d.update({"is_last_layer": True})
|
| 79 |
+
|
| 80 |
+
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
| 81 |
+
d.update({"wd_multiplier": 0.0})
|
| 82 |
+
|
| 83 |
+
if "patch_embed" in name:
|
| 84 |
+
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
| 85 |
+
|
| 86 |
+
all_param_groups.append(d)
|
| 87 |
+
logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
|
| 88 |
+
|
| 89 |
+
return all_param_groups
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
| 93 |
+
fused_params_groups = defaultdict(lambda: {"params": []})
|
| 94 |
+
for d in all_params_groups:
|
| 95 |
+
identifier = ""
|
| 96 |
+
for k in keys:
|
| 97 |
+
identifier += k + str(d[k]) + "_"
|
| 98 |
+
|
| 99 |
+
for k in keys:
|
| 100 |
+
fused_params_groups[identifier][k] = d[k]
|
| 101 |
+
fused_params_groups[identifier]["params"].append(d["params"])
|
| 102 |
+
|
| 103 |
+
return fused_params_groups.values()
|
flow3r/models/dinov2/utils/utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import subprocess
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
| 21 |
+
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
| 22 |
+
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
|
| 23 |
+
else:
|
| 24 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
| 25 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
| 26 |
+
# logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
| 27 |
+
state_dict = state_dict[checkpoint_key]
|
| 28 |
+
# remove `module.` prefix
|
| 29 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 30 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
| 31 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 32 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 33 |
+
# logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def fix_random_seeds(seed=31):
|
| 37 |
+
"""
|
| 38 |
+
Fix random seeds.
|
| 39 |
+
"""
|
| 40 |
+
torch.manual_seed(seed)
|
| 41 |
+
torch.cuda.manual_seed_all(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_sha():
|
| 47 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 48 |
+
|
| 49 |
+
def _run(command):
|
| 50 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
| 51 |
+
|
| 52 |
+
sha = "N/A"
|
| 53 |
+
diff = "clean"
|
| 54 |
+
branch = "N/A"
|
| 55 |
+
try:
|
| 56 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
| 57 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
| 58 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
| 59 |
+
diff = "has uncommitted changes" if diff else "clean"
|
| 60 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
| 61 |
+
except Exception:
|
| 62 |
+
pass
|
| 63 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 64 |
+
return message
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class CosineScheduler(object):
|
| 68 |
+
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.final_value = final_value
|
| 71 |
+
self.total_iters = total_iters
|
| 72 |
+
|
| 73 |
+
freeze_schedule = np.zeros((freeze_iters))
|
| 74 |
+
|
| 75 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 76 |
+
|
| 77 |
+
iters = np.arange(total_iters - warmup_iters - freeze_iters)
|
| 78 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
| 79 |
+
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
|
| 80 |
+
|
| 81 |
+
assert len(self.schedule) == self.total_iters
|
| 82 |
+
|
| 83 |
+
def __getitem__(self, it):
|
| 84 |
+
if it >= self.total_iters:
|
| 85 |
+
return self.final_value
|
| 86 |
+
else:
|
| 87 |
+
return self.schedule[it]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def has_batchnorms(model):
|
| 91 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
| 92 |
+
for name, module in model.named_modules():
|
| 93 |
+
if isinstance(module, bn_types):
|
| 94 |
+
return True
|
| 95 |
+
return False
|
flow3r/models/flow3r.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from functools import partial
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
|
| 6 |
+
from .dinov2.layers import Mlp
|
| 7 |
+
from ..utils.geometry import homogenize_points
|
| 8 |
+
from .layers.pos_embed import RoPE2D, PositionGetter
|
| 9 |
+
from .layers.block import BlockRope
|
| 10 |
+
from .layers.attention import FlashAttentionRope
|
| 11 |
+
from .layers.transformer_head import TransformerDecoder, LinearPts3d, ContextTransformerDecoder
|
| 12 |
+
from .layers.camera_head import CameraHead
|
| 13 |
+
from .flow_head.dpt_head import DPTHead
|
| 14 |
+
from .dinov2.hub.backbones import dinov2_vitl14_reg
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Flow3r(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
pos_type='rope100',
|
| 21 |
+
decoder_size='large',
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ----------------------
|
| 27 |
+
# Encoder
|
| 28 |
+
# ----------------------
|
| 29 |
+
self.encoder = dinov2_vitl14_reg(pretrained=False)
|
| 30 |
+
self.patch_size = 14
|
| 31 |
+
del self.encoder.mask_token
|
| 32 |
+
|
| 33 |
+
# ----------------------
|
| 34 |
+
# Positonal Encoding
|
| 35 |
+
# ----------------------
|
| 36 |
+
self.pos_type = pos_type if pos_type is not None else 'none'
|
| 37 |
+
self.rope=None
|
| 38 |
+
if self.pos_type.startswith('rope'): # eg rope100
|
| 39 |
+
if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
|
| 40 |
+
freq = float(self.pos_type[len('rope'):])
|
| 41 |
+
self.rope = RoPE2D(freq=freq)
|
| 42 |
+
self.position_getter = PositionGetter()
|
| 43 |
+
else:
|
| 44 |
+
raise NotImplementedError
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ----------------------
|
| 48 |
+
# Decoder
|
| 49 |
+
# ----------------------
|
| 50 |
+
if decoder_size == 'small':
|
| 51 |
+
dec_embed_dim = 384
|
| 52 |
+
dec_num_heads = 6
|
| 53 |
+
mlp_ratio = 4
|
| 54 |
+
dec_depth = 24
|
| 55 |
+
elif decoder_size == 'base':
|
| 56 |
+
dec_embed_dim = 768
|
| 57 |
+
dec_num_heads = 12
|
| 58 |
+
mlp_ratio = 4
|
| 59 |
+
dec_depth = 24
|
| 60 |
+
elif decoder_size == 'large':
|
| 61 |
+
dec_embed_dim = 1024
|
| 62 |
+
dec_num_heads = 16
|
| 63 |
+
mlp_ratio = 4
|
| 64 |
+
dec_depth = 36
|
| 65 |
+
else:
|
| 66 |
+
raise NotImplementedError
|
| 67 |
+
self.decoder = nn.ModuleList([
|
| 68 |
+
BlockRope(
|
| 69 |
+
dim=dec_embed_dim,
|
| 70 |
+
num_heads=dec_num_heads,
|
| 71 |
+
mlp_ratio=mlp_ratio,
|
| 72 |
+
qkv_bias=True,
|
| 73 |
+
proj_bias=True,
|
| 74 |
+
ffn_bias=True,
|
| 75 |
+
drop_path=0.0,
|
| 76 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 77 |
+
act_layer=nn.GELU,
|
| 78 |
+
ffn_layer=Mlp,
|
| 79 |
+
init_values=0.01,
|
| 80 |
+
qk_norm=True,
|
| 81 |
+
attn_class=FlashAttentionRope,
|
| 82 |
+
rope=self.rope
|
| 83 |
+
) for _ in range(dec_depth)])
|
| 84 |
+
self.dec_embed_dim = dec_embed_dim
|
| 85 |
+
|
| 86 |
+
# ----------------------
|
| 87 |
+
# Register_token
|
| 88 |
+
# ----------------------
|
| 89 |
+
num_register_tokens = 5
|
| 90 |
+
self.patch_start_idx = num_register_tokens
|
| 91 |
+
self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim))
|
| 92 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
| 93 |
+
|
| 94 |
+
# ----------------------
|
| 95 |
+
# Local Points Decoder
|
| 96 |
+
# ----------------------
|
| 97 |
+
self.point_decoder = TransformerDecoder(
|
| 98 |
+
in_dim=2*self.dec_embed_dim,
|
| 99 |
+
dec_embed_dim=1024,
|
| 100 |
+
dec_num_heads=16,
|
| 101 |
+
out_dim=1024,
|
| 102 |
+
rope=self.rope
|
| 103 |
+
)
|
| 104 |
+
self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3)
|
| 105 |
+
|
| 106 |
+
# ----------------------
|
| 107 |
+
# Camera Pose Decoder
|
| 108 |
+
# ----------------------
|
| 109 |
+
self.camera_decoder = TransformerDecoder(
|
| 110 |
+
in_dim=2*self.dec_embed_dim,
|
| 111 |
+
dec_embed_dim=1024,
|
| 112 |
+
dec_num_heads=16, # 8
|
| 113 |
+
out_dim=512,
|
| 114 |
+
rope=self.rope,
|
| 115 |
+
use_checkpoint=False
|
| 116 |
+
)
|
| 117 |
+
self.camera_head = CameraHead(dim=512)
|
| 118 |
+
|
| 119 |
+
# ----------------------
|
| 120 |
+
# Motion Flow Decoder
|
| 121 |
+
# ----------------------
|
| 122 |
+
self.flow_head = DPTHead(
|
| 123 |
+
patch_size=14,
|
| 124 |
+
output_dim=2,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# ----------------------
|
| 128 |
+
# Conf Decoder
|
| 129 |
+
# ----------------------
|
| 130 |
+
self.conf_decoder = deepcopy(self.point_decoder)
|
| 131 |
+
self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1)
|
| 132 |
+
|
| 133 |
+
# For ImageNet Normalize
|
| 134 |
+
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 135 |
+
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 136 |
+
|
| 137 |
+
self.register_buffer("image_mean", image_mean)
|
| 138 |
+
self.register_buffer("image_std", image_std)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def decode(self, hidden, N, H, W):
|
| 142 |
+
BN, hw, _ = hidden.shape
|
| 143 |
+
B = BN // N
|
| 144 |
+
|
| 145 |
+
final_output = []
|
| 146 |
+
|
| 147 |
+
hidden = hidden.reshape(B*N, hw, -1)
|
| 148 |
+
|
| 149 |
+
register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:])
|
| 150 |
+
|
| 151 |
+
# Concatenate special tokens with patch tokens
|
| 152 |
+
hidden = torch.cat([register_token, hidden], dim=1)
|
| 153 |
+
hw = hidden.shape[1]
|
| 154 |
+
|
| 155 |
+
if self.pos_type.startswith('rope'):
|
| 156 |
+
pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device)
|
| 157 |
+
|
| 158 |
+
if self.patch_start_idx > 0:
|
| 159 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
| 160 |
+
# so set pos to 0 for the special tokens
|
| 161 |
+
pos = pos + 1
|
| 162 |
+
pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype)
|
| 163 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
| 164 |
+
|
| 165 |
+
for i in range(len(self.decoder)):
|
| 166 |
+
blk = self.decoder[i]
|
| 167 |
+
|
| 168 |
+
if i % 2 == 0:
|
| 169 |
+
pos = pos.reshape(B*N, hw, -1)
|
| 170 |
+
hidden = hidden.reshape(B*N, hw, -1)
|
| 171 |
+
else:
|
| 172 |
+
pos = pos.reshape(B, N*hw, -1)
|
| 173 |
+
hidden = hidden.reshape(B, N*hw, -1)
|
| 174 |
+
|
| 175 |
+
hidden = blk(hidden, xpos=pos)
|
| 176 |
+
|
| 177 |
+
if i+1 in [len(self.decoder)-1, len(self.decoder)]:
|
| 178 |
+
final_output.append(hidden.reshape(B*N, hw, -1))
|
| 179 |
+
|
| 180 |
+
return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1)
|
| 181 |
+
|
| 182 |
+
def forward(self, imgs, pair_indices=None):
|
| 183 |
+
imgs = (imgs - self.image_mean) / self.image_std
|
| 184 |
+
# print("the shape of imgs is", imgs.shape)
|
| 185 |
+
|
| 186 |
+
B, N, _, H, W = imgs.shape
|
| 187 |
+
patch_h, patch_w = H // 14, W // 14
|
| 188 |
+
|
| 189 |
+
# encode by dinov2
|
| 190 |
+
imgs = imgs.reshape(B*N, _, H, W)
|
| 191 |
+
hidden = self.encoder(imgs, is_training=True)
|
| 192 |
+
|
| 193 |
+
if isinstance(hidden, dict):
|
| 194 |
+
hidden = hidden["x_norm_patchtokens"]
|
| 195 |
+
|
| 196 |
+
hidden, pos = self.decode(hidden, N, H, W)
|
| 197 |
+
|
| 198 |
+
point_hidden, point_intermediate = self.point_decoder(hidden, xpos=pos, return_intermediate=True)
|
| 199 |
+
conf_hidden = self.conf_decoder(hidden, xpos=pos)
|
| 200 |
+
camera_hidden, camera_intermediate = self.camera_decoder(hidden, xpos=pos, return_intermediate=True)
|
| 201 |
+
|
| 202 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
| 203 |
+
# local points
|
| 204 |
+
point_hidden = point_hidden.float()
|
| 205 |
+
ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
|
| 206 |
+
xy, z = ret.split([2, 1], dim=-1)
|
| 207 |
+
z = torch.exp(z)
|
| 208 |
+
local_points = torch.cat([xy * z, z], dim=-1)
|
| 209 |
+
|
| 210 |
+
# confidence
|
| 211 |
+
conf_hidden = conf_hidden.float()
|
| 212 |
+
conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
|
| 213 |
+
|
| 214 |
+
# camera
|
| 215 |
+
camera_hidden = camera_hidden.float()
|
| 216 |
+
camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4)
|
| 217 |
+
|
| 218 |
+
# Flow
|
| 219 |
+
if pair_indices is not None:
|
| 220 |
+
flow = self.flow_head([t.float() for t in point_intermediate], [t.float() for t in camera_intermediate], pair_indices, self.patch_start_idx,(H, W), B, N)
|
| 221 |
+
else:
|
| 222 |
+
flow = None
|
| 223 |
+
|
| 224 |
+
# unproject local points using camera poses
|
| 225 |
+
points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3]
|
| 226 |
+
|
| 227 |
+
return dict(
|
| 228 |
+
points=points,
|
| 229 |
+
local_points=local_points,
|
| 230 |
+
conf=conf,
|
| 231 |
+
camera_poses=camera_poses,
|
| 232 |
+
flow=flow,
|
| 233 |
+
)
|
flow3r/models/flow_head/dpt_head.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from typing import List, Dict, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DPTHead(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
DPT Head for dense prediction tasks.
|
| 22 |
+
|
| 23 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 24 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 25 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
dim_in (int): Input dimension (channels).
|
| 29 |
+
patch_size (int, optional): Patch size. Default is 14.
|
| 30 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
| 31 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
| 32 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 33 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 34 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 35 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 36 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 37 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 38 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
dim_in: int = 1024,
|
| 44 |
+
patch_size: int = 14,
|
| 45 |
+
output_dim: int = 4,
|
| 46 |
+
activation: str = "inv_log",
|
| 47 |
+
conf_activation: str = "expp1",
|
| 48 |
+
features: int = 256,
|
| 49 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 50 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
| 51 |
+
pos_embed: bool = True,
|
| 52 |
+
feature_only: bool = False,
|
| 53 |
+
down_ratio: int = 1,
|
| 54 |
+
) -> None:
|
| 55 |
+
super(DPTHead, self).__init__()
|
| 56 |
+
self.patch_size = patch_size
|
| 57 |
+
self.activation = activation
|
| 58 |
+
self.conf_activation = conf_activation
|
| 59 |
+
self.pos_embed = pos_embed
|
| 60 |
+
self.feature_only = feature_only
|
| 61 |
+
self.down_ratio = down_ratio
|
| 62 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 63 |
+
self.dim_in = dim_in
|
| 64 |
+
self.output_dim = output_dim
|
| 65 |
+
|
| 66 |
+
self.mlp = nn.Sequential(
|
| 67 |
+
nn.Linear(3 * dim_in, 2*dim_in),
|
| 68 |
+
nn.ReLU(),
|
| 69 |
+
nn.Linear(2*dim_in, 2*dim_in),
|
| 70 |
+
nn.ReLU(),
|
| 71 |
+
nn.Linear(2*dim_in, dim_in),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 75 |
+
|
| 76 |
+
# Projection layers for each output channel from tokens.
|
| 77 |
+
self.projects = nn.ModuleList(
|
| 78 |
+
[
|
| 79 |
+
nn.Conv2d(
|
| 80 |
+
in_channels=dim_in,
|
| 81 |
+
out_channels=oc,
|
| 82 |
+
kernel_size=1,
|
| 83 |
+
stride=1,
|
| 84 |
+
padding=0,
|
| 85 |
+
)
|
| 86 |
+
for oc in out_channels
|
| 87 |
+
]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Resize layers for upsampling feature maps.
|
| 91 |
+
self.resize_layers = nn.ModuleList(
|
| 92 |
+
[
|
| 93 |
+
nn.ConvTranspose2d(
|
| 94 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
| 95 |
+
),
|
| 96 |
+
nn.ConvTranspose2d(
|
| 97 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
| 98 |
+
),
|
| 99 |
+
nn.Identity(),
|
| 100 |
+
nn.Conv2d(
|
| 101 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
| 102 |
+
),
|
| 103 |
+
]
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.scratch = _make_scratch(
|
| 107 |
+
out_channels,
|
| 108 |
+
features,
|
| 109 |
+
expand=False,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Attach additional modules to scratch.
|
| 113 |
+
self.scratch.stem_transpose = None
|
| 114 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 115 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 116 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 117 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 118 |
+
|
| 119 |
+
head_features_1 = features
|
| 120 |
+
head_features_2 = 32
|
| 121 |
+
|
| 122 |
+
if feature_only:
|
| 123 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
| 124 |
+
else:
|
| 125 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 126 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 127 |
+
)
|
| 128 |
+
conv2_in_channels = head_features_1 // 2
|
| 129 |
+
|
| 130 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 131 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 132 |
+
nn.ReLU(inplace=True),
|
| 133 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def forward(
|
| 137 |
+
self,
|
| 138 |
+
patch_intermediate_4: List[torch.Tensor], # len=4, each (B*N, hw, dec_embed_dim)
|
| 139 |
+
camera_intermediate_4: List[torch.Tensor], # len=4, each (B*N, hw, camera_dim)
|
| 140 |
+
pair_indices: torch.Tensor, # (B, S, 2)
|
| 141 |
+
patch_start_idx: int,
|
| 142 |
+
img_shape: Tuple[int, int],
|
| 143 |
+
B: int,
|
| 144 |
+
N: int,
|
| 145 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 146 |
+
"""
|
| 147 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
| 148 |
+
Args:
|
| 149 |
+
patch_intermediate_4 (List[Tensor]): List of token tensors from different transformer layers.
|
| 150 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 151 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 152 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 153 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 154 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Tensor or Tuple[Tensor, Tensor]:
|
| 158 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 159 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
feats_4 = []
|
| 163 |
+
for l in range(4):
|
| 164 |
+
feat_l = self._fuse_one_layer(
|
| 165 |
+
patch_intermediate_4[l],
|
| 166 |
+
camera_intermediate_4[l],
|
| 167 |
+
patch_start_idx,
|
| 168 |
+
pair_indices,
|
| 169 |
+
img_shape,
|
| 170 |
+
B,
|
| 171 |
+
N,
|
| 172 |
+
)
|
| 173 |
+
feats_4.append(feat_l)
|
| 174 |
+
|
| 175 |
+
flow = self._dpt_fuse_and_predict(feats_4, img_shape) # (T,2,H,W)
|
| 176 |
+
|
| 177 |
+
H, W = img_shape
|
| 178 |
+
S = pair_indices.shape[1]
|
| 179 |
+
return flow.permute(0, 2, 3, 1).reshape(B, S, H, W, self.output_dim)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 183 |
+
"""
|
| 184 |
+
Apply positional embedding to tensor x.
|
| 185 |
+
"""
|
| 186 |
+
patch_w = x.shape[-1]
|
| 187 |
+
patch_h = x.shape[-2]
|
| 188 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 189 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 190 |
+
pos_embed = pos_embed * ratio
|
| 191 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 192 |
+
return x + pos_embed
|
| 193 |
+
|
| 194 |
+
def _fuse_one_layer(
|
| 195 |
+
self,
|
| 196 |
+
patch_hidden_l: torch.Tensor, # (B*N, hw, dec_embed_dim)
|
| 197 |
+
camera_hidden_l: torch.Tensor, # (B*N, hw, camera_dim)
|
| 198 |
+
patch_start_idx: int,
|
| 199 |
+
pair_indices: torch.Tensor, # (B, S, 2)
|
| 200 |
+
img_shape: Tuple[int, int],
|
| 201 |
+
B: int,
|
| 202 |
+
N: int,
|
| 203 |
+
) -> Tuple[torch.Tensor, int, int, int]:
|
| 204 |
+
"""
|
| 205 |
+
Returns:
|
| 206 |
+
feat_map: (T, dec_embed_dim, patch_h, patch_w) where T = B*S
|
| 207 |
+
T, patch_h, patch_w
|
| 208 |
+
"""
|
| 209 |
+
H, W = img_shape
|
| 210 |
+
hw = patch_hidden_l[:, patch_start_idx:].shape[1]
|
| 211 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 212 |
+
assert hw == patch_h * patch_w, (hw, patch_h, patch_w)
|
| 213 |
+
|
| 214 |
+
# reshape to (B, N, hw, C)
|
| 215 |
+
patch_hidden_l = patch_hidden_l[:, patch_start_idx:].reshape(B, N, hw, self.dim_in)
|
| 216 |
+
camera_hidden_l = camera_hidden_l[:, patch_start_idx:].reshape(B, N, hw, self.dim_in)
|
| 217 |
+
|
| 218 |
+
S = pair_indices.shape[1]
|
| 219 |
+
batch_idx = torch.arange(B, device=pair_indices.device).unsqueeze(1).expand(B, S)
|
| 220 |
+
idx_i = pair_indices[:, :, 0]
|
| 221 |
+
idx_j = pair_indices[:, :, 1]
|
| 222 |
+
|
| 223 |
+
patch_i = patch_hidden_l[batch_idx, idx_i] # (B,S,hw,dec)
|
| 224 |
+
cam_i = camera_hidden_l[batch_idx, idx_i] # (B,S,hw,cam)
|
| 225 |
+
cam_j = camera_hidden_l[batch_idx, idx_j] # (B,S,hw,cam)
|
| 226 |
+
|
| 227 |
+
# Average cam_j to get a single camera token per pair
|
| 228 |
+
cam_j = cam_j.mean(dim=2, keepdim=True).expand(-1, -1, hw, -1)
|
| 229 |
+
|
| 230 |
+
# concat + flatten
|
| 231 |
+
concat = torch.cat([cam_i, cam_j, patch_i], dim=-1) # (B,S,hw, 2cam+dec)
|
| 232 |
+
T = B * S
|
| 233 |
+
x = concat.reshape(B * S, hw, 3 * self.dim_in)
|
| 234 |
+
|
| 235 |
+
# MLP fuse
|
| 236 |
+
x = self.mlp(x) # (B * S, hw, dim_in)
|
| 237 |
+
x = self.norm(x)
|
| 238 |
+
|
| 239 |
+
# token -> grid
|
| 240 |
+
feat = x.transpose(1, 2).reshape(B * S, self.dim_in, patch_h, patch_w) # (B * S, dim_in, ph, pw)
|
| 241 |
+
return feat
|
| 242 |
+
|
| 243 |
+
def _dpt_fuse_and_predict(
|
| 244 |
+
self,
|
| 245 |
+
feats_4: List[torch.Tensor],
|
| 246 |
+
img_shape: Tuple[int, int],
|
| 247 |
+
) -> torch.Tensor:
|
| 248 |
+
"""
|
| 249 |
+
Runs standard DPT fusion and outputs flow:
|
| 250 |
+
returns (T, 2, H, W)
|
| 251 |
+
"""
|
| 252 |
+
H, W = img_shape
|
| 253 |
+
out = []
|
| 254 |
+
for i in range(4):
|
| 255 |
+
x = feats_4[i] # (T, dec, ph, pw)
|
| 256 |
+
x = self.projects[i](x)
|
| 257 |
+
if self.pos_embed:
|
| 258 |
+
x = self._apply_pos_embed(x, W, H)
|
| 259 |
+
x = self.resize_layers[i](x) # multi-scale path
|
| 260 |
+
out.append(x)
|
| 261 |
+
|
| 262 |
+
x = self.scratch_forward(out) # (T, features, ...)
|
| 263 |
+
x = custom_interpolate(
|
| 264 |
+
x, (H, W),
|
| 265 |
+
mode="bilinear",
|
| 266 |
+
align_corners=True,
|
| 267 |
+
)
|
| 268 |
+
if self.pos_embed:
|
| 269 |
+
out = self._apply_pos_embed(x, W, H)
|
| 270 |
+
|
| 271 |
+
flow = self.scratch.output_conv2(out) # (T, 2,H, W)
|
| 272 |
+
return flow
|
| 273 |
+
|
| 274 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 275 |
+
"""
|
| 276 |
+
Forward pass through the fusion blocks.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
features (List[Tensor]): List of feature maps from different layers.
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Tensor: Fused feature map.
|
| 283 |
+
"""
|
| 284 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
| 285 |
+
|
| 286 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 287 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 288 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 289 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 290 |
+
|
| 291 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 292 |
+
del layer_4_rn, layer_4
|
| 293 |
+
|
| 294 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 295 |
+
del layer_3_rn, layer_3
|
| 296 |
+
|
| 297 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 298 |
+
del layer_2_rn, layer_2
|
| 299 |
+
|
| 300 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 301 |
+
del layer_1_rn, layer_1
|
| 302 |
+
|
| 303 |
+
out = self.scratch.output_conv1(out)
|
| 304 |
+
return out
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
################################################################################
|
| 309 |
+
# Modules
|
| 310 |
+
################################################################################
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
| 314 |
+
return FeatureFusionBlock(
|
| 315 |
+
features,
|
| 316 |
+
nn.ReLU(inplace=True),
|
| 317 |
+
deconv=False,
|
| 318 |
+
bn=False,
|
| 319 |
+
expand=False,
|
| 320 |
+
align_corners=True,
|
| 321 |
+
size=size,
|
| 322 |
+
has_residual=has_residual,
|
| 323 |
+
groups=groups,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
| 328 |
+
scratch = nn.Module()
|
| 329 |
+
out_shape1 = out_shape
|
| 330 |
+
out_shape2 = out_shape
|
| 331 |
+
out_shape3 = out_shape
|
| 332 |
+
if len(in_shape) >= 4:
|
| 333 |
+
out_shape4 = out_shape
|
| 334 |
+
|
| 335 |
+
if expand:
|
| 336 |
+
out_shape1 = out_shape
|
| 337 |
+
out_shape2 = out_shape * 2
|
| 338 |
+
out_shape3 = out_shape * 4
|
| 339 |
+
if len(in_shape) >= 4:
|
| 340 |
+
out_shape4 = out_shape * 8
|
| 341 |
+
|
| 342 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 343 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 344 |
+
)
|
| 345 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 346 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 347 |
+
)
|
| 348 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 349 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 350 |
+
)
|
| 351 |
+
if len(in_shape) >= 4:
|
| 352 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 353 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 354 |
+
)
|
| 355 |
+
return scratch
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class ResidualConvUnit(nn.Module):
|
| 359 |
+
"""Residual convolution module."""
|
| 360 |
+
|
| 361 |
+
def __init__(self, features, activation, bn, groups=1):
|
| 362 |
+
"""Init.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
features (int): number of features
|
| 366 |
+
"""
|
| 367 |
+
super().__init__()
|
| 368 |
+
|
| 369 |
+
self.bn = bn
|
| 370 |
+
self.groups = groups
|
| 371 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 372 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 373 |
+
|
| 374 |
+
self.norm1 = None
|
| 375 |
+
self.norm2 = None
|
| 376 |
+
|
| 377 |
+
self.activation = activation
|
| 378 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 379 |
+
|
| 380 |
+
def forward(self, x):
|
| 381 |
+
"""Forward pass.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
x (tensor): input
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
tensor: output
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
out = self.activation(x)
|
| 391 |
+
out = self.conv1(out)
|
| 392 |
+
if self.norm1 is not None:
|
| 393 |
+
out = self.norm1(out)
|
| 394 |
+
|
| 395 |
+
out = self.activation(out)
|
| 396 |
+
out = self.conv2(out)
|
| 397 |
+
if self.norm2 is not None:
|
| 398 |
+
out = self.norm2(out)
|
| 399 |
+
|
| 400 |
+
return self.skip_add.add(out, x)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class FeatureFusionBlock(nn.Module):
|
| 404 |
+
"""Feature fusion block."""
|
| 405 |
+
|
| 406 |
+
def __init__(
|
| 407 |
+
self,
|
| 408 |
+
features,
|
| 409 |
+
activation,
|
| 410 |
+
deconv=False,
|
| 411 |
+
bn=False,
|
| 412 |
+
expand=False,
|
| 413 |
+
align_corners=True,
|
| 414 |
+
size=None,
|
| 415 |
+
has_residual=True,
|
| 416 |
+
groups=1,
|
| 417 |
+
):
|
| 418 |
+
"""Init.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
features (int): number of features
|
| 422 |
+
"""
|
| 423 |
+
super(FeatureFusionBlock, self).__init__()
|
| 424 |
+
|
| 425 |
+
self.deconv = deconv
|
| 426 |
+
self.align_corners = align_corners
|
| 427 |
+
self.groups = groups
|
| 428 |
+
self.expand = expand
|
| 429 |
+
out_features = features
|
| 430 |
+
if self.expand == True:
|
| 431 |
+
out_features = features // 2
|
| 432 |
+
|
| 433 |
+
self.out_conv = nn.Conv2d(
|
| 434 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
if has_residual:
|
| 438 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 439 |
+
|
| 440 |
+
self.has_residual = has_residual
|
| 441 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 442 |
+
|
| 443 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 444 |
+
self.size = size
|
| 445 |
+
|
| 446 |
+
def forward(self, *xs, size=None):
|
| 447 |
+
"""Forward pass.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
tensor: output
|
| 451 |
+
"""
|
| 452 |
+
output = xs[0]
|
| 453 |
+
|
| 454 |
+
if self.has_residual:
|
| 455 |
+
res = self.resConfUnit1(xs[1])
|
| 456 |
+
output = self.skip_add.add(output, res)
|
| 457 |
+
|
| 458 |
+
output = self.resConfUnit2(output)
|
| 459 |
+
|
| 460 |
+
if (size is None) and (self.size is None):
|
| 461 |
+
modifier = {"scale_factor": 2}
|
| 462 |
+
elif size is None:
|
| 463 |
+
modifier = {"size": self.size}
|
| 464 |
+
else:
|
| 465 |
+
modifier = {"size": size}
|
| 466 |
+
|
| 467 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
| 468 |
+
output = self.out_conv(output)
|
| 469 |
+
|
| 470 |
+
return output
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def custom_interpolate(
|
| 474 |
+
x: torch.Tensor,
|
| 475 |
+
size: Tuple[int, int] = None,
|
| 476 |
+
scale_factor: float = None,
|
| 477 |
+
mode: str = "bilinear",
|
| 478 |
+
align_corners: bool = True,
|
| 479 |
+
) -> torch.Tensor:
|
| 480 |
+
"""
|
| 481 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 482 |
+
"""
|
| 483 |
+
if size is None:
|
| 484 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 485 |
+
|
| 486 |
+
INT_MAX = 1610612736
|
| 487 |
+
|
| 488 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 489 |
+
|
| 490 |
+
if input_elements > INT_MAX:
|
| 491 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 492 |
+
interpolated_chunks = [
|
| 493 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
| 494 |
+
]
|
| 495 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 496 |
+
return x.contiguous()
|
| 497 |
+
else:
|
| 498 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
flow3r/models/flow_head/utils.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
| 17 |
+
embed_dim: Output channel dimension for embeddings
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
| 21 |
+
"""
|
| 22 |
+
H, W, grid_dim = pos_grid.shape
|
| 23 |
+
assert grid_dim == 2
|
| 24 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
| 25 |
+
|
| 26 |
+
# Process x and y coordinates separately
|
| 27 |
+
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
| 28 |
+
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
| 29 |
+
|
| 30 |
+
# Combine and reshape
|
| 31 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
| 32 |
+
|
| 33 |
+
return emb.view(H, W, embed_dim) # [H, W, D]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
| 37 |
+
"""
|
| 38 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
- embed_dim: The embedding dimension.
|
| 42 |
+
- pos: The position to generate the embedding from.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
- emb: The generated 1D positional embedding.
|
| 46 |
+
"""
|
| 47 |
+
assert embed_dim % 2 == 0
|
| 48 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
|
| 49 |
+
omega /= embed_dim / 2.0
|
| 50 |
+
omega = 1.0 / omega_0**omega # (D/2,)
|
| 51 |
+
|
| 52 |
+
pos = pos.reshape(-1) # (M,)
|
| 53 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 54 |
+
|
| 55 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 56 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 57 |
+
|
| 58 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 59 |
+
return emb.float()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Inspired by https://github.com/microsoft/moge
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def create_uv_grid(
|
| 66 |
+
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Create a normalized UV grid of shape (width, height, 2).
|
| 70 |
+
|
| 71 |
+
The grid spans horizontally and vertically according to an aspect ratio,
|
| 72 |
+
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
| 73 |
+
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
width (int): Number of points horizontally.
|
| 77 |
+
height (int): Number of points vertically.
|
| 78 |
+
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
| 79 |
+
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
| 80 |
+
device (torch.device, optional): Device on which the tensor is created.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
| 84 |
+
"""
|
| 85 |
+
# Derive aspect ratio if not explicitly provided
|
| 86 |
+
if aspect_ratio is None:
|
| 87 |
+
aspect_ratio = float(width) / float(height)
|
| 88 |
+
|
| 89 |
+
# Compute normalized spans for X and Y
|
| 90 |
+
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
| 91 |
+
span_x = aspect_ratio / diag_factor
|
| 92 |
+
span_y = 1.0 / diag_factor
|
| 93 |
+
|
| 94 |
+
# Establish the linspace boundaries
|
| 95 |
+
left_x = -span_x * (width - 1) / width
|
| 96 |
+
right_x = span_x * (width - 1) / width
|
| 97 |
+
top_y = -span_y * (height - 1) / height
|
| 98 |
+
bottom_y = span_y * (height - 1) / height
|
| 99 |
+
|
| 100 |
+
# Generate 1D coordinates
|
| 101 |
+
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
| 102 |
+
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
| 103 |
+
|
| 104 |
+
# Create 2D meshgrid (width x height) and stack into UV
|
| 105 |
+
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
| 106 |
+
uv_grid = torch.stack((uu, vv), dim=-1)
|
| 107 |
+
|
| 108 |
+
return uv_grid
|
flow3r/models/layers/attention.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 19 |
+
from torch.nn.attention import SDPBackend
|
| 20 |
+
|
| 21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 22 |
+
try:
|
| 23 |
+
if XFORMERS_ENABLED:
|
| 24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 25 |
+
|
| 26 |
+
XFORMERS_AVAILABLE = True
|
| 27 |
+
# warnings.warn("xFormers is available (Attention)")
|
| 28 |
+
else:
|
| 29 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
| 30 |
+
raise ImportError
|
| 31 |
+
except ImportError:
|
| 32 |
+
XFORMERS_AVAILABLE = False
|
| 33 |
+
# warnings.warn("xFormers is not available (Attention)")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Attention(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int = 8,
|
| 41 |
+
qkv_bias: bool = False,
|
| 42 |
+
proj_bias: bool = True,
|
| 43 |
+
attn_drop: float = 0.0,
|
| 44 |
+
proj_drop: float = 0.0,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
self.scale = head_dim**-0.5
|
| 50 |
+
|
| 51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 57 |
+
B, N, C = x.shape
|
| 58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 59 |
+
|
| 60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 61 |
+
attn = q @ k.transpose(-2, -1)
|
| 62 |
+
|
| 63 |
+
attn = attn.softmax(dim=-1)
|
| 64 |
+
attn = self.attn_drop(attn)
|
| 65 |
+
|
| 66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 67 |
+
x = self.proj(x)
|
| 68 |
+
x = self.proj_drop(x)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MemEffAttention(Attention):
|
| 73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 74 |
+
if not XFORMERS_AVAILABLE:
|
| 75 |
+
if attn_bias is not None:
|
| 76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 77 |
+
return super().forward(x)
|
| 78 |
+
|
| 79 |
+
B, N, C = x.shape
|
| 80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 81 |
+
|
| 82 |
+
# q, k, v = unbind(qkv, 2)
|
| 83 |
+
q, k, v = [qkv[:,:,i] for i in range(3)]
|
| 84 |
+
|
| 85 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 86 |
+
x = x.reshape([B, N, C])
|
| 87 |
+
|
| 88 |
+
x = self.proj(x)
|
| 89 |
+
x = self.proj_drop(x)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class FlashAttention(Attention):
|
| 95 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 96 |
+
B, N, C = x.shape
|
| 97 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3)
|
| 98 |
+
|
| 99 |
+
# q, k, v = unbind(qkv, 2)
|
| 100 |
+
q, k, v = [qkv[:,:,i] for i in range(3)]
|
| 101 |
+
|
| 102 |
+
if q.dtype == torch.bfloat16:
|
| 103 |
+
with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
| 104 |
+
x = scaled_dot_product_attention(q, k, v)
|
| 105 |
+
else:
|
| 106 |
+
with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
| 107 |
+
x = scaled_dot_product_attention(q, k, v)
|
| 108 |
+
|
| 109 |
+
x = x.transpose(1, 2).reshape([B, N, C])
|
| 110 |
+
|
| 111 |
+
x = self.proj(x)
|
| 112 |
+
x = self.proj_drop(x)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
"""
|
| 117 |
+
Following is written by GPT-4o
|
| 118 |
+
"""
|
| 119 |
+
class CrossAttentionRope(nn.Module):
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
dim: int,
|
| 123 |
+
num_heads: int = 8,
|
| 124 |
+
qkv_bias: bool = False,
|
| 125 |
+
proj_bias: bool = True,
|
| 126 |
+
attn_drop: float = 0.0,
|
| 127 |
+
proj_drop: float = 0.0,
|
| 128 |
+
qk_norm: bool = False,
|
| 129 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 130 |
+
rope=None,
|
| 131 |
+
) -> None:
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.num_heads = num_heads
|
| 134 |
+
head_dim = dim // num_heads
|
| 135 |
+
self.scale = head_dim**-0.5
|
| 136 |
+
|
| 137 |
+
# Separate projection layers for query, key, and value
|
| 138 |
+
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 139 |
+
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 140 |
+
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 141 |
+
|
| 142 |
+
self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 143 |
+
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 144 |
+
|
| 145 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 146 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 147 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 148 |
+
|
| 149 |
+
self.rope = rope
|
| 150 |
+
|
| 151 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
|
| 152 |
+
"""
|
| 153 |
+
Args:
|
| 154 |
+
query: Tensor of shape (B, N, C), input query
|
| 155 |
+
key: Tensor of shape (B, M, C), input key
|
| 156 |
+
value: Tensor of shape (B, M, C), input value
|
| 157 |
+
attn_bias: Optional tensor for attention bias
|
| 158 |
+
Returns:
|
| 159 |
+
Tensor of shape (B, N, C), output of cross-attention
|
| 160 |
+
"""
|
| 161 |
+
B, N, C = query.shape
|
| 162 |
+
_, M, _ = key.shape
|
| 163 |
+
|
| 164 |
+
# Project query, key, and value
|
| 165 |
+
q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 166 |
+
k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 167 |
+
v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 168 |
+
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 169 |
+
|
| 170 |
+
if self.rope is not None:
|
| 171 |
+
q = self.rope(q, qpos)
|
| 172 |
+
k = self.rope(k, kpos)
|
| 173 |
+
|
| 174 |
+
# Scale query
|
| 175 |
+
q = q * self.scale
|
| 176 |
+
|
| 177 |
+
# Compute attention scores
|
| 178 |
+
attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M)
|
| 179 |
+
if attn_bias is not None:
|
| 180 |
+
attn = attn + attn_bias
|
| 181 |
+
|
| 182 |
+
attn = attn.softmax(dim=-1)
|
| 183 |
+
attn = self.attn_drop(attn)
|
| 184 |
+
|
| 185 |
+
# Compute attention output
|
| 186 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C)
|
| 187 |
+
|
| 188 |
+
# Final projection
|
| 189 |
+
x = self.proj(x)
|
| 190 |
+
x = self.proj_drop(x)
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class MemEffCrossAttentionRope(CrossAttentionRope):
|
| 195 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
|
| 196 |
+
"""
|
| 197 |
+
Args:
|
| 198 |
+
query: Tensor of shape (B, N, C), input query
|
| 199 |
+
key: Tensor of shape (B, M, C), input key
|
| 200 |
+
value: Tensor of shape (B, M, C), input value
|
| 201 |
+
attn_bias: Optional tensor for attention bias
|
| 202 |
+
Returns:
|
| 203 |
+
Tensor of shape (B, N, C), output of cross-attention
|
| 204 |
+
"""
|
| 205 |
+
if not XFORMERS_AVAILABLE:
|
| 206 |
+
if attn_bias is not None:
|
| 207 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 208 |
+
return super().forward(query, key, value, attn_bias)
|
| 209 |
+
|
| 210 |
+
B, N, C = query.shape
|
| 211 |
+
_, M, _ = key.shape
|
| 212 |
+
|
| 213 |
+
# Project query, key, and value
|
| 214 |
+
q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads)
|
| 215 |
+
k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads)
|
| 216 |
+
v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads)
|
| 217 |
+
|
| 218 |
+
q = q.transpose(1, 2)
|
| 219 |
+
k = k.transpose(1, 2)
|
| 220 |
+
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 221 |
+
|
| 222 |
+
if self.rope is not None:
|
| 223 |
+
q = self.rope(q, qpos)
|
| 224 |
+
k = self.rope(k, kpos)
|
| 225 |
+
|
| 226 |
+
q = q.transpose(1, 2)
|
| 227 |
+
k = k.transpose(1, 2)
|
| 228 |
+
|
| 229 |
+
# Compute memory-efficient attention
|
| 230 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 231 |
+
x = x.reshape(B, N, C)
|
| 232 |
+
|
| 233 |
+
# Final projection
|
| 234 |
+
x = self.proj(x)
|
| 235 |
+
x = self.proj_drop(x)
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
class FlashCrossAttentionRope(CrossAttentionRope):
|
| 239 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
|
| 240 |
+
B, N, C = query.shape
|
| 241 |
+
_, M, _ = key.shape
|
| 242 |
+
|
| 243 |
+
# 1. 投射 query, key, value 并调整维度为 (B, num_heads, Seq_Len, head_dim)
|
| 244 |
+
q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 245 |
+
k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 246 |
+
v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 247 |
+
|
| 248 |
+
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 249 |
+
if self.rope is not None:
|
| 250 |
+
q = self.rope(q, qpos)
|
| 251 |
+
k = self.rope(k, kpos)
|
| 252 |
+
|
| 253 |
+
dropout_p = self.attn_drop.p if self.training else 0.0
|
| 254 |
+
|
| 255 |
+
if q.dtype == torch.bfloat16:
|
| 256 |
+
with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
| 257 |
+
x = scaled_dot_product_attention(
|
| 258 |
+
q, k, v, attn_mask=attn_bias, dropout_p=dropout_p
|
| 259 |
+
)
|
| 260 |
+
else:
|
| 261 |
+
with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
| 262 |
+
x = scaled_dot_product_attention(
|
| 263 |
+
q, k, v, attn_mask=attn_bias, dropout_p=dropout_p
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 267 |
+
|
| 268 |
+
x = self.proj(x)
|
| 269 |
+
x = self.proj_drop(x)
|
| 270 |
+
return x
|
| 271 |
+
|
| 272 |
+
class AttentionRope(nn.Module):
|
| 273 |
+
def __init__(
|
| 274 |
+
self,
|
| 275 |
+
dim: int,
|
| 276 |
+
num_heads: int = 8,
|
| 277 |
+
qkv_bias: bool = False,
|
| 278 |
+
proj_bias: bool = True,
|
| 279 |
+
attn_drop: float = 0.0,
|
| 280 |
+
proj_drop: float = 0.0,
|
| 281 |
+
qk_norm: bool = False,
|
| 282 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 283 |
+
rope=None
|
| 284 |
+
) -> None:
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.num_heads = num_heads
|
| 287 |
+
head_dim = dim // num_heads
|
| 288 |
+
self.scale = head_dim**-0.5
|
| 289 |
+
|
| 290 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 291 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 292 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 293 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 294 |
+
|
| 295 |
+
self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 296 |
+
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 297 |
+
|
| 298 |
+
self.rope = rope
|
| 299 |
+
|
| 300 |
+
def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
|
| 301 |
+
B, N, C = x.shape
|
| 302 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 303 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 304 |
+
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 305 |
+
|
| 306 |
+
if self.rope is not None:
|
| 307 |
+
q = self.rope(q, xpos)
|
| 308 |
+
k = self.rope(k, xpos)
|
| 309 |
+
|
| 310 |
+
q = q * self.scale
|
| 311 |
+
attn = q @ k.transpose(-2, -1)
|
| 312 |
+
|
| 313 |
+
attn = attn.softmax(dim=-1)
|
| 314 |
+
attn = self.attn_drop(attn)
|
| 315 |
+
|
| 316 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 317 |
+
x = self.proj(x)
|
| 318 |
+
x = self.proj_drop(x)
|
| 319 |
+
return x
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class MemEffAttentionRope(AttentionRope):
|
| 323 |
+
def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
|
| 324 |
+
if not XFORMERS_AVAILABLE:
|
| 325 |
+
if attn_bias is not None:
|
| 326 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 327 |
+
return super().forward(x)
|
| 328 |
+
|
| 329 |
+
B, N, C = x.shape
|
| 330 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 331 |
+
|
| 332 |
+
qkv = qkv.transpose(1, 3)
|
| 333 |
+
# q, k, v = unbind(qkv, 2)
|
| 334 |
+
q, k, v = [qkv[:,:,i] for i in range(3)]
|
| 335 |
+
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 336 |
+
|
| 337 |
+
if self.rope is not None:
|
| 338 |
+
q = self.rope(q, xpos)
|
| 339 |
+
k = self.rope(k, xpos)
|
| 340 |
+
|
| 341 |
+
q = q.transpose(1, 2)
|
| 342 |
+
k = k.transpose(1, 2)
|
| 343 |
+
v = v.transpose(1, 2)
|
| 344 |
+
|
| 345 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 346 |
+
x = x.reshape([B, N, C])
|
| 347 |
+
|
| 348 |
+
# score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix
|
| 349 |
+
# global_valid_id = torch.where(score_matrix > 0)
|
| 350 |
+
# score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1)
|
| 351 |
+
|
| 352 |
+
x = self.proj(x)
|
| 353 |
+
x = self.proj_drop(x)
|
| 354 |
+
return x
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class FlashAttentionRope(AttentionRope):
|
| 358 |
+
def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
|
| 359 |
+
B, N, C = x.shape
|
| 360 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3)
|
| 361 |
+
|
| 362 |
+
# q, k, v = unbind(qkv, 2)
|
| 363 |
+
q, k, v = [qkv[:,:,i] for i in range(3)]
|
| 364 |
+
q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
|
| 365 |
+
|
| 366 |
+
if self.rope is not None:
|
| 367 |
+
q = self.rope(q, xpos)
|
| 368 |
+
k = self.rope(k, xpos)
|
| 369 |
+
|
| 370 |
+
if q.dtype == torch.bfloat16:
|
| 371 |
+
with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
| 372 |
+
x = scaled_dot_product_attention(q, k, v)
|
| 373 |
+
else:
|
| 374 |
+
with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
| 375 |
+
x = scaled_dot_product_attention(q, k, v)
|
| 376 |
+
|
| 377 |
+
x = x.transpose(1, 2).reshape([B, N, C])
|
| 378 |
+
|
| 379 |
+
x = self.proj(x)
|
| 380 |
+
x = self.proj_drop(x)
|
| 381 |
+
return x
|
| 382 |
+
|
| 383 |
+
def get_attn_score(blk_class, x, frame_num, token_length, xpos=None):
|
| 384 |
+
x = blk_class.norm1(x)
|
| 385 |
+
|
| 386 |
+
B, N, C = x.shape
|
| 387 |
+
qkv = blk_class.attn.qkv(x).reshape(B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads)
|
| 388 |
+
|
| 389 |
+
qkv = qkv.transpose(1, 3)
|
| 390 |
+
# q, k, v = unbind(qkv, 2)
|
| 391 |
+
q, k, v = [qkv[:,:,i] for i in range(3)]
|
| 392 |
+
q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype)
|
| 393 |
+
|
| 394 |
+
if blk_class.attn.rope is not None:
|
| 395 |
+
q = blk_class.attn.rope(q, xpos)
|
| 396 |
+
k = blk_class.attn.rope(k, xpos)
|
| 397 |
+
|
| 398 |
+
q = q.transpose(1, 2)
|
| 399 |
+
k = k.transpose(1, 2)
|
| 400 |
+
|
| 401 |
+
score = (q.permute(0, 2, 1, 3) * blk_class.attn.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(B, frame_num, token_length, frame_num, token_length).mean(dim=[2, 4]).sum(-1)
|
| 402 |
+
|
| 403 |
+
return score
|
flow3r/models/layers/block.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention, MemEffAttention, CrossAttentionRope, MemEffCrossAttentionRope, FlashAttentionRope
|
| 19 |
+
from ..dinov2.layers.drop_path import DropPath
|
| 20 |
+
from ..dinov2.layers.layer_scale import LayerScale
|
| 21 |
+
from ..dinov2.layers.mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 25 |
+
try:
|
| 26 |
+
if XFORMERS_ENABLED:
|
| 27 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
| 28 |
+
|
| 29 |
+
XFORMERS_AVAILABLE = True
|
| 30 |
+
# warnings.warn("xFormers is available (Block)")
|
| 31 |
+
else:
|
| 32 |
+
# warnings.warn("xFormers is disabled (Block)")
|
| 33 |
+
raise ImportError
|
| 34 |
+
except ImportError:
|
| 35 |
+
XFORMERS_AVAILABLE = False
|
| 36 |
+
# warnings.warn("xFormers is not available (Block)")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Block(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
dim: int,
|
| 43 |
+
num_heads: int,
|
| 44 |
+
mlp_ratio: float = 4.0,
|
| 45 |
+
qkv_bias: bool = False,
|
| 46 |
+
proj_bias: bool = True,
|
| 47 |
+
ffn_bias: bool = True,
|
| 48 |
+
drop: float = 0.0,
|
| 49 |
+
attn_drop: float = 0.0,
|
| 50 |
+
init_values=None,
|
| 51 |
+
drop_path: float = 0.0,
|
| 52 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 53 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 54 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 55 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 56 |
+
) -> None:
|
| 57 |
+
super().__init__()
|
| 58 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 59 |
+
self.norm1 = norm_layer(dim)
|
| 60 |
+
self.attn = attn_class(
|
| 61 |
+
dim,
|
| 62 |
+
num_heads=num_heads,
|
| 63 |
+
qkv_bias=qkv_bias,
|
| 64 |
+
proj_bias=proj_bias,
|
| 65 |
+
attn_drop=attn_drop,
|
| 66 |
+
proj_drop=drop,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 70 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 71 |
+
|
| 72 |
+
self.norm2 = norm_layer(dim)
|
| 73 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 74 |
+
self.mlp = ffn_layer(
|
| 75 |
+
in_features=dim,
|
| 76 |
+
hidden_features=mlp_hidden_dim,
|
| 77 |
+
act_layer=act_layer,
|
| 78 |
+
drop=drop,
|
| 79 |
+
bias=ffn_bias,
|
| 80 |
+
)
|
| 81 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 82 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 83 |
+
|
| 84 |
+
self.sample_drop_ratio = drop_path
|
| 85 |
+
|
| 86 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 87 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 88 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 89 |
+
|
| 90 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 91 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 92 |
+
|
| 93 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 94 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 95 |
+
x = drop_add_residual_stochastic_depth(
|
| 96 |
+
x,
|
| 97 |
+
residual_func=attn_residual_func,
|
| 98 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 99 |
+
)
|
| 100 |
+
x = drop_add_residual_stochastic_depth(
|
| 101 |
+
x,
|
| 102 |
+
residual_func=ffn_residual_func,
|
| 103 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 104 |
+
)
|
| 105 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 106 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 107 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 108 |
+
else:
|
| 109 |
+
x = x + attn_residual_func(x)
|
| 110 |
+
x = x + ffn_residual_func(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def drop_add_residual_stochastic_depth(
|
| 115 |
+
x: Tensor,
|
| 116 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 117 |
+
sample_drop_ratio: float = 0.0,
|
| 118 |
+
) -> Tensor:
|
| 119 |
+
# 1) extract subset using permutation
|
| 120 |
+
b, n, d = x.shape
|
| 121 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 122 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 123 |
+
x_subset = x[brange]
|
| 124 |
+
|
| 125 |
+
# 2) apply residual_func to get residual
|
| 126 |
+
residual = residual_func(x_subset)
|
| 127 |
+
|
| 128 |
+
x_flat = x.flatten(1)
|
| 129 |
+
residual = residual.flatten(1)
|
| 130 |
+
|
| 131 |
+
residual_scale_factor = b / sample_subset_size
|
| 132 |
+
|
| 133 |
+
# 3) add the residual
|
| 134 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 135 |
+
return x_plus_residual.view_as(x)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 139 |
+
b, n, d = x.shape
|
| 140 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 141 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 142 |
+
residual_scale_factor = b / sample_subset_size
|
| 143 |
+
return brange, residual_scale_factor
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 147 |
+
if scaling_vector is None:
|
| 148 |
+
x_flat = x.flatten(1)
|
| 149 |
+
residual = residual.flatten(1)
|
| 150 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 151 |
+
else:
|
| 152 |
+
x_plus_residual = scaled_index_add(
|
| 153 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 154 |
+
)
|
| 155 |
+
return x_plus_residual
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 162 |
+
"""
|
| 163 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 164 |
+
"""
|
| 165 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 166 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 167 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 168 |
+
seqlens = []
|
| 169 |
+
for b, x in zip(batch_sizes, x_list):
|
| 170 |
+
for _ in range(b):
|
| 171 |
+
seqlens.append(x.shape[1])
|
| 172 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 173 |
+
attn_bias._batch_sizes = batch_sizes
|
| 174 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 175 |
+
|
| 176 |
+
if branges is not None:
|
| 177 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 178 |
+
else:
|
| 179 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 180 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 181 |
+
|
| 182 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def drop_add_residual_stochastic_depth_list(
|
| 186 |
+
x_list: List[Tensor],
|
| 187 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 188 |
+
sample_drop_ratio: float = 0.0,
|
| 189 |
+
scaling_vector=None,
|
| 190 |
+
) -> Tensor:
|
| 191 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 192 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 193 |
+
branges = [s[0] for s in branges_scales]
|
| 194 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 195 |
+
|
| 196 |
+
# 2) get attention bias and index+concat the tensors
|
| 197 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 198 |
+
|
| 199 |
+
# 3) apply residual_func to get residual, and split the result
|
| 200 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 201 |
+
|
| 202 |
+
outputs = []
|
| 203 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 204 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 205 |
+
return outputs
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class NestedTensorBlock(Block):
|
| 209 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 210 |
+
"""
|
| 211 |
+
x_list contains a list of tensors to nest together and run
|
| 212 |
+
"""
|
| 213 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 214 |
+
|
| 215 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 216 |
+
|
| 217 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 218 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 219 |
+
|
| 220 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 221 |
+
return self.mlp(self.norm2(x))
|
| 222 |
+
|
| 223 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 224 |
+
x_list,
|
| 225 |
+
residual_func=attn_residual_func,
|
| 226 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 227 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 228 |
+
)
|
| 229 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 230 |
+
x_list,
|
| 231 |
+
residual_func=ffn_residual_func,
|
| 232 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 233 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 234 |
+
)
|
| 235 |
+
return x_list
|
| 236 |
+
else:
|
| 237 |
+
|
| 238 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 239 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 240 |
+
|
| 241 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 242 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 243 |
+
|
| 244 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 245 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 246 |
+
x = x + ffn_residual_func(x)
|
| 247 |
+
return attn_bias.split(x)
|
| 248 |
+
|
| 249 |
+
def forward(self, x_or_x_list):
|
| 250 |
+
if isinstance(x_or_x_list, Tensor):
|
| 251 |
+
return super().forward(x_or_x_list)
|
| 252 |
+
elif isinstance(x_or_x_list, list):
|
| 253 |
+
if not XFORMERS_AVAILABLE:
|
| 254 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 255 |
+
return self.forward_nested(x_or_x_list)
|
| 256 |
+
else:
|
| 257 |
+
raise AssertionError
|
| 258 |
+
|
| 259 |
+
class BlockRope(nn.Module):
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
dim: int,
|
| 263 |
+
num_heads: int,
|
| 264 |
+
mlp_ratio: float = 4.0,
|
| 265 |
+
qkv_bias: bool = False,
|
| 266 |
+
proj_bias: bool = True,
|
| 267 |
+
ffn_bias: bool = True,
|
| 268 |
+
drop: float = 0.0,
|
| 269 |
+
attn_drop: float = 0.0,
|
| 270 |
+
init_values=None,
|
| 271 |
+
drop_path: float = 0.0,
|
| 272 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 273 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 274 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 275 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 276 |
+
qk_norm: bool=False,
|
| 277 |
+
rope=None
|
| 278 |
+
) -> None:
|
| 279 |
+
super().__init__()
|
| 280 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 281 |
+
self.norm1 = norm_layer(dim)
|
| 282 |
+
self.attn = attn_class(
|
| 283 |
+
dim,
|
| 284 |
+
num_heads=num_heads,
|
| 285 |
+
qkv_bias=qkv_bias,
|
| 286 |
+
proj_bias=proj_bias,
|
| 287 |
+
attn_drop=attn_drop,
|
| 288 |
+
proj_drop=drop,
|
| 289 |
+
qk_norm=qk_norm,
|
| 290 |
+
rope=rope
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 294 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 295 |
+
|
| 296 |
+
self.norm2 = norm_layer(dim)
|
| 297 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 298 |
+
self.mlp = ffn_layer(
|
| 299 |
+
in_features=dim,
|
| 300 |
+
hidden_features=mlp_hidden_dim,
|
| 301 |
+
act_layer=act_layer,
|
| 302 |
+
drop=drop,
|
| 303 |
+
bias=ffn_bias,
|
| 304 |
+
)
|
| 305 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 306 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 307 |
+
|
| 308 |
+
self.sample_drop_ratio = drop_path
|
| 309 |
+
|
| 310 |
+
def forward(self, x: Tensor, xpos=None) -> Tensor:
|
| 311 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 312 |
+
return self.ls1(self.attn(self.norm1(x), xpos=xpos))
|
| 313 |
+
|
| 314 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 315 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 316 |
+
|
| 317 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 318 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 319 |
+
x = drop_add_residual_stochastic_depth(
|
| 320 |
+
x,
|
| 321 |
+
residual_func=attn_residual_func,
|
| 322 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 323 |
+
)
|
| 324 |
+
x = drop_add_residual_stochastic_depth(
|
| 325 |
+
x,
|
| 326 |
+
residual_func=ffn_residual_func,
|
| 327 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 328 |
+
)
|
| 329 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 330 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 331 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 332 |
+
else:
|
| 333 |
+
x = x + attn_residual_func(x)
|
| 334 |
+
x = x + ffn_residual_func(x)
|
| 335 |
+
return x
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class CrossBlockRope(nn.Module):
|
| 339 |
+
def __init__(
|
| 340 |
+
self,
|
| 341 |
+
dim: int,
|
| 342 |
+
num_heads: int,
|
| 343 |
+
mlp_ratio: float = 4.0,
|
| 344 |
+
qkv_bias: bool = False,
|
| 345 |
+
proj_bias: bool = True,
|
| 346 |
+
ffn_bias: bool = True,
|
| 347 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 348 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 349 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 350 |
+
cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope,
|
| 351 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 352 |
+
init_values=None,
|
| 353 |
+
qk_norm: bool=False,
|
| 354 |
+
rope=None
|
| 355 |
+
) -> None:
|
| 356 |
+
super().__init__()
|
| 357 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 358 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 359 |
+
self.norm1 = norm_layer(dim)
|
| 360 |
+
self.attn = attn_class(
|
| 361 |
+
dim,
|
| 362 |
+
num_heads=num_heads,
|
| 363 |
+
qkv_bias=qkv_bias,
|
| 364 |
+
proj_bias=proj_bias,
|
| 365 |
+
rope=rope,
|
| 366 |
+
qk_norm=qk_norm
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 370 |
+
self.ls_y = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 371 |
+
self.norm2 = norm_layer(dim)
|
| 372 |
+
self.norm_y = norm_layer(dim)
|
| 373 |
+
self.cross_attn = cross_attn_class(
|
| 374 |
+
dim,
|
| 375 |
+
num_heads=num_heads,
|
| 376 |
+
qkv_bias=qkv_bias,
|
| 377 |
+
proj_bias=proj_bias,
|
| 378 |
+
rope=rope,
|
| 379 |
+
qk_norm=qk_norm
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
self.norm3 = norm_layer(dim)
|
| 383 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 384 |
+
self.mlp = ffn_layer(
|
| 385 |
+
in_features=dim,
|
| 386 |
+
hidden_features=mlp_hidden_dim,
|
| 387 |
+
act_layer=act_layer,
|
| 388 |
+
bias=ffn_bias,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor:
|
| 392 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 393 |
+
return self.ls1(self.attn(self.norm1(x), xpos=xpos))
|
| 394 |
+
|
| 395 |
+
def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor:
|
| 396 |
+
return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos))
|
| 397 |
+
|
| 398 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 399 |
+
return self.ls2(self.mlp(self.norm3(x)))
|
| 400 |
+
|
| 401 |
+
x = x + attn_residual_func(x)
|
| 402 |
+
y_ = self.norm_y(y)
|
| 403 |
+
x = x + cross_attn_residual_func(x, y_)
|
| 404 |
+
x = x + ffn_residual_func(x)
|
| 405 |
+
|
| 406 |
+
return x
|
flow3r/models/layers/camera_head.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172'
|
| 7 |
+
class ResConvBlock(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
1x1 convolution residual block
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, in_channels, out_channels):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.in_channels = in_channels
|
| 14 |
+
self.out_channels = out_channels
|
| 15 |
+
self.head_skip = nn.Identity() if self.in_channels == self.out_channels else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
|
| 16 |
+
# self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
|
| 17 |
+
# self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
|
| 18 |
+
# self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
|
| 19 |
+
|
| 20 |
+
# change 1x1 convolution to linear
|
| 21 |
+
self.res_conv1 = nn.Linear(self.in_channels, self.out_channels)
|
| 22 |
+
self.res_conv2 = nn.Linear(self.out_channels, self.out_channels)
|
| 23 |
+
self.res_conv3 = nn.Linear(self.out_channels, self.out_channels)
|
| 24 |
+
|
| 25 |
+
def forward(self, res):
|
| 26 |
+
x = F.relu(self.res_conv1(res))
|
| 27 |
+
x = F.relu(self.res_conv2(x))
|
| 28 |
+
x = F.relu(self.res_conv3(x))
|
| 29 |
+
res = self.head_skip(res) + x
|
| 30 |
+
return res
|
| 31 |
+
|
| 32 |
+
class CameraHead(nn.Module):
|
| 33 |
+
def __init__(self, dim=512):
|
| 34 |
+
super().__init__()
|
| 35 |
+
output_dim = dim
|
| 36 |
+
self.res_conv = nn.ModuleList([deepcopy(ResConvBlock(output_dim, output_dim))
|
| 37 |
+
for _ in range(2)])
|
| 38 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 39 |
+
self.more_mlps = nn.Sequential(
|
| 40 |
+
nn.Linear(output_dim,output_dim),
|
| 41 |
+
nn.ReLU(),
|
| 42 |
+
nn.Linear(output_dim,output_dim),
|
| 43 |
+
nn.ReLU()
|
| 44 |
+
)
|
| 45 |
+
self.fc_t = nn.Linear(output_dim, 3)
|
| 46 |
+
self.fc_rot = nn.Linear(output_dim, 9)
|
| 47 |
+
|
| 48 |
+
def forward(self, feat, patch_h, patch_w):
|
| 49 |
+
BN, hw, c = feat.shape
|
| 50 |
+
|
| 51 |
+
for i in range(2):
|
| 52 |
+
feat = self.res_conv[i](feat)
|
| 53 |
+
|
| 54 |
+
# feat = self.avgpool(feat)
|
| 55 |
+
feat = self.avgpool(feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()) ##########
|
| 56 |
+
feat = feat.view(feat.size(0), -1)
|
| 57 |
+
|
| 58 |
+
feat = self.more_mlps(feat) # [B, D_]
|
| 59 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
| 60 |
+
out_t = self.fc_t(feat.float()) # [B,3]
|
| 61 |
+
out_r = self.fc_rot(feat.float()) # [B,9]
|
| 62 |
+
pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device)
|
| 63 |
+
|
| 64 |
+
return pose
|
| 65 |
+
|
| 66 |
+
def convert_pose_to_4x4(self, B, out_r, out_t, device):
|
| 67 |
+
out_r = self.svd_orthogonalize(out_r) # [N,3,3]
|
| 68 |
+
pose = torch.zeros((B, 4, 4), device=device)
|
| 69 |
+
pose[:, :3, :3] = out_r
|
| 70 |
+
pose[:, :3, 3] = out_t
|
| 71 |
+
pose[:, 3, 3] = 1.
|
| 72 |
+
return pose
|
| 73 |
+
|
| 74 |
+
def svd_orthogonalize(self, m):
|
| 75 |
+
"""Convert 9D representation to SO(3) using SVD orthogonalization.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
m: [BATCH, 3, 3] 3x3 matrices.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
[BATCH, 3, 3] SO(3) rotation matrices.
|
| 82 |
+
"""
|
| 83 |
+
if m.dim() < 3:
|
| 84 |
+
m = m.reshape((-1, 3, 3))
|
| 85 |
+
m_transpose = torch.transpose(torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2)
|
| 86 |
+
u, s, v = torch.svd(m_transpose)
|
| 87 |
+
det = torch.det(torch.matmul(v, u.transpose(-2, -1)))
|
| 88 |
+
# Check orientation reflection.
|
| 89 |
+
r = torch.matmul(
|
| 90 |
+
torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2),
|
| 91 |
+
u.transpose(-2, -1)
|
| 92 |
+
)
|
| 93 |
+
return r
|
flow3r/models/layers/pos_embed.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# Position embedding utils
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# --------------------------------------------------------
|
| 16 |
+
# 2D sine-cosine position embedding
|
| 17 |
+
# References:
|
| 18 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 19 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 20 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 21 |
+
# --------------------------------------------------------
|
| 22 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
|
| 23 |
+
"""
|
| 24 |
+
grid_size: int of the grid height and width
|
| 25 |
+
return:
|
| 26 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 27 |
+
"""
|
| 28 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 29 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 30 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 31 |
+
grid = np.stack(grid, axis=0)
|
| 32 |
+
|
| 33 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 34 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 35 |
+
if n_cls_token>0:
|
| 36 |
+
pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
|
| 37 |
+
return pos_embed
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 41 |
+
assert embed_dim % 2 == 0
|
| 42 |
+
|
| 43 |
+
# use half of dimensions to encode grid_h
|
| 44 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 45 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 46 |
+
|
| 47 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 48 |
+
return emb
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 52 |
+
"""
|
| 53 |
+
embed_dim: output dimension for each position
|
| 54 |
+
pos: a list of positions to be encoded: size (M,)
|
| 55 |
+
out: (M, D)
|
| 56 |
+
"""
|
| 57 |
+
assert embed_dim % 2 == 0
|
| 58 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 59 |
+
omega /= embed_dim / 2.
|
| 60 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 61 |
+
|
| 62 |
+
pos = pos.reshape(-1) # (M,)
|
| 63 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 64 |
+
|
| 65 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 66 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 67 |
+
|
| 68 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 69 |
+
return emb
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# --------------------------------------------------------
|
| 73 |
+
# Interpolate position embeddings for high-resolution
|
| 74 |
+
# References:
|
| 75 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 76 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 77 |
+
# --------------------------------------------------------
|
| 78 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
| 79 |
+
if 'pos_embed' in checkpoint_model:
|
| 80 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
| 81 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 82 |
+
num_patches = model.patch_embed.num_patches
|
| 83 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 84 |
+
# height (== width) for the checkpoint position embedding
|
| 85 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 86 |
+
# height (== width) for the new position embedding
|
| 87 |
+
new_size = int(num_patches ** 0.5)
|
| 88 |
+
# class_token and dist_token are kept unchanged
|
| 89 |
+
if orig_size != new_size:
|
| 90 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 91 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 92 |
+
# only the position tokens are interpolated
|
| 93 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 94 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 95 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 96 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 97 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 98 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 99 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
#----------------------------------------------------------
|
| 103 |
+
# RoPE2D: RoPE implementation in 2D
|
| 104 |
+
#----------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
from models.curope import cuRoPE2D
|
| 108 |
+
RoPE2D = cuRoPE2D
|
| 109 |
+
except ImportError:
|
| 110 |
+
print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
|
| 111 |
+
|
| 112 |
+
class RoPE2D(torch.nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self, freq=100.0, F0=1.0):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.base = freq
|
| 117 |
+
self.F0 = F0
|
| 118 |
+
self.cache = {}
|
| 119 |
+
|
| 120 |
+
def get_cos_sin(self, D, seq_len, device, dtype):
|
| 121 |
+
if (D,seq_len,device,dtype) not in self.cache:
|
| 122 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
| 123 |
+
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 124 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
| 125 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
| 126 |
+
cos = freqs.cos() # (Seq, Dim)
|
| 127 |
+
sin = freqs.sin()
|
| 128 |
+
self.cache[D,seq_len,device,dtype] = (cos,sin)
|
| 129 |
+
return self.cache[D,seq_len,device,dtype]
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def rotate_half(x):
|
| 133 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 134 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 135 |
+
|
| 136 |
+
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
| 137 |
+
assert pos1d.ndim==2
|
| 138 |
+
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
| 139 |
+
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
| 140 |
+
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
| 141 |
+
|
| 142 |
+
def forward(self, tokens, positions):
|
| 143 |
+
"""
|
| 144 |
+
input:
|
| 145 |
+
* tokens: batch_size x nheads x ntokens x dim
|
| 146 |
+
* positions: batch_size x ntokens x 2 (y and x position of each token)
|
| 147 |
+
output:
|
| 148 |
+
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
| 149 |
+
"""
|
| 150 |
+
assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
| 151 |
+
D = tokens.size(3) // 2
|
| 152 |
+
assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
| 153 |
+
cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
|
| 154 |
+
# split features into two along the feature dimension, and apply rope1d on each half
|
| 155 |
+
y, x = tokens.chunk(2, dim=-1)
|
| 156 |
+
y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
| 157 |
+
x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
| 158 |
+
tokens = torch.cat((y, x), dim=-1)
|
| 159 |
+
return tokens
|
| 160 |
+
|
| 161 |
+
# patch embedding
|
| 162 |
+
class PositionGetter(object):
|
| 163 |
+
""" return positions of patches """
|
| 164 |
+
|
| 165 |
+
def __init__(self):
|
| 166 |
+
self.cache_positions = {}
|
| 167 |
+
|
| 168 |
+
def __call__(self, b, h, w, device):
|
| 169 |
+
if not (h,w) in self.cache_positions:
|
| 170 |
+
x = torch.arange(w, device=device)
|
| 171 |
+
y = torch.arange(h, device=device)
|
| 172 |
+
self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
|
| 173 |
+
pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
|
| 174 |
+
return pos
|
flow3r/models/layers/transformer_head.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .attention import FlashAttentionRope, FlashCrossAttentionRope
|
| 2 |
+
from .block import BlockRope, CrossBlockRope
|
| 3 |
+
from ..dinov2.layers import Mlp
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from functools import partial
|
| 7 |
+
from torch.utils.checkpoint import checkpoint
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from flow3r.models.flow_head.utils import create_uv_grid, position_grid_to_embed
|
| 10 |
+
|
| 11 |
+
class TransformerDecoder(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
in_dim,
|
| 15 |
+
out_dim,
|
| 16 |
+
dec_embed_dim=512,
|
| 17 |
+
depth=5,
|
| 18 |
+
dec_num_heads=8,
|
| 19 |
+
mlp_ratio=4,
|
| 20 |
+
rope=None,
|
| 21 |
+
need_project=True,
|
| 22 |
+
use_checkpoint=False,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
|
| 27 |
+
self.use_checkpoint = use_checkpoint
|
| 28 |
+
|
| 29 |
+
self.blocks = nn.ModuleList([
|
| 30 |
+
BlockRope(
|
| 31 |
+
dim=dec_embed_dim,
|
| 32 |
+
num_heads=dec_num_heads,
|
| 33 |
+
mlp_ratio=mlp_ratio,
|
| 34 |
+
qkv_bias=True,
|
| 35 |
+
proj_bias=True,
|
| 36 |
+
ffn_bias=True,
|
| 37 |
+
drop_path=0.0,
|
| 38 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 39 |
+
act_layer=nn.GELU,
|
| 40 |
+
ffn_layer=Mlp,
|
| 41 |
+
init_values=None,
|
| 42 |
+
qk_norm=False,
|
| 43 |
+
# attn_class=MemEffAttentionRope,
|
| 44 |
+
attn_class=FlashAttentionRope,
|
| 45 |
+
rope=rope
|
| 46 |
+
) for _ in range(depth)])
|
| 47 |
+
|
| 48 |
+
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
|
| 49 |
+
|
| 50 |
+
def forward(self, hidden, xpos=None, return_intermediate=False):
|
| 51 |
+
hidden = self.projects(hidden)
|
| 52 |
+
intermediate = []
|
| 53 |
+
for i, blk in enumerate(self.blocks):
|
| 54 |
+
if self.use_checkpoint and self.training:
|
| 55 |
+
hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False)
|
| 56 |
+
else:
|
| 57 |
+
hidden = blk(hidden, xpos=xpos)
|
| 58 |
+
|
| 59 |
+
if return_intermediate:
|
| 60 |
+
intermediate.append(hidden)
|
| 61 |
+
|
| 62 |
+
out = self.linear_out(hidden)
|
| 63 |
+
|
| 64 |
+
if return_intermediate:
|
| 65 |
+
return out, intermediate[-4:]
|
| 66 |
+
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
class LinearPts3d (nn.Module):
|
| 70 |
+
"""
|
| 71 |
+
Linear head for dust3r
|
| 72 |
+
Each token outputs: - 16x16 3D points (+ confidence)
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, patch_size, dec_embed_dim, output_dim=3,):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.patch_size = patch_size
|
| 78 |
+
|
| 79 |
+
self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
|
| 80 |
+
|
| 81 |
+
def forward(self, decout, img_shape):
|
| 82 |
+
H, W = img_shape
|
| 83 |
+
tokens = decout[-1]
|
| 84 |
+
B, S, D = tokens.shape
|
| 85 |
+
# print("--------------------------------")
|
| 86 |
+
# print("pointhead")
|
| 87 |
+
# print("H, W is", H, W)
|
| 88 |
+
# print("hw is", S)
|
| 89 |
+
# print("patch_h is", H//self.patch_size)
|
| 90 |
+
# print("patch_w is", W//self.patch_size)
|
| 91 |
+
# print("--------------------------------")
|
| 92 |
+
|
| 93 |
+
# extract 3D points
|
| 94 |
+
feat = self.proj(tokens) # B,S,D
|
| 95 |
+
feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
|
| 96 |
+
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
|
| 97 |
+
|
| 98 |
+
# permute + norm depth
|
| 99 |
+
return feat.permute(0, 2, 3, 1)
|
| 100 |
+
|
| 101 |
+
class LinearFlow2d (nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
Linear head for flow 2D with MLP fusion of camera and patch features
|
| 104 |
+
Each token outputs: - 16x16 2D flow
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def __init__(self, patch_size, dec_embed_dim, output_dim=2, camera_dim=512, num_heads=8, rope=None):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.patch_size = patch_size
|
| 110 |
+
self.dec_embed_dim = dec_embed_dim
|
| 111 |
+
self.camera_dim = camera_dim
|
| 112 |
+
|
| 113 |
+
# Position embedding for camera features (to distinguish first and second camera)
|
| 114 |
+
self.camera_pos_embed = nn.Parameter(torch.randn(2, 1, camera_dim))
|
| 115 |
+
nn.init.normal_(self.camera_pos_embed, std=0.02)
|
| 116 |
+
|
| 117 |
+
# Projection to match camera feature dimension to patch feature dimension
|
| 118 |
+
# self.camera_proj = nn.Linear(camera_dim, dec_embed_dim)
|
| 119 |
+
|
| 120 |
+
# MLP to fuse camera features and patch features
|
| 121 |
+
self.mlp = nn.Sequential(
|
| 122 |
+
nn.Linear(2*camera_dim + dec_embed_dim, 2*dec_embed_dim),
|
| 123 |
+
nn.ReLU(),
|
| 124 |
+
nn.Linear(2*dec_embed_dim, 2*dec_embed_dim),
|
| 125 |
+
nn.ReLU(),
|
| 126 |
+
nn.Linear(2*dec_embed_dim, dec_embed_dim),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Final projection to output dimension
|
| 130 |
+
self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
|
| 131 |
+
|
| 132 |
+
def forward(self, patch_hidden, camera_hidden, pair_indices, img_shape, B, N):
|
| 133 |
+
"""
|
| 134 |
+
Args:
|
| 135 |
+
patch_hidden: (B*N, hw, dec_embed_dim) - motion decoder output
|
| 136 |
+
camera_hidden: (B*N, hw, camera_dim) - camera decoder output
|
| 137 |
+
pair_indices: Tensor of shape (B, S, 2) or list of tuples
|
| 138 |
+
If Tensor (B, S, 2): indices are (i, j) relative to each batch
|
| 139 |
+
If list: [(b1, i1, j1), ...] or [(i1, j1), ...]
|
| 140 |
+
img_shape: (H, W)
|
| 141 |
+
B: batch size
|
| 142 |
+
N: sequence length (number of images)
|
| 143 |
+
Returns:
|
| 144 |
+
flow: (total_pairs, H, W, 2)
|
| 145 |
+
"""
|
| 146 |
+
H, W = img_shape
|
| 147 |
+
hw = patch_hidden.shape[1]
|
| 148 |
+
|
| 149 |
+
# Reshape from (B*N, hw, dim) to (B, N, hw, dim)、
|
| 150 |
+
# print("!!!!!now inside the LinearFlow2d forward function")
|
| 151 |
+
patch_hidden = patch_hidden.reshape(B, N, hw, self.dec_embed_dim)
|
| 152 |
+
camera_hidden = camera_hidden.reshape(B, N, hw, self.camera_dim)
|
| 153 |
+
# print("the shape of patch_hidden is", patch_hidden.shape)
|
| 154 |
+
# print("the shape of camera_hidden is", camera_hidden.shape)
|
| 155 |
+
|
| 156 |
+
# Handle Tensor input (B, S, 2)
|
| 157 |
+
if isinstance(pair_indices, torch.Tensor) and pair_indices.dim() == 3:
|
| 158 |
+
# pair_indices shape: (B, S, 2)
|
| 159 |
+
# We can use advanced indexing for efficiency
|
| 160 |
+
|
| 161 |
+
# Create batch indices: (B, S)
|
| 162 |
+
S = pair_indices.shape[1]
|
| 163 |
+
batch_idx = torch.arange(B, device=pair_indices.device).unsqueeze(1).expand(B, S)
|
| 164 |
+
|
| 165 |
+
# Extract indices for i and j images: (B, S)
|
| 166 |
+
idx_i = pair_indices[:, :, 0]
|
| 167 |
+
idx_j = pair_indices[:, :, 1]
|
| 168 |
+
|
| 169 |
+
# Extract patch features: (B, S, hw, dim)
|
| 170 |
+
patch_feat = patch_hidden[batch_idx, idx_i]
|
| 171 |
+
# print("the shape of patch_feat is", patch_feat.shape)
|
| 172 |
+
|
| 173 |
+
# Extract camera features: (B, S, hw, dim)
|
| 174 |
+
camera_i = camera_hidden[batch_idx, idx_i]
|
| 175 |
+
camera_j = camera_hidden[batch_idx, idx_j]
|
| 176 |
+
# print("the shape of camera_i is", camera_i.shape)
|
| 177 |
+
# print("the shape of camera_j is", camera_j.shape)
|
| 178 |
+
# Add position encoding
|
| 179 |
+
camera_i = camera_i + self.camera_pos_embed[0]
|
| 180 |
+
camera_j = camera_j + self.camera_pos_embed[1]
|
| 181 |
+
# print("the shape of camera_i after position encoding is", camera_i.shape)
|
| 182 |
+
# print("the shape of camera_j after position encoding is", camera_j.shape)
|
| 183 |
+
# Project camera features
|
| 184 |
+
# camera_i = self.camera_proj(camera_i)
|
| 185 |
+
# camera_j = self.camera_proj(camera_j)
|
| 186 |
+
# print("the shape of camera_i after projection is", camera_i.shape)
|
| 187 |
+
# print("the shape of camera_j after projection is", camera_j.shape)
|
| 188 |
+
# Concatenate camera features and patch features: (B, S, hw, 3*dim)
|
| 189 |
+
concat_features = torch.cat([camera_i, camera_j, patch_feat], dim=-1)
|
| 190 |
+
|
| 191 |
+
# Flatten B and S dimensions
|
| 192 |
+
total_pairs = B * S
|
| 193 |
+
input_features = concat_features.reshape(total_pairs, hw, 2*self.camera_dim + self.dec_embed_dim)
|
| 194 |
+
|
| 195 |
+
else:
|
| 196 |
+
raise ValueError("Invalid pair_indices type")
|
| 197 |
+
|
| 198 |
+
# Apply MLP
|
| 199 |
+
fused_features = self.mlp(input_features)
|
| 200 |
+
# print("the shape of fused_features after reshape is", fused_features.shape)
|
| 201 |
+
# Project to output dimension
|
| 202 |
+
feat = self.proj(fused_features) # (total_pairs, patch_hw, output_dim * patch_size^2)
|
| 203 |
+
# print("the shape of feat is", feat.shape)
|
| 204 |
+
# Reshape and apply pixel shuffle
|
| 205 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 206 |
+
# print("--------------------------------")
|
| 207 |
+
# print("H, W is", H, W)
|
| 208 |
+
# print("hw is", hw)
|
| 209 |
+
# print("patch_h is", patch_h)
|
| 210 |
+
# print("patch_w is", patch_w)
|
| 211 |
+
# print("--------------------------------")
|
| 212 |
+
feat = feat.transpose(-1, -2).reshape(total_pairs, -1, patch_h, patch_w)
|
| 213 |
+
feat = F.pixel_shuffle(feat, self.patch_size) # (total_pairs, output_dim, H, W)
|
| 214 |
+
# print("the shape of feat after pixel shuffle is", feat.shape)
|
| 215 |
+
|
| 216 |
+
# Permute to (total_pairs, H, W, output_dim)
|
| 217 |
+
return feat.permute(0, 2, 3, 1).reshape(B, S, H, W, -1)
|
| 218 |
+
|
| 219 |
+
class DPTFlow2d (nn.Module):
|
| 220 |
+
"""
|
| 221 |
+
Simplified DPT head for flow 2D with only one layer input
|
| 222 |
+
Each token outputs: - 16x16 2D flow
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(self, patch_size, dec_embed_dim, output_dim=2, camera_dim=512, rope=None, features=256):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.patch_size = patch_size
|
| 228 |
+
self.dec_embed_dim = dec_embed_dim
|
| 229 |
+
self.camera_dim = camera_dim
|
| 230 |
+
|
| 231 |
+
# Projection to match camera feature dimension to patch feature dimension
|
| 232 |
+
# self.camera_proj = nn.Linear(camera_dim, dec_embed_dim)
|
| 233 |
+
|
| 234 |
+
# MLP to fuse camera features and patch features
|
| 235 |
+
self.mlp = nn.Sequential(
|
| 236 |
+
nn.Linear(2*camera_dim + dec_embed_dim, 2*dec_embed_dim),
|
| 237 |
+
nn.ReLU(),
|
| 238 |
+
nn.Linear(2*dec_embed_dim, 2*dec_embed_dim),
|
| 239 |
+
nn.ReLU(),
|
| 240 |
+
nn.Linear(2*dec_embed_dim, dec_embed_dim),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.norm = nn.LayerNorm(dec_embed_dim)
|
| 244 |
+
|
| 245 |
+
self.project = nn.Conv2d(dec_embed_dim, features, kernel_size=1, stride=1, padding=0)
|
| 246 |
+
self.refine_low = nn.Sequential(
|
| 247 |
+
nn.Conv2d(features, features, 3, padding=1),
|
| 248 |
+
nn.GELU(),
|
| 249 |
+
nn.Conv2d(features, features, 3, padding=1),
|
| 250 |
+
nn.GELU(),
|
| 251 |
+
)
|
| 252 |
+
self.refine_high = nn.Sequential(
|
| 253 |
+
nn.Conv2d(features, features, 3, padding=1),
|
| 254 |
+
nn.GELU(),
|
| 255 |
+
nn.Conv2d(features, features, 3, padding=1),
|
| 256 |
+
nn.GELU(),
|
| 257 |
+
)
|
| 258 |
+
self.out_head = nn.Sequential(
|
| 259 |
+
nn.Conv2d(features, 64, 3, padding=1),
|
| 260 |
+
nn.GELU(),
|
| 261 |
+
nn.Conv2d(64, output_dim, 1),
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Final projection to output dimension
|
| 265 |
+
# self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
|
| 266 |
+
|
| 267 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 268 |
+
"""
|
| 269 |
+
Apply positional embedding to tensor x.
|
| 270 |
+
"""
|
| 271 |
+
patch_w = x.shape[-1]
|
| 272 |
+
patch_h = x.shape[-2]
|
| 273 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 274 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 275 |
+
pos_embed = pos_embed * ratio
|
| 276 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 277 |
+
return x + pos_embed
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def forward(self, patch_hidden, camera_hidden, pair_indices, img_shape, B, N):
|
| 281 |
+
"""
|
| 282 |
+
Args:
|
| 283 |
+
patch_hidden: (B*N, hw, dec_embed_dim) - motion decoder output
|
| 284 |
+
camera_hidden: (B*N, hw, camera_dim) - camera decoder output
|
| 285 |
+
pair_indices: Tensor of shape (B, S, 2) or list of tuples
|
| 286 |
+
If Tensor (B, S, 2): indices are (i, j) relative to each batch
|
| 287 |
+
If list: [(b1, i1, j1), ...] or [(i1, j1), ...]
|
| 288 |
+
img_shape: (H, W)
|
| 289 |
+
B: batch size
|
| 290 |
+
N: sequence length (number of images)
|
| 291 |
+
Returns:
|
| 292 |
+
flow: (total_pairs, H, W, 2)
|
| 293 |
+
"""
|
| 294 |
+
H, W = img_shape
|
| 295 |
+
hw = patch_hidden.shape[1]
|
| 296 |
+
|
| 297 |
+
# Reshape from (B*N, hw, dim) to (B, N, hw, dim)、
|
| 298 |
+
# print("!!!!!now inside the LinearFlow2d forward function")
|
| 299 |
+
patch_hidden = patch_hidden.reshape(B, N, hw, self.dec_embed_dim)
|
| 300 |
+
camera_hidden = camera_hidden.reshape(B, N, hw, self.camera_dim)
|
| 301 |
+
|
| 302 |
+
# Handle Tensor input (B, S, 2)
|
| 303 |
+
S = pair_indices.shape[1]
|
| 304 |
+
batch_idx = torch.arange(B, device=pair_indices.device).unsqueeze(1).expand(B, S)
|
| 305 |
+
|
| 306 |
+
# Extract indices for i and j images: (B, S)
|
| 307 |
+
idx_i = pair_indices[:, :, 0]
|
| 308 |
+
idx_j = pair_indices[:, :, 1]
|
| 309 |
+
|
| 310 |
+
# Extract patch features: (B, S, hw, dim)
|
| 311 |
+
patch_feat = patch_hidden[batch_idx, idx_i]
|
| 312 |
+
# print("the shape of patch_feat is", patch_feat.shape)
|
| 313 |
+
|
| 314 |
+
# Extract camera features: (B, S, hw, dim)
|
| 315 |
+
camera_i = camera_hidden[batch_idx, idx_i]
|
| 316 |
+
camera_j = camera_hidden[batch_idx, idx_j]
|
| 317 |
+
# Concatenate camera features and patch features: (B, S, hw, 3*dim)
|
| 318 |
+
concat_features = torch.cat([camera_i, camera_j, patch_feat], dim=-1)
|
| 319 |
+
|
| 320 |
+
# Flatten B and S dimensions
|
| 321 |
+
total_pairs = B * S
|
| 322 |
+
input_features = concat_features.reshape(total_pairs, hw, 2*self.camera_dim + self.dec_embed_dim)
|
| 323 |
+
|
| 324 |
+
# Apply MLP
|
| 325 |
+
fused = self.mlp(input_features) # (T, hw, dec_embed_dim)
|
| 326 |
+
|
| 327 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 328 |
+
assert hw == patch_h * patch_w, (hw, patch_h, patch_w)
|
| 329 |
+
fused = self.norm(fused)
|
| 330 |
+
feat = fused.transpose(1, 2).reshape(total_pairs, self.dec_embed_dim, patch_h, patch_w) # (T,D,h,w)
|
| 331 |
+
|
| 332 |
+
feat = self.project(feat) # (T,features,h,w)
|
| 333 |
+
feat = self._apply_pos_embed(feat, W, H)
|
| 334 |
+
feat = self.refine_low(feat)
|
| 335 |
+
|
| 336 |
+
feat = F.interpolate(feat, size=(H, W), mode="bilinear", align_corners=True)
|
| 337 |
+
feat = self._apply_pos_embed(feat, W, H)
|
| 338 |
+
feat = self.refine_high(feat)
|
| 339 |
+
|
| 340 |
+
flow = self.out_head(feat) # (T,2,H,W)
|
| 341 |
+
return flow.permute(0, 2, 3, 1).reshape(B, S, H, W, -1)
|
| 342 |
+
|
| 343 |
+
class ContextTransformerDecoder(nn.Module):
|
| 344 |
+
def __init__(
|
| 345 |
+
self,
|
| 346 |
+
in_dim,
|
| 347 |
+
out_dim,
|
| 348 |
+
dec_embed_dim=512,
|
| 349 |
+
depth=5,
|
| 350 |
+
dec_num_heads=8,
|
| 351 |
+
mlp_ratio=4,
|
| 352 |
+
rope=None,
|
| 353 |
+
):
|
| 354 |
+
super().__init__()
|
| 355 |
+
|
| 356 |
+
self.projects_x = nn.Linear(in_dim, dec_embed_dim)
|
| 357 |
+
self.projects_y = nn.Linear(in_dim, dec_embed_dim)
|
| 358 |
+
|
| 359 |
+
self.blocks = nn.ModuleList([
|
| 360 |
+
CrossBlockRope(
|
| 361 |
+
dim=dec_embed_dim,
|
| 362 |
+
num_heads=dec_num_heads,
|
| 363 |
+
mlp_ratio=mlp_ratio,
|
| 364 |
+
qkv_bias=True,
|
| 365 |
+
proj_bias=True,
|
| 366 |
+
ffn_bias=True,
|
| 367 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 368 |
+
act_layer=nn.GELU,
|
| 369 |
+
ffn_layer=Mlp,
|
| 370 |
+
init_values=None,
|
| 371 |
+
qk_norm=False,
|
| 372 |
+
# attn_class=MemEffAttentionRope,
|
| 373 |
+
# cross_attn_class=MemEffCrossAttentionRope,
|
| 374 |
+
attn_class=FlashAttentionRope,
|
| 375 |
+
cross_attn_class=FlashCrossAttentionRope,
|
| 376 |
+
rope=rope
|
| 377 |
+
) for _ in range(depth)])
|
| 378 |
+
|
| 379 |
+
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
|
| 380 |
+
|
| 381 |
+
def forward(self, hidden, context, xpos=None, ypos=None):
|
| 382 |
+
hidden = self.projects_x(hidden)
|
| 383 |
+
context = self.projects_y(context)
|
| 384 |
+
|
| 385 |
+
for i, blk in enumerate(self.blocks):
|
| 386 |
+
hidden = blk(hidden, context, xpos=xpos, ypos=ypos)
|
| 387 |
+
|
| 388 |
+
out = self.linear_out(hidden)
|
| 389 |
+
|
flow3r/utils/alignment.py
ADDED
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import math
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.types
|
| 10 |
+
# import utils3d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
|
| 14 |
+
"Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
|
| 15 |
+
shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
|
| 16 |
+
minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
|
| 17 |
+
minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
|
| 18 |
+
indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
|
| 19 |
+
indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
|
| 20 |
+
return torch.return_types.min((minimum, indices))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
|
| 24 |
+
batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
|
| 25 |
+
n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
|
| 26 |
+
splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
|
| 27 |
+
splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
|
| 28 |
+
results = []
|
| 29 |
+
for i in range(n_chunks):
|
| 30 |
+
chunk_args = tuple(arg[i] for arg in splited_args)
|
| 31 |
+
chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
|
| 32 |
+
results.append(fn(*chunk_args, **chunk_kwargs))
|
| 33 |
+
|
| 34 |
+
if isinstance(results[0], tuple):
|
| 35 |
+
return tuple(torch.cat(r, dim=0) for r in zip(*results))
|
| 36 |
+
else:
|
| 37 |
+
return torch.cat(results, dim=0)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _pad_inf(x_: torch.Tensor):
|
| 41 |
+
return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _pad_cumsum(cumsum: torch.Tensor):
|
| 45 |
+
return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
|
| 49 |
+
return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
| 53 |
+
"""
|
| 54 |
+
If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
|
| 55 |
+
|
| 56 |
+
w_i must be >= 0.
|
| 57 |
+
|
| 58 |
+
### Parameters:
|
| 59 |
+
- `x`: tensor of shape (..., n)
|
| 60 |
+
- `y`: tensor of shape (..., n)
|
| 61 |
+
- `w`: tensor of shape (..., n)
|
| 62 |
+
- `trunc`: optional, float or tensor of shape (..., n) or None
|
| 63 |
+
|
| 64 |
+
### Returns:
|
| 65 |
+
- `a`: tensor of shape (...), differentiable
|
| 66 |
+
- `loss`: tensor of shape (...), value of loss function at `a`, detached
|
| 67 |
+
- `index`: tensor of shape (...), where a = y[idx] / x[idx]
|
| 68 |
+
"""
|
| 69 |
+
if trunc is None:
|
| 70 |
+
x, y, w = torch.broadcast_tensors(x, y, w)
|
| 71 |
+
sign = torch.sign(x)
|
| 72 |
+
x, y = x * sign, y * sign
|
| 73 |
+
y_div_x = y / x.clamp_min(eps)
|
| 74 |
+
y_div_x, argsort = y_div_x.sort(dim=-1)
|
| 75 |
+
|
| 76 |
+
wx = torch.gather(x * w, dim=-1, index=argsort)
|
| 77 |
+
derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
|
| 78 |
+
search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
|
| 79 |
+
|
| 80 |
+
a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
|
| 81 |
+
index = argsort.gather(dim=-1, index=search).squeeze(-1)
|
| 82 |
+
loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
# Reshape to (batch_size, n) for simplicity
|
| 86 |
+
x, y, w = torch.broadcast_tensors(x, y, w)
|
| 87 |
+
batch_shape = x.shape[:-1]
|
| 88 |
+
batch_size = math.prod(batch_shape)
|
| 89 |
+
x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
|
| 90 |
+
|
| 91 |
+
sign = torch.sign(x)
|
| 92 |
+
x, y = x * sign, y * sign
|
| 93 |
+
wx, wy = w * x, w * y
|
| 94 |
+
xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
|
| 95 |
+
|
| 96 |
+
y_div_x = A = y / x.clamp_min(eps)
|
| 97 |
+
B = (wy - trunc) / wx.clamp_min(eps)
|
| 98 |
+
C = (wy + trunc) / wx.clamp_min(eps)
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
# Caculate prefix sum by orders of A, B, C
|
| 101 |
+
A, A_argsort = A.sort(dim=-1)
|
| 102 |
+
Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
|
| 103 |
+
A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
|
| 104 |
+
|
| 105 |
+
B, B_argsort = B.sort(dim=-1)
|
| 106 |
+
Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
|
| 107 |
+
B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
|
| 108 |
+
|
| 109 |
+
C, C_argsort = C.sort(dim=-1)
|
| 110 |
+
Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
|
| 111 |
+
C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
|
| 112 |
+
|
| 113 |
+
# Caculate left and right derivative of A
|
| 114 |
+
j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
|
| 115 |
+
j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
|
| 116 |
+
j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
|
| 117 |
+
left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
|
| 118 |
+
j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
|
| 119 |
+
j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
|
| 120 |
+
j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
|
| 121 |
+
right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
|
| 122 |
+
|
| 123 |
+
# Find extrema
|
| 124 |
+
is_extrema = (left_derivative < 0) & (right_derivative >= 0)
|
| 125 |
+
is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
|
| 126 |
+
where_extrema_batch, where_extrema_index = torch.where(is_extrema)
|
| 127 |
+
|
| 128 |
+
# Calculate objective value at extrema
|
| 129 |
+
extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
|
| 130 |
+
MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
|
| 131 |
+
SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
|
| 132 |
+
extrema_value = torch.cat([
|
| 133 |
+
_compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
|
| 134 |
+
for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
|
| 135 |
+
]) # (num_extrema,)
|
| 136 |
+
|
| 137 |
+
# Find minima among corresponding extrema
|
| 138 |
+
minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
|
| 139 |
+
index = where_extrema_index[indices]
|
| 140 |
+
|
| 141 |
+
a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
|
| 142 |
+
a = a.reshape(batch_shape)
|
| 143 |
+
loss = minima.reshape(batch_shape)
|
| 144 |
+
index = index.reshape(batch_shape)
|
| 145 |
+
|
| 146 |
+
return a, loss, index
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 150 |
+
"""
|
| 151 |
+
Align `depth_src` to `depth_tgt` with given constant weights.
|
| 152 |
+
|
| 153 |
+
### Parameters:
|
| 154 |
+
- `depth_src: torch.Tensor` of shape (..., N)
|
| 155 |
+
- `depth_tgt: torch.Tensor` of shape (..., N)
|
| 156 |
+
|
| 157 |
+
"""
|
| 158 |
+
scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
|
| 159 |
+
|
| 160 |
+
return scale
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 164 |
+
"""
|
| 165 |
+
Align `depth_src` to `depth_tgt` with given constant weights.
|
| 166 |
+
|
| 167 |
+
### Parameters:
|
| 168 |
+
- `depth_src: torch.Tensor` of shape (..., N)
|
| 169 |
+
- `depth_tgt: torch.Tensor` of shape (..., N)
|
| 170 |
+
- `weight: torch.Tensor` of shape (..., N)
|
| 171 |
+
- `trunc: float` or tensor of shape (..., N) or None
|
| 172 |
+
|
| 173 |
+
### Returns:
|
| 174 |
+
- `scale: torch.Tensor` of shape (...).
|
| 175 |
+
- `shift: torch.Tensor` of shape (...).
|
| 176 |
+
"""
|
| 177 |
+
dtype, device = depth_src.dtype, depth_src.device
|
| 178 |
+
|
| 179 |
+
# Flatten batch dimensions for simplicity
|
| 180 |
+
batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
|
| 181 |
+
batch_size = math.prod(batch_shape)
|
| 182 |
+
depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
|
| 183 |
+
|
| 184 |
+
# Here, we take anchors only for non-zero weights.
|
| 185 |
+
# Although the results will be still correct even anchor points have zero weight,
|
| 186 |
+
# it is wasting computation and may cause instability in some cases, e.g. too many extrema.
|
| 187 |
+
anchors_where_batch, anchors_where_n = torch.where(weight > 0)
|
| 188 |
+
|
| 189 |
+
# Stop gradient when solving optimal anchors
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
|
| 192 |
+
depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
|
| 193 |
+
|
| 194 |
+
depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
|
| 195 |
+
depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
|
| 196 |
+
weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
|
| 197 |
+
|
| 198 |
+
scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
|
| 199 |
+
|
| 200 |
+
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
|
| 201 |
+
|
| 202 |
+
# Reproduce by indexing for shorter compute graph
|
| 203 |
+
index_1 = anchors_where_n[index_anchor] # (batch_size,)
|
| 204 |
+
index_2 = index[index_anchor] # (batch_size,)
|
| 205 |
+
|
| 206 |
+
tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
|
| 207 |
+
tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
|
| 208 |
+
|
| 209 |
+
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
|
| 210 |
+
shift = tgt_1 - scale * src_1
|
| 211 |
+
|
| 212 |
+
scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
|
| 213 |
+
|
| 214 |
+
return scale, shift
|
| 215 |
+
|
| 216 |
+
def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
|
| 217 |
+
"""
|
| 218 |
+
Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
|
| 219 |
+
"""
|
| 220 |
+
dtype, device = depth_src.dtype, depth_src.device
|
| 221 |
+
|
| 222 |
+
w = weight
|
| 223 |
+
x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
|
| 224 |
+
y = depth_tgt
|
| 225 |
+
|
| 226 |
+
for i in range(max_iter):
|
| 227 |
+
beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
|
| 228 |
+
w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
|
| 229 |
+
|
| 230 |
+
return beta[..., 0], beta[..., 1]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 234 |
+
"""
|
| 235 |
+
### Parameters:
|
| 236 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 237 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 238 |
+
- `weight: torch.Tensor` of shape (..., N)
|
| 239 |
+
|
| 240 |
+
### Returns:
|
| 241 |
+
- `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
|
| 242 |
+
- `b: torch.Tensor` of shape (...)
|
| 243 |
+
"""
|
| 244 |
+
dtype, device = points_src.dtype, points_src.device
|
| 245 |
+
|
| 246 |
+
scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
|
| 247 |
+
|
| 248 |
+
return scale
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
|
| 252 |
+
"""
|
| 253 |
+
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
|
| 254 |
+
It is similar to `align_affine` but scale and shift are applied to different dimensions.
|
| 255 |
+
|
| 256 |
+
### Parameters:
|
| 257 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 258 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 259 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 260 |
+
|
| 261 |
+
### Returns:
|
| 262 |
+
- `scale: torch.Tensor` of shape (...).
|
| 263 |
+
- `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
|
| 264 |
+
"""
|
| 265 |
+
dtype, device = points_src.dtype, points_src.device
|
| 266 |
+
|
| 267 |
+
# Flatten batch dimensions for simplicity
|
| 268 |
+
batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
|
| 269 |
+
batch_size = math.prod(batch_shape)
|
| 270 |
+
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
|
| 271 |
+
|
| 272 |
+
# Take anchors
|
| 273 |
+
anchor_where_batch, anchor_where_n = torch.where(weight > 0)
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
|
| 276 |
+
points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
|
| 277 |
+
points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
|
| 278 |
+
|
| 279 |
+
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
|
| 280 |
+
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
|
| 281 |
+
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
|
| 282 |
+
|
| 283 |
+
# Solve optimal scale and shift for each anchor
|
| 284 |
+
MAX_ELEMENTS = 2 ** 20
|
| 285 |
+
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
|
| 286 |
+
|
| 287 |
+
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
|
| 288 |
+
|
| 289 |
+
# Reproduce by indexing for shorter compute graph
|
| 290 |
+
index_2 = index[index_anchor] # (batch_size,) [0, 3n)
|
| 291 |
+
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
|
| 292 |
+
|
| 293 |
+
zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
|
| 294 |
+
points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
|
| 295 |
+
tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
|
| 296 |
+
tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
|
| 297 |
+
|
| 298 |
+
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
|
| 299 |
+
shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
|
| 300 |
+
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
|
| 301 |
+
|
| 302 |
+
return scale, shift
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
|
| 306 |
+
"""
|
| 307 |
+
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
|
| 308 |
+
It is similar to `align_affine` but scale and shift are applied to different dimensions.
|
| 309 |
+
|
| 310 |
+
### Parameters:
|
| 311 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 312 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 313 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 314 |
+
|
| 315 |
+
### Returns:
|
| 316 |
+
- `scale: torch.Tensor` of shape (...).
|
| 317 |
+
- `shift: torch.Tensor` of shape (..., 3)
|
| 318 |
+
"""
|
| 319 |
+
dtype, device = points_src.dtype, points_src.device
|
| 320 |
+
|
| 321 |
+
# Flatten batch dimensions for simplicity
|
| 322 |
+
batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
|
| 323 |
+
batch_size = math.prod(batch_shape)
|
| 324 |
+
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
|
| 325 |
+
|
| 326 |
+
# Take anchors
|
| 327 |
+
anchor_where_batch, anchor_where_n = torch.where(weight > 0)
|
| 328 |
+
|
| 329 |
+
with torch.no_grad():
|
| 330 |
+
points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
|
| 331 |
+
points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
|
| 332 |
+
|
| 333 |
+
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
|
| 334 |
+
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
|
| 335 |
+
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
|
| 336 |
+
|
| 337 |
+
# Solve optimal scale and shift for each anchor
|
| 338 |
+
MAX_ELEMENTS = 2 ** 20
|
| 339 |
+
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
|
| 340 |
+
|
| 341 |
+
# Get optimal scale and shift for each batch element
|
| 342 |
+
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
|
| 343 |
+
|
| 344 |
+
index_2 = index[index_anchor] # (batch_size,) [0, 3n)
|
| 345 |
+
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
|
| 346 |
+
|
| 347 |
+
src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
|
| 348 |
+
src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
|
| 349 |
+
|
| 350 |
+
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
|
| 351 |
+
shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
|
| 352 |
+
|
| 353 |
+
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
|
| 354 |
+
|
| 355 |
+
return scale, shift
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
|
| 359 |
+
"""
|
| 360 |
+
Align `points_src` to `points_tgt` with respect to a Z-axis shift.
|
| 361 |
+
|
| 362 |
+
### Parameters:
|
| 363 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 364 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 365 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 366 |
+
|
| 367 |
+
### Returns:
|
| 368 |
+
- `scale: torch.Tensor` of shape (...).
|
| 369 |
+
- `shift: torch.Tensor` of shape (..., 3)
|
| 370 |
+
"""
|
| 371 |
+
dtype, device = points_src.dtype, points_src.device
|
| 372 |
+
|
| 373 |
+
shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
|
| 374 |
+
shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
|
| 375 |
+
|
| 376 |
+
return shift
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
|
| 380 |
+
"""
|
| 381 |
+
Align `points_src` to `points_tgt` with respect to a Z-axis shift.
|
| 382 |
+
|
| 383 |
+
### Parameters:
|
| 384 |
+
- `points_src: torch.Tensor` of shape (..., N, 3)
|
| 385 |
+
- `points_tgt: torch.Tensor` of shape (..., N, 3)
|
| 386 |
+
- `weights: torch.Tensor` of shape (..., N)
|
| 387 |
+
|
| 388 |
+
### Returns:
|
| 389 |
+
- `scale: torch.Tensor` of shape (...).
|
| 390 |
+
- `shift: torch.Tensor` of shape (..., 3)
|
| 391 |
+
"""
|
| 392 |
+
dtype, device = points_src.dtype, points_src.device
|
| 393 |
+
|
| 394 |
+
shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
|
| 395 |
+
|
| 396 |
+
return shift
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 400 |
+
"""
|
| 401 |
+
Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
|
| 402 |
+
|
| 403 |
+
### Parameters:
|
| 404 |
+
- `x: torch.Tensor` of shape (..., N)
|
| 405 |
+
- `y: torch.Tensor` of shape (..., N)
|
| 406 |
+
- `w: torch.Tensor` of shape (..., N)
|
| 407 |
+
|
| 408 |
+
### Returns:
|
| 409 |
+
- `a: torch.Tensor` of shape (...,)
|
| 410 |
+
- `b: torch.Tensor` of shape (...,)
|
| 411 |
+
"""
|
| 412 |
+
w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
|
| 413 |
+
A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
|
| 414 |
+
B = (w_sqrt * y)[..., None]
|
| 415 |
+
a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
|
| 416 |
+
return a, b
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def align_affine_lstsq_z_shift(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 420 |
+
"""
|
| 421 |
+
Solve `min sum_i w_i * ||a * x_i + b - y_i||^2`, where x_i and y_i are 3D points,
|
| 422 |
+
`a` is a scalar (isotropic scaling), and `b` is a translation vector of the form `[0, 0, shift_z]`.
|
| 423 |
+
The minimization is with respect to `a` (scalar_scale) and `shift_z`.
|
| 424 |
+
|
| 425 |
+
The input point clouds x and y are expected to have a shape like (..., N, 3),
|
| 426 |
+
where N is the number of points and the last dimension has size 3 (X, Y, Z).
|
| 427 |
+
The weights w, if provided, should have shape (..., N) corresponding to the points.
|
| 428 |
+
|
| 429 |
+
This function adapts the structure of a 1D affine least squares solver to this specific
|
| 430 |
+
3D problem by reformulating the design matrix A and observation vector B for torch.linalg.lstsq.
|
| 431 |
+
|
| 432 |
+
Parameters:
|
| 433 |
+
- `x: torch.Tensor` of shape (..., N, 3), representing the source point cloud.
|
| 434 |
+
- `y: torch.Tensor` of shape (..., N, 3), representing the target point cloud.
|
| 435 |
+
- `w: torch.Tensor` (optional) of shape (..., N), representing weights for each point.
|
| 436 |
+
If None, all points are weighted equally.
|
| 437 |
+
|
| 438 |
+
Returns:
|
| 439 |
+
- `a: torch.Tensor` of shape (...,), the scalar scaling factor.
|
| 440 |
+
- `b: torch.Tensor` of shape (..., 3), the translation vector `[0, 0, shift_z]`.
|
| 441 |
+
"""
|
| 442 |
+
if x.shape[-1] != 3 or y.shape[-1] != 3:
|
| 443 |
+
raise ValueError("Input tensors x and y must have 3 features in the last dimension (X, Y, Z). "
|
| 444 |
+
f"Got x shape: {x.shape}, y shape: {y.shape}")
|
| 445 |
+
# Check all dimensions except the last one (feature dimension)
|
| 446 |
+
if x.shape[:-1] != y.shape[:-1]:
|
| 447 |
+
raise ValueError("Input tensors x and y must have matching shapes up to the last dimension. "
|
| 448 |
+
f"Got x shape: {x.shape}, y shape: {y.shape}")
|
| 449 |
+
if w is not None and w.shape != x.shape[:-1]:
|
| 450 |
+
raise ValueError("Weights w, if provided, must have shape (..., N) matching x and y's point dimensions. "
|
| 451 |
+
f"Got w shape: {w.shape}, x shape: {x.shape}")
|
| 452 |
+
|
| 453 |
+
# Determine batch shape and number of points
|
| 454 |
+
# Example: x shape (B1, B2, N, 3) -> batch_shape (B1, B2), num_points N
|
| 455 |
+
batch_shape = x.shape[:-2]
|
| 456 |
+
num_points = x.shape[-2]
|
| 457 |
+
|
| 458 |
+
# Prepare w_sqrt. If w is None, use unit weights.
|
| 459 |
+
# w_sqrt_points will have shape (..., N)
|
| 460 |
+
if w is None:
|
| 461 |
+
w_sqrt_points = torch.ones(*batch_shape, num_points, device=x.device, dtype=x.dtype)
|
| 462 |
+
else:
|
| 463 |
+
w_sqrt_points = w.sqrt()
|
| 464 |
+
|
| 465 |
+
# Dimension along which to concatenate point data from different coordinates (X, Y, Z)
|
| 466 |
+
dim_to_cat = len(batch_shape)
|
| 467 |
+
|
| 468 |
+
# Coefficients for 'a_val' (the scalar scale)
|
| 469 |
+
s_terms_x = w_sqrt_points * x[..., :, 0] # Shape (..., N)
|
| 470 |
+
s_terms_y = w_sqrt_points * x[..., :, 1] # Shape (..., N)
|
| 471 |
+
s_terms_z = w_sqrt_points * x[..., :, 2] # Shape (..., N)
|
| 472 |
+
a_val_coeff_column = torch.cat([s_terms_x, s_terms_y, s_terms_z], dim=dim_to_cat) # Shape (..., 3*N)
|
| 473 |
+
|
| 474 |
+
# Coefficients for 'shift_z_val'
|
| 475 |
+
zeros_for_shift_coeffs = torch.zeros_like(s_terms_x) # Shape (..., N)
|
| 476 |
+
shift_z_val_coeff_column = torch.cat([zeros_for_shift_coeffs, zeros_for_shift_coeffs, w_sqrt_points], dim=dim_to_cat) # Shape (..., 3*N)
|
| 477 |
+
|
| 478 |
+
# Construct the design matrix A_ls (shape (..., 3*N, 2))
|
| 479 |
+
A_ls = torch.stack([a_val_coeff_column, shift_z_val_coeff_column], dim=-1)
|
| 480 |
+
|
| 481 |
+
# Construct the observation vector B_ls (shape (..., 3*N, 1))
|
| 482 |
+
B_terms_x = w_sqrt_points * y[..., :, 0] # Shape (..., N)
|
| 483 |
+
B_terms_y = w_sqrt_points * y[..., :, 1] # Shape (..., N)
|
| 484 |
+
B_terms_z = w_sqrt_points * y[..., :, 2] # Shape (..., N)
|
| 485 |
+
B_ls_flat = torch.cat([B_terms_x, B_terms_y, B_terms_z], dim=dim_to_cat) # Shape (..., 3*N)
|
| 486 |
+
B_ls = B_ls_flat.unsqueeze(-1)
|
| 487 |
+
|
| 488 |
+
# Solve the least squares problem
|
| 489 |
+
solution = torch.linalg.lstsq(A_ls, B_ls)[0] # solution shape (..., 2, 1)
|
| 490 |
+
|
| 491 |
+
# Extract the scalar scale 'a_val' and 'shift_z_val'
|
| 492 |
+
a_val = solution[..., 0, 0] # Shape (...,)
|
| 493 |
+
shift_z_val = solution[..., 1, 0] # Shape (...,)
|
| 494 |
+
|
| 495 |
+
# Construct the output translation vector b = [0, 0, shift_z_val]
|
| 496 |
+
zeros_for_b = torch.zeros_like(a_val)
|
| 497 |
+
b_vector = torch.stack([zeros_for_b, zeros_for_b, shift_z_val], dim=-1) # Shape (..., 3)
|
| 498 |
+
|
| 499 |
+
return a_val, b_vector
|
flow3r/utils/basic.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import math
|
| 4 |
+
import cv2
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from plyfile import PlyData, PlyElement
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000):
|
| 12 |
+
"""
|
| 13 |
+
Loads images from a directory or video, resizes them to a uniform size,
|
| 14 |
+
then converts and stacks them into a single [N, 3, H, W] PyTorch tensor.
|
| 15 |
+
"""
|
| 16 |
+
sources = []
|
| 17 |
+
|
| 18 |
+
# --- 1. Load image paths or video frames ---
|
| 19 |
+
if osp.isdir(path):
|
| 20 |
+
print(f"Loading images from directory: {path}")
|
| 21 |
+
filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))])
|
| 22 |
+
for i in range(0, len(filenames), interval):
|
| 23 |
+
img_path = osp.join(path, filenames[i])
|
| 24 |
+
try:
|
| 25 |
+
sources.append(Image.open(img_path).convert('RGB'))
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f"Could not load image {filenames[i]}: {e}")
|
| 28 |
+
elif path.lower().endswith('.mp4'):
|
| 29 |
+
print(f"Loading frames from video: {path}")
|
| 30 |
+
cap = cv2.VideoCapture(path)
|
| 31 |
+
if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}")
|
| 32 |
+
frame_idx = 0
|
| 33 |
+
while True:
|
| 34 |
+
ret, frame = cap.read()
|
| 35 |
+
if not ret: break
|
| 36 |
+
if frame_idx % interval == 0:
|
| 37 |
+
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 38 |
+
sources.append(Image.fromarray(rgb_frame))
|
| 39 |
+
frame_idx += 1
|
| 40 |
+
cap.release()
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}")
|
| 43 |
+
|
| 44 |
+
if not sources:
|
| 45 |
+
print("No images found or loaded.")
|
| 46 |
+
return torch.empty(0)
|
| 47 |
+
|
| 48 |
+
print(f"Found {len(sources)} images/frames. Processing...")
|
| 49 |
+
|
| 50 |
+
# --- 2. Determine a uniform target size for all images based on the first image ---
|
| 51 |
+
# This is necessary to ensure all tensors have the same dimensions for stacking.
|
| 52 |
+
first_img = sources[0]
|
| 53 |
+
W_orig, H_orig = first_img.size
|
| 54 |
+
scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1
|
| 55 |
+
W_target, H_target = W_orig * scale, H_orig * scale
|
| 56 |
+
k, m = round(W_target / 14), round(H_target / 14)
|
| 57 |
+
while (k * 14) * (m * 14) > PIXEL_LIMIT:
|
| 58 |
+
if k / m > W_target / H_target: k -= 1
|
| 59 |
+
else: m -= 1
|
| 60 |
+
TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14
|
| 61 |
+
print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})")
|
| 62 |
+
|
| 63 |
+
# --- 3. Resize images and convert them to tensors in the [0, 1] range ---
|
| 64 |
+
tensor_list = []
|
| 65 |
+
# Define a transform to convert a PIL Image to a CxHxW tensor and normalize to [0,1]
|
| 66 |
+
to_tensor_transform = transforms.ToTensor()
|
| 67 |
+
|
| 68 |
+
for img_pil in sources:
|
| 69 |
+
try:
|
| 70 |
+
# Resize to the uniform target size
|
| 71 |
+
resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS)
|
| 72 |
+
# Convert to tensor
|
| 73 |
+
img_tensor = to_tensor_transform(resized_img)
|
| 74 |
+
tensor_list.append(img_tensor)
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"Error processing an image: {e}")
|
| 77 |
+
|
| 78 |
+
if not tensor_list:
|
| 79 |
+
print("No images were successfully processed.")
|
| 80 |
+
return torch.empty(0)
|
| 81 |
+
|
| 82 |
+
# --- 4. Stack the list of tensors into a single [N, C, H, W] batch tensor ---
|
| 83 |
+
return torch.stack(tensor_list, dim=0)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def tensor_to_pil(tensor):
|
| 87 |
+
"""
|
| 88 |
+
Converts a PyTorch tensor to a PIL image. Automatically moves the channel dimension
|
| 89 |
+
(if it has size 3) to the last axis before converting.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
tensor (torch.Tensor): Input tensor. Expected shape can be [C, H, W], [H, W, C], or [H, W].
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
PIL.Image: The converted PIL image.
|
| 96 |
+
"""
|
| 97 |
+
if torch.is_tensor(tensor):
|
| 98 |
+
array = tensor.detach().cpu().numpy()
|
| 99 |
+
else:
|
| 100 |
+
array = tensor
|
| 101 |
+
|
| 102 |
+
return array_to_pil(array)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def array_to_pil(array):
|
| 106 |
+
"""
|
| 107 |
+
Converts a NumPy array to a PIL image. Automatically:
|
| 108 |
+
- Squeezes dimensions of size 1.
|
| 109 |
+
- Moves the channel dimension (if it has size 3) to the last axis.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
array (np.ndarray): Input array. Expected shape can be [C, H, W], [H, W, C], or [H, W].
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
PIL.Image: The converted PIL image.
|
| 116 |
+
"""
|
| 117 |
+
# Remove singleton dimensions
|
| 118 |
+
array = np.squeeze(array)
|
| 119 |
+
|
| 120 |
+
# Ensure the array has the channel dimension as the last axis
|
| 121 |
+
if array.ndim == 3 and array.shape[0] == 3: # If the channel is the first axis
|
| 122 |
+
array = np.transpose(array, (1, 2, 0)) # Move channel to the last axis
|
| 123 |
+
|
| 124 |
+
# Handle single-channel grayscale images
|
| 125 |
+
if array.ndim == 2: # [H, W]
|
| 126 |
+
return Image.fromarray((array * 255).astype(np.uint8), mode="L")
|
| 127 |
+
elif array.ndim == 3 and array.shape[2] == 3: # [H, W, C] with 3 channels
|
| 128 |
+
return Image.fromarray((array * 255).astype(np.uint8), mode="RGB")
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError(f"Unsupported array shape for PIL conversion: {array.shape}")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def rotate_target_dim_to_last_axis(x, target_dim=3):
|
| 134 |
+
shape = x.shape
|
| 135 |
+
axis_to_move = -1
|
| 136 |
+
# Iterate backwards to find the first occurrence from the end
|
| 137 |
+
# (which corresponds to the last dimension of size 3 in the original order).
|
| 138 |
+
for i in range(len(shape) - 1, -1, -1):
|
| 139 |
+
if shape[i] == target_dim:
|
| 140 |
+
axis_to_move = i
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
# 2. If the axis is found and it's not already in the last position, move it.
|
| 144 |
+
if axis_to_move != -1 and axis_to_move != len(shape) - 1:
|
| 145 |
+
# Create the new dimension order.
|
| 146 |
+
dims_order = list(range(len(shape)))
|
| 147 |
+
dims_order.pop(axis_to_move)
|
| 148 |
+
dims_order.append(axis_to_move)
|
| 149 |
+
|
| 150 |
+
# Use permute to reorder the dimensions.
|
| 151 |
+
ret = x.transpose(*dims_order)
|
| 152 |
+
else:
|
| 153 |
+
ret = x
|
| 154 |
+
|
| 155 |
+
return ret
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def write_ply(
|
| 159 |
+
xyz,
|
| 160 |
+
rgb=None,
|
| 161 |
+
path='output.ply',
|
| 162 |
+
) -> None:
|
| 163 |
+
if torch.is_tensor(xyz):
|
| 164 |
+
xyz = xyz.detach().cpu().numpy()
|
| 165 |
+
|
| 166 |
+
if torch.is_tensor(rgb):
|
| 167 |
+
rgb = rgb.detach().cpu().numpy()
|
| 168 |
+
|
| 169 |
+
if rgb is not None and rgb.max() > 1:
|
| 170 |
+
rgb = rgb / 255.
|
| 171 |
+
|
| 172 |
+
xyz = rotate_target_dim_to_last_axis(xyz, 3)
|
| 173 |
+
xyz = xyz.reshape(-1, 3)
|
| 174 |
+
|
| 175 |
+
if rgb is not None:
|
| 176 |
+
rgb = rotate_target_dim_to_last_axis(rgb, 3)
|
| 177 |
+
rgb = rgb.reshape(-1, 3)
|
| 178 |
+
|
| 179 |
+
if rgb is None:
|
| 180 |
+
min_coord = np.min(xyz, axis=0)
|
| 181 |
+
max_coord = np.max(xyz, axis=0)
|
| 182 |
+
normalized_coord = (xyz - min_coord) / (max_coord - min_coord + 1e-8)
|
| 183 |
+
|
| 184 |
+
hue = 0.7 * normalized_coord[:,0] + 0.2 * normalized_coord[:,1] + 0.1 * normalized_coord[:,2]
|
| 185 |
+
hsv = np.stack([hue, 0.9*np.ones_like(hue), 0.8*np.ones_like(hue)], axis=1)
|
| 186 |
+
|
| 187 |
+
c = hsv[:,2:] * hsv[:,1:2]
|
| 188 |
+
x = c * (1 - np.abs( (hsv[:,0:1]*6) % 2 - 1 ))
|
| 189 |
+
m = hsv[:,2:] - c
|
| 190 |
+
|
| 191 |
+
rgb = np.zeros_like(hsv)
|
| 192 |
+
cond = (0 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 1)
|
| 193 |
+
rgb[cond] = np.hstack([c[cond], x[cond], np.zeros_like(x[cond])])
|
| 194 |
+
cond = (1 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 2)
|
| 195 |
+
rgb[cond] = np.hstack([x[cond], c[cond], np.zeros_like(x[cond])])
|
| 196 |
+
cond = (2 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 3)
|
| 197 |
+
rgb[cond] = np.hstack([np.zeros_like(x[cond]), c[cond], x[cond]])
|
| 198 |
+
cond = (3 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 4)
|
| 199 |
+
rgb[cond] = np.hstack([np.zeros_like(x[cond]), x[cond], c[cond]])
|
| 200 |
+
cond = (4 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 5)
|
| 201 |
+
rgb[cond] = np.hstack([x[cond], np.zeros_like(x[cond]), c[cond]])
|
| 202 |
+
cond = (5 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 6)
|
| 203 |
+
rgb[cond] = np.hstack([c[cond], np.zeros_like(x[cond]), x[cond]])
|
| 204 |
+
rgb = (rgb + m)
|
| 205 |
+
|
| 206 |
+
dtype = [
|
| 207 |
+
("x", "f4"),
|
| 208 |
+
("y", "f4"),
|
| 209 |
+
("z", "f4"),
|
| 210 |
+
("nx", "f4"),
|
| 211 |
+
("ny", "f4"),
|
| 212 |
+
("nz", "f4"),
|
| 213 |
+
("red", "u1"),
|
| 214 |
+
("green", "u1"),
|
| 215 |
+
("blue", "u1"),
|
| 216 |
+
]
|
| 217 |
+
normals = np.zeros_like(xyz)
|
| 218 |
+
elements = np.empty(xyz.shape[0], dtype=dtype)
|
| 219 |
+
attributes = np.concatenate((xyz, normals, rgb * 255), axis=1)
|
| 220 |
+
elements[:] = list(map(tuple, attributes))
|
| 221 |
+
vertex_element = PlyElement.describe(elements, "vertex")
|
| 222 |
+
ply_data = PlyData([vertex_element])
|
| 223 |
+
ply_data.write(path)
|
flow3r/utils/cropping.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# croppping utilities
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import PIL.Image
|
| 8 |
+
import os
|
| 9 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 10 |
+
import cv2 # noqa
|
| 11 |
+
import numpy as np # noqa
|
| 12 |
+
try:
|
| 13 |
+
lanczos = PIL.Image.Resampling.LANCZOS
|
| 14 |
+
bicubic = PIL.Image.Resampling.BICUBIC
|
| 15 |
+
except AttributeError:
|
| 16 |
+
lanczos = PIL.Image.LANCZOS
|
| 17 |
+
bicubic = PIL.Image.BICUBIC
|
| 18 |
+
|
| 19 |
+
from utils.basic import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
|
| 20 |
+
|
| 21 |
+
class ImageList:
|
| 22 |
+
""" Convenience class to aply the same operation to a whole set of images.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, images):
|
| 26 |
+
if not isinstance(images, (tuple, list, set)):
|
| 27 |
+
images = [images]
|
| 28 |
+
self.images = []
|
| 29 |
+
for image in images:
|
| 30 |
+
if not isinstance(image, PIL.Image.Image):
|
| 31 |
+
image = PIL.Image.fromarray(image)
|
| 32 |
+
self.images.append(image)
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return len(self.images)
|
| 36 |
+
|
| 37 |
+
def to_pil(self):
|
| 38 |
+
return tuple(self.images) if len(self.images) > 1 else self.images[0]
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def size(self):
|
| 42 |
+
sizes = [im.size for im in self.images]
|
| 43 |
+
assert all(sizes[0] == s for s in sizes)
|
| 44 |
+
return sizes[0]
|
| 45 |
+
|
| 46 |
+
def resize(self, *args, **kwargs):
|
| 47 |
+
return ImageList(self._dispatch('resize', *args, **kwargs))
|
| 48 |
+
|
| 49 |
+
def crop(self, *args, **kwargs):
|
| 50 |
+
return ImageList(self._dispatch('crop', *args, **kwargs))
|
| 51 |
+
|
| 52 |
+
def _dispatch(self, func, *args, **kwargs):
|
| 53 |
+
return [getattr(im, func)(*args, **kwargs) for im in self.images]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True, normal=None, far_mask=None):
|
| 57 |
+
""" Jointly rescale a (image, depthmap)
|
| 58 |
+
so that (out_width, out_height) >= output_res
|
| 59 |
+
"""
|
| 60 |
+
image = ImageList(image)
|
| 61 |
+
input_resolution = np.array(image.size) # (W,H)
|
| 62 |
+
output_resolution = np.array(output_resolution)
|
| 63 |
+
if depthmap is not None:
|
| 64 |
+
# can also use this with masks instead of depthmaps
|
| 65 |
+
assert tuple(depthmap.shape[:2]) == image.size[::-1]
|
| 66 |
+
|
| 67 |
+
# define output resolution
|
| 68 |
+
assert output_resolution.shape == (2,)
|
| 69 |
+
scale_final = max(output_resolution / image.size) + 1e-8
|
| 70 |
+
if scale_final >= 1 and not force: # image is already smaller than what is asked
|
| 71 |
+
return (image.to_pil(), depthmap, camera_intrinsics)
|
| 72 |
+
output_resolution = np.floor(input_resolution * scale_final).astype(int)
|
| 73 |
+
|
| 74 |
+
# first rescale the image so that it contains the crop
|
| 75 |
+
image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic)
|
| 76 |
+
if depthmap is not None:
|
| 77 |
+
depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
|
| 78 |
+
fy=scale_final, interpolation=cv2.INTER_NEAREST)
|
| 79 |
+
|
| 80 |
+
if normal is not None:
|
| 81 |
+
normal = cv2.resize(normal, output_resolution, fx=scale_final,
|
| 82 |
+
fy=scale_final, interpolation=cv2.INTER_NEAREST)
|
| 83 |
+
if far_mask is not None:
|
| 84 |
+
far_mask = cv2.resize(far_mask, output_resolution, fx=scale_final,
|
| 85 |
+
fy=scale_final, interpolation=cv2.INTER_NEAREST)
|
| 86 |
+
|
| 87 |
+
# no offset here; simple rescaling
|
| 88 |
+
camera_intrinsics = camera_matrix_of_crop(
|
| 89 |
+
camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
|
| 90 |
+
|
| 91 |
+
return image.to_pil(), depthmap, camera_intrinsics, normal, far_mask
|
| 92 |
+
|
| 93 |
+
def center_crop_image_depthmap(image, depthmap, camera_intrinsics, crop_scale, normal=None, far_mask=None):
|
| 94 |
+
"""
|
| 95 |
+
Jointly center-crop an image and its depthmap, and adjust the camera intrinsics accordingly.
|
| 96 |
+
|
| 97 |
+
Parameters:
|
| 98 |
+
- image: PIL.Image or similar, the input image.
|
| 99 |
+
- depthmap: np.ndarray, the corresponding depth map.
|
| 100 |
+
- camera_intrinsics: np.ndarray, the 3x3 camera intrinsics matrix.
|
| 101 |
+
- crop_scale: float between 0 and 1, the fraction of the image to keep.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
- cropped_image: PIL.Image, the center-cropped image.
|
| 105 |
+
- cropped_depthmap: np.ndarray, the center-cropped depth map.
|
| 106 |
+
- adjusted_intrinsics: np.ndarray, the adjusted camera intrinsics matrix.
|
| 107 |
+
"""
|
| 108 |
+
# Ensure crop_scale is valid
|
| 109 |
+
assert 0 < crop_scale <= 1, "crop_scale must be between 0 and 1"
|
| 110 |
+
|
| 111 |
+
# Convert image to ImageList for consistent processing
|
| 112 |
+
image = ImageList(image)
|
| 113 |
+
input_resolution = np.array(image.size) # (width, height)
|
| 114 |
+
if depthmap is not None:
|
| 115 |
+
# Ensure depthmap matches the image size
|
| 116 |
+
assert depthmap.shape[:2] == tuple(image.size[::-1]), "Depthmap size must match image size"
|
| 117 |
+
|
| 118 |
+
# Compute output resolution after cropping
|
| 119 |
+
output_resolution = np.floor(input_resolution * crop_scale).astype(int)
|
| 120 |
+
# get the correct crop_scale
|
| 121 |
+
crop_scale = output_resolution / input_resolution
|
| 122 |
+
|
| 123 |
+
# Compute margins (amount to crop from each side)
|
| 124 |
+
margins = input_resolution - output_resolution
|
| 125 |
+
offset = margins / 2 # Since we are center cropping
|
| 126 |
+
|
| 127 |
+
# Calculate the crop bounding box
|
| 128 |
+
l, t = offset.astype(int)
|
| 129 |
+
r = l + output_resolution[0]
|
| 130 |
+
b = t + output_resolution[1]
|
| 131 |
+
crop_bbox = (l, t, r, b)
|
| 132 |
+
|
| 133 |
+
# Crop the image and depthmap
|
| 134 |
+
image = image.crop(crop_bbox)
|
| 135 |
+
if depthmap is not None:
|
| 136 |
+
depthmap = depthmap[t:b, l:r]
|
| 137 |
+
if normal is not None:
|
| 138 |
+
normal = normal[t:b, l:r]
|
| 139 |
+
if far_mask is not None:
|
| 140 |
+
far_mask = far_mask[t:b, l:r]
|
| 141 |
+
|
| 142 |
+
# Adjust the camera intrinsics
|
| 143 |
+
adjusted_intrinsics = camera_intrinsics.copy()
|
| 144 |
+
|
| 145 |
+
# Adjust focal lengths (fx, fy) # no need to adjust focal lengths for cropping
|
| 146 |
+
# adjusted_intrinsics[0, 0] /= crop_scale[0] # fx
|
| 147 |
+
# adjusted_intrinsics[1, 1] /= crop_scale[1] # fy
|
| 148 |
+
|
| 149 |
+
# Adjust principal point (cx, cy)
|
| 150 |
+
adjusted_intrinsics[0, 2] -= l # cx
|
| 151 |
+
adjusted_intrinsics[1, 2] -= t # cy
|
| 152 |
+
|
| 153 |
+
return image.to_pil(), depthmap, adjusted_intrinsics, normal, far_mask
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
|
| 157 |
+
# Margins to offset the origin
|
| 158 |
+
margins = np.asarray(input_resolution) * scaling - output_resolution
|
| 159 |
+
assert np.all(margins >= 0.0)
|
| 160 |
+
if offset is None:
|
| 161 |
+
offset = offset_factor * margins
|
| 162 |
+
|
| 163 |
+
# Generate new camera parameters
|
| 164 |
+
output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
|
| 165 |
+
output_camera_matrix_colmap[:2, :] *= scaling
|
| 166 |
+
output_camera_matrix_colmap[:2, 2] -= offset
|
| 167 |
+
output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
|
| 168 |
+
|
| 169 |
+
return output_camera_matrix
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox, normal=None, far_mask=None):
|
| 173 |
+
"""
|
| 174 |
+
Return a crop of the input view.
|
| 175 |
+
"""
|
| 176 |
+
image = ImageList(image)
|
| 177 |
+
l, t, r, b = crop_bbox
|
| 178 |
+
|
| 179 |
+
image = image.crop((l, t, r, b))
|
| 180 |
+
depthmap = depthmap[t:b, l:r]
|
| 181 |
+
if normal is not None:
|
| 182 |
+
normal = normal[t:b, l:r]
|
| 183 |
+
if far_mask is not None:
|
| 184 |
+
far_mask = far_mask[t:b, l:r]
|
| 185 |
+
|
| 186 |
+
camera_intrinsics = camera_intrinsics.copy()
|
| 187 |
+
camera_intrinsics[0, 2] -= l
|
| 188 |
+
camera_intrinsics[1, 2] -= t
|
| 189 |
+
|
| 190 |
+
return image.to_pil(), depthmap, camera_intrinsics, normal, far_mask
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
|
| 194 |
+
out_width, out_height = output_resolution
|
| 195 |
+
l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
|
| 196 |
+
crop_bbox = (l, t, l + out_width, t + out_height)
|
| 197 |
+
return crop_bbox
|
flow3r/utils/debug.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import debugpy
|
| 4 |
+
import socket
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
def update_vscode_launch_file(host: str, port: int):
|
| 8 |
+
"""Update the .vscode/launch.json file with the new host and port."""
|
| 9 |
+
launch_file_path = ".vscode/launch.json"
|
| 10 |
+
# Desired configuration
|
| 11 |
+
new_config = {
|
| 12 |
+
"version": "0.2.0",
|
| 13 |
+
"configurations": [
|
| 14 |
+
{
|
| 15 |
+
"name": "bash_debug",
|
| 16 |
+
"type": "debugpy",
|
| 17 |
+
"request": "attach",
|
| 18 |
+
"connect": {
|
| 19 |
+
"host": host,
|
| 20 |
+
"port": port
|
| 21 |
+
},
|
| 22 |
+
"justMyCode": False
|
| 23 |
+
},
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
# Ensure the .vscode directory exists
|
| 28 |
+
if not os.path.exists(".vscode"):
|
| 29 |
+
os.makedirs(".vscode")
|
| 30 |
+
|
| 31 |
+
# Write the updated configuration to launch.json
|
| 32 |
+
with open(launch_file_path, "w") as f:
|
| 33 |
+
json.dump(new_config, f, indent=4)
|
| 34 |
+
print(f"Updated {launch_file_path} with host: {host} and port: {port}")
|
| 35 |
+
|
| 36 |
+
def is_port_in_use(host, port):
|
| 37 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 38 |
+
return s.connect_ex((host, port)) == 0
|
| 39 |
+
|
| 40 |
+
def setup_debug(is_main_process=True, max_retries=10, port_range=(10000, 20000)):
|
| 41 |
+
if is_main_process:
|
| 42 |
+
host = os.environ['SLURM_NODELIST'].split(',')[0]
|
| 43 |
+
|
| 44 |
+
for _ in range(max_retries):
|
| 45 |
+
port = random.randint(*port_range)
|
| 46 |
+
try:
|
| 47 |
+
if is_port_in_use(host, port):
|
| 48 |
+
print(f"Port {port} is already in use, trying another...")
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
# 更新 launch.json
|
| 52 |
+
update_vscode_launch_file(host, port)
|
| 53 |
+
|
| 54 |
+
print("master_addr = ", host)
|
| 55 |
+
debugpy.listen((host, port))
|
| 56 |
+
print(f"Waiting for debugger attach at port {port}...", flush=True)
|
| 57 |
+
debugpy.wait_for_client()
|
| 58 |
+
print("Debugger attached", flush=True)
|
| 59 |
+
return
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"Failed to bind to port {port}: {e}")
|
| 62 |
+
|
| 63 |
+
raise RuntimeError("Could not find a free port for debugpy after several attempts.")
|
flow3r/utils/flow_utils.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import flow_vis
|
| 6 |
+
from .geometry import se3_inverse, homogenize_points
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import wandb
|
| 10 |
+
|
| 11 |
+
def warp_image_with_flow(source_image, source_mask, target_image, flow) -> np.ndarray:
|
| 12 |
+
"""
|
| 13 |
+
Warp the target to source image using the given flow vectors.
|
| 14 |
+
Flow vectors indicate the displacement from source to target.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
source_image: np.ndarray of shape (H, W, 3), normalized to [0, 1]
|
| 18 |
+
target_image: np.ndarray of shape (H, W, 3), normalized to [0, 1]
|
| 19 |
+
flow: np.ndarray of shape (H, W, 2)
|
| 20 |
+
source_mask: non_occluded mask represented in source image.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
warped_image: target_image warped according to flow into frame of source image
|
| 24 |
+
np.ndarray of shape (H, W, 3), normalized to [0, 1]
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
# assert source_image.shape[-1] == 3
|
| 28 |
+
# assert target_image.shape[-1] == 3
|
| 29 |
+
|
| 30 |
+
assert flow.shape[-1] == 2
|
| 31 |
+
|
| 32 |
+
# Get the shape of the source image
|
| 33 |
+
height, width = source_image.shape[:2]
|
| 34 |
+
target_height, target_width = target_image.shape[:2]
|
| 35 |
+
|
| 36 |
+
# Create mesh grid
|
| 37 |
+
x, y = np.meshgrid(np.arange(width), np.arange(height))
|
| 38 |
+
|
| 39 |
+
# Apply flow displacements
|
| 40 |
+
flow_x, flow_y = flow[..., 0], flow[..., 1]
|
| 41 |
+
x_new = np.clip(x + flow_x, 0, target_width - 1) + 0.5
|
| 42 |
+
y_new = np.clip(y + flow_y, 0, target_height - 1) + 0.5
|
| 43 |
+
|
| 44 |
+
x_new = (x_new / target_image.shape[1]) * 2 - 1
|
| 45 |
+
y_new = (y_new / target_image.shape[0]) * 2 - 1
|
| 46 |
+
|
| 47 |
+
warped_image = F.grid_sample(
|
| 48 |
+
torch.from_numpy(target_image).permute(2, 0, 1)[None, ...].float(),
|
| 49 |
+
torch.from_numpy(np.stack([x_new, y_new], axis=-1)).float()[None, ...],
|
| 50 |
+
mode="bilinear",
|
| 51 |
+
align_corners=False,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
warped_image = warped_image[0].permute(1, 2, 0).numpy()
|
| 55 |
+
|
| 56 |
+
if source_mask is not None:
|
| 57 |
+
warped_image = warped_image * (source_mask > 0.5)[..., None]
|
| 58 |
+
|
| 59 |
+
return warped_image
|
| 60 |
+
|
| 61 |
+
def ndc_to_pixel_coords(coords_ndc: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 62 |
+
"""
|
| 63 |
+
Convert coordinates from NDC space back to pixel space.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
coords_ndc: [..., H, W, 2], coordinates in NDC space (x_ndc, y_ndc)
|
| 67 |
+
H, W: image dimensions
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
coords_px: [..., H, W, 2], coordinates in pixel space (x_pix, y_pix)
|
| 71 |
+
"""
|
| 72 |
+
coords_px = coords_ndc.clone()
|
| 73 |
+
|
| 74 |
+
# Convert x: NDC [1, -1] -> pixel [0, W-1]
|
| 75 |
+
coords_px[..., 0] = (1.0 - coords_ndc[..., 0]) * max(W - 1, 1) / 2.0
|
| 76 |
+
|
| 77 |
+
# Convert y: NDC [1, -1] -> pixel [0, H-1]
|
| 78 |
+
coords_px[..., 1] = (1.0 - coords_ndc[..., 1]) * max(H - 1, 1) / 2.0
|
| 79 |
+
|
| 80 |
+
return coords_px
|
| 81 |
+
|
| 82 |
+
def coords_to_flow(coords: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 83 |
+
"""
|
| 84 |
+
Convert coordinates to flow by subtracting source pixel coordinates.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
coords: [..., H, W, 2], target coordinates (where pixels from source appear)
|
| 88 |
+
H, W: image dimensions
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
flow: [..., H, W, 2], optical flow (displacement vectors)
|
| 92 |
+
"""
|
| 93 |
+
device = coords.device
|
| 94 |
+
|
| 95 |
+
# Create source coordinate grid
|
| 96 |
+
grid_y, grid_x = torch.meshgrid(
|
| 97 |
+
torch.arange(H, device=device),
|
| 98 |
+
torch.arange(W, device=device),
|
| 99 |
+
indexing="ij"
|
| 100 |
+
)
|
| 101 |
+
source_coords = torch.stack([grid_x, grid_y], dim=-1).float() # (H, W, 2)
|
| 102 |
+
|
| 103 |
+
# Compute flow as target - source
|
| 104 |
+
flow = coords - source_coords
|
| 105 |
+
|
| 106 |
+
return flow
|
| 107 |
+
|
| 108 |
+
def flow_to_coords(flow: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Convert optical flow to absolute target coordinates.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
flow: [..., H, W, 2], optical flow (displacement vectors)
|
| 114 |
+
H, W: image dimensions
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
coords: [..., H, W, 2], absolute target coordinates (pixel positions in target image)
|
| 118 |
+
"""
|
| 119 |
+
device = flow.device
|
| 120 |
+
|
| 121 |
+
# Create source coordinate grid
|
| 122 |
+
grid_y, grid_x = torch.meshgrid(
|
| 123 |
+
torch.arange(H, device=device),
|
| 124 |
+
torch.arange(W, device=device),
|
| 125 |
+
indexing="ij"
|
| 126 |
+
)
|
| 127 |
+
source_coords = torch.stack([grid_x, grid_y], dim=-1).float() # (H, W, 2)
|
| 128 |
+
|
| 129 |
+
# Compute absolute target coordinates
|
| 130 |
+
coords = flow + source_coords
|
| 131 |
+
|
| 132 |
+
return coords
|
| 133 |
+
|
| 134 |
+
def ndc_pixels_to_flow(flow_ndc: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 135 |
+
"""
|
| 136 |
+
Convert optical flow from NDC space back to pixel space.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
flow_ndc: [..., H, W, 2], optical flow in NDC (dx_ndc, dy_ndc),
|
| 140 |
+
PyTorch3D NDC convention: +x left, +y up, origin at image center.
|
| 141 |
+
H, W: image height and width.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
flow_px: [..., H, W, 2], optical flow in pixel space (dx_pix, dy_pix),
|
| 145 |
+
screen convention: +x right, +y down, origin at top-left.
|
| 146 |
+
"""
|
| 147 |
+
# Inverse of: dx_ndc = -2/(W-1)*dx_pix, dy_ndc = -2/(H-1)*dy_pix
|
| 148 |
+
sx = 2.0 / max(W - 1, 1)
|
| 149 |
+
sy = 2.0 / max(H - 1, 1)
|
| 150 |
+
|
| 151 |
+
flow_px = flow_ndc.clone()
|
| 152 |
+
flow_px[..., 0] = - flow_ndc[..., 0] / sx # dx_pix
|
| 153 |
+
flow_px[..., 1] = - flow_ndc[..., 1] / sy # dy_pix
|
| 154 |
+
return flow_px
|
| 155 |
+
|
| 156 |
+
def coords_pixels_to_ndc(coords_px: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
| 157 |
+
"""
|
| 158 |
+
PyTorch3D convention:
|
| 159 |
+
- NDC space: x ∈ [1, -1] (+x left), y ∈ [-1, 1] (+y up), origin at center
|
| 160 |
+
- Pixel space: x ∈ [0, W-1] (+x right), y ∈ [0, H-1] (+y down), origin at top-left
|
| 161 |
+
"""
|
| 162 |
+
coords_ndc = coords_px.clone()
|
| 163 |
+
|
| 164 |
+
# Convert x: pixel [0, W-1] (left→right) -> NDC [1, -1] (left→right in NDC means 1→-1)
|
| 165 |
+
coords_ndc[..., 0] = 1.0 - (coords_px[..., 0] / max(W - 1, 1)) * 2.0
|
| 166 |
+
|
| 167 |
+
# Convert y: pixel [0, H-1] (top→bottom) -> NDC [1, -1] (top→bottom in NDC means 1→-1)
|
| 168 |
+
coords_ndc[..., 1] = 1.0 - (coords_px[..., 1] / max(H - 1, 1)) * 2.0
|
| 169 |
+
|
| 170 |
+
return coords_ndc
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def batched_pi3_motion_flow(world_points, camera_poses, camera_intrinsics, sampled_pairs, image_size):
|
| 174 |
+
"""
|
| 175 |
+
Compute batched motion flow from img1 to img2 using world points and camera pose encodings.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
world_points: (B, N, H, W, 3) predicted world points per image.
|
| 179 |
+
camera_poses: (B, N, 4, 4) extrinsics for each frame, camera-to-world.
|
| 180 |
+
camera_intrinsics: (B, N, 3, 3) camera intrinsics for each frame.
|
| 181 |
+
sampled_pairs: (B, P, 2) image pairs to compute flow between.
|
| 182 |
+
image_size: int, image height/width.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
flow: (B, P, H, W, 2) motion flows, (x, y) in pixel coordinates
|
| 186 |
+
"""
|
| 187 |
+
B, N, H, W, _ = world_points.shape
|
| 188 |
+
P = sampled_pairs.shape[1]
|
| 189 |
+
device = world_points.device
|
| 190 |
+
|
| 191 |
+
# Gather source points
|
| 192 |
+
# (B, P)
|
| 193 |
+
src_idx = sampled_pairs[..., 0]
|
| 194 |
+
# (B, P, 1, 1, 1) -> (B, P, H, W, 3)
|
| 195 |
+
# Expand indices to gather along N dimension
|
| 196 |
+
src_idx_exp = src_idx.view(B, P, 1, 1, 1).expand(B, P, H, W, 3)
|
| 197 |
+
src_points = torch.gather(world_points, 1, src_idx_exp)
|
| 198 |
+
|
| 199 |
+
# Gather target poses and intrinsics
|
| 200 |
+
# (B, P)
|
| 201 |
+
tgt_idx = sampled_pairs[..., 1]
|
| 202 |
+
|
| 203 |
+
tgt_poses = torch.gather(camera_poses, 1, tgt_idx.view(B, P, 1, 1).expand(B, P, 4, 4))
|
| 204 |
+
tgt_intrinsics = torch.gather(camera_intrinsics, 1, tgt_idx.view(B, P, 1, 1).expand(B, P, 3, 3))
|
| 205 |
+
|
| 206 |
+
# Transform points to target camera frame
|
| 207 |
+
w2c_tgt = se3_inverse(tgt_poses)
|
| 208 |
+
src_points_homo = homogenize_points(src_points)
|
| 209 |
+
|
| 210 |
+
# P_cam = T_w2c @ P_world
|
| 211 |
+
# (B, P, 4, 4) @ (B, P, H, W, 4) -> (B, P, H, W, 4)
|
| 212 |
+
pts_cam = torch.einsum('bpij,bphwj->bphwi', w2c_tgt, src_points_homo)[..., :3]
|
| 213 |
+
|
| 214 |
+
# Project to image plane
|
| 215 |
+
# P_img = K @ P_cam
|
| 216 |
+
# (B, P, 3, 3) @ (B, P, H, W, 3) -> (B, P, H, W, 3)
|
| 217 |
+
pts_img = torch.einsum('bpij,bphwj->bphwi', tgt_intrinsics, pts_cam)
|
| 218 |
+
|
| 219 |
+
# Normalize to pixels
|
| 220 |
+
uv_tgt = pts_img[..., :2] / (pts_img[..., 2:3] + 1e-6)
|
| 221 |
+
|
| 222 |
+
# Generate source pixel coordinates
|
| 223 |
+
# print("image_size is: ", image_size)
|
| 224 |
+
H_img, W_img = image_size[0]
|
| 225 |
+
|
| 226 |
+
scale_h = H_img / H
|
| 227 |
+
scale_w = W_img / W
|
| 228 |
+
|
| 229 |
+
y, x = torch.meshgrid(
|
| 230 |
+
torch.arange(H, device=device, dtype=torch.float32),
|
| 231 |
+
torch.arange(W, device=device, dtype=torch.float32),
|
| 232 |
+
indexing='ij'
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Map grid to image coordinates (assuming center of pixels/patches)
|
| 236 |
+
uv_src = torch.stack([
|
| 237 |
+
(x + 0.5) * scale_w - 0.5,
|
| 238 |
+
(y + 0.5) * scale_h - 0.5
|
| 239 |
+
], dim=-1) # (H, W, 2)
|
| 240 |
+
|
| 241 |
+
uv_src = uv_src.view(1, 1, H, W, 2).expand(B, P, -1, -1, -1)
|
| 242 |
+
|
| 243 |
+
return uv_tgt - uv_src
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def visualize_flow(pred_motion_coords, motion_coords, covis_masks, sampled_pairs, images, pred_pi3_flow, iteration, accelerator, dataset_names):
|
| 247 |
+
# visualize gt images, gt flow, pred flow, flow computed from predicted cameras and points
|
| 248 |
+
path = f"/ocean/projects/cis250013p/zcong/pi3/outputs/flow_vis/{iteration}"
|
| 249 |
+
if not os.path.exists(path):
|
| 250 |
+
os.makedirs(path)
|
| 251 |
+
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
# Get dimensions
|
| 254 |
+
B, num_pairs = sampled_pairs.shape[0], sampled_pairs.shape[1]
|
| 255 |
+
H, W = motion_coords[0, 0].shape[0], motion_coords[0, 0].shape[1]
|
| 256 |
+
|
| 257 |
+
# Process all pairs for all batches
|
| 258 |
+
for batch_idx in range(B):
|
| 259 |
+
dataset_name = dataset_names[batch_idx]
|
| 260 |
+
for pair_idx in range(num_pairs):
|
| 261 |
+
if pair_idx > 1: break
|
| 262 |
+
# Get pair indices
|
| 263 |
+
pairs = sampled_pairs[batch_idx, pair_idx].cpu().numpy() # (2,)
|
| 264 |
+
img1 = images[batch_idx, pairs[0]].cpu().numpy()
|
| 265 |
+
img2 = images[batch_idx, pairs[1]].cpu().numpy()
|
| 266 |
+
|
| 267 |
+
# Convert ground truth coordinates to flow
|
| 268 |
+
gt_coords_ndc = motion_coords[batch_idx, pair_idx] # NDC coordinates
|
| 269 |
+
gt_coords_pixel = ndc_to_pixel_coords(gt_coords_ndc, H, W) # Convert to pixel coordinates
|
| 270 |
+
flow_tensor = coords_to_flow(gt_coords_pixel, H, W).float().cpu() # (H, W, 2)
|
| 271 |
+
flow = flow_tensor.numpy() # (H, W, 2)
|
| 272 |
+
|
| 273 |
+
covis_mask = covis_masks[batch_idx, pair_idx].float().cpu().numpy() # (H, W)
|
| 274 |
+
masked_flow = flow * covis_mask[..., None]
|
| 275 |
+
|
| 276 |
+
# Convert predicted coordinates to flow
|
| 277 |
+
pred_coords_ndc = pred_motion_coords[batch_idx, pair_idx] # NDC coordinates
|
| 278 |
+
pred_coords_pixel = ndc_to_pixel_coords(pred_coords_ndc, H, W) # Convert to pixel coordinates
|
| 279 |
+
pred_flow = coords_to_flow(pred_coords_pixel, H, W).float().cpu().numpy() # (H, W, 2)
|
| 280 |
+
masked_pred_flow = pred_flow * covis_mask[..., None]
|
| 281 |
+
|
| 282 |
+
pi3_flow = pred_pi3_flow[batch_idx, pair_idx].float().cpu().numpy() # (H, W, 2)
|
| 283 |
+
masked_pi3_flow = pi3_flow * covis_mask[..., None]
|
| 284 |
+
|
| 285 |
+
# warp img1 to img2
|
| 286 |
+
# first compute gt warpping
|
| 287 |
+
img1_np = np.transpose(img1, (1, 2, 0)) # [H, W, 3]
|
| 288 |
+
img2_np = np.transpose(img2, (1, 2, 0)) # [H, W, 3]
|
| 289 |
+
warped_img_gt = warp_image_with_flow(img1_np, covis_mask, img2_np, flow)
|
| 290 |
+
warped_img_gt = warped_img_gt.clip(0, 1)
|
| 291 |
+
warped_img_gt = Image.fromarray((warped_img_gt * 255).astype(np.uint8))
|
| 292 |
+
# compute prediction warping
|
| 293 |
+
warped_img_pred = warp_image_with_flow(img1_np, covis_mask, img2_np, pred_flow)
|
| 294 |
+
warped_img_pred = warped_img_pred.clip(0, 1)
|
| 295 |
+
warped_img_pred = Image.fromarray((warped_img_pred * 255).astype(np.uint8))
|
| 296 |
+
# compute pi3 warping
|
| 297 |
+
warped_img_pi3 = warp_image_with_flow(img1_np, covis_mask, img2_np, pi3_flow)
|
| 298 |
+
warped_img_pi3 = warped_img_pi3.clip(0, 1)
|
| 299 |
+
warped_img_pi3 = Image.fromarray((warped_img_pi3 * 255).astype(np.uint8))
|
| 300 |
+
|
| 301 |
+
# visualize images
|
| 302 |
+
img_array1 = np.transpose(img1, (1, 2, 0))
|
| 303 |
+
img1_pil = Image.fromarray((img_array1 * 255).astype(np.uint8))
|
| 304 |
+
img_array2 = np.transpose(img2, (1, 2, 0))
|
| 305 |
+
img2_pil = Image.fromarray((img_array2 * 255).astype(np.uint8))
|
| 306 |
+
|
| 307 |
+
# Calculate AEPE metrics
|
| 308 |
+
# Only calculate on valid covisible pixels
|
| 309 |
+
valid_mask = covis_mask > 0
|
| 310 |
+
if np.sum(valid_mask) > 0:
|
| 311 |
+
# AEPE for predicted flow vs GT flow
|
| 312 |
+
flow_diff_pred = np.sqrt(np.sum((masked_pred_flow - masked_flow) ** 2, axis=-1))
|
| 313 |
+
aepe_pred = np.mean(flow_diff_pred[valid_mask])
|
| 314 |
+
aepe_5px_pred = np.mean(flow_diff_pred[valid_mask] < 5.0) * 100 # percentage
|
| 315 |
+
|
| 316 |
+
# AEPE for pi3 flow vs GT flow
|
| 317 |
+
flow_diff_pi3 = np.sqrt(np.sum((masked_pi3_flow - masked_flow) ** 2, axis=-1))
|
| 318 |
+
aepe_pi3 = np.mean(flow_diff_pi3[valid_mask])
|
| 319 |
+
aepe_5px_pi3 = np.mean(flow_diff_pi3[valid_mask] < 5.0) * 100 # percentage
|
| 320 |
+
else:
|
| 321 |
+
aepe_pred = float('inf')
|
| 322 |
+
aepe_5px_pred = 0.0
|
| 323 |
+
aepe_pi3 = float('inf')
|
| 324 |
+
aepe_5px_pi3 = 0.0
|
| 325 |
+
|
| 326 |
+
# visualize flow
|
| 327 |
+
flow_vis_image_gt = flow_vis.flow_to_color(masked_flow)
|
| 328 |
+
flow_pil = Image.fromarray(flow_vis_image_gt.astype(np.uint8))
|
| 329 |
+
flow_vis_image_pred = flow_vis.flow_to_color(masked_pred_flow)
|
| 330 |
+
flow_pred_pil = Image.fromarray(flow_vis_image_pred.astype(np.uint8))
|
| 331 |
+
flow_vis_image_pi3 = flow_vis.flow_to_color(masked_pi3_flow)
|
| 332 |
+
flow_pi3_pil = Image.fromarray(flow_vis_image_pi3.astype(np.uint8))
|
| 333 |
+
|
| 334 |
+
# Create metrics text
|
| 335 |
+
metrics_text = {
|
| 336 |
+
'pred_aepe': aepe_pred,
|
| 337 |
+
'pred_5px_pct': aepe_5px_pred,
|
| 338 |
+
'pi3_aepe': aepe_pi3,
|
| 339 |
+
'pi3_5px_pct': aepe_5px_pi3,
|
| 340 |
+
'covis_ratio': float(np.mean(covis_mask)) * 100,
|
| 341 |
+
'pairs': pairs,
|
| 342 |
+
'dataset': dataset_name,
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
# Save individual visualization and log to wandb
|
| 346 |
+
save_path = os.path.join(path, f"motion_flow_grid_batch_{batch_idx}_pair_{pair_idx}_imgs_{pairs[0]}_{pairs[1]}_iter_{iteration:08d}.png")
|
| 347 |
+
visualize_motion_grid_nodepth_with_metrics(
|
| 348 |
+
img1_pil, img2_pil, flow_pil, flow_pred_pil, flow_pi3_pil,
|
| 349 |
+
warped_img_gt, warped_img_pred, warped_img_pi3,
|
| 350 |
+
metrics_text,
|
| 351 |
+
save_path=save_path,
|
| 352 |
+
pair_idx = pair_idx,
|
| 353 |
+
step=iteration,
|
| 354 |
+
log_to_wandb=True, # We'll handle wandb logging separately
|
| 355 |
+
accelerator=accelerator,
|
| 356 |
+
dataset_name=dataset_name
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def visualize_motion_grid_nodepth_with_metrics(img1, img2, flow_pil, flow_pred_pil, flow_pi3_pil, warped_img_gt, warped_img_pred, warped_img_pi3, metrics_text, pair_idx, save_path="motion_flow_grid.png", step=None, log_to_wandb=True, accelerator=None, dataset_name=None):
|
| 360 |
+
fig, axes = plt.subplots(3, 3, figsize=(20, 16))
|
| 361 |
+
|
| 362 |
+
# images
|
| 363 |
+
axes[0, 0].imshow(img1)
|
| 364 |
+
axes[0, 0].set_title(f"Image {metrics_text['pairs'][0]}")
|
| 365 |
+
axes[0, 0].axis("off")
|
| 366 |
+
|
| 367 |
+
axes[0, 1].imshow(img2)
|
| 368 |
+
axes[0, 1].set_title(f"Image {metrics_text['pairs'][1]}")
|
| 369 |
+
axes[0, 1].axis("off")
|
| 370 |
+
|
| 371 |
+
# Add overall metrics in the third subplot
|
| 372 |
+
axes[0, 2].text(0.1, 0.9, f"{metrics_text['dataset']} Pair: {metrics_text['pairs'][0]} → {metrics_text['pairs'][1]}",
|
| 373 |
+
fontsize=14, fontweight='bold', transform=axes[0, 2].transAxes)
|
| 374 |
+
axes[0, 2].text(0.1, 0.8, f"Covis Ratio: {metrics_text['covis_ratio']:.1f}%",
|
| 375 |
+
fontsize=12, transform=axes[0, 2].transAxes)
|
| 376 |
+
axes[0, 2].text(0.1, 0.7, "Pred Flow Metrics:",
|
| 377 |
+
fontsize=12, fontweight='bold', color='blue', transform=axes[0, 2].transAxes)
|
| 378 |
+
axes[0, 2].text(0.1, 0.6, f"AEPE: {metrics_text['pred_aepe']:.3f}",
|
| 379 |
+
fontsize=11, color='blue', transform=axes[0, 2].transAxes)
|
| 380 |
+
axes[0, 2].text(0.1, 0.5, f"<5px: {metrics_text['pred_5px_pct']:.1f}%",
|
| 381 |
+
fontsize=11, color='blue', transform=axes[0, 2].transAxes)
|
| 382 |
+
axes[0, 2].text(0.1, 0.4, "Pi3 Flow Metrics:",
|
| 383 |
+
fontsize=12, fontweight='bold', color='red', transform=axes[0, 2].transAxes)
|
| 384 |
+
axes[0, 2].text(0.1, 0.3, f"AEPE: {metrics_text['pi3_aepe']:.3f}",
|
| 385 |
+
fontsize=11, color='red', transform=axes[0, 2].transAxes)
|
| 386 |
+
axes[0, 2].text(0.1, 0.2, f"<5px: {metrics_text['pi3_5px_pct']:.1f}%",
|
| 387 |
+
fontsize=11, color='red', transform=axes[0, 2].transAxes)
|
| 388 |
+
axes[0, 2].set_xlim(0, 1)
|
| 389 |
+
axes[0, 2].set_ylim(0, 1)
|
| 390 |
+
axes[0, 2].axis("off")
|
| 391 |
+
|
| 392 |
+
# GT flow and Pred flow
|
| 393 |
+
axes[1, 0].imshow(flow_pil)
|
| 394 |
+
axes[1, 0].set_title("GT Motion Flow")
|
| 395 |
+
axes[1, 0].axis("off")
|
| 396 |
+
|
| 397 |
+
axes[1, 1].imshow(flow_pred_pil)
|
| 398 |
+
axes[1, 1].set_title(f"Predicted Flow\nAEPE: {metrics_text['pred_aepe']:.3f}, <5px: {metrics_text['pred_5px_pct']:.1f}%")
|
| 399 |
+
axes[1, 1].axis("off")
|
| 400 |
+
|
| 401 |
+
axes[1, 2].imshow(flow_pi3_pil)
|
| 402 |
+
axes[1, 2].set_title(f"Pi3 Flow\nAEPE: {metrics_text['pi3_aepe']:.3f}, <5px: {metrics_text['pi3_5px_pct']:.1f}%")
|
| 403 |
+
axes[1, 2].axis("off")
|
| 404 |
+
|
| 405 |
+
# GT warp and Pred warp
|
| 406 |
+
axes[2, 0].imshow(warped_img_gt)
|
| 407 |
+
axes[2, 0].set_title("GT Warped Image")
|
| 408 |
+
axes[2, 0].axis("off")
|
| 409 |
+
|
| 410 |
+
axes[2, 1].imshow(warped_img_pred)
|
| 411 |
+
axes[2, 1].set_title("Pred Warped Image")
|
| 412 |
+
axes[2, 1].axis("off")
|
| 413 |
+
|
| 414 |
+
axes[2, 2].imshow(warped_img_pi3)
|
| 415 |
+
axes[2, 2].set_title("PI3 Warped Image")
|
| 416 |
+
axes[2, 2].axis("off")
|
| 417 |
+
|
| 418 |
+
plt.tight_layout()
|
| 419 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 420 |
+
if log_to_wandb:
|
| 421 |
+
accelerator.log({f"Visualization_{pair_idx}": wandb.Image(save_path)}, step=step)
|
| 422 |
+
plt.close()
|
| 423 |
+
|
| 424 |
+
def calculate_flow_metrics(pred_motion_coords, motion_coords, covis_masks, sampled_pairs, pred_pi3_flow):
|
| 425 |
+
with torch.no_grad():
|
| 426 |
+
# Get dimensions
|
| 427 |
+
B, num_pairs = sampled_pairs.shape[0], sampled_pairs.shape[1]
|
| 428 |
+
H, W = motion_coords[0, 0].shape[0], motion_coords[0, 0].shape[1]
|
| 429 |
+
aepe_pred, aepe_5px_pred, aepe_pi3, aepe_5px_pi3 = [], [], [], []
|
| 430 |
+
|
| 431 |
+
# Process all pairs for all batches
|
| 432 |
+
for batch_idx in range(B):
|
| 433 |
+
for pair_idx in range(num_pairs):
|
| 434 |
+
# Convert ground truth coordinates to flow
|
| 435 |
+
gt_coords_ndc = motion_coords[batch_idx, pair_idx] # NDC coordinates
|
| 436 |
+
gt_coords_pixel = ndc_to_pixel_coords(gt_coords_ndc, H, W) # Convert to pixel coordinates
|
| 437 |
+
flow_tensor = coords_to_flow(gt_coords_pixel, H, W).float().cpu() # (H, W, 2)
|
| 438 |
+
flow = flow_tensor.numpy() # (H, W, 2)
|
| 439 |
+
|
| 440 |
+
covis_mask = covis_masks[batch_idx, pair_idx].float().cpu().numpy() # (H, W)
|
| 441 |
+
masked_flow = flow * covis_mask[..., None]
|
| 442 |
+
|
| 443 |
+
# Convert predicted coordinates to flow
|
| 444 |
+
pred_coords_ndc = pred_motion_coords[batch_idx, pair_idx] # NDC coordinates
|
| 445 |
+
pred_coords_pixel = ndc_to_pixel_coords(pred_coords_ndc, H, W) # Convert to pixel coordinates
|
| 446 |
+
pred_flow = coords_to_flow(pred_coords_pixel, H, W).float().cpu().numpy() # (H, W, 2)
|
| 447 |
+
masked_pred_flow = pred_flow * covis_mask[..., None]
|
| 448 |
+
|
| 449 |
+
pi3_flow = pred_pi3_flow[batch_idx, pair_idx].float().cpu().numpy() # (H, W, 2)
|
| 450 |
+
masked_pi3_flow = pi3_flow * covis_mask[..., None]
|
| 451 |
+
|
| 452 |
+
# Calculate AEPE metrics
|
| 453 |
+
# Only calculate on valid covisible pixels
|
| 454 |
+
valid_mask = covis_mask > 0
|
| 455 |
+
if np.sum(valid_mask) > 0:
|
| 456 |
+
# AEPE for predicted flow vs GT flow
|
| 457 |
+
flow_diff_pred = np.sqrt(np.sum((masked_pred_flow - masked_flow) ** 2, axis=-1))
|
| 458 |
+
aepe_pred.append(np.mean(flow_diff_pred[valid_mask]))
|
| 459 |
+
aepe_5px_pred.append(np.mean(flow_diff_pred[valid_mask] < 5.0) * 100) # percentage
|
| 460 |
+
|
| 461 |
+
# AEPE for pi3 flow vs GT flow
|
| 462 |
+
flow_diff_pi3 = np.sqrt(np.sum((masked_pi3_flow - masked_flow) ** 2, axis=-1))
|
| 463 |
+
aepe_pi3.append(np.mean(flow_diff_pi3[valid_mask]))
|
| 464 |
+
aepe_5px_pi3.append(np.mean(flow_diff_pi3[valid_mask] < 5.0) * 100) # percentage
|
| 465 |
+
else:
|
| 466 |
+
aepe_pred.append(float('inf'))
|
| 467 |
+
aepe_5px_pred.append(0.0)
|
| 468 |
+
aepe_pi3.append(float('inf'))
|
| 469 |
+
aepe_5px_pi3.append(0.0)
|
| 470 |
+
|
| 471 |
+
# print("aepe 5px pi3 is",aepe_5px_pi3)
|
| 472 |
+
return np.mean(aepe_pred), np.mean(aepe_5px_pred), np.mean(aepe_pi3), np.mean(aepe_5px_pi3)
|
flow3r/utils/geometry.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
def se3_inverse(T):
|
| 6 |
+
"""
|
| 7 |
+
Computes the inverse of a batch of SE(3) matrices.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
if torch.is_tensor(T):
|
| 11 |
+
R = T[..., :3, :3]
|
| 12 |
+
t = T[..., :3, 3].unsqueeze(-1)
|
| 13 |
+
R_inv = R.transpose(-2, -1)
|
| 14 |
+
t_inv = -torch.matmul(R_inv, t)
|
| 15 |
+
T_inv = torch.cat([
|
| 16 |
+
torch.cat([R_inv, t_inv], dim=-1),
|
| 17 |
+
torch.tensor([0, 0, 0, 1], device=T.device, dtype=T.dtype).repeat(*T.shape[:-2], 1, 1)
|
| 18 |
+
], dim=-2)
|
| 19 |
+
else:
|
| 20 |
+
R = T[..., :3, :3]
|
| 21 |
+
t = T[..., :3, 3, np.newaxis]
|
| 22 |
+
|
| 23 |
+
R_inv = np.swapaxes(R, -2, -1)
|
| 24 |
+
t_inv = -R_inv @ t
|
| 25 |
+
|
| 26 |
+
bottom_row = np.zeros((*T.shape[:-2], 1, 4), dtype=T.dtype)
|
| 27 |
+
bottom_row[..., :, 3] = 1
|
| 28 |
+
|
| 29 |
+
top_part = np.concatenate([R_inv, t_inv], axis=-1)
|
| 30 |
+
T_inv = np.concatenate([top_part, bottom_row], axis=-2)
|
| 31 |
+
|
| 32 |
+
return T_inv
|
| 33 |
+
|
| 34 |
+
def get_pixel(H, W):
|
| 35 |
+
# get 2D pixels (u, v) for image_a in cam_a pixel space
|
| 36 |
+
u_a, v_a = np.meshgrid(np.arange(W), np.arange(H))
|
| 37 |
+
# u_a = np.flip(u_a, axis=1)
|
| 38 |
+
# v_a = np.flip(v_a, axis=0)
|
| 39 |
+
pixels_a = np.stack([
|
| 40 |
+
u_a.flatten() + 0.5,
|
| 41 |
+
v_a.flatten() + 0.5,
|
| 42 |
+
np.ones_like(u_a.flatten())
|
| 43 |
+
], axis=0)
|
| 44 |
+
|
| 45 |
+
return pixels_a
|
| 46 |
+
|
| 47 |
+
def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw):
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
- depthmap (HxW array):
|
| 51 |
+
- camera_intrinsics: a 3x3 matrix
|
| 52 |
+
- camera_pose: a 4x3 or 4x4 cam2world matrix
|
| 53 |
+
Returns:
|
| 54 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
|
| 55 |
+
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
|
| 56 |
+
if z_far > 0:
|
| 57 |
+
valid_mask = valid_mask & (depthmap < z_far)
|
| 58 |
+
|
| 59 |
+
X_world = X_cam # default
|
| 60 |
+
if camera_pose is not None:
|
| 61 |
+
# R_cam2world = np.float32(camera_params["R_cam2world"])
|
| 62 |
+
# t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
|
| 63 |
+
R_cam2world = camera_pose[:3, :3]
|
| 64 |
+
t_cam2world = camera_pose[:3, 3]
|
| 65 |
+
|
| 66 |
+
# Express in absolute coordinates (invalid depth values)
|
| 67 |
+
X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
|
| 68 |
+
|
| 69 |
+
return X_world, valid_mask
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
|
| 73 |
+
"""
|
| 74 |
+
Args:
|
| 75 |
+
- depthmap (HxW array):
|
| 76 |
+
- camera_intrinsics: a 3x3 matrix
|
| 77 |
+
Returns:
|
| 78 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 79 |
+
"""
|
| 80 |
+
camera_intrinsics = np.float32(camera_intrinsics)
|
| 81 |
+
H, W = depthmap.shape
|
| 82 |
+
|
| 83 |
+
# Compute 3D ray associated with each pixel
|
| 84 |
+
# Strong assumption: there are no skew terms
|
| 85 |
+
# assert camera_intrinsics[0, 1] == 0.0
|
| 86 |
+
# assert camera_intrinsics[1, 0] == 0.0
|
| 87 |
+
if pseudo_focal is None:
|
| 88 |
+
fu = camera_intrinsics[0, 0]
|
| 89 |
+
fv = camera_intrinsics[1, 1]
|
| 90 |
+
else:
|
| 91 |
+
assert pseudo_focal.shape == (H, W)
|
| 92 |
+
fu = fv = pseudo_focal
|
| 93 |
+
cu = camera_intrinsics[0, 2]
|
| 94 |
+
cv = camera_intrinsics[1, 2]
|
| 95 |
+
|
| 96 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 97 |
+
z_cam = depthmap
|
| 98 |
+
x_cam = (u - cu) * z_cam / fu
|
| 99 |
+
y_cam = (v - cv) * z_cam / fv
|
| 100 |
+
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
| 101 |
+
|
| 102 |
+
# Mask for valid coordinates
|
| 103 |
+
valid_mask = (depthmap > 0.0)
|
| 104 |
+
# Invalid any depth > 80m
|
| 105 |
+
valid_mask = valid_mask
|
| 106 |
+
return X_cam, valid_mask
|
| 107 |
+
|
| 108 |
+
def homogenize_points(
|
| 109 |
+
points,
|
| 110 |
+
):
|
| 111 |
+
"""Convert batched points (xyz) to (xyz1)."""
|
| 112 |
+
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
|
| 116 |
+
|
| 117 |
+
if H is None:
|
| 118 |
+
B,H,W = depth1.shape
|
| 119 |
+
else:
|
| 120 |
+
B = depth1.shape[0]
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
x1_n = torch.meshgrid(
|
| 123 |
+
*[
|
| 124 |
+
torch.linspace(
|
| 125 |
+
-1 + 1 / n, 1 - 1 / n, n, device=depth1.device
|
| 126 |
+
)
|
| 127 |
+
for n in (B, H, W)
|
| 128 |
+
],
|
| 129 |
+
indexing = 'ij'
|
| 130 |
+
)
|
| 131 |
+
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
|
| 132 |
+
mask, x2 = warp_kpts(
|
| 133 |
+
x1_n.double(),
|
| 134 |
+
depth1.double(),
|
| 135 |
+
depth2.double(),
|
| 136 |
+
T_1to2.double(),
|
| 137 |
+
K1.double(),
|
| 138 |
+
K2.double(),
|
| 139 |
+
depth_interpolation_mode = depth_interpolation_mode,
|
| 140 |
+
relative_depth_error_threshold = relative_depth_error_threshold,
|
| 141 |
+
)
|
| 142 |
+
prob = mask.float().reshape(B, H, W)
|
| 143 |
+
x2 = x2.reshape(B, H, W, 2)
|
| 144 |
+
return x2, prob
|
| 145 |
+
|
| 146 |
+
@torch.no_grad()
|
| 147 |
+
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
|
| 148 |
+
"""Warp kpts0 from I0 to I1 with depth, K and Rt
|
| 149 |
+
Also check covisibility and depth consistency.
|
| 150 |
+
Depth is consistent if relative error < 0.2 (hard-coded).
|
| 151 |
+
# https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
|
| 152 |
+
Args:
|
| 153 |
+
kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
|
| 154 |
+
depth0 (torch.Tensor): [N, H, W],
|
| 155 |
+
depth1 (torch.Tensor): [N, H, W],
|
| 156 |
+
T_0to1 (torch.Tensor): [N, 3, 4],
|
| 157 |
+
K0 (torch.Tensor): [N, 3, 3],
|
| 158 |
+
K1 (torch.Tensor): [N, 3, 3],
|
| 159 |
+
Returns:
|
| 160 |
+
calculable_mask (torch.Tensor): [N, L]
|
| 161 |
+
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
|
| 162 |
+
"""
|
| 163 |
+
(
|
| 164 |
+
n,
|
| 165 |
+
h,
|
| 166 |
+
w,
|
| 167 |
+
) = depth0.shape
|
| 168 |
+
if depth_interpolation_mode == "combined":
|
| 169 |
+
# Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
|
| 170 |
+
if smooth_mask:
|
| 171 |
+
raise NotImplementedError("Combined bilinear and NN warp not implemented")
|
| 172 |
+
valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
|
| 173 |
+
smooth_mask = smooth_mask,
|
| 174 |
+
return_relative_depth_error = return_relative_depth_error,
|
| 175 |
+
depth_interpolation_mode = "bilinear",
|
| 176 |
+
relative_depth_error_threshold = relative_depth_error_threshold)
|
| 177 |
+
valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
|
| 178 |
+
smooth_mask = smooth_mask,
|
| 179 |
+
return_relative_depth_error = return_relative_depth_error,
|
| 180 |
+
depth_interpolation_mode = "nearest-exact",
|
| 181 |
+
relative_depth_error_threshold = relative_depth_error_threshold)
|
| 182 |
+
nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
|
| 183 |
+
warp = warp_bilinear.clone()
|
| 184 |
+
warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
|
| 185 |
+
valid = valid_bilinear | valid_nearest
|
| 186 |
+
return valid, warp
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
|
| 190 |
+
:, 0, :, 0
|
| 191 |
+
]
|
| 192 |
+
kpts0 = torch.stack(
|
| 193 |
+
(w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
|
| 194 |
+
) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
|
| 195 |
+
# Sample depth, get calculable_mask on depth != 0
|
| 196 |
+
# nonzero_mask = kpts0_depth != 0
|
| 197 |
+
# Sample depth, get calculable_mask on depth > 0
|
| 198 |
+
nonzero_mask = kpts0_depth > 0
|
| 199 |
+
|
| 200 |
+
# Unproject
|
| 201 |
+
kpts0_h = (
|
| 202 |
+
torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
|
| 203 |
+
* kpts0_depth[..., None]
|
| 204 |
+
) # (N, L, 3)
|
| 205 |
+
kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
|
| 206 |
+
kpts0_cam = kpts0_n
|
| 207 |
+
|
| 208 |
+
# Rigid Transform
|
| 209 |
+
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
|
| 210 |
+
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
|
| 211 |
+
|
| 212 |
+
# Project
|
| 213 |
+
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
|
| 214 |
+
w_kpts0 = w_kpts0_h[:, :, :2] / (
|
| 215 |
+
w_kpts0_h[:, :, [2]] + 1e-4
|
| 216 |
+
) # (N, L, 2), +1e-4 to avoid zero depth
|
| 217 |
+
|
| 218 |
+
# Covisible Check
|
| 219 |
+
h, w = depth1.shape[1:3]
|
| 220 |
+
covisible_mask = (
|
| 221 |
+
(w_kpts0[:, :, 0] > 0)
|
| 222 |
+
* (w_kpts0[:, :, 0] < w - 1)
|
| 223 |
+
* (w_kpts0[:, :, 1] > 0)
|
| 224 |
+
* (w_kpts0[:, :, 1] < h - 1)
|
| 225 |
+
)
|
| 226 |
+
w_kpts0 = torch.stack(
|
| 227 |
+
(2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
|
| 228 |
+
) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
|
| 229 |
+
# w_kpts0[~covisible_mask, :] = -5 # xd
|
| 230 |
+
|
| 231 |
+
w_kpts0_depth = F.grid_sample(
|
| 232 |
+
depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
|
| 233 |
+
)[:, 0, :, 0]
|
| 234 |
+
|
| 235 |
+
relative_depth_error = (
|
| 236 |
+
(w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
|
| 237 |
+
).abs()
|
| 238 |
+
if not smooth_mask:
|
| 239 |
+
consistent_mask = relative_depth_error < relative_depth_error_threshold
|
| 240 |
+
else:
|
| 241 |
+
consistent_mask = (-relative_depth_error/smooth_mask).exp()
|
| 242 |
+
valid_mask = nonzero_mask * covisible_mask * consistent_mask
|
| 243 |
+
if return_relative_depth_error:
|
| 244 |
+
return relative_depth_error, w_kpts0
|
| 245 |
+
else:
|
| 246 |
+
return valid_mask, w_kpts0
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def geotrf(Trf, pts, ncol=None, norm=False):
|
| 250 |
+
""" Apply a geometric transformation to a list of 3-D points.
|
| 251 |
+
|
| 252 |
+
H: 3x3 or 4x4 projection matrix (typically a Homography)
|
| 253 |
+
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
| 254 |
+
|
| 255 |
+
ncol: int. number of columns of the result (2 or 3)
|
| 256 |
+
norm: float. if != 0, the resut is projected on the z=norm plane.
|
| 257 |
+
|
| 258 |
+
Returns an array of projected 2d points.
|
| 259 |
+
"""
|
| 260 |
+
assert Trf.ndim >= 2
|
| 261 |
+
if isinstance(Trf, np.ndarray):
|
| 262 |
+
pts = np.asarray(pts)
|
| 263 |
+
elif isinstance(Trf, torch.Tensor):
|
| 264 |
+
pts = torch.as_tensor(pts, dtype=Trf.dtype)
|
| 265 |
+
|
| 266 |
+
# adapt shape if necessary
|
| 267 |
+
output_reshape = pts.shape[:-1]
|
| 268 |
+
ncol = ncol or pts.shape[-1]
|
| 269 |
+
|
| 270 |
+
# optimized code
|
| 271 |
+
if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
|
| 272 |
+
Trf.ndim == 3 and pts.ndim == 4):
|
| 273 |
+
d = pts.shape[3]
|
| 274 |
+
if Trf.shape[-1] == d:
|
| 275 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
|
| 276 |
+
elif Trf.shape[-1] == d + 1:
|
| 277 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
|
| 278 |
+
else:
|
| 279 |
+
raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
|
| 280 |
+
else:
|
| 281 |
+
if Trf.ndim >= 3:
|
| 282 |
+
n = Trf.ndim - 2
|
| 283 |
+
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
|
| 284 |
+
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
| 285 |
+
|
| 286 |
+
if pts.ndim > Trf.ndim:
|
| 287 |
+
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
|
| 288 |
+
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
| 289 |
+
elif pts.ndim == 2:
|
| 290 |
+
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
|
| 291 |
+
pts = pts[:, None, :]
|
| 292 |
+
|
| 293 |
+
if pts.shape[-1] + 1 == Trf.shape[-1]:
|
| 294 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 295 |
+
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
| 296 |
+
elif pts.shape[-1] == Trf.shape[-1]:
|
| 297 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 298 |
+
pts = pts @ Trf
|
| 299 |
+
else:
|
| 300 |
+
pts = Trf @ pts.T
|
| 301 |
+
if pts.ndim >= 2:
|
| 302 |
+
pts = pts.swapaxes(-1, -2)
|
| 303 |
+
|
| 304 |
+
if norm:
|
| 305 |
+
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
| 306 |
+
if norm != 1:
|
| 307 |
+
pts *= norm
|
| 308 |
+
|
| 309 |
+
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
| 310 |
+
return res
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def inv(mat):
|
| 314 |
+
""" Invert a torch or numpy matrix
|
| 315 |
+
"""
|
| 316 |
+
if isinstance(mat, torch.Tensor):
|
| 317 |
+
return torch.linalg.inv(mat)
|
| 318 |
+
if isinstance(mat, np.ndarray):
|
| 319 |
+
return np.linalg.inv(mat)
|
| 320 |
+
raise ValueError(f'bad matrix type = {type(mat)}')
|
| 321 |
+
|
| 322 |
+
def opencv_camera_to_plucker(poses, K, H, W):
|
| 323 |
+
device = poses.device
|
| 324 |
+
B = poses.shape[0]
|
| 325 |
+
|
| 326 |
+
pixel = torch.from_numpy(get_pixel(H, W).astype(np.float32)).to(device).T.reshape(H, W, 3)[None].repeat(B, 1, 1, 1) # (3, H, W)
|
| 327 |
+
pixel = torch.einsum('bij, bhwj -> bhwi', torch.inverse(K), pixel)
|
| 328 |
+
ray_directions = torch.einsum('bij, bhwj -> bhwi', poses[..., :3, :3], pixel)
|
| 329 |
+
|
| 330 |
+
ray_origins = poses[..., :3, 3][:, None, None].repeat(1, H, W, 1)
|
| 331 |
+
|
| 332 |
+
ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
|
| 333 |
+
plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
|
| 334 |
+
plucker_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
|
| 335 |
+
|
| 336 |
+
return plucker_ray
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor:
|
| 340 |
+
"""
|
| 341 |
+
Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
depth (torch.Tensor): shape (..., height, width), linear depth map
|
| 345 |
+
atol (float): absolute tolerance
|
| 346 |
+
rtol (float): relative tolerance
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
|
| 350 |
+
"""
|
| 351 |
+
shape = depth.shape
|
| 352 |
+
depth = depth.reshape(-1, 1, *shape[-2:])
|
| 353 |
+
if mask is not None:
|
| 354 |
+
mask = mask.reshape(-1, 1, *shape[-2:])
|
| 355 |
+
|
| 356 |
+
if mask is None:
|
| 357 |
+
diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2))
|
| 358 |
+
else:
|
| 359 |
+
diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2))
|
| 360 |
+
|
| 361 |
+
edge = torch.zeros_like(depth, dtype=torch.bool)
|
| 362 |
+
if atol is not None:
|
| 363 |
+
edge |= diff > atol
|
| 364 |
+
if rtol is not None:
|
| 365 |
+
edge |= (diff / depth).nan_to_num_() > rtol
|
| 366 |
+
edge = edge.reshape(*shape)
|
| 367 |
+
return edge
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.5.1
|
| 2 |
+
torchvision==0.20.1
|
| 3 |
+
numpy==1.26.4
|
| 4 |
+
pillow
|
| 5 |
+
opencv-python
|
| 6 |
+
plyfile
|
| 7 |
+
huggingface_hub
|
| 8 |
+
safetensors
|
| 9 |
+
|
| 10 |
+
# below for gradio
|
| 11 |
+
gradio
|
| 12 |
+
trimesh
|
| 13 |
+
matplotlib
|
| 14 |
+
scipy
|
| 15 |
+
spaces
|