Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
d2d07a3
1
Parent(s):
7315a86
Sync from GitHub main @ 015473c9c5ee20c6f880c09fb6f5dfc4070596e1
Browse files- nl2sql/safety.py +80 -8
nl2sql/safety.py
CHANGED
|
@@ -2,12 +2,13 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import re
|
| 4 |
import time
|
| 5 |
-
from typing import List, Pattern
|
| 6 |
|
| 7 |
import sqlglot
|
|
|
|
| 8 |
|
| 9 |
from nl2sql.types import StageResult, StageTrace
|
| 10 |
-
from nl2sql.metrics import safety_blocks_total,
|
| 11 |
|
| 12 |
|
| 13 |
# ------------------------- Zero-width & basic regexes -------------------------
|
|
@@ -119,6 +120,54 @@ def _remove_comments(body: str) -> str:
|
|
| 119 |
return body
|
| 120 |
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
def _strip_strings(body: str) -> str:
|
| 123 |
"""
|
| 124 |
Remove string literals (so forbidden keyword checks won't fire on quoted text).
|
|
@@ -160,8 +209,11 @@ class Safety:
|
|
| 160 |
|
| 161 |
name = "safety"
|
| 162 |
|
| 163 |
-
def __init__(
|
|
|
|
|
|
|
| 164 |
self.allow_explain = allow_explain
|
|
|
|
| 165 |
|
| 166 |
def check(self, sql: str) -> StageResult:
|
| 167 |
t0 = time.perf_counter()
|
|
@@ -187,6 +239,16 @@ class Safety:
|
|
| 187 |
# 1) sanitize
|
| 188 |
body = _sanitize(sql)
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
# 2) single-statement check (semicolon + parser)
|
| 191 |
semicolon_count = _count_statements_semicolon(body)
|
| 192 |
glot_count = _count_statements_sqlglot(body)
|
|
@@ -232,8 +294,8 @@ class Safety:
|
|
| 232 |
|
| 233 |
# 4) read-only root kind (SELECT/EXPLAIN[/WITH])
|
| 234 |
try:
|
| 235 |
-
trees = sqlglot.parse(body)
|
| 236 |
-
root = trees[0]
|
| 237 |
except Exception as e:
|
| 238 |
safety_blocks_total.labels(reason="parse_error").inc()
|
| 239 |
safety_checks_total.labels(ok="false").inc()
|
|
@@ -252,10 +314,9 @@ class Safety:
|
|
| 252 |
if self.allow_explain and _EXPLAIN_HEAD_RE.match(body):
|
| 253 |
remainder = _EXPLAIN_HEAD_RE.sub("", body, count=1).lstrip()
|
| 254 |
try:
|
| 255 |
-
t2 = sqlglot.parse_one(remainder)
|
| 256 |
t2_type = type(t2).__name__.lower() if t2 else ""
|
| 257 |
if t2_type in {"select", "with"}:
|
| 258 |
-
stage_duration_ms.labels("safety").observe(_ms(t0) / 1.0)
|
| 259 |
safety_checks_total.labels(ok="true").inc()
|
| 260 |
return StageResult(
|
| 261 |
ok=True,
|
|
@@ -292,8 +353,19 @@ class Safety:
|
|
| 292 |
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 293 |
)
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
# 5) success
|
| 296 |
-
stage_duration_ms.labels("safety").observe(_ms(t0) / 1.0)
|
| 297 |
safety_checks_total.labels(ok="true").inc()
|
| 298 |
return StageResult(
|
| 299 |
ok=True,
|
|
|
|
| 2 |
|
| 3 |
import re
|
| 4 |
import time
|
| 5 |
+
from typing import List, Pattern, Any, cast
|
| 6 |
|
| 7 |
import sqlglot
|
| 8 |
+
from sqlglot import exp
|
| 9 |
|
| 10 |
from nl2sql.types import StageResult, StageTrace
|
| 11 |
+
from nl2sql.metrics import safety_blocks_total, safety_checks_total
|
| 12 |
|
| 13 |
|
| 14 |
# ------------------------- Zero-width & basic regexes -------------------------
|
|
|
|
| 120 |
return body
|
| 121 |
|
| 122 |
|
| 123 |
+
def _has_comments(body: str) -> bool:
|
| 124 |
+
return bool(_LINE_COMMENT_RE.search(body) or _BLOCK_COMMENT_RE.search(body))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _contains_forbidden_ast(root: exp.Expression) -> tuple[bool, str]:
|
| 128 |
+
"""Return (blocked, reason) based on AST nodes/commands."""
|
| 129 |
+
forbidden_node_names = {
|
| 130 |
+
"insert",
|
| 131 |
+
"update",
|
| 132 |
+
"delete",
|
| 133 |
+
"drop",
|
| 134 |
+
"create",
|
| 135 |
+
"alter",
|
| 136 |
+
"truncate",
|
| 137 |
+
"merge",
|
| 138 |
+
"grant",
|
| 139 |
+
"revoke",
|
| 140 |
+
"execute",
|
| 141 |
+
"call",
|
| 142 |
+
"copy",
|
| 143 |
+
"replace",
|
| 144 |
+
}
|
| 145 |
+
forbidden_command_markers = ("pragma", "attach", "vacuum", "reindex", "analyze")
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
walk = getattr(root, "walk", None)
|
| 149 |
+
if walk is None:
|
| 150 |
+
return False, ""
|
| 151 |
+
for node in root.walk():
|
| 152 |
+
name = type(node).__name__.lower()
|
| 153 |
+
if name in forbidden_node_names:
|
| 154 |
+
return True, name
|
| 155 |
+
if name == "command":
|
| 156 |
+
sql = ""
|
| 157 |
+
try:
|
| 158 |
+
sql = node.sql(dialect="sqlite").lower()
|
| 159 |
+
except Exception:
|
| 160 |
+
sql = str(node).lower()
|
| 161 |
+
for kw in forbidden_command_markers:
|
| 162 |
+
if kw in sql:
|
| 163 |
+
return True, f"command:{kw}"
|
| 164 |
+
except Exception:
|
| 165 |
+
# If AST walk fails, be conservative: do not block here (parse/root checks already ran).
|
| 166 |
+
return False, ""
|
| 167 |
+
|
| 168 |
+
return False, ""
|
| 169 |
+
|
| 170 |
+
|
| 171 |
def _strip_strings(body: str) -> str:
|
| 172 |
"""
|
| 173 |
Remove string literals (so forbidden keyword checks won't fire on quoted text).
|
|
|
|
| 209 |
|
| 210 |
name = "safety"
|
| 211 |
|
| 212 |
+
def __init__(
|
| 213 |
+
self, allow_explain: bool = True, forbid_comments: bool = False
|
| 214 |
+
) -> None:
|
| 215 |
self.allow_explain = allow_explain
|
| 216 |
+
self.forbid_comments = forbid_comments
|
| 217 |
|
| 218 |
def check(self, sql: str) -> StageResult:
|
| 219 |
t0 = time.perf_counter()
|
|
|
|
| 239 |
# 1) sanitize
|
| 240 |
body = _sanitize(sql)
|
| 241 |
|
| 242 |
+
# 1.5) comment policy (block if any comment tokens are present)
|
| 243 |
+
if self.forbid_comments and _has_comments(body):
|
| 244 |
+
safety_blocks_total.labels(reason="comments_not_allowed").inc()
|
| 245 |
+
safety_checks_total.labels(ok="false").inc()
|
| 246 |
+
return StageResult(
|
| 247 |
+
ok=False,
|
| 248 |
+
error=["comments_not_allowed"],
|
| 249 |
+
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
# 2) single-statement check (semicolon + parser)
|
| 253 |
semicolon_count = _count_statements_semicolon(body)
|
| 254 |
glot_count = _count_statements_sqlglot(body)
|
|
|
|
| 294 |
|
| 295 |
# 4) read-only root kind (SELECT/EXPLAIN[/WITH])
|
| 296 |
try:
|
| 297 |
+
trees: list[Any] = sqlglot.parse(body)
|
| 298 |
+
root = cast(exp.Expression, trees[0])
|
| 299 |
except Exception as e:
|
| 300 |
safety_blocks_total.labels(reason="parse_error").inc()
|
| 301 |
safety_checks_total.labels(ok="false").inc()
|
|
|
|
| 314 |
if self.allow_explain and _EXPLAIN_HEAD_RE.match(body):
|
| 315 |
remainder = _EXPLAIN_HEAD_RE.sub("", body, count=1).lstrip()
|
| 316 |
try:
|
| 317 |
+
t2 = cast(exp.Expression, sqlglot.parse_one(remainder))
|
| 318 |
t2_type = type(t2).__name__.lower() if t2 else ""
|
| 319 |
if t2_type in {"select", "with"}:
|
|
|
|
| 320 |
safety_checks_total.labels(ok="true").inc()
|
| 321 |
return StageResult(
|
| 322 |
ok=True,
|
|
|
|
| 353 |
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
|
| 354 |
)
|
| 355 |
|
| 356 |
+
# 4.5) AST-based forbidden nodes / commands (defense-in-depth)
|
| 357 |
+
blocked, reason = _contains_forbidden_ast(root)
|
| 358 |
+
if blocked:
|
| 359 |
+
safety_blocks_total.labels(reason="forbidden_ast").inc()
|
| 360 |
+
safety_checks_total.labels(ok="false").inc()
|
| 361 |
+
return StageResult(
|
| 362 |
+
ok=False,
|
| 363 |
+
error=[f"Forbidden AST: {reason}"],
|
| 364 |
+
trace=StageTrace(
|
| 365 |
+
stage=self.name, duration_ms=_ms(t0), notes={"reason": reason}
|
| 366 |
+
),
|
| 367 |
+
)
|
| 368 |
# 5) success
|
|
|
|
| 369 |
safety_checks_total.labels(ok="true").inc()
|
| 370 |
return StageResult(
|
| 371 |
ok=True,
|