def test_health_returns_ok(api_client): resp = api_client.get("/health") assert resp.status_code == 200 assert resp.json()["status"] == "ok" def test_labels_returns_four_classes(api_client): resp = api_client.get("/labels") assert resp.status_code == 200 assert len(resp.json()["labels"]) == 4 def test_predict_missing_model_returns_404(api_client): resp = api_client.post( "/predict", json={"text": "Some news article", "model_name": "nonexistent_model"} ) assert resp.status_code == 404 def test_predict_empty_text_returns_422(api_client): resp = api_client.post("/predict", json={"text": "", "model_name": "lr"}) assert resp.status_code == 422 def test_batch_predict_too_many_texts_returns_422(api_client): resp = api_client.post( "/batch_predict", json={"texts": ["text"] * 257, "model_name": "lr"} ) assert resp.status_code == 422 def test_analytics_summary_returns_valid_json(api_client): resp = api_client.get("/analytics/summary") assert resp.status_code == 200 assert "total_requests" in resp.json() def test_models_endpoint_lists_sklearn_model(api_client_with_models): resp = api_client_with_models.get("/models") assert resp.status_code == 200 payload = resp.json() names = {m["name"] for m in payload["models"]} assert "lr" in names def test_predict_lr_success_logs_request(api_client_with_models, test_db_path): import database resp = api_client_with_models.post( "/predict", json={"text": "Fed raises interest rates by 50 bps", "model_name": "lr"} ) assert resp.status_code == 200 body = resp.json() assert "request_id" in body assert "is_low_confidence" in body assert "latency_ms" in body history = database.get_request_history(db_path=str(test_db_path), limit=10) assert len(history) == 1 assert history[0]["request_id"] == body["request_id"] def test_predict_lr_missing_joblib_returns_404(api_client): resp = api_client.post( "/predict", json={"text": "Fed raises rates", "model_name": "lr"} ) assert resp.status_code == 404 def test_batch_predict_lr_logs_each_item(api_client_with_models, test_db_path): import database texts = ["Apple unveils new AI chip", "Team wins the final match in overtime"] resp = api_client_with_models.post( "/batch_predict", json={"texts": texts, "model_name": "lr"} ) assert resp.status_code == 200 body = resp.json() assert body["count"] == 2 assert len(body["predictions"]) == 2 history = database.get_request_history(db_path=str(test_db_path), limit=10) assert len(history) == 2 assert all(int(r["is_batch"]) == 1 for r in history) def test_predict_svm_success_uses_decision_function(api_client_with_models): resp = api_client_with_models.post( "/predict", json={"text": "Championship match ends in overtime", "model_name": "svm"} ) assert resp.status_code == 200 body = resp.json() assert body["probabilities"] is None assert 0.0 <= float(body.get("latency_ms", 0.0)) def test_batch_predict_svm_success(api_client_with_models): resp = api_client_with_models.post( "/batch_predict", json={"texts": ["Markets rise on earnings", "New chip released"], "model_name": "svm"}, ) assert resp.status_code == 200 body = resp.json() assert body["count"] == 2 def test_predict_lr_cached_path(api_client_with_models, test_db_path): import database r1 = api_client_with_models.post( "/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"} ) r2 = api_client_with_models.post( "/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"} ) assert r1.status_code == 200 assert r2.status_code == 200 history = database.get_request_history(db_path=str(test_db_path), limit=10) assert len(history) == 2 def test_low_confidence_review_flow(api_client_with_models, test_db_path, monkeypatch): import database from config import CFG monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99) resp = api_client_with_models.post( "/predict", json={"text": "Mixed signals from markets after earnings report", "model_name": "lr"} ) assert resp.status_code == 200 request_id = resp.json()["request_id"] flags = api_client_with_models.get("/analytics/low_confidence").json() assert any(f["request_id"] == request_id for f in flags) patch_resp = api_client_with_models.patch( f"/analytics/review/{request_id}", json={"note": "needs review"} ) assert patch_resp.status_code == 200 reviewed_flags = api_client_with_models.get("/analytics/low_confidence?reviewed=true").json() match = [f for f in reviewed_flags if f["request_id"] == request_id] assert match assert int(match[0]["reviewed"]) == 1 assert match[0]["review_note"] == "needs review" history = database.get_request_history(db_path=str(test_db_path), limit=10) assert history def test_export_flags_creates_files(api_client_with_models, tmp_path, monkeypatch): import database from config import CFG monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99) api_client_with_models.post( "/predict", json={"text": "Unclear headline with mixed topics", "model_name": "lr"} ) out_dir = tmp_path / "low_confidence_review" original = database.export_low_confidence_to_folder def _export_override(): return original(output_dir=str(out_dir)) monkeypatch.setattr(database, "export_low_confidence_to_folder", _export_override) resp = api_client_with_models.post("/analytics/export_flags") assert resp.status_code == 200 payload = resp.json() assert payload["exported"] >= 1 def test_health_with_startup_runs(api_client_with_startup): resp = api_client_with_startup.get("/health") assert resp.status_code == 200