nl2sql-copilot / tests /test_pipeline_integration.py
Melika Kheirieh
style: format code with ruff
4dae3e6
from nl2sql.pipeline import Pipeline, FinalResult
from nl2sql.types import StageResult, StageTrace
# --- Dummy stages to isolate pipeline -----------------------------------------
class DummyDetector:
"""Simulates ambiguity detector stage."""
def __init__(self, ambiguous: bool = False):
self.ambiguous = ambiguous
def detect(self, user_query, schema_preview):
# If ambiguous=True, return clarification questions
return ["Which column?"] if self.ambiguous else []
class DummyPlanner:
"""Simulates planner stage."""
def run(self, *, user_query, schema_preview):
trace = StageTrace(stage="planner", duration_ms=1.0)
if "fail_plan" in user_query:
return StageResult(ok=False, error=["Planner failed"], trace=trace)
return StageResult(ok=True, data={"plan": "plan text"}, trace=trace)
class DummyGenerator:
"""Simulates generator stage."""
def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
trace = StageTrace(stage="generator", duration_ms=1.0)
if "fail_gen" in user_query:
return StageResult(ok=False, error=["Generator failed"], trace=trace)
sql = "SELECT * FROM singer;"
rationale = "List all singers."
return StageResult(
ok=True, data={"sql": sql, "rationale": rationale}, trace=trace
)
class DummySafety:
"""Simulates safety stage."""
# NOTE: pipeline now calls safety.run(sql=...)
def run(self, *, sql):
trace = StageTrace(stage="safety", duration_ms=1.0)
if "DROP" in sql.upper():
return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
# echo back sql in data to feed executor
return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)
# --- 1) Success path ----------------------------------------------------------
def test_pipeline_success():
pipeline = Pipeline(
detector=DummyDetector(ambiguous=False),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety(),
)
r = pipeline.run(
user_query="show all singers",
schema_preview="CREATE TABLE singer(id int, name text);",
)
assert isinstance(r, FinalResult)
assert r.ok is True
assert r.sql is not None and r.sql.lower().startswith("select")
# traces is a list of dicts (StageTrace.__dict__)
assert any(t.get("stage") == "planner" for t in r.traces)
assert any(t.get("stage") == "generator" for t in r.traces)
assert any(t.get("stage") == "safety" for t in r.traces)
# --- 2) Ambiguity case --------------------------------------------------------
def test_pipeline_ambiguity():
pipeline = Pipeline(
detector=DummyDetector(ambiguous=True),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety(),
)
r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);")
assert isinstance(r, FinalResult)
assert r.ok is True
assert r.ambiguous is True
assert isinstance(r.questions, list) and len(r.questions) > 0
# --- 3) Planner failure -------------------------------------------------------
def test_pipeline_plan_fail():
pipeline = Pipeline(
detector=DummyDetector(),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety(),
)
r = pipeline.run(
user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);"
)
assert isinstance(r, FinalResult)
assert r.ok is False
assert r.details is not None
assert "Planner failed" in " ".join(r.details)
# --- 4) Generator failure -----------------------------------------------------
def test_pipeline_gen_fail():
pipeline = Pipeline(
detector=DummyDetector(),
planner=DummyPlanner(),
generator=DummyGenerator(),
safety=DummySafety(),
)
r = pipeline.run(
user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);"
)
assert isinstance(r, FinalResult)
assert r.ok is False
assert r.details is not None
assert "Generator failed" in " ".join(r.details)
# --- 5) Safety failure --------------------------------------------------------
def test_pipeline_safety_fail():
class UnsafeGen(DummyGenerator):
def run(self, **kw):
trace = StageTrace(stage="generator", duration_ms=1.0)
# Generate a DROP TABLE → unsafe
return StageResult(
ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace
)
pipeline = Pipeline(
detector=DummyDetector(),
planner=DummyPlanner(),
generator=UnsafeGen(),
safety=DummySafety(),
)
r = pipeline.run(
user_query="drop something", schema_preview="CREATE TABLE x(id int);"
)
assert isinstance(r, FinalResult)
assert r.ok is False
assert r.details is not None
assert "unsafe" in " ".join(r.details).lower()