Spaces:
Sleeping
Sleeping
Upload 22 files
Browse files- src/anti_truncation.py +589 -0
- src/auth.py +1530 -0
- src/credential_manager.py +611 -0
- src/format_detector.py +172 -0
- src/gemini_router.py +583 -0
- src/google_chat_api.py +535 -0
- src/google_oauth_api.py +543 -0
- src/httpx_client.py +174 -0
- src/models.py +198 -0
- src/openai_router.py +395 -0
- src/openai_transfer.py +403 -0
- src/state_manager.py +166 -0
- src/storage/cache_manager.py +312 -0
- src/storage/file_storage_manager.py +606 -0
- src/storage/mongodb_manager.py +494 -0
- src/storage/postgres_manager.py +296 -0
- src/storage/redis_manager.py +504 -0
- src/storage_adapter.py +361 -0
- src/task_manager.py +137 -0
- src/usage_stats.py +444 -0
- src/utils.py +10 -0
- src/web_routes.py +1738 -0
src/anti_truncation.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anti-Truncation Module - Ensures complete streaming output
|
| 3 |
+
保持一个流式请求内完整输出的反截断模块
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
from typing import Dict, Any, AsyncGenerator, List, Tuple
|
| 8 |
+
|
| 9 |
+
from fastapi.responses import StreamingResponse
|
| 10 |
+
|
| 11 |
+
from log import log
|
| 12 |
+
|
| 13 |
+
# 反截断配置
|
| 14 |
+
DONE_MARKER = "[done]"
|
| 15 |
+
MAX_CONTINUATION_ATTEMPTS = 3
|
| 16 |
+
CONTINUATION_PROMPT = f"""请从刚才被截断的地方继续输出剩余的所有内容。
|
| 17 |
+
|
| 18 |
+
重要提醒:
|
| 19 |
+
1. 不要重复前面已经输出的内容
|
| 20 |
+
2. 直接继续输出,无需任何前言或解释
|
| 21 |
+
3. 当你完整完成所有内容输出后,必须在最后一行单独输出:{DONE_MARKER}
|
| 22 |
+
4. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记
|
| 23 |
+
|
| 24 |
+
现在请继续输出:"""
|
| 25 |
+
|
| 26 |
+
# 正则替换配置
|
| 27 |
+
REGEX_REPLACEMENTS: List[Tuple[str, str, str]] = [
|
| 28 |
+
(
|
| 29 |
+
"age_pattern", # 替换规则名称
|
| 30 |
+
r"(?:[1-9]|1[0-8])岁(?:的)?|(?:十一|十二|十三|十四|十五|十六|十七|十八|十|一|二|三|四|五|六|七|八|九)岁(?:的)?", # 正则模式
|
| 31 |
+
"" # 替换文本
|
| 32 |
+
),
|
| 33 |
+
# 可在此处添加更多替换规则
|
| 34 |
+
# ("rule_name", r"pattern", "replacement"),
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
def apply_regex_replacements(text: str) -> str:
|
| 38 |
+
"""
|
| 39 |
+
对文本应用正则替换规则
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
text: 要处理的文本
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
处理后的文本
|
| 46 |
+
"""
|
| 47 |
+
if not text:
|
| 48 |
+
return text
|
| 49 |
+
|
| 50 |
+
processed_text = text
|
| 51 |
+
replacement_count = 0
|
| 52 |
+
|
| 53 |
+
for rule_name, pattern, replacement in REGEX_REPLACEMENTS:
|
| 54 |
+
try:
|
| 55 |
+
# 编译正则表达式,使用IGNORECASE标志
|
| 56 |
+
regex = re.compile(pattern, re.IGNORECASE)
|
| 57 |
+
|
| 58 |
+
# 执行替换
|
| 59 |
+
new_text, count = regex.subn(replacement, processed_text)
|
| 60 |
+
|
| 61 |
+
if count > 0:
|
| 62 |
+
log.debug(f"Regex replacement '{rule_name}': {count} matches replaced")
|
| 63 |
+
processed_text = new_text
|
| 64 |
+
replacement_count += count
|
| 65 |
+
|
| 66 |
+
except re.error as e:
|
| 67 |
+
log.error(f"Invalid regex pattern in rule '{rule_name}': {e}")
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
if replacement_count > 0:
|
| 71 |
+
log.info(f"Applied {replacement_count} regex replacements to text")
|
| 72 |
+
|
| 73 |
+
return processed_text
|
| 74 |
+
|
| 75 |
+
def apply_regex_replacements_to_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
|
| 76 |
+
"""
|
| 77 |
+
对请求payload中的文本内容应用正则替换
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
payload: 请求payload
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
应用替换后的payload
|
| 84 |
+
"""
|
| 85 |
+
if not REGEX_REPLACEMENTS:
|
| 86 |
+
return payload
|
| 87 |
+
|
| 88 |
+
modified_payload = payload.copy()
|
| 89 |
+
request_data = modified_payload.get("request", {})
|
| 90 |
+
|
| 91 |
+
# 处理contents中的文本
|
| 92 |
+
contents = request_data.get("contents", [])
|
| 93 |
+
if contents:
|
| 94 |
+
new_contents = []
|
| 95 |
+
for content in contents:
|
| 96 |
+
if isinstance(content, dict):
|
| 97 |
+
new_content = content.copy()
|
| 98 |
+
parts = new_content.get("parts", [])
|
| 99 |
+
if parts:
|
| 100 |
+
new_parts = []
|
| 101 |
+
for part in parts:
|
| 102 |
+
if isinstance(part, dict) and "text" in part:
|
| 103 |
+
new_part = part.copy()
|
| 104 |
+
new_part["text"] = apply_regex_replacements(part["text"])
|
| 105 |
+
new_parts.append(new_part)
|
| 106 |
+
else:
|
| 107 |
+
new_parts.append(part)
|
| 108 |
+
new_content["parts"] = new_parts
|
| 109 |
+
new_contents.append(new_content)
|
| 110 |
+
else:
|
| 111 |
+
new_contents.append(content)
|
| 112 |
+
|
| 113 |
+
request_data["contents"] = new_contents
|
| 114 |
+
modified_payload["request"] = request_data
|
| 115 |
+
log.debug("Applied regex replacements to request contents")
|
| 116 |
+
|
| 117 |
+
return modified_payload
|
| 118 |
+
|
| 119 |
+
def apply_anti_truncation(payload: Dict[str, Any]) -> Dict[str, Any]:
|
| 120 |
+
"""
|
| 121 |
+
对请求payload应用反截断处理和正则替换
|
| 122 |
+
在systemInstruction中添加提醒,要求模型在结束时输出DONE_MARKER标记
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
payload: 原始请求payload
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
添加了反截断指令并应用了正则替换的payload
|
| 129 |
+
"""
|
| 130 |
+
# 首先应用正则替换
|
| 131 |
+
modified_payload = apply_regex_replacements_to_payload(payload)
|
| 132 |
+
request_data = modified_payload.get("request", {})
|
| 133 |
+
|
| 134 |
+
# 获取或创建systemInstruction
|
| 135 |
+
system_instruction = request_data.get("systemInstruction", {})
|
| 136 |
+
if not system_instruction:
|
| 137 |
+
system_instruction = {"parts": []}
|
| 138 |
+
elif "parts" not in system_instruction:
|
| 139 |
+
system_instruction["parts"] = []
|
| 140 |
+
|
| 141 |
+
# 添加反截断指令
|
| 142 |
+
anti_truncation_instruction = {
|
| 143 |
+
"text": f"""严格执行以下输出结束规则:
|
| 144 |
+
|
| 145 |
+
1. 当你完成完整回答时,必须在输出的最后单独一行输出:{DONE_MARKER}
|
| 146 |
+
2. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记
|
| 147 |
+
3. 只有输出了 {DONE_MARKER} 标记,系统才认为你的回答是完整的
|
| 148 |
+
4. 如果你的回答被截断,系统会要求你继续输出剩余内容
|
| 149 |
+
5. 无论回答长短,都必须以 {DONE_MARKER} 标记结束
|
| 150 |
+
|
| 151 |
+
示例格式:
|
| 152 |
+
```
|
| 153 |
+
你的回答内容...
|
| 154 |
+
更多回答内容...
|
| 155 |
+
{DONE_MARKER}
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
注意:{DONE_MARKER} 必须单独占一行,前面不要有任何其他字符。
|
| 159 |
+
|
| 160 |
+
这个规则对于确保输出完整性极其重要,请严格遵守。"""
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
# 检查是否已经包含反截断指令
|
| 164 |
+
has_done_instruction = any(
|
| 165 |
+
part.get("text", "").find(DONE_MARKER) != -1
|
| 166 |
+
for part in system_instruction["parts"]
|
| 167 |
+
if isinstance(part, dict)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if not has_done_instruction:
|
| 171 |
+
system_instruction["parts"].append(anti_truncation_instruction)
|
| 172 |
+
request_data["systemInstruction"] = system_instruction
|
| 173 |
+
modified_payload["request"] = request_data
|
| 174 |
+
|
| 175 |
+
log.debug("Applied anti-truncation instruction to request")
|
| 176 |
+
|
| 177 |
+
return modified_payload
|
| 178 |
+
|
| 179 |
+
class AntiTruncationStreamProcessor:
|
| 180 |
+
"""反截断流式处理器"""
|
| 181 |
+
|
| 182 |
+
def __init__(self,
|
| 183 |
+
original_request_func,
|
| 184 |
+
payload: Dict[str, Any],
|
| 185 |
+
max_attempts: int = MAX_CONTINUATION_ATTEMPTS):
|
| 186 |
+
self.original_request_func = original_request_func
|
| 187 |
+
self.base_payload = payload.copy()
|
| 188 |
+
self.max_attempts = max_attempts
|
| 189 |
+
self.collected_content = [] # 使用列表避免字符串重复拼接
|
| 190 |
+
self.current_attempt = 0
|
| 191 |
+
|
| 192 |
+
async def process_stream(self) -> AsyncGenerator[bytes, None]:
|
| 193 |
+
"""处理流式响应,检测并处理截断"""
|
| 194 |
+
|
| 195 |
+
while self.current_attempt < self.max_attempts:
|
| 196 |
+
self.current_attempt += 1
|
| 197 |
+
|
| 198 |
+
# 构建当前请求payload
|
| 199 |
+
current_payload = self._build_current_payload()
|
| 200 |
+
|
| 201 |
+
log.debug(f"Anti-truncation attempt {self.current_attempt}/{self.max_attempts}")
|
| 202 |
+
|
| 203 |
+
# 发送请求
|
| 204 |
+
try:
|
| 205 |
+
response = await self.original_request_func(current_payload)
|
| 206 |
+
|
| 207 |
+
if not isinstance(response, StreamingResponse):
|
| 208 |
+
# 非流式响应,直接处理
|
| 209 |
+
yield await self._handle_non_streaming_response(response)
|
| 210 |
+
return
|
| 211 |
+
|
| 212 |
+
# 处理流式响应
|
| 213 |
+
chunk_content = ""
|
| 214 |
+
found_done_marker = False
|
| 215 |
+
|
| 216 |
+
async for chunk in response.body_iterator:
|
| 217 |
+
if not chunk:
|
| 218 |
+
yield chunk
|
| 219 |
+
continue
|
| 220 |
+
|
| 221 |
+
# 处理不同数据类型的startswith问题
|
| 222 |
+
if isinstance(chunk, bytes):
|
| 223 |
+
if not chunk.startswith(b'data: '):
|
| 224 |
+
yield chunk
|
| 225 |
+
continue
|
| 226 |
+
payload_data = chunk[len(b'data: '):]
|
| 227 |
+
else:
|
| 228 |
+
chunk_str = str(chunk)
|
| 229 |
+
if not chunk_str.startswith('data: '):
|
| 230 |
+
yield chunk
|
| 231 |
+
continue
|
| 232 |
+
payload_data = chunk_str[len('data: '):].encode()
|
| 233 |
+
|
| 234 |
+
# 解析chunk内容
|
| 235 |
+
|
| 236 |
+
if payload_data.strip() == b'[DONE]':
|
| 237 |
+
# 检查是否找到了done标记
|
| 238 |
+
if found_done_marker:
|
| 239 |
+
log.info("Anti-truncation: Found [done] marker, output complete")
|
| 240 |
+
yield chunk
|
| 241 |
+
return
|
| 242 |
+
else:
|
| 243 |
+
log.warning("Anti-truncation: Stream ended without [done] marker")
|
| 244 |
+
# 不发送[DONE],准备继续
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
data = json.loads(payload_data.decode())
|
| 249 |
+
content = self._extract_content_from_chunk(data)
|
| 250 |
+
|
| 251 |
+
if content:
|
| 252 |
+
chunk_content += content
|
| 253 |
+
|
| 254 |
+
# 检查是否包含done标记
|
| 255 |
+
if self._check_done_marker_in_chunk_content(content):
|
| 256 |
+
found_done_marker = True
|
| 257 |
+
log.info("Anti-truncation: Found [done] marker in chunk")
|
| 258 |
+
|
| 259 |
+
# 清理chunk中的[done]标记后再发送
|
| 260 |
+
cleaned_chunk = self._remove_done_marker_from_chunk(chunk, data)
|
| 261 |
+
yield cleaned_chunk
|
| 262 |
+
|
| 263 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
| 264 |
+
yield chunk
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
# 更新收集的内容 - 使用列表避免字符串重复拼接
|
| 268 |
+
if chunk_content:
|
| 269 |
+
self.collected_content.append(chunk_content)
|
| 270 |
+
|
| 271 |
+
# 如果找到了done标记,结束
|
| 272 |
+
if found_done_marker:
|
| 273 |
+
# 立即清理内容释放内存
|
| 274 |
+
self.collected_content.clear()
|
| 275 |
+
yield b'data: [DONE]\n\n'
|
| 276 |
+
return
|
| 277 |
+
|
| 278 |
+
# 只有在单个chunk中没有找到done标记时,才检查累积内容(防止done标记跨chunk出现)
|
| 279 |
+
if not found_done_marker:
|
| 280 |
+
accumulated_text = ''.join(self.collected_content) if self.collected_content else ""
|
| 281 |
+
if self._check_done_marker_in_text(accumulated_text):
|
| 282 |
+
log.info("Anti-truncation: Found [done] marker in accumulated content")
|
| 283 |
+
# 立即清理内容释放内存
|
| 284 |
+
self.collected_content.clear()
|
| 285 |
+
yield b'data: [DONE]\n\n'
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
# 如果没找到done标记且不是最后一次尝试,准备续传
|
| 289 |
+
if self.current_attempt < self.max_attempts:
|
| 290 |
+
total_length = sum(len(chunk) for chunk in self.collected_content) if self.collected_content else 0
|
| 291 |
+
log.info(f"Anti-truncation: No [done] marker found in output (length: {total_length}), preparing continuation (attempt {self.current_attempt + 1})")
|
| 292 |
+
if self.collected_content and total_length > 100:
|
| 293 |
+
last_chunk = self.collected_content[-1] if self.collected_content else ""
|
| 294 |
+
log.debug(f"Anti-truncation: Current collected content ends with: {'...' + last_chunk[-100:]}")
|
| 295 |
+
# 在下一次循环中会继续
|
| 296 |
+
continue
|
| 297 |
+
else:
|
| 298 |
+
# 最后一次尝试,直接结束
|
| 299 |
+
log.warning("Anti-truncation: Max attempts reached, ending stream")
|
| 300 |
+
# 立即清理内容释放内存
|
| 301 |
+
self.collected_content.clear()
|
| 302 |
+
yield b'data: [DONE]\n\n'
|
| 303 |
+
return
|
| 304 |
+
|
| 305 |
+
except Exception as e:
|
| 306 |
+
log.error(f"Anti-truncation error in attempt {self.current_attempt}: {str(e)}")
|
| 307 |
+
if self.current_attempt >= self.max_attempts:
|
| 308 |
+
# 发送错误chunk
|
| 309 |
+
error_chunk = {
|
| 310 |
+
"error": {
|
| 311 |
+
"message": f"Anti-truncation failed: {str(e)}",
|
| 312 |
+
"type": "api_error",
|
| 313 |
+
"code": 500
|
| 314 |
+
}
|
| 315 |
+
}
|
| 316 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 317 |
+
yield b'data: [DONE]\n\n'
|
| 318 |
+
return
|
| 319 |
+
# 否则继续下一次尝试
|
| 320 |
+
|
| 321 |
+
# 如果所有尝试都失败了
|
| 322 |
+
log.error("Anti-truncation: All attempts failed")
|
| 323 |
+
# 确保清理内容释放内存
|
| 324 |
+
self.collected_content.clear()
|
| 325 |
+
yield b'data: [DONE]\n\n'
|
| 326 |
+
|
| 327 |
+
def _build_current_payload(self) -> Dict[str, Any]:
|
| 328 |
+
"""构建当前请求的payload"""
|
| 329 |
+
if self.current_attempt == 1:
|
| 330 |
+
# 第一次请求,使用原始payload(已经包含反截断指令)
|
| 331 |
+
return self.base_payload
|
| 332 |
+
|
| 333 |
+
# 后续请求,添加续传指令
|
| 334 |
+
continuation_payload = self.base_payload.copy()
|
| 335 |
+
request_data = continuation_payload.get("request", {})
|
| 336 |
+
|
| 337 |
+
# 获取原始对话内容
|
| 338 |
+
contents = request_data.get("contents", [])
|
| 339 |
+
new_contents = contents.copy()
|
| 340 |
+
|
| 341 |
+
# 如果有收集到的内容,添加到对话中
|
| 342 |
+
if self.collected_content:
|
| 343 |
+
# 拼接收集的内容并添加模型的回复
|
| 344 |
+
accumulated_text = ''.join(self.collected_content)
|
| 345 |
+
new_contents.append({
|
| 346 |
+
"role": "model",
|
| 347 |
+
"parts": [{"text": accumulated_text}]
|
| 348 |
+
})
|
| 349 |
+
|
| 350 |
+
# 构建具体的续写指令,包含前面的内容摘要
|
| 351 |
+
content_summary = ""
|
| 352 |
+
if self.collected_content:
|
| 353 |
+
accumulated_text = ''.join(self.collected_content)
|
| 354 |
+
if len(accumulated_text) > 200:
|
| 355 |
+
content_summary = f"\n\n前面你已经输出了约 {len(accumulated_text)} 个字符的内容,结尾是:\n\"...{accumulated_text[-100:]}\""
|
| 356 |
+
else:
|
| 357 |
+
content_summary = f"\n\n前面你已经输出的内容是:\n\"{accumulated_text}\""
|
| 358 |
+
|
| 359 |
+
detailed_continuation_prompt = f"""{CONTINUATION_PROMPT}{content_summary}"""
|
| 360 |
+
|
| 361 |
+
# 添加继续指令
|
| 362 |
+
continuation_message = {
|
| 363 |
+
"role": "user",
|
| 364 |
+
"parts": [{"text": detailed_continuation_prompt}]
|
| 365 |
+
}
|
| 366 |
+
new_contents.append(continuation_message)
|
| 367 |
+
|
| 368 |
+
request_data["contents"] = new_contents
|
| 369 |
+
continuation_payload["request"] = request_data
|
| 370 |
+
|
| 371 |
+
return continuation_payload
|
| 372 |
+
|
| 373 |
+
def _extract_content_from_chunk(self, data: Dict[str, Any]) -> str:
|
| 374 |
+
"""从chunk数据中提取文本内容"""
|
| 375 |
+
content = ""
|
| 376 |
+
|
| 377 |
+
# 处理Gemini格式
|
| 378 |
+
if "candidates" in data:
|
| 379 |
+
for candidate in data["candidates"]:
|
| 380 |
+
if "content" in candidate:
|
| 381 |
+
parts = candidate["content"].get("parts", [])
|
| 382 |
+
for part in parts:
|
| 383 |
+
if "text" in part:
|
| 384 |
+
content += part["text"]
|
| 385 |
+
|
| 386 |
+
# 处理OpenAI格式
|
| 387 |
+
elif "choices" in data:
|
| 388 |
+
for choice in data["choices"]:
|
| 389 |
+
if "delta" in choice and "content" in choice["delta"]:
|
| 390 |
+
content += choice["delta"]["content"]
|
| 391 |
+
elif "message" in choice and "content" in choice["message"]:
|
| 392 |
+
content += choice["message"]["content"]
|
| 393 |
+
|
| 394 |
+
return content
|
| 395 |
+
|
| 396 |
+
async def _handle_non_streaming_response(self, response) -> bytes:
|
| 397 |
+
"""处理非流式响应"""
|
| 398 |
+
try:
|
| 399 |
+
if hasattr(response, 'body'):
|
| 400 |
+
content = response.body.decode() if isinstance(response.body, bytes) else response.body
|
| 401 |
+
elif hasattr(response, 'content'):
|
| 402 |
+
content = response.content.decode() if isinstance(response.content, bytes) else response.content
|
| 403 |
+
else:
|
| 404 |
+
content = str(response)
|
| 405 |
+
|
| 406 |
+
response_data = json.loads(content)
|
| 407 |
+
|
| 408 |
+
# 检查是否包含done标记
|
| 409 |
+
text_content = self._extract_content_from_response(response_data)
|
| 410 |
+
has_done_marker = self._check_done_marker_in_text(text_content)
|
| 411 |
+
|
| 412 |
+
if not has_done_marker and self.current_attempt < self.max_attempts:
|
| 413 |
+
log.info("Anti-truncation: Non-streaming response needs continuation")
|
| 414 |
+
if text_content:
|
| 415 |
+
self.collected_content.append(text_content)
|
| 416 |
+
# 递归处理续传
|
| 417 |
+
return await self._handle_non_streaming_response(
|
| 418 |
+
await self.original_request_func(self._build_current_payload())
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
return content.encode()
|
| 422 |
+
|
| 423 |
+
except Exception as e:
|
| 424 |
+
log.error(f"Anti-truncation non-streaming error: {str(e)}")
|
| 425 |
+
return json.dumps({
|
| 426 |
+
"error": {
|
| 427 |
+
"message": f"Anti-truncation failed: {str(e)}",
|
| 428 |
+
"type": "api_error",
|
| 429 |
+
"code": 500
|
| 430 |
+
}
|
| 431 |
+
}).encode()
|
| 432 |
+
|
| 433 |
+
def _check_done_marker_in_text(self, text: str) -> bool:
|
| 434 |
+
"""检测文本中是否包含DONE_MARKER(只检测指定标记)"""
|
| 435 |
+
if not text:
|
| 436 |
+
return False
|
| 437 |
+
|
| 438 |
+
# 只要文本中出现DONE_MARKER即可
|
| 439 |
+
return DONE_MARKER in text
|
| 440 |
+
|
| 441 |
+
def _check_done_marker_in_chunk_content(self, content: str) -> bool:
|
| 442 |
+
"""检查单个chunk内容中是否包含done标记"""
|
| 443 |
+
return self._check_done_marker_in_text(content)
|
| 444 |
+
|
| 445 |
+
def _extract_content_from_response(self, data: Dict[str, Any]) -> str:
|
| 446 |
+
"""从响应数据中提取文本内容"""
|
| 447 |
+
content = ""
|
| 448 |
+
|
| 449 |
+
# 处理Gemini格式
|
| 450 |
+
if "candidates" in data:
|
| 451 |
+
for candidate in data["candidates"]:
|
| 452 |
+
if "content" in candidate:
|
| 453 |
+
parts = candidate["content"].get("parts", [])
|
| 454 |
+
for part in parts:
|
| 455 |
+
if "text" in part:
|
| 456 |
+
content += part["text"]
|
| 457 |
+
|
| 458 |
+
# 处理OpenAI格式
|
| 459 |
+
elif "choices" in data:
|
| 460 |
+
for choice in data["choices"]:
|
| 461 |
+
if "message" in choice and "content" in choice["message"]:
|
| 462 |
+
content += choice["message"]["content"]
|
| 463 |
+
|
| 464 |
+
return content
|
| 465 |
+
|
| 466 |
+
def _remove_done_marker_from_chunk(self, chunk: bytes, data: Dict[str, Any]) -> bytes:
|
| 467 |
+
"""使用正则表达式从chunk中移除[done]标记"""
|
| 468 |
+
try:
|
| 469 |
+
# 首先检查是否真的包含[done]标记,如果没有则直接返回原始chunk
|
| 470 |
+
chunk_text = chunk.decode('utf-8', errors='ignore') if isinstance(chunk, bytes) else str(chunk)
|
| 471 |
+
if '[done]' not in chunk_text.lower():
|
| 472 |
+
return chunk # 没有[done]标记,直接返回原始chunk
|
| 473 |
+
|
| 474 |
+
# 编译正则表达式,匹配[done]标记(忽略大小写,包括可能的空白字符)
|
| 475 |
+
done_pattern = re.compile(r'\s*\[done\]\s*', re.IGNORECASE)
|
| 476 |
+
|
| 477 |
+
# 处理Gemini格式
|
| 478 |
+
if "candidates" in data:
|
| 479 |
+
modified_data = data.copy()
|
| 480 |
+
modified_data["candidates"] = []
|
| 481 |
+
|
| 482 |
+
for i, candidate in enumerate(data["candidates"]):
|
| 483 |
+
modified_candidate = candidate.copy()
|
| 484 |
+
# 只在最后一个candidate中清理[done]标记
|
| 485 |
+
is_last_candidate = (i == len(data["candidates"]) - 1)
|
| 486 |
+
|
| 487 |
+
if "content" in candidate:
|
| 488 |
+
modified_content = candidate["content"].copy()
|
| 489 |
+
if "parts" in modified_content:
|
| 490 |
+
modified_parts = []
|
| 491 |
+
for part in modified_content["parts"]:
|
| 492 |
+
if "text" in part and isinstance(part["text"], str):
|
| 493 |
+
modified_part = part.copy()
|
| 494 |
+
# 只在最后一个candidate中清理[done]标记
|
| 495 |
+
if is_last_candidate:
|
| 496 |
+
modified_part["text"] = done_pattern.sub('', part["text"])
|
| 497 |
+
modified_parts.append(modified_part)
|
| 498 |
+
else:
|
| 499 |
+
modified_parts.append(part)
|
| 500 |
+
modified_content["parts"] = modified_parts
|
| 501 |
+
modified_candidate["content"] = modified_content
|
| 502 |
+
modified_data["candidates"].append(modified_candidate)
|
| 503 |
+
|
| 504 |
+
# 重新编码为chunk格式,保持原始的换行符
|
| 505 |
+
if isinstance(chunk, bytes):
|
| 506 |
+
prefix = b'data: '
|
| 507 |
+
suffix = b'\n\n' # 确保有正确的换行符
|
| 508 |
+
json_data = json.dumps(modified_data, separators=(',',':'), ensure_ascii=False).encode('utf-8')
|
| 509 |
+
return prefix + json_data + suffix
|
| 510 |
+
else:
|
| 511 |
+
return f"data: {json.dumps(modified_data, separators=(',',':'), ensure_ascii=False)}\n\n"
|
| 512 |
+
|
| 513 |
+
# 处理OpenAI格式
|
| 514 |
+
elif "choices" in data:
|
| 515 |
+
modified_data = data.copy()
|
| 516 |
+
modified_data["choices"] = []
|
| 517 |
+
|
| 518 |
+
for choice in data["choices"]:
|
| 519 |
+
modified_choice = choice.copy()
|
| 520 |
+
if "delta" in choice and "content" in choice["delta"]:
|
| 521 |
+
modified_delta = choice["delta"].copy()
|
| 522 |
+
modified_delta["content"] = done_pattern.sub('', choice["delta"]["content"])
|
| 523 |
+
modified_choice["delta"] = modified_delta
|
| 524 |
+
elif "message" in choice and "content" in choice["message"]:
|
| 525 |
+
modified_message = choice["message"].copy()
|
| 526 |
+
modified_message["content"] = done_pattern.sub('', choice["message"]["content"])
|
| 527 |
+
modified_choice["message"] = modified_message
|
| 528 |
+
modified_data["choices"].append(modified_choice)
|
| 529 |
+
|
| 530 |
+
# 重新编码为chunk格式,保持原始的换行符
|
| 531 |
+
if isinstance(chunk, bytes):
|
| 532 |
+
prefix = b'data: '
|
| 533 |
+
suffix = b'\n\n' # 确保有正确的换行符
|
| 534 |
+
json_data = json.dumps(modified_data, separators=(',',':'), ensure_ascii=False).encode('utf-8')
|
| 535 |
+
return prefix + json_data + suffix
|
| 536 |
+
else:
|
| 537 |
+
return f"data: {json.dumps(modified_data, separators=(',',':'), ensure_ascii=False)}\n\n"
|
| 538 |
+
|
| 539 |
+
# 如果没有找到支持的格式,返回原始chunk
|
| 540 |
+
return chunk
|
| 541 |
+
|
| 542 |
+
except Exception as e:
|
| 543 |
+
log.warning(f"Failed to remove [done] marker from chunk: {str(e)}")
|
| 544 |
+
return chunk
|
| 545 |
+
|
| 546 |
+
async def apply_anti_truncation_to_stream(
|
| 547 |
+
request_func,
|
| 548 |
+
payload: Dict[str, Any],
|
| 549 |
+
max_attempts: int = MAX_CONTINUATION_ATTEMPTS
|
| 550 |
+
) -> StreamingResponse:
|
| 551 |
+
"""
|
| 552 |
+
对流式请求应用反截断处理
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
request_func: 原始请求函数
|
| 556 |
+
payload: 请求payload
|
| 557 |
+
max_attempts: 最大续传尝试次数
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
处理后的StreamingResponse
|
| 561 |
+
"""
|
| 562 |
+
|
| 563 |
+
# 首先对payload应用反截断指令
|
| 564 |
+
anti_truncation_payload = apply_anti_truncation(payload)
|
| 565 |
+
|
| 566 |
+
# 创建反截断处理器
|
| 567 |
+
processor = AntiTruncationStreamProcessor(
|
| 568 |
+
lambda p: request_func(p),
|
| 569 |
+
anti_truncation_payload,
|
| 570 |
+
max_attempts
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# 返回包装后的流式响应
|
| 574 |
+
return StreamingResponse(
|
| 575 |
+
processor.process_stream(),
|
| 576 |
+
media_type="text/event-stream"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
def is_anti_truncation_enabled(request_data: Dict[str, Any]) -> bool:
|
| 580 |
+
"""
|
| 581 |
+
检查请求是否启用了反截断功能
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
request_data: 请求数据
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
是否启用反截断
|
| 588 |
+
"""
|
| 589 |
+
return request_data.get("enable_anti_truncation", False)
|
src/auth.py
ADDED
|
@@ -0,0 +1,1530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
认证API模块 - 使用统一存储中间层,完全摆脱文件操作
|
| 3 |
+
"""
|
| 4 |
+
import asyncio
|
| 5 |
+
import json
|
| 6 |
+
import secrets
|
| 7 |
+
import socket
|
| 8 |
+
import threading
|
| 9 |
+
import time
|
| 10 |
+
import uuid
|
| 11 |
+
from datetime import timezone
|
| 12 |
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
| 13 |
+
from typing import Optional, Dict, Any, List
|
| 14 |
+
from urllib.parse import urlparse, parse_qs
|
| 15 |
+
|
| 16 |
+
from .google_oauth_api import Credentials, Flow, enable_required_apis, get_user_projects, select_default_project
|
| 17 |
+
from .storage_adapter import get_storage_adapter
|
| 18 |
+
from config import get_config_value
|
| 19 |
+
from log import log
|
| 20 |
+
|
| 21 |
+
# OAuth Configuration
|
| 22 |
+
CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
| 23 |
+
CLIENT_SECRET = "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"
|
| 24 |
+
SCOPES = [
|
| 25 |
+
"https://www.googleapis.com/auth/cloud-platform",
|
| 26 |
+
"https://www.googleapis.com/auth/userinfo.email",
|
| 27 |
+
"https://www.googleapis.com/auth/userinfo.profile",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
# 回调服务器配置
|
| 31 |
+
CALLBACK_HOST = 'localhost'
|
| 32 |
+
|
| 33 |
+
async def get_callback_port():
|
| 34 |
+
"""获取OAuth回调端口"""
|
| 35 |
+
return int(await get_config_value('oauth_callback_port', '8080', 'OAUTH_CALLBACK_PORT'))
|
| 36 |
+
|
| 37 |
+
# 全局状态管理 - 严格限制大小
|
| 38 |
+
auth_flows = {} # 存储进行中的认证流程
|
| 39 |
+
MAX_AUTH_FLOWS = 20 # 严格限制最大认证流程数
|
| 40 |
+
|
| 41 |
+
def cleanup_auth_flows_for_memory():
|
| 42 |
+
"""清理认证流程以释放内存"""
|
| 43 |
+
global auth_flows
|
| 44 |
+
cleaned = cleanup_expired_flows()
|
| 45 |
+
# 如果还是太多,强制清理一些旧的流程
|
| 46 |
+
if len(auth_flows) > 10:
|
| 47 |
+
# 按创建时间排序,保留最新的10个
|
| 48 |
+
sorted_flows = sorted(auth_flows.items(), key=lambda x: x[1].get('created_at', 0), reverse=True)
|
| 49 |
+
new_auth_flows = dict(sorted_flows[:10])
|
| 50 |
+
|
| 51 |
+
# 清理被移除的流程
|
| 52 |
+
for state, flow_data in auth_flows.items():
|
| 53 |
+
if state not in new_auth_flows:
|
| 54 |
+
try:
|
| 55 |
+
if flow_data.get('server'):
|
| 56 |
+
server = flow_data['server']
|
| 57 |
+
port = flow_data.get('callback_port')
|
| 58 |
+
async_shutdown_server(server, port)
|
| 59 |
+
except Exception:
|
| 60 |
+
pass
|
| 61 |
+
flow_data.clear()
|
| 62 |
+
|
| 63 |
+
auth_flows = new_auth_flows
|
| 64 |
+
log.info(f"强制清理认证流程,保留 {len(auth_flows)} 个最新流程")
|
| 65 |
+
|
| 66 |
+
return len(auth_flows)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
async def find_available_port(start_port: int = None) -> int:
|
| 70 |
+
"""动态查找可用端口"""
|
| 71 |
+
if start_port is None:
|
| 72 |
+
start_port = await get_callback_port()
|
| 73 |
+
|
| 74 |
+
# 首先尝试默认端口
|
| 75 |
+
for port in range(start_port, start_port + 100): # 尝试100个端口
|
| 76 |
+
try:
|
| 77 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 78 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 79 |
+
s.bind(('0.0.0.0', port))
|
| 80 |
+
log.info(f"找到可用端口: {port}")
|
| 81 |
+
return port
|
| 82 |
+
except OSError:
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
# 如果都不可用,让系统自动分配端口
|
| 86 |
+
try:
|
| 87 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 88 |
+
s.bind(('0.0.0.0', 0))
|
| 89 |
+
port = s.getsockname()[1]
|
| 90 |
+
log.info(f"系统分配可用端口: {port}")
|
| 91 |
+
return port
|
| 92 |
+
except OSError as e:
|
| 93 |
+
log.error(f"无法找到可用端口: {e}")
|
| 94 |
+
raise RuntimeError("无法找到可用端口")
|
| 95 |
+
|
| 96 |
+
def create_callback_server(port: int) -> HTTPServer:
|
| 97 |
+
"""创建指定端口的回调服务器,优化快速关闭"""
|
| 98 |
+
try:
|
| 99 |
+
# 服务器监听0.0.0.0
|
| 100 |
+
server = HTTPServer(("0.0.0.0", port), AuthCallbackHandler)
|
| 101 |
+
|
| 102 |
+
# 设置socket选项以支持快速关闭
|
| 103 |
+
server.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 104 |
+
# 设置较短的超时时间
|
| 105 |
+
server.timeout = 1.0
|
| 106 |
+
|
| 107 |
+
log.info(f"创建OAuth回调服务器,监听端口: {port}")
|
| 108 |
+
return server
|
| 109 |
+
except OSError as e:
|
| 110 |
+
log.error(f"创建端口{port}的服务器失败: {e}")
|
| 111 |
+
raise
|
| 112 |
+
|
| 113 |
+
class AuthCallbackHandler(BaseHTTPRequestHandler):
|
| 114 |
+
"""OAuth回调处理器"""
|
| 115 |
+
def do_GET(self):
|
| 116 |
+
query_components = parse_qs(urlparse(self.path).query)
|
| 117 |
+
code = query_components.get("code", [None])[0]
|
| 118 |
+
state = query_components.get("state", [None])[0]
|
| 119 |
+
|
| 120 |
+
log.info(f"收到OAuth回调: code={'已获取' if code else '未获取'}, state={state}")
|
| 121 |
+
|
| 122 |
+
if code and state and state in auth_flows:
|
| 123 |
+
# 更新流程状态
|
| 124 |
+
auth_flows[state]['code'] = code
|
| 125 |
+
auth_flows[state]['completed'] = True
|
| 126 |
+
|
| 127 |
+
log.info(f"OAuth回调成功处理: state={state}")
|
| 128 |
+
|
| 129 |
+
self.send_response(200)
|
| 130 |
+
self.send_header("Content-type", "text/html")
|
| 131 |
+
self.end_headers()
|
| 132 |
+
# 成功页面
|
| 133 |
+
self.wfile.write(b"<h1>OAuth authentication successful!</h1><p>You can close this window. Please return to the original page and click 'Get Credentials' button.</p>")
|
| 134 |
+
else:
|
| 135 |
+
self.send_response(400)
|
| 136 |
+
self.send_header("Content-type", "text/html")
|
| 137 |
+
self.end_headers()
|
| 138 |
+
self.wfile.write(b"<h1>Authentication failed.</h1><p>Please try again.</p>")
|
| 139 |
+
|
| 140 |
+
def log_message(self, format, *args):
|
| 141 |
+
# 减少日志噪音
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
async def create_auth_url(project_id: Optional[str] = None, user_session: str = None, get_all_projects: bool = False) -> Dict[str, Any]:
|
| 146 |
+
"""创建认证URL,支持动态端口分配"""
|
| 147 |
+
try:
|
| 148 |
+
# 动态分配端口
|
| 149 |
+
callback_port = await find_available_port()
|
| 150 |
+
callback_url = f"http://{CALLBACK_HOST}:{callback_port}"
|
| 151 |
+
|
| 152 |
+
# 立即启动回调服务器
|
| 153 |
+
try:
|
| 154 |
+
callback_server = create_callback_server(callback_port)
|
| 155 |
+
# 在后台线程中运行服务器
|
| 156 |
+
server_thread = threading.Thread(
|
| 157 |
+
target=callback_server.serve_forever,
|
| 158 |
+
daemon=True,
|
| 159 |
+
name=f"OAuth-Server-{callback_port}"
|
| 160 |
+
)
|
| 161 |
+
server_thread.start()
|
| 162 |
+
log.info(f"OAuth回调服务器已启动,端口: {callback_port}")
|
| 163 |
+
except Exception as e:
|
| 164 |
+
log.error(f"启动回调服务器失败: {e}")
|
| 165 |
+
return {
|
| 166 |
+
'success': False,
|
| 167 |
+
'error': f'无法启动OAuth回调服务器,端口{callback_port}: {str(e)}'
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
# 创建OAuth流程
|
| 171 |
+
client_config = {
|
| 172 |
+
"installed": {
|
| 173 |
+
"client_id": CLIENT_ID,
|
| 174 |
+
"client_secret": CLIENT_SECRET,
|
| 175 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
| 176 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
flow = Flow(
|
| 181 |
+
client_id=CLIENT_ID,
|
| 182 |
+
client_secret=CLIENT_SECRET,
|
| 183 |
+
scopes=SCOPES,
|
| 184 |
+
redirect_uri=callback_url
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# 生成状态标识符,包含用户会话信息
|
| 188 |
+
if user_session:
|
| 189 |
+
state = f"{user_session}_{str(uuid.uuid4())}"
|
| 190 |
+
else:
|
| 191 |
+
state = str(uuid.uuid4())
|
| 192 |
+
|
| 193 |
+
# 生成认证URL
|
| 194 |
+
auth_url = flow.get_auth_url(state=state)
|
| 195 |
+
|
| 196 |
+
# 严格控制认证流程数量 - 超过限制时立即清理最旧的
|
| 197 |
+
if len(auth_flows) >= MAX_AUTH_FLOWS:
|
| 198 |
+
# 清理最旧的认证流程
|
| 199 |
+
oldest_state = min(auth_flows.keys(),
|
| 200 |
+
key=lambda k: auth_flows[k].get('created_at', 0))
|
| 201 |
+
try:
|
| 202 |
+
# 清理服务器资源
|
| 203 |
+
old_flow = auth_flows[oldest_state]
|
| 204 |
+
if old_flow.get('server'):
|
| 205 |
+
server = old_flow['server']
|
| 206 |
+
port = old_flow.get('callback_port')
|
| 207 |
+
async_shutdown_server(server, port)
|
| 208 |
+
except Exception as e:
|
| 209 |
+
log.warning(f"Failed to cleanup old auth flow {oldest_state}: {e}")
|
| 210 |
+
|
| 211 |
+
del auth_flows[oldest_state]
|
| 212 |
+
log.debug(f"Removed oldest auth flow: {oldest_state}")
|
| 213 |
+
|
| 214 |
+
# 保存流程状态
|
| 215 |
+
auth_flows[state] = {
|
| 216 |
+
'flow': flow,
|
| 217 |
+
'project_id': project_id, # 可能为None,稍后在回调时确定
|
| 218 |
+
'user_session': user_session,
|
| 219 |
+
'callback_port': callback_port, # 存储分配的端口
|
| 220 |
+
'callback_url': callback_url, # 存储完整回调URL
|
| 221 |
+
'server': callback_server, # 存储服务器实例
|
| 222 |
+
'server_thread': server_thread, # 存储服务器线程
|
| 223 |
+
'code': None,
|
| 224 |
+
'completed': False,
|
| 225 |
+
'created_at': time.time(),
|
| 226 |
+
'auto_project_detection': project_id is None, # 标记是否需要自动检测项目ID
|
| 227 |
+
'get_all_projects': get_all_projects # 是否为所有项目获取凭证
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
# 清理过期的流程(30分钟)
|
| 231 |
+
cleanup_expired_flows()
|
| 232 |
+
|
| 233 |
+
log.info(f"OAuth流程已创建: state={state}, project_id={project_id}")
|
| 234 |
+
log.info(f"用户需要访问认证URL,然后OAuth会回调到 {callback_url}")
|
| 235 |
+
log.info(f"为此认证流程分配的端口: {callback_port}")
|
| 236 |
+
|
| 237 |
+
return {
|
| 238 |
+
'auth_url': auth_url,
|
| 239 |
+
'state': state,
|
| 240 |
+
'callback_port': callback_port,
|
| 241 |
+
'success': True,
|
| 242 |
+
'auto_project_detection': project_id is None,
|
| 243 |
+
'detected_project_id': project_id
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
log.error(f"创建认证URL失败: {e}")
|
| 248 |
+
return {
|
| 249 |
+
'success': False,
|
| 250 |
+
'error': str(e)
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def wait_for_callback_sync(state: str, timeout: int = 300) -> Optional[str]:
|
| 255 |
+
"""同步等待OAuth回调完成,使用对应流程的专用服务器"""
|
| 256 |
+
if state not in auth_flows:
|
| 257 |
+
log.error(f"未找到状态为 {state} 的认证流程")
|
| 258 |
+
return None
|
| 259 |
+
|
| 260 |
+
flow_data = auth_flows[state]
|
| 261 |
+
callback_port = flow_data['callback_port']
|
| 262 |
+
|
| 263 |
+
# 服务器已经在create_auth_url时启动了,这里只需要等待
|
| 264 |
+
log.info(f"等待OAuth回调完成,端口: {callback_port}")
|
| 265 |
+
|
| 266 |
+
# 等待回调完成
|
| 267 |
+
start_time = time.time()
|
| 268 |
+
while time.time() - start_time < timeout:
|
| 269 |
+
if flow_data.get('code'):
|
| 270 |
+
log.info(f"OAuth回调成功完成")
|
| 271 |
+
return flow_data['code']
|
| 272 |
+
time.sleep(0.5) # 每0.5秒检查一次
|
| 273 |
+
|
| 274 |
+
# 刷新flow_data引用
|
| 275 |
+
if state in auth_flows:
|
| 276 |
+
flow_data = auth_flows[state]
|
| 277 |
+
|
| 278 |
+
log.warning(f"等待OAuth回调超时 ({timeout}秒)")
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
async def complete_auth_flow(project_id: Optional[str] = None, user_session: str = None) -> Dict[str, Any]:
|
| 283 |
+
"""完成认证流程并保存凭证,支持自动检测项目ID"""
|
| 284 |
+
try:
|
| 285 |
+
# 查找对应的认证流程
|
| 286 |
+
state = None
|
| 287 |
+
flow_data = None
|
| 288 |
+
|
| 289 |
+
# 如果指定了project_id,先尝试匹配指定的项目
|
| 290 |
+
if project_id:
|
| 291 |
+
for s, data in auth_flows.items():
|
| 292 |
+
if data['project_id'] == project_id:
|
| 293 |
+
# 如果指定了用户会话,优先匹配相同会话的流程
|
| 294 |
+
if user_session and data.get('user_session') == user_session:
|
| 295 |
+
state = s
|
| 296 |
+
flow_data = data
|
| 297 |
+
break
|
| 298 |
+
# 如果没有指定会话,或没找到匹配会话的流程,使用第一个匹配项目ID的
|
| 299 |
+
elif not state:
|
| 300 |
+
state = s
|
| 301 |
+
flow_data = data
|
| 302 |
+
|
| 303 |
+
# 如果没有指定项目ID或没找到匹配的,查找需要自动检测项目ID的流程
|
| 304 |
+
if not state:
|
| 305 |
+
for s, data in auth_flows.items():
|
| 306 |
+
if data.get('auto_project_detection', False):
|
| 307 |
+
# 如果指定了用户会话,优先匹配相同会话的流程
|
| 308 |
+
if user_session and data.get('user_session') == user_session:
|
| 309 |
+
state = s
|
| 310 |
+
flow_data = data
|
| 311 |
+
break
|
| 312 |
+
# 使用第一个找到的需要自动检测的流程
|
| 313 |
+
elif not state:
|
| 314 |
+
state = s
|
| 315 |
+
flow_data = data
|
| 316 |
+
|
| 317 |
+
if not state or not flow_data:
|
| 318 |
+
return {
|
| 319 |
+
'success': False,
|
| 320 |
+
'error': '未找到对应的认证流程,请先点击获取认证链接'
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
if not project_id:
|
| 324 |
+
project_id = flow_data.get('project_id')
|
| 325 |
+
if not project_id:
|
| 326 |
+
return {
|
| 327 |
+
'success': False,
|
| 328 |
+
'error': '缺少项目ID,请指定项目ID',
|
| 329 |
+
'requires_manual_project_id': True
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
flow = flow_data['flow']
|
| 333 |
+
|
| 334 |
+
# 如果还没有授权码,需要等待回调
|
| 335 |
+
if not flow_data.get('code'):
|
| 336 |
+
log.info(f"等待用户完成OAuth授权 (state: {state})")
|
| 337 |
+
auth_code = wait_for_callback_sync(state)
|
| 338 |
+
|
| 339 |
+
if not auth_code:
|
| 340 |
+
return {
|
| 341 |
+
'success': False,
|
| 342 |
+
'error': '未接收到授权回调,请确保完成了浏览器中的OAuth认证'
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
# 更新流程数据
|
| 346 |
+
auth_flows[state]['code'] = auth_code
|
| 347 |
+
auth_flows[state]['completed'] = True
|
| 348 |
+
else:
|
| 349 |
+
auth_code = flow_data['code']
|
| 350 |
+
|
| 351 |
+
# 使用认证代码获取凭证
|
| 352 |
+
import oauthlib.oauth2.rfc6749.parameters
|
| 353 |
+
original_validate = oauthlib.oauth2.rfc6749.parameters.validate_token_parameters
|
| 354 |
+
|
| 355 |
+
def patched_validate(params):
|
| 356 |
+
try:
|
| 357 |
+
return original_validate(params)
|
| 358 |
+
except Warning:
|
| 359 |
+
pass
|
| 360 |
+
|
| 361 |
+
oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = patched_validate
|
| 362 |
+
|
| 363 |
+
try:
|
| 364 |
+
credentials = await flow.exchange_code(auth_code)
|
| 365 |
+
# credentials 已经在 exchange_code 中获得
|
| 366 |
+
|
| 367 |
+
# 如果需要自动检测项目ID且没有提供项目ID
|
| 368 |
+
if flow_data.get('auto_project_detection', False) and not project_id:
|
| 369 |
+
log.info("尝试通过API获取用户项目列表...")
|
| 370 |
+
log.info(f"使用的token: {credentials.access_token[:20]}...")
|
| 371 |
+
log.info(f"Token过期时间: {credentials.expires_at}")
|
| 372 |
+
user_projects = await get_user_projects(credentials)
|
| 373 |
+
|
| 374 |
+
if user_projects:
|
| 375 |
+
# 如果只有一个项目,自动使用
|
| 376 |
+
if len(user_projects) == 1:
|
| 377 |
+
project_id = user_projects[0].get('projectId')
|
| 378 |
+
if project_id:
|
| 379 |
+
flow_data['project_id'] = project_id
|
| 380 |
+
log.info(f"自动选择唯一项目: {project_id}")
|
| 381 |
+
# 如果有多个项目,尝试选择默认项目
|
| 382 |
+
else:
|
| 383 |
+
project_id = await select_default_project(user_projects)
|
| 384 |
+
if project_id:
|
| 385 |
+
flow_data['project_id'] = project_id
|
| 386 |
+
log.info(f"自动选择默认项目: {project_id}")
|
| 387 |
+
else:
|
| 388 |
+
# 返回项目列表让用户选择
|
| 389 |
+
return {
|
| 390 |
+
'success': False,
|
| 391 |
+
'error': '请从以下项目中选择一个',
|
| 392 |
+
'requires_project_selection': True,
|
| 393 |
+
'available_projects': [
|
| 394 |
+
{
|
| 395 |
+
'projectId': p.get('projectId'),
|
| 396 |
+
'name': p.get('displayName') or p.get('projectId'),
|
| 397 |
+
'projectNumber': p.get('projectNumber')
|
| 398 |
+
}
|
| 399 |
+
for p in user_projects
|
| 400 |
+
]
|
| 401 |
+
}
|
| 402 |
+
else:
|
| 403 |
+
# 如果无法获取项目列表,提示手动输入
|
| 404 |
+
return {
|
| 405 |
+
'success': False,
|
| 406 |
+
'error': '无法获取您的项目列表,请手动指定项目ID',
|
| 407 |
+
'requires_manual_project_id': True
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
# 如果仍然没有项目ID,返回错误
|
| 411 |
+
if not project_id:
|
| 412 |
+
return {
|
| 413 |
+
'success': False,
|
| 414 |
+
'error': '缺少项目ID,请指定项目ID',
|
| 415 |
+
'requires_manual_project_id': True
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
# 保存凭证
|
| 419 |
+
saved_filename = await save_credentials(credentials, project_id)
|
| 420 |
+
|
| 421 |
+
# 准备返回的凭证数据
|
| 422 |
+
creds_data = {
|
| 423 |
+
"client_id": CLIENT_ID,
|
| 424 |
+
"client_secret": CLIENT_SECRET,
|
| 425 |
+
"token": credentials.access_token,
|
| 426 |
+
"refresh_token": credentials.refresh_token,
|
| 427 |
+
"scopes": SCOPES,
|
| 428 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 429 |
+
"project_id": project_id
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
if credentials.expires_at:
|
| 433 |
+
if credentials.expires_at.tzinfo is None:
|
| 434 |
+
expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc)
|
| 435 |
+
else:
|
| 436 |
+
expiry_utc = credentials.expires_at
|
| 437 |
+
creds_data["expiry"] = expiry_utc.isoformat()
|
| 438 |
+
|
| 439 |
+
# 清理使用过的流程
|
| 440 |
+
if state in auth_flows:
|
| 441 |
+
flow_data_to_clean = auth_flows[state]
|
| 442 |
+
# 快速关闭服务器
|
| 443 |
+
try:
|
| 444 |
+
if flow_data_to_clean.get('server'):
|
| 445 |
+
server = flow_data_to_clean['server']
|
| 446 |
+
port = flow_data_to_clean.get('callback_port')
|
| 447 |
+
async_shutdown_server(server, port)
|
| 448 |
+
except Exception as e:
|
| 449 |
+
log.debug(f"启动异步关闭服务器时出错: {e}")
|
| 450 |
+
|
| 451 |
+
del auth_flows[state]
|
| 452 |
+
|
| 453 |
+
log.info("OAuth认证成功,凭证已保存")
|
| 454 |
+
return {
|
| 455 |
+
'success': True,
|
| 456 |
+
'credentials': creds_data,
|
| 457 |
+
'file_path': saved_filename,
|
| 458 |
+
'auto_detected_project': flow_data.get('auto_project_detection', False)
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
except Exception as e:
|
| 462 |
+
log.error(f"获取凭证失败: {e}")
|
| 463 |
+
return {
|
| 464 |
+
'success': False,
|
| 465 |
+
'error': f'获取凭证失败: {str(e)}'
|
| 466 |
+
}
|
| 467 |
+
finally:
|
| 468 |
+
oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = original_validate
|
| 469 |
+
|
| 470 |
+
except Exception as e:
|
| 471 |
+
log.error(f"完成认证流程失败: {e}")
|
| 472 |
+
return {
|
| 473 |
+
'success': False,
|
| 474 |
+
'error': str(e)
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
async def asyncio_complete_auth_flow(project_id: Optional[str] = None, user_session: str = None, get_all_projects: bool = False) -> Dict[str, Any]:
|
| 479 |
+
"""异步完成认证流程,支持自动检测项目ID"""
|
| 480 |
+
try:
|
| 481 |
+
log.info(f"asyncio_complete_auth_flow开始执行: project_id={project_id}, user_session={user_session}")
|
| 482 |
+
|
| 483 |
+
# 查找对应的认证流程
|
| 484 |
+
state = None
|
| 485 |
+
flow_data = None
|
| 486 |
+
|
| 487 |
+
log.debug(f"当前所有auth_flows: {list(auth_flows.keys())}")
|
| 488 |
+
|
| 489 |
+
# 如果指定了project_id,先尝试匹配指定的项目
|
| 490 |
+
if project_id:
|
| 491 |
+
log.info(f"尝试匹配指定的项目ID: {project_id}")
|
| 492 |
+
for s, data in auth_flows.items():
|
| 493 |
+
if data['project_id'] == project_id:
|
| 494 |
+
# 如果指定了用户会话,优先匹配相同会话的流程
|
| 495 |
+
if user_session and data.get('user_session') == user_session:
|
| 496 |
+
state = s
|
| 497 |
+
flow_data = data
|
| 498 |
+
log.info(f"找到匹配的用户会话: {s}")
|
| 499 |
+
break
|
| 500 |
+
# 如果没有指定会话,或没找到匹配会话的流程,使用第一个匹配项目ID的
|
| 501 |
+
elif not state:
|
| 502 |
+
state = s
|
| 503 |
+
flow_data = data
|
| 504 |
+
log.info(f"找到匹配的项目ID: {s}")
|
| 505 |
+
|
| 506 |
+
# 如果没有指定项目ID或没找到匹配的,查找需要自动检测项目ID的流程
|
| 507 |
+
if not state:
|
| 508 |
+
log.info(f"没有找到指定项目的流程,查找自动检测流程")
|
| 509 |
+
for s, data in auth_flows.items():
|
| 510 |
+
log.debug(f"检查流程 {s}: auto_project_detection={data.get('auto_project_detection', False)}")
|
| 511 |
+
if data.get('auto_project_detection', False):
|
| 512 |
+
# 如果指定了用户会话,优先匹配相同会话的流程
|
| 513 |
+
if user_session and data.get('user_session') == user_session:
|
| 514 |
+
state = s
|
| 515 |
+
flow_data = data
|
| 516 |
+
log.info(f"找到匹配用户会话的自动检测流程: {s}")
|
| 517 |
+
break
|
| 518 |
+
# 使用第一个找到的需要自动检测的流程
|
| 519 |
+
elif not state:
|
| 520 |
+
state = s
|
| 521 |
+
flow_data = data
|
| 522 |
+
log.info(f"找到自动检测流程: {s}")
|
| 523 |
+
|
| 524 |
+
if not state or not flow_data:
|
| 525 |
+
log.error(f"未找到认证流程: state={state}, flow_data存在={bool(flow_data)}")
|
| 526 |
+
log.debug(f"当前所有flow_data: {list(auth_flows.keys())}")
|
| 527 |
+
return {
|
| 528 |
+
'success': False,
|
| 529 |
+
'error': '未找到对应的认证流程,请先点击获取认证链接'
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
log.info(f"找到认证流程: state={state}")
|
| 533 |
+
log.info(f"flow_data内容: project_id={flow_data.get('project_id')}, auto_project_detection={flow_data.get('auto_project_detection')}")
|
| 534 |
+
log.info(f"传入的project_id参数: {project_id}")
|
| 535 |
+
|
| 536 |
+
# 如果需要自动检测项目ID且没有提供项目ID
|
| 537 |
+
log.info(f"检查auto_project_detection条件: auto_project_detection={flow_data.get('auto_project_detection', False)}, not project_id={not project_id}")
|
| 538 |
+
if flow_data.get('auto_project_detection', False) and not project_id:
|
| 539 |
+
log.info("跳过自动检测项目ID,进入等待阶段")
|
| 540 |
+
elif not project_id:
|
| 541 |
+
log.info("进入project_id检查分支")
|
| 542 |
+
project_id = flow_data.get('project_id')
|
| 543 |
+
if not project_id:
|
| 544 |
+
log.error("缺少项目ID,返回错误")
|
| 545 |
+
return {
|
| 546 |
+
'success': False,
|
| 547 |
+
'error': '缺少项目ID,请指定项目ID',
|
| 548 |
+
'requires_manual_project_id': True
|
| 549 |
+
}
|
| 550 |
+
else:
|
| 551 |
+
log.info(f"使用提供的项目ID: {project_id}")
|
| 552 |
+
|
| 553 |
+
# 检查是否已经有授权码
|
| 554 |
+
log.info(f"开始检查OAuth授权码...")
|
| 555 |
+
max_wait_time = 60 # 最多等待60秒
|
| 556 |
+
wait_interval = 1 # 每秒检查一次
|
| 557 |
+
waited = 0
|
| 558 |
+
|
| 559 |
+
while waited < max_wait_time:
|
| 560 |
+
log.debug(f"等待OAuth授权码... ({waited}/{max_wait_time}秒)")
|
| 561 |
+
if flow_data.get('code'):
|
| 562 |
+
log.info(f"检测到OAuth授权码,开始处理凭证 (等待时间: {waited}秒)")
|
| 563 |
+
break
|
| 564 |
+
|
| 565 |
+
# 异步等待
|
| 566 |
+
await asyncio.sleep(wait_interval)
|
| 567 |
+
waited += wait_interval
|
| 568 |
+
|
| 569 |
+
# 刷新flow_data引用,因为可能被回调更新了
|
| 570 |
+
if state in auth_flows:
|
| 571 |
+
flow_data = auth_flows[state]
|
| 572 |
+
log.debug(f"刷新flow_data: completed={flow_data.get('completed')}, code存在={bool(flow_data.get('code'))}")
|
| 573 |
+
|
| 574 |
+
if not flow_data.get('code'):
|
| 575 |
+
log.error(f"等待OAuth回调超时,等待了{waited}秒")
|
| 576 |
+
return {
|
| 577 |
+
'success': False,
|
| 578 |
+
'error': '等待OAuth回调超时,请确保完成了浏览器中的认证并看到成功页面'
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
flow = flow_data['flow']
|
| 582 |
+
auth_code = flow_data['code']
|
| 583 |
+
|
| 584 |
+
log.info(f"开始使用授权码获取凭证: code={'***' + auth_code[-4:] if auth_code else 'None'}")
|
| 585 |
+
|
| 586 |
+
# 使用认证代码获取凭证
|
| 587 |
+
import oauthlib.oauth2.rfc6749.parameters
|
| 588 |
+
original_validate = oauthlib.oauth2.rfc6749.parameters.validate_token_parameters
|
| 589 |
+
|
| 590 |
+
def patched_validate(params):
|
| 591 |
+
try:
|
| 592 |
+
return original_validate(params)
|
| 593 |
+
except Warning:
|
| 594 |
+
pass
|
| 595 |
+
|
| 596 |
+
oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = patched_validate
|
| 597 |
+
|
| 598 |
+
try:
|
| 599 |
+
log.info(f"调用flow.exchange_code...")
|
| 600 |
+
credentials = await flow.exchange_code(auth_code)
|
| 601 |
+
log.info(f"成功获取凭证,token前缀: {credentials.access_token[:20] if credentials.access_token else 'None'}...")
|
| 602 |
+
|
| 603 |
+
log.info(f"检查是否需要项目检测: auto_project_detection={flow_data.get('auto_project_detection')}, project_id={project_id}")
|
| 604 |
+
|
| 605 |
+
# 检查是否为批量获取所有项目模式
|
| 606 |
+
if flow_data.get('get_all_projects', False) or get_all_projects:
|
| 607 |
+
log.info("批量模式:为所有项目并发获取凭证...")
|
| 608 |
+
user_projects = await get_user_projects(credentials)
|
| 609 |
+
|
| 610 |
+
if user_projects:
|
| 611 |
+
async def process_single_project(project_info):
|
| 612 |
+
"""并发处理单个项目的凭证获取"""
|
| 613 |
+
project_id_current = project_info.get('projectId')
|
| 614 |
+
project_name = project_info.get('displayName') or project_id_current
|
| 615 |
+
|
| 616 |
+
try:
|
| 617 |
+
log.info(f"为项目 {project_name} ({project_id_current}) 启用API服务...")
|
| 618 |
+
await enable_required_apis(credentials, project_id_current)
|
| 619 |
+
|
| 620 |
+
# 保存凭证
|
| 621 |
+
saved_filename = await save_credentials(credentials, project_id_current)
|
| 622 |
+
|
| 623 |
+
log.info(f"成功为项目 {project_name} 保存凭证")
|
| 624 |
+
return {
|
| 625 |
+
'status': 'success',
|
| 626 |
+
'project_id': project_id_current,
|
| 627 |
+
'project_name': project_name,
|
| 628 |
+
'file_path': saved_filename
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
except Exception as e:
|
| 632 |
+
log.error(f"为项目 {project_name} ({project_id_current}) 处理凭证失败: {e}")
|
| 633 |
+
return {
|
| 634 |
+
'status': 'failed',
|
| 635 |
+
'project_id': project_id_current,
|
| 636 |
+
'project_name': project_name,
|
| 637 |
+
'error': str(e)
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
# 并发处理所有项目
|
| 641 |
+
log.info(f"开始并发处理 {len(user_projects)} 个项目...")
|
| 642 |
+
tasks = [process_single_project(project_info) for project_info in user_projects]
|
| 643 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 644 |
+
|
| 645 |
+
# 整理结果
|
| 646 |
+
multiple_results = {'success': [], 'failed': []}
|
| 647 |
+
for result in results:
|
| 648 |
+
if isinstance(result, Exception):
|
| 649 |
+
log.error(f"并发处理项目时发生异常: {result}")
|
| 650 |
+
multiple_results['failed'].append({
|
| 651 |
+
'project_id': 'unknown',
|
| 652 |
+
'project_name': 'unknown',
|
| 653 |
+
'error': f'处理异常: {str(result)}'
|
| 654 |
+
})
|
| 655 |
+
elif result['status'] == 'success':
|
| 656 |
+
multiple_results['success'].append({
|
| 657 |
+
'project_id': result['project_id'],
|
| 658 |
+
'project_name': result['project_name'],
|
| 659 |
+
'file_path': result['file_path']
|
| 660 |
+
})
|
| 661 |
+
else: # failed
|
| 662 |
+
multiple_results['failed'].append({
|
| 663 |
+
'project_id': result['project_id'],
|
| 664 |
+
'project_name': result['project_name'],
|
| 665 |
+
'error': result['error']
|
| 666 |
+
})
|
| 667 |
+
|
| 668 |
+
# 清理使用过的流程
|
| 669 |
+
if state in auth_flows:
|
| 670 |
+
flow_data_to_clean = auth_flows[state]
|
| 671 |
+
try:
|
| 672 |
+
if flow_data_to_clean.get('server'):
|
| 673 |
+
server = flow_data_to_clean['server']
|
| 674 |
+
port = flow_data_to_clean.get('callback_port')
|
| 675 |
+
async_shutdown_server(server, port)
|
| 676 |
+
except Exception as e:
|
| 677 |
+
log.debug(f"启动异步关闭服务器时出错: {e}")
|
| 678 |
+
del auth_flows[state]
|
| 679 |
+
|
| 680 |
+
log.info(f"批量并发认证完成:成功 {len(multiple_results['success'])} 个,失败 {len(multiple_results['failed'])} 个")
|
| 681 |
+
return {
|
| 682 |
+
'success': True,
|
| 683 |
+
'multiple_credentials': multiple_results
|
| 684 |
+
}
|
| 685 |
+
else:
|
| 686 |
+
return {
|
| 687 |
+
'success': False,
|
| 688 |
+
'error': '无法获取您的项目列表,批量认证失败'
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
# 如果需要自动检测项目ID且没有提供项目ID(单项目模式)
|
| 692 |
+
elif flow_data.get('auto_project_detection', False) and not project_id:
|
| 693 |
+
log.info("尝试通过API获取用户项目列表...")
|
| 694 |
+
log.info(f"使用的token: {credentials.access_token[:20]}...")
|
| 695 |
+
log.info(f"Token过期时间: {credentials.expires_at}")
|
| 696 |
+
user_projects = await get_user_projects(credentials)
|
| 697 |
+
|
| 698 |
+
if user_projects:
|
| 699 |
+
# 如果只有一个项目,自动使用
|
| 700 |
+
if len(user_projects) == 1:
|
| 701 |
+
project_id = user_projects[0].get('projectId')
|
| 702 |
+
if project_id:
|
| 703 |
+
flow_data['project_id'] = project_id
|
| 704 |
+
log.info(f"自动选择唯一项目: {project_id}")
|
| 705 |
+
# 自动启用必需的API服务
|
| 706 |
+
log.info("正在自动启用必需的API服务...")
|
| 707 |
+
await enable_required_apis(credentials, project_id)
|
| 708 |
+
# 如果有多个项目,尝试选择默认项目
|
| 709 |
+
else:
|
| 710 |
+
project_id = await select_default_project(user_projects)
|
| 711 |
+
if project_id:
|
| 712 |
+
flow_data['project_id'] = project_id
|
| 713 |
+
log.info(f"自动选择默认项目: {project_id}")
|
| 714 |
+
# 自动启用必需的API服务
|
| 715 |
+
log.info("正在自动启用必需的API服务...")
|
| 716 |
+
await enable_required_apis(credentials, project_id)
|
| 717 |
+
else:
|
| 718 |
+
# 返回项目列表让用户选择
|
| 719 |
+
return {
|
| 720 |
+
'success': False,
|
| 721 |
+
'error': '请从以下项目中选择一个',
|
| 722 |
+
'requires_project_selection': True,
|
| 723 |
+
'available_projects': [
|
| 724 |
+
{
|
| 725 |
+
'projectId': p.get('projectId'),
|
| 726 |
+
'name': p.get('displayName') or p.get('projectId'),
|
| 727 |
+
'projectNumber': p.get('projectNumber')
|
| 728 |
+
}
|
| 729 |
+
for p in user_projects
|
| 730 |
+
]
|
| 731 |
+
}
|
| 732 |
+
else:
|
| 733 |
+
# 如果无法获取项目列表,提示手动输入
|
| 734 |
+
return {
|
| 735 |
+
'success': False,
|
| 736 |
+
'error': '无法获取您的项目列表,请手动指定项目ID',
|
| 737 |
+
'requires_manual_project_id': True
|
| 738 |
+
}
|
| 739 |
+
elif project_id:
|
| 740 |
+
# 如果已经有项目ID(手动提供或环境检测),也尝试启用API服务
|
| 741 |
+
log.info("正在为已提供的项目ID自动启用必需的API服务...")
|
| 742 |
+
await enable_required_apis(credentials, project_id)
|
| 743 |
+
|
| 744 |
+
# 如果仍然没有项目ID,返回错误
|
| 745 |
+
if not project_id:
|
| 746 |
+
return {
|
| 747 |
+
'success': False,
|
| 748 |
+
'error': '缺少项目ID,请指定项目ID',
|
| 749 |
+
'requires_manual_project_id': True
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
# 保存凭证
|
| 753 |
+
saved_filename = await save_credentials(credentials, project_id)
|
| 754 |
+
|
| 755 |
+
# 准备返回的凭证数据
|
| 756 |
+
creds_data = {
|
| 757 |
+
"client_id": CLIENT_ID,
|
| 758 |
+
"client_secret": CLIENT_SECRET,
|
| 759 |
+
"token": credentials.access_token,
|
| 760 |
+
"refresh_token": credentials.refresh_token,
|
| 761 |
+
"scopes": SCOPES,
|
| 762 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 763 |
+
"project_id": project_id
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
if credentials.expires_at:
|
| 767 |
+
if credentials.expires_at.tzinfo is None:
|
| 768 |
+
expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc)
|
| 769 |
+
else:
|
| 770 |
+
expiry_utc = credentials.expires_at
|
| 771 |
+
creds_data["expiry"] = expiry_utc.isoformat()
|
| 772 |
+
|
| 773 |
+
# 清理使用过的流程
|
| 774 |
+
if state in auth_flows:
|
| 775 |
+
flow_data_to_clean = auth_flows[state]
|
| 776 |
+
# 快速关闭服务器
|
| 777 |
+
try:
|
| 778 |
+
if flow_data_to_clean.get('server'):
|
| 779 |
+
server = flow_data_to_clean['server']
|
| 780 |
+
port = flow_data_to_clean.get('callback_port')
|
| 781 |
+
async_shutdown_server(server, port)
|
| 782 |
+
except Exception as e:
|
| 783 |
+
log.debug(f"启动异步关闭服务器时出错: {e}")
|
| 784 |
+
|
| 785 |
+
del auth_flows[state]
|
| 786 |
+
|
| 787 |
+
log.info("OAuth认证成功,凭证已保存")
|
| 788 |
+
return {
|
| 789 |
+
'success': True,
|
| 790 |
+
'credentials': creds_data,
|
| 791 |
+
'file_path': saved_filename,
|
| 792 |
+
'auto_detected_project': flow_data.get('auto_project_detection', False)
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
except Exception as e:
|
| 796 |
+
log.error(f"获取凭证失败: {e}")
|
| 797 |
+
return {
|
| 798 |
+
'success': False,
|
| 799 |
+
'error': f'获取凭证失败: {str(e)}'
|
| 800 |
+
}
|
| 801 |
+
finally:
|
| 802 |
+
oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = original_validate
|
| 803 |
+
|
| 804 |
+
except Exception as e:
|
| 805 |
+
log.error(f"异步完成认证流程失败: {e}")
|
| 806 |
+
return {
|
| 807 |
+
'success': False,
|
| 808 |
+
'error': str(e)
|
| 809 |
+
}
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
async def complete_auth_flow_from_callback_url(callback_url: str, project_id: Optional[str] = None, get_all_projects: bool = False) -> Dict[str, Any]:
|
| 813 |
+
"""从回调URL直接完成认证流程,无需启动本地服务器"""
|
| 814 |
+
try:
|
| 815 |
+
log.info(f"开始从回调URL完成认证: {callback_url}")
|
| 816 |
+
|
| 817 |
+
# 解析回调URL
|
| 818 |
+
parsed_url = urlparse(callback_url)
|
| 819 |
+
query_params = parse_qs(parsed_url.query)
|
| 820 |
+
|
| 821 |
+
# 验证必要参数
|
| 822 |
+
if 'state' not in query_params or 'code' not in query_params:
|
| 823 |
+
return {
|
| 824 |
+
'success': False,
|
| 825 |
+
'error': '回调URL缺少必要参数 (state 或 code)'
|
| 826 |
+
}
|
| 827 |
+
|
| 828 |
+
state = query_params['state'][0]
|
| 829 |
+
code = query_params['code'][0]
|
| 830 |
+
|
| 831 |
+
log.info(f"从URL解析到: state={state}, code=xxx...")
|
| 832 |
+
|
| 833 |
+
# 检查是否有对应的认证流程
|
| 834 |
+
if state not in auth_flows:
|
| 835 |
+
return {
|
| 836 |
+
'success': False,
|
| 837 |
+
'error': f'未找到对应的认证流程,请先启动认证 (state: {state})'
|
| 838 |
+
}
|
| 839 |
+
|
| 840 |
+
flow_data = auth_flows[state]
|
| 841 |
+
flow = flow_data['flow']
|
| 842 |
+
|
| 843 |
+
# 构造回调URL(使用flow中存储的redirect_uri)
|
| 844 |
+
redirect_uri = flow.redirect_uri
|
| 845 |
+
log.info(f"使用redirect_uri: {redirect_uri}")
|
| 846 |
+
|
| 847 |
+
try:
|
| 848 |
+
# 使用authorization code获取token
|
| 849 |
+
credentials = await flow.exchange_code(code)
|
| 850 |
+
log.info("成功获取访问令牌")
|
| 851 |
+
|
| 852 |
+
# 检查是否为批量获取所有项目模式
|
| 853 |
+
if get_all_projects:
|
| 854 |
+
log.info("批量模式:从回调URL为所有项目并发获取凭证...")
|
| 855 |
+
try:
|
| 856 |
+
projects = await get_user_projects(credentials)
|
| 857 |
+
if projects:
|
| 858 |
+
async def process_single_project(project_info):
|
| 859 |
+
"""并发处理单个项目的凭证获取"""
|
| 860 |
+
project_id_current = project_info.get('projectId')
|
| 861 |
+
project_name = project_info.get('displayName') or project_id_current
|
| 862 |
+
|
| 863 |
+
try:
|
| 864 |
+
log.info(f"为项目 {project_name} ({project_id_current}) 启用API服务...")
|
| 865 |
+
await enable_required_apis(credentials, project_id_current)
|
| 866 |
+
|
| 867 |
+
# 保存凭证
|
| 868 |
+
saved_filename = await save_credentials(credentials, project_id_current)
|
| 869 |
+
|
| 870 |
+
log.info(f"成功为项目 {project_name} 保存凭证")
|
| 871 |
+
return {
|
| 872 |
+
'status': 'success',
|
| 873 |
+
'project_id': project_id_current,
|
| 874 |
+
'project_name': project_name,
|
| 875 |
+
'file_path': saved_filename
|
| 876 |
+
}
|
| 877 |
+
|
| 878 |
+
except Exception as e:
|
| 879 |
+
log.error(f"为项目 {project_name} ({project_id_current}) 处理凭证失败: {e}")
|
| 880 |
+
return {
|
| 881 |
+
'status': 'failed',
|
| 882 |
+
'project_id': project_id_current,
|
| 883 |
+
'project_name': project_name,
|
| 884 |
+
'error': str(e)
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
# 并发处理所有项目
|
| 888 |
+
log.info(f"开始并发处理 {len(projects)} 个项目...")
|
| 889 |
+
tasks = [process_single_project(project_info) for project_info in projects]
|
| 890 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 891 |
+
|
| 892 |
+
# 整理结果
|
| 893 |
+
multiple_results = {'success': [], 'failed': []}
|
| 894 |
+
for result in results:
|
| 895 |
+
if isinstance(result, Exception):
|
| 896 |
+
log.error(f"并发处理项目时发生异常: {result}")
|
| 897 |
+
multiple_results['failed'].append({
|
| 898 |
+
'project_id': 'unknown',
|
| 899 |
+
'project_name': 'unknown',
|
| 900 |
+
'error': f'处理异常: {str(result)}'
|
| 901 |
+
})
|
| 902 |
+
elif result['status'] == 'success':
|
| 903 |
+
multiple_results['success'].append({
|
| 904 |
+
'project_id': result['project_id'],
|
| 905 |
+
'project_name': result['project_name'],
|
| 906 |
+
'file_path': result['file_path']
|
| 907 |
+
})
|
| 908 |
+
else: # failed
|
| 909 |
+
multiple_results['failed'].append({
|
| 910 |
+
'project_id': result['project_id'],
|
| 911 |
+
'project_name': result['project_name'],
|
| 912 |
+
'error': result['error']
|
| 913 |
+
})
|
| 914 |
+
|
| 915 |
+
# 清理使用过的流程
|
| 916 |
+
if state in auth_flows:
|
| 917 |
+
flow_data_to_clean = auth_flows[state]
|
| 918 |
+
try:
|
| 919 |
+
if flow_data_to_clean.get('server'):
|
| 920 |
+
server = flow_data_to_clean['server']
|
| 921 |
+
port = flow_data_to_clean.get('callback_port')
|
| 922 |
+
async_shutdown_server(server, port)
|
| 923 |
+
except Exception as e:
|
| 924 |
+
log.debug(f"关闭服务器时出错: {e}")
|
| 925 |
+
del auth_flows[state]
|
| 926 |
+
|
| 927 |
+
log.info(f"从回调URL批量并发认证完成:成功 {len(multiple_results['success'])} 个,失败 {len(multiple_results['failed'])} 个")
|
| 928 |
+
return {
|
| 929 |
+
'success': True,
|
| 930 |
+
'multiple_credentials': multiple_results
|
| 931 |
+
}
|
| 932 |
+
else:
|
| 933 |
+
return {
|
| 934 |
+
'success': False,
|
| 935 |
+
'error': '无法获取您的项目列表,批量认证失败'
|
| 936 |
+
}
|
| 937 |
+
except Exception as e:
|
| 938 |
+
log.error(f"批量获取项目列表失败: {e}")
|
| 939 |
+
return {
|
| 940 |
+
'success': False,
|
| 941 |
+
'error': f'批量获取项目列表失败: {str(e)}'
|
| 942 |
+
}
|
| 943 |
+
|
| 944 |
+
# 单项目模式的项目ID处理逻辑
|
| 945 |
+
detected_project_id = None
|
| 946 |
+
auto_detected = False
|
| 947 |
+
|
| 948 |
+
if not project_id:
|
| 949 |
+
# 尝试自动检测项目ID
|
| 950 |
+
try:
|
| 951 |
+
projects = await get_user_projects(credentials)
|
| 952 |
+
if projects:
|
| 953 |
+
if len(projects) == 1:
|
| 954 |
+
# 只有一个项目,自动使用
|
| 955 |
+
detected_project_id = projects[0]['projectId']
|
| 956 |
+
auto_detected = True
|
| 957 |
+
log.info(f"自动检测到唯一项目ID: {detected_project_id}")
|
| 958 |
+
else:
|
| 959 |
+
# 多个项目,自动选择第一个
|
| 960 |
+
detected_project_id = projects[0]['projectId']
|
| 961 |
+
auto_detected = True
|
| 962 |
+
log.info(f"检测到{len(projects)}个项目,自动选择第一个: {detected_project_id}")
|
| 963 |
+
log.debug(f"其他可用项目: {[p['projectId'] for p in projects[1:]]}")
|
| 964 |
+
else:
|
| 965 |
+
# 没有项目访问权限
|
| 966 |
+
return {
|
| 967 |
+
'success': False,
|
| 968 |
+
'error': '未检测到可访问的项目,请检查权限或手动指定项目ID',
|
| 969 |
+
'requires_manual_project_id': True
|
| 970 |
+
}
|
| 971 |
+
except Exception as e:
|
| 972 |
+
log.warning(f"自动检测项目ID失败: {e}")
|
| 973 |
+
return {
|
| 974 |
+
'success': False,
|
| 975 |
+
'error': f'自动检测项目ID失败: {str(e)},请手动指定项目ID',
|
| 976 |
+
'requires_manual_project_id': True
|
| 977 |
+
}
|
| 978 |
+
else:
|
| 979 |
+
detected_project_id = project_id
|
| 980 |
+
|
| 981 |
+
# 启用必需的API服务
|
| 982 |
+
if detected_project_id:
|
| 983 |
+
try:
|
| 984 |
+
log.info(f"正在为项目 {detected_project_id} 启用必需的API服务...")
|
| 985 |
+
await enable_required_apis(credentials, detected_project_id)
|
| 986 |
+
except Exception as e:
|
| 987 |
+
log.warning(f"启用API服务失败: {e}")
|
| 988 |
+
|
| 989 |
+
# 保存凭证
|
| 990 |
+
saved_filename = await save_credentials(credentials, detected_project_id)
|
| 991 |
+
|
| 992 |
+
# 准备返回的凭证数据
|
| 993 |
+
creds_data = {
|
| 994 |
+
"client_id": CLIENT_ID,
|
| 995 |
+
"client_secret": CLIENT_SECRET,
|
| 996 |
+
"token": credentials.access_token,
|
| 997 |
+
"refresh_token": credentials.refresh_token,
|
| 998 |
+
"scopes": SCOPES,
|
| 999 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 1000 |
+
"project_id": detected_project_id
|
| 1001 |
+
}
|
| 1002 |
+
|
| 1003 |
+
if credentials.expires_at:
|
| 1004 |
+
if credentials.expires_at.tzinfo is None:
|
| 1005 |
+
expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc)
|
| 1006 |
+
else:
|
| 1007 |
+
expiry_utc = credentials.expires_at
|
| 1008 |
+
creds_data["expiry"] = expiry_utc.isoformat()
|
| 1009 |
+
|
| 1010 |
+
# 清理使用过的流程
|
| 1011 |
+
if state in auth_flows:
|
| 1012 |
+
flow_data_to_clean = auth_flows[state]
|
| 1013 |
+
# 快速关闭服务器(如果有)
|
| 1014 |
+
try:
|
| 1015 |
+
if flow_data_to_clean.get('server'):
|
| 1016 |
+
server = flow_data_to_clean['server']
|
| 1017 |
+
port = flow_data_to_clean.get('callback_port')
|
| 1018 |
+
async_shutdown_server(server, port)
|
| 1019 |
+
except Exception as e:
|
| 1020 |
+
log.debug(f"关闭服务器时出错: {e}")
|
| 1021 |
+
|
| 1022 |
+
del auth_flows[state]
|
| 1023 |
+
|
| 1024 |
+
log.info("从回调URL完成OAuth认证成功,凭证已保存")
|
| 1025 |
+
return {
|
| 1026 |
+
'success': True,
|
| 1027 |
+
'credentials': creds_data,
|
| 1028 |
+
'file_path': saved_filename,
|
| 1029 |
+
'auto_detected_project': auto_detected
|
| 1030 |
+
}
|
| 1031 |
+
|
| 1032 |
+
except Exception as e:
|
| 1033 |
+
log.error(f"从回调URL获取凭证失败: {e}")
|
| 1034 |
+
return {
|
| 1035 |
+
'success': False,
|
| 1036 |
+
'error': f'获取凭证失败: {str(e)}'
|
| 1037 |
+
}
|
| 1038 |
+
|
| 1039 |
+
except Exception as e:
|
| 1040 |
+
log.error(f"从回调URL完成认证流程失败: {e}")
|
| 1041 |
+
return {
|
| 1042 |
+
'success': False,
|
| 1043 |
+
'error': str(e)
|
| 1044 |
+
}
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
async def save_credentials(creds: Credentials, project_id: str) -> str:
|
| 1048 |
+
"""通过统一存储系统保存凭证"""
|
| 1049 |
+
# 生成文件名(使用project_id和时间戳)
|
| 1050 |
+
timestamp = int(time.time())
|
| 1051 |
+
filename = f"{project_id}-{timestamp}.json"
|
| 1052 |
+
|
| 1053 |
+
# 准备凭证数据
|
| 1054 |
+
creds_data = {
|
| 1055 |
+
"client_id": CLIENT_ID,
|
| 1056 |
+
"client_secret": CLIENT_SECRET,
|
| 1057 |
+
"token": creds.access_token,
|
| 1058 |
+
"refresh_token": creds.refresh_token,
|
| 1059 |
+
"scopes": SCOPES,
|
| 1060 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 1061 |
+
"project_id": project_id
|
| 1062 |
+
}
|
| 1063 |
+
|
| 1064 |
+
if creds.expires_at:
|
| 1065 |
+
if creds.expires_at.tzinfo is None:
|
| 1066 |
+
expiry_utc = creds.expires_at.replace(tzinfo=timezone.utc)
|
| 1067 |
+
else:
|
| 1068 |
+
expiry_utc = creds.expires_at
|
| 1069 |
+
creds_data["expiry"] = expiry_utc.isoformat()
|
| 1070 |
+
|
| 1071 |
+
# 通过存储适配器保存
|
| 1072 |
+
storage_adapter = await get_storage_adapter()
|
| 1073 |
+
success = await storage_adapter.store_credential(filename, creds_data)
|
| 1074 |
+
|
| 1075 |
+
if success:
|
| 1076 |
+
# 创建默认状态记录
|
| 1077 |
+
try:
|
| 1078 |
+
default_state = {
|
| 1079 |
+
"error_codes": [],
|
| 1080 |
+
"disabled": False,
|
| 1081 |
+
"last_success": time.time(),
|
| 1082 |
+
"user_email": None,
|
| 1083 |
+
"gemini_2_5_pro_calls": 0,
|
| 1084 |
+
"total_calls": 0,
|
| 1085 |
+
"next_reset_time": None,
|
| 1086 |
+
"daily_limit_gemini_2_5_pro": 100,
|
| 1087 |
+
"daily_limit_total": 1000
|
| 1088 |
+
}
|
| 1089 |
+
await storage_adapter.update_credential_state(filename, default_state)
|
| 1090 |
+
log.info(f"凭证和状态已保存到: {filename}")
|
| 1091 |
+
except Exception as e:
|
| 1092 |
+
log.warning(f"创建默认状态记录失败 {filename}: {e}")
|
| 1093 |
+
|
| 1094 |
+
return filename
|
| 1095 |
+
else:
|
| 1096 |
+
raise Exception(f"保存凭证失败: {filename}")
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
+
def async_shutdown_server(server, port):
|
| 1100 |
+
"""异步关闭OAuth回调服务器,避免阻塞主流程"""
|
| 1101 |
+
def shutdown_server_async():
|
| 1102 |
+
try:
|
| 1103 |
+
# 设置一个标志来跟踪关闭状态
|
| 1104 |
+
shutdown_completed = threading.Event()
|
| 1105 |
+
|
| 1106 |
+
def do_shutdown():
|
| 1107 |
+
try:
|
| 1108 |
+
server.shutdown()
|
| 1109 |
+
server.server_close()
|
| 1110 |
+
shutdown_completed.set()
|
| 1111 |
+
log.info(f"已关闭端口 {port} 的OAuth回调服务器")
|
| 1112 |
+
except Exception as e:
|
| 1113 |
+
shutdown_completed.set()
|
| 1114 |
+
log.debug(f"关闭服务器时出错: {e}")
|
| 1115 |
+
|
| 1116 |
+
# 在单独线程中执行关闭操作
|
| 1117 |
+
shutdown_worker = threading.Thread(target=do_shutdown, daemon=True)
|
| 1118 |
+
shutdown_worker.start()
|
| 1119 |
+
|
| 1120 |
+
# 等待最多5秒,如果超时就放弃等待
|
| 1121 |
+
if shutdown_completed.wait(timeout=5):
|
| 1122 |
+
log.debug(f"端口 {port} 服务器关闭完成")
|
| 1123 |
+
else:
|
| 1124 |
+
log.warning(f"端口 {port} 服务器关闭超时,但不阻塞主流程")
|
| 1125 |
+
|
| 1126 |
+
except Exception as e:
|
| 1127 |
+
log.debug(f"异步关闭服务器时出错: {e}")
|
| 1128 |
+
|
| 1129 |
+
# 在后台线程中关闭服务器,不阻塞主流程
|
| 1130 |
+
shutdown_thread = threading.Thread(target=shutdown_server_async, daemon=True)
|
| 1131 |
+
shutdown_thread.start()
|
| 1132 |
+
log.debug(f"开始异步关闭端口 {port} 的OAuth回调服务器")
|
| 1133 |
+
|
| 1134 |
+
def cleanup_expired_flows():
|
| 1135 |
+
"""清理过期的认证流程"""
|
| 1136 |
+
current_time = time.time()
|
| 1137 |
+
EXPIRY_TIME = 600 # 10分钟过期
|
| 1138 |
+
|
| 1139 |
+
# 直接遍历删除,避免创建额外列表
|
| 1140 |
+
states_to_remove = [
|
| 1141 |
+
state for state, flow_data in auth_flows.items()
|
| 1142 |
+
if current_time - flow_data['created_at'] > EXPIRY_TIME
|
| 1143 |
+
]
|
| 1144 |
+
|
| 1145 |
+
# 批量清理,提高效率
|
| 1146 |
+
cleaned_count = 0
|
| 1147 |
+
for state in states_to_remove:
|
| 1148 |
+
flow_data = auth_flows.get(state)
|
| 1149 |
+
if flow_data:
|
| 1150 |
+
# 快速关闭可能存在的服务器
|
| 1151 |
+
try:
|
| 1152 |
+
if flow_data.get('server'):
|
| 1153 |
+
server = flow_data['server']
|
| 1154 |
+
port = flow_data.get('callback_port')
|
| 1155 |
+
async_shutdown_server(server, port)
|
| 1156 |
+
except Exception as e:
|
| 1157 |
+
log.debug(f"清理过期流程时启动异步关闭服务器失败: {e}")
|
| 1158 |
+
|
| 1159 |
+
# 显式清理流程数据,释放内存
|
| 1160 |
+
flow_data.clear()
|
| 1161 |
+
del auth_flows[state]
|
| 1162 |
+
cleaned_count += 1
|
| 1163 |
+
|
| 1164 |
+
if cleaned_count > 0:
|
| 1165 |
+
log.info(f"清理了 {cleaned_count} 个过期的认证流程")
|
| 1166 |
+
|
| 1167 |
+
# 更积极的垃圾回收触发条件
|
| 1168 |
+
if len(auth_flows) > 20: # 降低阈值
|
| 1169 |
+
import gc
|
| 1170 |
+
gc.collect()
|
| 1171 |
+
log.debug(f"触发垃圾回收,当前活跃认证流程数: {len(auth_flows)}")
|
| 1172 |
+
|
| 1173 |
+
|
| 1174 |
+
def get_auth_status(project_id: str) -> Dict[str, Any]:
|
| 1175 |
+
"""获取认证状态"""
|
| 1176 |
+
for state, flow_data in auth_flows.items():
|
| 1177 |
+
if flow_data['project_id'] == project_id:
|
| 1178 |
+
return {
|
| 1179 |
+
'status': 'completed' if flow_data['completed'] else 'pending',
|
| 1180 |
+
'state': state,
|
| 1181 |
+
'created_at': flow_data['created_at']
|
| 1182 |
+
}
|
| 1183 |
+
|
| 1184 |
+
return {
|
| 1185 |
+
'status': 'not_found'
|
| 1186 |
+
}
|
| 1187 |
+
|
| 1188 |
+
|
| 1189 |
+
# 鉴权功能 - 使用更小的数据结构
|
| 1190 |
+
auth_tokens = {} # 存储有效的认证令牌
|
| 1191 |
+
TOKEN_EXPIRY = 3600 # 1小时令牌过期时间
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
async def verify_password(password: str) -> bool:
|
| 1195 |
+
"""验证密码(面板登录使用)"""
|
| 1196 |
+
from config import get_panel_password
|
| 1197 |
+
correct_password = await get_panel_password()
|
| 1198 |
+
return password == correct_password
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
def generate_auth_token() -> str:
|
| 1202 |
+
"""生成认证令牌"""
|
| 1203 |
+
# 清理过期令牌
|
| 1204 |
+
cleanup_expired_tokens()
|
| 1205 |
+
|
| 1206 |
+
token = secrets.token_urlsafe(32)
|
| 1207 |
+
# 只存储创建时间
|
| 1208 |
+
auth_tokens[token] = time.time()
|
| 1209 |
+
return token
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
def verify_auth_token(token: str) -> bool:
|
| 1213 |
+
"""验证认证令牌"""
|
| 1214 |
+
if not token or token not in auth_tokens:
|
| 1215 |
+
return False
|
| 1216 |
+
|
| 1217 |
+
created_at = auth_tokens[token]
|
| 1218 |
+
|
| 1219 |
+
# 检查令牌是否过期 (使用更短的过期时间)
|
| 1220 |
+
if time.time() - created_at > TOKEN_EXPIRY:
|
| 1221 |
+
del auth_tokens[token]
|
| 1222 |
+
return False
|
| 1223 |
+
|
| 1224 |
+
return True
|
| 1225 |
+
|
| 1226 |
+
|
| 1227 |
+
def cleanup_expired_tokens():
|
| 1228 |
+
"""清理过期的认证令牌"""
|
| 1229 |
+
current_time = time.time()
|
| 1230 |
+
expired_tokens = [
|
| 1231 |
+
token for token, created_at in auth_tokens.items()
|
| 1232 |
+
if current_time - created_at > TOKEN_EXPIRY
|
| 1233 |
+
]
|
| 1234 |
+
|
| 1235 |
+
for token in expired_tokens:
|
| 1236 |
+
del auth_tokens[token]
|
| 1237 |
+
|
| 1238 |
+
if expired_tokens:
|
| 1239 |
+
log.debug(f"清理了 {len(expired_tokens)} 个过期的认证令牌")
|
| 1240 |
+
|
| 1241 |
+
def invalidate_auth_token(token: str):
|
| 1242 |
+
"""使认证令牌失效"""
|
| 1243 |
+
if token in auth_tokens:
|
| 1244 |
+
del auth_tokens[token]
|
| 1245 |
+
|
| 1246 |
+
|
| 1247 |
+
# 文件验证和处理功能 - 使用统一存储系统
|
| 1248 |
+
def validate_credential_content(content: str) -> Dict[str, Any]:
|
| 1249 |
+
"""验证凭证内容格式"""
|
| 1250 |
+
try:
|
| 1251 |
+
creds_data = json.loads(content)
|
| 1252 |
+
|
| 1253 |
+
# 检查必要字段
|
| 1254 |
+
required_fields = ['client_id', 'client_secret', 'refresh_token', 'token_uri']
|
| 1255 |
+
missing_fields = [field for field in required_fields if field not in creds_data]
|
| 1256 |
+
|
| 1257 |
+
if missing_fields:
|
| 1258 |
+
return {
|
| 1259 |
+
'valid': False,
|
| 1260 |
+
'error': f'缺少必要字段: {", ".join(missing_fields)}'
|
| 1261 |
+
}
|
| 1262 |
+
|
| 1263 |
+
# 检查project_id
|
| 1264 |
+
if 'project_id' not in creds_data:
|
| 1265 |
+
log.warning("��证文件缺少project_id字段")
|
| 1266 |
+
|
| 1267 |
+
return {
|
| 1268 |
+
'valid': True,
|
| 1269 |
+
'data': creds_data
|
| 1270 |
+
}
|
| 1271 |
+
|
| 1272 |
+
except json.JSONDecodeError as e:
|
| 1273 |
+
return {
|
| 1274 |
+
'valid': False,
|
| 1275 |
+
'error': f'JSON格式错误: {str(e)}'
|
| 1276 |
+
}
|
| 1277 |
+
except Exception as e:
|
| 1278 |
+
return {
|
| 1279 |
+
'valid': False,
|
| 1280 |
+
'error': f'文件验证失败: {str(e)}'
|
| 1281 |
+
}
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
async def save_uploaded_credential(content: str, original_filename: str) -> Dict[str, Any]:
|
| 1285 |
+
"""通过统一存储系统保存上传的凭证"""
|
| 1286 |
+
try:
|
| 1287 |
+
# 验证内容格式
|
| 1288 |
+
validation = validate_credential_content(content)
|
| 1289 |
+
if not validation['valid']:
|
| 1290 |
+
return {
|
| 1291 |
+
'success': False,
|
| 1292 |
+
'error': validation['error']
|
| 1293 |
+
}
|
| 1294 |
+
|
| 1295 |
+
creds_data = validation['data']
|
| 1296 |
+
|
| 1297 |
+
# 生成文件名
|
| 1298 |
+
project_id = creds_data.get('project_id', 'unknown')
|
| 1299 |
+
timestamp = int(time.time())
|
| 1300 |
+
|
| 1301 |
+
# 从原文件名中提取有用信息
|
| 1302 |
+
import os
|
| 1303 |
+
base_name = os.path.splitext(original_filename)[0]
|
| 1304 |
+
filename = f"{base_name}-{timestamp}.json"
|
| 1305 |
+
|
| 1306 |
+
# 通过存储适配器保存
|
| 1307 |
+
storage_adapter = await get_storage_adapter()
|
| 1308 |
+
success = await storage_adapter.store_credential(filename, creds_data)
|
| 1309 |
+
|
| 1310 |
+
if success:
|
| 1311 |
+
log.info(f"凭证文件已上传保存: {filename}")
|
| 1312 |
+
return {
|
| 1313 |
+
'success': True,
|
| 1314 |
+
'file_path': filename,
|
| 1315 |
+
'project_id': project_id
|
| 1316 |
+
}
|
| 1317 |
+
else:
|
| 1318 |
+
return {
|
| 1319 |
+
'success': False,
|
| 1320 |
+
'error': '保存到存储系统失败'
|
| 1321 |
+
}
|
| 1322 |
+
|
| 1323 |
+
except Exception as e:
|
| 1324 |
+
log.error(f"保存上传文件失败: {e}")
|
| 1325 |
+
return {
|
| 1326 |
+
'success': False,
|
| 1327 |
+
'error': str(e)
|
| 1328 |
+
}
|
| 1329 |
+
|
| 1330 |
+
|
| 1331 |
+
async def batch_upload_credentials(files_data: List[Dict[str, str]]) -> Dict[str, Any]:
|
| 1332 |
+
"""批量上传凭证文件到统一存储系统"""
|
| 1333 |
+
results = []
|
| 1334 |
+
success_count = 0
|
| 1335 |
+
|
| 1336 |
+
for file_data in files_data:
|
| 1337 |
+
filename = file_data.get('filename', 'unknown.json')
|
| 1338 |
+
content = file_data.get('content', '')
|
| 1339 |
+
|
| 1340 |
+
result = await save_uploaded_credential(content, filename)
|
| 1341 |
+
result['filename'] = filename
|
| 1342 |
+
results.append(result)
|
| 1343 |
+
|
| 1344 |
+
if result['success']:
|
| 1345 |
+
success_count += 1
|
| 1346 |
+
|
| 1347 |
+
return {
|
| 1348 |
+
'uploaded_count': success_count,
|
| 1349 |
+
'total_count': len(files_data),
|
| 1350 |
+
'results': results
|
| 1351 |
+
}
|
| 1352 |
+
|
| 1353 |
+
|
| 1354 |
+
# 环境变量批量导入功能 - 使用统一存储系统
|
| 1355 |
+
async def load_credentials_from_env() -> Dict[str, Any]:
|
| 1356 |
+
"""
|
| 1357 |
+
从环境变量加载多个凭证文件到统一存储系统
|
| 1358 |
+
支持两种环境变量格式:
|
| 1359 |
+
1. GCLI_CREDS_1, GCLI_CREDS_2, ... (编号格式)
|
| 1360 |
+
2. GCLI_CREDS_projectname1, GCLI_CREDS_projectname2, ... (项目名格式)
|
| 1361 |
+
"""
|
| 1362 |
+
import os
|
| 1363 |
+
|
| 1364 |
+
results = []
|
| 1365 |
+
success_count = 0
|
| 1366 |
+
|
| 1367 |
+
log.info("开始从环境变量加载认证凭证...")
|
| 1368 |
+
|
| 1369 |
+
# 获取所有以GCLI_CREDS_开头的环境变量
|
| 1370 |
+
creds_env_vars = {key: value for key, value in os.environ.items()
|
| 1371 |
+
if key.startswith('GCLI_CREDS_') and value.strip()}
|
| 1372 |
+
|
| 1373 |
+
if not creds_env_vars:
|
| 1374 |
+
log.info("未找到GCLI_CREDS_*环境变量")
|
| 1375 |
+
return {
|
| 1376 |
+
'loaded_count': 0,
|
| 1377 |
+
'total_count': 0,
|
| 1378 |
+
'results': [],
|
| 1379 |
+
'message': '未找到GCLI_CREDS_*环境变量'
|
| 1380 |
+
}
|
| 1381 |
+
|
| 1382 |
+
log.info(f"找到 {len(creds_env_vars)} 个凭证环境变量")
|
| 1383 |
+
|
| 1384 |
+
# 获取存储适配器
|
| 1385 |
+
storage_adapter = await get_storage_adapter()
|
| 1386 |
+
|
| 1387 |
+
for env_name, creds_content in creds_env_vars.items():
|
| 1388 |
+
# 从环境变量名提取标识符
|
| 1389 |
+
identifier = env_name.replace('GCLI_CREDS_', '')
|
| 1390 |
+
|
| 1391 |
+
try:
|
| 1392 |
+
# 验证JSON格式
|
| 1393 |
+
validation = validate_credential_content(creds_content)
|
| 1394 |
+
if not validation['valid']:
|
| 1395 |
+
result = {
|
| 1396 |
+
'env_name': env_name,
|
| 1397 |
+
'identifier': identifier,
|
| 1398 |
+
'success': False,
|
| 1399 |
+
'error': validation['error']
|
| 1400 |
+
}
|
| 1401 |
+
results.append(result)
|
| 1402 |
+
log.error(f"环境变量 {env_name} 验证失败: {validation['error']}")
|
| 1403 |
+
continue
|
| 1404 |
+
|
| 1405 |
+
creds_data = validation['data']
|
| 1406 |
+
project_id = creds_data.get('project_id', 'unknown')
|
| 1407 |
+
|
| 1408 |
+
# 生成文件名 (使用标识符和项目ID)
|
| 1409 |
+
timestamp = int(time.time())
|
| 1410 |
+
if identifier.isdigit():
|
| 1411 |
+
# 如果标识符是数字,使用项目ID作为主要标识
|
| 1412 |
+
filename = f"env-{project_id}-{identifier}-{timestamp}.json"
|
| 1413 |
+
else:
|
| 1414 |
+
# 如果标识符是项目名,直接使用
|
| 1415 |
+
filename = f"env-{identifier}-{timestamp}.json"
|
| 1416 |
+
|
| 1417 |
+
# 通过存储适配器保存
|
| 1418 |
+
success = await storage_adapter.store_credential(filename, creds_data)
|
| 1419 |
+
|
| 1420 |
+
if success:
|
| 1421 |
+
result = {
|
| 1422 |
+
'env_name': env_name,
|
| 1423 |
+
'identifier': identifier,
|
| 1424 |
+
'success': True,
|
| 1425 |
+
'file_path': filename,
|
| 1426 |
+
'project_id': project_id,
|
| 1427 |
+
'filename': filename
|
| 1428 |
+
}
|
| 1429 |
+
results.append(result)
|
| 1430 |
+
success_count += 1
|
| 1431 |
+
|
| 1432 |
+
log.info(f"成功从环境变量 {env_name} 保存凭证到: {filename}")
|
| 1433 |
+
else:
|
| 1434 |
+
result = {
|
| 1435 |
+
'env_name': env_name,
|
| 1436 |
+
'identifier': identifier,
|
| 1437 |
+
'success': False,
|
| 1438 |
+
'error': '保存到存储系统失败'
|
| 1439 |
+
}
|
| 1440 |
+
results.append(result)
|
| 1441 |
+
log.error(f"环境变量 {env_name} 保存失败")
|
| 1442 |
+
|
| 1443 |
+
except Exception as e:
|
| 1444 |
+
result = {
|
| 1445 |
+
'env_name': env_name,
|
| 1446 |
+
'identifier': identifier,
|
| 1447 |
+
'success': False,
|
| 1448 |
+
'error': str(e)
|
| 1449 |
+
}
|
| 1450 |
+
results.append(result)
|
| 1451 |
+
log.error(f"处理环境变量 {env_name} 时发生错误: {e}")
|
| 1452 |
+
|
| 1453 |
+
message = f"成功导入 {success_count}/{len(creds_env_vars)} 个凭证文件"
|
| 1454 |
+
log.info(message)
|
| 1455 |
+
|
| 1456 |
+
return {
|
| 1457 |
+
'loaded_count': success_count,
|
| 1458 |
+
'total_count': len(creds_env_vars),
|
| 1459 |
+
'results': results,
|
| 1460 |
+
'message': message
|
| 1461 |
+
}
|
| 1462 |
+
|
| 1463 |
+
|
| 1464 |
+
async def auto_load_env_credentials_on_startup() -> None:
|
| 1465 |
+
"""
|
| 1466 |
+
程序启动时自动从环境变量加载凭证到统一存储系统
|
| 1467 |
+
如果设置了 AUTO_LOAD_ENV_CREDS=true,则会自动执行
|
| 1468 |
+
"""
|
| 1469 |
+
from config import get_auto_load_env_creds
|
| 1470 |
+
auto_load = await get_auto_load_env_creds()
|
| 1471 |
+
|
| 1472 |
+
if not auto_load:
|
| 1473 |
+
log.debug("AUTO_LOAD_ENV_CREDS未启用,跳过自动加载")
|
| 1474 |
+
return
|
| 1475 |
+
|
| 1476 |
+
log.info("AUTO_LOAD_ENV_CREDS已启用,开始自动加载环境变量中的凭证...")
|
| 1477 |
+
|
| 1478 |
+
try:
|
| 1479 |
+
result = await load_credentials_from_env()
|
| 1480 |
+
if result['loaded_count'] > 0:
|
| 1481 |
+
log.info(f"启动时成功自动导入 {result['loaded_count']} 个凭证文件")
|
| 1482 |
+
else:
|
| 1483 |
+
log.info("启动时未找到可导入的环境变量凭证")
|
| 1484 |
+
except Exception as e:
|
| 1485 |
+
log.error(f"启动时自动加载环境变量凭证失败: {e}")
|
| 1486 |
+
|
| 1487 |
+
|
| 1488 |
+
async def clear_env_credentials() -> Dict[str, Any]:
|
| 1489 |
+
"""
|
| 1490 |
+
清除所有从环境变量导入的凭证文件
|
| 1491 |
+
仅删除文件名包含'env-'前缀的文件
|
| 1492 |
+
"""
|
| 1493 |
+
try:
|
| 1494 |
+
storage_adapter = await get_storage_adapter()
|
| 1495 |
+
|
| 1496 |
+
# 获取所有凭证
|
| 1497 |
+
all_credentials = await storage_adapter.list_credentials()
|
| 1498 |
+
|
| 1499 |
+
deleted_files = []
|
| 1500 |
+
deleted_count = 0
|
| 1501 |
+
|
| 1502 |
+
for credential_name in all_credentials:
|
| 1503 |
+
if credential_name.startswith('env-') and credential_name.endswith('.json'):
|
| 1504 |
+
try:
|
| 1505 |
+
success = await storage_adapter.delete_credential(credential_name)
|
| 1506 |
+
if success:
|
| 1507 |
+
deleted_files.append(credential_name)
|
| 1508 |
+
deleted_count += 1
|
| 1509 |
+
log.info(f"删除环境变量凭证文件: {credential_name}")
|
| 1510 |
+
else:
|
| 1511 |
+
log.error(f"删除文件 {credential_name} 失败")
|
| 1512 |
+
except Exception as e:
|
| 1513 |
+
log.error(f"删除文件 {credential_name} 失败: {e}")
|
| 1514 |
+
|
| 1515 |
+
message = f"成功删除 {deleted_count} 个环境变量凭证文件"
|
| 1516 |
+
log.info(message)
|
| 1517 |
+
|
| 1518 |
+
return {
|
| 1519 |
+
'deleted_count': deleted_count,
|
| 1520 |
+
'deleted_files': deleted_files,
|
| 1521 |
+
'message': message
|
| 1522 |
+
}
|
| 1523 |
+
|
| 1524 |
+
except Exception as e:
|
| 1525 |
+
error_message = f"清除环境变量凭证文件时发生错误: {e}"
|
| 1526 |
+
log.error(error_message)
|
| 1527 |
+
return {
|
| 1528 |
+
'deleted_count': 0,
|
| 1529 |
+
'error': error_message
|
| 1530 |
+
}
|
src/credential_manager.py
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
凭证管理器 - 完全基于统一存储中间层
|
| 3 |
+
"""
|
| 4 |
+
import asyncio
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime, timezone
|
| 7 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 8 |
+
from contextlib import asynccontextmanager
|
| 9 |
+
|
| 10 |
+
from config import get_calls_per_rotation, is_mongodb_mode
|
| 11 |
+
from log import log
|
| 12 |
+
from .storage_adapter import get_storage_adapter
|
| 13 |
+
from .google_oauth_api import fetch_user_email_from_file, Credentials
|
| 14 |
+
from .task_manager import task_manager
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CredentialManager:
|
| 18 |
+
"""
|
| 19 |
+
统一凭证管理器
|
| 20 |
+
所有存储操作通过storage_adapter进行
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
# 核心状态
|
| 25 |
+
self._initialized = False
|
| 26 |
+
self._storage_adapter = None
|
| 27 |
+
|
| 28 |
+
# 凭证轮换相关
|
| 29 |
+
self._credential_files: List[str] = [] # 存储凭证文件名列表
|
| 30 |
+
self._current_credential_index = 0
|
| 31 |
+
self._call_count = 0
|
| 32 |
+
self._last_scan_time = 0
|
| 33 |
+
|
| 34 |
+
# 当前使用的凭证信息
|
| 35 |
+
self._current_credential_file: Optional[str] = None
|
| 36 |
+
self._current_credential_data: Optional[Dict[str, Any]] = None
|
| 37 |
+
self._current_credential_state: Dict[str, Any] = {}
|
| 38 |
+
|
| 39 |
+
# 并发控制
|
| 40 |
+
self._state_lock = asyncio.Lock()
|
| 41 |
+
self._operation_lock = asyncio.Lock()
|
| 42 |
+
|
| 43 |
+
# 工作线程控制
|
| 44 |
+
self._shutdown_event = asyncio.Event()
|
| 45 |
+
self._write_worker_running = False
|
| 46 |
+
self._write_worker_task = None
|
| 47 |
+
|
| 48 |
+
# 原子操作计数器
|
| 49 |
+
self._atomic_counter = 0
|
| 50 |
+
self._atomic_lock = asyncio.Lock()
|
| 51 |
+
|
| 52 |
+
# Onboarding state
|
| 53 |
+
self._onboarding_complete = False
|
| 54 |
+
self._onboarding_checked = False
|
| 55 |
+
|
| 56 |
+
async def initialize(self):
|
| 57 |
+
"""初始化凭证管理器"""
|
| 58 |
+
async with self._state_lock:
|
| 59 |
+
if self._initialized:
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
# 初始化统一存储适配器
|
| 63 |
+
self._storage_adapter = await get_storage_adapter()
|
| 64 |
+
|
| 65 |
+
# 启动后台工作线程
|
| 66 |
+
await self._start_background_workers()
|
| 67 |
+
|
| 68 |
+
# 发现并加载凭证
|
| 69 |
+
await self._discover_credentials()
|
| 70 |
+
|
| 71 |
+
self._initialized = True
|
| 72 |
+
storage_type = "MongoDB" if await is_mongodb_mode() else "File"
|
| 73 |
+
log.debug(f"Credential manager initialized with {storage_type} storage backend")
|
| 74 |
+
|
| 75 |
+
async def close(self):
|
| 76 |
+
"""清理资源"""
|
| 77 |
+
log.debug("Closing credential manager...")
|
| 78 |
+
|
| 79 |
+
# 设置关闭标志
|
| 80 |
+
self._shutdown_event.set()
|
| 81 |
+
|
| 82 |
+
# 等待后台任务结束
|
| 83 |
+
if self._write_worker_task:
|
| 84 |
+
try:
|
| 85 |
+
await asyncio.wait_for(self._write_worker_task, timeout=5.0)
|
| 86 |
+
except asyncio.TimeoutError:
|
| 87 |
+
log.warning("Write worker task did not finish within timeout")
|
| 88 |
+
if not self._write_worker_task.done():
|
| 89 |
+
self._write_worker_task.cancel()
|
| 90 |
+
|
| 91 |
+
self._initialized = False
|
| 92 |
+
log.debug("Credential manager closed")
|
| 93 |
+
|
| 94 |
+
async def _start_background_workers(self):
|
| 95 |
+
"""启动后台工作线程"""
|
| 96 |
+
if not self._write_worker_running:
|
| 97 |
+
self._write_worker_running = True
|
| 98 |
+
self._write_worker_task = task_manager.create_task(
|
| 99 |
+
self._background_worker(),
|
| 100 |
+
name="credential_background_worker"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
async def _background_worker(self):
|
| 104 |
+
"""后台工作线程,处理定期任务"""
|
| 105 |
+
while not self._shutdown_event.is_set():
|
| 106 |
+
try:
|
| 107 |
+
# 每60秒检查一次凭证更新
|
| 108 |
+
await asyncio.wait_for(self._shutdown_event.wait(), timeout=60.0)
|
| 109 |
+
if self._shutdown_event.is_set():
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
# 重新发现凭证(热更新)
|
| 113 |
+
await self._discover_credentials()
|
| 114 |
+
|
| 115 |
+
except asyncio.TimeoutError:
|
| 116 |
+
# 超时是正常的,继续下一轮
|
| 117 |
+
continue
|
| 118 |
+
except Exception as e:
|
| 119 |
+
log.error(f"Background worker error: {e}")
|
| 120 |
+
await asyncio.sleep(5) # 错误后等待5秒再继续
|
| 121 |
+
|
| 122 |
+
async def _discover_credentials(self):
|
| 123 |
+
"""发现和加载所有可用凭证"""
|
| 124 |
+
try:
|
| 125 |
+
# 从存储适配器获取所有凭证
|
| 126 |
+
all_credentials = await self._storage_adapter.list_credentials()
|
| 127 |
+
|
| 128 |
+
# 过滤出可用的凭证(排除被禁用的)- 批量读取状态以提升性能
|
| 129 |
+
available_credentials = []
|
| 130 |
+
|
| 131 |
+
# 批量获取所有凭证状态,避免多次读取状态文件
|
| 132 |
+
if all_credentials:
|
| 133 |
+
try:
|
| 134 |
+
all_states = await self._storage_adapter.get_all_credential_states()
|
| 135 |
+
|
| 136 |
+
for credential_name in all_credentials:
|
| 137 |
+
normalized_name = credential_name
|
| 138 |
+
# 标准化文件名以匹配状态数据中的键
|
| 139 |
+
if hasattr(self._storage_adapter._backend, '_normalize_filename'):
|
| 140 |
+
normalized_name = self._storage_adapter._backend._normalize_filename(credential_name)
|
| 141 |
+
|
| 142 |
+
state = all_states.get(normalized_name, {})
|
| 143 |
+
if not state.get("disabled", False):
|
| 144 |
+
available_credentials.append(credential_name)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
log.warning(f"Failed to batch load credential states, falling back to individual checks: {e}")
|
| 147 |
+
# 如果批量读取失败,回退到逐个检查
|
| 148 |
+
for credential_name in all_credentials:
|
| 149 |
+
try:
|
| 150 |
+
state = await self._storage_adapter.get_credential_state(credential_name)
|
| 151 |
+
if not state.get("disabled", False):
|
| 152 |
+
available_credentials.append(credential_name)
|
| 153 |
+
except Exception as e2:
|
| 154 |
+
log.warning(f"Failed to check state for credential {credential_name}: {e2}")
|
| 155 |
+
|
| 156 |
+
# 更新凭证列表
|
| 157 |
+
old_credentials = set(self._credential_files)
|
| 158 |
+
new_credentials = set(available_credentials)
|
| 159 |
+
|
| 160 |
+
if old_credentials != new_credentials:
|
| 161 |
+
# 记录变化(只在非初始状态时记录)
|
| 162 |
+
is_initial_load = len(old_credentials) == 0
|
| 163 |
+
added = new_credentials - old_credentials
|
| 164 |
+
removed = old_credentials - new_credentials
|
| 165 |
+
|
| 166 |
+
self._credential_files = available_credentials
|
| 167 |
+
|
| 168 |
+
# 初始加载时只记录调试信息,运行时变化才记录INFO
|
| 169 |
+
if not is_initial_load:
|
| 170 |
+
if added:
|
| 171 |
+
log.info(f"发现新的可用凭证: {list(added)}")
|
| 172 |
+
if removed:
|
| 173 |
+
log.info(f"移除不可用凭证: {list(removed)}")
|
| 174 |
+
else:
|
| 175 |
+
# 初始加载时只记录调试信息
|
| 176 |
+
if available_credentials:
|
| 177 |
+
log.debug(f"初始加载发现 {len(available_credentials)} 个可用凭证")
|
| 178 |
+
|
| 179 |
+
# 重置当前索引如果需要
|
| 180 |
+
if self._current_credential_index >= len(self._credential_files):
|
| 181 |
+
self._current_credential_index = 0
|
| 182 |
+
|
| 183 |
+
if not self._credential_files:
|
| 184 |
+
log.warning("No available credential files found")
|
| 185 |
+
else:
|
| 186 |
+
log.debug(f"Available credentials: {len(self._credential_files)} files")
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
log.error(f"Failed to discover credentials: {e}")
|
| 190 |
+
|
| 191 |
+
async def _load_current_credential(self) -> Optional[Tuple[str, Dict[str, Any]]]:
|
| 192 |
+
"""加载当前选中的凭证数据,包含token过期检测和自动刷新"""
|
| 193 |
+
if not self._credential_files:
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
current_file = self._credential_files[self._current_credential_index]
|
| 198 |
+
|
| 199 |
+
# 从存储适配器加载凭证数据
|
| 200 |
+
credential_data = await self._storage_adapter.get_credential(current_file)
|
| 201 |
+
if not credential_data:
|
| 202 |
+
log.error(f"Failed to load credential data for: {current_file}")
|
| 203 |
+
return None
|
| 204 |
+
|
| 205 |
+
# 检查refresh_token
|
| 206 |
+
if "refresh_token" not in credential_data or not credential_data["refresh_token"]:
|
| 207 |
+
log.warning(f"No refresh token in {current_file}")
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
# Auto-add 'type' field if missing but has required OAuth fields
|
| 211 |
+
if 'type' not in credential_data and all(key in credential_data for key in ['client_id', 'refresh_token']):
|
| 212 |
+
credential_data['type'] = 'authorized_user'
|
| 213 |
+
log.debug(f"Auto-added 'type' field to credential from file {current_file}")
|
| 214 |
+
|
| 215 |
+
# 兼容不同的token字段格式
|
| 216 |
+
if "access_token" in credential_data and "token" not in credential_data:
|
| 217 |
+
credential_data["token"] = credential_data["access_token"]
|
| 218 |
+
if "scope" in credential_data and "scopes" not in credential_data:
|
| 219 |
+
credential_data["scopes"] = credential_data["scope"].split()
|
| 220 |
+
|
| 221 |
+
# token过期检测和刷新
|
| 222 |
+
should_refresh = await self._should_refresh_token(credential_data)
|
| 223 |
+
|
| 224 |
+
if should_refresh:
|
| 225 |
+
log.debug(f"Token需要刷新 - 文件: {current_file}")
|
| 226 |
+
refreshed_data = await self._refresh_token(credential_data, current_file)
|
| 227 |
+
if refreshed_data:
|
| 228 |
+
credential_data = refreshed_data
|
| 229 |
+
log.debug(f"Token刷新成功: {current_file}")
|
| 230 |
+
else:
|
| 231 |
+
log.error(f"Token刷新失败: {current_file}")
|
| 232 |
+
return None
|
| 233 |
+
|
| 234 |
+
# 加载状态信息
|
| 235 |
+
state_data = await self._storage_adapter.get_credential_state(current_file)
|
| 236 |
+
|
| 237 |
+
# 缓存当前凭证信息
|
| 238 |
+
self._current_credential_file = current_file
|
| 239 |
+
self._current_credential_data = credential_data
|
| 240 |
+
self._current_credential_state = state_data
|
| 241 |
+
|
| 242 |
+
return current_file, credential_data
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
log.error(f"Error loading current credential: {e}")
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
async def get_valid_credential(self) -> Optional[Tuple[str, Dict[str, Any]]]:
|
| 249 |
+
"""获取有效的凭证,自动处理轮换和失效凭证切换"""
|
| 250 |
+
async with self._operation_lock:
|
| 251 |
+
if not self._credential_files:
|
| 252 |
+
await self._discover_credentials()
|
| 253 |
+
if not self._credential_files:
|
| 254 |
+
return None
|
| 255 |
+
|
| 256 |
+
# 检查是否需要轮换
|
| 257 |
+
if await self._should_rotate():
|
| 258 |
+
await self._rotate_credential()
|
| 259 |
+
|
| 260 |
+
# 尝试获取有效凭证,如果失败则自动切换
|
| 261 |
+
max_attempts = len(self._credential_files) # 最多尝试所有凭证
|
| 262 |
+
|
| 263 |
+
for attempt in range(max_attempts):
|
| 264 |
+
try:
|
| 265 |
+
# 加载当前凭证
|
| 266 |
+
result = await self._load_current_credential()
|
| 267 |
+
if result:
|
| 268 |
+
return result
|
| 269 |
+
|
| 270 |
+
# 当前凭证加载失败,标记为失效并切换到下一个
|
| 271 |
+
current_file = self._credential_files[self._current_credential_index] if self._credential_files else None
|
| 272 |
+
if current_file:
|
| 273 |
+
log.warning(f"凭证失效,自动禁用并切换: {current_file}")
|
| 274 |
+
await self.set_cred_disabled(current_file, True)
|
| 275 |
+
|
| 276 |
+
# 重新发现可用凭证(排除刚禁用的)
|
| 277 |
+
await self._discover_credentials()
|
| 278 |
+
if not self._credential_files:
|
| 279 |
+
log.error("没有可用的凭证")
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
# 重置索引到第一个可用凭证
|
| 283 |
+
self._current_credential_index = 0
|
| 284 |
+
log.info(f"切换到下一个可用凭证 (索引: {self._current_credential_index})")
|
| 285 |
+
else:
|
| 286 |
+
log.error("无法获取当前凭证文件名")
|
| 287 |
+
break
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
log.error(f"获取凭证时发生异常 (尝试 {attempt + 1}/{max_attempts}): {e}")
|
| 291 |
+
if attempt < max_attempts - 1:
|
| 292 |
+
# 切换到下一个凭证继续尝试
|
| 293 |
+
await self._rotate_credential()
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
log.error(f"所有 {max_attempts} 个凭证都尝试失败")
|
| 297 |
+
return None
|
| 298 |
+
|
| 299 |
+
async def _should_rotate(self) -> bool:
|
| 300 |
+
"""检查是否需要轮换凭证"""
|
| 301 |
+
if not self._credential_files or len(self._credential_files) <= 1:
|
| 302 |
+
return False
|
| 303 |
+
|
| 304 |
+
current_calls_per_rotation = await get_calls_per_rotation()
|
| 305 |
+
return self._call_count >= current_calls_per_rotation
|
| 306 |
+
|
| 307 |
+
async def _rotate_credential(self):
|
| 308 |
+
"""轮换到下一个凭证"""
|
| 309 |
+
if len(self._credential_files) <= 1:
|
| 310 |
+
return
|
| 311 |
+
|
| 312 |
+
self._current_credential_index = (self._current_credential_index + 1) % len(self._credential_files)
|
| 313 |
+
self._call_count = 0
|
| 314 |
+
|
| 315 |
+
log.info(f"Rotated to credential index {self._current_credential_index}")
|
| 316 |
+
|
| 317 |
+
async def force_rotate_credential(self):
|
| 318 |
+
"""强制轮换到下一个凭证(用于429错误处理)"""
|
| 319 |
+
async with self._operation_lock:
|
| 320 |
+
if len(self._credential_files) <= 1:
|
| 321 |
+
log.warning("Only one credential available, cannot rotate")
|
| 322 |
+
return
|
| 323 |
+
|
| 324 |
+
await self._rotate_credential()
|
| 325 |
+
log.info("Forced credential rotation due to rate limit")
|
| 326 |
+
|
| 327 |
+
def increment_call_count(self):
|
| 328 |
+
"""增加调用计数"""
|
| 329 |
+
self._call_count += 1
|
| 330 |
+
|
| 331 |
+
async def update_credential_state(self, credential_name: str, state_updates: Dict[str, Any]):
|
| 332 |
+
"""更新凭证状态"""
|
| 333 |
+
try:
|
| 334 |
+
# 直接通过存储适配器更新状态
|
| 335 |
+
success = await self._storage_adapter.update_credential_state(credential_name, state_updates)
|
| 336 |
+
|
| 337 |
+
# 如果是当前使用的凭证,更新缓存
|
| 338 |
+
if credential_name == self._current_credential_file:
|
| 339 |
+
self._current_credential_state.update(state_updates)
|
| 340 |
+
|
| 341 |
+
if success:
|
| 342 |
+
log.debug(f"Updated credential state: {credential_name}")
|
| 343 |
+
else:
|
| 344 |
+
log.warning(f"Failed to update credential state: {credential_name}")
|
| 345 |
+
|
| 346 |
+
return success
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
log.error(f"Error updating credential state {credential_name}: {e}")
|
| 350 |
+
return False
|
| 351 |
+
|
| 352 |
+
async def set_cred_disabled(self, credential_name: str, disabled: bool):
|
| 353 |
+
"""设置凭证的启用/禁用状态"""
|
| 354 |
+
try:
|
| 355 |
+
state_updates = {"disabled": disabled}
|
| 356 |
+
success = await self.update_credential_state(credential_name, state_updates)
|
| 357 |
+
|
| 358 |
+
if success:
|
| 359 |
+
# 如果禁用了当前正在使用的凭证,需要重新发现可用凭证
|
| 360 |
+
if disabled and credential_name == self._current_credential_file:
|
| 361 |
+
await self._discover_credentials()
|
| 362 |
+
if self._credential_files:
|
| 363 |
+
await self._rotate_credential()
|
| 364 |
+
|
| 365 |
+
action = "disabled" if disabled else "enabled"
|
| 366 |
+
log.info(f"Credential {action}: {credential_name}")
|
| 367 |
+
|
| 368 |
+
return success
|
| 369 |
+
|
| 370 |
+
except Exception as e:
|
| 371 |
+
log.error(f"Error setting credential disabled state {credential_name}: {e}")
|
| 372 |
+
return False
|
| 373 |
+
|
| 374 |
+
async def get_creds_status(self) -> Dict[str, Dict[str, Any]]:
|
| 375 |
+
"""获取所有凭证的状态"""
|
| 376 |
+
try:
|
| 377 |
+
# 从存储适配器获取所有状态
|
| 378 |
+
all_states = await self._storage_adapter.get_all_credential_states()
|
| 379 |
+
return all_states
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
log.error(f"Error getting credential statuses: {e}")
|
| 383 |
+
return {}
|
| 384 |
+
|
| 385 |
+
async def get_or_fetch_user_email(self, credential_name: str) -> Optional[str]:
|
| 386 |
+
"""获取或获取用户邮箱地址"""
|
| 387 |
+
try:
|
| 388 |
+
# 首先检查缓存的状态
|
| 389 |
+
state = await self._storage_adapter.get_credential_state(credential_name)
|
| 390 |
+
cached_email = state.get("user_email")
|
| 391 |
+
|
| 392 |
+
if cached_email:
|
| 393 |
+
return cached_email
|
| 394 |
+
|
| 395 |
+
# 如果没有缓存,从凭证数据获取
|
| 396 |
+
credential_data = await self._storage_adapter.get_credential(credential_name)
|
| 397 |
+
if not credential_data:
|
| 398 |
+
return None
|
| 399 |
+
|
| 400 |
+
# 尝试获取邮箱
|
| 401 |
+
email = await fetch_user_email_from_file(credential_data)
|
| 402 |
+
|
| 403 |
+
if email:
|
| 404 |
+
# 缓存邮箱地址
|
| 405 |
+
await self.update_credential_state(credential_name, {"user_email": email})
|
| 406 |
+
return email
|
| 407 |
+
|
| 408 |
+
return None
|
| 409 |
+
|
| 410 |
+
except Exception as e:
|
| 411 |
+
log.error(f"Error fetching user email for {credential_name}: {e}")
|
| 412 |
+
return None
|
| 413 |
+
|
| 414 |
+
async def record_api_call_result(self, credential_name: str, success: bool, error_code: Optional[int] = None):
|
| 415 |
+
"""记录API调用结果"""
|
| 416 |
+
try:
|
| 417 |
+
state_updates = {}
|
| 418 |
+
|
| 419 |
+
if success:
|
| 420 |
+
state_updates["last_success"] = time.time()
|
| 421 |
+
# 清除错误码(如果之前有的话)
|
| 422 |
+
state_updates["error_codes"] = []
|
| 423 |
+
elif error_code:
|
| 424 |
+
# 记录错误码
|
| 425 |
+
current_state = await self._storage_adapter.get_credential_state(credential_name)
|
| 426 |
+
error_codes = current_state.get("error_codes", [])
|
| 427 |
+
|
| 428 |
+
if error_code not in error_codes:
|
| 429 |
+
error_codes.append(error_code)
|
| 430 |
+
# 限制错误码列表长度
|
| 431 |
+
if len(error_codes) > 10:
|
| 432 |
+
error_codes = error_codes[-10:]
|
| 433 |
+
|
| 434 |
+
state_updates["error_codes"] = error_codes
|
| 435 |
+
|
| 436 |
+
if state_updates:
|
| 437 |
+
await self.update_credential_state(credential_name, state_updates)
|
| 438 |
+
|
| 439 |
+
except Exception as e:
|
| 440 |
+
log.error(f"Error recording API call result for {credential_name}: {e}")
|
| 441 |
+
|
| 442 |
+
# 原子操作支持
|
| 443 |
+
@asynccontextmanager
|
| 444 |
+
async def _atomic_operation(self, operation_name: str):
|
| 445 |
+
"""原子操作上下文管理器"""
|
| 446 |
+
async with self._atomic_lock:
|
| 447 |
+
self._atomic_counter += 1
|
| 448 |
+
operation_id = self._atomic_counter
|
| 449 |
+
log.debug(f"开始原子操作[{operation_id}]: {operation_name}")
|
| 450 |
+
|
| 451 |
+
try:
|
| 452 |
+
yield operation_id
|
| 453 |
+
log.debug(f"完成原子操作[{operation_id}]: {operation_name}")
|
| 454 |
+
except Exception as e:
|
| 455 |
+
log.error(f"原子操作[{operation_id}]失败: {operation_name} - {e}")
|
| 456 |
+
raise
|
| 457 |
+
|
| 458 |
+
async def _should_refresh_token(self, credential_data: Dict[str, Any]) -> bool:
|
| 459 |
+
"""检查token是否需要刷新"""
|
| 460 |
+
try:
|
| 461 |
+
# 如果没有access_token或过期时间,需要刷新
|
| 462 |
+
if not credential_data.get("access_token") and not credential_data.get("token"):
|
| 463 |
+
log.debug("没有access_token,需要刷新")
|
| 464 |
+
return True
|
| 465 |
+
|
| 466 |
+
expiry_str = credential_data.get("expiry")
|
| 467 |
+
if not expiry_str:
|
| 468 |
+
log.debug("没有过期时间,需要刷新")
|
| 469 |
+
return True
|
| 470 |
+
|
| 471 |
+
# 解析过期时间
|
| 472 |
+
try:
|
| 473 |
+
if isinstance(expiry_str, str):
|
| 474 |
+
if "+" in expiry_str:
|
| 475 |
+
file_expiry = datetime.fromisoformat(expiry_str)
|
| 476 |
+
elif expiry_str.endswith("Z"):
|
| 477 |
+
file_expiry = datetime.fromisoformat(expiry_str.replace('Z', '+00:00'))
|
| 478 |
+
else:
|
| 479 |
+
file_expiry = datetime.fromisoformat(expiry_str)
|
| 480 |
+
else:
|
| 481 |
+
log.debug("过期时间格式无效,需要刷新")
|
| 482 |
+
return True
|
| 483 |
+
|
| 484 |
+
# 确保时区信息
|
| 485 |
+
if file_expiry.tzinfo is None:
|
| 486 |
+
file_expiry = file_expiry.replace(tzinfo=timezone.utc)
|
| 487 |
+
|
| 488 |
+
# 检查是否还有至少5分钟有效期
|
| 489 |
+
now = datetime.now(timezone.utc)
|
| 490 |
+
time_left = (file_expiry - now).total_seconds()
|
| 491 |
+
|
| 492 |
+
log.debug(f"Token剩余时间: {int(time_left/60)}分钟")
|
| 493 |
+
|
| 494 |
+
if time_left > 300: # 5分钟缓冲
|
| 495 |
+
return False
|
| 496 |
+
else:
|
| 497 |
+
log.debug(f"Token即将过期(剩余{int(time_left/60)}分钟),需要刷新")
|
| 498 |
+
return True
|
| 499 |
+
|
| 500 |
+
except Exception as e:
|
| 501 |
+
log.warning(f"解析过期时间失败: {e},需要刷新")
|
| 502 |
+
return True
|
| 503 |
+
|
| 504 |
+
except Exception as e:
|
| 505 |
+
log.error(f"检查token过期时出错: {e}")
|
| 506 |
+
return True
|
| 507 |
+
|
| 508 |
+
async def _refresh_token(self, credential_data: Dict[str, Any], filename: str) -> Optional[Dict[str, Any]]:
|
| 509 |
+
"""刷新token并更新存储"""
|
| 510 |
+
try:
|
| 511 |
+
# 创建Credentials对象
|
| 512 |
+
creds = Credentials.from_dict(credential_data)
|
| 513 |
+
|
| 514 |
+
# 检查是否可以刷新
|
| 515 |
+
if not creds.refresh_token:
|
| 516 |
+
log.error(f"没有refresh_token,无法刷新: {filename}")
|
| 517 |
+
return None
|
| 518 |
+
|
| 519 |
+
# 刷新token
|
| 520 |
+
log.debug(f"正在刷新token: {filename}")
|
| 521 |
+
await creds.refresh()
|
| 522 |
+
|
| 523 |
+
# 更新凭证数据
|
| 524 |
+
if creds.access_token:
|
| 525 |
+
credential_data["access_token"] = creds.access_token
|
| 526 |
+
# 保持兼容性
|
| 527 |
+
credential_data["token"] = creds.access_token
|
| 528 |
+
|
| 529 |
+
if creds.expires_at:
|
| 530 |
+
credential_data["expiry"] = creds.expires_at.isoformat()
|
| 531 |
+
|
| 532 |
+
# 保存到存储
|
| 533 |
+
await self._storage_adapter.store_credential(filename, credential_data)
|
| 534 |
+
log.info(f"Token刷新成功并已保存: {filename}")
|
| 535 |
+
|
| 536 |
+
return credential_data
|
| 537 |
+
|
| 538 |
+
except Exception as e:
|
| 539 |
+
error_msg = str(e)
|
| 540 |
+
log.error(f"Token刷新失败 {filename}: {error_msg}")
|
| 541 |
+
|
| 542 |
+
# 检查是否是凭证永久失效的错误
|
| 543 |
+
is_permanent_failure = self._is_permanent_refresh_failure(error_msg)
|
| 544 |
+
|
| 545 |
+
if is_permanent_failure:
|
| 546 |
+
log.warning(f"检测到凭证永久失效: {filename}")
|
| 547 |
+
# 记录失效状态,但不在这里禁用凭证,让上层调用者处理
|
| 548 |
+
await self.record_api_call_result(filename, False, 400)
|
| 549 |
+
|
| 550 |
+
return None
|
| 551 |
+
|
| 552 |
+
def _is_permanent_refresh_failure(self, error_msg: str) -> bool:
|
| 553 |
+
"""判断是否是凭证永久失效的错误"""
|
| 554 |
+
# 常见的永久失效错误模式
|
| 555 |
+
permanent_error_patterns = [
|
| 556 |
+
"400 Bad Request",
|
| 557 |
+
"invalid_grant",
|
| 558 |
+
"refresh_token_expired",
|
| 559 |
+
"invalid_refresh_token",
|
| 560 |
+
"unauthorized_client",
|
| 561 |
+
"access_denied"
|
| 562 |
+
]
|
| 563 |
+
|
| 564 |
+
error_msg_lower = error_msg.lower()
|
| 565 |
+
for pattern in permanent_error_patterns:
|
| 566 |
+
if pattern.lower() in error_msg_lower:
|
| 567 |
+
return True
|
| 568 |
+
|
| 569 |
+
return False
|
| 570 |
+
|
| 571 |
+
# 兼容性方法 - 保持与现有代码的接口兼容
|
| 572 |
+
async def _update_token_in_file(self, file_path: str, new_token: str, expires_at=None):
|
| 573 |
+
"""更新凭证令牌(兼容性方法)"""
|
| 574 |
+
try:
|
| 575 |
+
credential_data = await self._storage_adapter.get_credential(file_path)
|
| 576 |
+
if not credential_data:
|
| 577 |
+
log.error(f"Credential not found for token update: {file_path}")
|
| 578 |
+
return False
|
| 579 |
+
|
| 580 |
+
# 更新令牌数据
|
| 581 |
+
credential_data["token"] = new_token
|
| 582 |
+
if expires_at:
|
| 583 |
+
credential_data["expiry"] = expires_at.isoformat() if hasattr(expires_at, 'isoformat') else expires_at
|
| 584 |
+
|
| 585 |
+
# 保存更新后的凭证
|
| 586 |
+
success = await self._storage_adapter.store_credential(file_path, credential_data)
|
| 587 |
+
|
| 588 |
+
if success:
|
| 589 |
+
log.debug(f"Token updated for credential: {file_path}")
|
| 590 |
+
else:
|
| 591 |
+
log.error(f"Failed to update token for credential: {file_path}")
|
| 592 |
+
|
| 593 |
+
return success
|
| 594 |
+
|
| 595 |
+
except Exception as e:
|
| 596 |
+
log.error(f"Error updating token for {file_path}: {e}")
|
| 597 |
+
return False
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
# 全局实例管理(保持兼容性)
|
| 601 |
+
_credential_manager: Optional[CredentialManager] = None
|
| 602 |
+
|
| 603 |
+
async def get_credential_manager() -> CredentialManager:
|
| 604 |
+
"""获取全局凭证管理器实例"""
|
| 605 |
+
global _credential_manager
|
| 606 |
+
|
| 607 |
+
if _credential_manager is None:
|
| 608 |
+
_credential_manager = CredentialManager()
|
| 609 |
+
await _credential_manager.initialize()
|
| 610 |
+
|
| 611 |
+
return _credential_manager
|
src/format_detector.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Format detection utilities for supporting both OpenAI and Gemini request formats
|
| 3 |
+
"""
|
| 4 |
+
from typing import Dict, Any
|
| 5 |
+
|
| 6 |
+
from log import log
|
| 7 |
+
|
| 8 |
+
def detect_request_format(data: Dict[str, Any]) -> str:
|
| 9 |
+
"""
|
| 10 |
+
Detect whether the request is in OpenAI or Gemini format.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
"openai" or "gemini"
|
| 14 |
+
"""
|
| 15 |
+
# OpenAI format indicators:
|
| 16 |
+
# - Has "messages" field with array of {role, content} objects
|
| 17 |
+
# - Role values are "system", "user", "assistant"
|
| 18 |
+
if "messages" in data and isinstance(data["messages"], list):
|
| 19 |
+
if data["messages"] and isinstance(data["messages"][0], dict):
|
| 20 |
+
# Check for OpenAI role values
|
| 21 |
+
first_role = data["messages"][0].get("role", "")
|
| 22 |
+
if first_role in ["system", "user", "assistant"]:
|
| 23 |
+
return "openai"
|
| 24 |
+
|
| 25 |
+
# Gemini format indicators:
|
| 26 |
+
# - Has "contents" field with array of {role, parts} objects
|
| 27 |
+
# - Role values are "user", "model"
|
| 28 |
+
# - May have "systemInstruction" field
|
| 29 |
+
if "contents" in data and isinstance(data["contents"], list):
|
| 30 |
+
if data["contents"] and isinstance(data["contents"][0], dict):
|
| 31 |
+
# Check for Gemini structure
|
| 32 |
+
if "parts" in data["contents"][0]:
|
| 33 |
+
return "gemini"
|
| 34 |
+
|
| 35 |
+
# Additional Gemini indicators
|
| 36 |
+
if "systemInstruction" in data or "generationConfig" in data:
|
| 37 |
+
return "gemini"
|
| 38 |
+
|
| 39 |
+
# Default to OpenAI if unclear (for backwards compatibility)
|
| 40 |
+
log.debug(f"Unable to definitively detect format, defaulting to OpenAI. Keys present: {list(data.keys())}")
|
| 41 |
+
return "openai"
|
| 42 |
+
|
| 43 |
+
def gemini_request_to_openai(gemini_request: Dict[str, Any]) -> Dict[str, Any]:
|
| 44 |
+
"""
|
| 45 |
+
Convert a Gemini format request to OpenAI format.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
gemini_request: Request in Gemini API format
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Dictionary in OpenAI API format
|
| 52 |
+
"""
|
| 53 |
+
openai_request = {
|
| 54 |
+
"model": gemini_request.get("model", "gemini-2.5-pro"),
|
| 55 |
+
"messages": []
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# Convert system instruction if present
|
| 59 |
+
if "systemInstruction" in gemini_request:
|
| 60 |
+
system_content = ""
|
| 61 |
+
if isinstance(gemini_request["systemInstruction"], dict):
|
| 62 |
+
parts = gemini_request["systemInstruction"].get("parts", [])
|
| 63 |
+
for part in parts:
|
| 64 |
+
if "text" in part:
|
| 65 |
+
system_content += part["text"]
|
| 66 |
+
elif isinstance(gemini_request["systemInstruction"], str):
|
| 67 |
+
system_content = gemini_request["systemInstruction"]
|
| 68 |
+
|
| 69 |
+
if system_content:
|
| 70 |
+
openai_request["messages"].append({
|
| 71 |
+
"role": "system",
|
| 72 |
+
"content": system_content
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
# Convert contents to messages
|
| 76 |
+
contents = gemini_request.get("contents", [])
|
| 77 |
+
for content in contents:
|
| 78 |
+
role = content.get("role", "user")
|
| 79 |
+
# Map Gemini roles to OpenAI roles
|
| 80 |
+
if role == "model":
|
| 81 |
+
role = "assistant"
|
| 82 |
+
|
| 83 |
+
# Convert parts to content
|
| 84 |
+
parts = content.get("parts", [])
|
| 85 |
+
if len(parts) == 1 and "text" in parts[0]:
|
| 86 |
+
# Simple text message
|
| 87 |
+
openai_request["messages"].append({
|
| 88 |
+
"role": role,
|
| 89 |
+
"content": parts[0]["text"]
|
| 90 |
+
})
|
| 91 |
+
elif len(parts) > 0:
|
| 92 |
+
# Multi-part message (could include images)
|
| 93 |
+
content_parts = []
|
| 94 |
+
for part in parts:
|
| 95 |
+
if "text" in part:
|
| 96 |
+
content_parts.append({
|
| 97 |
+
"type": "text",
|
| 98 |
+
"text": part["text"]
|
| 99 |
+
})
|
| 100 |
+
elif "inlineData" in part:
|
| 101 |
+
# Convert Gemini inline data to OpenAI image format
|
| 102 |
+
inline_data = part["inlineData"]
|
| 103 |
+
mime_type = inline_data.get("mimeType", "image/jpeg")
|
| 104 |
+
data = inline_data.get("data", "")
|
| 105 |
+
content_parts.append({
|
| 106 |
+
"type": "image_url",
|
| 107 |
+
"image_url": {
|
| 108 |
+
"url": f"data:{mime_type};base64,{data}"
|
| 109 |
+
}
|
| 110 |
+
})
|
| 111 |
+
|
| 112 |
+
if content_parts:
|
| 113 |
+
# If only one text part, use simple string format
|
| 114 |
+
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
| 115 |
+
openai_request["messages"].append({
|
| 116 |
+
"role": role,
|
| 117 |
+
"content": content_parts[0]["text"]
|
| 118 |
+
})
|
| 119 |
+
else:
|
| 120 |
+
openai_request["messages"].append({
|
| 121 |
+
"role": role,
|
| 122 |
+
"content": content_parts
|
| 123 |
+
})
|
| 124 |
+
|
| 125 |
+
# Convert generation config if present
|
| 126 |
+
if "generationConfig" in gemini_request:
|
| 127 |
+
config = gemini_request["generationConfig"]
|
| 128 |
+
if "temperature" in config:
|
| 129 |
+
openai_request["temperature"] = config["temperature"]
|
| 130 |
+
if "topP" in config:
|
| 131 |
+
openai_request["top_p"] = config["topP"]
|
| 132 |
+
if "topK" in config:
|
| 133 |
+
openai_request["top_k"] = config["topK"]
|
| 134 |
+
if "maxOutputTokens" in config:
|
| 135 |
+
openai_request["max_tokens"] = config["maxOutputTokens"]
|
| 136 |
+
if "stopSequences" in config:
|
| 137 |
+
openai_request["stop"] = config["stopSequences"]
|
| 138 |
+
if "frequencyPenalty" in config:
|
| 139 |
+
openai_request["frequency_penalty"] = config["frequencyPenalty"]
|
| 140 |
+
if "presencePenalty" in config:
|
| 141 |
+
openai_request["presence_penalty"] = config["presencePenalty"]
|
| 142 |
+
if "candidateCount" in config:
|
| 143 |
+
openai_request["n"] = config["candidateCount"]
|
| 144 |
+
if "seed" in config:
|
| 145 |
+
openai_request["seed"] = config["seed"]
|
| 146 |
+
|
| 147 |
+
# Preserve stream setting if present
|
| 148 |
+
if "stream" in gemini_request:
|
| 149 |
+
openai_request["stream"] = gemini_request["stream"]
|
| 150 |
+
|
| 151 |
+
return openai_request
|
| 152 |
+
|
| 153 |
+
def validate_and_normalize_request(data: Dict[str, Any]) -> Dict[str, Any]:
|
| 154 |
+
"""
|
| 155 |
+
Validate and normalize the request to OpenAI format.
|
| 156 |
+
Automatically detects format and converts if necessary.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
data: Raw request data
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Normalized request in OpenAI format
|
| 163 |
+
"""
|
| 164 |
+
format_type = detect_request_format(data)
|
| 165 |
+
log.info(f"Detected request format: {format_type}")
|
| 166 |
+
|
| 167 |
+
if format_type == "gemini":
|
| 168 |
+
# Convert Gemini format to OpenAI format
|
| 169 |
+
return gemini_request_to_openai(data)
|
| 170 |
+
else:
|
| 171 |
+
# Already in OpenAI format
|
| 172 |
+
return data
|
src/gemini_router.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gemini Router - Handles native Gemini format API requests
|
| 3 |
+
处理原生Gemini格式请求的路由模块
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
from fastapi import APIRouter, HTTPException, Depends, Request, Path, Query, status, Header
|
| 11 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 12 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 13 |
+
|
| 14 |
+
from config import get_available_models, is_fake_streaming_model, is_anti_truncation_model, get_base_model_from_feature_model, get_anti_truncation_max_attempts, get_base_model_name
|
| 15 |
+
from log import log
|
| 16 |
+
from .anti_truncation import apply_anti_truncation_to_stream
|
| 17 |
+
from .credential_manager import CredentialManager
|
| 18 |
+
from .google_chat_api import send_gemini_request, build_gemini_payload_from_native
|
| 19 |
+
from .openai_transfer import _extract_content_and_reasoning
|
| 20 |
+
from .task_manager import create_managed_task
|
| 21 |
+
# 创建路由器
|
| 22 |
+
router = APIRouter()
|
| 23 |
+
security = HTTPBearer()
|
| 24 |
+
|
| 25 |
+
# 全局凭证管理器实例
|
| 26 |
+
credential_manager = None
|
| 27 |
+
|
| 28 |
+
@asynccontextmanager
|
| 29 |
+
async def get_credential_manager():
|
| 30 |
+
"""获取全局凭证管理器实例"""
|
| 31 |
+
global credential_manager
|
| 32 |
+
if not credential_manager:
|
| 33 |
+
credential_manager = CredentialManager()
|
| 34 |
+
await credential_manager.initialize()
|
| 35 |
+
yield credential_manager
|
| 36 |
+
|
| 37 |
+
async def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
|
| 38 |
+
"""验证用户密码(Bearer Token方式)"""
|
| 39 |
+
from config import get_api_password
|
| 40 |
+
password = await get_api_password()
|
| 41 |
+
token = credentials.credentials
|
| 42 |
+
if token != password:
|
| 43 |
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="密码错误")
|
| 44 |
+
return token
|
| 45 |
+
|
| 46 |
+
async def authenticate_gemini_flexible(
|
| 47 |
+
request: Request,
|
| 48 |
+
x_goog_api_key: Optional[str] = Header(None, alias="x-goog-api-key"),
|
| 49 |
+
key: Optional[str] = Query(None),
|
| 50 |
+
credentials: Optional[HTTPAuthorizationCredentials] = Depends(lambda: None)
|
| 51 |
+
) -> str:
|
| 52 |
+
"""灵活验证:支持x-goog-api-key头部、URL参数key或Authorization Bearer"""
|
| 53 |
+
from config import get_api_password
|
| 54 |
+
password = await get_api_password()
|
| 55 |
+
|
| 56 |
+
# 尝试从URL参数key获取(Google官方标准方式)
|
| 57 |
+
if key:
|
| 58 |
+
log.debug(f"Using URL parameter key authentication")
|
| 59 |
+
if key == password:
|
| 60 |
+
return key
|
| 61 |
+
|
| 62 |
+
# 尝试从Authorization头获取(兼容旧方式)
|
| 63 |
+
auth_header = request.headers.get("authorization")
|
| 64 |
+
if auth_header and auth_header.startswith("Bearer "):
|
| 65 |
+
token = auth_header[7:] # 移除 "Bearer " 前缀
|
| 66 |
+
log.debug(f"Using Bearer token authentication")
|
| 67 |
+
if token == password:
|
| 68 |
+
return token
|
| 69 |
+
|
| 70 |
+
# 尝试从x-goog-api-key头获取(新标准方式)
|
| 71 |
+
if x_goog_api_key:
|
| 72 |
+
log.debug(f"Using x-goog-api-key authentication")
|
| 73 |
+
if x_goog_api_key == password:
|
| 74 |
+
return x_goog_api_key
|
| 75 |
+
|
| 76 |
+
log.error(f"Authentication failed. Headers: {dict(request.headers)}, Query params: key={key}")
|
| 77 |
+
raise HTTPException(
|
| 78 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 79 |
+
detail="Missing or invalid authentication. Use 'key' URL parameter, 'x-goog-api-key' header, or 'Authorization: Bearer <token>'"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
@router.get("/v1/v1beta/models")
|
| 83 |
+
@router.get("/v1/v1/models")
|
| 84 |
+
@router.get("/v1beta/models")
|
| 85 |
+
@router.get("/v1/models")
|
| 86 |
+
async def list_gemini_models():
|
| 87 |
+
"""返回Gemini格式的模型列表"""
|
| 88 |
+
models = get_available_models("gemini")
|
| 89 |
+
|
| 90 |
+
# 构建符合Gemini API格式的模型列表
|
| 91 |
+
gemini_models = []
|
| 92 |
+
for model_name in models:
|
| 93 |
+
# 获取基础模型名
|
| 94 |
+
base_model = get_base_model_from_feature_model(model_name)
|
| 95 |
+
|
| 96 |
+
model_info = {
|
| 97 |
+
"name": f"models/{model_name}",
|
| 98 |
+
"baseModelId": base_model,
|
| 99 |
+
"version": "001",
|
| 100 |
+
"displayName": model_name,
|
| 101 |
+
"description": f"Gemini {base_model} model",
|
| 102 |
+
"inputTokenLimit": 1000000,
|
| 103 |
+
"outputTokenLimit": 8192,
|
| 104 |
+
"supportedGenerationMethods": ["generateContent", "streamGenerateContent"],
|
| 105 |
+
"temperature": 1.0,
|
| 106 |
+
"maxTemperature": 2.0,
|
| 107 |
+
"topP": 0.95,
|
| 108 |
+
"topK": 64
|
| 109 |
+
}
|
| 110 |
+
gemini_models.append(model_info)
|
| 111 |
+
|
| 112 |
+
return JSONResponse(content={
|
| 113 |
+
"models": gemini_models
|
| 114 |
+
})
|
| 115 |
+
|
| 116 |
+
@router.post("/v1/v1beta/models/{model:path}:generateContent")
|
| 117 |
+
@router.post("/v1/v1/models/{model:path}:generateContent")
|
| 118 |
+
@router.post("/v1beta/models/{model:path}:generateContent")
|
| 119 |
+
@router.post("/v1/models/{model:path}:generateContent")
|
| 120 |
+
async def generate_content(
|
| 121 |
+
model: str = Path(..., description="Model name"),
|
| 122 |
+
request: Request = None,
|
| 123 |
+
api_key: str = Depends(authenticate_gemini_flexible)
|
| 124 |
+
):
|
| 125 |
+
"""处理Gemini格式的内容生成请求(非流式)"""
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# 获取原始请求数据
|
| 129 |
+
try:
|
| 130 |
+
request_data = await request.json()
|
| 131 |
+
except Exception as e:
|
| 132 |
+
log.error(f"Failed to parse JSON request: {e}")
|
| 133 |
+
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
|
| 134 |
+
|
| 135 |
+
# 验证必要字段
|
| 136 |
+
if "contents" not in request_data or not request_data["contents"]:
|
| 137 |
+
raise HTTPException(status_code=400, detail="Missing required field: contents")
|
| 138 |
+
|
| 139 |
+
# 请求预处理:限制参数
|
| 140 |
+
if "generationConfig" in request_data and request_data["generationConfig"]:
|
| 141 |
+
generation_config = request_data["generationConfig"]
|
| 142 |
+
|
| 143 |
+
# 限制max_tokens (在Gemini中叫maxOutputTokens)
|
| 144 |
+
if "maxOutputTokens" in generation_config and generation_config["maxOutputTokens"] is not None:
|
| 145 |
+
if generation_config["maxOutputTokens"] > 65535:
|
| 146 |
+
generation_config["maxOutputTokens"] = 65535
|
| 147 |
+
|
| 148 |
+
# 覆写 top_k 为 64 (在Gemini中叫topK)
|
| 149 |
+
generation_config["topK"] = 64
|
| 150 |
+
else:
|
| 151 |
+
# 如果没有generationConfig,创建一个并设置topK
|
| 152 |
+
request_data["generationConfig"] = {"topK": 64}
|
| 153 |
+
|
| 154 |
+
# 处理模型名称和功能检测
|
| 155 |
+
use_anti_truncation = is_anti_truncation_model(model)
|
| 156 |
+
|
| 157 |
+
# 获取基础模型名
|
| 158 |
+
real_model = get_base_model_from_feature_model(model)
|
| 159 |
+
|
| 160 |
+
# 对于假流式模型,如果是流式端点才返回假流式响应
|
| 161 |
+
# 注意:这是generateContent端点,不应该触发假流式
|
| 162 |
+
|
| 163 |
+
# 对于抗截断模型的非流式请求,给出警告
|
| 164 |
+
if use_anti_truncation:
|
| 165 |
+
log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置")
|
| 166 |
+
|
| 167 |
+
# 健康检查
|
| 168 |
+
if (len(request_data["contents"]) == 1 and
|
| 169 |
+
request_data["contents"][0].get("role") == "user" and
|
| 170 |
+
request_data["contents"][0].get("parts", [{}])[0].get("text") == "Hi"):
|
| 171 |
+
return JSONResponse(content={
|
| 172 |
+
"candidates": [{
|
| 173 |
+
"content": {
|
| 174 |
+
"parts": [{"text": "gcli2api工作中"}],
|
| 175 |
+
"role": "model"
|
| 176 |
+
},
|
| 177 |
+
"finishReason": "STOP",
|
| 178 |
+
"index": 0
|
| 179 |
+
}]
|
| 180 |
+
})
|
| 181 |
+
|
| 182 |
+
# 获取凭证管理器
|
| 183 |
+
from src.credential_manager import get_credential_manager
|
| 184 |
+
cred_mgr = await get_credential_manager()
|
| 185 |
+
|
| 186 |
+
# 获取有效凭证
|
| 187 |
+
credential_result = await cred_mgr.get_valid_credential()
|
| 188 |
+
if not credential_result:
|
| 189 |
+
log.error("当前无可用凭证,请去控制台获取")
|
| 190 |
+
raise HTTPException(status_code=500, detail="当前无可用凭证,请去控制台获取")
|
| 191 |
+
|
| 192 |
+
# 增加调用计数
|
| 193 |
+
cred_mgr.increment_call_count()
|
| 194 |
+
|
| 195 |
+
# 构建Google API payload
|
| 196 |
+
try:
|
| 197 |
+
api_payload = build_gemini_payload_from_native(request_data, real_model)
|
| 198 |
+
except Exception as e:
|
| 199 |
+
log.error(f"Gemini payload build failed: {e}")
|
| 200 |
+
raise HTTPException(status_code=500, detail="Request processing failed")
|
| 201 |
+
|
| 202 |
+
# 发送请求(429重试已在google_api_client中处理)
|
| 203 |
+
response = await send_gemini_request(api_payload, False, cred_mgr)
|
| 204 |
+
|
| 205 |
+
# 处理响应
|
| 206 |
+
try:
|
| 207 |
+
if hasattr(response, 'body'):
|
| 208 |
+
response_data = json.loads(response.body.decode() if isinstance(response.body, bytes) else response.body)
|
| 209 |
+
elif hasattr(response, 'content'):
|
| 210 |
+
response_data = json.loads(response.content.decode() if isinstance(response.content, bytes) else response.content)
|
| 211 |
+
else:
|
| 212 |
+
response_data = json.loads(str(response))
|
| 213 |
+
|
| 214 |
+
return JSONResponse(content=response_data)
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
log.error(f"Response processing failed: {e}")
|
| 218 |
+
# 返回原始响应
|
| 219 |
+
if hasattr(response, 'content'):
|
| 220 |
+
return JSONResponse(content=json.loads(response.content))
|
| 221 |
+
else:
|
| 222 |
+
raise HTTPException(status_code=500, detail="Response processing failed")
|
| 223 |
+
|
| 224 |
+
@router.post("/v1/v1beta/models/{model:path}:streamGenerateContent")
|
| 225 |
+
@router.post("/v1/v1/models/{model:path}:streamGenerateContent")
|
| 226 |
+
@router.post("/v1beta/models/{model:path}:streamGenerateContent")
|
| 227 |
+
@router.post("/v1/models/{model:path}:streamGenerateContent")
|
| 228 |
+
async def stream_generate_content(
|
| 229 |
+
model: str = Path(..., description="Model name"),
|
| 230 |
+
request: Request = None,
|
| 231 |
+
api_key: str = Depends(authenticate_gemini_flexible)
|
| 232 |
+
):
|
| 233 |
+
"""处理Gemini格式的流式内容生成请求"""
|
| 234 |
+
log.debug(f"Stream request received for model: {model}")
|
| 235 |
+
log.debug(f"Request headers: {dict(request.headers)}")
|
| 236 |
+
log.debug(f"API key received: {api_key[:10] if api_key else None}...")
|
| 237 |
+
try:
|
| 238 |
+
body = await request.body()
|
| 239 |
+
log.debug(f"request body: {body.decode() if isinstance(body, bytes) else body}")
|
| 240 |
+
except Exception as e:
|
| 241 |
+
log.error(f"Failed to read request body: {e}")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# 获取原始请求数据
|
| 245 |
+
try:
|
| 246 |
+
request_data = await request.json()
|
| 247 |
+
except Exception as e:
|
| 248 |
+
log.error(f"Failed to parse JSON request: {e}")
|
| 249 |
+
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
|
| 250 |
+
|
| 251 |
+
# 验���必要字段
|
| 252 |
+
if "contents" not in request_data or not request_data["contents"]:
|
| 253 |
+
raise HTTPException(status_code=400, detail="Missing required field: contents")
|
| 254 |
+
|
| 255 |
+
# 请求预处理:限制参数
|
| 256 |
+
if "generationConfig" in request_data and request_data["generationConfig"]:
|
| 257 |
+
generation_config = request_data["generationConfig"]
|
| 258 |
+
|
| 259 |
+
# 限制max_tokens (在Gemini中叫maxOutputTokens)
|
| 260 |
+
if "maxOutputTokens" in generation_config and generation_config["maxOutputTokens"] is not None:
|
| 261 |
+
if generation_config["maxOutputTokens"] > 65535:
|
| 262 |
+
generation_config["maxOutputTokens"] = 65535
|
| 263 |
+
|
| 264 |
+
# 覆写 top_k 为 64 (在Gemini中叫topK)
|
| 265 |
+
generation_config["topK"] = 64
|
| 266 |
+
else:
|
| 267 |
+
# 如果没有generationConfig,创建一个并设置topK
|
| 268 |
+
request_data["generationConfig"] = {"topK": 64}
|
| 269 |
+
|
| 270 |
+
# 处理模型名称和功能检测
|
| 271 |
+
use_fake_streaming = is_fake_streaming_model(model)
|
| 272 |
+
use_anti_truncation = is_anti_truncation_model(model)
|
| 273 |
+
|
| 274 |
+
# 获取基础模型名
|
| 275 |
+
real_model = get_base_model_from_feature_model(model)
|
| 276 |
+
|
| 277 |
+
# 对于假流式模型,返回假流式响应
|
| 278 |
+
if use_fake_streaming:
|
| 279 |
+
return await fake_stream_response_gemini(request_data, real_model)
|
| 280 |
+
|
| 281 |
+
# 获取凭证管理器
|
| 282 |
+
from src.credential_manager import get_credential_manager
|
| 283 |
+
cred_mgr = await get_credential_manager()
|
| 284 |
+
|
| 285 |
+
# 获取有效凭证
|
| 286 |
+
credential_result = await cred_mgr.get_valid_credential()
|
| 287 |
+
if not credential_result:
|
| 288 |
+
log.error("当前无可用凭证,请去控制台获取")
|
| 289 |
+
raise HTTPException(status_code=500, detail="当前无可用凭证,请去控制台获取")
|
| 290 |
+
|
| 291 |
+
# 增加调用计数
|
| 292 |
+
cred_mgr.increment_call_count()
|
| 293 |
+
|
| 294 |
+
# 构建Google API payload
|
| 295 |
+
try:
|
| 296 |
+
api_payload = build_gemini_payload_from_native(request_data, real_model)
|
| 297 |
+
except Exception as e:
|
| 298 |
+
log.error(f"Gemini payload build failed: {e}")
|
| 299 |
+
raise HTTPException(status_code=500, detail="Request processing failed")
|
| 300 |
+
|
| 301 |
+
# 处理抗截断功能(仅流式传输时有效)
|
| 302 |
+
if use_anti_truncation:
|
| 303 |
+
log.info("启用流式抗截断功能")
|
| 304 |
+
# 使用流式抗截断处理器
|
| 305 |
+
max_attempts = await get_anti_truncation_max_attempts()
|
| 306 |
+
return await apply_anti_truncation_to_stream(
|
| 307 |
+
lambda payload: send_gemini_request(payload, True, cred_mgr),
|
| 308 |
+
api_payload,
|
| 309 |
+
max_attempts
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# 常规流式请求(429重试已在google_api_client中处理)
|
| 313 |
+
response = await send_gemini_request(api_payload, True, cred_mgr)
|
| 314 |
+
|
| 315 |
+
# 直接返回流式响应
|
| 316 |
+
return response
|
| 317 |
+
|
| 318 |
+
@router.post("/v1/v1beta/models/{model:path}:countTokens")
|
| 319 |
+
@router.post("/v1/v1/models/{model:path}:countTokens")
|
| 320 |
+
@router.post("/v1beta/models/{model:path}:countTokens")
|
| 321 |
+
@router.post("/v1/models/{model:path}:countTokens")
|
| 322 |
+
async def count_tokens(
|
| 323 |
+
request: Request = None,
|
| 324 |
+
api_key: str = Depends(authenticate_gemini_flexible)
|
| 325 |
+
):
|
| 326 |
+
"""模拟Gemini格式的token计数"""
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
request_data = await request.json()
|
| 330 |
+
except Exception as e:
|
| 331 |
+
log.error(f"Failed to parse JSON request: {e}")
|
| 332 |
+
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
|
| 333 |
+
|
| 334 |
+
# 简单的token计数模拟 - 基于文本长度估算
|
| 335 |
+
total_tokens = 0
|
| 336 |
+
|
| 337 |
+
# 如果有contents字段
|
| 338 |
+
if "contents" in request_data:
|
| 339 |
+
for content in request_data["contents"]:
|
| 340 |
+
if "parts" in content:
|
| 341 |
+
for part in content["parts"]:
|
| 342 |
+
if "text" in part:
|
| 343 |
+
# 简单估算:大约4字符=1token
|
| 344 |
+
text_length = len(part["text"])
|
| 345 |
+
total_tokens += max(1, text_length // 4)
|
| 346 |
+
|
| 347 |
+
# 如果有generateContentRequest字段
|
| 348 |
+
elif "generateContentRequest" in request_data:
|
| 349 |
+
gen_request = request_data["generateContentRequest"]
|
| 350 |
+
if "contents" in gen_request:
|
| 351 |
+
for content in gen_request["contents"]:
|
| 352 |
+
if "parts" in content:
|
| 353 |
+
for part in content["parts"]:
|
| 354 |
+
if "text" in part:
|
| 355 |
+
text_length = len(part["text"])
|
| 356 |
+
total_tokens += max(1, text_length // 4)
|
| 357 |
+
|
| 358 |
+
# 返回Gemini格式的响应
|
| 359 |
+
return JSONResponse(content={
|
| 360 |
+
"totalTokens": total_tokens
|
| 361 |
+
})
|
| 362 |
+
|
| 363 |
+
@router.get("/v1/v1beta/models/{model:path}")
|
| 364 |
+
@router.get("/v1/v1/models/{model:path}")
|
| 365 |
+
@router.get("/v1beta/models/{model:path}")
|
| 366 |
+
@router.get("/v1/models/{model:path}")
|
| 367 |
+
async def get_model_info(
|
| 368 |
+
model: str = Path(..., description="Model name"),
|
| 369 |
+
api_key: str = Depends(authenticate_gemini_flexible)
|
| 370 |
+
):
|
| 371 |
+
"""获取特定模型的信息"""
|
| 372 |
+
|
| 373 |
+
# 获取基础模型名称
|
| 374 |
+
base_model = get_base_model_name(model)
|
| 375 |
+
|
| 376 |
+
# 模拟模型信息
|
| 377 |
+
model_info = {
|
| 378 |
+
"name": f"models/{base_model}",
|
| 379 |
+
"baseModelId": base_model,
|
| 380 |
+
"version": "001",
|
| 381 |
+
"displayName": base_model,
|
| 382 |
+
"description": f"Gemini {base_model} model",
|
| 383 |
+
"inputTokenLimit": 128000,
|
| 384 |
+
"outputTokenLimit": 8192,
|
| 385 |
+
"supportedGenerationMethods": [
|
| 386 |
+
"generateContent",
|
| 387 |
+
"streamGenerateContent"
|
| 388 |
+
],
|
| 389 |
+
"temperature": 1.0,
|
| 390 |
+
"maxTemperature": 2.0,
|
| 391 |
+
"topP": 0.95,
|
| 392 |
+
"topK": 64
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
return JSONResponse(content=model_info)
|
| 396 |
+
|
| 397 |
+
async def fake_stream_response_gemini(request_data: dict, model: str):
|
| 398 |
+
"""处理Gemini格式的假流式响应"""
|
| 399 |
+
|
| 400 |
+
async def gemini_stream_generator():
|
| 401 |
+
try:
|
| 402 |
+
# 获取凭证管理器
|
| 403 |
+
from src.credential_manager import get_credential_manager
|
| 404 |
+
cred_mgr = await get_credential_manager()
|
| 405 |
+
|
| 406 |
+
# 获取有效凭证
|
| 407 |
+
credential_result = await cred_mgr.get_valid_credential()
|
| 408 |
+
if not credential_result:
|
| 409 |
+
log.error("当前无可用凭证,请去控制台获取")
|
| 410 |
+
error_chunk = {
|
| 411 |
+
"error": {
|
| 412 |
+
"message": "当前无凭证,请去控制台获取",
|
| 413 |
+
"type": "authentication_error",
|
| 414 |
+
"code": 500
|
| 415 |
+
}
|
| 416 |
+
}
|
| 417 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 418 |
+
yield "data: [DONE]\n\n".encode()
|
| 419 |
+
return
|
| 420 |
+
|
| 421 |
+
# 增加调用计数
|
| 422 |
+
cred_mgr.increment_call_count()
|
| 423 |
+
|
| 424 |
+
# 构建Google API payload
|
| 425 |
+
try:
|
| 426 |
+
api_payload = build_gemini_payload_from_native(request_data, model)
|
| 427 |
+
except Exception as e:
|
| 428 |
+
log.error(f"Gemini payload build failed: {e}")
|
| 429 |
+
error_chunk = {
|
| 430 |
+
"error": {
|
| 431 |
+
"message": f"Request processing failed: {str(e)}",
|
| 432 |
+
"type": "api_error",
|
| 433 |
+
"code": 500
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 437 |
+
yield "data: [DONE]\n\n".encode()
|
| 438 |
+
return
|
| 439 |
+
|
| 440 |
+
# 发送心跳
|
| 441 |
+
heartbeat = {
|
| 442 |
+
"candidates": [{
|
| 443 |
+
"content": {
|
| 444 |
+
"parts": [{"text": ""}],
|
| 445 |
+
"role": "model"
|
| 446 |
+
},
|
| 447 |
+
"finishReason": None,
|
| 448 |
+
"index": 0
|
| 449 |
+
}]
|
| 450 |
+
}
|
| 451 |
+
yield f"data: {json.dumps(heartbeat)}\n\n".encode()
|
| 452 |
+
|
| 453 |
+
# 异步发送实际请求
|
| 454 |
+
async def get_response():
|
| 455 |
+
return await send_gemini_request(api_payload, False, cred_mgr)
|
| 456 |
+
|
| 457 |
+
# 创建请求任务
|
| 458 |
+
response_task = create_managed_task(get_response(), name="gemini_fake_stream_request")
|
| 459 |
+
|
| 460 |
+
try:
|
| 461 |
+
# 每3秒发送一次心跳,直到收到响应
|
| 462 |
+
while not response_task.done():
|
| 463 |
+
await asyncio.sleep(3.0)
|
| 464 |
+
if not response_task.done():
|
| 465 |
+
yield f"data: {json.dumps(heartbeat)}\n\n".encode()
|
| 466 |
+
|
| 467 |
+
# 获取响应结果
|
| 468 |
+
response = await response_task
|
| 469 |
+
|
| 470 |
+
except asyncio.CancelledError:
|
| 471 |
+
# 取消任务并传播取消
|
| 472 |
+
response_task.cancel()
|
| 473 |
+
try:
|
| 474 |
+
await response_task
|
| 475 |
+
except asyncio.CancelledError:
|
| 476 |
+
pass
|
| 477 |
+
raise
|
| 478 |
+
except Exception as e:
|
| 479 |
+
# 取消任务并处理其他异常
|
| 480 |
+
response_task.cancel()
|
| 481 |
+
try:
|
| 482 |
+
await response_task
|
| 483 |
+
except asyncio.CancelledError:
|
| 484 |
+
pass
|
| 485 |
+
log.error(f"Fake streaming request failed: {e}")
|
| 486 |
+
raise
|
| 487 |
+
|
| 488 |
+
# 发送实际请求
|
| 489 |
+
# response 已在上面获取
|
| 490 |
+
|
| 491 |
+
# 处理结果
|
| 492 |
+
try:
|
| 493 |
+
if hasattr(response, 'body'):
|
| 494 |
+
response_data = json.loads(response.body.decode() if isinstance(response.body, bytes) else response.body)
|
| 495 |
+
elif hasattr(response, 'content'):
|
| 496 |
+
response_data = json.loads(response.content.decode() if isinstance(response.content, bytes) else response.content)
|
| 497 |
+
else:
|
| 498 |
+
response_data = json.loads(str(response))
|
| 499 |
+
|
| 500 |
+
log.debug(f"Gemini fake stream response data: {response_data}")
|
| 501 |
+
|
| 502 |
+
# 发送完整内容作为单个chunk,使用思维链分离
|
| 503 |
+
if "candidates" in response_data and response_data["candidates"]:
|
| 504 |
+
candidate = response_data["candidates"][0]
|
| 505 |
+
if "content" in candidate and "parts" in candidate["content"]:
|
| 506 |
+
parts = candidate["content"]["parts"]
|
| 507 |
+
content, reasoning_content = _extract_content_and_reasoning(parts)
|
| 508 |
+
log.debug(f"Gemini extracted content: {content}")
|
| 509 |
+
log.debug(f"Gemini extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...")
|
| 510 |
+
|
| 511 |
+
# 如果没有正常内容但有思维内容
|
| 512 |
+
if not content and reasoning_content:
|
| 513 |
+
log.warning(f"Gemini fake stream contains only thinking content: {reasoning_content[:100]}...")
|
| 514 |
+
content = "[模型正在思考中,请稍后再试或重新提问]"
|
| 515 |
+
|
| 516 |
+
if content:
|
| 517 |
+
# 构建包含分离内容的响应
|
| 518 |
+
parts_response = [{"text": content}]
|
| 519 |
+
if reasoning_content:
|
| 520 |
+
parts_response.append({"text": reasoning_content, "thought": True})
|
| 521 |
+
|
| 522 |
+
content_chunk = {
|
| 523 |
+
"candidates": [{
|
| 524 |
+
"content": {
|
| 525 |
+
"parts": parts_response,
|
| 526 |
+
"role": "model"
|
| 527 |
+
},
|
| 528 |
+
"finishReason": candidate.get("finishReason", "STOP"),
|
| 529 |
+
"index": 0
|
| 530 |
+
}]
|
| 531 |
+
}
|
| 532 |
+
yield f"data: {json.dumps(content_chunk)}\n\n".encode()
|
| 533 |
+
else:
|
| 534 |
+
log.warning(f"No content found in Gemini candidate: {candidate}")
|
| 535 |
+
# 提供默认回复
|
| 536 |
+
error_chunk = {
|
| 537 |
+
"candidates": [{
|
| 538 |
+
"content": {
|
| 539 |
+
"parts": [{"text": "[响应为空,请重新尝试]"}],
|
| 540 |
+
"role": "model"
|
| 541 |
+
},
|
| 542 |
+
"finishReason": "STOP",
|
| 543 |
+
"index": 0
|
| 544 |
+
}]
|
| 545 |
+
}
|
| 546 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 547 |
+
else:
|
| 548 |
+
log.warning(f"No content/parts found in Gemini candidate: {candidate}")
|
| 549 |
+
# 返回原始响应
|
| 550 |
+
yield f"data: {json.dumps(response_data)}\n\n".encode()
|
| 551 |
+
else:
|
| 552 |
+
log.warning(f"No candidates found in Gemini response: {response_data}")
|
| 553 |
+
yield f"data: {json.dumps(response_data)}\n\n".encode()
|
| 554 |
+
|
| 555 |
+
except Exception as e:
|
| 556 |
+
log.error(f"Response parsing failed: {e}")
|
| 557 |
+
error_chunk = {
|
| 558 |
+
"candidates": [{
|
| 559 |
+
"content": {
|
| 560 |
+
"parts": [{"text": f"Response parsing error: {str(e)}"}],
|
| 561 |
+
"role": "model"
|
| 562 |
+
},
|
| 563 |
+
"finishReason": "ERROR",
|
| 564 |
+
"index": 0
|
| 565 |
+
}]
|
| 566 |
+
}
|
| 567 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 568 |
+
|
| 569 |
+
yield "data: [DONE]\n\n".encode()
|
| 570 |
+
|
| 571 |
+
except Exception as e:
|
| 572 |
+
log.error(f"Fake streaming error: {e}")
|
| 573 |
+
error_chunk = {
|
| 574 |
+
"error": {
|
| 575 |
+
"message": f"Fake streaming error: {str(e)}",
|
| 576 |
+
"type": "api_error",
|
| 577 |
+
"code": 500
|
| 578 |
+
}
|
| 579 |
+
}
|
| 580 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 581 |
+
yield "data: [DONE]\n\n".encode()
|
| 582 |
+
|
| 583 |
+
return StreamingResponse(gemini_stream_generator(), media_type="text/event-stream")
|
src/google_chat_api.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Google API Client - Handles all communication with Google's Gemini API.
|
| 3 |
+
This module is used by both OpenAI compatibility layer and native Gemini endpoints.
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import gc
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
from fastapi import Response
|
| 10 |
+
from fastapi.responses import StreamingResponse
|
| 11 |
+
|
| 12 |
+
from config import (
|
| 13 |
+
get_code_assist_endpoint,
|
| 14 |
+
DEFAULT_SAFETY_SETTINGS,
|
| 15 |
+
get_base_model_name,
|
| 16 |
+
get_thinking_budget,
|
| 17 |
+
should_include_thoughts,
|
| 18 |
+
is_search_model,
|
| 19 |
+
get_auto_ban_enabled,
|
| 20 |
+
get_auto_ban_error_codes,
|
| 21 |
+
get_retry_429_max_retries,
|
| 22 |
+
get_retry_429_enabled,
|
| 23 |
+
get_retry_429_interval
|
| 24 |
+
)
|
| 25 |
+
from .httpx_client import http_client, create_streaming_client_with_kwargs
|
| 26 |
+
from log import log
|
| 27 |
+
from .credential_manager import CredentialManager
|
| 28 |
+
from .usage_stats import record_successful_call
|
| 29 |
+
from .utils import get_user_agent
|
| 30 |
+
|
| 31 |
+
def _create_error_response(message: str, status_code: int = 500) -> Response:
|
| 32 |
+
"""Create standardized error response."""
|
| 33 |
+
return Response(
|
| 34 |
+
content=json.dumps({
|
| 35 |
+
"error": {
|
| 36 |
+
"message": message,
|
| 37 |
+
"type": "api_error",
|
| 38 |
+
"code": status_code
|
| 39 |
+
}
|
| 40 |
+
}),
|
| 41 |
+
status_code=status_code,
|
| 42 |
+
media_type="application/json"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
async def _handle_api_error(credential_manager: CredentialManager, status_code: int, response_content: str = ""):
|
| 46 |
+
"""Handle API errors by rotating credentials when needed. Error recording should be done before calling this function."""
|
| 47 |
+
if status_code == 429 and credential_manager:
|
| 48 |
+
if response_content:
|
| 49 |
+
log.error(f"Google API returned status 429 - quota exhausted. Response details: {response_content[:500]}")
|
| 50 |
+
else:
|
| 51 |
+
log.error("Google API returned status 429 - quota exhausted, switching credentials")
|
| 52 |
+
await credential_manager.force_rotate_credential()
|
| 53 |
+
|
| 54 |
+
# 处理自动封禁的错误码
|
| 55 |
+
elif await get_auto_ban_enabled() and status_code in await get_auto_ban_error_codes() and credential_manager:
|
| 56 |
+
if response_content:
|
| 57 |
+
log.error(f"Google API returned status {status_code} - auto ban triggered. Response details: {response_content[:500]}")
|
| 58 |
+
else:
|
| 59 |
+
log.warning(f"Google API returned status {status_code} - auto ban triggered, rotating credentials")
|
| 60 |
+
await credential_manager.force_rotate_credential()
|
| 61 |
+
|
| 62 |
+
async def _prepare_request_headers_and_payload(payload: dict, credential_data: dict):
|
| 63 |
+
"""Prepare request headers and final payload from credential data."""
|
| 64 |
+
# 尝试获取token,支持多种字段名
|
| 65 |
+
token = credential_data.get('token') or credential_data.get('access_token', '')
|
| 66 |
+
|
| 67 |
+
if not token:
|
| 68 |
+
raise Exception("凭证中没有找到有效的访问令牌(token或access_token字段)")
|
| 69 |
+
|
| 70 |
+
headers = {
|
| 71 |
+
"Authorization": f"Bearer {token}",
|
| 72 |
+
"Content-Type": "application/json",
|
| 73 |
+
"User-Agent": get_user_agent(),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# 直接使用凭证数据中的项目ID
|
| 77 |
+
project_id = credential_data.get("project_id", "")
|
| 78 |
+
if not project_id:
|
| 79 |
+
raise Exception("项目ID不存在于凭证数据中")
|
| 80 |
+
|
| 81 |
+
final_payload = {
|
| 82 |
+
"model": payload.get("model"),
|
| 83 |
+
"project": project_id,
|
| 84 |
+
"request": payload.get("request", {})
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
return headers, final_payload
|
| 88 |
+
|
| 89 |
+
async def send_gemini_request(payload: dict, is_streaming: bool = False, credential_manager: CredentialManager = None) -> Response:
|
| 90 |
+
"""
|
| 91 |
+
Send a request to Google's Gemini API.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
payload: The request payload in Gemini format
|
| 95 |
+
is_streaming: Whether this is a streaming request
|
| 96 |
+
credential_manager: CredentialManager instance
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
FastAPI Response object
|
| 100 |
+
"""
|
| 101 |
+
# 获取429重试配置
|
| 102 |
+
max_retries = await get_retry_429_max_retries()
|
| 103 |
+
retry_429_enabled = await get_retry_429_enabled()
|
| 104 |
+
retry_interval = await get_retry_429_interval()
|
| 105 |
+
|
| 106 |
+
# 确定API端点
|
| 107 |
+
action = "streamGenerateContent" if is_streaming else "generateContent"
|
| 108 |
+
target_url = f"{await get_code_assist_endpoint()}/v1internal:{action}"
|
| 109 |
+
if is_streaming:
|
| 110 |
+
target_url += "?alt=sse"
|
| 111 |
+
|
| 112 |
+
# 确保有credential_manager
|
| 113 |
+
if not credential_manager:
|
| 114 |
+
return _create_error_response("Credential manager not provided", 500)
|
| 115 |
+
|
| 116 |
+
# 获取当前凭证
|
| 117 |
+
try:
|
| 118 |
+
credential_result = await credential_manager.get_valid_credential()
|
| 119 |
+
if not credential_result:
|
| 120 |
+
return _create_error_response("No valid credentials available", 500)
|
| 121 |
+
|
| 122 |
+
current_file, credential_data = credential_result
|
| 123 |
+
headers, final_payload = await _prepare_request_headers_and_payload(payload, credential_data)
|
| 124 |
+
except Exception as e:
|
| 125 |
+
return _create_error_response(str(e), 500)
|
| 126 |
+
|
| 127 |
+
# 预序列化payload,避免重试时重复序列化
|
| 128 |
+
final_post_data = json.dumps(final_payload)
|
| 129 |
+
|
| 130 |
+
# Debug日志:打印请求体结构
|
| 131 |
+
log.debug(f"Final request payload structure: {json.dumps(final_payload, ensure_ascii=False, indent=2)}")
|
| 132 |
+
|
| 133 |
+
for attempt in range(max_retries + 1):
|
| 134 |
+
try:
|
| 135 |
+
if is_streaming:
|
| 136 |
+
# 流式请求处理 - 使用httpx_client模块的统一配置
|
| 137 |
+
client = await create_streaming_client_with_kwargs()
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
# 使用stream方法但不在async with块中消费数据
|
| 141 |
+
stream_ctx = client.stream("POST", target_url, content=final_post_data, headers=headers)
|
| 142 |
+
resp = await stream_ctx.__aenter__()
|
| 143 |
+
|
| 144 |
+
if resp.status_code == 429:
|
| 145 |
+
# 记录429错误并获取响应内容
|
| 146 |
+
response_content = ""
|
| 147 |
+
try:
|
| 148 |
+
content_bytes = await resp.aread()
|
| 149 |
+
if isinstance(content_bytes, bytes):
|
| 150 |
+
response_content = content_bytes.decode('utf-8', errors='ignore')
|
| 151 |
+
except Exception as e:
|
| 152 |
+
log.debug(f"[STREAMING] Failed to read 429 response content: {e}")
|
| 153 |
+
|
| 154 |
+
# 显示详细的429错误信息
|
| 155 |
+
if response_content:
|
| 156 |
+
log.error(f"Google API returned status 429 (STREAMING). Response details: {response_content[:500]}")
|
| 157 |
+
else:
|
| 158 |
+
log.error("Google API returned status 429 (STREAMING) - quota exhausted, no response details available")
|
| 159 |
+
|
| 160 |
+
if credential_manager and current_file:
|
| 161 |
+
await credential_manager.record_api_call_result(current_file, False, 429)
|
| 162 |
+
|
| 163 |
+
# 清理资源
|
| 164 |
+
try:
|
| 165 |
+
await stream_ctx.__aexit__(None, None, None)
|
| 166 |
+
except:
|
| 167 |
+
pass
|
| 168 |
+
await client.aclose()
|
| 169 |
+
|
| 170 |
+
# 如果重试可用且未达到最大次数,进行重试
|
| 171 |
+
if retry_429_enabled and attempt < max_retries:
|
| 172 |
+
log.warning(f"[RETRY] 429 error encountered, retrying ({attempt + 1}/{max_retries})")
|
| 173 |
+
if credential_manager:
|
| 174 |
+
# 429错误时强制轮换凭证,不增加调用计数
|
| 175 |
+
await credential_manager.force_rotate_credential()
|
| 176 |
+
# 重新获取凭证和headers(凭证可能已轮换)
|
| 177 |
+
new_credential_result = await credential_manager.get_valid_credential()
|
| 178 |
+
if new_credential_result:
|
| 179 |
+
current_file, credential_data = new_credential_result
|
| 180 |
+
headers, updated_payload = await _prepare_request_headers_and_payload(payload, credential_data)
|
| 181 |
+
final_post_data = json.dumps(updated_payload)
|
| 182 |
+
await asyncio.sleep(retry_interval)
|
| 183 |
+
continue # 跳出内层处理,继续外层循环重试
|
| 184 |
+
else:
|
| 185 |
+
# 返回429错误流
|
| 186 |
+
async def error_stream():
|
| 187 |
+
error_response = {
|
| 188 |
+
"error": {
|
| 189 |
+
"message": "429 rate limit exceeded, max retries reached",
|
| 190 |
+
"type": "api_error",
|
| 191 |
+
"code": 429
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
| 195 |
+
return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=429)
|
| 196 |
+
elif resp.status_code != 200:
|
| 197 |
+
# 处理其他非200状态码的错误
|
| 198 |
+
response_content = ""
|
| 199 |
+
try:
|
| 200 |
+
content_bytes = await resp.aread()
|
| 201 |
+
if isinstance(content_bytes, bytes):
|
| 202 |
+
response_content = content_bytes.decode('utf-8', errors='ignore')
|
| 203 |
+
except Exception as e:
|
| 204 |
+
log.debug(f"[STREAMING] Failed to read error response content: {e}")
|
| 205 |
+
|
| 206 |
+
# 显示详细的错误信息
|
| 207 |
+
if response_content:
|
| 208 |
+
log.error(f"Google API returned status {resp.status_code} (STREAMING). Response details: {response_content[:500]}")
|
| 209 |
+
else:
|
| 210 |
+
log.error(f"Google API returned status {resp.status_code} (STREAMING) - no response details available")
|
| 211 |
+
|
| 212 |
+
# 记录API调用错误
|
| 213 |
+
if credential_manager and current_file:
|
| 214 |
+
await credential_manager.record_api_call_result(current_file, False, resp.status_code)
|
| 215 |
+
|
| 216 |
+
# 清理资源
|
| 217 |
+
try:
|
| 218 |
+
await stream_ctx.__aexit__(None, None, None)
|
| 219 |
+
except:
|
| 220 |
+
pass
|
| 221 |
+
await client.aclose()
|
| 222 |
+
|
| 223 |
+
# 处理凭证轮换
|
| 224 |
+
await _handle_api_error(credential_manager, resp.status_code, response_content)
|
| 225 |
+
|
| 226 |
+
# 返回错误流
|
| 227 |
+
async def error_stream():
|
| 228 |
+
error_response = {
|
| 229 |
+
"error": {
|
| 230 |
+
"message": f"API error: {resp.status_code}",
|
| 231 |
+
"type": "api_error",
|
| 232 |
+
"code": resp.status_code
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
| 236 |
+
return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=resp.status_code)
|
| 237 |
+
else:
|
| 238 |
+
# 成功响应,传递所有资源给流式处理函数管理
|
| 239 |
+
return _handle_streaming_response_managed(resp, stream_ctx, client, credential_manager, payload.get("model", ""), current_file)
|
| 240 |
+
|
| 241 |
+
except Exception as e:
|
| 242 |
+
# 清理资源
|
| 243 |
+
try:
|
| 244 |
+
await client.aclose()
|
| 245 |
+
except:
|
| 246 |
+
pass
|
| 247 |
+
raise e
|
| 248 |
+
|
| 249 |
+
else:
|
| 250 |
+
# 非流式请求处理 - 使用httpx_client模块
|
| 251 |
+
async with http_client.get_client(timeout=None) as client:
|
| 252 |
+
resp = await client.post(
|
| 253 |
+
target_url, content=final_post_data, headers=headers
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
if resp.status_code == 429:
|
| 257 |
+
# 记录429错误
|
| 258 |
+
if credential_manager and current_file:
|
| 259 |
+
await credential_manager.record_api_call_result(current_file, False, 429)
|
| 260 |
+
|
| 261 |
+
# 如果重试可用且未达到最大次数,继续重试
|
| 262 |
+
if retry_429_enabled and attempt < max_retries:
|
| 263 |
+
log.warning(f"[RETRY] 429 error encountered, retrying ({attempt + 1}/{max_retries})")
|
| 264 |
+
if credential_manager:
|
| 265 |
+
# 429错误时强制轮换凭证,不增加调用计数
|
| 266 |
+
await credential_manager.force_rotate_credential()
|
| 267 |
+
# 重新获取凭证和headers(凭证可能已轮换)
|
| 268 |
+
new_credential_result = await credential_manager.get_valid_credential()
|
| 269 |
+
if new_credential_result:
|
| 270 |
+
current_file, credential_data = new_credential_result
|
| 271 |
+
headers, updated_payload = await _prepare_request_headers_and_payload(payload, credential_data)
|
| 272 |
+
final_post_data = json.dumps(updated_payload)
|
| 273 |
+
await asyncio.sleep(retry_interval)
|
| 274 |
+
continue
|
| 275 |
+
else:
|
| 276 |
+
log.error(f"[RETRY] Max retries exceeded for 429 error")
|
| 277 |
+
return _create_error_response("429 rate limit exceeded, max retries reached", 429)
|
| 278 |
+
else:
|
| 279 |
+
# 非429错误或成功响应,正常处理
|
| 280 |
+
return await _handle_non_streaming_response(resp, credential_manager, payload.get("model", ""), current_file)
|
| 281 |
+
|
| 282 |
+
except Exception as e:
|
| 283 |
+
if attempt < max_retries:
|
| 284 |
+
log.warning(f"[RETRY] Request failed with exception, retrying ({attempt + 1}/{max_retries}): {str(e)}")
|
| 285 |
+
await asyncio.sleep(retry_interval)
|
| 286 |
+
continue
|
| 287 |
+
else:
|
| 288 |
+
log.error(f"Request to Google API failed: {str(e)}")
|
| 289 |
+
return _create_error_response(f"Request failed: {str(e)}")
|
| 290 |
+
|
| 291 |
+
# 如果循环结束仍未成功,返回错误
|
| 292 |
+
return _create_error_response("Max retries exceeded", 429)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _handle_streaming_response_managed(resp, stream_ctx, client, credential_manager: CredentialManager = None, model_name: str = "", current_file: str = None) -> StreamingResponse:
|
| 296 |
+
"""Handle streaming response with complete resource lifecycle management."""
|
| 297 |
+
|
| 298 |
+
# 检查HTTP错误
|
| 299 |
+
if resp.status_code != 200:
|
| 300 |
+
# 立即清理资源并返回错误
|
| 301 |
+
async def cleanup_and_error():
|
| 302 |
+
try:
|
| 303 |
+
await stream_ctx.__aexit__(None, None, None)
|
| 304 |
+
except:
|
| 305 |
+
pass
|
| 306 |
+
try:
|
| 307 |
+
await client.aclose()
|
| 308 |
+
except:
|
| 309 |
+
pass
|
| 310 |
+
|
| 311 |
+
# 获取响应内容用于详细错误显示
|
| 312 |
+
response_content = ""
|
| 313 |
+
try:
|
| 314 |
+
content_bytes = await resp.aread()
|
| 315 |
+
if isinstance(content_bytes, bytes):
|
| 316 |
+
response_content = content_bytes.decode('utf-8', errors='ignore')
|
| 317 |
+
except Exception as e:
|
| 318 |
+
log.debug(f"[STREAMING] Failed to read response content for error analysis: {e}")
|
| 319 |
+
response_content = ""
|
| 320 |
+
|
| 321 |
+
# 显示详细错误信息
|
| 322 |
+
if resp.status_code == 429:
|
| 323 |
+
if response_content:
|
| 324 |
+
log.error(f"Google API returned status 429 (STREAMING). Response details: {response_content[:500]}")
|
| 325 |
+
else:
|
| 326 |
+
log.error(f"Google API returned status 429 (STREAMING)")
|
| 327 |
+
else:
|
| 328 |
+
if response_content:
|
| 329 |
+
log.error(f"Google API returned status {resp.status_code} (STREAMING). Response details: {response_content[:500]}")
|
| 330 |
+
else:
|
| 331 |
+
log.error(f"Google API returned status {resp.status_code} (STREAMING)")
|
| 332 |
+
|
| 333 |
+
# 记录API调用错误
|
| 334 |
+
if credential_manager and current_file:
|
| 335 |
+
await credential_manager.record_api_call_result(current_file, False, resp.status_code)
|
| 336 |
+
|
| 337 |
+
await _handle_api_error(credential_manager, resp.status_code, response_content)
|
| 338 |
+
|
| 339 |
+
error_response = {
|
| 340 |
+
"error": {
|
| 341 |
+
"message": f"API error: {resp.status_code}",
|
| 342 |
+
"type": "api_error",
|
| 343 |
+
"code": resp.status_code
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
yield f'data: {json.dumps(error_response)}\n\n'.encode('utf-8')
|
| 347 |
+
|
| 348 |
+
return StreamingResponse(
|
| 349 |
+
cleanup_and_error(),
|
| 350 |
+
media_type="text/event-stream",
|
| 351 |
+
status_code=resp.status_code
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# 正常流式响应处理,确保资源在流结束时被清理
|
| 355 |
+
async def managed_stream_generator():
|
| 356 |
+
success_recorded = False
|
| 357 |
+
managed_stream_generator._chunk_count = 0 # 初始化chunk计数器
|
| 358 |
+
try:
|
| 359 |
+
async for chunk in resp.aiter_lines():
|
| 360 |
+
if not chunk or not chunk.startswith('data: '):
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
# 记录第一次成功响应
|
| 364 |
+
if not success_recorded:
|
| 365 |
+
if current_file and credential_manager:
|
| 366 |
+
await credential_manager.record_api_call_result(current_file, True)
|
| 367 |
+
# 记录到使用统计
|
| 368 |
+
try:
|
| 369 |
+
await record_successful_call(current_file, model_name)
|
| 370 |
+
except Exception as e:
|
| 371 |
+
log.debug(f"Failed to record usage statistics: {e}")
|
| 372 |
+
success_recorded = True
|
| 373 |
+
|
| 374 |
+
payload = chunk[len('data: '):]
|
| 375 |
+
try:
|
| 376 |
+
obj = json.loads(payload)
|
| 377 |
+
if "response" in obj:
|
| 378 |
+
data = obj["response"]
|
| 379 |
+
yield f"data: {json.dumps(data, separators=(',',':'))}\n\n".encode()
|
| 380 |
+
await asyncio.sleep(0) # 让其他协程有机会运行
|
| 381 |
+
|
| 382 |
+
# 定期释放内存(每100个chunk)
|
| 383 |
+
if hasattr(managed_stream_generator, '_chunk_count'):
|
| 384 |
+
managed_stream_generator._chunk_count += 1
|
| 385 |
+
if managed_stream_generator._chunk_count % 100 == 0:
|
| 386 |
+
gc.collect()
|
| 387 |
+
else:
|
| 388 |
+
yield f"data: {json.dumps(obj, separators=(',',':'))}\n\n".encode()
|
| 389 |
+
except json.JSONDecodeError:
|
| 390 |
+
continue
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
log.error(f"Streaming error: {e}")
|
| 394 |
+
err = {"error": {"message": str(e), "type": "api_error", "code": 500}}
|
| 395 |
+
yield f"data: {json.dumps(err)}\n\n".encode()
|
| 396 |
+
finally:
|
| 397 |
+
# 确保清理所有资源
|
| 398 |
+
try:
|
| 399 |
+
await stream_ctx.__aexit__(None, None, None)
|
| 400 |
+
except Exception as e:
|
| 401 |
+
log.debug(f"Error closing stream context: {e}")
|
| 402 |
+
try:
|
| 403 |
+
await client.aclose()
|
| 404 |
+
except Exception as e:
|
| 405 |
+
log.debug(f"Error closing client: {e}")
|
| 406 |
+
|
| 407 |
+
return StreamingResponse(
|
| 408 |
+
managed_stream_generator(),
|
| 409 |
+
media_type="text/event-stream"
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
async def _handle_non_streaming_response(resp, credential_manager: CredentialManager = None, model_name: str = "", current_file: str = None) -> Response:
|
| 413 |
+
"""Handle non-streaming response from Google API."""
|
| 414 |
+
if resp.status_code == 200:
|
| 415 |
+
try:
|
| 416 |
+
# 记录成功响应
|
| 417 |
+
if current_file and credential_manager:
|
| 418 |
+
await credential_manager.record_api_call_result(current_file, True)
|
| 419 |
+
# 记录到使用统计
|
| 420 |
+
try:
|
| 421 |
+
await record_successful_call(current_file, model_name)
|
| 422 |
+
except Exception as e:
|
| 423 |
+
log.debug(f"Failed to record usage statistics: {e}")
|
| 424 |
+
|
| 425 |
+
raw = await resp.aread()
|
| 426 |
+
google_api_response = raw.decode('utf-8')
|
| 427 |
+
if google_api_response.startswith('data: '):
|
| 428 |
+
google_api_response = google_api_response[len('data: '):]
|
| 429 |
+
google_api_response = json.loads(google_api_response)
|
| 430 |
+
log.debug(f"Google API原始响应: {json.dumps(google_api_response, ensure_ascii=False)[:500]}...")
|
| 431 |
+
standard_gemini_response = google_api_response.get("response")
|
| 432 |
+
log.debug(f"提取的response字段: {json.dumps(standard_gemini_response, ensure_ascii=False)[:500]}...")
|
| 433 |
+
return Response(
|
| 434 |
+
content=json.dumps(standard_gemini_response),
|
| 435 |
+
status_code=200,
|
| 436 |
+
media_type="application/json; charset=utf-8"
|
| 437 |
+
)
|
| 438 |
+
except Exception as e:
|
| 439 |
+
log.error(f"Failed to parse Google API response: {str(e)}")
|
| 440 |
+
return Response(
|
| 441 |
+
content=resp.content,
|
| 442 |
+
status_code=resp.status_code,
|
| 443 |
+
media_type=resp.headers.get("Content-Type")
|
| 444 |
+
)
|
| 445 |
+
else:
|
| 446 |
+
# 获取响应内容用于详细错误显示
|
| 447 |
+
response_content = ""
|
| 448 |
+
try:
|
| 449 |
+
if hasattr(resp, 'content'):
|
| 450 |
+
content = resp.content
|
| 451 |
+
if isinstance(content, bytes):
|
| 452 |
+
response_content = content.decode('utf-8', errors='ignore')
|
| 453 |
+
else:
|
| 454 |
+
content_bytes = await resp.aread()
|
| 455 |
+
if isinstance(content_bytes, bytes):
|
| 456 |
+
response_content = content_bytes.decode('utf-8', errors='ignore')
|
| 457 |
+
except Exception as e:
|
| 458 |
+
log.debug(f"[NON-STREAMING] Failed to read response content for error analysis: {e}")
|
| 459 |
+
response_content = ""
|
| 460 |
+
|
| 461 |
+
# 显示详细错误信息
|
| 462 |
+
if resp.status_code == 429:
|
| 463 |
+
if response_content:
|
| 464 |
+
log.error(f"Google API returned status 429 (NON-STREAMING). Response details: {response_content[:500]}")
|
| 465 |
+
else:
|
| 466 |
+
log.error(f"Google API returned status 429 (NON-STREAMING)")
|
| 467 |
+
else:
|
| 468 |
+
if response_content:
|
| 469 |
+
log.error(f"Google API returned status {resp.status_code} (NON-STREAMING). Response details: {response_content[:500]}")
|
| 470 |
+
else:
|
| 471 |
+
log.error(f"Google API returned status {resp.status_code} (NON-STREAMING)")
|
| 472 |
+
|
| 473 |
+
# 记录API调用错误
|
| 474 |
+
if credential_manager and current_file:
|
| 475 |
+
await credential_manager.record_api_call_result(current_file, False, resp.status_code)
|
| 476 |
+
|
| 477 |
+
await _handle_api_error(credential_manager, resp.status_code, response_content)
|
| 478 |
+
|
| 479 |
+
return _create_error_response(f"API error: {resp.status_code}", resp.status_code)
|
| 480 |
+
|
| 481 |
+
def build_gemini_payload_from_native(native_request: dict, model_from_path: str) -> dict:
|
| 482 |
+
"""
|
| 483 |
+
Build a Gemini API payload from a native Gemini request with full pass-through support.
|
| 484 |
+
"""
|
| 485 |
+
# 创建请求副本以避免修改原始数据
|
| 486 |
+
request_data = native_request.copy()
|
| 487 |
+
|
| 488 |
+
# 应用默认安全设置(如果未指定)
|
| 489 |
+
if "safetySettings" not in request_data:
|
| 490 |
+
request_data["safetySettings"] = DEFAULT_SAFETY_SETTINGS
|
| 491 |
+
|
| 492 |
+
# 确保generationConfig存在
|
| 493 |
+
if "generationConfig" not in request_data:
|
| 494 |
+
request_data["generationConfig"] = {}
|
| 495 |
+
|
| 496 |
+
generation_config = request_data["generationConfig"]
|
| 497 |
+
|
| 498 |
+
# 配置thinking(如果未指定thinkingConfig)
|
| 499 |
+
if "thinkingConfig" not in generation_config:
|
| 500 |
+
generation_config["thinkingConfig"] = {}
|
| 501 |
+
|
| 502 |
+
thinking_config = generation_config["thinkingConfig"]
|
| 503 |
+
|
| 504 |
+
# 只有在未明确设置时才应用默认thinking配置
|
| 505 |
+
if "includeThoughts" not in thinking_config:
|
| 506 |
+
thinking_config["includeThoughts"] = should_include_thoughts(model_from_path)
|
| 507 |
+
if "thinkingBudget" not in thinking_config:
|
| 508 |
+
thinking_config["thinkingBudget"] = get_thinking_budget(model_from_path)
|
| 509 |
+
|
| 510 |
+
# 为搜索模型添加Google Search工具(如果未指定且没有functionDeclarations)
|
| 511 |
+
if is_search_model(model_from_path):
|
| 512 |
+
if "tools" not in request_data:
|
| 513 |
+
request_data["tools"] = []
|
| 514 |
+
# 检查是否已有functionDeclarations或googleSearch工具
|
| 515 |
+
has_function_declarations = any(tool.get("functionDeclarations") for tool in request_data["tools"])
|
| 516 |
+
has_google_search = any(tool.get("googleSearch") for tool in request_data["tools"])
|
| 517 |
+
|
| 518 |
+
# 只有在没有任何工具时才添加googleSearch,或者只有googleSearch工具时可以添加更多googleSearch
|
| 519 |
+
if not has_function_declarations and not has_google_search:
|
| 520 |
+
request_data["tools"].append({"googleSearch": {}})
|
| 521 |
+
|
| 522 |
+
# 透传所有其他Gemini原生字段:
|
| 523 |
+
# - contents (必需)
|
| 524 |
+
# - systemInstruction (可选)
|
| 525 |
+
# - generationConfig (已处理)
|
| 526 |
+
# - safetySettings (已处理)
|
| 527 |
+
# - tools (已处理)
|
| 528 |
+
# - toolConfig (透传)
|
| 529 |
+
# - cachedContent (透传)
|
| 530 |
+
# - 以及任何其他未知字段都会被透传
|
| 531 |
+
|
| 532 |
+
return {
|
| 533 |
+
"model": get_base_model_name(model_from_path),
|
| 534 |
+
"request": request_data
|
| 535 |
+
}
|
src/google_oauth_api.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Google OAuth2 认证模块
|
| 3 |
+
"""
|
| 4 |
+
import time
|
| 5 |
+
import jwt
|
| 6 |
+
import asyncio
|
| 7 |
+
from datetime import datetime, timezone, timedelta
|
| 8 |
+
from typing import Optional, Dict, Any, List
|
| 9 |
+
from urllib.parse import urlencode
|
| 10 |
+
|
| 11 |
+
from config import get_oauth_proxy_url, get_googleapis_proxy_url, get_resource_manager_api_url, get_service_usage_api_url
|
| 12 |
+
from log import log
|
| 13 |
+
from .httpx_client import get_async, post_async
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TokenError(Exception):
|
| 17 |
+
"""Token相关错误"""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
class Credentials:
|
| 21 |
+
"""凭证类"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, access_token: str, refresh_token: str = None,
|
| 24 |
+
client_id: str = None, client_secret: str = None,
|
| 25 |
+
expires_at: datetime = None, project_id: str = None):
|
| 26 |
+
self.access_token = access_token
|
| 27 |
+
self.refresh_token = refresh_token
|
| 28 |
+
self.client_id = client_id
|
| 29 |
+
self.client_secret = client_secret
|
| 30 |
+
self.expires_at = expires_at
|
| 31 |
+
self.project_id = project_id
|
| 32 |
+
|
| 33 |
+
# 反代配置将在使用时异步获取
|
| 34 |
+
self.oauth_base_url = None
|
| 35 |
+
self.token_endpoint = None
|
| 36 |
+
|
| 37 |
+
def is_expired(self) -> bool:
|
| 38 |
+
"""检查token是否过期"""
|
| 39 |
+
if not self.expires_at:
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
# 提前3分钟认为过期
|
| 43 |
+
buffer = timedelta(minutes=3)
|
| 44 |
+
return (self.expires_at - buffer) <= datetime.now(timezone.utc)
|
| 45 |
+
|
| 46 |
+
async def refresh_if_needed(self) -> bool:
|
| 47 |
+
"""如果需要则刷新token"""
|
| 48 |
+
if not self.is_expired():
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
if not self.refresh_token:
|
| 52 |
+
raise TokenError("需要刷新令牌但未提供")
|
| 53 |
+
|
| 54 |
+
await self.refresh()
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
async def refresh(self, max_retries: int = 3, base_delay: float = 1.0):
|
| 58 |
+
"""刷新访问令牌,支持重试机制"""
|
| 59 |
+
if not self.refresh_token:
|
| 60 |
+
raise TokenError("无刷新令牌")
|
| 61 |
+
|
| 62 |
+
data = {
|
| 63 |
+
'client_id': self.client_id,
|
| 64 |
+
'client_secret': self.client_secret,
|
| 65 |
+
'refresh_token': self.refresh_token,
|
| 66 |
+
'grant_type': 'refresh_token'
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
last_exception = None
|
| 70 |
+
for attempt in range(max_retries + 1):
|
| 71 |
+
try:
|
| 72 |
+
oauth_base_url = await get_oauth_proxy_url()
|
| 73 |
+
token_url = f"{oauth_base_url.rstrip('/')}/token"
|
| 74 |
+
response = await post_async(
|
| 75 |
+
token_url,
|
| 76 |
+
data=data,
|
| 77 |
+
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
| 78 |
+
)
|
| 79 |
+
response.raise_for_status()
|
| 80 |
+
|
| 81 |
+
token_data = response.json()
|
| 82 |
+
self.access_token = token_data['access_token']
|
| 83 |
+
|
| 84 |
+
if 'expires_in' in token_data:
|
| 85 |
+
expires_in = int(token_data['expires_in'])
|
| 86 |
+
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
| 87 |
+
|
| 88 |
+
if 'refresh_token' in token_data:
|
| 89 |
+
self.refresh_token = token_data['refresh_token']
|
| 90 |
+
|
| 91 |
+
if attempt > 0:
|
| 92 |
+
log.debug(f"Token刷新成功(第{attempt + 1}次尝试),过期时间: {self.expires_at}")
|
| 93 |
+
else:
|
| 94 |
+
log.debug(f"Token刷新成功,过期时间: {self.expires_at}")
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
last_exception = e
|
| 99 |
+
error_msg = str(e)
|
| 100 |
+
|
| 101 |
+
# 检查是否是不可恢复的错误,如果是则不重试
|
| 102 |
+
if self._is_non_retryable_error(error_msg):
|
| 103 |
+
log.error(f"Token刷新遇到不可恢复错误: {error_msg}")
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
if attempt < max_retries:
|
| 107 |
+
# 计算退避延迟时间(指数退避)
|
| 108 |
+
delay = base_delay * (2 ** attempt)
|
| 109 |
+
log.warning(f"Token刷新失败(第{attempt + 1}次尝试): {error_msg},{delay}秒后重试...")
|
| 110 |
+
await asyncio.sleep(delay)
|
| 111 |
+
else:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
# 所有重试都失败了
|
| 115 |
+
error_msg = f"Token刷新失败(已重试{max_retries}次): {str(last_exception)}"
|
| 116 |
+
log.error(error_msg)
|
| 117 |
+
raise TokenError(error_msg)
|
| 118 |
+
|
| 119 |
+
def _is_non_retryable_error(self, error_msg: str) -> bool:
|
| 120 |
+
"""判断是否是不需要重试的错误"""
|
| 121 |
+
non_retryable_patterns = [
|
| 122 |
+
"400 Bad Request",
|
| 123 |
+
"invalid_grant",
|
| 124 |
+
"refresh_token_expired",
|
| 125 |
+
"invalid_refresh_token",
|
| 126 |
+
"unauthorized_client",
|
| 127 |
+
"access_denied",
|
| 128 |
+
"401 Unauthorized"
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
error_msg_lower = error_msg.lower()
|
| 132 |
+
for pattern in non_retryable_patterns:
|
| 133 |
+
if pattern.lower() in error_msg_lower:
|
| 134 |
+
return True
|
| 135 |
+
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
@classmethod
|
| 139 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'Credentials':
|
| 140 |
+
"""从字典创建凭证"""
|
| 141 |
+
# 处理过期时间
|
| 142 |
+
expires_at = None
|
| 143 |
+
if 'expiry' in data and data['expiry']:
|
| 144 |
+
try:
|
| 145 |
+
expiry_str = data['expiry']
|
| 146 |
+
if isinstance(expiry_str, str):
|
| 147 |
+
if expiry_str.endswith('Z'):
|
| 148 |
+
expires_at = datetime.fromisoformat(expiry_str.replace('Z', '+00:00'))
|
| 149 |
+
elif '+' in expiry_str:
|
| 150 |
+
expires_at = datetime.fromisoformat(expiry_str)
|
| 151 |
+
else:
|
| 152 |
+
expires_at = datetime.fromisoformat(expiry_str).replace(tzinfo=timezone.utc)
|
| 153 |
+
except ValueError:
|
| 154 |
+
log.warning(f"无法解析过期时间: {expiry_str}")
|
| 155 |
+
|
| 156 |
+
return cls(
|
| 157 |
+
access_token=data.get('token') or data.get('access_token', ''),
|
| 158 |
+
refresh_token=data.get('refresh_token'),
|
| 159 |
+
client_id=data.get('client_id'),
|
| 160 |
+
client_secret=data.get('client_secret'),
|
| 161 |
+
expires_at=expires_at,
|
| 162 |
+
project_id=data.get('project_id')
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 166 |
+
"""转为字典"""
|
| 167 |
+
result = {
|
| 168 |
+
'access_token': self.access_token,
|
| 169 |
+
'refresh_token': self.refresh_token,
|
| 170 |
+
'client_id': self.client_id,
|
| 171 |
+
'client_secret': self.client_secret,
|
| 172 |
+
'project_id': self.project_id
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
if self.expires_at:
|
| 176 |
+
result['expiry'] = self.expires_at.isoformat()
|
| 177 |
+
|
| 178 |
+
return result
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class Flow:
|
| 182 |
+
"""OAuth流程类"""
|
| 183 |
+
|
| 184 |
+
def __init__(self, client_id: str, client_secret: str, scopes: List[str],
|
| 185 |
+
redirect_uri: str = None):
|
| 186 |
+
self.client_id = client_id
|
| 187 |
+
self.client_secret = client_secret
|
| 188 |
+
self.scopes = scopes
|
| 189 |
+
self.redirect_uri = redirect_uri
|
| 190 |
+
|
| 191 |
+
# 反代配置将在使用时异步获取
|
| 192 |
+
self.oauth_base_url = None
|
| 193 |
+
self.token_endpoint = None
|
| 194 |
+
self.auth_endpoint = "https://accounts.google.com/o/oauth2/auth"
|
| 195 |
+
|
| 196 |
+
self.credentials: Optional[Credentials] = None
|
| 197 |
+
|
| 198 |
+
def get_auth_url(self, state: str = None, **kwargs) -> str:
|
| 199 |
+
"""生成授权URL"""
|
| 200 |
+
params = {
|
| 201 |
+
'client_id': self.client_id,
|
| 202 |
+
'redirect_uri': self.redirect_uri,
|
| 203 |
+
'scope': ' '.join(self.scopes),
|
| 204 |
+
'response_type': 'code',
|
| 205 |
+
'access_type': 'offline',
|
| 206 |
+
'prompt': 'consent',
|
| 207 |
+
'include_granted_scopes': 'true'
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
if state:
|
| 211 |
+
params['state'] = state
|
| 212 |
+
|
| 213 |
+
params.update(kwargs)
|
| 214 |
+
return f"{self.auth_endpoint}?{urlencode(params)}"
|
| 215 |
+
|
| 216 |
+
async def exchange_code(self, code: str) -> Credentials:
|
| 217 |
+
"""用授权码换取token"""
|
| 218 |
+
data = {
|
| 219 |
+
'client_id': self.client_id,
|
| 220 |
+
'client_secret': self.client_secret,
|
| 221 |
+
'redirect_uri': self.redirect_uri,
|
| 222 |
+
'code': code,
|
| 223 |
+
'grant_type': 'authorization_code'
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
oauth_base_url = await get_oauth_proxy_url()
|
| 228 |
+
token_url = f"{oauth_base_url.rstrip('/')}/token"
|
| 229 |
+
response = await post_async(
|
| 230 |
+
token_url,
|
| 231 |
+
data=data,
|
| 232 |
+
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
| 233 |
+
)
|
| 234 |
+
response.raise_for_status()
|
| 235 |
+
|
| 236 |
+
token_data = response.json()
|
| 237 |
+
|
| 238 |
+
# 计算过期时间
|
| 239 |
+
expires_at = None
|
| 240 |
+
if 'expires_in' in token_data:
|
| 241 |
+
expires_in = int(token_data['expires_in'])
|
| 242 |
+
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
| 243 |
+
|
| 244 |
+
# 创建凭证对象
|
| 245 |
+
self.credentials = Credentials(
|
| 246 |
+
access_token=token_data['access_token'],
|
| 247 |
+
refresh_token=token_data.get('refresh_token'),
|
| 248 |
+
client_id=self.client_id,
|
| 249 |
+
client_secret=self.client_secret,
|
| 250 |
+
expires_at=expires_at
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
return self.credentials
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
error_msg = f"获取token失败: {str(e)}"
|
| 257 |
+
log.error(error_msg)
|
| 258 |
+
raise TokenError(error_msg)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class ServiceAccount:
|
| 262 |
+
"""Service Account类"""
|
| 263 |
+
|
| 264 |
+
def __init__(self, email: str, private_key: str, project_id: str = None,
|
| 265 |
+
scopes: List[str] = None):
|
| 266 |
+
self.email = email
|
| 267 |
+
self.private_key = private_key
|
| 268 |
+
self.project_id = project_id
|
| 269 |
+
self.scopes = scopes or []
|
| 270 |
+
|
| 271 |
+
# 反代配置将在使用时异步获取
|
| 272 |
+
self.oauth_base_url = None
|
| 273 |
+
self.token_endpoint = None
|
| 274 |
+
|
| 275 |
+
self.access_token: Optional[str] = None
|
| 276 |
+
self.expires_at: Optional[datetime] = None
|
| 277 |
+
|
| 278 |
+
def is_expired(self) -> bool:
|
| 279 |
+
"""检查token是否过期"""
|
| 280 |
+
if not self.expires_at:
|
| 281 |
+
return True
|
| 282 |
+
|
| 283 |
+
buffer = timedelta(minutes=3)
|
| 284 |
+
return (self.expires_at - buffer) <= datetime.now(timezone.utc)
|
| 285 |
+
|
| 286 |
+
def create_jwt(self) -> str:
|
| 287 |
+
"""创建JWT令牌"""
|
| 288 |
+
now = int(time.time())
|
| 289 |
+
|
| 290 |
+
payload = {
|
| 291 |
+
'iss': self.email,
|
| 292 |
+
'scope': ' '.join(self.scopes) if self.scopes else '',
|
| 293 |
+
'aud': self.token_endpoint,
|
| 294 |
+
'exp': now + 3600,
|
| 295 |
+
'iat': now
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
return jwt.encode(payload, self.private_key, algorithm='RS256')
|
| 299 |
+
|
| 300 |
+
async def get_access_token(self) -> str:
|
| 301 |
+
"""获取访问令牌"""
|
| 302 |
+
if not self.is_expired() and self.access_token:
|
| 303 |
+
return self.access_token
|
| 304 |
+
|
| 305 |
+
assertion = self.create_jwt()
|
| 306 |
+
|
| 307 |
+
data = {
|
| 308 |
+
'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
|
| 309 |
+
'assertion': assertion
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
oauth_base_url = await get_oauth_proxy_url()
|
| 314 |
+
token_url = f"{oauth_base_url.rstrip('/')}/token"
|
| 315 |
+
response = await post_async(
|
| 316 |
+
token_url,
|
| 317 |
+
data=data,
|
| 318 |
+
headers={'Content-Type': 'application/x-www-form-urlencoded'}
|
| 319 |
+
)
|
| 320 |
+
response.raise_for_status()
|
| 321 |
+
|
| 322 |
+
token_data = response.json()
|
| 323 |
+
self.access_token = token_data['access_token']
|
| 324 |
+
|
| 325 |
+
if 'expires_in' in token_data:
|
| 326 |
+
expires_in = int(token_data['expires_in'])
|
| 327 |
+
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
| 328 |
+
|
| 329 |
+
return self.access_token
|
| 330 |
+
|
| 331 |
+
except Exception as e:
|
| 332 |
+
error_msg = f"Service Account获取token失败: {str(e)}"
|
| 333 |
+
log.error(error_msg)
|
| 334 |
+
raise TokenError(error_msg)
|
| 335 |
+
|
| 336 |
+
@classmethod
|
| 337 |
+
def from_dict(cls, data: Dict[str, Any], scopes: List[str] = None) -> 'ServiceAccount':
|
| 338 |
+
"""从字典创建Service Account凭证"""
|
| 339 |
+
return cls(
|
| 340 |
+
email=data['client_email'],
|
| 341 |
+
private_key=data['private_key'],
|
| 342 |
+
project_id=data.get('project_id'),
|
| 343 |
+
scopes=scopes
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
# 工具函数
|
| 348 |
+
async def get_user_info(credentials: Credentials) -> Optional[Dict[str, Any]]:
|
| 349 |
+
"""获取用户信息"""
|
| 350 |
+
await credentials.refresh_if_needed()
|
| 351 |
+
|
| 352 |
+
try:
|
| 353 |
+
googleapis_base_url = await get_googleapis_proxy_url()
|
| 354 |
+
userinfo_url = f"{googleapis_base_url.rstrip('/')}/oauth2/v2/userinfo"
|
| 355 |
+
response = await get_async(
|
| 356 |
+
userinfo_url,
|
| 357 |
+
headers={'Authorization': f'Bearer {credentials.access_token}'}
|
| 358 |
+
)
|
| 359 |
+
response.raise_for_status()
|
| 360 |
+
return response.json()
|
| 361 |
+
except Exception as e:
|
| 362 |
+
log.error(f"获取用户信息失败: {e}")
|
| 363 |
+
return None
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
async def get_user_email(credentials: Credentials) -> Optional[str]:
|
| 367 |
+
"""获取用户邮箱地址"""
|
| 368 |
+
try:
|
| 369 |
+
# 确保凭证有效
|
| 370 |
+
await credentials.refresh_if_needed()
|
| 371 |
+
|
| 372 |
+
# 调用Google userinfo API获取邮箱
|
| 373 |
+
user_info = await get_user_info(credentials)
|
| 374 |
+
if user_info:
|
| 375 |
+
email = user_info.get("email")
|
| 376 |
+
if email:
|
| 377 |
+
log.info(f"成功获取邮箱地址: {email}")
|
| 378 |
+
return email
|
| 379 |
+
else:
|
| 380 |
+
log.warning(f"userinfo响应中没有邮箱信息: {user_info}")
|
| 381 |
+
return None
|
| 382 |
+
else:
|
| 383 |
+
log.warning("获取用户信息失败")
|
| 384 |
+
return None
|
| 385 |
+
|
| 386 |
+
except Exception as e:
|
| 387 |
+
log.error(f"获取用户邮箱失败: {e}")
|
| 388 |
+
return None
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
async def fetch_user_email_from_file(cred_data: Dict[str, Any]) -> Optional[str]:
|
| 392 |
+
"""从凭证数据获取用户邮箱地址(支持统一存储)"""
|
| 393 |
+
try:
|
| 394 |
+
# 直接从凭证数据创建凭证对象
|
| 395 |
+
credentials = Credentials.from_dict(cred_data)
|
| 396 |
+
if not credentials or not credentials.access_token:
|
| 397 |
+
log.warning(f"无法从凭证数据创建凭证对象或获取访问令牌")
|
| 398 |
+
return None
|
| 399 |
+
|
| 400 |
+
# 获取邮箱
|
| 401 |
+
return await get_user_email(credentials)
|
| 402 |
+
|
| 403 |
+
except Exception as e:
|
| 404 |
+
log.error(f"从凭证数据获取用户邮箱失败: {e}")
|
| 405 |
+
return None
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
async def validate_token(token: str) -> Optional[Dict[str, Any]]:
|
| 409 |
+
"""验证访问令牌"""
|
| 410 |
+
try:
|
| 411 |
+
oauth_base_url = await get_oauth_proxy_url()
|
| 412 |
+
tokeninfo_url = f"{oauth_base_url.rstrip('/')}/tokeninfo?access_token={token}"
|
| 413 |
+
|
| 414 |
+
response = await get_async(tokeninfo_url)
|
| 415 |
+
response.raise_for_status()
|
| 416 |
+
return response.json()
|
| 417 |
+
except Exception as e:
|
| 418 |
+
log.error(f"验证令牌失败: {e}")
|
| 419 |
+
return None
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
async def enable_required_apis(credentials: Credentials, project_id: str) -> bool:
|
| 423 |
+
"""自动启用必需的API服务"""
|
| 424 |
+
try:
|
| 425 |
+
# 确保凭证有效
|
| 426 |
+
if credentials.is_expired() and credentials.refresh_token:
|
| 427 |
+
await credentials.refresh()
|
| 428 |
+
|
| 429 |
+
headers = {
|
| 430 |
+
"Authorization": f"Bearer {credentials.access_token}",
|
| 431 |
+
"Content-Type": "application/json",
|
| 432 |
+
"User-Agent": "geminicli-oauth/1.0",
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
# 需要启用的服务列表
|
| 436 |
+
required_services = [
|
| 437 |
+
"geminicloudassist.googleapis.com", # Gemini Cloud Assist API
|
| 438 |
+
"cloudaicompanion.googleapis.com" # Gemini for Google Cloud API
|
| 439 |
+
]
|
| 440 |
+
|
| 441 |
+
for service in required_services:
|
| 442 |
+
log.info(f"正在检查并启用服务: {service}")
|
| 443 |
+
|
| 444 |
+
# 检查服务是否已启用
|
| 445 |
+
service_usage_base_url = await get_service_usage_api_url()
|
| 446 |
+
check_url = f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}"
|
| 447 |
+
try:
|
| 448 |
+
check_response = await get_async(check_url, headers=headers)
|
| 449 |
+
if check_response.status_code == 200:
|
| 450 |
+
service_data = check_response.json()
|
| 451 |
+
if service_data.get("state") == "ENABLED":
|
| 452 |
+
log.info(f"服务 {service} 已启用")
|
| 453 |
+
continue
|
| 454 |
+
except Exception as e:
|
| 455 |
+
log.debug(f"检查服务状态失败,将尝试启用: {e}")
|
| 456 |
+
|
| 457 |
+
# 启用服务
|
| 458 |
+
enable_url = f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}:enable"
|
| 459 |
+
try:
|
| 460 |
+
enable_response = await post_async(enable_url, headers=headers, json={})
|
| 461 |
+
|
| 462 |
+
if enable_response.status_code in [200, 201]:
|
| 463 |
+
log.info(f"✅ 成功启用服务: {service}")
|
| 464 |
+
elif enable_response.status_code == 400:
|
| 465 |
+
error_data = enable_response.json()
|
| 466 |
+
if "already enabled" in error_data.get("error", {}).get("message", "").lower():
|
| 467 |
+
log.info(f"✅ 服务 {service} 已经启用")
|
| 468 |
+
else:
|
| 469 |
+
log.warning(f"⚠️ 启用服务 {service} 时出现警告: {error_data}")
|
| 470 |
+
else:
|
| 471 |
+
log.warning(f"⚠️ 启用服务 {service} 失败: {enable_response.status_code} - {enable_response.text}")
|
| 472 |
+
|
| 473 |
+
except Exception as e:
|
| 474 |
+
log.warning(f"⚠️ 启用服务 {service} 时发生异常: {e}")
|
| 475 |
+
|
| 476 |
+
return True
|
| 477 |
+
|
| 478 |
+
except Exception as e:
|
| 479 |
+
log.error(f"启用API服务时发生错误: {e}")
|
| 480 |
+
return False
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
async def get_user_projects(credentials: Credentials) -> List[Dict[str, Any]]:
|
| 484 |
+
"""获取用户可访问的Google Cloud项目列表"""
|
| 485 |
+
try:
|
| 486 |
+
# 确保凭证有效
|
| 487 |
+
if credentials.is_expired() and credentials.refresh_token:
|
| 488 |
+
await credentials.refresh()
|
| 489 |
+
|
| 490 |
+
headers = {
|
| 491 |
+
"Authorization": f"Bearer {credentials.access_token}",
|
| 492 |
+
"User-Agent": "geminicli-oauth/1.0",
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
# 使用Resource Manager API的正确域名和端点
|
| 496 |
+
resource_manager_base_url = await get_resource_manager_api_url()
|
| 497 |
+
url = f"{resource_manager_base_url.rstrip('/')}/v1/projects"
|
| 498 |
+
log.info(f"正在调用API: {url}")
|
| 499 |
+
response = await get_async(url, headers=headers)
|
| 500 |
+
|
| 501 |
+
log.info(f"API响应状态码: {response.status_code}")
|
| 502 |
+
if response.status_code != 200:
|
| 503 |
+
log.error(f"API响应内容: {response.text}")
|
| 504 |
+
|
| 505 |
+
if response.status_code == 200:
|
| 506 |
+
data = response.json()
|
| 507 |
+
projects = data.get('projects', [])
|
| 508 |
+
# 只返回活跃的项目
|
| 509 |
+
active_projects = [
|
| 510 |
+
project for project in projects
|
| 511 |
+
if project.get('lifecycleState') == 'ACTIVE'
|
| 512 |
+
]
|
| 513 |
+
log.info(f"获取到 {len(active_projects)} 个活跃项目")
|
| 514 |
+
return active_projects
|
| 515 |
+
else:
|
| 516 |
+
log.warning(f"获取项目列表失败: {response.status_code} - {response.text}")
|
| 517 |
+
return []
|
| 518 |
+
|
| 519 |
+
except Exception as e:
|
| 520 |
+
log.error(f"获取用户项目列表失败: {e}")
|
| 521 |
+
return []
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
async def select_default_project(projects: List[Dict[str, Any]]) -> Optional[str]:
|
| 527 |
+
"""从项目列表中选择默认项目"""
|
| 528 |
+
if not projects:
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
# 策略1:查找显示名称或项目ID包含"default"的项目
|
| 532 |
+
for project in projects:
|
| 533 |
+
display_name = project.get('displayName', '').lower()
|
| 534 |
+
project_id = project.get('projectId', '')
|
| 535 |
+
if 'default' in display_name or 'default' in project_id.lower():
|
| 536 |
+
log.info(f"选择默认项目: {project_id} ({project.get('displayName', project_id)})")
|
| 537 |
+
return project_id
|
| 538 |
+
|
| 539 |
+
# 策略2:选择第一个项目
|
| 540 |
+
first_project = projects[0]
|
| 541 |
+
project_id = first_project.get('projectId', '')
|
| 542 |
+
log.info(f"选择第一个项目作为默认: {project_id} ({first_project.get('displayName', project_id)})")
|
| 543 |
+
return project_id
|
src/httpx_client.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
通用的HTTP客户端模块
|
| 3 |
+
为所有需要使用httpx的模块提供统一的客户端配置和方法
|
| 4 |
+
保持通用性,不与特定业务逻辑耦合
|
| 5 |
+
"""
|
| 6 |
+
import httpx
|
| 7 |
+
from typing import Optional, Dict, Any, AsyncGenerator
|
| 8 |
+
from contextlib import asynccontextmanager
|
| 9 |
+
|
| 10 |
+
from config import get_proxy_config
|
| 11 |
+
from log import log
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class HttpxClientManager:
|
| 15 |
+
"""通用HTTP客户端管理器"""
|
| 16 |
+
|
| 17 |
+
async def get_client_kwargs(self, timeout: float = 30.0, **kwargs) -> Dict[str, Any]:
|
| 18 |
+
"""获取httpx客户端的通用配置参数"""
|
| 19 |
+
client_kwargs = {
|
| 20 |
+
"timeout": timeout,
|
| 21 |
+
**kwargs
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
# 动态读取代理配置,支持热更新
|
| 25 |
+
current_proxy_config = await get_proxy_config()
|
| 26 |
+
if current_proxy_config:
|
| 27 |
+
client_kwargs["proxy"] = current_proxy_config
|
| 28 |
+
|
| 29 |
+
return client_kwargs
|
| 30 |
+
|
| 31 |
+
@asynccontextmanager
|
| 32 |
+
async def get_client(self, timeout: float = 30.0, **kwargs) -> AsyncGenerator[httpx.AsyncClient, None]:
|
| 33 |
+
"""获取配置好的异步HTTP客户端"""
|
| 34 |
+
client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs)
|
| 35 |
+
|
| 36 |
+
async with httpx.AsyncClient(**client_kwargs) as client:
|
| 37 |
+
yield client
|
| 38 |
+
|
| 39 |
+
@asynccontextmanager
|
| 40 |
+
async def get_streaming_client(self, timeout: float = None, **kwargs) -> AsyncGenerator[httpx.AsyncClient, None]:
|
| 41 |
+
"""获取用于流式请求的HTTP客户端(无超时限制)"""
|
| 42 |
+
client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs)
|
| 43 |
+
|
| 44 |
+
# 创建独立的客户端实例用于流式处理
|
| 45 |
+
client = httpx.AsyncClient(**client_kwargs)
|
| 46 |
+
try:
|
| 47 |
+
yield client
|
| 48 |
+
finally:
|
| 49 |
+
await client.aclose()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# 全局HTTP客户端管理器实例
|
| 53 |
+
http_client = HttpxClientManager()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# 通用的异步方法
|
| 57 |
+
async def get_async(url: str, headers: Optional[Dict[str, str]] = None,
|
| 58 |
+
timeout: float = 30.0, **kwargs) -> httpx.Response:
|
| 59 |
+
"""通用异步GET请求"""
|
| 60 |
+
async with http_client.get_client(timeout=timeout, **kwargs) as client:
|
| 61 |
+
return await client.get(url, headers=headers)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
async def post_async(url: str, data: Any = None, json: Any = None,
|
| 65 |
+
headers: Optional[Dict[str, str]] = None,
|
| 66 |
+
timeout: float = 30.0, **kwargs) -> httpx.Response:
|
| 67 |
+
"""通用异步POST请求"""
|
| 68 |
+
async with http_client.get_client(timeout=timeout, **kwargs) as client:
|
| 69 |
+
return await client.post(url, data=data, json=json, headers=headers)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def put_async(url: str, data: Any = None, json: Any = None,
|
| 73 |
+
headers: Optional[Dict[str, str]] = None,
|
| 74 |
+
timeout: float = 30.0, **kwargs) -> httpx.Response:
|
| 75 |
+
"""通用异步PUT请求"""
|
| 76 |
+
async with http_client.get_client(timeout=timeout, **kwargs) as client:
|
| 77 |
+
return await client.put(url, data=data, json=json, headers=headers)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
async def delete_async(url: str, headers: Optional[Dict[str, str]] = None,
|
| 81 |
+
timeout: float = 30.0, **kwargs) -> httpx.Response:
|
| 82 |
+
"""通用异步DELETE请求"""
|
| 83 |
+
async with http_client.get_client(timeout=timeout, **kwargs) as client:
|
| 84 |
+
return await client.delete(url, headers=headers)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# 错误处理装饰器
|
| 88 |
+
def handle_http_errors(func):
|
| 89 |
+
"""HTTP错误处理装饰器"""
|
| 90 |
+
async def wrapper(*args, **kwargs):
|
| 91 |
+
try:
|
| 92 |
+
response = await func(*args, **kwargs)
|
| 93 |
+
response.raise_for_status()
|
| 94 |
+
return response
|
| 95 |
+
except httpx.HTTPStatusError as e:
|
| 96 |
+
log.error(f"HTTP错误: {e.response.status_code} - {e.response.text}")
|
| 97 |
+
raise
|
| 98 |
+
except httpx.RequestError as e:
|
| 99 |
+
log.error(f"请求错误: {e}")
|
| 100 |
+
raise
|
| 101 |
+
except Exception as e:
|
| 102 |
+
log.error(f"未知错误: {e}")
|
| 103 |
+
raise
|
| 104 |
+
return wrapper
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# 应用错误处理的安全方法
|
| 108 |
+
@handle_http_errors
|
| 109 |
+
async def safe_get_async(url: str, headers: Optional[Dict[str, str]] = None,
|
| 110 |
+
timeout: float = 30.0, **kwargs) -> httpx.Response:
|
| 111 |
+
"""安全的异步GET请求(自动错误处理)"""
|
| 112 |
+
return await get_async(url, headers=headers, timeout=timeout, **kwargs)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@handle_http_errors
|
| 116 |
+
async def safe_post_async(url: str, data: Any = None, json: Any = None,
|
| 117 |
+
headers: Optional[Dict[str, str]] = None,
|
| 118 |
+
timeout: float = 30.0, **kwargs) -> httpx.Response:
|
| 119 |
+
"""安全的异步POST请求(自动错误处理)"""
|
| 120 |
+
return await post_async(url, data=data, json=json, headers=headers, timeout=timeout, **kwargs)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@handle_http_errors
|
| 124 |
+
async def safe_put_async(url: str, data: Any = None, json: Any = None,
|
| 125 |
+
headers: Optional[Dict[str, str]] = None,
|
| 126 |
+
timeout: float = 30.0, **kwargs) -> httpx.Response:
|
| 127 |
+
"""安全的异步PUT请求(自动错误处理)"""
|
| 128 |
+
return await put_async(url, data=data, json=json, headers=headers, timeout=timeout, **kwargs)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@handle_http_errors
|
| 132 |
+
async def safe_delete_async(url: str, headers: Optional[Dict[str, str]] = None,
|
| 133 |
+
timeout: float = 30.0, **kwargs) -> httpx.Response:
|
| 134 |
+
"""安全的异步DELETE请求(自动错误处理)"""
|
| 135 |
+
return await delete_async(url, headers=headers, timeout=timeout, **kwargs)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# 流式请求支持
|
| 139 |
+
class StreamingContext:
|
| 140 |
+
"""流式请求上下文管理器"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, client: httpx.AsyncClient, stream_context):
|
| 143 |
+
self.client = client
|
| 144 |
+
self.stream_context = stream_context
|
| 145 |
+
self.response = None
|
| 146 |
+
|
| 147 |
+
async def __aenter__(self):
|
| 148 |
+
self.response = await self.stream_context.__aenter__()
|
| 149 |
+
return self.response
|
| 150 |
+
|
| 151 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 152 |
+
try:
|
| 153 |
+
if self.stream_context:
|
| 154 |
+
await self.stream_context.__aexit__(exc_type, exc_val, exc_tb)
|
| 155 |
+
finally:
|
| 156 |
+
if self.client:
|
| 157 |
+
await self.client.aclose()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@asynccontextmanager
|
| 161 |
+
async def get_streaming_post_context(url: str, data: Any = None, json: Any = None,
|
| 162 |
+
headers: Optional[Dict[str, str]] = None,
|
| 163 |
+
timeout: float = None, **kwargs) -> AsyncGenerator[StreamingContext, None]:
|
| 164 |
+
"""获取流式POST请求的上下文管理器"""
|
| 165 |
+
async with http_client.get_streaming_client(timeout=timeout, **kwargs) as client:
|
| 166 |
+
stream_ctx = client.stream("POST", url, data=data, json=json, headers=headers)
|
| 167 |
+
streaming_context = StreamingContext(client, stream_ctx)
|
| 168 |
+
yield streaming_context
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
async def create_streaming_client_with_kwargs(**kwargs) -> httpx.AsyncClient:
|
| 172 |
+
"""创建用于流式处理的独立客户端实例(手动管理生命周期)"""
|
| 173 |
+
client_kwargs = await http_client.get_client_kwargs(timeout=None, **kwargs)
|
| 174 |
+
return httpx.AsyncClient(**client_kwargs)
|
src/models.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union, Dict, Any
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
|
| 5 |
+
# Common Models
|
| 6 |
+
class Model(BaseModel):
|
| 7 |
+
id: str
|
| 8 |
+
object: str = "model"
|
| 9 |
+
created: Optional[int] = None
|
| 10 |
+
owned_by: Optional[str] = "google"
|
| 11 |
+
|
| 12 |
+
class ModelList(BaseModel):
|
| 13 |
+
object: str = "list"
|
| 14 |
+
data: List[Model]
|
| 15 |
+
|
| 16 |
+
# OpenAI Models
|
| 17 |
+
class OpenAIChatMessage(BaseModel):
|
| 18 |
+
role: str
|
| 19 |
+
content: Union[str, List[Dict[str, Any]], None] = None
|
| 20 |
+
reasoning_content: Optional[str] = None
|
| 21 |
+
name: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
class OpenAIChatCompletionRequest(BaseModel):
|
| 24 |
+
model: str
|
| 25 |
+
messages: List[OpenAIChatMessage]
|
| 26 |
+
stream: bool = False
|
| 27 |
+
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
|
| 28 |
+
top_p: Optional[float] = Field(None, ge=0.0, le=1.0)
|
| 29 |
+
max_tokens: Optional[int] = Field(None, ge=1)
|
| 30 |
+
stop: Optional[Union[str, List[str]]] = None
|
| 31 |
+
frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
|
| 32 |
+
presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
|
| 33 |
+
n: Optional[int] = Field(1, ge=1, le=128)
|
| 34 |
+
seed: Optional[int] = None
|
| 35 |
+
response_format: Optional[Dict[str, Any]] = None
|
| 36 |
+
top_k: Optional[int] = Field(None, ge=1)
|
| 37 |
+
enable_anti_truncation: Optional[bool] = False
|
| 38 |
+
|
| 39 |
+
class Config:
|
| 40 |
+
extra = "allow" # Allow additional fields not explicitly defined
|
| 41 |
+
|
| 42 |
+
# 通用的聊天完成请求模型(兼容OpenAI和其他格式)
|
| 43 |
+
ChatCompletionRequest = OpenAIChatCompletionRequest
|
| 44 |
+
|
| 45 |
+
class OpenAIChatCompletionChoice(BaseModel):
|
| 46 |
+
index: int
|
| 47 |
+
message: OpenAIChatMessage
|
| 48 |
+
finish_reason: Optional[str] = None
|
| 49 |
+
logprobs: Optional[Dict[str, Any]] = None
|
| 50 |
+
|
| 51 |
+
class OpenAIChatCompletionResponse(BaseModel):
|
| 52 |
+
id: str
|
| 53 |
+
object: str = "chat.completion"
|
| 54 |
+
created: int
|
| 55 |
+
model: str
|
| 56 |
+
choices: List[OpenAIChatCompletionChoice]
|
| 57 |
+
usage: Optional[Dict[str, int]] = None
|
| 58 |
+
system_fingerprint: Optional[str] = None
|
| 59 |
+
|
| 60 |
+
class OpenAIDelta(BaseModel):
|
| 61 |
+
role: Optional[str] = None
|
| 62 |
+
content: Optional[str] = None
|
| 63 |
+
reasoning_content: Optional[str] = None
|
| 64 |
+
|
| 65 |
+
class OpenAIChatCompletionStreamChoice(BaseModel):
|
| 66 |
+
index: int
|
| 67 |
+
delta: OpenAIDelta
|
| 68 |
+
finish_reason: Optional[str] = None
|
| 69 |
+
logprobs: Optional[Dict[str, Any]] = None
|
| 70 |
+
|
| 71 |
+
class OpenAIChatCompletionStreamResponse(BaseModel):
|
| 72 |
+
id: str
|
| 73 |
+
object: str = "chat.completion.chunk"
|
| 74 |
+
created: int
|
| 75 |
+
model: str
|
| 76 |
+
choices: List[OpenAIChatCompletionStreamChoice]
|
| 77 |
+
system_fingerprint: Optional[str] = None
|
| 78 |
+
|
| 79 |
+
# Gemini Models
|
| 80 |
+
class GeminiPart(BaseModel):
|
| 81 |
+
text: Optional[str] = None
|
| 82 |
+
inlineData: Optional[Dict[str, Any]] = None
|
| 83 |
+
fileData: Optional[Dict[str, Any]] = None
|
| 84 |
+
thought: Optional[bool] = False
|
| 85 |
+
|
| 86 |
+
class GeminiContent(BaseModel):
|
| 87 |
+
role: str
|
| 88 |
+
parts: List[GeminiPart]
|
| 89 |
+
|
| 90 |
+
class GeminiSystemInstruction(BaseModel):
|
| 91 |
+
parts: List[GeminiPart]
|
| 92 |
+
|
| 93 |
+
class GeminiGenerationConfig(BaseModel):
|
| 94 |
+
temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
|
| 95 |
+
topP: Optional[float] = Field(None, ge=0.0, le=1.0)
|
| 96 |
+
topK: Optional[int] = Field(None, ge=1)
|
| 97 |
+
maxOutputTokens: Optional[int] = Field(None, ge=1)
|
| 98 |
+
stopSequences: Optional[List[str]] = None
|
| 99 |
+
responseMimeType: Optional[str] = None
|
| 100 |
+
responseSchema: Optional[Dict[str, Any]] = None
|
| 101 |
+
candidateCount: Optional[int] = Field(None, ge=1, le=8)
|
| 102 |
+
seed: Optional[int] = None
|
| 103 |
+
frequencyPenalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
|
| 104 |
+
presencePenalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
|
| 105 |
+
thinkingConfig: Optional[Dict[str, Any]] = None
|
| 106 |
+
|
| 107 |
+
class GeminiSafetySetting(BaseModel):
|
| 108 |
+
category: str
|
| 109 |
+
threshold: str
|
| 110 |
+
|
| 111 |
+
class GeminiRequest(BaseModel):
|
| 112 |
+
contents: List[GeminiContent]
|
| 113 |
+
systemInstruction: Optional[GeminiSystemInstruction] = None
|
| 114 |
+
generationConfig: Optional[GeminiGenerationConfig] = None
|
| 115 |
+
safetySettings: Optional[List[GeminiSafetySetting]] = None
|
| 116 |
+
tools: Optional[List[Dict[str, Any]]] = None
|
| 117 |
+
toolConfig: Optional[Dict[str, Any]] = None
|
| 118 |
+
cachedContent: Optional[str] = None
|
| 119 |
+
enable_anti_truncation: Optional[bool] = False
|
| 120 |
+
|
| 121 |
+
class Config:
|
| 122 |
+
extra = "allow" # 允许透传未定义的字段
|
| 123 |
+
|
| 124 |
+
class GeminiCandidate(BaseModel):
|
| 125 |
+
content: GeminiContent
|
| 126 |
+
finishReason: Optional[str] = None
|
| 127 |
+
index: int = 0
|
| 128 |
+
safetyRatings: Optional[List[Dict[str, Any]]] = None
|
| 129 |
+
citationMetadata: Optional[Dict[str, Any]] = None
|
| 130 |
+
tokenCount: Optional[int] = None
|
| 131 |
+
|
| 132 |
+
class GeminiUsageMetadata(BaseModel):
|
| 133 |
+
promptTokenCount: Optional[int] = None
|
| 134 |
+
candidatesTokenCount: Optional[int] = None
|
| 135 |
+
totalTokenCount: Optional[int] = None
|
| 136 |
+
|
| 137 |
+
class GeminiResponse(BaseModel):
|
| 138 |
+
candidates: List[GeminiCandidate]
|
| 139 |
+
usageMetadata: Optional[GeminiUsageMetadata] = None
|
| 140 |
+
modelVersion: Optional[str] = None
|
| 141 |
+
|
| 142 |
+
# Error Models
|
| 143 |
+
class APIError(BaseModel):
|
| 144 |
+
message: str
|
| 145 |
+
type: str = "api_error"
|
| 146 |
+
code: Optional[int] = None
|
| 147 |
+
|
| 148 |
+
class ErrorResponse(BaseModel):
|
| 149 |
+
error: APIError
|
| 150 |
+
|
| 151 |
+
# Control Panel Models
|
| 152 |
+
class SystemStatus(BaseModel):
|
| 153 |
+
status: str
|
| 154 |
+
timestamp: str
|
| 155 |
+
credentials: Dict[str, int]
|
| 156 |
+
config: Dict[str, Any]
|
| 157 |
+
current_credential: str
|
| 158 |
+
|
| 159 |
+
class CredentialInfo(BaseModel):
|
| 160 |
+
filename: str
|
| 161 |
+
project_id: Optional[str] = None
|
| 162 |
+
status: Dict[str, Any]
|
| 163 |
+
size: Optional[int] = None
|
| 164 |
+
modified_time: Optional[str] = None
|
| 165 |
+
error: Optional[str] = None
|
| 166 |
+
|
| 167 |
+
class LogEntry(BaseModel):
|
| 168 |
+
timestamp: str
|
| 169 |
+
level: str
|
| 170 |
+
message: str
|
| 171 |
+
module: Optional[str] = None
|
| 172 |
+
|
| 173 |
+
class ConfigValue(BaseModel):
|
| 174 |
+
key: str
|
| 175 |
+
value: Any
|
| 176 |
+
env_locked: bool = False
|
| 177 |
+
description: Optional[str] = None
|
| 178 |
+
|
| 179 |
+
# Authentication Models
|
| 180 |
+
class AuthRequest(BaseModel):
|
| 181 |
+
project_id: Optional[str] = None
|
| 182 |
+
user_session: Optional[str] = None
|
| 183 |
+
|
| 184 |
+
class AuthResponse(BaseModel):
|
| 185 |
+
success: bool
|
| 186 |
+
auth_url: Optional[str] = None
|
| 187 |
+
state: Optional[str] = None
|
| 188 |
+
error: Optional[str] = None
|
| 189 |
+
credentials: Optional[Dict[str, Any]] = None
|
| 190 |
+
file_path: Optional[str] = None
|
| 191 |
+
requires_manual_project_id: Optional[bool] = None
|
| 192 |
+
requires_project_selection: Optional[bool] = None
|
| 193 |
+
available_projects: Optional[List[Dict[str, str]]] = None
|
| 194 |
+
|
| 195 |
+
class CredentialStatus(BaseModel):
|
| 196 |
+
disabled: bool = False
|
| 197 |
+
error_codes: List[int] = []
|
| 198 |
+
last_success: Optional[str] = None
|
src/openai_router.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI Router - Handles OpenAI format API requests
|
| 3 |
+
处理OpenAI格式请求的路由模块
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
import asyncio
|
| 9 |
+
from contextlib import asynccontextmanager
|
| 10 |
+
|
| 11 |
+
from fastapi import APIRouter, HTTPException, Depends, Request, status
|
| 12 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 13 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 14 |
+
|
| 15 |
+
from config import get_available_models, is_fake_streaming_model, is_anti_truncation_model, get_base_model_from_feature_model, get_anti_truncation_max_attempts
|
| 16 |
+
from log import log
|
| 17 |
+
from .anti_truncation import apply_anti_truncation_to_stream
|
| 18 |
+
from .credential_manager import CredentialManager
|
| 19 |
+
from .google_chat_api import send_gemini_request
|
| 20 |
+
from .models import ChatCompletionRequest, ModelList, Model
|
| 21 |
+
from .task_manager import create_managed_task
|
| 22 |
+
from .openai_transfer import openai_request_to_gemini_payload, gemini_response_to_openai, gemini_stream_chunk_to_openai
|
| 23 |
+
|
| 24 |
+
# 创建路由器
|
| 25 |
+
router = APIRouter()
|
| 26 |
+
security = HTTPBearer()
|
| 27 |
+
|
| 28 |
+
# 全局凭证管理器实例
|
| 29 |
+
credential_manager = None
|
| 30 |
+
|
| 31 |
+
@asynccontextmanager
|
| 32 |
+
async def get_credential_manager():
|
| 33 |
+
"""获取全局凭证管理器实例"""
|
| 34 |
+
global credential_manager
|
| 35 |
+
if not credential_manager:
|
| 36 |
+
credential_manager = CredentialManager()
|
| 37 |
+
await credential_manager.initialize()
|
| 38 |
+
yield credential_manager
|
| 39 |
+
|
| 40 |
+
async def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
|
| 41 |
+
"""验证用户密码"""
|
| 42 |
+
from config import get_api_password
|
| 43 |
+
password = await get_api_password()
|
| 44 |
+
token = credentials.credentials
|
| 45 |
+
if token != password:
|
| 46 |
+
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="密码错误")
|
| 47 |
+
return token
|
| 48 |
+
|
| 49 |
+
@router.get("/v1/models", response_model=ModelList)
|
| 50 |
+
async def list_models():
|
| 51 |
+
"""返回OpenAI格式的模型列表"""
|
| 52 |
+
models = get_available_models("openai")
|
| 53 |
+
return ModelList(data=[Model(id=m) for m in models])
|
| 54 |
+
|
| 55 |
+
@router.post("/v1/chat/completions")
|
| 56 |
+
async def chat_completions(
|
| 57 |
+
request: Request,
|
| 58 |
+
token: str = Depends(authenticate)
|
| 59 |
+
):
|
| 60 |
+
"""处理OpenAI格式的聊天完成请求"""
|
| 61 |
+
|
| 62 |
+
# 获取原始请求数据
|
| 63 |
+
try:
|
| 64 |
+
raw_data = await request.json()
|
| 65 |
+
except Exception as e:
|
| 66 |
+
log.error(f"Failed to parse JSON request: {e}")
|
| 67 |
+
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
|
| 68 |
+
|
| 69 |
+
# 创建请求对象
|
| 70 |
+
try:
|
| 71 |
+
request_data = ChatCompletionRequest(**raw_data)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
log.error(f"Request validation failed: {e}")
|
| 74 |
+
raise HTTPException(status_code=400, detail=f"Request validation error: {str(e)}")
|
| 75 |
+
|
| 76 |
+
# 健康检查
|
| 77 |
+
if (len(request_data.messages) == 1 and
|
| 78 |
+
getattr(request_data.messages[0], "role", None) == "user" and
|
| 79 |
+
getattr(request_data.messages[0], "content", None) == "Hi"):
|
| 80 |
+
return JSONResponse(content={
|
| 81 |
+
"choices": [{"message": {"role": "assistant", "content": "gcli2api正常工作中"}}]
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
# 限制max_tokens
|
| 85 |
+
if getattr(request_data, "max_tokens", None) is not None and request_data.max_tokens > 65535:
|
| 86 |
+
request_data.max_tokens = 65535
|
| 87 |
+
|
| 88 |
+
# 覆写 top_k 为 64
|
| 89 |
+
setattr(request_data, "top_k", 64)
|
| 90 |
+
|
| 91 |
+
# 过滤空消息
|
| 92 |
+
filtered_messages = []
|
| 93 |
+
for m in request_data.messages:
|
| 94 |
+
content = getattr(m, "content", None)
|
| 95 |
+
if content:
|
| 96 |
+
if isinstance(content, str) and content.strip():
|
| 97 |
+
filtered_messages.append(m)
|
| 98 |
+
elif isinstance(content, list) and len(content) > 0:
|
| 99 |
+
has_valid_content = False
|
| 100 |
+
for part in content:
|
| 101 |
+
if isinstance(part, dict):
|
| 102 |
+
if part.get("type") == "text" and part.get("text", "").strip():
|
| 103 |
+
has_valid_content = True
|
| 104 |
+
break
|
| 105 |
+
elif part.get("type") == "image_url" and part.get("image_url", {}).get("url"):
|
| 106 |
+
has_valid_content = True
|
| 107 |
+
break
|
| 108 |
+
if has_valid_content:
|
| 109 |
+
filtered_messages.append(m)
|
| 110 |
+
|
| 111 |
+
request_data.messages = filtered_messages
|
| 112 |
+
|
| 113 |
+
# 处理模型名称和功能检测
|
| 114 |
+
model = request_data.model
|
| 115 |
+
use_fake_streaming = is_fake_streaming_model(model)
|
| 116 |
+
use_anti_truncation = is_anti_truncation_model(model)
|
| 117 |
+
|
| 118 |
+
# 获取基础模型名
|
| 119 |
+
real_model = get_base_model_from_feature_model(model)
|
| 120 |
+
request_data.model = real_model
|
| 121 |
+
|
| 122 |
+
# 获取凭证管理器
|
| 123 |
+
from src.credential_manager import get_credential_manager
|
| 124 |
+
cred_mgr = await get_credential_manager()
|
| 125 |
+
|
| 126 |
+
# 获取有效凭证
|
| 127 |
+
credential_result = await cred_mgr.get_valid_credential()
|
| 128 |
+
if not credential_result:
|
| 129 |
+
log.error("当前无可用凭证,请去控制台获取")
|
| 130 |
+
raise HTTPException(status_code=500, detail="当前无可用凭证,请去控制台获取")
|
| 131 |
+
|
| 132 |
+
current_file = credential_result
|
| 133 |
+
log.debug(f"Using credential: {current_file}")
|
| 134 |
+
|
| 135 |
+
# 增加调用计数
|
| 136 |
+
cred_mgr.increment_call_count()
|
| 137 |
+
|
| 138 |
+
# 转换为Gemini API payload格式
|
| 139 |
+
try:
|
| 140 |
+
api_payload = await openai_request_to_gemini_payload(request_data)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
log.error(f"OpenAI to Gemini conversion failed: {e}")
|
| 143 |
+
raise HTTPException(status_code=500, detail="Request conversion failed")
|
| 144 |
+
|
| 145 |
+
# 处理假流式
|
| 146 |
+
if use_fake_streaming and getattr(request_data, "stream", False):
|
| 147 |
+
request_data.stream = False
|
| 148 |
+
return await fake_stream_response(api_payload, cred_mgr)
|
| 149 |
+
|
| 150 |
+
# 处理抗截断 (仅流式传输时有效)
|
| 151 |
+
is_streaming = getattr(request_data, "stream", False)
|
| 152 |
+
if use_anti_truncation and is_streaming:
|
| 153 |
+
log.info("启用流式抗截断功能")
|
| 154 |
+
max_attempts = await get_anti_truncation_max_attempts()
|
| 155 |
+
|
| 156 |
+
# 使用流式抗截断处理器
|
| 157 |
+
gemini_response = await apply_anti_truncation_to_stream(
|
| 158 |
+
lambda api_payload: send_gemini_request(api_payload, is_streaming, cred_mgr),
|
| 159 |
+
api_payload,
|
| 160 |
+
max_attempts
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return await convert_streaming_response(gemini_response, model)
|
| 164 |
+
elif use_anti_truncation and not is_streaming:
|
| 165 |
+
log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置")
|
| 166 |
+
|
| 167 |
+
# 发送请求(429重试已在google_api_client中处理)
|
| 168 |
+
is_streaming = getattr(request_data, "stream", False)
|
| 169 |
+
log.debug(f"Sending request: streaming={is_streaming}, model={real_model}")
|
| 170 |
+
response = await send_gemini_request(api_payload, is_streaming, cred_mgr)
|
| 171 |
+
|
| 172 |
+
# 如果是流式响应,直接返回
|
| 173 |
+
if is_streaming:
|
| 174 |
+
return await convert_streaming_response(response, model)
|
| 175 |
+
|
| 176 |
+
# 转换非流式响应
|
| 177 |
+
try:
|
| 178 |
+
log.debug(f"Processing response: type={type(response)}")
|
| 179 |
+
if hasattr(response, 'body'):
|
| 180 |
+
response_data = json.loads(response.body.decode() if isinstance(response.body, bytes) else response.body)
|
| 181 |
+
else:
|
| 182 |
+
response_data = json.loads(response.content.decode() if isinstance(response.content, bytes) else response.content)
|
| 183 |
+
|
| 184 |
+
log.debug(f"Response data keys: {list(response_data.keys()) if isinstance(response_data, dict) else 'Not a dict'}")
|
| 185 |
+
openai_response = gemini_response_to_openai(response_data, model)
|
| 186 |
+
log.debug(f"Converted OpenAI response keys: {list(openai_response.keys()) if isinstance(openai_response, dict) else 'Not a dict'}")
|
| 187 |
+
return JSONResponse(content=openai_response)
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
log.error(f"Response conversion failed: {e}")
|
| 191 |
+
log.error(f"Response object: {response}")
|
| 192 |
+
raise HTTPException(status_code=500, detail="Response conversion failed")
|
| 193 |
+
|
| 194 |
+
async def fake_stream_response(api_payload: dict, cred_mgr: CredentialManager) -> StreamingResponse:
|
| 195 |
+
"""处理假流式响应"""
|
| 196 |
+
async def stream_generator():
|
| 197 |
+
try:
|
| 198 |
+
# 发送心跳
|
| 199 |
+
heartbeat = {
|
| 200 |
+
"choices": [{
|
| 201 |
+
"index": 0,
|
| 202 |
+
"delta": {"role": "assistant", "content": ""},
|
| 203 |
+
"finish_reason": None
|
| 204 |
+
}]
|
| 205 |
+
}
|
| 206 |
+
yield f"data: {json.dumps(heartbeat)}\n\n".encode()
|
| 207 |
+
|
| 208 |
+
# 异步发送实际请求
|
| 209 |
+
async def get_response():
|
| 210 |
+
return await send_gemini_request(api_payload, False, cred_mgr)
|
| 211 |
+
|
| 212 |
+
# 创建请求任务
|
| 213 |
+
response_task = create_managed_task(get_response(), name="openai_fake_stream_request")
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
# 每3秒发送一次心跳,直到收到响应
|
| 217 |
+
while not response_task.done():
|
| 218 |
+
await asyncio.sleep(3.0)
|
| 219 |
+
if not response_task.done():
|
| 220 |
+
yield f"data: {json.dumps(heartbeat)}\n\n".encode()
|
| 221 |
+
|
| 222 |
+
# 获取响应结果
|
| 223 |
+
response = await response_task
|
| 224 |
+
|
| 225 |
+
except asyncio.CancelledError:
|
| 226 |
+
# 取消任务并传播取消
|
| 227 |
+
response_task.cancel()
|
| 228 |
+
try:
|
| 229 |
+
await response_task
|
| 230 |
+
except asyncio.CancelledError:
|
| 231 |
+
pass
|
| 232 |
+
raise
|
| 233 |
+
except Exception as e:
|
| 234 |
+
# 取消任务并处理其他异常
|
| 235 |
+
response_task.cancel()
|
| 236 |
+
try:
|
| 237 |
+
await response_task
|
| 238 |
+
except asyncio.CancelledError:
|
| 239 |
+
pass
|
| 240 |
+
log.error(f"Fake streaming request failed: {e}")
|
| 241 |
+
raise
|
| 242 |
+
|
| 243 |
+
# 发送实际请求
|
| 244 |
+
# response 已在上面获取
|
| 245 |
+
|
| 246 |
+
# 处理结果
|
| 247 |
+
if hasattr(response, 'body'):
|
| 248 |
+
body_str = response.body.decode() if isinstance(response.body, bytes) else str(response.body)
|
| 249 |
+
elif hasattr(response, 'content'):
|
| 250 |
+
body_str = response.content.decode() if isinstance(response.content, bytes) else str(response.content)
|
| 251 |
+
else:
|
| 252 |
+
body_str = str(response)
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
response_data = json.loads(body_str)
|
| 256 |
+
log.debug(f"Fake stream response data: {response_data}")
|
| 257 |
+
|
| 258 |
+
# 从Gemini响应中提取内容,使用思维链分离逻辑
|
| 259 |
+
content = ""
|
| 260 |
+
reasoning_content = ""
|
| 261 |
+
if "candidates" in response_data and response_data["candidates"]:
|
| 262 |
+
# Gemini格式响应 - 使用思维链分离
|
| 263 |
+
from .openai_transfer import _extract_content_and_reasoning
|
| 264 |
+
candidate = response_data["candidates"][0]
|
| 265 |
+
if "content" in candidate and "parts" in candidate["content"]:
|
| 266 |
+
parts = candidate["content"]["parts"]
|
| 267 |
+
content, reasoning_content = _extract_content_and_reasoning(parts)
|
| 268 |
+
elif "choices" in response_data and response_data["choices"]:
|
| 269 |
+
# OpenAI格式响应
|
| 270 |
+
content = response_data["choices"][0].get("message", {}).get("content", "")
|
| 271 |
+
|
| 272 |
+
log.debug(f"Extracted content: {content}")
|
| 273 |
+
log.debug(f"Extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...")
|
| 274 |
+
|
| 275 |
+
# 如果没有正常内容但有思维内容,给出警告
|
| 276 |
+
if not content and reasoning_content:
|
| 277 |
+
log.warning(f"Fake stream response contains only thinking content: {reasoning_content[:100]}...")
|
| 278 |
+
content = "[模型正在思考中,请稍后再试或重新提问]"
|
| 279 |
+
|
| 280 |
+
if content:
|
| 281 |
+
# 构建响应块,包括思维内容(如果有)
|
| 282 |
+
delta = {"role": "assistant", "content": content}
|
| 283 |
+
if reasoning_content:
|
| 284 |
+
delta["reasoning_content"] = reasoning_content
|
| 285 |
+
|
| 286 |
+
content_chunk = {
|
| 287 |
+
"choices": [{
|
| 288 |
+
"index": 0,
|
| 289 |
+
"delta": delta,
|
| 290 |
+
"finish_reason": "stop"
|
| 291 |
+
}]
|
| 292 |
+
}
|
| 293 |
+
yield f"data: {json.dumps(content_chunk)}\n\n".encode()
|
| 294 |
+
else:
|
| 295 |
+
log.warning(f"No content found in response: {response_data}")
|
| 296 |
+
# 如果完全没有内容,提供默认回复
|
| 297 |
+
error_chunk = {
|
| 298 |
+
"choices": [{
|
| 299 |
+
"index": 0,
|
| 300 |
+
"delta": {"role": "assistant", "content": "[响应为空,请重新尝试]"},
|
| 301 |
+
"finish_reason": "stop"
|
| 302 |
+
}]
|
| 303 |
+
}
|
| 304 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 305 |
+
except json.JSONDecodeError:
|
| 306 |
+
error_chunk = {
|
| 307 |
+
"choices": [{
|
| 308 |
+
"index": 0,
|
| 309 |
+
"delta": {"role": "assistant", "content": body_str},
|
| 310 |
+
"finish_reason": "stop"
|
| 311 |
+
}]
|
| 312 |
+
}
|
| 313 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 314 |
+
|
| 315 |
+
yield "data: [DONE]\n\n".encode()
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
log.error(f"Fake streaming error: {e}")
|
| 319 |
+
error_chunk = {
|
| 320 |
+
"choices": [{
|
| 321 |
+
"index": 0,
|
| 322 |
+
"delta": {"role": "assistant", "content": f"Error: {str(e)}"},
|
| 323 |
+
"finish_reason": "stop"
|
| 324 |
+
}]
|
| 325 |
+
}
|
| 326 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 327 |
+
yield "data: [DONE]\n\n".encode()
|
| 328 |
+
|
| 329 |
+
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
| 330 |
+
|
| 331 |
+
async def convert_streaming_response(gemini_response, model: str) -> StreamingResponse:
|
| 332 |
+
"""转换流式响应为OpenAI格式"""
|
| 333 |
+
response_id = str(uuid.uuid4())
|
| 334 |
+
|
| 335 |
+
async def openai_stream_generator():
|
| 336 |
+
try:
|
| 337 |
+
# 处理不同类型的响应对象
|
| 338 |
+
if hasattr(gemini_response, 'body_iterator'):
|
| 339 |
+
# FastAPI StreamingResponse
|
| 340 |
+
async for chunk in gemini_response.body_iterator:
|
| 341 |
+
if not chunk:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
# 处理不同数据类型的startswith问题
|
| 345 |
+
if isinstance(chunk, bytes):
|
| 346 |
+
if not chunk.startswith(b'data: '):
|
| 347 |
+
continue
|
| 348 |
+
payload = chunk[len(b'data: '):]
|
| 349 |
+
else:
|
| 350 |
+
chunk_str = str(chunk)
|
| 351 |
+
if not chunk_str.startswith('data: '):
|
| 352 |
+
continue
|
| 353 |
+
payload = chunk_str[len('data: '):].encode()
|
| 354 |
+
try:
|
| 355 |
+
gemini_chunk = json.loads(payload.decode())
|
| 356 |
+
openai_chunk = gemini_stream_chunk_to_openai(gemini_chunk, model, response_id)
|
| 357 |
+
yield f"data: {json.dumps(openai_chunk, separators=(',',':'))}\n\n".encode()
|
| 358 |
+
except json.JSONDecodeError:
|
| 359 |
+
continue
|
| 360 |
+
else:
|
| 361 |
+
# 其他类型的响应,尝试直接处理
|
| 362 |
+
log.warning(f"Unexpected response type: {type(gemini_response)}")
|
| 363 |
+
error_chunk = {
|
| 364 |
+
"id": response_id,
|
| 365 |
+
"object": "chat.completion.chunk",
|
| 366 |
+
"created": int(time.time()),
|
| 367 |
+
"model": model,
|
| 368 |
+
"choices": [{
|
| 369 |
+
"index": 0,
|
| 370 |
+
"delta": {"role": "assistant", "content": "Response type error"},
|
| 371 |
+
"finish_reason": "stop"
|
| 372 |
+
}]
|
| 373 |
+
}
|
| 374 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 375 |
+
|
| 376 |
+
# 发送结束标记
|
| 377 |
+
yield "data: [DONE]\n\n".encode()
|
| 378 |
+
|
| 379 |
+
except Exception as e:
|
| 380 |
+
log.error(f"Stream conversion error: {e}")
|
| 381 |
+
error_chunk = {
|
| 382 |
+
"id": response_id,
|
| 383 |
+
"object": "chat.completion.chunk",
|
| 384 |
+
"created": int(time.time()),
|
| 385 |
+
"model": model,
|
| 386 |
+
"choices": [{
|
| 387 |
+
"index": 0,
|
| 388 |
+
"delta": {"role": "assistant", "content": f"Stream error: {str(e)}"},
|
| 389 |
+
"finish_reason": "stop"
|
| 390 |
+
}]
|
| 391 |
+
}
|
| 392 |
+
yield f"data: {json.dumps(error_chunk)}\n\n".encode()
|
| 393 |
+
yield "data: [DONE]\n\n".encode()
|
| 394 |
+
|
| 395 |
+
return StreamingResponse(openai_stream_generator(), media_type="text/event-stream")
|
src/openai_transfer.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI Transfer Module - Handles conversion between OpenAI and Gemini API formats
|
| 3 |
+
被openai-router调用,负责OpenAI格式与Gemini格式的双向转换
|
| 4 |
+
"""
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
|
| 9 |
+
from config import (
|
| 10 |
+
DEFAULT_SAFETY_SETTINGS,
|
| 11 |
+
get_base_model_name,
|
| 12 |
+
get_thinking_budget,
|
| 13 |
+
is_search_model,
|
| 14 |
+
should_include_thoughts,
|
| 15 |
+
get_compatibility_mode_enabled
|
| 16 |
+
)
|
| 17 |
+
from log import log
|
| 18 |
+
from .models import ChatCompletionRequest
|
| 19 |
+
|
| 20 |
+
async def openai_request_to_gemini_payload(openai_request: ChatCompletionRequest) -> Dict[str, Any]:
|
| 21 |
+
"""
|
| 22 |
+
将OpenAI聊天完成请求直接转换为完整的Gemini API payload格式
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
openai_request: OpenAI格式请求对象
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
完整的Gemini API payload,包含model和request字段
|
| 29 |
+
"""
|
| 30 |
+
contents = []
|
| 31 |
+
system_instructions = []
|
| 32 |
+
|
| 33 |
+
# 检查是否启用兼容性模式
|
| 34 |
+
compatibility_mode = await get_compatibility_mode_enabled()
|
| 35 |
+
|
| 36 |
+
# 处理对话中的每条消息
|
| 37 |
+
# 第一阶段:收集连续的system消息到system_instruction中(除非在兼容性模式下)
|
| 38 |
+
collecting_system = True if not compatibility_mode else False
|
| 39 |
+
|
| 40 |
+
for message in openai_request.messages:
|
| 41 |
+
role = message.role
|
| 42 |
+
|
| 43 |
+
# 处理系统消息
|
| 44 |
+
if role == "system":
|
| 45 |
+
if compatibility_mode:
|
| 46 |
+
# 兼容性模式:所有system消息转换为user消息
|
| 47 |
+
role = "user"
|
| 48 |
+
elif collecting_system:
|
| 49 |
+
# 正常模式:仍在收集连续的system消息
|
| 50 |
+
if isinstance(message.content, str):
|
| 51 |
+
system_instructions.append(message.content)
|
| 52 |
+
elif isinstance(message.content, list):
|
| 53 |
+
# 处理列表格式的系统消息
|
| 54 |
+
for part in message.content:
|
| 55 |
+
if part.get("type") == "text" and part.get("text"):
|
| 56 |
+
system_instructions.append(part["text"])
|
| 57 |
+
continue
|
| 58 |
+
else:
|
| 59 |
+
# 正常模式:后续的system消息转换为user消息
|
| 60 |
+
role = "user"
|
| 61 |
+
else:
|
| 62 |
+
# 遇到非system消息,停止收集system消息
|
| 63 |
+
collecting_system = False
|
| 64 |
+
|
| 65 |
+
# 将OpenAI角色映射到Gemini角色
|
| 66 |
+
if role == "assistant":
|
| 67 |
+
role = "model"
|
| 68 |
+
|
| 69 |
+
# 处理普通内容
|
| 70 |
+
if isinstance(message.content, list):
|
| 71 |
+
parts = []
|
| 72 |
+
for part in message.content:
|
| 73 |
+
if part.get("type") == "text":
|
| 74 |
+
parts.append({"text": part.get("text", "")})
|
| 75 |
+
elif part.get("type") == "image_url":
|
| 76 |
+
image_url = part.get("image_url", {}).get("url")
|
| 77 |
+
if image_url:
|
| 78 |
+
# 解析数据URI: "data:image/jpeg;base64,{base64_image}"
|
| 79 |
+
try:
|
| 80 |
+
mime_type, base64_data = image_url.split(";")
|
| 81 |
+
_, mime_type = mime_type.split(":")
|
| 82 |
+
_, base64_data = base64_data.split(",")
|
| 83 |
+
parts.append({
|
| 84 |
+
"inlineData": {
|
| 85 |
+
"mimeType": mime_type,
|
| 86 |
+
"data": base64_data
|
| 87 |
+
}
|
| 88 |
+
})
|
| 89 |
+
except ValueError:
|
| 90 |
+
continue
|
| 91 |
+
contents.append({"role": role, "parts": parts})
|
| 92 |
+
# log.debug(f"Added message to contents: role={role}, parts={parts}")
|
| 93 |
+
elif message.content:
|
| 94 |
+
# 简单文本内容
|
| 95 |
+
contents.append({"role": role, "parts": [{"text": message.content}]})
|
| 96 |
+
# log.debug(f"Added message to contents: role={role}, content={message.content}")
|
| 97 |
+
|
| 98 |
+
# 将OpenAI生成参数映射到Gemini格式
|
| 99 |
+
generation_config = {}
|
| 100 |
+
if openai_request.temperature is not None:
|
| 101 |
+
generation_config["temperature"] = openai_request.temperature
|
| 102 |
+
if openai_request.top_p is not None:
|
| 103 |
+
generation_config["topP"] = openai_request.top_p
|
| 104 |
+
if openai_request.max_tokens is not None:
|
| 105 |
+
generation_config["maxOutputTokens"] = openai_request.max_tokens
|
| 106 |
+
if openai_request.stop is not None:
|
| 107 |
+
# Gemini支持停止序列
|
| 108 |
+
if isinstance(openai_request.stop, str):
|
| 109 |
+
generation_config["stopSequences"] = [openai_request.stop]
|
| 110 |
+
elif isinstance(openai_request.stop, list):
|
| 111 |
+
generation_config["stopSequences"] = openai_request.stop
|
| 112 |
+
if openai_request.frequency_penalty is not None:
|
| 113 |
+
generation_config["frequencyPenalty"] = openai_request.frequency_penalty
|
| 114 |
+
if openai_request.presence_penalty is not None:
|
| 115 |
+
generation_config["presencePenalty"] = openai_request.presence_penalty
|
| 116 |
+
if openai_request.n is not None:
|
| 117 |
+
generation_config["candidateCount"] = openai_request.n
|
| 118 |
+
if openai_request.seed is not None:
|
| 119 |
+
generation_config["seed"] = openai_request.seed
|
| 120 |
+
if openai_request.response_format is not None:
|
| 121 |
+
# 处理JSON模式
|
| 122 |
+
if openai_request.response_format.get("type") == "json_object":
|
| 123 |
+
generation_config["responseMimeType"] = "application/json"
|
| 124 |
+
|
| 125 |
+
# 如果contents为空(只有系统消息的情况),添加一个默认的用户消息以满足Gemini API要求
|
| 126 |
+
if not contents:
|
| 127 |
+
contents.append({"role": "user", "parts": [{"text": "请根据系统指令回答。"}]})
|
| 128 |
+
|
| 129 |
+
# 构建请求数据
|
| 130 |
+
request_data = {
|
| 131 |
+
"contents": contents,
|
| 132 |
+
"generationConfig": generation_config,
|
| 133 |
+
"safetySettings": DEFAULT_SAFETY_SETTINGS,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
# 如果有系统消息且未启用兼容性模式,添加systemInstruction
|
| 137 |
+
if system_instructions and not compatibility_mode:
|
| 138 |
+
combined_system_instruction = "\n\n".join(system_instructions)
|
| 139 |
+
request_data["systemInstruction"] = {"parts": [{"text": combined_system_instruction}]}
|
| 140 |
+
|
| 141 |
+
log.debug(f"Final request payload contents count: {len(contents)}, system_instruction: {bool(system_instructions and not compatibility_mode)}, compatibility_mode: {compatibility_mode}")
|
| 142 |
+
|
| 143 |
+
# 为thinking模型添加thinking配置
|
| 144 |
+
thinking_budget = get_thinking_budget(openai_request.model)
|
| 145 |
+
if thinking_budget is not None:
|
| 146 |
+
request_data["generationConfig"]["thinkingConfig"] = {
|
| 147 |
+
"thinkingBudget": thinking_budget,
|
| 148 |
+
"includeThoughts": should_include_thoughts(openai_request.model)
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
# 为搜索模型添加Google Search工具
|
| 152 |
+
if is_search_model(openai_request.model):
|
| 153 |
+
request_data["tools"] = [{"googleSearch": {}}]
|
| 154 |
+
|
| 155 |
+
# 移除None值
|
| 156 |
+
request_data = {k: v for k, v in request_data.items() if v is not None}
|
| 157 |
+
|
| 158 |
+
# 返回完整的Gemini API payload格式
|
| 159 |
+
return {
|
| 160 |
+
"model": get_base_model_name(openai_request.model),
|
| 161 |
+
"request": request_data
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
def _extract_content_and_reasoning(parts: list) -> tuple:
|
| 165 |
+
"""从Gemini响应部件中提取内容和推理内容"""
|
| 166 |
+
content = ""
|
| 167 |
+
reasoning_content = ""
|
| 168 |
+
|
| 169 |
+
for part in parts:
|
| 170 |
+
# 处理文本内容
|
| 171 |
+
if part.get("text"):
|
| 172 |
+
# 检查这个部件是否包含thinking tokens
|
| 173 |
+
if part.get("thought", False):
|
| 174 |
+
reasoning_content += part.get("text", "")
|
| 175 |
+
else:
|
| 176 |
+
content += part.get("text", "")
|
| 177 |
+
|
| 178 |
+
return content, reasoning_content
|
| 179 |
+
|
| 180 |
+
def _build_message_with_reasoning(role: str, content: str, reasoning_content: str) -> dict:
|
| 181 |
+
"""构建包含可选推理内容的消息对象"""
|
| 182 |
+
message = {
|
| 183 |
+
"role": role,
|
| 184 |
+
"content": content
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
# 如果有thinking tokens,添加reasoning_content
|
| 188 |
+
if reasoning_content:
|
| 189 |
+
message["reasoning_content"] = reasoning_content
|
| 190 |
+
|
| 191 |
+
return message
|
| 192 |
+
|
| 193 |
+
def gemini_response_to_openai(gemini_response: Dict[str, Any], model: str) -> Dict[str, Any]:
|
| 194 |
+
"""
|
| 195 |
+
将Gemini API响应转换为OpenAI聊天完成格式
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
gemini_response: 来自Gemini API的响应
|
| 199 |
+
model: 要在响应中包含的模型名称
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
OpenAI聊天完成格式的字典
|
| 203 |
+
"""
|
| 204 |
+
choices = []
|
| 205 |
+
|
| 206 |
+
for candidate in gemini_response.get("candidates", []):
|
| 207 |
+
role = candidate.get("content", {}).get("role", "assistant")
|
| 208 |
+
|
| 209 |
+
# 将Gemini角色映射回OpenAI角色
|
| 210 |
+
if role == "model":
|
| 211 |
+
role = "assistant"
|
| 212 |
+
|
| 213 |
+
# 提取并分离thinking tokens和常规内容
|
| 214 |
+
parts = candidate.get("content", {}).get("parts", [])
|
| 215 |
+
content, reasoning_content = _extract_content_and_reasoning(parts)
|
| 216 |
+
|
| 217 |
+
# 构建消息对象
|
| 218 |
+
message = _build_message_with_reasoning(role, content, reasoning_content)
|
| 219 |
+
|
| 220 |
+
choices.append({
|
| 221 |
+
"index": candidate.get("index", 0),
|
| 222 |
+
"message": message,
|
| 223 |
+
"finish_reason": _map_finish_reason(candidate.get("finishReason")),
|
| 224 |
+
})
|
| 225 |
+
|
| 226 |
+
return {
|
| 227 |
+
"id": str(uuid.uuid4()),
|
| 228 |
+
"object": "chat.completion",
|
| 229 |
+
"created": int(time.time()),
|
| 230 |
+
"model": model,
|
| 231 |
+
"choices": choices,
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
def gemini_stream_chunk_to_openai(gemini_chunk: Dict[str, Any], model: str, response_id: str) -> Dict[str, Any]:
|
| 235 |
+
"""
|
| 236 |
+
将Gemini流式响应块转换为OpenAI流式格式
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
gemini_chunk: 来自Gemini流式响应的单个块
|
| 240 |
+
model: 要在响应中包含的模型名称
|
| 241 |
+
response_id: 此流式响应的一致ID
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
OpenAI流式格式的字典
|
| 245 |
+
"""
|
| 246 |
+
choices = []
|
| 247 |
+
|
| 248 |
+
for candidate in gemini_chunk.get("candidates", []):
|
| 249 |
+
role = candidate.get("content", {}).get("role", "assistant")
|
| 250 |
+
|
| 251 |
+
# 将Gemini角色映射回OpenAI角色
|
| 252 |
+
if role == "model":
|
| 253 |
+
role = "assistant"
|
| 254 |
+
|
| 255 |
+
# 提取并分离thinking tokens和常规内容
|
| 256 |
+
parts = candidate.get("content", {}).get("parts", [])
|
| 257 |
+
content, reasoning_content = _extract_content_and_reasoning(parts)
|
| 258 |
+
|
| 259 |
+
# 构建delta对象
|
| 260 |
+
delta = {}
|
| 261 |
+
if content:
|
| 262 |
+
delta["content"] = content
|
| 263 |
+
if reasoning_content:
|
| 264 |
+
delta["reasoning_content"] = reasoning_content
|
| 265 |
+
|
| 266 |
+
choices.append({
|
| 267 |
+
"index": candidate.get("index", 0),
|
| 268 |
+
"delta": delta,
|
| 269 |
+
"finish_reason": _map_finish_reason(candidate.get("finishReason")),
|
| 270 |
+
})
|
| 271 |
+
|
| 272 |
+
return {
|
| 273 |
+
"id": response_id,
|
| 274 |
+
"object": "chat.completion.chunk",
|
| 275 |
+
"created": int(time.time()),
|
| 276 |
+
"model": model,
|
| 277 |
+
"choices": choices,
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
def _map_finish_reason(gemini_reason: str) -> str:
|
| 281 |
+
"""
|
| 282 |
+
将Gemini结束原因映射到OpenAI结束原因
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
gemini_reason: 来自Gemini API的结束原因
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
OpenAI兼容的结束原因
|
| 289 |
+
"""
|
| 290 |
+
if gemini_reason == "STOP":
|
| 291 |
+
return "stop"
|
| 292 |
+
elif gemini_reason == "MAX_TOKENS":
|
| 293 |
+
return "length"
|
| 294 |
+
elif gemini_reason in ["SAFETY", "RECITATION"]:
|
| 295 |
+
return "content_filter"
|
| 296 |
+
else:
|
| 297 |
+
return None
|
| 298 |
+
|
| 299 |
+
def validate_openai_request(request_data: Dict[str, Any]) -> ChatCompletionRequest:
|
| 300 |
+
"""
|
| 301 |
+
验证并标准化OpenAI请求数据
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
request_data: 原始请求数据字典
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
验证后的ChatCompletionRequest对象
|
| 308 |
+
|
| 309 |
+
Raises:
|
| 310 |
+
ValueError: 当请求数据无效时
|
| 311 |
+
"""
|
| 312 |
+
try:
|
| 313 |
+
return ChatCompletionRequest(**request_data)
|
| 314 |
+
except Exception as e:
|
| 315 |
+
raise ValueError(f"Invalid OpenAI request format: {str(e)}")
|
| 316 |
+
|
| 317 |
+
def normalize_openai_request(request_data: ChatCompletionRequest) -> ChatCompletionRequest:
|
| 318 |
+
"""
|
| 319 |
+
标准化OpenAI请求数据,应用默认值和限制
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
request_data: 原始请求对象
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
标准化后的请求对象
|
| 326 |
+
"""
|
| 327 |
+
# 限制max_tokens
|
| 328 |
+
if getattr(request_data, "max_tokens", None) is not None and request_data.max_tokens > 65535:
|
| 329 |
+
request_data.max_tokens = 65535
|
| 330 |
+
|
| 331 |
+
# 覆写 top_k 为 64
|
| 332 |
+
setattr(request_data, "top_k", 64)
|
| 333 |
+
|
| 334 |
+
# 过滤空消息
|
| 335 |
+
filtered_messages = []
|
| 336 |
+
for m in request_data.messages:
|
| 337 |
+
content = getattr(m, "content", None)
|
| 338 |
+
if content:
|
| 339 |
+
if isinstance(content, str) and content.strip():
|
| 340 |
+
filtered_messages.append(m)
|
| 341 |
+
elif isinstance(content, list) and len(content) > 0:
|
| 342 |
+
has_valid_content = False
|
| 343 |
+
for part in content:
|
| 344 |
+
if isinstance(part, dict):
|
| 345 |
+
if part.get("type") == "text" and part.get("text", "").strip():
|
| 346 |
+
has_valid_content = True
|
| 347 |
+
break
|
| 348 |
+
elif part.get("type") == "image_url" and part.get("image_url", {}).get("url"):
|
| 349 |
+
has_valid_content = True
|
| 350 |
+
break
|
| 351 |
+
if has_valid_content:
|
| 352 |
+
filtered_messages.append(m)
|
| 353 |
+
|
| 354 |
+
request_data.messages = filtered_messages
|
| 355 |
+
|
| 356 |
+
return request_data
|
| 357 |
+
|
| 358 |
+
def is_health_check_request(request_data: ChatCompletionRequest) -> bool:
|
| 359 |
+
"""
|
| 360 |
+
检查是否为健康检查请求
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
request_data: 请求对象
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
是否为健康检查请求
|
| 367 |
+
"""
|
| 368 |
+
return (len(request_data.messages) == 1 and
|
| 369 |
+
getattr(request_data.messages[0], "role", None) == "user" and
|
| 370 |
+
getattr(request_data.messages[0], "content", None) == "Hi")
|
| 371 |
+
|
| 372 |
+
def create_health_check_response() -> Dict[str, Any]:
|
| 373 |
+
"""
|
| 374 |
+
创建健康检查响应
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
健康检查响应字典
|
| 378 |
+
"""
|
| 379 |
+
return {
|
| 380 |
+
"choices": [{
|
| 381 |
+
"message": {
|
| 382 |
+
"role": "assistant",
|
| 383 |
+
"content": "gcli2api正常工作中"
|
| 384 |
+
}
|
| 385 |
+
}]
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
def extract_model_settings(model: str) -> Dict[str, Any]:
|
| 389 |
+
"""
|
| 390 |
+
从模型名称中提取设置信息
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
model: 模型名称
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
包含模型设置的字典
|
| 397 |
+
"""
|
| 398 |
+
return {
|
| 399 |
+
"base_model": get_base_model_name(model),
|
| 400 |
+
"use_fake_streaming": model.endswith("-假流式"),
|
| 401 |
+
"thinking_budget": get_thinking_budget(model),
|
| 402 |
+
"include_thoughts": should_include_thoughts(model)
|
| 403 |
+
}
|
src/state_manager.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
统一状态管理器
|
| 3 |
+
"""
|
| 4 |
+
import asyncio
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, Any
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
|
| 9 |
+
from config import is_mongodb_mode
|
| 10 |
+
from log import log
|
| 11 |
+
from .storage_adapter import get_storage_adapter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class StateManager:
|
| 15 |
+
"""
|
| 16 |
+
统一状态管理器
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, state_file_path: str):
|
| 20 |
+
self.state_file_path = state_file_path
|
| 21 |
+
self._lock = asyncio.Lock()
|
| 22 |
+
self._storage_adapter = None
|
| 23 |
+
self._initialized = False
|
| 24 |
+
|
| 25 |
+
# 从文件路径推断存储用途
|
| 26 |
+
self._storage_purpose = self._infer_storage_purpose(state_file_path)
|
| 27 |
+
|
| 28 |
+
def _infer_storage_purpose(self, file_path: str) -> str:
|
| 29 |
+
"""根据文件路径推断存储用途"""
|
| 30 |
+
filename = os.path.basename(file_path)
|
| 31 |
+
|
| 32 |
+
if "creds_state" in filename:
|
| 33 |
+
return "credential_state"
|
| 34 |
+
elif "config" in filename:
|
| 35 |
+
return "config"
|
| 36 |
+
elif "usage" in filename or "stats" in filename:
|
| 37 |
+
return "usage_stats"
|
| 38 |
+
else:
|
| 39 |
+
return "general"
|
| 40 |
+
|
| 41 |
+
async def _ensure_initialized(self):
|
| 42 |
+
"""确保状态管理器已初始化"""
|
| 43 |
+
if not self._initialized:
|
| 44 |
+
self._storage_adapter = await get_storage_adapter()
|
| 45 |
+
self._initialized = True
|
| 46 |
+
|
| 47 |
+
if await is_mongodb_mode():
|
| 48 |
+
log.debug(f"Unified state manager initialized with MongoDB backend for: {self._storage_purpose}")
|
| 49 |
+
else:
|
| 50 |
+
log.debug(f"Unified state manager initialized with file backend for: {self._storage_purpose}")
|
| 51 |
+
|
| 52 |
+
async def _load_state(self) -> Dict[str, Any]:
|
| 53 |
+
"""加载状态数据"""
|
| 54 |
+
await self._ensure_initialized()
|
| 55 |
+
|
| 56 |
+
if self._storage_purpose == "credential_state":
|
| 57 |
+
return await self._storage_adapter.get_all_credential_states()
|
| 58 |
+
elif self._storage_purpose == "config":
|
| 59 |
+
return await self._storage_adapter.get_all_config()
|
| 60 |
+
elif self._storage_purpose == "usage_stats":
|
| 61 |
+
return await self._storage_adapter.get_all_usage_stats()
|
| 62 |
+
else:
|
| 63 |
+
# 对于通用存储,尝试获取配置数据
|
| 64 |
+
return await self._storage_adapter.get_all_config()
|
| 65 |
+
|
| 66 |
+
async def _save_state(self, state: Dict[str, Any]):
|
| 67 |
+
"""保存状态数据"""
|
| 68 |
+
await self._ensure_initialized()
|
| 69 |
+
|
| 70 |
+
# 根据存储用途批量更新数据
|
| 71 |
+
if self._storage_purpose == "credential_state":
|
| 72 |
+
# 批量更新凭证状态
|
| 73 |
+
for filename, file_state in state.items():
|
| 74 |
+
await self._storage_adapter.update_credential_state(filename, file_state)
|
| 75 |
+
elif self._storage_purpose == "config":
|
| 76 |
+
# 批量更新配置
|
| 77 |
+
for key, value in state.items():
|
| 78 |
+
await self._storage_adapter.set_config(key, value)
|
| 79 |
+
elif self._storage_purpose == "usage_stats":
|
| 80 |
+
# 批量更新使用统计
|
| 81 |
+
for filename, stats in state.items():
|
| 82 |
+
await self._storage_adapter.update_usage_stats(filename, stats)
|
| 83 |
+
else:
|
| 84 |
+
# 通用存储,作为配置处理
|
| 85 |
+
for key, value in state.items():
|
| 86 |
+
await self._storage_adapter.set_config(key, value)
|
| 87 |
+
|
| 88 |
+
@asynccontextmanager
|
| 89 |
+
async def transaction(self):
|
| 90 |
+
"""
|
| 91 |
+
事务上下文管理器,兼容原有接口。
|
| 92 |
+
Usage:
|
| 93 |
+
async with state_manager.transaction() as state:
|
| 94 |
+
state['key'] = 'value'
|
| 95 |
+
# State is automatically saved on exit
|
| 96 |
+
"""
|
| 97 |
+
async with self._lock:
|
| 98 |
+
state = await self._load_state()
|
| 99 |
+
try:
|
| 100 |
+
yield state
|
| 101 |
+
await self._save_state(state)
|
| 102 |
+
except Exception:
|
| 103 |
+
# Don't save if there was an error
|
| 104 |
+
raise
|
| 105 |
+
|
| 106 |
+
async def read_file_state(self, filename: str) -> Dict[str, Any]:
|
| 107 |
+
"""读取特定文件的状态,兼容原有接口"""
|
| 108 |
+
await self._ensure_initialized()
|
| 109 |
+
|
| 110 |
+
if self._storage_purpose == "credential_state":
|
| 111 |
+
return await self._storage_adapter.get_credential_state(filename)
|
| 112 |
+
elif self._storage_purpose == "usage_stats":
|
| 113 |
+
return await self._storage_adapter.get_usage_stats(filename)
|
| 114 |
+
else:
|
| 115 |
+
# 对于配置和通用存储,filename作为配置键
|
| 116 |
+
value = await self._storage_adapter.get_config(filename)
|
| 117 |
+
return value if isinstance(value, dict) else {}
|
| 118 |
+
|
| 119 |
+
async def update_file_state(self, filename: str, updates: Dict[str, Any]):
|
| 120 |
+
"""更新特定文件的状态,兼容原有接口"""
|
| 121 |
+
await self._ensure_initialized()
|
| 122 |
+
|
| 123 |
+
if self._storage_purpose == "credential_state":
|
| 124 |
+
await self._storage_adapter.update_credential_state(filename, updates)
|
| 125 |
+
elif self._storage_purpose == "usage_stats":
|
| 126 |
+
await self._storage_adapter.update_usage_stats(filename, updates)
|
| 127 |
+
else:
|
| 128 |
+
# 对于配置存储,如果updates是字典则作为嵌套配置处理
|
| 129 |
+
if isinstance(updates, dict) and len(updates) == 1:
|
| 130 |
+
# 如果只有一个键值对,可能是设置单个配置
|
| 131 |
+
for key, value in updates.items():
|
| 132 |
+
await self._storage_adapter.set_config(f"{filename}.{key}", value)
|
| 133 |
+
else:
|
| 134 |
+
# 否则将整个updates作为配置值
|
| 135 |
+
await self._storage_adapter.set_config(filename, updates)
|
| 136 |
+
|
| 137 |
+
async def batch_update(self, updates: Dict[str, Dict[str, Any]]):
|
| 138 |
+
"""批量更新多个文件,兼容原有接口"""
|
| 139 |
+
await self._ensure_initialized()
|
| 140 |
+
|
| 141 |
+
for filename, file_updates in updates.items():
|
| 142 |
+
await self.update_file_state(filename, file_updates)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# 全局状态管理器实例缓存
|
| 146 |
+
_state_managers: Dict[str, StateManager] = {}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_state_manager(state_file_path: str) -> StateManager:
|
| 150 |
+
"""获取或创建状态管理器实例,兼容原有接口"""
|
| 151 |
+
if state_file_path not in _state_managers:
|
| 152 |
+
_state_managers[state_file_path] = StateManager(state_file_path)
|
| 153 |
+
return _state_managers[state_file_path]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
async def close_all_state_managers():
|
| 157 |
+
"""关闭所有状态管理器(用于优雅关闭)"""
|
| 158 |
+
global _state_managers
|
| 159 |
+
|
| 160 |
+
# 关闭存储适配器(这会自动处理所有状态管理器)
|
| 161 |
+
from .storage_adapter import close_storage_adapter
|
| 162 |
+
await close_storage_adapter()
|
| 163 |
+
|
| 164 |
+
# 清空缓存
|
| 165 |
+
_state_managers.clear()
|
| 166 |
+
log.debug("All state managers closed")
|
src/storage/cache_manager.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
统一内存缓存管理器
|
| 3 |
+
为所有存储后端提供一致的内存缓存机制,确保读写一致性和高性能。
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import time
|
| 7 |
+
from typing import Dict, Any, Optional
|
| 8 |
+
from collections import deque
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
|
| 11 |
+
from log import log
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CacheBackend(ABC):
|
| 15 |
+
"""缓存后端接口,定义底层存储的读写操作"""
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
async def load_data(self) -> Dict[str, Any]:
|
| 19 |
+
"""从底层存储加载数据"""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
async def write_data(self, data: Dict[str, Any]) -> bool:
|
| 24 |
+
"""将数据写入底层存储"""
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class UnifiedCacheManager:
|
| 29 |
+
"""统一缓存管理器"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
cache_backend: CacheBackend,
|
| 34 |
+
cache_ttl: float = 300.0,
|
| 35 |
+
write_delay: float = 1.0,
|
| 36 |
+
name: str = "cache"
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
初始化缓存管理器
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
cache_backend: 缓存后端实现
|
| 43 |
+
cache_ttl: 缓存TTL(秒)
|
| 44 |
+
write_delay: 写入延迟(秒)
|
| 45 |
+
name: 缓存名称(用于日志)
|
| 46 |
+
"""
|
| 47 |
+
self._backend = cache_backend
|
| 48 |
+
self._cache_ttl = cache_ttl
|
| 49 |
+
self._write_delay = write_delay
|
| 50 |
+
self._name = name
|
| 51 |
+
|
| 52 |
+
# 缓存数据
|
| 53 |
+
self._cache: Dict[str, Any] = {}
|
| 54 |
+
self._cache_dirty = False
|
| 55 |
+
self._last_cache_time = 0
|
| 56 |
+
|
| 57 |
+
# 并发控制
|
| 58 |
+
self._cache_lock = asyncio.Lock()
|
| 59 |
+
|
| 60 |
+
# 异步写回任务
|
| 61 |
+
self._write_task: Optional[asyncio.Task] = None
|
| 62 |
+
self._shutdown_event = asyncio.Event()
|
| 63 |
+
|
| 64 |
+
# 性能监控
|
| 65 |
+
self._operation_count = 0
|
| 66 |
+
self._operation_times = deque(maxlen=1000)
|
| 67 |
+
|
| 68 |
+
async def start(self):
|
| 69 |
+
"""启动缓存管理器"""
|
| 70 |
+
if self._write_task and not self._write_task.done():
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
self._shutdown_event.clear()
|
| 74 |
+
self._write_task = asyncio.create_task(self._write_loop())
|
| 75 |
+
log.debug(f"{self._name} cache manager started")
|
| 76 |
+
|
| 77 |
+
async def stop(self):
|
| 78 |
+
"""停止缓存管理器并刷新数据"""
|
| 79 |
+
self._shutdown_event.set()
|
| 80 |
+
|
| 81 |
+
if self._write_task and not self._write_task.done():
|
| 82 |
+
try:
|
| 83 |
+
await asyncio.wait_for(self._write_task, timeout=5.0)
|
| 84 |
+
except asyncio.TimeoutError:
|
| 85 |
+
self._write_task.cancel()
|
| 86 |
+
log.warning(f"{self._name} cache writer forcibly cancelled")
|
| 87 |
+
|
| 88 |
+
# 刷新缓存
|
| 89 |
+
await self._flush_cache()
|
| 90 |
+
log.debug(f"{self._name} cache manager stopped")
|
| 91 |
+
|
| 92 |
+
async def get(self, key: str, default: Any = None) -> Any:
|
| 93 |
+
"""获取缓存项"""
|
| 94 |
+
async with self._cache_lock:
|
| 95 |
+
start_time = time.time()
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
# 确保缓存已加载
|
| 99 |
+
await self._ensure_cache_loaded()
|
| 100 |
+
|
| 101 |
+
# 性能监控
|
| 102 |
+
self._operation_count += 1
|
| 103 |
+
operation_time = time.time() - start_time
|
| 104 |
+
self._operation_times.append(operation_time)
|
| 105 |
+
|
| 106 |
+
result = self._cache.get(key, default)
|
| 107 |
+
log.debug(f"{self._name} cache get: {key} in {operation_time:.3f}s")
|
| 108 |
+
return result
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
operation_time = time.time() - start_time
|
| 112 |
+
log.error(f"Error getting {self._name} cache key {key} in {operation_time:.3f}s: {e}")
|
| 113 |
+
return default
|
| 114 |
+
|
| 115 |
+
async def set(self, key: str, value: Any) -> bool:
|
| 116 |
+
"""设置缓存项"""
|
| 117 |
+
async with self._cache_lock:
|
| 118 |
+
start_time = time.time()
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
# 确保缓存已加载
|
| 122 |
+
await self._ensure_cache_loaded()
|
| 123 |
+
|
| 124 |
+
# 更新缓存
|
| 125 |
+
self._cache[key] = value
|
| 126 |
+
self._cache_dirty = True
|
| 127 |
+
|
| 128 |
+
# 性能监控
|
| 129 |
+
self._operation_count += 1
|
| 130 |
+
operation_time = time.time() - start_time
|
| 131 |
+
self._operation_times.append(operation_time)
|
| 132 |
+
|
| 133 |
+
log.debug(f"{self._name} cache set: {key} in {operation_time:.3f}s")
|
| 134 |
+
return True
|
| 135 |
+
|
| 136 |
+
except Exception as e:
|
| 137 |
+
operation_time = time.time() - start_time
|
| 138 |
+
log.error(f"Error setting {self._name} cache key {key} in {operation_time:.3f}s: {e}")
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
async def delete(self, key: str) -> bool:
|
| 142 |
+
"""删除缓存项"""
|
| 143 |
+
async with self._cache_lock:
|
| 144 |
+
start_time = time.time()
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
# 确保缓存已加载
|
| 148 |
+
await self._ensure_cache_loaded()
|
| 149 |
+
|
| 150 |
+
if key in self._cache:
|
| 151 |
+
del self._cache[key]
|
| 152 |
+
self._cache_dirty = True
|
| 153 |
+
|
| 154 |
+
# 性能监控
|
| 155 |
+
self._operation_count += 1
|
| 156 |
+
operation_time = time.time() - start_time
|
| 157 |
+
self._operation_times.append(operation_time)
|
| 158 |
+
|
| 159 |
+
log.debug(f"{self._name} cache delete: {key} in {operation_time:.3f}s")
|
| 160 |
+
return True
|
| 161 |
+
else:
|
| 162 |
+
log.warning(f"{self._name} cache key not found for deletion: {key}")
|
| 163 |
+
return False
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
operation_time = time.time() - start_time
|
| 167 |
+
log.error(f"Error deleting {self._name} cache key {key} in {operation_time:.3f}s: {e}")
|
| 168 |
+
return False
|
| 169 |
+
|
| 170 |
+
async def get_all(self) -> Dict[str, Any]:
|
| 171 |
+
"""获取所有缓存数据"""
|
| 172 |
+
async with self._cache_lock:
|
| 173 |
+
start_time = time.time()
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
# 确保缓存已加载
|
| 177 |
+
await self._ensure_cache_loaded()
|
| 178 |
+
|
| 179 |
+
# 性能监控
|
| 180 |
+
self._operation_count += 1
|
| 181 |
+
operation_time = time.time() - start_time
|
| 182 |
+
self._operation_times.append(operation_time)
|
| 183 |
+
|
| 184 |
+
log.debug(f"{self._name} cache get_all ({len(self._cache)}) in {operation_time:.3f}s")
|
| 185 |
+
return self._cache.copy()
|
| 186 |
+
|
| 187 |
+
except Exception as e:
|
| 188 |
+
operation_time = time.time() - start_time
|
| 189 |
+
log.error(f"Error getting all {self._name} cache in {operation_time:.3f}s: {e}")
|
| 190 |
+
return {}
|
| 191 |
+
|
| 192 |
+
async def update_multi(self, updates: Dict[str, Any]) -> bool:
|
| 193 |
+
"""批量更新缓存项"""
|
| 194 |
+
async with self._cache_lock:
|
| 195 |
+
start_time = time.time()
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
# 确保缓存已加载
|
| 199 |
+
await self._ensure_cache_loaded()
|
| 200 |
+
|
| 201 |
+
# 批量更新
|
| 202 |
+
self._cache.update(updates)
|
| 203 |
+
self._cache_dirty = True
|
| 204 |
+
|
| 205 |
+
# 性能监控
|
| 206 |
+
self._operation_count += 1
|
| 207 |
+
operation_time = time.time() - start_time
|
| 208 |
+
self._operation_times.append(operation_time)
|
| 209 |
+
|
| 210 |
+
log.debug(f"{self._name} cache update_multi ({len(updates)}) in {operation_time:.3f}s")
|
| 211 |
+
return True
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
operation_time = time.time() - start_time
|
| 215 |
+
log.error(f"Error updating {self._name} cache multi in {operation_time:.3f}s: {e}")
|
| 216 |
+
return False
|
| 217 |
+
|
| 218 |
+
async def _ensure_cache_loaded(self):
|
| 219 |
+
"""确保缓存已从底层存储加载"""
|
| 220 |
+
current_time = time.time()
|
| 221 |
+
|
| 222 |
+
# 检查缓存是否需要加载(首次加载或过期)
|
| 223 |
+
# 如果缓存脏了(有未写入的数据),不要重新加载以避免数据丢失
|
| 224 |
+
if (self._last_cache_time == 0 or
|
| 225 |
+
(current_time - self._last_cache_time > self._cache_ttl and not self._cache_dirty)):
|
| 226 |
+
|
| 227 |
+
await self._load_cache()
|
| 228 |
+
self._last_cache_time = current_time
|
| 229 |
+
|
| 230 |
+
async def _load_cache(self):
|
| 231 |
+
"""从底层存储加载缓存"""
|
| 232 |
+
try:
|
| 233 |
+
start_time = time.time()
|
| 234 |
+
|
| 235 |
+
# 从后端加载数据
|
| 236 |
+
data = await self._backend.load_data()
|
| 237 |
+
|
| 238 |
+
if data:
|
| 239 |
+
self._cache = data
|
| 240 |
+
log.debug(f"{self._name} cache loaded ({len(self._cache)}) from backend")
|
| 241 |
+
else:
|
| 242 |
+
# 如果后端没有数据,初始化空缓存
|
| 243 |
+
self._cache = {}
|
| 244 |
+
log.debug(f"{self._name} cache initialized empty")
|
| 245 |
+
|
| 246 |
+
operation_time = time.time() - start_time
|
| 247 |
+
log.debug(f"{self._name} cache loaded in {operation_time:.3f}s")
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
log.error(f"Error loading {self._name} cache from backend: {e}")
|
| 251 |
+
self._cache = {}
|
| 252 |
+
|
| 253 |
+
async def _write_loop(self):
|
| 254 |
+
"""异步写回循环"""
|
| 255 |
+
while not self._shutdown_event.is_set():
|
| 256 |
+
try:
|
| 257 |
+
# 等待写入延迟或关闭信号
|
| 258 |
+
try:
|
| 259 |
+
await asyncio.wait_for(self._shutdown_event.wait(), timeout=self._write_delay)
|
| 260 |
+
break # 收到关闭信号
|
| 261 |
+
except asyncio.TimeoutError:
|
| 262 |
+
pass # 超时,检查是否需要写回
|
| 263 |
+
|
| 264 |
+
# 如果缓存脏了,写回底层存储
|
| 265 |
+
async with self._cache_lock:
|
| 266 |
+
if self._cache_dirty:
|
| 267 |
+
await self._write_cache()
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
log.error(f"Error in {self._name} cache writer loop: {e}")
|
| 271 |
+
await asyncio.sleep(1)
|
| 272 |
+
|
| 273 |
+
async def _write_cache(self):
|
| 274 |
+
"""将缓存写回底层存储"""
|
| 275 |
+
if not self._cache_dirty:
|
| 276 |
+
return
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
start_time = time.time()
|
| 280 |
+
|
| 281 |
+
# 写入后端
|
| 282 |
+
success = await self._backend.write_data(self._cache.copy())
|
| 283 |
+
|
| 284 |
+
if success:
|
| 285 |
+
self._cache_dirty = False
|
| 286 |
+
operation_time = time.time() - start_time
|
| 287 |
+
log.debug(f"{self._name} cache written to backend in {operation_time:.3f}s ({len(self._cache)} items)")
|
| 288 |
+
else:
|
| 289 |
+
log.error(f"Failed to write {self._name} cache to backend")
|
| 290 |
+
|
| 291 |
+
except Exception as e:
|
| 292 |
+
log.error(f"Error writing {self._name} cache to backend: {e}")
|
| 293 |
+
|
| 294 |
+
async def _flush_cache(self):
|
| 295 |
+
"""立即刷新缓存到底层存储"""
|
| 296 |
+
async with self._cache_lock:
|
| 297 |
+
if self._cache_dirty:
|
| 298 |
+
await self._write_cache()
|
| 299 |
+
log.debug(f"{self._name} cache flushed to backend")
|
| 300 |
+
|
| 301 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 302 |
+
"""获取缓存统计信息"""
|
| 303 |
+
avg_time = sum(self._operation_times) / len(self._operation_times) if self._operation_times else 0
|
| 304 |
+
|
| 305 |
+
return {
|
| 306 |
+
"cache_name": self._name,
|
| 307 |
+
"cache_size": len(self._cache),
|
| 308 |
+
"cache_dirty": self._cache_dirty,
|
| 309 |
+
"operation_count": self._operation_count,
|
| 310 |
+
"avg_operation_time": avg_time,
|
| 311 |
+
"last_cache_time": self._last_cache_time,
|
| 312 |
+
}
|
src/storage/file_storage_manager.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
本地文件存储管理器,使用统一缓存支持队列写入优化。
|
| 3 |
+
所有凭证和状态数据存储在creds.toml中,配置数据存储在config.toml中。
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import time
|
| 9 |
+
from typing import Dict, Any, List, Optional
|
| 10 |
+
|
| 11 |
+
import aiofiles
|
| 12 |
+
import toml
|
| 13 |
+
|
| 14 |
+
from log import log
|
| 15 |
+
from .cache_manager import UnifiedCacheManager, CacheBackend
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FileCacheBackend(CacheBackend):
|
| 19 |
+
"""文件缓存后端实现"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, file_path: str):
|
| 22 |
+
self._file_path = file_path
|
| 23 |
+
|
| 24 |
+
async def load_data(self) -> Dict[str, Any]:
|
| 25 |
+
"""从TOML文件加载数据"""
|
| 26 |
+
try:
|
| 27 |
+
if not os.path.exists(self._file_path):
|
| 28 |
+
return {}
|
| 29 |
+
|
| 30 |
+
async with aiofiles.open(self._file_path, "r", encoding="utf-8") as f:
|
| 31 |
+
content = await f.read()
|
| 32 |
+
|
| 33 |
+
if not content.strip():
|
| 34 |
+
return {}
|
| 35 |
+
|
| 36 |
+
return toml.loads(content)
|
| 37 |
+
|
| 38 |
+
except Exception as e:
|
| 39 |
+
log.error(f"Error loading data from file {self._file_path}: {e}")
|
| 40 |
+
return {}
|
| 41 |
+
|
| 42 |
+
async def write_data(self, data: Dict[str, Any]) -> bool:
|
| 43 |
+
"""将数据写入TOML文件"""
|
| 44 |
+
try:
|
| 45 |
+
# 确保目录存在
|
| 46 |
+
os.makedirs(os.path.dirname(self._file_path), exist_ok=True)
|
| 47 |
+
|
| 48 |
+
# 写入TOML文件
|
| 49 |
+
toml_content = toml.dumps(data)
|
| 50 |
+
async with aiofiles.open(self._file_path, "w", encoding="utf-8") as f:
|
| 51 |
+
await f.write(toml_content)
|
| 52 |
+
|
| 53 |
+
return True
|
| 54 |
+
|
| 55 |
+
except Exception as e:
|
| 56 |
+
log.error(f"Error writing data to file {self._file_path}: {e}")
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class FileStorageManager:
|
| 61 |
+
"""基于本地文件的存储管理器(使用统一缓存)"""
|
| 62 |
+
|
| 63 |
+
# 状态字段常量
|
| 64 |
+
STATE_FIELDS = {
|
| 65 |
+
"error_codes", "disabled", "last_success", "user_email",
|
| 66 |
+
"gemini_2_5_pro_calls", "total_calls", "next_reset_time",
|
| 67 |
+
"daily_limit_gemini_2_5_pro", "daily_limit_total"
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
# 默认状态数据模板(不包含动态值)
|
| 71 |
+
_DEFAULT_STATE_TEMPLATE = {
|
| 72 |
+
"error_codes": [],
|
| 73 |
+
"disabled": False,
|
| 74 |
+
"user_email": None,
|
| 75 |
+
"gemini_2_5_pro_calls": 0,
|
| 76 |
+
"total_calls": 0,
|
| 77 |
+
"next_reset_time": None,
|
| 78 |
+
"daily_limit_gemini_2_5_pro": 100,
|
| 79 |
+
"daily_limit_total": 1000
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def get_default_state(cls) -> Dict[str, Any]:
|
| 84 |
+
"""获取默认状态数据(包含当前时间戳)"""
|
| 85 |
+
state = cls._DEFAULT_STATE_TEMPLATE.copy()
|
| 86 |
+
state["last_success"] = time.time()
|
| 87 |
+
return state
|
| 88 |
+
|
| 89 |
+
def __init__(self):
|
| 90 |
+
self._credentials_dir = None # 将通过异步初始化设置
|
| 91 |
+
self._state_file = None
|
| 92 |
+
self._config_file = None
|
| 93 |
+
self._lock = asyncio.Lock()
|
| 94 |
+
self._initialized = False
|
| 95 |
+
|
| 96 |
+
# 统一缓存管理器
|
| 97 |
+
self._credentials_cache_manager: Optional[UnifiedCacheManager] = None
|
| 98 |
+
self._config_cache_manager: Optional[UnifiedCacheManager] = None
|
| 99 |
+
|
| 100 |
+
# 配置参数
|
| 101 |
+
self._write_delay = 0.5 # 写入延迟(秒)
|
| 102 |
+
self._cache_ttl = 300 # 缓存TTL(秒)
|
| 103 |
+
|
| 104 |
+
async def initialize(self) -> None:
|
| 105 |
+
"""初始化文件存储"""
|
| 106 |
+
if self._initialized:
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
# 获取凭证目录配置(初始化时直接使用环境变量,避免循环依赖)
|
| 110 |
+
self._credentials_dir = os.getenv("CREDENTIALS_DIR", "./creds")
|
| 111 |
+
self._state_file = os.path.join(self._credentials_dir, "creds.toml")
|
| 112 |
+
self._config_file = os.path.join(self._credentials_dir, "config.toml")
|
| 113 |
+
|
| 114 |
+
# 确保目录存在
|
| 115 |
+
os.makedirs(self._credentials_dir, exist_ok=True)
|
| 116 |
+
|
| 117 |
+
# 执行JSON到TOML的迁移
|
| 118 |
+
await self._migrate_json_to_toml()
|
| 119 |
+
|
| 120 |
+
# 创建缓存管理器
|
| 121 |
+
credentials_backend = FileCacheBackend(self._state_file)
|
| 122 |
+
config_backend = FileCacheBackend(self._config_file)
|
| 123 |
+
|
| 124 |
+
self._credentials_cache_manager = UnifiedCacheManager(
|
| 125 |
+
credentials_backend,
|
| 126 |
+
cache_ttl=self._cache_ttl,
|
| 127 |
+
write_delay=self._write_delay,
|
| 128 |
+
name="credentials"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self._config_cache_manager = UnifiedCacheManager(
|
| 132 |
+
config_backend,
|
| 133 |
+
cache_ttl=self._cache_ttl,
|
| 134 |
+
write_delay=self._write_delay,
|
| 135 |
+
name="config"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# 启动缓存管理器
|
| 139 |
+
await self._credentials_cache_manager.start()
|
| 140 |
+
await self._config_cache_manager.start()
|
| 141 |
+
|
| 142 |
+
self._initialized = True
|
| 143 |
+
log.debug("File storage manager initialized with unified cache")
|
| 144 |
+
|
| 145 |
+
async def close(self) -> None:
|
| 146 |
+
"""关闭文件存储"""
|
| 147 |
+
# 停止缓存管理器
|
| 148 |
+
if self._credentials_cache_manager:
|
| 149 |
+
await self._credentials_cache_manager.stop()
|
| 150 |
+
if self._config_cache_manager:
|
| 151 |
+
await self._config_cache_manager.stop()
|
| 152 |
+
|
| 153 |
+
self._initialized = False
|
| 154 |
+
log.debug("File storage manager closed with unified cache flushed")
|
| 155 |
+
|
| 156 |
+
def _normalize_filename(self, filename: str) -> str:
|
| 157 |
+
"""标准化文件名"""
|
| 158 |
+
return os.path.basename(filename)
|
| 159 |
+
|
| 160 |
+
def _ensure_initialized(self):
|
| 161 |
+
"""确保已初始化"""
|
| 162 |
+
if not self._initialized:
|
| 163 |
+
raise RuntimeError("File storage manager not initialized")
|
| 164 |
+
|
| 165 |
+
async def _migrate_json_to_toml(self) -> None:
|
| 166 |
+
"""将现有的JSON凭证文件和旧的creds_state.toml迁移到新的creds.toml文件中"""
|
| 167 |
+
try:
|
| 168 |
+
# 扫描JSON凭证文件
|
| 169 |
+
json_files = []
|
| 170 |
+
if os.path.exists(self._credentials_dir):
|
| 171 |
+
for filename in os.listdir(self._credentials_dir):
|
| 172 |
+
if filename.endswith(".json"):
|
| 173 |
+
json_files.append(filename)
|
| 174 |
+
|
| 175 |
+
# 检查旧的creds_state.toml文件
|
| 176 |
+
old_state_file = os.path.join(self._credentials_dir, "creds_state.toml")
|
| 177 |
+
has_old_state = os.path.exists(old_state_file)
|
| 178 |
+
|
| 179 |
+
if not json_files and not has_old_state:
|
| 180 |
+
log.debug("No JSON credential files or old state file found for migration")
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# 加载现有TOML数据(如果存在)
|
| 184 |
+
toml_data = {}
|
| 185 |
+
if os.path.exists(self._state_file):
|
| 186 |
+
try:
|
| 187 |
+
async with aiofiles.open(self._state_file, "r", encoding="utf-8") as f:
|
| 188 |
+
content = await f.read()
|
| 189 |
+
if content.strip():
|
| 190 |
+
toml_data = toml.loads(content)
|
| 191 |
+
except Exception as e:
|
| 192 |
+
log.error(f"Failed to load existing TOML file: {e}")
|
| 193 |
+
|
| 194 |
+
# 加载旧的creds_state.toml文件(稍后处理)
|
| 195 |
+
old_state_data = {}
|
| 196 |
+
if has_old_state:
|
| 197 |
+
try:
|
| 198 |
+
async with aiofiles.open(old_state_file, "r", encoding="utf-8") as f:
|
| 199 |
+
content = await f.read()
|
| 200 |
+
old_state_data = toml.loads(content)
|
| 201 |
+
log.debug("Loaded old state file for potential migration")
|
| 202 |
+
except Exception as e:
|
| 203 |
+
log.error(f"Failed to load old state file: {e}")
|
| 204 |
+
old_state_data = {}
|
| 205 |
+
|
| 206 |
+
if json_files:
|
| 207 |
+
log.info(f"Migrating {len(json_files)} JSON credential files to TOML")
|
| 208 |
+
|
| 209 |
+
# 处理每个JSON文件
|
| 210 |
+
migrated_count = 0
|
| 211 |
+
for filename in json_files:
|
| 212 |
+
try:
|
| 213 |
+
filepath = os.path.join(self._credentials_dir, filename)
|
| 214 |
+
|
| 215 |
+
# 读取JSON凭证数据
|
| 216 |
+
async with aiofiles.open(filepath, "r", encoding="utf-8") as f:
|
| 217 |
+
json_content = await f.read()
|
| 218 |
+
credential_data = json.loads(json_content)
|
| 219 |
+
|
| 220 |
+
# 创建新的section:凭证数据 + 状态数据
|
| 221 |
+
section_data = credential_data.copy()
|
| 222 |
+
|
| 223 |
+
# 首先添加默认状态数据
|
| 224 |
+
section_data.update(self.get_default_state())
|
| 225 |
+
|
| 226 |
+
# 如果旧状态文件中有该凭证的状态数据,则使用旧状态数据覆盖默认值
|
| 227 |
+
if filename in old_state_data and isinstance(old_state_data[filename], dict):
|
| 228 |
+
log.debug(f"Using old state data for: {filename}")
|
| 229 |
+
section_data.update(old_state_data[filename])
|
| 230 |
+
|
| 231 |
+
# 如果当前TOML中已存在该凭证,保留其状态数据
|
| 232 |
+
if filename in toml_data and isinstance(toml_data[filename], dict):
|
| 233 |
+
log.debug(f"Merging with existing TOML state for: {filename}")
|
| 234 |
+
existing_state = toml_data[filename]
|
| 235 |
+
section_data.update(existing_state)
|
| 236 |
+
|
| 237 |
+
# 最后确保凭证数据是最新的(覆盖任何冲突的字段)
|
| 238 |
+
section_data.update(credential_data)
|
| 239 |
+
|
| 240 |
+
toml_data[filename] = section_data
|
| 241 |
+
|
| 242 |
+
migrated_count += 1
|
| 243 |
+
log.debug(f"Migrated credential: {filename}")
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
log.error(f"Failed to migrate {filename}: {e}")
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
# 保存TOML文件(如果有新的迁移)
|
| 250 |
+
if migrated_count > 0:
|
| 251 |
+
try:
|
| 252 |
+
toml_content = toml.dumps(toml_data)
|
| 253 |
+
async with aiofiles.open(self._state_file, "w", encoding="utf-8") as f:
|
| 254 |
+
await f.write(toml_content)
|
| 255 |
+
|
| 256 |
+
# 删除已迁移的JSON文件
|
| 257 |
+
for filename in json_files:
|
| 258 |
+
try:
|
| 259 |
+
if filename in toml_data: # 确保文件确实被迁移了
|
| 260 |
+
filepath = os.path.join(self._credentials_dir, filename)
|
| 261 |
+
os.remove(filepath)
|
| 262 |
+
log.debug(f"Removed migrated JSON file: {filename}")
|
| 263 |
+
except Exception as e:
|
| 264 |
+
log.warning(f"Failed to remove {filename}: {e}")
|
| 265 |
+
|
| 266 |
+
# 删除旧的状态文件(如果存在)
|
| 267 |
+
if has_old_state:
|
| 268 |
+
try:
|
| 269 |
+
os.remove(old_state_file)
|
| 270 |
+
log.debug("Removed old state file: creds_state.toml")
|
| 271 |
+
except Exception as e:
|
| 272 |
+
log.warning(f"Failed to remove old state file: {e}")
|
| 273 |
+
|
| 274 |
+
log.info(f"Migration completed: {migrated_count} files migrated to TOML format")
|
| 275 |
+
|
| 276 |
+
except Exception as e:
|
| 277 |
+
log.error(f"Failed to save migrated TOML file: {e}")
|
| 278 |
+
|
| 279 |
+
except Exception as e:
|
| 280 |
+
log.error(f"Error during JSON to TOML migration: {e}")
|
| 281 |
+
|
| 282 |
+
# ============ 凭证管理 ============
|
| 283 |
+
|
| 284 |
+
async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
|
| 285 |
+
"""存储凭证数据到统一缓存"""
|
| 286 |
+
self._ensure_initialized()
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
filename = self._normalize_filename(filename)
|
| 290 |
+
|
| 291 |
+
# 获取现有数据或创建新数据
|
| 292 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 293 |
+
existing_state = all_data.get(filename, {})
|
| 294 |
+
|
| 295 |
+
# 创建新的section数据:凭证数据 + 状态数据
|
| 296 |
+
final_data = self.get_default_state()
|
| 297 |
+
final_data.update(existing_state)
|
| 298 |
+
final_data.update(credential_data) # 凭证数据覆盖状态数据中的同名字段
|
| 299 |
+
|
| 300 |
+
# 更新整个数据集
|
| 301 |
+
all_data[filename] = final_data
|
| 302 |
+
|
| 303 |
+
success = await self._credentials_cache_manager.update_multi({filename: final_data})
|
| 304 |
+
log.debug(f"Stored credential to unified cache: {filename}")
|
| 305 |
+
return success
|
| 306 |
+
|
| 307 |
+
except Exception as e:
|
| 308 |
+
log.error(f"Error storing credential {filename}: {e}")
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
+
async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
|
| 312 |
+
"""从统一缓存获取凭证数据"""
|
| 313 |
+
self._ensure_initialized()
|
| 314 |
+
|
| 315 |
+
try:
|
| 316 |
+
filename = self._normalize_filename(filename)
|
| 317 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 318 |
+
|
| 319 |
+
if filename not in all_data:
|
| 320 |
+
return None
|
| 321 |
+
|
| 322 |
+
section_data = all_data[filename]
|
| 323 |
+
|
| 324 |
+
# 提取凭证数据(排除状态字段)
|
| 325 |
+
credential_data = {k: v for k, v in section_data.items() if k not in self.STATE_FIELDS}
|
| 326 |
+
return credential_data
|
| 327 |
+
|
| 328 |
+
except Exception as e:
|
| 329 |
+
log.error(f"Error getting credential {filename}: {e}")
|
| 330 |
+
return None
|
| 331 |
+
|
| 332 |
+
async def list_credentials(self) -> List[str]:
|
| 333 |
+
"""从统一缓存列出所有凭证文件名"""
|
| 334 |
+
self._ensure_initialized()
|
| 335 |
+
|
| 336 |
+
try:
|
| 337 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 338 |
+
return list(all_data.keys())
|
| 339 |
+
|
| 340 |
+
except Exception as e:
|
| 341 |
+
log.error(f"Error listing credentials: {e}")
|
| 342 |
+
return []
|
| 343 |
+
|
| 344 |
+
async def delete_credential(self, filename: str) -> bool:
|
| 345 |
+
"""从统一缓存删除凭证"""
|
| 346 |
+
self._ensure_initialized()
|
| 347 |
+
|
| 348 |
+
try:
|
| 349 |
+
filename = self._normalize_filename(filename)
|
| 350 |
+
success = await self._credentials_cache_manager.delete(filename)
|
| 351 |
+
log.debug(f"Deleted credential from unified cache: {filename}")
|
| 352 |
+
return success
|
| 353 |
+
|
| 354 |
+
except Exception as e:
|
| 355 |
+
log.error(f"Error deleting credential {filename}: {e}")
|
| 356 |
+
return False
|
| 357 |
+
|
| 358 |
+
# ============ 状态管理 ============
|
| 359 |
+
|
| 360 |
+
async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
|
| 361 |
+
"""更新凭证状态"""
|
| 362 |
+
self._ensure_initialized()
|
| 363 |
+
|
| 364 |
+
try:
|
| 365 |
+
filename = self._normalize_filename(filename)
|
| 366 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 367 |
+
|
| 368 |
+
if filename not in all_data:
|
| 369 |
+
all_data[filename] = self.get_default_state()
|
| 370 |
+
|
| 371 |
+
# 更新状态
|
| 372 |
+
all_data[filename].update(state_updates)
|
| 373 |
+
|
| 374 |
+
success = await self._credentials_cache_manager.update_multi({filename: all_data[filename]})
|
| 375 |
+
log.debug(f"Updated credential state in unified cache: {filename}")
|
| 376 |
+
return success
|
| 377 |
+
|
| 378 |
+
except Exception as e:
|
| 379 |
+
log.error(f"Error updating credential state {filename}: {e}")
|
| 380 |
+
return False
|
| 381 |
+
|
| 382 |
+
async def get_credential_state(self, filename: str) -> Dict[str, Any]:
|
| 383 |
+
"""从统一缓存获取凭证状态"""
|
| 384 |
+
self._ensure_initialized()
|
| 385 |
+
|
| 386 |
+
try:
|
| 387 |
+
filename = self._normalize_filename(filename)
|
| 388 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 389 |
+
|
| 390 |
+
if filename not in all_data:
|
| 391 |
+
# 返回基本的状态字段
|
| 392 |
+
default_state = self.get_default_state()
|
| 393 |
+
return {k: v for k, v in default_state.items() if k in {"error_codes", "disabled", "last_success", "user_email"}}
|
| 394 |
+
|
| 395 |
+
section_data = all_data[filename]
|
| 396 |
+
|
| 397 |
+
# 提取状态字段
|
| 398 |
+
state_data = {k: v for k, v in section_data.items() if k in self.STATE_FIELDS}
|
| 399 |
+
|
| 400 |
+
# 确保必要字段存在
|
| 401 |
+
basic_fields = {"error_codes", "disabled", "last_success", "user_email"}
|
| 402 |
+
default_state = self.get_default_state()
|
| 403 |
+
|
| 404 |
+
for field in basic_fields:
|
| 405 |
+
if field not in state_data:
|
| 406 |
+
state_data[field] = default_state[field]
|
| 407 |
+
|
| 408 |
+
return state_data
|
| 409 |
+
|
| 410 |
+
except Exception as e:
|
| 411 |
+
log.error(f"Error getting credential state {filename}: {e}")
|
| 412 |
+
return self.get_default_state()
|
| 413 |
+
|
| 414 |
+
async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
|
| 415 |
+
"""从统一缓存获取所有凭证状态"""
|
| 416 |
+
self._ensure_initialized()
|
| 417 |
+
|
| 418 |
+
try:
|
| 419 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 420 |
+
|
| 421 |
+
states = {}
|
| 422 |
+
for filename, section_data in all_data.items():
|
| 423 |
+
# 提取状态字段
|
| 424 |
+
state_data = {k: v for k, v in section_data.items() if k in self.STATE_FIELDS}
|
| 425 |
+
|
| 426 |
+
# 确保必要字段存在
|
| 427 |
+
basic_fields = {"error_codes", "disabled", "last_success", "user_email"}
|
| 428 |
+
default_state = self.get_default_state()
|
| 429 |
+
|
| 430 |
+
for field in basic_fields:
|
| 431 |
+
if field not in state_data:
|
| 432 |
+
state_data[field] = default_state[field]
|
| 433 |
+
|
| 434 |
+
states[filename] = state_data
|
| 435 |
+
|
| 436 |
+
return states
|
| 437 |
+
|
| 438 |
+
except Exception as e:
|
| 439 |
+
log.error(f"Error getting all credential states: {e}")
|
| 440 |
+
return {}
|
| 441 |
+
|
| 442 |
+
# ============ 配置管理 ============
|
| 443 |
+
|
| 444 |
+
async def set_config(self, key: str, value: Any) -> bool:
|
| 445 |
+
"""设置配置到统一缓存"""
|
| 446 |
+
self._ensure_initialized()
|
| 447 |
+
return await self._config_cache_manager.set(key, value)
|
| 448 |
+
|
| 449 |
+
async def get_config(self, key: str, default: Any = None) -> Any:
|
| 450 |
+
"""从统一缓存获取配置"""
|
| 451 |
+
self._ensure_initialized()
|
| 452 |
+
return await self._config_cache_manager.get(key, default)
|
| 453 |
+
|
| 454 |
+
async def get_all_config(self) -> Dict[str, Any]:
|
| 455 |
+
"""从统一缓存获取所有配置"""
|
| 456 |
+
self._ensure_initialized()
|
| 457 |
+
return await self._config_cache_manager.get_all()
|
| 458 |
+
|
| 459 |
+
async def delete_config(self, key: str) -> bool:
|
| 460 |
+
"""从统一缓存删除配置"""
|
| 461 |
+
self._ensure_initialized()
|
| 462 |
+
return await self._config_cache_manager.delete(key)
|
| 463 |
+
|
| 464 |
+
# ============ 使用统计管理 ============
|
| 465 |
+
|
| 466 |
+
async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
|
| 467 |
+
"""更新使用统计"""
|
| 468 |
+
self._ensure_initialized()
|
| 469 |
+
|
| 470 |
+
try:
|
| 471 |
+
filename = self._normalize_filename(filename)
|
| 472 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 473 |
+
|
| 474 |
+
if filename not in all_data:
|
| 475 |
+
all_data[filename] = self.get_default_state()
|
| 476 |
+
|
| 477 |
+
# 更新统计数据
|
| 478 |
+
all_data[filename].update(stats_updates)
|
| 479 |
+
|
| 480 |
+
success = await self._credentials_cache_manager.update_multi({filename: all_data[filename]})
|
| 481 |
+
log.debug(f"Updated usage stats in unified cache: {filename}")
|
| 482 |
+
return success
|
| 483 |
+
|
| 484 |
+
except Exception as e:
|
| 485 |
+
log.error(f"Error updating usage stats {filename}: {e}")
|
| 486 |
+
return False
|
| 487 |
+
|
| 488 |
+
async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
|
| 489 |
+
"""从统一缓存获取使用统计"""
|
| 490 |
+
self._ensure_initialized()
|
| 491 |
+
|
| 492 |
+
try:
|
| 493 |
+
filename = self._normalize_filename(filename)
|
| 494 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 495 |
+
|
| 496 |
+
if filename not in all_data:
|
| 497 |
+
# 返回基本的统计��段
|
| 498 |
+
default_state = self.get_default_state()
|
| 499 |
+
return {k: v for k, v in default_state.items() if k in {"gemini_2_5_pro_calls", "total_calls", "next_reset_time", "daily_limit_gemini_2_5_pro", "daily_limit_total"}}
|
| 500 |
+
|
| 501 |
+
section_data = all_data[filename]
|
| 502 |
+
|
| 503 |
+
# 提取统计字段
|
| 504 |
+
stats_fields = {"gemini_2_5_pro_calls", "total_calls", "next_reset_time", "daily_limit_gemini_2_5_pro", "daily_limit_total"}
|
| 505 |
+
stats_data = {k: v for k, v in section_data.items() if k in stats_fields}
|
| 506 |
+
|
| 507 |
+
# 确保必要字段存在
|
| 508 |
+
default_state = self.get_default_state()
|
| 509 |
+
for field in stats_fields:
|
| 510 |
+
if field not in stats_data:
|
| 511 |
+
stats_data[field] = default_state[field]
|
| 512 |
+
|
| 513 |
+
return stats_data
|
| 514 |
+
|
| 515 |
+
except Exception as e:
|
| 516 |
+
log.error(f"Error getting usage stats {filename}: {e}")
|
| 517 |
+
return self.get_default_state()
|
| 518 |
+
|
| 519 |
+
async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
|
| 520 |
+
"""从统一缓存获取所有使用统计"""
|
| 521 |
+
self._ensure_initialized()
|
| 522 |
+
|
| 523 |
+
try:
|
| 524 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 525 |
+
|
| 526 |
+
stats = {}
|
| 527 |
+
stats_fields = {"gemini_2_5_pro_calls", "total_calls", "next_reset_time", "daily_limit_gemini_2_5_pro", "daily_limit_total"}
|
| 528 |
+
|
| 529 |
+
for filename, section_data in all_data.items():
|
| 530 |
+
# 提取统计字段
|
| 531 |
+
stats_data = {k: v for k, v in section_data.items() if k in stats_fields}
|
| 532 |
+
|
| 533 |
+
# 确保必要字段存在
|
| 534 |
+
default_state = self.get_default_state()
|
| 535 |
+
for field in stats_fields:
|
| 536 |
+
if field not in stats_data:
|
| 537 |
+
stats_data[field] = default_state[field]
|
| 538 |
+
|
| 539 |
+
stats[filename] = stats_data
|
| 540 |
+
|
| 541 |
+
return stats
|
| 542 |
+
|
| 543 |
+
except Exception as e:
|
| 544 |
+
log.error(f"Error getting all usage stats: {e}")
|
| 545 |
+
return {}
|
| 546 |
+
|
| 547 |
+
# ============ 工具方法 ============
|
| 548 |
+
|
| 549 |
+
async def export_credential_to_json(self, filename: str, output_path: str = None) -> bool:
|
| 550 |
+
"""将TOML中的凭证导出为JSON文件(用于兼容性和备份)"""
|
| 551 |
+
self._ensure_initialized()
|
| 552 |
+
|
| 553 |
+
try:
|
| 554 |
+
filename = self._normalize_filename(filename)
|
| 555 |
+
credential_data = await self.get_credential(filename)
|
| 556 |
+
|
| 557 |
+
if credential_data is None:
|
| 558 |
+
log.warning(f"Credential not found for export: {filename}")
|
| 559 |
+
return False
|
| 560 |
+
|
| 561 |
+
if output_path is None:
|
| 562 |
+
output_path = os.path.join(self._credentials_dir, f"{filename}.json")
|
| 563 |
+
|
| 564 |
+
# 写入JSON文件
|
| 565 |
+
json_content = json.dumps(credential_data, indent=2, ensure_ascii=False)
|
| 566 |
+
async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
|
| 567 |
+
await f.write(json_content)
|
| 568 |
+
|
| 569 |
+
log.info(f"Credential exported to JSON: {output_path}")
|
| 570 |
+
return True
|
| 571 |
+
|
| 572 |
+
except Exception as e:
|
| 573 |
+
log.error(f"Error exporting credential {filename} to JSON: {e}")
|
| 574 |
+
return False
|
| 575 |
+
|
| 576 |
+
async def import_credential_from_json(self, json_path: str, filename: str = None) -> bool:
|
| 577 |
+
"""从JSON文件导入凭证到TOML"""
|
| 578 |
+
self._ensure_initialized()
|
| 579 |
+
|
| 580 |
+
try:
|
| 581 |
+
if not os.path.exists(json_path):
|
| 582 |
+
log.error(f"JSON file not found: {json_path}")
|
| 583 |
+
return False
|
| 584 |
+
|
| 585 |
+
# 读取JSON文件
|
| 586 |
+
async with aiofiles.open(json_path, "r", encoding="utf-8") as f:
|
| 587 |
+
json_content = await f.read()
|
| 588 |
+
|
| 589 |
+
credential_data = json.loads(json_content)
|
| 590 |
+
|
| 591 |
+
if filename is None:
|
| 592 |
+
filename = os.path.basename(json_path)
|
| 593 |
+
|
| 594 |
+
filename = self._normalize_filename(filename)
|
| 595 |
+
|
| 596 |
+
# 存储凭证
|
| 597 |
+
success = await self.store_credential(filename, credential_data)
|
| 598 |
+
|
| 599 |
+
if success:
|
| 600 |
+
log.info(f"Credential imported from JSON: {json_path} -> {filename}")
|
| 601 |
+
|
| 602 |
+
return success
|
| 603 |
+
|
| 604 |
+
except Exception as e:
|
| 605 |
+
log.error(f"Error importing credential from JSON {json_path}: {e}")
|
| 606 |
+
return False
|
src/storage/mongodb_manager.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MongoDB数据库管理器,使用单文档设计和统一缓存。
|
| 3 |
+
所有凭证数据存储在一个文档中,配置数据存储在另一个文档中,类似TOML文件结构。
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from datetime import datetime, timezone
|
| 9 |
+
from typing import Dict, Any, List, Optional
|
| 10 |
+
from collections import deque
|
| 11 |
+
|
| 12 |
+
import motor.motor_asyncio
|
| 13 |
+
from log import log
|
| 14 |
+
from .cache_manager import UnifiedCacheManager, CacheBackend
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MongoDBCacheBackend(CacheBackend):
|
| 18 |
+
"""MongoDB缓存后端实现"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, db, collection_name: str, doc_key: str):
|
| 21 |
+
self._db = db
|
| 22 |
+
self._collection_name = collection_name
|
| 23 |
+
self._doc_key = doc_key
|
| 24 |
+
|
| 25 |
+
async def load_data(self) -> Dict[str, Any]:
|
| 26 |
+
"""从MongoDB文档加载数据"""
|
| 27 |
+
try:
|
| 28 |
+
collection = self._db[self._collection_name]
|
| 29 |
+
doc = await collection.find_one({"key": self._doc_key})
|
| 30 |
+
|
| 31 |
+
if doc and "data" in doc:
|
| 32 |
+
return doc["data"]
|
| 33 |
+
return {}
|
| 34 |
+
|
| 35 |
+
except Exception as e:
|
| 36 |
+
log.error(f"Error loading data from MongoDB document {self._doc_key}: {e}")
|
| 37 |
+
return {}
|
| 38 |
+
|
| 39 |
+
async def write_data(self, data: Dict[str, Any]) -> bool:
|
| 40 |
+
"""将数据写入MongoDB文档"""
|
| 41 |
+
try:
|
| 42 |
+
collection = self._db[self._collection_name]
|
| 43 |
+
|
| 44 |
+
doc = {
|
| 45 |
+
"key": self._doc_key,
|
| 46 |
+
"data": data,
|
| 47 |
+
"updated_at": datetime.now(timezone.utc)
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
await collection.replace_one(
|
| 51 |
+
{"key": self._doc_key},
|
| 52 |
+
doc,
|
| 53 |
+
upsert=True
|
| 54 |
+
)
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
log.error(f"Error writing data to MongoDB document {self._doc_key}: {e}")
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class MongoDBManager:
|
| 63 |
+
"""MongoDB数据库管理器"""
|
| 64 |
+
|
| 65 |
+
def __init__(self):
|
| 66 |
+
self._client: Optional[motor.motor_asyncio.AsyncIOMotorClient] = None
|
| 67 |
+
self._db: Optional[motor.motor_asyncio.AsyncIOMotorDatabase] = None
|
| 68 |
+
self._initialized = False
|
| 69 |
+
self._lock = asyncio.Lock()
|
| 70 |
+
|
| 71 |
+
# 配置
|
| 72 |
+
self._connection_uri = None
|
| 73 |
+
self._database_name = None
|
| 74 |
+
|
| 75 |
+
# 单文档设计 - 所有凭证存在一个文档中(类似TOML文件)
|
| 76 |
+
self._collection_name = "credentials_data"
|
| 77 |
+
|
| 78 |
+
# 性能监控
|
| 79 |
+
self._operation_count = 0
|
| 80 |
+
self._operation_times = deque(maxlen=5000)
|
| 81 |
+
|
| 82 |
+
# 统一缓存管理器
|
| 83 |
+
self._credentials_cache_manager: Optional[UnifiedCacheManager] = None
|
| 84 |
+
self._config_cache_manager: Optional[UnifiedCacheManager] = None
|
| 85 |
+
|
| 86 |
+
# 文档key定义
|
| 87 |
+
self._credentials_doc_key = "all_credentials"
|
| 88 |
+
self._config_doc_key = "config_data"
|
| 89 |
+
|
| 90 |
+
# 写入配置参数
|
| 91 |
+
self._write_delay = 1.0 # 写入延迟(秒)
|
| 92 |
+
self._cache_ttl = 300 # 缓存TTL(秒)
|
| 93 |
+
|
| 94 |
+
async def initialize(self):
|
| 95 |
+
"""初始化MongoDB连接"""
|
| 96 |
+
async with self._lock:
|
| 97 |
+
if self._initialized:
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
# 获取连接配置
|
| 102 |
+
self._connection_uri = os.getenv("MONGODB_URI")
|
| 103 |
+
self._database_name = os.getenv("MONGODB_DATABASE", "gcli2api")
|
| 104 |
+
|
| 105 |
+
if not self._connection_uri:
|
| 106 |
+
raise ValueError("MONGODB_URI environment variable is required")
|
| 107 |
+
|
| 108 |
+
# 建立连接
|
| 109 |
+
self._client = motor.motor_asyncio.AsyncIOMotorClient(
|
| 110 |
+
self._connection_uri,
|
| 111 |
+
serverSelectionTimeoutMS=5000,
|
| 112 |
+
maxPoolSize=100,
|
| 113 |
+
minPoolSize=10,
|
| 114 |
+
maxIdleTimeMS=45000,
|
| 115 |
+
waitQueueTimeoutMS=10000,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# 验证连接
|
| 119 |
+
await self._client.admin.command('ping')
|
| 120 |
+
|
| 121 |
+
# 获取数据库
|
| 122 |
+
self._db = self._client[self._database_name]
|
| 123 |
+
|
| 124 |
+
# 创建索引
|
| 125 |
+
await self._create_indexes()
|
| 126 |
+
|
| 127 |
+
# 创建缓存管理器
|
| 128 |
+
credentials_backend = MongoDBCacheBackend(self._db, self._collection_name, self._credentials_doc_key)
|
| 129 |
+
config_backend = MongoDBCacheBackend(self._db, self._collection_name, self._config_doc_key)
|
| 130 |
+
|
| 131 |
+
self._credentials_cache_manager = UnifiedCacheManager(
|
| 132 |
+
credentials_backend,
|
| 133 |
+
cache_ttl=self._cache_ttl,
|
| 134 |
+
write_delay=self._write_delay,
|
| 135 |
+
name="credentials"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self._config_cache_manager = UnifiedCacheManager(
|
| 139 |
+
config_backend,
|
| 140 |
+
cache_ttl=self._cache_ttl,
|
| 141 |
+
write_delay=self._write_delay,
|
| 142 |
+
name="config"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# 启动缓存管理器
|
| 146 |
+
await self._credentials_cache_manager.start()
|
| 147 |
+
await self._config_cache_manager.start()
|
| 148 |
+
|
| 149 |
+
self._initialized = True
|
| 150 |
+
log.info(f"MongoDB connection established to {self._database_name} with unified cache")
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
log.error(f"Error initializing MongoDB: {e}")
|
| 154 |
+
raise
|
| 155 |
+
|
| 156 |
+
async def _create_indexes(self):
|
| 157 |
+
"""创建简单索引(单文档设计)"""
|
| 158 |
+
try:
|
| 159 |
+
# 单文档设计只需要主键索引
|
| 160 |
+
await self._db[self._collection_name].create_index("key", unique=True)
|
| 161 |
+
await self._db[self._collection_name].create_index("updated_at")
|
| 162 |
+
|
| 163 |
+
log.info("MongoDB indexes created for single-document design")
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
log.error(f"Error creating MongoDB indexes: {e}")
|
| 167 |
+
|
| 168 |
+
async def close(self):
|
| 169 |
+
"""关闭MongoDB连接"""
|
| 170 |
+
# 停止缓存管理器
|
| 171 |
+
if self._credentials_cache_manager:
|
| 172 |
+
await self._credentials_cache_manager.stop()
|
| 173 |
+
if self._config_cache_manager:
|
| 174 |
+
await self._config_cache_manager.stop()
|
| 175 |
+
|
| 176 |
+
if self._client:
|
| 177 |
+
self._client.close()
|
| 178 |
+
self._initialized = False
|
| 179 |
+
log.info("MongoDB connection closed with unified cache flushed")
|
| 180 |
+
|
| 181 |
+
def _ensure_initialized(self):
|
| 182 |
+
"""确保已初始化"""
|
| 183 |
+
if not self._initialized:
|
| 184 |
+
raise RuntimeError("MongoDB manager not initialized")
|
| 185 |
+
|
| 186 |
+
def _get_default_state(self) -> Dict[str, Any]:
|
| 187 |
+
"""获取默认状态数据"""
|
| 188 |
+
return {
|
| 189 |
+
"error_codes": [],
|
| 190 |
+
"disabled": False,
|
| 191 |
+
"last_success": time.time(),
|
| 192 |
+
"user_email": None,
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
def _get_default_stats(self) -> Dict[str, Any]:
|
| 196 |
+
"""获取默认统计数据"""
|
| 197 |
+
return {
|
| 198 |
+
"gemini_2_5_pro_calls": 0,
|
| 199 |
+
"total_calls": 0,
|
| 200 |
+
"next_reset_time": None,
|
| 201 |
+
"daily_limit_gemini_2_5_pro": 100,
|
| 202 |
+
"daily_limit_total": 1000
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
# ============ 凭证管理 ============
|
| 206 |
+
|
| 207 |
+
async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
|
| 208 |
+
"""存储凭证数据到统一缓存"""
|
| 209 |
+
self._ensure_initialized()
|
| 210 |
+
start_time = time.time()
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
# 获取现有数据或创建新数据
|
| 214 |
+
existing_data = await self._credentials_cache_manager.get(filename, {})
|
| 215 |
+
|
| 216 |
+
credential_entry = {
|
| 217 |
+
"credential": credential_data,
|
| 218 |
+
"state": existing_data.get("state", self._get_default_state()),
|
| 219 |
+
"stats": existing_data.get("stats", self._get_default_stats())
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
success = await self._credentials_cache_manager.set(filename, credential_entry)
|
| 223 |
+
|
| 224 |
+
# 性能监控
|
| 225 |
+
self._operation_count += 1
|
| 226 |
+
operation_time = time.time() - start_time
|
| 227 |
+
self._operation_times.append(operation_time)
|
| 228 |
+
|
| 229 |
+
log.debug(f"Stored credential to unified cache: {filename} in {operation_time:.3f}s")
|
| 230 |
+
return success
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
operation_time = time.time() - start_time
|
| 234 |
+
log.error(f"Error storing credential {filename} in {operation_time:.3f}s: {e}")
|
| 235 |
+
return False
|
| 236 |
+
|
| 237 |
+
async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
|
| 238 |
+
"""从统一缓存获取凭证数据"""
|
| 239 |
+
self._ensure_initialized()
|
| 240 |
+
start_time = time.time()
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
credential_entry = await self._credentials_cache_manager.get(filename)
|
| 244 |
+
|
| 245 |
+
# 性能监控
|
| 246 |
+
self._operation_count += 1
|
| 247 |
+
operation_time = time.time() - start_time
|
| 248 |
+
self._operation_times.append(operation_time)
|
| 249 |
+
|
| 250 |
+
if credential_entry and "credential" in credential_entry:
|
| 251 |
+
return credential_entry["credential"]
|
| 252 |
+
return None
|
| 253 |
+
|
| 254 |
+
except Exception as e:
|
| 255 |
+
operation_time = time.time() - start_time
|
| 256 |
+
log.error(f"Error retrieving credential {filename} in {operation_time:.3f}s: {e}")
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
async def list_credentials(self) -> List[str]:
|
| 260 |
+
"""从统一缓存列出所有凭证文件名"""
|
| 261 |
+
self._ensure_initialized()
|
| 262 |
+
start_time = time.time()
|
| 263 |
+
|
| 264 |
+
try:
|
| 265 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 266 |
+
filenames = list(all_data.keys())
|
| 267 |
+
|
| 268 |
+
# 性能监控
|
| 269 |
+
self._operation_count += 1
|
| 270 |
+
operation_time = time.time() - start_time
|
| 271 |
+
self._operation_times.append(operation_time)
|
| 272 |
+
|
| 273 |
+
log.debug(f"Listed {len(filenames)} credentials from unified cache in {operation_time:.3f}s")
|
| 274 |
+
return filenames
|
| 275 |
+
|
| 276 |
+
except Exception as e:
|
| 277 |
+
operation_time = time.time() - start_time
|
| 278 |
+
log.error(f"Error listing credentials in {operation_time:.3f}s: {e}")
|
| 279 |
+
return []
|
| 280 |
+
|
| 281 |
+
async def delete_credential(self, filename: str) -> bool:
|
| 282 |
+
"""从统一缓存删除凭证及所有相关数据"""
|
| 283 |
+
self._ensure_initialized()
|
| 284 |
+
start_time = time.time()
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
success = await self._credentials_cache_manager.delete(filename)
|
| 288 |
+
|
| 289 |
+
# 性能监控
|
| 290 |
+
self._operation_count += 1
|
| 291 |
+
operation_time = time.time() - start_time
|
| 292 |
+
self._operation_times.append(operation_time)
|
| 293 |
+
|
| 294 |
+
log.debug(f"Deleted credential from unified cache: {filename} in {operation_time:.3f}s")
|
| 295 |
+
return success
|
| 296 |
+
|
| 297 |
+
except Exception as e:
|
| 298 |
+
operation_time = time.time() - start_time
|
| 299 |
+
log.error(f"Error deleting credential {filename} in {operation_time:.3f}s: {e}")
|
| 300 |
+
return False
|
| 301 |
+
|
| 302 |
+
# ============ 状态管理 ============
|
| 303 |
+
|
| 304 |
+
async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
|
| 305 |
+
"""更新凭证状态(使用统一缓存)"""
|
| 306 |
+
self._ensure_initialized()
|
| 307 |
+
start_time = time.time()
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
# 获取现有数据或创建新数据
|
| 311 |
+
existing_data = await self._credentials_cache_manager.get(filename, {})
|
| 312 |
+
|
| 313 |
+
if not existing_data:
|
| 314 |
+
existing_data = {
|
| 315 |
+
"credential": {},
|
| 316 |
+
"state": self._get_default_state(),
|
| 317 |
+
"stats": self._get_default_stats()
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
# 更新状态数据
|
| 321 |
+
existing_data["state"].update(state_updates)
|
| 322 |
+
|
| 323 |
+
success = await self._credentials_cache_manager.set(filename, existing_data)
|
| 324 |
+
|
| 325 |
+
# 性能监控
|
| 326 |
+
self._operation_count += 1
|
| 327 |
+
operation_time = time.time() - start_time
|
| 328 |
+
self._operation_times.append(operation_time)
|
| 329 |
+
|
| 330 |
+
log.debug(f"Updated credential state in unified cache: {filename} in {operation_time:.3f}s")
|
| 331 |
+
return success
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
operation_time = time.time() - start_time
|
| 335 |
+
log.error(f"Error updating credential state {filename} in {operation_time:.3f}s: {e}")
|
| 336 |
+
return False
|
| 337 |
+
|
| 338 |
+
async def get_credential_state(self, filename: str) -> Dict[str, Any]:
|
| 339 |
+
"""从统一缓存获取凭证状态"""
|
| 340 |
+
self._ensure_initialized()
|
| 341 |
+
start_time = time.time()
|
| 342 |
+
|
| 343 |
+
try:
|
| 344 |
+
credential_entry = await self._credentials_cache_manager.get(filename)
|
| 345 |
+
|
| 346 |
+
# 性能监控
|
| 347 |
+
self._operation_count += 1
|
| 348 |
+
operation_time = time.time() - start_time
|
| 349 |
+
self._operation_times.append(operation_time)
|
| 350 |
+
|
| 351 |
+
if credential_entry and "state" in credential_entry:
|
| 352 |
+
log.debug(f"Retrieved credential state from unified cache: {filename} in {operation_time:.3f}s")
|
| 353 |
+
return credential_entry["state"]
|
| 354 |
+
else:
|
| 355 |
+
# 返回默认状态
|
| 356 |
+
return self._get_default_state()
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
operation_time = time.time() - start_time
|
| 360 |
+
log.error(f"Error getting credential state {filename} in {operation_time:.3f}s: {e}")
|
| 361 |
+
return self._get_default_state()
|
| 362 |
+
|
| 363 |
+
async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
|
| 364 |
+
"""从统一缓存获取所有凭证状态"""
|
| 365 |
+
self._ensure_initialized()
|
| 366 |
+
start_time = time.time()
|
| 367 |
+
|
| 368 |
+
try:
|
| 369 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 370 |
+
|
| 371 |
+
states = {}
|
| 372 |
+
for filename, cred_data in all_data.items():
|
| 373 |
+
states[filename] = cred_data.get("state", self._get_default_state())
|
| 374 |
+
|
| 375 |
+
# 性能监控
|
| 376 |
+
self._operation_count += 1
|
| 377 |
+
operation_time = time.time() - start_time
|
| 378 |
+
self._operation_times.append(operation_time)
|
| 379 |
+
|
| 380 |
+
log.debug(f"Retrieved all credential states from unified cache ({len(states)}) in {operation_time:.3f}s")
|
| 381 |
+
return states
|
| 382 |
+
|
| 383 |
+
except Exception as e:
|
| 384 |
+
operation_time = time.time() - start_time
|
| 385 |
+
log.error(f"Error getting all credential states in {operation_time:.3f}s: {e}")
|
| 386 |
+
return {}
|
| 387 |
+
|
| 388 |
+
# ============ 配置管理 ============
|
| 389 |
+
|
| 390 |
+
async def set_config(self, key: str, value: Any) -> bool:
|
| 391 |
+
"""设置配置到统一缓存"""
|
| 392 |
+
self._ensure_initialized()
|
| 393 |
+
return await self._config_cache_manager.set(key, value)
|
| 394 |
+
|
| 395 |
+
async def get_config(self, key: str, default: Any = None) -> Any:
|
| 396 |
+
"""从统一缓存获取配置"""
|
| 397 |
+
self._ensure_initialized()
|
| 398 |
+
return await self._config_cache_manager.get(key, default)
|
| 399 |
+
|
| 400 |
+
async def get_all_config(self) -> Dict[str, Any]:
|
| 401 |
+
"""从统一缓存获取所有配置"""
|
| 402 |
+
self._ensure_initialized()
|
| 403 |
+
return await self._config_cache_manager.get_all()
|
| 404 |
+
|
| 405 |
+
async def delete_config(self, key: str) -> bool:
|
| 406 |
+
"""从统一缓存删除配置"""
|
| 407 |
+
self._ensure_initialized()
|
| 408 |
+
return await self._config_cache_manager.delete(key)
|
| 409 |
+
|
| 410 |
+
# ============ 使用统计管理 ============
|
| 411 |
+
|
| 412 |
+
async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
|
| 413 |
+
"""更新使用统计(使用统一缓存)"""
|
| 414 |
+
self._ensure_initialized()
|
| 415 |
+
start_time = time.time()
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
# 获取现有数据或创建新数据
|
| 419 |
+
existing_data = await self._credentials_cache_manager.get(filename, {})
|
| 420 |
+
|
| 421 |
+
if not existing_data:
|
| 422 |
+
existing_data = {
|
| 423 |
+
"credential": {},
|
| 424 |
+
"state": self._get_default_state(),
|
| 425 |
+
"stats": self._get_default_stats()
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
# 更新统计数据
|
| 429 |
+
existing_data["stats"].update(stats_updates)
|
| 430 |
+
|
| 431 |
+
success = await self._credentials_cache_manager.set(filename, existing_data)
|
| 432 |
+
|
| 433 |
+
# 性能监控
|
| 434 |
+
self._operation_count += 1
|
| 435 |
+
operation_time = time.time() - start_time
|
| 436 |
+
self._operation_times.append(operation_time)
|
| 437 |
+
|
| 438 |
+
log.debug(f"Updated usage stats in unified cache: {filename} in {operation_time:.3f}s")
|
| 439 |
+
return success
|
| 440 |
+
|
| 441 |
+
except Exception as e:
|
| 442 |
+
operation_time = time.time() - start_time
|
| 443 |
+
log.error(f"Error updating usage stats {filename} in {operation_time:.3f}s: {e}")
|
| 444 |
+
return False
|
| 445 |
+
|
| 446 |
+
async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
|
| 447 |
+
"""从统一缓存获取使用统计"""
|
| 448 |
+
self._ensure_initialized()
|
| 449 |
+
start_time = time.time()
|
| 450 |
+
|
| 451 |
+
try:
|
| 452 |
+
credential_entry = await self._credentials_cache_manager.get(filename)
|
| 453 |
+
|
| 454 |
+
# 性能监控
|
| 455 |
+
self._operation_count += 1
|
| 456 |
+
operation_time = time.time() - start_time
|
| 457 |
+
self._operation_times.append(operation_time)
|
| 458 |
+
|
| 459 |
+
if credential_entry and "stats" in credential_entry:
|
| 460 |
+
log.debug(f"Retrieved usage stats from unified cache: {filename} in {operation_time:.3f}s")
|
| 461 |
+
return credential_entry["stats"]
|
| 462 |
+
else:
|
| 463 |
+
return self._get_default_stats()
|
| 464 |
+
|
| 465 |
+
except Exception as e:
|
| 466 |
+
operation_time = time.time() - start_time
|
| 467 |
+
log.error(f"Error getting usage stats {filename} in {operation_time:.3f}s: {e}")
|
| 468 |
+
return self._get_default_stats()
|
| 469 |
+
|
| 470 |
+
async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
|
| 471 |
+
"""从统一缓存获取所有使用统计"""
|
| 472 |
+
self._ensure_initialized()
|
| 473 |
+
start_time = time.time()
|
| 474 |
+
|
| 475 |
+
try:
|
| 476 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 477 |
+
|
| 478 |
+
stats = {}
|
| 479 |
+
for filename, cred_data in all_data.items():
|
| 480 |
+
if "stats" in cred_data:
|
| 481 |
+
stats[filename] = cred_data["stats"]
|
| 482 |
+
|
| 483 |
+
# 性能监控
|
| 484 |
+
self._operation_count += 1
|
| 485 |
+
operation_time = time.time() - start_time
|
| 486 |
+
self._operation_times.append(operation_time)
|
| 487 |
+
|
| 488 |
+
log.debug(f"Retrieved all usage stats from unified cache ({len(stats)}) in {operation_time:.3f}s")
|
| 489 |
+
return stats
|
| 490 |
+
|
| 491 |
+
except Exception as e:
|
| 492 |
+
operation_time = time.time() - start_time
|
| 493 |
+
log.error(f"Error getting all usage stats in {operation_time:.3f}s: {e}")
|
| 494 |
+
return {}
|
src/storage/postgres_manager.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Postgres数据库管理器,采用单行设计并兼容 UnifiedCacheManager。
|
| 3 |
+
实现与 mongodb_manager.py 风格一致的接口(异步)。
|
| 4 |
+
需要环境变量: POSTGRES_DSN (例如: postgresql://user:pass@host:port/dbname)
|
| 5 |
+
"""
|
| 6 |
+
import asyncio
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
+
from datetime import datetime, timezone
|
| 11 |
+
from typing import Dict, Any, List, Optional
|
| 12 |
+
from collections import deque
|
| 13 |
+
|
| 14 |
+
import asyncpg
|
| 15 |
+
from log import log
|
| 16 |
+
from .cache_manager import UnifiedCacheManager, CacheBackend
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PostgresCacheBackend(CacheBackend):
|
| 20 |
+
"""Postgres缓存后端,数据存储为key, data(JSONB), updated_at
|
| 21 |
+
单行/单表设计:表名由管理器指定,每行以key区分。
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, conn_pool, table_name: str, row_key: str):
|
| 25 |
+
self._pool = conn_pool
|
| 26 |
+
self._table_name = table_name
|
| 27 |
+
self._row_key = row_key
|
| 28 |
+
|
| 29 |
+
async def load_data(self) -> Dict[str, Any]:
|
| 30 |
+
try:
|
| 31 |
+
async with self._pool.acquire() as conn:
|
| 32 |
+
row = await conn.fetchrow(
|
| 33 |
+
f"SELECT data FROM {self._table_name} WHERE key = $1",
|
| 34 |
+
self._row_key
|
| 35 |
+
)
|
| 36 |
+
if row and row.get('data') is not None:
|
| 37 |
+
data = row['data']
|
| 38 |
+
# JSONB字段返回JSON字符串,需要解析为字典
|
| 39 |
+
if isinstance(data, str):
|
| 40 |
+
return json.loads(data)
|
| 41 |
+
elif isinstance(data, dict):
|
| 42 |
+
return data
|
| 43 |
+
else:
|
| 44 |
+
log.warning(f"Unexpected data type from JSONB field: {type(data)}")
|
| 45 |
+
return {}
|
| 46 |
+
return {}
|
| 47 |
+
except Exception as e:
|
| 48 |
+
log.error(f"Error loading data from Postgres row {self._row_key}: {e}")
|
| 49 |
+
return {}
|
| 50 |
+
|
| 51 |
+
async def write_data(self, data: Dict[str, Any]) -> bool:
|
| 52 |
+
try:
|
| 53 |
+
async with self._pool.acquire() as conn:
|
| 54 |
+
await conn.execute(
|
| 55 |
+
f"INSERT INTO {self._table_name}(key, data, updated_at) VALUES($1, $2::jsonb, $3)"
|
| 56 |
+
" ON CONFLICT (key) DO UPDATE SET data = EXCLUDED.data, updated_at = EXCLUDED.updated_at",
|
| 57 |
+
self._row_key, json.dumps(data, default=str), datetime.now(timezone.utc)
|
| 58 |
+
)
|
| 59 |
+
return True
|
| 60 |
+
except Exception as e:
|
| 61 |
+
log.error(f"Error writing data to Postgres row {self._row_key}: {e}")
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class PostgresManager:
|
| 66 |
+
"""Postgres管理器。
|
| 67 |
+
使用单表单行设计存储凭证和配置数据。
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self):
|
| 71 |
+
self._pool: Optional[asyncpg.pool.Pool] = None
|
| 72 |
+
self._initialized = False
|
| 73 |
+
self._lock = asyncio.Lock()
|
| 74 |
+
|
| 75 |
+
self._dsn = None
|
| 76 |
+
self._table_name = 'unified_storage'
|
| 77 |
+
|
| 78 |
+
self._operation_count = 0
|
| 79 |
+
|
| 80 |
+
self._operation_times = deque(maxlen=5000)
|
| 81 |
+
|
| 82 |
+
self._credentials_cache_manager: Optional[UnifiedCacheManager] = None
|
| 83 |
+
self._config_cache_manager: Optional[UnifiedCacheManager] = None
|
| 84 |
+
|
| 85 |
+
self._credentials_row_key = 'all_credentials'
|
| 86 |
+
self._config_row_key = 'config_data'
|
| 87 |
+
|
| 88 |
+
self._write_delay = 1.0
|
| 89 |
+
self._cache_ttl = 300
|
| 90 |
+
|
| 91 |
+
async def initialize(self):
|
| 92 |
+
async with self._lock:
|
| 93 |
+
if self._initialized:
|
| 94 |
+
return
|
| 95 |
+
try:
|
| 96 |
+
self._dsn = os.getenv('POSTGRES_DSN')
|
| 97 |
+
if not self._dsn:
|
| 98 |
+
raise ValueError('POSTGRES_DSN environment variable is required')
|
| 99 |
+
|
| 100 |
+
self._pool = await asyncpg.create_pool(dsn=self._dsn, max_size=20, min_size=1)
|
| 101 |
+
|
| 102 |
+
# 确保表存在
|
| 103 |
+
await self._ensure_table()
|
| 104 |
+
|
| 105 |
+
# 创建缓存管理器后端
|
| 106 |
+
credentials_backend = PostgresCacheBackend(self._pool, self._table_name, self._credentials_row_key)
|
| 107 |
+
config_backend = PostgresCacheBackend(self._pool, self._table_name, self._config_row_key)
|
| 108 |
+
|
| 109 |
+
self._credentials_cache_manager = UnifiedCacheManager(
|
| 110 |
+
credentials_backend, cache_ttl=self._cache_ttl, write_delay=self._write_delay, name='credentials'
|
| 111 |
+
)
|
| 112 |
+
self._config_cache_manager = UnifiedCacheManager(
|
| 113 |
+
config_backend, cache_ttl=self._cache_ttl, write_delay=self._write_delay, name='config'
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
await self._credentials_cache_manager.start()
|
| 117 |
+
await self._config_cache_manager.start()
|
| 118 |
+
|
| 119 |
+
self._initialized = True
|
| 120 |
+
log.info('Postgres connection established with unified cache')
|
| 121 |
+
except Exception as e:
|
| 122 |
+
log.error(f'Error initializing Postgres: {e}')
|
| 123 |
+
raise
|
| 124 |
+
|
| 125 |
+
async def _ensure_table(self):
|
| 126 |
+
try:
|
| 127 |
+
async with self._pool.acquire() as conn:
|
| 128 |
+
await conn.execute(
|
| 129 |
+
f"CREATE TABLE IF NOT EXISTS {self._table_name}(\n key TEXT PRIMARY KEY,\n data JSONB,\n updated_at TIMESTAMPTZ\n )"
|
| 130 |
+
)
|
| 131 |
+
except Exception as e:
|
| 132 |
+
log.error(f'Error ensuring Postgres table: {e}')
|
| 133 |
+
raise
|
| 134 |
+
|
| 135 |
+
async def close(self):
|
| 136 |
+
if self._credentials_cache_manager:
|
| 137 |
+
await self._credentials_cache_manager.stop()
|
| 138 |
+
if self._config_cache_manager:
|
| 139 |
+
await self._config_cache_manager.stop()
|
| 140 |
+
if self._pool:
|
| 141 |
+
await self._pool.close()
|
| 142 |
+
self._initialized = False
|
| 143 |
+
log.info('Postgres connection closed with unified cache flushed')
|
| 144 |
+
|
| 145 |
+
def _ensure_initialized(self):
|
| 146 |
+
if not self._initialized:
|
| 147 |
+
raise RuntimeError('Postgres manager not initialized')
|
| 148 |
+
|
| 149 |
+
def _get_default_state(self) -> Dict[str, Any]:
|
| 150 |
+
return {
|
| 151 |
+
'error_codes': [],
|
| 152 |
+
'disabled': False,
|
| 153 |
+
'last_success': time.time(),
|
| 154 |
+
'user_email': None,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
def _get_default_stats(self) -> Dict[str, Any]:
|
| 158 |
+
return {
|
| 159 |
+
'gemini_2_5_pro_calls': 0,
|
| 160 |
+
'total_calls': 0,
|
| 161 |
+
'next_reset_time': None,
|
| 162 |
+
'daily_limit_gemini_2_5_pro': 100,
|
| 163 |
+
'daily_limit_total': 1000
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
# 以下方法委托给 UnifiedCacheManager
|
| 167 |
+
async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
|
| 168 |
+
self._ensure_initialized()
|
| 169 |
+
start_time = time.time()
|
| 170 |
+
try:
|
| 171 |
+
existing_data = await self._credentials_cache_manager.get(filename, {})
|
| 172 |
+
credential_entry = {
|
| 173 |
+
'credential': credential_data,
|
| 174 |
+
'state': existing_data.get('state', self._get_default_state()),
|
| 175 |
+
'stats': existing_data.get('stats', self._get_default_stats())
|
| 176 |
+
}
|
| 177 |
+
success = await self._credentials_cache_manager.set(filename, credential_entry)
|
| 178 |
+
self._operation_count += 1
|
| 179 |
+
self._operation_times.append(time.time() - start_time)
|
| 180 |
+
log.debug(f'Stored credential to unified cache (postgres): {filename}')
|
| 181 |
+
return success
|
| 182 |
+
except Exception as e:
|
| 183 |
+
log.error(f'Error storing credential {filename} in Postgres: {e}')
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
|
| 187 |
+
self._ensure_initialized()
|
| 188 |
+
try:
|
| 189 |
+
credential_entry = await self._credentials_cache_manager.get(filename)
|
| 190 |
+
self._operation_count += 1
|
| 191 |
+
if credential_entry and 'credential' in credential_entry:
|
| 192 |
+
return credential_entry['credential']
|
| 193 |
+
return None
|
| 194 |
+
except Exception as e:
|
| 195 |
+
log.error(f'Error retrieving credential {filename} from Postgres: {e}')
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
async def list_credentials(self) -> List[str]:
|
| 199 |
+
self._ensure_initialized()
|
| 200 |
+
try:
|
| 201 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 202 |
+
return list(all_data.keys())
|
| 203 |
+
except Exception as e:
|
| 204 |
+
log.error(f'Error listing credentials from Postgres: {e}')
|
| 205 |
+
return []
|
| 206 |
+
|
| 207 |
+
async def delete_credential(self, filename: str) -> bool:
|
| 208 |
+
self._ensure_initialized()
|
| 209 |
+
try:
|
| 210 |
+
return await self._credentials_cache_manager.delete(filename)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
log.error(f'Error deleting credential {filename} from Postgres: {e}')
|
| 213 |
+
return False
|
| 214 |
+
|
| 215 |
+
async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
|
| 216 |
+
self._ensure_initialized()
|
| 217 |
+
try:
|
| 218 |
+
existing_data = await self._credentials_cache_manager.get(filename, {})
|
| 219 |
+
if not existing_data:
|
| 220 |
+
existing_data = {'credential': {}, 'state': self._get_default_state(), 'stats': self._get_default_stats()}
|
| 221 |
+
existing_data['state'].update(state_updates)
|
| 222 |
+
return await self._credentials_cache_manager.set(filename, existing_data)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
log.error(f'Error updating credential state {filename} in Postgres: {e}')
|
| 225 |
+
return False
|
| 226 |
+
|
| 227 |
+
async def get_credential_state(self, filename: str) -> Dict[str, Any]:
|
| 228 |
+
self._ensure_initialized()
|
| 229 |
+
try:
|
| 230 |
+
credential_entry = await self._credentials_cache_manager.get(filename)
|
| 231 |
+
if credential_entry and 'state' in credential_entry:
|
| 232 |
+
return credential_entry['state']
|
| 233 |
+
return self._get_default_state()
|
| 234 |
+
except Exception as e:
|
| 235 |
+
log.error(f'Error getting credential state {filename} from Postgres: {e}')
|
| 236 |
+
return self._get_default_state()
|
| 237 |
+
|
| 238 |
+
async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
|
| 239 |
+
self._ensure_initialized()
|
| 240 |
+
try:
|
| 241 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 242 |
+
states = {fn: data.get('state', self._get_default_state()) for fn, data in all_data.items()}
|
| 243 |
+
return states
|
| 244 |
+
except Exception as e:
|
| 245 |
+
log.error(f'Error getting all credential states from Postgres: {e}')
|
| 246 |
+
return {}
|
| 247 |
+
|
| 248 |
+
async def set_config(self, key: str, value: Any) -> bool:
|
| 249 |
+
self._ensure_initialized()
|
| 250 |
+
return await self._config_cache_manager.set(key, value)
|
| 251 |
+
|
| 252 |
+
async def get_config(self, key: str, default: Any = None) -> Any:
|
| 253 |
+
self._ensure_initialized()
|
| 254 |
+
return await self._config_cache_manager.get(key, default)
|
| 255 |
+
|
| 256 |
+
async def get_all_config(self) -> Dict[str, Any]:
|
| 257 |
+
self._ensure_initialized()
|
| 258 |
+
return await self._config_cache_manager.get_all()
|
| 259 |
+
|
| 260 |
+
async def delete_config(self, key: str) -> bool:
|
| 261 |
+
self._ensure_initialized()
|
| 262 |
+
return await self._config_cache_manager.delete(key)
|
| 263 |
+
|
| 264 |
+
async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
|
| 265 |
+
self._ensure_initialized()
|
| 266 |
+
try:
|
| 267 |
+
existing_data = await self._credentials_cache_manager.get(filename, {})
|
| 268 |
+
if not existing_data:
|
| 269 |
+
existing_data = {'credential': {}, 'state': self._get_default_state(), 'stats': self._get_default_stats()}
|
| 270 |
+
existing_data['stats'].update(stats_updates)
|
| 271 |
+
return await self._credentials_cache_manager.set(filename, existing_data)
|
| 272 |
+
except Exception as e:
|
| 273 |
+
log.error(f'Error updating usage stats for {filename} in Postgres: {e}')
|
| 274 |
+
return False
|
| 275 |
+
|
| 276 |
+
async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
|
| 277 |
+
self._ensure_initialized()
|
| 278 |
+
try:
|
| 279 |
+
credential_entry = await self._credentials_cache_manager.get(filename)
|
| 280 |
+
if credential_entry and 'stats' in credential_entry:
|
| 281 |
+
return credential_entry['stats']
|
| 282 |
+
return self._get_default_stats()
|
| 283 |
+
except Exception as e:
|
| 284 |
+
log.error(f'Error getting usage stats for {filename} from Postgres: {e}')
|
| 285 |
+
return self._get_default_stats()
|
| 286 |
+
|
| 287 |
+
async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
|
| 288 |
+
self._ensure_initialized()
|
| 289 |
+
try:
|
| 290 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 291 |
+
stats = {fn: data.get('stats', self._get_default_stats()) for fn, data in all_data.items()}
|
| 292 |
+
return stats
|
| 293 |
+
except Exception as e:
|
| 294 |
+
log.error(f'Error getting all usage stats from Postgres: {e}')
|
| 295 |
+
return {}
|
| 296 |
+
|
src/storage/redis_manager.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Redis数据库管理器,使用哈希表设计和统一缓存。
|
| 3 |
+
所有凭证数据存储在一个哈希表中,配置数据存储在另一个哈希表中。
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
from typing import Dict, Any, List, Optional
|
| 10 |
+
from collections import deque
|
| 11 |
+
|
| 12 |
+
import redis.asyncio as redis
|
| 13 |
+
from log import log
|
| 14 |
+
from .cache_manager import UnifiedCacheManager, CacheBackend
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RedisCacheBackend(CacheBackend):
|
| 18 |
+
"""Redis缓存后端实现"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, redis_client: redis.Redis, hash_name: str):
|
| 21 |
+
self._client = redis_client
|
| 22 |
+
self._hash_name = hash_name
|
| 23 |
+
|
| 24 |
+
async def load_data(self) -> Dict[str, Any]:
|
| 25 |
+
"""从Redis哈希表加载数据"""
|
| 26 |
+
try:
|
| 27 |
+
hash_data = await self._client.hgetall(self._hash_name)
|
| 28 |
+
if not hash_data:
|
| 29 |
+
return {}
|
| 30 |
+
|
| 31 |
+
result = {}
|
| 32 |
+
for key, value_str in hash_data.items():
|
| 33 |
+
try:
|
| 34 |
+
result[key] = json.loads(value_str)
|
| 35 |
+
except json.JSONDecodeError as e:
|
| 36 |
+
log.error(f"Error deserializing Redis data for key {key}: {e}")
|
| 37 |
+
continue
|
| 38 |
+
return result
|
| 39 |
+
except Exception as e:
|
| 40 |
+
log.error(f"Error loading data from Redis hash {self._hash_name}: {e}")
|
| 41 |
+
return {}
|
| 42 |
+
|
| 43 |
+
async def write_data(self, data: Dict[str, Any]) -> bool:
|
| 44 |
+
"""将数据写入Redis哈希表"""
|
| 45 |
+
try:
|
| 46 |
+
if not data:
|
| 47 |
+
await self._client.delete(self._hash_name)
|
| 48 |
+
return True
|
| 49 |
+
|
| 50 |
+
hash_data = {}
|
| 51 |
+
for key, value in data.items():
|
| 52 |
+
try:
|
| 53 |
+
hash_data[key] = json.dumps(value, ensure_ascii=False)
|
| 54 |
+
except (TypeError, ValueError) as e:
|
| 55 |
+
log.error(f"Error serializing data for key {key}: {e}")
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
if not hash_data:
|
| 59 |
+
return True
|
| 60 |
+
|
| 61 |
+
pipe = self._client.pipeline()
|
| 62 |
+
pipe.delete(self._hash_name)
|
| 63 |
+
pipe.hset(self._hash_name, mapping=hash_data)
|
| 64 |
+
await pipe.execute()
|
| 65 |
+
return True
|
| 66 |
+
except Exception as e:
|
| 67 |
+
log.error(f"Error writing data to Redis hash {self._hash_name}: {e}")
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class RedisManager:
|
| 72 |
+
"""Redis数据库管理器"""
|
| 73 |
+
|
| 74 |
+
def __init__(self):
|
| 75 |
+
self._client: Optional[redis.Redis] = None
|
| 76 |
+
self._initialized = False
|
| 77 |
+
self._lock = asyncio.Lock()
|
| 78 |
+
|
| 79 |
+
# 配置
|
| 80 |
+
self._connection_uri = None
|
| 81 |
+
self._database_index = 0
|
| 82 |
+
|
| 83 |
+
# 哈希表设计 - 所有凭证存在一个哈希表中
|
| 84 |
+
self._credentials_hash_name = "gcli2api:credentials"
|
| 85 |
+
self._config_hash_name = "gcli2api:config"
|
| 86 |
+
|
| 87 |
+
# 性能监控
|
| 88 |
+
self._operation_count = 0
|
| 89 |
+
self._operation_times = deque(maxlen=5000)
|
| 90 |
+
|
| 91 |
+
# 统一缓存管理器
|
| 92 |
+
self._credentials_cache_manager: Optional[UnifiedCacheManager] = None
|
| 93 |
+
self._config_cache_manager: Optional[UnifiedCacheManager] = None
|
| 94 |
+
|
| 95 |
+
# 写入配置参数
|
| 96 |
+
self._write_delay = 1.0 # 写入延迟(秒)
|
| 97 |
+
self._cache_ttl = 300 # 缓存TTL(秒)
|
| 98 |
+
|
| 99 |
+
async def initialize(self):
|
| 100 |
+
"""初始化Redis连接"""
|
| 101 |
+
async with self._lock:
|
| 102 |
+
if self._initialized:
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
# 获取连接配置
|
| 107 |
+
self._connection_uri = os.getenv("REDIS_URI", "redis://localhost:6379")
|
| 108 |
+
self._database_index = int(os.getenv("REDIS_DATABASE", "0"))
|
| 109 |
+
|
| 110 |
+
# 建立连接 - 使用最简配置确保兼容性
|
| 111 |
+
# 检查是否需要 SSL
|
| 112 |
+
if self._connection_uri.startswith("rediss://"):
|
| 113 |
+
# SSL 连接
|
| 114 |
+
self._client = redis.from_url(
|
| 115 |
+
self._connection_uri,
|
| 116 |
+
db=self._database_index,
|
| 117 |
+
decode_responses=True,
|
| 118 |
+
ssl_cert_reqs=None,
|
| 119 |
+
ssl_check_hostname=False,
|
| 120 |
+
ssl_ca_certs=None
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
# 普通连接
|
| 124 |
+
self._client = redis.from_url(
|
| 125 |
+
self._connection_uri,
|
| 126 |
+
db=self._database_index,
|
| 127 |
+
decode_responses=True
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# 验证连接
|
| 131 |
+
await self._client.ping()
|
| 132 |
+
|
| 133 |
+
# 创建缓存管理器
|
| 134 |
+
credentials_backend = RedisCacheBackend(self._client, self._credentials_hash_name)
|
| 135 |
+
config_backend = RedisCacheBackend(self._client, self._config_hash_name)
|
| 136 |
+
|
| 137 |
+
self._credentials_cache_manager = UnifiedCacheManager(
|
| 138 |
+
credentials_backend,
|
| 139 |
+
cache_ttl=self._cache_ttl,
|
| 140 |
+
write_delay=self._write_delay,
|
| 141 |
+
name="credentials"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self._config_cache_manager = UnifiedCacheManager(
|
| 145 |
+
config_backend,
|
| 146 |
+
cache_ttl=self._cache_ttl,
|
| 147 |
+
write_delay=self._write_delay,
|
| 148 |
+
name="config"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# 启动缓存管理器
|
| 152 |
+
await self._credentials_cache_manager.start()
|
| 153 |
+
await self._config_cache_manager.start()
|
| 154 |
+
|
| 155 |
+
self._initialized = True
|
| 156 |
+
log.info(f"Redis connection established to database {self._database_index} with unified cache")
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
log.error(f"Error initializing Redis: {e}")
|
| 160 |
+
raise
|
| 161 |
+
|
| 162 |
+
async def close(self):
|
| 163 |
+
"""关闭Redis连接"""
|
| 164 |
+
# 停止缓存管理器
|
| 165 |
+
if self._credentials_cache_manager:
|
| 166 |
+
await self._credentials_cache_manager.stop()
|
| 167 |
+
if self._config_cache_manager:
|
| 168 |
+
await self._config_cache_manager.stop()
|
| 169 |
+
|
| 170 |
+
if self._client:
|
| 171 |
+
await self._client.close()
|
| 172 |
+
self._initialized = False
|
| 173 |
+
log.info("Redis connection closed with unified cache flushed")
|
| 174 |
+
|
| 175 |
+
def _ensure_initialized(self):
|
| 176 |
+
"""确保已初始化"""
|
| 177 |
+
if not self._initialized:
|
| 178 |
+
raise RuntimeError("Redis manager not initialized")
|
| 179 |
+
|
| 180 |
+
def _get_default_state(self) -> Dict[str, Any]:
|
| 181 |
+
"""获取默认状态数据"""
|
| 182 |
+
return {
|
| 183 |
+
"error_codes": [],
|
| 184 |
+
"disabled": False,
|
| 185 |
+
"last_success": time.time(),
|
| 186 |
+
"user_email": None,
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
def _get_default_stats(self) -> Dict[str, Any]:
|
| 190 |
+
"""获取默认统计数据"""
|
| 191 |
+
return {
|
| 192 |
+
"gemini_2_5_pro_calls": 0,
|
| 193 |
+
"total_calls": 0,
|
| 194 |
+
"next_reset_time": None,
|
| 195 |
+
"daily_limit_gemini_2_5_pro": 100,
|
| 196 |
+
"daily_limit_total": 1000
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
# ============ 凭证管理 ============
|
| 200 |
+
|
| 201 |
+
async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
|
| 202 |
+
"""存储凭证数据到统一缓存"""
|
| 203 |
+
self._ensure_initialized()
|
| 204 |
+
start_time = time.time()
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
# 获取现有数据或创建新数据
|
| 208 |
+
existing_data = await self._credentials_cache_manager.get(filename, {})
|
| 209 |
+
|
| 210 |
+
credential_entry = {
|
| 211 |
+
"credential": credential_data,
|
| 212 |
+
"state": existing_data.get("state", self._get_default_state()),
|
| 213 |
+
"stats": existing_data.get("stats", self._get_default_stats())
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
success = await self._credentials_cache_manager.set(filename, credential_entry)
|
| 217 |
+
|
| 218 |
+
# 性能监控
|
| 219 |
+
self._operation_count += 1
|
| 220 |
+
operation_time = time.time() - start_time
|
| 221 |
+
self._operation_times.append(operation_time)
|
| 222 |
+
|
| 223 |
+
log.debug(f"Stored credential to unified cache: {filename} in {operation_time:.3f}s")
|
| 224 |
+
return success
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
operation_time = time.time() - start_time
|
| 228 |
+
log.error(f"Error storing credential {filename} in {operation_time:.3f}s: {e}")
|
| 229 |
+
return False
|
| 230 |
+
|
| 231 |
+
async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
|
| 232 |
+
"""从统一缓存获取凭证数据"""
|
| 233 |
+
self._ensure_initialized()
|
| 234 |
+
start_time = time.time()
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
credential_entry = await self._credentials_cache_manager.get(filename)
|
| 238 |
+
|
| 239 |
+
# 性能监控
|
| 240 |
+
self._operation_count += 1
|
| 241 |
+
operation_time = time.time() - start_time
|
| 242 |
+
self._operation_times.append(operation_time)
|
| 243 |
+
|
| 244 |
+
if credential_entry and "credential" in credential_entry:
|
| 245 |
+
return credential_entry["credential"]
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
except Exception as e:
|
| 249 |
+
operation_time = time.time() - start_time
|
| 250 |
+
log.error(f"Error retrieving credential {filename} in {operation_time:.3f}s: {e}")
|
| 251 |
+
return None
|
| 252 |
+
|
| 253 |
+
async def list_credentials(self) -> List[str]:
|
| 254 |
+
"""从统一缓存列出所有凭证文件名"""
|
| 255 |
+
self._ensure_initialized()
|
| 256 |
+
start_time = time.time()
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 260 |
+
filenames = list(all_data.keys())
|
| 261 |
+
|
| 262 |
+
# 性能监控
|
| 263 |
+
self._operation_count += 1
|
| 264 |
+
operation_time = time.time() - start_time
|
| 265 |
+
self._operation_times.append(operation_time)
|
| 266 |
+
|
| 267 |
+
log.debug(f"Listed {len(filenames)} credentials from unified cache in {operation_time:.3f}s")
|
| 268 |
+
return filenames
|
| 269 |
+
|
| 270 |
+
except Exception as e:
|
| 271 |
+
operation_time = time.time() - start_time
|
| 272 |
+
log.error(f"Error listing credentials in {operation_time:.3f}s: {e}")
|
| 273 |
+
return []
|
| 274 |
+
|
| 275 |
+
async def delete_credential(self, filename: str) -> bool:
|
| 276 |
+
"""从统一缓存删除凭证及所有相关数据"""
|
| 277 |
+
self._ensure_initialized()
|
| 278 |
+
start_time = time.time()
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
success = await self._credentials_cache_manager.delete(filename)
|
| 282 |
+
|
| 283 |
+
# 性能监控
|
| 284 |
+
self._operation_count += 1
|
| 285 |
+
operation_time = time.time() - start_time
|
| 286 |
+
self._operation_times.append(operation_time)
|
| 287 |
+
|
| 288 |
+
log.debug(f"Deleted credential from unified cache: {filename} in {operation_time:.3f}s")
|
| 289 |
+
return success
|
| 290 |
+
|
| 291 |
+
except Exception as e:
|
| 292 |
+
operation_time = time.time() - start_time
|
| 293 |
+
log.error(f"Error deleting credential {filename} in {operation_time:.3f}s: {e}")
|
| 294 |
+
return False
|
| 295 |
+
|
| 296 |
+
# ============ 状态管理 ============
|
| 297 |
+
|
| 298 |
+
async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
|
| 299 |
+
"""更新凭证状态(使用统一缓存)"""
|
| 300 |
+
self._ensure_initialized()
|
| 301 |
+
start_time = time.time()
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
# 获取现有数据或创建新数据
|
| 305 |
+
existing_data = await self._credentials_cache_manager.get(filename, {})
|
| 306 |
+
|
| 307 |
+
if not existing_data:
|
| 308 |
+
existing_data = {
|
| 309 |
+
"credential": {},
|
| 310 |
+
"state": self._get_default_state(),
|
| 311 |
+
"stats": self._get_default_stats()
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
# 更新状态数据
|
| 315 |
+
existing_data["state"].update(state_updates)
|
| 316 |
+
|
| 317 |
+
success = await self._credentials_cache_manager.set(filename, existing_data)
|
| 318 |
+
|
| 319 |
+
# 性能监控
|
| 320 |
+
self._operation_count += 1
|
| 321 |
+
operation_time = time.time() - start_time
|
| 322 |
+
self._operation_times.append(operation_time)
|
| 323 |
+
|
| 324 |
+
log.debug(f"Updated credential state in unified cache: {filename} in {operation_time:.3f}s")
|
| 325 |
+
return success
|
| 326 |
+
|
| 327 |
+
except Exception as e:
|
| 328 |
+
operation_time = time.time() - start_time
|
| 329 |
+
log.error(f"Error updating credential state {filename} in {operation_time:.3f}s: {e}")
|
| 330 |
+
return False
|
| 331 |
+
|
| 332 |
+
async def get_credential_state(self, filename: str) -> Dict[str, Any]:
|
| 333 |
+
"""从统一缓存获取凭证状态"""
|
| 334 |
+
self._ensure_initialized()
|
| 335 |
+
start_time = time.time()
|
| 336 |
+
|
| 337 |
+
try:
|
| 338 |
+
credential_entry = await self._credentials_cache_manager.get(filename)
|
| 339 |
+
|
| 340 |
+
# 性能监控
|
| 341 |
+
self._operation_count += 1
|
| 342 |
+
operation_time = time.time() - start_time
|
| 343 |
+
self._operation_times.append(operation_time)
|
| 344 |
+
|
| 345 |
+
if credential_entry and "state" in credential_entry:
|
| 346 |
+
log.debug(f"Retrieved credential state from unified cache: {filename} in {operation_time:.3f}s")
|
| 347 |
+
return credential_entry["state"]
|
| 348 |
+
else:
|
| 349 |
+
# 返回默认状态
|
| 350 |
+
return self._get_default_state()
|
| 351 |
+
|
| 352 |
+
except Exception as e:
|
| 353 |
+
operation_time = time.time() - start_time
|
| 354 |
+
log.error(f"Error getting credential state {filename} in {operation_time:.3f}s: {e}")
|
| 355 |
+
return self._get_default_state()
|
| 356 |
+
|
| 357 |
+
async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
|
| 358 |
+
"""从统一缓存获取所有凭证状态"""
|
| 359 |
+
self._ensure_initialized()
|
| 360 |
+
start_time = time.time()
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 364 |
+
|
| 365 |
+
states = {}
|
| 366 |
+
for filename, cred_data in all_data.items():
|
| 367 |
+
states[filename] = cred_data.get("state", self._get_default_state())
|
| 368 |
+
|
| 369 |
+
# 性能监控
|
| 370 |
+
self._operation_count += 1
|
| 371 |
+
operation_time = time.time() - start_time
|
| 372 |
+
self._operation_times.append(operation_time)
|
| 373 |
+
|
| 374 |
+
log.debug(f"Retrieved all credential states from unified cache ({len(states)}) in {operation_time:.3f}s")
|
| 375 |
+
return states
|
| 376 |
+
|
| 377 |
+
except Exception as e:
|
| 378 |
+
operation_time = time.time() - start_time
|
| 379 |
+
log.error(f"Error getting all credential states in {operation_time:.3f}s: {e}")
|
| 380 |
+
return {}
|
| 381 |
+
|
| 382 |
+
# ============ 配置管理 ============
|
| 383 |
+
|
| 384 |
+
async def set_config(self, key: str, value: Any) -> bool:
|
| 385 |
+
"""设置配置到统一缓存"""
|
| 386 |
+
self._ensure_initialized()
|
| 387 |
+
start_time = time.time()
|
| 388 |
+
|
| 389 |
+
try:
|
| 390 |
+
success = await self._config_cache_manager.set(key, value)
|
| 391 |
+
|
| 392 |
+
# ���能监控
|
| 393 |
+
self._operation_count += 1
|
| 394 |
+
operation_time = time.time() - start_time
|
| 395 |
+
self._operation_times.append(operation_time)
|
| 396 |
+
|
| 397 |
+
log.debug(f"Set config to unified cache: {key} in {operation_time:.3f}s")
|
| 398 |
+
return success
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
operation_time = time.time() - start_time
|
| 402 |
+
log.error(f"Error setting config {key} in {operation_time:.3f}s: {e}")
|
| 403 |
+
return False
|
| 404 |
+
|
| 405 |
+
async def get_config(self, key: str, default: Any = None) -> Any:
|
| 406 |
+
"""从统一缓存获取配置"""
|
| 407 |
+
self._ensure_initialized()
|
| 408 |
+
return await self._config_cache_manager.get(key, default)
|
| 409 |
+
|
| 410 |
+
async def get_all_config(self) -> Dict[str, Any]:
|
| 411 |
+
"""从统一缓存获取所有配置"""
|
| 412 |
+
self._ensure_initialized()
|
| 413 |
+
return await self._config_cache_manager.get_all()
|
| 414 |
+
|
| 415 |
+
async def delete_config(self, key: str) -> bool:
|
| 416 |
+
"""从统一缓存删除配置"""
|
| 417 |
+
self._ensure_initialized()
|
| 418 |
+
return await self._config_cache_manager.delete(key)
|
| 419 |
+
|
| 420 |
+
# ============ 使用统计管理 ============
|
| 421 |
+
|
| 422 |
+
async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
|
| 423 |
+
"""更新使用统计(使用统一缓存)"""
|
| 424 |
+
self._ensure_initialized()
|
| 425 |
+
start_time = time.time()
|
| 426 |
+
|
| 427 |
+
try:
|
| 428 |
+
# 获取现有数据或创建新数据
|
| 429 |
+
existing_data = await self._credentials_cache_manager.get(filename, {})
|
| 430 |
+
|
| 431 |
+
if not existing_data:
|
| 432 |
+
existing_data = {
|
| 433 |
+
"credential": {},
|
| 434 |
+
"state": self._get_default_state(),
|
| 435 |
+
"stats": self._get_default_stats()
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
# 更新统计数据
|
| 439 |
+
existing_data["stats"].update(stats_updates)
|
| 440 |
+
|
| 441 |
+
success = await self._credentials_cache_manager.set(filename, existing_data)
|
| 442 |
+
|
| 443 |
+
# 性能监控
|
| 444 |
+
self._operation_count += 1
|
| 445 |
+
operation_time = time.time() - start_time
|
| 446 |
+
self._operation_times.append(operation_time)
|
| 447 |
+
|
| 448 |
+
log.debug(f"Updated usage stats in unified cache: {filename} in {operation_time:.3f}s")
|
| 449 |
+
return success
|
| 450 |
+
|
| 451 |
+
except Exception as e:
|
| 452 |
+
operation_time = time.time() - start_time
|
| 453 |
+
log.error(f"Error updating usage stats {filename} in {operation_time:.3f}s: {e}")
|
| 454 |
+
return False
|
| 455 |
+
|
| 456 |
+
async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
|
| 457 |
+
"""从统一缓存获取使用统计"""
|
| 458 |
+
self._ensure_initialized()
|
| 459 |
+
start_time = time.time()
|
| 460 |
+
|
| 461 |
+
try:
|
| 462 |
+
credential_entry = await self._credentials_cache_manager.get(filename)
|
| 463 |
+
|
| 464 |
+
# 性能监控
|
| 465 |
+
self._operation_count += 1
|
| 466 |
+
operation_time = time.time() - start_time
|
| 467 |
+
self._operation_times.append(operation_time)
|
| 468 |
+
|
| 469 |
+
if credential_entry and "stats" in credential_entry:
|
| 470 |
+
log.debug(f"Retrieved usage stats from unified cache: {filename} in {operation_time:.3f}s")
|
| 471 |
+
return credential_entry["stats"]
|
| 472 |
+
else:
|
| 473 |
+
return self._get_default_stats()
|
| 474 |
+
|
| 475 |
+
except Exception as e:
|
| 476 |
+
operation_time = time.time() - start_time
|
| 477 |
+
log.error(f"Error getting usage stats {filename} in {operation_time:.3f}s: {e}")
|
| 478 |
+
return self._get_default_stats()
|
| 479 |
+
|
| 480 |
+
async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
|
| 481 |
+
"""从统一缓存获取所有使用统计"""
|
| 482 |
+
self._ensure_initialized()
|
| 483 |
+
start_time = time.time()
|
| 484 |
+
|
| 485 |
+
try:
|
| 486 |
+
all_data = await self._credentials_cache_manager.get_all()
|
| 487 |
+
|
| 488 |
+
stats = {}
|
| 489 |
+
for filename, cred_data in all_data.items():
|
| 490 |
+
if "stats" in cred_data:
|
| 491 |
+
stats[filename] = cred_data["stats"]
|
| 492 |
+
|
| 493 |
+
# 性能监控
|
| 494 |
+
self._operation_count += 1
|
| 495 |
+
operation_time = time.time() - start_time
|
| 496 |
+
self._operation_times.append(operation_time)
|
| 497 |
+
|
| 498 |
+
log.debug(f"Retrieved all usage stats from unified cache ({len(stats)}) in {operation_time:.3f}s")
|
| 499 |
+
return stats
|
| 500 |
+
|
| 501 |
+
except Exception as e:
|
| 502 |
+
operation_time = time.time() - start_time
|
| 503 |
+
log.error(f"Error getting all usage stats in {operation_time:.3f}s: {e}")
|
| 504 |
+
return {}
|
src/storage_adapter.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
存储适配器,提供统一的接口来处理Redis、MongoDB和本地文件存储。
|
| 3 |
+
根据配置自动选择存储后端,优先级:Redis > MongoDB > 本地文件。
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
from typing import Dict, Any, List, Optional, Protocol
|
| 9 |
+
|
| 10 |
+
from log import log
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class StorageBackend(Protocol):
|
| 14 |
+
"""存储后端协议"""
|
| 15 |
+
|
| 16 |
+
async def initialize(self) -> None:
|
| 17 |
+
"""初始化存储后端"""
|
| 18 |
+
...
|
| 19 |
+
|
| 20 |
+
async def close(self) -> None:
|
| 21 |
+
"""关闭存储后端"""
|
| 22 |
+
...
|
| 23 |
+
|
| 24 |
+
# 凭证管理
|
| 25 |
+
async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
|
| 26 |
+
"""存储凭证数据"""
|
| 27 |
+
...
|
| 28 |
+
|
| 29 |
+
async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
|
| 30 |
+
"""获取凭证数据"""
|
| 31 |
+
...
|
| 32 |
+
|
| 33 |
+
async def list_credentials(self) -> List[str]:
|
| 34 |
+
"""列出所有凭证文件名"""
|
| 35 |
+
...
|
| 36 |
+
|
| 37 |
+
async def delete_credential(self, filename: str) -> bool:
|
| 38 |
+
"""删除凭证"""
|
| 39 |
+
...
|
| 40 |
+
|
| 41 |
+
# 状态管理
|
| 42 |
+
async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
|
| 43 |
+
"""更新凭证状态"""
|
| 44 |
+
...
|
| 45 |
+
|
| 46 |
+
async def get_credential_state(self, filename: str) -> Dict[str, Any]:
|
| 47 |
+
"""获取凭证状态"""
|
| 48 |
+
...
|
| 49 |
+
|
| 50 |
+
async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
|
| 51 |
+
"""获取所有凭证状态"""
|
| 52 |
+
...
|
| 53 |
+
|
| 54 |
+
# 配置管理
|
| 55 |
+
async def set_config(self, key: str, value: Any) -> bool:
|
| 56 |
+
"""设置配置项"""
|
| 57 |
+
...
|
| 58 |
+
|
| 59 |
+
async def get_config(self, key: str, default: Any = None) -> Any:
|
| 60 |
+
"""获取配置项"""
|
| 61 |
+
...
|
| 62 |
+
|
| 63 |
+
async def get_all_config(self) -> Dict[str, Any]:
|
| 64 |
+
"""获取所有配置"""
|
| 65 |
+
...
|
| 66 |
+
|
| 67 |
+
async def delete_config(self, key: str) -> bool:
|
| 68 |
+
"""删除配置项"""
|
| 69 |
+
...
|
| 70 |
+
|
| 71 |
+
# 使用统计管理
|
| 72 |
+
async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
|
| 73 |
+
"""更新使用统计"""
|
| 74 |
+
...
|
| 75 |
+
|
| 76 |
+
async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
|
| 77 |
+
"""获取使用统计"""
|
| 78 |
+
...
|
| 79 |
+
|
| 80 |
+
async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
|
| 81 |
+
"""获取所有使用统计"""
|
| 82 |
+
...
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class StorageAdapter:
|
| 88 |
+
"""存储适配器,根据配置选择存储后端"""
|
| 89 |
+
|
| 90 |
+
def __init__(self):
|
| 91 |
+
self._backend: Optional["StorageBackend"] = None
|
| 92 |
+
self._initialized = False
|
| 93 |
+
self._lock = asyncio.Lock()
|
| 94 |
+
|
| 95 |
+
async def initialize(self) -> None:
|
| 96 |
+
"""初始化存储适配器"""
|
| 97 |
+
async with self._lock:
|
| 98 |
+
if self._initialized:
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
# 按优先级检查存储后端:Redis > MongoDB > 本地文件
|
| 102 |
+
redis_uri = os.getenv("REDIS_URI", "")
|
| 103 |
+
mongodb_uri = os.getenv("MONGODB_URI", "")
|
| 104 |
+
|
| 105 |
+
# 优先尝试Redis存储
|
| 106 |
+
if redis_uri:
|
| 107 |
+
try:
|
| 108 |
+
from .storage.redis_manager import RedisManager
|
| 109 |
+
self._backend = RedisManager()
|
| 110 |
+
await self._backend.initialize()
|
| 111 |
+
log.info("Using Redis storage backend")
|
| 112 |
+
except ImportError as e:
|
| 113 |
+
log.error(f"Failed to import Redis backend: {e}")
|
| 114 |
+
log.info("Falling back to next available storage backend")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
log.error(f"Failed to initialize Redis backend: {e}")
|
| 117 |
+
log.info("Falling back to next available storage backend")
|
| 118 |
+
|
| 119 |
+
# 如果Redis不可用或未配置,接下来尝试Postgres(优先级低于Redis)
|
| 120 |
+
postgres_dsn = os.getenv("POSTGRES_DSN", "")
|
| 121 |
+
if not self._backend and postgres_dsn:
|
| 122 |
+
try:
|
| 123 |
+
from .storage.postgres_manager import PostgresManager
|
| 124 |
+
self._backend = PostgresManager()
|
| 125 |
+
await self._backend.initialize()
|
| 126 |
+
log.info("Using Postgres storage backend")
|
| 127 |
+
except ImportError as e:
|
| 128 |
+
log.error(f"Failed to import Postgres backend: {e}")
|
| 129 |
+
log.info("Falling back to next available storage backend")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
log.error(f"Failed to initialize Postgres backend: {e}")
|
| 132 |
+
log.info("Falling back to next available storage backend")
|
| 133 |
+
|
| 134 |
+
# 如果Redis和Postgres不可用,尝试MongoDB存储
|
| 135 |
+
if not self._backend and mongodb_uri:
|
| 136 |
+
try:
|
| 137 |
+
from .storage.mongodb_manager import MongoDBManager
|
| 138 |
+
self._backend = MongoDBManager()
|
| 139 |
+
await self._backend.initialize()
|
| 140 |
+
log.info("Using MongoDB storage backend")
|
| 141 |
+
except ImportError as e:
|
| 142 |
+
log.error(f"Failed to import MongoDB backend: {e}")
|
| 143 |
+
log.info("Falling back to file storage backend")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
log.error(f"Failed to initialize MongoDB backend: {e}")
|
| 146 |
+
log.info("Falling back to file storage backend")
|
| 147 |
+
|
| 148 |
+
# 如果Redis和MongoDB都不可用,使用文件存储
|
| 149 |
+
if not self._backend:
|
| 150 |
+
from .storage.file_storage_manager import FileStorageManager
|
| 151 |
+
self._backend = FileStorageManager()
|
| 152 |
+
await self._backend.initialize()
|
| 153 |
+
log.info("Using file storage backend")
|
| 154 |
+
|
| 155 |
+
self._initialized = True
|
| 156 |
+
|
| 157 |
+
async def close(self) -> None:
|
| 158 |
+
"""关闭存储适配器"""
|
| 159 |
+
if self._backend:
|
| 160 |
+
await self._backend.close()
|
| 161 |
+
self._backend = None
|
| 162 |
+
self._initialized = False
|
| 163 |
+
|
| 164 |
+
def _ensure_initialized(self):
|
| 165 |
+
"""确保存储适配器已初始化"""
|
| 166 |
+
if not self._initialized or not self._backend:
|
| 167 |
+
raise RuntimeError("Storage adapter not initialized")
|
| 168 |
+
|
| 169 |
+
# ============ 凭证管理 ============
|
| 170 |
+
|
| 171 |
+
async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
|
| 172 |
+
"""存储凭证数据"""
|
| 173 |
+
self._ensure_initialized()
|
| 174 |
+
return await self._backend.store_credential(filename, credential_data)
|
| 175 |
+
|
| 176 |
+
async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
|
| 177 |
+
"""获取凭证数据"""
|
| 178 |
+
self._ensure_initialized()
|
| 179 |
+
return await self._backend.get_credential(filename)
|
| 180 |
+
|
| 181 |
+
async def list_credentials(self) -> List[str]:
|
| 182 |
+
"""列出所有凭证文件名"""
|
| 183 |
+
self._ensure_initialized()
|
| 184 |
+
return await self._backend.list_credentials()
|
| 185 |
+
|
| 186 |
+
async def delete_credential(self, filename: str) -> bool:
|
| 187 |
+
"""删除凭证"""
|
| 188 |
+
self._ensure_initialized()
|
| 189 |
+
return await self._backend.delete_credential(filename)
|
| 190 |
+
|
| 191 |
+
# ============ 状态管理 ============
|
| 192 |
+
|
| 193 |
+
async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
|
| 194 |
+
"""更新凭证状态"""
|
| 195 |
+
self._ensure_initialized()
|
| 196 |
+
return await self._backend.update_credential_state(filename, state_updates)
|
| 197 |
+
|
| 198 |
+
async def get_credential_state(self, filename: str) -> Dict[str, Any]:
|
| 199 |
+
"""获取凭证状态"""
|
| 200 |
+
self._ensure_initialized()
|
| 201 |
+
return await self._backend.get_credential_state(filename)
|
| 202 |
+
|
| 203 |
+
async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
|
| 204 |
+
"""获取所有凭证状态"""
|
| 205 |
+
self._ensure_initialized()
|
| 206 |
+
return await self._backend.get_all_credential_states()
|
| 207 |
+
|
| 208 |
+
# ============ 配置管理 ============
|
| 209 |
+
|
| 210 |
+
async def set_config(self, key: str, value: Any) -> bool:
|
| 211 |
+
"""设置配置项"""
|
| 212 |
+
self._ensure_initialized()
|
| 213 |
+
return await self._backend.set_config(key, value)
|
| 214 |
+
|
| 215 |
+
async def get_config(self, key: str, default: Any = None) -> Any:
|
| 216 |
+
"""获取配置项"""
|
| 217 |
+
self._ensure_initialized()
|
| 218 |
+
return await self._backend.get_config(key, default)
|
| 219 |
+
|
| 220 |
+
async def get_all_config(self) -> Dict[str, Any]:
|
| 221 |
+
"""获取所有配置"""
|
| 222 |
+
self._ensure_initialized()
|
| 223 |
+
return await self._backend.get_all_config()
|
| 224 |
+
|
| 225 |
+
async def delete_config(self, key: str) -> bool:
|
| 226 |
+
"""删除配置项"""
|
| 227 |
+
self._ensure_initialized()
|
| 228 |
+
return await self._backend.delete_config(key)
|
| 229 |
+
|
| 230 |
+
# ============ 使用统计管理 ============
|
| 231 |
+
|
| 232 |
+
async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
|
| 233 |
+
"""更新使用统计"""
|
| 234 |
+
self._ensure_initialized()
|
| 235 |
+
return await self._backend.update_usage_stats(filename, stats_updates)
|
| 236 |
+
|
| 237 |
+
async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
|
| 238 |
+
"""获取使用统计"""
|
| 239 |
+
self._ensure_initialized()
|
| 240 |
+
return await self._backend.get_usage_stats(filename)
|
| 241 |
+
|
| 242 |
+
async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
|
| 243 |
+
"""获取所有使用统计"""
|
| 244 |
+
self._ensure_initialized()
|
| 245 |
+
return await self._backend.get_all_usage_stats()
|
| 246 |
+
|
| 247 |
+
# ============ 工具方法 ============
|
| 248 |
+
|
| 249 |
+
async def export_credential_to_json(self, filename: str, output_path: str = None) -> bool:
|
| 250 |
+
"""将凭证导出为JSON文件"""
|
| 251 |
+
self._ensure_initialized()
|
| 252 |
+
if hasattr(self._backend, 'export_credential_to_json'):
|
| 253 |
+
return await self._backend.export_credential_to_json(filename, output_path)
|
| 254 |
+
# MongoDB后端的fallback实现
|
| 255 |
+
credential_data = await self.get_credential(filename)
|
| 256 |
+
if credential_data is None:
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
if output_path is None:
|
| 260 |
+
output_path = f"{filename}.json"
|
| 261 |
+
|
| 262 |
+
import aiofiles
|
| 263 |
+
try:
|
| 264 |
+
async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
|
| 265 |
+
await f.write(json.dumps(credential_data, indent=2, ensure_ascii=False))
|
| 266 |
+
return True
|
| 267 |
+
except Exception:
|
| 268 |
+
return False
|
| 269 |
+
|
| 270 |
+
async def import_credential_from_json(self, json_path: str, filename: str = None) -> bool:
|
| 271 |
+
"""从JSON文件导入凭证"""
|
| 272 |
+
self._ensure_initialized()
|
| 273 |
+
if hasattr(self._backend, 'import_credential_from_json'):
|
| 274 |
+
return await self._backend.import_credential_from_json(json_path, filename)
|
| 275 |
+
# MongoDB后端的fallback实现
|
| 276 |
+
try:
|
| 277 |
+
import aiofiles
|
| 278 |
+
async with aiofiles.open(json_path, "r", encoding="utf-8") as f:
|
| 279 |
+
content = await f.read()
|
| 280 |
+
|
| 281 |
+
credential_data = json.loads(content)
|
| 282 |
+
|
| 283 |
+
if filename is None:
|
| 284 |
+
filename = os.path.basename(json_path)
|
| 285 |
+
|
| 286 |
+
return await self.store_credential(filename, credential_data)
|
| 287 |
+
except Exception:
|
| 288 |
+
return False
|
| 289 |
+
|
| 290 |
+
def get_backend_type(self) -> str:
|
| 291 |
+
"""获取当前存储后端类型"""
|
| 292 |
+
if not self._backend:
|
| 293 |
+
return "none"
|
| 294 |
+
|
| 295 |
+
# 检查后端类型
|
| 296 |
+
backend_class_name = self._backend.__class__.__name__
|
| 297 |
+
if "File" in backend_class_name or "file" in backend_class_name.lower():
|
| 298 |
+
return "file"
|
| 299 |
+
elif "MongoDB" in backend_class_name or "mongo" in backend_class_name.lower():
|
| 300 |
+
return "mongodb"
|
| 301 |
+
elif "Redis" in backend_class_name or "redis" in backend_class_name.lower():
|
| 302 |
+
return "redis"
|
| 303 |
+
else:
|
| 304 |
+
return "unknown"
|
| 305 |
+
|
| 306 |
+
async def get_backend_info(self) -> Dict[str, Any]:
|
| 307 |
+
"""获取存储后端信息"""
|
| 308 |
+
self._ensure_initialized()
|
| 309 |
+
|
| 310 |
+
backend_type = self.get_backend_type()
|
| 311 |
+
info = {
|
| 312 |
+
"backend_type": backend_type,
|
| 313 |
+
"initialized": self._initialized
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
# 获取底层存储信息
|
| 317 |
+
if hasattr(self._backend, 'get_database_info'):
|
| 318 |
+
try:
|
| 319 |
+
db_info = await self._backend.get_database_info()
|
| 320 |
+
info.update(db_info)
|
| 321 |
+
except Exception as e:
|
| 322 |
+
info["database_error"] = str(e)
|
| 323 |
+
else:
|
| 324 |
+
backend_type = self.get_backend_type()
|
| 325 |
+
if backend_type == "file":
|
| 326 |
+
info.update({
|
| 327 |
+
"credentials_dir": getattr(self._backend, '_credentials_dir', None),
|
| 328 |
+
"state_file": getattr(self._backend, '_state_file', None),
|
| 329 |
+
"config_file": getattr(self._backend, '_config_file', None)
|
| 330 |
+
})
|
| 331 |
+
elif backend_type == "redis":
|
| 332 |
+
info.update({
|
| 333 |
+
"redis_url": getattr(self._backend, '_redis_url', None),
|
| 334 |
+
"connection_pool_size": getattr(self._backend, '_pool_size', None)
|
| 335 |
+
})
|
| 336 |
+
|
| 337 |
+
return info
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# 全局存储适配器实例
|
| 341 |
+
_storage_adapter: Optional[StorageAdapter] = None
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
async def get_storage_adapter() -> StorageAdapter:
|
| 345 |
+
"""获取全局存储适配器实例"""
|
| 346 |
+
global _storage_adapter
|
| 347 |
+
|
| 348 |
+
if _storage_adapter is None:
|
| 349 |
+
_storage_adapter = StorageAdapter()
|
| 350 |
+
await _storage_adapter.initialize()
|
| 351 |
+
|
| 352 |
+
return _storage_adapter
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
async def close_storage_adapter():
|
| 356 |
+
"""关闭全局存储适配器"""
|
| 357 |
+
global _storage_adapter
|
| 358 |
+
|
| 359 |
+
if _storage_adapter:
|
| 360 |
+
await _storage_adapter.close()
|
| 361 |
+
_storage_adapter = None
|
src/task_manager.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Global task lifecycle management module
|
| 3 |
+
管理应用程序中所有异步任务的生命周期,确保正确清理
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import weakref
|
| 7 |
+
from typing import Set, Dict, Any
|
| 8 |
+
from log import log
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TaskManager:
|
| 12 |
+
"""全局异步任务管理器 - 单例模式"""
|
| 13 |
+
|
| 14 |
+
_instance = None
|
| 15 |
+
_lock = asyncio.Lock()
|
| 16 |
+
|
| 17 |
+
def __new__(cls):
|
| 18 |
+
if cls._instance is None:
|
| 19 |
+
cls._instance = super().__new__(cls)
|
| 20 |
+
cls._instance._initialized = False
|
| 21 |
+
return cls._instance
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
if self._initialized:
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
self._tasks: Set[asyncio.Task] = set()
|
| 28 |
+
self._resources: Set[Any] = set() # 需要关闭的资源
|
| 29 |
+
self._shutdown_event = asyncio.Event()
|
| 30 |
+
self._initialized = True
|
| 31 |
+
log.debug("TaskManager initialized")
|
| 32 |
+
|
| 33 |
+
def register_task(self, task: asyncio.Task, description: str = None) -> asyncio.Task:
|
| 34 |
+
"""注册一个任务供生命周期管理"""
|
| 35 |
+
self._tasks.add(task)
|
| 36 |
+
task.add_done_callback(lambda t: self._tasks.discard(t))
|
| 37 |
+
|
| 38 |
+
if description:
|
| 39 |
+
task.set_name(description)
|
| 40 |
+
|
| 41 |
+
log.debug(f"Registered task: {task.get_name() or 'unnamed'}")
|
| 42 |
+
return task
|
| 43 |
+
|
| 44 |
+
def create_task(self, coro, *, name: str = None) -> asyncio.Task:
|
| 45 |
+
"""创建并注册一个任务"""
|
| 46 |
+
task = asyncio.create_task(coro, name=name)
|
| 47 |
+
return self.register_task(task, name)
|
| 48 |
+
|
| 49 |
+
def register_resource(self, resource: Any) -> Any:
|
| 50 |
+
"""注册一个需要清理的资源(如HTTP客户端、文件句柄等)"""
|
| 51 |
+
# 使用弱引用避免循环引用
|
| 52 |
+
self._resources.add(weakref.ref(resource))
|
| 53 |
+
log.debug(f"Registered resource: {type(resource).__name__}")
|
| 54 |
+
return resource
|
| 55 |
+
|
| 56 |
+
async def shutdown(self, timeout: float = 30.0):
|
| 57 |
+
"""关闭所有任务和资源"""
|
| 58 |
+
log.info("TaskManager shutdown initiated")
|
| 59 |
+
|
| 60 |
+
# 设置关闭标志
|
| 61 |
+
self._shutdown_event.set()
|
| 62 |
+
|
| 63 |
+
# 取消所有未完成的任务
|
| 64 |
+
cancelled_count = 0
|
| 65 |
+
for task in list(self._tasks):
|
| 66 |
+
if not task.done():
|
| 67 |
+
task.cancel()
|
| 68 |
+
cancelled_count += 1
|
| 69 |
+
|
| 70 |
+
if cancelled_count > 0:
|
| 71 |
+
log.info(f"Cancelled {cancelled_count} pending tasks")
|
| 72 |
+
|
| 73 |
+
# 等待所有任务完成(包括取消)
|
| 74 |
+
if self._tasks:
|
| 75 |
+
try:
|
| 76 |
+
await asyncio.wait_for(
|
| 77 |
+
asyncio.gather(*self._tasks, return_exceptions=True),
|
| 78 |
+
timeout=timeout
|
| 79 |
+
)
|
| 80 |
+
except asyncio.TimeoutError:
|
| 81 |
+
log.warning(f"Some tasks did not complete within {timeout}s timeout")
|
| 82 |
+
|
| 83 |
+
# 清理资源
|
| 84 |
+
cleaned_resources = 0
|
| 85 |
+
for resource_ref in list(self._resources):
|
| 86 |
+
resource = resource_ref()
|
| 87 |
+
if resource is not None:
|
| 88 |
+
try:
|
| 89 |
+
if hasattr(resource, 'close'):
|
| 90 |
+
if asyncio.iscoroutinefunction(resource.close):
|
| 91 |
+
await resource.close()
|
| 92 |
+
else:
|
| 93 |
+
resource.close()
|
| 94 |
+
elif hasattr(resource, 'aclose'):
|
| 95 |
+
await resource.aclose()
|
| 96 |
+
cleaned_resources += 1
|
| 97 |
+
except Exception as e:
|
| 98 |
+
log.warning(f"Failed to close resource {type(resource).__name__}: {e}")
|
| 99 |
+
|
| 100 |
+
if cleaned_resources > 0:
|
| 101 |
+
log.info(f"Cleaned up {cleaned_resources} resources")
|
| 102 |
+
|
| 103 |
+
self._tasks.clear()
|
| 104 |
+
self._resources.clear()
|
| 105 |
+
log.info("TaskManager shutdown completed")
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def is_shutdown(self) -> bool:
|
| 109 |
+
"""检查是否已经开始关闭"""
|
| 110 |
+
return self._shutdown_event.is_set()
|
| 111 |
+
|
| 112 |
+
def get_stats(self) -> Dict[str, int]:
|
| 113 |
+
"""获取任务管理统计信息"""
|
| 114 |
+
return {
|
| 115 |
+
'active_tasks': len(self._tasks),
|
| 116 |
+
'registered_resources': len(self._resources),
|
| 117 |
+
'is_shutdown': self.is_shutdown
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# 全局任务管理器实例
|
| 122 |
+
task_manager = TaskManager()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def create_managed_task(coro, *, name: str = None) -> asyncio.Task:
|
| 126 |
+
"""创建一个被管理的异步任务的便捷函数"""
|
| 127 |
+
return task_manager.create_task(coro, name=name)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def register_resource(resource: Any) -> Any:
|
| 131 |
+
"""注册资源的便捷函数"""
|
| 132 |
+
return task_manager.register_resource(resource)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
async def shutdown_all_tasks(timeout: float = 30.0):
|
| 136 |
+
"""关闭所有任务的便捷函数"""
|
| 137 |
+
await task_manager.shutdown(timeout)
|
src/usage_stats.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage statistics module for tracking API calls per credential file.
|
| 3 |
+
Uses the simpler logic: compare current time with next_reset_time.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime, timezone, timedelta
|
| 8 |
+
from threading import Lock
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
|
| 11 |
+
from config import get_credentials_dir, is_mongodb_mode
|
| 12 |
+
from log import log
|
| 13 |
+
from .state_manager import get_state_manager
|
| 14 |
+
from .storage_adapter import get_storage_adapter
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _get_next_utc_7am() -> datetime:
|
| 18 |
+
"""
|
| 19 |
+
Calculate the next UTC 07:00 time for quota reset.
|
| 20 |
+
"""
|
| 21 |
+
now = datetime.now(timezone.utc)
|
| 22 |
+
today_7am = now.replace(hour=7, minute=0, second=0, microsecond=0)
|
| 23 |
+
|
| 24 |
+
if now < today_7am:
|
| 25 |
+
return today_7am
|
| 26 |
+
else:
|
| 27 |
+
return today_7am + timedelta(days=1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class UsageStats:
|
| 31 |
+
"""
|
| 32 |
+
Simplified usage statistics manager with clear reset logic.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self._lock = Lock()
|
| 37 |
+
# 状态文件路径将在初始化时异步设置
|
| 38 |
+
self._state_file = None
|
| 39 |
+
self._state_manager = None
|
| 40 |
+
self._storage_adapter = None
|
| 41 |
+
self._stats_cache: Dict[str, Dict[str, Any]] = {}
|
| 42 |
+
self._initialized = False
|
| 43 |
+
self._cache_dirty = False # 缓存脏标记,减少不必要的写入
|
| 44 |
+
self._last_save_time = 0
|
| 45 |
+
self._save_interval = 60 # 最多每分钟保存一次,减少I/O
|
| 46 |
+
self._max_cache_size = 100 # 严格限制缓存大小
|
| 47 |
+
|
| 48 |
+
async def initialize(self):
|
| 49 |
+
"""Initialize the usage stats module."""
|
| 50 |
+
if self._initialized:
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
# 初始化存储适配器
|
| 54 |
+
self._storage_adapter = await get_storage_adapter()
|
| 55 |
+
|
| 56 |
+
# 只在文件模式下创建本地状态文件
|
| 57 |
+
if not await is_mongodb_mode():
|
| 58 |
+
credentials_dir = await get_credentials_dir()
|
| 59 |
+
self._state_file = os.path.join(credentials_dir, "creds_state.toml")
|
| 60 |
+
self._state_manager = get_state_manager(self._state_file)
|
| 61 |
+
|
| 62 |
+
await self._load_stats()
|
| 63 |
+
self._initialized = True
|
| 64 |
+
storage_type = "MongoDB" if await is_mongodb_mode() else "File"
|
| 65 |
+
log.debug(f"Usage statistics module initialized with {storage_type} storage backend")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _normalize_filename(self, filename: str) -> str:
|
| 69 |
+
"""Normalize filename to relative path for consistent storage."""
|
| 70 |
+
if not filename:
|
| 71 |
+
return ""
|
| 72 |
+
|
| 73 |
+
if os.path.sep not in filename and "/" not in filename:
|
| 74 |
+
return filename
|
| 75 |
+
|
| 76 |
+
return os.path.basename(filename)
|
| 77 |
+
|
| 78 |
+
def _is_gemini_2_5_pro(self, model_name: str) -> bool:
|
| 79 |
+
"""
|
| 80 |
+
Check if model is gemini-2.5-pro variant (including prefixes and suffixes).
|
| 81 |
+
"""
|
| 82 |
+
if not model_name:
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
from config import get_base_model_name, get_base_model_from_feature_model
|
| 87 |
+
|
| 88 |
+
# Remove feature prefixes (流式抗截断/, 假流式/)
|
| 89 |
+
base_with_suffix = get_base_model_from_feature_model(model_name)
|
| 90 |
+
|
| 91 |
+
# Remove thinking/search suffixes (-maxthinking, -nothinking, -search)
|
| 92 |
+
pure_base_model = get_base_model_name(base_with_suffix)
|
| 93 |
+
|
| 94 |
+
# Check if the pure base model is exactly "gemini-2.5-pro"
|
| 95 |
+
return pure_base_model == "gemini-2.5-pro"
|
| 96 |
+
|
| 97 |
+
except ImportError:
|
| 98 |
+
# Fallback logic if config import fails
|
| 99 |
+
clean_model = model_name
|
| 100 |
+
for prefix in ["流式抗截断/", "假流式/"]:
|
| 101 |
+
if clean_model.startswith(prefix):
|
| 102 |
+
clean_model = clean_model[len(prefix):]
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
for suffix in ["-maxthinking", "-nothinking", "-search"]:
|
| 106 |
+
if clean_model.endswith(suffix):
|
| 107 |
+
clean_model = clean_model[:-len(suffix)]
|
| 108 |
+
break
|
| 109 |
+
|
| 110 |
+
return clean_model == "gemini-2.5-pro"
|
| 111 |
+
|
| 112 |
+
async def _load_stats(self):
|
| 113 |
+
"""Load statistics from unified storage"""
|
| 114 |
+
try:
|
| 115 |
+
# 从统一存储获取所有使用统计,添加超时机制防止卡死
|
| 116 |
+
import asyncio
|
| 117 |
+
|
| 118 |
+
async def load_stats_with_timeout():
|
| 119 |
+
all_usage_stats = await self._storage_adapter.get_all_usage_stats()
|
| 120 |
+
|
| 121 |
+
log.debug(f"Processing {len(all_usage_stats)} usage statistics items...")
|
| 122 |
+
|
| 123 |
+
# 直接处理统计数据
|
| 124 |
+
stats_cache = {}
|
| 125 |
+
processed_count = 0
|
| 126 |
+
|
| 127 |
+
for filename, stats_data in all_usage_stats.items():
|
| 128 |
+
if isinstance(stats_data, dict):
|
| 129 |
+
normalized_filename = self._normalize_filename(filename)
|
| 130 |
+
|
| 131 |
+
# 提取使用统计字段
|
| 132 |
+
usage_data = {
|
| 133 |
+
"gemini_2_5_pro_calls": stats_data.get("gemini_2_5_pro_calls", 0),
|
| 134 |
+
"total_calls": stats_data.get("total_calls", 0),
|
| 135 |
+
"next_reset_time": stats_data.get("next_reset_time"),
|
| 136 |
+
"daily_limit_gemini_2_5_pro": stats_data.get("daily_limit_gemini_2_5_pro", 100),
|
| 137 |
+
"daily_limit_total": stats_data.get("daily_limit_total", 1000)
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
# 只加载有实际使用数据的统计,或者有reset时间的
|
| 141 |
+
if (usage_data.get("gemini_2_5_pro_calls", 0) > 0 or
|
| 142 |
+
usage_data.get("total_calls", 0) > 0 or
|
| 143 |
+
usage_data.get("next_reset_time")):
|
| 144 |
+
stats_cache[normalized_filename] = usage_data
|
| 145 |
+
processed_count += 1
|
| 146 |
+
|
| 147 |
+
return stats_cache, processed_count
|
| 148 |
+
|
| 149 |
+
# 设置15秒超时防止卡死
|
| 150 |
+
try:
|
| 151 |
+
self._stats_cache, processed_count = await asyncio.wait_for(
|
| 152 |
+
load_stats_with_timeout(), timeout=15.0
|
| 153 |
+
)
|
| 154 |
+
log.debug(f"Loaded usage statistics for {processed_count} credential files")
|
| 155 |
+
except asyncio.TimeoutError:
|
| 156 |
+
log.error("Loading usage statistics timed out after 30 seconds, using empty cache")
|
| 157 |
+
self._stats_cache = {}
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
log.error(f"Failed to load usage statistics: {e}")
|
| 162 |
+
self._stats_cache = {}
|
| 163 |
+
|
| 164 |
+
async def _save_stats(self):
|
| 165 |
+
"""Save statistics to unified storage."""
|
| 166 |
+
current_time = time.time()
|
| 167 |
+
|
| 168 |
+
# 使用脏标记和时间间隔控制,减少不必要的写入
|
| 169 |
+
if not self._cache_dirty or (current_time - self._last_save_time < self._save_interval):
|
| 170 |
+
return
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
# 批量更新使用统计到存储适配器
|
| 174 |
+
log.debug(f"Saving {len(self._stats_cache)} usage statistics items...")
|
| 175 |
+
|
| 176 |
+
saved_count = 0
|
| 177 |
+
for filename, stats in self._stats_cache.items():
|
| 178 |
+
try:
|
| 179 |
+
stats_data = {
|
| 180 |
+
"gemini_2_5_pro_calls": stats.get("gemini_2_5_pro_calls", 0),
|
| 181 |
+
"total_calls": stats.get("total_calls", 0),
|
| 182 |
+
"next_reset_time": stats.get("next_reset_time"),
|
| 183 |
+
"daily_limit_gemini_2_5_pro": stats.get("daily_limit_gemini_2_5_pro", 100),
|
| 184 |
+
"daily_limit_total": stats.get("daily_limit_total", 1000)
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
success = await self._storage_adapter.update_usage_stats(filename, stats_data)
|
| 188 |
+
if success:
|
| 189 |
+
saved_count += 1
|
| 190 |
+
except Exception as e:
|
| 191 |
+
log.error(f"Failed to save stats for {filename}: {e}")
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
self._cache_dirty = False # 清除脏标记
|
| 195 |
+
self._last_save_time = current_time
|
| 196 |
+
log.debug(f"Successfully saved {saved_count}/{len(self._stats_cache)} usage statistics to unified storage")
|
| 197 |
+
except Exception as e:
|
| 198 |
+
log.error(f"Failed to save usage statistics: {e}")
|
| 199 |
+
|
| 200 |
+
def _get_or_create_stats(self, filename: str) -> Dict[str, Any]:
|
| 201 |
+
"""Get or create statistics entry for a credential file."""
|
| 202 |
+
normalized_filename = self._normalize_filename(filename)
|
| 203 |
+
|
| 204 |
+
if normalized_filename not in self._stats_cache:
|
| 205 |
+
# 严格控制缓存大小 - 超过限制时删除最旧的条目
|
| 206 |
+
if len(self._stats_cache) >= self._max_cache_size:
|
| 207 |
+
# 删除最旧的统计数据(基于next_reset_time或没有该字段的)
|
| 208 |
+
oldest_key = min(self._stats_cache.keys(),
|
| 209 |
+
key=lambda k: self._stats_cache[k].get('next_reset_time', ''))
|
| 210 |
+
del self._stats_cache[oldest_key]
|
| 211 |
+
self._cache_dirty = True
|
| 212 |
+
log.debug(f"Removed oldest usage stats cache entry: {oldest_key}")
|
| 213 |
+
|
| 214 |
+
next_reset = _get_next_utc_7am()
|
| 215 |
+
self._stats_cache[normalized_filename] = {
|
| 216 |
+
"gemini_2_5_pro_calls": 0,
|
| 217 |
+
"total_calls": 0,
|
| 218 |
+
"next_reset_time": next_reset.isoformat(),
|
| 219 |
+
"daily_limit_gemini_2_5_pro": 100,
|
| 220 |
+
"daily_limit_total": 1000
|
| 221 |
+
}
|
| 222 |
+
self._cache_dirty = True # 标记缓存已修改
|
| 223 |
+
|
| 224 |
+
return self._stats_cache[normalized_filename]
|
| 225 |
+
|
| 226 |
+
def _check_and_reset_daily_quota(self, stats: Dict[str, Any]) -> bool:
|
| 227 |
+
"""
|
| 228 |
+
Simple reset logic: if current time >= next_reset_time, then reset.
|
| 229 |
+
"""
|
| 230 |
+
try:
|
| 231 |
+
next_reset_str = stats.get("next_reset_time")
|
| 232 |
+
if not next_reset_str:
|
| 233 |
+
# No next reset time recorded, set it up
|
| 234 |
+
next_reset = _get_next_utc_7am()
|
| 235 |
+
stats["next_reset_time"] = next_reset.isoformat()
|
| 236 |
+
return False
|
| 237 |
+
|
| 238 |
+
next_reset = datetime.fromisoformat(next_reset_str)
|
| 239 |
+
now = datetime.now(timezone.utc)
|
| 240 |
+
|
| 241 |
+
# Simple comparison: if current time >= next reset time, then reset
|
| 242 |
+
if now >= next_reset:
|
| 243 |
+
old_gemini_calls = stats.get("gemini_2_5_pro_calls", 0)
|
| 244 |
+
old_total_calls = stats.get("total_calls", 0)
|
| 245 |
+
|
| 246 |
+
# Reset counters and set new next reset time
|
| 247 |
+
new_next_reset = _get_next_utc_7am()
|
| 248 |
+
stats.update({
|
| 249 |
+
"gemini_2_5_pro_calls": 0,
|
| 250 |
+
"total_calls": 0,
|
| 251 |
+
"next_reset_time": new_next_reset.isoformat()
|
| 252 |
+
})
|
| 253 |
+
|
| 254 |
+
self._cache_dirty = True # 标记缓存已修改
|
| 255 |
+
log.info(f"Daily quota reset performed. Previous stats - Gemini 2.5 Pro: {old_gemini_calls}, Total: {old_total_calls}")
|
| 256 |
+
return True
|
| 257 |
+
|
| 258 |
+
return False
|
| 259 |
+
except Exception as e:
|
| 260 |
+
log.error(f"Error in daily quota reset check: {e}")
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
async def record_successful_call(self, filename: str, model_name: str):
|
| 264 |
+
"""Record a successful API call for statistics."""
|
| 265 |
+
if not self._initialized:
|
| 266 |
+
await self.initialize()
|
| 267 |
+
|
| 268 |
+
with self._lock:
|
| 269 |
+
try:
|
| 270 |
+
normalized_filename = self._normalize_filename(filename)
|
| 271 |
+
stats = self._get_or_create_stats(normalized_filename)
|
| 272 |
+
|
| 273 |
+
# Check and perform daily reset if needed
|
| 274 |
+
reset_performed = self._check_and_reset_daily_quota(stats)
|
| 275 |
+
|
| 276 |
+
# Increment counters
|
| 277 |
+
is_gemini_2_5_pro = self._is_gemini_2_5_pro(model_name)
|
| 278 |
+
|
| 279 |
+
stats["total_calls"] += 1
|
| 280 |
+
if is_gemini_2_5_pro:
|
| 281 |
+
stats["gemini_2_5_pro_calls"] += 1
|
| 282 |
+
|
| 283 |
+
self._cache_dirty = True # 标记缓存已修改
|
| 284 |
+
|
| 285 |
+
log.debug(f"Usage recorded - File: {normalized_filename}, Model: {model_name}, "
|
| 286 |
+
f"Gemini 2.5 Pro: {stats['gemini_2_5_pro_calls']}/{stats.get('daily_limit_gemini_2_5_pro', 100)}, "
|
| 287 |
+
f"Total: {stats['total_calls']}/{stats.get('daily_limit_total', 1000)}")
|
| 288 |
+
|
| 289 |
+
if reset_performed:
|
| 290 |
+
log.info(f"Daily quota was reset for {normalized_filename}")
|
| 291 |
+
|
| 292 |
+
except Exception as e:
|
| 293 |
+
log.error(f"Failed to record usage statistics: {e}")
|
| 294 |
+
|
| 295 |
+
# Save stats asynchronously
|
| 296 |
+
try:
|
| 297 |
+
await self._save_stats()
|
| 298 |
+
except Exception as e:
|
| 299 |
+
log.error(f"Failed to save usage statistics after recording: {e}")
|
| 300 |
+
|
| 301 |
+
async def get_usage_stats(self, filename: str = None) -> Dict[str, Any]:
|
| 302 |
+
"""Get usage statistics."""
|
| 303 |
+
if not self._initialized:
|
| 304 |
+
await self.initialize()
|
| 305 |
+
|
| 306 |
+
with self._lock:
|
| 307 |
+
if filename:
|
| 308 |
+
normalized_filename = self._normalize_filename(filename)
|
| 309 |
+
stats = self._get_or_create_stats(normalized_filename)
|
| 310 |
+
# Check for daily reset before returning stats
|
| 311 |
+
self._check_and_reset_daily_quota(stats)
|
| 312 |
+
return {
|
| 313 |
+
"filename": normalized_filename,
|
| 314 |
+
"gemini_2_5_pro_calls": stats.get("gemini_2_5_pro_calls", 0),
|
| 315 |
+
"total_calls": stats.get("total_calls", 0),
|
| 316 |
+
"daily_limit_gemini_2_5_pro": stats.get("daily_limit_gemini_2_5_pro", 100),
|
| 317 |
+
"daily_limit_total": stats.get("daily_limit_total", 1000),
|
| 318 |
+
"next_reset_time": stats.get("next_reset_time")
|
| 319 |
+
}
|
| 320 |
+
else:
|
| 321 |
+
# Return all statistics
|
| 322 |
+
all_stats = {}
|
| 323 |
+
for filename, stats in self._stats_cache.items():
|
| 324 |
+
# Check for daily reset for each file
|
| 325 |
+
self._check_and_reset_daily_quota(stats)
|
| 326 |
+
all_stats[filename] = {
|
| 327 |
+
"gemini_2_5_pro_calls": stats.get("gemini_2_5_pro_calls", 0),
|
| 328 |
+
"total_calls": stats.get("total_calls", 0),
|
| 329 |
+
"daily_limit_gemini_2_5_pro": stats.get("daily_limit_gemini_2_5_pro", 100),
|
| 330 |
+
"daily_limit_total": stats.get("daily_limit_total", 1000),
|
| 331 |
+
"next_reset_time": stats.get("next_reset_time")
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
return all_stats
|
| 335 |
+
|
| 336 |
+
async def get_aggregated_stats(self) -> Dict[str, Any]:
|
| 337 |
+
"""Get aggregated statistics across all credential files."""
|
| 338 |
+
if not self._initialized:
|
| 339 |
+
await self.initialize()
|
| 340 |
+
|
| 341 |
+
all_stats = await self.get_usage_stats()
|
| 342 |
+
|
| 343 |
+
total_gemini_2_5_pro = 0
|
| 344 |
+
total_all_models = 0
|
| 345 |
+
total_files = len(all_stats)
|
| 346 |
+
|
| 347 |
+
for stats in all_stats.values():
|
| 348 |
+
total_gemini_2_5_pro += stats["gemini_2_5_pro_calls"]
|
| 349 |
+
total_all_models += stats["total_calls"]
|
| 350 |
+
|
| 351 |
+
return {
|
| 352 |
+
"total_files": total_files,
|
| 353 |
+
"total_gemini_2_5_pro_calls": total_gemini_2_5_pro,
|
| 354 |
+
"total_all_model_calls": total_all_models,
|
| 355 |
+
"avg_gemini_2_5_pro_per_file": total_gemini_2_5_pro / max(total_files, 1),
|
| 356 |
+
"avg_total_per_file": total_all_models / max(total_files, 1),
|
| 357 |
+
"next_reset_time": _get_next_utc_7am().isoformat()
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
async def update_daily_limits(self, filename: str, gemini_2_5_pro_limit: int = None,
|
| 361 |
+
total_limit: int = None):
|
| 362 |
+
"""Update daily limits for a specific credential file."""
|
| 363 |
+
if not self._initialized:
|
| 364 |
+
await self.initialize()
|
| 365 |
+
|
| 366 |
+
with self._lock:
|
| 367 |
+
try:
|
| 368 |
+
normalized_filename = self._normalize_filename(filename)
|
| 369 |
+
stats = self._get_or_create_stats(normalized_filename)
|
| 370 |
+
|
| 371 |
+
if gemini_2_5_pro_limit is not None:
|
| 372 |
+
stats["daily_limit_gemini_2_5_pro"] = gemini_2_5_pro_limit
|
| 373 |
+
|
| 374 |
+
if total_limit is not None:
|
| 375 |
+
stats["daily_limit_total"] = total_limit
|
| 376 |
+
|
| 377 |
+
log.info(f"Updated daily limits for {normalized_filename}: "
|
| 378 |
+
f"Gemini 2.5 Pro = {stats.get('daily_limit_gemini_2_5_pro', 100)}, "
|
| 379 |
+
f"Total = {stats.get('daily_limit_total', 1000)}")
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
log.error(f"Failed to update daily limits: {e}")
|
| 383 |
+
raise
|
| 384 |
+
|
| 385 |
+
await self._save_stats()
|
| 386 |
+
|
| 387 |
+
async def reset_stats(self, filename: str = None):
|
| 388 |
+
"""Reset usage statistics."""
|
| 389 |
+
if not self._initialized:
|
| 390 |
+
await self.initialize()
|
| 391 |
+
|
| 392 |
+
with self._lock:
|
| 393 |
+
if filename:
|
| 394 |
+
normalized_filename = self._normalize_filename(filename)
|
| 395 |
+
if normalized_filename in self._stats_cache:
|
| 396 |
+
# Manual reset: reset counters and set new next reset time
|
| 397 |
+
next_reset = _get_next_utc_7am()
|
| 398 |
+
self._stats_cache[normalized_filename].update({
|
| 399 |
+
"gemini_2_5_pro_calls": 0,
|
| 400 |
+
"total_calls": 0,
|
| 401 |
+
"next_reset_time": next_reset.isoformat()
|
| 402 |
+
})
|
| 403 |
+
log.info(f"Reset usage statistics for {normalized_filename}")
|
| 404 |
+
else:
|
| 405 |
+
# Reset all statistics
|
| 406 |
+
next_reset = _get_next_utc_7am()
|
| 407 |
+
for filename, stats in self._stats_cache.items():
|
| 408 |
+
stats.update({
|
| 409 |
+
"gemini_2_5_pro_calls": 0,
|
| 410 |
+
"total_calls": 0,
|
| 411 |
+
"next_reset_time": next_reset.isoformat()
|
| 412 |
+
})
|
| 413 |
+
log.info("Reset usage statistics for all credential files")
|
| 414 |
+
|
| 415 |
+
await self._save_stats()
|
| 416 |
+
|
| 417 |
+
# Global instance
|
| 418 |
+
_usage_stats_instance: Optional[UsageStats] = None
|
| 419 |
+
|
| 420 |
+
async def get_usage_stats_instance() -> UsageStats:
|
| 421 |
+
"""Get the global usage statistics instance."""
|
| 422 |
+
global _usage_stats_instance
|
| 423 |
+
if _usage_stats_instance is None:
|
| 424 |
+
_usage_stats_instance = UsageStats()
|
| 425 |
+
await _usage_stats_instance.initialize()
|
| 426 |
+
return _usage_stats_instance
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
async def record_successful_call(filename: str, model_name: str):
|
| 430 |
+
"""Convenience function to record a successful API call."""
|
| 431 |
+
stats = await get_usage_stats_instance()
|
| 432 |
+
await stats.record_successful_call(filename, model_name)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
async def get_usage_stats(filename: str = None) -> Dict[str, Any]:
|
| 436 |
+
"""Convenience function to get usage statistics."""
|
| 437 |
+
stats = await get_usage_stats_instance()
|
| 438 |
+
return await stats.get_usage_stats(filename)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
async def get_aggregated_stats() -> Dict[str, Any]:
|
| 442 |
+
"""Convenience function to get aggregated statistics."""
|
| 443 |
+
stats = await get_usage_stats_instance()
|
| 444 |
+
return await stats.get_aggregated_stats()
|
src/utils.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import platform
|
| 2 |
+
|
| 3 |
+
CLI_VERSION = "0.1.5" # Match current gemini-cli version
|
| 4 |
+
|
| 5 |
+
def get_user_agent():
|
| 6 |
+
"""Generate User-Agent string matching gemini-cli format."""
|
| 7 |
+
version = CLI_VERSION
|
| 8 |
+
system = platform.system()
|
| 9 |
+
arch = platform.machine()
|
| 10 |
+
return f"GeminiCLI/{version} ({system}; {arch})"
|
src/web_routes.py
ADDED
|
@@ -0,0 +1,1738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Web路由模块 - 处理认证相关的HTTP请求和控制面板功能
|
| 3 |
+
用于与上级web.py集成
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
import datetime
|
| 7 |
+
import io
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import zipfile
|
| 12 |
+
from collections import deque
|
| 13 |
+
from typing import List, Optional, Dict, Any
|
| 14 |
+
from urllib.parse import urlparse
|
| 15 |
+
|
| 16 |
+
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, WebSocket, WebSocketDisconnect, Request
|
| 17 |
+
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse, Response
|
| 18 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 19 |
+
from pydantic import BaseModel
|
| 20 |
+
from starlette.websockets import WebSocketState
|
| 21 |
+
import toml
|
| 22 |
+
import zipfile
|
| 23 |
+
import httpx
|
| 24 |
+
|
| 25 |
+
import config
|
| 26 |
+
from log import log
|
| 27 |
+
from .auth import (
|
| 28 |
+
create_auth_url, get_auth_status,
|
| 29 |
+
verify_password, generate_auth_token, verify_auth_token,
|
| 30 |
+
asyncio_complete_auth_flow, complete_auth_flow_from_callback_url,
|
| 31 |
+
load_credentials_from_env, clear_env_credentials
|
| 32 |
+
)
|
| 33 |
+
from .credential_manager import CredentialManager
|
| 34 |
+
from .usage_stats import get_usage_stats, get_aggregated_stats, get_usage_stats_instance
|
| 35 |
+
from .storage_adapter import get_storage_adapter
|
| 36 |
+
|
| 37 |
+
# 创建路由器
|
| 38 |
+
router = APIRouter()
|
| 39 |
+
security = HTTPBearer()
|
| 40 |
+
|
| 41 |
+
# 创建credential manager实例
|
| 42 |
+
credential_manager = CredentialManager()
|
| 43 |
+
|
| 44 |
+
# WebSocket连接管理
|
| 45 |
+
|
| 46 |
+
class ConnectionManager:
|
| 47 |
+
def __init__(self, max_connections: int = 3): # 进一步降低最大连接数
|
| 48 |
+
# 使用双端队列严格限制内存使用
|
| 49 |
+
self.active_connections: deque = deque(maxlen=max_connections)
|
| 50 |
+
self.max_connections = max_connections
|
| 51 |
+
self._last_cleanup = 0
|
| 52 |
+
self._cleanup_interval = 120 # 120秒清理一次死连接
|
| 53 |
+
|
| 54 |
+
async def connect(self, websocket: WebSocket):
|
| 55 |
+
# 自动清理死连接
|
| 56 |
+
self._auto_cleanup()
|
| 57 |
+
|
| 58 |
+
# 限制最大连接数,防止内存无限增长
|
| 59 |
+
if len(self.active_connections) >= self.max_connections:
|
| 60 |
+
await websocket.close(code=1008, reason="Too many connections")
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
await websocket.accept()
|
| 64 |
+
self.active_connections.append(websocket)
|
| 65 |
+
log.debug(f"WebSocket连接建立,当前连接数: {len(self.active_connections)}")
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
def disconnect(self, websocket: WebSocket):
|
| 69 |
+
# 使用更高效的方式移除连接
|
| 70 |
+
try:
|
| 71 |
+
self.active_connections.remove(websocket)
|
| 72 |
+
except ValueError:
|
| 73 |
+
pass # 连接已不存在
|
| 74 |
+
log.debug(f"WebSocket连接断开,当前连接数: {len(self.active_connections)}")
|
| 75 |
+
|
| 76 |
+
async def send_personal_message(self, message: str, websocket: WebSocket):
|
| 77 |
+
try:
|
| 78 |
+
await websocket.send_text(message)
|
| 79 |
+
except Exception:
|
| 80 |
+
self.disconnect(websocket)
|
| 81 |
+
|
| 82 |
+
async def broadcast(self, message: str):
|
| 83 |
+
# 使用更高效的方式处理广播,避免索引操作
|
| 84 |
+
dead_connections = []
|
| 85 |
+
for conn in self.active_connections:
|
| 86 |
+
try:
|
| 87 |
+
await conn.send_text(message)
|
| 88 |
+
except Exception:
|
| 89 |
+
dead_connections.append(conn)
|
| 90 |
+
|
| 91 |
+
# 批量移除死连接
|
| 92 |
+
for dead_conn in dead_connections:
|
| 93 |
+
self.disconnect(dead_conn)
|
| 94 |
+
|
| 95 |
+
def _auto_cleanup(self):
|
| 96 |
+
"""自动清理死连接"""
|
| 97 |
+
current_time = time.time()
|
| 98 |
+
if current_time - self._last_cleanup > self._cleanup_interval:
|
| 99 |
+
self.cleanup_dead_connections()
|
| 100 |
+
self._last_cleanup = current_time
|
| 101 |
+
|
| 102 |
+
def cleanup_dead_connections(self):
|
| 103 |
+
"""清理已断开的连接"""
|
| 104 |
+
original_count = len(self.active_connections)
|
| 105 |
+
# 使用列表推导式过滤活跃连接,更高效
|
| 106 |
+
alive_connections = deque([
|
| 107 |
+
conn for conn in self.active_connections
|
| 108 |
+
if hasattr(conn, 'client_state') and conn.client_state != WebSocketState.DISCONNECTED
|
| 109 |
+
], maxlen=self.max_connections)
|
| 110 |
+
|
| 111 |
+
self.active_connections = alive_connections
|
| 112 |
+
cleaned = original_count - len(self.active_connections)
|
| 113 |
+
if cleaned > 0:
|
| 114 |
+
log.debug(f"清理了 {cleaned} 个死连接,剩余连接数: {len(self.active_connections)}")
|
| 115 |
+
|
| 116 |
+
manager = ConnectionManager()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
async def ensure_credential_manager_initialized():
|
| 120 |
+
"""确保credential manager已初始化"""
|
| 121 |
+
if not credential_manager._initialized:
|
| 122 |
+
await credential_manager.initialize()
|
| 123 |
+
|
| 124 |
+
async def get_credential_manager():
|
| 125 |
+
"""获取全局凭证管理器实例"""
|
| 126 |
+
global credential_manager
|
| 127 |
+
if not credential_manager:
|
| 128 |
+
credential_manager = CredentialManager()
|
| 129 |
+
await credential_manager.initialize()
|
| 130 |
+
return credential_manager
|
| 131 |
+
|
| 132 |
+
async def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
|
| 133 |
+
"""验证用户密码(控制面板使用)"""
|
| 134 |
+
from config import get_panel_password
|
| 135 |
+
password = await get_panel_password()
|
| 136 |
+
token = credentials.credentials
|
| 137 |
+
if token != password:
|
| 138 |
+
raise HTTPException(status_code=403, detail="密码错误")
|
| 139 |
+
return token
|
| 140 |
+
|
| 141 |
+
class LoginRequest(BaseModel):
|
| 142 |
+
password: str
|
| 143 |
+
|
| 144 |
+
class AuthStartRequest(BaseModel):
|
| 145 |
+
project_id: Optional[str] = None # 现在是可选的
|
| 146 |
+
get_all_projects: Optional[bool] = False # 是否为所有项目获取凭证
|
| 147 |
+
|
| 148 |
+
class AuthCallbackRequest(BaseModel):
|
| 149 |
+
project_id: Optional[str] = None # 现在是可选的
|
| 150 |
+
get_all_projects: Optional[bool] = False # 是否为所有项目获取凭证
|
| 151 |
+
|
| 152 |
+
class AuthCallbackUrlRequest(BaseModel):
|
| 153 |
+
callback_url: str # OAuth回调完整URL
|
| 154 |
+
project_id: Optional[str] = None # 可选的项目ID
|
| 155 |
+
get_all_projects: Optional[bool] = False # 是否为所有项目获取凭证
|
| 156 |
+
|
| 157 |
+
class CredFileActionRequest(BaseModel):
|
| 158 |
+
filename: str
|
| 159 |
+
action: str # enable, disable, delete
|
| 160 |
+
|
| 161 |
+
class CredFileBatchActionRequest(BaseModel):
|
| 162 |
+
action: str # "enable", "disable", "delete"
|
| 163 |
+
filenames: List[str] # 批量操作的文件名列表
|
| 164 |
+
|
| 165 |
+
class ConfigSaveRequest(BaseModel):
|
| 166 |
+
config: dict
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 171 |
+
"""验证认证令牌"""
|
| 172 |
+
if not verify_auth_token(credentials.credentials):
|
| 173 |
+
raise HTTPException(status_code=401, detail="无效的认证令牌")
|
| 174 |
+
return credentials.credentials
|
| 175 |
+
|
| 176 |
+
def is_mobile_user_agent(user_agent: str) -> bool:
|
| 177 |
+
"""检测是否为移动设备用户代理"""
|
| 178 |
+
if not user_agent:
|
| 179 |
+
return False
|
| 180 |
+
|
| 181 |
+
user_agent_lower = user_agent.lower()
|
| 182 |
+
mobile_keywords = [
|
| 183 |
+
'mobile', 'android', 'iphone', 'ipad', 'ipod',
|
| 184 |
+
'blackberry', 'windows phone', 'samsung', 'htc',
|
| 185 |
+
'motorola', 'nokia', 'palm', 'webos', 'opera mini',
|
| 186 |
+
'opera mobi', 'fennec', 'minimo', 'symbian', 'psp',
|
| 187 |
+
'nintendo', 'tablet'
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
return any(keyword in user_agent_lower for keyword in mobile_keywords)
|
| 191 |
+
|
| 192 |
+
@router.get("/", response_class=HTMLResponse)
|
| 193 |
+
@router.get("/v1", response_class=HTMLResponse)
|
| 194 |
+
@router.get("/auth", response_class=HTMLResponse)
|
| 195 |
+
async def serve_control_panel(request: Request):
|
| 196 |
+
"""提供统一控制面板(包含认证、文件管理、配置等功能)"""
|
| 197 |
+
try:
|
| 198 |
+
# 获取用户代理并判断是否为移动设备
|
| 199 |
+
user_agent = request.headers.get("user-agent", "")
|
| 200 |
+
is_mobile = is_mobile_user_agent(user_agent)
|
| 201 |
+
|
| 202 |
+
# 根据设备类型选择相应的HTML文件
|
| 203 |
+
if is_mobile:
|
| 204 |
+
html_file_path = "front/control_panel_mobile.html"
|
| 205 |
+
log.info(f"Serving mobile control panel to user-agent: {user_agent}")
|
| 206 |
+
else:
|
| 207 |
+
html_file_path = "front/control_panel.html"
|
| 208 |
+
log.info(f"Serving desktop control panel to user-agent: {user_agent}")
|
| 209 |
+
|
| 210 |
+
with open(html_file_path, "r", encoding="utf-8") as f:
|
| 211 |
+
html_content = f.read()
|
| 212 |
+
return HTMLResponse(content=html_content)
|
| 213 |
+
except FileNotFoundError:
|
| 214 |
+
log.error(f"控制面板页面文件不存在: {html_file_path}")
|
| 215 |
+
# 如果移动端文件不存在,回退到桌面版
|
| 216 |
+
if is_mobile:
|
| 217 |
+
try:
|
| 218 |
+
with open("front/control_panel.html", "r", encoding="utf-8") as f:
|
| 219 |
+
html_content = f.read()
|
| 220 |
+
return HTMLResponse(content=html_content)
|
| 221 |
+
except FileNotFoundError:
|
| 222 |
+
raise HTTPException(status_code=404, detail="控制面板页面不存在")
|
| 223 |
+
else:
|
| 224 |
+
raise HTTPException(status_code=404, detail="控制面板页面不存在")
|
| 225 |
+
except Exception as e:
|
| 226 |
+
log.error(f"加载控制面板页面失败: {e}")
|
| 227 |
+
raise HTTPException(status_code=500, detail="服务器内部错误")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@router.post("/auth/login")
|
| 231 |
+
async def login(request: LoginRequest):
|
| 232 |
+
"""用户登录"""
|
| 233 |
+
try:
|
| 234 |
+
if await verify_password(request.password):
|
| 235 |
+
token = generate_auth_token()
|
| 236 |
+
return JSONResponse(content={"token": token, "message": "登录成功"})
|
| 237 |
+
else:
|
| 238 |
+
raise HTTPException(status_code=401, detail="密码错误")
|
| 239 |
+
except HTTPException:
|
| 240 |
+
raise
|
| 241 |
+
except Exception as e:
|
| 242 |
+
log.error(f"登录失败: {e}")
|
| 243 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
@router.post("/auth/start")
|
| 247 |
+
async def start_auth(request: AuthStartRequest, token: str = Depends(verify_token)):
|
| 248 |
+
"""开始认证流程,支持自动检测项目ID和批量获取所有项目"""
|
| 249 |
+
try:
|
| 250 |
+
# 检查是否为批量项目模式
|
| 251 |
+
if request.get_all_projects:
|
| 252 |
+
log.info("用户请求批量获取所有项目的凭证...")
|
| 253 |
+
project_id = None # 批量模式下不指定单个项目ID
|
| 254 |
+
else:
|
| 255 |
+
# 如果没有提供项目ID,尝试自动检测
|
| 256 |
+
project_id = request.project_id
|
| 257 |
+
if not project_id:
|
| 258 |
+
log.info("用户未提供项目ID,后续将使用自动检测...")
|
| 259 |
+
|
| 260 |
+
# 使用认证令牌作为用户会话标识
|
| 261 |
+
user_session = token if token else None
|
| 262 |
+
result = await create_auth_url(project_id, user_session, get_all_projects=request.get_all_projects)
|
| 263 |
+
|
| 264 |
+
if result['success']:
|
| 265 |
+
return JSONResponse(content={
|
| 266 |
+
"auth_url": result['auth_url'],
|
| 267 |
+
"state": result['state'],
|
| 268 |
+
"auto_project_detection": result.get('auto_project_detection', False),
|
| 269 |
+
"detected_project_id": result.get('detected_project_id'),
|
| 270 |
+
"get_all_projects": request.get_all_projects
|
| 271 |
+
})
|
| 272 |
+
else:
|
| 273 |
+
raise HTTPException(status_code=500, detail=result['error'])
|
| 274 |
+
|
| 275 |
+
except HTTPException:
|
| 276 |
+
raise
|
| 277 |
+
except Exception as e:
|
| 278 |
+
log.error(f"开始认证流程失败: {e}")
|
| 279 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@router.post("/auth/callback")
|
| 283 |
+
async def auth_callback(request: AuthCallbackRequest, token: str = Depends(verify_token)):
|
| 284 |
+
"""处理认证回调,支持自动检测项目ID和批量获取所有项目"""
|
| 285 |
+
try:
|
| 286 |
+
# 项目ID现在是可选的,在回调处理中进行自动检测
|
| 287 |
+
project_id = request.project_id
|
| 288 |
+
get_all_projects = request.get_all_projects
|
| 289 |
+
|
| 290 |
+
# 使用认证令牌作为用户会话标识
|
| 291 |
+
user_session = token if token else None
|
| 292 |
+
# 异步等待OAuth回调完成
|
| 293 |
+
result = await asyncio_complete_auth_flow(project_id, user_session, get_all_projects=get_all_projects)
|
| 294 |
+
|
| 295 |
+
if result['success']:
|
| 296 |
+
if get_all_projects and result.get('multiple_credentials'):
|
| 297 |
+
# 批量认证成功,返回多个凭证信息
|
| 298 |
+
return JSONResponse(content={
|
| 299 |
+
"multiple_credentials": result['multiple_credentials'],
|
| 300 |
+
"message": "批量认证成功,已为多个项目保存凭证"
|
| 301 |
+
})
|
| 302 |
+
else:
|
| 303 |
+
# 单项目认证成功
|
| 304 |
+
return JSONResponse(content={
|
| 305 |
+
"credentials": result['credentials'],
|
| 306 |
+
"file_path": result['file_path'],
|
| 307 |
+
"message": "认证成功,凭证已保存",
|
| 308 |
+
"auto_detected_project": result.get('auto_detected_project', False)
|
| 309 |
+
})
|
| 310 |
+
else:
|
| 311 |
+
# 如果需要手动项目ID或项目选择,在响应中标明
|
| 312 |
+
if result.get('requires_manual_project_id'):
|
| 313 |
+
# 使用JSON响应
|
| 314 |
+
return JSONResponse(
|
| 315 |
+
status_code=400,
|
| 316 |
+
content={
|
| 317 |
+
"error": result['error'],
|
| 318 |
+
"requires_manual_project_id": True
|
| 319 |
+
}
|
| 320 |
+
)
|
| 321 |
+
elif result.get('requires_project_selection'):
|
| 322 |
+
# 返回项目列表供用户选择
|
| 323 |
+
return JSONResponse(
|
| 324 |
+
status_code=400,
|
| 325 |
+
content={
|
| 326 |
+
"error": result['error'],
|
| 327 |
+
"requires_project_selection": True,
|
| 328 |
+
"available_projects": result['available_projects']
|
| 329 |
+
}
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
raise HTTPException(status_code=400, detail=result['error'])
|
| 333 |
+
|
| 334 |
+
except HTTPException:
|
| 335 |
+
raise
|
| 336 |
+
except Exception as e:
|
| 337 |
+
log.error(f"处理认证回调失败: {e}")
|
| 338 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@router.post("/auth/callback-url")
|
| 342 |
+
async def auth_callback_url(request: AuthCallbackUrlRequest, token: str = Depends(verify_token)):
|
| 343 |
+
"""从回调URL直接完成认证,支持批量获取所有项目"""
|
| 344 |
+
try:
|
| 345 |
+
# 验证URL格式
|
| 346 |
+
if not request.callback_url or not request.callback_url.startswith(('http://', 'https://')):
|
| 347 |
+
raise HTTPException(status_code=400, detail="请提供有效的回调URL")
|
| 348 |
+
|
| 349 |
+
# 从回调URL完成认证
|
| 350 |
+
result = await complete_auth_flow_from_callback_url(
|
| 351 |
+
request.callback_url,
|
| 352 |
+
request.project_id,
|
| 353 |
+
get_all_projects=request.get_all_projects
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if result['success']:
|
| 357 |
+
if request.get_all_projects and result.get('multiple_credentials'):
|
| 358 |
+
# 批量认证成功,返回多个凭证信息
|
| 359 |
+
return JSONResponse(content={
|
| 360 |
+
"multiple_credentials": result['multiple_credentials'],
|
| 361 |
+
"message": "从回调URL批量认证成功,已为多个项目保存凭证"
|
| 362 |
+
})
|
| 363 |
+
else:
|
| 364 |
+
# 单项目认证成功
|
| 365 |
+
return JSONResponse(content={
|
| 366 |
+
"credentials": result['credentials'],
|
| 367 |
+
"file_path": result['file_path'],
|
| 368 |
+
"message": "从回调URL认证成功,凭证已保存",
|
| 369 |
+
"auto_detected_project": result.get('auto_detected_project', False)
|
| 370 |
+
})
|
| 371 |
+
else:
|
| 372 |
+
# 处理各种错误情况
|
| 373 |
+
if result.get('requires_manual_project_id'):
|
| 374 |
+
return JSONResponse(
|
| 375 |
+
status_code=400,
|
| 376 |
+
content={
|
| 377 |
+
"error": result['error'],
|
| 378 |
+
"requires_manual_project_id": True
|
| 379 |
+
}
|
| 380 |
+
)
|
| 381 |
+
elif result.get('requires_project_selection'):
|
| 382 |
+
return JSONResponse(
|
| 383 |
+
status_code=400,
|
| 384 |
+
content={
|
| 385 |
+
"error": result['error'],
|
| 386 |
+
"requires_project_selection": True,
|
| 387 |
+
"available_projects": result['available_projects']
|
| 388 |
+
}
|
| 389 |
+
)
|
| 390 |
+
else:
|
| 391 |
+
raise HTTPException(status_code=400, detail=result['error'])
|
| 392 |
+
|
| 393 |
+
except HTTPException:
|
| 394 |
+
raise
|
| 395 |
+
except Exception as e:
|
| 396 |
+
log.error(f"从回调URL处理认证失败: {e}")
|
| 397 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
@router.get("/auth/status/{project_id}")
|
| 401 |
+
async def check_auth_status(project_id: str, token: str = Depends(verify_token)):
|
| 402 |
+
"""检查认证状态"""
|
| 403 |
+
try:
|
| 404 |
+
if not project_id:
|
| 405 |
+
raise HTTPException(status_code=400, detail="Project ID 不能为空")
|
| 406 |
+
|
| 407 |
+
status = get_auth_status(project_id)
|
| 408 |
+
return JSONResponse(content=status)
|
| 409 |
+
|
| 410 |
+
except Exception as e:
|
| 411 |
+
log.error(f"检查认证状态失败: {e}")
|
| 412 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
async def extract_json_files_from_zip(zip_file: UploadFile) -> List[dict]:
|
| 416 |
+
"""从ZIP文件中提取JSON文件"""
|
| 417 |
+
try:
|
| 418 |
+
# 读取ZIP文件内容
|
| 419 |
+
zip_content = await zip_file.read()
|
| 420 |
+
|
| 421 |
+
# 不限制ZIP文件大小,只在处理时控制文件数量
|
| 422 |
+
|
| 423 |
+
files_data = []
|
| 424 |
+
|
| 425 |
+
with zipfile.ZipFile(io.BytesIO(zip_content), 'r') as zip_ref:
|
| 426 |
+
# 获取ZIP中的所有文件
|
| 427 |
+
file_list = zip_ref.namelist()
|
| 428 |
+
json_files = [f for f in file_list if f.endswith('.json') and not f.startswith('__MACOSX/')]
|
| 429 |
+
|
| 430 |
+
if not json_files:
|
| 431 |
+
raise HTTPException(status_code=400, detail="ZIP文件中没有找到JSON文件")
|
| 432 |
+
|
| 433 |
+
log.info(f"从ZIP文件 {zip_file.filename} 中找到 {len(json_files)} 个JSON文件")
|
| 434 |
+
|
| 435 |
+
for json_filename in json_files:
|
| 436 |
+
try:
|
| 437 |
+
# 读取JSON文件内容
|
| 438 |
+
with zip_ref.open(json_filename) as json_file:
|
| 439 |
+
content = json_file.read()
|
| 440 |
+
|
| 441 |
+
try:
|
| 442 |
+
content_str = content.decode('utf-8')
|
| 443 |
+
except UnicodeDecodeError:
|
| 444 |
+
log.warning(f"跳过编码错误的文件: {json_filename}")
|
| 445 |
+
continue
|
| 446 |
+
|
| 447 |
+
# 使用原始文件名(去掉路径)
|
| 448 |
+
filename = os.path.basename(json_filename)
|
| 449 |
+
files_data.append({
|
| 450 |
+
'filename': filename,
|
| 451 |
+
'content': content_str
|
| 452 |
+
})
|
| 453 |
+
|
| 454 |
+
except Exception as e:
|
| 455 |
+
log.warning(f"处理ZIP中的文件 {json_filename} 时出错: {e}")
|
| 456 |
+
continue
|
| 457 |
+
|
| 458 |
+
log.info(f"成功从ZIP文件中提取 {len(files_data)} 个有效的JSON文件")
|
| 459 |
+
return files_data
|
| 460 |
+
|
| 461 |
+
except zipfile.BadZipFile:
|
| 462 |
+
raise HTTPException(status_code=400, detail="无效的ZIP文件格式")
|
| 463 |
+
except Exception as e:
|
| 464 |
+
log.error(f"处理ZIP文件失败: {e}")
|
| 465 |
+
raise HTTPException(status_code=500, detail=f"处理ZIP文件失败: {str(e)}")
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
@router.post("/auth/upload")
|
| 469 |
+
async def upload_credentials(files: List[UploadFile] = File(...), token: str = Depends(verify_token)):
|
| 470 |
+
"""批量上传认证文件"""
|
| 471 |
+
try:
|
| 472 |
+
if not files:
|
| 473 |
+
raise HTTPException(status_code=400, detail="请选择要上传的文件")
|
| 474 |
+
|
| 475 |
+
# 检查文件数量限制
|
| 476 |
+
if len(files) > 100:
|
| 477 |
+
raise HTTPException(status_code=400, detail=f"文件数量过多,最多支持100个文件,当前:{len(files)}个")
|
| 478 |
+
|
| 479 |
+
files_data = []
|
| 480 |
+
for file in files:
|
| 481 |
+
# 检查文件类型:支持JSON和ZIP
|
| 482 |
+
if file.filename.endswith('.zip'):
|
| 483 |
+
# 处理ZIP文件
|
| 484 |
+
zip_files_data = await extract_json_files_from_zip(file)
|
| 485 |
+
files_data.extend(zip_files_data)
|
| 486 |
+
log.info(f"从ZIP文件 {file.filename} 中提取了 {len(zip_files_data)} 个JSON文件")
|
| 487 |
+
|
| 488 |
+
elif file.filename.endswith('.json'):
|
| 489 |
+
# 处理单个JSON文件
|
| 490 |
+
# 流式读取文件内容
|
| 491 |
+
content_chunks = []
|
| 492 |
+
while True:
|
| 493 |
+
chunk = await file.read(8192) # 8KB chunks
|
| 494 |
+
if not chunk:
|
| 495 |
+
break
|
| 496 |
+
content_chunks.append(chunk)
|
| 497 |
+
|
| 498 |
+
content = b''.join(content_chunks)
|
| 499 |
+
try:
|
| 500 |
+
content_str = content.decode('utf-8')
|
| 501 |
+
except UnicodeDecodeError:
|
| 502 |
+
raise HTTPException(status_code=400, detail=f"文件 {file.filename} 编码格式不支持")
|
| 503 |
+
|
| 504 |
+
files_data.append({
|
| 505 |
+
'filename': file.filename,
|
| 506 |
+
'content': content_str
|
| 507 |
+
})
|
| 508 |
+
else:
|
| 509 |
+
raise HTTPException(status_code=400, detail=f"文件 {file.filename} 格式不支持,只支持JSON和ZIP文件")
|
| 510 |
+
|
| 511 |
+
# 获取存储适配器
|
| 512 |
+
storage_adapter = await get_storage_adapter()
|
| 513 |
+
|
| 514 |
+
# 分批处理大量文件以提高稳定性
|
| 515 |
+
batch_size = 1000 # 每批处理1000个文件
|
| 516 |
+
all_results = []
|
| 517 |
+
total_success = 0
|
| 518 |
+
|
| 519 |
+
for i in range(0, len(files_data), batch_size):
|
| 520 |
+
batch_files = files_data[i:i + batch_size]
|
| 521 |
+
|
| 522 |
+
# 使用并发处理提升文件上传性能
|
| 523 |
+
async def process_single_file(file_data):
|
| 524 |
+
"""处理单个文件的并发函数"""
|
| 525 |
+
try:
|
| 526 |
+
filename = file_data['filename']
|
| 527 |
+
content_str = file_data['content']
|
| 528 |
+
|
| 529 |
+
# 解析JSON内容
|
| 530 |
+
credential_data = json.loads(content_str)
|
| 531 |
+
|
| 532 |
+
# 存储到统一存储系统
|
| 533 |
+
success = await storage_adapter.store_credential(filename, credential_data)
|
| 534 |
+
if success:
|
| 535 |
+
# 创建默认状态记录(如果不存在)
|
| 536 |
+
try:
|
| 537 |
+
import time
|
| 538 |
+
default_state = {
|
| 539 |
+
"error_codes": [],
|
| 540 |
+
"disabled": False,
|
| 541 |
+
"last_success": time.time(),
|
| 542 |
+
"user_email": None,
|
| 543 |
+
"gemini_2_5_pro_calls": 0,
|
| 544 |
+
"total_calls": 0,
|
| 545 |
+
"next_reset_time": None,
|
| 546 |
+
"daily_limit_gemini_2_5_pro": 100,
|
| 547 |
+
"daily_limit_total": 1000
|
| 548 |
+
}
|
| 549 |
+
# 只在状态不存在时创建,避免覆盖现有状态
|
| 550 |
+
# 检查数据库中是否真正存在状态记录
|
| 551 |
+
all_states = await storage_adapter.get_all_credential_states()
|
| 552 |
+
if filename not in all_states:
|
| 553 |
+
await storage_adapter.update_credential_state(filename, default_state)
|
| 554 |
+
log.debug(f"Created default state for new credential: {filename}")
|
| 555 |
+
except Exception as e:
|
| 556 |
+
log.warning(f"Failed to create default state for {filename}: {e}")
|
| 557 |
+
|
| 558 |
+
log.debug(f"成功上传凭证文件: {filename}")
|
| 559 |
+
return {
|
| 560 |
+
"filename": filename,
|
| 561 |
+
"status": "success",
|
| 562 |
+
"message": "上传成功"
|
| 563 |
+
}
|
| 564 |
+
else:
|
| 565 |
+
return {
|
| 566 |
+
"filename": filename,
|
| 567 |
+
"status": "error",
|
| 568 |
+
"message": "存储失败"
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
except json.JSONDecodeError as e:
|
| 572 |
+
return {
|
| 573 |
+
"filename": file_data['filename'],
|
| 574 |
+
"status": "error",
|
| 575 |
+
"message": f"JSON格式错误: {str(e)}"
|
| 576 |
+
}
|
| 577 |
+
except Exception as e:
|
| 578 |
+
return {
|
| 579 |
+
"filename": file_data['filename'],
|
| 580 |
+
"status": "error",
|
| 581 |
+
"message": f"处理失败: {str(e)}"
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
# 并发处理这一批文件
|
| 585 |
+
log.info(f"开始并发处理 {len(batch_files)} 个文件...")
|
| 586 |
+
concurrent_tasks = [process_single_file(file_data) for file_data in batch_files]
|
| 587 |
+
batch_results = await asyncio.gather(*concurrent_tasks, return_exceptions=True)
|
| 588 |
+
|
| 589 |
+
# 处理异常结果
|
| 590 |
+
processed_results = []
|
| 591 |
+
batch_uploaded_count = 0
|
| 592 |
+
for result in batch_results:
|
| 593 |
+
if isinstance(result, Exception):
|
| 594 |
+
processed_results.append({
|
| 595 |
+
"filename": "unknown",
|
| 596 |
+
"status": "error",
|
| 597 |
+
"message": f"处理异常: {str(result)}"
|
| 598 |
+
})
|
| 599 |
+
else:
|
| 600 |
+
processed_results.append(result)
|
| 601 |
+
if result["status"] == "success":
|
| 602 |
+
batch_uploaded_count += 1
|
| 603 |
+
|
| 604 |
+
batch_results = processed_results
|
| 605 |
+
|
| 606 |
+
all_results.extend(batch_results)
|
| 607 |
+
total_success += batch_uploaded_count
|
| 608 |
+
|
| 609 |
+
# 记录批次进度
|
| 610 |
+
batch_num = (i // batch_size) + 1
|
| 611 |
+
total_batches = (len(files_data) + batch_size - 1) // batch_size
|
| 612 |
+
log.info(f"批次 {batch_num}/{total_batches} 完成: 成功 {batch_uploaded_count}/{len(batch_files)} 个文件")
|
| 613 |
+
|
| 614 |
+
if total_success > 0:
|
| 615 |
+
return JSONResponse(content={
|
| 616 |
+
"uploaded_count": total_success,
|
| 617 |
+
"total_count": len(files_data),
|
| 618 |
+
"results": all_results,
|
| 619 |
+
"message": f"批量上传完成: 成功 {total_success}/{len(files_data)} 个文件"
|
| 620 |
+
})
|
| 621 |
+
else:
|
| 622 |
+
raise HTTPException(status_code=400, detail="没有文件上传成功")
|
| 623 |
+
|
| 624 |
+
except HTTPException:
|
| 625 |
+
raise
|
| 626 |
+
except Exception as e:
|
| 627 |
+
log.error(f"批量上传失败: {e}")
|
| 628 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
@router.get("/creds/status")
|
| 632 |
+
async def get_creds_status(token: str = Depends(verify_token)):
|
| 633 |
+
"""获取所有凭证文件的状态"""
|
| 634 |
+
try:
|
| 635 |
+
await ensure_credential_manager_initialized()
|
| 636 |
+
|
| 637 |
+
# 获取存储适配器
|
| 638 |
+
storage_adapter = await get_storage_adapter()
|
| 639 |
+
|
| 640 |
+
# 获取所有凭证和状态
|
| 641 |
+
all_credentials = await storage_adapter.list_credentials()
|
| 642 |
+
all_states = await credential_manager.get_creds_status()
|
| 643 |
+
|
| 644 |
+
# 获取后端信息(一次性获取,避免重复查询)
|
| 645 |
+
backend_info = await storage_adapter.get_backend_info()
|
| 646 |
+
backend_type = backend_info.get("backend_type", "unknown")
|
| 647 |
+
|
| 648 |
+
# 并发处理所有凭证的数据获取(状态已获取,无需重复处理)
|
| 649 |
+
async def process_credential_data(filename):
|
| 650 |
+
"""并发处理单个凭证的数据获取"""
|
| 651 |
+
file_status = all_states.get(filename)
|
| 652 |
+
|
| 653 |
+
# 如果没有状态记录,创建默认状态
|
| 654 |
+
if not file_status:
|
| 655 |
+
try:
|
| 656 |
+
import time
|
| 657 |
+
default_state = {
|
| 658 |
+
"error_codes": [],
|
| 659 |
+
"disabled": False,
|
| 660 |
+
"last_success": time.time(),
|
| 661 |
+
"user_email": None,
|
| 662 |
+
"gemini_2_5_pro_calls": 0,
|
| 663 |
+
"total_calls": 0,
|
| 664 |
+
"next_reset_time": None,
|
| 665 |
+
"daily_limit_gemini_2_5_pro": 100,
|
| 666 |
+
"daily_limit_total": 1000
|
| 667 |
+
}
|
| 668 |
+
await storage_adapter.update_credential_state(filename, default_state)
|
| 669 |
+
file_status = default_state
|
| 670 |
+
log.debug(f"为凭证 {filename} 创建了默认状态记录")
|
| 671 |
+
except Exception as e:
|
| 672 |
+
log.warning(f"无法为凭证 {filename} 创建状态记录: {e}")
|
| 673 |
+
# 创建临时状态用于显示
|
| 674 |
+
file_status = {
|
| 675 |
+
"error_codes": [],
|
| 676 |
+
"disabled": False,
|
| 677 |
+
"last_success": time.time(),
|
| 678 |
+
"user_email": None,
|
| 679 |
+
"gemini_2_5_pro_calls": 0,
|
| 680 |
+
"total_calls": 0,
|
| 681 |
+
"next_reset_time": None,
|
| 682 |
+
"daily_limit_gemini_2_5_pro": 100,
|
| 683 |
+
"daily_limit_total": 1000
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
try:
|
| 687 |
+
# 从存储获取凭证数据
|
| 688 |
+
credential_data = await storage_adapter.get_credential(filename)
|
| 689 |
+
if credential_data:
|
| 690 |
+
result = {
|
| 691 |
+
"status": file_status,
|
| 692 |
+
"content": credential_data,
|
| 693 |
+
"filename": os.path.basename(filename),
|
| 694 |
+
"backend_type": backend_type, # 复用backend信息
|
| 695 |
+
"user_email": file_status.get("user_email")
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
# 如果是文件模式,添加文件元数据
|
| 699 |
+
if backend_type == "file" and os.path.exists(filename):
|
| 700 |
+
result.update({
|
| 701 |
+
"size": os.path.getsize(filename),
|
| 702 |
+
"modified_time": os.path.getmtime(filename)
|
| 703 |
+
})
|
| 704 |
+
|
| 705 |
+
return filename, result
|
| 706 |
+
else:
|
| 707 |
+
return filename, {
|
| 708 |
+
"status": file_status,
|
| 709 |
+
"content": None,
|
| 710 |
+
"filename": os.path.basename(filename),
|
| 711 |
+
"error": "凭证数据不存在"
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
except Exception as e:
|
| 715 |
+
log.error(f"读取凭证��件失败 {filename}: {e}")
|
| 716 |
+
return filename, {
|
| 717 |
+
"status": file_status,
|
| 718 |
+
"content": None,
|
| 719 |
+
"filename": os.path.basename(filename),
|
| 720 |
+
"error": str(e)
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
# 并发处理所有凭证数据获取
|
| 724 |
+
log.debug(f"开始并发获取 {len(all_credentials)} 个凭证数据...")
|
| 725 |
+
concurrent_tasks = [process_credential_data(filename) for filename in all_credentials]
|
| 726 |
+
results = await asyncio.gather(*concurrent_tasks, return_exceptions=True)
|
| 727 |
+
|
| 728 |
+
# 组装结果
|
| 729 |
+
creds_info = {}
|
| 730 |
+
for result in results:
|
| 731 |
+
if isinstance(result, Exception):
|
| 732 |
+
log.error(f"处理凭证状态异常: {result}")
|
| 733 |
+
else:
|
| 734 |
+
filename, credential_info = result
|
| 735 |
+
creds_info[filename] = credential_info
|
| 736 |
+
|
| 737 |
+
return JSONResponse(content={"creds": creds_info})
|
| 738 |
+
|
| 739 |
+
except Exception as e:
|
| 740 |
+
log.error(f"获取凭证状态失败: {e}")
|
| 741 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
@router.post("/creds/action")
|
| 745 |
+
async def creds_action(request: CredFileActionRequest, token: str = Depends(verify_token)):
|
| 746 |
+
"""对凭证文件执行操作(启用/禁用/删除)"""
|
| 747 |
+
try:
|
| 748 |
+
await ensure_credential_manager_initialized()
|
| 749 |
+
|
| 750 |
+
log.info(f"Received request: {request}")
|
| 751 |
+
|
| 752 |
+
filename = request.filename
|
| 753 |
+
action = request.action
|
| 754 |
+
|
| 755 |
+
log.info(f"Performing action '{action}' on file: {filename}")
|
| 756 |
+
|
| 757 |
+
# 验证文件名
|
| 758 |
+
if not filename.endswith('.json'):
|
| 759 |
+
log.error(f"Invalid filename: {filename} (not a .json file)")
|
| 760 |
+
raise HTTPException(status_code=400, detail=f"无效的文件名: {filename}")
|
| 761 |
+
|
| 762 |
+
# 获取存储适配器
|
| 763 |
+
storage_adapter = await get_storage_adapter()
|
| 764 |
+
|
| 765 |
+
# 检查凭证是否存在
|
| 766 |
+
credential_data = await storage_adapter.get_credential(filename)
|
| 767 |
+
if not credential_data:
|
| 768 |
+
log.error(f"Credential not found: {filename}")
|
| 769 |
+
raise HTTPException(status_code=404, detail="凭证文件不存在")
|
| 770 |
+
|
| 771 |
+
if action == "enable":
|
| 772 |
+
log.info(f"Web request: ENABLING file {filename}")
|
| 773 |
+
await credential_manager.set_cred_disabled(filename, False)
|
| 774 |
+
log.info(f"Web request: ENABLED file {filename} successfully")
|
| 775 |
+
return JSONResponse(content={"message": f"已启用凭证文件 {os.path.basename(filename)}"})
|
| 776 |
+
|
| 777 |
+
elif action == "disable":
|
| 778 |
+
log.info(f"Web request: DISABLING file {filename}")
|
| 779 |
+
await credential_manager.set_cred_disabled(filename, True)
|
| 780 |
+
log.info(f"Web request: DISABLED file {filename} successfully")
|
| 781 |
+
return JSONResponse(content={"message": f"已禁用凭证文件 {os.path.basename(filename)}"})
|
| 782 |
+
|
| 783 |
+
elif action == "delete":
|
| 784 |
+
try:
|
| 785 |
+
# 使用存储适配器删除凭证
|
| 786 |
+
success = await storage_adapter.delete_credential(filename)
|
| 787 |
+
if success:
|
| 788 |
+
log.info(f"Successfully deleted credential: {filename}")
|
| 789 |
+
return JSONResponse(content={"message": f"已删除凭证文件 {os.path.basename(filename)}"})
|
| 790 |
+
else:
|
| 791 |
+
raise HTTPException(status_code=500, detail="删除凭证失败")
|
| 792 |
+
except Exception as e:
|
| 793 |
+
log.error(f"Error deleting credential {filename}: {e}")
|
| 794 |
+
raise HTTPException(status_code=500, detail=f"删除文件失败: {str(e)}")
|
| 795 |
+
|
| 796 |
+
else:
|
| 797 |
+
raise HTTPException(status_code=400, detail="无效的操作类型")
|
| 798 |
+
|
| 799 |
+
except HTTPException:
|
| 800 |
+
raise
|
| 801 |
+
except Exception as e:
|
| 802 |
+
log.error(f"凭证文件操作失败: {e}")
|
| 803 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
@router.post("/creds/batch-action")
|
| 807 |
+
async def creds_batch_action(request: CredFileBatchActionRequest, token: str = Depends(verify_token)):
|
| 808 |
+
"""批量对凭证文件执行操作(启用/禁用/删除)"""
|
| 809 |
+
try:
|
| 810 |
+
await ensure_credential_manager_initialized()
|
| 811 |
+
|
| 812 |
+
action = request.action
|
| 813 |
+
filenames = request.filenames
|
| 814 |
+
|
| 815 |
+
if not filenames:
|
| 816 |
+
raise HTTPException(status_code=400, detail="文件名列表不能为空")
|
| 817 |
+
|
| 818 |
+
log.info(f"Performing batch action '{action}' on {len(filenames)} files")
|
| 819 |
+
|
| 820 |
+
success_count = 0
|
| 821 |
+
errors = []
|
| 822 |
+
|
| 823 |
+
# 获取存储适配器
|
| 824 |
+
storage_adapter = await get_storage_adapter()
|
| 825 |
+
|
| 826 |
+
for filename in filenames:
|
| 827 |
+
try:
|
| 828 |
+
# 验证文件名安全性
|
| 829 |
+
if not filename.endswith('.json'):
|
| 830 |
+
errors.append(f"{filename}: 无效的文件类型")
|
| 831 |
+
continue
|
| 832 |
+
|
| 833 |
+
# 检查凭证是否存在
|
| 834 |
+
credential_data = await storage_adapter.get_credential(filename)
|
| 835 |
+
if not credential_data:
|
| 836 |
+
errors.append(f"{filename}: 凭证不存在")
|
| 837 |
+
continue
|
| 838 |
+
|
| 839 |
+
# 执行相应操作
|
| 840 |
+
if action == "enable":
|
| 841 |
+
await credential_manager.set_cred_disabled(filename, False)
|
| 842 |
+
success_count += 1
|
| 843 |
+
|
| 844 |
+
elif action == "disable":
|
| 845 |
+
await credential_manager.set_cred_disabled(filename, True)
|
| 846 |
+
success_count += 1
|
| 847 |
+
|
| 848 |
+
elif action == "delete":
|
| 849 |
+
try:
|
| 850 |
+
# 使用存储适配器删除凭证
|
| 851 |
+
delete_success = await storage_adapter.delete_credential(filename)
|
| 852 |
+
if delete_success:
|
| 853 |
+
success_count += 1
|
| 854 |
+
log.info(f"Successfully deleted credential in batch: {filename}")
|
| 855 |
+
else:
|
| 856 |
+
errors.append(f"{filename}: 删除失败")
|
| 857 |
+
continue
|
| 858 |
+
except Exception as e:
|
| 859 |
+
errors.append(f"{filename}: 删除文件失败 - {str(e)}")
|
| 860 |
+
continue
|
| 861 |
+
else:
|
| 862 |
+
errors.append(f"{filename}: 无效的操作类型")
|
| 863 |
+
continue
|
| 864 |
+
|
| 865 |
+
except Exception as e:
|
| 866 |
+
log.error(f"Processing {filename} failed: {e}")
|
| 867 |
+
errors.append(f"{filename}: 处理失败 - {str(e)}")
|
| 868 |
+
continue
|
| 869 |
+
|
| 870 |
+
# 构建返回消息
|
| 871 |
+
result_message = f"批量操作完成:成功处理 {success_count}/{len(filenames)} 个文件"
|
| 872 |
+
if errors:
|
| 873 |
+
result_message += f"\n错误详情:\n" + "\n".join(errors)
|
| 874 |
+
|
| 875 |
+
response_data = {
|
| 876 |
+
"success_count": success_count,
|
| 877 |
+
"total_count": len(filenames),
|
| 878 |
+
"errors": errors,
|
| 879 |
+
"message": result_message
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
return JSONResponse(content=response_data)
|
| 883 |
+
|
| 884 |
+
except HTTPException:
|
| 885 |
+
raise
|
| 886 |
+
except Exception as e:
|
| 887 |
+
log.error(f"批量凭证文件操作失败: {e}")
|
| 888 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
@router.get("/creds/download/{filename}")
|
| 892 |
+
async def download_cred_file(filename: str, token: str = Depends(verify_token)):
|
| 893 |
+
"""下载单个凭证文件"""
|
| 894 |
+
try:
|
| 895 |
+
# 验证文件名安全性
|
| 896 |
+
if not filename.endswith('.json'):
|
| 897 |
+
raise HTTPException(status_code=404, detail="无效的文件名")
|
| 898 |
+
|
| 899 |
+
# 获取存储适配器
|
| 900 |
+
storage_adapter = await get_storage_adapter()
|
| 901 |
+
|
| 902 |
+
# 从存储系统获取凭证数据
|
| 903 |
+
credential_data = await storage_adapter.get_credential(filename)
|
| 904 |
+
if not credential_data:
|
| 905 |
+
raise HTTPException(status_code=404, detail="文件不存在")
|
| 906 |
+
|
| 907 |
+
# 转换为JSON字符串
|
| 908 |
+
content = json.dumps(credential_data, ensure_ascii=False, indent=2)
|
| 909 |
+
|
| 910 |
+
from fastapi.responses import Response
|
| 911 |
+
return Response(
|
| 912 |
+
content=content,
|
| 913 |
+
media_type="application/json",
|
| 914 |
+
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
except HTTPException:
|
| 918 |
+
raise
|
| 919 |
+
except Exception as e:
|
| 920 |
+
log.error(f"下载凭证文件失败: {e}")
|
| 921 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
@router.post("/creds/fetch-email/{filename}")
|
| 925 |
+
async def fetch_user_email(filename: str, token: str = Depends(verify_token)):
|
| 926 |
+
"""获取指定凭证文件的用户邮箱地址"""
|
| 927 |
+
try:
|
| 928 |
+
await ensure_credential_manager_initialized()
|
| 929 |
+
|
| 930 |
+
# 标准化文件名(只保留文件名部分)
|
| 931 |
+
import os
|
| 932 |
+
filename_only = os.path.basename(filename)
|
| 933 |
+
if not filename_only.endswith('.json'):
|
| 934 |
+
raise HTTPException(status_code=404, detail="无效的文件名")
|
| 935 |
+
|
| 936 |
+
# 检查凭证是否存在于存储系统中
|
| 937 |
+
storage_adapter = await get_storage_adapter()
|
| 938 |
+
credential_data = await storage_adapter.get_credential(filename_only)
|
| 939 |
+
if not credential_data:
|
| 940 |
+
raise HTTPException(status_code=404, detail="凭证文件不存在")
|
| 941 |
+
|
| 942 |
+
# 获取用户邮箱(使用凭证名称而不是文件路径)
|
| 943 |
+
email = await credential_manager.get_or_fetch_user_email(filename_only)
|
| 944 |
+
|
| 945 |
+
if email:
|
| 946 |
+
return JSONResponse(content={
|
| 947 |
+
"filename": filename_only,
|
| 948 |
+
"user_email": email,
|
| 949 |
+
"message": "成功获取用户邮箱"
|
| 950 |
+
})
|
| 951 |
+
else:
|
| 952 |
+
return JSONResponse(content={
|
| 953 |
+
"filename": filename_only,
|
| 954 |
+
"user_email": None,
|
| 955 |
+
"message": "无法获取用户邮箱,可能凭证已过期或权限不足"
|
| 956 |
+
}, status_code=400)
|
| 957 |
+
|
| 958 |
+
except HTTPException:
|
| 959 |
+
raise
|
| 960 |
+
except Exception as e:
|
| 961 |
+
log.error(f"获取用户邮箱失败: {e}")
|
| 962 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 963 |
+
|
| 964 |
+
@router.post("/creds/refresh-all-emails")
|
| 965 |
+
async def refresh_all_user_emails(token: str = Depends(verify_token)):
|
| 966 |
+
"""刷新所有凭证文件的用户邮箱地址"""
|
| 967 |
+
try:
|
| 968 |
+
await ensure_credential_manager_initialized()
|
| 969 |
+
|
| 970 |
+
# 获取存储适配器
|
| 971 |
+
storage_adapter = await get_storage_adapter()
|
| 972 |
+
|
| 973 |
+
# 获取所有凭证文件
|
| 974 |
+
credential_filenames = await storage_adapter.list_credentials()
|
| 975 |
+
|
| 976 |
+
results = []
|
| 977 |
+
success_count = 0
|
| 978 |
+
|
| 979 |
+
for filename in credential_filenames:
|
| 980 |
+
try:
|
| 981 |
+
email = await credential_manager.get_or_fetch_user_email(filename)
|
| 982 |
+
if email:
|
| 983 |
+
success_count += 1
|
| 984 |
+
results.append({
|
| 985 |
+
"filename": os.path.basename(filename),
|
| 986 |
+
"user_email": email,
|
| 987 |
+
"success": True
|
| 988 |
+
})
|
| 989 |
+
else:
|
| 990 |
+
results.append({
|
| 991 |
+
"filename": os.path.basename(filename),
|
| 992 |
+
"user_email": None,
|
| 993 |
+
"success": False,
|
| 994 |
+
"error": "无法获取邮箱"
|
| 995 |
+
})
|
| 996 |
+
except Exception as e:
|
| 997 |
+
results.append({
|
| 998 |
+
"filename": os.path.basename(filename),
|
| 999 |
+
"user_email": None,
|
| 1000 |
+
"success": False,
|
| 1001 |
+
"error": str(e)
|
| 1002 |
+
})
|
| 1003 |
+
|
| 1004 |
+
return JSONResponse(content={
|
| 1005 |
+
"success_count": success_count,
|
| 1006 |
+
"total_count": len(credential_filenames),
|
| 1007 |
+
"results": results,
|
| 1008 |
+
"message": f"成功获取 {success_count}/{len(credential_filenames)} 个邮箱地址"
|
| 1009 |
+
})
|
| 1010 |
+
|
| 1011 |
+
except Exception as e:
|
| 1012 |
+
log.error(f"批量获取用户邮箱失败: {e}")
|
| 1013 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1014 |
+
|
| 1015 |
+
@router.get("/creds/download-all")
|
| 1016 |
+
async def download_all_creds(token: str = Depends(verify_token)):
|
| 1017 |
+
"""打包下载所有凭证文件"""
|
| 1018 |
+
try:
|
| 1019 |
+
# 获取存储适配器
|
| 1020 |
+
storage_adapter = await get_storage_adapter()
|
| 1021 |
+
|
| 1022 |
+
# 获取所有凭证文件列表
|
| 1023 |
+
credential_filenames = await storage_adapter.list_credentials()
|
| 1024 |
+
|
| 1025 |
+
if not credential_filenames:
|
| 1026 |
+
raise HTTPException(status_code=404, detail="没有找到凭证文件")
|
| 1027 |
+
|
| 1028 |
+
# 创建内存中的ZIP文件
|
| 1029 |
+
zip_buffer = io.BytesIO()
|
| 1030 |
+
|
| 1031 |
+
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
| 1032 |
+
# 遍历所有凭证文件
|
| 1033 |
+
for filename in credential_filenames:
|
| 1034 |
+
try:
|
| 1035 |
+
credential_data = await storage_adapter.get_credential(filename)
|
| 1036 |
+
if credential_data:
|
| 1037 |
+
# 转换为JSON字符串
|
| 1038 |
+
content = json.dumps(credential_data, ensure_ascii=False, indent=2)
|
| 1039 |
+
|
| 1040 |
+
# 添加到ZIP文件中
|
| 1041 |
+
zip_file.writestr(os.path.basename(filename), content)
|
| 1042 |
+
log.debug(f"已添加到ZIP: {filename}")
|
| 1043 |
+
except Exception as e:
|
| 1044 |
+
log.warning(f"处理凭证文件 {filename} 时出错: {e}")
|
| 1045 |
+
continue
|
| 1046 |
+
|
| 1047 |
+
zip_buffer.seek(0)
|
| 1048 |
+
return Response(
|
| 1049 |
+
content=zip_buffer.getvalue(),
|
| 1050 |
+
media_type="application/zip",
|
| 1051 |
+
headers={"Content-Disposition": "attachment; filename=credentials.zip"}
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
except Exception as e:
|
| 1055 |
+
log.error(f"打包下载失败: {e}")
|
| 1056 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
@router.get("/config/get")
|
| 1060 |
+
async def get_config(token: str = Depends(verify_token)):
|
| 1061 |
+
"""获取当前配置"""
|
| 1062 |
+
try:
|
| 1063 |
+
await ensure_credential_manager_initialized()
|
| 1064 |
+
|
| 1065 |
+
# 导入配置相关模块
|
| 1066 |
+
|
| 1067 |
+
# 读取当前配置(包括环境变量和TOML文件中的配置)
|
| 1068 |
+
current_config = {}
|
| 1069 |
+
env_locked = []
|
| 1070 |
+
|
| 1071 |
+
# 基础配置
|
| 1072 |
+
current_config["code_assist_endpoint"] = await config.get_code_assist_endpoint()
|
| 1073 |
+
current_config["credentials_dir"] = await config.get_credentials_dir()
|
| 1074 |
+
current_config["proxy"] = await config.get_proxy_config() or ""
|
| 1075 |
+
|
| 1076 |
+
# 代理端点配置
|
| 1077 |
+
current_config["oauth_proxy_url"] = await config.get_oauth_proxy_url()
|
| 1078 |
+
current_config["googleapis_proxy_url"] = await config.get_googleapis_proxy_url()
|
| 1079 |
+
current_config["resource_manager_api_url"] = await config.get_resource_manager_api_url()
|
| 1080 |
+
current_config["service_usage_api_url"] = await config.get_service_usage_api_url()
|
| 1081 |
+
|
| 1082 |
+
# 检查环境变量锁定状态
|
| 1083 |
+
if os.getenv("CODE_ASSIST_ENDPOINT"):
|
| 1084 |
+
env_locked.append("code_assist_endpoint")
|
| 1085 |
+
if os.getenv("CREDENTIALS_DIR"):
|
| 1086 |
+
env_locked.append("credentials_dir")
|
| 1087 |
+
if os.getenv("PROXY"):
|
| 1088 |
+
env_locked.append("proxy")
|
| 1089 |
+
if os.getenv("OAUTH_PROXY_URL"):
|
| 1090 |
+
env_locked.append("oauth_proxy_url")
|
| 1091 |
+
if os.getenv("GOOGLEAPIS_PROXY_URL"):
|
| 1092 |
+
env_locked.append("googleapis_proxy_url")
|
| 1093 |
+
if os.getenv("RESOURCE_MANAGER_API_URL"):
|
| 1094 |
+
env_locked.append("resource_manager_api_url")
|
| 1095 |
+
if os.getenv("SERVICE_USAGE_API_URL"):
|
| 1096 |
+
env_locked.append("service_usage_api_url")
|
| 1097 |
+
|
| 1098 |
+
# 自动封禁配置
|
| 1099 |
+
current_config["auto_ban_enabled"] = await config.get_auto_ban_enabled()
|
| 1100 |
+
current_config["auto_ban_error_codes"] = await config.get_auto_ban_error_codes()
|
| 1101 |
+
|
| 1102 |
+
# 检查环境变量锁定状态
|
| 1103 |
+
if os.getenv("AUTO_BAN"):
|
| 1104 |
+
env_locked.append("auto_ban_enabled")
|
| 1105 |
+
|
| 1106 |
+
# 从存储系统读取配置
|
| 1107 |
+
storage_adapter = await get_storage_adapter()
|
| 1108 |
+
storage_config = await storage_adapter.get_all_config()
|
| 1109 |
+
|
| 1110 |
+
# 合并存储系统配置(不覆盖环境变量)
|
| 1111 |
+
for key, value in storage_config.items():
|
| 1112 |
+
if key not in env_locked:
|
| 1113 |
+
current_config[key] = value
|
| 1114 |
+
|
| 1115 |
+
# 性能配置
|
| 1116 |
+
current_config["calls_per_rotation"] = await config.get_calls_per_rotation()
|
| 1117 |
+
|
| 1118 |
+
# 429重试配置
|
| 1119 |
+
current_config["retry_429_max_retries"] = await config.get_retry_429_max_retries()
|
| 1120 |
+
current_config["retry_429_enabled"] = await config.get_retry_429_enabled()
|
| 1121 |
+
current_config["retry_429_interval"] = await config.get_retry_429_interval()
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
# 抗截断配置
|
| 1125 |
+
current_config["anti_truncation_max_attempts"] = await config.get_anti_truncation_max_attempts()
|
| 1126 |
+
|
| 1127 |
+
# 兼容性配置
|
| 1128 |
+
current_config["compatibility_mode_enabled"] = await config.get_compatibility_mode_enabled()
|
| 1129 |
+
|
| 1130 |
+
# 服务器配置
|
| 1131 |
+
current_config["host"] = await config.get_server_host()
|
| 1132 |
+
current_config["port"] = await config.get_server_port()
|
| 1133 |
+
current_config["api_password"] = await config.get_api_password()
|
| 1134 |
+
current_config["panel_password"] = await config.get_panel_password()
|
| 1135 |
+
current_config["password"] = await config.get_server_password()
|
| 1136 |
+
|
| 1137 |
+
# 检查其他环境变量锁定状态
|
| 1138 |
+
if os.getenv("RETRY_429_MAX_RETRIES"):
|
| 1139 |
+
env_locked.append("retry_429_max_retries")
|
| 1140 |
+
if os.getenv("RETRY_429_ENABLED"):
|
| 1141 |
+
env_locked.append("retry_429_enabled")
|
| 1142 |
+
if os.getenv("RETRY_429_INTERVAL"):
|
| 1143 |
+
env_locked.append("retry_429_interval")
|
| 1144 |
+
if os.getenv("ANTI_TRUNCATION_MAX_ATTEMPTS"):
|
| 1145 |
+
env_locked.append("anti_truncation_max_attempts")
|
| 1146 |
+
if os.getenv("COMPATIBILITY_MODE"):
|
| 1147 |
+
env_locked.append("compatibility_mode_enabled")
|
| 1148 |
+
if os.getenv("HOST"):
|
| 1149 |
+
env_locked.append("host")
|
| 1150 |
+
if os.getenv("PORT"):
|
| 1151 |
+
env_locked.append("port")
|
| 1152 |
+
if os.getenv("API_PASSWORD"):
|
| 1153 |
+
env_locked.append("api_password")
|
| 1154 |
+
if os.getenv("PANEL_PASSWORD"):
|
| 1155 |
+
env_locked.append("panel_password")
|
| 1156 |
+
if os.getenv("PASSWORD"):
|
| 1157 |
+
env_locked.append("password")
|
| 1158 |
+
|
| 1159 |
+
return JSONResponse(content={
|
| 1160 |
+
"config": current_config,
|
| 1161 |
+
"env_locked": env_locked
|
| 1162 |
+
})
|
| 1163 |
+
|
| 1164 |
+
except Exception as e:
|
| 1165 |
+
log.error(f"获取配置失败: {e}")
|
| 1166 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
@router.post("/config/save")
|
| 1170 |
+
async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_token)):
|
| 1171 |
+
"""保存配置到TOML文件"""
|
| 1172 |
+
try:
|
| 1173 |
+
await ensure_credential_manager_initialized()
|
| 1174 |
+
new_config = request.config
|
| 1175 |
+
|
| 1176 |
+
log.debug(f"收到的配置数据: {list(new_config.keys())}")
|
| 1177 |
+
log.debug(f"收到的password值: {new_config.get('password', 'NOT_FOUND')}")
|
| 1178 |
+
|
| 1179 |
+
# 验证配置项
|
| 1180 |
+
if "calls_per_rotation" in new_config:
|
| 1181 |
+
if not isinstance(new_config["calls_per_rotation"], int) or new_config["calls_per_rotation"] < 1:
|
| 1182 |
+
raise HTTPException(status_code=400, detail="凭证轮换调用次数必须是大于0的整数")
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
if "retry_429_max_retries" in new_config:
|
| 1186 |
+
if not isinstance(new_config["retry_429_max_retries"], int) or new_config["retry_429_max_retries"] < 0:
|
| 1187 |
+
raise HTTPException(status_code=400, detail="最大429重试次数必须是大于等于0的整数")
|
| 1188 |
+
|
| 1189 |
+
if "retry_429_enabled" in new_config:
|
| 1190 |
+
if not isinstance(new_config["retry_429_enabled"], bool):
|
| 1191 |
+
raise HTTPException(status_code=400, detail="429重试开关必须是布尔值")
|
| 1192 |
+
|
| 1193 |
+
# 验证新的配置项
|
| 1194 |
+
if "retry_429_interval" in new_config:
|
| 1195 |
+
try:
|
| 1196 |
+
interval = float(new_config["retry_429_interval"])
|
| 1197 |
+
if interval < 0.01 or interval > 10:
|
| 1198 |
+
raise HTTPException(status_code=400, detail="429重试间隔必须在0.01-10秒之间")
|
| 1199 |
+
except (ValueError, TypeError):
|
| 1200 |
+
raise HTTPException(status_code=400, detail="429重试间隔必须是有效的数字")
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
if "anti_truncation_max_attempts" in new_config:
|
| 1204 |
+
if not isinstance(new_config["anti_truncation_max_attempts"], int) or new_config["anti_truncation_max_attempts"] < 1 or new_config["anti_truncation_max_attempts"] > 10:
|
| 1205 |
+
raise HTTPException(status_code=400, detail="抗截断最大重试次数必须是1-10之间的整数")
|
| 1206 |
+
|
| 1207 |
+
if "compatibility_mode_enabled" in new_config:
|
| 1208 |
+
if not isinstance(new_config["compatibility_mode_enabled"], bool):
|
| 1209 |
+
raise HTTPException(status_code=400, detail="兼容性模式开关必须是布尔值")
|
| 1210 |
+
|
| 1211 |
+
# 验证服务器配置
|
| 1212 |
+
if "host" in new_config:
|
| 1213 |
+
if not isinstance(new_config["host"], str) or not new_config["host"].strip():
|
| 1214 |
+
raise HTTPException(status_code=400, detail="服务器主机地址不能为空")
|
| 1215 |
+
|
| 1216 |
+
if "port" in new_config:
|
| 1217 |
+
if not isinstance(new_config["port"], int) or new_config["port"] < 1 or new_config["port"] > 65535:
|
| 1218 |
+
raise HTTPException(status_code=400, detail="端口号必须是1-65535之间的整数")
|
| 1219 |
+
|
| 1220 |
+
if "api_password" in new_config:
|
| 1221 |
+
if not isinstance(new_config["api_password"], str):
|
| 1222 |
+
raise HTTPException(status_code=400, detail="API访问密码必须是字符串")
|
| 1223 |
+
|
| 1224 |
+
if "panel_password" in new_config:
|
| 1225 |
+
if not isinstance(new_config["panel_password"], str):
|
| 1226 |
+
raise HTTPException(status_code=400, detail="控制面板密码必须是字符串")
|
| 1227 |
+
|
| 1228 |
+
if "password" in new_config:
|
| 1229 |
+
if not isinstance(new_config["password"], str):
|
| 1230 |
+
raise HTTPException(status_code=400, detail="访问密码必须是字符串")
|
| 1231 |
+
|
| 1232 |
+
# 读取现有的配置文件
|
| 1233 |
+
credentials_dir = await config.get_credentials_dir()
|
| 1234 |
+
config_file = os.path.join(credentials_dir, "config.toml")
|
| 1235 |
+
existing_config = {}
|
| 1236 |
+
|
| 1237 |
+
try:
|
| 1238 |
+
if os.path.exists(config_file):
|
| 1239 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
| 1240 |
+
existing_config = toml.load(f)
|
| 1241 |
+
except Exception as e:
|
| 1242 |
+
log.warning(f"读取现有配置文件失败: {e}")
|
| 1243 |
+
|
| 1244 |
+
# 只更新不被环境变量锁定的配置项
|
| 1245 |
+
env_locked_keys = set()
|
| 1246 |
+
if os.getenv("CODE_ASSIST_ENDPOINT"):
|
| 1247 |
+
env_locked_keys.add("code_assist_endpoint")
|
| 1248 |
+
if os.getenv("CREDENTIALS_DIR"):
|
| 1249 |
+
env_locked_keys.add("credentials_dir")
|
| 1250 |
+
if os.getenv("PROXY"):
|
| 1251 |
+
env_locked_keys.add("proxy")
|
| 1252 |
+
if os.getenv("OAUTH_PROXY_URL"):
|
| 1253 |
+
env_locked_keys.add("oauth_proxy_url")
|
| 1254 |
+
if os.getenv("GOOGLEAPIS_PROXY_URL"):
|
| 1255 |
+
env_locked_keys.add("googleapis_proxy_url")
|
| 1256 |
+
if os.getenv("AUTO_BAN"):
|
| 1257 |
+
env_locked_keys.add("auto_ban_enabled")
|
| 1258 |
+
if os.getenv("RETRY_429_MAX_RETRIES"):
|
| 1259 |
+
env_locked_keys.add("retry_429_max_retries")
|
| 1260 |
+
if os.getenv("RETRY_429_ENABLED"):
|
| 1261 |
+
env_locked_keys.add("retry_429_enabled")
|
| 1262 |
+
if os.getenv("RETRY_429_INTERVAL"):
|
| 1263 |
+
env_locked_keys.add("retry_429_interval")
|
| 1264 |
+
if os.getenv("ANTI_TRUNCATION_MAX_ATTEMPTS"):
|
| 1265 |
+
env_locked_keys.add("anti_truncation_max_attempts")
|
| 1266 |
+
if os.getenv("COMPATIBILITY_MODE"):
|
| 1267 |
+
env_locked_keys.add("compatibility_mode_enabled")
|
| 1268 |
+
if os.getenv("HOST"):
|
| 1269 |
+
env_locked_keys.add("host")
|
| 1270 |
+
if os.getenv("PORT"):
|
| 1271 |
+
env_locked_keys.add("port")
|
| 1272 |
+
if os.getenv("API_PASSWORD"):
|
| 1273 |
+
env_locked_keys.add("api_password")
|
| 1274 |
+
if os.getenv("PANEL_PASSWORD"):
|
| 1275 |
+
env_locked_keys.add("panel_password")
|
| 1276 |
+
if os.getenv("PASSWORD"):
|
| 1277 |
+
env_locked_keys.add("password")
|
| 1278 |
+
|
| 1279 |
+
for key, value in new_config.items():
|
| 1280 |
+
if key not in env_locked_keys:
|
| 1281 |
+
existing_config[key] = value
|
| 1282 |
+
if key == 'password':
|
| 1283 |
+
log.debug(f"设置password字段为: {value}")
|
| 1284 |
+
elif key == 'api_password':
|
| 1285 |
+
log.debug(f"设置api_password字段为: {value}")
|
| 1286 |
+
elif key == 'panel_password':
|
| 1287 |
+
log.debug(f"设置panel_password字段为: {value}")
|
| 1288 |
+
log.debug(f"最终保存的existing_config中password = {existing_config.get('password', 'NOT_FOUND')}")
|
| 1289 |
+
|
| 1290 |
+
# 直接使用存储适配器保存配置
|
| 1291 |
+
storage_adapter = await get_storage_adapter()
|
| 1292 |
+
for key, value in existing_config.items():
|
| 1293 |
+
await storage_adapter.set_config(key, value)
|
| 1294 |
+
|
| 1295 |
+
# 验证保存后的结果
|
| 1296 |
+
test_api_password = await config.get_api_password()
|
| 1297 |
+
test_panel_password = await config.get_panel_password()
|
| 1298 |
+
test_password = await config.get_server_password()
|
| 1299 |
+
log.debug(f"保存后立即读取的API密码: {test_api_password}")
|
| 1300 |
+
log.debug(f"保存后立即读取的面板密码: {test_panel_password}")
|
| 1301 |
+
log.debug(f"保存后立即读取的通用密码: {test_password}")
|
| 1302 |
+
|
| 1303 |
+
# 热更新配置到内存中的模块(如果可能)
|
| 1304 |
+
hot_updated = [] # 记录成功热更新的配置项
|
| 1305 |
+
restart_required = [] # 记录需要重启的配置项
|
| 1306 |
+
|
| 1307 |
+
# 支持热更新的配置项:
|
| 1308 |
+
# - calls_per_rotation: 凭证轮换调用次数
|
| 1309 |
+
# - proxy: 网络配置
|
| 1310 |
+
# - log_level: 日志级别
|
| 1311 |
+
# - auto_ban_enabled, auto_ban_error_codes: 自动封禁配置
|
| 1312 |
+
# - retry_429_enabled, retry_429_max_retries, retry_429_interval: 429重试配置
|
| 1313 |
+
# - anti_truncation_max_attempts: 抗截断配置
|
| 1314 |
+
# - compatibility_mode_enabled: 兼容性模式
|
| 1315 |
+
# - api_password, panel_password, password: 访问密码
|
| 1316 |
+
#
|
| 1317 |
+
# 需要重启的配置项:
|
| 1318 |
+
# - host, port: 服务器地址和端口
|
| 1319 |
+
# - log_file: 日志文件路径
|
| 1320 |
+
|
| 1321 |
+
try:
|
| 1322 |
+
# save_config_to_toml已经更新了缓存,不需要reload
|
| 1323 |
+
|
| 1324 |
+
# 1. credential_manager配置通过config模块动态获取,无需手动更新
|
| 1325 |
+
if "calls_per_rotation" in new_config and "calls_per_rotation" not in env_locked_keys:
|
| 1326 |
+
# 新的credential_manager会通过get_calls_per_rotation()动态获取最新配置
|
| 1327 |
+
hot_updated.append("calls_per_rotation")
|
| 1328 |
+
|
| 1329 |
+
# 2. 代理配置(部分热更新)
|
| 1330 |
+
if "proxy" in new_config and "proxy" not in env_locked_keys:
|
| 1331 |
+
hot_updated.append("proxy")
|
| 1332 |
+
|
| 1333 |
+
# 代理端点配置(可热更新)
|
| 1334 |
+
proxy_endpoint_configs = ["oauth_proxy_url", "googleapis_proxy_url"]
|
| 1335 |
+
for config_key in proxy_endpoint_configs:
|
| 1336 |
+
if config_key in new_config and config_key not in env_locked_keys:
|
| 1337 |
+
hot_updated.append(config_key)
|
| 1338 |
+
|
| 1339 |
+
|
| 1340 |
+
# 4. 其他可热更新的配置项
|
| 1341 |
+
hot_updatable_configs = [
|
| 1342 |
+
"auto_ban_enabled", "auto_ban_error_codes",
|
| 1343 |
+
"retry_429_enabled", "retry_429_max_retries", "retry_429_interval",
|
| 1344 |
+
"anti_truncation_max_attempts", "compatibility_mode_enabled"
|
| 1345 |
+
]
|
| 1346 |
+
|
| 1347 |
+
for config_key in hot_updatable_configs:
|
| 1348 |
+
if config_key in new_config and config_key not in env_locked_keys:
|
| 1349 |
+
hot_updated.append(config_key)
|
| 1350 |
+
|
| 1351 |
+
# 4. 需要重启的配置项
|
| 1352 |
+
restart_required_configs = ["host", "port"]
|
| 1353 |
+
for config_key in restart_required_configs:
|
| 1354 |
+
if config_key in new_config and config_key not in env_locked_keys:
|
| 1355 |
+
restart_required.append(config_key)
|
| 1356 |
+
|
| 1357 |
+
# 5. 密码配置(立即生效)
|
| 1358 |
+
password_configs = ["api_password", "panel_password", "password"]
|
| 1359 |
+
for config_key in password_configs:
|
| 1360 |
+
if config_key in new_config and config_key not in env_locked_keys:
|
| 1361 |
+
hot_updated.append(config_key)
|
| 1362 |
+
|
| 1363 |
+
except Exception as e:
|
| 1364 |
+
log.warning(f"热更新配置失败: {e}")
|
| 1365 |
+
|
| 1366 |
+
# 构建响应消息
|
| 1367 |
+
response_data = {
|
| 1368 |
+
"message": "配置保存成功",
|
| 1369 |
+
"saved_config": {k: v for k, v in new_config.items() if k not in env_locked_keys}
|
| 1370 |
+
}
|
| 1371 |
+
|
| 1372 |
+
# 添加热更新状态信息
|
| 1373 |
+
if hot_updated:
|
| 1374 |
+
response_data["hot_updated"] = hot_updated
|
| 1375 |
+
|
| 1376 |
+
if restart_required:
|
| 1377 |
+
response_data["restart_required"] = restart_required
|
| 1378 |
+
response_data["restart_notice"] = f"以下配置项需要重启服务器才能生效: {', '.join(restart_required)}"
|
| 1379 |
+
|
| 1380 |
+
return JSONResponse(content=response_data)
|
| 1381 |
+
|
| 1382 |
+
except HTTPException:
|
| 1383 |
+
raise
|
| 1384 |
+
except Exception as e:
|
| 1385 |
+
log.error(f"保存配置失败: {e}")
|
| 1386 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1387 |
+
|
| 1388 |
+
|
| 1389 |
+
@router.post("/auth/load-env-creds")
|
| 1390 |
+
async def load_env_credentials(token: str = Depends(verify_token)):
|
| 1391 |
+
"""从环境变量加载凭证文件"""
|
| 1392 |
+
try:
|
| 1393 |
+
result = await load_credentials_from_env()
|
| 1394 |
+
|
| 1395 |
+
if result['loaded_count'] > 0:
|
| 1396 |
+
return JSONResponse(content={
|
| 1397 |
+
"loaded_count": result['loaded_count'],
|
| 1398 |
+
"total_count": result['total_count'],
|
| 1399 |
+
"results": result['results'],
|
| 1400 |
+
"message": result['message']
|
| 1401 |
+
})
|
| 1402 |
+
else:
|
| 1403 |
+
return JSONResponse(content={
|
| 1404 |
+
"loaded_count": 0,
|
| 1405 |
+
"total_count": result['total_count'],
|
| 1406 |
+
"message": result['message'],
|
| 1407 |
+
"results": result['results']
|
| 1408 |
+
})
|
| 1409 |
+
|
| 1410 |
+
except Exception as e:
|
| 1411 |
+
log.error(f"从环境变量加载凭证失败: {e}")
|
| 1412 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
+
@router.delete("/auth/env-creds")
|
| 1416 |
+
async def clear_env_creds(token: str = Depends(verify_token)):
|
| 1417 |
+
"""清除所有从环境变量导入的凭证文件"""
|
| 1418 |
+
try:
|
| 1419 |
+
result = await clear_env_credentials()
|
| 1420 |
+
|
| 1421 |
+
if 'error' in result:
|
| 1422 |
+
raise HTTPException(status_code=500, detail=result['error'])
|
| 1423 |
+
|
| 1424 |
+
return JSONResponse(content={
|
| 1425 |
+
"deleted_count": result['deleted_count'],
|
| 1426 |
+
"deleted_files": result.get('deleted_files', []),
|
| 1427 |
+
"message": result['message']
|
| 1428 |
+
})
|
| 1429 |
+
|
| 1430 |
+
except HTTPException:
|
| 1431 |
+
raise
|
| 1432 |
+
except Exception as e:
|
| 1433 |
+
log.error(f"清除环境变量凭证失败: {e}")
|
| 1434 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1435 |
+
|
| 1436 |
+
|
| 1437 |
+
@router.get("/auth/env-creds-status")
|
| 1438 |
+
async def get_env_creds_status(token: str = Depends(verify_token)):
|
| 1439 |
+
"""获取环境变量凭证状态"""
|
| 1440 |
+
try:
|
| 1441 |
+
# 检查有哪些环境变量可用
|
| 1442 |
+
available_env_vars = {key: "***已设置***" for key, value in os.environ.items()
|
| 1443 |
+
if key.startswith('GCLI_CREDS_') and value.strip()}
|
| 1444 |
+
|
| 1445 |
+
# 检查自动加载设置
|
| 1446 |
+
auto_load_enabled = await config.get_auto_load_env_creds()
|
| 1447 |
+
|
| 1448 |
+
# 统计已存在的环境变量凭证文件
|
| 1449 |
+
storage_adapter = await get_storage_adapter()
|
| 1450 |
+
all_credentials = await storage_adapter.list_credentials()
|
| 1451 |
+
existing_env_files = [
|
| 1452 |
+
filename for filename in all_credentials
|
| 1453 |
+
if filename.startswith('env-') and filename.endswith('.json')
|
| 1454 |
+
]
|
| 1455 |
+
|
| 1456 |
+
return JSONResponse(content={
|
| 1457 |
+
"available_env_vars": available_env_vars,
|
| 1458 |
+
"auto_load_enabled": auto_load_enabled,
|
| 1459 |
+
"existing_env_files_count": len(existing_env_files),
|
| 1460 |
+
"existing_env_files": existing_env_files
|
| 1461 |
+
})
|
| 1462 |
+
|
| 1463 |
+
except Exception as e:
|
| 1464 |
+
log.error(f"获取环境变量凭证状态失败: {e}")
|
| 1465 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1466 |
+
|
| 1467 |
+
|
| 1468 |
+
# =============================================================================
|
| 1469 |
+
# 实时日志WebSocket (Real-time Logs WebSocket)
|
| 1470 |
+
# =============================================================================
|
| 1471 |
+
|
| 1472 |
+
@router.post("/auth/logs/clear")
|
| 1473 |
+
async def clear_logs(token: str = Depends(verify_token)):
|
| 1474 |
+
"""清空日志文件"""
|
| 1475 |
+
try:
|
| 1476 |
+
# 直接使用环境变量获取日志文件路径
|
| 1477 |
+
log_file_path = os.getenv('LOG_FILE', 'log.txt')
|
| 1478 |
+
|
| 1479 |
+
# 检查日志文件是否存在
|
| 1480 |
+
if os.path.exists(log_file_path):
|
| 1481 |
+
try:
|
| 1482 |
+
# 清空文件内容(保留文件),确保以UTF-8编码写入
|
| 1483 |
+
with open(log_file_path, 'w', encoding='utf-8', newline='') as f:
|
| 1484 |
+
f.write('')
|
| 1485 |
+
f.flush() # 强制刷新到磁盘
|
| 1486 |
+
log.info(f"日志文件已清空: {log_file_path}")
|
| 1487 |
+
|
| 1488 |
+
# 通知所有WebSocket连接日志已清空
|
| 1489 |
+
await manager.broadcast("--- 日志文件已清空 ---")
|
| 1490 |
+
|
| 1491 |
+
return JSONResponse(content={"message": f"日志文件已清空: {os.path.basename(log_file_path)}"})
|
| 1492 |
+
except Exception as e:
|
| 1493 |
+
log.error(f"清空日志文件失败: {e}")
|
| 1494 |
+
raise HTTPException(status_code=500, detail=f"清空日志文件失败: {str(e)}")
|
| 1495 |
+
else:
|
| 1496 |
+
return JSONResponse(content={"message": "日志文件不存在"})
|
| 1497 |
+
|
| 1498 |
+
except Exception as e:
|
| 1499 |
+
log.error(f"清空日志文件失败: {e}")
|
| 1500 |
+
raise HTTPException(status_code=500, detail=f"清空日志文件失败: {str(e)}")
|
| 1501 |
+
|
| 1502 |
+
@router.get("/auth/logs/download")
|
| 1503 |
+
async def download_logs(token: str = Depends(verify_token)):
|
| 1504 |
+
"""下载日志文件"""
|
| 1505 |
+
try:
|
| 1506 |
+
# 直接使用环境变量获取日志文件路径
|
| 1507 |
+
log_file_path = os.getenv('LOG_FILE', 'log.txt')
|
| 1508 |
+
|
| 1509 |
+
# 检查日志文件是否存在
|
| 1510 |
+
if not os.path.exists(log_file_path):
|
| 1511 |
+
raise HTTPException(status_code=404, detail="日志文件不存在")
|
| 1512 |
+
|
| 1513 |
+
# 检查文件是否为空
|
| 1514 |
+
file_size = os.path.getsize(log_file_path)
|
| 1515 |
+
if file_size == 0:
|
| 1516 |
+
raise HTTPException(status_code=404, detail="日志文件为空")
|
| 1517 |
+
|
| 1518 |
+
# 生成文件名(包含时间戳)
|
| 1519 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 1520 |
+
filename = f"gcli2api_logs_{timestamp}.txt"
|
| 1521 |
+
|
| 1522 |
+
log.info(f"下载日志文件: {log_file_path}")
|
| 1523 |
+
|
| 1524 |
+
return FileResponse(
|
| 1525 |
+
path=log_file_path,
|
| 1526 |
+
filename=filename,
|
| 1527 |
+
media_type='text/plain',
|
| 1528 |
+
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
| 1529 |
+
)
|
| 1530 |
+
|
| 1531 |
+
except HTTPException:
|
| 1532 |
+
raise
|
| 1533 |
+
except Exception as e:
|
| 1534 |
+
log.error(f"下载日志文件失败: {e}")
|
| 1535 |
+
raise HTTPException(status_code=500, detail=f"下载日志文件失败: {str(e)}")
|
| 1536 |
+
|
| 1537 |
+
@router.websocket("/auth/logs/stream")
|
| 1538 |
+
async def websocket_logs(websocket: WebSocket):
|
| 1539 |
+
"""WebSocket端点,用于实时日志流"""
|
| 1540 |
+
# 检查连接数限制
|
| 1541 |
+
if not await manager.connect(websocket):
|
| 1542 |
+
return
|
| 1543 |
+
|
| 1544 |
+
try:
|
| 1545 |
+
# 直接使用环境变量获取日志文件路径
|
| 1546 |
+
log_file_path = os.getenv('LOG_FILE', 'log.txt')
|
| 1547 |
+
|
| 1548 |
+
# 发送初始日志(限制为最后50行,减少内存占用)
|
| 1549 |
+
if os.path.exists(log_file_path):
|
| 1550 |
+
try:
|
| 1551 |
+
with open(log_file_path, "r", encoding="utf-8") as f:
|
| 1552 |
+
lines = f.readlines()
|
| 1553 |
+
# 只发送最后50行,减少初始内存消耗
|
| 1554 |
+
for line in lines[-50:]:
|
| 1555 |
+
if line.strip():
|
| 1556 |
+
await websocket.send_text(line.strip())
|
| 1557 |
+
except Exception as e:
|
| 1558 |
+
await websocket.send_text(f"Error reading log file: {e}")
|
| 1559 |
+
|
| 1560 |
+
# 监控日志文件变化
|
| 1561 |
+
last_size = os.path.getsize(log_file_path) if os.path.exists(log_file_path) else 0
|
| 1562 |
+
max_read_size = 8192 # 限制单次读取大小为8KB,防止大量日志造成内存激增
|
| 1563 |
+
check_interval = 2 # 增加检查间隔,减少CPU和I/O开销
|
| 1564 |
+
|
| 1565 |
+
while websocket.client_state == WebSocketState.CONNECTED:
|
| 1566 |
+
await asyncio.sleep(check_interval)
|
| 1567 |
+
|
| 1568 |
+
if os.path.exists(log_file_path):
|
| 1569 |
+
current_size = os.path.getsize(log_file_path)
|
| 1570 |
+
if current_size > last_size:
|
| 1571 |
+
# 限制读取大小,防止单次读取过多内容
|
| 1572 |
+
read_size = min(current_size - last_size, max_read_size)
|
| 1573 |
+
|
| 1574 |
+
try:
|
| 1575 |
+
with open(log_file_path, "r", encoding="utf-8", errors="replace") as f:
|
| 1576 |
+
f.seek(last_size)
|
| 1577 |
+
new_content = f.read(read_size)
|
| 1578 |
+
|
| 1579 |
+
# 处理编码错误的情况
|
| 1580 |
+
if not new_content:
|
| 1581 |
+
last_size = current_size
|
| 1582 |
+
continue
|
| 1583 |
+
|
| 1584 |
+
# 分行发送,避免发送不完整的行
|
| 1585 |
+
lines = new_content.splitlines(keepends=True)
|
| 1586 |
+
if lines:
|
| 1587 |
+
# 如果最后一行没有换行符,保留到下次处理
|
| 1588 |
+
if not lines[-1].endswith('\n') and len(lines) > 1:
|
| 1589 |
+
# 除了最后一行,其他都发送
|
| 1590 |
+
for line in lines[:-1]:
|
| 1591 |
+
if line.strip():
|
| 1592 |
+
await websocket.send_text(line.rstrip())
|
| 1593 |
+
# 更新位置,但要退回最后一行的字节数
|
| 1594 |
+
last_size += len(new_content.encode('utf-8')) - len(lines[-1].encode('utf-8'))
|
| 1595 |
+
else:
|
| 1596 |
+
# 所有行都发送
|
| 1597 |
+
for line in lines:
|
| 1598 |
+
if line.strip():
|
| 1599 |
+
await websocket.send_text(line.rstrip())
|
| 1600 |
+
last_size += len(new_content.encode('utf-8'))
|
| 1601 |
+
except UnicodeDecodeError as e:
|
| 1602 |
+
# 遇到编码错误时,跳过这部分内容
|
| 1603 |
+
log.warning(f"WebSocket日志读取编码错误: {e}, 跳过部分内容")
|
| 1604 |
+
last_size = current_size
|
| 1605 |
+
except Exception as e:
|
| 1606 |
+
await websocket.send_text(f"Error reading new content: {e}")
|
| 1607 |
+
# 发生其他错误时,重置文件位置
|
| 1608 |
+
last_size = current_size
|
| 1609 |
+
|
| 1610 |
+
# 如果文件被截断(如清空日志),重置位置
|
| 1611 |
+
elif current_size < last_size:
|
| 1612 |
+
last_size = 0
|
| 1613 |
+
await websocket.send_text("--- 日志已清空 ---")
|
| 1614 |
+
|
| 1615 |
+
except WebSocketDisconnect:
|
| 1616 |
+
pass
|
| 1617 |
+
except Exception as e:
|
| 1618 |
+
log.error(f"WebSocket logs error: {e}")
|
| 1619 |
+
finally:
|
| 1620 |
+
manager.disconnect(websocket)
|
| 1621 |
+
|
| 1622 |
+
|
| 1623 |
+
# =============================================================================
|
| 1624 |
+
# Usage Statistics API (使用统计API)
|
| 1625 |
+
# =============================================================================
|
| 1626 |
+
|
| 1627 |
+
@router.get("/usage/stats")
|
| 1628 |
+
async def get_usage_statistics(filename: Optional[str] = None, token: str = Depends(verify_token)):
|
| 1629 |
+
"""
|
| 1630 |
+
获取使用统计信息
|
| 1631 |
+
|
| 1632 |
+
Args:
|
| 1633 |
+
filename: 可选,指定凭证文��名。如果不提供则返回所有文件的统计
|
| 1634 |
+
|
| 1635 |
+
Returns:
|
| 1636 |
+
usage statistics for the specified file or all files
|
| 1637 |
+
"""
|
| 1638 |
+
try:
|
| 1639 |
+
stats = await get_usage_stats(filename)
|
| 1640 |
+
return JSONResponse(content={
|
| 1641 |
+
"success": True,
|
| 1642 |
+
"data": stats
|
| 1643 |
+
})
|
| 1644 |
+
except Exception as e:
|
| 1645 |
+
log.error(f"获取使用统计失败: {e}")
|
| 1646 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1647 |
+
|
| 1648 |
+
|
| 1649 |
+
@router.get("/usage/aggregated")
|
| 1650 |
+
async def get_aggregated_usage_statistics(token: str = Depends(verify_token)):
|
| 1651 |
+
"""
|
| 1652 |
+
获取聚合使用统计信息
|
| 1653 |
+
|
| 1654 |
+
Returns:
|
| 1655 |
+
Aggregated statistics across all credential files
|
| 1656 |
+
"""
|
| 1657 |
+
try:
|
| 1658 |
+
stats = await get_aggregated_stats()
|
| 1659 |
+
return JSONResponse(content={
|
| 1660 |
+
"success": True,
|
| 1661 |
+
"data": stats
|
| 1662 |
+
})
|
| 1663 |
+
except Exception as e:
|
| 1664 |
+
log.error(f"获取聚合统计失败: {e}")
|
| 1665 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1666 |
+
|
| 1667 |
+
|
| 1668 |
+
|
| 1669 |
+
class UsageLimitsUpdateRequest(BaseModel):
|
| 1670 |
+
filename: str
|
| 1671 |
+
gemini_2_5_pro_limit: Optional[int] = None
|
| 1672 |
+
total_limit: Optional[int] = None
|
| 1673 |
+
|
| 1674 |
+
|
| 1675 |
+
@router.post("/usage/update-limits")
|
| 1676 |
+
async def update_usage_limits(request: UsageLimitsUpdateRequest, token: str = Depends(verify_token)):
|
| 1677 |
+
"""
|
| 1678 |
+
更新指定凭证文件的每日使用限制
|
| 1679 |
+
|
| 1680 |
+
Args:
|
| 1681 |
+
request: 包含文件名和新限制值的请求
|
| 1682 |
+
|
| 1683 |
+
Returns:
|
| 1684 |
+
Success message
|
| 1685 |
+
"""
|
| 1686 |
+
try:
|
| 1687 |
+
stats_instance = await get_usage_stats_instance()
|
| 1688 |
+
|
| 1689 |
+
await stats_instance.update_daily_limits(
|
| 1690 |
+
filename=request.filename,
|
| 1691 |
+
gemini_2_5_pro_limit=request.gemini_2_5_pro_limit,
|
| 1692 |
+
total_limit=request.total_limit
|
| 1693 |
+
)
|
| 1694 |
+
|
| 1695 |
+
return JSONResponse(content={
|
| 1696 |
+
"success": True,
|
| 1697 |
+
"message": f"已更新 {request.filename} 的使用限制"
|
| 1698 |
+
})
|
| 1699 |
+
|
| 1700 |
+
except Exception as e:
|
| 1701 |
+
log.error(f"更新使用限制失败: {e}")
|
| 1702 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1703 |
+
|
| 1704 |
+
|
| 1705 |
+
class UsageResetRequest(BaseModel):
|
| 1706 |
+
filename: Optional[str] = None
|
| 1707 |
+
|
| 1708 |
+
|
| 1709 |
+
@router.post("/usage/reset")
|
| 1710 |
+
async def reset_usage_statistics(request: UsageResetRequest, token: str = Depends(verify_token)):
|
| 1711 |
+
"""
|
| 1712 |
+
重置使用统计
|
| 1713 |
+
|
| 1714 |
+
Args:
|
| 1715 |
+
request: 包含可选文件名的请求。如果不提供文件名则重置所有统计
|
| 1716 |
+
|
| 1717 |
+
Returns:
|
| 1718 |
+
Success message
|
| 1719 |
+
"""
|
| 1720 |
+
try:
|
| 1721 |
+
stats_instance = await get_usage_stats_instance()
|
| 1722 |
+
|
| 1723 |
+
await stats_instance.reset_stats(filename=request.filename)
|
| 1724 |
+
|
| 1725 |
+
if request.filename:
|
| 1726 |
+
message = f"已重置 {request.filename} 的使用统计"
|
| 1727 |
+
else:
|
| 1728 |
+
message = "已重置所有文件的使用统计"
|
| 1729 |
+
|
| 1730 |
+
return JSONResponse(content={
|
| 1731 |
+
"success": True,
|
| 1732 |
+
"message": message
|
| 1733 |
+
})
|
| 1734 |
+
|
| 1735 |
+
except Exception as e:
|
| 1736 |
+
log.error(f"重置使用统计失败: {e}")
|
| 1737 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1738 |
+
|