github-actions[bot] commited on
Commit
d2d07a3
·
1 Parent(s): 7315a86

Sync from GitHub main @ 015473c9c5ee20c6f880c09fb6f5dfc4070596e1

Browse files
Files changed (1) hide show
  1. nl2sql/safety.py +80 -8
nl2sql/safety.py CHANGED
@@ -2,12 +2,13 @@ from __future__ import annotations
2
 
3
  import re
4
  import time
5
- from typing import List, Pattern
6
 
7
  import sqlglot
 
8
 
9
  from nl2sql.types import StageResult, StageTrace
10
- from nl2sql.metrics import safety_blocks_total, stage_duration_ms, safety_checks_total
11
 
12
 
13
  # ------------------------- Zero-width & basic regexes -------------------------
@@ -119,6 +120,54 @@ def _remove_comments(body: str) -> str:
119
  return body
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def _strip_strings(body: str) -> str:
123
  """
124
  Remove string literals (so forbidden keyword checks won't fire on quoted text).
@@ -160,8 +209,11 @@ class Safety:
160
 
161
  name = "safety"
162
 
163
- def __init__(self, allow_explain: bool = True) -> None:
 
 
164
  self.allow_explain = allow_explain
 
165
 
166
  def check(self, sql: str) -> StageResult:
167
  t0 = time.perf_counter()
@@ -187,6 +239,16 @@ class Safety:
187
  # 1) sanitize
188
  body = _sanitize(sql)
189
 
 
 
 
 
 
 
 
 
 
 
190
  # 2) single-statement check (semicolon + parser)
191
  semicolon_count = _count_statements_semicolon(body)
192
  glot_count = _count_statements_sqlglot(body)
@@ -232,8 +294,8 @@ class Safety:
232
 
233
  # 4) read-only root kind (SELECT/EXPLAIN[/WITH])
234
  try:
235
- trees = sqlglot.parse(body)
236
- root = trees[0]
237
  except Exception as e:
238
  safety_blocks_total.labels(reason="parse_error").inc()
239
  safety_checks_total.labels(ok="false").inc()
@@ -252,10 +314,9 @@ class Safety:
252
  if self.allow_explain and _EXPLAIN_HEAD_RE.match(body):
253
  remainder = _EXPLAIN_HEAD_RE.sub("", body, count=1).lstrip()
254
  try:
255
- t2 = sqlglot.parse_one(remainder)
256
  t2_type = type(t2).__name__.lower() if t2 else ""
257
  if t2_type in {"select", "with"}:
258
- stage_duration_ms.labels("safety").observe(_ms(t0) / 1.0)
259
  safety_checks_total.labels(ok="true").inc()
260
  return StageResult(
261
  ok=True,
@@ -292,8 +353,19 @@ class Safety:
292
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
293
  )
294
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  # 5) success
296
- stage_duration_ms.labels("safety").observe(_ms(t0) / 1.0)
297
  safety_checks_total.labels(ok="true").inc()
298
  return StageResult(
299
  ok=True,
 
2
 
3
  import re
4
  import time
5
+ from typing import List, Pattern, Any, cast
6
 
7
  import sqlglot
8
+ from sqlglot import exp
9
 
10
  from nl2sql.types import StageResult, StageTrace
11
+ from nl2sql.metrics import safety_blocks_total, safety_checks_total
12
 
13
 
14
  # ------------------------- Zero-width & basic regexes -------------------------
 
120
  return body
121
 
122
 
123
+ def _has_comments(body: str) -> bool:
124
+ return bool(_LINE_COMMENT_RE.search(body) or _BLOCK_COMMENT_RE.search(body))
125
+
126
+
127
+ def _contains_forbidden_ast(root: exp.Expression) -> tuple[bool, str]:
128
+ """Return (blocked, reason) based on AST nodes/commands."""
129
+ forbidden_node_names = {
130
+ "insert",
131
+ "update",
132
+ "delete",
133
+ "drop",
134
+ "create",
135
+ "alter",
136
+ "truncate",
137
+ "merge",
138
+ "grant",
139
+ "revoke",
140
+ "execute",
141
+ "call",
142
+ "copy",
143
+ "replace",
144
+ }
145
+ forbidden_command_markers = ("pragma", "attach", "vacuum", "reindex", "analyze")
146
+
147
+ try:
148
+ walk = getattr(root, "walk", None)
149
+ if walk is None:
150
+ return False, ""
151
+ for node in root.walk():
152
+ name = type(node).__name__.lower()
153
+ if name in forbidden_node_names:
154
+ return True, name
155
+ if name == "command":
156
+ sql = ""
157
+ try:
158
+ sql = node.sql(dialect="sqlite").lower()
159
+ except Exception:
160
+ sql = str(node).lower()
161
+ for kw in forbidden_command_markers:
162
+ if kw in sql:
163
+ return True, f"command:{kw}"
164
+ except Exception:
165
+ # If AST walk fails, be conservative: do not block here (parse/root checks already ran).
166
+ return False, ""
167
+
168
+ return False, ""
169
+
170
+
171
  def _strip_strings(body: str) -> str:
172
  """
173
  Remove string literals (so forbidden keyword checks won't fire on quoted text).
 
209
 
210
  name = "safety"
211
 
212
+ def __init__(
213
+ self, allow_explain: bool = True, forbid_comments: bool = False
214
+ ) -> None:
215
  self.allow_explain = allow_explain
216
+ self.forbid_comments = forbid_comments
217
 
218
  def check(self, sql: str) -> StageResult:
219
  t0 = time.perf_counter()
 
239
  # 1) sanitize
240
  body = _sanitize(sql)
241
 
242
+ # 1.5) comment policy (block if any comment tokens are present)
243
+ if self.forbid_comments and _has_comments(body):
244
+ safety_blocks_total.labels(reason="comments_not_allowed").inc()
245
+ safety_checks_total.labels(ok="false").inc()
246
+ return StageResult(
247
+ ok=False,
248
+ error=["comments_not_allowed"],
249
+ trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
250
+ )
251
+
252
  # 2) single-statement check (semicolon + parser)
253
  semicolon_count = _count_statements_semicolon(body)
254
  glot_count = _count_statements_sqlglot(body)
 
294
 
295
  # 4) read-only root kind (SELECT/EXPLAIN[/WITH])
296
  try:
297
+ trees: list[Any] = sqlglot.parse(body)
298
+ root = cast(exp.Expression, trees[0])
299
  except Exception as e:
300
  safety_blocks_total.labels(reason="parse_error").inc()
301
  safety_checks_total.labels(ok="false").inc()
 
314
  if self.allow_explain and _EXPLAIN_HEAD_RE.match(body):
315
  remainder = _EXPLAIN_HEAD_RE.sub("", body, count=1).lstrip()
316
  try:
317
+ t2 = cast(exp.Expression, sqlglot.parse_one(remainder))
318
  t2_type = type(t2).__name__.lower() if t2 else ""
319
  if t2_type in {"select", "with"}:
 
320
  safety_checks_total.labels(ok="true").inc()
321
  return StageResult(
322
  ok=True,
 
353
  trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
354
  )
355
 
356
+ # 4.5) AST-based forbidden nodes / commands (defense-in-depth)
357
+ blocked, reason = _contains_forbidden_ast(root)
358
+ if blocked:
359
+ safety_blocks_total.labels(reason="forbidden_ast").inc()
360
+ safety_checks_total.labels(ok="false").inc()
361
+ return StageResult(
362
+ ok=False,
363
+ error=[f"Forbidden AST: {reason}"],
364
+ trace=StageTrace(
365
+ stage=self.name, duration_ms=_ms(t0), notes={"reason": reason}
366
+ ),
367
+ )
368
  # 5) success
 
369
  safety_checks_total.labels(ok="true").inc()
370
  return StageResult(
371
  ok=True,