BirdScopeAI / tests /test_modal_direct.py
facemelter's picture
Initial commit to hf space for hackathon
ff0e97f verified
"""
Simple test script for Modal Bird Classifier MCP Server
Tests the deployed Modal server directly with both tools
"""
import asyncio
import base64
import json
import os
from io import BytesIO
from pathlib import Path
from PIL import Image
from dotenv import load_dotenv
from fastmcp import Client
from fastmcp.client.transports import StreamableHttpTransport
load_dotenv()
# ============================================================================
# CONFIGURATION
# ============================================================================
MODAL_MCP_URL = os.getenv("MODAL_MCP_URL")
BIRD_API_KEY = os.getenv("BIRD_CLASSIFIER_API_KEY")
print("="*70)
print("[STATUS]: Testing Classifier Modal MCP Server...")
print("="*70)
print(f"[MODAL URL]: {MODAL_MCP_URL}")
print(f"[API KEY]: {'Set' if BIRD_API_KEY else 'Missing'}")
print("="*70)
if not MODAL_MCP_URL or not BIRD_API_KEY:
print("\n[ERROR]: Missing MODAL_MCP_URL or BIRD_CLASSIFIER_API_KEY in .env")
print(" Set these in your .env file:")
print(" MODAL_MCP_URL=https://your-username--bird-classifier-mcp-web.modal.run/mcp/")
print(" BIRD_CLASSIFIER_API_KEY=your-api-key")
exit(1)
# ============================================================================
# HELPER FUNCTIONS
# ===========================================================================
def image_to_base64(image_path: str) -> str:
"""Convert image file to base64 string."""
image = Image.open(image_path)
# Resize if too large
if max(image.size) > 800:
ratio = 800 / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
# Convert to RGB
if image.mode != 'RGB':
image = image.convert('RGB')
# Compress as JPEG
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=85, optimize=True)
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return img_base64
# ============================================================================
# TEST FUNCTIONS
# ============================================================================
async def test_list_tools():
"""Test: List available tools on Modal server"""
print("\n"+"="*70)
print("[TEST 1]: List Available Tools")
print("="*70)
try:
transport = StreamableHttpTransport(
url=MODAL_MCP_URL,
headers={"X-API-Key": BIRD_API_KEY}
)
client = Client(transport)
async with client:
tools = await client.list_tools()
print(f"\n[βœ… FOUND]: {len(tools)} tools:")
for tool in tools:
print(f" - {tool.name}")
print(f". {tool.description[:60]}...")
print("\n[βœ… TEST 1 PASSED]")
return True
except Exception as e:
print(f"\n[❌TEST 1 FAILED]: {e}")
return False
async def test_classify_from_url():
"""Test: Classify bird from URL"""
print("\n"+"="*70)
print("[TEST 2]: Classify Bird from URL")
print("="*70)
try:
#test_url = "https://images.unsplash.com/photo-1444464666168-49d633b86797?w=400"
test_url = "https://images.unsplash.com/photo-1551085254-e96b210db58a?q=80&w=680&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
print(f"\nURL: {test_url[:60]}...")
transport = StreamableHttpTransport(
url=MODAL_MCP_URL,
headers={"X-API-Key": BIRD_API_KEY}
)
client = Client(transport)
async with client:
result = await client.call_tool(
"classify_from_url",
arguments={"image_url": test_url}
)
if not result.content:
print("❌ No response from server")
return False
result_text = result.content[0].text if hasattr(result.content[0], 'text') else str(result.content[0])
data = json.loads(result_text)
if "error" in data:
print("❌ [ERROR]: Server error: {data['error']}")
return False
print(f"\nβœ… [CLASSIFICATION RESULT]:")
print(f" Species: {data.get('species')}")
print(f" Confidence: {data.get('confidence'):.1%}")
print(f" Source: {data.get('source')}")
print("\nβœ… TEST 2 PASSED")
return True
except Exception as e:
print(f"\n❌ TEST 2 FAILED: {e}")
import traceback
return False
async def test_classify_from_base64_with_local_file():
"""Test: Classify bird from local file (converted to base64)."""
print("\n"+"="*70)
print("[TEST 3]: Classify Bird from Local File (base64)")
print("="*70)
try:
# Find a test image
test_image = Path("/Users/jacobbinder/Desktop/hackathon/hackathon_draft/examples/another_bird.jpg")
if not test_image.exists():
print(f"\n[ERROR]: Test image not found: {test_image}")
print(" Skipping test 3...")
print("\n[TEST 3 SKIPPED]")
return True
print(f"\nFile: {test_image.name}")
# Convert to base64
print("[STATUS]: Converting image to base64...")
img_base64 = image_to_base64(str(test_image))
print(f"Base64 size: {len(img_base64) / 1024:.1f} KB")
transport = StreamableHttpTransport(
url=MODAL_MCP_URL,
headers={"X-API-Key": BIRD_API_KEY}
)
client = Client(transport)
async with client:
print("[STATUS]:Sending to Modal server...")
result = await client.call_tool(
"classify_from_base64",
arguments={"image_data": img_base64}
)
if not result.content:
print("❌ [ERROR]: No response from server")
return False
result_text = result.content[0].text if hasattr(result.content[0], 'text') else str(result.content[0])
data = json.loads(result_text)
if "error" in data:
print(f"[SERVER ERROR]: {data['error']}")
return False
print(f"\n[CLASSIFICATION RESULT]:")
print(f" Species: {data.get('species')}")
print(f" Confidence: {data.get('confidence'):.1%}")
print(f" Source: {data.get('source')}")
print("\nβœ… [TEST 3 PASSED]")
return True
except Exception as e:
print(f"\n❌ [TEST 3 FAILED]")
import traceback
traceback.print_exc()
return False
async def test_auth_failure():
"""Test: Verify API key authentication work."""
print("\n"+"="*70)
print("[TEST 4]: API Key Authentication")
print("="*70)
try:
print("\nTesting with INVALID API key...")
transport = StreamableHttpTransport(
url=MODAL_MCP_URL,
headers={"X-API-Key": "invalid-key-123"}
)
client = Client(transport)
async with client:
# Try to list tool - should fail with 401
tools = await client.list_tools()
print("❌ Should have failed with 401")
return False
except Exception as e:
if "401" in str(e) or "Unauthorized" in str(e) or "Invalid" in str(e):
print(f"βœ… Correctly rejected invalid API key: {str(e)[:60]}...")
print("\nβœ… [TEST 4 PASSED]")
return True
else:
print(f"❌ [ERROR]: {e}")
return False
# ============================================================================
# MAIN
# ============================================================================
async def main():
"""Run all tests"""
print("\n")
results = []
# Test 1: List tools
results.append(("[TEST 1]: List Tools", await test_list_tools()))
# Test 2: Classify from URL
results.append(("[TEST 2]: Classify from URL", await test_classify_from_url()))
# Test 3: Classify from base64
results.append(("[TEST 3]: Classify from Base64", await test_classify_from_base64_with_local_file()))
# Test 4: Auth
results.append(("[TEST 4]: API Key Auth", await test_auth_failure()))
# Summary
print("\n"+"="*70)
print("[TEST SUMMARY]")
print("="*70)
passed = sum(1 for _, result in results if result)
total = len(results)
for test_name, result in results:
status = "βœ… [PASS]" if result else "❌ [FAIL]"
print(f"{status}: {test_name}")
print(f"\n[TOTAL]: {passed}/{total} tests passed")
if passed == total:
print("\nπŸŽ‰ All tests passed! Modal server is working correctly!")
else:
print(f"\n⚠️ {total - passed} test(s) failed. Check configuration and server status.")
print("="*70+"\n")
if __name__ == "__main__":
asyncio.run(main())