File size: 4,069 Bytes
34a177c
 
 
343ad62
34a177c
343ad62
 
 
 
 
 
 
 
 
 
34a177c
343ad62
 
34a177c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343ad62
 
 
34a177c
343ad62
34a177c
 
343ad62
34a177c
343ad62
 
 
 
34a177c
 
 
 
 
 
 
 
 
 
 
343ad62
34a177c
 
 
 
343ad62
 
 
34a177c
 
 
 
 
 
 
 
 
343ad62
34a177c
 
343ad62
34a177c
343ad62
34a177c
343ad62
 
 
 
 
 
 
 
 
 
 
 
 
34a177c
 
 
343ad62
 
 
34a177c
 
 
343ad62
34a177c
 
343ad62
34a177c
343ad62
34a177c
343ad62
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import annotations

from typing import Any, Dict, Optional, cast
import yaml

from nl2sql.pipeline import Pipeline
from nl2sql.registry import (
    DETECTORS,
    PLANNERS,
    GENERATORS,
    SAFETIES,
    EXECUTORS,
    VERIFIERS,
    REPAIRS,
)
from adapters.db.base import DBAdapter
from adapters.db.sqlite_adapter import SQLiteAdapter
from adapters.db.postgres_adapter import PostgresAdapter

# 🔁 Use your real LLM provider here
from adapters.llm.openai_provider import OpenAIProvider  # noqa: F401


# ------------------ helpers ------------------ #
def _require_str(value: Any, *, name: str) -> str:
    if value is None:
        raise ValueError(f"Missing required string config: {name}")
    if not isinstance(value, str):
        raise TypeError(f"Config {name} must be a string, got {type(value).__name__}")
    v = value.strip()
    if not v:
        raise ValueError(f"Config {name} cannot be empty")
    return v


def _build_adapter(adapter_cfg: Dict[str, Any]) -> DBAdapter:
    kind = (adapter_cfg.get("kind") or "sqlite").lower()
    if kind == "sqlite":
        dsn = _require_str(adapter_cfg.get("dsn"), name="adapter.dsn")
        return SQLiteAdapter(dsn)
    if kind == "postgres":
        # expect keys like {"kind":"postgres","dsn":"postgresql://..."} OR kwargs your adapter needs
        return PostgresAdapter(**adapter_cfg)
    raise ValueError(f"Unknown adapter kind: {kind}")


def _build_llm(llm_cfg: Optional[Dict[str, Any]] = None) -> Any:
    """
    Create an LLM client/provider instance.
    Adjust this to your real signature (model name, base_url, api_key in env, etc.).
    """
    _ = llm_cfg or {}
    # Example: OpenAIProvider() reads env; or pass model via cfg.
    return OpenAIProvider()


# ------------------ main: config → Pipeline ------------------ #
def pipeline_from_config(path: str) -> Pipeline:
    """
    Build a Pipeline from YAML configuration.
    Inject proper constructor dependencies (llm, db/adapter) to satisfy mypy signatures.
    """
    with open(path, "r", encoding="utf-8") as fh:
        cfg: Dict[str, Any] = yaml.safe_load(fh)

    # Optional sections
    adapter_cfg = cast(Dict[str, Any], cfg.get("adapter", {}))
    llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))

    # Core deps
    adapter = _build_adapter(adapter_cfg)
    llm = _build_llm(llm_cfg)

    # Instantiate stages with required ctor args
    detector = DETECTORS[cfg.get("detector", "default")]()
    planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
    generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
    safety = SAFETIES[cfg.get("safety", "default")]()
    executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
    verifier = VERIFIERS[cfg.get("verifier", "basic")]()
    repair = REPAIRS[cfg.get("repair", "default")](llm=llm)

    return Pipeline(
        detector=detector,
        planner=planner,
        generator=generator,
        safety=safety,
        executor=executor,
        verifier=verifier,
        repair=repair,
    )


def pipeline_from_config_with_adapter(path: str, *, adapter: DBAdapter) -> Pipeline:
    """
    Same as pipeline_from_config, but force a specific adapter (per-request override).
    """
    with open(path, "r", encoding="utf-8") as fh:
        cfg: Dict[str, Any] = yaml.safe_load(fh)

    llm_cfg = cast(Optional[Dict[str, Any]], cfg.get("llm"))
    llm = _build_llm(llm_cfg)

    detector = DETECTORS[cfg.get("detector", "default")]()
    planner = PLANNERS[cfg.get("planner", "default")](llm=llm)
    generator = GENERATORS[cfg.get("generator", "rules")](llm=llm)
    safety = SAFETIES[cfg.get("safety", "default")]()
    executor = EXECUTORS[cfg.get("executor", "default")](db=adapter)
    verifier = VERIFIERS[cfg.get("verifier", "basic")]()
    repair = REPAIRS[cfg.get("repair", "default")](llm=llm)

    return Pipeline(
        detector=detector,
        planner=planner,
        generator=generator,
        safety=safety,
        executor=executor,
        verifier=verifier,
        repair=repair,
    )