Spaces:
Sleeping
Sleeping
| """ | |
| Security middleware for input validation, rate limiting, and request sanitization. | |
| """ | |
| import time | |
| import json | |
| import logging | |
| from typing import Dict, Any, Optional | |
| from collections import defaultdict, deque | |
| from fastapi import Request, Response, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from app.utils.input_sanitizer import InputSanitizer | |
| # Use standard logger for middleware to avoid circular dependencies | |
| logger = logging.getLogger(__name__) | |
| class SecurityMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Comprehensive security middleware that provides: | |
| - Request size limiting | |
| - Rate limiting | |
| - Input validation | |
| - Request logging | |
| - Security headers | |
| """ | |
| def __init__(self, app, max_request_size: int = 10 * 1024 * 1024): # 10MB default | |
| super().__init__(app) | |
| self.max_request_size = max_request_size | |
| self.rate_limiter = RateLimiter() | |
| async def dispatch(self, request: Request, call_next): | |
| start_time = time.time() | |
| try: | |
| # Check request size | |
| if hasattr(request, 'headers') and 'content-length' in request.headers: | |
| content_length = int(request.headers['content-length']) | |
| if content_length > self.max_request_size: | |
| logger.warning("Request size too large") | |
| return JSONResponse( | |
| status_code=413, | |
| content={"error": "Request entity too large"} | |
| ) | |
| # Rate limiting | |
| client_ip = self._get_client_ip(request) | |
| if not self.rate_limiter.is_allowed(client_ip, request.url.path): | |
| logger.warning("Rate limit exceeded for client") | |
| return JSONResponse( | |
| status_code=429, | |
| content={"error": "Rate limit exceeded"} | |
| ) | |
| # Process request | |
| response = await call_next(request) | |
| # Add security headers | |
| response.headers["X-Content-Type-Options"] = "nosniff" | |
| response.headers["X-Frame-Options"] = "DENY" | |
| response.headers["X-XSS-Protection"] = "1; mode=block" | |
| response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" | |
| # Log request safely (basic logging to avoid circular dependencies) | |
| processing_time = time.time() - start_time | |
| logger.info(f"Request processed: {request.method} {request.url.path} " | |
| f"in {processing_time:.3f}s with status {response.status_code}") | |
| return response | |
| except Exception as e: | |
| # Use basic logging to avoid circular dependency issues | |
| logger.error("Security middleware error occurred") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Internal server error"} | |
| ) | |
| def _get_client_ip(self, request: Request) -> str: | |
| """Extract client IP address from request""" | |
| # Check for forwarded headers first | |
| forwarded_for = request.headers.get("X-Forwarded-For") | |
| if forwarded_for: | |
| return forwarded_for.split(",")[0].strip() | |
| real_ip = request.headers.get("X-Real-IP") | |
| if real_ip: | |
| return real_ip | |
| # Fallback to client host | |
| return request.client.host if request.client else "unknown" | |
| class RateLimiter: | |
| """ | |
| Simple in-memory rate limiter with sliding window. | |
| In production, use Redis or similar distributed cache. | |
| """ | |
| def __init__(self): | |
| self.requests = defaultdict(deque) | |
| self.limits = { | |
| # requests per minute for different endpoint patterns | |
| "/api/v1/merchants": 100, | |
| "/api/v1/helpers": 200, | |
| "/api/v1/nlp": 50, | |
| "default": 60 | |
| } | |
| self.window_size = 60 # 1 minute window | |
| def is_allowed(self, client_ip: str, path: str) -> bool: | |
| """Check if request is allowed based on rate limits""" | |
| current_time = time.time() | |
| # Determine rate limit for this path | |
| limit = self._get_limit_for_path(path) | |
| # Clean old requests outside the window | |
| client_requests = self.requests[client_ip] | |
| while client_requests and client_requests[0] < current_time - self.window_size: | |
| client_requests.popleft() | |
| # Check if limit exceeded | |
| if len(client_requests) >= limit: | |
| return False | |
| # Add current request | |
| client_requests.append(current_time) | |
| return True | |
| def _get_limit_for_path(self, path: str) -> int: | |
| """Get rate limit for specific path""" | |
| for pattern, limit in self.limits.items(): | |
| if pattern != "default" and pattern in path: | |
| return limit | |
| return self.limits["default"] | |
| class RequestValidator: | |
| """Validates common request patterns and parameters""" | |
| def validate_pagination(limit: Optional[int], offset: Optional[int]) -> tuple: | |
| """Validate pagination parameters""" | |
| if limit is not None: | |
| limit = InputSanitizer.sanitize_pagination(limit, 0)[0] | |
| if offset is not None: | |
| offset = InputSanitizer.sanitize_pagination(10, offset)[1] | |
| return limit, offset | |
| def validate_search_params(params: Dict[str, Any]) -> Dict[str, Any]: | |
| """Validate search parameters""" | |
| validated = {} | |
| for key, value in params.items(): | |
| if value is None: | |
| continue | |
| try: | |
| if key == "location_id": | |
| validated[key] = InputSanitizer.sanitize_location_id(value) | |
| elif key == "merchant_id": | |
| validated[key] = InputSanitizer.sanitize_merchant_id(value) | |
| elif key in ["latitude", "longitude"]: | |
| lat = params.get("latitude") | |
| lng = params.get("longitude") | |
| lat, lng = InputSanitizer.sanitize_coordinates(lat, lng) | |
| validated["latitude"] = lat | |
| validated["longitude"] = lng | |
| elif key in ["limit", "offset"]: | |
| limit = params.get("limit", 10) | |
| offset = params.get("offset", 0) | |
| limit, offset = InputSanitizer.sanitize_pagination( | |
| limit, offset) | |
| validated["limit"] = limit | |
| validated["offset"] = offset | |
| elif isinstance(value, str): | |
| validated[key] = InputSanitizer.sanitize_string(value) | |
| else: | |
| validated[key] = value | |
| except ValueError as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid parameter {key}: {str(e)}" | |
| ) | |
| return validated | |
| class CSRFProtection: | |
| """Basic CSRF protection for state-changing operations""" | |
| def __init__(self): | |
| self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"} | |
| def validate_request(self, request: Request) -> bool: | |
| """Validate CSRF token for protected methods""" | |
| if request.method not in self.protected_methods: | |
| return True | |
| # Check for CSRF token in headers | |
| csrf_token = request.headers.get("X-CSRF-Token") | |
| if not csrf_token: | |
| return False | |
| # In production, validate against stored token | |
| # For now, just check that token exists and is not empty | |
| return len(csrf_token.strip()) > 0 | |
| def create_security_middleware(app, **kwargs): | |
| """Factory function to create security middleware with configuration""" | |
| return SecurityMiddleware(app, **kwargs) | |
| # Utility decorators for endpoint protection | |
| def require_valid_input(validation_func): | |
| """Decorator to validate input parameters""" | |
| def decorator(func): | |
| async def wrapper(*args, **kwargs): | |
| try: | |
| validated_kwargs = validation_func(kwargs) | |
| return await func(*args, **validated_kwargs) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| return wrapper | |
| return decorator | |
| def rate_limit(requests_per_minute: int = 60): | |
| """Decorator for endpoint-specific rate limiting""" | |
| def decorator(func): | |
| # This would integrate with the rate limiter | |
| # Implementation depends on your specific needs | |
| return func | |
| return decorator | |