""" SQL Validator - Security layer for SQL queries. Ensures ONLY safe SELECT queries are executed. Validates against whitelist and blocks dangerous operations. """ import logging import re from typing import List, Tuple, Optional, Set import sqlparse from sqlparse.sql import Statement, Token, Identifier, IdentifierList from sqlparse.tokens import Keyword, DML logger = logging.getLogger(__name__) class SQLValidationError(Exception): """Raised when SQL validation fails.""" pass class SQLValidator: """Validates SQL queries for safety before execution.""" FORBIDDEN_KEYWORDS = { 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE', 'GRANT', 'REVOKE', 'EXECUTE', 'EXEC', 'INTO OUTFILE', 'INTO DUMPFILE', 'LOAD_FILE', 'LOAD DATA' } FORBIDDEN_PATTERNS = [ r'INTO\s+OUTFILE', r'INTO\s+DUMPFILE', r'LOAD_FILE\s*\(', r'LOAD\s+DATA', r';\s*(?:DROP|DELETE|UPDATE|INSERT)', # Multi-statement attacks r'--', # SQL comments (potential injection) r'/\*.*\*/', # Block comments ] def __init__(self, allowed_tables: Optional[Set[str]] = None, max_limit: int = 100): self.allowed_tables = allowed_tables or set() self.max_limit = max_limit self._compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.FORBIDDEN_PATTERNS] def set_allowed_tables(self, tables: List[str]): """Set the whitelist of allowed tables.""" self.allowed_tables = set(tables) def validate(self, sql: str) -> Tuple[bool, str, Optional[str]]: """ Validate SQL query for safety. Returns: Tuple of (is_valid, message, sanitized_sql) """ if not sql or not sql.strip(): return False, "Empty SQL query", None sql = sql.strip() # Check for forbidden patterns for pattern in self._compiled_patterns: if pattern.search(sql): return False, f"Forbidden pattern detected in query", None # Parse SQL try: parsed = sqlparse.parse(sql) except Exception as e: return False, f"Failed to parse SQL: {e}", None if not parsed: return False, "Failed to parse SQL query", None # Only allow single statements if len(parsed) > 1: return False, "Multiple SQL statements not allowed", None statement = parsed[0] # Check statement type stmt_type = statement.get_type() if stmt_type != 'SELECT': return False, f"Only SELECT statements allowed, got: {stmt_type}", None # Check for forbidden keywords in tokens sql_upper = sql.upper() for keyword in self.FORBIDDEN_KEYWORDS: if keyword in sql_upper: return False, f"Forbidden keyword detected: {keyword}", None # Extract and validate tables tables = self._extract_tables(statement) if self.allowed_tables: # Normalize for comparison (remove quotes, lowercase) allowed_norm = {t.lower().replace('"', '').replace('`', '') for t in self.allowed_tables} tables_norm = {t.lower().replace('"', '').replace('`', '') for t in tables} invalid_tables = tables_norm - allowed_norm if invalid_tables: return False, f"Access denied to tables: {invalid_tables}", None # Ensure LIMIT clause exists sanitized = self._ensure_limit(sql) return True, "Query validated successfully", sanitized def _extract_tables(self, statement: Statement) -> Set[str]: """Extract table names from a SELECT statement using regex.""" tables = set() sql = str(statement) # Use regex to find tables after FROM and JOIN # Pattern: FROM table_name or JOIN table_name, supporting quotes # Matches: FROM table, FROM "table", FROM `table` from_pattern = re.compile( r'\bFROM\s+(?:["`]?)([a-zA-Z0-9_]+)(?:["`]?)', re.IGNORECASE ) join_pattern = re.compile( r'\bJOIN\s+(?:["`]?)([a-zA-Z0-9_]+)(?:["`]?)', re.IGNORECASE ) # Find all FROM tables for match in from_pattern.finditer(sql): tables.add(match.group(1)) # Find all JOIN tables for match in join_pattern.finditer(sql): tables.add(match.group(1)) return tables def _ensure_limit(self, sql: str) -> str: """Ensure the query has a LIMIT clause.""" sql_upper = sql.upper() if 'LIMIT' in sql_upper: # Check if limit is too high limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper) if limit_match: current_limit = int(limit_match.group(1)) if current_limit > self.max_limit: # Replace with max limit sql = re.sub( r'LIMIT\s+\d+', f'LIMIT {self.max_limit}', sql, flags=re.IGNORECASE ) return sql else: # Add LIMIT clause sql = sql.rstrip(';').strip() return f"{sql} LIMIT {self.max_limit}" _validator: Optional[SQLValidator] = None def get_sql_validator() -> SQLValidator: global _validator if _validator is None: _validator = SQLValidator() return _validator