WildDet3D Gradio demo
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +3 -0
- __pycache__/vis3d_glb.cpython-311.pyc +0 -0
- app.py +822 -0
- assets/demo/intrinsics.npy +3 -0
- assets/demo/rgb.png +3 -0
- requirements.txt +59 -0
- third_party/lingbot_depth/mdm/model/__init__.py +15 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/__init__.py +6 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/__init__.py +4 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/backbones.py +162 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/utils.py +39 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/__init__.py +12 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/attention.py +100 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/block.py +259 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/dino_head.py +58 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/drop_path.py +34 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/layer_scale.py +27 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/mlp.py +40 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed.py +88 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py +153 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py +72 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/__init__.py +55 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/mask_utils.py +137 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/vision_transformer.py +479 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/__init__.py +4 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/cluster.py +95 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/config.py +72 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/dtype.py +37 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/param_groups.py +103 -0
- third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/utils.py +95 -0
- third_party/lingbot_depth/mdm/model/modules_decoder.py +185 -0
- third_party/lingbot_depth/mdm/model/modules_rgbd_encoder.py +152 -0
- third_party/lingbot_depth/mdm/model/utils.py +127 -0
- third_party/lingbot_depth/mdm/model/v2.py +297 -0
- third_party/lingbot_depth/mdm/utils/__init__.py +0 -0
- third_party/lingbot_depth/mdm/utils/geo.py +105 -0
- third_party/lingbot_depth/mdm/utils/io.py +270 -0
- third_party/lingbot_depth/mdm/utils/tools.py +289 -0
- third_party/lingbot_depth/mdm/utils/vis.py +65 -0
- third_party/lingbot_depth/pyproject.toml +26 -0
- third_party/sam3/pyproject.toml +135 -0
- third_party/sam3/sam3/__init__.py +9 -0
- third_party/sam3/sam3/__pycache__/__init__.cpython-311.pyc +0 -0
- third_party/sam3/sam3/__pycache__/logger.cpython-311.pyc +0 -0
- third_party/sam3/sam3/__pycache__/model_builder.cpython-311.pyc +0 -0
- third_party/sam3/sam3/agent/__init__.py +3 -0
- third_party/sam3/sam3/agent/agent_core.py +565 -0
- third_party/sam3/sam3/agent/client_llm.py +207 -0
- third_party/sam3/sam3/agent/client_sam3.py +139 -0
- third_party/sam3/sam3/agent/helpers/__init__.py +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/demo/rgb.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
third_party/sam3/sam3/model/__pycache__/video_tracking_multiplex.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
third_party/sam3/sam3/perflib/tests/assets/masks.tiff filter=lfs diff=lfs merge=lfs -text
|
__pycache__/vis3d_glb.cpython-311.pyc
ADDED
|
Binary file (28.6 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,822 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio Web Demo for WildDet3D (5-mode).
|
| 2 |
+
|
| 3 |
+
Supports 5 prompt modes:
|
| 4 |
+
- Text: Enter text like "chair.table" (one-to-many)
|
| 5 |
+
- Visual: Click box on image, text="visual" (one-to-many)
|
| 6 |
+
- Visual+Label: Click box + category label (one-to-many)
|
| 7 |
+
- Geometry: Click box on image, text="geometric" (one-to-one)
|
| 8 |
+
- Geometry+Label: Click box + category label (one-to-one)
|
| 9 |
+
- Point: Click on image to select point
|
| 10 |
+
|
| 11 |
+
Requirements:
|
| 12 |
+
pip install gradio>=5.0.0
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python demo/huggingface/app.py
|
| 16 |
+
|
| 17 |
+
Then open http://localhost:7860 in browser.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
# Add paths: support both local dev and HuggingFace Space.
|
| 25 |
+
# Local dev: demo/huggingface/app.py -> repo root = ../../
|
| 26 |
+
# HF Space: wilddet3d/ is bundled in the same directory as app.py
|
| 27 |
+
_this_dir = Path(__file__).resolve().parent
|
| 28 |
+
if (_this_dir / "wilddet3d").exists():
|
| 29 |
+
# HuggingFace Space: everything bundled next to app.py
|
| 30 |
+
sys.path.insert(0, str(_this_dir))
|
| 31 |
+
else:
|
| 32 |
+
# Local dev: repo root is two levels up
|
| 33 |
+
repo_root = _this_dir.parent.parent
|
| 34 |
+
sys.path.insert(0, str(repo_root))
|
| 35 |
+
|
| 36 |
+
import spaces
|
| 37 |
+
import gradio as gr
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
import cv2
|
| 41 |
+
from PIL import Image
|
| 42 |
+
|
| 43 |
+
from wilddet3d.inference import build_model, WildDet3DPredictor
|
| 44 |
+
from wilddet3d.preprocessing import preprocess
|
| 45 |
+
from wilddet3d.vis.visualize import draw_3d_boxes
|
| 46 |
+
from vis3d_glb import (
|
| 47 |
+
depth_to_pointcloud, create_scene_glb, create_mesh_scene_glb,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def draw_points_on_image(image, points, color=(0, 255, 0), radius=8):
|
| 52 |
+
"""Draw points on image.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
image: numpy array (H, W, 3)
|
| 56 |
+
points: list of (x, y, label) tuples
|
| 57 |
+
color: color for positive points (green default)
|
| 58 |
+
radius: point radius
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Image with points drawn
|
| 62 |
+
"""
|
| 63 |
+
img = image.copy()
|
| 64 |
+
for x, y, label in points:
|
| 65 |
+
c = color if label == 1 else (255, 0, 0)
|
| 66 |
+
cv2.circle(img, (int(x), int(y)), radius, c, -1)
|
| 67 |
+
cv2.circle(img, (int(x), int(y)), radius + 2, (255, 255, 255), 2)
|
| 68 |
+
return img
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def draw_box_on_image(image, box, color=(0, 0, 255), thickness=3):
|
| 72 |
+
"""Draw box on image.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
image: numpy array (H, W, 3)
|
| 76 |
+
box: [x1, y1, x2, y2] coordinates
|
| 77 |
+
color: box color (red default)
|
| 78 |
+
thickness: line thickness
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Image with box drawn
|
| 82 |
+
"""
|
| 83 |
+
img = image.copy()
|
| 84 |
+
x1, y1, x2, y2 = [int(v) for v in box]
|
| 85 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
|
| 86 |
+
return img
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# HuggingFace Model repo for checkpoints
|
| 90 |
+
HF_MODEL_REPO = "weikaih/WildDet3D"
|
| 91 |
+
HF_CKPT_NAME = "wilddet3d.pt"
|
| 92 |
+
|
| 93 |
+
# Local checkpoint paths (tried in order)
|
| 94 |
+
LOCAL_CHECKPOINTS = [
|
| 95 |
+
"ckpt/wilddet3d.pt", # release repo layout
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
# Default demo image path
|
| 99 |
+
DEFAULT_IMAGE_PATH = "assets/demo/rgb.png"
|
| 100 |
+
DEFAULT_INTRINSICS_PATH = "assets/demo/intrinsics.npy"
|
| 101 |
+
|
| 102 |
+
# Global model (loaded once)
|
| 103 |
+
_cached_model = None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _resolve_checkpoint():
|
| 107 |
+
"""Resolve checkpoint: local if exists, else download from HF Hub."""
|
| 108 |
+
for path in LOCAL_CHECKPOINTS:
|
| 109 |
+
if os.path.exists(path):
|
| 110 |
+
return path
|
| 111 |
+
from huggingface_hub import hf_hub_download
|
| 112 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 113 |
+
print(f"Downloading checkpoint from {HF_MODEL_REPO}...")
|
| 114 |
+
ckpt = hf_hub_download(
|
| 115 |
+
repo_id=HF_MODEL_REPO, filename=HF_CKPT_NAME, token=hf_token
|
| 116 |
+
)
|
| 117 |
+
return ckpt
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def get_model():
|
| 121 |
+
"""Load model once and cache it."""
|
| 122 |
+
global _cached_model
|
| 123 |
+
if _cached_model is None:
|
| 124 |
+
ckpt_path = _resolve_checkpoint()
|
| 125 |
+
print(f"Loading WildDet3D model from {ckpt_path}...")
|
| 126 |
+
_cached_model = build_model(
|
| 127 |
+
checkpoint=ckpt_path,
|
| 128 |
+
score_threshold=0.0,
|
| 129 |
+
canonical_rotation=True,
|
| 130 |
+
skip_pretrained=True,
|
| 131 |
+
)
|
| 132 |
+
print("Model loaded!")
|
| 133 |
+
return _cached_model
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_default_image():
|
| 137 |
+
"""Load the default demo image."""
|
| 138 |
+
if os.path.exists(DEFAULT_IMAGE_PATH):
|
| 139 |
+
return np.array(Image.open(DEFAULT_IMAGE_PATH))
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def load_default_intrinsics():
|
| 144 |
+
"""Load default intrinsics values."""
|
| 145 |
+
if os.path.exists(DEFAULT_INTRINSICS_PATH):
|
| 146 |
+
intrinsics = np.load(DEFAULT_INTRINSICS_PATH)
|
| 147 |
+
return (
|
| 148 |
+
float(intrinsics[0, 0]),
|
| 149 |
+
float(intrinsics[1, 1]),
|
| 150 |
+
float(intrinsics[0, 2]),
|
| 151 |
+
float(intrinsics[1, 2]),
|
| 152 |
+
)
|
| 153 |
+
return 518.86, 519.47, 325.58, 253.74
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def format_intrinsics(K):
|
| 157 |
+
"""Format intrinsics tensor for display."""
|
| 158 |
+
if K is None:
|
| 159 |
+
return "Not available"
|
| 160 |
+
if isinstance(K, torch.Tensor):
|
| 161 |
+
K = K.cpu().numpy()
|
| 162 |
+
if K.ndim == 3:
|
| 163 |
+
K = K[0]
|
| 164 |
+
return (
|
| 165 |
+
f"fx={K[0, 0]:.2f}, fy={K[1, 1]:.2f}, "
|
| 166 |
+
f"cx={K[0, 2]:.2f}, cy={K[1, 2]:.2f}"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def scale_intrinsics_to_original(K, input_hw, original_hw):
|
| 171 |
+
"""Scale intrinsics from model input resolution to original."""
|
| 172 |
+
if K is None:
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
if isinstance(K, torch.Tensor):
|
| 176 |
+
K = K.clone()
|
| 177 |
+
else:
|
| 178 |
+
K = K.copy()
|
| 179 |
+
|
| 180 |
+
input_h, input_w = input_hw
|
| 181 |
+
orig_h, orig_w = original_hw
|
| 182 |
+
|
| 183 |
+
scale_x = orig_w / input_w
|
| 184 |
+
scale_y = orig_h / input_h
|
| 185 |
+
|
| 186 |
+
if K.ndim == 3:
|
| 187 |
+
K[:, 0, 0] *= scale_x
|
| 188 |
+
K[:, 1, 1] *= scale_y
|
| 189 |
+
K[:, 0, 2] *= scale_x
|
| 190 |
+
K[:, 1, 2] *= scale_y
|
| 191 |
+
else:
|
| 192 |
+
K[0, 0] *= scale_x
|
| 193 |
+
K[1, 1] *= scale_y
|
| 194 |
+
K[0, 2] *= scale_x
|
| 195 |
+
K[1, 2] *= scale_y
|
| 196 |
+
|
| 197 |
+
return K
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def transform_coords_to_input_space(x, y, original_hw, input_hw, padding):
|
| 201 |
+
"""Transform coords from original image space to preprocessed input.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
x, y: Coordinates in original image space
|
| 205 |
+
original_hw: (H, W) of original image
|
| 206 |
+
input_hw: (H, W) of preprocessed image (e.g., 1008x1008)
|
| 207 |
+
padding: (pad_left, pad_right, pad_top, pad_bottom)
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
(new_x, new_y) in preprocessed input space
|
| 211 |
+
"""
|
| 212 |
+
orig_h, orig_w = original_hw
|
| 213 |
+
pad_left, pad_right, pad_top, pad_bottom = padding
|
| 214 |
+
|
| 215 |
+
content_w = input_hw[1] - pad_left - pad_right
|
| 216 |
+
content_h = input_hw[0] - pad_top - pad_bottom
|
| 217 |
+
|
| 218 |
+
scale_x = content_w / orig_w
|
| 219 |
+
scale_y = content_h / orig_h
|
| 220 |
+
|
| 221 |
+
new_x = x * scale_x + pad_left
|
| 222 |
+
new_y = y * scale_y + pad_top
|
| 223 |
+
|
| 224 |
+
return new_x, new_y
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def on_image_select(
|
| 228 |
+
evt: gr.SelectData, image, original_image, state,
|
| 229 |
+
prompt_mode, point_label,
|
| 230 |
+
):
|
| 231 |
+
"""Handle click on image and visualize the click."""
|
| 232 |
+
if image is None:
|
| 233 |
+
return state, "Please upload an image first", None
|
| 234 |
+
|
| 235 |
+
x, y = evt.index[0], evt.index[1]
|
| 236 |
+
label = 1 if "Positive" in point_label else 0
|
| 237 |
+
|
| 238 |
+
new_state = {
|
| 239 |
+
"points": list(state.get("points", [])),
|
| 240 |
+
"box": list(state.get("box", [])),
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
vis_image = (
|
| 244 |
+
original_image.copy()
|
| 245 |
+
if original_image is not None
|
| 246 |
+
else image.copy()
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if prompt_mode == "Point":
|
| 250 |
+
new_state["points"].append((x, y, label))
|
| 251 |
+
new_state["box"] = []
|
| 252 |
+
label_str = "+" if label == 1 else "-"
|
| 253 |
+
info = (
|
| 254 |
+
f"Points: {len(new_state['points'])} total. "
|
| 255 |
+
f"Last: ({x}, {y}) [{label_str}]"
|
| 256 |
+
)
|
| 257 |
+
vis_image = draw_points_on_image(vis_image, new_state["points"])
|
| 258 |
+
|
| 259 |
+
elif prompt_mode in ("Box-to-Multi-Object", "Box-to-Single-Object"):
|
| 260 |
+
new_state["points"] = []
|
| 261 |
+
box_clicks = list(new_state.get("box", []))
|
| 262 |
+
box_clicks.append((x, y))
|
| 263 |
+
|
| 264 |
+
if len(box_clicks) == 1:
|
| 265 |
+
new_state["box"] = box_clicks
|
| 266 |
+
info = (
|
| 267 |
+
f"[{prompt_mode}] Corner 1: ({x}, {y}) "
|
| 268 |
+
f"- click again for corner 2"
|
| 269 |
+
)
|
| 270 |
+
vis_image = draw_points_on_image(vis_image, [(x, y, 1)])
|
| 271 |
+
|
| 272 |
+
elif len(box_clicks) >= 2:
|
| 273 |
+
x1, y1 = box_clicks[0]
|
| 274 |
+
x2, y2 = box_clicks[1]
|
| 275 |
+
box = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)]
|
| 276 |
+
new_state["box"] = [(box[0], box[1]), (box[2], box[3])]
|
| 277 |
+
info = (
|
| 278 |
+
f"[{prompt_mode}] Box: "
|
| 279 |
+
f"({box[0]}, {box[1]}) -> ({box[2]}, {box[3]})"
|
| 280 |
+
)
|
| 281 |
+
vis_image = draw_box_on_image(vis_image, box)
|
| 282 |
+
else:
|
| 283 |
+
info = f"Box clicks: {box_clicks}"
|
| 284 |
+
else:
|
| 285 |
+
info = "Text mode - just enter text and click Run"
|
| 286 |
+
|
| 287 |
+
return new_state, info, vis_image
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def clear_clicks(state, original_image):
|
| 291 |
+
"""Reset click state and restore original image."""
|
| 292 |
+
new_state = {"points": [], "box": []}
|
| 293 |
+
return (
|
| 294 |
+
new_state,
|
| 295 |
+
"Cleared - ready for new clicks",
|
| 296 |
+
original_image.copy() if original_image is not None else None,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@spaces.GPU
|
| 301 |
+
def run_wilddet3d(
|
| 302 |
+
image,
|
| 303 |
+
state,
|
| 304 |
+
prompt_mode,
|
| 305 |
+
text_prompt,
|
| 306 |
+
use_label,
|
| 307 |
+
label_text,
|
| 308 |
+
score_thres,
|
| 309 |
+
use_predicted_K,
|
| 310 |
+
fx, fy, cx, cy,
|
| 311 |
+
enable_3d_vis=True,
|
| 312 |
+
remove_edges=True,
|
| 313 |
+
point_density=2,
|
| 314 |
+
use_textured_mesh=True,
|
| 315 |
+
):
|
| 316 |
+
"""Run WildDet3D with selected prompt mode."""
|
| 317 |
+
if image is None:
|
| 318 |
+
return None, "Please upload an image first", None, None
|
| 319 |
+
|
| 320 |
+
# Convert RGBA to RGB if needed
|
| 321 |
+
if image.ndim == 3 and image.shape[2] == 4:
|
| 322 |
+
image = image[:, :, :3]
|
| 323 |
+
|
| 324 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 325 |
+
detector = get_model()
|
| 326 |
+
|
| 327 |
+
# Build intrinsics matrix (or None if using predicted)
|
| 328 |
+
if use_predicted_K:
|
| 329 |
+
intrinsics = None
|
| 330 |
+
else:
|
| 331 |
+
intrinsics = np.array([
|
| 332 |
+
[fx, 0, cx],
|
| 333 |
+
[0, fy, cy],
|
| 334 |
+
[0, 0, 1]
|
| 335 |
+
], dtype=np.float32)
|
| 336 |
+
|
| 337 |
+
# Preprocess image
|
| 338 |
+
data = preprocess(image.astype(np.float32), intrinsics)
|
| 339 |
+
|
| 340 |
+
# Build prompt_text for box/point modes
|
| 341 |
+
if prompt_mode == "Box-to-Multi-Object":
|
| 342 |
+
prefix = "visual"
|
| 343 |
+
elif prompt_mode == "Box-to-Single-Object":
|
| 344 |
+
prefix = "geometric"
|
| 345 |
+
else:
|
| 346 |
+
prefix = "geometric" # Point mode default
|
| 347 |
+
|
| 348 |
+
if prompt_mode != "Text":
|
| 349 |
+
if use_label and label_text and label_text.strip():
|
| 350 |
+
geo_prompt_text = f"{prefix}: {label_text.strip()}"
|
| 351 |
+
else:
|
| 352 |
+
geo_prompt_text = prefix
|
| 353 |
+
|
| 354 |
+
# Initialize prompt info for visualization
|
| 355 |
+
prompt_points = None
|
| 356 |
+
prompt_box = None
|
| 357 |
+
|
| 358 |
+
# Run based on prompt mode
|
| 359 |
+
if prompt_mode == "Text":
|
| 360 |
+
input_texts = [
|
| 361 |
+
t.strip() for t in text_prompt.split(".") if t.strip()
|
| 362 |
+
]
|
| 363 |
+
if not input_texts:
|
| 364 |
+
input_texts = ["object"]
|
| 365 |
+
|
| 366 |
+
results = detector(
|
| 367 |
+
images=data["images"].to(device),
|
| 368 |
+
intrinsics=data["intrinsics"].to(device)[None],
|
| 369 |
+
input_hw=[data["input_hw"]],
|
| 370 |
+
original_hw=[data["original_hw"]],
|
| 371 |
+
padding=[data["padding"]],
|
| 372 |
+
input_texts=input_texts,
|
| 373 |
+
return_predicted_intrinsics=True,
|
| 374 |
+
)
|
| 375 |
+
(
|
| 376 |
+
boxes, boxes3d, scores, scores_2d, scores_3d,
|
| 377 |
+
class_ids, depth_maps, predicted_K,
|
| 378 |
+
) = results
|
| 379 |
+
class_id_mapping = {i: t for i, t in enumerate(input_texts)}
|
| 380 |
+
|
| 381 |
+
elif prompt_mode in ("Box-to-Multi-Object", "Box-to-Single-Object"):
|
| 382 |
+
box_coords = state.get("box", [])
|
| 383 |
+
if len(box_coords) < 2:
|
| 384 |
+
return (
|
| 385 |
+
None,
|
| 386 |
+
"Please click twice on the image to define a box",
|
| 387 |
+
None,
|
| 388 |
+
None,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
x1_orig, y1_orig = box_coords[0]
|
| 392 |
+
x2_orig, y2_orig = box_coords[1]
|
| 393 |
+
x1, y1 = transform_coords_to_input_space(
|
| 394 |
+
x1_orig, y1_orig,
|
| 395 |
+
data["original_hw"], data["input_hw"], data["padding"],
|
| 396 |
+
)
|
| 397 |
+
x2, y2 = transform_coords_to_input_space(
|
| 398 |
+
x2_orig, y2_orig,
|
| 399 |
+
data["original_hw"], data["input_hw"], data["padding"],
|
| 400 |
+
)
|
| 401 |
+
box_xyxy = [float(x1), float(y1), float(x2), float(y2)]
|
| 402 |
+
|
| 403 |
+
prompt_box = [x1_orig, y1_orig, x2_orig, y2_orig]
|
| 404 |
+
|
| 405 |
+
results = detector(
|
| 406 |
+
images=data["images"].to(device),
|
| 407 |
+
intrinsics=data["intrinsics"].to(device)[None],
|
| 408 |
+
input_hw=[data["input_hw"]],
|
| 409 |
+
original_hw=[data["original_hw"]],
|
| 410 |
+
padding=[data["padding"]],
|
| 411 |
+
input_boxes=[box_xyxy],
|
| 412 |
+
prompt_text=geo_prompt_text,
|
| 413 |
+
return_predicted_intrinsics=True,
|
| 414 |
+
)
|
| 415 |
+
(
|
| 416 |
+
boxes, boxes3d, scores, scores_2d, scores_3d,
|
| 417 |
+
class_ids, depth_maps, predicted_K,
|
| 418 |
+
) = results
|
| 419 |
+
class_id_mapping = {0: geo_prompt_text}
|
| 420 |
+
|
| 421 |
+
elif prompt_mode == "Point":
|
| 422 |
+
points = state.get("points", [])
|
| 423 |
+
if not points:
|
| 424 |
+
return (
|
| 425 |
+
None,
|
| 426 |
+
"Please click on the image to select a point",
|
| 427 |
+
None,
|
| 428 |
+
None,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
transformed_points = []
|
| 432 |
+
for x_orig, y_orig, lbl in points:
|
| 433 |
+
x, y = transform_coords_to_input_space(
|
| 434 |
+
x_orig, y_orig,
|
| 435 |
+
data["original_hw"], data["input_hw"], data["padding"],
|
| 436 |
+
)
|
| 437 |
+
transformed_points.append((x, y, lbl))
|
| 438 |
+
|
| 439 |
+
prompt_points = points
|
| 440 |
+
|
| 441 |
+
results = detector(
|
| 442 |
+
images=data["images"].to(device),
|
| 443 |
+
intrinsics=data["intrinsics"].to(device)[None],
|
| 444 |
+
input_hw=[data["input_hw"]],
|
| 445 |
+
original_hw=[data["original_hw"]],
|
| 446 |
+
padding=[data["padding"]],
|
| 447 |
+
input_points=[transformed_points],
|
| 448 |
+
prompt_text=geo_prompt_text,
|
| 449 |
+
return_predicted_intrinsics=True,
|
| 450 |
+
)
|
| 451 |
+
(
|
| 452 |
+
boxes, boxes3d, scores, scores_2d, scores_3d,
|
| 453 |
+
class_ids, depth_maps, predicted_K,
|
| 454 |
+
) = results
|
| 455 |
+
class_id_mapping = {0: geo_prompt_text}
|
| 456 |
+
|
| 457 |
+
else:
|
| 458 |
+
return None, f"Unknown prompt mode: {prompt_mode}", None, None
|
| 459 |
+
|
| 460 |
+
# Scale predicted intrinsics to original resolution
|
| 461 |
+
predicted_K_scaled = scale_intrinsics_to_original(
|
| 462 |
+
predicted_K,
|
| 463 |
+
input_hw=data["input_hw"],
|
| 464 |
+
original_hw=data["original_hw"],
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# Format intrinsics info
|
| 468 |
+
orig_h, orig_w = data["original_hw"]
|
| 469 |
+
intrinsics_info = f"Image: {orig_w}x{orig_h}\n"
|
| 470 |
+
intrinsics_info += f"Predicted: {format_intrinsics(predicted_K_scaled)}"
|
| 471 |
+
if not use_predicted_K:
|
| 472 |
+
intrinsics_info = f"Image: {orig_w}x{orig_h}\n"
|
| 473 |
+
intrinsics_info += (
|
| 474 |
+
f"Used: fx={fx:.2f}, fy={fy:.2f}, "
|
| 475 |
+
f"cx={cx:.2f}, cy={cy:.2f}\n"
|
| 476 |
+
)
|
| 477 |
+
intrinsics_info += (
|
| 478 |
+
f"Predicted: {format_intrinsics(predicted_K_scaled)}"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# 2D visualization
|
| 482 |
+
img_2d = visualize_results(
|
| 483 |
+
data, boxes3d, scores, scores_2d, scores_3d,
|
| 484 |
+
class_ids, class_id_mapping, score_thres,
|
| 485 |
+
prompt_points=prompt_points, prompt_box=prompt_box,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Depth map visualization
|
| 489 |
+
depth_vis_img = None
|
| 490 |
+
if depth_maps is not None and len(depth_maps) > 0:
|
| 491 |
+
depth_np_raw = depth_maps[0].cpu().numpy()
|
| 492 |
+
d = depth_np_raw.squeeze()
|
| 493 |
+
|
| 494 |
+
pad_l, pad_r, pad_t, pad_b = data["padding"]
|
| 495 |
+
h_end = d.shape[0] - pad_b if pad_b > 0 else d.shape[0]
|
| 496 |
+
w_end = d.shape[1] - pad_r if pad_r > 0 else d.shape[1]
|
| 497 |
+
d_crop = d[pad_t:h_end, pad_l:w_end]
|
| 498 |
+
|
| 499 |
+
d_valid = d_crop[d_crop > 0.01]
|
| 500 |
+
if len(d_valid) > 0:
|
| 501 |
+
d_min, d_max = d_valid.min(), d_valid.max()
|
| 502 |
+
d_norm = np.clip(
|
| 503 |
+
(d_crop - d_min) / (d_max - d_min + 1e-6), 0, 1
|
| 504 |
+
)
|
| 505 |
+
d_norm = (1.0 - d_norm) * 255
|
| 506 |
+
d_norm = d_norm.astype(np.uint8)
|
| 507 |
+
depth_vis_img = cv2.applyColorMap(d_norm, cv2.COLORMAP_TURBO)
|
| 508 |
+
depth_vis_img = cv2.cvtColor(depth_vis_img, cv2.COLOR_BGR2RGB)
|
| 509 |
+
depth_vis_img = Image.fromarray(depth_vis_img)
|
| 510 |
+
|
| 511 |
+
# 3D visualization (optional)
|
| 512 |
+
glb_path = None
|
| 513 |
+
if enable_3d_vis and depth_maps is not None and len(depth_maps) > 0:
|
| 514 |
+
depth_np = depth_maps[0].cpu().numpy()
|
| 515 |
+
|
| 516 |
+
input_img = data["images"].cpu()
|
| 517 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 518 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 519 |
+
input_img = (input_img * std + mean).clamp(0, 1) * 255
|
| 520 |
+
input_img = (
|
| 521 |
+
input_img.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
K_for_unproj = data["intrinsics"].cpu().numpy()
|
| 525 |
+
|
| 526 |
+
filtered_boxes3d_np = []
|
| 527 |
+
for i in range(len(boxes3d)):
|
| 528 |
+
mask = scores[i] >= score_thres
|
| 529 |
+
filtered_boxes3d_np.append(boxes3d[i][mask].cpu().numpy())
|
| 530 |
+
|
| 531 |
+
glb_path = "/tmp/wilddet3d_scene.glb"
|
| 532 |
+
|
| 533 |
+
if use_textured_mesh:
|
| 534 |
+
create_mesh_scene_glb(
|
| 535 |
+
depth_np, input_img, K_for_unproj,
|
| 536 |
+
filtered_boxes3d_np, glb_path,
|
| 537 |
+
max_depth=20.0,
|
| 538 |
+
padding=data["padding"],
|
| 539 |
+
remove_edge=remove_edges,
|
| 540 |
+
edge_rtol=0.04,
|
| 541 |
+
)
|
| 542 |
+
else:
|
| 543 |
+
subsample = max(1, int(point_density))
|
| 544 |
+
points, point_colors = depth_to_pointcloud(
|
| 545 |
+
depth_np, input_img, K_for_unproj,
|
| 546 |
+
max_depth=20.0, subsample=subsample,
|
| 547 |
+
padding=data["padding"],
|
| 548 |
+
remove_edge=remove_edges,
|
| 549 |
+
edge_rtol=0.04,
|
| 550 |
+
)
|
| 551 |
+
create_scene_glb(
|
| 552 |
+
points, point_colors, filtered_boxes3d_np, glb_path
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
return img_2d, intrinsics_info, glb_path, depth_vis_img
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def visualize_results(
|
| 559 |
+
data, boxes3d, scores, scores_2d, scores_3d, class_ids,
|
| 560 |
+
class_id_mapping, score_thres,
|
| 561 |
+
prompt_points=None, prompt_box=None,
|
| 562 |
+
):
|
| 563 |
+
"""Visualize 3D detection results using wilddet3d.vis.draw_3d_boxes."""
|
| 564 |
+
filtered_boxes3d = []
|
| 565 |
+
filtered_scores_2d = []
|
| 566 |
+
filtered_scores_3d = []
|
| 567 |
+
filtered_class_ids = []
|
| 568 |
+
|
| 569 |
+
for i in range(len(boxes3d)):
|
| 570 |
+
mask = scores[i] >= score_thres
|
| 571 |
+
filtered_boxes3d.append(boxes3d[i][mask])
|
| 572 |
+
if scores_2d is not None:
|
| 573 |
+
filtered_scores_2d.append(scores_2d[i][mask])
|
| 574 |
+
else:
|
| 575 |
+
filtered_scores_2d.append(torch.zeros_like(scores[i][mask]))
|
| 576 |
+
if scores_3d is not None:
|
| 577 |
+
filtered_scores_3d.append(scores_3d[i][mask])
|
| 578 |
+
else:
|
| 579 |
+
filtered_scores_3d.append(torch.zeros_like(scores[i][mask]))
|
| 580 |
+
filtered_class_ids.append(class_ids[i][mask])
|
| 581 |
+
|
| 582 |
+
# Get original image and draw prompts on it
|
| 583 |
+
original_img = data["original_images"].cpu().numpy().astype(np.uint8)
|
| 584 |
+
|
| 585 |
+
if prompt_points is not None and len(prompt_points) > 0:
|
| 586 |
+
original_img = draw_points_on_image(original_img, prompt_points)
|
| 587 |
+
|
| 588 |
+
if prompt_box is not None and len(prompt_box) == 4:
|
| 589 |
+
original_img = draw_box_on_image(original_img, prompt_box)
|
| 590 |
+
|
| 591 |
+
# Use wilddet3d's draw_3d_boxes for visualization
|
| 592 |
+
K = data["original_intrinsics"].cpu().numpy()
|
| 593 |
+
if K.ndim == 3:
|
| 594 |
+
K = K[0]
|
| 595 |
+
|
| 596 |
+
class_names = [
|
| 597 |
+
class_id_mapping.get(i, str(i))
|
| 598 |
+
for i in range(max(len(class_id_mapping), 1))
|
| 599 |
+
]
|
| 600 |
+
|
| 601 |
+
# Draw 3D boxes with 2D/3D score labels
|
| 602 |
+
if len(filtered_boxes3d) > 0 and len(filtered_boxes3d[0]) > 0:
|
| 603 |
+
pil_img = draw_3d_boxes(
|
| 604 |
+
image=original_img,
|
| 605 |
+
boxes3d=filtered_boxes3d[0],
|
| 606 |
+
intrinsics=K,
|
| 607 |
+
scores_2d=filtered_scores_2d[0],
|
| 608 |
+
scores_3d=filtered_scores_3d[0],
|
| 609 |
+
class_ids=filtered_class_ids[0],
|
| 610 |
+
class_names=class_names,
|
| 611 |
+
n_colors=max(len(class_id_mapping), 1),
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
pil_img = Image.fromarray(original_img)
|
| 615 |
+
|
| 616 |
+
return pil_img
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
# Load default values
|
| 620 |
+
default_fx, default_fy, default_cx, default_cy = load_default_intrinsics()
|
| 621 |
+
default_image = load_default_image()
|
| 622 |
+
|
| 623 |
+
# Build Gradio interface
|
| 624 |
+
with gr.Blocks(title="WildDet3D: 3D Detection") as demo:
|
| 625 |
+
gr.Markdown("# WildDet3D: Open-Vocabulary 3D Detection in the Wild")
|
| 626 |
+
gr.Markdown("""
|
| 627 |
+
**How to use:**
|
| 628 |
+
- **Text**: Enter object names (e.g., "chair.table"), click Run
|
| 629 |
+
- **Box-to-Multi-Object**: Draw box -> detect ALL similar objects (one-to-many)
|
| 630 |
+
- **Box-to-Single-Object**: Draw box -> detect ONLY the boxed object (one-to-one)
|
| 631 |
+
- **Point**: Click on object, click Run
|
| 632 |
+
- **+ Label**: Check this to attach a category name (e.g., "chair") to box/point prompts
|
| 633 |
+
""")
|
| 634 |
+
|
| 635 |
+
# State for click coordinates and original image
|
| 636 |
+
click_state = gr.State({"points": [], "box": []})
|
| 637 |
+
original_image_state = gr.State(
|
| 638 |
+
default_image.copy() if default_image is not None else None
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
with gr.Row():
|
| 642 |
+
# Left column: Input
|
| 643 |
+
with gr.Column(scale=1):
|
| 644 |
+
input_image = gr.Image(
|
| 645 |
+
label="Input Image (click for Box/Point mode)",
|
| 646 |
+
type="numpy",
|
| 647 |
+
value=default_image,
|
| 648 |
+
interactive=True,
|
| 649 |
+
sources=["upload", "clipboard"],
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
# Prompt settings
|
| 653 |
+
prompt_mode = gr.Radio(
|
| 654 |
+
choices=[
|
| 655 |
+
"Text",
|
| 656 |
+
"Box-to-Multi-Object",
|
| 657 |
+
"Box-to-Single-Object",
|
| 658 |
+
"Point",
|
| 659 |
+
],
|
| 660 |
+
value="Text",
|
| 661 |
+
label="Prompt Mode",
|
| 662 |
+
)
|
| 663 |
+
text_prompt = gr.Textbox(
|
| 664 |
+
label="Text Prompt (e.g. 'chair.table')",
|
| 665 |
+
value="chair.table",
|
| 666 |
+
placeholder="Enter object names separated by '.'",
|
| 667 |
+
visible=True,
|
| 668 |
+
)
|
| 669 |
+
use_label = gr.Checkbox(
|
| 670 |
+
label="+ Label (attach category name to box/point prompt)",
|
| 671 |
+
value=False,
|
| 672 |
+
visible=False,
|
| 673 |
+
)
|
| 674 |
+
label_text = gr.Textbox(
|
| 675 |
+
label="Category Label (e.g. 'chair')",
|
| 676 |
+
value="",
|
| 677 |
+
placeholder="Category name for the selected object",
|
| 678 |
+
visible=False,
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# Point label for Point mode
|
| 682 |
+
point_label = gr.Radio(
|
| 683 |
+
choices=["Positive (include)", "Negative (exclude)"],
|
| 684 |
+
value="Positive (include)",
|
| 685 |
+
label="Point Label (for Point mode)",
|
| 686 |
+
visible=False,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
# Click info display
|
| 690 |
+
click_info = gr.Textbox(
|
| 691 |
+
label="Click Info",
|
| 692 |
+
value="Select mode and click on image",
|
| 693 |
+
interactive=False,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
with gr.Row():
|
| 697 |
+
clear_btn = gr.Button("Clear Clicks")
|
| 698 |
+
run_btn = gr.Button("Run Detection", variant="primary")
|
| 699 |
+
|
| 700 |
+
# Intrinsics settings
|
| 701 |
+
use_predicted_K = gr.Checkbox(
|
| 702 |
+
label="Use Predicted Intrinsics",
|
| 703 |
+
value=True,
|
| 704 |
+
)
|
| 705 |
+
with gr.Row():
|
| 706 |
+
fx = gr.Number(label="fx", value=default_fx)
|
| 707 |
+
fy = gr.Number(label="fy", value=default_fy)
|
| 708 |
+
cx = gr.Number(label="cx", value=default_cx)
|
| 709 |
+
cy = gr.Number(label="cy", value=default_cy)
|
| 710 |
+
|
| 711 |
+
score_thres = gr.Slider(
|
| 712 |
+
minimum=0, maximum=1, value=0.3, step=0.05,
|
| 713 |
+
label="Score Threshold",
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
# 3D visualization settings
|
| 717 |
+
gr.Markdown("### 3D Visualization Settings")
|
| 718 |
+
enable_3d_vis = gr.Checkbox(
|
| 719 |
+
label="Enable 3D Point Cloud / Mesh Visualization",
|
| 720 |
+
value=False,
|
| 721 |
+
)
|
| 722 |
+
gr.Markdown(
|
| 723 |
+
"*Notice: the model takes the depth latent to generate "
|
| 724 |
+
"3D boxes, so the boxes and the point cloud might not "
|
| 725 |
+
"exactly match.*"
|
| 726 |
+
)
|
| 727 |
+
use_textured_mesh = gr.Checkbox(
|
| 728 |
+
label="Textured Mesh (otherwise point cloud)",
|
| 729 |
+
value=True,
|
| 730 |
+
)
|
| 731 |
+
remove_edges = gr.Checkbox(
|
| 732 |
+
label="Remove depth edges (cleaner geometry)",
|
| 733 |
+
value=True,
|
| 734 |
+
)
|
| 735 |
+
point_density = gr.Slider(
|
| 736 |
+
minimum=1, maximum=8, value=2, step=1,
|
| 737 |
+
label="Point Subsample (point cloud mode only, 1=dense)",
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# Right column: Output
|
| 741 |
+
with gr.Column(scale=1):
|
| 742 |
+
output_image = gr.Image(
|
| 743 |
+
label="2D Detection Results", type="pil"
|
| 744 |
+
)
|
| 745 |
+
depth_image = gr.Image(label="Depth Map", type="pil")
|
| 746 |
+
output_3d = gr.Model3D(
|
| 747 |
+
label="3D View (Mesh/Point Cloud + Boxes)",
|
| 748 |
+
clear_color=(0.1, 0.1, 0.1, 1.0),
|
| 749 |
+
)
|
| 750 |
+
intrinsics_info = gr.Textbox(
|
| 751 |
+
label="Intrinsics Info", interactive=False
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
# Toggle visibility based on prompt mode
|
| 755 |
+
def on_mode_change(mode):
|
| 756 |
+
is_text = mode == "Text"
|
| 757 |
+
is_point = mode == "Point"
|
| 758 |
+
return (
|
| 759 |
+
gr.update(visible=is_text), # text_prompt
|
| 760 |
+
gr.update(visible=not is_text), # use_label
|
| 761 |
+
gr.update(visible=not is_text), # label_text
|
| 762 |
+
gr.update(visible=is_point), # point_label
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
prompt_mode.change(
|
| 766 |
+
on_mode_change,
|
| 767 |
+
inputs=[prompt_mode],
|
| 768 |
+
outputs=[text_prompt, use_label, label_text, point_label],
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# Connect events
|
| 772 |
+
input_image.select(
|
| 773 |
+
on_image_select,
|
| 774 |
+
inputs=[
|
| 775 |
+
input_image, original_image_state, click_state,
|
| 776 |
+
prompt_mode, point_label,
|
| 777 |
+
],
|
| 778 |
+
outputs=[click_state, click_info, input_image],
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
clear_btn.click(
|
| 782 |
+
clear_clicks,
|
| 783 |
+
inputs=[click_state, original_image_state],
|
| 784 |
+
outputs=[click_state, click_info, input_image],
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
# When new image is uploaded, save it as original
|
| 788 |
+
def on_image_upload(image):
|
| 789 |
+
if image is None:
|
| 790 |
+
return None, {"points": [], "box": []}, "Upload an image"
|
| 791 |
+
return (
|
| 792 |
+
image.copy(),
|
| 793 |
+
{"points": [], "box": []},
|
| 794 |
+
"Image loaded - select mode and click",
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
input_image.upload(
|
| 798 |
+
on_image_upload,
|
| 799 |
+
inputs=[input_image],
|
| 800 |
+
outputs=[original_image_state, click_state, click_info],
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
run_btn.click(
|
| 804 |
+
run_wilddet3d,
|
| 805 |
+
inputs=[
|
| 806 |
+
input_image, click_state, prompt_mode, text_prompt,
|
| 807 |
+
use_label, label_text, score_thres, use_predicted_K,
|
| 808 |
+
fx, fy, cx, cy,
|
| 809 |
+
enable_3d_vis, remove_edges, point_density, use_textured_mesh,
|
| 810 |
+
],
|
| 811 |
+
outputs=[output_image, intrinsics_info, output_3d, depth_image],
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
if __name__ == "__main__":
|
| 816 |
+
print("=" * 60)
|
| 817 |
+
print("WildDet3D Web Demo")
|
| 818 |
+
print("=" * 60)
|
| 819 |
+
print()
|
| 820 |
+
print("Starting server...")
|
| 821 |
+
port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
|
| 822 |
+
demo.launch(share=False, server_name="0.0.0.0", server_port=port)
|
assets/demo/intrinsics.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5e46d677b736c45d98fda89d2b4b6b8e88028f8c7a5e25df6c9c3e61f6c6fed
|
| 3 |
+
size 164
|
assets/demo/rgb.png
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Vis4D (same approach: install dependencies, not vis4d itself)
|
| 2 |
+
absl-py
|
| 3 |
+
appdirs
|
| 4 |
+
cloudpickle
|
| 5 |
+
cython
|
| 6 |
+
devtools
|
| 7 |
+
h5py
|
| 8 |
+
jsonargparse[signatures]
|
| 9 |
+
lightning
|
| 10 |
+
ml_collections==1.1.0
|
| 11 |
+
numpy>=1.21.0,<2.0.0
|
| 12 |
+
opencv-python
|
| 13 |
+
pandas
|
| 14 |
+
pillow
|
| 15 |
+
plyfile
|
| 16 |
+
pycocotools
|
| 17 |
+
pydantic>=2.0
|
| 18 |
+
setuptools
|
| 19 |
+
tensorboard
|
| 20 |
+
termcolor
|
| 21 |
+
terminaltables
|
| 22 |
+
timm>=0.6.0
|
| 23 |
+
torch>=2.0.0
|
| 24 |
+
torchvision>=0.15.1
|
| 25 |
+
tqdm
|
| 26 |
+
utm
|
| 27 |
+
wheel
|
| 28 |
+
scipy
|
| 29 |
+
|
| 30 |
+
# Git utils
|
| 31 |
+
gitdb
|
| 32 |
+
GitPython
|
| 33 |
+
|
| 34 |
+
# WildDet3D
|
| 35 |
+
einops
|
| 36 |
+
fvcore
|
| 37 |
+
nltk
|
| 38 |
+
transformers
|
| 39 |
+
fairscale
|
| 40 |
+
mmengine
|
| 41 |
+
decord
|
| 42 |
+
|
| 43 |
+
# SAM3 dependencies
|
| 44 |
+
ftfy
|
| 45 |
+
regex
|
| 46 |
+
iopath
|
| 47 |
+
omegaconf
|
| 48 |
+
hydra-core
|
| 49 |
+
scikit-image
|
| 50 |
+
scikit-learn
|
| 51 |
+
open_clip_torch
|
| 52 |
+
|
| 53 |
+
# 3D visualization
|
| 54 |
+
pygltflib
|
| 55 |
+
trimesh
|
| 56 |
+
utils3d
|
| 57 |
+
|
| 58 |
+
# Depth estimation
|
| 59 |
+
huggingface_hub
|
third_party/lingbot_depth/mdm/model/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from typing import *
|
| 3 |
+
|
| 4 |
+
if TYPE_CHECKING:
|
| 5 |
+
from .v2 import MDMModel as MDMModelV2
|
| 6 |
+
|
| 7 |
+
def import_model_class_by_version(version: str) -> Type[Union['MDMModelV2']]:
|
| 8 |
+
assert version in ['v2'], f'Unsupported model version: {version}'
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
module = importlib.import_module(f'.{version}', __package__)
|
| 12 |
+
except ModuleNotFoundError:
|
| 13 |
+
raise ValueError(f'Model version "{version}" not found.')
|
| 14 |
+
cls = getattr(module, 'MDMModel')
|
| 15 |
+
return cls
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/backbones.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Weights(Enum):
|
| 15 |
+
LVD142M = "LVD142M"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _make_dinov2_model(
|
| 19 |
+
*,
|
| 20 |
+
arch_name: str = "vit_large",
|
| 21 |
+
img_size: int = 518,
|
| 22 |
+
patch_size: int = 14,
|
| 23 |
+
init_values: float = 1.0,
|
| 24 |
+
ffn_layer: str = "mlp",
|
| 25 |
+
block_chunks: int = 0,
|
| 26 |
+
num_register_tokens: int = 0,
|
| 27 |
+
interpolate_antialias: bool = False,
|
| 28 |
+
interpolate_offset: float = 0.1,
|
| 29 |
+
pretrained: bool = True,
|
| 30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
from ..models import vision_transformer as vits
|
| 34 |
+
|
| 35 |
+
if isinstance(weights, str):
|
| 36 |
+
try:
|
| 37 |
+
weights = Weights[weights]
|
| 38 |
+
except KeyError:
|
| 39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 40 |
+
|
| 41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 42 |
+
vit_kwargs = dict(
|
| 43 |
+
img_size=img_size,
|
| 44 |
+
patch_size=patch_size,
|
| 45 |
+
init_values=init_values,
|
| 46 |
+
ffn_layer=ffn_layer,
|
| 47 |
+
block_chunks=block_chunks,
|
| 48 |
+
num_register_tokens=num_register_tokens,
|
| 49 |
+
interpolate_antialias=interpolate_antialias,
|
| 50 |
+
interpolate_offset=interpolate_offset,
|
| 51 |
+
)
|
| 52 |
+
vit_kwargs.update(**kwargs)
|
| 53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 54 |
+
|
| 55 |
+
if pretrained:
|
| 56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
| 57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 59 |
+
model.load_state_dict(state_dict, strict=True)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 67 |
+
"""
|
| 68 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 72 |
+
"""
|
| 73 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 74 |
+
"""
|
| 75 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 79 |
+
"""
|
| 80 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 81 |
+
"""
|
| 82 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
| 83 |
+
|
| 84 |
+
def dinov2_vitl16(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 85 |
+
"""
|
| 86 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 87 |
+
"""
|
| 88 |
+
# kwargs.update({'img_size': 224, 'patch_size': 16, })
|
| 89 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=False, weights=weights, **kwargs)
|
| 90 |
+
|
| 91 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 92 |
+
"""
|
| 93 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 94 |
+
"""
|
| 95 |
+
return _make_dinov2_model(
|
| 96 |
+
arch_name="vit_giant2",
|
| 97 |
+
ffn_layer="swiglufused",
|
| 98 |
+
weights=weights,
|
| 99 |
+
pretrained=pretrained,
|
| 100 |
+
**kwargs,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 105 |
+
"""
|
| 106 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 107 |
+
"""
|
| 108 |
+
return _make_dinov2_model(
|
| 109 |
+
arch_name="vit_small",
|
| 110 |
+
pretrained=pretrained,
|
| 111 |
+
weights=weights,
|
| 112 |
+
num_register_tokens=4,
|
| 113 |
+
interpolate_antialias=True,
|
| 114 |
+
interpolate_offset=0.0,
|
| 115 |
+
**kwargs,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 120 |
+
"""
|
| 121 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 122 |
+
"""
|
| 123 |
+
return _make_dinov2_model(
|
| 124 |
+
arch_name="vit_base",
|
| 125 |
+
pretrained=pretrained,
|
| 126 |
+
weights=weights,
|
| 127 |
+
num_register_tokens=4,
|
| 128 |
+
interpolate_antialias=True,
|
| 129 |
+
interpolate_offset=0.0,
|
| 130 |
+
**kwargs,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 135 |
+
"""
|
| 136 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 137 |
+
"""
|
| 138 |
+
return _make_dinov2_model(
|
| 139 |
+
arch_name="vit_large",
|
| 140 |
+
pretrained=pretrained,
|
| 141 |
+
weights=weights,
|
| 142 |
+
num_register_tokens=4,
|
| 143 |
+
interpolate_antialias=True,
|
| 144 |
+
interpolate_offset=0.0,
|
| 145 |
+
**kwargs,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 150 |
+
"""
|
| 151 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 152 |
+
"""
|
| 153 |
+
return _make_dinov2_model(
|
| 154 |
+
arch_name="vit_giant2",
|
| 155 |
+
ffn_layer="swiglufused",
|
| 156 |
+
weights=weights,
|
| 157 |
+
pretrained=pretrained,
|
| 158 |
+
num_register_tokens=4,
|
| 159 |
+
interpolate_antialias=True,
|
| 160 |
+
interpolate_offset=0.0,
|
| 161 |
+
**kwargs,
|
| 162 |
+
)
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
| 18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CenterPadding(nn.Module):
|
| 24 |
+
def __init__(self, multiple):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.multiple = multiple
|
| 27 |
+
|
| 28 |
+
def _get_pad(self, size):
|
| 29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 30 |
+
pad_size = new_size - size
|
| 31 |
+
pad_size_left = pad_size // 2
|
| 32 |
+
pad_size_right = pad_size - pad_size_left
|
| 33 |
+
return pad_size_left, pad_size_right
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
| 38 |
+
output = F.pad(x, pads)
|
| 39 |
+
return output
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_head import DINOHead
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
| 12 |
+
from .patch_embed_mlp import PatchEmbed as PatchEmbedMLP
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/attention.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
from torch import nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger("dinov2")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 23 |
+
try:
|
| 24 |
+
if XFORMERS_ENABLED:
|
| 25 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 26 |
+
|
| 27 |
+
XFORMERS_AVAILABLE = True
|
| 28 |
+
# warnings.warn("xFormers is available (Attention)")
|
| 29 |
+
else:
|
| 30 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
| 31 |
+
raise ImportError
|
| 32 |
+
except ImportError:
|
| 33 |
+
XFORMERS_AVAILABLE = False
|
| 34 |
+
# warnings.warn("xFormers is not available (Attention)")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Attention(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
dim: int,
|
| 41 |
+
num_heads: int = 8,
|
| 42 |
+
qkv_bias: bool = False,
|
| 43 |
+
proj_bias: bool = True,
|
| 44 |
+
attn_drop: float = 0.0,
|
| 45 |
+
proj_drop: float = 0.0,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.num_heads = num_heads
|
| 49 |
+
head_dim = dim // num_heads
|
| 50 |
+
self.scale = head_dim**-0.5
|
| 51 |
+
|
| 52 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 53 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 54 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 55 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 56 |
+
|
| 57 |
+
# # Deprecated implementation, extremely slow
|
| 58 |
+
# def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 59 |
+
# B, N, C = x.shape
|
| 60 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 61 |
+
# q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 62 |
+
# attn = q @ k.transpose(-2, -1)
|
| 63 |
+
# attn = attn.softmax(dim=-1)
|
| 64 |
+
# attn = self.attn_drop(attn)
|
| 65 |
+
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 66 |
+
# x = self.proj(x)
|
| 67 |
+
# x = self.proj_drop(x)
|
| 68 |
+
# return x
|
| 69 |
+
|
| 70 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 71 |
+
B, N, C = x.shape
|
| 72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
|
| 73 |
+
|
| 74 |
+
q, k, v = qkv.unbind(0) # (B, H, N, C // H)
|
| 75 |
+
|
| 76 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
| 77 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
x = self.proj_drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
class MemEffAttention(Attention):
|
| 84 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 85 |
+
if not XFORMERS_AVAILABLE:
|
| 86 |
+
if attn_bias is not None:
|
| 87 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 88 |
+
return super().forward(x)
|
| 89 |
+
|
| 90 |
+
B, N, C = x.shape
|
| 91 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 92 |
+
|
| 93 |
+
q, k, v = unbind(qkv, 2)
|
| 94 |
+
|
| 95 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 96 |
+
x = x.reshape([B, N, C])
|
| 97 |
+
|
| 98 |
+
x = self.proj(x)
|
| 99 |
+
x = self.proj_drop(x)
|
| 100 |
+
return x
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/block.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention, MemEffAttention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 28 |
+
try:
|
| 29 |
+
if XFORMERS_ENABLED:
|
| 30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
| 31 |
+
|
| 32 |
+
XFORMERS_AVAILABLE = True
|
| 33 |
+
# warnings.warn("xFormers is available (Block)")
|
| 34 |
+
else:
|
| 35 |
+
# warnings.warn("xFormers is disabled (Block)")
|
| 36 |
+
raise ImportError
|
| 37 |
+
except ImportError:
|
| 38 |
+
XFORMERS_AVAILABLE = False
|
| 39 |
+
# warnings.warn("xFormers is not available (Block)")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Block(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim: int,
|
| 46 |
+
num_heads: int,
|
| 47 |
+
mlp_ratio: float = 4.0,
|
| 48 |
+
qkv_bias: bool = False,
|
| 49 |
+
proj_bias: bool = True,
|
| 50 |
+
ffn_bias: bool = True,
|
| 51 |
+
drop: float = 0.0,
|
| 52 |
+
attn_drop: float = 0.0,
|
| 53 |
+
init_values=None,
|
| 54 |
+
drop_path: float = 0.0,
|
| 55 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 56 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 57 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 58 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 59 |
+
) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 62 |
+
self.norm1 = norm_layer(dim)
|
| 63 |
+
self.attn = attn_class(
|
| 64 |
+
dim,
|
| 65 |
+
num_heads=num_heads,
|
| 66 |
+
qkv_bias=qkv_bias,
|
| 67 |
+
proj_bias=proj_bias,
|
| 68 |
+
attn_drop=attn_drop,
|
| 69 |
+
proj_drop=drop,
|
| 70 |
+
)
|
| 71 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 72 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 73 |
+
|
| 74 |
+
self.norm2 = norm_layer(dim)
|
| 75 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 76 |
+
self.mlp = ffn_layer(
|
| 77 |
+
in_features=dim,
|
| 78 |
+
hidden_features=mlp_hidden_dim,
|
| 79 |
+
act_layer=act_layer,
|
| 80 |
+
drop=drop,
|
| 81 |
+
bias=ffn_bias,
|
| 82 |
+
)
|
| 83 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 84 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 85 |
+
|
| 86 |
+
self.sample_drop_ratio = drop_path
|
| 87 |
+
|
| 88 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 89 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 90 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 91 |
+
|
| 92 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 93 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 94 |
+
|
| 95 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 96 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 97 |
+
x = drop_add_residual_stochastic_depth(
|
| 98 |
+
x,
|
| 99 |
+
residual_func=attn_residual_func,
|
| 100 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 101 |
+
)
|
| 102 |
+
x = drop_add_residual_stochastic_depth(
|
| 103 |
+
x,
|
| 104 |
+
residual_func=ffn_residual_func,
|
| 105 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 106 |
+
)
|
| 107 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 108 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 109 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 110 |
+
else:
|
| 111 |
+
x = x + attn_residual_func(x)
|
| 112 |
+
x = x + ffn_residual_func(x)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def drop_add_residual_stochastic_depth(
|
| 117 |
+
x: Tensor,
|
| 118 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 119 |
+
sample_drop_ratio: float = 0.0,
|
| 120 |
+
) -> Tensor:
|
| 121 |
+
# 1) extract subset using permutation
|
| 122 |
+
b, n, d = x.shape
|
| 123 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 124 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 125 |
+
x_subset = x[brange]
|
| 126 |
+
|
| 127 |
+
# 2) apply residual_func to get residual
|
| 128 |
+
residual = residual_func(x_subset)
|
| 129 |
+
|
| 130 |
+
x_flat = x.flatten(1)
|
| 131 |
+
residual = residual.flatten(1)
|
| 132 |
+
|
| 133 |
+
residual_scale_factor = b / sample_subset_size
|
| 134 |
+
|
| 135 |
+
# 3) add the residual
|
| 136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 137 |
+
return x_plus_residual.view_as(x)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 141 |
+
b, n, d = x.shape
|
| 142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 144 |
+
residual_scale_factor = b / sample_subset_size
|
| 145 |
+
return brange, residual_scale_factor
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 149 |
+
if scaling_vector is None:
|
| 150 |
+
x_flat = x.flatten(1)
|
| 151 |
+
residual = residual.flatten(1)
|
| 152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 153 |
+
else:
|
| 154 |
+
x_plus_residual = scaled_index_add(
|
| 155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 156 |
+
)
|
| 157 |
+
return x_plus_residual
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 164 |
+
"""
|
| 165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 166 |
+
"""
|
| 167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 169 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 170 |
+
seqlens = []
|
| 171 |
+
for b, x in zip(batch_sizes, x_list):
|
| 172 |
+
for _ in range(b):
|
| 173 |
+
seqlens.append(x.shape[1])
|
| 174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 175 |
+
attn_bias._batch_sizes = batch_sizes
|
| 176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 177 |
+
|
| 178 |
+
if branges is not None:
|
| 179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 180 |
+
else:
|
| 181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 183 |
+
|
| 184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def drop_add_residual_stochastic_depth_list(
|
| 188 |
+
x_list: List[Tensor],
|
| 189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 190 |
+
sample_drop_ratio: float = 0.0,
|
| 191 |
+
scaling_vector=None,
|
| 192 |
+
) -> Tensor:
|
| 193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 195 |
+
branges = [s[0] for s in branges_scales]
|
| 196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 197 |
+
|
| 198 |
+
# 2) get attention bias and index+concat the tensors
|
| 199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 200 |
+
|
| 201 |
+
# 3) apply residual_func to get residual, and split the result
|
| 202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 203 |
+
|
| 204 |
+
outputs = []
|
| 205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 207 |
+
return outputs
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class NestedTensorBlock(Block):
|
| 211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 212 |
+
"""
|
| 213 |
+
x_list contains a list of tensors to nest together and run
|
| 214 |
+
"""
|
| 215 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 216 |
+
|
| 217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 218 |
+
|
| 219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 221 |
+
|
| 222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 223 |
+
return self.mlp(self.norm2(x))
|
| 224 |
+
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=attn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 232 |
+
x_list,
|
| 233 |
+
residual_func=ffn_residual_func,
|
| 234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 236 |
+
)
|
| 237 |
+
return x_list
|
| 238 |
+
else:
|
| 239 |
+
|
| 240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 242 |
+
|
| 243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 245 |
+
|
| 246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 248 |
+
x = x + ffn_residual_func(x)
|
| 249 |
+
return attn_bias.split(x)
|
| 250 |
+
|
| 251 |
+
def forward(self, x_or_x_list):
|
| 252 |
+
if isinstance(x_or_x_list, Tensor):
|
| 253 |
+
return super().forward(x_or_x_list)
|
| 254 |
+
elif isinstance(x_or_x_list, list):
|
| 255 |
+
if not XFORMERS_AVAILABLE:
|
| 256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 257 |
+
return self.forward_nested(x_or_x_list)
|
| 258 |
+
else:
|
| 259 |
+
raise AssertionError
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/dino_head.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn.init import trunc_normal_
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOHead(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_dim,
|
| 16 |
+
out_dim,
|
| 17 |
+
use_bn=False,
|
| 18 |
+
nlayers=3,
|
| 19 |
+
hidden_dim=2048,
|
| 20 |
+
bottleneck_dim=256,
|
| 21 |
+
mlp_bias=True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
nlayers = max(nlayers, 1)
|
| 25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
| 26 |
+
self.apply(self._init_weights)
|
| 27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 28 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, m):
|
| 31 |
+
if isinstance(m, nn.Linear):
|
| 32 |
+
trunc_normal_(m.weight, std=0.02)
|
| 33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.mlp(x)
|
| 38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 40 |
+
x = self.last_layer(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
| 45 |
+
if nlayers == 1:
|
| 46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 47 |
+
else:
|
| 48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 49 |
+
if use_bn:
|
| 50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 51 |
+
layers.append(nn.GELU())
|
| 52 |
+
for _ in range(nlayers - 2):
|
| 53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 54 |
+
if use_bn:
|
| 55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 56 |
+
layers.append(nn.GELU())
|
| 57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 58 |
+
return nn.Sequential(*layers)
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 20 |
+
inplace: bool = False,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.inplace = inplace
|
| 24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
_, _, H, W = x.shape
|
| 70 |
+
patch_H, patch_W = self.patch_size
|
| 71 |
+
|
| 72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 74 |
+
|
| 75 |
+
x = self.proj(x) # B C H W
|
| 76 |
+
H, W = x.size(2), x.size(3)
|
| 77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 78 |
+
x = self.norm(x)
|
| 79 |
+
if not self.flatten_embedding:
|
| 80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def flops(self) -> float:
|
| 84 |
+
Ho, Wo = self.patches_resolution
|
| 85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 86 |
+
if self.norm is not None:
|
| 87 |
+
flops += Ho * Wo * self.embed_dim
|
| 88 |
+
return flops
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
def make_2tuple(x):
|
| 18 |
+
if isinstance(x, tuple):
|
| 19 |
+
assert len(x) == 2
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
assert isinstance(x, int)
|
| 23 |
+
return (x, x)
|
| 24 |
+
|
| 25 |
+
class PixelUnshuffle (nn.Module):
|
| 26 |
+
def __init__(self, downscale_factor):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.downscale_factor = downscale_factor
|
| 29 |
+
|
| 30 |
+
def forward(self, input):
|
| 31 |
+
if input.numel() == 0:
|
| 32 |
+
# this is not in the original torch implementation
|
| 33 |
+
C,H,W = input.shape[-3:]
|
| 34 |
+
assert H and W and H % self.downscale_factor == W%self.downscale_factor == 0
|
| 35 |
+
return input.view(*input.shape[:-3], C*self.downscale_factor**2, H//self.downscale_factor, W//self.downscale_factor)
|
| 36 |
+
else:
|
| 37 |
+
return F.pixel_unshuffle(input, self.downscale_factor)
|
| 38 |
+
|
| 39 |
+
class Permute(nn.Module):
|
| 40 |
+
dims: tuple[int, ...]
|
| 41 |
+
def __init__(self, dims: tuple[int, ...]) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.dims = tuple(dims)
|
| 44 |
+
|
| 45 |
+
def __repr__(self):
|
| 46 |
+
return f"Permute{self.dims}"
|
| 47 |
+
|
| 48 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
return input.permute(*self.dims)
|
| 50 |
+
|
| 51 |
+
from itertools import repeat
|
| 52 |
+
import collections.abc
|
| 53 |
+
def _ntuple(n):
|
| 54 |
+
def parse(x):
|
| 55 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 56 |
+
return x
|
| 57 |
+
return tuple(repeat(x, n))
|
| 58 |
+
return parse
|
| 59 |
+
to_2tuple = _ntuple(2)
|
| 60 |
+
|
| 61 |
+
class Mlp(nn.Module):
|
| 62 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 63 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
|
| 64 |
+
super().__init__()
|
| 65 |
+
out_features = out_features or in_features
|
| 66 |
+
hidden_features = hidden_features or in_features
|
| 67 |
+
bias = to_2tuple(bias)
|
| 68 |
+
drop_probs = to_2tuple(drop)
|
| 69 |
+
|
| 70 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
|
| 71 |
+
self.act = act_layer()
|
| 72 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 73 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
| 74 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
x = self.fc1(x)
|
| 78 |
+
x = self.act(x)
|
| 79 |
+
x = self.drop1(x)
|
| 80 |
+
x = self.fc2(x)
|
| 81 |
+
x = self.drop2(x)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
class PatchEmbed(nn.Module):
|
| 85 |
+
"""
|
| 86 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
img_size: Image size.
|
| 90 |
+
patch_size: Patch token size.
|
| 91 |
+
in_chans: Number of input image channels.
|
| 92 |
+
embed_dim: Number of linear projection output channels.
|
| 93 |
+
norm_layer: Normalization layer.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 99 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 100 |
+
in_chans: int = 3,
|
| 101 |
+
embed_dim: int = 768,
|
| 102 |
+
norm_layer: Optional[Callable] = None,
|
| 103 |
+
flatten_embedding: bool = True,
|
| 104 |
+
) -> None:
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
image_HW = make_2tuple(img_size)
|
| 108 |
+
patch_HW = make_2tuple(patch_size)
|
| 109 |
+
patch_grid_size = (
|
| 110 |
+
image_HW[0] // patch_HW[0],
|
| 111 |
+
image_HW[1] // patch_HW[1],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.img_size = image_HW
|
| 115 |
+
self.patch_size = patch_HW
|
| 116 |
+
self.patches_resolution = patch_grid_size
|
| 117 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 118 |
+
|
| 119 |
+
self.in_chans = in_chans
|
| 120 |
+
self.embed_dim = embed_dim
|
| 121 |
+
|
| 122 |
+
self.flatten_embedding = flatten_embedding
|
| 123 |
+
|
| 124 |
+
self.proj = nn.Sequential(
|
| 125 |
+
PixelUnshuffle(patch_size),
|
| 126 |
+
Permute((0,2,3,1)),
|
| 127 |
+
Mlp(in_chans * patch_size * patch_size, 4*embed_dim, embed_dim),
|
| 128 |
+
Permute((0,3,1,2)),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 132 |
+
|
| 133 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 134 |
+
_, _, H, W = x.shape
|
| 135 |
+
patch_H, patch_W = self.patch_size
|
| 136 |
+
|
| 137 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 138 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 139 |
+
|
| 140 |
+
x = self.proj(x) # B C H W
|
| 141 |
+
H, W = x.size(2), x.size(3)
|
| 142 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 143 |
+
x = self.norm(x)
|
| 144 |
+
if not self.flatten_embedding:
|
| 145 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
def flops(self) -> float:
|
| 149 |
+
Ho, Wo = self.patches_resolution
|
| 150 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 151 |
+
if self.norm is not None:
|
| 152 |
+
flops += Ho * Wo * self.embed_dim
|
| 153 |
+
return flops
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
hidden_features: Optional[int] = None,
|
| 19 |
+
out_features: Optional[int] = None,
|
| 20 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 21 |
+
drop: float = 0.0,
|
| 22 |
+
bias: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
x12 = self.w12(x)
|
| 32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 33 |
+
hidden = F.silu(x1) * x2
|
| 34 |
+
return self.w3(hidden)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 38 |
+
try:
|
| 39 |
+
if XFORMERS_ENABLED:
|
| 40 |
+
from xformers.ops import SwiGLU
|
| 41 |
+
|
| 42 |
+
XFORMERS_AVAILABLE = True
|
| 43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 44 |
+
else:
|
| 45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 46 |
+
raise ImportError
|
| 47 |
+
except ImportError:
|
| 48 |
+
SwiGLU = SwiGLUFFN
|
| 49 |
+
XFORMERS_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_features: int,
|
| 58 |
+
hidden_features: Optional[int] = None,
|
| 59 |
+
out_features: Optional[int] = None,
|
| 60 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 61 |
+
drop: float = 0.0,
|
| 62 |
+
bias: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
out_features = out_features or in_features
|
| 65 |
+
hidden_features = hidden_features or in_features
|
| 66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 67 |
+
super().__init__(
|
| 68 |
+
in_features=in_features,
|
| 69 |
+
hidden_features=hidden_features,
|
| 70 |
+
out_features=out_features,
|
| 71 |
+
bias=bias,
|
| 72 |
+
)
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
Docstring for MDM.mdm.model.dinov2_rgbd.models_vlmae
|
| 8 |
+
=======================================================
|
| 9 |
+
This version is modified from the original DINOv2 to support the MIM(masked image modeling) of RGBD input.
|
| 10 |
+
(The original DINOv2 is available at https://github.com/facebookresearch/dinov2.)
|
| 11 |
+
|
| 12 |
+
Core Changes:
|
| 13 |
+
1. We add the depth input into the original DINOv2 transformer encoder.
|
| 14 |
+
|
| 15 |
+
2. We support the Variable Mask Ratio MAE for both RGB and Depth input.
|
| 16 |
+
'''
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
|
| 20 |
+
from . import vision_transformer as vits
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger("dinov2")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_model(args, only_teacher=False, img_size=224):
|
| 26 |
+
args.arch = args.arch.removesuffix("_memeff")
|
| 27 |
+
if "vit" in args.arch:
|
| 28 |
+
vit_kwargs = dict(
|
| 29 |
+
img_size=img_size,
|
| 30 |
+
patch_size=args.patch_size,
|
| 31 |
+
init_values=args.layerscale,
|
| 32 |
+
ffn_layer=args.ffn_layer,
|
| 33 |
+
block_chunks=args.block_chunks,
|
| 34 |
+
qkv_bias=args.qkv_bias,
|
| 35 |
+
proj_bias=args.proj_bias,
|
| 36 |
+
ffn_bias=args.ffn_bias,
|
| 37 |
+
num_register_tokens=args.num_register_tokens,
|
| 38 |
+
interpolate_offset=args.interpolate_offset,
|
| 39 |
+
interpolate_antialias=args.interpolate_antialias,
|
| 40 |
+
)
|
| 41 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
| 42 |
+
if only_teacher:
|
| 43 |
+
return teacher, teacher.embed_dim
|
| 44 |
+
student = vits.__dict__[args.arch](
|
| 45 |
+
**vit_kwargs,
|
| 46 |
+
drop_path_rate=args.drop_path_rate,
|
| 47 |
+
drop_path_uniform=args.drop_path_uniform,
|
| 48 |
+
)
|
| 49 |
+
embed_dim = student.embed_dim
|
| 50 |
+
return student, teacher, embed_dim
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
| 54 |
+
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
| 55 |
+
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/mask_utils.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
def depth_masking(
|
| 3 |
+
x,
|
| 4 |
+
patch_num_h,
|
| 5 |
+
patch_num_w,
|
| 6 |
+
depth_values,
|
| 7 |
+
depth_mask_threshold_ratio=None,
|
| 8 |
+
depth_mask_threshold_num=None,
|
| 9 |
+
valid_depth_range=(0.1, 10.0),
|
| 10 |
+
):
|
| 11 |
+
"""
|
| 12 |
+
Perform patch masking based on depth validity
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
x: [B, N, D] input features (after patch embedding)
|
| 16 |
+
patch_num_h: int, height of the patch grid
|
| 17 |
+
patch_num_w: int, width of the patch grid
|
| 18 |
+
depth_values: [B, 1, H_img, W_img], raw depth map
|
| 19 |
+
depth_mask_threshold_ratio: float or list, valid depth ratio threshold (0-1)
|
| 20 |
+
depth_mask_threshold_num: int or list, valid depth pixel count threshold
|
| 21 |
+
valid_depth_range: tuple, valid depth range (min, max)
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
visible_list: list of [N_visible_i, D], visible patches for each sample
|
| 25 |
+
mask_info: dict, containing masking information
|
| 26 |
+
"""
|
| 27 |
+
B, N, D = x.shape
|
| 28 |
+
device = x.device
|
| 29 |
+
|
| 30 |
+
assert N == patch_num_h * patch_num_w, \
|
| 31 |
+
f"N={N} must equal patch_num_h * patch_num_w = {patch_num_h * patch_num_w}"
|
| 32 |
+
|
| 33 |
+
# Compute depth invalid mask
|
| 34 |
+
depth_invalid_mask = _compute_depth_invalid_mask(
|
| 35 |
+
depth_values,
|
| 36 |
+
patch_num_h,
|
| 37 |
+
patch_num_w,
|
| 38 |
+
depth_mask_threshold_ratio,
|
| 39 |
+
depth_mask_threshold_num,
|
| 40 |
+
valid_depth_range
|
| 41 |
+
) # [B, N], True indicates this patch is invalid
|
| 42 |
+
|
| 43 |
+
# Process each sample separately
|
| 44 |
+
visible_list = []
|
| 45 |
+
mask_info = {
|
| 46 |
+
'visible_indices': [],
|
| 47 |
+
'mask_indices': [],
|
| 48 |
+
'num_visible': [],
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
for i in range(B):
|
| 52 |
+
# Get valid patch indices
|
| 53 |
+
valid_mask = ~depth_invalid_mask[i] # [N]
|
| 54 |
+
visible_indices = torch.where(valid_mask)[0]
|
| 55 |
+
masked_indices = torch.where(depth_invalid_mask[i])[0]
|
| 56 |
+
|
| 57 |
+
# Extract visible patches
|
| 58 |
+
visible = x[i, visible_indices] # [N_visible, D]
|
| 59 |
+
visible_list.append(visible)
|
| 60 |
+
|
| 61 |
+
# Record information
|
| 62 |
+
mask_info['visible_indices'].append(visible_indices)
|
| 63 |
+
mask_info['mask_indices'].append(masked_indices)
|
| 64 |
+
mask_info['num_visible'].append(len(visible_indices))
|
| 65 |
+
|
| 66 |
+
return visible_list, mask_info
|
| 67 |
+
|
| 68 |
+
def _compute_depth_invalid_mask(
|
| 69 |
+
depth_values,
|
| 70 |
+
H_patch,
|
| 71 |
+
W_patch,
|
| 72 |
+
threshold_ratio,
|
| 73 |
+
threshold_num,
|
| 74 |
+
valid_range
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Compute depth validity for each patch
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
depth_values: [B, 1, H_img, W_img] raw depth map
|
| 81 |
+
H_patch, W_patch: patch grid dimensions
|
| 82 |
+
threshold_ratio: float or list, valid depth ratio threshold
|
| 83 |
+
threshold_num: int or list, valid depth pixel count threshold
|
| 84 |
+
valid_range: tuple, (min_depth, max_depth)
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
invalid_mask: [B, N] bool tensor, True indicates this patch is invalid
|
| 88 |
+
"""
|
| 89 |
+
B, _, H_img, W_img = depth_values.shape
|
| 90 |
+
N = H_patch * W_patch
|
| 91 |
+
device = depth_values.device
|
| 92 |
+
|
| 93 |
+
min_depth, max_depth = valid_range
|
| 94 |
+
|
| 95 |
+
# Calculate pixel size for each patch
|
| 96 |
+
patch_h = H_img // H_patch
|
| 97 |
+
patch_w = W_img // W_patch
|
| 98 |
+
|
| 99 |
+
assert H_img % H_patch == 0 and W_img % W_patch == 0, \
|
| 100 |
+
f"Image size ({H_img}, {W_img}) must be divisible by patch grid ({H_patch}, {W_patch})"
|
| 101 |
+
|
| 102 |
+
# Reshape depth map into patches: [B, 1, H_img, W_img] -> [B, H_patch, patch_h, W_patch, patch_w]
|
| 103 |
+
depth_reshaped = depth_values.view(B, 1, H_patch, patch_h, W_patch, patch_w)
|
| 104 |
+
|
| 105 |
+
# Transpose and flatten: [B, H_patch, W_patch, patch_h, patch_w] -> [B, N, patch_h*patch_w]
|
| 106 |
+
depth_reshaped = depth_reshaped.permute(0, 2, 4, 1, 3, 5).reshape(B, N, -1)
|
| 107 |
+
|
| 108 |
+
# Calculate valid depth
|
| 109 |
+
valid_depth = (depth_reshaped >= min_depth) & (depth_reshaped <= max_depth)
|
| 110 |
+
valid_depth_ratio = valid_depth.float().mean(dim=-1) # [B, N]
|
| 111 |
+
valid_depth_num = valid_depth.float().sum(dim=-1) # [B, N]
|
| 112 |
+
|
| 113 |
+
# Handle list-form thresholds (different thresholds for each sample in batch)
|
| 114 |
+
if isinstance(threshold_ratio, list) or isinstance(threshold_num, list):
|
| 115 |
+
invalid_mask = torch.zeros(B, N, dtype=torch.bool, device=device)
|
| 116 |
+
|
| 117 |
+
for i in range(B):
|
| 118 |
+
tr = threshold_ratio[i] if isinstance(threshold_ratio, list) else threshold_ratio
|
| 119 |
+
tn = threshold_num[i] if isinstance(threshold_num, list) else threshold_num
|
| 120 |
+
|
| 121 |
+
sample_mask = torch.zeros(N, dtype=torch.bool, device=device)
|
| 122 |
+
if tr is not None:
|
| 123 |
+
sample_mask |= (valid_depth_ratio[i] < tr)
|
| 124 |
+
if tn is not None:
|
| 125 |
+
sample_mask |= (valid_depth_num[i] < tn)
|
| 126 |
+
|
| 127 |
+
invalid_mask[i] = sample_mask
|
| 128 |
+
else:
|
| 129 |
+
# Uniform threshold
|
| 130 |
+
invalid_mask = torch.zeros(B, N, dtype=torch.bool, device=device)
|
| 131 |
+
|
| 132 |
+
if threshold_ratio is not None:
|
| 133 |
+
invalid_mask |= (valid_depth_ratio < threshold_ratio)
|
| 134 |
+
if threshold_num is not None:
|
| 135 |
+
invalid_mask |= (valid_depth_num < threshold_num)
|
| 136 |
+
|
| 137 |
+
return invalid_mask
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable, Optional, List
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.utils.checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
|
| 20 |
+
from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 21 |
+
from ..layers import PatchEmbedMLP
|
| 22 |
+
|
| 23 |
+
from .mask_utils import depth_masking
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("dinov2_rgbd")
|
| 26 |
+
|
| 27 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 28 |
+
if not depth_first and include_root:
|
| 29 |
+
fn(module=module, name=name)
|
| 30 |
+
for child_name, child_module in module.named_children():
|
| 31 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 32 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 33 |
+
if depth_first and include_root:
|
| 34 |
+
fn(module=module, name=name)
|
| 35 |
+
return module
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BlockChunk(nn.ModuleList):
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
for b in self:
|
| 41 |
+
x = b(x)
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DinoVisionTransformer(nn.Module):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
img_size=224,
|
| 49 |
+
patch_size=16,
|
| 50 |
+
in_chans=3,
|
| 51 |
+
embed_dim=768,
|
| 52 |
+
depth=12,
|
| 53 |
+
num_heads=12,
|
| 54 |
+
mlp_ratio=4.0,
|
| 55 |
+
qkv_bias=True,
|
| 56 |
+
ffn_bias=True,
|
| 57 |
+
proj_bias=True,
|
| 58 |
+
drop_path_rate=0.0,
|
| 59 |
+
drop_path_uniform=False,
|
| 60 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 61 |
+
embed_layer=PatchEmbed,
|
| 62 |
+
act_layer=nn.GELU,
|
| 63 |
+
block_fn=Block,
|
| 64 |
+
ffn_layer="mlp",
|
| 65 |
+
block_chunks=1,
|
| 66 |
+
num_register_tokens=0,
|
| 67 |
+
interpolate_antialias=False,
|
| 68 |
+
interpolate_offset=0.1,
|
| 69 |
+
img_depth_fuse_mode='',
|
| 70 |
+
depth_mask_ratio:Union[float, List[float]]=0.6,
|
| 71 |
+
img_mask_ratio:Union[float, List[float]]=0.0,
|
| 72 |
+
depth_mask_patch_grid_size: int=1,
|
| 73 |
+
img_mask_patch_grid_size: int=1,
|
| 74 |
+
depth_emb_mode='',
|
| 75 |
+
# depth_emb_mode='conv_1c'
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Args:
|
| 79 |
+
img_size (int, tuple): input image size
|
| 80 |
+
patch_size (int, tuple): patch size
|
| 81 |
+
in_chans (int): number of input channels
|
| 82 |
+
embed_dim (int): embedding dimension
|
| 83 |
+
depth (int): depth of transformer
|
| 84 |
+
num_heads (int): number of attention heads
|
| 85 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 86 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 87 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 88 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 89 |
+
drop_path_rate (float): stochastic depth rate
|
| 90 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 91 |
+
weight_init (str): weight init scheme
|
| 92 |
+
init_values (float): layer-scale init values
|
| 93 |
+
embed_layer (nn.Module): patch embedding layer
|
| 94 |
+
act_layer (nn.Module): MLP activation layer
|
| 95 |
+
block_fn (nn.Module): transformer block class
|
| 96 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 97 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 98 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 99 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 100 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 101 |
+
"""
|
| 102 |
+
super().__init__()
|
| 103 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 104 |
+
|
| 105 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 106 |
+
self.num_tokens = 1
|
| 107 |
+
self.n_blocks = depth
|
| 108 |
+
self.num_heads = num_heads
|
| 109 |
+
self.patch_size = patch_size
|
| 110 |
+
self.num_register_tokens = num_register_tokens
|
| 111 |
+
self.interpolate_antialias = interpolate_antialias
|
| 112 |
+
self.interpolate_offset = interpolate_offset
|
| 113 |
+
|
| 114 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 115 |
+
num_patches = self.patch_embed.num_patches
|
| 116 |
+
|
| 117 |
+
self.depth_emb_mode = depth_emb_mode
|
| 118 |
+
if self.depth_emb_mode == 'conv_1c':
|
| 119 |
+
self.depth_patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=1, embed_dim=embed_dim)
|
| 120 |
+
else:
|
| 121 |
+
self.depth_patch_embed = None
|
| 122 |
+
|
| 123 |
+
self.img_depth_fuse_mode = img_depth_fuse_mode
|
| 124 |
+
|
| 125 |
+
self.depth_mask_patch_grid_size = depth_mask_patch_grid_size
|
| 126 |
+
self.img_mask_patch_grid_size = img_mask_patch_grid_size
|
| 127 |
+
assert self.depth_mask_patch_grid_size == 1, "depth_mask_patch_grid_size must be 1 in current version"
|
| 128 |
+
assert self.img_mask_patch_grid_size == 1, "img_mask_patch_grid_size must be 1 in current version"
|
| 129 |
+
|
| 130 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 131 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 132 |
+
assert num_register_tokens >= 0
|
| 133 |
+
self.register_tokens = (
|
| 134 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if drop_path_uniform is True:
|
| 138 |
+
dpr = [drop_path_rate] * depth
|
| 139 |
+
else:
|
| 140 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 141 |
+
|
| 142 |
+
if ffn_layer == "mlp":
|
| 143 |
+
logger.info("using MLP layer as FFN")
|
| 144 |
+
ffn_layer = Mlp
|
| 145 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 146 |
+
logger.info("using SwiGLU layer as FFN")
|
| 147 |
+
ffn_layer = SwiGLUFFNFused
|
| 148 |
+
elif ffn_layer == "identity":
|
| 149 |
+
logger.info("using Identity layer as FFN")
|
| 150 |
+
|
| 151 |
+
def f(*args, **kwargs):
|
| 152 |
+
return nn.Identity()
|
| 153 |
+
|
| 154 |
+
ffn_layer = f
|
| 155 |
+
else:
|
| 156 |
+
raise NotImplementedError
|
| 157 |
+
|
| 158 |
+
blocks_list = [
|
| 159 |
+
block_fn(
|
| 160 |
+
dim=embed_dim,
|
| 161 |
+
num_heads=num_heads,
|
| 162 |
+
mlp_ratio=mlp_ratio,
|
| 163 |
+
qkv_bias=qkv_bias,
|
| 164 |
+
proj_bias=proj_bias,
|
| 165 |
+
ffn_bias=ffn_bias,
|
| 166 |
+
drop_path=dpr[i],
|
| 167 |
+
norm_layer=norm_layer,
|
| 168 |
+
act_layer=act_layer,
|
| 169 |
+
ffn_layer=ffn_layer,
|
| 170 |
+
init_values=init_values,
|
| 171 |
+
)
|
| 172 |
+
for i in range(depth)
|
| 173 |
+
]
|
| 174 |
+
if block_chunks > 0:
|
| 175 |
+
self.chunked_blocks = True
|
| 176 |
+
chunked_blocks = []
|
| 177 |
+
chunksize = depth // block_chunks
|
| 178 |
+
for i in range(0, depth, chunksize):
|
| 179 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 180 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 181 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 182 |
+
else:
|
| 183 |
+
self.chunked_blocks = False
|
| 184 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 185 |
+
|
| 186 |
+
self.norm = norm_layer(embed_dim)
|
| 187 |
+
self.head = nn.Identity()
|
| 188 |
+
|
| 189 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 190 |
+
|
| 191 |
+
self.init_weights()
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def onnx_compatible_mode(self):
|
| 195 |
+
return getattr(self, "_onnx_compatible_mode", False)
|
| 196 |
+
|
| 197 |
+
@onnx_compatible_mode.setter
|
| 198 |
+
def onnx_compatible_mode(self, value: bool):
|
| 199 |
+
self._onnx_compatible_mode = value
|
| 200 |
+
|
| 201 |
+
def init_weights(self):
|
| 202 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 203 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 204 |
+
if self.register_tokens is not None:
|
| 205 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 206 |
+
named_apply(init_weights_vit_timm, self)
|
| 207 |
+
|
| 208 |
+
def interpolate_pos_encoding(self, x, h, w):
|
| 209 |
+
previous_dtype = x.dtype
|
| 210 |
+
npatch = x.shape[1] - 1
|
| 211 |
+
batch_size = x.shape[0]
|
| 212 |
+
N = self.pos_embed.shape[1] - 1
|
| 213 |
+
if not self.onnx_compatible_mode and npatch == N and w == h:
|
| 214 |
+
return self.pos_embed
|
| 215 |
+
pos_embed = self.pos_embed.float()
|
| 216 |
+
class_pos_embed = pos_embed[:, 0, :]
|
| 217 |
+
patch_pos_embed = pos_embed[:, 1:, :]
|
| 218 |
+
dim = x.shape[-1]
|
| 219 |
+
h0, w0 = h // self.patch_size, w // self.patch_size
|
| 220 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 221 |
+
assert N == M * M
|
| 222 |
+
kwargs = {}
|
| 223 |
+
if not self.onnx_compatible_mode and self.interpolate_offset > 0:
|
| 224 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 225 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 226 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 227 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 228 |
+
kwargs["scale_factor"] = (sy, sx)
|
| 229 |
+
else:
|
| 230 |
+
# Simply specify an output size instead of a scale factor
|
| 231 |
+
kwargs["size"] = (h0, w0)
|
| 232 |
+
|
| 233 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 234 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 235 |
+
mode="bicubic",
|
| 236 |
+
antialias=self.interpolate_antialias,
|
| 237 |
+
**kwargs,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
assert (h0, w0) == patch_pos_embed.shape[-2:]
|
| 241 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
|
| 242 |
+
return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype)
|
| 243 |
+
|
| 244 |
+
def interpolate_pos_encoding_without_cls(self, x, h, w, input_pos_embed):
|
| 245 |
+
previous_dtype = x.dtype
|
| 246 |
+
npatch = x.shape[1]
|
| 247 |
+
batch_size = x.shape[0]
|
| 248 |
+
N = input_pos_embed.shape[1]
|
| 249 |
+
if not self.onnx_compatible_mode and npatch == N and w == h:
|
| 250 |
+
return input_pos_embed
|
| 251 |
+
patch_pos_embed = input_pos_embed.float()
|
| 252 |
+
dim = x.shape[-1]
|
| 253 |
+
h0, w0 = h // self.patch_size, w // self.patch_size
|
| 254 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 255 |
+
assert N == M * M
|
| 256 |
+
kwargs = {}
|
| 257 |
+
if not self.onnx_compatible_mode and self.interpolate_offset > 0:
|
| 258 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 259 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 260 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 261 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 262 |
+
kwargs["scale_factor"] = (sy, sx)
|
| 263 |
+
else:
|
| 264 |
+
# Simply specify an output size instead of a scale factor
|
| 265 |
+
kwargs["size"] = (h0, w0)
|
| 266 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 267 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 268 |
+
mode="bicubic",
|
| 269 |
+
antialias=self.interpolate_antialias,
|
| 270 |
+
**kwargs,
|
| 271 |
+
)
|
| 272 |
+
assert (h0, w0) == patch_pos_embed.shape[-2:]
|
| 273 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
|
| 274 |
+
return patch_pos_embed.to(previous_dtype)
|
| 275 |
+
|
| 276 |
+
def prepare_tokens_with_masks(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, masks=None, **kwargs):
|
| 277 |
+
assert masks is None, "extra masks are not supported for this model."
|
| 278 |
+
B, nc, h_img, w_img = x_img.shape
|
| 279 |
+
_, _, h_depth, w_depth = x_depth.shape
|
| 280 |
+
x_depth_raw = x_depth.clone()
|
| 281 |
+
x_depth_raw[x_depth_raw == 0] = -10
|
| 282 |
+
|
| 283 |
+
depth_patch_num_h, depth_patch_num_w = h_depth // self.patch_size, w_depth // self.patch_size
|
| 284 |
+
|
| 285 |
+
# patchify, embed image tokens and depth tokens
|
| 286 |
+
x_img = self.patch_embed(x_img) # batch, length_img, dim
|
| 287 |
+
assert self.depth_patch_embed is not None
|
| 288 |
+
x_depth = self.depth_patch_embed(x_depth) # batch, length_depth, dim
|
| 289 |
+
assert depth_patch_num_h * depth_patch_num_w == x_depth.shape[1]
|
| 290 |
+
|
| 291 |
+
# get full pose enc of img and depth
|
| 292 |
+
# 1-> img data type enc
|
| 293 |
+
# 2-> depth data type enc
|
| 294 |
+
img_pose_enc = 1 + self.interpolate_pos_encoding_without_cls(x_img, h_img, w_img, self.pos_embed[:, 1:]).repeat(B, 1, 1)
|
| 295 |
+
depth_pose_enc = 2 + self.interpolate_pos_encoding_without_cls(x_depth, h_depth, w_depth, self.pos_embed[:, 1:]).repeat(B, 1, 1)
|
| 296 |
+
|
| 297 |
+
# add pose enc to img and depth
|
| 298 |
+
x_img = x_img + img_pose_enc
|
| 299 |
+
x_depth = x_depth + depth_pose_enc
|
| 300 |
+
|
| 301 |
+
## mask depth tokens
|
| 302 |
+
if kwargs.get('enable_depth_mask', True):
|
| 303 |
+
x_depth_masked, depth_mask_info = depth_masking(
|
| 304 |
+
x_depth,
|
| 305 |
+
depth_patch_num_h,
|
| 306 |
+
depth_patch_num_w,
|
| 307 |
+
depth_values=x_depth_raw,
|
| 308 |
+
depth_mask_threshold_num=[1]*B,
|
| 309 |
+
valid_depth_range=(-9.5, 200.0)
|
| 310 |
+
)
|
| 311 |
+
else:
|
| 312 |
+
x_depth_masked = x_depth
|
| 313 |
+
depth_mask_info = None
|
| 314 |
+
|
| 315 |
+
## mask image tokens
|
| 316 |
+
x_img_masked = x_img
|
| 317 |
+
img_mask_info = None
|
| 318 |
+
|
| 319 |
+
# get cls token
|
| 320 |
+
x_cls = self.cls_token.squeeze(0) + self.pos_embed.squeeze(0)[:1] # 1, dim
|
| 321 |
+
|
| 322 |
+
# cat cls, img and depth tokens
|
| 323 |
+
assert self.img_depth_fuse_mode == 'cat_token', "Only cat_token mode is supported for this model."
|
| 324 |
+
x_masked_list = []
|
| 325 |
+
for i in range(B):
|
| 326 |
+
if self.register_tokens is not None:
|
| 327 |
+
x_mased = torch.cat([x_cls, self.register_tokens.squeeze(0), x_img_masked[i], x_depth_masked[i]], dim=0) # 1 + num_register_tokens + length_img + length_depth, dim
|
| 328 |
+
else:
|
| 329 |
+
x_mased = torch.cat([x_cls, x_img_masked[i], x_depth_masked[i]], dim=0) # 1 + length_img + length_depth, dim
|
| 330 |
+
x_mased = x_mased.unsqueeze(0) # 1, 1 + num_register_tokens + length_img + length_depth, dim
|
| 331 |
+
x_masked_list.append(x_mased)
|
| 332 |
+
|
| 333 |
+
return x_masked_list
|
| 334 |
+
|
| 335 |
+
def _get_intermediate_layers_not_chunked(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, n=1, return_mae_aux=False, **kwargs):
|
| 336 |
+
x = self.prepare_tokens_with_masks(x_img, x_depth, x_img_mask, x_depth_mask, **kwargs)
|
| 337 |
+
|
| 338 |
+
if not kwargs.get('enable_depth_mask', True):
|
| 339 |
+
x = torch.cat(x, dim=0)
|
| 340 |
+
|
| 341 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 342 |
+
output, total_block_len = [], len(self.blocks)
|
| 343 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 344 |
+
for i, blk in enumerate(self.blocks):
|
| 345 |
+
x = blk(x)
|
| 346 |
+
if i in blocks_to_take:
|
| 347 |
+
output.append(x)
|
| 348 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 349 |
+
|
| 350 |
+
if not kwargs.get('enable_depth_mask', True):
|
| 351 |
+
output = [list(torch.split(out, 1, dim=0)) for out in output]
|
| 352 |
+
return output
|
| 353 |
+
|
| 354 |
+
def _get_intermediate_layers_chunked(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, n=1, return_mae_aux=False, **kwargs):
|
| 355 |
+
x = self.prepare_tokens_with_masks(x_img, x_depth, x_img_mask, x_depth_mask, **kwargs)
|
| 356 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 357 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 358 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 359 |
+
for block_chunk in self.blocks:
|
| 360 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 361 |
+
x = blk(x)
|
| 362 |
+
if i in blocks_to_take:
|
| 363 |
+
output.append(x)
|
| 364 |
+
i += 1
|
| 365 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 366 |
+
|
| 367 |
+
return output
|
| 368 |
+
|
| 369 |
+
def extract_features(self, outputs, norm=True):
|
| 370 |
+
feat_outputs = []
|
| 371 |
+
class_tokens = []
|
| 372 |
+
feat_start_idx = 1 + self.num_register_tokens
|
| 373 |
+
|
| 374 |
+
def process_output(out):
|
| 375 |
+
normed = self.norm(out) if norm else out
|
| 376 |
+
return normed[:, feat_start_idx:], normed[:, 0]
|
| 377 |
+
|
| 378 |
+
for output in outputs:
|
| 379 |
+
if isinstance(output, list):
|
| 380 |
+
feats, tokens = zip(*[process_output(out) for out in output])
|
| 381 |
+
feat_outputs.append(list(feats))
|
| 382 |
+
class_tokens.append(list(tokens))
|
| 383 |
+
else:
|
| 384 |
+
feat, token = process_output(output)
|
| 385 |
+
feat_outputs.append(feat)
|
| 386 |
+
class_tokens.append(token)
|
| 387 |
+
|
| 388 |
+
return feat_outputs, class_tokens
|
| 389 |
+
|
| 390 |
+
def get_intermediate_layers_mae(
|
| 391 |
+
self,
|
| 392 |
+
x_img: torch.Tensor,
|
| 393 |
+
x_depth: torch.Tensor,
|
| 394 |
+
x_img_mask: torch.Tensor=None,
|
| 395 |
+
x_depth_mask: torch.Tensor=None,
|
| 396 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 397 |
+
reshape: bool = False,
|
| 398 |
+
return_class_token: bool = False,
|
| 399 |
+
norm=True,
|
| 400 |
+
return_mae_aux=True,
|
| 401 |
+
**kwargs
|
| 402 |
+
):
|
| 403 |
+
assert reshape is False, "reshape is not supported for now"
|
| 404 |
+
if self.chunked_blocks:
|
| 405 |
+
outputs = self._get_intermediate_layers_chunked(x_img, x_depth, x_img_mask, x_depth_mask, n, return_mae_aux=return_mae_aux,**kwargs)
|
| 406 |
+
else:
|
| 407 |
+
outputs = self._get_intermediate_layers_not_chunked(x_img, x_depth, x_img_mask, x_depth_mask, n, return_mae_aux=return_mae_aux,**kwargs)
|
| 408 |
+
|
| 409 |
+
feat_outputs, class_tokens = self.extract_features(outputs, norm)
|
| 410 |
+
|
| 411 |
+
if return_class_token:
|
| 412 |
+
return tuple(zip(feat_outputs, class_tokens))
|
| 413 |
+
return tuple(feat_outputs)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 417 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 418 |
+
if isinstance(module, nn.Linear):
|
| 419 |
+
trunc_normal_(module.weight, std=0.02)
|
| 420 |
+
if module.bias is not None:
|
| 421 |
+
nn.init.zeros_(module.bias)
|
| 422 |
+
|
| 423 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 424 |
+
model = DinoVisionTransformer(
|
| 425 |
+
patch_size=patch_size,
|
| 426 |
+
embed_dim=384,
|
| 427 |
+
depth=12,
|
| 428 |
+
num_heads=6,
|
| 429 |
+
mlp_ratio=4,
|
| 430 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 431 |
+
num_register_tokens=num_register_tokens,
|
| 432 |
+
**kwargs,
|
| 433 |
+
)
|
| 434 |
+
return model
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 438 |
+
model = DinoVisionTransformer(
|
| 439 |
+
patch_size=patch_size,
|
| 440 |
+
embed_dim=768,
|
| 441 |
+
depth=12,
|
| 442 |
+
num_heads=12,
|
| 443 |
+
mlp_ratio=4,
|
| 444 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 445 |
+
num_register_tokens=num_register_tokens,
|
| 446 |
+
**kwargs,
|
| 447 |
+
)
|
| 448 |
+
return model
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 452 |
+
model = DinoVisionTransformer(
|
| 453 |
+
patch_size=patch_size,
|
| 454 |
+
embed_dim=1024,
|
| 455 |
+
depth=24,
|
| 456 |
+
num_heads=16,
|
| 457 |
+
mlp_ratio=4,
|
| 458 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 459 |
+
num_register_tokens=num_register_tokens,
|
| 460 |
+
**kwargs,
|
| 461 |
+
)
|
| 462 |
+
return model
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 466 |
+
"""
|
| 467 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 468 |
+
"""
|
| 469 |
+
model = DinoVisionTransformer(
|
| 470 |
+
patch_size=patch_size,
|
| 471 |
+
embed_dim=1536,
|
| 472 |
+
depth=40,
|
| 473 |
+
num_heads=24,
|
| 474 |
+
mlp_ratio=4,
|
| 475 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 476 |
+
num_register_tokens=num_register_tokens,
|
| 477 |
+
**kwargs,
|
| 478 |
+
)
|
| 479 |
+
return model
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/cluster.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ClusterType(Enum):
|
| 13 |
+
AWS = "aws"
|
| 14 |
+
FAIR = "fair"
|
| 15 |
+
RSC = "rsc"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _guess_cluster_type() -> ClusterType:
|
| 19 |
+
uname = os.uname()
|
| 20 |
+
if uname.sysname == "Linux":
|
| 21 |
+
if uname.release.endswith("-aws"):
|
| 22 |
+
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
| 23 |
+
return ClusterType.AWS
|
| 24 |
+
elif uname.nodename.startswith("rsc"):
|
| 25 |
+
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
| 26 |
+
return ClusterType.RSC
|
| 27 |
+
|
| 28 |
+
return ClusterType.FAIR
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
| 32 |
+
if cluster_type is None:
|
| 33 |
+
return _guess_cluster_type()
|
| 34 |
+
|
| 35 |
+
return cluster_type
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 39 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 40 |
+
if cluster_type is None:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
CHECKPOINT_DIRNAMES = {
|
| 44 |
+
ClusterType.AWS: "checkpoints",
|
| 45 |
+
ClusterType.FAIR: "checkpoint",
|
| 46 |
+
ClusterType.RSC: "checkpoint/dino",
|
| 47 |
+
}
|
| 48 |
+
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 52 |
+
checkpoint_path = get_checkpoint_path(cluster_type)
|
| 53 |
+
if checkpoint_path is None:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
username = os.environ.get("USER")
|
| 57 |
+
assert username is not None
|
| 58 |
+
return checkpoint_path / username
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| 62 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 63 |
+
if cluster_type is None:
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
SLURM_PARTITIONS = {
|
| 67 |
+
ClusterType.AWS: "learnlab",
|
| 68 |
+
ClusterType.FAIR: "learnlab",
|
| 69 |
+
ClusterType.RSC: "learn",
|
| 70 |
+
}
|
| 71 |
+
return SLURM_PARTITIONS[cluster_type]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_slurm_executor_parameters(
|
| 75 |
+
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
|
| 76 |
+
) -> Dict[str, Any]:
|
| 77 |
+
# create default parameters
|
| 78 |
+
params = {
|
| 79 |
+
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
| 80 |
+
"gpus_per_node": num_gpus_per_node,
|
| 81 |
+
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
| 82 |
+
"cpus_per_task": 10,
|
| 83 |
+
"nodes": nodes,
|
| 84 |
+
"slurm_partition": get_slurm_partition(cluster_type),
|
| 85 |
+
}
|
| 86 |
+
# apply cluster-specific adjustments
|
| 87 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 88 |
+
if cluster_type == ClusterType.AWS:
|
| 89 |
+
params["cpus_per_task"] = 12
|
| 90 |
+
del params["mem_gb"]
|
| 91 |
+
elif cluster_type == ClusterType.RSC:
|
| 92 |
+
params["cpus_per_task"] = 12
|
| 93 |
+
# set additional parameters / apply overrides
|
| 94 |
+
params.update(kwargs)
|
| 95 |
+
return params
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/config.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
|
| 12 |
+
import dinov2.distributed as distributed
|
| 13 |
+
from dinov2.logging import setup_logging
|
| 14 |
+
from dinov2.utils import utils
|
| 15 |
+
from dinov2.configs import dinov2_default_config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_scaling_rules_to_cfg(cfg): # to fix
|
| 22 |
+
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
| 23 |
+
base_lr = cfg.optim.base_lr
|
| 24 |
+
cfg.optim.lr = base_lr
|
| 25 |
+
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
| 26 |
+
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
return cfg
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
| 33 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
| 34 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
| 35 |
+
with open(saved_cfg_path, "w") as f:
|
| 36 |
+
OmegaConf.save(config=cfg, f=f)
|
| 37 |
+
return saved_cfg_path
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_cfg_from_args(args):
|
| 41 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
| 42 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 43 |
+
default_cfg = OmegaConf.create(dinov2_default_config)
|
| 44 |
+
cfg = OmegaConf.load(args.config_file)
|
| 45 |
+
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
| 46 |
+
return cfg
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def default_setup(args):
|
| 50 |
+
distributed.enable(overwrite=True)
|
| 51 |
+
seed = getattr(args, "seed", 0)
|
| 52 |
+
rank = distributed.get_global_rank()
|
| 53 |
+
|
| 54 |
+
global logger
|
| 55 |
+
setup_logging(output=args.output_dir, level=logging.INFO)
|
| 56 |
+
logger = logging.getLogger("dinov2")
|
| 57 |
+
|
| 58 |
+
utils.fix_random_seeds(seed + rank)
|
| 59 |
+
logger.info("git:\n {}\n".format(utils.get_sha()))
|
| 60 |
+
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def setup(args):
|
| 64 |
+
"""
|
| 65 |
+
Create configs and perform basic setups.
|
| 66 |
+
"""
|
| 67 |
+
cfg = get_cfg_from_args(args)
|
| 68 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 69 |
+
default_setup(args)
|
| 70 |
+
apply_scaling_rules_to_cfg(cfg)
|
| 71 |
+
write_config(cfg, args.output_dir)
|
| 72 |
+
return cfg
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/dtype.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
TypeSpec = Union[str, np.dtype, torch.dtype]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
| 17 |
+
np.dtype("bool"): torch.bool,
|
| 18 |
+
np.dtype("uint8"): torch.uint8,
|
| 19 |
+
np.dtype("int8"): torch.int8,
|
| 20 |
+
np.dtype("int16"): torch.int16,
|
| 21 |
+
np.dtype("int32"): torch.int32,
|
| 22 |
+
np.dtype("int64"): torch.int64,
|
| 23 |
+
np.dtype("float16"): torch.float16,
|
| 24 |
+
np.dtype("float32"): torch.float32,
|
| 25 |
+
np.dtype("float64"): torch.float64,
|
| 26 |
+
np.dtype("complex64"): torch.complex64,
|
| 27 |
+
np.dtype("complex128"): torch.complex128,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
| 32 |
+
if isinstance(dtype, torch.dtype):
|
| 33 |
+
return dtype
|
| 34 |
+
if isinstance(dtype, str):
|
| 35 |
+
dtype = np.dtype(dtype)
|
| 36 |
+
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
| 37 |
+
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/param_groups.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("dinov2")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
|
| 14 |
+
"""
|
| 15 |
+
Calculate lr decay rate for different ViT blocks.
|
| 16 |
+
Args:
|
| 17 |
+
name (string): parameter name.
|
| 18 |
+
lr_decay_rate (float): base lr decay rate.
|
| 19 |
+
num_layers (int): number of ViT blocks.
|
| 20 |
+
Returns:
|
| 21 |
+
lr decay rate for the given parameter.
|
| 22 |
+
"""
|
| 23 |
+
layer_id = num_layers + 1
|
| 24 |
+
if name.startswith("backbone") or force_is_backbone:
|
| 25 |
+
if (
|
| 26 |
+
".pos_embed" in name
|
| 27 |
+
or ".patch_embed" in name
|
| 28 |
+
or ".mask_token" in name
|
| 29 |
+
or ".cls_token" in name
|
| 30 |
+
or ".register_tokens" in name
|
| 31 |
+
):
|
| 32 |
+
layer_id = 0
|
| 33 |
+
elif force_is_backbone and (
|
| 34 |
+
"pos_embed" in name
|
| 35 |
+
or "patch_embed" in name
|
| 36 |
+
or "mask_token" in name
|
| 37 |
+
or "cls_token" in name
|
| 38 |
+
or "register_tokens" in name
|
| 39 |
+
):
|
| 40 |
+
layer_id = 0
|
| 41 |
+
elif ".blocks." in name and ".residual." not in name:
|
| 42 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
| 43 |
+
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
| 44 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
| 45 |
+
elif "blocks." in name and "residual." not in name:
|
| 46 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
| 47 |
+
|
| 48 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
| 52 |
+
chunked_blocks = False
|
| 53 |
+
if hasattr(model, "n_blocks"):
|
| 54 |
+
logger.info("chunked fsdp")
|
| 55 |
+
n_blocks = model.n_blocks
|
| 56 |
+
chunked_blocks = model.chunked_blocks
|
| 57 |
+
elif hasattr(model, "blocks"):
|
| 58 |
+
logger.info("first code branch")
|
| 59 |
+
n_blocks = len(model.blocks)
|
| 60 |
+
elif hasattr(model, "backbone"):
|
| 61 |
+
logger.info("second code branch")
|
| 62 |
+
n_blocks = len(model.backbone.blocks)
|
| 63 |
+
else:
|
| 64 |
+
logger.info("else code branch")
|
| 65 |
+
n_blocks = 0
|
| 66 |
+
all_param_groups = []
|
| 67 |
+
|
| 68 |
+
for name, param in model.named_parameters():
|
| 69 |
+
name = name.replace("_fsdp_wrapped_module.", "")
|
| 70 |
+
if not param.requires_grad:
|
| 71 |
+
continue
|
| 72 |
+
decay_rate = get_vit_lr_decay_rate(
|
| 73 |
+
name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
|
| 74 |
+
)
|
| 75 |
+
d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
|
| 76 |
+
|
| 77 |
+
if "last_layer" in name:
|
| 78 |
+
d.update({"is_last_layer": True})
|
| 79 |
+
|
| 80 |
+
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
| 81 |
+
d.update({"wd_multiplier": 0.0})
|
| 82 |
+
|
| 83 |
+
if "patch_embed" in name:
|
| 84 |
+
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
| 85 |
+
|
| 86 |
+
all_param_groups.append(d)
|
| 87 |
+
logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
|
| 88 |
+
|
| 89 |
+
return all_param_groups
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
| 93 |
+
fused_params_groups = defaultdict(lambda: {"params": []})
|
| 94 |
+
for d in all_params_groups:
|
| 95 |
+
identifier = ""
|
| 96 |
+
for k in keys:
|
| 97 |
+
identifier += k + str(d[k]) + "_"
|
| 98 |
+
|
| 99 |
+
for k in keys:
|
| 100 |
+
fused_params_groups[identifier][k] = d[k]
|
| 101 |
+
fused_params_groups[identifier]["params"].append(d["params"])
|
| 102 |
+
|
| 103 |
+
return fused_params_groups.values()
|
third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import subprocess
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
| 21 |
+
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
| 22 |
+
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
|
| 23 |
+
else:
|
| 24 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
| 25 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
| 26 |
+
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
| 27 |
+
state_dict = state_dict[checkpoint_key]
|
| 28 |
+
# remove `module.` prefix
|
| 29 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 30 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
| 31 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 32 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 33 |
+
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def fix_random_seeds(seed=31):
|
| 37 |
+
"""
|
| 38 |
+
Fix random seeds.
|
| 39 |
+
"""
|
| 40 |
+
torch.manual_seed(seed)
|
| 41 |
+
torch.cuda.manual_seed_all(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_sha():
|
| 47 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 48 |
+
|
| 49 |
+
def _run(command):
|
| 50 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
| 51 |
+
|
| 52 |
+
sha = "N/A"
|
| 53 |
+
diff = "clean"
|
| 54 |
+
branch = "N/A"
|
| 55 |
+
try:
|
| 56 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
| 57 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
| 58 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
| 59 |
+
diff = "has uncommitted changes" if diff else "clean"
|
| 60 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
| 61 |
+
except Exception:
|
| 62 |
+
pass
|
| 63 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 64 |
+
return message
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class CosineScheduler(object):
|
| 68 |
+
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.final_value = final_value
|
| 71 |
+
self.total_iters = total_iters
|
| 72 |
+
|
| 73 |
+
freeze_schedule = np.zeros((freeze_iters))
|
| 74 |
+
|
| 75 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 76 |
+
|
| 77 |
+
iters = np.arange(total_iters - warmup_iters - freeze_iters)
|
| 78 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
| 79 |
+
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
|
| 80 |
+
|
| 81 |
+
assert len(self.schedule) == self.total_iters
|
| 82 |
+
|
| 83 |
+
def __getitem__(self, it):
|
| 84 |
+
if it >= self.total_iters:
|
| 85 |
+
return self.final_value
|
| 86 |
+
else:
|
| 87 |
+
return self.schedule[it]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def has_batchnorms(model):
|
| 91 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
| 92 |
+
for name, module in model.named_modules():
|
| 93 |
+
if isinstance(module, bn_types):
|
| 94 |
+
return True
|
| 95 |
+
return False
|
third_party/lingbot_depth/mdm/model/modules_decoder.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from numbers import Number
|
| 3 |
+
import importlib
|
| 4 |
+
import itertools
|
| 5 |
+
import functools
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from .utils import wrap_module_with_gradient_checkpointing
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResidualConvBlock(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_channels: int,
|
| 20 |
+
out_channels: int = None,
|
| 21 |
+
hidden_channels: int = None,
|
| 22 |
+
kernel_size: int = 3,
|
| 23 |
+
padding_mode: str = 'replicate',
|
| 24 |
+
activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
|
| 25 |
+
in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
|
| 26 |
+
hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
|
| 27 |
+
):
|
| 28 |
+
super(ResidualConvBlock, self).__init__()
|
| 29 |
+
if out_channels is None:
|
| 30 |
+
out_channels = in_channels
|
| 31 |
+
if hidden_channels is None:
|
| 32 |
+
hidden_channels = in_channels
|
| 33 |
+
|
| 34 |
+
if activation =='relu':
|
| 35 |
+
activation_cls = nn.ReLU
|
| 36 |
+
elif activation == 'leaky_relu':
|
| 37 |
+
activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
|
| 38 |
+
elif activation =='silu':
|
| 39 |
+
activation_cls = nn.SiLU
|
| 40 |
+
elif activation == 'elu':
|
| 41 |
+
activation_cls = nn.ELU
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f'Unsupported activation function: {activation}')
|
| 44 |
+
|
| 45 |
+
self.layers = nn.Sequential(
|
| 46 |
+
nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
|
| 47 |
+
nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
|
| 48 |
+
nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
|
| 49 |
+
nn.Identity(),
|
| 50 |
+
activation_cls(),
|
| 51 |
+
nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
|
| 52 |
+
nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
|
| 53 |
+
nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
|
| 54 |
+
nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
|
| 55 |
+
nn.Identity(),
|
| 56 |
+
activation_cls(),
|
| 57 |
+
nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
skip = self.skip_connection(x)
|
| 64 |
+
x = self.layers(x)
|
| 65 |
+
x = x + skip
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class Resampler(nn.Sequential):
|
| 70 |
+
def __init__(self,
|
| 71 |
+
in_channels: int,
|
| 72 |
+
out_channels: int,
|
| 73 |
+
type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
|
| 74 |
+
scale_factor: int = 2,
|
| 75 |
+
):
|
| 76 |
+
if type_ == 'pixel_shuffle':
|
| 77 |
+
nn.Sequential.__init__(self,
|
| 78 |
+
nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 79 |
+
nn.PixelShuffle(scale_factor),
|
| 80 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 81 |
+
)
|
| 82 |
+
for i in range(1, scale_factor ** 2):
|
| 83 |
+
self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
|
| 84 |
+
self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
|
| 85 |
+
elif type_ in ['nearest', 'bilinear']:
|
| 86 |
+
nn.Sequential.__init__(self,
|
| 87 |
+
nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
|
| 88 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 89 |
+
)
|
| 90 |
+
elif type_ == 'conv_transpose':
|
| 91 |
+
nn.Sequential.__init__(self,
|
| 92 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
|
| 93 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 94 |
+
)
|
| 95 |
+
self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
|
| 96 |
+
elif type_ == 'pixel_unshuffle':
|
| 97 |
+
nn.Sequential.__init__(self,
|
| 98 |
+
nn.PixelUnshuffle(scale_factor),
|
| 99 |
+
nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 100 |
+
)
|
| 101 |
+
elif type_ == 'avg_pool':
|
| 102 |
+
nn.Sequential.__init__(self,
|
| 103 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 104 |
+
nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
|
| 105 |
+
)
|
| 106 |
+
elif type_ == 'max_pool':
|
| 107 |
+
nn.Sequential.__init__(self,
|
| 108 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 109 |
+
nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f'Unsupported resampler type: {type_}')
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class MLP(nn.Sequential):
|
| 116 |
+
def __init__(self, dims: Sequence[int]):
|
| 117 |
+
nn.Sequential.__init__(self,
|
| 118 |
+
*itertools.chain(*[
|
| 119 |
+
(nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
|
| 120 |
+
for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
|
| 121 |
+
]),
|
| 122 |
+
nn.Linear(dims[-2], dims[-1]),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ConvStack(nn.Module):
|
| 127 |
+
def __init__(self,
|
| 128 |
+
dim_in: List[Optional[int]],
|
| 129 |
+
dim_res_blocks: List[int],
|
| 130 |
+
dim_out: List[Optional[int]],
|
| 131 |
+
resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
|
| 132 |
+
dim_times_res_block_hidden: int = 1,
|
| 133 |
+
num_res_blocks: int = 1,
|
| 134 |
+
res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
|
| 135 |
+
res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
|
| 136 |
+
activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
|
| 137 |
+
):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.input_blocks = nn.ModuleList([
|
| 140 |
+
nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
|
| 141 |
+
for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
|
| 142 |
+
])
|
| 143 |
+
self.resamplers = nn.ModuleList([
|
| 144 |
+
Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
|
| 145 |
+
for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
|
| 146 |
+
dim_res_blocks[:-1],
|
| 147 |
+
dim_res_blocks[1:],
|
| 148 |
+
resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
|
| 149 |
+
))
|
| 150 |
+
])
|
| 151 |
+
self.res_blocks = nn.ModuleList([
|
| 152 |
+
nn.Sequential(
|
| 153 |
+
*(
|
| 154 |
+
ResidualConvBlock(
|
| 155 |
+
dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
|
| 156 |
+
activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
|
| 157 |
+
) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
|
| 158 |
+
)
|
| 159 |
+
) for i, dim_res_block_ in enumerate(dim_res_blocks)
|
| 160 |
+
])
|
| 161 |
+
self.output_blocks = nn.ModuleList([
|
| 162 |
+
nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
|
| 163 |
+
for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
|
| 164 |
+
])
|
| 165 |
+
|
| 166 |
+
def enable_gradient_checkpointing(self):
|
| 167 |
+
for i in range(len(self.resamplers)):
|
| 168 |
+
self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
|
| 169 |
+
for i in range(len(self.res_blocks)):
|
| 170 |
+
for j in range(len(self.res_blocks[i])):
|
| 171 |
+
self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
|
| 172 |
+
|
| 173 |
+
def forward(self, in_features: List[torch.Tensor]):
|
| 174 |
+
out_features = []
|
| 175 |
+
for i in range(len(self.res_blocks)):
|
| 176 |
+
feature = self.input_blocks[i](in_features[i])
|
| 177 |
+
if i == 0:
|
| 178 |
+
x = feature
|
| 179 |
+
elif feature is not None:
|
| 180 |
+
x = x + feature
|
| 181 |
+
x = self.res_blocks[i](x)
|
| 182 |
+
out_features.append(self.output_blocks[i](x))
|
| 183 |
+
if i < len(self.res_blocks) - 1:
|
| 184 |
+
x = self.resamplers[i](x)
|
| 185 |
+
return out_features
|
third_party/lingbot_depth/mdm/model/modules_rgbd_encoder.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from numbers import Number
|
| 3 |
+
import importlib
|
| 4 |
+
import itertools
|
| 5 |
+
import functools
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from .dinov2_rgbd.models.vision_transformer import DinoVisionTransformer
|
| 14 |
+
from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DINOv2_RGBD_Encoder(nn.Module):
|
| 18 |
+
backbone: DinoVisionTransformer
|
| 19 |
+
image_mean: torch.Tensor
|
| 20 |
+
image_std: torch.Tensor
|
| 21 |
+
dim_features: int
|
| 22 |
+
|
| 23 |
+
def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, ignore_layers: Union[str, List[str]]=[], in_chans: int=3, strict: bool=True, img_depth_fuse_mode='', depth_emb_mode='', depth_mask_ratio=0.6, img_mask_ratio=0.0, **deprecated_kwargs):
|
| 24 |
+
super(DINOv2_RGBD_Encoder, self).__init__()
|
| 25 |
+
|
| 26 |
+
self.intermediate_layers = intermediate_layers
|
| 27 |
+
self.strict = strict
|
| 28 |
+
self.ignore_layers = ignore_layers
|
| 29 |
+
self.img_mask_ratio = img_mask_ratio
|
| 30 |
+
# Load the backbone
|
| 31 |
+
self.hub_loader = getattr(importlib.import_module(".dinov2_rgbd.hub.backbones", __package__), backbone)
|
| 32 |
+
self.backbone_name = backbone
|
| 33 |
+
self.backbone = self.hub_loader(pretrained=False,
|
| 34 |
+
in_chans=in_chans,
|
| 35 |
+
img_depth_fuse_mode=img_depth_fuse_mode,
|
| 36 |
+
depth_emb_mode=depth_emb_mode,
|
| 37 |
+
depth_mask_ratio=depth_mask_ratio,
|
| 38 |
+
img_mask_ratio=img_mask_ratio)
|
| 39 |
+
|
| 40 |
+
self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
|
| 41 |
+
self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
|
| 42 |
+
|
| 43 |
+
if img_mask_ratio > 0:
|
| 44 |
+
self.mask_token_mae = nn.Parameter(torch.zeros(1, 1, self.dim_features))
|
| 45 |
+
torch.nn.init.normal_(self.mask_token_mae, std=.02)
|
| 46 |
+
|
| 47 |
+
self.output_projections = nn.ModuleList([
|
| 48 |
+
nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,)
|
| 49 |
+
for _ in range(self.num_features)
|
| 50 |
+
])
|
| 51 |
+
|
| 52 |
+
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 53 |
+
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def onnx_compatible_mode(self):
|
| 57 |
+
return getattr(self, "_onnx_compatible_mode", False)
|
| 58 |
+
|
| 59 |
+
@onnx_compatible_mode.setter
|
| 60 |
+
def onnx_compatible_mode(self, value: bool):
|
| 61 |
+
self._onnx_compatible_mode = value
|
| 62 |
+
self.backbone.onnx_compatible_mode = value
|
| 63 |
+
|
| 64 |
+
def init_weights(self):
|
| 65 |
+
pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
|
| 66 |
+
ignore_layers = []
|
| 67 |
+
if isinstance(self.ignore_layers, str):
|
| 68 |
+
ignore_layers = [self.ignore_layers]
|
| 69 |
+
else:
|
| 70 |
+
ignore_layers = self.ignore_layers
|
| 71 |
+
|
| 72 |
+
if len(ignore_layers) == 0:
|
| 73 |
+
self.backbone.load_state_dict(pretrained_backbone_state_dict, strict=self.strict)
|
| 74 |
+
else:
|
| 75 |
+
state_dict = {}
|
| 76 |
+
for k, v in pretrained_backbone_state_dict.items():
|
| 77 |
+
is_ignore = False
|
| 78 |
+
for ig_k in ignore_layers:
|
| 79 |
+
if ig_k in k:
|
| 80 |
+
is_ignore = True
|
| 81 |
+
break
|
| 82 |
+
if not is_ignore:
|
| 83 |
+
state_dict[k] = v
|
| 84 |
+
self.backbone.load_state_dict(state_dict, strict=self.strict)
|
| 85 |
+
|
| 86 |
+
def enable_gradient_checkpointing(self):
|
| 87 |
+
for i in range(len(self.backbone.blocks)):
|
| 88 |
+
wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
|
| 89 |
+
|
| 90 |
+
def enable_pytorch_native_sdpa(self):
|
| 91 |
+
for i in range(len(self.backbone.blocks)):
|
| 92 |
+
wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
|
| 93 |
+
|
| 94 |
+
def forward(self,
|
| 95 |
+
image: torch.Tensor,
|
| 96 |
+
depth: torch.Tensor,
|
| 97 |
+
token_rows: Union[int, torch.LongTensor],
|
| 98 |
+
token_cols: Union[int, torch.LongTensor],
|
| 99 |
+
return_class_token: bool = False,
|
| 100 |
+
remap_depth_in: str='linear',
|
| 101 |
+
**kwargs):
|
| 102 |
+
image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode)
|
| 103 |
+
image_14 = (image_14 - self.image_mean) / self.image_std
|
| 104 |
+
|
| 105 |
+
depth_14 = F.interpolate(depth, (token_rows * 14, token_cols * 14), mode="nearest")
|
| 106 |
+
|
| 107 |
+
# set invalid depth value to zero
|
| 108 |
+
depth_14[torch.isinf(depth_14)] = 0.0
|
| 109 |
+
depth_14[torch.isnan(depth_14)] = 0.0
|
| 110 |
+
dmask_14 = (depth_14 > 0.01).detach()
|
| 111 |
+
depth_14 = depth_14 * dmask_14.float()
|
| 112 |
+
|
| 113 |
+
if remap_depth_in == 'linear':
|
| 114 |
+
pass # do nothing
|
| 115 |
+
elif remap_depth_in == 'log':
|
| 116 |
+
depth_14 = torch.log(depth_14)
|
| 117 |
+
depth_14[~dmask_14] = 0.0
|
| 118 |
+
depth_14 = torch.nan_to_num(depth_14, nan=0.0, posinf=0.0, neginf=0.0)
|
| 119 |
+
else:
|
| 120 |
+
raise NotImplementedError
|
| 121 |
+
|
| 122 |
+
# Get intermediate layers from the backbone
|
| 123 |
+
features = self.backbone.get_intermediate_layers_mae(
|
| 124 |
+
x_img=image_14,
|
| 125 |
+
x_depth=depth_14,
|
| 126 |
+
n=self.intermediate_layers,
|
| 127 |
+
return_class_token=True,
|
| 128 |
+
**kwargs)
|
| 129 |
+
|
| 130 |
+
assert self.img_mask_ratio == 0, "img_mask_ratio is not supported in this encoder"
|
| 131 |
+
|
| 132 |
+
if isinstance(features[0][0], list):
|
| 133 |
+
num_valid_tokens = token_rows * token_cols
|
| 134 |
+
features = tuple(
|
| 135 |
+
(
|
| 136 |
+
torch.cat([feat[:, :num_valid_tokens].contiguous() for feat in feats], dim=0),
|
| 137 |
+
torch.cat(cls_tokens, dim=0)
|
| 138 |
+
)
|
| 139 |
+
for feats, cls_tokens in features
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Project features to the desired dimensionality
|
| 143 |
+
x = torch.stack([
|
| 144 |
+
proj(feat.permute(0, 2, 1)[:, :, :token_rows*token_cols].unflatten(2, (token_rows, token_cols)).contiguous())
|
| 145 |
+
for proj, (feat, clstoken) in zip(self.output_projections, features)
|
| 146 |
+
], dim=1).sum(dim=1)
|
| 147 |
+
cls_token = features[-1][1]
|
| 148 |
+
|
| 149 |
+
if return_class_token:
|
| 150 |
+
return x, cls_token, None, None
|
| 151 |
+
else:
|
| 152 |
+
return x, None, None
|
third_party/lingbot_depth/mdm/model/utils.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
def wrap_module_with_gradient_checkpointing(module: nn.Module):
|
| 8 |
+
from torch.utils.checkpoint import checkpoint
|
| 9 |
+
class _CheckpointingWrapper(module.__class__):
|
| 10 |
+
_restore_cls = module.__class__
|
| 11 |
+
def forward(self, *args, **kwargs):
|
| 12 |
+
return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
|
| 13 |
+
|
| 14 |
+
module.__class__ = _CheckpointingWrapper
|
| 15 |
+
return module
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def unwrap_module_with_gradient_checkpointing(module: nn.Module):
|
| 19 |
+
module.__class__ = module.__class__._restore_cls
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def wrap_dinov2_attention_with_sdpa(module: nn.Module):
|
| 23 |
+
assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
|
| 24 |
+
class _AttentionWrapper(module.__class__):
|
| 25 |
+
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 26 |
+
B, N, C = x.shape
|
| 27 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
|
| 28 |
+
|
| 29 |
+
q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
|
| 30 |
+
|
| 31 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
| 32 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 33 |
+
|
| 34 |
+
x = self.proj(x)
|
| 35 |
+
x = self.proj_drop(x)
|
| 36 |
+
return x
|
| 37 |
+
module.__class__ = _AttentionWrapper
|
| 38 |
+
return module
|
| 39 |
+
|
| 40 |
+
def wrap_dinov3_attention_with_sdpa(module: nn.Module):
|
| 41 |
+
assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
|
| 42 |
+
class _AttentionWrapper(module.__class__):
|
| 43 |
+
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 44 |
+
B, N, C = x.shape
|
| 45 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
|
| 46 |
+
|
| 47 |
+
q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
|
| 48 |
+
|
| 49 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
| 50 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 51 |
+
|
| 52 |
+
x = self.proj(x)
|
| 53 |
+
x = self.proj_drop(x)
|
| 54 |
+
return x
|
| 55 |
+
module.__class__ = _AttentionWrapper
|
| 56 |
+
return module
|
| 57 |
+
|
| 58 |
+
def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]:
|
| 59 |
+
group_to_use = torch.distributed.group.WORLD
|
| 60 |
+
world_size = group_to_use.size()
|
| 61 |
+
grad = bucket.buffer()
|
| 62 |
+
grad.div_(world_size)
|
| 63 |
+
torch.distributed.all_reduce(grad, group=group_to_use)
|
| 64 |
+
fut = torch.futures.Future()
|
| 65 |
+
fut.set_result(grad)
|
| 66 |
+
return fut
|
| 67 |
+
|
| 68 |
+
def depth_to_pointcloud(depth, intrinsic_normalized, depth_scale=1.0):
|
| 69 |
+
"""
|
| 70 |
+
Convert depth map to point cloud (pure Tensor version, no point filtering)
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
depth: torch.Tensor, shape (H, W) or (B, H, W), depth map
|
| 74 |
+
intrinsic_normalized: torch.Tensor, shape (3, 3) or (B, 3, 3), normalized intrinsic matrix
|
| 75 |
+
Normalized intrinsics: fx' = fx/W, fy' = fy/H, cx' = cx/W, cy' = cy/H
|
| 76 |
+
depth_scale: float, depth scale factor, default 1000.0
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
points: torch.Tensor, shape (H, W, 3) or (B, H, W, 3), point cloud coordinates (x, y, z)
|
| 80 |
+
"""
|
| 81 |
+
# Handle batch dimension
|
| 82 |
+
if depth.dim() == 2:
|
| 83 |
+
depth = depth.unsqueeze(0) # (1, H, W)
|
| 84 |
+
intrinsic_normalized = intrinsic_normalized.unsqueeze(0) # (1, 3, 3)
|
| 85 |
+
squeeze_output = True
|
| 86 |
+
else:
|
| 87 |
+
squeeze_output = False
|
| 88 |
+
|
| 89 |
+
B, H, W = depth.shape
|
| 90 |
+
device = depth.device
|
| 91 |
+
|
| 92 |
+
# Denormalize intrinsics
|
| 93 |
+
fx = intrinsic_normalized[:, 0, 0] * W # (B,)
|
| 94 |
+
fy = intrinsic_normalized[:, 1, 1] * H
|
| 95 |
+
cx = intrinsic_normalized[:, 0, 2] * W
|
| 96 |
+
cy = intrinsic_normalized[:, 1, 2] * H
|
| 97 |
+
|
| 98 |
+
# Create pixel coordinate grid (H, W)
|
| 99 |
+
v, u = torch.meshgrid(
|
| 100 |
+
torch.arange(H, device=device, dtype=torch.float32),
|
| 101 |
+
torch.arange(W, device=device, dtype=torch.float32),
|
| 102 |
+
indexing='ij'
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Expand to batch dimension (B, H, W)
|
| 106 |
+
u = u.unsqueeze(0).expand(B, -1, -1)
|
| 107 |
+
v = v.unsqueeze(0).expand(B, -1, -1)
|
| 108 |
+
|
| 109 |
+
# Backproject to 3D space
|
| 110 |
+
z = depth / depth_scale # (B, H, W)
|
| 111 |
+
|
| 112 |
+
# Expand intrinsic dimensions for broadcasting (B, 1, 1)
|
| 113 |
+
fx = fx.view(B, 1, 1)
|
| 114 |
+
fy = fy.view(B, 1, 1)
|
| 115 |
+
cx = cx.view(B, 1, 1)
|
| 116 |
+
cy = cy.view(B, 1, 1)
|
| 117 |
+
|
| 118 |
+
x = (u - cx) * z / fx # (B, H, W)
|
| 119 |
+
y = (v - cy) * z / fy # (B, H, W)
|
| 120 |
+
|
| 121 |
+
# Stack coordinates (B, H, W, 3)
|
| 122 |
+
points = torch.stack([x, y, z], dim=-1)
|
| 123 |
+
|
| 124 |
+
if squeeze_output:
|
| 125 |
+
points = points.squeeze(0) # (H, W, 3)
|
| 126 |
+
|
| 127 |
+
return points
|
third_party/lingbot_depth/mdm/model/v2.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from numbers import Number
|
| 3 |
+
from functools import partial
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.utils
|
| 11 |
+
import torch.utils.checkpoint
|
| 12 |
+
import torch.amp
|
| 13 |
+
import torch.version
|
| 14 |
+
from huggingface_hub import hf_hub_download
|
| 15 |
+
|
| 16 |
+
from .modules_rgbd_encoder import DINOv2_RGBD_Encoder
|
| 17 |
+
from .modules_decoder import MLP, ConvStack
|
| 18 |
+
from ..utils.geo import depth_to_pointcloud, normalized_view_plane_uv
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MDMModel(nn.Module):
|
| 22 |
+
encoder: Union[DINOv2_RGBD_Encoder]
|
| 23 |
+
neck: ConvStack
|
| 24 |
+
points_head: ConvStack
|
| 25 |
+
mask_head: ConvStack
|
| 26 |
+
scale_head: MLP
|
| 27 |
+
onnx_compatible_mode: bool
|
| 28 |
+
|
| 29 |
+
def __init__(self,
|
| 30 |
+
encoder: Dict[str, Any],
|
| 31 |
+
neck: Dict[str, Any],
|
| 32 |
+
depth_head: Dict[str, Any] = None,
|
| 33 |
+
mask_head: Dict[str, Any] = None,
|
| 34 |
+
normal_head: Dict[str, Any] = None,
|
| 35 |
+
scale_head: Dict[str, Any] = None,
|
| 36 |
+
remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
|
| 37 |
+
remap_depth_in: Literal['linear', 'log'] = 'log',
|
| 38 |
+
remap_depth_out: Literal['linear', 'exp'] = 'exp',
|
| 39 |
+
num_tokens_range: List[int] = [1200, 3600],
|
| 40 |
+
**deprecated_kwargs
|
| 41 |
+
):
|
| 42 |
+
super(MDMModel, self).__init__()
|
| 43 |
+
if deprecated_kwargs:
|
| 44 |
+
warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
|
| 45 |
+
|
| 46 |
+
self.remap_output = remap_output
|
| 47 |
+
self.num_tokens_range = num_tokens_range
|
| 48 |
+
self.remap_depth_in = remap_depth_in
|
| 49 |
+
self.remap_depth_out = remap_depth_out
|
| 50 |
+
|
| 51 |
+
self.encoder = DINOv2_RGBD_Encoder(**encoder)
|
| 52 |
+
|
| 53 |
+
self.neck = ConvStack(**neck)
|
| 54 |
+
if depth_head is not None:
|
| 55 |
+
self.depth_head = ConvStack(**depth_head)
|
| 56 |
+
if mask_head is not None:
|
| 57 |
+
self.mask_head = ConvStack(**mask_head)
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def device(self) -> torch.device:
|
| 61 |
+
return next(self.parameters()).device
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def dtype(self) -> torch.dtype:
|
| 65 |
+
return next(self.parameters()).dtype
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def from_pretrained(
|
| 69 |
+
cls,
|
| 70 |
+
pretrained_model_name_or_path: Union[str, Path, IO[bytes]],
|
| 71 |
+
model_kwargs: Optional[Dict[str, Any]] = None,
|
| 72 |
+
**hf_kwargs) -> 'MDMModel':
|
| 73 |
+
if Path(pretrained_model_name_or_path).exists():
|
| 74 |
+
checkpoint_path = pretrained_model_name_or_path
|
| 75 |
+
else:
|
| 76 |
+
checkpoint_path = hf_hub_download(
|
| 77 |
+
repo_id=pretrained_model_name_or_path,
|
| 78 |
+
repo_type="model",
|
| 79 |
+
filename="model.pt",
|
| 80 |
+
**hf_kwargs
|
| 81 |
+
)
|
| 82 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
|
| 83 |
+
|
| 84 |
+
model_config = checkpoint['model_config']
|
| 85 |
+
if model_kwargs is not None:
|
| 86 |
+
model_config.update(model_kwargs)
|
| 87 |
+
model = cls(**model_config)
|
| 88 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
| 89 |
+
|
| 90 |
+
return model
|
| 91 |
+
|
| 92 |
+
def init_weights(self):
|
| 93 |
+
self.encoder.init_weights()
|
| 94 |
+
|
| 95 |
+
def enable_pytorch_native_sdpa(self):
|
| 96 |
+
self.encoder.enable_pytorch_native_sdpa()
|
| 97 |
+
|
| 98 |
+
def forward(self,
|
| 99 |
+
image: torch.Tensor,
|
| 100 |
+
num_tokens: Union[int, torch.LongTensor],
|
| 101 |
+
depth: Union[None, torch.Tensor]=None,
|
| 102 |
+
**kwargs) -> Dict[str, torch.Tensor]:
|
| 103 |
+
batch_size, _, img_h, img_w = image.shape
|
| 104 |
+
device, dtype = image.device, image.dtype
|
| 105 |
+
|
| 106 |
+
assert depth is not None # in this version, depth is required
|
| 107 |
+
if depth.dim() == 3:
|
| 108 |
+
depth = depth.unsqueeze(1) # from (B, H, W) to (B, 1, H, W)
|
| 109 |
+
|
| 110 |
+
aspect_ratio = img_w / img_h
|
| 111 |
+
base_h, base_w = (num_tokens / aspect_ratio) ** 0.5, (num_tokens * aspect_ratio) ** 0.5
|
| 112 |
+
if isinstance(base_h, torch.Tensor):
|
| 113 |
+
base_h, base_w = base_h.round().long(), base_w.round().long()
|
| 114 |
+
else:
|
| 115 |
+
base_h, base_w = round(base_h), round(base_w)
|
| 116 |
+
|
| 117 |
+
# Backbones encoding
|
| 118 |
+
features, cls_token, _, _ = self.encoder(image, depth, base_h, base_w, return_class_token=True, remap_depth_in=self.remap_depth_in, **kwargs)
|
| 119 |
+
|
| 120 |
+
features = features + cls_token[..., None, None]
|
| 121 |
+
features = [features, None, None, None, None]
|
| 122 |
+
|
| 123 |
+
# Concat UVs for aspect ratio input
|
| 124 |
+
for level in range(5):
|
| 125 |
+
uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
|
| 126 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
| 127 |
+
if features[level] is None:
|
| 128 |
+
features[level] = uv
|
| 129 |
+
else:
|
| 130 |
+
features[level] = torch.concat([features[level], uv], dim=1)
|
| 131 |
+
|
| 132 |
+
# Shared neck
|
| 133 |
+
features = self.neck(features)
|
| 134 |
+
|
| 135 |
+
# Heads decoding
|
| 136 |
+
depth_reg, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['depth_head', 'normal_head', 'mask_head'])
|
| 137 |
+
metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None
|
| 138 |
+
|
| 139 |
+
# Resize
|
| 140 |
+
depth_reg, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [depth_reg, normal, mask])
|
| 141 |
+
|
| 142 |
+
# Remap output
|
| 143 |
+
if depth_reg is not None:
|
| 144 |
+
if self.remap_depth_out == 'exp':
|
| 145 |
+
depth_reg = depth_reg.exp().squeeze(1)
|
| 146 |
+
elif self.remap_depth_out == 'linear':
|
| 147 |
+
depth_reg = depth_reg.squeeze(1)
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f"Invalid remap_depth_out: {self.remap_depth_out}")
|
| 150 |
+
if normal is not None:
|
| 151 |
+
normal = normal.permute(0, 2, 3, 1)
|
| 152 |
+
normal = F.normalize(normal, dim=-1)
|
| 153 |
+
if mask is not None:
|
| 154 |
+
mask_prob = mask.squeeze(1).sigmoid()
|
| 155 |
+
# mask_logits = mask.squeeze(1)
|
| 156 |
+
else:
|
| 157 |
+
mask_prob = None
|
| 158 |
+
if metric_scale is not None:
|
| 159 |
+
metric_scale = metric_scale.squeeze(1).exp()
|
| 160 |
+
|
| 161 |
+
return_dict = {
|
| 162 |
+
'depth_reg': depth_reg,
|
| 163 |
+
'normal': normal,
|
| 164 |
+
'mask': mask_prob,
|
| 165 |
+
}
|
| 166 |
+
return_dict = {k: v for k, v in return_dict.items() if v is not None}
|
| 167 |
+
|
| 168 |
+
return return_dict
|
| 169 |
+
|
| 170 |
+
@torch.inference_mode()
|
| 171 |
+
def infer(
|
| 172 |
+
self,
|
| 173 |
+
image: torch.Tensor,
|
| 174 |
+
depth_in: torch.Tensor = None,
|
| 175 |
+
num_tokens: int = None,
|
| 176 |
+
resolution_level: int = 9,
|
| 177 |
+
apply_mask: bool = True,
|
| 178 |
+
use_fp16: bool = True,
|
| 179 |
+
intrinsics: Optional[torch.Tensor] = None,
|
| 180 |
+
**kwargs
|
| 181 |
+
) -> Dict[str, torch.Tensor]:
|
| 182 |
+
if image.dim() == 3:
|
| 183 |
+
omit_batch_dim = True
|
| 184 |
+
image = image.unsqueeze(0)
|
| 185 |
+
else:
|
| 186 |
+
omit_batch_dim = False
|
| 187 |
+
image = image.to(dtype=self.dtype, device=self.device)
|
| 188 |
+
|
| 189 |
+
if (depth_in is not None) and (depth_in.dim() == 2):
|
| 190 |
+
depth_in = depth_in.unsqueeze(0).to(dtype=self.dtype, device=self.device)
|
| 191 |
+
|
| 192 |
+
original_height, original_width = image.shape[-2:]
|
| 193 |
+
area = original_height * original_width
|
| 194 |
+
aspect_ratio = original_width / original_height
|
| 195 |
+
|
| 196 |
+
# Determine the number of base tokens to use
|
| 197 |
+
if num_tokens is None:
|
| 198 |
+
min_tokens, max_tokens = self.num_tokens_range
|
| 199 |
+
num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
|
| 200 |
+
|
| 201 |
+
# Forward pass
|
| 202 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=use_fp16 and self.dtype != torch.bfloat16):
|
| 203 |
+
output = self.forward(image, num_tokens=num_tokens, depth=depth_in, **kwargs)
|
| 204 |
+
depth_reg, mask = (output.get(k, None) for k in ['depth_reg', 'mask'])
|
| 205 |
+
|
| 206 |
+
# Always process the output in fp32 precision
|
| 207 |
+
depth_reg, mask = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [depth_reg, mask])
|
| 208 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
|
| 209 |
+
if mask is not None:
|
| 210 |
+
mask_binary = mask > 0.5
|
| 211 |
+
else:
|
| 212 |
+
mask_binary = None
|
| 213 |
+
|
| 214 |
+
depth = depth_reg
|
| 215 |
+
if intrinsics is not None:
|
| 216 |
+
points = depth_to_pointcloud(depth, intrinsics)
|
| 217 |
+
else:
|
| 218 |
+
points = None
|
| 219 |
+
|
| 220 |
+
# Apply mask
|
| 221 |
+
if apply_mask and mask_binary is not None:
|
| 222 |
+
points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
|
| 223 |
+
depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
|
| 224 |
+
|
| 225 |
+
return_dict = {
|
| 226 |
+
'points': points,
|
| 227 |
+
'depth': depth,
|
| 228 |
+
'mask': mask_binary,
|
| 229 |
+
}
|
| 230 |
+
return_dict = {k: v for k, v in return_dict.items() if v is not None}
|
| 231 |
+
|
| 232 |
+
if omit_batch_dim:
|
| 233 |
+
return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
|
| 234 |
+
|
| 235 |
+
return return_dict
|
| 236 |
+
|
| 237 |
+
def forward_feat(self,
|
| 238 |
+
image: torch.Tensor,
|
| 239 |
+
num_tokens: Union[int, torch.LongTensor],
|
| 240 |
+
depth: Union[None, torch.Tensor]=None,
|
| 241 |
+
**kwargs) -> Dict[str, torch.Tensor]:
|
| 242 |
+
batch_size, _, img_h, img_w = image.shape
|
| 243 |
+
device, dtype = image.device, image.dtype
|
| 244 |
+
|
| 245 |
+
assert depth is not None # in this version, depth is required
|
| 246 |
+
if depth.dim() == 3:
|
| 247 |
+
depth = depth.unsqueeze(1) # from (B, H, W) to (B, 1, H, W)
|
| 248 |
+
|
| 249 |
+
aspect_ratio = img_w / img_h
|
| 250 |
+
base_h, base_w = (num_tokens / aspect_ratio) ** 0.5, (num_tokens * aspect_ratio) ** 0.5
|
| 251 |
+
if isinstance(base_h, torch.Tensor):
|
| 252 |
+
base_h, base_w = base_h.round().long(), base_w.round().long()
|
| 253 |
+
else:
|
| 254 |
+
base_h, base_w = round(base_h), round(base_w)
|
| 255 |
+
|
| 256 |
+
# Backbones encoding
|
| 257 |
+
features, cls_token, _, _ = self.encoder(image, depth, base_h, base_w, return_class_token=True, remap_depth_in=self.remap_depth_in, **kwargs)
|
| 258 |
+
|
| 259 |
+
return features, cls_token
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@torch.inference_mode()
|
| 263 |
+
def infer_feat(
|
| 264 |
+
self,
|
| 265 |
+
image: torch.Tensor,
|
| 266 |
+
depth_in: torch.Tensor = None,
|
| 267 |
+
num_tokens: int = None,
|
| 268 |
+
resolution_level: int = 9,
|
| 269 |
+
apply_mask: bool = True,
|
| 270 |
+
use_fp16: bool = True,
|
| 271 |
+
intrinsics: Optional[torch.Tensor] = None,
|
| 272 |
+
**kwargs
|
| 273 |
+
):
|
| 274 |
+
if image.dim() == 3:
|
| 275 |
+
omit_batch_dim = True
|
| 276 |
+
image = image.unsqueeze(0)
|
| 277 |
+
else:
|
| 278 |
+
omit_batch_dim = False
|
| 279 |
+
image = image.to(dtype=self.dtype, device=self.device)
|
| 280 |
+
|
| 281 |
+
if (depth_in is not None) and (depth_in.dim() == 2):
|
| 282 |
+
depth_in = depth_in.unsqueeze(0).to(dtype=self.dtype, device=self.device)
|
| 283 |
+
|
| 284 |
+
original_height, original_width = image.shape[-2:]
|
| 285 |
+
area = original_height * original_width
|
| 286 |
+
aspect_ratio = original_width / original_height
|
| 287 |
+
|
| 288 |
+
# Determine the number of base tokens to use
|
| 289 |
+
if num_tokens is None:
|
| 290 |
+
min_tokens, max_tokens = self.num_tokens_range
|
| 291 |
+
num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
|
| 292 |
+
|
| 293 |
+
# Forward pass
|
| 294 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=use_fp16 and self.dtype != torch.bfloat16):
|
| 295 |
+
features, cls_token = self.forward_feat(image, num_tokens=num_tokens, depth=depth_in, **kwargs)
|
| 296 |
+
|
| 297 |
+
return features, cls_token
|
third_party/lingbot_depth/mdm/utils/__init__.py
ADDED
|
File without changes
|
third_party/lingbot_depth/mdm/utils/geo.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
|
| 4 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
| 5 |
+
if aspect_ratio is None:
|
| 6 |
+
aspect_ratio = width / height
|
| 7 |
+
|
| 8 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
| 9 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
| 10 |
+
|
| 11 |
+
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
|
| 12 |
+
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
|
| 13 |
+
u, v = torch.meshgrid(u, v, indexing='xy')
|
| 14 |
+
uv = torch.stack([u, v], dim=-1)
|
| 15 |
+
return uv
|
| 16 |
+
|
| 17 |
+
def depth_to_pointcloud(depth, intrinsic_normalized, depth_scale=1.0):
|
| 18 |
+
"""
|
| 19 |
+
Convert depth map to point cloud (pure Tensor version, no point filtering)
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
depth: torch.Tensor, shape (H, W) or (B, H, W), depth map
|
| 23 |
+
intrinsic_normalized: torch.Tensor, shape (3, 3) or (B, 3, 3), normalized intrinsic matrix
|
| 24 |
+
Normalized intrinsics: fx' = fx/W, fy' = fy/H, cx' = cx/W, cy' = cy/H
|
| 25 |
+
depth_scale: float, depth scale factor, default 1000.0
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
points: torch.Tensor, shape (H, W, 3) or (B, H, W, 3), point cloud coordinates (x, y, z)
|
| 29 |
+
"""
|
| 30 |
+
# Handle batch dimension
|
| 31 |
+
if depth.dim() == 2:
|
| 32 |
+
depth = depth.unsqueeze(0) # (1, H, W)
|
| 33 |
+
intrinsic_normalized = intrinsic_normalized.unsqueeze(0) # (1, 3, 3)
|
| 34 |
+
squeeze_output = True
|
| 35 |
+
else:
|
| 36 |
+
squeeze_output = False
|
| 37 |
+
|
| 38 |
+
B, H, W = depth.shape
|
| 39 |
+
device = depth.device
|
| 40 |
+
|
| 41 |
+
# Denormalize intrinsics
|
| 42 |
+
fx = intrinsic_normalized[:, 0, 0] * W # (B,)
|
| 43 |
+
fy = intrinsic_normalized[:, 1, 1] * H
|
| 44 |
+
cx = intrinsic_normalized[:, 0, 2] * W
|
| 45 |
+
cy = intrinsic_normalized[:, 1, 2] * H
|
| 46 |
+
|
| 47 |
+
# Create pixel coordinate grid (H, W)
|
| 48 |
+
v, u = torch.meshgrid(
|
| 49 |
+
torch.arange(H, device=device, dtype=torch.float32),
|
| 50 |
+
torch.arange(W, device=device, dtype=torch.float32),
|
| 51 |
+
indexing='ij'
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Expand to batch dimension (B, H, W)
|
| 55 |
+
u = u.unsqueeze(0).expand(B, -1, -1)
|
| 56 |
+
v = v.unsqueeze(0).expand(B, -1, -1)
|
| 57 |
+
|
| 58 |
+
# Backproject to 3D space
|
| 59 |
+
z = depth / depth_scale # (B, H, W)
|
| 60 |
+
|
| 61 |
+
# Expand intrinsic dimensions for broadcasting (B, 1, 1)
|
| 62 |
+
fx = fx.view(B, 1, 1)
|
| 63 |
+
fy = fy.view(B, 1, 1)
|
| 64 |
+
cx = cx.view(B, 1, 1)
|
| 65 |
+
cy = cy.view(B, 1, 1)
|
| 66 |
+
|
| 67 |
+
x = (u - cx) * z / fx # (B, H, W)
|
| 68 |
+
y = (v - cy) * z / fy # (B, H, W)
|
| 69 |
+
|
| 70 |
+
# Stack coordinates (B, H, W, 3)
|
| 71 |
+
points = torch.stack([x, y, z], dim=-1)
|
| 72 |
+
|
| 73 |
+
if squeeze_output:
|
| 74 |
+
points = points.squeeze(0) # (H, W, 3)
|
| 75 |
+
|
| 76 |
+
return points
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Usage example
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
# Single image
|
| 82 |
+
depth = torch.rand(480, 640) * 5000 # Depth values
|
| 83 |
+
intrinsic_norm = torch.tensor([
|
| 84 |
+
[525.0/640, 0, 319.5/640],
|
| 85 |
+
[0, 525.0/480, 239.5/480],
|
| 86 |
+
[0, 0, 1]
|
| 87 |
+
])
|
| 88 |
+
|
| 89 |
+
points = depth_to_pointcloud(depth, intrinsic_norm)
|
| 90 |
+
print(f"Point cloud shape: {points.shape}") # (480, 640, 3)
|
| 91 |
+
|
| 92 |
+
# Batch processing
|
| 93 |
+
depth_batch = torch.rand(4, 480, 640) * 5000
|
| 94 |
+
intrinsic_batch = intrinsic_norm.unsqueeze(0).expand(4, -1, -1)
|
| 95 |
+
|
| 96 |
+
points_batch = depth_to_pointcloud(depth_batch, intrinsic_batch)
|
| 97 |
+
print(f"Batch point cloud shape: {points_batch.shape}") # (4, 480, 640, 3)
|
| 98 |
+
|
| 99 |
+
# Flatten to (N, 3) format if needed
|
| 100 |
+
points_flat = points.reshape(-1, 3)
|
| 101 |
+
print(f"Flattened shape: {points_flat.shape}") # (480*640, 3)
|
| 102 |
+
|
| 103 |
+
# Batch flatten to (B, N, 3)
|
| 104 |
+
points_batch_flat = points_batch.reshape(4, -1, 3)
|
| 105 |
+
print(f"Batch flattened shape: {points_batch_flat.shape}") # (4, 480*640, 3)
|
third_party/lingbot_depth/mdm/utils/io.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 3 |
+
from typing import IO
|
| 4 |
+
import zipfile
|
| 5 |
+
import json
|
| 6 |
+
import io
|
| 7 |
+
from typing import *
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import re
|
| 10 |
+
from PIL import Image, PngImagePlugin
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import cv2
|
| 14 |
+
|
| 15 |
+
from .tools import timeit
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def save_glb(
|
| 19 |
+
save_path: Union[str, os.PathLike],
|
| 20 |
+
vertices: np.ndarray,
|
| 21 |
+
faces: np.ndarray,
|
| 22 |
+
vertex_uvs: np.ndarray,
|
| 23 |
+
texture: np.ndarray,
|
| 24 |
+
vertex_normals: Optional[np.ndarray] = None,
|
| 25 |
+
):
|
| 26 |
+
import trimesh
|
| 27 |
+
import trimesh.visual
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
trimesh.Trimesh(
|
| 31 |
+
vertices=vertices,
|
| 32 |
+
vertex_normals=vertex_normals,
|
| 33 |
+
faces=faces,
|
| 34 |
+
visual = trimesh.visual.texture.TextureVisuals(
|
| 35 |
+
uv=vertex_uvs,
|
| 36 |
+
material=trimesh.visual.material.PBRMaterial(
|
| 37 |
+
baseColorTexture=Image.fromarray(texture),
|
| 38 |
+
metallicFactor=0.5,
|
| 39 |
+
roughnessFactor=1.0
|
| 40 |
+
)
|
| 41 |
+
),
|
| 42 |
+
process=False
|
| 43 |
+
).export(save_path)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def save_ply(
|
| 47 |
+
save_path: Union[str, os.PathLike],
|
| 48 |
+
vertices: np.ndarray,
|
| 49 |
+
faces: np.ndarray,
|
| 50 |
+
vertex_colors: np.ndarray,
|
| 51 |
+
vertex_normals: Optional[np.ndarray] = None,
|
| 52 |
+
):
|
| 53 |
+
import trimesh
|
| 54 |
+
import trimesh.visual
|
| 55 |
+
from PIL import Image
|
| 56 |
+
|
| 57 |
+
trimesh.Trimesh(
|
| 58 |
+
vertices=vertices,
|
| 59 |
+
faces=faces,
|
| 60 |
+
vertex_colors=vertex_colors,
|
| 61 |
+
vertex_normals=vertex_normals,
|
| 62 |
+
process=False
|
| 63 |
+
).export(save_path)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray:
|
| 67 |
+
"""
|
| 68 |
+
Read a image, return uint8 RGB array of shape (H, W, 3).
|
| 69 |
+
"""
|
| 70 |
+
if isinstance(path, (str, os.PathLike)):
|
| 71 |
+
data = Path(path).read_bytes()
|
| 72 |
+
else:
|
| 73 |
+
data = path.read()
|
| 74 |
+
image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
|
| 75 |
+
return image
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95):
|
| 79 |
+
"""
|
| 80 |
+
Write a image, input uint8 RGB array of shape (H, W, 3).
|
| 81 |
+
"""
|
| 82 |
+
data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes()
|
| 83 |
+
if isinstance(path, (str, os.PathLike)):
|
| 84 |
+
Path(path).write_bytes(data)
|
| 85 |
+
else:
|
| 86 |
+
path.write(data)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def read_depth(path: Union[str, os.PathLike, IO]) -> np.ndarray:
|
| 90 |
+
"""
|
| 91 |
+
Read a depth image, return float32 depth array of shape (H, W).
|
| 92 |
+
"""
|
| 93 |
+
if isinstance(path, (str, os.PathLike)):
|
| 94 |
+
data = Path(path).read_bytes()
|
| 95 |
+
else:
|
| 96 |
+
data = path.read()
|
| 97 |
+
pil_image = Image.open(io.BytesIO(data))
|
| 98 |
+
near = float(pil_image.info.get('near'))
|
| 99 |
+
far = float(pil_image.info.get('far'))
|
| 100 |
+
depth = np.array(pil_image)
|
| 101 |
+
mask_nan, mask_inf = depth == 0, depth == 65535
|
| 102 |
+
depth = (depth.astype(np.float32) - 1) / 65533
|
| 103 |
+
depth = near ** (1 - depth) * far ** depth
|
| 104 |
+
if 'unit' in pil_image.info: # Legacy support for depth units
|
| 105 |
+
unit = float(pil_image.info.get('unit'))
|
| 106 |
+
depth = depth * unit
|
| 107 |
+
depth[mask_nan] = np.nan
|
| 108 |
+
depth[mask_inf] = np.inf
|
| 109 |
+
return depth
|
| 110 |
+
|
| 111 |
+
def write_depth(
|
| 112 |
+
path: Union[str, os.PathLike, IO],
|
| 113 |
+
depth: np.ndarray,
|
| 114 |
+
max_range: float = 1e5,
|
| 115 |
+
compression_level: int = 7,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Encode and write a depth image as 16-bit PNG format.
|
| 119 |
+
## Parameters:
|
| 120 |
+
- `path: Union[str, os.PathLike, IO]`
|
| 121 |
+
The file path or file object to write to.
|
| 122 |
+
- `depth: np.ndarray`
|
| 123 |
+
The depth array, float32 array of shape (H, W).
|
| 124 |
+
May contain `NaN` for invalid values and `Inf` for infinite values.
|
| 125 |
+
|
| 126 |
+
Depth values are encoded as follows:
|
| 127 |
+
- 0: unknown
|
| 128 |
+
- 1 ~ 65534: depth values in logarithmic
|
| 129 |
+
- 65535: infinity
|
| 130 |
+
|
| 131 |
+
metadata is stored in the PNG file as text fields:
|
| 132 |
+
- `near`: the minimum depth value
|
| 133 |
+
- `far`: the maximum depth value
|
| 134 |
+
"""
|
| 135 |
+
mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth)
|
| 136 |
+
|
| 137 |
+
depth = depth.astype(np.float32)
|
| 138 |
+
mask_finite = depth
|
| 139 |
+
near = max(depth[mask_values].min(), 1e-5)
|
| 140 |
+
far = max(near * 1.1, min(depth[mask_values].max(), near * max_range))
|
| 141 |
+
depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534
|
| 142 |
+
depth[mask_nan] = 0
|
| 143 |
+
depth[mask_inf] = 65535
|
| 144 |
+
|
| 145 |
+
pil_image = Image.fromarray(depth)
|
| 146 |
+
pnginfo = PngImagePlugin.PngInfo()
|
| 147 |
+
pnginfo.add_text('near', str(near))
|
| 148 |
+
pnginfo.add_text('far', str(far))
|
| 149 |
+
pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]:
|
| 153 |
+
"""
|
| 154 |
+
Read a segmentation mask
|
| 155 |
+
### Parameters:
|
| 156 |
+
- `path: Union[str, os.PathLike, IO]`
|
| 157 |
+
The file path or file object to read from.
|
| 158 |
+
### Returns:
|
| 159 |
+
- `Tuple[np.ndarray, Dict[str, int]]`
|
| 160 |
+
A tuple containing:
|
| 161 |
+
- `mask`: uint8 or uint16 numpy.ndarray of shape (H, W).
|
| 162 |
+
- `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}.
|
| 163 |
+
"""
|
| 164 |
+
if isinstance(path, (str, os.PathLike)):
|
| 165 |
+
data = Path(path).read_bytes()
|
| 166 |
+
else:
|
| 167 |
+
data = path.read()
|
| 168 |
+
pil_image = Image.open(io.BytesIO(data))
|
| 169 |
+
labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None
|
| 170 |
+
mask = np.array(pil_image)
|
| 171 |
+
return mask, labels
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7):
|
| 175 |
+
"""
|
| 176 |
+
Write a segmentation mask and label mapping, as PNG format.
|
| 177 |
+
### Parameters:
|
| 178 |
+
- `path: Union[str, os.PathLike, IO]`
|
| 179 |
+
The file path or file object to write to.
|
| 180 |
+
- `mask: np.ndarray`
|
| 181 |
+
The segmentation mask, uint8 or uint16 array of shape (H, W).
|
| 182 |
+
- `labels: Dict[str, int] = None`
|
| 183 |
+
The label mapping, a dictionary of {label_name: label_id}.
|
| 184 |
+
- `compression_level: int = 7`
|
| 185 |
+
The compression level for PNG compression.
|
| 186 |
+
"""
|
| 187 |
+
assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}"
|
| 188 |
+
pil_image = Image.fromarray(mask)
|
| 189 |
+
pnginfo = PngImagePlugin.PngInfo()
|
| 190 |
+
if labels is not None:
|
| 191 |
+
labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':'))
|
| 192 |
+
pnginfo.add_text('labels', labels_json)
|
| 193 |
+
pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray:
|
| 198 |
+
"""
|
| 199 |
+
Read a normal image, return float32 normal array of shape (H, W, 3).
|
| 200 |
+
"""
|
| 201 |
+
if isinstance(path, (str, os.PathLike)):
|
| 202 |
+
data = Path(path).read_bytes()
|
| 203 |
+
else:
|
| 204 |
+
data = path.read()
|
| 205 |
+
normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
|
| 206 |
+
mask_nan = np.all(normal == 0, axis=-1)
|
| 207 |
+
normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
|
| 208 |
+
normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12)
|
| 209 |
+
normal[mask_nan] = np.nan
|
| 210 |
+
return normal
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray:
|
| 214 |
+
"""
|
| 215 |
+
Write a normal image, input float32 normal array of shape (H, W, 3).
|
| 216 |
+
"""
|
| 217 |
+
mask_nan = np.isnan(normal).any(axis=-1)
|
| 218 |
+
normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
|
| 219 |
+
normal[mask_nan] = 0
|
| 220 |
+
data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
|
| 221 |
+
if isinstance(path, (str, os.PathLike)):
|
| 222 |
+
Path(path).write_bytes(data)
|
| 223 |
+
else:
|
| 224 |
+
path.write(data)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def read_mask(path: Union[str, os.PathLike, IO[bytes]]) -> np.ndarray:
|
| 228 |
+
"""
|
| 229 |
+
Read a binary mask, return bool array of shape (H, W).
|
| 230 |
+
"""
|
| 231 |
+
if isinstance(path, (str, os.PathLike)):
|
| 232 |
+
data = Path(path).read_bytes()
|
| 233 |
+
else:
|
| 234 |
+
data = path.read()
|
| 235 |
+
mask = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED)
|
| 236 |
+
if len(mask.shape) == 3:
|
| 237 |
+
mask = mask[..., 0]
|
| 238 |
+
return mask > 0
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def write_mask(path: Union[str, os.PathLike, IO[bytes]], mask: np.ndarray, compression_level: int = 7):
|
| 242 |
+
"""
|
| 243 |
+
Write a binary mask, input bool array of shape (H, W).
|
| 244 |
+
"""
|
| 245 |
+
assert mask.dtype == bool, f"Mask must be bool array, got {mask.dtype}"
|
| 246 |
+
mask = (mask.astype(np.uint8) * 255).astype(np.uint8)
|
| 247 |
+
data = cv2.imencode('.png', mask, [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
|
| 248 |
+
if isinstance(path, (str, os.PathLike)):
|
| 249 |
+
Path(path).write_bytes(data)
|
| 250 |
+
else:
|
| 251 |
+
path.write(data)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
JSON_TYPE = Union[str, int, float, bool, None, Dict[str, "JSON"], List["JSON"]]
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def read_json(path: Union[str, os.PathLike, IO[str]]) -> JSON_TYPE:
|
| 258 |
+
if isinstance(path, (str, os.PathLike)):
|
| 259 |
+
text = Path(path).read_text()
|
| 260 |
+
else:
|
| 261 |
+
text = path.read()
|
| 262 |
+
return json.loads(text)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def write_json(path: Union[str, os.PathLike, IO[str]], content: JSON_TYPE):
|
| 266 |
+
text = json.dumps(content)
|
| 267 |
+
if isinstance(path, (str, os.PathLike)):
|
| 268 |
+
Path(path).write_text(text)
|
| 269 |
+
else:
|
| 270 |
+
path.write(text)
|
third_party/lingbot_depth/mdm/utils/tools.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from numbers import Number
|
| 5 |
+
from functools import wraps
|
| 6 |
+
import warnings
|
| 7 |
+
import math
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import importlib
|
| 11 |
+
import importlib.util
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def catch_exception(fn):
|
| 15 |
+
@wraps(fn)
|
| 16 |
+
def wrapper(*args, **kwargs):
|
| 17 |
+
try:
|
| 18 |
+
return fn(*args, **kwargs)
|
| 19 |
+
except Exception as e:
|
| 20 |
+
import traceback
|
| 21 |
+
print(f"Exception in {fn.__name__}", end='r')
|
| 22 |
+
# print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})
|
| 23 |
+
traceback.print_exc(chain=False)
|
| 24 |
+
time.sleep(0.1)
|
| 25 |
+
return None
|
| 26 |
+
return wrapper
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CallbackOnException:
|
| 30 |
+
def __init__(self, callback: Callable, exception: type):
|
| 31 |
+
self.exception = exception
|
| 32 |
+
self.callback = callback
|
| 33 |
+
|
| 34 |
+
def __enter__(self):
|
| 35 |
+
return self
|
| 36 |
+
|
| 37 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 38 |
+
if isinstance(exc_val, self.exception):
|
| 39 |
+
self.callback()
|
| 40 |
+
return True
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
|
| 44 |
+
for k, v in d.items():
|
| 45 |
+
if isinstance(v, dict):
|
| 46 |
+
for sub_key in traverse_nested_dict_keys(v):
|
| 47 |
+
yield (k, ) + sub_key
|
| 48 |
+
else:
|
| 49 |
+
yield (k, )
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
|
| 53 |
+
for k in keys:
|
| 54 |
+
d = d.get(k, default)
|
| 55 |
+
if d is None:
|
| 56 |
+
break
|
| 57 |
+
return d
|
| 58 |
+
|
| 59 |
+
def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
|
| 60 |
+
for k in keys[:-1]:
|
| 61 |
+
d = d.setdefault(k, {})
|
| 62 |
+
d[keys[-1]] = value
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def key_average(list_of_dicts: list) -> Dict[str, Any]:
|
| 66 |
+
"""
|
| 67 |
+
Returns a dictionary with the average value of each key in the input list of dictionaries.
|
| 68 |
+
"""
|
| 69 |
+
_nested_dict_keys = set()
|
| 70 |
+
for d in list_of_dicts:
|
| 71 |
+
_nested_dict_keys.update(traverse_nested_dict_keys(d))
|
| 72 |
+
_nested_dict_keys = sorted(_nested_dict_keys)
|
| 73 |
+
result = {}
|
| 74 |
+
for k in _nested_dict_keys:
|
| 75 |
+
values = []
|
| 76 |
+
for d in list_of_dicts:
|
| 77 |
+
v = get_nested_dict(d, k)
|
| 78 |
+
if v is not None and not math.isnan(v):
|
| 79 |
+
values.append(v)
|
| 80 |
+
avg = sum(values) / len(values) if values else float('nan')
|
| 81 |
+
set_nested_dict(result, k, avg)
|
| 82 |
+
return result
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
|
| 86 |
+
"""
|
| 87 |
+
Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
|
| 88 |
+
"""
|
| 89 |
+
items = []
|
| 90 |
+
if parent_key is None:
|
| 91 |
+
parent_key = ()
|
| 92 |
+
for k, v in d.items():
|
| 93 |
+
new_key = parent_key + (k, )
|
| 94 |
+
if isinstance(v, MutableMapping):
|
| 95 |
+
items.extend(flatten_nested_dict(v, new_key).items())
|
| 96 |
+
else:
|
| 97 |
+
items.append((new_key, v))
|
| 98 |
+
return dict(items)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
| 102 |
+
"""
|
| 103 |
+
Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
|
| 104 |
+
"""
|
| 105 |
+
result = {}
|
| 106 |
+
for k, v in d.items():
|
| 107 |
+
sub_dict = result
|
| 108 |
+
for k_ in k[:-1]:
|
| 109 |
+
if k_ not in sub_dict:
|
| 110 |
+
sub_dict[k_] = {}
|
| 111 |
+
sub_dict = sub_dict[k_]
|
| 112 |
+
sub_dict[k[-1]] = v
|
| 113 |
+
return result
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def read_jsonl(file):
|
| 117 |
+
import json
|
| 118 |
+
with open(file, 'r') as f:
|
| 119 |
+
data = f.readlines()
|
| 120 |
+
return [json.loads(line) for line in data]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def write_jsonl(data: List[dict], file):
|
| 124 |
+
import json
|
| 125 |
+
with open(file, 'w') as f:
|
| 126 |
+
for item in data:
|
| 127 |
+
f.write(json.dumps(item) + '\n')
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
|
| 131 |
+
import pandas as pd
|
| 132 |
+
data = [flatten_nested_dict(d) for d in data]
|
| 133 |
+
df = pd.DataFrame(data)
|
| 134 |
+
df = df.sort_index(axis=1)
|
| 135 |
+
df.columns = pd.MultiIndex.from_tuples(df.columns)
|
| 136 |
+
return df
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
|
| 140 |
+
if isinstance(d, str):
|
| 141 |
+
for old, new in mapping.items():
|
| 142 |
+
d = d.replace(old, new)
|
| 143 |
+
elif isinstance(d, list):
|
| 144 |
+
for i, item in enumerate(d):
|
| 145 |
+
d[i] = recursive_replace(item, mapping)
|
| 146 |
+
elif isinstance(d, dict):
|
| 147 |
+
for k, v in d.items():
|
| 148 |
+
d[k] = recursive_replace(v, mapping)
|
| 149 |
+
return d
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class timeit:
|
| 153 |
+
_history: Dict[str, List['timeit']] = {}
|
| 154 |
+
|
| 155 |
+
def __init__(self, name: str = None, verbose: bool = True, average: bool = False):
|
| 156 |
+
self.name = name
|
| 157 |
+
self.verbose = verbose
|
| 158 |
+
self.start = None
|
| 159 |
+
self.end = None
|
| 160 |
+
self.average = average
|
| 161 |
+
if average and name not in timeit._history:
|
| 162 |
+
timeit._history[name] = []
|
| 163 |
+
|
| 164 |
+
def __call__(self, func: Callable):
|
| 165 |
+
import inspect
|
| 166 |
+
if inspect.iscoroutinefunction(func):
|
| 167 |
+
async def wrapper(*args, **kwargs):
|
| 168 |
+
with timeit(self.name or func.__qualname__):
|
| 169 |
+
ret = await func(*args, **kwargs)
|
| 170 |
+
return ret
|
| 171 |
+
return wrapper
|
| 172 |
+
else:
|
| 173 |
+
def wrapper(*args, **kwargs):
|
| 174 |
+
with timeit(self.name or func.__qualname__):
|
| 175 |
+
ret = func(*args, **kwargs)
|
| 176 |
+
return ret
|
| 177 |
+
return wrapper
|
| 178 |
+
|
| 179 |
+
def __enter__(self):
|
| 180 |
+
self.start = time.time()
|
| 181 |
+
return self
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def time(self) -> float:
|
| 185 |
+
assert self.start is not None, "Time not yet started."
|
| 186 |
+
assert self.end is not None, "Time not yet ended."
|
| 187 |
+
return self.end - self.start
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def average_time(self) -> float:
|
| 191 |
+
assert self.average, "Average time not available."
|
| 192 |
+
return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def history(self) -> List['timeit']:
|
| 196 |
+
return timeit._history.get(self.name, [])
|
| 197 |
+
|
| 198 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 199 |
+
self.end = time.time()
|
| 200 |
+
if self.average:
|
| 201 |
+
timeit._history[self.name].append(self)
|
| 202 |
+
if self.verbose:
|
| 203 |
+
if self.average:
|
| 204 |
+
avg = self.average_time
|
| 205 |
+
print(f"{self.name or 'It'} took {avg:.6f} seconds in average.")
|
| 206 |
+
else:
|
| 207 |
+
print(f"{self.name or 'It'} took {self.time:.6f} seconds.")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
|
| 211 |
+
first = strings[0]
|
| 212 |
+
|
| 213 |
+
for start in range(len(first)):
|
| 214 |
+
if any(s[start] != strings[0][start] for s in strings):
|
| 215 |
+
break
|
| 216 |
+
|
| 217 |
+
for end in range(1, min(len(s) for s in strings)):
|
| 218 |
+
if any(s[-end] != first[-end] for s in strings):
|
| 219 |
+
break
|
| 220 |
+
|
| 221 |
+
return [s[start:len(s) - end + 1] for s in strings]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
|
| 225 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 226 |
+
from contextlib import nullcontext
|
| 227 |
+
from tqdm import tqdm
|
| 228 |
+
|
| 229 |
+
if pbar is not None:
|
| 230 |
+
pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
|
| 231 |
+
else:
|
| 232 |
+
pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
|
| 233 |
+
|
| 234 |
+
def decorator(fn: Callable):
|
| 235 |
+
with (
|
| 236 |
+
ThreadPoolExecutor(max_workers=num_workers) as executor,
|
| 237 |
+
pbar
|
| 238 |
+
):
|
| 239 |
+
pbar.refresh()
|
| 240 |
+
@catch_exception
|
| 241 |
+
@suppress_traceback
|
| 242 |
+
def _fn(input):
|
| 243 |
+
ret = fn(input)
|
| 244 |
+
pbar.update()
|
| 245 |
+
return ret
|
| 246 |
+
executor.map(_fn, inputs)
|
| 247 |
+
executor.shutdown(wait=True)
|
| 248 |
+
|
| 249 |
+
return decorator
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def suppress_traceback(fn):
|
| 253 |
+
@wraps(fn)
|
| 254 |
+
def wrapper(*args, **kwargs):
|
| 255 |
+
try:
|
| 256 |
+
return fn(*args, **kwargs)
|
| 257 |
+
except Exception as e:
|
| 258 |
+
e.__traceback__ = e.__traceback__.tb_next.tb_next
|
| 259 |
+
raise
|
| 260 |
+
return wrapper
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class no_warnings:
|
| 264 |
+
def __init__(self, action: str = 'ignore', **kwargs):
|
| 265 |
+
self.action = action
|
| 266 |
+
self.filter_kwargs = kwargs
|
| 267 |
+
|
| 268 |
+
def __call__(self, fn):
|
| 269 |
+
@wraps(fn)
|
| 270 |
+
def wrapper(*args, **kwargs):
|
| 271 |
+
with warnings.catch_warnings():
|
| 272 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 273 |
+
return fn(*args, **kwargs)
|
| 274 |
+
return wrapper
|
| 275 |
+
|
| 276 |
+
def __enter__(self):
|
| 277 |
+
self.warnings_manager = warnings.catch_warnings()
|
| 278 |
+
self.warnings_manager.__enter__()
|
| 279 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 280 |
+
|
| 281 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 282 |
+
self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str):
|
| 286 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 287 |
+
module = importlib.util.module_from_spec(spec)
|
| 288 |
+
spec.loader.exec_module(module)
|
| 289 |
+
return module
|
third_party/lingbot_depth/mdm/utils/vis.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib
|
| 5 |
+
import trimesh
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
|
| 12 |
+
depth = depth.copy()
|
| 13 |
+
if mask is None:
|
| 14 |
+
depth = np.where(depth > 0, depth, np.nan)
|
| 15 |
+
else:
|
| 16 |
+
depth = np.where((depth > 0) & mask, depth, np.nan)
|
| 17 |
+
disp = 1 / depth
|
| 18 |
+
if normalize:
|
| 19 |
+
min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99)
|
| 20 |
+
disp = (disp - min_disp) / (max_disp - min_disp)
|
| 21 |
+
|
| 22 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
|
| 23 |
+
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
|
| 24 |
+
return colored
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray:
|
| 28 |
+
if mask is not None:
|
| 29 |
+
depth = np.where(mask, depth, np.nan)
|
| 30 |
+
|
| 31 |
+
min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999)
|
| 32 |
+
depth = (depth - min_depth) / (max_depth - min_depth)
|
| 33 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0)
|
| 34 |
+
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
|
| 35 |
+
return colored
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
|
| 39 |
+
if mask is not None:
|
| 40 |
+
disparity = np.where(mask, disparity, np.nan)
|
| 41 |
+
|
| 42 |
+
if normalize:
|
| 43 |
+
min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999)
|
| 44 |
+
disparity = (disparity - min_disp) / (max_disp - min_disp)
|
| 45 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0)
|
| 46 |
+
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
|
| 47 |
+
return colored
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
|
| 51 |
+
if mask is not None:
|
| 52 |
+
normal = np.where(mask[..., None], normal, 0)
|
| 53 |
+
normal = normal * [0.5, -0.5, -0.5] + 0.5
|
| 54 |
+
normal = (normal.clip(0, 1) * 255).astype(np.uint8)
|
| 55 |
+
return normal
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None):
|
| 59 |
+
vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map))
|
| 60 |
+
cmap = matplotlib.colormaps[cmap]
|
| 61 |
+
colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3]
|
| 62 |
+
if mask is not None:
|
| 63 |
+
colorized_error_map = np.where(mask[..., None], colorized_error_map, 0)
|
| 64 |
+
colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8))
|
| 65 |
+
return colorized_error_map
|
third_party/lingbot_depth/pyproject.toml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "mdm"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
readme = "README.md"
|
| 9 |
+
dependencies = [
|
| 10 |
+
"click",
|
| 11 |
+
"opencv-python",
|
| 12 |
+
"scipy",
|
| 13 |
+
"matplotlib",
|
| 14 |
+
"trimesh",
|
| 15 |
+
"pillow",
|
| 16 |
+
"huggingface_hub",
|
| 17 |
+
"numpy",
|
| 18 |
+
"torch==2.6.0",
|
| 19 |
+
"torchvision",
|
| 20 |
+
"xformers==v0.0.29.post2",
|
| 21 |
+
]
|
| 22 |
+
requires-python = ">=3.9"
|
| 23 |
+
|
| 24 |
+
[tool.setuptools.packages.find]
|
| 25 |
+
where = ["."]
|
| 26 |
+
include = ["mdm*"]
|
third_party/sam3/pyproject.toml
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "sam3"
|
| 7 |
+
dynamic = ["version"]
|
| 8 |
+
description = "SAM3 (Segment Anything Model 3) implementation"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.8"
|
| 11 |
+
license = {file = "LICENSE"}
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Meta AI Research"}
|
| 14 |
+
]
|
| 15 |
+
classifiers = [
|
| 16 |
+
"Development Status :: 4 - Beta",
|
| 17 |
+
"Intended Audience :: Science/Research",
|
| 18 |
+
"License :: OSI Approved :: MIT License",
|
| 19 |
+
"Programming Language :: Python :: 3",
|
| 20 |
+
"Programming Language :: Python :: 3.8",
|
| 21 |
+
"Programming Language :: Python :: 3.9",
|
| 22 |
+
"Programming Language :: Python :: 3.10",
|
| 23 |
+
"Programming Language :: Python :: 3.11",
|
| 24 |
+
"Programming Language :: Python :: 3.12",
|
| 25 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 26 |
+
]
|
| 27 |
+
dependencies = [
|
| 28 |
+
"timm>=1.0.17",
|
| 29 |
+
"numpy>=1.26,<2",
|
| 30 |
+
"tqdm",
|
| 31 |
+
"ftfy==6.1.1",
|
| 32 |
+
"regex",
|
| 33 |
+
"iopath>=0.1.10",
|
| 34 |
+
"typing_extensions",
|
| 35 |
+
"huggingface_hub",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
[project.optional-dependencies]
|
| 39 |
+
dev = [
|
| 40 |
+
"pytest",
|
| 41 |
+
"pytest-cov",
|
| 42 |
+
"black==24.2.0",
|
| 43 |
+
"ufmt==2.8.0",
|
| 44 |
+
"ruff-api==0.1.0",
|
| 45 |
+
"usort==1.0.2",
|
| 46 |
+
"gitpython==3.1.31",
|
| 47 |
+
"yt-dlp",
|
| 48 |
+
"pandas",
|
| 49 |
+
"opencv-python",
|
| 50 |
+
"pycocotools",
|
| 51 |
+
"numba",
|
| 52 |
+
"python-rapidjson",
|
| 53 |
+
]
|
| 54 |
+
notebooks = [
|
| 55 |
+
"matplotlib",
|
| 56 |
+
"jupyter",
|
| 57 |
+
"notebook",
|
| 58 |
+
"ipywidgets",
|
| 59 |
+
"ipycanvas",
|
| 60 |
+
"ipympl",
|
| 61 |
+
"pycocotools",
|
| 62 |
+
"decord",
|
| 63 |
+
"opencv-python",
|
| 64 |
+
"einops",
|
| 65 |
+
"scikit-image",
|
| 66 |
+
"scikit-learn",
|
| 67 |
+
]
|
| 68 |
+
train = [
|
| 69 |
+
"hydra-core",
|
| 70 |
+
"submitit",
|
| 71 |
+
"tensorboard",
|
| 72 |
+
"zstandard",
|
| 73 |
+
"scipy",
|
| 74 |
+
"torchmetrics",
|
| 75 |
+
"fvcore",
|
| 76 |
+
"fairscale",
|
| 77 |
+
"scikit-image",
|
| 78 |
+
"scikit-learn",
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
[project.urls]
|
| 82 |
+
"Homepage" = "https://github.com/facebookresearch/sam3"
|
| 83 |
+
"Bug Tracker" = "https://github.com/facebookresearch/sam3/issues"
|
| 84 |
+
|
| 85 |
+
[tool.setuptools.packages.find]
|
| 86 |
+
include = ["sam3*"]
|
| 87 |
+
exclude = ["build*", "scripts*", "examples*"]
|
| 88 |
+
|
| 89 |
+
[tool.setuptools.package-data]
|
| 90 |
+
sam3 = ["assets/*.txt.gz"]
|
| 91 |
+
|
| 92 |
+
[tool.setuptools.dynamic]
|
| 93 |
+
version = {attr = "sam3.__version__"}
|
| 94 |
+
|
| 95 |
+
[tool.black]
|
| 96 |
+
line-length = 88
|
| 97 |
+
target-version = ['py38', 'py39', 'py310', 'py311', 'py312']
|
| 98 |
+
include = '\.pyi?$'
|
| 99 |
+
|
| 100 |
+
[tool.isort]
|
| 101 |
+
profile = "black"
|
| 102 |
+
multi_line_output = 3
|
| 103 |
+
|
| 104 |
+
[tool.usort]
|
| 105 |
+
first_party_detection = false
|
| 106 |
+
|
| 107 |
+
[tool.ufmt]
|
| 108 |
+
formatter = "ruff-api"
|
| 109 |
+
|
| 110 |
+
[tool.mypy]
|
| 111 |
+
python_version = "3.12"
|
| 112 |
+
warn_return_any = true
|
| 113 |
+
warn_unused_configs = true
|
| 114 |
+
disallow_untyped_defs = true
|
| 115 |
+
disallow_incomplete_defs = true
|
| 116 |
+
|
| 117 |
+
[[tool.mypy.overrides]]
|
| 118 |
+
module = [
|
| 119 |
+
"torch.*",
|
| 120 |
+
"torchvision.*",
|
| 121 |
+
"timm.*",
|
| 122 |
+
"numpy.*",
|
| 123 |
+
"PIL.*",
|
| 124 |
+
"tqdm.*",
|
| 125 |
+
"ftfy.*",
|
| 126 |
+
"regex.*",
|
| 127 |
+
"iopath.*",
|
| 128 |
+
]
|
| 129 |
+
ignore_missing_imports = true
|
| 130 |
+
|
| 131 |
+
[tool.pytest.ini_options]
|
| 132 |
+
testpaths = ["tests"]
|
| 133 |
+
python_files = "test_*.py"
|
| 134 |
+
python_classes = "Test*"
|
| 135 |
+
python_functions = "test_*"
|
third_party/sam3/sam3/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .model_builder import build_sam3_image_model, build_sam3_predictor
|
| 6 |
+
|
| 7 |
+
__version__ = "0.1.0"
|
| 8 |
+
|
| 9 |
+
__all__ = ["build_sam3_image_model", "build_sam3_predictor"]
|
third_party/sam3/sam3/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (453 Bytes). View file
|
|
|
third_party/sam3/sam3/__pycache__/logger.cpython-311.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
third_party/sam3/sam3/__pycache__/model_builder.cpython-311.pyc
ADDED
|
Binary file (42.9 kB). View file
|
|
|
third_party/sam3/sam3/agent/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
third_party/sam3/sam3/agent/agent_core.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from .client_llm import send_generate_request
|
| 13 |
+
from .client_sam3 import call_sam_service
|
| 14 |
+
from .viz import visualize
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def save_debug_messages(messages_list, debug, debug_folder_path, debug_jsonl_path):
|
| 18 |
+
"""Save messages to debug jsonl file if debug is enabled"""
|
| 19 |
+
if debug and debug_jsonl_path:
|
| 20 |
+
# Ensure the debug directory exists before writing
|
| 21 |
+
os.makedirs(debug_folder_path, exist_ok=True)
|
| 22 |
+
with open(debug_jsonl_path, "w") as f:
|
| 23 |
+
for msg in messages_list:
|
| 24 |
+
f.write(json.dumps(msg, indent=4) + "\n")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path):
|
| 28 |
+
"""Clean up debug files when function successfully returns"""
|
| 29 |
+
if debug and debug_folder_path:
|
| 30 |
+
try:
|
| 31 |
+
if os.path.exists(debug_jsonl_path):
|
| 32 |
+
os.remove(debug_jsonl_path)
|
| 33 |
+
if os.path.exists(debug_folder_path):
|
| 34 |
+
os.rmdir(debug_folder_path)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"Warning: Could not clean up debug files: {e}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def count_images(messages):
|
| 40 |
+
"""Count the total number of images present in the messages history."""
|
| 41 |
+
total = 0
|
| 42 |
+
for message in messages:
|
| 43 |
+
# Check if message has content (should be a list)
|
| 44 |
+
if "content" in message and isinstance(message["content"], list):
|
| 45 |
+
# Iterate through each content item
|
| 46 |
+
for content_item in message["content"]:
|
| 47 |
+
# Check if content item is a dict with type "image"
|
| 48 |
+
if (
|
| 49 |
+
isinstance(content_item, dict)
|
| 50 |
+
and content_item.get("type") == "image"
|
| 51 |
+
):
|
| 52 |
+
total += 1
|
| 53 |
+
return total
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _prune_messages_for_next_round(
|
| 57 |
+
messages_list,
|
| 58 |
+
used_text_prompts,
|
| 59 |
+
latest_sam3_text_prompt,
|
| 60 |
+
img_path,
|
| 61 |
+
initial_text_prompt,
|
| 62 |
+
):
|
| 63 |
+
"""Return a new messages list that contains only:
|
| 64 |
+
1) messages[:2] (with optional warning text added to the second message's content)
|
| 65 |
+
2) the latest assistant message (and everything after it) that contains a segment_phrase tool call
|
| 66 |
+
"""
|
| 67 |
+
# There should not be more than 10 messages in the conversation history
|
| 68 |
+
assert len(messages_list) < 10
|
| 69 |
+
|
| 70 |
+
# Part 1: always keep the first two message JSONs
|
| 71 |
+
part1 = copy.deepcopy(messages_list[:2])
|
| 72 |
+
|
| 73 |
+
# Part 2: search backwards for the latest assistant message containing a segment_phrase tool call
|
| 74 |
+
part2_start_idx = None
|
| 75 |
+
for idx in range(len(messages_list) - 1, 1, -1):
|
| 76 |
+
msg = messages_list[idx]
|
| 77 |
+
# We only consider assistant messages with a "content" list
|
| 78 |
+
if msg.get("role") != "assistant" or "content" not in msg:
|
| 79 |
+
continue
|
| 80 |
+
# Look for any content element that is a text containing the segment_phrase tool call
|
| 81 |
+
for content in msg["content"]:
|
| 82 |
+
if (
|
| 83 |
+
isinstance(content, dict)
|
| 84 |
+
and content.get("type") == "text"
|
| 85 |
+
and "<tool>" in content.get("text", "")
|
| 86 |
+
and "segment_phrase" in content.get("text", "")
|
| 87 |
+
):
|
| 88 |
+
part2_start_idx = idx
|
| 89 |
+
break
|
| 90 |
+
if part2_start_idx is not None:
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
part2 = messages_list[part2_start_idx:] if part2_start_idx is not None else []
|
| 94 |
+
|
| 95 |
+
# Part 3: decide whether to add warning text to the second message in part1
|
| 96 |
+
previously_used = (
|
| 97 |
+
[p for p in used_text_prompts if p != latest_sam3_text_prompt]
|
| 98 |
+
if latest_sam3_text_prompt
|
| 99 |
+
else list(used_text_prompts)
|
| 100 |
+
)
|
| 101 |
+
if part2 and len(previously_used) > 0:
|
| 102 |
+
warning_text = f'Note that we have previously called the segment_phrase tool with each "text_prompt" in this list: {list(previously_used)}, but none of the generated results were satisfactory. So make sure that you do not use any of these phrases as the "text_prompt" to call the segment_phrase tool again.'
|
| 103 |
+
# Replace the second message entirely to keep exactly 2 content items
|
| 104 |
+
part1[1] = {
|
| 105 |
+
"role": "user",
|
| 106 |
+
"content": [
|
| 107 |
+
{"type": "image", "image": img_path},
|
| 108 |
+
{
|
| 109 |
+
"type": "text",
|
| 110 |
+
"text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'."
|
| 111 |
+
+ " "
|
| 112 |
+
+ warning_text,
|
| 113 |
+
},
|
| 114 |
+
],
|
| 115 |
+
}
|
| 116 |
+
assert len(part1[1]["content"]) == 2
|
| 117 |
+
|
| 118 |
+
# Build the new messages list: part1 (with optional warning), then part2
|
| 119 |
+
new_messages = list(part1)
|
| 120 |
+
new_messages.extend(part2)
|
| 121 |
+
return new_messages
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def agent_inference(
|
| 125 |
+
img_path: str,
|
| 126 |
+
initial_text_prompt: str,
|
| 127 |
+
debug: bool = False,
|
| 128 |
+
send_generate_request=send_generate_request,
|
| 129 |
+
call_sam_service=call_sam_service,
|
| 130 |
+
max_generations: int = 100,
|
| 131 |
+
output_dir="../../sam3_agent_out",
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Given a text prompt and an image, this tool will perform all aspects of agentic problem solving,
|
| 135 |
+
while saving sam3 and MLLM outputs to their respective directories.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
img_path: Path to the input image
|
| 139 |
+
initial_text_prompt: Initial text prompt from the user
|
| 140 |
+
debug: Whether to enable debug mode
|
| 141 |
+
max_generations: Maximum number of send_generate_request calls allowed (default: 100)
|
| 142 |
+
"""
|
| 143 |
+
# setup dir
|
| 144 |
+
sam_output_dir = os.path.join(output_dir, "sam_out")
|
| 145 |
+
error_save_dir = os.path.join(output_dir, "none_out")
|
| 146 |
+
debug_save_dir = os.path.join(output_dir, "agent_debug_out")
|
| 147 |
+
os.makedirs(sam_output_dir, exist_ok=True)
|
| 148 |
+
os.makedirs(error_save_dir, exist_ok=True)
|
| 149 |
+
os.makedirs(debug_save_dir, exist_ok=True)
|
| 150 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 151 |
+
MLLM_SYSTEM_PROMPT_PATH = os.path.join(
|
| 152 |
+
current_dir, "system_prompts/system_prompt.txt"
|
| 153 |
+
)
|
| 154 |
+
ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH = os.path.join(
|
| 155 |
+
current_dir, "system_prompts/system_prompt_iterative_checking.txt"
|
| 156 |
+
)
|
| 157 |
+
# init variables
|
| 158 |
+
PATH_TO_LATEST_OUTPUT_JSON = ""
|
| 159 |
+
LATEST_SAM3_TEXT_PROMPT = ""
|
| 160 |
+
USED_TEXT_PROMPTS = (
|
| 161 |
+
set()
|
| 162 |
+
) # Track all previously used text prompts for segment_phrase
|
| 163 |
+
generation_count = 0 # Counter for number of send_generate_request calls
|
| 164 |
+
|
| 165 |
+
# debug setup
|
| 166 |
+
debug_folder_path = None
|
| 167 |
+
debug_jsonl_path = None
|
| 168 |
+
if debug:
|
| 169 |
+
debug_folder_path = os.path.join(
|
| 170 |
+
debug_save_dir, f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}"
|
| 171 |
+
)
|
| 172 |
+
debug_jsonl_path = os.path.join(debug_folder_path, "debug_history.json")
|
| 173 |
+
os.makedirs(debug_folder_path, exist_ok=True)
|
| 174 |
+
|
| 175 |
+
# The helper functions are now defined outside the agent_inference function
|
| 176 |
+
with open(MLLM_SYSTEM_PROMPT_PATH, "r") as f:
|
| 177 |
+
system_prompt = f.read().strip()
|
| 178 |
+
with open(ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH, "r") as f:
|
| 179 |
+
iterative_checking_system_prompt = f.read().strip()
|
| 180 |
+
|
| 181 |
+
# Construct the initial message list
|
| 182 |
+
messages = [
|
| 183 |
+
{"role": "system", "content": system_prompt},
|
| 184 |
+
{
|
| 185 |
+
"role": "user",
|
| 186 |
+
"content": [
|
| 187 |
+
{"type": "image", "image": img_path},
|
| 188 |
+
{
|
| 189 |
+
"type": "text",
|
| 190 |
+
"text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'.",
|
| 191 |
+
},
|
| 192 |
+
],
|
| 193 |
+
},
|
| 194 |
+
]
|
| 195 |
+
print(f"> Text prompt: {initial_text_prompt}")
|
| 196 |
+
print(f"> Image path: {img_path}")
|
| 197 |
+
|
| 198 |
+
print("\n\n")
|
| 199 |
+
print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
|
| 200 |
+
print("\n\n")
|
| 201 |
+
generated_text = send_generate_request(messages)
|
| 202 |
+
print(f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n")
|
| 203 |
+
while generated_text is not None:
|
| 204 |
+
save_debug_messages(messages, debug, debug_folder_path, debug_jsonl_path)
|
| 205 |
+
assert (
|
| 206 |
+
"<tool>" in generated_text,
|
| 207 |
+
f"Generated text does not contain <tool> tag: {generated_text}",
|
| 208 |
+
)
|
| 209 |
+
generated_text = generated_text.split("</tool>", 1)[0] + "</tool>"
|
| 210 |
+
tool_call_json_str = (
|
| 211 |
+
generated_text.split("<tool>")[-1]
|
| 212 |
+
.split("</tool>")[0]
|
| 213 |
+
.strip()
|
| 214 |
+
.replace(r"}}}", r"}}") # remove extra } if any
|
| 215 |
+
)
|
| 216 |
+
try:
|
| 217 |
+
tool_call = json.loads(tool_call_json_str)
|
| 218 |
+
except json.JSONDecodeError:
|
| 219 |
+
raise ValueError(f"Invalid JSON in tool call: {tool_call_json_str}")
|
| 220 |
+
|
| 221 |
+
if PATH_TO_LATEST_OUTPUT_JSON == "":
|
| 222 |
+
# The first tool call must be segment_phrase or report_no_mask
|
| 223 |
+
assert (
|
| 224 |
+
tool_call["name"] == "segment_phrase"
|
| 225 |
+
or tool_call["name"] == "report_no_mask"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if tool_call["name"] == "segment_phrase":
|
| 229 |
+
print("🔍 Calling segment_phrase tool...")
|
| 230 |
+
assert list(tool_call["parameters"].keys()) == ["text_prompt"]
|
| 231 |
+
|
| 232 |
+
# Check if this text_prompt has been used before
|
| 233 |
+
current_text_prompt = tool_call["parameters"]["text_prompt"]
|
| 234 |
+
if current_text_prompt in USED_TEXT_PROMPTS:
|
| 235 |
+
print(
|
| 236 |
+
f"❌ Text prompt '{current_text_prompt}' has been used before. Requesting a different prompt."
|
| 237 |
+
)
|
| 238 |
+
duplicate_prompt_message = f"You have previously used '{current_text_prompt}' as your text_prompt to call the segment_phrase tool. You may not use it again. Please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase prompt, while adhering to all the rules stated in the system prompt. You must also never use any of the following text_prompt(s): {str(list(USED_TEXT_PROMPTS))}."
|
| 239 |
+
messages.append(
|
| 240 |
+
{
|
| 241 |
+
"role": "assistant",
|
| 242 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 243 |
+
}
|
| 244 |
+
)
|
| 245 |
+
messages.append(
|
| 246 |
+
{
|
| 247 |
+
"role": "user",
|
| 248 |
+
"content": [{"type": "text", "text": duplicate_prompt_message}],
|
| 249 |
+
}
|
| 250 |
+
)
|
| 251 |
+
else:
|
| 252 |
+
# Add the text_prompt to the set of used prompts
|
| 253 |
+
USED_TEXT_PROMPTS.add(current_text_prompt)
|
| 254 |
+
LATEST_SAM3_TEXT_PROMPT = current_text_prompt
|
| 255 |
+
PATH_TO_LATEST_OUTPUT_JSON = call_sam_service(
|
| 256 |
+
image_path=img_path,
|
| 257 |
+
text_prompt=current_text_prompt,
|
| 258 |
+
output_folder_path=sam_output_dir,
|
| 259 |
+
)
|
| 260 |
+
sam3_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
|
| 261 |
+
sam3_output_image_path = sam3_outputs["output_image_path"]
|
| 262 |
+
num_masks = len(sam3_outputs["pred_boxes"])
|
| 263 |
+
|
| 264 |
+
messages.append(
|
| 265 |
+
{
|
| 266 |
+
"role": "assistant",
|
| 267 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 268 |
+
}
|
| 269 |
+
)
|
| 270 |
+
if num_masks == 0:
|
| 271 |
+
print("❌ No masks generated by SAM3, reporting no mask to Qwen.")
|
| 272 |
+
sam3_output_text_message = f"The segment_phrase tool did not generate any masks for the text_prompt '{current_text_prompt}'. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt. Please be reminded that the original user query was '{initial_text_prompt}'."
|
| 273 |
+
messages.append(
|
| 274 |
+
{
|
| 275 |
+
"role": "user",
|
| 276 |
+
"content": [
|
| 277 |
+
{"type": "text", "text": sam3_output_text_message}
|
| 278 |
+
],
|
| 279 |
+
}
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
sam3_output_text_message = rf"The segment_phrase tool generated {num_masks} available masks. All {num_masks} available masks are rendered in this image below, now you must analyze the {num_masks} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action. Please be reminded that the original user query was '{initial_text_prompt}'."
|
| 283 |
+
messages.append(
|
| 284 |
+
{
|
| 285 |
+
"role": "user",
|
| 286 |
+
"content": [
|
| 287 |
+
{"type": "text", "text": sam3_output_text_message},
|
| 288 |
+
{"type": "image", "image": sam3_output_image_path},
|
| 289 |
+
],
|
| 290 |
+
}
|
| 291 |
+
)
|
| 292 |
+
print("\n\n>>> sam3_output_text_message:\n", sam3_output_text_message)
|
| 293 |
+
|
| 294 |
+
elif tool_call["name"] == "examine_each_mask":
|
| 295 |
+
print("🔍 Calling examine_each_mask tool...")
|
| 296 |
+
assert LATEST_SAM3_TEXT_PROMPT != ""
|
| 297 |
+
|
| 298 |
+
# Make sure that the last message is a image
|
| 299 |
+
assert (
|
| 300 |
+
messages[-1]["content"][1]["type"] == "image"
|
| 301 |
+
), "Second content element should be an image"
|
| 302 |
+
messages.pop() # Remove the last user message
|
| 303 |
+
# Add simplified replacement message
|
| 304 |
+
simplified_message = {
|
| 305 |
+
"role": "user",
|
| 306 |
+
"content": [
|
| 307 |
+
{
|
| 308 |
+
"type": "text",
|
| 309 |
+
"text": "The segment_phrase tool generated several masks. Now you must analyze the mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
|
| 310 |
+
}
|
| 311 |
+
],
|
| 312 |
+
}
|
| 313 |
+
messages.append(simplified_message)
|
| 314 |
+
|
| 315 |
+
current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
|
| 316 |
+
num_masks = len(current_outputs["pred_masks"])
|
| 317 |
+
masks_to_keep = []
|
| 318 |
+
|
| 319 |
+
# MLLM check the mask one by one
|
| 320 |
+
for i in range(num_masks):
|
| 321 |
+
print(f"🔍 Checking mask {i + 1}/{num_masks}...")
|
| 322 |
+
image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
|
| 323 |
+
|
| 324 |
+
image_w_zoomed_in_mask_i_path = os.path.join(
|
| 325 |
+
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
|
| 326 |
+
).replace(".png", f"_zoom_in_mask_{i + 1}.png")
|
| 327 |
+
image_w_mask_i_path = os.path.join(
|
| 328 |
+
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
|
| 329 |
+
).replace(".png", f"_selected_mask_{i + 1}.png")
|
| 330 |
+
image_w_zoomed_in_mask_i.save(image_w_zoomed_in_mask_i_path)
|
| 331 |
+
image_w_mask_i.save(image_w_mask_i_path)
|
| 332 |
+
|
| 333 |
+
iterative_checking_messages = [
|
| 334 |
+
{"role": "system", "content": iterative_checking_system_prompt},
|
| 335 |
+
{
|
| 336 |
+
"role": "user",
|
| 337 |
+
"content": [
|
| 338 |
+
{"type": "text", "text": f"The raw input image: "},
|
| 339 |
+
{"type": "image", "image": img_path},
|
| 340 |
+
{
|
| 341 |
+
"type": "text",
|
| 342 |
+
"text": f"The initial user input query is: '{initial_text_prompt}'",
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"type": "text",
|
| 346 |
+
"text": f"Image with the predicted segmentation mask rendered on it: ",
|
| 347 |
+
},
|
| 348 |
+
{"type": "image", "image": image_w_mask_i_path},
|
| 349 |
+
{
|
| 350 |
+
"type": "text",
|
| 351 |
+
"text": f"Image with the zoomed-in mask: ",
|
| 352 |
+
},
|
| 353 |
+
{"type": "image", "image": image_w_zoomed_in_mask_i_path},
|
| 354 |
+
],
|
| 355 |
+
},
|
| 356 |
+
]
|
| 357 |
+
checking_generated_text = send_generate_request(
|
| 358 |
+
iterative_checking_messages
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Process the generated text to determine if the mask should be kept or rejected
|
| 362 |
+
if checking_generated_text is None:
|
| 363 |
+
raise ValueError(
|
| 364 |
+
"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
|
| 365 |
+
)
|
| 366 |
+
print(f"Generated text for mask {i + 1}: {checking_generated_text}")
|
| 367 |
+
verdict = (
|
| 368 |
+
checking_generated_text.split("<verdict>")[-1]
|
| 369 |
+
.split("</verdict>")[0]
|
| 370 |
+
.strip()
|
| 371 |
+
)
|
| 372 |
+
if "Accept" in verdict:
|
| 373 |
+
assert not "Reject" in verdict
|
| 374 |
+
print(f"Mask {i + 1} accepted, keeping it in the outputs.")
|
| 375 |
+
masks_to_keep.append(i)
|
| 376 |
+
elif "Reject" in verdict:
|
| 377 |
+
assert not "Accept" in verdict
|
| 378 |
+
print(f"Mask {i + 1} rejected, removing it from the outputs.")
|
| 379 |
+
else:
|
| 380 |
+
raise ValueError(
|
| 381 |
+
f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
updated_outputs = {
|
| 385 |
+
"original_image_path": current_outputs["original_image_path"],
|
| 386 |
+
"orig_img_h": current_outputs["orig_img_h"],
|
| 387 |
+
"orig_img_w": current_outputs["orig_img_w"],
|
| 388 |
+
"pred_boxes": [current_outputs["pred_boxes"][i] for i in masks_to_keep],
|
| 389 |
+
"pred_scores": [
|
| 390 |
+
current_outputs["pred_scores"][i] for i in masks_to_keep
|
| 391 |
+
],
|
| 392 |
+
"pred_masks": [current_outputs["pred_masks"][i] for i in masks_to_keep],
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
image_w_check_masks = visualize(updated_outputs)
|
| 396 |
+
image_w_check_masks_path = os.path.join(
|
| 397 |
+
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
|
| 398 |
+
).replace(
|
| 399 |
+
".png",
|
| 400 |
+
f"_selected_masks_{'-'.join(map(str, [i + 1 for i in masks_to_keep]))}.png".replace(
|
| 401 |
+
"/", "_"
|
| 402 |
+
),
|
| 403 |
+
)
|
| 404 |
+
image_w_check_masks.save(image_w_check_masks_path)
|
| 405 |
+
# save the updated json outputs and append to message history
|
| 406 |
+
messages.append(
|
| 407 |
+
{
|
| 408 |
+
"role": "assistant",
|
| 409 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 410 |
+
}
|
| 411 |
+
)
|
| 412 |
+
if len(masks_to_keep) == 0:
|
| 413 |
+
messages.append(
|
| 414 |
+
{
|
| 415 |
+
"role": "user",
|
| 416 |
+
"content": [
|
| 417 |
+
{
|
| 418 |
+
"type": "text",
|
| 419 |
+
"text": f"The original user query was: '{initial_text_prompt}'. The examine_each_mask tool examined and rejected all of the masks generated by the segment_phrase tool. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt.",
|
| 420 |
+
}
|
| 421 |
+
],
|
| 422 |
+
}
|
| 423 |
+
)
|
| 424 |
+
else:
|
| 425 |
+
messages.append(
|
| 426 |
+
{
|
| 427 |
+
"role": "user",
|
| 428 |
+
"content": [
|
| 429 |
+
{
|
| 430 |
+
"type": "text",
|
| 431 |
+
"text": f"The original user query was: '{initial_text_prompt}'. After calling the examine_each_mask tool on the available masks, the number of available masks is now {len(masks_to_keep)}. All {len(masks_to_keep)} available masks are rendered in this image below, now you must analyze the {len(masks_to_keep)} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
|
| 432 |
+
},
|
| 433 |
+
{"type": "image", "image": image_w_check_masks_path},
|
| 434 |
+
],
|
| 435 |
+
}
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Create a new filename based on the original path to avoid filename length issues
|
| 439 |
+
base_path = PATH_TO_LATEST_OUTPUT_JSON
|
| 440 |
+
# Remove any existing "masks_" suffix to avoid duplication
|
| 441 |
+
if "masks_" in base_path:
|
| 442 |
+
base_path = base_path.split("masks_")[0] + ".json"
|
| 443 |
+
# Create new filename with current masks; use a clearer suffix when empty
|
| 444 |
+
if len(masks_to_keep) == 0:
|
| 445 |
+
PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
|
| 446 |
+
".json", "masks_none.json"
|
| 447 |
+
)
|
| 448 |
+
else:
|
| 449 |
+
PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
|
| 450 |
+
".json", f"masks_{'_'.join(map(str, masks_to_keep))}.json"
|
| 451 |
+
)
|
| 452 |
+
json.dump(updated_outputs, open(PATH_TO_LATEST_OUTPUT_JSON, "w"), indent=4)
|
| 453 |
+
|
| 454 |
+
elif tool_call["name"] == "select_masks_and_return":
|
| 455 |
+
print("🔍 Calling select_masks_and_return tool...")
|
| 456 |
+
current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
|
| 457 |
+
|
| 458 |
+
assert list(tool_call["parameters"].keys()) == ["final_answer_masks"]
|
| 459 |
+
masks_to_keep = tool_call["parameters"]["final_answer_masks"]
|
| 460 |
+
|
| 461 |
+
# Keep only valid mask indices, remove duplicates, and preserve deterministic ascending order
|
| 462 |
+
available_masks = set(range(1, len(current_outputs["pred_masks"]) + 1))
|
| 463 |
+
masks_to_keep = sorted({i for i in masks_to_keep if i in available_masks})
|
| 464 |
+
# Change this to a update message telling the model to try again along with information about errors made.
|
| 465 |
+
|
| 466 |
+
final_outputs = {
|
| 467 |
+
"original_image_path": current_outputs["original_image_path"],
|
| 468 |
+
"orig_img_h": current_outputs["orig_img_h"],
|
| 469 |
+
"orig_img_w": current_outputs["orig_img_w"],
|
| 470 |
+
"pred_boxes": [
|
| 471 |
+
current_outputs["pred_boxes"][i - 1] for i in masks_to_keep
|
| 472 |
+
],
|
| 473 |
+
"pred_scores": [
|
| 474 |
+
current_outputs["pred_scores"][i - 1] for i in masks_to_keep
|
| 475 |
+
],
|
| 476 |
+
"pred_masks": [
|
| 477 |
+
current_outputs["pred_masks"][i - 1] for i in masks_to_keep
|
| 478 |
+
],
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
rendered_final_output = visualize(final_outputs)
|
| 482 |
+
messages.append(
|
| 483 |
+
{
|
| 484 |
+
"role": "assistant",
|
| 485 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 486 |
+
}
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Clean up debug files before successful return
|
| 490 |
+
cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path)
|
| 491 |
+
return messages, final_outputs, rendered_final_output
|
| 492 |
+
|
| 493 |
+
elif tool_call["name"] == "report_no_mask":
|
| 494 |
+
print("🔍 Calling report_no_mask tool...")
|
| 495 |
+
height, width = cv2.imread(img_path).shape[:2]
|
| 496 |
+
final_outputs = {
|
| 497 |
+
"original_image_path": img_path,
|
| 498 |
+
"orig_img_h": height,
|
| 499 |
+
"orig_img_w": width,
|
| 500 |
+
"pred_boxes": [],
|
| 501 |
+
"pred_scores": [],
|
| 502 |
+
"pred_masks": [],
|
| 503 |
+
}
|
| 504 |
+
rendered_final_output = Image.open(img_path)
|
| 505 |
+
messages.append(
|
| 506 |
+
{
|
| 507 |
+
"role": "assistant",
|
| 508 |
+
"content": [{"type": "text", "text": generated_text}],
|
| 509 |
+
}
|
| 510 |
+
)
|
| 511 |
+
return messages, final_outputs, rendered_final_output
|
| 512 |
+
|
| 513 |
+
else:
|
| 514 |
+
raise ValueError(f"Unknown tool call: {tool_call['name']}")
|
| 515 |
+
|
| 516 |
+
# sometimes the MLLM don't know when to stop, and generates multiple tool calls in one round, so we need to split the generated text by </tool> and only keep the first one
|
| 517 |
+
|
| 518 |
+
for message in messages:
|
| 519 |
+
if message["role"] == "assistant" and "content" in message:
|
| 520 |
+
for content in message["content"]:
|
| 521 |
+
if (
|
| 522 |
+
isinstance(content, dict)
|
| 523 |
+
and content.get("type") == "text"
|
| 524 |
+
and "text" in content
|
| 525 |
+
):
|
| 526 |
+
content["text"] = (
|
| 527 |
+
content["text"].split("</tool>", 1)[0] + "</tool>\n\n"
|
| 528 |
+
)
|
| 529 |
+
# Prune the messages history before the next MLLM generation round according to the 3-part rules.
|
| 530 |
+
# This keeps history compact and ensures the model sees only the allowed parts.
|
| 531 |
+
messages = _prune_messages_for_next_round(
|
| 532 |
+
messages,
|
| 533 |
+
USED_TEXT_PROMPTS,
|
| 534 |
+
LATEST_SAM3_TEXT_PROMPT,
|
| 535 |
+
img_path,
|
| 536 |
+
initial_text_prompt,
|
| 537 |
+
)
|
| 538 |
+
# make sure there can never be more than 2 images in the context
|
| 539 |
+
assert count_images(messages) <= 2
|
| 540 |
+
generation_count += 1
|
| 541 |
+
if generation_count > max_generations:
|
| 542 |
+
raise ValueError(
|
| 543 |
+
f"Exceeded maximum number of allowed generation requests ({max_generations})"
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
print("\n\n")
|
| 547 |
+
print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
|
| 548 |
+
print("\n\n")
|
| 549 |
+
generated_text = send_generate_request(messages)
|
| 550 |
+
print(
|
| 551 |
+
f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n"
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
print("\n\n>>> SAM 3 Agent execution ended.\n\n")
|
| 555 |
+
|
| 556 |
+
error_save_path = os.path.join(
|
| 557 |
+
error_save_dir,
|
| 558 |
+
f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}_error_history.json",
|
| 559 |
+
)
|
| 560 |
+
with open(error_save_path, "w") as f:
|
| 561 |
+
json.dump(messages, f, indent=4)
|
| 562 |
+
print("Saved messages history that caused error to:", error_save_path)
|
| 563 |
+
raise ValueError(
|
| 564 |
+
rf"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters for image path: {img_path} and initial text prompt: {initial_text_prompt}."
|
| 565 |
+
)
|
third_party/sam3/sam3/agent/client_llm.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, Optional
|
| 8 |
+
|
| 9 |
+
from openai import OpenAI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_image_base64_and_mime(image_path):
|
| 13 |
+
"""Convert image file to base64 string and get MIME type"""
|
| 14 |
+
try:
|
| 15 |
+
# Get MIME type based on file extension
|
| 16 |
+
ext = os.path.splitext(image_path)[1].lower()
|
| 17 |
+
mime_types = {
|
| 18 |
+
".jpg": "image/jpeg",
|
| 19 |
+
".jpeg": "image/jpeg",
|
| 20 |
+
".png": "image/png",
|
| 21 |
+
".gif": "image/gif",
|
| 22 |
+
".webp": "image/webp",
|
| 23 |
+
".bmp": "image/bmp",
|
| 24 |
+
}
|
| 25 |
+
mime_type = mime_types.get(ext, "image/jpeg") # Default to JPEG
|
| 26 |
+
|
| 27 |
+
# Convert image to base64
|
| 28 |
+
with open(image_path, "rb") as image_file:
|
| 29 |
+
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
| 30 |
+
return base64_data, mime_type
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"Error converting image to base64: {e}")
|
| 33 |
+
return None, None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def send_generate_request(
|
| 37 |
+
messages,
|
| 38 |
+
server_url=None,
|
| 39 |
+
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
| 40 |
+
api_key=None,
|
| 41 |
+
max_tokens=4096,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Sends a request to the OpenAI-compatible API endpoint using the OpenAI client library.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
server_url (str): The base URL of the server, e.g. "http://127.0.0.1:8000"
|
| 48 |
+
messages (list): A list of message dicts, each containing role and content.
|
| 49 |
+
model (str): The model to use for generation (default: "llama-4")
|
| 50 |
+
max_tokens (int): Maximum number of tokens to generate (default: 4096)
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
str: The generated response text from the server.
|
| 54 |
+
"""
|
| 55 |
+
# Process messages to convert image paths to base64
|
| 56 |
+
processed_messages = []
|
| 57 |
+
for message in messages:
|
| 58 |
+
processed_message = message.copy()
|
| 59 |
+
if message["role"] == "user" and "content" in message:
|
| 60 |
+
processed_content = []
|
| 61 |
+
for c in message["content"]:
|
| 62 |
+
if isinstance(c, dict) and c.get("type") == "image":
|
| 63 |
+
# Convert image path to base64 format
|
| 64 |
+
image_path = c["image"]
|
| 65 |
+
|
| 66 |
+
print("image_path", image_path)
|
| 67 |
+
new_image_path = image_path.replace(
|
| 68 |
+
"?", "%3F"
|
| 69 |
+
) # Escape ? in the path
|
| 70 |
+
|
| 71 |
+
# Read the image file and convert to base64
|
| 72 |
+
try:
|
| 73 |
+
base64_image, mime_type = get_image_base64_and_mime(
|
| 74 |
+
new_image_path
|
| 75 |
+
)
|
| 76 |
+
if base64_image is None:
|
| 77 |
+
print(
|
| 78 |
+
f"Warning: Could not convert image to base64: {new_image_path}"
|
| 79 |
+
)
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# Create the proper image_url structure with base64 data
|
| 83 |
+
processed_content.append(
|
| 84 |
+
{
|
| 85 |
+
"type": "image_url",
|
| 86 |
+
"image_url": {
|
| 87 |
+
"url": f"data:{mime_type};base64,{base64_image}",
|
| 88 |
+
"detail": "high",
|
| 89 |
+
},
|
| 90 |
+
}
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
except FileNotFoundError:
|
| 94 |
+
print(f"Warning: Image file not found: {new_image_path}")
|
| 95 |
+
continue
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"Warning: Error processing image {new_image_path}: {e}")
|
| 98 |
+
continue
|
| 99 |
+
else:
|
| 100 |
+
processed_content.append(c)
|
| 101 |
+
|
| 102 |
+
processed_message["content"] = processed_content
|
| 103 |
+
processed_messages.append(processed_message)
|
| 104 |
+
|
| 105 |
+
# Create OpenAI client with custom base URL
|
| 106 |
+
client = OpenAI(api_key=api_key, base_url=server_url)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
print(f"🔍 Calling model {model}...")
|
| 110 |
+
response = client.chat.completions.create(
|
| 111 |
+
model=model,
|
| 112 |
+
messages=processed_messages,
|
| 113 |
+
max_completion_tokens=max_tokens,
|
| 114 |
+
n=1,
|
| 115 |
+
)
|
| 116 |
+
# print(f"Received response: {response.choices[0].message}")
|
| 117 |
+
|
| 118 |
+
# Extract the response content
|
| 119 |
+
if response.choices and len(response.choices) > 0:
|
| 120 |
+
return response.choices[0].message.content
|
| 121 |
+
else:
|
| 122 |
+
print(f"Unexpected response format: {response}")
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"Request failed: {e}")
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def send_direct_request(
|
| 131 |
+
llm: Any,
|
| 132 |
+
messages: list[dict[str, Any]],
|
| 133 |
+
sampling_params: Any,
|
| 134 |
+
) -> Optional[str]:
|
| 135 |
+
"""
|
| 136 |
+
Run inference on a vLLM model instance directly without using a server.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
llm: Initialized vLLM LLM instance (passed from external initialization)
|
| 140 |
+
messages: List of message dicts with role and content (OpenAI format)
|
| 141 |
+
sampling_params: vLLM SamplingParams instance (initialized externally)
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
str: Generated response text, or None if inference fails
|
| 145 |
+
"""
|
| 146 |
+
try:
|
| 147 |
+
# Process messages to handle images (convert to base64 if needed)
|
| 148 |
+
processed_messages = []
|
| 149 |
+
for message in messages:
|
| 150 |
+
processed_message = message.copy()
|
| 151 |
+
if message["role"] == "user" and "content" in message:
|
| 152 |
+
processed_content = []
|
| 153 |
+
for c in message["content"]:
|
| 154 |
+
if isinstance(c, dict) and c.get("type") == "image":
|
| 155 |
+
# Convert image path to base64 format
|
| 156 |
+
image_path = c["image"]
|
| 157 |
+
new_image_path = image_path.replace("?", "%3F")
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
base64_image, mime_type = get_image_base64_and_mime(
|
| 161 |
+
new_image_path
|
| 162 |
+
)
|
| 163 |
+
if base64_image is None:
|
| 164 |
+
print(
|
| 165 |
+
f"Warning: Could not convert image: {new_image_path}"
|
| 166 |
+
)
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
# vLLM expects image_url format
|
| 170 |
+
processed_content.append(
|
| 171 |
+
{
|
| 172 |
+
"type": "image_url",
|
| 173 |
+
"image_url": {
|
| 174 |
+
"url": f"data:{mime_type};base64,{base64_image}"
|
| 175 |
+
},
|
| 176 |
+
}
|
| 177 |
+
)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(
|
| 180 |
+
f"Warning: Error processing image {new_image_path}: {e}"
|
| 181 |
+
)
|
| 182 |
+
continue
|
| 183 |
+
else:
|
| 184 |
+
processed_content.append(c)
|
| 185 |
+
|
| 186 |
+
processed_message["content"] = processed_content
|
| 187 |
+
processed_messages.append(processed_message)
|
| 188 |
+
|
| 189 |
+
print("🔍 Running direct inference with vLLM...")
|
| 190 |
+
|
| 191 |
+
# Run inference using vLLM's chat interface
|
| 192 |
+
outputs = llm.chat(
|
| 193 |
+
messages=processed_messages,
|
| 194 |
+
sampling_params=sampling_params,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Extract the generated text from the first output
|
| 198 |
+
if outputs and len(outputs) > 0:
|
| 199 |
+
generated_text = outputs[0].outputs[0].text
|
| 200 |
+
return generated_text
|
| 201 |
+
else:
|
| 202 |
+
print(f"Unexpected output format: {outputs}")
|
| 203 |
+
return None
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print(f"Direct inference failed: {e}")
|
| 207 |
+
return None
|
third_party/sam3/sam3/agent/client_sam3.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from sam3.model.box_ops import box_xyxy_to_xywh
|
| 11 |
+
from sam3.train.masks_ops import rle_encode
|
| 12 |
+
|
| 13 |
+
from .helpers.mask_overlap_removal import remove_overlapping_masks
|
| 14 |
+
from .viz import visualize
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def sam3_inference(processor, image_path, text_prompt):
|
| 18 |
+
"""Run SAM 3 image inference with text prompts and format the outputs"""
|
| 19 |
+
image = Image.open(image_path)
|
| 20 |
+
orig_img_w, orig_img_h = image.size
|
| 21 |
+
|
| 22 |
+
# model inference
|
| 23 |
+
inference_state = processor.set_image(image)
|
| 24 |
+
inference_state = processor.set_text_prompt(
|
| 25 |
+
state=inference_state, prompt=text_prompt
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# format and assemble outputs
|
| 29 |
+
pred_boxes_xyxy = torch.stack(
|
| 30 |
+
[
|
| 31 |
+
inference_state["boxes"][:, 0] / orig_img_w,
|
| 32 |
+
inference_state["boxes"][:, 1] / orig_img_h,
|
| 33 |
+
inference_state["boxes"][:, 2] / orig_img_w,
|
| 34 |
+
inference_state["boxes"][:, 3] / orig_img_h,
|
| 35 |
+
],
|
| 36 |
+
dim=-1,
|
| 37 |
+
) # normalized in range [0, 1]
|
| 38 |
+
pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist()
|
| 39 |
+
pred_masks = rle_encode(inference_state["masks"].squeeze(1))
|
| 40 |
+
pred_masks = [m["counts"] for m in pred_masks]
|
| 41 |
+
outputs = {
|
| 42 |
+
"orig_img_h": orig_img_h,
|
| 43 |
+
"orig_img_w": orig_img_w,
|
| 44 |
+
"pred_boxes": pred_boxes_xywh,
|
| 45 |
+
"pred_masks": pred_masks,
|
| 46 |
+
"pred_scores": inference_state["scores"].tolist(),
|
| 47 |
+
}
|
| 48 |
+
return outputs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def call_sam_service(
|
| 52 |
+
sam3_processor,
|
| 53 |
+
image_path: str,
|
| 54 |
+
text_prompt: str,
|
| 55 |
+
output_folder_path: str = "sam3_output",
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Loads an image, sends it with a text prompt to the service,
|
| 59 |
+
saves the results, and renders the visualization.
|
| 60 |
+
"""
|
| 61 |
+
print(f"📞 Loading image '{image_path}' and sending with prompt '{text_prompt}'...")
|
| 62 |
+
|
| 63 |
+
text_prompt_for_save_path = (
|
| 64 |
+
text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
os.makedirs(
|
| 68 |
+
os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True
|
| 69 |
+
)
|
| 70 |
+
output_json_path = os.path.join(
|
| 71 |
+
output_folder_path,
|
| 72 |
+
image_path.replace("/", "-"),
|
| 73 |
+
rf"{text_prompt_for_save_path}.json",
|
| 74 |
+
)
|
| 75 |
+
output_image_path = os.path.join(
|
| 76 |
+
output_folder_path,
|
| 77 |
+
image_path.replace("/", "-"),
|
| 78 |
+
rf"{text_prompt_for_save_path}.png",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
# Send the image and text prompt as a multipart/form-data request
|
| 83 |
+
serialized_response = sam3_inference(sam3_processor, image_path, text_prompt)
|
| 84 |
+
|
| 85 |
+
# 1. Prepare the response dictionary
|
| 86 |
+
serialized_response = remove_overlapping_masks(serialized_response)
|
| 87 |
+
serialized_response = {
|
| 88 |
+
"original_image_path": image_path,
|
| 89 |
+
"output_image_path": output_image_path,
|
| 90 |
+
**serialized_response,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# 2. Reorder predictions by scores (highest to lowest) if scores are available
|
| 94 |
+
if "pred_scores" in serialized_response and serialized_response["pred_scores"]:
|
| 95 |
+
# Create indices sorted by scores in descending order
|
| 96 |
+
score_indices = sorted(
|
| 97 |
+
range(len(serialized_response["pred_scores"])),
|
| 98 |
+
key=lambda i: serialized_response["pred_scores"][i],
|
| 99 |
+
reverse=True,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Reorder all three lists based on the sorted indices
|
| 103 |
+
serialized_response["pred_scores"] = [
|
| 104 |
+
serialized_response["pred_scores"][i] for i in score_indices
|
| 105 |
+
]
|
| 106 |
+
serialized_response["pred_boxes"] = [
|
| 107 |
+
serialized_response["pred_boxes"][i] for i in score_indices
|
| 108 |
+
]
|
| 109 |
+
serialized_response["pred_masks"] = [
|
| 110 |
+
serialized_response["pred_masks"][i] for i in score_indices
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
# 3. Remove any invalid RLE masks that is too short (shorter than 5 characters)
|
| 114 |
+
valid_masks = []
|
| 115 |
+
valid_boxes = []
|
| 116 |
+
valid_scores = []
|
| 117 |
+
for i, rle in enumerate(serialized_response["pred_masks"]):
|
| 118 |
+
if len(rle) > 4:
|
| 119 |
+
valid_masks.append(rle)
|
| 120 |
+
valid_boxes.append(serialized_response["pred_boxes"][i])
|
| 121 |
+
valid_scores.append(serialized_response["pred_scores"][i])
|
| 122 |
+
serialized_response["pred_masks"] = valid_masks
|
| 123 |
+
serialized_response["pred_boxes"] = valid_boxes
|
| 124 |
+
serialized_response["pred_scores"] = valid_scores
|
| 125 |
+
|
| 126 |
+
with open(output_json_path, "w") as f:
|
| 127 |
+
json.dump(serialized_response, f, indent=4)
|
| 128 |
+
print(f"✅ Raw JSON response saved to '{output_json_path}'")
|
| 129 |
+
|
| 130 |
+
# 4. Render and save visualizations on the image and save it in the SAM3 output folder
|
| 131 |
+
print("🔍 Rendering visualizations on the image ...")
|
| 132 |
+
viz_image = visualize(serialized_response)
|
| 133 |
+
os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
|
| 134 |
+
viz_image.save(output_image_path)
|
| 135 |
+
print("✅ Saved visualization at:", output_image_path)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"❌ Error calling service: {e}")
|
| 138 |
+
|
| 139 |
+
return output_json_path
|
third_party/sam3/sam3/agent/helpers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|