import time from collections import defaultdict from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware class RateLimitMiddleware(BaseHTTPMiddleware): """API限流中间件""" def __init__(self, app, calls_per_minute: int = 60): super().__init__(app) self.calls_per_minute = calls_per_minute self.calls = defaultdict(list) async def dispatch(self, request: Request, call_next): client_ip = request.client.host if request.client else "unknown" current_time = time.time() # 清理过期的调用记录 minute_ago = current_time - 60 self.calls[client_ip] = [ call_time for call_time in self.calls[client_ip] if call_time > minute_ago ] # 检查是否超过限制 if len(self.calls[client_ip]) >= self.calls_per_minute: raise HTTPException( status_code=429, detail=f"请求频率过高,每分钟最多允许 {self.calls_per_minute} 次请求" ) # 记录本次调用 self.calls[client_ip].append(current_time) # 处理请求 response = await call_next(request) return response