""" Input Validation and Security Middleware for AegisLM Backend. Provides comprehensive input validation, sanitization, and security checks for all incoming requests to prevent common vulnerabilities. """ import re import html import logging from typing import Dict, Any, List, Optional from fastapi import Request, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse logger = logging.getLogger(__name__) class SecurityValidator: """ Security validation utilities for input sanitization and validation. """ # Common attack patterns SQL_INJECTION_PATTERNS = [ r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION)\b)", r"(\b(OR|AND)\s+\d+\s*=\s*\d+)", r"(\b(OR|AND)\s+['\"][\w\s]*['\"]\s*=\s*['\"][\w\s]*['\"])", r"(--|#|\/\*|\*\/)", r"(\b(SCRIPT|JAVASCRIPT|VBSCRIPT|ONLOAD|ONERROR)\b)", ] XSS_PATTERNS = [ r"]*>.*?", r"javascript:", r"vbscript:", r"onload\s*=", r"onerror\s*=", r"onclick\s*=", r"]*>", r"]*>", r"]*>", ] PATH_TRAVERSAL_PATTERNS = [ r"\.\.[\/\\]", r"\.\.%2f", r"\.\.%5c", r"%2e%2e%2f", r"%2e%2e%5c", ] @classmethod def validate_input(cls, input_data: str, max_length: int = 10000) -> str: """ Validate and sanitize input string. Args: input_data: Input string to validate max_length: Maximum allowed length Returns: Sanitized input string Raises: HTTPException: If input contains malicious content """ if not isinstance(input_data, str): return input_data # Check length if len(input_data) > max_length: raise HTTPException( status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="Input too long" ) # Check for SQL injection patterns for pattern in cls.SQL_INJECTION_PATTERNS: if re.search(pattern, input_data, re.IGNORECASE): # Using explicit slicing to help static analyzers input_snippet = str(input_data)[0:100] logger.warning(f"Potential vulnerability detected in input: {input_snippet}...") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid input detected" ) # Check for XSS patterns for pattern in cls.XSS_PATTERNS: if re.search(pattern, input_data, re.IGNORECASE): input_snippet = str(input_data)[0:100] logger.warning(f"Potential XSS detected: {input_snippet}...") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid input detected" ) # Check for path traversal patterns for pattern in cls.PATH_TRAVERSAL_PATTERNS: if re.search(pattern, input_data, re.IGNORECASE): input_snippet = str(input_data)[0:100] logger.warning(f"Potential path traversal detected: {input_snippet}...") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid input detected" ) # HTML escape the input sanitized = html.escape(input_data) return sanitized @classmethod def validate_json_data(cls, data: Dict[str, Any], max_depth: int = 10) -> Dict[str, Any]: """ Validate JSON data recursively. Args: data: JSON data to validate max_depth: Maximum nesting depth Returns: Validated JSON data Raises: HTTPException: If data contains malicious content """ def validate_recursive(obj, depth=0): if depth > max_depth: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Input nesting too deep" ) if isinstance(obj, str): return cls.validate_input(obj) elif isinstance(obj, dict): return {k: validate_recursive(v, depth + 1) for k, v in obj.items()} elif isinstance(obj, list): return [validate_recursive(item, depth + 1) for item in obj] else: return obj return validate_recursive(data) @classmethod def validate_file_upload(cls, filename: str, content_type: str, size: int) -> bool: """ Validate uploaded file for security. Args: filename: Uploaded filename content_type: File content type size: File size in bytes Returns: True if file is valid Raises: HTTPException: If file is invalid """ # Check file size (max 10MB) max_size = 10 * 1024 * 1024 # 10MB if size > max_size: raise HTTPException( status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="File too large" ) # Check filename for suspicious patterns dangerous_extensions = ['.exe', '.bat', '.cmd', '.scr', '.pif', '.com'] filename_lower = filename.lower() for ext in dangerous_extensions: if filename_lower.endswith(ext): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="File type not allowed" ) # Check for path traversal in filename if '..' in filename or '/' in filename or '\\' in filename: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid filename" ) # Allowed content types allowed_types = [ 'image/jpeg', 'image/png', 'image/gif', 'image/webp', 'text/plain', 'text/csv', 'application/json', 'application/pdf' ] if content_type not in allowed_types: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Content type not allowed" ) return True class InputValidationMiddleware(BaseHTTPMiddleware): """ Middleware for automatic input validation and sanitization. """ def __init__(self, app, exclude_paths: Optional[List[str]] = None): super().__init__(app) self.exclude_paths = exclude_paths or ['/health', '/docs', '/openapi.json'] self.validator = SecurityValidator() async def dispatch(self, request: Request, call_next): # Skip validation for OPTIONS requests and excluded paths if request.method == "OPTIONS" or any(request.url.path.startswith(path) for path in self.exclude_paths): return await call_next(request) # Validate query parameters if request.query_params: validated_params = {} for key, value in request.query_params.items(): try: validated_params[key] = self.validator.validate_input(value) except HTTPException: raise except Exception as e: logger.error(f"Error validating query param {key}: {str(e)}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid query parameter: {key}" ) # Replace query params with validated ones request._query_params = validated_params # Validate headers for security self._validate_headers(request) # Process request response = await call_next(request) return response def _validate_headers(self, request: Request): """Validate request headers for security issues.""" # Check for suspicious headers suspicious_headers = [ 'x-forwarded-for', 'x-real-ip', 'x-originating-ip', 'x-cluster-client-ip', 'x-forwarded-host' ] for header in suspicious_headers: if header in request.headers: value = request.headers[header] # Basic validation for header values if len(value) > 500: # Unusually long header value logger.warning(f"Suspicious header detected: {header}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request headers" ) class RateLimitValidator: """ Rate limiting validation for API endpoints. """ def __init__(self): self.request_counts = {} # Simple in-memory tracking self.window_size = 60 # 1 minute window def check_rate_limit(self, client_ip: str, endpoint: str, limit: int = 100) -> bool: """ Check if client has exceeded rate limit. Args: client_ip: Client IP address endpoint: API endpoint limit: Request limit per window Returns: True if within limit, False otherwise """ import time current_time = time.time() key = f"{client_ip}:{endpoint}" # Clean old entries cutoff_time = current_time - self.window_size if key in self.request_counts: self.request_counts[key] = [ timestamp for timestamp in self.request_counts[key] if timestamp > cutoff_time ] else: self.request_counts[key] = [] # Check current count if len(self.request_counts[key]) >= limit: return False # Add current request self.request_counts[key].append(current_time) return True class ContentSecurityMiddleware(BaseHTTPMiddleware): """ Content Security Policy middleware. """ async def dispatch(self, request: Request, call_next): response = await call_next(request) # Add security headers response.headers['X-Content-Type-Options'] = 'nosniff' response.headers['X-Frame-Options'] = 'DENY' response.headers['X-XSS-Protection'] = '1; mode=block' response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin' response.headers['Content-Security-Policy'] = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: https:; " "connect-src 'self'" ) return response def validate_evaluation_data(evaluation_data: Dict[str, Any]) -> Dict[str, Any]: """ Validate evaluation-specific data. Args: evaluation_data: Evaluation configuration data Returns: Validated evaluation data Raises: HTTPException: If data is invalid """ validator = SecurityValidator() # Validate model configuration if 'model_config_data' in evaluation_data: model_config = evaluation_data['model_config_data'] # Validate model name if 'model_name' in model_config: model_config['model_name'] = validator.validate_input( model_config['model_name'], max_length=100 ) # Validate API endpoint if 'api_endpoint' in model_config: model_config['api_endpoint'] = validator.validate_input( model_config['api_endpoint'], max_length=500 ) # Validate API key (don't log or escape this) if 'api_key' in model_config: if len(model_config['api_key']) > 500: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="API key too long" ) # Validate pipeline configuration if 'pipeline_config' in evaluation_data: pipeline_config = evaluation_data['pipeline_config'] # Validate numeric parameters for field in ['max_iterations', 'num_prompts']: if field in pipeline_config: try: value = int(pipeline_config[field]) if value < 1 or value > 1000: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"{field} must be between 1 and 1000" ) pipeline_config[field] = value except (ValueError, TypeError): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid {field} value" ) return evaluation_data def sanitize_error_message(message: str) -> str: """ Sanitize error messages to prevent information leakage. Args: message: Original error message Returns: Sanitized error message """ # Remove sensitive information from error messages sensitive_patterns = [ r'password\s*[:=]\s*\S+', r'token\s*[:=]\s*\S+', r'key\s*[:=]\s*\S+', r'secret\s*[:=]\s*\S+', r'api_key\s*[:=]\s*\S+', ] sanitized = message for pattern in sensitive_patterns: sanitized = re.sub(pattern, '[REDACTED]', sanitized, flags=re.IGNORECASE) return sanitized