Spaces:
Running
Running
File size: 3,190 Bytes
570f7bd 105e019 570f7bd c1bc4eb 570f7bd 370553a c1bc4eb 370553a 570f7bd 370553a 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd 370553a 570f7bd 370553a 570f7bd c1bc4eb 570f7bd 370553a 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd c1bc4eb 570f7bd 370553a 570f7bd c1bc4eb 570f7bd a337fad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from __future__ import annotations
import re
import time
from nl2sql.types import StageResult, StageTrace
# --- Regex utils ---
_COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
_COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
# string literals (single & double quotes), allow escaped quotes
_STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
_STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
# case-insensitive, word-boundary forbidden keywords
_FORBIDDEN = re.compile(
r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b",
re.IGNORECASE,
)
# allow: SELECT ... or WITH <cte...> SELECT ...
_ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
# --- New cleanup helpers ---
_FENCE_SQL = re.compile(r"```sql", re.IGNORECASE)
_FENCE_ANY = re.compile(r"```")
def _sanitize_sql(sql: str) -> str:
"""Remove markdown fences, comments, and surrounding junk."""
s = _FENCE_SQL.sub("", sql)
s = _FENCE_ANY.sub("", s)
s = _COMMENT_BLOCK.sub(" ", s)
s = _COMMENT_LINE.sub(" ", s)
s = s.strip()
# remove trailing semicolon safely
s = s.rstrip(";").strip()
return s
def _mask_strings(s: str) -> str:
s = _STRING_SINGLE.sub("'X'", s)
s = _STRING_DOUBLE.sub('"X"', s)
return s
def _split_statements(s: str) -> list[str]:
"""
Split only if there are real multiple statements,
ignoring harmless trailing semicolons or markdown.
"""
parts = [p.strip() for p in s.split(";")]
parts = [p for p in parts if p]
return parts
class Safety:
name = "safety"
def check(self, sql: str) -> StageResult:
t0 = time.perf_counter()
print("🧩 SQL candidate:", sql)
# --- sanitize first ---
s = _sanitize_sql(sql)
s = _mask_strings(s).strip()
stmts = _split_statements(s)
if len(stmts) != 1:
return StageResult(
ok=False,
error=["Multiple statements detected"],
trace=StageTrace(
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
),
)
body = stmts[0]
if _FORBIDDEN.search(body):
return StageResult(
ok=False,
error=["Forbidden keyword detected"],
trace=StageTrace(
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
),
)
if not _ALLOW_SELECT.match(body):
return StageResult(
ok=False,
error=["Non-SELECT statement"],
trace=StageTrace(
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
),
)
return StageResult(
ok=True,
data={
"sql": body,
"rationale": "Statement validated as SELECT-only (strings/comments/markdown ignored).",
},
trace=StageTrace(
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
),
)
run = check
|