Melika Kheirieh commited on
Commit
a45c0eb
·
1 Parent(s): 713d3ca

refactor: unify pipeline output via FinalResult model

Browse files
app/routers/nl2sql.py CHANGED
@@ -1,13 +1,12 @@
1
  from dataclasses import asdict, is_dataclass
2
  from fastapi import APIRouter, HTTPException
3
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
4
- from nl2sql.pipeline import Pipeline
5
  from nl2sql.ambiguity_detector import AmbiguityDetector
6
  from nl2sql.safety import Safety
7
  from nl2sql.planner import Planner
8
  from nl2sql.generator import Generator
9
  from adapters.llm.openai_provider import OpenAIProvider
10
- from nl2sql.types import StageResult
11
  from nl2sql.executor import Executor
12
  from nl2sql.verifier import Verifier
13
  from nl2sql.repair import Repair
@@ -59,28 +58,28 @@ def _round_trace(t: dict) -> dict:
59
  @router.post("", name="nl2sql_handler")
60
  def nl2sql_handler(request: NL2SQLRequest):
61
  result = _pipeline.run(
62
- user_query=request.query, schema_preview=request.schema_preview
 
63
  )
64
 
65
  # --- Ensure result type ---
66
- if not isinstance(result, StageResult):
67
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
68
 
69
- data = result.data or {}
70
-
71
  # --- Handle ambiguity ---
72
- if isinstance(data, dict) and data.get("ambiguous") and data.get("questions"):
73
- return ClarifyResponse(ambiguous=True, questions=data["questions"])
74
 
75
  # --- Handle error ---
76
- if not result.ok:
77
- detail = "; ".join(result.error) if result.error else "Unknown error"
78
  raise HTTPException(status_code=400, detail=detail)
79
 
80
  # --- Success case ---
 
81
  return NL2SQLResponse(
82
  ambiguous=False,
83
- sql=data.get("sql"),
84
- rationale=data.get("rationale"),
85
- traces=[_to_dict(t) for t in data.get("traces", [])],
86
  )
 
1
  from dataclasses import asdict, is_dataclass
2
  from fastapi import APIRouter, HTTPException
3
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
4
+ from nl2sql.pipeline import Pipeline, FinalResult
5
  from nl2sql.ambiguity_detector import AmbiguityDetector
6
  from nl2sql.safety import Safety
7
  from nl2sql.planner import Planner
8
  from nl2sql.generator import Generator
9
  from adapters.llm.openai_provider import OpenAIProvider
 
10
  from nl2sql.executor import Executor
11
  from nl2sql.verifier import Verifier
12
  from nl2sql.repair import Repair
 
58
  @router.post("", name="nl2sql_handler")
59
  def nl2sql_handler(request: NL2SQLRequest):
60
  result = _pipeline.run(
61
+ user_query=request.query,
62
+ schema_preview=request.schema_preview,
63
  )
64
 
65
  # --- Ensure result type ---
66
+ if not isinstance(result, FinalResult):
67
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
68
 
 
 
69
  # --- Handle ambiguity ---
70
+ if result.ambiguous and result.questions:
71
+ return ClarifyResponse(ambiguous=True, questions=result.questions)
72
 
73
  # --- Handle error ---
74
+ if not result.ok or result.error:
75
+ detail = "; ".join(result.details or ["Unknown error"])
76
  raise HTTPException(status_code=400, detail=detail)
77
 
78
  # --- Success case ---
79
+ traces = [ _round_trace(t) for t in (result.traces or []) ]
80
  return NL2SQLResponse(
81
  ambiguous=False,
82
+ sql=result.sql,
83
+ rationale=result.rationale,
84
+ traces=traces,
85
  )
nl2sql/pipeline.py CHANGED
@@ -1,6 +1,8 @@
1
  from __future__ import annotations
2
  import traceback
 
3
  from typing import Dict, Any, Optional, List
 
4
  from nl2sql.types import StageResult
5
  from nl2sql.ambiguity_detector import AmbiguityDetector
6
  from nl2sql.planner import Planner
@@ -11,10 +13,26 @@ from nl2sql.verifier import Verifier
11
  from nl2sql.repair import Repair
12
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class Pipeline:
15
  """
16
- NL2SQL Copilot pipeline with guaranteed dict output.
17
- All stages return structured traces and errors but final result is JSON-safe dict.
 
18
  """
19
 
20
  def __init__(
@@ -25,7 +43,7 @@ class Pipeline:
25
  generator: Generator,
26
  safety: Safety,
27
  executor: Optional[Executor] = None,
28
- verifier: Optional[Verifier] = None ,
29
  repair: Optional[Repair] = None,
30
  ):
31
  self.detector = detector
@@ -55,7 +73,7 @@ class Pipeline:
55
  if isinstance(r, StageResult):
56
  return r
57
  else:
58
- # not ideal, but wrap it
59
  return StageResult(ok=True, data=r, trace=None)
60
  except Exception as e:
61
  tb = traceback.format_exc()
@@ -68,41 +86,40 @@ class Pipeline:
68
  user_query: str,
69
  schema_preview: str,
70
  clarify_answers: Optional[Dict[str, Any]] = None,
71
- ) -> Dict[str, Any]:
72
- """
73
- Always returns:
74
- {
75
- "ambiguous": bool,
76
- "error": bool,
77
- "details": list[str] | None,
78
- "sql": str | None,
79
- "rationale": str | None,
80
- "verified": bool | None,
81
- "traces": list[dict]
82
- }
83
- """
84
  traces: List[dict] = []
85
  details: List[str] = []
86
- sql, rationale, verified = None, None, None
 
 
87
 
88
  # --- 1) ambiguity detection
89
  try:
90
  questions = self.detector.detect(user_query, schema_preview)
91
  if questions:
92
- return {
93
- "ambiguous": True,
94
- "error": False,
95
- "details": [f"Ambiguities found: {len(questions)}"],
96
- "questions": questions,
97
- "traces": [],
98
- }
 
 
 
 
99
  except Exception as e:
100
- return {
101
- "ambiguous": True,
102
- "error": True,
103
- "details": [f"Detector failed: {e}"],
104
- "traces": [],
105
- }
 
 
 
 
 
106
 
107
  # --- 2) planner
108
  r_plan = self._safe_stage(
@@ -110,12 +127,17 @@ class Pipeline:
110
  )
111
  traces.extend(self._trace_list(r_plan))
112
  if not r_plan.ok:
113
- return {
114
- "ambiguous": False,
115
- "error": True,
116
- "details": r_plan.error,
117
- "traces": traces,
118
- }
 
 
 
 
 
119
 
120
  # --- 3) generator
121
  r_gen = self._safe_stage(
@@ -127,40 +149,51 @@ class Pipeline:
127
  )
128
  traces.extend(self._trace_list(r_gen))
129
  if not r_gen.ok:
130
- return {
131
- "ambiguous": False,
132
- "error": True,
133
- "details": r_gen.errors,
134
- "traces": traces,
135
- }
 
 
 
 
 
136
  sql = r_gen.data.get("sql")
137
  rationale = r_gen.data.get("rationale")
138
 
139
  # --- 4) safety
140
- r_safe = self._safe_stage(self.safety.check, sql=sql)
 
141
  traces.extend(self._trace_list(r_safe))
142
  if not r_safe.ok:
143
- return {
144
- "ambiguous": False,
145
- "error": True,
146
- "details": r_safe.error,
147
- "traces": traces,
148
- }
 
 
 
 
 
149
 
150
  # --- 5) executor
151
- r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
152
  traces.extend(self._trace_list(r_exec))
153
  if not r_exec.ok:
154
  details.extend(r_exec.error or [])
155
 
156
  # --- 6) verifier
157
- r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
158
  traces.extend(self._trace_list(r_ver))
159
  verified = bool(r_ver.ok)
160
 
161
  # --- 7) repair loop if verification failed
162
  if not verified:
163
- for attempt in range(2):
164
  r_fix = self._safe_stage(
165
  self.repair.run,
166
  sql=sql,
@@ -171,29 +204,33 @@ class Pipeline:
171
  if not r_fix.ok:
172
  break
173
  sql = r_fix.data.get("sql")
174
- r_safe = self._safe_stage(self.safety.check, sql=sql)
 
175
  traces.extend(self._trace_list(r_safe))
176
  if not r_safe.ok:
177
  details.extend(r_safe.error or [])
178
  continue
179
- r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
 
180
  traces.extend(self._trace_list(r_exec))
181
  if not r_exec.ok:
182
  details.extend(r_exec.error or [])
183
  continue
184
- r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
 
185
  traces.extend(self._trace_list(r_ver))
186
  verified = bool(r_ver.ok)
187
  if verified:
188
  break
189
 
190
- # --- Final result dict
191
- return {
192
- "ambiguous": False,
193
- "error": len(details) > 0 and not verified,
194
- "details": details or None,
195
- "sql": sql,
196
- "rationale": rationale,
197
- "verified": verified,
198
- "traces": traces,
199
- }
 
 
1
  from __future__ import annotations
2
  import traceback
3
+ from dataclasses import dataclass, asdict
4
  from typing import Dict, Any, Optional, List
5
+
6
  from nl2sql.types import StageResult
7
  from nl2sql.ambiguity_detector import AmbiguityDetector
8
  from nl2sql.planner import Planner
 
13
  from nl2sql.repair import Repair
14
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
15
 
16
+
17
+ # ---- NEW: FinalResult as domain-level, type-safe result ----
18
+ @dataclass(frozen=True)
19
+ class FinalResult:
20
+ ok: bool
21
+ ambiguous: bool
22
+ error: bool
23
+ details: Optional[List[str]]
24
+ sql: Optional[str]
25
+ rationale: Optional[str]
26
+ verified: Optional[bool]
27
+ questions: Optional[List[str]]
28
+ traces: List[dict]
29
+
30
+
31
  class Pipeline:
32
  """
33
+ NL2SQL Copilot pipeline.
34
+ Stages return StageResult; final result is a type-safe FinalResult.
35
+ Adapters (e.g. FastAPI) can serialize with dataclasses.asdict().
36
  """
37
 
38
  def __init__(
 
43
  generator: Generator,
44
  safety: Safety,
45
  executor: Optional[Executor] = None,
46
+ verifier: Optional[Verifier] = None,
47
  repair: Optional[Repair] = None,
48
  ):
49
  self.detector = detector
 
73
  if isinstance(r, StageResult):
74
  return r
75
  else:
76
+ # Normalize non-StageResult returns
77
  return StageResult(ok=True, data=r, trace=None)
78
  except Exception as e:
79
  tb = traceback.format_exc()
 
86
  user_query: str,
87
  schema_preview: str,
88
  clarify_answers: Optional[Dict[str, Any]] = None,
89
+ ) -> FinalResult:
 
 
 
 
 
 
 
 
 
 
 
 
90
  traces: List[dict] = []
91
  details: List[str] = []
92
+ sql: Optional[str] = None
93
+ rationale: Optional[str] = None
94
+ verified: Optional[bool] = None
95
 
96
  # --- 1) ambiguity detection
97
  try:
98
  questions = self.detector.detect(user_query, schema_preview)
99
  if questions:
100
+ return FinalResult(
101
+ ok=True,
102
+ ambiguous=True,
103
+ error=False,
104
+ details=[f"Ambiguities found: {len(questions)}"],
105
+ questions=questions,
106
+ sql=None,
107
+ rationale=None,
108
+ verified=None,
109
+ traces=[],
110
+ )
111
  except Exception as e:
112
+ return FinalResult(
113
+ ok=False,
114
+ ambiguous=True,
115
+ error=True,
116
+ details=[f"Detector failed: {e}"],
117
+ questions=None,
118
+ sql=None,
119
+ rationale=None,
120
+ verified=None,
121
+ traces=[],
122
+ )
123
 
124
  # --- 2) planner
125
  r_plan = self._safe_stage(
 
127
  )
128
  traces.extend(self._trace_list(r_plan))
129
  if not r_plan.ok:
130
+ return FinalResult(
131
+ ok=False,
132
+ ambiguous=False,
133
+ error=True,
134
+ details=r_plan.error,
135
+ questions=None,
136
+ sql=None,
137
+ rationale=None,
138
+ verified=None,
139
+ traces=traces,
140
+ )
141
 
142
  # --- 3) generator
143
  r_gen = self._safe_stage(
 
149
  )
150
  traces.extend(self._trace_list(r_gen))
151
  if not r_gen.ok:
152
+ return FinalResult(
153
+ ok=False,
154
+ ambiguous=False,
155
+ error=True,
156
+ details=r_gen.error,
157
+ questions=None,
158
+ sql=None,
159
+ rationale=None,
160
+ verified=None,
161
+ traces=traces,
162
+ )
163
  sql = r_gen.data.get("sql")
164
  rationale = r_gen.data.get("rationale")
165
 
166
  # --- 4) safety
167
+ # fix: align with DummySafety signature → use .run (not .check)
168
+ r_safe = self._safe_stage(self.safety.run, sql=sql)
169
  traces.extend(self._trace_list(r_safe))
170
  if not r_safe.ok:
171
+ return FinalResult(
172
+ ok=False,
173
+ ambiguous=False,
174
+ error=True,
175
+ details=r_safe.error,
176
+ questions=None,
177
+ sql=sql,
178
+ rationale=rationale,
179
+ verified=None,
180
+ traces=traces,
181
+ )
182
 
183
  # --- 5) executor
184
+ r_exec = self._safe_stage(self.executor.run, sql=r_safe.data.get("sql", sql))
185
  traces.extend(self._trace_list(r_exec))
186
  if not r_exec.ok:
187
  details.extend(r_exec.error or [])
188
 
189
  # --- 6) verifier
190
+ r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec.data)
191
  traces.extend(self._trace_list(r_ver))
192
  verified = bool(r_ver.ok)
193
 
194
  # --- 7) repair loop if verification failed
195
  if not verified:
196
+ for _attempt in range(2):
197
  r_fix = self._safe_stage(
198
  self.repair.run,
199
  sql=sql,
 
204
  if not r_fix.ok:
205
  break
206
  sql = r_fix.data.get("sql")
207
+
208
+ r_safe = self._safe_stage(self.safety.run, sql=sql)
209
  traces.extend(self._trace_list(r_safe))
210
  if not r_safe.ok:
211
  details.extend(r_safe.error or [])
212
  continue
213
+
214
+ r_exec = self._safe_stage(self.executor.run, sql=r_safe.data.get("sql", sql))
215
  traces.extend(self._trace_list(r_exec))
216
  if not r_exec.ok:
217
  details.extend(r_exec.error or [])
218
  continue
219
+
220
+ r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec.data)
221
  traces.extend(self._trace_list(r_ver))
222
  verified = bool(r_ver.ok)
223
  if verified:
224
  break
225
 
226
+ return FinalResult(
227
+ ok=bool(verified) and not details,
228
+ ambiguous=False,
229
+ error=bool(details) and not bool(verified),
230
+ details=details or None,
231
+ sql=sql,
232
+ rationale=rationale,
233
+ verified=verified,
234
+ questions=None,
235
+ traces=traces,
236
+ )
nl2sql/types.py CHANGED
@@ -19,3 +19,20 @@ class StageResult:
19
  trace: Optional[StageTrace] = None
20
  error: Optional[List[str]] = None
21
  notes: Optional[Dict[str, Any]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  trace: Optional[StageTrace] = None
20
  error: Optional[List[str]] = None
21
  notes: Optional[Dict[str, Any]] = None
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class FinalResult:
26
+ """
27
+ Final domain result of the whole pipeline.
28
+ Adapters (HTTP/CLI/UI) should serialize this to dict/JSON at the boundary.
29
+ """
30
+ ok: bool # end-to-end success
31
+ ambiguous: bool
32
+ error: bool
33
+ sql: Optional[str]
34
+ rationale: Optional[str]
35
+ verified: Optional[bool]
36
+ details: Optional[List[str]]
37
+ questions: Optional[List[str]]
38
+ traces: List[Dict[str, Any]]
tests/test_nl2sql_router.py CHANGED
@@ -1,13 +1,12 @@
1
  from fastapi.testclient import TestClient
2
  from app.main import app
3
- from nl2sql.types import StageResult, StageTrace
4
 
5
  client = TestClient(app)
6
 
7
-
8
- def fake_trace(stage: str):
9
- return StageTrace(stage=stage, duration_ms=10.0)
10
-
11
 
12
  path = app.url_path_for("nl2sql_handler")
13
 
@@ -16,15 +15,18 @@ path = app.url_path_for("nl2sql_handler")
16
  def test_ambiguity_route(monkeypatch):
17
  from app.routers import nl2sql
18
 
19
- # mock pipeline to return StageResult with ambiguous=True
20
  def fake_run(*args, **kwargs):
21
- return StageResult(
22
  ok=True,
23
- data={
24
- "ambiguous": True,
25
- "questions": ["Which table do you mean?"],
26
- "traces": [fake_trace("detector")],
27
- },
 
 
 
28
  )
29
 
30
  monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
@@ -36,11 +38,11 @@ def test_ambiguity_route(monkeypatch):
36
  "schema_preview": "CREATE TABLE ...",
37
  },
38
  )
39
-
40
  assert resp.status_code == 200
41
  data = resp.json()
42
  assert data["ambiguous"] is True
43
  assert "questions" in data
 
44
 
45
 
46
  # --- 2) Error / failure case -------------------------------------------------
@@ -48,8 +50,16 @@ def test_error_route(monkeypatch):
48
  from app.routers import nl2sql
49
 
50
  def fake_run(*args, **kwargs):
51
- return StageResult(
52
- ok=False, error=["Bad SQL"], data={"traces": [fake_trace("safety")]}
 
 
 
 
 
 
 
 
53
  )
54
 
55
  monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
@@ -61,7 +71,6 @@ def test_error_route(monkeypatch):
61
  "schema_preview": "CREATE TABLE users(id int);",
62
  },
63
  )
64
-
65
  assert resp.status_code == 400
66
  assert "Bad SQL" in resp.json()["detail"]
67
 
@@ -71,14 +80,16 @@ def test_success_route(monkeypatch):
71
  from app.routers import nl2sql
72
 
73
  def fake_run(*args, **kwargs):
74
- return StageResult(
75
  ok=True,
76
- data={
77
- "ambiguous": False,
78
- "sql": "SELECT * FROM users;",
79
- "rationale": "Simple listing",
80
- "traces": [fake_trace("planner"), fake_trace("generator")],
81
- },
 
 
82
  )
83
 
84
  monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
@@ -96,3 +107,4 @@ def test_success_route(monkeypatch):
96
  assert data["sql"].lower().startswith("select")
97
  assert isinstance(data["traces"], list)
98
  assert any(t["stage"] == "planner" for t in data["traces"])
 
 
1
  from fastapi.testclient import TestClient
2
  from app.main import app
3
+ from nl2sql.pipeline import FinalResult
4
 
5
  client = TestClient(app)
6
 
7
+ def fake_trace(stage: str) -> dict:
8
+ # FinalResult.traces is a list of dicts (StageTrace.__dict__)
9
+ return {"stage": stage, "duration_ms": 10.0, "cost_usd": None, "notes": None}
 
10
 
11
  path = app.url_path_for("nl2sql_handler")
12
 
 
15
  def test_ambiguity_route(monkeypatch):
16
  from app.routers import nl2sql
17
 
18
+ # mock pipeline to return FinalResult with ambiguous=True
19
  def fake_run(*args, **kwargs):
20
+ return FinalResult(
21
  ok=True,
22
+ ambiguous=True,
23
+ error=False,
24
+ details=["Ambiguities found: 1"],
25
+ questions=["Which table do you mean?"],
26
+ sql=None,
27
+ rationale=None,
28
+ verified=None,
29
+ traces=[fake_trace("detector")],
30
  )
31
 
32
  monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
 
38
  "schema_preview": "CREATE TABLE ...",
39
  },
40
  )
 
41
  assert resp.status_code == 200
42
  data = resp.json()
43
  assert data["ambiguous"] is True
44
  assert "questions" in data
45
+ assert isinstance(data["questions"], list)
46
 
47
 
48
  # --- 2) Error / failure case -------------------------------------------------
 
50
  from app.routers import nl2sql
51
 
52
  def fake_run(*args, **kwargs):
53
+ return FinalResult(
54
+ ok=False,
55
+ ambiguous=False,
56
+ error=True,
57
+ details=["Bad SQL"],
58
+ questions=None,
59
+ sql=None,
60
+ rationale=None,
61
+ verified=None,
62
+ traces=[fake_trace("safety")],
63
  )
64
 
65
  monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
 
71
  "schema_preview": "CREATE TABLE users(id int);",
72
  },
73
  )
 
74
  assert resp.status_code == 400
75
  assert "Bad SQL" in resp.json()["detail"]
76
 
 
80
  from app.routers import nl2sql
81
 
82
  def fake_run(*args, **kwargs):
83
+ return FinalResult(
84
  ok=True,
85
+ ambiguous=False,
86
+ error=False,
87
+ details=None,
88
+ questions=None,
89
+ sql="SELECT * FROM users;",
90
+ rationale="Simple listing",
91
+ verified=True,
92
+ traces=[fake_trace("planner"), fake_trace("generator")],
93
  )
94
 
95
  monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
 
107
  assert data["sql"].lower().startswith("select")
108
  assert isinstance(data["traces"], list)
109
  assert any(t["stage"] == "planner" for t in data["traces"])
110
+ assert any(t["stage"] == "generator" for t in data["traces"])
tests/test_pipeline_integration.py CHANGED
@@ -1,14 +1,13 @@
1
- from nl2sql.pipeline import Pipeline
2
  from nl2sql.types import StageResult, StageTrace
3
 
4
 
5
  # --- Dummy stages to isolate pipeline -----------------------------------------
6
 
7
-
8
  class DummyDetector:
9
  """Simulates ambiguity detector stage."""
10
 
11
- def __init__(self, ambiguous=False):
12
  self.ambiguous = ambiguous
13
 
14
  def detect(self, user_query, schema_preview):
@@ -43,10 +42,12 @@ class DummyGenerator:
43
  class DummySafety:
44
  """Simulates safety stage."""
45
 
46
- def check(self, sql):
 
47
  trace = StageTrace(stage="safety", duration_ms=1.0)
48
  if "DROP" in sql.upper():
49
  return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
 
50
  return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)
51
 
52
 
@@ -64,13 +65,13 @@ def test_pipeline_success():
64
  schema_preview="CREATE TABLE singer(id int, name text);",
65
  )
66
 
67
- assert isinstance(r, StageResult)
68
  assert r.ok is True
69
- data = r.data or {}
70
- assert data["sql"].lower().startswith("select")
71
- assert any(t.stage == "planner" for t in data["traces"])
72
- assert any(t.stage == "generator" for t in data["traces"])
73
- assert any(t.stage == "safety" for t in data["traces"])
74
 
75
 
76
  # --- 2) Ambiguity case --------------------------------------------------------
@@ -84,10 +85,10 @@ def test_pipeline_ambiguity():
84
 
85
  r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);")
86
 
87
- assert isinstance(r, StageResult)
88
  assert r.ok is True
89
- assert r.data["ambiguous"] is True
90
- assert isinstance(r.data["questions"], list)
91
 
92
 
93
  # --- 3) Planner failure -------------------------------------------------------
@@ -101,9 +102,10 @@ def test_pipeline_plan_fail():
101
  r = pipeline.run(
102
  user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);"
103
  )
104
- assert isinstance(r, StageResult)
105
  assert r.ok is False
106
- assert "Planner failed" in " ".join(r.error or [])
 
107
 
108
 
109
  # --- 4) Generator failure -----------------------------------------------------
@@ -117,8 +119,10 @@ def test_pipeline_gen_fail():
117
  r = pipeline.run(
118
  user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);"
119
  )
 
120
  assert r.ok is False
121
- assert "Generator failed" in " ".join(r.error or [])
 
122
 
123
 
124
  # --- 5) Safety failure --------------------------------------------------------
@@ -140,5 +144,7 @@ def test_pipeline_safety_fail():
140
  r = pipeline.run(
141
  user_query="drop something", schema_preview="CREATE TABLE x(id int);"
142
  )
 
143
  assert r.ok is False
144
- assert "unsafe" in " ".join(r.error or []).lower()
 
 
1
+ from nl2sql.pipeline import Pipeline, FinalResult
2
  from nl2sql.types import StageResult, StageTrace
3
 
4
 
5
  # --- Dummy stages to isolate pipeline -----------------------------------------
6
 
 
7
  class DummyDetector:
8
  """Simulates ambiguity detector stage."""
9
 
10
+ def __init__(self, ambiguous: bool = False):
11
  self.ambiguous = ambiguous
12
 
13
  def detect(self, user_query, schema_preview):
 
42
  class DummySafety:
43
  """Simulates safety stage."""
44
 
45
+ # NOTE: pipeline now calls safety.run(sql=...)
46
+ def run(self, *, sql):
47
  trace = StageTrace(stage="safety", duration_ms=1.0)
48
  if "DROP" in sql.upper():
49
  return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
50
+ # echo back sql in data to feed executor
51
  return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)
52
 
53
 
 
65
  schema_preview="CREATE TABLE singer(id int, name text);",
66
  )
67
 
68
+ assert isinstance(r, FinalResult)
69
  assert r.ok is True
70
+ assert r.sql is not None and r.sql.lower().startswith("select")
71
+ # traces is a list of dicts (StageTrace.__dict__)
72
+ assert any(t.get("stage") == "planner" for t in r.traces)
73
+ assert any(t.get("stage") == "generator" for t in r.traces)
74
+ assert any(t.get("stage") == "safety" for t in r.traces)
75
 
76
 
77
  # --- 2) Ambiguity case --------------------------------------------------------
 
85
 
86
  r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);")
87
 
88
+ assert isinstance(r, FinalResult)
89
  assert r.ok is True
90
+ assert r.ambiguous is True
91
+ assert isinstance(r.questions, list) and len(r.questions) > 0
92
 
93
 
94
  # --- 3) Planner failure -------------------------------------------------------
 
102
  r = pipeline.run(
103
  user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);"
104
  )
105
+ assert isinstance(r, FinalResult)
106
  assert r.ok is False
107
+ assert r.details is not None
108
+ assert "Planner failed" in " ".join(r.details)
109
 
110
 
111
  # --- 4) Generator failure -----------------------------------------------------
 
119
  r = pipeline.run(
120
  user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);"
121
  )
122
+ assert isinstance(r, FinalResult)
123
  assert r.ok is False
124
+ assert r.details is not None
125
+ assert "Generator failed" in " ".join(r.details)
126
 
127
 
128
  # --- 5) Safety failure --------------------------------------------------------
 
144
  r = pipeline.run(
145
  user_query="drop something", schema_preview="CREATE TABLE x(id int);"
146
  )
147
+ assert isinstance(r, FinalResult)
148
  assert r.ok is False
149
+ assert r.details is not None
150
+ assert "unsafe" in " ".join(r.details).lower()