hedonismbot24's picture
Add FoundationPose 6D pose estimation demo app
69016ef verified
"""
FoundationPose 6D Object Pose Estimation Demo
A polished Gradio interface for NVIDIA FoundationPose — the #1 method on the
BOP Challenge 2024 benchmark for model-based 6D object localization of unseen objects.
This app connects to a FoundationPose inference backend and provides:
- CAD-based (model-based) initialization with a 3D mesh
- Automatic object masking via SlimSAM
- 6D pose estimation (position + orientation)
- 3D pose visualization overlaid on the image
"""
import io
import logging
import math
import tempfile
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import cv2
import gradio as gr
import numpy as np
from PIL import Image
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
# ── Backend connection ──────────────────────────────────────────────────────
BACKEND_URL = "https://gpue-foundationpose.hf.space"
_gradio_client = None
def _get_client():
"""Lazy-load Gradio client to the FoundationPose backend."""
global _gradio_client
if _gradio_client is None:
from gradio_client import Client
logger.info(f"Connecting to FoundationPose backend at {BACKEND_URL}...")
_gradio_client = Client(BACKEND_URL)
logger.info("Connected.")
return _gradio_client
# ── Pose visualization ──────────────────────────────────────────────────────
def draw_pose_axes(
image: np.ndarray,
pose_matrix: np.ndarray,
K: np.ndarray,
axis_length: float = 0.05,
thickness: int = 3,
) -> np.ndarray:
"""Draw 3D coordinate axes on the image from a 4x4 pose matrix.
Red = X, Green = Y, Blue = Z.
"""
vis = image.copy()
R = pose_matrix[:3, :3]
t = pose_matrix[:3, 3]
# Origin and axis endpoints in 3D
origin = t.reshape(3, 1)
axes_3d = origin + R @ (np.eye(3) * axis_length) # (3, 3)
# Project to 2D
def project(pt3d):
p = (K @ pt3d).flatten()
if abs(p[2]) < 1e-6:
return None
return int(p[0] / p[2]), int(p[1] / p[2])
o2d = project(origin)
if o2d is None:
return vis
colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0)] # R, G, B for X, Y, Z
labels = ["X", "Y", "Z"]
for i in range(3):
end = project(axes_3d[:, i:i + 1])
if end is None:
continue
cv2.arrowedLine(vis, o2d, end, colors[i], thickness, tipLength=0.2)
cv2.putText(vis, labels[i], (end[0] + 5, end[1] - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, colors[i], 2)
# Draw origin circle
cv2.circle(vis, o2d, 5, (255, 255, 255), -1)
cv2.circle(vis, o2d, 5, (0, 0, 0), 2)
return vis
def draw_bounding_box_from_pose(
image: np.ndarray,
pose_matrix: np.ndarray,
K: np.ndarray,
size: float = 0.03,
) -> np.ndarray:
"""Draw a projected 3D bounding box around the object."""
vis = image.copy()
R = pose_matrix[:3, :3]
t = pose_matrix[:3, 3]
# 8 corners of a cube centered at origin
corners = np.array([
[-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1],
[-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1],
], dtype=np.float64) * size
# Transform to camera frame
corners_cam = (R @ corners.T + t.reshape(3, 1)).T # (8, 3)
# Project
def project(pt3d):
p = (K @ pt3d.reshape(3, 1)).flatten()
if abs(p[2]) < 1e-6:
return None
return int(p[0] / p[2]), int(p[1] / p[2])
pts_2d = [project(c) for c in corners_cam]
if any(p is None for p in pts_2d):
return vis
# Draw edges
edges = [
(0, 1), (1, 2), (2, 3), (3, 0), # back face
(4, 5), (5, 6), (6, 7), (7, 4), # front face
(0, 4), (1, 5), (2, 6), (3, 7), # connecting edges
]
for i, j in edges:
cv2.line(vis, pts_2d[i], pts_2d[j], (0, 255, 255), 2)
return vis
def quat_to_euler(w, x, y, z) -> Tuple[float, float, float]:
"""Convert quaternion to Euler angles (roll, pitch, yaw) in degrees."""
# Roll (X-axis rotation)
sinr_cosp = 2 * (w * x + y * z)
cosr_cosp = 1 - 2 * (x * x + y * y)
roll = math.atan2(sinr_cosp, cosr_cosp)
# Pitch (Y-axis rotation)
sinp = 2 * (w * y - z * x)
if abs(sinp) >= 1:
pitch = math.copysign(math.pi / 2, sinp)
else:
pitch = math.asin(sinp)
# Yaw (Z-axis rotation)
siny_cosp = 2 * (w * z + x * y)
cosy_cosp = 1 - 2 * (y * y + z * z)
yaw = math.atan2(siny_cosp, cosy_cosp)
return math.degrees(roll), math.degrees(pitch), math.degrees(yaw)
def format_pose_result(pose: Dict) -> str:
"""Format a pose result into a readable string."""
lines = []
pos = pose.get("position", {})
ori = pose.get("orientation", {})
lines.append("━━━ Position (meters) ━━━")
lines.append(f" X: {pos.get('x', 0):+.4f}")
lines.append(f" Y: {pos.get('y', 0):+.4f}")
lines.append(f" Z: {pos.get('z', 0):+.4f}")
lines.append("")
lines.append("━━━ Orientation (quaternion) ━━━")
lines.append(f" W: {ori.get('w', 0):+.6f}")
lines.append(f" X: {ori.get('x', 0):+.6f}")
lines.append(f" Y: {ori.get('y', 0):+.6f}")
lines.append(f" Z: {ori.get('z', 0):+.6f}")
# Euler angles
roll, pitch, yaw = quat_to_euler(
ori.get('w', 1), ori.get('x', 0),
ori.get('y', 0), ori.get('z', 0)
)
lines.append("")
lines.append("━━━ Euler Angles (degrees) ━━━")
lines.append(f" Roll: {roll:+.2f}°")
lines.append(f" Pitch: {pitch:+.2f}°")
lines.append(f" Yaw: {yaw:+.2f}°")
if "confidence" in pose:
lines.append("")
lines.append(f"━━━ Confidence: {pose['confidence']:.2%} ━━━")
return "\n".join(lines)
# ── Core API functions ───────────────────────────────────────────────────────
def initialize_object(
object_id: str,
mesh_file,
reference_files: List,
fx: float,
fy: float,
cx: float,
cy: float,
):
"""Initialize an object with a CAD mesh + optional reference images."""
if not object_id:
return "❌ Please provide an Object ID"
if not mesh_file:
return "❌ Please upload a 3D mesh file (.obj, .stl, .ply)"
try:
from gradio_client import handle_file
client = _get_client()
# Prepare reference files
ref_handles = []
if reference_files:
for f in reference_files:
if hasattr(f, 'name'):
ref_handles.append(handle_file(f.name))
elif isinstance(f, str):
ref_handles.append(handle_file(f))
mesh_handle = handle_file(mesh_file.name) if hasattr(mesh_file, 'name') else handle_file(mesh_file)
result = client.predict(
object_id,
mesh_handle,
ref_handles if ref_handles else None,
fx, fy, cx, cy,
api_name="/gradio_initialize_cad",
)
return f"✅ {result}"
except Exception as e:
logger.error(f"Initialization error: {e}", exc_info=True)
return f"❌ Error: {str(e)}"
def estimate_pose(
object_id: str,
query_image: np.ndarray,
depth_image: Optional[np.ndarray],
fx: float,
fy: float,
cx: float,
cy: float,
mask_method: str,
):
"""Estimate 6D pose and return visualization."""
if query_image is None:
return "❌ Please upload a query image", None
if not object_id:
return "❌ Please provide the Object ID (must match initialization)", None
try:
from gradio_client import handle_file
client = _get_client()
# Save query image to temp file
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
Image.fromarray(query_image).save(f.name)
query_path = f.name
# Save depth image if provided
depth_path = None
if depth_image is not None:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
Image.fromarray(depth_image).save(f.name)
depth_path = f.name
# Call the backend
result = client.predict(
object_id,
handle_file(query_path),
handle_file(depth_path) if depth_path else None,
fx, fy, cx, cy,
mask_method,
None, # mask_editor_data
api_name="/gradio_estimate",
)
# Parse result — the backend returns (text, image_path, mask_path)
if isinstance(result, (list, tuple)):
text_result = result[0] if len(result) > 0 else ""
viz_path = result[1] if len(result) > 1 else None
mask_path = result[2] if len(result) > 2 else None
else:
text_result = str(result)
viz_path = None
mask_path = None
# Build visualization
vis_image = query_image.copy()
# Try to parse pose from the text result
pose_info = _parse_pose_text(text_result)
if pose_info:
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
# If we have the full pose matrix, visualize it
if "pose_matrix" in pose_info:
pose_mat = np.array(pose_info["pose_matrix"])
vis_image = draw_pose_axes(vis_image, pose_mat, K, axis_length=0.05)
vis_image = draw_bounding_box_from_pose(vis_image, pose_mat, K, size=0.03)
elif "position" in pose_info:
# Build pose matrix from position + quaternion
pos = pose_info["position"]
ori = pose_info.get("orientation", {"w": 1, "x": 0, "y": 0, "z": 0})
pose_mat = _quat_pos_to_matrix(pos, ori)
vis_image = draw_pose_axes(vis_image, pose_mat, K, axis_length=0.05)
vis_image = draw_bounding_box_from_pose(vis_image, pose_mat, K, size=0.03)
formatted = format_pose_result(pose_info)
return f"✅ Pose Estimated Successfully\n\n{formatted}", vis_image
else:
# Return raw result from backend
return text_result, vis_image
except Exception as e:
logger.error(f"Pose estimation error: {e}", exc_info=True)
return f"❌ Error: {str(e)}", query_image
def _parse_pose_text(text: str) -> Optional[Dict]:
"""Parse pose information from the backend's text output."""
if not text or "No poses" in text or "failed" in text.lower() or "error" in text.lower():
return None
pose = {}
lines = text.strip().split("\n")
position = {}
orientation = {}
in_position = False
in_orientation = False
for line in lines:
line = line.strip()
if "Position:" in line:
in_position = True
in_orientation = False
continue
if "Orientation" in line:
in_position = False
in_orientation = True
continue
if "Confidence" in line:
in_position = False
in_orientation = False
try:
val = line.split(":")[-1].strip().rstrip("%")
pose["confidence"] = float(val) / 100 if "%" in line else float(val)
except (ValueError, IndexError):
pass
continue
if in_position:
if "x:" in line:
try:
position["x"] = float(line.split(":")[-1].strip().split()[0])
except (ValueError, IndexError):
pass
elif "y:" in line:
try:
position["y"] = float(line.split(":")[-1].strip().split()[0])
except (ValueError, IndexError):
pass
elif "z:" in line:
try:
position["z"] = float(line.split(":")[-1].strip().split()[0])
except (ValueError, IndexError):
pass
if in_orientation:
if "w:" in line:
try:
orientation["w"] = float(line.split(":")[-1].strip())
except (ValueError, IndexError):
pass
elif "x:" in line:
try:
orientation["x"] = float(line.split(":")[-1].strip())
except (ValueError, IndexError):
pass
elif "y:" in line:
try:
orientation["y"] = float(line.split(":")[-1].strip())
except (ValueError, IndexError):
pass
elif "z:" in line:
try:
orientation["z"] = float(line.split(":")[-1].strip())
except (ValueError, IndexError):
pass
if position:
pose["position"] = position
if orientation:
pose["orientation"] = orientation
return pose if pose else None
def _quat_pos_to_matrix(pos: Dict, ori: Dict) -> np.ndarray:
"""Convert position + quaternion to a 4x4 transformation matrix."""
w, x, y, z = ori.get("w", 1), ori.get("x", 0), ori.get("y", 0), ori.get("z", 0)
# Rotation matrix from quaternion
R = np.array([
[1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y)],
[2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x)],
[2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y)],
], dtype=np.float64)
T = np.eye(4, dtype=np.float64)
T[:3, :3] = R
T[0, 3] = pos.get("x", 0)
T[1, 3] = pos.get("y", 0)
T[2, 3] = pos.get("z", 0)
return T
# ── Gradio UI ────────────────────────────────────────────────────────────────
DESCRIPTION = """
# 🎯 FoundationPose — 6D Object Pose Estimation
**[FoundationPose](https://nvlabs.github.io/FoundationPose/)** by NVIDIA is the **#1 method** on the
[BOP Challenge 2024](https://bop.felk.cvut.cz/) benchmark for model-based 6D localization of unseen objects,
achieving an **AR score of 73.4** across 7 core datasets (LM-O, T-LESS, TUD-L, IC-BIN, ITODD, HB, YCB-V).
### How it works
1. **Initialize**: Upload a 3D mesh (.obj/.stl/.ply) of your object and optionally reference RGB images
2. **Estimate**: Upload a query RGB image (+ optional depth) and the model estimates the full 6D pose
3. **Visualize**: See the projected 3D axes and bounding box overlaid on the image
The pose output is a 4×4 transformation matrix (rotation + translation) from object frame to camera frame.
| Metric | Value |
|--------|-------|
| BOP AR Score | **73.4** |
| BOP Rank (2024) | **#1** (model-based unseen) |
| Paper | [CVPR 2024](https://arxiv.org/abs/2312.08344) |
| Input | RGB-D + CAD mesh |
"""
INIT_HELP = """
### 📋 Initialization Guide
**Required:**
- **Object ID**: A unique name for your object (e.g., "mug", "wrench")
- **3D Mesh**: Upload an `.obj`, `.stl`, or `.ply` file of the object
**Optional but recommended:**
- **Reference Images**: 1+ RGB images of the object from known viewpoints
- **Camera Intrinsics**: Focal lengths (fx, fy) and principal point (cx, cy)
> 💡 **Tip**: The default intrinsics work for the bundled test data. For your own images,
> use the calibration values from your camera.
"""
ESTIMATE_HELP = """
### 📋 Estimation Guide
- **Query Image**: An RGB image containing the initialized object
- **Depth Image**: Optional 16-bit depth map (improves accuracy significantly)
- **Mask Method**:
- `SlimSAM` — automatic segmentation (recommended)
- `Otsu` — simple brightness-based thresholding
> ⚠️ **Important**: Camera intrinsics must match the query image resolution.
> If you resize the image, scale fx/fy/cx/cy proportionally.
"""
def build_ui():
with gr.Blocks(
title="FoundationPose 6D Pose Estimation",
theme=gr.themes.Soft(),
css="""
.pose-output { font-family: monospace; }
.gr-button-primary { background: #6366f1 !important; }
""",
) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tabs():
# ── Tab 1: Initialize ─────────────────────────────────
with gr.Tab("① Initialize Object", id="init"):
gr.Markdown(INIT_HELP)
with gr.Row():
with gr.Column(scale=1):
init_object_id = gr.Textbox(
label="Object ID",
placeholder="e.g., target_cube",
value="target_cube",
)
init_mesh = gr.File(
label="3D Mesh (.obj / .stl / .ply)",
file_count="single",
file_types=[".obj", ".stl", ".ply", ".mesh"],
)
init_refs = gr.File(
label="Reference Images (optional)",
file_count="multiple",
file_types=["image"],
)
gr.Markdown("#### Camera Intrinsics")
with gr.Row():
init_fx = gr.Number(label="fx", value=193.137, precision=3)
init_fy = gr.Number(label="fy", value=193.137, precision=3)
with gr.Row():
init_cx = gr.Number(label="cx", value=120.0, precision=1)
init_cy = gr.Number(label="cy", value=80.0, precision=1)
init_btn = gr.Button("🚀 Initialize Object", variant="primary", size="lg")
with gr.Column(scale=1):
init_result = gr.Textbox(
label="Result",
lines=6,
interactive=False,
elem_classes=["pose-output"],
)
init_btn.click(
fn=initialize_object,
inputs=[init_object_id, init_mesh, init_refs, init_fx, init_fy, init_cx, init_cy],
outputs=init_result,
)
# ── Tab 2: Estimate Pose ──────────────────────────────
with gr.Tab("② Estimate Pose", id="estimate"):
gr.Markdown(ESTIMATE_HELP)
with gr.Row():
with gr.Column(scale=1):
est_object_id = gr.Textbox(
label="Object ID",
placeholder="Must match initialization",
value="target_cube",
)
est_query = gr.Image(
label="Query Image (RGB)",
type="numpy",
)
est_depth = gr.Image(
label="Depth Image (optional, 16-bit PNG)",
type="numpy",
)
est_mask = gr.Radio(
choices=["SlimSAM", "Otsu"],
value="SlimSAM",
label="Mask Method",
)
gr.Markdown("#### Camera Intrinsics")
with gr.Row():
est_fx = gr.Number(label="fx", value=193.137, precision=3)
est_fy = gr.Number(label="fy", value=193.137, precision=3)
with gr.Row():
est_cx = gr.Number(label="cx", value=120.0, precision=1)
est_cy = gr.Number(label="cy", value=80.0, precision=1)
est_btn = gr.Button("🎯 Estimate Pose", variant="primary", size="lg")
with gr.Column(scale=1):
est_viz = gr.Image(
label="Pose Visualization (axes + bounding box)",
type="numpy",
)
est_result = gr.Textbox(
label="Pose Output",
lines=18,
interactive=False,
elem_classes=["pose-output"],
)
est_btn.click(
fn=estimate_pose,
inputs=[est_object_id, est_query, est_depth, est_fx, est_fy, est_cx, est_cy, est_mask],
outputs=[est_result, est_viz],
)
# ── Tab 3: About ──────────────────────────────────────
with gr.Tab("ℹ️ About"):
gr.Markdown("""
## Architecture
FoundationPose uses a **two-stage pipeline**:
1. **Pose Hypothesis Generation**: Generates 42 coarse pose hypotheses from uniformly sampled viewpoints
2. **Transformer-based Refinement**: A ResNet-34 backbone with 4-head attention refines each hypothesis
3. **Contrastive Ranking**: InfoNCE loss ranks hypotheses, selecting the best pose
### Training Data
- **600K synthetic scenes** rendered on Objaverse objects with LLM-aided texture augmentation
- **1.2M training images** — no real-world training data needed
### Two Modes
- **Model-Based**: Uses a CAD mesh for precise render-and-compare
- **Model-Free**: Reconstructs a NeRF from 16-20 reference images
## BOP Challenge 2024 Results
| Dataset | AR Score |
|---------|----------|
| LM-O | 75.6 |
| T-LESS | 64.6 |
| TUD-L | 92.3 |
| IC-BIN | 50.8 |
| ITODD | 58.0 |
| HB | 83.5 |
| YCB-V | 88.9 |
| **Average** | **73.4** |
## Citation
```bibtex
@inproceedings{wen2024foundationpose,
title={FoundationPose: Unified 6D Pose Estimation and Tracking of Novel Objects},
author={Wen, Bowen and Yang, Wei and Kautz, Jan and Birchfield, Stan},
booktitle={CVPR},
year={2024}
}
```
## Links
- [Paper (arXiv)](https://arxiv.org/abs/2312.08344)
- [Project Page](https://nvlabs.github.io/FoundationPose/)
- [GitHub](https://github.com/NVlabs/FoundationPose)
- [Model Weights (HF Hub)](https://huggingface.co/gpue/foundationpose-weights)
- [Backend Space](https://huggingface.co/spaces/gpue/foundationpose)
- [BOP Challenge](https://bop.felk.cvut.cz/)
""")
gr.Markdown("""
---
<center>
Built with ❤️ using [FoundationPose](https://nvlabs.github.io/FoundationPose/) by NVIDIA
and [Gradio](https://gradio.app) — Powered by the
[FoundationPose backend Space](https://huggingface.co/spaces/gpue/foundationpose)
</center>
""")
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch(server_name="0.0.0.0", server_port=7860)