Spaces:
Sleeping
Sleeping
| # tests/test_api/test_classify_upload.py | |
| from io import BytesIO | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| from PIL import Image | |
| from app.api.main import app | |
| client = TestClient(app) | |
| def create_test_image_bytes() -> BytesIO: | |
| """Create a PNG image as BytesIO for upload.""" | |
| img = Image.new("RGB", (256, 512), color=(100, 150, 80)) | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| buf.seek(0) | |
| return buf | |
| class TestClassifyUploadEndpoint: | |
| """Tests for POST /classify/upload.""" | |
| def test_upload_returns_200(self): | |
| buf = create_test_image_bytes() | |
| response = client.post( | |
| "/classify/upload", | |
| files={"file": ("test.png", buf, "image/png")}, | |
| ) | |
| assert response.status_code == 200 | |
| def test_upload_response_structure(self): | |
| buf = create_test_image_bytes() | |
| response = client.post( | |
| "/classify/upload", | |
| files={"file": ("test.png", buf, "image/png")}, | |
| ) | |
| data = response.json() | |
| assert "top_predictions" in data | |
| assert "predicted_species" in data | |
| assert "confidence" in data | |
| assert isinstance(data["top_predictions"], list) | |
| assert len(data["top_predictions"]) > 0 | |
| def test_upload_prediction_fields(self): | |
| buf = create_test_image_bytes() | |
| response = client.post( | |
| "/classify/upload", | |
| files={"file": ("test.png", buf, "image/png")}, | |
| ) | |
| pred = response.json()["top_predictions"][0] | |
| assert "species" in pred | |
| assert "confidence" in pred | |
| assert 0.0 <= pred["confidence"] <= 1.0 | |
| def test_upload_confidence_range(self): | |
| buf = create_test_image_bytes() | |
| response = client.post( | |
| "/classify/upload", | |
| files={"file": ("test.png", buf, "image/png")}, | |
| ) | |
| data = response.json() | |
| assert 0.0 <= data["confidence"] <= 1.0 | |
| def test_upload_with_invalid_model_id_returns_error(self): | |
| buf = create_test_image_bytes() | |
| with pytest.raises(ValueError, match="No weights found"): | |
| client.post( | |
| "/classify/upload", | |
| files={"file": ("test.png", buf, "image/png")}, | |
| data={"model_id": "nonexistent_model"}, | |
| ) | |
| def test_upload_no_file_returns_422(self): | |
| response = client.post("/classify/upload") | |
| assert response.status_code == 422 | |
| class TestExplainEndpoint: | |
| """Tests for POST /classify/explain (slow — requires model + captum).""" | |
| def test_explain_returns_200(self): | |
| buf = create_test_image_bytes() | |
| response = client.post( | |
| "/classify/explain", | |
| files={"file": ("test.png", buf, "image/png")}, | |
| ) | |
| assert response.status_code == 200 | |
| def test_explain_response_structure(self): | |
| buf = create_test_image_bytes() | |
| response = client.post( | |
| "/classify/explain", | |
| files={"file": ("test.png", buf, "image/png")}, | |
| ) | |
| data = response.json() | |
| assert "heatmap_base64" in data | |
| assert "predicted_species" in data | |
| assert "confidence" in data | |
| assert isinstance(data["heatmap_base64"], str) | |
| assert len(data["heatmap_base64"]) > 100 | |
| class TestModelsEndpoint: | |
| """Tests for GET /classify/models.""" | |
| def test_models_returns_200(self): | |
| response = client.get("/classify/models") | |
| assert response.status_code == 200 | |
| def test_models_response_structure(self): | |
| response = client.get("/classify/models") | |
| data = response.json() | |
| assert "models" in data | |
| assert isinstance(data["models"], list) | |
| def test_models_items_have_fields(self): | |
| response = client.get("/classify/models") | |
| data = response.json() | |
| if data["models"]: | |
| m = data["models"][0] | |
| assert "id" in m | |
| assert "name" in m | |
| assert "model_variant" in m | |