Spaces:
Running
Running
| """ | |
| Rate Limiting 中間件 | |
| 防止 API 濫用 | |
| """ | |
| import logging | |
| import time | |
| from typing import Dict, Tuple | |
| from collections import defaultdict | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.requests import Request as StarletteRequest | |
| from starlette.responses import JSONResponse | |
| logger = logging.getLogger("middleware.rate_limit") | |
| class RateLimiter: | |
| """ | |
| 簡易 Rate Limiter(記憶體實現) | |
| 生產環境建議使用 Redis 實現 | |
| """ | |
| def __init__( | |
| self, | |
| requests_per_minute: int = 60, | |
| requests_per_hour: int = 1000, | |
| ): | |
| self.requests_per_minute = requests_per_minute | |
| self.requests_per_hour = requests_per_hour | |
| # 記錄請求:{ip: [(timestamp, count), ...]} | |
| self._minute_requests: Dict[str, list] = defaultdict(list) | |
| self._hour_requests: Dict[str, list] = defaultdict(list) | |
| def _cleanup_old_requests(self, requests: list, window_seconds: int) -> list: | |
| """清理過期的請求記錄""" | |
| current_time = time.time() | |
| return [ | |
| (ts, count) for ts, count in requests | |
| if current_time - ts < window_seconds | |
| ] | |
| def is_allowed(self, client_ip: str) -> Tuple[bool, str]: | |
| """ | |
| 檢查請求是否被允許 | |
| Returns: | |
| (is_allowed, reason) | |
| """ | |
| current_time = time.time() | |
| # 清理過期記錄 | |
| self._minute_requests[client_ip] = self._cleanup_old_requests( | |
| self._minute_requests[client_ip], 60 | |
| ) | |
| self._hour_requests[client_ip] = self._cleanup_old_requests( | |
| self._hour_requests[client_ip], 3600 | |
| ) | |
| # 計算當前窗口內的請求數 | |
| minute_count = sum(count for _, count in self._minute_requests[client_ip]) | |
| hour_count = sum(count for _, count in self._hour_requests[client_ip]) | |
| # 檢查限制 | |
| if minute_count >= self.requests_per_minute: | |
| return False, f"每分鐘請求數超過限制({self.requests_per_minute})" | |
| if hour_count >= self.requests_per_hour: | |
| return False, f"每小時請求數超過限制({self.requests_per_hour})" | |
| # 記錄請求 | |
| self._minute_requests[client_ip].append((current_time, 1)) | |
| self._hour_requests[client_ip].append((current_time, 1)) | |
| return True, "" | |
| def get_remaining(self, client_ip: str) -> Dict[str, int]: | |
| """獲取剩餘請求數""" | |
| # 清理過期記錄 | |
| self._minute_requests[client_ip] = self._cleanup_old_requests( | |
| self._minute_requests[client_ip], 60 | |
| ) | |
| self._hour_requests[client_ip] = self._cleanup_old_requests( | |
| self._hour_requests[client_ip], 3600 | |
| ) | |
| minute_count = sum(count for _, count in self._minute_requests[client_ip]) | |
| hour_count = sum(count for _, count in self._hour_requests[client_ip]) | |
| return { | |
| "minute_remaining": max(0, self.requests_per_minute - minute_count), | |
| "hour_remaining": max(0, self.requests_per_hour - hour_count), | |
| } | |
| # 全局 Rate Limiter 實例 | |
| rate_limiter = RateLimiter( | |
| requests_per_minute=60, | |
| requests_per_hour=1000, | |
| ) | |
| def get_client_ip(request: StarletteRequest) -> str: | |
| """獲取客戶端 IP""" | |
| # 優先取 X-Forwarded-For | |
| xff = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For") | |
| if xff: | |
| ip = xff.split(",")[0].strip() | |
| if ip: | |
| return ip | |
| return request.client.host if request.client else "unknown" | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| """Rate Limiting 中間件""" | |
| # 不需要限制的路徑 | |
| EXEMPT_PATHS = { | |
| "/", | |
| "/health", | |
| "/static", | |
| "/login", | |
| "/favicon.ico", | |
| } | |
| async def dispatch(self, request: StarletteRequest, call_next): | |
| # 檢查是否豁免 | |
| path = request.url.path | |
| if any(path.startswith(exempt) for exempt in self.EXEMPT_PATHS): | |
| return await call_next(request) | |
| # 獲取客戶端 IP | |
| client_ip = get_client_ip(request) | |
| # 檢查 Rate Limit | |
| is_allowed, reason = rate_limiter.is_allowed(client_ip) | |
| if not is_allowed: | |
| logger.warning(f"Rate limit exceeded for {client_ip}: {reason}") | |
| remaining = rate_limiter.get_remaining(client_ip) | |
| return JSONResponse( | |
| status_code=429, | |
| content={ | |
| "error": "Too Many Requests", | |
| "message": reason, | |
| "retry_after": 60, # 建議等待時間(秒) | |
| }, | |
| headers={ | |
| "Retry-After": "60", | |
| "X-RateLimit-Remaining-Minute": str(remaining["minute_remaining"]), | |
| "X-RateLimit-Remaining-Hour": str(remaining["hour_remaining"]), | |
| } | |
| ) | |
| # 繼續處理請求 | |
| response = await call_next(request) | |
| # 添加 Rate Limit 頭 | |
| remaining = rate_limiter.get_remaining(client_ip) | |
| response.headers["X-RateLimit-Remaining-Minute"] = str(remaining["minute_remaining"]) | |
| response.headers["X-RateLimit-Remaining-Hour"] = str(remaining["hour_remaining"]) | |
| return response | |