|
|
""" |
|
|
Simple test script for Modal Bird Classifier MCP Server |
|
|
Tests the deployed Modal server directly with both tools |
|
|
""" |
|
|
import asyncio |
|
|
import base64 |
|
|
import json |
|
|
import os |
|
|
from io import BytesIO |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
from fastmcp import Client |
|
|
from fastmcp.client.transports import StreamableHttpTransport |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODAL_MCP_URL = os.getenv("MODAL_MCP_URL") |
|
|
BIRD_API_KEY = os.getenv("BIRD_CLASSIFIER_API_KEY") |
|
|
|
|
|
print("="*70) |
|
|
print("[STATUS]: Testing Classifier Modal MCP Server...") |
|
|
print("="*70) |
|
|
print(f"[MODAL URL]: {MODAL_MCP_URL}") |
|
|
print(f"[API KEY]: {'Set' if BIRD_API_KEY else 'Missing'}") |
|
|
print("="*70) |
|
|
|
|
|
if not MODAL_MCP_URL or not BIRD_API_KEY: |
|
|
print("\n[ERROR]: Missing MODAL_MCP_URL or BIRD_CLASSIFIER_API_KEY in .env") |
|
|
print(" Set these in your .env file:") |
|
|
print(" MODAL_MCP_URL=https://your-username--bird-classifier-mcp-web.modal.run/mcp/") |
|
|
print(" BIRD_CLASSIFIER_API_KEY=your-api-key") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def image_to_base64(image_path: str) -> str: |
|
|
"""Convert image file to base64 string.""" |
|
|
image = Image.open(image_path) |
|
|
|
|
|
|
|
|
if max(image.size) > 800: |
|
|
ratio = 800 / max(image.size) |
|
|
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) |
|
|
image = image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
|
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="JPEG", quality=85, optimize=True) |
|
|
img_base64 = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
return img_base64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_list_tools(): |
|
|
"""Test: List available tools on Modal server""" |
|
|
print("\n"+"="*70) |
|
|
print("[TEST 1]: List Available Tools") |
|
|
print("="*70) |
|
|
|
|
|
try: |
|
|
transport = StreamableHttpTransport( |
|
|
url=MODAL_MCP_URL, |
|
|
headers={"X-API-Key": BIRD_API_KEY} |
|
|
) |
|
|
|
|
|
client = Client(transport) |
|
|
|
|
|
async with client: |
|
|
tools = await client.list_tools() |
|
|
print(f"\n[β
FOUND]: {len(tools)} tools:") |
|
|
for tool in tools: |
|
|
print(f" - {tool.name}") |
|
|
print(f". {tool.description[:60]}...") |
|
|
|
|
|
print("\n[β
TEST 1 PASSED]") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n[βTEST 1 FAILED]: {e}") |
|
|
return False |
|
|
|
|
|
async def test_classify_from_url(): |
|
|
"""Test: Classify bird from URL""" |
|
|
print("\n"+"="*70) |
|
|
print("[TEST 2]: Classify Bird from URL") |
|
|
print("="*70) |
|
|
|
|
|
try: |
|
|
|
|
|
test_url = "https://images.unsplash.com/photo-1551085254-e96b210db58a?q=80&w=680&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" |
|
|
print(f"\nURL: {test_url[:60]}...") |
|
|
|
|
|
transport = StreamableHttpTransport( |
|
|
url=MODAL_MCP_URL, |
|
|
headers={"X-API-Key": BIRD_API_KEY} |
|
|
) |
|
|
|
|
|
client = Client(transport) |
|
|
|
|
|
async with client: |
|
|
result = await client.call_tool( |
|
|
"classify_from_url", |
|
|
arguments={"image_url": test_url} |
|
|
) |
|
|
|
|
|
if not result.content: |
|
|
print("β No response from server") |
|
|
return False |
|
|
|
|
|
result_text = result.content[0].text if hasattr(result.content[0], 'text') else str(result.content[0]) |
|
|
data = json.loads(result_text) |
|
|
|
|
|
if "error" in data: |
|
|
print("β [ERROR]: Server error: {data['error']}") |
|
|
return False |
|
|
|
|
|
print(f"\nβ
[CLASSIFICATION RESULT]:") |
|
|
print(f" Species: {data.get('species')}") |
|
|
print(f" Confidence: {data.get('confidence'):.1%}") |
|
|
print(f" Source: {data.get('source')}") |
|
|
|
|
|
print("\nβ
TEST 2 PASSED") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ TEST 2 FAILED: {e}") |
|
|
import traceback |
|
|
return False |
|
|
|
|
|
async def test_classify_from_base64_with_local_file(): |
|
|
"""Test: Classify bird from local file (converted to base64).""" |
|
|
print("\n"+"="*70) |
|
|
print("[TEST 3]: Classify Bird from Local File (base64)") |
|
|
print("="*70) |
|
|
|
|
|
try: |
|
|
|
|
|
test_image = Path("/Users/jacobbinder/Desktop/hackathon/hackathon_draft/examples/another_bird.jpg") |
|
|
|
|
|
if not test_image.exists(): |
|
|
print(f"\n[ERROR]: Test image not found: {test_image}") |
|
|
print(" Skipping test 3...") |
|
|
print("\n[TEST 3 SKIPPED]") |
|
|
return True |
|
|
|
|
|
print(f"\nFile: {test_image.name}") |
|
|
|
|
|
|
|
|
print("[STATUS]: Converting image to base64...") |
|
|
img_base64 = image_to_base64(str(test_image)) |
|
|
print(f"Base64 size: {len(img_base64) / 1024:.1f} KB") |
|
|
|
|
|
transport = StreamableHttpTransport( |
|
|
url=MODAL_MCP_URL, |
|
|
headers={"X-API-Key": BIRD_API_KEY} |
|
|
) |
|
|
|
|
|
client = Client(transport) |
|
|
|
|
|
async with client: |
|
|
print("[STATUS]:Sending to Modal server...") |
|
|
result = await client.call_tool( |
|
|
"classify_from_base64", |
|
|
arguments={"image_data": img_base64} |
|
|
) |
|
|
|
|
|
if not result.content: |
|
|
print("β [ERROR]: No response from server") |
|
|
return False |
|
|
|
|
|
result_text = result.content[0].text if hasattr(result.content[0], 'text') else str(result.content[0]) |
|
|
data = json.loads(result_text) |
|
|
|
|
|
if "error" in data: |
|
|
print(f"[SERVER ERROR]: {data['error']}") |
|
|
return False |
|
|
|
|
|
print(f"\n[CLASSIFICATION RESULT]:") |
|
|
print(f" Species: {data.get('species')}") |
|
|
print(f" Confidence: {data.get('confidence'):.1%}") |
|
|
print(f" Source: {data.get('source')}") |
|
|
|
|
|
print("\nβ
[TEST 3 PASSED]") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ [TEST 3 FAILED]") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
async def test_auth_failure(): |
|
|
"""Test: Verify API key authentication work.""" |
|
|
print("\n"+"="*70) |
|
|
print("[TEST 4]: API Key Authentication") |
|
|
print("="*70) |
|
|
|
|
|
try: |
|
|
print("\nTesting with INVALID API key...") |
|
|
|
|
|
transport = StreamableHttpTransport( |
|
|
url=MODAL_MCP_URL, |
|
|
headers={"X-API-Key": "invalid-key-123"} |
|
|
) |
|
|
|
|
|
client = Client(transport) |
|
|
|
|
|
async with client: |
|
|
|
|
|
tools = await client.list_tools() |
|
|
print("β Should have failed with 401") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
if "401" in str(e) or "Unauthorized" in str(e) or "Invalid" in str(e): |
|
|
print(f"β
Correctly rejected invalid API key: {str(e)[:60]}...") |
|
|
print("\nβ
[TEST 4 PASSED]") |
|
|
return True |
|
|
else: |
|
|
print(f"β [ERROR]: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Run all tests""" |
|
|
print("\n") |
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
results.append(("[TEST 1]: List Tools", await test_list_tools())) |
|
|
|
|
|
|
|
|
results.append(("[TEST 2]: Classify from URL", await test_classify_from_url())) |
|
|
|
|
|
|
|
|
results.append(("[TEST 3]: Classify from Base64", await test_classify_from_base64_with_local_file())) |
|
|
|
|
|
|
|
|
results.append(("[TEST 4]: API Key Auth", await test_auth_failure())) |
|
|
|
|
|
|
|
|
print("\n"+"="*70) |
|
|
print("[TEST SUMMARY]") |
|
|
print("="*70) |
|
|
|
|
|
passed = sum(1 for _, result in results if result) |
|
|
total = len(results) |
|
|
|
|
|
for test_name, result in results: |
|
|
status = "β
[PASS]" if result else "β [FAIL]" |
|
|
print(f"{status}: {test_name}") |
|
|
|
|
|
print(f"\n[TOTAL]: {passed}/{total} tests passed") |
|
|
|
|
|
if passed == total: |
|
|
print("\nπ All tests passed! Modal server is working correctly!") |
|
|
else: |
|
|
print(f"\nβ οΈ {total - passed} test(s) failed. Check configuration and server status.") |
|
|
|
|
|
print("="*70+"\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|