|
|
"""Simple rate limiting middleware for demo/single user scenarios.""" |
|
|
|
|
|
import time |
|
|
from collections import defaultdict, deque |
|
|
from typing import Callable |
|
|
from fastapi import Request, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
from app.utils.constants import ( |
|
|
RATE_LIMIT_REQUESTS_PER_MINUTE, |
|
|
RATE_LIMIT_REQUESTS_PER_HOUR, |
|
|
) |
|
|
|
|
|
|
|
|
class SimpleRateLimiter: |
|
|
"""Simple in-memory rate limiter for demo use (not for production with multiple servers).""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self._requests_by_ip: dict[str, deque] = defaultdict(lambda: deque()) |
|
|
self._last_cleanup = time.time() |
|
|
self._cleanup_interval = 300 |
|
|
|
|
|
def _cleanup_old_entries(self): |
|
|
"""Remove old request timestamps to prevent memory growth.""" |
|
|
current_time = time.time() |
|
|
if current_time - self._last_cleanup < self._cleanup_interval: |
|
|
return |
|
|
|
|
|
cutoff_minute = current_time - 60 |
|
|
cutoff_hour = current_time - 3600 |
|
|
|
|
|
for ip in list(self._requests_by_ip.keys()): |
|
|
requests = self._requests_by_ip[ip] |
|
|
|
|
|
while requests and requests[0] < cutoff_hour: |
|
|
requests.popleft() |
|
|
|
|
|
|
|
|
if not requests: |
|
|
del self._requests_by_ip[ip] |
|
|
|
|
|
self._last_cleanup = current_time |
|
|
|
|
|
def check_rate_limit(self, ip: str) -> tuple[bool, str | None]: |
|
|
""" |
|
|
Check if request should be allowed. |
|
|
|
|
|
Returns: |
|
|
(allowed, error_message) |
|
|
""" |
|
|
self._cleanup_old_entries() |
|
|
|
|
|
current_time = time.time() |
|
|
requests = self._requests_by_ip[ip] |
|
|
|
|
|
|
|
|
cutoff_hour = current_time - 3600 |
|
|
while requests and requests[0] < cutoff_hour: |
|
|
requests.popleft() |
|
|
|
|
|
|
|
|
if len(requests) >= RATE_LIMIT_REQUESTS_PER_HOUR: |
|
|
return False, f"Rate limit exceeded: {RATE_LIMIT_REQUESTS_PER_HOUR} requests per hour" |
|
|
|
|
|
|
|
|
cutoff_minute = current_time - 60 |
|
|
recent_requests = [r for r in requests if r >= cutoff_minute] |
|
|
if len(recent_requests) >= RATE_LIMIT_REQUESTS_PER_MINUTE: |
|
|
return False, f"Rate limit exceeded: {RATE_LIMIT_REQUESTS_PER_MINUTE} requests per minute" |
|
|
|
|
|
|
|
|
requests.append(current_time) |
|
|
return True, None |
|
|
|
|
|
|
|
|
|
|
|
_rate_limiter = SimpleRateLimiter() |
|
|
|
|
|
|
|
|
async def rate_limit_middleware(request: Request, call_next: Callable): |
|
|
"""Rate limiting middleware.""" |
|
|
|
|
|
public_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json", "/v1/stats"] |
|
|
if request.url.path in public_paths: |
|
|
return await call_next(request) |
|
|
|
|
|
|
|
|
client_ip = request.client.host if request.client else "unknown" |
|
|
|
|
|
|
|
|
allowed, error_msg = _rate_limiter.check_rate_limit(client_ip) |
|
|
|
|
|
if not allowed: |
|
|
return JSONResponse( |
|
|
status_code=429, |
|
|
content={ |
|
|
"error": { |
|
|
"message": error_msg, |
|
|
"type": "rate_limit_error" |
|
|
} |
|
|
}, |
|
|
headers={ |
|
|
"Retry-After": "60", |
|
|
"X-RateLimit-Limit-Minute": str(RATE_LIMIT_REQUESTS_PER_MINUTE), |
|
|
"X-RateLimit-Limit-Hour": str(RATE_LIMIT_REQUESTS_PER_HOUR), |
|
|
} |
|
|
) |
|
|
|
|
|
response = await call_next(request) |
|
|
|
|
|
|
|
|
requests = _rate_limiter._requests_by_ip[client_ip] |
|
|
current_time = time.time() |
|
|
recent_minute = [r for r in requests if r >= current_time - 60] |
|
|
recent_hour = [r for r in requests if r >= current_time - 3600] |
|
|
|
|
|
response.headers["X-RateLimit-Limit-Minute"] = str(RATE_LIMIT_REQUESTS_PER_MINUTE) |
|
|
response.headers["X-RateLimit-Limit-Hour"] = str(RATE_LIMIT_REQUESTS_PER_HOUR) |
|
|
response.headers["X-RateLimit-Remaining-Minute"] = str(max(0, RATE_LIMIT_REQUESTS_PER_MINUTE - len(recent_minute))) |
|
|
response.headers["X-RateLimit-Remaining-Hour"] = str(max(0, RATE_LIMIT_REQUESTS_PER_HOUR - len(recent_hour))) |
|
|
|
|
|
return response |
|
|
|
|
|
|