""" Test script for FoundationPose HuggingFace API. This test verifies that the API can: 1. Initialize an object with CAD model (T-shape mesh) 2. Estimate pose from query images """ import sys from pathlib import Path import numpy as np import cv2 from gradio_client import Client, handle_file # Add parent directory to path to import client sys.path.insert(0, str(Path(__file__).parent.parent)) from client import FoundationPoseClient def load_test_data(reference_dir: Path): """Load RGB and depth test images from t_shape directory.""" rgb_path = reference_dir / "rgb_001.jpg" depth_path = reference_dir / "depth_001.png" # Load RGB image print(f"Loading RGB: {rgb_path}") rgb = cv2.imread(str(rgb_path)) if rgb is None: raise FileNotFoundError(f"Could not load RGB image: {rgb_path}") rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) # Load depth image (16-bit PNG) print(f"Loading depth: {depth_path}") depth = cv2.imread(str(depth_path), cv2.IMREAD_ANYDEPTH) if depth is None: raise FileNotFoundError(f"Could not load depth image: {depth_path}") # Check if depth needs resizing to match RGB if depth.shape[:2] != rgb.shape[:2]: print(f"⚠ Warning: Depth ({depth.shape[:2]}) and RGB ({rgb.shape[:2]}) sizes don't match") print(f" Resizing depth to match RGB...") depth = cv2.resize(depth, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST) # Convert depth to meters (assuming it's in mm or similar) # Depth should be float32 in meters for FoundationPose depth = depth.astype(np.float32) / 1000.0 # Convert mm to meters print(f"✓ RGB loaded: shape={rgb.shape}, dtype={rgb.dtype}") print(f"✓ Depth loaded: shape={depth.shape}, dtype={depth.dtype}, range=[{depth.min():.3f}, {depth.max():.3f}]m") return rgb, depth def test_client_initialization(): """Test that API client initializes without errors.""" print("=" * 60) print("Test 1: API Client Initialization") print("=" * 60) try: client = FoundationPoseClient(api_url="https://gpue-foundationpose.hf.space") print("✓ API client initialized successfully") return client except Exception as e: print(f"✗ API client initialization failed: {e}") return None def test_cad_initialization(client, mesh_path): """Test CAD-based object initialization via API.""" print("\n" + "=" * 60) print("Test 2: CAD-Based Initialization via API") print("=" * 60) print(f"Mesh file: {mesh_path.name}") # Define camera intrinsics matching the actual image size (240x160) # Principal point (cx, cy) should be at image center # Focal lengths estimated assuming ~60° FOV camera_intrinsics = { "fx": 200.0, # Focal length adjusted for 240px width "fy": 200.0, # Focal length adjusted for 160px height "cx": 120.0, # Image center x (240/2) "cy": 80.0 # Image center y (160/2) } try: # Extract intrinsics fx = camera_intrinsics.get("fx", 600.0) fy = camera_intrinsics.get("fy", 600.0) cx = camera_intrinsics.get("cx", 320.0) cy = camera_intrinsics.get("cy", 240.0) # Call CAD-based initialization endpoint directly result = client.client.predict( object_id="t_shape", mesh_file=handle_file(str(mesh_path)), reference_files=[], # No reference images needed for CAD mode fx=fx, fy=fy, cx=cx, cy=cy, api_name="/gradio_initialize_cad" ) print(f"API result: {result}") if isinstance(result, str) and ("✓" in result or "initialized" in result.lower()): print("✓ Object initialized successfully with CAD model") return True elif isinstance(result, str) and ("Error" in result or "error" in result): print(f"✗ Object initialization failed: {result}") return False else: print("✓ Object initialized (assuming success)") return True except Exception as e: print(f"✗ Object initialization failed with exception: {e}") import traceback traceback.print_exc() return False def test_pose_estimation(client, query_image, depth_image, query_name): """Test pose estimation on a query image via API with depth and mask verification.""" print("\n" + "=" * 60) print("Test 3: Pose Estimation via API (with Depth & Mask)") print("=" * 60) print(f"Query image: {query_name}") # Define camera intrinsics (must match initialization and actual image size) camera_intrinsics = { "fx": 200.0, # Focal length for 240px width "fy": 200.0, # Focal length for 160px height "cx": 120.0, # Image center x (240/2) "cy": 80.0 # Image center y (160/2) } try: # Save images to temp files for API upload import tempfile rgb_temp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") depth_temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") # Save RGB as JPEG rgb_bgr = cv2.cvtColor(query_image, cv2.COLOR_RGB2BGR) cv2.imwrite(rgb_temp.name, rgb_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95]) # Save depth as 16-bit PNG (convert back from meters to mm) depth_uint16 = (depth_image * 1000.0).astype(np.uint16) cv2.imwrite(depth_temp.name, depth_uint16) print(f"Calling API with RGB + Depth images...") # Call Gradio API directly to get all outputs (text, viz, mask) result = client.client.predict( object_id="t_shape", query_image=handle_file(rgb_temp.name), depth_image=handle_file(depth_temp.name), fx=camera_intrinsics["fx"], fy=camera_intrinsics["fy"], cx=camera_intrinsics["cx"], cy=camera_intrinsics["cy"], api_name="/gradio_estimate" ) # Clean up temp files from pathlib import Path Path(rgb_temp.name).unlink() Path(depth_temp.name).unlink() # Result should be tuple: (text_output, viz_image, mask_image) if not isinstance(result, tuple) or len(result) != 3: print(f"✗ Unexpected result format: {type(result)}, length={len(result) if isinstance(result, tuple) else 'N/A'}") return False text_output, viz_image, mask_image = result print(f"\n✓ API returned 3 outputs as expected") print(f" - Text output: {len(text_output)} chars") print(f" - Viz image: {viz_image.shape if viz_image is not None else 'None'}") print(f" - Mask image: {mask_image.shape if mask_image is not None else 'None'}") # Verify mask was generated if mask_image is None: print(f"✗ Mask was not returned (expected auto-generated mask)") return False print(f"✓ Mask returned: shape={mask_image.shape}, dtype={mask_image.dtype}") # Check text output for success/failure if "Error" in text_output or "✗" in text_output: print(f"✗ Estimation failed: {text_output[:200]}") return False # Check if poses were detected if "No poses detected" in text_output or "⚠" in text_output: print(f"⚠ No poses detected (API working, but no objects found)") print(f"Output: {text_output[:300]}") return False # Success - parse output print(f"✓ Pose estimation succeeded!") print(f"\nEstimation output:") print(text_output) return True except Exception as e: print(f"✗ Pose estimation failed with exception: {e}") import traceback traceback.print_exc() return False def main(): """Run all tests.""" print("\n" + "=" * 60) print("FoundationPose CAD-Based API Test Suite") print("=" * 60) # Setup paths test_dir = Path(__file__).parent mesh_path = test_dir / "reference" / "t_shape" / "t_shape.obj" reference_dir = test_dir / "reference" / "t_shape" # Check if mesh file exists if not mesh_path.exists(): print(f"✗ Mesh file not found: {mesh_path}") return # Check if reference images exist (for query testing) if not reference_dir.exists(): print(f"✗ Reference directory not found: {reference_dir}") return print(f"\nUsing T-shape mesh: {mesh_path}") print(f"Using test data from: {reference_dir}") # Load test RGB and depth images try: rgb_image, depth_image = load_test_data(reference_dir) except FileNotFoundError as e: print(f"✗ {e}") return print(f"\n✓ Loaded RGB and depth images - testing with both") # Test 1: Initialize API client client = test_client_initialization() if client is None: print("\n" + "=" * 60) print("TESTS ABORTED: API client initialization failed") print("=" * 60) return # Test 2: Initialize object with CAD model success = test_cad_initialization(client, mesh_path) if not success: print("\n" + "=" * 60) print("TESTS ABORTED: CAD initialization failed") print("=" * 60) return # Test 3: Estimate pose using RGB + depth images success = test_pose_estimation(client, rgb_image, depth_image, "rgb_001.jpg") # Print final results print("\n" + "=" * 60) print("TEST SUMMARY") print("=" * 60) print("✓ API client initialization: PASSED") print("✓ CAD-based object initialization: PASSED") if success: print("✓ Pose estimation with RGB+depth: PASSED") print("✓ Mask generation verification: PASSED") print("\n🎉 ALL TESTS PASSED") else: print("⚠ Pose estimation: Issues detected (see output above)") print("\n📊 API TESTS PARTIALLY PASSED (2/3 core functions verified)") print("\nPossible reasons for no detections:") print(" - Camera intrinsics mismatch") print(" - Object not visible or occluded in image") print(" - Depth data quality issues") print(" - Mask segmentation inaccurate") print("=" * 60) if __name__ == "__main__": main()