foundationpose / app.py
Georg
Prepare job build context
a2e4f10
"""
FoundationPose inference server for Hugging Face Spaces with ZeroGPU.
This version uses pure Gradio for ZeroGPU compatibility.
"""
import base64
import json
import logging
import os
from pathlib import Path
from typing import Dict, List
# Ensure OMP_NUM_THREADS is a valid integer to avoid libgomp warnings
if not os.environ.get("OMP_NUM_THREADS", "").isdigit():
os.environ["OMP_NUM_THREADS"] = "1"
import cv2
import gradio as gr
import numpy as np
import torch
from masks import generate_naive_mask
DEFAULT_DATA_DIR = Path("/app/tests/reference/t_shape")
DEFAULT_MESH = DEFAULT_DATA_DIR / "t_shape.obj"
DEFAULT_RGB = DEFAULT_DATA_DIR / "rgb_001.jpg"
DEFAULT_DEPTH = DEFAULT_DATA_DIR / "depth_001.png"
DEFAULT_REF_IMAGES = [
DEFAULT_DATA_DIR / "rgb_001.jpg",
DEFAULT_DATA_DIR / "rgb_002.jpg",
DEFAULT_DATA_DIR / "rgb_003.jpg",
]
_slimsam_model = None
_slimsam_processor = None
_slimsam_device = None
def _get_slimsam():
"""Lazy-load SlimSAM to avoid heavy startup cost."""
global _slimsam_model, _slimsam_processor, _slimsam_device
if _slimsam_model is None or _slimsam_processor is None:
from transformers import SamModel, SamProcessor
_slimsam_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(_slimsam_device)
_slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
logger.info("SlimSAM loaded on %s", _slimsam_device)
return _slimsam_model, _slimsam_processor, _slimsam_device
def _box_from_mask(mask_bool: np.ndarray) -> List[int]:
ys, xs = np.where(mask_bool)
if len(xs) == 0:
return [0, 0, mask_bool.shape[1] - 1, mask_bool.shape[0] - 1]
x0, x1 = int(xs.min()), int(xs.max())
y0, y1 = int(ys.min()), int(ys.max())
return [x0, y0, x1, y1]
def generate_slimsam_mask(rgb_image: np.ndarray, box_prompt: List[int]) -> tuple[np.ndarray, np.ndarray, float]:
"""Generate a SlimSAM mask using a box prompt."""
from PIL import Image
model, processor, device = _get_slimsam()
raw_image = Image.fromarray(rgb_image).convert("RGB")
enc = processor(raw_image, input_boxes=[[box_prompt]], return_tensors="np")
# Keep size tensors on CPU for post-processing
original_sizes = torch.as_tensor(enc["original_sizes"])
reshaped_sizes = torch.as_tensor(enc["reshaped_input_sizes"])
# Move model inputs to device
inputs = {
k: torch.as_tensor(v).to(device)
for k, v in enc.items()
if k not in {"original_sizes", "reshaped_input_sizes"}
}
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
original_sizes,
reshaped_sizes,
)[0]
scores = outputs.iou_scores.squeeze().cpu()
best_idx = int(scores.argmax().item())
best_mask = masks[0, best_idx].numpy()
best_score = float(scores[best_idx].item())
mask_bool = best_mask.astype(bool)
debug_mask = (mask_bool.astype(np.uint8) * 255)
return mask_bool, debug_mask, best_score
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s"
)
logger = logging.getLogger(__name__)
# Ensure OMP_NUM_THREADS is a valid integer to avoid libgomp warnings
if not os.environ.get("OMP_NUM_THREADS", "").isdigit():
os.environ["OMP_NUM_THREADS"] = "1"
# Always use real FoundationPose model
USE_REAL_MODEL = True
logger.info("Starting in REAL mode with FoundationPose")
class FoundationPoseInference:
"""Wrapper for FoundationPose model inference."""
def __init__(self):
self.model = None
self.device = None
self.initialized = False
self.tracked_objects = {}
self.use_real_model = USE_REAL_MODEL
def initialize_model(self):
"""Initialize the FoundationPose model on GPU."""
if self.initialized:
logger.info("Model already initialized")
return
logger.info("Initializing FoundationPose model...")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
if self.use_real_model:
try:
logger.info("Loading real FoundationPose model...")
from estimator import FoundationPoseEstimator
self.model = FoundationPoseEstimator(
device=str(self.device)
)
if getattr(self.model, "available", True):
logger.info("✓ Real FoundationPose model initialized successfully")
else:
raise RuntimeError("FoundationPose dependencies missing")
except Exception as e:
logger.error(f"Failed to initialize real model: {e}", exc_info=True)
logger.warning("Falling back to placeholder mode")
self.use_real_model = False
self.model = None
else:
logger.info("Using placeholder mode (set USE_REAL_MODEL=true for real inference)")
self.model = None
self.initialized = True
logger.info("FoundationPose inference ready")
def register_object(
self,
object_id: str,
reference_images: List[np.ndarray],
camera_intrinsics: Dict = None,
mesh_path: str = None
) -> bool:
"""Register an object for tracking with reference images."""
if not self.initialized:
self.initialize_model()
logger.info(f"Registering object '{object_id}' with {len(reference_images)} reference images")
if self.use_real_model and self.model is not None:
try:
success = self.model.register_object(
object_id=object_id,
reference_images=reference_images,
camera_intrinsics=camera_intrinsics,
mesh_path=mesh_path
)
if success:
self.tracked_objects[object_id] = {
"num_references": len(reference_images),
"camera_intrinsics": camera_intrinsics,
"mesh_path": mesh_path
}
return success
except Exception as e:
logger.error(f"Registration failed: {e}", exc_info=True)
return False
else:
self.tracked_objects[object_id] = {
"num_references": len(reference_images),
"camera_intrinsics": camera_intrinsics,
"mesh_path": mesh_path
}
logger.info(f"✓ Object '{object_id}' registered (placeholder mode)")
return True
def estimate_pose(
self,
object_id: str,
query_image: np.ndarray,
camera_intrinsics: Dict = None,
depth_image: np.ndarray = None,
mask: np.ndarray = None
) -> Dict:
"""Estimate 6D pose of an object in a query image."""
if not self.initialized:
return {"success": False, "error": "Model not initialized"}
if object_id not in self.tracked_objects:
return {"success": False, "error": f"Object '{object_id}' not registered"}
logger.info(f"Estimating pose for object '{object_id}'")
if self.use_real_model and self.model is not None:
try:
pose_result = self.model.estimate_pose(
object_id=object_id,
rgb_image=query_image,
depth_image=depth_image,
mask=mask,
camera_intrinsics=camera_intrinsics
)
if pose_result is None:
return {
"success": False,
"error": "Pose estimation returned None",
"poses": [],
"debug_mask": None
}
# Extract debug mask if present
debug_mask = pose_result.pop("debug_mask", None)
return {
"success": True,
"poses": [pose_result],
"debug_mask": debug_mask
}
except Exception as e:
logger.error(f"Pose estimation error: {e}", exc_info=True)
return {"success": False, "error": str(e), "poses": []}
else:
logger.info("Placeholder mode: returning empty pose result")
return {
"success": True,
"poses": [],
"note": "Placeholder mode - set USE_REAL_MODEL=true for real inference"
}
# Global model instance
pose_estimator = FoundationPoseInference()
# Gradio wrapper functions
def gradio_initialize_cad(object_id: str, mesh_file, reference_files: List, fx: float, fy: float, cx: float, cy: float):
"""Gradio wrapper for CAD-based object initialization."""
try:
if not mesh_file:
return "Error: No mesh file provided"
# Load reference images (optional for CAD mode)
reference_images = []
if reference_files:
for file in reference_files:
img = cv2.imread(file.name)
if img is None:
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
reference_images.append(img)
# Prepare camera intrinsics
camera_intrinsics = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy
}
# Register object with mesh
success = pose_estimator.register_object(
object_id=object_id,
reference_images=reference_images if reference_images else [],
camera_intrinsics=camera_intrinsics,
mesh_path=mesh_file.name
)
if success:
ref_info = f" and {len(reference_images)} reference images" if reference_images else ""
return f"✓ Object '{object_id}' initialized with CAD model{ref_info}"
else:
return f"✗ Failed to initialize object '{object_id}'"
except Exception as e:
logger.error(f"CAD initialization error: {e}", exc_info=True)
return f"Error: {str(e)}"
def gradio_initialize_model_free(object_id: str, reference_files: List, fx: float, fy: float, cx: float, cy: float):
"""Gradio wrapper for model-free object initialization."""
try:
if not reference_files:
return "Error: No reference images provided"
# Load reference images
reference_images = []
for file in reference_files:
img = cv2.imread(file.name)
if img is None:
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
reference_images.append(img)
if not reference_images:
return "Error: Could not load any reference images"
if len(reference_images) < 8:
return f"Warning: Only {len(reference_images)} images provided. 16-24 recommended for best results."
# Prepare camera intrinsics
camera_intrinsics = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy
}
# Register object without mesh (model-free)
success = pose_estimator.register_object(
object_id=object_id,
reference_images=reference_images,
camera_intrinsics=camera_intrinsics,
mesh_path=None
)
if success:
return f"✓ Object '{object_id}' initialized with {len(reference_images)} reference images (model-free mode)"
else:
return f"✗ Failed to initialize object '{object_id}'"
except Exception as e:
logger.error(f"Model-free initialization error: {e}", exc_info=True)
return f"Error: {str(e)}"
def gradio_estimate(
object_id: str,
query_image: np.ndarray,
depth_image: np.ndarray,
fx: float,
fy: float,
cx: float,
cy: float,
mask_method: str,
mask_editor_data
):
"""Gradio wrapper for pose estimation."""
try:
if query_image is None:
return "Error: No query image provided", None, None
# Process depth image if provided
depth = None
if depth_image is not None:
# Check if depth needs resizing to match RGB
if depth_image.shape[:2] != query_image.shape[:2]:
logger.warning(f"Depth {depth_image.shape[:2]} and RGB {query_image.shape[:2]} sizes don't match, resizing depth")
depth_image = cv2.resize(depth_image, (query_image.shape[1], query_image.shape[0]), interpolation=cv2.INTER_NEAREST)
# Convert to float32 if needed
if depth_image.dtype == np.uint16:
# Assume 16-bit depth in millimeters
depth = depth_image.astype(np.float32) / 1000.0
logger.info(f"Converted 16-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m")
elif depth_image.dtype == np.uint8:
# 8-bit depth (encoded), need to decode based on format
# For now, assume linear scaling to reasonable depth range
depth = depth_image.astype(np.float32) / 255.0 * 5.0 # Map to 0-5m
logger.info(f"Converted 8-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m")
else:
# Already float, use as-is
depth = depth_image.astype(np.float32)
logger.info(f"Using provided depth (dtype={depth_image.dtype}), range: [{depth.min():.3f}, {depth.max():.3f}]m")
# Handle color depth images (H, W, 3) - take first channel
if len(depth.shape) == 3:
logger.warning("Depth image has 3 channels, using first channel")
depth = depth[:, :, 0]
# Prepare camera intrinsics
camera_intrinsics = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy
}
# Choose mask method
mask = None
debug_mask = None
if mask_method == "SlimSAM":
# Use Otsu mask as a box prompt to guide SlimSAM
naive_mask, _, _, _ = generate_naive_mask(query_image)
box_prompt = _box_from_mask(naive_mask)
mask, debug_mask, score = generate_slimsam_mask(query_image, box_prompt)
logger.info("SlimSAM mask generated (score=%.3f, box=%s)", score, box_prompt)
elif mask_method == "Otsu":
mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(query_image)
logger.info("Otsu mask coverage %.1f%%", mask_percentage)
if fallback_full_image:
logger.warning("Otsu mask fallback to full image due to unrealistic coverage")
elif mask_method == "From editor":
editor_mask = None
if isinstance(mask_editor_data, dict):
layers = mask_editor_data.get("layers")
if isinstance(layers, list) and layers:
editor_mask = layers[-1]
else:
editor_mask = mask_editor_data.get("composite")
else:
editor_mask = mask_editor_data
if editor_mask is None:
return "Error: No editor mask provided", query_image, None
editor_mask = np.array(editor_mask)
if editor_mask.ndim == 3 and editor_mask.shape[2] >= 4:
alpha = editor_mask[:, :, 3]
mask = (alpha > 0).astype(np.uint8) * 255
elif editor_mask.ndim == 3:
gray = cv2.cvtColor(editor_mask, cv2.COLOR_RGB2GRAY)
mask = (gray > 0).astype(np.uint8) * 255
elif editor_mask.ndim == 2:
mask = (editor_mask > 0).astype(np.uint8) * 255
else:
return "Error: Unsupported editor mask format", query_image, None
debug_mask = mask
# Estimate pose
result = pose_estimator.estimate_pose(
object_id=object_id,
query_image=query_image,
depth_image=depth,
camera_intrinsics=camera_intrinsics,
mask=mask
)
if not result.get("success"):
error = result.get("error", "Unknown error")
# Still show mask output even on failure
mask_vis = None
if debug_mask is not None:
mask_vis = query_image.copy()
mask_overlay = np.zeros_like(query_image)
mask_overlay[:, :, 1] = debug_mask
mask_vis = cv2.addWeighted(mask_vis, 0.7, mask_overlay, 0.3, 0)
return f"✗ Estimation failed: {error}", query_image, mask_vis
poses = result.get("poses", [])
note = result.get("note", "")
if debug_mask is None:
debug_mask = result.get("debug_mask", None)
# Create mask visualization
mask_vis = None
if debug_mask is not None:
# Create an RGB visualization of the mask overlaid on the original image
mask_vis = query_image.copy()
# Create green overlay where mask is active
mask_overlay = np.zeros_like(query_image)
mask_overlay[:, :, 1] = debug_mask # Green channel
# Blend with original image
mask_vis = cv2.addWeighted(mask_vis, 0.7, mask_overlay, 0.3, 0)
# Format output
if not poses:
output = "⚠ No poses detected\n"
if note:
output += f"\nNote: {note}"
if debug_mask is not None:
mask_percentage = (debug_mask > 0).sum() / debug_mask.size * 100
output += f"\n\nMask Coverage: {mask_percentage:.1f}% of image"
return output, query_image, mask_vis
output = f"✓ Detected {len(poses)} pose(s):\n\n"
for i, pose in enumerate(poses):
output += f"Pose {i + 1}:\n"
output += f" Object ID: {pose.get('object_id', 'unknown')}\n"
if 'position' in pose:
pos = pose['position']
output += f" Position:\n"
output += f" x: {pos.get('x', 0):.4f} m\n"
output += f" y: {pos.get('y', 0):.4f} m\n"
output += f" z: {pos.get('z', 0):.4f} m\n"
if 'orientation' in pose:
ori = pose['orientation']
output += f" Orientation (quaternion):\n"
output += f" w: {ori.get('w', 0):.4f}\n"
output += f" x: {ori.get('x', 0):.4f}\n"
output += f" y: {ori.get('y', 0):.4f}\n"
output += f" z: {ori.get('z', 0):.4f}\n"
if 'confidence' in pose:
output += f" Confidence: {pose['confidence']:.2%}\n"
output += "\n"
if debug_mask is not None:
mask_percentage = (debug_mask > 0).sum() / debug_mask.size * 100
output += f"\nMask Coverage: {mask_percentage:.1f}% of image"
return output, query_image, mask_vis
except Exception as e:
logger.error(f"Gradio estimation error: {e}", exc_info=True)
return f"Error: {str(e)}", None, None
# Gradio UI
with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎯 FoundationPose 6D Object Pose Estimation")
mode_indicator = gr.Markdown(
"**Mode:** 🟢 Real FoundationPose with GPU",
elem_id="mode"
)
with gr.Tabs():
# Tab 1: Initialize Object
with gr.Tab("Initialize Object"):
gr.Markdown("""
Choose the initialization mode based on whether you have a 3D CAD model of your object.
""")
with gr.Tabs():
# Sub-tab 1.1: CAD-Based Init
with gr.Tab("CAD-Based (Model-Based)"):
gr.Markdown("""
**Model-Based Mode**: Use this if you have a 3D mesh/CAD model (.obj, .stl, .ply).
- Upload your 3D mesh file
- Optionally upload reference images for better initialization
- More accurate and robust
""")
with gr.Row():
with gr.Column():
cad_object_id = gr.Textbox(
label="Object ID",
placeholder="e.g., target_cube",
value="target_cube"
)
cad_mesh_file = gr.File(
label="3D Mesh File (.obj, .stl, .ply)",
file_count="single",
file_types=[".obj", ".stl", ".ply", ".mesh"],
value=str(DEFAULT_MESH) if DEFAULT_MESH.exists() else None
)
cad_ref_files = gr.File(
label="Reference Images (Optional)",
file_count="multiple",
file_types=["image"],
value=[str(p) for p in DEFAULT_REF_IMAGES if p.exists()]
)
gr.Markdown("### Camera Intrinsics")
with gr.Row():
cad_fx = gr.Number(label="fx", value=193.13708498984758)
cad_fy = gr.Number(label="fy", value=193.13708498984758)
with gr.Row():
cad_cx = gr.Number(label="cx", value=120.0)
cad_cy = gr.Number(label="cy", value=80.0)
cad_init_button = gr.Button("Initialize with CAD", variant="primary")
with gr.Column():
cad_init_output = gr.Textbox(
label="Initialization Result",
lines=5,
interactive=False
)
cad_init_button.click(
fn=gradio_initialize_cad,
inputs=[cad_object_id, cad_mesh_file, cad_ref_files, cad_fx, cad_fy, cad_cx, cad_cy],
outputs=cad_init_output
)
# Sub-tab 1.2: Model-Free Init (disabled)
# Tab 2: Estimate Pose
with gr.Tab("Estimate Pose"):
gr.Markdown("""
Upload a query image containing the initialized object.
The model will estimate the 6D pose (position + orientation).
""")
with gr.Row():
with gr.Column():
est_object_id = gr.Textbox(
label="Object ID",
placeholder="e.g., target_cube",
value="target_cube"
)
est_query_image = gr.Image(
label="Query Image (RGB)",
type="numpy",
value=str(DEFAULT_RGB) if DEFAULT_RGB.exists() else None
)
est_depth_image = gr.Image(
label="Depth Image (Optional, 16-bit PNG)",
type="numpy",
value=str(DEFAULT_DEPTH) if DEFAULT_DEPTH.exists() else None
)
est_mask_method = gr.Radio(
choices=["SlimSAM", "Otsu", "From editor"],
value="SlimSAM",
label="Mask Method"
)
est_mask_editor = gr.ImageEditor(
label="Mask Editor (paint mask)",
type="numpy",
visible=False
)
est_fx = gr.Number(label="fx (focal length x)", value=193.13708498984758, visible=False)
est_fy = gr.Number(label="fy (focal length y)", value=193.13708498984758, visible=False)
est_cx = gr.Number(label="cx (principal point x)", value=120.0, visible=False)
est_cy = gr.Number(label="cy (principal point y)", value=80.0, visible=False)
est_button = gr.Button("Estimate Pose", variant="primary")
with gr.Column():
est_mask = gr.Image(label="Auto-Generated Mask (green overlay)")
est_output = gr.Textbox(
label="Pose Estimation Result",
lines=15,
interactive=False
)
est_viz = gr.Image(label="Query Image")
def _toggle_editor(method: str):
return gr.update(visible=method == "From editor")
est_mask_method.change(
fn=_toggle_editor,
inputs=est_mask_method,
outputs=est_mask_editor
)
est_button.click(
fn=gradio_estimate,
inputs=[
est_object_id,
est_query_image,
est_depth_image,
est_fx,
est_fy,
est_cx,
est_cy,
est_mask_method,
est_mask_editor,
],
outputs=[est_output, est_viz, est_mask]
)
gr.Markdown("""
---
## API Documentation
This Space uses Gradio's built-in API. For programmatic access, use the `gradio_client` library:
```python
from gradio_client import Client
client = Client("https://gpue-foundationpose.hf.space")
# Initialize object
result = client.predict(
object_id="target_cube",
reference_files=[file1, file2, ...],
fx=500.0, fy=500.0, cx=320.0, cy=240.0,
api_name="/gradio_initialize"
)
# Estimate pose
result = client.predict(
object_id="target_cube",
query_image=image,
fx=500.0, fy=500.0, cx=320.0, cy=240.0,
mask_method="SlimSAM",
api_name="/gradio_estimate"
)
```
See [client.py](https://huggingface.co/spaces/gpue/foundationpose/blob/main/client.py) for a complete example.
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)