| import re |
| from typing import Dict, List, Tuple |
|
|
| class QueryParser: |
| """Parse and validate SQL queries to prevent injection attacks""" |
| |
| |
| SAFE_READ_KEYWORDS = { |
| 'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', |
| 'ON', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'BETWEEN', 'IS', 'NULL', |
| 'ORDER', 'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'AS', 'DISTINCT', |
| 'COUNT', 'SUM', 'AVG', 'MIN', 'MAX', 'UNION', 'INTERSECT', 'EXCEPT' |
| } |
| |
| |
| WRITE_KEYWORDS = { |
| 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'ALTER', 'DROP' |
| } |
| |
| |
| CREATE_KEYWORDS = {'CREATE'} |
| |
| |
| ALWAYS_DANGEROUS = { |
| 'GRANT', 'REVOKE', 'EXEC', 'EXECUTE', 'CALL', 'DECLARE', 'CURSOR', |
| 'PRAGMA' |
| } |
| |
| |
| DANGEROUS_COMMANDS = { |
| 'ATTACH', 'DETACH' |
| } |
| |
| |
| DROP_DATABASE_PATTERNS = [ |
| r'DROP\s+DATABASE', |
| r'DROP\s+SCHEMA' |
| ] |
| |
| CREATE_DATABASE_PATTERNS = [ |
| r'CREATE\s+DATABASE', |
| r'CREATE\s+SCHEMA' |
| ] |
| |
| def __init__(self): |
| self.max_query_length = 10000 |
| |
| def sanitize_identifier(self, identifier: str) -> str: |
| """Sanitize table/column identifiers""" |
| |
| sanitized = re.sub(r'[^\w\.]', '', identifier) |
| |
| |
| if '..' in sanitized or sanitized.startswith('/'): |
| raise ValueError(f"Invalid identifier: {identifier}") |
| |
| return sanitized |
| |
| def validate_query(self, sql: str, allow_write: bool = False, allow_schema_ops: bool = False) -> Tuple[bool, str]: |
| """ |
| Validate SQL query for safety |
| |
| Args: |
| sql: SQL query to validate |
| allow_write: Allow INSERT, UPDATE, DELETE, CREATE TABLE |
| allow_schema_ops: Allow CREATE/DROP DATABASE/SCHEMA (for SQL imports) |
| |
| Returns: |
| (is_valid, error_message) |
| """ |
| if not sql or not sql.strip(): |
| return False, "Empty query" |
| |
| if len(sql) > self.max_query_length: |
| return False, f"Query too long (max {self.max_query_length} chars)" |
| |
| |
| sql_upper = sql.upper() |
| |
| |
| if not allow_schema_ops: |
| for pattern in self.DROP_DATABASE_PATTERNS: |
| if re.search(pattern, sql_upper): |
| return False, "DROP DATABASE/SCHEMA not allowed" |
| |
| |
| if not allow_schema_ops: |
| for pattern in self.CREATE_DATABASE_PATTERNS: |
| if re.search(pattern, sql_upper): |
| return False, "CREATE DATABASE/SCHEMA not allowed via SQL. Use the 'Create Database' button or API endpoint instead." |
| |
| |
| if 'CREATE' in sql_upper and 'TABLE' in sql_upper: |
| if not allow_write: |
| return False, "CREATE TABLE requires allow_write=True" |
| |
| |
| for keyword in self.ALWAYS_DANGEROUS: |
| if keyword in sql_upper: |
| return False, f"Dangerous keyword not allowed: {keyword}" |
| |
| |
| for keyword in self.DANGEROUS_COMMANDS: |
| |
| if re.search(r'\b' + keyword + r'\b', sql_upper): |
| return False, f"Dangerous keyword not allowed: {keyword}" |
| |
| |
| for keyword in self.WRITE_KEYWORDS: |
| if not allow_write and keyword in sql_upper: |
| return False, f"Write operation not allowed: {keyword} (use allow_write=True)" |
| |
| |
| injection_patterns = [ |
| r';\s*DROP\s+DATABASE', |
| r'UNION\s+SELECT.*FROM\s+(?![\w\.]+\s)', |
| r'OR\s+1\s*=\s*1', |
| r'OR\s+\'1\'\s*=\s*\'1\'', |
| ] |
| |
| for pattern in injection_patterns: |
| if re.search(pattern, sql_upper): |
| return False, f"Potential SQL injection detected" |
| |
| return True, "" |
| |
| def extract_tables(self, sql: str) -> List[str]: |
| """Extract table names from SQL query""" |
| tables = set() |
| |
| for m in re.findall(r'(?:FROM|JOIN)\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): |
| tables.add(m) |
| |
| for m in re.findall(r'(?:FROM|JOIN)\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE): |
| if '.' not in m and m.upper() not in ( |
| 'SELECT', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'LIMIT', |
| 'ORDER', 'GROUP', 'HAVING', 'SET', 'VALUES', 'INTO', 'AS', |
| 'ON', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'CROSS', 'NATURAL', |
| 'BETWEEN', 'IS', 'NULL', 'TRUE', 'FALSE', 'DISTINCT', 'ALL', |
| 'EXISTS', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'UNION', |
| 'INTERSECT', 'EXCEPT', 'WITH', 'RECURSIVE' |
| ): |
| tables.add(m) |
| |
| for m in re.findall(r'CREATE\s+(?:OR\s+REPLACE\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): |
| tables.add(m) |
| |
| for m in re.findall(r'CREATE\s+(?:OR\s+REPLACE\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["\']?(\w+)["\']?\s', sql, re.IGNORECASE): |
| tables.add(m) |
| |
| for m in re.findall(r'INSERT\s+INTO\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): |
| tables.add(m) |
| |
| for m in re.findall(r'INSERT\s+INTO\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE): |
| tables.add(m) |
| |
| for m in re.findall(r'UPDATE\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): |
| tables.add(m) |
| |
| for m in re.findall(r'UPDATE\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE): |
| tables.add(m) |
| |
| for m in re.findall(r'DELETE\s+FROM\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): |
| tables.add(m) |
| |
| for m in re.findall(r'DELETE\s+FROM\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE): |
| tables.add(m) |
| return [self.sanitize_identifier(m) for m in tables] |
| |
| def build_safe_query(self, database: str, table: str, filters: Dict = None, |
| limit: int = None, offset: int = None) -> str: |
| """Build a safe parameterized query""" |
| db = self.sanitize_identifier(database) |
| tbl = self.sanitize_identifier(table) |
| |
| query = f'SELECT * FROM "{db}"."{tbl}"' |
| |
| if filters: |
| where_clauses = [] |
| for key, value in filters.items(): |
| safe_key = self.sanitize_identifier(key) |
| |
| where_clauses.append(f'"{safe_key}" = ?') |
| if where_clauses: |
| query += ' WHERE ' + ' AND '.join(where_clauses) |
| |
| if limit: |
| query += f' LIMIT {int(limit)}' |
| if offset: |
| query += f' OFFSET {int(offset)}' |
| |
| return query |
| |
| def is_read_only(self, sql: str) -> bool: |
| """Check if query is read-only""" |
| sql_upper = sql.upper().strip() |
| return sql_upper.startswith('SELECT') or sql_upper.startswith('WITH') |
| |
| def split_sql_statements(self, sql: str) -> List[str]: |
| """Split SQL dump into individual statements |
| |
| Handles phpMyAdmin-style SQL dumps with comments and multiple statements |
| """ |
| statements = [] |
| current_statement = [] |
| in_string = False |
| string_char = None |
| |
| lines = sql.split('\n') |
| |
| for line in lines: |
| stripped = line.strip() |
| |
| |
| if not stripped: |
| continue |
| |
| |
| if stripped.startswith('--'): |
| continue |
| |
| |
| upper_stripped = stripped.upper() |
| if upper_stripped.startswith(('SET ', 'START TRANSACTION', 'COMMIT')): |
| continue |
| |
| |
| has_semicolon = False |
| for i, char in enumerate(line): |
| if char in ('"', "'"): |
| if not in_string: |
| in_string = True |
| string_char = char |
| elif char == string_char: |
| |
| if i > 0 and line[i-1] != '\\': |
| in_string = False |
| string_char = None |
| elif char == ';' and not in_string: |
| has_semicolon = True |
| |
| current_statement.append(line[:i+1]) |
| break |
| |
| if has_semicolon: |
| |
| stmt = ' '.join(current_statement).strip() |
| if stmt and not stmt.startswith('--'): |
| statements.append(stmt) |
| current_statement = [] |
| in_string = False |
| string_char = None |
| else: |
| |
| current_statement.append(line) |
| |
| |
| if current_statement: |
| stmt = ' '.join(current_statement).strip() |
| if stmt and not stmt.startswith('--'): |
| |
| if not stmt.endswith(';'): |
| stmt += ';' |
| statements.append(stmt) |
| |
| return statements |
|
|
| query_parser = QueryParser() |
|
|