fishapi / tests /test_api.py
kamau1's picture
Initial commit
bcc2f7b verified
"""
Basic API tests for the Marine Species Identification API.
"""
import pytest
import base64
import io
from PIL import Image
import numpy as np
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def create_test_image(width: int = 640, height: int = 480) -> str:
"""Create a test image and return as base64 string."""
# Create a simple test image
image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
pil_image = Image.fromarray(image)
# Convert to base64
buffer = io.BytesIO()
pil_image.save(buffer, format="JPEG")
image_bytes = buffer.getvalue()
return base64.b64encode(image_bytes).decode('utf-8')
class TestHealthEndpoints:
"""Test health and status endpoints."""
def test_root_endpoint(self):
"""Test root endpoint."""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "message" in data
assert "version" in data
def test_root_health(self):
"""Test root health endpoint."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
def test_health_check(self):
"""Test detailed health check."""
response = client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "model_loaded" in data
assert "timestamp" in data
def test_api_info(self):
"""Test API info endpoint."""
response = client.get("/api/v1/info")
assert response.status_code == 200
data = response.json()
assert "name" in data
assert "version" in data
assert "endpoints" in data
def test_liveness_check(self):
"""Test liveness probe."""
response = client.get("/api/v1/live")
assert response.status_code == 200
data = response.json()
assert data["status"] == "alive"
class TestSpeciesEndpoints:
"""Test species-related endpoints."""
def test_list_species(self):
"""Test species list endpoint."""
response = client.get("/api/v1/species")
assert response.status_code in [200, 503] # May fail if model not loaded
if response.status_code == 200:
data = response.json()
assert "species" in data
assert "total_count" in data
assert isinstance(data["species"], list)
def test_get_species_info(self):
"""Test individual species info endpoint."""
# This may fail if model is not loaded, which is expected in test environment
response = client.get("/api/v1/species/0")
assert response.status_code in [200, 404, 503]
class TestInferenceEndpoints:
"""Test inference endpoints."""
def test_detect_invalid_image(self):
"""Test detection with invalid image data."""
response = client.post(
"/api/v1/detect",
json={
"image": "invalid_base64_data",
"confidence_threshold": 0.25
}
)
assert response.status_code in [400, 503] # Bad request or service unavailable
def test_detect_valid_request_format(self):
"""Test detection with valid request format."""
test_image = create_test_image()
response = client.post(
"/api/v1/detect",
json={
"image": test_image,
"confidence_threshold": 0.25,
"iou_threshold": 0.45,
"image_size": 640,
"return_annotated_image": True
}
)
# May return 503 if model is not loaded, which is expected in test environment
assert response.status_code in [200, 503]
if response.status_code == 200:
data = response.json()
assert "detections" in data
assert "processing_time" in data
assert "model_info" in data
assert "image_dimensions" in data
def test_detect_parameter_validation(self):
"""Test parameter validation."""
test_image = create_test_image()
# Test invalid confidence threshold
response = client.post(
"/api/v1/detect",
json={
"image": test_image,
"confidence_threshold": 1.5 # Invalid: > 1.0
}
)
assert response.status_code == 422 # Validation error
# Test invalid image size
response = client.post(
"/api/v1/detect",
json={
"image": test_image,
"image_size": 100 # Invalid: < 320
}
)
assert response.status_code == 422 # Validation error
class TestErrorHandling:
"""Test error handling."""
def test_404_endpoint(self):
"""Test non-existent endpoint."""
response = client.get("/api/v1/nonexistent")
assert response.status_code == 404
def test_method_not_allowed(self):
"""Test wrong HTTP method."""
response = client.get("/api/v1/detect") # Should be POST
assert response.status_code == 405
if __name__ == "__main__":
pytest.main([__file__])