Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
1af43ae
1
Parent(s):
052c644
style: format code with ruff
Browse files- tests/test_pipeline_extra.py +10 -0
- tests/test_pipeline_integration_real.py +47 -16
- tests/test_verifier.py +10 -2
tests/test_pipeline_extra.py
CHANGED
|
@@ -1,41 +1,51 @@
|
|
| 1 |
from nl2sql.pipeline import Pipeline
|
| 2 |
from nl2sql.types import StageResult
|
| 3 |
|
|
|
|
| 4 |
class DetectorOK:
|
| 5 |
def detect(self, *a, **k):
|
| 6 |
return []
|
| 7 |
|
|
|
|
| 8 |
class PlannerOK:
|
| 9 |
def run(self, *a, **k):
|
| 10 |
return StageResult(ok=True, data={"plan": "p"})
|
| 11 |
|
|
|
|
| 12 |
class GeneratorOK:
|
| 13 |
def run(self, *a, **k):
|
| 14 |
return StageResult(ok=True, data={"sql": "SELECT * FROM t", "rationale": "ok"})
|
| 15 |
|
|
|
|
| 16 |
class SafetyOK:
|
| 17 |
def run(self, *a, **k):
|
| 18 |
sql = k.get("sql", "SELECT * FROM t")
|
| 19 |
return StageResult(ok=True, data={"sql": sql})
|
| 20 |
|
|
|
|
| 21 |
class ExecOK:
|
| 22 |
def run(self, *a, **k):
|
| 23 |
return StageResult(ok=True, data={"rows": [{"x": 1}]})
|
| 24 |
|
|
|
|
| 25 |
class VerifierThenOK:
|
| 26 |
"""اولین بار fail، بعد از repair pass میکند."""
|
|
|
|
| 27 |
def __init__(self):
|
| 28 |
self.calls = 0
|
|
|
|
| 29 |
def run(self, *, sql, exec_result):
|
| 30 |
self.calls += 1
|
| 31 |
if self.calls == 1:
|
| 32 |
return StageResult(ok=False, error=["first verify fail"])
|
| 33 |
return StageResult(ok=True, data={"verified": True})
|
| 34 |
|
|
|
|
| 35 |
class RepairOK:
|
| 36 |
def run(self, *, sql, error_msg, schema_preview):
|
| 37 |
return StageResult(ok=True, data={"sql": "SELECT * FROM t LIMIT 1"})
|
| 38 |
|
|
|
|
| 39 |
def test_pipeline_repair_success_path():
|
| 40 |
p = Pipeline(
|
| 41 |
detector=DetectorOK(),
|
|
|
|
| 1 |
from nl2sql.pipeline import Pipeline
|
| 2 |
from nl2sql.types import StageResult
|
| 3 |
|
| 4 |
+
|
| 5 |
class DetectorOK:
|
| 6 |
def detect(self, *a, **k):
|
| 7 |
return []
|
| 8 |
|
| 9 |
+
|
| 10 |
class PlannerOK:
|
| 11 |
def run(self, *a, **k):
|
| 12 |
return StageResult(ok=True, data={"plan": "p"})
|
| 13 |
|
| 14 |
+
|
| 15 |
class GeneratorOK:
|
| 16 |
def run(self, *a, **k):
|
| 17 |
return StageResult(ok=True, data={"sql": "SELECT * FROM t", "rationale": "ok"})
|
| 18 |
|
| 19 |
+
|
| 20 |
class SafetyOK:
|
| 21 |
def run(self, *a, **k):
|
| 22 |
sql = k.get("sql", "SELECT * FROM t")
|
| 23 |
return StageResult(ok=True, data={"sql": sql})
|
| 24 |
|
| 25 |
+
|
| 26 |
class ExecOK:
|
| 27 |
def run(self, *a, **k):
|
| 28 |
return StageResult(ok=True, data={"rows": [{"x": 1}]})
|
| 29 |
|
| 30 |
+
|
| 31 |
class VerifierThenOK:
|
| 32 |
"""اولین بار fail، بعد از repair pass میکند."""
|
| 33 |
+
|
| 34 |
def __init__(self):
|
| 35 |
self.calls = 0
|
| 36 |
+
|
| 37 |
def run(self, *, sql, exec_result):
|
| 38 |
self.calls += 1
|
| 39 |
if self.calls == 1:
|
| 40 |
return StageResult(ok=False, error=["first verify fail"])
|
| 41 |
return StageResult(ok=True, data={"verified": True})
|
| 42 |
|
| 43 |
+
|
| 44 |
class RepairOK:
|
| 45 |
def run(self, *, sql, error_msg, schema_preview):
|
| 46 |
return StageResult(ok=True, data={"sql": "SELECT * FROM t LIMIT 1"})
|
| 47 |
|
| 48 |
+
|
| 49 |
def test_pipeline_repair_success_path():
|
| 50 |
p = Pipeline(
|
| 51 |
detector=DetectorOK(),
|
tests/test_pipeline_integration_real.py
CHANGED
|
@@ -2,34 +2,53 @@ import sqlite3
|
|
| 2 |
from nl2sql.pipeline import Pipeline
|
| 3 |
from nl2sql.types import StageResult, StageTrace
|
| 4 |
|
|
|
|
| 5 |
# --- Realistic dummy stages ----------------------------------
|
| 6 |
class DetectorOK:
|
| 7 |
"""Always returns no ambiguities."""
|
|
|
|
| 8 |
def detect(self, *a, **k):
|
| 9 |
return []
|
| 10 |
|
|
|
|
| 11 |
class PlannerLLM:
|
| 12 |
def run(self, *, user_query, schema_preview):
|
| 13 |
plan = f"Understand user query '{user_query}' and map to table."
|
| 14 |
-
return StageResult(
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class GeneratorSimple:
|
| 18 |
def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
|
| 19 |
sql = "SELECT city, COUNT(*) AS cnt FROM users GROUP BY city"
|
| 20 |
-
return StageResult(
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
class SafetyReadOnly:
|
| 24 |
def run(self, *, sql):
|
| 25 |
if sql.strip().lower().startswith("select"):
|
| 26 |
-
return StageResult(
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class ExecutorSQLite:
|
| 32 |
"""Executes the SQL query on a temporary in-memory SQLite database."""
|
|
|
|
| 33 |
def __init__(self):
|
| 34 |
# create in-memory DB and seed some rows
|
| 35 |
self.conn = sqlite3.connect(":memory:")
|
|
@@ -38,9 +57,14 @@ class ExecutorSQLite:
|
|
| 38 |
def _seed(self):
|
| 39 |
cur = self.conn.cursor()
|
| 40 |
cur.execute("CREATE TABLE users (id INTEGER, city TEXT)")
|
| 41 |
-
cur.executemany(
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
self.conn.commit()
|
| 45 |
|
| 46 |
def run(self, *, sql):
|
|
@@ -50,7 +74,7 @@ class ExecutorSQLite:
|
|
| 50 |
return StageResult(
|
| 51 |
ok=True,
|
| 52 |
data={"rows": rows},
|
| 53 |
-
trace=StageTrace(stage="executor", duration_ms=0)
|
| 54 |
)
|
| 55 |
|
| 56 |
|
|
@@ -58,13 +82,18 @@ class VerifierCheckCount:
|
|
| 58 |
def run(self, *, sql, exec_result):
|
| 59 |
rows = exec_result.get("rows", [])
|
| 60 |
ok = bool(rows and "city" in rows[0] and "cnt" in rows[0])
|
| 61 |
-
return StageResult(
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class RepairNoOp:
|
| 67 |
"""Dummy repair stage (not triggered in this scenario)."""
|
|
|
|
| 68 |
def run(self, *a, **k):
|
| 69 |
return StageResult(ok=False, error=["no repair needed"])
|
| 70 |
|
|
@@ -82,7 +111,9 @@ def test_pipeline_end_to_end_real_sqlite():
|
|
| 82 |
repair=RepairNoOp(),
|
| 83 |
)
|
| 84 |
|
| 85 |
-
result = pipe.run(
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# --- Assertions ---
|
| 88 |
assert result.ok
|
|
|
|
| 2 |
from nl2sql.pipeline import Pipeline
|
| 3 |
from nl2sql.types import StageResult, StageTrace
|
| 4 |
|
| 5 |
+
|
| 6 |
# --- Realistic dummy stages ----------------------------------
|
| 7 |
class DetectorOK:
|
| 8 |
"""Always returns no ambiguities."""
|
| 9 |
+
|
| 10 |
def detect(self, *a, **k):
|
| 11 |
return []
|
| 12 |
|
| 13 |
+
|
| 14 |
class PlannerLLM:
|
| 15 |
def run(self, *, user_query, schema_preview):
|
| 16 |
plan = f"Understand user query '{user_query}' and map to table."
|
| 17 |
+
return StageResult(
|
| 18 |
+
ok=True,
|
| 19 |
+
data={"plan": plan},
|
| 20 |
+
trace=StageTrace(stage="planner", duration_ms=0),
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
|
| 24 |
class GeneratorSimple:
|
| 25 |
def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
|
| 26 |
sql = "SELECT city, COUNT(*) AS cnt FROM users GROUP BY city"
|
| 27 |
+
return StageResult(
|
| 28 |
+
ok=True,
|
| 29 |
+
data={"sql": sql, "rationale": plan_text},
|
| 30 |
+
trace=StageTrace(stage="generator", duration_ms=0),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
|
| 34 |
class SafetyReadOnly:
|
| 35 |
def run(self, *, sql):
|
| 36 |
if sql.strip().lower().startswith("select"):
|
| 37 |
+
return StageResult(
|
| 38 |
+
ok=True,
|
| 39 |
+
data={"sql": sql},
|
| 40 |
+
trace=StageTrace(stage="safety", duration_ms=0),
|
| 41 |
+
)
|
| 42 |
+
return StageResult(
|
| 43 |
+
ok=False,
|
| 44 |
+
error=["Unsafe query"],
|
| 45 |
+
trace=StageTrace(stage="safety", duration_ms=0, notes={"reason": "unsafe"}),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
|
| 49 |
class ExecutorSQLite:
|
| 50 |
"""Executes the SQL query on a temporary in-memory SQLite database."""
|
| 51 |
+
|
| 52 |
def __init__(self):
|
| 53 |
# create in-memory DB and seed some rows
|
| 54 |
self.conn = sqlite3.connect(":memory:")
|
|
|
|
| 57 |
def _seed(self):
|
| 58 |
cur = self.conn.cursor()
|
| 59 |
cur.execute("CREATE TABLE users (id INTEGER, city TEXT)")
|
| 60 |
+
cur.executemany(
|
| 61 |
+
"INSERT INTO users VALUES (?, ?)",
|
| 62 |
+
[
|
| 63 |
+
(1, "Berlin"),
|
| 64 |
+
(2, "Berlin"),
|
| 65 |
+
(3, "Munich"),
|
| 66 |
+
],
|
| 67 |
+
)
|
| 68 |
self.conn.commit()
|
| 69 |
|
| 70 |
def run(self, *, sql):
|
|
|
|
| 74 |
return StageResult(
|
| 75 |
ok=True,
|
| 76 |
data={"rows": rows},
|
| 77 |
+
trace=StageTrace(stage="executor", duration_ms=0),
|
| 78 |
)
|
| 79 |
|
| 80 |
|
|
|
|
| 82 |
def run(self, *, sql, exec_result):
|
| 83 |
rows = exec_result.get("rows", [])
|
| 84 |
ok = bool(rows and "city" in rows[0] and "cnt" in rows[0])
|
| 85 |
+
return StageResult(
|
| 86 |
+
ok=ok,
|
| 87 |
+
data={"verified": ok},
|
| 88 |
+
trace=StageTrace(
|
| 89 |
+
stage="verifier", duration_ms=0, notes={"rows_len": len(rows)}
|
| 90 |
+
),
|
| 91 |
+
)
|
| 92 |
|
| 93 |
|
| 94 |
class RepairNoOp:
|
| 95 |
"""Dummy repair stage (not triggered in this scenario)."""
|
| 96 |
+
|
| 97 |
def run(self, *a, **k):
|
| 98 |
return StageResult(ok=False, error=["no repair needed"])
|
| 99 |
|
|
|
|
| 111 |
repair=RepairNoOp(),
|
| 112 |
)
|
| 113 |
|
| 114 |
+
result = pipe.run(
|
| 115 |
+
user_query="count users per city", schema_preview="users(id, city)"
|
| 116 |
+
)
|
| 117 |
|
| 118 |
# --- Assertions ---
|
| 119 |
assert result.ok
|
tests/test_verifier.py
CHANGED
|
@@ -1,16 +1,23 @@
|
|
| 1 |
from nl2sql.verifier import Verifier
|
| 2 |
from nl2sql.types import StageResult, StageTrace
|
| 3 |
|
|
|
|
| 4 |
def make_exec_result(ok=True, error=None):
|
| 5 |
-
return StageResult(
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def test_verifier_handles_execution_error():
|
| 8 |
v = Verifier()
|
| 9 |
-
r = v.run(
|
|
|
|
|
|
|
| 10 |
assert not r.ok
|
| 11 |
assert "execution_error" in r.trace.notes["reason"]
|
| 12 |
assert r.error == ["db error"]
|
| 13 |
|
|
|
|
| 14 |
def test_verifier_detects_agg_without_group():
|
| 15 |
v = Verifier()
|
| 16 |
sql = "SELECT COUNT(*) FROM users"
|
|
@@ -18,6 +25,7 @@ def test_verifier_detects_agg_without_group():
|
|
| 18 |
assert not r.ok
|
| 19 |
assert any("Aggregation without GROUP BY" in e for e in r.error)
|
| 20 |
|
|
|
|
| 21 |
def test_verifier_parses_valid_sql_ok():
|
| 22 |
v = Verifier()
|
| 23 |
sql = "SELECT COUNT(*), city FROM users GROUP BY city"
|
|
|
|
| 1 |
from nl2sql.verifier import Verifier
|
| 2 |
from nl2sql.types import StageResult, StageTrace
|
| 3 |
|
| 4 |
+
|
| 5 |
def make_exec_result(ok=True, error=None):
|
| 6 |
+
return StageResult(
|
| 7 |
+
ok=ok, data={"dummy": True} if ok else None, trace=None, error=error
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
|
| 11 |
def test_verifier_handles_execution_error():
|
| 12 |
v = Verifier()
|
| 13 |
+
r = v.run(
|
| 14 |
+
sql="SELECT 1", exec_result=make_exec_result(ok=False, error=["db error"])
|
| 15 |
+
)
|
| 16 |
assert not r.ok
|
| 17 |
assert "execution_error" in r.trace.notes["reason"]
|
| 18 |
assert r.error == ["db error"]
|
| 19 |
|
| 20 |
+
|
| 21 |
def test_verifier_detects_agg_without_group():
|
| 22 |
v = Verifier()
|
| 23 |
sql = "SELECT COUNT(*) FROM users"
|
|
|
|
| 25 |
assert not r.ok
|
| 26 |
assert any("Aggregation without GROUP BY" in e for e in r.error)
|
| 27 |
|
| 28 |
+
|
| 29 |
def test_verifier_parses_valid_sql_ok():
|
| 30 |
v = Verifier()
|
| 31 |
sql = "SELECT COUNT(*), city FROM users GROUP BY city"
|