from __future__ import annotations import re import time from nl2sql.types import StageResult, StageTrace # --- Regex utils --- _COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL) _COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE) # string literals (single & double quotes), allow escaped quotes _STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL) _STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL) # case-insensitive, word-boundary forbidden keywords _FORBIDDEN = re.compile( r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b", re.IGNORECASE, ) # allow: SELECT ... or WITH SELECT ... _ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL) # --- New cleanup helpers --- _FENCE_SQL = re.compile(r"```sql", re.IGNORECASE) _FENCE_ANY = re.compile(r"```") def _sanitize_sql(sql: str) -> str: """Remove markdown fences, comments, and surrounding junk.""" s = _FENCE_SQL.sub("", sql) s = _FENCE_ANY.sub("", s) s = _COMMENT_BLOCK.sub(" ", s) s = _COMMENT_LINE.sub(" ", s) s = s.strip() # remove trailing semicolon safely s = s.rstrip(";").strip() return s def _mask_strings(s: str) -> str: s = _STRING_SINGLE.sub("'X'", s) s = _STRING_DOUBLE.sub('"X"', s) return s def _split_statements(s: str) -> list[str]: """ Split only if there are real multiple statements, ignoring harmless trailing semicolons or markdown. """ parts = [p.strip() for p in s.split(";")] parts = [p for p in parts if p] return parts class Safety: name = "safety" def check(self, sql: str) -> StageResult: t0 = time.perf_counter() print("🧩 SQL candidate:", sql) # --- sanitize first --- s = _sanitize_sql(sql) s = _mask_strings(s).strip() stmts = _split_statements(s) if len(stmts) != 1: return StageResult( ok=False, error=["Multiple statements detected"], trace=StageTrace( stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000 ), ) body = stmts[0] if _FORBIDDEN.search(body): return StageResult( ok=False, error=["Forbidden keyword detected"], trace=StageTrace( stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000 ), ) if not _ALLOW_SELECT.match(body): return StageResult( ok=False, error=["Non-SELECT statement"], trace=StageTrace( stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000 ), ) return StageResult( ok=True, data={ "sql": body, "rationale": "Statement validated as SELECT-only (strings/comments/markdown ignored).", }, trace=StageTrace( stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000 ), ) run = check