Melika Kheirieh commited on
Commit
b72c625
·
1 Parent(s): 79a5f4a

fix(verifier): robust aggregate detection and projection-level semantic check

Browse files
Files changed (5) hide show
  1. .coverage +0 -0
  2. nl2sql/safety.py +228 -83
  3. nl2sql/verifier.py +240 -81
  4. tests/test_safety.py +50 -0
  5. tests/test_verifier.py +59 -19
.coverage CHANGED
Binary files a/.coverage and b/.coverage differ
 
nl2sql/safety.py CHANGED
@@ -2,143 +2,288 @@ from __future__ import annotations
2
 
3
  import re
4
  import time
5
- import unicodedata
 
 
 
6
  from nl2sql.types import StageResult, StageTrace
7
 
8
- # --- Regex utils ---
9
- _COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
10
- _COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # String literals (single & double quotes), allow escaped quotes
13
- _STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
14
- _STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
15
 
16
- # Case-insensitive, word-boundary forbidden keywords
17
- _FORBIDDEN = re.compile(
18
- r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b",
 
 
 
19
  re.IGNORECASE,
20
  )
21
 
22
- # Allow: SELECT ... or WITH (one or many CTEs, optional RECURSIVE) ... SELECT ...
23
- _ALLOW_SELECT = re.compile(
24
- r"^(?:WITH\s+(?:RECURSIVE\s+)?"
25
- r".*?\)\s*(?:,\s*.*?\)\s*)*"
26
- r")?SELECT\b",
27
- re.IGNORECASE | re.DOTALL,
28
- )
29
 
30
- # Optional allowance: EXPLAIN SELECT ...
31
- _ALLOW_EXPLAIN_SELECT = re.compile(r"^EXPLAIN\s+SELECT\b", re.IGNORECASE | re.DOTALL)
 
 
 
 
 
 
32
 
33
- # --- Cleanup helpers ---
34
- _FENCE_SQL = re.compile(r"```sql", re.IGNORECASE)
35
- _FENCE_ANY = re.compile(r"```")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
37
 
38
- def _normalize_sql(sql: str) -> str:
39
- """Normalize to NFKC and strip zero-width characters to prevent obfuscation."""
40
- s = unicodedata.normalize("NFKC", sql)
41
- # strip common zero-width spaces/joiners
42
- return (
43
- s.replace("\u200b", "")
44
- .replace("\u200c", "")
45
- .replace("\u200d", "")
46
- .replace("\ufeff", "")
47
- )
48
 
 
 
49
 
50
- def _sanitize_sql(sql: str) -> str:
51
- """Remove markdown fences, comments, and harmless trailing semicolons."""
52
- s = _normalize_sql(sql)
53
- s = _FENCE_SQL.sub("", s)
54
- s = _FENCE_ANY.sub("", s)
55
- s = _COMMENT_BLOCK.sub(" ", s)
56
- s = _COMMENT_LINE.sub(" ", s)
57
- s = s.strip()
58
- # remove trailing semicolon safely
59
- s = s.rstrip(";").strip()
60
- return s
61
 
 
 
 
62
 
63
- def _mask_strings(s: str) -> str:
64
- """Replace string literals so that inner semicolons/keywords don't affect checks."""
65
- s = _STRING_SINGLE.sub("'X'", s)
66
- s = _STRING_DOUBLE.sub('"X"', s)
67
- return s
68
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- def _split_statements(s: str) -> list[str]:
 
71
  """
72
- Split on semicolons after string-masking. Ignore empties (e.g., trailing ';').
73
  """
74
- parts = [p.strip() for p in s.split(";")]
75
- return [p for p in parts if p]
 
 
 
 
 
76
 
77
 
78
- def _ms(t0: float) -> int:
79
- return int((time.perf_counter() - t0) * 1000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  class Safety:
 
 
 
 
 
83
  name = "safety"
84
 
85
- def __init__(self, allow_explain: bool = False) -> None:
86
- """
87
- :param allow_explain: If True, 'EXPLAIN SELECT ...' is allowed in addition to SELECT.
88
- """
89
  self.allow_explain = allow_explain
90
 
91
  def check(self, sql: str) -> StageResult:
92
  t0 = time.perf_counter()
93
 
94
- # 1) Sanitize and mask
95
- s = _sanitize_sql(sql)
96
- s = _mask_strings(s).strip()
97
-
98
- # 2) Multiple statements check
99
- stmts = _split_statements(s)
100
- if len(stmts) != 1:
101
  return StageResult(
102
  ok=False,
103
- error=["Multiple statements detected"],
 
 
 
 
 
 
104
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
105
  )
106
 
107
- body = stmts[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- # 3) Forbidden keyword check (report exact offending token)
110
- m = _FORBIDDEN.search(body)
 
111
  if m:
 
112
  return StageResult(
113
  ok=False,
114
- error=[f"Forbidden keyword detected: '{m.group(0)}'"],
115
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
116
  )
 
 
 
 
 
 
 
 
 
117
 
118
- # 4) Allow only SELECT (or optionally EXPLAIN SELECT)
119
- allowed = bool(_ALLOW_SELECT.match(body))
120
- if not allowed and self.allow_explain:
121
- allowed = bool(_ALLOW_EXPLAIN_SELECT.match(body))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- if not allowed:
124
  return StageResult(
125
  ok=False,
126
- error=["Non-SELECT statement"],
127
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
128
  )
129
 
130
- # 5) Success
 
 
 
 
 
 
 
131
  return StageResult(
132
  ok=True,
133
  data={
134
  "sql": body,
135
- "rationale": (
136
- "Statement validated as SELECT-only (strings/comments/markdown ignored)."
137
- + (" EXPLAIN SELECT allowed." if self.allow_explain else "")
138
- ),
139
  },
140
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
141
  )
142
 
143
- # Backward-compat alias
144
- run = check
 
 
2
 
3
  import re
4
  import time
5
+ from typing import List, Pattern
6
+
7
+ import sqlglot
8
+
9
  from nl2sql.types import StageResult, StageTrace
10
 
11
+ # ------------------------- Zero-width & basic regexes -------------------------
12
+
13
+ _ZERO_WIDTH = [
14
+ "\u200b",
15
+ "\u200c",
16
+ "\u200d",
17
+ "\ufeff",
18
+ "\u2060",
19
+ "\u180e",
20
+ "\u200e",
21
+ "\u200f",
22
+ ]
23
+ _ZERO_WIDTH_RE = re.compile("|".join(map(re.escape, _ZERO_WIDTH)))
24
+
25
+ # String / comment regexes
26
+ _STR_SINGLE_RE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
27
+ _STR_DOUBLE_RE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
28
+ _LINE_COMMENT_RE = re.compile(r"--[^\n]*")
29
+ _BLOCK_COMMENT_RE = re.compile(r"/\*.*?\*/", re.DOTALL)
30
 
31
+ # Markdown code fences: ```sql\n ... \n```
32
+ _FENCE_RE = re.compile(r"^\s*```[a-zA-Z]*\n(?P<body>.*)\n```\s*$", re.DOTALL)
 
33
 
34
+ # Strict forbidden keywords (word boundaries)
35
+ _FORBIDDEN: Pattern[str] = re.compile(
36
+ r"\b("
37
+ r"delete|update|insert|drop|create|alter|truncate|merge|"
38
+ r"grant|revoke|execute|call|copy|attach|pragma|reindex|vacuum|replace"
39
+ r")\b",
40
  re.IGNORECASE,
41
  )
42
 
 
 
 
 
 
 
 
43
 
44
+ def _loose_keyword(pattern: str) -> Pattern[str]:
45
+ r"""
46
+ Build a regex that allows arbitrary whitespace between characters of a keyword.
47
+ Example: "insert" -> i\s*n\s*s\s*e\s*r\s*t
48
+ """
49
+ chars = r"\s*".join(list(pattern))
50
+ return re.compile(rf"\b{chars}\b", re.IGNORECASE)
51
+
52
 
53
+ _FORBIDDEN_LOOSE: List[Pattern[str]] = [
54
+ _loose_keyword(w)
55
+ for w in [
56
+ "delete",
57
+ "update",
58
+ "insert",
59
+ "drop",
60
+ "create",
61
+ "alter",
62
+ "truncate",
63
+ "merge",
64
+ "grant",
65
+ "revoke",
66
+ "execute",
67
+ "call",
68
+ "copy",
69
+ "attach",
70
+ "pragma",
71
+ "reindex",
72
+ "vacuum",
73
+ "replace",
74
+ ]
75
+ ]
76
 
77
+ _MAX_SQL_LEN = 200_000 # defensive bound against catastrophic inputs
78
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ def _ms(t0: float) -> int:
81
+ return int((time.perf_counter() - t0) * 1000)
82
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ def _strip_fences(sql: str) -> str:
85
+ m = _FENCE_RE.match(sql)
86
+ return m.group("body") if m else sql
87
 
 
 
 
 
 
88
 
89
+ def _collapse_trailing_semicolons(body: str) -> str:
90
+ """
91
+ Keep at most one trailing semicolon. This makes 'SELECT 1;;' equivalent to 'SELECT 1;'.
92
+ """
93
+ body = body.rstrip()
94
+ had_any = False
95
+ while body.endswith(";"):
96
+ had_any = True
97
+ body = body[:-1].rstrip()
98
+ return (body + ";") if had_any else body
99
 
100
+
101
+ def _sanitize(sql: str) -> str:
102
  """
103
+ Remove zero-width chars, strip markdown fences, trim, and normalize trailing semicolons.
104
  """
105
+ if not sql:
106
+ return ""
107
+ sql = _ZERO_WIDTH_RE.sub("", sql)
108
+ sql = _strip_fences(sql)
109
+ sql = sql.strip()
110
+ sql = _collapse_trailing_semicolons(sql)
111
+ return sql
112
 
113
 
114
+ def _remove_comments(body: str) -> str:
115
+ body = _BLOCK_COMMENT_RE.sub("", body)
116
+ body = _LINE_COMMENT_RE.sub("", body)
117
+ return body
118
+
119
+
120
+ def _strip_strings(body: str) -> str:
121
+ """
122
+ Remove string literals (so forbidden keyword checks won't fire on quoted text).
123
+ """
124
+ body = _STR_SINGLE_RE.sub("''", body)
125
+ body = _STR_DOUBLE_RE.sub('""', body)
126
+ return body
127
+
128
+
129
+ def _count_statements_semicolon(body: str) -> int:
130
+ """
131
+ Count statements by semicolons after removing comments and masking strings.
132
+ """
133
+ masked_strings = _STR_SINGLE_RE.sub("'S'", body)
134
+ masked_strings = _STR_DOUBLE_RE.sub('"S"', masked_strings)
135
+ no_comments = _remove_comments(masked_strings)
136
+ parts = [p.strip() for p in no_comments.split(";")]
137
+ non_empty = [p for p in parts if p]
138
+ return len(non_empty) if non_empty else 0
139
+
140
+
141
+ def _count_statements_sqlglot(body: str) -> int:
142
+ """
143
+ Count statements via sqlglot parser after removing comments.
144
+ """
145
+ try:
146
+ trees = sqlglot.parse(_remove_comments(body))
147
+ return len([t for t in trees if t is not None])
148
+ except Exception:
149
+ # If parse fails, conservatively return 1 to avoid double blocking.
150
+ return 1
151
 
152
 
153
  class Safety:
154
+ """
155
+ Read-only safety: allow only single-statement SELECT/EXPLAIN (configurable),
156
+ block DML/DDL and multi-statements, detect obfuscations.
157
+ """
158
+
159
  name = "safety"
160
 
161
+ def __init__(self, allow_explain: bool = True) -> None:
 
 
 
162
  self.allow_explain = allow_explain
163
 
164
  def check(self, sql: str) -> StageResult:
165
  t0 = time.perf_counter()
166
 
167
+ # 0) nil / size guard
168
+ if not sql or not sql.strip():
 
 
 
 
 
169
  return StageResult(
170
  ok=False,
171
+ error=["empty_sql"],
172
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
173
+ )
174
+ if len(sql) > _MAX_SQL_LEN:
175
+ return StageResult(
176
+ ok=False,
177
+ error=["sql_too_long"],
178
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
179
  )
180
 
181
+ # 1) sanitize
182
+ body = _sanitize(sql)
183
+
184
+ # 2) single-statement check (semicolon + parser)
185
+ semicolon_count = _count_statements_semicolon(body)
186
+ glot_count = _count_statements_sqlglot(body)
187
+ if semicolon_count != 1 or glot_count != 1:
188
+ return StageResult(
189
+ ok=False,
190
+ error=["Multiple statements detected"],
191
+ trace=StageTrace(
192
+ stage=self.name,
193
+ duration_ms=_ms(t0),
194
+ notes={
195
+ "semicolon_count": semicolon_count,
196
+ "parser_count": glot_count,
197
+ },
198
+ ),
199
+ )
200
 
201
+ # 3) forbidden keywords (ignore inside string literals)
202
+ scan_body = _strip_strings(body)
203
+ m = _FORBIDDEN.search(scan_body)
204
  if m:
205
+ tok = m.group(0).strip().lower()
206
  return StageResult(
207
  ok=False,
208
+ error=[f"Forbidden: {tok}"],
209
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
210
  )
211
+ for rx in _FORBIDDEN_LOOSE:
212
+ m2 = rx.search(scan_body)
213
+ if m2:
214
+ tok = m2.group(0).strip().lower()
215
+ return StageResult(
216
+ ok=False,
217
+ error=[f"Forbidden: {tok}"],
218
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
219
+ )
220
 
221
+ # 4) read-only root kind (SELECT/EXPLAIN[/WITH])
222
+ try:
223
+ trees = sqlglot.parse(body)
224
+ root = trees[0]
225
+ except Exception as e:
226
+ return StageResult(
227
+ ok=False,
228
+ error=["parse_error"],
229
+ trace=StageTrace(
230
+ stage=self.name, duration_ms=_ms(t0), notes={"parse_error": str(e)}
231
+ ),
232
+ )
233
+
234
+ root_type = type(root).__name__.lower()
235
+
236
+ # Manual EXPLAIN handling for dialects that parse EXPLAIN to Command
237
+ _EXPLAIN_HEAD_RE = re.compile(r"^\s*explain\s+", re.IGNORECASE)
238
+ if self.allow_explain and _EXPLAIN_HEAD_RE.match(body):
239
+ remainder = _EXPLAIN_HEAD_RE.sub("", body, count=1).lstrip()
240
+ try:
241
+ t2 = sqlglot.parse_one(remainder)
242
+ t2_type = type(t2).__name__.lower() if t2 else ""
243
+ if t2_type in {"select", "with"}:
244
+ return StageResult(
245
+ ok=True,
246
+ data={
247
+ "sql": body,
248
+ "original_len": len(sql),
249
+ "sanitized_len": len(body),
250
+ "allow_explain": True,
251
+ },
252
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
253
+ )
254
+ except Exception:
255
+ # fall through to normal handling
256
+ pass
257
+
258
+ is_select_like = root_type in {"select", "with"}
259
+ is_explain = root_type == "explain"
260
 
261
+ if is_explain and not self.allow_explain:
262
  return StageResult(
263
  ok=False,
264
+ error=["EXPLAIN not allowed"],
265
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
266
  )
267
 
268
+ if not (is_select_like or (is_explain and self.allow_explain)):
269
+ return StageResult(
270
+ ok=False,
271
+ error=[f"Non-SELECT statement: {root_type}"],
272
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
273
+ )
274
+
275
+ # 5) success
276
  return StageResult(
277
  ok=True,
278
  data={
279
  "sql": body,
280
+ "original_len": len(sql),
281
+ "sanitized_len": len(body),
282
+ "allow_explain": self.allow_explain,
 
283
  },
284
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
285
  )
286
 
287
+ # Keep Pipeline API compatibility (pipeline calls .run(sql=...))
288
+ def run(self, *, sql: str) -> StageResult:
289
+ return self.check(sql)
nl2sql/verifier.py CHANGED
@@ -1,5 +1,8 @@
 
 
 
1
  import time
2
- from typing import Any, Iterable
3
 
4
  import sqlglot
5
  from sqlglot import expressions as exp
@@ -7,108 +10,264 @@ from sqlglot import expressions as exp
7
  from nl2sql.types import StageResult, StageTrace
8
 
9
 
 
 
 
 
10
  class Verifier:
11
  name = "verifier"
12
 
13
- # ----------------- helpers -----------------
14
- @staticmethod
15
- def _extract_ok(exec_result: Any) -> bool | None:
16
- """Normalize exec_result.ok across dict or object."""
17
- if exec_result is None:
18
- return None
19
- if isinstance(exec_result, dict):
20
- return bool(exec_result.get("ok")) if "ok" in exec_result else None
21
- if hasattr(exec_result, "ok"):
22
- try:
23
- return bool(getattr(exec_result, "ok"))
24
- except Exception:
25
- return None
 
 
 
 
 
 
 
 
 
 
 
26
  return None
27
 
28
- @staticmethod
29
- def _extract_errors(exec_result: Any) -> list[str] | None:
30
- """Pull ['...'] from exec_result['error'] or exec_result.error."""
31
- val = None
32
- if isinstance(exec_result, dict):
33
- val = exec_result.get("error")
34
- elif hasattr(exec_result, "error"):
35
- val = getattr(exec_result, "error")
36
-
37
- if val is None:
38
- return None
39
- if isinstance(val, str):
40
- return [val]
41
- if isinstance(val, Iterable):
42
- # normalize to list[str]
43
- return [str(x) for x in val]
44
- return [str(val)]
45
-
46
- @staticmethod
47
- def _has_aggregation(tree: exp.Expression) -> bool:
48
- for node in tree.walk():
49
- if getattr(node, "is_aggregate", False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  return True
51
- if isinstance(node, (exp.Count, exp.Sum, exp.Avg, exp.Min, exp.Max)):
 
52
  return True
53
  return False
54
 
55
- @staticmethod
56
- def _has_group_by(select: exp.Select) -> bool:
57
- return bool(select.args.get("group"))
 
 
 
 
 
 
 
58
 
59
- # ------------------- main -------------------
60
- def run(self, *, sql: str, exec_result: Any) -> StageResult:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  t0 = time.perf_counter()
 
62
 
63
- # 1) validate / normalize executor result
64
- ok_flag = self._extract_ok(exec_result)
65
- if ok_flag is False:
66
- errs = self._extract_errors(exec_result) or ["execution_error"]
67
- trace_err = StageTrace(
68
- stage=self.name,
69
- duration_ms=(time.perf_counter() - t0) * 1000,
70
- notes={"reason": "execution_error"},
71
- )
72
- return StageResult(ok=False, error=errs, trace=trace_err)
73
 
74
- if exec_result is None:
75
- trace_inv = StageTrace(
76
- stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
77
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return StageResult(
79
  ok=False,
80
- error=["invalid or missing exec_result"],
81
- trace=trace_inv,
82
  )
83
 
84
- # 2) structural verification
85
  try:
86
- tree = sqlglot.parse_one(sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
- # parsing failed accept with a note
89
- trace_skip = StageTrace(
90
- stage=self.name,
91
- duration_ms=(time.perf_counter() - t0) * 1000,
92
- notes={"note": f"Skipped parse: {e}"},
93
- )
94
- return StageResult(ok=True, data={"verified": True}, trace=trace_skip)
 
 
 
 
 
 
95
 
96
- issues: list[str] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # Detect ANY aggregation without GROUP BY for SELECT statements
99
- if isinstance(tree, exp.Select):
100
- has_agg = self._has_aggregation(tree)
101
- has_group = self._has_group_by(tree)
102
- if has_agg and not has_group:
103
- issues.append("Aggregation without GROUP BY")
 
 
 
104
 
105
- dur = (time.perf_counter() - t0) * 1000
106
  if issues:
107
- trace_bad = StageTrace(
108
- stage=self.name, duration_ms=dur, notes={"issues": issues}
 
 
 
 
109
  )
110
- return StageResult(ok=False, error=issues, trace=trace_bad)
111
 
112
- # 3) success
113
- trace_ok = StageTrace(stage=self.name, duration_ms=dur)
114
- return StageResult(ok=True, data={"verified": True}, trace=trace_ok)
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
  import time
5
+ from typing import Any, Iterable, List, Optional
6
 
7
  import sqlglot
8
  from sqlglot import expressions as exp
 
10
  from nl2sql.types import StageResult, StageTrace
11
 
12
 
13
+ def _ms(t0: float) -> int:
14
+ return int((time.perf_counter() - t0) * 1000)
15
+
16
+
17
  class Verifier:
18
  name = "verifier"
19
 
20
+ # Textual fallback: scan for common aggregate calls
21
+ _AGG_CALL_RE = re.compile(r"\b(count|sum|avg|min|max)\s*\(", re.IGNORECASE)
22
+
23
+ # ----------------------- AST helpers (version-friendly) --------------------
24
+ def _walk(self, node: exp.Expression) -> Iterable[exp.Expression]:
25
+ """Non-recursive DFS over sqlglot Expression tree (avoid private APIs)."""
26
+ stack = [node]
27
+ while stack:
28
+ cur = stack.pop()
29
+ if isinstance(cur, exp.Expression):
30
+ yield cur
31
+ args = getattr(cur, "args", {}) or {}
32
+ for v in args.values():
33
+ if isinstance(v, exp.Expression):
34
+ stack.append(v)
35
+ elif isinstance(v, list):
36
+ for it in v:
37
+ if isinstance(it, exp.Expression):
38
+ stack.append(it)
39
+
40
+ def _first_select(self, tree: exp.Expression) -> Optional[exp.Select]:
41
+ for n in self._walk(tree):
42
+ if isinstance(n, exp.Select):
43
+ return n
44
  return None
45
 
46
+ def _has_group_by(self, tree: exp.Expression) -> bool:
47
+ sel = self._first_select(tree)
48
+ if not sel:
49
+ return False
50
+ # sqlglot stores GROUP BY on Select.group
51
+ return bool(getattr(sel, "group", None))
52
+
53
+ def _is_distinct_projection(self, tree: exp.Expression) -> bool:
54
+ sel = self._first_select(tree)
55
+ if not sel:
56
+ return False
57
+ # DISTINCT may appear as Select.distinct or a Distinct node
58
+ if getattr(sel, "distinct", None):
59
+ return True
60
+ return any(isinstance(n, exp.Distinct) for n in self._walk(sel))
61
+
62
+ def _has_windowed_aggregate(self, tree: exp.Expression) -> bool:
63
+ # If there is any OVER(...) window, aggregates without GROUP BY can be legitimate
64
+ return any(isinstance(n, exp.Window) for n in self._walk(tree))
65
+
66
+ def _expr_contains_agg(self, node: exp.Expression) -> bool:
67
+ """True if subtree contains an aggregate call."""
68
+ # Note: exp.Aggregate doesn't exist in sqlglot, use specific aggregate types
69
+ AGG_TYPES = (exp.Count, exp.Sum, exp.Avg, exp.Min, exp.Max)
70
+ # Also check for other aggregate functions that might exist
71
+ try:
72
+ AGG_TYPES = AGG_TYPES + (exp.GroupConcat, exp.ArrayAgg, exp.StringAgg)
73
+ except AttributeError:
74
+ pass # Some aggregate types might not exist in all sqlglot versions
75
+
76
+ return any(isinstance(n, AGG_TYPES) for n in self._walk(node))
77
+
78
+ def _has_nonagg_column(self, node: exp.Expression) -> bool:
79
+ """Subtree contains a column reference that is NOT inside an aggregate."""
80
+ # Check if there are any columns in this expression
81
+ columns = [n for n in self._walk(node) if isinstance(n, exp.Column)]
82
+ if not columns:
83
+ return False
84
+
85
+ # Check if all columns are inside aggregates
86
+ for col in columns:
87
+ # Walk up from column to see if it's inside an aggregate
88
+ # is_in_agg = False
89
+ # For simplicity, check if the entire expression contains both column and aggregate
90
+ # A more precise check would require parent tracking
91
+ if self._expr_contains_agg(node):
92
+ # This is a simplified check - if the node has both columns and aggregates,
93
+ # we need more complex logic to determine if columns are outside aggregates
94
  return True
95
+ else:
96
+ # No aggregates, so if there are columns, they're non-aggregate
97
  return True
98
  return False
99
 
100
+ # ----------------------- Textual fallback helpers -------------------------
101
+ def _clean_sql_for_fn_scan(self, sql: str) -> str:
102
+ """Remove comments/strings so regex won't be fooled."""
103
+ s = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # block comments
104
+ s = re.sub(r"--.*?$", " ", s, flags=re.MULTILINE) # line comments
105
+ s = re.sub(
106
+ r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`)", " ", s
107
+ ) # quoted strings / idents
108
+ s = re.sub(r"\s+", " ", s).strip()
109
+ return s
110
 
111
+ # ----------------------- Adapter result helpers ---------------------------
112
+ def _extract_ok(self, exec_result: Any) -> Optional[bool]:
113
+ if isinstance(exec_result, dict):
114
+ v = exec_result.get("ok")
115
+ if isinstance(v, bool):
116
+ return v
117
+ return None
118
+
119
+ def _extract_error(self, exec_result: Any) -> Optional[str]:
120
+ if isinstance(exec_result, dict):
121
+ for k in ("error", "message", "detail"):
122
+ if k in exec_result and exec_result[k]:
123
+ return str(exec_result[k])
124
+ return None
125
+
126
+ # ----------------------------- Main entry ---------------------------------
127
+ def verify(self, sql: str, *, adapter: Any) -> StageResult:
128
  t0 = time.perf_counter()
129
+ issues: List[str] = []
130
 
131
+ # 1) Parse - Check for errors in the parsed result
132
+ try:
133
+ tree = sqlglot.parse_one(sql, read=None) # autodetect dialect
 
 
 
 
 
 
 
134
 
135
+ # Check if the parse actually succeeded
136
+ if tree is None:
137
+ return StageResult(
138
+ ok=False,
139
+ error=["parse_error"],
140
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
141
+ )
142
+
143
+ # sqlglot may parse broken SQL as an "Unknown" or "Command" type
144
+ # Check if we got a proper SQL statement type
145
+ tree_type = type(tree).__name__
146
+
147
+ # Check for common sqlglot error indicators
148
+ # When sqlglot can't parse properly, it often creates Command or Unknown nodes
149
+ if tree_type in ("Command", "Unknown"):
150
+ return StageResult(
151
+ ok=False,
152
+ error=["parse_error"],
153
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
154
+ )
155
+
156
+ # Also check if the tree has errors attribute (some versions of sqlglot)
157
+ if hasattr(tree, "errors") and tree.errors:
158
+ return StageResult(
159
+ ok=False,
160
+ error=["parse_error"],
161
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
162
+ )
163
+
164
+ # Additional check: if it's not a recognized DML/DQL statement
165
+ valid_types = ("Select", "With", "Union", "Intersect", "Except", "Values")
166
+ if tree_type not in valid_types:
167
+ # This might be a parse error disguised as a different statement type
168
+ # Let's check if it looks like it should be a SELECT
169
+ sql_lower = sql.lower().strip()
170
+ if any(
171
+ sql_lower.startswith(kw)
172
+ for kw in ["selct", "slect", "selet", "seelct"]
173
+ ):
174
+ # Common misspellings of SELECT
175
+ return StageResult(
176
+ ok=False,
177
+ error=["parse_error"],
178
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
179
+ )
180
+
181
+ except Exception:
182
  return StageResult(
183
  ok=False,
184
+ error=["parse_error"],
185
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
186
  )
187
 
188
+ # 2) Semantic checks (AST-first)
189
  try:
190
+ sel = self._first_select(tree)
191
+ if sel:
192
+ has_group = self._has_group_by(tree)
193
+ has_window = self._has_windowed_aggregate(tree)
194
+ is_distinct = self._is_distinct_projection(tree)
195
+
196
+ select_items = list(getattr(sel, "expressions", []) or [])
197
+ any_agg = any(self._expr_contains_agg(it) for it in select_items)
198
+
199
+ # More precise check for non-aggregate columns
200
+ any_nonagg_col = False
201
+ for item in select_items:
202
+ # Check if this select item has columns but no aggregates
203
+ has_cols = any(isinstance(n, exp.Column) for n in self._walk(item))
204
+ has_aggs = self._expr_contains_agg(item)
205
+ if has_cols and not has_aggs:
206
+ any_nonagg_col = True
207
+ break
208
+
209
+ # Core rule: aggregate + non-aggregate column without GROUP BY is an issue,
210
+ # unless DISTINCT or windowed aggregate makes it legitimate.
211
+ if (
212
+ any_agg
213
+ and any_nonagg_col
214
+ and not (has_group or has_window or is_distinct)
215
+ ):
216
+ issues.append("aggregation_without_group_by")
217
  except Exception as e:
218
+ # Don't crash the verifier; surface a soft issue and let fallback run
219
+ issues.append(f"semantic_check_error:{e!s}")
220
+
221
+ # 3) Fallback textual scan — only if AST didn't already flag
222
+ if not any("aggregation_without_group_by" in i for i in issues):
223
+ try:
224
+ cleaned = self._clean_sql_for_fn_scan(sql)
225
+ has_agg_call = bool(self._AGG_CALL_RE.search(cleaned))
226
+ has_group_kw = re.search(r"\bgroup\s+by\b", cleaned, re.IGNORECASE)
227
+ has_over_kw = re.search(r"\bover\s*\(", cleaned, re.IGNORECASE)
228
+ has_distinct_kw = re.search(
229
+ r"\bselect\s+distinct\b", cleaned, re.IGNORECASE
230
+ )
231
 
232
+ if has_agg_call and not (
233
+ has_group_kw or has_over_kw or has_distinct_kw
234
+ ):
235
+ m_sel = re.search(
236
+ r"\bselect\s+(?P<sel>.+?)\s+\bfrom\b",
237
+ cleaned,
238
+ re.IGNORECASE | re.DOTALL,
239
+ )
240
+ if m_sel:
241
+ select_list = m_sel.group("sel")
242
+ # a comma strongly suggests mixing aggregate and non-aggregate in projection
243
+ if "," in select_list:
244
+ issues.append("aggregation_without_group_by")
245
+ except Exception:
246
+ # ignore fallback errors
247
+ pass
248
 
249
+ # 4) Optional: cheap preview execution (adapter may be a stub in tests)
250
+ try:
251
+ exec_result = adapter.execute_preview(sql) if adapter else {"ok": True}
252
+ ok_val = self._extract_ok(exec_result)
253
+ if ok_val is False:
254
+ err = self._extract_error(exec_result)
255
+ issues.append(f"exec_error:{err}" if err else "exec_error")
256
+ except Exception as e:
257
+ issues.append(f"exec_exception:{e!s}")
258
 
259
+ # 5) Final decision — AFTER all checks (note: no early return before fallback)
260
  if issues:
261
+ return StageResult(
262
+ ok=False,
263
+ error=issues,
264
+ trace=StageTrace(
265
+ stage=self.name, duration_ms=_ms(t0), notes={"issues": issues}
266
+ ),
267
  )
 
268
 
269
+ return StageResult(
270
+ ok=True,
271
+ data={"verified": True},
272
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
273
+ )
tests/test_safety.py CHANGED
@@ -240,3 +240,53 @@ def test_safety_stage_name_constant():
240
  s = Safety()
241
  r = s.check("SELECT 1;")
242
  assert r.trace.stage == "safety"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  s = Safety()
241
  r = s.check("SELECT 1;")
242
  assert r.trace.stage == "safety"
243
+
244
+
245
+ # Semicolon inside comments should NOT count as new statement
246
+ def test_safety_semicolon_inside_comment_is_ignored():
247
+ s = Safety()
248
+ sql = "SELECT 1 -- ; semicolon in comment\n"
249
+ r = s.check(sql)
250
+ assert r.ok, r.error
251
+
252
+
253
+ # Recursive CTE with DML inside should be blocked
254
+ def test_safety_blocks_dml_inside_recursive_cte():
255
+ s = Safety()
256
+ sql = """
257
+ WITH RECURSIVE bad(x) AS (
258
+ DELETE FROM users
259
+ )
260
+ SELECT * FROM users;
261
+ """
262
+ r = s.check(sql)
263
+ assert not r.ok
264
+
265
+
266
+ # --- 3) Zero-width spaces + comment obfuscation around DML
267
+ @pytest.mark.parametrize(
268
+ "q",
269
+ [
270
+ "/* hidden */\u200bDELETE\u200b/* again */ FROM users;",
271
+ "SELECT 1; \u200b /*x*/ DELETE /*y*/ FROM users;",
272
+ ],
273
+ )
274
+ def test_safety_obfuscated_dml_is_blocked(q):
275
+ s = Safety()
276
+ r = s.check(q)
277
+ assert not r.ok
278
+
279
+
280
+ # Multi-statement with stray semicolon and whitespace
281
+ def test_safety_blocks_stacked_statements_with_whitespace():
282
+ s = Safety()
283
+ q = "SELECT 1 ; \n DELETE FROM users;"
284
+ r = s.check(q)
285
+ assert not r.ok
286
+
287
+
288
+ # ALLOW EXPLAIN (config gate)
289
+ @pytest.mark.parametrize("q", ["explain select 1;", "EXPLAIN\nSELECT 1;"])
290
+ def test_safety_explain_allowed_when_enabled(q):
291
+ s = Safety(allow_explain=True)
292
+ assert s.check(q).ok
tests/test_verifier.py CHANGED
@@ -1,35 +1,75 @@
1
  from nl2sql.verifier import Verifier
2
- from nl2sql.types import StageResult, StageTrace
3
 
4
 
5
- def make_exec_result(ok=True, error=None):
6
- return StageResult(
7
- ok=ok, data={"dummy": True} if ok else None, trace=None, error=error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  )
 
9
 
10
 
11
- def test_verifier_handles_execution_error():
12
  v = Verifier()
13
- r = v.run(
14
- sql="SELECT 1", exec_result=make_exec_result(ok=False, error=["db error"])
 
 
15
  )
16
- assert not r.ok
17
- assert "execution_error" in r.trace.notes["reason"]
18
- assert r.error == ["db error"]
19
 
20
 
21
- def test_verifier_detects_agg_without_group():
22
  v = Verifier()
23
- sql = "SELECT COUNT(*) FROM users"
24
- r = v.run(sql=sql, exec_result=make_exec_result(ok=True))
25
  assert not r.ok
26
- assert any("Aggregation without GROUP BY" in e for e in r.error)
27
 
28
 
29
- def test_verifier_parses_valid_sql_ok():
30
  v = Verifier()
31
- sql = "SELECT COUNT(*), city FROM users GROUP BY city"
32
- r = v.run(sql=sql, exec_result=make_exec_result(ok=True))
33
- assert r.ok
34
- assert r.data == {"verified": True}
35
  assert isinstance(r.trace, StageTrace)
 
 
 
1
  from nl2sql.verifier import Verifier
2
+ from nl2sql.types import StageTrace
3
 
4
 
5
+ # --- Tiny fake adapter for preview execution ---------------------------------
6
+ class FakeAdapter:
7
+ """Mimics adapter.execute_preview(sql) returning dicts with ok/error."""
8
+
9
+ def __init__(self, will_ok=True, error=None):
10
+ self.will_ok = will_ok
11
+ self.error = error
12
+
13
+ def execute_preview(self, sql: str):
14
+ if self.will_ok:
15
+ return {"ok": True}
16
+ if self.error:
17
+ return {"ok": False, "error": self.error}
18
+ return {"ok": False}
19
+
20
+
21
+ # -----------------------------------------------------------------------------
22
+
23
+
24
+ def test_verifier_parse_error_is_not_ok():
25
+ v = Verifier()
26
+ fake = FakeAdapter(will_ok=True)
27
+ r = v.verify("SELCT * FRM broken;", adapter=fake) # intentionally broken
28
+ assert not r.ok
29
+ assert r.error and "parse_error" in r.error
30
+
31
+
32
+ def test_verifier_plain_aggregate_without_groupby_is_flagged():
33
+ v = Verifier()
34
+ fake = FakeAdapter(will_ok=True)
35
+ r = v.verify("SELECT COUNT(*), country FROM customers;", adapter=fake)
36
+ assert not r.ok
37
+ assert r.error and "aggregation_without_group_by" in r.error
38
+
39
+
40
+ def test_verifier_windowed_aggregate_is_ok_without_groupby():
41
+ v = Verifier()
42
+ fake = FakeAdapter(will_ok=True)
43
+ r = v.verify(
44
+ "SELECT customer_id, SUM(amount) OVER (PARTITION BY customer_id) AS s FROM payments;",
45
+ adapter=fake,
46
  )
47
+ assert r.ok, r.error
48
 
49
 
50
+ def test_verifier_distinct_projection_is_ok_with_aggregate():
51
  v = Verifier()
52
+ fake = FakeAdapter(will_ok=True)
53
+ r = v.verify(
54
+ "SELECT DISTINCT artist_id, COUNT(*) FROM albums;",
55
+ adapter=fake,
56
  )
57
+ # DISTINCT + aggregate can be valid; avoid false positives.
58
+ assert r.ok or "aggregation_without_group_by" not in (r.error or [])
 
59
 
60
 
61
+ def test_verifier_exec_error_is_reported():
62
  v = Verifier()
63
+ fake = FakeAdapter(will_ok=False, error="no such table: imaginary_table")
64
+ r = v.verify("SELECT name FROM imaginary_table;", adapter=fake)
65
  assert not r.ok
66
+ assert any(("exec_error" in e) or ("exec_exception" in e) for e in (r.error or []))
67
 
68
 
69
+ def test_verifier_returns_trace_with_int_duration():
70
  v = Verifier()
71
+ fake = FakeAdapter(will_ok=True)
72
+ r = v.verify("SELECT 1;", adapter=fake)
 
 
73
  assert isinstance(r.trace, StageTrace)
74
+ # Some implementations store duration as int milliseconds:
75
+ assert isinstance(r.trace.duration_ms, int)