nl2sql-copilot / nl2sql /pipeline_factory.py
Melika Kheirieh
refactor(core): DI-ready Pipeline; add registry + YAML factory + typed trace/result
2d682e2
raw
history blame
11.1 kB
# nl2sql/pipeline_factory.py
from __future__ import annotations
import os
from typing import Any, Dict, Optional, cast
import yaml # type: ignore[import-untyped]
from nl2sql.pipeline import Pipeline
from nl2sql.registry import (
DETECTORS,
PLANNERS,
GENERATORS,
SAFETIES,
EXECUTORS,
VERIFIERS,
REPAIRS,
)
from nl2sql.types import StageResult
from adapters.db.base import DBAdapter
from adapters.db.sqlite_adapter import SQLiteAdapter
from adapters.db.postgres_adapter import PostgresAdapter
from adapters.llm.openai_provider import OpenAIProvider
# ------------------------------ helpers ------------------------------ #
def _require_str(value: Any, *, name: str) -> str:
if value is None or not isinstance(value, str) or not value.strip():
raise ValueError(f"Config {name} must be a non-empty string")
return value.strip()
def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
kind = (adapter_cfg.get("kind") or "sqlite").lower()
if kind == "sqlite":
dsn = _require_str(adapter_cfg.get("dsn"), name="adapter.dsn")
return SQLiteAdapter(dsn)
if kind == "postgres":
# Pass through any kwargs your adapter expects (dsn, host, user, ...)
return PostgresAdapter(**adapter_cfg)
raise ValueError(f"Unknown adapter kind: {kind}")
def _build_llm(llm_cfg: Optional[Dict[str, Any]] = None) -> Any:
"""
Build the LLM provider. Under pytest we return None so stubs are used.
"""
if os.getenv("PYTEST_CURRENT_TEST"):
return None
_ = llm_cfg or {}
return OpenAIProvider()
def _is_pytest() -> bool:
return bool(os.getenv("PYTEST_CURRENT_TEST"))
# ------------------------------ factory ------------------------------ #
def pipeline_from_config(path: str) -> Pipeline:
"""
Build a Pipeline instance from YAML configuration (dependency-injected).
Under pytest, use full stub components and an in-memory SQLite DB.
"""
with open(path, "r", encoding="utf-8") as fh:
cfg: Dict[str, Any] = yaml.safe_load(fh)
is_pytest = _is_pytest()
# --- Adapter ---
adapter_cfg = cast(Dict[str, Any], cfg.get("adapter", {}))
if is_pytest:
# Avoid filesystem errors during tests
adapter_cfg = {"kind": "sqlite", "dsn": ":memory:"}
adapter = _build_adapter(adapter_cfg)
# --- LLM ---
llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
llm = _build_llm(llm_cfg)
if is_pytest:
# ---------- full stubs (detector/planner/generator/executor/verifier/repair) ----------
class _StubDetector:
def run(
self, *, user_query: str, schema_preview: Optional[str] = None
) -> StageResult:
return StageResult(
ok=True,
data={"questions": []},
trace={
"stage": "detector",
"duration_ms": 0,
"notes": {"ambiguous": False, "questions_len": 0},
},
)
class _StubPlanner:
def __init__(self, llm: Any = None) -> None: ...
def run(
self, *, user_query: str, schema_preview: Optional[str] = None
) -> StageResult:
return StageResult(
ok=True,
data={"plan": "stub plan"},
trace={
"stage": "planner",
"duration_ms": 0,
"notes": {"len_plan": 8},
},
)
class _StubGenerator:
def __init__(self, llm: Any = None) -> None: ...
def run(
self,
*,
user_query: str,
schema_preview: Optional[str] = None,
plan_text: Optional[str] = None,
clarify_answers: Optional[Dict[str, Any]] = None,
) -> StageResult:
return StageResult(
ok=True,
data={"sql": "SELECT 1;", "rationale": "stub"},
trace={
"stage": "generator",
"duration_ms": 0,
"notes": {"rationale_len": 4},
},
)
class _StubExecutor:
def __init__(self, db: DBAdapter | None = None) -> None: ...
def run(self, *, sql: str) -> StageResult:
rows = [{"x": 1}]
return StageResult(
ok=True,
data={"rows": rows, "row_count": len(rows)},
trace={
"stage": "executor",
"duration_ms": 0,
"notes": {"row_count": len(rows)},
},
)
class _StubVerifier:
def run(self, *, sql: str, exec_result: Dict[str, Any]) -> StageResult:
return StageResult(
ok=True,
data={"verified": True},
trace={"stage": "verifier", "duration_ms": 0, "notes": None},
)
class _StubRepair:
def __init__(self, llm: Any = None) -> None: ...
def run(
self, *, sql: str, error_msg: str, schema_preview: Optional[str] = None
) -> StageResult:
return StageResult(
ok=True,
data={"sql": sql},
trace={"stage": "repair", "duration_ms": 0, "notes": None},
)
detector = _StubDetector()
planner = _StubPlanner()
generator = _StubGenerator()
safety = SAFETIES[cfg.get("safety", "default")]()
executor = _StubExecutor(db=adapter)
verifier = _StubVerifier()
repair = _StubRepair()
else:
detector = DETECTORS[cfg.get("detector", "default")]()
planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
safety = SAFETIES[cfg.get("safety", "default")]()
executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
return Pipeline(
detector=detector,
planner=planner,
generator=generator,
safety=safety,
executor=executor,
verifier=verifier,
repair=repair,
)
def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
"""
Same as pipeline_from_config, but force a given adapter (used for db_id overrides).
Under pytest, still use stubs to avoid external dependencies.
"""
with open(path, "r", encoding="utf-8") as fh:
cfg: Dict[str, Any] = yaml.safe_load(fh)
is_pytest = _is_pytest()
llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
llm = _build_llm(llm_cfg)
if is_pytest:
class _StubDetector:
def run(
self, *, user_query: str, schema_preview: Optional[str] = None
) -> StageResult:
return StageResult(
ok=True,
data={"questions": []},
trace={
"stage": "detector",
"duration_ms": 0,
"notes": {"ambiguous": False, "questions_len": 0},
},
)
class _StubPlanner:
def __init__(self, llm: Any = None) -> None: ...
def run(
self, *, user_query: str, schema_preview: Optional[str] = None
) -> StageResult:
return StageResult(
ok=True,
data={"plan": "stub plan"},
trace={
"stage": "planner",
"duration_ms": 0,
"notes": {"len_plan": 8},
},
)
class _StubGenerator:
def __init__(self, llm: Any = None) -> None: ...
def run(
self,
*,
user_query: str,
schema_preview: Optional[str] = None,
plan_text: Optional[str] = None,
clarify_answers: Optional[Dict[str, Any]] = None,
) -> StageResult:
return StageResult(
ok=True,
data={"sql": "SELECT 1;", "rationale": "stub"},
trace={
"stage": "generator",
"duration_ms": 0,
"notes": {"rationale_len": 4},
},
)
class _StubExecutor:
def __init__(self, db: DBAdapter | None = None) -> None: ...
def run(self, *, sql: str) -> StageResult:
rows = [{"x": 1}]
return StageResult(
ok=True,
data={"rows": rows, "row_count": len(rows)},
trace={
"stage": "executor",
"duration_ms": 0,
"notes": {"row_count": len(rows)},
},
)
class _StubVerifier:
def run(self, *, sql: str, exec_result: Dict[str, Any]) -> StageResult:
return StageResult(
ok=True,
data={"verified": True},
trace={"stage": "verifier", "duration_ms": 0, "notes": None},
)
class _StubRepair:
def __init__(self, llm: Any = None) -> None: ...
def run(
self, *, sql: str, error_msg: str, schema_preview: Optional[str] = None
) -> StageResult:
return StageResult(
ok=True,
data={"sql": sql},
trace={"stage": "repair", "duration_ms": 0, "notes": None},
)
detector = _StubDetector()
planner = _StubPlanner()
generator = _StubGenerator()
safety = SAFETIES[cfg.get("safety", "default")]()
executor = _StubExecutor(db=adapter)
verifier = _StubVerifier()
repair = _StubRepair()
else:
detector = DETECTORS[cfg.get("detector", "default")]()
planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
safety = SAFETIES[cfg.get("safety", "default")]()
executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
return Pipeline(
detector=detector,
planner=planner,
generator=generator,
safety=safety,
executor=executor,
verifier=verifier,
repair=repair,
)