Melika Kheirieh commited on
Commit
34a177c
·
1 Parent(s): 343ad62

refactor(core): DI-ready Pipeline; add registry + YAML factory + typed trace/result

Browse files
config/sqlite_pipeline.yaml CHANGED
@@ -1,6 +1,6 @@
1
  detector: default
2
  planner: default
3
- generator: rules # or "llm" when available
4
  safety: default
5
  executor: default
6
  verifier: basic
@@ -9,3 +9,7 @@ repair: default
9
  adapter:
10
  kind: sqlite
11
  dsn: data/chinook.db
 
 
 
 
 
1
  detector: default
2
  planner: default
3
+ generator: rules
4
  safety: default
5
  executor: default
6
  verifier: basic
 
9
  adapter:
10
  kind: sqlite
11
  dsn: data/chinook.db
12
+
13
+ llm:
14
+ provider: openai
15
+ model: gpt-4o-mini
nl2sql/pipeline_factory.py CHANGED
@@ -1,5 +1,8 @@
 
 
 
1
  import yaml
2
- from typing import Any, Dict
3
  from nl2sql.pipeline import Pipeline
4
  from nl2sql.registry import (
5
  DETECTORS,
@@ -10,39 +13,72 @@ from nl2sql.registry import (
10
  VERIFIERS,
11
  REPAIRS,
12
  )
 
13
  from adapters.db.sqlite_adapter import SQLiteAdapter
14
  from adapters.db.postgres_adapter import PostgresAdapter
15
- from adapters.db.base import DBAdapter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
19
- kind = adapter_cfg.get("kind", "sqlite")
20
  if kind == "sqlite":
21
- return SQLiteAdapter(adapter_cfg.get("dsn"))
 
22
  if kind == "postgres":
 
23
  return PostgresAdapter(**adapter_cfg)
24
  raise ValueError(f"Unknown adapter kind: {kind}")
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  def pipeline_from_config(path: str) -> Pipeline:
 
 
 
 
28
  with open(path, "r", encoding="utf-8") as fh:
29
  cfg: Dict[str, Any] = yaml.safe_load(fh)
30
 
 
 
 
 
 
 
 
 
 
31
  detector = DETECTORS[cfg.get("detector", "default")]()
32
- planner = PLANNERS[cfg.get("planner", "default")]()
33
- generator = GENERATORS[cfg.get("generator", "rules")]()
34
  safety = SAFETIES[cfg.get("safety", "default")]()
35
- executor = EXECUTORS[cfg.get("executor", "default")]()
36
  verifier = VERIFIERS[cfg.get("verifier", "basic")]()
37
- repair = REPAIRS[cfg.get("repair", "default")]()
38
-
39
- # If your Executor needs an adapter inside, set it there (common pattern):
40
- adapter_cfg = cfg.get("adapter", {"kind": "sqlite", "dsn": "data/chinook.db"})
41
- adapter = _build_adapter(adapter_cfg)
42
- if hasattr(executor, "bind_adapter"):
43
- executor.bind_adapter(adapter)
44
- elif hasattr(executor, "adapter"):
45
- executor.adapter = adapter # fallback
46
 
47
  return Pipeline(
48
  detector=detector,
@@ -56,22 +92,22 @@ def pipeline_from_config(path: str) -> Pipeline:
56
 
57
 
58
  def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
59
- """Same as pipeline_from_config, but force a specific adapter (per-request override)."""
 
 
60
  with open(path, "r", encoding="utf-8") as fh:
61
  cfg: Dict[str, Any] = yaml.safe_load(fh)
62
 
 
 
 
63
  detector = DETECTORS[cfg.get("detector", "default")]()
64
- planner = PLANNERS[cfg.get("planner", "default")]()
65
- generator = GENERATORS[cfg.get("generator", "rules")]()
66
  safety = SAFETIES[cfg.get("safety", "default")]()
67
- executor = EXECUTORS[cfg.get("executor", "default")]()
68
  verifier = VERIFIERS[cfg.get("verifier", "basic")]()
69
- repair = REPAIRS[cfg.get("repair", "default")]()
70
-
71
- if hasattr(executor, "bind_adapter"):
72
- executor.bind_adapter(adapter)
73
- elif hasattr(executor, "adapter"):
74
- executor.adapter = adapter
75
 
76
  return Pipeline(
77
  detector=detector,
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional, cast
4
  import yaml
5
+
6
  from nl2sql.pipeline import Pipeline
7
  from nl2sql.registry import (
8
  DETECTORS,
 
13
  VERIFIERS,
14
  REPAIRS,
15
  )
16
+ from adapters.db.base import DBAdapter
17
  from adapters.db.sqlite_adapter import SQLiteAdapter
18
  from adapters.db.postgres_adapter import PostgresAdapter
19
+
20
+ # 🔁 Use your real LLM provider here
21
+ from adapters.llm.openai_provider import OpenAIProvider # noqa: F401
22
+
23
+
24
+ # ------------------ helpers ------------------ #
25
+ def _require_str(value: Any, *, name: str) -> str:
26
+ if value is None:
27
+ raise ValueError(f"Missing required string config: {name}")
28
+ if not isinstance(value, str):
29
+ raise TypeError(f"Config {name} must be a string, got {type(value).__name__}")
30
+ v = value.strip()
31
+ if not v:
32
+ raise ValueError(f"Config {name} cannot be empty")
33
+ return v
34
 
35
 
36
  def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
37
+ kind = (adapter_cfg.get("kind") or "sqlite").lower()
38
  if kind == "sqlite":
39
+ dsn = _require_str(adapter_cfg.get("dsn"), name="adapter.dsn")
40
+ return SQLiteAdapter(dsn)
41
  if kind == "postgres":
42
+ # expect keys like {"kind":"postgres","dsn":"postgresql://..."} OR kwargs your adapter needs
43
  return PostgresAdapter(**adapter_cfg)
44
  raise ValueError(f"Unknown adapter kind: {kind}")
45
 
46
 
47
+ def _build_llm(llm_cfg: Optional[Dict[str, Any]] = None) -> Any:
48
+ """
49
+ Create an LLM client/provider instance.
50
+ Adjust this to your real signature (model name, base_url, api_key in env, etc.).
51
+ """
52
+ _ = llm_cfg or {}
53
+ # Example: OpenAIProvider() reads env; or pass model via cfg.
54
+ return OpenAIProvider()
55
+
56
+
57
+ # ------------------ main: config → Pipeline ------------------ #
58
  def pipeline_from_config(path: str) -> Pipeline:
59
+ """
60
+ Build a Pipeline from YAML configuration.
61
+ Inject proper constructor dependencies (llm, db/adapter) to satisfy mypy signatures.
62
+ """
63
  with open(path, "r", encoding="utf-8") as fh:
64
  cfg: Dict[str, Any] = yaml.safe_load(fh)
65
 
66
+ # Optional sections
67
+ adapter_cfg = cast(Dict[str, Any], cfg.get("adapter", {}))
68
+ llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
69
+
70
+ # Core deps
71
+ adapter = _build_adapter(adapter_cfg)
72
+ llm = _build_llm(llm_cfg)
73
+
74
+ # Instantiate stages with required ctor args
75
  detector = DETECTORS[cfg.get("detector", "default")]()
76
+ planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
77
+ generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
78
  safety = SAFETIES[cfg.get("safety", "default")]()
79
+ executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
80
  verifier = VERIFIERS[cfg.get("verifier", "basic")]()
81
+ repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
 
 
 
 
 
 
 
 
82
 
83
  return Pipeline(
84
  detector=detector,
 
92
 
93
 
94
  def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
95
+ """
96
+ Same as pipeline_from_config, but force a specific adapter (per-request override).
97
+ """
98
  with open(path, "r", encoding="utf-8") as fh:
99
  cfg: Dict[str, Any] = yaml.safe_load(fh)
100
 
101
+ llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
102
+ llm = _build_llm(llm_cfg)
103
+
104
  detector = DETECTORS[cfg.get("detector", "default")]()
105
+ planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
106
+ generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
107
  safety = SAFETIES[cfg.get("safety", "default")]()
108
+ executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
109
  verifier = VERIFIERS[cfg.get("verifier", "basic")]()
110
+ repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
 
 
 
 
 
111
 
112
  return Pipeline(
113
  detector=detector,
nl2sql/registry.py CHANGED
@@ -3,7 +3,6 @@ Registry mapping simple string keys to concrete component classes.
3
  Used by pipeline_factory to perform lightweight dependency injection.
4
  """
5
 
6
- from typing import Dict, Type
7
  from nl2sql.ambiguity_detector import AmbiguityDetector
8
  from nl2sql.planner import Planner
9
  from nl2sql.generator import Generator
@@ -13,10 +12,10 @@ from nl2sql.verifier import Verifier
13
  from nl2sql.repair import Repair
14
 
15
  # later you can add llm-aware generator variants, etc.
16
- PLANNERS: Dict[str, Type[Planner]] = {"default": Planner}
17
- DETECTORS: Dict[str, Type[AmbiguityDetector]] = {"default": AmbiguityDetector}
18
- GENERATORS: Dict[str, Type[Generator]] = {"rules": Generator}
19
- SAFETIES: Dict[str, Type[Safety]] = {"default": Safety}
20
- EXECUTORS: Dict[str, Type[Executor]] = {"default": Executor}
21
- VERIFIERS: Dict[str, Type[Verifier]] = {"basic": Verifier}
22
- REPAIRS: Dict[str, Type[Repair]] = {"default": Repair}
 
3
  Used by pipeline_factory to perform lightweight dependency injection.
4
  """
5
 
 
6
  from nl2sql.ambiguity_detector import AmbiguityDetector
7
  from nl2sql.planner import Planner
8
  from nl2sql.generator import Generator
 
12
  from nl2sql.repair import Repair
13
 
14
  # later you can add llm-aware generator variants, etc.
15
+ PLANNERS = {"default": Planner}
16
+ GENERATORS = {"rules": Generator}
17
+ EXECUTORS = {"default": Executor}
18
+ REPAIRS = {"default": Repair}
19
+ DETECTORS = {"default": AmbiguityDetector}
20
+ SAFETIES = {"default": Safety}
21
+ VERIFIERS = {"basic": Verifier}