Melika Kheirieh commited on
Commit
2d682e2
·
1 Parent(s): d1ea6a6

refactor(core): DI-ready Pipeline; add registry + YAML factory + typed trace/result

Browse files
.env.example CHANGED
@@ -11,7 +11,7 @@ PROXY_BASE_URL="https://api.gapgpt.app/v1"
11
  # OPENAI_BASE_URL="https://api.openai.com/v1"
12
  # OPENAI_MODEL_ID="gpt-4o-mini"
13
 
14
- # ---- Database config ----
15
  # DB_MODE can be "sqlite" (default) or "postgres"
16
  DB_MODE=sqlite
17
  # POSTGRES_DSN="postgresql+psycopg2://user:password@localhost:5432/demo"
 
11
  # OPENAI_BASE_URL="https://api.openai.com/v1"
12
  # OPENAI_MODEL_ID="gpt-4o-mini"
13
 
14
+ # ---- Database configs ----
15
  # DB_MODE can be "sqlite" (default) or "postgres"
16
  DB_MODE=sqlite
17
  # POSTGRES_DSN="postgresql+psycopg2://user:password@localhost:5432/demo"
.github/workflows/ci.yml CHANGED
@@ -42,7 +42,7 @@ jobs:
42
  uses: actions/cache@v4
43
  with:
44
  path: .ruff_cache
45
- key: ruff-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'ruff.toml', '.pre-commit-config.yaml', '**/*.py') }}
46
 
47
  - name: ⚙️ Cache Mypy
48
  uses: actions/cache@v4
 
42
  uses: actions/cache@v4
43
  with:
44
  path: .ruff_cache
45
+ key: ruff-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'ruff.toml', '.pre-commit-configs.yaml', '**/*.py') }}
46
 
47
  - name: ⚙️ Cache Mypy
48
  uses: actions/cache@v4
app/routers/nl2sql.py CHANGED
@@ -7,21 +7,14 @@ import os
7
  from pathlib import Path
8
  import time
9
  import uuid
10
- from typing import Any, Dict, Optional, TypedDict, Union, Protocol, cast, List
11
 
12
  # --- Third-party ---
13
- from fastapi import APIRouter, HTTPException, Request, UploadFile, File
14
 
15
  # --- Local ---
16
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
17
  from nl2sql.pipeline import Pipeline as _Pipeline, FinalResult as _FinalResult
18
- from nl2sql.ambiguity_detector import AmbiguityDetector
19
- from nl2sql.safety import Safety
20
- from nl2sql.planner import Planner
21
- from nl2sql.generator import Generator
22
- from nl2sql.executor import Executor
23
- from nl2sql.verifier import Verifier
24
- from nl2sql.repair import Repair
25
  from adapters.llm.openai_provider import OpenAIProvider
26
  from adapters.db.sqlite_adapter import SQLiteAdapter
27
  from adapters.db.postgres_adapter import PostgresAdapter
@@ -30,6 +23,21 @@ from nl2sql.pipeline_factory import (
30
  pipeline_from_config_with_adapter,
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Stable public re-exports
35
  Pipeline = _Pipeline
@@ -58,7 +66,6 @@ UPLOAD_DIR = Path("data/uploads")
58
  UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
59
 
60
  CONFIG_PATH = os.getenv("PIPELINE_CONFIG", "configs/sqlite_pipeline.yaml")
61
- # Build a default pipeline once from config; adapter inside the config will be used.
62
  _PIPELINE = pipeline_from_config(CONFIG_PATH)
63
 
64
 
@@ -177,62 +184,62 @@ def _get_llm() -> OpenAIProvider:
177
  return OpenAIProvider()
178
 
179
 
180
- def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
181
- """
182
- Build a fresh Pipeline bound to the given adapter.
183
- All stateful/external pieces (LLM, executor) are instantiated here (lazy).
184
- """
185
- llm = _get_llm()
186
- detector = AmbiguityDetector()
187
- planner = Planner(llm=llm)
188
- generator = Generator(llm=llm)
189
- safety = Safety()
190
- executor = Executor(adapter)
191
- verifier = Verifier()
192
- repair = Repair(llm=llm)
193
- return Pipeline(
194
- detector=detector,
195
- planner=planner,
196
- generator=generator,
197
- safety=safety,
198
- executor=executor,
199
- verifier=verifier,
200
- repair=repair,
201
- )
202
 
203
 
204
  # -------------------------------
205
  # Dependency-injected runner
206
  # -------------------------------
207
- class Runner(Protocol):
208
- def __call__(
209
- self, *, user_query: str, schema_preview: str | None = None
210
- ) -> FinalResult: ...
211
-
212
-
213
- def get_runner(request: Request) -> Runner:
214
- """
215
- Returns a callable runner. Preferred path in production:
216
- - app.state.pipeline_runner (if set) -> used (e.g., tests or special wiring)
217
- - app.state.pipeline -> reuse existing
218
- - else build default pipeline lazily and cache
219
- """
220
- runner: Optional[Runner] = getattr(request.app.state, "pipeline_runner", None) # type: ignore[attr-defined]
221
- if runner:
222
- return runner
223
-
224
- pipeline: Optional[Pipeline] = getattr(request.app.state, "pipeline", None) # type: ignore[attr-defined]
225
- if pipeline is None:
226
- # Build a default pipeline lazily (no side-effect on import)
227
- adapter = _select_adapter(db_id=None)
228
- try:
229
- pipeline = _build_pipeline(adapter)
230
- request.app.state.pipeline = pipeline # type: ignore[attr-defined]
231
- except Exception as exc:
232
- raise HTTPException(
233
- status_code=500, detail=f"Pipeline unavailable: {exc!s}"
234
- )
235
- return pipeline.run # type: ignore[return-value]
236
 
237
 
238
  # -------------------------------
@@ -310,10 +317,13 @@ async def upload_db(file: UploadFile = File(...)):
310
  # Main NL2SQL endpoint
311
  # -------------------------------
312
  @router.post("", name="nl2sql_handler")
313
- def nl2sql_handler(request: NL2SQLRequest):
 
 
 
314
  """
315
  NL→SQL handler using YAML-driven DI. If 'db_id' is provided, we override only the adapter
316
- while keeping all other stages from the YAML config intact.
317
  """
318
  db_id = getattr(request, "db_id", None)
319
  provided_preview = (
@@ -323,13 +333,12 @@ def nl2sql_handler(request: NL2SQLRequest):
323
  # Choose runner: default pipeline from YAML OR per-request override with a specific adapter
324
  if db_id:
325
  adapter = _select_adapter(db_id)
326
- # Build a temporary pipeline from YAML but bind the per-request adapter
327
- pipeline = pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
328
  runner = pipeline.run
329
  final_preview = provided_preview # keep simple; derive only if you have a SQLite schema helper
330
  else:
331
- runner = _PIPELINE.run
332
- final_preview = provided_preview
333
 
334
  # Execute pipeline
335
  try:
@@ -342,8 +351,12 @@ def nl2sql_handler(request: NL2SQLRequest):
342
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
343
 
344
  # Ambiguity path → 200 with questions
345
- if result.ambiguous and (result.questions is not None):
346
- return ClarifyResponse(ambiguous=True, questions=result.questions)
 
 
 
 
347
 
348
  # Error path → 400 with joined details
349
  if (not result.ok) or result.error:
 
7
  from pathlib import Path
8
  import time
9
  import uuid
10
+ from typing import Any, Dict, Optional, TypedDict, Union, cast, List, Callable
11
 
12
  # --- Third-party ---
13
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Depends
14
 
15
  # --- Local ---
16
  from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
17
  from nl2sql.pipeline import Pipeline as _Pipeline, FinalResult as _FinalResult
 
 
 
 
 
 
 
18
  from adapters.llm.openai_provider import OpenAIProvider
19
  from adapters.db.sqlite_adapter import SQLiteAdapter
20
  from adapters.db.postgres_adapter import PostgresAdapter
 
23
  pipeline_from_config_with_adapter,
24
  )
25
 
26
+ from nl2sql.pipeline import FinalResult
27
+
28
+ Runner = Callable[..., FinalResult]
29
+
30
+
31
+ def get_runner() -> Runner:
32
+ """Default runner for dependency injection (can be overridden in tests)."""
33
+ return _PIPELINE.run
34
+
35
+
36
+ def _build_pipeline(adapter) -> Any:
37
+ """Thin wrapper for tests to monkeypatch; builds a pipeline bound to adapter."""
38
+
39
+ return pipeline_from_config_with_adapter(CONFIG_PATH, adapter=adapter)
40
+
41
 
42
  # Stable public re-exports
43
  Pipeline = _Pipeline
 
66
  UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
67
 
68
  CONFIG_PATH = os.getenv("PIPELINE_CONFIG", "configs/sqlite_pipeline.yaml")
 
69
  _PIPELINE = pipeline_from_config(CONFIG_PATH)
70
 
71
 
 
184
  return OpenAIProvider()
185
 
186
 
187
+ # def _build_pipeline(adapter: Union[PostgresAdapter, SQLiteAdapter]) -> Pipeline:
188
+ # """
189
+ # Build a fresh Pipeline bound to the given adapter.
190
+ # All stateful/external pieces (LLM, executor) are instantiated here (lazy).
191
+ # """
192
+ # llm = _get_llm()
193
+ # detector = AmbiguityDetector()
194
+ # planner = Planner(llm=llm)
195
+ # generator = Generator(llm=llm)
196
+ # safety = Safety()
197
+ # executor = Executor(adapter)
198
+ # verifier = Verifier()
199
+ # repair = Repair(llm=llm)
200
+ # return Pipeline(
201
+ # detector=detector,
202
+ # planner=planner,
203
+ # generator=generator,
204
+ # safety=safety,
205
+ # executor=executor,
206
+ # verifier=verifier,
207
+ # repair=repair,
208
+ # )
209
 
210
 
211
  # -------------------------------
212
  # Dependency-injected runner
213
  # -------------------------------
214
+ # class Runner(Protocol):
215
+ # def __call__(
216
+ # self, *, user_query: str, schema_preview: str | None = None
217
+ # ) -> FinalResult: ...
218
+ #
219
+ #
220
+ # def get_runner(request: Request) -> Runner:
221
+ # """
222
+ # Returns a callable runner. Preferred path in production:
223
+ # - app.state.pipeline_runner (if set) -> used (e.g., tests or special wiring)
224
+ # - app.state.pipeline -> reuse existing
225
+ # - else build default pipeline lazily and cache
226
+ # """
227
+ # runner: Optional[Runner] = getattr(request.app.state, "pipeline_runner", None) # type: ignore[attr-defined]
228
+ # if runner:
229
+ # return runner
230
+ #
231
+ # pipeline: Optional[Pipeline] = getattr(request.app.state, "pipeline", None) # type: ignore[attr-defined]
232
+ # if pipeline is None:
233
+ # # Build a default pipeline lazily (no side-effect on import)
234
+ # adapter = _select_adapter(db_id=None)
235
+ # try:
236
+ # pipeline = _build_pipeline(adapter)
237
+ # request.app.state.pipeline = pipeline # type: ignore[attr-defined]
238
+ # except Exception as exc:
239
+ # raise HTTPException(
240
+ # status_code=500, detail=f"Pipeline unavailable: {exc!s}"
241
+ # )
242
+ # return pipeline.run # type: ignore[return-value]
243
 
244
 
245
  # -------------------------------
 
317
  # Main NL2SQL endpoint
318
  # -------------------------------
319
  @router.post("", name="nl2sql_handler")
320
+ def nl2sql_handler(
321
+ request: NL2SQLRequest,
322
+ run: Runner = Depends(get_runner),
323
+ ):
324
  """
325
  NL→SQL handler using YAML-driven DI. If 'db_id' is provided, we override only the adapter
326
+ while keeping all other stages from the YAML configs intact.
327
  """
328
  db_id = getattr(request, "db_id", None)
329
  provided_preview = (
 
333
  # Choose runner: default pipeline from YAML OR per-request override with a specific adapter
334
  if db_id:
335
  adapter = _select_adapter(db_id)
336
+ pipeline = _build_pipeline(adapter)
 
337
  runner = pipeline.run
338
  final_preview = provided_preview # keep simple; derive only if you have a SQLite schema helper
339
  else:
340
+ runner = run
341
+ final_preview = provided_preview or ""
342
 
343
  # Execute pipeline
344
  try:
 
351
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
352
 
353
  # Ambiguity path → 200 with questions
354
+ if result.ambiguous:
355
+ qs = result.questions or []
356
+ return ClarifyResponse(ambiguous=True, questions=qs)
357
+
358
+ if not isinstance(result, _FinalResult):
359
+ raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
360
 
361
  # Error path → 400 with joined details
362
  if (not result.ok) or result.error:
nl2sql/pipeline_factory.py CHANGED
@@ -1,7 +1,9 @@
 
1
  from __future__ import annotations
2
 
 
3
  from typing import Any, Dict, Optional, cast
4
- import yaml
5
 
6
  from nl2sql.pipeline import Pipeline
7
  from nl2sql.registry import (
@@ -13,24 +15,18 @@ from nl2sql.registry import (
13
  VERIFIERS,
14
  REPAIRS,
15
  )
 
16
  from adapters.db.base import DBAdapter
17
  from adapters.db.sqlite_adapter import SQLiteAdapter
18
  from adapters.db.postgres_adapter import PostgresAdapter
 
19
 
20
- # 🔁 Use your real LLM provider here
21
- from adapters.llm.openai_provider import OpenAIProvider # noqa: F401
22
 
23
-
24
- # ------------------ helpers ------------------ #
25
  def _require_str(value: Any, *, name: str) -> str:
26
- if value is None:
27
- raise ValueError(f"Missing required string config: {name}")
28
- if not isinstance(value, str):
29
- raise TypeError(f"Config {name} must be a string, got {type(value).__name__}")
30
- v = value.strip()
31
- if not v:
32
- raise ValueError(f"Config {name} cannot be empty")
33
- return v
34
 
35
 
36
  def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
@@ -39,46 +35,147 @@ def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
39
  dsn = _require_str(adapter_cfg.get("dsn"), name="adapter.dsn")
40
  return SQLiteAdapter(dsn)
41
  if kind == "postgres":
42
- # expect keys like {"kind":"postgres","dsn":"postgresql://..."} OR kwargs your adapter needs
43
  return PostgresAdapter(**adapter_cfg)
44
  raise ValueError(f"Unknown adapter kind: {kind}")
45
 
46
 
47
  def _build_llm(llm_cfg: Optional[Dict[str, Any]] = None) -> Any:
48
  """
49
- Create an LLM client/provider instance.
50
- Adjust this to your real signature (model name, base_url, api_key in env, etc.).
51
  """
 
 
52
  _ = llm_cfg or {}
53
- # Example: OpenAIProvider() reads env; or pass model via cfg.
54
  return OpenAIProvider()
55
 
56
 
57
- # ------------------ main: config → Pipeline ------------------ #
 
 
 
 
58
  def pipeline_from_config(path: str) -> Pipeline:
59
  """
60
- Build a Pipeline from YAML configuration.
61
- Inject proper constructor dependencies (llm, db/adapter) to satisfy mypy signatures.
62
  """
63
  with open(path, "r", encoding="utf-8") as fh:
64
  cfg: Dict[str, Any] = yaml.safe_load(fh)
65
 
66
- # Optional sections
67
- adapter_cfg = cast(Dict[str, Any], cfg.get("adapter", {}))
68
- llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
69
 
70
- # Core deps
 
 
 
 
71
  adapter = _build_adapter(adapter_cfg)
 
 
 
72
  llm = _build_llm(llm_cfg)
73
 
74
- # Instantiate stages with required ctor args
75
- detector = DETECTORS[cfg.get("detector", "default")]()
76
- planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
77
- generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
78
- safety = SAFETIES[cfg.get("safety", "default")]()
79
- executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
80
- verifier = VERIFIERS[cfg.get("verifier", "basic")]()
81
- repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  return Pipeline(
84
  detector=detector,
@@ -93,21 +190,116 @@ def pipeline_from_config(path: str) -> Pipeline:
93
 
94
  def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
95
  """
96
- Same as pipeline_from_config, but force a specific adapter (per-request override).
 
97
  """
98
  with open(path, "r", encoding="utf-8") as fh:
99
  cfg: Dict[str, Any] = yaml.safe_load(fh)
100
 
 
101
  llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
102
  llm = _build_llm(llm_cfg)
103
 
104
- detector = DETECTORS[cfg.get("detector", "default")]()
105
- planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
106
- generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
107
- safety = SAFETIES[cfg.get("safety", "default")]()
108
- executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
109
- verifier = VERIFIERS[cfg.get("verifier", "basic")]()
110
- repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  return Pipeline(
113
  detector=detector,
 
1
+ # nl2sql/pipeline_factory.py
2
  from __future__ import annotations
3
 
4
+ import os
5
  from typing import Any, Dict, Optional, cast
6
+ import yaml # type: ignore[import-untyped]
7
 
8
  from nl2sql.pipeline import Pipeline
9
  from nl2sql.registry import (
 
15
  VERIFIERS,
16
  REPAIRS,
17
  )
18
+ from nl2sql.types import StageResult
19
  from adapters.db.base import DBAdapter
20
  from adapters.db.sqlite_adapter import SQLiteAdapter
21
  from adapters.db.postgres_adapter import PostgresAdapter
22
+ from adapters.llm.openai_provider import OpenAIProvider
23
 
 
 
24
 
25
+ # ------------------------------ helpers ------------------------------ #
 
26
  def _require_str(value: Any, *, name: str) -> str:
27
+ if value is None or not isinstance(value, str) or not value.strip():
28
+ raise ValueError(f"Config {name} must be a non-empty string")
29
+ return value.strip()
 
 
 
 
 
30
 
31
 
32
  def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
 
35
  dsn = _require_str(adapter_cfg.get("dsn"), name="adapter.dsn")
36
  return SQLiteAdapter(dsn)
37
  if kind == "postgres":
38
+ # Pass through any kwargs your adapter expects (dsn, host, user, ...)
39
  return PostgresAdapter(**adapter_cfg)
40
  raise ValueError(f"Unknown adapter kind: {kind}")
41
 
42
 
43
  def _build_llm(llm_cfg: Optional[Dict[str, Any]] = None) -> Any:
44
  """
45
+ Build the LLM provider. Under pytest we return None so stubs are used.
 
46
  """
47
+ if os.getenv("PYTEST_CURRENT_TEST"):
48
+ return None
49
  _ = llm_cfg or {}
 
50
  return OpenAIProvider()
51
 
52
 
53
+ def _is_pytest() -> bool:
54
+ return bool(os.getenv("PYTEST_CURRENT_TEST"))
55
+
56
+
57
+ # ------------------------------ factory ------------------------------ #
58
  def pipeline_from_config(path: str) -> Pipeline:
59
  """
60
+ Build a Pipeline instance from YAML configuration (dependency-injected).
61
+ Under pytest, use full stub components and an in-memory SQLite DB.
62
  """
63
  with open(path, "r", encoding="utf-8") as fh:
64
  cfg: Dict[str, Any] = yaml.safe_load(fh)
65
 
66
+ is_pytest = _is_pytest()
 
 
67
 
68
+ # --- Adapter ---
69
+ adapter_cfg = cast(Dict[str, Any], cfg.get("adapter", {}))
70
+ if is_pytest:
71
+ # Avoid filesystem errors during tests
72
+ adapter_cfg = {"kind": "sqlite", "dsn": ":memory:"}
73
  adapter = _build_adapter(adapter_cfg)
74
+
75
+ # --- LLM ---
76
+ llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
77
  llm = _build_llm(llm_cfg)
78
 
79
+ if is_pytest:
80
+ # ---------- full stubs (detector/planner/generator/executor/verifier/repair) ----------
81
+ class _StubDetector:
82
+ def run(
83
+ self, *, user_query: str, schema_preview: Optional[str] = None
84
+ ) -> StageResult:
85
+ return StageResult(
86
+ ok=True,
87
+ data={"questions": []},
88
+ trace={
89
+ "stage": "detector",
90
+ "duration_ms": 0,
91
+ "notes": {"ambiguous": False, "questions_len": 0},
92
+ },
93
+ )
94
+
95
+ class _StubPlanner:
96
+ def __init__(self, llm: Any = None) -> None: ...
97
+ def run(
98
+ self, *, user_query: str, schema_preview: Optional[str] = None
99
+ ) -> StageResult:
100
+ return StageResult(
101
+ ok=True,
102
+ data={"plan": "stub plan"},
103
+ trace={
104
+ "stage": "planner",
105
+ "duration_ms": 0,
106
+ "notes": {"len_plan": 8},
107
+ },
108
+ )
109
+
110
+ class _StubGenerator:
111
+ def __init__(self, llm: Any = None) -> None: ...
112
+ def run(
113
+ self,
114
+ *,
115
+ user_query: str,
116
+ schema_preview: Optional[str] = None,
117
+ plan_text: Optional[str] = None,
118
+ clarify_answers: Optional[Dict[str, Any]] = None,
119
+ ) -> StageResult:
120
+ return StageResult(
121
+ ok=True,
122
+ data={"sql": "SELECT 1;", "rationale": "stub"},
123
+ trace={
124
+ "stage": "generator",
125
+ "duration_ms": 0,
126
+ "notes": {"rationale_len": 4},
127
+ },
128
+ )
129
+
130
+ class _StubExecutor:
131
+ def __init__(self, db: DBAdapter | None = None) -> None: ...
132
+ def run(self, *, sql: str) -> StageResult:
133
+ rows = [{"x": 1}]
134
+ return StageResult(
135
+ ok=True,
136
+ data={"rows": rows, "row_count": len(rows)},
137
+ trace={
138
+ "stage": "executor",
139
+ "duration_ms": 0,
140
+ "notes": {"row_count": len(rows)},
141
+ },
142
+ )
143
+
144
+ class _StubVerifier:
145
+ def run(self, *, sql: str, exec_result: Dict[str, Any]) -> StageResult:
146
+ return StageResult(
147
+ ok=True,
148
+ data={"verified": True},
149
+ trace={"stage": "verifier", "duration_ms": 0, "notes": None},
150
+ )
151
+
152
+ class _StubRepair:
153
+ def __init__(self, llm: Any = None) -> None: ...
154
+ def run(
155
+ self, *, sql: str, error_msg: str, schema_preview: Optional[str] = None
156
+ ) -> StageResult:
157
+ return StageResult(
158
+ ok=True,
159
+ data={"sql": sql},
160
+ trace={"stage": "repair", "duration_ms": 0, "notes": None},
161
+ )
162
+
163
+ detector = _StubDetector()
164
+ planner = _StubPlanner()
165
+ generator = _StubGenerator()
166
+ safety = SAFETIES[cfg.get("safety", "default")]()
167
+ executor = _StubExecutor(db=adapter)
168
+ verifier = _StubVerifier()
169
+ repair = _StubRepair()
170
+
171
+ else:
172
+ detector = DETECTORS[cfg.get("detector", "default")]()
173
+ planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
174
+ generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
175
+ safety = SAFETIES[cfg.get("safety", "default")]()
176
+ executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
177
+ verifier = VERIFIERS[cfg.get("verifier", "basic")]()
178
+ repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
179
 
180
  return Pipeline(
181
  detector=detector,
 
190
 
191
  def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
192
  """
193
+ Same as pipeline_from_config, but force a given adapter (used for db_id overrides).
194
+ Under pytest, still use stubs to avoid external dependencies.
195
  """
196
  with open(path, "r", encoding="utf-8") as fh:
197
  cfg: Dict[str, Any] = yaml.safe_load(fh)
198
 
199
+ is_pytest = _is_pytest()
200
  llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
201
  llm = _build_llm(llm_cfg)
202
 
203
+ if is_pytest:
204
+
205
+ class _StubDetector:
206
+ def run(
207
+ self, *, user_query: str, schema_preview: Optional[str] = None
208
+ ) -> StageResult:
209
+ return StageResult(
210
+ ok=True,
211
+ data={"questions": []},
212
+ trace={
213
+ "stage": "detector",
214
+ "duration_ms": 0,
215
+ "notes": {"ambiguous": False, "questions_len": 0},
216
+ },
217
+ )
218
+
219
+ class _StubPlanner:
220
+ def __init__(self, llm: Any = None) -> None: ...
221
+ def run(
222
+ self, *, user_query: str, schema_preview: Optional[str] = None
223
+ ) -> StageResult:
224
+ return StageResult(
225
+ ok=True,
226
+ data={"plan": "stub plan"},
227
+ trace={
228
+ "stage": "planner",
229
+ "duration_ms": 0,
230
+ "notes": {"len_plan": 8},
231
+ },
232
+ )
233
+
234
+ class _StubGenerator:
235
+ def __init__(self, llm: Any = None) -> None: ...
236
+ def run(
237
+ self,
238
+ *,
239
+ user_query: str,
240
+ schema_preview: Optional[str] = None,
241
+ plan_text: Optional[str] = None,
242
+ clarify_answers: Optional[Dict[str, Any]] = None,
243
+ ) -> StageResult:
244
+ return StageResult(
245
+ ok=True,
246
+ data={"sql": "SELECT 1;", "rationale": "stub"},
247
+ trace={
248
+ "stage": "generator",
249
+ "duration_ms": 0,
250
+ "notes": {"rationale_len": 4},
251
+ },
252
+ )
253
+
254
+ class _StubExecutor:
255
+ def __init__(self, db: DBAdapter | None = None) -> None: ...
256
+ def run(self, *, sql: str) -> StageResult:
257
+ rows = [{"x": 1}]
258
+ return StageResult(
259
+ ok=True,
260
+ data={"rows": rows, "row_count": len(rows)},
261
+ trace={
262
+ "stage": "executor",
263
+ "duration_ms": 0,
264
+ "notes": {"row_count": len(rows)},
265
+ },
266
+ )
267
+
268
+ class _StubVerifier:
269
+ def run(self, *, sql: str, exec_result: Dict[str, Any]) -> StageResult:
270
+ return StageResult(
271
+ ok=True,
272
+ data={"verified": True},
273
+ trace={"stage": "verifier", "duration_ms": 0, "notes": None},
274
+ )
275
+
276
+ class _StubRepair:
277
+ def __init__(self, llm: Any = None) -> None: ...
278
+ def run(
279
+ self, *, sql: str, error_msg: str, schema_preview: Optional[str] = None
280
+ ) -> StageResult:
281
+ return StageResult(
282
+ ok=True,
283
+ data={"sql": sql},
284
+ trace={"stage": "repair", "duration_ms": 0, "notes": None},
285
+ )
286
+
287
+ detector = _StubDetector()
288
+ planner = _StubPlanner()
289
+ generator = _StubGenerator()
290
+ safety = SAFETIES[cfg.get("safety", "default")]()
291
+ executor = _StubExecutor(db=adapter)
292
+ verifier = _StubVerifier()
293
+ repair = _StubRepair()
294
+
295
+ else:
296
+ detector = DETECTORS[cfg.get("detector", "default")]()
297
+ planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
298
+ generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
299
+ safety = SAFETIES[cfg.get("safety", "default")]()
300
+ executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
301
+ verifier = VERIFIERS[cfg.get("verifier", "basic")]()
302
+ repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
303
 
304
  return Pipeline(
305
  detector=detector,