steven
FIX: patch DA3 export/__init__.py to skip pycolmap/gsplat deps
435e110
"""
ConceptPose Demo β€” Hugging Face Space (Gradio + ZeroGPU)
Estimate relative 6DoF pose between two images of the same object
using semantic concept-based 3D registration.
Pipeline: SAM3 (segmentation) β†’ DepthAnything3 (depth) β†’ ConceptPose (pose)
"""
import os
import subprocess
import sys
# DepthAnything3 β€” clone and patch to avoid heavy optional deps (pycolmap, gsplat, open3d).
# We only need the inference API, not export utilities.
_da3_dir = "/tmp/_da3_src"
if not os.path.exists(os.path.join(_da3_dir, "src", "depth_anything_3")):
subprocess.check_call(["rm", "-rf", _da3_dir])
subprocess.check_call([
"git", "clone", "--depth", "1",
"https://github.com/ByteDance-Seed/Depth-Anything-3.git", _da3_dir,
])
# Patch export/__init__.py to make all imports optional (we don't use export)
_export_init = os.path.join(_da3_dir, "src", "depth_anything_3", "utils", "export", "__init__.py")
with open(_export_init, "w") as f:
f.write(
"# Patched: lazy imports to avoid pycolmap/gsplat/open3d deps\n"
"def export(*args, **kwargs):\n"
" raise NotImplementedError('Export not available in this environment')\n"
"__all__ = [export]\n"
)
sys.path.insert(0, os.path.join(_da3_dir, "src"))
import gc
import tempfile
import numpy as np
import torch
import gradio as gr
import spaces
from PIL import Image
from pathlib import Path
# ---------------------------------------------------------------------------
# Pre-load the list of cached categories from parts.json (no GPU needed)
# ---------------------------------------------------------------------------
def _get_cached_categories():
"""Return list of categories available in the shipped parts.json cache."""
import json
parts_json = Path(__file__).parent / "concept_pose" / "partonomy" / "parts.json"
if not parts_json.exists():
# Installed as package β€” find via importlib
import importlib.resources
try:
ref = importlib.resources.files("concept_pose") / "partonomy" / "parts.json"
parts_json = ref
except Exception:
return []
try:
data = json.loads(parts_json.read_text() if hasattr(parts_json, "read_text") else open(parts_json).read())
return sorted(e.get("category_label", "") for e in data if e.get("category_label"))
except Exception:
return []
CACHED_CATEGORIES = _get_cached_categories()
# ---------------------------------------------------------------------------
# SAM3 in-process helper (replaces subprocess version for ZeroGPU)
# ---------------------------------------------------------------------------
def _sam3_segment_inprocess(estimator, image_paths, prompt):
"""
Run SAM3 segmentation in-process instead of subprocess.
ZeroGPU does not support spawning CUDA subprocesses, so we load SAM3
in the main process, segment, then unload to free VRAM.
"""
estimator._load_sam_model()
masks = []
for img_path in image_paths:
pil_image = Image.open(img_path).convert("RGB")
mask = estimator.get_object_mask(pil_image, prompt)
masks.append(mask)
estimator._unload_sam_model()
return masks
# ---------------------------------------------------------------------------
# GPU-accelerated pipeline
# ---------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_pipeline(
anchor_image: Image.Image,
query_image: Image.Image,
category: str,
custom_concepts: str,
gemini_api_key: str,
):
"""Full pose estimation pipeline β€” runs on ZeroGPU."""
# Ensure DA3 is on sys.path in the GPU worker process too
_da3_src = "/tmp/_da3_src/src"
if _da3_src not in sys.path:
sys.path.insert(0, _da3_src)
if anchor_image is None or query_image is None:
raise gr.Error("Please upload both an anchor and a query image.")
if not category or not category.strip():
raise gr.Error("Please enter an object category name.")
category = category.strip().lower()
# Parse custom concepts
concepts = None
if custom_concepts and custom_concepts.strip():
concepts = [c.strip() for c in custom_concepts.split(",") if c.strip()]
if len(concepts) == 0:
concepts = None
# Set Gemini key if provided
if gemini_api_key and gemini_api_key.strip():
os.environ["GEMINI_API_KEY"] = gemini_api_key.strip()
# Save PIL images to temp files (DA3 needs file paths)
tmp_dir = tempfile.mkdtemp()
anchor_path = os.path.join(tmp_dir, "anchor.jpg")
query_path = os.path.join(tmp_dir, "query.jpg")
anchor_image.save(anchor_path)
query_image.save(query_path)
try:
from concept_pose.demo.wild_pose_estimator import WildPoseEstimator
estimator = WildPoseEstimator(device="cuda")
# Monkey-patch: replace subprocess SAM3 with in-process version
# (ZeroGPU doesn't support spawning CUDA subprocesses)
estimator.get_object_masks_subprocess = (
lambda image_paths, prompt: _sam3_segment_inprocess(estimator, image_paths, prompt)
)
result = estimator.estimate(
anchor_image=anchor_path,
query_image=query_path,
category=category,
concepts=concepts,
visualize=True,
output_dir=tmp_dir,
)
# Build result text
if result["success"]:
R = result["R"]
t = result["t"]
n_corr = result["num_correspondences"]
n_inliers = result["num_inliers"]
labels = result.get("semantic_labels", [])
result_text = (
"Pose estimation successful!\n\n"
f"Correspondences: {n_corr}\n"
f"Inliers: {n_inliers}\n\n"
f"Rotation matrix:\n{np.array2string(R, precision=4, suppress_small=True)}\n\n"
f"Translation vector:\n{np.array2string(t, precision=4, suppress_small=True)}\n\n"
f"Semantic labels used ({len(labels)}): {', '.join(labels[:10])}"
+ ("..." if len(labels) > 10 else "")
)
else:
result_text = "Pose estimation failed. Try different images or a different category."
# Load visualization images if they exist
# visualize=True produces: anchor_building.png, query_estimation.png,
# pose_projection.png, pose_overlay.png
# (correspondences.png requires return_debug_info=True which is too expensive)
build_img = None
pose_img = None
# Show the query estimation visualization (concept saliency maps)
for name in ["query_estimation.png", "anchor_building.png"]:
p = os.path.join(tmp_dir, name)
if os.path.exists(p):
build_img = Image.open(p)
break
# Show pose overlay (projected anchor point cloud onto query)
for name in ["pose_overlay.png", "pose_projection.png"]:
p = os.path.join(tmp_dir, name)
if os.path.exists(p):
pose_img = Image.open(p)
break
# Cleanup
estimator.cleanup()
del estimator
gc.collect()
torch.cuda.empty_cache()
return result_text, build_img, pose_img
except Exception as e:
gc.collect()
torch.cuda.empty_cache()
raise gr.Error(f"Pipeline error: {e}")
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def build_demo():
with gr.Blocks(
title="ConceptPose Demo",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
"# ConceptPose: In-the-Wild 6D Pose Estimation\n"
"Upload two images of the **same object** from different viewpoints "
"and get the estimated relative 6DoF pose.\n\n"
"**Pipeline:** SAM3 (segmentation) β†’ DepthAnything3 (depth) β†’ ConceptPose (semantic 3D registration)\n\n"
"[Paper](https://arxiv.org/abs/2506.10806) | "
"[Code](https://github.com/StevenKuang/concept-pose)"
)
with gr.Row():
anchor_input = gr.Image(
label="Anchor Image (reference view)",
type="pil",
height=350,
)
query_input = gr.Image(
label="Query Image (target view)",
type="pil",
height=350,
)
category_input = gr.Textbox(
label="Object Category",
placeholder="e.g., car, bottle, mug, shoe, laptop ...",
info=f"Pre-cached categories: {', '.join(CACHED_CATEGORIES[:20])}{'...' if len(CACHED_CATEGORIES) > 20 else ''}",
)
with gr.Accordion("Advanced Options", open=False):
custom_concepts_input = gr.Textbox(
label="Custom Concepts (comma-separated)",
placeholder="e.g., wheel, door, windshield, roof, bumper",
info="Override auto-generated semantic parts. Leave empty to use defaults.",
)
gemini_key_input = gr.Textbox(
label="Gemini API Key",
type="password",
placeholder="Optional β€” only needed for categories not in the cache",
info="Required only for new categories not in the pre-cached list.",
)
run_btn = gr.Button("Estimate Pose", variant="primary", size="lg")
result_text = gr.Textbox(label="Result", lines=12, interactive=False)
with gr.Row():
build_output = gr.Image(label="Concept Saliency Visualization", type="pil")
pose_output = gr.Image(label="Pose Projection Visualization", type="pil")
# Examples
gr.Examples(
examples=[
["examples/car.jpg", "examples/car-2.jpg", "car"],
],
inputs=[anchor_input, query_input, category_input],
label="Example Pairs",
)
run_btn.click(
fn=run_pipeline,
inputs=[
anchor_input,
query_input,
category_input,
custom_concepts_input,
gemini_key_input,
],
outputs=[result_text, build_output, pose_output],
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()