Spaces:
Sleeping
Sleeping
| """ | |
| Client for FoundationPose Hugging Face Space API | |
| This client can be used from the robot-ml training pipeline to call the | |
| FoundationPose inference API hosted on Hugging Face Spaces. | |
| """ | |
| import json | |
| import logging | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| import cv2 | |
| import numpy as np | |
| from gradio_client import Client, handle_file | |
| logger = logging.getLogger(__name__) | |
| class FoundationPoseClient: | |
| """Client for FoundationPose Gradio API.""" | |
| def __init__(self, api_url: str = "https://gpue-foundationpose.hf.space"): | |
| """Initialize client. | |
| Args: | |
| api_url: Base URL of the FoundationPose Space | |
| """ | |
| self.api_url = api_url.rstrip("/") | |
| logger.info(f"Initializing Gradio client for {self.api_url}") | |
| self.client = Client(self.api_url) | |
| logger.info("Gradio client initialized") | |
| def _save_image_temp(self, image: np.ndarray) -> str: | |
| """Save image to temporary file. | |
| Args: | |
| image: RGB image as numpy array | |
| Returns: | |
| Path to temporary file | |
| """ | |
| # Convert RGB to BGR for OpenCV | |
| image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| # Save to temp file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") | |
| cv2.imwrite(temp_file.name, image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95]) | |
| return temp_file.name | |
| def initialize( | |
| self, | |
| object_id: str, | |
| reference_images: List[np.ndarray], | |
| camera_intrinsics: Optional[Dict] = None | |
| ) -> bool: | |
| """Initialize object tracking with reference images. | |
| Args: | |
| object_id: Unique ID for the object | |
| reference_images: List of RGB images (numpy arrays) | |
| camera_intrinsics: Optional camera parameters (dict with fx, fy, cx, cy) | |
| Returns: | |
| True if successful | |
| Raises: | |
| RuntimeError: If initialization fails | |
| """ | |
| logger.info(f"Initializing object '{object_id}' with {len(reference_images)} reference images") | |
| # Save images to temporary files | |
| temp_files = [] | |
| try: | |
| for img in reference_images: | |
| temp_path = self._save_image_temp(img) | |
| temp_files.append(temp_path) | |
| # Extract camera intrinsics or use defaults | |
| if camera_intrinsics: | |
| fx = camera_intrinsics.get("fx", 600.0) | |
| fy = camera_intrinsics.get("fy", 600.0) | |
| cx = camera_intrinsics.get("cx", 320.0) | |
| cy = camera_intrinsics.get("cy", 240.0) | |
| else: | |
| fx, fy, cx, cy = 600.0, 600.0, 320.0, 240.0 | |
| # Call Gradio API | |
| result = self.client.predict( | |
| object_id=object_id, | |
| reference_files=[handle_file(f) for f in temp_files], | |
| fx=fx, | |
| fy=fy, | |
| cx=cx, | |
| cy=cy, | |
| api_name="/gradio_initialize" | |
| ) | |
| # Parse result - Gradio returns plain text | |
| logger.info(f"API result: {result}") | |
| if isinstance(result, str): | |
| # Check if result indicates success (contains ✓ or "initialized") | |
| if "✓" in result or "initialized" in result.lower(): | |
| logger.info("Initialization successful") | |
| return True | |
| elif "Error" in result or "error" in result: | |
| raise RuntimeError(f"Initialization failed: {result}") | |
| else: | |
| # Assume success if no error indication | |
| return True | |
| else: | |
| raise RuntimeError(f"Unexpected result type: {type(result)}") | |
| except RuntimeError: | |
| raise | |
| except Exception as e: | |
| logger.error(f"API request failed: {e}") | |
| raise RuntimeError(f"Failed to initialize object: {e}") | |
| finally: | |
| # Clean up temp files | |
| for temp_file in temp_files: | |
| try: | |
| Path(temp_file).unlink() | |
| except Exception: | |
| pass | |
| def estimate_pose( | |
| self, | |
| object_id: str, | |
| query_image: np.ndarray, | |
| camera_intrinsics: Optional[Dict] = None | |
| ) -> List[Dict]: | |
| """Estimate 6D pose of object in query image. | |
| Args: | |
| object_id: ID of object to detect | |
| query_image: RGB query image as numpy array | |
| camera_intrinsics: Optional camera parameters (dict with fx, fy, cx, cy) | |
| Returns: | |
| List of detected poses: | |
| [ | |
| { | |
| "object_id": str, | |
| "position": {"x": float, "y": float, "z": float}, | |
| "orientation": {"w": float, "x": float, "y": float, "z": float}, | |
| "confidence": float, | |
| "dimensions": [float, float, float] | |
| } | |
| ] | |
| Raises: | |
| RuntimeError: If estimation fails | |
| """ | |
| # Save query image to temp file | |
| temp_file = self._save_image_temp(query_image) | |
| try: | |
| # Extract camera intrinsics or use defaults | |
| if camera_intrinsics: | |
| fx = camera_intrinsics.get("fx", 600.0) | |
| fy = camera_intrinsics.get("fy", 600.0) | |
| cx = camera_intrinsics.get("cx", 320.0) | |
| cy = camera_intrinsics.get("cy", 240.0) | |
| else: | |
| fx, fy, cx, cy = 600.0, 600.0, 320.0, 240.0 | |
| # Call Gradio API | |
| result = self.client.predict( | |
| object_id=object_id, | |
| query_image=handle_file(temp_file), | |
| fx=fx, | |
| fy=fy, | |
| cx=cx, | |
| cy=cy, | |
| api_name="/gradio_estimate" | |
| ) | |
| # Parse result - Gradio may return tuple (text, image) or just text | |
| logger.info(f"API result type: {type(result)}") | |
| # If tuple, take first element (text output) | |
| if isinstance(result, tuple): | |
| result = result[0] | |
| if isinstance(result, str): | |
| logger.info(f"API result: {result}") | |
| # Check for errors | |
| if "Error" in result or "not initialized" in result: | |
| raise RuntimeError(f"Pose estimation failed: {result}") | |
| # Try to parse as JSON (in case app.py returns JSON string) | |
| try: | |
| result_dict = json.loads(result) | |
| if isinstance(result_dict, dict) and "poses" in result_dict: | |
| return result_dict["poses"] | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| # Check if the result indicates no poses detected | |
| if "No poses detected" in result or "⚠" in result: | |
| logger.info("No poses detected in query image") | |
| return [] | |
| # For now, return empty list with a warning | |
| logger.warning(f"Could not parse pose from result: {result}") | |
| return [] | |
| else: | |
| raise RuntimeError(f"Unexpected result type: {type(result)}") | |
| except RuntimeError: | |
| raise | |
| except Exception as e: | |
| logger.error(f"API request failed: {e}") | |
| raise RuntimeError(f"Failed to estimate pose: {e}") | |
| finally: | |
| # Clean up temp file | |
| try: | |
| Path(temp_file).unlink() | |
| except Exception: | |
| pass | |
| def load_reference_images(directory: Path) -> List[np.ndarray]: | |
| """Load reference images from directory. | |
| Args: | |
| directory: Path to directory containing images | |
| Returns: | |
| List of RGB images as numpy arrays | |
| """ | |
| images = [] | |
| for img_path in sorted(directory.glob("*.jpg")): | |
| img = cv2.imread(str(img_path)) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| images.append(img) | |
| logger.info(f"Loaded {len(images)} reference images from {directory}") | |
| return images | |
| # Example usage | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| # Initialize client | |
| client = FoundationPoseClient() | |
| # Load reference images | |
| ref_dir = Path("../training/perception/reference/target_cube") | |
| if ref_dir.exists(): | |
| ref_images = load_reference_images(ref_dir) | |
| # Initialize object | |
| client.initialize("target_cube", ref_images) | |
| # Estimate pose on first reference image (for testing) | |
| poses = client.estimate_pose("target_cube", ref_images[0]) | |
| print(f"Detected {len(poses)} poses:") | |
| for pose in poses: | |
| print(f" {pose}") | |
| else: | |
| print(f"Reference directory not found: {ref_dir}") | |
| print("Run 'make capture-reference' to collect reference images first") | |