Spaces:
Running
Running
File size: 2,379 Bytes
28035e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | """SQL safety validation.
Rejects any query that is not a pure SELECT statement.
"""
import re
_FORBIDDEN_KEYWORDS = [
r"\bDROP\b",
r"\bDELETE\b",
r"\bUPDATE\b",
r"\bALTER\b",
r"\bTRUNCATE\b",
r"\bINSERT\b",
r"\bCREATE\b",
r"\bGRANT\b",
r"\bREVOKE\b",
r"\bEXEC\b",
r"\bEXECUTE\b",
]
_FORBIDDEN_PATTERN = re.compile("|".join(_FORBIDDEN_KEYWORDS), re.IGNORECASE)
def validate_sql(sql: str) -> tuple[bool, str]:
"""Check if a SQL string is safe to execute.
Returns
-------
(is_safe, reason)
"""
stripped = sql.strip().rstrip(";").strip()
if not stripped:
return False, "Empty query."
# Must start with SELECT or WITH (CTE)
if not re.match(r"^\s*(SELECT|WITH)\b", stripped, re.IGNORECASE):
return False, "Only SELECT queries are allowed."
# Check for forbidden keywords
match = _FORBIDDEN_PATTERN.search(stripped)
if match:
return False, f"Forbidden keyword detected: {match.group().upper()}"
return True, ""
def check_sql_against_schema(sql: str, schema: dict[str, list[dict]]) -> tuple[bool, list[str]]:
"""Programmatically check that tables/columns in SQL exist in the schema.
Returns (is_valid, list_of_issues).
Much faster and more accurate than LLM-based critique.
"""
issues: list[str] = []
# Build lookup sets
all_tables = {t.lower() for t in schema}
table_columns: dict[str, set[str]] = {}
for t, cols in schema.items():
table_columns[t.lower()] = {c["column_name"].lower() for c in cols}
all_columns = set()
for cols in table_columns.values():
all_columns |= cols
sql_upper = sql.upper()
# Extract table references (FROM / JOIN)
table_refs = re.findall(
r'(?:FROM|JOIN)\s+"?(\w+)"?', sql, re.IGNORECASE
)
for tref in table_refs:
if tref.lower() not in all_tables:
issues.append(f"Table '{tref}' not found in schema")
# Basic check: if GROUP BY is present, verify SELECT has aggregation or is in GROUP BY
# (lightweight check — not full SQL parsing)
if "GROUP BY" in sql_upper and "SELECT" in sql_upper:
if not any(fn in sql_upper for fn in ["SUM(", "COUNT(", "AVG(", "MIN(", "MAX("]):
issues.append("GROUP BY present but no aggregation function found")
return (len(issues) == 0, issues)
|