foundationpose / tests /test_estimator.py
Georg
Update test to verify mask generation and add psutil dependency
f7e2564
"""
Test script for FoundationPose HuggingFace API.
This test verifies that the API can:
1. Initialize an object with CAD model (T-shape mesh)
2. Estimate pose from query images
"""
import sys
from pathlib import Path
import numpy as np
import cv2
from gradio_client import Client, handle_file
# Add parent directory to path to import client
sys.path.insert(0, str(Path(__file__).parent.parent))
from client import FoundationPoseClient
def load_test_data(reference_dir: Path):
"""Load RGB and depth test images from t_shape directory."""
rgb_path = reference_dir / "rgb_001.jpg"
depth_path = reference_dir / "depth_001.png"
# Load RGB image
print(f"Loading RGB: {rgb_path}")
rgb = cv2.imread(str(rgb_path))
if rgb is None:
raise FileNotFoundError(f"Could not load RGB image: {rgb_path}")
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
# Load depth image (16-bit PNG)
print(f"Loading depth: {depth_path}")
depth = cv2.imread(str(depth_path), cv2.IMREAD_ANYDEPTH)
if depth is None:
raise FileNotFoundError(f"Could not load depth image: {depth_path}")
# Check if depth needs resizing to match RGB
if depth.shape[:2] != rgb.shape[:2]:
print(f"⚠ Warning: Depth ({depth.shape[:2]}) and RGB ({rgb.shape[:2]}) sizes don't match")
print(f" Resizing depth to match RGB...")
depth = cv2.resize(depth, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
# Convert depth to meters (assuming it's in mm or similar)
# Depth should be float32 in meters for FoundationPose
depth = depth.astype(np.float32) / 1000.0 # Convert mm to meters
print(f"✓ RGB loaded: shape={rgb.shape}, dtype={rgb.dtype}")
print(f"✓ Depth loaded: shape={depth.shape}, dtype={depth.dtype}, range=[{depth.min():.3f}, {depth.max():.3f}]m")
return rgb, depth
def test_client_initialization():
"""Test that API client initializes without errors."""
print("=" * 60)
print("Test 1: API Client Initialization")
print("=" * 60)
try:
client = FoundationPoseClient(api_url="https://gpue-foundationpose.hf.space")
print("✓ API client initialized successfully")
return client
except Exception as e:
print(f"✗ API client initialization failed: {e}")
return None
def test_cad_initialization(client, mesh_path):
"""Test CAD-based object initialization via API."""
print("\n" + "=" * 60)
print("Test 2: CAD-Based Initialization via API")
print("=" * 60)
print(f"Mesh file: {mesh_path.name}")
# Define camera intrinsics matching the actual image size (240x160)
# Principal point (cx, cy) should be at image center
# Focal lengths estimated assuming ~60° FOV
camera_intrinsics = {
"fx": 200.0, # Focal length adjusted for 240px width
"fy": 200.0, # Focal length adjusted for 160px height
"cx": 120.0, # Image center x (240/2)
"cy": 80.0 # Image center y (160/2)
}
try:
# Extract intrinsics
fx = camera_intrinsics.get("fx", 600.0)
fy = camera_intrinsics.get("fy", 600.0)
cx = camera_intrinsics.get("cx", 320.0)
cy = camera_intrinsics.get("cy", 240.0)
# Call CAD-based initialization endpoint directly
result = client.client.predict(
object_id="t_shape",
mesh_file=handle_file(str(mesh_path)),
reference_files=[], # No reference images needed for CAD mode
fx=fx,
fy=fy,
cx=cx,
cy=cy,
api_name="/gradio_initialize_cad"
)
print(f"API result: {result}")
if isinstance(result, str) and ("✓" in result or "initialized" in result.lower()):
print("✓ Object initialized successfully with CAD model")
return True
elif isinstance(result, str) and ("Error" in result or "error" in result):
print(f"✗ Object initialization failed: {result}")
return False
else:
print("✓ Object initialized (assuming success)")
return True
except Exception as e:
print(f"✗ Object initialization failed with exception: {e}")
import traceback
traceback.print_exc()
return False
def test_pose_estimation(client, query_image, depth_image, query_name):
"""Test pose estimation on a query image via API with depth and mask verification."""
print("\n" + "=" * 60)
print("Test 3: Pose Estimation via API (with Depth & Mask)")
print("=" * 60)
print(f"Query image: {query_name}")
# Define camera intrinsics (must match initialization and actual image size)
camera_intrinsics = {
"fx": 200.0, # Focal length for 240px width
"fy": 200.0, # Focal length for 160px height
"cx": 120.0, # Image center x (240/2)
"cy": 80.0 # Image center y (160/2)
}
try:
# Save images to temp files for API upload
import tempfile
rgb_temp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
depth_temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
# Save RGB as JPEG
rgb_bgr = cv2.cvtColor(query_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(rgb_temp.name, rgb_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95])
# Save depth as 16-bit PNG (convert back from meters to mm)
depth_uint16 = (depth_image * 1000.0).astype(np.uint16)
cv2.imwrite(depth_temp.name, depth_uint16)
print(f"Calling API with RGB + Depth images...")
# Call Gradio API directly to get all outputs (text, viz, mask)
result = client.client.predict(
object_id="t_shape",
query_image=handle_file(rgb_temp.name),
depth_image=handle_file(depth_temp.name),
fx=camera_intrinsics["fx"],
fy=camera_intrinsics["fy"],
cx=camera_intrinsics["cx"],
cy=camera_intrinsics["cy"],
api_name="/gradio_estimate"
)
# Clean up temp files
from pathlib import Path
Path(rgb_temp.name).unlink()
Path(depth_temp.name).unlink()
# Result should be tuple: (text_output, viz_image, mask_image)
if not isinstance(result, tuple) or len(result) != 3:
print(f"✗ Unexpected result format: {type(result)}, length={len(result) if isinstance(result, tuple) else 'N/A'}")
return False
text_output, viz_image, mask_image = result
print(f"\n✓ API returned 3 outputs as expected")
print(f" - Text output: {len(text_output)} chars")
print(f" - Viz image: {viz_image.shape if viz_image is not None else 'None'}")
print(f" - Mask image: {mask_image.shape if mask_image is not None else 'None'}")
# Verify mask was generated
if mask_image is None:
print(f"✗ Mask was not returned (expected auto-generated mask)")
return False
print(f"✓ Mask returned: shape={mask_image.shape}, dtype={mask_image.dtype}")
# Check text output for success/failure
if "Error" in text_output or "✗" in text_output:
print(f"✗ Estimation failed: {text_output[:200]}")
return False
# Check if poses were detected
if "No poses detected" in text_output or "⚠" in text_output:
print(f"⚠ No poses detected (API working, but no objects found)")
print(f"Output: {text_output[:300]}")
return False
# Success - parse output
print(f"✓ Pose estimation succeeded!")
print(f"\nEstimation output:")
print(text_output)
return True
except Exception as e:
print(f"✗ Pose estimation failed with exception: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Run all tests."""
print("\n" + "=" * 60)
print("FoundationPose CAD-Based API Test Suite")
print("=" * 60)
# Setup paths
test_dir = Path(__file__).parent
mesh_path = test_dir / "reference" / "t_shape" / "t_shape.obj"
reference_dir = test_dir / "reference" / "t_shape"
# Check if mesh file exists
if not mesh_path.exists():
print(f"✗ Mesh file not found: {mesh_path}")
return
# Check if reference images exist (for query testing)
if not reference_dir.exists():
print(f"✗ Reference directory not found: {reference_dir}")
return
print(f"\nUsing T-shape mesh: {mesh_path}")
print(f"Using test data from: {reference_dir}")
# Load test RGB and depth images
try:
rgb_image, depth_image = load_test_data(reference_dir)
except FileNotFoundError as e:
print(f"✗ {e}")
return
print(f"\n✓ Loaded RGB and depth images - testing with both")
# Test 1: Initialize API client
client = test_client_initialization()
if client is None:
print("\n" + "=" * 60)
print("TESTS ABORTED: API client initialization failed")
print("=" * 60)
return
# Test 2: Initialize object with CAD model
success = test_cad_initialization(client, mesh_path)
if not success:
print("\n" + "=" * 60)
print("TESTS ABORTED: CAD initialization failed")
print("=" * 60)
return
# Test 3: Estimate pose using RGB + depth images
success = test_pose_estimation(client, rgb_image, depth_image, "rgb_001.jpg")
# Print final results
print("\n" + "=" * 60)
print("TEST SUMMARY")
print("=" * 60)
print("✓ API client initialization: PASSED")
print("✓ CAD-based object initialization: PASSED")
if success:
print("✓ Pose estimation with RGB+depth: PASSED")
print("✓ Mask generation verification: PASSED")
print("\n🎉 ALL TESTS PASSED")
else:
print("⚠ Pose estimation: Issues detected (see output above)")
print("\n📊 API TESTS PARTIALLY PASSED (2/3 core functions verified)")
print("\nPossible reasons for no detections:")
print(" - Camera intrinsics mismatch")
print(" - Object not visible or occluded in image")
print(" - Depth data quality issues")
print(" - Mask segmentation inaccurate")
print("=" * 60)
if __name__ == "__main__":
main()