File size: 2,655 Bytes
9c73b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
"""

Test script for Plant Disease Detection API

"""

import requests
import time
from pathlib import Path

def test_api(base_url: str = "http://localhost:7860"):
    """Test all API endpoints"""

    print(f"πŸ§ͺ Testing API at {base_url}")

    # Test 1: Health check
    print("\n1. Testing health endpoint...")
    try:
        response = requests.get(f"{base_url}/health")
        if response.status_code == 200:
            data = response.json()
            print(f"βœ… Health check: {data}")
        else:
            print(f"❌ Health check failed: {response.status_code}")
            return False
    except Exception as e:
        print(f"❌ Health check error: {e}")
        return False

    # Test 2: Root endpoint
    print("\n2. Testing root endpoint...")
    try:
        response = requests.get(f"{base_url}/")
        if response.status_code == 200:
            data = response.json()
            print(f"βœ… Root endpoint: {data}")
        else:
            print(f"❌ Root endpoint failed: {response.status_code}")
    except Exception as e:
        print(f"❌ Root endpoint error: {e}")

    # Test 3: Prediction with sample image
    print("\n3. Testing prediction endpoint...")

    # Create a simple test image (1x1 pixel)
    from PIL import Image
    import io

    # Create a small green image (plant-like)
    test_image = Image.new('RGB', (224, 224), color=(50, 150, 50))
    img_buffer = io.BytesIO()
    test_image.save(img_buffer, format='JPEG')
    img_bytes = img_buffer.getvalue()

    try:
        files = {"file": ("test_plant.jpg", img_bytes, "image/jpeg")}
        response = requests.post(f"{base_url}/predict", files=files, timeout=60)

        if response.status_code == 200:
            result = response.json()
            print("βœ… Prediction successful!")
            print(f"   Disease: {result.get('predicted_disease', 'N/A')}")
            print(".2%")
            print(f"   Unknown: {result.get('is_unknown', 'N/A')}")
            print(f"   Top neighbors: {len(result.get('topk_neighbors', []))}")
        else:
            print(f"❌ Prediction failed: {response.status_code}")
            print(f"   Response: {response.text}")
            return False

    except Exception as e:
        print(f"❌ Prediction error: {e}")
        return False

    print("\nπŸŽ‰ All tests passed!")
    return True

if __name__ == "__main__":
    import sys

    base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860"

    success = test_api(base_url)
    sys.exit(0 if success else 1)