Melika Kheirieh commited on
Commit
8618ece
·
1 Parent(s): 552a3c5

refactor(core): DI-ready Pipeline; add registry + YAML factory + typed trace/result

Browse files
Files changed (1) hide show
  1. nl2sql/pipeline_factory.py +89 -124
nl2sql/pipeline_factory.py CHANGED
@@ -1,4 +1,3 @@
1
- # nl2sql/pipeline_factory.py
2
  from __future__ import annotations
3
 
4
  import os
@@ -15,7 +14,15 @@ from nl2sql.registry import (
15
  VERIFIERS,
16
  REPAIRS,
17
  )
18
- from nl2sql.types import StageResult
 
 
 
 
 
 
 
 
19
  from adapters.db.base import DBAdapter
20
  from adapters.db.sqlite_adapter import SQLiteAdapter
21
  from adapters.db.postgres_adapter import PostgresAdapter
@@ -35,15 +42,12 @@ def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
35
  dsn = _require_str(adapter_cfg.get("dsn"), name="adapter.dsn")
36
  return SQLiteAdapter(dsn)
37
  if kind == "postgres":
38
- # Pass through any kwargs your adapter expects (dsn, host, user, ...)
39
  return PostgresAdapter(**adapter_cfg)
40
  raise ValueError(f"Unknown adapter kind: {kind}")
41
 
42
 
43
  def _build_llm(llm_cfg: Optional[Dict[str, Any]] = None) -> Any:
44
- """
45
- Build the LLM provider. Under pytest we return None so stubs are used.
46
- """
47
  if os.getenv("PYTEST_CURRENT_TEST"):
48
  return None
49
  _ = llm_cfg or {}
@@ -54,6 +58,25 @@ def _is_pytest() -> bool:
54
  return bool(os.getenv("PYTEST_CURRENT_TEST"))
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # ------------------------------ factory ------------------------------ #
58
  def pipeline_from_config(path: str) -> Pipeline:
59
  """
@@ -68,7 +91,6 @@ def pipeline_from_config(path: str) -> Pipeline:
68
  # --- Adapter ---
69
  adapter_cfg = cast(Dict[str, Any], cfg.get("adapter", {}))
70
  if is_pytest:
71
- # Avoid filesystem errors during tests
72
  adapter_cfg = {"kind": "sqlite", "dsn": ":memory:"}
73
  adapter = _build_adapter(adapter_cfg)
74
 
@@ -77,120 +99,85 @@ def pipeline_from_config(path: str) -> Pipeline:
77
  llm = _build_llm(llm_cfg)
78
 
79
  if is_pytest:
80
-
81
  class _StubDetector:
82
- # Domain method: return list[str]
83
  def detect(self, *args, **kwargs) -> list[str]:
84
- return [] # no ambiguities
85
 
86
- # Compatibility: return StageResult
87
  def run(self, *args, **kwargs) -> StageResult:
88
  return StageResult(
89
  ok=True,
90
  data={"questions": []},
91
- trace={
92
- "stage": "detector",
93
- "duration_ms": 0,
94
- "notes": {"ambiguous": False, "questions_len": 0},
95
- },
96
  )
97
 
98
  class _StubPlanner:
99
  def __init__(self, llm: Any = None) -> None: ...
100
-
101
- # Domain: return str (plan text)
102
  def plan(self, *args, **kwargs) -> str:
103
  return "stub plan"
104
 
105
- # Compat: StageResult
106
  def run(self, *args, **kwargs) -> StageResult:
 
107
  return StageResult(
108
  ok=True,
109
- data={"plan": "stub plan"},
110
- trace={
111
- "stage": "planner",
112
- "duration_ms": 0,
113
- "notes": {"len_plan": 9},
114
- },
115
  )
116
 
117
  class _StubGenerator:
118
  def __init__(self, llm: Any = None) -> None: ...
119
-
120
- # Domain: return tuple[str, str] → (sql, rationale)
121
  def generate(self, *args, **kwargs) -> tuple[str, str]:
122
  return "SELECT 1;", "stub"
123
 
124
- # Compat: StageResult
125
  def run(self, *args, **kwargs) -> StageResult:
126
  sql, rationale = self.generate(*args, **kwargs)
127
  return StageResult(
128
  ok=True,
129
  data={"sql": sql, "rationale": rationale},
130
- trace={
131
- "stage": "generator",
132
- "duration_ms": 0,
133
- "notes": {"rationale_len": len(rationale)},
134
- },
135
  )
136
 
137
  class _StubExecutor:
138
  def __init__(self, db: Any | None = None) -> None: ...
139
-
140
- # Domain: return dict (execution result)
141
  def execute(self, *args, **kwargs) -> Dict[str, Any]:
142
  rows = [{"x": 1}]
143
  return {"rows": rows, "row_count": len(rows)}
144
 
145
- # Compat: StageResult
146
  def run(self, *args, **kwargs) -> StageResult:
147
  out = self.execute(*args, **kwargs)
148
  return StageResult(
149
  ok=True,
150
  data=out,
151
- trace={
152
- "stage": "executor",
153
- "duration_ms": 0,
154
- "notes": {"row_count": out["row_count"]},
155
- },
156
  )
157
 
158
  class _StubVerifier:
159
- # Domain: return bool
160
  def verify(self, *args, **kwargs) -> bool:
161
  return True
162
 
163
- # Compat: StageResult
164
  def run(self, *args, **kwargs) -> StageResult:
165
  return StageResult(
166
- ok=True,
167
- data={"verified": True},
168
- trace={"stage": "verifier", "duration_ms": 0, "notes": None},
169
  )
170
 
171
  class _StubRepair:
172
  def __init__(self, llm: Any = None) -> None: ...
173
-
174
- # Domain: return str (repaired SQL)
175
  def repair(self, *args, **kwargs) -> str:
176
  return kwargs.get("sql") or "SELECT 1;"
177
 
178
- # Compat: StageResult
179
  def run(self, *args, **kwargs) -> StageResult:
180
  sql = self.repair(*args, **kwargs)
181
- return StageResult(
182
- ok=True,
183
- data={"sql": sql},
184
- trace={"stage": "repair", "duration_ms": 0, "notes": None},
185
- )
186
 
187
- detector = _StubDetector()
188
- planner = _StubPlanner()
189
- generator = _StubGenerator()
190
  safety = SAFETIES[cfg.get("safety", "default")]()
191
- executor = _StubExecutor(db=adapter)
192
- verifier = _StubVerifier()
193
- repair = _StubRepair()
194
 
195
  else:
196
  detector = DETECTORS[cfg.get("detector", "default")]()
@@ -227,105 +214,83 @@ def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipel
227
  if is_pytest:
228
 
229
  class _StubDetector:
230
- def detect(self, *args, **kwargs) -> StageResult:
 
 
 
231
  return StageResult(
232
  ok=True,
233
  data={"questions": []},
234
- trace={
235
- "stage": "detector",
236
- "duration_ms": 0,
237
- "notes": {"ambiguous": False, "questions_len": 0},
238
- },
239
  )
240
 
241
- def run(self, *args, **kwargs) -> StageResult:
242
- return self.detect(*args, **kwargs)
243
-
244
  class _StubPlanner:
245
  def __init__(self, llm: Any = None) -> None: ...
 
 
246
 
247
- def plan(self, *args, **kwargs) -> StageResult:
 
248
  return StageResult(
249
  ok=True,
250
- data={"plan": "stub plan"},
251
- trace={
252
- "stage": "planner",
253
- "duration_ms": 0,
254
- "notes": {"len_plan": 8},
255
- },
256
  )
257
 
258
- def run(self, *args, **kwargs) -> StageResult:
259
- return self.plan(*args, **kwargs)
260
-
261
  class _StubGenerator:
262
  def __init__(self, llm: Any = None) -> None: ...
 
 
263
 
264
- def generate(self, *args, **kwargs) -> StageResult:
 
265
  return StageResult(
266
  ok=True,
267
- data={"sql": "SELECT 1;", "rationale": "stub"},
268
- trace={
269
- "stage": "generator",
270
- "duration_ms": 0,
271
- "notes": {"rationale_len": 4},
272
- },
273
  )
274
 
275
- def run(self, *args, **kwargs) -> StageResult:
276
- return self.generate(*args, **kwargs)
277
-
278
  class _StubExecutor:
279
- def __init__(self, db: DBAdapter | None = None) -> None: ...
280
-
281
- def execute(self, *args, **kwargs) -> StageResult:
282
  rows = [{"x": 1}]
283
- return StageResult(
284
- ok=True,
285
- data={"rows": rows, "row_count": len(rows)},
286
- trace={
287
- "stage": "executor",
288
- "duration_ms": 0,
289
- "notes": {"row_count": len(rows)},
290
- },
291
- )
292
 
293
  def run(self, *args, **kwargs) -> StageResult:
294
- return self.execute(*args, **kwargs)
295
-
296
- class _StubVerifier:
297
- def verify(self, *args, **kwargs) -> StageResult:
298
  return StageResult(
299
  ok=True,
300
- data={"verified": True},
301
- trace={"stage": "verifier", "duration_ms": 0, "notes": None},
302
  )
303
 
 
 
 
 
304
  def run(self, *args, **kwargs) -> StageResult:
305
- return self.verify(*args, **kwargs)
 
 
306
 
307
  class _StubRepair:
308
  def __init__(self, llm: Any = None) -> None: ...
309
-
310
- def repair(self, *args, **kwargs) -> StageResult:
311
- # return original sql if any, else SELECT 1
312
- sql = kwargs.get("sql") or "SELECT 1;"
313
- return StageResult(
314
- ok=True,
315
- data={"sql": sql},
316
- trace={"stage": "repair", "duration_ms": 0, "notes": None},
317
- )
318
 
319
  def run(self, *args, **kwargs) -> StageResult:
320
- return self.repair(*args, **kwargs)
 
321
 
322
- detector = _StubDetector()
323
- planner = _StubPlanner()
324
- generator = _StubGenerator()
325
  safety = SAFETIES[cfg.get("safety", "default")]()
326
- executor = _StubExecutor(db=adapter)
327
- verifier = _StubVerifier()
328
- repair = _StubRepair()
329
 
330
  else:
331
  detector = DETECTORS[cfg.get("detector", "default")]()
 
 
1
  from __future__ import annotations
2
 
3
  import os
 
14
  VERIFIERS,
15
  REPAIRS,
16
  )
17
+ from nl2sql.types import StageResult, StageTrace
18
+
19
+ from nl2sql.ambiguity_detector import AmbiguityDetector
20
+ from nl2sql.planner import Planner
21
+ from nl2sql.generator import Generator
22
+ from nl2sql.executor import Executor
23
+ from nl2sql.verifier import Verifier
24
+ from nl2sql.repair import Repair
25
+
26
  from adapters.db.base import DBAdapter
27
  from adapters.db.sqlite_adapter import SQLiteAdapter
28
  from adapters.db.postgres_adapter import PostgresAdapter
 
42
  dsn = _require_str(adapter_cfg.get("dsn"), name="adapter.dsn")
43
  return SQLiteAdapter(dsn)
44
  if kind == "postgres":
 
45
  return PostgresAdapter(**adapter_cfg)
46
  raise ValueError(f"Unknown adapter kind: {kind}")
47
 
48
 
49
  def _build_llm(llm_cfg: Optional[Dict[str, Any]] = None) -> Any:
50
+ """Under pytest return None (stubs handle logic); otherwise real OpenAI provider."""
 
 
51
  if os.getenv("PYTEST_CURRENT_TEST"):
52
  return None
53
  _ = llm_cfg or {}
 
58
  return bool(os.getenv("PYTEST_CURRENT_TEST"))
59
 
60
 
61
+ def _tr(
62
+ stage: str,
63
+ *,
64
+ duration_ms: int = 0,
65
+ notes: Optional[Dict[str, Any]] = None,
66
+ token_in: Optional[int] = None,
67
+ token_out: Optional[int] = None,
68
+ cost_usd: Optional[float] = None,
69
+ ) -> StageTrace:
70
+ return StageTrace(
71
+ stage=stage,
72
+ duration_ms=duration_ms,
73
+ notes=notes,
74
+ token_in=token_in,
75
+ token_out=token_out,
76
+ cost_usd=cost_usd,
77
+ )
78
+
79
+
80
  # ------------------------------ factory ------------------------------ #
81
  def pipeline_from_config(path: str) -> Pipeline:
82
  """
 
91
  # --- Adapter ---
92
  adapter_cfg = cast(Dict[str, Any], cfg.get("adapter", {}))
93
  if is_pytest:
 
94
  adapter_cfg = {"kind": "sqlite", "dsn": ":memory:"}
95
  adapter = _build_adapter(adapter_cfg)
96
 
 
99
  llm = _build_llm(llm_cfg)
100
 
101
  if is_pytest:
102
+ # ---------- stubs: domain-shaped + StageResult on run() ----------
103
  class _StubDetector:
 
104
  def detect(self, *args, **kwargs) -> list[str]:
105
+ return []
106
 
 
107
  def run(self, *args, **kwargs) -> StageResult:
108
  return StageResult(
109
  ok=True,
110
  data={"questions": []},
111
+ trace=_tr(
112
+ "detector", notes={"ambiguous": False, "questions_len": 0}
113
+ ),
 
 
114
  )
115
 
116
  class _StubPlanner:
117
  def __init__(self, llm: Any = None) -> None: ...
 
 
118
  def plan(self, *args, **kwargs) -> str:
119
  return "stub plan"
120
 
 
121
  def run(self, *args, **kwargs) -> StageResult:
122
+ plan = self.plan(*args, **kwargs)
123
  return StageResult(
124
  ok=True,
125
+ data={"plan": plan},
126
+ trace=_tr("planner", notes={"len_plan": len(plan)}),
 
 
 
 
127
  )
128
 
129
  class _StubGenerator:
130
  def __init__(self, llm: Any = None) -> None: ...
 
 
131
  def generate(self, *args, **kwargs) -> tuple[str, str]:
132
  return "SELECT 1;", "stub"
133
 
 
134
  def run(self, *args, **kwargs) -> StageResult:
135
  sql, rationale = self.generate(*args, **kwargs)
136
  return StageResult(
137
  ok=True,
138
  data={"sql": sql, "rationale": rationale},
139
+ trace=_tr("generator", notes={"rationale_len": len(rationale)}),
 
 
 
 
140
  )
141
 
142
  class _StubExecutor:
143
  def __init__(self, db: Any | None = None) -> None: ...
 
 
144
  def execute(self, *args, **kwargs) -> Dict[str, Any]:
145
  rows = [{"x": 1}]
146
  return {"rows": rows, "row_count": len(rows)}
147
 
 
148
  def run(self, *args, **kwargs) -> StageResult:
149
  out = self.execute(*args, **kwargs)
150
  return StageResult(
151
  ok=True,
152
  data=out,
153
+ trace=_tr("executor", notes={"row_count": out["row_count"]}),
 
 
 
 
154
  )
155
 
156
  class _StubVerifier:
 
157
  def verify(self, *args, **kwargs) -> bool:
158
  return True
159
 
 
160
  def run(self, *args, **kwargs) -> StageResult:
161
  return StageResult(
162
+ ok=True, data={"verified": True}, trace=_tr("verifier")
 
 
163
  )
164
 
165
  class _StubRepair:
166
  def __init__(self, llm: Any = None) -> None: ...
 
 
167
  def repair(self, *args, **kwargs) -> str:
168
  return kwargs.get("sql") or "SELECT 1;"
169
 
 
170
  def run(self, *args, **kwargs) -> StageResult:
171
  sql = self.repair(*args, **kwargs)
172
+ return StageResult(ok=True, data={"sql": sql}, trace=_tr("repair"))
 
 
 
 
173
 
174
+ detector = cast(AmbiguityDetector, _StubDetector())
175
+ planner = cast(Planner, _StubPlanner())
176
+ generator = cast(Generator, _StubGenerator())
177
  safety = SAFETIES[cfg.get("safety", "default")]()
178
+ executor = cast(Executor, _StubExecutor(db=adapter))
179
+ verifier = cast(Verifier, _StubVerifier())
180
+ repair = cast(Repair, _StubRepair())
181
 
182
  else:
183
  detector = DETECTORS[cfg.get("detector", "default")]()
 
214
  if is_pytest:
215
 
216
  class _StubDetector:
217
+ def detect(self, *args, **kwargs) -> list[str]:
218
+ return []
219
+
220
+ def run(self, *args, **kwargs) -> StageResult:
221
  return StageResult(
222
  ok=True,
223
  data={"questions": []},
224
+ trace=_tr(
225
+ "detector", notes={"ambiguous": False, "questions_len": 0}
226
+ ),
 
 
227
  )
228
 
 
 
 
229
  class _StubPlanner:
230
  def __init__(self, llm: Any = None) -> None: ...
231
+ def plan(self, *args, **kwargs) -> str:
232
+ return "stub plan"
233
 
234
+ def run(self, *args, **kwargs) -> StageResult:
235
+ plan = self.plan(*args, **kwargs)
236
  return StageResult(
237
  ok=True,
238
+ data={"plan": plan},
239
+ trace=_tr("planner", notes={"len_plan": len(plan)}),
 
 
 
 
240
  )
241
 
 
 
 
242
  class _StubGenerator:
243
  def __init__(self, llm: Any = None) -> None: ...
244
+ def generate(self, *args, **kwargs) -> tuple[str, str]:
245
+ return "SELECT 1;", "stub"
246
 
247
+ def run(self, *args, **kwargs) -> StageResult:
248
+ sql, rationale = self.generate(*args, **kwargs)
249
  return StageResult(
250
  ok=True,
251
+ data={"sql": sql, "rationale": rationale},
252
+ trace=_tr("generator", notes={"rationale_len": len(rationale)}),
 
 
 
 
253
  )
254
 
 
 
 
255
  class _StubExecutor:
256
+ def __init__(self, db: Any | None = None) -> None: ...
257
+ def execute(self, *args, **kwargs) -> Dict[str, Any]:
 
258
  rows = [{"x": 1}]
259
+ return {"rows": rows, "row_count": len(rows)}
 
 
 
 
 
 
 
 
260
 
261
  def run(self, *args, **kwargs) -> StageResult:
262
+ out = self.execute(*args, **kwargs)
 
 
 
263
  return StageResult(
264
  ok=True,
265
+ data=out,
266
+ trace=_tr("executor", notes={"row_count": out["row_count"]}),
267
  )
268
 
269
+ class _StubVerifier:
270
+ def verify(self, *args, **kwargs) -> bool:
271
+ return True
272
+
273
  def run(self, *args, **kwargs) -> StageResult:
274
+ return StageResult(
275
+ ok=True, data={"verified": True}, trace=_tr("verifier")
276
+ )
277
 
278
  class _StubRepair:
279
  def __init__(self, llm: Any = None) -> None: ...
280
+ def repair(self, *args, **kwargs) -> str:
281
+ return kwargs.get("sql") or "SELECT 1;"
 
 
 
 
 
 
 
282
 
283
  def run(self, *args, **kwargs) -> StageResult:
284
+ sql = self.repair(*args, **kwargs)
285
+ return StageResult(ok=True, data={"sql": sql}, trace=_tr("repair"))
286
 
287
+ detector = cast(AmbiguityDetector, _StubDetector())
288
+ planner = cast(Planner, _StubPlanner())
289
+ generator = cast(Generator, _StubGenerator())
290
  safety = SAFETIES[cfg.get("safety", "default")]()
291
+ executor = cast(Executor, _StubExecutor(db=adapter))
292
+ verifier = cast(Verifier, _StubVerifier())
293
+ repair = cast(Repair, _StubRepair())
294
 
295
  else:
296
  detector = DETECTORS[cfg.get("detector", "default")]()