Melika Kheirieh commited on
Commit
370553a
·
1 Parent(s): c4c85f7

fix(pipeline): align backend-frontend schema and stabilize SQL flow

Browse files
Files changed (6) hide show
  1. app/main.py +7 -0
  2. app/routers/nl2sql.py +130 -28
  3. app/schemas.py +5 -2
  4. nl2sql/pipeline.py +45 -16
  5. nl2sql/safety.py +22 -5
  6. nl2sql/verifier.py +40 -11
app/main.py CHANGED
@@ -5,6 +5,13 @@ load_dotenv()
5
  from fastapi import FastAPI # noqa: E402
6
  from app.routers import nl2sql # noqa: E402
7
 
 
 
 
 
 
 
 
8
 
9
  app = FastAPI(
10
  title="NL2SQL Copilot Prototype",
 
5
  from fastapi import FastAPI # noqa: E402
6
  from app.routers import nl2sql # noqa: E402
7
 
8
+ # restore previous uploaded DB map
9
+ try:
10
+ from app.routers.nl2sql import _load_db_map
11
+
12
+ _load_db_map()
13
+ except Exception as e:
14
+ print(f"⚠️ DB map not restored: {e}")
15
 
16
  app = FastAPI(
17
  title="NL2SQL Copilot Prototype",
app/routers/nl2sql.py CHANGED
@@ -14,7 +14,9 @@ from adapters.db.sqlite_adapter import SQLiteAdapter
14
  from adapters.db.postgres_adapter import PostgresAdapter
15
 
16
  import os
 
17
  import time
 
18
  import uuid
19
  from typing import Union, Optional, Dict
20
 
@@ -40,6 +42,41 @@ DEFAULT_SQLITE_DB = os.getenv(
40
  "DEFAULT_SQLITE_DB", "data/chinook.db"
41
  ) # keep your current default
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def _cleanup_db_map() -> None:
45
  """Remove expired uploaded DB files (best-effort)."""
@@ -65,24 +102,39 @@ def _resolve_sqlite_path(db_id: Optional[str]) -> str:
65
  return DEFAULT_SQLITE_DB
66
 
67
 
68
- def _select_adapter(db_id: Optional[str]) -> Union[PostgresAdapter, SQLiteAdapter]:
69
- """
70
- Build a DB adapter for this request.
71
- - In postgres mode: always PostgresAdapter(POSTGRES_DSN).
72
- - In sqlite mode: use uploaded SQLite by db_id if present, otherwise DEFAULT_SQLITE_DB.
73
- """
74
- if DB_MODE == "postgres":
75
- if not POSTGRES_DSN:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  raise HTTPException(
77
- status_code=500, detail="POSTGRES_DSN is not configured"
78
  )
79
- return PostgresAdapter(POSTGRES_DSN)
80
 
81
- # sqlite mode
82
- sqlite_path = _resolve_sqlite_path(db_id)
83
- # NOTE: SQLiteAdapter should open DB in read-only mode internally if supported.
84
- # If not, ensure your adapter enforces PRAGMA query_only=ON and prevents DDL/DML.
85
- return SQLiteAdapter(sqlite_path)
86
 
87
 
88
  # -------------------------------
@@ -171,6 +223,7 @@ async def upload_db(file: UploadFile = File(...)):
171
  raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}")
172
 
173
  _DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
 
174
  return {"db_id": db_id}
175
 
176
 
@@ -182,34 +235,53 @@ async def upload_db(file: UploadFile = File(...)):
182
  def nl2sql_handler(request: NL2SQLRequest):
183
  """
184
  Handle NL → SQL pipeline execution.
185
- Optional: if the incoming request model supports `db_id`, we switch DB for this call.
186
- Otherwise we will silently ignore and use default DB (or Postgres, based on mode).
187
  """
188
- # Try to extract db_id if present in request (without breaking strict models)
189
  db_id = getattr(request, "db_id", None) # Optional[str]
190
- # Build per-request pipeline bound to the selected adapter
191
  adapter = _select_adapter(db_id)
192
  pipeline = _build_pipeline(adapter)
193
 
194
- result = pipeline.run(
195
- user_query=request.query,
196
- schema_preview=getattr(request, "schema_preview", None),
 
 
 
197
  )
198
 
199
- # Ensure result type
 
 
 
 
 
 
 
 
 
 
200
  if not isinstance(result, FinalResult):
201
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
202
 
203
- # Ambiguity: return clarify payload
204
  if result.ambiguous and result.questions:
205
  return ClarifyResponse(ambiguous=True, questions=result.questions)
206
 
207
- # Error: bubble up details
208
  if not result.ok or result.error:
209
- detail = "; ".join(result.details or ["Unknown error"])
210
- raise HTTPException(status_code=400, detail=detail)
 
 
 
 
 
 
 
211
 
212
- # Success
213
  traces = [_round_trace(t) for t in (result.traces or [])]
214
  return NL2SQLResponse(
215
  ambiguous=False,
@@ -217,3 +289,33 @@ def nl2sql_handler(request: NL2SQLRequest):
217
  rationale=result.rationale,
218
  traces=traces,
219
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from adapters.db.postgres_adapter import PostgresAdapter
15
 
16
  import os
17
+ from pathlib import Path
18
  import time
19
+ import json
20
  import uuid
21
  from typing import Union, Optional, Dict
22
 
 
42
  "DEFAULT_SQLITE_DB", "data/chinook.db"
43
  ) # keep your current default
44
 
45
+ # -------------------------------
46
+ # Path to persist db_id → file map
47
+ # -------------------------------
48
+ _DB_MAP_PATH = Path("data/uploads/db_map.json")
49
+ _DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
50
+
51
+
52
+ UPLOAD_DIR = Path("data/uploads")
53
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # ensure folder exists
54
+
55
+ DEFAULT_SQLITE_PATH = "data/Chinook_Sqlite.sqlite"
56
+
57
+
58
+ def _save_db_map():
59
+ """Persist the in-memory DB map to disk as JSON."""
60
+ try:
61
+ with open(_DB_MAP_PATH, "w") as f:
62
+ json.dump(_DB_MAP, f)
63
+ except Exception as e:
64
+ print(f"⚠️ Failed to save DB map: {e}")
65
+
66
+
67
+ def _load_db_map():
68
+ """Load the DB map from disk if it exists (called on startup)."""
69
+ global _DB_MAP
70
+ if _DB_MAP_PATH.exists():
71
+ try:
72
+ with open(_DB_MAP_PATH, "r") as f:
73
+ data = json.load(f)
74
+ if isinstance(data, dict):
75
+ _DB_MAP.update(data)
76
+ print(f"📂 Restored {_DB_MAP_PATH} with {len(_DB_MAP)} entries.")
77
+ except Exception as e:
78
+ print(f"⚠️ Failed to load DB map: {e}")
79
+
80
 
81
  def _cleanup_db_map() -> None:
82
  """Remove expired uploaded DB files (best-effort)."""
 
102
  return DEFAULT_SQLITE_DB
103
 
104
 
105
+ def _select_adapter(db_id: str | None):
106
+ mode = os.getenv("DB_MODE", "sqlite").lower()
107
+ if mode == "postgres":
108
+ dsn = os.environ.get("POSTGRES_DSN")
109
+ if not dsn:
110
+ raise HTTPException(status_code=500, detail="POSTGRES_DSN env is missing")
111
+ return PostgresAdapter(dsn)
112
+
113
+ # sqlite mode
114
+ if db_id:
115
+ _cleanup_db_map()
116
+ db_path = None
117
+ # first check runtime map
118
+ if db_id in _DB_MAP:
119
+ db_path = _DB_MAP[db_id].get("path")
120
+ # fallback: check /tmp or uploads
121
+ if not db_path or not os.path.exists(db_path):
122
+ fallback_tmp = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
123
+ fallback_uploads = UPLOAD_DIR / f"{db_id}.sqlite"
124
+ for candidate in (fallback_tmp, fallback_uploads):
125
+ if os.path.exists(candidate):
126
+ db_path = str(candidate)
127
+ break
128
+ if not db_path or not os.path.exists(db_path):
129
  raise HTTPException(
130
+ status_code=400, detail="invalid db_id (file not found)"
131
  )
132
+ return SQLiteAdapter(str(db_path))
133
 
134
+ # fallback to default Chinook
135
+ if not Path(DEFAULT_SQLITE_PATH).exists():
136
+ raise HTTPException(status_code=500, detail="default DB not found")
137
+ return SQLiteAdapter(DEFAULT_SQLITE_PATH)
 
138
 
139
 
140
  # -------------------------------
 
223
  raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}")
224
 
225
  _DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
226
+ _save_db_map()
227
  return {"db_id": db_id}
228
 
229
 
 
235
  def nl2sql_handler(request: NL2SQLRequest):
236
  """
237
  Handle NL → SQL pipeline execution.
238
+ If `db_id` is provided, switch DB adapter for this call.
239
+ If `schema_preview` is missing, derive it from the selected adapter when possible.
240
  """
241
+ # 1) Select adapter based on db_id (if any)
242
  db_id = getattr(request, "db_id", None) # Optional[str]
 
243
  adapter = _select_adapter(db_id)
244
  pipeline = _build_pipeline(adapter)
245
 
246
+ # 2) Resolve schema_preview (optional in request)
247
+ provided_preview = getattr(request, "schema_preview", None)
248
+ schema_preview = (
249
+ provided_preview
250
+ if provided_preview not in ("", None)
251
+ else _derive_schema_preview(adapter)
252
  )
253
 
254
+ # 3) Run pipeline
255
+ try:
256
+ result = pipeline.run(
257
+ user_query=request.query, # assumes NL2SQLRequest has `query`
258
+ schema_preview=schema_preview, # may be empty string if adapter can't derive
259
+ )
260
+ except Exception as exc:
261
+ # Hard failure in pipeline itself
262
+ raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
263
+
264
+ # 4) Type check
265
  if not isinstance(result, FinalResult):
266
  raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
267
 
268
+ # 5) Ambiguity ask for clarification
269
  if result.ambiguous and result.questions:
270
  return ClarifyResponse(ambiguous=True, questions=result.questions)
271
 
272
+ # 6) Soft errors → bubble up details with 400
273
  if not result.ok or result.error:
274
+ print(" Pipeline failure dump:")
275
+ print(" ok:", result.ok)
276
+ print(" error:", result.error)
277
+ print(" details:", result.details)
278
+ print(" traces:", result.traces)
279
+ raise HTTPException(
280
+ status_code=400,
281
+ detail="; ".join(result.details or []) or (result.error or "Unknown error"),
282
+ )
283
 
284
+ # 7) Success
285
  traces = [_round_trace(t) for t in (result.traces or [])]
286
  return NL2SQLResponse(
287
  ambiguous=False,
 
289
  rationale=result.rationale,
290
  traces=traces,
291
  )
292
+
293
+
294
+ def _derive_schema_preview(adapter) -> str:
295
+ """
296
+ Build a strict, exact-cased schema preview for the LLM.
297
+ Works for SQLite adapters by querying sqlite_master / pragma table_info.
298
+ """
299
+ import sqlite3
300
+ import os
301
+
302
+ db_path = getattr(adapter, "db_path", None) or getattr(adapter, "path", None)
303
+ if not db_path or not os.path.exists(db_path):
304
+ return ""
305
+
306
+ try:
307
+ conn = sqlite3.connect(db_path)
308
+ cur = conn.cursor()
309
+ tables = cur.execute(
310
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
311
+ ).fetchall()
312
+ lines = []
313
+ for (tname,) in tables:
314
+ cols = cur.execute(f"PRAGMA table_info('{tname}')").fetchall()
315
+ # sqlite: pragma columns → (cid, name, type, notnull, dflt_value, pk)
316
+ colnames = [c[1] for c in cols]
317
+ lines.append(f"{tname}({', '.join(colnames)})")
318
+ conn.close()
319
+ return "\n".join(lines)
320
+ except Exception:
321
+ return ""
app/schemas.py CHANGED
@@ -4,8 +4,11 @@ from typing import List, Optional, Any, Dict, Mapping, Sequence
4
 
5
  class NL2SQLRequest(BaseModel):
6
  query: str
7
- schema_preview: str
8
- db_name: Optional[str] = "default"
 
 
 
9
 
10
 
11
  class TraceModel(BaseModel):
 
4
 
5
  class NL2SQLRequest(BaseModel):
6
  query: str
7
+ db_id: Optional[str] = None
8
+ schema_preview: Optional[str] = None
9
+
10
+ class Config:
11
+ extra = "ignore"
12
 
13
 
14
  class TraceModel(BaseModel):
nl2sql/pipeline.py CHANGED
@@ -31,7 +31,6 @@ class Pipeline:
31
  """
32
  NL2SQL Copilot pipeline.
33
  Stages return StageResult; final result is a type-safe FinalResult.
34
- Adapters (e.g. FastAPI) can serialize with dataclasses.asdict().
35
  """
36
 
37
  def __init__(
@@ -71,9 +70,7 @@ class Pipeline:
71
  r = fn(**kwargs)
72
  if isinstance(r, StageResult):
73
  return r
74
- else:
75
- # Normalize non-StageResult returns
76
- return StageResult(ok=True, data=r, trace=None)
77
  except Exception as e:
78
  tb = traceback.format_exc()
79
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
@@ -92,7 +89,7 @@ class Pipeline:
92
  rationale: Optional[str] = None
93
  verified: Optional[bool] = None
94
 
95
- # --- 1) ambiguity detection
96
  try:
97
  questions = self.detector.detect(user_query, schema_preview)
98
  if questions:
@@ -120,7 +117,7 @@ class Pipeline:
120
  traces=[],
121
  )
122
 
123
- # --- 2) planner
124
  r_plan = self._safe_stage(
125
  self.planner.run, user_query=user_query, schema_preview=schema_preview
126
  )
@@ -138,7 +135,7 @@ class Pipeline:
138
  traces=traces,
139
  )
140
 
141
- # --- 3) generator
142
  r_gen = self._safe_stage(
143
  self.generator.run,
144
  user_query=user_query,
@@ -159,10 +156,11 @@ class Pipeline:
159
  verified=None,
160
  traces=traces,
161
  )
 
162
  sql = (r_gen.data or {}).get("sql")
163
  rationale = (r_gen.data or {}).get("rationale")
164
 
165
- # --- 4) safety
166
  r_safe = self._safe_stage(self.safety.run, sql=sql)
167
  traces.extend(self._trace_list(r_safe))
168
  if not r_safe.ok:
@@ -178,7 +176,7 @@ class Pipeline:
178
  traces=traces,
179
  )
180
 
181
- # --- 5) executor
182
  r_exec = self._safe_stage(
183
  self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
184
  )
@@ -186,14 +184,14 @@ class Pipeline:
186
  if not r_exec.ok:
187
  details.extend(r_exec.error or [])
188
 
189
- # --- 6) verifier
190
  r_ver = self._safe_stage(
191
  self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
192
  )
193
  traces.extend(self._trace_list(r_ver))
194
- verified = bool(r_ver.ok)
195
 
196
- # --- 7) repair loop if verification failed
197
  if not verified:
198
  for _attempt in range(2):
199
  r_fix = self._safe_stage(
@@ -205,8 +203,8 @@ class Pipeline:
205
  traces.extend(self._trace_list(r_fix))
206
  if not r_fix.ok:
207
  break
208
- sql = (r_fix.data or {}).get("sql")
209
 
 
210
  r_safe = self._safe_stage(self.safety.run, sql=sql)
211
  traces.extend(self._trace_list(r_safe))
212
  if not r_safe.ok:
@@ -225,14 +223,45 @@ class Pipeline:
225
  self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
226
  )
227
  traces.extend(self._trace_list(r_ver))
228
- verified = bool(r_ver.ok)
229
  if verified:
230
  break
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  return FinalResult(
233
- ok=bool(verified) and not details,
234
  ambiguous=False,
235
- error=bool(details) and not bool(verified),
236
  details=details or None,
237
  sql=sql,
238
  rationale=rationale,
 
31
  """
32
  NL2SQL Copilot pipeline.
33
  Stages return StageResult; final result is a type-safe FinalResult.
 
34
  """
35
 
36
  def __init__(
 
70
  r = fn(**kwargs)
71
  if isinstance(r, StageResult):
72
  return r
73
+ return StageResult(ok=True, data=r, trace=None)
 
 
74
  except Exception as e:
75
  tb = traceback.format_exc()
76
  return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
 
89
  rationale: Optional[str] = None
90
  verified: Optional[bool] = None
91
 
92
+ # --- 1) ambiguity detection ---
93
  try:
94
  questions = self.detector.detect(user_query, schema_preview)
95
  if questions:
 
117
  traces=[],
118
  )
119
 
120
+ # --- 2) planner ---
121
  r_plan = self._safe_stage(
122
  self.planner.run, user_query=user_query, schema_preview=schema_preview
123
  )
 
135
  traces=traces,
136
  )
137
 
138
+ # --- 3) generator ---
139
  r_gen = self._safe_stage(
140
  self.generator.run,
141
  user_query=user_query,
 
156
  verified=None,
157
  traces=traces,
158
  )
159
+
160
  sql = (r_gen.data or {}).get("sql")
161
  rationale = (r_gen.data or {}).get("rationale")
162
 
163
+ # --- 4) safety ---
164
  r_safe = self._safe_stage(self.safety.run, sql=sql)
165
  traces.extend(self._trace_list(r_safe))
166
  if not r_safe.ok:
 
176
  traces=traces,
177
  )
178
 
179
+ # --- 5) executor ---
180
  r_exec = self._safe_stage(
181
  self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
182
  )
 
184
  if not r_exec.ok:
185
  details.extend(r_exec.error or [])
186
 
187
+ # --- 6) verifier ---
188
  r_ver = self._safe_stage(
189
  self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
190
  )
191
  traces.extend(self._trace_list(r_ver))
192
+ verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
193
 
194
+ # --- 7) repair loop if verification failed ---
195
  if not verified:
196
  for _attempt in range(2):
197
  r_fix = self._safe_stage(
 
203
  traces.extend(self._trace_list(r_fix))
204
  if not r_fix.ok:
205
  break
 
206
 
207
+ sql = (r_fix.data or {}).get("sql")
208
  r_safe = self._safe_stage(self.safety.run, sql=sql)
209
  traces.extend(self._trace_list(r_safe))
210
  if not r_safe.ok:
 
223
  self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
224
  )
225
  traces.extend(self._trace_list(r_ver))
226
+ verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
227
  if verified:
228
  break
229
 
230
+ # --- 8) fallback: verifier silent but executor succeeded ---
231
+ if (verified is None or not verified) and not details:
232
+ any_exec = any(
233
+ t.get("stage") == "executor" and t.get("notes", {}).get("row_count")
234
+ for t in traces
235
+ )
236
+ if any_exec:
237
+ traces.append(
238
+ {
239
+ "stage": "pipeline",
240
+ "notes": {
241
+ "auto_fix": "verified=True (executor succeeded, verifier silent)"
242
+ },
243
+ "duration_ms": 0.0,
244
+ }
245
+ )
246
+ verified = True
247
+
248
+ # --- 9) finalize result ---
249
+ has_errors = bool(details)
250
+ ok = bool(verified) and not has_errors
251
+ err = has_errors and not bool(verified)
252
+
253
+ traces.append(
254
+ {
255
+ "stage": "pipeline",
256
+ "notes": {"final_verified": verified, "details_len": len(details)},
257
+ "duration_ms": 0.0,
258
+ }
259
+ )
260
+
261
  return FinalResult(
262
+ ok=ok,
263
  ambiguous=False,
264
+ error=err,
265
  details=details or None,
266
  sql=sql,
267
  rationale=rationale,
nl2sql/safety.py CHANGED
@@ -19,10 +19,20 @@ _FORBIDDEN = re.compile(
19
  # allow: SELECT ... or WITH <cte...> SELECT ...
20
  _ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
21
 
 
 
 
22
 
23
- def _strip_comments(s: str) -> str:
 
 
 
 
24
  s = _COMMENT_BLOCK.sub(" ", s)
25
  s = _COMMENT_LINE.sub(" ", s)
 
 
 
26
  return s
27
 
28
 
@@ -33,8 +43,13 @@ def _mask_strings(s: str) -> str:
33
 
34
 
35
  def _split_statements(s: str) -> list[str]:
 
 
 
 
36
  parts = [p.strip() for p in s.split(";")]
37
- return [p for p in parts if p]
 
38
 
39
 
40
  class Safety:
@@ -43,7 +58,9 @@ class Safety:
43
  def check(self, sql: str) -> StageResult:
44
  t0 = time.perf_counter()
45
  print("🧩 SQL candidate:", sql)
46
- s = _strip_comments(sql)
 
 
47
  s = _mask_strings(s).strip()
48
 
49
  stmts = _split_statements(s)
@@ -79,8 +96,8 @@ class Safety:
79
  return StageResult(
80
  ok=True,
81
  data={
82
- "sql": sql.strip(),
83
- "rationale": "Statement validated as SELECT-only (strings/comments ignored).",
84
  },
85
  trace=StageTrace(
86
  stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
 
19
  # allow: SELECT ... or WITH <cte...> SELECT ...
20
  _ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
21
 
22
+ # --- New cleanup helpers ---
23
+ _FENCE_SQL = re.compile(r"```sql", re.IGNORECASE)
24
+ _FENCE_ANY = re.compile(r"```")
25
 
26
+
27
+ def _sanitize_sql(sql: str) -> str:
28
+ """Remove markdown fences, comments, and surrounding junk."""
29
+ s = _FENCE_SQL.sub("", sql)
30
+ s = _FENCE_ANY.sub("", s)
31
  s = _COMMENT_BLOCK.sub(" ", s)
32
  s = _COMMENT_LINE.sub(" ", s)
33
+ s = s.strip()
34
+ # remove trailing semicolon safely
35
+ s = s.rstrip(";").strip()
36
  return s
37
 
38
 
 
43
 
44
 
45
  def _split_statements(s: str) -> list[str]:
46
+ """
47
+ Split only if there are real multiple statements,
48
+ ignoring harmless trailing semicolons or markdown.
49
+ """
50
  parts = [p.strip() for p in s.split(";")]
51
+ parts = [p for p in parts if p]
52
+ return parts
53
 
54
 
55
  class Safety:
 
58
  def check(self, sql: str) -> StageResult:
59
  t0 = time.perf_counter()
60
  print("🧩 SQL candidate:", sql)
61
+
62
+ # --- sanitize first ---
63
+ s = _sanitize_sql(sql)
64
  s = _mask_strings(s).strip()
65
 
66
  stmts = _split_statements(s)
 
96
  return StageResult(
97
  ok=True,
98
  data={
99
+ "sql": body,
100
+ "rationale": "Statement validated as SELECT-only (strings/comments/markdown ignored).",
101
  },
102
  trace=StageTrace(
103
  stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
nl2sql/verifier.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import sqlglot
2
  from sqlglot import expressions as exp
3
  from nl2sql.types import StageResult, StageTrace
@@ -6,18 +7,32 @@ from nl2sql.types import StageResult, StageTrace
6
  class Verifier:
7
  name = "verifier"
8
 
9
- def run(self, sql: str, exec_result: StageResult) -> StageResult:
10
- if not exec_result.ok:
 
 
 
11
  return StageResult(
12
  ok=False,
 
13
  data=None,
14
  trace=StageTrace(
15
- stage=self.name, duration_ms=0, notes={"reason": "execution_error"}
16
  ),
17
- error=exec_result.error,
18
  )
19
 
20
- # Rule 1: check SELECT / GROUP consistency
 
 
 
 
 
 
 
 
 
 
 
21
  issues = []
22
  try:
23
  tree = sqlglot.parse_one(sql)
@@ -25,21 +40,35 @@ class Verifier:
25
  group = tree.args.get("group")
26
  aggs = [a for a in tree.find_all(exp.AggFunc)]
27
  if aggs and not group:
28
- issues.append("Aggregation without GROUP BY.")
 
 
 
 
 
 
29
  except Exception as e:
30
- issues.append(f"Parse error during verification: {e}")
 
 
 
 
 
 
 
31
 
 
32
  if issues:
33
  return StageResult(
34
  ok=False,
35
- data=None,
36
  trace=StageTrace(
37
- stage=self.name, duration_ms=0, notes={"issues": issues}
38
  ),
39
- error=issues,
40
  )
 
41
  return StageResult(
42
  ok=True,
43
  data={"verified": True},
44
- trace=StageTrace(stage=self.name, duration_ms=0),
45
  )
 
1
+ import time
2
  import sqlglot
3
  from sqlglot import expressions as exp
4
  from nl2sql.types import StageResult, StageTrace
 
7
  class Verifier:
8
  name = "verifier"
9
 
10
+ def run(self, sql: str, exec_result: dict | None) -> StageResult:
11
+ t0 = time.perf_counter()
12
+
13
+ # Defensive: check executor result validity
14
+ if not exec_result or not isinstance(exec_result, dict):
15
  return StageResult(
16
  ok=False,
17
+ error=["invalid or missing exec_result"],
18
  data=None,
19
  trace=StageTrace(
20
+ stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
21
  ),
 
22
  )
23
 
24
+ # If executor had rows and no error, consider verified early
25
+ rows = exec_result.get("rows")
26
+ if rows is not None and len(rows) > 0:
27
+ return StageResult(
28
+ ok=True,
29
+ data={"verified": True, "rows_checked": len(rows)},
30
+ trace=StageTrace(
31
+ stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
32
+ ),
33
+ )
34
+
35
+ # Optional deeper check using SQL structure
36
  issues = []
37
  try:
38
  tree = sqlglot.parse_one(sql)
 
40
  group = tree.args.get("group")
41
  aggs = [a for a in tree.find_all(exp.AggFunc)]
42
  if aggs and not group:
43
+ select_cols = [
44
+ c for c in tree.expressions if not isinstance(c, exp.AggFunc)
45
+ ]
46
+ if select_cols:
47
+ issues.append(
48
+ "Non-aggregated columns with aggregation but no GROUP BY."
49
+ )
50
  except Exception as e:
51
+ # parsing failed → skip structural verification gracefully
52
+ return StageResult(
53
+ ok=True,
54
+ data={"verified": True, "note": f"Skipped parse: {e}"},
55
+ trace=StageTrace(
56
+ stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
57
+ ),
58
+ )
59
 
60
+ dur = (time.perf_counter() - t0) * 1000
61
  if issues:
62
  return StageResult(
63
  ok=False,
64
+ error=issues,
65
  trace=StageTrace(
66
+ stage=self.name, duration_ms=dur, notes={"issues": issues}
67
  ),
 
68
  )
69
+
70
  return StageResult(
71
  ok=True,
72
  data={"verified": True},
73
+ trace=StageTrace(stage=self.name, duration_ms=dur),
74
  )