File size: 3,759 Bytes
052c644
 
 
 
1af43ae
052c644
 
 
1af43ae
052c644
 
 
1af43ae
052c644
 
 
1af43ae
 
 
 
 
 
052c644
 
 
 
1af43ae
 
 
 
 
 
052c644
 
 
 
1af43ae
 
 
 
 
 
 
 
 
 
 
052c644
 
 
1af43ae
052c644
 
 
 
 
 
 
 
1af43ae
 
 
 
 
 
 
 
052c644
 
 
 
 
 
 
 
 
1af43ae
052c644
 
 
 
 
 
 
1af43ae
 
 
 
 
 
 
052c644
 
 
 
1af43ae
052c644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1af43ae
 
 
052c644
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import sqlite3
from nl2sql.pipeline import Pipeline
from nl2sql.types import StageResult, StageTrace


# --- Realistic dummy stages ----------------------------------
class DetectorOK:
    """Always returns no ambiguities."""

    def detect(self, *a, **k):
        return []


class PlannerLLM:
    def run(self, *, user_query, schema_preview):
        plan = f"Understand user query '{user_query}' and map to table."
        return StageResult(
            ok=True,
            data={"plan": plan},
            trace=StageTrace(stage="planner", duration_ms=0),
        )


class GeneratorSimple:
    def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
        sql = "SELECT city, COUNT(*) AS cnt FROM users GROUP BY city"
        return StageResult(
            ok=True,
            data={"sql": sql, "rationale": plan_text},
            trace=StageTrace(stage="generator", duration_ms=0),
        )


class SafetyReadOnly:
    def run(self, *, sql):
        if sql.strip().lower().startswith("select"):
            return StageResult(
                ok=True,
                data={"sql": sql},
                trace=StageTrace(stage="safety", duration_ms=0),
            )
        return StageResult(
            ok=False,
            error=["Unsafe query"],
            trace=StageTrace(stage="safety", duration_ms=0, notes={"reason": "unsafe"}),
        )


class ExecutorSQLite:
    """Executes the SQL query on a temporary in-memory SQLite database."""

    def __init__(self):
        # create in-memory DB and seed some rows
        self.conn = sqlite3.connect(":memory:")
        self._seed()

    def _seed(self):
        cur = self.conn.cursor()
        cur.execute("CREATE TABLE users (id INTEGER, city TEXT)")
        cur.executemany(
            "INSERT INTO users VALUES (?, ?)",
            [
                (1, "Berlin"),
                (2, "Berlin"),
                (3, "Munich"),
            ],
        )
        self.conn.commit()

    def run(self, *, sql):
        cur = self.conn.cursor()
        cur.execute(sql)
        rows = [dict(zip([d[0] for d in cur.description], r)) for r in cur.fetchall()]
        return StageResult(
            ok=True,
            data={"rows": rows},
            trace=StageTrace(stage="executor", duration_ms=0),
        )


class VerifierCheckCount:
    def run(self, *, sql, exec_result):
        rows = exec_result.get("rows", [])
        ok = bool(rows and "city" in rows[0] and "cnt" in rows[0])
        return StageResult(
            ok=ok,
            data={"verified": ok},
            trace=StageTrace(
                stage="verifier", duration_ms=0, notes={"rows_len": len(rows)}
            ),
        )


class RepairNoOp:
    """Dummy repair stage (not triggered in this scenario)."""

    def run(self, *a, **k):
        return StageResult(ok=False, error=["no repair needed"])


# --- Integration test ----------------------------------------
def test_pipeline_end_to_end_real_sqlite():
    """Full NL2SQL pipeline test on real SQLite DB with no mocks."""
    pipe = Pipeline(
        detector=DetectorOK(),
        planner=PlannerLLM(),
        generator=GeneratorSimple(),
        safety=SafetyReadOnly(),
        executor=ExecutorSQLite(),
        verifier=VerifierCheckCount(),
        repair=RepairNoOp(),
    )

    result = pipe.run(
        user_query="count users per city", schema_preview="users(id, city)"
    )

    # --- Assertions ---
    assert result.ok
    assert result.verified
    assert not result.error
    assert "SELECT" in result.sql

    # Ensure pipeline produced valid SQL and traces
    assert isinstance(result.traces, list)
    assert result.traces  # not empty

    # Logical validation
    assert "city" in result.sql.lower()