weikaih commited on
Commit
f71ac1d
·
verified ·
1 Parent(s): 2fcc11c

WildDet3D Gradio demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. __pycache__/vis3d_glb.cpython-311.pyc +0 -0
  3. app.py +822 -0
  4. assets/demo/intrinsics.npy +3 -0
  5. assets/demo/rgb.png +3 -0
  6. requirements.txt +59 -0
  7. third_party/lingbot_depth/mdm/model/__init__.py +15 -0
  8. third_party/lingbot_depth/mdm/model/dinov2_rgbd/__init__.py +6 -0
  9. third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/__init__.py +4 -0
  10. third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/backbones.py +162 -0
  11. third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/utils.py +39 -0
  12. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/__init__.py +12 -0
  13. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/attention.py +100 -0
  14. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/block.py +259 -0
  15. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/dino_head.py +58 -0
  16. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/drop_path.py +34 -0
  17. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/layer_scale.py +27 -0
  18. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/mlp.py +40 -0
  19. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed.py +88 -0
  20. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py +153 -0
  21. third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py +72 -0
  22. third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/__init__.py +55 -0
  23. third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/mask_utils.py +137 -0
  24. third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/vision_transformer.py +479 -0
  25. third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/__init__.py +4 -0
  26. third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/cluster.py +95 -0
  27. third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/config.py +72 -0
  28. third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/dtype.py +37 -0
  29. third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/param_groups.py +103 -0
  30. third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/utils.py +95 -0
  31. third_party/lingbot_depth/mdm/model/modules_decoder.py +185 -0
  32. third_party/lingbot_depth/mdm/model/modules_rgbd_encoder.py +152 -0
  33. third_party/lingbot_depth/mdm/model/utils.py +127 -0
  34. third_party/lingbot_depth/mdm/model/v2.py +297 -0
  35. third_party/lingbot_depth/mdm/utils/__init__.py +0 -0
  36. third_party/lingbot_depth/mdm/utils/geo.py +105 -0
  37. third_party/lingbot_depth/mdm/utils/io.py +270 -0
  38. third_party/lingbot_depth/mdm/utils/tools.py +289 -0
  39. third_party/lingbot_depth/mdm/utils/vis.py +65 -0
  40. third_party/lingbot_depth/pyproject.toml +26 -0
  41. third_party/sam3/pyproject.toml +135 -0
  42. third_party/sam3/sam3/__init__.py +9 -0
  43. third_party/sam3/sam3/__pycache__/__init__.cpython-311.pyc +0 -0
  44. third_party/sam3/sam3/__pycache__/logger.cpython-311.pyc +0 -0
  45. third_party/sam3/sam3/__pycache__/model_builder.cpython-311.pyc +0 -0
  46. third_party/sam3/sam3/agent/__init__.py +3 -0
  47. third_party/sam3/sam3/agent/agent_core.py +565 -0
  48. third_party/sam3/sam3/agent/client_llm.py +207 -0
  49. third_party/sam3/sam3/agent/client_sam3.py +139 -0
  50. third_party/sam3/sam3/agent/helpers/__init__.py +3 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo/rgb.png filter=lfs diff=lfs merge=lfs -text
37
+ third_party/sam3/sam3/model/__pycache__/video_tracking_multiplex.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
38
+ third_party/sam3/sam3/perflib/tests/assets/masks.tiff filter=lfs diff=lfs merge=lfs -text
__pycache__/vis3d_glb.cpython-311.pyc ADDED
Binary file (28.6 kB). View file
 
app.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio Web Demo for WildDet3D (5-mode).
2
+
3
+ Supports 5 prompt modes:
4
+ - Text: Enter text like "chair.table" (one-to-many)
5
+ - Visual: Click box on image, text="visual" (one-to-many)
6
+ - Visual+Label: Click box + category label (one-to-many)
7
+ - Geometry: Click box on image, text="geometric" (one-to-one)
8
+ - Geometry+Label: Click box + category label (one-to-one)
9
+ - Point: Click on image to select point
10
+
11
+ Requirements:
12
+ pip install gradio>=5.0.0
13
+
14
+ Usage:
15
+ python demo/huggingface/app.py
16
+
17
+ Then open http://localhost:7860 in browser.
18
+ """
19
+
20
+ import os
21
+ import sys
22
+ from pathlib import Path
23
+
24
+ # Add paths: support both local dev and HuggingFace Space.
25
+ # Local dev: demo/huggingface/app.py -> repo root = ../../
26
+ # HF Space: wilddet3d/ is bundled in the same directory as app.py
27
+ _this_dir = Path(__file__).resolve().parent
28
+ if (_this_dir / "wilddet3d").exists():
29
+ # HuggingFace Space: everything bundled next to app.py
30
+ sys.path.insert(0, str(_this_dir))
31
+ else:
32
+ # Local dev: repo root is two levels up
33
+ repo_root = _this_dir.parent.parent
34
+ sys.path.insert(0, str(repo_root))
35
+
36
+ import spaces
37
+ import gradio as gr
38
+ import numpy as np
39
+ import torch
40
+ import cv2
41
+ from PIL import Image
42
+
43
+ from wilddet3d.inference import build_model, WildDet3DPredictor
44
+ from wilddet3d.preprocessing import preprocess
45
+ from wilddet3d.vis.visualize import draw_3d_boxes
46
+ from vis3d_glb import (
47
+ depth_to_pointcloud, create_scene_glb, create_mesh_scene_glb,
48
+ )
49
+
50
+
51
+ def draw_points_on_image(image, points, color=(0, 255, 0), radius=8):
52
+ """Draw points on image.
53
+
54
+ Args:
55
+ image: numpy array (H, W, 3)
56
+ points: list of (x, y, label) tuples
57
+ color: color for positive points (green default)
58
+ radius: point radius
59
+
60
+ Returns:
61
+ Image with points drawn
62
+ """
63
+ img = image.copy()
64
+ for x, y, label in points:
65
+ c = color if label == 1 else (255, 0, 0)
66
+ cv2.circle(img, (int(x), int(y)), radius, c, -1)
67
+ cv2.circle(img, (int(x), int(y)), radius + 2, (255, 255, 255), 2)
68
+ return img
69
+
70
+
71
+ def draw_box_on_image(image, box, color=(0, 0, 255), thickness=3):
72
+ """Draw box on image.
73
+
74
+ Args:
75
+ image: numpy array (H, W, 3)
76
+ box: [x1, y1, x2, y2] coordinates
77
+ color: box color (red default)
78
+ thickness: line thickness
79
+
80
+ Returns:
81
+ Image with box drawn
82
+ """
83
+ img = image.copy()
84
+ x1, y1, x2, y2 = [int(v) for v in box]
85
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
86
+ return img
87
+
88
+
89
+ # HuggingFace Model repo for checkpoints
90
+ HF_MODEL_REPO = "weikaih/WildDet3D"
91
+ HF_CKPT_NAME = "wilddet3d.pt"
92
+
93
+ # Local checkpoint paths (tried in order)
94
+ LOCAL_CHECKPOINTS = [
95
+ "ckpt/wilddet3d.pt", # release repo layout
96
+ ]
97
+
98
+ # Default demo image path
99
+ DEFAULT_IMAGE_PATH = "assets/demo/rgb.png"
100
+ DEFAULT_INTRINSICS_PATH = "assets/demo/intrinsics.npy"
101
+
102
+ # Global model (loaded once)
103
+ _cached_model = None
104
+
105
+
106
+ def _resolve_checkpoint():
107
+ """Resolve checkpoint: local if exists, else download from HF Hub."""
108
+ for path in LOCAL_CHECKPOINTS:
109
+ if os.path.exists(path):
110
+ return path
111
+ from huggingface_hub import hf_hub_download
112
+ hf_token = os.environ.get("HF_TOKEN")
113
+ print(f"Downloading checkpoint from {HF_MODEL_REPO}...")
114
+ ckpt = hf_hub_download(
115
+ repo_id=HF_MODEL_REPO, filename=HF_CKPT_NAME, token=hf_token
116
+ )
117
+ return ckpt
118
+
119
+
120
+ def get_model():
121
+ """Load model once and cache it."""
122
+ global _cached_model
123
+ if _cached_model is None:
124
+ ckpt_path = _resolve_checkpoint()
125
+ print(f"Loading WildDet3D model from {ckpt_path}...")
126
+ _cached_model = build_model(
127
+ checkpoint=ckpt_path,
128
+ score_threshold=0.0,
129
+ canonical_rotation=True,
130
+ skip_pretrained=True,
131
+ )
132
+ print("Model loaded!")
133
+ return _cached_model
134
+
135
+
136
+ def load_default_image():
137
+ """Load the default demo image."""
138
+ if os.path.exists(DEFAULT_IMAGE_PATH):
139
+ return np.array(Image.open(DEFAULT_IMAGE_PATH))
140
+ return None
141
+
142
+
143
+ def load_default_intrinsics():
144
+ """Load default intrinsics values."""
145
+ if os.path.exists(DEFAULT_INTRINSICS_PATH):
146
+ intrinsics = np.load(DEFAULT_INTRINSICS_PATH)
147
+ return (
148
+ float(intrinsics[0, 0]),
149
+ float(intrinsics[1, 1]),
150
+ float(intrinsics[0, 2]),
151
+ float(intrinsics[1, 2]),
152
+ )
153
+ return 518.86, 519.47, 325.58, 253.74
154
+
155
+
156
+ def format_intrinsics(K):
157
+ """Format intrinsics tensor for display."""
158
+ if K is None:
159
+ return "Not available"
160
+ if isinstance(K, torch.Tensor):
161
+ K = K.cpu().numpy()
162
+ if K.ndim == 3:
163
+ K = K[0]
164
+ return (
165
+ f"fx={K[0, 0]:.2f}, fy={K[1, 1]:.2f}, "
166
+ f"cx={K[0, 2]:.2f}, cy={K[1, 2]:.2f}"
167
+ )
168
+
169
+
170
+ def scale_intrinsics_to_original(K, input_hw, original_hw):
171
+ """Scale intrinsics from model input resolution to original."""
172
+ if K is None:
173
+ return None
174
+
175
+ if isinstance(K, torch.Tensor):
176
+ K = K.clone()
177
+ else:
178
+ K = K.copy()
179
+
180
+ input_h, input_w = input_hw
181
+ orig_h, orig_w = original_hw
182
+
183
+ scale_x = orig_w / input_w
184
+ scale_y = orig_h / input_h
185
+
186
+ if K.ndim == 3:
187
+ K[:, 0, 0] *= scale_x
188
+ K[:, 1, 1] *= scale_y
189
+ K[:, 0, 2] *= scale_x
190
+ K[:, 1, 2] *= scale_y
191
+ else:
192
+ K[0, 0] *= scale_x
193
+ K[1, 1] *= scale_y
194
+ K[0, 2] *= scale_x
195
+ K[1, 2] *= scale_y
196
+
197
+ return K
198
+
199
+
200
+ def transform_coords_to_input_space(x, y, original_hw, input_hw, padding):
201
+ """Transform coords from original image space to preprocessed input.
202
+
203
+ Args:
204
+ x, y: Coordinates in original image space
205
+ original_hw: (H, W) of original image
206
+ input_hw: (H, W) of preprocessed image (e.g., 1008x1008)
207
+ padding: (pad_left, pad_right, pad_top, pad_bottom)
208
+
209
+ Returns:
210
+ (new_x, new_y) in preprocessed input space
211
+ """
212
+ orig_h, orig_w = original_hw
213
+ pad_left, pad_right, pad_top, pad_bottom = padding
214
+
215
+ content_w = input_hw[1] - pad_left - pad_right
216
+ content_h = input_hw[0] - pad_top - pad_bottom
217
+
218
+ scale_x = content_w / orig_w
219
+ scale_y = content_h / orig_h
220
+
221
+ new_x = x * scale_x + pad_left
222
+ new_y = y * scale_y + pad_top
223
+
224
+ return new_x, new_y
225
+
226
+
227
+ def on_image_select(
228
+ evt: gr.SelectData, image, original_image, state,
229
+ prompt_mode, point_label,
230
+ ):
231
+ """Handle click on image and visualize the click."""
232
+ if image is None:
233
+ return state, "Please upload an image first", None
234
+
235
+ x, y = evt.index[0], evt.index[1]
236
+ label = 1 if "Positive" in point_label else 0
237
+
238
+ new_state = {
239
+ "points": list(state.get("points", [])),
240
+ "box": list(state.get("box", [])),
241
+ }
242
+
243
+ vis_image = (
244
+ original_image.copy()
245
+ if original_image is not None
246
+ else image.copy()
247
+ )
248
+
249
+ if prompt_mode == "Point":
250
+ new_state["points"].append((x, y, label))
251
+ new_state["box"] = []
252
+ label_str = "+" if label == 1 else "-"
253
+ info = (
254
+ f"Points: {len(new_state['points'])} total. "
255
+ f"Last: ({x}, {y}) [{label_str}]"
256
+ )
257
+ vis_image = draw_points_on_image(vis_image, new_state["points"])
258
+
259
+ elif prompt_mode in ("Box-to-Multi-Object", "Box-to-Single-Object"):
260
+ new_state["points"] = []
261
+ box_clicks = list(new_state.get("box", []))
262
+ box_clicks.append((x, y))
263
+
264
+ if len(box_clicks) == 1:
265
+ new_state["box"] = box_clicks
266
+ info = (
267
+ f"[{prompt_mode}] Corner 1: ({x}, {y}) "
268
+ f"- click again for corner 2"
269
+ )
270
+ vis_image = draw_points_on_image(vis_image, [(x, y, 1)])
271
+
272
+ elif len(box_clicks) >= 2:
273
+ x1, y1 = box_clicks[0]
274
+ x2, y2 = box_clicks[1]
275
+ box = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)]
276
+ new_state["box"] = [(box[0], box[1]), (box[2], box[3])]
277
+ info = (
278
+ f"[{prompt_mode}] Box: "
279
+ f"({box[0]}, {box[1]}) -> ({box[2]}, {box[3]})"
280
+ )
281
+ vis_image = draw_box_on_image(vis_image, box)
282
+ else:
283
+ info = f"Box clicks: {box_clicks}"
284
+ else:
285
+ info = "Text mode - just enter text and click Run"
286
+
287
+ return new_state, info, vis_image
288
+
289
+
290
+ def clear_clicks(state, original_image):
291
+ """Reset click state and restore original image."""
292
+ new_state = {"points": [], "box": []}
293
+ return (
294
+ new_state,
295
+ "Cleared - ready for new clicks",
296
+ original_image.copy() if original_image is not None else None,
297
+ )
298
+
299
+
300
+ @spaces.GPU
301
+ def run_wilddet3d(
302
+ image,
303
+ state,
304
+ prompt_mode,
305
+ text_prompt,
306
+ use_label,
307
+ label_text,
308
+ score_thres,
309
+ use_predicted_K,
310
+ fx, fy, cx, cy,
311
+ enable_3d_vis=True,
312
+ remove_edges=True,
313
+ point_density=2,
314
+ use_textured_mesh=True,
315
+ ):
316
+ """Run WildDet3D with selected prompt mode."""
317
+ if image is None:
318
+ return None, "Please upload an image first", None, None
319
+
320
+ # Convert RGBA to RGB if needed
321
+ if image.ndim == 3 and image.shape[2] == 4:
322
+ image = image[:, :, :3]
323
+
324
+ device = "cuda" if torch.cuda.is_available() else "cpu"
325
+ detector = get_model()
326
+
327
+ # Build intrinsics matrix (or None if using predicted)
328
+ if use_predicted_K:
329
+ intrinsics = None
330
+ else:
331
+ intrinsics = np.array([
332
+ [fx, 0, cx],
333
+ [0, fy, cy],
334
+ [0, 0, 1]
335
+ ], dtype=np.float32)
336
+
337
+ # Preprocess image
338
+ data = preprocess(image.astype(np.float32), intrinsics)
339
+
340
+ # Build prompt_text for box/point modes
341
+ if prompt_mode == "Box-to-Multi-Object":
342
+ prefix = "visual"
343
+ elif prompt_mode == "Box-to-Single-Object":
344
+ prefix = "geometric"
345
+ else:
346
+ prefix = "geometric" # Point mode default
347
+
348
+ if prompt_mode != "Text":
349
+ if use_label and label_text and label_text.strip():
350
+ geo_prompt_text = f"{prefix}: {label_text.strip()}"
351
+ else:
352
+ geo_prompt_text = prefix
353
+
354
+ # Initialize prompt info for visualization
355
+ prompt_points = None
356
+ prompt_box = None
357
+
358
+ # Run based on prompt mode
359
+ if prompt_mode == "Text":
360
+ input_texts = [
361
+ t.strip() for t in text_prompt.split(".") if t.strip()
362
+ ]
363
+ if not input_texts:
364
+ input_texts = ["object"]
365
+
366
+ results = detector(
367
+ images=data["images"].to(device),
368
+ intrinsics=data["intrinsics"].to(device)[None],
369
+ input_hw=[data["input_hw"]],
370
+ original_hw=[data["original_hw"]],
371
+ padding=[data["padding"]],
372
+ input_texts=input_texts,
373
+ return_predicted_intrinsics=True,
374
+ )
375
+ (
376
+ boxes, boxes3d, scores, scores_2d, scores_3d,
377
+ class_ids, depth_maps, predicted_K,
378
+ ) = results
379
+ class_id_mapping = {i: t for i, t in enumerate(input_texts)}
380
+
381
+ elif prompt_mode in ("Box-to-Multi-Object", "Box-to-Single-Object"):
382
+ box_coords = state.get("box", [])
383
+ if len(box_coords) < 2:
384
+ return (
385
+ None,
386
+ "Please click twice on the image to define a box",
387
+ None,
388
+ None,
389
+ )
390
+
391
+ x1_orig, y1_orig = box_coords[0]
392
+ x2_orig, y2_orig = box_coords[1]
393
+ x1, y1 = transform_coords_to_input_space(
394
+ x1_orig, y1_orig,
395
+ data["original_hw"], data["input_hw"], data["padding"],
396
+ )
397
+ x2, y2 = transform_coords_to_input_space(
398
+ x2_orig, y2_orig,
399
+ data["original_hw"], data["input_hw"], data["padding"],
400
+ )
401
+ box_xyxy = [float(x1), float(y1), float(x2), float(y2)]
402
+
403
+ prompt_box = [x1_orig, y1_orig, x2_orig, y2_orig]
404
+
405
+ results = detector(
406
+ images=data["images"].to(device),
407
+ intrinsics=data["intrinsics"].to(device)[None],
408
+ input_hw=[data["input_hw"]],
409
+ original_hw=[data["original_hw"]],
410
+ padding=[data["padding"]],
411
+ input_boxes=[box_xyxy],
412
+ prompt_text=geo_prompt_text,
413
+ return_predicted_intrinsics=True,
414
+ )
415
+ (
416
+ boxes, boxes3d, scores, scores_2d, scores_3d,
417
+ class_ids, depth_maps, predicted_K,
418
+ ) = results
419
+ class_id_mapping = {0: geo_prompt_text}
420
+
421
+ elif prompt_mode == "Point":
422
+ points = state.get("points", [])
423
+ if not points:
424
+ return (
425
+ None,
426
+ "Please click on the image to select a point",
427
+ None,
428
+ None,
429
+ )
430
+
431
+ transformed_points = []
432
+ for x_orig, y_orig, lbl in points:
433
+ x, y = transform_coords_to_input_space(
434
+ x_orig, y_orig,
435
+ data["original_hw"], data["input_hw"], data["padding"],
436
+ )
437
+ transformed_points.append((x, y, lbl))
438
+
439
+ prompt_points = points
440
+
441
+ results = detector(
442
+ images=data["images"].to(device),
443
+ intrinsics=data["intrinsics"].to(device)[None],
444
+ input_hw=[data["input_hw"]],
445
+ original_hw=[data["original_hw"]],
446
+ padding=[data["padding"]],
447
+ input_points=[transformed_points],
448
+ prompt_text=geo_prompt_text,
449
+ return_predicted_intrinsics=True,
450
+ )
451
+ (
452
+ boxes, boxes3d, scores, scores_2d, scores_3d,
453
+ class_ids, depth_maps, predicted_K,
454
+ ) = results
455
+ class_id_mapping = {0: geo_prompt_text}
456
+
457
+ else:
458
+ return None, f"Unknown prompt mode: {prompt_mode}", None, None
459
+
460
+ # Scale predicted intrinsics to original resolution
461
+ predicted_K_scaled = scale_intrinsics_to_original(
462
+ predicted_K,
463
+ input_hw=data["input_hw"],
464
+ original_hw=data["original_hw"],
465
+ )
466
+
467
+ # Format intrinsics info
468
+ orig_h, orig_w = data["original_hw"]
469
+ intrinsics_info = f"Image: {orig_w}x{orig_h}\n"
470
+ intrinsics_info += f"Predicted: {format_intrinsics(predicted_K_scaled)}"
471
+ if not use_predicted_K:
472
+ intrinsics_info = f"Image: {orig_w}x{orig_h}\n"
473
+ intrinsics_info += (
474
+ f"Used: fx={fx:.2f}, fy={fy:.2f}, "
475
+ f"cx={cx:.2f}, cy={cy:.2f}\n"
476
+ )
477
+ intrinsics_info += (
478
+ f"Predicted: {format_intrinsics(predicted_K_scaled)}"
479
+ )
480
+
481
+ # 2D visualization
482
+ img_2d = visualize_results(
483
+ data, boxes3d, scores, scores_2d, scores_3d,
484
+ class_ids, class_id_mapping, score_thres,
485
+ prompt_points=prompt_points, prompt_box=prompt_box,
486
+ )
487
+
488
+ # Depth map visualization
489
+ depth_vis_img = None
490
+ if depth_maps is not None and len(depth_maps) > 0:
491
+ depth_np_raw = depth_maps[0].cpu().numpy()
492
+ d = depth_np_raw.squeeze()
493
+
494
+ pad_l, pad_r, pad_t, pad_b = data["padding"]
495
+ h_end = d.shape[0] - pad_b if pad_b > 0 else d.shape[0]
496
+ w_end = d.shape[1] - pad_r if pad_r > 0 else d.shape[1]
497
+ d_crop = d[pad_t:h_end, pad_l:w_end]
498
+
499
+ d_valid = d_crop[d_crop > 0.01]
500
+ if len(d_valid) > 0:
501
+ d_min, d_max = d_valid.min(), d_valid.max()
502
+ d_norm = np.clip(
503
+ (d_crop - d_min) / (d_max - d_min + 1e-6), 0, 1
504
+ )
505
+ d_norm = (1.0 - d_norm) * 255
506
+ d_norm = d_norm.astype(np.uint8)
507
+ depth_vis_img = cv2.applyColorMap(d_norm, cv2.COLORMAP_TURBO)
508
+ depth_vis_img = cv2.cvtColor(depth_vis_img, cv2.COLOR_BGR2RGB)
509
+ depth_vis_img = Image.fromarray(depth_vis_img)
510
+
511
+ # 3D visualization (optional)
512
+ glb_path = None
513
+ if enable_3d_vis and depth_maps is not None and len(depth_maps) > 0:
514
+ depth_np = depth_maps[0].cpu().numpy()
515
+
516
+ input_img = data["images"].cpu()
517
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
518
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
519
+ input_img = (input_img * std + mean).clamp(0, 1) * 255
520
+ input_img = (
521
+ input_img.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
522
+ )
523
+
524
+ K_for_unproj = data["intrinsics"].cpu().numpy()
525
+
526
+ filtered_boxes3d_np = []
527
+ for i in range(len(boxes3d)):
528
+ mask = scores[i] >= score_thres
529
+ filtered_boxes3d_np.append(boxes3d[i][mask].cpu().numpy())
530
+
531
+ glb_path = "/tmp/wilddet3d_scene.glb"
532
+
533
+ if use_textured_mesh:
534
+ create_mesh_scene_glb(
535
+ depth_np, input_img, K_for_unproj,
536
+ filtered_boxes3d_np, glb_path,
537
+ max_depth=20.0,
538
+ padding=data["padding"],
539
+ remove_edge=remove_edges,
540
+ edge_rtol=0.04,
541
+ )
542
+ else:
543
+ subsample = max(1, int(point_density))
544
+ points, point_colors = depth_to_pointcloud(
545
+ depth_np, input_img, K_for_unproj,
546
+ max_depth=20.0, subsample=subsample,
547
+ padding=data["padding"],
548
+ remove_edge=remove_edges,
549
+ edge_rtol=0.04,
550
+ )
551
+ create_scene_glb(
552
+ points, point_colors, filtered_boxes3d_np, glb_path
553
+ )
554
+
555
+ return img_2d, intrinsics_info, glb_path, depth_vis_img
556
+
557
+
558
+ def visualize_results(
559
+ data, boxes3d, scores, scores_2d, scores_3d, class_ids,
560
+ class_id_mapping, score_thres,
561
+ prompt_points=None, prompt_box=None,
562
+ ):
563
+ """Visualize 3D detection results using wilddet3d.vis.draw_3d_boxes."""
564
+ filtered_boxes3d = []
565
+ filtered_scores_2d = []
566
+ filtered_scores_3d = []
567
+ filtered_class_ids = []
568
+
569
+ for i in range(len(boxes3d)):
570
+ mask = scores[i] >= score_thres
571
+ filtered_boxes3d.append(boxes3d[i][mask])
572
+ if scores_2d is not None:
573
+ filtered_scores_2d.append(scores_2d[i][mask])
574
+ else:
575
+ filtered_scores_2d.append(torch.zeros_like(scores[i][mask]))
576
+ if scores_3d is not None:
577
+ filtered_scores_3d.append(scores_3d[i][mask])
578
+ else:
579
+ filtered_scores_3d.append(torch.zeros_like(scores[i][mask]))
580
+ filtered_class_ids.append(class_ids[i][mask])
581
+
582
+ # Get original image and draw prompts on it
583
+ original_img = data["original_images"].cpu().numpy().astype(np.uint8)
584
+
585
+ if prompt_points is not None and len(prompt_points) > 0:
586
+ original_img = draw_points_on_image(original_img, prompt_points)
587
+
588
+ if prompt_box is not None and len(prompt_box) == 4:
589
+ original_img = draw_box_on_image(original_img, prompt_box)
590
+
591
+ # Use wilddet3d's draw_3d_boxes for visualization
592
+ K = data["original_intrinsics"].cpu().numpy()
593
+ if K.ndim == 3:
594
+ K = K[0]
595
+
596
+ class_names = [
597
+ class_id_mapping.get(i, str(i))
598
+ for i in range(max(len(class_id_mapping), 1))
599
+ ]
600
+
601
+ # Draw 3D boxes with 2D/3D score labels
602
+ if len(filtered_boxes3d) > 0 and len(filtered_boxes3d[0]) > 0:
603
+ pil_img = draw_3d_boxes(
604
+ image=original_img,
605
+ boxes3d=filtered_boxes3d[0],
606
+ intrinsics=K,
607
+ scores_2d=filtered_scores_2d[0],
608
+ scores_3d=filtered_scores_3d[0],
609
+ class_ids=filtered_class_ids[0],
610
+ class_names=class_names,
611
+ n_colors=max(len(class_id_mapping), 1),
612
+ )
613
+ else:
614
+ pil_img = Image.fromarray(original_img)
615
+
616
+ return pil_img
617
+
618
+
619
+ # Load default values
620
+ default_fx, default_fy, default_cx, default_cy = load_default_intrinsics()
621
+ default_image = load_default_image()
622
+
623
+ # Build Gradio interface
624
+ with gr.Blocks(title="WildDet3D: 3D Detection") as demo:
625
+ gr.Markdown("# WildDet3D: Open-Vocabulary 3D Detection in the Wild")
626
+ gr.Markdown("""
627
+ **How to use:**
628
+ - **Text**: Enter object names (e.g., "chair.table"), click Run
629
+ - **Box-to-Multi-Object**: Draw box -> detect ALL similar objects (one-to-many)
630
+ - **Box-to-Single-Object**: Draw box -> detect ONLY the boxed object (one-to-one)
631
+ - **Point**: Click on object, click Run
632
+ - **+ Label**: Check this to attach a category name (e.g., "chair") to box/point prompts
633
+ """)
634
+
635
+ # State for click coordinates and original image
636
+ click_state = gr.State({"points": [], "box": []})
637
+ original_image_state = gr.State(
638
+ default_image.copy() if default_image is not None else None
639
+ )
640
+
641
+ with gr.Row():
642
+ # Left column: Input
643
+ with gr.Column(scale=1):
644
+ input_image = gr.Image(
645
+ label="Input Image (click for Box/Point mode)",
646
+ type="numpy",
647
+ value=default_image,
648
+ interactive=True,
649
+ sources=["upload", "clipboard"],
650
+ )
651
+
652
+ # Prompt settings
653
+ prompt_mode = gr.Radio(
654
+ choices=[
655
+ "Text",
656
+ "Box-to-Multi-Object",
657
+ "Box-to-Single-Object",
658
+ "Point",
659
+ ],
660
+ value="Text",
661
+ label="Prompt Mode",
662
+ )
663
+ text_prompt = gr.Textbox(
664
+ label="Text Prompt (e.g. 'chair.table')",
665
+ value="chair.table",
666
+ placeholder="Enter object names separated by '.'",
667
+ visible=True,
668
+ )
669
+ use_label = gr.Checkbox(
670
+ label="+ Label (attach category name to box/point prompt)",
671
+ value=False,
672
+ visible=False,
673
+ )
674
+ label_text = gr.Textbox(
675
+ label="Category Label (e.g. 'chair')",
676
+ value="",
677
+ placeholder="Category name for the selected object",
678
+ visible=False,
679
+ )
680
+
681
+ # Point label for Point mode
682
+ point_label = gr.Radio(
683
+ choices=["Positive (include)", "Negative (exclude)"],
684
+ value="Positive (include)",
685
+ label="Point Label (for Point mode)",
686
+ visible=False,
687
+ )
688
+
689
+ # Click info display
690
+ click_info = gr.Textbox(
691
+ label="Click Info",
692
+ value="Select mode and click on image",
693
+ interactive=False,
694
+ )
695
+
696
+ with gr.Row():
697
+ clear_btn = gr.Button("Clear Clicks")
698
+ run_btn = gr.Button("Run Detection", variant="primary")
699
+
700
+ # Intrinsics settings
701
+ use_predicted_K = gr.Checkbox(
702
+ label="Use Predicted Intrinsics",
703
+ value=True,
704
+ )
705
+ with gr.Row():
706
+ fx = gr.Number(label="fx", value=default_fx)
707
+ fy = gr.Number(label="fy", value=default_fy)
708
+ cx = gr.Number(label="cx", value=default_cx)
709
+ cy = gr.Number(label="cy", value=default_cy)
710
+
711
+ score_thres = gr.Slider(
712
+ minimum=0, maximum=1, value=0.3, step=0.05,
713
+ label="Score Threshold",
714
+ )
715
+
716
+ # 3D visualization settings
717
+ gr.Markdown("### 3D Visualization Settings")
718
+ enable_3d_vis = gr.Checkbox(
719
+ label="Enable 3D Point Cloud / Mesh Visualization",
720
+ value=False,
721
+ )
722
+ gr.Markdown(
723
+ "*Notice: the model takes the depth latent to generate "
724
+ "3D boxes, so the boxes and the point cloud might not "
725
+ "exactly match.*"
726
+ )
727
+ use_textured_mesh = gr.Checkbox(
728
+ label="Textured Mesh (otherwise point cloud)",
729
+ value=True,
730
+ )
731
+ remove_edges = gr.Checkbox(
732
+ label="Remove depth edges (cleaner geometry)",
733
+ value=True,
734
+ )
735
+ point_density = gr.Slider(
736
+ minimum=1, maximum=8, value=2, step=1,
737
+ label="Point Subsample (point cloud mode only, 1=dense)",
738
+ )
739
+
740
+ # Right column: Output
741
+ with gr.Column(scale=1):
742
+ output_image = gr.Image(
743
+ label="2D Detection Results", type="pil"
744
+ )
745
+ depth_image = gr.Image(label="Depth Map", type="pil")
746
+ output_3d = gr.Model3D(
747
+ label="3D View (Mesh/Point Cloud + Boxes)",
748
+ clear_color=(0.1, 0.1, 0.1, 1.0),
749
+ )
750
+ intrinsics_info = gr.Textbox(
751
+ label="Intrinsics Info", interactive=False
752
+ )
753
+
754
+ # Toggle visibility based on prompt mode
755
+ def on_mode_change(mode):
756
+ is_text = mode == "Text"
757
+ is_point = mode == "Point"
758
+ return (
759
+ gr.update(visible=is_text), # text_prompt
760
+ gr.update(visible=not is_text), # use_label
761
+ gr.update(visible=not is_text), # label_text
762
+ gr.update(visible=is_point), # point_label
763
+ )
764
+
765
+ prompt_mode.change(
766
+ on_mode_change,
767
+ inputs=[prompt_mode],
768
+ outputs=[text_prompt, use_label, label_text, point_label],
769
+ )
770
+
771
+ # Connect events
772
+ input_image.select(
773
+ on_image_select,
774
+ inputs=[
775
+ input_image, original_image_state, click_state,
776
+ prompt_mode, point_label,
777
+ ],
778
+ outputs=[click_state, click_info, input_image],
779
+ )
780
+
781
+ clear_btn.click(
782
+ clear_clicks,
783
+ inputs=[click_state, original_image_state],
784
+ outputs=[click_state, click_info, input_image],
785
+ )
786
+
787
+ # When new image is uploaded, save it as original
788
+ def on_image_upload(image):
789
+ if image is None:
790
+ return None, {"points": [], "box": []}, "Upload an image"
791
+ return (
792
+ image.copy(),
793
+ {"points": [], "box": []},
794
+ "Image loaded - select mode and click",
795
+ )
796
+
797
+ input_image.upload(
798
+ on_image_upload,
799
+ inputs=[input_image],
800
+ outputs=[original_image_state, click_state, click_info],
801
+ )
802
+
803
+ run_btn.click(
804
+ run_wilddet3d,
805
+ inputs=[
806
+ input_image, click_state, prompt_mode, text_prompt,
807
+ use_label, label_text, score_thres, use_predicted_K,
808
+ fx, fy, cx, cy,
809
+ enable_3d_vis, remove_edges, point_density, use_textured_mesh,
810
+ ],
811
+ outputs=[output_image, intrinsics_info, output_3d, depth_image],
812
+ )
813
+
814
+
815
+ if __name__ == "__main__":
816
+ print("=" * 60)
817
+ print("WildDet3D Web Demo")
818
+ print("=" * 60)
819
+ print()
820
+ print("Starting server...")
821
+ port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
822
+ demo.launch(share=False, server_name="0.0.0.0", server_port=port)
assets/demo/intrinsics.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5e46d677b736c45d98fda89d2b4b6b8e88028f8c7a5e25df6c9c3e61f6c6fed
3
+ size 164
assets/demo/rgb.png ADDED

Git LFS Details

  • SHA256: 377def0b77a5d11be17fdf3f48466a7dfcde7fff9fd10e1e2f68c57efb18736e
  • Pointer size: 131 Bytes
  • Size of remote file: 449 kB
requirements.txt ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vis4D (same approach: install dependencies, not vis4d itself)
2
+ absl-py
3
+ appdirs
4
+ cloudpickle
5
+ cython
6
+ devtools
7
+ h5py
8
+ jsonargparse[signatures]
9
+ lightning
10
+ ml_collections==1.1.0
11
+ numpy>=1.21.0,<2.0.0
12
+ opencv-python
13
+ pandas
14
+ pillow
15
+ plyfile
16
+ pycocotools
17
+ pydantic>=2.0
18
+ setuptools
19
+ tensorboard
20
+ termcolor
21
+ terminaltables
22
+ timm>=0.6.0
23
+ torch>=2.0.0
24
+ torchvision>=0.15.1
25
+ tqdm
26
+ utm
27
+ wheel
28
+ scipy
29
+
30
+ # Git utils
31
+ gitdb
32
+ GitPython
33
+
34
+ # WildDet3D
35
+ einops
36
+ fvcore
37
+ nltk
38
+ transformers
39
+ fairscale
40
+ mmengine
41
+ decord
42
+
43
+ # SAM3 dependencies
44
+ ftfy
45
+ regex
46
+ iopath
47
+ omegaconf
48
+ hydra-core
49
+ scikit-image
50
+ scikit-learn
51
+ open_clip_torch
52
+
53
+ # 3D visualization
54
+ pygltflib
55
+ trimesh
56
+ utils3d
57
+
58
+ # Depth estimation
59
+ huggingface_hub
third_party/lingbot_depth/mdm/model/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import *
3
+
4
+ if TYPE_CHECKING:
5
+ from .v2 import MDMModel as MDMModelV2
6
+
7
+ def import_model_class_by_version(version: str) -> Type[Union['MDMModelV2']]:
8
+ assert version in ['v2'], f'Unsupported model version: {version}'
9
+
10
+ try:
11
+ module = importlib.import_module(f'.{version}', __package__)
12
+ except ModuleNotFoundError:
13
+ raise ValueError(f'Model version "{version}" not found.')
14
+ cls = getattr(module, 'MDMModel')
15
+ return cls
third_party/lingbot_depth/mdm/model/dinov2_rgbd/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/backbones.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+ def dinov2_vitl16(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
85
+ """
86
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
87
+ """
88
+ # kwargs.update({'img_size': 224, 'patch_size': 16, })
89
+ return _make_dinov2_model(arch_name="vit_large", pretrained=False, weights=weights, **kwargs)
90
+
91
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
92
+ """
93
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
94
+ """
95
+ return _make_dinov2_model(
96
+ arch_name="vit_giant2",
97
+ ffn_layer="swiglufused",
98
+ weights=weights,
99
+ pretrained=pretrained,
100
+ **kwargs,
101
+ )
102
+
103
+
104
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
105
+ """
106
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
107
+ """
108
+ return _make_dinov2_model(
109
+ arch_name="vit_small",
110
+ pretrained=pretrained,
111
+ weights=weights,
112
+ num_register_tokens=4,
113
+ interpolate_antialias=True,
114
+ interpolate_offset=0.0,
115
+ **kwargs,
116
+ )
117
+
118
+
119
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
120
+ """
121
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
122
+ """
123
+ return _make_dinov2_model(
124
+ arch_name="vit_base",
125
+ pretrained=pretrained,
126
+ weights=weights,
127
+ num_register_tokens=4,
128
+ interpolate_antialias=True,
129
+ interpolate_offset=0.0,
130
+ **kwargs,
131
+ )
132
+
133
+
134
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
135
+ """
136
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
137
+ """
138
+ return _make_dinov2_model(
139
+ arch_name="vit_large",
140
+ pretrained=pretrained,
141
+ weights=weights,
142
+ num_register_tokens=4,
143
+ interpolate_antialias=True,
144
+ interpolate_offset=0.0,
145
+ **kwargs,
146
+ )
147
+
148
+
149
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
150
+ """
151
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
152
+ """
153
+ return _make_dinov2_model(
154
+ arch_name="vit_giant2",
155
+ ffn_layer="swiglufused",
156
+ weights=weights,
157
+ pretrained=pretrained,
158
+ num_register_tokens=4,
159
+ interpolate_antialias=True,
160
+ interpolate_offset=0.0,
161
+ **kwargs,
162
+ )
third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
12
+ from .patch_embed_mlp import PatchEmbed as PatchEmbedMLP
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/attention.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ import torch.nn.functional as F
15
+ from torch import Tensor
16
+ from torch import nn
17
+
18
+
19
+ logger = logging.getLogger("dinov2")
20
+
21
+
22
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
23
+ try:
24
+ if XFORMERS_ENABLED:
25
+ from xformers.ops import memory_efficient_attention, unbind
26
+
27
+ XFORMERS_AVAILABLE = True
28
+ # warnings.warn("xFormers is available (Attention)")
29
+ else:
30
+ # warnings.warn("xFormers is disabled (Attention)")
31
+ raise ImportError
32
+ except ImportError:
33
+ XFORMERS_AVAILABLE = False
34
+ # warnings.warn("xFormers is not available (Attention)")
35
+
36
+
37
+ class Attention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ num_heads: int = 8,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ attn_drop: float = 0.0,
45
+ proj_drop: float = 0.0,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.num_heads = num_heads
49
+ head_dim = dim // num_heads
50
+ self.scale = head_dim**-0.5
51
+
52
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
53
+ self.attn_drop = nn.Dropout(attn_drop)
54
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
55
+ self.proj_drop = nn.Dropout(proj_drop)
56
+
57
+ # # Deprecated implementation, extremely slow
58
+ # def forward(self, x: Tensor, attn_bias=None) -> Tensor:
59
+ # B, N, C = x.shape
60
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
61
+ # q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
62
+ # attn = q @ k.transpose(-2, -1)
63
+ # attn = attn.softmax(dim=-1)
64
+ # attn = self.attn_drop(attn)
65
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
66
+ # x = self.proj(x)
67
+ # x = self.proj_drop(x)
68
+ # return x
69
+
70
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
73
+
74
+ q, k, v = qkv.unbind(0) # (B, H, N, C // H)
75
+
76
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
77
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+ class MemEffAttention(Attention):
84
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
85
+ if not XFORMERS_AVAILABLE:
86
+ if attn_bias is not None:
87
+ raise AssertionError("xFormers is required for using nested tensors")
88
+ return super().forward(x)
89
+
90
+ B, N, C = x.shape
91
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
92
+
93
+ q, k, v = unbind(qkv, 2)
94
+
95
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
96
+ x = x.reshape([B, N, C])
97
+
98
+ x = self.proj(x)
99
+ x = self.proj_drop(x)
100
+ return x
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ # warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ # warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+ # warnings.warn("xFormers is not available (Block)")
40
+
41
+
42
+ class Block(nn.Module):
43
+ def __init__(
44
+ self,
45
+ dim: int,
46
+ num_heads: int,
47
+ mlp_ratio: float = 4.0,
48
+ qkv_bias: bool = False,
49
+ proj_bias: bool = True,
50
+ ffn_bias: bool = True,
51
+ drop: float = 0.0,
52
+ attn_drop: float = 0.0,
53
+ init_values=None,
54
+ drop_path: float = 0.0,
55
+ act_layer: Callable[..., nn.Module] = nn.GELU,
56
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
57
+ attn_class: Callable[..., nn.Module] = Attention,
58
+ ffn_layer: Callable[..., nn.Module] = Mlp,
59
+ ) -> None:
60
+ super().__init__()
61
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
62
+ self.norm1 = norm_layer(dim)
63
+ self.attn = attn_class(
64
+ dim,
65
+ num_heads=num_heads,
66
+ qkv_bias=qkv_bias,
67
+ proj_bias=proj_bias,
68
+ attn_drop=attn_drop,
69
+ proj_drop=drop,
70
+ )
71
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
72
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
73
+
74
+ self.norm2 = norm_layer(dim)
75
+ mlp_hidden_dim = int(dim * mlp_ratio)
76
+ self.mlp = ffn_layer(
77
+ in_features=dim,
78
+ hidden_features=mlp_hidden_dim,
79
+ act_layer=act_layer,
80
+ drop=drop,
81
+ bias=ffn_bias,
82
+ )
83
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
84
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
85
+
86
+ self.sample_drop_ratio = drop_path
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ def attn_residual_func(x: Tensor) -> Tensor:
90
+ return self.ls1(self.attn(self.norm1(x)))
91
+
92
+ def ffn_residual_func(x: Tensor) -> Tensor:
93
+ return self.ls2(self.mlp(self.norm2(x)))
94
+
95
+ if self.training and self.sample_drop_ratio > 0.1:
96
+ # the overhead is compensated only for a drop path rate larger than 0.1
97
+ x = drop_add_residual_stochastic_depth(
98
+ x,
99
+ residual_func=attn_residual_func,
100
+ sample_drop_ratio=self.sample_drop_ratio,
101
+ )
102
+ x = drop_add_residual_stochastic_depth(
103
+ x,
104
+ residual_func=ffn_residual_func,
105
+ sample_drop_ratio=self.sample_drop_ratio,
106
+ )
107
+ elif self.training and self.sample_drop_ratio > 0.0:
108
+ x = x + self.drop_path1(attn_residual_func(x))
109
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
110
+ else:
111
+ x = x + attn_residual_func(x)
112
+ x = x + ffn_residual_func(x)
113
+ return x
114
+
115
+
116
+ def drop_add_residual_stochastic_depth(
117
+ x: Tensor,
118
+ residual_func: Callable[[Tensor], Tensor],
119
+ sample_drop_ratio: float = 0.0,
120
+ ) -> Tensor:
121
+ # 1) extract subset using permutation
122
+ b, n, d = x.shape
123
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
124
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
125
+ x_subset = x[brange]
126
+
127
+ # 2) apply residual_func to get residual
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ import torch
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+ class PixelUnshuffle (nn.Module):
26
+ def __init__(self, downscale_factor):
27
+ super().__init__()
28
+ self.downscale_factor = downscale_factor
29
+
30
+ def forward(self, input):
31
+ if input.numel() == 0:
32
+ # this is not in the original torch implementation
33
+ C,H,W = input.shape[-3:]
34
+ assert H and W and H % self.downscale_factor == W%self.downscale_factor == 0
35
+ return input.view(*input.shape[:-3], C*self.downscale_factor**2, H//self.downscale_factor, W//self.downscale_factor)
36
+ else:
37
+ return F.pixel_unshuffle(input, self.downscale_factor)
38
+
39
+ class Permute(nn.Module):
40
+ dims: tuple[int, ...]
41
+ def __init__(self, dims: tuple[int, ...]) -> None:
42
+ super().__init__()
43
+ self.dims = tuple(dims)
44
+
45
+ def __repr__(self):
46
+ return f"Permute{self.dims}"
47
+
48
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
49
+ return input.permute(*self.dims)
50
+
51
+ from itertools import repeat
52
+ import collections.abc
53
+ def _ntuple(n):
54
+ def parse(x):
55
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
56
+ return x
57
+ return tuple(repeat(x, n))
58
+ return parse
59
+ to_2tuple = _ntuple(2)
60
+
61
+ class Mlp(nn.Module):
62
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
63
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
64
+ super().__init__()
65
+ out_features = out_features or in_features
66
+ hidden_features = hidden_features or in_features
67
+ bias = to_2tuple(bias)
68
+ drop_probs = to_2tuple(drop)
69
+
70
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
71
+ self.act = act_layer()
72
+ self.drop1 = nn.Dropout(drop_probs[0])
73
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
74
+ self.drop2 = nn.Dropout(drop_probs[1])
75
+
76
+ def forward(self, x):
77
+ x = self.fc1(x)
78
+ x = self.act(x)
79
+ x = self.drop1(x)
80
+ x = self.fc2(x)
81
+ x = self.drop2(x)
82
+ return x
83
+
84
+ class PatchEmbed(nn.Module):
85
+ """
86
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
87
+
88
+ Args:
89
+ img_size: Image size.
90
+ patch_size: Patch token size.
91
+ in_chans: Number of input image channels.
92
+ embed_dim: Number of linear projection output channels.
93
+ norm_layer: Normalization layer.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ img_size: Union[int, Tuple[int, int]] = 224,
99
+ patch_size: Union[int, Tuple[int, int]] = 16,
100
+ in_chans: int = 3,
101
+ embed_dim: int = 768,
102
+ norm_layer: Optional[Callable] = None,
103
+ flatten_embedding: bool = True,
104
+ ) -> None:
105
+ super().__init__()
106
+
107
+ image_HW = make_2tuple(img_size)
108
+ patch_HW = make_2tuple(patch_size)
109
+ patch_grid_size = (
110
+ image_HW[0] // patch_HW[0],
111
+ image_HW[1] // patch_HW[1],
112
+ )
113
+
114
+ self.img_size = image_HW
115
+ self.patch_size = patch_HW
116
+ self.patches_resolution = patch_grid_size
117
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
118
+
119
+ self.in_chans = in_chans
120
+ self.embed_dim = embed_dim
121
+
122
+ self.flatten_embedding = flatten_embedding
123
+
124
+ self.proj = nn.Sequential(
125
+ PixelUnshuffle(patch_size),
126
+ Permute((0,2,3,1)),
127
+ Mlp(in_chans * patch_size * patch_size, 4*embed_dim, embed_dim),
128
+ Permute((0,3,1,2)),
129
+ )
130
+
131
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
132
+
133
+ def forward(self, x: Tensor) -> Tensor:
134
+ _, _, H, W = x.shape
135
+ patch_H, patch_W = self.patch_size
136
+
137
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
138
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
139
+
140
+ x = self.proj(x) # B C H W
141
+ H, W = x.size(2), x.size(3)
142
+ x = x.flatten(2).transpose(1, 2) # B HW C
143
+ x = self.norm(x)
144
+ if not self.flatten_embedding:
145
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
146
+ return x
147
+
148
+ def flops(self) -> float:
149
+ Ho, Wo = self.patches_resolution
150
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
151
+ if self.norm is not None:
152
+ flops += Ho * Wo * self.embed_dim
153
+ return flops
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ '''
7
+ Docstring for MDM.mdm.model.dinov2_rgbd.models_vlmae
8
+ =======================================================
9
+ This version is modified from the original DINOv2 to support the MIM(masked image modeling) of RGBD input.
10
+ (The original DINOv2 is available at https://github.com/facebookresearch/dinov2.)
11
+
12
+ Core Changes:
13
+ 1. We add the depth input into the original DINOv2 transformer encoder.
14
+
15
+ 2. We support the Variable Mask Ratio MAE for both RGB and Depth input.
16
+ '''
17
+
18
+ import logging
19
+
20
+ from . import vision_transformer as vits
21
+
22
+ logger = logging.getLogger("dinov2")
23
+
24
+
25
+ def build_model(args, only_teacher=False, img_size=224):
26
+ args.arch = args.arch.removesuffix("_memeff")
27
+ if "vit" in args.arch:
28
+ vit_kwargs = dict(
29
+ img_size=img_size,
30
+ patch_size=args.patch_size,
31
+ init_values=args.layerscale,
32
+ ffn_layer=args.ffn_layer,
33
+ block_chunks=args.block_chunks,
34
+ qkv_bias=args.qkv_bias,
35
+ proj_bias=args.proj_bias,
36
+ ffn_bias=args.ffn_bias,
37
+ num_register_tokens=args.num_register_tokens,
38
+ interpolate_offset=args.interpolate_offset,
39
+ interpolate_antialias=args.interpolate_antialias,
40
+ )
41
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
42
+ if only_teacher:
43
+ return teacher, teacher.embed_dim
44
+ student = vits.__dict__[args.arch](
45
+ **vit_kwargs,
46
+ drop_path_rate=args.drop_path_rate,
47
+ drop_path_uniform=args.drop_path_uniform,
48
+ )
49
+ embed_dim = student.embed_dim
50
+ return student, teacher, embed_dim
51
+
52
+
53
+ def build_model_from_cfg(cfg, only_teacher=False):
54
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
55
+
third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/mask_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ def depth_masking(
3
+ x,
4
+ patch_num_h,
5
+ patch_num_w,
6
+ depth_values,
7
+ depth_mask_threshold_ratio=None,
8
+ depth_mask_threshold_num=None,
9
+ valid_depth_range=(0.1, 10.0),
10
+ ):
11
+ """
12
+ Perform patch masking based on depth validity
13
+
14
+ Args:
15
+ x: [B, N, D] input features (after patch embedding)
16
+ patch_num_h: int, height of the patch grid
17
+ patch_num_w: int, width of the patch grid
18
+ depth_values: [B, 1, H_img, W_img], raw depth map
19
+ depth_mask_threshold_ratio: float or list, valid depth ratio threshold (0-1)
20
+ depth_mask_threshold_num: int or list, valid depth pixel count threshold
21
+ valid_depth_range: tuple, valid depth range (min, max)
22
+
23
+ Returns:
24
+ visible_list: list of [N_visible_i, D], visible patches for each sample
25
+ mask_info: dict, containing masking information
26
+ """
27
+ B, N, D = x.shape
28
+ device = x.device
29
+
30
+ assert N == patch_num_h * patch_num_w, \
31
+ f"N={N} must equal patch_num_h * patch_num_w = {patch_num_h * patch_num_w}"
32
+
33
+ # Compute depth invalid mask
34
+ depth_invalid_mask = _compute_depth_invalid_mask(
35
+ depth_values,
36
+ patch_num_h,
37
+ patch_num_w,
38
+ depth_mask_threshold_ratio,
39
+ depth_mask_threshold_num,
40
+ valid_depth_range
41
+ ) # [B, N], True indicates this patch is invalid
42
+
43
+ # Process each sample separately
44
+ visible_list = []
45
+ mask_info = {
46
+ 'visible_indices': [],
47
+ 'mask_indices': [],
48
+ 'num_visible': [],
49
+ }
50
+
51
+ for i in range(B):
52
+ # Get valid patch indices
53
+ valid_mask = ~depth_invalid_mask[i] # [N]
54
+ visible_indices = torch.where(valid_mask)[0]
55
+ masked_indices = torch.where(depth_invalid_mask[i])[0]
56
+
57
+ # Extract visible patches
58
+ visible = x[i, visible_indices] # [N_visible, D]
59
+ visible_list.append(visible)
60
+
61
+ # Record information
62
+ mask_info['visible_indices'].append(visible_indices)
63
+ mask_info['mask_indices'].append(masked_indices)
64
+ mask_info['num_visible'].append(len(visible_indices))
65
+
66
+ return visible_list, mask_info
67
+
68
+ def _compute_depth_invalid_mask(
69
+ depth_values,
70
+ H_patch,
71
+ W_patch,
72
+ threshold_ratio,
73
+ threshold_num,
74
+ valid_range
75
+ ):
76
+ """
77
+ Compute depth validity for each patch
78
+
79
+ Args:
80
+ depth_values: [B, 1, H_img, W_img] raw depth map
81
+ H_patch, W_patch: patch grid dimensions
82
+ threshold_ratio: float or list, valid depth ratio threshold
83
+ threshold_num: int or list, valid depth pixel count threshold
84
+ valid_range: tuple, (min_depth, max_depth)
85
+
86
+ Returns:
87
+ invalid_mask: [B, N] bool tensor, True indicates this patch is invalid
88
+ """
89
+ B, _, H_img, W_img = depth_values.shape
90
+ N = H_patch * W_patch
91
+ device = depth_values.device
92
+
93
+ min_depth, max_depth = valid_range
94
+
95
+ # Calculate pixel size for each patch
96
+ patch_h = H_img // H_patch
97
+ patch_w = W_img // W_patch
98
+
99
+ assert H_img % H_patch == 0 and W_img % W_patch == 0, \
100
+ f"Image size ({H_img}, {W_img}) must be divisible by patch grid ({H_patch}, {W_patch})"
101
+
102
+ # Reshape depth map into patches: [B, 1, H_img, W_img] -> [B, H_patch, patch_h, W_patch, patch_w]
103
+ depth_reshaped = depth_values.view(B, 1, H_patch, patch_h, W_patch, patch_w)
104
+
105
+ # Transpose and flatten: [B, H_patch, W_patch, patch_h, patch_w] -> [B, N, patch_h*patch_w]
106
+ depth_reshaped = depth_reshaped.permute(0, 2, 4, 1, 3, 5).reshape(B, N, -1)
107
+
108
+ # Calculate valid depth
109
+ valid_depth = (depth_reshaped >= min_depth) & (depth_reshaped <= max_depth)
110
+ valid_depth_ratio = valid_depth.float().mean(dim=-1) # [B, N]
111
+ valid_depth_num = valid_depth.float().sum(dim=-1) # [B, N]
112
+
113
+ # Handle list-form thresholds (different thresholds for each sample in batch)
114
+ if isinstance(threshold_ratio, list) or isinstance(threshold_num, list):
115
+ invalid_mask = torch.zeros(B, N, dtype=torch.bool, device=device)
116
+
117
+ for i in range(B):
118
+ tr = threshold_ratio[i] if isinstance(threshold_ratio, list) else threshold_ratio
119
+ tn = threshold_num[i] if isinstance(threshold_num, list) else threshold_num
120
+
121
+ sample_mask = torch.zeros(N, dtype=torch.bool, device=device)
122
+ if tr is not None:
123
+ sample_mask |= (valid_depth_ratio[i] < tr)
124
+ if tn is not None:
125
+ sample_mask |= (valid_depth_num[i] < tn)
126
+
127
+ invalid_mask[i] = sample_mask
128
+ else:
129
+ # Uniform threshold
130
+ invalid_mask = torch.zeros(B, N, dtype=torch.bool, device=device)
131
+
132
+ if threshold_ratio is not None:
133
+ invalid_mask |= (valid_depth_ratio < threshold_ratio)
134
+ if threshold_num is not None:
135
+ invalid_mask |= (valid_depth_num < threshold_num)
136
+
137
+ return invalid_mask
third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/vision_transformer.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable, Optional, List
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+ from ..layers import PatchEmbedMLP
22
+
23
+ from .mask_utils import depth_masking
24
+
25
+ logger = logging.getLogger("dinov2_rgbd")
26
+
27
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
28
+ if not depth_first and include_root:
29
+ fn(module=module, name=name)
30
+ for child_name, child_module in module.named_children():
31
+ child_name = ".".join((name, child_name)) if name else child_name
32
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
33
+ if depth_first and include_root:
34
+ fn(module=module, name=name)
35
+ return module
36
+
37
+
38
+ class BlockChunk(nn.ModuleList):
39
+ def forward(self, x):
40
+ for b in self:
41
+ x = b(x)
42
+ return x
43
+
44
+
45
+ class DinoVisionTransformer(nn.Module):
46
+ def __init__(
47
+ self,
48
+ img_size=224,
49
+ patch_size=16,
50
+ in_chans=3,
51
+ embed_dim=768,
52
+ depth=12,
53
+ num_heads=12,
54
+ mlp_ratio=4.0,
55
+ qkv_bias=True,
56
+ ffn_bias=True,
57
+ proj_bias=True,
58
+ drop_path_rate=0.0,
59
+ drop_path_uniform=False,
60
+ init_values=None, # for layerscale: None or 0 => no layerscale
61
+ embed_layer=PatchEmbed,
62
+ act_layer=nn.GELU,
63
+ block_fn=Block,
64
+ ffn_layer="mlp",
65
+ block_chunks=1,
66
+ num_register_tokens=0,
67
+ interpolate_antialias=False,
68
+ interpolate_offset=0.1,
69
+ img_depth_fuse_mode='',
70
+ depth_mask_ratio:Union[float, List[float]]=0.6,
71
+ img_mask_ratio:Union[float, List[float]]=0.0,
72
+ depth_mask_patch_grid_size: int=1,
73
+ img_mask_patch_grid_size: int=1,
74
+ depth_emb_mode='',
75
+ # depth_emb_mode='conv_1c'
76
+ ):
77
+ """
78
+ Args:
79
+ img_size (int, tuple): input image size
80
+ patch_size (int, tuple): patch size
81
+ in_chans (int): number of input channels
82
+ embed_dim (int): embedding dimension
83
+ depth (int): depth of transformer
84
+ num_heads (int): number of attention heads
85
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
86
+ qkv_bias (bool): enable bias for qkv if True
87
+ proj_bias (bool): enable bias for proj in attn if True
88
+ ffn_bias (bool): enable bias for ffn if True
89
+ drop_path_rate (float): stochastic depth rate
90
+ drop_path_uniform (bool): apply uniform drop rate across blocks
91
+ weight_init (str): weight init scheme
92
+ init_values (float): layer-scale init values
93
+ embed_layer (nn.Module): patch embedding layer
94
+ act_layer (nn.Module): MLP activation layer
95
+ block_fn (nn.Module): transformer block class
96
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
97
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
98
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
99
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
100
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
101
+ """
102
+ super().__init__()
103
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
104
+
105
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
106
+ self.num_tokens = 1
107
+ self.n_blocks = depth
108
+ self.num_heads = num_heads
109
+ self.patch_size = patch_size
110
+ self.num_register_tokens = num_register_tokens
111
+ self.interpolate_antialias = interpolate_antialias
112
+ self.interpolate_offset = interpolate_offset
113
+
114
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
115
+ num_patches = self.patch_embed.num_patches
116
+
117
+ self.depth_emb_mode = depth_emb_mode
118
+ if self.depth_emb_mode == 'conv_1c':
119
+ self.depth_patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=1, embed_dim=embed_dim)
120
+ else:
121
+ self.depth_patch_embed = None
122
+
123
+ self.img_depth_fuse_mode = img_depth_fuse_mode
124
+
125
+ self.depth_mask_patch_grid_size = depth_mask_patch_grid_size
126
+ self.img_mask_patch_grid_size = img_mask_patch_grid_size
127
+ assert self.depth_mask_patch_grid_size == 1, "depth_mask_patch_grid_size must be 1 in current version"
128
+ assert self.img_mask_patch_grid_size == 1, "img_mask_patch_grid_size must be 1 in current version"
129
+
130
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
131
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
132
+ assert num_register_tokens >= 0
133
+ self.register_tokens = (
134
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
135
+ )
136
+
137
+ if drop_path_uniform is True:
138
+ dpr = [drop_path_rate] * depth
139
+ else:
140
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
141
+
142
+ if ffn_layer == "mlp":
143
+ logger.info("using MLP layer as FFN")
144
+ ffn_layer = Mlp
145
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
146
+ logger.info("using SwiGLU layer as FFN")
147
+ ffn_layer = SwiGLUFFNFused
148
+ elif ffn_layer == "identity":
149
+ logger.info("using Identity layer as FFN")
150
+
151
+ def f(*args, **kwargs):
152
+ return nn.Identity()
153
+
154
+ ffn_layer = f
155
+ else:
156
+ raise NotImplementedError
157
+
158
+ blocks_list = [
159
+ block_fn(
160
+ dim=embed_dim,
161
+ num_heads=num_heads,
162
+ mlp_ratio=mlp_ratio,
163
+ qkv_bias=qkv_bias,
164
+ proj_bias=proj_bias,
165
+ ffn_bias=ffn_bias,
166
+ drop_path=dpr[i],
167
+ norm_layer=norm_layer,
168
+ act_layer=act_layer,
169
+ ffn_layer=ffn_layer,
170
+ init_values=init_values,
171
+ )
172
+ for i in range(depth)
173
+ ]
174
+ if block_chunks > 0:
175
+ self.chunked_blocks = True
176
+ chunked_blocks = []
177
+ chunksize = depth // block_chunks
178
+ for i in range(0, depth, chunksize):
179
+ # this is to keep the block index consistent if we chunk the block list
180
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
181
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
182
+ else:
183
+ self.chunked_blocks = False
184
+ self.blocks = nn.ModuleList(blocks_list)
185
+
186
+ self.norm = norm_layer(embed_dim)
187
+ self.head = nn.Identity()
188
+
189
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
190
+
191
+ self.init_weights()
192
+
193
+ @property
194
+ def onnx_compatible_mode(self):
195
+ return getattr(self, "_onnx_compatible_mode", False)
196
+
197
+ @onnx_compatible_mode.setter
198
+ def onnx_compatible_mode(self, value: bool):
199
+ self._onnx_compatible_mode = value
200
+
201
+ def init_weights(self):
202
+ trunc_normal_(self.pos_embed, std=0.02)
203
+ nn.init.normal_(self.cls_token, std=1e-6)
204
+ if self.register_tokens is not None:
205
+ nn.init.normal_(self.register_tokens, std=1e-6)
206
+ named_apply(init_weights_vit_timm, self)
207
+
208
+ def interpolate_pos_encoding(self, x, h, w):
209
+ previous_dtype = x.dtype
210
+ npatch = x.shape[1] - 1
211
+ batch_size = x.shape[0]
212
+ N = self.pos_embed.shape[1] - 1
213
+ if not self.onnx_compatible_mode and npatch == N and w == h:
214
+ return self.pos_embed
215
+ pos_embed = self.pos_embed.float()
216
+ class_pos_embed = pos_embed[:, 0, :]
217
+ patch_pos_embed = pos_embed[:, 1:, :]
218
+ dim = x.shape[-1]
219
+ h0, w0 = h // self.patch_size, w // self.patch_size
220
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
221
+ assert N == M * M
222
+ kwargs = {}
223
+ if not self.onnx_compatible_mode and self.interpolate_offset > 0:
224
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
225
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
226
+ sx = float(w0 + self.interpolate_offset) / M
227
+ sy = float(h0 + self.interpolate_offset) / M
228
+ kwargs["scale_factor"] = (sy, sx)
229
+ else:
230
+ # Simply specify an output size instead of a scale factor
231
+ kwargs["size"] = (h0, w0)
232
+
233
+ patch_pos_embed = nn.functional.interpolate(
234
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
235
+ mode="bicubic",
236
+ antialias=self.interpolate_antialias,
237
+ **kwargs,
238
+ )
239
+
240
+ assert (h0, w0) == patch_pos_embed.shape[-2:]
241
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
242
+ return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype)
243
+
244
+ def interpolate_pos_encoding_without_cls(self, x, h, w, input_pos_embed):
245
+ previous_dtype = x.dtype
246
+ npatch = x.shape[1]
247
+ batch_size = x.shape[0]
248
+ N = input_pos_embed.shape[1]
249
+ if not self.onnx_compatible_mode and npatch == N and w == h:
250
+ return input_pos_embed
251
+ patch_pos_embed = input_pos_embed.float()
252
+ dim = x.shape[-1]
253
+ h0, w0 = h // self.patch_size, w // self.patch_size
254
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
255
+ assert N == M * M
256
+ kwargs = {}
257
+ if not self.onnx_compatible_mode and self.interpolate_offset > 0:
258
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
259
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
260
+ sx = float(w0 + self.interpolate_offset) / M
261
+ sy = float(h0 + self.interpolate_offset) / M
262
+ kwargs["scale_factor"] = (sy, sx)
263
+ else:
264
+ # Simply specify an output size instead of a scale factor
265
+ kwargs["size"] = (h0, w0)
266
+ patch_pos_embed = nn.functional.interpolate(
267
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
268
+ mode="bicubic",
269
+ antialias=self.interpolate_antialias,
270
+ **kwargs,
271
+ )
272
+ assert (h0, w0) == patch_pos_embed.shape[-2:]
273
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
274
+ return patch_pos_embed.to(previous_dtype)
275
+
276
+ def prepare_tokens_with_masks(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, masks=None, **kwargs):
277
+ assert masks is None, "extra masks are not supported for this model."
278
+ B, nc, h_img, w_img = x_img.shape
279
+ _, _, h_depth, w_depth = x_depth.shape
280
+ x_depth_raw = x_depth.clone()
281
+ x_depth_raw[x_depth_raw == 0] = -10
282
+
283
+ depth_patch_num_h, depth_patch_num_w = h_depth // self.patch_size, w_depth // self.patch_size
284
+
285
+ # patchify, embed image tokens and depth tokens
286
+ x_img = self.patch_embed(x_img) # batch, length_img, dim
287
+ assert self.depth_patch_embed is not None
288
+ x_depth = self.depth_patch_embed(x_depth) # batch, length_depth, dim
289
+ assert depth_patch_num_h * depth_patch_num_w == x_depth.shape[1]
290
+
291
+ # get full pose enc of img and depth
292
+ # 1-> img data type enc
293
+ # 2-> depth data type enc
294
+ img_pose_enc = 1 + self.interpolate_pos_encoding_without_cls(x_img, h_img, w_img, self.pos_embed[:, 1:]).repeat(B, 1, 1)
295
+ depth_pose_enc = 2 + self.interpolate_pos_encoding_without_cls(x_depth, h_depth, w_depth, self.pos_embed[:, 1:]).repeat(B, 1, 1)
296
+
297
+ # add pose enc to img and depth
298
+ x_img = x_img + img_pose_enc
299
+ x_depth = x_depth + depth_pose_enc
300
+
301
+ ## mask depth tokens
302
+ if kwargs.get('enable_depth_mask', True):
303
+ x_depth_masked, depth_mask_info = depth_masking(
304
+ x_depth,
305
+ depth_patch_num_h,
306
+ depth_patch_num_w,
307
+ depth_values=x_depth_raw,
308
+ depth_mask_threshold_num=[1]*B,
309
+ valid_depth_range=(-9.5, 200.0)
310
+ )
311
+ else:
312
+ x_depth_masked = x_depth
313
+ depth_mask_info = None
314
+
315
+ ## mask image tokens
316
+ x_img_masked = x_img
317
+ img_mask_info = None
318
+
319
+ # get cls token
320
+ x_cls = self.cls_token.squeeze(0) + self.pos_embed.squeeze(0)[:1] # 1, dim
321
+
322
+ # cat cls, img and depth tokens
323
+ assert self.img_depth_fuse_mode == 'cat_token', "Only cat_token mode is supported for this model."
324
+ x_masked_list = []
325
+ for i in range(B):
326
+ if self.register_tokens is not None:
327
+ x_mased = torch.cat([x_cls, self.register_tokens.squeeze(0), x_img_masked[i], x_depth_masked[i]], dim=0) # 1 + num_register_tokens + length_img + length_depth, dim
328
+ else:
329
+ x_mased = torch.cat([x_cls, x_img_masked[i], x_depth_masked[i]], dim=0) # 1 + length_img + length_depth, dim
330
+ x_mased = x_mased.unsqueeze(0) # 1, 1 + num_register_tokens + length_img + length_depth, dim
331
+ x_masked_list.append(x_mased)
332
+
333
+ return x_masked_list
334
+
335
+ def _get_intermediate_layers_not_chunked(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, n=1, return_mae_aux=False, **kwargs):
336
+ x = self.prepare_tokens_with_masks(x_img, x_depth, x_img_mask, x_depth_mask, **kwargs)
337
+
338
+ if not kwargs.get('enable_depth_mask', True):
339
+ x = torch.cat(x, dim=0)
340
+
341
+ # If n is an int, take the n last blocks. If it's a list, take them
342
+ output, total_block_len = [], len(self.blocks)
343
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
344
+ for i, blk in enumerate(self.blocks):
345
+ x = blk(x)
346
+ if i in blocks_to_take:
347
+ output.append(x)
348
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
349
+
350
+ if not kwargs.get('enable_depth_mask', True):
351
+ output = [list(torch.split(out, 1, dim=0)) for out in output]
352
+ return output
353
+
354
+ def _get_intermediate_layers_chunked(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, n=1, return_mae_aux=False, **kwargs):
355
+ x = self.prepare_tokens_with_masks(x_img, x_depth, x_img_mask, x_depth_mask, **kwargs)
356
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
357
+ # If n is an int, take the n last blocks. If it's a list, take them
358
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
359
+ for block_chunk in self.blocks:
360
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
361
+ x = blk(x)
362
+ if i in blocks_to_take:
363
+ output.append(x)
364
+ i += 1
365
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
366
+
367
+ return output
368
+
369
+ def extract_features(self, outputs, norm=True):
370
+ feat_outputs = []
371
+ class_tokens = []
372
+ feat_start_idx = 1 + self.num_register_tokens
373
+
374
+ def process_output(out):
375
+ normed = self.norm(out) if norm else out
376
+ return normed[:, feat_start_idx:], normed[:, 0]
377
+
378
+ for output in outputs:
379
+ if isinstance(output, list):
380
+ feats, tokens = zip(*[process_output(out) for out in output])
381
+ feat_outputs.append(list(feats))
382
+ class_tokens.append(list(tokens))
383
+ else:
384
+ feat, token = process_output(output)
385
+ feat_outputs.append(feat)
386
+ class_tokens.append(token)
387
+
388
+ return feat_outputs, class_tokens
389
+
390
+ def get_intermediate_layers_mae(
391
+ self,
392
+ x_img: torch.Tensor,
393
+ x_depth: torch.Tensor,
394
+ x_img_mask: torch.Tensor=None,
395
+ x_depth_mask: torch.Tensor=None,
396
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
397
+ reshape: bool = False,
398
+ return_class_token: bool = False,
399
+ norm=True,
400
+ return_mae_aux=True,
401
+ **kwargs
402
+ ):
403
+ assert reshape is False, "reshape is not supported for now"
404
+ if self.chunked_blocks:
405
+ outputs = self._get_intermediate_layers_chunked(x_img, x_depth, x_img_mask, x_depth_mask, n, return_mae_aux=return_mae_aux,**kwargs)
406
+ else:
407
+ outputs = self._get_intermediate_layers_not_chunked(x_img, x_depth, x_img_mask, x_depth_mask, n, return_mae_aux=return_mae_aux,**kwargs)
408
+
409
+ feat_outputs, class_tokens = self.extract_features(outputs, norm)
410
+
411
+ if return_class_token:
412
+ return tuple(zip(feat_outputs, class_tokens))
413
+ return tuple(feat_outputs)
414
+
415
+
416
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
417
+ """ViT weight initialization, original timm impl (for reproducibility)"""
418
+ if isinstance(module, nn.Linear):
419
+ trunc_normal_(module.weight, std=0.02)
420
+ if module.bias is not None:
421
+ nn.init.zeros_(module.bias)
422
+
423
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
424
+ model = DinoVisionTransformer(
425
+ patch_size=patch_size,
426
+ embed_dim=384,
427
+ depth=12,
428
+ num_heads=6,
429
+ mlp_ratio=4,
430
+ block_fn=partial(Block, attn_class=MemEffAttention),
431
+ num_register_tokens=num_register_tokens,
432
+ **kwargs,
433
+ )
434
+ return model
435
+
436
+
437
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
438
+ model = DinoVisionTransformer(
439
+ patch_size=patch_size,
440
+ embed_dim=768,
441
+ depth=12,
442
+ num_heads=12,
443
+ mlp_ratio=4,
444
+ block_fn=partial(Block, attn_class=MemEffAttention),
445
+ num_register_tokens=num_register_tokens,
446
+ **kwargs,
447
+ )
448
+ return model
449
+
450
+
451
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
452
+ model = DinoVisionTransformer(
453
+ patch_size=patch_size,
454
+ embed_dim=1024,
455
+ depth=24,
456
+ num_heads=16,
457
+ mlp_ratio=4,
458
+ block_fn=partial(Block, attn_class=MemEffAttention),
459
+ num_register_tokens=num_register_tokens,
460
+ **kwargs,
461
+ )
462
+ return model
463
+
464
+
465
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
466
+ """
467
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
468
+ """
469
+ model = DinoVisionTransformer(
470
+ patch_size=patch_size,
471
+ embed_dim=1536,
472
+ depth=40,
473
+ num_heads=24,
474
+ mlp_ratio=4,
475
+ block_fn=partial(Block, attn_class=MemEffAttention),
476
+ num_register_tokens=num_register_tokens,
477
+ **kwargs,
478
+ )
479
+ return model
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/cluster.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+
12
+ class ClusterType(Enum):
13
+ AWS = "aws"
14
+ FAIR = "fair"
15
+ RSC = "rsc"
16
+
17
+
18
+ def _guess_cluster_type() -> ClusterType:
19
+ uname = os.uname()
20
+ if uname.sysname == "Linux":
21
+ if uname.release.endswith("-aws"):
22
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
23
+ return ClusterType.AWS
24
+ elif uname.nodename.startswith("rsc"):
25
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
26
+ return ClusterType.RSC
27
+
28
+ return ClusterType.FAIR
29
+
30
+
31
+ def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
32
+ if cluster_type is None:
33
+ return _guess_cluster_type()
34
+
35
+ return cluster_type
36
+
37
+
38
+ def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
39
+ cluster_type = get_cluster_type(cluster_type)
40
+ if cluster_type is None:
41
+ return None
42
+
43
+ CHECKPOINT_DIRNAMES = {
44
+ ClusterType.AWS: "checkpoints",
45
+ ClusterType.FAIR: "checkpoint",
46
+ ClusterType.RSC: "checkpoint/dino",
47
+ }
48
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
49
+
50
+
51
+ def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
52
+ checkpoint_path = get_checkpoint_path(cluster_type)
53
+ if checkpoint_path is None:
54
+ return None
55
+
56
+ username = os.environ.get("USER")
57
+ assert username is not None
58
+ return checkpoint_path / username
59
+
60
+
61
+ def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
62
+ cluster_type = get_cluster_type(cluster_type)
63
+ if cluster_type is None:
64
+ return None
65
+
66
+ SLURM_PARTITIONS = {
67
+ ClusterType.AWS: "learnlab",
68
+ ClusterType.FAIR: "learnlab",
69
+ ClusterType.RSC: "learn",
70
+ }
71
+ return SLURM_PARTITIONS[cluster_type]
72
+
73
+
74
+ def get_slurm_executor_parameters(
75
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
76
+ ) -> Dict[str, Any]:
77
+ # create default parameters
78
+ params = {
79
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
80
+ "gpus_per_node": num_gpus_per_node,
81
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
82
+ "cpus_per_task": 10,
83
+ "nodes": nodes,
84
+ "slurm_partition": get_slurm_partition(cluster_type),
85
+ }
86
+ # apply cluster-specific adjustments
87
+ cluster_type = get_cluster_type(cluster_type)
88
+ if cluster_type == ClusterType.AWS:
89
+ params["cpus_per_task"] = 12
90
+ del params["mem_gb"]
91
+ elif cluster_type == ClusterType.RSC:
92
+ params["cpus_per_task"] = 12
93
+ # set additional parameters / apply overrides
94
+ params.update(kwargs)
95
+ return params
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import logging
8
+ import os
9
+
10
+ from omegaconf import OmegaConf
11
+
12
+ import dinov2.distributed as distributed
13
+ from dinov2.logging import setup_logging
14
+ from dinov2.utils import utils
15
+ from dinov2.configs import dinov2_default_config
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ def apply_scaling_rules_to_cfg(cfg): # to fix
22
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
23
+ base_lr = cfg.optim.base_lr
24
+ cfg.optim.lr = base_lr
25
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
26
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
27
+ else:
28
+ raise NotImplementedError
29
+ return cfg
30
+
31
+
32
+ def write_config(cfg, output_dir, name="config.yaml"):
33
+ logger.info(OmegaConf.to_yaml(cfg))
34
+ saved_cfg_path = os.path.join(output_dir, name)
35
+ with open(saved_cfg_path, "w") as f:
36
+ OmegaConf.save(config=cfg, f=f)
37
+ return saved_cfg_path
38
+
39
+
40
+ def get_cfg_from_args(args):
41
+ args.output_dir = os.path.abspath(args.output_dir)
42
+ args.opts += [f"train.output_dir={args.output_dir}"]
43
+ default_cfg = OmegaConf.create(dinov2_default_config)
44
+ cfg = OmegaConf.load(args.config_file)
45
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
46
+ return cfg
47
+
48
+
49
+ def default_setup(args):
50
+ distributed.enable(overwrite=True)
51
+ seed = getattr(args, "seed", 0)
52
+ rank = distributed.get_global_rank()
53
+
54
+ global logger
55
+ setup_logging(output=args.output_dir, level=logging.INFO)
56
+ logger = logging.getLogger("dinov2")
57
+
58
+ utils.fix_random_seeds(seed + rank)
59
+ logger.info("git:\n {}\n".format(utils.get_sha()))
60
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
61
+
62
+
63
+ def setup(args):
64
+ """
65
+ Create configs and perform basic setups.
66
+ """
67
+ cfg = get_cfg_from_args(args)
68
+ os.makedirs(args.output_dir, exist_ok=True)
69
+ default_setup(args)
70
+ apply_scaling_rules_to_cfg(cfg)
71
+ write_config(cfg, args.output_dir)
72
+ return cfg
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/dtype.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import Dict, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ TypeSpec = Union[str, np.dtype, torch.dtype]
14
+
15
+
16
+ _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
17
+ np.dtype("bool"): torch.bool,
18
+ np.dtype("uint8"): torch.uint8,
19
+ np.dtype("int8"): torch.int8,
20
+ np.dtype("int16"): torch.int16,
21
+ np.dtype("int32"): torch.int32,
22
+ np.dtype("int64"): torch.int64,
23
+ np.dtype("float16"): torch.float16,
24
+ np.dtype("float32"): torch.float32,
25
+ np.dtype("float64"): torch.float64,
26
+ np.dtype("complex64"): torch.complex64,
27
+ np.dtype("complex128"): torch.complex128,
28
+ }
29
+
30
+
31
+ def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
32
+ if isinstance(dtype, torch.dtype):
33
+ return dtype
34
+ if isinstance(dtype, str):
35
+ dtype = np.dtype(dtype)
36
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
37
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/param_groups.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ import logging
8
+
9
+
10
+ logger = logging.getLogger("dinov2")
11
+
12
+
13
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
14
+ """
15
+ Calculate lr decay rate for different ViT blocks.
16
+ Args:
17
+ name (string): parameter name.
18
+ lr_decay_rate (float): base lr decay rate.
19
+ num_layers (int): number of ViT blocks.
20
+ Returns:
21
+ lr decay rate for the given parameter.
22
+ """
23
+ layer_id = num_layers + 1
24
+ if name.startswith("backbone") or force_is_backbone:
25
+ if (
26
+ ".pos_embed" in name
27
+ or ".patch_embed" in name
28
+ or ".mask_token" in name
29
+ or ".cls_token" in name
30
+ or ".register_tokens" in name
31
+ ):
32
+ layer_id = 0
33
+ elif force_is_backbone and (
34
+ "pos_embed" in name
35
+ or "patch_embed" in name
36
+ or "mask_token" in name
37
+ or "cls_token" in name
38
+ or "register_tokens" in name
39
+ ):
40
+ layer_id = 0
41
+ elif ".blocks." in name and ".residual." not in name:
42
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
43
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
44
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
45
+ elif "blocks." in name and "residual." not in name:
46
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
47
+
48
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
49
+
50
+
51
+ def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
52
+ chunked_blocks = False
53
+ if hasattr(model, "n_blocks"):
54
+ logger.info("chunked fsdp")
55
+ n_blocks = model.n_blocks
56
+ chunked_blocks = model.chunked_blocks
57
+ elif hasattr(model, "blocks"):
58
+ logger.info("first code branch")
59
+ n_blocks = len(model.blocks)
60
+ elif hasattr(model, "backbone"):
61
+ logger.info("second code branch")
62
+ n_blocks = len(model.backbone.blocks)
63
+ else:
64
+ logger.info("else code branch")
65
+ n_blocks = 0
66
+ all_param_groups = []
67
+
68
+ for name, param in model.named_parameters():
69
+ name = name.replace("_fsdp_wrapped_module.", "")
70
+ if not param.requires_grad:
71
+ continue
72
+ decay_rate = get_vit_lr_decay_rate(
73
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
74
+ )
75
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
76
+
77
+ if "last_layer" in name:
78
+ d.update({"is_last_layer": True})
79
+
80
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
81
+ d.update({"wd_multiplier": 0.0})
82
+
83
+ if "patch_embed" in name:
84
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
85
+
86
+ all_param_groups.append(d)
87
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
88
+
89
+ return all_param_groups
90
+
91
+
92
+ def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
93
+ fused_params_groups = defaultdict(lambda: {"params": []})
94
+ for d in all_params_groups:
95
+ identifier = ""
96
+ for k in keys:
97
+ identifier += k + str(d[k]) + "_"
98
+
99
+ for k in keys:
100
+ fused_params_groups[identifier][k] = d[k]
101
+ fused_params_groups[identifier]["params"].append(d["params"])
102
+
103
+ return fused_params_groups.values()
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import random
9
+ import subprocess
10
+ from urllib.parse import urlparse
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
21
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
22
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
23
+ else:
24
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
25
+ if checkpoint_key is not None and checkpoint_key in state_dict:
26
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
27
+ state_dict = state_dict[checkpoint_key]
28
+ # remove `module.` prefix
29
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
30
+ # remove `backbone.` prefix induced by multicrop wrapper
31
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
32
+ msg = model.load_state_dict(state_dict, strict=False)
33
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
34
+
35
+
36
+ def fix_random_seeds(seed=31):
37
+ """
38
+ Fix random seeds.
39
+ """
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed)
42
+ np.random.seed(seed)
43
+ random.seed(seed)
44
+
45
+
46
+ def get_sha():
47
+ cwd = os.path.dirname(os.path.abspath(__file__))
48
+
49
+ def _run(command):
50
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
51
+
52
+ sha = "N/A"
53
+ diff = "clean"
54
+ branch = "N/A"
55
+ try:
56
+ sha = _run(["git", "rev-parse", "HEAD"])
57
+ subprocess.check_output(["git", "diff"], cwd=cwd)
58
+ diff = _run(["git", "diff-index", "HEAD"])
59
+ diff = "has uncommitted changes" if diff else "clean"
60
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
61
+ except Exception:
62
+ pass
63
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
64
+ return message
65
+
66
+
67
+ class CosineScheduler(object):
68
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
69
+ super().__init__()
70
+ self.final_value = final_value
71
+ self.total_iters = total_iters
72
+
73
+ freeze_schedule = np.zeros((freeze_iters))
74
+
75
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
76
+
77
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
78
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
79
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
80
+
81
+ assert len(self.schedule) == self.total_iters
82
+
83
+ def __getitem__(self, it):
84
+ if it >= self.total_iters:
85
+ return self.final_value
86
+ else:
87
+ return self.schedule[it]
88
+
89
+
90
+ def has_batchnorms(model):
91
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
92
+ for name, module in model.named_modules():
93
+ if isinstance(module, bn_types):
94
+ return True
95
+ return False
third_party/lingbot_depth/mdm/model/modules_decoder.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+ import importlib
4
+ import itertools
5
+ import functools
6
+ import sys
7
+
8
+ import torch
9
+ from torch import Tensor
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .utils import wrap_module_with_gradient_checkpointing
14
+
15
+
16
+ class ResidualConvBlock(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ out_channels: int = None,
21
+ hidden_channels: int = None,
22
+ kernel_size: int = 3,
23
+ padding_mode: str = 'replicate',
24
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
25
+ in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
26
+ hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
27
+ ):
28
+ super(ResidualConvBlock, self).__init__()
29
+ if out_channels is None:
30
+ out_channels = in_channels
31
+ if hidden_channels is None:
32
+ hidden_channels = in_channels
33
+
34
+ if activation =='relu':
35
+ activation_cls = nn.ReLU
36
+ elif activation == 'leaky_relu':
37
+ activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
38
+ elif activation =='silu':
39
+ activation_cls = nn.SiLU
40
+ elif activation == 'elu':
41
+ activation_cls = nn.ELU
42
+ else:
43
+ raise ValueError(f'Unsupported activation function: {activation}')
44
+
45
+ self.layers = nn.Sequential(
46
+ nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
47
+ nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
48
+ nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
49
+ nn.Identity(),
50
+ activation_cls(),
51
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
52
+ nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
53
+ nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
54
+ nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
55
+ nn.Identity(),
56
+ activation_cls(),
57
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
58
+ )
59
+
60
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
61
+
62
+ def forward(self, x):
63
+ skip = self.skip_connection(x)
64
+ x = self.layers(x)
65
+ x = x + skip
66
+ return x
67
+
68
+
69
+ class Resampler(nn.Sequential):
70
+ def __init__(self,
71
+ in_channels: int,
72
+ out_channels: int,
73
+ type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
74
+ scale_factor: int = 2,
75
+ ):
76
+ if type_ == 'pixel_shuffle':
77
+ nn.Sequential.__init__(self,
78
+ nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
79
+ nn.PixelShuffle(scale_factor),
80
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
81
+ )
82
+ for i in range(1, scale_factor ** 2):
83
+ self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
84
+ self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
85
+ elif type_ in ['nearest', 'bilinear']:
86
+ nn.Sequential.__init__(self,
87
+ nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
88
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
89
+ )
90
+ elif type_ == 'conv_transpose':
91
+ nn.Sequential.__init__(self,
92
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
93
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
94
+ )
95
+ self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
96
+ elif type_ == 'pixel_unshuffle':
97
+ nn.Sequential.__init__(self,
98
+ nn.PixelUnshuffle(scale_factor),
99
+ nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
100
+ )
101
+ elif type_ == 'avg_pool':
102
+ nn.Sequential.__init__(self,
103
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
104
+ nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
105
+ )
106
+ elif type_ == 'max_pool':
107
+ nn.Sequential.__init__(self,
108
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
109
+ nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
110
+ )
111
+ else:
112
+ raise ValueError(f'Unsupported resampler type: {type_}')
113
+
114
+
115
+ class MLP(nn.Sequential):
116
+ def __init__(self, dims: Sequence[int]):
117
+ nn.Sequential.__init__(self,
118
+ *itertools.chain(*[
119
+ (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
120
+ for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
121
+ ]),
122
+ nn.Linear(dims[-2], dims[-1]),
123
+ )
124
+
125
+
126
+ class ConvStack(nn.Module):
127
+ def __init__(self,
128
+ dim_in: List[Optional[int]],
129
+ dim_res_blocks: List[int],
130
+ dim_out: List[Optional[int]],
131
+ resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
132
+ dim_times_res_block_hidden: int = 1,
133
+ num_res_blocks: int = 1,
134
+ res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
135
+ res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
136
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
137
+ ):
138
+ super().__init__()
139
+ self.input_blocks = nn.ModuleList([
140
+ nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
141
+ for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
142
+ ])
143
+ self.resamplers = nn.ModuleList([
144
+ Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
145
+ for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
146
+ dim_res_blocks[:-1],
147
+ dim_res_blocks[1:],
148
+ resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
149
+ ))
150
+ ])
151
+ self.res_blocks = nn.ModuleList([
152
+ nn.Sequential(
153
+ *(
154
+ ResidualConvBlock(
155
+ dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
156
+ activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
157
+ ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
158
+ )
159
+ ) for i, dim_res_block_ in enumerate(dim_res_blocks)
160
+ ])
161
+ self.output_blocks = nn.ModuleList([
162
+ nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
163
+ for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
164
+ ])
165
+
166
+ def enable_gradient_checkpointing(self):
167
+ for i in range(len(self.resamplers)):
168
+ self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
169
+ for i in range(len(self.res_blocks)):
170
+ for j in range(len(self.res_blocks[i])):
171
+ self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
172
+
173
+ def forward(self, in_features: List[torch.Tensor]):
174
+ out_features = []
175
+ for i in range(len(self.res_blocks)):
176
+ feature = self.input_blocks[i](in_features[i])
177
+ if i == 0:
178
+ x = feature
179
+ elif feature is not None:
180
+ x = x + feature
181
+ x = self.res_blocks[i](x)
182
+ out_features.append(self.output_blocks[i](x))
183
+ if i < len(self.res_blocks) - 1:
184
+ x = self.resamplers[i](x)
185
+ return out_features
third_party/lingbot_depth/mdm/model/modules_rgbd_encoder.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+ import importlib
4
+ import itertools
5
+ import functools
6
+ import sys
7
+
8
+ import torch
9
+ from torch import Tensor
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .dinov2_rgbd.models.vision_transformer import DinoVisionTransformer
14
+ from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing
15
+
16
+
17
+ class DINOv2_RGBD_Encoder(nn.Module):
18
+ backbone: DinoVisionTransformer
19
+ image_mean: torch.Tensor
20
+ image_std: torch.Tensor
21
+ dim_features: int
22
+
23
+ def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, ignore_layers: Union[str, List[str]]=[], in_chans: int=3, strict: bool=True, img_depth_fuse_mode='', depth_emb_mode='', depth_mask_ratio=0.6, img_mask_ratio=0.0, **deprecated_kwargs):
24
+ super(DINOv2_RGBD_Encoder, self).__init__()
25
+
26
+ self.intermediate_layers = intermediate_layers
27
+ self.strict = strict
28
+ self.ignore_layers = ignore_layers
29
+ self.img_mask_ratio = img_mask_ratio
30
+ # Load the backbone
31
+ self.hub_loader = getattr(importlib.import_module(".dinov2_rgbd.hub.backbones", __package__), backbone)
32
+ self.backbone_name = backbone
33
+ self.backbone = self.hub_loader(pretrained=False,
34
+ in_chans=in_chans,
35
+ img_depth_fuse_mode=img_depth_fuse_mode,
36
+ depth_emb_mode=depth_emb_mode,
37
+ depth_mask_ratio=depth_mask_ratio,
38
+ img_mask_ratio=img_mask_ratio)
39
+
40
+ self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
41
+ self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
42
+
43
+ if img_mask_ratio > 0:
44
+ self.mask_token_mae = nn.Parameter(torch.zeros(1, 1, self.dim_features))
45
+ torch.nn.init.normal_(self.mask_token_mae, std=.02)
46
+
47
+ self.output_projections = nn.ModuleList([
48
+ nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,)
49
+ for _ in range(self.num_features)
50
+ ])
51
+
52
+ self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
53
+ self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
54
+
55
+ @property
56
+ def onnx_compatible_mode(self):
57
+ return getattr(self, "_onnx_compatible_mode", False)
58
+
59
+ @onnx_compatible_mode.setter
60
+ def onnx_compatible_mode(self, value: bool):
61
+ self._onnx_compatible_mode = value
62
+ self.backbone.onnx_compatible_mode = value
63
+
64
+ def init_weights(self):
65
+ pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
66
+ ignore_layers = []
67
+ if isinstance(self.ignore_layers, str):
68
+ ignore_layers = [self.ignore_layers]
69
+ else:
70
+ ignore_layers = self.ignore_layers
71
+
72
+ if len(ignore_layers) == 0:
73
+ self.backbone.load_state_dict(pretrained_backbone_state_dict, strict=self.strict)
74
+ else:
75
+ state_dict = {}
76
+ for k, v in pretrained_backbone_state_dict.items():
77
+ is_ignore = False
78
+ for ig_k in ignore_layers:
79
+ if ig_k in k:
80
+ is_ignore = True
81
+ break
82
+ if not is_ignore:
83
+ state_dict[k] = v
84
+ self.backbone.load_state_dict(state_dict, strict=self.strict)
85
+
86
+ def enable_gradient_checkpointing(self):
87
+ for i in range(len(self.backbone.blocks)):
88
+ wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
89
+
90
+ def enable_pytorch_native_sdpa(self):
91
+ for i in range(len(self.backbone.blocks)):
92
+ wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
93
+
94
+ def forward(self,
95
+ image: torch.Tensor,
96
+ depth: torch.Tensor,
97
+ token_rows: Union[int, torch.LongTensor],
98
+ token_cols: Union[int, torch.LongTensor],
99
+ return_class_token: bool = False,
100
+ remap_depth_in: str='linear',
101
+ **kwargs):
102
+ image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode)
103
+ image_14 = (image_14 - self.image_mean) / self.image_std
104
+
105
+ depth_14 = F.interpolate(depth, (token_rows * 14, token_cols * 14), mode="nearest")
106
+
107
+ # set invalid depth value to zero
108
+ depth_14[torch.isinf(depth_14)] = 0.0
109
+ depth_14[torch.isnan(depth_14)] = 0.0
110
+ dmask_14 = (depth_14 > 0.01).detach()
111
+ depth_14 = depth_14 * dmask_14.float()
112
+
113
+ if remap_depth_in == 'linear':
114
+ pass # do nothing
115
+ elif remap_depth_in == 'log':
116
+ depth_14 = torch.log(depth_14)
117
+ depth_14[~dmask_14] = 0.0
118
+ depth_14 = torch.nan_to_num(depth_14, nan=0.0, posinf=0.0, neginf=0.0)
119
+ else:
120
+ raise NotImplementedError
121
+
122
+ # Get intermediate layers from the backbone
123
+ features = self.backbone.get_intermediate_layers_mae(
124
+ x_img=image_14,
125
+ x_depth=depth_14,
126
+ n=self.intermediate_layers,
127
+ return_class_token=True,
128
+ **kwargs)
129
+
130
+ assert self.img_mask_ratio == 0, "img_mask_ratio is not supported in this encoder"
131
+
132
+ if isinstance(features[0][0], list):
133
+ num_valid_tokens = token_rows * token_cols
134
+ features = tuple(
135
+ (
136
+ torch.cat([feat[:, :num_valid_tokens].contiguous() for feat in feats], dim=0),
137
+ torch.cat(cls_tokens, dim=0)
138
+ )
139
+ for feats, cls_tokens in features
140
+ )
141
+
142
+ # Project features to the desired dimensionality
143
+ x = torch.stack([
144
+ proj(feat.permute(0, 2, 1)[:, :, :token_rows*token_cols].unflatten(2, (token_rows, token_cols)).contiguous())
145
+ for proj, (feat, clstoken) in zip(self.output_projections, features)
146
+ ], dim=1).sum(dim=1)
147
+ cls_token = features[-1][1]
148
+
149
+ if return_class_token:
150
+ return x, cls_token, None, None
151
+ else:
152
+ return x, None, None
third_party/lingbot_depth/mdm/model/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ def wrap_module_with_gradient_checkpointing(module: nn.Module):
8
+ from torch.utils.checkpoint import checkpoint
9
+ class _CheckpointingWrapper(module.__class__):
10
+ _restore_cls = module.__class__
11
+ def forward(self, *args, **kwargs):
12
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
13
+
14
+ module.__class__ = _CheckpointingWrapper
15
+ return module
16
+
17
+
18
+ def unwrap_module_with_gradient_checkpointing(module: nn.Module):
19
+ module.__class__ = module.__class__._restore_cls
20
+
21
+
22
+ def wrap_dinov2_attention_with_sdpa(module: nn.Module):
23
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
24
+ class _AttentionWrapper(module.__class__):
25
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
26
+ B, N, C = x.shape
27
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
28
+
29
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
30
+
31
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
32
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
33
+
34
+ x = self.proj(x)
35
+ x = self.proj_drop(x)
36
+ return x
37
+ module.__class__ = _AttentionWrapper
38
+ return module
39
+
40
+ def wrap_dinov3_attention_with_sdpa(module: nn.Module):
41
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
42
+ class _AttentionWrapper(module.__class__):
43
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
44
+ B, N, C = x.shape
45
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
46
+
47
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
48
+
49
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
50
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
51
+
52
+ x = self.proj(x)
53
+ x = self.proj_drop(x)
54
+ return x
55
+ module.__class__ = _AttentionWrapper
56
+ return module
57
+
58
+ def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]:
59
+ group_to_use = torch.distributed.group.WORLD
60
+ world_size = group_to_use.size()
61
+ grad = bucket.buffer()
62
+ grad.div_(world_size)
63
+ torch.distributed.all_reduce(grad, group=group_to_use)
64
+ fut = torch.futures.Future()
65
+ fut.set_result(grad)
66
+ return fut
67
+
68
+ def depth_to_pointcloud(depth, intrinsic_normalized, depth_scale=1.0):
69
+ """
70
+ Convert depth map to point cloud (pure Tensor version, no point filtering)
71
+
72
+ Args:
73
+ depth: torch.Tensor, shape (H, W) or (B, H, W), depth map
74
+ intrinsic_normalized: torch.Tensor, shape (3, 3) or (B, 3, 3), normalized intrinsic matrix
75
+ Normalized intrinsics: fx' = fx/W, fy' = fy/H, cx' = cx/W, cy' = cy/H
76
+ depth_scale: float, depth scale factor, default 1000.0
77
+
78
+ Returns:
79
+ points: torch.Tensor, shape (H, W, 3) or (B, H, W, 3), point cloud coordinates (x, y, z)
80
+ """
81
+ # Handle batch dimension
82
+ if depth.dim() == 2:
83
+ depth = depth.unsqueeze(0) # (1, H, W)
84
+ intrinsic_normalized = intrinsic_normalized.unsqueeze(0) # (1, 3, 3)
85
+ squeeze_output = True
86
+ else:
87
+ squeeze_output = False
88
+
89
+ B, H, W = depth.shape
90
+ device = depth.device
91
+
92
+ # Denormalize intrinsics
93
+ fx = intrinsic_normalized[:, 0, 0] * W # (B,)
94
+ fy = intrinsic_normalized[:, 1, 1] * H
95
+ cx = intrinsic_normalized[:, 0, 2] * W
96
+ cy = intrinsic_normalized[:, 1, 2] * H
97
+
98
+ # Create pixel coordinate grid (H, W)
99
+ v, u = torch.meshgrid(
100
+ torch.arange(H, device=device, dtype=torch.float32),
101
+ torch.arange(W, device=device, dtype=torch.float32),
102
+ indexing='ij'
103
+ )
104
+
105
+ # Expand to batch dimension (B, H, W)
106
+ u = u.unsqueeze(0).expand(B, -1, -1)
107
+ v = v.unsqueeze(0).expand(B, -1, -1)
108
+
109
+ # Backproject to 3D space
110
+ z = depth / depth_scale # (B, H, W)
111
+
112
+ # Expand intrinsic dimensions for broadcasting (B, 1, 1)
113
+ fx = fx.view(B, 1, 1)
114
+ fy = fy.view(B, 1, 1)
115
+ cx = cx.view(B, 1, 1)
116
+ cy = cy.view(B, 1, 1)
117
+
118
+ x = (u - cx) * z / fx # (B, H, W)
119
+ y = (v - cy) * z / fy # (B, H, W)
120
+
121
+ # Stack coordinates (B, H, W, 3)
122
+ points = torch.stack([x, y, z], dim=-1)
123
+
124
+ if squeeze_output:
125
+ points = points.squeeze(0) # (H, W, 3)
126
+
127
+ return points
third_party/lingbot_depth/mdm/model/v2.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+ from functools import partial
4
+ from pathlib import Path
5
+ import warnings
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils
11
+ import torch.utils.checkpoint
12
+ import torch.amp
13
+ import torch.version
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ from .modules_rgbd_encoder import DINOv2_RGBD_Encoder
17
+ from .modules_decoder import MLP, ConvStack
18
+ from ..utils.geo import depth_to_pointcloud, normalized_view_plane_uv
19
+
20
+
21
+ class MDMModel(nn.Module):
22
+ encoder: Union[DINOv2_RGBD_Encoder]
23
+ neck: ConvStack
24
+ points_head: ConvStack
25
+ mask_head: ConvStack
26
+ scale_head: MLP
27
+ onnx_compatible_mode: bool
28
+
29
+ def __init__(self,
30
+ encoder: Dict[str, Any],
31
+ neck: Dict[str, Any],
32
+ depth_head: Dict[str, Any] = None,
33
+ mask_head: Dict[str, Any] = None,
34
+ normal_head: Dict[str, Any] = None,
35
+ scale_head: Dict[str, Any] = None,
36
+ remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
37
+ remap_depth_in: Literal['linear', 'log'] = 'log',
38
+ remap_depth_out: Literal['linear', 'exp'] = 'exp',
39
+ num_tokens_range: List[int] = [1200, 3600],
40
+ **deprecated_kwargs
41
+ ):
42
+ super(MDMModel, self).__init__()
43
+ if deprecated_kwargs:
44
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
45
+
46
+ self.remap_output = remap_output
47
+ self.num_tokens_range = num_tokens_range
48
+ self.remap_depth_in = remap_depth_in
49
+ self.remap_depth_out = remap_depth_out
50
+
51
+ self.encoder = DINOv2_RGBD_Encoder(**encoder)
52
+
53
+ self.neck = ConvStack(**neck)
54
+ if depth_head is not None:
55
+ self.depth_head = ConvStack(**depth_head)
56
+ if mask_head is not None:
57
+ self.mask_head = ConvStack(**mask_head)
58
+
59
+ @property
60
+ def device(self) -> torch.device:
61
+ return next(self.parameters()).device
62
+
63
+ @property
64
+ def dtype(self) -> torch.dtype:
65
+ return next(self.parameters()).dtype
66
+
67
+ @classmethod
68
+ def from_pretrained(
69
+ cls,
70
+ pretrained_model_name_or_path: Union[str, Path, IO[bytes]],
71
+ model_kwargs: Optional[Dict[str, Any]] = None,
72
+ **hf_kwargs) -> 'MDMModel':
73
+ if Path(pretrained_model_name_or_path).exists():
74
+ checkpoint_path = pretrained_model_name_or_path
75
+ else:
76
+ checkpoint_path = hf_hub_download(
77
+ repo_id=pretrained_model_name_or_path,
78
+ repo_type="model",
79
+ filename="model.pt",
80
+ **hf_kwargs
81
+ )
82
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
83
+
84
+ model_config = checkpoint['model_config']
85
+ if model_kwargs is not None:
86
+ model_config.update(model_kwargs)
87
+ model = cls(**model_config)
88
+ model.load_state_dict(checkpoint['model'], strict=False)
89
+
90
+ return model
91
+
92
+ def init_weights(self):
93
+ self.encoder.init_weights()
94
+
95
+ def enable_pytorch_native_sdpa(self):
96
+ self.encoder.enable_pytorch_native_sdpa()
97
+
98
+ def forward(self,
99
+ image: torch.Tensor,
100
+ num_tokens: Union[int, torch.LongTensor],
101
+ depth: Union[None, torch.Tensor]=None,
102
+ **kwargs) -> Dict[str, torch.Tensor]:
103
+ batch_size, _, img_h, img_w = image.shape
104
+ device, dtype = image.device, image.dtype
105
+
106
+ assert depth is not None # in this version, depth is required
107
+ if depth.dim() == 3:
108
+ depth = depth.unsqueeze(1) # from (B, H, W) to (B, 1, H, W)
109
+
110
+ aspect_ratio = img_w / img_h
111
+ base_h, base_w = (num_tokens / aspect_ratio) ** 0.5, (num_tokens * aspect_ratio) ** 0.5
112
+ if isinstance(base_h, torch.Tensor):
113
+ base_h, base_w = base_h.round().long(), base_w.round().long()
114
+ else:
115
+ base_h, base_w = round(base_h), round(base_w)
116
+
117
+ # Backbones encoding
118
+ features, cls_token, _, _ = self.encoder(image, depth, base_h, base_w, return_class_token=True, remap_depth_in=self.remap_depth_in, **kwargs)
119
+
120
+ features = features + cls_token[..., None, None]
121
+ features = [features, None, None, None, None]
122
+
123
+ # Concat UVs for aspect ratio input
124
+ for level in range(5):
125
+ uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
126
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
127
+ if features[level] is None:
128
+ features[level] = uv
129
+ else:
130
+ features[level] = torch.concat([features[level], uv], dim=1)
131
+
132
+ # Shared neck
133
+ features = self.neck(features)
134
+
135
+ # Heads decoding
136
+ depth_reg, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['depth_head', 'normal_head', 'mask_head'])
137
+ metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None
138
+
139
+ # Resize
140
+ depth_reg, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [depth_reg, normal, mask])
141
+
142
+ # Remap output
143
+ if depth_reg is not None:
144
+ if self.remap_depth_out == 'exp':
145
+ depth_reg = depth_reg.exp().squeeze(1)
146
+ elif self.remap_depth_out == 'linear':
147
+ depth_reg = depth_reg.squeeze(1)
148
+ else:
149
+ raise ValueError(f"Invalid remap_depth_out: {self.remap_depth_out}")
150
+ if normal is not None:
151
+ normal = normal.permute(0, 2, 3, 1)
152
+ normal = F.normalize(normal, dim=-1)
153
+ if mask is not None:
154
+ mask_prob = mask.squeeze(1).sigmoid()
155
+ # mask_logits = mask.squeeze(1)
156
+ else:
157
+ mask_prob = None
158
+ if metric_scale is not None:
159
+ metric_scale = metric_scale.squeeze(1).exp()
160
+
161
+ return_dict = {
162
+ 'depth_reg': depth_reg,
163
+ 'normal': normal,
164
+ 'mask': mask_prob,
165
+ }
166
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
167
+
168
+ return return_dict
169
+
170
+ @torch.inference_mode()
171
+ def infer(
172
+ self,
173
+ image: torch.Tensor,
174
+ depth_in: torch.Tensor = None,
175
+ num_tokens: int = None,
176
+ resolution_level: int = 9,
177
+ apply_mask: bool = True,
178
+ use_fp16: bool = True,
179
+ intrinsics: Optional[torch.Tensor] = None,
180
+ **kwargs
181
+ ) -> Dict[str, torch.Tensor]:
182
+ if image.dim() == 3:
183
+ omit_batch_dim = True
184
+ image = image.unsqueeze(0)
185
+ else:
186
+ omit_batch_dim = False
187
+ image = image.to(dtype=self.dtype, device=self.device)
188
+
189
+ if (depth_in is not None) and (depth_in.dim() == 2):
190
+ depth_in = depth_in.unsqueeze(0).to(dtype=self.dtype, device=self.device)
191
+
192
+ original_height, original_width = image.shape[-2:]
193
+ area = original_height * original_width
194
+ aspect_ratio = original_width / original_height
195
+
196
+ # Determine the number of base tokens to use
197
+ if num_tokens is None:
198
+ min_tokens, max_tokens = self.num_tokens_range
199
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
200
+
201
+ # Forward pass
202
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=use_fp16 and self.dtype != torch.bfloat16):
203
+ output = self.forward(image, num_tokens=num_tokens, depth=depth_in, **kwargs)
204
+ depth_reg, mask = (output.get(k, None) for k in ['depth_reg', 'mask'])
205
+
206
+ # Always process the output in fp32 precision
207
+ depth_reg, mask = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [depth_reg, mask])
208
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
209
+ if mask is not None:
210
+ mask_binary = mask > 0.5
211
+ else:
212
+ mask_binary = None
213
+
214
+ depth = depth_reg
215
+ if intrinsics is not None:
216
+ points = depth_to_pointcloud(depth, intrinsics)
217
+ else:
218
+ points = None
219
+
220
+ # Apply mask
221
+ if apply_mask and mask_binary is not None:
222
+ points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
223
+ depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
224
+
225
+ return_dict = {
226
+ 'points': points,
227
+ 'depth': depth,
228
+ 'mask': mask_binary,
229
+ }
230
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
231
+
232
+ if omit_batch_dim:
233
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
234
+
235
+ return return_dict
236
+
237
+ def forward_feat(self,
238
+ image: torch.Tensor,
239
+ num_tokens: Union[int, torch.LongTensor],
240
+ depth: Union[None, torch.Tensor]=None,
241
+ **kwargs) -> Dict[str, torch.Tensor]:
242
+ batch_size, _, img_h, img_w = image.shape
243
+ device, dtype = image.device, image.dtype
244
+
245
+ assert depth is not None # in this version, depth is required
246
+ if depth.dim() == 3:
247
+ depth = depth.unsqueeze(1) # from (B, H, W) to (B, 1, H, W)
248
+
249
+ aspect_ratio = img_w / img_h
250
+ base_h, base_w = (num_tokens / aspect_ratio) ** 0.5, (num_tokens * aspect_ratio) ** 0.5
251
+ if isinstance(base_h, torch.Tensor):
252
+ base_h, base_w = base_h.round().long(), base_w.round().long()
253
+ else:
254
+ base_h, base_w = round(base_h), round(base_w)
255
+
256
+ # Backbones encoding
257
+ features, cls_token, _, _ = self.encoder(image, depth, base_h, base_w, return_class_token=True, remap_depth_in=self.remap_depth_in, **kwargs)
258
+
259
+ return features, cls_token
260
+
261
+
262
+ @torch.inference_mode()
263
+ def infer_feat(
264
+ self,
265
+ image: torch.Tensor,
266
+ depth_in: torch.Tensor = None,
267
+ num_tokens: int = None,
268
+ resolution_level: int = 9,
269
+ apply_mask: bool = True,
270
+ use_fp16: bool = True,
271
+ intrinsics: Optional[torch.Tensor] = None,
272
+ **kwargs
273
+ ):
274
+ if image.dim() == 3:
275
+ omit_batch_dim = True
276
+ image = image.unsqueeze(0)
277
+ else:
278
+ omit_batch_dim = False
279
+ image = image.to(dtype=self.dtype, device=self.device)
280
+
281
+ if (depth_in is not None) and (depth_in.dim() == 2):
282
+ depth_in = depth_in.unsqueeze(0).to(dtype=self.dtype, device=self.device)
283
+
284
+ original_height, original_width = image.shape[-2:]
285
+ area = original_height * original_width
286
+ aspect_ratio = original_width / original_height
287
+
288
+ # Determine the number of base tokens to use
289
+ if num_tokens is None:
290
+ min_tokens, max_tokens = self.num_tokens_range
291
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
292
+
293
+ # Forward pass
294
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=use_fp16 and self.dtype != torch.bfloat16):
295
+ features, cls_token = self.forward_feat(image, num_tokens=num_tokens, depth=depth_in, **kwargs)
296
+
297
+ return features, cls_token
third_party/lingbot_depth/mdm/utils/__init__.py ADDED
File without changes
third_party/lingbot_depth/mdm/utils/geo.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
4
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
5
+ if aspect_ratio is None:
6
+ aspect_ratio = width / height
7
+
8
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
9
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
10
+
11
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
12
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
13
+ u, v = torch.meshgrid(u, v, indexing='xy')
14
+ uv = torch.stack([u, v], dim=-1)
15
+ return uv
16
+
17
+ def depth_to_pointcloud(depth, intrinsic_normalized, depth_scale=1.0):
18
+ """
19
+ Convert depth map to point cloud (pure Tensor version, no point filtering)
20
+
21
+ Args:
22
+ depth: torch.Tensor, shape (H, W) or (B, H, W), depth map
23
+ intrinsic_normalized: torch.Tensor, shape (3, 3) or (B, 3, 3), normalized intrinsic matrix
24
+ Normalized intrinsics: fx' = fx/W, fy' = fy/H, cx' = cx/W, cy' = cy/H
25
+ depth_scale: float, depth scale factor, default 1000.0
26
+
27
+ Returns:
28
+ points: torch.Tensor, shape (H, W, 3) or (B, H, W, 3), point cloud coordinates (x, y, z)
29
+ """
30
+ # Handle batch dimension
31
+ if depth.dim() == 2:
32
+ depth = depth.unsqueeze(0) # (1, H, W)
33
+ intrinsic_normalized = intrinsic_normalized.unsqueeze(0) # (1, 3, 3)
34
+ squeeze_output = True
35
+ else:
36
+ squeeze_output = False
37
+
38
+ B, H, W = depth.shape
39
+ device = depth.device
40
+
41
+ # Denormalize intrinsics
42
+ fx = intrinsic_normalized[:, 0, 0] * W # (B,)
43
+ fy = intrinsic_normalized[:, 1, 1] * H
44
+ cx = intrinsic_normalized[:, 0, 2] * W
45
+ cy = intrinsic_normalized[:, 1, 2] * H
46
+
47
+ # Create pixel coordinate grid (H, W)
48
+ v, u = torch.meshgrid(
49
+ torch.arange(H, device=device, dtype=torch.float32),
50
+ torch.arange(W, device=device, dtype=torch.float32),
51
+ indexing='ij'
52
+ )
53
+
54
+ # Expand to batch dimension (B, H, W)
55
+ u = u.unsqueeze(0).expand(B, -1, -1)
56
+ v = v.unsqueeze(0).expand(B, -1, -1)
57
+
58
+ # Backproject to 3D space
59
+ z = depth / depth_scale # (B, H, W)
60
+
61
+ # Expand intrinsic dimensions for broadcasting (B, 1, 1)
62
+ fx = fx.view(B, 1, 1)
63
+ fy = fy.view(B, 1, 1)
64
+ cx = cx.view(B, 1, 1)
65
+ cy = cy.view(B, 1, 1)
66
+
67
+ x = (u - cx) * z / fx # (B, H, W)
68
+ y = (v - cy) * z / fy # (B, H, W)
69
+
70
+ # Stack coordinates (B, H, W, 3)
71
+ points = torch.stack([x, y, z], dim=-1)
72
+
73
+ if squeeze_output:
74
+ points = points.squeeze(0) # (H, W, 3)
75
+
76
+ return points
77
+
78
+
79
+ # Usage example
80
+ if __name__ == "__main__":
81
+ # Single image
82
+ depth = torch.rand(480, 640) * 5000 # Depth values
83
+ intrinsic_norm = torch.tensor([
84
+ [525.0/640, 0, 319.5/640],
85
+ [0, 525.0/480, 239.5/480],
86
+ [0, 0, 1]
87
+ ])
88
+
89
+ points = depth_to_pointcloud(depth, intrinsic_norm)
90
+ print(f"Point cloud shape: {points.shape}") # (480, 640, 3)
91
+
92
+ # Batch processing
93
+ depth_batch = torch.rand(4, 480, 640) * 5000
94
+ intrinsic_batch = intrinsic_norm.unsqueeze(0).expand(4, -1, -1)
95
+
96
+ points_batch = depth_to_pointcloud(depth_batch, intrinsic_batch)
97
+ print(f"Batch point cloud shape: {points_batch.shape}") # (4, 480, 640, 3)
98
+
99
+ # Flatten to (N, 3) format if needed
100
+ points_flat = points.reshape(-1, 3)
101
+ print(f"Flattened shape: {points_flat.shape}") # (480*640, 3)
102
+
103
+ # Batch flatten to (B, N, 3)
104
+ points_batch_flat = points_batch.reshape(4, -1, 3)
105
+ print(f"Batch flattened shape: {points_batch_flat.shape}") # (4, 480*640, 3)
third_party/lingbot_depth/mdm/utils/io.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ from typing import IO
4
+ import zipfile
5
+ import json
6
+ import io
7
+ from typing import *
8
+ from pathlib import Path
9
+ import re
10
+ from PIL import Image, PngImagePlugin
11
+
12
+ import numpy as np
13
+ import cv2
14
+
15
+ from .tools import timeit
16
+
17
+
18
+ def save_glb(
19
+ save_path: Union[str, os.PathLike],
20
+ vertices: np.ndarray,
21
+ faces: np.ndarray,
22
+ vertex_uvs: np.ndarray,
23
+ texture: np.ndarray,
24
+ vertex_normals: Optional[np.ndarray] = None,
25
+ ):
26
+ import trimesh
27
+ import trimesh.visual
28
+ from PIL import Image
29
+
30
+ trimesh.Trimesh(
31
+ vertices=vertices,
32
+ vertex_normals=vertex_normals,
33
+ faces=faces,
34
+ visual = trimesh.visual.texture.TextureVisuals(
35
+ uv=vertex_uvs,
36
+ material=trimesh.visual.material.PBRMaterial(
37
+ baseColorTexture=Image.fromarray(texture),
38
+ metallicFactor=0.5,
39
+ roughnessFactor=1.0
40
+ )
41
+ ),
42
+ process=False
43
+ ).export(save_path)
44
+
45
+
46
+ def save_ply(
47
+ save_path: Union[str, os.PathLike],
48
+ vertices: np.ndarray,
49
+ faces: np.ndarray,
50
+ vertex_colors: np.ndarray,
51
+ vertex_normals: Optional[np.ndarray] = None,
52
+ ):
53
+ import trimesh
54
+ import trimesh.visual
55
+ from PIL import Image
56
+
57
+ trimesh.Trimesh(
58
+ vertices=vertices,
59
+ faces=faces,
60
+ vertex_colors=vertex_colors,
61
+ vertex_normals=vertex_normals,
62
+ process=False
63
+ ).export(save_path)
64
+
65
+
66
+ def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray:
67
+ """
68
+ Read a image, return uint8 RGB array of shape (H, W, 3).
69
+ """
70
+ if isinstance(path, (str, os.PathLike)):
71
+ data = Path(path).read_bytes()
72
+ else:
73
+ data = path.read()
74
+ image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
75
+ return image
76
+
77
+
78
+ def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95):
79
+ """
80
+ Write a image, input uint8 RGB array of shape (H, W, 3).
81
+ """
82
+ data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes()
83
+ if isinstance(path, (str, os.PathLike)):
84
+ Path(path).write_bytes(data)
85
+ else:
86
+ path.write(data)
87
+
88
+
89
+ def read_depth(path: Union[str, os.PathLike, IO]) -> np.ndarray:
90
+ """
91
+ Read a depth image, return float32 depth array of shape (H, W).
92
+ """
93
+ if isinstance(path, (str, os.PathLike)):
94
+ data = Path(path).read_bytes()
95
+ else:
96
+ data = path.read()
97
+ pil_image = Image.open(io.BytesIO(data))
98
+ near = float(pil_image.info.get('near'))
99
+ far = float(pil_image.info.get('far'))
100
+ depth = np.array(pil_image)
101
+ mask_nan, mask_inf = depth == 0, depth == 65535
102
+ depth = (depth.astype(np.float32) - 1) / 65533
103
+ depth = near ** (1 - depth) * far ** depth
104
+ if 'unit' in pil_image.info: # Legacy support for depth units
105
+ unit = float(pil_image.info.get('unit'))
106
+ depth = depth * unit
107
+ depth[mask_nan] = np.nan
108
+ depth[mask_inf] = np.inf
109
+ return depth
110
+
111
+ def write_depth(
112
+ path: Union[str, os.PathLike, IO],
113
+ depth: np.ndarray,
114
+ max_range: float = 1e5,
115
+ compression_level: int = 7,
116
+ ):
117
+ """
118
+ Encode and write a depth image as 16-bit PNG format.
119
+ ## Parameters:
120
+ - `path: Union[str, os.PathLike, IO]`
121
+ The file path or file object to write to.
122
+ - `depth: np.ndarray`
123
+ The depth array, float32 array of shape (H, W).
124
+ May contain `NaN` for invalid values and `Inf` for infinite values.
125
+
126
+ Depth values are encoded as follows:
127
+ - 0: unknown
128
+ - 1 ~ 65534: depth values in logarithmic
129
+ - 65535: infinity
130
+
131
+ metadata is stored in the PNG file as text fields:
132
+ - `near`: the minimum depth value
133
+ - `far`: the maximum depth value
134
+ """
135
+ mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth)
136
+
137
+ depth = depth.astype(np.float32)
138
+ mask_finite = depth
139
+ near = max(depth[mask_values].min(), 1e-5)
140
+ far = max(near * 1.1, min(depth[mask_values].max(), near * max_range))
141
+ depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534
142
+ depth[mask_nan] = 0
143
+ depth[mask_inf] = 65535
144
+
145
+ pil_image = Image.fromarray(depth)
146
+ pnginfo = PngImagePlugin.PngInfo()
147
+ pnginfo.add_text('near', str(near))
148
+ pnginfo.add_text('far', str(far))
149
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
150
+
151
+
152
+ def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]:
153
+ """
154
+ Read a segmentation mask
155
+ ### Parameters:
156
+ - `path: Union[str, os.PathLike, IO]`
157
+ The file path or file object to read from.
158
+ ### Returns:
159
+ - `Tuple[np.ndarray, Dict[str, int]]`
160
+ A tuple containing:
161
+ - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W).
162
+ - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}.
163
+ """
164
+ if isinstance(path, (str, os.PathLike)):
165
+ data = Path(path).read_bytes()
166
+ else:
167
+ data = path.read()
168
+ pil_image = Image.open(io.BytesIO(data))
169
+ labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None
170
+ mask = np.array(pil_image)
171
+ return mask, labels
172
+
173
+
174
+ def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7):
175
+ """
176
+ Write a segmentation mask and label mapping, as PNG format.
177
+ ### Parameters:
178
+ - `path: Union[str, os.PathLike, IO]`
179
+ The file path or file object to write to.
180
+ - `mask: np.ndarray`
181
+ The segmentation mask, uint8 or uint16 array of shape (H, W).
182
+ - `labels: Dict[str, int] = None`
183
+ The label mapping, a dictionary of {label_name: label_id}.
184
+ - `compression_level: int = 7`
185
+ The compression level for PNG compression.
186
+ """
187
+ assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}"
188
+ pil_image = Image.fromarray(mask)
189
+ pnginfo = PngImagePlugin.PngInfo()
190
+ if labels is not None:
191
+ labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':'))
192
+ pnginfo.add_text('labels', labels_json)
193
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
194
+
195
+
196
+
197
+ def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray:
198
+ """
199
+ Read a normal image, return float32 normal array of shape (H, W, 3).
200
+ """
201
+ if isinstance(path, (str, os.PathLike)):
202
+ data = Path(path).read_bytes()
203
+ else:
204
+ data = path.read()
205
+ normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
206
+ mask_nan = np.all(normal == 0, axis=-1)
207
+ normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
208
+ normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12)
209
+ normal[mask_nan] = np.nan
210
+ return normal
211
+
212
+
213
+ def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray:
214
+ """
215
+ Write a normal image, input float32 normal array of shape (H, W, 3).
216
+ """
217
+ mask_nan = np.isnan(normal).any(axis=-1)
218
+ normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
219
+ normal[mask_nan] = 0
220
+ data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
221
+ if isinstance(path, (str, os.PathLike)):
222
+ Path(path).write_bytes(data)
223
+ else:
224
+ path.write(data)
225
+
226
+
227
+ def read_mask(path: Union[str, os.PathLike, IO[bytes]]) -> np.ndarray:
228
+ """
229
+ Read a binary mask, return bool array of shape (H, W).
230
+ """
231
+ if isinstance(path, (str, os.PathLike)):
232
+ data = Path(path).read_bytes()
233
+ else:
234
+ data = path.read()
235
+ mask = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED)
236
+ if len(mask.shape) == 3:
237
+ mask = mask[..., 0]
238
+ return mask > 0
239
+
240
+
241
+ def write_mask(path: Union[str, os.PathLike, IO[bytes]], mask: np.ndarray, compression_level: int = 7):
242
+ """
243
+ Write a binary mask, input bool array of shape (H, W).
244
+ """
245
+ assert mask.dtype == bool, f"Mask must be bool array, got {mask.dtype}"
246
+ mask = (mask.astype(np.uint8) * 255).astype(np.uint8)
247
+ data = cv2.imencode('.png', mask, [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
248
+ if isinstance(path, (str, os.PathLike)):
249
+ Path(path).write_bytes(data)
250
+ else:
251
+ path.write(data)
252
+
253
+
254
+ JSON_TYPE = Union[str, int, float, bool, None, Dict[str, "JSON"], List["JSON"]]
255
+
256
+
257
+ def read_json(path: Union[str, os.PathLike, IO[str]]) -> JSON_TYPE:
258
+ if isinstance(path, (str, os.PathLike)):
259
+ text = Path(path).read_text()
260
+ else:
261
+ text = path.read()
262
+ return json.loads(text)
263
+
264
+
265
+ def write_json(path: Union[str, os.PathLike, IO[str]], content: JSON_TYPE):
266
+ text = json.dumps(content)
267
+ if isinstance(path, (str, os.PathLike)):
268
+ Path(path).write_text(text)
269
+ else:
270
+ path.write(text)
third_party/lingbot_depth/mdm/utils/tools.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import time
3
+ from pathlib import Path
4
+ from numbers import Number
5
+ from functools import wraps
6
+ import warnings
7
+ import math
8
+ import json
9
+ import os
10
+ import importlib
11
+ import importlib.util
12
+
13
+
14
+ def catch_exception(fn):
15
+ @wraps(fn)
16
+ def wrapper(*args, **kwargs):
17
+ try:
18
+ return fn(*args, **kwargs)
19
+ except Exception as e:
20
+ import traceback
21
+ print(f"Exception in {fn.__name__}", end='r')
22
+ # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})
23
+ traceback.print_exc(chain=False)
24
+ time.sleep(0.1)
25
+ return None
26
+ return wrapper
27
+
28
+
29
+ class CallbackOnException:
30
+ def __init__(self, callback: Callable, exception: type):
31
+ self.exception = exception
32
+ self.callback = callback
33
+
34
+ def __enter__(self):
35
+ return self
36
+
37
+ def __exit__(self, exc_type, exc_val, exc_tb):
38
+ if isinstance(exc_val, self.exception):
39
+ self.callback()
40
+ return True
41
+ return False
42
+
43
+ def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
44
+ for k, v in d.items():
45
+ if isinstance(v, dict):
46
+ for sub_key in traverse_nested_dict_keys(v):
47
+ yield (k, ) + sub_key
48
+ else:
49
+ yield (k, )
50
+
51
+
52
+ def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
53
+ for k in keys:
54
+ d = d.get(k, default)
55
+ if d is None:
56
+ break
57
+ return d
58
+
59
+ def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
60
+ for k in keys[:-1]:
61
+ d = d.setdefault(k, {})
62
+ d[keys[-1]] = value
63
+
64
+
65
+ def key_average(list_of_dicts: list) -> Dict[str, Any]:
66
+ """
67
+ Returns a dictionary with the average value of each key in the input list of dictionaries.
68
+ """
69
+ _nested_dict_keys = set()
70
+ for d in list_of_dicts:
71
+ _nested_dict_keys.update(traverse_nested_dict_keys(d))
72
+ _nested_dict_keys = sorted(_nested_dict_keys)
73
+ result = {}
74
+ for k in _nested_dict_keys:
75
+ values = []
76
+ for d in list_of_dicts:
77
+ v = get_nested_dict(d, k)
78
+ if v is not None and not math.isnan(v):
79
+ values.append(v)
80
+ avg = sum(values) / len(values) if values else float('nan')
81
+ set_nested_dict(result, k, avg)
82
+ return result
83
+
84
+
85
+ def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
86
+ """
87
+ Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
88
+ """
89
+ items = []
90
+ if parent_key is None:
91
+ parent_key = ()
92
+ for k, v in d.items():
93
+ new_key = parent_key + (k, )
94
+ if isinstance(v, MutableMapping):
95
+ items.extend(flatten_nested_dict(v, new_key).items())
96
+ else:
97
+ items.append((new_key, v))
98
+ return dict(items)
99
+
100
+
101
+ def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
102
+ """
103
+ Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
104
+ """
105
+ result = {}
106
+ for k, v in d.items():
107
+ sub_dict = result
108
+ for k_ in k[:-1]:
109
+ if k_ not in sub_dict:
110
+ sub_dict[k_] = {}
111
+ sub_dict = sub_dict[k_]
112
+ sub_dict[k[-1]] = v
113
+ return result
114
+
115
+
116
+ def read_jsonl(file):
117
+ import json
118
+ with open(file, 'r') as f:
119
+ data = f.readlines()
120
+ return [json.loads(line) for line in data]
121
+
122
+
123
+ def write_jsonl(data: List[dict], file):
124
+ import json
125
+ with open(file, 'w') as f:
126
+ for item in data:
127
+ f.write(json.dumps(item) + '\n')
128
+
129
+
130
+ def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
131
+ import pandas as pd
132
+ data = [flatten_nested_dict(d) for d in data]
133
+ df = pd.DataFrame(data)
134
+ df = df.sort_index(axis=1)
135
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
136
+ return df
137
+
138
+
139
+ def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
140
+ if isinstance(d, str):
141
+ for old, new in mapping.items():
142
+ d = d.replace(old, new)
143
+ elif isinstance(d, list):
144
+ for i, item in enumerate(d):
145
+ d[i] = recursive_replace(item, mapping)
146
+ elif isinstance(d, dict):
147
+ for k, v in d.items():
148
+ d[k] = recursive_replace(v, mapping)
149
+ return d
150
+
151
+
152
+ class timeit:
153
+ _history: Dict[str, List['timeit']] = {}
154
+
155
+ def __init__(self, name: str = None, verbose: bool = True, average: bool = False):
156
+ self.name = name
157
+ self.verbose = verbose
158
+ self.start = None
159
+ self.end = None
160
+ self.average = average
161
+ if average and name not in timeit._history:
162
+ timeit._history[name] = []
163
+
164
+ def __call__(self, func: Callable):
165
+ import inspect
166
+ if inspect.iscoroutinefunction(func):
167
+ async def wrapper(*args, **kwargs):
168
+ with timeit(self.name or func.__qualname__):
169
+ ret = await func(*args, **kwargs)
170
+ return ret
171
+ return wrapper
172
+ else:
173
+ def wrapper(*args, **kwargs):
174
+ with timeit(self.name or func.__qualname__):
175
+ ret = func(*args, **kwargs)
176
+ return ret
177
+ return wrapper
178
+
179
+ def __enter__(self):
180
+ self.start = time.time()
181
+ return self
182
+
183
+ @property
184
+ def time(self) -> float:
185
+ assert self.start is not None, "Time not yet started."
186
+ assert self.end is not None, "Time not yet ended."
187
+ return self.end - self.start
188
+
189
+ @property
190
+ def average_time(self) -> float:
191
+ assert self.average, "Average time not available."
192
+ return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
193
+
194
+ @property
195
+ def history(self) -> List['timeit']:
196
+ return timeit._history.get(self.name, [])
197
+
198
+ def __exit__(self, exc_type, exc_val, exc_tb):
199
+ self.end = time.time()
200
+ if self.average:
201
+ timeit._history[self.name].append(self)
202
+ if self.verbose:
203
+ if self.average:
204
+ avg = self.average_time
205
+ print(f"{self.name or 'It'} took {avg:.6f} seconds in average.")
206
+ else:
207
+ print(f"{self.name or 'It'} took {self.time:.6f} seconds.")
208
+
209
+
210
+ def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
211
+ first = strings[0]
212
+
213
+ for start in range(len(first)):
214
+ if any(s[start] != strings[0][start] for s in strings):
215
+ break
216
+
217
+ for end in range(1, min(len(s) for s in strings)):
218
+ if any(s[-end] != first[-end] for s in strings):
219
+ break
220
+
221
+ return [s[start:len(s) - end + 1] for s in strings]
222
+
223
+
224
+ def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
225
+ from concurrent.futures import ThreadPoolExecutor
226
+ from contextlib import nullcontext
227
+ from tqdm import tqdm
228
+
229
+ if pbar is not None:
230
+ pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
231
+ else:
232
+ pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
233
+
234
+ def decorator(fn: Callable):
235
+ with (
236
+ ThreadPoolExecutor(max_workers=num_workers) as executor,
237
+ pbar
238
+ ):
239
+ pbar.refresh()
240
+ @catch_exception
241
+ @suppress_traceback
242
+ def _fn(input):
243
+ ret = fn(input)
244
+ pbar.update()
245
+ return ret
246
+ executor.map(_fn, inputs)
247
+ executor.shutdown(wait=True)
248
+
249
+ return decorator
250
+
251
+
252
+ def suppress_traceback(fn):
253
+ @wraps(fn)
254
+ def wrapper(*args, **kwargs):
255
+ try:
256
+ return fn(*args, **kwargs)
257
+ except Exception as e:
258
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
259
+ raise
260
+ return wrapper
261
+
262
+
263
+ class no_warnings:
264
+ def __init__(self, action: str = 'ignore', **kwargs):
265
+ self.action = action
266
+ self.filter_kwargs = kwargs
267
+
268
+ def __call__(self, fn):
269
+ @wraps(fn)
270
+ def wrapper(*args, **kwargs):
271
+ with warnings.catch_warnings():
272
+ warnings.simplefilter(self.action, **self.filter_kwargs)
273
+ return fn(*args, **kwargs)
274
+ return wrapper
275
+
276
+ def __enter__(self):
277
+ self.warnings_manager = warnings.catch_warnings()
278
+ self.warnings_manager.__enter__()
279
+ warnings.simplefilter(self.action, **self.filter_kwargs)
280
+
281
+ def __exit__(self, exc_type, exc_val, exc_tb):
282
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
283
+
284
+
285
+ def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str):
286
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
287
+ module = importlib.util.module_from_spec(spec)
288
+ spec.loader.exec_module(module)
289
+ return module
third_party/lingbot_depth/mdm/utils/vis.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import numpy as np
4
+ import matplotlib
5
+ import trimesh
6
+ import random
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import os
10
+
11
+ def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
12
+ depth = depth.copy()
13
+ if mask is None:
14
+ depth = np.where(depth > 0, depth, np.nan)
15
+ else:
16
+ depth = np.where((depth > 0) & mask, depth, np.nan)
17
+ disp = 1 / depth
18
+ if normalize:
19
+ min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99)
20
+ disp = (disp - min_disp) / (max_disp - min_disp)
21
+
22
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
23
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
24
+ return colored
25
+
26
+
27
+ def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray:
28
+ if mask is not None:
29
+ depth = np.where(mask, depth, np.nan)
30
+
31
+ min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999)
32
+ depth = (depth - min_depth) / (max_depth - min_depth)
33
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0)
34
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
35
+ return colored
36
+
37
+
38
+ def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
39
+ if mask is not None:
40
+ disparity = np.where(mask, disparity, np.nan)
41
+
42
+ if normalize:
43
+ min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999)
44
+ disparity = (disparity - min_disp) / (max_disp - min_disp)
45
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0)
46
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
47
+ return colored
48
+
49
+
50
+ def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
51
+ if mask is not None:
52
+ normal = np.where(mask[..., None], normal, 0)
53
+ normal = normal * [0.5, -0.5, -0.5] + 0.5
54
+ normal = (normal.clip(0, 1) * 255).astype(np.uint8)
55
+ return normal
56
+
57
+
58
+ def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None):
59
+ vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map))
60
+ cmap = matplotlib.colormaps[cmap]
61
+ colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3]
62
+ if mask is not None:
63
+ colorized_error_map = np.where(mask[..., None], colorized_error_map, 0)
64
+ colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8))
65
+ return colorized_error_map
third_party/lingbot_depth/pyproject.toml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "mdm"
7
+ version = "1.0.0"
8
+ readme = "README.md"
9
+ dependencies = [
10
+ "click",
11
+ "opencv-python",
12
+ "scipy",
13
+ "matplotlib",
14
+ "trimesh",
15
+ "pillow",
16
+ "huggingface_hub",
17
+ "numpy",
18
+ "torch==2.6.0",
19
+ "torchvision",
20
+ "xformers==v0.0.29.post2",
21
+ ]
22
+ requires-python = ">=3.9"
23
+
24
+ [tool.setuptools.packages.find]
25
+ where = ["."]
26
+ include = ["mdm*"]
third_party/sam3/pyproject.toml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sam3"
7
+ dynamic = ["version"]
8
+ description = "SAM3 (Segment Anything Model 3) implementation"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = {file = "LICENSE"}
12
+ authors = [
13
+ {name = "Meta AI Research"}
14
+ ]
15
+ classifiers = [
16
+ "Development Status :: 4 - Beta",
17
+ "Intended Audience :: Science/Research",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Programming Language :: Python :: 3",
20
+ "Programming Language :: Python :: 3.8",
21
+ "Programming Language :: Python :: 3.9",
22
+ "Programming Language :: Python :: 3.10",
23
+ "Programming Language :: Python :: 3.11",
24
+ "Programming Language :: Python :: 3.12",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ ]
27
+ dependencies = [
28
+ "timm>=1.0.17",
29
+ "numpy>=1.26,<2",
30
+ "tqdm",
31
+ "ftfy==6.1.1",
32
+ "regex",
33
+ "iopath>=0.1.10",
34
+ "typing_extensions",
35
+ "huggingface_hub",
36
+ ]
37
+
38
+ [project.optional-dependencies]
39
+ dev = [
40
+ "pytest",
41
+ "pytest-cov",
42
+ "black==24.2.0",
43
+ "ufmt==2.8.0",
44
+ "ruff-api==0.1.0",
45
+ "usort==1.0.2",
46
+ "gitpython==3.1.31",
47
+ "yt-dlp",
48
+ "pandas",
49
+ "opencv-python",
50
+ "pycocotools",
51
+ "numba",
52
+ "python-rapidjson",
53
+ ]
54
+ notebooks = [
55
+ "matplotlib",
56
+ "jupyter",
57
+ "notebook",
58
+ "ipywidgets",
59
+ "ipycanvas",
60
+ "ipympl",
61
+ "pycocotools",
62
+ "decord",
63
+ "opencv-python",
64
+ "einops",
65
+ "scikit-image",
66
+ "scikit-learn",
67
+ ]
68
+ train = [
69
+ "hydra-core",
70
+ "submitit",
71
+ "tensorboard",
72
+ "zstandard",
73
+ "scipy",
74
+ "torchmetrics",
75
+ "fvcore",
76
+ "fairscale",
77
+ "scikit-image",
78
+ "scikit-learn",
79
+ ]
80
+
81
+ [project.urls]
82
+ "Homepage" = "https://github.com/facebookresearch/sam3"
83
+ "Bug Tracker" = "https://github.com/facebookresearch/sam3/issues"
84
+
85
+ [tool.setuptools.packages.find]
86
+ include = ["sam3*"]
87
+ exclude = ["build*", "scripts*", "examples*"]
88
+
89
+ [tool.setuptools.package-data]
90
+ sam3 = ["assets/*.txt.gz"]
91
+
92
+ [tool.setuptools.dynamic]
93
+ version = {attr = "sam3.__version__"}
94
+
95
+ [tool.black]
96
+ line-length = 88
97
+ target-version = ['py38', 'py39', 'py310', 'py311', 'py312']
98
+ include = '\.pyi?$'
99
+
100
+ [tool.isort]
101
+ profile = "black"
102
+ multi_line_output = 3
103
+
104
+ [tool.usort]
105
+ first_party_detection = false
106
+
107
+ [tool.ufmt]
108
+ formatter = "ruff-api"
109
+
110
+ [tool.mypy]
111
+ python_version = "3.12"
112
+ warn_return_any = true
113
+ warn_unused_configs = true
114
+ disallow_untyped_defs = true
115
+ disallow_incomplete_defs = true
116
+
117
+ [[tool.mypy.overrides]]
118
+ module = [
119
+ "torch.*",
120
+ "torchvision.*",
121
+ "timm.*",
122
+ "numpy.*",
123
+ "PIL.*",
124
+ "tqdm.*",
125
+ "ftfy.*",
126
+ "regex.*",
127
+ "iopath.*",
128
+ ]
129
+ ignore_missing_imports = true
130
+
131
+ [tool.pytest.ini_options]
132
+ testpaths = ["tests"]
133
+ python_files = "test_*.py"
134
+ python_classes = "Test*"
135
+ python_functions = "test_*"
third_party/sam3/sam3/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ from .model_builder import build_sam3_image_model, build_sam3_predictor
6
+
7
+ __version__ = "0.1.0"
8
+
9
+ __all__ = ["build_sam3_image_model", "build_sam3_predictor"]
third_party/sam3/sam3/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (453 Bytes). View file
 
third_party/sam3/sam3/__pycache__/logger.cpython-311.pyc ADDED
Binary file (3.67 kB). View file
 
third_party/sam3/sam3/__pycache__/model_builder.cpython-311.pyc ADDED
Binary file (42.9 kB). View file
 
third_party/sam3/sam3/agent/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
third_party/sam3/sam3/agent/agent_core.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import copy
6
+ import json
7
+ import os
8
+
9
+ import cv2
10
+ from PIL import Image
11
+
12
+ from .client_llm import send_generate_request
13
+ from .client_sam3 import call_sam_service
14
+ from .viz import visualize
15
+
16
+
17
+ def save_debug_messages(messages_list, debug, debug_folder_path, debug_jsonl_path):
18
+ """Save messages to debug jsonl file if debug is enabled"""
19
+ if debug and debug_jsonl_path:
20
+ # Ensure the debug directory exists before writing
21
+ os.makedirs(debug_folder_path, exist_ok=True)
22
+ with open(debug_jsonl_path, "w") as f:
23
+ for msg in messages_list:
24
+ f.write(json.dumps(msg, indent=4) + "\n")
25
+
26
+
27
+ def cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path):
28
+ """Clean up debug files when function successfully returns"""
29
+ if debug and debug_folder_path:
30
+ try:
31
+ if os.path.exists(debug_jsonl_path):
32
+ os.remove(debug_jsonl_path)
33
+ if os.path.exists(debug_folder_path):
34
+ os.rmdir(debug_folder_path)
35
+ except Exception as e:
36
+ print(f"Warning: Could not clean up debug files: {e}")
37
+
38
+
39
+ def count_images(messages):
40
+ """Count the total number of images present in the messages history."""
41
+ total = 0
42
+ for message in messages:
43
+ # Check if message has content (should be a list)
44
+ if "content" in message and isinstance(message["content"], list):
45
+ # Iterate through each content item
46
+ for content_item in message["content"]:
47
+ # Check if content item is a dict with type "image"
48
+ if (
49
+ isinstance(content_item, dict)
50
+ and content_item.get("type") == "image"
51
+ ):
52
+ total += 1
53
+ return total
54
+
55
+
56
+ def _prune_messages_for_next_round(
57
+ messages_list,
58
+ used_text_prompts,
59
+ latest_sam3_text_prompt,
60
+ img_path,
61
+ initial_text_prompt,
62
+ ):
63
+ """Return a new messages list that contains only:
64
+ 1) messages[:2] (with optional warning text added to the second message's content)
65
+ 2) the latest assistant message (and everything after it) that contains a segment_phrase tool call
66
+ """
67
+ # There should not be more than 10 messages in the conversation history
68
+ assert len(messages_list) < 10
69
+
70
+ # Part 1: always keep the first two message JSONs
71
+ part1 = copy.deepcopy(messages_list[:2])
72
+
73
+ # Part 2: search backwards for the latest assistant message containing a segment_phrase tool call
74
+ part2_start_idx = None
75
+ for idx in range(len(messages_list) - 1, 1, -1):
76
+ msg = messages_list[idx]
77
+ # We only consider assistant messages with a "content" list
78
+ if msg.get("role") != "assistant" or "content" not in msg:
79
+ continue
80
+ # Look for any content element that is a text containing the segment_phrase tool call
81
+ for content in msg["content"]:
82
+ if (
83
+ isinstance(content, dict)
84
+ and content.get("type") == "text"
85
+ and "<tool>" in content.get("text", "")
86
+ and "segment_phrase" in content.get("text", "")
87
+ ):
88
+ part2_start_idx = idx
89
+ break
90
+ if part2_start_idx is not None:
91
+ break
92
+
93
+ part2 = messages_list[part2_start_idx:] if part2_start_idx is not None else []
94
+
95
+ # Part 3: decide whether to add warning text to the second message in part1
96
+ previously_used = (
97
+ [p for p in used_text_prompts if p != latest_sam3_text_prompt]
98
+ if latest_sam3_text_prompt
99
+ else list(used_text_prompts)
100
+ )
101
+ if part2 and len(previously_used) > 0:
102
+ warning_text = f'Note that we have previously called the segment_phrase tool with each "text_prompt" in this list: {list(previously_used)}, but none of the generated results were satisfactory. So make sure that you do not use any of these phrases as the "text_prompt" to call the segment_phrase tool again.'
103
+ # Replace the second message entirely to keep exactly 2 content items
104
+ part1[1] = {
105
+ "role": "user",
106
+ "content": [
107
+ {"type": "image", "image": img_path},
108
+ {
109
+ "type": "text",
110
+ "text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'."
111
+ + " "
112
+ + warning_text,
113
+ },
114
+ ],
115
+ }
116
+ assert len(part1[1]["content"]) == 2
117
+
118
+ # Build the new messages list: part1 (with optional warning), then part2
119
+ new_messages = list(part1)
120
+ new_messages.extend(part2)
121
+ return new_messages
122
+
123
+
124
+ def agent_inference(
125
+ img_path: str,
126
+ initial_text_prompt: str,
127
+ debug: bool = False,
128
+ send_generate_request=send_generate_request,
129
+ call_sam_service=call_sam_service,
130
+ max_generations: int = 100,
131
+ output_dir="../../sam3_agent_out",
132
+ ):
133
+ """
134
+ Given a text prompt and an image, this tool will perform all aspects of agentic problem solving,
135
+ while saving sam3 and MLLM outputs to their respective directories.
136
+
137
+ Args:
138
+ img_path: Path to the input image
139
+ initial_text_prompt: Initial text prompt from the user
140
+ debug: Whether to enable debug mode
141
+ max_generations: Maximum number of send_generate_request calls allowed (default: 100)
142
+ """
143
+ # setup dir
144
+ sam_output_dir = os.path.join(output_dir, "sam_out")
145
+ error_save_dir = os.path.join(output_dir, "none_out")
146
+ debug_save_dir = os.path.join(output_dir, "agent_debug_out")
147
+ os.makedirs(sam_output_dir, exist_ok=True)
148
+ os.makedirs(error_save_dir, exist_ok=True)
149
+ os.makedirs(debug_save_dir, exist_ok=True)
150
+ current_dir = os.path.dirname(os.path.abspath(__file__))
151
+ MLLM_SYSTEM_PROMPT_PATH = os.path.join(
152
+ current_dir, "system_prompts/system_prompt.txt"
153
+ )
154
+ ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH = os.path.join(
155
+ current_dir, "system_prompts/system_prompt_iterative_checking.txt"
156
+ )
157
+ # init variables
158
+ PATH_TO_LATEST_OUTPUT_JSON = ""
159
+ LATEST_SAM3_TEXT_PROMPT = ""
160
+ USED_TEXT_PROMPTS = (
161
+ set()
162
+ ) # Track all previously used text prompts for segment_phrase
163
+ generation_count = 0 # Counter for number of send_generate_request calls
164
+
165
+ # debug setup
166
+ debug_folder_path = None
167
+ debug_jsonl_path = None
168
+ if debug:
169
+ debug_folder_path = os.path.join(
170
+ debug_save_dir, f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}"
171
+ )
172
+ debug_jsonl_path = os.path.join(debug_folder_path, "debug_history.json")
173
+ os.makedirs(debug_folder_path, exist_ok=True)
174
+
175
+ # The helper functions are now defined outside the agent_inference function
176
+ with open(MLLM_SYSTEM_PROMPT_PATH, "r") as f:
177
+ system_prompt = f.read().strip()
178
+ with open(ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH, "r") as f:
179
+ iterative_checking_system_prompt = f.read().strip()
180
+
181
+ # Construct the initial message list
182
+ messages = [
183
+ {"role": "system", "content": system_prompt},
184
+ {
185
+ "role": "user",
186
+ "content": [
187
+ {"type": "image", "image": img_path},
188
+ {
189
+ "type": "text",
190
+ "text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'.",
191
+ },
192
+ ],
193
+ },
194
+ ]
195
+ print(f"> Text prompt: {initial_text_prompt}")
196
+ print(f"> Image path: {img_path}")
197
+
198
+ print("\n\n")
199
+ print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
200
+ print("\n\n")
201
+ generated_text = send_generate_request(messages)
202
+ print(f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n")
203
+ while generated_text is not None:
204
+ save_debug_messages(messages, debug, debug_folder_path, debug_jsonl_path)
205
+ assert (
206
+ "<tool>" in generated_text,
207
+ f"Generated text does not contain <tool> tag: {generated_text}",
208
+ )
209
+ generated_text = generated_text.split("</tool>", 1)[0] + "</tool>"
210
+ tool_call_json_str = (
211
+ generated_text.split("<tool>")[-1]
212
+ .split("</tool>")[0]
213
+ .strip()
214
+ .replace(r"}}}", r"}}") # remove extra } if any
215
+ )
216
+ try:
217
+ tool_call = json.loads(tool_call_json_str)
218
+ except json.JSONDecodeError:
219
+ raise ValueError(f"Invalid JSON in tool call: {tool_call_json_str}")
220
+
221
+ if PATH_TO_LATEST_OUTPUT_JSON == "":
222
+ # The first tool call must be segment_phrase or report_no_mask
223
+ assert (
224
+ tool_call["name"] == "segment_phrase"
225
+ or tool_call["name"] == "report_no_mask"
226
+ )
227
+
228
+ if tool_call["name"] == "segment_phrase":
229
+ print("🔍 Calling segment_phrase tool...")
230
+ assert list(tool_call["parameters"].keys()) == ["text_prompt"]
231
+
232
+ # Check if this text_prompt has been used before
233
+ current_text_prompt = tool_call["parameters"]["text_prompt"]
234
+ if current_text_prompt in USED_TEXT_PROMPTS:
235
+ print(
236
+ f"❌ Text prompt '{current_text_prompt}' has been used before. Requesting a different prompt."
237
+ )
238
+ duplicate_prompt_message = f"You have previously used '{current_text_prompt}' as your text_prompt to call the segment_phrase tool. You may not use it again. Please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase prompt, while adhering to all the rules stated in the system prompt. You must also never use any of the following text_prompt(s): {str(list(USED_TEXT_PROMPTS))}."
239
+ messages.append(
240
+ {
241
+ "role": "assistant",
242
+ "content": [{"type": "text", "text": generated_text}],
243
+ }
244
+ )
245
+ messages.append(
246
+ {
247
+ "role": "user",
248
+ "content": [{"type": "text", "text": duplicate_prompt_message}],
249
+ }
250
+ )
251
+ else:
252
+ # Add the text_prompt to the set of used prompts
253
+ USED_TEXT_PROMPTS.add(current_text_prompt)
254
+ LATEST_SAM3_TEXT_PROMPT = current_text_prompt
255
+ PATH_TO_LATEST_OUTPUT_JSON = call_sam_service(
256
+ image_path=img_path,
257
+ text_prompt=current_text_prompt,
258
+ output_folder_path=sam_output_dir,
259
+ )
260
+ sam3_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
261
+ sam3_output_image_path = sam3_outputs["output_image_path"]
262
+ num_masks = len(sam3_outputs["pred_boxes"])
263
+
264
+ messages.append(
265
+ {
266
+ "role": "assistant",
267
+ "content": [{"type": "text", "text": generated_text}],
268
+ }
269
+ )
270
+ if num_masks == 0:
271
+ print("❌ No masks generated by SAM3, reporting no mask to Qwen.")
272
+ sam3_output_text_message = f"The segment_phrase tool did not generate any masks for the text_prompt '{current_text_prompt}'. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt. Please be reminded that the original user query was '{initial_text_prompt}'."
273
+ messages.append(
274
+ {
275
+ "role": "user",
276
+ "content": [
277
+ {"type": "text", "text": sam3_output_text_message}
278
+ ],
279
+ }
280
+ )
281
+ else:
282
+ sam3_output_text_message = rf"The segment_phrase tool generated {num_masks} available masks. All {num_masks} available masks are rendered in this image below, now you must analyze the {num_masks} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action. Please be reminded that the original user query was '{initial_text_prompt}'."
283
+ messages.append(
284
+ {
285
+ "role": "user",
286
+ "content": [
287
+ {"type": "text", "text": sam3_output_text_message},
288
+ {"type": "image", "image": sam3_output_image_path},
289
+ ],
290
+ }
291
+ )
292
+ print("\n\n>>> sam3_output_text_message:\n", sam3_output_text_message)
293
+
294
+ elif tool_call["name"] == "examine_each_mask":
295
+ print("🔍 Calling examine_each_mask tool...")
296
+ assert LATEST_SAM3_TEXT_PROMPT != ""
297
+
298
+ # Make sure that the last message is a image
299
+ assert (
300
+ messages[-1]["content"][1]["type"] == "image"
301
+ ), "Second content element should be an image"
302
+ messages.pop() # Remove the last user message
303
+ # Add simplified replacement message
304
+ simplified_message = {
305
+ "role": "user",
306
+ "content": [
307
+ {
308
+ "type": "text",
309
+ "text": "The segment_phrase tool generated several masks. Now you must analyze the mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
310
+ }
311
+ ],
312
+ }
313
+ messages.append(simplified_message)
314
+
315
+ current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
316
+ num_masks = len(current_outputs["pred_masks"])
317
+ masks_to_keep = []
318
+
319
+ # MLLM check the mask one by one
320
+ for i in range(num_masks):
321
+ print(f"🔍 Checking mask {i + 1}/{num_masks}...")
322
+ image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
323
+
324
+ image_w_zoomed_in_mask_i_path = os.path.join(
325
+ sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
326
+ ).replace(".png", f"_zoom_in_mask_{i + 1}.png")
327
+ image_w_mask_i_path = os.path.join(
328
+ sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
329
+ ).replace(".png", f"_selected_mask_{i + 1}.png")
330
+ image_w_zoomed_in_mask_i.save(image_w_zoomed_in_mask_i_path)
331
+ image_w_mask_i.save(image_w_mask_i_path)
332
+
333
+ iterative_checking_messages = [
334
+ {"role": "system", "content": iterative_checking_system_prompt},
335
+ {
336
+ "role": "user",
337
+ "content": [
338
+ {"type": "text", "text": f"The raw input image: "},
339
+ {"type": "image", "image": img_path},
340
+ {
341
+ "type": "text",
342
+ "text": f"The initial user input query is: '{initial_text_prompt}'",
343
+ },
344
+ {
345
+ "type": "text",
346
+ "text": f"Image with the predicted segmentation mask rendered on it: ",
347
+ },
348
+ {"type": "image", "image": image_w_mask_i_path},
349
+ {
350
+ "type": "text",
351
+ "text": f"Image with the zoomed-in mask: ",
352
+ },
353
+ {"type": "image", "image": image_w_zoomed_in_mask_i_path},
354
+ ],
355
+ },
356
+ ]
357
+ checking_generated_text = send_generate_request(
358
+ iterative_checking_messages
359
+ )
360
+
361
+ # Process the generated text to determine if the mask should be kept or rejected
362
+ if checking_generated_text is None:
363
+ raise ValueError(
364
+ "Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
365
+ )
366
+ print(f"Generated text for mask {i + 1}: {checking_generated_text}")
367
+ verdict = (
368
+ checking_generated_text.split("<verdict>")[-1]
369
+ .split("</verdict>")[0]
370
+ .strip()
371
+ )
372
+ if "Accept" in verdict:
373
+ assert not "Reject" in verdict
374
+ print(f"Mask {i + 1} accepted, keeping it in the outputs.")
375
+ masks_to_keep.append(i)
376
+ elif "Reject" in verdict:
377
+ assert not "Accept" in verdict
378
+ print(f"Mask {i + 1} rejected, removing it from the outputs.")
379
+ else:
380
+ raise ValueError(
381
+ f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
382
+ )
383
+
384
+ updated_outputs = {
385
+ "original_image_path": current_outputs["original_image_path"],
386
+ "orig_img_h": current_outputs["orig_img_h"],
387
+ "orig_img_w": current_outputs["orig_img_w"],
388
+ "pred_boxes": [current_outputs["pred_boxes"][i] for i in masks_to_keep],
389
+ "pred_scores": [
390
+ current_outputs["pred_scores"][i] for i in masks_to_keep
391
+ ],
392
+ "pred_masks": [current_outputs["pred_masks"][i] for i in masks_to_keep],
393
+ }
394
+
395
+ image_w_check_masks = visualize(updated_outputs)
396
+ image_w_check_masks_path = os.path.join(
397
+ sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
398
+ ).replace(
399
+ ".png",
400
+ f"_selected_masks_{'-'.join(map(str, [i + 1 for i in masks_to_keep]))}.png".replace(
401
+ "/", "_"
402
+ ),
403
+ )
404
+ image_w_check_masks.save(image_w_check_masks_path)
405
+ # save the updated json outputs and append to message history
406
+ messages.append(
407
+ {
408
+ "role": "assistant",
409
+ "content": [{"type": "text", "text": generated_text}],
410
+ }
411
+ )
412
+ if len(masks_to_keep) == 0:
413
+ messages.append(
414
+ {
415
+ "role": "user",
416
+ "content": [
417
+ {
418
+ "type": "text",
419
+ "text": f"The original user query was: '{initial_text_prompt}'. The examine_each_mask tool examined and rejected all of the masks generated by the segment_phrase tool. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt.",
420
+ }
421
+ ],
422
+ }
423
+ )
424
+ else:
425
+ messages.append(
426
+ {
427
+ "role": "user",
428
+ "content": [
429
+ {
430
+ "type": "text",
431
+ "text": f"The original user query was: '{initial_text_prompt}'. After calling the examine_each_mask tool on the available masks, the number of available masks is now {len(masks_to_keep)}. All {len(masks_to_keep)} available masks are rendered in this image below, now you must analyze the {len(masks_to_keep)} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
432
+ },
433
+ {"type": "image", "image": image_w_check_masks_path},
434
+ ],
435
+ }
436
+ )
437
+
438
+ # Create a new filename based on the original path to avoid filename length issues
439
+ base_path = PATH_TO_LATEST_OUTPUT_JSON
440
+ # Remove any existing "masks_" suffix to avoid duplication
441
+ if "masks_" in base_path:
442
+ base_path = base_path.split("masks_")[0] + ".json"
443
+ # Create new filename with current masks; use a clearer suffix when empty
444
+ if len(masks_to_keep) == 0:
445
+ PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
446
+ ".json", "masks_none.json"
447
+ )
448
+ else:
449
+ PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
450
+ ".json", f"masks_{'_'.join(map(str, masks_to_keep))}.json"
451
+ )
452
+ json.dump(updated_outputs, open(PATH_TO_LATEST_OUTPUT_JSON, "w"), indent=4)
453
+
454
+ elif tool_call["name"] == "select_masks_and_return":
455
+ print("🔍 Calling select_masks_and_return tool...")
456
+ current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
457
+
458
+ assert list(tool_call["parameters"].keys()) == ["final_answer_masks"]
459
+ masks_to_keep = tool_call["parameters"]["final_answer_masks"]
460
+
461
+ # Keep only valid mask indices, remove duplicates, and preserve deterministic ascending order
462
+ available_masks = set(range(1, len(current_outputs["pred_masks"]) + 1))
463
+ masks_to_keep = sorted({i for i in masks_to_keep if i in available_masks})
464
+ # Change this to a update message telling the model to try again along with information about errors made.
465
+
466
+ final_outputs = {
467
+ "original_image_path": current_outputs["original_image_path"],
468
+ "orig_img_h": current_outputs["orig_img_h"],
469
+ "orig_img_w": current_outputs["orig_img_w"],
470
+ "pred_boxes": [
471
+ current_outputs["pred_boxes"][i - 1] for i in masks_to_keep
472
+ ],
473
+ "pred_scores": [
474
+ current_outputs["pred_scores"][i - 1] for i in masks_to_keep
475
+ ],
476
+ "pred_masks": [
477
+ current_outputs["pred_masks"][i - 1] for i in masks_to_keep
478
+ ],
479
+ }
480
+
481
+ rendered_final_output = visualize(final_outputs)
482
+ messages.append(
483
+ {
484
+ "role": "assistant",
485
+ "content": [{"type": "text", "text": generated_text}],
486
+ }
487
+ )
488
+
489
+ # Clean up debug files before successful return
490
+ cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path)
491
+ return messages, final_outputs, rendered_final_output
492
+
493
+ elif tool_call["name"] == "report_no_mask":
494
+ print("🔍 Calling report_no_mask tool...")
495
+ height, width = cv2.imread(img_path).shape[:2]
496
+ final_outputs = {
497
+ "original_image_path": img_path,
498
+ "orig_img_h": height,
499
+ "orig_img_w": width,
500
+ "pred_boxes": [],
501
+ "pred_scores": [],
502
+ "pred_masks": [],
503
+ }
504
+ rendered_final_output = Image.open(img_path)
505
+ messages.append(
506
+ {
507
+ "role": "assistant",
508
+ "content": [{"type": "text", "text": generated_text}],
509
+ }
510
+ )
511
+ return messages, final_outputs, rendered_final_output
512
+
513
+ else:
514
+ raise ValueError(f"Unknown tool call: {tool_call['name']}")
515
+
516
+ # sometimes the MLLM don't know when to stop, and generates multiple tool calls in one round, so we need to split the generated text by </tool> and only keep the first one
517
+
518
+ for message in messages:
519
+ if message["role"] == "assistant" and "content" in message:
520
+ for content in message["content"]:
521
+ if (
522
+ isinstance(content, dict)
523
+ and content.get("type") == "text"
524
+ and "text" in content
525
+ ):
526
+ content["text"] = (
527
+ content["text"].split("</tool>", 1)[0] + "</tool>\n\n"
528
+ )
529
+ # Prune the messages history before the next MLLM generation round according to the 3-part rules.
530
+ # This keeps history compact and ensures the model sees only the allowed parts.
531
+ messages = _prune_messages_for_next_round(
532
+ messages,
533
+ USED_TEXT_PROMPTS,
534
+ LATEST_SAM3_TEXT_PROMPT,
535
+ img_path,
536
+ initial_text_prompt,
537
+ )
538
+ # make sure there can never be more than 2 images in the context
539
+ assert count_images(messages) <= 2
540
+ generation_count += 1
541
+ if generation_count > max_generations:
542
+ raise ValueError(
543
+ f"Exceeded maximum number of allowed generation requests ({max_generations})"
544
+ )
545
+
546
+ print("\n\n")
547
+ print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
548
+ print("\n\n")
549
+ generated_text = send_generate_request(messages)
550
+ print(
551
+ f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n"
552
+ )
553
+
554
+ print("\n\n>>> SAM 3 Agent execution ended.\n\n")
555
+
556
+ error_save_path = os.path.join(
557
+ error_save_dir,
558
+ f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}_error_history.json",
559
+ )
560
+ with open(error_save_path, "w") as f:
561
+ json.dump(messages, f, indent=4)
562
+ print("Saved messages history that caused error to:", error_save_path)
563
+ raise ValueError(
564
+ rf"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters for image path: {img_path} and initial text prompt: {initial_text_prompt}."
565
+ )
third_party/sam3/sam3/agent/client_llm.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import base64
6
+ import os
7
+ from typing import Any, Optional
8
+
9
+ from openai import OpenAI
10
+
11
+
12
+ def get_image_base64_and_mime(image_path):
13
+ """Convert image file to base64 string and get MIME type"""
14
+ try:
15
+ # Get MIME type based on file extension
16
+ ext = os.path.splitext(image_path)[1].lower()
17
+ mime_types = {
18
+ ".jpg": "image/jpeg",
19
+ ".jpeg": "image/jpeg",
20
+ ".png": "image/png",
21
+ ".gif": "image/gif",
22
+ ".webp": "image/webp",
23
+ ".bmp": "image/bmp",
24
+ }
25
+ mime_type = mime_types.get(ext, "image/jpeg") # Default to JPEG
26
+
27
+ # Convert image to base64
28
+ with open(image_path, "rb") as image_file:
29
+ base64_data = base64.b64encode(image_file.read()).decode("utf-8")
30
+ return base64_data, mime_type
31
+ except Exception as e:
32
+ print(f"Error converting image to base64: {e}")
33
+ return None, None
34
+
35
+
36
+ def send_generate_request(
37
+ messages,
38
+ server_url=None,
39
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
40
+ api_key=None,
41
+ max_tokens=4096,
42
+ ):
43
+ """
44
+ Sends a request to the OpenAI-compatible API endpoint using the OpenAI client library.
45
+
46
+ Args:
47
+ server_url (str): The base URL of the server, e.g. "http://127.0.0.1:8000"
48
+ messages (list): A list of message dicts, each containing role and content.
49
+ model (str): The model to use for generation (default: "llama-4")
50
+ max_tokens (int): Maximum number of tokens to generate (default: 4096)
51
+
52
+ Returns:
53
+ str: The generated response text from the server.
54
+ """
55
+ # Process messages to convert image paths to base64
56
+ processed_messages = []
57
+ for message in messages:
58
+ processed_message = message.copy()
59
+ if message["role"] == "user" and "content" in message:
60
+ processed_content = []
61
+ for c in message["content"]:
62
+ if isinstance(c, dict) and c.get("type") == "image":
63
+ # Convert image path to base64 format
64
+ image_path = c["image"]
65
+
66
+ print("image_path", image_path)
67
+ new_image_path = image_path.replace(
68
+ "?", "%3F"
69
+ ) # Escape ? in the path
70
+
71
+ # Read the image file and convert to base64
72
+ try:
73
+ base64_image, mime_type = get_image_base64_and_mime(
74
+ new_image_path
75
+ )
76
+ if base64_image is None:
77
+ print(
78
+ f"Warning: Could not convert image to base64: {new_image_path}"
79
+ )
80
+ continue
81
+
82
+ # Create the proper image_url structure with base64 data
83
+ processed_content.append(
84
+ {
85
+ "type": "image_url",
86
+ "image_url": {
87
+ "url": f"data:{mime_type};base64,{base64_image}",
88
+ "detail": "high",
89
+ },
90
+ }
91
+ )
92
+
93
+ except FileNotFoundError:
94
+ print(f"Warning: Image file not found: {new_image_path}")
95
+ continue
96
+ except Exception as e:
97
+ print(f"Warning: Error processing image {new_image_path}: {e}")
98
+ continue
99
+ else:
100
+ processed_content.append(c)
101
+
102
+ processed_message["content"] = processed_content
103
+ processed_messages.append(processed_message)
104
+
105
+ # Create OpenAI client with custom base URL
106
+ client = OpenAI(api_key=api_key, base_url=server_url)
107
+
108
+ try:
109
+ print(f"🔍 Calling model {model}...")
110
+ response = client.chat.completions.create(
111
+ model=model,
112
+ messages=processed_messages,
113
+ max_completion_tokens=max_tokens,
114
+ n=1,
115
+ )
116
+ # print(f"Received response: {response.choices[0].message}")
117
+
118
+ # Extract the response content
119
+ if response.choices and len(response.choices) > 0:
120
+ return response.choices[0].message.content
121
+ else:
122
+ print(f"Unexpected response format: {response}")
123
+ return None
124
+
125
+ except Exception as e:
126
+ print(f"Request failed: {e}")
127
+ return None
128
+
129
+
130
+ def send_direct_request(
131
+ llm: Any,
132
+ messages: list[dict[str, Any]],
133
+ sampling_params: Any,
134
+ ) -> Optional[str]:
135
+ """
136
+ Run inference on a vLLM model instance directly without using a server.
137
+
138
+ Args:
139
+ llm: Initialized vLLM LLM instance (passed from external initialization)
140
+ messages: List of message dicts with role and content (OpenAI format)
141
+ sampling_params: vLLM SamplingParams instance (initialized externally)
142
+
143
+ Returns:
144
+ str: Generated response text, or None if inference fails
145
+ """
146
+ try:
147
+ # Process messages to handle images (convert to base64 if needed)
148
+ processed_messages = []
149
+ for message in messages:
150
+ processed_message = message.copy()
151
+ if message["role"] == "user" and "content" in message:
152
+ processed_content = []
153
+ for c in message["content"]:
154
+ if isinstance(c, dict) and c.get("type") == "image":
155
+ # Convert image path to base64 format
156
+ image_path = c["image"]
157
+ new_image_path = image_path.replace("?", "%3F")
158
+
159
+ try:
160
+ base64_image, mime_type = get_image_base64_and_mime(
161
+ new_image_path
162
+ )
163
+ if base64_image is None:
164
+ print(
165
+ f"Warning: Could not convert image: {new_image_path}"
166
+ )
167
+ continue
168
+
169
+ # vLLM expects image_url format
170
+ processed_content.append(
171
+ {
172
+ "type": "image_url",
173
+ "image_url": {
174
+ "url": f"data:{mime_type};base64,{base64_image}"
175
+ },
176
+ }
177
+ )
178
+ except Exception as e:
179
+ print(
180
+ f"Warning: Error processing image {new_image_path}: {e}"
181
+ )
182
+ continue
183
+ else:
184
+ processed_content.append(c)
185
+
186
+ processed_message["content"] = processed_content
187
+ processed_messages.append(processed_message)
188
+
189
+ print("🔍 Running direct inference with vLLM...")
190
+
191
+ # Run inference using vLLM's chat interface
192
+ outputs = llm.chat(
193
+ messages=processed_messages,
194
+ sampling_params=sampling_params,
195
+ )
196
+
197
+ # Extract the generated text from the first output
198
+ if outputs and len(outputs) > 0:
199
+ generated_text = outputs[0].outputs[0].text
200
+ return generated_text
201
+ else:
202
+ print(f"Unexpected output format: {outputs}")
203
+ return None
204
+
205
+ except Exception as e:
206
+ print(f"Direct inference failed: {e}")
207
+ return None
third_party/sam3/sam3/agent/client_sam3.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import json
6
+ import os
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from sam3.model.box_ops import box_xyxy_to_xywh
11
+ from sam3.train.masks_ops import rle_encode
12
+
13
+ from .helpers.mask_overlap_removal import remove_overlapping_masks
14
+ from .viz import visualize
15
+
16
+
17
+ def sam3_inference(processor, image_path, text_prompt):
18
+ """Run SAM 3 image inference with text prompts and format the outputs"""
19
+ image = Image.open(image_path)
20
+ orig_img_w, orig_img_h = image.size
21
+
22
+ # model inference
23
+ inference_state = processor.set_image(image)
24
+ inference_state = processor.set_text_prompt(
25
+ state=inference_state, prompt=text_prompt
26
+ )
27
+
28
+ # format and assemble outputs
29
+ pred_boxes_xyxy = torch.stack(
30
+ [
31
+ inference_state["boxes"][:, 0] / orig_img_w,
32
+ inference_state["boxes"][:, 1] / orig_img_h,
33
+ inference_state["boxes"][:, 2] / orig_img_w,
34
+ inference_state["boxes"][:, 3] / orig_img_h,
35
+ ],
36
+ dim=-1,
37
+ ) # normalized in range [0, 1]
38
+ pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist()
39
+ pred_masks = rle_encode(inference_state["masks"].squeeze(1))
40
+ pred_masks = [m["counts"] for m in pred_masks]
41
+ outputs = {
42
+ "orig_img_h": orig_img_h,
43
+ "orig_img_w": orig_img_w,
44
+ "pred_boxes": pred_boxes_xywh,
45
+ "pred_masks": pred_masks,
46
+ "pred_scores": inference_state["scores"].tolist(),
47
+ }
48
+ return outputs
49
+
50
+
51
+ def call_sam_service(
52
+ sam3_processor,
53
+ image_path: str,
54
+ text_prompt: str,
55
+ output_folder_path: str = "sam3_output",
56
+ ):
57
+ """
58
+ Loads an image, sends it with a text prompt to the service,
59
+ saves the results, and renders the visualization.
60
+ """
61
+ print(f"📞 Loading image '{image_path}' and sending with prompt '{text_prompt}'...")
62
+
63
+ text_prompt_for_save_path = (
64
+ text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt
65
+ )
66
+
67
+ os.makedirs(
68
+ os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True
69
+ )
70
+ output_json_path = os.path.join(
71
+ output_folder_path,
72
+ image_path.replace("/", "-"),
73
+ rf"{text_prompt_for_save_path}.json",
74
+ )
75
+ output_image_path = os.path.join(
76
+ output_folder_path,
77
+ image_path.replace("/", "-"),
78
+ rf"{text_prompt_for_save_path}.png",
79
+ )
80
+
81
+ try:
82
+ # Send the image and text prompt as a multipart/form-data request
83
+ serialized_response = sam3_inference(sam3_processor, image_path, text_prompt)
84
+
85
+ # 1. Prepare the response dictionary
86
+ serialized_response = remove_overlapping_masks(serialized_response)
87
+ serialized_response = {
88
+ "original_image_path": image_path,
89
+ "output_image_path": output_image_path,
90
+ **serialized_response,
91
+ }
92
+
93
+ # 2. Reorder predictions by scores (highest to lowest) if scores are available
94
+ if "pred_scores" in serialized_response and serialized_response["pred_scores"]:
95
+ # Create indices sorted by scores in descending order
96
+ score_indices = sorted(
97
+ range(len(serialized_response["pred_scores"])),
98
+ key=lambda i: serialized_response["pred_scores"][i],
99
+ reverse=True,
100
+ )
101
+
102
+ # Reorder all three lists based on the sorted indices
103
+ serialized_response["pred_scores"] = [
104
+ serialized_response["pred_scores"][i] for i in score_indices
105
+ ]
106
+ serialized_response["pred_boxes"] = [
107
+ serialized_response["pred_boxes"][i] for i in score_indices
108
+ ]
109
+ serialized_response["pred_masks"] = [
110
+ serialized_response["pred_masks"][i] for i in score_indices
111
+ ]
112
+
113
+ # 3. Remove any invalid RLE masks that is too short (shorter than 5 characters)
114
+ valid_masks = []
115
+ valid_boxes = []
116
+ valid_scores = []
117
+ for i, rle in enumerate(serialized_response["pred_masks"]):
118
+ if len(rle) > 4:
119
+ valid_masks.append(rle)
120
+ valid_boxes.append(serialized_response["pred_boxes"][i])
121
+ valid_scores.append(serialized_response["pred_scores"][i])
122
+ serialized_response["pred_masks"] = valid_masks
123
+ serialized_response["pred_boxes"] = valid_boxes
124
+ serialized_response["pred_scores"] = valid_scores
125
+
126
+ with open(output_json_path, "w") as f:
127
+ json.dump(serialized_response, f, indent=4)
128
+ print(f"✅ Raw JSON response saved to '{output_json_path}'")
129
+
130
+ # 4. Render and save visualizations on the image and save it in the SAM3 output folder
131
+ print("🔍 Rendering visualizations on the image ...")
132
+ viz_image = visualize(serialized_response)
133
+ os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
134
+ viz_image.save(output_image_path)
135
+ print("✅ Saved visualization at:", output_image_path)
136
+ except Exception as e:
137
+ print(f"❌ Error calling service: {e}")
138
+
139
+ return output_json_path
third_party/sam3/sam3/agent/helpers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe