Spaces:
Running on Zero
Running on Zero
| """ | |
| ConceptPose Demo β Hugging Face Space (Gradio + ZeroGPU) | |
| Estimate relative 6DoF pose between two images of the same object | |
| using semantic concept-based 3D registration. | |
| Pipeline: SAM3 (segmentation) β DepthAnything3 (depth) β ConceptPose (pose) | |
| """ | |
| import os | |
| import subprocess | |
| import sys | |
| # DepthAnything3 β clone and patch to avoid heavy optional deps (pycolmap, gsplat, open3d). | |
| # We only need the inference API, not export utilities. | |
| _da3_dir = "/tmp/_da3_src" | |
| if not os.path.exists(os.path.join(_da3_dir, "src", "depth_anything_3")): | |
| subprocess.check_call(["rm", "-rf", _da3_dir]) | |
| subprocess.check_call([ | |
| "git", "clone", "--depth", "1", | |
| "https://github.com/ByteDance-Seed/Depth-Anything-3.git", _da3_dir, | |
| ]) | |
| # Patch export/__init__.py to make all imports optional (we don't use export) | |
| _export_init = os.path.join(_da3_dir, "src", "depth_anything_3", "utils", "export", "__init__.py") | |
| with open(_export_init, "w") as f: | |
| f.write( | |
| "# Patched: lazy imports to avoid pycolmap/gsplat/open3d deps\n" | |
| "def export(*args, **kwargs):\n" | |
| " raise NotImplementedError('Export not available in this environment')\n" | |
| "__all__ = [export]\n" | |
| ) | |
| sys.path.insert(0, os.path.join(_da3_dir, "src")) | |
| import gc | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from PIL import Image | |
| from pathlib import Path | |
| # --------------------------------------------------------------------------- | |
| # Pre-load the list of cached categories from parts.json (no GPU needed) | |
| # --------------------------------------------------------------------------- | |
| def _get_cached_categories(): | |
| """Return list of categories available in the shipped parts.json cache.""" | |
| import json | |
| parts_json = Path(__file__).parent / "concept_pose" / "partonomy" / "parts.json" | |
| if not parts_json.exists(): | |
| # Installed as package β find via importlib | |
| import importlib.resources | |
| try: | |
| ref = importlib.resources.files("concept_pose") / "partonomy" / "parts.json" | |
| parts_json = ref | |
| except Exception: | |
| return [] | |
| try: | |
| data = json.loads(parts_json.read_text() if hasattr(parts_json, "read_text") else open(parts_json).read()) | |
| return sorted(e.get("category_label", "") for e in data if e.get("category_label")) | |
| except Exception: | |
| return [] | |
| CACHED_CATEGORIES = _get_cached_categories() | |
| # --------------------------------------------------------------------------- | |
| # SAM3 in-process helper (replaces subprocess version for ZeroGPU) | |
| # --------------------------------------------------------------------------- | |
| def _sam3_segment_inprocess(estimator, image_paths, prompt): | |
| """ | |
| Run SAM3 segmentation in-process instead of subprocess. | |
| ZeroGPU does not support spawning CUDA subprocesses, so we load SAM3 | |
| in the main process, segment, then unload to free VRAM. | |
| """ | |
| estimator._load_sam_model() | |
| masks = [] | |
| for img_path in image_paths: | |
| pil_image = Image.open(img_path).convert("RGB") | |
| mask = estimator.get_object_mask(pil_image, prompt) | |
| masks.append(mask) | |
| estimator._unload_sam_model() | |
| return masks | |
| # --------------------------------------------------------------------------- | |
| # GPU-accelerated pipeline | |
| # --------------------------------------------------------------------------- | |
| def run_pipeline( | |
| anchor_image: Image.Image, | |
| query_image: Image.Image, | |
| category: str, | |
| custom_concepts: str, | |
| gemini_api_key: str, | |
| ): | |
| """Full pose estimation pipeline β runs on ZeroGPU.""" | |
| # Ensure DA3 is on sys.path in the GPU worker process too | |
| _da3_src = "/tmp/_da3_src/src" | |
| if _da3_src not in sys.path: | |
| sys.path.insert(0, _da3_src) | |
| if anchor_image is None or query_image is None: | |
| raise gr.Error("Please upload both an anchor and a query image.") | |
| if not category or not category.strip(): | |
| raise gr.Error("Please enter an object category name.") | |
| category = category.strip().lower() | |
| # Parse custom concepts | |
| concepts = None | |
| if custom_concepts and custom_concepts.strip(): | |
| concepts = [c.strip() for c in custom_concepts.split(",") if c.strip()] | |
| if len(concepts) == 0: | |
| concepts = None | |
| # Set Gemini key if provided | |
| if gemini_api_key and gemini_api_key.strip(): | |
| os.environ["GEMINI_API_KEY"] = gemini_api_key.strip() | |
| # Save PIL images to temp files (DA3 needs file paths) | |
| tmp_dir = tempfile.mkdtemp() | |
| anchor_path = os.path.join(tmp_dir, "anchor.jpg") | |
| query_path = os.path.join(tmp_dir, "query.jpg") | |
| anchor_image.save(anchor_path) | |
| query_image.save(query_path) | |
| try: | |
| from concept_pose.demo.wild_pose_estimator import WildPoseEstimator | |
| estimator = WildPoseEstimator(device="cuda") | |
| # Monkey-patch: replace subprocess SAM3 with in-process version | |
| # (ZeroGPU doesn't support spawning CUDA subprocesses) | |
| estimator.get_object_masks_subprocess = ( | |
| lambda image_paths, prompt: _sam3_segment_inprocess(estimator, image_paths, prompt) | |
| ) | |
| result = estimator.estimate( | |
| anchor_image=anchor_path, | |
| query_image=query_path, | |
| category=category, | |
| concepts=concepts, | |
| visualize=True, | |
| output_dir=tmp_dir, | |
| ) | |
| # Build result text | |
| if result["success"]: | |
| R = result["R"] | |
| t = result["t"] | |
| n_corr = result["num_correspondences"] | |
| n_inliers = result["num_inliers"] | |
| labels = result.get("semantic_labels", []) | |
| result_text = ( | |
| "Pose estimation successful!\n\n" | |
| f"Correspondences: {n_corr}\n" | |
| f"Inliers: {n_inliers}\n\n" | |
| f"Rotation matrix:\n{np.array2string(R, precision=4, suppress_small=True)}\n\n" | |
| f"Translation vector:\n{np.array2string(t, precision=4, suppress_small=True)}\n\n" | |
| f"Semantic labels used ({len(labels)}): {', '.join(labels[:10])}" | |
| + ("..." if len(labels) > 10 else "") | |
| ) | |
| else: | |
| result_text = "Pose estimation failed. Try different images or a different category." | |
| # Load visualization images if they exist | |
| # visualize=True produces: anchor_building.png, query_estimation.png, | |
| # pose_projection.png, pose_overlay.png | |
| # (correspondences.png requires return_debug_info=True which is too expensive) | |
| build_img = None | |
| pose_img = None | |
| # Show the query estimation visualization (concept saliency maps) | |
| for name in ["query_estimation.png", "anchor_building.png"]: | |
| p = os.path.join(tmp_dir, name) | |
| if os.path.exists(p): | |
| build_img = Image.open(p) | |
| break | |
| # Show pose overlay (projected anchor point cloud onto query) | |
| for name in ["pose_overlay.png", "pose_projection.png"]: | |
| p = os.path.join(tmp_dir, name) | |
| if os.path.exists(p): | |
| pose_img = Image.open(p) | |
| break | |
| # Cleanup | |
| estimator.cleanup() | |
| del estimator | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return result_text, build_img, pose_img | |
| except Exception as e: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| raise gr.Error(f"Pipeline error: {e}") | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def build_demo(): | |
| with gr.Blocks( | |
| title="ConceptPose Demo", | |
| theme=gr.themes.Soft(), | |
| ) as demo: | |
| gr.Markdown( | |
| "# ConceptPose: In-the-Wild 6D Pose Estimation\n" | |
| "Upload two images of the **same object** from different viewpoints " | |
| "and get the estimated relative 6DoF pose.\n\n" | |
| "**Pipeline:** SAM3 (segmentation) β DepthAnything3 (depth) β ConceptPose (semantic 3D registration)\n\n" | |
| "[Paper](https://arxiv.org/abs/2506.10806) | " | |
| "[Code](https://github.com/StevenKuang/concept-pose)" | |
| ) | |
| with gr.Row(): | |
| anchor_input = gr.Image( | |
| label="Anchor Image (reference view)", | |
| type="pil", | |
| height=350, | |
| ) | |
| query_input = gr.Image( | |
| label="Query Image (target view)", | |
| type="pil", | |
| height=350, | |
| ) | |
| category_input = gr.Textbox( | |
| label="Object Category", | |
| placeholder="e.g., car, bottle, mug, shoe, laptop ...", | |
| info=f"Pre-cached categories: {', '.join(CACHED_CATEGORIES[:20])}{'...' if len(CACHED_CATEGORIES) > 20 else ''}", | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| custom_concepts_input = gr.Textbox( | |
| label="Custom Concepts (comma-separated)", | |
| placeholder="e.g., wheel, door, windshield, roof, bumper", | |
| info="Override auto-generated semantic parts. Leave empty to use defaults.", | |
| ) | |
| gemini_key_input = gr.Textbox( | |
| label="Gemini API Key", | |
| type="password", | |
| placeholder="Optional β only needed for categories not in the cache", | |
| info="Required only for new categories not in the pre-cached list.", | |
| ) | |
| run_btn = gr.Button("Estimate Pose", variant="primary", size="lg") | |
| result_text = gr.Textbox(label="Result", lines=12, interactive=False) | |
| with gr.Row(): | |
| build_output = gr.Image(label="Concept Saliency Visualization", type="pil") | |
| pose_output = gr.Image(label="Pose Projection Visualization", type="pil") | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["examples/car.jpg", "examples/car-2.jpg", "car"], | |
| ], | |
| inputs=[anchor_input, query_input, category_input], | |
| label="Example Pairs", | |
| ) | |
| run_btn.click( | |
| fn=run_pipeline, | |
| inputs=[ | |
| anchor_input, | |
| query_input, | |
| category_input, | |
| custom_concepts_input, | |
| gemini_key_input, | |
| ], | |
| outputs=[result_text, build_output, pose_output], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() | |