Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
34a177c
1
Parent(s):
343ad62
refactor(core): DI-ready Pipeline; add registry + YAML factory + typed trace/result
Browse files- config/sqlite_pipeline.yaml +5 -1
- nl2sql/pipeline_factory.py +62 -26
- nl2sql/registry.py +7 -8
config/sqlite_pipeline.yaml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
detector: default
|
| 2 |
planner: default
|
| 3 |
-
generator: rules
|
| 4 |
safety: default
|
| 5 |
executor: default
|
| 6 |
verifier: basic
|
|
@@ -9,3 +9,7 @@ repair: default
|
|
| 9 |
adapter:
|
| 10 |
kind: sqlite
|
| 11 |
dsn: data/chinook.db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
detector: default
|
| 2 |
planner: default
|
| 3 |
+
generator: rules
|
| 4 |
safety: default
|
| 5 |
executor: default
|
| 6 |
verifier: basic
|
|
|
|
| 9 |
adapter:
|
| 10 |
kind: sqlite
|
| 11 |
dsn: data/chinook.db
|
| 12 |
+
|
| 13 |
+
llm:
|
| 14 |
+
provider: openai
|
| 15 |
+
model: gpt-4o-mini
|
nl2sql/pipeline_factory.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import yaml
|
| 2 |
-
|
| 3 |
from nl2sql.pipeline import Pipeline
|
| 4 |
from nl2sql.registry import (
|
| 5 |
DETECTORS,
|
|
@@ -10,39 +13,72 @@ from nl2sql.registry import (
|
|
| 10 |
VERIFIERS,
|
| 11 |
REPAIRS,
|
| 12 |
)
|
|
|
|
| 13 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 14 |
from adapters.db.postgres_adapter import PostgresAdapter
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
|
| 19 |
-
kind = adapter_cfg.get("kind"
|
| 20 |
if kind == "sqlite":
|
| 21 |
-
|
|
|
|
| 22 |
if kind == "postgres":
|
|
|
|
| 23 |
return PostgresAdapter(**adapter_cfg)
|
| 24 |
raise ValueError(f"Unknown adapter kind: {kind}")
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def pipeline_from_config(path: str) -> Pipeline:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
with open(path, "r", encoding="utf-8") as fh:
|
| 29 |
cfg: Dict[str, Any] = yaml.safe_load(fh)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
detector = DETECTORS[cfg.get("detector", "default")]()
|
| 32 |
-
planner = PLANNERS[cfg.get("planner", "default")]()
|
| 33 |
-
generator = GENERATORS[cfg.get("generator", "rules")]()
|
| 34 |
safety = SAFETIES[cfg.get("safety", "default")]()
|
| 35 |
-
executor = EXECUTORS[cfg.get("executor", "default")]()
|
| 36 |
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
|
| 37 |
-
repair = REPAIRS[cfg.get("repair", "default")]()
|
| 38 |
-
|
| 39 |
-
# If your Executor needs an adapter inside, set it there (common pattern):
|
| 40 |
-
adapter_cfg = cfg.get("adapter", {"kind": "sqlite", "dsn": "data/chinook.db"})
|
| 41 |
-
adapter = _build_adapter(adapter_cfg)
|
| 42 |
-
if hasattr(executor, "bind_adapter"):
|
| 43 |
-
executor.bind_adapter(adapter)
|
| 44 |
-
elif hasattr(executor, "adapter"):
|
| 45 |
-
executor.adapter = adapter # fallback
|
| 46 |
|
| 47 |
return Pipeline(
|
| 48 |
detector=detector,
|
|
@@ -56,22 +92,22 @@ def pipeline_from_config(path: str) -> Pipeline:
|
|
| 56 |
|
| 57 |
|
| 58 |
def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
|
| 59 |
-
"""
|
|
|
|
|
|
|
| 60 |
with open(path, "r", encoding="utf-8") as fh:
|
| 61 |
cfg: Dict[str, Any] = yaml.safe_load(fh)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
detector = DETECTORS[cfg.get("detector", "default")]()
|
| 64 |
-
planner = PLANNERS[cfg.get("planner", "default")]()
|
| 65 |
-
generator = GENERATORS[cfg.get("generator", "rules")]()
|
| 66 |
safety = SAFETIES[cfg.get("safety", "default")]()
|
| 67 |
-
executor = EXECUTORS[cfg.get("executor", "default")]()
|
| 68 |
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
|
| 69 |
-
repair = REPAIRS[cfg.get("repair", "default")]()
|
| 70 |
-
|
| 71 |
-
if hasattr(executor, "bind_adapter"):
|
| 72 |
-
executor.bind_adapter(adapter)
|
| 73 |
-
elif hasattr(executor, "adapter"):
|
| 74 |
-
executor.adapter = adapter
|
| 75 |
|
| 76 |
return Pipeline(
|
| 77 |
detector=detector,
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Optional, cast
|
| 4 |
import yaml
|
| 5 |
+
|
| 6 |
from nl2sql.pipeline import Pipeline
|
| 7 |
from nl2sql.registry import (
|
| 8 |
DETECTORS,
|
|
|
|
| 13 |
VERIFIERS,
|
| 14 |
REPAIRS,
|
| 15 |
)
|
| 16 |
+
from adapters.db.base import DBAdapter
|
| 17 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 18 |
from adapters.db.postgres_adapter import PostgresAdapter
|
| 19 |
+
|
| 20 |
+
# 🔁 Use your real LLM provider here
|
| 21 |
+
from adapters.llm.openai_provider import OpenAIProvider # noqa: F401
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ------------------ helpers ------------------ #
|
| 25 |
+
def _require_str(value: Any, *, name: str) -> str:
|
| 26 |
+
if value is None:
|
| 27 |
+
raise ValueError(f"Missing required string config: {name}")
|
| 28 |
+
if not isinstance(value, str):
|
| 29 |
+
raise TypeError(f"Config {name} must be a string, got {type(value).__name__}")
|
| 30 |
+
v = value.strip()
|
| 31 |
+
if not v:
|
| 32 |
+
raise ValueError(f"Config {name} cannot be empty")
|
| 33 |
+
return v
|
| 34 |
|
| 35 |
|
| 36 |
def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
|
| 37 |
+
kind = (adapter_cfg.get("kind") or "sqlite").lower()
|
| 38 |
if kind == "sqlite":
|
| 39 |
+
dsn = _require_str(adapter_cfg.get("dsn"), name="adapter.dsn")
|
| 40 |
+
return SQLiteAdapter(dsn)
|
| 41 |
if kind == "postgres":
|
| 42 |
+
# expect keys like {"kind":"postgres","dsn":"postgresql://..."} OR kwargs your adapter needs
|
| 43 |
return PostgresAdapter(**adapter_cfg)
|
| 44 |
raise ValueError(f"Unknown adapter kind: {kind}")
|
| 45 |
|
| 46 |
|
| 47 |
+
def _build_llm(llm_cfg: Optional[Dict[str, Any]] = None) -> Any:
|
| 48 |
+
"""
|
| 49 |
+
Create an LLM client/provider instance.
|
| 50 |
+
Adjust this to your real signature (model name, base_url, api_key in env, etc.).
|
| 51 |
+
"""
|
| 52 |
+
_ = llm_cfg or {}
|
| 53 |
+
# Example: OpenAIProvider() reads env; or pass model via cfg.
|
| 54 |
+
return OpenAIProvider()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ------------------ main: config → Pipeline ------------------ #
|
| 58 |
def pipeline_from_config(path: str) -> Pipeline:
|
| 59 |
+
"""
|
| 60 |
+
Build a Pipeline from YAML configuration.
|
| 61 |
+
Inject proper constructor dependencies (llm, db/adapter) to satisfy mypy signatures.
|
| 62 |
+
"""
|
| 63 |
with open(path, "r", encoding="utf-8") as fh:
|
| 64 |
cfg: Dict[str, Any] = yaml.safe_load(fh)
|
| 65 |
|
| 66 |
+
# Optional sections
|
| 67 |
+
adapter_cfg = cast(Dict[str, Any], cfg.get("adapter", {}))
|
| 68 |
+
llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
|
| 69 |
+
|
| 70 |
+
# Core deps
|
| 71 |
+
adapter = _build_adapter(adapter_cfg)
|
| 72 |
+
llm = _build_llm(llm_cfg)
|
| 73 |
+
|
| 74 |
+
# Instantiate stages with required ctor args
|
| 75 |
detector = DETECTORS[cfg.get("detector", "default")]()
|
| 76 |
+
planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
|
| 77 |
+
generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
|
| 78 |
safety = SAFETIES[cfg.get("safety", "default")]()
|
| 79 |
+
executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
|
| 80 |
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
|
| 81 |
+
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
return Pipeline(
|
| 84 |
detector=detector,
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
|
| 95 |
+
"""
|
| 96 |
+
Same as pipeline_from_config, but force a specific adapter (per-request override).
|
| 97 |
+
"""
|
| 98 |
with open(path, "r", encoding="utf-8") as fh:
|
| 99 |
cfg: Dict[str, Any] = yaml.safe_load(fh)
|
| 100 |
|
| 101 |
+
llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
|
| 102 |
+
llm = _build_llm(llm_cfg)
|
| 103 |
+
|
| 104 |
detector = DETECTORS[cfg.get("detector", "default")]()
|
| 105 |
+
planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
|
| 106 |
+
generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
|
| 107 |
safety = SAFETIES[cfg.get("safety", "default")]()
|
| 108 |
+
executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
|
| 109 |
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
|
| 110 |
+
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
return Pipeline(
|
| 113 |
detector=detector,
|
nl2sql/registry.py
CHANGED
|
@@ -3,7 +3,6 @@ Registry mapping simple string keys to concrete component classes.
|
|
| 3 |
Used by pipeline_factory to perform lightweight dependency injection.
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
from typing import Dict, Type
|
| 7 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 8 |
from nl2sql.planner import Planner
|
| 9 |
from nl2sql.generator import Generator
|
|
@@ -13,10 +12,10 @@ from nl2sql.verifier import Verifier
|
|
| 13 |
from nl2sql.repair import Repair
|
| 14 |
|
| 15 |
# later you can add llm-aware generator variants, etc.
|
| 16 |
-
PLANNERS
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
| 3 |
Used by pipeline_factory to perform lightweight dependency injection.
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 6 |
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 7 |
from nl2sql.planner import Planner
|
| 8 |
from nl2sql.generator import Generator
|
|
|
|
| 12 |
from nl2sql.repair import Repair
|
| 13 |
|
| 14 |
# later you can add llm-aware generator variants, etc.
|
| 15 |
+
PLANNERS = {"default": Planner}
|
| 16 |
+
GENERATORS = {"rules": Generator}
|
| 17 |
+
EXECUTORS = {"default": Executor}
|
| 18 |
+
REPAIRS = {"default": Repair}
|
| 19 |
+
DETECTORS = {"default": AmbiguityDetector}
|
| 20 |
+
SAFETIES = {"default": Safety}
|
| 21 |
+
VERIFIERS = {"basic": Verifier}
|