nl2sql-copilot / tests /test_nl2sql_router.py
Melika Kheirieh
style: format code with ruff
4dae3e6
raw
history blame
3.2 kB
from fastapi.testclient import TestClient
from app.main import app
from nl2sql.pipeline import FinalResult
client = TestClient(app)
def fake_trace(stage: str) -> dict:
# FinalResult.traces is a list of dicts (StageTrace.__dict__)
return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}
path = app.url_path_for("nl2sql_handler")
# --- 1) Clarify / ambiguity case ---------------------------------------------
def test_ambiguity_route(monkeypatch):
from app.routers import nl2sql
# mock pipeline to return FinalResult with ambiguous=True
def fake_run(*args, **kwargs):
return FinalResult(
ok=True,
ambiguous=True,
error=False,
details=["Ambiguities found: 1"],
questions=["Which table do you mean?"],
sql=None,
rationale=None,
verified=None,
traces=[fake_trace("detector")],
)
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "show all records",
"schema_preview": "CREATE TABLE ...",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["ambiguous"] is True
assert "questions" in data
assert isinstance(data["questions"], list)
# --- 2) Error / failure case -------------------------------------------------
def test_error_route(monkeypatch):
from app.routers import nl2sql
def fake_run(*args, **kwargs):
return FinalResult(
ok=False,
ambiguous=False,
error=True,
details=["Bad SQL"],
questions=None,
sql=None,
rationale=None,
verified=None,
traces=[fake_trace("safety")],
)
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "drop table users;",
"schema_preview": "CREATE TABLE users(id int);",
},
)
assert resp.status_code == 400
assert "Bad SQL" in resp.json()["detail"]
# --- 3) Success / happy path -------------------------------------------------
def test_success_route(monkeypatch):
from app.routers import nl2sql
def fake_run(*args, **kwargs):
return FinalResult(
ok=True,
ambiguous=False,
error=False,
details=None,
questions=None,
sql="SELECT * FROM users;",
rationale="Simple listing",
verified=True,
traces=[fake_trace("planner"), fake_trace("generator")],
)
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "show all users",
"schema_preview": "CREATE TABLE users(id int, name text);",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["sql"].lower().startswith("select")
assert isinstance(data["traces"], list)
assert any(t["stage"] == "planner" for t in data["traces"])
assert any(t["stage"] == "generator" for t in data["traces"])