File size: 4,919 Bytes
80e0598 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import time
from collections import defaultdict
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import logging
import asyncio # Concurrency limiting
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
rate_limit_per_minute=10,
rate_limit_window=60,
protected_routes=["/generate", "/api/generate", "/api/generate-with-report"]
):
super().__init__(app)
self.rate_limit_per_minute = rate_limit_per_minute
self.rate_limit_window = rate_limit_window
self.protected_routes = protected_routes
self.ip_requests = defaultdict(list)
logger.info(f"Rate limit middleware initialized: {rate_limit_per_minute} requests per {rate_limit_window}s")
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
current_time = time.time()
# Only apply rate limiting to protected routes
if any(request.url.path.startswith(route) for route in self.protected_routes):
# Clean up old requests
self.ip_requests[client_ip] = [t for t in self.ip_requests[client_ip]
if current_time - t < self.rate_limit_window]
# Check if rate limit exceeded
if len(self.ip_requests[client_ip]) >= self.rate_limit_per_minute:
logger.warning(f"Rate limit exceeded for IP {client_ip} on {request.url.path}")
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded. Please try again later."}
)
# Add current request timestamp
self.ip_requests[client_ip].append(current_time)
# Process the request
response = await call_next(request)
return response
class ConcurrencyLimitMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
max_concurrent_requests=5,
timeout=5.0,
protected_routes=None
):
super().__init__(app)
self.semaphore = asyncio.Semaphore(max_concurrent_requests)
self.timeout = timeout
self.protected_routes = protected_routes or ["/generate", "/api/generate", "/api/generate-with-report"]
logger.info(f"Concurrency limit middleware initialized: {max_concurrent_requests} concurrent requests")
async def dispatch(self, request, call_next):
try:
# Only apply to protected routes
if any(request.url.path.startswith(route) for route in self.protected_routes):
try:
# Try to acquire the semaphore
acquired = False
try:
# Use wait_for instead of timeout context manager for compatibility
await asyncio.wait_for(self.semaphore.acquire(), timeout=self.timeout)
acquired = True
return await call_next(request)
finally:
if acquired:
self.semaphore.release()
except asyncio.TimeoutError:
# Timeout waiting for semaphore
logger.warning(f"Concurrency limit reached for {request.url.path}")
return JSONResponse(
status_code=503,
content={"detail": "Server is at capacity. Please try again later."}
)
else:
# For non-protected routes, proceed normally
return await call_next(request)
except Exception as e:
logger.error(f"Error in ConcurrencyLimitMiddleware: {str(e)}")
return JSONResponse(
status_code=500,
content={"detail": f"Internal server error in middleware: {str(e)}"}
)
# Protection against large request payloads
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_content_length=1024*1024): # 1MB default
super().__init__(app)
self.max_content_length = max_content_length
logger.info(f"Request size limit middleware initialized: {max_content_length} bytes")
async def dispatch(self, request: Request, call_next):
content_length = request.headers.get('content-length')
if content_length:
if int(content_length) > self.max_content_length:
logger.warning(f"Request too large: {content_length} bytes")
return JSONResponse(
status_code=413,
content={"detail": "Request too large"}
)
return await call_next(request)
|