Spaces:
Running
Running
File size: 5,064 Bytes
a45c0eb 570f7bd 4dae3e6 570f7bd c1bc4eb a45c0eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb a45c0eb 570f7bd a45c0eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd a45c0eb 570f7bd a45c0eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd a45c0eb 570f7bd a45c0eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from nl2sql.pipeline import Pipeline, FinalResult
from nl2sql.types import StageResult, StageTrace
# --- Dummy stages to isolate pipeline -----------------------------------------
class DummyDetector:
"""Simulates ambiguity detector stage."""
def __init__(self, ambiguous: bool = False):
self.ambiguous = ambiguous
def detect(self, user_query, schema_preview):
# If ambiguous=True, return clarification questions
return ["Which column?"] if self.ambiguous else []
class DummyPlanner:
"""Simulates planner stage."""
def run(self, *, user_query, schema_preview):
trace = StageTrace(stage="planner", duration_ms=1.0)
if "fail_plan" in user_query:
return StageResult(ok=False, error=["Planner failed"], trace=trace)
return StageResult(ok=True, data={"plan": "plan text"}, trace=trace)
class DummyGenerator:
"""Simulates generator stage."""
def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
trace = StageTrace(stage="generator", duration_ms=1.0)
if "fail_gen" in user_query:
return StageResult(ok=False, error=["Generator failed"], trace=trace)
sql = "SELECT * FROM singer;"
rationale = "List all singers."
return StageResult(
ok=True, data={"sql": sql, "rationale": rationale}, trace=trace
)
class DummySafety:
"""Simulates safety stage."""
# NOTE: pipeline now calls safety.run(sql=...)
def run(self, *, sql):
trace = StageTrace(stage="safety", duration_ms=1.0)
if "DROP" in sql.upper():
return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
# echo back sql in data to feed executor
return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)
# --- 1) Success path ----------------------------------------------------------
def test_pipeline_success():
pipeline = Pipeline(
detector=DummyDetector(ambiguous=False),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety(),
)
r = pipeline.run(
user_query="show all singers",
schema_preview="CREATE TABLE singer(id int, name text);",
)
assert isinstance(r, FinalResult)
assert r.ok is True
assert r.sql is not None and r.sql.lower().startswith("select")
# traces is a list of dicts (StageTrace.__dict__)
assert any(t.get("stage") == "planner" for t in r.traces)
assert any(t.get("stage") == "generator" for t in r.traces)
assert any(t.get("stage") == "safety" for t in r.traces)
# --- 2) Ambiguity case --------------------------------------------------------
def test_pipeline_ambiguity():
pipeline = Pipeline(
detector=DummyDetector(ambiguous=True),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety(),
)
r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);")
assert isinstance(r, FinalResult)
assert r.ok is True
assert r.ambiguous is True
assert isinstance(r.questions, list) and len(r.questions) > 0
# --- 3) Planner failure -------------------------------------------------------
def test_pipeline_plan_fail():
pipeline = Pipeline(
detector=DummyDetector(),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety(),
)
r = pipeline.run(
user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);"
)
assert isinstance(r, FinalResult)
assert r.ok is False
assert r.details is not None
assert "Planner failed" in " ".join(r.details)
# --- 4) Generator failure -----------------------------------------------------
def test_pipeline_gen_fail():
pipeline = Pipeline(
detector=DummyDetector(),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety(),
)
r = pipeline.run(
user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);"
)
assert isinstance(r, FinalResult)
assert r.ok is False
assert r.details is not None
assert "Generator failed" in " ".join(r.details)
# --- 5) Safety failure --------------------------------------------------------
def test_pipeline_safety_fail():
class UnsafeGen(DummyGenerator):
def run(self, **kw):
trace = StageTrace(stage="generator", duration_ms=1.0)
# Generate a DROP TABLE → unsafe
return StageResult(
ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace
)
pipeline = Pipeline(
detector=DummyDetector(),
planner=DummyPlanner(),
generator=UnsafeGen(),
safety=DummySafety(),
)
r = pipeline.run(
user_query="drop something", schema_preview="CREATE TABLE x(id int);"
)
assert isinstance(r, FinalResult)
assert r.ok is False
assert r.details is not None
assert "unsafe" in " ".join(r.details).lower()
|