Spaces:
Sleeping
Sleeping
File size: 3,201 Bytes
570f7bd a45c0eb 570f7bd 4dae3e6 a45c0eb c1bc4eb 4dae3e6 570f7bd c1bc4eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd a45c0eb c1bc4eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd a45c0eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from fastapi.testclient import TestClient
from app.main import app
from nl2sql.pipeline import FinalResult
client = TestClient(app)
def fake_trace(stage: str) -> dict:
# FinalResult.traces is a list of dicts (StageTrace.__dict__)
return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}
path = app.url_path_for("nl2sql_handler")
# --- 1) Clarify / ambiguity case ---------------------------------------------
def test_ambiguity_route(monkeypatch):
from app.routers import nl2sql
# mock pipeline to return FinalResult with ambiguous=True
def fake_run(*args, **kwargs):
return FinalResult(
ok=True,
ambiguous=True,
error=False,
details=["Ambiguities found: 1"],
questions=["Which table do you mean?"],
sql=None,
rationale=None,
verified=None,
traces=[fake_trace("detector")],
)
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "show all records",
"schema_preview": "CREATE TABLE ...",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["ambiguous"] is True
assert "questions" in data
assert isinstance(data["questions"], list)
# --- 2) Error / failure case -------------------------------------------------
def test_error_route(monkeypatch):
from app.routers import nl2sql
def fake_run(*args, **kwargs):
return FinalResult(
ok=False,
ambiguous=False,
error=True,
details=["Bad SQL"],
questions=None,
sql=None,
rationale=None,
verified=None,
traces=[fake_trace("safety")],
)
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "drop table users;",
"schema_preview": "CREATE TABLE users(id int);",
},
)
assert resp.status_code == 400
assert "Bad SQL" in resp.json()["detail"]
# --- 3) Success / happy path -------------------------------------------------
def test_success_route(monkeypatch):
from app.routers import nl2sql
def fake_run(*args, **kwargs):
return FinalResult(
ok=True,
ambiguous=False,
error=False,
details=None,
questions=None,
sql="SELECT * FROM users;",
rationale="Simple listing",
verified=True,
traces=[fake_trace("planner"), fake_trace("generator")],
)
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "show all users",
"schema_preview": "CREATE TABLE users(id int, name text);",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["sql"].lower().startswith("select")
assert isinstance(data["traces"], list)
assert any(t["stage"] == "planner" for t in data["traces"])
assert any(t["stage"] == "generator" for t in data["traces"])
|