Spaces:
Sleeping
Sleeping
Melika Kheirieh
feat(safety): harden SQL validation (multi-CTE, recursive WITH, unicode normalization, precise errors, EXPLAIN gate)
b0bec17
| from __future__ import annotations | |
| import re | |
| import time | |
| import unicodedata | |
| from nl2sql.types import StageResult, StageTrace | |
| # --- Regex utils --- | |
| _COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL) | |
| _COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE) | |
| # String literals (single & double quotes), allow escaped quotes | |
| _STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL) | |
| _STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL) | |
| # Case-insensitive, word-boundary forbidden keywords | |
| _FORBIDDEN = re.compile( | |
| r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b", | |
| re.IGNORECASE, | |
| ) | |
| # Allow: SELECT ... or WITH (one or many CTEs, optional RECURSIVE) ... SELECT ... | |
| _ALLOW_SELECT = re.compile( | |
| r"^(?:WITH\s+(?:RECURSIVE\s+)?" | |
| r".*?\)\s*(?:,\s*.*?\)\s*)*" | |
| r")?SELECT\b", | |
| re.IGNORECASE | re.DOTALL, | |
| ) | |
| # Optional allowance: EXPLAIN SELECT ... | |
| _ALLOW_EXPLAIN_SELECT = re.compile(r"^EXPLAIN\s+SELECT\b", re.IGNORECASE | re.DOTALL) | |
| # --- Cleanup helpers --- | |
| _FENCE_SQL = re.compile(r"```sql", re.IGNORECASE) | |
| _FENCE_ANY = re.compile(r"```") | |
| def _normalize_sql(sql: str) -> str: | |
| """Normalize to NFKC and strip zero-width characters to prevent obfuscation.""" | |
| s = unicodedata.normalize("NFKC", sql) | |
| # strip common zero-width spaces/joiners | |
| return ( | |
| s.replace("\u200b", "") | |
| .replace("\u200c", "") | |
| .replace("\u200d", "") | |
| .replace("\ufeff", "") | |
| ) | |
| def _sanitize_sql(sql: str) -> str: | |
| """Remove markdown fences, comments, and harmless trailing semicolons.""" | |
| s = _normalize_sql(sql) | |
| s = _FENCE_SQL.sub("", s) | |
| s = _FENCE_ANY.sub("", s) | |
| s = _COMMENT_BLOCK.sub(" ", s) | |
| s = _COMMENT_LINE.sub(" ", s) | |
| s = s.strip() | |
| # remove trailing semicolon safely | |
| s = s.rstrip(";").strip() | |
| return s | |
| def _mask_strings(s: str) -> str: | |
| """Replace string literals so that inner semicolons/keywords don't affect checks.""" | |
| s = _STRING_SINGLE.sub("'X'", s) | |
| s = _STRING_DOUBLE.sub('"X"', s) | |
| return s | |
| def _split_statements(s: str) -> list[str]: | |
| """ | |
| Split on semicolons after string-masking. Ignore empties (e.g., trailing ';'). | |
| """ | |
| parts = [p.strip() for p in s.split(";")] | |
| return [p for p in parts if p] | |
| def _ms(t0: float) -> int: | |
| return int((time.perf_counter() - t0) * 1000) | |
| class Safety: | |
| name = "safety" | |
| def __init__(self, allow_explain: bool = False) -> None: | |
| """ | |
| :param allow_explain: If True, 'EXPLAIN SELECT ...' is allowed in addition to SELECT. | |
| """ | |
| self.allow_explain = allow_explain | |
| def check(self, sql: str) -> StageResult: | |
| t0 = time.perf_counter() | |
| # 1) Sanitize and mask | |
| s = _sanitize_sql(sql) | |
| s = _mask_strings(s).strip() | |
| # 2) Multiple statements check | |
| stmts = _split_statements(s) | |
| if len(stmts) != 1: | |
| return StageResult( | |
| ok=False, | |
| error=["Multiple statements detected"], | |
| trace=StageTrace(stage=self.name, duration_ms=_ms(t0)), | |
| ) | |
| body = stmts[0] | |
| # 3) Forbidden keyword check (report exact offending token) | |
| m = _FORBIDDEN.search(body) | |
| if m: | |
| return StageResult( | |
| ok=False, | |
| error=[f"Forbidden keyword detected: '{m.group(0)}'"], | |
| trace=StageTrace(stage=self.name, duration_ms=_ms(t0)), | |
| ) | |
| # 4) Allow only SELECT (or optionally EXPLAIN SELECT) | |
| allowed = bool(_ALLOW_SELECT.match(body)) | |
| if not allowed and self.allow_explain: | |
| allowed = bool(_ALLOW_EXPLAIN_SELECT.match(body)) | |
| if not allowed: | |
| return StageResult( | |
| ok=False, | |
| error=["Non-SELECT statement"], | |
| trace=StageTrace(stage=self.name, duration_ms=_ms(t0)), | |
| ) | |
| # 5) Success | |
| return StageResult( | |
| ok=True, | |
| data={ | |
| "sql": body, | |
| "rationale": ( | |
| "Statement validated as SELECT-only (strings/comments/markdown ignored)." | |
| + (" EXPLAIN SELECT allowed." if self.allow_explain else "") | |
| ), | |
| }, | |
| trace=StageTrace(stage=self.name, duration_ms=_ms(t0)), | |
| ) | |
| # Backward-compat alias | |
| run = check | |