""" ConceptPose Demo — Hugging Face Space (Gradio + ZeroGPU) Estimate relative 6DoF pose between two images of the same object using semantic concept-based 3D registration. Pipeline: SAM3 (segmentation) → DepthAnything3 (depth) → ConceptPose (pose) """ import os import subprocess import sys # DepthAnything3 — clone and patch to avoid heavy optional deps (pycolmap, gsplat, open3d). # We only need the inference API, not export utilities. _da3_dir = "/tmp/_da3_src" if not os.path.exists(os.path.join(_da3_dir, "src", "depth_anything_3")): subprocess.check_call(["rm", "-rf", _da3_dir]) subprocess.check_call([ "git", "clone", "--depth", "1", "https://github.com/ByteDance-Seed/Depth-Anything-3.git", _da3_dir, ]) # Patch export/__init__.py to make all imports optional (we don't use export) _export_init = os.path.join(_da3_dir, "src", "depth_anything_3", "utils", "export", "__init__.py") with open(_export_init, "w") as f: f.write( "# Patched: lazy imports to avoid pycolmap/gsplat/open3d deps\n" "def export(*args, **kwargs):\n" " raise NotImplementedError('Export not available in this environment')\n" "__all__ = [export]\n" ) sys.path.insert(0, os.path.join(_da3_dir, "src")) import gc import tempfile import numpy as np import torch import gradio as gr import spaces from PIL import Image from pathlib import Path # --------------------------------------------------------------------------- # Pre-load the list of cached categories from parts.json (no GPU needed) # --------------------------------------------------------------------------- def _get_cached_categories(): """Return list of categories available in the shipped parts.json cache.""" import json parts_json = Path(__file__).parent / "concept_pose" / "partonomy" / "parts.json" if not parts_json.exists(): # Installed as package — find via importlib import importlib.resources try: ref = importlib.resources.files("concept_pose") / "partonomy" / "parts.json" parts_json = ref except Exception: return [] try: data = json.loads(parts_json.read_text() if hasattr(parts_json, "read_text") else open(parts_json).read()) return sorted(e.get("category_label", "") for e in data if e.get("category_label")) except Exception: return [] CACHED_CATEGORIES = _get_cached_categories() # --------------------------------------------------------------------------- # SAM3 in-process helper (replaces subprocess version for ZeroGPU) # --------------------------------------------------------------------------- def _sam3_segment_inprocess(estimator, image_paths, prompt): """ Run SAM3 segmentation in-process instead of subprocess. ZeroGPU does not support spawning CUDA subprocesses, so we load SAM3 in the main process, segment, then unload to free VRAM. """ estimator._load_sam_model() masks = [] for img_path in image_paths: pil_image = Image.open(img_path).convert("RGB") mask = estimator.get_object_mask(pil_image, prompt) masks.append(mask) estimator._unload_sam_model() return masks # --------------------------------------------------------------------------- # GPU-accelerated pipeline # --------------------------------------------------------------------------- @spaces.GPU(duration=120) def run_pipeline( anchor_image: Image.Image, query_image: Image.Image, category: str, custom_concepts: str, gemini_api_key: str, ): """Full pose estimation pipeline — runs on ZeroGPU.""" # Ensure DA3 is on sys.path in the GPU worker process too _da3_src = "/tmp/_da3_src/src" if _da3_src not in sys.path: sys.path.insert(0, _da3_src) if anchor_image is None or query_image is None: raise gr.Error("Please upload both an anchor and a query image.") if not category or not category.strip(): raise gr.Error("Please enter an object category name.") category = category.strip().lower() # Parse custom concepts concepts = None if custom_concepts and custom_concepts.strip(): concepts = [c.strip() for c in custom_concepts.split(",") if c.strip()] if len(concepts) == 0: concepts = None # Set Gemini key if provided if gemini_api_key and gemini_api_key.strip(): os.environ["GEMINI_API_KEY"] = gemini_api_key.strip() # Save PIL images to temp files (DA3 needs file paths) tmp_dir = tempfile.mkdtemp() anchor_path = os.path.join(tmp_dir, "anchor.jpg") query_path = os.path.join(tmp_dir, "query.jpg") anchor_image.save(anchor_path) query_image.save(query_path) try: from concept_pose.demo.wild_pose_estimator import WildPoseEstimator estimator = WildPoseEstimator(device="cuda") # Monkey-patch: replace subprocess SAM3 with in-process version # (ZeroGPU doesn't support spawning CUDA subprocesses) estimator.get_object_masks_subprocess = ( lambda image_paths, prompt: _sam3_segment_inprocess(estimator, image_paths, prompt) ) result = estimator.estimate( anchor_image=anchor_path, query_image=query_path, category=category, concepts=concepts, visualize=True, output_dir=tmp_dir, ) # Build result text if result["success"]: R = result["R"] t = result["t"] n_corr = result["num_correspondences"] n_inliers = result["num_inliers"] labels = result.get("semantic_labels", []) result_text = ( "Pose estimation successful!\n\n" f"Correspondences: {n_corr}\n" f"Inliers: {n_inliers}\n\n" f"Rotation matrix:\n{np.array2string(R, precision=4, suppress_small=True)}\n\n" f"Translation vector:\n{np.array2string(t, precision=4, suppress_small=True)}\n\n" f"Semantic labels used ({len(labels)}): {', '.join(labels[:10])}" + ("..." if len(labels) > 10 else "") ) else: result_text = "Pose estimation failed. Try different images or a different category." # Load visualization images if they exist # visualize=True produces: anchor_building.png, query_estimation.png, # pose_projection.png, pose_overlay.png # (correspondences.png requires return_debug_info=True which is too expensive) build_img = None pose_img = None # Show the query estimation visualization (concept saliency maps) for name in ["query_estimation.png", "anchor_building.png"]: p = os.path.join(tmp_dir, name) if os.path.exists(p): build_img = Image.open(p) break # Show pose overlay (projected anchor point cloud onto query) for name in ["pose_overlay.png", "pose_projection.png"]: p = os.path.join(tmp_dir, name) if os.path.exists(p): pose_img = Image.open(p) break # Cleanup estimator.cleanup() del estimator gc.collect() torch.cuda.empty_cache() return result_text, build_img, pose_img except Exception as e: gc.collect() torch.cuda.empty_cache() raise gr.Error(f"Pipeline error: {e}") # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def build_demo(): with gr.Blocks( title="ConceptPose Demo", theme=gr.themes.Soft(), ) as demo: gr.Markdown( "# ConceptPose: In-the-Wild 6D Pose Estimation\n" "Upload two images of the **same object** from different viewpoints " "and get the estimated relative 6DoF pose.\n\n" "**Pipeline:** SAM3 (segmentation) → DepthAnything3 (depth) → ConceptPose (semantic 3D registration)\n\n" "[Paper](https://arxiv.org/abs/2506.10806) | " "[Code](https://github.com/StevenKuang/concept-pose)" ) with gr.Row(): anchor_input = gr.Image( label="Anchor Image (reference view)", type="pil", height=350, ) query_input = gr.Image( label="Query Image (target view)", type="pil", height=350, ) category_input = gr.Textbox( label="Object Category", placeholder="e.g., car, bottle, mug, shoe, laptop ...", info=f"Pre-cached categories: {', '.join(CACHED_CATEGORIES[:20])}{'...' if len(CACHED_CATEGORIES) > 20 else ''}", ) with gr.Accordion("Advanced Options", open=False): custom_concepts_input = gr.Textbox( label="Custom Concepts (comma-separated)", placeholder="e.g., wheel, door, windshield, roof, bumper", info="Override auto-generated semantic parts. Leave empty to use defaults.", ) gemini_key_input = gr.Textbox( label="Gemini API Key", type="password", placeholder="Optional — only needed for categories not in the cache", info="Required only for new categories not in the pre-cached list.", ) run_btn = gr.Button("Estimate Pose", variant="primary", size="lg") result_text = gr.Textbox(label="Result", lines=12, interactive=False) with gr.Row(): build_output = gr.Image(label="Concept Saliency Visualization", type="pil") pose_output = gr.Image(label="Pose Projection Visualization", type="pil") # Examples gr.Examples( examples=[ ["examples/car.jpg", "examples/car-2.jpg", "car"], ], inputs=[anchor_input, query_input, category_input], label="Example Pairs", ) run_btn.click( fn=run_pipeline, inputs=[ anchor_input, query_input, category_input, custom_concepts_input, gemini_key_input, ], outputs=[result_text, build_output, pose_output], ) return demo if __name__ == "__main__": demo = build_demo() demo.launch()