Spaces:
Running
Running
File size: 3,198 Bytes
570f7bd a45c0eb 570f7bd 4dae3e6 a45c0eb c1bc4eb 4dae3e6 570f7bd c1bc4eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd ccefd8e 570f7bd a45c0eb 570f7bd a45c0eb c1bc4eb 570f7bd ccefd8e 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd ccefd8e 570f7bd a45c0eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from fastapi.testclient import TestClient
from app.main import app
from nl2sql.pipeline import FinalResult
client = TestClient(app)
def fake_trace(stage: str) -> dict:
# FinalResult.traces is a list of dicts (StageTrace.__dict__)
return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}
path = app.url_path_for("nl2sql_handler")
# --- 1) Clarify / ambiguity case ---------------------------------------------
def test_ambiguity_route(monkeypatch):
from app.routers import nl2sql
# mock pipeline to return FinalResult with ambiguous=True
def fake_run(*args, **kwargs):
return FinalResult(
ok=True,
ambiguous=True,
error=False,
details=["Ambiguities found: 1"],
questions=["Which table do you mean?"],
sql=None,
rationale=None,
verified=None,
traces=[fake_trace("detector")],
)
monkeypatch.setattr(nl2sql.Pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "show all records",
"schema_preview": "CREATE TABLE ...",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["ambiguous"] is True
assert "questions" in data
assert isinstance(data["questions"], list)
# --- 2) Error / failure case -------------------------------------------------
def test_error_route(monkeypatch):
from app.routers import nl2sql
def fake_run(*args, **kwargs):
return FinalResult(
ok=False,
ambiguous=False,
error=True,
details=["Bad SQL"],
questions=None,
sql=None,
rationale=None,
verified=None,
traces=[fake_trace("safety")],
)
monkeypatch.setattr(nl2sql.Pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "drop table users;",
"schema_preview": "CREATE TABLE users(id int);",
},
)
assert resp.status_code == 400
assert "Bad SQL" in resp.json()["detail"]
# --- 3) Success / happy path -------------------------------------------------
def test_success_route(monkeypatch):
from app.routers import nl2sql
def fake_run(*args, **kwargs):
return FinalResult(
ok=True,
ambiguous=False,
error=False,
details=None,
questions=None,
sql="SELECT * FROM users;",
rationale="Simple listing",
verified=True,
traces=[fake_trace("planner"), fake_trace("generator")],
)
monkeypatch.setattr(nl2sql.Pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "show all users",
"schema_preview": "CREATE TABLE users(id int, name text);",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["sql"].lower().startswith("select")
assert isinstance(data["traces"], list)
assert any(t["stage"] == "planner" for t in data["traces"])
assert any(t["stage"] == "generator" for t in data["traces"])
|