Spaces:
Paused
Paused
| import asyncio | |
| import os | |
| import json | |
| import logging | |
| import uvicorn | |
| import argparse | |
| import time | |
| import re | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, HTTPException, Request, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Union, Dict, Any | |
| from gemini_webapi import GeminiClient, set_log_level | |
| from httpx import RemoteProtocolError, ReadTimeout | |
| load_dotenv() | |
| # 解析命令行参数 | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Gemini API 代理服务器') | |
| parser.add_argument('--init-timeout', type=float, | |
| default=float(os.environ.get('GEMINI_INIT_TIMEOUT', '180')), | |
| help='客户端初始化超时时间(秒)') | |
| parser.add_argument('--request-timeout', type=float, | |
| default=float(os.environ.get('GEMINI_REQUEST_TIMEOUT', '300')), | |
| help='请求处理超时时间(秒)') | |
| parser.add_argument('--auto-close', type=lambda x: x.lower() == 'true', | |
| default=os.environ.get('GEMINI_AUTO_CLOSE', 'false').lower() == 'true', | |
| help='是否自动关闭客户端') | |
| parser.add_argument('--close-delay', type=float, | |
| default=float(os.environ.get('GEMINI_CLOSE_DELAY', '300')), | |
| help='自动关闭延迟时间(秒)') | |
| parser.add_argument('--auto-refresh', type=lambda x: x.lower() == 'true', | |
| default=os.environ.get('GEMINI_AUTO_REFRESH', 'true').lower() == 'true', | |
| help='是否自动刷新会话') | |
| parser.add_argument('--refresh-interval', type=float, | |
| default=float(os.environ.get('GEMINI_REFRESH_INTERVAL', '540')), | |
| help='刷新间隔(秒)') | |
| parser.add_argument('--host', type=str, | |
| default=os.environ.get('GEMINI_HOST', '0.0.0.0'), | |
| help='服务器监听地址') | |
| parser.add_argument('--port', type=int, | |
| default=int(os.environ.get('GEMINI_PORT', '7860')), | |
| help='服务器端口') | |
| # 添加重试相关参数 | |
| parser.add_argument('--max-retries', type=int, | |
| default=int(os.environ.get('GEMINI_MAX_RETRIES', '3')), | |
| help='最大重试次数') | |
| parser.add_argument('--retry-delay', type=float, | |
| default=float(os.environ.get('GEMINI_RETRY_DELAY', '2')), | |
| help='重试间隔时间(秒)') | |
| parser.add_argument('--retry-exceptions', type=str, | |
| default=os.environ.get('GEMINI_RETRY_EXCEPTIONS', 'RemoteProtocolError,ReadTimeout'), | |
| help='需要重试的异常类型,以逗号分隔') | |
| parser.add_argument('--long-response-mode', type=lambda x: x.lower() == 'true', | |
| default=os.environ.get('GEMINI_LONG_RESPONSE_MODE', 'true').lower() == 'true', | |
| help='是否启用长响应模式,在此模式下会等待更长时间而不是立即重试') | |
| parser.add_argument('--long-response-wait', type=float, | |
| default=float(os.environ.get('GEMINI_LONG_RESPONSE_WAIT', '180')), | |
| help='长响应模式下的等待时间(秒),超过此时间才会重试') | |
| parser.add_argument('--max-long-response-retries', type=int, | |
| default=int(os.environ.get('GEMINI_MAX_LONG_RESPONSE_RETRIES', '5')), | |
| help='长响应模式下的最大重试次数') | |
| parser.add_argument('--keep-conversation-history', type=lambda x: x.lower() == 'true', | |
| default=os.environ.get('GEMINI_KEEP_CONVERSATION_HISTORY', 'true').lower() == 'true', | |
| help='是否保存对话历史,在SillyTavern等应用中非常有用') | |
| parser.add_argument('--filter-thinking-vessel', type=lambda x: x.lower() == 'true', | |
| default=os.environ.get('GEMINI_FILTER_THINKING_VESSEL', 'true').lower() == 'true', | |
| help='是否过滤thinking和vessel标签') | |
| return parser.parse_args() | |
| # 配置日志系统 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(), | |
| logging.FileHandler("gemini_proxy.log") | |
| ] | |
| ) | |
| logger = logging.getLogger("gemini-proxy") | |
| # 设置Gemini API的日志级别 | |
| set_log_level("DEBUG") | |
| # Gemini API凭据 | |
| SECURE_1PSID = os.environ.get("GEMINI_SECURE_1PSID", "") | |
| SECURE_1PSIDTS = os.environ.get("GEMINI_SECURE_1PSIDTS", "") | |
| # logger.info(f"SECURE_1PSID:{SECURE_1PSID}") | |
| # logger.info(f"SECURE_1PSIDTS:{SECURE_1PSIDTS}") | |
| # 支持的模型列表 | |
| SUPPORTED_MODELS = [ | |
| { | |
| "id": "gemini-pro", | |
| "object": "model", | |
| "created": 1678892800, | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-pro", | |
| "parent": None | |
| }, | |
| { | |
| "id": "gemini-2.5-flash-preview-04-17", | |
| "object": "model", | |
| "created": 1713571200, # 2024年4月20日(估计的日期) | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-2.5-flash-preview-04-17", | |
| "parent": None | |
| }, | |
| { | |
| "id": "gemini-2.5-exp-advanced", | |
| "object": "model", | |
| "created": 1713571200, | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-2.5-exp-advanced", | |
| "parent": None | |
| }, | |
| { | |
| "id": "gemini-2.0-exp-advanced", | |
| "object": "model", | |
| "created": 1713571200, | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-2.0-exp-advanced", | |
| "parent": None | |
| }, | |
| { | |
| "id": "gemini-2.5-pro", | |
| "object": "model", | |
| "created": 1713571200, | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-2.5-pro", | |
| "parent": None | |
| }, | |
| { | |
| "id": "gemini-2.5-flash", | |
| "object": "model", | |
| "created": 1713571200, | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-2.5-flash", | |
| "parent": None | |
| }, | |
| { | |
| "id": "gemini-2.0-flash-thinking", | |
| "object": "model", | |
| "created": 1713571200, | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-2.0-flash-thinking", | |
| "parent": None | |
| }, | |
| { | |
| "id": "gemini-2.0-flash", | |
| "object": "model", | |
| "created": 1713571200, | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-2.0-flash", | |
| "parent": None | |
| }, | |
| { | |
| "id": "gemini-2.5-pro-exp-03-25", | |
| "object": "model", | |
| "created": 1711324800, # 2024年3月25日 | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-2.5-pro-exp-03-25", | |
| "parent": None | |
| }, | |
| { | |
| "id": "gemini-2.5-pro-preview-03-25", | |
| "object": "model", | |
| "created": 1711324800, # 2024年3月25日 | |
| "owned_by": "google", | |
| "permission": [], | |
| "root": "gemini-2.5-pro-preview-03-25", | |
| "parent": None | |
| } | |
| ] | |
| app = FastAPI(title="Gemini API Proxy") | |
| # 添加CORS中间件,允许跨域请求 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # 允许所有源 | |
| allow_credentials=True, | |
| allow_methods=["*"], # 允许所有方法 | |
| allow_headers=["*"], # 允许所有头 | |
| ) | |
| # 创建Gemini客户端实例 | |
| client = None | |
| # 存储会话ID到会话对象的映射 | |
| sessions = {} | |
| # 存储会话ID到对话历史的映射 | |
| conversation_history = {} | |
| # 过滤thinking和vessel标签的函数 | |
| def filter_thinking_vessel(text): | |
| if not text: | |
| return text | |
| # 过滤\<thinking>...</thinking>标签 | |
| text = re.sub(r'\\<thinking\\>[\s\S]*?\\<\/thinking\\>', '', text) | |
| # 过滤<thinking>...</thinking>标签 | |
| text = re.sub(r'<thinking>[\s\S]*?<\/thinking>', '', text) | |
| # 过滤\<vessel>...</vessel>标签 | |
| text = re.sub(r'\\<vessel\\>[\s\S]*?\\<\/vessel\\>', '', text) | |
| # 过滤<vessel>...</vessel>标签 | |
| text = re.sub(r'<vessel>[\s\S]*?<\/vessel>', '', text) | |
| return text | |
| # 存储超时配置 | |
| config = { | |
| "init_timeout": 180, | |
| "request_timeout": 300, | |
| "auto_close": False, | |
| "close_delay": 300, | |
| "auto_refresh": True, | |
| "refresh_interval": 540, | |
| # 添加重试配置 | |
| "max_retries": 3, | |
| "retry_delay": 2, | |
| "retry_exceptions": ["RemoteProtocolError", "ReadTimeout"], | |
| # 长响应模式配置 | |
| "long_response_mode": True, | |
| "long_response_wait": 180, | |
| "max_long_response_retries": 5, # 长响应模式下的最大重试次数 | |
| # 对话历史配置 | |
| "keep_conversation_history": True, # 是否保存对话历史 | |
| "filter_thinking_vessel": True # 是否过滤thinking和vessel标签 | |
| } | |
| class ChatRequest(BaseModel): | |
| messages: List[Dict[str, Any]] | |
| model: Optional[str] = None | |
| stream: Optional[bool] = True | |
| temperature: Optional[float] = None | |
| max_tokens: Optional[int] = None | |
| session_id: Optional[str] = None | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| async def startup_event(): | |
| global client | |
| try: | |
| logger.info("正在初始化Gemini客户端...") | |
| logger.info(f"使用初始化超时: {config['init_timeout']}秒") | |
| client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) | |
| await client.init( | |
| timeout=config['init_timeout'], | |
| auto_close=config['auto_close'], | |
| close_delay=config['close_delay'], | |
| auto_refresh=config['auto_refresh'], | |
| refresh_interval=config['refresh_interval'] | |
| ) | |
| logger.info("Gemini客户端初始化成功") | |
| print("Gemini客户端初始化成功") | |
| except Exception as e: | |
| logger.error(f"Gemini客户端初始化失败: {e}", exc_info=True) | |
| raise e | |
| async def shutdown_event(): | |
| global client | |
| if client: | |
| try: | |
| logger.info("正在关闭Gemini客户端...") | |
| await client.close() | |
| logger.info("Gemini客户端已关闭") | |
| print("Gemini客户端已关闭") | |
| except Exception as e: | |
| logger.error(f"关闭Gemini客户端时出错: {e}", exc_info=True) | |
| async def root(): | |
| logger.info("收到根路径请求") | |
| return {"message": "Gemini API Proxy服务正在运行"} | |
| async def list_models(): | |
| """ | |
| 列出支持的模型。 | |
| 这是一个OpenAI API兼容端点,返回支持的模型列表。 | |
| """ | |
| logger.info("收到列出模型请求") | |
| return { | |
| "object": "list", | |
| "data": SUPPORTED_MODELS | |
| } | |
| async def chat_completions(request: ChatRequest): | |
| """ | |
| 兼容OpenAI风格的聊天完成API。 | |
| 这个接口可以被SillyTavern等应用使用。 | |
| """ | |
| global client, sessions, conversation_history | |
| logger.info(f"收到聊天完成请求, 模型: {request.model}, 会话ID: {request.session_id}, 流式: {request.stream}") | |
| if not client or not client.running: | |
| try: | |
| logger.warning("客户端未初始化或已关闭,尝试重新初始化...") | |
| client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) | |
| await client.init( | |
| timeout=config['init_timeout'], | |
| auto_close=config['auto_close'], | |
| close_delay=config['close_delay'], | |
| auto_refresh=config['auto_refresh'], | |
| refresh_interval=config['refresh_interval'] | |
| ) | |
| logger.info("重新初始化成功") | |
| except Exception as e: | |
| logger.error(f"Gemini客户端初始化失败: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Gemini客户端初始化失败: {str(e)}") | |
| try: | |
| session_id = request.session_id | |
| prompt = "" | |
| system_message = "" | |
| conversation_prompt = "" | |
| processed_messages_for_history = [] # 用于存储处理过的消息,以便添加到历史记录 | |
| for message in request.messages: | |
| role = message.get("role", "") | |
| raw_content = message.get("content", "") # 原始content,可能是 str 或 list | |
| # text_content_for_prompt 将是用于构建prompt的纯文本字符串 | |
| text_content_for_prompt = "" | |
| if isinstance(raw_content, str): | |
| text_content_for_prompt = raw_content | |
| elif isinstance(raw_content, list): | |
| # 从列表中提取所有文本部分并拼接 | |
| temp_text_parts = [] | |
| for part in raw_content: | |
| if isinstance(part, dict) and part.get("type") == "text": | |
| temp_text_parts.append(part.get("text", "")) | |
| text_content_for_prompt = "\\n".join(temp_text_parts) | |
| # 对提取出的文本应用过滤规则 | |
| final_text_for_prompt = text_content_for_prompt | |
| if config['filter_thinking_vessel']: | |
| final_text_for_prompt = filter_thinking_vessel(text_content_for_prompt) | |
| # 将原始的 raw_content (str 或 list) 添加到待处理历史记录列表 | |
| # 但要确保内容不为空或仅包含无效部分 | |
| is_effectively_empty_for_history = False | |
| if isinstance(raw_content, str) and not raw_content.strip(): | |
| is_effectively_empty_for_history = True | |
| elif isinstance(raw_content, list): | |
| has_meaningful_part = False | |
| for part in raw_content: | |
| if isinstance(part, dict): | |
| if part.get("type") == "text" and part.get("text", "").strip(): | |
| has_meaningful_part = True; break | |
| elif part.get("type") != "text": # 例如 image_url | |
| has_meaningful_part = True; break | |
| if not has_meaningful_part: | |
| is_effectively_empty_for_history = True | |
| if not is_effectively_empty_for_history: | |
| processed_messages_for_history.append({"role": role, "content": raw_content}) | |
| # 处理系统消息 | |
| if role == "system": | |
| system_message = final_text_for_prompt # system_message 始终是过滤后的字符串 | |
| continue # raw_content 已加入 processed_messages_for_history | |
| # 将过滤后的文本添加到对话prompt | |
| if final_text_for_prompt.strip(): | |
| role_prefix = "User: " if role == "user" else "Assistant: " | |
| conversation_prompt += f"{role_prefix}{final_text_for_prompt}\\n\\n" | |
| # 构建最终的prompt字符串 | |
| if system_message and conversation_prompt: | |
| prompt = f"{system_message}\\n\\n{conversation_prompt}" | |
| elif conversation_prompt: | |
| prompt = conversation_prompt | |
| elif system_message: | |
| prompt = system_message | |
| if not prompt: | |
| logger.error("请求中没有找到有效消息 (after processing for prompt)") | |
| raise HTTPException(status_code=400, detail="没有找到有效消息 (after processing for prompt)") | |
| # 更新对话历史(在实际发送请求之前) | |
| if config['keep_conversation_history'] and session_id: | |
| if session_id not in conversation_history: | |
| conversation_history[session_id] = [] | |
| for msg_hist_entry in processed_messages_for_history: | |
| hist_role = msg_hist_entry["role"] | |
| hist_content_original = msg_hist_entry["content"] # str or list | |
| content_to_store_in_history = hist_content_original | |
| # 仅对非助手消息应用历史过滤,助手消息在接收后过滤 | |
| if config['filter_thinking_vessel'] and hist_role != "assistant": | |
| if isinstance(hist_content_original, str): | |
| content_to_store_in_history = filter_thinking_vessel(hist_content_original) | |
| elif isinstance(hist_content_original, list): | |
| filtered_list_parts = [] | |
| for part in hist_content_original: | |
| if isinstance(part, dict) and part.get("type") == "text": | |
| filtered_text = filter_thinking_vessel(part.get("text", "")) | |
| filtered_list_parts.append({"type": "text", "text": filtered_text}) | |
| else: | |
| filtered_list_parts.append(part) # 保留非文本部分 | |
| content_to_store_in_history = filtered_list_parts | |
| # 避免重复添加的简单检查 | |
| is_duplicate = False | |
| if conversation_history[session_id]: | |
| last_entry = conversation_history[session_id][-1] | |
| if last_entry["role"] == hist_role and \ | |
| isinstance(last_entry["content"], type(content_to_store_in_history)) and \ | |
| last_entry["content"] == content_to_store_in_history: | |
| is_duplicate = True | |
| if not is_duplicate: | |
| # 再次检查待存历史的内容是否有效 | |
| is_effectively_empty_hist_store = False | |
| if isinstance(content_to_store_in_history, str) and not content_to_store_in_history.strip(): | |
| is_effectively_empty_hist_store = True | |
| elif isinstance(content_to_store_in_history, list): | |
| has_meaningful_part_hist_store = False | |
| for part_hist in content_to_store_in_history: | |
| if isinstance(part_hist, dict): | |
| if part_hist.get("type") == "text" and part_hist.get("text", "").strip(): | |
| has_meaningful_part_hist_store = True; break | |
| elif part_hist.get("type") != "text": | |
| has_meaningful_part_hist_store = True; break | |
| if not has_meaningful_part_hist_store: | |
| is_effectively_empty_hist_store = True | |
| if not is_effectively_empty_hist_store: | |
| conversation_history[session_id].append({"role": hist_role, "content": content_to_store_in_history}) | |
| chat_session = None | |
| if session_id and session_id in sessions: | |
| logger.info(f"使用现有会话 {session_id}") | |
| chat_session = sessions[session_id] | |
| if request.model: | |
| chat_session.model = request.model | |
| else: | |
| logger.info("创建新会话") | |
| chat_session = client.start_chat(model=request.model or "gemini-pro") | |
| if session_id: | |
| logger.info(f"将新会话存储为ID {session_id}") | |
| sessions[session_id] = chat_session | |
| request_model_id = request.model or chat_session.model or "gemini-pro" | |
| # 流式输出逻辑 | |
| if request.stream: | |
| logger.info(f"向会话 {chat_session.cid} 发送流式消息 (模拟): {prompt[:50]}...") | |
| async def stream_generator(): | |
| full_assistant_response_text = "" | |
| gemini_response_obj = None # To store the full GeminiResponse object | |
| try: | |
| # 1. 获取完整响应 | |
| async def get_full_gemini_response(): | |
| return await chat_session.send_message(prompt, timeout=config['request_timeout']) | |
| # retry_request 应该用于可以重试的单个请求/响应操作 | |
| # 我们在这里获取整个响应,然后分块 | |
| try: | |
| gemini_response_obj = await retry_request(get_full_gemini_response) | |
| full_assistant_response_text = gemini_response_obj.text | |
| logger.info(f"已获取完整响应用于模拟流,长度: {len(full_assistant_response_text)}") | |
| except Exception as e_fetch: | |
| logger.error(f"获取完整响应以进行模拟流式处理时出错: {e_fetch}", exc_info=True) | |
| error_response_chunk = { | |
| "id": f"chatcmpl-error-{chat_session.cid}", | |
| "object": "chat.completion.chunk", | |
| "created": int(asyncio.get_event_loop().time()), | |
| "model": request_model_id, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": {"role": "assistant", "content": f"Error fetching response: {str(e_fetch)}"}, | |
| "finish_reason": "error" | |
| } | |
| ] | |
| } | |
| yield f"data: {json.dumps(error_response_chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| # 即使出错,也尝试记录到历史 | |
| if config['keep_conversation_history'] and session_id: | |
| error_content_for_history = f"Error fetching response: {str(e_fetch)}" | |
| if config['filter_thinking_vessel']: | |
| error_content_for_history = filter_thinking_vessel(error_content_for_history) | |
| conversation_history[session_id].append({"role": "assistant", "content": error_content_for_history}) | |
| return | |
| # 2. 将完整响应分块并以SSE格式发送 | |
| # 可以根据需要调整分块大小或逻辑(例如按词、按句子) | |
| chunk_size = 20 # 例如每次发送20个字符 | |
| for i in range(0, len(full_assistant_response_text), chunk_size): | |
| text_chunk = full_assistant_response_text[i:i+chunk_size] | |
| stream_response_chunk = { | |
| "id": f"chatcmpl-{chat_session.cid}", | |
| "object": "chat.completion.chunk", | |
| "created": int(asyncio.get_event_loop().time()), | |
| "model": request_model_id, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": {"content": text_chunk}, # 首次块可包含 role: assistant | |
| "finish_reason": None | |
| } | |
| ] | |
| } | |
| # 第一个块可以包含 role | |
| if i == 0: | |
| stream_response_chunk["choices"][0]["delta"]["role"] = "assistant" | |
| yield f"data: {json.dumps(stream_response_chunk)}\n\n" | |
| await asyncio.sleep(0.02) # 短暂延迟以模拟真实流 | |
| # 3. 发送流结束标记 | |
| final_chunk = { | |
| "id": f"chatcmpl-{chat_session.cid}", | |
| "object": "chat.completion.chunk", | |
| "created": int(asyncio.get_event_loop().time()), | |
| "model": request_model_id, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop" | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": len(prompt), | |
| "completion_tokens": len(full_assistant_response_text), | |
| "total_tokens": len(prompt) + len(full_assistant_response_text) | |
| } | |
| } | |
| yield f"data: {json.dumps(final_chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| # 模拟流式响应结束后,处理完整的助手回复并保存到历史 | |
| if config['keep_conversation_history'] and session_id: | |
| assistant_content_for_history = full_assistant_response_text | |
| if config['filter_thinking_vessel']: | |
| assistant_content_for_history = filter_thinking_vessel(assistant_content_for_history) | |
| if assistant_content_for_history.strip(): | |
| conversation_history[session_id].append({"role": "assistant", "content": assistant_content_for_history}) | |
| logger.info(f"成功完成模拟流式响应,总长度: {len(full_assistant_response_text)}") | |
| except Exception as e_stream_outer: | |
| # 这个外部 try-except 捕获 stream_generator 本身的未预料错误 | |
| logger.error(f"模拟流式处理中发生意外错误: {e_stream_outer}", exc_info=True) | |
| # 尝试发送一个最终的错误信息 | |
| error_response_chunk = { | |
| "id": f"chatcmpl-error-outer-{chat_session.cid}", | |
| "object": "chat.completion.chunk", | |
| "created": int(asyncio.get_event_loop().time()), | |
| "model": request_model_id, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": {"role": "assistant", "content": f"Unexpected error during streaming: {str(e_stream_outer)}"}, | |
| "finish_reason": "error" | |
| } | |
| ] | |
| } | |
| yield f"data: {json.dumps(error_response_chunk)}\n\n" | |
| yield f"data: [DONE]\n\n" | |
| if config['keep_conversation_history'] and session_id: | |
| error_content_for_history = f"Unexpected error during streaming: {str(e_stream_outer)}" | |
| if config['filter_thinking_vessel']: | |
| error_content_for_history = filter_thinking_vessel(error_content_for_history) | |
| conversation_history[session_id].append({"role": "assistant", "content": error_content_for_history}) | |
| return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
| # 非流式输出逻辑 (保持原有retry_request) | |
| else: | |
| logger.info(f"向会话 {chat_session.cid} 发送非流式消息: {prompt[:50]}...") | |
| async def send_message_with_session_logic(): | |
| return await chat_session.send_message(prompt, timeout=config['request_timeout']) | |
| response = await retry_request(send_message_with_session_logic) | |
| assistant_response_text = response.text | |
| # 过滤thinking和vessel标签,如果启用了该功能 | |
| if config['filter_thinking_vessel']: | |
| assistant_response_text_for_completion = filter_thinking_vessel(assistant_response_text) | |
| else: | |
| assistant_response_text_for_completion = assistant_response_text | |
| # 如果启用了对话历史功能,将助手的回复添加到历史记录中 | |
| if config['keep_conversation_history'] and session_id: | |
| # 历史记录中的助手回复也需要根据配置进行过滤 | |
| assistant_content_for_history = assistant_response_text # 原始回复 | |
| if config['filter_thinking_vessel']: | |
| assistant_content_for_history = filter_thinking_vessel(assistant_content_for_history) | |
| if assistant_content_for_history.strip(): # 只添加非空消息 | |
| conversation_history[session_id].append({"role": "assistant", "content": assistant_content_for_history}) | |
| completion_response = { | |
| "id": f"chatcmpl-{chat_session.cid}", | |
| "object": "chat.completion", | |
| "created": int(asyncio.get_event_loop().time()), | |
| "model": request_model_id, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": assistant_response_text_for_completion # 返回给客户端的响应 | |
| }, | |
| "finish_reason": "stop" | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": len(prompt), | |
| "completion_tokens": len(assistant_response_text_for_completion), | |
| "total_tokens": len(prompt) + len(assistant_response_text_for_completion) | |
| } | |
| } | |
| logger.info(f"成功生成响应,长度: {len(assistant_response_text_for_completion)}") | |
| return completion_response | |
| except Exception as e: | |
| error_message = str(e) | |
| error_class = e.__class__.__name__ | |
| if "timeout" in error_message.lower(): | |
| logger.error(f"处理聊天完成请求时超时: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=500, | |
| detail="请求处理超时。如果您正在生成长文本或使用高级模型(如带思维链的模型)," | |
| "可能需要更长的处理时间。请稍后重试,或尝试使用不同的模型。" | |
| ) | |
| elif error_class == "RemoteProtocolError": | |
| logger.error(f"服务器连接断开: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=502, | |
| detail="Gemini服务器连接断开。这可能是由于服务器负载过高、网络不稳定或请求内容过于复杂。" | |
| "您可以尝试简化请求、使用不同的模型,或稍后重试。" | |
| ) | |
| else: | |
| logger.error(f"处理聊天完成请求时出错: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}") | |
| async def gemini_direct(request: Request): | |
| """ | |
| 直接向Gemini API发送请求,不做任何格式转换。 | |
| """ | |
| global client | |
| logger.info("收到直接Gemini API请求") | |
| if not client or not client.running: | |
| try: | |
| logger.warning("客户端未初始化或已关闭,尝试重新初始化...") | |
| client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) | |
| await client.init( | |
| timeout=config['init_timeout'], | |
| auto_close=config['auto_close'], | |
| close_delay=config['close_delay'], | |
| auto_refresh=config['auto_refresh'], | |
| refresh_interval=config['refresh_interval'] | |
| ) | |
| logger.info("重新初始化成功") | |
| except Exception as e: | |
| logger.error(f"Gemini客户端初始化失败: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Gemini客户端初始化失败: {str(e)}") | |
| try: | |
| # 获取请求体 | |
| body = await request.json() | |
| prompt = body.get("prompt", "") | |
| model = body.get("model", "") # 获取可能的模型参数 | |
| if not prompt: | |
| logger.error("请求中缺少'prompt'字段") | |
| raise HTTPException(status_code=400, detail="请求中缺少'prompt'字段") | |
| # 发送请求到Gemini,使用重试逻辑包装 | |
| logger.info(f"向Gemini发送直接请求: {prompt[:50]}...") | |
| async def generate_content_with_retry(): | |
| # 如果指定了模型,使用指定的模型 | |
| if model: | |
| return await client.generate_content(prompt, model=model, timeout=config['request_timeout']) | |
| else: | |
| return await client.generate_content(prompt, timeout=config['request_timeout']) | |
| response = await retry_request(generate_content_with_retry) | |
| # 构造响应 | |
| result = { | |
| "response": response.text, | |
| "images": [{"url": img.url, "title": img.title} for img in response.images], | |
| "thoughts": response.thoughts | |
| } | |
| logger.info(f"成功生成响应,长度: {len(response.text)}, 图片数量: {len(response.images)}") | |
| return result | |
| except Exception as e: | |
| error_message = str(e) | |
| error_class = e.__class__.__name__ | |
| if "timeout" in error_message.lower(): | |
| logger.error(f"处理直接Gemini请求时超时: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=500, | |
| detail="请求处理超时。如果您正在生成长文本或使用高级模型(如带思维链的模型)," | |
| "可能需要更长的处理时间。请稍后重试,或尝试使用不同的模型。" | |
| ) | |
| elif error_class == "RemoteProtocolError": | |
| logger.error(f"服务器连接断开: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=502, | |
| detail="Gemini服务器连接断开。这可能是由于服务器负载过高、网络不稳定或请求内容过于复杂。" | |
| "您可以尝试简化请求、使用不同的模型,或稍后重试。" | |
| ) | |
| else: | |
| logger.error(f"处理直接Gemini请求时出错: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}") | |
| async def gemini_chat(request: Request): | |
| """ | |
| 创建或使用一个Gemini聊天会话。 | |
| """ | |
| global client, sessions | |
| logger.info("收到Gemini聊天请求") | |
| if not client or not client.running: | |
| try: | |
| logger.warning("客户端未初始化或已关闭,尝试重新初始化...") | |
| client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) | |
| await client.init( | |
| timeout=config['init_timeout'], | |
| auto_close=config['auto_close'], | |
| close_delay=config['close_delay'], | |
| auto_refresh=config['auto_refresh'], | |
| refresh_interval=config['refresh_interval'] | |
| ) | |
| logger.info("重新初始化成功") | |
| except Exception as e: | |
| logger.error(f"Gemini客户端初始化失败: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Gemini客户端初始化失败: {str(e)}") | |
| try: | |
| # 获取请求体 | |
| body = await request.json() | |
| prompt = body.get("prompt", "") | |
| session_id = body.get("session_id", "") | |
| model = body.get("model", "") # 添加模型参数的获取 | |
| if not prompt: | |
| logger.error("请求中缺少'prompt'字段") | |
| raise HTTPException(status_code=400, detail="请求中缺少'prompt'字段") | |
| # 获取或创建会话 | |
| chat_session = None | |
| if session_id and session_id in sessions: | |
| logger.info(f"使用现有会话 {session_id}") | |
| chat_session = sessions[session_id] | |
| if not chat_session: | |
| logger.info("创建新会话") | |
| chat_session = client.start_chat(model=model or "gemini-pro") | |
| if session_id: | |
| logger.info(f"将新会话存储为ID {session_id}") | |
| sessions[session_id] = chat_session | |
| # 发送消息,使用重试逻辑包装 | |
| logger.info(f"向会话发送消息: {prompt[:50]}...") | |
| if model: # 如果请求中指定了模型,更新会话的模型 | |
| chat_session.model = model | |
| async def send_message_with_retry(): | |
| return await chat_session.send_message(prompt, timeout=config['request_timeout']) | |
| response = await retry_request(send_message_with_retry) | |
| # 构造响应 | |
| result = { | |
| "session_id": chat_session.cid, | |
| "response": response.text, | |
| "images": [{"url": img.url, "title": img.title} for img in response.images], | |
| "thoughts": response.thoughts | |
| } | |
| logger.info(f"成功生成响应,长度: {len(response.text)}, 图片数量: {len(response.images)}") | |
| return result | |
| except Exception as e: | |
| error_message = str(e) | |
| error_class = e.__class__.__name__ | |
| if "timeout" in error_message.lower(): | |
| logger.error(f"处理Gemini聊天请求时超时: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=500, | |
| detail="请求处理超时。如果您正在生成长文本或使用高级模型(如带思维链的模型)," | |
| "可能需要更长的处理时间。请稍后重试,或尝试使用不同的模型。" | |
| ) | |
| elif error_class == "RemoteProtocolError": | |
| logger.error(f"服务器连接断开: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=502, | |
| detail="Gemini服务器连接断开。这可能是由于服务器负载过高、网络不稳定或请求内容过于复杂。" | |
| "您可以尝试简化请求、使用不同的模型,或稍后重试。" | |
| ) | |
| else: | |
| logger.error(f"处理Gemini聊天请求时出错: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"处理请求时出错: {str(e)}") | |
| # 添加重试请求辅助函数 | |
| async def retry_request(func, *args, **kwargs): | |
| """ | |
| 带重试功能的请求辅助函数 | |
| Parameters: | |
| ----------- | |
| func : 异步函数 | |
| 要执行的异步函数 | |
| *args : 参数列表 | |
| 传递给func的位置参数 | |
| **kwargs : 关键字参数 | |
| 传递给func的关键字参数 | |
| Returns: | |
| -------- | |
| 返回func的执行结果 | |
| Raises: | |
| ------- | |
| 最后一次尝试的异常 | |
| """ | |
| max_retries = config["max_retries"] | |
| retry_delay = config["retry_delay"] | |
| retry_exceptions = config["retry_exceptions"] | |
| long_response_mode = config["long_response_mode"] | |
| long_response_wait = config["long_response_wait"] | |
| last_exception = None | |
| retry_count = 0 | |
| while retry_count <= max_retries: | |
| try: | |
| if retry_count > 0: | |
| logger.info(f"第{retry_count}次重试...") | |
| return await func(*args, **kwargs) | |
| except Exception as e: | |
| # 获取异常类型名称 | |
| exception_name = e.__class__.__name__ | |
| error_message = str(e) | |
| logger.warning(f"请求异常: {exception_name} - {error_message}") | |
| last_exception = e | |
| # 特殊处理RemoteProtocolError异常,可能是长响应导致的 | |
| if exception_name == "RemoteProtocolError" and "Server disconnected" in error_message and long_response_mode: | |
| # 使用单独的计数器跟踪长响应模式的重试次数 | |
| long_retry_count = getattr(func, '_long_retry_count', 0) + 1 | |
| setattr(func, '_long_retry_count', long_retry_count) | |
| max_long_retries = config.get("max_long_response_retries", 3) | |
| if long_retry_count <= max_long_retries: | |
| logger.info(f"检测到服务器断开连接,可能是长响应导致。启用长响应模式,等待 {long_response_wait} 秒... (第 {long_retry_count}/{max_long_retries} 次尝试)") | |
| # 在长响应模式下,我们不增加普通重试次数,而是使用单独的计数器 | |
| await asyncio.sleep(long_response_wait) # 等待更长时间 | |
| # 尝试直接重新调用原函数,而不是重新发送请求 | |
| logger.info("尝试重新连接并获取响应...") | |
| try: | |
| # 如果是聊天会话,尝试直接获取最新的消息 | |
| # 这里我们不重新发送请求,而是尝试直接获取响应 | |
| return await func(*args, **kwargs) | |
| except Exception as retry_error: | |
| logger.warning(f"重新连接失败: {str(retry_error)}") | |
| continue | |
| else: | |
| # 如果长响应模式重试次数超过了最大值,则切换回普通重试模式 | |
| logger.warning(f"长响应模式重试次数已达到最大值 {max_long_retries},切换回普通重试模式") | |
| # 重置长响应模式重试计数器 | |
| setattr(func, '_long_retry_count', 0) | |
| # 检查是否是可重试的异常 | |
| if exception_name in retry_exceptions and retry_count < max_retries: | |
| retry_count += 1 | |
| wait_time = retry_delay * (2 ** (retry_count - 1)) # 指数退避 | |
| logger.info(f"等待 {wait_time} 秒后重试...") | |
| await asyncio.sleep(wait_time) | |
| else: | |
| # 不可重试或已达到最大重试次数,抛出异常 | |
| raise | |
| # 理论上不会执行到这里,因为最后一次失败会在循环中抛出异常 | |
| raise last_exception | |
| if __name__ == "__main__": | |
| # 解析命令行参数并更新配置 | |
| args = parse_args() | |
| config.update({ | |
| "init_timeout": args.init_timeout, | |
| "request_timeout": args.request_timeout, | |
| "auto_close": args.auto_close, | |
| "close_delay": args.close_delay, | |
| "auto_refresh": args.auto_refresh, | |
| "refresh_interval": args.refresh_interval, | |
| # 更新重试配置 | |
| "max_retries": args.max_retries, | |
| "retry_delay": args.retry_delay, | |
| "retry_exceptions": args.retry_exceptions.split(','), | |
| # 更新长响应模式配置 | |
| "long_response_mode": args.long_response_mode, | |
| "long_response_wait": args.long_response_wait, | |
| "max_long_response_retries": args.max_long_response_retries, | |
| # 更新对话历史配置 | |
| "keep_conversation_history": args.keep_conversation_history, | |
| "filter_thinking_vessel": args.filter_thinking_vessel | |
| }) | |
| logger.info(f"启动Gemini API代理服务器,监听地址: {args.host}:{args.port}") | |
| logger.info(f"客户端配置: 初始化超时={config['init_timeout']}秒, 请求超时={config['request_timeout']}秒") | |
| logger.info(f"自动关闭={config['auto_close']}, 关闭延迟={config['close_delay']}秒") | |
| logger.info(f"自动刷新={config['auto_refresh']}, 刷新间隔={config['refresh_interval']}秒") | |
| logger.info(f"重试配置: 最大重试次数={config['max_retries']}, 重试间隔={config['retry_delay']}秒") | |
| logger.info(f"重试异常类型: {', '.join(config['retry_exceptions'])}") | |
| logger.info(f"长响应模式: 启用={config['long_response_mode']}, 等待时间={config['long_response_wait']}秒, 最大重试次数={config['max_long_response_retries']}") | |
| logger.info(f"对话历史配置: 保存历史={config['keep_conversation_history']}, 过滤thinking和vessel={config['filter_thinking_vessel']}") | |
| uvicorn.run("gemini_proxy_server:app", host=args.host, port=args.port) |