""" Security middleware for input validation, rate limiting, and request sanitization. """ import time import json import logging from typing import Dict, Any, Optional from collections import defaultdict, deque from fastapi import Request, Response, HTTPException from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from app.utils.input_sanitizer import InputSanitizer # Use standard logger for middleware to avoid circular dependencies logger = logging.getLogger(__name__) class SecurityMiddleware(BaseHTTPMiddleware): """ Comprehensive security middleware that provides: - Request size limiting - Rate limiting - Input validation - Request logging - Security headers """ def __init__(self, app, max_request_size: int = 10 * 1024 * 1024): # 10MB default super().__init__(app) self.max_request_size = max_request_size self.rate_limiter = RateLimiter() async def dispatch(self, request: Request, call_next): start_time = time.time() try: # Check request size if hasattr(request, 'headers') and 'content-length' in request.headers: content_length = int(request.headers['content-length']) if content_length > self.max_request_size: logger.warning("Request size too large") return JSONResponse( status_code=413, content={"error": "Request entity too large"} ) # Rate limiting client_ip = self._get_client_ip(request) if not self.rate_limiter.is_allowed(client_ip, request.url.path): logger.warning("Rate limit exceeded for client") return JSONResponse( status_code=429, content={"error": "Rate limit exceeded"} ) # Process request response = await call_next(request) # Add security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" # Log request safely (basic logging to avoid circular dependencies) processing_time = time.time() - start_time logger.info(f"Request processed: {request.method} {request.url.path} " f"in {processing_time:.3f}s with status {response.status_code}") return response except Exception as e: # Use basic logging to avoid circular dependency issues logger.error("Security middleware error occurred") return JSONResponse( status_code=500, content={"error": "Internal server error"} ) def _get_client_ip(self, request: Request) -> str: """Extract client IP address from request""" # Check for forwarded headers first forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # Fallback to client host return request.client.host if request.client else "unknown" class RateLimiter: """ Simple in-memory rate limiter with sliding window. In production, use Redis or similar distributed cache. """ def __init__(self): self.requests = defaultdict(deque) self.limits = { # requests per minute for different endpoint patterns "/api/v1/merchants": 100, "/api/v1/helpers": 200, "/api/v1/nlp": 50, "default": 60 } self.window_size = 60 # 1 minute window def is_allowed(self, client_ip: str, path: str) -> bool: """Check if request is allowed based on rate limits""" current_time = time.time() # Determine rate limit for this path limit = self._get_limit_for_path(path) # Clean old requests outside the window client_requests = self.requests[client_ip] while client_requests and client_requests[0] < current_time - self.window_size: client_requests.popleft() # Check if limit exceeded if len(client_requests) >= limit: return False # Add current request client_requests.append(current_time) return True def _get_limit_for_path(self, path: str) -> int: """Get rate limit for specific path""" for pattern, limit in self.limits.items(): if pattern != "default" and pattern in path: return limit return self.limits["default"] class RequestValidator: """Validates common request patterns and parameters""" @staticmethod def validate_pagination(limit: Optional[int], offset: Optional[int]) -> tuple: """Validate pagination parameters""" if limit is not None: limit = InputSanitizer.sanitize_pagination(limit, 0)[0] if offset is not None: offset = InputSanitizer.sanitize_pagination(10, offset)[1] return limit, offset @staticmethod def validate_search_params(params: Dict[str, Any]) -> Dict[str, Any]: """Validate search parameters""" validated = {} for key, value in params.items(): if value is None: continue try: if key == "location_id": validated[key] = InputSanitizer.sanitize_location_id(value) elif key == "merchant_id": validated[key] = InputSanitizer.sanitize_merchant_id(value) elif key in ["latitude", "longitude"]: lat = params.get("latitude") lng = params.get("longitude") lat, lng = InputSanitizer.sanitize_coordinates(lat, lng) validated["latitude"] = lat validated["longitude"] = lng elif key in ["limit", "offset"]: limit = params.get("limit", 10) offset = params.get("offset", 0) limit, offset = InputSanitizer.sanitize_pagination( limit, offset) validated["limit"] = limit validated["offset"] = offset elif isinstance(value, str): validated[key] = InputSanitizer.sanitize_string(value) else: validated[key] = value except ValueError as e: raise HTTPException( status_code=400, detail=f"Invalid parameter {key}: {str(e)}" ) return validated class CSRFProtection: """Basic CSRF protection for state-changing operations""" def __init__(self): self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"} def validate_request(self, request: Request) -> bool: """Validate CSRF token for protected methods""" if request.method not in self.protected_methods: return True # Check for CSRF token in headers csrf_token = request.headers.get("X-CSRF-Token") if not csrf_token: return False # In production, validate against stored token # For now, just check that token exists and is not empty return len(csrf_token.strip()) > 0 def create_security_middleware(app, **kwargs): """Factory function to create security middleware with configuration""" return SecurityMiddleware(app, **kwargs) # Utility decorators for endpoint protection def require_valid_input(validation_func): """Decorator to validate input parameters""" def decorator(func): async def wrapper(*args, **kwargs): try: validated_kwargs = validation_func(kwargs) return await func(*args, **validated_kwargs) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return wrapper return decorator def rate_limit(requests_per_minute: int = 60): """Decorator for endpoint-specific rate limiting""" def decorator(func): # This would integrate with the rate limiter # Implementation depends on your specific needs return func return decorator