Melika Kheirieh commited on
Commit
b0bec17
·
1 Parent(s): 602cae0

feat(safety): harden SQL validation (multi-CTE, recursive WITH, unicode normalization, precise errors, EXPLAIN gate)

Browse files
Files changed (2) hide show
  1. nl2sql/safety.py +66 -29
  2. tests/test_safety.py +121 -0
nl2sql/safety.py CHANGED
@@ -1,32 +1,56 @@
1
  from __future__ import annotations
 
2
  import re
3
  import time
 
4
  from nl2sql.types import StageResult, StageTrace
5
 
6
  # --- Regex utils ---
7
  _COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
8
  _COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
9
- # string literals (single & double quotes), allow escaped quotes
 
10
  _STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
11
  _STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
12
 
13
- # case-insensitive, word-boundary forbidden keywords
14
  _FORBIDDEN = re.compile(
15
  r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b",
16
  re.IGNORECASE,
17
  )
18
 
19
- # allow: SELECT ... or WITH <cte...> SELECT ...
20
- _ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
 
 
 
 
 
 
 
 
21
 
22
- # --- New cleanup helpers ---
23
  _FENCE_SQL = re.compile(r"```sql", re.IGNORECASE)
24
  _FENCE_ANY = re.compile(r"```")
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def _sanitize_sql(sql: str) -> str:
28
- """Remove markdown fences, comments, and surrounding junk."""
29
- s = _FENCE_SQL.sub("", sql)
 
30
  s = _FENCE_ANY.sub("", s)
31
  s = _COMMENT_BLOCK.sub(" ", s)
32
  s = _COMMENT_LINE.sub(" ", s)
@@ -37,6 +61,7 @@ def _sanitize_sql(sql: str) -> str:
37
 
38
 
39
  def _mask_strings(s: str) -> str:
 
40
  s = _STRING_SINGLE.sub("'X'", s)
41
  s = _STRING_DOUBLE.sub('"X"', s)
42
  return s
@@ -44,64 +69,76 @@ def _mask_strings(s: str) -> str:
44
 
45
  def _split_statements(s: str) -> list[str]:
46
  """
47
- Split only if there are real multiple statements,
48
- ignoring harmless trailing semicolons or markdown.
49
  """
50
  parts = [p.strip() for p in s.split(";")]
51
- parts = [p for p in parts if p]
52
- return parts
 
 
 
53
 
54
 
55
  class Safety:
56
  name = "safety"
57
 
 
 
 
 
 
 
58
  def check(self, sql: str) -> StageResult:
59
  t0 = time.perf_counter()
60
- print("🧩 SQL candidate:", sql)
61
 
62
- # --- sanitize first ---
63
  s = _sanitize_sql(sql)
64
  s = _mask_strings(s).strip()
65
 
 
66
  stmts = _split_statements(s)
67
  if len(stmts) != 1:
68
  return StageResult(
69
  ok=False,
70
  error=["Multiple statements detected"],
71
- trace=StageTrace(
72
- stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
73
- ),
74
  )
75
 
76
  body = stmts[0]
77
 
78
- if _FORBIDDEN.search(body):
 
 
79
  return StageResult(
80
  ok=False,
81
- error=["Forbidden keyword detected"],
82
- trace=StageTrace(
83
- stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
84
- ),
85
  )
86
 
87
- if not _ALLOW_SELECT.match(body):
 
 
 
 
 
88
  return StageResult(
89
  ok=False,
90
  error=["Non-SELECT statement"],
91
- trace=StageTrace(
92
- stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
93
- ),
94
  )
95
 
 
96
  return StageResult(
97
  ok=True,
98
  data={
99
  "sql": body,
100
- "rationale": "Statement validated as SELECT-only (strings/comments/markdown ignored).",
 
 
 
101
  },
102
- trace=StageTrace(
103
- stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
104
- ),
105
  )
106
 
 
107
  run = check
 
1
  from __future__ import annotations
2
+
3
  import re
4
  import time
5
+ import unicodedata
6
  from nl2sql.types import StageResult, StageTrace
7
 
8
  # --- Regex utils ---
9
  _COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
10
  _COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
11
+
12
+ # String literals (single & double quotes), allow escaped quotes
13
  _STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
14
  _STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
15
 
16
+ # Case-insensitive, word-boundary forbidden keywords
17
  _FORBIDDEN = re.compile(
18
  r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b",
19
  re.IGNORECASE,
20
  )
21
 
22
+ # Allow: SELECT ... or WITH (one or many CTEs, optional RECURSIVE) ... SELECT ...
23
+ _ALLOW_SELECT = re.compile(
24
+ r"^(?:WITH\s+(?:RECURSIVE\s+)?"
25
+ r".*?\)\s*(?:,\s*.*?\)\s*)*"
26
+ r")?SELECT\b",
27
+ re.IGNORECASE | re.DOTALL,
28
+ )
29
+
30
+ # Optional allowance: EXPLAIN SELECT ...
31
+ _ALLOW_EXPLAIN_SELECT = re.compile(r"^EXPLAIN\s+SELECT\b", re.IGNORECASE | re.DOTALL)
32
 
33
+ # --- Cleanup helpers ---
34
  _FENCE_SQL = re.compile(r"```sql", re.IGNORECASE)
35
  _FENCE_ANY = re.compile(r"```")
36
 
37
 
38
+ def _normalize_sql(sql: str) -> str:
39
+ """Normalize to NFKC and strip zero-width characters to prevent obfuscation."""
40
+ s = unicodedata.normalize("NFKC", sql)
41
+ # strip common zero-width spaces/joiners
42
+ return (
43
+ s.replace("\u200b", "")
44
+ .replace("\u200c", "")
45
+ .replace("\u200d", "")
46
+ .replace("\ufeff", "")
47
+ )
48
+
49
+
50
  def _sanitize_sql(sql: str) -> str:
51
+ """Remove markdown fences, comments, and harmless trailing semicolons."""
52
+ s = _normalize_sql(sql)
53
+ s = _FENCE_SQL.sub("", s)
54
  s = _FENCE_ANY.sub("", s)
55
  s = _COMMENT_BLOCK.sub(" ", s)
56
  s = _COMMENT_LINE.sub(" ", s)
 
61
 
62
 
63
  def _mask_strings(s: str) -> str:
64
+ """Replace string literals so that inner semicolons/keywords don't affect checks."""
65
  s = _STRING_SINGLE.sub("'X'", s)
66
  s = _STRING_DOUBLE.sub('"X"', s)
67
  return s
 
69
 
70
  def _split_statements(s: str) -> list[str]:
71
  """
72
+ Split on semicolons after string-masking. Ignore empties (e.g., trailing ';').
 
73
  """
74
  parts = [p.strip() for p in s.split(";")]
75
+ return [p for p in parts if p]
76
+
77
+
78
+ def _ms(t0: float) -> int:
79
+ return int((time.perf_counter() - t0) * 1000)
80
 
81
 
82
  class Safety:
83
  name = "safety"
84
 
85
+ def __init__(self, allow_explain: bool = False) -> None:
86
+ """
87
+ :param allow_explain: If True, 'EXPLAIN SELECT ...' is allowed in addition to SELECT.
88
+ """
89
+ self.allow_explain = allow_explain
90
+
91
  def check(self, sql: str) -> StageResult:
92
  t0 = time.perf_counter()
 
93
 
94
+ # 1) Sanitize and mask
95
  s = _sanitize_sql(sql)
96
  s = _mask_strings(s).strip()
97
 
98
+ # 2) Multiple statements check
99
  stmts = _split_statements(s)
100
  if len(stmts) != 1:
101
  return StageResult(
102
  ok=False,
103
  error=["Multiple statements detected"],
104
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
 
 
105
  )
106
 
107
  body = stmts[0]
108
 
109
+ # 3) Forbidden keyword check (report exact offending token)
110
+ m = _FORBIDDEN.search(body)
111
+ if m:
112
  return StageResult(
113
  ok=False,
114
+ error=[f"Forbidden keyword detected: '{m.group(0)}'"],
115
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
 
 
116
  )
117
 
118
+ # 4) Allow only SELECT (or optionally EXPLAIN SELECT)
119
+ allowed = bool(_ALLOW_SELECT.match(body))
120
+ if not allowed and self.allow_explain:
121
+ allowed = bool(_ALLOW_EXPLAIN_SELECT.match(body))
122
+
123
+ if not allowed:
124
  return StageResult(
125
  ok=False,
126
  error=["Non-SELECT statement"],
127
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
 
 
128
  )
129
 
130
+ # 5) Success
131
  return StageResult(
132
  ok=True,
133
  data={
134
  "sql": body,
135
+ "rationale": (
136
+ "Statement validated as SELECT-only (strings/comments/markdown ignored)."
137
+ + (" EXPLAIN SELECT allowed." if self.allow_explain else "")
138
+ ),
139
  },
140
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
 
 
141
  )
142
 
143
+ # Backward-compat alias
144
  run = check
tests/test_safety.py CHANGED
@@ -119,3 +119,124 @@ def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
119
  sql_bad = "SELECT 1; /* spacer */ DROP TABLE x;"
120
  assert s.check(sql).ok
121
  assert not s.check(sql_bad).ok
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  sql_bad = "SELECT 1; /* spacer */ DROP TABLE x;"
120
  assert s.check(sql).ok
121
  assert not s.check(sql_bad).ok
122
+
123
+
124
+ def test_safety_allows_multiple_ctes():
125
+ s = Safety()
126
+ sql = """
127
+ WITH a AS (SELECT 1 AS x),
128
+ b AS (SELECT 2 AS y)
129
+ SELECT a.x, b.y FROM a CROSS JOIN b;
130
+ """
131
+ assert s.check(sql).ok
132
+
133
+
134
+ def test_safety_allows_with_recursive():
135
+ s = Safety()
136
+ sql = """
137
+ WITH RECURSIVE cnt(x) AS (
138
+ SELECT 1 UNION ALL SELECT x+1 FROM cnt WHERE x < 3
139
+ )
140
+ SELECT * FROM cnt;
141
+ """
142
+ assert s.check(sql).ok
143
+
144
+
145
+ def test_safety_blocks_zero_width_obfuscation_in_keyword():
146
+ s = Safety()
147
+ # "DROP" با zero-width joiner وسط حروف
148
+ bad = "DR\u200dOP TABLE users;"
149
+ r = s.check(bad)
150
+ assert not r.ok
151
+
152
+
153
+ def test_safety_ignores_markdown_fences():
154
+ s = Safety()
155
+ sql = "```sql\nSELECT 1;\n```"
156
+ assert s.check(sql).ok
157
+
158
+
159
+ def test_safety_semicolon_inside_string_literal_is_ignored():
160
+ s = Safety()
161
+ sql = "SELECT 'a; b; c' AS sample;"
162
+ assert s.check(sql).ok
163
+
164
+
165
+ def test_safety_forbidden_keyword_inside_string_literal_ok():
166
+ s = Safety()
167
+ sql = "SELECT 'DROP TABLE x' AS note, 'delete from y' AS text;"
168
+ assert s.check(sql).ok
169
+
170
+
171
+ def test_safety_reports_offending_token_in_error_message():
172
+ s = Safety()
173
+ r = s.check(" \n ReIndex users;")
174
+ assert not r.ok
175
+ assert any("reindex" in e.lower() for e in (r.error or []))
176
+
177
+
178
+ def test_safety_multiple_statements_with_masked_strings_is_blocked():
179
+ s = Safety()
180
+ sql = "SELECT 'abc'; SELECT 1;"
181
+ r = s.check(sql)
182
+ assert not r.ok
183
+
184
+
185
+ def test_safety_duration_ms_is_int():
186
+ s = Safety()
187
+ r = s.check("SELECT 1;")
188
+ assert isinstance(r.trace.duration_ms, int)
189
+
190
+
191
+ def test_safety_allows_explain_select_when_enabled():
192
+ s = Safety(allow_explain=True)
193
+ r = s.check("EXPLAIN SELECT * FROM users;")
194
+ assert r.ok
195
+
196
+
197
+ def test_safety_blocks_explain_select_when_disabled():
198
+ s = Safety(allow_explain=False)
199
+ r = s.check("EXPLAIN SELECT * FROM users;")
200
+ assert not r.ok
201
+
202
+
203
+ def test_safety_blocks_forbidden_inside_cte_body():
204
+ s = Safety()
205
+ sql = """
206
+ WITH bad AS (DELETE FROM users)
207
+ SELECT * FROM users;
208
+ """
209
+ assert not s.check(sql).ok
210
+
211
+
212
+ def test_safety_permits_with_comments_and_newlines_complex():
213
+ s = Safety()
214
+ sql = """
215
+ /* head */ WITH a AS (SELECT 1 /*x*/ AS x) -- inline
216
+ , b AS (SELECT 2 AS y) /* tail */
217
+ SELECT a.x, b.y FROM a JOIN b; -- end
218
+ """
219
+ assert s.check(sql).ok
220
+
221
+
222
+ def test_safety_blocks_bom_prefixed_forbidden():
223
+ s = Safety()
224
+ sql = "\ufeffDROP TABLE x;"
225
+ assert not s.check(sql).ok
226
+
227
+
228
+ def test_safety_allows_trailing_double_semicolon():
229
+ s = Safety()
230
+ assert s.check("SELECT 1;;").ok
231
+
232
+
233
+ @pytest.mark.parametrize("q", ["explain select 1;", "EXPLAIN\nSELECT 1;"])
234
+ def test_safety_explain_various_spacing_when_enabled(q):
235
+ s = Safety(allow_explain=True)
236
+ assert s.check(q).ok
237
+
238
+
239
+ def test_safety_stage_name_constant():
240
+ s = Safety()
241
+ r = s.check("SELECT 1;")
242
+ assert r.trace.stage == "safety"