Spaces:
Running
Running
File size: 4,718 Bytes
570f7bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import pytest
from nl2sql.pipeline import Pipeline
from nl2sql.types import StageResult, StageTrace
# --- Dummy stages to isolate pipeline -----------------------------------------
class DummyDetector:
"""Simulates ambiguity detector stage."""
def __init__(self, ambiguous=False):
self.ambiguous = ambiguous
def detect(self, user_query, schema_preview):
# If ambiguous=True, return clarification questions
return ["Which column?"] if self.ambiguous else []
class DummyPlanner:
"""Simulates planner stage."""
def run(self, *, user_query, schema_preview):
trace = StageTrace(stage="planner", duration_ms=1.0)
if "fail_plan" in user_query:
return StageResult(ok=False, error=["Planner failed"], trace=trace)
return StageResult(ok=True, data={"plan": "plan text"}, trace=trace)
class DummyGenerator:
"""Simulates generator stage."""
def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
trace = StageTrace(stage="generator", duration_ms=1.0)
if "fail_gen" in user_query:
return StageResult(ok=False, error=["Generator failed"], trace=trace)
sql = "SELECT * FROM singer;"
rationale = "List all singers."
return StageResult(ok=True, data={"sql": sql, "rationale": rationale}, trace=trace)
class DummySafety:
"""Simulates safety stage."""
def check(self, sql):
trace = StageTrace(stage="safety", duration_ms=1.0)
if "DROP" in sql.upper():
return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)
# --- 1) Success path ----------------------------------------------------------
def test_pipeline_success():
pipeline = Pipeline(
detector=DummyDetector(ambiguous=False),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety()
)
r = pipeline.run(
user_query="show all singers",
schema_preview="CREATE TABLE singer(id int, name text);"
)
assert isinstance(r, StageResult)
assert r.ok is True
data = r.data or {}
assert data["sql"].lower().startswith("select")
assert any(t.stage == "planner" for t in data["traces"])
assert any(t.stage == "generator" for t in data["traces"])
assert any(t.stage == "safety" for t in data["traces"])
# --- 2) Ambiguity case --------------------------------------------------------
def test_pipeline_ambiguity():
pipeline = Pipeline(
detector=DummyDetector(ambiguous=True),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety()
)
r = pipeline.run(
user_query="show data",
schema_preview="CREATE TABLE x(id int);"
)
assert isinstance(r, StageResult)
assert r.ok is True
assert r.data["ambiguous"] is True
assert isinstance(r.data["questions"], list)
# --- 3) Planner failure -------------------------------------------------------
def test_pipeline_plan_fail():
pipeline = Pipeline(
detector=DummyDetector(),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety()
)
r = pipeline.run(
user_query="fail_plan",
schema_preview="CREATE TABLE singer(id int);"
)
assert isinstance(r, StageResult)
assert r.ok is False
assert "Planner failed" in " ".join(r.error or [])
# --- 4) Generator failure -----------------------------------------------------
def test_pipeline_gen_fail():
pipeline = Pipeline(
detector=DummyDetector(),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety()
)
r = pipeline.run(
user_query="fail_gen",
schema_preview="CREATE TABLE singer(id int);"
)
assert r.ok is False
assert "Generator failed" in " ".join(r.error or [])
# --- 5) Safety failure --------------------------------------------------------
def test_pipeline_safety_fail():
class UnsafeGen(DummyGenerator):
def run(self, **kw):
trace = StageTrace(stage="generator", duration_ms=1.0)
# Generate a DROP TABLE → unsafe
return StageResult(ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace)
pipeline = Pipeline(
detector=DummyDetector(),
planner=DummyPlanner(),
generator=UnsafeGen(),
safety=DummySafety()
)
r = pipeline.run(
user_query="drop something",
schema_preview="CREATE TABLE x(id int);"
)
assert r.ok is False
assert "unsafe" in " ".join(r.error or []).lower()
|