""" FoundationPose inference server for Hugging Face Spaces with ZeroGPU. This version uses pure Gradio for ZeroGPU compatibility. """ import base64 import json import logging import os from pathlib import Path from typing import Dict, List # Ensure OMP_NUM_THREADS is a valid integer to avoid libgomp warnings if not os.environ.get("OMP_NUM_THREADS", "").isdigit(): os.environ["OMP_NUM_THREADS"] = "1" import cv2 import gradio as gr import numpy as np import torch from masks import generate_naive_mask DEFAULT_DATA_DIR = Path("/app/tests/reference/t_shape") DEFAULT_MESH = DEFAULT_DATA_DIR / "t_shape.obj" DEFAULT_RGB = DEFAULT_DATA_DIR / "rgb_001.jpg" DEFAULT_DEPTH = DEFAULT_DATA_DIR / "depth_001.png" DEFAULT_REF_IMAGES = [ DEFAULT_DATA_DIR / "rgb_001.jpg", DEFAULT_DATA_DIR / "rgb_002.jpg", DEFAULT_DATA_DIR / "rgb_003.jpg", ] _slimsam_model = None _slimsam_processor = None _slimsam_device = None def _get_slimsam(): """Lazy-load SlimSAM to avoid heavy startup cost.""" global _slimsam_model, _slimsam_processor, _slimsam_device if _slimsam_model is None or _slimsam_processor is None: from transformers import SamModel, SamProcessor _slimsam_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(_slimsam_device) _slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform") logger.info("SlimSAM loaded on %s", _slimsam_device) return _slimsam_model, _slimsam_processor, _slimsam_device def _box_from_mask(mask_bool: np.ndarray) -> List[int]: ys, xs = np.where(mask_bool) if len(xs) == 0: return [0, 0, mask_bool.shape[1] - 1, mask_bool.shape[0] - 1] x0, x1 = int(xs.min()), int(xs.max()) y0, y1 = int(ys.min()), int(ys.max()) return [x0, y0, x1, y1] def generate_slimsam_mask(rgb_image: np.ndarray, box_prompt: List[int]) -> tuple[np.ndarray, np.ndarray, float]: """Generate a SlimSAM mask using a box prompt.""" from PIL import Image model, processor, device = _get_slimsam() raw_image = Image.fromarray(rgb_image).convert("RGB") enc = processor(raw_image, input_boxes=[[box_prompt]], return_tensors="np") # Keep size tensors on CPU for post-processing original_sizes = torch.as_tensor(enc["original_sizes"]) reshaped_sizes = torch.as_tensor(enc["reshaped_input_sizes"]) # Move model inputs to device inputs = { k: torch.as_tensor(v).to(device) for k, v in enc.items() if k not in {"original_sizes", "reshaped_input_sizes"} } outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), original_sizes, reshaped_sizes, )[0] scores = outputs.iou_scores.squeeze().cpu() best_idx = int(scores.argmax().item()) best_mask = masks[0, best_idx].numpy() best_score = float(scores[best_idx].item()) mask_bool = best_mask.astype(bool) debug_mask = (mask_bool.astype(np.uint8) * 255) return mask_bool, debug_mask, best_score logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s" ) logger = logging.getLogger(__name__) # Ensure OMP_NUM_THREADS is a valid integer to avoid libgomp warnings if not os.environ.get("OMP_NUM_THREADS", "").isdigit(): os.environ["OMP_NUM_THREADS"] = "1" # Always use real FoundationPose model USE_REAL_MODEL = True logger.info("Starting in REAL mode with FoundationPose") class FoundationPoseInference: """Wrapper for FoundationPose model inference.""" def __init__(self): self.model = None self.device = None self.initialized = False self.tracked_objects = {} self.use_real_model = USE_REAL_MODEL def initialize_model(self): """Initialize the FoundationPose model on GPU.""" if self.initialized: logger.info("Model already initialized") return logger.info("Initializing FoundationPose model...") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") if self.use_real_model: try: logger.info("Loading real FoundationPose model...") from estimator import FoundationPoseEstimator self.model = FoundationPoseEstimator( device=str(self.device) ) if getattr(self.model, "available", True): logger.info("✓ Real FoundationPose model initialized successfully") else: raise RuntimeError("FoundationPose dependencies missing") except Exception as e: logger.error(f"Failed to initialize real model: {e}", exc_info=True) logger.warning("Falling back to placeholder mode") self.use_real_model = False self.model = None else: logger.info("Using placeholder mode (set USE_REAL_MODEL=true for real inference)") self.model = None self.initialized = True logger.info("FoundationPose inference ready") def register_object( self, object_id: str, reference_images: List[np.ndarray], camera_intrinsics: Dict = None, mesh_path: str = None ) -> bool: """Register an object for tracking with reference images.""" if not self.initialized: self.initialize_model() logger.info(f"Registering object '{object_id}' with {len(reference_images)} reference images") if self.use_real_model and self.model is not None: try: success = self.model.register_object( object_id=object_id, reference_images=reference_images, camera_intrinsics=camera_intrinsics, mesh_path=mesh_path ) if success: self.tracked_objects[object_id] = { "num_references": len(reference_images), "camera_intrinsics": camera_intrinsics, "mesh_path": mesh_path } return success except Exception as e: logger.error(f"Registration failed: {e}", exc_info=True) return False else: self.tracked_objects[object_id] = { "num_references": len(reference_images), "camera_intrinsics": camera_intrinsics, "mesh_path": mesh_path } logger.info(f"✓ Object '{object_id}' registered (placeholder mode)") return True def estimate_pose( self, object_id: str, query_image: np.ndarray, camera_intrinsics: Dict = None, depth_image: np.ndarray = None, mask: np.ndarray = None ) -> Dict: """Estimate 6D pose of an object in a query image.""" if not self.initialized: return {"success": False, "error": "Model not initialized"} if object_id not in self.tracked_objects: return {"success": False, "error": f"Object '{object_id}' not registered"} logger.info(f"Estimating pose for object '{object_id}'") if self.use_real_model and self.model is not None: try: pose_result = self.model.estimate_pose( object_id=object_id, rgb_image=query_image, depth_image=depth_image, mask=mask, camera_intrinsics=camera_intrinsics ) if pose_result is None: return { "success": False, "error": "Pose estimation returned None", "poses": [], "debug_mask": None } # Extract debug mask if present debug_mask = pose_result.pop("debug_mask", None) return { "success": True, "poses": [pose_result], "debug_mask": debug_mask } except Exception as e: logger.error(f"Pose estimation error: {e}", exc_info=True) return {"success": False, "error": str(e), "poses": []} else: logger.info("Placeholder mode: returning empty pose result") return { "success": True, "poses": [], "note": "Placeholder mode - set USE_REAL_MODEL=true for real inference" } # Global model instance pose_estimator = FoundationPoseInference() # Gradio wrapper functions def gradio_initialize_cad(object_id: str, mesh_file, reference_files: List, fx: float, fy: float, cx: float, cy: float): """Gradio wrapper for CAD-based object initialization.""" try: if not mesh_file: return "Error: No mesh file provided" # Load reference images (optional for CAD mode) reference_images = [] if reference_files: for file in reference_files: img = cv2.imread(file.name) if img is None: continue img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) reference_images.append(img) # Prepare camera intrinsics camera_intrinsics = { "fx": fx, "fy": fy, "cx": cx, "cy": cy } # Register object with mesh success = pose_estimator.register_object( object_id=object_id, reference_images=reference_images if reference_images else [], camera_intrinsics=camera_intrinsics, mesh_path=mesh_file.name ) if success: ref_info = f" and {len(reference_images)} reference images" if reference_images else "" return f"✓ Object '{object_id}' initialized with CAD model{ref_info}" else: return f"✗ Failed to initialize object '{object_id}'" except Exception as e: logger.error(f"CAD initialization error: {e}", exc_info=True) return f"Error: {str(e)}" def gradio_initialize_model_free(object_id: str, reference_files: List, fx: float, fy: float, cx: float, cy: float): """Gradio wrapper for model-free object initialization.""" try: if not reference_files: return "Error: No reference images provided" # Load reference images reference_images = [] for file in reference_files: img = cv2.imread(file.name) if img is None: continue img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) reference_images.append(img) if not reference_images: return "Error: Could not load any reference images" if len(reference_images) < 8: return f"Warning: Only {len(reference_images)} images provided. 16-24 recommended for best results." # Prepare camera intrinsics camera_intrinsics = { "fx": fx, "fy": fy, "cx": cx, "cy": cy } # Register object without mesh (model-free) success = pose_estimator.register_object( object_id=object_id, reference_images=reference_images, camera_intrinsics=camera_intrinsics, mesh_path=None ) if success: return f"✓ Object '{object_id}' initialized with {len(reference_images)} reference images (model-free mode)" else: return f"✗ Failed to initialize object '{object_id}'" except Exception as e: logger.error(f"Model-free initialization error: {e}", exc_info=True) return f"Error: {str(e)}" def gradio_estimate( object_id: str, query_image: np.ndarray, depth_image: np.ndarray, fx: float, fy: float, cx: float, cy: float, mask_method: str, mask_editor_data ): """Gradio wrapper for pose estimation.""" try: if query_image is None: return "Error: No query image provided", None, None # Process depth image if provided depth = None if depth_image is not None: # Check if depth needs resizing to match RGB if depth_image.shape[:2] != query_image.shape[:2]: logger.warning(f"Depth {depth_image.shape[:2]} and RGB {query_image.shape[:2]} sizes don't match, resizing depth") depth_image = cv2.resize(depth_image, (query_image.shape[1], query_image.shape[0]), interpolation=cv2.INTER_NEAREST) # Convert to float32 if needed if depth_image.dtype == np.uint16: # Assume 16-bit depth in millimeters depth = depth_image.astype(np.float32) / 1000.0 logger.info(f"Converted 16-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m") elif depth_image.dtype == np.uint8: # 8-bit depth (encoded), need to decode based on format # For now, assume linear scaling to reasonable depth range depth = depth_image.astype(np.float32) / 255.0 * 5.0 # Map to 0-5m logger.info(f"Converted 8-bit depth to float32, range: [{depth.min():.3f}, {depth.max():.3f}]m") else: # Already float, use as-is depth = depth_image.astype(np.float32) logger.info(f"Using provided depth (dtype={depth_image.dtype}), range: [{depth.min():.3f}, {depth.max():.3f}]m") # Handle color depth images (H, W, 3) - take first channel if len(depth.shape) == 3: logger.warning("Depth image has 3 channels, using first channel") depth = depth[:, :, 0] # Prepare camera intrinsics camera_intrinsics = { "fx": fx, "fy": fy, "cx": cx, "cy": cy } # Choose mask method mask = None debug_mask = None if mask_method == "SlimSAM": # Use Otsu mask as a box prompt to guide SlimSAM naive_mask, _, _, _ = generate_naive_mask(query_image) box_prompt = _box_from_mask(naive_mask) mask, debug_mask, score = generate_slimsam_mask(query_image, box_prompt) logger.info("SlimSAM mask generated (score=%.3f, box=%s)", score, box_prompt) elif mask_method == "Otsu": mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(query_image) logger.info("Otsu mask coverage %.1f%%", mask_percentage) if fallback_full_image: logger.warning("Otsu mask fallback to full image due to unrealistic coverage") elif mask_method == "From editor": editor_mask = None if isinstance(mask_editor_data, dict): layers = mask_editor_data.get("layers") if isinstance(layers, list) and layers: editor_mask = layers[-1] else: editor_mask = mask_editor_data.get("composite") else: editor_mask = mask_editor_data if editor_mask is None: return "Error: No editor mask provided", query_image, None editor_mask = np.array(editor_mask) if editor_mask.ndim == 3 and editor_mask.shape[2] >= 4: alpha = editor_mask[:, :, 3] mask = (alpha > 0).astype(np.uint8) * 255 elif editor_mask.ndim == 3: gray = cv2.cvtColor(editor_mask, cv2.COLOR_RGB2GRAY) mask = (gray > 0).astype(np.uint8) * 255 elif editor_mask.ndim == 2: mask = (editor_mask > 0).astype(np.uint8) * 255 else: return "Error: Unsupported editor mask format", query_image, None debug_mask = mask # Estimate pose result = pose_estimator.estimate_pose( object_id=object_id, query_image=query_image, depth_image=depth, camera_intrinsics=camera_intrinsics, mask=mask ) if not result.get("success"): error = result.get("error", "Unknown error") # Still show mask output even on failure mask_vis = None if debug_mask is not None: mask_vis = query_image.copy() mask_overlay = np.zeros_like(query_image) mask_overlay[:, :, 1] = debug_mask mask_vis = cv2.addWeighted(mask_vis, 0.7, mask_overlay, 0.3, 0) return f"✗ Estimation failed: {error}", query_image, mask_vis poses = result.get("poses", []) note = result.get("note", "") if debug_mask is None: debug_mask = result.get("debug_mask", None) # Create mask visualization mask_vis = None if debug_mask is not None: # Create an RGB visualization of the mask overlaid on the original image mask_vis = query_image.copy() # Create green overlay where mask is active mask_overlay = np.zeros_like(query_image) mask_overlay[:, :, 1] = debug_mask # Green channel # Blend with original image mask_vis = cv2.addWeighted(mask_vis, 0.7, mask_overlay, 0.3, 0) # Format output if not poses: output = "⚠ No poses detected\n" if note: output += f"\nNote: {note}" if debug_mask is not None: mask_percentage = (debug_mask > 0).sum() / debug_mask.size * 100 output += f"\n\nMask Coverage: {mask_percentage:.1f}% of image" return output, query_image, mask_vis output = f"✓ Detected {len(poses)} pose(s):\n\n" for i, pose in enumerate(poses): output += f"Pose {i + 1}:\n" output += f" Object ID: {pose.get('object_id', 'unknown')}\n" if 'position' in pose: pos = pose['position'] output += f" Position:\n" output += f" x: {pos.get('x', 0):.4f} m\n" output += f" y: {pos.get('y', 0):.4f} m\n" output += f" z: {pos.get('z', 0):.4f} m\n" if 'orientation' in pose: ori = pose['orientation'] output += f" Orientation (quaternion):\n" output += f" w: {ori.get('w', 0):.4f}\n" output += f" x: {ori.get('x', 0):.4f}\n" output += f" y: {ori.get('y', 0):.4f}\n" output += f" z: {ori.get('z', 0):.4f}\n" if 'confidence' in pose: output += f" Confidence: {pose['confidence']:.2%}\n" output += "\n" if debug_mask is not None: mask_percentage = (debug_mask > 0).sum() / debug_mask.size * 100 output += f"\nMask Coverage: {mask_percentage:.1f}% of image" return output, query_image, mask_vis except Exception as e: logger.error(f"Gradio estimation error: {e}", exc_info=True) return f"Error: {str(e)}", None, None # Gradio UI with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎯 FoundationPose 6D Object Pose Estimation") mode_indicator = gr.Markdown( "**Mode:** 🟢 Real FoundationPose with GPU", elem_id="mode" ) with gr.Tabs(): # Tab 1: Initialize Object with gr.Tab("Initialize Object"): gr.Markdown(""" Choose the initialization mode based on whether you have a 3D CAD model of your object. """) with gr.Tabs(): # Sub-tab 1.1: CAD-Based Init with gr.Tab("CAD-Based (Model-Based)"): gr.Markdown(""" **Model-Based Mode**: Use this if you have a 3D mesh/CAD model (.obj, .stl, .ply). - Upload your 3D mesh file - Optionally upload reference images for better initialization - More accurate and robust """) with gr.Row(): with gr.Column(): cad_object_id = gr.Textbox( label="Object ID", placeholder="e.g., target_cube", value="target_cube" ) cad_mesh_file = gr.File( label="3D Mesh File (.obj, .stl, .ply)", file_count="single", file_types=[".obj", ".stl", ".ply", ".mesh"], value=str(DEFAULT_MESH) if DEFAULT_MESH.exists() else None ) cad_ref_files = gr.File( label="Reference Images (Optional)", file_count="multiple", file_types=["image"], value=[str(p) for p in DEFAULT_REF_IMAGES if p.exists()] ) gr.Markdown("### Camera Intrinsics") with gr.Row(): cad_fx = gr.Number(label="fx", value=193.13708498984758) cad_fy = gr.Number(label="fy", value=193.13708498984758) with gr.Row(): cad_cx = gr.Number(label="cx", value=120.0) cad_cy = gr.Number(label="cy", value=80.0) cad_init_button = gr.Button("Initialize with CAD", variant="primary") with gr.Column(): cad_init_output = gr.Textbox( label="Initialization Result", lines=5, interactive=False ) cad_init_button.click( fn=gradio_initialize_cad, inputs=[cad_object_id, cad_mesh_file, cad_ref_files, cad_fx, cad_fy, cad_cx, cad_cy], outputs=cad_init_output ) # Sub-tab 1.2: Model-Free Init (disabled) # Tab 2: Estimate Pose with gr.Tab("Estimate Pose"): gr.Markdown(""" Upload a query image containing the initialized object. The model will estimate the 6D pose (position + orientation). """) with gr.Row(): with gr.Column(): est_object_id = gr.Textbox( label="Object ID", placeholder="e.g., target_cube", value="target_cube" ) est_query_image = gr.Image( label="Query Image (RGB)", type="numpy", value=str(DEFAULT_RGB) if DEFAULT_RGB.exists() else None ) est_depth_image = gr.Image( label="Depth Image (Optional, 16-bit PNG)", type="numpy", value=str(DEFAULT_DEPTH) if DEFAULT_DEPTH.exists() else None ) est_mask_method = gr.Radio( choices=["SlimSAM", "Otsu", "From editor"], value="SlimSAM", label="Mask Method" ) est_mask_editor = gr.ImageEditor( label="Mask Editor (paint mask)", type="numpy", visible=False ) est_fx = gr.Number(label="fx (focal length x)", value=193.13708498984758, visible=False) est_fy = gr.Number(label="fy (focal length y)", value=193.13708498984758, visible=False) est_cx = gr.Number(label="cx (principal point x)", value=120.0, visible=False) est_cy = gr.Number(label="cy (principal point y)", value=80.0, visible=False) est_button = gr.Button("Estimate Pose", variant="primary") with gr.Column(): est_mask = gr.Image(label="Auto-Generated Mask (green overlay)") est_output = gr.Textbox( label="Pose Estimation Result", lines=15, interactive=False ) est_viz = gr.Image(label="Query Image") def _toggle_editor(method: str): return gr.update(visible=method == "From editor") est_mask_method.change( fn=_toggle_editor, inputs=est_mask_method, outputs=est_mask_editor ) est_button.click( fn=gradio_estimate, inputs=[ est_object_id, est_query_image, est_depth_image, est_fx, est_fy, est_cx, est_cy, est_mask_method, est_mask_editor, ], outputs=[est_output, est_viz, est_mask] ) gr.Markdown(""" --- ## API Documentation This Space uses Gradio's built-in API. For programmatic access, use the `gradio_client` library: ```python from gradio_client import Client client = Client("https://gpue-foundationpose.hf.space") # Initialize object result = client.predict( object_id="target_cube", reference_files=[file1, file2, ...], fx=500.0, fy=500.0, cx=320.0, cy=240.0, api_name="/gradio_initialize" ) # Estimate pose result = client.predict( object_id="target_cube", query_image=image, fx=500.0, fy=500.0, cx=320.0, cy=240.0, mask_method="SlimSAM", api_name="/gradio_estimate" ) ``` See [client.py](https://huggingface.co/spaces/gpue/foundationpose/blob/main/client.py) for a complete example. """) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)