""" Request validation middleware for enhanced input validation and security """ import logging from fastapi import HTTPException, Request from starlette.middleware.base import BaseHTTPMiddleware logger = logging.getLogger(__name__) class RequestValidationMiddleware(BaseHTTPMiddleware): """Middleware for comprehensive request validation and security checks""" def __init__(self, app, max_body_size: int = 10 * 1024 * 1024): # 10MB default super().__init__(app) self.max_body_size = max_body_size async def dispatch(self, request: Request, call_next): try: # Validate request size await self._validate_request_size(request) # Validate content type for POST/PUT/PATCH requests await self._validate_content_type(request) # Log suspicious requests await self._log_suspicious_requests(request) response = await call_next(request) return response except HTTPException: raise except Exception as exc: logger.error(f"Request validation middleware error: {exc}") raise async def _validate_request_size(self, request: Request) -> None: """Validate request body size""" if request.method in ["POST", "PUT", "PATCH"]: content_length = request.headers.get("content-length") if content_length: try: size = int(content_length) if size > self.max_body_size: raise HTTPException( status_code=413, detail=f"Request body too large. Maximum size: {self.max_body_size} bytes", ) except ValueError: pass # Invalid content-length header, let FastAPI handle it async def _validate_content_type(self, request: Request) -> None: """Validate content type for requests with bodies""" if request.method in ["POST", "PUT", "PATCH"]: content_type = request.headers.get("content-type", "").lower() # Require content-type for requests with bodies if not content_type: # Read a small amount to check if there's actually a body body = await request.body() if body and len(body) > 0: raise HTTPException( status_code=400, detail="Content-Type header required for requests with body", ) return # Validate content-type format allowed_types = [ "application/json", "application/x-www-form-urlencoded", "multipart/form-data", "text/plain", "application/xml", "application/octet-stream", ] # Check if it's one of the allowed types or starts with allowed prefix is_allowed = any( content_type.startswith(allowed) for allowed in allowed_types ) if not is_allowed: raise HTTPException( status_code=415, detail=f"Unsupported content type: {content_type}" ) async def _log_suspicious_requests(self, request: Request) -> None: """Log potentially suspicious requests for security monitoring""" suspicious_indicators = [] # Check for SQL injection patterns in query parameters query_params = str(request.query_params) sql_patterns = [ "union", "select", "insert", "update", "delete", "drop", "exec", "script", ] if any(pattern in query_params.lower() for pattern in sql_patterns): suspicious_indicators.append("sql_injection_patterns") # Check for XSS patterns in query parameters xss_patterns = [" 2000: suspicious_indicators.append("long_query_string") # Check for suspicious user agents user_agent = request.headers.get("user-agent", "").lower() suspicious_uas = ["sqlmap", "nmap", "masscan", "dirbuster", "gobuster"] if any(ua in user_agent for ua in suspicious_uas): suspicious_indicators.append("suspicious_user_agent") # Log suspicious requests if suspicious_indicators: logger.warning( f"Suspicious request detected: {request.method} {request.url.path}", extra={ "client_ip": request.client.host if request.client else "unknown", "user_agent": user_agent, "indicators": suspicious_indicators, "query_params_length": len(query_params), }, ) class InputValidationMiddleware(BaseHTTPMiddleware): """Middleware for input sanitization and validation""" async def dispatch(self, request: Request, call_next): try: # Sanitize headers await self._sanitize_headers(request) # Validate request path and query parameters await self._validate_request_parameters(request) response = await call_next(request) return response except HTTPException: raise except Exception as exc: logger.error(f"Input validation middleware error: {exc}") raise async def _sanitize_headers(self, request: Request) -> None: """Sanitize and validate request headers""" # Remove any headers that could cause issues # Log headers that might indicate proxy misuse suspicious_headers = ["x-forwarded-for", "x-real-ip", "x-client-ip"] found_suspicious = [h for h in suspicious_headers if h in request.headers] if found_suspicious: logger.info(f"Request with proxy headers: {found_suspicious}") async def _validate_request_parameters(self, request: Request) -> None: """Validate request path and query parameters""" # Check for path traversal attempts path = request.url.path if ".." in path or "%" in path: # More thorough check for path traversal normalized_path = path.replace("\\", "/") if "../" in normalized_path or "..\\" in normalized_path: raise HTTPException( status_code=400, detail="Invalid path: path traversal detected" ) # Validate query parameter names (no special characters that could cause issues) for param_name in request.query_params: if any(char in param_name for char in ["<", ">", '"', "'", ";", "(", ")"]): raise HTTPException( status_code=400, detail=f"Invalid query parameter name: {param_name}", )