File size: 5,440 Bytes
bcc2f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Basic API tests for the Marine Species Identification API.
"""

import pytest
import base64
import io
from PIL import Image
import numpy as np
from fastapi.testclient import TestClient

from app.main import app

client = TestClient(app)


def create_test_image(width: int = 640, height: int = 480) -> str:
    """Create a test image and return as base64 string."""
    # Create a simple test image
    image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
    pil_image = Image.fromarray(image)
    
    # Convert to base64
    buffer = io.BytesIO()
    pil_image.save(buffer, format="JPEG")
    image_bytes = buffer.getvalue()
    
    return base64.b64encode(image_bytes).decode('utf-8')


class TestHealthEndpoints:
    """Test health and status endpoints."""
    
    def test_root_endpoint(self):
        """Test root endpoint."""
        response = client.get("/")
        assert response.status_code == 200
        data = response.json()
        assert "message" in data
        assert "version" in data
    
    def test_root_health(self):
        """Test root health endpoint."""
        response = client.get("/health")
        assert response.status_code == 200
        data = response.json()
        assert data["status"] == "ok"
    
    def test_health_check(self):
        """Test detailed health check."""
        response = client.get("/api/v1/health")
        assert response.status_code == 200
        data = response.json()
        assert "status" in data
        assert "model_loaded" in data
        assert "timestamp" in data
    
    def test_api_info(self):
        """Test API info endpoint."""
        response = client.get("/api/v1/info")
        assert response.status_code == 200
        data = response.json()
        assert "name" in data
        assert "version" in data
        assert "endpoints" in data
    
    def test_liveness_check(self):
        """Test liveness probe."""
        response = client.get("/api/v1/live")
        assert response.status_code == 200
        data = response.json()
        assert data["status"] == "alive"


class TestSpeciesEndpoints:
    """Test species-related endpoints."""
    
    def test_list_species(self):
        """Test species list endpoint."""
        response = client.get("/api/v1/species")
        assert response.status_code in [200, 503]  # May fail if model not loaded
        
        if response.status_code == 200:
            data = response.json()
            assert "species" in data
            assert "total_count" in data
            assert isinstance(data["species"], list)
    
    def test_get_species_info(self):
        """Test individual species info endpoint."""
        # This may fail if model is not loaded, which is expected in test environment
        response = client.get("/api/v1/species/0")
        assert response.status_code in [200, 404, 503]


class TestInferenceEndpoints:
    """Test inference endpoints."""
    
    def test_detect_invalid_image(self):
        """Test detection with invalid image data."""
        response = client.post(
            "/api/v1/detect",
            json={
                "image": "invalid_base64_data",
                "confidence_threshold": 0.25
            }
        )
        assert response.status_code in [400, 503]  # Bad request or service unavailable
    
    def test_detect_valid_request_format(self):
        """Test detection with valid request format."""
        test_image = create_test_image()
        
        response = client.post(
            "/api/v1/detect",
            json={
                "image": test_image,
                "confidence_threshold": 0.25,
                "iou_threshold": 0.45,
                "image_size": 640,
                "return_annotated_image": True
            }
        )
        
        # May return 503 if model is not loaded, which is expected in test environment
        assert response.status_code in [200, 503]
        
        if response.status_code == 200:
            data = response.json()
            assert "detections" in data
            assert "processing_time" in data
            assert "model_info" in data
            assert "image_dimensions" in data
    
    def test_detect_parameter_validation(self):
        """Test parameter validation."""
        test_image = create_test_image()
        
        # Test invalid confidence threshold
        response = client.post(
            "/api/v1/detect",
            json={
                "image": test_image,
                "confidence_threshold": 1.5  # Invalid: > 1.0
            }
        )
        assert response.status_code == 422  # Validation error
        
        # Test invalid image size
        response = client.post(
            "/api/v1/detect",
            json={
                "image": test_image,
                "image_size": 100  # Invalid: < 320
            }
        )
        assert response.status_code == 422  # Validation error


class TestErrorHandling:
    """Test error handling."""
    
    def test_404_endpoint(self):
        """Test non-existent endpoint."""
        response = client.get("/api/v1/nonexistent")
        assert response.status_code == 404
    
    def test_method_not_allowed(self):
        """Test wrong HTTP method."""
        response = client.get("/api/v1/detect")  # Should be POST
        assert response.status_code == 405


if __name__ == "__main__":
    pytest.main([__file__])