| import re |
| from typing import Any, Dict, List |
| from email_validator import validate_email, EmailNotValidError |
|
|
| class InputValidator: |
| """Comprehensive input validation and sanitization""" |
| |
| def __init__(self): |
| |
| self.username_pattern = re.compile(r'^[a-zA-Z0-9_-]{3,32}$') |
| self.database_name_pattern = re.compile(r'^[a-zA-Z0-9_-]{1,64}$') |
| self.table_name_pattern = re.compile(r'^[a-zA-Z0-9_-]{1,64}$') |
| self.column_name_pattern = re.compile(r'^[a-zA-Z0-9_]{1,64}$') |
| |
| |
| self.xss_patterns = [ |
| r'<script[^>]*>.*?</script>', |
| r'javascript:', |
| r'on\w+\s*=', |
| r'<iframe', |
| r'<object', |
| r'<embed', |
| ] |
| |
| |
| self.sql_keywords = { |
| 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', |
| 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE', 'UNION', 'DECLARE' |
| } |
| |
| def validate_username(self, username: str) -> tuple[bool, str]: |
| """Validate username format""" |
| if not username: |
| return False, "Username is required" |
| |
| if len(username) < 3: |
| return False, "Username must be at least 3 characters" |
| |
| if len(username) > 32: |
| return False, "Username must be at most 32 characters" |
| |
| if not self.username_pattern.match(username): |
| return False, "Username can only contain letters, numbers, underscore and hyphen" |
| |
| |
| reserved = ['root', 'system', 'api', 'null', 'undefined', 'corpusdb', 'superuser'] |
| if username.lower() in reserved: |
| return False, "This username is reserved. Please choose a different one" |
| |
| return True, "" |
| |
| def validate_email(self, email: str) -> tuple[bool, str]: |
| """Validate email format""" |
| if not email: |
| return False, "Email is required" |
| |
| try: |
| |
| valid = validate_email(email, check_deliverability=False) |
| return True, "" |
| except EmailNotValidError as e: |
| return False, str(e) |
| |
| def validate_password(self, password: str) -> tuple[bool, str]: |
| """Validate password strength""" |
| if not password: |
| return False, "Password is required" |
| |
| if len(password) < 8: |
| return False, "Password must be at least 8 characters" |
| |
| if len(password) > 128: |
| return False, "Password must be at most 128 characters" |
| |
| |
| has_upper = any(c.isupper() for c in password) |
| has_lower = any(c.islower() for c in password) |
| has_digit = any(c.isdigit() for c in password) |
| |
| if not (has_upper and has_lower and has_digit): |
| return False, "Password must contain uppercase, lowercase, and numbers" |
| |
| |
| common_passwords = [ |
| 'password', '12345678', 'qwerty', 'abc123', 'password123', |
| 'admin123', 'letmein', 'welcome', 'monkey', '1234567890' |
| ] |
| if password.lower() in common_passwords: |
| return False, "Password is too common" |
| |
| return True, "" |
| |
| def validate_database_name(self, name: str) -> tuple[bool, str]: |
| """Validate database name""" |
| if not name: |
| return False, "Database name is required" |
| |
| if not self.database_name_pattern.match(name): |
| return False, "Database name can only contain letters, numbers, underscore and hyphen" |
| |
| if name.upper() in self.sql_keywords: |
| return False, "Database name cannot be a SQL keyword" |
| |
| return True, "" |
| |
| def validate_table_name(self, name: str) -> tuple[bool, str]: |
| """Validate table name""" |
| if not name: |
| return False, "Table name is required" |
| |
| if not self.table_name_pattern.match(name): |
| return False, "Table name can only contain letters, numbers, underscore and hyphen" |
| |
| if name.upper() in self.sql_keywords: |
| return False, "Table name cannot be a SQL keyword" |
| |
| return True, "" |
| |
| def validate_column_name(self, name: str) -> tuple[bool, str]: |
| """Validate column name""" |
| if not name: |
| return False, "Column name is required" |
| |
| if not self.column_name_pattern.match(name): |
| return False, "Column name can only contain letters, numbers and underscore" |
| |
| if name.upper() in self.sql_keywords: |
| return False, "Column name cannot be a SQL keyword" |
| |
| return True, "" |
| |
| def sanitize_string(self, value: str, max_length: int = 1000) -> str: |
| """Sanitize string input to prevent XSS""" |
| if not isinstance(value, str): |
| return str(value) |
| |
| |
| value = value[:max_length] |
| |
| |
| for pattern in self.xss_patterns: |
| value = re.sub(pattern, '', value, flags=re.IGNORECASE) |
| |
| |
| value = value.replace('<', '<').replace('>', '>') |
| value = value.replace('"', '"').replace("'", ''') |
| |
| return value.strip() |
| |
| def validate_json_data(self, data: Dict, max_depth: int = 5, current_depth: int = 0) -> tuple[bool, str]: |
| """Validate JSON data structure""" |
| if current_depth > max_depth: |
| return False, "JSON structure too deep" |
| |
| if not isinstance(data, dict): |
| return False, "Data must be a dictionary" |
| |
| if len(data) > 100: |
| return False, "Too many fields (max 100)" |
| |
| for key, value in data.items(): |
| |
| if not isinstance(key, str): |
| return False, f"Key must be string: {key}" |
| |
| if len(key) > 64: |
| return False, f"Key too long: {key}" |
| |
| |
| if isinstance(value, dict): |
| is_valid, error = self.validate_json_data(value, max_depth, current_depth + 1) |
| if not is_valid: |
| return False, error |
| elif isinstance(value, list): |
| if len(value) > 1000: |
| return False, f"Array too large for key: {key}" |
| elif isinstance(value, str): |
| if len(value) > 10000: |
| return False, f"String too long for key: {key}" |
| |
| return True, "" |
| |
| def validate_file_upload(self, filename: str, content_type: str, size: int) -> tuple[bool, str]: |
| """Validate file upload""" |
| |
| if not filename: |
| return False, "Filename is required" |
| |
| |
| if '..' in filename or '/' in filename or '\\' in filename: |
| return False, "Invalid filename" |
| |
| |
| allowed_extensions = {'.csv', '.json', '.parquet', '.txt'} |
| ext = '.' + filename.rsplit('.', 1)[-1].lower() if '.' in filename else '' |
| if ext not in allowed_extensions: |
| return False, f"File type not allowed. Allowed: {', '.join(allowed_extensions)}" |
| |
| |
| allowed_types = { |
| 'text/csv', 'application/json', 'application/octet-stream', |
| 'text/plain', 'application/vnd.apache.parquet' |
| } |
| if content_type not in allowed_types: |
| return False, f"Content type not allowed: {content_type}" |
| |
| |
| max_size = 100 * 1024 * 1024 |
| if size > max_size: |
| return False, f"File too large. Max size: 100MB" |
| |
| return True, "" |
| |
| def validate_workspace_id(self, workspace_id: str) -> tuple[bool, str]: |
| """Validate workspace ID format""" |
| if not workspace_id: |
| return False, "Workspace ID is required" |
| |
| |
| if not workspace_id.startswith('user_'): |
| return False, "Invalid workspace ID format" |
| |
| |
| pattern = re.compile(r'^user_[a-f0-9]{16}$') |
| if not pattern.match(workspace_id): |
| return False, "Invalid workspace ID format" |
| |
| return True, "" |
|
|
| input_validator = InputValidator() |
|
|