nexa-classify-api / tests /test_api.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
5.98 kB
def test_health_returns_ok(api_client):
resp = api_client.get("/health")
assert resp.status_code == 200
assert resp.json()["status"] == "ok"
def test_labels_returns_four_classes(api_client):
resp = api_client.get("/labels")
assert resp.status_code == 200
assert len(resp.json()["labels"]) == 4
def test_predict_missing_model_returns_404(api_client):
resp = api_client.post(
"/predict", json={"text": "Some news article", "model_name": "nonexistent_model"}
)
assert resp.status_code == 404
def test_predict_empty_text_returns_422(api_client):
resp = api_client.post("/predict", json={"text": "", "model_name": "lr"})
assert resp.status_code == 422
def test_batch_predict_too_many_texts_returns_422(api_client):
resp = api_client.post(
"/batch_predict", json={"texts": ["text"] * 257, "model_name": "lr"}
)
assert resp.status_code == 422
def test_analytics_summary_returns_valid_json(api_client):
resp = api_client.get("/analytics/summary")
assert resp.status_code == 200
assert "total_requests" in resp.json()
def test_models_endpoint_lists_sklearn_model(api_client_with_models):
resp = api_client_with_models.get("/models")
assert resp.status_code == 200
payload = resp.json()
names = {m["name"] for m in payload["models"]}
assert "lr" in names
def test_predict_lr_success_logs_request(api_client_with_models, test_db_path):
import database
resp = api_client_with_models.post(
"/predict", json={"text": "Fed raises interest rates by 50 bps", "model_name": "lr"}
)
assert resp.status_code == 200
body = resp.json()
assert "request_id" in body
assert "is_low_confidence" in body
assert "latency_ms" in body
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert len(history) == 1
assert history[0]["request_id"] == body["request_id"]
def test_predict_lr_missing_joblib_returns_404(api_client):
resp = api_client.post(
"/predict", json={"text": "Fed raises rates", "model_name": "lr"}
)
assert resp.status_code == 404
def test_batch_predict_lr_logs_each_item(api_client_with_models, test_db_path):
import database
texts = ["Apple unveils new AI chip", "Team wins the final match in overtime"]
resp = api_client_with_models.post(
"/batch_predict", json={"texts": texts, "model_name": "lr"}
)
assert resp.status_code == 200
body = resp.json()
assert body["count"] == 2
assert len(body["predictions"]) == 2
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert len(history) == 2
assert all(int(r["is_batch"]) == 1 for r in history)
def test_predict_svm_success_uses_decision_function(api_client_with_models):
resp = api_client_with_models.post(
"/predict", json={"text": "Championship match ends in overtime", "model_name": "svm"}
)
assert resp.status_code == 200
body = resp.json()
assert body["probabilities"] is None
assert 0.0 <= float(body.get("latency_ms", 0.0))
def test_batch_predict_svm_success(api_client_with_models):
resp = api_client_with_models.post(
"/batch_predict",
json={"texts": ["Markets rise on earnings", "New chip released"], "model_name": "svm"},
)
assert resp.status_code == 200
body = resp.json()
assert body["count"] == 2
def test_predict_lr_cached_path(api_client_with_models, test_db_path):
import database
r1 = api_client_with_models.post(
"/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"}
)
r2 = api_client_with_models.post(
"/predict", json={"text": "Stocks fall on inflation data", "model_name": "lr"}
)
assert r1.status_code == 200
assert r2.status_code == 200
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert len(history) == 2
def test_low_confidence_review_flow(api_client_with_models, test_db_path, monkeypatch):
import database
from config import CFG
monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99)
resp = api_client_with_models.post(
"/predict", json={"text": "Mixed signals from markets after earnings report", "model_name": "lr"}
)
assert resp.status_code == 200
request_id = resp.json()["request_id"]
flags = api_client_with_models.get("/analytics/low_confidence").json()
assert any(f["request_id"] == request_id for f in flags)
patch_resp = api_client_with_models.patch(
f"/analytics/review/{request_id}", json={"note": "needs review"}
)
assert patch_resp.status_code == 200
reviewed_flags = api_client_with_models.get("/analytics/low_confidence?reviewed=true").json()
match = [f for f in reviewed_flags if f["request_id"] == request_id]
assert match
assert int(match[0]["reviewed"]) == 1
assert match[0]["review_note"] == "needs review"
history = database.get_request_history(db_path=str(test_db_path), limit=10)
assert history
def test_export_flags_creates_files(api_client_with_models, tmp_path, monkeypatch):
import database
from config import CFG
monkeypatch.setattr(CFG, "low_confidence_threshold", 0.99)
api_client_with_models.post(
"/predict", json={"text": "Unclear headline with mixed topics", "model_name": "lr"}
)
out_dir = tmp_path / "low_confidence_review"
original = database.export_low_confidence_to_folder
def _export_override():
return original(output_dir=str(out_dir))
monkeypatch.setattr(database, "export_low_confidence_to_folder", _export_override)
resp = api_client_with_models.post("/analytics/export_flags")
assert resp.status_code == 200
payload = resp.json()
assert payload["exported"] >= 1
def test_health_with_startup_runs(api_client_with_startup):
resp = api_client_with_startup.get("/health")
assert resp.status_code == 200