Spaces:
Sleeping
Sleeping
| # tests/test_api/test_classify.py | |
| import base64 | |
| from io import BytesIO | |
| from fastapi.testclient import TestClient | |
| from PIL import Image | |
| from app.api.main import app | |
| client = TestClient(app) | |
| def create_test_image_base64() -> str: | |
| """Create a simple test image encoded as base64.""" | |
| img = Image.new("RGB", (224, 224), color="red") | |
| buffer = BytesIO() | |
| img.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| return base64.b64encode(buffer.read()).decode("utf-8") | |
| class TestClassifyEndpoint: | |
| """Tests for classify endpoint.""" | |
| def test_classify_returns_200(self): | |
| """Test classify endpoint returns 200 with valid image.""" | |
| image_base64 = create_test_image_base64() | |
| response = client.post("/classify", json={"image_base64": image_base64, "model_type": "resnet18"}) | |
| assert response.status_code == 200 | |
| def test_classify_response_structure(self): | |
| """Test classify response has required fields.""" | |
| image_base64 = create_test_image_base64() | |
| response = client.post("/classify", json={"image_base64": image_base64}) | |
| data = response.json() | |
| assert "predicted_class" in data | |
| assert "predicted_class_id" in data | |
| assert "confidence" in data | |
| assert "probabilities" in data | |
| def test_classify_confidence_range(self): | |
| """Test confidence is between 0 and 1.""" | |
| image_base64 = create_test_image_base64() | |
| response = client.post("/classify", json={"image_base64": image_base64}) | |
| data = response.json() | |
| assert 0.0 <= data["confidence"] <= 1.0 | |
| def test_classify_probabilities_sum_to_one(self): | |
| """Test probabilities sum to approximately 1.""" | |
| image_base64 = create_test_image_base64() | |
| response = client.post("/classify", json={"image_base64": image_base64}) | |
| data = response.json() | |
| probs = list(data["probabilities"].values()) | |
| assert abs(sum(probs) - 1.0) < 0.01 | |
| def test_classify_invalid_base64_returns_400(self): | |
| """Test invalid base64 returns 400.""" | |
| response = client.post("/classify", json={"image_base64": "not_valid_base64!!!"}) | |
| assert response.status_code == 400 | |