bookmyservice-mhs / app /middleware /security_middleware.py
MukeshKapoor25's picture
fix(security): Resolve critical security middleware and logging vulnerabilities
96e312e
"""
Security middleware for input validation, rate limiting, and request sanitization.
"""
import time
import json
import logging
from typing import Dict, Any, Optional
from collections import defaultdict, deque
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.utils.input_sanitizer import InputSanitizer
# Use standard logger for middleware to avoid circular dependencies
logger = logging.getLogger(__name__)
class SecurityMiddleware(BaseHTTPMiddleware):
"""
Comprehensive security middleware that provides:
- Request size limiting
- Rate limiting
- Input validation
- Request logging
- Security headers
"""
def __init__(self, app, max_request_size: int = 10 * 1024 * 1024): # 10MB default
super().__init__(app)
self.max_request_size = max_request_size
self.rate_limiter = RateLimiter()
async def dispatch(self, request: Request, call_next):
start_time = time.time()
try:
# Check request size
if hasattr(request, 'headers') and 'content-length' in request.headers:
content_length = int(request.headers['content-length'])
if content_length > self.max_request_size:
logger.warning("Request size too large")
return JSONResponse(
status_code=413,
content={"error": "Request entity too large"}
)
# Rate limiting
client_ip = self._get_client_ip(request)
if not self.rate_limiter.is_allowed(client_ip, request.url.path):
logger.warning("Rate limit exceeded for client")
return JSONResponse(
status_code=429,
content={"error": "Rate limit exceeded"}
)
# Process request
response = await call_next(request)
# Add security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
# Log request safely (basic logging to avoid circular dependencies)
processing_time = time.time() - start_time
logger.info(f"Request processed: {request.method} {request.url.path} "
f"in {processing_time:.3f}s with status {response.status_code}")
return response
except Exception as e:
# Use basic logging to avoid circular dependency issues
logger.error("Security middleware error occurred")
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request"""
# Check for forwarded headers first
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fallback to client host
return request.client.host if request.client else "unknown"
class RateLimiter:
"""
Simple in-memory rate limiter with sliding window.
In production, use Redis or similar distributed cache.
"""
def __init__(self):
self.requests = defaultdict(deque)
self.limits = {
# requests per minute for different endpoint patterns
"/api/v1/merchants": 100,
"/api/v1/helpers": 200,
"/api/v1/nlp": 50,
"default": 60
}
self.window_size = 60 # 1 minute window
def is_allowed(self, client_ip: str, path: str) -> bool:
"""Check if request is allowed based on rate limits"""
current_time = time.time()
# Determine rate limit for this path
limit = self._get_limit_for_path(path)
# Clean old requests outside the window
client_requests = self.requests[client_ip]
while client_requests and client_requests[0] < current_time - self.window_size:
client_requests.popleft()
# Check if limit exceeded
if len(client_requests) >= limit:
return False
# Add current request
client_requests.append(current_time)
return True
def _get_limit_for_path(self, path: str) -> int:
"""Get rate limit for specific path"""
for pattern, limit in self.limits.items():
if pattern != "default" and pattern in path:
return limit
return self.limits["default"]
class RequestValidator:
"""Validates common request patterns and parameters"""
@staticmethod
def validate_pagination(limit: Optional[int], offset: Optional[int]) -> tuple:
"""Validate pagination parameters"""
if limit is not None:
limit = InputSanitizer.sanitize_pagination(limit, 0)[0]
if offset is not None:
offset = InputSanitizer.sanitize_pagination(10, offset)[1]
return limit, offset
@staticmethod
def validate_search_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Validate search parameters"""
validated = {}
for key, value in params.items():
if value is None:
continue
try:
if key == "location_id":
validated[key] = InputSanitizer.sanitize_location_id(value)
elif key == "merchant_id":
validated[key] = InputSanitizer.sanitize_merchant_id(value)
elif key in ["latitude", "longitude"]:
lat = params.get("latitude")
lng = params.get("longitude")
lat, lng = InputSanitizer.sanitize_coordinates(lat, lng)
validated["latitude"] = lat
validated["longitude"] = lng
elif key in ["limit", "offset"]:
limit = params.get("limit", 10)
offset = params.get("offset", 0)
limit, offset = InputSanitizer.sanitize_pagination(
limit, offset)
validated["limit"] = limit
validated["offset"] = offset
elif isinstance(value, str):
validated[key] = InputSanitizer.sanitize_string(value)
else:
validated[key] = value
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid parameter {key}: {str(e)}"
)
return validated
class CSRFProtection:
"""Basic CSRF protection for state-changing operations"""
def __init__(self):
self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}
def validate_request(self, request: Request) -> bool:
"""Validate CSRF token for protected methods"""
if request.method not in self.protected_methods:
return True
# Check for CSRF token in headers
csrf_token = request.headers.get("X-CSRF-Token")
if not csrf_token:
return False
# In production, validate against stored token
# For now, just check that token exists and is not empty
return len(csrf_token.strip()) > 0
def create_security_middleware(app, **kwargs):
"""Factory function to create security middleware with configuration"""
return SecurityMiddleware(app, **kwargs)
# Utility decorators for endpoint protection
def require_valid_input(validation_func):
"""Decorator to validate input parameters"""
def decorator(func):
async def wrapper(*args, **kwargs):
try:
validated_kwargs = validation_func(kwargs)
return await func(*args, **validated_kwargs)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return wrapper
return decorator
def rate_limit(requests_per_minute: int = 60):
"""Decorator for endpoint-specific rate limiting"""
def decorator(func):
# This would integrate with the rate limiter
# Implementation depends on your specific needs
return func
return decorator