Spaces:
Running
Running
File size: 8,633 Bytes
a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e a7c2198 96e312e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | """
Security middleware for input validation, rate limiting, and request sanitization.
"""
import time
import json
import logging
from typing import Dict, Any, Optional
from collections import defaultdict, deque
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.utils.input_sanitizer import InputSanitizer
# Use standard logger for middleware to avoid circular dependencies
logger = logging.getLogger(__name__)
class SecurityMiddleware(BaseHTTPMiddleware):
"""
Comprehensive security middleware that provides:
- Request size limiting
- Rate limiting
- Input validation
- Request logging
- Security headers
"""
def __init__(self, app, max_request_size: int = 10 * 1024 * 1024): # 10MB default
super().__init__(app)
self.max_request_size = max_request_size
self.rate_limiter = RateLimiter()
async def dispatch(self, request: Request, call_next):
start_time = time.time()
try:
# Check request size
if hasattr(request, 'headers') and 'content-length' in request.headers:
content_length = int(request.headers['content-length'])
if content_length > self.max_request_size:
logger.warning("Request size too large")
return JSONResponse(
status_code=413,
content={"error": "Request entity too large"}
)
# Rate limiting
client_ip = self._get_client_ip(request)
if not self.rate_limiter.is_allowed(client_ip, request.url.path):
logger.warning("Rate limit exceeded for client")
return JSONResponse(
status_code=429,
content={"error": "Rate limit exceeded"}
)
# Process request
response = await call_next(request)
# Add security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
# Log request safely (basic logging to avoid circular dependencies)
processing_time = time.time() - start_time
logger.info(f"Request processed: {request.method} {request.url.path} "
f"in {processing_time:.3f}s with status {response.status_code}")
return response
except Exception as e:
# Use basic logging to avoid circular dependency issues
logger.error("Security middleware error occurred")
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request"""
# Check for forwarded headers first
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fallback to client host
return request.client.host if request.client else "unknown"
class RateLimiter:
"""
Simple in-memory rate limiter with sliding window.
In production, use Redis or similar distributed cache.
"""
def __init__(self):
self.requests = defaultdict(deque)
self.limits = {
# requests per minute for different endpoint patterns
"/api/v1/merchants": 100,
"/api/v1/helpers": 200,
"/api/v1/nlp": 50,
"default": 60
}
self.window_size = 60 # 1 minute window
def is_allowed(self, client_ip: str, path: str) -> bool:
"""Check if request is allowed based on rate limits"""
current_time = time.time()
# Determine rate limit for this path
limit = self._get_limit_for_path(path)
# Clean old requests outside the window
client_requests = self.requests[client_ip]
while client_requests and client_requests[0] < current_time - self.window_size:
client_requests.popleft()
# Check if limit exceeded
if len(client_requests) >= limit:
return False
# Add current request
client_requests.append(current_time)
return True
def _get_limit_for_path(self, path: str) -> int:
"""Get rate limit for specific path"""
for pattern, limit in self.limits.items():
if pattern != "default" and pattern in path:
return limit
return self.limits["default"]
class RequestValidator:
"""Validates common request patterns and parameters"""
@staticmethod
def validate_pagination(limit: Optional[int], offset: Optional[int]) -> tuple:
"""Validate pagination parameters"""
if limit is not None:
limit = InputSanitizer.sanitize_pagination(limit, 0)[0]
if offset is not None:
offset = InputSanitizer.sanitize_pagination(10, offset)[1]
return limit, offset
@staticmethod
def validate_search_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""Validate search parameters"""
validated = {}
for key, value in params.items():
if value is None:
continue
try:
if key == "location_id":
validated[key] = InputSanitizer.sanitize_location_id(value)
elif key == "merchant_id":
validated[key] = InputSanitizer.sanitize_merchant_id(value)
elif key in ["latitude", "longitude"]:
lat = params.get("latitude")
lng = params.get("longitude")
lat, lng = InputSanitizer.sanitize_coordinates(lat, lng)
validated["latitude"] = lat
validated["longitude"] = lng
elif key in ["limit", "offset"]:
limit = params.get("limit", 10)
offset = params.get("offset", 0)
limit, offset = InputSanitizer.sanitize_pagination(
limit, offset)
validated["limit"] = limit
validated["offset"] = offset
elif isinstance(value, str):
validated[key] = InputSanitizer.sanitize_string(value)
else:
validated[key] = value
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid parameter {key}: {str(e)}"
)
return validated
class CSRFProtection:
"""Basic CSRF protection for state-changing operations"""
def __init__(self):
self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}
def validate_request(self, request: Request) -> bool:
"""Validate CSRF token for protected methods"""
if request.method not in self.protected_methods:
return True
# Check for CSRF token in headers
csrf_token = request.headers.get("X-CSRF-Token")
if not csrf_token:
return False
# In production, validate against stored token
# For now, just check that token exists and is not empty
return len(csrf_token.strip()) > 0
def create_security_middleware(app, **kwargs):
"""Factory function to create security middleware with configuration"""
return SecurityMiddleware(app, **kwargs)
# Utility decorators for endpoint protection
def require_valid_input(validation_func):
"""Decorator to validate input parameters"""
def decorator(func):
async def wrapper(*args, **kwargs):
try:
validated_kwargs = validation_func(kwargs)
return await func(*args, **validated_kwargs)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return wrapper
return decorator
def rate_limit(requests_per_minute: int = 60):
"""Decorator for endpoint-specific rate limiting"""
def decorator(func):
# This would integrate with the rate limiter
# Implementation depends on your specific needs
return func
return decorator
|