zenith-backend / app /middleware /rate_limit.py
teoat's picture
Upload folder using huggingface_hub
4ae946d verified
import logging
import time
from collections import defaultdict
from fastapi import HTTPException, Request
logger = logging.getLogger(__name__)
class RateLimiter:
def __init__(self, requests_per_minute: int = 100, burst_limit: int = 20):
self.requests_per_minute = requests_per_minute
self.burst_limit = burst_limit # Allow burst of requests
self.requests = defaultdict(list)
self.blocked_ips = set()
def is_allowed(self, client_ip: str) -> bool:
"""Check if request is within rate limits with burst protection"""
if client_ip in self.blocked_ips:
return False
current_time = time.time()
window_start = current_time - 60 # 1 minute window
# Clean old requests
self.requests[client_ip] = [
req_time for req_time in self.requests[client_ip] if req_time > window_start
]
request_count = len(self.requests[client_ip])
# Check burst limit (requests in last 10 seconds)
burst_window = current_time - 10
burst_count = sum(
1 for req_time in self.requests[client_ip] if req_time > burst_window
)
if burst_count >= self.burst_limit:
# Temporary block for burst abuse
self.blocked_ips.add(client_ip)
# Auto-unblock after 5 minutes
import threading
timer = threading.Timer(300, lambda: self.blocked_ips.discard(client_ip))
timer.start()
logger.warning(f"IP {client_ip} temporarily blocked for burst abuse")
return False
# Check sustained rate limit
if request_count >= self.requests_per_minute:
logger.warning(f"Rate limit exceeded for IP {client_ip}")
return False
# Add current request
self.requests[client_ip].append(current_time)
return True
# Global rate limiter instance
rate_limiter = RateLimiter(requests_per_minute=100) # 100 requests per minute
async def rate_limit_middleware(request: Request, call_next):
"""Rate limiting middleware"""
# Skip rate limiting in development
import os
if os.getenv("ENVIRONMENT", "development").lower() == "development":
return await call_next(request)
client_ip = request.client.host if request.client else "unknown"
# Exempt localhost/127.0.0.1 from rate limiting
if client_ip in ["127.0.0.1", "localhost", "::1"]:
return await call_next(request)
if not rate_limiter.is_allowed(client_ip):
logger.warning(f"Rate limit exceeded for IP: {client_ip}")
raise HTTPException(
status_code=429, detail="Too many requests. Please try again later."
)
response = await call_next(request)
return response