Melika Kheirieh commited on
Commit
1615809
·
1 Parent(s): 9c10293

test(router): add full coverage for nl2sql_handler (clarify, error, success, db_id, crash, invalid-type, 422, trace rounding)

Browse files
Files changed (2) hide show
  1. .coverage +0 -0
  2. tests/test_nl2sql_router.py +148 -1
.coverage CHANGED
Binary files a/.coverage and b/.coverage differ
 
tests/test_nl2sql_router.py CHANGED
@@ -1,7 +1,7 @@
1
- # tests/test_nl2sql_router.py
2
  from __future__ import annotations
3
 
4
  from fastapi.testclient import TestClient
 
5
  from app.main import app
6
  from app.routers import nl2sql
7
  from nl2sql.pipeline import FinalResult
@@ -11,11 +11,14 @@ path = app.url_path_for("nl2sql_handler")
11
 
12
 
13
  def fake_trace(stage: str) -> dict:
 
14
  return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}
15
 
16
 
17
  # --- 1) Clarify / ambiguity case ---------------------------------------------
18
  def test_ambiguity_route():
 
 
19
  def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult:
20
  return FinalResult(
21
  ok=True,
@@ -45,6 +48,8 @@ def test_ambiguity_route():
45
 
46
  # --- 2) Error / failure case -------------------------------------------------
47
  def test_error_route():
 
 
48
  def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult:
49
  return FinalResult(
50
  ok=False,
@@ -75,6 +80,8 @@ def test_error_route():
75
 
76
  # --- 3) Success / happy path -------------------------------------------------
77
  def test_success_route():
 
 
78
  def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult:
79
  return FinalResult(
80
  ok=True,
@@ -105,3 +112,143 @@ def test_success_route():
105
  assert any(t["stage"] == "generator" for t in data["traces"])
106
  finally:
107
  app.dependency_overrides.pop(nl2sql.get_runner, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  from fastapi.testclient import TestClient
4
+
5
  from app.main import app
6
  from app.routers import nl2sql
7
  from nl2sql.pipeline import FinalResult
 
11
 
12
 
13
  def fake_trace(stage: str) -> dict:
14
+ """Minimal trace stub used across tests."""
15
  return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}
16
 
17
 
18
  # --- 1) Clarify / ambiguity case ---------------------------------------------
19
  def test_ambiguity_route():
20
+ """Should return 200 with ambiguous=True and questions present."""
21
+
22
  def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult:
23
  return FinalResult(
24
  ok=True,
 
48
 
49
  # --- 2) Error / failure case -------------------------------------------------
50
  def test_error_route():
51
+ """Should return 400 and include aggregated details in 'detail'."""
52
+
53
  def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult:
54
  return FinalResult(
55
  ok=False,
 
80
 
81
  # --- 3) Success / happy path -------------------------------------------------
82
  def test_success_route():
83
+ """Should return 200, include SQL and traces with expected stages."""
84
+
85
  def fake_run(*, user_query: str, schema_preview: str | None = None) -> FinalResult:
86
  return FinalResult(
87
  ok=True,
 
112
  assert any(t["stage"] == "generator" for t in data["traces"])
113
  finally:
114
  app.dependency_overrides.pop(nl2sql.get_runner, None)
115
+
116
+
117
+ # --- 4) Success with db_id (per-request pipeline) ----------------------------
118
+ def test_success_route_with_db_id(monkeypatch):
119
+ """Should build a per-request pipeline when db_id is provided."""
120
+
121
+ def fake_select_adapter(db_id: str):
122
+ class DummyAdapter:
123
+ pass
124
+
125
+ return DummyAdapter()
126
+
127
+ class DummyPipeline:
128
+ def run(
129
+ self, *, user_query: str, schema_preview: str | None = None
130
+ ) -> FinalResult:
131
+ return FinalResult(
132
+ ok=True,
133
+ ambiguous=False,
134
+ error=False,
135
+ details=None,
136
+ questions=None,
137
+ sql="SELECT 1;",
138
+ rationale=None,
139
+ verified=True,
140
+ traces=[fake_trace("executor")],
141
+ )
142
+
143
+ monkeypatch.setattr(nl2sql, "_select_adapter", fake_select_adapter)
144
+ monkeypatch.setattr(nl2sql, "_build_pipeline", lambda _a: DummyPipeline())
145
+ monkeypatch.setattr(
146
+ nl2sql, "_derive_schema_preview", lambda _a: "CREATE TABLE t(id int);"
147
+ )
148
+
149
+ resp = client.post(path, json={"query": "anything", "db_id": "sqlite"})
150
+ assert resp.status_code == 200
151
+ assert resp.json()["sql"].startswith("SELECT")
152
+
153
+
154
+ # --- 5) Pipeline crash → 500 -------------------------------------------------
155
+ def test_pipeline_crash_returns_500():
156
+ """Exceptions inside pipeline should result in HTTP 500 with a clear message."""
157
+
158
+ def crash_run(*, user_query: str, schema_preview: str | None = None): # type: ignore[no-untyped-def]
159
+ raise RuntimeError("boom")
160
+
161
+ app.dependency_overrides[nl2sql.get_runner] = lambda: crash_run
162
+ try:
163
+ resp = client.post(path, json={"query": "x"})
164
+ assert resp.status_code == 500
165
+ assert "Pipeline crash" in resp.json()["detail"]
166
+ finally:
167
+ app.dependency_overrides.pop(nl2sql.get_runner, None)
168
+
169
+
170
+ # --- 6) Unexpected output type → 500 -----------------------------------------
171
+ def test_pipeline_returns_non_finalresult():
172
+ """If pipeline returns a non-FinalResult, it must yield HTTP 500."""
173
+
174
+ def bad_run(
175
+ *, user_query: str, schema_preview: str | None = None
176
+ ): # no FinalResult
177
+ return {"ok": True}
178
+
179
+ app.dependency_overrides[nl2sql.get_runner] = lambda: bad_run
180
+ try:
181
+ resp = client.post(path, json={"query": "x"})
182
+ assert resp.status_code == 500
183
+ assert "unexpected type" in resp.json()["detail"].lower()
184
+ finally:
185
+ app.dependency_overrides.pop(nl2sql.get_runner, None)
186
+
187
+
188
+ # --- 7) Ambiguous without questions (edge case) ------------------------------
189
+ def test_ambiguity_without_questions_edge_case():
190
+ """
191
+ If ambiguous=True but questions is None, handler should not crash.
192
+ Accept either 200 (if handler treats it as clarify) or 400 (if treated as error).
193
+ """
194
+
195
+ def bad_ambiguous(
196
+ *, user_query: str, schema_preview: str | None = None
197
+ ) -> FinalResult:
198
+ return FinalResult(
199
+ ok=True,
200
+ ambiguous=True,
201
+ error=False,
202
+ details=["ambiguous but no questions"],
203
+ questions=None,
204
+ sql=None,
205
+ rationale=None,
206
+ verified=None,
207
+ traces=[fake_trace("detector")],
208
+ )
209
+
210
+ app.dependency_overrides[nl2sql.get_runner] = lambda: bad_ambiguous
211
+ try:
212
+ resp = client.post(path, json={"query": "x"})
213
+ assert resp.status_code in (200, 400)
214
+ finally:
215
+ app.dependency_overrides.pop(nl2sql.get_runner, None)
216
+
217
+
218
+ # --- 8) FastAPI validation (422) ---------------------------------------------
219
+ def test_validation_422_missing_query():
220
+ """Pydantic/FastAPI should return 422 when required field is missing."""
221
+ resp = client.post(path, json={"schema_preview": "CREATE TABLE t(id int);"})
222
+ assert resp.status_code == 422
223
+
224
+
225
+ # --- 9) Trace rounding to int ------------------------------------------------
226
+ def test_traces_are_rounded_to_ints():
227
+ """duration_ms in traces must be coerced/rounded to int in the response."""
228
+
229
+ def run_with_float_traces(
230
+ *, user_query: str, schema_preview: str | None = None
231
+ ) -> FinalResult:
232
+ return FinalResult(
233
+ ok=True,
234
+ ambiguous=False,
235
+ error=False,
236
+ details=None,
237
+ questions=None,
238
+ sql="SELECT 1;",
239
+ rationale=None,
240
+ verified=True,
241
+ traces=[
242
+ {"stage": "x", "duration_ms": 12.7, "notes": None, "cost_usd": None}
243
+ ],
244
+ )
245
+
246
+ app.dependency_overrides[nl2sql.get_runner] = lambda: run_with_float_traces
247
+ try:
248
+ resp = client.post(path, json={"query": "x"})
249
+ assert resp.status_code == 200
250
+ traces = resp.json()["traces"]
251
+ assert isinstance(traces, list) and traces
252
+ assert isinstance(traces[0]["duration_ms"], int)
253
+ finally:
254
+ app.dependency_overrides.pop(nl2sql.get_runner, None)