Spaces:
Sleeping
Sleeping
| from nl2sql.pipeline import Pipeline, FinalResult | |
| from nl2sql.types import StageResult, StageTrace | |
| # --- Dummy stages to isolate pipeline ----------------------------------------- | |
| class DummyDetector: | |
| """Simulates ambiguity detector stage.""" | |
| def __init__(self, ambiguous: bool = False): | |
| self.ambiguous = ambiguous | |
| def detect(self, user_query, schema_preview): | |
| # If ambiguous=True, return clarification questions | |
| return ["Which column?"] if self.ambiguous else [] | |
| class DummyPlanner: | |
| """Simulates planner stage.""" | |
| def run(self, *, user_query, schema_preview): | |
| trace = StageTrace(stage="planner", duration_ms=1.0) | |
| if "fail_plan" in user_query: | |
| return StageResult(ok=False, error=["Planner failed"], trace=trace) | |
| return StageResult(ok=True, data={"plan": "plan text"}, trace=trace) | |
| class DummyGenerator: | |
| """Simulates generator stage.""" | |
| def run(self, *, user_query, schema_preview, plan_text, clarify_answers): | |
| trace = StageTrace(stage="generator", duration_ms=1.0) | |
| if "fail_gen" in user_query: | |
| return StageResult(ok=False, error=["Generator failed"], trace=trace) | |
| sql = "SELECT * FROM singer;" | |
| rationale = "List all singers." | |
| return StageResult( | |
| ok=True, data={"sql": sql, "rationale": rationale}, trace=trace | |
| ) | |
| class DummySafety: | |
| """Simulates safety stage.""" | |
| # NOTE: pipeline now calls safety.run(sql=...) | |
| def run(self, *, sql): | |
| trace = StageTrace(stage="safety", duration_ms=1.0) | |
| if "DROP" in sql.upper(): | |
| return StageResult(ok=False, error=["Unsafe SQL"], trace=trace) | |
| # echo back sql in data to feed executor | |
| return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace) | |
| # --- 1) Success path ---------------------------------------------------------- | |
| def test_pipeline_success(): | |
| pipeline = Pipeline( | |
| detector=DummyDetector(ambiguous=False), | |
| planner=DummyPlanner(), | |
| generator=DummyGenerator(), | |
| safety=DummySafety(), | |
| ) | |
| r = pipeline.run( | |
| user_query="show all singers", | |
| schema_preview="CREATE TABLE singer(id int, name text);", | |
| ) | |
| assert isinstance(r, FinalResult) | |
| assert r.ok is True | |
| assert r.sql is not None and r.sql.lower().startswith("select") | |
| # traces is a list of dicts (StageTrace.__dict__) | |
| assert any(t.get("stage") == "planner" for t in r.traces) | |
| assert any(t.get("stage") == "generator" for t in r.traces) | |
| assert any(t.get("stage") == "safety" for t in r.traces) | |
| # --- 2) Ambiguity case -------------------------------------------------------- | |
| def test_pipeline_ambiguity(): | |
| pipeline = Pipeline( | |
| detector=DummyDetector(ambiguous=True), | |
| planner=DummyPlanner(), | |
| generator=DummyGenerator(), | |
| safety=DummySafety(), | |
| ) | |
| r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);") | |
| assert isinstance(r, FinalResult) | |
| assert r.ok is True | |
| assert r.ambiguous is True | |
| assert isinstance(r.questions, list) and len(r.questions) > 0 | |
| # --- 3) Planner failure ------------------------------------------------------- | |
| def test_pipeline_plan_fail(): | |
| pipeline = Pipeline( | |
| detector=DummyDetector(), | |
| planner=DummyPlanner(), | |
| generator=DummyGenerator(), | |
| safety=DummySafety(), | |
| ) | |
| r = pipeline.run( | |
| user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);" | |
| ) | |
| assert isinstance(r, FinalResult) | |
| assert r.ok is False | |
| assert r.details is not None | |
| assert "Planner failed" in " ".join(r.details) | |
| # --- 4) Generator failure ----------------------------------------------------- | |
| def test_pipeline_gen_fail(): | |
| pipeline = Pipeline( | |
| detector=DummyDetector(), | |
| planner=DummyPlanner(), | |
| generator=DummyGenerator(), | |
| safety=DummySafety(), | |
| ) | |
| r = pipeline.run( | |
| user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);" | |
| ) | |
| assert isinstance(r, FinalResult) | |
| assert r.ok is False | |
| assert r.details is not None | |
| assert "Generator failed" in " ".join(r.details) | |
| # --- 5) Safety failure -------------------------------------------------------- | |
| def test_pipeline_safety_fail(): | |
| class UnsafeGen(DummyGenerator): | |
| def run(self, **kw): | |
| trace = StageTrace(stage="generator", duration_ms=1.0) | |
| # Generate a DROP TABLE → unsafe | |
| return StageResult( | |
| ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace | |
| ) | |
| pipeline = Pipeline( | |
| detector=DummyDetector(), | |
| planner=DummyPlanner(), | |
| generator=UnsafeGen(), | |
| safety=DummySafety(), | |
| ) | |
| r = pipeline.run( | |
| user_query="drop something", schema_preview="CREATE TABLE x(id int);" | |
| ) | |
| assert isinstance(r, FinalResult) | |
| assert r.ok is False | |
| assert r.details is not None | |
| assert "unsafe" in " ".join(r.details).lower() | |