| |
|
|
| import os |
| import time |
| import logging |
| from fastapi import Request, HTTPException, Depends, status |
| from typing import Dict, List |
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
| |
| ADMIN_API_KEYS = [key.strip() for key in os.getenv("ADMIN_API_KEYS", "").split(",") if key.strip()] |
| |
| if not ADMIN_API_KEYS: |
| logger.warning("未配置 ADMIN_API_KEYS。某些管理功能可能无法访问。") |
|
|
| USER_API_KEYS = os.getenv("USER_API_KEYS", "").split(",") |
| USER_API_KEYS = [key.strip() for key in USER_API_KEYS if key.strip()] |
|
|
| ALL_API_KEYS = list(set(ADMIN_API_KEYS + USER_API_KEYS)) |
|
|
| |
| |
| active_keys_status: Dict[str, float] = {} |
|
|
| |
| |
| RATE_LIMIT_PER_MINUTE = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")) |
| RATE_LIMIT_WINDOW_SECONDS = 60 |
|
|
| |
| request_timestamps: Dict[str, List[float]] = {} |
|
|
| |
| async def get_auth_token(request: Request) -> str: |
| """从请求头中获取认证令牌并进行验证,返回认证令牌。""" |
| auth_token = None |
| |
| auth_header = request.headers.get("Authorization") |
| if auth_header and auth_header.startswith("Bearer "): |
| auth_token = auth_header.split(" ")[1] |
| logger.debug(f"从 Authorization 头获取到认证令牌 (Bearer Token): {auth_token[:4]}... 来自 IP: {request.client.host}") |
| else: |
| |
| auth_token = request.headers.get("X-Auth-Token") |
| logger.debug(f"从 X-Auth-Token 头获取到认证令牌: {auth_token[:4] if auth_token else 'None'}... 来自 IP: {request.client.host}") |
|
|
| if not auth_token or auth_token not in ALL_API_KEYS: |
| |
| log_token_display = auth_token if auth_token else 'None' |
| logger.warning(f"认证失败: 无效或缺失的认证令牌 '{log_token_display}' 来自 IP: {request.client.host}") |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, detail="无效或缺失的认证令牌" |
| ) |
|
|
| |
| active_keys_status[auth_token] = time.time() |
|
|
| |
| current_time = time.time() |
| |
| request_timestamps[auth_token] = [ |
| t for t in request_timestamps.get(auth_token, []) if current_time - t < RATE_LIMIT_WINDOW_SECONDS |
| ] |
| |
| |
| if len(request_timestamps[auth_token]) >= RATE_LIMIT_PER_MINUTE: |
| logger.warning(f"认证令牌 '{auth_token}' 达到速率限制。来自 IP: {request.client.host}") |
| raise HTTPException( |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, |
| detail=f"请求过于频繁,请稍后再试。当前限制为每分钟 {RATE_LIMIT_PER_MINUTE} 次。", |
| ) |
| |
| |
| request_timestamps[auth_token].append(current_time) |
| |
|
|
| return auth_token |
|
|
|
|
| async def get_admin_api_key(auth_token: str = Depends(get_auth_token)) -> str: |
| """验证认证令牌是否为管理员令牌。""" |
| if auth_token not in ADMIN_API_KEYS: |
| logger.warning(f"权限不足: 认证令牌 '{auth_token}' 不是管理员令牌,尝试访问受限资源。") |
| raise HTTPException( |
| status_code=status.HTTP_403_FORBIDDEN, detail="操作需要管理员权限" |
| ) |
| return auth_token |
|
|
|
|
| async def get_user_api_key(auth_token: str = Depends(get_auth_token)) -> str: |
| """验证认证令牌是否为有效用户令牌 (管理员或普通用户均可)。""" |
| |
| |
| return auth_token |
|
|