File size: 5,440 Bytes
bcc2f7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""
Basic API tests for the Marine Species Identification API.
"""
import pytest
import base64
import io
from PIL import Image
import numpy as np
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def create_test_image(width: int = 640, height: int = 480) -> str:
"""Create a test image and return as base64 string."""
# Create a simple test image
image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
pil_image = Image.fromarray(image)
# Convert to base64
buffer = io.BytesIO()
pil_image.save(buffer, format="JPEG")
image_bytes = buffer.getvalue()
return base64.b64encode(image_bytes).decode('utf-8')
class TestHealthEndpoints:
"""Test health and status endpoints."""
def test_root_endpoint(self):
"""Test root endpoint."""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "message" in data
assert "version" in data
def test_root_health(self):
"""Test root health endpoint."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
def test_health_check(self):
"""Test detailed health check."""
response = client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "model_loaded" in data
assert "timestamp" in data
def test_api_info(self):
"""Test API info endpoint."""
response = client.get("/api/v1/info")
assert response.status_code == 200
data = response.json()
assert "name" in data
assert "version" in data
assert "endpoints" in data
def test_liveness_check(self):
"""Test liveness probe."""
response = client.get("/api/v1/live")
assert response.status_code == 200
data = response.json()
assert data["status"] == "alive"
class TestSpeciesEndpoints:
"""Test species-related endpoints."""
def test_list_species(self):
"""Test species list endpoint."""
response = client.get("/api/v1/species")
assert response.status_code in [200, 503] # May fail if model not loaded
if response.status_code == 200:
data = response.json()
assert "species" in data
assert "total_count" in data
assert isinstance(data["species"], list)
def test_get_species_info(self):
"""Test individual species info endpoint."""
# This may fail if model is not loaded, which is expected in test environment
response = client.get("/api/v1/species/0")
assert response.status_code in [200, 404, 503]
class TestInferenceEndpoints:
"""Test inference endpoints."""
def test_detect_invalid_image(self):
"""Test detection with invalid image data."""
response = client.post(
"/api/v1/detect",
json={
"image": "invalid_base64_data",
"confidence_threshold": 0.25
}
)
assert response.status_code in [400, 503] # Bad request or service unavailable
def test_detect_valid_request_format(self):
"""Test detection with valid request format."""
test_image = create_test_image()
response = client.post(
"/api/v1/detect",
json={
"image": test_image,
"confidence_threshold": 0.25,
"iou_threshold": 0.45,
"image_size": 640,
"return_annotated_image": True
}
)
# May return 503 if model is not loaded, which is expected in test environment
assert response.status_code in [200, 503]
if response.status_code == 200:
data = response.json()
assert "detections" in data
assert "processing_time" in data
assert "model_info" in data
assert "image_dimensions" in data
def test_detect_parameter_validation(self):
"""Test parameter validation."""
test_image = create_test_image()
# Test invalid confidence threshold
response = client.post(
"/api/v1/detect",
json={
"image": test_image,
"confidence_threshold": 1.5 # Invalid: > 1.0
}
)
assert response.status_code == 422 # Validation error
# Test invalid image size
response = client.post(
"/api/v1/detect",
json={
"image": test_image,
"image_size": 100 # Invalid: < 320
}
)
assert response.status_code == 422 # Validation error
class TestErrorHandling:
"""Test error handling."""
def test_404_endpoint(self):
"""Test non-existent endpoint."""
response = client.get("/api/v1/nonexistent")
assert response.status_code == 404
def test_method_not_allowed(self):
"""Test wrong HTTP method."""
response = client.get("/api/v1/detect") # Should be POST
assert response.status_code == 405
if __name__ == "__main__":
pytest.main([__file__])
|