Melika Kheirieh commited on
Commit
052c644
·
1 Parent(s): be3a2bc

tests(pipeline,verifier): add unit and integration tests to increase coverage and validate end-to-end flow

Browse files
.coverage CHANGED
Binary files a/.coverage and b/.coverage differ
 
tests/test_pipeline_extra.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nl2sql.pipeline import Pipeline
2
+ from nl2sql.types import StageResult
3
+
4
+ class DetectorOK:
5
+ def detect(self, *a, **k):
6
+ return []
7
+
8
+ class PlannerOK:
9
+ def run(self, *a, **k):
10
+ return StageResult(ok=True, data={"plan": "p"})
11
+
12
+ class GeneratorOK:
13
+ def run(self, *a, **k):
14
+ return StageResult(ok=True, data={"sql": "SELECT * FROM t", "rationale": "ok"})
15
+
16
+ class SafetyOK:
17
+ def run(self, *a, **k):
18
+ sql = k.get("sql", "SELECT * FROM t")
19
+ return StageResult(ok=True, data={"sql": sql})
20
+
21
+ class ExecOK:
22
+ def run(self, *a, **k):
23
+ return StageResult(ok=True, data={"rows": [{"x": 1}]})
24
+
25
+ class VerifierThenOK:
26
+ """اولین بار fail، بعد از repair pass می‌کند."""
27
+ def __init__(self):
28
+ self.calls = 0
29
+ def run(self, *, sql, exec_result):
30
+ self.calls += 1
31
+ if self.calls == 1:
32
+ return StageResult(ok=False, error=["first verify fail"])
33
+ return StageResult(ok=True, data={"verified": True})
34
+
35
+ class RepairOK:
36
+ def run(self, *, sql, error_msg, schema_preview):
37
+ return StageResult(ok=True, data={"sql": "SELECT * FROM t LIMIT 1"})
38
+
39
+ def test_pipeline_repair_success_path():
40
+ p = Pipeline(
41
+ detector=DetectorOK(),
42
+ planner=PlannerOK(),
43
+ generator=GeneratorOK(),
44
+ safety=SafetyOK(),
45
+ executor=ExecOK(),
46
+ verifier=VerifierThenOK(),
47
+ repair=RepairOK(),
48
+ )
49
+ out = p.run(user_query="?", schema_preview="")
50
+ assert out.ok
51
+ assert out.verified
52
+ assert not out.error
tests/test_pipeline_integration_real.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from nl2sql.pipeline import Pipeline
3
+ from nl2sql.types import StageResult, StageTrace
4
+
5
+ # --- Realistic dummy stages ----------------------------------
6
+ class DetectorOK:
7
+ """Always returns no ambiguities."""
8
+ def detect(self, *a, **k):
9
+ return []
10
+
11
+ class PlannerLLM:
12
+ def run(self, *, user_query, schema_preview):
13
+ plan = f"Understand user query '{user_query}' and map to table."
14
+ return StageResult(ok=True, data={"plan": plan},
15
+ trace=StageTrace(stage="planner", duration_ms=0))
16
+
17
+ class GeneratorSimple:
18
+ def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
19
+ sql = "SELECT city, COUNT(*) AS cnt FROM users GROUP BY city"
20
+ return StageResult(ok=True, data={"sql": sql, "rationale": plan_text},
21
+ trace=StageTrace(stage="generator", duration_ms=0))
22
+
23
+ class SafetyReadOnly:
24
+ def run(self, *, sql):
25
+ if sql.strip().lower().startswith("select"):
26
+ return StageResult(ok=True, data={"sql": sql},
27
+ trace=StageTrace(stage="safety", duration_ms=0))
28
+ return StageResult(ok=False, error=["Unsafe query"],
29
+ trace=StageTrace(stage="safety", duration_ms=0, notes={"reason": "unsafe"}))
30
+
31
+ class ExecutorSQLite:
32
+ """Executes the SQL query on a temporary in-memory SQLite database."""
33
+ def __init__(self):
34
+ # create in-memory DB and seed some rows
35
+ self.conn = sqlite3.connect(":memory:")
36
+ self._seed()
37
+
38
+ def _seed(self):
39
+ cur = self.conn.cursor()
40
+ cur.execute("CREATE TABLE users (id INTEGER, city TEXT)")
41
+ cur.executemany("INSERT INTO users VALUES (?, ?)", [
42
+ (1, "Berlin"), (2, "Berlin"), (3, "Munich"),
43
+ ])
44
+ self.conn.commit()
45
+
46
+ def run(self, *, sql):
47
+ cur = self.conn.cursor()
48
+ cur.execute(sql)
49
+ rows = [dict(zip([d[0] for d in cur.description], r)) for r in cur.fetchall()]
50
+ return StageResult(
51
+ ok=True,
52
+ data={"rows": rows},
53
+ trace=StageTrace(stage="executor", duration_ms=0)
54
+ )
55
+
56
+
57
+ class VerifierCheckCount:
58
+ def run(self, *, sql, exec_result):
59
+ rows = exec_result.get("rows", [])
60
+ ok = bool(rows and "city" in rows[0] and "cnt" in rows[0])
61
+ return StageResult(ok=ok, data={"verified": ok},
62
+ trace=StageTrace(stage="verifier", duration_ms=0,
63
+ notes={"rows_len": len(rows)}))
64
+
65
+
66
+ class RepairNoOp:
67
+ """Dummy repair stage (not triggered in this scenario)."""
68
+ def run(self, *a, **k):
69
+ return StageResult(ok=False, error=["no repair needed"])
70
+
71
+
72
+ # --- Integration test ----------------------------------------
73
+ def test_pipeline_end_to_end_real_sqlite():
74
+ """Full NL2SQL pipeline test on real SQLite DB with no mocks."""
75
+ pipe = Pipeline(
76
+ detector=DetectorOK(),
77
+ planner=PlannerLLM(),
78
+ generator=GeneratorSimple(),
79
+ safety=SafetyReadOnly(),
80
+ executor=ExecutorSQLite(),
81
+ verifier=VerifierCheckCount(),
82
+ repair=RepairNoOp(),
83
+ )
84
+
85
+ result = pipe.run(user_query="count users per city", schema_preview="users(id, city)")
86
+
87
+ # --- Assertions ---
88
+ assert result.ok
89
+ assert result.verified
90
+ assert not result.error
91
+ assert "SELECT" in result.sql
92
+
93
+ # Ensure pipeline produced valid SQL and traces
94
+ assert isinstance(result.traces, list)
95
+ assert result.traces # not empty
96
+
97
+ # Logical validation
98
+ assert "city" in result.sql.lower()
tests/test_verifier.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nl2sql.verifier import Verifier
2
+ from nl2sql.types import StageResult, StageTrace
3
+
4
+ def make_exec_result(ok=True, error=None):
5
+ return StageResult(ok=ok, data={"dummy": True} if ok else None, trace=None, error=error)
6
+
7
+ def test_verifier_handles_execution_error():
8
+ v = Verifier()
9
+ r = v.run(sql="SELECT 1", exec_result=make_exec_result(ok=False, error=["db error"]))
10
+ assert not r.ok
11
+ assert "execution_error" in r.trace.notes["reason"]
12
+ assert r.error == ["db error"]
13
+
14
+ def test_verifier_detects_agg_without_group():
15
+ v = Verifier()
16
+ sql = "SELECT COUNT(*) FROM users"
17
+ r = v.run(sql=sql, exec_result=make_exec_result(ok=True))
18
+ assert not r.ok
19
+ assert any("Aggregation without GROUP BY" in e for e in r.error)
20
+
21
+ def test_verifier_parses_valid_sql_ok():
22
+ v = Verifier()
23
+ sql = "SELECT COUNT(*), city FROM users GROUP BY city"
24
+ r = v.run(sql=sql, exec_result=make_exec_result(ok=True))
25
+ assert r.ok
26
+ assert r.data == {"verified": True}
27
+ assert isinstance(r.trace, StageTrace)