Spaces:
Running
Running
File size: 5,859 Bytes
f9ad313 a8441ef f9ad313 a8441ef f9ad313 a8441ef f9ad313 a8441ef f9ad313 a8441ef f9ad313 a8441ef f9ad313 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
"""
SQL Validator - Security layer for SQL queries.
Ensures ONLY safe SELECT queries are executed.
Validates against whitelist and blocks dangerous operations.
"""
import logging
import re
from typing import List, Tuple, Optional, Set
import sqlparse
from sqlparse.sql import Statement, Token, Identifier, IdentifierList
from sqlparse.tokens import Keyword, DML
logger = logging.getLogger(__name__)
class SQLValidationError(Exception):
"""Raised when SQL validation fails."""
pass
class SQLValidator:
"""Validates SQL queries for safety before execution."""
FORBIDDEN_KEYWORDS = {
'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER',
'TRUNCATE', 'GRANT', 'REVOKE', 'EXECUTE', 'EXEC',
'INTO OUTFILE', 'INTO DUMPFILE', 'LOAD_FILE', 'LOAD DATA'
}
FORBIDDEN_PATTERNS = [
r'INTO\s+OUTFILE',
r'INTO\s+DUMPFILE',
r'LOAD_FILE\s*\(',
r'LOAD\s+DATA',
r';\s*(?:DROP|DELETE|UPDATE|INSERT)', # Multi-statement attacks
r'--', # SQL comments (potential injection)
r'/\*.*\*/', # Block comments
]
def __init__(self, allowed_tables: Optional[Set[str]] = None, max_limit: int = 100):
self.allowed_tables = allowed_tables or set()
self.max_limit = max_limit
self._compiled_patterns = [re.compile(p, re.IGNORECASE) for p in self.FORBIDDEN_PATTERNS]
def set_allowed_tables(self, tables: List[str]):
"""Set the whitelist of allowed tables."""
self.allowed_tables = set(tables)
def validate(self, sql: str) -> Tuple[bool, str, Optional[str]]:
"""
Validate SQL query for safety.
Returns:
Tuple of (is_valid, message, sanitized_sql)
"""
if not sql or not sql.strip():
return False, "Empty SQL query", None
sql = sql.strip()
# Check for forbidden patterns
for pattern in self._compiled_patterns:
if pattern.search(sql):
return False, f"Forbidden pattern detected in query", None
# Parse SQL
try:
parsed = sqlparse.parse(sql)
except Exception as e:
return False, f"Failed to parse SQL: {e}", None
if not parsed:
return False, "Failed to parse SQL query", None
# Only allow single statements
if len(parsed) > 1:
return False, "Multiple SQL statements not allowed", None
statement = parsed[0]
# Check statement type
stmt_type = statement.get_type()
if stmt_type != 'SELECT':
return False, f"Only SELECT statements allowed, got: {stmt_type}", None
# Check for forbidden keywords in tokens
sql_upper = sql.upper()
for keyword in self.FORBIDDEN_KEYWORDS:
if keyword in sql_upper:
return False, f"Forbidden keyword detected: {keyword}", None
# Extract and validate tables
tables = self._extract_tables(statement)
if self.allowed_tables:
# Normalize for comparison (remove quotes, lowercase)
allowed_norm = {t.lower().replace('"', '').replace('`', '') for t in self.allowed_tables}
tables_norm = {t.lower().replace('"', '').replace('`', '') for t in tables}
invalid_tables = tables_norm - allowed_norm
if invalid_tables:
return False, f"Access denied to tables: {invalid_tables}", None
# Ensure LIMIT clause exists
sanitized = self._ensure_limit(sql)
return True, "Query validated successfully", sanitized
def _extract_tables(self, statement: Statement) -> Set[str]:
"""Extract table names from a SELECT statement using regex."""
tables = set()
sql = str(statement)
# Use regex to find tables after FROM and JOIN
# Pattern: FROM table_name or JOIN table_name, supporting quotes
# Matches: FROM table, FROM "table", FROM `table`
from_pattern = re.compile(
r'\bFROM\s+(?:["`]?)([a-zA-Z0-9_]+)(?:["`]?)',
re.IGNORECASE
)
join_pattern = re.compile(
r'\bJOIN\s+(?:["`]?)([a-zA-Z0-9_]+)(?:["`]?)',
re.IGNORECASE
)
# Find all FROM tables
for match in from_pattern.finditer(sql):
tables.add(match.group(1))
# Find all JOIN tables
for match in join_pattern.finditer(sql):
tables.add(match.group(1))
return tables
def _ensure_limit(self, sql: str) -> str:
"""Ensure the query has a LIMIT clause."""
sql_upper = sql.upper()
if 'LIMIT' in sql_upper:
# Check if limit is too high
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
if limit_match:
current_limit = int(limit_match.group(1))
if current_limit > self.max_limit:
# Replace with max limit
sql = re.sub(
r'LIMIT\s+\d+',
f'LIMIT {self.max_limit}',
sql,
flags=re.IGNORECASE
)
return sql
else:
# Add LIMIT clause
sql = sql.rstrip(';').strip()
return f"{sql} LIMIT {self.max_limit}"
_validator: Optional[SQLValidator] = None
def get_sql_validator() -> SQLValidator:
global _validator
if _validator is None:
_validator = SQLValidator()
return _validator
|