Spaces:
Paused
Paused
Upload 9 files
Browse files- api_utils/__init__.py +78 -0
- api_utils/app.py +312 -0
- api_utils/auth_utils.py +32 -0
- api_utils/dependencies.py +57 -0
- api_utils/queue_worker.py +266 -0
- api_utils/request_processor.py +795 -0
- api_utils/request_processor_backup.py +274 -0
- api_utils/routes.py +374 -0
- api_utils/utils.py +372 -0
api_utils/__init__.py
CHANGED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API工具模块
|
| 3 |
+
提供FastAPI应用初始化、路由处理和工具函数
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# 应用初始化
|
| 7 |
+
from .app import (
|
| 8 |
+
create_app
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# 路由处理器
|
| 12 |
+
from .routes import (
|
| 13 |
+
read_index,
|
| 14 |
+
get_css,
|
| 15 |
+
get_js,
|
| 16 |
+
get_api_info,
|
| 17 |
+
health_check,
|
| 18 |
+
list_models,
|
| 19 |
+
chat_completions,
|
| 20 |
+
cancel_request,
|
| 21 |
+
get_queue_status,
|
| 22 |
+
websocket_log_endpoint
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# 工具函数
|
| 26 |
+
from .utils import (
|
| 27 |
+
generate_sse_chunk,
|
| 28 |
+
generate_sse_stop_chunk,
|
| 29 |
+
generate_sse_error_chunk,
|
| 30 |
+
use_stream_response,
|
| 31 |
+
clear_stream_queue,
|
| 32 |
+
use_helper_get_response,
|
| 33 |
+
validate_chat_request,
|
| 34 |
+
prepare_combined_prompt,
|
| 35 |
+
estimate_tokens,
|
| 36 |
+
calculate_usage_stats
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# 请求处理器
|
| 40 |
+
from .request_processor import (
|
| 41 |
+
_process_request_refactored
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# 队列工作器
|
| 45 |
+
from .queue_worker import (
|
| 46 |
+
queue_worker
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
__all__ = [
|
| 50 |
+
# 应用初始化
|
| 51 |
+
'create_app',
|
| 52 |
+
# 路由处理器
|
| 53 |
+
'read_index',
|
| 54 |
+
'get_css',
|
| 55 |
+
'get_js',
|
| 56 |
+
'get_api_info',
|
| 57 |
+
'health_check',
|
| 58 |
+
'list_models',
|
| 59 |
+
'chat_completions',
|
| 60 |
+
'cancel_request',
|
| 61 |
+
'get_queue_status',
|
| 62 |
+
'websocket_log_endpoint',
|
| 63 |
+
# 工具函数
|
| 64 |
+
'generate_sse_chunk',
|
| 65 |
+
'generate_sse_stop_chunk',
|
| 66 |
+
'generate_sse_error_chunk',
|
| 67 |
+
'use_stream_response',
|
| 68 |
+
'clear_stream_queue',
|
| 69 |
+
'use_helper_get_response',
|
| 70 |
+
'validate_chat_request',
|
| 71 |
+
'prepare_combined_prompt',
|
| 72 |
+
'estimate_tokens',
|
| 73 |
+
'calculate_usage_stats',
|
| 74 |
+
# 请求处理器
|
| 75 |
+
'_process_request_refactored',
|
| 76 |
+
# 队列工作器
|
| 77 |
+
'queue_worker'
|
| 78 |
+
]
|
api_utils/app.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI应用初始化和生命周期管理
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import multiprocessing
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from contextlib import asynccontextmanager
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from fastapi import FastAPI, Request
|
| 13 |
+
from fastapi.responses import JSONResponse
|
| 14 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 15 |
+
from starlette.types import ASGIApp
|
| 16 |
+
from typing import Callable, Awaitable
|
| 17 |
+
from playwright.async_api import Browser as AsyncBrowser, Playwright as AsyncPlaywright
|
| 18 |
+
|
| 19 |
+
# --- 配置模块导入 ---
|
| 20 |
+
from config import *
|
| 21 |
+
|
| 22 |
+
# --- models模块导入 ---
|
| 23 |
+
from models import WebSocketConnectionManager
|
| 24 |
+
|
| 25 |
+
# --- logging_utils模块导入 ---
|
| 26 |
+
from logging_utils import setup_server_logging, restore_original_streams
|
| 27 |
+
|
| 28 |
+
# --- browser_utils模块导入 ---
|
| 29 |
+
from browser_utils import (
|
| 30 |
+
_initialize_page_logic,
|
| 31 |
+
_close_page_logic,
|
| 32 |
+
load_excluded_models,
|
| 33 |
+
_handle_initial_model_state_and_storage
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
import stream
|
| 37 |
+
from asyncio import Queue, Lock
|
| 38 |
+
from . import auth_utils
|
| 39 |
+
|
| 40 |
+
# 全局状态变量(这些将在server.py中被引用)
|
| 41 |
+
playwright_manager: Optional[AsyncPlaywright] = None
|
| 42 |
+
browser_instance: Optional[AsyncBrowser] = None
|
| 43 |
+
page_instance = None
|
| 44 |
+
is_playwright_ready = False
|
| 45 |
+
is_browser_connected = False
|
| 46 |
+
is_page_ready = False
|
| 47 |
+
is_initializing = False
|
| 48 |
+
|
| 49 |
+
global_model_list_raw_json = None
|
| 50 |
+
parsed_model_list = []
|
| 51 |
+
model_list_fetch_event = None
|
| 52 |
+
|
| 53 |
+
current_ai_studio_model_id = None
|
| 54 |
+
model_switching_lock = None
|
| 55 |
+
|
| 56 |
+
excluded_model_ids = set()
|
| 57 |
+
|
| 58 |
+
request_queue = None
|
| 59 |
+
processing_lock = None
|
| 60 |
+
worker_task = None
|
| 61 |
+
|
| 62 |
+
page_params_cache = {}
|
| 63 |
+
params_cache_lock = None
|
| 64 |
+
|
| 65 |
+
log_ws_manager = None
|
| 66 |
+
|
| 67 |
+
STREAM_QUEUE = None
|
| 68 |
+
STREAM_PROCESS = None
|
| 69 |
+
|
| 70 |
+
# --- Lifespan Context Manager ---
|
| 71 |
+
def _setup_logging():
|
| 72 |
+
import server
|
| 73 |
+
log_level_env = os.environ.get('SERVER_LOG_LEVEL', 'INFO')
|
| 74 |
+
redirect_print_env = os.environ.get('SERVER_REDIRECT_PRINT', 'false')
|
| 75 |
+
server.log_ws_manager = WebSocketConnectionManager()
|
| 76 |
+
return setup_server_logging(
|
| 77 |
+
logger_instance=server.logger,
|
| 78 |
+
log_ws_manager=server.log_ws_manager,
|
| 79 |
+
log_level_name=log_level_env,
|
| 80 |
+
redirect_print_str=redirect_print_env
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def _initialize_globals():
|
| 84 |
+
import server
|
| 85 |
+
server.request_queue = Queue()
|
| 86 |
+
server.processing_lock = Lock()
|
| 87 |
+
server.model_switching_lock = Lock()
|
| 88 |
+
server.params_cache_lock = Lock()
|
| 89 |
+
auth_utils.initialize_keys()
|
| 90 |
+
server.logger.info("API keys and global locks initialized.")
|
| 91 |
+
|
| 92 |
+
def _initialize_proxy_settings():
|
| 93 |
+
import server
|
| 94 |
+
STREAM_PORT = os.environ.get('STREAM_PORT')
|
| 95 |
+
if STREAM_PORT == '0':
|
| 96 |
+
PROXY_SERVER_ENV = os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY')
|
| 97 |
+
else:
|
| 98 |
+
PROXY_SERVER_ENV = f"http://127.0.0.1:{STREAM_PORT or 3120}/"
|
| 99 |
+
|
| 100 |
+
if PROXY_SERVER_ENV:
|
| 101 |
+
server.PLAYWRIGHT_PROXY_SETTINGS = {'server': PROXY_SERVER_ENV}
|
| 102 |
+
if NO_PROXY_ENV:
|
| 103 |
+
server.PLAYWRIGHT_PROXY_SETTINGS['bypass'] = NO_PROXY_ENV.replace(',', ';')
|
| 104 |
+
server.logger.info(f"Playwright proxy settings configured: {server.PLAYWRIGHT_PROXY_SETTINGS}")
|
| 105 |
+
else:
|
| 106 |
+
server.logger.info("No proxy configured for Playwright.")
|
| 107 |
+
|
| 108 |
+
async def _start_stream_proxy():
|
| 109 |
+
import server
|
| 110 |
+
STREAM_PORT = os.environ.get('STREAM_PORT')
|
| 111 |
+
if STREAM_PORT != '0':
|
| 112 |
+
port = int(STREAM_PORT or 3120)
|
| 113 |
+
STREAM_PROXY_SERVER_ENV = os.environ.get('UNIFIED_PROXY_CONFIG') or os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY')
|
| 114 |
+
server.logger.info(f"Starting STREAM proxy on port {port} with upstream proxy: {STREAM_PROXY_SERVER_ENV}")
|
| 115 |
+
server.STREAM_QUEUE = multiprocessing.Queue()
|
| 116 |
+
server.STREAM_PROCESS = multiprocessing.Process(target=stream.start, args=(server.STREAM_QUEUE, port, STREAM_PROXY_SERVER_ENV))
|
| 117 |
+
server.STREAM_PROCESS.start()
|
| 118 |
+
server.logger.info("STREAM proxy process started.")
|
| 119 |
+
|
| 120 |
+
async def _initialize_browser_and_page():
|
| 121 |
+
import server
|
| 122 |
+
from playwright.async_api import async_playwright
|
| 123 |
+
|
| 124 |
+
server.logger.info("Starting Playwright...")
|
| 125 |
+
server.playwright_manager = await async_playwright().start()
|
| 126 |
+
server.is_playwright_ready = True
|
| 127 |
+
server.logger.info("Playwright started.")
|
| 128 |
+
|
| 129 |
+
ws_endpoint = os.environ.get('CAMOUFOX_WS_ENDPOINT')
|
| 130 |
+
launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
|
| 131 |
+
|
| 132 |
+
if not ws_endpoint and launch_mode != "direct_debug_no_browser":
|
| 133 |
+
raise ValueError("CAMOUFOX_WS_ENDPOINT environment variable is missing.")
|
| 134 |
+
|
| 135 |
+
if ws_endpoint:
|
| 136 |
+
server.logger.info(f"Connecting to browser at: {ws_endpoint}")
|
| 137 |
+
server.browser_instance = await server.playwright_manager.firefox.connect(ws_endpoint, timeout=30000)
|
| 138 |
+
server.is_browser_connected = True
|
| 139 |
+
server.logger.info(f"Connected to browser: {server.browser_instance.version}")
|
| 140 |
+
|
| 141 |
+
server.page_instance, server.is_page_ready = await _initialize_page_logic(server.browser_instance)
|
| 142 |
+
if server.is_page_ready:
|
| 143 |
+
await _handle_initial_model_state_and_storage(server.page_instance)
|
| 144 |
+
server.logger.info("Page initialized successfully.")
|
| 145 |
+
else:
|
| 146 |
+
server.logger.error("Page initialization failed.")
|
| 147 |
+
|
| 148 |
+
if not server.model_list_fetch_event.is_set():
|
| 149 |
+
server.model_list_fetch_event.set()
|
| 150 |
+
|
| 151 |
+
async def _shutdown_resources():
|
| 152 |
+
import server
|
| 153 |
+
logger = server.logger
|
| 154 |
+
logger.info("Shutting down resources...")
|
| 155 |
+
|
| 156 |
+
if server.STREAM_PROCESS:
|
| 157 |
+
server.STREAM_PROCESS.terminate()
|
| 158 |
+
logger.info("STREAM proxy terminated.")
|
| 159 |
+
|
| 160 |
+
if server.worker_task and not server.worker_task.done():
|
| 161 |
+
server.worker_task.cancel()
|
| 162 |
+
try:
|
| 163 |
+
await asyncio.wait_for(server.worker_task, timeout=5.0)
|
| 164 |
+
except (asyncio.TimeoutError, asyncio.CancelledError):
|
| 165 |
+
pass
|
| 166 |
+
logger.info("Worker task stopped.")
|
| 167 |
+
|
| 168 |
+
if server.page_instance:
|
| 169 |
+
await _close_page_logic()
|
| 170 |
+
|
| 171 |
+
if server.browser_instance and server.browser_instance.is_connected():
|
| 172 |
+
await server.browser_instance.close()
|
| 173 |
+
logger.info("Browser connection closed.")
|
| 174 |
+
|
| 175 |
+
if server.playwright_manager:
|
| 176 |
+
await server.playwright_manager.stop()
|
| 177 |
+
logger.info("Playwright stopped.")
|
| 178 |
+
|
| 179 |
+
@asynccontextmanager
|
| 180 |
+
async def lifespan(app: FastAPI):
|
| 181 |
+
"""FastAPI application life cycle management"""
|
| 182 |
+
import server
|
| 183 |
+
from server import queue_worker
|
| 184 |
+
|
| 185 |
+
original_streams = sys.stdout, sys.stderr
|
| 186 |
+
initial_stdout, initial_stderr = _setup_logging()
|
| 187 |
+
logger = server.logger
|
| 188 |
+
|
| 189 |
+
_initialize_globals()
|
| 190 |
+
_initialize_proxy_settings()
|
| 191 |
+
load_excluded_models(EXCLUDED_MODELS_FILENAME)
|
| 192 |
+
|
| 193 |
+
server.is_initializing = True
|
| 194 |
+
logger.info("Starting AI Studio Proxy Server...")
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
await _start_stream_proxy()
|
| 198 |
+
await _initialize_browser_and_page()
|
| 199 |
+
|
| 200 |
+
launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
|
| 201 |
+
if server.is_page_ready or launch_mode == "direct_debug_no_browser":
|
| 202 |
+
server.worker_task = asyncio.create_task(queue_worker())
|
| 203 |
+
logger.info("Request processing worker started.")
|
| 204 |
+
else:
|
| 205 |
+
raise RuntimeError("Failed to initialize browser/page, worker not started.")
|
| 206 |
+
|
| 207 |
+
logger.info("Server startup complete.")
|
| 208 |
+
server.is_initializing = False
|
| 209 |
+
yield
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.critical(f"Application startup failed: {e}", exc_info=True)
|
| 212 |
+
await _shutdown_resources()
|
| 213 |
+
raise RuntimeError(f"Application startup failed: {e}") from e
|
| 214 |
+
finally:
|
| 215 |
+
logger.info("Shutting down server...")
|
| 216 |
+
await _shutdown_resources()
|
| 217 |
+
restore_original_streams(initial_stdout, initial_stderr)
|
| 218 |
+
restore_original_streams(*original_streams)
|
| 219 |
+
logger.info("Server shutdown complete.")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class APIKeyAuthMiddleware(BaseHTTPMiddleware):
|
| 223 |
+
def __init__(self, app: ASGIApp):
|
| 224 |
+
super().__init__(app)
|
| 225 |
+
self.excluded_paths = [
|
| 226 |
+
"/v1/models",
|
| 227 |
+
"/health",
|
| 228 |
+
"/docs",
|
| 229 |
+
"/openapi.json",
|
| 230 |
+
# FastAPI 自动生成的其他文档路径
|
| 231 |
+
"/redoc",
|
| 232 |
+
"/favicon.ico"
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable]):
|
| 236 |
+
if not auth_utils.API_KEYS: # 如果 API_KEYS 为空,则不进行验证
|
| 237 |
+
return await call_next(request)
|
| 238 |
+
|
| 239 |
+
# 检查是否是需要保护的路径
|
| 240 |
+
if not request.url.path.startswith("/v1/"):
|
| 241 |
+
return await call_next(request)
|
| 242 |
+
|
| 243 |
+
# 检查是否是排除的路径
|
| 244 |
+
for excluded_path in self.excluded_paths:
|
| 245 |
+
if request.url.path == excluded_path or request.url.path.startswith(excluded_path + "/"):
|
| 246 |
+
return await call_next(request)
|
| 247 |
+
|
| 248 |
+
# 支持多种认证头格式以兼容OpenAI标准
|
| 249 |
+
api_key = None
|
| 250 |
+
|
| 251 |
+
# 1. 优先检查标准的 Authorization: Bearer <token> 头
|
| 252 |
+
auth_header = request.headers.get("Authorization")
|
| 253 |
+
if auth_header and auth_header.startswith("Bearer "):
|
| 254 |
+
api_key = auth_header[7:] # 移除 "Bearer " 前缀
|
| 255 |
+
|
| 256 |
+
# 2. 回退到自定义的 X-API-Key 头(向后兼容)
|
| 257 |
+
if not api_key:
|
| 258 |
+
api_key = request.headers.get("X-API-Key")
|
| 259 |
+
|
| 260 |
+
if not api_key or not auth_utils.verify_api_key(api_key):
|
| 261 |
+
return JSONResponse(
|
| 262 |
+
status_code=401,
|
| 263 |
+
content={
|
| 264 |
+
"error": {
|
| 265 |
+
"message": "Invalid or missing API key. Please provide a valid API key using 'Authorization: Bearer <your_key>' or 'X-API-Key: <your_key>' header.",
|
| 266 |
+
"type": "invalid_request_error",
|
| 267 |
+
"param": None,
|
| 268 |
+
"code": "invalid_api_key"
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
)
|
| 272 |
+
return await call_next(request)
|
| 273 |
+
|
| 274 |
+
def create_app() -> FastAPI:
|
| 275 |
+
"""创建FastAPI应用实例"""
|
| 276 |
+
app = FastAPI(
|
| 277 |
+
title="AI Studio Proxy Server (集成模式)",
|
| 278 |
+
description="通过 Playwright与 AI Studio 交互的代理服务器。",
|
| 279 |
+
version="0.6.0-integrated",
|
| 280 |
+
lifespan=lifespan
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# 添加中间件
|
| 284 |
+
app.add_middleware(APIKeyAuthMiddleware)
|
| 285 |
+
|
| 286 |
+
# 注册路由
|
| 287 |
+
from .routes import (
|
| 288 |
+
read_index, get_css, get_js, get_api_info,
|
| 289 |
+
health_check, list_models, chat_completions,
|
| 290 |
+
cancel_request, get_queue_status, websocket_log_endpoint,
|
| 291 |
+
get_api_keys, add_api_key, test_api_key, delete_api_key
|
| 292 |
+
)
|
| 293 |
+
from fastapi.responses import FileResponse
|
| 294 |
+
|
| 295 |
+
app.get("/", response_class=FileResponse)(read_index)
|
| 296 |
+
app.get("/webui.css")(get_css)
|
| 297 |
+
app.get("/webui.js")(get_js)
|
| 298 |
+
app.get("/api/info")(get_api_info)
|
| 299 |
+
app.get("/health")(health_check)
|
| 300 |
+
app.get("/v1/models")(list_models)
|
| 301 |
+
app.post("/v1/chat/completions")(chat_completions)
|
| 302 |
+
app.post("/v1/cancel/{req_id}")(cancel_request)
|
| 303 |
+
app.get("/v1/queue")(get_queue_status)
|
| 304 |
+
app.websocket("/ws/logs")(websocket_log_endpoint)
|
| 305 |
+
|
| 306 |
+
# API密钥管理端点
|
| 307 |
+
app.get("/api/keys")(get_api_keys)
|
| 308 |
+
app.post("/api/keys")(add_api_key)
|
| 309 |
+
app.post("/api/keys/test")(test_api_key)
|
| 310 |
+
app.delete("/api/keys")(delete_api_key)
|
| 311 |
+
|
| 312 |
+
return app
|
api_utils/auth_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Set
|
| 3 |
+
|
| 4 |
+
API_KEYS: Set[str] = set()
|
| 5 |
+
KEY_FILE_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "key.txt")
|
| 6 |
+
|
| 7 |
+
def load_api_keys():
|
| 8 |
+
"""Loads API keys from the key file into the API_KEYS set."""
|
| 9 |
+
global API_KEYS
|
| 10 |
+
API_KEYS.clear()
|
| 11 |
+
if os.path.exists(KEY_FILE_PATH):
|
| 12 |
+
with open(KEY_FILE_PATH, "r") as f:
|
| 13 |
+
for line in f:
|
| 14 |
+
key = line.strip()
|
| 15 |
+
if key:
|
| 16 |
+
API_KEYS.add(key)
|
| 17 |
+
|
| 18 |
+
def initialize_keys():
|
| 19 |
+
"""Initializes API keys. Ensures key.txt exists and loads keys."""
|
| 20 |
+
if not os.path.exists(KEY_FILE_PATH):
|
| 21 |
+
with open(KEY_FILE_PATH, "w") as f:
|
| 22 |
+
pass # Create an empty file
|
| 23 |
+
load_api_keys()
|
| 24 |
+
|
| 25 |
+
def verify_api_key(api_key_from_header: str) -> bool:
|
| 26 |
+
"""
|
| 27 |
+
Verifies the API key.
|
| 28 |
+
Returns True if API_KEYS is empty (no validation) or if the key is valid.
|
| 29 |
+
"""
|
| 30 |
+
if not API_KEYS:
|
| 31 |
+
return True
|
| 32 |
+
return api_key_from_header in API_KEYS
|
api_utils/dependencies.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI 依赖项模块
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from asyncio import Queue, Lock, Event
|
| 6 |
+
from typing import Dict, Any, List, Set
|
| 7 |
+
|
| 8 |
+
from fastapi import Request
|
| 9 |
+
|
| 10 |
+
def get_logger() -> logging.Logger:
|
| 11 |
+
from server import logger
|
| 12 |
+
return logger
|
| 13 |
+
|
| 14 |
+
def get_log_ws_manager():
|
| 15 |
+
from server import log_ws_manager
|
| 16 |
+
return log_ws_manager
|
| 17 |
+
|
| 18 |
+
def get_request_queue() -> Queue:
|
| 19 |
+
from server import request_queue
|
| 20 |
+
return request_queue
|
| 21 |
+
|
| 22 |
+
def get_processing_lock() -> Lock:
|
| 23 |
+
from server import processing_lock
|
| 24 |
+
return processing_lock
|
| 25 |
+
|
| 26 |
+
def get_worker_task():
|
| 27 |
+
from server import worker_task
|
| 28 |
+
return worker_task
|
| 29 |
+
|
| 30 |
+
def get_server_state() -> Dict[str, Any]:
|
| 31 |
+
from server import is_initializing, is_playwright_ready, is_browser_connected, is_page_ready
|
| 32 |
+
return {
|
| 33 |
+
"is_initializing": is_initializing,
|
| 34 |
+
"is_playwright_ready": is_playwright_ready,
|
| 35 |
+
"is_browser_connected": is_browser_connected,
|
| 36 |
+
"is_page_ready": is_page_ready,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def get_page_instance():
|
| 40 |
+
from server import page_instance
|
| 41 |
+
return page_instance
|
| 42 |
+
|
| 43 |
+
def get_model_list_fetch_event() -> Event:
|
| 44 |
+
from server import model_list_fetch_event
|
| 45 |
+
return model_list_fetch_event
|
| 46 |
+
|
| 47 |
+
def get_parsed_model_list() -> List[Dict[str, Any]]:
|
| 48 |
+
from server import parsed_model_list
|
| 49 |
+
return parsed_model_list
|
| 50 |
+
|
| 51 |
+
def get_excluded_model_ids() -> Set[str]:
|
| 52 |
+
from server import excluded_model_ids
|
| 53 |
+
return excluded_model_ids
|
| 54 |
+
|
| 55 |
+
def get_current_ai_studio_model_id() -> str:
|
| 56 |
+
from server import current_ai_studio_model_id
|
| 57 |
+
return current_ai_studio_model_id
|
api_utils/queue_worker.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
队列工作器模块
|
| 3 |
+
处理请求队列中的任务
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
from fastapi import HTTPException
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def queue_worker():
|
| 13 |
+
"""队列工作器,处理请求队列中的任务"""
|
| 14 |
+
# 导入全局变量
|
| 15 |
+
from server import (
|
| 16 |
+
logger, request_queue, processing_lock, model_switching_lock,
|
| 17 |
+
params_cache_lock
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
logger.info("--- 队列 Worker 已启动 ---")
|
| 21 |
+
|
| 22 |
+
# 检查并初始化全局变量
|
| 23 |
+
if request_queue is None:
|
| 24 |
+
logger.info("初始化 request_queue...")
|
| 25 |
+
from asyncio import Queue
|
| 26 |
+
request_queue = Queue()
|
| 27 |
+
|
| 28 |
+
if processing_lock is None:
|
| 29 |
+
logger.info("初始化 processing_lock...")
|
| 30 |
+
from asyncio import Lock
|
| 31 |
+
processing_lock = Lock()
|
| 32 |
+
|
| 33 |
+
if model_switching_lock is None:
|
| 34 |
+
logger.info("初始化 model_switching_lock...")
|
| 35 |
+
from asyncio import Lock
|
| 36 |
+
model_switching_lock = Lock()
|
| 37 |
+
|
| 38 |
+
if params_cache_lock is None:
|
| 39 |
+
logger.info("初始化 params_cache_lock...")
|
| 40 |
+
from asyncio import Lock
|
| 41 |
+
params_cache_lock = Lock()
|
| 42 |
+
|
| 43 |
+
was_last_request_streaming = False
|
| 44 |
+
last_request_completion_time = 0
|
| 45 |
+
|
| 46 |
+
while True:
|
| 47 |
+
request_item = None
|
| 48 |
+
result_future = None
|
| 49 |
+
req_id = "UNKNOWN"
|
| 50 |
+
completion_event = None
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
# 检查队列中的项目,清理已断开连接的请求
|
| 54 |
+
queue_size = request_queue.qsize()
|
| 55 |
+
if queue_size > 0:
|
| 56 |
+
checked_count = 0
|
| 57 |
+
items_to_requeue = []
|
| 58 |
+
processed_ids = set()
|
| 59 |
+
|
| 60 |
+
while checked_count < queue_size and checked_count < 10:
|
| 61 |
+
try:
|
| 62 |
+
item = request_queue.get_nowait()
|
| 63 |
+
item_req_id = item.get("req_id", "unknown")
|
| 64 |
+
|
| 65 |
+
if item_req_id in processed_ids:
|
| 66 |
+
items_to_requeue.append(item)
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
processed_ids.add(item_req_id)
|
| 70 |
+
|
| 71 |
+
if not item.get("cancelled", False):
|
| 72 |
+
item_http_request = item.get("http_request")
|
| 73 |
+
if item_http_request:
|
| 74 |
+
try:
|
| 75 |
+
if await item_http_request.is_disconnected():
|
| 76 |
+
logger.info(f"[{item_req_id}] (Worker Queue Check) 检测到客户端已断开,标记为取消。")
|
| 77 |
+
item["cancelled"] = True
|
| 78 |
+
item_future = item.get("result_future")
|
| 79 |
+
if item_future and not item_future.done():
|
| 80 |
+
item_future.set_exception(HTTPException(status_code=499, detail=f"[{item_req_id}] Client disconnected while queued."))
|
| 81 |
+
except Exception as check_err:
|
| 82 |
+
logger.error(f"[{item_req_id}] (Worker Queue Check) Error checking disconnect: {check_err}")
|
| 83 |
+
|
| 84 |
+
items_to_requeue.append(item)
|
| 85 |
+
checked_count += 1
|
| 86 |
+
except asyncio.QueueEmpty:
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
for item in items_to_requeue:
|
| 90 |
+
await request_queue.put(item)
|
| 91 |
+
|
| 92 |
+
# 获取下一个请求
|
| 93 |
+
try:
|
| 94 |
+
request_item = await asyncio.wait_for(request_queue.get(), timeout=5.0)
|
| 95 |
+
except asyncio.TimeoutError:
|
| 96 |
+
# 如果5秒内没有新请求,继续循环检查
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
req_id = request_item["req_id"]
|
| 100 |
+
request_data = request_item["request_data"]
|
| 101 |
+
http_request = request_item["http_request"]
|
| 102 |
+
result_future = request_item["result_future"]
|
| 103 |
+
|
| 104 |
+
if request_item.get("cancelled", False):
|
| 105 |
+
logger.info(f"[{req_id}] (Worker) 请求已取消,跳过。")
|
| 106 |
+
if not result_future.done():
|
| 107 |
+
result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 请求已被用户取消"))
|
| 108 |
+
request_queue.task_done()
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
is_streaming_request = request_data.stream
|
| 112 |
+
logger.info(f"[{req_id}] (Worker) 取出请求。模式: {'流式' if is_streaming_request else '非流式'}")
|
| 113 |
+
|
| 114 |
+
# 流式请求间隔控制
|
| 115 |
+
current_time = time.time()
|
| 116 |
+
if was_last_request_streaming and is_streaming_request and (current_time - last_request_completion_time < 1.0):
|
| 117 |
+
delay_time = max(0.5, 1.0 - (current_time - last_request_completion_time))
|
| 118 |
+
logger.info(f"[{req_id}] (Worker) 连续流式请求,添加 {delay_time:.2f}s 延迟...")
|
| 119 |
+
await asyncio.sleep(delay_time)
|
| 120 |
+
|
| 121 |
+
if await http_request.is_disconnected():
|
| 122 |
+
logger.info(f"[{req_id}] (Worker) 客户端在等待锁时断开。取消。")
|
| 123 |
+
if not result_future.done():
|
| 124 |
+
result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 客户端关闭了请求"))
|
| 125 |
+
request_queue.task_done()
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
logger.info(f"[{req_id}] (Worker) 等待处理锁...")
|
| 129 |
+
async with processing_lock:
|
| 130 |
+
logger.info(f"[{req_id}] (Worker) 已获取处理锁。开始核心处理...")
|
| 131 |
+
|
| 132 |
+
if await http_request.is_disconnected():
|
| 133 |
+
logger.info(f"[{req_id}] (Worker) 客户端在获取锁后断开。取消。")
|
| 134 |
+
if not result_future.done():
|
| 135 |
+
result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 客户端关闭了请求"))
|
| 136 |
+
elif result_future.done():
|
| 137 |
+
logger.info(f"[{req_id}] (Worker) Future 在处理前已完成/取消。跳过。")
|
| 138 |
+
else:
|
| 139 |
+
# 调用实际的请求处理函数
|
| 140 |
+
try:
|
| 141 |
+
from api_utils import _process_request_refactored
|
| 142 |
+
returned_value = await _process_request_refactored(
|
| 143 |
+
req_id, request_data, http_request, result_future
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
completion_event, submit_btn_loc, client_disco_checker = None, None, None
|
| 147 |
+
current_request_was_streaming = False
|
| 148 |
+
|
| 149 |
+
if isinstance(returned_value, tuple) and len(returned_value) == 3:
|
| 150 |
+
completion_event, submit_btn_loc, client_disco_checker = returned_value
|
| 151 |
+
if completion_event is not None:
|
| 152 |
+
current_request_was_streaming = True
|
| 153 |
+
logger.info(f"[{req_id}] (Worker) _process_request_refactored returned stream info (event, locator, checker).")
|
| 154 |
+
else:
|
| 155 |
+
current_request_was_streaming = False
|
| 156 |
+
logger.info(f"[{req_id}] (Worker) _process_request_refactored returned a tuple, but completion_event is None (likely non-stream or early exit).")
|
| 157 |
+
elif returned_value is None:
|
| 158 |
+
current_request_was_streaming = False
|
| 159 |
+
logger.info(f"[{req_id}] (Worker) _process_request_refactored returned non-stream completion (None).")
|
| 160 |
+
else:
|
| 161 |
+
current_request_was_streaming = False
|
| 162 |
+
logger.warning(f"[{req_id}] (Worker) _process_request_refactored returned unexpected type: {type(returned_value)}")
|
| 163 |
+
|
| 164 |
+
# 关键修复:在锁内等待流式完成(与原始参考文件一致)
|
| 165 |
+
if completion_event:
|
| 166 |
+
logger.info(f"[{req_id}] (Worker) 等待流式生成器完成信号...")
|
| 167 |
+
try:
|
| 168 |
+
from server import RESPONSE_COMPLETION_TIMEOUT
|
| 169 |
+
await asyncio.wait_for(completion_event.wait(), timeout=RESPONSE_COMPLETION_TIMEOUT/1000 + 60)
|
| 170 |
+
logger.info(f"[{req_id}] (Worker) ✅ 流式生成器完成信号收到。")
|
| 171 |
+
|
| 172 |
+
# 等待发送按钮禁用确认流式响应完全结束
|
| 173 |
+
if submit_btn_loc and client_disco_checker:
|
| 174 |
+
logger.info(f"[{req_id}] (Worker) 流式响应完成,检查并处理发送按钮状态...")
|
| 175 |
+
wait_timeout_ms = 30000 # 30 seconds
|
| 176 |
+
try:
|
| 177 |
+
from playwright.async_api import expect as expect_async
|
| 178 |
+
from api_utils.request_processor import ClientDisconnectedError
|
| 179 |
+
|
| 180 |
+
# 检查客户端连接状态
|
| 181 |
+
client_disco_checker("流式响应后按钮状态检查 - 前置检查: ")
|
| 182 |
+
await asyncio.sleep(0.5) # 给UI一点时间更新
|
| 183 |
+
|
| 184 |
+
# 检查按钮是否仍然启用,如果启用则直接点击停止
|
| 185 |
+
logger.info(f"[{req_id}] (Worker) 检查发送按钮状态...")
|
| 186 |
+
try:
|
| 187 |
+
is_button_enabled = await submit_btn_loc.is_enabled(timeout=2000)
|
| 188 |
+
logger.info(f"[{req_id}] (Worker) 发��按钮启用状态: {is_button_enabled}")
|
| 189 |
+
|
| 190 |
+
if is_button_enabled:
|
| 191 |
+
# 流式响应完成后按钮仍启用,直接点击停止
|
| 192 |
+
logger.info(f"[{req_id}] (Worker) 流式响应完成但按钮仍启用,主动点击按钮停止生成...")
|
| 193 |
+
await submit_btn_loc.click(timeout=5000, force=True)
|
| 194 |
+
logger.info(f"[{req_id}] (Worker) ✅ 发送按钮点击完成。")
|
| 195 |
+
else:
|
| 196 |
+
logger.info(f"[{req_id}] (Worker) 发送按钮已禁用,无需点击。")
|
| 197 |
+
except Exception as button_check_err:
|
| 198 |
+
logger.warning(f"[{req_id}] (Worker) 检查按钮状态失败: {button_check_err}")
|
| 199 |
+
|
| 200 |
+
# 等待按钮最终禁用
|
| 201 |
+
logger.info(f"[{req_id}] (Worker) 等待发送按钮最终禁用...")
|
| 202 |
+
await expect_async(submit_btn_loc).to_be_disabled(timeout=wait_timeout_ms)
|
| 203 |
+
logger.info(f"[{req_id}] ✅ 发送按钮已禁用。")
|
| 204 |
+
|
| 205 |
+
except Exception as e_pw_disabled:
|
| 206 |
+
logger.warning(f"[{req_id}] ⚠️ 流式响应后按钮状态处理超时或错误: {e_pw_disabled}")
|
| 207 |
+
from api_utils.request_processor import save_error_snapshot
|
| 208 |
+
await save_error_snapshot(f"stream_post_submit_button_handling_timeout_{req_id}")
|
| 209 |
+
except ClientDisconnectedError:
|
| 210 |
+
logger.info(f"[{req_id}] 客户端在流式响应后按钮状态处理时断开连接。")
|
| 211 |
+
elif current_request_was_streaming:
|
| 212 |
+
logger.warning(f"[{req_id}] (Worker) 流式请求但 submit_btn_loc 或 client_disco_checker 未提供。跳过按钮禁用等待。")
|
| 213 |
+
|
| 214 |
+
except asyncio.TimeoutError:
|
| 215 |
+
logger.warning(f"[{req_id}] (Worker) ⚠️ 等待流式生成器完成信号超时。")
|
| 216 |
+
if not result_future.done():
|
| 217 |
+
result_future.set_exception(HTTPException(status_code=504, detail=f"[{req_id}] Stream generation timed out waiting for completion signal."))
|
| 218 |
+
except Exception as ev_wait_err:
|
| 219 |
+
logger.error(f"[{req_id}] (Worker) ❌ 等待流式完成事件时出错: {ev_wait_err}")
|
| 220 |
+
if not result_future.done():
|
| 221 |
+
result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Error waiting for stream completion: {ev_wait_err}"))
|
| 222 |
+
|
| 223 |
+
except Exception as process_err:
|
| 224 |
+
logger.error(f"[{req_id}] (Worker) _process_request_refactored execution error: {process_err}")
|
| 225 |
+
if not result_future.done():
|
| 226 |
+
result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Request processing error: {process_err}"))
|
| 227 |
+
|
| 228 |
+
logger.info(f"[{req_id}] (Worker) 释放处理锁。")
|
| 229 |
+
|
| 230 |
+
# 在释放处理锁后立即执行清空操作
|
| 231 |
+
try:
|
| 232 |
+
# 清空流式队列缓存
|
| 233 |
+
from api_utils import clear_stream_queue
|
| 234 |
+
await clear_stream_queue()
|
| 235 |
+
|
| 236 |
+
# 清空聊天历史(对于所有模式:流式和非流式)
|
| 237 |
+
if submit_btn_loc and client_disco_checker:
|
| 238 |
+
from server import page_instance, is_page_ready
|
| 239 |
+
if page_instance and is_page_ready:
|
| 240 |
+
from browser_utils.page_controller import PageController
|
| 241 |
+
page_controller = PageController(page_instance, logger, req_id)
|
| 242 |
+
logger.info(f"[{req_id}] (Worker) 执行聊天历史清空({'流式' if completion_event else '非流式'}模式)...")
|
| 243 |
+
await page_controller.clear_chat_history(client_disco_checker)
|
| 244 |
+
logger.info(f"[{req_id}] (Worker) ✅ 聊天历史清空完成。")
|
| 245 |
+
else:
|
| 246 |
+
logger.info(f"[{req_id}] (Worker) 跳过聊天历史清空:缺少必要参数(submit_btn_loc: {bool(submit_btn_loc)}, client_disco_checker: {bool(client_disco_checker)})")
|
| 247 |
+
except Exception as clear_err:
|
| 248 |
+
logger.error(f"[{req_id}] (Worker) 清空操作时发生错误: {clear_err}", exc_info=True)
|
| 249 |
+
|
| 250 |
+
was_last_request_streaming = is_streaming_request
|
| 251 |
+
last_request_completion_time = time.time()
|
| 252 |
+
|
| 253 |
+
except asyncio.CancelledError:
|
| 254 |
+
logger.info("--- 队列 Worker 被取消 ---")
|
| 255 |
+
if result_future and not result_future.done():
|
| 256 |
+
result_future.cancel("Worker cancelled")
|
| 257 |
+
break
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.error(f"[{req_id}] (Worker) ❌ 处理请求时发生意外错误: {e}", exc_info=True)
|
| 260 |
+
if result_future and not result_future.done():
|
| 261 |
+
result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] 服务器内部错误: {e}"))
|
| 262 |
+
finally:
|
| 263 |
+
if request_item:
|
| 264 |
+
request_queue.task_done()
|
| 265 |
+
|
| 266 |
+
logger.info("--- 队列 Worker 已停止 ---")
|
api_utils/request_processor.py
ADDED
|
@@ -0,0 +1,795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
请求处理器模块
|
| 3 |
+
包含核心的请求处理逻辑
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import time
|
| 11 |
+
from typing import Optional, Tuple, Callable, AsyncGenerator
|
| 12 |
+
from asyncio import Event, Future
|
| 13 |
+
|
| 14 |
+
from fastapi import HTTPException, Request
|
| 15 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 16 |
+
from playwright.async_api import Page as AsyncPage, Locator, Error as PlaywrightAsyncError, expect as expect_async
|
| 17 |
+
|
| 18 |
+
# --- 配置模块导入 ---
|
| 19 |
+
from config import *
|
| 20 |
+
|
| 21 |
+
# --- models模块导入 ---
|
| 22 |
+
from models import ChatCompletionRequest, ClientDisconnectedError
|
| 23 |
+
|
| 24 |
+
# --- browser_utils模块导入 ---
|
| 25 |
+
from browser_utils import (
|
| 26 |
+
switch_ai_studio_model,
|
| 27 |
+
save_error_snapshot
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# --- api_utils模块导入 ---
|
| 31 |
+
from .utils import (
|
| 32 |
+
validate_chat_request,
|
| 33 |
+
prepare_combined_prompt,
|
| 34 |
+
generate_sse_chunk,
|
| 35 |
+
generate_sse_stop_chunk,
|
| 36 |
+
use_stream_response,
|
| 37 |
+
calculate_usage_stats
|
| 38 |
+
)
|
| 39 |
+
from browser_utils.page_controller import PageController
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
async def _initialize_request_context(req_id: str, request: ChatCompletionRequest) -> dict:
|
| 43 |
+
"""初始化请求上下文"""
|
| 44 |
+
from server import (
|
| 45 |
+
logger, page_instance, is_page_ready, parsed_model_list,
|
| 46 |
+
current_ai_studio_model_id, model_switching_lock, page_params_cache,
|
| 47 |
+
params_cache_lock
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
logger.info(f"[{req_id}] 开始处理请求...")
|
| 51 |
+
logger.info(f"[{req_id}] 请求参数 - Model: {request.model}, Stream: {request.stream}")
|
| 52 |
+
|
| 53 |
+
context = {
|
| 54 |
+
'logger': logger,
|
| 55 |
+
'page': page_instance,
|
| 56 |
+
'is_page_ready': is_page_ready,
|
| 57 |
+
'parsed_model_list': parsed_model_list,
|
| 58 |
+
'current_ai_studio_model_id': current_ai_studio_model_id,
|
| 59 |
+
'model_switching_lock': model_switching_lock,
|
| 60 |
+
'page_params_cache': page_params_cache,
|
| 61 |
+
'params_cache_lock': params_cache_lock,
|
| 62 |
+
'is_streaming': request.stream,
|
| 63 |
+
'model_actually_switched': False,
|
| 64 |
+
'requested_model': request.model,
|
| 65 |
+
'model_id_to_use': None,
|
| 66 |
+
'needs_model_switching': False
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
return context
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def _analyze_model_requirements(req_id: str, context: dict, request: ChatCompletionRequest) -> dict:
|
| 73 |
+
"""分析模型需求并确定是否需要切换"""
|
| 74 |
+
logger = context['logger']
|
| 75 |
+
current_ai_studio_model_id = context['current_ai_studio_model_id']
|
| 76 |
+
parsed_model_list = context['parsed_model_list']
|
| 77 |
+
requested_model = request.model
|
| 78 |
+
|
| 79 |
+
if requested_model and requested_model != MODEL_NAME:
|
| 80 |
+
requested_model_id = requested_model.split('/')[-1]
|
| 81 |
+
logger.info(f"[{req_id}] 请求使用模型: {requested_model_id}")
|
| 82 |
+
|
| 83 |
+
if parsed_model_list:
|
| 84 |
+
valid_model_ids = [m.get("id") for m in parsed_model_list]
|
| 85 |
+
if requested_model_id not in valid_model_ids:
|
| 86 |
+
raise HTTPException(
|
| 87 |
+
status_code=400,
|
| 88 |
+
detail=f"[{req_id}] Invalid model '{requested_model_id}'. Available models: {', '.join(valid_model_ids)}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
context['model_id_to_use'] = requested_model_id
|
| 92 |
+
if current_ai_studio_model_id != requested_model_id:
|
| 93 |
+
context['needs_model_switching'] = True
|
| 94 |
+
logger.info(f"[{req_id}] 需要切换模型: 当前={current_ai_studio_model_id} -> 目标={requested_model_id}")
|
| 95 |
+
|
| 96 |
+
return context
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
async def _setup_disconnect_monitoring(req_id: str, http_request: Request, result_future: Future) -> Tuple[Event, asyncio.Task, Callable]:
|
| 100 |
+
"""设置客户端断开连接监控"""
|
| 101 |
+
from server import logger
|
| 102 |
+
|
| 103 |
+
client_disconnected_event = Event()
|
| 104 |
+
|
| 105 |
+
async def check_disconnect_periodically():
|
| 106 |
+
while not client_disconnected_event.is_set():
|
| 107 |
+
try:
|
| 108 |
+
if await http_request.is_disconnected():
|
| 109 |
+
logger.info(f"[{req_id}] 客户端断开,设置事件。")
|
| 110 |
+
client_disconnected_event.set()
|
| 111 |
+
if not result_future.done():
|
| 112 |
+
result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 客户端关闭了请求"))
|
| 113 |
+
break
|
| 114 |
+
await asyncio.sleep(1.0)
|
| 115 |
+
except asyncio.CancelledError:
|
| 116 |
+
break
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"[{req_id}] (Disco Check Task) 错误: {e}")
|
| 119 |
+
client_disconnected_event.set()
|
| 120 |
+
if not result_future.done():
|
| 121 |
+
result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Internal disconnect checker error: {e}"))
|
| 122 |
+
break
|
| 123 |
+
|
| 124 |
+
disconnect_check_task = asyncio.create_task(check_disconnect_periodically())
|
| 125 |
+
|
| 126 |
+
def check_client_disconnected(stage: str = ""):
|
| 127 |
+
if client_disconnected_event.is_set():
|
| 128 |
+
logger.info(f"[{req_id}] 在 '{stage}' 检测到客户端断开连接。")
|
| 129 |
+
raise ClientDisconnectedError(f"[{req_id}] Client disconnected at stage: {stage}")
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
return client_disconnected_event, disconnect_check_task, check_client_disconnected
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
async def _validate_page_status(req_id: str, context: dict, check_client_disconnected: Callable) -> None:
|
| 136 |
+
"""验证页面状态"""
|
| 137 |
+
page = context['page']
|
| 138 |
+
is_page_ready = context['is_page_ready']
|
| 139 |
+
|
| 140 |
+
if not page or page.is_closed() or not is_page_ready:
|
| 141 |
+
raise HTTPException(status_code=503, detail=f"[{req_id}] AI Studio 页面丢失或未就绪。", headers={"Retry-After": "30"})
|
| 142 |
+
|
| 143 |
+
check_client_disconnected("Initial Page Check")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
async def _handle_model_switching(req_id: str, context: dict, check_client_disconnected: Callable) -> dict:
|
| 147 |
+
"""处理模型切换逻辑"""
|
| 148 |
+
if not context['needs_model_switching']:
|
| 149 |
+
return context
|
| 150 |
+
|
| 151 |
+
logger = context['logger']
|
| 152 |
+
page = context['page']
|
| 153 |
+
model_switching_lock = context['model_switching_lock']
|
| 154 |
+
model_id_to_use = context['model_id_to_use']
|
| 155 |
+
|
| 156 |
+
import server
|
| 157 |
+
|
| 158 |
+
async with model_switching_lock:
|
| 159 |
+
if server.current_ai_studio_model_id != model_id_to_use:
|
| 160 |
+
logger.info(f"[{req_id}] 准备切换模型: {server.current_ai_studio_model_id} -> {model_id_to_use}")
|
| 161 |
+
switch_success = await switch_ai_studio_model(page, model_id_to_use, req_id)
|
| 162 |
+
if switch_success:
|
| 163 |
+
server.current_ai_studio_model_id = model_id_to_use
|
| 164 |
+
context['model_actually_switched'] = True
|
| 165 |
+
context['current_ai_studio_model_id'] = model_id_to_use
|
| 166 |
+
logger.info(f"[{req_id}] ✅ 模型切换成功: {server.current_ai_studio_model_id}")
|
| 167 |
+
else:
|
| 168 |
+
await _handle_model_switch_failure(req_id, page, model_id_to_use, server.current_ai_studio_model_id, logger)
|
| 169 |
+
|
| 170 |
+
return context
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
async def _handle_model_switch_failure(req_id: str, page: AsyncPage, model_id_to_use: str, model_before_switch: str, logger) -> None:
|
| 174 |
+
"""处理模型切换失败的情况"""
|
| 175 |
+
import server
|
| 176 |
+
|
| 177 |
+
logger.warning(f"[{req_id}] ❌ 模型切换至 {model_id_to_use} 失败。")
|
| 178 |
+
# 尝试恢复全局状态
|
| 179 |
+
server.current_ai_studio_model_id = model_before_switch
|
| 180 |
+
|
| 181 |
+
raise HTTPException(
|
| 182 |
+
status_code=422,
|
| 183 |
+
detail=f"[{req_id}] 未能切换到模型 '{model_id_to_use}'。请确保模型可用。"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
async def _handle_parameter_cache(req_id: str, context: dict) -> None:
|
| 188 |
+
"""处理参数缓存"""
|
| 189 |
+
logger = context['logger']
|
| 190 |
+
params_cache_lock = context['params_cache_lock']
|
| 191 |
+
page_params_cache = context['page_params_cache']
|
| 192 |
+
current_ai_studio_model_id = context['current_ai_studio_model_id']
|
| 193 |
+
model_actually_switched = context['model_actually_switched']
|
| 194 |
+
|
| 195 |
+
async with params_cache_lock:
|
| 196 |
+
cached_model_for_params = page_params_cache.get("last_known_model_id_for_params")
|
| 197 |
+
|
| 198 |
+
if model_actually_switched or (current_ai_studio_model_id != cached_model_for_params):
|
| 199 |
+
logger.info(f"[{req_id}] 模型已更改,参数缓存失效。")
|
| 200 |
+
page_params_cache.clear()
|
| 201 |
+
page_params_cache["last_known_model_id_for_params"] = current_ai_studio_model_id
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
async def _prepare_and_validate_request(req_id: str, request: ChatCompletionRequest, check_client_disconnected: Callable) -> str:
|
| 205 |
+
"""准备和验证请求"""
|
| 206 |
+
try:
|
| 207 |
+
validate_chat_request(request.messages, req_id)
|
| 208 |
+
except ValueError as e:
|
| 209 |
+
raise HTTPException(status_code=400, detail=f"[{req_id}] 无效请求: {e}")
|
| 210 |
+
|
| 211 |
+
prepared_prompt = prepare_combined_prompt(request.messages, req_id)
|
| 212 |
+
check_client_disconnected("After Prompt Prep")
|
| 213 |
+
|
| 214 |
+
return prepared_prompt
|
| 215 |
+
|
| 216 |
+
async def _handle_response_processing(req_id: str, request: ChatCompletionRequest, page: AsyncPage,
|
| 217 |
+
context: dict, result_future: Future,
|
| 218 |
+
submit_button_locator: Locator, check_client_disconnected: Callable) -> Optional[Tuple[Event, Locator, Callable]]:
|
| 219 |
+
"""处理响应生成"""
|
| 220 |
+
from server import logger
|
| 221 |
+
|
| 222 |
+
is_streaming = request.stream
|
| 223 |
+
current_ai_studio_model_id = context.get('current_ai_studio_model_id')
|
| 224 |
+
|
| 225 |
+
# 检查是否使用辅助流
|
| 226 |
+
stream_port = os.environ.get('STREAM_PORT')
|
| 227 |
+
use_stream = stream_port != '0'
|
| 228 |
+
|
| 229 |
+
if use_stream:
|
| 230 |
+
return await _handle_auxiliary_stream_response(req_id, request, context, result_future, submit_button_locator, check_client_disconnected)
|
| 231 |
+
else:
|
| 232 |
+
return await _handle_playwright_response(req_id, request, page, context, result_future, submit_button_locator, check_client_disconnected)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
async def _handle_auxiliary_stream_response(req_id: str, request: ChatCompletionRequest, context: dict,
|
| 236 |
+
result_future: Future, submit_button_locator: Locator,
|
| 237 |
+
check_client_disconnected: Callable) -> Optional[Tuple[Event, Locator, Callable]]:
|
| 238 |
+
"""使用辅助流处理响应"""
|
| 239 |
+
from server import logger
|
| 240 |
+
|
| 241 |
+
is_streaming = request.stream
|
| 242 |
+
current_ai_studio_model_id = context.get('current_ai_studio_model_id')
|
| 243 |
+
|
| 244 |
+
def generate_random_string(length):
|
| 245 |
+
charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
| 246 |
+
return ''.join(random.choice(charset) for _ in range(length))
|
| 247 |
+
|
| 248 |
+
if is_streaming:
|
| 249 |
+
try:
|
| 250 |
+
completion_event = Event()
|
| 251 |
+
|
| 252 |
+
async def create_stream_generator_from_helper(event_to_set: Event) -> AsyncGenerator[str, None]:
|
| 253 |
+
last_reason_pos = 0
|
| 254 |
+
last_body_pos = 0
|
| 255 |
+
model_name_for_stream = current_ai_studio_model_id or MODEL_NAME
|
| 256 |
+
chat_completion_id = f"{CHAT_COMPLETION_ID_PREFIX}{req_id}-{int(time.time())}-{random.randint(100, 999)}"
|
| 257 |
+
created_timestamp = int(time.time())
|
| 258 |
+
|
| 259 |
+
# 用于收集完整内容以计算usage
|
| 260 |
+
full_reasoning_content = ""
|
| 261 |
+
full_body_content = ""
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
async for raw_data in use_stream_response(req_id):
|
| 265 |
+
# 检查客户端是否断开连接
|
| 266 |
+
try:
|
| 267 |
+
check_client_disconnected(f"流式生成器循环 ({req_id}): ")
|
| 268 |
+
except ClientDisconnectedError:
|
| 269 |
+
logger.info(f"[{req_id}] 客户端断开连接,终止流式生成")
|
| 270 |
+
break
|
| 271 |
+
|
| 272 |
+
# 确保 data 是字典类型
|
| 273 |
+
if isinstance(raw_data, str):
|
| 274 |
+
try:
|
| 275 |
+
data = json.loads(raw_data)
|
| 276 |
+
except json.JSONDecodeError:
|
| 277 |
+
logger.warning(f"[{req_id}] 无法解析流数据JSON: {raw_data}")
|
| 278 |
+
continue
|
| 279 |
+
elif isinstance(raw_data, dict):
|
| 280 |
+
data = raw_data
|
| 281 |
+
else:
|
| 282 |
+
logger.warning(f"[{req_id}] 未知的流数据类型: {type(raw_data)}")
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
# 确保必要的键存在
|
| 286 |
+
if not isinstance(data, dict):
|
| 287 |
+
logger.warning(f"[{req_id}] 数据不是字典类型: {data}")
|
| 288 |
+
continue
|
| 289 |
+
|
| 290 |
+
reason = data.get("reason", "")
|
| 291 |
+
body = data.get("body", "")
|
| 292 |
+
done = data.get("done", False)
|
| 293 |
+
function = data.get("function", [])
|
| 294 |
+
|
| 295 |
+
# 更新完整内容记录
|
| 296 |
+
if reason:
|
| 297 |
+
full_reasoning_content = reason
|
| 298 |
+
if body:
|
| 299 |
+
full_body_content = body
|
| 300 |
+
|
| 301 |
+
# 处理推理内容
|
| 302 |
+
if len(reason) > last_reason_pos:
|
| 303 |
+
output = {
|
| 304 |
+
"id": chat_completion_id,
|
| 305 |
+
"object": "chat.completion.chunk",
|
| 306 |
+
"model": model_name_for_stream,
|
| 307 |
+
"created": created_timestamp,
|
| 308 |
+
"choices":[{
|
| 309 |
+
"index": 0,
|
| 310 |
+
"delta":{
|
| 311 |
+
"role": "assistant",
|
| 312 |
+
"content": None,
|
| 313 |
+
"reasoning_content": reason[last_reason_pos:],
|
| 314 |
+
},
|
| 315 |
+
"finish_reason": None,
|
| 316 |
+
"native_finish_reason": None,
|
| 317 |
+
}]
|
| 318 |
+
}
|
| 319 |
+
last_reason_pos = len(reason)
|
| 320 |
+
yield f"data: {json.dumps(output, ensure_ascii=False, separators=(',', ':'))}\n\n"
|
| 321 |
+
|
| 322 |
+
# 处理主体内容
|
| 323 |
+
if len(body) > last_body_pos:
|
| 324 |
+
finish_reason_val = None
|
| 325 |
+
if done:
|
| 326 |
+
finish_reason_val = "stop"
|
| 327 |
+
|
| 328 |
+
delta_content = {"role": "assistant", "content": body[last_body_pos:]}
|
| 329 |
+
choice_item = {
|
| 330 |
+
"index": 0,
|
| 331 |
+
"delta": delta_content,
|
| 332 |
+
"finish_reason": finish_reason_val,
|
| 333 |
+
"native_finish_reason": finish_reason_val,
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
if done and function and len(function) > 0:
|
| 337 |
+
tool_calls_list = []
|
| 338 |
+
for func_idx, function_call_data in enumerate(function):
|
| 339 |
+
tool_calls_list.append({
|
| 340 |
+
"id": f"call_{generate_random_string(24)}",
|
| 341 |
+
"index": func_idx,
|
| 342 |
+
"type": "function",
|
| 343 |
+
"function": {
|
| 344 |
+
"name": function_call_data["name"],
|
| 345 |
+
"arguments": json.dumps(function_call_data["params"]),
|
| 346 |
+
},
|
| 347 |
+
})
|
| 348 |
+
delta_content["tool_calls"] = tool_calls_list
|
| 349 |
+
choice_item["finish_reason"] = "tool_calls"
|
| 350 |
+
choice_item["native_finish_reason"] = "tool_calls"
|
| 351 |
+
delta_content["content"] = None
|
| 352 |
+
|
| 353 |
+
output = {
|
| 354 |
+
"id": chat_completion_id,
|
| 355 |
+
"object": "chat.completion.chunk",
|
| 356 |
+
"model": model_name_for_stream,
|
| 357 |
+
"created": created_timestamp,
|
| 358 |
+
"choices": [choice_item]
|
| 359 |
+
}
|
| 360 |
+
last_body_pos = len(body)
|
| 361 |
+
yield f"data: {json.dumps(output, ensure_ascii=False, separators=(',', ':'))}\n\n"
|
| 362 |
+
|
| 363 |
+
# 处理只有done=True但没有新内容的情况(仅有函数调用或纯结束)
|
| 364 |
+
elif done:
|
| 365 |
+
# 如果有函数调用但没有新的body内容
|
| 366 |
+
if function and len(function) > 0:
|
| 367 |
+
delta_content = {"role": "assistant", "content": None}
|
| 368 |
+
tool_calls_list = []
|
| 369 |
+
for func_idx, function_call_data in enumerate(function):
|
| 370 |
+
tool_calls_list.append({
|
| 371 |
+
"id": f"call_{generate_random_string(24)}",
|
| 372 |
+
"index": func_idx,
|
| 373 |
+
"type": "function",
|
| 374 |
+
"function": {
|
| 375 |
+
"name": function_call_data["name"],
|
| 376 |
+
"arguments": json.dumps(function_call_data["params"]),
|
| 377 |
+
},
|
| 378 |
+
})
|
| 379 |
+
delta_content["tool_calls"] = tool_calls_list
|
| 380 |
+
choice_item = {
|
| 381 |
+
"index": 0,
|
| 382 |
+
"delta": delta_content,
|
| 383 |
+
"finish_reason": "tool_calls",
|
| 384 |
+
"native_finish_reason": "tool_calls",
|
| 385 |
+
}
|
| 386 |
+
else:
|
| 387 |
+
# 纯结束,没有新内容和函数调用
|
| 388 |
+
choice_item = {
|
| 389 |
+
"index": 0,
|
| 390 |
+
"delta": {"role": "assistant"},
|
| 391 |
+
"finish_reason": "stop",
|
| 392 |
+
"native_finish_reason": "stop",
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
output = {
|
| 396 |
+
"id": chat_completion_id,
|
| 397 |
+
"object": "chat.completion.chunk",
|
| 398 |
+
"model": model_name_for_stream,
|
| 399 |
+
"created": created_timestamp,
|
| 400 |
+
"choices": [choice_item]
|
| 401 |
+
}
|
| 402 |
+
yield f"data: {json.dumps(output, ensure_ascii=False, separators=(',', ':'))}\n\n"
|
| 403 |
+
|
| 404 |
+
except ClientDisconnectedError:
|
| 405 |
+
logger.info(f"[{req_id}] 流式生成器中检测到客户端断开连接")
|
| 406 |
+
except Exception as e:
|
| 407 |
+
logger.error(f"[{req_id}] 流式生成器处理过程中发生错误: {e}", exc_info=True)
|
| 408 |
+
# 发送错误信息给客户端
|
| 409 |
+
try:
|
| 410 |
+
error_chunk = {
|
| 411 |
+
"id": chat_completion_id,
|
| 412 |
+
"object": "chat.completion.chunk",
|
| 413 |
+
"model": model_name_for_stream,
|
| 414 |
+
"created": created_timestamp,
|
| 415 |
+
"choices": [{
|
| 416 |
+
"index": 0,
|
| 417 |
+
"delta": {"role": "assistant", "content": f"\n\n[错误: {str(e)}]"},
|
| 418 |
+
"finish_reason": "stop",
|
| 419 |
+
"native_finish_reason": "stop",
|
| 420 |
+
}]
|
| 421 |
+
}
|
| 422 |
+
yield f"data: {json.dumps(error_chunk, ensure_ascii=False, separators=(',', ':'))}\n\n"
|
| 423 |
+
except Exception:
|
| 424 |
+
pass # 如果无法发送错误信息,继续处理结束逻辑
|
| 425 |
+
finally:
|
| 426 |
+
# 计算usage统计
|
| 427 |
+
try:
|
| 428 |
+
usage_stats = calculate_usage_stats(
|
| 429 |
+
[msg.model_dump() for msg in request.messages],
|
| 430 |
+
full_body_content,
|
| 431 |
+
full_reasoning_content
|
| 432 |
+
)
|
| 433 |
+
logger.info(f"[{req_id}] 计算的token使用统计: {usage_stats}")
|
| 434 |
+
|
| 435 |
+
# 发送带usage的最终chunk
|
| 436 |
+
final_chunk = {
|
| 437 |
+
"id": chat_completion_id,
|
| 438 |
+
"object": "chat.completion.chunk",
|
| 439 |
+
"model": model_name_for_stream,
|
| 440 |
+
"created": created_timestamp,
|
| 441 |
+
"choices": [{
|
| 442 |
+
"index": 0,
|
| 443 |
+
"delta": {},
|
| 444 |
+
"finish_reason": "stop",
|
| 445 |
+
"native_finish_reason": "stop"
|
| 446 |
+
}],
|
| 447 |
+
"usage": usage_stats
|
| 448 |
+
}
|
| 449 |
+
yield f"data: {json.dumps(final_chunk, ensure_ascii=False, separators=(',', ':'))}\n\n"
|
| 450 |
+
logger.info(f"[{req_id}] 已发送带usage统计的最终chunk")
|
| 451 |
+
|
| 452 |
+
except Exception as usage_err:
|
| 453 |
+
logger.error(f"[{req_id}] 计算或发送usage统计时出错: {usage_err}")
|
| 454 |
+
|
| 455 |
+
# 确保总是发送 [DONE] 标记
|
| 456 |
+
try:
|
| 457 |
+
logger.info(f"[{req_id}] 流式生成器完成,发送 [DONE] 标记")
|
| 458 |
+
yield "data: [DONE]\n\n"
|
| 459 |
+
except Exception as done_err:
|
| 460 |
+
logger.error(f"[{req_id}] 发送 [DONE] 标记时出错: {done_err}")
|
| 461 |
+
|
| 462 |
+
# 确保事件被设置
|
| 463 |
+
if not event_to_set.is_set():
|
| 464 |
+
event_to_set.set()
|
| 465 |
+
logger.info(f"[{req_id}] 流式生成器完成事件已设置")
|
| 466 |
+
|
| 467 |
+
stream_gen_func = create_stream_generator_from_helper(completion_event)
|
| 468 |
+
if not result_future.done():
|
| 469 |
+
result_future.set_result(StreamingResponse(stream_gen_func, media_type="text/event-stream"))
|
| 470 |
+
else:
|
| 471 |
+
if not completion_event.is_set():
|
| 472 |
+
completion_event.set()
|
| 473 |
+
|
| 474 |
+
return completion_event, submit_button_locator, check_client_disconnected
|
| 475 |
+
|
| 476 |
+
except Exception as e:
|
| 477 |
+
logger.error(f"[{req_id}] 从队列获取流式数据时出错: {e}", exc_info=True)
|
| 478 |
+
if completion_event and not completion_event.is_set():
|
| 479 |
+
completion_event.set()
|
| 480 |
+
raise
|
| 481 |
+
|
| 482 |
+
else: # 非流式
|
| 483 |
+
content = None
|
| 484 |
+
reasoning_content = None
|
| 485 |
+
functions = None
|
| 486 |
+
final_data_from_aux_stream = None
|
| 487 |
+
|
| 488 |
+
async for raw_data in use_stream_response(req_id):
|
| 489 |
+
check_client_disconnected(f"非流式辅助流 - 循环中 ({req_id}): ")
|
| 490 |
+
|
| 491 |
+
# 确保 data 是字典类型
|
| 492 |
+
if isinstance(raw_data, str):
|
| 493 |
+
try:
|
| 494 |
+
data = json.loads(raw_data)
|
| 495 |
+
except json.JSONDecodeError:
|
| 496 |
+
logger.warning(f"[{req_id}] 无法解析非流式数据JSON: {raw_data}")
|
| 497 |
+
continue
|
| 498 |
+
elif isinstance(raw_data, dict):
|
| 499 |
+
data = raw_data
|
| 500 |
+
else:
|
| 501 |
+
logger.warning(f"[{req_id}] 非流式未知数据类型: {type(raw_data)}")
|
| 502 |
+
continue
|
| 503 |
+
|
| 504 |
+
# 确保数据是字典类型
|
| 505 |
+
if not isinstance(data, dict):
|
| 506 |
+
logger.warning(f"[{req_id}] 非流式数据不是字典类型: {data}")
|
| 507 |
+
continue
|
| 508 |
+
|
| 509 |
+
final_data_from_aux_stream = data
|
| 510 |
+
if data.get("done"):
|
| 511 |
+
content = data.get("body")
|
| 512 |
+
reasoning_content = data.get("reason")
|
| 513 |
+
functions = data.get("function")
|
| 514 |
+
break
|
| 515 |
+
|
| 516 |
+
if final_data_from_aux_stream and final_data_from_aux_stream.get("reason") == "internal_timeout":
|
| 517 |
+
logger.error(f"[{req_id}] 非流式请求通过辅助流失败: 内部超时")
|
| 518 |
+
raise HTTPException(status_code=502, detail=f"[{req_id}] 辅助流处理错误 (内部超时)")
|
| 519 |
+
|
| 520 |
+
if final_data_from_aux_stream and final_data_from_aux_stream.get("done") is True and content is None:
|
| 521 |
+
logger.error(f"[{req_id}] 非流式请求通过辅助流完成但未提供内容")
|
| 522 |
+
raise HTTPException(status_code=502, detail=f"[{req_id}] 辅助流完成但未提供内容")
|
| 523 |
+
|
| 524 |
+
model_name_for_json = current_ai_studio_model_id or MODEL_NAME
|
| 525 |
+
message_payload = {"role": "assistant", "content": content}
|
| 526 |
+
finish_reason_val = "stop"
|
| 527 |
+
|
| 528 |
+
if functions and len(functions) > 0:
|
| 529 |
+
tool_calls_list = []
|
| 530 |
+
for func_idx, function_call_data in enumerate(functions):
|
| 531 |
+
tool_calls_list.append({
|
| 532 |
+
"id": f"call_{generate_random_string(24)}",
|
| 533 |
+
"index": func_idx,
|
| 534 |
+
"type": "function",
|
| 535 |
+
"function": {
|
| 536 |
+
"name": function_call_data["name"],
|
| 537 |
+
"arguments": json.dumps(function_call_data["params"]),
|
| 538 |
+
},
|
| 539 |
+
})
|
| 540 |
+
message_payload["tool_calls"] = tool_calls_list
|
| 541 |
+
finish_reason_val = "tool_calls"
|
| 542 |
+
message_payload["content"] = None
|
| 543 |
+
|
| 544 |
+
if reasoning_content:
|
| 545 |
+
message_payload["reasoning_content"] = reasoning_content
|
| 546 |
+
|
| 547 |
+
# 计算token使用统计
|
| 548 |
+
usage_stats = calculate_usage_stats(
|
| 549 |
+
[msg.model_dump() for msg in request.messages],
|
| 550 |
+
content or "",
|
| 551 |
+
reasoning_content
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
response_payload = {
|
| 555 |
+
"id": f"{CHAT_COMPLETION_ID_PREFIX}{req_id}-{int(time.time())}",
|
| 556 |
+
"object": "chat.completion",
|
| 557 |
+
"created": int(time.time()),
|
| 558 |
+
"model": model_name_for_json,
|
| 559 |
+
"choices": [{
|
| 560 |
+
"index": 0,
|
| 561 |
+
"message": message_payload,
|
| 562 |
+
"finish_reason": finish_reason_val,
|
| 563 |
+
"native_finish_reason": finish_reason_val,
|
| 564 |
+
}],
|
| 565 |
+
"usage": usage_stats
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
if not result_future.done():
|
| 569 |
+
result_future.set_result(JSONResponse(content=response_payload))
|
| 570 |
+
return None
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
async def _handle_playwright_response(req_id: str, request: ChatCompletionRequest, page: AsyncPage,
|
| 574 |
+
context: dict, result_future: Future, submit_button_locator: Locator,
|
| 575 |
+
check_client_disconnected: Callable) -> Optional[Tuple[Event, Locator, Callable]]:
|
| 576 |
+
"""使用Playwright处理响应"""
|
| 577 |
+
from server import logger
|
| 578 |
+
|
| 579 |
+
is_streaming = request.stream
|
| 580 |
+
current_ai_studio_model_id = context.get('current_ai_studio_model_id')
|
| 581 |
+
|
| 582 |
+
logger.info(f"[{req_id}] 定位响应元素...")
|
| 583 |
+
response_container = page.locator(RESPONSE_CONTAINER_SELECTOR).last
|
| 584 |
+
response_element = response_container.locator(RESPONSE_TEXT_SELECTOR)
|
| 585 |
+
|
| 586 |
+
try:
|
| 587 |
+
await expect_async(response_container).to_be_attached(timeout=20000)
|
| 588 |
+
check_client_disconnected("After Response Container Attached: ")
|
| 589 |
+
await expect_async(response_element).to_be_attached(timeout=90000)
|
| 590 |
+
logger.info(f"[{req_id}] 响应元素已定位。")
|
| 591 |
+
except (PlaywrightAsyncError, asyncio.TimeoutError, ClientDisconnectedError) as locate_err:
|
| 592 |
+
if isinstance(locate_err, ClientDisconnectedError):
|
| 593 |
+
raise
|
| 594 |
+
logger.error(f"[{req_id}] ❌ 错误: 定位响应元素失败或超时: {locate_err}")
|
| 595 |
+
await save_error_snapshot(f"response_locate_error_{req_id}")
|
| 596 |
+
raise HTTPException(status_code=502, detail=f"[{req_id}] 定位AI Studio响应元素失败: {locate_err}")
|
| 597 |
+
except Exception as locate_exc:
|
| 598 |
+
logger.exception(f"[{req_id}] ❌ 错误: 定位响应元素时意外错误")
|
| 599 |
+
await save_error_snapshot(f"response_locate_unexpected_{req_id}")
|
| 600 |
+
raise HTTPException(status_code=500, detail=f"[{req_id}] 定位响应元素时意外错误: {locate_exc}")
|
| 601 |
+
|
| 602 |
+
check_client_disconnected("After Response Element Located: ")
|
| 603 |
+
|
| 604 |
+
if is_streaming:
|
| 605 |
+
completion_event = Event()
|
| 606 |
+
|
| 607 |
+
async def create_response_stream_generator():
|
| 608 |
+
try:
|
| 609 |
+
# 使用PageController获取响应
|
| 610 |
+
page_controller = PageController(page, logger, req_id)
|
| 611 |
+
final_content = await page_controller.get_response(check_client_disconnected)
|
| 612 |
+
|
| 613 |
+
# 生成流式响应 - 保持Markdown格式
|
| 614 |
+
# 按行分割以保持换行符和Markdown结构
|
| 615 |
+
lines = final_content.split('\n')
|
| 616 |
+
for line_idx, line in enumerate(lines):
|
| 617 |
+
# 检查客户端是否断开连接
|
| 618 |
+
try:
|
| 619 |
+
check_client_disconnected(f"Playwright流式生成器循环 ({req_id}): ")
|
| 620 |
+
except ClientDisconnectedError:
|
| 621 |
+
logger.info(f"[{req_id}] Playwright流式生成器中检测到客户端断开连接")
|
| 622 |
+
break
|
| 623 |
+
|
| 624 |
+
# 输出当前行的内容(包括空行,以保持Markdown格式)
|
| 625 |
+
if line: # 非空行按字符分块输出
|
| 626 |
+
chunk_size = 5 # 每次输出5个字符,平衡速度和体验
|
| 627 |
+
for i in range(0, len(line), chunk_size):
|
| 628 |
+
chunk = line[i:i+chunk_size]
|
| 629 |
+
yield generate_sse_chunk(chunk, req_id, current_ai_studio_model_id or MODEL_NAME)
|
| 630 |
+
await asyncio.sleep(0.03) # 适中的输出速度
|
| 631 |
+
|
| 632 |
+
# 添加换行符(除了最后一行)
|
| 633 |
+
if line_idx < len(lines) - 1:
|
| 634 |
+
yield generate_sse_chunk('\n', req_id, current_ai_studio_model_id or MODEL_NAME)
|
| 635 |
+
await asyncio.sleep(0.01)
|
| 636 |
+
|
| 637 |
+
# 计算并发送带usage的完成块
|
| 638 |
+
usage_stats = calculate_usage_stats(
|
| 639 |
+
[msg.model_dump() for msg in request.messages],
|
| 640 |
+
final_content,
|
| 641 |
+
"" # Playwright模式没有reasoning content
|
| 642 |
+
)
|
| 643 |
+
logger.info(f"[{req_id}] Playwright非流式计算的token使用统计: {usage_stats}")
|
| 644 |
+
|
| 645 |
+
# 发送带usage的完成块
|
| 646 |
+
yield generate_sse_stop_chunk(req_id, current_ai_studio_model_id or MODEL_NAME, "stop", usage_stats)
|
| 647 |
+
|
| 648 |
+
except ClientDisconnectedError:
|
| 649 |
+
logger.info(f"[{req_id}] Playwright流式生成器中检测到客户端断开连接")
|
| 650 |
+
except Exception as e:
|
| 651 |
+
logger.error(f"[{req_id}] Playwright流式生成器处理过程中发生错误: {e}", exc_info=True)
|
| 652 |
+
# 发送错误信息给客户端
|
| 653 |
+
try:
|
| 654 |
+
yield generate_sse_chunk(f"\n\n[错误: {str(e)}]", req_id, current_ai_studio_model_id or MODEL_NAME)
|
| 655 |
+
yield generate_sse_stop_chunk(req_id, current_ai_studio_model_id or MODEL_NAME)
|
| 656 |
+
except Exception:
|
| 657 |
+
pass # 如果无法发送错误信息,继续处理结束逻辑
|
| 658 |
+
finally:
|
| 659 |
+
# 确保事件被设置
|
| 660 |
+
if not completion_event.is_set():
|
| 661 |
+
completion_event.set()
|
| 662 |
+
logger.info(f"[{req_id}] Playwright流式生成器完成事件已设置")
|
| 663 |
+
|
| 664 |
+
stream_gen_func = create_response_stream_generator()
|
| 665 |
+
if not result_future.done():
|
| 666 |
+
result_future.set_result(StreamingResponse(stream_gen_func, media_type="text/event-stream"))
|
| 667 |
+
|
| 668 |
+
return completion_event, submit_button_locator, check_client_disconnected
|
| 669 |
+
else:
|
| 670 |
+
# 使用PageController获取响应
|
| 671 |
+
page_controller = PageController(page, logger, req_id)
|
| 672 |
+
final_content = await page_controller.get_response(check_client_disconnected)
|
| 673 |
+
|
| 674 |
+
# 计算token使用统计
|
| 675 |
+
usage_stats = calculate_usage_stats(
|
| 676 |
+
[msg.model_dump() for msg in request.messages],
|
| 677 |
+
final_content,
|
| 678 |
+
"" # Playwright模式没有reasoning content
|
| 679 |
+
)
|
| 680 |
+
logger.info(f"[{req_id}] Playwright非流式计算的token使用统计: {usage_stats}")
|
| 681 |
+
|
| 682 |
+
response_payload = {
|
| 683 |
+
"id": f"{CHAT_COMPLETION_ID_PREFIX}{req_id}-{int(time.time())}",
|
| 684 |
+
"object": "chat.completion",
|
| 685 |
+
"created": int(time.time()),
|
| 686 |
+
"model": current_ai_studio_model_id or MODEL_NAME,
|
| 687 |
+
"choices": [{
|
| 688 |
+
"index": 0,
|
| 689 |
+
"message": {"role": "assistant", "content": final_content},
|
| 690 |
+
"finish_reason": "stop"
|
| 691 |
+
}],
|
| 692 |
+
"usage": usage_stats
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
if not result_future.done():
|
| 696 |
+
result_future.set_result(JSONResponse(content=response_payload))
|
| 697 |
+
|
| 698 |
+
return None
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
async def _cleanup_request_resources(req_id: str, disconnect_check_task: Optional[asyncio.Task],
|
| 702 |
+
completion_event: Optional[Event], result_future: Future,
|
| 703 |
+
is_streaming: bool) -> None:
|
| 704 |
+
"""清理请求资源"""
|
| 705 |
+
from server import logger
|
| 706 |
+
|
| 707 |
+
if disconnect_check_task and not disconnect_check_task.done():
|
| 708 |
+
disconnect_check_task.cancel()
|
| 709 |
+
try:
|
| 710 |
+
await disconnect_check_task
|
| 711 |
+
except asyncio.CancelledError:
|
| 712 |
+
pass
|
| 713 |
+
except Exception as task_clean_err:
|
| 714 |
+
logger.error(f"[{req_id}] 清理任务时出错: {task_clean_err}")
|
| 715 |
+
|
| 716 |
+
logger.info(f"[{req_id}] 处理完成。")
|
| 717 |
+
|
| 718 |
+
if is_streaming and completion_event and not completion_event.is_set() and (result_future.done() and result_future.exception() is not None):
|
| 719 |
+
logger.warning(f"[{req_id}] 流式请求异常,确保完成事件已设置。")
|
| 720 |
+
completion_event.set()
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
async def _process_request_refactored(
|
| 724 |
+
req_id: str,
|
| 725 |
+
request: ChatCompletionRequest,
|
| 726 |
+
http_request: Request,
|
| 727 |
+
result_future: Future
|
| 728 |
+
) -> Optional[Tuple[Event, Locator, Callable[[str], bool]]]:
|
| 729 |
+
"""核心请求处理函数 - 重构版本"""
|
| 730 |
+
|
| 731 |
+
context = await _initialize_request_context(req_id, request)
|
| 732 |
+
context = await _analyze_model_requirements(req_id, context, request)
|
| 733 |
+
|
| 734 |
+
client_disconnected_event, disconnect_check_task, check_client_disconnected = await _setup_disconnect_monitoring(
|
| 735 |
+
req_id, http_request, result_future
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
page = context['page']
|
| 739 |
+
submit_button_locator = page.locator(SUBMIT_BUTTON_SELECTOR) if page else None
|
| 740 |
+
completion_event = None
|
| 741 |
+
|
| 742 |
+
try:
|
| 743 |
+
await _validate_page_status(req_id, context, check_client_disconnected)
|
| 744 |
+
|
| 745 |
+
page_controller = PageController(page, context['logger'], req_id)
|
| 746 |
+
|
| 747 |
+
await _handle_model_switching(req_id, context, check_client_disconnected)
|
| 748 |
+
await _handle_parameter_cache(req_id, context)
|
| 749 |
+
|
| 750 |
+
prepared_prompt = await _prepare_and_validate_request(req_id, request, check_client_disconnected)
|
| 751 |
+
|
| 752 |
+
# 使用PageController处理页面交互
|
| 753 |
+
# 注意:聊天历史清空已移至队列处理锁释放后执行
|
| 754 |
+
|
| 755 |
+
await page_controller.adjust_parameters(
|
| 756 |
+
request.model_dump(exclude_none=True), # 使用 exclude_none=True 避免传递None值
|
| 757 |
+
context['page_params_cache'],
|
| 758 |
+
context['params_cache_lock'],
|
| 759 |
+
context['model_id_to_use'],
|
| 760 |
+
context['parsed_model_list'],
|
| 761 |
+
check_client_disconnected
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
await page_controller.submit_prompt(prepared_prompt, check_client_disconnected)
|
| 765 |
+
|
| 766 |
+
# 响应处理仍然需要在这里,因为它决定了是流式还是非流式,并设置future
|
| 767 |
+
response_result = await _handle_response_processing(
|
| 768 |
+
req_id, request, page, context, result_future, submit_button_locator, check_client_disconnected
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
if response_result:
|
| 772 |
+
completion_event, _, _ = response_result
|
| 773 |
+
|
| 774 |
+
return completion_event, submit_button_locator, check_client_disconnected
|
| 775 |
+
|
| 776 |
+
except ClientDisconnectedError as disco_err:
|
| 777 |
+
context['logger'].info(f"[{req_id}] 捕获到客户端断开连接信号: {disco_err}")
|
| 778 |
+
if not result_future.done():
|
| 779 |
+
result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] Client disconnected during processing."))
|
| 780 |
+
except HTTPException as http_err:
|
| 781 |
+
context['logger'].warning(f"[{req_id}] 捕获到 HTTP 异常: {http_err.status_code} - {http_err.detail}")
|
| 782 |
+
if not result_future.done():
|
| 783 |
+
result_future.set_exception(http_err)
|
| 784 |
+
except PlaywrightAsyncError as pw_err:
|
| 785 |
+
context['logger'].error(f"[{req_id}] 捕获到 Playwright 错误: {pw_err}")
|
| 786 |
+
await save_error_snapshot(f"process_playwright_error_{req_id}")
|
| 787 |
+
if not result_future.done():
|
| 788 |
+
result_future.set_exception(HTTPException(status_code=502, detail=f"[{req_id}] Playwright interaction failed: {pw_err}"))
|
| 789 |
+
except Exception as e:
|
| 790 |
+
context['logger'].exception(f"[{req_id}] 捕获到意外错误")
|
| 791 |
+
await save_error_snapshot(f"process_unexpected_error_{req_id}")
|
| 792 |
+
if not result_future.done():
|
| 793 |
+
result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Unexpected server error: {e}"))
|
| 794 |
+
finally:
|
| 795 |
+
await _cleanup_request_resources(req_id, disconnect_check_task, completion_event, result_future, request.stream)
|
api_utils/request_processor_backup.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
请求处理器模块
|
| 3 |
+
包含核心的请求处理逻辑
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import time
|
| 11 |
+
from typing import Optional, Tuple, Callable, AsyncGenerator
|
| 12 |
+
from asyncio import Event, Future
|
| 13 |
+
|
| 14 |
+
from fastapi import HTTPException, Request
|
| 15 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 16 |
+
from playwright.async_api import Page as AsyncPage, Locator, Error as PlaywrightAsyncError, expect as expect_async, TimeoutError
|
| 17 |
+
|
| 18 |
+
# --- 配置模块导入 ---
|
| 19 |
+
from config import *
|
| 20 |
+
|
| 21 |
+
# --- models模块导入 ---
|
| 22 |
+
from models import ChatCompletionRequest, ClientDisconnectedError
|
| 23 |
+
|
| 24 |
+
# --- browser_utils模块导入 ---
|
| 25 |
+
from browser_utils import (
|
| 26 |
+
switch_ai_studio_model,
|
| 27 |
+
save_error_snapshot,
|
| 28 |
+
_wait_for_response_completion,
|
| 29 |
+
_get_final_response_content,
|
| 30 |
+
detect_and_extract_page_error
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# --- api_utils模块导入 ---
|
| 34 |
+
from .utils import (
|
| 35 |
+
validate_chat_request,
|
| 36 |
+
prepare_combined_prompt,
|
| 37 |
+
generate_sse_chunk,
|
| 38 |
+
generate_sse_stop_chunk,
|
| 39 |
+
generate_sse_error_chunk,
|
| 40 |
+
use_helper_get_response,
|
| 41 |
+
use_stream_response
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
async def _process_request_refactored(
|
| 46 |
+
req_id: str,
|
| 47 |
+
request: ChatCompletionRequest,
|
| 48 |
+
http_request: Request,
|
| 49 |
+
result_future: Future
|
| 50 |
+
) -> Optional[Tuple[Event, Locator, Callable[[str], bool]]]:
|
| 51 |
+
"""核心请求处理函数 - 完整版本"""
|
| 52 |
+
global current_ai_studio_model_id
|
| 53 |
+
|
| 54 |
+
# 导入全局变量
|
| 55 |
+
from server import (
|
| 56 |
+
logger, page_instance, is_page_ready, parsed_model_list,
|
| 57 |
+
current_ai_studio_model_id, model_switching_lock, page_params_cache,
|
| 58 |
+
params_cache_lock
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
model_actually_switched_in_current_api_call = False
|
| 62 |
+
logger.info(f"[{req_id}] (Refactored Process) 开始处理请求...")
|
| 63 |
+
logger.info(f"[{req_id}] 请求参数 - Model: {request.model}, Stream: {request.stream}")
|
| 64 |
+
logger.info(f"[{req_id}] 请求参数 - Temperature: {request.temperature}")
|
| 65 |
+
logger.info(f"[{req_id}] 请求参数 - Max Output Tokens: {request.max_output_tokens}")
|
| 66 |
+
logger.info(f"[{req_id}] 请求参数 - Stop Sequences: {request.stop}")
|
| 67 |
+
logger.info(f"[{req_id}] 请求参数 - Top P: {request.top_p}")
|
| 68 |
+
|
| 69 |
+
is_streaming = request.stream
|
| 70 |
+
page: Optional[AsyncPage] = page_instance
|
| 71 |
+
completion_event: Optional[Event] = None
|
| 72 |
+
requested_model = request.model
|
| 73 |
+
model_id_to_use = None
|
| 74 |
+
needs_model_switching = False
|
| 75 |
+
|
| 76 |
+
if requested_model and requested_model != MODEL_NAME:
|
| 77 |
+
requested_model_parts = requested_model.split('/')
|
| 78 |
+
requested_model_id = requested_model_parts[-1] if len(requested_model_parts) > 1 else requested_model
|
| 79 |
+
logger.info(f"[{req_id}] 请求使用模型: {requested_model_id}")
|
| 80 |
+
if parsed_model_list:
|
| 81 |
+
valid_model_ids = [m.get("id") for m in parsed_model_list]
|
| 82 |
+
if requested_model_id not in valid_model_ids:
|
| 83 |
+
logger.error(f"[{req_id}] ❌ 无效的模型ID: {requested_model_id}。可用模型: {valid_model_ids}")
|
| 84 |
+
raise HTTPException(status_code=400, detail=f"[{req_id}] Invalid model '{requested_model_id}'. Available models: {', '.join(valid_model_ids)}")
|
| 85 |
+
model_id_to_use = requested_model_id
|
| 86 |
+
if current_ai_studio_model_id != model_id_to_use:
|
| 87 |
+
needs_model_switching = True
|
| 88 |
+
logger.info(f"[{req_id}] 需要切换模型: 当前={current_ai_studio_model_id} -> 目标={model_id_to_use}")
|
| 89 |
+
else:
|
| 90 |
+
logger.info(f"[{req_id}] 请求模型与当前模型相同 ({model_id_to_use}),无需切换")
|
| 91 |
+
else:
|
| 92 |
+
logger.info(f"[{req_id}] 未指定具体模型或使用代理模型名称,将使用当前模型: {current_ai_studio_model_id or '未知'}")
|
| 93 |
+
|
| 94 |
+
client_disconnected_event = Event()
|
| 95 |
+
disconnect_check_task = None
|
| 96 |
+
input_field_locator = page.locator(INPUT_SELECTOR) if page else None
|
| 97 |
+
submit_button_locator = page.locator(SUBMIT_BUTTON_SELECTOR) if page else None
|
| 98 |
+
|
| 99 |
+
async def check_disconnect_periodically():
|
| 100 |
+
while not client_disconnected_event.is_set():
|
| 101 |
+
try:
|
| 102 |
+
if await http_request.is_disconnected():
|
| 103 |
+
logger.info(f"[{req_id}] (Disco Check Task) 客户端断开。设置事件并尝试停止。")
|
| 104 |
+
client_disconnected_event.set()
|
| 105 |
+
try:
|
| 106 |
+
if submit_button_locator and await submit_button_locator.is_enabled(timeout=1500):
|
| 107 |
+
if input_field_locator and await input_field_locator.input_value(timeout=1500) == '':
|
| 108 |
+
logger.info(f"[{req_id}] (Disco Check Task) 点击停止...")
|
| 109 |
+
await submit_button_locator.click(timeout=3000, force=True)
|
| 110 |
+
except Exception as click_err:
|
| 111 |
+
logger.warning(f"[{req_id}] (Disco Check Task) 停止按钮点击失败: {click_err}")
|
| 112 |
+
if not result_future.done():
|
| 113 |
+
result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 客户端在处理期间关闭了请求"))
|
| 114 |
+
break
|
| 115 |
+
await asyncio.sleep(1.0)
|
| 116 |
+
except asyncio.CancelledError:
|
| 117 |
+
break
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.error(f"[{req_id}] (Disco Check Task) 错误: {e}")
|
| 120 |
+
client_disconnected_event.set()
|
| 121 |
+
if not result_future.done():
|
| 122 |
+
result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Internal disconnect checker error: {e}"))
|
| 123 |
+
break
|
| 124 |
+
|
| 125 |
+
disconnect_check_task = asyncio.create_task(check_disconnect_periodically())
|
| 126 |
+
|
| 127 |
+
def check_client_disconnected(*args):
|
| 128 |
+
msg_to_log = ""
|
| 129 |
+
if len(args) == 1 and isinstance(args[0], str):
|
| 130 |
+
msg_to_log = args[0]
|
| 131 |
+
|
| 132 |
+
if client_disconnected_event.is_set():
|
| 133 |
+
logger.info(f"[{req_id}] {msg_to_log}检测到客户端断开连接事件。")
|
| 134 |
+
raise ClientDisconnectedError(f"[{req_id}] Client disconnected event set.")
|
| 135 |
+
return False
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
if not page or page.is_closed() or not is_page_ready:
|
| 139 |
+
raise HTTPException(status_code=503, detail=f"[{req_id}] AI Studio 页面丢失或未就绪。", headers={"Retry-After": "30"})
|
| 140 |
+
|
| 141 |
+
check_client_disconnected("Initial Page Check: ")
|
| 142 |
+
|
| 143 |
+
# 模型切换逻辑
|
| 144 |
+
if needs_model_switching and model_id_to_use:
|
| 145 |
+
async with model_switching_lock:
|
| 146 |
+
model_before_switch_attempt = current_ai_studio_model_id
|
| 147 |
+
if current_ai_studio_model_id != model_id_to_use:
|
| 148 |
+
logger.info(f"[{req_id}] 获取锁后准备切换: 当前内存中模型={current_ai_studio_model_id}, 目标={model_id_to_use}")
|
| 149 |
+
switch_success = await switch_ai_studio_model(page, model_id_to_use, req_id)
|
| 150 |
+
if switch_success:
|
| 151 |
+
current_ai_studio_model_id = model_id_to_use
|
| 152 |
+
model_actually_switched_in_current_api_call = True
|
| 153 |
+
logger.info(f"[{req_id}] ✅ 模型切换成功。全局模型状态已更新为: {current_ai_studio_model_id}")
|
| 154 |
+
else:
|
| 155 |
+
logger.warning(f"[{req_id}] ❌ 模型切换至 {model_id_to_use} 失败 (AI Studio 未接受或覆盖了更改)。")
|
| 156 |
+
active_model_id_after_fail = model_before_switch_attempt
|
| 157 |
+
try:
|
| 158 |
+
final_prefs_str_after_fail = await page.evaluate("() => localStorage.getItem('aiStudioUserPreference')")
|
| 159 |
+
if final_prefs_str_after_fail:
|
| 160 |
+
final_prefs_obj_after_fail = json.loads(final_prefs_str_after_fail)
|
| 161 |
+
model_path_in_final_prefs = final_prefs_obj_after_fail.get("promptModel")
|
| 162 |
+
if model_path_in_final_prefs and isinstance(model_path_in_final_prefs, str):
|
| 163 |
+
active_model_id_after_fail = model_path_in_final_prefs.split('/')[-1]
|
| 164 |
+
except Exception as read_final_prefs_err:
|
| 165 |
+
logger.error(f"[{req_id}] 切换失败后读取最终 localStorage 出错: {read_final_prefs_err}")
|
| 166 |
+
current_ai_studio_model_id = active_model_id_after_fail
|
| 167 |
+
logger.info(f"[{req_id}] 全局模型状态在切换失败后设置为 (或保持为): {current_ai_studio_model_id}")
|
| 168 |
+
actual_displayed_model_name = "未知 (无法读取)"
|
| 169 |
+
try:
|
| 170 |
+
model_wrapper_locator = page.locator('#mat-select-value-0 mat-select-trigger').first
|
| 171 |
+
actual_displayed_model_name = await model_wrapper_locator.inner_text(timeout=3000)
|
| 172 |
+
except Exception:
|
| 173 |
+
pass
|
| 174 |
+
raise HTTPException(
|
| 175 |
+
status_code=422,
|
| 176 |
+
detail=f"[{req_id}] AI Studio 未能应用所请求的模型 '{model_id_to_use}' 或该模型不受支持。请选择 AI Studio 网页界面中可用的模型。当前实际生效的模型 ID 为 '{current_ai_studio_model_id}', 页面显示为 '{actual_displayed_model_name}'."
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
logger.info(f"[{req_id}] 获取锁后发现模型已是目标模型 {current_ai_studio_model_id},无需切换")
|
| 180 |
+
|
| 181 |
+
# 参数缓存处理
|
| 182 |
+
async with params_cache_lock:
|
| 183 |
+
cached_model_for_params = page_params_cache.get("last_known_model_id_for_params")
|
| 184 |
+
if model_actually_switched_in_current_api_call or \
|
| 185 |
+
(current_ai_studio_model_id is not None and current_ai_studio_model_id != cached_model_for_params):
|
| 186 |
+
action_taken = "Invalidating" if page_params_cache else "Initializing"
|
| 187 |
+
logger.info(f"[{req_id}] {action_taken} parameter cache. Reason: Model context changed (switched this call: {model_actually_switched_in_current_api_call}, current model: {current_ai_studio_model_id}, cache model: {cached_model_for_params}).")
|
| 188 |
+
page_params_cache.clear()
|
| 189 |
+
if current_ai_studio_model_id:
|
| 190 |
+
page_params_cache["last_known_model_id_for_params"] = current_ai_studio_model_id
|
| 191 |
+
else:
|
| 192 |
+
logger.debug(f"[{req_id}] Parameter cache for model '{cached_model_for_params}' remains valid (current model: '{current_ai_studio_model_id}', switched this call: {model_actually_switched_in_current_api_call}).")
|
| 193 |
+
|
| 194 |
+
# 验证请求
|
| 195 |
+
try:
|
| 196 |
+
validate_chat_request(request.messages, req_id)
|
| 197 |
+
except ValueError as e:
|
| 198 |
+
raise HTTPException(status_code=400, detail=f"[{req_id}] 无效请求: {e}")
|
| 199 |
+
|
| 200 |
+
# 准备提示
|
| 201 |
+
prepared_prompt = prepare_combined_prompt(request.messages, req_id)
|
| 202 |
+
check_client_disconnected("After Prompt Prep: ")
|
| 203 |
+
|
| 204 |
+
# 这里需要添加完整的处理逻辑 - 由于函数太长,暂时返回简化响应
|
| 205 |
+
logger.info(f"[{req_id}] (Refactored Process) 处理完整逻辑 - 需要从备份恢复剩余部分")
|
| 206 |
+
|
| 207 |
+
# 简单响应用于测试
|
| 208 |
+
if is_streaming:
|
| 209 |
+
completion_event = Event()
|
| 210 |
+
|
| 211 |
+
async def create_simple_stream_generator():
|
| 212 |
+
try:
|
| 213 |
+
yield generate_sse_chunk("正在处理请求...", req_id, MODEL_NAME)
|
| 214 |
+
await asyncio.sleep(1)
|
| 215 |
+
yield generate_sse_chunk("处理完成", req_id, MODEL_NAME)
|
| 216 |
+
yield generate_sse_stop_chunk(req_id, MODEL_NAME)
|
| 217 |
+
yield "data: [DONE]\n\n"
|
| 218 |
+
finally:
|
| 219 |
+
if not completion_event.is_set():
|
| 220 |
+
completion_event.set()
|
| 221 |
+
|
| 222 |
+
if not result_future.done():
|
| 223 |
+
result_future.set_result(StreamingResponse(create_simple_stream_generator(), media_type="text/event-stream"))
|
| 224 |
+
|
| 225 |
+
return completion_event, submit_button_locator, check_client_disconnected
|
| 226 |
+
else:
|
| 227 |
+
response_payload = {
|
| 228 |
+
"id": f"{CHAT_COMPLETION_ID_PREFIX}{req_id}-{int(time.time())}",
|
| 229 |
+
"object": "chat.completion",
|
| 230 |
+
"created": int(time.time()),
|
| 231 |
+
"model": MODEL_NAME,
|
| 232 |
+
"choices": [{
|
| 233 |
+
"index": 0,
|
| 234 |
+
"message": {"role": "assistant", "content": "处理完成 - 需要完整逻辑"},
|
| 235 |
+
"finish_reason": "stop"
|
| 236 |
+
}],
|
| 237 |
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
if not result_future.done():
|
| 241 |
+
result_future.set_result(JSONResponse(content=response_payload))
|
| 242 |
+
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
except ClientDisconnectedError as disco_err:
|
| 246 |
+
logger.info(f"[{req_id}] (Refactored Process) 捕获到客户端断开连接信号: {disco_err}")
|
| 247 |
+
if not result_future.done():
|
| 248 |
+
result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] Client disconnected during processing."))
|
| 249 |
+
except HTTPException as http_err:
|
| 250 |
+
logger.warning(f"[{req_id}] (Refactored Process) 捕获到 HTTP 异常: {http_err.status_code} - {http_err.detail}")
|
| 251 |
+
if not result_future.done():
|
| 252 |
+
result_future.set_exception(http_err)
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.exception(f"[{req_id}] (Refactored Process) 捕获到意外错误")
|
| 255 |
+
await save_error_snapshot(f"process_unexpected_error_{req_id}")
|
| 256 |
+
if not result_future.done():
|
| 257 |
+
result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Unexpected server error: {e}"))
|
| 258 |
+
finally:
|
| 259 |
+
if disconnect_check_task and not disconnect_check_task.done():
|
| 260 |
+
disconnect_check_task.cancel()
|
| 261 |
+
try:
|
| 262 |
+
await disconnect_check_task
|
| 263 |
+
except asyncio.CancelledError:
|
| 264 |
+
pass
|
| 265 |
+
except Exception as task_clean_err:
|
| 266 |
+
logger.error(f"[{req_id}] 清理任务时出错: {task_clean_err}")
|
| 267 |
+
|
| 268 |
+
logger.info(f"[{req_id}] (Refactored Process) 处理完成。")
|
| 269 |
+
|
| 270 |
+
if is_streaming and completion_event and not completion_event.is_set() and (result_future.done() and result_future.exception() is not None):
|
| 271 |
+
logger.warning(f"[{req_id}] (Refactored Process) 流式请求异常,确保完成事件已设置。")
|
| 272 |
+
completion_event.set()
|
| 273 |
+
|
| 274 |
+
return completion_event, submit_button_locator, check_client_disconnected
|
api_utils/routes.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI路由处理器模块
|
| 3 |
+
包含所有API端点的处理函数
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import time
|
| 10 |
+
import uuid
|
| 11 |
+
from typing import Dict, List, Any, Set
|
| 12 |
+
from asyncio import Queue, Future, Lock, Event
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
from fastapi import HTTPException, Request, WebSocket, WebSocketDisconnect, Depends
|
| 16 |
+
from fastapi.responses import JSONResponse, FileResponse
|
| 17 |
+
from pydantic import BaseModel
|
| 18 |
+
from playwright.async_api import Page as AsyncPage
|
| 19 |
+
|
| 20 |
+
# --- 配置模块导入 ---
|
| 21 |
+
from config import *
|
| 22 |
+
|
| 23 |
+
# --- models模块导入 ---
|
| 24 |
+
from models import ChatCompletionRequest, WebSocketConnectionManager
|
| 25 |
+
|
| 26 |
+
# --- browser_utils模块导入 ---
|
| 27 |
+
from browser_utils import _handle_model_list_response
|
| 28 |
+
|
| 29 |
+
# --- 依赖项导入 ---
|
| 30 |
+
from .dependencies import *
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# --- 静态文件端点 ---
|
| 34 |
+
async def read_index(logger: logging.Logger = Depends(get_logger)):
|
| 35 |
+
"""返回主页面"""
|
| 36 |
+
index_html_path = os.path.join(os.path.dirname(__file__), "..", "index.html")
|
| 37 |
+
if not os.path.exists(index_html_path):
|
| 38 |
+
logger.error(f"index.html not found at {index_html_path}")
|
| 39 |
+
raise HTTPException(status_code=404, detail="index.html not found")
|
| 40 |
+
return FileResponse(index_html_path)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
async def get_css(logger: logging.Logger = Depends(get_logger)):
|
| 44 |
+
"""返回CSS文件"""
|
| 45 |
+
css_path = os.path.join(os.path.dirname(__file__), "..", "webui.css")
|
| 46 |
+
if not os.path.exists(css_path):
|
| 47 |
+
logger.error(f"webui.css not found at {css_path}")
|
| 48 |
+
raise HTTPException(status_code=404, detail="webui.css not found")
|
| 49 |
+
return FileResponse(css_path, media_type="text/css")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
async def get_js(logger: logging.Logger = Depends(get_logger)):
|
| 53 |
+
"""返回JavaScript文件"""
|
| 54 |
+
js_path = os.path.join(os.path.dirname(__file__), "..", "webui.js")
|
| 55 |
+
if not os.path.exists(js_path):
|
| 56 |
+
logger.error(f"webui.js not found at {js_path}")
|
| 57 |
+
raise HTTPException(status_code=404, detail="webui.js not found")
|
| 58 |
+
return FileResponse(js_path, media_type="application/javascript")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# --- API信息端点 ---
|
| 62 |
+
async def get_api_info(request: Request, current_ai_studio_model_id: str = Depends(get_current_ai_studio_model_id)):
|
| 63 |
+
"""返回API信息"""
|
| 64 |
+
from api_utils import auth_utils
|
| 65 |
+
|
| 66 |
+
server_port = request.url.port or os.environ.get('SERVER_PORT_INFO', '8000')
|
| 67 |
+
host = request.headers.get('host') or f"127.0.0.1:{server_port}"
|
| 68 |
+
scheme = request.headers.get('x-forwarded-proto', 'http')
|
| 69 |
+
base_url = f"{scheme}://{host}"
|
| 70 |
+
api_base = f"{base_url}/v1"
|
| 71 |
+
effective_model_name = current_ai_studio_model_id or MODEL_NAME
|
| 72 |
+
|
| 73 |
+
api_key_required = bool(auth_utils.API_KEYS)
|
| 74 |
+
api_key_count = len(auth_utils.API_KEYS)
|
| 75 |
+
|
| 76 |
+
if api_key_required:
|
| 77 |
+
message = f"API Key is required. {api_key_count} valid key(s) configured."
|
| 78 |
+
else:
|
| 79 |
+
message = "API Key is not required."
|
| 80 |
+
|
| 81 |
+
return JSONResponse(content={
|
| 82 |
+
"model_name": effective_model_name,
|
| 83 |
+
"api_base_url": api_base,
|
| 84 |
+
"server_base_url": base_url,
|
| 85 |
+
"api_key_required": api_key_required,
|
| 86 |
+
"api_key_count": api_key_count,
|
| 87 |
+
"auth_header": "Authorization: Bearer <token> or X-API-Key: <token>" if api_key_required else None,
|
| 88 |
+
"openai_compatible": True,
|
| 89 |
+
"supported_auth_methods": ["Authorization: Bearer", "X-API-Key"] if api_key_required else [],
|
| 90 |
+
"message": message
|
| 91 |
+
})
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# --- 健康检查端点 ---
|
| 95 |
+
async def health_check(
|
| 96 |
+
server_state: Dict[str, Any] = Depends(get_server_state),
|
| 97 |
+
worker_task = Depends(get_worker_task),
|
| 98 |
+
request_queue: Queue = Depends(get_request_queue)
|
| 99 |
+
):
|
| 100 |
+
"""健康检查"""
|
| 101 |
+
is_worker_running = bool(worker_task and not worker_task.done())
|
| 102 |
+
launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
|
| 103 |
+
browser_page_critical = launch_mode != "direct_debug_no_browser"
|
| 104 |
+
|
| 105 |
+
core_ready_conditions = [not server_state["is_initializing"], server_state["is_playwright_ready"]]
|
| 106 |
+
if browser_page_critical:
|
| 107 |
+
core_ready_conditions.extend([server_state["is_browser_connected"], server_state["is_page_ready"]])
|
| 108 |
+
|
| 109 |
+
is_core_ready = all(core_ready_conditions)
|
| 110 |
+
status_val = "OK" if is_core_ready and is_worker_running else "Error"
|
| 111 |
+
q_size = request_queue.qsize() if request_queue else -1
|
| 112 |
+
|
| 113 |
+
status_message_parts = []
|
| 114 |
+
if server_state["is_initializing"]: status_message_parts.append("初始化进行中")
|
| 115 |
+
if not server_state["is_playwright_ready"]: status_message_parts.append("Playwright 未就绪")
|
| 116 |
+
if browser_page_critical:
|
| 117 |
+
if not server_state["is_browser_connected"]: status_message_parts.append("浏览器未连接")
|
| 118 |
+
if not server_state["is_page_ready"]: status_message_parts.append("页面未就绪")
|
| 119 |
+
if not is_worker_running: status_message_parts.append("Worker 未运行")
|
| 120 |
+
|
| 121 |
+
status = {
|
| 122 |
+
"status": status_val,
|
| 123 |
+
"message": "",
|
| 124 |
+
"details": {**server_state, "workerRunning": is_worker_running, "queueLength": q_size, "launchMode": launch_mode, "browserAndPageCritical": browser_page_critical}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
if status_val == "OK":
|
| 128 |
+
status["message"] = f"服务运行中;队列长度: {q_size}。"
|
| 129 |
+
return JSONResponse(content=status, status_code=200)
|
| 130 |
+
else:
|
| 131 |
+
status["message"] = f"服务不可用;问题: {(', '.join(status_message_parts) or '未知原因')}. 队列长度: {q_size}."
|
| 132 |
+
return JSONResponse(content=status, status_code=503)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# --- 模型列表端点 ---
|
| 136 |
+
async def list_models(
|
| 137 |
+
logger: logging.Logger = Depends(get_logger),
|
| 138 |
+
model_list_fetch_event: Event = Depends(get_model_list_fetch_event),
|
| 139 |
+
page_instance: AsyncPage = Depends(get_page_instance),
|
| 140 |
+
parsed_model_list: List[Dict[str, Any]] = Depends(get_parsed_model_list),
|
| 141 |
+
excluded_model_ids: Set[str] = Depends(get_excluded_model_ids)
|
| 142 |
+
):
|
| 143 |
+
"""获取模型列表"""
|
| 144 |
+
logger.info("[API] 收到 /v1/models 请求。")
|
| 145 |
+
|
| 146 |
+
if not model_list_fetch_event.is_set() and page_instance and not page_instance.is_closed():
|
| 147 |
+
logger.info("/v1/models: 模型列表事件未设置,尝试刷新页面...")
|
| 148 |
+
try:
|
| 149 |
+
await page_instance.reload(wait_until="domcontentloaded", timeout=20000)
|
| 150 |
+
await asyncio.wait_for(model_list_fetch_event.wait(), timeout=10.0)
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error(f"/v1/models: 刷新或等待模型列表时出错: {e}")
|
| 153 |
+
finally:
|
| 154 |
+
if not model_list_fetch_event.is_set():
|
| 155 |
+
model_list_fetch_event.set()
|
| 156 |
+
|
| 157 |
+
if parsed_model_list:
|
| 158 |
+
final_model_list = [m for m in parsed_model_list if m.get("id") not in excluded_model_ids]
|
| 159 |
+
return {"object": "list", "data": final_model_list}
|
| 160 |
+
else:
|
| 161 |
+
logger.warning("模型列表为空,返回默认后备模型。")
|
| 162 |
+
return {"object": "list", "data": [{
|
| 163 |
+
"id": DEFAULT_FALLBACK_MODEL_ID, "object": "model", "created": int(time.time()),
|
| 164 |
+
"owned_by": "camoufox-proxy-fallback"
|
| 165 |
+
}]}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# --- 聊天完成端点 ---
|
| 169 |
+
async def chat_completions(
|
| 170 |
+
request: ChatCompletionRequest,
|
| 171 |
+
http_request: Request,
|
| 172 |
+
logger: logging.Logger = Depends(get_logger),
|
| 173 |
+
request_queue: Queue = Depends(get_request_queue),
|
| 174 |
+
server_state: Dict[str, Any] = Depends(get_server_state),
|
| 175 |
+
worker_task = Depends(get_worker_task)
|
| 176 |
+
):
|
| 177 |
+
"""处理聊天完成请求"""
|
| 178 |
+
req_id = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=7))
|
| 179 |
+
logger.info(f"[{req_id}] 收到 /v1/chat/completions 请求 (Stream={request.stream})")
|
| 180 |
+
|
| 181 |
+
launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
|
| 182 |
+
browser_page_critical = launch_mode != "direct_debug_no_browser"
|
| 183 |
+
|
| 184 |
+
service_unavailable = server_state["is_initializing"] or \
|
| 185 |
+
not server_state["is_playwright_ready"] or \
|
| 186 |
+
(browser_page_critical and (not server_state["is_page_ready"] or not server_state["is_browser_connected"])) or \
|
| 187 |
+
not worker_task or worker_task.done()
|
| 188 |
+
|
| 189 |
+
if service_unavailable:
|
| 190 |
+
raise HTTPException(status_code=503, detail=f"[{req_id}] 服务当前不可用。请稍后重试。", headers={"Retry-After": "30"})
|
| 191 |
+
|
| 192 |
+
result_future = Future()
|
| 193 |
+
await request_queue.put({
|
| 194 |
+
"req_id": req_id, "request_data": request, "http_request": http_request,
|
| 195 |
+
"result_future": result_future, "enqueue_time": time.time(), "cancelled": False
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
timeout_seconds = RESPONSE_COMPLETION_TIMEOUT / 1000 + 120
|
| 200 |
+
return await asyncio.wait_for(result_future, timeout=timeout_seconds)
|
| 201 |
+
except asyncio.TimeoutError:
|
| 202 |
+
raise HTTPException(status_code=504, detail=f"[{req_id}] 请求处理超时。")
|
| 203 |
+
except asyncio.CancelledError:
|
| 204 |
+
raise HTTPException(status_code=499, detail=f"[{req_id}] 请求被客户端取消。")
|
| 205 |
+
except Exception as e:
|
| 206 |
+
logger.exception(f"[{req_id}] 等待Worker响应时出错")
|
| 207 |
+
raise HTTPException(status_code=500, detail=f"[{req_id}] 服务器内部错误: {e}")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# --- 取消请求相关 ---
|
| 211 |
+
async def cancel_queued_request(req_id: str, request_queue: Queue, logger: logging.Logger) -> bool:
|
| 212 |
+
"""取消队列中的请求"""
|
| 213 |
+
items_to_requeue = []
|
| 214 |
+
found = False
|
| 215 |
+
try:
|
| 216 |
+
while not request_queue.empty():
|
| 217 |
+
item = request_queue.get_nowait()
|
| 218 |
+
if item.get("req_id") == req_id:
|
| 219 |
+
logger.info(f"[{req_id}] 在队列中找到请求,标记为已取消。")
|
| 220 |
+
item["cancelled"] = True
|
| 221 |
+
if (future := item.get("result_future")) and not future.done():
|
| 222 |
+
future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] Request cancelled."))
|
| 223 |
+
found = True
|
| 224 |
+
items_to_requeue.append(item)
|
| 225 |
+
finally:
|
| 226 |
+
for item in items_to_requeue:
|
| 227 |
+
await request_queue.put(item)
|
| 228 |
+
return found
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
async def cancel_request(
|
| 232 |
+
req_id: str,
|
| 233 |
+
logger: logging.Logger = Depends(get_logger),
|
| 234 |
+
request_queue: Queue = Depends(get_request_queue)
|
| 235 |
+
):
|
| 236 |
+
"""取消请求端点"""
|
| 237 |
+
logger.info(f"[{req_id}] 收到取消请求。")
|
| 238 |
+
if await cancel_queued_request(req_id, request_queue, logger):
|
| 239 |
+
return JSONResponse(content={"success": True, "message": f"Request {req_id} marked as cancelled."})
|
| 240 |
+
else:
|
| 241 |
+
return JSONResponse(status_code=404, content={"success": False, "message": f"Request {req_id} not found in queue."})
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# --- 队列状态端点 ---
|
| 245 |
+
async def get_queue_status(
|
| 246 |
+
request_queue: Queue = Depends(get_request_queue),
|
| 247 |
+
processing_lock: Lock = Depends(get_processing_lock)
|
| 248 |
+
):
|
| 249 |
+
"""获取队列状态"""
|
| 250 |
+
queue_items = list(request_queue._queue)
|
| 251 |
+
return JSONResponse(content={
|
| 252 |
+
"queue_length": len(queue_items),
|
| 253 |
+
"is_processing_locked": processing_lock.locked(),
|
| 254 |
+
"items": sorted([
|
| 255 |
+
{
|
| 256 |
+
"req_id": item.get("req_id", "unknown"),
|
| 257 |
+
"enqueue_time": item.get("enqueue_time", 0),
|
| 258 |
+
"wait_time_seconds": round(time.time() - item.get("enqueue_time", 0), 2),
|
| 259 |
+
"is_streaming": item.get("request_data").stream,
|
| 260 |
+
"cancelled": item.get("cancelled", False)
|
| 261 |
+
} for item in queue_items
|
| 262 |
+
], key=lambda x: x.get("enqueue_time", 0))
|
| 263 |
+
})
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# --- WebSocket日志端点 ---
|
| 267 |
+
async def websocket_log_endpoint(
|
| 268 |
+
websocket: WebSocket,
|
| 269 |
+
logger: logging.Logger = Depends(get_logger),
|
| 270 |
+
log_ws_manager: WebSocketConnectionManager = Depends(get_log_ws_manager)
|
| 271 |
+
):
|
| 272 |
+
"""WebSocket日志端点"""
|
| 273 |
+
if not log_ws_manager:
|
| 274 |
+
await websocket.close(code=1011)
|
| 275 |
+
return
|
| 276 |
+
|
| 277 |
+
client_id = str(uuid.uuid4())
|
| 278 |
+
try:
|
| 279 |
+
await log_ws_manager.connect(client_id, websocket)
|
| 280 |
+
while True:
|
| 281 |
+
await websocket.receive_text() # Keep connection alive
|
| 282 |
+
except WebSocketDisconnect:
|
| 283 |
+
pass
|
| 284 |
+
except Exception as e:
|
| 285 |
+
logger.error(f"日志 WebSocket (客户端 {client_id}) 发生异常: {e}", exc_info=True)
|
| 286 |
+
finally:
|
| 287 |
+
log_ws_manager.disconnect(client_id)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# --- API密钥管理数据模型 ---
|
| 291 |
+
class ApiKeyRequest(BaseModel):
|
| 292 |
+
key: str
|
| 293 |
+
|
| 294 |
+
class ApiKeyTestRequest(BaseModel):
|
| 295 |
+
key: str
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# --- API密钥管理端点 ---
|
| 299 |
+
async def get_api_keys(logger: logging.Logger = Depends(get_logger)):
|
| 300 |
+
"""获取API密钥列表"""
|
| 301 |
+
from api_utils import auth_utils
|
| 302 |
+
try:
|
| 303 |
+
auth_utils.initialize_keys()
|
| 304 |
+
keys_info = [{"value": key, "status": "有效"} for key in auth_utils.API_KEYS]
|
| 305 |
+
return JSONResponse(content={"success": True, "keys": keys_info, "total_count": len(keys_info)})
|
| 306 |
+
except Exception as e:
|
| 307 |
+
logger.error(f"获取API密钥列表失败: {e}")
|
| 308 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
async def add_api_key(request: ApiKeyRequest, logger: logging.Logger = Depends(get_logger)):
|
| 312 |
+
"""添加API密钥"""
|
| 313 |
+
from api_utils import auth_utils
|
| 314 |
+
key_value = request.key.strip()
|
| 315 |
+
if not key_value or len(key_value) < 8:
|
| 316 |
+
raise HTTPException(status_code=400, detail="无效的API密钥格式。")
|
| 317 |
+
|
| 318 |
+
auth_utils.initialize_keys()
|
| 319 |
+
if key_value in auth_utils.API_KEYS:
|
| 320 |
+
raise HTTPException(status_code=400, detail="该API密钥已存在。")
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
key_file_path = os.path.join(os.path.dirname(__file__), "..", "key.txt")
|
| 324 |
+
with open(key_file_path, 'a+', encoding='utf-8') as f:
|
| 325 |
+
f.seek(0)
|
| 326 |
+
if f.read(): f.write("\n")
|
| 327 |
+
f.write(key_value)
|
| 328 |
+
|
| 329 |
+
auth_utils.initialize_keys()
|
| 330 |
+
logger.info(f"API密钥已添加: {key_value[:4]}...{key_value[-4:]}")
|
| 331 |
+
return JSONResponse(content={"success": True, "message": "API密钥添加成功", "key_count": len(auth_utils.API_KEYS)})
|
| 332 |
+
except Exception as e:
|
| 333 |
+
logger.error(f"添加API密钥失败: {e}")
|
| 334 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
async def test_api_key(request: ApiKeyTestRequest, logger: logging.Logger = Depends(get_logger)):
|
| 338 |
+
"""测试API密钥"""
|
| 339 |
+
from api_utils import auth_utils
|
| 340 |
+
key_value = request.key.strip()
|
| 341 |
+
if not key_value:
|
| 342 |
+
raise HTTPException(status_code=400, detail="API密钥不能为空。")
|
| 343 |
+
|
| 344 |
+
auth_utils.initialize_keys()
|
| 345 |
+
is_valid = auth_utils.verify_api_key(key_value)
|
| 346 |
+
logger.info(f"API密钥测试: {key_value[:4]}...{key_value[-4:]} - {'有效' if is_valid else '无效'}")
|
| 347 |
+
return JSONResponse(content={"success": True, "valid": is_valid, "message": "密钥有效" if is_valid else "密钥无效或不存在"})
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
async def delete_api_key(request: ApiKeyRequest, logger: logging.Logger = Depends(get_logger)):
|
| 351 |
+
"""删除API密钥"""
|
| 352 |
+
from api_utils import auth_utils
|
| 353 |
+
key_value = request.key.strip()
|
| 354 |
+
if not key_value:
|
| 355 |
+
raise HTTPException(status_code=400, detail="API密钥不能为空。")
|
| 356 |
+
|
| 357 |
+
auth_utils.initialize_keys()
|
| 358 |
+
if key_value not in auth_utils.API_KEYS:
|
| 359 |
+
raise HTTPException(status_code=404, detail="API密钥不存在。")
|
| 360 |
+
|
| 361 |
+
try:
|
| 362 |
+
key_file_path = os.path.join(os.path.dirname(__file__), "..", "key.txt")
|
| 363 |
+
with open(key_file_path, 'r', encoding='utf-8') as f:
|
| 364 |
+
lines = f.readlines()
|
| 365 |
+
|
| 366 |
+
with open(key_file_path, 'w', encoding='utf-8') as f:
|
| 367 |
+
f.writelines(line for line in lines if line.strip() != key_value)
|
| 368 |
+
|
| 369 |
+
auth_utils.initialize_keys()
|
| 370 |
+
logger.info(f"API密钥已删除: {key_value[:4]}...{key_value[-4:]}")
|
| 371 |
+
return JSONResponse(content={"success": True, "message": "API密钥删除成功", "key_count": len(auth_utils.API_KEYS)})
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.error(f"删除API密钥失败: {e}")
|
| 374 |
+
raise HTTPException(status_code=500, detail=str(e))
|
api_utils/utils.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API工具函数模块
|
| 3 |
+
包含SSE生成、流处理、token统计和请求验证等工具函数
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import json
|
| 8 |
+
import time
|
| 9 |
+
import datetime
|
| 10 |
+
from typing import Any, Dict, List, Optional, AsyncGenerator
|
| 11 |
+
from asyncio import Queue
|
| 12 |
+
from models import Message
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# --- SSE生成函数 ---
|
| 17 |
+
def generate_sse_chunk(delta: str, req_id: str, model: str) -> str:
|
| 18 |
+
"""生成SSE数据块"""
|
| 19 |
+
chunk_data = {
|
| 20 |
+
"id": f"chatcmpl-{req_id}",
|
| 21 |
+
"object": "chat.completion.chunk",
|
| 22 |
+
"created": int(time.time()),
|
| 23 |
+
"model": model,
|
| 24 |
+
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}]
|
| 25 |
+
}
|
| 26 |
+
return f"data: {json.dumps(chunk_data)}\n\n"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def generate_sse_stop_chunk(req_id: str, model: str, reason: str = "stop", usage: dict = None) -> str:
|
| 30 |
+
"""生成SSE停止块"""
|
| 31 |
+
stop_chunk_data = {
|
| 32 |
+
"id": f"chatcmpl-{req_id}",
|
| 33 |
+
"object": "chat.completion.chunk",
|
| 34 |
+
"created": int(time.time()),
|
| 35 |
+
"model": model,
|
| 36 |
+
"choices": [{"index": 0, "delta": {}, "finish_reason": reason}]
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# 添加usage信息(如果提供)
|
| 40 |
+
if usage:
|
| 41 |
+
stop_chunk_data["usage"] = usage
|
| 42 |
+
|
| 43 |
+
return f"data: {json.dumps(stop_chunk_data)}\n\ndata: [DONE]\n\n"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def generate_sse_error_chunk(message: str, req_id: str, error_type: str = "server_error") -> str:
|
| 47 |
+
"""生成SSE错误块"""
|
| 48 |
+
error_chunk = {"error": {"message": message, "type": error_type, "param": None, "code": req_id}}
|
| 49 |
+
return f"data: {json.dumps(error_chunk)}\n\n"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# --- 流处理工具函数 ---
|
| 53 |
+
async def use_stream_response(req_id: str) -> AsyncGenerator[Any, None]:
|
| 54 |
+
"""使用流响应(从服务器的全局队列获取数据)"""
|
| 55 |
+
from server import STREAM_QUEUE, logger
|
| 56 |
+
import queue
|
| 57 |
+
|
| 58 |
+
if STREAM_QUEUE is None:
|
| 59 |
+
logger.warning(f"[{req_id}] STREAM_QUEUE is None, 无法使用流响应")
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
logger.info(f"[{req_id}] 开始使用流响应")
|
| 63 |
+
|
| 64 |
+
empty_count = 0
|
| 65 |
+
max_empty_retries = 300 # 30秒超时
|
| 66 |
+
data_received = False
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
while True:
|
| 70 |
+
try:
|
| 71 |
+
# 从队列中获取数据
|
| 72 |
+
data = STREAM_QUEUE.get_nowait()
|
| 73 |
+
if data is None: # 结束标志
|
| 74 |
+
logger.info(f"[{req_id}] 接收到流结束标志")
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
# 重置空计数器
|
| 78 |
+
empty_count = 0
|
| 79 |
+
data_received = True
|
| 80 |
+
logger.debug(f"[{req_id}] 接收到流数据: {type(data)} - {str(data)[:200]}...")
|
| 81 |
+
|
| 82 |
+
# 检查是否是JSON字符串形式的结束标志
|
| 83 |
+
if isinstance(data, str):
|
| 84 |
+
try:
|
| 85 |
+
parsed_data = json.loads(data)
|
| 86 |
+
if parsed_data.get("done") is True:
|
| 87 |
+
logger.info(f"[{req_id}] 接收到JSON格式的完成标志")
|
| 88 |
+
yield parsed_data
|
| 89 |
+
break
|
| 90 |
+
else:
|
| 91 |
+
yield parsed_data
|
| 92 |
+
except json.JSONDecodeError:
|
| 93 |
+
# 如果不是JSON,直接返回字符串
|
| 94 |
+
logger.debug(f"[{req_id}] 返回非JSON字符串数据")
|
| 95 |
+
yield data
|
| 96 |
+
else:
|
| 97 |
+
# 直接返回数据
|
| 98 |
+
yield data
|
| 99 |
+
|
| 100 |
+
# 检查字典类型的结束标志
|
| 101 |
+
if isinstance(data, dict) and data.get("done") is True:
|
| 102 |
+
logger.info(f"[{req_id}] 接收到字典格式的完成标志")
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
except (queue.Empty, asyncio.QueueEmpty):
|
| 106 |
+
empty_count += 1
|
| 107 |
+
if empty_count % 50 == 0: # 每5秒记录一次等待状态
|
| 108 |
+
logger.info(f"[{req_id}] 等待流数据... ({empty_count}/{max_empty_retries})")
|
| 109 |
+
|
| 110 |
+
if empty_count >= max_empty_retries:
|
| 111 |
+
if not data_received:
|
| 112 |
+
logger.error(f"[{req_id}] 流响应队列空读取次数达到上限且未收到任何数据,可能是辅助流未启动或出错")
|
| 113 |
+
else:
|
| 114 |
+
logger.warning(f"[{req_id}] 流响应队列空读取次数达到上限 ({max_empty_retries}),结束读取")
|
| 115 |
+
|
| 116 |
+
# 返回超时完成信号,而不是简单退出
|
| 117 |
+
yield {"done": True, "reason": "internal_timeout", "body": "", "function": []}
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
await asyncio.sleep(0.1) # 100ms等待
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"[{req_id}] 使用流响应时出错: {e}")
|
| 125 |
+
raise
|
| 126 |
+
finally:
|
| 127 |
+
logger.info(f"[{req_id}] 流响应使用完成,数据接收状态: {data_received}")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
async def clear_stream_queue():
|
| 131 |
+
"""清空流队列(与原始参考文件保持一致)"""
|
| 132 |
+
from server import STREAM_QUEUE, logger
|
| 133 |
+
import queue
|
| 134 |
+
|
| 135 |
+
if STREAM_QUEUE is None:
|
| 136 |
+
logger.info("流队列未初始化或已被禁用,跳过清空操作。")
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
while True:
|
| 140 |
+
try:
|
| 141 |
+
data_chunk = await asyncio.to_thread(STREAM_QUEUE.get_nowait)
|
| 142 |
+
# logger.info(f"清空流式队列缓存,丢弃数据: {data_chunk}")
|
| 143 |
+
except queue.Empty:
|
| 144 |
+
logger.info("流式队列已清空 (捕获到 queue.Empty)。")
|
| 145 |
+
break
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"清空流式队列时发生意外错误: {e}", exc_info=True)
|
| 148 |
+
break
|
| 149 |
+
logger.info("流式队列缓存清空完毕。")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# --- Helper response generator ---
|
| 153 |
+
async def use_helper_get_response(helper_endpoint: str, helper_sapisid: str) -> AsyncGenerator[str, None]:
|
| 154 |
+
"""使用Helper服务获取响应的生成器"""
|
| 155 |
+
from server import logger
|
| 156 |
+
import aiohttp
|
| 157 |
+
|
| 158 |
+
logger.info(f"正在尝试使用Helper端点: {helper_endpoint}")
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
async with aiohttp.ClientSession() as session:
|
| 162 |
+
headers = {
|
| 163 |
+
'Content-Type': 'application/json',
|
| 164 |
+
'Cookie': f'SAPISID={helper_sapisid}' if helper_sapisid else ''
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
async with session.get(helper_endpoint, headers=headers) as response:
|
| 168 |
+
if response.status == 200:
|
| 169 |
+
async for chunk in response.content.iter_chunked(1024):
|
| 170 |
+
if chunk:
|
| 171 |
+
yield chunk.decode('utf-8', errors='ignore')
|
| 172 |
+
else:
|
| 173 |
+
logger.error(f"Helper端点返回错误状态: {response.status}")
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.error(f"使用Helper端点时出错: {e}")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# --- 请求验证函数 ---
|
| 180 |
+
def validate_chat_request(messages: List[Message], req_id: str) -> Dict[str, Optional[str]]:
|
| 181 |
+
"""验证聊天请求"""
|
| 182 |
+
from server import logger
|
| 183 |
+
|
| 184 |
+
if not messages:
|
| 185 |
+
raise ValueError(f"[{req_id}] 无效请求: 'messages' 数组缺失或为空。")
|
| 186 |
+
|
| 187 |
+
if not any(msg.role != 'system' for msg in messages):
|
| 188 |
+
raise ValueError(f"[{req_id}] 无效请求: 所有消息都是系统消息。至少需要一条用户或助手消息。")
|
| 189 |
+
|
| 190 |
+
# 返回验证结果
|
| 191 |
+
return {
|
| 192 |
+
"error": None,
|
| 193 |
+
"warning": None
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# --- 提示准备函数 ---
|
| 198 |
+
def prepare_combined_prompt(messages: List[Message], req_id: str) -> str:
|
| 199 |
+
"""准备组合提示"""
|
| 200 |
+
from server import logger
|
| 201 |
+
|
| 202 |
+
logger.info(f"[{req_id}] (准备提示) 正在从 {len(messages)} 条消息准备组合提示 (包括历史)。")
|
| 203 |
+
|
| 204 |
+
combined_parts = []
|
| 205 |
+
system_prompt_content: Optional[str] = None
|
| 206 |
+
processed_system_message_indices = set()
|
| 207 |
+
|
| 208 |
+
# 处理系统消息
|
| 209 |
+
for i, msg in enumerate(messages):
|
| 210 |
+
if msg.role == 'system':
|
| 211 |
+
content = msg.content
|
| 212 |
+
if isinstance(content, str) and content.strip():
|
| 213 |
+
system_prompt_content = content.strip()
|
| 214 |
+
processed_system_message_indices.add(i)
|
| 215 |
+
logger.info(f"[{req_id}] (准备提示) 在索引 {i} 找到并使用系统提示: '{system_prompt_content[:80]}...'")
|
| 216 |
+
system_instr_prefix = "系统指令:\n"
|
| 217 |
+
combined_parts.append(f"{system_instr_prefix}{system_prompt_content}")
|
| 218 |
+
else:
|
| 219 |
+
logger.info(f"[{req_id}] (准备提示) 在索引 {i} 忽略非字符串或空的系统消息。")
|
| 220 |
+
processed_system_message_indices.add(i)
|
| 221 |
+
break
|
| 222 |
+
|
| 223 |
+
role_map_ui = {"user": "用户", "assistant": "助手", "system": "系统", "tool": "工具"}
|
| 224 |
+
turn_separator = "\n---\n"
|
| 225 |
+
|
| 226 |
+
# 处理其他消息
|
| 227 |
+
for i, msg in enumerate(messages):
|
| 228 |
+
if i in processed_system_message_indices:
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
if msg.role == 'system':
|
| 232 |
+
logger.info(f"[{req_id}] (准备提示) 跳过在索引 {i} 的后续系统消息。")
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
if combined_parts:
|
| 236 |
+
combined_parts.append(turn_separator)
|
| 237 |
+
|
| 238 |
+
role = msg.role or 'unknown'
|
| 239 |
+
role_prefix_ui = f"{role_map_ui.get(role, role.capitalize())}:\n"
|
| 240 |
+
current_turn_parts = [role_prefix_ui]
|
| 241 |
+
|
| 242 |
+
content = msg.content or ''
|
| 243 |
+
content_str = ""
|
| 244 |
+
|
| 245 |
+
if isinstance(content, str):
|
| 246 |
+
content_str = content.strip()
|
| 247 |
+
elif isinstance(content, list):
|
| 248 |
+
# 处理多模态内容
|
| 249 |
+
text_parts = []
|
| 250 |
+
for item in content:
|
| 251 |
+
if hasattr(item, 'type') and item.type == 'text':
|
| 252 |
+
text_parts.append(item.text or '')
|
| 253 |
+
elif isinstance(item, dict) and item.get('type') == 'text':
|
| 254 |
+
text_parts.append(item.get('text', ''))
|
| 255 |
+
else:
|
| 256 |
+
logger.warning(f"[{req_id}] (准备提示) 警告: 在索引 {i} 的消息中忽略非文本或未知类型的 content item")
|
| 257 |
+
content_str = "\n".join(text_parts).strip()
|
| 258 |
+
else:
|
| 259 |
+
logger.warning(f"[{req_id}] (准备提示) 警告: 角色 {role} 在索引 {i} 的内容类型意外 ({type(content)}) 或为 None。")
|
| 260 |
+
content_str = str(content or "").strip()
|
| 261 |
+
|
| 262 |
+
if content_str:
|
| 263 |
+
current_turn_parts.append(content_str)
|
| 264 |
+
|
| 265 |
+
# 处理工具调用
|
| 266 |
+
tool_calls = msg.tool_calls
|
| 267 |
+
if role == 'assistant' and tool_calls:
|
| 268 |
+
if content_str:
|
| 269 |
+
current_turn_parts.append("\n")
|
| 270 |
+
|
| 271 |
+
tool_call_visualizations = []
|
| 272 |
+
for tool_call in tool_calls:
|
| 273 |
+
if hasattr(tool_call, 'type') and tool_call.type == 'function':
|
| 274 |
+
function_call = tool_call.function
|
| 275 |
+
func_name = function_call.name if function_call else None
|
| 276 |
+
func_args_str = function_call.arguments if function_call else None
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
parsed_args = json.loads(func_args_str if func_args_str else '{}')
|
| 280 |
+
formatted_args = json.dumps(parsed_args, indent=2, ensure_ascii=False)
|
| 281 |
+
except (json.JSONDecodeError, TypeError):
|
| 282 |
+
formatted_args = func_args_str if func_args_str is not None else "{}"
|
| 283 |
+
|
| 284 |
+
tool_call_visualizations.append(
|
| 285 |
+
f"请求调用函数: {func_name}\n参数:\n{formatted_args}"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if tool_call_visualizations:
|
| 289 |
+
current_turn_parts.append("\n".join(tool_call_visualizations))
|
| 290 |
+
|
| 291 |
+
if len(current_turn_parts) > 1 or (role == 'assistant' and tool_calls):
|
| 292 |
+
combined_parts.append("".join(current_turn_parts))
|
| 293 |
+
elif not combined_parts and not current_turn_parts:
|
| 294 |
+
logger.info(f"[{req_id}] (准备提示) 跳过角色 {role} 在索引 {i} 的空消息 (且无工具调用)。")
|
| 295 |
+
elif len(current_turn_parts) == 1 and not combined_parts:
|
| 296 |
+
logger.info(f"[{req_id}] (准备提示) 跳过角色 {role} 在索引 {i} 的空消息 (只有前缀)。")
|
| 297 |
+
|
| 298 |
+
final_prompt = "".join(combined_parts)
|
| 299 |
+
if final_prompt:
|
| 300 |
+
final_prompt += "\n"
|
| 301 |
+
|
| 302 |
+
preview_text = final_prompt[:300].replace('\n', '\\n')
|
| 303 |
+
logger.info(f"[{req_id}] (准备提示) 组合提示长度: {len(final_prompt)}。预览: '{preview_text}...'")
|
| 304 |
+
|
| 305 |
+
return final_prompt
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def estimate_tokens(text: str) -> int:
|
| 309 |
+
"""
|
| 310 |
+
估算文本的token数量
|
| 311 |
+
使用简单的字符计数方法:
|
| 312 |
+
- 英文:大约4个字符 = 1个token
|
| 313 |
+
- 中文:大约1.5个字符 = 1个token
|
| 314 |
+
- 混合文本:采用加权平均
|
| 315 |
+
"""
|
| 316 |
+
if not text:
|
| 317 |
+
return 0
|
| 318 |
+
|
| 319 |
+
# 统计中文字符数量(包括中文标点)
|
| 320 |
+
chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff' or '\u3000' <= char <= '\u303f' or '\uff00' <= char <= '\uffef')
|
| 321 |
+
|
| 322 |
+
# 统计非中文字符数量
|
| 323 |
+
non_chinese_chars = len(text) - chinese_chars
|
| 324 |
+
|
| 325 |
+
# 计算token估算
|
| 326 |
+
chinese_tokens = chinese_chars / 1.5 # 中文大约1.5字符/token
|
| 327 |
+
english_tokens = non_chinese_chars / 4.0 # 英文大约4字符/token
|
| 328 |
+
|
| 329 |
+
return max(1, int(chinese_tokens + english_tokens))
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def calculate_usage_stats(messages: List[dict], response_content: str, reasoning_content: str = None) -> dict:
|
| 333 |
+
"""
|
| 334 |
+
计算token使用统计
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
messages: 请求中的消息列表
|
| 338 |
+
response_content: 响应内容
|
| 339 |
+
reasoning_content: 推理内容(可选)
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
包含token使用统计的字典
|
| 343 |
+
"""
|
| 344 |
+
# 计算输入token(prompt tokens)
|
| 345 |
+
prompt_text = ""
|
| 346 |
+
for message in messages:
|
| 347 |
+
role = message.get("role", "")
|
| 348 |
+
content = message.get("content", "")
|
| 349 |
+
prompt_text += f"{role}: {content}\n"
|
| 350 |
+
|
| 351 |
+
prompt_tokens = estimate_tokens(prompt_text)
|
| 352 |
+
|
| 353 |
+
# 计算输出token(completion tokens)
|
| 354 |
+
completion_text = response_content or ""
|
| 355 |
+
if reasoning_content:
|
| 356 |
+
completion_text += reasoning_content
|
| 357 |
+
|
| 358 |
+
completion_tokens = estimate_tokens(completion_text)
|
| 359 |
+
|
| 360 |
+
# 总token数
|
| 361 |
+
total_tokens = prompt_tokens + completion_tokens
|
| 362 |
+
|
| 363 |
+
return {
|
| 364 |
+
"prompt_tokens": prompt_tokens,
|
| 365 |
+
"completion_tokens": completion_tokens,
|
| 366 |
+
"total_tokens": total_tokens
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def generate_sse_stop_chunk_with_usage(req_id: str, model: str, usage_stats: dict, reason: str = "stop") -> str:
|
| 371 |
+
"""生成带usage统计的SSE停止块"""
|
| 372 |
+
return generate_sse_stop_chunk(req_id, model, reason, usage_stats)
|