Spaces:
Running
Running
| def test_health_returns_ok(api_client): | |
| resp = api_client.get("/health") | |
| assert resp.status_code == 200 | |
| assert resp.json()["status"] == "ok" | |
| def test_labels_returns_four_classes(api_client): | |
| resp = api_client.get("/labels") | |
| assert resp.status_code == 200 | |
| assert len(resp.json()["labels"]) == 4 | |
| def test_predict_missing_model_returns_404(api_client): | |
| resp = api_client.post( | |
| "/predict", json={"text": "Some news article", "model_name": "nonexistent_model"} | |
| ) | |
| assert resp.status_code == 404 | |
| def test_predict_empty_text_returns_422(api_client): | |
| resp = api_client.post("/predict", json={"text": "", "model_name": "lr"}) | |
| assert resp.status_code == 422 | |
| def test_batch_predict_too_many_texts_returns_422(api_client): | |
| resp = api_client.post( | |
| "/batch_predict", json={"texts": ["text"] * 257, "model_name": "lr"} | |
| ) | |
| assert resp.status_code == 422 | |
| def test_analytics_summary_returns_valid_json(api_client): | |
| resp = api_client.get("/analytics/summary") | |
| assert resp.status_code == 200 | |
| assert "total_requests" in resp.json() | |
| def test_models_endpoint_lists_sklearn_model(api_client_with_models): | |
| resp = api_client_with_models.get("/models") | |
| assert resp.status_code == 200 | |
| payload = resp.json() | |
| names = {m["name"] for m in payload["models"]} | |
| assert "lr" in names | |
| def test_predict_lr_success_logs_request(api_client_with_models, test_db_path): | |
| import database | |
| resp = api_client_with_models.post( | |
| "/predict", json={"text": "Fed raises interest rates by 50 bps", "model_name": "lr"} | |
| ) | |
| assert resp.status_code == 200 | |
| body = resp.json() | |
| assert "request_id" in body | |
| assert "is_low_confidence" in body | |
| assert "latency_ms" in body | |
| history = database.get_request_history(db_path=str(test_db_path), limit=10) | |
| assert len(history) == 1 | |
| assert history[0]["request_id"] == body["request_id"] | |
| def test_predict_lr_missing_joblib_returns_404(api_client): | |
| resp = api_client.post( | |
| "/predict", json={"text": "Fed raises rates", "model_name": "lr"} | |
| ) | |
| assert resp.status_code == 404 | |
| def test_batch_predict_lr_logs_each_item(api_client_with_models, test_db_path): | |
| import database | |
| texts = ["Apple unveils new AI chip", "Team wins the final match in overtime"] | |
| resp = api_client_with_models.post( | |
| "/batch_predict", json={"texts": texts, "model_name": "lr"} | |
| ) | |
| assert resp.status_code == 200 | |
| body = resp.json() | |
| assert body["count"] == 2 | |
| assert len(body["predictions"]) == 2 | |
| history = database.get_request_history(db_path=str(test_db_path), limit=10) | |
| assert len(history) == 2 | |
| assert all(int(r["is_batch"]) == 1 for r in history) | |
| def test_predict_svm_success_uses_decision_function(api_client_with_models): | |
| resp = api_client_with_models.post( | |
| "/predict", json={"text": "Championship match ends in overtime", "model_name": "svm"} | |
| ) | |
| assert resp.status_code == 200 | |
| body = resp.json() | |
| assert body["probabilities"] is None | |
| assert 0.0 <= float(body.get("latency_ms", 0.0)) | |
| def test_batch_predict_svm_success(api_client_with_models): | |
| resp = api_client_with_models.post( | |
| "/batch_predict", | |
| json={"texts": ["Markets rise on earnings", "New chip released"], "model_name": "svm"}, | |
| ) | |
| assert resp.status_code == 200 | |
| body = resp.json() | |
| assert body["count"] == 2 | |
| def test_predict_lr_cached_path(api_client_with_models, test_db_path): | |
| import database | |
| r1 = api_client_with_models.post( | |
| "/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"} | |
| ) | |
| r2 = api_client_with_models.post( | |
| "/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"} | |
| ) | |
| assert r1.status_code == 200 | |
| assert r2.status_code == 200 | |
| history = database.get_request_history(db_path=str(test_db_path), limit=10) | |
| assert len(history) == 2 | |
| def test_low_confidence_review_flow(api_client_with_models, test_db_path, monkeypatch): | |
| import database | |
| from config import CFG | |
| monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99) | |
| resp = api_client_with_models.post( | |
| "/predict", json={"text": "Mixed signals from markets after earnings report", "model_name": "lr"} | |
| ) | |
| assert resp.status_code == 200 | |
| request_id = resp.json()["request_id"] | |
| flags = api_client_with_models.get("/analytics/low_confidence").json() | |
| assert any(f["request_id"] == request_id for f in flags) | |
| patch_resp = api_client_with_models.patch( | |
| f"/analytics/review/{request_id}", json={"note": "needs review"} | |
| ) | |
| assert patch_resp.status_code == 200 | |
| reviewed_flags = api_client_with_models.get("/analytics/low_confidence?reviewed=true").json() | |
| match = [f for f in reviewed_flags if f["request_id"] == request_id] | |
| assert match | |
| assert int(match[0]["reviewed"]) == 1 | |
| assert match[0]["review_note"] == "needs review" | |
| history = database.get_request_history(db_path=str(test_db_path), limit=10) | |
| assert history | |
| def test_export_flags_creates_files(api_client_with_models, tmp_path, monkeypatch): | |
| import database | |
| from config import CFG | |
| monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99) | |
| api_client_with_models.post( | |
| "/predict", json={"text": "Unclear headline with mixed topics", "model_name": "lr"} | |
| ) | |
| out_dir = tmp_path / "low_confidence_review" | |
| original = database.export_low_confidence_to_folder | |
| def _export_override(): | |
| return original(output_dir=str(out_dir)) | |
| monkeypatch.setattr(database, "export_low_confidence_to_folder", _export_override) | |
| resp = api_client_with_models.post("/analytics/export_flags") | |
| assert resp.status_code == 200 | |
| payload = resp.json() | |
| assert payload["exported"] >= 1 | |
| def test_health_with_startup_runs(api_client_with_startup): | |
| resp = api_client_with_startup.get("/health") | |
| assert resp.status_code == 200 | |