Spaces:
Sleeping
Sleeping
File size: 2,208 Bytes
090a270 f2a237f 090a270 f2a237f 090a270 f2a237f 090a270 20d06bb 090a270 20d06bb 090a270 20d06bb 090a270 20d06bb 090a270 20d06bb 090a270 20d06bb 090a270 20d06bb 090a270 20d06bb 090a270 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | # tests/test_api/test_classify.py
import base64
from io import BytesIO
from fastapi.testclient import TestClient
from PIL import Image
from app.api.main import app
client = TestClient(app)
def create_test_image_base64() -> str:
"""Create a simple test image encoded as base64."""
img = Image.new("RGB", (224, 224), color="red")
buffer = BytesIO()
img.save(buffer, format="PNG")
buffer.seek(0)
return base64.b64encode(buffer.read()).decode("utf-8")
class TestClassifyEndpoint:
"""Tests for classify endpoint."""
def test_classify_returns_200(self):
"""Test classify endpoint returns 200 with valid image."""
image_base64 = create_test_image_base64()
response = client.post("/classify", json={"image_base64": image_base64, "model_type": "resnet18"})
assert response.status_code == 200
def test_classify_response_structure(self):
"""Test classify response has required fields."""
image_base64 = create_test_image_base64()
response = client.post("/classify", json={"image_base64": image_base64})
data = response.json()
assert "predicted_class" in data
assert "predicted_class_id" in data
assert "confidence" in data
assert "probabilities" in data
def test_classify_confidence_range(self):
"""Test confidence is between 0 and 1."""
image_base64 = create_test_image_base64()
response = client.post("/classify", json={"image_base64": image_base64})
data = response.json()
assert 0.0 <= data["confidence"] <= 1.0
def test_classify_probabilities_sum_to_one(self):
"""Test probabilities sum to approximately 1."""
image_base64 = create_test_image_base64()
response = client.post("/classify", json={"image_base64": image_base64})
data = response.json()
probs = list(data["probabilities"].values())
assert abs(sum(probs) - 1.0) < 0.01
def test_classify_invalid_base64_returns_400(self):
"""Test invalid base64 returns 400."""
response = client.post("/classify", json={"image_base64": "not_valid_base64!!!"})
assert response.status_code == 400
|