corpusdb / app /query_parser.py
mrsavage1's picture
Upload 62 files
5f30028 verified
import re
from typing import Dict, List, Tuple
class QueryParser:
"""Parse and validate SQL queries to prevent injection attacks"""
# Allowed SQL keywords for read operations
SAFE_READ_KEYWORDS = {
'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER',
'ON', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'BETWEEN', 'IS', 'NULL',
'ORDER', 'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'AS', 'DISTINCT',
'COUNT', 'SUM', 'AVG', 'MIN', 'MAX', 'UNION', 'INTERSECT', 'EXCEPT'
}
# Write operations that need allow_write=True
WRITE_KEYWORDS = {
'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'ALTER', 'DROP'
}
# CREATE operations (allowed for tables, blocked for databases)
CREATE_KEYWORDS = {'CREATE'}
# Always dangerous - never allow (must be standalone commands)
ALWAYS_DANGEROUS = {
'GRANT', 'REVOKE', 'EXEC', 'EXECUTE', 'CALL', 'DECLARE', 'CURSOR',
'PRAGMA'
}
# Dangerous commands that need word boundary checking
DANGEROUS_COMMANDS = {
'ATTACH', 'DETACH'
}
# Keywords that need special validation
DROP_DATABASE_PATTERNS = [
r'DROP\s+DATABASE',
r'DROP\s+SCHEMA'
]
CREATE_DATABASE_PATTERNS = [
r'CREATE\s+DATABASE',
r'CREATE\s+SCHEMA'
]
def __init__(self):
self.max_query_length = 10000
def sanitize_identifier(self, identifier: str) -> str:
"""Sanitize table/column identifiers"""
# Remove any non-alphanumeric characters except underscore and dot
sanitized = re.sub(r'[^\w\.]', '', identifier)
# Prevent path traversal
if '..' in sanitized or sanitized.startswith('/'):
raise ValueError(f"Invalid identifier: {identifier}")
return sanitized
def validate_query(self, sql: str, allow_write: bool = False, allow_schema_ops: bool = False) -> Tuple[bool, str]:
"""
Validate SQL query for safety
Args:
sql: SQL query to validate
allow_write: Allow INSERT, UPDATE, DELETE, CREATE TABLE
allow_schema_ops: Allow CREATE/DROP DATABASE/SCHEMA (for SQL imports)
Returns:
(is_valid, error_message)
"""
if not sql or not sql.strip():
return False, "Empty query"
if len(sql) > self.max_query_length:
return False, f"Query too long (max {self.max_query_length} chars)"
# Normalize query
sql_upper = sql.upper()
# Check for DROP DATABASE/SCHEMA (blocked unless allow_schema_ops)
if not allow_schema_ops:
for pattern in self.DROP_DATABASE_PATTERNS:
if re.search(pattern, sql_upper):
return False, "DROP DATABASE/SCHEMA not allowed"
# Check for CREATE DATABASE/SCHEMA (blocked unless allow_schema_ops)
if not allow_schema_ops:
for pattern in self.CREATE_DATABASE_PATTERNS:
if re.search(pattern, sql_upper):
return False, "CREATE DATABASE/SCHEMA not allowed via SQL. Use the 'Create Database' button or API endpoint instead."
# Check for CREATE TABLE (allowed with allow_write=True)
if 'CREATE' in sql_upper and 'TABLE' in sql_upper:
if not allow_write:
return False, "CREATE TABLE requires allow_write=True"
# Check for always dangerous keywords
for keyword in self.ALWAYS_DANGEROUS:
if keyword in sql_upper:
return False, f"Dangerous keyword not allowed: {keyword}"
# Check for dangerous commands (with word boundaries to avoid false positives)
for keyword in self.DANGEROUS_COMMANDS:
# Use word boundary to match only standalone commands, not table names
if re.search(r'\b' + keyword + r'\b', sql_upper):
return False, f"Dangerous keyword not allowed: {keyword}"
# Check for write operations
for keyword in self.WRITE_KEYWORDS:
if not allow_write and keyword in sql_upper:
return False, f"Write operation not allowed: {keyword} (use allow_write=True)"
# Check for common SQL injection patterns (but allow comments for phpMyAdmin imports)
injection_patterns = [
r';\s*DROP\s+DATABASE', # Only block DROP DATABASE after semicolon
r'UNION\s+SELECT.*FROM\s+(?![\w\.]+\s)', # UNION injection
r'OR\s+1\s*=\s*1',
r'OR\s+\'1\'\s*=\s*\'1\'',
]
for pattern in injection_patterns:
if re.search(pattern, sql_upper):
return False, f"Potential SQL injection detected"
return True, ""
def extract_tables(self, sql: str) -> List[str]:
"""Extract table names from SQL query"""
tables = set()
# FROM and JOIN clauses (qualified: db.table)
for m in re.findall(r'(?:FROM|JOIN)\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
tables.add(m)
# Unqualified table names
for m in re.findall(r'(?:FROM|JOIN)\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
if '.' not in m and m.upper() not in (
'SELECT', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'LIMIT',
'ORDER', 'GROUP', 'HAVING', 'SET', 'VALUES', 'INTO', 'AS',
'ON', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'CROSS', 'NATURAL',
'BETWEEN', 'IS', 'NULL', 'TRUE', 'FALSE', 'DISTINCT', 'ALL',
'EXISTS', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'UNION',
'INTERSECT', 'EXCEPT', 'WITH', 'RECURSIVE'
):
tables.add(m)
# CREATE TABLE
for m in re.findall(r'CREATE\s+(?:OR\s+REPLACE\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
tables.add(m)
# CREATE TABLE unqualified
for m in re.findall(r'CREATE\s+(?:OR\s+REPLACE\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
tables.add(m)
# INSERT INTO
for m in re.findall(r'INSERT\s+INTO\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
tables.add(m)
# INSERT INTO unqualified
for m in re.findall(r'INSERT\s+INTO\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
tables.add(m)
# UPDATE
for m in re.findall(r'UPDATE\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
tables.add(m)
# UPDATE unqualified
for m in re.findall(r'UPDATE\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
tables.add(m)
# DELETE FROM
for m in re.findall(r'DELETE\s+FROM\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
tables.add(m)
# DELETE FROM unqualified
for m in re.findall(r'DELETE\s+FROM\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
tables.add(m)
return [self.sanitize_identifier(m) for m in tables]
def build_safe_query(self, database: str, table: str, filters: Dict = None,
limit: int = None, offset: int = None) -> str:
"""Build a safe parameterized query"""
db = self.sanitize_identifier(database)
tbl = self.sanitize_identifier(table)
query = f'SELECT * FROM "{db}"."{tbl}"'
if filters:
where_clauses = []
for key, value in filters.items():
safe_key = self.sanitize_identifier(key)
# Use parameterized queries - values will be bound separately
where_clauses.append(f'"{safe_key}" = ?')
if where_clauses:
query += ' WHERE ' + ' AND '.join(where_clauses)
if limit:
query += f' LIMIT {int(limit)}'
if offset:
query += f' OFFSET {int(offset)}'
return query
def is_read_only(self, sql: str) -> bool:
"""Check if query is read-only"""
sql_upper = sql.upper().strip()
return sql_upper.startswith('SELECT') or sql_upper.startswith('WITH')
def split_sql_statements(self, sql: str) -> List[str]:
"""Split SQL dump into individual statements
Handles phpMyAdmin-style SQL dumps with comments and multiple statements
"""
statements = []
current_statement = []
in_string = False
string_char = None
lines = sql.split('\n')
for line in lines:
stripped = line.strip()
# Skip empty lines
if not stripped:
continue
# Skip comments
if stripped.startswith('--'):
continue
# Skip SET commands and transaction commands
upper_stripped = stripped.upper()
if upper_stripped.startswith(('SET ', 'START TRANSACTION', 'COMMIT')):
continue
# Check if line contains semicolon (statement terminator)
has_semicolon = False
for i, char in enumerate(line):
if char in ('"', "'"):
if not in_string:
in_string = True
string_char = char
elif char == string_char:
# Check if escaped
if i > 0 and line[i-1] != '\\':
in_string = False
string_char = None
elif char == ';' and not in_string:
has_semicolon = True
# Add everything up to semicolon
current_statement.append(line[:i+1])
break
if has_semicolon:
# Statement complete
stmt = ' '.join(current_statement).strip()
if stmt and not stmt.startswith('--'):
statements.append(stmt)
current_statement = []
in_string = False
string_char = None
else:
# Continue building statement
current_statement.append(line)
# Add any remaining statement
if current_statement:
stmt = ' '.join(current_statement).strip()
if stmt and not stmt.startswith('--'):
# Add semicolon if missing
if not stmt.endswith(';'):
stmt += ';'
statements.append(stmt)
return statements
query_parser = QueryParser()