File size: 4,350 Bytes
e317d56 447c4a0 e317d56 447c4a0 e317d56 447c4a0 46cc63a 447c4a0 e317d56 46cc63a e317d56 447c4a0 e317d56 447c4a0 e317d56 447c4a0 e317d56 447c4a0 46cc63a 0f0ce9b 46cc63a 0f0ce9b 46cc63a 0f0ce9b 46cc63a 0f0ce9b e317d56 46cc63a e317d56 447c4a0 e317d56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """Tests for POST /predict."""
from unittest.mock import MagicMock
import pytest
from fastapi.testclient import TestClient
from src.api import main as api_main
from src.api.state import get_state
PREDICT_RESPONSE_KEYS = {
"text",
"is_toxic",
"probability",
"status",
"mode",
"labels",
"model_used",
"latency_ms",
}
@pytest.fixture
def client():
mock_service = MagicMock()
mock_service.predict.return_value = {
"is_toxic": False,
"probability": 0.12,
"labels": [],
"model_used": "Meta-Feature Stacking (Production)",
}
with TestClient(api_main.app) as test_client:
state = get_state()
state["service"] = mock_service
state["model_name"] = "Meta-Feature Stacking (Production)"
state["predictions_served"] = 0
state["startup_time"] = 0.0
yield test_client
state = get_state()
state["service"] = None
state["model_name"] = None
def test_predict_returns_correct_structure(client: TestClient):
response = client.post(
"/predict",
json={"text": "This is a sample comment", "threshold": 0.5},
)
assert response.status_code == 200
data = response.json()
assert PREDICT_RESPONSE_KEYS <= set(data.keys())
assert data["text"] == "This is a sample comment"
assert data["status"] == "Safe"
assert data["mode"] == "binary"
assert isinstance(data["is_toxic"], bool)
assert 0.0 <= data["probability"] <= 1.0
def test_predict_rejects_empty_text(client: TestClient):
response = client.post("/predict", json={"text": " "})
assert response.status_code == 422
def test_health_includes_project_name(client: TestClient):
response = client.get("/health")
assert response.status_code == 200
assert response.json()["project"] == "youtube_hate_detector"
def test_predict_video_demo_comments_differ_by_url(client: TestClient, monkeypatch):
monkeypatch.delenv("YOUTUBE_API_KEY", raising=False)
r1 = client.post(
"/predict-video",
json={
"url": "https://www.youtube.com/watch?v=jNQXAC9IVRw",
"max_comments": 5,
"threshold": 0.5,
},
)
r2 = client.post(
"/predict-video",
json={
"url": "https://www.youtube.com/watch?v=IEEhzQoKtQU",
"max_comments": 5,
"threshold": 0.5,
},
)
assert r1.status_code == 200
assert r2.status_code == 200
data1 = r1.json()
data2 = r2.json()
assert data1["source"] == "demo"
assert data2["source"] == "demo"
assert data1["results"][0]["text"] != data2["results"][0]["text"]
def test_catalog_has_demo_models():
from src.service.model_service import AVAILABLE_MODELS
assert set(AVAILABLE_MODELS.keys()) == {
"Meta-Feature Stacking (Production)",
"LR + TF-IDF (Baseline)",
"Frozen Toxic-BERT (Baseline)",
}
def test_select_model_via_post(client: TestClient):
response = client.post(
"/models/select",
json={"model_name": "LR + TF-IDF (Baseline)"},
)
assert response.status_code == 200
assert response.json()["model"] == "LR + TF-IDF (Baseline)"
def test_models_status_lists_catalog(client: TestClient):
response = client.get("/models/status")
assert response.status_code == 200
data = response.json()
assert "models" in data
assert len(data["models"]) >= 1
names = {m["name"] for m in data["models"]}
assert names == {
"Meta-Feature Stacking (Production)",
"LR + TF-IDF (Baseline)",
"Frozen Toxic-BERT (Baseline)",
}
def test_predict_video_comments_disabled_raises_422(client: TestClient, monkeypatch):
from src.api.youtube import CommentsFetchError
monkeypatch.setenv("YOUTUBE_API_KEY", "fake-key")
def _raise_disabled(*_args, **_kwargs):
raise CommentsFetchError("Comments are disabled on this video")
monkeypatch.setattr("src.api.routes.predict.fetch_comments", _raise_disabled)
response = client.post(
"/predict-video",
json={
"url": "https://www.youtube.com/watch?v=disabled123",
"max_comments": 5,
"threshold": 0.5,
},
)
assert response.status_code == 422
assert "disabled" in response.json()["detail"].lower()
|