| |
|
|
| import os |
| import time |
| import base64 |
| import logging |
| import asyncio |
| import tempfile |
| from fastapi import APIRouter, Request, HTTPException, Depends, status |
| from fastapi.responses import JSONResponse |
| from typing import Dict, Any, List, Optional |
|
|
| |
| from app.api.auth import ( |
| get_user_api_key, |
| get_admin_api_key, |
| get_auth_token, |
| ) |
| from app.api.metrics import ( |
| update_metrics_on_request, |
| update_metrics_on_response, |
| get_current_metrics, |
| ) |
| from app.core.gemini_client_manager import get_gemini_client, reload_gemini_cookies |
| from app.core.session_manager import ( |
| Session, |
| get_or_create_session, |
| get_session, |
| delete_session_by_id, |
| list_all_sessions, |
| cleanup_expired_sessions, |
| ) |
| from app.services.image_proxy_service import proxy_image_request |
| from app.utils.gemini_converter import format_gemini_response_to_openai, convert_openai_messages_to_gemini_history |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| router = APIRouter() |
|
|
| |
| @router.post("/refresh_cookies", summary="刷新 Gemini Cookies (管理员权限)") |
| async def refresh_gemini_cookies_endpoint(admin_token: str = Depends(get_admin_api_key)): |
| """ |
| 管理员接口,用于重新加载 Gemini Cookies。 |
| 调用此接口将从环境变量中重新读取 GEMINI_PSID_COOKIES 和 GEMINI_PSIDTS_COOKIES, |
| 并清空所有活跃的 GeminiClient 实例,强制它们在下次请求时重新初始化。 |
| """ |
| logger.info(f"管理员令牌 '{admin_token}' 请求刷新 Gemini Cookies。") |
| reload_gemini_cookies() |
| return JSONResponse(content={"status": "ok", "message": "Gemini Cookies 已重新加载,所有客户端实例已清空。"}) |
|
|
| @router.get("/image_proxy", summary="图片代理接口") |
| async def image_proxy_endpoint(image_url: str, request: Request): |
| """ |
| 代理图片下载,解决前端跨域或防盗链问题。 |
| """ |
| return await proxy_image_request(image_url, request) |
|
|
| @router.post("/sessions", summary="创建新会话") |
| async def create_session_endpoint(auth_token: str = Depends(get_auth_token)): |
| """创建新会话并返回会话ID""" |
| |
| cleanup_expired_sessions(auth_token) |
| |
| |
| try: |
| client = await get_gemini_client(auth_token) |
| except Exception as e: |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) |
|
|
| |
| session = get_or_create_session(auth_token, None, client) |
| |
| return JSONResponse({ |
| "thread_id": session.thread_id, |
| "name": session.name |
| }) |
|
|
| @router.get("/sessions", summary="获取会话列表") |
| async def list_sessions_endpoint(auth_token: str = Depends(get_auth_token)): |
| """获取当前认证令牌下的会话列表""" |
| sessions_list = list_all_sessions(auth_token) |
| return JSONResponse(sessions_list) |
|
|
| @router.get("/sessions/{thread_id}/history", summary="获取会话历史") |
| async def get_session_history(thread_id: str, auth_token: str = Depends(get_auth_token)): |
| """获取指定会话的历史消息""" |
| |
| session = get_session(auth_token, thread_id) |
| |
| if not session: |
| logger.warning(f"认证令牌 {auth_token} 请求会话历史,但会话 {thread_id} 未找到。") |
| raise HTTPException(status_code=404, detail="会话未找到") |
| |
| if not session._chat_instance: |
| logger.warning(f"认证令牌 {auth_token} 请求会话历史,会话 {thread_id} 存在但聊天实例未初始化。") |
| raise HTTPException(status_code=404, detail="会话聊天实例未初始化") |
| |
| |
| history = [] |
| if hasattr(session._chat_instance, 'history') and session._chat_instance.history: |
| |
| |
| for i, part in enumerate(session._chat_instance.history): |
| role = "user" if i % 2 == 0 else "model" |
| content = "" |
| image_urls = [] |
|
|
| if hasattr(part, 'text'): |
| content = part.text |
| |
| |
| |
|
|
| history.append({ |
| "role": role, |
| "parts": [{"text": content}] |
| }) |
| logger.debug(f"认证令牌 {auth_token} 会话 {thread_id} 历史消息数量: {len(history)}") |
| else: |
| logger.warning(f"认证令牌 {auth_token} 会话 {thread_id} 的聊天实例没有 'history' 属性或历史为空。") |
| |
| return JSONResponse({ |
| "thread_id": thread_id, |
| "history": history |
| }) |
|
|
| @router.delete("/sessions/{thread_id}", summary="删除会话") |
| async def delete_session_endpoint(thread_id: str, auth_token: str = Depends(get_auth_token)): |
| """删除指定会话""" |
| if delete_session_by_id(auth_token, thread_id): |
| return JSONResponse({"status": "deleted"}) |
| raise HTTPException(status_code=404, detail="Session not found") |
|
|
| @router.post("/chat/completions", summary="OpenAI 兼容聊天完成接口") |
| async def create_chat_completion(request: Request, auth_token: str = Depends(get_auth_token)): |
| """ |
| 接收 OpenAI 格式的聊天完成请求,并代理到 Gemini。 |
| """ |
| start_time = time.time() |
| client_host = request.client.host |
| request_body = {} |
| temp_file_path = None |
|
|
| |
| update_metrics_on_request(auth_token, client_host) |
|
|
| try: |
| |
| request_body = await request.json() |
| |
| |
| client = await get_gemini_client(auth_token) |
|
|
| |
| model_name = request_body.get("model", "gemini-2.5-flash") |
| |
| |
| model_mapping = { |
| "gemini-2.5-flash-preview-05-20": "gemini-2.5-flash", |
| "gemini-pro": "gemini-2.5-pro", |
| |
| } |
| model_name = model_mapping.get(model_name, model_name) |
|
|
| messages = request_body.get("messages") |
| temperature = request_body.get("temperature") |
| max_tokens = request_body.get("max_tokens") |
| top_p = request_body.get("top_p") |
|
|
| thread_id = request_body.get("thread_id") |
|
|
| if not messages: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, detail="缺少 messages 参数" |
| ) |
|
|
| |
| session = get_or_create_session(auth_token, thread_id, client) |
| chat = session._chat_instance |
|
|
| |
| current_user_message = messages[-1] |
| if current_user_message.get("role") != "user": |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="最后一条消息必须是用户消息" |
| ) |
|
|
| user_input_content = current_user_message.get("content") |
|
|
| |
| gemini_input_content: Any |
| if isinstance(user_input_content, str): |
| gemini_input_content = user_input_content |
| elif isinstance(user_input_content, list): |
| |
| gemini_input_content = [] |
| for item in user_input_content: |
| item_type = item.get("type") |
| if item_type == "text": |
| gemini_input_content.append(item.get("text", "")) |
| |
| |
| if len(gemini_input_content) == 1 and isinstance(gemini_input_content[0], str): |
| gemini_input_content = gemini_input_content[0].strip() |
| else: |
| gemini_input_content = [part.strip() for part in gemini_input_content if part.strip()] |
| if not gemini_input_content: |
| gemini_input_content = None |
| else: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="用户消息内容必须是字符串或列表", |
| ) |
|
|
| |
| image_data_b64 = request_body.get("image_data") |
| image_mime_type = request_body.get("image_mime_type") |
| |
| |
| if image_data_b64 and image_mime_type: |
| try: |
| image_bytes = base64.b64decode(image_data_b64) |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{image_mime_type.split('/')[-1]}") as temp_file: |
| temp_file.write(image_bytes) |
| temp_file_path = temp_file.name |
| logger.info(f"临时图片文件已保存: {temp_file_path}") |
| except Exception as e: |
| logger.error(f"解码或保存图片失败: {e}", exc_info=True) |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"图片数据处理失败: {e}", |
| ) |
|
|
| |
| gemini_response = None |
| max_retries = 3 |
| retry_delay = 5 |
|
|
| for attempt in range(max_retries): |
| try: |
| |
| send_params = {} |
| if temperature is not None: |
| send_params["temperature"] = temperature |
| if max_tokens is not None: |
| send_params["max_tokens"] = max_tokens |
| if top_p is not None: |
| send_params["top_p"] = top_p |
|
|
| |
| if temp_file_path: |
| gemini_response = await chat.send_message( |
| prompt=gemini_input_content, files=[temp_file_path], **send_params |
| ) |
| else: |
| gemini_response = await chat.send_message( |
| prompt=gemini_input_content, **send_params |
| ) |
| |
| break |
| except Exception as e: |
| if attempt < max_retries - 1: |
| logger.warning( |
| f"认证令牌 {auth_token} 会话 {session.thread_id} 调用 Gemini API 失败 (尝试 {attempt + 1}/{max_retries}),将在 {retry_delay} 秒后重试: {e}" |
| ) |
| await asyncio.sleep(retry_delay) |
| else: |
| logger.error( |
| f"认证令牌 {auth_token} 会话 {session.thread_id} 调用 Gemini API 失败,重试次数耗尽: {e}", |
| exc_info=True, |
| ) |
| |
| delete_session_by_id(auth_token, session.thread_id) |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail=f"Gemini API 调用失败,重试次数耗尽: {e}", |
| ) |
|
|
| if not gemini_response: |
| logger.error(f"认证令牌 {auth_token} 会话 {session.thread_id} 调用 Gemini API 未返回响应,尽管没有捕获到异常。") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Gemini API 返回空响应", |
| ) |
|
|
| |
| openai_response = await format_gemini_response_to_openai(gemini_response) |
|
|
| |
| end_time = time.time() |
| response_time = end_time - start_time |
| update_metrics_on_response(auth_token, "success", response_time, client_host, request_body, openai_response) |
|
|
| |
| openai_response["thread_id"] = session.thread_id |
| return JSONResponse(content=openai_response) |
|
|
| except HTTPException as e: |
| update_metrics_on_response(auth_token, "failed", time.time() - start_time, client_host, request_body, {"detail": e.detail}) |
| raise e |
| except Exception as e: |
| logger.error(f"处理请求时发生错误。", exc_info=True) |
| update_metrics_on_response(auth_token, "failed", time.time() - start_time, client_host, request_body, {"detail": "内部服务器错误"}) |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="内部服务器错误,请联系管理员。", |
| ) |
| finally: |
| |
| if temp_file_path and os.path.exists(temp_file_path): |
| os.remove(temp_file_path) |
| logger.info(f"临时图片文件已删除: {temp_file_path}") |
|
|