Spaces:
Sleeping
Sleeping
| """ | |
| Test script for FoundationPose HuggingFace API. | |
| This test verifies that the API can: | |
| 1. Initialize an object with CAD model (T-shape mesh) | |
| 2. Estimate pose from query images | |
| """ | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import cv2 | |
| from gradio_client import Client, handle_file | |
| # Add parent directory to path to import client | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from client import FoundationPoseClient | |
| def load_test_data(reference_dir: Path): | |
| """Load RGB and depth test images from t_shape directory.""" | |
| rgb_path = reference_dir / "rgb_001.jpg" | |
| depth_path = reference_dir / "depth_001.png" | |
| # Load RGB image | |
| print(f"Loading RGB: {rgb_path}") | |
| rgb = cv2.imread(str(rgb_path)) | |
| if rgb is None: | |
| raise FileNotFoundError(f"Could not load RGB image: {rgb_path}") | |
| rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) | |
| # Load depth image (16-bit PNG) | |
| print(f"Loading depth: {depth_path}") | |
| depth = cv2.imread(str(depth_path), cv2.IMREAD_ANYDEPTH) | |
| if depth is None: | |
| raise FileNotFoundError(f"Could not load depth image: {depth_path}") | |
| # Check if depth needs resizing to match RGB | |
| if depth.shape[:2] != rgb.shape[:2]: | |
| print(f"⚠ Warning: Depth ({depth.shape[:2]}) and RGB ({rgb.shape[:2]}) sizes don't match") | |
| print(f" Resizing depth to match RGB...") | |
| depth = cv2.resize(depth, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| # Convert depth to meters (assuming it's in mm or similar) | |
| # Depth should be float32 in meters for FoundationPose | |
| depth = depth.astype(np.float32) / 1000.0 # Convert mm to meters | |
| print(f"✓ RGB loaded: shape={rgb.shape}, dtype={rgb.dtype}") | |
| print(f"✓ Depth loaded: shape={depth.shape}, dtype={depth.dtype}, range=[{depth.min():.3f}, {depth.max():.3f}]m") | |
| return rgb, depth | |
| def test_client_initialization(): | |
| """Test that API client initializes without errors.""" | |
| print("=" * 60) | |
| print("Test 1: API Client Initialization") | |
| print("=" * 60) | |
| try: | |
| client = FoundationPoseClient(api_url="https://gpue-foundationpose.hf.space") | |
| print("✓ API client initialized successfully") | |
| return client | |
| except Exception as e: | |
| print(f"✗ API client initialization failed: {e}") | |
| return None | |
| def test_cad_initialization(client, mesh_path): | |
| """Test CAD-based object initialization via API.""" | |
| print("\n" + "=" * 60) | |
| print("Test 2: CAD-Based Initialization via API") | |
| print("=" * 60) | |
| print(f"Mesh file: {mesh_path.name}") | |
| # Define camera intrinsics matching the actual image size (240x160) | |
| # Principal point (cx, cy) should be at image center | |
| # Focal lengths estimated assuming ~60° FOV | |
| camera_intrinsics = { | |
| "fx": 200.0, # Focal length adjusted for 240px width | |
| "fy": 200.0, # Focal length adjusted for 160px height | |
| "cx": 120.0, # Image center x (240/2) | |
| "cy": 80.0 # Image center y (160/2) | |
| } | |
| try: | |
| # Extract intrinsics | |
| fx = camera_intrinsics.get("fx", 600.0) | |
| fy = camera_intrinsics.get("fy", 600.0) | |
| cx = camera_intrinsics.get("cx", 320.0) | |
| cy = camera_intrinsics.get("cy", 240.0) | |
| # Call CAD-based initialization endpoint directly | |
| result = client.client.predict( | |
| object_id="t_shape", | |
| mesh_file=handle_file(str(mesh_path)), | |
| reference_files=[], # No reference images needed for CAD mode | |
| fx=fx, | |
| fy=fy, | |
| cx=cx, | |
| cy=cy, | |
| api_name="/gradio_initialize_cad" | |
| ) | |
| print(f"API result: {result}") | |
| if isinstance(result, str) and ("✓" in result or "initialized" in result.lower()): | |
| print("✓ Object initialized successfully with CAD model") | |
| return True | |
| elif isinstance(result, str) and ("Error" in result or "error" in result): | |
| print(f"✗ Object initialization failed: {result}") | |
| return False | |
| else: | |
| print("✓ Object initialized (assuming success)") | |
| return True | |
| except Exception as e: | |
| print(f"✗ Object initialization failed with exception: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def test_pose_estimation(client, query_image, depth_image, query_name): | |
| """Test pose estimation on a query image via API with depth and mask verification.""" | |
| print("\n" + "=" * 60) | |
| print("Test 3: Pose Estimation via API (with Depth & Mask)") | |
| print("=" * 60) | |
| print(f"Query image: {query_name}") | |
| # Define camera intrinsics (must match initialization and actual image size) | |
| camera_intrinsics = { | |
| "fx": 200.0, # Focal length for 240px width | |
| "fy": 200.0, # Focal length for 160px height | |
| "cx": 120.0, # Image center x (240/2) | |
| "cy": 80.0 # Image center y (160/2) | |
| } | |
| try: | |
| # Save images to temp files for API upload | |
| import tempfile | |
| rgb_temp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") | |
| depth_temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| # Save RGB as JPEG | |
| rgb_bgr = cv2.cvtColor(query_image, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(rgb_temp.name, rgb_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95]) | |
| # Save depth as 16-bit PNG (convert back from meters to mm) | |
| depth_uint16 = (depth_image * 1000.0).astype(np.uint16) | |
| cv2.imwrite(depth_temp.name, depth_uint16) | |
| print(f"Calling API with RGB + Depth images...") | |
| # Call Gradio API directly to get all outputs (text, viz, mask) | |
| result = client.client.predict( | |
| object_id="t_shape", | |
| query_image=handle_file(rgb_temp.name), | |
| depth_image=handle_file(depth_temp.name), | |
| fx=camera_intrinsics["fx"], | |
| fy=camera_intrinsics["fy"], | |
| cx=camera_intrinsics["cx"], | |
| cy=camera_intrinsics["cy"], | |
| api_name="/gradio_estimate" | |
| ) | |
| # Clean up temp files | |
| from pathlib import Path | |
| Path(rgb_temp.name).unlink() | |
| Path(depth_temp.name).unlink() | |
| # Result should be tuple: (text_output, viz_image, mask_image) | |
| if not isinstance(result, tuple) or len(result) != 3: | |
| print(f"✗ Unexpected result format: {type(result)}, length={len(result) if isinstance(result, tuple) else 'N/A'}") | |
| return False | |
| text_output, viz_image, mask_image = result | |
| print(f"\n✓ API returned 3 outputs as expected") | |
| print(f" - Text output: {len(text_output)} chars") | |
| print(f" - Viz image: {viz_image.shape if viz_image is not None else 'None'}") | |
| print(f" - Mask image: {mask_image.shape if mask_image is not None else 'None'}") | |
| # Verify mask was generated | |
| if mask_image is None: | |
| print(f"✗ Mask was not returned (expected auto-generated mask)") | |
| return False | |
| print(f"✓ Mask returned: shape={mask_image.shape}, dtype={mask_image.dtype}") | |
| # Check text output for success/failure | |
| if "Error" in text_output or "✗" in text_output: | |
| print(f"✗ Estimation failed: {text_output[:200]}") | |
| return False | |
| # Check if poses were detected | |
| if "No poses detected" in text_output or "⚠" in text_output: | |
| print(f"⚠ No poses detected (API working, but no objects found)") | |
| print(f"Output: {text_output[:300]}") | |
| return False | |
| # Success - parse output | |
| print(f"✓ Pose estimation succeeded!") | |
| print(f"\nEstimation output:") | |
| print(text_output) | |
| return True | |
| except Exception as e: | |
| print(f"✗ Pose estimation failed with exception: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def main(): | |
| """Run all tests.""" | |
| print("\n" + "=" * 60) | |
| print("FoundationPose CAD-Based API Test Suite") | |
| print("=" * 60) | |
| # Setup paths | |
| test_dir = Path(__file__).parent | |
| mesh_path = test_dir / "reference" / "t_shape" / "t_shape.obj" | |
| reference_dir = test_dir / "reference" / "t_shape" | |
| # Check if mesh file exists | |
| if not mesh_path.exists(): | |
| print(f"✗ Mesh file not found: {mesh_path}") | |
| return | |
| # Check if reference images exist (for query testing) | |
| if not reference_dir.exists(): | |
| print(f"✗ Reference directory not found: {reference_dir}") | |
| return | |
| print(f"\nUsing T-shape mesh: {mesh_path}") | |
| print(f"Using test data from: {reference_dir}") | |
| # Load test RGB and depth images | |
| try: | |
| rgb_image, depth_image = load_test_data(reference_dir) | |
| except FileNotFoundError as e: | |
| print(f"✗ {e}") | |
| return | |
| print(f"\n✓ Loaded RGB and depth images - testing with both") | |
| # Test 1: Initialize API client | |
| client = test_client_initialization() | |
| if client is None: | |
| print("\n" + "=" * 60) | |
| print("TESTS ABORTED: API client initialization failed") | |
| print("=" * 60) | |
| return | |
| # Test 2: Initialize object with CAD model | |
| success = test_cad_initialization(client, mesh_path) | |
| if not success: | |
| print("\n" + "=" * 60) | |
| print("TESTS ABORTED: CAD initialization failed") | |
| print("=" * 60) | |
| return | |
| # Test 3: Estimate pose using RGB + depth images | |
| success = test_pose_estimation(client, rgb_image, depth_image, "rgb_001.jpg") | |
| # Print final results | |
| print("\n" + "=" * 60) | |
| print("TEST SUMMARY") | |
| print("=" * 60) | |
| print("✓ API client initialization: PASSED") | |
| print("✓ CAD-based object initialization: PASSED") | |
| if success: | |
| print("✓ Pose estimation with RGB+depth: PASSED") | |
| print("✓ Mask generation verification: PASSED") | |
| print("\n🎉 ALL TESTS PASSED") | |
| else: | |
| print("⚠ Pose estimation: Issues detected (see output above)") | |
| print("\n📊 API TESTS PARTIALLY PASSED (2/3 core functions verified)") | |
| print("\nPossible reasons for no detections:") | |
| print(" - Camera intrinsics mismatch") | |
| print(" - Object not visible or occluded in image") | |
| print(" - Depth data quality issues") | |
| print(" - Mask segmentation inaccurate") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |