kiroproxy / kiro_proxy /core /auth_middleware.py
KiroProxy User
chore: repo cleanup and maintenance
0edbd7b
"""管理员认证中间件和装饰器"""
import os
import hmac
from functools import wraps
from typing import Optional
from fastapi import HTTPException, Request, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from .admin_auth import get_admin_auth, validate_admin_session, is_auth_required
security = HTTPBearer(auto_error=False)
def get_api_key() -> Optional[str]:
"""获取 API 密钥环境变量"""
return os.getenv("API_KEY")
def is_api_auth_required() -> bool:
"""检查是否需要 API 认证"""
return bool(get_api_key())
def verify_api_key(provided_key: str) -> bool:
"""验证 API 密钥"""
expected_key = get_api_key()
if not expected_key:
return True
return hmac.compare_digest(provided_key, expected_key)
def extract_api_key_from_request(request: Request) -> Optional[str]:
"""从请求中提取 API 密钥"""
auth_header = request.headers.get("Authorization")
if auth_header:
if auth_header.startswith("Bearer "):
return auth_header[7:]
return auth_header
api_key = request.headers.get("X-API-Key")
if api_key:
return api_key
api_key = request.query_params.get("api_key")
if api_key:
return api_key
return None
async def require_api_auth(request: Request) -> bool:
"""要求 API 认证的依赖项"""
if not is_api_auth_required():
return True
api_key = extract_api_key_from_request(request)
if not api_key:
raise HTTPException(
status_code=401,
detail="Missing API key. Please provide via Authorization header (Bearer <key>), X-API-Key header, or api_key query parameter.",
headers={"WWW-Authenticate": "Bearer"}
)
if not verify_api_key(api_key):
raise HTTPException(
status_code=401,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"}
)
return True
class AdminAuthMiddleware:
"""管理员认证中间件"""
def __init__(self):
self.auth = get_admin_auth()
def get_session_from_request(self, request: Request) -> Optional[str]:
"""从请求中提取会话ID"""
# 1. 从Authorization头获取
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header[7:]
# 2. 从Cookie获取
session_id = request.cookies.get("admin_session")
if session_id:
return session_id
# 3. 从查询参数获取(用于WebSocket等场景)
session_id = request.query_params.get("session")
if session_id:
return session_id
return None
async def authenticate_request(self, request: Request) -> bool:
"""验证请求是否已认证"""
if not is_auth_required():
return True
session_id = self.get_session_from_request(request)
if not session_id:
return False
return validate_admin_session(session_id)
# 全局中间件实例
_middleware_instance: Optional[AdminAuthMiddleware] = None
def get_auth_middleware() -> AdminAuthMiddleware:
"""获取认证中间件实例"""
global _middleware_instance
if _middleware_instance is None:
_middleware_instance = AdminAuthMiddleware()
return _middleware_instance
# FastAPI依赖项
async def require_admin_auth(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[str]:
"""要求管理员认证的依赖项"""
if not is_auth_required():
return "no_auth_required"
middleware = get_auth_middleware()
# 检查认证
if not await middleware.authenticate_request(request):
raise HTTPException(
status_code=401,
detail="需要管理员认证",
headers={"WWW-Authenticate": "Bearer"}
)
# 返回会话ID
session_id = middleware.get_session_from_request(request)
return session_id
async def optional_admin_auth(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[str]:
"""可选的管理员认证依赖项"""
middleware = get_auth_middleware()
if await middleware.authenticate_request(request):
return middleware.get_session_from_request(request)
return None
# 装饰器版本(用于非FastAPI函数)
def require_admin_session(func):
"""要求管理员会话的装饰器"""
@wraps(func)
async def wrapper(*args, **kwargs):
# 查找Request参数
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request:
# 从kwargs查找
request = kwargs.get('request')
if not request:
raise HTTPException(status_code=500, detail="无法获取请求对象")
middleware = get_auth_middleware()
if not await middleware.authenticate_request(request):
raise HTTPException(
status_code=401,
detail="需要管理员认证",
headers={"WWW-Authenticate": "Bearer"}
)
return await func(*args, **kwargs)
return wrapper