File size: 5,064 Bytes
a45c0eb
570f7bd
 
 
 
 
4dae3e6
570f7bd
 
c1bc4eb
a45c0eb
570f7bd
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
c1bc4eb
 
 
570f7bd
 
 
 
c1bc4eb
a45c0eb
 
570f7bd
 
 
a45c0eb
570f7bd
 
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
 
 
c1bc4eb
570f7bd
 
a45c0eb
570f7bd
a45c0eb
 
 
 
 
570f7bd
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
c1bc4eb
570f7bd
a45c0eb
570f7bd
a45c0eb
 
570f7bd
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
c1bc4eb
570f7bd
a45c0eb
570f7bd
a45c0eb
 
570f7bd
 
 
 
 
 
 
 
c1bc4eb
570f7bd
 
c1bc4eb
570f7bd
a45c0eb
570f7bd
a45c0eb
 
570f7bd
 
 
 
 
 
 
 
c1bc4eb
 
 
570f7bd
 
 
 
 
c1bc4eb
570f7bd
 
c1bc4eb
570f7bd
a45c0eb
570f7bd
a45c0eb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from nl2sql.pipeline import Pipeline, FinalResult
from nl2sql.types import StageResult, StageTrace


# --- Dummy stages to isolate pipeline -----------------------------------------


class DummyDetector:
    """Simulates ambiguity detector stage."""

    def __init__(self, ambiguous: bool = False):
        self.ambiguous = ambiguous

    def detect(self, user_query, schema_preview):
        # If ambiguous=True, return clarification questions
        return ["Which column?"] if self.ambiguous else []


class DummyPlanner:
    """Simulates planner stage."""

    def run(self, *, user_query, schema_preview):
        trace = StageTrace(stage="planner", duration_ms=1.0)
        if "fail_plan" in user_query:
            return StageResult(ok=False, error=["Planner failed"], trace=trace)
        return StageResult(ok=True, data={"plan": "plan text"}, trace=trace)


class DummyGenerator:
    """Simulates generator stage."""

    def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
        trace = StageTrace(stage="generator", duration_ms=1.0)
        if "fail_gen" in user_query:
            return StageResult(ok=False, error=["Generator failed"], trace=trace)
        sql = "SELECT * FROM singer;"
        rationale = "List all singers."
        return StageResult(
            ok=True, data={"sql": sql, "rationale": rationale}, trace=trace
        )


class DummySafety:
    """Simulates safety stage."""

    # NOTE: pipeline now calls safety.run(sql=...)
    def run(self, *, sql):
        trace = StageTrace(stage="safety", duration_ms=1.0)
        if "DROP" in sql.upper():
            return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
        # echo back sql in data to feed executor
        return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)


# --- 1) Success path ----------------------------------------------------------
def test_pipeline_success():
    pipeline = Pipeline(
        detector=DummyDetector(ambiguous=False),
        planner=DummyPlanner(),
        generator=DummyGenerator(),
        safety=DummySafety(),
    )

    r = pipeline.run(
        user_query="show all singers",
        schema_preview="CREATE TABLE singer(id int, name text);",
    )

    assert isinstance(r, FinalResult)
    assert r.ok is True
    assert r.sql is not None and r.sql.lower().startswith("select")
    # traces is a list of dicts (StageTrace.__dict__)
    assert any(t.get("stage") == "planner" for t in r.traces)
    assert any(t.get("stage") == "generator" for t in r.traces)
    assert any(t.get("stage") == "safety" for t in r.traces)


# --- 2) Ambiguity case --------------------------------------------------------
def test_pipeline_ambiguity():
    pipeline = Pipeline(
        detector=DummyDetector(ambiguous=True),
        planner=DummyPlanner(),
        generator=DummyGenerator(),
        safety=DummySafety(),
    )

    r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);")

    assert isinstance(r, FinalResult)
    assert r.ok is True
    assert r.ambiguous is True
    assert isinstance(r.questions, list) and len(r.questions) > 0


# --- 3) Planner failure -------------------------------------------------------
def test_pipeline_plan_fail():
    pipeline = Pipeline(
        detector=DummyDetector(),
        planner=DummyPlanner(),
        generator=DummyGenerator(),
        safety=DummySafety(),
    )
    r = pipeline.run(
        user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);"
    )
    assert isinstance(r, FinalResult)
    assert r.ok is False
    assert r.details is not None
    assert "Planner failed" in " ".join(r.details)


# --- 4) Generator failure -----------------------------------------------------
def test_pipeline_gen_fail():
    pipeline = Pipeline(
        detector=DummyDetector(),
        planner=DummyPlanner(),
        generator=DummyGenerator(),
        safety=DummySafety(),
    )
    r = pipeline.run(
        user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);"
    )
    assert isinstance(r, FinalResult)
    assert r.ok is False
    assert r.details is not None
    assert "Generator failed" in " ".join(r.details)


# --- 5) Safety failure --------------------------------------------------------
def test_pipeline_safety_fail():
    class UnsafeGen(DummyGenerator):
        def run(self, **kw):
            trace = StageTrace(stage="generator", duration_ms=1.0)
            # Generate a DROP TABLE → unsafe
            return StageResult(
                ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace
            )

    pipeline = Pipeline(
        detector=DummyDetector(),
        planner=DummyPlanner(),
        generator=UnsafeGen(),
        safety=DummySafety(),
    )
    r = pipeline.run(
        user_query="drop something", schema_preview="CREATE TABLE x(id int);"
    )
    assert isinstance(r, FinalResult)
    assert r.ok is False
    assert r.details is not None
    assert "unsafe" in " ".join(r.details).lower()