Melika Kheirieh commited on
Commit
343ad62
·
1 Parent(s): 76df10c

refactor(core): DI-ready Pipeline; add registry + YAML factory + typed trace/result

Browse files
app/routers/nl2sql.py CHANGED
@@ -7,10 +7,10 @@ import os
7
  from pathlib import Path
8
  import time
9
  import uuid
10
- from typing import Any, Dict, Optional, TypedDict, Union, Protocol, cast
11
 
12
  # --- Third-party ---
13
- from fastapi import APIRouter, HTTPException, Request, UploadFile, File, Depends
14
 
15
  # --- Local ---
16
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
@@ -25,6 +25,10 @@ from nl2sql.repair import Repair
25
  from adapters.llm.openai_provider import OpenAIProvider
26
  from adapters.db.sqlite_adapter import SQLiteAdapter
27
  from adapters.db.postgres_adapter import PostgresAdapter
 
 
 
 
28
 
29
 
30
  # Stable public re-exports
@@ -53,6 +57,10 @@ _DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
53
  UPLOAD_DIR = Path("data/uploads")
54
  UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
55
 
 
 
 
 
56
 
57
  class DBEntry(TypedDict):
58
  path: str
@@ -110,42 +118,55 @@ _load_db_map()
110
  # -------------------------------
111
  # Adapter selection (lazy)
112
  # -------------------------------
 
113
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
114
  """
115
- Resolve a DB adapter:
116
- - postgres: requires POSTGRES_DSN
117
- - sqlite with db_id: uploaded file or fallback locations
118
- - sqlite default: DEFAULT_SQLITE_PATH must exist
 
 
 
 
 
 
 
 
119
  """
120
- mode = os.getenv("DB_MODE", "sqlite").lower()
121
- if mode == "postgres":
122
  dsn = os.environ.get("POSTGRES_DSN")
123
  if not dsn:
124
  raise HTTPException(status_code=500, detail="POSTGRES_DSN env is missing")
125
  return PostgresAdapter(dsn)
126
 
127
  # sqlite mode
128
- _cleanup_db_map()
129
  if db_id:
130
- # Check runtime map
131
- entry = _DB_MAP.get(db_id)
132
- candidates = []
133
- if entry and os.path.exists(entry["path"]):
134
- candidates.append(entry["path"])
135
- # Fallback locations based on convention
136
- candidates.append(os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite"))
137
- candidates.append(str(UPLOAD_DIR / f"{db_id}.sqlite"))
138
-
139
- for p in candidates:
140
- if p and os.path.exists(p):
141
- return SQLiteAdapter(p)
 
 
 
 
 
142
 
143
  raise HTTPException(status_code=400, detail="invalid db_id (file not found)")
144
 
145
- # default sqlite
146
- if not Path(DEFAULT_SQLITE_PATH).exists():
147
- raise HTTPException(status_code=500, detail="default DB not found")
148
- return SQLiteAdapter(DEFAULT_SQLITE_PATH)
 
149
 
150
 
151
  # -------------------------------
@@ -289,57 +310,52 @@ async def upload_db(file: UploadFile = File(...)):
289
  # Main NL2SQL endpoint
290
  # -------------------------------
291
  @router.post("", name="nl2sql_handler")
292
- def nl2sql_handler(
293
- request: NL2SQLRequest,
294
- run: Runner = Depends(get_runner),
295
- ):
296
  """
297
- Handles NL→SQL conversion requests.
298
- Uses dependency-injected pipeline runner (get_runner).
299
- If db_id provided → builds a temporary per-request pipeline.
300
  """
301
  db_id = getattr(request, "db_id", None)
302
- provided_preview: Optional[str] = cast(
303
- Optional[str], getattr(request, "schema_preview", None)
304
  )
305
 
306
- # Select pipeline (DI default vs per-request)
307
  if db_id:
308
  adapter = _select_adapter(db_id)
309
- pipeline = _build_pipeline(adapter)
310
- derived_preview = _derive_schema_preview(adapter)
311
- runner: Runner = pipeline.run
312
- final_preview = provided_preview or derived_preview or ""
313
  else:
314
- runner = run
315
- final_preview = provided_preview or ""
316
 
317
- # Execute safely
318
  try:
319
  result = runner(user_query=request.query, schema_preview=final_preview)
320
  except Exception as exc:
321
  raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
322
 
 
323
  if not isinstance(result, FinalResult):
324
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
325
 
326
- # Ambiguous → 200
327
  if result.ambiguous and (result.questions is not None):
328
  return ClarifyResponse(ambiguous=True, questions=result.questions)
329
 
330
- # Error → 400 + dump
331
  if (not result.ok) or result.error:
332
  print("❌ Pipeline failure dump:")
333
  print(" ok:", result.ok)
334
  print(" error:", result.error)
335
  print(" details:", result.details)
336
  print(" traces:", result.traces)
337
- raise HTTPException(
338
- status_code=400,
339
- detail="; ".join(result.details or []) or (result.error or "Unknown error"),
340
- )
341
 
342
- # Success → 200
343
  traces = [_round_trace(t) for t in (result.traces or [])]
344
  return NL2SQLResponse(
345
  ambiguous=False,
 
7
  from pathlib import Path
8
  import time
9
  import uuid
10
+ from typing import Any, Dict, Optional, TypedDict, Union, Protocol, cast, List
11
 
12
  # --- Third-party ---
13
+ from fastapi import APIRouter, HTTPException, Request, UploadFile, File
14
 
15
  # --- Local ---
16
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
 
25
  from adapters.llm.openai_provider import OpenAIProvider
26
  from adapters.db.sqlite_adapter import SQLiteAdapter
27
  from adapters.db.postgres_adapter import PostgresAdapter
28
+ from nl2sql.pipeline_factory import (
29
+ pipeline_from_config,
30
+ pipeline_from_config_with_adapter,
31
+ )
32
 
33
 
34
  # Stable public re-exports
 
57
  UPLOAD_DIR = Path("data/uploads")
58
  UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
59
 
60
+ CONFIG_PATH = os.getenv("PIPELINE_CONFIG", "configs/sqlite_pipeline.yaml")
61
+ # Build a default pipeline once from config; adapter inside the config will be used.
62
+ _PIPELINE = pipeline_from_config(CONFIG_PATH)
63
+
64
 
65
  class DBEntry(TypedDict):
66
  path: str
 
118
  # -------------------------------
119
  # Adapter selection (lazy)
120
  # -------------------------------
121
+ # ---------- SELECT ADAPTER ----------
122
  def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
123
  """
124
+ Resolve a DB adapter based on module-level DB_MODE and an optional db_id.
125
+
126
+ - postgres mode:
127
+ requires POSTGRES_DSN in env
128
+ - sqlite mode:
129
+ if db_id provided, resolve file by:
130
+ 1) absolute path (if user supplied a full path)
131
+ 2) uploads/{db_id}.sqlite
132
+ 3) uploads/{db_id}.db
133
+ 4) data/{db_id}.sqlite
134
+ 5) data/{db_id}.db
135
+ else fallback to DEFAULT_SQLITE_PATH
136
  """
137
+ if DB_MODE == "postgres":
 
138
  dsn = os.environ.get("POSTGRES_DSN")
139
  if not dsn:
140
  raise HTTPException(status_code=500, detail="POSTGRES_DSN env is missing")
141
  return PostgresAdapter(dsn)
142
 
143
  # sqlite mode
 
144
  if db_id:
145
+ # 1) absolute path
146
+ p = Path(db_id)
147
+ candidates: List[Path] = []
148
+ if p.is_absolute():
149
+ candidates.append(p)
150
+
151
+ # 2) uploads/
152
+ candidates.append(UPLOAD_DIR / f"{db_id}.sqlite")
153
+ candidates.append(UPLOAD_DIR / f"{db_id}.db")
154
+
155
+ # 3) data/
156
+ candidates.append(Path("data") / f"{db_id}.sqlite")
157
+ candidates.append(Path("data") / f"{db_id}.db")
158
+
159
+ for c in candidates:
160
+ if c.exists() and c.is_file():
161
+ return SQLiteAdapter(str(c))
162
 
163
  raise HTTPException(status_code=400, detail="invalid db_id (file not found)")
164
 
165
+ # default sqlite fallback
166
+ default_path = Path(DEFAULT_SQLITE_PATH)
167
+ if not default_path.exists():
168
+ raise HTTPException(status_code=500, detail="default SQLite DB not found")
169
+ return SQLiteAdapter(str(default_path))
170
 
171
 
172
  # -------------------------------
 
310
  # Main NL2SQL endpoint
311
  # -------------------------------
312
  @router.post("", name="nl2sql_handler")
313
+ def nl2sql_handler(request: NL2SQLRequest):
 
 
 
314
  """
315
+ NL→SQL handler using YAML-driven DI. If 'db_id' is provided, we override only the adapter
316
+ while keeping all other stages from the YAML config intact.
 
317
  """
318
  db_id = getattr(request, "db_id", None)
319
+ provided_preview = (
320
+ cast(Optional[str], getattr(request, "schema_preview", None)) or ""
321
  )
322
 
323
+ # Choose runner: default pipeline from YAML OR per-request override with a specific adapter
324
  if db_id:
325
  adapter = _select_adapter(db_id)
326
+ # Build a temporary pipeline from YAML but bind the per-request adapter
327
+ pipeline = pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
328
+ runner = pipeline.run
329
+ final_preview = provided_preview # keep simple; derive only if you have a SQLite schema helper
330
  else:
331
+ runner = _PIPELINE.run
332
+ final_preview = provided_preview
333
 
334
+ # Execute pipeline
335
  try:
336
  result = runner(user_query=request.query, schema_preview=final_preview)
337
  except Exception as exc:
338
  raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
339
 
340
+ # Type sanity
341
  if not isinstance(result, FinalResult):
342
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
343
 
344
+ # Ambiguity path → 200 with questions
345
  if result.ambiguous and (result.questions is not None):
346
  return ClarifyResponse(ambiguous=True, questions=result.questions)
347
 
348
+ # Error path → 400 with joined details
349
  if (not result.ok) or result.error:
350
  print("❌ Pipeline failure dump:")
351
  print(" ok:", result.ok)
352
  print(" error:", result.error)
353
  print(" details:", result.details)
354
  print(" traces:", result.traces)
355
+ message = "; ".join(result.details or []) or "Unknown error"
356
+ raise HTTPException(status_code=400, detail=message)
 
 
357
 
358
+ # Success path → 200
359
  traces = [_round_trace(t) for t in (result.traces or [])]
360
  return NL2SQLResponse(
361
  ambiguous=False,
config/sqlite_pipeline.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ detector: default
2
+ planner: default
3
+ generator: rules # or "llm" when available
4
+ safety: default
5
+ executor: default
6
+ verifier: basic
7
+ repair: default
8
+
9
+ adapter:
10
+ kind: sqlite
11
+ dsn: data/chinook.db
nl2sql/pipeline.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
  import traceback
3
  from dataclasses import dataclass
4
  from typing import Dict, Any, Optional, List
 
5
 
6
  from nl2sql.types import StageResult
7
  from nl2sql.ambiguity_detector import AmbiguityDetector
@@ -31,6 +32,7 @@ class Pipeline:
31
  """
32
  NL2SQL Copilot pipeline.
33
  Stages return StageResult; final result is a type-safe FinalResult.
 
34
  """
35
 
36
  def __init__(
@@ -53,19 +55,26 @@ class Pipeline:
53
  self.repair = repair or NoOpRepair()
54
 
55
  # ------------------------------------------------------------
56
- def _trace_list(self, *stages: StageResult) -> List[dict]:
57
- traces = []
 
 
58
  for s in stages:
59
  if not s:
60
  continue
61
  t = getattr(s, "trace", None)
62
- if t:
63
- traces.append(t.__dict__)
 
64
  return traces
65
 
66
  # ------------------------------------------------------------
67
- def _safe_stage(self, fn, **kwargs) -> StageResult:
68
- """Run a stage safely; if it throws, catch and convert to StageResult."""
 
 
 
 
69
  try:
70
  r = fn(**kwargs)
71
  if isinstance(r, StageResult):
@@ -75,6 +84,18 @@ class Pipeline:
75
  tb = traceback.format_exc()
76
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
77
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # ------------------------------------------------------------
79
  def run(
80
  self,
@@ -88,11 +109,27 @@ class Pipeline:
88
  sql: Optional[str] = None
89
  rationale: Optional[str] = None
90
  verified: Optional[bool] = None
 
 
91
  schema_preview = schema_preview or ""
 
92
 
93
- # --- 1) ambiguity detection ---
94
  try:
 
95
  questions = self.detector.detect(user_query, schema_preview)
 
 
 
 
 
 
 
 
 
 
 
 
96
  if questions:
97
  return FinalResult(
98
  ok=True,
@@ -103,9 +140,11 @@ class Pipeline:
103
  sql=None,
104
  rationale=None,
105
  verified=None,
106
- traces=[],
107
  )
108
  except Exception as e:
 
 
109
  return FinalResult(
110
  ok=False,
111
  ambiguous=True,
@@ -115,7 +154,7 @@ class Pipeline:
115
  sql=None,
116
  rationale=None,
117
  verified=None,
118
- traces=[],
119
  )
120
 
121
  # --- 2) planner ---
@@ -142,7 +181,7 @@ class Pipeline:
142
  user_query=user_query,
143
  schema_preview=schema_preview,
144
  plan_text=(r_plan.data or {}).get("plan"),
145
- clarify_answers=clarify_answers or {},
146
  )
147
  traces.extend(self._trace_list(r_gen))
148
  if not r_gen.ok:
@@ -183,7 +222,9 @@ class Pipeline:
183
  )
184
  traces.extend(self._trace_list(r_exec))
185
  if not r_exec.ok:
186
- details.extend(r_exec.error or [])
 
 
187
 
188
  # --- 6) verifier ---
189
  r_ver = self._safe_stage(
@@ -203,13 +244,17 @@ class Pipeline:
203
  )
204
  traces.extend(self._trace_list(r_fix))
205
  if not r_fix.ok:
 
206
  break
207
 
208
- sql = (r_fix.data or {}).get("sql")
 
 
209
  r_safe = self._safe_stage(self.safety.run, sql=sql)
210
  traces.extend(self._trace_list(r_safe))
211
  if not r_safe.ok:
212
- details.extend(r_safe.error or [])
 
213
  continue
214
 
215
  r_exec = self._safe_stage(
@@ -217,7 +262,8 @@ class Pipeline:
217
  )
218
  traces.extend(self._trace_list(r_exec))
219
  if not r_exec.ok:
220
- details.extend(r_exec.error or [])
 
221
  continue
222
 
223
  r_ver = self._safe_stage(
@@ -230,19 +276,19 @@ class Pipeline:
230
 
231
  # --- 8) fallback: verifier silent but executor succeeded ---
232
  if (verified is None or not verified) and not details:
233
- any_exec = any(
234
- t.get("stage") == "executor" and t.get("notes", {}).get("row_count")
235
  for t in traces
236
  )
237
- if any_exec:
238
  traces.append(
239
- {
240
- "stage": "pipeline",
241
- "notes": {
 
242
  "auto_fix": "verified=True (executor succeeded, verifier silent)"
243
  },
244
- "duration_ms": 0.0,
245
- }
246
  )
247
  verified = True
248
 
@@ -252,11 +298,11 @@ class Pipeline:
252
  err = has_errors and not bool(verified)
253
 
254
  traces.append(
255
- {
256
- "stage": "pipeline",
257
- "notes": {"final_verified": verified, "details_len": len(details)},
258
- "duration_ms": 0.0,
259
- }
260
  )
261
 
262
  return FinalResult(
 
2
  import traceback
3
  from dataclasses import dataclass
4
  from typing import Dict, Any, Optional, List
5
+ import time
6
 
7
  from nl2sql.types import StageResult
8
  from nl2sql.ambiguity_detector import AmbiguityDetector
 
32
  """
33
  NL2SQL Copilot pipeline.
34
  Stages return StageResult; final result is a type-safe FinalResult.
35
+ DI-ready: all dependencies are injected via __init__.
36
  """
37
 
38
  def __init__(
 
55
  self.repair = repair or NoOpRepair()
56
 
57
  # ------------------------------------------------------------
58
+ @staticmethod
59
+ def _trace_list(*stages: Optional[StageResult]) -> List[dict]:
60
+ """Collect .trace objects (as dict) from StageResult items if present."""
61
+ traces: List[dict] = []
62
  for s in stages:
63
  if not s:
64
  continue
65
  t = getattr(s, "trace", None)
66
+ if t is not None:
67
+ # t is likely a dataclass – expose as plain dict for JSON safety
68
+ traces.append(getattr(t, "__dict__", t))
69
  return traces
70
 
71
  # ------------------------------------------------------------
72
+ @staticmethod
73
+ def _safe_stage(fn, **kwargs) -> StageResult:
74
+ """
75
+ Run a stage safely; if it throws, return a StageResult(ok=False, error=[...]).
76
+ If fn returns a non-StageResult (e.g., dict), coerce to StageResult(ok=True, data=...).
77
+ """
78
  try:
79
  r = fn(**kwargs)
80
  if isinstance(r, StageResult):
 
84
  tb = traceback.format_exc()
85
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
86
 
87
+ # ------------------------------------------------------------
88
+ @staticmethod
89
+ def _mk_trace(
90
+ stage: str, duration_ms: float, notes: Optional[Dict[str, Any]] = None
91
+ ) -> dict:
92
+ """Create a normalized trace dict."""
93
+ return {
94
+ "stage": stage,
95
+ "duration_ms": float(duration_ms),
96
+ "notes": notes or {},
97
+ }
98
+
99
  # ------------------------------------------------------------
100
  def run(
101
  self,
 
109
  sql: Optional[str] = None
110
  rationale: Optional[str] = None
111
  verified: Optional[bool] = None
112
+
113
+ # Normalize inputs
114
  schema_preview = schema_preview or ""
115
+ clarify_answers = clarify_answers or {}
116
 
117
+ # --- 1) ambiguity detection (with explicit timing & trace) ---
118
  try:
119
+ t0 = time.perf_counter()
120
  questions = self.detector.detect(user_query, schema_preview)
121
+ t1 = time.perf_counter()
122
+ traces.append(
123
+ self._mk_trace(
124
+ "detector",
125
+ (t1 - t0) * 1000.0,
126
+ {
127
+ "ambiguous": bool(questions),
128
+ "questions_len": len(questions or []),
129
+ },
130
+ )
131
+ )
132
+
133
  if questions:
134
  return FinalResult(
135
  ok=True,
 
140
  sql=None,
141
  rationale=None,
142
  verified=None,
143
+ traces=traces,
144
  )
145
  except Exception as e:
146
+ # detector crash – mark as error but keep trace so far
147
+ traces.append(self._mk_trace("detector", 0.0, {"error": str(e)}))
148
  return FinalResult(
149
  ok=False,
150
  ambiguous=True,
 
154
  sql=None,
155
  rationale=None,
156
  verified=None,
157
+ traces=traces,
158
  )
159
 
160
  # --- 2) planner ---
 
181
  user_query=user_query,
182
  schema_preview=schema_preview,
183
  plan_text=(r_plan.data or {}).get("plan"),
184
+ clarify_answers=clarify_answers,
185
  )
186
  traces.extend(self._trace_list(r_gen))
187
  if not r_gen.ok:
 
222
  )
223
  traces.extend(self._trace_list(r_exec))
224
  if not r_exec.ok:
225
+ # executor failure does not hard-fail the pipeline; accumulate details
226
+ if r_exec.error:
227
+ details.extend(r_exec.error)
228
 
229
  # --- 6) verifier ---
230
  r_ver = self._safe_stage(
 
244
  )
245
  traces.extend(self._trace_list(r_fix))
246
  if not r_fix.ok:
247
+ # repair failed – stop trying further
248
  break
249
 
250
+ # re-run safety executor → verifier on the fixed SQL
251
+ sql = (r_fix.data or {}).get("sql", sql)
252
+
253
  r_safe = self._safe_stage(self.safety.run, sql=sql)
254
  traces.extend(self._trace_list(r_safe))
255
  if not r_safe.ok:
256
+ if r_safe.error:
257
+ details.extend(r_safe.error)
258
  continue
259
 
260
  r_exec = self._safe_stage(
 
262
  )
263
  traces.extend(self._trace_list(r_exec))
264
  if not r_exec.ok:
265
+ if r_exec.error:
266
+ details.extend(r_exec.error)
267
  continue
268
 
269
  r_ver = self._safe_stage(
 
276
 
277
  # --- 8) fallback: verifier silent but executor succeeded ---
278
  if (verified is None or not verified) and not details:
279
+ any_exec_ok = any(
280
+ t.get("stage") == "executor" and (t.get("notes") or {}).get("row_count")
281
  for t in traces
282
  )
283
+ if any_exec_ok:
284
  traces.append(
285
+ self._mk_trace(
286
+ "pipeline",
287
+ 0.0,
288
+ {
289
  "auto_fix": "verified=True (executor succeeded, verifier silent)"
290
  },
291
+ )
 
292
  )
293
  verified = True
294
 
 
298
  err = has_errors and not bool(verified)
299
 
300
  traces.append(
301
+ self._mk_trace(
302
+ "pipeline",
303
+ 0.0,
304
+ {"final_verified": bool(verified), "details_len": len(details)},
305
+ )
306
  )
307
 
308
  return FinalResult(
nl2sql/pipeline_factory.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from typing import Any, Dict
3
+ from nl2sql.pipeline import Pipeline
4
+ from nl2sql.registry import (
5
+ DETECTORS,
6
+ PLANNERS,
7
+ GENERATORS,
8
+ SAFETIES,
9
+ EXECUTORS,
10
+ VERIFIERS,
11
+ REPAIRS,
12
+ )
13
+ from adapters.db.sqlite_adapter import SQLiteAdapter
14
+ from adapters.db.postgres_adapter import PostgresAdapter
15
+ from adapters.db.base import DBAdapter
16
+
17
+
18
+ def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
19
+ kind = adapter_cfg.get("kind", "sqlite")
20
+ if kind == "sqlite":
21
+ return SQLiteAdapter(adapter_cfg.get("dsn"))
22
+ if kind == "postgres":
23
+ return PostgresAdapter(**adapter_cfg)
24
+ raise ValueError(f"Unknown adapter kind: {kind}")
25
+
26
+
27
+ def pipeline_from_config(path: str) -> Pipeline:
28
+ with open(path, "r", encoding="utf-8") as fh:
29
+ cfg: Dict[str, Any] = yaml.safe_load(fh)
30
+
31
+ detector = DETECTORS[cfg.get("detector", "default")]()
32
+ planner = PLANNERS[cfg.get("planner", "default")]()
33
+ generator = GENERATORS[cfg.get("generator", "rules")]()
34
+ safety = SAFETIES[cfg.get("safety", "default")]()
35
+ executor = EXECUTORS[cfg.get("executor", "default")]()
36
+ verifier = VERIFIERS[cfg.get("verifier", "basic")]()
37
+ repair = REPAIRS[cfg.get("repair", "default")]()
38
+
39
+ # If your Executor needs an adapter inside, set it there (common pattern):
40
+ adapter_cfg = cfg.get("adapter", {"kind": "sqlite", "dsn": "data/chinook.db"})
41
+ adapter = _build_adapter(adapter_cfg)
42
+ if hasattr(executor, "bind_adapter"):
43
+ executor.bind_adapter(adapter)
44
+ elif hasattr(executor, "adapter"):
45
+ executor.adapter = adapter # fallback
46
+
47
+ return Pipeline(
48
+ detector=detector,
49
+ planner=planner,
50
+ generator=generator,
51
+ safety=safety,
52
+ executor=executor,
53
+ verifier=verifier,
54
+ repair=repair,
55
+ )
56
+
57
+
58
+ def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
59
+ """Same as pipeline_from_config, but force a specific adapter (per-request override)."""
60
+ with open(path, "r", encoding="utf-8") as fh:
61
+ cfg: Dict[str, Any] = yaml.safe_load(fh)
62
+
63
+ detector = DETECTORS[cfg.get("detector", "default")]()
64
+ planner = PLANNERS[cfg.get("planner", "default")]()
65
+ generator = GENERATORS[cfg.get("generator", "rules")]()
66
+ safety = SAFETIES[cfg.get("safety", "default")]()
67
+ executor = EXECUTORS[cfg.get("executor", "default")]()
68
+ verifier = VERIFIERS[cfg.get("verifier", "basic")]()
69
+ repair = REPAIRS[cfg.get("repair", "default")]()
70
+
71
+ if hasattr(executor, "bind_adapter"):
72
+ executor.bind_adapter(adapter)
73
+ elif hasattr(executor, "adapter"):
74
+ executor.adapter = adapter
75
+
76
+ return Pipeline(
77
+ detector=detector,
78
+ planner=planner,
79
+ generator=generator,
80
+ safety=safety,
81
+ executor=executor,
82
+ verifier=verifier,
83
+ repair=repair,
84
+ )
nl2sql/registry.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Registry mapping simple string keys to concrete component classes.
3
+ Used by pipeline_factory to perform lightweight dependency injection.
4
+ """
5
+
6
+ from typing import Dict, Type
7
+ from nl2sql.ambiguity_detector import AmbiguityDetector
8
+ from nl2sql.planner import Planner
9
+ from nl2sql.generator import Generator
10
+ from nl2sql.safety import Safety
11
+ from nl2sql.executor import Executor
12
+ from nl2sql.verifier import Verifier
13
+ from nl2sql.repair import Repair
14
+
15
+ # later you can add llm-aware generator variants, etc.
16
+ PLANNERS: Dict[str, Type[Planner]] = {"default": Planner}
17
+ DETECTORS: Dict[str, Type[AmbiguityDetector]] = {"default": AmbiguityDetector}
18
+ GENERATORS: Dict[str, Type[Generator]] = {"rules": Generator}
19
+ SAFETIES: Dict[str, Type[Safety]] = {"default": Safety}
20
+ EXECUTORS: Dict[str, Type[Executor]] = {"default": Executor}
21
+ VERIFIERS: Dict[str, Type[Verifier]] = {"basic": Verifier}
22
+ REPAIRS: Dict[str, Type[Repair]] = {"default": Repair}
tests/test_nl2sql_router.py CHANGED
@@ -252,3 +252,12 @@ def test_traces_are_rounded_to_ints():
252
  assert isinstance(traces[0]["duration_ms"], int)
253
  finally:
254
  app.dependency_overrides.pop(nl2sql.get_runner, None)
 
 
 
 
 
 
 
 
 
 
252
  assert isinstance(traces[0]["duration_ms"], int)
253
  finally:
254
  app.dependency_overrides.pop(nl2sql.get_runner, None)
255
+
256
+
257
+ def test_nl2sql_handler_returns_sql(monkeypatch):
258
+ payload = {"query": "Top 5 albums by sales"}
259
+ r = client.post("/nl2sql", json=payload)
260
+ assert r.status_code == 200
261
+ data = r.json()
262
+ assert "sql" in data
263
+ assert "traces" in data
tests/test_pipeline_factory.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nl2sql.pipeline_factory import (
2
+ pipeline_from_config,
3
+ pipeline_from_config_with_adapter,
4
+ )
5
+ from adapters.db.sqlite_adapter import SQLiteAdapter
6
+
7
+
8
+ def test_pipeline_from_config_builds_and_runs(tmp_path):
9
+ p = pipeline_from_config("configs/sqlite_pipeline.yaml")
10
+ result = p.run(user_query="Top 3 albums by sales")
11
+ assert result.sql is not None
12
+ assert isinstance(result.traces, list)
13
+
14
+
15
+ def test_pipeline_from_config_with_adapter_override(tmp_path):
16
+ adapter = SQLiteAdapter("data/chinook.db")
17
+ p = pipeline_from_config_with_adapter(
18
+ "configs/sqlite_pipeline.yaml", adapter=adapter
19
+ )
20
+ result = p.run(user_query="Count customers")
21
+ assert "SELECT" in result.sql.upper()
22
+ assert isinstance(result.traces, list)
23
+
24
+
25
+ def test_full_pipeline_from_yaml(monkeypatch):
26
+ from nl2sql.pipeline_factory import pipeline_from_config
27
+
28
+ p = pipeline_from_config("configs/sqlite_pipeline.yaml")
29
+ res = p.run(user_query="List all artists")
30
+ assert res.ok
31
+ assert isinstance(res.sql, str)
32
+ assert any(t["stage"] == "executor" for t in res.traces)