| """Tests for POST /predict.""" |
|
|
| from unittest.mock import MagicMock |
|
|
| import pytest |
| from fastapi.testclient import TestClient |
|
|
| from src.api import main as api_main |
| from src.api.state import get_state |
|
|
| PREDICT_RESPONSE_KEYS = { |
| "text", |
| "is_toxic", |
| "probability", |
| "status", |
| "mode", |
| "labels", |
| "model_used", |
| "latency_ms", |
| } |
|
|
|
|
| @pytest.fixture |
| def client(): |
| mock_service = MagicMock() |
| mock_service.predict.return_value = { |
| "is_toxic": False, |
| "probability": 0.12, |
| "labels": [], |
| "model_used": "Meta-Feature Stacking (Production)", |
| } |
|
|
| with TestClient(api_main.app) as test_client: |
| state = get_state() |
| state["service"] = mock_service |
| state["model_name"] = "Meta-Feature Stacking (Production)" |
| state["predictions_served"] = 0 |
| state["startup_time"] = 0.0 |
| yield test_client |
|
|
| state = get_state() |
| state["service"] = None |
| state["model_name"] = None |
|
|
|
|
| def test_predict_returns_correct_structure(client: TestClient): |
| response = client.post( |
| "/predict", |
| json={"text": "This is a sample comment", "threshold": 0.5}, |
| ) |
|
|
| assert response.status_code == 200 |
| data = response.json() |
| assert PREDICT_RESPONSE_KEYS <= set(data.keys()) |
| assert data["text"] == "This is a sample comment" |
| assert data["status"] == "Safe" |
| assert data["mode"] == "binary" |
| assert isinstance(data["is_toxic"], bool) |
| assert 0.0 <= data["probability"] <= 1.0 |
|
|
|
|
| def test_predict_rejects_empty_text(client: TestClient): |
| response = client.post("/predict", json={"text": " "}) |
| assert response.status_code == 422 |
|
|
|
|
| def test_health_includes_project_name(client: TestClient): |
| response = client.get("/health") |
| assert response.status_code == 200 |
| assert response.json()["project"] == "youtube_hate_detector" |
|
|
|
|
| def test_predict_video_demo_comments_differ_by_url(client: TestClient, monkeypatch): |
| monkeypatch.delenv("YOUTUBE_API_KEY", raising=False) |
|
|
| r1 = client.post( |
| "/predict-video", |
| json={ |
| "url": "https://www.youtube.com/watch?v=jNQXAC9IVRw", |
| "max_comments": 5, |
| "threshold": 0.5, |
| }, |
| ) |
| r2 = client.post( |
| "/predict-video", |
| json={ |
| "url": "https://www.youtube.com/watch?v=IEEhzQoKtQU", |
| "max_comments": 5, |
| "threshold": 0.5, |
| }, |
| ) |
|
|
| assert r1.status_code == 200 |
| assert r2.status_code == 200 |
| data1 = r1.json() |
| data2 = r2.json() |
| assert data1["source"] == "demo" |
| assert data2["source"] == "demo" |
| assert data1["results"][0]["text"] != data2["results"][0]["text"] |
|
|
|
|
| def test_catalog_has_demo_models(): |
| from src.service.model_service import AVAILABLE_MODELS |
|
|
| assert set(AVAILABLE_MODELS.keys()) == { |
| "Meta-Feature Stacking (Production)", |
| "LR + TF-IDF (Baseline)", |
| "Frozen Toxic-BERT (Baseline)", |
| } |
|
|
|
|
| def test_select_model_via_post(client: TestClient): |
| response = client.post( |
| "/models/select", |
| json={"model_name": "LR + TF-IDF (Baseline)"}, |
| ) |
| assert response.status_code == 200 |
| assert response.json()["model"] == "LR + TF-IDF (Baseline)" |
|
|
|
|
| def test_models_status_lists_catalog(client: TestClient): |
| response = client.get("/models/status") |
| assert response.status_code == 200 |
| data = response.json() |
| assert "models" in data |
| assert len(data["models"]) >= 1 |
| names = {m["name"] for m in data["models"]} |
| assert names == { |
| "Meta-Feature Stacking (Production)", |
| "LR + TF-IDF (Baseline)", |
| "Frozen Toxic-BERT (Baseline)", |
| } |
|
|
|
|
| def test_predict_video_comments_disabled_raises_422(client: TestClient, monkeypatch): |
| from src.api.youtube import CommentsFetchError |
|
|
| monkeypatch.setenv("YOUTUBE_API_KEY", "fake-key") |
|
|
| def _raise_disabled(*_args, **_kwargs): |
| raise CommentsFetchError("Comments are disabled on this video") |
|
|
| monkeypatch.setattr("src.api.routes.predict.fetch_comments", _raise_disabled) |
|
|
| response = client.post( |
| "/predict-video", |
| json={ |
| "url": "https://www.youtube.com/watch?v=disabled123", |
| "max_comments": 5, |
| "threshold": 0.5, |
| }, |
| ) |
| assert response.status_code == 422 |
| assert "disabled" in response.json()["detail"].lower() |
|
|