corpusdb / app /input_validator.py
mrsavage1's picture
Upload 52 files
723f9ab verified
import re
from typing import Any, Dict, List
from email_validator import validate_email, EmailNotValidError
class InputValidator:
"""Comprehensive input validation and sanitization"""
def __init__(self):
# Regex patterns
self.username_pattern = re.compile(r'^[a-zA-Z0-9_-]{3,32}$')
self.database_name_pattern = re.compile(r'^[a-zA-Z0-9_-]{1,64}$')
self.table_name_pattern = re.compile(r'^[a-zA-Z0-9_-]{1,64}$')
self.column_name_pattern = re.compile(r'^[a-zA-Z0-9_]{1,64}$')
# Dangerous patterns
self.xss_patterns = [
r'<script[^>]*>.*?</script>',
r'javascript:',
r'on\w+\s*=',
r'<iframe',
r'<object',
r'<embed',
]
# SQL keywords to escape
self.sql_keywords = {
'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE',
'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE', 'UNION', 'DECLARE'
}
def validate_username(self, username: str) -> tuple[bool, str]:
"""Validate username format"""
if not username:
return False, "Username is required"
if len(username) < 3:
return False, "Username must be at least 3 characters"
if len(username) > 32:
return False, "Username must be at most 32 characters"
if not self.username_pattern.match(username):
return False, "Username can only contain letters, numbers, underscore and hyphen"
# Check for reserved names (truly system-level ones only)
reserved = ['root', 'system', 'api', 'null', 'undefined', 'corpusdb', 'superuser']
if username.lower() in reserved:
return False, "This username is reserved. Please choose a different one"
return True, ""
def validate_email(self, email: str) -> tuple[bool, str]:
"""Validate email format"""
if not email:
return False, "Email is required"
try:
# Use email-validator library
valid = validate_email(email, check_deliverability=False)
return True, ""
except EmailNotValidError as e:
return False, str(e)
def validate_password(self, password: str) -> tuple[bool, str]:
"""Validate password strength"""
if not password:
return False, "Password is required"
if len(password) < 8:
return False, "Password must be at least 8 characters"
if len(password) > 128:
return False, "Password must be at most 128 characters"
# Check complexity
has_upper = any(c.isupper() for c in password)
has_lower = any(c.islower() for c in password)
has_digit = any(c.isdigit() for c in password)
if not (has_upper and has_lower and has_digit):
return False, "Password must contain uppercase, lowercase, and numbers"
# Check for common passwords
common_passwords = [
'password', '12345678', 'qwerty', 'abc123', 'password123',
'admin123', 'letmein', 'welcome', 'monkey', '1234567890'
]
if password.lower() in common_passwords:
return False, "Password is too common"
return True, ""
def validate_database_name(self, name: str) -> tuple[bool, str]:
"""Validate database name"""
if not name:
return False, "Database name is required"
if not self.database_name_pattern.match(name):
return False, "Database name can only contain letters, numbers, underscore and hyphen"
if name.upper() in self.sql_keywords:
return False, "Database name cannot be a SQL keyword"
return True, ""
def validate_table_name(self, name: str) -> tuple[bool, str]:
"""Validate table name"""
if not name:
return False, "Table name is required"
if not self.table_name_pattern.match(name):
return False, "Table name can only contain letters, numbers, underscore and hyphen"
if name.upper() in self.sql_keywords:
return False, "Table name cannot be a SQL keyword"
return True, ""
def validate_column_name(self, name: str) -> tuple[bool, str]:
"""Validate column name"""
if not name:
return False, "Column name is required"
if not self.column_name_pattern.match(name):
return False, "Column name can only contain letters, numbers and underscore"
if name.upper() in self.sql_keywords:
return False, "Column name cannot be a SQL keyword"
return True, ""
def sanitize_string(self, value: str, max_length: int = 1000) -> str:
"""Sanitize string input to prevent XSS"""
if not isinstance(value, str):
return str(value)
# Truncate
value = value[:max_length]
# Remove dangerous patterns
for pattern in self.xss_patterns:
value = re.sub(pattern, '', value, flags=re.IGNORECASE)
# Escape HTML entities
value = value.replace('<', '&lt;').replace('>', '&gt;')
value = value.replace('"', '&quot;').replace("'", '&#x27;')
return value.strip()
def validate_json_data(self, data: Dict, max_depth: int = 5, current_depth: int = 0) -> tuple[bool, str]:
"""Validate JSON data structure"""
if current_depth > max_depth:
return False, "JSON structure too deep"
if not isinstance(data, dict):
return False, "Data must be a dictionary"
if len(data) > 100:
return False, "Too many fields (max 100)"
for key, value in data.items():
# Validate key
if not isinstance(key, str):
return False, f"Key must be string: {key}"
if len(key) > 64:
return False, f"Key too long: {key}"
# Validate value
if isinstance(value, dict):
is_valid, error = self.validate_json_data(value, max_depth, current_depth + 1)
if not is_valid:
return False, error
elif isinstance(value, list):
if len(value) > 1000:
return False, f"Array too large for key: {key}"
elif isinstance(value, str):
if len(value) > 10000:
return False, f"String too long for key: {key}"
return True, ""
def validate_file_upload(self, filename: str, content_type: str, size: int) -> tuple[bool, str]:
"""Validate file upload"""
# Check filename
if not filename:
return False, "Filename is required"
# Check for path traversal in filename
if '..' in filename or '/' in filename or '\\' in filename:
return False, "Invalid filename"
# Check extension
allowed_extensions = {'.csv', '.json', '.parquet', '.txt'}
ext = '.' + filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
if ext not in allowed_extensions:
return False, f"File type not allowed. Allowed: {', '.join(allowed_extensions)}"
# Check content type
allowed_types = {
'text/csv', 'application/json', 'application/octet-stream',
'text/plain', 'application/vnd.apache.parquet'
}
if content_type not in allowed_types:
return False, f"Content type not allowed: {content_type}"
# Check size (max 100MB)
max_size = 100 * 1024 * 1024
if size > max_size:
return False, f"File too large. Max size: 100MB"
return True, ""
def validate_workspace_id(self, workspace_id: str) -> tuple[bool, str]:
"""Validate workspace ID format"""
if not workspace_id:
return False, "Workspace ID is required"
# Must start with user_
if not workspace_id.startswith('user_'):
return False, "Invalid workspace ID format"
# Check pattern
pattern = re.compile(r'^user_[a-f0-9]{16}$')
if not pattern.match(workspace_id):
return False, "Invalid workspace ID format"
return True, ""
input_validator = InputValidator()