|
|
""" |
|
|
Base API Client - 共用的 API 客户端基础功能 |
|
|
提供错误处理、自动封禁、重试逻辑等共同功能 |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
from datetime import datetime, timezone |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
from fastapi import Response |
|
|
|
|
|
from config import ( |
|
|
get_auto_ban_enabled, |
|
|
get_auto_ban_error_codes, |
|
|
get_retry_429_enabled, |
|
|
get_retry_429_interval, |
|
|
get_retry_429_max_retries, |
|
|
) |
|
|
from log import log |
|
|
from src.credential_manager import CredentialManager |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def check_should_auto_ban(status_code: int) -> bool: |
|
|
""" |
|
|
检查是否应该触发自动封禁 |
|
|
|
|
|
Args: |
|
|
status_code: HTTP状态码 |
|
|
|
|
|
Returns: |
|
|
bool: 是否应该触发自动封禁 |
|
|
""" |
|
|
return ( |
|
|
await get_auto_ban_enabled() |
|
|
and status_code in await get_auto_ban_error_codes() |
|
|
) |
|
|
|
|
|
|
|
|
async def handle_auto_ban( |
|
|
credential_manager: CredentialManager, |
|
|
status_code: int, |
|
|
credential_name: str, |
|
|
mode: str = "geminicli" |
|
|
) -> None: |
|
|
""" |
|
|
处理自动封禁:直接禁用凭证 |
|
|
|
|
|
Args: |
|
|
credential_manager: 凭证管理器实例 |
|
|
status_code: HTTP状态码 |
|
|
credential_name: 凭证名称 |
|
|
mode: 模式(geminicli 或 antigravity) |
|
|
""" |
|
|
if credential_manager and credential_name: |
|
|
log.warning( |
|
|
f"[{mode.upper()} AUTO_BAN] Status {status_code} triggers auto-ban for credential: {credential_name}" |
|
|
) |
|
|
await credential_manager.set_cred_disabled( |
|
|
credential_name, True, mode=mode |
|
|
) |
|
|
|
|
|
|
|
|
async def handle_error_with_retry( |
|
|
credential_manager: CredentialManager, |
|
|
status_code: int, |
|
|
credential_name: str, |
|
|
retry_enabled: bool, |
|
|
attempt: int, |
|
|
max_retries: int, |
|
|
retry_interval: float, |
|
|
mode: str = "geminicli" |
|
|
) -> bool: |
|
|
""" |
|
|
统一处理错误和重试逻辑 |
|
|
|
|
|
仅在以下情况下进行自动重试: |
|
|
1. 429错误(速率限制) |
|
|
2. 导致凭证封禁的错误(AUTO_BAN_ERROR_CODES配置) |
|
|
|
|
|
Args: |
|
|
credential_manager: 凭证管理器实例 |
|
|
status_code: HTTP状态码 |
|
|
credential_name: 凭证名称 |
|
|
retry_enabled: 是否启用重试 |
|
|
attempt: 当前重试次数 |
|
|
max_retries: 最大重试次数 |
|
|
retry_interval: 重试间隔 |
|
|
mode: 模式(geminicli 或 antigravity) |
|
|
|
|
|
Returns: |
|
|
bool: True表示需要继续重试,False表示不需要重试 |
|
|
""" |
|
|
|
|
|
should_auto_ban = await check_should_auto_ban(status_code) |
|
|
|
|
|
if should_auto_ban: |
|
|
|
|
|
await handle_auto_ban(credential_manager, status_code, credential_name, mode) |
|
|
|
|
|
|
|
|
if retry_enabled and attempt < max_retries: |
|
|
log.info( |
|
|
f"[{mode.upper()} RETRY] Retrying with next credential after auto-ban " |
|
|
f"(status {status_code}, attempt {attempt + 1}/{max_retries})" |
|
|
) |
|
|
await asyncio.sleep(retry_interval) |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
if status_code == 429 and retry_enabled and attempt < max_retries: |
|
|
log.info( |
|
|
f"[{mode.upper()} RETRY] 429 rate limit encountered, retrying " |
|
|
f"(attempt {attempt + 1}/{max_retries})" |
|
|
) |
|
|
await asyncio.sleep(retry_interval) |
|
|
return True |
|
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_retry_config() -> Dict[str, Any]: |
|
|
""" |
|
|
获取重试配置 |
|
|
|
|
|
Returns: |
|
|
包含重试配置的字典 |
|
|
""" |
|
|
return { |
|
|
"retry_enabled": await get_retry_429_enabled(), |
|
|
"max_retries": await get_retry_429_max_retries(), |
|
|
"retry_interval": await get_retry_429_interval(), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def record_api_call_success( |
|
|
credential_manager: CredentialManager, |
|
|
credential_name: str, |
|
|
mode: str = "geminicli", |
|
|
model_key: Optional[str] = None |
|
|
) -> None: |
|
|
""" |
|
|
记录API调用成功 |
|
|
|
|
|
Args: |
|
|
credential_manager: 凭证管理器实例 |
|
|
credential_name: 凭证名称 |
|
|
mode: 模式(geminicli 或 antigravity) |
|
|
model_key: 模型键(用于模型级CD) |
|
|
""" |
|
|
if credential_manager and credential_name: |
|
|
await credential_manager.record_api_call_result( |
|
|
credential_name, True, mode=mode, model_key=model_key |
|
|
) |
|
|
|
|
|
|
|
|
async def record_api_call_error( |
|
|
credential_manager: CredentialManager, |
|
|
credential_name: str, |
|
|
status_code: int, |
|
|
cooldown_until: Optional[float] = None, |
|
|
mode: str = "geminicli", |
|
|
model_key: Optional[str] = None |
|
|
) -> None: |
|
|
""" |
|
|
记录API调用错误 |
|
|
|
|
|
Args: |
|
|
credential_manager: 凭证管理器实例 |
|
|
credential_name: 凭证名称 |
|
|
status_code: HTTP状态码 |
|
|
cooldown_until: 冷却截止时间(Unix时间戳) |
|
|
mode: 模式(geminicli 或 antigravity) |
|
|
model_key: 模型键(用于模型级CD) |
|
|
""" |
|
|
if credential_manager and credential_name: |
|
|
await credential_manager.record_api_call_result( |
|
|
credential_name, |
|
|
False, |
|
|
status_code, |
|
|
cooldown_until=cooldown_until, |
|
|
mode=mode, |
|
|
model_key=model_key |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def parse_and_log_cooldown( |
|
|
error_text: str, |
|
|
mode: str = "geminicli" |
|
|
) -> Optional[float]: |
|
|
""" |
|
|
解析并记录冷却时间 |
|
|
|
|
|
Args: |
|
|
error_text: 错误响应文本 |
|
|
mode: 模式(geminicli 或 antigravity) |
|
|
|
|
|
Returns: |
|
|
冷却截止时间(Unix时间戳),如果解析失败则返回None |
|
|
""" |
|
|
try: |
|
|
error_data = json.loads(error_text) |
|
|
cooldown_until = parse_quota_reset_timestamp(error_data) |
|
|
if cooldown_until: |
|
|
log.info( |
|
|
f"[{mode.upper()}] 检测到quota冷却时间: " |
|
|
f"{datetime.fromtimestamp(cooldown_until, timezone.utc).isoformat()}" |
|
|
) |
|
|
return cooldown_until |
|
|
except Exception as parse_err: |
|
|
log.debug(f"[{mode.upper()}] Failed to parse cooldown time: {parse_err}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def collect_streaming_response(stream_generator) -> Response: |
|
|
""" |
|
|
将Gemini流式响应收集为一条完整的非流式响应 |
|
|
|
|
|
Args: |
|
|
stream_generator: 流式响应生成器,产生 "data: {json}" 格式的行或Response对象 |
|
|
|
|
|
Returns: |
|
|
Response: 合并后的完整响应对象 |
|
|
|
|
|
Example: |
|
|
>>> async for line in stream_generator: |
|
|
... # line format: "data: {...}" or Response object |
|
|
>>> response = await collect_streaming_response(stream_generator) |
|
|
""" |
|
|
|
|
|
merged_response = { |
|
|
"response": { |
|
|
"candidates": [{ |
|
|
"content": { |
|
|
"parts": [], |
|
|
"role": "model" |
|
|
}, |
|
|
"finishReason": None, |
|
|
"safetyRatings": [], |
|
|
"citationMetadata": None |
|
|
}], |
|
|
"usageMetadata": { |
|
|
"promptTokenCount": 0, |
|
|
"candidatesTokenCount": 0, |
|
|
"totalTokenCount": 0 |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
collected_text = [] |
|
|
collected_thought_text = [] |
|
|
collected_other_parts = [] |
|
|
has_data = False |
|
|
line_count = 0 |
|
|
|
|
|
log.debug("[STREAM COLLECTOR] Starting to collect streaming response") |
|
|
|
|
|
try: |
|
|
async for line in stream_generator: |
|
|
line_count += 1 |
|
|
|
|
|
|
|
|
if isinstance(line, Response): |
|
|
log.debug(f"[STREAM COLLECTOR] 收到错误Response,状态码: {line.status_code}") |
|
|
return line |
|
|
|
|
|
|
|
|
if isinstance(line, bytes): |
|
|
line_str = line.decode('utf-8', errors='ignore') |
|
|
log.debug(f"[STREAM COLLECTOR] Processing bytes line {line_count}: {line_str[:200] if line_str else 'empty'}") |
|
|
elif isinstance(line, str): |
|
|
line_str = line |
|
|
log.debug(f"[STREAM COLLECTOR] Processing line {line_count}: {line_str[:200] if line_str else 'empty'}") |
|
|
else: |
|
|
log.debug(f"[STREAM COLLECTOR] Skipping non-string/bytes line: {type(line)}") |
|
|
continue |
|
|
|
|
|
|
|
|
if not line_str.startswith("data: "): |
|
|
log.debug(f"[STREAM COLLECTOR] Skipping line without 'data: ' prefix: {line_str[:100]}") |
|
|
continue |
|
|
|
|
|
raw = line_str[6:].strip() |
|
|
if raw == "[DONE]": |
|
|
log.debug("[STREAM COLLECTOR] Received [DONE] marker") |
|
|
break |
|
|
|
|
|
try: |
|
|
log.debug(f"[STREAM COLLECTOR] Parsing JSON: {raw[:200]}") |
|
|
chunk = json.loads(raw) |
|
|
has_data = True |
|
|
log.debug(f"[STREAM COLLECTOR] Chunk keys: {chunk.keys() if isinstance(chunk, dict) else type(chunk)}") |
|
|
|
|
|
|
|
|
response_obj = chunk.get("response", {}) |
|
|
if not response_obj: |
|
|
log.debug("[STREAM COLLECTOR] No 'response' key in chunk, trying direct access") |
|
|
response_obj = chunk |
|
|
|
|
|
candidates = response_obj.get("candidates", []) |
|
|
log.debug(f"[STREAM COLLECTOR] Found {len(candidates)} candidates") |
|
|
if not candidates: |
|
|
log.debug(f"[STREAM COLLECTOR] No candidates in chunk, chunk structure: {list(chunk.keys()) if isinstance(chunk, dict) else type(chunk)}") |
|
|
continue |
|
|
|
|
|
candidate = candidates[0] |
|
|
|
|
|
|
|
|
content = candidate.get("content", {}) |
|
|
parts = content.get("parts", []) |
|
|
log.debug(f"[STREAM COLLECTOR] Processing {len(parts)} parts from candidate") |
|
|
|
|
|
for part in parts: |
|
|
if not isinstance(part, dict): |
|
|
continue |
|
|
|
|
|
|
|
|
text = part.get("text", "") |
|
|
if text: |
|
|
|
|
|
if part.get("thought", False): |
|
|
collected_thought_text.append(text) |
|
|
log.debug(f"[STREAM COLLECTOR] Collected thought text: {text[:100]}") |
|
|
else: |
|
|
collected_text.append(text) |
|
|
log.debug(f"[STREAM COLLECTOR] Collected regular text: {text[:100]}") |
|
|
|
|
|
elif "inlineData" in part or "fileData" in part or "executableCode" in part or "codeExecutionResult" in part: |
|
|
collected_other_parts.append(part) |
|
|
log.debug(f"[STREAM COLLECTOR] Collected non-text part: {list(part.keys())}") |
|
|
|
|
|
|
|
|
if candidate.get("finishReason"): |
|
|
merged_response["response"]["candidates"][0]["finishReason"] = candidate["finishReason"] |
|
|
|
|
|
if candidate.get("safetyRatings"): |
|
|
merged_response["response"]["candidates"][0]["safetyRatings"] = candidate["safetyRatings"] |
|
|
|
|
|
if candidate.get("citationMetadata"): |
|
|
merged_response["response"]["candidates"][0]["citationMetadata"] = candidate["citationMetadata"] |
|
|
|
|
|
|
|
|
usage = response_obj.get("usageMetadata", {}) |
|
|
if usage: |
|
|
merged_response["response"]["usageMetadata"].update(usage) |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
log.debug(f"[STREAM COLLECTOR] Failed to parse JSON chunk: {e}") |
|
|
continue |
|
|
except Exception as e: |
|
|
log.debug(f"[STREAM COLLECTOR] Error processing chunk: {e}") |
|
|
continue |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"[STREAM COLLECTOR] Error collecting stream after {line_count} lines: {e}") |
|
|
return Response( |
|
|
content=json.dumps({"error": f"收集流式响应失败: {str(e)}"}), |
|
|
status_code=500, |
|
|
media_type="application/json" |
|
|
) |
|
|
|
|
|
log.debug(f"[STREAM COLLECTOR] Finished iteration, has_data={has_data}, line_count={line_count}") |
|
|
|
|
|
|
|
|
if not has_data: |
|
|
log.error(f"[STREAM COLLECTOR] No data collected from stream after {line_count} lines") |
|
|
return Response( |
|
|
content=json.dumps({"error": "No data collected from stream"}), |
|
|
status_code=500, |
|
|
media_type="application/json" |
|
|
) |
|
|
|
|
|
|
|
|
final_parts = [] |
|
|
|
|
|
|
|
|
if collected_thought_text: |
|
|
final_parts.append({ |
|
|
"text": "".join(collected_thought_text), |
|
|
"thought": True |
|
|
}) |
|
|
|
|
|
|
|
|
if collected_text: |
|
|
final_parts.append({ |
|
|
"text": "".join(collected_text) |
|
|
}) |
|
|
|
|
|
|
|
|
final_parts.extend(collected_other_parts) |
|
|
|
|
|
|
|
|
if not final_parts: |
|
|
final_parts.append({"text": ""}) |
|
|
|
|
|
merged_response["response"]["candidates"][0]["content"]["parts"] = final_parts |
|
|
|
|
|
log.info(f"[STREAM COLLECTOR] Collected {len(collected_text)} text chunks, {len(collected_thought_text)} thought chunks, and {len(collected_other_parts)} other parts") |
|
|
|
|
|
|
|
|
if "response" in merged_response and "candidates" not in merged_response: |
|
|
log.debug(f"[STREAM COLLECTOR] 展开response包装") |
|
|
merged_response = merged_response["response"] |
|
|
|
|
|
|
|
|
return Response( |
|
|
content=json.dumps(merged_response, ensure_ascii=False).encode('utf-8'), |
|
|
status_code=200, |
|
|
headers={}, |
|
|
media_type="application/json" |
|
|
) |
|
|
|
|
|
|
|
|
def parse_quota_reset_timestamp(error_response: dict) -> Optional[float]: |
|
|
""" |
|
|
从Google API错误响应中提取quota重置时间戳 |
|
|
|
|
|
Args: |
|
|
error_response: Google API返回的错误响应字典 |
|
|
|
|
|
Returns: |
|
|
Unix时间戳(秒),如果无法解析则返回None |
|
|
|
|
|
示例错误响应: |
|
|
{ |
|
|
"error": { |
|
|
"code": 429, |
|
|
"message": "You have exhausted your capacity...", |
|
|
"status": "RESOURCE_EXHAUSTED", |
|
|
"details": [ |
|
|
{ |
|
|
"@type": "type.googleapis.com/google.rpc.ErrorInfo", |
|
|
"reason": "QUOTA_EXHAUSTED", |
|
|
"metadata": { |
|
|
"quotaResetTimeStamp": "2025-11-30T14:57:24Z", |
|
|
"quotaResetDelay": "13h19m1.20964964s" |
|
|
} |
|
|
} |
|
|
] |
|
|
} |
|
|
} |
|
|
""" |
|
|
try: |
|
|
details = error_response.get("error", {}).get("details", []) |
|
|
|
|
|
for detail in details: |
|
|
if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": |
|
|
reset_timestamp_str = detail.get("metadata", {}).get("quotaResetTimeStamp") |
|
|
|
|
|
if reset_timestamp_str: |
|
|
if reset_timestamp_str.endswith("Z"): |
|
|
reset_timestamp_str = reset_timestamp_str.replace("Z", "+00:00") |
|
|
|
|
|
reset_dt = datetime.fromisoformat(reset_timestamp_str) |
|
|
if reset_dt.tzinfo is None: |
|
|
reset_dt = reset_dt.replace(tzinfo=timezone.utc) |
|
|
|
|
|
return reset_dt.astimezone(timezone.utc).timestamp() |
|
|
|
|
|
return None |
|
|
|
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def get_model_group(model_name: str) -> str: |
|
|
""" |
|
|
获取模型组,用于 GCLI CD 机制。 |
|
|
|
|
|
Args: |
|
|
model_name: 模型名称 |
|
|
|
|
|
Returns: |
|
|
"pro" 或 "flash" |
|
|
|
|
|
说明: |
|
|
- pro 组: gemini-2.5-pro, gemini-3-pro-preview 共享额度 |
|
|
- flash 组: gemini-2.5-flash 单独额度 |
|
|
""" |
|
|
|
|
|
|
|
|
if "flash" in model_name.lower(): |
|
|
return "flash" |
|
|
else: |
|
|
|
|
|
return "pro" |