""" Basic API tests for the Marine Species Identification API. """ import pytest import base64 import io from PIL import Image import numpy as np from fastapi.testclient import TestClient from app.main import app client = TestClient(app) def create_test_image(width: int = 640, height: int = 480) -> str: """Create a test image and return as base64 string.""" # Create a simple test image image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) pil_image = Image.fromarray(image) # Convert to base64 buffer = io.BytesIO() pil_image.save(buffer, format="JPEG") image_bytes = buffer.getvalue() return base64.b64encode(image_bytes).decode('utf-8') class TestHealthEndpoints: """Test health and status endpoints.""" def test_root_endpoint(self): """Test root endpoint.""" response = client.get("/") assert response.status_code == 200 data = response.json() assert "message" in data assert "version" in data def test_root_health(self): """Test root health endpoint.""" response = client.get("/health") assert response.status_code == 200 data = response.json() assert data["status"] == "ok" def test_health_check(self): """Test detailed health check.""" response = client.get("/api/v1/health") assert response.status_code == 200 data = response.json() assert "status" in data assert "model_loaded" in data assert "timestamp" in data def test_api_info(self): """Test API info endpoint.""" response = client.get("/api/v1/info") assert response.status_code == 200 data = response.json() assert "name" in data assert "version" in data assert "endpoints" in data def test_liveness_check(self): """Test liveness probe.""" response = client.get("/api/v1/live") assert response.status_code == 200 data = response.json() assert data["status"] == "alive" class TestSpeciesEndpoints: """Test species-related endpoints.""" def test_list_species(self): """Test species list endpoint.""" response = client.get("/api/v1/species") assert response.status_code in [200, 503] # May fail if model not loaded if response.status_code == 200: data = response.json() assert "species" in data assert "total_count" in data assert isinstance(data["species"], list) def test_get_species_info(self): """Test individual species info endpoint.""" # This may fail if model is not loaded, which is expected in test environment response = client.get("/api/v1/species/0") assert response.status_code in [200, 404, 503] class TestInferenceEndpoints: """Test inference endpoints.""" def test_detect_invalid_image(self): """Test detection with invalid image data.""" response = client.post( "/api/v1/detect", json={ "image": "invalid_base64_data", "confidence_threshold": 0.25 } ) assert response.status_code in [400, 503] # Bad request or service unavailable def test_detect_valid_request_format(self): """Test detection with valid request format.""" test_image = create_test_image() response = client.post( "/api/v1/detect", json={ "image": test_image, "confidence_threshold": 0.25, "iou_threshold": 0.45, "image_size": 640, "return_annotated_image": True } ) # May return 503 if model is not loaded, which is expected in test environment assert response.status_code in [200, 503] if response.status_code == 200: data = response.json() assert "detections" in data assert "processing_time" in data assert "model_info" in data assert "image_dimensions" in data def test_detect_parameter_validation(self): """Test parameter validation.""" test_image = create_test_image() # Test invalid confidence threshold response = client.post( "/api/v1/detect", json={ "image": test_image, "confidence_threshold": 1.5 # Invalid: > 1.0 } ) assert response.status_code == 422 # Validation error # Test invalid image size response = client.post( "/api/v1/detect", json={ "image": test_image, "image_size": 100 # Invalid: < 320 } ) assert response.status_code == 422 # Validation error class TestErrorHandling: """Test error handling.""" def test_404_endpoint(self): """Test non-existent endpoint.""" response = client.get("/api/v1/nonexistent") assert response.status_code == 404 def test_method_not_allowed(self): """Test wrong HTTP method.""" response = client.get("/api/v1/detect") # Should be POST assert response.status_code == 405 if __name__ == "__main__": pytest.main([__file__])