zenith-backend / app /middleware /validation_middleware.py
teoat's picture
Upload folder using huggingface_hub
4ae946d verified
"""
Request validation middleware for enhanced input validation and security
"""
import logging
from fastapi import HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class RequestValidationMiddleware(BaseHTTPMiddleware):
"""Middleware for comprehensive request validation and security checks"""
def __init__(self, app, max_body_size: int = 10 * 1024 * 1024): # 10MB default
super().__init__(app)
self.max_body_size = max_body_size
async def dispatch(self, request: Request, call_next):
try:
# Validate request size
await self._validate_request_size(request)
# Validate content type for POST/PUT/PATCH requests
await self._validate_content_type(request)
# Log suspicious requests
await self._log_suspicious_requests(request)
response = await call_next(request)
return response
except HTTPException:
raise
except Exception as exc:
logger.error(f"Request validation middleware error: {exc}")
raise
async def _validate_request_size(self, request: Request) -> None:
"""Validate request body size"""
if request.method in ["POST", "PUT", "PATCH"]:
content_length = request.headers.get("content-length")
if content_length:
try:
size = int(content_length)
if size > self.max_body_size:
raise HTTPException(
status_code=413,
detail=f"Request body too large. Maximum size: {self.max_body_size} bytes",
)
except ValueError:
pass # Invalid content-length header, let FastAPI handle it
async def _validate_content_type(self, request: Request) -> None:
"""Validate content type for requests with bodies"""
if request.method in ["POST", "PUT", "PATCH"]:
content_type = request.headers.get("content-type", "").lower()
# Require content-type for requests with bodies
if not content_type:
# Read a small amount to check if there's actually a body
body = await request.body()
if body and len(body) > 0:
raise HTTPException(
status_code=400,
detail="Content-Type header required for requests with body",
)
return
# Validate content-type format
allowed_types = [
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
"application/xml",
"application/octet-stream",
]
# Check if it's one of the allowed types or starts with allowed prefix
is_allowed = any(
content_type.startswith(allowed) for allowed in allowed_types
)
if not is_allowed:
raise HTTPException(
status_code=415, detail=f"Unsupported content type: {content_type}"
)
async def _log_suspicious_requests(self, request: Request) -> None:
"""Log potentially suspicious requests for security monitoring"""
suspicious_indicators = []
# Check for SQL injection patterns in query parameters
query_params = str(request.query_params)
sql_patterns = [
"union",
"select",
"insert",
"update",
"delete",
"drop",
"exec",
"script",
]
if any(pattern in query_params.lower() for pattern in sql_patterns):
suspicious_indicators.append("sql_injection_patterns")
# Check for XSS patterns in query parameters
xss_patterns = ["<script", "javascript:", "onload=", "onerror="]
if any(pattern in query_params.lower() for pattern in xss_patterns):
suspicious_indicators.append("xss_patterns")
# Check for unusually long query strings
if len(query_params) > 2000:
suspicious_indicators.append("long_query_string")
# Check for suspicious user agents
user_agent = request.headers.get("user-agent", "").lower()
suspicious_uas = ["sqlmap", "nmap", "masscan", "dirbuster", "gobuster"]
if any(ua in user_agent for ua in suspicious_uas):
suspicious_indicators.append("suspicious_user_agent")
# Log suspicious requests
if suspicious_indicators:
logger.warning(
f"Suspicious request detected: {request.method} {request.url.path}",
extra={
"client_ip": request.client.host if request.client else "unknown",
"user_agent": user_agent,
"indicators": suspicious_indicators,
"query_params_length": len(query_params),
},
)
class InputValidationMiddleware(BaseHTTPMiddleware):
"""Middleware for input sanitization and validation"""
async def dispatch(self, request: Request, call_next):
try:
# Sanitize headers
await self._sanitize_headers(request)
# Validate request path and query parameters
await self._validate_request_parameters(request)
response = await call_next(request)
return response
except HTTPException:
raise
except Exception as exc:
logger.error(f"Input validation middleware error: {exc}")
raise
async def _sanitize_headers(self, request: Request) -> None:
"""Sanitize and validate request headers"""
# Remove any headers that could cause issues
# Log headers that might indicate proxy misuse
suspicious_headers = ["x-forwarded-for", "x-real-ip", "x-client-ip"]
found_suspicious = [h for h in suspicious_headers if h in request.headers]
if found_suspicious:
logger.info(f"Request with proxy headers: {found_suspicious}")
async def _validate_request_parameters(self, request: Request) -> None:
"""Validate request path and query parameters"""
# Check for path traversal attempts
path = request.url.path
if ".." in path or "%" in path:
# More thorough check for path traversal
normalized_path = path.replace("\\", "/")
if "../" in normalized_path or "..\\" in normalized_path:
raise HTTPException(
status_code=400, detail="Invalid path: path traversal detected"
)
# Validate query parameter names (no special characters that could cause issues)
for param_name in request.query_params:
if any(char in param_name for char in ["<", ">", '"', "'", ";", "(", ")"]):
raise HTTPException(
status_code=400,
detail=f"Invalid query parameter name: {param_name}",
)