File size: 5,984 Bytes
a229747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def test_health_returns_ok(api_client):
    resp = api_client.get("/health")
    assert resp.status_code == 200
    assert resp.json()["status"] == "ok"


def test_labels_returns_four_classes(api_client):
    resp = api_client.get("/labels")
    assert resp.status_code == 200
    assert len(resp.json()["labels"]) == 4


def test_predict_missing_model_returns_404(api_client):
    resp = api_client.post(
        "/predict", json={"text": "Some news article", "model_name": "nonexistent_model"}
    )
    assert resp.status_code == 404


def test_predict_empty_text_returns_422(api_client):
    resp = api_client.post("/predict", json={"text": "", "model_name": "lr"})
    assert resp.status_code == 422


def test_batch_predict_too_many_texts_returns_422(api_client):
    resp = api_client.post(
        "/batch_predict", json={"texts": ["text"] * 257, "model_name": "lr"}
    )
    assert resp.status_code == 422


def test_analytics_summary_returns_valid_json(api_client):
    resp = api_client.get("/analytics/summary")
    assert resp.status_code == 200
    assert "total_requests" in resp.json()


def test_models_endpoint_lists_sklearn_model(api_client_with_models):
    resp = api_client_with_models.get("/models")
    assert resp.status_code == 200
    payload = resp.json()
    names = {m["name"] for m in payload["models"]}
    assert "lr" in names


def test_predict_lr_success_logs_request(api_client_with_models, test_db_path):
    import database

    resp = api_client_with_models.post(
        "/predict", json={"text": "Fed raises interest rates by 50 bps", "model_name": "lr"}
    )
    assert resp.status_code == 200
    body = resp.json()
    assert "request_id" in body
    assert "is_low_confidence" in body
    assert "latency_ms" in body

    history = database.get_request_history(db_path=str(test_db_path), limit=10)
    assert len(history) == 1
    assert history[0]["request_id"] == body["request_id"]


def test_predict_lr_missing_joblib_returns_404(api_client):
    resp = api_client.post(
        "/predict", json={"text": "Fed raises rates", "model_name": "lr"}
    )
    assert resp.status_code == 404


def test_batch_predict_lr_logs_each_item(api_client_with_models, test_db_path):
    import database

    texts = ["Apple unveils new AI chip", "Team wins the final match in overtime"]
    resp = api_client_with_models.post(
        "/batch_predict", json={"texts": texts, "model_name": "lr"}
    )
    assert resp.status_code == 200
    body = resp.json()
    assert body["count"] == 2
    assert len(body["predictions"]) == 2

    history = database.get_request_history(db_path=str(test_db_path), limit=10)
    assert len(history) == 2
    assert all(int(r["is_batch"]) == 1 for r in history)


def test_predict_svm_success_uses_decision_function(api_client_with_models):
    resp = api_client_with_models.post(
        "/predict", json={"text": "Championship match ends in overtime", "model_name": "svm"}
    )
    assert resp.status_code == 200
    body = resp.json()
    assert body["probabilities"] is None
    assert 0.0 <= float(body.get("latency_ms", 0.0))


def test_batch_predict_svm_success(api_client_with_models):
    resp = api_client_with_models.post(
        "/batch_predict",
        json={"texts": ["Markets rise on earnings", "New chip released"], "model_name": "svm"},
    )
    assert resp.status_code == 200
    body = resp.json()
    assert body["count"] == 2


def test_predict_lr_cached_path(api_client_with_models, test_db_path):
    import database

    r1 = api_client_with_models.post(
        "/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"}
    )
    r2 = api_client_with_models.post(
        "/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"}
    )
    assert r1.status_code == 200
    assert r2.status_code == 200
    history = database.get_request_history(db_path=str(test_db_path), limit=10)
    assert len(history) == 2


def test_low_confidence_review_flow(api_client_with_models, test_db_path, monkeypatch):
    import database
    from config import CFG

    monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99)
    resp = api_client_with_models.post(
        "/predict", json={"text": "Mixed signals from markets after earnings report", "model_name": "lr"}
    )
    assert resp.status_code == 200
    request_id = resp.json()["request_id"]

    flags = api_client_with_models.get("/analytics/low_confidence").json()
    assert any(f["request_id"] == request_id for f in flags)

    patch_resp = api_client_with_models.patch(
        f"/analytics/review/{request_id}", json={"note": "needs review"}
    )
    assert patch_resp.status_code == 200

    reviewed_flags = api_client_with_models.get("/analytics/low_confidence?reviewed=true").json()
    match = [f for f in reviewed_flags if f["request_id"] == request_id]
    assert match
    assert int(match[0]["reviewed"]) == 1
    assert match[0]["review_note"] == "needs review"

    history = database.get_request_history(db_path=str(test_db_path), limit=10)
    assert history


def test_export_flags_creates_files(api_client_with_models, tmp_path, monkeypatch):
    import database
    from config import CFG

    monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99)
    api_client_with_models.post(
        "/predict", json={"text": "Unclear headline with mixed topics", "model_name": "lr"}
    )

    out_dir = tmp_path / "low_confidence_review"

    original = database.export_low_confidence_to_folder

    def _export_override():
        return original(output_dir=str(out_dir))

    monkeypatch.setattr(database, "export_low_confidence_to_folder", _export_override)
    resp = api_client_with_models.post("/analytics/export_flags")
    assert resp.status_code == 200
    payload = resp.json()
    assert payload["exported"] >= 1


def test_health_with_startup_runs(api_client_with_startup):
    resp = api_client_with_startup.get("/health")
    assert resp.status_code == 200