from __future__ import annotations import re, time from nl2sql.types import StageResult, StageTrace # --- Regex utils --- _COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL) _COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE) # string literals (single & double quotes), allow escaped quotes _STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL) _STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL) # case-insensitive, word-boundary forbidden keywords _FORBIDDEN = re.compile( r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b", re.IGNORECASE, ) # allow: SELECT ... or WITH SELECT ... _ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL) def _strip_comments(s: str) -> str: s = _COMMENT_BLOCK.sub(" ", s) s = _COMMENT_LINE.sub(" ", s) return s def _mask_strings(s: str) -> str: s = _STRING_SINGLE.sub("'X'", s) s = _STRING_DOUBLE.sub('"X"', s) return s def _split_statements(s: str) -> list[str]: parts = [p.strip() for p in s.split(";")] return [p for p in parts if p] class Safety: name = "safety" def check(self, sql: str) -> StageResult: t0 = time.perf_counter() print("🧩 SQL candidate:", sql) s = _strip_comments(sql) s = _mask_strings(s).strip() stmts = _split_statements(s) if len(stmts) != 1: return StageResult( ok=False, error=["Multiple statements detected"], trace=StageTrace( stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000 ), ) body = stmts[0] if _FORBIDDEN.search(body): return StageResult( ok=False, error=["Forbidden keyword detected"], trace=StageTrace( stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000 ), ) if not _ALLOW_SELECT.match(body): return StageResult( ok=False, error=["Non-SELECT statement"], trace=StageTrace( stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000 ), ) return StageResult( ok=True, data={ "sql": sql.strip(), "rationale": "Statement validated as SELECT-only (strings/comments ignored).", }, trace=StageTrace( stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000 ), )