Bloom_Ware / middleware /rate_limit.py
XiaoBai1221's picture
Latest
69fb140
"""
Rate Limiting 中間件
防止 API 濫用
"""
import logging
import time
from typing import Dict, Tuple
from collections import defaultdict
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request as StarletteRequest
from starlette.responses import JSONResponse
logger = logging.getLogger("middleware.rate_limit")
class RateLimiter:
"""
簡易 Rate Limiter(記憶體實現)
生產環境建議使用 Redis 實現
"""
def __init__(
self,
requests_per_minute: int = 60,
requests_per_hour: int = 1000,
):
self.requests_per_minute = requests_per_minute
self.requests_per_hour = requests_per_hour
# 記錄請求:{ip: [(timestamp, count), ...]}
self._minute_requests: Dict[str, list] = defaultdict(list)
self._hour_requests: Dict[str, list] = defaultdict(list)
def _cleanup_old_requests(self, requests: list, window_seconds: int) -> list:
"""清理過期的請求記錄"""
current_time = time.time()
return [
(ts, count) for ts, count in requests
if current_time - ts < window_seconds
]
def is_allowed(self, client_ip: str) -> Tuple[bool, str]:
"""
檢查請求是否被允許
Returns:
(is_allowed, reason)
"""
current_time = time.time()
# 清理過期記錄
self._minute_requests[client_ip] = self._cleanup_old_requests(
self._minute_requests[client_ip], 60
)
self._hour_requests[client_ip] = self._cleanup_old_requests(
self._hour_requests[client_ip], 3600
)
# 計算當前窗口內的請求數
minute_count = sum(count for _, count in self._minute_requests[client_ip])
hour_count = sum(count for _, count in self._hour_requests[client_ip])
# 檢查限制
if minute_count >= self.requests_per_minute:
return False, f"每分鐘請求數超過限制({self.requests_per_minute})"
if hour_count >= self.requests_per_hour:
return False, f"每小時請求數超過限制({self.requests_per_hour})"
# 記錄請求
self._minute_requests[client_ip].append((current_time, 1))
self._hour_requests[client_ip].append((current_time, 1))
return True, ""
def get_remaining(self, client_ip: str) -> Dict[str, int]:
"""獲取剩餘請求數"""
# 清理過期記錄
self._minute_requests[client_ip] = self._cleanup_old_requests(
self._minute_requests[client_ip], 60
)
self._hour_requests[client_ip] = self._cleanup_old_requests(
self._hour_requests[client_ip], 3600
)
minute_count = sum(count for _, count in self._minute_requests[client_ip])
hour_count = sum(count for _, count in self._hour_requests[client_ip])
return {
"minute_remaining": max(0, self.requests_per_minute - minute_count),
"hour_remaining": max(0, self.requests_per_hour - hour_count),
}
# 全局 Rate Limiter 實例
rate_limiter = RateLimiter(
requests_per_minute=60,
requests_per_hour=1000,
)
def get_client_ip(request: StarletteRequest) -> str:
"""獲取客戶端 IP"""
# 優先取 X-Forwarded-For
xff = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For")
if xff:
ip = xff.split(",")[0].strip()
if ip:
return ip
return request.client.host if request.client else "unknown"
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate Limiting 中間件"""
# 不需要限制的路徑
EXEMPT_PATHS = {
"/",
"/health",
"/static",
"/login",
"/favicon.ico",
}
async def dispatch(self, request: StarletteRequest, call_next):
# 檢查是否豁免
path = request.url.path
if any(path.startswith(exempt) for exempt in self.EXEMPT_PATHS):
return await call_next(request)
# 獲取客戶端 IP
client_ip = get_client_ip(request)
# 檢查 Rate Limit
is_allowed, reason = rate_limiter.is_allowed(client_ip)
if not is_allowed:
logger.warning(f"Rate limit exceeded for {client_ip}: {reason}")
remaining = rate_limiter.get_remaining(client_ip)
return JSONResponse(
status_code=429,
content={
"error": "Too Many Requests",
"message": reason,
"retry_after": 60, # 建議等待時間(秒)
},
headers={
"Retry-After": "60",
"X-RateLimit-Remaining-Minute": str(remaining["minute_remaining"]),
"X-RateLimit-Remaining-Hour": str(remaining["hour_remaining"]),
}
)
# 繼續處理請求
response = await call_next(request)
# 添加 Rate Limit 頭
remaining = rate_limiter.get_remaining(client_ip)
response.headers["X-RateLimit-Remaining-Minute"] = str(remaining["minute_remaining"])
response.headers["X-RateLimit-Remaining-Hour"] = str(remaining["hour_remaining"])
return response