hins111 commited on
Commit
469e046
·
verified ·
1 Parent(s): 6841ed0

Upload 9 files

Browse files
api_utils/__init__.py CHANGED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API工具模块
3
+ 提供FastAPI应用初始化、路由处理和工具函数
4
+ """
5
+
6
+ # 应用初始化
7
+ from .app import (
8
+ create_app
9
+ )
10
+
11
+ # 路由处理器
12
+ from .routes import (
13
+ read_index,
14
+ get_css,
15
+ get_js,
16
+ get_api_info,
17
+ health_check,
18
+ list_models,
19
+ chat_completions,
20
+ cancel_request,
21
+ get_queue_status,
22
+ websocket_log_endpoint
23
+ )
24
+
25
+ # 工具函数
26
+ from .utils import (
27
+ generate_sse_chunk,
28
+ generate_sse_stop_chunk,
29
+ generate_sse_error_chunk,
30
+ use_stream_response,
31
+ clear_stream_queue,
32
+ use_helper_get_response,
33
+ validate_chat_request,
34
+ prepare_combined_prompt,
35
+ estimate_tokens,
36
+ calculate_usage_stats
37
+ )
38
+
39
+ # 请求处理器
40
+ from .request_processor import (
41
+ _process_request_refactored
42
+ )
43
+
44
+ # 队列工作器
45
+ from .queue_worker import (
46
+ queue_worker
47
+ )
48
+
49
+ __all__ = [
50
+ # 应用初始化
51
+ 'create_app',
52
+ # 路由处理器
53
+ 'read_index',
54
+ 'get_css',
55
+ 'get_js',
56
+ 'get_api_info',
57
+ 'health_check',
58
+ 'list_models',
59
+ 'chat_completions',
60
+ 'cancel_request',
61
+ 'get_queue_status',
62
+ 'websocket_log_endpoint',
63
+ # 工具函数
64
+ 'generate_sse_chunk',
65
+ 'generate_sse_stop_chunk',
66
+ 'generate_sse_error_chunk',
67
+ 'use_stream_response',
68
+ 'clear_stream_queue',
69
+ 'use_helper_get_response',
70
+ 'validate_chat_request',
71
+ 'prepare_combined_prompt',
72
+ 'estimate_tokens',
73
+ 'calculate_usage_stats',
74
+ # 请求处理器
75
+ '_process_request_refactored',
76
+ # 队列工作器
77
+ 'queue_worker'
78
+ ]
api_utils/app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI应用初始化和生命周期管理
3
+ """
4
+
5
+ import asyncio
6
+ import multiprocessing
7
+ import os
8
+ import sys
9
+ from contextlib import asynccontextmanager
10
+ from typing import Optional
11
+
12
+ from fastapi import FastAPI, Request
13
+ from fastapi.responses import JSONResponse
14
+ from starlette.middleware.base import BaseHTTPMiddleware
15
+ from starlette.types import ASGIApp
16
+ from typing import Callable, Awaitable
17
+ from playwright.async_api import Browser as AsyncBrowser, Playwright as AsyncPlaywright
18
+
19
+ # --- 配置模块导入 ---
20
+ from config import *
21
+
22
+ # --- models模块导入 ---
23
+ from models import WebSocketConnectionManager
24
+
25
+ # --- logging_utils模块导入 ---
26
+ from logging_utils import setup_server_logging, restore_original_streams
27
+
28
+ # --- browser_utils模块导入 ---
29
+ from browser_utils import (
30
+ _initialize_page_logic,
31
+ _close_page_logic,
32
+ load_excluded_models,
33
+ _handle_initial_model_state_and_storage
34
+ )
35
+
36
+ import stream
37
+ from asyncio import Queue, Lock
38
+ from . import auth_utils
39
+
40
+ # 全局状态变量(这些将在server.py中被引用)
41
+ playwright_manager: Optional[AsyncPlaywright] = None
42
+ browser_instance: Optional[AsyncBrowser] = None
43
+ page_instance = None
44
+ is_playwright_ready = False
45
+ is_browser_connected = False
46
+ is_page_ready = False
47
+ is_initializing = False
48
+
49
+ global_model_list_raw_json = None
50
+ parsed_model_list = []
51
+ model_list_fetch_event = None
52
+
53
+ current_ai_studio_model_id = None
54
+ model_switching_lock = None
55
+
56
+ excluded_model_ids = set()
57
+
58
+ request_queue = None
59
+ processing_lock = None
60
+ worker_task = None
61
+
62
+ page_params_cache = {}
63
+ params_cache_lock = None
64
+
65
+ log_ws_manager = None
66
+
67
+ STREAM_QUEUE = None
68
+ STREAM_PROCESS = None
69
+
70
+ # --- Lifespan Context Manager ---
71
+ def _setup_logging():
72
+ import server
73
+ log_level_env = os.environ.get('SERVER_LOG_LEVEL', 'INFO')
74
+ redirect_print_env = os.environ.get('SERVER_REDIRECT_PRINT', 'false')
75
+ server.log_ws_manager = WebSocketConnectionManager()
76
+ return setup_server_logging(
77
+ logger_instance=server.logger,
78
+ log_ws_manager=server.log_ws_manager,
79
+ log_level_name=log_level_env,
80
+ redirect_print_str=redirect_print_env
81
+ )
82
+
83
+ def _initialize_globals():
84
+ import server
85
+ server.request_queue = Queue()
86
+ server.processing_lock = Lock()
87
+ server.model_switching_lock = Lock()
88
+ server.params_cache_lock = Lock()
89
+ auth_utils.initialize_keys()
90
+ server.logger.info("API keys and global locks initialized.")
91
+
92
+ def _initialize_proxy_settings():
93
+ import server
94
+ STREAM_PORT = os.environ.get('STREAM_PORT')
95
+ if STREAM_PORT == '0':
96
+ PROXY_SERVER_ENV = os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY')
97
+ else:
98
+ PROXY_SERVER_ENV = f"http://127.0.0.1:{STREAM_PORT or 3120}/"
99
+
100
+ if PROXY_SERVER_ENV:
101
+ server.PLAYWRIGHT_PROXY_SETTINGS = {'server': PROXY_SERVER_ENV}
102
+ if NO_PROXY_ENV:
103
+ server.PLAYWRIGHT_PROXY_SETTINGS['bypass'] = NO_PROXY_ENV.replace(',', ';')
104
+ server.logger.info(f"Playwright proxy settings configured: {server.PLAYWRIGHT_PROXY_SETTINGS}")
105
+ else:
106
+ server.logger.info("No proxy configured for Playwright.")
107
+
108
+ async def _start_stream_proxy():
109
+ import server
110
+ STREAM_PORT = os.environ.get('STREAM_PORT')
111
+ if STREAM_PORT != '0':
112
+ port = int(STREAM_PORT or 3120)
113
+ STREAM_PROXY_SERVER_ENV = os.environ.get('UNIFIED_PROXY_CONFIG') or os.environ.get('HTTPS_PROXY') or os.environ.get('HTTP_PROXY')
114
+ server.logger.info(f"Starting STREAM proxy on port {port} with upstream proxy: {STREAM_PROXY_SERVER_ENV}")
115
+ server.STREAM_QUEUE = multiprocessing.Queue()
116
+ server.STREAM_PROCESS = multiprocessing.Process(target=stream.start, args=(server.STREAM_QUEUE, port, STREAM_PROXY_SERVER_ENV))
117
+ server.STREAM_PROCESS.start()
118
+ server.logger.info("STREAM proxy process started.")
119
+
120
+ async def _initialize_browser_and_page():
121
+ import server
122
+ from playwright.async_api import async_playwright
123
+
124
+ server.logger.info("Starting Playwright...")
125
+ server.playwright_manager = await async_playwright().start()
126
+ server.is_playwright_ready = True
127
+ server.logger.info("Playwright started.")
128
+
129
+ ws_endpoint = os.environ.get('CAMOUFOX_WS_ENDPOINT')
130
+ launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
131
+
132
+ if not ws_endpoint and launch_mode != "direct_debug_no_browser":
133
+ raise ValueError("CAMOUFOX_WS_ENDPOINT environment variable is missing.")
134
+
135
+ if ws_endpoint:
136
+ server.logger.info(f"Connecting to browser at: {ws_endpoint}")
137
+ server.browser_instance = await server.playwright_manager.firefox.connect(ws_endpoint, timeout=30000)
138
+ server.is_browser_connected = True
139
+ server.logger.info(f"Connected to browser: {server.browser_instance.version}")
140
+
141
+ server.page_instance, server.is_page_ready = await _initialize_page_logic(server.browser_instance)
142
+ if server.is_page_ready:
143
+ await _handle_initial_model_state_and_storage(server.page_instance)
144
+ server.logger.info("Page initialized successfully.")
145
+ else:
146
+ server.logger.error("Page initialization failed.")
147
+
148
+ if not server.model_list_fetch_event.is_set():
149
+ server.model_list_fetch_event.set()
150
+
151
+ async def _shutdown_resources():
152
+ import server
153
+ logger = server.logger
154
+ logger.info("Shutting down resources...")
155
+
156
+ if server.STREAM_PROCESS:
157
+ server.STREAM_PROCESS.terminate()
158
+ logger.info("STREAM proxy terminated.")
159
+
160
+ if server.worker_task and not server.worker_task.done():
161
+ server.worker_task.cancel()
162
+ try:
163
+ await asyncio.wait_for(server.worker_task, timeout=5.0)
164
+ except (asyncio.TimeoutError, asyncio.CancelledError):
165
+ pass
166
+ logger.info("Worker task stopped.")
167
+
168
+ if server.page_instance:
169
+ await _close_page_logic()
170
+
171
+ if server.browser_instance and server.browser_instance.is_connected():
172
+ await server.browser_instance.close()
173
+ logger.info("Browser connection closed.")
174
+
175
+ if server.playwright_manager:
176
+ await server.playwright_manager.stop()
177
+ logger.info("Playwright stopped.")
178
+
179
+ @asynccontextmanager
180
+ async def lifespan(app: FastAPI):
181
+ """FastAPI application life cycle management"""
182
+ import server
183
+ from server import queue_worker
184
+
185
+ original_streams = sys.stdout, sys.stderr
186
+ initial_stdout, initial_stderr = _setup_logging()
187
+ logger = server.logger
188
+
189
+ _initialize_globals()
190
+ _initialize_proxy_settings()
191
+ load_excluded_models(EXCLUDED_MODELS_FILENAME)
192
+
193
+ server.is_initializing = True
194
+ logger.info("Starting AI Studio Proxy Server...")
195
+
196
+ try:
197
+ await _start_stream_proxy()
198
+ await _initialize_browser_and_page()
199
+
200
+ launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
201
+ if server.is_page_ready or launch_mode == "direct_debug_no_browser":
202
+ server.worker_task = asyncio.create_task(queue_worker())
203
+ logger.info("Request processing worker started.")
204
+ else:
205
+ raise RuntimeError("Failed to initialize browser/page, worker not started.")
206
+
207
+ logger.info("Server startup complete.")
208
+ server.is_initializing = False
209
+ yield
210
+ except Exception as e:
211
+ logger.critical(f"Application startup failed: {e}", exc_info=True)
212
+ await _shutdown_resources()
213
+ raise RuntimeError(f"Application startup failed: {e}") from e
214
+ finally:
215
+ logger.info("Shutting down server...")
216
+ await _shutdown_resources()
217
+ restore_original_streams(initial_stdout, initial_stderr)
218
+ restore_original_streams(*original_streams)
219
+ logger.info("Server shutdown complete.")
220
+
221
+
222
+ class APIKeyAuthMiddleware(BaseHTTPMiddleware):
223
+ def __init__(self, app: ASGIApp):
224
+ super().__init__(app)
225
+ self.excluded_paths = [
226
+ "/v1/models",
227
+ "/health",
228
+ "/docs",
229
+ "/openapi.json",
230
+ # FastAPI 自动生成的其他文档路径
231
+ "/redoc",
232
+ "/favicon.ico"
233
+ ]
234
+
235
+ async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable]):
236
+ if not auth_utils.API_KEYS: # 如果 API_KEYS 为空,则不进行验证
237
+ return await call_next(request)
238
+
239
+ # 检查是否是需要保护的路径
240
+ if not request.url.path.startswith("/v1/"):
241
+ return await call_next(request)
242
+
243
+ # 检查是否是排除的路径
244
+ for excluded_path in self.excluded_paths:
245
+ if request.url.path == excluded_path or request.url.path.startswith(excluded_path + "/"):
246
+ return await call_next(request)
247
+
248
+ # 支持多种认证头格式以兼容OpenAI标准
249
+ api_key = None
250
+
251
+ # 1. 优先检查标准的 Authorization: Bearer <token> 头
252
+ auth_header = request.headers.get("Authorization")
253
+ if auth_header and auth_header.startswith("Bearer "):
254
+ api_key = auth_header[7:] # 移除 "Bearer " 前缀
255
+
256
+ # 2. 回退到自定义的 X-API-Key 头(向后兼容)
257
+ if not api_key:
258
+ api_key = request.headers.get("X-API-Key")
259
+
260
+ if not api_key or not auth_utils.verify_api_key(api_key):
261
+ return JSONResponse(
262
+ status_code=401,
263
+ content={
264
+ "error": {
265
+ "message": "Invalid or missing API key. Please provide a valid API key using 'Authorization: Bearer <your_key>' or 'X-API-Key: <your_key>' header.",
266
+ "type": "invalid_request_error",
267
+ "param": None,
268
+ "code": "invalid_api_key"
269
+ }
270
+ }
271
+ )
272
+ return await call_next(request)
273
+
274
+ def create_app() -> FastAPI:
275
+ """创建FastAPI应用实例"""
276
+ app = FastAPI(
277
+ title="AI Studio Proxy Server (集成模式)",
278
+ description="通过 Playwright与 AI Studio 交互的代理服务器。",
279
+ version="0.6.0-integrated",
280
+ lifespan=lifespan
281
+ )
282
+
283
+ # 添加中间件
284
+ app.add_middleware(APIKeyAuthMiddleware)
285
+
286
+ # 注册路由
287
+ from .routes import (
288
+ read_index, get_css, get_js, get_api_info,
289
+ health_check, list_models, chat_completions,
290
+ cancel_request, get_queue_status, websocket_log_endpoint,
291
+ get_api_keys, add_api_key, test_api_key, delete_api_key
292
+ )
293
+ from fastapi.responses import FileResponse
294
+
295
+ app.get("/", response_class=FileResponse)(read_index)
296
+ app.get("/webui.css")(get_css)
297
+ app.get("/webui.js")(get_js)
298
+ app.get("/api/info")(get_api_info)
299
+ app.get("/health")(health_check)
300
+ app.get("/v1/models")(list_models)
301
+ app.post("/v1/chat/completions")(chat_completions)
302
+ app.post("/v1/cancel/{req_id}")(cancel_request)
303
+ app.get("/v1/queue")(get_queue_status)
304
+ app.websocket("/ws/logs")(websocket_log_endpoint)
305
+
306
+ # API密钥管理端点
307
+ app.get("/api/keys")(get_api_keys)
308
+ app.post("/api/keys")(add_api_key)
309
+ app.post("/api/keys/test")(test_api_key)
310
+ app.delete("/api/keys")(delete_api_key)
311
+
312
+ return app
api_utils/auth_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Set
3
+
4
+ API_KEYS: Set[str] = set()
5
+ KEY_FILE_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "key.txt")
6
+
7
+ def load_api_keys():
8
+ """Loads API keys from the key file into the API_KEYS set."""
9
+ global API_KEYS
10
+ API_KEYS.clear()
11
+ if os.path.exists(KEY_FILE_PATH):
12
+ with open(KEY_FILE_PATH, "r") as f:
13
+ for line in f:
14
+ key = line.strip()
15
+ if key:
16
+ API_KEYS.add(key)
17
+
18
+ def initialize_keys():
19
+ """Initializes API keys. Ensures key.txt exists and loads keys."""
20
+ if not os.path.exists(KEY_FILE_PATH):
21
+ with open(KEY_FILE_PATH, "w") as f:
22
+ pass # Create an empty file
23
+ load_api_keys()
24
+
25
+ def verify_api_key(api_key_from_header: str) -> bool:
26
+ """
27
+ Verifies the API key.
28
+ Returns True if API_KEYS is empty (no validation) or if the key is valid.
29
+ """
30
+ if not API_KEYS:
31
+ return True
32
+ return api_key_from_header in API_KEYS
api_utils/dependencies.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI 依赖项模块
3
+ """
4
+ import logging
5
+ from asyncio import Queue, Lock, Event
6
+ from typing import Dict, Any, List, Set
7
+
8
+ from fastapi import Request
9
+
10
+ def get_logger() -> logging.Logger:
11
+ from server import logger
12
+ return logger
13
+
14
+ def get_log_ws_manager():
15
+ from server import log_ws_manager
16
+ return log_ws_manager
17
+
18
+ def get_request_queue() -> Queue:
19
+ from server import request_queue
20
+ return request_queue
21
+
22
+ def get_processing_lock() -> Lock:
23
+ from server import processing_lock
24
+ return processing_lock
25
+
26
+ def get_worker_task():
27
+ from server import worker_task
28
+ return worker_task
29
+
30
+ def get_server_state() -> Dict[str, Any]:
31
+ from server import is_initializing, is_playwright_ready, is_browser_connected, is_page_ready
32
+ return {
33
+ "is_initializing": is_initializing,
34
+ "is_playwright_ready": is_playwright_ready,
35
+ "is_browser_connected": is_browser_connected,
36
+ "is_page_ready": is_page_ready,
37
+ }
38
+
39
+ def get_page_instance():
40
+ from server import page_instance
41
+ return page_instance
42
+
43
+ def get_model_list_fetch_event() -> Event:
44
+ from server import model_list_fetch_event
45
+ return model_list_fetch_event
46
+
47
+ def get_parsed_model_list() -> List[Dict[str, Any]]:
48
+ from server import parsed_model_list
49
+ return parsed_model_list
50
+
51
+ def get_excluded_model_ids() -> Set[str]:
52
+ from server import excluded_model_ids
53
+ return excluded_model_ids
54
+
55
+ def get_current_ai_studio_model_id() -> str:
56
+ from server import current_ai_studio_model_id
57
+ return current_ai_studio_model_id
api_utils/queue_worker.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 队列工作器模块
3
+ 处理请求队列中的任务
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ from fastapi import HTTPException
9
+
10
+
11
+
12
+ async def queue_worker():
13
+ """队列工作器,处理请求队列中的任务"""
14
+ # 导入全局变量
15
+ from server import (
16
+ logger, request_queue, processing_lock, model_switching_lock,
17
+ params_cache_lock
18
+ )
19
+
20
+ logger.info("--- 队列 Worker 已启动 ---")
21
+
22
+ # 检查并初始化全局变量
23
+ if request_queue is None:
24
+ logger.info("初始化 request_queue...")
25
+ from asyncio import Queue
26
+ request_queue = Queue()
27
+
28
+ if processing_lock is None:
29
+ logger.info("初始化 processing_lock...")
30
+ from asyncio import Lock
31
+ processing_lock = Lock()
32
+
33
+ if model_switching_lock is None:
34
+ logger.info("初始化 model_switching_lock...")
35
+ from asyncio import Lock
36
+ model_switching_lock = Lock()
37
+
38
+ if params_cache_lock is None:
39
+ logger.info("初始化 params_cache_lock...")
40
+ from asyncio import Lock
41
+ params_cache_lock = Lock()
42
+
43
+ was_last_request_streaming = False
44
+ last_request_completion_time = 0
45
+
46
+ while True:
47
+ request_item = None
48
+ result_future = None
49
+ req_id = "UNKNOWN"
50
+ completion_event = None
51
+
52
+ try:
53
+ # 检查队列中的项目,清理已断开连接的请求
54
+ queue_size = request_queue.qsize()
55
+ if queue_size > 0:
56
+ checked_count = 0
57
+ items_to_requeue = []
58
+ processed_ids = set()
59
+
60
+ while checked_count < queue_size and checked_count < 10:
61
+ try:
62
+ item = request_queue.get_nowait()
63
+ item_req_id = item.get("req_id", "unknown")
64
+
65
+ if item_req_id in processed_ids:
66
+ items_to_requeue.append(item)
67
+ continue
68
+
69
+ processed_ids.add(item_req_id)
70
+
71
+ if not item.get("cancelled", False):
72
+ item_http_request = item.get("http_request")
73
+ if item_http_request:
74
+ try:
75
+ if await item_http_request.is_disconnected():
76
+ logger.info(f"[{item_req_id}] (Worker Queue Check) 检测到客户端已断开,标记为取消。")
77
+ item["cancelled"] = True
78
+ item_future = item.get("result_future")
79
+ if item_future and not item_future.done():
80
+ item_future.set_exception(HTTPException(status_code=499, detail=f"[{item_req_id}] Client disconnected while queued."))
81
+ except Exception as check_err:
82
+ logger.error(f"[{item_req_id}] (Worker Queue Check) Error checking disconnect: {check_err}")
83
+
84
+ items_to_requeue.append(item)
85
+ checked_count += 1
86
+ except asyncio.QueueEmpty:
87
+ break
88
+
89
+ for item in items_to_requeue:
90
+ await request_queue.put(item)
91
+
92
+ # 获取下一个请求
93
+ try:
94
+ request_item = await asyncio.wait_for(request_queue.get(), timeout=5.0)
95
+ except asyncio.TimeoutError:
96
+ # 如果5秒内没有新请求,继续循环检查
97
+ continue
98
+
99
+ req_id = request_item["req_id"]
100
+ request_data = request_item["request_data"]
101
+ http_request = request_item["http_request"]
102
+ result_future = request_item["result_future"]
103
+
104
+ if request_item.get("cancelled", False):
105
+ logger.info(f"[{req_id}] (Worker) 请求已取消,跳过。")
106
+ if not result_future.done():
107
+ result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 请求已被用户取消"))
108
+ request_queue.task_done()
109
+ continue
110
+
111
+ is_streaming_request = request_data.stream
112
+ logger.info(f"[{req_id}] (Worker) 取出请求。模式: {'流式' if is_streaming_request else '非流式'}")
113
+
114
+ # 流式请求间隔控制
115
+ current_time = time.time()
116
+ if was_last_request_streaming and is_streaming_request and (current_time - last_request_completion_time < 1.0):
117
+ delay_time = max(0.5, 1.0 - (current_time - last_request_completion_time))
118
+ logger.info(f"[{req_id}] (Worker) 连续流式请求,添加 {delay_time:.2f}s 延迟...")
119
+ await asyncio.sleep(delay_time)
120
+
121
+ if await http_request.is_disconnected():
122
+ logger.info(f"[{req_id}] (Worker) 客户端在等待锁时断开。取消。")
123
+ if not result_future.done():
124
+ result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 客户端关闭了请求"))
125
+ request_queue.task_done()
126
+ continue
127
+
128
+ logger.info(f"[{req_id}] (Worker) 等待处理锁...")
129
+ async with processing_lock:
130
+ logger.info(f"[{req_id}] (Worker) 已获取处理锁。开始核心处理...")
131
+
132
+ if await http_request.is_disconnected():
133
+ logger.info(f"[{req_id}] (Worker) 客户端在获取锁后断开。取消。")
134
+ if not result_future.done():
135
+ result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 客户端关闭了请求"))
136
+ elif result_future.done():
137
+ logger.info(f"[{req_id}] (Worker) Future 在处理前已完成/取消。跳过。")
138
+ else:
139
+ # 调用实际的请求处理函数
140
+ try:
141
+ from api_utils import _process_request_refactored
142
+ returned_value = await _process_request_refactored(
143
+ req_id, request_data, http_request, result_future
144
+ )
145
+
146
+ completion_event, submit_btn_loc, client_disco_checker = None, None, None
147
+ current_request_was_streaming = False
148
+
149
+ if isinstance(returned_value, tuple) and len(returned_value) == 3:
150
+ completion_event, submit_btn_loc, client_disco_checker = returned_value
151
+ if completion_event is not None:
152
+ current_request_was_streaming = True
153
+ logger.info(f"[{req_id}] (Worker) _process_request_refactored returned stream info (event, locator, checker).")
154
+ else:
155
+ current_request_was_streaming = False
156
+ logger.info(f"[{req_id}] (Worker) _process_request_refactored returned a tuple, but completion_event is None (likely non-stream or early exit).")
157
+ elif returned_value is None:
158
+ current_request_was_streaming = False
159
+ logger.info(f"[{req_id}] (Worker) _process_request_refactored returned non-stream completion (None).")
160
+ else:
161
+ current_request_was_streaming = False
162
+ logger.warning(f"[{req_id}] (Worker) _process_request_refactored returned unexpected type: {type(returned_value)}")
163
+
164
+ # 关键修复:在锁内等待流式完成(与原始参考文件一致)
165
+ if completion_event:
166
+ logger.info(f"[{req_id}] (Worker) 等待流式生成器完成信号...")
167
+ try:
168
+ from server import RESPONSE_COMPLETION_TIMEOUT
169
+ await asyncio.wait_for(completion_event.wait(), timeout=RESPONSE_COMPLETION_TIMEOUT/1000 + 60)
170
+ logger.info(f"[{req_id}] (Worker) ✅ 流式生成器完成信号收到。")
171
+
172
+ # 等待发送按钮禁用确认流式响应完全结束
173
+ if submit_btn_loc and client_disco_checker:
174
+ logger.info(f"[{req_id}] (Worker) 流式响应完成,检查并处理发送按钮状态...")
175
+ wait_timeout_ms = 30000 # 30 seconds
176
+ try:
177
+ from playwright.async_api import expect as expect_async
178
+ from api_utils.request_processor import ClientDisconnectedError
179
+
180
+ # 检查客户端连接状态
181
+ client_disco_checker("流式响应后按钮状态检查 - 前置检查: ")
182
+ await asyncio.sleep(0.5) # 给UI一点时间更新
183
+
184
+ # 检查按钮是否仍然启用,如果启用则直接点击停止
185
+ logger.info(f"[{req_id}] (Worker) 检查发送按钮状态...")
186
+ try:
187
+ is_button_enabled = await submit_btn_loc.is_enabled(timeout=2000)
188
+ logger.info(f"[{req_id}] (Worker) 发��按钮启用状态: {is_button_enabled}")
189
+
190
+ if is_button_enabled:
191
+ # 流式响应完成后按钮仍启用,直接点击停止
192
+ logger.info(f"[{req_id}] (Worker) 流式响应完成但按钮仍启用,主动点击按钮停止生成...")
193
+ await submit_btn_loc.click(timeout=5000, force=True)
194
+ logger.info(f"[{req_id}] (Worker) ✅ 发送按钮点击完成。")
195
+ else:
196
+ logger.info(f"[{req_id}] (Worker) 发送按钮已禁用,无需点击。")
197
+ except Exception as button_check_err:
198
+ logger.warning(f"[{req_id}] (Worker) 检查按钮状态失败: {button_check_err}")
199
+
200
+ # 等待按钮最终禁用
201
+ logger.info(f"[{req_id}] (Worker) 等待发送按钮最终禁用...")
202
+ await expect_async(submit_btn_loc).to_be_disabled(timeout=wait_timeout_ms)
203
+ logger.info(f"[{req_id}] ✅ 发送按钮已禁用。")
204
+
205
+ except Exception as e_pw_disabled:
206
+ logger.warning(f"[{req_id}] ⚠️ 流式响应后按钮状态处理超时或错误: {e_pw_disabled}")
207
+ from api_utils.request_processor import save_error_snapshot
208
+ await save_error_snapshot(f"stream_post_submit_button_handling_timeout_{req_id}")
209
+ except ClientDisconnectedError:
210
+ logger.info(f"[{req_id}] 客户端在流式响应后按钮状态处理时断开连接。")
211
+ elif current_request_was_streaming:
212
+ logger.warning(f"[{req_id}] (Worker) 流式请求但 submit_btn_loc 或 client_disco_checker 未提供。跳过按钮禁用等待。")
213
+
214
+ except asyncio.TimeoutError:
215
+ logger.warning(f"[{req_id}] (Worker) ⚠️ 等待流式生成器完成信号超时。")
216
+ if not result_future.done():
217
+ result_future.set_exception(HTTPException(status_code=504, detail=f"[{req_id}] Stream generation timed out waiting for completion signal."))
218
+ except Exception as ev_wait_err:
219
+ logger.error(f"[{req_id}] (Worker) ❌ 等待流式完成事件时出错: {ev_wait_err}")
220
+ if not result_future.done():
221
+ result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Error waiting for stream completion: {ev_wait_err}"))
222
+
223
+ except Exception as process_err:
224
+ logger.error(f"[{req_id}] (Worker) _process_request_refactored execution error: {process_err}")
225
+ if not result_future.done():
226
+ result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Request processing error: {process_err}"))
227
+
228
+ logger.info(f"[{req_id}] (Worker) 释放处理锁。")
229
+
230
+ # 在释放处理锁后立即执行清空操作
231
+ try:
232
+ # 清空流式队列缓存
233
+ from api_utils import clear_stream_queue
234
+ await clear_stream_queue()
235
+
236
+ # 清空聊天历史(对于所有模式:流式和非流式)
237
+ if submit_btn_loc and client_disco_checker:
238
+ from server import page_instance, is_page_ready
239
+ if page_instance and is_page_ready:
240
+ from browser_utils.page_controller import PageController
241
+ page_controller = PageController(page_instance, logger, req_id)
242
+ logger.info(f"[{req_id}] (Worker) 执行聊天历史清空({'流式' if completion_event else '非流式'}模式)...")
243
+ await page_controller.clear_chat_history(client_disco_checker)
244
+ logger.info(f"[{req_id}] (Worker) ✅ 聊天历史清空完成。")
245
+ else:
246
+ logger.info(f"[{req_id}] (Worker) 跳过聊天历史清空:缺少必要参数(submit_btn_loc: {bool(submit_btn_loc)}, client_disco_checker: {bool(client_disco_checker)})")
247
+ except Exception as clear_err:
248
+ logger.error(f"[{req_id}] (Worker) 清空操作时发生错误: {clear_err}", exc_info=True)
249
+
250
+ was_last_request_streaming = is_streaming_request
251
+ last_request_completion_time = time.time()
252
+
253
+ except asyncio.CancelledError:
254
+ logger.info("--- 队列 Worker 被取消 ---")
255
+ if result_future and not result_future.done():
256
+ result_future.cancel("Worker cancelled")
257
+ break
258
+ except Exception as e:
259
+ logger.error(f"[{req_id}] (Worker) ❌ 处理请求时发生意外错误: {e}", exc_info=True)
260
+ if result_future and not result_future.done():
261
+ result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] 服务器内部错误: {e}"))
262
+ finally:
263
+ if request_item:
264
+ request_queue.task_done()
265
+
266
+ logger.info("--- 队列 Worker 已停止 ---")
api_utils/request_processor.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 请求处理器模块
3
+ 包含核心的请求处理逻辑
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ import os
9
+ import random
10
+ import time
11
+ from typing import Optional, Tuple, Callable, AsyncGenerator
12
+ from asyncio import Event, Future
13
+
14
+ from fastapi import HTTPException, Request
15
+ from fastapi.responses import JSONResponse, StreamingResponse
16
+ from playwright.async_api import Page as AsyncPage, Locator, Error as PlaywrightAsyncError, expect as expect_async
17
+
18
+ # --- 配置模块导入 ---
19
+ from config import *
20
+
21
+ # --- models模块导入 ---
22
+ from models import ChatCompletionRequest, ClientDisconnectedError
23
+
24
+ # --- browser_utils模块导入 ---
25
+ from browser_utils import (
26
+ switch_ai_studio_model,
27
+ save_error_snapshot
28
+ )
29
+
30
+ # --- api_utils模块导入 ---
31
+ from .utils import (
32
+ validate_chat_request,
33
+ prepare_combined_prompt,
34
+ generate_sse_chunk,
35
+ generate_sse_stop_chunk,
36
+ use_stream_response,
37
+ calculate_usage_stats
38
+ )
39
+ from browser_utils.page_controller import PageController
40
+
41
+
42
+ async def _initialize_request_context(req_id: str, request: ChatCompletionRequest) -> dict:
43
+ """初始化请求上下文"""
44
+ from server import (
45
+ logger, page_instance, is_page_ready, parsed_model_list,
46
+ current_ai_studio_model_id, model_switching_lock, page_params_cache,
47
+ params_cache_lock
48
+ )
49
+
50
+ logger.info(f"[{req_id}] 开始处理请求...")
51
+ logger.info(f"[{req_id}] 请求参数 - Model: {request.model}, Stream: {request.stream}")
52
+
53
+ context = {
54
+ 'logger': logger,
55
+ 'page': page_instance,
56
+ 'is_page_ready': is_page_ready,
57
+ 'parsed_model_list': parsed_model_list,
58
+ 'current_ai_studio_model_id': current_ai_studio_model_id,
59
+ 'model_switching_lock': model_switching_lock,
60
+ 'page_params_cache': page_params_cache,
61
+ 'params_cache_lock': params_cache_lock,
62
+ 'is_streaming': request.stream,
63
+ 'model_actually_switched': False,
64
+ 'requested_model': request.model,
65
+ 'model_id_to_use': None,
66
+ 'needs_model_switching': False
67
+ }
68
+
69
+ return context
70
+
71
+
72
+ async def _analyze_model_requirements(req_id: str, context: dict, request: ChatCompletionRequest) -> dict:
73
+ """分析模型需求并确定是否需要切换"""
74
+ logger = context['logger']
75
+ current_ai_studio_model_id = context['current_ai_studio_model_id']
76
+ parsed_model_list = context['parsed_model_list']
77
+ requested_model = request.model
78
+
79
+ if requested_model and requested_model != MODEL_NAME:
80
+ requested_model_id = requested_model.split('/')[-1]
81
+ logger.info(f"[{req_id}] 请求使用模型: {requested_model_id}")
82
+
83
+ if parsed_model_list:
84
+ valid_model_ids = [m.get("id") for m in parsed_model_list]
85
+ if requested_model_id not in valid_model_ids:
86
+ raise HTTPException(
87
+ status_code=400,
88
+ detail=f"[{req_id}] Invalid model '{requested_model_id}'. Available models: {', '.join(valid_model_ids)}"
89
+ )
90
+
91
+ context['model_id_to_use'] = requested_model_id
92
+ if current_ai_studio_model_id != requested_model_id:
93
+ context['needs_model_switching'] = True
94
+ logger.info(f"[{req_id}] 需要切换模型: 当前={current_ai_studio_model_id} -> 目标={requested_model_id}")
95
+
96
+ return context
97
+
98
+
99
+ async def _setup_disconnect_monitoring(req_id: str, http_request: Request, result_future: Future) -> Tuple[Event, asyncio.Task, Callable]:
100
+ """设置客户端断开连接监控"""
101
+ from server import logger
102
+
103
+ client_disconnected_event = Event()
104
+
105
+ async def check_disconnect_periodically():
106
+ while not client_disconnected_event.is_set():
107
+ try:
108
+ if await http_request.is_disconnected():
109
+ logger.info(f"[{req_id}] 客户端断开,设置事件。")
110
+ client_disconnected_event.set()
111
+ if not result_future.done():
112
+ result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 客户端关闭了请求"))
113
+ break
114
+ await asyncio.sleep(1.0)
115
+ except asyncio.CancelledError:
116
+ break
117
+ except Exception as e:
118
+ logger.error(f"[{req_id}] (Disco Check Task) 错误: {e}")
119
+ client_disconnected_event.set()
120
+ if not result_future.done():
121
+ result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Internal disconnect checker error: {e}"))
122
+ break
123
+
124
+ disconnect_check_task = asyncio.create_task(check_disconnect_periodically())
125
+
126
+ def check_client_disconnected(stage: str = ""):
127
+ if client_disconnected_event.is_set():
128
+ logger.info(f"[{req_id}] 在 '{stage}' 检测到客户端断开连接。")
129
+ raise ClientDisconnectedError(f"[{req_id}] Client disconnected at stage: {stage}")
130
+ return False
131
+
132
+ return client_disconnected_event, disconnect_check_task, check_client_disconnected
133
+
134
+
135
+ async def _validate_page_status(req_id: str, context: dict, check_client_disconnected: Callable) -> None:
136
+ """验证页面状态"""
137
+ page = context['page']
138
+ is_page_ready = context['is_page_ready']
139
+
140
+ if not page or page.is_closed() or not is_page_ready:
141
+ raise HTTPException(status_code=503, detail=f"[{req_id}] AI Studio 页面丢失或未就绪。", headers={"Retry-After": "30"})
142
+
143
+ check_client_disconnected("Initial Page Check")
144
+
145
+
146
+ async def _handle_model_switching(req_id: str, context: dict, check_client_disconnected: Callable) -> dict:
147
+ """处理模型切换逻辑"""
148
+ if not context['needs_model_switching']:
149
+ return context
150
+
151
+ logger = context['logger']
152
+ page = context['page']
153
+ model_switching_lock = context['model_switching_lock']
154
+ model_id_to_use = context['model_id_to_use']
155
+
156
+ import server
157
+
158
+ async with model_switching_lock:
159
+ if server.current_ai_studio_model_id != model_id_to_use:
160
+ logger.info(f"[{req_id}] 准备切换模型: {server.current_ai_studio_model_id} -> {model_id_to_use}")
161
+ switch_success = await switch_ai_studio_model(page, model_id_to_use, req_id)
162
+ if switch_success:
163
+ server.current_ai_studio_model_id = model_id_to_use
164
+ context['model_actually_switched'] = True
165
+ context['current_ai_studio_model_id'] = model_id_to_use
166
+ logger.info(f"[{req_id}] ✅ 模型切换成功: {server.current_ai_studio_model_id}")
167
+ else:
168
+ await _handle_model_switch_failure(req_id, page, model_id_to_use, server.current_ai_studio_model_id, logger)
169
+
170
+ return context
171
+
172
+
173
+ async def _handle_model_switch_failure(req_id: str, page: AsyncPage, model_id_to_use: str, model_before_switch: str, logger) -> None:
174
+ """处理模型切换失败的情况"""
175
+ import server
176
+
177
+ logger.warning(f"[{req_id}] ❌ 模型切换至 {model_id_to_use} 失败。")
178
+ # 尝试恢复全局状态
179
+ server.current_ai_studio_model_id = model_before_switch
180
+
181
+ raise HTTPException(
182
+ status_code=422,
183
+ detail=f"[{req_id}] 未能切换到模型 '{model_id_to_use}'。请确保模型可用。"
184
+ )
185
+
186
+
187
+ async def _handle_parameter_cache(req_id: str, context: dict) -> None:
188
+ """处理参数缓存"""
189
+ logger = context['logger']
190
+ params_cache_lock = context['params_cache_lock']
191
+ page_params_cache = context['page_params_cache']
192
+ current_ai_studio_model_id = context['current_ai_studio_model_id']
193
+ model_actually_switched = context['model_actually_switched']
194
+
195
+ async with params_cache_lock:
196
+ cached_model_for_params = page_params_cache.get("last_known_model_id_for_params")
197
+
198
+ if model_actually_switched or (current_ai_studio_model_id != cached_model_for_params):
199
+ logger.info(f"[{req_id}] 模型已更改,参数缓存失效。")
200
+ page_params_cache.clear()
201
+ page_params_cache["last_known_model_id_for_params"] = current_ai_studio_model_id
202
+
203
+
204
+ async def _prepare_and_validate_request(req_id: str, request: ChatCompletionRequest, check_client_disconnected: Callable) -> str:
205
+ """准备和验证请求"""
206
+ try:
207
+ validate_chat_request(request.messages, req_id)
208
+ except ValueError as e:
209
+ raise HTTPException(status_code=400, detail=f"[{req_id}] 无效请求: {e}")
210
+
211
+ prepared_prompt = prepare_combined_prompt(request.messages, req_id)
212
+ check_client_disconnected("After Prompt Prep")
213
+
214
+ return prepared_prompt
215
+
216
+ async def _handle_response_processing(req_id: str, request: ChatCompletionRequest, page: AsyncPage,
217
+ context: dict, result_future: Future,
218
+ submit_button_locator: Locator, check_client_disconnected: Callable) -> Optional[Tuple[Event, Locator, Callable]]:
219
+ """处理响应生成"""
220
+ from server import logger
221
+
222
+ is_streaming = request.stream
223
+ current_ai_studio_model_id = context.get('current_ai_studio_model_id')
224
+
225
+ # 检查是否使用辅助流
226
+ stream_port = os.environ.get('STREAM_PORT')
227
+ use_stream = stream_port != '0'
228
+
229
+ if use_stream:
230
+ return await _handle_auxiliary_stream_response(req_id, request, context, result_future, submit_button_locator, check_client_disconnected)
231
+ else:
232
+ return await _handle_playwright_response(req_id, request, page, context, result_future, submit_button_locator, check_client_disconnected)
233
+
234
+
235
+ async def _handle_auxiliary_stream_response(req_id: str, request: ChatCompletionRequest, context: dict,
236
+ result_future: Future, submit_button_locator: Locator,
237
+ check_client_disconnected: Callable) -> Optional[Tuple[Event, Locator, Callable]]:
238
+ """使用辅助流处理响应"""
239
+ from server import logger
240
+
241
+ is_streaming = request.stream
242
+ current_ai_studio_model_id = context.get('current_ai_studio_model_id')
243
+
244
+ def generate_random_string(length):
245
+ charset = "abcdefghijklmnopqrstuvwxyz0123456789"
246
+ return ''.join(random.choice(charset) for _ in range(length))
247
+
248
+ if is_streaming:
249
+ try:
250
+ completion_event = Event()
251
+
252
+ async def create_stream_generator_from_helper(event_to_set: Event) -> AsyncGenerator[str, None]:
253
+ last_reason_pos = 0
254
+ last_body_pos = 0
255
+ model_name_for_stream = current_ai_studio_model_id or MODEL_NAME
256
+ chat_completion_id = f"{CHAT_COMPLETION_ID_PREFIX}{req_id}-{int(time.time())}-{random.randint(100, 999)}"
257
+ created_timestamp = int(time.time())
258
+
259
+ # 用于收集完整内容以计算usage
260
+ full_reasoning_content = ""
261
+ full_body_content = ""
262
+
263
+ try:
264
+ async for raw_data in use_stream_response(req_id):
265
+ # 检查客户端是否断开连接
266
+ try:
267
+ check_client_disconnected(f"流式生成器循环 ({req_id}): ")
268
+ except ClientDisconnectedError:
269
+ logger.info(f"[{req_id}] 客户端断开连接,终止流式生成")
270
+ break
271
+
272
+ # 确保 data 是字典类型
273
+ if isinstance(raw_data, str):
274
+ try:
275
+ data = json.loads(raw_data)
276
+ except json.JSONDecodeError:
277
+ logger.warning(f"[{req_id}] 无法解析流数据JSON: {raw_data}")
278
+ continue
279
+ elif isinstance(raw_data, dict):
280
+ data = raw_data
281
+ else:
282
+ logger.warning(f"[{req_id}] 未知的流数据类型: {type(raw_data)}")
283
+ continue
284
+
285
+ # 确保必要的键存在
286
+ if not isinstance(data, dict):
287
+ logger.warning(f"[{req_id}] 数据不是字典类型: {data}")
288
+ continue
289
+
290
+ reason = data.get("reason", "")
291
+ body = data.get("body", "")
292
+ done = data.get("done", False)
293
+ function = data.get("function", [])
294
+
295
+ # 更新完整内容记录
296
+ if reason:
297
+ full_reasoning_content = reason
298
+ if body:
299
+ full_body_content = body
300
+
301
+ # 处理推理内容
302
+ if len(reason) > last_reason_pos:
303
+ output = {
304
+ "id": chat_completion_id,
305
+ "object": "chat.completion.chunk",
306
+ "model": model_name_for_stream,
307
+ "created": created_timestamp,
308
+ "choices":[{
309
+ "index": 0,
310
+ "delta":{
311
+ "role": "assistant",
312
+ "content": None,
313
+ "reasoning_content": reason[last_reason_pos:],
314
+ },
315
+ "finish_reason": None,
316
+ "native_finish_reason": None,
317
+ }]
318
+ }
319
+ last_reason_pos = len(reason)
320
+ yield f"data: {json.dumps(output, ensure_ascii=False, separators=(',', ':'))}\n\n"
321
+
322
+ # 处理主体内容
323
+ if len(body) > last_body_pos:
324
+ finish_reason_val = None
325
+ if done:
326
+ finish_reason_val = "stop"
327
+
328
+ delta_content = {"role": "assistant", "content": body[last_body_pos:]}
329
+ choice_item = {
330
+ "index": 0,
331
+ "delta": delta_content,
332
+ "finish_reason": finish_reason_val,
333
+ "native_finish_reason": finish_reason_val,
334
+ }
335
+
336
+ if done and function and len(function) > 0:
337
+ tool_calls_list = []
338
+ for func_idx, function_call_data in enumerate(function):
339
+ tool_calls_list.append({
340
+ "id": f"call_{generate_random_string(24)}",
341
+ "index": func_idx,
342
+ "type": "function",
343
+ "function": {
344
+ "name": function_call_data["name"],
345
+ "arguments": json.dumps(function_call_data["params"]),
346
+ },
347
+ })
348
+ delta_content["tool_calls"] = tool_calls_list
349
+ choice_item["finish_reason"] = "tool_calls"
350
+ choice_item["native_finish_reason"] = "tool_calls"
351
+ delta_content["content"] = None
352
+
353
+ output = {
354
+ "id": chat_completion_id,
355
+ "object": "chat.completion.chunk",
356
+ "model": model_name_for_stream,
357
+ "created": created_timestamp,
358
+ "choices": [choice_item]
359
+ }
360
+ last_body_pos = len(body)
361
+ yield f"data: {json.dumps(output, ensure_ascii=False, separators=(',', ':'))}\n\n"
362
+
363
+ # 处理只有done=True但没有新内容的情况(仅有函数调用或纯结束)
364
+ elif done:
365
+ # 如果有函数调用但没有新的body内容
366
+ if function and len(function) > 0:
367
+ delta_content = {"role": "assistant", "content": None}
368
+ tool_calls_list = []
369
+ for func_idx, function_call_data in enumerate(function):
370
+ tool_calls_list.append({
371
+ "id": f"call_{generate_random_string(24)}",
372
+ "index": func_idx,
373
+ "type": "function",
374
+ "function": {
375
+ "name": function_call_data["name"],
376
+ "arguments": json.dumps(function_call_data["params"]),
377
+ },
378
+ })
379
+ delta_content["tool_calls"] = tool_calls_list
380
+ choice_item = {
381
+ "index": 0,
382
+ "delta": delta_content,
383
+ "finish_reason": "tool_calls",
384
+ "native_finish_reason": "tool_calls",
385
+ }
386
+ else:
387
+ # 纯结束,没有新内容和函数调用
388
+ choice_item = {
389
+ "index": 0,
390
+ "delta": {"role": "assistant"},
391
+ "finish_reason": "stop",
392
+ "native_finish_reason": "stop",
393
+ }
394
+
395
+ output = {
396
+ "id": chat_completion_id,
397
+ "object": "chat.completion.chunk",
398
+ "model": model_name_for_stream,
399
+ "created": created_timestamp,
400
+ "choices": [choice_item]
401
+ }
402
+ yield f"data: {json.dumps(output, ensure_ascii=False, separators=(',', ':'))}\n\n"
403
+
404
+ except ClientDisconnectedError:
405
+ logger.info(f"[{req_id}] 流式生成器中检测到客户端断开连接")
406
+ except Exception as e:
407
+ logger.error(f"[{req_id}] 流式生成器处理过程中发生错误: {e}", exc_info=True)
408
+ # 发送错误信息给客户端
409
+ try:
410
+ error_chunk = {
411
+ "id": chat_completion_id,
412
+ "object": "chat.completion.chunk",
413
+ "model": model_name_for_stream,
414
+ "created": created_timestamp,
415
+ "choices": [{
416
+ "index": 0,
417
+ "delta": {"role": "assistant", "content": f"\n\n[错误: {str(e)}]"},
418
+ "finish_reason": "stop",
419
+ "native_finish_reason": "stop",
420
+ }]
421
+ }
422
+ yield f"data: {json.dumps(error_chunk, ensure_ascii=False, separators=(',', ':'))}\n\n"
423
+ except Exception:
424
+ pass # 如果无法发送错误信息,继续处理结束逻辑
425
+ finally:
426
+ # 计算usage统计
427
+ try:
428
+ usage_stats = calculate_usage_stats(
429
+ [msg.model_dump() for msg in request.messages],
430
+ full_body_content,
431
+ full_reasoning_content
432
+ )
433
+ logger.info(f"[{req_id}] 计算的token使用统计: {usage_stats}")
434
+
435
+ # 发送带usage的最终chunk
436
+ final_chunk = {
437
+ "id": chat_completion_id,
438
+ "object": "chat.completion.chunk",
439
+ "model": model_name_for_stream,
440
+ "created": created_timestamp,
441
+ "choices": [{
442
+ "index": 0,
443
+ "delta": {},
444
+ "finish_reason": "stop",
445
+ "native_finish_reason": "stop"
446
+ }],
447
+ "usage": usage_stats
448
+ }
449
+ yield f"data: {json.dumps(final_chunk, ensure_ascii=False, separators=(',', ':'))}\n\n"
450
+ logger.info(f"[{req_id}] 已发送带usage统计的最终chunk")
451
+
452
+ except Exception as usage_err:
453
+ logger.error(f"[{req_id}] 计算或发送usage统计时出错: {usage_err}")
454
+
455
+ # 确保总是发送 [DONE] 标记
456
+ try:
457
+ logger.info(f"[{req_id}] 流式生成器完成,发送 [DONE] 标记")
458
+ yield "data: [DONE]\n\n"
459
+ except Exception as done_err:
460
+ logger.error(f"[{req_id}] 发送 [DONE] 标记时出错: {done_err}")
461
+
462
+ # 确保事件被设置
463
+ if not event_to_set.is_set():
464
+ event_to_set.set()
465
+ logger.info(f"[{req_id}] 流式生成器完成事件已设置")
466
+
467
+ stream_gen_func = create_stream_generator_from_helper(completion_event)
468
+ if not result_future.done():
469
+ result_future.set_result(StreamingResponse(stream_gen_func, media_type="text/event-stream"))
470
+ else:
471
+ if not completion_event.is_set():
472
+ completion_event.set()
473
+
474
+ return completion_event, submit_button_locator, check_client_disconnected
475
+
476
+ except Exception as e:
477
+ logger.error(f"[{req_id}] 从队列获取流式数据时出错: {e}", exc_info=True)
478
+ if completion_event and not completion_event.is_set():
479
+ completion_event.set()
480
+ raise
481
+
482
+ else: # 非流式
483
+ content = None
484
+ reasoning_content = None
485
+ functions = None
486
+ final_data_from_aux_stream = None
487
+
488
+ async for raw_data in use_stream_response(req_id):
489
+ check_client_disconnected(f"非流式辅助流 - 循环中 ({req_id}): ")
490
+
491
+ # 确保 data 是字典类型
492
+ if isinstance(raw_data, str):
493
+ try:
494
+ data = json.loads(raw_data)
495
+ except json.JSONDecodeError:
496
+ logger.warning(f"[{req_id}] 无法解析非流式数据JSON: {raw_data}")
497
+ continue
498
+ elif isinstance(raw_data, dict):
499
+ data = raw_data
500
+ else:
501
+ logger.warning(f"[{req_id}] 非流式未知数据类型: {type(raw_data)}")
502
+ continue
503
+
504
+ # 确保数据是字典类型
505
+ if not isinstance(data, dict):
506
+ logger.warning(f"[{req_id}] 非流式数据不是字典类型: {data}")
507
+ continue
508
+
509
+ final_data_from_aux_stream = data
510
+ if data.get("done"):
511
+ content = data.get("body")
512
+ reasoning_content = data.get("reason")
513
+ functions = data.get("function")
514
+ break
515
+
516
+ if final_data_from_aux_stream and final_data_from_aux_stream.get("reason") == "internal_timeout":
517
+ logger.error(f"[{req_id}] 非流式请求通过辅助流失败: 内部超时")
518
+ raise HTTPException(status_code=502, detail=f"[{req_id}] 辅助流处理错误 (内部超时)")
519
+
520
+ if final_data_from_aux_stream and final_data_from_aux_stream.get("done") is True and content is None:
521
+ logger.error(f"[{req_id}] 非流式请求通过辅助流完成但未提供内容")
522
+ raise HTTPException(status_code=502, detail=f"[{req_id}] 辅助流完成但未提供内容")
523
+
524
+ model_name_for_json = current_ai_studio_model_id or MODEL_NAME
525
+ message_payload = {"role": "assistant", "content": content}
526
+ finish_reason_val = "stop"
527
+
528
+ if functions and len(functions) > 0:
529
+ tool_calls_list = []
530
+ for func_idx, function_call_data in enumerate(functions):
531
+ tool_calls_list.append({
532
+ "id": f"call_{generate_random_string(24)}",
533
+ "index": func_idx,
534
+ "type": "function",
535
+ "function": {
536
+ "name": function_call_data["name"],
537
+ "arguments": json.dumps(function_call_data["params"]),
538
+ },
539
+ })
540
+ message_payload["tool_calls"] = tool_calls_list
541
+ finish_reason_val = "tool_calls"
542
+ message_payload["content"] = None
543
+
544
+ if reasoning_content:
545
+ message_payload["reasoning_content"] = reasoning_content
546
+
547
+ # 计算token使用统计
548
+ usage_stats = calculate_usage_stats(
549
+ [msg.model_dump() for msg in request.messages],
550
+ content or "",
551
+ reasoning_content
552
+ )
553
+
554
+ response_payload = {
555
+ "id": f"{CHAT_COMPLETION_ID_PREFIX}{req_id}-{int(time.time())}",
556
+ "object": "chat.completion",
557
+ "created": int(time.time()),
558
+ "model": model_name_for_json,
559
+ "choices": [{
560
+ "index": 0,
561
+ "message": message_payload,
562
+ "finish_reason": finish_reason_val,
563
+ "native_finish_reason": finish_reason_val,
564
+ }],
565
+ "usage": usage_stats
566
+ }
567
+
568
+ if not result_future.done():
569
+ result_future.set_result(JSONResponse(content=response_payload))
570
+ return None
571
+
572
+
573
+ async def _handle_playwright_response(req_id: str, request: ChatCompletionRequest, page: AsyncPage,
574
+ context: dict, result_future: Future, submit_button_locator: Locator,
575
+ check_client_disconnected: Callable) -> Optional[Tuple[Event, Locator, Callable]]:
576
+ """使用Playwright处理响应"""
577
+ from server import logger
578
+
579
+ is_streaming = request.stream
580
+ current_ai_studio_model_id = context.get('current_ai_studio_model_id')
581
+
582
+ logger.info(f"[{req_id}] 定位响应元素...")
583
+ response_container = page.locator(RESPONSE_CONTAINER_SELECTOR).last
584
+ response_element = response_container.locator(RESPONSE_TEXT_SELECTOR)
585
+
586
+ try:
587
+ await expect_async(response_container).to_be_attached(timeout=20000)
588
+ check_client_disconnected("After Response Container Attached: ")
589
+ await expect_async(response_element).to_be_attached(timeout=90000)
590
+ logger.info(f"[{req_id}] 响应元素已定位。")
591
+ except (PlaywrightAsyncError, asyncio.TimeoutError, ClientDisconnectedError) as locate_err:
592
+ if isinstance(locate_err, ClientDisconnectedError):
593
+ raise
594
+ logger.error(f"[{req_id}] ❌ 错误: 定位响应元素失败或超时: {locate_err}")
595
+ await save_error_snapshot(f"response_locate_error_{req_id}")
596
+ raise HTTPException(status_code=502, detail=f"[{req_id}] 定位AI Studio响应元素失败: {locate_err}")
597
+ except Exception as locate_exc:
598
+ logger.exception(f"[{req_id}] ❌ 错误: 定位响应元素时意外错误")
599
+ await save_error_snapshot(f"response_locate_unexpected_{req_id}")
600
+ raise HTTPException(status_code=500, detail=f"[{req_id}] 定位响应元素时意外错误: {locate_exc}")
601
+
602
+ check_client_disconnected("After Response Element Located: ")
603
+
604
+ if is_streaming:
605
+ completion_event = Event()
606
+
607
+ async def create_response_stream_generator():
608
+ try:
609
+ # 使用PageController获取响应
610
+ page_controller = PageController(page, logger, req_id)
611
+ final_content = await page_controller.get_response(check_client_disconnected)
612
+
613
+ # 生成流式响应 - 保持Markdown格式
614
+ # 按行分割以保持换行符和Markdown结构
615
+ lines = final_content.split('\n')
616
+ for line_idx, line in enumerate(lines):
617
+ # 检查客户端是否断开连接
618
+ try:
619
+ check_client_disconnected(f"Playwright流式生成器循环 ({req_id}): ")
620
+ except ClientDisconnectedError:
621
+ logger.info(f"[{req_id}] Playwright流式生成器中检测到客户端断开连接")
622
+ break
623
+
624
+ # 输出当前行的内容(包括空行,以保持Markdown格式)
625
+ if line: # 非空行按字符分块输出
626
+ chunk_size = 5 # 每次输出5个字符,平衡速度和体验
627
+ for i in range(0, len(line), chunk_size):
628
+ chunk = line[i:i+chunk_size]
629
+ yield generate_sse_chunk(chunk, req_id, current_ai_studio_model_id or MODEL_NAME)
630
+ await asyncio.sleep(0.03) # 适中的输出速度
631
+
632
+ # 添加换行符(除了最后一行)
633
+ if line_idx < len(lines) - 1:
634
+ yield generate_sse_chunk('\n', req_id, current_ai_studio_model_id or MODEL_NAME)
635
+ await asyncio.sleep(0.01)
636
+
637
+ # 计算并发送带usage的完成块
638
+ usage_stats = calculate_usage_stats(
639
+ [msg.model_dump() for msg in request.messages],
640
+ final_content,
641
+ "" # Playwright模式没有reasoning content
642
+ )
643
+ logger.info(f"[{req_id}] Playwright非流式计算的token使用统计: {usage_stats}")
644
+
645
+ # 发送带usage的完成块
646
+ yield generate_sse_stop_chunk(req_id, current_ai_studio_model_id or MODEL_NAME, "stop", usage_stats)
647
+
648
+ except ClientDisconnectedError:
649
+ logger.info(f"[{req_id}] Playwright流式生成器中检测到客户端断开连接")
650
+ except Exception as e:
651
+ logger.error(f"[{req_id}] Playwright流式生成器处理过程中发生错误: {e}", exc_info=True)
652
+ # 发送错误信息给客户端
653
+ try:
654
+ yield generate_sse_chunk(f"\n\n[错误: {str(e)}]", req_id, current_ai_studio_model_id or MODEL_NAME)
655
+ yield generate_sse_stop_chunk(req_id, current_ai_studio_model_id or MODEL_NAME)
656
+ except Exception:
657
+ pass # 如果无法发送错误信息,继续处理结束逻辑
658
+ finally:
659
+ # 确保事件被设置
660
+ if not completion_event.is_set():
661
+ completion_event.set()
662
+ logger.info(f"[{req_id}] Playwright流式生成器完成事件已设置")
663
+
664
+ stream_gen_func = create_response_stream_generator()
665
+ if not result_future.done():
666
+ result_future.set_result(StreamingResponse(stream_gen_func, media_type="text/event-stream"))
667
+
668
+ return completion_event, submit_button_locator, check_client_disconnected
669
+ else:
670
+ # 使用PageController获取响应
671
+ page_controller = PageController(page, logger, req_id)
672
+ final_content = await page_controller.get_response(check_client_disconnected)
673
+
674
+ # 计算token使用统计
675
+ usage_stats = calculate_usage_stats(
676
+ [msg.model_dump() for msg in request.messages],
677
+ final_content,
678
+ "" # Playwright模式没有reasoning content
679
+ )
680
+ logger.info(f"[{req_id}] Playwright非流式计算的token使用统计: {usage_stats}")
681
+
682
+ response_payload = {
683
+ "id": f"{CHAT_COMPLETION_ID_PREFIX}{req_id}-{int(time.time())}",
684
+ "object": "chat.completion",
685
+ "created": int(time.time()),
686
+ "model": current_ai_studio_model_id or MODEL_NAME,
687
+ "choices": [{
688
+ "index": 0,
689
+ "message": {"role": "assistant", "content": final_content},
690
+ "finish_reason": "stop"
691
+ }],
692
+ "usage": usage_stats
693
+ }
694
+
695
+ if not result_future.done():
696
+ result_future.set_result(JSONResponse(content=response_payload))
697
+
698
+ return None
699
+
700
+
701
+ async def _cleanup_request_resources(req_id: str, disconnect_check_task: Optional[asyncio.Task],
702
+ completion_event: Optional[Event], result_future: Future,
703
+ is_streaming: bool) -> None:
704
+ """清理请求资源"""
705
+ from server import logger
706
+
707
+ if disconnect_check_task and not disconnect_check_task.done():
708
+ disconnect_check_task.cancel()
709
+ try:
710
+ await disconnect_check_task
711
+ except asyncio.CancelledError:
712
+ pass
713
+ except Exception as task_clean_err:
714
+ logger.error(f"[{req_id}] 清理任务时出错: {task_clean_err}")
715
+
716
+ logger.info(f"[{req_id}] 处理完成。")
717
+
718
+ if is_streaming and completion_event and not completion_event.is_set() and (result_future.done() and result_future.exception() is not None):
719
+ logger.warning(f"[{req_id}] 流式请求异常,确保完成事件已设置。")
720
+ completion_event.set()
721
+
722
+
723
+ async def _process_request_refactored(
724
+ req_id: str,
725
+ request: ChatCompletionRequest,
726
+ http_request: Request,
727
+ result_future: Future
728
+ ) -> Optional[Tuple[Event, Locator, Callable[[str], bool]]]:
729
+ """核心请求处理函数 - 重构版本"""
730
+
731
+ context = await _initialize_request_context(req_id, request)
732
+ context = await _analyze_model_requirements(req_id, context, request)
733
+
734
+ client_disconnected_event, disconnect_check_task, check_client_disconnected = await _setup_disconnect_monitoring(
735
+ req_id, http_request, result_future
736
+ )
737
+
738
+ page = context['page']
739
+ submit_button_locator = page.locator(SUBMIT_BUTTON_SELECTOR) if page else None
740
+ completion_event = None
741
+
742
+ try:
743
+ await _validate_page_status(req_id, context, check_client_disconnected)
744
+
745
+ page_controller = PageController(page, context['logger'], req_id)
746
+
747
+ await _handle_model_switching(req_id, context, check_client_disconnected)
748
+ await _handle_parameter_cache(req_id, context)
749
+
750
+ prepared_prompt = await _prepare_and_validate_request(req_id, request, check_client_disconnected)
751
+
752
+ # 使用PageController处理页面交互
753
+ # 注意:聊天历史清空已移至队列处理锁释放后执行
754
+
755
+ await page_controller.adjust_parameters(
756
+ request.model_dump(exclude_none=True), # 使用 exclude_none=True 避免传递None值
757
+ context['page_params_cache'],
758
+ context['params_cache_lock'],
759
+ context['model_id_to_use'],
760
+ context['parsed_model_list'],
761
+ check_client_disconnected
762
+ )
763
+
764
+ await page_controller.submit_prompt(prepared_prompt, check_client_disconnected)
765
+
766
+ # 响应处理仍然需要在这里,因为它决定了是流式还是非流式,并设置future
767
+ response_result = await _handle_response_processing(
768
+ req_id, request, page, context, result_future, submit_button_locator, check_client_disconnected
769
+ )
770
+
771
+ if response_result:
772
+ completion_event, _, _ = response_result
773
+
774
+ return completion_event, submit_button_locator, check_client_disconnected
775
+
776
+ except ClientDisconnectedError as disco_err:
777
+ context['logger'].info(f"[{req_id}] 捕获到客户端断开连接信号: {disco_err}")
778
+ if not result_future.done():
779
+ result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] Client disconnected during processing."))
780
+ except HTTPException as http_err:
781
+ context['logger'].warning(f"[{req_id}] 捕获到 HTTP 异常: {http_err.status_code} - {http_err.detail}")
782
+ if not result_future.done():
783
+ result_future.set_exception(http_err)
784
+ except PlaywrightAsyncError as pw_err:
785
+ context['logger'].error(f"[{req_id}] 捕获到 Playwright 错误: {pw_err}")
786
+ await save_error_snapshot(f"process_playwright_error_{req_id}")
787
+ if not result_future.done():
788
+ result_future.set_exception(HTTPException(status_code=502, detail=f"[{req_id}] Playwright interaction failed: {pw_err}"))
789
+ except Exception as e:
790
+ context['logger'].exception(f"[{req_id}] 捕获到意外错误")
791
+ await save_error_snapshot(f"process_unexpected_error_{req_id}")
792
+ if not result_future.done():
793
+ result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Unexpected server error: {e}"))
794
+ finally:
795
+ await _cleanup_request_resources(req_id, disconnect_check_task, completion_event, result_future, request.stream)
api_utils/request_processor_backup.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 请求处理器模块
3
+ 包含核心的请求处理逻辑
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ import os
9
+ import random
10
+ import time
11
+ from typing import Optional, Tuple, Callable, AsyncGenerator
12
+ from asyncio import Event, Future
13
+
14
+ from fastapi import HTTPException, Request
15
+ from fastapi.responses import JSONResponse, StreamingResponse
16
+ from playwright.async_api import Page as AsyncPage, Locator, Error as PlaywrightAsyncError, expect as expect_async, TimeoutError
17
+
18
+ # --- 配置模块导入 ---
19
+ from config import *
20
+
21
+ # --- models模块导入 ---
22
+ from models import ChatCompletionRequest, ClientDisconnectedError
23
+
24
+ # --- browser_utils模块导入 ---
25
+ from browser_utils import (
26
+ switch_ai_studio_model,
27
+ save_error_snapshot,
28
+ _wait_for_response_completion,
29
+ _get_final_response_content,
30
+ detect_and_extract_page_error
31
+ )
32
+
33
+ # --- api_utils模块导入 ---
34
+ from .utils import (
35
+ validate_chat_request,
36
+ prepare_combined_prompt,
37
+ generate_sse_chunk,
38
+ generate_sse_stop_chunk,
39
+ generate_sse_error_chunk,
40
+ use_helper_get_response,
41
+ use_stream_response
42
+ )
43
+
44
+
45
+ async def _process_request_refactored(
46
+ req_id: str,
47
+ request: ChatCompletionRequest,
48
+ http_request: Request,
49
+ result_future: Future
50
+ ) -> Optional[Tuple[Event, Locator, Callable[[str], bool]]]:
51
+ """核心请求处理函数 - 完整版本"""
52
+ global current_ai_studio_model_id
53
+
54
+ # 导入全局变量
55
+ from server import (
56
+ logger, page_instance, is_page_ready, parsed_model_list,
57
+ current_ai_studio_model_id, model_switching_lock, page_params_cache,
58
+ params_cache_lock
59
+ )
60
+
61
+ model_actually_switched_in_current_api_call = False
62
+ logger.info(f"[{req_id}] (Refactored Process) 开始处理请求...")
63
+ logger.info(f"[{req_id}] 请求参数 - Model: {request.model}, Stream: {request.stream}")
64
+ logger.info(f"[{req_id}] 请求参数 - Temperature: {request.temperature}")
65
+ logger.info(f"[{req_id}] 请求参数 - Max Output Tokens: {request.max_output_tokens}")
66
+ logger.info(f"[{req_id}] 请求参数 - Stop Sequences: {request.stop}")
67
+ logger.info(f"[{req_id}] 请求参数 - Top P: {request.top_p}")
68
+
69
+ is_streaming = request.stream
70
+ page: Optional[AsyncPage] = page_instance
71
+ completion_event: Optional[Event] = None
72
+ requested_model = request.model
73
+ model_id_to_use = None
74
+ needs_model_switching = False
75
+
76
+ if requested_model and requested_model != MODEL_NAME:
77
+ requested_model_parts = requested_model.split('/')
78
+ requested_model_id = requested_model_parts[-1] if len(requested_model_parts) > 1 else requested_model
79
+ logger.info(f"[{req_id}] 请求使用模型: {requested_model_id}")
80
+ if parsed_model_list:
81
+ valid_model_ids = [m.get("id") for m in parsed_model_list]
82
+ if requested_model_id not in valid_model_ids:
83
+ logger.error(f"[{req_id}] ❌ 无效的模型ID: {requested_model_id}。可用模型: {valid_model_ids}")
84
+ raise HTTPException(status_code=400, detail=f"[{req_id}] Invalid model '{requested_model_id}'. Available models: {', '.join(valid_model_ids)}")
85
+ model_id_to_use = requested_model_id
86
+ if current_ai_studio_model_id != model_id_to_use:
87
+ needs_model_switching = True
88
+ logger.info(f"[{req_id}] 需要切换模型: 当前={current_ai_studio_model_id} -> 目标={model_id_to_use}")
89
+ else:
90
+ logger.info(f"[{req_id}] 请求模型与当前模型相同 ({model_id_to_use}),无需切换")
91
+ else:
92
+ logger.info(f"[{req_id}] 未指定具体模型或使用代理模型名称,将使用当前模型: {current_ai_studio_model_id or '未知'}")
93
+
94
+ client_disconnected_event = Event()
95
+ disconnect_check_task = None
96
+ input_field_locator = page.locator(INPUT_SELECTOR) if page else None
97
+ submit_button_locator = page.locator(SUBMIT_BUTTON_SELECTOR) if page else None
98
+
99
+ async def check_disconnect_periodically():
100
+ while not client_disconnected_event.is_set():
101
+ try:
102
+ if await http_request.is_disconnected():
103
+ logger.info(f"[{req_id}] (Disco Check Task) 客户端断开。设置事件并尝试停止。")
104
+ client_disconnected_event.set()
105
+ try:
106
+ if submit_button_locator and await submit_button_locator.is_enabled(timeout=1500):
107
+ if input_field_locator and await input_field_locator.input_value(timeout=1500) == '':
108
+ logger.info(f"[{req_id}] (Disco Check Task) 点击停止...")
109
+ await submit_button_locator.click(timeout=3000, force=True)
110
+ except Exception as click_err:
111
+ logger.warning(f"[{req_id}] (Disco Check Task) 停止按钮点击失败: {click_err}")
112
+ if not result_future.done():
113
+ result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] 客户端在处理期间关闭了请求"))
114
+ break
115
+ await asyncio.sleep(1.0)
116
+ except asyncio.CancelledError:
117
+ break
118
+ except Exception as e:
119
+ logger.error(f"[{req_id}] (Disco Check Task) 错误: {e}")
120
+ client_disconnected_event.set()
121
+ if not result_future.done():
122
+ result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Internal disconnect checker error: {e}"))
123
+ break
124
+
125
+ disconnect_check_task = asyncio.create_task(check_disconnect_periodically())
126
+
127
+ def check_client_disconnected(*args):
128
+ msg_to_log = ""
129
+ if len(args) == 1 and isinstance(args[0], str):
130
+ msg_to_log = args[0]
131
+
132
+ if client_disconnected_event.is_set():
133
+ logger.info(f"[{req_id}] {msg_to_log}检测到客户端断开连接事件。")
134
+ raise ClientDisconnectedError(f"[{req_id}] Client disconnected event set.")
135
+ return False
136
+
137
+ try:
138
+ if not page or page.is_closed() or not is_page_ready:
139
+ raise HTTPException(status_code=503, detail=f"[{req_id}] AI Studio 页面丢失或未就绪。", headers={"Retry-After": "30"})
140
+
141
+ check_client_disconnected("Initial Page Check: ")
142
+
143
+ # 模型切换逻辑
144
+ if needs_model_switching and model_id_to_use:
145
+ async with model_switching_lock:
146
+ model_before_switch_attempt = current_ai_studio_model_id
147
+ if current_ai_studio_model_id != model_id_to_use:
148
+ logger.info(f"[{req_id}] 获取锁后准备切换: 当前内存中模型={current_ai_studio_model_id}, 目标={model_id_to_use}")
149
+ switch_success = await switch_ai_studio_model(page, model_id_to_use, req_id)
150
+ if switch_success:
151
+ current_ai_studio_model_id = model_id_to_use
152
+ model_actually_switched_in_current_api_call = True
153
+ logger.info(f"[{req_id}] ✅ 模型切换成功。全局模型状态已更新为: {current_ai_studio_model_id}")
154
+ else:
155
+ logger.warning(f"[{req_id}] ❌ 模型切换至 {model_id_to_use} 失败 (AI Studio 未接受或覆盖了更改)。")
156
+ active_model_id_after_fail = model_before_switch_attempt
157
+ try:
158
+ final_prefs_str_after_fail = await page.evaluate("() => localStorage.getItem('aiStudioUserPreference')")
159
+ if final_prefs_str_after_fail:
160
+ final_prefs_obj_after_fail = json.loads(final_prefs_str_after_fail)
161
+ model_path_in_final_prefs = final_prefs_obj_after_fail.get("promptModel")
162
+ if model_path_in_final_prefs and isinstance(model_path_in_final_prefs, str):
163
+ active_model_id_after_fail = model_path_in_final_prefs.split('/')[-1]
164
+ except Exception as read_final_prefs_err:
165
+ logger.error(f"[{req_id}] 切换失败后读取最终 localStorage 出错: {read_final_prefs_err}")
166
+ current_ai_studio_model_id = active_model_id_after_fail
167
+ logger.info(f"[{req_id}] 全局模型状态在切换失败后设置为 (或保持为): {current_ai_studio_model_id}")
168
+ actual_displayed_model_name = "未知 (无法读取)"
169
+ try:
170
+ model_wrapper_locator = page.locator('#mat-select-value-0 mat-select-trigger').first
171
+ actual_displayed_model_name = await model_wrapper_locator.inner_text(timeout=3000)
172
+ except Exception:
173
+ pass
174
+ raise HTTPException(
175
+ status_code=422,
176
+ detail=f"[{req_id}] AI Studio 未能应用所请求的模型 '{model_id_to_use}' 或该模型不受支持。请选择 AI Studio 网页界面中可用的模型。当前实际生效的模型 ID 为 '{current_ai_studio_model_id}', 页面显示为 '{actual_displayed_model_name}'."
177
+ )
178
+ else:
179
+ logger.info(f"[{req_id}] 获取锁后发现模型已是目标模型 {current_ai_studio_model_id},无需切换")
180
+
181
+ # 参数缓存处理
182
+ async with params_cache_lock:
183
+ cached_model_for_params = page_params_cache.get("last_known_model_id_for_params")
184
+ if model_actually_switched_in_current_api_call or \
185
+ (current_ai_studio_model_id is not None and current_ai_studio_model_id != cached_model_for_params):
186
+ action_taken = "Invalidating" if page_params_cache else "Initializing"
187
+ logger.info(f"[{req_id}] {action_taken} parameter cache. Reason: Model context changed (switched this call: {model_actually_switched_in_current_api_call}, current model: {current_ai_studio_model_id}, cache model: {cached_model_for_params}).")
188
+ page_params_cache.clear()
189
+ if current_ai_studio_model_id:
190
+ page_params_cache["last_known_model_id_for_params"] = current_ai_studio_model_id
191
+ else:
192
+ logger.debug(f"[{req_id}] Parameter cache for model '{cached_model_for_params}' remains valid (current model: '{current_ai_studio_model_id}', switched this call: {model_actually_switched_in_current_api_call}).")
193
+
194
+ # 验证请求
195
+ try:
196
+ validate_chat_request(request.messages, req_id)
197
+ except ValueError as e:
198
+ raise HTTPException(status_code=400, detail=f"[{req_id}] 无效请求: {e}")
199
+
200
+ # 准备提示
201
+ prepared_prompt = prepare_combined_prompt(request.messages, req_id)
202
+ check_client_disconnected("After Prompt Prep: ")
203
+
204
+ # 这里需要添加完整的处理逻辑 - 由于函数太长,暂时返回简化响应
205
+ logger.info(f"[{req_id}] (Refactored Process) 处理完整逻辑 - 需要从备份恢复剩余部分")
206
+
207
+ # 简单响应用于测试
208
+ if is_streaming:
209
+ completion_event = Event()
210
+
211
+ async def create_simple_stream_generator():
212
+ try:
213
+ yield generate_sse_chunk("正在处理请求...", req_id, MODEL_NAME)
214
+ await asyncio.sleep(1)
215
+ yield generate_sse_chunk("处理完成", req_id, MODEL_NAME)
216
+ yield generate_sse_stop_chunk(req_id, MODEL_NAME)
217
+ yield "data: [DONE]\n\n"
218
+ finally:
219
+ if not completion_event.is_set():
220
+ completion_event.set()
221
+
222
+ if not result_future.done():
223
+ result_future.set_result(StreamingResponse(create_simple_stream_generator(), media_type="text/event-stream"))
224
+
225
+ return completion_event, submit_button_locator, check_client_disconnected
226
+ else:
227
+ response_payload = {
228
+ "id": f"{CHAT_COMPLETION_ID_PREFIX}{req_id}-{int(time.time())}",
229
+ "object": "chat.completion",
230
+ "created": int(time.time()),
231
+ "model": MODEL_NAME,
232
+ "choices": [{
233
+ "index": 0,
234
+ "message": {"role": "assistant", "content": "处理完成 - 需要完整逻辑"},
235
+ "finish_reason": "stop"
236
+ }],
237
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
238
+ }
239
+
240
+ if not result_future.done():
241
+ result_future.set_result(JSONResponse(content=response_payload))
242
+
243
+ return None
244
+
245
+ except ClientDisconnectedError as disco_err:
246
+ logger.info(f"[{req_id}] (Refactored Process) 捕获到客户端断开连接信号: {disco_err}")
247
+ if not result_future.done():
248
+ result_future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] Client disconnected during processing."))
249
+ except HTTPException as http_err:
250
+ logger.warning(f"[{req_id}] (Refactored Process) 捕获到 HTTP 异常: {http_err.status_code} - {http_err.detail}")
251
+ if not result_future.done():
252
+ result_future.set_exception(http_err)
253
+ except Exception as e:
254
+ logger.exception(f"[{req_id}] (Refactored Process) 捕获到意外错误")
255
+ await save_error_snapshot(f"process_unexpected_error_{req_id}")
256
+ if not result_future.done():
257
+ result_future.set_exception(HTTPException(status_code=500, detail=f"[{req_id}] Unexpected server error: {e}"))
258
+ finally:
259
+ if disconnect_check_task and not disconnect_check_task.done():
260
+ disconnect_check_task.cancel()
261
+ try:
262
+ await disconnect_check_task
263
+ except asyncio.CancelledError:
264
+ pass
265
+ except Exception as task_clean_err:
266
+ logger.error(f"[{req_id}] 清理任务时出错: {task_clean_err}")
267
+
268
+ logger.info(f"[{req_id}] (Refactored Process) 处理完成。")
269
+
270
+ if is_streaming and completion_event and not completion_event.is_set() and (result_future.done() and result_future.exception() is not None):
271
+ logger.warning(f"[{req_id}] (Refactored Process) 流式请求异常,确保完成事件已设置。")
272
+ completion_event.set()
273
+
274
+ return completion_event, submit_button_locator, check_client_disconnected
api_utils/routes.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI路由处理器模块
3
+ 包含所有API端点的处理函数
4
+ """
5
+
6
+ import asyncio
7
+ import os
8
+ import random
9
+ import time
10
+ import uuid
11
+ from typing import Dict, List, Any, Set
12
+ from asyncio import Queue, Future, Lock, Event
13
+ import logging
14
+
15
+ from fastapi import HTTPException, Request, WebSocket, WebSocketDisconnect, Depends
16
+ from fastapi.responses import JSONResponse, FileResponse
17
+ from pydantic import BaseModel
18
+ from playwright.async_api import Page as AsyncPage
19
+
20
+ # --- 配置模块导入 ---
21
+ from config import *
22
+
23
+ # --- models模块导入 ---
24
+ from models import ChatCompletionRequest, WebSocketConnectionManager
25
+
26
+ # --- browser_utils模块导入 ---
27
+ from browser_utils import _handle_model_list_response
28
+
29
+ # --- 依赖项导入 ---
30
+ from .dependencies import *
31
+
32
+
33
+ # --- 静态文件端点 ---
34
+ async def read_index(logger: logging.Logger = Depends(get_logger)):
35
+ """返回主页面"""
36
+ index_html_path = os.path.join(os.path.dirname(__file__), "..", "index.html")
37
+ if not os.path.exists(index_html_path):
38
+ logger.error(f"index.html not found at {index_html_path}")
39
+ raise HTTPException(status_code=404, detail="index.html not found")
40
+ return FileResponse(index_html_path)
41
+
42
+
43
+ async def get_css(logger: logging.Logger = Depends(get_logger)):
44
+ """返回CSS文件"""
45
+ css_path = os.path.join(os.path.dirname(__file__), "..", "webui.css")
46
+ if not os.path.exists(css_path):
47
+ logger.error(f"webui.css not found at {css_path}")
48
+ raise HTTPException(status_code=404, detail="webui.css not found")
49
+ return FileResponse(css_path, media_type="text/css")
50
+
51
+
52
+ async def get_js(logger: logging.Logger = Depends(get_logger)):
53
+ """返回JavaScript文件"""
54
+ js_path = os.path.join(os.path.dirname(__file__), "..", "webui.js")
55
+ if not os.path.exists(js_path):
56
+ logger.error(f"webui.js not found at {js_path}")
57
+ raise HTTPException(status_code=404, detail="webui.js not found")
58
+ return FileResponse(js_path, media_type="application/javascript")
59
+
60
+
61
+ # --- API信息端点 ---
62
+ async def get_api_info(request: Request, current_ai_studio_model_id: str = Depends(get_current_ai_studio_model_id)):
63
+ """返回API信息"""
64
+ from api_utils import auth_utils
65
+
66
+ server_port = request.url.port or os.environ.get('SERVER_PORT_INFO', '8000')
67
+ host = request.headers.get('host') or f"127.0.0.1:{server_port}"
68
+ scheme = request.headers.get('x-forwarded-proto', 'http')
69
+ base_url = f"{scheme}://{host}"
70
+ api_base = f"{base_url}/v1"
71
+ effective_model_name = current_ai_studio_model_id or MODEL_NAME
72
+
73
+ api_key_required = bool(auth_utils.API_KEYS)
74
+ api_key_count = len(auth_utils.API_KEYS)
75
+
76
+ if api_key_required:
77
+ message = f"API Key is required. {api_key_count} valid key(s) configured."
78
+ else:
79
+ message = "API Key is not required."
80
+
81
+ return JSONResponse(content={
82
+ "model_name": effective_model_name,
83
+ "api_base_url": api_base,
84
+ "server_base_url": base_url,
85
+ "api_key_required": api_key_required,
86
+ "api_key_count": api_key_count,
87
+ "auth_header": "Authorization: Bearer <token> or X-API-Key: <token>" if api_key_required else None,
88
+ "openai_compatible": True,
89
+ "supported_auth_methods": ["Authorization: Bearer", "X-API-Key"] if api_key_required else [],
90
+ "message": message
91
+ })
92
+
93
+
94
+ # --- 健康检查端点 ---
95
+ async def health_check(
96
+ server_state: Dict[str, Any] = Depends(get_server_state),
97
+ worker_task = Depends(get_worker_task),
98
+ request_queue: Queue = Depends(get_request_queue)
99
+ ):
100
+ """健康检查"""
101
+ is_worker_running = bool(worker_task and not worker_task.done())
102
+ launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
103
+ browser_page_critical = launch_mode != "direct_debug_no_browser"
104
+
105
+ core_ready_conditions = [not server_state["is_initializing"], server_state["is_playwright_ready"]]
106
+ if browser_page_critical:
107
+ core_ready_conditions.extend([server_state["is_browser_connected"], server_state["is_page_ready"]])
108
+
109
+ is_core_ready = all(core_ready_conditions)
110
+ status_val = "OK" if is_core_ready and is_worker_running else "Error"
111
+ q_size = request_queue.qsize() if request_queue else -1
112
+
113
+ status_message_parts = []
114
+ if server_state["is_initializing"]: status_message_parts.append("初始化进行中")
115
+ if not server_state["is_playwright_ready"]: status_message_parts.append("Playwright 未就绪")
116
+ if browser_page_critical:
117
+ if not server_state["is_browser_connected"]: status_message_parts.append("浏览器未连接")
118
+ if not server_state["is_page_ready"]: status_message_parts.append("页面未就绪")
119
+ if not is_worker_running: status_message_parts.append("Worker 未运行")
120
+
121
+ status = {
122
+ "status": status_val,
123
+ "message": "",
124
+ "details": {**server_state, "workerRunning": is_worker_running, "queueLength": q_size, "launchMode": launch_mode, "browserAndPageCritical": browser_page_critical}
125
+ }
126
+
127
+ if status_val == "OK":
128
+ status["message"] = f"服务运行中;队列长度: {q_size}。"
129
+ return JSONResponse(content=status, status_code=200)
130
+ else:
131
+ status["message"] = f"服务不可用;问题: {(', '.join(status_message_parts) or '未知原因')}. 队列长度: {q_size}."
132
+ return JSONResponse(content=status, status_code=503)
133
+
134
+
135
+ # --- 模型列表端点 ---
136
+ async def list_models(
137
+ logger: logging.Logger = Depends(get_logger),
138
+ model_list_fetch_event: Event = Depends(get_model_list_fetch_event),
139
+ page_instance: AsyncPage = Depends(get_page_instance),
140
+ parsed_model_list: List[Dict[str, Any]] = Depends(get_parsed_model_list),
141
+ excluded_model_ids: Set[str] = Depends(get_excluded_model_ids)
142
+ ):
143
+ """获取模型列表"""
144
+ logger.info("[API] 收到 /v1/models 请求。")
145
+
146
+ if not model_list_fetch_event.is_set() and page_instance and not page_instance.is_closed():
147
+ logger.info("/v1/models: 模型列表事件未设置,尝试刷新页面...")
148
+ try:
149
+ await page_instance.reload(wait_until="domcontentloaded", timeout=20000)
150
+ await asyncio.wait_for(model_list_fetch_event.wait(), timeout=10.0)
151
+ except Exception as e:
152
+ logger.error(f"/v1/models: 刷新或等待模型列表时出错: {e}")
153
+ finally:
154
+ if not model_list_fetch_event.is_set():
155
+ model_list_fetch_event.set()
156
+
157
+ if parsed_model_list:
158
+ final_model_list = [m for m in parsed_model_list if m.get("id") not in excluded_model_ids]
159
+ return {"object": "list", "data": final_model_list}
160
+ else:
161
+ logger.warning("模型列表为空,返回默认后备模型。")
162
+ return {"object": "list", "data": [{
163
+ "id": DEFAULT_FALLBACK_MODEL_ID, "object": "model", "created": int(time.time()),
164
+ "owned_by": "camoufox-proxy-fallback"
165
+ }]}
166
+
167
+
168
+ # --- 聊天完成端点 ---
169
+ async def chat_completions(
170
+ request: ChatCompletionRequest,
171
+ http_request: Request,
172
+ logger: logging.Logger = Depends(get_logger),
173
+ request_queue: Queue = Depends(get_request_queue),
174
+ server_state: Dict[str, Any] = Depends(get_server_state),
175
+ worker_task = Depends(get_worker_task)
176
+ ):
177
+ """处理聊天完成请求"""
178
+ req_id = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=7))
179
+ logger.info(f"[{req_id}] 收到 /v1/chat/completions 请求 (Stream={request.stream})")
180
+
181
+ launch_mode = os.environ.get('LAUNCH_MODE', 'unknown')
182
+ browser_page_critical = launch_mode != "direct_debug_no_browser"
183
+
184
+ service_unavailable = server_state["is_initializing"] or \
185
+ not server_state["is_playwright_ready"] or \
186
+ (browser_page_critical and (not server_state["is_page_ready"] or not server_state["is_browser_connected"])) or \
187
+ not worker_task or worker_task.done()
188
+
189
+ if service_unavailable:
190
+ raise HTTPException(status_code=503, detail=f"[{req_id}] 服务当前不可用。请稍后重试。", headers={"Retry-After": "30"})
191
+
192
+ result_future = Future()
193
+ await request_queue.put({
194
+ "req_id": req_id, "request_data": request, "http_request": http_request,
195
+ "result_future": result_future, "enqueue_time": time.time(), "cancelled": False
196
+ })
197
+
198
+ try:
199
+ timeout_seconds = RESPONSE_COMPLETION_TIMEOUT / 1000 + 120
200
+ return await asyncio.wait_for(result_future, timeout=timeout_seconds)
201
+ except asyncio.TimeoutError:
202
+ raise HTTPException(status_code=504, detail=f"[{req_id}] 请求处理超时。")
203
+ except asyncio.CancelledError:
204
+ raise HTTPException(status_code=499, detail=f"[{req_id}] 请求被客户端取消。")
205
+ except Exception as e:
206
+ logger.exception(f"[{req_id}] 等待Worker响应时出错")
207
+ raise HTTPException(status_code=500, detail=f"[{req_id}] 服务器内部错误: {e}")
208
+
209
+
210
+ # --- 取消请求相关 ---
211
+ async def cancel_queued_request(req_id: str, request_queue: Queue, logger: logging.Logger) -> bool:
212
+ """取消队列中的请求"""
213
+ items_to_requeue = []
214
+ found = False
215
+ try:
216
+ while not request_queue.empty():
217
+ item = request_queue.get_nowait()
218
+ if item.get("req_id") == req_id:
219
+ logger.info(f"[{req_id}] 在队列中找到请求,标记为已取消。")
220
+ item["cancelled"] = True
221
+ if (future := item.get("result_future")) and not future.done():
222
+ future.set_exception(HTTPException(status_code=499, detail=f"[{req_id}] Request cancelled."))
223
+ found = True
224
+ items_to_requeue.append(item)
225
+ finally:
226
+ for item in items_to_requeue:
227
+ await request_queue.put(item)
228
+ return found
229
+
230
+
231
+ async def cancel_request(
232
+ req_id: str,
233
+ logger: logging.Logger = Depends(get_logger),
234
+ request_queue: Queue = Depends(get_request_queue)
235
+ ):
236
+ """取消请求端点"""
237
+ logger.info(f"[{req_id}] 收到取消请求。")
238
+ if await cancel_queued_request(req_id, request_queue, logger):
239
+ return JSONResponse(content={"success": True, "message": f"Request {req_id} marked as cancelled."})
240
+ else:
241
+ return JSONResponse(status_code=404, content={"success": False, "message": f"Request {req_id} not found in queue."})
242
+
243
+
244
+ # --- 队列状态端点 ---
245
+ async def get_queue_status(
246
+ request_queue: Queue = Depends(get_request_queue),
247
+ processing_lock: Lock = Depends(get_processing_lock)
248
+ ):
249
+ """获取队列状态"""
250
+ queue_items = list(request_queue._queue)
251
+ return JSONResponse(content={
252
+ "queue_length": len(queue_items),
253
+ "is_processing_locked": processing_lock.locked(),
254
+ "items": sorted([
255
+ {
256
+ "req_id": item.get("req_id", "unknown"),
257
+ "enqueue_time": item.get("enqueue_time", 0),
258
+ "wait_time_seconds": round(time.time() - item.get("enqueue_time", 0), 2),
259
+ "is_streaming": item.get("request_data").stream,
260
+ "cancelled": item.get("cancelled", False)
261
+ } for item in queue_items
262
+ ], key=lambda x: x.get("enqueue_time", 0))
263
+ })
264
+
265
+
266
+ # --- WebSocket日志端点 ---
267
+ async def websocket_log_endpoint(
268
+ websocket: WebSocket,
269
+ logger: logging.Logger = Depends(get_logger),
270
+ log_ws_manager: WebSocketConnectionManager = Depends(get_log_ws_manager)
271
+ ):
272
+ """WebSocket日志端点"""
273
+ if not log_ws_manager:
274
+ await websocket.close(code=1011)
275
+ return
276
+
277
+ client_id = str(uuid.uuid4())
278
+ try:
279
+ await log_ws_manager.connect(client_id, websocket)
280
+ while True:
281
+ await websocket.receive_text() # Keep connection alive
282
+ except WebSocketDisconnect:
283
+ pass
284
+ except Exception as e:
285
+ logger.error(f"日志 WebSocket (客户端 {client_id}) 发生异常: {e}", exc_info=True)
286
+ finally:
287
+ log_ws_manager.disconnect(client_id)
288
+
289
+
290
+ # --- API密钥管理数据模型 ---
291
+ class ApiKeyRequest(BaseModel):
292
+ key: str
293
+
294
+ class ApiKeyTestRequest(BaseModel):
295
+ key: str
296
+
297
+
298
+ # --- API密钥管理端点 ---
299
+ async def get_api_keys(logger: logging.Logger = Depends(get_logger)):
300
+ """获取API密钥列表"""
301
+ from api_utils import auth_utils
302
+ try:
303
+ auth_utils.initialize_keys()
304
+ keys_info = [{"value": key, "status": "有效"} for key in auth_utils.API_KEYS]
305
+ return JSONResponse(content={"success": True, "keys": keys_info, "total_count": len(keys_info)})
306
+ except Exception as e:
307
+ logger.error(f"获取API密钥列表失败: {e}")
308
+ raise HTTPException(status_code=500, detail=str(e))
309
+
310
+
311
+ async def add_api_key(request: ApiKeyRequest, logger: logging.Logger = Depends(get_logger)):
312
+ """添加API密钥"""
313
+ from api_utils import auth_utils
314
+ key_value = request.key.strip()
315
+ if not key_value or len(key_value) < 8:
316
+ raise HTTPException(status_code=400, detail="无效的API密钥格式。")
317
+
318
+ auth_utils.initialize_keys()
319
+ if key_value in auth_utils.API_KEYS:
320
+ raise HTTPException(status_code=400, detail="该API密钥已存在。")
321
+
322
+ try:
323
+ key_file_path = os.path.join(os.path.dirname(__file__), "..", "key.txt")
324
+ with open(key_file_path, 'a+', encoding='utf-8') as f:
325
+ f.seek(0)
326
+ if f.read(): f.write("\n")
327
+ f.write(key_value)
328
+
329
+ auth_utils.initialize_keys()
330
+ logger.info(f"API密钥已添加: {key_value[:4]}...{key_value[-4:]}")
331
+ return JSONResponse(content={"success": True, "message": "API密钥添加成功", "key_count": len(auth_utils.API_KEYS)})
332
+ except Exception as e:
333
+ logger.error(f"添加API密钥失败: {e}")
334
+ raise HTTPException(status_code=500, detail=str(e))
335
+
336
+
337
+ async def test_api_key(request: ApiKeyTestRequest, logger: logging.Logger = Depends(get_logger)):
338
+ """测试API密钥"""
339
+ from api_utils import auth_utils
340
+ key_value = request.key.strip()
341
+ if not key_value:
342
+ raise HTTPException(status_code=400, detail="API密钥不能为空。")
343
+
344
+ auth_utils.initialize_keys()
345
+ is_valid = auth_utils.verify_api_key(key_value)
346
+ logger.info(f"API密钥测试: {key_value[:4]}...{key_value[-4:]} - {'有效' if is_valid else '无效'}")
347
+ return JSONResponse(content={"success": True, "valid": is_valid, "message": "密钥有效" if is_valid else "密钥无效或不存在"})
348
+
349
+
350
+ async def delete_api_key(request: ApiKeyRequest, logger: logging.Logger = Depends(get_logger)):
351
+ """删除API密钥"""
352
+ from api_utils import auth_utils
353
+ key_value = request.key.strip()
354
+ if not key_value:
355
+ raise HTTPException(status_code=400, detail="API密钥不能为空。")
356
+
357
+ auth_utils.initialize_keys()
358
+ if key_value not in auth_utils.API_KEYS:
359
+ raise HTTPException(status_code=404, detail="API密钥不存在。")
360
+
361
+ try:
362
+ key_file_path = os.path.join(os.path.dirname(__file__), "..", "key.txt")
363
+ with open(key_file_path, 'r', encoding='utf-8') as f:
364
+ lines = f.readlines()
365
+
366
+ with open(key_file_path, 'w', encoding='utf-8') as f:
367
+ f.writelines(line for line in lines if line.strip() != key_value)
368
+
369
+ auth_utils.initialize_keys()
370
+ logger.info(f"API密钥已删除: {key_value[:4]}...{key_value[-4:]}")
371
+ return JSONResponse(content={"success": True, "message": "API密钥删除成功", "key_count": len(auth_utils.API_KEYS)})
372
+ except Exception as e:
373
+ logger.error(f"删除API密钥失败: {e}")
374
+ raise HTTPException(status_code=500, detail=str(e))
api_utils/utils.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API工具函数模块
3
+ 包含SSE生成、流处理、token统计和请求验证等工具函数
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ import time
9
+ import datetime
10
+ from typing import Any, Dict, List, Optional, AsyncGenerator
11
+ from asyncio import Queue
12
+ from models import Message
13
+
14
+
15
+
16
+ # --- SSE生成函数 ---
17
+ def generate_sse_chunk(delta: str, req_id: str, model: str) -> str:
18
+ """生成SSE数据块"""
19
+ chunk_data = {
20
+ "id": f"chatcmpl-{req_id}",
21
+ "object": "chat.completion.chunk",
22
+ "created": int(time.time()),
23
+ "model": model,
24
+ "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}]
25
+ }
26
+ return f"data: {json.dumps(chunk_data)}\n\n"
27
+
28
+
29
+ def generate_sse_stop_chunk(req_id: str, model: str, reason: str = "stop", usage: dict = None) -> str:
30
+ """生成SSE停止块"""
31
+ stop_chunk_data = {
32
+ "id": f"chatcmpl-{req_id}",
33
+ "object": "chat.completion.chunk",
34
+ "created": int(time.time()),
35
+ "model": model,
36
+ "choices": [{"index": 0, "delta": {}, "finish_reason": reason}]
37
+ }
38
+
39
+ # 添加usage信息(如果提供)
40
+ if usage:
41
+ stop_chunk_data["usage"] = usage
42
+
43
+ return f"data: {json.dumps(stop_chunk_data)}\n\ndata: [DONE]\n\n"
44
+
45
+
46
+ def generate_sse_error_chunk(message: str, req_id: str, error_type: str = "server_error") -> str:
47
+ """生成SSE错误块"""
48
+ error_chunk = {"error": {"message": message, "type": error_type, "param": None, "code": req_id}}
49
+ return f"data: {json.dumps(error_chunk)}\n\n"
50
+
51
+
52
+ # --- 流处理工具函数 ---
53
+ async def use_stream_response(req_id: str) -> AsyncGenerator[Any, None]:
54
+ """使用流响应(从服务器的全局队列获取数据)"""
55
+ from server import STREAM_QUEUE, logger
56
+ import queue
57
+
58
+ if STREAM_QUEUE is None:
59
+ logger.warning(f"[{req_id}] STREAM_QUEUE is None, 无法使用流响应")
60
+ return
61
+
62
+ logger.info(f"[{req_id}] 开始使用流响应")
63
+
64
+ empty_count = 0
65
+ max_empty_retries = 300 # 30秒超时
66
+ data_received = False
67
+
68
+ try:
69
+ while True:
70
+ try:
71
+ # 从队列中获取数据
72
+ data = STREAM_QUEUE.get_nowait()
73
+ if data is None: # 结束标志
74
+ logger.info(f"[{req_id}] 接收到流结束标志")
75
+ break
76
+
77
+ # 重置空计数器
78
+ empty_count = 0
79
+ data_received = True
80
+ logger.debug(f"[{req_id}] 接收到流数据: {type(data)} - {str(data)[:200]}...")
81
+
82
+ # 检查是否是JSON字符串形式的结束标志
83
+ if isinstance(data, str):
84
+ try:
85
+ parsed_data = json.loads(data)
86
+ if parsed_data.get("done") is True:
87
+ logger.info(f"[{req_id}] 接收到JSON格式的完成标志")
88
+ yield parsed_data
89
+ break
90
+ else:
91
+ yield parsed_data
92
+ except json.JSONDecodeError:
93
+ # 如果不是JSON,直接返回字符串
94
+ logger.debug(f"[{req_id}] 返回非JSON字符串数据")
95
+ yield data
96
+ else:
97
+ # 直接返回数据
98
+ yield data
99
+
100
+ # 检查字典类型的结束标志
101
+ if isinstance(data, dict) and data.get("done") is True:
102
+ logger.info(f"[{req_id}] 接收到字典格式的完成标志")
103
+ break
104
+
105
+ except (queue.Empty, asyncio.QueueEmpty):
106
+ empty_count += 1
107
+ if empty_count % 50 == 0: # 每5秒记录一次等待状态
108
+ logger.info(f"[{req_id}] 等待流数据... ({empty_count}/{max_empty_retries})")
109
+
110
+ if empty_count >= max_empty_retries:
111
+ if not data_received:
112
+ logger.error(f"[{req_id}] 流响应队列空读取次数达到上限且未收到任何数据,可能是辅助流未启动或出错")
113
+ else:
114
+ logger.warning(f"[{req_id}] 流响应队列空读取次数达到上限 ({max_empty_retries}),结束读取")
115
+
116
+ # 返回超时完成信号,而不是简单退出
117
+ yield {"done": True, "reason": "internal_timeout", "body": "", "function": []}
118
+ return
119
+
120
+ await asyncio.sleep(0.1) # 100ms等待
121
+ continue
122
+
123
+ except Exception as e:
124
+ logger.error(f"[{req_id}] 使用流响应时出错: {e}")
125
+ raise
126
+ finally:
127
+ logger.info(f"[{req_id}] 流响应使用完成,数据接收状态: {data_received}")
128
+
129
+
130
+ async def clear_stream_queue():
131
+ """清空流队列(与原始参考文件保持一致)"""
132
+ from server import STREAM_QUEUE, logger
133
+ import queue
134
+
135
+ if STREAM_QUEUE is None:
136
+ logger.info("流队列未初始化或已被禁用,跳过清空操作。")
137
+ return
138
+
139
+ while True:
140
+ try:
141
+ data_chunk = await asyncio.to_thread(STREAM_QUEUE.get_nowait)
142
+ # logger.info(f"清空流式队列缓存,丢弃数据: {data_chunk}")
143
+ except queue.Empty:
144
+ logger.info("流式队列已清空 (捕获到 queue.Empty)。")
145
+ break
146
+ except Exception as e:
147
+ logger.error(f"清空流式队列时发生意外错误: {e}", exc_info=True)
148
+ break
149
+ logger.info("流式队列缓存清空完毕。")
150
+
151
+
152
+ # --- Helper response generator ---
153
+ async def use_helper_get_response(helper_endpoint: str, helper_sapisid: str) -> AsyncGenerator[str, None]:
154
+ """使用Helper服务获取响应的生成器"""
155
+ from server import logger
156
+ import aiohttp
157
+
158
+ logger.info(f"正在尝试使用Helper端点: {helper_endpoint}")
159
+
160
+ try:
161
+ async with aiohttp.ClientSession() as session:
162
+ headers = {
163
+ 'Content-Type': 'application/json',
164
+ 'Cookie': f'SAPISID={helper_sapisid}' if helper_sapisid else ''
165
+ }
166
+
167
+ async with session.get(helper_endpoint, headers=headers) as response:
168
+ if response.status == 200:
169
+ async for chunk in response.content.iter_chunked(1024):
170
+ if chunk:
171
+ yield chunk.decode('utf-8', errors='ignore')
172
+ else:
173
+ logger.error(f"Helper端点返回错误状态: {response.status}")
174
+
175
+ except Exception as e:
176
+ logger.error(f"使用Helper端点时出错: {e}")
177
+
178
+
179
+ # --- 请求验证函数 ---
180
+ def validate_chat_request(messages: List[Message], req_id: str) -> Dict[str, Optional[str]]:
181
+ """验证聊天请求"""
182
+ from server import logger
183
+
184
+ if not messages:
185
+ raise ValueError(f"[{req_id}] 无效请求: 'messages' 数组缺失或为空。")
186
+
187
+ if not any(msg.role != 'system' for msg in messages):
188
+ raise ValueError(f"[{req_id}] 无效请求: 所有消息都是系统消息。至少需要一条用户或助手消息。")
189
+
190
+ # 返回验证结果
191
+ return {
192
+ "error": None,
193
+ "warning": None
194
+ }
195
+
196
+
197
+ # --- 提示准备函数 ---
198
+ def prepare_combined_prompt(messages: List[Message], req_id: str) -> str:
199
+ """准备组合提示"""
200
+ from server import logger
201
+
202
+ logger.info(f"[{req_id}] (准备提示) 正在从 {len(messages)} 条消息准备组合提示 (包括历史)。")
203
+
204
+ combined_parts = []
205
+ system_prompt_content: Optional[str] = None
206
+ processed_system_message_indices = set()
207
+
208
+ # 处理系统消息
209
+ for i, msg in enumerate(messages):
210
+ if msg.role == 'system':
211
+ content = msg.content
212
+ if isinstance(content, str) and content.strip():
213
+ system_prompt_content = content.strip()
214
+ processed_system_message_indices.add(i)
215
+ logger.info(f"[{req_id}] (准备提示) 在索引 {i} 找到并使用系统提示: '{system_prompt_content[:80]}...'")
216
+ system_instr_prefix = "系统指令:\n"
217
+ combined_parts.append(f"{system_instr_prefix}{system_prompt_content}")
218
+ else:
219
+ logger.info(f"[{req_id}] (准备提示) 在索引 {i} 忽略非字符串或空的系统消息。")
220
+ processed_system_message_indices.add(i)
221
+ break
222
+
223
+ role_map_ui = {"user": "用户", "assistant": "助手", "system": "系统", "tool": "工具"}
224
+ turn_separator = "\n---\n"
225
+
226
+ # 处理其他消息
227
+ for i, msg in enumerate(messages):
228
+ if i in processed_system_message_indices:
229
+ continue
230
+
231
+ if msg.role == 'system':
232
+ logger.info(f"[{req_id}] (准备提示) 跳过在索引 {i} 的后续系统消息。")
233
+ continue
234
+
235
+ if combined_parts:
236
+ combined_parts.append(turn_separator)
237
+
238
+ role = msg.role or 'unknown'
239
+ role_prefix_ui = f"{role_map_ui.get(role, role.capitalize())}:\n"
240
+ current_turn_parts = [role_prefix_ui]
241
+
242
+ content = msg.content or ''
243
+ content_str = ""
244
+
245
+ if isinstance(content, str):
246
+ content_str = content.strip()
247
+ elif isinstance(content, list):
248
+ # 处理多模态内容
249
+ text_parts = []
250
+ for item in content:
251
+ if hasattr(item, 'type') and item.type == 'text':
252
+ text_parts.append(item.text or '')
253
+ elif isinstance(item, dict) and item.get('type') == 'text':
254
+ text_parts.append(item.get('text', ''))
255
+ else:
256
+ logger.warning(f"[{req_id}] (准备提示) 警告: 在索引 {i} 的消息中忽略非文本或未知类型的 content item")
257
+ content_str = "\n".join(text_parts).strip()
258
+ else:
259
+ logger.warning(f"[{req_id}] (准备提示) 警告: 角色 {role} 在索引 {i} 的内容类型意外 ({type(content)}) 或为 None。")
260
+ content_str = str(content or "").strip()
261
+
262
+ if content_str:
263
+ current_turn_parts.append(content_str)
264
+
265
+ # 处理工具调用
266
+ tool_calls = msg.tool_calls
267
+ if role == 'assistant' and tool_calls:
268
+ if content_str:
269
+ current_turn_parts.append("\n")
270
+
271
+ tool_call_visualizations = []
272
+ for tool_call in tool_calls:
273
+ if hasattr(tool_call, 'type') and tool_call.type == 'function':
274
+ function_call = tool_call.function
275
+ func_name = function_call.name if function_call else None
276
+ func_args_str = function_call.arguments if function_call else None
277
+
278
+ try:
279
+ parsed_args = json.loads(func_args_str if func_args_str else '{}')
280
+ formatted_args = json.dumps(parsed_args, indent=2, ensure_ascii=False)
281
+ except (json.JSONDecodeError, TypeError):
282
+ formatted_args = func_args_str if func_args_str is not None else "{}"
283
+
284
+ tool_call_visualizations.append(
285
+ f"请求调用函数: {func_name}\n参数:\n{formatted_args}"
286
+ )
287
+
288
+ if tool_call_visualizations:
289
+ current_turn_parts.append("\n".join(tool_call_visualizations))
290
+
291
+ if len(current_turn_parts) > 1 or (role == 'assistant' and tool_calls):
292
+ combined_parts.append("".join(current_turn_parts))
293
+ elif not combined_parts and not current_turn_parts:
294
+ logger.info(f"[{req_id}] (准备提示) 跳过角色 {role} 在索引 {i} 的空消息 (且无工具调用)。")
295
+ elif len(current_turn_parts) == 1 and not combined_parts:
296
+ logger.info(f"[{req_id}] (准备提示) 跳过角色 {role} 在索引 {i} 的空消息 (只有前缀)。")
297
+
298
+ final_prompt = "".join(combined_parts)
299
+ if final_prompt:
300
+ final_prompt += "\n"
301
+
302
+ preview_text = final_prompt[:300].replace('\n', '\\n')
303
+ logger.info(f"[{req_id}] (准备提示) 组合提示长度: {len(final_prompt)}。预览: '{preview_text}...'")
304
+
305
+ return final_prompt
306
+
307
+
308
+ def estimate_tokens(text: str) -> int:
309
+ """
310
+ 估算文本的token数量
311
+ 使用简单的字符计数方法:
312
+ - 英文:大约4个字符 = 1个token
313
+ - 中文:大约1.5个字符 = 1个token
314
+ - 混合文本:采用加权平均
315
+ """
316
+ if not text:
317
+ return 0
318
+
319
+ # 统计中文字符数量(包括中文标点)
320
+ chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff' or '\u3000' <= char <= '\u303f' or '\uff00' <= char <= '\uffef')
321
+
322
+ # 统计非中文字符数量
323
+ non_chinese_chars = len(text) - chinese_chars
324
+
325
+ # 计算token估算
326
+ chinese_tokens = chinese_chars / 1.5 # 中文大约1.5字符/token
327
+ english_tokens = non_chinese_chars / 4.0 # 英文大约4字符/token
328
+
329
+ return max(1, int(chinese_tokens + english_tokens))
330
+
331
+
332
+ def calculate_usage_stats(messages: List[dict], response_content: str, reasoning_content: str = None) -> dict:
333
+ """
334
+ 计算token使用统计
335
+
336
+ Args:
337
+ messages: 请求中的消息列表
338
+ response_content: 响应内容
339
+ reasoning_content: 推理内容(可选)
340
+
341
+ Returns:
342
+ 包含token使用统计的字典
343
+ """
344
+ # 计算输入token(prompt tokens)
345
+ prompt_text = ""
346
+ for message in messages:
347
+ role = message.get("role", "")
348
+ content = message.get("content", "")
349
+ prompt_text += f"{role}: {content}\n"
350
+
351
+ prompt_tokens = estimate_tokens(prompt_text)
352
+
353
+ # 计算输出token(completion tokens)
354
+ completion_text = response_content or ""
355
+ if reasoning_content:
356
+ completion_text += reasoning_content
357
+
358
+ completion_tokens = estimate_tokens(completion_text)
359
+
360
+ # 总token数
361
+ total_tokens = prompt_tokens + completion_tokens
362
+
363
+ return {
364
+ "prompt_tokens": prompt_tokens,
365
+ "completion_tokens": completion_tokens,
366
+ "total_tokens": total_tokens
367
+ }
368
+
369
+
370
+ def generate_sse_stop_chunk_with_usage(req_id: str, model: str, usage_stats: dict, reason: str = "stop") -> str:
371
+ """生成带usage统计的SSE停止块"""
372
+ return generate_sse_stop_chunk(req_id, model, reason, usage_stats)