File size: 4,350 Bytes
e317d56
447c4a0
 
 
 
 
 
 
e317d56
447c4a0
 
 
 
 
e317d56
 
447c4a0
 
 
 
 
 
 
 
 
 
 
 
 
46cc63a
447c4a0
 
 
e317d56
 
46cc63a
e317d56
 
447c4a0
 
e317d56
 
 
447c4a0
 
 
 
 
 
 
 
 
 
 
 
e317d56
 
447c4a0
 
 
 
 
 
e317d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447c4a0
46cc63a
 
0f0ce9b
46cc63a
 
 
 
 
0f0ce9b
 
 
 
 
46cc63a
0f0ce9b
 
46cc63a
0f0ce9b
 
e317d56
 
 
 
 
 
 
46cc63a
 
 
 
 
e317d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447c4a0
e317d56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Tests for POST /predict."""

from unittest.mock import MagicMock

import pytest
from fastapi.testclient import TestClient

from src.api import main as api_main
from src.api.state import get_state

PREDICT_RESPONSE_KEYS = {
    "text",
    "is_toxic",
    "probability",
    "status",
    "mode",
    "labels",
    "model_used",
    "latency_ms",
}


@pytest.fixture
def client():
    mock_service = MagicMock()
    mock_service.predict.return_value = {
        "is_toxic": False,
        "probability": 0.12,
        "labels": [],
        "model_used": "Meta-Feature Stacking (Production)",
    }

    with TestClient(api_main.app) as test_client:
        state = get_state()
        state["service"] = mock_service
        state["model_name"] = "Meta-Feature Stacking (Production)"
        state["predictions_served"] = 0
        state["startup_time"] = 0.0
        yield test_client

    state = get_state()
    state["service"] = None
    state["model_name"] = None


def test_predict_returns_correct_structure(client: TestClient):
    response = client.post(
        "/predict",
        json={"text": "This is a sample comment", "threshold": 0.5},
    )

    assert response.status_code == 200
    data = response.json()
    assert PREDICT_RESPONSE_KEYS <= set(data.keys())
    assert data["text"] == "This is a sample comment"
    assert data["status"] == "Safe"
    assert data["mode"] == "binary"
    assert isinstance(data["is_toxic"], bool)
    assert 0.0 <= data["probability"] <= 1.0


def test_predict_rejects_empty_text(client: TestClient):
    response = client.post("/predict", json={"text": "   "})
    assert response.status_code == 422


def test_health_includes_project_name(client: TestClient):
    response = client.get("/health")
    assert response.status_code == 200
    assert response.json()["project"] == "youtube_hate_detector"


def test_predict_video_demo_comments_differ_by_url(client: TestClient, monkeypatch):
    monkeypatch.delenv("YOUTUBE_API_KEY", raising=False)

    r1 = client.post(
        "/predict-video",
        json={
            "url": "https://www.youtube.com/watch?v=jNQXAC9IVRw",
            "max_comments": 5,
            "threshold": 0.5,
        },
    )
    r2 = client.post(
        "/predict-video",
        json={
            "url": "https://www.youtube.com/watch?v=IEEhzQoKtQU",
            "max_comments": 5,
            "threshold": 0.5,
        },
    )

    assert r1.status_code == 200
    assert r2.status_code == 200
    data1 = r1.json()
    data2 = r2.json()
    assert data1["source"] == "demo"
    assert data2["source"] == "demo"
    assert data1["results"][0]["text"] != data2["results"][0]["text"]


def test_catalog_has_demo_models():
    from src.service.model_service import AVAILABLE_MODELS

    assert set(AVAILABLE_MODELS.keys()) == {
        "Meta-Feature Stacking (Production)",
        "LR + TF-IDF (Baseline)",
        "Frozen Toxic-BERT (Baseline)",
    }


def test_select_model_via_post(client: TestClient):
    response = client.post(
        "/models/select",
        json={"model_name": "LR + TF-IDF (Baseline)"},
    )
    assert response.status_code == 200
    assert response.json()["model"] == "LR + TF-IDF (Baseline)"


def test_models_status_lists_catalog(client: TestClient):
    response = client.get("/models/status")
    assert response.status_code == 200
    data = response.json()
    assert "models" in data
    assert len(data["models"]) >= 1
    names = {m["name"] for m in data["models"]}
    assert names == {
        "Meta-Feature Stacking (Production)",
        "LR + TF-IDF (Baseline)",
        "Frozen Toxic-BERT (Baseline)",
    }


def test_predict_video_comments_disabled_raises_422(client: TestClient, monkeypatch):
    from src.api.youtube import CommentsFetchError

    monkeypatch.setenv("YOUTUBE_API_KEY", "fake-key")

    def _raise_disabled(*_args, **_kwargs):
        raise CommentsFetchError("Comments are disabled on this video")

    monkeypatch.setattr("src.api.routes.predict.fetch_comments", _raise_disabled)

    response = client.post(
        "/predict-video",
        json={
            "url": "https://www.youtube.com/watch?v=disabled123",
            "max_comments": 5,
            "threshold": 0.5,
        },
    )
    assert response.status_code == 422
    assert "disabled" in response.json()["detail"].lower()