Gem2a / app /api /gemini_proxy.py
misonL's picture
Feat: Adapt Gemini chat history and improve session cleanup
a43b63a
Raw
History Blame Contribute Delete
14.3 kB
# app/api/gemini_proxy.py
import os # 导入 os 模块用于读取环境变量
import time
import base64
import logging
import asyncio
import tempfile # 导入 tempfile 模块用于创建临时文件
from fastapi import APIRouter, Request, HTTPException, Depends, status
from fastapi.responses import JSONResponse
from typing import Dict, Any, List, Optional
# 导入自定义模块
from app.api.auth import (
get_user_api_key,
get_admin_api_key,
get_auth_token,
) # 从 auth 模块导入认证依赖
from app.api.metrics import (
update_metrics_on_request,
update_metrics_on_response,
get_current_metrics,
) # 从 metrics 模块导入指标和记录获取函数
from app.core.gemini_client_manager import get_gemini_client, reload_gemini_cookies # 从 core 模块导入 Gemini 客户端管理器
from app.core.session_manager import (
Session,
get_or_create_session,
get_session,
delete_session_by_id,
list_all_sessions,
cleanup_expired_sessions,
) # 从 core 模块导入会话管理器
from app.services.image_proxy_service import proxy_image_request # 从 services 模块导入图片代理服务
from app.utils.gemini_converter import format_gemini_response_to_openai, convert_openai_messages_to_gemini_history # 从 utils 模块导入格式转换工具
logger = logging.getLogger(__name__)
# 创建 APIRouter 实例
router = APIRouter()
# 定义一个管理员接口来刷新 Gemini Cookies
@router.post("/refresh_cookies", summary="刷新 Gemini Cookies (管理员权限)")
async def refresh_gemini_cookies_endpoint(admin_token: str = Depends(get_admin_api_key)):
"""
管理员接口,用于重新加载 Gemini Cookies。
调用此接口将从环境变量中重新读取 GEMINI_PSID_COOKIES 和 GEMINI_PSIDTS_COOKIES,
并清空所有活跃的 GeminiClient 实例,强制它们在下次请求时重新初始化。
"""
logger.info(f"管理员令牌 '{admin_token}' 请求刷新 Gemini Cookies。")
reload_gemini_cookies()
return JSONResponse(content={"status": "ok", "message": "Gemini Cookies 已重新加载,所有客户端实例已清空。"})
@router.get("/image_proxy", summary="图片代理接口")
async def image_proxy_endpoint(image_url: str, request: Request):
"""
代理图片下载,解决前端跨域或防盗链问题。
"""
return await proxy_image_request(image_url, request)
@router.post("/sessions", summary="创建新会话")
async def create_session_endpoint(auth_token: str = Depends(get_auth_token)):
"""创建新会话并返回会话ID"""
# 清理过期会话
cleanup_expired_sessions(auth_token)
# 获取 Gemini 客户端实例
try:
client = await get_gemini_client(auth_token)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
# 创建新会话
session = get_or_create_session(auth_token, None, client) # thread_id 为 None 表示创建新会话
return JSONResponse({
"thread_id": session.thread_id,
"name": session.name
})
@router.get("/sessions", summary="获取会话列表")
async def list_sessions_endpoint(auth_token: str = Depends(get_auth_token)):
"""获取当前认证令牌下的会话列表"""
sessions_list = list_all_sessions(auth_token) # 清理并获取最新会话列表
return JSONResponse(sessions_list)
@router.get("/sessions/{thread_id}/history", summary="获取会话历史")
async def get_session_history(thread_id: str, auth_token: str = Depends(get_auth_token)):
"""获取指定会话的历史消息"""
session = get_session(auth_token, thread_id)
if not session:
logger.warning(f"认证令牌 {auth_token} 请求会话历史,但会话 {thread_id} 未找到。")
raise HTTPException(status_code=404, detail="会话未找到")
if not session._chat_instance:
logger.warning(f"认证令牌 {auth_token} 请求会话历史,会话 {thread_id} 存在但聊天实例未初始化。")
raise HTTPException(status_code=404, detail="会话聊天实例未初始化")
# 获取会话历史
history = []
if hasattr(session._chat_instance, 'history') and session._chat_instance.history:
# gemini_webapi 的 history 是一个包含 Part 对象的列表
# 需要将其转换为 OpenAI 兼容的 messages 格式
for i, part in enumerate(session._chat_instance.history):
role = "user" if i % 2 == 0 else "model" # 假设交替角色
content = ""
image_urls = []
if hasattr(part, 'text'):
content = part.text
# 如果有图片,这里需要更复杂的逻辑来处理,目前 gemini_webapi 的 history 不直接暴露图片 URL
# 对于图片,我们可能需要从原始请求中存储或重新生成,或者在前端处理
history.append({
"role": role,
"parts": [{"text": content}] # OpenAI 兼容的 parts 格式
})
logger.debug(f"认证令牌 {auth_token} 会话 {thread_id} 历史消息数量: {len(history)}")
else:
logger.warning(f"认证令牌 {auth_token} 会话 {thread_id} 的聊天实例没有 'history' 属性或历史为空。")
return JSONResponse({
"thread_id": thread_id,
"history": history
})
@router.delete("/sessions/{thread_id}", summary="删除会话")
async def delete_session_endpoint(thread_id: str, auth_token: str = Depends(get_auth_token)):
"""删除指定会话"""
if delete_session_by_id(auth_token, thread_id):
return JSONResponse({"status": "deleted"})
raise HTTPException(status_code=404, detail="Session not found")
@router.post("/chat/completions", summary="OpenAI 兼容聊天完成接口")
async def create_chat_completion(request: Request, auth_token: str = Depends(get_auth_token)):
"""
接收 OpenAI 格式的聊天完成请求,并代理到 Gemini。
"""
start_time = time.time()
client_host = request.client.host # 获取客户端 IP
request_body = {} # 初始化 request_body,以防在 try 块外部访问
temp_file_path = None # 确保 temp_file_path 总是被初始化
# 更新请求指标
update_metrics_on_request(auth_token, client_host) # 使用 auth_token
try:
# 从请求体中获取 OpenAI 格式的请求
request_body = await request.json()
# 获取 Gemini 客户端实例
client = await get_gemini_client(auth_token)
# 解析 OpenAI 格式的请求体
model_name = request_body.get("model", "gemini-2.5-flash") # 默认模型改为 gemini-2.5-flash
# 模型名称映射
model_mapping = {
"gemini-2.5-flash-preview-05-20": "gemini-2.5-flash",
"gemini-pro": "gemini-2.5-pro", # 将 gemini-pro 映射到 gemini-2.5-pro
# 可以根据需要添加更多映射
}
model_name = model_mapping.get(model_name, model_name) # 应用映射
messages = request_body.get("messages")
temperature = request_body.get("temperature")
max_tokens = request_body.get("max_tokens")
top_p = request_body.get("top_p") # 提取 top_p 参数
thread_id = request_body.get("thread_id") # 提取会话 ID
if not messages:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="缺少 messages 参数"
)
# 获取或创建用户对应的会话
session = get_or_create_session(auth_token, thread_id, client)
chat = session._chat_instance # 获取聊天实例
# 获取最后一条用户消息作为当前输入
current_user_message = messages[-1]
if current_user_message.get("role") != "user":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="最后一条消息必须是用户消息"
)
user_input_content = current_user_message.get("content")
# 将当前用户消息内容转换为 gemini-webapi 的文本格式或多模态列表
gemini_input_content: Any
if isinstance(user_input_content, str):
gemini_input_content = user_input_content
elif isinstance(user_input_content, list):
# 处理多模态内容列表
gemini_input_content = []
for item in user_input_content:
item_type = item.get("type")
if item_type == "text":
gemini_input_content.append(item.get("text", ""))
# 注意:这里不再处理 image_url,因为前端会单独发送 base64 数据
# 如果只有文本,合并为字符串;否则保持列表
if len(gemini_input_content) == 1 and isinstance(gemini_input_content[0], str):
gemini_input_content = gemini_input_content[0].strip()
else:
gemini_input_content = [part.strip() for part in gemini_input_content if part.strip()]
if not gemini_input_content: # 如果列表为空,则设为None
gemini_input_content = None
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户消息内容必须是字符串或列表",
)
# 获取图片数据和 MIME 类型
image_data_b64 = request_body.get("image_data")
image_mime_type = request_body.get("image_mime_type")
# 如果有图片数据,将其保存为临时文件
if image_data_b64 and image_mime_type:
try:
image_bytes = base64.b64decode(image_data_b64)
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{image_mime_type.split('/')[-1]}") as temp_file:
temp_file.write(image_bytes)
temp_file_path = temp_file.name
logger.info(f"临时图片文件已保存: {temp_file_path}")
except Exception as e:
logger.error(f"解码或保存图片失败: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"图片数据处理失败: {e}",
)
# 调用 gemini-webapi 发送消息,并添加重试逻辑
gemini_response = None
max_retries = 3 # 最大重试次数
retry_delay = 5 # 重试间隔(秒)
for attempt in range(max_retries):
try:
# 映射 OpenAI 参数到 gemini-webapi
send_params = {}
if temperature is not None:
send_params["temperature"] = temperature
if max_tokens is not None:
send_params["max_tokens"] = max_tokens
if top_p is not None:
send_params["top_p"] = top_p
# 根据是否有图片数据,调用不同的 send_message 方式
if temp_file_path:
gemini_response = await chat.send_message(
prompt=gemini_input_content, files=[temp_file_path], **send_params
)
else:
gemini_response = await chat.send_message(
prompt=gemini_input_content, **send_params
)
# 如果调用成功,跳出重试循环
break
except Exception as e: # 捕获所有异常,包括超时
if attempt < max_retries - 1:
logger.warning(
f"认证令牌 {auth_token} 会话 {session.thread_id} 调用 Gemini API 失败 (尝试 {attempt + 1}/{max_retries}),将在 {retry_delay} 秒后重试: {e}"
)
await asyncio.sleep(retry_delay)
else:
logger.error(
f"认证令牌 {auth_token} 会话 {session.thread_id} 调用 Gemini API 失败,重试次数耗尽: {e}",
exc_info=True,
)
# 发生异常时,移除当前会话,避免无效会话残留
delete_session_by_id(auth_token, session.thread_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Gemini API 调用失败,重试次数耗尽: {e}",
)
if not gemini_response:
logger.error(f"认证令牌 {auth_token} 会话 {session.thread_id} 调用 Gemini API 未返回响应,尽管没有捕获到异常。")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Gemini API 返回空响应",
)
# 将 gemini-webapi 响应转换为 OpenAI 格式
openai_response = await format_gemini_response_to_openai(gemini_response)
# 记录成功调用
end_time = time.time()
response_time = end_time - start_time
update_metrics_on_response(auth_token, "success", response_time, client_host, request_body, openai_response)
# 返回 OpenAI 格式的响应
openai_response["thread_id"] = session.thread_id
return JSONResponse(content=openai_response)
except HTTPException as e:
update_metrics_on_response(auth_token, "failed", time.time() - start_time, client_host, request_body, {"detail": e.detail})
raise e
except Exception as e:
logger.error(f"处理请求时发生错误。", exc_info=True)
update_metrics_on_response(auth_token, "failed", time.time() - start_time, client_host, request_body, {"detail": "内部服务器错误"})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="内部服务器错误,请联系管理员。",
)
finally:
# 确保删除临时文件
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
logger.info(f"临时图片文件已删除: {temp_file_path}")