Spaces:
Running
Running
File size: 2,674 Bytes
570f7bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import pytest
from fastapi.testclient import TestClient
from app.main import app
from nl2sql.types import StageResult, StageTrace
client = TestClient(app)
def fake_trace(stage: str):
return StageTrace(stage=stage, duration_ms=10.0)
path = app.url_path_for("nl2sql_handler")
# --- 1) Clarify / ambiguity case ---------------------------------------------
def test_ambiguity_route(monkeypatch):
from app.routers import nl2sql
# mock pipeline to return StageResult with ambiguous=True
def fake_run(*args, **kwargs):
return StageResult(
ok=True,
data={
"ambiguous": True,
"questions": ["Which table do you mean?"],
"traces": [fake_trace("detector")],
},
)
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "show all records",
"schema_preview": "CREATE TABLE ...",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["ambiguous"] is True
assert "questions" in data
# --- 2) Error / failure case -------------------------------------------------
def test_error_route(monkeypatch):
from app.routers import nl2sql
def fake_run(*args, **kwargs):
return StageResult(ok=False, error=["Bad SQL"], data={"traces": [fake_trace("safety")]})
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "drop table users;",
"schema_preview": "CREATE TABLE users(id int);",
},
)
assert resp.status_code == 400
assert "Bad SQL" in resp.json()["detail"]
# --- 3) Success / happy path -------------------------------------------------
def test_success_route(monkeypatch):
from app.routers import nl2sql
def fake_run(*args, **kwargs):
return StageResult(
ok=True,
data={
"ambiguous": False,
"sql": "SELECT * FROM users;",
"rationale": "Simple listing",
"traces": [fake_trace("planner"), fake_trace("generator")],
},
)
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
resp = client.post(
path,
json={
"query": "show all users",
"schema_preview": "CREATE TABLE users(id int, name text);",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["sql"].lower().startswith("select")
assert isinstance(data["traces"], list)
assert any(t["stage"] == "planner" for t in data["traces"])
|