File size: 4,040 Bytes
090a270
 
f2a237f
 
090a270
f2a237f
 
090a270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# tests/test_api/test_classify_upload.py
from io import BytesIO

import pytest
from fastapi.testclient import TestClient
from PIL import Image

from app.api.main import app

client = TestClient(app)


def create_test_image_bytes() -> BytesIO:
    """Create a PNG image as BytesIO for upload."""
    img = Image.new("RGB", (256, 512), color=(100, 150, 80))
    buf = BytesIO()
    img.save(buf, format="PNG")
    buf.seek(0)
    return buf


class TestClassifyUploadEndpoint:
    """Tests for POST /classify/upload."""

    def test_upload_returns_200(self):
        buf = create_test_image_bytes()
        response = client.post(
            "/classify/upload",
            files={"file": ("test.png", buf, "image/png")},
        )
        assert response.status_code == 200

    def test_upload_response_structure(self):
        buf = create_test_image_bytes()
        response = client.post(
            "/classify/upload",
            files={"file": ("test.png", buf, "image/png")},
        )
        data = response.json()

        assert "top_predictions" in data
        assert "predicted_species" in data
        assert "confidence" in data
        assert isinstance(data["top_predictions"], list)
        assert len(data["top_predictions"]) > 0

    def test_upload_prediction_fields(self):
        buf = create_test_image_bytes()
        response = client.post(
            "/classify/upload",
            files={"file": ("test.png", buf, "image/png")},
        )
        pred = response.json()["top_predictions"][0]

        assert "species" in pred
        assert "confidence" in pred
        assert 0.0 <= pred["confidence"] <= 1.0

    def test_upload_confidence_range(self):
        buf = create_test_image_bytes()
        response = client.post(
            "/classify/upload",
            files={"file": ("test.png", buf, "image/png")},
        )
        data = response.json()

        assert 0.0 <= data["confidence"] <= 1.0

    def test_upload_with_invalid_model_id_returns_error(self):
        buf = create_test_image_bytes()
        with pytest.raises(ValueError, match="No weights found"):
            client.post(
                "/classify/upload",
                files={"file": ("test.png", buf, "image/png")},
                data={"model_id": "nonexistent_model"},
            )

    def test_upload_no_file_returns_422(self):
        response = client.post("/classify/upload")
        assert response.status_code == 422


class TestExplainEndpoint:
    """Tests for POST /classify/explain (slow — requires model + captum)."""

    @pytest.mark.slow
    def test_explain_returns_200(self):
        buf = create_test_image_bytes()
        response = client.post(
            "/classify/explain",
            files={"file": ("test.png", buf, "image/png")},
        )
        assert response.status_code == 200

    @pytest.mark.slow
    def test_explain_response_structure(self):
        buf = create_test_image_bytes()
        response = client.post(
            "/classify/explain",
            files={"file": ("test.png", buf, "image/png")},
        )
        data = response.json()

        assert "heatmap_base64" in data
        assert "predicted_species" in data
        assert "confidence" in data
        assert isinstance(data["heatmap_base64"], str)
        assert len(data["heatmap_base64"]) > 100


class TestModelsEndpoint:
    """Tests for GET /classify/models."""

    def test_models_returns_200(self):
        response = client.get("/classify/models")
        assert response.status_code == 200

    def test_models_response_structure(self):
        response = client.get("/classify/models")
        data = response.json()

        assert "models" in data
        assert isinstance(data["models"], list)

    def test_models_items_have_fields(self):
        response = client.get("/classify/models")
        data = response.json()

        if data["models"]:
            m = data["models"][0]
            assert "id" in m
            assert "name" in m
            assert "model_variant" in m