File size: 4,999 Bytes
c8df794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Test script for all API endpoints
"""
import requests
import json
import os
from pathlib import Path

# API base URL
BASE_URL = "http://localhost:8000"

def test_health_endpoint():
    """Test the health check endpoint"""
    print("πŸ” Testing /health endpoint...")
    try:
        response = requests.get(f"{BASE_URL}/health")
        print(f"Status Code: {response.status_code}")
        print(f"Response: {response.json()}")
        return response.status_code == 200
    except Exception as e:
        print(f"❌ Error: {e}")
        return False

def test_classes_endpoint():
    """Test the classes endpoint"""
    print("\nπŸ” Testing /classes endpoint...")
    try:
        response = requests.get(f"{BASE_URL}/classes")
        print(f"Status Code: {response.status_code}")
        data = response.json()
        print(f"Number of classes: {len(data.get('classes', []))}")
        print(f"Classes: {data.get('classes', [])[:3]}...")  # Show first 3
        return response.status_code == 200
    except Exception as e:
        print(f"❌ Error: {e}")
        return False

def test_model_info_endpoint():
    """Test the model info endpoint"""
    print("\nπŸ” Testing /model_info endpoint...")
    try:
        response = requests.get(f"{BASE_URL}/model_info")
        print(f"Status Code: {response.status_code}")
        data = response.json()
        print(f"Model info keys: {list(data.keys())}")
        return response.status_code == 200
    except Exception as e:
        print(f"❌ Error: {e}")
        return False

def test_predict_endpoint():
    """Test the prediction endpoint"""
    print("\nπŸ” Testing /predict endpoint...")
    
    # Check if test image exists
    image_path = "test_leaf_sample.jpg"
    if not os.path.exists(image_path):
        print(f"❌ Test image not found: {image_path}")
        return False
    
    try:
        with open(image_path, 'rb') as f:
            files = {'file': ('test_image.jpg', f, 'image/jpeg')}
            response = requests.post(f"{BASE_URL}/predict", files=files)
        
        print(f"Status Code: {response.status_code}")
        
        if response.status_code == 200:
            data = response.json()
            print(f"Prediction: {data.get('predicted_class', 'N/A')}")
            print(f"Confidence: {data.get('confidence', 'N/A')}")
            print(f"Risk Level: {data.get('risk_level', 'N/A')}")
            print(f"Has explanation: {'explanation' in data}")
            return True
        else:
            print(f"Response: {response.text}")
            return False
            
    except Exception as e:
        print(f"❌ Error: {e}")
        return False

def test_batch_predict_endpoint():
    """Test the batch prediction endpoint"""
    print("\nπŸ” Testing /batch_predict endpoint...")
    
    # Check if test image exists
    image_path = "test_leaf_sample.jpg"
    if not os.path.exists(image_path):
        print(f"❌ Test image not found: {image_path}")
        return False
    
    try:
        # Test with single image (simulating batch with one image)
        with open(image_path, 'rb') as f:
            files = {'files': ('test_image.jpg', f, 'image/jpeg')}
            response = requests.post(f"{BASE_URL}/batch_predict", files=files)
        
        print(f"Status Code: {response.status_code}")
        
        if response.status_code == 200:
            data = response.json()
            print(f"Number of results: {len(data.get('results', []))}")
            if data.get('results'):
                first_result = data['results'][0]
                print(f"First result prediction: {first_result.get('predicted_class', 'N/A')}")
            return True
        else:
            print(f"Response: {response.text}")
            return False
            
    except Exception as e:
        print(f"❌ Error: {e}")
        return False

def main():
    """Run all API tests"""
    print("πŸš€ Starting API Tests...")
    print("=" * 50)
    
    tests = [
        ("Health Check", test_health_endpoint),
        ("Classes Endpoint", test_classes_endpoint),
        ("Model Info", test_model_info_endpoint),
        ("Predict Endpoint", test_predict_endpoint),
        ("Batch Predict", test_batch_predict_endpoint)
    ]
    
    results = {}
    for test_name, test_func in tests:
        results[test_name] = test_func()
        print()
    
    print("=" * 50)
    print("πŸ“Š Test Results Summary:")
    print("=" * 50)
    
    passed = 0
    total = len(tests)
    
    for test_name, result in results.items():
        status = "βœ… PASSED" if result else "❌ FAILED"
        print(f"{test_name}: {status}")
        if result:
            passed += 1
    
    print(f"\nOverall: {passed}/{total} tests passed")
    
    if passed == total:
        print("πŸŽ‰ All API tests passed!")
    else:
        print("⚠️  Some API tests failed. Check the output above for details.")
    
    return passed == total

if __name__ == "__main__":
    main()