ALM-2 / backend /middleware /validation.py
ACA050's picture
Upload 520 files
2ed8996 verified
"""
Input Validation and Security Middleware for AegisLM Backend.
Provides comprehensive input validation, sanitization, and security checks
for all incoming requests to prevent common vulnerabilities.
"""
import re
import html
import logging
from typing import Dict, Any, List, Optional
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
logger = logging.getLogger(__name__)
class SecurityValidator:
"""
Security validation utilities for input sanitization and validation.
"""
# Common attack patterns
SQL_INJECTION_PATTERNS = [
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION)\b)",
r"(\b(OR|AND)\s+\d+\s*=\s*\d+)",
r"(\b(OR|AND)\s+['\"][\w\s]*['\"]\s*=\s*['\"][\w\s]*['\"])",
r"(--|#|\/\*|\*\/)",
r"(\b(SCRIPT|JAVASCRIPT|VBSCRIPT|ONLOAD|ONERROR)\b)",
]
XSS_PATTERNS = [
r"<script[^>]*>.*?</script>",
r"javascript:",
r"vbscript:",
r"onload\s*=",
r"onerror\s*=",
r"onclick\s*=",
r"<iframe[^>]*>",
r"<object[^>]*>",
r"<embed[^>]*>",
]
PATH_TRAVERSAL_PATTERNS = [
r"\.\.[\/\\]",
r"\.\.%2f",
r"\.\.%5c",
r"%2e%2e%2f",
r"%2e%2e%5c",
]
@classmethod
def validate_input(cls, input_data: str, max_length: int = 10000) -> str:
"""
Validate and sanitize input string.
Args:
input_data: Input string to validate
max_length: Maximum allowed length
Returns:
Sanitized input string
Raises:
HTTPException: If input contains malicious content
"""
if not isinstance(input_data, str):
return input_data
# Check length
if len(input_data) > max_length:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail="Input too long"
)
# Check for SQL injection patterns
for pattern in cls.SQL_INJECTION_PATTERNS:
if re.search(pattern, input_data, re.IGNORECASE):
# Using explicit slicing to help static analyzers
input_snippet = str(input_data)[0:100]
logger.warning(f"Potential vulnerability detected in input: {input_snippet}...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid input detected"
)
# Check for XSS patterns
for pattern in cls.XSS_PATTERNS:
if re.search(pattern, input_data, re.IGNORECASE):
input_snippet = str(input_data)[0:100]
logger.warning(f"Potential XSS detected: {input_snippet}...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid input detected"
)
# Check for path traversal patterns
for pattern in cls.PATH_TRAVERSAL_PATTERNS:
if re.search(pattern, input_data, re.IGNORECASE):
input_snippet = str(input_data)[0:100]
logger.warning(f"Potential path traversal detected: {input_snippet}...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid input detected"
)
# HTML escape the input
sanitized = html.escape(input_data)
return sanitized
@classmethod
def validate_json_data(cls, data: Dict[str, Any], max_depth: int = 10) -> Dict[str, Any]:
"""
Validate JSON data recursively.
Args:
data: JSON data to validate
max_depth: Maximum nesting depth
Returns:
Validated JSON data
Raises:
HTTPException: If data contains malicious content
"""
def validate_recursive(obj, depth=0):
if depth > max_depth:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Input nesting too deep"
)
if isinstance(obj, str):
return cls.validate_input(obj)
elif isinstance(obj, dict):
return {k: validate_recursive(v, depth + 1) for k, v in obj.items()}
elif isinstance(obj, list):
return [validate_recursive(item, depth + 1) for item in obj]
else:
return obj
return validate_recursive(data)
@classmethod
def validate_file_upload(cls, filename: str, content_type: str, size: int) -> bool:
"""
Validate uploaded file for security.
Args:
filename: Uploaded filename
content_type: File content type
size: File size in bytes
Returns:
True if file is valid
Raises:
HTTPException: If file is invalid
"""
# Check file size (max 10MB)
max_size = 10 * 1024 * 1024 # 10MB
if size > max_size:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail="File too large"
)
# Check filename for suspicious patterns
dangerous_extensions = ['.exe', '.bat', '.cmd', '.scr', '.pif', '.com']
filename_lower = filename.lower()
for ext in dangerous_extensions:
if filename_lower.endswith(ext):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="File type not allowed"
)
# Check for path traversal in filename
if '..' in filename or '/' in filename or '\\' in filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid filename"
)
# Allowed content types
allowed_types = [
'image/jpeg', 'image/png', 'image/gif', 'image/webp',
'text/plain', 'text/csv', 'application/json',
'application/pdf'
]
if content_type not in allowed_types:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Content type not allowed"
)
return True
class InputValidationMiddleware(BaseHTTPMiddleware):
"""
Middleware for automatic input validation and sanitization.
"""
def __init__(self, app, exclude_paths: Optional[List[str]] = None):
super().__init__(app)
self.exclude_paths = exclude_paths or ['/health', '/docs', '/openapi.json']
self.validator = SecurityValidator()
async def dispatch(self, request: Request, call_next):
# Skip validation for OPTIONS requests and excluded paths
if request.method == "OPTIONS" or any(request.url.path.startswith(path) for path in self.exclude_paths):
return await call_next(request)
# Validate query parameters
if request.query_params:
validated_params = {}
for key, value in request.query_params.items():
try:
validated_params[key] = self.validator.validate_input(value)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error validating query param {key}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid query parameter: {key}"
)
# Replace query params with validated ones
request._query_params = validated_params
# Validate headers for security
self._validate_headers(request)
# Process request
response = await call_next(request)
return response
def _validate_headers(self, request: Request):
"""Validate request headers for security issues."""
# Check for suspicious headers
suspicious_headers = [
'x-forwarded-for', 'x-real-ip', 'x-originating-ip',
'x-cluster-client-ip', 'x-forwarded-host'
]
for header in suspicious_headers:
if header in request.headers:
value = request.headers[header]
# Basic validation for header values
if len(value) > 500: # Unusually long header value
logger.warning(f"Suspicious header detected: {header}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request headers"
)
class RateLimitValidator:
"""
Rate limiting validation for API endpoints.
"""
def __init__(self):
self.request_counts = {} # Simple in-memory tracking
self.window_size = 60 # 1 minute window
def check_rate_limit(self, client_ip: str, endpoint: str, limit: int = 100) -> bool:
"""
Check if client has exceeded rate limit.
Args:
client_ip: Client IP address
endpoint: API endpoint
limit: Request limit per window
Returns:
True if within limit, False otherwise
"""
import time
current_time = time.time()
key = f"{client_ip}:{endpoint}"
# Clean old entries
cutoff_time = current_time - self.window_size
if key in self.request_counts:
self.request_counts[key] = [
timestamp for timestamp in self.request_counts[key]
if timestamp > cutoff_time
]
else:
self.request_counts[key] = []
# Check current count
if len(self.request_counts[key]) >= limit:
return False
# Add current request
self.request_counts[key].append(current_time)
return True
class ContentSecurityMiddleware(BaseHTTPMiddleware):
"""
Content Security Policy middleware.
"""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Add security headers
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = 'DENY'
response.headers['X-XSS-Protection'] = '1; mode=block'
response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
response.headers['Content-Security-Policy'] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"connect-src 'self'"
)
return response
def validate_evaluation_data(evaluation_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate evaluation-specific data.
Args:
evaluation_data: Evaluation configuration data
Returns:
Validated evaluation data
Raises:
HTTPException: If data is invalid
"""
validator = SecurityValidator()
# Validate model configuration
if 'model_config_data' in evaluation_data:
model_config = evaluation_data['model_config_data']
# Validate model name
if 'model_name' in model_config:
model_config['model_name'] = validator.validate_input(
model_config['model_name'], max_length=100
)
# Validate API endpoint
if 'api_endpoint' in model_config:
model_config['api_endpoint'] = validator.validate_input(
model_config['api_endpoint'], max_length=500
)
# Validate API key (don't log or escape this)
if 'api_key' in model_config:
if len(model_config['api_key']) > 500:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="API key too long"
)
# Validate pipeline configuration
if 'pipeline_config' in evaluation_data:
pipeline_config = evaluation_data['pipeline_config']
# Validate numeric parameters
for field in ['max_iterations', 'num_prompts']:
if field in pipeline_config:
try:
value = int(pipeline_config[field])
if value < 1 or value > 1000:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{field} must be between 1 and 1000"
)
pipeline_config[field] = value
except (ValueError, TypeError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid {field} value"
)
return evaluation_data
def sanitize_error_message(message: str) -> str:
"""
Sanitize error messages to prevent information leakage.
Args:
message: Original error message
Returns:
Sanitized error message
"""
# Remove sensitive information from error messages
sensitive_patterns = [
r'password\s*[:=]\s*\S+',
r'token\s*[:=]\s*\S+',
r'key\s*[:=]\s*\S+',
r'secret\s*[:=]\s*\S+',
r'api_key\s*[:=]\s*\S+',
]
sanitized = message
for pattern in sensitive_patterns:
sanitized = re.sub(pattern, '[REDACTED]', sanitized, flags=re.IGNORECASE)
return sanitized