github-actions[bot] commited on
Commit
0ecc315
·
1 Parent(s): d2d07a3

Sync from GitHub main @ 793782272bfcd6bdae9a711aabd0ec2b0aef2312

Browse files
adapters/db/base.py CHANGED
@@ -13,5 +13,5 @@ class DBAdapter(Protocol):
13
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
14
  """Execute a SELECT query and return (rows, columns)."""
15
 
16
- def explain_query_plan(self, sql: str) -> None:
17
- """Validate SQL by asking the DB to plan it (must be read-only). Raise on failure."""
 
13
  def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
14
  """Execute a SELECT query and return (rows, columns)."""
15
 
16
+ def explain_query_plan(self, sql: str) -> List[str]:
17
+ """Return a query plan preview (must be read-only). Raise on failure."""
adapters/db/postgres_adapter.py CHANGED
@@ -69,7 +69,7 @@ class PostgresAdapter(DBAdapter):
69
  cols: List[str] = [d[0] for d in desc if d]
70
  return rows, cols
71
 
72
- def explain_query_plan(self, sql: str) -> None:
73
  sql_stripped = (sql or "").strip().rstrip(";")
74
  if not sql_stripped.lower().startswith("select"):
75
  raise ValueError("Only SELECT statements are allowed.")
@@ -79,5 +79,7 @@ class PostgresAdapter(DBAdapter):
79
  with conn.cursor() as cur:
80
  cur.execute("SET TRANSACTION READ ONLY;")
81
  cur.execute(f"EXPLAIN {sql_stripped}")
82
- # We don't need the output; if planning fails, it raises.
83
- _ = cur.fetchall()
 
 
 
69
  cols: List[str] = [d[0] for d in desc if d]
70
  return rows, cols
71
 
72
+ def explain_query_plan(self, sql: str) -> List[str]:
73
  sql_stripped = (sql or "").strip().rstrip(";")
74
  if not sql_stripped.lower().startswith("select"):
75
  raise ValueError("Only SELECT statements are allowed.")
 
79
  with conn.cursor() as cur:
80
  cur.execute("SET TRANSACTION READ ONLY;")
81
  cur.execute(f"EXPLAIN {sql_stripped}")
82
+ rows = cur.fetchall() or []
83
+ # psycopg returns rows like ("Seq Scan on ...",)
84
+ plan_lines: List[str] = [str(r[0]) for r in rows if r and len(r) >= 1]
85
+ return plan_lines
adapters/db/sqlite_adapter.py CHANGED
@@ -45,7 +45,7 @@ class SQLiteAdapter(DBAdapter):
45
  log.info("Query executed successfully. Returned %d rows.", len(rows))
46
  return rows, cols
47
 
48
- def explain_query_plan(self, sql: str) -> None:
49
  if not self.path.exists():
50
  raise FileNotFoundError(f"SQLite DB does not exist: {self.path}")
51
 
@@ -60,4 +60,8 @@ class SQLiteAdapter(DBAdapter):
60
  conn.execute("PRAGMA query_only = ON;")
61
  except Exception:
62
  pass
63
- conn.execute(f"EXPLAIN QUERY PLAN {sql_stripped}")
 
 
 
 
 
45
  log.info("Query executed successfully. Returned %d rows.", len(rows))
46
  return rows, cols
47
 
48
+ def explain_query_plan(self, sql: str) -> List[str]:
49
  if not self.path.exists():
50
  raise FileNotFoundError(f"SQLite DB does not exist: {self.path}")
51
 
 
60
  conn.execute("PRAGMA query_only = ON;")
61
  except Exception:
62
  pass
63
+ cur = conn.execute(f"EXPLAIN QUERY PLAN {sql_stripped}")
64
+ rows = cur.fetchall() or []
65
+ # Rows are typically (id, parent, notused, detail)
66
+ plan_lines: List[str] = [str(r[-1]) for r in rows if r]
67
+ return plan_lines
nl2sql/errors/codes.py CHANGED
@@ -14,6 +14,7 @@ class ErrorCode(str, Enum):
14
  # --- Executor / DB ---
15
  DB_LOCKED = "DB_LOCKED"
16
  DB_TIMEOUT = "DB_TIMEOUT"
 
17
  LLM_FAILURE = "LLM_FAILURE"
18
 
19
  # --- LLM ---
 
14
  # --- Executor / DB ---
15
  DB_LOCKED = "DB_LOCKED"
16
  DB_TIMEOUT = "DB_TIMEOUT"
17
+ EXECUTOR_COST_GUARDRAIL_BLOCKED = "EXECUTOR_COST_GUARDRAIL_BLOCKED"
18
  LLM_FAILURE = "LLM_FAILURE"
19
 
20
  # --- LLM ---
nl2sql/errors/mapper.py CHANGED
@@ -8,6 +8,7 @@ ERROR_MAP = {
8
  ErrorCode.PLAN_SYNTAX_ERROR: (422, False),
9
  ErrorCode.DB_LOCKED: (503, True),
10
  ErrorCode.DB_TIMEOUT: (503, True),
 
11
  ErrorCode.LLM_TIMEOUT: (503, True),
12
  ErrorCode.PIPELINE_CRASH: (500, False),
13
  }
 
8
  ErrorCode.PLAN_SYNTAX_ERROR: (422, False),
9
  ErrorCode.DB_LOCKED: (503, True),
10
  ErrorCode.DB_TIMEOUT: (503, True),
11
+ ErrorCode.EXECUTOR_COST_GUARDRAIL_BLOCKED: (422, False),
12
  ErrorCode.LLM_TIMEOUT: (503, True),
13
  ErrorCode.PIPELINE_CRASH: (500, False),
14
  }
nl2sql/executor.py CHANGED
@@ -1,5 +1,9 @@
 
 
 
1
  import time
2
  from nl2sql.types import StageResult, StageTrace
 
3
  from adapters.db.base import DBAdapter
4
 
5
 
@@ -9,8 +13,79 @@ class Executor:
9
  def __init__(self, db: DBAdapter):
10
  self.db = db
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def run(self, sql: str) -> StageResult:
13
  t0 = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  try:
15
  rows, cols = self.db.execute(sql)
16
  trace = StageTrace(
@@ -20,6 +95,8 @@ class Executor:
20
  "row_count": len(rows),
21
  "col_count": len(cols),
22
  "sql_length": len(sql or ""),
 
 
23
  },
24
  )
25
  return StageResult(
@@ -33,6 +110,8 @@ class Executor:
33
  "error": str(e),
34
  "error_type": type(e).__name__,
35
  "sql_length": len(sql or ""),
 
 
36
  },
37
  )
38
  return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
 
1
+ import sqlglot
2
+ from sqlglot import exp
3
+
4
  import time
5
  from nl2sql.types import StageResult, StageTrace
6
+ from nl2sql.errors.codes import ErrorCode
7
  from adapters.db.base import DBAdapter
8
 
9
 
 
13
  def __init__(self, db: DBAdapter):
14
  self.db = db
15
 
16
+ def _preflight_cost_check(self, sql: str) -> tuple[bool, str, dict]:
17
+ """Return (ok, reason, notes). Reason is machine-readable."""
18
+ sql_stripped = (sql or "").strip().rstrip(";")
19
+ notes: dict = {"sql_length": len(sql_stripped)}
20
+ if not sql_stripped:
21
+ return False, "empty_sql", notes
22
+
23
+ # Parse for cheap structural signals (LIMIT/JOIN/ORDER)
24
+ try:
25
+ tree = sqlglot.parse_one(
26
+ sql_stripped, read=getattr(self.db, "dialect", None) or "sqlite"
27
+ )
28
+ except Exception:
29
+ # Safety should usually catch parse errors; executor treats as reject.
30
+ return False, "parse_error", notes
31
+
32
+ has_limit = tree.find(exp.Limit) is not None
33
+ join_count = sum(1 for _ in tree.find_all(exp.Join))
34
+ has_order = tree.find(exp.Order) is not None
35
+ has_star = tree.find(exp.Star) is not None
36
+ notes.update(
37
+ {"has_limit": has_limit, "join_count": join_count, "has_order": has_order}
38
+ )
39
+
40
+ # Ask DB for a plan preview
41
+ try:
42
+ plan_lines = self.db.explain_query_plan(sql_stripped)
43
+ except Exception as e:
44
+ # Planning failures are treated as non-OK but not as cost guardrail.
45
+ notes.update({"plan_error": str(e), "plan_error_type": type(e).__name__})
46
+ return True, "plan_unavailable", notes
47
+
48
+ plan_preview = plan_lines[:6] if isinstance(plan_lines, list) else []
49
+ notes.update({"plan_preview": plan_preview})
50
+
51
+ plan_text = "".join(plan_lines).lower() if isinstance(plan_lines, list) else ""
52
+ full_scan = ("scan" in plan_text) and ("index" not in plan_text)
53
+ notes.update({"full_scan": full_scan})
54
+
55
+ # MVP heuristics
56
+ # Block only the highest-risk pattern for v1: full scan + no LIMIT + SELECT *
57
+ if full_scan and (not has_limit) and has_star:
58
+ return False, "full_scan_without_limit", notes
59
+ # Very high join count is a strong proxy for expensive queries
60
+ if join_count >= 6:
61
+ return False, "too_many_joins", notes
62
+ return True, "ok", notes
63
+
64
  def run(self, sql: str) -> StageResult:
65
  t0 = time.perf_counter()
66
+
67
+ preflight_ok, preflight_reason, preflight_notes = self._preflight_cost_check(
68
+ sql
69
+ )
70
+ if not preflight_ok:
71
+ trace = StageTrace(
72
+ stage=self.name,
73
+ duration_ms=(time.perf_counter() - t0) * 1000,
74
+ summary="blocked",
75
+ notes={
76
+ **preflight_notes,
77
+ "blocked_reason": preflight_reason,
78
+ },
79
+ )
80
+ return StageResult(
81
+ ok=False,
82
+ data=None,
83
+ trace=trace,
84
+ error=[preflight_reason],
85
+ error_code=ErrorCode.EXECUTOR_COST_GUARDRAIL_BLOCKED,
86
+ retryable=False,
87
+ )
88
+
89
  try:
90
  rows, cols = self.db.execute(sql)
91
  trace = StageTrace(
 
95
  "row_count": len(rows),
96
  "col_count": len(cols),
97
  "sql_length": len(sql or ""),
98
+ "preflight": preflight_reason,
99
+ **preflight_notes,
100
  },
101
  )
102
  return StageResult(
 
110
  "error": str(e),
111
  "error_type": type(e).__name__,
112
  "sql_length": len(sql or ""),
113
+ "preflight": preflight_reason,
114
+ **preflight_notes,
115
  },
116
  )
117
  return StageResult(ok=False, data=None, trace=trace, error=[str(e)])