import re from typing import Any, Dict, List from email_validator import validate_email, EmailNotValidError class InputValidator: """Comprehensive input validation and sanitization""" def __init__(self): # Regex patterns self.username_pattern = re.compile(r'^[a-zA-Z0-9_-]{3,32}$') self.database_name_pattern = re.compile(r'^[a-zA-Z0-9_-]{1,64}$') self.table_name_pattern = re.compile(r'^[a-zA-Z0-9_-]{1,64}$') self.column_name_pattern = re.compile(r'^[a-zA-Z0-9_]{1,64}$') # Dangerous patterns self.xss_patterns = [ r']*>.*?', r'javascript:', r'on\w+\s*=', r' tuple[bool, str]: """Validate username format""" if not username: return False, "Username is required" if len(username) < 3: return False, "Username must be at least 3 characters" if len(username) > 32: return False, "Username must be at most 32 characters" if not self.username_pattern.match(username): return False, "Username can only contain letters, numbers, underscore and hyphen" # Check for reserved names (truly system-level ones only) reserved = ['root', 'system', 'api', 'null', 'undefined', 'corpusdb', 'superuser'] if username.lower() in reserved: return False, "This username is reserved. Please choose a different one" return True, "" def validate_email(self, email: str) -> tuple[bool, str]: """Validate email format""" if not email: return False, "Email is required" try: # Use email-validator library valid = validate_email(email, check_deliverability=False) return True, "" except EmailNotValidError as e: return False, str(e) def validate_password(self, password: str) -> tuple[bool, str]: """Validate password strength""" if not password: return False, "Password is required" if len(password) < 8: return False, "Password must be at least 8 characters" if len(password) > 128: return False, "Password must be at most 128 characters" # Check complexity has_upper = any(c.isupper() for c in password) has_lower = any(c.islower() for c in password) has_digit = any(c.isdigit() for c in password) if not (has_upper and has_lower and has_digit): return False, "Password must contain uppercase, lowercase, and numbers" # Check for common passwords common_passwords = [ 'password', '12345678', 'qwerty', 'abc123', 'password123', 'admin123', 'letmein', 'welcome', 'monkey', '1234567890' ] if password.lower() in common_passwords: return False, "Password is too common" return True, "" def validate_database_name(self, name: str) -> tuple[bool, str]: """Validate database name""" if not name: return False, "Database name is required" if not self.database_name_pattern.match(name): return False, "Database name can only contain letters, numbers, underscore and hyphen" if name.upper() in self.sql_keywords: return False, "Database name cannot be a SQL keyword" return True, "" def validate_table_name(self, name: str) -> tuple[bool, str]: """Validate table name""" if not name: return False, "Table name is required" if not self.table_name_pattern.match(name): return False, "Table name can only contain letters, numbers, underscore and hyphen" if name.upper() in self.sql_keywords: return False, "Table name cannot be a SQL keyword" return True, "" def validate_column_name(self, name: str) -> tuple[bool, str]: """Validate column name""" if not name: return False, "Column name is required" if not self.column_name_pattern.match(name): return False, "Column name can only contain letters, numbers and underscore" if name.upper() in self.sql_keywords: return False, "Column name cannot be a SQL keyword" return True, "" def sanitize_string(self, value: str, max_length: int = 1000) -> str: """Sanitize string input to prevent XSS""" if not isinstance(value, str): return str(value) # Truncate value = value[:max_length] # Remove dangerous patterns for pattern in self.xss_patterns: value = re.sub(pattern, '', value, flags=re.IGNORECASE) # Escape HTML entities value = value.replace('<', '<').replace('>', '>') value = value.replace('"', '"').replace("'", ''') return value.strip() def validate_json_data(self, data: Dict, max_depth: int = 5, current_depth: int = 0) -> tuple[bool, str]: """Validate JSON data structure""" if current_depth > max_depth: return False, "JSON structure too deep" if not isinstance(data, dict): return False, "Data must be a dictionary" if len(data) > 100: return False, "Too many fields (max 100)" for key, value in data.items(): # Validate key if not isinstance(key, str): return False, f"Key must be string: {key}" if len(key) > 64: return False, f"Key too long: {key}" # Validate value if isinstance(value, dict): is_valid, error = self.validate_json_data(value, max_depth, current_depth + 1) if not is_valid: return False, error elif isinstance(value, list): if len(value) > 1000: return False, f"Array too large for key: {key}" elif isinstance(value, str): if len(value) > 10000: return False, f"String too long for key: {key}" return True, "" def validate_file_upload(self, filename: str, content_type: str, size: int) -> tuple[bool, str]: """Validate file upload""" # Check filename if not filename: return False, "Filename is required" # Check for path traversal in filename if '..' in filename or '/' in filename or '\\' in filename: return False, "Invalid filename" # Check extension allowed_extensions = {'.csv', '.json', '.parquet', '.txt'} ext = '.' + filename.rsplit('.', 1)[-1].lower() if '.' in filename else '' if ext not in allowed_extensions: return False, f"File type not allowed. Allowed: {', '.join(allowed_extensions)}" # Check content type allowed_types = { 'text/csv', 'application/json', 'application/octet-stream', 'text/plain', 'application/vnd.apache.parquet' } if content_type not in allowed_types: return False, f"Content type not allowed: {content_type}" # Check size (max 100MB) max_size = 100 * 1024 * 1024 if size > max_size: return False, f"File too large. Max size: 100MB" return True, "" def validate_workspace_id(self, workspace_id: str) -> tuple[bool, str]: """Validate workspace ID format""" if not workspace_id: return False, "Workspace ID is required" # Must start with user_ if not workspace_id.startswith('user_'): return False, "Invalid workspace ID format" # Check pattern pattern = re.compile(r'^user_[a-f0-9]{16}$') if not pattern.match(workspace_id): return False, "Invalid workspace ID format" return True, "" input_validator = InputValidator()