Melika Kheirieh commited on
Commit
1af43ae
·
1 Parent(s): 052c644

style: format code with ruff

Browse files
tests/test_pipeline_extra.py CHANGED
@@ -1,41 +1,51 @@
1
  from nl2sql.pipeline import Pipeline
2
  from nl2sql.types import StageResult
3
 
 
4
  class DetectorOK:
5
  def detect(self, *a, **k):
6
  return []
7
 
 
8
  class PlannerOK:
9
  def run(self, *a, **k):
10
  return StageResult(ok=True, data={"plan": "p"})
11
 
 
12
  class GeneratorOK:
13
  def run(self, *a, **k):
14
  return StageResult(ok=True, data={"sql": "SELECT * FROM t", "rationale": "ok"})
15
 
 
16
  class SafetyOK:
17
  def run(self, *a, **k):
18
  sql = k.get("sql", "SELECT * FROM t")
19
  return StageResult(ok=True, data={"sql": sql})
20
 
 
21
  class ExecOK:
22
  def run(self, *a, **k):
23
  return StageResult(ok=True, data={"rows": [{"x": 1}]})
24
 
 
25
  class VerifierThenOK:
26
  """اولین بار fail، بعد از repair pass می‌کند."""
 
27
  def __init__(self):
28
  self.calls = 0
 
29
  def run(self, *, sql, exec_result):
30
  self.calls += 1
31
  if self.calls == 1:
32
  return StageResult(ok=False, error=["first verify fail"])
33
  return StageResult(ok=True, data={"verified": True})
34
 
 
35
  class RepairOK:
36
  def run(self, *, sql, error_msg, schema_preview):
37
  return StageResult(ok=True, data={"sql": "SELECT * FROM t LIMIT 1"})
38
 
 
39
  def test_pipeline_repair_success_path():
40
  p = Pipeline(
41
  detector=DetectorOK(),
 
1
  from nl2sql.pipeline import Pipeline
2
  from nl2sql.types import StageResult
3
 
4
+
5
  class DetectorOK:
6
  def detect(self, *a, **k):
7
  return []
8
 
9
+
10
  class PlannerOK:
11
  def run(self, *a, **k):
12
  return StageResult(ok=True, data={"plan": "p"})
13
 
14
+
15
  class GeneratorOK:
16
  def run(self, *a, **k):
17
  return StageResult(ok=True, data={"sql": "SELECT * FROM t", "rationale": "ok"})
18
 
19
+
20
  class SafetyOK:
21
  def run(self, *a, **k):
22
  sql = k.get("sql", "SELECT * FROM t")
23
  return StageResult(ok=True, data={"sql": sql})
24
 
25
+
26
  class ExecOK:
27
  def run(self, *a, **k):
28
  return StageResult(ok=True, data={"rows": [{"x": 1}]})
29
 
30
+
31
  class VerifierThenOK:
32
  """اولین بار fail، بعد از repair pass می‌کند."""
33
+
34
  def __init__(self):
35
  self.calls = 0
36
+
37
  def run(self, *, sql, exec_result):
38
  self.calls += 1
39
  if self.calls == 1:
40
  return StageResult(ok=False, error=["first verify fail"])
41
  return StageResult(ok=True, data={"verified": True})
42
 
43
+
44
  class RepairOK:
45
  def run(self, *, sql, error_msg, schema_preview):
46
  return StageResult(ok=True, data={"sql": "SELECT * FROM t LIMIT 1"})
47
 
48
+
49
  def test_pipeline_repair_success_path():
50
  p = Pipeline(
51
  detector=DetectorOK(),
tests/test_pipeline_integration_real.py CHANGED
@@ -2,34 +2,53 @@ import sqlite3
2
  from nl2sql.pipeline import Pipeline
3
  from nl2sql.types import StageResult, StageTrace
4
 
 
5
  # --- Realistic dummy stages ----------------------------------
6
  class DetectorOK:
7
  """Always returns no ambiguities."""
 
8
  def detect(self, *a, **k):
9
  return []
10
 
 
11
  class PlannerLLM:
12
  def run(self, *, user_query, schema_preview):
13
  plan = f"Understand user query '{user_query}' and map to table."
14
- return StageResult(ok=True, data={"plan": plan},
15
- trace=StageTrace(stage="planner", duration_ms=0))
 
 
 
 
16
 
17
  class GeneratorSimple:
18
  def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
19
  sql = "SELECT city, COUNT(*) AS cnt FROM users GROUP BY city"
20
- return StageResult(ok=True, data={"sql": sql, "rationale": plan_text},
21
- trace=StageTrace(stage="generator", duration_ms=0))
 
 
 
 
22
 
23
  class SafetyReadOnly:
24
  def run(self, *, sql):
25
  if sql.strip().lower().startswith("select"):
26
- return StageResult(ok=True, data={"sql": sql},
27
- trace=StageTrace(stage="safety", duration_ms=0))
28
- return StageResult(ok=False, error=["Unsafe query"],
29
- trace=StageTrace(stage="safety", duration_ms=0, notes={"reason": "unsafe"}))
 
 
 
 
 
 
 
30
 
31
  class ExecutorSQLite:
32
  """Executes the SQL query on a temporary in-memory SQLite database."""
 
33
  def __init__(self):
34
  # create in-memory DB and seed some rows
35
  self.conn = sqlite3.connect(":memory:")
@@ -38,9 +57,14 @@ class ExecutorSQLite:
38
  def _seed(self):
39
  cur = self.conn.cursor()
40
  cur.execute("CREATE TABLE users (id INTEGER, city TEXT)")
41
- cur.executemany("INSERT INTO users VALUES (?, ?)", [
42
- (1, "Berlin"), (2, "Berlin"), (3, "Munich"),
43
- ])
 
 
 
 
 
44
  self.conn.commit()
45
 
46
  def run(self, *, sql):
@@ -50,7 +74,7 @@ class ExecutorSQLite:
50
  return StageResult(
51
  ok=True,
52
  data={"rows": rows},
53
- trace=StageTrace(stage="executor", duration_ms=0)
54
  )
55
 
56
 
@@ -58,13 +82,18 @@ class VerifierCheckCount:
58
  def run(self, *, sql, exec_result):
59
  rows = exec_result.get("rows", [])
60
  ok = bool(rows and "city" in rows[0] and "cnt" in rows[0])
61
- return StageResult(ok=ok, data={"verified": ok},
62
- trace=StageTrace(stage="verifier", duration_ms=0,
63
- notes={"rows_len": len(rows)}))
 
 
 
 
64
 
65
 
66
  class RepairNoOp:
67
  """Dummy repair stage (not triggered in this scenario)."""
 
68
  def run(self, *a, **k):
69
  return StageResult(ok=False, error=["no repair needed"])
70
 
@@ -82,7 +111,9 @@ def test_pipeline_end_to_end_real_sqlite():
82
  repair=RepairNoOp(),
83
  )
84
 
85
- result = pipe.run(user_query="count users per city", schema_preview="users(id, city)")
 
 
86
 
87
  # --- Assertions ---
88
  assert result.ok
 
2
  from nl2sql.pipeline import Pipeline
3
  from nl2sql.types import StageResult, StageTrace
4
 
5
+
6
  # --- Realistic dummy stages ----------------------------------
7
  class DetectorOK:
8
  """Always returns no ambiguities."""
9
+
10
  def detect(self, *a, **k):
11
  return []
12
 
13
+
14
  class PlannerLLM:
15
  def run(self, *, user_query, schema_preview):
16
  plan = f"Understand user query '{user_query}' and map to table."
17
+ return StageResult(
18
+ ok=True,
19
+ data={"plan": plan},
20
+ trace=StageTrace(stage="planner", duration_ms=0),
21
+ )
22
+
23
 
24
  class GeneratorSimple:
25
  def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
26
  sql = "SELECT city, COUNT(*) AS cnt FROM users GROUP BY city"
27
+ return StageResult(
28
+ ok=True,
29
+ data={"sql": sql, "rationale": plan_text},
30
+ trace=StageTrace(stage="generator", duration_ms=0),
31
+ )
32
+
33
 
34
  class SafetyReadOnly:
35
  def run(self, *, sql):
36
  if sql.strip().lower().startswith("select"):
37
+ return StageResult(
38
+ ok=True,
39
+ data={"sql": sql},
40
+ trace=StageTrace(stage="safety", duration_ms=0),
41
+ )
42
+ return StageResult(
43
+ ok=False,
44
+ error=["Unsafe query"],
45
+ trace=StageTrace(stage="safety", duration_ms=0, notes={"reason": "unsafe"}),
46
+ )
47
+
48
 
49
  class ExecutorSQLite:
50
  """Executes the SQL query on a temporary in-memory SQLite database."""
51
+
52
  def __init__(self):
53
  # create in-memory DB and seed some rows
54
  self.conn = sqlite3.connect(":memory:")
 
57
  def _seed(self):
58
  cur = self.conn.cursor()
59
  cur.execute("CREATE TABLE users (id INTEGER, city TEXT)")
60
+ cur.executemany(
61
+ "INSERT INTO users VALUES (?, ?)",
62
+ [
63
+ (1, "Berlin"),
64
+ (2, "Berlin"),
65
+ (3, "Munich"),
66
+ ],
67
+ )
68
  self.conn.commit()
69
 
70
  def run(self, *, sql):
 
74
  return StageResult(
75
  ok=True,
76
  data={"rows": rows},
77
+ trace=StageTrace(stage="executor", duration_ms=0),
78
  )
79
 
80
 
 
82
  def run(self, *, sql, exec_result):
83
  rows = exec_result.get("rows", [])
84
  ok = bool(rows and "city" in rows[0] and "cnt" in rows[0])
85
+ return StageResult(
86
+ ok=ok,
87
+ data={"verified": ok},
88
+ trace=StageTrace(
89
+ stage="verifier", duration_ms=0, notes={"rows_len": len(rows)}
90
+ ),
91
+ )
92
 
93
 
94
  class RepairNoOp:
95
  """Dummy repair stage (not triggered in this scenario)."""
96
+
97
  def run(self, *a, **k):
98
  return StageResult(ok=False, error=["no repair needed"])
99
 
 
111
  repair=RepairNoOp(),
112
  )
113
 
114
+ result = pipe.run(
115
+ user_query="count users per city", schema_preview="users(id, city)"
116
+ )
117
 
118
  # --- Assertions ---
119
  assert result.ok
tests/test_verifier.py CHANGED
@@ -1,16 +1,23 @@
1
  from nl2sql.verifier import Verifier
2
  from nl2sql.types import StageResult, StageTrace
3
 
 
4
  def make_exec_result(ok=True, error=None):
5
- return StageResult(ok=ok, data={"dummy": True} if ok else None, trace=None, error=error)
 
 
 
6
 
7
  def test_verifier_handles_execution_error():
8
  v = Verifier()
9
- r = v.run(sql="SELECT 1", exec_result=make_exec_result(ok=False, error=["db error"]))
 
 
10
  assert not r.ok
11
  assert "execution_error" in r.trace.notes["reason"]
12
  assert r.error == ["db error"]
13
 
 
14
  def test_verifier_detects_agg_without_group():
15
  v = Verifier()
16
  sql = "SELECT COUNT(*) FROM users"
@@ -18,6 +25,7 @@ def test_verifier_detects_agg_without_group():
18
  assert not r.ok
19
  assert any("Aggregation without GROUP BY" in e for e in r.error)
20
 
 
21
  def test_verifier_parses_valid_sql_ok():
22
  v = Verifier()
23
  sql = "SELECT COUNT(*), city FROM users GROUP BY city"
 
1
  from nl2sql.verifier import Verifier
2
  from nl2sql.types import StageResult, StageTrace
3
 
4
+
5
  def make_exec_result(ok=True, error=None):
6
+ return StageResult(
7
+ ok=ok, data={"dummy": True} if ok else None, trace=None, error=error
8
+ )
9
+
10
 
11
  def test_verifier_handles_execution_error():
12
  v = Verifier()
13
+ r = v.run(
14
+ sql="SELECT 1", exec_result=make_exec_result(ok=False, error=["db error"])
15
+ )
16
  assert not r.ok
17
  assert "execution_error" in r.trace.notes["reason"]
18
  assert r.error == ["db error"]
19
 
20
+
21
  def test_verifier_detects_agg_without_group():
22
  v = Verifier()
23
  sql = "SELECT COUNT(*) FROM users"
 
25
  assert not r.ok
26
  assert any("Aggregation without GROUP BY" in e for e in r.error)
27
 
28
+
29
  def test_verifier_parses_valid_sql_ok():
30
  v = Verifier()
31
  sql = "SELECT COUNT(*), city FROM users GROUP BY city"