Spaces:
Restarting
Restarting
File size: 4,390 Bytes
570f7bd b0bec17 105e019 b0bec17 570f7bd c1bc4eb b0bec17 570f7bd b0bec17 570f7bd b0bec17 570f7bd b0bec17 370553a c1bc4eb 370553a b0bec17 370553a b0bec17 370553a 570f7bd 370553a 570f7bd c1bc4eb 570f7bd b0bec17 570f7bd c1bc4eb 570f7bd 370553a b0bec17 370553a 570f7bd b0bec17 570f7bd c1bc4eb 570f7bd b0bec17 570f7bd 370553a b0bec17 370553a 570f7bd b0bec17 570f7bd b0bec17 570f7bd b0bec17 570f7bd b0bec17 570f7bd b0bec17 570f7bd b0bec17 570f7bd b0bec17 570f7bd 370553a b0bec17 570f7bd b0bec17 570f7bd a337fad b0bec17 a337fad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from __future__ import annotations
import re
import time
import unicodedata
from nl2sql.types import StageResult, StageTrace
# --- Regex utils ---
_COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
_COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
# String literals (single & double quotes), allow escaped quotes
_STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
_STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
# Case-insensitive, word-boundary forbidden keywords
_FORBIDDEN = re.compile(
r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b",
re.IGNORECASE,
)
# Allow: SELECT ... or WITH (one or many CTEs, optional RECURSIVE) ... SELECT ...
_ALLOW_SELECT = re.compile(
r"^(?:WITH\s+(?:RECURSIVE\s+)?"
r".*?\)\s*(?:,\s*.*?\)\s*)*"
r")?SELECT\b",
re.IGNORECASE | re.DOTALL,
)
# Optional allowance: EXPLAIN SELECT ...
_ALLOW_EXPLAIN_SELECT = re.compile(r"^EXPLAIN\s+SELECT\b", re.IGNORECASE | re.DOTALL)
# --- Cleanup helpers ---
_FENCE_SQL = re.compile(r"```sql", re.IGNORECASE)
_FENCE_ANY = re.compile(r"```")
def _normalize_sql(sql: str) -> str:
"""Normalize to NFKC and strip zero-width characters to prevent obfuscation."""
s = unicodedata.normalize("NFKC", sql)
# strip common zero-width spaces/joiners
return (
s.replace("\u200b", "")
.replace("\u200c", "")
.replace("\u200d", "")
.replace("\ufeff", "")
)
def _sanitize_sql(sql: str) -> str:
"""Remove markdown fences, comments, and harmless trailing semicolons."""
s = _normalize_sql(sql)
s = _FENCE_SQL.sub("", s)
s = _FENCE_ANY.sub("", s)
s = _COMMENT_BLOCK.sub(" ", s)
s = _COMMENT_LINE.sub(" ", s)
s = s.strip()
# remove trailing semicolon safely
s = s.rstrip(";").strip()
return s
def _mask_strings(s: str) -> str:
"""Replace string literals so that inner semicolons/keywords don't affect checks."""
s = _STRING_SINGLE.sub("'X'", s)
s = _STRING_DOUBLE.sub('"X"', s)
return s
def _split_statements(s: str) -> list[str]:
"""
Split on semicolons after string-masking. Ignore empties (e.g., trailing ';').
"""
parts = [p.strip() for p in s.split(";")]
return [p for p in parts if p]
def _ms(t0: float) -> int:
return int((time.perf_counter() - t0) * 1000)
class Safety:
name = "safety"
def __init__(self, allow_explain: bool = False) -> None:
"""
:param allow_explain: If True, 'EXPLAIN SELECT ...' is allowed in addition to SELECT.
"""
self.allow_explain = allow_explain
def check(self, sql: str) -> StageResult:
t0 = time.perf_counter()
# 1) Sanitize and mask
s = _sanitize_sql(sql)
s = _mask_strings(s).strip()
# 2) Multiple statements check
stmts = _split_statements(s)
if len(stmts) != 1:
return StageResult(
ok=False,
error=["Multiple statements detected"],
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)
body = stmts[0]
# 3) Forbidden keyword check (report exact offending token)
m = _FORBIDDEN.search(body)
if m:
return StageResult(
ok=False,
error=[f"Forbidden keyword detected: '{m.group(0)}'"],
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)
# 4) Allow only SELECT (or optionally EXPLAIN SELECT)
allowed = bool(_ALLOW_SELECT.match(body))
if not allowed and self.allow_explain:
allowed = bool(_ALLOW_EXPLAIN_SELECT.match(body))
if not allowed:
return StageResult(
ok=False,
error=["Non-SELECT statement"],
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)
# 5) Success
return StageResult(
ok=True,
data={
"sql": body,
"rationale": (
"Statement validated as SELECT-only (strings/comments/markdown ignored)."
+ (" EXPLAIN SELECT allowed." if self.allow_explain else "")
),
},
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)
# Backward-compat alias
run = check
|