File size: 5,755 Bytes
bcc2f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/usr/bin/env python3
"""
Simple test script for the Marine Species Identification API.
This script can be used to quickly test the API functionality.
"""

import requests
import base64
import json
import time
from PIL import Image
import numpy as np
import io


def create_test_image(width: int = 640, height: int = 480) -> str:
    """Create a test image and return as base64 string."""
    # Create a simple test image with some patterns
    image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
    
    # Add some simple patterns to make it more interesting
    image[100:200, 100:200] = [255, 0, 0]  # Red square
    image[300:400, 300:400] = [0, 255, 0]  # Green square
    
    pil_image = Image.fromarray(image)
    
    # Convert to base64
    buffer = io.BytesIO()
    pil_image.save(buffer, format="JPEG", quality=85)
    image_bytes = buffer.getvalue()
    
    return base64.b64encode(image_bytes).decode('utf-8')


def test_api(base_url: str = "http://localhost:7860"):
    """Test the API endpoints."""
    
    print(f"🧪 Testing Marine Species Identification API at {base_url}")
    print("=" * 60)
    
    # Test 1: Root endpoint
    print("1. Testing root endpoint...")
    try:
        response = requests.get(f"{base_url}/")
        print(f"   Status: {response.status_code}")
        if response.status_code == 200:
            print(f"   Response: {response.json()}")
        print()
    except Exception as e:
        print(f"   Error: {e}")
        return
    
    # Test 2: Health check
    print("2. Testing health check...")
    try:
        response = requests.get(f"{base_url}/api/v1/health")
        print(f"   Status: {response.status_code}")
        if response.status_code == 200:
            health_data = response.json()
            print(f"   API Status: {health_data.get('status')}")
            print(f"   Model Loaded: {health_data.get('model_loaded')}")
        print()
    except Exception as e:
        print(f"   Error: {e}")
        print()
    
    # Test 3: API info
    print("3. Testing API info...")
    try:
        response = requests.get(f"{base_url}/api/v1/info")
        print(f"   Status: {response.status_code}")
        if response.status_code == 200:
            info_data = response.json()
            print(f"   API Name: {info_data.get('name')}")
            print(f"   Version: {info_data.get('version')}")
            model_info = info_data.get('model_info', {})
            print(f"   Model Classes: {model_info.get('total_classes')}")
        print()
    except Exception as e:
        print(f"   Error: {e}")
        print()
    
    # Test 4: Species list
    print("4. Testing species list...")
    try:
        response = requests.get(f"{base_url}/api/v1/species")
        print(f"   Status: {response.status_code}")
        if response.status_code == 200:
            species_data = response.json()
            total_species = species_data.get('total_count', 0)
            print(f"   Total Species: {total_species}")
            if total_species > 0:
                print(f"   First 3 species:")
                for species in species_data.get('species', [])[:3]:
                    print(f"     - {species.get('class_name')} (ID: {species.get('class_id')})")
        print()
    except Exception as e:
        print(f"   Error: {e}")
        print()
    
    # Test 5: Detection with test image
    print("5. Testing marine species detection...")
    try:
        # Create a test image
        print("   Creating test image...")
        test_image_b64 = create_test_image()
        
        # Prepare request
        detection_request = {
            "image": test_image_b64,
            "confidence_threshold": 0.25,
            "iou_threshold": 0.45,
            "image_size": 640,
            "return_annotated_image": True
        }
        
        print("   Sending detection request...")
        start_time = time.time()
        
        response = requests.post(
            f"{base_url}/api/v1/detect",
            json=detection_request,
            timeout=30
        )
        
        end_time = time.time()
        request_time = end_time - start_time
        
        print(f"   Status: {response.status_code}")
        print(f"   Request Time: {request_time:.2f}s")
        
        if response.status_code == 200:
            detection_data = response.json()
            detections = detection_data.get('detections', [])
            processing_time = detection_data.get('processing_time', 0)
            
            print(f"   Processing Time: {processing_time:.3f}s")
            print(f"   Detections Found: {len(detections)}")
            
            if detections:
                print("   Top detections:")
                for i, detection in enumerate(detections[:3]):
                    print(f"     {i+1}. {detection.get('class_name')} "
                          f"(confidence: {detection.get('confidence'):.3f})")
            
            # Check if annotated image was returned
            if detection_data.get('annotated_image'):
                print("   ✅ Annotated image returned")
            else:
                print("   ❌ No annotated image returned")
                
        elif response.status_code == 503:
            print("   ⚠️  Service unavailable (model may not be loaded)")
        else:
            print(f"   ❌ Error: {response.text}")
        
        print()
        
    except Exception as e:
        print(f"   Error: {e}")
        print()
    
    print("🎉 API testing completed!")
    print("=" * 60)


if __name__ == "__main__":
    import sys
    
    # Allow custom base URL
    base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860"
    
    test_api(base_url)