from fastapi.testclient import TestClient from app.main import app from nl2sql.pipeline import FinalResult client = TestClient(app) def fake_trace(stage: str) -> dict: # FinalResult.traces is a list of dicts (StageTrace.__dict__) return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None} path = app.url_path_for("nl2sql_handler") # --- 1) Clarify / ambiguity case --------------------------------------------- def test_ambiguity_route(monkeypatch): from app.routers import nl2sql # mock pipeline to return FinalResult with ambiguous=True def fake_run(*args, **kwargs): return FinalResult( ok=True, ambiguous=True, error=False, details=["Ambiguities found: 1"], questions=["Which table do you mean?"], sql=None, rationale=None, verified=None, traces=[fake_trace("detector")], ) monkeypatch.setattr(nl2sql._pipeline, "run", fake_run) resp = client.post( path, json={ "query": "show all records", "schema_preview": "CREATE TABLE ...", }, ) assert resp.status_code == 200 data = resp.json() assert data["ambiguous"] is True assert "questions" in data assert isinstance(data["questions"], list) # --- 2) Error / failure case ------------------------------------------------- def test_error_route(monkeypatch): from app.routers import nl2sql def fake_run(*args, **kwargs): return FinalResult( ok=False, ambiguous=False, error=True, details=["Bad SQL"], questions=None, sql=None, rationale=None, verified=None, traces=[fake_trace("safety")], ) monkeypatch.setattr(nl2sql._pipeline, "run", fake_run) resp = client.post( path, json={ "query": "drop table users;", "schema_preview": "CREATE TABLE users(id int);", }, ) assert resp.status_code == 400 assert "Bad SQL" in resp.json()["detail"] # --- 3) Success / happy path ------------------------------------------------- def test_success_route(monkeypatch): from app.routers import nl2sql def fake_run(*args, **kwargs): return FinalResult( ok=True, ambiguous=False, error=False, details=None, questions=None, sql="SELECT * FROM users;", rationale="Simple listing", verified=True, traces=[fake_trace("planner"), fake_trace("generator")], ) monkeypatch.setattr(nl2sql._pipeline, "run", fake_run) resp = client.post( path, json={ "query": "show all users", "schema_preview": "CREATE TABLE users(id int, name text);", }, ) assert resp.status_code == 200 data = resp.json() assert data["sql"].lower().startswith("select") assert isinstance(data["traces"], list) assert any(t["stage"] == "planner" for t in data["traces"]) assert any(t["stage"] == "generator" for t in data["traces"])