Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
052c644
1
Parent(s):
be3a2bc
tests(pipeline,verifier): add unit and integration tests to increase coverage and validate end-to-end flow
Browse files- .coverage +0 -0
- tests/test_pipeline_extra.py +52 -0
- tests/test_pipeline_integration_real.py +98 -0
- tests/test_verifier.py +27 -0
.coverage
CHANGED
|
Binary files a/.coverage and b/.coverage differ
|
|
|
tests/test_pipeline_extra.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nl2sql.pipeline import Pipeline
|
| 2 |
+
from nl2sql.types import StageResult
|
| 3 |
+
|
| 4 |
+
class DetectorOK:
|
| 5 |
+
def detect(self, *a, **k):
|
| 6 |
+
return []
|
| 7 |
+
|
| 8 |
+
class PlannerOK:
|
| 9 |
+
def run(self, *a, **k):
|
| 10 |
+
return StageResult(ok=True, data={"plan": "p"})
|
| 11 |
+
|
| 12 |
+
class GeneratorOK:
|
| 13 |
+
def run(self, *a, **k):
|
| 14 |
+
return StageResult(ok=True, data={"sql": "SELECT * FROM t", "rationale": "ok"})
|
| 15 |
+
|
| 16 |
+
class SafetyOK:
|
| 17 |
+
def run(self, *a, **k):
|
| 18 |
+
sql = k.get("sql", "SELECT * FROM t")
|
| 19 |
+
return StageResult(ok=True, data={"sql": sql})
|
| 20 |
+
|
| 21 |
+
class ExecOK:
|
| 22 |
+
def run(self, *a, **k):
|
| 23 |
+
return StageResult(ok=True, data={"rows": [{"x": 1}]})
|
| 24 |
+
|
| 25 |
+
class VerifierThenOK:
|
| 26 |
+
"""اولین بار fail، بعد از repair pass میکند."""
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.calls = 0
|
| 29 |
+
def run(self, *, sql, exec_result):
|
| 30 |
+
self.calls += 1
|
| 31 |
+
if self.calls == 1:
|
| 32 |
+
return StageResult(ok=False, error=["first verify fail"])
|
| 33 |
+
return StageResult(ok=True, data={"verified": True})
|
| 34 |
+
|
| 35 |
+
class RepairOK:
|
| 36 |
+
def run(self, *, sql, error_msg, schema_preview):
|
| 37 |
+
return StageResult(ok=True, data={"sql": "SELECT * FROM t LIMIT 1"})
|
| 38 |
+
|
| 39 |
+
def test_pipeline_repair_success_path():
|
| 40 |
+
p = Pipeline(
|
| 41 |
+
detector=DetectorOK(),
|
| 42 |
+
planner=PlannerOK(),
|
| 43 |
+
generator=GeneratorOK(),
|
| 44 |
+
safety=SafetyOK(),
|
| 45 |
+
executor=ExecOK(),
|
| 46 |
+
verifier=VerifierThenOK(),
|
| 47 |
+
repair=RepairOK(),
|
| 48 |
+
)
|
| 49 |
+
out = p.run(user_query="?", schema_preview="")
|
| 50 |
+
assert out.ok
|
| 51 |
+
assert out.verified
|
| 52 |
+
assert not out.error
|
tests/test_pipeline_integration_real.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from nl2sql.pipeline import Pipeline
|
| 3 |
+
from nl2sql.types import StageResult, StageTrace
|
| 4 |
+
|
| 5 |
+
# --- Realistic dummy stages ----------------------------------
|
| 6 |
+
class DetectorOK:
|
| 7 |
+
"""Always returns no ambiguities."""
|
| 8 |
+
def detect(self, *a, **k):
|
| 9 |
+
return []
|
| 10 |
+
|
| 11 |
+
class PlannerLLM:
|
| 12 |
+
def run(self, *, user_query, schema_preview):
|
| 13 |
+
plan = f"Understand user query '{user_query}' and map to table."
|
| 14 |
+
return StageResult(ok=True, data={"plan": plan},
|
| 15 |
+
trace=StageTrace(stage="planner", duration_ms=0))
|
| 16 |
+
|
| 17 |
+
class GeneratorSimple:
|
| 18 |
+
def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
|
| 19 |
+
sql = "SELECT city, COUNT(*) AS cnt FROM users GROUP BY city"
|
| 20 |
+
return StageResult(ok=True, data={"sql": sql, "rationale": plan_text},
|
| 21 |
+
trace=StageTrace(stage="generator", duration_ms=0))
|
| 22 |
+
|
| 23 |
+
class SafetyReadOnly:
|
| 24 |
+
def run(self, *, sql):
|
| 25 |
+
if sql.strip().lower().startswith("select"):
|
| 26 |
+
return StageResult(ok=True, data={"sql": sql},
|
| 27 |
+
trace=StageTrace(stage="safety", duration_ms=0))
|
| 28 |
+
return StageResult(ok=False, error=["Unsafe query"],
|
| 29 |
+
trace=StageTrace(stage="safety", duration_ms=0, notes={"reason": "unsafe"}))
|
| 30 |
+
|
| 31 |
+
class ExecutorSQLite:
|
| 32 |
+
"""Executes the SQL query on a temporary in-memory SQLite database."""
|
| 33 |
+
def __init__(self):
|
| 34 |
+
# create in-memory DB and seed some rows
|
| 35 |
+
self.conn = sqlite3.connect(":memory:")
|
| 36 |
+
self._seed()
|
| 37 |
+
|
| 38 |
+
def _seed(self):
|
| 39 |
+
cur = self.conn.cursor()
|
| 40 |
+
cur.execute("CREATE TABLE users (id INTEGER, city TEXT)")
|
| 41 |
+
cur.executemany("INSERT INTO users VALUES (?, ?)", [
|
| 42 |
+
(1, "Berlin"), (2, "Berlin"), (3, "Munich"),
|
| 43 |
+
])
|
| 44 |
+
self.conn.commit()
|
| 45 |
+
|
| 46 |
+
def run(self, *, sql):
|
| 47 |
+
cur = self.conn.cursor()
|
| 48 |
+
cur.execute(sql)
|
| 49 |
+
rows = [dict(zip([d[0] for d in cur.description], r)) for r in cur.fetchall()]
|
| 50 |
+
return StageResult(
|
| 51 |
+
ok=True,
|
| 52 |
+
data={"rows": rows},
|
| 53 |
+
trace=StageTrace(stage="executor", duration_ms=0)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class VerifierCheckCount:
|
| 58 |
+
def run(self, *, sql, exec_result):
|
| 59 |
+
rows = exec_result.get("rows", [])
|
| 60 |
+
ok = bool(rows and "city" in rows[0] and "cnt" in rows[0])
|
| 61 |
+
return StageResult(ok=ok, data={"verified": ok},
|
| 62 |
+
trace=StageTrace(stage="verifier", duration_ms=0,
|
| 63 |
+
notes={"rows_len": len(rows)}))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class RepairNoOp:
|
| 67 |
+
"""Dummy repair stage (not triggered in this scenario)."""
|
| 68 |
+
def run(self, *a, **k):
|
| 69 |
+
return StageResult(ok=False, error=["no repair needed"])
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# --- Integration test ----------------------------------------
|
| 73 |
+
def test_pipeline_end_to_end_real_sqlite():
|
| 74 |
+
"""Full NL2SQL pipeline test on real SQLite DB with no mocks."""
|
| 75 |
+
pipe = Pipeline(
|
| 76 |
+
detector=DetectorOK(),
|
| 77 |
+
planner=PlannerLLM(),
|
| 78 |
+
generator=GeneratorSimple(),
|
| 79 |
+
safety=SafetyReadOnly(),
|
| 80 |
+
executor=ExecutorSQLite(),
|
| 81 |
+
verifier=VerifierCheckCount(),
|
| 82 |
+
repair=RepairNoOp(),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
result = pipe.run(user_query="count users per city", schema_preview="users(id, city)")
|
| 86 |
+
|
| 87 |
+
# --- Assertions ---
|
| 88 |
+
assert result.ok
|
| 89 |
+
assert result.verified
|
| 90 |
+
assert not result.error
|
| 91 |
+
assert "SELECT" in result.sql
|
| 92 |
+
|
| 93 |
+
# Ensure pipeline produced valid SQL and traces
|
| 94 |
+
assert isinstance(result.traces, list)
|
| 95 |
+
assert result.traces # not empty
|
| 96 |
+
|
| 97 |
+
# Logical validation
|
| 98 |
+
assert "city" in result.sql.lower()
|
tests/test_verifier.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nl2sql.verifier import Verifier
|
| 2 |
+
from nl2sql.types import StageResult, StageTrace
|
| 3 |
+
|
| 4 |
+
def make_exec_result(ok=True, error=None):
|
| 5 |
+
return StageResult(ok=ok, data={"dummy": True} if ok else None, trace=None, error=error)
|
| 6 |
+
|
| 7 |
+
def test_verifier_handles_execution_error():
|
| 8 |
+
v = Verifier()
|
| 9 |
+
r = v.run(sql="SELECT 1", exec_result=make_exec_result(ok=False, error=["db error"]))
|
| 10 |
+
assert not r.ok
|
| 11 |
+
assert "execution_error" in r.trace.notes["reason"]
|
| 12 |
+
assert r.error == ["db error"]
|
| 13 |
+
|
| 14 |
+
def test_verifier_detects_agg_without_group():
|
| 15 |
+
v = Verifier()
|
| 16 |
+
sql = "SELECT COUNT(*) FROM users"
|
| 17 |
+
r = v.run(sql=sql, exec_result=make_exec_result(ok=True))
|
| 18 |
+
assert not r.ok
|
| 19 |
+
assert any("Aggregation without GROUP BY" in e for e in r.error)
|
| 20 |
+
|
| 21 |
+
def test_verifier_parses_valid_sql_ok():
|
| 22 |
+
v = Verifier()
|
| 23 |
+
sql = "SELECT COUNT(*), city FROM users GROUP BY city"
|
| 24 |
+
r = v.run(sql=sql, exec_result=make_exec_result(ok=True))
|
| 25 |
+
assert r.ok
|
| 26 |
+
assert r.data == {"verified": True}
|
| 27 |
+
assert isinstance(r.trace, StageTrace)
|