lightspeed commited on
Commit
5868187
·
verified ·
1 Parent(s): b0b42d4

Upload 22 files

Browse files
src/anti_truncation.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anti-Truncation Module - Ensures complete streaming output
3
+ 保持一个流式请求内完整输出的反截断模块
4
+ """
5
+ import json
6
+ import re
7
+ from typing import Dict, Any, AsyncGenerator, List, Tuple
8
+
9
+ from fastapi.responses import StreamingResponse
10
+
11
+ from log import log
12
+
13
+ # 反截断配置
14
+ DONE_MARKER = "[done]"
15
+ MAX_CONTINUATION_ATTEMPTS = 3
16
+ CONTINUATION_PROMPT = f"""请从刚才被截断的地方继续输出剩余的所有内容。
17
+
18
+ 重要提醒:
19
+ 1. 不要重复前面已经输出的内容
20
+ 2. 直接继续输出,无需任何前言或解释
21
+ 3. 当你完整完成所有内容输出后,必须在最后一行单独输出:{DONE_MARKER}
22
+ 4. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记
23
+
24
+ 现在请继续输出:"""
25
+
26
+ # 正则替换配置
27
+ REGEX_REPLACEMENTS: List[Tuple[str, str, str]] = [
28
+ (
29
+ "age_pattern", # 替换规则名称
30
+ r"(?:[1-9]|1[0-8])岁(?:的)?|(?:十一|十二|十三|十四|十五|十六|十七|十八|十|一|二|三|四|五|六|七|八|九)岁(?:的)?", # 正则模式
31
+ "" # 替换文本
32
+ ),
33
+ # 可在此处添加更多替换规则
34
+ # ("rule_name", r"pattern", "replacement"),
35
+ ]
36
+
37
+ def apply_regex_replacements(text: str) -> str:
38
+ """
39
+ 对文本应用正则替换规则
40
+
41
+ Args:
42
+ text: 要处理的文本
43
+
44
+ Returns:
45
+ 处理后的文本
46
+ """
47
+ if not text:
48
+ return text
49
+
50
+ processed_text = text
51
+ replacement_count = 0
52
+
53
+ for rule_name, pattern, replacement in REGEX_REPLACEMENTS:
54
+ try:
55
+ # 编译正则表达式,使用IGNORECASE标志
56
+ regex = re.compile(pattern, re.IGNORECASE)
57
+
58
+ # 执行替换
59
+ new_text, count = regex.subn(replacement, processed_text)
60
+
61
+ if count > 0:
62
+ log.debug(f"Regex replacement '{rule_name}': {count} matches replaced")
63
+ processed_text = new_text
64
+ replacement_count += count
65
+
66
+ except re.error as e:
67
+ log.error(f"Invalid regex pattern in rule '{rule_name}': {e}")
68
+ continue
69
+
70
+ if replacement_count > 0:
71
+ log.info(f"Applied {replacement_count} regex replacements to text")
72
+
73
+ return processed_text
74
+
75
+ def apply_regex_replacements_to_payload(payload: Dict[str, Any]) -> Dict[str, Any]:
76
+ """
77
+ 对请求payload中的文本内容应用正则替换
78
+
79
+ Args:
80
+ payload: 请求payload
81
+
82
+ Returns:
83
+ 应用替换后的payload
84
+ """
85
+ if not REGEX_REPLACEMENTS:
86
+ return payload
87
+
88
+ modified_payload = payload.copy()
89
+ request_data = modified_payload.get("request", {})
90
+
91
+ # 处理contents中的文本
92
+ contents = request_data.get("contents", [])
93
+ if contents:
94
+ new_contents = []
95
+ for content in contents:
96
+ if isinstance(content, dict):
97
+ new_content = content.copy()
98
+ parts = new_content.get("parts", [])
99
+ if parts:
100
+ new_parts = []
101
+ for part in parts:
102
+ if isinstance(part, dict) and "text" in part:
103
+ new_part = part.copy()
104
+ new_part["text"] = apply_regex_replacements(part["text"])
105
+ new_parts.append(new_part)
106
+ else:
107
+ new_parts.append(part)
108
+ new_content["parts"] = new_parts
109
+ new_contents.append(new_content)
110
+ else:
111
+ new_contents.append(content)
112
+
113
+ request_data["contents"] = new_contents
114
+ modified_payload["request"] = request_data
115
+ log.debug("Applied regex replacements to request contents")
116
+
117
+ return modified_payload
118
+
119
+ def apply_anti_truncation(payload: Dict[str, Any]) -> Dict[str, Any]:
120
+ """
121
+ 对请求payload应用反截断处理和正则替换
122
+ 在systemInstruction中添加提醒,要求模型在结束时输出DONE_MARKER标记
123
+
124
+ Args:
125
+ payload: 原始请求payload
126
+
127
+ Returns:
128
+ 添加了反截断指令并应用了正则替换的payload
129
+ """
130
+ # 首先应用正则替换
131
+ modified_payload = apply_regex_replacements_to_payload(payload)
132
+ request_data = modified_payload.get("request", {})
133
+
134
+ # 获取或创建systemInstruction
135
+ system_instruction = request_data.get("systemInstruction", {})
136
+ if not system_instruction:
137
+ system_instruction = {"parts": []}
138
+ elif "parts" not in system_instruction:
139
+ system_instruction["parts"] = []
140
+
141
+ # 添加反截断指令
142
+ anti_truncation_instruction = {
143
+ "text": f"""严格执行以下输出结束规则:
144
+
145
+ 1. 当你完成完整回答时,必须在输出的最后单独一行输出:{DONE_MARKER}
146
+ 2. {DONE_MARKER} 标记表示你的回答已经完全结束,这是必需的结束标记
147
+ 3. 只有输出了 {DONE_MARKER} 标记,系统才认为你的回答是完整的
148
+ 4. 如果你的回答被截断,系统会要求你继续输出剩余内容
149
+ 5. 无论回答长短,都必须以 {DONE_MARKER} 标记结束
150
+
151
+ 示例格式:
152
+ ```
153
+ 你的回答内容...
154
+ 更多回答内容...
155
+ {DONE_MARKER}
156
+ ```
157
+
158
+ 注意:{DONE_MARKER} 必须单独占一行,前面不要有任何其他字符。
159
+
160
+ 这个规则对于确保输出完整性极其重要,请严格遵守。"""
161
+ }
162
+
163
+ # 检查是否已经包含反截断指令
164
+ has_done_instruction = any(
165
+ part.get("text", "").find(DONE_MARKER) != -1
166
+ for part in system_instruction["parts"]
167
+ if isinstance(part, dict)
168
+ )
169
+
170
+ if not has_done_instruction:
171
+ system_instruction["parts"].append(anti_truncation_instruction)
172
+ request_data["systemInstruction"] = system_instruction
173
+ modified_payload["request"] = request_data
174
+
175
+ log.debug("Applied anti-truncation instruction to request")
176
+
177
+ return modified_payload
178
+
179
+ class AntiTruncationStreamProcessor:
180
+ """反截断流式处理器"""
181
+
182
+ def __init__(self,
183
+ original_request_func,
184
+ payload: Dict[str, Any],
185
+ max_attempts: int = MAX_CONTINUATION_ATTEMPTS):
186
+ self.original_request_func = original_request_func
187
+ self.base_payload = payload.copy()
188
+ self.max_attempts = max_attempts
189
+ self.collected_content = [] # 使用列表避免字符串重复拼接
190
+ self.current_attempt = 0
191
+
192
+ async def process_stream(self) -> AsyncGenerator[bytes, None]:
193
+ """处理流式响应,检测并处理截断"""
194
+
195
+ while self.current_attempt < self.max_attempts:
196
+ self.current_attempt += 1
197
+
198
+ # 构建当前请求payload
199
+ current_payload = self._build_current_payload()
200
+
201
+ log.debug(f"Anti-truncation attempt {self.current_attempt}/{self.max_attempts}")
202
+
203
+ # 发送请求
204
+ try:
205
+ response = await self.original_request_func(current_payload)
206
+
207
+ if not isinstance(response, StreamingResponse):
208
+ # 非流式响应,直接处理
209
+ yield await self._handle_non_streaming_response(response)
210
+ return
211
+
212
+ # 处理流式响应
213
+ chunk_content = ""
214
+ found_done_marker = False
215
+
216
+ async for chunk in response.body_iterator:
217
+ if not chunk:
218
+ yield chunk
219
+ continue
220
+
221
+ # 处理不同数据类型的startswith问题
222
+ if isinstance(chunk, bytes):
223
+ if not chunk.startswith(b'data: '):
224
+ yield chunk
225
+ continue
226
+ payload_data = chunk[len(b'data: '):]
227
+ else:
228
+ chunk_str = str(chunk)
229
+ if not chunk_str.startswith('data: '):
230
+ yield chunk
231
+ continue
232
+ payload_data = chunk_str[len('data: '):].encode()
233
+
234
+ # 解析chunk内容
235
+
236
+ if payload_data.strip() == b'[DONE]':
237
+ # 检查是否找到了done标记
238
+ if found_done_marker:
239
+ log.info("Anti-truncation: Found [done] marker, output complete")
240
+ yield chunk
241
+ return
242
+ else:
243
+ log.warning("Anti-truncation: Stream ended without [done] marker")
244
+ # 不发送[DONE],准备继续
245
+ break
246
+
247
+ try:
248
+ data = json.loads(payload_data.decode())
249
+ content = self._extract_content_from_chunk(data)
250
+
251
+ if content:
252
+ chunk_content += content
253
+
254
+ # 检查是否包含done标记
255
+ if self._check_done_marker_in_chunk_content(content):
256
+ found_done_marker = True
257
+ log.info("Anti-truncation: Found [done] marker in chunk")
258
+
259
+ # 清理chunk中的[done]标记后再发送
260
+ cleaned_chunk = self._remove_done_marker_from_chunk(chunk, data)
261
+ yield cleaned_chunk
262
+
263
+ except (json.JSONDecodeError, UnicodeDecodeError):
264
+ yield chunk
265
+ continue
266
+
267
+ # 更新收集的内容 - 使用列表避免字符串重复拼接
268
+ if chunk_content:
269
+ self.collected_content.append(chunk_content)
270
+
271
+ # 如果找到了done标记,结束
272
+ if found_done_marker:
273
+ # 立即清理内容释放内存
274
+ self.collected_content.clear()
275
+ yield b'data: [DONE]\n\n'
276
+ return
277
+
278
+ # 只有在单个chunk中没有找到done标记时,才检查累积内容(防止done标记跨chunk出现)
279
+ if not found_done_marker:
280
+ accumulated_text = ''.join(self.collected_content) if self.collected_content else ""
281
+ if self._check_done_marker_in_text(accumulated_text):
282
+ log.info("Anti-truncation: Found [done] marker in accumulated content")
283
+ # 立即清理内容释放内存
284
+ self.collected_content.clear()
285
+ yield b'data: [DONE]\n\n'
286
+ return
287
+
288
+ # 如果没找到done标记且不是最后一次尝试,准备续传
289
+ if self.current_attempt < self.max_attempts:
290
+ total_length = sum(len(chunk) for chunk in self.collected_content) if self.collected_content else 0
291
+ log.info(f"Anti-truncation: No [done] marker found in output (length: {total_length}), preparing continuation (attempt {self.current_attempt + 1})")
292
+ if self.collected_content and total_length > 100:
293
+ last_chunk = self.collected_content[-1] if self.collected_content else ""
294
+ log.debug(f"Anti-truncation: Current collected content ends with: {'...' + last_chunk[-100:]}")
295
+ # 在下一次循环中会继续
296
+ continue
297
+ else:
298
+ # 最后一次尝试,直接结束
299
+ log.warning("Anti-truncation: Max attempts reached, ending stream")
300
+ # 立即清理内容释放内存
301
+ self.collected_content.clear()
302
+ yield b'data: [DONE]\n\n'
303
+ return
304
+
305
+ except Exception as e:
306
+ log.error(f"Anti-truncation error in attempt {self.current_attempt}: {str(e)}")
307
+ if self.current_attempt >= self.max_attempts:
308
+ # 发送错误chunk
309
+ error_chunk = {
310
+ "error": {
311
+ "message": f"Anti-truncation failed: {str(e)}",
312
+ "type": "api_error",
313
+ "code": 500
314
+ }
315
+ }
316
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
317
+ yield b'data: [DONE]\n\n'
318
+ return
319
+ # 否则继续下一次尝试
320
+
321
+ # 如果所有尝试都失败了
322
+ log.error("Anti-truncation: All attempts failed")
323
+ # 确保清理内容释放内存
324
+ self.collected_content.clear()
325
+ yield b'data: [DONE]\n\n'
326
+
327
+ def _build_current_payload(self) -> Dict[str, Any]:
328
+ """构建当前请求的payload"""
329
+ if self.current_attempt == 1:
330
+ # 第一次请求,使用原始payload(已经包含反截断指令)
331
+ return self.base_payload
332
+
333
+ # 后续请求,添加续传指令
334
+ continuation_payload = self.base_payload.copy()
335
+ request_data = continuation_payload.get("request", {})
336
+
337
+ # 获取原始对话内容
338
+ contents = request_data.get("contents", [])
339
+ new_contents = contents.copy()
340
+
341
+ # 如果有收集到的内容,添加到对话中
342
+ if self.collected_content:
343
+ # 拼接收集的内容并添加模型的回复
344
+ accumulated_text = ''.join(self.collected_content)
345
+ new_contents.append({
346
+ "role": "model",
347
+ "parts": [{"text": accumulated_text}]
348
+ })
349
+
350
+ # 构建具体的续写指令,包含前面的内容摘要
351
+ content_summary = ""
352
+ if self.collected_content:
353
+ accumulated_text = ''.join(self.collected_content)
354
+ if len(accumulated_text) > 200:
355
+ content_summary = f"\n\n前面你已经输出了约 {len(accumulated_text)} 个字符的内容,结尾是:\n\"...{accumulated_text[-100:]}\""
356
+ else:
357
+ content_summary = f"\n\n前面你已经输出的内容是:\n\"{accumulated_text}\""
358
+
359
+ detailed_continuation_prompt = f"""{CONTINUATION_PROMPT}{content_summary}"""
360
+
361
+ # 添加继续指令
362
+ continuation_message = {
363
+ "role": "user",
364
+ "parts": [{"text": detailed_continuation_prompt}]
365
+ }
366
+ new_contents.append(continuation_message)
367
+
368
+ request_data["contents"] = new_contents
369
+ continuation_payload["request"] = request_data
370
+
371
+ return continuation_payload
372
+
373
+ def _extract_content_from_chunk(self, data: Dict[str, Any]) -> str:
374
+ """从chunk数据中提取文本内容"""
375
+ content = ""
376
+
377
+ # 处理Gemini格式
378
+ if "candidates" in data:
379
+ for candidate in data["candidates"]:
380
+ if "content" in candidate:
381
+ parts = candidate["content"].get("parts", [])
382
+ for part in parts:
383
+ if "text" in part:
384
+ content += part["text"]
385
+
386
+ # 处理OpenAI格式
387
+ elif "choices" in data:
388
+ for choice in data["choices"]:
389
+ if "delta" in choice and "content" in choice["delta"]:
390
+ content += choice["delta"]["content"]
391
+ elif "message" in choice and "content" in choice["message"]:
392
+ content += choice["message"]["content"]
393
+
394
+ return content
395
+
396
+ async def _handle_non_streaming_response(self, response) -> bytes:
397
+ """处理非流式响应"""
398
+ try:
399
+ if hasattr(response, 'body'):
400
+ content = response.body.decode() if isinstance(response.body, bytes) else response.body
401
+ elif hasattr(response, 'content'):
402
+ content = response.content.decode() if isinstance(response.content, bytes) else response.content
403
+ else:
404
+ content = str(response)
405
+
406
+ response_data = json.loads(content)
407
+
408
+ # 检查是否包含done标记
409
+ text_content = self._extract_content_from_response(response_data)
410
+ has_done_marker = self._check_done_marker_in_text(text_content)
411
+
412
+ if not has_done_marker and self.current_attempt < self.max_attempts:
413
+ log.info("Anti-truncation: Non-streaming response needs continuation")
414
+ if text_content:
415
+ self.collected_content.append(text_content)
416
+ # 递归处理续传
417
+ return await self._handle_non_streaming_response(
418
+ await self.original_request_func(self._build_current_payload())
419
+ )
420
+
421
+ return content.encode()
422
+
423
+ except Exception as e:
424
+ log.error(f"Anti-truncation non-streaming error: {str(e)}")
425
+ return json.dumps({
426
+ "error": {
427
+ "message": f"Anti-truncation failed: {str(e)}",
428
+ "type": "api_error",
429
+ "code": 500
430
+ }
431
+ }).encode()
432
+
433
+ def _check_done_marker_in_text(self, text: str) -> bool:
434
+ """检测文本中是否包含DONE_MARKER(只检测指定标记)"""
435
+ if not text:
436
+ return False
437
+
438
+ # 只要文本中出现DONE_MARKER即可
439
+ return DONE_MARKER in text
440
+
441
+ def _check_done_marker_in_chunk_content(self, content: str) -> bool:
442
+ """检查单个chunk内容中是否包含done标记"""
443
+ return self._check_done_marker_in_text(content)
444
+
445
+ def _extract_content_from_response(self, data: Dict[str, Any]) -> str:
446
+ """从响应数据中提取文本内容"""
447
+ content = ""
448
+
449
+ # 处理Gemini格式
450
+ if "candidates" in data:
451
+ for candidate in data["candidates"]:
452
+ if "content" in candidate:
453
+ parts = candidate["content"].get("parts", [])
454
+ for part in parts:
455
+ if "text" in part:
456
+ content += part["text"]
457
+
458
+ # 处理OpenAI格式
459
+ elif "choices" in data:
460
+ for choice in data["choices"]:
461
+ if "message" in choice and "content" in choice["message"]:
462
+ content += choice["message"]["content"]
463
+
464
+ return content
465
+
466
+ def _remove_done_marker_from_chunk(self, chunk: bytes, data: Dict[str, Any]) -> bytes:
467
+ """使用正则表达式从chunk中移除[done]标记"""
468
+ try:
469
+ # 首先检查是否真的包含[done]标记,如果没有则直接返回原始chunk
470
+ chunk_text = chunk.decode('utf-8', errors='ignore') if isinstance(chunk, bytes) else str(chunk)
471
+ if '[done]' not in chunk_text.lower():
472
+ return chunk # 没有[done]标记,直接返回原始chunk
473
+
474
+ # 编译正则表达式,匹配[done]标记(忽略大小写,包括可能的空白字符)
475
+ done_pattern = re.compile(r'\s*\[done\]\s*', re.IGNORECASE)
476
+
477
+ # 处理Gemini格式
478
+ if "candidates" in data:
479
+ modified_data = data.copy()
480
+ modified_data["candidates"] = []
481
+
482
+ for i, candidate in enumerate(data["candidates"]):
483
+ modified_candidate = candidate.copy()
484
+ # 只在最后一个candidate中清理[done]标记
485
+ is_last_candidate = (i == len(data["candidates"]) - 1)
486
+
487
+ if "content" in candidate:
488
+ modified_content = candidate["content"].copy()
489
+ if "parts" in modified_content:
490
+ modified_parts = []
491
+ for part in modified_content["parts"]:
492
+ if "text" in part and isinstance(part["text"], str):
493
+ modified_part = part.copy()
494
+ # 只在最后一个candidate中清理[done]标记
495
+ if is_last_candidate:
496
+ modified_part["text"] = done_pattern.sub('', part["text"])
497
+ modified_parts.append(modified_part)
498
+ else:
499
+ modified_parts.append(part)
500
+ modified_content["parts"] = modified_parts
501
+ modified_candidate["content"] = modified_content
502
+ modified_data["candidates"].append(modified_candidate)
503
+
504
+ # 重新编码为chunk格式,保持原始的换行符
505
+ if isinstance(chunk, bytes):
506
+ prefix = b'data: '
507
+ suffix = b'\n\n' # 确保有正确的换行符
508
+ json_data = json.dumps(modified_data, separators=(',',':'), ensure_ascii=False).encode('utf-8')
509
+ return prefix + json_data + suffix
510
+ else:
511
+ return f"data: {json.dumps(modified_data, separators=(',',':'), ensure_ascii=False)}\n\n"
512
+
513
+ # 处理OpenAI格式
514
+ elif "choices" in data:
515
+ modified_data = data.copy()
516
+ modified_data["choices"] = []
517
+
518
+ for choice in data["choices"]:
519
+ modified_choice = choice.copy()
520
+ if "delta" in choice and "content" in choice["delta"]:
521
+ modified_delta = choice["delta"].copy()
522
+ modified_delta["content"] = done_pattern.sub('', choice["delta"]["content"])
523
+ modified_choice["delta"] = modified_delta
524
+ elif "message" in choice and "content" in choice["message"]:
525
+ modified_message = choice["message"].copy()
526
+ modified_message["content"] = done_pattern.sub('', choice["message"]["content"])
527
+ modified_choice["message"] = modified_message
528
+ modified_data["choices"].append(modified_choice)
529
+
530
+ # 重新编码为chunk格式,保持原始的换行符
531
+ if isinstance(chunk, bytes):
532
+ prefix = b'data: '
533
+ suffix = b'\n\n' # 确保有正确的换行符
534
+ json_data = json.dumps(modified_data, separators=(',',':'), ensure_ascii=False).encode('utf-8')
535
+ return prefix + json_data + suffix
536
+ else:
537
+ return f"data: {json.dumps(modified_data, separators=(',',':'), ensure_ascii=False)}\n\n"
538
+
539
+ # 如果没有找到支持的格式,返回原始chunk
540
+ return chunk
541
+
542
+ except Exception as e:
543
+ log.warning(f"Failed to remove [done] marker from chunk: {str(e)}")
544
+ return chunk
545
+
546
+ async def apply_anti_truncation_to_stream(
547
+ request_func,
548
+ payload: Dict[str, Any],
549
+ max_attempts: int = MAX_CONTINUATION_ATTEMPTS
550
+ ) -> StreamingResponse:
551
+ """
552
+ 对流式请求应用反截断处理
553
+
554
+ Args:
555
+ request_func: 原始请求函数
556
+ payload: 请求payload
557
+ max_attempts: 最大续传尝试次数
558
+
559
+ Returns:
560
+ 处理后的StreamingResponse
561
+ """
562
+
563
+ # 首先对payload应用反截断指令
564
+ anti_truncation_payload = apply_anti_truncation(payload)
565
+
566
+ # 创建反截断处理器
567
+ processor = AntiTruncationStreamProcessor(
568
+ lambda p: request_func(p),
569
+ anti_truncation_payload,
570
+ max_attempts
571
+ )
572
+
573
+ # 返回包装后的流式响应
574
+ return StreamingResponse(
575
+ processor.process_stream(),
576
+ media_type="text/event-stream"
577
+ )
578
+
579
+ def is_anti_truncation_enabled(request_data: Dict[str, Any]) -> bool:
580
+ """
581
+ 检查请求是否启用了反截断功能
582
+
583
+ Args:
584
+ request_data: 请求数据
585
+
586
+ Returns:
587
+ 是否启用反截断
588
+ """
589
+ return request_data.get("enable_anti_truncation", False)
src/auth.py ADDED
@@ -0,0 +1,1530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 认证API模块 - 使用统一存储中间层,完全摆脱文件操作
3
+ """
4
+ import asyncio
5
+ import json
6
+ import secrets
7
+ import socket
8
+ import threading
9
+ import time
10
+ import uuid
11
+ from datetime import timezone
12
+ from http.server import BaseHTTPRequestHandler, HTTPServer
13
+ from typing import Optional, Dict, Any, List
14
+ from urllib.parse import urlparse, parse_qs
15
+
16
+ from .google_oauth_api import Credentials, Flow, enable_required_apis, get_user_projects, select_default_project
17
+ from .storage_adapter import get_storage_adapter
18
+ from config import get_config_value
19
+ from log import log
20
+
21
+ # OAuth Configuration
22
+ CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
23
+ CLIENT_SECRET = "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"
24
+ SCOPES = [
25
+ "https://www.googleapis.com/auth/cloud-platform",
26
+ "https://www.googleapis.com/auth/userinfo.email",
27
+ "https://www.googleapis.com/auth/userinfo.profile",
28
+ ]
29
+
30
+ # 回调服务器配置
31
+ CALLBACK_HOST = 'localhost'
32
+
33
+ async def get_callback_port():
34
+ """获取OAuth回调端口"""
35
+ return int(await get_config_value('oauth_callback_port', '8080', 'OAUTH_CALLBACK_PORT'))
36
+
37
+ # 全局状态管理 - 严格限制大小
38
+ auth_flows = {} # 存储进行中的认证流程
39
+ MAX_AUTH_FLOWS = 20 # 严格限制最大认证流程数
40
+
41
+ def cleanup_auth_flows_for_memory():
42
+ """清理认证流程以释放内存"""
43
+ global auth_flows
44
+ cleaned = cleanup_expired_flows()
45
+ # 如果还是太多,强制清理一些旧的流程
46
+ if len(auth_flows) > 10:
47
+ # 按创建时间排序,保留最新的10个
48
+ sorted_flows = sorted(auth_flows.items(), key=lambda x: x[1].get('created_at', 0), reverse=True)
49
+ new_auth_flows = dict(sorted_flows[:10])
50
+
51
+ # 清理被移除的流程
52
+ for state, flow_data in auth_flows.items():
53
+ if state not in new_auth_flows:
54
+ try:
55
+ if flow_data.get('server'):
56
+ server = flow_data['server']
57
+ port = flow_data.get('callback_port')
58
+ async_shutdown_server(server, port)
59
+ except Exception:
60
+ pass
61
+ flow_data.clear()
62
+
63
+ auth_flows = new_auth_flows
64
+ log.info(f"强制清理认证流程,保留 {len(auth_flows)} 个最新流程")
65
+
66
+ return len(auth_flows)
67
+
68
+
69
+ async def find_available_port(start_port: int = None) -> int:
70
+ """动态查找可用端口"""
71
+ if start_port is None:
72
+ start_port = await get_callback_port()
73
+
74
+ # 首先尝试默认端口
75
+ for port in range(start_port, start_port + 100): # 尝试100个端口
76
+ try:
77
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
78
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
79
+ s.bind(('0.0.0.0', port))
80
+ log.info(f"找到可用端口: {port}")
81
+ return port
82
+ except OSError:
83
+ continue
84
+
85
+ # 如果都不可用,让系统自动分配端口
86
+ try:
87
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
88
+ s.bind(('0.0.0.0', 0))
89
+ port = s.getsockname()[1]
90
+ log.info(f"系统分配可用端口: {port}")
91
+ return port
92
+ except OSError as e:
93
+ log.error(f"无法找到可用端口: {e}")
94
+ raise RuntimeError("无法找到可用端口")
95
+
96
+ def create_callback_server(port: int) -> HTTPServer:
97
+ """创建指定端口的回调服务器,优化快速关闭"""
98
+ try:
99
+ # 服务器监听0.0.0.0
100
+ server = HTTPServer(("0.0.0.0", port), AuthCallbackHandler)
101
+
102
+ # 设置socket选项以支持快速关闭
103
+ server.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
104
+ # 设置较短的超时时间
105
+ server.timeout = 1.0
106
+
107
+ log.info(f"创建OAuth回调服务器,监听端口: {port}")
108
+ return server
109
+ except OSError as e:
110
+ log.error(f"创建端口{port}的服务器失败: {e}")
111
+ raise
112
+
113
+ class AuthCallbackHandler(BaseHTTPRequestHandler):
114
+ """OAuth回调处理器"""
115
+ def do_GET(self):
116
+ query_components = parse_qs(urlparse(self.path).query)
117
+ code = query_components.get("code", [None])[0]
118
+ state = query_components.get("state", [None])[0]
119
+
120
+ log.info(f"收到OAuth回调: code={'已获取' if code else '未获取'}, state={state}")
121
+
122
+ if code and state and state in auth_flows:
123
+ # 更新流程状态
124
+ auth_flows[state]['code'] = code
125
+ auth_flows[state]['completed'] = True
126
+
127
+ log.info(f"OAuth回调成功处理: state={state}")
128
+
129
+ self.send_response(200)
130
+ self.send_header("Content-type", "text/html")
131
+ self.end_headers()
132
+ # 成功页面
133
+ self.wfile.write(b"<h1>OAuth authentication successful!</h1><p>You can close this window. Please return to the original page and click 'Get Credentials' button.</p>")
134
+ else:
135
+ self.send_response(400)
136
+ self.send_header("Content-type", "text/html")
137
+ self.end_headers()
138
+ self.wfile.write(b"<h1>Authentication failed.</h1><p>Please try again.</p>")
139
+
140
+ def log_message(self, format, *args):
141
+ # 减少日志噪音
142
+ pass
143
+
144
+
145
+ async def create_auth_url(project_id: Optional[str] = None, user_session: str = None, get_all_projects: bool = False) -> Dict[str, Any]:
146
+ """创建认证URL,支持动态端口分配"""
147
+ try:
148
+ # 动态分配端口
149
+ callback_port = await find_available_port()
150
+ callback_url = f"http://{CALLBACK_HOST}:{callback_port}"
151
+
152
+ # 立即启动回调服务器
153
+ try:
154
+ callback_server = create_callback_server(callback_port)
155
+ # 在后台线程中运行服务器
156
+ server_thread = threading.Thread(
157
+ target=callback_server.serve_forever,
158
+ daemon=True,
159
+ name=f"OAuth-Server-{callback_port}"
160
+ )
161
+ server_thread.start()
162
+ log.info(f"OAuth回调服务器已启动,端口: {callback_port}")
163
+ except Exception as e:
164
+ log.error(f"启动回调服务器失败: {e}")
165
+ return {
166
+ 'success': False,
167
+ 'error': f'无法启动OAuth回调服务器,端口{callback_port}: {str(e)}'
168
+ }
169
+
170
+ # 创建OAuth流程
171
+ client_config = {
172
+ "installed": {
173
+ "client_id": CLIENT_ID,
174
+ "client_secret": CLIENT_SECRET,
175
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
176
+ "token_uri": "https://oauth2.googleapis.com/token",
177
+ }
178
+ }
179
+
180
+ flow = Flow(
181
+ client_id=CLIENT_ID,
182
+ client_secret=CLIENT_SECRET,
183
+ scopes=SCOPES,
184
+ redirect_uri=callback_url
185
+ )
186
+
187
+ # 生成状态标识符,包含用户会话信息
188
+ if user_session:
189
+ state = f"{user_session}_{str(uuid.uuid4())}"
190
+ else:
191
+ state = str(uuid.uuid4())
192
+
193
+ # 生成认证URL
194
+ auth_url = flow.get_auth_url(state=state)
195
+
196
+ # 严格控制认证流程数量 - 超过限制时立即清理最旧的
197
+ if len(auth_flows) >= MAX_AUTH_FLOWS:
198
+ # 清理最旧的认证流程
199
+ oldest_state = min(auth_flows.keys(),
200
+ key=lambda k: auth_flows[k].get('created_at', 0))
201
+ try:
202
+ # 清理服务器资源
203
+ old_flow = auth_flows[oldest_state]
204
+ if old_flow.get('server'):
205
+ server = old_flow['server']
206
+ port = old_flow.get('callback_port')
207
+ async_shutdown_server(server, port)
208
+ except Exception as e:
209
+ log.warning(f"Failed to cleanup old auth flow {oldest_state}: {e}")
210
+
211
+ del auth_flows[oldest_state]
212
+ log.debug(f"Removed oldest auth flow: {oldest_state}")
213
+
214
+ # 保存流程状态
215
+ auth_flows[state] = {
216
+ 'flow': flow,
217
+ 'project_id': project_id, # 可能为None,稍后在回调时确定
218
+ 'user_session': user_session,
219
+ 'callback_port': callback_port, # 存储分配的端口
220
+ 'callback_url': callback_url, # 存储完整回调URL
221
+ 'server': callback_server, # 存储服务器实例
222
+ 'server_thread': server_thread, # 存储服务器线程
223
+ 'code': None,
224
+ 'completed': False,
225
+ 'created_at': time.time(),
226
+ 'auto_project_detection': project_id is None, # 标记是否需要自动检测项目ID
227
+ 'get_all_projects': get_all_projects # 是否为所有项目获取凭证
228
+ }
229
+
230
+ # 清理过期的流程(30分钟)
231
+ cleanup_expired_flows()
232
+
233
+ log.info(f"OAuth流程已创建: state={state}, project_id={project_id}")
234
+ log.info(f"用户需要访问认证URL,然后OAuth会回调到 {callback_url}")
235
+ log.info(f"为此认证流程分配的端口: {callback_port}")
236
+
237
+ return {
238
+ 'auth_url': auth_url,
239
+ 'state': state,
240
+ 'callback_port': callback_port,
241
+ 'success': True,
242
+ 'auto_project_detection': project_id is None,
243
+ 'detected_project_id': project_id
244
+ }
245
+
246
+ except Exception as e:
247
+ log.error(f"创建认证URL失败: {e}")
248
+ return {
249
+ 'success': False,
250
+ 'error': str(e)
251
+ }
252
+
253
+
254
+ def wait_for_callback_sync(state: str, timeout: int = 300) -> Optional[str]:
255
+ """同步等待OAuth回调完成,使用对应流程的专用服务器"""
256
+ if state not in auth_flows:
257
+ log.error(f"未找到状态为 {state} 的认证流程")
258
+ return None
259
+
260
+ flow_data = auth_flows[state]
261
+ callback_port = flow_data['callback_port']
262
+
263
+ # 服务器已经在create_auth_url时启动了,这里只需要等待
264
+ log.info(f"等待OAuth回调完成,端口: {callback_port}")
265
+
266
+ # 等待回调完成
267
+ start_time = time.time()
268
+ while time.time() - start_time < timeout:
269
+ if flow_data.get('code'):
270
+ log.info(f"OAuth回调成功完成")
271
+ return flow_data['code']
272
+ time.sleep(0.5) # 每0.5秒检查一次
273
+
274
+ # 刷新flow_data引用
275
+ if state in auth_flows:
276
+ flow_data = auth_flows[state]
277
+
278
+ log.warning(f"等待OAuth回调超时 ({timeout}秒)")
279
+ return None
280
+
281
+
282
+ async def complete_auth_flow(project_id: Optional[str] = None, user_session: str = None) -> Dict[str, Any]:
283
+ """完成认证流程并保存凭证,支持自动检测项目ID"""
284
+ try:
285
+ # 查找对应的认证流程
286
+ state = None
287
+ flow_data = None
288
+
289
+ # 如果指定了project_id,先尝试匹配指定的项目
290
+ if project_id:
291
+ for s, data in auth_flows.items():
292
+ if data['project_id'] == project_id:
293
+ # 如果指定了用户会话,优先匹配相同会话的流程
294
+ if user_session and data.get('user_session') == user_session:
295
+ state = s
296
+ flow_data = data
297
+ break
298
+ # 如果没有指定会话,或没找到匹配会话的流程,使用第一个匹配项目ID的
299
+ elif not state:
300
+ state = s
301
+ flow_data = data
302
+
303
+ # 如果没有指定项目ID或没找到匹配的,查找需要自动检测项目ID的流程
304
+ if not state:
305
+ for s, data in auth_flows.items():
306
+ if data.get('auto_project_detection', False):
307
+ # 如果指定了用户会话,优先匹配相同会话的流程
308
+ if user_session and data.get('user_session') == user_session:
309
+ state = s
310
+ flow_data = data
311
+ break
312
+ # 使用第一个找到的需要自动检测的流程
313
+ elif not state:
314
+ state = s
315
+ flow_data = data
316
+
317
+ if not state or not flow_data:
318
+ return {
319
+ 'success': False,
320
+ 'error': '未找到对应的认证流程,请先点击获取认证链接'
321
+ }
322
+
323
+ if not project_id:
324
+ project_id = flow_data.get('project_id')
325
+ if not project_id:
326
+ return {
327
+ 'success': False,
328
+ 'error': '缺少项目ID,请指定项目ID',
329
+ 'requires_manual_project_id': True
330
+ }
331
+
332
+ flow = flow_data['flow']
333
+
334
+ # 如果还没有授权码,需要等待回调
335
+ if not flow_data.get('code'):
336
+ log.info(f"等待用户完成OAuth授权 (state: {state})")
337
+ auth_code = wait_for_callback_sync(state)
338
+
339
+ if not auth_code:
340
+ return {
341
+ 'success': False,
342
+ 'error': '未接收到授权回调,请确保完成了浏览器中的OAuth认证'
343
+ }
344
+
345
+ # 更新流程数据
346
+ auth_flows[state]['code'] = auth_code
347
+ auth_flows[state]['completed'] = True
348
+ else:
349
+ auth_code = flow_data['code']
350
+
351
+ # 使用认证代码获取凭证
352
+ import oauthlib.oauth2.rfc6749.parameters
353
+ original_validate = oauthlib.oauth2.rfc6749.parameters.validate_token_parameters
354
+
355
+ def patched_validate(params):
356
+ try:
357
+ return original_validate(params)
358
+ except Warning:
359
+ pass
360
+
361
+ oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = patched_validate
362
+
363
+ try:
364
+ credentials = await flow.exchange_code(auth_code)
365
+ # credentials 已经在 exchange_code 中获得
366
+
367
+ # 如果需要自动检测项目ID且没有提供项目ID
368
+ if flow_data.get('auto_project_detection', False) and not project_id:
369
+ log.info("尝试通过API获取用户项目列表...")
370
+ log.info(f"使用的token: {credentials.access_token[:20]}...")
371
+ log.info(f"Token过期时间: {credentials.expires_at}")
372
+ user_projects = await get_user_projects(credentials)
373
+
374
+ if user_projects:
375
+ # 如果只有一个项目,自动使用
376
+ if len(user_projects) == 1:
377
+ project_id = user_projects[0].get('projectId')
378
+ if project_id:
379
+ flow_data['project_id'] = project_id
380
+ log.info(f"自动选择唯一项目: {project_id}")
381
+ # 如果有多个项目,尝试选择默认项目
382
+ else:
383
+ project_id = await select_default_project(user_projects)
384
+ if project_id:
385
+ flow_data['project_id'] = project_id
386
+ log.info(f"自动选择默认项目: {project_id}")
387
+ else:
388
+ # 返回项目列表让用户选择
389
+ return {
390
+ 'success': False,
391
+ 'error': '请从以下项目中选择一个',
392
+ 'requires_project_selection': True,
393
+ 'available_projects': [
394
+ {
395
+ 'projectId': p.get('projectId'),
396
+ 'name': p.get('displayName') or p.get('projectId'),
397
+ 'projectNumber': p.get('projectNumber')
398
+ }
399
+ for p in user_projects
400
+ ]
401
+ }
402
+ else:
403
+ # 如果无法获取项目列表,提示手动输入
404
+ return {
405
+ 'success': False,
406
+ 'error': '无法获取您的项目列表,请手动指定项目ID',
407
+ 'requires_manual_project_id': True
408
+ }
409
+
410
+ # 如果仍然没有项目ID,返回错误
411
+ if not project_id:
412
+ return {
413
+ 'success': False,
414
+ 'error': '缺少项目ID,请指定项目ID',
415
+ 'requires_manual_project_id': True
416
+ }
417
+
418
+ # 保存凭证
419
+ saved_filename = await save_credentials(credentials, project_id)
420
+
421
+ # 准备返回的凭证数据
422
+ creds_data = {
423
+ "client_id": CLIENT_ID,
424
+ "client_secret": CLIENT_SECRET,
425
+ "token": credentials.access_token,
426
+ "refresh_token": credentials.refresh_token,
427
+ "scopes": SCOPES,
428
+ "token_uri": "https://oauth2.googleapis.com/token",
429
+ "project_id": project_id
430
+ }
431
+
432
+ if credentials.expires_at:
433
+ if credentials.expires_at.tzinfo is None:
434
+ expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc)
435
+ else:
436
+ expiry_utc = credentials.expires_at
437
+ creds_data["expiry"] = expiry_utc.isoformat()
438
+
439
+ # 清理使用过的流程
440
+ if state in auth_flows:
441
+ flow_data_to_clean = auth_flows[state]
442
+ # 快速关闭服务器
443
+ try:
444
+ if flow_data_to_clean.get('server'):
445
+ server = flow_data_to_clean['server']
446
+ port = flow_data_to_clean.get('callback_port')
447
+ async_shutdown_server(server, port)
448
+ except Exception as e:
449
+ log.debug(f"启动异步关闭服务器时出错: {e}")
450
+
451
+ del auth_flows[state]
452
+
453
+ log.info("OAuth认证成功,凭证已保存")
454
+ return {
455
+ 'success': True,
456
+ 'credentials': creds_data,
457
+ 'file_path': saved_filename,
458
+ 'auto_detected_project': flow_data.get('auto_project_detection', False)
459
+ }
460
+
461
+ except Exception as e:
462
+ log.error(f"获取凭证失败: {e}")
463
+ return {
464
+ 'success': False,
465
+ 'error': f'获取凭证失败: {str(e)}'
466
+ }
467
+ finally:
468
+ oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = original_validate
469
+
470
+ except Exception as e:
471
+ log.error(f"完成认证流程失败: {e}")
472
+ return {
473
+ 'success': False,
474
+ 'error': str(e)
475
+ }
476
+
477
+
478
+ async def asyncio_complete_auth_flow(project_id: Optional[str] = None, user_session: str = None, get_all_projects: bool = False) -> Dict[str, Any]:
479
+ """异步完成认证流程,支持自动检测项目ID"""
480
+ try:
481
+ log.info(f"asyncio_complete_auth_flow开始执行: project_id={project_id}, user_session={user_session}")
482
+
483
+ # 查找对应的认证流程
484
+ state = None
485
+ flow_data = None
486
+
487
+ log.debug(f"当前所有auth_flows: {list(auth_flows.keys())}")
488
+
489
+ # 如果指定了project_id,先尝试匹配指定的项目
490
+ if project_id:
491
+ log.info(f"尝试匹配指定的项目ID: {project_id}")
492
+ for s, data in auth_flows.items():
493
+ if data['project_id'] == project_id:
494
+ # 如果指定了用户会话,优先匹配相同会话的流程
495
+ if user_session and data.get('user_session') == user_session:
496
+ state = s
497
+ flow_data = data
498
+ log.info(f"找到匹配的用户会话: {s}")
499
+ break
500
+ # 如果没有指定会话,或没找到匹配会话的流程,使用第一个匹配项目ID的
501
+ elif not state:
502
+ state = s
503
+ flow_data = data
504
+ log.info(f"找到匹配的项目ID: {s}")
505
+
506
+ # 如果没有指定项目ID或没找到匹配的,查找需要自动检测项目ID的流程
507
+ if not state:
508
+ log.info(f"没有找到指定项目的流程,查找自动检测流程")
509
+ for s, data in auth_flows.items():
510
+ log.debug(f"检查流程 {s}: auto_project_detection={data.get('auto_project_detection', False)}")
511
+ if data.get('auto_project_detection', False):
512
+ # 如果指定了用户会话,优先匹配相同会话的流程
513
+ if user_session and data.get('user_session') == user_session:
514
+ state = s
515
+ flow_data = data
516
+ log.info(f"找到匹配用户会话的自动检测流程: {s}")
517
+ break
518
+ # 使用第一个找到的需要自动检测的流程
519
+ elif not state:
520
+ state = s
521
+ flow_data = data
522
+ log.info(f"找到自动检测流程: {s}")
523
+
524
+ if not state or not flow_data:
525
+ log.error(f"未找到认证流程: state={state}, flow_data存在={bool(flow_data)}")
526
+ log.debug(f"当前所有flow_data: {list(auth_flows.keys())}")
527
+ return {
528
+ 'success': False,
529
+ 'error': '未找到对应的认证流程,请先点击获取认证链接'
530
+ }
531
+
532
+ log.info(f"找到认证流程: state={state}")
533
+ log.info(f"flow_data内容: project_id={flow_data.get('project_id')}, auto_project_detection={flow_data.get('auto_project_detection')}")
534
+ log.info(f"传入的project_id参数: {project_id}")
535
+
536
+ # 如果需要自动检测项目ID且没有提供项目ID
537
+ log.info(f"检查auto_project_detection条件: auto_project_detection={flow_data.get('auto_project_detection', False)}, not project_id={not project_id}")
538
+ if flow_data.get('auto_project_detection', False) and not project_id:
539
+ log.info("跳过自动检测项目ID,进入等待阶段")
540
+ elif not project_id:
541
+ log.info("进入project_id检查分支")
542
+ project_id = flow_data.get('project_id')
543
+ if not project_id:
544
+ log.error("缺少项目ID,返回错误")
545
+ return {
546
+ 'success': False,
547
+ 'error': '缺少项目ID,请指定项目ID',
548
+ 'requires_manual_project_id': True
549
+ }
550
+ else:
551
+ log.info(f"使用提供的项目ID: {project_id}")
552
+
553
+ # 检查是否已经有授权码
554
+ log.info(f"开始检查OAuth授权码...")
555
+ max_wait_time = 60 # 最多等待60秒
556
+ wait_interval = 1 # 每秒检查一次
557
+ waited = 0
558
+
559
+ while waited < max_wait_time:
560
+ log.debug(f"等待OAuth授权码... ({waited}/{max_wait_time}秒)")
561
+ if flow_data.get('code'):
562
+ log.info(f"检测到OAuth授权码,开始处理凭证 (等待时间: {waited}秒)")
563
+ break
564
+
565
+ # 异步等待
566
+ await asyncio.sleep(wait_interval)
567
+ waited += wait_interval
568
+
569
+ # 刷新flow_data引用,因为可能被回调更新了
570
+ if state in auth_flows:
571
+ flow_data = auth_flows[state]
572
+ log.debug(f"刷新flow_data: completed={flow_data.get('completed')}, code存在={bool(flow_data.get('code'))}")
573
+
574
+ if not flow_data.get('code'):
575
+ log.error(f"等待OAuth回调超时,等待了{waited}秒")
576
+ return {
577
+ 'success': False,
578
+ 'error': '等待OAuth回调超时,请确保完成了浏览器中的认证并看到成功页面'
579
+ }
580
+
581
+ flow = flow_data['flow']
582
+ auth_code = flow_data['code']
583
+
584
+ log.info(f"开始使用授权码获取凭证: code={'***' + auth_code[-4:] if auth_code else 'None'}")
585
+
586
+ # 使用认证代码获取凭证
587
+ import oauthlib.oauth2.rfc6749.parameters
588
+ original_validate = oauthlib.oauth2.rfc6749.parameters.validate_token_parameters
589
+
590
+ def patched_validate(params):
591
+ try:
592
+ return original_validate(params)
593
+ except Warning:
594
+ pass
595
+
596
+ oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = patched_validate
597
+
598
+ try:
599
+ log.info(f"调用flow.exchange_code...")
600
+ credentials = await flow.exchange_code(auth_code)
601
+ log.info(f"成功获取凭证,token前缀: {credentials.access_token[:20] if credentials.access_token else 'None'}...")
602
+
603
+ log.info(f"检查是否需要项目检测: auto_project_detection={flow_data.get('auto_project_detection')}, project_id={project_id}")
604
+
605
+ # 检查是否为批量获取所有项目模式
606
+ if flow_data.get('get_all_projects', False) or get_all_projects:
607
+ log.info("批量模式:为所有项目并发获取凭证...")
608
+ user_projects = await get_user_projects(credentials)
609
+
610
+ if user_projects:
611
+ async def process_single_project(project_info):
612
+ """并发处理单个项目的凭证获取"""
613
+ project_id_current = project_info.get('projectId')
614
+ project_name = project_info.get('displayName') or project_id_current
615
+
616
+ try:
617
+ log.info(f"为项目 {project_name} ({project_id_current}) 启用API服务...")
618
+ await enable_required_apis(credentials, project_id_current)
619
+
620
+ # 保存凭证
621
+ saved_filename = await save_credentials(credentials, project_id_current)
622
+
623
+ log.info(f"成功为项目 {project_name} 保存凭证")
624
+ return {
625
+ 'status': 'success',
626
+ 'project_id': project_id_current,
627
+ 'project_name': project_name,
628
+ 'file_path': saved_filename
629
+ }
630
+
631
+ except Exception as e:
632
+ log.error(f"为项目 {project_name} ({project_id_current}) 处理凭证失败: {e}")
633
+ return {
634
+ 'status': 'failed',
635
+ 'project_id': project_id_current,
636
+ 'project_name': project_name,
637
+ 'error': str(e)
638
+ }
639
+
640
+ # 并发处理所有项目
641
+ log.info(f"开始并发处理 {len(user_projects)} 个项目...")
642
+ tasks = [process_single_project(project_info) for project_info in user_projects]
643
+ results = await asyncio.gather(*tasks, return_exceptions=True)
644
+
645
+ # 整理结果
646
+ multiple_results = {'success': [], 'failed': []}
647
+ for result in results:
648
+ if isinstance(result, Exception):
649
+ log.error(f"并发处理项目时发生异常: {result}")
650
+ multiple_results['failed'].append({
651
+ 'project_id': 'unknown',
652
+ 'project_name': 'unknown',
653
+ 'error': f'处理异常: {str(result)}'
654
+ })
655
+ elif result['status'] == 'success':
656
+ multiple_results['success'].append({
657
+ 'project_id': result['project_id'],
658
+ 'project_name': result['project_name'],
659
+ 'file_path': result['file_path']
660
+ })
661
+ else: # failed
662
+ multiple_results['failed'].append({
663
+ 'project_id': result['project_id'],
664
+ 'project_name': result['project_name'],
665
+ 'error': result['error']
666
+ })
667
+
668
+ # 清理使用过的流程
669
+ if state in auth_flows:
670
+ flow_data_to_clean = auth_flows[state]
671
+ try:
672
+ if flow_data_to_clean.get('server'):
673
+ server = flow_data_to_clean['server']
674
+ port = flow_data_to_clean.get('callback_port')
675
+ async_shutdown_server(server, port)
676
+ except Exception as e:
677
+ log.debug(f"启动异步关闭服务器时出错: {e}")
678
+ del auth_flows[state]
679
+
680
+ log.info(f"批量并发认证完成:成功 {len(multiple_results['success'])} 个,失败 {len(multiple_results['failed'])} 个")
681
+ return {
682
+ 'success': True,
683
+ 'multiple_credentials': multiple_results
684
+ }
685
+ else:
686
+ return {
687
+ 'success': False,
688
+ 'error': '无法获取您的项目列表,批量认证失败'
689
+ }
690
+
691
+ # 如果需要自动检测项目ID且没有提供项目ID(单项目模式)
692
+ elif flow_data.get('auto_project_detection', False) and not project_id:
693
+ log.info("尝试通过API获取用户项目列表...")
694
+ log.info(f"使用的token: {credentials.access_token[:20]}...")
695
+ log.info(f"Token过期时间: {credentials.expires_at}")
696
+ user_projects = await get_user_projects(credentials)
697
+
698
+ if user_projects:
699
+ # 如果只有一个项目,自动使用
700
+ if len(user_projects) == 1:
701
+ project_id = user_projects[0].get('projectId')
702
+ if project_id:
703
+ flow_data['project_id'] = project_id
704
+ log.info(f"自动选择唯一项目: {project_id}")
705
+ # 自动启用必需的API服务
706
+ log.info("正在自动启用必需的API服务...")
707
+ await enable_required_apis(credentials, project_id)
708
+ # 如果有多个项目,尝试选择默认项目
709
+ else:
710
+ project_id = await select_default_project(user_projects)
711
+ if project_id:
712
+ flow_data['project_id'] = project_id
713
+ log.info(f"自动选择默认项目: {project_id}")
714
+ # 自动启用必需的API服务
715
+ log.info("正在自动启用必需的API服务...")
716
+ await enable_required_apis(credentials, project_id)
717
+ else:
718
+ # 返回项目列表让用户选择
719
+ return {
720
+ 'success': False,
721
+ 'error': '请从以下项目中选择一个',
722
+ 'requires_project_selection': True,
723
+ 'available_projects': [
724
+ {
725
+ 'projectId': p.get('projectId'),
726
+ 'name': p.get('displayName') or p.get('projectId'),
727
+ 'projectNumber': p.get('projectNumber')
728
+ }
729
+ for p in user_projects
730
+ ]
731
+ }
732
+ else:
733
+ # 如果无法获取项目列表,提示手动输入
734
+ return {
735
+ 'success': False,
736
+ 'error': '无法获取您的项目列表,请手动指定项目ID',
737
+ 'requires_manual_project_id': True
738
+ }
739
+ elif project_id:
740
+ # 如果已经有项目ID(手动提供或环境检测),也尝试启用API服务
741
+ log.info("正在为已提供的项目ID自动启用必需的API服务...")
742
+ await enable_required_apis(credentials, project_id)
743
+
744
+ # 如果仍然没有项目ID,返回错误
745
+ if not project_id:
746
+ return {
747
+ 'success': False,
748
+ 'error': '缺少项目ID,请指定项目ID',
749
+ 'requires_manual_project_id': True
750
+ }
751
+
752
+ # 保存凭证
753
+ saved_filename = await save_credentials(credentials, project_id)
754
+
755
+ # 准备返回的凭证数据
756
+ creds_data = {
757
+ "client_id": CLIENT_ID,
758
+ "client_secret": CLIENT_SECRET,
759
+ "token": credentials.access_token,
760
+ "refresh_token": credentials.refresh_token,
761
+ "scopes": SCOPES,
762
+ "token_uri": "https://oauth2.googleapis.com/token",
763
+ "project_id": project_id
764
+ }
765
+
766
+ if credentials.expires_at:
767
+ if credentials.expires_at.tzinfo is None:
768
+ expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc)
769
+ else:
770
+ expiry_utc = credentials.expires_at
771
+ creds_data["expiry"] = expiry_utc.isoformat()
772
+
773
+ # 清理使用过的流程
774
+ if state in auth_flows:
775
+ flow_data_to_clean = auth_flows[state]
776
+ # 快速关闭服务器
777
+ try:
778
+ if flow_data_to_clean.get('server'):
779
+ server = flow_data_to_clean['server']
780
+ port = flow_data_to_clean.get('callback_port')
781
+ async_shutdown_server(server, port)
782
+ except Exception as e:
783
+ log.debug(f"启动异步关闭服务器时出错: {e}")
784
+
785
+ del auth_flows[state]
786
+
787
+ log.info("OAuth认证成功,凭证已保存")
788
+ return {
789
+ 'success': True,
790
+ 'credentials': creds_data,
791
+ 'file_path': saved_filename,
792
+ 'auto_detected_project': flow_data.get('auto_project_detection', False)
793
+ }
794
+
795
+ except Exception as e:
796
+ log.error(f"获取凭证失败: {e}")
797
+ return {
798
+ 'success': False,
799
+ 'error': f'获取凭证失败: {str(e)}'
800
+ }
801
+ finally:
802
+ oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = original_validate
803
+
804
+ except Exception as e:
805
+ log.error(f"异步完成认证流程失败: {e}")
806
+ return {
807
+ 'success': False,
808
+ 'error': str(e)
809
+ }
810
+
811
+
812
+ async def complete_auth_flow_from_callback_url(callback_url: str, project_id: Optional[str] = None, get_all_projects: bool = False) -> Dict[str, Any]:
813
+ """从回调URL直接完成认证流程,无需启动本地服务器"""
814
+ try:
815
+ log.info(f"开始从回调URL完成认证: {callback_url}")
816
+
817
+ # 解析回调URL
818
+ parsed_url = urlparse(callback_url)
819
+ query_params = parse_qs(parsed_url.query)
820
+
821
+ # 验证必要参数
822
+ if 'state' not in query_params or 'code' not in query_params:
823
+ return {
824
+ 'success': False,
825
+ 'error': '回调URL缺少必要参数 (state 或 code)'
826
+ }
827
+
828
+ state = query_params['state'][0]
829
+ code = query_params['code'][0]
830
+
831
+ log.info(f"从URL解析到: state={state}, code=xxx...")
832
+
833
+ # 检查是否有对应的认证流程
834
+ if state not in auth_flows:
835
+ return {
836
+ 'success': False,
837
+ 'error': f'未找到对应的认证流程,请先启动认证 (state: {state})'
838
+ }
839
+
840
+ flow_data = auth_flows[state]
841
+ flow = flow_data['flow']
842
+
843
+ # 构造回调URL(使用flow中存储的redirect_uri)
844
+ redirect_uri = flow.redirect_uri
845
+ log.info(f"使用redirect_uri: {redirect_uri}")
846
+
847
+ try:
848
+ # 使用authorization code获取token
849
+ credentials = await flow.exchange_code(code)
850
+ log.info("成功获取访问令牌")
851
+
852
+ # 检查是否为批量获取所有项目模式
853
+ if get_all_projects:
854
+ log.info("批量模式:从回调URL为所有项目并发获取凭证...")
855
+ try:
856
+ projects = await get_user_projects(credentials)
857
+ if projects:
858
+ async def process_single_project(project_info):
859
+ """并发处理单个项目的凭证获取"""
860
+ project_id_current = project_info.get('projectId')
861
+ project_name = project_info.get('displayName') or project_id_current
862
+
863
+ try:
864
+ log.info(f"为项目 {project_name} ({project_id_current}) 启用API服务...")
865
+ await enable_required_apis(credentials, project_id_current)
866
+
867
+ # 保存凭证
868
+ saved_filename = await save_credentials(credentials, project_id_current)
869
+
870
+ log.info(f"成功为项目 {project_name} 保存凭证")
871
+ return {
872
+ 'status': 'success',
873
+ 'project_id': project_id_current,
874
+ 'project_name': project_name,
875
+ 'file_path': saved_filename
876
+ }
877
+
878
+ except Exception as e:
879
+ log.error(f"为项目 {project_name} ({project_id_current}) 处理凭证失败: {e}")
880
+ return {
881
+ 'status': 'failed',
882
+ 'project_id': project_id_current,
883
+ 'project_name': project_name,
884
+ 'error': str(e)
885
+ }
886
+
887
+ # 并发处理所有项目
888
+ log.info(f"开始并发处理 {len(projects)} 个项目...")
889
+ tasks = [process_single_project(project_info) for project_info in projects]
890
+ results = await asyncio.gather(*tasks, return_exceptions=True)
891
+
892
+ # 整理结果
893
+ multiple_results = {'success': [], 'failed': []}
894
+ for result in results:
895
+ if isinstance(result, Exception):
896
+ log.error(f"并发处理项目时发生异常: {result}")
897
+ multiple_results['failed'].append({
898
+ 'project_id': 'unknown',
899
+ 'project_name': 'unknown',
900
+ 'error': f'处理异常: {str(result)}'
901
+ })
902
+ elif result['status'] == 'success':
903
+ multiple_results['success'].append({
904
+ 'project_id': result['project_id'],
905
+ 'project_name': result['project_name'],
906
+ 'file_path': result['file_path']
907
+ })
908
+ else: # failed
909
+ multiple_results['failed'].append({
910
+ 'project_id': result['project_id'],
911
+ 'project_name': result['project_name'],
912
+ 'error': result['error']
913
+ })
914
+
915
+ # 清理使用过的流程
916
+ if state in auth_flows:
917
+ flow_data_to_clean = auth_flows[state]
918
+ try:
919
+ if flow_data_to_clean.get('server'):
920
+ server = flow_data_to_clean['server']
921
+ port = flow_data_to_clean.get('callback_port')
922
+ async_shutdown_server(server, port)
923
+ except Exception as e:
924
+ log.debug(f"关闭服务器时出错: {e}")
925
+ del auth_flows[state]
926
+
927
+ log.info(f"从回调URL批量并发认证完成:成功 {len(multiple_results['success'])} 个,失败 {len(multiple_results['failed'])} 个")
928
+ return {
929
+ 'success': True,
930
+ 'multiple_credentials': multiple_results
931
+ }
932
+ else:
933
+ return {
934
+ 'success': False,
935
+ 'error': '无法获取您的项目列表,批量认证失败'
936
+ }
937
+ except Exception as e:
938
+ log.error(f"批量获取项目列表失败: {e}")
939
+ return {
940
+ 'success': False,
941
+ 'error': f'批量获取项目列表失败: {str(e)}'
942
+ }
943
+
944
+ # 单项目模式的项目ID处理逻辑
945
+ detected_project_id = None
946
+ auto_detected = False
947
+
948
+ if not project_id:
949
+ # 尝试自动检测项目ID
950
+ try:
951
+ projects = await get_user_projects(credentials)
952
+ if projects:
953
+ if len(projects) == 1:
954
+ # 只有一个项目,自动使用
955
+ detected_project_id = projects[0]['projectId']
956
+ auto_detected = True
957
+ log.info(f"自动检测到唯一项目ID: {detected_project_id}")
958
+ else:
959
+ # 多个项目,自动选择第一个
960
+ detected_project_id = projects[0]['projectId']
961
+ auto_detected = True
962
+ log.info(f"检测到{len(projects)}个项目,自动选择第一个: {detected_project_id}")
963
+ log.debug(f"其他可用项目: {[p['projectId'] for p in projects[1:]]}")
964
+ else:
965
+ # 没有项目访问权限
966
+ return {
967
+ 'success': False,
968
+ 'error': '未检测到可访问的项目,请检查权限或手动指定项目ID',
969
+ 'requires_manual_project_id': True
970
+ }
971
+ except Exception as e:
972
+ log.warning(f"自动检测项目ID失败: {e}")
973
+ return {
974
+ 'success': False,
975
+ 'error': f'自动检测项目ID失败: {str(e)},请手动指定项目ID',
976
+ 'requires_manual_project_id': True
977
+ }
978
+ else:
979
+ detected_project_id = project_id
980
+
981
+ # 启用必需的API服务
982
+ if detected_project_id:
983
+ try:
984
+ log.info(f"正在为项目 {detected_project_id} 启用必需的API服务...")
985
+ await enable_required_apis(credentials, detected_project_id)
986
+ except Exception as e:
987
+ log.warning(f"启用API服务失败: {e}")
988
+
989
+ # 保存凭证
990
+ saved_filename = await save_credentials(credentials, detected_project_id)
991
+
992
+ # 准备返回的凭证数据
993
+ creds_data = {
994
+ "client_id": CLIENT_ID,
995
+ "client_secret": CLIENT_SECRET,
996
+ "token": credentials.access_token,
997
+ "refresh_token": credentials.refresh_token,
998
+ "scopes": SCOPES,
999
+ "token_uri": "https://oauth2.googleapis.com/token",
1000
+ "project_id": detected_project_id
1001
+ }
1002
+
1003
+ if credentials.expires_at:
1004
+ if credentials.expires_at.tzinfo is None:
1005
+ expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc)
1006
+ else:
1007
+ expiry_utc = credentials.expires_at
1008
+ creds_data["expiry"] = expiry_utc.isoformat()
1009
+
1010
+ # 清理使用过的流程
1011
+ if state in auth_flows:
1012
+ flow_data_to_clean = auth_flows[state]
1013
+ # 快速关闭服务器(如果有)
1014
+ try:
1015
+ if flow_data_to_clean.get('server'):
1016
+ server = flow_data_to_clean['server']
1017
+ port = flow_data_to_clean.get('callback_port')
1018
+ async_shutdown_server(server, port)
1019
+ except Exception as e:
1020
+ log.debug(f"关闭服务器时出错: {e}")
1021
+
1022
+ del auth_flows[state]
1023
+
1024
+ log.info("从回调URL完成OAuth认证成功,凭证已保存")
1025
+ return {
1026
+ 'success': True,
1027
+ 'credentials': creds_data,
1028
+ 'file_path': saved_filename,
1029
+ 'auto_detected_project': auto_detected
1030
+ }
1031
+
1032
+ except Exception as e:
1033
+ log.error(f"从回调URL获取凭证失败: {e}")
1034
+ return {
1035
+ 'success': False,
1036
+ 'error': f'获取凭证失败: {str(e)}'
1037
+ }
1038
+
1039
+ except Exception as e:
1040
+ log.error(f"从回调URL完成认证流程失败: {e}")
1041
+ return {
1042
+ 'success': False,
1043
+ 'error': str(e)
1044
+ }
1045
+
1046
+
1047
+ async def save_credentials(creds: Credentials, project_id: str) -> str:
1048
+ """通过统一存储系统保存凭证"""
1049
+ # 生成文件名(使用project_id和时间戳)
1050
+ timestamp = int(time.time())
1051
+ filename = f"{project_id}-{timestamp}.json"
1052
+
1053
+ # 准备凭证数据
1054
+ creds_data = {
1055
+ "client_id": CLIENT_ID,
1056
+ "client_secret": CLIENT_SECRET,
1057
+ "token": creds.access_token,
1058
+ "refresh_token": creds.refresh_token,
1059
+ "scopes": SCOPES,
1060
+ "token_uri": "https://oauth2.googleapis.com/token",
1061
+ "project_id": project_id
1062
+ }
1063
+
1064
+ if creds.expires_at:
1065
+ if creds.expires_at.tzinfo is None:
1066
+ expiry_utc = creds.expires_at.replace(tzinfo=timezone.utc)
1067
+ else:
1068
+ expiry_utc = creds.expires_at
1069
+ creds_data["expiry"] = expiry_utc.isoformat()
1070
+
1071
+ # 通过存储适配器保存
1072
+ storage_adapter = await get_storage_adapter()
1073
+ success = await storage_adapter.store_credential(filename, creds_data)
1074
+
1075
+ if success:
1076
+ # 创建默认状态记录
1077
+ try:
1078
+ default_state = {
1079
+ "error_codes": [],
1080
+ "disabled": False,
1081
+ "last_success": time.time(),
1082
+ "user_email": None,
1083
+ "gemini_2_5_pro_calls": 0,
1084
+ "total_calls": 0,
1085
+ "next_reset_time": None,
1086
+ "daily_limit_gemini_2_5_pro": 100,
1087
+ "daily_limit_total": 1000
1088
+ }
1089
+ await storage_adapter.update_credential_state(filename, default_state)
1090
+ log.info(f"凭证和状态已保存到: {filename}")
1091
+ except Exception as e:
1092
+ log.warning(f"创建默认状态记录失败 {filename}: {e}")
1093
+
1094
+ return filename
1095
+ else:
1096
+ raise Exception(f"保存凭证失败: {filename}")
1097
+
1098
+
1099
+ def async_shutdown_server(server, port):
1100
+ """异步关闭OAuth回调服务器,避免阻塞主流程"""
1101
+ def shutdown_server_async():
1102
+ try:
1103
+ # 设置一个标志来跟踪关闭状态
1104
+ shutdown_completed = threading.Event()
1105
+
1106
+ def do_shutdown():
1107
+ try:
1108
+ server.shutdown()
1109
+ server.server_close()
1110
+ shutdown_completed.set()
1111
+ log.info(f"已关闭端口 {port} 的OAuth回调服务器")
1112
+ except Exception as e:
1113
+ shutdown_completed.set()
1114
+ log.debug(f"关闭服务器时出错: {e}")
1115
+
1116
+ # 在单独线程中执行关闭操作
1117
+ shutdown_worker = threading.Thread(target=do_shutdown, daemon=True)
1118
+ shutdown_worker.start()
1119
+
1120
+ # 等待最多5秒,如果超时就放弃等待
1121
+ if shutdown_completed.wait(timeout=5):
1122
+ log.debug(f"端口 {port} 服务器关闭完成")
1123
+ else:
1124
+ log.warning(f"端口 {port} 服务器关闭超时,但不阻塞主流程")
1125
+
1126
+ except Exception as e:
1127
+ log.debug(f"异步关闭服务器时出错: {e}")
1128
+
1129
+ # 在后台线程中关闭服务器,不阻塞主流程
1130
+ shutdown_thread = threading.Thread(target=shutdown_server_async, daemon=True)
1131
+ shutdown_thread.start()
1132
+ log.debug(f"开始异步关闭端口 {port} 的OAuth回调服务器")
1133
+
1134
+ def cleanup_expired_flows():
1135
+ """清理过期的认证流程"""
1136
+ current_time = time.time()
1137
+ EXPIRY_TIME = 600 # 10分钟过期
1138
+
1139
+ # 直接遍历删除,避免创建额外列表
1140
+ states_to_remove = [
1141
+ state for state, flow_data in auth_flows.items()
1142
+ if current_time - flow_data['created_at'] > EXPIRY_TIME
1143
+ ]
1144
+
1145
+ # 批量清理,提高效率
1146
+ cleaned_count = 0
1147
+ for state in states_to_remove:
1148
+ flow_data = auth_flows.get(state)
1149
+ if flow_data:
1150
+ # 快速关闭可能存在的服务器
1151
+ try:
1152
+ if flow_data.get('server'):
1153
+ server = flow_data['server']
1154
+ port = flow_data.get('callback_port')
1155
+ async_shutdown_server(server, port)
1156
+ except Exception as e:
1157
+ log.debug(f"清理过期流程时启动异步关闭服务器失败: {e}")
1158
+
1159
+ # 显式清理流程数据,释放内存
1160
+ flow_data.clear()
1161
+ del auth_flows[state]
1162
+ cleaned_count += 1
1163
+
1164
+ if cleaned_count > 0:
1165
+ log.info(f"清理了 {cleaned_count} 个过期的认证流程")
1166
+
1167
+ # 更积极的垃圾回收触发条件
1168
+ if len(auth_flows) > 20: # 降低阈值
1169
+ import gc
1170
+ gc.collect()
1171
+ log.debug(f"触发垃圾回收,当前活跃认证流程数: {len(auth_flows)}")
1172
+
1173
+
1174
+ def get_auth_status(project_id: str) -> Dict[str, Any]:
1175
+ """获取认证状态"""
1176
+ for state, flow_data in auth_flows.items():
1177
+ if flow_data['project_id'] == project_id:
1178
+ return {
1179
+ 'status': 'completed' if flow_data['completed'] else 'pending',
1180
+ 'state': state,
1181
+ 'created_at': flow_data['created_at']
1182
+ }
1183
+
1184
+ return {
1185
+ 'status': 'not_found'
1186
+ }
1187
+
1188
+
1189
+ # 鉴权功能 - 使用更小的数据结构
1190
+ auth_tokens = {} # 存储有效的认证令牌
1191
+ TOKEN_EXPIRY = 3600 # 1小时令牌过期时间
1192
+
1193
+
1194
+ async def verify_password(password: str) -> bool:
1195
+ """验证密码(面板登录使用)"""
1196
+ from config import get_panel_password
1197
+ correct_password = await get_panel_password()
1198
+ return password == correct_password
1199
+
1200
+
1201
+ def generate_auth_token() -> str:
1202
+ """生成认证令牌"""
1203
+ # 清理过期令牌
1204
+ cleanup_expired_tokens()
1205
+
1206
+ token = secrets.token_urlsafe(32)
1207
+ # 只存储创建时间
1208
+ auth_tokens[token] = time.time()
1209
+ return token
1210
+
1211
+
1212
+ def verify_auth_token(token: str) -> bool:
1213
+ """验证认证令牌"""
1214
+ if not token or token not in auth_tokens:
1215
+ return False
1216
+
1217
+ created_at = auth_tokens[token]
1218
+
1219
+ # 检查令牌是否过期 (使用更短的过期时间)
1220
+ if time.time() - created_at > TOKEN_EXPIRY:
1221
+ del auth_tokens[token]
1222
+ return False
1223
+
1224
+ return True
1225
+
1226
+
1227
+ def cleanup_expired_tokens():
1228
+ """清理过期的认证令牌"""
1229
+ current_time = time.time()
1230
+ expired_tokens = [
1231
+ token for token, created_at in auth_tokens.items()
1232
+ if current_time - created_at > TOKEN_EXPIRY
1233
+ ]
1234
+
1235
+ for token in expired_tokens:
1236
+ del auth_tokens[token]
1237
+
1238
+ if expired_tokens:
1239
+ log.debug(f"清理了 {len(expired_tokens)} 个过期的认证令牌")
1240
+
1241
+ def invalidate_auth_token(token: str):
1242
+ """使认证令牌失效"""
1243
+ if token in auth_tokens:
1244
+ del auth_tokens[token]
1245
+
1246
+
1247
+ # 文件验证和处理功能 - 使用统一存储系统
1248
+ def validate_credential_content(content: str) -> Dict[str, Any]:
1249
+ """验证凭证内容格式"""
1250
+ try:
1251
+ creds_data = json.loads(content)
1252
+
1253
+ # 检查必要字段
1254
+ required_fields = ['client_id', 'client_secret', 'refresh_token', 'token_uri']
1255
+ missing_fields = [field for field in required_fields if field not in creds_data]
1256
+
1257
+ if missing_fields:
1258
+ return {
1259
+ 'valid': False,
1260
+ 'error': f'缺少必要字段: {", ".join(missing_fields)}'
1261
+ }
1262
+
1263
+ # 检查project_id
1264
+ if 'project_id' not in creds_data:
1265
+ log.warning("��证文件缺少project_id字段")
1266
+
1267
+ return {
1268
+ 'valid': True,
1269
+ 'data': creds_data
1270
+ }
1271
+
1272
+ except json.JSONDecodeError as e:
1273
+ return {
1274
+ 'valid': False,
1275
+ 'error': f'JSON格式错误: {str(e)}'
1276
+ }
1277
+ except Exception as e:
1278
+ return {
1279
+ 'valid': False,
1280
+ 'error': f'文件验证失败: {str(e)}'
1281
+ }
1282
+
1283
+
1284
+ async def save_uploaded_credential(content: str, original_filename: str) -> Dict[str, Any]:
1285
+ """通过统一存储系统保存上传的凭证"""
1286
+ try:
1287
+ # 验证内容格式
1288
+ validation = validate_credential_content(content)
1289
+ if not validation['valid']:
1290
+ return {
1291
+ 'success': False,
1292
+ 'error': validation['error']
1293
+ }
1294
+
1295
+ creds_data = validation['data']
1296
+
1297
+ # 生成文件名
1298
+ project_id = creds_data.get('project_id', 'unknown')
1299
+ timestamp = int(time.time())
1300
+
1301
+ # 从原文件名中提取有用信息
1302
+ import os
1303
+ base_name = os.path.splitext(original_filename)[0]
1304
+ filename = f"{base_name}-{timestamp}.json"
1305
+
1306
+ # 通过存储适配器保存
1307
+ storage_adapter = await get_storage_adapter()
1308
+ success = await storage_adapter.store_credential(filename, creds_data)
1309
+
1310
+ if success:
1311
+ log.info(f"凭证文件已上传保存: {filename}")
1312
+ return {
1313
+ 'success': True,
1314
+ 'file_path': filename,
1315
+ 'project_id': project_id
1316
+ }
1317
+ else:
1318
+ return {
1319
+ 'success': False,
1320
+ 'error': '保存到存储系统失败'
1321
+ }
1322
+
1323
+ except Exception as e:
1324
+ log.error(f"保存上传文件失败: {e}")
1325
+ return {
1326
+ 'success': False,
1327
+ 'error': str(e)
1328
+ }
1329
+
1330
+
1331
+ async def batch_upload_credentials(files_data: List[Dict[str, str]]) -> Dict[str, Any]:
1332
+ """批量上传凭证文件到统一存储系统"""
1333
+ results = []
1334
+ success_count = 0
1335
+
1336
+ for file_data in files_data:
1337
+ filename = file_data.get('filename', 'unknown.json')
1338
+ content = file_data.get('content', '')
1339
+
1340
+ result = await save_uploaded_credential(content, filename)
1341
+ result['filename'] = filename
1342
+ results.append(result)
1343
+
1344
+ if result['success']:
1345
+ success_count += 1
1346
+
1347
+ return {
1348
+ 'uploaded_count': success_count,
1349
+ 'total_count': len(files_data),
1350
+ 'results': results
1351
+ }
1352
+
1353
+
1354
+ # 环境变量批量导入功能 - 使用统一存储系统
1355
+ async def load_credentials_from_env() -> Dict[str, Any]:
1356
+ """
1357
+ 从环境变量加载多个凭证文件到统一存储系统
1358
+ 支持两种环境变量格式:
1359
+ 1. GCLI_CREDS_1, GCLI_CREDS_2, ... (编号格式)
1360
+ 2. GCLI_CREDS_projectname1, GCLI_CREDS_projectname2, ... (项目名格式)
1361
+ """
1362
+ import os
1363
+
1364
+ results = []
1365
+ success_count = 0
1366
+
1367
+ log.info("开始从环境变量加载认证凭证...")
1368
+
1369
+ # 获取所有以GCLI_CREDS_开头的环境变量
1370
+ creds_env_vars = {key: value for key, value in os.environ.items()
1371
+ if key.startswith('GCLI_CREDS_') and value.strip()}
1372
+
1373
+ if not creds_env_vars:
1374
+ log.info("未找到GCLI_CREDS_*环境变量")
1375
+ return {
1376
+ 'loaded_count': 0,
1377
+ 'total_count': 0,
1378
+ 'results': [],
1379
+ 'message': '未找到GCLI_CREDS_*环境变量'
1380
+ }
1381
+
1382
+ log.info(f"找到 {len(creds_env_vars)} 个凭证环境变量")
1383
+
1384
+ # 获取存储适配器
1385
+ storage_adapter = await get_storage_adapter()
1386
+
1387
+ for env_name, creds_content in creds_env_vars.items():
1388
+ # 从环境变量名提取标识符
1389
+ identifier = env_name.replace('GCLI_CREDS_', '')
1390
+
1391
+ try:
1392
+ # 验证JSON格式
1393
+ validation = validate_credential_content(creds_content)
1394
+ if not validation['valid']:
1395
+ result = {
1396
+ 'env_name': env_name,
1397
+ 'identifier': identifier,
1398
+ 'success': False,
1399
+ 'error': validation['error']
1400
+ }
1401
+ results.append(result)
1402
+ log.error(f"环境变量 {env_name} 验证失败: {validation['error']}")
1403
+ continue
1404
+
1405
+ creds_data = validation['data']
1406
+ project_id = creds_data.get('project_id', 'unknown')
1407
+
1408
+ # 生成文件名 (使用标识符和项目ID)
1409
+ timestamp = int(time.time())
1410
+ if identifier.isdigit():
1411
+ # 如果标识符是数字,使用项目ID作为主要标识
1412
+ filename = f"env-{project_id}-{identifier}-{timestamp}.json"
1413
+ else:
1414
+ # 如果标识符是项目名,直接使用
1415
+ filename = f"env-{identifier}-{timestamp}.json"
1416
+
1417
+ # 通过存储适配器保存
1418
+ success = await storage_adapter.store_credential(filename, creds_data)
1419
+
1420
+ if success:
1421
+ result = {
1422
+ 'env_name': env_name,
1423
+ 'identifier': identifier,
1424
+ 'success': True,
1425
+ 'file_path': filename,
1426
+ 'project_id': project_id,
1427
+ 'filename': filename
1428
+ }
1429
+ results.append(result)
1430
+ success_count += 1
1431
+
1432
+ log.info(f"成功从环境变量 {env_name} 保存凭证到: {filename}")
1433
+ else:
1434
+ result = {
1435
+ 'env_name': env_name,
1436
+ 'identifier': identifier,
1437
+ 'success': False,
1438
+ 'error': '保存到存储系统失败'
1439
+ }
1440
+ results.append(result)
1441
+ log.error(f"环境变量 {env_name} 保存失败")
1442
+
1443
+ except Exception as e:
1444
+ result = {
1445
+ 'env_name': env_name,
1446
+ 'identifier': identifier,
1447
+ 'success': False,
1448
+ 'error': str(e)
1449
+ }
1450
+ results.append(result)
1451
+ log.error(f"处理环境变量 {env_name} 时发生错误: {e}")
1452
+
1453
+ message = f"成功导入 {success_count}/{len(creds_env_vars)} 个凭证文件"
1454
+ log.info(message)
1455
+
1456
+ return {
1457
+ 'loaded_count': success_count,
1458
+ 'total_count': len(creds_env_vars),
1459
+ 'results': results,
1460
+ 'message': message
1461
+ }
1462
+
1463
+
1464
+ async def auto_load_env_credentials_on_startup() -> None:
1465
+ """
1466
+ 程序启动时自动从环境变量加载凭证到统一存储系统
1467
+ 如果设置了 AUTO_LOAD_ENV_CREDS=true,则会自动执行
1468
+ """
1469
+ from config import get_auto_load_env_creds
1470
+ auto_load = await get_auto_load_env_creds()
1471
+
1472
+ if not auto_load:
1473
+ log.debug("AUTO_LOAD_ENV_CREDS未启用,跳过自动加载")
1474
+ return
1475
+
1476
+ log.info("AUTO_LOAD_ENV_CREDS已启用,开始自动加载环境变量中的凭证...")
1477
+
1478
+ try:
1479
+ result = await load_credentials_from_env()
1480
+ if result['loaded_count'] > 0:
1481
+ log.info(f"启动时成功自动导入 {result['loaded_count']} 个凭证文件")
1482
+ else:
1483
+ log.info("启动时未找到可导入的环境变量凭证")
1484
+ except Exception as e:
1485
+ log.error(f"启动时自动加载环境变量凭证失败: {e}")
1486
+
1487
+
1488
+ async def clear_env_credentials() -> Dict[str, Any]:
1489
+ """
1490
+ 清除所有从环境变量导入的凭证文件
1491
+ 仅删除文件名包含'env-'前缀的文件
1492
+ """
1493
+ try:
1494
+ storage_adapter = await get_storage_adapter()
1495
+
1496
+ # 获取所有凭证
1497
+ all_credentials = await storage_adapter.list_credentials()
1498
+
1499
+ deleted_files = []
1500
+ deleted_count = 0
1501
+
1502
+ for credential_name in all_credentials:
1503
+ if credential_name.startswith('env-') and credential_name.endswith('.json'):
1504
+ try:
1505
+ success = await storage_adapter.delete_credential(credential_name)
1506
+ if success:
1507
+ deleted_files.append(credential_name)
1508
+ deleted_count += 1
1509
+ log.info(f"删除环境变量凭证文件: {credential_name}")
1510
+ else:
1511
+ log.error(f"删除文件 {credential_name} 失败")
1512
+ except Exception as e:
1513
+ log.error(f"删除文件 {credential_name} 失败: {e}")
1514
+
1515
+ message = f"成功删除 {deleted_count} 个环境变量凭证文件"
1516
+ log.info(message)
1517
+
1518
+ return {
1519
+ 'deleted_count': deleted_count,
1520
+ 'deleted_files': deleted_files,
1521
+ 'message': message
1522
+ }
1523
+
1524
+ except Exception as e:
1525
+ error_message = f"清除环境变量凭证文件时发生错误: {e}"
1526
+ log.error(error_message)
1527
+ return {
1528
+ 'deleted_count': 0,
1529
+ 'error': error_message
1530
+ }
src/credential_manager.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 凭证管理器 - 完全基于统一存储中间层
3
+ """
4
+ import asyncio
5
+ import time
6
+ from datetime import datetime, timezone
7
+ from typing import Dict, Any, List, Optional, Tuple
8
+ from contextlib import asynccontextmanager
9
+
10
+ from config import get_calls_per_rotation, is_mongodb_mode
11
+ from log import log
12
+ from .storage_adapter import get_storage_adapter
13
+ from .google_oauth_api import fetch_user_email_from_file, Credentials
14
+ from .task_manager import task_manager
15
+
16
+
17
+ class CredentialManager:
18
+ """
19
+ 统一凭证管理器
20
+ 所有存储操作通过storage_adapter进行
21
+ """
22
+
23
+ def __init__(self):
24
+ # 核心状态
25
+ self._initialized = False
26
+ self._storage_adapter = None
27
+
28
+ # 凭证轮换相关
29
+ self._credential_files: List[str] = [] # 存储凭证文件名列表
30
+ self._current_credential_index = 0
31
+ self._call_count = 0
32
+ self._last_scan_time = 0
33
+
34
+ # 当前使用的凭证信息
35
+ self._current_credential_file: Optional[str] = None
36
+ self._current_credential_data: Optional[Dict[str, Any]] = None
37
+ self._current_credential_state: Dict[str, Any] = {}
38
+
39
+ # 并发控制
40
+ self._state_lock = asyncio.Lock()
41
+ self._operation_lock = asyncio.Lock()
42
+
43
+ # 工作线程控制
44
+ self._shutdown_event = asyncio.Event()
45
+ self._write_worker_running = False
46
+ self._write_worker_task = None
47
+
48
+ # 原子操作计数器
49
+ self._atomic_counter = 0
50
+ self._atomic_lock = asyncio.Lock()
51
+
52
+ # Onboarding state
53
+ self._onboarding_complete = False
54
+ self._onboarding_checked = False
55
+
56
+ async def initialize(self):
57
+ """初始化凭证管理器"""
58
+ async with self._state_lock:
59
+ if self._initialized:
60
+ return
61
+
62
+ # 初始化统一存储适配器
63
+ self._storage_adapter = await get_storage_adapter()
64
+
65
+ # 启动后台工作线程
66
+ await self._start_background_workers()
67
+
68
+ # 发现并加载凭证
69
+ await self._discover_credentials()
70
+
71
+ self._initialized = True
72
+ storage_type = "MongoDB" if await is_mongodb_mode() else "File"
73
+ log.debug(f"Credential manager initialized with {storage_type} storage backend")
74
+
75
+ async def close(self):
76
+ """清理资源"""
77
+ log.debug("Closing credential manager...")
78
+
79
+ # 设置关闭标志
80
+ self._shutdown_event.set()
81
+
82
+ # 等待后台任务结束
83
+ if self._write_worker_task:
84
+ try:
85
+ await asyncio.wait_for(self._write_worker_task, timeout=5.0)
86
+ except asyncio.TimeoutError:
87
+ log.warning("Write worker task did not finish within timeout")
88
+ if not self._write_worker_task.done():
89
+ self._write_worker_task.cancel()
90
+
91
+ self._initialized = False
92
+ log.debug("Credential manager closed")
93
+
94
+ async def _start_background_workers(self):
95
+ """启动后台工作线程"""
96
+ if not self._write_worker_running:
97
+ self._write_worker_running = True
98
+ self._write_worker_task = task_manager.create_task(
99
+ self._background_worker(),
100
+ name="credential_background_worker"
101
+ )
102
+
103
+ async def _background_worker(self):
104
+ """后台工作线程,处理定期任务"""
105
+ while not self._shutdown_event.is_set():
106
+ try:
107
+ # 每60秒检查一次凭证更新
108
+ await asyncio.wait_for(self._shutdown_event.wait(), timeout=60.0)
109
+ if self._shutdown_event.is_set():
110
+ break
111
+
112
+ # 重新发现凭证(热更新)
113
+ await self._discover_credentials()
114
+
115
+ except asyncio.TimeoutError:
116
+ # 超时是正常的,继续下一轮
117
+ continue
118
+ except Exception as e:
119
+ log.error(f"Background worker error: {e}")
120
+ await asyncio.sleep(5) # 错误后等待5秒再继续
121
+
122
+ async def _discover_credentials(self):
123
+ """发现和加载所有可用凭证"""
124
+ try:
125
+ # 从存储适配器获取所有凭证
126
+ all_credentials = await self._storage_adapter.list_credentials()
127
+
128
+ # 过滤出可用的凭证(排除被禁用的)- 批量读取状态以提升性能
129
+ available_credentials = []
130
+
131
+ # 批量获取所有凭证状态,避免多次读取状态文件
132
+ if all_credentials:
133
+ try:
134
+ all_states = await self._storage_adapter.get_all_credential_states()
135
+
136
+ for credential_name in all_credentials:
137
+ normalized_name = credential_name
138
+ # 标准化文件名以匹配状态数据中的键
139
+ if hasattr(self._storage_adapter._backend, '_normalize_filename'):
140
+ normalized_name = self._storage_adapter._backend._normalize_filename(credential_name)
141
+
142
+ state = all_states.get(normalized_name, {})
143
+ if not state.get("disabled", False):
144
+ available_credentials.append(credential_name)
145
+ except Exception as e:
146
+ log.warning(f"Failed to batch load credential states, falling back to individual checks: {e}")
147
+ # 如果批量读取失败,回退到逐个检查
148
+ for credential_name in all_credentials:
149
+ try:
150
+ state = await self._storage_adapter.get_credential_state(credential_name)
151
+ if not state.get("disabled", False):
152
+ available_credentials.append(credential_name)
153
+ except Exception as e2:
154
+ log.warning(f"Failed to check state for credential {credential_name}: {e2}")
155
+
156
+ # 更新凭证列表
157
+ old_credentials = set(self._credential_files)
158
+ new_credentials = set(available_credentials)
159
+
160
+ if old_credentials != new_credentials:
161
+ # 记录变化(只在非初始状态时记录)
162
+ is_initial_load = len(old_credentials) == 0
163
+ added = new_credentials - old_credentials
164
+ removed = old_credentials - new_credentials
165
+
166
+ self._credential_files = available_credentials
167
+
168
+ # 初始加载时只记录调试信息,运行时变化才记录INFO
169
+ if not is_initial_load:
170
+ if added:
171
+ log.info(f"发现新的可用凭证: {list(added)}")
172
+ if removed:
173
+ log.info(f"移除不可用凭证: {list(removed)}")
174
+ else:
175
+ # 初始加载时只记录调试信息
176
+ if available_credentials:
177
+ log.debug(f"初始加载发现 {len(available_credentials)} 个可用凭证")
178
+
179
+ # 重置当前索引如果需要
180
+ if self._current_credential_index >= len(self._credential_files):
181
+ self._current_credential_index = 0
182
+
183
+ if not self._credential_files:
184
+ log.warning("No available credential files found")
185
+ else:
186
+ log.debug(f"Available credentials: {len(self._credential_files)} files")
187
+
188
+ except Exception as e:
189
+ log.error(f"Failed to discover credentials: {e}")
190
+
191
+ async def _load_current_credential(self) -> Optional[Tuple[str, Dict[str, Any]]]:
192
+ """加载当前选中的凭证数据,包含token过期检测和自动刷新"""
193
+ if not self._credential_files:
194
+ return None
195
+
196
+ try:
197
+ current_file = self._credential_files[self._current_credential_index]
198
+
199
+ # 从存储适配器加载凭证数据
200
+ credential_data = await self._storage_adapter.get_credential(current_file)
201
+ if not credential_data:
202
+ log.error(f"Failed to load credential data for: {current_file}")
203
+ return None
204
+
205
+ # 检查refresh_token
206
+ if "refresh_token" not in credential_data or not credential_data["refresh_token"]:
207
+ log.warning(f"No refresh token in {current_file}")
208
+ return None
209
+
210
+ # Auto-add 'type' field if missing but has required OAuth fields
211
+ if 'type' not in credential_data and all(key in credential_data for key in ['client_id', 'refresh_token']):
212
+ credential_data['type'] = 'authorized_user'
213
+ log.debug(f"Auto-added 'type' field to credential from file {current_file}")
214
+
215
+ # 兼容不同的token字段格式
216
+ if "access_token" in credential_data and "token" not in credential_data:
217
+ credential_data["token"] = credential_data["access_token"]
218
+ if "scope" in credential_data and "scopes" not in credential_data:
219
+ credential_data["scopes"] = credential_data["scope"].split()
220
+
221
+ # token过期检测和刷新
222
+ should_refresh = await self._should_refresh_token(credential_data)
223
+
224
+ if should_refresh:
225
+ log.debug(f"Token需要刷新 - 文件: {current_file}")
226
+ refreshed_data = await self._refresh_token(credential_data, current_file)
227
+ if refreshed_data:
228
+ credential_data = refreshed_data
229
+ log.debug(f"Token刷新成功: {current_file}")
230
+ else:
231
+ log.error(f"Token刷新失败: {current_file}")
232
+ return None
233
+
234
+ # 加载状态信息
235
+ state_data = await self._storage_adapter.get_credential_state(current_file)
236
+
237
+ # 缓存当前凭证信息
238
+ self._current_credential_file = current_file
239
+ self._current_credential_data = credential_data
240
+ self._current_credential_state = state_data
241
+
242
+ return current_file, credential_data
243
+
244
+ except Exception as e:
245
+ log.error(f"Error loading current credential: {e}")
246
+ return None
247
+
248
+ async def get_valid_credential(self) -> Optional[Tuple[str, Dict[str, Any]]]:
249
+ """获取有效的凭证,自动处理轮换和失效凭证切换"""
250
+ async with self._operation_lock:
251
+ if not self._credential_files:
252
+ await self._discover_credentials()
253
+ if not self._credential_files:
254
+ return None
255
+
256
+ # 检查是否需要轮换
257
+ if await self._should_rotate():
258
+ await self._rotate_credential()
259
+
260
+ # 尝试获取有效凭证,如果失败则自动切换
261
+ max_attempts = len(self._credential_files) # 最多尝试所有凭证
262
+
263
+ for attempt in range(max_attempts):
264
+ try:
265
+ # 加载当前凭证
266
+ result = await self._load_current_credential()
267
+ if result:
268
+ return result
269
+
270
+ # 当前凭证加载失败,标记为失效并切换到下一个
271
+ current_file = self._credential_files[self._current_credential_index] if self._credential_files else None
272
+ if current_file:
273
+ log.warning(f"凭证失效,自动禁用并切换: {current_file}")
274
+ await self.set_cred_disabled(current_file, True)
275
+
276
+ # 重新发现可用凭证(排除刚禁用的)
277
+ await self._discover_credentials()
278
+ if not self._credential_files:
279
+ log.error("没有可用的凭证")
280
+ return None
281
+
282
+ # 重置索引到第一个可用凭证
283
+ self._current_credential_index = 0
284
+ log.info(f"切换到下一个可用凭证 (索引: {self._current_credential_index})")
285
+ else:
286
+ log.error("无法获取当前凭证文件名")
287
+ break
288
+
289
+ except Exception as e:
290
+ log.error(f"获取凭证时发生异常 (尝试 {attempt + 1}/{max_attempts}): {e}")
291
+ if attempt < max_attempts - 1:
292
+ # 切换到下一个凭证继续尝试
293
+ await self._rotate_credential()
294
+ continue
295
+
296
+ log.error(f"所有 {max_attempts} 个凭证都尝试失败")
297
+ return None
298
+
299
+ async def _should_rotate(self) -> bool:
300
+ """检查是否需要轮换凭证"""
301
+ if not self._credential_files or len(self._credential_files) <= 1:
302
+ return False
303
+
304
+ current_calls_per_rotation = await get_calls_per_rotation()
305
+ return self._call_count >= current_calls_per_rotation
306
+
307
+ async def _rotate_credential(self):
308
+ """轮换到下一个凭证"""
309
+ if len(self._credential_files) <= 1:
310
+ return
311
+
312
+ self._current_credential_index = (self._current_credential_index + 1) % len(self._credential_files)
313
+ self._call_count = 0
314
+
315
+ log.info(f"Rotated to credential index {self._current_credential_index}")
316
+
317
+ async def force_rotate_credential(self):
318
+ """强制轮换到下一个凭证(用于429错误处理)"""
319
+ async with self._operation_lock:
320
+ if len(self._credential_files) <= 1:
321
+ log.warning("Only one credential available, cannot rotate")
322
+ return
323
+
324
+ await self._rotate_credential()
325
+ log.info("Forced credential rotation due to rate limit")
326
+
327
+ def increment_call_count(self):
328
+ """增加调用计数"""
329
+ self._call_count += 1
330
+
331
+ async def update_credential_state(self, credential_name: str, state_updates: Dict[str, Any]):
332
+ """更新凭证状态"""
333
+ try:
334
+ # 直接通过存储适配器更新状态
335
+ success = await self._storage_adapter.update_credential_state(credential_name, state_updates)
336
+
337
+ # 如果是当前使用的凭证,更新缓存
338
+ if credential_name == self._current_credential_file:
339
+ self._current_credential_state.update(state_updates)
340
+
341
+ if success:
342
+ log.debug(f"Updated credential state: {credential_name}")
343
+ else:
344
+ log.warning(f"Failed to update credential state: {credential_name}")
345
+
346
+ return success
347
+
348
+ except Exception as e:
349
+ log.error(f"Error updating credential state {credential_name}: {e}")
350
+ return False
351
+
352
+ async def set_cred_disabled(self, credential_name: str, disabled: bool):
353
+ """设置凭证的启用/禁用状态"""
354
+ try:
355
+ state_updates = {"disabled": disabled}
356
+ success = await self.update_credential_state(credential_name, state_updates)
357
+
358
+ if success:
359
+ # 如果禁用了当前正在使用的凭证,需要重新发现可用凭证
360
+ if disabled and credential_name == self._current_credential_file:
361
+ await self._discover_credentials()
362
+ if self._credential_files:
363
+ await self._rotate_credential()
364
+
365
+ action = "disabled" if disabled else "enabled"
366
+ log.info(f"Credential {action}: {credential_name}")
367
+
368
+ return success
369
+
370
+ except Exception as e:
371
+ log.error(f"Error setting credential disabled state {credential_name}: {e}")
372
+ return False
373
+
374
+ async def get_creds_status(self) -> Dict[str, Dict[str, Any]]:
375
+ """获取所有凭证的状态"""
376
+ try:
377
+ # 从存储适配器获取所有状态
378
+ all_states = await self._storage_adapter.get_all_credential_states()
379
+ return all_states
380
+
381
+ except Exception as e:
382
+ log.error(f"Error getting credential statuses: {e}")
383
+ return {}
384
+
385
+ async def get_or_fetch_user_email(self, credential_name: str) -> Optional[str]:
386
+ """获取或获取用户邮箱地址"""
387
+ try:
388
+ # 首先检查缓存的状态
389
+ state = await self._storage_adapter.get_credential_state(credential_name)
390
+ cached_email = state.get("user_email")
391
+
392
+ if cached_email:
393
+ return cached_email
394
+
395
+ # 如果没有缓存,从凭证数据获取
396
+ credential_data = await self._storage_adapter.get_credential(credential_name)
397
+ if not credential_data:
398
+ return None
399
+
400
+ # 尝试获取邮箱
401
+ email = await fetch_user_email_from_file(credential_data)
402
+
403
+ if email:
404
+ # 缓存邮箱地址
405
+ await self.update_credential_state(credential_name, {"user_email": email})
406
+ return email
407
+
408
+ return None
409
+
410
+ except Exception as e:
411
+ log.error(f"Error fetching user email for {credential_name}: {e}")
412
+ return None
413
+
414
+ async def record_api_call_result(self, credential_name: str, success: bool, error_code: Optional[int] = None):
415
+ """记录API调用结果"""
416
+ try:
417
+ state_updates = {}
418
+
419
+ if success:
420
+ state_updates["last_success"] = time.time()
421
+ # 清除错误码(如果之前有的话)
422
+ state_updates["error_codes"] = []
423
+ elif error_code:
424
+ # 记录错误码
425
+ current_state = await self._storage_adapter.get_credential_state(credential_name)
426
+ error_codes = current_state.get("error_codes", [])
427
+
428
+ if error_code not in error_codes:
429
+ error_codes.append(error_code)
430
+ # 限制错误码列表长度
431
+ if len(error_codes) > 10:
432
+ error_codes = error_codes[-10:]
433
+
434
+ state_updates["error_codes"] = error_codes
435
+
436
+ if state_updates:
437
+ await self.update_credential_state(credential_name, state_updates)
438
+
439
+ except Exception as e:
440
+ log.error(f"Error recording API call result for {credential_name}: {e}")
441
+
442
+ # 原子操作支持
443
+ @asynccontextmanager
444
+ async def _atomic_operation(self, operation_name: str):
445
+ """原子操作上下文管理器"""
446
+ async with self._atomic_lock:
447
+ self._atomic_counter += 1
448
+ operation_id = self._atomic_counter
449
+ log.debug(f"开始原子操作[{operation_id}]: {operation_name}")
450
+
451
+ try:
452
+ yield operation_id
453
+ log.debug(f"完成原子操作[{operation_id}]: {operation_name}")
454
+ except Exception as e:
455
+ log.error(f"原子操作[{operation_id}]失败: {operation_name} - {e}")
456
+ raise
457
+
458
+ async def _should_refresh_token(self, credential_data: Dict[str, Any]) -> bool:
459
+ """检查token是否需要刷新"""
460
+ try:
461
+ # 如果没有access_token或过期时间,需要刷新
462
+ if not credential_data.get("access_token") and not credential_data.get("token"):
463
+ log.debug("没有access_token,需要刷新")
464
+ return True
465
+
466
+ expiry_str = credential_data.get("expiry")
467
+ if not expiry_str:
468
+ log.debug("没有过期时间,需要刷新")
469
+ return True
470
+
471
+ # 解析过期时间
472
+ try:
473
+ if isinstance(expiry_str, str):
474
+ if "+" in expiry_str:
475
+ file_expiry = datetime.fromisoformat(expiry_str)
476
+ elif expiry_str.endswith("Z"):
477
+ file_expiry = datetime.fromisoformat(expiry_str.replace('Z', '+00:00'))
478
+ else:
479
+ file_expiry = datetime.fromisoformat(expiry_str)
480
+ else:
481
+ log.debug("过期时间格式无效,需要刷新")
482
+ return True
483
+
484
+ # 确保时区信息
485
+ if file_expiry.tzinfo is None:
486
+ file_expiry = file_expiry.replace(tzinfo=timezone.utc)
487
+
488
+ # 检查是否还有至少5分钟有效期
489
+ now = datetime.now(timezone.utc)
490
+ time_left = (file_expiry - now).total_seconds()
491
+
492
+ log.debug(f"Token剩余时间: {int(time_left/60)}分钟")
493
+
494
+ if time_left > 300: # 5分钟缓冲
495
+ return False
496
+ else:
497
+ log.debug(f"Token即将过期(剩余{int(time_left/60)}分钟),需要刷新")
498
+ return True
499
+
500
+ except Exception as e:
501
+ log.warning(f"解析过期时间失败: {e},需要刷新")
502
+ return True
503
+
504
+ except Exception as e:
505
+ log.error(f"检查token过期时出错: {e}")
506
+ return True
507
+
508
+ async def _refresh_token(self, credential_data: Dict[str, Any], filename: str) -> Optional[Dict[str, Any]]:
509
+ """刷新token并更新存储"""
510
+ try:
511
+ # 创建Credentials对象
512
+ creds = Credentials.from_dict(credential_data)
513
+
514
+ # 检查是否可以刷新
515
+ if not creds.refresh_token:
516
+ log.error(f"没有refresh_token,无法刷新: {filename}")
517
+ return None
518
+
519
+ # 刷新token
520
+ log.debug(f"正在刷新token: {filename}")
521
+ await creds.refresh()
522
+
523
+ # 更新凭证数据
524
+ if creds.access_token:
525
+ credential_data["access_token"] = creds.access_token
526
+ # 保持兼容性
527
+ credential_data["token"] = creds.access_token
528
+
529
+ if creds.expires_at:
530
+ credential_data["expiry"] = creds.expires_at.isoformat()
531
+
532
+ # 保存到存储
533
+ await self._storage_adapter.store_credential(filename, credential_data)
534
+ log.info(f"Token刷新成功并已保存: {filename}")
535
+
536
+ return credential_data
537
+
538
+ except Exception as e:
539
+ error_msg = str(e)
540
+ log.error(f"Token刷新失败 {filename}: {error_msg}")
541
+
542
+ # 检查是否是凭证永久失效的错误
543
+ is_permanent_failure = self._is_permanent_refresh_failure(error_msg)
544
+
545
+ if is_permanent_failure:
546
+ log.warning(f"检测到凭证永久失效: {filename}")
547
+ # 记录失效状态,但不在这里禁用凭证,让上层调用者处理
548
+ await self.record_api_call_result(filename, False, 400)
549
+
550
+ return None
551
+
552
+ def _is_permanent_refresh_failure(self, error_msg: str) -> bool:
553
+ """判断是否是凭证永久失效的错误"""
554
+ # 常见的永久失效错误模式
555
+ permanent_error_patterns = [
556
+ "400 Bad Request",
557
+ "invalid_grant",
558
+ "refresh_token_expired",
559
+ "invalid_refresh_token",
560
+ "unauthorized_client",
561
+ "access_denied"
562
+ ]
563
+
564
+ error_msg_lower = error_msg.lower()
565
+ for pattern in permanent_error_patterns:
566
+ if pattern.lower() in error_msg_lower:
567
+ return True
568
+
569
+ return False
570
+
571
+ # 兼容性方法 - 保持与现有代码的接口兼容
572
+ async def _update_token_in_file(self, file_path: str, new_token: str, expires_at=None):
573
+ """更新凭证令牌(兼容性方法)"""
574
+ try:
575
+ credential_data = await self._storage_adapter.get_credential(file_path)
576
+ if not credential_data:
577
+ log.error(f"Credential not found for token update: {file_path}")
578
+ return False
579
+
580
+ # 更新令牌数据
581
+ credential_data["token"] = new_token
582
+ if expires_at:
583
+ credential_data["expiry"] = expires_at.isoformat() if hasattr(expires_at, 'isoformat') else expires_at
584
+
585
+ # 保存更新后的凭证
586
+ success = await self._storage_adapter.store_credential(file_path, credential_data)
587
+
588
+ if success:
589
+ log.debug(f"Token updated for credential: {file_path}")
590
+ else:
591
+ log.error(f"Failed to update token for credential: {file_path}")
592
+
593
+ return success
594
+
595
+ except Exception as e:
596
+ log.error(f"Error updating token for {file_path}: {e}")
597
+ return False
598
+
599
+
600
+ # 全局实例管理(保持兼容性)
601
+ _credential_manager: Optional[CredentialManager] = None
602
+
603
+ async def get_credential_manager() -> CredentialManager:
604
+ """获取全局凭证管理器实例"""
605
+ global _credential_manager
606
+
607
+ if _credential_manager is None:
608
+ _credential_manager = CredentialManager()
609
+ await _credential_manager.initialize()
610
+
611
+ return _credential_manager
src/format_detector.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Format detection utilities for supporting both OpenAI and Gemini request formats
3
+ """
4
+ from typing import Dict, Any
5
+
6
+ from log import log
7
+
8
+ def detect_request_format(data: Dict[str, Any]) -> str:
9
+ """
10
+ Detect whether the request is in OpenAI or Gemini format.
11
+
12
+ Returns:
13
+ "openai" or "gemini"
14
+ """
15
+ # OpenAI format indicators:
16
+ # - Has "messages" field with array of {role, content} objects
17
+ # - Role values are "system", "user", "assistant"
18
+ if "messages" in data and isinstance(data["messages"], list):
19
+ if data["messages"] and isinstance(data["messages"][0], dict):
20
+ # Check for OpenAI role values
21
+ first_role = data["messages"][0].get("role", "")
22
+ if first_role in ["system", "user", "assistant"]:
23
+ return "openai"
24
+
25
+ # Gemini format indicators:
26
+ # - Has "contents" field with array of {role, parts} objects
27
+ # - Role values are "user", "model"
28
+ # - May have "systemInstruction" field
29
+ if "contents" in data and isinstance(data["contents"], list):
30
+ if data["contents"] and isinstance(data["contents"][0], dict):
31
+ # Check for Gemini structure
32
+ if "parts" in data["contents"][0]:
33
+ return "gemini"
34
+
35
+ # Additional Gemini indicators
36
+ if "systemInstruction" in data or "generationConfig" in data:
37
+ return "gemini"
38
+
39
+ # Default to OpenAI if unclear (for backwards compatibility)
40
+ log.debug(f"Unable to definitively detect format, defaulting to OpenAI. Keys present: {list(data.keys())}")
41
+ return "openai"
42
+
43
+ def gemini_request_to_openai(gemini_request: Dict[str, Any]) -> Dict[str, Any]:
44
+ """
45
+ Convert a Gemini format request to OpenAI format.
46
+
47
+ Args:
48
+ gemini_request: Request in Gemini API format
49
+
50
+ Returns:
51
+ Dictionary in OpenAI API format
52
+ """
53
+ openai_request = {
54
+ "model": gemini_request.get("model", "gemini-2.5-pro"),
55
+ "messages": []
56
+ }
57
+
58
+ # Convert system instruction if present
59
+ if "systemInstruction" in gemini_request:
60
+ system_content = ""
61
+ if isinstance(gemini_request["systemInstruction"], dict):
62
+ parts = gemini_request["systemInstruction"].get("parts", [])
63
+ for part in parts:
64
+ if "text" in part:
65
+ system_content += part["text"]
66
+ elif isinstance(gemini_request["systemInstruction"], str):
67
+ system_content = gemini_request["systemInstruction"]
68
+
69
+ if system_content:
70
+ openai_request["messages"].append({
71
+ "role": "system",
72
+ "content": system_content
73
+ })
74
+
75
+ # Convert contents to messages
76
+ contents = gemini_request.get("contents", [])
77
+ for content in contents:
78
+ role = content.get("role", "user")
79
+ # Map Gemini roles to OpenAI roles
80
+ if role == "model":
81
+ role = "assistant"
82
+
83
+ # Convert parts to content
84
+ parts = content.get("parts", [])
85
+ if len(parts) == 1 and "text" in parts[0]:
86
+ # Simple text message
87
+ openai_request["messages"].append({
88
+ "role": role,
89
+ "content": parts[0]["text"]
90
+ })
91
+ elif len(parts) > 0:
92
+ # Multi-part message (could include images)
93
+ content_parts = []
94
+ for part in parts:
95
+ if "text" in part:
96
+ content_parts.append({
97
+ "type": "text",
98
+ "text": part["text"]
99
+ })
100
+ elif "inlineData" in part:
101
+ # Convert Gemini inline data to OpenAI image format
102
+ inline_data = part["inlineData"]
103
+ mime_type = inline_data.get("mimeType", "image/jpeg")
104
+ data = inline_data.get("data", "")
105
+ content_parts.append({
106
+ "type": "image_url",
107
+ "image_url": {
108
+ "url": f"data:{mime_type};base64,{data}"
109
+ }
110
+ })
111
+
112
+ if content_parts:
113
+ # If only one text part, use simple string format
114
+ if len(content_parts) == 1 and content_parts[0]["type"] == "text":
115
+ openai_request["messages"].append({
116
+ "role": role,
117
+ "content": content_parts[0]["text"]
118
+ })
119
+ else:
120
+ openai_request["messages"].append({
121
+ "role": role,
122
+ "content": content_parts
123
+ })
124
+
125
+ # Convert generation config if present
126
+ if "generationConfig" in gemini_request:
127
+ config = gemini_request["generationConfig"]
128
+ if "temperature" in config:
129
+ openai_request["temperature"] = config["temperature"]
130
+ if "topP" in config:
131
+ openai_request["top_p"] = config["topP"]
132
+ if "topK" in config:
133
+ openai_request["top_k"] = config["topK"]
134
+ if "maxOutputTokens" in config:
135
+ openai_request["max_tokens"] = config["maxOutputTokens"]
136
+ if "stopSequences" in config:
137
+ openai_request["stop"] = config["stopSequences"]
138
+ if "frequencyPenalty" in config:
139
+ openai_request["frequency_penalty"] = config["frequencyPenalty"]
140
+ if "presencePenalty" in config:
141
+ openai_request["presence_penalty"] = config["presencePenalty"]
142
+ if "candidateCount" in config:
143
+ openai_request["n"] = config["candidateCount"]
144
+ if "seed" in config:
145
+ openai_request["seed"] = config["seed"]
146
+
147
+ # Preserve stream setting if present
148
+ if "stream" in gemini_request:
149
+ openai_request["stream"] = gemini_request["stream"]
150
+
151
+ return openai_request
152
+
153
+ def validate_and_normalize_request(data: Dict[str, Any]) -> Dict[str, Any]:
154
+ """
155
+ Validate and normalize the request to OpenAI format.
156
+ Automatically detects format and converts if necessary.
157
+
158
+ Args:
159
+ data: Raw request data
160
+
161
+ Returns:
162
+ Normalized request in OpenAI format
163
+ """
164
+ format_type = detect_request_format(data)
165
+ log.info(f"Detected request format: {format_type}")
166
+
167
+ if format_type == "gemini":
168
+ # Convert Gemini format to OpenAI format
169
+ return gemini_request_to_openai(data)
170
+ else:
171
+ # Already in OpenAI format
172
+ return data
src/gemini_router.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini Router - Handles native Gemini format API requests
3
+ 处理原生Gemini格式请求的路由模块
4
+ """
5
+ import asyncio
6
+ import json
7
+ from contextlib import asynccontextmanager
8
+ from typing import Optional
9
+
10
+ from fastapi import APIRouter, HTTPException, Depends, Request, Path, Query, status, Header
11
+ from fastapi.responses import JSONResponse, StreamingResponse
12
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
+
14
+ from config import get_available_models, is_fake_streaming_model, is_anti_truncation_model, get_base_model_from_feature_model, get_anti_truncation_max_attempts, get_base_model_name
15
+ from log import log
16
+ from .anti_truncation import apply_anti_truncation_to_stream
17
+ from .credential_manager import CredentialManager
18
+ from .google_chat_api import send_gemini_request, build_gemini_payload_from_native
19
+ from .openai_transfer import _extract_content_and_reasoning
20
+ from .task_manager import create_managed_task
21
+ # 创建路由器
22
+ router = APIRouter()
23
+ security = HTTPBearer()
24
+
25
+ # 全局凭证管理器实例
26
+ credential_manager = None
27
+
28
+ @asynccontextmanager
29
+ async def get_credential_manager():
30
+ """获取全局凭证管理器实例"""
31
+ global credential_manager
32
+ if not credential_manager:
33
+ credential_manager = CredentialManager()
34
+ await credential_manager.initialize()
35
+ yield credential_manager
36
+
37
+ async def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
38
+ """验证用户密码(Bearer Token方式)"""
39
+ from config import get_api_password
40
+ password = await get_api_password()
41
+ token = credentials.credentials
42
+ if token != password:
43
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="密码错误")
44
+ return token
45
+
46
+ async def authenticate_gemini_flexible(
47
+ request: Request,
48
+ x_goog_api_key: Optional[str] = Header(None, alias="x-goog-api-key"),
49
+ key: Optional[str] = Query(None),
50
+ credentials: Optional[HTTPAuthorizationCredentials] = Depends(lambda: None)
51
+ ) -> str:
52
+ """灵活验证:支持x-goog-api-key头部、URL参数key或Authorization Bearer"""
53
+ from config import get_api_password
54
+ password = await get_api_password()
55
+
56
+ # 尝试从URL参数key获取(Google官方标准方式)
57
+ if key:
58
+ log.debug(f"Using URL parameter key authentication")
59
+ if key == password:
60
+ return key
61
+
62
+ # 尝试从Authorization头获取(兼容旧方式)
63
+ auth_header = request.headers.get("authorization")
64
+ if auth_header and auth_header.startswith("Bearer "):
65
+ token = auth_header[7:] # 移除 "Bearer " 前缀
66
+ log.debug(f"Using Bearer token authentication")
67
+ if token == password:
68
+ return token
69
+
70
+ # 尝试从x-goog-api-key头获取(新标准方式)
71
+ if x_goog_api_key:
72
+ log.debug(f"Using x-goog-api-key authentication")
73
+ if x_goog_api_key == password:
74
+ return x_goog_api_key
75
+
76
+ log.error(f"Authentication failed. Headers: {dict(request.headers)}, Query params: key={key}")
77
+ raise HTTPException(
78
+ status_code=status.HTTP_400_BAD_REQUEST,
79
+ detail="Missing or invalid authentication. Use 'key' URL parameter, 'x-goog-api-key' header, or 'Authorization: Bearer <token>'"
80
+ )
81
+
82
+ @router.get("/v1/v1beta/models")
83
+ @router.get("/v1/v1/models")
84
+ @router.get("/v1beta/models")
85
+ @router.get("/v1/models")
86
+ async def list_gemini_models():
87
+ """返回Gemini格式的模型列表"""
88
+ models = get_available_models("gemini")
89
+
90
+ # 构建符合Gemini API格式的模型列表
91
+ gemini_models = []
92
+ for model_name in models:
93
+ # 获取基础模型名
94
+ base_model = get_base_model_from_feature_model(model_name)
95
+
96
+ model_info = {
97
+ "name": f"models/{model_name}",
98
+ "baseModelId": base_model,
99
+ "version": "001",
100
+ "displayName": model_name,
101
+ "description": f"Gemini {base_model} model",
102
+ "inputTokenLimit": 1000000,
103
+ "outputTokenLimit": 8192,
104
+ "supportedGenerationMethods": ["generateContent", "streamGenerateContent"],
105
+ "temperature": 1.0,
106
+ "maxTemperature": 2.0,
107
+ "topP": 0.95,
108
+ "topK": 64
109
+ }
110
+ gemini_models.append(model_info)
111
+
112
+ return JSONResponse(content={
113
+ "models": gemini_models
114
+ })
115
+
116
+ @router.post("/v1/v1beta/models/{model:path}:generateContent")
117
+ @router.post("/v1/v1/models/{model:path}:generateContent")
118
+ @router.post("/v1beta/models/{model:path}:generateContent")
119
+ @router.post("/v1/models/{model:path}:generateContent")
120
+ async def generate_content(
121
+ model: str = Path(..., description="Model name"),
122
+ request: Request = None,
123
+ api_key: str = Depends(authenticate_gemini_flexible)
124
+ ):
125
+ """处理Gemini格式的内容生成请求(非流式)"""
126
+
127
+
128
+ # 获取原始请求数据
129
+ try:
130
+ request_data = await request.json()
131
+ except Exception as e:
132
+ log.error(f"Failed to parse JSON request: {e}")
133
+ raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
134
+
135
+ # 验证必要字段
136
+ if "contents" not in request_data or not request_data["contents"]:
137
+ raise HTTPException(status_code=400, detail="Missing required field: contents")
138
+
139
+ # 请求预处理:限制参数
140
+ if "generationConfig" in request_data and request_data["generationConfig"]:
141
+ generation_config = request_data["generationConfig"]
142
+
143
+ # 限制max_tokens (在Gemini中叫maxOutputTokens)
144
+ if "maxOutputTokens" in generation_config and generation_config["maxOutputTokens"] is not None:
145
+ if generation_config["maxOutputTokens"] > 65535:
146
+ generation_config["maxOutputTokens"] = 65535
147
+
148
+ # 覆写 top_k 为 64 (在Gemini中叫topK)
149
+ generation_config["topK"] = 64
150
+ else:
151
+ # 如果没有generationConfig,创建一个并设置topK
152
+ request_data["generationConfig"] = {"topK": 64}
153
+
154
+ # 处理模型名称和功能检测
155
+ use_anti_truncation = is_anti_truncation_model(model)
156
+
157
+ # 获取基础模型名
158
+ real_model = get_base_model_from_feature_model(model)
159
+
160
+ # 对于假流式模型,如果是流式端点才返回假流式响应
161
+ # 注意:这是generateContent端点,不应该触发假流式
162
+
163
+ # 对于抗截断模型的非流式请求,给出警告
164
+ if use_anti_truncation:
165
+ log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置")
166
+
167
+ # 健康检查
168
+ if (len(request_data["contents"]) == 1 and
169
+ request_data["contents"][0].get("role") == "user" and
170
+ request_data["contents"][0].get("parts", [{}])[0].get("text") == "Hi"):
171
+ return JSONResponse(content={
172
+ "candidates": [{
173
+ "content": {
174
+ "parts": [{"text": "gcli2api工作中"}],
175
+ "role": "model"
176
+ },
177
+ "finishReason": "STOP",
178
+ "index": 0
179
+ }]
180
+ })
181
+
182
+ # 获取凭证管理器
183
+ from src.credential_manager import get_credential_manager
184
+ cred_mgr = await get_credential_manager()
185
+
186
+ # 获取有效凭证
187
+ credential_result = await cred_mgr.get_valid_credential()
188
+ if not credential_result:
189
+ log.error("当前无可用凭证,请去控制台获取")
190
+ raise HTTPException(status_code=500, detail="当前无可用凭证,请去控制台获取")
191
+
192
+ # 增加调用计数
193
+ cred_mgr.increment_call_count()
194
+
195
+ # 构建Google API payload
196
+ try:
197
+ api_payload = build_gemini_payload_from_native(request_data, real_model)
198
+ except Exception as e:
199
+ log.error(f"Gemini payload build failed: {e}")
200
+ raise HTTPException(status_code=500, detail="Request processing failed")
201
+
202
+ # 发送请求(429重试已在google_api_client中处理)
203
+ response = await send_gemini_request(api_payload, False, cred_mgr)
204
+
205
+ # 处理响应
206
+ try:
207
+ if hasattr(response, 'body'):
208
+ response_data = json.loads(response.body.decode() if isinstance(response.body, bytes) else response.body)
209
+ elif hasattr(response, 'content'):
210
+ response_data = json.loads(response.content.decode() if isinstance(response.content, bytes) else response.content)
211
+ else:
212
+ response_data = json.loads(str(response))
213
+
214
+ return JSONResponse(content=response_data)
215
+
216
+ except Exception as e:
217
+ log.error(f"Response processing failed: {e}")
218
+ # 返回原始响应
219
+ if hasattr(response, 'content'):
220
+ return JSONResponse(content=json.loads(response.content))
221
+ else:
222
+ raise HTTPException(status_code=500, detail="Response processing failed")
223
+
224
+ @router.post("/v1/v1beta/models/{model:path}:streamGenerateContent")
225
+ @router.post("/v1/v1/models/{model:path}:streamGenerateContent")
226
+ @router.post("/v1beta/models/{model:path}:streamGenerateContent")
227
+ @router.post("/v1/models/{model:path}:streamGenerateContent")
228
+ async def stream_generate_content(
229
+ model: str = Path(..., description="Model name"),
230
+ request: Request = None,
231
+ api_key: str = Depends(authenticate_gemini_flexible)
232
+ ):
233
+ """处理Gemini格式的流式内容生成请求"""
234
+ log.debug(f"Stream request received for model: {model}")
235
+ log.debug(f"Request headers: {dict(request.headers)}")
236
+ log.debug(f"API key received: {api_key[:10] if api_key else None}...")
237
+ try:
238
+ body = await request.body()
239
+ log.debug(f"request body: {body.decode() if isinstance(body, bytes) else body}")
240
+ except Exception as e:
241
+ log.error(f"Failed to read request body: {e}")
242
+
243
+
244
+ # 获取原始请求数据
245
+ try:
246
+ request_data = await request.json()
247
+ except Exception as e:
248
+ log.error(f"Failed to parse JSON request: {e}")
249
+ raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
250
+
251
+ # 验���必要字段
252
+ if "contents" not in request_data or not request_data["contents"]:
253
+ raise HTTPException(status_code=400, detail="Missing required field: contents")
254
+
255
+ # 请求预处理:限制参数
256
+ if "generationConfig" in request_data and request_data["generationConfig"]:
257
+ generation_config = request_data["generationConfig"]
258
+
259
+ # 限制max_tokens (在Gemini中叫maxOutputTokens)
260
+ if "maxOutputTokens" in generation_config and generation_config["maxOutputTokens"] is not None:
261
+ if generation_config["maxOutputTokens"] > 65535:
262
+ generation_config["maxOutputTokens"] = 65535
263
+
264
+ # 覆写 top_k 为 64 (在Gemini中叫topK)
265
+ generation_config["topK"] = 64
266
+ else:
267
+ # 如果没有generationConfig,创建一个并设置topK
268
+ request_data["generationConfig"] = {"topK": 64}
269
+
270
+ # 处理模型名称和功能检测
271
+ use_fake_streaming = is_fake_streaming_model(model)
272
+ use_anti_truncation = is_anti_truncation_model(model)
273
+
274
+ # 获取基础模型名
275
+ real_model = get_base_model_from_feature_model(model)
276
+
277
+ # 对于假流式模型,返回假流式响应
278
+ if use_fake_streaming:
279
+ return await fake_stream_response_gemini(request_data, real_model)
280
+
281
+ # 获取凭证管理器
282
+ from src.credential_manager import get_credential_manager
283
+ cred_mgr = await get_credential_manager()
284
+
285
+ # 获取有效凭证
286
+ credential_result = await cred_mgr.get_valid_credential()
287
+ if not credential_result:
288
+ log.error("当前无可用凭证,请去控制台获取")
289
+ raise HTTPException(status_code=500, detail="当前无可用凭证,请去控制台获取")
290
+
291
+ # 增加调用计数
292
+ cred_mgr.increment_call_count()
293
+
294
+ # 构建Google API payload
295
+ try:
296
+ api_payload = build_gemini_payload_from_native(request_data, real_model)
297
+ except Exception as e:
298
+ log.error(f"Gemini payload build failed: {e}")
299
+ raise HTTPException(status_code=500, detail="Request processing failed")
300
+
301
+ # 处理抗截断功能(仅流式传输时有效)
302
+ if use_anti_truncation:
303
+ log.info("启用流式抗截断功能")
304
+ # 使用流式抗截断处理器
305
+ max_attempts = await get_anti_truncation_max_attempts()
306
+ return await apply_anti_truncation_to_stream(
307
+ lambda payload: send_gemini_request(payload, True, cred_mgr),
308
+ api_payload,
309
+ max_attempts
310
+ )
311
+
312
+ # 常规流式请求(429重试已在google_api_client中处理)
313
+ response = await send_gemini_request(api_payload, True, cred_mgr)
314
+
315
+ # 直接返回流式响应
316
+ return response
317
+
318
+ @router.post("/v1/v1beta/models/{model:path}:countTokens")
319
+ @router.post("/v1/v1/models/{model:path}:countTokens")
320
+ @router.post("/v1beta/models/{model:path}:countTokens")
321
+ @router.post("/v1/models/{model:path}:countTokens")
322
+ async def count_tokens(
323
+ request: Request = None,
324
+ api_key: str = Depends(authenticate_gemini_flexible)
325
+ ):
326
+ """模拟Gemini格式的token计数"""
327
+
328
+ try:
329
+ request_data = await request.json()
330
+ except Exception as e:
331
+ log.error(f"Failed to parse JSON request: {e}")
332
+ raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
333
+
334
+ # 简单的token计数模拟 - 基于文本长度估算
335
+ total_tokens = 0
336
+
337
+ # 如果有contents字段
338
+ if "contents" in request_data:
339
+ for content in request_data["contents"]:
340
+ if "parts" in content:
341
+ for part in content["parts"]:
342
+ if "text" in part:
343
+ # 简单估算:大约4字符=1token
344
+ text_length = len(part["text"])
345
+ total_tokens += max(1, text_length // 4)
346
+
347
+ # 如果有generateContentRequest字段
348
+ elif "generateContentRequest" in request_data:
349
+ gen_request = request_data["generateContentRequest"]
350
+ if "contents" in gen_request:
351
+ for content in gen_request["contents"]:
352
+ if "parts" in content:
353
+ for part in content["parts"]:
354
+ if "text" in part:
355
+ text_length = len(part["text"])
356
+ total_tokens += max(1, text_length // 4)
357
+
358
+ # 返回Gemini格式的响应
359
+ return JSONResponse(content={
360
+ "totalTokens": total_tokens
361
+ })
362
+
363
+ @router.get("/v1/v1beta/models/{model:path}")
364
+ @router.get("/v1/v1/models/{model:path}")
365
+ @router.get("/v1beta/models/{model:path}")
366
+ @router.get("/v1/models/{model:path}")
367
+ async def get_model_info(
368
+ model: str = Path(..., description="Model name"),
369
+ api_key: str = Depends(authenticate_gemini_flexible)
370
+ ):
371
+ """获取特定模型的信息"""
372
+
373
+ # 获取基础模型名称
374
+ base_model = get_base_model_name(model)
375
+
376
+ # 模拟模型信息
377
+ model_info = {
378
+ "name": f"models/{base_model}",
379
+ "baseModelId": base_model,
380
+ "version": "001",
381
+ "displayName": base_model,
382
+ "description": f"Gemini {base_model} model",
383
+ "inputTokenLimit": 128000,
384
+ "outputTokenLimit": 8192,
385
+ "supportedGenerationMethods": [
386
+ "generateContent",
387
+ "streamGenerateContent"
388
+ ],
389
+ "temperature": 1.0,
390
+ "maxTemperature": 2.0,
391
+ "topP": 0.95,
392
+ "topK": 64
393
+ }
394
+
395
+ return JSONResponse(content=model_info)
396
+
397
+ async def fake_stream_response_gemini(request_data: dict, model: str):
398
+ """处理Gemini格式的假流式响应"""
399
+
400
+ async def gemini_stream_generator():
401
+ try:
402
+ # 获取凭证管理器
403
+ from src.credential_manager import get_credential_manager
404
+ cred_mgr = await get_credential_manager()
405
+
406
+ # 获取有效凭证
407
+ credential_result = await cred_mgr.get_valid_credential()
408
+ if not credential_result:
409
+ log.error("当前无可用凭证,请去控制台获取")
410
+ error_chunk = {
411
+ "error": {
412
+ "message": "当前无凭证,请去控制台获取",
413
+ "type": "authentication_error",
414
+ "code": 500
415
+ }
416
+ }
417
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
418
+ yield "data: [DONE]\n\n".encode()
419
+ return
420
+
421
+ # 增加调用计数
422
+ cred_mgr.increment_call_count()
423
+
424
+ # 构建Google API payload
425
+ try:
426
+ api_payload = build_gemini_payload_from_native(request_data, model)
427
+ except Exception as e:
428
+ log.error(f"Gemini payload build failed: {e}")
429
+ error_chunk = {
430
+ "error": {
431
+ "message": f"Request processing failed: {str(e)}",
432
+ "type": "api_error",
433
+ "code": 500
434
+ }
435
+ }
436
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
437
+ yield "data: [DONE]\n\n".encode()
438
+ return
439
+
440
+ # 发送心跳
441
+ heartbeat = {
442
+ "candidates": [{
443
+ "content": {
444
+ "parts": [{"text": ""}],
445
+ "role": "model"
446
+ },
447
+ "finishReason": None,
448
+ "index": 0
449
+ }]
450
+ }
451
+ yield f"data: {json.dumps(heartbeat)}\n\n".encode()
452
+
453
+ # 异步发送实际请求
454
+ async def get_response():
455
+ return await send_gemini_request(api_payload, False, cred_mgr)
456
+
457
+ # 创建请求任务
458
+ response_task = create_managed_task(get_response(), name="gemini_fake_stream_request")
459
+
460
+ try:
461
+ # 每3秒发送一次心跳,直到收到响应
462
+ while not response_task.done():
463
+ await asyncio.sleep(3.0)
464
+ if not response_task.done():
465
+ yield f"data: {json.dumps(heartbeat)}\n\n".encode()
466
+
467
+ # 获取响应结果
468
+ response = await response_task
469
+
470
+ except asyncio.CancelledError:
471
+ # 取消任务并传播取消
472
+ response_task.cancel()
473
+ try:
474
+ await response_task
475
+ except asyncio.CancelledError:
476
+ pass
477
+ raise
478
+ except Exception as e:
479
+ # 取消任务并处理其他异常
480
+ response_task.cancel()
481
+ try:
482
+ await response_task
483
+ except asyncio.CancelledError:
484
+ pass
485
+ log.error(f"Fake streaming request failed: {e}")
486
+ raise
487
+
488
+ # 发送实际请求
489
+ # response 已在上面获取
490
+
491
+ # 处理结果
492
+ try:
493
+ if hasattr(response, 'body'):
494
+ response_data = json.loads(response.body.decode() if isinstance(response.body, bytes) else response.body)
495
+ elif hasattr(response, 'content'):
496
+ response_data = json.loads(response.content.decode() if isinstance(response.content, bytes) else response.content)
497
+ else:
498
+ response_data = json.loads(str(response))
499
+
500
+ log.debug(f"Gemini fake stream response data: {response_data}")
501
+
502
+ # 发送完整内容作为单个chunk,使用思维链分离
503
+ if "candidates" in response_data and response_data["candidates"]:
504
+ candidate = response_data["candidates"][0]
505
+ if "content" in candidate and "parts" in candidate["content"]:
506
+ parts = candidate["content"]["parts"]
507
+ content, reasoning_content = _extract_content_and_reasoning(parts)
508
+ log.debug(f"Gemini extracted content: {content}")
509
+ log.debug(f"Gemini extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...")
510
+
511
+ # 如果没有正常内容但有思维内容
512
+ if not content and reasoning_content:
513
+ log.warning(f"Gemini fake stream contains only thinking content: {reasoning_content[:100]}...")
514
+ content = "[模型正在思考中,请稍后再试或重新提问]"
515
+
516
+ if content:
517
+ # 构建包含分离内容的响应
518
+ parts_response = [{"text": content}]
519
+ if reasoning_content:
520
+ parts_response.append({"text": reasoning_content, "thought": True})
521
+
522
+ content_chunk = {
523
+ "candidates": [{
524
+ "content": {
525
+ "parts": parts_response,
526
+ "role": "model"
527
+ },
528
+ "finishReason": candidate.get("finishReason", "STOP"),
529
+ "index": 0
530
+ }]
531
+ }
532
+ yield f"data: {json.dumps(content_chunk)}\n\n".encode()
533
+ else:
534
+ log.warning(f"No content found in Gemini candidate: {candidate}")
535
+ # 提供默认回复
536
+ error_chunk = {
537
+ "candidates": [{
538
+ "content": {
539
+ "parts": [{"text": "[响应为空,请重新尝试]"}],
540
+ "role": "model"
541
+ },
542
+ "finishReason": "STOP",
543
+ "index": 0
544
+ }]
545
+ }
546
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
547
+ else:
548
+ log.warning(f"No content/parts found in Gemini candidate: {candidate}")
549
+ # 返回原始响应
550
+ yield f"data: {json.dumps(response_data)}\n\n".encode()
551
+ else:
552
+ log.warning(f"No candidates found in Gemini response: {response_data}")
553
+ yield f"data: {json.dumps(response_data)}\n\n".encode()
554
+
555
+ except Exception as e:
556
+ log.error(f"Response parsing failed: {e}")
557
+ error_chunk = {
558
+ "candidates": [{
559
+ "content": {
560
+ "parts": [{"text": f"Response parsing error: {str(e)}"}],
561
+ "role": "model"
562
+ },
563
+ "finishReason": "ERROR",
564
+ "index": 0
565
+ }]
566
+ }
567
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
568
+
569
+ yield "data: [DONE]\n\n".encode()
570
+
571
+ except Exception as e:
572
+ log.error(f"Fake streaming error: {e}")
573
+ error_chunk = {
574
+ "error": {
575
+ "message": f"Fake streaming error: {str(e)}",
576
+ "type": "api_error",
577
+ "code": 500
578
+ }
579
+ }
580
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
581
+ yield "data: [DONE]\n\n".encode()
582
+
583
+ return StreamingResponse(gemini_stream_generator(), media_type="text/event-stream")
src/google_chat_api.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Google API Client - Handles all communication with Google's Gemini API.
3
+ This module is used by both OpenAI compatibility layer and native Gemini endpoints.
4
+ """
5
+ import asyncio
6
+ import gc
7
+ import json
8
+
9
+ from fastapi import Response
10
+ from fastapi.responses import StreamingResponse
11
+
12
+ from config import (
13
+ get_code_assist_endpoint,
14
+ DEFAULT_SAFETY_SETTINGS,
15
+ get_base_model_name,
16
+ get_thinking_budget,
17
+ should_include_thoughts,
18
+ is_search_model,
19
+ get_auto_ban_enabled,
20
+ get_auto_ban_error_codes,
21
+ get_retry_429_max_retries,
22
+ get_retry_429_enabled,
23
+ get_retry_429_interval
24
+ )
25
+ from .httpx_client import http_client, create_streaming_client_with_kwargs
26
+ from log import log
27
+ from .credential_manager import CredentialManager
28
+ from .usage_stats import record_successful_call
29
+ from .utils import get_user_agent
30
+
31
+ def _create_error_response(message: str, status_code: int = 500) -> Response:
32
+ """Create standardized error response."""
33
+ return Response(
34
+ content=json.dumps({
35
+ "error": {
36
+ "message": message,
37
+ "type": "api_error",
38
+ "code": status_code
39
+ }
40
+ }),
41
+ status_code=status_code,
42
+ media_type="application/json"
43
+ )
44
+
45
+ async def _handle_api_error(credential_manager: CredentialManager, status_code: int, response_content: str = ""):
46
+ """Handle API errors by rotating credentials when needed. Error recording should be done before calling this function."""
47
+ if status_code == 429 and credential_manager:
48
+ if response_content:
49
+ log.error(f"Google API returned status 429 - quota exhausted. Response details: {response_content[:500]}")
50
+ else:
51
+ log.error("Google API returned status 429 - quota exhausted, switching credentials")
52
+ await credential_manager.force_rotate_credential()
53
+
54
+ # 处理自动封禁的错误码
55
+ elif await get_auto_ban_enabled() and status_code in await get_auto_ban_error_codes() and credential_manager:
56
+ if response_content:
57
+ log.error(f"Google API returned status {status_code} - auto ban triggered. Response details: {response_content[:500]}")
58
+ else:
59
+ log.warning(f"Google API returned status {status_code} - auto ban triggered, rotating credentials")
60
+ await credential_manager.force_rotate_credential()
61
+
62
+ async def _prepare_request_headers_and_payload(payload: dict, credential_data: dict):
63
+ """Prepare request headers and final payload from credential data."""
64
+ # 尝试获取token,支持多种字段名
65
+ token = credential_data.get('token') or credential_data.get('access_token', '')
66
+
67
+ if not token:
68
+ raise Exception("凭证中没有找到有效的访问令牌(token或access_token字段)")
69
+
70
+ headers = {
71
+ "Authorization": f"Bearer {token}",
72
+ "Content-Type": "application/json",
73
+ "User-Agent": get_user_agent(),
74
+ }
75
+
76
+ # 直接使用凭证数据中的项目ID
77
+ project_id = credential_data.get("project_id", "")
78
+ if not project_id:
79
+ raise Exception("项目ID不存在于凭证数据中")
80
+
81
+ final_payload = {
82
+ "model": payload.get("model"),
83
+ "project": project_id,
84
+ "request": payload.get("request", {})
85
+ }
86
+
87
+ return headers, final_payload
88
+
89
+ async def send_gemini_request(payload: dict, is_streaming: bool = False, credential_manager: CredentialManager = None) -> Response:
90
+ """
91
+ Send a request to Google's Gemini API.
92
+
93
+ Args:
94
+ payload: The request payload in Gemini format
95
+ is_streaming: Whether this is a streaming request
96
+ credential_manager: CredentialManager instance
97
+
98
+ Returns:
99
+ FastAPI Response object
100
+ """
101
+ # 获取429重试配置
102
+ max_retries = await get_retry_429_max_retries()
103
+ retry_429_enabled = await get_retry_429_enabled()
104
+ retry_interval = await get_retry_429_interval()
105
+
106
+ # 确定API端点
107
+ action = "streamGenerateContent" if is_streaming else "generateContent"
108
+ target_url = f"{await get_code_assist_endpoint()}/v1internal:{action}"
109
+ if is_streaming:
110
+ target_url += "?alt=sse"
111
+
112
+ # 确保有credential_manager
113
+ if not credential_manager:
114
+ return _create_error_response("Credential manager not provided", 500)
115
+
116
+ # 获取当前凭证
117
+ try:
118
+ credential_result = await credential_manager.get_valid_credential()
119
+ if not credential_result:
120
+ return _create_error_response("No valid credentials available", 500)
121
+
122
+ current_file, credential_data = credential_result
123
+ headers, final_payload = await _prepare_request_headers_and_payload(payload, credential_data)
124
+ except Exception as e:
125
+ return _create_error_response(str(e), 500)
126
+
127
+ # 预序列化payload,避免重试时重复序列化
128
+ final_post_data = json.dumps(final_payload)
129
+
130
+ # Debug日志:打印请求体结构
131
+ log.debug(f"Final request payload structure: {json.dumps(final_payload, ensure_ascii=False, indent=2)}")
132
+
133
+ for attempt in range(max_retries + 1):
134
+ try:
135
+ if is_streaming:
136
+ # 流式请求处理 - 使用httpx_client模块的统一配置
137
+ client = await create_streaming_client_with_kwargs()
138
+
139
+ try:
140
+ # 使用stream方法但不在async with块中消费数据
141
+ stream_ctx = client.stream("POST", target_url, content=final_post_data, headers=headers)
142
+ resp = await stream_ctx.__aenter__()
143
+
144
+ if resp.status_code == 429:
145
+ # 记录429错误并获取响应内容
146
+ response_content = ""
147
+ try:
148
+ content_bytes = await resp.aread()
149
+ if isinstance(content_bytes, bytes):
150
+ response_content = content_bytes.decode('utf-8', errors='ignore')
151
+ except Exception as e:
152
+ log.debug(f"[STREAMING] Failed to read 429 response content: {e}")
153
+
154
+ # 显示详细的429错误信息
155
+ if response_content:
156
+ log.error(f"Google API returned status 429 (STREAMING). Response details: {response_content[:500]}")
157
+ else:
158
+ log.error("Google API returned status 429 (STREAMING) - quota exhausted, no response details available")
159
+
160
+ if credential_manager and current_file:
161
+ await credential_manager.record_api_call_result(current_file, False, 429)
162
+
163
+ # 清理资源
164
+ try:
165
+ await stream_ctx.__aexit__(None, None, None)
166
+ except:
167
+ pass
168
+ await client.aclose()
169
+
170
+ # 如果重试可用且未达到最大次数,进行重试
171
+ if retry_429_enabled and attempt < max_retries:
172
+ log.warning(f"[RETRY] 429 error encountered, retrying ({attempt + 1}/{max_retries})")
173
+ if credential_manager:
174
+ # 429错误时强制轮换凭证,不增加调用计数
175
+ await credential_manager.force_rotate_credential()
176
+ # 重新获取凭证和headers(凭证可能已轮换)
177
+ new_credential_result = await credential_manager.get_valid_credential()
178
+ if new_credential_result:
179
+ current_file, credential_data = new_credential_result
180
+ headers, updated_payload = await _prepare_request_headers_and_payload(payload, credential_data)
181
+ final_post_data = json.dumps(updated_payload)
182
+ await asyncio.sleep(retry_interval)
183
+ continue # 跳出内层处理,继续外层循环重试
184
+ else:
185
+ # 返回429错误流
186
+ async def error_stream():
187
+ error_response = {
188
+ "error": {
189
+ "message": "429 rate limit exceeded, max retries reached",
190
+ "type": "api_error",
191
+ "code": 429
192
+ }
193
+ }
194
+ yield f"data: {json.dumps(error_response)}\n\n"
195
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=429)
196
+ elif resp.status_code != 200:
197
+ # 处理其他非200状态码的错误
198
+ response_content = ""
199
+ try:
200
+ content_bytes = await resp.aread()
201
+ if isinstance(content_bytes, bytes):
202
+ response_content = content_bytes.decode('utf-8', errors='ignore')
203
+ except Exception as e:
204
+ log.debug(f"[STREAMING] Failed to read error response content: {e}")
205
+
206
+ # 显示详细的错误信息
207
+ if response_content:
208
+ log.error(f"Google API returned status {resp.status_code} (STREAMING). Response details: {response_content[:500]}")
209
+ else:
210
+ log.error(f"Google API returned status {resp.status_code} (STREAMING) - no response details available")
211
+
212
+ # 记录API调用错误
213
+ if credential_manager and current_file:
214
+ await credential_manager.record_api_call_result(current_file, False, resp.status_code)
215
+
216
+ # 清理资源
217
+ try:
218
+ await stream_ctx.__aexit__(None, None, None)
219
+ except:
220
+ pass
221
+ await client.aclose()
222
+
223
+ # 处理凭证轮换
224
+ await _handle_api_error(credential_manager, resp.status_code, response_content)
225
+
226
+ # 返回错误流
227
+ async def error_stream():
228
+ error_response = {
229
+ "error": {
230
+ "message": f"API error: {resp.status_code}",
231
+ "type": "api_error",
232
+ "code": resp.status_code
233
+ }
234
+ }
235
+ yield f"data: {json.dumps(error_response)}\n\n"
236
+ return StreamingResponse(error_stream(), media_type="text/event-stream", status_code=resp.status_code)
237
+ else:
238
+ # 成功响应,传递所有资源给流式处理函数管理
239
+ return _handle_streaming_response_managed(resp, stream_ctx, client, credential_manager, payload.get("model", ""), current_file)
240
+
241
+ except Exception as e:
242
+ # 清理资源
243
+ try:
244
+ await client.aclose()
245
+ except:
246
+ pass
247
+ raise e
248
+
249
+ else:
250
+ # 非流式请求处理 - 使用httpx_client模块
251
+ async with http_client.get_client(timeout=None) as client:
252
+ resp = await client.post(
253
+ target_url, content=final_post_data, headers=headers
254
+ )
255
+
256
+ if resp.status_code == 429:
257
+ # 记录429错误
258
+ if credential_manager and current_file:
259
+ await credential_manager.record_api_call_result(current_file, False, 429)
260
+
261
+ # 如果重试可用且未达到最大次数,继续重试
262
+ if retry_429_enabled and attempt < max_retries:
263
+ log.warning(f"[RETRY] 429 error encountered, retrying ({attempt + 1}/{max_retries})")
264
+ if credential_manager:
265
+ # 429错误时强制轮换凭证,不增加调用计数
266
+ await credential_manager.force_rotate_credential()
267
+ # 重新获取凭证和headers(凭证可能已轮换)
268
+ new_credential_result = await credential_manager.get_valid_credential()
269
+ if new_credential_result:
270
+ current_file, credential_data = new_credential_result
271
+ headers, updated_payload = await _prepare_request_headers_and_payload(payload, credential_data)
272
+ final_post_data = json.dumps(updated_payload)
273
+ await asyncio.sleep(retry_interval)
274
+ continue
275
+ else:
276
+ log.error(f"[RETRY] Max retries exceeded for 429 error")
277
+ return _create_error_response("429 rate limit exceeded, max retries reached", 429)
278
+ else:
279
+ # 非429错误或成功响应,正常处理
280
+ return await _handle_non_streaming_response(resp, credential_manager, payload.get("model", ""), current_file)
281
+
282
+ except Exception as e:
283
+ if attempt < max_retries:
284
+ log.warning(f"[RETRY] Request failed with exception, retrying ({attempt + 1}/{max_retries}): {str(e)}")
285
+ await asyncio.sleep(retry_interval)
286
+ continue
287
+ else:
288
+ log.error(f"Request to Google API failed: {str(e)}")
289
+ return _create_error_response(f"Request failed: {str(e)}")
290
+
291
+ # 如果循环结束仍未成功,返回错误
292
+ return _create_error_response("Max retries exceeded", 429)
293
+
294
+
295
+ def _handle_streaming_response_managed(resp, stream_ctx, client, credential_manager: CredentialManager = None, model_name: str = "", current_file: str = None) -> StreamingResponse:
296
+ """Handle streaming response with complete resource lifecycle management."""
297
+
298
+ # 检查HTTP错误
299
+ if resp.status_code != 200:
300
+ # 立即清理资源并返回错误
301
+ async def cleanup_and_error():
302
+ try:
303
+ await stream_ctx.__aexit__(None, None, None)
304
+ except:
305
+ pass
306
+ try:
307
+ await client.aclose()
308
+ except:
309
+ pass
310
+
311
+ # 获取响应内容用于详细错误显示
312
+ response_content = ""
313
+ try:
314
+ content_bytes = await resp.aread()
315
+ if isinstance(content_bytes, bytes):
316
+ response_content = content_bytes.decode('utf-8', errors='ignore')
317
+ except Exception as e:
318
+ log.debug(f"[STREAMING] Failed to read response content for error analysis: {e}")
319
+ response_content = ""
320
+
321
+ # 显示详细错误信息
322
+ if resp.status_code == 429:
323
+ if response_content:
324
+ log.error(f"Google API returned status 429 (STREAMING). Response details: {response_content[:500]}")
325
+ else:
326
+ log.error(f"Google API returned status 429 (STREAMING)")
327
+ else:
328
+ if response_content:
329
+ log.error(f"Google API returned status {resp.status_code} (STREAMING). Response details: {response_content[:500]}")
330
+ else:
331
+ log.error(f"Google API returned status {resp.status_code} (STREAMING)")
332
+
333
+ # 记录API调用错误
334
+ if credential_manager and current_file:
335
+ await credential_manager.record_api_call_result(current_file, False, resp.status_code)
336
+
337
+ await _handle_api_error(credential_manager, resp.status_code, response_content)
338
+
339
+ error_response = {
340
+ "error": {
341
+ "message": f"API error: {resp.status_code}",
342
+ "type": "api_error",
343
+ "code": resp.status_code
344
+ }
345
+ }
346
+ yield f'data: {json.dumps(error_response)}\n\n'.encode('utf-8')
347
+
348
+ return StreamingResponse(
349
+ cleanup_and_error(),
350
+ media_type="text/event-stream",
351
+ status_code=resp.status_code
352
+ )
353
+
354
+ # 正常流式响应处理,确保资源在流结束时被清理
355
+ async def managed_stream_generator():
356
+ success_recorded = False
357
+ managed_stream_generator._chunk_count = 0 # 初始化chunk计数器
358
+ try:
359
+ async for chunk in resp.aiter_lines():
360
+ if not chunk or not chunk.startswith('data: '):
361
+ continue
362
+
363
+ # 记录第一次成功响应
364
+ if not success_recorded:
365
+ if current_file and credential_manager:
366
+ await credential_manager.record_api_call_result(current_file, True)
367
+ # 记录到使用统计
368
+ try:
369
+ await record_successful_call(current_file, model_name)
370
+ except Exception as e:
371
+ log.debug(f"Failed to record usage statistics: {e}")
372
+ success_recorded = True
373
+
374
+ payload = chunk[len('data: '):]
375
+ try:
376
+ obj = json.loads(payload)
377
+ if "response" in obj:
378
+ data = obj["response"]
379
+ yield f"data: {json.dumps(data, separators=(',',':'))}\n\n".encode()
380
+ await asyncio.sleep(0) # 让其他协程有机会运行
381
+
382
+ # 定期释放内存(每100个chunk)
383
+ if hasattr(managed_stream_generator, '_chunk_count'):
384
+ managed_stream_generator._chunk_count += 1
385
+ if managed_stream_generator._chunk_count % 100 == 0:
386
+ gc.collect()
387
+ else:
388
+ yield f"data: {json.dumps(obj, separators=(',',':'))}\n\n".encode()
389
+ except json.JSONDecodeError:
390
+ continue
391
+
392
+ except Exception as e:
393
+ log.error(f"Streaming error: {e}")
394
+ err = {"error": {"message": str(e), "type": "api_error", "code": 500}}
395
+ yield f"data: {json.dumps(err)}\n\n".encode()
396
+ finally:
397
+ # 确保清理所有资源
398
+ try:
399
+ await stream_ctx.__aexit__(None, None, None)
400
+ except Exception as e:
401
+ log.debug(f"Error closing stream context: {e}")
402
+ try:
403
+ await client.aclose()
404
+ except Exception as e:
405
+ log.debug(f"Error closing client: {e}")
406
+
407
+ return StreamingResponse(
408
+ managed_stream_generator(),
409
+ media_type="text/event-stream"
410
+ )
411
+
412
+ async def _handle_non_streaming_response(resp, credential_manager: CredentialManager = None, model_name: str = "", current_file: str = None) -> Response:
413
+ """Handle non-streaming response from Google API."""
414
+ if resp.status_code == 200:
415
+ try:
416
+ # 记录成功响应
417
+ if current_file and credential_manager:
418
+ await credential_manager.record_api_call_result(current_file, True)
419
+ # 记录到使用统计
420
+ try:
421
+ await record_successful_call(current_file, model_name)
422
+ except Exception as e:
423
+ log.debug(f"Failed to record usage statistics: {e}")
424
+
425
+ raw = await resp.aread()
426
+ google_api_response = raw.decode('utf-8')
427
+ if google_api_response.startswith('data: '):
428
+ google_api_response = google_api_response[len('data: '):]
429
+ google_api_response = json.loads(google_api_response)
430
+ log.debug(f"Google API原始响应: {json.dumps(google_api_response, ensure_ascii=False)[:500]}...")
431
+ standard_gemini_response = google_api_response.get("response")
432
+ log.debug(f"提取的response字段: {json.dumps(standard_gemini_response, ensure_ascii=False)[:500]}...")
433
+ return Response(
434
+ content=json.dumps(standard_gemini_response),
435
+ status_code=200,
436
+ media_type="application/json; charset=utf-8"
437
+ )
438
+ except Exception as e:
439
+ log.error(f"Failed to parse Google API response: {str(e)}")
440
+ return Response(
441
+ content=resp.content,
442
+ status_code=resp.status_code,
443
+ media_type=resp.headers.get("Content-Type")
444
+ )
445
+ else:
446
+ # 获取响应内容用于详细错误显示
447
+ response_content = ""
448
+ try:
449
+ if hasattr(resp, 'content'):
450
+ content = resp.content
451
+ if isinstance(content, bytes):
452
+ response_content = content.decode('utf-8', errors='ignore')
453
+ else:
454
+ content_bytes = await resp.aread()
455
+ if isinstance(content_bytes, bytes):
456
+ response_content = content_bytes.decode('utf-8', errors='ignore')
457
+ except Exception as e:
458
+ log.debug(f"[NON-STREAMING] Failed to read response content for error analysis: {e}")
459
+ response_content = ""
460
+
461
+ # 显示详细错误信息
462
+ if resp.status_code == 429:
463
+ if response_content:
464
+ log.error(f"Google API returned status 429 (NON-STREAMING). Response details: {response_content[:500]}")
465
+ else:
466
+ log.error(f"Google API returned status 429 (NON-STREAMING)")
467
+ else:
468
+ if response_content:
469
+ log.error(f"Google API returned status {resp.status_code} (NON-STREAMING). Response details: {response_content[:500]}")
470
+ else:
471
+ log.error(f"Google API returned status {resp.status_code} (NON-STREAMING)")
472
+
473
+ # 记录API调用错误
474
+ if credential_manager and current_file:
475
+ await credential_manager.record_api_call_result(current_file, False, resp.status_code)
476
+
477
+ await _handle_api_error(credential_manager, resp.status_code, response_content)
478
+
479
+ return _create_error_response(f"API error: {resp.status_code}", resp.status_code)
480
+
481
+ def build_gemini_payload_from_native(native_request: dict, model_from_path: str) -> dict:
482
+ """
483
+ Build a Gemini API payload from a native Gemini request with full pass-through support.
484
+ """
485
+ # 创建请求副本以避免修改原始数据
486
+ request_data = native_request.copy()
487
+
488
+ # 应用默认安全设置(如果未指定)
489
+ if "safetySettings" not in request_data:
490
+ request_data["safetySettings"] = DEFAULT_SAFETY_SETTINGS
491
+
492
+ # 确保generationConfig存在
493
+ if "generationConfig" not in request_data:
494
+ request_data["generationConfig"] = {}
495
+
496
+ generation_config = request_data["generationConfig"]
497
+
498
+ # 配置thinking(如果未指定thinkingConfig)
499
+ if "thinkingConfig" not in generation_config:
500
+ generation_config["thinkingConfig"] = {}
501
+
502
+ thinking_config = generation_config["thinkingConfig"]
503
+
504
+ # 只有在未明确设置时才应用默认thinking配置
505
+ if "includeThoughts" not in thinking_config:
506
+ thinking_config["includeThoughts"] = should_include_thoughts(model_from_path)
507
+ if "thinkingBudget" not in thinking_config:
508
+ thinking_config["thinkingBudget"] = get_thinking_budget(model_from_path)
509
+
510
+ # 为搜索模型添加Google Search工具(如果未指定且没有functionDeclarations)
511
+ if is_search_model(model_from_path):
512
+ if "tools" not in request_data:
513
+ request_data["tools"] = []
514
+ # 检查是否已有functionDeclarations或googleSearch工具
515
+ has_function_declarations = any(tool.get("functionDeclarations") for tool in request_data["tools"])
516
+ has_google_search = any(tool.get("googleSearch") for tool in request_data["tools"])
517
+
518
+ # 只有在没有任何工具时才添加googleSearch,或者只有googleSearch工具时可以添加更多googleSearch
519
+ if not has_function_declarations and not has_google_search:
520
+ request_data["tools"].append({"googleSearch": {}})
521
+
522
+ # 透传所有其他Gemini原生字段:
523
+ # - contents (必需)
524
+ # - systemInstruction (可选)
525
+ # - generationConfig (已处理)
526
+ # - safetySettings (已处理)
527
+ # - tools (已处理)
528
+ # - toolConfig (透传)
529
+ # - cachedContent (透传)
530
+ # - 以及任何其他未知字段都会被透传
531
+
532
+ return {
533
+ "model": get_base_model_name(model_from_path),
534
+ "request": request_data
535
+ }
src/google_oauth_api.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Google OAuth2 认证模块
3
+ """
4
+ import time
5
+ import jwt
6
+ import asyncio
7
+ from datetime import datetime, timezone, timedelta
8
+ from typing import Optional, Dict, Any, List
9
+ from urllib.parse import urlencode
10
+
11
+ from config import get_oauth_proxy_url, get_googleapis_proxy_url, get_resource_manager_api_url, get_service_usage_api_url
12
+ from log import log
13
+ from .httpx_client import get_async, post_async
14
+
15
+
16
+ class TokenError(Exception):
17
+ """Token相关错误"""
18
+ pass
19
+
20
+ class Credentials:
21
+ """凭证类"""
22
+
23
+ def __init__(self, access_token: str, refresh_token: str = None,
24
+ client_id: str = None, client_secret: str = None,
25
+ expires_at: datetime = None, project_id: str = None):
26
+ self.access_token = access_token
27
+ self.refresh_token = refresh_token
28
+ self.client_id = client_id
29
+ self.client_secret = client_secret
30
+ self.expires_at = expires_at
31
+ self.project_id = project_id
32
+
33
+ # 反代配置将在使用时异步获取
34
+ self.oauth_base_url = None
35
+ self.token_endpoint = None
36
+
37
+ def is_expired(self) -> bool:
38
+ """检查token是否过期"""
39
+ if not self.expires_at:
40
+ return True
41
+
42
+ # 提前3分钟认为过期
43
+ buffer = timedelta(minutes=3)
44
+ return (self.expires_at - buffer) <= datetime.now(timezone.utc)
45
+
46
+ async def refresh_if_needed(self) -> bool:
47
+ """如果需要则刷新token"""
48
+ if not self.is_expired():
49
+ return False
50
+
51
+ if not self.refresh_token:
52
+ raise TokenError("需要刷新令牌但未提供")
53
+
54
+ await self.refresh()
55
+ return True
56
+
57
+ async def refresh(self, max_retries: int = 3, base_delay: float = 1.0):
58
+ """刷新访问令牌,支持重试机制"""
59
+ if not self.refresh_token:
60
+ raise TokenError("无刷新令牌")
61
+
62
+ data = {
63
+ 'client_id': self.client_id,
64
+ 'client_secret': self.client_secret,
65
+ 'refresh_token': self.refresh_token,
66
+ 'grant_type': 'refresh_token'
67
+ }
68
+
69
+ last_exception = None
70
+ for attempt in range(max_retries + 1):
71
+ try:
72
+ oauth_base_url = await get_oauth_proxy_url()
73
+ token_url = f"{oauth_base_url.rstrip('/')}/token"
74
+ response = await post_async(
75
+ token_url,
76
+ data=data,
77
+ headers={'Content-Type': 'application/x-www-form-urlencoded'}
78
+ )
79
+ response.raise_for_status()
80
+
81
+ token_data = response.json()
82
+ self.access_token = token_data['access_token']
83
+
84
+ if 'expires_in' in token_data:
85
+ expires_in = int(token_data['expires_in'])
86
+ self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
87
+
88
+ if 'refresh_token' in token_data:
89
+ self.refresh_token = token_data['refresh_token']
90
+
91
+ if attempt > 0:
92
+ log.debug(f"Token刷新成功(第{attempt + 1}次尝试),过期时间: {self.expires_at}")
93
+ else:
94
+ log.debug(f"Token刷新成功,过期时间: {self.expires_at}")
95
+ return
96
+
97
+ except Exception as e:
98
+ last_exception = e
99
+ error_msg = str(e)
100
+
101
+ # 检查是否是不可恢复的错误,如果是则不重试
102
+ if self._is_non_retryable_error(error_msg):
103
+ log.error(f"Token刷新遇到不可恢复错误: {error_msg}")
104
+ break
105
+
106
+ if attempt < max_retries:
107
+ # 计算退避延迟时间(指数退避)
108
+ delay = base_delay * (2 ** attempt)
109
+ log.warning(f"Token刷新失败(第{attempt + 1}次尝试): {error_msg},{delay}秒后重试...")
110
+ await asyncio.sleep(delay)
111
+ else:
112
+ break
113
+
114
+ # 所有重试都失败了
115
+ error_msg = f"Token刷新失败(已重试{max_retries}次): {str(last_exception)}"
116
+ log.error(error_msg)
117
+ raise TokenError(error_msg)
118
+
119
+ def _is_non_retryable_error(self, error_msg: str) -> bool:
120
+ """判断是否是不需要重试的错误"""
121
+ non_retryable_patterns = [
122
+ "400 Bad Request",
123
+ "invalid_grant",
124
+ "refresh_token_expired",
125
+ "invalid_refresh_token",
126
+ "unauthorized_client",
127
+ "access_denied",
128
+ "401 Unauthorized"
129
+ ]
130
+
131
+ error_msg_lower = error_msg.lower()
132
+ for pattern in non_retryable_patterns:
133
+ if pattern.lower() in error_msg_lower:
134
+ return True
135
+
136
+ return False
137
+
138
+ @classmethod
139
+ def from_dict(cls, data: Dict[str, Any]) -> 'Credentials':
140
+ """从字典创建凭证"""
141
+ # 处理过期时间
142
+ expires_at = None
143
+ if 'expiry' in data and data['expiry']:
144
+ try:
145
+ expiry_str = data['expiry']
146
+ if isinstance(expiry_str, str):
147
+ if expiry_str.endswith('Z'):
148
+ expires_at = datetime.fromisoformat(expiry_str.replace('Z', '+00:00'))
149
+ elif '+' in expiry_str:
150
+ expires_at = datetime.fromisoformat(expiry_str)
151
+ else:
152
+ expires_at = datetime.fromisoformat(expiry_str).replace(tzinfo=timezone.utc)
153
+ except ValueError:
154
+ log.warning(f"无法解析过期时间: {expiry_str}")
155
+
156
+ return cls(
157
+ access_token=data.get('token') or data.get('access_token', ''),
158
+ refresh_token=data.get('refresh_token'),
159
+ client_id=data.get('client_id'),
160
+ client_secret=data.get('client_secret'),
161
+ expires_at=expires_at,
162
+ project_id=data.get('project_id')
163
+ )
164
+
165
+ def to_dict(self) -> Dict[str, Any]:
166
+ """转为字典"""
167
+ result = {
168
+ 'access_token': self.access_token,
169
+ 'refresh_token': self.refresh_token,
170
+ 'client_id': self.client_id,
171
+ 'client_secret': self.client_secret,
172
+ 'project_id': self.project_id
173
+ }
174
+
175
+ if self.expires_at:
176
+ result['expiry'] = self.expires_at.isoformat()
177
+
178
+ return result
179
+
180
+
181
+ class Flow:
182
+ """OAuth流程类"""
183
+
184
+ def __init__(self, client_id: str, client_secret: str, scopes: List[str],
185
+ redirect_uri: str = None):
186
+ self.client_id = client_id
187
+ self.client_secret = client_secret
188
+ self.scopes = scopes
189
+ self.redirect_uri = redirect_uri
190
+
191
+ # 反代配置将在使用时异步获取
192
+ self.oauth_base_url = None
193
+ self.token_endpoint = None
194
+ self.auth_endpoint = "https://accounts.google.com/o/oauth2/auth"
195
+
196
+ self.credentials: Optional[Credentials] = None
197
+
198
+ def get_auth_url(self, state: str = None, **kwargs) -> str:
199
+ """生成授权URL"""
200
+ params = {
201
+ 'client_id': self.client_id,
202
+ 'redirect_uri': self.redirect_uri,
203
+ 'scope': ' '.join(self.scopes),
204
+ 'response_type': 'code',
205
+ 'access_type': 'offline',
206
+ 'prompt': 'consent',
207
+ 'include_granted_scopes': 'true'
208
+ }
209
+
210
+ if state:
211
+ params['state'] = state
212
+
213
+ params.update(kwargs)
214
+ return f"{self.auth_endpoint}?{urlencode(params)}"
215
+
216
+ async def exchange_code(self, code: str) -> Credentials:
217
+ """用授权码换取token"""
218
+ data = {
219
+ 'client_id': self.client_id,
220
+ 'client_secret': self.client_secret,
221
+ 'redirect_uri': self.redirect_uri,
222
+ 'code': code,
223
+ 'grant_type': 'authorization_code'
224
+ }
225
+
226
+ try:
227
+ oauth_base_url = await get_oauth_proxy_url()
228
+ token_url = f"{oauth_base_url.rstrip('/')}/token"
229
+ response = await post_async(
230
+ token_url,
231
+ data=data,
232
+ headers={'Content-Type': 'application/x-www-form-urlencoded'}
233
+ )
234
+ response.raise_for_status()
235
+
236
+ token_data = response.json()
237
+
238
+ # 计算过期时间
239
+ expires_at = None
240
+ if 'expires_in' in token_data:
241
+ expires_in = int(token_data['expires_in'])
242
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
243
+
244
+ # 创建凭证对象
245
+ self.credentials = Credentials(
246
+ access_token=token_data['access_token'],
247
+ refresh_token=token_data.get('refresh_token'),
248
+ client_id=self.client_id,
249
+ client_secret=self.client_secret,
250
+ expires_at=expires_at
251
+ )
252
+
253
+ return self.credentials
254
+
255
+ except Exception as e:
256
+ error_msg = f"获取token失败: {str(e)}"
257
+ log.error(error_msg)
258
+ raise TokenError(error_msg)
259
+
260
+
261
+ class ServiceAccount:
262
+ """Service Account类"""
263
+
264
+ def __init__(self, email: str, private_key: str, project_id: str = None,
265
+ scopes: List[str] = None):
266
+ self.email = email
267
+ self.private_key = private_key
268
+ self.project_id = project_id
269
+ self.scopes = scopes or []
270
+
271
+ # 反代配置将在使用时异步获取
272
+ self.oauth_base_url = None
273
+ self.token_endpoint = None
274
+
275
+ self.access_token: Optional[str] = None
276
+ self.expires_at: Optional[datetime] = None
277
+
278
+ def is_expired(self) -> bool:
279
+ """检查token是否过期"""
280
+ if not self.expires_at:
281
+ return True
282
+
283
+ buffer = timedelta(minutes=3)
284
+ return (self.expires_at - buffer) <= datetime.now(timezone.utc)
285
+
286
+ def create_jwt(self) -> str:
287
+ """创建JWT令牌"""
288
+ now = int(time.time())
289
+
290
+ payload = {
291
+ 'iss': self.email,
292
+ 'scope': ' '.join(self.scopes) if self.scopes else '',
293
+ 'aud': self.token_endpoint,
294
+ 'exp': now + 3600,
295
+ 'iat': now
296
+ }
297
+
298
+ return jwt.encode(payload, self.private_key, algorithm='RS256')
299
+
300
+ async def get_access_token(self) -> str:
301
+ """获取访问令牌"""
302
+ if not self.is_expired() and self.access_token:
303
+ return self.access_token
304
+
305
+ assertion = self.create_jwt()
306
+
307
+ data = {
308
+ 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
309
+ 'assertion': assertion
310
+ }
311
+
312
+ try:
313
+ oauth_base_url = await get_oauth_proxy_url()
314
+ token_url = f"{oauth_base_url.rstrip('/')}/token"
315
+ response = await post_async(
316
+ token_url,
317
+ data=data,
318
+ headers={'Content-Type': 'application/x-www-form-urlencoded'}
319
+ )
320
+ response.raise_for_status()
321
+
322
+ token_data = response.json()
323
+ self.access_token = token_data['access_token']
324
+
325
+ if 'expires_in' in token_data:
326
+ expires_in = int(token_data['expires_in'])
327
+ self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
328
+
329
+ return self.access_token
330
+
331
+ except Exception as e:
332
+ error_msg = f"Service Account获取token失败: {str(e)}"
333
+ log.error(error_msg)
334
+ raise TokenError(error_msg)
335
+
336
+ @classmethod
337
+ def from_dict(cls, data: Dict[str, Any], scopes: List[str] = None) -> 'ServiceAccount':
338
+ """从字典创建Service Account凭证"""
339
+ return cls(
340
+ email=data['client_email'],
341
+ private_key=data['private_key'],
342
+ project_id=data.get('project_id'),
343
+ scopes=scopes
344
+ )
345
+
346
+
347
+ # 工具函数
348
+ async def get_user_info(credentials: Credentials) -> Optional[Dict[str, Any]]:
349
+ """获取用户信息"""
350
+ await credentials.refresh_if_needed()
351
+
352
+ try:
353
+ googleapis_base_url = await get_googleapis_proxy_url()
354
+ userinfo_url = f"{googleapis_base_url.rstrip('/')}/oauth2/v2/userinfo"
355
+ response = await get_async(
356
+ userinfo_url,
357
+ headers={'Authorization': f'Bearer {credentials.access_token}'}
358
+ )
359
+ response.raise_for_status()
360
+ return response.json()
361
+ except Exception as e:
362
+ log.error(f"获取用户信息失败: {e}")
363
+ return None
364
+
365
+
366
+ async def get_user_email(credentials: Credentials) -> Optional[str]:
367
+ """获取用户邮箱地址"""
368
+ try:
369
+ # 确保凭证有效
370
+ await credentials.refresh_if_needed()
371
+
372
+ # 调用Google userinfo API获取邮箱
373
+ user_info = await get_user_info(credentials)
374
+ if user_info:
375
+ email = user_info.get("email")
376
+ if email:
377
+ log.info(f"成功获取邮箱地址: {email}")
378
+ return email
379
+ else:
380
+ log.warning(f"userinfo响应中没有邮箱信息: {user_info}")
381
+ return None
382
+ else:
383
+ log.warning("获取用户信息失败")
384
+ return None
385
+
386
+ except Exception as e:
387
+ log.error(f"获取用户邮箱失败: {e}")
388
+ return None
389
+
390
+
391
+ async def fetch_user_email_from_file(cred_data: Dict[str, Any]) -> Optional[str]:
392
+ """从凭证数据获取用户邮箱地址(支持统一存储)"""
393
+ try:
394
+ # 直接从凭证数据创建凭证对象
395
+ credentials = Credentials.from_dict(cred_data)
396
+ if not credentials or not credentials.access_token:
397
+ log.warning(f"无法从凭证数据创建凭证对象或获取访问令牌")
398
+ return None
399
+
400
+ # 获取邮箱
401
+ return await get_user_email(credentials)
402
+
403
+ except Exception as e:
404
+ log.error(f"从凭证数据获取用户邮箱失败: {e}")
405
+ return None
406
+
407
+
408
+ async def validate_token(token: str) -> Optional[Dict[str, Any]]:
409
+ """验证访问令牌"""
410
+ try:
411
+ oauth_base_url = await get_oauth_proxy_url()
412
+ tokeninfo_url = f"{oauth_base_url.rstrip('/')}/tokeninfo?access_token={token}"
413
+
414
+ response = await get_async(tokeninfo_url)
415
+ response.raise_for_status()
416
+ return response.json()
417
+ except Exception as e:
418
+ log.error(f"验证令牌失败: {e}")
419
+ return None
420
+
421
+
422
+ async def enable_required_apis(credentials: Credentials, project_id: str) -> bool:
423
+ """自动启用必需的API服务"""
424
+ try:
425
+ # 确保凭证有效
426
+ if credentials.is_expired() and credentials.refresh_token:
427
+ await credentials.refresh()
428
+
429
+ headers = {
430
+ "Authorization": f"Bearer {credentials.access_token}",
431
+ "Content-Type": "application/json",
432
+ "User-Agent": "geminicli-oauth/1.0",
433
+ }
434
+
435
+ # 需要启用的服务列表
436
+ required_services = [
437
+ "geminicloudassist.googleapis.com", # Gemini Cloud Assist API
438
+ "cloudaicompanion.googleapis.com" # Gemini for Google Cloud API
439
+ ]
440
+
441
+ for service in required_services:
442
+ log.info(f"正在检查并启用服务: {service}")
443
+
444
+ # 检查服务是否已启用
445
+ service_usage_base_url = await get_service_usage_api_url()
446
+ check_url = f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}"
447
+ try:
448
+ check_response = await get_async(check_url, headers=headers)
449
+ if check_response.status_code == 200:
450
+ service_data = check_response.json()
451
+ if service_data.get("state") == "ENABLED":
452
+ log.info(f"服务 {service} 已启用")
453
+ continue
454
+ except Exception as e:
455
+ log.debug(f"检查服务状态失败,将尝试启用: {e}")
456
+
457
+ # 启用服务
458
+ enable_url = f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}:enable"
459
+ try:
460
+ enable_response = await post_async(enable_url, headers=headers, json={})
461
+
462
+ if enable_response.status_code in [200, 201]:
463
+ log.info(f"✅ 成功启用服务: {service}")
464
+ elif enable_response.status_code == 400:
465
+ error_data = enable_response.json()
466
+ if "already enabled" in error_data.get("error", {}).get("message", "").lower():
467
+ log.info(f"✅ 服务 {service} 已经启用")
468
+ else:
469
+ log.warning(f"⚠️ 启用服务 {service} 时出现警告: {error_data}")
470
+ else:
471
+ log.warning(f"⚠️ 启用服务 {service} 失败: {enable_response.status_code} - {enable_response.text}")
472
+
473
+ except Exception as e:
474
+ log.warning(f"⚠️ 启用服务 {service} 时发生异常: {e}")
475
+
476
+ return True
477
+
478
+ except Exception as e:
479
+ log.error(f"启用API服务时发生错误: {e}")
480
+ return False
481
+
482
+
483
+ async def get_user_projects(credentials: Credentials) -> List[Dict[str, Any]]:
484
+ """获取用户可访问的Google Cloud项目列表"""
485
+ try:
486
+ # 确保凭证有效
487
+ if credentials.is_expired() and credentials.refresh_token:
488
+ await credentials.refresh()
489
+
490
+ headers = {
491
+ "Authorization": f"Bearer {credentials.access_token}",
492
+ "User-Agent": "geminicli-oauth/1.0",
493
+ }
494
+
495
+ # 使用Resource Manager API的正确域名和端点
496
+ resource_manager_base_url = await get_resource_manager_api_url()
497
+ url = f"{resource_manager_base_url.rstrip('/')}/v1/projects"
498
+ log.info(f"正在调用API: {url}")
499
+ response = await get_async(url, headers=headers)
500
+
501
+ log.info(f"API响应状态码: {response.status_code}")
502
+ if response.status_code != 200:
503
+ log.error(f"API响应内容: {response.text}")
504
+
505
+ if response.status_code == 200:
506
+ data = response.json()
507
+ projects = data.get('projects', [])
508
+ # 只返回活跃的项目
509
+ active_projects = [
510
+ project for project in projects
511
+ if project.get('lifecycleState') == 'ACTIVE'
512
+ ]
513
+ log.info(f"获取到 {len(active_projects)} 个活跃项目")
514
+ return active_projects
515
+ else:
516
+ log.warning(f"获取项目列表失败: {response.status_code} - {response.text}")
517
+ return []
518
+
519
+ except Exception as e:
520
+ log.error(f"获取用户项目列表失败: {e}")
521
+ return []
522
+
523
+
524
+
525
+
526
+ async def select_default_project(projects: List[Dict[str, Any]]) -> Optional[str]:
527
+ """从项目列表中选择默认项目"""
528
+ if not projects:
529
+ return None
530
+
531
+ # 策略1:查找显示名称或项目ID包含"default"的项目
532
+ for project in projects:
533
+ display_name = project.get('displayName', '').lower()
534
+ project_id = project.get('projectId', '')
535
+ if 'default' in display_name or 'default' in project_id.lower():
536
+ log.info(f"选择默认项目: {project_id} ({project.get('displayName', project_id)})")
537
+ return project_id
538
+
539
+ # 策略2:选择第一个项目
540
+ first_project = projects[0]
541
+ project_id = first_project.get('projectId', '')
542
+ log.info(f"选择第一个项目作为默认: {project_id} ({first_project.get('displayName', project_id)})")
543
+ return project_id
src/httpx_client.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 通用的HTTP客户端模块
3
+ 为所有需要使用httpx的模块提供统一的客户端配置和方法
4
+ 保持通用性,不与特定业务逻辑耦合
5
+ """
6
+ import httpx
7
+ from typing import Optional, Dict, Any, AsyncGenerator
8
+ from contextlib import asynccontextmanager
9
+
10
+ from config import get_proxy_config
11
+ from log import log
12
+
13
+
14
+ class HttpxClientManager:
15
+ """通用HTTP客户端管理器"""
16
+
17
+ async def get_client_kwargs(self, timeout: float = 30.0, **kwargs) -> Dict[str, Any]:
18
+ """获取httpx客户端的通用配置参数"""
19
+ client_kwargs = {
20
+ "timeout": timeout,
21
+ **kwargs
22
+ }
23
+
24
+ # 动态读取代理配置,支持热更新
25
+ current_proxy_config = await get_proxy_config()
26
+ if current_proxy_config:
27
+ client_kwargs["proxy"] = current_proxy_config
28
+
29
+ return client_kwargs
30
+
31
+ @asynccontextmanager
32
+ async def get_client(self, timeout: float = 30.0, **kwargs) -> AsyncGenerator[httpx.AsyncClient, None]:
33
+ """获取配置好的异步HTTP客户端"""
34
+ client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs)
35
+
36
+ async with httpx.AsyncClient(**client_kwargs) as client:
37
+ yield client
38
+
39
+ @asynccontextmanager
40
+ async def get_streaming_client(self, timeout: float = None, **kwargs) -> AsyncGenerator[httpx.AsyncClient, None]:
41
+ """获取用于流式请求的HTTP客户端(无超时限制)"""
42
+ client_kwargs = await self.get_client_kwargs(timeout=timeout, **kwargs)
43
+
44
+ # 创建独立的客户端实例用于流式处理
45
+ client = httpx.AsyncClient(**client_kwargs)
46
+ try:
47
+ yield client
48
+ finally:
49
+ await client.aclose()
50
+
51
+
52
+ # 全局HTTP客户端管理器实例
53
+ http_client = HttpxClientManager()
54
+
55
+
56
+ # 通用的异步方法
57
+ async def get_async(url: str, headers: Optional[Dict[str, str]] = None,
58
+ timeout: float = 30.0, **kwargs) -> httpx.Response:
59
+ """通用异步GET请求"""
60
+ async with http_client.get_client(timeout=timeout, **kwargs) as client:
61
+ return await client.get(url, headers=headers)
62
+
63
+
64
+ async def post_async(url: str, data: Any = None, json: Any = None,
65
+ headers: Optional[Dict[str, str]] = None,
66
+ timeout: float = 30.0, **kwargs) -> httpx.Response:
67
+ """通用异步POST请求"""
68
+ async with http_client.get_client(timeout=timeout, **kwargs) as client:
69
+ return await client.post(url, data=data, json=json, headers=headers)
70
+
71
+
72
+ async def put_async(url: str, data: Any = None, json: Any = None,
73
+ headers: Optional[Dict[str, str]] = None,
74
+ timeout: float = 30.0, **kwargs) -> httpx.Response:
75
+ """通用异步PUT请求"""
76
+ async with http_client.get_client(timeout=timeout, **kwargs) as client:
77
+ return await client.put(url, data=data, json=json, headers=headers)
78
+
79
+
80
+ async def delete_async(url: str, headers: Optional[Dict[str, str]] = None,
81
+ timeout: float = 30.0, **kwargs) -> httpx.Response:
82
+ """通用异步DELETE请求"""
83
+ async with http_client.get_client(timeout=timeout, **kwargs) as client:
84
+ return await client.delete(url, headers=headers)
85
+
86
+
87
+ # 错误处理装饰器
88
+ def handle_http_errors(func):
89
+ """HTTP错误处理装饰器"""
90
+ async def wrapper(*args, **kwargs):
91
+ try:
92
+ response = await func(*args, **kwargs)
93
+ response.raise_for_status()
94
+ return response
95
+ except httpx.HTTPStatusError as e:
96
+ log.error(f"HTTP错误: {e.response.status_code} - {e.response.text}")
97
+ raise
98
+ except httpx.RequestError as e:
99
+ log.error(f"请求错误: {e}")
100
+ raise
101
+ except Exception as e:
102
+ log.error(f"未知错误: {e}")
103
+ raise
104
+ return wrapper
105
+
106
+
107
+ # 应用错误处理的安全方法
108
+ @handle_http_errors
109
+ async def safe_get_async(url: str, headers: Optional[Dict[str, str]] = None,
110
+ timeout: float = 30.0, **kwargs) -> httpx.Response:
111
+ """安全的异步GET请求(自动错误处理)"""
112
+ return await get_async(url, headers=headers, timeout=timeout, **kwargs)
113
+
114
+
115
+ @handle_http_errors
116
+ async def safe_post_async(url: str, data: Any = None, json: Any = None,
117
+ headers: Optional[Dict[str, str]] = None,
118
+ timeout: float = 30.0, **kwargs) -> httpx.Response:
119
+ """安全的异步POST请求(自动错误处理)"""
120
+ return await post_async(url, data=data, json=json, headers=headers, timeout=timeout, **kwargs)
121
+
122
+
123
+ @handle_http_errors
124
+ async def safe_put_async(url: str, data: Any = None, json: Any = None,
125
+ headers: Optional[Dict[str, str]] = None,
126
+ timeout: float = 30.0, **kwargs) -> httpx.Response:
127
+ """安全的异步PUT请求(自动错误处理)"""
128
+ return await put_async(url, data=data, json=json, headers=headers, timeout=timeout, **kwargs)
129
+
130
+
131
+ @handle_http_errors
132
+ async def safe_delete_async(url: str, headers: Optional[Dict[str, str]] = None,
133
+ timeout: float = 30.0, **kwargs) -> httpx.Response:
134
+ """安全的异步DELETE请求(自动错误处理)"""
135
+ return await delete_async(url, headers=headers, timeout=timeout, **kwargs)
136
+
137
+
138
+ # 流式请求支持
139
+ class StreamingContext:
140
+ """流式请求上下文管理器"""
141
+
142
+ def __init__(self, client: httpx.AsyncClient, stream_context):
143
+ self.client = client
144
+ self.stream_context = stream_context
145
+ self.response = None
146
+
147
+ async def __aenter__(self):
148
+ self.response = await self.stream_context.__aenter__()
149
+ return self.response
150
+
151
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
152
+ try:
153
+ if self.stream_context:
154
+ await self.stream_context.__aexit__(exc_type, exc_val, exc_tb)
155
+ finally:
156
+ if self.client:
157
+ await self.client.aclose()
158
+
159
+
160
+ @asynccontextmanager
161
+ async def get_streaming_post_context(url: str, data: Any = None, json: Any = None,
162
+ headers: Optional[Dict[str, str]] = None,
163
+ timeout: float = None, **kwargs) -> AsyncGenerator[StreamingContext, None]:
164
+ """获取流式POST请求的上下文管理器"""
165
+ async with http_client.get_streaming_client(timeout=timeout, **kwargs) as client:
166
+ stream_ctx = client.stream("POST", url, data=data, json=json, headers=headers)
167
+ streaming_context = StreamingContext(client, stream_ctx)
168
+ yield streaming_context
169
+
170
+
171
+ async def create_streaming_client_with_kwargs(**kwargs) -> httpx.AsyncClient:
172
+ """创建用于流式处理的独立客户端实例(手动管理生命周期)"""
173
+ client_kwargs = await http_client.get_client_kwargs(timeout=None, **kwargs)
174
+ return httpx.AsyncClient(**client_kwargs)
src/models.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Any
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ # Common Models
6
+ class Model(BaseModel):
7
+ id: str
8
+ object: str = "model"
9
+ created: Optional[int] = None
10
+ owned_by: Optional[str] = "google"
11
+
12
+ class ModelList(BaseModel):
13
+ object: str = "list"
14
+ data: List[Model]
15
+
16
+ # OpenAI Models
17
+ class OpenAIChatMessage(BaseModel):
18
+ role: str
19
+ content: Union[str, List[Dict[str, Any]], None] = None
20
+ reasoning_content: Optional[str] = None
21
+ name: Optional[str] = None
22
+
23
+ class OpenAIChatCompletionRequest(BaseModel):
24
+ model: str
25
+ messages: List[OpenAIChatMessage]
26
+ stream: bool = False
27
+ temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
28
+ top_p: Optional[float] = Field(None, ge=0.0, le=1.0)
29
+ max_tokens: Optional[int] = Field(None, ge=1)
30
+ stop: Optional[Union[str, List[str]]] = None
31
+ frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
32
+ presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
33
+ n: Optional[int] = Field(1, ge=1, le=128)
34
+ seed: Optional[int] = None
35
+ response_format: Optional[Dict[str, Any]] = None
36
+ top_k: Optional[int] = Field(None, ge=1)
37
+ enable_anti_truncation: Optional[bool] = False
38
+
39
+ class Config:
40
+ extra = "allow" # Allow additional fields not explicitly defined
41
+
42
+ # 通用的聊天完成请求模型(兼容OpenAI和其他格式)
43
+ ChatCompletionRequest = OpenAIChatCompletionRequest
44
+
45
+ class OpenAIChatCompletionChoice(BaseModel):
46
+ index: int
47
+ message: OpenAIChatMessage
48
+ finish_reason: Optional[str] = None
49
+ logprobs: Optional[Dict[str, Any]] = None
50
+
51
+ class OpenAIChatCompletionResponse(BaseModel):
52
+ id: str
53
+ object: str = "chat.completion"
54
+ created: int
55
+ model: str
56
+ choices: List[OpenAIChatCompletionChoice]
57
+ usage: Optional[Dict[str, int]] = None
58
+ system_fingerprint: Optional[str] = None
59
+
60
+ class OpenAIDelta(BaseModel):
61
+ role: Optional[str] = None
62
+ content: Optional[str] = None
63
+ reasoning_content: Optional[str] = None
64
+
65
+ class OpenAIChatCompletionStreamChoice(BaseModel):
66
+ index: int
67
+ delta: OpenAIDelta
68
+ finish_reason: Optional[str] = None
69
+ logprobs: Optional[Dict[str, Any]] = None
70
+
71
+ class OpenAIChatCompletionStreamResponse(BaseModel):
72
+ id: str
73
+ object: str = "chat.completion.chunk"
74
+ created: int
75
+ model: str
76
+ choices: List[OpenAIChatCompletionStreamChoice]
77
+ system_fingerprint: Optional[str] = None
78
+
79
+ # Gemini Models
80
+ class GeminiPart(BaseModel):
81
+ text: Optional[str] = None
82
+ inlineData: Optional[Dict[str, Any]] = None
83
+ fileData: Optional[Dict[str, Any]] = None
84
+ thought: Optional[bool] = False
85
+
86
+ class GeminiContent(BaseModel):
87
+ role: str
88
+ parts: List[GeminiPart]
89
+
90
+ class GeminiSystemInstruction(BaseModel):
91
+ parts: List[GeminiPart]
92
+
93
+ class GeminiGenerationConfig(BaseModel):
94
+ temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
95
+ topP: Optional[float] = Field(None, ge=0.0, le=1.0)
96
+ topK: Optional[int] = Field(None, ge=1)
97
+ maxOutputTokens: Optional[int] = Field(None, ge=1)
98
+ stopSequences: Optional[List[str]] = None
99
+ responseMimeType: Optional[str] = None
100
+ responseSchema: Optional[Dict[str, Any]] = None
101
+ candidateCount: Optional[int] = Field(None, ge=1, le=8)
102
+ seed: Optional[int] = None
103
+ frequencyPenalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
104
+ presencePenalty: Optional[float] = Field(None, ge=-2.0, le=2.0)
105
+ thinkingConfig: Optional[Dict[str, Any]] = None
106
+
107
+ class GeminiSafetySetting(BaseModel):
108
+ category: str
109
+ threshold: str
110
+
111
+ class GeminiRequest(BaseModel):
112
+ contents: List[GeminiContent]
113
+ systemInstruction: Optional[GeminiSystemInstruction] = None
114
+ generationConfig: Optional[GeminiGenerationConfig] = None
115
+ safetySettings: Optional[List[GeminiSafetySetting]] = None
116
+ tools: Optional[List[Dict[str, Any]]] = None
117
+ toolConfig: Optional[Dict[str, Any]] = None
118
+ cachedContent: Optional[str] = None
119
+ enable_anti_truncation: Optional[bool] = False
120
+
121
+ class Config:
122
+ extra = "allow" # 允许透传未定义的字段
123
+
124
+ class GeminiCandidate(BaseModel):
125
+ content: GeminiContent
126
+ finishReason: Optional[str] = None
127
+ index: int = 0
128
+ safetyRatings: Optional[List[Dict[str, Any]]] = None
129
+ citationMetadata: Optional[Dict[str, Any]] = None
130
+ tokenCount: Optional[int] = None
131
+
132
+ class GeminiUsageMetadata(BaseModel):
133
+ promptTokenCount: Optional[int] = None
134
+ candidatesTokenCount: Optional[int] = None
135
+ totalTokenCount: Optional[int] = None
136
+
137
+ class GeminiResponse(BaseModel):
138
+ candidates: List[GeminiCandidate]
139
+ usageMetadata: Optional[GeminiUsageMetadata] = None
140
+ modelVersion: Optional[str] = None
141
+
142
+ # Error Models
143
+ class APIError(BaseModel):
144
+ message: str
145
+ type: str = "api_error"
146
+ code: Optional[int] = None
147
+
148
+ class ErrorResponse(BaseModel):
149
+ error: APIError
150
+
151
+ # Control Panel Models
152
+ class SystemStatus(BaseModel):
153
+ status: str
154
+ timestamp: str
155
+ credentials: Dict[str, int]
156
+ config: Dict[str, Any]
157
+ current_credential: str
158
+
159
+ class CredentialInfo(BaseModel):
160
+ filename: str
161
+ project_id: Optional[str] = None
162
+ status: Dict[str, Any]
163
+ size: Optional[int] = None
164
+ modified_time: Optional[str] = None
165
+ error: Optional[str] = None
166
+
167
+ class LogEntry(BaseModel):
168
+ timestamp: str
169
+ level: str
170
+ message: str
171
+ module: Optional[str] = None
172
+
173
+ class ConfigValue(BaseModel):
174
+ key: str
175
+ value: Any
176
+ env_locked: bool = False
177
+ description: Optional[str] = None
178
+
179
+ # Authentication Models
180
+ class AuthRequest(BaseModel):
181
+ project_id: Optional[str] = None
182
+ user_session: Optional[str] = None
183
+
184
+ class AuthResponse(BaseModel):
185
+ success: bool
186
+ auth_url: Optional[str] = None
187
+ state: Optional[str] = None
188
+ error: Optional[str] = None
189
+ credentials: Optional[Dict[str, Any]] = None
190
+ file_path: Optional[str] = None
191
+ requires_manual_project_id: Optional[bool] = None
192
+ requires_project_selection: Optional[bool] = None
193
+ available_projects: Optional[List[Dict[str, str]]] = None
194
+
195
+ class CredentialStatus(BaseModel):
196
+ disabled: bool = False
197
+ error_codes: List[int] = []
198
+ last_success: Optional[str] = None
src/openai_router.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI Router - Handles OpenAI format API requests
3
+ 处理OpenAI格式请求的路由模块
4
+ """
5
+ import json
6
+ import time
7
+ import uuid
8
+ import asyncio
9
+ from contextlib import asynccontextmanager
10
+
11
+ from fastapi import APIRouter, HTTPException, Depends, Request, status
12
+ from fastapi.responses import JSONResponse, StreamingResponse
13
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
14
+
15
+ from config import get_available_models, is_fake_streaming_model, is_anti_truncation_model, get_base_model_from_feature_model, get_anti_truncation_max_attempts
16
+ from log import log
17
+ from .anti_truncation import apply_anti_truncation_to_stream
18
+ from .credential_manager import CredentialManager
19
+ from .google_chat_api import send_gemini_request
20
+ from .models import ChatCompletionRequest, ModelList, Model
21
+ from .task_manager import create_managed_task
22
+ from .openai_transfer import openai_request_to_gemini_payload, gemini_response_to_openai, gemini_stream_chunk_to_openai
23
+
24
+ # 创建路由器
25
+ router = APIRouter()
26
+ security = HTTPBearer()
27
+
28
+ # 全局凭证管理器实例
29
+ credential_manager = None
30
+
31
+ @asynccontextmanager
32
+ async def get_credential_manager():
33
+ """获取全局凭证管理器实例"""
34
+ global credential_manager
35
+ if not credential_manager:
36
+ credential_manager = CredentialManager()
37
+ await credential_manager.initialize()
38
+ yield credential_manager
39
+
40
+ async def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
41
+ """验证用户密码"""
42
+ from config import get_api_password
43
+ password = await get_api_password()
44
+ token = credentials.credentials
45
+ if token != password:
46
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="密码错误")
47
+ return token
48
+
49
+ @router.get("/v1/models", response_model=ModelList)
50
+ async def list_models():
51
+ """返回OpenAI格式的模型列表"""
52
+ models = get_available_models("openai")
53
+ return ModelList(data=[Model(id=m) for m in models])
54
+
55
+ @router.post("/v1/chat/completions")
56
+ async def chat_completions(
57
+ request: Request,
58
+ token: str = Depends(authenticate)
59
+ ):
60
+ """处理OpenAI格式的聊天完成请求"""
61
+
62
+ # 获取原始请求数据
63
+ try:
64
+ raw_data = await request.json()
65
+ except Exception as e:
66
+ log.error(f"Failed to parse JSON request: {e}")
67
+ raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
68
+
69
+ # 创建请求对象
70
+ try:
71
+ request_data = ChatCompletionRequest(**raw_data)
72
+ except Exception as e:
73
+ log.error(f"Request validation failed: {e}")
74
+ raise HTTPException(status_code=400, detail=f"Request validation error: {str(e)}")
75
+
76
+ # 健康检查
77
+ if (len(request_data.messages) == 1 and
78
+ getattr(request_data.messages[0], "role", None) == "user" and
79
+ getattr(request_data.messages[0], "content", None) == "Hi"):
80
+ return JSONResponse(content={
81
+ "choices": [{"message": {"role": "assistant", "content": "gcli2api正常工作中"}}]
82
+ })
83
+
84
+ # 限制max_tokens
85
+ if getattr(request_data, "max_tokens", None) is not None and request_data.max_tokens > 65535:
86
+ request_data.max_tokens = 65535
87
+
88
+ # 覆写 top_k 为 64
89
+ setattr(request_data, "top_k", 64)
90
+
91
+ # 过滤空消息
92
+ filtered_messages = []
93
+ for m in request_data.messages:
94
+ content = getattr(m, "content", None)
95
+ if content:
96
+ if isinstance(content, str) and content.strip():
97
+ filtered_messages.append(m)
98
+ elif isinstance(content, list) and len(content) > 0:
99
+ has_valid_content = False
100
+ for part in content:
101
+ if isinstance(part, dict):
102
+ if part.get("type") == "text" and part.get("text", "").strip():
103
+ has_valid_content = True
104
+ break
105
+ elif part.get("type") == "image_url" and part.get("image_url", {}).get("url"):
106
+ has_valid_content = True
107
+ break
108
+ if has_valid_content:
109
+ filtered_messages.append(m)
110
+
111
+ request_data.messages = filtered_messages
112
+
113
+ # 处理模型名称和功能检测
114
+ model = request_data.model
115
+ use_fake_streaming = is_fake_streaming_model(model)
116
+ use_anti_truncation = is_anti_truncation_model(model)
117
+
118
+ # 获取基础模型名
119
+ real_model = get_base_model_from_feature_model(model)
120
+ request_data.model = real_model
121
+
122
+ # 获取凭证管理器
123
+ from src.credential_manager import get_credential_manager
124
+ cred_mgr = await get_credential_manager()
125
+
126
+ # 获取有效凭证
127
+ credential_result = await cred_mgr.get_valid_credential()
128
+ if not credential_result:
129
+ log.error("当前无可用凭证,请去控制台获取")
130
+ raise HTTPException(status_code=500, detail="当前无可用凭证,请去控制台获取")
131
+
132
+ current_file = credential_result
133
+ log.debug(f"Using credential: {current_file}")
134
+
135
+ # 增加调用计数
136
+ cred_mgr.increment_call_count()
137
+
138
+ # 转换为Gemini API payload格式
139
+ try:
140
+ api_payload = await openai_request_to_gemini_payload(request_data)
141
+ except Exception as e:
142
+ log.error(f"OpenAI to Gemini conversion failed: {e}")
143
+ raise HTTPException(status_code=500, detail="Request conversion failed")
144
+
145
+ # 处理假流式
146
+ if use_fake_streaming and getattr(request_data, "stream", False):
147
+ request_data.stream = False
148
+ return await fake_stream_response(api_payload, cred_mgr)
149
+
150
+ # 处理抗截断 (仅流式传输时有效)
151
+ is_streaming = getattr(request_data, "stream", False)
152
+ if use_anti_truncation and is_streaming:
153
+ log.info("启用流式抗截断功能")
154
+ max_attempts = await get_anti_truncation_max_attempts()
155
+
156
+ # 使用流式抗截断处理器
157
+ gemini_response = await apply_anti_truncation_to_stream(
158
+ lambda api_payload: send_gemini_request(api_payload, is_streaming, cred_mgr),
159
+ api_payload,
160
+ max_attempts
161
+ )
162
+
163
+ return await convert_streaming_response(gemini_response, model)
164
+ elif use_anti_truncation and not is_streaming:
165
+ log.warning("抗截断功能仅在流式传输时有效,非流式请求将忽略此设置")
166
+
167
+ # 发送请求(429重试已在google_api_client中处理)
168
+ is_streaming = getattr(request_data, "stream", False)
169
+ log.debug(f"Sending request: streaming={is_streaming}, model={real_model}")
170
+ response = await send_gemini_request(api_payload, is_streaming, cred_mgr)
171
+
172
+ # 如果是流式响应,直接返回
173
+ if is_streaming:
174
+ return await convert_streaming_response(response, model)
175
+
176
+ # 转换非流式响应
177
+ try:
178
+ log.debug(f"Processing response: type={type(response)}")
179
+ if hasattr(response, 'body'):
180
+ response_data = json.loads(response.body.decode() if isinstance(response.body, bytes) else response.body)
181
+ else:
182
+ response_data = json.loads(response.content.decode() if isinstance(response.content, bytes) else response.content)
183
+
184
+ log.debug(f"Response data keys: {list(response_data.keys()) if isinstance(response_data, dict) else 'Not a dict'}")
185
+ openai_response = gemini_response_to_openai(response_data, model)
186
+ log.debug(f"Converted OpenAI response keys: {list(openai_response.keys()) if isinstance(openai_response, dict) else 'Not a dict'}")
187
+ return JSONResponse(content=openai_response)
188
+
189
+ except Exception as e:
190
+ log.error(f"Response conversion failed: {e}")
191
+ log.error(f"Response object: {response}")
192
+ raise HTTPException(status_code=500, detail="Response conversion failed")
193
+
194
+ async def fake_stream_response(api_payload: dict, cred_mgr: CredentialManager) -> StreamingResponse:
195
+ """处理假流式响应"""
196
+ async def stream_generator():
197
+ try:
198
+ # 发送心跳
199
+ heartbeat = {
200
+ "choices": [{
201
+ "index": 0,
202
+ "delta": {"role": "assistant", "content": ""},
203
+ "finish_reason": None
204
+ }]
205
+ }
206
+ yield f"data: {json.dumps(heartbeat)}\n\n".encode()
207
+
208
+ # 异步发送实际请求
209
+ async def get_response():
210
+ return await send_gemini_request(api_payload, False, cred_mgr)
211
+
212
+ # 创建请求任务
213
+ response_task = create_managed_task(get_response(), name="openai_fake_stream_request")
214
+
215
+ try:
216
+ # 每3秒发送一次心跳,直到收到响应
217
+ while not response_task.done():
218
+ await asyncio.sleep(3.0)
219
+ if not response_task.done():
220
+ yield f"data: {json.dumps(heartbeat)}\n\n".encode()
221
+
222
+ # 获取响应结果
223
+ response = await response_task
224
+
225
+ except asyncio.CancelledError:
226
+ # 取消任务并传播取消
227
+ response_task.cancel()
228
+ try:
229
+ await response_task
230
+ except asyncio.CancelledError:
231
+ pass
232
+ raise
233
+ except Exception as e:
234
+ # 取消任务并处理其他异常
235
+ response_task.cancel()
236
+ try:
237
+ await response_task
238
+ except asyncio.CancelledError:
239
+ pass
240
+ log.error(f"Fake streaming request failed: {e}")
241
+ raise
242
+
243
+ # 发送实际请求
244
+ # response 已在上面获取
245
+
246
+ # 处理结果
247
+ if hasattr(response, 'body'):
248
+ body_str = response.body.decode() if isinstance(response.body, bytes) else str(response.body)
249
+ elif hasattr(response, 'content'):
250
+ body_str = response.content.decode() if isinstance(response.content, bytes) else str(response.content)
251
+ else:
252
+ body_str = str(response)
253
+
254
+ try:
255
+ response_data = json.loads(body_str)
256
+ log.debug(f"Fake stream response data: {response_data}")
257
+
258
+ # 从Gemini响应中提取内容,使用思维链分离逻辑
259
+ content = ""
260
+ reasoning_content = ""
261
+ if "candidates" in response_data and response_data["candidates"]:
262
+ # Gemini格式响应 - 使用思维链分离
263
+ from .openai_transfer import _extract_content_and_reasoning
264
+ candidate = response_data["candidates"][0]
265
+ if "content" in candidate and "parts" in candidate["content"]:
266
+ parts = candidate["content"]["parts"]
267
+ content, reasoning_content = _extract_content_and_reasoning(parts)
268
+ elif "choices" in response_data and response_data["choices"]:
269
+ # OpenAI格式响应
270
+ content = response_data["choices"][0].get("message", {}).get("content", "")
271
+
272
+ log.debug(f"Extracted content: {content}")
273
+ log.debug(f"Extracted reasoning: {reasoning_content[:100] if reasoning_content else 'None'}...")
274
+
275
+ # 如果没有正常内容但有思维内容,给出警告
276
+ if not content and reasoning_content:
277
+ log.warning(f"Fake stream response contains only thinking content: {reasoning_content[:100]}...")
278
+ content = "[模型正在思考中,请稍后再试或重新提问]"
279
+
280
+ if content:
281
+ # 构建响应块,包括思维内容(如果有)
282
+ delta = {"role": "assistant", "content": content}
283
+ if reasoning_content:
284
+ delta["reasoning_content"] = reasoning_content
285
+
286
+ content_chunk = {
287
+ "choices": [{
288
+ "index": 0,
289
+ "delta": delta,
290
+ "finish_reason": "stop"
291
+ }]
292
+ }
293
+ yield f"data: {json.dumps(content_chunk)}\n\n".encode()
294
+ else:
295
+ log.warning(f"No content found in response: {response_data}")
296
+ # 如果完全没有内容,提供默认回复
297
+ error_chunk = {
298
+ "choices": [{
299
+ "index": 0,
300
+ "delta": {"role": "assistant", "content": "[响应为空,请重新尝试]"},
301
+ "finish_reason": "stop"
302
+ }]
303
+ }
304
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
305
+ except json.JSONDecodeError:
306
+ error_chunk = {
307
+ "choices": [{
308
+ "index": 0,
309
+ "delta": {"role": "assistant", "content": body_str},
310
+ "finish_reason": "stop"
311
+ }]
312
+ }
313
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
314
+
315
+ yield "data: [DONE]\n\n".encode()
316
+
317
+ except Exception as e:
318
+ log.error(f"Fake streaming error: {e}")
319
+ error_chunk = {
320
+ "choices": [{
321
+ "index": 0,
322
+ "delta": {"role": "assistant", "content": f"Error: {str(e)}"},
323
+ "finish_reason": "stop"
324
+ }]
325
+ }
326
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
327
+ yield "data: [DONE]\n\n".encode()
328
+
329
+ return StreamingResponse(stream_generator(), media_type="text/event-stream")
330
+
331
+ async def convert_streaming_response(gemini_response, model: str) -> StreamingResponse:
332
+ """转换流式响应为OpenAI格式"""
333
+ response_id = str(uuid.uuid4())
334
+
335
+ async def openai_stream_generator():
336
+ try:
337
+ # 处理不同类型的响应对象
338
+ if hasattr(gemini_response, 'body_iterator'):
339
+ # FastAPI StreamingResponse
340
+ async for chunk in gemini_response.body_iterator:
341
+ if not chunk:
342
+ continue
343
+
344
+ # 处理不同数据类型的startswith问题
345
+ if isinstance(chunk, bytes):
346
+ if not chunk.startswith(b'data: '):
347
+ continue
348
+ payload = chunk[len(b'data: '):]
349
+ else:
350
+ chunk_str = str(chunk)
351
+ if not chunk_str.startswith('data: '):
352
+ continue
353
+ payload = chunk_str[len('data: '):].encode()
354
+ try:
355
+ gemini_chunk = json.loads(payload.decode())
356
+ openai_chunk = gemini_stream_chunk_to_openai(gemini_chunk, model, response_id)
357
+ yield f"data: {json.dumps(openai_chunk, separators=(',',':'))}\n\n".encode()
358
+ except json.JSONDecodeError:
359
+ continue
360
+ else:
361
+ # 其他类型的响应,尝试直接处理
362
+ log.warning(f"Unexpected response type: {type(gemini_response)}")
363
+ error_chunk = {
364
+ "id": response_id,
365
+ "object": "chat.completion.chunk",
366
+ "created": int(time.time()),
367
+ "model": model,
368
+ "choices": [{
369
+ "index": 0,
370
+ "delta": {"role": "assistant", "content": "Response type error"},
371
+ "finish_reason": "stop"
372
+ }]
373
+ }
374
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
375
+
376
+ # 发送结束标记
377
+ yield "data: [DONE]\n\n".encode()
378
+
379
+ except Exception as e:
380
+ log.error(f"Stream conversion error: {e}")
381
+ error_chunk = {
382
+ "id": response_id,
383
+ "object": "chat.completion.chunk",
384
+ "created": int(time.time()),
385
+ "model": model,
386
+ "choices": [{
387
+ "index": 0,
388
+ "delta": {"role": "assistant", "content": f"Stream error: {str(e)}"},
389
+ "finish_reason": "stop"
390
+ }]
391
+ }
392
+ yield f"data: {json.dumps(error_chunk)}\n\n".encode()
393
+ yield "data: [DONE]\n\n".encode()
394
+
395
+ return StreamingResponse(openai_stream_generator(), media_type="text/event-stream")
src/openai_transfer.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI Transfer Module - Handles conversion between OpenAI and Gemini API formats
3
+ 被openai-router调用,负责OpenAI格式与Gemini格式的双向转换
4
+ """
5
+ import time
6
+ import uuid
7
+ from typing import Dict, Any
8
+
9
+ from config import (
10
+ DEFAULT_SAFETY_SETTINGS,
11
+ get_base_model_name,
12
+ get_thinking_budget,
13
+ is_search_model,
14
+ should_include_thoughts,
15
+ get_compatibility_mode_enabled
16
+ )
17
+ from log import log
18
+ from .models import ChatCompletionRequest
19
+
20
+ async def openai_request_to_gemini_payload(openai_request: ChatCompletionRequest) -> Dict[str, Any]:
21
+ """
22
+ 将OpenAI聊天完成请求直接转换为完整的Gemini API payload格式
23
+
24
+ Args:
25
+ openai_request: OpenAI格式请求对象
26
+
27
+ Returns:
28
+ 完整的Gemini API payload,包含model和request字段
29
+ """
30
+ contents = []
31
+ system_instructions = []
32
+
33
+ # 检查是否启用兼容性模式
34
+ compatibility_mode = await get_compatibility_mode_enabled()
35
+
36
+ # 处理对话中的每条消息
37
+ # 第一阶段:收集连续的system消息到system_instruction中(除非在兼容性模式下)
38
+ collecting_system = True if not compatibility_mode else False
39
+
40
+ for message in openai_request.messages:
41
+ role = message.role
42
+
43
+ # 处理系统消息
44
+ if role == "system":
45
+ if compatibility_mode:
46
+ # 兼容性模式:所有system消息转换为user消息
47
+ role = "user"
48
+ elif collecting_system:
49
+ # 正常模式:仍在收集连续的system消息
50
+ if isinstance(message.content, str):
51
+ system_instructions.append(message.content)
52
+ elif isinstance(message.content, list):
53
+ # 处理列表格式的系统消息
54
+ for part in message.content:
55
+ if part.get("type") == "text" and part.get("text"):
56
+ system_instructions.append(part["text"])
57
+ continue
58
+ else:
59
+ # 正常模式:后续的system消息转换为user消息
60
+ role = "user"
61
+ else:
62
+ # 遇到非system消息,停止收集system消息
63
+ collecting_system = False
64
+
65
+ # 将OpenAI角色映射到Gemini角色
66
+ if role == "assistant":
67
+ role = "model"
68
+
69
+ # 处理普通内容
70
+ if isinstance(message.content, list):
71
+ parts = []
72
+ for part in message.content:
73
+ if part.get("type") == "text":
74
+ parts.append({"text": part.get("text", "")})
75
+ elif part.get("type") == "image_url":
76
+ image_url = part.get("image_url", {}).get("url")
77
+ if image_url:
78
+ # 解析数据URI: "data:image/jpeg;base64,{base64_image}"
79
+ try:
80
+ mime_type, base64_data = image_url.split(";")
81
+ _, mime_type = mime_type.split(":")
82
+ _, base64_data = base64_data.split(",")
83
+ parts.append({
84
+ "inlineData": {
85
+ "mimeType": mime_type,
86
+ "data": base64_data
87
+ }
88
+ })
89
+ except ValueError:
90
+ continue
91
+ contents.append({"role": role, "parts": parts})
92
+ # log.debug(f"Added message to contents: role={role}, parts={parts}")
93
+ elif message.content:
94
+ # 简单文本内容
95
+ contents.append({"role": role, "parts": [{"text": message.content}]})
96
+ # log.debug(f"Added message to contents: role={role}, content={message.content}")
97
+
98
+ # 将OpenAI生成参数映射到Gemini格式
99
+ generation_config = {}
100
+ if openai_request.temperature is not None:
101
+ generation_config["temperature"] = openai_request.temperature
102
+ if openai_request.top_p is not None:
103
+ generation_config["topP"] = openai_request.top_p
104
+ if openai_request.max_tokens is not None:
105
+ generation_config["maxOutputTokens"] = openai_request.max_tokens
106
+ if openai_request.stop is not None:
107
+ # Gemini支持停止序列
108
+ if isinstance(openai_request.stop, str):
109
+ generation_config["stopSequences"] = [openai_request.stop]
110
+ elif isinstance(openai_request.stop, list):
111
+ generation_config["stopSequences"] = openai_request.stop
112
+ if openai_request.frequency_penalty is not None:
113
+ generation_config["frequencyPenalty"] = openai_request.frequency_penalty
114
+ if openai_request.presence_penalty is not None:
115
+ generation_config["presencePenalty"] = openai_request.presence_penalty
116
+ if openai_request.n is not None:
117
+ generation_config["candidateCount"] = openai_request.n
118
+ if openai_request.seed is not None:
119
+ generation_config["seed"] = openai_request.seed
120
+ if openai_request.response_format is not None:
121
+ # 处理JSON模式
122
+ if openai_request.response_format.get("type") == "json_object":
123
+ generation_config["responseMimeType"] = "application/json"
124
+
125
+ # 如果contents为空(只有系统消息的情况),添加一个默认的用户消息以满足Gemini API要求
126
+ if not contents:
127
+ contents.append({"role": "user", "parts": [{"text": "请根据系统指令回答。"}]})
128
+
129
+ # 构建请求数据
130
+ request_data = {
131
+ "contents": contents,
132
+ "generationConfig": generation_config,
133
+ "safetySettings": DEFAULT_SAFETY_SETTINGS,
134
+ }
135
+
136
+ # 如果有系统消息且未启用兼容性模式,添加systemInstruction
137
+ if system_instructions and not compatibility_mode:
138
+ combined_system_instruction = "\n\n".join(system_instructions)
139
+ request_data["systemInstruction"] = {"parts": [{"text": combined_system_instruction}]}
140
+
141
+ log.debug(f"Final request payload contents count: {len(contents)}, system_instruction: {bool(system_instructions and not compatibility_mode)}, compatibility_mode: {compatibility_mode}")
142
+
143
+ # 为thinking模型添加thinking配置
144
+ thinking_budget = get_thinking_budget(openai_request.model)
145
+ if thinking_budget is not None:
146
+ request_data["generationConfig"]["thinkingConfig"] = {
147
+ "thinkingBudget": thinking_budget,
148
+ "includeThoughts": should_include_thoughts(openai_request.model)
149
+ }
150
+
151
+ # 为搜索模型添加Google Search工具
152
+ if is_search_model(openai_request.model):
153
+ request_data["tools"] = [{"googleSearch": {}}]
154
+
155
+ # 移除None值
156
+ request_data = {k: v for k, v in request_data.items() if v is not None}
157
+
158
+ # 返回完整的Gemini API payload格式
159
+ return {
160
+ "model": get_base_model_name(openai_request.model),
161
+ "request": request_data
162
+ }
163
+
164
+ def _extract_content_and_reasoning(parts: list) -> tuple:
165
+ """从Gemini响应部件中提取内容和推理内容"""
166
+ content = ""
167
+ reasoning_content = ""
168
+
169
+ for part in parts:
170
+ # 处理文本内容
171
+ if part.get("text"):
172
+ # 检查这个部件是否包含thinking tokens
173
+ if part.get("thought", False):
174
+ reasoning_content += part.get("text", "")
175
+ else:
176
+ content += part.get("text", "")
177
+
178
+ return content, reasoning_content
179
+
180
+ def _build_message_with_reasoning(role: str, content: str, reasoning_content: str) -> dict:
181
+ """构建包含可选推理内容的消息对象"""
182
+ message = {
183
+ "role": role,
184
+ "content": content
185
+ }
186
+
187
+ # 如果有thinking tokens,添加reasoning_content
188
+ if reasoning_content:
189
+ message["reasoning_content"] = reasoning_content
190
+
191
+ return message
192
+
193
+ def gemini_response_to_openai(gemini_response: Dict[str, Any], model: str) -> Dict[str, Any]:
194
+ """
195
+ 将Gemini API响应转换为OpenAI聊天完成格式
196
+
197
+ Args:
198
+ gemini_response: 来自Gemini API的响应
199
+ model: 要在响应中包含的模型名称
200
+
201
+ Returns:
202
+ OpenAI聊天完成格式的字典
203
+ """
204
+ choices = []
205
+
206
+ for candidate in gemini_response.get("candidates", []):
207
+ role = candidate.get("content", {}).get("role", "assistant")
208
+
209
+ # 将Gemini角色映射回OpenAI角色
210
+ if role == "model":
211
+ role = "assistant"
212
+
213
+ # 提取并分离thinking tokens和常规内容
214
+ parts = candidate.get("content", {}).get("parts", [])
215
+ content, reasoning_content = _extract_content_and_reasoning(parts)
216
+
217
+ # 构建消息对象
218
+ message = _build_message_with_reasoning(role, content, reasoning_content)
219
+
220
+ choices.append({
221
+ "index": candidate.get("index", 0),
222
+ "message": message,
223
+ "finish_reason": _map_finish_reason(candidate.get("finishReason")),
224
+ })
225
+
226
+ return {
227
+ "id": str(uuid.uuid4()),
228
+ "object": "chat.completion",
229
+ "created": int(time.time()),
230
+ "model": model,
231
+ "choices": choices,
232
+ }
233
+
234
+ def gemini_stream_chunk_to_openai(gemini_chunk: Dict[str, Any], model: str, response_id: str) -> Dict[str, Any]:
235
+ """
236
+ 将Gemini流式响应块转换为OpenAI流式格式
237
+
238
+ Args:
239
+ gemini_chunk: 来自Gemini流式响应的单个块
240
+ model: 要在响应中包含的模型名称
241
+ response_id: 此流式响应的一致ID
242
+
243
+ Returns:
244
+ OpenAI流式格式的字典
245
+ """
246
+ choices = []
247
+
248
+ for candidate in gemini_chunk.get("candidates", []):
249
+ role = candidate.get("content", {}).get("role", "assistant")
250
+
251
+ # 将Gemini角色映射回OpenAI角色
252
+ if role == "model":
253
+ role = "assistant"
254
+
255
+ # 提取并分离thinking tokens和常规内容
256
+ parts = candidate.get("content", {}).get("parts", [])
257
+ content, reasoning_content = _extract_content_and_reasoning(parts)
258
+
259
+ # 构建delta对象
260
+ delta = {}
261
+ if content:
262
+ delta["content"] = content
263
+ if reasoning_content:
264
+ delta["reasoning_content"] = reasoning_content
265
+
266
+ choices.append({
267
+ "index": candidate.get("index", 0),
268
+ "delta": delta,
269
+ "finish_reason": _map_finish_reason(candidate.get("finishReason")),
270
+ })
271
+
272
+ return {
273
+ "id": response_id,
274
+ "object": "chat.completion.chunk",
275
+ "created": int(time.time()),
276
+ "model": model,
277
+ "choices": choices,
278
+ }
279
+
280
+ def _map_finish_reason(gemini_reason: str) -> str:
281
+ """
282
+ 将Gemini结束原因映射到OpenAI结束原因
283
+
284
+ Args:
285
+ gemini_reason: 来自Gemini API的结束原因
286
+
287
+ Returns:
288
+ OpenAI兼容的结束原因
289
+ """
290
+ if gemini_reason == "STOP":
291
+ return "stop"
292
+ elif gemini_reason == "MAX_TOKENS":
293
+ return "length"
294
+ elif gemini_reason in ["SAFETY", "RECITATION"]:
295
+ return "content_filter"
296
+ else:
297
+ return None
298
+
299
+ def validate_openai_request(request_data: Dict[str, Any]) -> ChatCompletionRequest:
300
+ """
301
+ 验证并标准化OpenAI请求数据
302
+
303
+ Args:
304
+ request_data: 原始请求数据字典
305
+
306
+ Returns:
307
+ 验证后的ChatCompletionRequest对象
308
+
309
+ Raises:
310
+ ValueError: 当请求数据无效时
311
+ """
312
+ try:
313
+ return ChatCompletionRequest(**request_data)
314
+ except Exception as e:
315
+ raise ValueError(f"Invalid OpenAI request format: {str(e)}")
316
+
317
+ def normalize_openai_request(request_data: ChatCompletionRequest) -> ChatCompletionRequest:
318
+ """
319
+ 标准化OpenAI请求数据,应用默认值和限制
320
+
321
+ Args:
322
+ request_data: 原始请求对象
323
+
324
+ Returns:
325
+ 标准化后的请求对象
326
+ """
327
+ # 限制max_tokens
328
+ if getattr(request_data, "max_tokens", None) is not None and request_data.max_tokens > 65535:
329
+ request_data.max_tokens = 65535
330
+
331
+ # 覆写 top_k 为 64
332
+ setattr(request_data, "top_k", 64)
333
+
334
+ # 过滤空消息
335
+ filtered_messages = []
336
+ for m in request_data.messages:
337
+ content = getattr(m, "content", None)
338
+ if content:
339
+ if isinstance(content, str) and content.strip():
340
+ filtered_messages.append(m)
341
+ elif isinstance(content, list) and len(content) > 0:
342
+ has_valid_content = False
343
+ for part in content:
344
+ if isinstance(part, dict):
345
+ if part.get("type") == "text" and part.get("text", "").strip():
346
+ has_valid_content = True
347
+ break
348
+ elif part.get("type") == "image_url" and part.get("image_url", {}).get("url"):
349
+ has_valid_content = True
350
+ break
351
+ if has_valid_content:
352
+ filtered_messages.append(m)
353
+
354
+ request_data.messages = filtered_messages
355
+
356
+ return request_data
357
+
358
+ def is_health_check_request(request_data: ChatCompletionRequest) -> bool:
359
+ """
360
+ 检查是否为健康检查请求
361
+
362
+ Args:
363
+ request_data: 请求对象
364
+
365
+ Returns:
366
+ 是否为健康检查请求
367
+ """
368
+ return (len(request_data.messages) == 1 and
369
+ getattr(request_data.messages[0], "role", None) == "user" and
370
+ getattr(request_data.messages[0], "content", None) == "Hi")
371
+
372
+ def create_health_check_response() -> Dict[str, Any]:
373
+ """
374
+ 创建健康检查响应
375
+
376
+ Returns:
377
+ 健康检查响应字典
378
+ """
379
+ return {
380
+ "choices": [{
381
+ "message": {
382
+ "role": "assistant",
383
+ "content": "gcli2api正常工作中"
384
+ }
385
+ }]
386
+ }
387
+
388
+ def extract_model_settings(model: str) -> Dict[str, Any]:
389
+ """
390
+ 从模型名称中提取设置信息
391
+
392
+ Args:
393
+ model: 模型名称
394
+
395
+ Returns:
396
+ 包含模型设置的字典
397
+ """
398
+ return {
399
+ "base_model": get_base_model_name(model),
400
+ "use_fake_streaming": model.endswith("-假流式"),
401
+ "thinking_budget": get_thinking_budget(model),
402
+ "include_thoughts": should_include_thoughts(model)
403
+ }
src/state_manager.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 统一状态管理器
3
+ """
4
+ import asyncio
5
+ import os
6
+ from typing import Dict, Any
7
+ from contextlib import asynccontextmanager
8
+
9
+ from config import is_mongodb_mode
10
+ from log import log
11
+ from .storage_adapter import get_storage_adapter
12
+
13
+
14
+ class StateManager:
15
+ """
16
+ 统一状态管理器
17
+ """
18
+
19
+ def __init__(self, state_file_path: str):
20
+ self.state_file_path = state_file_path
21
+ self._lock = asyncio.Lock()
22
+ self._storage_adapter = None
23
+ self._initialized = False
24
+
25
+ # 从文件路径推断存储用途
26
+ self._storage_purpose = self._infer_storage_purpose(state_file_path)
27
+
28
+ def _infer_storage_purpose(self, file_path: str) -> str:
29
+ """根据文件路径推断存储用途"""
30
+ filename = os.path.basename(file_path)
31
+
32
+ if "creds_state" in filename:
33
+ return "credential_state"
34
+ elif "config" in filename:
35
+ return "config"
36
+ elif "usage" in filename or "stats" in filename:
37
+ return "usage_stats"
38
+ else:
39
+ return "general"
40
+
41
+ async def _ensure_initialized(self):
42
+ """确保状态管理器已初始化"""
43
+ if not self._initialized:
44
+ self._storage_adapter = await get_storage_adapter()
45
+ self._initialized = True
46
+
47
+ if await is_mongodb_mode():
48
+ log.debug(f"Unified state manager initialized with MongoDB backend for: {self._storage_purpose}")
49
+ else:
50
+ log.debug(f"Unified state manager initialized with file backend for: {self._storage_purpose}")
51
+
52
+ async def _load_state(self) -> Dict[str, Any]:
53
+ """加载状态数据"""
54
+ await self._ensure_initialized()
55
+
56
+ if self._storage_purpose == "credential_state":
57
+ return await self._storage_adapter.get_all_credential_states()
58
+ elif self._storage_purpose == "config":
59
+ return await self._storage_adapter.get_all_config()
60
+ elif self._storage_purpose == "usage_stats":
61
+ return await self._storage_adapter.get_all_usage_stats()
62
+ else:
63
+ # 对于通用存储,尝试获取配置数据
64
+ return await self._storage_adapter.get_all_config()
65
+
66
+ async def _save_state(self, state: Dict[str, Any]):
67
+ """保存状态数据"""
68
+ await self._ensure_initialized()
69
+
70
+ # 根据存储用途批量更新数据
71
+ if self._storage_purpose == "credential_state":
72
+ # 批量更新凭证状态
73
+ for filename, file_state in state.items():
74
+ await self._storage_adapter.update_credential_state(filename, file_state)
75
+ elif self._storage_purpose == "config":
76
+ # 批量更新配置
77
+ for key, value in state.items():
78
+ await self._storage_adapter.set_config(key, value)
79
+ elif self._storage_purpose == "usage_stats":
80
+ # 批量更新使用统计
81
+ for filename, stats in state.items():
82
+ await self._storage_adapter.update_usage_stats(filename, stats)
83
+ else:
84
+ # 通用存储,作为配置处理
85
+ for key, value in state.items():
86
+ await self._storage_adapter.set_config(key, value)
87
+
88
+ @asynccontextmanager
89
+ async def transaction(self):
90
+ """
91
+ 事务上下文管理器,兼容原有接口。
92
+ Usage:
93
+ async with state_manager.transaction() as state:
94
+ state['key'] = 'value'
95
+ # State is automatically saved on exit
96
+ """
97
+ async with self._lock:
98
+ state = await self._load_state()
99
+ try:
100
+ yield state
101
+ await self._save_state(state)
102
+ except Exception:
103
+ # Don't save if there was an error
104
+ raise
105
+
106
+ async def read_file_state(self, filename: str) -> Dict[str, Any]:
107
+ """读取特定文件的状态,兼容原有接口"""
108
+ await self._ensure_initialized()
109
+
110
+ if self._storage_purpose == "credential_state":
111
+ return await self._storage_adapter.get_credential_state(filename)
112
+ elif self._storage_purpose == "usage_stats":
113
+ return await self._storage_adapter.get_usage_stats(filename)
114
+ else:
115
+ # 对于配置和通用存储,filename作为配置键
116
+ value = await self._storage_adapter.get_config(filename)
117
+ return value if isinstance(value, dict) else {}
118
+
119
+ async def update_file_state(self, filename: str, updates: Dict[str, Any]):
120
+ """更新特定文件的状态,兼容原有接口"""
121
+ await self._ensure_initialized()
122
+
123
+ if self._storage_purpose == "credential_state":
124
+ await self._storage_adapter.update_credential_state(filename, updates)
125
+ elif self._storage_purpose == "usage_stats":
126
+ await self._storage_adapter.update_usage_stats(filename, updates)
127
+ else:
128
+ # 对于配置存储,如果updates是字典则作为嵌套配置处理
129
+ if isinstance(updates, dict) and len(updates) == 1:
130
+ # 如果只有一个键值对,可能是设置单个配置
131
+ for key, value in updates.items():
132
+ await self._storage_adapter.set_config(f"{filename}.{key}", value)
133
+ else:
134
+ # 否则将整个updates作为配置值
135
+ await self._storage_adapter.set_config(filename, updates)
136
+
137
+ async def batch_update(self, updates: Dict[str, Dict[str, Any]]):
138
+ """批量更新多个文件,兼容原有接口"""
139
+ await self._ensure_initialized()
140
+
141
+ for filename, file_updates in updates.items():
142
+ await self.update_file_state(filename, file_updates)
143
+
144
+
145
+ # 全局状态管理器实例缓存
146
+ _state_managers: Dict[str, StateManager] = {}
147
+
148
+
149
+ def get_state_manager(state_file_path: str) -> StateManager:
150
+ """获取或创建状态管理器实例,兼容原有接口"""
151
+ if state_file_path not in _state_managers:
152
+ _state_managers[state_file_path] = StateManager(state_file_path)
153
+ return _state_managers[state_file_path]
154
+
155
+
156
+ async def close_all_state_managers():
157
+ """关闭所有状态管理器(用于优雅关闭)"""
158
+ global _state_managers
159
+
160
+ # 关闭存储适配器(这会自动处理所有状态管理器)
161
+ from .storage_adapter import close_storage_adapter
162
+ await close_storage_adapter()
163
+
164
+ # 清空缓存
165
+ _state_managers.clear()
166
+ log.debug("All state managers closed")
src/storage/cache_manager.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 统一内存缓存管理器
3
+ 为所有存储后端提供一致的内存缓存机制,确保读写一致性和高性能。
4
+ """
5
+ import asyncio
6
+ import time
7
+ from typing import Dict, Any, Optional
8
+ from collections import deque
9
+ from abc import ABC, abstractmethod
10
+
11
+ from log import log
12
+
13
+
14
+ class CacheBackend(ABC):
15
+ """缓存后端接口,定义底层存储的读写操作"""
16
+
17
+ @abstractmethod
18
+ async def load_data(self) -> Dict[str, Any]:
19
+ """从底层存储加载数据"""
20
+ pass
21
+
22
+ @abstractmethod
23
+ async def write_data(self, data: Dict[str, Any]) -> bool:
24
+ """将数据写入底层存储"""
25
+ pass
26
+
27
+
28
+ class UnifiedCacheManager:
29
+ """统一缓存管理器"""
30
+
31
+ def __init__(
32
+ self,
33
+ cache_backend: CacheBackend,
34
+ cache_ttl: float = 300.0,
35
+ write_delay: float = 1.0,
36
+ name: str = "cache"
37
+ ):
38
+ """
39
+ 初始化缓存管理器
40
+
41
+ Args:
42
+ cache_backend: 缓存后端实现
43
+ cache_ttl: 缓存TTL(秒)
44
+ write_delay: 写入延迟(秒)
45
+ name: 缓存名称(用于日志)
46
+ """
47
+ self._backend = cache_backend
48
+ self._cache_ttl = cache_ttl
49
+ self._write_delay = write_delay
50
+ self._name = name
51
+
52
+ # 缓存数据
53
+ self._cache: Dict[str, Any] = {}
54
+ self._cache_dirty = False
55
+ self._last_cache_time = 0
56
+
57
+ # 并发控制
58
+ self._cache_lock = asyncio.Lock()
59
+
60
+ # 异步写回任务
61
+ self._write_task: Optional[asyncio.Task] = None
62
+ self._shutdown_event = asyncio.Event()
63
+
64
+ # 性能监控
65
+ self._operation_count = 0
66
+ self._operation_times = deque(maxlen=1000)
67
+
68
+ async def start(self):
69
+ """启动缓存管理器"""
70
+ if self._write_task and not self._write_task.done():
71
+ return
72
+
73
+ self._shutdown_event.clear()
74
+ self._write_task = asyncio.create_task(self._write_loop())
75
+ log.debug(f"{self._name} cache manager started")
76
+
77
+ async def stop(self):
78
+ """停止缓存管理器并刷新数据"""
79
+ self._shutdown_event.set()
80
+
81
+ if self._write_task and not self._write_task.done():
82
+ try:
83
+ await asyncio.wait_for(self._write_task, timeout=5.0)
84
+ except asyncio.TimeoutError:
85
+ self._write_task.cancel()
86
+ log.warning(f"{self._name} cache writer forcibly cancelled")
87
+
88
+ # 刷新缓存
89
+ await self._flush_cache()
90
+ log.debug(f"{self._name} cache manager stopped")
91
+
92
+ async def get(self, key: str, default: Any = None) -> Any:
93
+ """获取缓存项"""
94
+ async with self._cache_lock:
95
+ start_time = time.time()
96
+
97
+ try:
98
+ # 确保缓存已加载
99
+ await self._ensure_cache_loaded()
100
+
101
+ # 性能监控
102
+ self._operation_count += 1
103
+ operation_time = time.time() - start_time
104
+ self._operation_times.append(operation_time)
105
+
106
+ result = self._cache.get(key, default)
107
+ log.debug(f"{self._name} cache get: {key} in {operation_time:.3f}s")
108
+ return result
109
+
110
+ except Exception as e:
111
+ operation_time = time.time() - start_time
112
+ log.error(f"Error getting {self._name} cache key {key} in {operation_time:.3f}s: {e}")
113
+ return default
114
+
115
+ async def set(self, key: str, value: Any) -> bool:
116
+ """设置缓存项"""
117
+ async with self._cache_lock:
118
+ start_time = time.time()
119
+
120
+ try:
121
+ # 确保缓存已加载
122
+ await self._ensure_cache_loaded()
123
+
124
+ # 更新缓存
125
+ self._cache[key] = value
126
+ self._cache_dirty = True
127
+
128
+ # 性能监控
129
+ self._operation_count += 1
130
+ operation_time = time.time() - start_time
131
+ self._operation_times.append(operation_time)
132
+
133
+ log.debug(f"{self._name} cache set: {key} in {operation_time:.3f}s")
134
+ return True
135
+
136
+ except Exception as e:
137
+ operation_time = time.time() - start_time
138
+ log.error(f"Error setting {self._name} cache key {key} in {operation_time:.3f}s: {e}")
139
+ return False
140
+
141
+ async def delete(self, key: str) -> bool:
142
+ """删除缓存项"""
143
+ async with self._cache_lock:
144
+ start_time = time.time()
145
+
146
+ try:
147
+ # 确保缓存已加载
148
+ await self._ensure_cache_loaded()
149
+
150
+ if key in self._cache:
151
+ del self._cache[key]
152
+ self._cache_dirty = True
153
+
154
+ # 性能监控
155
+ self._operation_count += 1
156
+ operation_time = time.time() - start_time
157
+ self._operation_times.append(operation_time)
158
+
159
+ log.debug(f"{self._name} cache delete: {key} in {operation_time:.3f}s")
160
+ return True
161
+ else:
162
+ log.warning(f"{self._name} cache key not found for deletion: {key}")
163
+ return False
164
+
165
+ except Exception as e:
166
+ operation_time = time.time() - start_time
167
+ log.error(f"Error deleting {self._name} cache key {key} in {operation_time:.3f}s: {e}")
168
+ return False
169
+
170
+ async def get_all(self) -> Dict[str, Any]:
171
+ """获取所有缓存数据"""
172
+ async with self._cache_lock:
173
+ start_time = time.time()
174
+
175
+ try:
176
+ # 确保缓存已加载
177
+ await self._ensure_cache_loaded()
178
+
179
+ # 性能监控
180
+ self._operation_count += 1
181
+ operation_time = time.time() - start_time
182
+ self._operation_times.append(operation_time)
183
+
184
+ log.debug(f"{self._name} cache get_all ({len(self._cache)}) in {operation_time:.3f}s")
185
+ return self._cache.copy()
186
+
187
+ except Exception as e:
188
+ operation_time = time.time() - start_time
189
+ log.error(f"Error getting all {self._name} cache in {operation_time:.3f}s: {e}")
190
+ return {}
191
+
192
+ async def update_multi(self, updates: Dict[str, Any]) -> bool:
193
+ """批量更新缓存项"""
194
+ async with self._cache_lock:
195
+ start_time = time.time()
196
+
197
+ try:
198
+ # 确保缓存已加载
199
+ await self._ensure_cache_loaded()
200
+
201
+ # 批量更新
202
+ self._cache.update(updates)
203
+ self._cache_dirty = True
204
+
205
+ # 性能监控
206
+ self._operation_count += 1
207
+ operation_time = time.time() - start_time
208
+ self._operation_times.append(operation_time)
209
+
210
+ log.debug(f"{self._name} cache update_multi ({len(updates)}) in {operation_time:.3f}s")
211
+ return True
212
+
213
+ except Exception as e:
214
+ operation_time = time.time() - start_time
215
+ log.error(f"Error updating {self._name} cache multi in {operation_time:.3f}s: {e}")
216
+ return False
217
+
218
+ async def _ensure_cache_loaded(self):
219
+ """确保缓存已从底层存储加载"""
220
+ current_time = time.time()
221
+
222
+ # 检查缓存是否需要加载(首次加载或过期)
223
+ # 如果缓存脏了(有未写入的数据),不要重新加载以避免数据丢失
224
+ if (self._last_cache_time == 0 or
225
+ (current_time - self._last_cache_time > self._cache_ttl and not self._cache_dirty)):
226
+
227
+ await self._load_cache()
228
+ self._last_cache_time = current_time
229
+
230
+ async def _load_cache(self):
231
+ """从底层存储加载缓存"""
232
+ try:
233
+ start_time = time.time()
234
+
235
+ # 从后端加载数据
236
+ data = await self._backend.load_data()
237
+
238
+ if data:
239
+ self._cache = data
240
+ log.debug(f"{self._name} cache loaded ({len(self._cache)}) from backend")
241
+ else:
242
+ # 如果后端没有数据,初始化空缓存
243
+ self._cache = {}
244
+ log.debug(f"{self._name} cache initialized empty")
245
+
246
+ operation_time = time.time() - start_time
247
+ log.debug(f"{self._name} cache loaded in {operation_time:.3f}s")
248
+
249
+ except Exception as e:
250
+ log.error(f"Error loading {self._name} cache from backend: {e}")
251
+ self._cache = {}
252
+
253
+ async def _write_loop(self):
254
+ """异步写回循环"""
255
+ while not self._shutdown_event.is_set():
256
+ try:
257
+ # 等待写入延迟或关闭信号
258
+ try:
259
+ await asyncio.wait_for(self._shutdown_event.wait(), timeout=self._write_delay)
260
+ break # 收到关闭信号
261
+ except asyncio.TimeoutError:
262
+ pass # 超时,检查是否需要写回
263
+
264
+ # 如果缓存脏了,写回底层存储
265
+ async with self._cache_lock:
266
+ if self._cache_dirty:
267
+ await self._write_cache()
268
+
269
+ except Exception as e:
270
+ log.error(f"Error in {self._name} cache writer loop: {e}")
271
+ await asyncio.sleep(1)
272
+
273
+ async def _write_cache(self):
274
+ """将缓存写回底层存储"""
275
+ if not self._cache_dirty:
276
+ return
277
+
278
+ try:
279
+ start_time = time.time()
280
+
281
+ # 写入后端
282
+ success = await self._backend.write_data(self._cache.copy())
283
+
284
+ if success:
285
+ self._cache_dirty = False
286
+ operation_time = time.time() - start_time
287
+ log.debug(f"{self._name} cache written to backend in {operation_time:.3f}s ({len(self._cache)} items)")
288
+ else:
289
+ log.error(f"Failed to write {self._name} cache to backend")
290
+
291
+ except Exception as e:
292
+ log.error(f"Error writing {self._name} cache to backend: {e}")
293
+
294
+ async def _flush_cache(self):
295
+ """立即刷新缓存到底层存储"""
296
+ async with self._cache_lock:
297
+ if self._cache_dirty:
298
+ await self._write_cache()
299
+ log.debug(f"{self._name} cache flushed to backend")
300
+
301
+ def get_stats(self) -> Dict[str, Any]:
302
+ """获取缓存统计信息"""
303
+ avg_time = sum(self._operation_times) / len(self._operation_times) if self._operation_times else 0
304
+
305
+ return {
306
+ "cache_name": self._name,
307
+ "cache_size": len(self._cache),
308
+ "cache_dirty": self._cache_dirty,
309
+ "operation_count": self._operation_count,
310
+ "avg_operation_time": avg_time,
311
+ "last_cache_time": self._last_cache_time,
312
+ }
src/storage/file_storage_manager.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 本地文件存储管理器,使用统一缓存支持队列写入优化。
3
+ 所有凭证和状态数据存储在creds.toml中,配置数据存储在config.toml中。
4
+ """
5
+ import asyncio
6
+ import os
7
+ import json
8
+ import time
9
+ from typing import Dict, Any, List, Optional
10
+
11
+ import aiofiles
12
+ import toml
13
+
14
+ from log import log
15
+ from .cache_manager import UnifiedCacheManager, CacheBackend
16
+
17
+
18
+ class FileCacheBackend(CacheBackend):
19
+ """文件缓存后端实现"""
20
+
21
+ def __init__(self, file_path: str):
22
+ self._file_path = file_path
23
+
24
+ async def load_data(self) -> Dict[str, Any]:
25
+ """从TOML文件加载数据"""
26
+ try:
27
+ if not os.path.exists(self._file_path):
28
+ return {}
29
+
30
+ async with aiofiles.open(self._file_path, "r", encoding="utf-8") as f:
31
+ content = await f.read()
32
+
33
+ if not content.strip():
34
+ return {}
35
+
36
+ return toml.loads(content)
37
+
38
+ except Exception as e:
39
+ log.error(f"Error loading data from file {self._file_path}: {e}")
40
+ return {}
41
+
42
+ async def write_data(self, data: Dict[str, Any]) -> bool:
43
+ """将数据写入TOML文件"""
44
+ try:
45
+ # 确保目录存在
46
+ os.makedirs(os.path.dirname(self._file_path), exist_ok=True)
47
+
48
+ # 写入TOML文件
49
+ toml_content = toml.dumps(data)
50
+ async with aiofiles.open(self._file_path, "w", encoding="utf-8") as f:
51
+ await f.write(toml_content)
52
+
53
+ return True
54
+
55
+ except Exception as e:
56
+ log.error(f"Error writing data to file {self._file_path}: {e}")
57
+ return False
58
+
59
+
60
+ class FileStorageManager:
61
+ """基于本地文件的存储管理器(使用统一缓存)"""
62
+
63
+ # 状态字段常量
64
+ STATE_FIELDS = {
65
+ "error_codes", "disabled", "last_success", "user_email",
66
+ "gemini_2_5_pro_calls", "total_calls", "next_reset_time",
67
+ "daily_limit_gemini_2_5_pro", "daily_limit_total"
68
+ }
69
+
70
+ # 默认状态数据模板(不包含动态值)
71
+ _DEFAULT_STATE_TEMPLATE = {
72
+ "error_codes": [],
73
+ "disabled": False,
74
+ "user_email": None,
75
+ "gemini_2_5_pro_calls": 0,
76
+ "total_calls": 0,
77
+ "next_reset_time": None,
78
+ "daily_limit_gemini_2_5_pro": 100,
79
+ "daily_limit_total": 1000
80
+ }
81
+
82
+ @classmethod
83
+ def get_default_state(cls) -> Dict[str, Any]:
84
+ """获取默认状态数据(包含当前时间戳)"""
85
+ state = cls._DEFAULT_STATE_TEMPLATE.copy()
86
+ state["last_success"] = time.time()
87
+ return state
88
+
89
+ def __init__(self):
90
+ self._credentials_dir = None # 将通过异步初始化设置
91
+ self._state_file = None
92
+ self._config_file = None
93
+ self._lock = asyncio.Lock()
94
+ self._initialized = False
95
+
96
+ # 统一缓存管理器
97
+ self._credentials_cache_manager: Optional[UnifiedCacheManager] = None
98
+ self._config_cache_manager: Optional[UnifiedCacheManager] = None
99
+
100
+ # 配置参数
101
+ self._write_delay = 0.5 # 写入延迟(秒)
102
+ self._cache_ttl = 300 # 缓存TTL(秒)
103
+
104
+ async def initialize(self) -> None:
105
+ """初始化文件存储"""
106
+ if self._initialized:
107
+ return
108
+
109
+ # 获取凭证目录配置(初始化时直接使用环境变量,避免循环依赖)
110
+ self._credentials_dir = os.getenv("CREDENTIALS_DIR", "./creds")
111
+ self._state_file = os.path.join(self._credentials_dir, "creds.toml")
112
+ self._config_file = os.path.join(self._credentials_dir, "config.toml")
113
+
114
+ # 确保目录存在
115
+ os.makedirs(self._credentials_dir, exist_ok=True)
116
+
117
+ # 执行JSON到TOML的迁移
118
+ await self._migrate_json_to_toml()
119
+
120
+ # 创建缓存管理器
121
+ credentials_backend = FileCacheBackend(self._state_file)
122
+ config_backend = FileCacheBackend(self._config_file)
123
+
124
+ self._credentials_cache_manager = UnifiedCacheManager(
125
+ credentials_backend,
126
+ cache_ttl=self._cache_ttl,
127
+ write_delay=self._write_delay,
128
+ name="credentials"
129
+ )
130
+
131
+ self._config_cache_manager = UnifiedCacheManager(
132
+ config_backend,
133
+ cache_ttl=self._cache_ttl,
134
+ write_delay=self._write_delay,
135
+ name="config"
136
+ )
137
+
138
+ # 启动缓存管理器
139
+ await self._credentials_cache_manager.start()
140
+ await self._config_cache_manager.start()
141
+
142
+ self._initialized = True
143
+ log.debug("File storage manager initialized with unified cache")
144
+
145
+ async def close(self) -> None:
146
+ """关闭文件存储"""
147
+ # 停止缓存管理器
148
+ if self._credentials_cache_manager:
149
+ await self._credentials_cache_manager.stop()
150
+ if self._config_cache_manager:
151
+ await self._config_cache_manager.stop()
152
+
153
+ self._initialized = False
154
+ log.debug("File storage manager closed with unified cache flushed")
155
+
156
+ def _normalize_filename(self, filename: str) -> str:
157
+ """标准化文件名"""
158
+ return os.path.basename(filename)
159
+
160
+ def _ensure_initialized(self):
161
+ """确保已初始化"""
162
+ if not self._initialized:
163
+ raise RuntimeError("File storage manager not initialized")
164
+
165
+ async def _migrate_json_to_toml(self) -> None:
166
+ """将现有的JSON凭证文件和旧的creds_state.toml迁移到新的creds.toml文件中"""
167
+ try:
168
+ # 扫描JSON凭证文件
169
+ json_files = []
170
+ if os.path.exists(self._credentials_dir):
171
+ for filename in os.listdir(self._credentials_dir):
172
+ if filename.endswith(".json"):
173
+ json_files.append(filename)
174
+
175
+ # 检查旧的creds_state.toml文件
176
+ old_state_file = os.path.join(self._credentials_dir, "creds_state.toml")
177
+ has_old_state = os.path.exists(old_state_file)
178
+
179
+ if not json_files and not has_old_state:
180
+ log.debug("No JSON credential files or old state file found for migration")
181
+ return
182
+
183
+ # 加载现有TOML数据(如果存在)
184
+ toml_data = {}
185
+ if os.path.exists(self._state_file):
186
+ try:
187
+ async with aiofiles.open(self._state_file, "r", encoding="utf-8") as f:
188
+ content = await f.read()
189
+ if content.strip():
190
+ toml_data = toml.loads(content)
191
+ except Exception as e:
192
+ log.error(f"Failed to load existing TOML file: {e}")
193
+
194
+ # 加载旧的creds_state.toml文件(稍后处理)
195
+ old_state_data = {}
196
+ if has_old_state:
197
+ try:
198
+ async with aiofiles.open(old_state_file, "r", encoding="utf-8") as f:
199
+ content = await f.read()
200
+ old_state_data = toml.loads(content)
201
+ log.debug("Loaded old state file for potential migration")
202
+ except Exception as e:
203
+ log.error(f"Failed to load old state file: {e}")
204
+ old_state_data = {}
205
+
206
+ if json_files:
207
+ log.info(f"Migrating {len(json_files)} JSON credential files to TOML")
208
+
209
+ # 处理每个JSON文件
210
+ migrated_count = 0
211
+ for filename in json_files:
212
+ try:
213
+ filepath = os.path.join(self._credentials_dir, filename)
214
+
215
+ # 读取JSON凭证数据
216
+ async with aiofiles.open(filepath, "r", encoding="utf-8") as f:
217
+ json_content = await f.read()
218
+ credential_data = json.loads(json_content)
219
+
220
+ # 创建新的section:凭证数据 + 状态数据
221
+ section_data = credential_data.copy()
222
+
223
+ # 首先添加默认状态数据
224
+ section_data.update(self.get_default_state())
225
+
226
+ # 如果旧状态文件中有该凭证的状态数据,则使用旧状态数据覆盖默认值
227
+ if filename in old_state_data and isinstance(old_state_data[filename], dict):
228
+ log.debug(f"Using old state data for: {filename}")
229
+ section_data.update(old_state_data[filename])
230
+
231
+ # 如果当前TOML中已存在该凭证,保留其状态数据
232
+ if filename in toml_data and isinstance(toml_data[filename], dict):
233
+ log.debug(f"Merging with existing TOML state for: {filename}")
234
+ existing_state = toml_data[filename]
235
+ section_data.update(existing_state)
236
+
237
+ # 最后确保凭证数据是最新的(覆盖任何冲突的字段)
238
+ section_data.update(credential_data)
239
+
240
+ toml_data[filename] = section_data
241
+
242
+ migrated_count += 1
243
+ log.debug(f"Migrated credential: {filename}")
244
+
245
+ except Exception as e:
246
+ log.error(f"Failed to migrate {filename}: {e}")
247
+ continue
248
+
249
+ # 保存TOML文件(如果有新的迁移)
250
+ if migrated_count > 0:
251
+ try:
252
+ toml_content = toml.dumps(toml_data)
253
+ async with aiofiles.open(self._state_file, "w", encoding="utf-8") as f:
254
+ await f.write(toml_content)
255
+
256
+ # 删除已迁移的JSON文件
257
+ for filename in json_files:
258
+ try:
259
+ if filename in toml_data: # 确保文件确实被迁移了
260
+ filepath = os.path.join(self._credentials_dir, filename)
261
+ os.remove(filepath)
262
+ log.debug(f"Removed migrated JSON file: {filename}")
263
+ except Exception as e:
264
+ log.warning(f"Failed to remove {filename}: {e}")
265
+
266
+ # 删除旧的状态文件(如果存在)
267
+ if has_old_state:
268
+ try:
269
+ os.remove(old_state_file)
270
+ log.debug("Removed old state file: creds_state.toml")
271
+ except Exception as e:
272
+ log.warning(f"Failed to remove old state file: {e}")
273
+
274
+ log.info(f"Migration completed: {migrated_count} files migrated to TOML format")
275
+
276
+ except Exception as e:
277
+ log.error(f"Failed to save migrated TOML file: {e}")
278
+
279
+ except Exception as e:
280
+ log.error(f"Error during JSON to TOML migration: {e}")
281
+
282
+ # ============ 凭证管理 ============
283
+
284
+ async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
285
+ """存储凭证数据到统一缓存"""
286
+ self._ensure_initialized()
287
+
288
+ try:
289
+ filename = self._normalize_filename(filename)
290
+
291
+ # 获取现有数据或创建新数据
292
+ all_data = await self._credentials_cache_manager.get_all()
293
+ existing_state = all_data.get(filename, {})
294
+
295
+ # 创建新的section数据:凭证数据 + 状态数据
296
+ final_data = self.get_default_state()
297
+ final_data.update(existing_state)
298
+ final_data.update(credential_data) # 凭证数据覆盖状态数据中的同名字段
299
+
300
+ # 更新整个数据集
301
+ all_data[filename] = final_data
302
+
303
+ success = await self._credentials_cache_manager.update_multi({filename: final_data})
304
+ log.debug(f"Stored credential to unified cache: {filename}")
305
+ return success
306
+
307
+ except Exception as e:
308
+ log.error(f"Error storing credential {filename}: {e}")
309
+ return False
310
+
311
+ async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
312
+ """从统一缓存获取凭证数据"""
313
+ self._ensure_initialized()
314
+
315
+ try:
316
+ filename = self._normalize_filename(filename)
317
+ all_data = await self._credentials_cache_manager.get_all()
318
+
319
+ if filename not in all_data:
320
+ return None
321
+
322
+ section_data = all_data[filename]
323
+
324
+ # 提取凭证数据(排除状态字段)
325
+ credential_data = {k: v for k, v in section_data.items() if k not in self.STATE_FIELDS}
326
+ return credential_data
327
+
328
+ except Exception as e:
329
+ log.error(f"Error getting credential {filename}: {e}")
330
+ return None
331
+
332
+ async def list_credentials(self) -> List[str]:
333
+ """从统一缓存列出所有凭证文件名"""
334
+ self._ensure_initialized()
335
+
336
+ try:
337
+ all_data = await self._credentials_cache_manager.get_all()
338
+ return list(all_data.keys())
339
+
340
+ except Exception as e:
341
+ log.error(f"Error listing credentials: {e}")
342
+ return []
343
+
344
+ async def delete_credential(self, filename: str) -> bool:
345
+ """从统一缓存删除凭证"""
346
+ self._ensure_initialized()
347
+
348
+ try:
349
+ filename = self._normalize_filename(filename)
350
+ success = await self._credentials_cache_manager.delete(filename)
351
+ log.debug(f"Deleted credential from unified cache: {filename}")
352
+ return success
353
+
354
+ except Exception as e:
355
+ log.error(f"Error deleting credential {filename}: {e}")
356
+ return False
357
+
358
+ # ============ 状态管理 ============
359
+
360
+ async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
361
+ """更新凭证状态"""
362
+ self._ensure_initialized()
363
+
364
+ try:
365
+ filename = self._normalize_filename(filename)
366
+ all_data = await self._credentials_cache_manager.get_all()
367
+
368
+ if filename not in all_data:
369
+ all_data[filename] = self.get_default_state()
370
+
371
+ # 更新状态
372
+ all_data[filename].update(state_updates)
373
+
374
+ success = await self._credentials_cache_manager.update_multi({filename: all_data[filename]})
375
+ log.debug(f"Updated credential state in unified cache: {filename}")
376
+ return success
377
+
378
+ except Exception as e:
379
+ log.error(f"Error updating credential state {filename}: {e}")
380
+ return False
381
+
382
+ async def get_credential_state(self, filename: str) -> Dict[str, Any]:
383
+ """从统一缓存获取凭证状态"""
384
+ self._ensure_initialized()
385
+
386
+ try:
387
+ filename = self._normalize_filename(filename)
388
+ all_data = await self._credentials_cache_manager.get_all()
389
+
390
+ if filename not in all_data:
391
+ # 返回基本的状态字段
392
+ default_state = self.get_default_state()
393
+ return {k: v for k, v in default_state.items() if k in {"error_codes", "disabled", "last_success", "user_email"}}
394
+
395
+ section_data = all_data[filename]
396
+
397
+ # 提取状态字段
398
+ state_data = {k: v for k, v in section_data.items() if k in self.STATE_FIELDS}
399
+
400
+ # 确保必要字段存在
401
+ basic_fields = {"error_codes", "disabled", "last_success", "user_email"}
402
+ default_state = self.get_default_state()
403
+
404
+ for field in basic_fields:
405
+ if field not in state_data:
406
+ state_data[field] = default_state[field]
407
+
408
+ return state_data
409
+
410
+ except Exception as e:
411
+ log.error(f"Error getting credential state {filename}: {e}")
412
+ return self.get_default_state()
413
+
414
+ async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
415
+ """从统一缓存获取所有凭证状态"""
416
+ self._ensure_initialized()
417
+
418
+ try:
419
+ all_data = await self._credentials_cache_manager.get_all()
420
+
421
+ states = {}
422
+ for filename, section_data in all_data.items():
423
+ # 提取状态字段
424
+ state_data = {k: v for k, v in section_data.items() if k in self.STATE_FIELDS}
425
+
426
+ # 确保必要字段存在
427
+ basic_fields = {"error_codes", "disabled", "last_success", "user_email"}
428
+ default_state = self.get_default_state()
429
+
430
+ for field in basic_fields:
431
+ if field not in state_data:
432
+ state_data[field] = default_state[field]
433
+
434
+ states[filename] = state_data
435
+
436
+ return states
437
+
438
+ except Exception as e:
439
+ log.error(f"Error getting all credential states: {e}")
440
+ return {}
441
+
442
+ # ============ 配置管理 ============
443
+
444
+ async def set_config(self, key: str, value: Any) -> bool:
445
+ """设置配置到统一缓存"""
446
+ self._ensure_initialized()
447
+ return await self._config_cache_manager.set(key, value)
448
+
449
+ async def get_config(self, key: str, default: Any = None) -> Any:
450
+ """从统一缓存获取配置"""
451
+ self._ensure_initialized()
452
+ return await self._config_cache_manager.get(key, default)
453
+
454
+ async def get_all_config(self) -> Dict[str, Any]:
455
+ """从统一缓存获取所有配置"""
456
+ self._ensure_initialized()
457
+ return await self._config_cache_manager.get_all()
458
+
459
+ async def delete_config(self, key: str) -> bool:
460
+ """从统一缓存删除配置"""
461
+ self._ensure_initialized()
462
+ return await self._config_cache_manager.delete(key)
463
+
464
+ # ============ 使用统计管理 ============
465
+
466
+ async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
467
+ """更新使用统计"""
468
+ self._ensure_initialized()
469
+
470
+ try:
471
+ filename = self._normalize_filename(filename)
472
+ all_data = await self._credentials_cache_manager.get_all()
473
+
474
+ if filename not in all_data:
475
+ all_data[filename] = self.get_default_state()
476
+
477
+ # 更新统计数据
478
+ all_data[filename].update(stats_updates)
479
+
480
+ success = await self._credentials_cache_manager.update_multi({filename: all_data[filename]})
481
+ log.debug(f"Updated usage stats in unified cache: {filename}")
482
+ return success
483
+
484
+ except Exception as e:
485
+ log.error(f"Error updating usage stats {filename}: {e}")
486
+ return False
487
+
488
+ async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
489
+ """从统一缓存获取使用统计"""
490
+ self._ensure_initialized()
491
+
492
+ try:
493
+ filename = self._normalize_filename(filename)
494
+ all_data = await self._credentials_cache_manager.get_all()
495
+
496
+ if filename not in all_data:
497
+ # 返回基本的统计��段
498
+ default_state = self.get_default_state()
499
+ return {k: v for k, v in default_state.items() if k in {"gemini_2_5_pro_calls", "total_calls", "next_reset_time", "daily_limit_gemini_2_5_pro", "daily_limit_total"}}
500
+
501
+ section_data = all_data[filename]
502
+
503
+ # 提取统计字段
504
+ stats_fields = {"gemini_2_5_pro_calls", "total_calls", "next_reset_time", "daily_limit_gemini_2_5_pro", "daily_limit_total"}
505
+ stats_data = {k: v for k, v in section_data.items() if k in stats_fields}
506
+
507
+ # 确保必要字段存在
508
+ default_state = self.get_default_state()
509
+ for field in stats_fields:
510
+ if field not in stats_data:
511
+ stats_data[field] = default_state[field]
512
+
513
+ return stats_data
514
+
515
+ except Exception as e:
516
+ log.error(f"Error getting usage stats {filename}: {e}")
517
+ return self.get_default_state()
518
+
519
+ async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
520
+ """从统一缓存获取所有使用统计"""
521
+ self._ensure_initialized()
522
+
523
+ try:
524
+ all_data = await self._credentials_cache_manager.get_all()
525
+
526
+ stats = {}
527
+ stats_fields = {"gemini_2_5_pro_calls", "total_calls", "next_reset_time", "daily_limit_gemini_2_5_pro", "daily_limit_total"}
528
+
529
+ for filename, section_data in all_data.items():
530
+ # 提取统计字段
531
+ stats_data = {k: v for k, v in section_data.items() if k in stats_fields}
532
+
533
+ # 确保必要字段存在
534
+ default_state = self.get_default_state()
535
+ for field in stats_fields:
536
+ if field not in stats_data:
537
+ stats_data[field] = default_state[field]
538
+
539
+ stats[filename] = stats_data
540
+
541
+ return stats
542
+
543
+ except Exception as e:
544
+ log.error(f"Error getting all usage stats: {e}")
545
+ return {}
546
+
547
+ # ============ 工具方法 ============
548
+
549
+ async def export_credential_to_json(self, filename: str, output_path: str = None) -> bool:
550
+ """将TOML中的凭证导出为JSON文件(用于兼容性和备份)"""
551
+ self._ensure_initialized()
552
+
553
+ try:
554
+ filename = self._normalize_filename(filename)
555
+ credential_data = await self.get_credential(filename)
556
+
557
+ if credential_data is None:
558
+ log.warning(f"Credential not found for export: {filename}")
559
+ return False
560
+
561
+ if output_path is None:
562
+ output_path = os.path.join(self._credentials_dir, f"{filename}.json")
563
+
564
+ # 写入JSON文件
565
+ json_content = json.dumps(credential_data, indent=2, ensure_ascii=False)
566
+ async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
567
+ await f.write(json_content)
568
+
569
+ log.info(f"Credential exported to JSON: {output_path}")
570
+ return True
571
+
572
+ except Exception as e:
573
+ log.error(f"Error exporting credential {filename} to JSON: {e}")
574
+ return False
575
+
576
+ async def import_credential_from_json(self, json_path: str, filename: str = None) -> bool:
577
+ """从JSON文件导入凭证到TOML"""
578
+ self._ensure_initialized()
579
+
580
+ try:
581
+ if not os.path.exists(json_path):
582
+ log.error(f"JSON file not found: {json_path}")
583
+ return False
584
+
585
+ # 读取JSON文件
586
+ async with aiofiles.open(json_path, "r", encoding="utf-8") as f:
587
+ json_content = await f.read()
588
+
589
+ credential_data = json.loads(json_content)
590
+
591
+ if filename is None:
592
+ filename = os.path.basename(json_path)
593
+
594
+ filename = self._normalize_filename(filename)
595
+
596
+ # 存储凭证
597
+ success = await self.store_credential(filename, credential_data)
598
+
599
+ if success:
600
+ log.info(f"Credential imported from JSON: {json_path} -> {filename}")
601
+
602
+ return success
603
+
604
+ except Exception as e:
605
+ log.error(f"Error importing credential from JSON {json_path}: {e}")
606
+ return False
src/storage/mongodb_manager.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MongoDB数据库管理器,使用单文档设计和统一缓存。
3
+ 所有凭证数据存储在一个文档中,配置数据存储在另一个文档中,类似TOML文件结构。
4
+ """
5
+ import asyncio
6
+ import os
7
+ import time
8
+ from datetime import datetime, timezone
9
+ from typing import Dict, Any, List, Optional
10
+ from collections import deque
11
+
12
+ import motor.motor_asyncio
13
+ from log import log
14
+ from .cache_manager import UnifiedCacheManager, CacheBackend
15
+
16
+
17
+ class MongoDBCacheBackend(CacheBackend):
18
+ """MongoDB缓存后端实现"""
19
+
20
+ def __init__(self, db, collection_name: str, doc_key: str):
21
+ self._db = db
22
+ self._collection_name = collection_name
23
+ self._doc_key = doc_key
24
+
25
+ async def load_data(self) -> Dict[str, Any]:
26
+ """从MongoDB文档加载数据"""
27
+ try:
28
+ collection = self._db[self._collection_name]
29
+ doc = await collection.find_one({"key": self._doc_key})
30
+
31
+ if doc and "data" in doc:
32
+ return doc["data"]
33
+ return {}
34
+
35
+ except Exception as e:
36
+ log.error(f"Error loading data from MongoDB document {self._doc_key}: {e}")
37
+ return {}
38
+
39
+ async def write_data(self, data: Dict[str, Any]) -> bool:
40
+ """将数据写入MongoDB文档"""
41
+ try:
42
+ collection = self._db[self._collection_name]
43
+
44
+ doc = {
45
+ "key": self._doc_key,
46
+ "data": data,
47
+ "updated_at": datetime.now(timezone.utc)
48
+ }
49
+
50
+ await collection.replace_one(
51
+ {"key": self._doc_key},
52
+ doc,
53
+ upsert=True
54
+ )
55
+ return True
56
+
57
+ except Exception as e:
58
+ log.error(f"Error writing data to MongoDB document {self._doc_key}: {e}")
59
+ return False
60
+
61
+
62
+ class MongoDBManager:
63
+ """MongoDB数据库管理器"""
64
+
65
+ def __init__(self):
66
+ self._client: Optional[motor.motor_asyncio.AsyncIOMotorClient] = None
67
+ self._db: Optional[motor.motor_asyncio.AsyncIOMotorDatabase] = None
68
+ self._initialized = False
69
+ self._lock = asyncio.Lock()
70
+
71
+ # 配置
72
+ self._connection_uri = None
73
+ self._database_name = None
74
+
75
+ # 单文档设计 - 所有凭证存在一个文档中(类似TOML文件)
76
+ self._collection_name = "credentials_data"
77
+
78
+ # 性能监控
79
+ self._operation_count = 0
80
+ self._operation_times = deque(maxlen=5000)
81
+
82
+ # 统一缓存管理器
83
+ self._credentials_cache_manager: Optional[UnifiedCacheManager] = None
84
+ self._config_cache_manager: Optional[UnifiedCacheManager] = None
85
+
86
+ # 文档key定义
87
+ self._credentials_doc_key = "all_credentials"
88
+ self._config_doc_key = "config_data"
89
+
90
+ # 写入配置参数
91
+ self._write_delay = 1.0 # 写入延迟(秒)
92
+ self._cache_ttl = 300 # 缓存TTL(秒)
93
+
94
+ async def initialize(self):
95
+ """初始化MongoDB连接"""
96
+ async with self._lock:
97
+ if self._initialized:
98
+ return
99
+
100
+ try:
101
+ # 获取连接配置
102
+ self._connection_uri = os.getenv("MONGODB_URI")
103
+ self._database_name = os.getenv("MONGODB_DATABASE", "gcli2api")
104
+
105
+ if not self._connection_uri:
106
+ raise ValueError("MONGODB_URI environment variable is required")
107
+
108
+ # 建立连接
109
+ self._client = motor.motor_asyncio.AsyncIOMotorClient(
110
+ self._connection_uri,
111
+ serverSelectionTimeoutMS=5000,
112
+ maxPoolSize=100,
113
+ minPoolSize=10,
114
+ maxIdleTimeMS=45000,
115
+ waitQueueTimeoutMS=10000,
116
+ )
117
+
118
+ # 验证连接
119
+ await self._client.admin.command('ping')
120
+
121
+ # 获取数据库
122
+ self._db = self._client[self._database_name]
123
+
124
+ # 创建索引
125
+ await self._create_indexes()
126
+
127
+ # 创建缓存管理器
128
+ credentials_backend = MongoDBCacheBackend(self._db, self._collection_name, self._credentials_doc_key)
129
+ config_backend = MongoDBCacheBackend(self._db, self._collection_name, self._config_doc_key)
130
+
131
+ self._credentials_cache_manager = UnifiedCacheManager(
132
+ credentials_backend,
133
+ cache_ttl=self._cache_ttl,
134
+ write_delay=self._write_delay,
135
+ name="credentials"
136
+ )
137
+
138
+ self._config_cache_manager = UnifiedCacheManager(
139
+ config_backend,
140
+ cache_ttl=self._cache_ttl,
141
+ write_delay=self._write_delay,
142
+ name="config"
143
+ )
144
+
145
+ # 启动缓存管理器
146
+ await self._credentials_cache_manager.start()
147
+ await self._config_cache_manager.start()
148
+
149
+ self._initialized = True
150
+ log.info(f"MongoDB connection established to {self._database_name} with unified cache")
151
+
152
+ except Exception as e:
153
+ log.error(f"Error initializing MongoDB: {e}")
154
+ raise
155
+
156
+ async def _create_indexes(self):
157
+ """创建简单索引(单文档设计)"""
158
+ try:
159
+ # 单文档设计只需要主键索引
160
+ await self._db[self._collection_name].create_index("key", unique=True)
161
+ await self._db[self._collection_name].create_index("updated_at")
162
+
163
+ log.info("MongoDB indexes created for single-document design")
164
+
165
+ except Exception as e:
166
+ log.error(f"Error creating MongoDB indexes: {e}")
167
+
168
+ async def close(self):
169
+ """关闭MongoDB连接"""
170
+ # 停止缓存管理器
171
+ if self._credentials_cache_manager:
172
+ await self._credentials_cache_manager.stop()
173
+ if self._config_cache_manager:
174
+ await self._config_cache_manager.stop()
175
+
176
+ if self._client:
177
+ self._client.close()
178
+ self._initialized = False
179
+ log.info("MongoDB connection closed with unified cache flushed")
180
+
181
+ def _ensure_initialized(self):
182
+ """确保已初始化"""
183
+ if not self._initialized:
184
+ raise RuntimeError("MongoDB manager not initialized")
185
+
186
+ def _get_default_state(self) -> Dict[str, Any]:
187
+ """获取默认状态数据"""
188
+ return {
189
+ "error_codes": [],
190
+ "disabled": False,
191
+ "last_success": time.time(),
192
+ "user_email": None,
193
+ }
194
+
195
+ def _get_default_stats(self) -> Dict[str, Any]:
196
+ """获取默认统计数据"""
197
+ return {
198
+ "gemini_2_5_pro_calls": 0,
199
+ "total_calls": 0,
200
+ "next_reset_time": None,
201
+ "daily_limit_gemini_2_5_pro": 100,
202
+ "daily_limit_total": 1000
203
+ }
204
+
205
+ # ============ 凭证管理 ============
206
+
207
+ async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
208
+ """存储凭证数据到统一缓存"""
209
+ self._ensure_initialized()
210
+ start_time = time.time()
211
+
212
+ try:
213
+ # 获取现有数据或创建新数据
214
+ existing_data = await self._credentials_cache_manager.get(filename, {})
215
+
216
+ credential_entry = {
217
+ "credential": credential_data,
218
+ "state": existing_data.get("state", self._get_default_state()),
219
+ "stats": existing_data.get("stats", self._get_default_stats())
220
+ }
221
+
222
+ success = await self._credentials_cache_manager.set(filename, credential_entry)
223
+
224
+ # 性能监控
225
+ self._operation_count += 1
226
+ operation_time = time.time() - start_time
227
+ self._operation_times.append(operation_time)
228
+
229
+ log.debug(f"Stored credential to unified cache: {filename} in {operation_time:.3f}s")
230
+ return success
231
+
232
+ except Exception as e:
233
+ operation_time = time.time() - start_time
234
+ log.error(f"Error storing credential {filename} in {operation_time:.3f}s: {e}")
235
+ return False
236
+
237
+ async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
238
+ """从统一缓存获取凭证数据"""
239
+ self._ensure_initialized()
240
+ start_time = time.time()
241
+
242
+ try:
243
+ credential_entry = await self._credentials_cache_manager.get(filename)
244
+
245
+ # 性能监控
246
+ self._operation_count += 1
247
+ operation_time = time.time() - start_time
248
+ self._operation_times.append(operation_time)
249
+
250
+ if credential_entry and "credential" in credential_entry:
251
+ return credential_entry["credential"]
252
+ return None
253
+
254
+ except Exception as e:
255
+ operation_time = time.time() - start_time
256
+ log.error(f"Error retrieving credential {filename} in {operation_time:.3f}s: {e}")
257
+ return None
258
+
259
+ async def list_credentials(self) -> List[str]:
260
+ """从统一缓存列出所有凭证文件名"""
261
+ self._ensure_initialized()
262
+ start_time = time.time()
263
+
264
+ try:
265
+ all_data = await self._credentials_cache_manager.get_all()
266
+ filenames = list(all_data.keys())
267
+
268
+ # 性能监控
269
+ self._operation_count += 1
270
+ operation_time = time.time() - start_time
271
+ self._operation_times.append(operation_time)
272
+
273
+ log.debug(f"Listed {len(filenames)} credentials from unified cache in {operation_time:.3f}s")
274
+ return filenames
275
+
276
+ except Exception as e:
277
+ operation_time = time.time() - start_time
278
+ log.error(f"Error listing credentials in {operation_time:.3f}s: {e}")
279
+ return []
280
+
281
+ async def delete_credential(self, filename: str) -> bool:
282
+ """从统一缓存删除凭证及所有相关数据"""
283
+ self._ensure_initialized()
284
+ start_time = time.time()
285
+
286
+ try:
287
+ success = await self._credentials_cache_manager.delete(filename)
288
+
289
+ # 性能监控
290
+ self._operation_count += 1
291
+ operation_time = time.time() - start_time
292
+ self._operation_times.append(operation_time)
293
+
294
+ log.debug(f"Deleted credential from unified cache: {filename} in {operation_time:.3f}s")
295
+ return success
296
+
297
+ except Exception as e:
298
+ operation_time = time.time() - start_time
299
+ log.error(f"Error deleting credential {filename} in {operation_time:.3f}s: {e}")
300
+ return False
301
+
302
+ # ============ 状态管理 ============
303
+
304
+ async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
305
+ """更新凭证状态(使用统一缓存)"""
306
+ self._ensure_initialized()
307
+ start_time = time.time()
308
+
309
+ try:
310
+ # 获取现有数据或创建新数据
311
+ existing_data = await self._credentials_cache_manager.get(filename, {})
312
+
313
+ if not existing_data:
314
+ existing_data = {
315
+ "credential": {},
316
+ "state": self._get_default_state(),
317
+ "stats": self._get_default_stats()
318
+ }
319
+
320
+ # 更新状态数据
321
+ existing_data["state"].update(state_updates)
322
+
323
+ success = await self._credentials_cache_manager.set(filename, existing_data)
324
+
325
+ # 性能监控
326
+ self._operation_count += 1
327
+ operation_time = time.time() - start_time
328
+ self._operation_times.append(operation_time)
329
+
330
+ log.debug(f"Updated credential state in unified cache: {filename} in {operation_time:.3f}s")
331
+ return success
332
+
333
+ except Exception as e:
334
+ operation_time = time.time() - start_time
335
+ log.error(f"Error updating credential state {filename} in {operation_time:.3f}s: {e}")
336
+ return False
337
+
338
+ async def get_credential_state(self, filename: str) -> Dict[str, Any]:
339
+ """从统一缓存获取凭证状态"""
340
+ self._ensure_initialized()
341
+ start_time = time.time()
342
+
343
+ try:
344
+ credential_entry = await self._credentials_cache_manager.get(filename)
345
+
346
+ # 性能监控
347
+ self._operation_count += 1
348
+ operation_time = time.time() - start_time
349
+ self._operation_times.append(operation_time)
350
+
351
+ if credential_entry and "state" in credential_entry:
352
+ log.debug(f"Retrieved credential state from unified cache: {filename} in {operation_time:.3f}s")
353
+ return credential_entry["state"]
354
+ else:
355
+ # 返回默认状态
356
+ return self._get_default_state()
357
+
358
+ except Exception as e:
359
+ operation_time = time.time() - start_time
360
+ log.error(f"Error getting credential state {filename} in {operation_time:.3f}s: {e}")
361
+ return self._get_default_state()
362
+
363
+ async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
364
+ """从统一缓存获取所有凭证状态"""
365
+ self._ensure_initialized()
366
+ start_time = time.time()
367
+
368
+ try:
369
+ all_data = await self._credentials_cache_manager.get_all()
370
+
371
+ states = {}
372
+ for filename, cred_data in all_data.items():
373
+ states[filename] = cred_data.get("state", self._get_default_state())
374
+
375
+ # 性能监控
376
+ self._operation_count += 1
377
+ operation_time = time.time() - start_time
378
+ self._operation_times.append(operation_time)
379
+
380
+ log.debug(f"Retrieved all credential states from unified cache ({len(states)}) in {operation_time:.3f}s")
381
+ return states
382
+
383
+ except Exception as e:
384
+ operation_time = time.time() - start_time
385
+ log.error(f"Error getting all credential states in {operation_time:.3f}s: {e}")
386
+ return {}
387
+
388
+ # ============ 配置管理 ============
389
+
390
+ async def set_config(self, key: str, value: Any) -> bool:
391
+ """设置配置到统一缓存"""
392
+ self._ensure_initialized()
393
+ return await self._config_cache_manager.set(key, value)
394
+
395
+ async def get_config(self, key: str, default: Any = None) -> Any:
396
+ """从统一缓存获取配置"""
397
+ self._ensure_initialized()
398
+ return await self._config_cache_manager.get(key, default)
399
+
400
+ async def get_all_config(self) -> Dict[str, Any]:
401
+ """从统一缓存获取所有配置"""
402
+ self._ensure_initialized()
403
+ return await self._config_cache_manager.get_all()
404
+
405
+ async def delete_config(self, key: str) -> bool:
406
+ """从统一缓存删除配置"""
407
+ self._ensure_initialized()
408
+ return await self._config_cache_manager.delete(key)
409
+
410
+ # ============ 使用统计管理 ============
411
+
412
+ async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
413
+ """更新使用统计(使用统一缓存)"""
414
+ self._ensure_initialized()
415
+ start_time = time.time()
416
+
417
+ try:
418
+ # 获取现有数据或创建新数据
419
+ existing_data = await self._credentials_cache_manager.get(filename, {})
420
+
421
+ if not existing_data:
422
+ existing_data = {
423
+ "credential": {},
424
+ "state": self._get_default_state(),
425
+ "stats": self._get_default_stats()
426
+ }
427
+
428
+ # 更新统计数据
429
+ existing_data["stats"].update(stats_updates)
430
+
431
+ success = await self._credentials_cache_manager.set(filename, existing_data)
432
+
433
+ # 性能监控
434
+ self._operation_count += 1
435
+ operation_time = time.time() - start_time
436
+ self._operation_times.append(operation_time)
437
+
438
+ log.debug(f"Updated usage stats in unified cache: {filename} in {operation_time:.3f}s")
439
+ return success
440
+
441
+ except Exception as e:
442
+ operation_time = time.time() - start_time
443
+ log.error(f"Error updating usage stats {filename} in {operation_time:.3f}s: {e}")
444
+ return False
445
+
446
+ async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
447
+ """从统一缓存获取使用统计"""
448
+ self._ensure_initialized()
449
+ start_time = time.time()
450
+
451
+ try:
452
+ credential_entry = await self._credentials_cache_manager.get(filename)
453
+
454
+ # 性能监控
455
+ self._operation_count += 1
456
+ operation_time = time.time() - start_time
457
+ self._operation_times.append(operation_time)
458
+
459
+ if credential_entry and "stats" in credential_entry:
460
+ log.debug(f"Retrieved usage stats from unified cache: {filename} in {operation_time:.3f}s")
461
+ return credential_entry["stats"]
462
+ else:
463
+ return self._get_default_stats()
464
+
465
+ except Exception as e:
466
+ operation_time = time.time() - start_time
467
+ log.error(f"Error getting usage stats {filename} in {operation_time:.3f}s: {e}")
468
+ return self._get_default_stats()
469
+
470
+ async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
471
+ """从统一缓存获取所有使用统计"""
472
+ self._ensure_initialized()
473
+ start_time = time.time()
474
+
475
+ try:
476
+ all_data = await self._credentials_cache_manager.get_all()
477
+
478
+ stats = {}
479
+ for filename, cred_data in all_data.items():
480
+ if "stats" in cred_data:
481
+ stats[filename] = cred_data["stats"]
482
+
483
+ # 性能监控
484
+ self._operation_count += 1
485
+ operation_time = time.time() - start_time
486
+ self._operation_times.append(operation_time)
487
+
488
+ log.debug(f"Retrieved all usage stats from unified cache ({len(stats)}) in {operation_time:.3f}s")
489
+ return stats
490
+
491
+ except Exception as e:
492
+ operation_time = time.time() - start_time
493
+ log.error(f"Error getting all usage stats in {operation_time:.3f}s: {e}")
494
+ return {}
src/storage/postgres_manager.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Postgres数据库管理器,采用单行设计并兼容 UnifiedCacheManager。
3
+ 实现与 mongodb_manager.py 风格一致的接口(异步)。
4
+ 需要环境变量: POSTGRES_DSN (例如: postgresql://user:pass@host:port/dbname)
5
+ """
6
+ import asyncio
7
+ import os
8
+ import time
9
+ import json
10
+ from datetime import datetime, timezone
11
+ from typing import Dict, Any, List, Optional
12
+ from collections import deque
13
+
14
+ import asyncpg
15
+ from log import log
16
+ from .cache_manager import UnifiedCacheManager, CacheBackend
17
+
18
+
19
+ class PostgresCacheBackend(CacheBackend):
20
+ """Postgres缓存后端,数据存储为key, data(JSONB), updated_at
21
+ 单行/单表设计:表名由管理器指定,每行以key区分。
22
+ """
23
+
24
+ def __init__(self, conn_pool, table_name: str, row_key: str):
25
+ self._pool = conn_pool
26
+ self._table_name = table_name
27
+ self._row_key = row_key
28
+
29
+ async def load_data(self) -> Dict[str, Any]:
30
+ try:
31
+ async with self._pool.acquire() as conn:
32
+ row = await conn.fetchrow(
33
+ f"SELECT data FROM {self._table_name} WHERE key = $1",
34
+ self._row_key
35
+ )
36
+ if row and row.get('data') is not None:
37
+ data = row['data']
38
+ # JSONB字段返回JSON字符串,需要解析为字典
39
+ if isinstance(data, str):
40
+ return json.loads(data)
41
+ elif isinstance(data, dict):
42
+ return data
43
+ else:
44
+ log.warning(f"Unexpected data type from JSONB field: {type(data)}")
45
+ return {}
46
+ return {}
47
+ except Exception as e:
48
+ log.error(f"Error loading data from Postgres row {self._row_key}: {e}")
49
+ return {}
50
+
51
+ async def write_data(self, data: Dict[str, Any]) -> bool:
52
+ try:
53
+ async with self._pool.acquire() as conn:
54
+ await conn.execute(
55
+ f"INSERT INTO {self._table_name}(key, data, updated_at) VALUES($1, $2::jsonb, $3)"
56
+ " ON CONFLICT (key) DO UPDATE SET data = EXCLUDED.data, updated_at = EXCLUDED.updated_at",
57
+ self._row_key, json.dumps(data, default=str), datetime.now(timezone.utc)
58
+ )
59
+ return True
60
+ except Exception as e:
61
+ log.error(f"Error writing data to Postgres row {self._row_key}: {e}")
62
+ return False
63
+
64
+
65
+ class PostgresManager:
66
+ """Postgres管理器。
67
+ 使用单表单行设计存储凭证和配置数据。
68
+ """
69
+
70
+ def __init__(self):
71
+ self._pool: Optional[asyncpg.pool.Pool] = None
72
+ self._initialized = False
73
+ self._lock = asyncio.Lock()
74
+
75
+ self._dsn = None
76
+ self._table_name = 'unified_storage'
77
+
78
+ self._operation_count = 0
79
+
80
+ self._operation_times = deque(maxlen=5000)
81
+
82
+ self._credentials_cache_manager: Optional[UnifiedCacheManager] = None
83
+ self._config_cache_manager: Optional[UnifiedCacheManager] = None
84
+
85
+ self._credentials_row_key = 'all_credentials'
86
+ self._config_row_key = 'config_data'
87
+
88
+ self._write_delay = 1.0
89
+ self._cache_ttl = 300
90
+
91
+ async def initialize(self):
92
+ async with self._lock:
93
+ if self._initialized:
94
+ return
95
+ try:
96
+ self._dsn = os.getenv('POSTGRES_DSN')
97
+ if not self._dsn:
98
+ raise ValueError('POSTGRES_DSN environment variable is required')
99
+
100
+ self._pool = await asyncpg.create_pool(dsn=self._dsn, max_size=20, min_size=1)
101
+
102
+ # 确保表存在
103
+ await self._ensure_table()
104
+
105
+ # 创建缓存管理器后端
106
+ credentials_backend = PostgresCacheBackend(self._pool, self._table_name, self._credentials_row_key)
107
+ config_backend = PostgresCacheBackend(self._pool, self._table_name, self._config_row_key)
108
+
109
+ self._credentials_cache_manager = UnifiedCacheManager(
110
+ credentials_backend, cache_ttl=self._cache_ttl, write_delay=self._write_delay, name='credentials'
111
+ )
112
+ self._config_cache_manager = UnifiedCacheManager(
113
+ config_backend, cache_ttl=self._cache_ttl, write_delay=self._write_delay, name='config'
114
+ )
115
+
116
+ await self._credentials_cache_manager.start()
117
+ await self._config_cache_manager.start()
118
+
119
+ self._initialized = True
120
+ log.info('Postgres connection established with unified cache')
121
+ except Exception as e:
122
+ log.error(f'Error initializing Postgres: {e}')
123
+ raise
124
+
125
+ async def _ensure_table(self):
126
+ try:
127
+ async with self._pool.acquire() as conn:
128
+ await conn.execute(
129
+ f"CREATE TABLE IF NOT EXISTS {self._table_name}(\n key TEXT PRIMARY KEY,\n data JSONB,\n updated_at TIMESTAMPTZ\n )"
130
+ )
131
+ except Exception as e:
132
+ log.error(f'Error ensuring Postgres table: {e}')
133
+ raise
134
+
135
+ async def close(self):
136
+ if self._credentials_cache_manager:
137
+ await self._credentials_cache_manager.stop()
138
+ if self._config_cache_manager:
139
+ await self._config_cache_manager.stop()
140
+ if self._pool:
141
+ await self._pool.close()
142
+ self._initialized = False
143
+ log.info('Postgres connection closed with unified cache flushed')
144
+
145
+ def _ensure_initialized(self):
146
+ if not self._initialized:
147
+ raise RuntimeError('Postgres manager not initialized')
148
+
149
+ def _get_default_state(self) -> Dict[str, Any]:
150
+ return {
151
+ 'error_codes': [],
152
+ 'disabled': False,
153
+ 'last_success': time.time(),
154
+ 'user_email': None,
155
+ }
156
+
157
+ def _get_default_stats(self) -> Dict[str, Any]:
158
+ return {
159
+ 'gemini_2_5_pro_calls': 0,
160
+ 'total_calls': 0,
161
+ 'next_reset_time': None,
162
+ 'daily_limit_gemini_2_5_pro': 100,
163
+ 'daily_limit_total': 1000
164
+ }
165
+
166
+ # 以下方法委托给 UnifiedCacheManager
167
+ async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
168
+ self._ensure_initialized()
169
+ start_time = time.time()
170
+ try:
171
+ existing_data = await self._credentials_cache_manager.get(filename, {})
172
+ credential_entry = {
173
+ 'credential': credential_data,
174
+ 'state': existing_data.get('state', self._get_default_state()),
175
+ 'stats': existing_data.get('stats', self._get_default_stats())
176
+ }
177
+ success = await self._credentials_cache_manager.set(filename, credential_entry)
178
+ self._operation_count += 1
179
+ self._operation_times.append(time.time() - start_time)
180
+ log.debug(f'Stored credential to unified cache (postgres): {filename}')
181
+ return success
182
+ except Exception as e:
183
+ log.error(f'Error storing credential {filename} in Postgres: {e}')
184
+ return False
185
+
186
+ async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
187
+ self._ensure_initialized()
188
+ try:
189
+ credential_entry = await self._credentials_cache_manager.get(filename)
190
+ self._operation_count += 1
191
+ if credential_entry and 'credential' in credential_entry:
192
+ return credential_entry['credential']
193
+ return None
194
+ except Exception as e:
195
+ log.error(f'Error retrieving credential {filename} from Postgres: {e}')
196
+ return None
197
+
198
+ async def list_credentials(self) -> List[str]:
199
+ self._ensure_initialized()
200
+ try:
201
+ all_data = await self._credentials_cache_manager.get_all()
202
+ return list(all_data.keys())
203
+ except Exception as e:
204
+ log.error(f'Error listing credentials from Postgres: {e}')
205
+ return []
206
+
207
+ async def delete_credential(self, filename: str) -> bool:
208
+ self._ensure_initialized()
209
+ try:
210
+ return await self._credentials_cache_manager.delete(filename)
211
+ except Exception as e:
212
+ log.error(f'Error deleting credential {filename} from Postgres: {e}')
213
+ return False
214
+
215
+ async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
216
+ self._ensure_initialized()
217
+ try:
218
+ existing_data = await self._credentials_cache_manager.get(filename, {})
219
+ if not existing_data:
220
+ existing_data = {'credential': {}, 'state': self._get_default_state(), 'stats': self._get_default_stats()}
221
+ existing_data['state'].update(state_updates)
222
+ return await self._credentials_cache_manager.set(filename, existing_data)
223
+ except Exception as e:
224
+ log.error(f'Error updating credential state {filename} in Postgres: {e}')
225
+ return False
226
+
227
+ async def get_credential_state(self, filename: str) -> Dict[str, Any]:
228
+ self._ensure_initialized()
229
+ try:
230
+ credential_entry = await self._credentials_cache_manager.get(filename)
231
+ if credential_entry and 'state' in credential_entry:
232
+ return credential_entry['state']
233
+ return self._get_default_state()
234
+ except Exception as e:
235
+ log.error(f'Error getting credential state {filename} from Postgres: {e}')
236
+ return self._get_default_state()
237
+
238
+ async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
239
+ self._ensure_initialized()
240
+ try:
241
+ all_data = await self._credentials_cache_manager.get_all()
242
+ states = {fn: data.get('state', self._get_default_state()) for fn, data in all_data.items()}
243
+ return states
244
+ except Exception as e:
245
+ log.error(f'Error getting all credential states from Postgres: {e}')
246
+ return {}
247
+
248
+ async def set_config(self, key: str, value: Any) -> bool:
249
+ self._ensure_initialized()
250
+ return await self._config_cache_manager.set(key, value)
251
+
252
+ async def get_config(self, key: str, default: Any = None) -> Any:
253
+ self._ensure_initialized()
254
+ return await self._config_cache_manager.get(key, default)
255
+
256
+ async def get_all_config(self) -> Dict[str, Any]:
257
+ self._ensure_initialized()
258
+ return await self._config_cache_manager.get_all()
259
+
260
+ async def delete_config(self, key: str) -> bool:
261
+ self._ensure_initialized()
262
+ return await self._config_cache_manager.delete(key)
263
+
264
+ async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
265
+ self._ensure_initialized()
266
+ try:
267
+ existing_data = await self._credentials_cache_manager.get(filename, {})
268
+ if not existing_data:
269
+ existing_data = {'credential': {}, 'state': self._get_default_state(), 'stats': self._get_default_stats()}
270
+ existing_data['stats'].update(stats_updates)
271
+ return await self._credentials_cache_manager.set(filename, existing_data)
272
+ except Exception as e:
273
+ log.error(f'Error updating usage stats for {filename} in Postgres: {e}')
274
+ return False
275
+
276
+ async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
277
+ self._ensure_initialized()
278
+ try:
279
+ credential_entry = await self._credentials_cache_manager.get(filename)
280
+ if credential_entry and 'stats' in credential_entry:
281
+ return credential_entry['stats']
282
+ return self._get_default_stats()
283
+ except Exception as e:
284
+ log.error(f'Error getting usage stats for {filename} from Postgres: {e}')
285
+ return self._get_default_stats()
286
+
287
+ async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
288
+ self._ensure_initialized()
289
+ try:
290
+ all_data = await self._credentials_cache_manager.get_all()
291
+ stats = {fn: data.get('stats', self._get_default_stats()) for fn, data in all_data.items()}
292
+ return stats
293
+ except Exception as e:
294
+ log.error(f'Error getting all usage stats from Postgres: {e}')
295
+ return {}
296
+
src/storage/redis_manager.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Redis数据库管理器,使用哈希表设计和统一缓存。
3
+ 所有凭证数据存储在一个哈希表中,配置数据存储在另一个哈希表中。
4
+ """
5
+ import asyncio
6
+ import json
7
+ import os
8
+ import time
9
+ from typing import Dict, Any, List, Optional
10
+ from collections import deque
11
+
12
+ import redis.asyncio as redis
13
+ from log import log
14
+ from .cache_manager import UnifiedCacheManager, CacheBackend
15
+
16
+
17
+ class RedisCacheBackend(CacheBackend):
18
+ """Redis缓存后端实现"""
19
+
20
+ def __init__(self, redis_client: redis.Redis, hash_name: str):
21
+ self._client = redis_client
22
+ self._hash_name = hash_name
23
+
24
+ async def load_data(self) -> Dict[str, Any]:
25
+ """从Redis哈希表加载数据"""
26
+ try:
27
+ hash_data = await self._client.hgetall(self._hash_name)
28
+ if not hash_data:
29
+ return {}
30
+
31
+ result = {}
32
+ for key, value_str in hash_data.items():
33
+ try:
34
+ result[key] = json.loads(value_str)
35
+ except json.JSONDecodeError as e:
36
+ log.error(f"Error deserializing Redis data for key {key}: {e}")
37
+ continue
38
+ return result
39
+ except Exception as e:
40
+ log.error(f"Error loading data from Redis hash {self._hash_name}: {e}")
41
+ return {}
42
+
43
+ async def write_data(self, data: Dict[str, Any]) -> bool:
44
+ """将数据写入Redis哈希表"""
45
+ try:
46
+ if not data:
47
+ await self._client.delete(self._hash_name)
48
+ return True
49
+
50
+ hash_data = {}
51
+ for key, value in data.items():
52
+ try:
53
+ hash_data[key] = json.dumps(value, ensure_ascii=False)
54
+ except (TypeError, ValueError) as e:
55
+ log.error(f"Error serializing data for key {key}: {e}")
56
+ continue
57
+
58
+ if not hash_data:
59
+ return True
60
+
61
+ pipe = self._client.pipeline()
62
+ pipe.delete(self._hash_name)
63
+ pipe.hset(self._hash_name, mapping=hash_data)
64
+ await pipe.execute()
65
+ return True
66
+ except Exception as e:
67
+ log.error(f"Error writing data to Redis hash {self._hash_name}: {e}")
68
+ return False
69
+
70
+
71
+ class RedisManager:
72
+ """Redis数据库管理器"""
73
+
74
+ def __init__(self):
75
+ self._client: Optional[redis.Redis] = None
76
+ self._initialized = False
77
+ self._lock = asyncio.Lock()
78
+
79
+ # 配置
80
+ self._connection_uri = None
81
+ self._database_index = 0
82
+
83
+ # 哈希表设计 - 所有凭证存在一个哈希表中
84
+ self._credentials_hash_name = "gcli2api:credentials"
85
+ self._config_hash_name = "gcli2api:config"
86
+
87
+ # 性能监控
88
+ self._operation_count = 0
89
+ self._operation_times = deque(maxlen=5000)
90
+
91
+ # 统一缓存管理器
92
+ self._credentials_cache_manager: Optional[UnifiedCacheManager] = None
93
+ self._config_cache_manager: Optional[UnifiedCacheManager] = None
94
+
95
+ # 写入配置参数
96
+ self._write_delay = 1.0 # 写入延迟(秒)
97
+ self._cache_ttl = 300 # 缓存TTL(秒)
98
+
99
+ async def initialize(self):
100
+ """初始化Redis连接"""
101
+ async with self._lock:
102
+ if self._initialized:
103
+ return
104
+
105
+ try:
106
+ # 获取连接配置
107
+ self._connection_uri = os.getenv("REDIS_URI", "redis://localhost:6379")
108
+ self._database_index = int(os.getenv("REDIS_DATABASE", "0"))
109
+
110
+ # 建立连接 - 使用最简配置确保兼容性
111
+ # 检查是否需要 SSL
112
+ if self._connection_uri.startswith("rediss://"):
113
+ # SSL 连接
114
+ self._client = redis.from_url(
115
+ self._connection_uri,
116
+ db=self._database_index,
117
+ decode_responses=True,
118
+ ssl_cert_reqs=None,
119
+ ssl_check_hostname=False,
120
+ ssl_ca_certs=None
121
+ )
122
+ else:
123
+ # 普通连接
124
+ self._client = redis.from_url(
125
+ self._connection_uri,
126
+ db=self._database_index,
127
+ decode_responses=True
128
+ )
129
+
130
+ # 验证连接
131
+ await self._client.ping()
132
+
133
+ # 创建缓存管理器
134
+ credentials_backend = RedisCacheBackend(self._client, self._credentials_hash_name)
135
+ config_backend = RedisCacheBackend(self._client, self._config_hash_name)
136
+
137
+ self._credentials_cache_manager = UnifiedCacheManager(
138
+ credentials_backend,
139
+ cache_ttl=self._cache_ttl,
140
+ write_delay=self._write_delay,
141
+ name="credentials"
142
+ )
143
+
144
+ self._config_cache_manager = UnifiedCacheManager(
145
+ config_backend,
146
+ cache_ttl=self._cache_ttl,
147
+ write_delay=self._write_delay,
148
+ name="config"
149
+ )
150
+
151
+ # 启动缓存管理器
152
+ await self._credentials_cache_manager.start()
153
+ await self._config_cache_manager.start()
154
+
155
+ self._initialized = True
156
+ log.info(f"Redis connection established to database {self._database_index} with unified cache")
157
+
158
+ except Exception as e:
159
+ log.error(f"Error initializing Redis: {e}")
160
+ raise
161
+
162
+ async def close(self):
163
+ """关闭Redis连接"""
164
+ # 停止缓存管理器
165
+ if self._credentials_cache_manager:
166
+ await self._credentials_cache_manager.stop()
167
+ if self._config_cache_manager:
168
+ await self._config_cache_manager.stop()
169
+
170
+ if self._client:
171
+ await self._client.close()
172
+ self._initialized = False
173
+ log.info("Redis connection closed with unified cache flushed")
174
+
175
+ def _ensure_initialized(self):
176
+ """确保已初始化"""
177
+ if not self._initialized:
178
+ raise RuntimeError("Redis manager not initialized")
179
+
180
+ def _get_default_state(self) -> Dict[str, Any]:
181
+ """获取默认状态数据"""
182
+ return {
183
+ "error_codes": [],
184
+ "disabled": False,
185
+ "last_success": time.time(),
186
+ "user_email": None,
187
+ }
188
+
189
+ def _get_default_stats(self) -> Dict[str, Any]:
190
+ """获取默认统计数据"""
191
+ return {
192
+ "gemini_2_5_pro_calls": 0,
193
+ "total_calls": 0,
194
+ "next_reset_time": None,
195
+ "daily_limit_gemini_2_5_pro": 100,
196
+ "daily_limit_total": 1000
197
+ }
198
+
199
+ # ============ 凭证管理 ============
200
+
201
+ async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
202
+ """存储凭证数据到统一缓存"""
203
+ self._ensure_initialized()
204
+ start_time = time.time()
205
+
206
+ try:
207
+ # 获取现有数据或创建新数据
208
+ existing_data = await self._credentials_cache_manager.get(filename, {})
209
+
210
+ credential_entry = {
211
+ "credential": credential_data,
212
+ "state": existing_data.get("state", self._get_default_state()),
213
+ "stats": existing_data.get("stats", self._get_default_stats())
214
+ }
215
+
216
+ success = await self._credentials_cache_manager.set(filename, credential_entry)
217
+
218
+ # 性能监控
219
+ self._operation_count += 1
220
+ operation_time = time.time() - start_time
221
+ self._operation_times.append(operation_time)
222
+
223
+ log.debug(f"Stored credential to unified cache: {filename} in {operation_time:.3f}s")
224
+ return success
225
+
226
+ except Exception as e:
227
+ operation_time = time.time() - start_time
228
+ log.error(f"Error storing credential {filename} in {operation_time:.3f}s: {e}")
229
+ return False
230
+
231
+ async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
232
+ """从统一缓存获取凭证数据"""
233
+ self._ensure_initialized()
234
+ start_time = time.time()
235
+
236
+ try:
237
+ credential_entry = await self._credentials_cache_manager.get(filename)
238
+
239
+ # 性能监控
240
+ self._operation_count += 1
241
+ operation_time = time.time() - start_time
242
+ self._operation_times.append(operation_time)
243
+
244
+ if credential_entry and "credential" in credential_entry:
245
+ return credential_entry["credential"]
246
+ return None
247
+
248
+ except Exception as e:
249
+ operation_time = time.time() - start_time
250
+ log.error(f"Error retrieving credential {filename} in {operation_time:.3f}s: {e}")
251
+ return None
252
+
253
+ async def list_credentials(self) -> List[str]:
254
+ """从统一缓存列出所有凭证文件名"""
255
+ self._ensure_initialized()
256
+ start_time = time.time()
257
+
258
+ try:
259
+ all_data = await self._credentials_cache_manager.get_all()
260
+ filenames = list(all_data.keys())
261
+
262
+ # 性能监控
263
+ self._operation_count += 1
264
+ operation_time = time.time() - start_time
265
+ self._operation_times.append(operation_time)
266
+
267
+ log.debug(f"Listed {len(filenames)} credentials from unified cache in {operation_time:.3f}s")
268
+ return filenames
269
+
270
+ except Exception as e:
271
+ operation_time = time.time() - start_time
272
+ log.error(f"Error listing credentials in {operation_time:.3f}s: {e}")
273
+ return []
274
+
275
+ async def delete_credential(self, filename: str) -> bool:
276
+ """从统一缓存删除凭证及所有相关数据"""
277
+ self._ensure_initialized()
278
+ start_time = time.time()
279
+
280
+ try:
281
+ success = await self._credentials_cache_manager.delete(filename)
282
+
283
+ # 性能监控
284
+ self._operation_count += 1
285
+ operation_time = time.time() - start_time
286
+ self._operation_times.append(operation_time)
287
+
288
+ log.debug(f"Deleted credential from unified cache: {filename} in {operation_time:.3f}s")
289
+ return success
290
+
291
+ except Exception as e:
292
+ operation_time = time.time() - start_time
293
+ log.error(f"Error deleting credential {filename} in {operation_time:.3f}s: {e}")
294
+ return False
295
+
296
+ # ============ 状态管理 ============
297
+
298
+ async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
299
+ """更新凭证状态(使用统一缓存)"""
300
+ self._ensure_initialized()
301
+ start_time = time.time()
302
+
303
+ try:
304
+ # 获取现有数据或创建新数据
305
+ existing_data = await self._credentials_cache_manager.get(filename, {})
306
+
307
+ if not existing_data:
308
+ existing_data = {
309
+ "credential": {},
310
+ "state": self._get_default_state(),
311
+ "stats": self._get_default_stats()
312
+ }
313
+
314
+ # 更新状态数据
315
+ existing_data["state"].update(state_updates)
316
+
317
+ success = await self._credentials_cache_manager.set(filename, existing_data)
318
+
319
+ # 性能监控
320
+ self._operation_count += 1
321
+ operation_time = time.time() - start_time
322
+ self._operation_times.append(operation_time)
323
+
324
+ log.debug(f"Updated credential state in unified cache: {filename} in {operation_time:.3f}s")
325
+ return success
326
+
327
+ except Exception as e:
328
+ operation_time = time.time() - start_time
329
+ log.error(f"Error updating credential state {filename} in {operation_time:.3f}s: {e}")
330
+ return False
331
+
332
+ async def get_credential_state(self, filename: str) -> Dict[str, Any]:
333
+ """从统一缓存获取凭证状态"""
334
+ self._ensure_initialized()
335
+ start_time = time.time()
336
+
337
+ try:
338
+ credential_entry = await self._credentials_cache_manager.get(filename)
339
+
340
+ # 性能监控
341
+ self._operation_count += 1
342
+ operation_time = time.time() - start_time
343
+ self._operation_times.append(operation_time)
344
+
345
+ if credential_entry and "state" in credential_entry:
346
+ log.debug(f"Retrieved credential state from unified cache: {filename} in {operation_time:.3f}s")
347
+ return credential_entry["state"]
348
+ else:
349
+ # 返回默认状态
350
+ return self._get_default_state()
351
+
352
+ except Exception as e:
353
+ operation_time = time.time() - start_time
354
+ log.error(f"Error getting credential state {filename} in {operation_time:.3f}s: {e}")
355
+ return self._get_default_state()
356
+
357
+ async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
358
+ """从统一缓存获取所有凭证状态"""
359
+ self._ensure_initialized()
360
+ start_time = time.time()
361
+
362
+ try:
363
+ all_data = await self._credentials_cache_manager.get_all()
364
+
365
+ states = {}
366
+ for filename, cred_data in all_data.items():
367
+ states[filename] = cred_data.get("state", self._get_default_state())
368
+
369
+ # 性能监控
370
+ self._operation_count += 1
371
+ operation_time = time.time() - start_time
372
+ self._operation_times.append(operation_time)
373
+
374
+ log.debug(f"Retrieved all credential states from unified cache ({len(states)}) in {operation_time:.3f}s")
375
+ return states
376
+
377
+ except Exception as e:
378
+ operation_time = time.time() - start_time
379
+ log.error(f"Error getting all credential states in {operation_time:.3f}s: {e}")
380
+ return {}
381
+
382
+ # ============ 配置管理 ============
383
+
384
+ async def set_config(self, key: str, value: Any) -> bool:
385
+ """设置配置到统一缓存"""
386
+ self._ensure_initialized()
387
+ start_time = time.time()
388
+
389
+ try:
390
+ success = await self._config_cache_manager.set(key, value)
391
+
392
+ # ���能监控
393
+ self._operation_count += 1
394
+ operation_time = time.time() - start_time
395
+ self._operation_times.append(operation_time)
396
+
397
+ log.debug(f"Set config to unified cache: {key} in {operation_time:.3f}s")
398
+ return success
399
+
400
+ except Exception as e:
401
+ operation_time = time.time() - start_time
402
+ log.error(f"Error setting config {key} in {operation_time:.3f}s: {e}")
403
+ return False
404
+
405
+ async def get_config(self, key: str, default: Any = None) -> Any:
406
+ """从统一缓存获取配置"""
407
+ self._ensure_initialized()
408
+ return await self._config_cache_manager.get(key, default)
409
+
410
+ async def get_all_config(self) -> Dict[str, Any]:
411
+ """从统一缓存获取所有配置"""
412
+ self._ensure_initialized()
413
+ return await self._config_cache_manager.get_all()
414
+
415
+ async def delete_config(self, key: str) -> bool:
416
+ """从统一缓存删除配置"""
417
+ self._ensure_initialized()
418
+ return await self._config_cache_manager.delete(key)
419
+
420
+ # ============ 使用统计管理 ============
421
+
422
+ async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
423
+ """更新使用统计(使用统一缓存)"""
424
+ self._ensure_initialized()
425
+ start_time = time.time()
426
+
427
+ try:
428
+ # 获取现有数据或创建新数据
429
+ existing_data = await self._credentials_cache_manager.get(filename, {})
430
+
431
+ if not existing_data:
432
+ existing_data = {
433
+ "credential": {},
434
+ "state": self._get_default_state(),
435
+ "stats": self._get_default_stats()
436
+ }
437
+
438
+ # 更新统计数据
439
+ existing_data["stats"].update(stats_updates)
440
+
441
+ success = await self._credentials_cache_manager.set(filename, existing_data)
442
+
443
+ # 性能监控
444
+ self._operation_count += 1
445
+ operation_time = time.time() - start_time
446
+ self._operation_times.append(operation_time)
447
+
448
+ log.debug(f"Updated usage stats in unified cache: {filename} in {operation_time:.3f}s")
449
+ return success
450
+
451
+ except Exception as e:
452
+ operation_time = time.time() - start_time
453
+ log.error(f"Error updating usage stats {filename} in {operation_time:.3f}s: {e}")
454
+ return False
455
+
456
+ async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
457
+ """从统一缓存获取使用统计"""
458
+ self._ensure_initialized()
459
+ start_time = time.time()
460
+
461
+ try:
462
+ credential_entry = await self._credentials_cache_manager.get(filename)
463
+
464
+ # 性能监控
465
+ self._operation_count += 1
466
+ operation_time = time.time() - start_time
467
+ self._operation_times.append(operation_time)
468
+
469
+ if credential_entry and "stats" in credential_entry:
470
+ log.debug(f"Retrieved usage stats from unified cache: {filename} in {operation_time:.3f}s")
471
+ return credential_entry["stats"]
472
+ else:
473
+ return self._get_default_stats()
474
+
475
+ except Exception as e:
476
+ operation_time = time.time() - start_time
477
+ log.error(f"Error getting usage stats {filename} in {operation_time:.3f}s: {e}")
478
+ return self._get_default_stats()
479
+
480
+ async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
481
+ """从统一缓存获取所有使用统计"""
482
+ self._ensure_initialized()
483
+ start_time = time.time()
484
+
485
+ try:
486
+ all_data = await self._credentials_cache_manager.get_all()
487
+
488
+ stats = {}
489
+ for filename, cred_data in all_data.items():
490
+ if "stats" in cred_data:
491
+ stats[filename] = cred_data["stats"]
492
+
493
+ # 性能监控
494
+ self._operation_count += 1
495
+ operation_time = time.time() - start_time
496
+ self._operation_times.append(operation_time)
497
+
498
+ log.debug(f"Retrieved all usage stats from unified cache ({len(stats)}) in {operation_time:.3f}s")
499
+ return stats
500
+
501
+ except Exception as e:
502
+ operation_time = time.time() - start_time
503
+ log.error(f"Error getting all usage stats in {operation_time:.3f}s: {e}")
504
+ return {}
src/storage_adapter.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 存储适配器,提供统一的接口来处理Redis、MongoDB和本地文件存储。
3
+ 根据配置自动选择存储后端,优先级:Redis > MongoDB > 本地文件。
4
+ """
5
+ import asyncio
6
+ import os
7
+ import json
8
+ from typing import Dict, Any, List, Optional, Protocol
9
+
10
+ from log import log
11
+
12
+
13
+ class StorageBackend(Protocol):
14
+ """存储后端协议"""
15
+
16
+ async def initialize(self) -> None:
17
+ """初始化存储后端"""
18
+ ...
19
+
20
+ async def close(self) -> None:
21
+ """关闭存储后端"""
22
+ ...
23
+
24
+ # 凭证管理
25
+ async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
26
+ """存储凭证数据"""
27
+ ...
28
+
29
+ async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
30
+ """获取凭证数据"""
31
+ ...
32
+
33
+ async def list_credentials(self) -> List[str]:
34
+ """列出所有凭证文件名"""
35
+ ...
36
+
37
+ async def delete_credential(self, filename: str) -> bool:
38
+ """删除凭证"""
39
+ ...
40
+
41
+ # 状态管理
42
+ async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
43
+ """更新凭证状态"""
44
+ ...
45
+
46
+ async def get_credential_state(self, filename: str) -> Dict[str, Any]:
47
+ """获取凭证状态"""
48
+ ...
49
+
50
+ async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
51
+ """获取所有凭证状态"""
52
+ ...
53
+
54
+ # 配置管理
55
+ async def set_config(self, key: str, value: Any) -> bool:
56
+ """设置配置项"""
57
+ ...
58
+
59
+ async def get_config(self, key: str, default: Any = None) -> Any:
60
+ """获取配置项"""
61
+ ...
62
+
63
+ async def get_all_config(self) -> Dict[str, Any]:
64
+ """获取所有配置"""
65
+ ...
66
+
67
+ async def delete_config(self, key: str) -> bool:
68
+ """删除配置项"""
69
+ ...
70
+
71
+ # 使用统计管理
72
+ async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
73
+ """更新使用统计"""
74
+ ...
75
+
76
+ async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
77
+ """获取使用统计"""
78
+ ...
79
+
80
+ async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
81
+ """获取所有使用统计"""
82
+ ...
83
+
84
+
85
+
86
+
87
+ class StorageAdapter:
88
+ """存储适配器,根据配置选择存储后端"""
89
+
90
+ def __init__(self):
91
+ self._backend: Optional["StorageBackend"] = None
92
+ self._initialized = False
93
+ self._lock = asyncio.Lock()
94
+
95
+ async def initialize(self) -> None:
96
+ """初始化存储适配器"""
97
+ async with self._lock:
98
+ if self._initialized:
99
+ return
100
+
101
+ # 按优先级检查存储后端:Redis > MongoDB > 本地文件
102
+ redis_uri = os.getenv("REDIS_URI", "")
103
+ mongodb_uri = os.getenv("MONGODB_URI", "")
104
+
105
+ # 优先尝试Redis存储
106
+ if redis_uri:
107
+ try:
108
+ from .storage.redis_manager import RedisManager
109
+ self._backend = RedisManager()
110
+ await self._backend.initialize()
111
+ log.info("Using Redis storage backend")
112
+ except ImportError as e:
113
+ log.error(f"Failed to import Redis backend: {e}")
114
+ log.info("Falling back to next available storage backend")
115
+ except Exception as e:
116
+ log.error(f"Failed to initialize Redis backend: {e}")
117
+ log.info("Falling back to next available storage backend")
118
+
119
+ # 如果Redis不可用或未配置,接下来尝试Postgres(优先级低于Redis)
120
+ postgres_dsn = os.getenv("POSTGRES_DSN", "")
121
+ if not self._backend and postgres_dsn:
122
+ try:
123
+ from .storage.postgres_manager import PostgresManager
124
+ self._backend = PostgresManager()
125
+ await self._backend.initialize()
126
+ log.info("Using Postgres storage backend")
127
+ except ImportError as e:
128
+ log.error(f"Failed to import Postgres backend: {e}")
129
+ log.info("Falling back to next available storage backend")
130
+ except Exception as e:
131
+ log.error(f"Failed to initialize Postgres backend: {e}")
132
+ log.info("Falling back to next available storage backend")
133
+
134
+ # 如果Redis和Postgres不可用,尝试MongoDB存储
135
+ if not self._backend and mongodb_uri:
136
+ try:
137
+ from .storage.mongodb_manager import MongoDBManager
138
+ self._backend = MongoDBManager()
139
+ await self._backend.initialize()
140
+ log.info("Using MongoDB storage backend")
141
+ except ImportError as e:
142
+ log.error(f"Failed to import MongoDB backend: {e}")
143
+ log.info("Falling back to file storage backend")
144
+ except Exception as e:
145
+ log.error(f"Failed to initialize MongoDB backend: {e}")
146
+ log.info("Falling back to file storage backend")
147
+
148
+ # 如果Redis和MongoDB都不可用,使用文件存储
149
+ if not self._backend:
150
+ from .storage.file_storage_manager import FileStorageManager
151
+ self._backend = FileStorageManager()
152
+ await self._backend.initialize()
153
+ log.info("Using file storage backend")
154
+
155
+ self._initialized = True
156
+
157
+ async def close(self) -> None:
158
+ """关闭存储适配器"""
159
+ if self._backend:
160
+ await self._backend.close()
161
+ self._backend = None
162
+ self._initialized = False
163
+
164
+ def _ensure_initialized(self):
165
+ """确保存储适配器已初始化"""
166
+ if not self._initialized or not self._backend:
167
+ raise RuntimeError("Storage adapter not initialized")
168
+
169
+ # ============ 凭证管理 ============
170
+
171
+ async def store_credential(self, filename: str, credential_data: Dict[str, Any]) -> bool:
172
+ """存储凭证数据"""
173
+ self._ensure_initialized()
174
+ return await self._backend.store_credential(filename, credential_data)
175
+
176
+ async def get_credential(self, filename: str) -> Optional[Dict[str, Any]]:
177
+ """获取凭证数据"""
178
+ self._ensure_initialized()
179
+ return await self._backend.get_credential(filename)
180
+
181
+ async def list_credentials(self) -> List[str]:
182
+ """列出所有凭证文件名"""
183
+ self._ensure_initialized()
184
+ return await self._backend.list_credentials()
185
+
186
+ async def delete_credential(self, filename: str) -> bool:
187
+ """删除凭证"""
188
+ self._ensure_initialized()
189
+ return await self._backend.delete_credential(filename)
190
+
191
+ # ============ 状态管理 ============
192
+
193
+ async def update_credential_state(self, filename: str, state_updates: Dict[str, Any]) -> bool:
194
+ """更新凭证状态"""
195
+ self._ensure_initialized()
196
+ return await self._backend.update_credential_state(filename, state_updates)
197
+
198
+ async def get_credential_state(self, filename: str) -> Dict[str, Any]:
199
+ """获取凭证状态"""
200
+ self._ensure_initialized()
201
+ return await self._backend.get_credential_state(filename)
202
+
203
+ async def get_all_credential_states(self) -> Dict[str, Dict[str, Any]]:
204
+ """获取所有凭证状态"""
205
+ self._ensure_initialized()
206
+ return await self._backend.get_all_credential_states()
207
+
208
+ # ============ 配置管理 ============
209
+
210
+ async def set_config(self, key: str, value: Any) -> bool:
211
+ """设置配置项"""
212
+ self._ensure_initialized()
213
+ return await self._backend.set_config(key, value)
214
+
215
+ async def get_config(self, key: str, default: Any = None) -> Any:
216
+ """获取配置项"""
217
+ self._ensure_initialized()
218
+ return await self._backend.get_config(key, default)
219
+
220
+ async def get_all_config(self) -> Dict[str, Any]:
221
+ """获取所有配置"""
222
+ self._ensure_initialized()
223
+ return await self._backend.get_all_config()
224
+
225
+ async def delete_config(self, key: str) -> bool:
226
+ """删除配置项"""
227
+ self._ensure_initialized()
228
+ return await self._backend.delete_config(key)
229
+
230
+ # ============ 使用统计管理 ============
231
+
232
+ async def update_usage_stats(self, filename: str, stats_updates: Dict[str, Any]) -> bool:
233
+ """更新使用统计"""
234
+ self._ensure_initialized()
235
+ return await self._backend.update_usage_stats(filename, stats_updates)
236
+
237
+ async def get_usage_stats(self, filename: str) -> Dict[str, Any]:
238
+ """获取使用统计"""
239
+ self._ensure_initialized()
240
+ return await self._backend.get_usage_stats(filename)
241
+
242
+ async def get_all_usage_stats(self) -> Dict[str, Dict[str, Any]]:
243
+ """获取所有使用统计"""
244
+ self._ensure_initialized()
245
+ return await self._backend.get_all_usage_stats()
246
+
247
+ # ============ 工具方法 ============
248
+
249
+ async def export_credential_to_json(self, filename: str, output_path: str = None) -> bool:
250
+ """将凭证导出为JSON文件"""
251
+ self._ensure_initialized()
252
+ if hasattr(self._backend, 'export_credential_to_json'):
253
+ return await self._backend.export_credential_to_json(filename, output_path)
254
+ # MongoDB后端的fallback实现
255
+ credential_data = await self.get_credential(filename)
256
+ if credential_data is None:
257
+ return False
258
+
259
+ if output_path is None:
260
+ output_path = f"{filename}.json"
261
+
262
+ import aiofiles
263
+ try:
264
+ async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
265
+ await f.write(json.dumps(credential_data, indent=2, ensure_ascii=False))
266
+ return True
267
+ except Exception:
268
+ return False
269
+
270
+ async def import_credential_from_json(self, json_path: str, filename: str = None) -> bool:
271
+ """从JSON文件导入凭证"""
272
+ self._ensure_initialized()
273
+ if hasattr(self._backend, 'import_credential_from_json'):
274
+ return await self._backend.import_credential_from_json(json_path, filename)
275
+ # MongoDB后端的fallback实现
276
+ try:
277
+ import aiofiles
278
+ async with aiofiles.open(json_path, "r", encoding="utf-8") as f:
279
+ content = await f.read()
280
+
281
+ credential_data = json.loads(content)
282
+
283
+ if filename is None:
284
+ filename = os.path.basename(json_path)
285
+
286
+ return await self.store_credential(filename, credential_data)
287
+ except Exception:
288
+ return False
289
+
290
+ def get_backend_type(self) -> str:
291
+ """获取当前存储后端类型"""
292
+ if not self._backend:
293
+ return "none"
294
+
295
+ # 检查后端类型
296
+ backend_class_name = self._backend.__class__.__name__
297
+ if "File" in backend_class_name or "file" in backend_class_name.lower():
298
+ return "file"
299
+ elif "MongoDB" in backend_class_name or "mongo" in backend_class_name.lower():
300
+ return "mongodb"
301
+ elif "Redis" in backend_class_name or "redis" in backend_class_name.lower():
302
+ return "redis"
303
+ else:
304
+ return "unknown"
305
+
306
+ async def get_backend_info(self) -> Dict[str, Any]:
307
+ """获取存储后端信息"""
308
+ self._ensure_initialized()
309
+
310
+ backend_type = self.get_backend_type()
311
+ info = {
312
+ "backend_type": backend_type,
313
+ "initialized": self._initialized
314
+ }
315
+
316
+ # 获取底层存储信息
317
+ if hasattr(self._backend, 'get_database_info'):
318
+ try:
319
+ db_info = await self._backend.get_database_info()
320
+ info.update(db_info)
321
+ except Exception as e:
322
+ info["database_error"] = str(e)
323
+ else:
324
+ backend_type = self.get_backend_type()
325
+ if backend_type == "file":
326
+ info.update({
327
+ "credentials_dir": getattr(self._backend, '_credentials_dir', None),
328
+ "state_file": getattr(self._backend, '_state_file', None),
329
+ "config_file": getattr(self._backend, '_config_file', None)
330
+ })
331
+ elif backend_type == "redis":
332
+ info.update({
333
+ "redis_url": getattr(self._backend, '_redis_url', None),
334
+ "connection_pool_size": getattr(self._backend, '_pool_size', None)
335
+ })
336
+
337
+ return info
338
+
339
+
340
+ # 全局存储适配器实例
341
+ _storage_adapter: Optional[StorageAdapter] = None
342
+
343
+
344
+ async def get_storage_adapter() -> StorageAdapter:
345
+ """获取全局存储适配器实例"""
346
+ global _storage_adapter
347
+
348
+ if _storage_adapter is None:
349
+ _storage_adapter = StorageAdapter()
350
+ await _storage_adapter.initialize()
351
+
352
+ return _storage_adapter
353
+
354
+
355
+ async def close_storage_adapter():
356
+ """关闭全局存储适配器"""
357
+ global _storage_adapter
358
+
359
+ if _storage_adapter:
360
+ await _storage_adapter.close()
361
+ _storage_adapter = None
src/task_manager.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global task lifecycle management module
3
+ 管理应用程序中所有异步任务的生命周期,确保正确清理
4
+ """
5
+ import asyncio
6
+ import weakref
7
+ from typing import Set, Dict, Any
8
+ from log import log
9
+
10
+
11
+ class TaskManager:
12
+ """全局异步任务管理器 - 单例模式"""
13
+
14
+ _instance = None
15
+ _lock = asyncio.Lock()
16
+
17
+ def __new__(cls):
18
+ if cls._instance is None:
19
+ cls._instance = super().__new__(cls)
20
+ cls._instance._initialized = False
21
+ return cls._instance
22
+
23
+ def __init__(self):
24
+ if self._initialized:
25
+ return
26
+
27
+ self._tasks: Set[asyncio.Task] = set()
28
+ self._resources: Set[Any] = set() # 需要关闭的资源
29
+ self._shutdown_event = asyncio.Event()
30
+ self._initialized = True
31
+ log.debug("TaskManager initialized")
32
+
33
+ def register_task(self, task: asyncio.Task, description: str = None) -> asyncio.Task:
34
+ """注册一个任务供生命周期管理"""
35
+ self._tasks.add(task)
36
+ task.add_done_callback(lambda t: self._tasks.discard(t))
37
+
38
+ if description:
39
+ task.set_name(description)
40
+
41
+ log.debug(f"Registered task: {task.get_name() or 'unnamed'}")
42
+ return task
43
+
44
+ def create_task(self, coro, *, name: str = None) -> asyncio.Task:
45
+ """创建并注册一个任务"""
46
+ task = asyncio.create_task(coro, name=name)
47
+ return self.register_task(task, name)
48
+
49
+ def register_resource(self, resource: Any) -> Any:
50
+ """注册一个需要清理的资源(如HTTP客户端、文件句柄等)"""
51
+ # 使用弱引用避免循环引用
52
+ self._resources.add(weakref.ref(resource))
53
+ log.debug(f"Registered resource: {type(resource).__name__}")
54
+ return resource
55
+
56
+ async def shutdown(self, timeout: float = 30.0):
57
+ """关闭所有任务和资源"""
58
+ log.info("TaskManager shutdown initiated")
59
+
60
+ # 设置关闭标志
61
+ self._shutdown_event.set()
62
+
63
+ # 取消所有未完成的任务
64
+ cancelled_count = 0
65
+ for task in list(self._tasks):
66
+ if not task.done():
67
+ task.cancel()
68
+ cancelled_count += 1
69
+
70
+ if cancelled_count > 0:
71
+ log.info(f"Cancelled {cancelled_count} pending tasks")
72
+
73
+ # 等待所有任务完成(包括取消)
74
+ if self._tasks:
75
+ try:
76
+ await asyncio.wait_for(
77
+ asyncio.gather(*self._tasks, return_exceptions=True),
78
+ timeout=timeout
79
+ )
80
+ except asyncio.TimeoutError:
81
+ log.warning(f"Some tasks did not complete within {timeout}s timeout")
82
+
83
+ # 清理资源
84
+ cleaned_resources = 0
85
+ for resource_ref in list(self._resources):
86
+ resource = resource_ref()
87
+ if resource is not None:
88
+ try:
89
+ if hasattr(resource, 'close'):
90
+ if asyncio.iscoroutinefunction(resource.close):
91
+ await resource.close()
92
+ else:
93
+ resource.close()
94
+ elif hasattr(resource, 'aclose'):
95
+ await resource.aclose()
96
+ cleaned_resources += 1
97
+ except Exception as e:
98
+ log.warning(f"Failed to close resource {type(resource).__name__}: {e}")
99
+
100
+ if cleaned_resources > 0:
101
+ log.info(f"Cleaned up {cleaned_resources} resources")
102
+
103
+ self._tasks.clear()
104
+ self._resources.clear()
105
+ log.info("TaskManager shutdown completed")
106
+
107
+ @property
108
+ def is_shutdown(self) -> bool:
109
+ """检查是否已经开始关闭"""
110
+ return self._shutdown_event.is_set()
111
+
112
+ def get_stats(self) -> Dict[str, int]:
113
+ """获取任务管理统计信息"""
114
+ return {
115
+ 'active_tasks': len(self._tasks),
116
+ 'registered_resources': len(self._resources),
117
+ 'is_shutdown': self.is_shutdown
118
+ }
119
+
120
+
121
+ # 全局任务管理器实例
122
+ task_manager = TaskManager()
123
+
124
+
125
+ def create_managed_task(coro, *, name: str = None) -> asyncio.Task:
126
+ """创建一个被管理的异步任务的便捷函数"""
127
+ return task_manager.create_task(coro, name=name)
128
+
129
+
130
+ def register_resource(resource: Any) -> Any:
131
+ """注册资源的便捷函数"""
132
+ return task_manager.register_resource(resource)
133
+
134
+
135
+ async def shutdown_all_tasks(timeout: float = 30.0):
136
+ """关闭所有任务的便捷函数"""
137
+ await task_manager.shutdown(timeout)
src/usage_stats.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage statistics module for tracking API calls per credential file.
3
+ Uses the simpler logic: compare current time with next_reset_time.
4
+ """
5
+ import os
6
+ import time
7
+ from datetime import datetime, timezone, timedelta
8
+ from threading import Lock
9
+ from typing import Dict, Any, Optional
10
+
11
+ from config import get_credentials_dir, is_mongodb_mode
12
+ from log import log
13
+ from .state_manager import get_state_manager
14
+ from .storage_adapter import get_storage_adapter
15
+
16
+
17
+ def _get_next_utc_7am() -> datetime:
18
+ """
19
+ Calculate the next UTC 07:00 time for quota reset.
20
+ """
21
+ now = datetime.now(timezone.utc)
22
+ today_7am = now.replace(hour=7, minute=0, second=0, microsecond=0)
23
+
24
+ if now < today_7am:
25
+ return today_7am
26
+ else:
27
+ return today_7am + timedelta(days=1)
28
+
29
+
30
+ class UsageStats:
31
+ """
32
+ Simplified usage statistics manager with clear reset logic.
33
+ """
34
+
35
+ def __init__(self):
36
+ self._lock = Lock()
37
+ # 状态文件路径将在初始化时异步设置
38
+ self._state_file = None
39
+ self._state_manager = None
40
+ self._storage_adapter = None
41
+ self._stats_cache: Dict[str, Dict[str, Any]] = {}
42
+ self._initialized = False
43
+ self._cache_dirty = False # 缓存脏标记,减少不必要的写入
44
+ self._last_save_time = 0
45
+ self._save_interval = 60 # 最多每分钟保存一次,减少I/O
46
+ self._max_cache_size = 100 # 严格限制缓存大小
47
+
48
+ async def initialize(self):
49
+ """Initialize the usage stats module."""
50
+ if self._initialized:
51
+ return
52
+
53
+ # 初始化存储适配器
54
+ self._storage_adapter = await get_storage_adapter()
55
+
56
+ # 只在文件模式下创建本地状态文件
57
+ if not await is_mongodb_mode():
58
+ credentials_dir = await get_credentials_dir()
59
+ self._state_file = os.path.join(credentials_dir, "creds_state.toml")
60
+ self._state_manager = get_state_manager(self._state_file)
61
+
62
+ await self._load_stats()
63
+ self._initialized = True
64
+ storage_type = "MongoDB" if await is_mongodb_mode() else "File"
65
+ log.debug(f"Usage statistics module initialized with {storage_type} storage backend")
66
+
67
+
68
+ def _normalize_filename(self, filename: str) -> str:
69
+ """Normalize filename to relative path for consistent storage."""
70
+ if not filename:
71
+ return ""
72
+
73
+ if os.path.sep not in filename and "/" not in filename:
74
+ return filename
75
+
76
+ return os.path.basename(filename)
77
+
78
+ def _is_gemini_2_5_pro(self, model_name: str) -> bool:
79
+ """
80
+ Check if model is gemini-2.5-pro variant (including prefixes and suffixes).
81
+ """
82
+ if not model_name:
83
+ return False
84
+
85
+ try:
86
+ from config import get_base_model_name, get_base_model_from_feature_model
87
+
88
+ # Remove feature prefixes (流式抗截断/, 假流式/)
89
+ base_with_suffix = get_base_model_from_feature_model(model_name)
90
+
91
+ # Remove thinking/search suffixes (-maxthinking, -nothinking, -search)
92
+ pure_base_model = get_base_model_name(base_with_suffix)
93
+
94
+ # Check if the pure base model is exactly "gemini-2.5-pro"
95
+ return pure_base_model == "gemini-2.5-pro"
96
+
97
+ except ImportError:
98
+ # Fallback logic if config import fails
99
+ clean_model = model_name
100
+ for prefix in ["流式抗截断/", "假流式/"]:
101
+ if clean_model.startswith(prefix):
102
+ clean_model = clean_model[len(prefix):]
103
+ break
104
+
105
+ for suffix in ["-maxthinking", "-nothinking", "-search"]:
106
+ if clean_model.endswith(suffix):
107
+ clean_model = clean_model[:-len(suffix)]
108
+ break
109
+
110
+ return clean_model == "gemini-2.5-pro"
111
+
112
+ async def _load_stats(self):
113
+ """Load statistics from unified storage"""
114
+ try:
115
+ # 从统一存储获取所有使用统计,添加超时机制防止卡死
116
+ import asyncio
117
+
118
+ async def load_stats_with_timeout():
119
+ all_usage_stats = await self._storage_adapter.get_all_usage_stats()
120
+
121
+ log.debug(f"Processing {len(all_usage_stats)} usage statistics items...")
122
+
123
+ # 直接处理统计数据
124
+ stats_cache = {}
125
+ processed_count = 0
126
+
127
+ for filename, stats_data in all_usage_stats.items():
128
+ if isinstance(stats_data, dict):
129
+ normalized_filename = self._normalize_filename(filename)
130
+
131
+ # 提取使用统计字段
132
+ usage_data = {
133
+ "gemini_2_5_pro_calls": stats_data.get("gemini_2_5_pro_calls", 0),
134
+ "total_calls": stats_data.get("total_calls", 0),
135
+ "next_reset_time": stats_data.get("next_reset_time"),
136
+ "daily_limit_gemini_2_5_pro": stats_data.get("daily_limit_gemini_2_5_pro", 100),
137
+ "daily_limit_total": stats_data.get("daily_limit_total", 1000)
138
+ }
139
+
140
+ # 只加载有实际使用数据的统计,或者有reset时间的
141
+ if (usage_data.get("gemini_2_5_pro_calls", 0) > 0 or
142
+ usage_data.get("total_calls", 0) > 0 or
143
+ usage_data.get("next_reset_time")):
144
+ stats_cache[normalized_filename] = usage_data
145
+ processed_count += 1
146
+
147
+ return stats_cache, processed_count
148
+
149
+ # 设置15秒超时防止卡死
150
+ try:
151
+ self._stats_cache, processed_count = await asyncio.wait_for(
152
+ load_stats_with_timeout(), timeout=15.0
153
+ )
154
+ log.debug(f"Loaded usage statistics for {processed_count} credential files")
155
+ except asyncio.TimeoutError:
156
+ log.error("Loading usage statistics timed out after 30 seconds, using empty cache")
157
+ self._stats_cache = {}
158
+ return
159
+
160
+ except Exception as e:
161
+ log.error(f"Failed to load usage statistics: {e}")
162
+ self._stats_cache = {}
163
+
164
+ async def _save_stats(self):
165
+ """Save statistics to unified storage."""
166
+ current_time = time.time()
167
+
168
+ # 使用脏标记和时间间隔控制,减少不必要的写入
169
+ if not self._cache_dirty or (current_time - self._last_save_time < self._save_interval):
170
+ return
171
+
172
+ try:
173
+ # 批量更新使用统计到存储适配器
174
+ log.debug(f"Saving {len(self._stats_cache)} usage statistics items...")
175
+
176
+ saved_count = 0
177
+ for filename, stats in self._stats_cache.items():
178
+ try:
179
+ stats_data = {
180
+ "gemini_2_5_pro_calls": stats.get("gemini_2_5_pro_calls", 0),
181
+ "total_calls": stats.get("total_calls", 0),
182
+ "next_reset_time": stats.get("next_reset_time"),
183
+ "daily_limit_gemini_2_5_pro": stats.get("daily_limit_gemini_2_5_pro", 100),
184
+ "daily_limit_total": stats.get("daily_limit_total", 1000)
185
+ }
186
+
187
+ success = await self._storage_adapter.update_usage_stats(filename, stats_data)
188
+ if success:
189
+ saved_count += 1
190
+ except Exception as e:
191
+ log.error(f"Failed to save stats for {filename}: {e}")
192
+ continue
193
+
194
+ self._cache_dirty = False # 清除脏标记
195
+ self._last_save_time = current_time
196
+ log.debug(f"Successfully saved {saved_count}/{len(self._stats_cache)} usage statistics to unified storage")
197
+ except Exception as e:
198
+ log.error(f"Failed to save usage statistics: {e}")
199
+
200
+ def _get_or_create_stats(self, filename: str) -> Dict[str, Any]:
201
+ """Get or create statistics entry for a credential file."""
202
+ normalized_filename = self._normalize_filename(filename)
203
+
204
+ if normalized_filename not in self._stats_cache:
205
+ # 严格控制缓存大小 - 超过限制时删除最旧的条目
206
+ if len(self._stats_cache) >= self._max_cache_size:
207
+ # 删除最旧的统计数据(基于next_reset_time或没有该字段的)
208
+ oldest_key = min(self._stats_cache.keys(),
209
+ key=lambda k: self._stats_cache[k].get('next_reset_time', ''))
210
+ del self._stats_cache[oldest_key]
211
+ self._cache_dirty = True
212
+ log.debug(f"Removed oldest usage stats cache entry: {oldest_key}")
213
+
214
+ next_reset = _get_next_utc_7am()
215
+ self._stats_cache[normalized_filename] = {
216
+ "gemini_2_5_pro_calls": 0,
217
+ "total_calls": 0,
218
+ "next_reset_time": next_reset.isoformat(),
219
+ "daily_limit_gemini_2_5_pro": 100,
220
+ "daily_limit_total": 1000
221
+ }
222
+ self._cache_dirty = True # 标记缓存已修改
223
+
224
+ return self._stats_cache[normalized_filename]
225
+
226
+ def _check_and_reset_daily_quota(self, stats: Dict[str, Any]) -> bool:
227
+ """
228
+ Simple reset logic: if current time >= next_reset_time, then reset.
229
+ """
230
+ try:
231
+ next_reset_str = stats.get("next_reset_time")
232
+ if not next_reset_str:
233
+ # No next reset time recorded, set it up
234
+ next_reset = _get_next_utc_7am()
235
+ stats["next_reset_time"] = next_reset.isoformat()
236
+ return False
237
+
238
+ next_reset = datetime.fromisoformat(next_reset_str)
239
+ now = datetime.now(timezone.utc)
240
+
241
+ # Simple comparison: if current time >= next reset time, then reset
242
+ if now >= next_reset:
243
+ old_gemini_calls = stats.get("gemini_2_5_pro_calls", 0)
244
+ old_total_calls = stats.get("total_calls", 0)
245
+
246
+ # Reset counters and set new next reset time
247
+ new_next_reset = _get_next_utc_7am()
248
+ stats.update({
249
+ "gemini_2_5_pro_calls": 0,
250
+ "total_calls": 0,
251
+ "next_reset_time": new_next_reset.isoformat()
252
+ })
253
+
254
+ self._cache_dirty = True # 标记缓存已修改
255
+ log.info(f"Daily quota reset performed. Previous stats - Gemini 2.5 Pro: {old_gemini_calls}, Total: {old_total_calls}")
256
+ return True
257
+
258
+ return False
259
+ except Exception as e:
260
+ log.error(f"Error in daily quota reset check: {e}")
261
+ return False
262
+
263
+ async def record_successful_call(self, filename: str, model_name: str):
264
+ """Record a successful API call for statistics."""
265
+ if not self._initialized:
266
+ await self.initialize()
267
+
268
+ with self._lock:
269
+ try:
270
+ normalized_filename = self._normalize_filename(filename)
271
+ stats = self._get_or_create_stats(normalized_filename)
272
+
273
+ # Check and perform daily reset if needed
274
+ reset_performed = self._check_and_reset_daily_quota(stats)
275
+
276
+ # Increment counters
277
+ is_gemini_2_5_pro = self._is_gemini_2_5_pro(model_name)
278
+
279
+ stats["total_calls"] += 1
280
+ if is_gemini_2_5_pro:
281
+ stats["gemini_2_5_pro_calls"] += 1
282
+
283
+ self._cache_dirty = True # 标记缓存已修改
284
+
285
+ log.debug(f"Usage recorded - File: {normalized_filename}, Model: {model_name}, "
286
+ f"Gemini 2.5 Pro: {stats['gemini_2_5_pro_calls']}/{stats.get('daily_limit_gemini_2_5_pro', 100)}, "
287
+ f"Total: {stats['total_calls']}/{stats.get('daily_limit_total', 1000)}")
288
+
289
+ if reset_performed:
290
+ log.info(f"Daily quota was reset for {normalized_filename}")
291
+
292
+ except Exception as e:
293
+ log.error(f"Failed to record usage statistics: {e}")
294
+
295
+ # Save stats asynchronously
296
+ try:
297
+ await self._save_stats()
298
+ except Exception as e:
299
+ log.error(f"Failed to save usage statistics after recording: {e}")
300
+
301
+ async def get_usage_stats(self, filename: str = None) -> Dict[str, Any]:
302
+ """Get usage statistics."""
303
+ if not self._initialized:
304
+ await self.initialize()
305
+
306
+ with self._lock:
307
+ if filename:
308
+ normalized_filename = self._normalize_filename(filename)
309
+ stats = self._get_or_create_stats(normalized_filename)
310
+ # Check for daily reset before returning stats
311
+ self._check_and_reset_daily_quota(stats)
312
+ return {
313
+ "filename": normalized_filename,
314
+ "gemini_2_5_pro_calls": stats.get("gemini_2_5_pro_calls", 0),
315
+ "total_calls": stats.get("total_calls", 0),
316
+ "daily_limit_gemini_2_5_pro": stats.get("daily_limit_gemini_2_5_pro", 100),
317
+ "daily_limit_total": stats.get("daily_limit_total", 1000),
318
+ "next_reset_time": stats.get("next_reset_time")
319
+ }
320
+ else:
321
+ # Return all statistics
322
+ all_stats = {}
323
+ for filename, stats in self._stats_cache.items():
324
+ # Check for daily reset for each file
325
+ self._check_and_reset_daily_quota(stats)
326
+ all_stats[filename] = {
327
+ "gemini_2_5_pro_calls": stats.get("gemini_2_5_pro_calls", 0),
328
+ "total_calls": stats.get("total_calls", 0),
329
+ "daily_limit_gemini_2_5_pro": stats.get("daily_limit_gemini_2_5_pro", 100),
330
+ "daily_limit_total": stats.get("daily_limit_total", 1000),
331
+ "next_reset_time": stats.get("next_reset_time")
332
+ }
333
+
334
+ return all_stats
335
+
336
+ async def get_aggregated_stats(self) -> Dict[str, Any]:
337
+ """Get aggregated statistics across all credential files."""
338
+ if not self._initialized:
339
+ await self.initialize()
340
+
341
+ all_stats = await self.get_usage_stats()
342
+
343
+ total_gemini_2_5_pro = 0
344
+ total_all_models = 0
345
+ total_files = len(all_stats)
346
+
347
+ for stats in all_stats.values():
348
+ total_gemini_2_5_pro += stats["gemini_2_5_pro_calls"]
349
+ total_all_models += stats["total_calls"]
350
+
351
+ return {
352
+ "total_files": total_files,
353
+ "total_gemini_2_5_pro_calls": total_gemini_2_5_pro,
354
+ "total_all_model_calls": total_all_models,
355
+ "avg_gemini_2_5_pro_per_file": total_gemini_2_5_pro / max(total_files, 1),
356
+ "avg_total_per_file": total_all_models / max(total_files, 1),
357
+ "next_reset_time": _get_next_utc_7am().isoformat()
358
+ }
359
+
360
+ async def update_daily_limits(self, filename: str, gemini_2_5_pro_limit: int = None,
361
+ total_limit: int = None):
362
+ """Update daily limits for a specific credential file."""
363
+ if not self._initialized:
364
+ await self.initialize()
365
+
366
+ with self._lock:
367
+ try:
368
+ normalized_filename = self._normalize_filename(filename)
369
+ stats = self._get_or_create_stats(normalized_filename)
370
+
371
+ if gemini_2_5_pro_limit is not None:
372
+ stats["daily_limit_gemini_2_5_pro"] = gemini_2_5_pro_limit
373
+
374
+ if total_limit is not None:
375
+ stats["daily_limit_total"] = total_limit
376
+
377
+ log.info(f"Updated daily limits for {normalized_filename}: "
378
+ f"Gemini 2.5 Pro = {stats.get('daily_limit_gemini_2_5_pro', 100)}, "
379
+ f"Total = {stats.get('daily_limit_total', 1000)}")
380
+
381
+ except Exception as e:
382
+ log.error(f"Failed to update daily limits: {e}")
383
+ raise
384
+
385
+ await self._save_stats()
386
+
387
+ async def reset_stats(self, filename: str = None):
388
+ """Reset usage statistics."""
389
+ if not self._initialized:
390
+ await self.initialize()
391
+
392
+ with self._lock:
393
+ if filename:
394
+ normalized_filename = self._normalize_filename(filename)
395
+ if normalized_filename in self._stats_cache:
396
+ # Manual reset: reset counters and set new next reset time
397
+ next_reset = _get_next_utc_7am()
398
+ self._stats_cache[normalized_filename].update({
399
+ "gemini_2_5_pro_calls": 0,
400
+ "total_calls": 0,
401
+ "next_reset_time": next_reset.isoformat()
402
+ })
403
+ log.info(f"Reset usage statistics for {normalized_filename}")
404
+ else:
405
+ # Reset all statistics
406
+ next_reset = _get_next_utc_7am()
407
+ for filename, stats in self._stats_cache.items():
408
+ stats.update({
409
+ "gemini_2_5_pro_calls": 0,
410
+ "total_calls": 0,
411
+ "next_reset_time": next_reset.isoformat()
412
+ })
413
+ log.info("Reset usage statistics for all credential files")
414
+
415
+ await self._save_stats()
416
+
417
+ # Global instance
418
+ _usage_stats_instance: Optional[UsageStats] = None
419
+
420
+ async def get_usage_stats_instance() -> UsageStats:
421
+ """Get the global usage statistics instance."""
422
+ global _usage_stats_instance
423
+ if _usage_stats_instance is None:
424
+ _usage_stats_instance = UsageStats()
425
+ await _usage_stats_instance.initialize()
426
+ return _usage_stats_instance
427
+
428
+
429
+ async def record_successful_call(filename: str, model_name: str):
430
+ """Convenience function to record a successful API call."""
431
+ stats = await get_usage_stats_instance()
432
+ await stats.record_successful_call(filename, model_name)
433
+
434
+
435
+ async def get_usage_stats(filename: str = None) -> Dict[str, Any]:
436
+ """Convenience function to get usage statistics."""
437
+ stats = await get_usage_stats_instance()
438
+ return await stats.get_usage_stats(filename)
439
+
440
+
441
+ async def get_aggregated_stats() -> Dict[str, Any]:
442
+ """Convenience function to get aggregated statistics."""
443
+ stats = await get_usage_stats_instance()
444
+ return await stats.get_aggregated_stats()
src/utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+
3
+ CLI_VERSION = "0.1.5" # Match current gemini-cli version
4
+
5
+ def get_user_agent():
6
+ """Generate User-Agent string matching gemini-cli format."""
7
+ version = CLI_VERSION
8
+ system = platform.system()
9
+ arch = platform.machine()
10
+ return f"GeminiCLI/{version} ({system}; {arch})"
src/web_routes.py ADDED
@@ -0,0 +1,1738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Web路由模块 - 处理认证相关的HTTP请求和控制面板功能
3
+ 用于与上级web.py集成
4
+ """
5
+ import asyncio
6
+ import datetime
7
+ import io
8
+ import json
9
+ import os
10
+ import time
11
+ import zipfile
12
+ from collections import deque
13
+ from typing import List, Optional, Dict, Any
14
+ from urllib.parse import urlparse
15
+
16
+ from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, WebSocket, WebSocketDisconnect, Request
17
+ from fastapi.responses import HTMLResponse, JSONResponse, FileResponse, Response
18
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
19
+ from pydantic import BaseModel
20
+ from starlette.websockets import WebSocketState
21
+ import toml
22
+ import zipfile
23
+ import httpx
24
+
25
+ import config
26
+ from log import log
27
+ from .auth import (
28
+ create_auth_url, get_auth_status,
29
+ verify_password, generate_auth_token, verify_auth_token,
30
+ asyncio_complete_auth_flow, complete_auth_flow_from_callback_url,
31
+ load_credentials_from_env, clear_env_credentials
32
+ )
33
+ from .credential_manager import CredentialManager
34
+ from .usage_stats import get_usage_stats, get_aggregated_stats, get_usage_stats_instance
35
+ from .storage_adapter import get_storage_adapter
36
+
37
+ # 创建路由器
38
+ router = APIRouter()
39
+ security = HTTPBearer()
40
+
41
+ # 创建credential manager实例
42
+ credential_manager = CredentialManager()
43
+
44
+ # WebSocket连接管理
45
+
46
+ class ConnectionManager:
47
+ def __init__(self, max_connections: int = 3): # 进一步降低最大连接数
48
+ # 使用双端队列严格限制内存使用
49
+ self.active_connections: deque = deque(maxlen=max_connections)
50
+ self.max_connections = max_connections
51
+ self._last_cleanup = 0
52
+ self._cleanup_interval = 120 # 120秒清理一次死连接
53
+
54
+ async def connect(self, websocket: WebSocket):
55
+ # 自动清理死连接
56
+ self._auto_cleanup()
57
+
58
+ # 限制最大连接数,防止内存无限增长
59
+ if len(self.active_connections) >= self.max_connections:
60
+ await websocket.close(code=1008, reason="Too many connections")
61
+ return False
62
+
63
+ await websocket.accept()
64
+ self.active_connections.append(websocket)
65
+ log.debug(f"WebSocket连接建立,当前连接数: {len(self.active_connections)}")
66
+ return True
67
+
68
+ def disconnect(self, websocket: WebSocket):
69
+ # 使用更高效的方式移除连接
70
+ try:
71
+ self.active_connections.remove(websocket)
72
+ except ValueError:
73
+ pass # 连接已不存在
74
+ log.debug(f"WebSocket连接断开,当前连接数: {len(self.active_connections)}")
75
+
76
+ async def send_personal_message(self, message: str, websocket: WebSocket):
77
+ try:
78
+ await websocket.send_text(message)
79
+ except Exception:
80
+ self.disconnect(websocket)
81
+
82
+ async def broadcast(self, message: str):
83
+ # 使用更高效的方式处理广播,避免索引操作
84
+ dead_connections = []
85
+ for conn in self.active_connections:
86
+ try:
87
+ await conn.send_text(message)
88
+ except Exception:
89
+ dead_connections.append(conn)
90
+
91
+ # 批量移除死连接
92
+ for dead_conn in dead_connections:
93
+ self.disconnect(dead_conn)
94
+
95
+ def _auto_cleanup(self):
96
+ """自动清理死连接"""
97
+ current_time = time.time()
98
+ if current_time - self._last_cleanup > self._cleanup_interval:
99
+ self.cleanup_dead_connections()
100
+ self._last_cleanup = current_time
101
+
102
+ def cleanup_dead_connections(self):
103
+ """清理已断开的连接"""
104
+ original_count = len(self.active_connections)
105
+ # 使用列表推导式过滤活跃连接,更高效
106
+ alive_connections = deque([
107
+ conn for conn in self.active_connections
108
+ if hasattr(conn, 'client_state') and conn.client_state != WebSocketState.DISCONNECTED
109
+ ], maxlen=self.max_connections)
110
+
111
+ self.active_connections = alive_connections
112
+ cleaned = original_count - len(self.active_connections)
113
+ if cleaned > 0:
114
+ log.debug(f"清理了 {cleaned} 个死连接,剩余连接数: {len(self.active_connections)}")
115
+
116
+ manager = ConnectionManager()
117
+
118
+
119
+ async def ensure_credential_manager_initialized():
120
+ """确保credential manager已初始化"""
121
+ if not credential_manager._initialized:
122
+ await credential_manager.initialize()
123
+
124
+ async def get_credential_manager():
125
+ """获取全局凭证管理器实例"""
126
+ global credential_manager
127
+ if not credential_manager:
128
+ credential_manager = CredentialManager()
129
+ await credential_manager.initialize()
130
+ return credential_manager
131
+
132
+ async def authenticate(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
133
+ """验证用户密码(控制面板使用)"""
134
+ from config import get_panel_password
135
+ password = await get_panel_password()
136
+ token = credentials.credentials
137
+ if token != password:
138
+ raise HTTPException(status_code=403, detail="密码错误")
139
+ return token
140
+
141
+ class LoginRequest(BaseModel):
142
+ password: str
143
+
144
+ class AuthStartRequest(BaseModel):
145
+ project_id: Optional[str] = None # 现在是可选的
146
+ get_all_projects: Optional[bool] = False # 是否为所有项目获取凭证
147
+
148
+ class AuthCallbackRequest(BaseModel):
149
+ project_id: Optional[str] = None # 现在是可选的
150
+ get_all_projects: Optional[bool] = False # 是否为所有项目获取凭证
151
+
152
+ class AuthCallbackUrlRequest(BaseModel):
153
+ callback_url: str # OAuth回调完整URL
154
+ project_id: Optional[str] = None # 可选的项目ID
155
+ get_all_projects: Optional[bool] = False # 是否为所有项目获取凭证
156
+
157
+ class CredFileActionRequest(BaseModel):
158
+ filename: str
159
+ action: str # enable, disable, delete
160
+
161
+ class CredFileBatchActionRequest(BaseModel):
162
+ action: str # "enable", "disable", "delete"
163
+ filenames: List[str] # 批量操作的文件名列表
164
+
165
+ class ConfigSaveRequest(BaseModel):
166
+ config: dict
167
+
168
+
169
+
170
+ def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
171
+ """验证认证令牌"""
172
+ if not verify_auth_token(credentials.credentials):
173
+ raise HTTPException(status_code=401, detail="无效的认证令牌")
174
+ return credentials.credentials
175
+
176
+ def is_mobile_user_agent(user_agent: str) -> bool:
177
+ """检测是否为移动设备用户代理"""
178
+ if not user_agent:
179
+ return False
180
+
181
+ user_agent_lower = user_agent.lower()
182
+ mobile_keywords = [
183
+ 'mobile', 'android', 'iphone', 'ipad', 'ipod',
184
+ 'blackberry', 'windows phone', 'samsung', 'htc',
185
+ 'motorola', 'nokia', 'palm', 'webos', 'opera mini',
186
+ 'opera mobi', 'fennec', 'minimo', 'symbian', 'psp',
187
+ 'nintendo', 'tablet'
188
+ ]
189
+
190
+ return any(keyword in user_agent_lower for keyword in mobile_keywords)
191
+
192
+ @router.get("/", response_class=HTMLResponse)
193
+ @router.get("/v1", response_class=HTMLResponse)
194
+ @router.get("/auth", response_class=HTMLResponse)
195
+ async def serve_control_panel(request: Request):
196
+ """提供统一控制面板(包含认证、文件管理、配置等功能)"""
197
+ try:
198
+ # 获取用户代理并判断是否为移动设备
199
+ user_agent = request.headers.get("user-agent", "")
200
+ is_mobile = is_mobile_user_agent(user_agent)
201
+
202
+ # 根据设备类型选择相应的HTML文件
203
+ if is_mobile:
204
+ html_file_path = "front/control_panel_mobile.html"
205
+ log.info(f"Serving mobile control panel to user-agent: {user_agent}")
206
+ else:
207
+ html_file_path = "front/control_panel.html"
208
+ log.info(f"Serving desktop control panel to user-agent: {user_agent}")
209
+
210
+ with open(html_file_path, "r", encoding="utf-8") as f:
211
+ html_content = f.read()
212
+ return HTMLResponse(content=html_content)
213
+ except FileNotFoundError:
214
+ log.error(f"控制面板页面文件不存在: {html_file_path}")
215
+ # 如果移动端文件不存在,回退到桌面版
216
+ if is_mobile:
217
+ try:
218
+ with open("front/control_panel.html", "r", encoding="utf-8") as f:
219
+ html_content = f.read()
220
+ return HTMLResponse(content=html_content)
221
+ except FileNotFoundError:
222
+ raise HTTPException(status_code=404, detail="控制面板页面不存在")
223
+ else:
224
+ raise HTTPException(status_code=404, detail="控制面板页面不存在")
225
+ except Exception as e:
226
+ log.error(f"加载控制面板页面失败: {e}")
227
+ raise HTTPException(status_code=500, detail="服务器内部错误")
228
+
229
+
230
+ @router.post("/auth/login")
231
+ async def login(request: LoginRequest):
232
+ """用户登录"""
233
+ try:
234
+ if await verify_password(request.password):
235
+ token = generate_auth_token()
236
+ return JSONResponse(content={"token": token, "message": "登录成功"})
237
+ else:
238
+ raise HTTPException(status_code=401, detail="密码错误")
239
+ except HTTPException:
240
+ raise
241
+ except Exception as e:
242
+ log.error(f"登录失败: {e}")
243
+ raise HTTPException(status_code=500, detail=str(e))
244
+
245
+
246
+ @router.post("/auth/start")
247
+ async def start_auth(request: AuthStartRequest, token: str = Depends(verify_token)):
248
+ """开始认证流程,支持自动检测项目ID和批量获取所有项目"""
249
+ try:
250
+ # 检查是否为批量项目模式
251
+ if request.get_all_projects:
252
+ log.info("用户请求批量获取所有项目的凭证...")
253
+ project_id = None # 批量模式下不指定单个项目ID
254
+ else:
255
+ # 如果没有提供项目ID,尝试自动检测
256
+ project_id = request.project_id
257
+ if not project_id:
258
+ log.info("用户未提供项目ID,后续将使用自动检测...")
259
+
260
+ # 使用认证令牌作为用户会话标识
261
+ user_session = token if token else None
262
+ result = await create_auth_url(project_id, user_session, get_all_projects=request.get_all_projects)
263
+
264
+ if result['success']:
265
+ return JSONResponse(content={
266
+ "auth_url": result['auth_url'],
267
+ "state": result['state'],
268
+ "auto_project_detection": result.get('auto_project_detection', False),
269
+ "detected_project_id": result.get('detected_project_id'),
270
+ "get_all_projects": request.get_all_projects
271
+ })
272
+ else:
273
+ raise HTTPException(status_code=500, detail=result['error'])
274
+
275
+ except HTTPException:
276
+ raise
277
+ except Exception as e:
278
+ log.error(f"开始认证流程失败: {e}")
279
+ raise HTTPException(status_code=500, detail=str(e))
280
+
281
+
282
+ @router.post("/auth/callback")
283
+ async def auth_callback(request: AuthCallbackRequest, token: str = Depends(verify_token)):
284
+ """处理认证回调,支持自动检测项目ID和批量获取所有项目"""
285
+ try:
286
+ # 项目ID现在是可选的,在回调处理中进行自动检测
287
+ project_id = request.project_id
288
+ get_all_projects = request.get_all_projects
289
+
290
+ # 使用认证令牌作为用户会话标识
291
+ user_session = token if token else None
292
+ # 异步等待OAuth回调完成
293
+ result = await asyncio_complete_auth_flow(project_id, user_session, get_all_projects=get_all_projects)
294
+
295
+ if result['success']:
296
+ if get_all_projects and result.get('multiple_credentials'):
297
+ # 批量认证成功,返回多个凭证信息
298
+ return JSONResponse(content={
299
+ "multiple_credentials": result['multiple_credentials'],
300
+ "message": "批量认证成功,已为多个项目保存凭证"
301
+ })
302
+ else:
303
+ # 单项目认证成功
304
+ return JSONResponse(content={
305
+ "credentials": result['credentials'],
306
+ "file_path": result['file_path'],
307
+ "message": "认证成功,凭证已保存",
308
+ "auto_detected_project": result.get('auto_detected_project', False)
309
+ })
310
+ else:
311
+ # 如果需要手动项目ID或项目选择,在响应中标明
312
+ if result.get('requires_manual_project_id'):
313
+ # 使用JSON响应
314
+ return JSONResponse(
315
+ status_code=400,
316
+ content={
317
+ "error": result['error'],
318
+ "requires_manual_project_id": True
319
+ }
320
+ )
321
+ elif result.get('requires_project_selection'):
322
+ # 返回项目列表供用户选择
323
+ return JSONResponse(
324
+ status_code=400,
325
+ content={
326
+ "error": result['error'],
327
+ "requires_project_selection": True,
328
+ "available_projects": result['available_projects']
329
+ }
330
+ )
331
+ else:
332
+ raise HTTPException(status_code=400, detail=result['error'])
333
+
334
+ except HTTPException:
335
+ raise
336
+ except Exception as e:
337
+ log.error(f"处理认证回调失败: {e}")
338
+ raise HTTPException(status_code=500, detail=str(e))
339
+
340
+
341
+ @router.post("/auth/callback-url")
342
+ async def auth_callback_url(request: AuthCallbackUrlRequest, token: str = Depends(verify_token)):
343
+ """从回调URL直接完成认证,支持批量获取所有项目"""
344
+ try:
345
+ # 验证URL格式
346
+ if not request.callback_url or not request.callback_url.startswith(('http://', 'https://')):
347
+ raise HTTPException(status_code=400, detail="请提供有效的回调URL")
348
+
349
+ # 从回调URL完成认证
350
+ result = await complete_auth_flow_from_callback_url(
351
+ request.callback_url,
352
+ request.project_id,
353
+ get_all_projects=request.get_all_projects
354
+ )
355
+
356
+ if result['success']:
357
+ if request.get_all_projects and result.get('multiple_credentials'):
358
+ # 批量认证成功,返回多个凭证信息
359
+ return JSONResponse(content={
360
+ "multiple_credentials": result['multiple_credentials'],
361
+ "message": "从回调URL批量认证成功,已为多个项目保存凭证"
362
+ })
363
+ else:
364
+ # 单项目认证成功
365
+ return JSONResponse(content={
366
+ "credentials": result['credentials'],
367
+ "file_path": result['file_path'],
368
+ "message": "从回调URL认证成功,凭证已保存",
369
+ "auto_detected_project": result.get('auto_detected_project', False)
370
+ })
371
+ else:
372
+ # 处理各种错误情况
373
+ if result.get('requires_manual_project_id'):
374
+ return JSONResponse(
375
+ status_code=400,
376
+ content={
377
+ "error": result['error'],
378
+ "requires_manual_project_id": True
379
+ }
380
+ )
381
+ elif result.get('requires_project_selection'):
382
+ return JSONResponse(
383
+ status_code=400,
384
+ content={
385
+ "error": result['error'],
386
+ "requires_project_selection": True,
387
+ "available_projects": result['available_projects']
388
+ }
389
+ )
390
+ else:
391
+ raise HTTPException(status_code=400, detail=result['error'])
392
+
393
+ except HTTPException:
394
+ raise
395
+ except Exception as e:
396
+ log.error(f"从回调URL处理认证失败: {e}")
397
+ raise HTTPException(status_code=500, detail=str(e))
398
+
399
+
400
+ @router.get("/auth/status/{project_id}")
401
+ async def check_auth_status(project_id: str, token: str = Depends(verify_token)):
402
+ """检查认证状态"""
403
+ try:
404
+ if not project_id:
405
+ raise HTTPException(status_code=400, detail="Project ID 不能为空")
406
+
407
+ status = get_auth_status(project_id)
408
+ return JSONResponse(content=status)
409
+
410
+ except Exception as e:
411
+ log.error(f"检查认证状态失败: {e}")
412
+ raise HTTPException(status_code=500, detail=str(e))
413
+
414
+
415
+ async def extract_json_files_from_zip(zip_file: UploadFile) -> List[dict]:
416
+ """从ZIP文件中提取JSON文件"""
417
+ try:
418
+ # 读取ZIP文件内容
419
+ zip_content = await zip_file.read()
420
+
421
+ # 不限制ZIP文件大小,只在处理时控制文件数量
422
+
423
+ files_data = []
424
+
425
+ with zipfile.ZipFile(io.BytesIO(zip_content), 'r') as zip_ref:
426
+ # 获取ZIP中的所有文件
427
+ file_list = zip_ref.namelist()
428
+ json_files = [f for f in file_list if f.endswith('.json') and not f.startswith('__MACOSX/')]
429
+
430
+ if not json_files:
431
+ raise HTTPException(status_code=400, detail="ZIP文件中没有找到JSON文件")
432
+
433
+ log.info(f"从ZIP文件 {zip_file.filename} 中找到 {len(json_files)} 个JSON文件")
434
+
435
+ for json_filename in json_files:
436
+ try:
437
+ # 读取JSON文件内容
438
+ with zip_ref.open(json_filename) as json_file:
439
+ content = json_file.read()
440
+
441
+ try:
442
+ content_str = content.decode('utf-8')
443
+ except UnicodeDecodeError:
444
+ log.warning(f"跳过编码错误的文件: {json_filename}")
445
+ continue
446
+
447
+ # 使用原始文件名(去掉路径)
448
+ filename = os.path.basename(json_filename)
449
+ files_data.append({
450
+ 'filename': filename,
451
+ 'content': content_str
452
+ })
453
+
454
+ except Exception as e:
455
+ log.warning(f"处理ZIP中的文件 {json_filename} 时出错: {e}")
456
+ continue
457
+
458
+ log.info(f"成功从ZIP文件中提取 {len(files_data)} 个有效的JSON文件")
459
+ return files_data
460
+
461
+ except zipfile.BadZipFile:
462
+ raise HTTPException(status_code=400, detail="无效的ZIP文件格式")
463
+ except Exception as e:
464
+ log.error(f"处理ZIP文件失败: {e}")
465
+ raise HTTPException(status_code=500, detail=f"处理ZIP文件失败: {str(e)}")
466
+
467
+
468
+ @router.post("/auth/upload")
469
+ async def upload_credentials(files: List[UploadFile] = File(...), token: str = Depends(verify_token)):
470
+ """批量上传认证文件"""
471
+ try:
472
+ if not files:
473
+ raise HTTPException(status_code=400, detail="请选择要上传的文件")
474
+
475
+ # 检查文件数量限制
476
+ if len(files) > 100:
477
+ raise HTTPException(status_code=400, detail=f"文件数量过多,最多支持100个文件,当前:{len(files)}个")
478
+
479
+ files_data = []
480
+ for file in files:
481
+ # 检查文件类型:支持JSON和ZIP
482
+ if file.filename.endswith('.zip'):
483
+ # 处理ZIP文件
484
+ zip_files_data = await extract_json_files_from_zip(file)
485
+ files_data.extend(zip_files_data)
486
+ log.info(f"从ZIP文件 {file.filename} 中提取了 {len(zip_files_data)} 个JSON文件")
487
+
488
+ elif file.filename.endswith('.json'):
489
+ # 处理单个JSON文件
490
+ # 流式读取文件内容
491
+ content_chunks = []
492
+ while True:
493
+ chunk = await file.read(8192) # 8KB chunks
494
+ if not chunk:
495
+ break
496
+ content_chunks.append(chunk)
497
+
498
+ content = b''.join(content_chunks)
499
+ try:
500
+ content_str = content.decode('utf-8')
501
+ except UnicodeDecodeError:
502
+ raise HTTPException(status_code=400, detail=f"文件 {file.filename} 编码格式不支持")
503
+
504
+ files_data.append({
505
+ 'filename': file.filename,
506
+ 'content': content_str
507
+ })
508
+ else:
509
+ raise HTTPException(status_code=400, detail=f"文件 {file.filename} 格式不支持,只支持JSON和ZIP文件")
510
+
511
+ # 获取存储适配器
512
+ storage_adapter = await get_storage_adapter()
513
+
514
+ # 分批处理大量文件以提高稳定性
515
+ batch_size = 1000 # 每批处理1000个文件
516
+ all_results = []
517
+ total_success = 0
518
+
519
+ for i in range(0, len(files_data), batch_size):
520
+ batch_files = files_data[i:i + batch_size]
521
+
522
+ # 使用并发处理提升文件上传性能
523
+ async def process_single_file(file_data):
524
+ """处理单个文件的并发函数"""
525
+ try:
526
+ filename = file_data['filename']
527
+ content_str = file_data['content']
528
+
529
+ # 解析JSON内容
530
+ credential_data = json.loads(content_str)
531
+
532
+ # 存储到统一存储系统
533
+ success = await storage_adapter.store_credential(filename, credential_data)
534
+ if success:
535
+ # 创建默认状态记录(如果不存在)
536
+ try:
537
+ import time
538
+ default_state = {
539
+ "error_codes": [],
540
+ "disabled": False,
541
+ "last_success": time.time(),
542
+ "user_email": None,
543
+ "gemini_2_5_pro_calls": 0,
544
+ "total_calls": 0,
545
+ "next_reset_time": None,
546
+ "daily_limit_gemini_2_5_pro": 100,
547
+ "daily_limit_total": 1000
548
+ }
549
+ # 只在状态不存在时创建,避免覆盖现有状态
550
+ # 检查数据库中是否真正存在状态记录
551
+ all_states = await storage_adapter.get_all_credential_states()
552
+ if filename not in all_states:
553
+ await storage_adapter.update_credential_state(filename, default_state)
554
+ log.debug(f"Created default state for new credential: {filename}")
555
+ except Exception as e:
556
+ log.warning(f"Failed to create default state for {filename}: {e}")
557
+
558
+ log.debug(f"成功上传凭证文件: {filename}")
559
+ return {
560
+ "filename": filename,
561
+ "status": "success",
562
+ "message": "上传成功"
563
+ }
564
+ else:
565
+ return {
566
+ "filename": filename,
567
+ "status": "error",
568
+ "message": "存储失败"
569
+ }
570
+
571
+ except json.JSONDecodeError as e:
572
+ return {
573
+ "filename": file_data['filename'],
574
+ "status": "error",
575
+ "message": f"JSON格式错误: {str(e)}"
576
+ }
577
+ except Exception as e:
578
+ return {
579
+ "filename": file_data['filename'],
580
+ "status": "error",
581
+ "message": f"处理失败: {str(e)}"
582
+ }
583
+
584
+ # 并发处理这一批文件
585
+ log.info(f"开始并发处理 {len(batch_files)} 个文件...")
586
+ concurrent_tasks = [process_single_file(file_data) for file_data in batch_files]
587
+ batch_results = await asyncio.gather(*concurrent_tasks, return_exceptions=True)
588
+
589
+ # 处理异常结果
590
+ processed_results = []
591
+ batch_uploaded_count = 0
592
+ for result in batch_results:
593
+ if isinstance(result, Exception):
594
+ processed_results.append({
595
+ "filename": "unknown",
596
+ "status": "error",
597
+ "message": f"处理异常: {str(result)}"
598
+ })
599
+ else:
600
+ processed_results.append(result)
601
+ if result["status"] == "success":
602
+ batch_uploaded_count += 1
603
+
604
+ batch_results = processed_results
605
+
606
+ all_results.extend(batch_results)
607
+ total_success += batch_uploaded_count
608
+
609
+ # 记录批次进度
610
+ batch_num = (i // batch_size) + 1
611
+ total_batches = (len(files_data) + batch_size - 1) // batch_size
612
+ log.info(f"批次 {batch_num}/{total_batches} 完成: 成功 {batch_uploaded_count}/{len(batch_files)} 个文件")
613
+
614
+ if total_success > 0:
615
+ return JSONResponse(content={
616
+ "uploaded_count": total_success,
617
+ "total_count": len(files_data),
618
+ "results": all_results,
619
+ "message": f"批量上传完成: 成功 {total_success}/{len(files_data)} 个文件"
620
+ })
621
+ else:
622
+ raise HTTPException(status_code=400, detail="没有文件上传成功")
623
+
624
+ except HTTPException:
625
+ raise
626
+ except Exception as e:
627
+ log.error(f"批量上传失败: {e}")
628
+ raise HTTPException(status_code=500, detail=str(e))
629
+
630
+
631
+ @router.get("/creds/status")
632
+ async def get_creds_status(token: str = Depends(verify_token)):
633
+ """获取所有凭证文件的状态"""
634
+ try:
635
+ await ensure_credential_manager_initialized()
636
+
637
+ # 获取存储适配器
638
+ storage_adapter = await get_storage_adapter()
639
+
640
+ # 获取所有凭证和状态
641
+ all_credentials = await storage_adapter.list_credentials()
642
+ all_states = await credential_manager.get_creds_status()
643
+
644
+ # 获取后端信息(一次性获取,避免重复查询)
645
+ backend_info = await storage_adapter.get_backend_info()
646
+ backend_type = backend_info.get("backend_type", "unknown")
647
+
648
+ # 并发处理所有凭证的数据获取(状态已获取,无需重复处理)
649
+ async def process_credential_data(filename):
650
+ """并发处理单个凭证的数据获取"""
651
+ file_status = all_states.get(filename)
652
+
653
+ # 如果没有状态记录,创建默认状态
654
+ if not file_status:
655
+ try:
656
+ import time
657
+ default_state = {
658
+ "error_codes": [],
659
+ "disabled": False,
660
+ "last_success": time.time(),
661
+ "user_email": None,
662
+ "gemini_2_5_pro_calls": 0,
663
+ "total_calls": 0,
664
+ "next_reset_time": None,
665
+ "daily_limit_gemini_2_5_pro": 100,
666
+ "daily_limit_total": 1000
667
+ }
668
+ await storage_adapter.update_credential_state(filename, default_state)
669
+ file_status = default_state
670
+ log.debug(f"为凭证 {filename} 创建了默认状态记录")
671
+ except Exception as e:
672
+ log.warning(f"无法为凭证 {filename} 创建状态记录: {e}")
673
+ # 创建临时状态用于显示
674
+ file_status = {
675
+ "error_codes": [],
676
+ "disabled": False,
677
+ "last_success": time.time(),
678
+ "user_email": None,
679
+ "gemini_2_5_pro_calls": 0,
680
+ "total_calls": 0,
681
+ "next_reset_time": None,
682
+ "daily_limit_gemini_2_5_pro": 100,
683
+ "daily_limit_total": 1000
684
+ }
685
+
686
+ try:
687
+ # 从存储获取凭证数据
688
+ credential_data = await storage_adapter.get_credential(filename)
689
+ if credential_data:
690
+ result = {
691
+ "status": file_status,
692
+ "content": credential_data,
693
+ "filename": os.path.basename(filename),
694
+ "backend_type": backend_type, # 复用backend信息
695
+ "user_email": file_status.get("user_email")
696
+ }
697
+
698
+ # 如果是文件模式,添加文件元数据
699
+ if backend_type == "file" and os.path.exists(filename):
700
+ result.update({
701
+ "size": os.path.getsize(filename),
702
+ "modified_time": os.path.getmtime(filename)
703
+ })
704
+
705
+ return filename, result
706
+ else:
707
+ return filename, {
708
+ "status": file_status,
709
+ "content": None,
710
+ "filename": os.path.basename(filename),
711
+ "error": "凭证数据不存在"
712
+ }
713
+
714
+ except Exception as e:
715
+ log.error(f"读取凭证��件失败 {filename}: {e}")
716
+ return filename, {
717
+ "status": file_status,
718
+ "content": None,
719
+ "filename": os.path.basename(filename),
720
+ "error": str(e)
721
+ }
722
+
723
+ # 并发处理所有凭证数据获取
724
+ log.debug(f"开始并发获取 {len(all_credentials)} 个凭证数据...")
725
+ concurrent_tasks = [process_credential_data(filename) for filename in all_credentials]
726
+ results = await asyncio.gather(*concurrent_tasks, return_exceptions=True)
727
+
728
+ # 组装结果
729
+ creds_info = {}
730
+ for result in results:
731
+ if isinstance(result, Exception):
732
+ log.error(f"处理凭证状态异常: {result}")
733
+ else:
734
+ filename, credential_info = result
735
+ creds_info[filename] = credential_info
736
+
737
+ return JSONResponse(content={"creds": creds_info})
738
+
739
+ except Exception as e:
740
+ log.error(f"获取凭证状态失败: {e}")
741
+ raise HTTPException(status_code=500, detail=str(e))
742
+
743
+
744
+ @router.post("/creds/action")
745
+ async def creds_action(request: CredFileActionRequest, token: str = Depends(verify_token)):
746
+ """对凭证文件执行操作(启用/禁用/删除)"""
747
+ try:
748
+ await ensure_credential_manager_initialized()
749
+
750
+ log.info(f"Received request: {request}")
751
+
752
+ filename = request.filename
753
+ action = request.action
754
+
755
+ log.info(f"Performing action '{action}' on file: {filename}")
756
+
757
+ # 验证文件名
758
+ if not filename.endswith('.json'):
759
+ log.error(f"Invalid filename: {filename} (not a .json file)")
760
+ raise HTTPException(status_code=400, detail=f"无效的文件名: {filename}")
761
+
762
+ # 获取存储适配器
763
+ storage_adapter = await get_storage_adapter()
764
+
765
+ # 检查凭证是否存在
766
+ credential_data = await storage_adapter.get_credential(filename)
767
+ if not credential_data:
768
+ log.error(f"Credential not found: {filename}")
769
+ raise HTTPException(status_code=404, detail="凭证文件不存在")
770
+
771
+ if action == "enable":
772
+ log.info(f"Web request: ENABLING file {filename}")
773
+ await credential_manager.set_cred_disabled(filename, False)
774
+ log.info(f"Web request: ENABLED file {filename} successfully")
775
+ return JSONResponse(content={"message": f"已启用凭证文件 {os.path.basename(filename)}"})
776
+
777
+ elif action == "disable":
778
+ log.info(f"Web request: DISABLING file {filename}")
779
+ await credential_manager.set_cred_disabled(filename, True)
780
+ log.info(f"Web request: DISABLED file {filename} successfully")
781
+ return JSONResponse(content={"message": f"已禁用凭证文件 {os.path.basename(filename)}"})
782
+
783
+ elif action == "delete":
784
+ try:
785
+ # 使用存储适配器删除凭证
786
+ success = await storage_adapter.delete_credential(filename)
787
+ if success:
788
+ log.info(f"Successfully deleted credential: {filename}")
789
+ return JSONResponse(content={"message": f"已删除凭证文件 {os.path.basename(filename)}"})
790
+ else:
791
+ raise HTTPException(status_code=500, detail="删除凭证失败")
792
+ except Exception as e:
793
+ log.error(f"Error deleting credential {filename}: {e}")
794
+ raise HTTPException(status_code=500, detail=f"删除文件失败: {str(e)}")
795
+
796
+ else:
797
+ raise HTTPException(status_code=400, detail="无效的操作类型")
798
+
799
+ except HTTPException:
800
+ raise
801
+ except Exception as e:
802
+ log.error(f"凭证文件操作失败: {e}")
803
+ raise HTTPException(status_code=500, detail=str(e))
804
+
805
+
806
+ @router.post("/creds/batch-action")
807
+ async def creds_batch_action(request: CredFileBatchActionRequest, token: str = Depends(verify_token)):
808
+ """批量对凭证文件执行操作(启用/禁用/删除)"""
809
+ try:
810
+ await ensure_credential_manager_initialized()
811
+
812
+ action = request.action
813
+ filenames = request.filenames
814
+
815
+ if not filenames:
816
+ raise HTTPException(status_code=400, detail="文件名列表不能为空")
817
+
818
+ log.info(f"Performing batch action '{action}' on {len(filenames)} files")
819
+
820
+ success_count = 0
821
+ errors = []
822
+
823
+ # 获取存储适配器
824
+ storage_adapter = await get_storage_adapter()
825
+
826
+ for filename in filenames:
827
+ try:
828
+ # 验证文件名安全性
829
+ if not filename.endswith('.json'):
830
+ errors.append(f"{filename}: 无效的文件类型")
831
+ continue
832
+
833
+ # 检查凭证是否存在
834
+ credential_data = await storage_adapter.get_credential(filename)
835
+ if not credential_data:
836
+ errors.append(f"{filename}: 凭证不存在")
837
+ continue
838
+
839
+ # 执行相应操作
840
+ if action == "enable":
841
+ await credential_manager.set_cred_disabled(filename, False)
842
+ success_count += 1
843
+
844
+ elif action == "disable":
845
+ await credential_manager.set_cred_disabled(filename, True)
846
+ success_count += 1
847
+
848
+ elif action == "delete":
849
+ try:
850
+ # 使用存储适配器删除凭证
851
+ delete_success = await storage_adapter.delete_credential(filename)
852
+ if delete_success:
853
+ success_count += 1
854
+ log.info(f"Successfully deleted credential in batch: {filename}")
855
+ else:
856
+ errors.append(f"{filename}: 删除失败")
857
+ continue
858
+ except Exception as e:
859
+ errors.append(f"{filename}: 删除文件失败 - {str(e)}")
860
+ continue
861
+ else:
862
+ errors.append(f"{filename}: 无效的操作类型")
863
+ continue
864
+
865
+ except Exception as e:
866
+ log.error(f"Processing {filename} failed: {e}")
867
+ errors.append(f"{filename}: 处理失败 - {str(e)}")
868
+ continue
869
+
870
+ # 构建返回消息
871
+ result_message = f"批量操作完成:成功处理 {success_count}/{len(filenames)} 个文件"
872
+ if errors:
873
+ result_message += f"\n错误详情:\n" + "\n".join(errors)
874
+
875
+ response_data = {
876
+ "success_count": success_count,
877
+ "total_count": len(filenames),
878
+ "errors": errors,
879
+ "message": result_message
880
+ }
881
+
882
+ return JSONResponse(content=response_data)
883
+
884
+ except HTTPException:
885
+ raise
886
+ except Exception as e:
887
+ log.error(f"批量凭证文件操作失败: {e}")
888
+ raise HTTPException(status_code=500, detail=str(e))
889
+
890
+
891
+ @router.get("/creds/download/{filename}")
892
+ async def download_cred_file(filename: str, token: str = Depends(verify_token)):
893
+ """下载单个凭证文件"""
894
+ try:
895
+ # 验证文件名安全性
896
+ if not filename.endswith('.json'):
897
+ raise HTTPException(status_code=404, detail="无效的文件名")
898
+
899
+ # 获取存储适配器
900
+ storage_adapter = await get_storage_adapter()
901
+
902
+ # 从存储系统获取凭证数据
903
+ credential_data = await storage_adapter.get_credential(filename)
904
+ if not credential_data:
905
+ raise HTTPException(status_code=404, detail="文件不存在")
906
+
907
+ # 转换为JSON字符串
908
+ content = json.dumps(credential_data, ensure_ascii=False, indent=2)
909
+
910
+ from fastapi.responses import Response
911
+ return Response(
912
+ content=content,
913
+ media_type="application/json",
914
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
915
+ )
916
+
917
+ except HTTPException:
918
+ raise
919
+ except Exception as e:
920
+ log.error(f"下载凭证文件失败: {e}")
921
+ raise HTTPException(status_code=500, detail=str(e))
922
+
923
+
924
+ @router.post("/creds/fetch-email/{filename}")
925
+ async def fetch_user_email(filename: str, token: str = Depends(verify_token)):
926
+ """获取指定凭证文件的用户邮箱地址"""
927
+ try:
928
+ await ensure_credential_manager_initialized()
929
+
930
+ # 标准化文件名(只保留文件名部分)
931
+ import os
932
+ filename_only = os.path.basename(filename)
933
+ if not filename_only.endswith('.json'):
934
+ raise HTTPException(status_code=404, detail="无效的文件名")
935
+
936
+ # 检查凭证是否存在于存储系统中
937
+ storage_adapter = await get_storage_adapter()
938
+ credential_data = await storage_adapter.get_credential(filename_only)
939
+ if not credential_data:
940
+ raise HTTPException(status_code=404, detail="凭证文件不存在")
941
+
942
+ # 获取用户邮箱(使用凭证名称而不是文件路径)
943
+ email = await credential_manager.get_or_fetch_user_email(filename_only)
944
+
945
+ if email:
946
+ return JSONResponse(content={
947
+ "filename": filename_only,
948
+ "user_email": email,
949
+ "message": "成功获取用户邮箱"
950
+ })
951
+ else:
952
+ return JSONResponse(content={
953
+ "filename": filename_only,
954
+ "user_email": None,
955
+ "message": "无法获取用户邮箱,可能凭证已过期或权限不足"
956
+ }, status_code=400)
957
+
958
+ except HTTPException:
959
+ raise
960
+ except Exception as e:
961
+ log.error(f"获取用户邮箱失败: {e}")
962
+ raise HTTPException(status_code=500, detail=str(e))
963
+
964
+ @router.post("/creds/refresh-all-emails")
965
+ async def refresh_all_user_emails(token: str = Depends(verify_token)):
966
+ """刷新所有凭证文件的用户邮箱地址"""
967
+ try:
968
+ await ensure_credential_manager_initialized()
969
+
970
+ # 获取存储适配器
971
+ storage_adapter = await get_storage_adapter()
972
+
973
+ # 获取所有凭证文件
974
+ credential_filenames = await storage_adapter.list_credentials()
975
+
976
+ results = []
977
+ success_count = 0
978
+
979
+ for filename in credential_filenames:
980
+ try:
981
+ email = await credential_manager.get_or_fetch_user_email(filename)
982
+ if email:
983
+ success_count += 1
984
+ results.append({
985
+ "filename": os.path.basename(filename),
986
+ "user_email": email,
987
+ "success": True
988
+ })
989
+ else:
990
+ results.append({
991
+ "filename": os.path.basename(filename),
992
+ "user_email": None,
993
+ "success": False,
994
+ "error": "无法获取邮箱"
995
+ })
996
+ except Exception as e:
997
+ results.append({
998
+ "filename": os.path.basename(filename),
999
+ "user_email": None,
1000
+ "success": False,
1001
+ "error": str(e)
1002
+ })
1003
+
1004
+ return JSONResponse(content={
1005
+ "success_count": success_count,
1006
+ "total_count": len(credential_filenames),
1007
+ "results": results,
1008
+ "message": f"成功获取 {success_count}/{len(credential_filenames)} 个邮箱地址"
1009
+ })
1010
+
1011
+ except Exception as e:
1012
+ log.error(f"批量获取用户邮箱失败: {e}")
1013
+ raise HTTPException(status_code=500, detail=str(e))
1014
+
1015
+ @router.get("/creds/download-all")
1016
+ async def download_all_creds(token: str = Depends(verify_token)):
1017
+ """打包下载所有凭证文件"""
1018
+ try:
1019
+ # 获取存储适配器
1020
+ storage_adapter = await get_storage_adapter()
1021
+
1022
+ # 获取所有凭证文件列表
1023
+ credential_filenames = await storage_adapter.list_credentials()
1024
+
1025
+ if not credential_filenames:
1026
+ raise HTTPException(status_code=404, detail="没有找到凭证文件")
1027
+
1028
+ # 创建内存中的ZIP文件
1029
+ zip_buffer = io.BytesIO()
1030
+
1031
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
1032
+ # 遍历所有凭证文件
1033
+ for filename in credential_filenames:
1034
+ try:
1035
+ credential_data = await storage_adapter.get_credential(filename)
1036
+ if credential_data:
1037
+ # 转换为JSON字符串
1038
+ content = json.dumps(credential_data, ensure_ascii=False, indent=2)
1039
+
1040
+ # 添加到ZIP文件中
1041
+ zip_file.writestr(os.path.basename(filename), content)
1042
+ log.debug(f"已添加到ZIP: {filename}")
1043
+ except Exception as e:
1044
+ log.warning(f"处理凭证文件 {filename} 时出错: {e}")
1045
+ continue
1046
+
1047
+ zip_buffer.seek(0)
1048
+ return Response(
1049
+ content=zip_buffer.getvalue(),
1050
+ media_type="application/zip",
1051
+ headers={"Content-Disposition": "attachment; filename=credentials.zip"}
1052
+ )
1053
+
1054
+ except Exception as e:
1055
+ log.error(f"打包下载失败: {e}")
1056
+ raise HTTPException(status_code=500, detail=str(e))
1057
+
1058
+
1059
+ @router.get("/config/get")
1060
+ async def get_config(token: str = Depends(verify_token)):
1061
+ """获取当前配置"""
1062
+ try:
1063
+ await ensure_credential_manager_initialized()
1064
+
1065
+ # 导入配置相关模块
1066
+
1067
+ # 读取当前配置(包括环境变量和TOML文件中的配置)
1068
+ current_config = {}
1069
+ env_locked = []
1070
+
1071
+ # 基础配置
1072
+ current_config["code_assist_endpoint"] = await config.get_code_assist_endpoint()
1073
+ current_config["credentials_dir"] = await config.get_credentials_dir()
1074
+ current_config["proxy"] = await config.get_proxy_config() or ""
1075
+
1076
+ # 代理端点配置
1077
+ current_config["oauth_proxy_url"] = await config.get_oauth_proxy_url()
1078
+ current_config["googleapis_proxy_url"] = await config.get_googleapis_proxy_url()
1079
+ current_config["resource_manager_api_url"] = await config.get_resource_manager_api_url()
1080
+ current_config["service_usage_api_url"] = await config.get_service_usage_api_url()
1081
+
1082
+ # 检查环境变量锁定状态
1083
+ if os.getenv("CODE_ASSIST_ENDPOINT"):
1084
+ env_locked.append("code_assist_endpoint")
1085
+ if os.getenv("CREDENTIALS_DIR"):
1086
+ env_locked.append("credentials_dir")
1087
+ if os.getenv("PROXY"):
1088
+ env_locked.append("proxy")
1089
+ if os.getenv("OAUTH_PROXY_URL"):
1090
+ env_locked.append("oauth_proxy_url")
1091
+ if os.getenv("GOOGLEAPIS_PROXY_URL"):
1092
+ env_locked.append("googleapis_proxy_url")
1093
+ if os.getenv("RESOURCE_MANAGER_API_URL"):
1094
+ env_locked.append("resource_manager_api_url")
1095
+ if os.getenv("SERVICE_USAGE_API_URL"):
1096
+ env_locked.append("service_usage_api_url")
1097
+
1098
+ # 自动封禁配置
1099
+ current_config["auto_ban_enabled"] = await config.get_auto_ban_enabled()
1100
+ current_config["auto_ban_error_codes"] = await config.get_auto_ban_error_codes()
1101
+
1102
+ # 检查环境变量锁定状态
1103
+ if os.getenv("AUTO_BAN"):
1104
+ env_locked.append("auto_ban_enabled")
1105
+
1106
+ # 从存储系统读取配置
1107
+ storage_adapter = await get_storage_adapter()
1108
+ storage_config = await storage_adapter.get_all_config()
1109
+
1110
+ # 合并存储系统配置(不覆盖环境变量)
1111
+ for key, value in storage_config.items():
1112
+ if key not in env_locked:
1113
+ current_config[key] = value
1114
+
1115
+ # 性能配置
1116
+ current_config["calls_per_rotation"] = await config.get_calls_per_rotation()
1117
+
1118
+ # 429重试配置
1119
+ current_config["retry_429_max_retries"] = await config.get_retry_429_max_retries()
1120
+ current_config["retry_429_enabled"] = await config.get_retry_429_enabled()
1121
+ current_config["retry_429_interval"] = await config.get_retry_429_interval()
1122
+
1123
+
1124
+ # 抗截断配置
1125
+ current_config["anti_truncation_max_attempts"] = await config.get_anti_truncation_max_attempts()
1126
+
1127
+ # 兼容性配置
1128
+ current_config["compatibility_mode_enabled"] = await config.get_compatibility_mode_enabled()
1129
+
1130
+ # 服务器配置
1131
+ current_config["host"] = await config.get_server_host()
1132
+ current_config["port"] = await config.get_server_port()
1133
+ current_config["api_password"] = await config.get_api_password()
1134
+ current_config["panel_password"] = await config.get_panel_password()
1135
+ current_config["password"] = await config.get_server_password()
1136
+
1137
+ # 检查其他环境变量锁定状态
1138
+ if os.getenv("RETRY_429_MAX_RETRIES"):
1139
+ env_locked.append("retry_429_max_retries")
1140
+ if os.getenv("RETRY_429_ENABLED"):
1141
+ env_locked.append("retry_429_enabled")
1142
+ if os.getenv("RETRY_429_INTERVAL"):
1143
+ env_locked.append("retry_429_interval")
1144
+ if os.getenv("ANTI_TRUNCATION_MAX_ATTEMPTS"):
1145
+ env_locked.append("anti_truncation_max_attempts")
1146
+ if os.getenv("COMPATIBILITY_MODE"):
1147
+ env_locked.append("compatibility_mode_enabled")
1148
+ if os.getenv("HOST"):
1149
+ env_locked.append("host")
1150
+ if os.getenv("PORT"):
1151
+ env_locked.append("port")
1152
+ if os.getenv("API_PASSWORD"):
1153
+ env_locked.append("api_password")
1154
+ if os.getenv("PANEL_PASSWORD"):
1155
+ env_locked.append("panel_password")
1156
+ if os.getenv("PASSWORD"):
1157
+ env_locked.append("password")
1158
+
1159
+ return JSONResponse(content={
1160
+ "config": current_config,
1161
+ "env_locked": env_locked
1162
+ })
1163
+
1164
+ except Exception as e:
1165
+ log.error(f"获取配置失败: {e}")
1166
+ raise HTTPException(status_code=500, detail=str(e))
1167
+
1168
+
1169
+ @router.post("/config/save")
1170
+ async def save_config(request: ConfigSaveRequest, token: str = Depends(verify_token)):
1171
+ """保存配置到TOML文件"""
1172
+ try:
1173
+ await ensure_credential_manager_initialized()
1174
+ new_config = request.config
1175
+
1176
+ log.debug(f"收到的配置数据: {list(new_config.keys())}")
1177
+ log.debug(f"收到的password值: {new_config.get('password', 'NOT_FOUND')}")
1178
+
1179
+ # 验证配置项
1180
+ if "calls_per_rotation" in new_config:
1181
+ if not isinstance(new_config["calls_per_rotation"], int) or new_config["calls_per_rotation"] < 1:
1182
+ raise HTTPException(status_code=400, detail="凭证轮换调用次数必须是大于0的整数")
1183
+
1184
+
1185
+ if "retry_429_max_retries" in new_config:
1186
+ if not isinstance(new_config["retry_429_max_retries"], int) or new_config["retry_429_max_retries"] < 0:
1187
+ raise HTTPException(status_code=400, detail="最大429重试次数必须是大于等于0的整数")
1188
+
1189
+ if "retry_429_enabled" in new_config:
1190
+ if not isinstance(new_config["retry_429_enabled"], bool):
1191
+ raise HTTPException(status_code=400, detail="429重试开关必须是布尔值")
1192
+
1193
+ # 验证新的配置项
1194
+ if "retry_429_interval" in new_config:
1195
+ try:
1196
+ interval = float(new_config["retry_429_interval"])
1197
+ if interval < 0.01 or interval > 10:
1198
+ raise HTTPException(status_code=400, detail="429重试间隔必须在0.01-10秒之间")
1199
+ except (ValueError, TypeError):
1200
+ raise HTTPException(status_code=400, detail="429重试间隔必须是有效的数字")
1201
+
1202
+
1203
+ if "anti_truncation_max_attempts" in new_config:
1204
+ if not isinstance(new_config["anti_truncation_max_attempts"], int) or new_config["anti_truncation_max_attempts"] < 1 or new_config["anti_truncation_max_attempts"] > 10:
1205
+ raise HTTPException(status_code=400, detail="抗截断最大重试次数必须是1-10之间的整数")
1206
+
1207
+ if "compatibility_mode_enabled" in new_config:
1208
+ if not isinstance(new_config["compatibility_mode_enabled"], bool):
1209
+ raise HTTPException(status_code=400, detail="兼容性模式开关必须是布尔值")
1210
+
1211
+ # 验证服务器配置
1212
+ if "host" in new_config:
1213
+ if not isinstance(new_config["host"], str) or not new_config["host"].strip():
1214
+ raise HTTPException(status_code=400, detail="服务器主机地址不能为空")
1215
+
1216
+ if "port" in new_config:
1217
+ if not isinstance(new_config["port"], int) or new_config["port"] < 1 or new_config["port"] > 65535:
1218
+ raise HTTPException(status_code=400, detail="端口号必须是1-65535之间的整数")
1219
+
1220
+ if "api_password" in new_config:
1221
+ if not isinstance(new_config["api_password"], str):
1222
+ raise HTTPException(status_code=400, detail="API访问密码必须是字符串")
1223
+
1224
+ if "panel_password" in new_config:
1225
+ if not isinstance(new_config["panel_password"], str):
1226
+ raise HTTPException(status_code=400, detail="控制面板密码必须是字符串")
1227
+
1228
+ if "password" in new_config:
1229
+ if not isinstance(new_config["password"], str):
1230
+ raise HTTPException(status_code=400, detail="访问密码必须是字符串")
1231
+
1232
+ # 读取现有的配置文件
1233
+ credentials_dir = await config.get_credentials_dir()
1234
+ config_file = os.path.join(credentials_dir, "config.toml")
1235
+ existing_config = {}
1236
+
1237
+ try:
1238
+ if os.path.exists(config_file):
1239
+ with open(config_file, "r", encoding="utf-8") as f:
1240
+ existing_config = toml.load(f)
1241
+ except Exception as e:
1242
+ log.warning(f"读取现有配置文件失败: {e}")
1243
+
1244
+ # 只更新不被环境变量锁定的配置项
1245
+ env_locked_keys = set()
1246
+ if os.getenv("CODE_ASSIST_ENDPOINT"):
1247
+ env_locked_keys.add("code_assist_endpoint")
1248
+ if os.getenv("CREDENTIALS_DIR"):
1249
+ env_locked_keys.add("credentials_dir")
1250
+ if os.getenv("PROXY"):
1251
+ env_locked_keys.add("proxy")
1252
+ if os.getenv("OAUTH_PROXY_URL"):
1253
+ env_locked_keys.add("oauth_proxy_url")
1254
+ if os.getenv("GOOGLEAPIS_PROXY_URL"):
1255
+ env_locked_keys.add("googleapis_proxy_url")
1256
+ if os.getenv("AUTO_BAN"):
1257
+ env_locked_keys.add("auto_ban_enabled")
1258
+ if os.getenv("RETRY_429_MAX_RETRIES"):
1259
+ env_locked_keys.add("retry_429_max_retries")
1260
+ if os.getenv("RETRY_429_ENABLED"):
1261
+ env_locked_keys.add("retry_429_enabled")
1262
+ if os.getenv("RETRY_429_INTERVAL"):
1263
+ env_locked_keys.add("retry_429_interval")
1264
+ if os.getenv("ANTI_TRUNCATION_MAX_ATTEMPTS"):
1265
+ env_locked_keys.add("anti_truncation_max_attempts")
1266
+ if os.getenv("COMPATIBILITY_MODE"):
1267
+ env_locked_keys.add("compatibility_mode_enabled")
1268
+ if os.getenv("HOST"):
1269
+ env_locked_keys.add("host")
1270
+ if os.getenv("PORT"):
1271
+ env_locked_keys.add("port")
1272
+ if os.getenv("API_PASSWORD"):
1273
+ env_locked_keys.add("api_password")
1274
+ if os.getenv("PANEL_PASSWORD"):
1275
+ env_locked_keys.add("panel_password")
1276
+ if os.getenv("PASSWORD"):
1277
+ env_locked_keys.add("password")
1278
+
1279
+ for key, value in new_config.items():
1280
+ if key not in env_locked_keys:
1281
+ existing_config[key] = value
1282
+ if key == 'password':
1283
+ log.debug(f"设置password字段为: {value}")
1284
+ elif key == 'api_password':
1285
+ log.debug(f"设置api_password字段为: {value}")
1286
+ elif key == 'panel_password':
1287
+ log.debug(f"设置panel_password字段为: {value}")
1288
+ log.debug(f"最终保存的existing_config中password = {existing_config.get('password', 'NOT_FOUND')}")
1289
+
1290
+ # 直接使用存储适配器保存配置
1291
+ storage_adapter = await get_storage_adapter()
1292
+ for key, value in existing_config.items():
1293
+ await storage_adapter.set_config(key, value)
1294
+
1295
+ # 验证保存后的结果
1296
+ test_api_password = await config.get_api_password()
1297
+ test_panel_password = await config.get_panel_password()
1298
+ test_password = await config.get_server_password()
1299
+ log.debug(f"保存后立即读取的API密码: {test_api_password}")
1300
+ log.debug(f"保存后立即读取的面板密码: {test_panel_password}")
1301
+ log.debug(f"保存后立即读取的通用密码: {test_password}")
1302
+
1303
+ # 热更新配置到内存中的模块(如果可能)
1304
+ hot_updated = [] # 记录成功热更新的配置项
1305
+ restart_required = [] # 记录需要重启的配置项
1306
+
1307
+ # 支持热更新的配置项:
1308
+ # - calls_per_rotation: 凭证轮换调用次数
1309
+ # - proxy: 网络配置
1310
+ # - log_level: 日志级别
1311
+ # - auto_ban_enabled, auto_ban_error_codes: 自动封禁配置
1312
+ # - retry_429_enabled, retry_429_max_retries, retry_429_interval: 429重试配置
1313
+ # - anti_truncation_max_attempts: 抗截断配置
1314
+ # - compatibility_mode_enabled: 兼容性模式
1315
+ # - api_password, panel_password, password: 访问密码
1316
+ #
1317
+ # 需要重启的配置项:
1318
+ # - host, port: 服务器地址和端口
1319
+ # - log_file: 日志文件路径
1320
+
1321
+ try:
1322
+ # save_config_to_toml已经更新了缓存,不需要reload
1323
+
1324
+ # 1. credential_manager配置通过config模块动态获取,无需手动更新
1325
+ if "calls_per_rotation" in new_config and "calls_per_rotation" not in env_locked_keys:
1326
+ # 新的credential_manager会通过get_calls_per_rotation()动态获取最新配置
1327
+ hot_updated.append("calls_per_rotation")
1328
+
1329
+ # 2. 代理配置(部分热更新)
1330
+ if "proxy" in new_config and "proxy" not in env_locked_keys:
1331
+ hot_updated.append("proxy")
1332
+
1333
+ # 代理端点配置(可热更新)
1334
+ proxy_endpoint_configs = ["oauth_proxy_url", "googleapis_proxy_url"]
1335
+ for config_key in proxy_endpoint_configs:
1336
+ if config_key in new_config and config_key not in env_locked_keys:
1337
+ hot_updated.append(config_key)
1338
+
1339
+
1340
+ # 4. 其他可热更新的配置项
1341
+ hot_updatable_configs = [
1342
+ "auto_ban_enabled", "auto_ban_error_codes",
1343
+ "retry_429_enabled", "retry_429_max_retries", "retry_429_interval",
1344
+ "anti_truncation_max_attempts", "compatibility_mode_enabled"
1345
+ ]
1346
+
1347
+ for config_key in hot_updatable_configs:
1348
+ if config_key in new_config and config_key not in env_locked_keys:
1349
+ hot_updated.append(config_key)
1350
+
1351
+ # 4. 需要重启的配置项
1352
+ restart_required_configs = ["host", "port"]
1353
+ for config_key in restart_required_configs:
1354
+ if config_key in new_config and config_key not in env_locked_keys:
1355
+ restart_required.append(config_key)
1356
+
1357
+ # 5. 密码配置(立即生效)
1358
+ password_configs = ["api_password", "panel_password", "password"]
1359
+ for config_key in password_configs:
1360
+ if config_key in new_config and config_key not in env_locked_keys:
1361
+ hot_updated.append(config_key)
1362
+
1363
+ except Exception as e:
1364
+ log.warning(f"热更新配置失败: {e}")
1365
+
1366
+ # 构建响应消息
1367
+ response_data = {
1368
+ "message": "配置保存成功",
1369
+ "saved_config": {k: v for k, v in new_config.items() if k not in env_locked_keys}
1370
+ }
1371
+
1372
+ # 添加热更新状态信息
1373
+ if hot_updated:
1374
+ response_data["hot_updated"] = hot_updated
1375
+
1376
+ if restart_required:
1377
+ response_data["restart_required"] = restart_required
1378
+ response_data["restart_notice"] = f"以下配置项需要重启服务器才能生效: {', '.join(restart_required)}"
1379
+
1380
+ return JSONResponse(content=response_data)
1381
+
1382
+ except HTTPException:
1383
+ raise
1384
+ except Exception as e:
1385
+ log.error(f"保存配置失败: {e}")
1386
+ raise HTTPException(status_code=500, detail=str(e))
1387
+
1388
+
1389
+ @router.post("/auth/load-env-creds")
1390
+ async def load_env_credentials(token: str = Depends(verify_token)):
1391
+ """从环境变量加载凭证文件"""
1392
+ try:
1393
+ result = await load_credentials_from_env()
1394
+
1395
+ if result['loaded_count'] > 0:
1396
+ return JSONResponse(content={
1397
+ "loaded_count": result['loaded_count'],
1398
+ "total_count": result['total_count'],
1399
+ "results": result['results'],
1400
+ "message": result['message']
1401
+ })
1402
+ else:
1403
+ return JSONResponse(content={
1404
+ "loaded_count": 0,
1405
+ "total_count": result['total_count'],
1406
+ "message": result['message'],
1407
+ "results": result['results']
1408
+ })
1409
+
1410
+ except Exception as e:
1411
+ log.error(f"从环境变量加载凭证失败: {e}")
1412
+ raise HTTPException(status_code=500, detail=str(e))
1413
+
1414
+
1415
+ @router.delete("/auth/env-creds")
1416
+ async def clear_env_creds(token: str = Depends(verify_token)):
1417
+ """清除所有从环境变量导入的凭证文件"""
1418
+ try:
1419
+ result = await clear_env_credentials()
1420
+
1421
+ if 'error' in result:
1422
+ raise HTTPException(status_code=500, detail=result['error'])
1423
+
1424
+ return JSONResponse(content={
1425
+ "deleted_count": result['deleted_count'],
1426
+ "deleted_files": result.get('deleted_files', []),
1427
+ "message": result['message']
1428
+ })
1429
+
1430
+ except HTTPException:
1431
+ raise
1432
+ except Exception as e:
1433
+ log.error(f"清除环境变量凭证失败: {e}")
1434
+ raise HTTPException(status_code=500, detail=str(e))
1435
+
1436
+
1437
+ @router.get("/auth/env-creds-status")
1438
+ async def get_env_creds_status(token: str = Depends(verify_token)):
1439
+ """获取环境变量凭证状态"""
1440
+ try:
1441
+ # 检查有哪些环境变量可用
1442
+ available_env_vars = {key: "***已设置***" for key, value in os.environ.items()
1443
+ if key.startswith('GCLI_CREDS_') and value.strip()}
1444
+
1445
+ # 检查自动加载设置
1446
+ auto_load_enabled = await config.get_auto_load_env_creds()
1447
+
1448
+ # 统计已存在的环境变量凭证文件
1449
+ storage_adapter = await get_storage_adapter()
1450
+ all_credentials = await storage_adapter.list_credentials()
1451
+ existing_env_files = [
1452
+ filename for filename in all_credentials
1453
+ if filename.startswith('env-') and filename.endswith('.json')
1454
+ ]
1455
+
1456
+ return JSONResponse(content={
1457
+ "available_env_vars": available_env_vars,
1458
+ "auto_load_enabled": auto_load_enabled,
1459
+ "existing_env_files_count": len(existing_env_files),
1460
+ "existing_env_files": existing_env_files
1461
+ })
1462
+
1463
+ except Exception as e:
1464
+ log.error(f"获取环境变量凭证状态失败: {e}")
1465
+ raise HTTPException(status_code=500, detail=str(e))
1466
+
1467
+
1468
+ # =============================================================================
1469
+ # 实时日志WebSocket (Real-time Logs WebSocket)
1470
+ # =============================================================================
1471
+
1472
+ @router.post("/auth/logs/clear")
1473
+ async def clear_logs(token: str = Depends(verify_token)):
1474
+ """清空日志文件"""
1475
+ try:
1476
+ # 直接使用环境变量获取日志文件路径
1477
+ log_file_path = os.getenv('LOG_FILE', 'log.txt')
1478
+
1479
+ # 检查日志文件是否存在
1480
+ if os.path.exists(log_file_path):
1481
+ try:
1482
+ # 清空文件内容(保留文件),确保以UTF-8编码写入
1483
+ with open(log_file_path, 'w', encoding='utf-8', newline='') as f:
1484
+ f.write('')
1485
+ f.flush() # 强制刷新到磁盘
1486
+ log.info(f"日志文件已清空: {log_file_path}")
1487
+
1488
+ # 通知所有WebSocket连接日志已清空
1489
+ await manager.broadcast("--- 日志文件已清空 ---")
1490
+
1491
+ return JSONResponse(content={"message": f"日志文件已清空: {os.path.basename(log_file_path)}"})
1492
+ except Exception as e:
1493
+ log.error(f"清空日志文件失败: {e}")
1494
+ raise HTTPException(status_code=500, detail=f"清空日志文件失败: {str(e)}")
1495
+ else:
1496
+ return JSONResponse(content={"message": "日志文件不存在"})
1497
+
1498
+ except Exception as e:
1499
+ log.error(f"清空日志文件失败: {e}")
1500
+ raise HTTPException(status_code=500, detail=f"清空日志文件失败: {str(e)}")
1501
+
1502
+ @router.get("/auth/logs/download")
1503
+ async def download_logs(token: str = Depends(verify_token)):
1504
+ """下载日志文件"""
1505
+ try:
1506
+ # 直接使用环境变量获取日志文件路径
1507
+ log_file_path = os.getenv('LOG_FILE', 'log.txt')
1508
+
1509
+ # 检查日志文件是否存在
1510
+ if not os.path.exists(log_file_path):
1511
+ raise HTTPException(status_code=404, detail="日志文件不存在")
1512
+
1513
+ # 检查文件是否为空
1514
+ file_size = os.path.getsize(log_file_path)
1515
+ if file_size == 0:
1516
+ raise HTTPException(status_code=404, detail="日志文件为空")
1517
+
1518
+ # 生成文件名(包含时间戳)
1519
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
1520
+ filename = f"gcli2api_logs_{timestamp}.txt"
1521
+
1522
+ log.info(f"下载日志文件: {log_file_path}")
1523
+
1524
+ return FileResponse(
1525
+ path=log_file_path,
1526
+ filename=filename,
1527
+ media_type='text/plain',
1528
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
1529
+ )
1530
+
1531
+ except HTTPException:
1532
+ raise
1533
+ except Exception as e:
1534
+ log.error(f"下载日志文件失败: {e}")
1535
+ raise HTTPException(status_code=500, detail=f"下载日志文件失败: {str(e)}")
1536
+
1537
+ @router.websocket("/auth/logs/stream")
1538
+ async def websocket_logs(websocket: WebSocket):
1539
+ """WebSocket端点,用于实时日志流"""
1540
+ # 检查连接数限制
1541
+ if not await manager.connect(websocket):
1542
+ return
1543
+
1544
+ try:
1545
+ # 直接使用环境变量获取日志文件路径
1546
+ log_file_path = os.getenv('LOG_FILE', 'log.txt')
1547
+
1548
+ # 发送初始日志(限制为最后50行,减少内存占用)
1549
+ if os.path.exists(log_file_path):
1550
+ try:
1551
+ with open(log_file_path, "r", encoding="utf-8") as f:
1552
+ lines = f.readlines()
1553
+ # 只发送最后50行,减少初始内存消耗
1554
+ for line in lines[-50:]:
1555
+ if line.strip():
1556
+ await websocket.send_text(line.strip())
1557
+ except Exception as e:
1558
+ await websocket.send_text(f"Error reading log file: {e}")
1559
+
1560
+ # 监控日志文件变化
1561
+ last_size = os.path.getsize(log_file_path) if os.path.exists(log_file_path) else 0
1562
+ max_read_size = 8192 # 限制单次读取大小为8KB,防止大量日志造成内存激增
1563
+ check_interval = 2 # 增加检查间隔,减少CPU和I/O开销
1564
+
1565
+ while websocket.client_state == WebSocketState.CONNECTED:
1566
+ await asyncio.sleep(check_interval)
1567
+
1568
+ if os.path.exists(log_file_path):
1569
+ current_size = os.path.getsize(log_file_path)
1570
+ if current_size > last_size:
1571
+ # 限制读取大小,防止单次读取过多内容
1572
+ read_size = min(current_size - last_size, max_read_size)
1573
+
1574
+ try:
1575
+ with open(log_file_path, "r", encoding="utf-8", errors="replace") as f:
1576
+ f.seek(last_size)
1577
+ new_content = f.read(read_size)
1578
+
1579
+ # 处理编码错误的情况
1580
+ if not new_content:
1581
+ last_size = current_size
1582
+ continue
1583
+
1584
+ # 分行发送,避免发送不完整的行
1585
+ lines = new_content.splitlines(keepends=True)
1586
+ if lines:
1587
+ # 如果最后一行没有换行符,保留到下次处理
1588
+ if not lines[-1].endswith('\n') and len(lines) > 1:
1589
+ # 除了最后一行,其他都发送
1590
+ for line in lines[:-1]:
1591
+ if line.strip():
1592
+ await websocket.send_text(line.rstrip())
1593
+ # 更新位置,但要退回最后一行的字节数
1594
+ last_size += len(new_content.encode('utf-8')) - len(lines[-1].encode('utf-8'))
1595
+ else:
1596
+ # 所有行都发送
1597
+ for line in lines:
1598
+ if line.strip():
1599
+ await websocket.send_text(line.rstrip())
1600
+ last_size += len(new_content.encode('utf-8'))
1601
+ except UnicodeDecodeError as e:
1602
+ # 遇到编码错误时,跳过这部分内容
1603
+ log.warning(f"WebSocket日志读取编码错误: {e}, 跳过部分内容")
1604
+ last_size = current_size
1605
+ except Exception as e:
1606
+ await websocket.send_text(f"Error reading new content: {e}")
1607
+ # 发生其他错误时,重置文件位置
1608
+ last_size = current_size
1609
+
1610
+ # 如果文件被截断(如清空日志),重置位置
1611
+ elif current_size < last_size:
1612
+ last_size = 0
1613
+ await websocket.send_text("--- 日志已清空 ---")
1614
+
1615
+ except WebSocketDisconnect:
1616
+ pass
1617
+ except Exception as e:
1618
+ log.error(f"WebSocket logs error: {e}")
1619
+ finally:
1620
+ manager.disconnect(websocket)
1621
+
1622
+
1623
+ # =============================================================================
1624
+ # Usage Statistics API (使用统计API)
1625
+ # =============================================================================
1626
+
1627
+ @router.get("/usage/stats")
1628
+ async def get_usage_statistics(filename: Optional[str] = None, token: str = Depends(verify_token)):
1629
+ """
1630
+ 获取使用统计信息
1631
+
1632
+ Args:
1633
+ filename: 可选,指定凭证文��名。如果不提供则返回所有文件的统计
1634
+
1635
+ Returns:
1636
+ usage statistics for the specified file or all files
1637
+ """
1638
+ try:
1639
+ stats = await get_usage_stats(filename)
1640
+ return JSONResponse(content={
1641
+ "success": True,
1642
+ "data": stats
1643
+ })
1644
+ except Exception as e:
1645
+ log.error(f"获取使用统计失败: {e}")
1646
+ raise HTTPException(status_code=500, detail=str(e))
1647
+
1648
+
1649
+ @router.get("/usage/aggregated")
1650
+ async def get_aggregated_usage_statistics(token: str = Depends(verify_token)):
1651
+ """
1652
+ 获取聚合使用统计信息
1653
+
1654
+ Returns:
1655
+ Aggregated statistics across all credential files
1656
+ """
1657
+ try:
1658
+ stats = await get_aggregated_stats()
1659
+ return JSONResponse(content={
1660
+ "success": True,
1661
+ "data": stats
1662
+ })
1663
+ except Exception as e:
1664
+ log.error(f"获取聚合统计失败: {e}")
1665
+ raise HTTPException(status_code=500, detail=str(e))
1666
+
1667
+
1668
+
1669
+ class UsageLimitsUpdateRequest(BaseModel):
1670
+ filename: str
1671
+ gemini_2_5_pro_limit: Optional[int] = None
1672
+ total_limit: Optional[int] = None
1673
+
1674
+
1675
+ @router.post("/usage/update-limits")
1676
+ async def update_usage_limits(request: UsageLimitsUpdateRequest, token: str = Depends(verify_token)):
1677
+ """
1678
+ 更新指定凭证文件的每日使用限制
1679
+
1680
+ Args:
1681
+ request: 包含文件名和新限制值的请求
1682
+
1683
+ Returns:
1684
+ Success message
1685
+ """
1686
+ try:
1687
+ stats_instance = await get_usage_stats_instance()
1688
+
1689
+ await stats_instance.update_daily_limits(
1690
+ filename=request.filename,
1691
+ gemini_2_5_pro_limit=request.gemini_2_5_pro_limit,
1692
+ total_limit=request.total_limit
1693
+ )
1694
+
1695
+ return JSONResponse(content={
1696
+ "success": True,
1697
+ "message": f"已更新 {request.filename} 的使用限制"
1698
+ })
1699
+
1700
+ except Exception as e:
1701
+ log.error(f"更新使用限制失败: {e}")
1702
+ raise HTTPException(status_code=500, detail=str(e))
1703
+
1704
+
1705
+ class UsageResetRequest(BaseModel):
1706
+ filename: Optional[str] = None
1707
+
1708
+
1709
+ @router.post("/usage/reset")
1710
+ async def reset_usage_statistics(request: UsageResetRequest, token: str = Depends(verify_token)):
1711
+ """
1712
+ 重置使用统计
1713
+
1714
+ Args:
1715
+ request: 包含可选文件名的请求。如果不提供文件名则重置所有统计
1716
+
1717
+ Returns:
1718
+ Success message
1719
+ """
1720
+ try:
1721
+ stats_instance = await get_usage_stats_instance()
1722
+
1723
+ await stats_instance.reset_stats(filename=request.filename)
1724
+
1725
+ if request.filename:
1726
+ message = f"已重置 {request.filename} 的使用统计"
1727
+ else:
1728
+ message = "已重置所有文件的使用统计"
1729
+
1730
+ return JSONResponse(content={
1731
+ "success": True,
1732
+ "message": message
1733
+ })
1734
+
1735
+ except Exception as e:
1736
+ log.error(f"重置使用统计失败: {e}")
1737
+ raise HTTPException(status_code=500, detail=str(e))
1738
+