| """ |
| Input Validation and Security Middleware for AegisLM Backend. |
| |
| Provides comprehensive input validation, sanitization, and security checks |
| for all incoming requests to prevent common vulnerabilities. |
| """ |
|
|
| import re |
| import html |
| import logging |
| from typing import Dict, Any, List, Optional |
| from fastapi import Request, HTTPException, status |
| from starlette.middleware.base import BaseHTTPMiddleware |
| from starlette.responses import JSONResponse |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class SecurityValidator: |
| """ |
| Security validation utilities for input sanitization and validation. |
| """ |
| |
| |
| SQL_INJECTION_PATTERNS = [ |
| r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION)\b)", |
| r"(\b(OR|AND)\s+\d+\s*=\s*\d+)", |
| r"(\b(OR|AND)\s+['\"][\w\s]*['\"]\s*=\s*['\"][\w\s]*['\"])", |
| r"(--|#|\/\*|\*\/)", |
| r"(\b(SCRIPT|JAVASCRIPT|VBSCRIPT|ONLOAD|ONERROR)\b)", |
| ] |
| |
| XSS_PATTERNS = [ |
| r"<script[^>]*>.*?</script>", |
| r"javascript:", |
| r"vbscript:", |
| r"onload\s*=", |
| r"onerror\s*=", |
| r"onclick\s*=", |
| r"<iframe[^>]*>", |
| r"<object[^>]*>", |
| r"<embed[^>]*>", |
| ] |
| |
| PATH_TRAVERSAL_PATTERNS = [ |
| r"\.\.[\/\\]", |
| r"\.\.%2f", |
| r"\.\.%5c", |
| r"%2e%2e%2f", |
| r"%2e%2e%5c", |
| ] |
| |
| @classmethod |
| def validate_input(cls, input_data: str, max_length: int = 10000) -> str: |
| """ |
| Validate and sanitize input string. |
| |
| Args: |
| input_data: Input string to validate |
| max_length: Maximum allowed length |
| |
| Returns: |
| Sanitized input string |
| |
| Raises: |
| HTTPException: If input contains malicious content |
| """ |
| if not isinstance(input_data, str): |
| return input_data |
| |
| |
| if len(input_data) > max_length: |
| raise HTTPException( |
| status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, |
| detail="Input too long" |
| ) |
| |
| |
| for pattern in cls.SQL_INJECTION_PATTERNS: |
| if re.search(pattern, input_data, re.IGNORECASE): |
| |
| input_snippet = str(input_data)[0:100] |
| logger.warning(f"Potential vulnerability detected in input: {input_snippet}...") |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Invalid input detected" |
| ) |
| |
| |
| for pattern in cls.XSS_PATTERNS: |
| if re.search(pattern, input_data, re.IGNORECASE): |
| input_snippet = str(input_data)[0:100] |
| logger.warning(f"Potential XSS detected: {input_snippet}...") |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Invalid input detected" |
| ) |
| |
| |
| for pattern in cls.PATH_TRAVERSAL_PATTERNS: |
| if re.search(pattern, input_data, re.IGNORECASE): |
| input_snippet = str(input_data)[0:100] |
| logger.warning(f"Potential path traversal detected: {input_snippet}...") |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Invalid input detected" |
| ) |
| |
| |
| sanitized = html.escape(input_data) |
| |
| return sanitized |
| |
| @classmethod |
| def validate_json_data(cls, data: Dict[str, Any], max_depth: int = 10) -> Dict[str, Any]: |
| """ |
| Validate JSON data recursively. |
| |
| Args: |
| data: JSON data to validate |
| max_depth: Maximum nesting depth |
| |
| Returns: |
| Validated JSON data |
| |
| Raises: |
| HTTPException: If data contains malicious content |
| """ |
| def validate_recursive(obj, depth=0): |
| if depth > max_depth: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Input nesting too deep" |
| ) |
| |
| if isinstance(obj, str): |
| return cls.validate_input(obj) |
| elif isinstance(obj, dict): |
| return {k: validate_recursive(v, depth + 1) for k, v in obj.items()} |
| elif isinstance(obj, list): |
| return [validate_recursive(item, depth + 1) for item in obj] |
| else: |
| return obj |
| |
| return validate_recursive(data) |
| |
| @classmethod |
| def validate_file_upload(cls, filename: str, content_type: str, size: int) -> bool: |
| """ |
| Validate uploaded file for security. |
| |
| Args: |
| filename: Uploaded filename |
| content_type: File content type |
| size: File size in bytes |
| |
| Returns: |
| True if file is valid |
| |
| Raises: |
| HTTPException: If file is invalid |
| """ |
| |
| max_size = 10 * 1024 * 1024 |
| if size > max_size: |
| raise HTTPException( |
| status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, |
| detail="File too large" |
| ) |
| |
| |
| dangerous_extensions = ['.exe', '.bat', '.cmd', '.scr', '.pif', '.com'] |
| filename_lower = filename.lower() |
| |
| for ext in dangerous_extensions: |
| if filename_lower.endswith(ext): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="File type not allowed" |
| ) |
| |
| |
| if '..' in filename or '/' in filename or '\\' in filename: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Invalid filename" |
| ) |
| |
| |
| allowed_types = [ |
| 'image/jpeg', 'image/png', 'image/gif', 'image/webp', |
| 'text/plain', 'text/csv', 'application/json', |
| 'application/pdf' |
| ] |
| |
| if content_type not in allowed_types: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Content type not allowed" |
| ) |
| |
| return True |
|
|
|
|
| class InputValidationMiddleware(BaseHTTPMiddleware): |
| """ |
| Middleware for automatic input validation and sanitization. |
| """ |
| |
| def __init__(self, app, exclude_paths: Optional[List[str]] = None): |
| super().__init__(app) |
| self.exclude_paths = exclude_paths or ['/health', '/docs', '/openapi.json'] |
| self.validator = SecurityValidator() |
| |
| async def dispatch(self, request: Request, call_next): |
| |
| if request.method == "OPTIONS" or any(request.url.path.startswith(path) for path in self.exclude_paths): |
| return await call_next(request) |
| |
| |
| if request.query_params: |
| validated_params = {} |
| for key, value in request.query_params.items(): |
| try: |
| validated_params[key] = self.validator.validate_input(value) |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Error validating query param {key}: {str(e)}") |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"Invalid query parameter: {key}" |
| ) |
| |
| |
| request._query_params = validated_params |
| |
| |
| self._validate_headers(request) |
| |
| |
| response = await call_next(request) |
| |
| return response |
| |
| def _validate_headers(self, request: Request): |
| """Validate request headers for security issues.""" |
| |
| suspicious_headers = [ |
| 'x-forwarded-for', 'x-real-ip', 'x-originating-ip', |
| 'x-cluster-client-ip', 'x-forwarded-host' |
| ] |
| |
| for header in suspicious_headers: |
| if header in request.headers: |
| value = request.headers[header] |
| |
| if len(value) > 500: |
| logger.warning(f"Suspicious header detected: {header}") |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Invalid request headers" |
| ) |
|
|
|
|
| class RateLimitValidator: |
| """ |
| Rate limiting validation for API endpoints. |
| """ |
| |
| def __init__(self): |
| self.request_counts = {} |
| self.window_size = 60 |
| |
| def check_rate_limit(self, client_ip: str, endpoint: str, limit: int = 100) -> bool: |
| """ |
| Check if client has exceeded rate limit. |
| |
| Args: |
| client_ip: Client IP address |
| endpoint: API endpoint |
| limit: Request limit per window |
| |
| Returns: |
| True if within limit, False otherwise |
| """ |
| import time |
| current_time = time.time() |
| key = f"{client_ip}:{endpoint}" |
| |
| |
| cutoff_time = current_time - self.window_size |
| if key in self.request_counts: |
| self.request_counts[key] = [ |
| timestamp for timestamp in self.request_counts[key] |
| if timestamp > cutoff_time |
| ] |
| else: |
| self.request_counts[key] = [] |
| |
| |
| if len(self.request_counts[key]) >= limit: |
| return False |
| |
| |
| self.request_counts[key].append(current_time) |
| return True |
|
|
|
|
| class ContentSecurityMiddleware(BaseHTTPMiddleware): |
| """ |
| Content Security Policy middleware. |
| """ |
| |
| async def dispatch(self, request: Request, call_next): |
| response = await call_next(request) |
| |
| |
| response.headers['X-Content-Type-Options'] = 'nosniff' |
| response.headers['X-Frame-Options'] = 'DENY' |
| response.headers['X-XSS-Protection'] = '1; mode=block' |
| response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin' |
| response.headers['Content-Security-Policy'] = ( |
| "default-src 'self'; " |
| "script-src 'self' 'unsafe-inline'; " |
| "style-src 'self' 'unsafe-inline'; " |
| "img-src 'self' data: https:; " |
| "connect-src 'self'" |
| ) |
| |
| return response |
|
|
|
|
| def validate_evaluation_data(evaluation_data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Validate evaluation-specific data. |
| |
| Args: |
| evaluation_data: Evaluation configuration data |
| |
| Returns: |
| Validated evaluation data |
| |
| Raises: |
| HTTPException: If data is invalid |
| """ |
| validator = SecurityValidator() |
| |
| |
| if 'model_config_data' in evaluation_data: |
| model_config = evaluation_data['model_config_data'] |
| |
| |
| if 'model_name' in model_config: |
| model_config['model_name'] = validator.validate_input( |
| model_config['model_name'], max_length=100 |
| ) |
| |
| |
| if 'api_endpoint' in model_config: |
| model_config['api_endpoint'] = validator.validate_input( |
| model_config['api_endpoint'], max_length=500 |
| ) |
| |
| |
| if 'api_key' in model_config: |
| if len(model_config['api_key']) > 500: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="API key too long" |
| ) |
| |
| |
| if 'pipeline_config' in evaluation_data: |
| pipeline_config = evaluation_data['pipeline_config'] |
| |
| |
| for field in ['max_iterations', 'num_prompts']: |
| if field in pipeline_config: |
| try: |
| value = int(pipeline_config[field]) |
| if value < 1 or value > 1000: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"{field} must be between 1 and 1000" |
| ) |
| pipeline_config[field] = value |
| except (ValueError, TypeError): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"Invalid {field} value" |
| ) |
| |
| return evaluation_data |
|
|
|
|
| def sanitize_error_message(message: str) -> str: |
| """ |
| Sanitize error messages to prevent information leakage. |
| |
| Args: |
| message: Original error message |
| |
| Returns: |
| Sanitized error message |
| """ |
| |
| sensitive_patterns = [ |
| r'password\s*[:=]\s*\S+', |
| r'token\s*[:=]\s*\S+', |
| r'key\s*[:=]\s*\S+', |
| r'secret\s*[:=]\s*\S+', |
| r'api_key\s*[:=]\s*\S+', |
| ] |
| |
| sanitized = message |
| for pattern in sensitive_patterns: |
| sanitized = re.sub(pattern, '[REDACTED]', sanitized, flags=re.IGNORECASE) |
| |
| return sanitized |
|
|