File size: 4,718 Bytes
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import pytest
from nl2sql.pipeline import Pipeline
from nl2sql.types import StageResult, StageTrace


# --- Dummy stages to isolate pipeline -----------------------------------------

class DummyDetector:
    """Simulates ambiguity detector stage."""
    def __init__(self, ambiguous=False):
        self.ambiguous = ambiguous

    def detect(self, user_query, schema_preview):
        # If ambiguous=True, return clarification questions
        return ["Which column?"] if self.ambiguous else []


class DummyPlanner:
    """Simulates planner stage."""
    def run(self, *, user_query, schema_preview):
        trace = StageTrace(stage="planner", duration_ms=1.0)
        if "fail_plan" in user_query:
            return StageResult(ok=False, error=["Planner failed"], trace=trace)
        return StageResult(ok=True, data={"plan": "plan text"}, trace=trace)


class DummyGenerator:
    """Simulates generator stage."""
    def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
        trace = StageTrace(stage="generator", duration_ms=1.0)
        if "fail_gen" in user_query:
            return StageResult(ok=False, error=["Generator failed"], trace=trace)
        sql = "SELECT * FROM singer;"
        rationale = "List all singers."
        return StageResult(ok=True, data={"sql": sql, "rationale": rationale}, trace=trace)


class DummySafety:
    """Simulates safety stage."""
    def check(self, sql):
        trace = StageTrace(stage="safety", duration_ms=1.0)
        if "DROP" in sql.upper():
            return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
        return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)


# --- 1) Success path ----------------------------------------------------------
def test_pipeline_success():
    pipeline = Pipeline(
        detector=DummyDetector(ambiguous=False),
        planner=DummyPlanner(),
        generator=DummyGenerator(),
        safety=DummySafety()
    )

    r = pipeline.run(
        user_query="show all singers",
        schema_preview="CREATE TABLE singer(id int, name text);"
    )

    assert isinstance(r, StageResult)
    assert r.ok is True
    data = r.data or {}
    assert data["sql"].lower().startswith("select")
    assert any(t.stage == "planner" for t in data["traces"])
    assert any(t.stage == "generator" for t in data["traces"])
    assert any(t.stage == "safety" for t in data["traces"])


# --- 2) Ambiguity case --------------------------------------------------------
def test_pipeline_ambiguity():
    pipeline = Pipeline(
        detector=DummyDetector(ambiguous=True),
        planner=DummyPlanner(),
        generator=DummyGenerator(),
        safety=DummySafety()
    )

    r = pipeline.run(
        user_query="show data",
        schema_preview="CREATE TABLE x(id int);"
    )

    assert isinstance(r, StageResult)
    assert r.ok is True
    assert r.data["ambiguous"] is True
    assert isinstance(r.data["questions"], list)


# --- 3) Planner failure -------------------------------------------------------
def test_pipeline_plan_fail():
    pipeline = Pipeline(
        detector=DummyDetector(),
        planner=DummyPlanner(),
        generator=DummyGenerator(),
        safety=DummySafety()
    )
    r = pipeline.run(
        user_query="fail_plan",
        schema_preview="CREATE TABLE singer(id int);"
    )
    assert isinstance(r, StageResult)
    assert r.ok is False
    assert "Planner failed" in " ".join(r.error or [])


# --- 4) Generator failure -----------------------------------------------------
def test_pipeline_gen_fail():
    pipeline = Pipeline(
        detector=DummyDetector(),
        planner=DummyPlanner(),
        generator=DummyGenerator(),
        safety=DummySafety()
    )
    r = pipeline.run(
        user_query="fail_gen",
        schema_preview="CREATE TABLE singer(id int);"
    )
    assert r.ok is False
    assert "Generator failed" in " ".join(r.error or [])


# --- 5) Safety failure --------------------------------------------------------
def test_pipeline_safety_fail():
    class UnsafeGen(DummyGenerator):
        def run(self, **kw):
            trace = StageTrace(stage="generator", duration_ms=1.0)
            # Generate a DROP TABLE → unsafe
            return StageResult(ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace)

    pipeline = Pipeline(
        detector=DummyDetector(),
        planner=DummyPlanner(),
        generator=UnsafeGen(),
        safety=DummySafety()
    )
    r = pipeline.run(
        user_query="drop something",
        schema_preview="CREATE TABLE x(id int);"
    )
    assert r.ok is False
    assert "unsafe" in " ".join(r.error or []).lower()