Spaces:
Sleeping
Sleeping
Melika Kheirieh
commited on
Commit
·
370553a
1
Parent(s):
c4c85f7
fix(pipeline): align backend-frontend schema and stabilize SQL flow
Browse files- app/main.py +7 -0
- app/routers/nl2sql.py +130 -28
- app/schemas.py +5 -2
- nl2sql/pipeline.py +45 -16
- nl2sql/safety.py +22 -5
- nl2sql/verifier.py +40 -11
app/main.py
CHANGED
|
@@ -5,6 +5,13 @@ load_dotenv()
|
|
| 5 |
from fastapi import FastAPI # noqa: E402
|
| 6 |
from app.routers import nl2sql # noqa: E402
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
app = FastAPI(
|
| 10 |
title="NL2SQL Copilot Prototype",
|
|
|
|
| 5 |
from fastapi import FastAPI # noqa: E402
|
| 6 |
from app.routers import nl2sql # noqa: E402
|
| 7 |
|
| 8 |
+
# restore previous uploaded DB map
|
| 9 |
+
try:
|
| 10 |
+
from app.routers.nl2sql import _load_db_map
|
| 11 |
+
|
| 12 |
+
_load_db_map()
|
| 13 |
+
except Exception as e:
|
| 14 |
+
print(f"⚠️ DB map not restored: {e}")
|
| 15 |
|
| 16 |
app = FastAPI(
|
| 17 |
title="NL2SQL Copilot Prototype",
|
app/routers/nl2sql.py
CHANGED
|
@@ -14,7 +14,9 @@ from adapters.db.sqlite_adapter import SQLiteAdapter
|
|
| 14 |
from adapters.db.postgres_adapter import PostgresAdapter
|
| 15 |
|
| 16 |
import os
|
|
|
|
| 17 |
import time
|
|
|
|
| 18 |
import uuid
|
| 19 |
from typing import Union, Optional, Dict
|
| 20 |
|
|
@@ -40,6 +42,41 @@ DEFAULT_SQLITE_DB = os.getenv(
|
|
| 40 |
"DEFAULT_SQLITE_DB", "data/chinook.db"
|
| 41 |
) # keep your current default
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def _cleanup_db_map() -> None:
|
| 45 |
"""Remove expired uploaded DB files (best-effort)."""
|
|
@@ -65,24 +102,39 @@ def _resolve_sqlite_path(db_id: Optional[str]) -> str:
|
|
| 65 |
return DEFAULT_SQLITE_DB
|
| 66 |
|
| 67 |
|
| 68 |
-
def _select_adapter(db_id:
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
raise HTTPException(
|
| 77 |
-
status_code=
|
| 78 |
)
|
| 79 |
-
return
|
| 80 |
|
| 81 |
-
#
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
return SQLiteAdapter(sqlite_path)
|
| 86 |
|
| 87 |
|
| 88 |
# -------------------------------
|
|
@@ -171,6 +223,7 @@ async def upload_db(file: UploadFile = File(...)):
|
|
| 171 |
raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}")
|
| 172 |
|
| 173 |
_DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
|
|
|
|
| 174 |
return {"db_id": db_id}
|
| 175 |
|
| 176 |
|
|
@@ -182,34 +235,53 @@ async def upload_db(file: UploadFile = File(...)):
|
|
| 182 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 183 |
"""
|
| 184 |
Handle NL → SQL pipeline execution.
|
| 185 |
-
|
| 186 |
-
|
| 187 |
"""
|
| 188 |
-
#
|
| 189 |
db_id = getattr(request, "db_id", None) # Optional[str]
|
| 190 |
-
# Build per-request pipeline bound to the selected adapter
|
| 191 |
adapter = _select_adapter(db_id)
|
| 192 |
pipeline = _build_pipeline(adapter)
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
| 197 |
)
|
| 198 |
|
| 199 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
if not isinstance(result, FinalResult):
|
| 201 |
raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
|
| 202 |
|
| 203 |
-
# Ambiguity
|
| 204 |
if result.ambiguous and result.questions:
|
| 205 |
return ClarifyResponse(ambiguous=True, questions=result.questions)
|
| 206 |
|
| 207 |
-
#
|
| 208 |
if not result.ok or result.error:
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
# Success
|
| 213 |
traces = [_round_trace(t) for t in (result.traces or [])]
|
| 214 |
return NL2SQLResponse(
|
| 215 |
ambiguous=False,
|
|
@@ -217,3 +289,33 @@ def nl2sql_handler(request: NL2SQLRequest):
|
|
| 217 |
rationale=result.rationale,
|
| 218 |
traces=traces,
|
| 219 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from adapters.db.postgres_adapter import PostgresAdapter
|
| 15 |
|
| 16 |
import os
|
| 17 |
+
from pathlib import Path
|
| 18 |
import time
|
| 19 |
+
import json
|
| 20 |
import uuid
|
| 21 |
from typing import Union, Optional, Dict
|
| 22 |
|
|
|
|
| 42 |
"DEFAULT_SQLITE_DB", "data/chinook.db"
|
| 43 |
) # keep your current default
|
| 44 |
|
| 45 |
+
# -------------------------------
|
| 46 |
+
# Path to persist db_id → file map
|
| 47 |
+
# -------------------------------
|
| 48 |
+
_DB_MAP_PATH = Path("data/uploads/db_map.json")
|
| 49 |
+
_DB_MAP_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
UPLOAD_DIR = Path("data/uploads")
|
| 53 |
+
UPLOAD_DIR.mkdir(parents=True, exist_ok=True) # ensure folder exists
|
| 54 |
+
|
| 55 |
+
DEFAULT_SQLITE_PATH = "data/Chinook_Sqlite.sqlite"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _save_db_map():
|
| 59 |
+
"""Persist the in-memory DB map to disk as JSON."""
|
| 60 |
+
try:
|
| 61 |
+
with open(_DB_MAP_PATH, "w") as f:
|
| 62 |
+
json.dump(_DB_MAP, f)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"⚠️ Failed to save DB map: {e}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _load_db_map():
|
| 68 |
+
"""Load the DB map from disk if it exists (called on startup)."""
|
| 69 |
+
global _DB_MAP
|
| 70 |
+
if _DB_MAP_PATH.exists():
|
| 71 |
+
try:
|
| 72 |
+
with open(_DB_MAP_PATH, "r") as f:
|
| 73 |
+
data = json.load(f)
|
| 74 |
+
if isinstance(data, dict):
|
| 75 |
+
_DB_MAP.update(data)
|
| 76 |
+
print(f"📂 Restored {_DB_MAP_PATH} with {len(_DB_MAP)} entries.")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"⚠️ Failed to load DB map: {e}")
|
| 79 |
+
|
| 80 |
|
| 81 |
def _cleanup_db_map() -> None:
|
| 82 |
"""Remove expired uploaded DB files (best-effort)."""
|
|
|
|
| 102 |
return DEFAULT_SQLITE_DB
|
| 103 |
|
| 104 |
|
| 105 |
+
def _select_adapter(db_id: str | None):
|
| 106 |
+
mode = os.getenv("DB_MODE", "sqlite").lower()
|
| 107 |
+
if mode == "postgres":
|
| 108 |
+
dsn = os.environ.get("POSTGRES_DSN")
|
| 109 |
+
if not dsn:
|
| 110 |
+
raise HTTPException(status_code=500, detail="POSTGRES_DSN env is missing")
|
| 111 |
+
return PostgresAdapter(dsn)
|
| 112 |
+
|
| 113 |
+
# sqlite mode
|
| 114 |
+
if db_id:
|
| 115 |
+
_cleanup_db_map()
|
| 116 |
+
db_path = None
|
| 117 |
+
# first check runtime map
|
| 118 |
+
if db_id in _DB_MAP:
|
| 119 |
+
db_path = _DB_MAP[db_id].get("path")
|
| 120 |
+
# fallback: check /tmp or uploads
|
| 121 |
+
if not db_path or not os.path.exists(db_path):
|
| 122 |
+
fallback_tmp = os.path.join(_DB_UPLOAD_DIR, f"{db_id}.sqlite")
|
| 123 |
+
fallback_uploads = UPLOAD_DIR / f"{db_id}.sqlite"
|
| 124 |
+
for candidate in (fallback_tmp, fallback_uploads):
|
| 125 |
+
if os.path.exists(candidate):
|
| 126 |
+
db_path = str(candidate)
|
| 127 |
+
break
|
| 128 |
+
if not db_path or not os.path.exists(db_path):
|
| 129 |
raise HTTPException(
|
| 130 |
+
status_code=400, detail="invalid db_id (file not found)"
|
| 131 |
)
|
| 132 |
+
return SQLiteAdapter(str(db_path))
|
| 133 |
|
| 134 |
+
# fallback to default Chinook
|
| 135 |
+
if not Path(DEFAULT_SQLITE_PATH).exists():
|
| 136 |
+
raise HTTPException(status_code=500, detail="default DB not found")
|
| 137 |
+
return SQLiteAdapter(DEFAULT_SQLITE_PATH)
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
# -------------------------------
|
|
|
|
| 223 |
raise HTTPException(status_code=500, detail=f"Failed to store DB: {e}")
|
| 224 |
|
| 225 |
_DB_MAP[db_id] = {"path": out_path, "ts": time.time()}
|
| 226 |
+
_save_db_map()
|
| 227 |
return {"db_id": db_id}
|
| 228 |
|
| 229 |
|
|
|
|
| 235 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 236 |
"""
|
| 237 |
Handle NL → SQL pipeline execution.
|
| 238 |
+
If `db_id` is provided, switch DB adapter for this call.
|
| 239 |
+
If `schema_preview` is missing, derive it from the selected adapter when possible.
|
| 240 |
"""
|
| 241 |
+
# 1) Select adapter based on db_id (if any)
|
| 242 |
db_id = getattr(request, "db_id", None) # Optional[str]
|
|
|
|
| 243 |
adapter = _select_adapter(db_id)
|
| 244 |
pipeline = _build_pipeline(adapter)
|
| 245 |
|
| 246 |
+
# 2) Resolve schema_preview (optional in request)
|
| 247 |
+
provided_preview = getattr(request, "schema_preview", None)
|
| 248 |
+
schema_preview = (
|
| 249 |
+
provided_preview
|
| 250 |
+
if provided_preview not in ("", None)
|
| 251 |
+
else _derive_schema_preview(adapter)
|
| 252 |
)
|
| 253 |
|
| 254 |
+
# 3) Run pipeline
|
| 255 |
+
try:
|
| 256 |
+
result = pipeline.run(
|
| 257 |
+
user_query=request.query, # assumes NL2SQLRequest has `query`
|
| 258 |
+
schema_preview=schema_preview, # may be empty string if adapter can't derive
|
| 259 |
+
)
|
| 260 |
+
except Exception as exc:
|
| 261 |
+
# Hard failure in pipeline itself
|
| 262 |
+
raise HTTPException(status_code=500, detail=f"Pipeline crash: {exc!s}")
|
| 263 |
+
|
| 264 |
+
# 4) Type check
|
| 265 |
if not isinstance(result, FinalResult):
|
| 266 |
raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
|
| 267 |
|
| 268 |
+
# 5) Ambiguity → ask for clarification
|
| 269 |
if result.ambiguous and result.questions:
|
| 270 |
return ClarifyResponse(ambiguous=True, questions=result.questions)
|
| 271 |
|
| 272 |
+
# 6) Soft errors → bubble up details with 400
|
| 273 |
if not result.ok or result.error:
|
| 274 |
+
print("❌ Pipeline failure dump:")
|
| 275 |
+
print(" ok:", result.ok)
|
| 276 |
+
print(" error:", result.error)
|
| 277 |
+
print(" details:", result.details)
|
| 278 |
+
print(" traces:", result.traces)
|
| 279 |
+
raise HTTPException(
|
| 280 |
+
status_code=400,
|
| 281 |
+
detail="; ".join(result.details or []) or (result.error or "Unknown error"),
|
| 282 |
+
)
|
| 283 |
|
| 284 |
+
# 7) Success
|
| 285 |
traces = [_round_trace(t) for t in (result.traces or [])]
|
| 286 |
return NL2SQLResponse(
|
| 287 |
ambiguous=False,
|
|
|
|
| 289 |
rationale=result.rationale,
|
| 290 |
traces=traces,
|
| 291 |
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _derive_schema_preview(adapter) -> str:
|
| 295 |
+
"""
|
| 296 |
+
Build a strict, exact-cased schema preview for the LLM.
|
| 297 |
+
Works for SQLite adapters by querying sqlite_master / pragma table_info.
|
| 298 |
+
"""
|
| 299 |
+
import sqlite3
|
| 300 |
+
import os
|
| 301 |
+
|
| 302 |
+
db_path = getattr(adapter, "db_path", None) or getattr(adapter, "path", None)
|
| 303 |
+
if not db_path or not os.path.exists(db_path):
|
| 304 |
+
return ""
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
conn = sqlite3.connect(db_path)
|
| 308 |
+
cur = conn.cursor()
|
| 309 |
+
tables = cur.execute(
|
| 310 |
+
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
| 311 |
+
).fetchall()
|
| 312 |
+
lines = []
|
| 313 |
+
for (tname,) in tables:
|
| 314 |
+
cols = cur.execute(f"PRAGMA table_info('{tname}')").fetchall()
|
| 315 |
+
# sqlite: pragma columns → (cid, name, type, notnull, dflt_value, pk)
|
| 316 |
+
colnames = [c[1] for c in cols]
|
| 317 |
+
lines.append(f"{tname}({', '.join(colnames)})")
|
| 318 |
+
conn.close()
|
| 319 |
+
return "\n".join(lines)
|
| 320 |
+
except Exception:
|
| 321 |
+
return ""
|
app/schemas.py
CHANGED
|
@@ -4,8 +4,11 @@ from typing import List, Optional, Any, Dict, Mapping, Sequence
|
|
| 4 |
|
| 5 |
class NL2SQLRequest(BaseModel):
|
| 6 |
query: str
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class TraceModel(BaseModel):
|
|
|
|
| 4 |
|
| 5 |
class NL2SQLRequest(BaseModel):
|
| 6 |
query: str
|
| 7 |
+
db_id: Optional[str] = None
|
| 8 |
+
schema_preview: Optional[str] = None
|
| 9 |
+
|
| 10 |
+
class Config:
|
| 11 |
+
extra = "ignore"
|
| 12 |
|
| 13 |
|
| 14 |
class TraceModel(BaseModel):
|
nl2sql/pipeline.py
CHANGED
|
@@ -31,7 +31,6 @@ class Pipeline:
|
|
| 31 |
"""
|
| 32 |
NL2SQL Copilot pipeline.
|
| 33 |
Stages return StageResult; final result is a type-safe FinalResult.
|
| 34 |
-
Adapters (e.g. FastAPI) can serialize with dataclasses.asdict().
|
| 35 |
"""
|
| 36 |
|
| 37 |
def __init__(
|
|
@@ -71,9 +70,7 @@ class Pipeline:
|
|
| 71 |
r = fn(**kwargs)
|
| 72 |
if isinstance(r, StageResult):
|
| 73 |
return r
|
| 74 |
-
|
| 75 |
-
# Normalize non-StageResult returns
|
| 76 |
-
return StageResult(ok=True, data=r, trace=None)
|
| 77 |
except Exception as e:
|
| 78 |
tb = traceback.format_exc()
|
| 79 |
return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
|
|
@@ -92,7 +89,7 @@ class Pipeline:
|
|
| 92 |
rationale: Optional[str] = None
|
| 93 |
verified: Optional[bool] = None
|
| 94 |
|
| 95 |
-
# --- 1) ambiguity detection
|
| 96 |
try:
|
| 97 |
questions = self.detector.detect(user_query, schema_preview)
|
| 98 |
if questions:
|
|
@@ -120,7 +117,7 @@ class Pipeline:
|
|
| 120 |
traces=[],
|
| 121 |
)
|
| 122 |
|
| 123 |
-
# --- 2) planner
|
| 124 |
r_plan = self._safe_stage(
|
| 125 |
self.planner.run, user_query=user_query, schema_preview=schema_preview
|
| 126 |
)
|
|
@@ -138,7 +135,7 @@ class Pipeline:
|
|
| 138 |
traces=traces,
|
| 139 |
)
|
| 140 |
|
| 141 |
-
# --- 3) generator
|
| 142 |
r_gen = self._safe_stage(
|
| 143 |
self.generator.run,
|
| 144 |
user_query=user_query,
|
|
@@ -159,10 +156,11 @@ class Pipeline:
|
|
| 159 |
verified=None,
|
| 160 |
traces=traces,
|
| 161 |
)
|
|
|
|
| 162 |
sql = (r_gen.data or {}).get("sql")
|
| 163 |
rationale = (r_gen.data or {}).get("rationale")
|
| 164 |
|
| 165 |
-
# --- 4) safety
|
| 166 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 167 |
traces.extend(self._trace_list(r_safe))
|
| 168 |
if not r_safe.ok:
|
|
@@ -178,7 +176,7 @@ class Pipeline:
|
|
| 178 |
traces=traces,
|
| 179 |
)
|
| 180 |
|
| 181 |
-
# --- 5) executor
|
| 182 |
r_exec = self._safe_stage(
|
| 183 |
self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
|
| 184 |
)
|
|
@@ -186,14 +184,14 @@ class Pipeline:
|
|
| 186 |
if not r_exec.ok:
|
| 187 |
details.extend(r_exec.error or [])
|
| 188 |
|
| 189 |
-
# --- 6) verifier
|
| 190 |
r_ver = self._safe_stage(
|
| 191 |
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
|
| 192 |
)
|
| 193 |
traces.extend(self._trace_list(r_ver))
|
| 194 |
-
verified = bool(r_ver.ok
|
| 195 |
|
| 196 |
-
# --- 7) repair loop if verification failed
|
| 197 |
if not verified:
|
| 198 |
for _attempt in range(2):
|
| 199 |
r_fix = self._safe_stage(
|
|
@@ -205,8 +203,8 @@ class Pipeline:
|
|
| 205 |
traces.extend(self._trace_list(r_fix))
|
| 206 |
if not r_fix.ok:
|
| 207 |
break
|
| 208 |
-
sql = (r_fix.data or {}).get("sql")
|
| 209 |
|
|
|
|
| 210 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 211 |
traces.extend(self._trace_list(r_safe))
|
| 212 |
if not r_safe.ok:
|
|
@@ -225,14 +223,45 @@ class Pipeline:
|
|
| 225 |
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
|
| 226 |
)
|
| 227 |
traces.extend(self._trace_list(r_ver))
|
| 228 |
-
verified = bool(r_ver.ok
|
| 229 |
if verified:
|
| 230 |
break
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
return FinalResult(
|
| 233 |
-
ok=
|
| 234 |
ambiguous=False,
|
| 235 |
-
error=
|
| 236 |
details=details or None,
|
| 237 |
sql=sql,
|
| 238 |
rationale=rationale,
|
|
|
|
| 31 |
"""
|
| 32 |
NL2SQL Copilot pipeline.
|
| 33 |
Stages return StageResult; final result is a type-safe FinalResult.
|
|
|
|
| 34 |
"""
|
| 35 |
|
| 36 |
def __init__(
|
|
|
|
| 70 |
r = fn(**kwargs)
|
| 71 |
if isinstance(r, StageResult):
|
| 72 |
return r
|
| 73 |
+
return StageResult(ok=True, data=r, trace=None)
|
|
|
|
|
|
|
| 74 |
except Exception as e:
|
| 75 |
tb = traceback.format_exc()
|
| 76 |
return StageResult(ok=False, data=None, trace=None, error=[f"{e}", tb])
|
|
|
|
| 89 |
rationale: Optional[str] = None
|
| 90 |
verified: Optional[bool] = None
|
| 91 |
|
| 92 |
+
# --- 1) ambiguity detection ---
|
| 93 |
try:
|
| 94 |
questions = self.detector.detect(user_query, schema_preview)
|
| 95 |
if questions:
|
|
|
|
| 117 |
traces=[],
|
| 118 |
)
|
| 119 |
|
| 120 |
+
# --- 2) planner ---
|
| 121 |
r_plan = self._safe_stage(
|
| 122 |
self.planner.run, user_query=user_query, schema_preview=schema_preview
|
| 123 |
)
|
|
|
|
| 135 |
traces=traces,
|
| 136 |
)
|
| 137 |
|
| 138 |
+
# --- 3) generator ---
|
| 139 |
r_gen = self._safe_stage(
|
| 140 |
self.generator.run,
|
| 141 |
user_query=user_query,
|
|
|
|
| 156 |
verified=None,
|
| 157 |
traces=traces,
|
| 158 |
)
|
| 159 |
+
|
| 160 |
sql = (r_gen.data or {}).get("sql")
|
| 161 |
rationale = (r_gen.data or {}).get("rationale")
|
| 162 |
|
| 163 |
+
# --- 4) safety ---
|
| 164 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 165 |
traces.extend(self._trace_list(r_safe))
|
| 166 |
if not r_safe.ok:
|
|
|
|
| 176 |
traces=traces,
|
| 177 |
)
|
| 178 |
|
| 179 |
+
# --- 5) executor ---
|
| 180 |
r_exec = self._safe_stage(
|
| 181 |
self.executor.run, sql=(r_safe.data or {}).get("sql", sql)
|
| 182 |
)
|
|
|
|
| 184 |
if not r_exec.ok:
|
| 185 |
details.extend(r_exec.error or [])
|
| 186 |
|
| 187 |
+
# --- 6) verifier ---
|
| 188 |
r_ver = self._safe_stage(
|
| 189 |
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
|
| 190 |
)
|
| 191 |
traces.extend(self._trace_list(r_ver))
|
| 192 |
+
verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
|
| 193 |
|
| 194 |
+
# --- 7) repair loop if verification failed ---
|
| 195 |
if not verified:
|
| 196 |
for _attempt in range(2):
|
| 197 |
r_fix = self._safe_stage(
|
|
|
|
| 203 |
traces.extend(self._trace_list(r_fix))
|
| 204 |
if not r_fix.ok:
|
| 205 |
break
|
|
|
|
| 206 |
|
| 207 |
+
sql = (r_fix.data or {}).get("sql")
|
| 208 |
r_safe = self._safe_stage(self.safety.run, sql=sql)
|
| 209 |
traces.extend(self._trace_list(r_safe))
|
| 210 |
if not r_safe.ok:
|
|
|
|
| 223 |
self.verifier.run, sql=sql, exec_result=(r_exec.data or {})
|
| 224 |
)
|
| 225 |
traces.extend(self._trace_list(r_ver))
|
| 226 |
+
verified = bool(r_ver.data and r_ver.data.get("verified")) or r_ver.ok
|
| 227 |
if verified:
|
| 228 |
break
|
| 229 |
|
| 230 |
+
# --- 8) fallback: verifier silent but executor succeeded ---
|
| 231 |
+
if (verified is None or not verified) and not details:
|
| 232 |
+
any_exec = any(
|
| 233 |
+
t.get("stage") == "executor" and t.get("notes", {}).get("row_count")
|
| 234 |
+
for t in traces
|
| 235 |
+
)
|
| 236 |
+
if any_exec:
|
| 237 |
+
traces.append(
|
| 238 |
+
{
|
| 239 |
+
"stage": "pipeline",
|
| 240 |
+
"notes": {
|
| 241 |
+
"auto_fix": "verified=True (executor succeeded, verifier silent)"
|
| 242 |
+
},
|
| 243 |
+
"duration_ms": 0.0,
|
| 244 |
+
}
|
| 245 |
+
)
|
| 246 |
+
verified = True
|
| 247 |
+
|
| 248 |
+
# --- 9) finalize result ---
|
| 249 |
+
has_errors = bool(details)
|
| 250 |
+
ok = bool(verified) and not has_errors
|
| 251 |
+
err = has_errors and not bool(verified)
|
| 252 |
+
|
| 253 |
+
traces.append(
|
| 254 |
+
{
|
| 255 |
+
"stage": "pipeline",
|
| 256 |
+
"notes": {"final_verified": verified, "details_len": len(details)},
|
| 257 |
+
"duration_ms": 0.0,
|
| 258 |
+
}
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
return FinalResult(
|
| 262 |
+
ok=ok,
|
| 263 |
ambiguous=False,
|
| 264 |
+
error=err,
|
| 265 |
details=details or None,
|
| 266 |
sql=sql,
|
| 267 |
rationale=rationale,
|
nl2sql/safety.py
CHANGED
|
@@ -19,10 +19,20 @@ _FORBIDDEN = re.compile(
|
|
| 19 |
# allow: SELECT ... or WITH <cte...> SELECT ...
|
| 20 |
_ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
s = _COMMENT_BLOCK.sub(" ", s)
|
| 25 |
s = _COMMENT_LINE.sub(" ", s)
|
|
|
|
|
|
|
|
|
|
| 26 |
return s
|
| 27 |
|
| 28 |
|
|
@@ -33,8 +43,13 @@ def _mask_strings(s: str) -> str:
|
|
| 33 |
|
| 34 |
|
| 35 |
def _split_statements(s: str) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
parts = [p.strip() for p in s.split(";")]
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
class Safety:
|
|
@@ -43,7 +58,9 @@ class Safety:
|
|
| 43 |
def check(self, sql: str) -> StageResult:
|
| 44 |
t0 = time.perf_counter()
|
| 45 |
print("🧩 SQL candidate:", sql)
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
s = _mask_strings(s).strip()
|
| 48 |
|
| 49 |
stmts = _split_statements(s)
|
|
@@ -79,8 +96,8 @@ class Safety:
|
|
| 79 |
return StageResult(
|
| 80 |
ok=True,
|
| 81 |
data={
|
| 82 |
-
"sql":
|
| 83 |
-
"rationale": "Statement validated as SELECT-only (strings/comments ignored).",
|
| 84 |
},
|
| 85 |
trace=StageTrace(
|
| 86 |
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
|
|
|
|
| 19 |
# allow: SELECT ... or WITH <cte...> SELECT ...
|
| 20 |
_ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
|
| 21 |
|
| 22 |
+
# --- New cleanup helpers ---
|
| 23 |
+
_FENCE_SQL = re.compile(r"```sql", re.IGNORECASE)
|
| 24 |
+
_FENCE_ANY = re.compile(r"```")
|
| 25 |
|
| 26 |
+
|
| 27 |
+
def _sanitize_sql(sql: str) -> str:
|
| 28 |
+
"""Remove markdown fences, comments, and surrounding junk."""
|
| 29 |
+
s = _FENCE_SQL.sub("", sql)
|
| 30 |
+
s = _FENCE_ANY.sub("", s)
|
| 31 |
s = _COMMENT_BLOCK.sub(" ", s)
|
| 32 |
s = _COMMENT_LINE.sub(" ", s)
|
| 33 |
+
s = s.strip()
|
| 34 |
+
# remove trailing semicolon safely
|
| 35 |
+
s = s.rstrip(";").strip()
|
| 36 |
return s
|
| 37 |
|
| 38 |
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def _split_statements(s: str) -> list[str]:
|
| 46 |
+
"""
|
| 47 |
+
Split only if there are real multiple statements,
|
| 48 |
+
ignoring harmless trailing semicolons or markdown.
|
| 49 |
+
"""
|
| 50 |
parts = [p.strip() for p in s.split(";")]
|
| 51 |
+
parts = [p for p in parts if p]
|
| 52 |
+
return parts
|
| 53 |
|
| 54 |
|
| 55 |
class Safety:
|
|
|
|
| 58 |
def check(self, sql: str) -> StageResult:
|
| 59 |
t0 = time.perf_counter()
|
| 60 |
print("🧩 SQL candidate:", sql)
|
| 61 |
+
|
| 62 |
+
# --- sanitize first ---
|
| 63 |
+
s = _sanitize_sql(sql)
|
| 64 |
s = _mask_strings(s).strip()
|
| 65 |
|
| 66 |
stmts = _split_statements(s)
|
|
|
|
| 96 |
return StageResult(
|
| 97 |
ok=True,
|
| 98 |
data={
|
| 99 |
+
"sql": body,
|
| 100 |
+
"rationale": "Statement validated as SELECT-only (strings/comments/markdown ignored).",
|
| 101 |
},
|
| 102 |
trace=StageTrace(
|
| 103 |
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
|
nl2sql/verifier.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import sqlglot
|
| 2 |
from sqlglot import expressions as exp
|
| 3 |
from nl2sql.types import StageResult, StageTrace
|
|
@@ -6,18 +7,32 @@ from nl2sql.types import StageResult, StageTrace
|
|
| 6 |
class Verifier:
|
| 7 |
name = "verifier"
|
| 8 |
|
| 9 |
-
def run(self, sql: str, exec_result:
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
| 11 |
return StageResult(
|
| 12 |
ok=False,
|
|
|
|
| 13 |
data=None,
|
| 14 |
trace=StageTrace(
|
| 15 |
-
stage=self.name, duration_ms=
|
| 16 |
),
|
| 17 |
-
error=exec_result.error,
|
| 18 |
)
|
| 19 |
|
| 20 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
issues = []
|
| 22 |
try:
|
| 23 |
tree = sqlglot.parse_one(sql)
|
|
@@ -25,21 +40,35 @@ class Verifier:
|
|
| 25 |
group = tree.args.get("group")
|
| 26 |
aggs = [a for a in tree.find_all(exp.AggFunc)]
|
| 27 |
if aggs and not group:
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
except Exception as e:
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
|
|
|
| 32 |
if issues:
|
| 33 |
return StageResult(
|
| 34 |
ok=False,
|
| 35 |
-
|
| 36 |
trace=StageTrace(
|
| 37 |
-
stage=self.name, duration_ms=
|
| 38 |
),
|
| 39 |
-
error=issues,
|
| 40 |
)
|
|
|
|
| 41 |
return StageResult(
|
| 42 |
ok=True,
|
| 43 |
data={"verified": True},
|
| 44 |
-
trace=StageTrace(stage=self.name, duration_ms=
|
| 45 |
)
|
|
|
|
| 1 |
+
import time
|
| 2 |
import sqlglot
|
| 3 |
from sqlglot import expressions as exp
|
| 4 |
from nl2sql.types import StageResult, StageTrace
|
|
|
|
| 7 |
class Verifier:
|
| 8 |
name = "verifier"
|
| 9 |
|
| 10 |
+
def run(self, sql: str, exec_result: dict | None) -> StageResult:
|
| 11 |
+
t0 = time.perf_counter()
|
| 12 |
+
|
| 13 |
+
# Defensive: check executor result validity
|
| 14 |
+
if not exec_result or not isinstance(exec_result, dict):
|
| 15 |
return StageResult(
|
| 16 |
ok=False,
|
| 17 |
+
error=["invalid or missing exec_result"],
|
| 18 |
data=None,
|
| 19 |
trace=StageTrace(
|
| 20 |
+
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
|
| 21 |
),
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
+
# If executor had rows and no error, consider verified early
|
| 25 |
+
rows = exec_result.get("rows")
|
| 26 |
+
if rows is not None and len(rows) > 0:
|
| 27 |
+
return StageResult(
|
| 28 |
+
ok=True,
|
| 29 |
+
data={"verified": True, "rows_checked": len(rows)},
|
| 30 |
+
trace=StageTrace(
|
| 31 |
+
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
|
| 32 |
+
),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Optional deeper check using SQL structure
|
| 36 |
issues = []
|
| 37 |
try:
|
| 38 |
tree = sqlglot.parse_one(sql)
|
|
|
|
| 40 |
group = tree.args.get("group")
|
| 41 |
aggs = [a for a in tree.find_all(exp.AggFunc)]
|
| 42 |
if aggs and not group:
|
| 43 |
+
select_cols = [
|
| 44 |
+
c for c in tree.expressions if not isinstance(c, exp.AggFunc)
|
| 45 |
+
]
|
| 46 |
+
if select_cols:
|
| 47 |
+
issues.append(
|
| 48 |
+
"Non-aggregated columns with aggregation but no GROUP BY."
|
| 49 |
+
)
|
| 50 |
except Exception as e:
|
| 51 |
+
# parsing failed → skip structural verification gracefully
|
| 52 |
+
return StageResult(
|
| 53 |
+
ok=True,
|
| 54 |
+
data={"verified": True, "note": f"Skipped parse: {e}"},
|
| 55 |
+
trace=StageTrace(
|
| 56 |
+
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
|
| 57 |
+
),
|
| 58 |
+
)
|
| 59 |
|
| 60 |
+
dur = (time.perf_counter() - t0) * 1000
|
| 61 |
if issues:
|
| 62 |
return StageResult(
|
| 63 |
ok=False,
|
| 64 |
+
error=issues,
|
| 65 |
trace=StageTrace(
|
| 66 |
+
stage=self.name, duration_ms=dur, notes={"issues": issues}
|
| 67 |
),
|
|
|
|
| 68 |
)
|
| 69 |
+
|
| 70 |
return StageResult(
|
| 71 |
ok=True,
|
| 72 |
data={"verified": True},
|
| 73 |
+
trace=StageTrace(stage=self.name, duration_ms=dur),
|
| 74 |
)
|