|
|
"""管理员认证中间件和装饰器""" |
|
|
import os |
|
|
import hmac |
|
|
from functools import wraps |
|
|
from typing import Optional |
|
|
from fastapi import HTTPException, Request, Depends |
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
|
from .admin_auth import get_admin_auth, validate_admin_session, is_auth_required |
|
|
|
|
|
security = HTTPBearer(auto_error=False) |
|
|
|
|
|
|
|
|
def get_api_key() -> Optional[str]: |
|
|
"""获取 API 密钥环境变量""" |
|
|
return os.getenv("API_KEY") |
|
|
|
|
|
|
|
|
def is_api_auth_required() -> bool: |
|
|
"""检查是否需要 API 认证""" |
|
|
return bool(get_api_key()) |
|
|
|
|
|
|
|
|
def verify_api_key(provided_key: str) -> bool: |
|
|
"""验证 API 密钥""" |
|
|
expected_key = get_api_key() |
|
|
if not expected_key: |
|
|
return True |
|
|
return hmac.compare_digest(provided_key, expected_key) |
|
|
|
|
|
|
|
|
def extract_api_key_from_request(request: Request) -> Optional[str]: |
|
|
"""从请求中提取 API 密钥""" |
|
|
auth_header = request.headers.get("Authorization") |
|
|
if auth_header: |
|
|
if auth_header.startswith("Bearer "): |
|
|
return auth_header[7:] |
|
|
return auth_header |
|
|
|
|
|
api_key = request.headers.get("X-API-Key") |
|
|
if api_key: |
|
|
return api_key |
|
|
|
|
|
api_key = request.query_params.get("api_key") |
|
|
if api_key: |
|
|
return api_key |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
async def require_api_auth(request: Request) -> bool: |
|
|
"""要求 API 认证的依赖项""" |
|
|
if not is_api_auth_required(): |
|
|
return True |
|
|
|
|
|
api_key = extract_api_key_from_request(request) |
|
|
if not api_key: |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="Missing API key. Please provide via Authorization header (Bearer <key>), X-API-Key header, or api_key query parameter.", |
|
|
headers={"WWW-Authenticate": "Bearer"} |
|
|
) |
|
|
|
|
|
if not verify_api_key(api_key): |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="Invalid API key", |
|
|
headers={"WWW-Authenticate": "Bearer"} |
|
|
) |
|
|
|
|
|
return True |
|
|
|
|
|
class AdminAuthMiddleware: |
|
|
"""管理员认证中间件""" |
|
|
|
|
|
def __init__(self): |
|
|
self.auth = get_admin_auth() |
|
|
|
|
|
def get_session_from_request(self, request: Request) -> Optional[str]: |
|
|
"""从请求中提取会话ID""" |
|
|
|
|
|
auth_header = request.headers.get("Authorization") |
|
|
if auth_header and auth_header.startswith("Bearer "): |
|
|
return auth_header[7:] |
|
|
|
|
|
|
|
|
session_id = request.cookies.get("admin_session") |
|
|
if session_id: |
|
|
return session_id |
|
|
|
|
|
|
|
|
session_id = request.query_params.get("session") |
|
|
if session_id: |
|
|
return session_id |
|
|
|
|
|
return None |
|
|
|
|
|
async def authenticate_request(self, request: Request) -> bool: |
|
|
"""验证请求是否已认证""" |
|
|
if not is_auth_required(): |
|
|
return True |
|
|
|
|
|
session_id = self.get_session_from_request(request) |
|
|
if not session_id: |
|
|
return False |
|
|
|
|
|
return validate_admin_session(session_id) |
|
|
|
|
|
|
|
|
|
|
|
_middleware_instance: Optional[AdminAuthMiddleware] = None |
|
|
|
|
|
def get_auth_middleware() -> AdminAuthMiddleware: |
|
|
"""获取认证中间件实例""" |
|
|
global _middleware_instance |
|
|
if _middleware_instance is None: |
|
|
_middleware_instance = AdminAuthMiddleware() |
|
|
return _middleware_instance |
|
|
|
|
|
|
|
|
|
|
|
async def require_admin_auth( |
|
|
request: Request, |
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security) |
|
|
) -> Optional[str]: |
|
|
"""要求管理员认证的依赖项""" |
|
|
if not is_auth_required(): |
|
|
return "no_auth_required" |
|
|
|
|
|
middleware = get_auth_middleware() |
|
|
|
|
|
|
|
|
if not await middleware.authenticate_request(request): |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="需要管理员认证", |
|
|
headers={"WWW-Authenticate": "Bearer"} |
|
|
) |
|
|
|
|
|
|
|
|
session_id = middleware.get_session_from_request(request) |
|
|
return session_id |
|
|
|
|
|
|
|
|
async def optional_admin_auth( |
|
|
request: Request, |
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security) |
|
|
) -> Optional[str]: |
|
|
"""可选的管理员认证依赖项""" |
|
|
middleware = get_auth_middleware() |
|
|
|
|
|
if await middleware.authenticate_request(request): |
|
|
return middleware.get_session_from_request(request) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def require_admin_session(func): |
|
|
"""要求管理员会话的装饰器""" |
|
|
@wraps(func) |
|
|
async def wrapper(*args, **kwargs): |
|
|
|
|
|
request = None |
|
|
for arg in args: |
|
|
if isinstance(arg, Request): |
|
|
request = arg |
|
|
break |
|
|
|
|
|
if not request: |
|
|
|
|
|
request = kwargs.get('request') |
|
|
|
|
|
if not request: |
|
|
raise HTTPException(status_code=500, detail="无法获取请求对象") |
|
|
|
|
|
middleware = get_auth_middleware() |
|
|
if not await middleware.authenticate_request(request): |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="需要管理员认证", |
|
|
headers={"WWW-Authenticate": "Bearer"} |
|
|
) |
|
|
|
|
|
return await func(*args, **kwargs) |
|
|
|
|
|
return wrapper |