|
|
""" |
|
|
Basic API tests for the Marine Species Identification API. |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
import base64 |
|
|
import io |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from fastapi.testclient import TestClient |
|
|
|
|
|
from app.main import app |
|
|
|
|
|
client = TestClient(app) |
|
|
|
|
|
|
|
|
def create_test_image(width: int = 640, height: int = 480) -> str: |
|
|
"""Create a test image and return as base64 string.""" |
|
|
|
|
|
image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) |
|
|
pil_image = Image.fromarray(image) |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
pil_image.save(buffer, format="JPEG") |
|
|
image_bytes = buffer.getvalue() |
|
|
|
|
|
return base64.b64encode(image_bytes).decode('utf-8') |
|
|
|
|
|
|
|
|
class TestHealthEndpoints: |
|
|
"""Test health and status endpoints.""" |
|
|
|
|
|
def test_root_endpoint(self): |
|
|
"""Test root endpoint.""" |
|
|
response = client.get("/") |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert "message" in data |
|
|
assert "version" in data |
|
|
|
|
|
def test_root_health(self): |
|
|
"""Test root health endpoint.""" |
|
|
response = client.get("/health") |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert data["status"] == "ok" |
|
|
|
|
|
def test_health_check(self): |
|
|
"""Test detailed health check.""" |
|
|
response = client.get("/api/v1/health") |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert "status" in data |
|
|
assert "model_loaded" in data |
|
|
assert "timestamp" in data |
|
|
|
|
|
def test_api_info(self): |
|
|
"""Test API info endpoint.""" |
|
|
response = client.get("/api/v1/info") |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert "name" in data |
|
|
assert "version" in data |
|
|
assert "endpoints" in data |
|
|
|
|
|
def test_liveness_check(self): |
|
|
"""Test liveness probe.""" |
|
|
response = client.get("/api/v1/live") |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert data["status"] == "alive" |
|
|
|
|
|
|
|
|
class TestSpeciesEndpoints: |
|
|
"""Test species-related endpoints.""" |
|
|
|
|
|
def test_list_species(self): |
|
|
"""Test species list endpoint.""" |
|
|
response = client.get("/api/v1/species") |
|
|
assert response.status_code in [200, 503] |
|
|
|
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
assert "species" in data |
|
|
assert "total_count" in data |
|
|
assert isinstance(data["species"], list) |
|
|
|
|
|
def test_get_species_info(self): |
|
|
"""Test individual species info endpoint.""" |
|
|
|
|
|
response = client.get("/api/v1/species/0") |
|
|
assert response.status_code in [200, 404, 503] |
|
|
|
|
|
|
|
|
class TestInferenceEndpoints: |
|
|
"""Test inference endpoints.""" |
|
|
|
|
|
def test_detect_invalid_image(self): |
|
|
"""Test detection with invalid image data.""" |
|
|
response = client.post( |
|
|
"/api/v1/detect", |
|
|
json={ |
|
|
"image": "invalid_base64_data", |
|
|
"confidence_threshold": 0.25 |
|
|
} |
|
|
) |
|
|
assert response.status_code in [400, 503] |
|
|
|
|
|
def test_detect_valid_request_format(self): |
|
|
"""Test detection with valid request format.""" |
|
|
test_image = create_test_image() |
|
|
|
|
|
response = client.post( |
|
|
"/api/v1/detect", |
|
|
json={ |
|
|
"image": test_image, |
|
|
"confidence_threshold": 0.25, |
|
|
"iou_threshold": 0.45, |
|
|
"image_size": 640, |
|
|
"return_annotated_image": True |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
assert response.status_code in [200, 503] |
|
|
|
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
assert "detections" in data |
|
|
assert "processing_time" in data |
|
|
assert "model_info" in data |
|
|
assert "image_dimensions" in data |
|
|
|
|
|
def test_detect_parameter_validation(self): |
|
|
"""Test parameter validation.""" |
|
|
test_image = create_test_image() |
|
|
|
|
|
|
|
|
response = client.post( |
|
|
"/api/v1/detect", |
|
|
json={ |
|
|
"image": test_image, |
|
|
"confidence_threshold": 1.5 |
|
|
} |
|
|
) |
|
|
assert response.status_code == 422 |
|
|
|
|
|
|
|
|
response = client.post( |
|
|
"/api/v1/detect", |
|
|
json={ |
|
|
"image": test_image, |
|
|
"image_size": 100 |
|
|
} |
|
|
) |
|
|
assert response.status_code == 422 |
|
|
|
|
|
|
|
|
class TestErrorHandling: |
|
|
"""Test error handling.""" |
|
|
|
|
|
def test_404_endpoint(self): |
|
|
"""Test non-existent endpoint.""" |
|
|
response = client.get("/api/v1/nonexistent") |
|
|
assert response.status_code == 404 |
|
|
|
|
|
def test_method_not_allowed(self): |
|
|
"""Test wrong HTTP method.""" |
|
|
response = client.get("/api/v1/detect") |
|
|
assert response.status_code == 405 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__]) |
|
|
|