Melika Kheirieh commited on
Commit
6a94b42
·
1 Parent(s): 977a885

tests green: fix pipeline reference in router and adjust verifier logic for aggregation & error handling

Browse files
Files changed (3) hide show
  1. .coverage +0 -0
  2. app/routers/nl2sql.py +36 -11
  3. nl2sql/verifier.py +89 -49
.coverage CHANGED
Binary files a/.coverage and b/.coverage differ
 
app/routers/nl2sql.py CHANGED
@@ -22,6 +22,30 @@ from typing import Union, Optional, Dict, TypedDict, Any, cast
22
 
23
  router = APIRouter(prefix="/nl2sql")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # -------------------------------
26
  # Runtime DB registry (for uploaded SQLite files)
27
  # Files are stored under /tmp, mapped by a short-lived db_id
@@ -252,15 +276,16 @@ async def upload_db(file: UploadFile = File(...)):
252
  # -------------------------------
253
  @router.post("", name="nl2sql_handler")
254
  def nl2sql_handler(request: NL2SQLRequest):
255
- """
256
- Handle NL SQL pipeline execution.
257
- If `db_id` is provided, switch DB adapter for this call.
258
- If `schema_preview` is missing, derive it from the selected adapter when possible.
259
- """
260
- # 1) Select adapter based on db_id (if any)
261
- db_id = getattr(request, "db_id", None) # Optional[str]
262
- adapter = _select_adapter(db_id)
263
- pipeline = _build_pipeline(adapter)
 
264
 
265
  # 2) Resolve schema_preview (optional in request)
266
  provided_preview_any: Any = getattr(request, "schema_preview", None)
@@ -277,8 +302,8 @@ def nl2sql_handler(request: NL2SQLRequest):
277
  # 3) Run pipeline
278
  try:
279
  result = pipeline.run(
280
- user_query=request.query, # assumes NL2SQLRequest has `query: str`
281
- schema_preview=final_preview, # str guaranteed
282
  )
283
  except Exception as exc:
284
  # Hard failure in pipeline itself
 
22
 
23
  router = APIRouter(prefix="/nl2sql")
24
 
25
+ # --- Database adapter selection ---
26
+ if os.getenv("DB_MODE", "sqlite") == "postgres":
27
+ _db = PostgresAdapter(os.environ["POSTGRES_DSN"])
28
+ else:
29
+ _db = SQLiteAdapter("data/chinook.db")
30
+
31
+
32
+ # --- Build a single shared pipeline for all routes ---
33
+ def _make_pipeline() -> Pipeline:
34
+ llm = OpenAIProvider()
35
+ return Pipeline(
36
+ detector=AmbiguityDetector(),
37
+ planner=Planner(llm=llm),
38
+ generator=Generator(llm=llm),
39
+ safety=Safety(),
40
+ executor=Executor(db=_db),
41
+ verifier=Verifier(),
42
+ repair=Repair(llm=llm),
43
+ )
44
+
45
+
46
+ _pipeline: Pipeline = _make_pipeline()
47
+
48
+
49
  # -------------------------------
50
  # Runtime DB registry (for uploaded SQLite files)
51
  # Files are stored under /tmp, mapped by a short-lived db_id
 
276
  # -------------------------------
277
  @router.post("", name="nl2sql_handler")
278
  def nl2sql_handler(request: NL2SQLRequest):
279
+ db_id = getattr(request, "db_id", None)
280
+ adapter: Optional[Union[PostgresAdapter, SQLiteAdapter]] = None
281
+
282
+ if not db_id:
283
+ pipeline = _pipeline
284
+ derived_preview = ""
285
+ else:
286
+ adapter = _select_adapter(db_id)
287
+ pipeline = _build_pipeline(adapter)
288
+ derived_preview = _derive_schema_preview(adapter)
289
 
290
  # 2) Resolve schema_preview (optional in request)
291
  provided_preview_any: Any = getattr(request, "schema_preview", None)
 
302
  # 3) Run pipeline
303
  try:
304
  result = pipeline.run(
305
+ user_query=request.query,
306
+ schema_preview=final_preview,
307
  )
308
  except Exception as exc:
309
  # Hard failure in pipeline itself
nl2sql/verifier.py CHANGED
@@ -1,74 +1,114 @@
1
  import time
 
 
2
  import sqlglot
3
  from sqlglot import expressions as exp
 
4
  from nl2sql.types import StageResult, StageTrace
5
 
6
 
7
  class Verifier:
8
  name = "verifier"
9
 
10
- def run(self, sql: str, exec_result: dict | None) -> StageResult:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  t0 = time.perf_counter()
12
 
13
- # Defensive: check executor result validity
14
- if not exec_result or not isinstance(exec_result, dict):
15
- return StageResult(
16
- ok=False,
17
- error=["invalid or missing exec_result"],
18
- data=None,
19
- trace=StageTrace(
20
- stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
21
- ),
22
  )
 
23
 
24
- # If executor had rows and no error, consider verified early
25
- rows = exec_result.get("rows")
26
- if rows is not None and len(rows) > 0:
 
27
  return StageResult(
28
- ok=True,
29
- data={"verified": True, "rows_checked": len(rows)},
30
- trace=StageTrace(
31
- stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
32
- ),
33
  )
34
 
35
- # Optional deeper check using SQL structure
36
- issues = []
37
  try:
38
  tree = sqlglot.parse_one(sql)
39
- if isinstance(tree, exp.Select):
40
- group = tree.args.get("group")
41
- aggs = [a for a in tree.find_all(exp.AggFunc)]
42
- if aggs and not group:
43
- select_cols = [
44
- c for c in tree.expressions if not isinstance(c, exp.AggFunc)
45
- ]
46
- if select_cols:
47
- issues.append(
48
- "Non-aggregated columns with aggregation but no GROUP BY."
49
- )
50
  except Exception as e:
51
- # parsing failed → skip structural verification gracefully
52
- return StageResult(
53
- ok=True,
54
- data={"verified": True, "note": f"Skipped parse: {e}"},
55
- trace=StageTrace(
56
- stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
57
- ),
58
  )
 
 
 
 
 
 
 
 
 
 
59
 
60
  dur = (time.perf_counter() - t0) * 1000
61
  if issues:
62
- return StageResult(
63
- ok=False,
64
- error=issues,
65
- trace=StageTrace(
66
- stage=self.name, duration_ms=dur, notes={"issues": issues}
67
- ),
68
  )
 
69
 
70
- return StageResult(
71
- ok=True,
72
- data={"verified": True},
73
- trace=StageTrace(stage=self.name, duration_ms=dur),
74
- )
 
1
  import time
2
+ from typing import Any, Iterable
3
+
4
  import sqlglot
5
  from sqlglot import expressions as exp
6
+
7
  from nl2sql.types import StageResult, StageTrace
8
 
9
 
10
  class Verifier:
11
  name = "verifier"
12
 
13
+ # ----------------- helpers -----------------
14
+ @staticmethod
15
+ def _extract_ok(exec_result: Any) -> bool | None:
16
+ """Normalize exec_result.ok across dict or object."""
17
+ if exec_result is None:
18
+ return None
19
+ if isinstance(exec_result, dict):
20
+ return bool(exec_result.get("ok")) if "ok" in exec_result else None
21
+ if hasattr(exec_result, "ok"):
22
+ try:
23
+ return bool(getattr(exec_result, "ok"))
24
+ except Exception:
25
+ return None
26
+ return None
27
+
28
+ @staticmethod
29
+ def _extract_errors(exec_result: Any) -> list[str] | None:
30
+ """Pull ['...'] from exec_result['error'] or exec_result.error."""
31
+ val = None
32
+ if isinstance(exec_result, dict):
33
+ val = exec_result.get("error")
34
+ elif hasattr(exec_result, "error"):
35
+ val = getattr(exec_result, "error")
36
+
37
+ if val is None:
38
+ return None
39
+ if isinstance(val, str):
40
+ return [val]
41
+ if isinstance(val, Iterable):
42
+ # normalize to list[str]
43
+ return [str(x) for x in val]
44
+ return [str(val)]
45
+
46
+ @staticmethod
47
+ def _has_aggregation(tree: exp.Expression) -> bool:
48
+ for node in tree.walk():
49
+ if getattr(node, "is_aggregate", False):
50
+ return True
51
+ if isinstance(node, (exp.Count, exp.Sum, exp.Avg, exp.Min, exp.Max)):
52
+ return True
53
+ return False
54
+
55
+ @staticmethod
56
+ def _has_group_by(select: exp.Select) -> bool:
57
+ return bool(select.args.get("group"))
58
+
59
+ # ------------------- main -------------------
60
+ def run(self, *, sql: str, exec_result: Any) -> StageResult:
61
  t0 = time.perf_counter()
62
 
63
+ # 1) validate / normalize executor result
64
+ ok_flag = self._extract_ok(exec_result)
65
+ if ok_flag is False:
66
+ errs = self._extract_errors(exec_result) or ["execution_error"]
67
+ trace_err = StageTrace(
68
+ stage=self.name,
69
+ duration_ms=(time.perf_counter() - t0) * 1000,
70
+ notes={"reason": "execution_error"},
 
71
  )
72
+ return StageResult(ok=False, error=errs, trace=trace_err)
73
 
74
+ if exec_result is None:
75
+ trace_inv = StageTrace(
76
+ stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
77
+ )
78
  return StageResult(
79
+ ok=False,
80
+ error=["invalid or missing exec_result"],
81
+ trace=trace_inv,
 
 
82
  )
83
 
84
+ # 2) structural verification
 
85
  try:
86
  tree = sqlglot.parse_one(sql)
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
+ # parsing failed → accept with a note
89
+ trace_skip = StageTrace(
90
+ stage=self.name,
91
+ duration_ms=(time.perf_counter() - t0) * 1000,
92
+ notes={"note": f"Skipped parse: {e}"},
 
 
93
  )
94
+ return StageResult(ok=True, data={"verified": True}, trace=trace_skip)
95
+
96
+ issues: list[str] = []
97
+
98
+ # Detect ANY aggregation without GROUP BY for SELECT statements
99
+ if isinstance(tree, exp.Select):
100
+ has_agg = self._has_aggregation(tree)
101
+ has_group = self._has_group_by(tree)
102
+ if has_agg and not has_group:
103
+ issues.append("Aggregation without GROUP BY")
104
 
105
  dur = (time.perf_counter() - t0) * 1000
106
  if issues:
107
+ trace_bad = StageTrace(
108
+ stage=self.name, duration_ms=dur, notes={"issues": issues}
 
 
 
 
109
  )
110
+ return StageResult(ok=False, error=issues, trace=trace_bad)
111
 
112
+ # 3) success
113
+ trace_ok = StageTrace(stage=self.name, duration_ms=dur)
114
+ return StageResult(ok=True, data={"verified": True}, trace=trace_ok)