Spaces:
Running
Running
File size: 5,984 Bytes
a229747 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | def test_health_returns_ok(api_client):
resp = api_client.get("/health")
assert resp.status_code == 200
assert resp.json()["status"] == "ok"
def test_labels_returns_four_classes(api_client):
resp = api_client.get("/labels")
assert resp.status_code == 200
assert len(resp.json()["labels"]) == 4
def test_predict_missing_model_returns_404(api_client):
resp = api_client.post(
"/predict", json={"text": "Some news article", "model_name": "nonexistent_model"}
)
assert resp.status_code == 404
def test_predict_empty_text_returns_422(api_client):
resp = api_client.post("/predict", json={"text": "", "model_name": "lr"})
assert resp.status_code == 422
def test_batch_predict_too_many_texts_returns_422(api_client):
resp = api_client.post(
"/batch_predict", json={"texts": ["text"] * 257, "model_name": "lr"}
)
assert resp.status_code == 422
def test_analytics_summary_returns_valid_json(api_client):
resp = api_client.get("/analytics/summary")
assert resp.status_code == 200
assert "total_requests" in resp.json()
def test_models_endpoint_lists_sklearn_model(api_client_with_models):
resp = api_client_with_models.get("/models")
assert resp.status_code == 200
payload = resp.json()
names = {m["name"] for m in payload["models"]}
assert "lr" in names
def test_predict_lr_success_logs_request(api_client_with_models, test_db_path):
import database
resp = api_client_with_models.post(
"/predict", json={"text": "Fed raises interest rates by 50 bps", "model_name": "lr"}
)
assert resp.status_code == 200
body = resp.json()
assert "request_id" in body
assert "is_low_confidence" in body
assert "latency_ms" in body
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert len(history) == 1
assert history[0]["request_id"] == body["request_id"]
def test_predict_lr_missing_joblib_returns_404(api_client):
resp = api_client.post(
"/predict", json={"text": "Fed raises rates", "model_name": "lr"}
)
assert resp.status_code == 404
def test_batch_predict_lr_logs_each_item(api_client_with_models, test_db_path):
import database
texts = ["Apple unveils new AI chip", "Team wins the final match in overtime"]
resp = api_client_with_models.post(
"/batch_predict", json={"texts": texts, "model_name": "lr"}
)
assert resp.status_code == 200
body = resp.json()
assert body["count"] == 2
assert len(body["predictions"]) == 2
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert len(history) == 2
assert all(int(r["is_batch"]) == 1 for r in history)
def test_predict_svm_success_uses_decision_function(api_client_with_models):
resp = api_client_with_models.post(
"/predict", json={"text": "Championship match ends in overtime", "model_name": "svm"}
)
assert resp.status_code == 200
body = resp.json()
assert body["probabilities"] is None
assert 0.0 <= float(body.get("latency_ms", 0.0))
def test_batch_predict_svm_success(api_client_with_models):
resp = api_client_with_models.post(
"/batch_predict",
json={"texts": ["Markets rise on earnings", "New chip released"], "model_name": "svm"},
)
assert resp.status_code == 200
body = resp.json()
assert body["count"] == 2
def test_predict_lr_cached_path(api_client_with_models, test_db_path):
import database
r1 = api_client_with_models.post(
"/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"}
)
r2 = api_client_with_models.post(
"/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"}
)
assert r1.status_code == 200
assert r2.status_code == 200
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert len(history) == 2
def test_low_confidence_review_flow(api_client_with_models, test_db_path, monkeypatch):
import database
from config import CFG
monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99)
resp = api_client_with_models.post(
"/predict", json={"text": "Mixed signals from markets after earnings report", "model_name": "lr"}
)
assert resp.status_code == 200
request_id = resp.json()["request_id"]
flags = api_client_with_models.get("/analytics/low_confidence").json()
assert any(f["request_id"] == request_id for f in flags)
patch_resp = api_client_with_models.patch(
f"/analytics/review/{request_id}", json={"note": "needs review"}
)
assert patch_resp.status_code == 200
reviewed_flags = api_client_with_models.get("/analytics/low_confidence?reviewed=true").json()
match = [f for f in reviewed_flags if f["request_id"] == request_id]
assert match
assert int(match[0]["reviewed"]) == 1
assert match[0]["review_note"] == "needs review"
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert history
def test_export_flags_creates_files(api_client_with_models, tmp_path, monkeypatch):
import database
from config import CFG
monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99)
api_client_with_models.post(
"/predict", json={"text": "Unclear headline with mixed topics", "model_name": "lr"}
)
out_dir = tmp_path / "low_confidence_review"
original = database.export_low_confidence_to_folder
def _export_override():
return original(output_dir=str(out_dir))
monkeypatch.setattr(database, "export_low_confidence_to_folder", _export_override)
resp = api_client_with_models.post("/analytics/export_flags")
assert resp.status_code == 200
payload = resp.json()
assert payload["exported"] >= 1
def test_health_with_startup_runs(api_client_with_startup):
resp = api_client_with_startup.get("/health")
assert resp.status_code == 200
|