Spaces:
Sleeping
Sleeping
Melika Kheirieh
refactor(core): DI-ready Pipeline; add registry + YAML factory + typed trace/result
343ad62
| from __future__ import annotations | |
| from fastapi.testclient import TestClient | |
| from app.main import app | |
| from app.routers import nl2sql | |
| from nl2sql.pipeline import FinalResult | |
| client = TestClient(app) | |
| path = app.url_path_for("nl2sql_handler") | |
| def fake_trace(stage: str) -> dict: | |
| """Minimal trace stub used across tests.""" | |
| return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None} | |
| # --- 1) Clarify / ambiguity case --------------------------------------------- | |
| def test_ambiguity_route(): | |
| """Should return 200 with ambiguous=True and questions present.""" | |
| def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult: | |
| return FinalResult( | |
| ok=True, | |
| ambiguous=True, | |
| error=False, | |
| details=["Ambiguities found: 1"], | |
| questions=["Which table do you mean?"], | |
| sql=None, | |
| rationale=None, | |
| verified=None, | |
| traces=[fake_trace("detector")], | |
| ) | |
| app.dependency_overrides[nl2sql.get_runner] = lambda: fake_run | |
| try: | |
| resp = client.post( | |
| path, | |
| json={"query": "show all records", "schema_preview": "CREATE TABLE ..."}, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["ambiguous"] is True | |
| assert "questions" in data and isinstance(data["questions"], list) | |
| finally: | |
| app.dependency_overrides.pop(nl2sql.get_runner, None) | |
| # --- 2) Error / failure case ------------------------------------------------- | |
| def test_error_route(): | |
| """Should return 400 and include aggregated details in 'detail'.""" | |
| def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult: | |
| return FinalResult( | |
| ok=False, | |
| ambiguous=False, | |
| error=True, | |
| details=["Bad SQL"], | |
| questions=None, | |
| sql=None, | |
| rationale=None, | |
| verified=None, | |
| traces=[fake_trace("safety")], | |
| ) | |
| app.dependency_overrides[nl2sql.get_runner] = lambda: fake_run | |
| try: | |
| resp = client.post( | |
| path, | |
| json={ | |
| "query": "drop table users;", | |
| "schema_preview": "CREATE TABLE users(id int);", | |
| }, | |
| ) | |
| assert resp.status_code == 400 | |
| assert "Bad SQL" in resp.json()["detail"] | |
| finally: | |
| app.dependency_overrides.pop(nl2sql.get_runner, None) | |
| # --- 3) Success / happy path ------------------------------------------------- | |
| def test_success_route(): | |
| """Should return 200, include SQL and traces with expected stages.""" | |
| def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult: | |
| return FinalResult( | |
| ok=True, | |
| ambiguous=False, | |
| error=False, | |
| details=None, | |
| questions=None, | |
| sql="SELECT * FROM users;", | |
| rationale="Simple listing", | |
| verified=True, | |
| traces=[fake_trace("planner"), fake_trace("generator")], | |
| ) | |
| app.dependency_overrides[nl2sql.get_runner] = lambda: fake_run | |
| try: | |
| resp = client.post( | |
| path, | |
| json={ | |
| "query": "show all users", | |
| "schema_preview": "CREATE TABLE users(id int, name text);", | |
| }, | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["sql"].lower().startswith("select") | |
| assert isinstance(data["traces"], list) | |
| assert any(t["stage"] == "planner" for t in data["traces"]) | |
| assert any(t["stage"] == "generator" for t in data["traces"]) | |
| finally: | |
| app.dependency_overrides.pop(nl2sql.get_runner, None) | |
| # --- 4) Success with db_id (per-request pipeline) ---------------------------- | |
| def test_success_route_with_db_id(monkeypatch): | |
| """Should build a per-request pipeline when db_id is provided.""" | |
| def fake_select_adapter(db_id: str): | |
| class DummyAdapter: | |
| pass | |
| return DummyAdapter() | |
| class DummyPipeline: | |
| def run( | |
| self, *, user_query: str, schema_preview: str | None = None | |
| ) -> FinalResult: | |
| return FinalResult( | |
| ok=True, | |
| ambiguous=False, | |
| error=False, | |
| details=None, | |
| questions=None, | |
| sql="SELECT 1;", | |
| rationale=None, | |
| verified=True, | |
| traces=[fake_trace("executor")], | |
| ) | |
| monkeypatch.setattr(nl2sql, "_select_adapter", fake_select_adapter) | |
| monkeypatch.setattr(nl2sql, "_build_pipeline", lambda _a: DummyPipeline()) | |
| monkeypatch.setattr( | |
| nl2sql, "_derive_schema_preview", lambda _a: "CREATE TABLE t(id int);" | |
| ) | |
| resp = client.post(path, json={"query": "anything", "db_id": "sqlite"}) | |
| assert resp.status_code == 200 | |
| assert resp.json()["sql"].startswith("SELECT") | |
| # --- 5) Pipeline crash → 500 ------------------------------------------------- | |
| def test_pipeline_crash_returns_500(): | |
| """Exceptions inside pipeline should result in HTTP 500 with a clear message.""" | |
| def crash_run(*, user_query: str, schema_preview: str | None = None): # type: ignore[no-untyped-def] | |
| raise RuntimeError("boom") | |
| app.dependency_overrides[nl2sql.get_runner] = lambda: crash_run | |
| try: | |
| resp = client.post(path, json={"query": "x"}) | |
| assert resp.status_code == 500 | |
| assert "Pipeline crash" in resp.json()["detail"] | |
| finally: | |
| app.dependency_overrides.pop(nl2sql.get_runner, None) | |
| # --- 6) Unexpected output type → 500 ----------------------------------------- | |
| def test_pipeline_returns_non_finalresult(): | |
| """If pipeline returns a non-FinalResult, it must yield HTTP 500.""" | |
| def bad_run( | |
| *, user_query: str, schema_preview: str | None = None | |
| ): # no FinalResult | |
| return {"ok": True} | |
| app.dependency_overrides[nl2sql.get_runner] = lambda: bad_run | |
| try: | |
| resp = client.post(path, json={"query": "x"}) | |
| assert resp.status_code == 500 | |
| assert "unexpected type" in resp.json()["detail"].lower() | |
| finally: | |
| app.dependency_overrides.pop(nl2sql.get_runner, None) | |
| # --- 7) Ambiguous without questions (edge case) ------------------------------ | |
| def test_ambiguity_without_questions_edge_case(): | |
| """ | |
| If ambiguous=True but questions is None, handler should not crash. | |
| Accept either 200 (if handler treats it as clarify) or 400 (if treated as error). | |
| """ | |
| def bad_ambiguous( | |
| *, user_query: str, schema_preview: str | None = None | |
| ) -> FinalResult: | |
| return FinalResult( | |
| ok=True, | |
| ambiguous=True, | |
| error=False, | |
| details=["ambiguous but no questions"], | |
| questions=None, | |
| sql=None, | |
| rationale=None, | |
| verified=None, | |
| traces=[fake_trace("detector")], | |
| ) | |
| app.dependency_overrides[nl2sql.get_runner] = lambda: bad_ambiguous | |
| try: | |
| resp = client.post(path, json={"query": "x"}) | |
| assert resp.status_code in (200, 400) | |
| finally: | |
| app.dependency_overrides.pop(nl2sql.get_runner, None) | |
| # --- 8) FastAPI validation (422) --------------------------------------------- | |
| def test_validation_422_missing_query(): | |
| """Pydantic/FastAPI should return 422 when required field is missing.""" | |
| resp = client.post(path, json={"schema_preview": "CREATE TABLE t(id int);"}) | |
| assert resp.status_code == 422 | |
| # --- 9) Trace rounding to int ------------------------------------------------ | |
| def test_traces_are_rounded_to_ints(): | |
| """duration_ms in traces must be coerced/rounded to int in the response.""" | |
| def run_with_float_traces( | |
| *, user_query: str, schema_preview: str | None = None | |
| ) -> FinalResult: | |
| return FinalResult( | |
| ok=True, | |
| ambiguous=False, | |
| error=False, | |
| details=None, | |
| questions=None, | |
| sql="SELECT 1;", | |
| rationale=None, | |
| verified=True, | |
| traces=[ | |
| {"stage": "x", "duration_ms": 12.7, "notes": None, "cost_usd": None} | |
| ], | |
| ) | |
| app.dependency_overrides[nl2sql.get_runner] = lambda: run_with_float_traces | |
| try: | |
| resp = client.post(path, json={"query": "x"}) | |
| assert resp.status_code == 200 | |
| traces = resp.json()["traces"] | |
| assert isinstance(traces, list) and traces | |
| assert isinstance(traces[0]["duration_ms"], int) | |
| finally: | |
| app.dependency_overrides.pop(nl2sql.get_runner, None) | |
| def test_nl2sql_handler_returns_sql(monkeypatch): | |
| payload = {"query": "Top 5 albums by sales"} | |
| r = client.post("/nl2sql", json=payload) | |
| assert r.status_code == 200 | |
| data = r.json() | |
| assert "sql" in data | |
| assert "traces" in data | |