File size: 8,942 Bytes
9c10293
 
570f7bd
1615809
570f7bd
9c10293
a45c0eb
570f7bd
 
9c10293
570f7bd
4dae3e6
a45c0eb
1615809
a45c0eb
c1bc4eb
4dae3e6
570f7bd
9c10293
1615809
 
9c10293
a45c0eb
570f7bd
a45c0eb
 
 
 
 
 
 
 
570f7bd
 
9c10293
 
 
 
 
 
 
 
 
 
 
 
570f7bd
 
 
9c10293
1615809
 
9c10293
a45c0eb
 
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
9c10293
 
 
 
 
 
 
 
 
 
 
 
 
570f7bd
 
 
9c10293
1615809
 
9c10293
a45c0eb
570f7bd
a45c0eb
 
 
 
 
 
 
 
570f7bd
 
9c10293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343ad62
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from __future__ import annotations

from fastapi.testclient import TestClient

from app.main import app
from app.routers import nl2sql
from nl2sql.pipeline import FinalResult

client = TestClient(app)
path = app.url_path_for("nl2sql_handler")


def fake_trace(stage: str) -> dict:
    """Minimal trace stub used across tests."""
    return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}


# --- 1) Clarify / ambiguity case ---------------------------------------------
def test_ambiguity_route():
    """Should return 200 with ambiguous=True and questions present."""

    def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult:
        return FinalResult(
            ok=True,
            ambiguous=True,
            error=False,
            details=["Ambiguities found: 1"],
            questions=["Which table do you mean?"],
            sql=None,
            rationale=None,
            verified=None,
            traces=[fake_trace("detector")],
        )

    app.dependency_overrides[nl2sql.get_runner] = lambda: fake_run
    try:
        resp = client.post(
            path,
            json={"query": "show all records", "schema_preview": "CREATE TABLE ..."},
        )
        assert resp.status_code == 200
        data = resp.json()
        assert data["ambiguous"] is True
        assert "questions" in data and isinstance(data["questions"], list)
    finally:
        app.dependency_overrides.pop(nl2sql.get_runner, None)


# --- 2) Error / failure case -------------------------------------------------
def test_error_route():
    """Should return 400 and include aggregated details in 'detail'."""

    def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult:
        return FinalResult(
            ok=False,
            ambiguous=False,
            error=True,
            details=["Bad SQL"],
            questions=None,
            sql=None,
            rationale=None,
            verified=None,
            traces=[fake_trace("safety")],
        )

    app.dependency_overrides[nl2sql.get_runner] = lambda: fake_run
    try:
        resp = client.post(
            path,
            json={
                "query": "drop table users;",
                "schema_preview": "CREATE TABLE users(id int);",
            },
        )
        assert resp.status_code == 400
        assert "Bad SQL" in resp.json()["detail"]
    finally:
        app.dependency_overrides.pop(nl2sql.get_runner, None)


# --- 3) Success / happy path -------------------------------------------------
def test_success_route():
    """Should return 200, include SQL and traces with expected stages."""

    def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult:
        return FinalResult(
            ok=True,
            ambiguous=False,
            error=False,
            details=None,
            questions=None,
            sql="SELECT * FROM users;",
            rationale="Simple listing",
            verified=True,
            traces=[fake_trace("planner"), fake_trace("generator")],
        )

    app.dependency_overrides[nl2sql.get_runner] = lambda: fake_run
    try:
        resp = client.post(
            path,
            json={
                "query": "show all users",
                "schema_preview": "CREATE TABLE users(id int, name text);",
            },
        )
        assert resp.status_code == 200
        data = resp.json()
        assert data["sql"].lower().startswith("select")
        assert isinstance(data["traces"], list)
        assert any(t["stage"] == "planner" for t in data["traces"])
        assert any(t["stage"] == "generator" for t in data["traces"])
    finally:
        app.dependency_overrides.pop(nl2sql.get_runner, None)


# --- 4) Success with db_id (per-request pipeline) ----------------------------
def test_success_route_with_db_id(monkeypatch):
    """Should build a per-request pipeline when db_id is provided."""

    def fake_select_adapter(db_id: str):
        class DummyAdapter:
            pass

        return DummyAdapter()

    class DummyPipeline:
        def run(
            self, *, user_query: str, schema_preview: str | None = None
        ) -> FinalResult:
            return FinalResult(
                ok=True,
                ambiguous=False,
                error=False,
                details=None,
                questions=None,
                sql="SELECT 1;",
                rationale=None,
                verified=True,
                traces=[fake_trace("executor")],
            )

    monkeypatch.setattr(nl2sql, "_select_adapter", fake_select_adapter)
    monkeypatch.setattr(nl2sql, "_build_pipeline", lambda _a: DummyPipeline())
    monkeypatch.setattr(
        nl2sql, "_derive_schema_preview", lambda _a: "CREATE TABLE t(id int);"
    )

    resp = client.post(path, json={"query": "anything", "db_id": "sqlite"})
    assert resp.status_code == 200
    assert resp.json()["sql"].startswith("SELECT")


# --- 5) Pipeline crash → 500 -------------------------------------------------
def test_pipeline_crash_returns_500():
    """Exceptions inside pipeline should result in HTTP 500 with a clear message."""

    def crash_run(*, user_query: str, schema_preview: str | None = None):  # type: ignore[no-untyped-def]
        raise RuntimeError("boom")

    app.dependency_overrides[nl2sql.get_runner] = lambda: crash_run
    try:
        resp = client.post(path, json={"query": "x"})
        assert resp.status_code == 500
        assert "Pipeline crash" in resp.json()["detail"]
    finally:
        app.dependency_overrides.pop(nl2sql.get_runner, None)


# --- 6) Unexpected output type → 500 -----------------------------------------
def test_pipeline_returns_non_finalresult():
    """If pipeline returns a non-FinalResult, it must yield HTTP 500."""

    def bad_run(
        *, user_query: str, schema_preview: str | None = None
    ):  # no FinalResult
        return {"ok": True}

    app.dependency_overrides[nl2sql.get_runner] = lambda: bad_run
    try:
        resp = client.post(path, json={"query": "x"})
        assert resp.status_code == 500
        assert "unexpected type" in resp.json()["detail"].lower()
    finally:
        app.dependency_overrides.pop(nl2sql.get_runner, None)


# --- 7) Ambiguous without questions (edge case) ------------------------------
def test_ambiguity_without_questions_edge_case():
    """
    If ambiguous=True but questions is None, handler should not crash.
    Accept either 200 (if handler treats it as clarify) or 400 (if treated as error).
    """

    def bad_ambiguous(
        *, user_query: str, schema_preview: str | None = None
    ) -> FinalResult:
        return FinalResult(
            ok=True,
            ambiguous=True,
            error=False,
            details=["ambiguous but no questions"],
            questions=None,
            sql=None,
            rationale=None,
            verified=None,
            traces=[fake_trace("detector")],
        )

    app.dependency_overrides[nl2sql.get_runner] = lambda: bad_ambiguous
    try:
        resp = client.post(path, json={"query": "x"})
        assert resp.status_code in (200, 400)
    finally:
        app.dependency_overrides.pop(nl2sql.get_runner, None)


# --- 8) FastAPI validation (422) ---------------------------------------------
def test_validation_422_missing_query():
    """Pydantic/FastAPI should return 422 when required field is missing."""
    resp = client.post(path, json={"schema_preview": "CREATE TABLE t(id int);"})
    assert resp.status_code == 422


# --- 9) Trace rounding to int ------------------------------------------------
def test_traces_are_rounded_to_ints():
    """duration_ms in traces must be coerced/rounded to int in the response."""

    def run_with_float_traces(
        *, user_query: str, schema_preview: str | None = None
    ) -> FinalResult:
        return FinalResult(
            ok=True,
            ambiguous=False,
            error=False,
            details=None,
            questions=None,
            sql="SELECT 1;",
            rationale=None,
            verified=True,
            traces=[
                {"stage": "x", "duration_ms": 12.7, "notes": None, "cost_usd": None}
            ],
        )

    app.dependency_overrides[nl2sql.get_runner] = lambda: run_with_float_traces
    try:
        resp = client.post(path, json={"query": "x"})
        assert resp.status_code == 200
        traces = resp.json()["traces"]
        assert isinstance(traces, list) and traces
        assert isinstance(traces[0]["duration_ms"], int)
    finally:
        app.dependency_overrides.pop(nl2sql.get_runner, None)


def test_nl2sql_handler_returns_sql(monkeypatch):
    payload = {"query": "Top 5 albums by sales"}
    r = client.post("/nl2sql", json=payload)
    assert r.status_code == 200
    data = r.json()
    assert "sql" in data
    assert "traces" in data