Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
0ecc315
1
Parent(s):
d2d07a3
Sync from GitHub main @ 793782272bfcd6bdae9a711aabd0ec2b0aef2312
Browse files- adapters/db/base.py +2 -2
- adapters/db/postgres_adapter.py +5 -3
- adapters/db/sqlite_adapter.py +6 -2
- nl2sql/errors/codes.py +1 -0
- nl2sql/errors/mapper.py +1 -0
- nl2sql/executor.py +79 -0
adapters/db/base.py
CHANGED
|
@@ -13,5 +13,5 @@ class DBAdapter(Protocol):
|
|
| 13 |
def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
|
| 14 |
"""Execute a SELECT query and return (rows, columns)."""
|
| 15 |
|
| 16 |
-
def explain_query_plan(self, sql: str) ->
|
| 17 |
-
"""
|
|
|
|
| 13 |
def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
|
| 14 |
"""Execute a SELECT query and return (rows, columns)."""
|
| 15 |
|
| 16 |
+
def explain_query_plan(self, sql: str) -> List[str]:
|
| 17 |
+
"""Return a query plan preview (must be read-only). Raise on failure."""
|
adapters/db/postgres_adapter.py
CHANGED
|
@@ -69,7 +69,7 @@ class PostgresAdapter(DBAdapter):
|
|
| 69 |
cols: List[str] = [d[0] for d in desc if d]
|
| 70 |
return rows, cols
|
| 71 |
|
| 72 |
-
def explain_query_plan(self, sql: str) ->
|
| 73 |
sql_stripped = (sql or "").strip().rstrip(";")
|
| 74 |
if not sql_stripped.lower().startswith("select"):
|
| 75 |
raise ValueError("Only SELECT statements are allowed.")
|
|
@@ -79,5 +79,7 @@ class PostgresAdapter(DBAdapter):
|
|
| 79 |
with conn.cursor() as cur:
|
| 80 |
cur.execute("SET TRANSACTION READ ONLY;")
|
| 81 |
cur.execute(f"EXPLAIN {sql_stripped}")
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
| 69 |
cols: List[str] = [d[0] for d in desc if d]
|
| 70 |
return rows, cols
|
| 71 |
|
| 72 |
+
def explain_query_plan(self, sql: str) -> List[str]:
|
| 73 |
sql_stripped = (sql or "").strip().rstrip(";")
|
| 74 |
if not sql_stripped.lower().startswith("select"):
|
| 75 |
raise ValueError("Only SELECT statements are allowed.")
|
|
|
|
| 79 |
with conn.cursor() as cur:
|
| 80 |
cur.execute("SET TRANSACTION READ ONLY;")
|
| 81 |
cur.execute(f"EXPLAIN {sql_stripped}")
|
| 82 |
+
rows = cur.fetchall() or []
|
| 83 |
+
# psycopg returns rows like ("Seq Scan on ...",)
|
| 84 |
+
plan_lines: List[str] = [str(r[0]) for r in rows if r and len(r) >= 1]
|
| 85 |
+
return plan_lines
|
adapters/db/sqlite_adapter.py
CHANGED
|
@@ -45,7 +45,7 @@ class SQLiteAdapter(DBAdapter):
|
|
| 45 |
log.info("Query executed successfully. Returned %d rows.", len(rows))
|
| 46 |
return rows, cols
|
| 47 |
|
| 48 |
-
def explain_query_plan(self, sql: str) ->
|
| 49 |
if not self.path.exists():
|
| 50 |
raise FileNotFoundError(f"SQLite DB does not exist: {self.path}")
|
| 51 |
|
|
@@ -60,4 +60,8 @@ class SQLiteAdapter(DBAdapter):
|
|
| 60 |
conn.execute("PRAGMA query_only = ON;")
|
| 61 |
except Exception:
|
| 62 |
pass
|
| 63 |
-
conn.execute(f"EXPLAIN QUERY PLAN {sql_stripped}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
log.info("Query executed successfully. Returned %d rows.", len(rows))
|
| 46 |
return rows, cols
|
| 47 |
|
| 48 |
+
def explain_query_plan(self, sql: str) -> List[str]:
|
| 49 |
if not self.path.exists():
|
| 50 |
raise FileNotFoundError(f"SQLite DB does not exist: {self.path}")
|
| 51 |
|
|
|
|
| 60 |
conn.execute("PRAGMA query_only = ON;")
|
| 61 |
except Exception:
|
| 62 |
pass
|
| 63 |
+
cur = conn.execute(f"EXPLAIN QUERY PLAN {sql_stripped}")
|
| 64 |
+
rows = cur.fetchall() or []
|
| 65 |
+
# Rows are typically (id, parent, notused, detail)
|
| 66 |
+
plan_lines: List[str] = [str(r[-1]) for r in rows if r]
|
| 67 |
+
return plan_lines
|
nl2sql/errors/codes.py
CHANGED
|
@@ -14,6 +14,7 @@ class ErrorCode(str, Enum):
|
|
| 14 |
# --- Executor / DB ---
|
| 15 |
DB_LOCKED = "DB_LOCKED"
|
| 16 |
DB_TIMEOUT = "DB_TIMEOUT"
|
|
|
|
| 17 |
LLM_FAILURE = "LLM_FAILURE"
|
| 18 |
|
| 19 |
# --- LLM ---
|
|
|
|
| 14 |
# --- Executor / DB ---
|
| 15 |
DB_LOCKED = "DB_LOCKED"
|
| 16 |
DB_TIMEOUT = "DB_TIMEOUT"
|
| 17 |
+
EXECUTOR_COST_GUARDRAIL_BLOCKED = "EXECUTOR_COST_GUARDRAIL_BLOCKED"
|
| 18 |
LLM_FAILURE = "LLM_FAILURE"
|
| 19 |
|
| 20 |
# --- LLM ---
|
nl2sql/errors/mapper.py
CHANGED
|
@@ -8,6 +8,7 @@ ERROR_MAP = {
|
|
| 8 |
ErrorCode.PLAN_SYNTAX_ERROR: (422, False),
|
| 9 |
ErrorCode.DB_LOCKED: (503, True),
|
| 10 |
ErrorCode.DB_TIMEOUT: (503, True),
|
|
|
|
| 11 |
ErrorCode.LLM_TIMEOUT: (503, True),
|
| 12 |
ErrorCode.PIPELINE_CRASH: (500, False),
|
| 13 |
}
|
|
|
|
| 8 |
ErrorCode.PLAN_SYNTAX_ERROR: (422, False),
|
| 9 |
ErrorCode.DB_LOCKED: (503, True),
|
| 10 |
ErrorCode.DB_TIMEOUT: (503, True),
|
| 11 |
+
ErrorCode.EXECUTOR_COST_GUARDRAIL_BLOCKED: (422, False),
|
| 12 |
ErrorCode.LLM_TIMEOUT: (503, True),
|
| 13 |
ErrorCode.PIPELINE_CRASH: (500, False),
|
| 14 |
}
|
nl2sql/executor.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import time
|
| 2 |
from nl2sql.types import StageResult, StageTrace
|
|
|
|
| 3 |
from adapters.db.base import DBAdapter
|
| 4 |
|
| 5 |
|
|
@@ -9,8 +13,79 @@ class Executor:
|
|
| 9 |
def __init__(self, db: DBAdapter):
|
| 10 |
self.db = db
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def run(self, sql: str) -> StageResult:
|
| 13 |
t0 = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
try:
|
| 15 |
rows, cols = self.db.execute(sql)
|
| 16 |
trace = StageTrace(
|
|
@@ -20,6 +95,8 @@ class Executor:
|
|
| 20 |
"row_count": len(rows),
|
| 21 |
"col_count": len(cols),
|
| 22 |
"sql_length": len(sql or ""),
|
|
|
|
|
|
|
| 23 |
},
|
| 24 |
)
|
| 25 |
return StageResult(
|
|
@@ -33,6 +110,8 @@ class Executor:
|
|
| 33 |
"error": str(e),
|
| 34 |
"error_type": type(e).__name__,
|
| 35 |
"sql_length": len(sql or ""),
|
|
|
|
|
|
|
| 36 |
},
|
| 37 |
)
|
| 38 |
return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
|
|
|
|
| 1 |
+
import sqlglot
|
| 2 |
+
from sqlglot import exp
|
| 3 |
+
|
| 4 |
import time
|
| 5 |
from nl2sql.types import StageResult, StageTrace
|
| 6 |
+
from nl2sql.errors.codes import ErrorCode
|
| 7 |
from adapters.db.base import DBAdapter
|
| 8 |
|
| 9 |
|
|
|
|
| 13 |
def __init__(self, db: DBAdapter):
|
| 14 |
self.db = db
|
| 15 |
|
| 16 |
+
def _preflight_cost_check(self, sql: str) -> tuple[bool, str, dict]:
|
| 17 |
+
"""Return (ok, reason, notes). Reason is machine-readable."""
|
| 18 |
+
sql_stripped = (sql or "").strip().rstrip(";")
|
| 19 |
+
notes: dict = {"sql_length": len(sql_stripped)}
|
| 20 |
+
if not sql_stripped:
|
| 21 |
+
return False, "empty_sql", notes
|
| 22 |
+
|
| 23 |
+
# Parse for cheap structural signals (LIMIT/JOIN/ORDER)
|
| 24 |
+
try:
|
| 25 |
+
tree = sqlglot.parse_one(
|
| 26 |
+
sql_stripped, read=getattr(self.db, "dialect", None) or "sqlite"
|
| 27 |
+
)
|
| 28 |
+
except Exception:
|
| 29 |
+
# Safety should usually catch parse errors; executor treats as reject.
|
| 30 |
+
return False, "parse_error", notes
|
| 31 |
+
|
| 32 |
+
has_limit = tree.find(exp.Limit) is not None
|
| 33 |
+
join_count = sum(1 for _ in tree.find_all(exp.Join))
|
| 34 |
+
has_order = tree.find(exp.Order) is not None
|
| 35 |
+
has_star = tree.find(exp.Star) is not None
|
| 36 |
+
notes.update(
|
| 37 |
+
{"has_limit": has_limit, "join_count": join_count, "has_order": has_order}
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Ask DB for a plan preview
|
| 41 |
+
try:
|
| 42 |
+
plan_lines = self.db.explain_query_plan(sql_stripped)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
# Planning failures are treated as non-OK but not as cost guardrail.
|
| 45 |
+
notes.update({"plan_error": str(e), "plan_error_type": type(e).__name__})
|
| 46 |
+
return True, "plan_unavailable", notes
|
| 47 |
+
|
| 48 |
+
plan_preview = plan_lines[:6] if isinstance(plan_lines, list) else []
|
| 49 |
+
notes.update({"plan_preview": plan_preview})
|
| 50 |
+
|
| 51 |
+
plan_text = "".join(plan_lines).lower() if isinstance(plan_lines, list) else ""
|
| 52 |
+
full_scan = ("scan" in plan_text) and ("index" not in plan_text)
|
| 53 |
+
notes.update({"full_scan": full_scan})
|
| 54 |
+
|
| 55 |
+
# MVP heuristics
|
| 56 |
+
# Block only the highest-risk pattern for v1: full scan + no LIMIT + SELECT *
|
| 57 |
+
if full_scan and (not has_limit) and has_star:
|
| 58 |
+
return False, "full_scan_without_limit", notes
|
| 59 |
+
# Very high join count is a strong proxy for expensive queries
|
| 60 |
+
if join_count >= 6:
|
| 61 |
+
return False, "too_many_joins", notes
|
| 62 |
+
return True, "ok", notes
|
| 63 |
+
|
| 64 |
def run(self, sql: str) -> StageResult:
|
| 65 |
t0 = time.perf_counter()
|
| 66 |
+
|
| 67 |
+
preflight_ok, preflight_reason, preflight_notes = self._preflight_cost_check(
|
| 68 |
+
sql
|
| 69 |
+
)
|
| 70 |
+
if not preflight_ok:
|
| 71 |
+
trace = StageTrace(
|
| 72 |
+
stage=self.name,
|
| 73 |
+
duration_ms=(time.perf_counter() - t0) * 1000,
|
| 74 |
+
summary="blocked",
|
| 75 |
+
notes={
|
| 76 |
+
**preflight_notes,
|
| 77 |
+
"blocked_reason": preflight_reason,
|
| 78 |
+
},
|
| 79 |
+
)
|
| 80 |
+
return StageResult(
|
| 81 |
+
ok=False,
|
| 82 |
+
data=None,
|
| 83 |
+
trace=trace,
|
| 84 |
+
error=[preflight_reason],
|
| 85 |
+
error_code=ErrorCode.EXECUTOR_COST_GUARDRAIL_BLOCKED,
|
| 86 |
+
retryable=False,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
try:
|
| 90 |
rows, cols = self.db.execute(sql)
|
| 91 |
trace = StageTrace(
|
|
|
|
| 95 |
"row_count": len(rows),
|
| 96 |
"col_count": len(cols),
|
| 97 |
"sql_length": len(sql or ""),
|
| 98 |
+
"preflight": preflight_reason,
|
| 99 |
+
**preflight_notes,
|
| 100 |
},
|
| 101 |
)
|
| 102 |
return StageResult(
|
|
|
|
| 110 |
"error": str(e),
|
| 111 |
"error_type": type(e).__name__,
|
| 112 |
"sql_length": len(sql or ""),
|
| 113 |
+
"preflight": preflight_reason,
|
| 114 |
+
**preflight_notes,
|
| 115 |
},
|
| 116 |
)
|
| 117 |
return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
|