File size: 2,208 Bytes
090a270
 
 
f2a237f
090a270
f2a237f
090a270
f2a237f
090a270
 
 
 
 
 
20d06bb
090a270
20d06bb
090a270
20d06bb
090a270
 
 
 
 
 
 
 
 
20d06bb
090a270
 
 
 
 
 
 
20d06bb
090a270
 
 
 
 
 
 
 
 
 
 
20d06bb
090a270
 
 
 
 
 
 
 
20d06bb
090a270
 
 
 
 
 
 
20d06bb
090a270
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# tests/test_api/test_classify.py
import base64
from io import BytesIO

from fastapi.testclient import TestClient
from PIL import Image

from app.api.main import app

client = TestClient(app)


def create_test_image_base64() -> str:
    """Create a simple test image encoded as base64."""
    img = Image.new("RGB", (224, 224), color="red")
    buffer = BytesIO()
    img.save(buffer, format="PNG")
    buffer.seek(0)
    return base64.b64encode(buffer.read()).decode("utf-8")


class TestClassifyEndpoint:
    """Tests for classify endpoint."""

    def test_classify_returns_200(self):
        """Test classify endpoint returns 200 with valid image."""
        image_base64 = create_test_image_base64()

        response = client.post("/classify", json={"image_base64": image_base64, "model_type": "resnet18"})

        assert response.status_code == 200

    def test_classify_response_structure(self):
        """Test classify response has required fields."""
        image_base64 = create_test_image_base64()

        response = client.post("/classify", json={"image_base64": image_base64})
        data = response.json()

        assert "predicted_class" in data
        assert "predicted_class_id" in data
        assert "confidence" in data
        assert "probabilities" in data

    def test_classify_confidence_range(self):
        """Test confidence is between 0 and 1."""
        image_base64 = create_test_image_base64()

        response = client.post("/classify", json={"image_base64": image_base64})
        data = response.json()

        assert 0.0 <= data["confidence"] <= 1.0

    def test_classify_probabilities_sum_to_one(self):
        """Test probabilities sum to approximately 1."""
        image_base64 = create_test_image_base64()

        response = client.post("/classify", json={"image_base64": image_base64})
        data = response.json()

        probs = list(data["probabilities"].values())
        assert abs(sum(probs) - 1.0) < 0.01

    def test_classify_invalid_base64_returns_400(self):
        """Test invalid base64 returns 400."""
        response = client.post("/classify", json={"image_base64": "not_valid_base64!!!"})

        assert response.status_code == 400