nl2sql-copilot / nl2sql /pipeline_factory.py
Melika Kheirieh
chore(factory): safely load .env via dotenv (with fallback under CI)
f8224ec
raw
history blame
11.2 kB
from __future__ import annotations
import os
from typing import Any, Dict, Optional, cast
import yaml # type: ignore[import-untyped]
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
from nl2sql.pipeline import Pipeline
from nl2sql.registry import (
DETECTORS,
PLANNERS,
GENERATORS,
SAFETIES,
EXECUTORS,
VERIFIERS,
REPAIRS,
)
from nl2sql.types import StageResult, StageTrace
from nl2sql.ambiguity_detector import AmbiguityDetector
from nl2sql.planner import Planner
from nl2sql.generator import Generator
from nl2sql.executor import Executor
from nl2sql.verifier import Verifier
from nl2sql.repair import Repair
from adapters.db.base import DBAdapter
from adapters.db.sqlite_adapter import SQLiteAdapter
from adapters.db.postgres_adapter import PostgresAdapter
from adapters.llm.openai_provider import OpenAIProvider
# ------------------------------ helpers ------------------------------ #
def _require_str(value: Any, *, name: str) -> str:
if value is None or not isinstance(value, str) or not value.strip():
raise ValueError(f"Config {name} must be a non-empty string")
return value.strip()
def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
kind = (adapter_cfg.get("kind") or "sqlite").lower()
if kind == "sqlite":
dsn = _require_str(adapter_cfg.get("dsn"), name="adapter.dsn")
return SQLiteAdapter(dsn)
if kind == "postgres":
return PostgresAdapter(**adapter_cfg)
raise ValueError(f"Unknown adapter kind: {kind}")
def _build_llm(llm_cfg: Optional[Dict[str, Any]] = None) -> Any:
"""Under pytest return None (stubs handle logic); otherwise real OpenAI provider."""
if os.getenv("PYTEST_CURRENT_TEST"):
return None
_ = llm_cfg or {}
return OpenAIProvider()
def _is_pytest() -> bool:
return bool(os.getenv("PYTEST_CURRENT_TEST"))
def _tr(
stage: str,
*,
duration_ms: int = 0,
notes: Optional[Dict[str, Any]] = None,
token_in: Optional[int] = None,
token_out: Optional[int] = None,
cost_usd: Optional[float] = None,
) -> StageTrace:
return StageTrace(
stage=stage,
duration_ms=duration_ms,
notes=notes,
token_in=token_in,
token_out=token_out,
cost_usd=cost_usd,
)
# ------------------------------ factory ------------------------------ #
def pipeline_from_config(path: str) -> Pipeline:
"""
Build a Pipeline instance from YAML configuration (dependency-injected).
Under pytest, use full stub components and an in-memory SQLite DB.
"""
with open(path, "r", encoding="utf-8") as fh:
cfg: Dict[str, Any] = yaml.safe_load(fh)
is_pytest = _is_pytest()
# --- Adapter ---
adapter_cfg = cast(Dict[str, Any], cfg.get("adapter", {}))
if is_pytest:
adapter_cfg = {"kind": "sqlite", "dsn": ":memory:"}
adapter = _build_adapter(adapter_cfg)
# --- LLM ---
llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
llm = _build_llm(llm_cfg)
if is_pytest:
# ---------- stubs: domain-shaped + StageResult on run() ----------
class _StubDetector:
def detect(self, *args, **kwargs) -> list[str]:
return []
def run(self, *args, **kwargs) -> StageResult:
return StageResult(
ok=True,
data={"questions": []},
trace=_tr(
"detector", notes={"ambiguous": False, "questions_len": 0}
),
)
class _StubPlanner:
def __init__(self, llm: Any = None) -> None: ...
def plan(self, *args, **kwargs) -> str:
return "stub plan"
def run(self, *args, **kwargs) -> StageResult:
plan = self.plan(*args, **kwargs)
return StageResult(
ok=True,
data={"plan": plan},
trace=_tr("planner", notes={"len_plan": len(plan)}),
)
class _StubGenerator:
def __init__(self, llm: Any = None) -> None: ...
def generate(self, *args, **kwargs) -> tuple[str, str]:
return "SELECT 1;", "stub"
def run(self, *args, **kwargs) -> StageResult:
sql, rationale = self.generate(*args, **kwargs)
return StageResult(
ok=True,
data={"sql": sql, "rationale": rationale},
trace=_tr("generator", notes={"rationale_len": len(rationale)}),
)
class _StubExecutor:
def __init__(self, db: Any | None = None) -> None: ...
def execute(self, *args, **kwargs) -> Dict[str, Any]:
rows = [{"x": 1}]
return {"rows": rows, "row_count": len(rows)}
def run(self, *args, **kwargs) -> StageResult:
out = self.execute(*args, **kwargs)
return StageResult(
ok=True,
data=out,
trace=_tr("executor", notes={"row_count": out["row_count"]}),
)
class _StubVerifier:
def verify(self, *args, **kwargs) -> bool:
return True
def run(self, *args, **kwargs) -> StageResult:
return StageResult(
ok=True, data={"verified": True}, trace=_tr("verifier")
)
class _StubRepair:
def __init__(self, llm: Any = None) -> None: ...
def repair(self, *args, **kwargs) -> str:
return kwargs.get("sql") or "SELECT 1;"
def run(self, *args, **kwargs) -> StageResult:
sql = self.repair(*args, **kwargs)
return StageResult(ok=True, data={"sql": sql}, trace=_tr("repair"))
detector = cast(AmbiguityDetector, _StubDetector())
planner = cast(Planner, _StubPlanner())
generator = cast(Generator, _StubGenerator())
safety = SAFETIES[cfg.get("safety", "default")]()
executor = cast(Executor, _StubExecutor(db=adapter))
verifier = cast(Verifier, _StubVerifier())
repair = cast(Repair, _StubRepair())
else:
detector = DETECTORS[cfg.get("detector", "default")]()
planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
safety = SAFETIES[cfg.get("safety", "default")]()
executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
return Pipeline(
detector=detector,
planner=planner,
generator=generator,
safety=safety,
executor=executor,
verifier=verifier,
repair=repair,
)
def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
"""
Same as pipeline_from_config, but force a given adapter (used for db_id overrides).
Under pytest, still use stubs to avoid external dependencies.
"""
with open(path, "r", encoding="utf-8") as fh:
cfg: Dict[str, Any] = yaml.safe_load(fh)
is_pytest = _is_pytest()
llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
llm = _build_llm(llm_cfg)
if is_pytest:
class _StubDetector:
def detect(self, *args, **kwargs) -> list[str]:
return []
def run(self, *args, **kwargs) -> StageResult:
return StageResult(
ok=True,
data={"questions": []},
trace=_tr(
"detector", notes={"ambiguous": False, "questions_len": 0}
),
)
class _StubPlanner:
def __init__(self, llm: Any = None) -> None: ...
def plan(self, *args, **kwargs) -> str:
return "stub plan"
def run(self, *args, **kwargs) -> StageResult:
plan = self.plan(*args, **kwargs)
return StageResult(
ok=True,
data={"plan": plan},
trace=_tr("planner", notes={"len_plan": len(plan)}),
)
class _StubGenerator:
def __init__(self, llm: Any = None) -> None: ...
def generate(self, *args, **kwargs) -> tuple[str, str]:
return "SELECT 1;", "stub"
def run(self, *args, **kwargs) -> StageResult:
sql, rationale = self.generate(*args, **kwargs)
return StageResult(
ok=True,
data={"sql": sql, "rationale": rationale},
trace=_tr("generator", notes={"rationale_len": len(rationale)}),
)
class _StubExecutor:
def __init__(self, db: Any | None = None) -> None: ...
def execute(self, *args, **kwargs) -> Dict[str, Any]:
rows = [{"x": 1}]
return {"rows": rows, "row_count": len(rows)}
def run(self, *args, **kwargs) -> StageResult:
out = self.execute(*args, **kwargs)
return StageResult(
ok=True,
data=out,
trace=_tr("executor", notes={"row_count": out["row_count"]}),
)
class _StubVerifier:
def verify(self, *args, **kwargs) -> bool:
return True
def run(self, *args, **kwargs) -> StageResult:
return StageResult(
ok=True, data={"verified": True}, trace=_tr("verifier")
)
class _StubRepair:
def __init__(self, llm: Any = None) -> None: ...
def repair(self, *args, **kwargs) -> str:
return kwargs.get("sql") or "SELECT 1;"
def run(self, *args, **kwargs) -> StageResult:
sql = self.repair(*args, **kwargs)
return StageResult(ok=True, data={"sql": sql}, trace=_tr("repair"))
detector = cast(AmbiguityDetector, _StubDetector())
planner = cast(Planner, _StubPlanner())
generator = cast(Generator, _StubGenerator())
safety = SAFETIES[cfg.get("safety", "default")]()
executor = cast(Executor, _StubExecutor(db=adapter))
verifier = cast(Verifier, _StubVerifier())
repair = cast(Repair, _StubRepair())
else:
detector = DETECTORS[cfg.get("detector", "default")]()
planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
safety = SAFETIES[cfg.get("safety", "default")]()
executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
return Pipeline(
detector=detector,
planner=planner,
generator=generator,
safety=safety,
executor=executor,
verifier=verifier,
repair=repair,
)