import re from typing import Dict, List, Tuple class QueryParser: """Parse and validate SQL queries to prevent injection attacks""" # Allowed SQL keywords for read operations SAFE_READ_KEYWORDS = { 'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'BETWEEN', 'IS', 'NULL', 'ORDER', 'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'AS', 'DISTINCT', 'COUNT', 'SUM', 'AVG', 'MIN', 'MAX', 'UNION', 'INTERSECT', 'EXCEPT' } # Write operations that need allow_write=True WRITE_KEYWORDS = { 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'ALTER', 'DROP' } # CREATE operations (allowed for tables, blocked for databases) CREATE_KEYWORDS = {'CREATE'} # Always dangerous - never allow (must be standalone commands) ALWAYS_DANGEROUS = { 'GRANT', 'REVOKE', 'EXEC', 'EXECUTE', 'CALL', 'DECLARE', 'CURSOR', 'PRAGMA' } # Dangerous commands that need word boundary checking DANGEROUS_COMMANDS = { 'ATTACH', 'DETACH' } # Keywords that need special validation DROP_DATABASE_PATTERNS = [ r'DROP\s+DATABASE', r'DROP\s+SCHEMA' ] CREATE_DATABASE_PATTERNS = [ r'CREATE\s+DATABASE', r'CREATE\s+SCHEMA' ] def __init__(self): self.max_query_length = 10000 def sanitize_identifier(self, identifier: str) -> str: """Sanitize table/column identifiers""" # Remove any non-alphanumeric characters except underscore and dot sanitized = re.sub(r'[^\w\.]', '', identifier) # Prevent path traversal if '..' in sanitized or sanitized.startswith('/'): raise ValueError(f"Invalid identifier: {identifier}") return sanitized def validate_query(self, sql: str, allow_write: bool = False, allow_schema_ops: bool = False) -> Tuple[bool, str]: """ Validate SQL query for safety Args: sql: SQL query to validate allow_write: Allow INSERT, UPDATE, DELETE, CREATE TABLE allow_schema_ops: Allow CREATE/DROP DATABASE/SCHEMA (for SQL imports) Returns: (is_valid, error_message) """ if not sql or not sql.strip(): return False, "Empty query" if len(sql) > self.max_query_length: return False, f"Query too long (max {self.max_query_length} chars)" # Normalize query sql_upper = sql.upper() # Check for DROP DATABASE/SCHEMA (blocked unless allow_schema_ops) if not allow_schema_ops: for pattern in self.DROP_DATABASE_PATTERNS: if re.search(pattern, sql_upper): return False, "DROP DATABASE/SCHEMA not allowed" # Check for CREATE DATABASE/SCHEMA (blocked unless allow_schema_ops) if not allow_schema_ops: for pattern in self.CREATE_DATABASE_PATTERNS: if re.search(pattern, sql_upper): return False, "CREATE DATABASE/SCHEMA not allowed via SQL. Use the 'Create Database' button or API endpoint instead." # Check for CREATE TABLE (allowed with allow_write=True) if 'CREATE' in sql_upper and 'TABLE' in sql_upper: if not allow_write: return False, "CREATE TABLE requires allow_write=True" # Check for always dangerous keywords for keyword in self.ALWAYS_DANGEROUS: if keyword in sql_upper: return False, f"Dangerous keyword not allowed: {keyword}" # Check for dangerous commands (with word boundaries to avoid false positives) for keyword in self.DANGEROUS_COMMANDS: # Use word boundary to match only standalone commands, not table names if re.search(r'\b' + keyword + r'\b', sql_upper): return False, f"Dangerous keyword not allowed: {keyword}" # Check for write operations for keyword in self.WRITE_KEYWORDS: if not allow_write and keyword in sql_upper: return False, f"Write operation not allowed: {keyword} (use allow_write=True)" # Check for common SQL injection patterns (but allow comments for phpMyAdmin imports) injection_patterns = [ r';\s*DROP\s+DATABASE', # Only block DROP DATABASE after semicolon r'UNION\s+SELECT.*FROM\s+(?![\w\.]+\s)', # UNION injection r'OR\s+1\s*=\s*1', r'OR\s+\'1\'\s*=\s*\'1\'', ] for pattern in injection_patterns: if re.search(pattern, sql_upper): return False, f"Potential SQL injection detected" return True, "" def extract_tables(self, sql: str) -> List[str]: """Extract table names from SQL query""" tables = set() # FROM and JOIN clauses (qualified: db.table) for m in re.findall(r'(?:FROM|JOIN)\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): tables.add(m) # Unqualified table names for m in re.findall(r'(?:FROM|JOIN)\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE): if '.' not in m and m.upper() not in ( 'SELECT', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'LIMIT', 'ORDER', 'GROUP', 'HAVING', 'SET', 'VALUES', 'INTO', 'AS', 'ON', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'CROSS', 'NATURAL', 'BETWEEN', 'IS', 'NULL', 'TRUE', 'FALSE', 'DISTINCT', 'ALL', 'EXISTS', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'UNION', 'INTERSECT', 'EXCEPT', 'WITH', 'RECURSIVE' ): tables.add(m) # CREATE TABLE for m in re.findall(r'CREATE\s+(?:OR\s+REPLACE\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): tables.add(m) # CREATE TABLE unqualified for m in re.findall(r'CREATE\s+(?:OR\s+REPLACE\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["\']?(\w+)["\']?\s', sql, re.IGNORECASE): tables.add(m) # INSERT INTO for m in re.findall(r'INSERT\s+INTO\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): tables.add(m) # INSERT INTO unqualified for m in re.findall(r'INSERT\s+INTO\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE): tables.add(m) # UPDATE for m in re.findall(r'UPDATE\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): tables.add(m) # UPDATE unqualified for m in re.findall(r'UPDATE\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE): tables.add(m) # DELETE FROM for m in re.findall(r'DELETE\s+FROM\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE): tables.add(m) # DELETE FROM unqualified for m in re.findall(r'DELETE\s+FROM\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE): tables.add(m) return [self.sanitize_identifier(m) for m in tables] def build_safe_query(self, database: str, table: str, filters: Dict = None, limit: int = None, offset: int = None) -> str: """Build a safe parameterized query""" db = self.sanitize_identifier(database) tbl = self.sanitize_identifier(table) query = f'SELECT * FROM "{db}"."{tbl}"' if filters: where_clauses = [] for key, value in filters.items(): safe_key = self.sanitize_identifier(key) # Use parameterized queries - values will be bound separately where_clauses.append(f'"{safe_key}" = ?') if where_clauses: query += ' WHERE ' + ' AND '.join(where_clauses) if limit: query += f' LIMIT {int(limit)}' if offset: query += f' OFFSET {int(offset)}' return query def is_read_only(self, sql: str) -> bool: """Check if query is read-only""" sql_upper = sql.upper().strip() return sql_upper.startswith('SELECT') or sql_upper.startswith('WITH') def split_sql_statements(self, sql: str) -> List[str]: """Split SQL dump into individual statements Handles phpMyAdmin-style SQL dumps with comments and multiple statements """ statements = [] current_statement = [] in_string = False string_char = None lines = sql.split('\n') for line in lines: stripped = line.strip() # Skip empty lines if not stripped: continue # Skip comments if stripped.startswith('--'): continue # Skip SET commands and transaction commands upper_stripped = stripped.upper() if upper_stripped.startswith(('SET ', 'START TRANSACTION', 'COMMIT')): continue # Check if line contains semicolon (statement terminator) has_semicolon = False for i, char in enumerate(line): if char in ('"', "'"): if not in_string: in_string = True string_char = char elif char == string_char: # Check if escaped if i > 0 and line[i-1] != '\\': in_string = False string_char = None elif char == ';' and not in_string: has_semicolon = True # Add everything up to semicolon current_statement.append(line[:i+1]) break if has_semicolon: # Statement complete stmt = ' '.join(current_statement).strip() if stmt and not stmt.startswith('--'): statements.append(stmt) current_statement = [] in_string = False string_char = None else: # Continue building statement current_statement.append(line) # Add any remaining statement if current_statement: stmt = ' '.join(current_statement).strip() if stmt and not stmt.startswith('--'): # Add semicolon if missing if not stmt.endswith(';'): stmt += ';' statements.append(stmt) return statements query_parser = QueryParser()