File size: 5,512 Bytes
29482b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bec596e
 
 
 
 
 
 
29482b8
bec596e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29482b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import Optional, Tuple, Dict, Any
from fastapi import Request, HTTPException, Depends, Header
import aiohttp
import logging
from ..config import Config
from .auth import verify_jwt_token

logger = logging.getLogger("sora-api.dependencies")

# 全局会话池
session_pool: Optional[aiohttp.ClientSession] = None

# 获取Sora客户端
def get_sora_client(auth_token: str):
    from ..sora_integration import SoraClient
    
    # 使用字典缓存客户端实例
    if not hasattr(get_sora_client, "clients"):
        get_sora_client.clients = {}
        
    if auth_token not in get_sora_client.clients:
        proxy_host = Config.PROXY_HOST if Config.PROXY_HOST and Config.PROXY_HOST.strip() else None
        proxy_port = Config.PROXY_PORT if Config.PROXY_PORT and Config.PROXY_PORT.strip() else None
        proxy_user = Config.PROXY_USER if Config.PROXY_USER and Config.PROXY_USER.strip() else None
        proxy_pass = Config.PROXY_PASS if Config.PROXY_PASS and Config.PROXY_PASS.strip() else None
        
        get_sora_client.clients[auth_token] = SoraClient(
            proxy_host=proxy_host, 
            proxy_port=proxy_port,
            proxy_user=proxy_user,
            proxy_pass=proxy_pass,
            auth_token=auth_token
        )
    
    return get_sora_client.clients[auth_token]

# 从请求头中获取并验证认证令牌
async def get_token_from_header(authorization: Optional[str] = Header(None)) -> str:
    """从请求头中获取认证令牌"""
    if not authorization:
        raise HTTPException(status_code=401, detail="缺少认证头")
    
    if not authorization.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="无效的认证头格式")
    
    return authorization.replace("Bearer ", "")

# 验证API key
async def verify_api_key(request: Request):
    """检查请求头中的API密钥"""
    auth_header = request.headers.get("Authorization")
    if not auth_header or not auth_header.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="缺少或无效的API key")
    
    api_key = auth_header.replace("Bearer ", "")
    
    # 优先使用API_ACCESS_TOKEN进行认证
    if Config.API_ACCESS_TOKEN:
        # 如果设置了API_ACCESS_TOKEN环境变量,则进行验证
        if api_key == Config.API_ACCESS_TOKEN:
            return api_key
        else:
            logger.warning(f"API访问认证失败: 提供的令牌与API_ACCESS_TOKEN不匹配")
    
    # 如果没有API_ACCESS_TOKEN或验证失败,验证是否为key_manager中的密钥
    from ..key_manager import key_manager
    # 获取所有启用的密钥的原始令牌值(不含Bearer前缀)
    valid_keys = []
    for k in key_manager.get_all_keys():
        if k.get("is_enabled", False):
            key_value = k.get("key", "")
            if key_value.startswith("Bearer "):
                key_value = key_value[7:]  # 移除Bearer前缀
            valid_keys.append(key_value)
    
    # 检查API密钥是否在有效列表中
    if api_key in valid_keys:
        return api_key
    
    # 最后检查是否为管理员密钥
    if api_key == Config.ADMIN_KEY:
        return api_key
    
    # 所有验证都失败
    logger.warning(f"API认证失败: 提供的key不在有效列表中")
    raise HTTPException(status_code=401, detail="API认证失败,key无效")

# 获取Sora客户端依赖
def get_sora_client_dep(specific_key=None):
    """返回一个依赖函数,用于获取Sora客户端
    
    Args:
        specific_key: 指定使用的API密钥,如果不为None,则优先使用此密钥
    """
    async def _get_client(auth_token: str = Depends(verify_api_key)):
        from ..key_manager import key_manager
        
        # 如果提供了特定密钥,则使用该密钥
        if specific_key:
            sora_auth_token = specific_key
        else:
            # 使用密钥管理器获取可用的API密钥
            sora_auth_token = key_manager.get_key()
            if not sora_auth_token:
                raise HTTPException(status_code=429, detail="所有API key都已达到速率限制")
            
        # 获取Sora客户端
        return get_sora_client(sora_auth_token), sora_auth_token
    
    return _get_client

# 验证JWT令牌并验证管理员权限
async def verify_admin_jwt(token: str = Depends(get_token_from_header)) -> Dict[str, Any]:
    """验证JWT令牌并确认管理员权限"""
    # 验证JWT令牌
    payload = verify_jwt_token(token)
    
    # 验证是否为管理员角色
    if payload.get("role") != "admin":
        raise HTTPException(status_code=403, detail="没有管理员权限")
    
    return payload

# 验证管理员权限(传统方法,保留向后兼容性)
async def verify_admin(request: Request):
    """验证管理员权限"""
    auth_header = request.headers.get("Authorization")
    if not auth_header or not auth_header.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="未授权")
    
    token = auth_header.replace("Bearer ", "")
    
    # 尝试JWT验证
    try:
        payload = verify_jwt_token(token)
        if payload.get("role") == "admin":
            return token
    except HTTPException:
        # JWT验证失败,尝试传统验证
        pass
    
    # 传统验证(直接验证管理员密钥)
    if token != Config.ADMIN_KEY:
        raise HTTPException(status_code=403, detail="没有管理员权限")
    
    return token