ZyphrZero commited on
Commit
d88d316
·
unverified ·
2 Parent(s): c4a7c2d 04753fc

Merge pull request #10 from ZyphrZero/dev

Browse files
.env.example CHANGED
@@ -28,6 +28,9 @@ AUTH_TOKENS_FILE=tokens.txt
28
  # 服务监听端口
29
  LISTEN_PORT=8080
30
 
 
 
 
31
  # 调试日志
32
  DEBUG_LOGGING=true
33
 
@@ -45,6 +48,6 @@ SCAN_LIMIT=200000
45
  # ========== 错误码400处理 ==========
46
 
47
  # 重试次数
48
- MAX_RETRIES=5
49
  # 初始重试延迟
50
  RETRY_DELAY=1
 
28
  # 服务监听端口
29
  LISTEN_PORT=8080
30
 
31
+ # 服务名称(用于进程唯一性验证)
32
+ SERVICE_NAME=z-ai2api-server
33
+
34
  # 调试日志
35
  DEBUG_LOGGING=true
36
 
 
48
  # ========== 错误码400处理 ==========
49
 
50
  # 重试次数
51
+ MAX_RETRIES=6
52
  # 初始重试延迟
53
  RETRY_DELAY=1
.gitignore CHANGED
@@ -14,6 +14,7 @@ main.onefile-build/
14
  *.yaml
15
  logs/
16
  backup/
 
17
 
18
  # AI Toolset
19
  .augment/
 
14
  *.yaml
15
  logs/
16
  backup/
17
+ uv.lock
18
 
19
  # AI Toolset
20
  .augment/
README.md CHANGED
@@ -23,6 +23,7 @@
23
  - 📊 **多模型映射** - 智能上游模型路由
24
  - 🔄 **Token 池管理** - 自动轮询、容错恢复、动态更新
25
  - 🛡️ **错误处理** - 完善的异常捕获和重试机制
 
26
 
27
  ## 🚀 快速开始
28
 
@@ -245,6 +246,13 @@ A: 这通常是因为 Token 获取失败导致的。请检查:
245
  - Token 文件是否存在且格式正确(`tokens.txt`)
246
  - 网络连接是否正常,能否访问 Z.AI API
247
 
 
 
 
 
 
 
 
248
  **Q: 如何通过 Claude Code 使用本服务?**
249
 
250
  A: 创建 [zai.js](https://gist.githubusercontent.com/musistudio/b35402d6f9c95c64269c7666b8405348/raw/f108d66fa050f308387938f149a2b14a295d29e9/gistfile1.txt) 这个 ccr 插件放在`./.claude-code-router/plugins`目录下,配置 `./.claude-code-router/config.json` 指向本服务地址,使用 `AUTH_TOKEN` 进行认证。
 
23
  - 📊 **多模型映射** - 智能上游模型路由
24
  - 🔄 **Token 池管理** - 自动轮询、容错恢复、动态更新
25
  - 🛡️ **错误处理** - 完善的异常捕获和重试机制
26
+ - 🔒 **服务唯一性** - 基于进程名称(pname)的服务唯一性验证,防止重复启动
27
 
28
  ## 🚀 快速开始
29
 
 
246
  - Token 文件是否存在且格式正确(`tokens.txt`)
247
  - 网络连接是否正常,能否访问 Z.AI API
248
 
249
+ **Q: 启动时提示"服务已在运行"怎么办?**
250
+ A: 这是服务唯一性验证功能,防止重复启动。解决方法:
251
+ - 检查是否已有服务实例在运行:`ps aux | grep z-ai2api-server`
252
+ - 停止现有实例后再启动新的
253
+ - 如果确认没有实例运行,删除 PID 文件:`rm z-ai2api-server.pid`
254
+ - 可通过环境变量 `SERVICE_NAME` 自定义服务名称避免冲突
255
+
256
  **Q: 如何通过 Claude Code 使用本服务?**
257
 
258
  A: 创建 [zai.js](https://gist.githubusercontent.com/musistudio/b35402d6f9c95c64269c7666b8405348/raw/f108d66fa050f308387938f149a2b14a295d29e9/gistfile1.txt) 这个 ccr 插件放在`./.claude-code-router/plugins`目录下,配置 `./.claude-code-router/config.json` 指向本服务地址,使用 `AUTH_TOKEN` 进行认证。
app/core/config.py CHANGED
@@ -109,6 +109,7 @@ class Settings(BaseSettings):
109
  # Server Configuration
110
  LISTEN_PORT: int = int(os.getenv("LISTEN_PORT", "8080"))
111
  DEBUG_LOGGING: bool = os.getenv("DEBUG_LOGGING", "true").lower() == "true"
 
112
 
113
  ANONYMOUS_MODE: bool = os.getenv("ANONYMOUS_MODE", "true").lower() == "true"
114
  TOOL_SUPPORT: bool = os.getenv("TOOL_SUPPORT", "true").lower() == "true"
 
109
  # Server Configuration
110
  LISTEN_PORT: int = int(os.getenv("LISTEN_PORT", "8080"))
111
  DEBUG_LOGGING: bool = os.getenv("DEBUG_LOGGING", "true").lower() == "true"
112
+ SERVICE_NAME: str = os.getenv("SERVICE_NAME", "z-ai2api-server")
113
 
114
  ANONYMOUS_MODE: bool = os.getenv("ANONYMOUS_MODE", "true").lower() == "true"
115
  TOOL_SUPPORT: bool = os.getenv("TOOL_SUPPORT", "true").lower() == "true"
app/core/openai.py CHANGED
@@ -156,12 +156,21 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
156
 
157
  # 初始化工具处理器(如果需要)
158
  has_tools = transformed["body"].get("tools") is not None
 
159
  tool_handler = None
160
- if has_tools:
 
 
161
  chat_id = transformed["body"]["chat_id"]
162
  model = request.model
163
  tool_handler = SSEToolHandler(chat_id, model)
164
- logger.info(f"🔧 初始化工具处理器: {len(transformed['body'].get('tools', []))} 个工具")
 
 
 
 
 
 
165
 
166
  # 处理状态
167
  has_thinking = False
@@ -368,7 +377,7 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
368
  "system_fingerprint": "fp_zai_001",
369
  }
370
  output_data = f"data: {json.dumps(content_chunk)}\n\n"
371
- logger.debug(f"➡️ 输出内容块到客户端: {output_data[:1000]}...")
372
  yield output_data
373
 
374
  # 处理完成
 
156
 
157
  # 初始化工具处理器(如果需要)
158
  has_tools = transformed["body"].get("tools") is not None
159
+ has_mcp_servers = bool(transformed["body"].get("mcp_servers"))
160
  tool_handler = None
161
+
162
+ # 如果有工具定义或MCP服务器,都需要工具处理器
163
+ if has_tools or has_mcp_servers:
164
  chat_id = transformed["body"]["chat_id"]
165
  model = request.model
166
  tool_handler = SSEToolHandler(chat_id, model)
167
+
168
+ if has_tools and has_mcp_servers:
169
+ logger.info(f"🔧 初始化工具处理器: {len(transformed['body'].get('tools', []))} 个OpenAI工具 + {len(transformed['body'].get('mcp_servers', []))} 个MCP服务器")
170
+ elif has_tools:
171
+ logger.info(f"🔧 初始化工具处理器: {len(transformed['body'].get('tools', []))} 个OpenAI工具")
172
+ elif has_mcp_servers:
173
+ logger.info(f"🔧 初始化工具处理器: {len(transformed['body'].get('mcp_servers', []))} 个MCP服务器")
174
 
175
  # 处理状态
176
  has_thinking = False
 
377
  "system_fingerprint": "fp_zai_001",
378
  }
379
  output_data = f"data: {json.dumps(content_chunk)}\n\n"
380
+ logger.debug(f"➡️ 输出内容块到客户端: {output_data}")
381
  yield output_data
382
 
383
  # 处理完成
app/core/zai_transformer.py CHANGED
@@ -209,10 +209,19 @@ class ZAITransformer:
209
  is_search = requested_model == settings.SEARCH_MODEL
210
  is_air = requested_model == settings.AIR_MODEL
211
 
 
 
 
 
 
 
 
 
212
  # 获取上游模型ID(使用模型映射)
213
  upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API")
214
  logger.debug(f" 模型映射: {requested_model} -> {upstream_model_id}")
215
  logger.debug(f" 模型特性检测: is_search={is_search}, is_thinking={is_thinking}, is_air={is_air}")
 
216
  logger.debug(f" SEARCH_MODEL配置: {settings.SEARCH_MODEL}")
217
 
218
  # 处理消息列表
@@ -263,17 +272,28 @@ class ZAITransformer:
263
 
264
  # 构建MCP服务器列表
265
  mcp_servers = []
266
- if is_search:
267
  mcp_servers.append("deep-web-search")
268
- logger.info(f"🔍 检测到搜索模型,添加 deep-web-search MCP 服务器")
 
269
  else:
270
- logger.debug(f" 非搜索模型,不添加 MCP 服务器")
271
 
272
  logger.debug(f" MCP服务器列表: {mcp_servers}")
273
 
274
  # 构建上游请求体
275
  chat_id = generate_uuid()
276
-
 
 
 
 
 
 
 
 
 
 
277
  body = {
278
  "stream": True, # 总是使用流式
279
  "model": upstream_model_id, # 使用映射后的模型ID
@@ -281,11 +301,11 @@ class ZAITransformer:
281
  "params": {},
282
  "features": {
283
  "image_generation": False,
284
- "web_search": is_search,
285
- "auto_web_search": is_search,
286
- "preview_mode": False,
287
  "flags": [],
288
- "features": [],
289
  "enable_thinking": is_thinking,
290
  },
291
  "background_tasks": {
@@ -300,10 +320,14 @@ class ZAITransformer:
300
  "{{CURRENT_DATE}}": datetime.now().strftime("%Y-%m-%d"),
301
  "{{CURRENT_TIME}}": datetime.now().strftime("%H:%M:%S"),
302
  "{{CURRENT_WEEKDAY}}": datetime.now().strftime("%A"),
303
- "{{CURRENT_TIMEZONE}}": "UTC",
304
  "{{USER_LANGUAGE}}": "zh-CN",
305
  },
306
- "model_item": {},
 
 
 
 
307
  "chat_id": chat_id,
308
  "id": generate_uuid(),
309
  }
@@ -336,11 +360,24 @@ class ZAITransformer:
336
 
337
  # 记录关键的请求信息用于调试
338
  logger.debug(f" 📋 发送到Z.AI的关键信息:")
 
339
  logger.debug(f" - 上游模型: {body['model']}")
340
  logger.debug(f" - MCP服务器: {body['mcp_servers']}")
341
  logger.debug(f" - web_search: {body['features']['web_search']}")
342
  logger.debug(f" - auto_web_search: {body['features']['auto_web_search']}")
343
  logger.debug(f" - 消息数量: {len(body['messages'])}")
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  return {"body": body, "config": config, "token": token}
346
 
 
209
  is_search = requested_model == settings.SEARCH_MODEL
210
  is_air = requested_model == settings.AIR_MODEL
211
 
212
+ # 检查是否需要搜索功能(更灵活的检测)
213
+ needs_search = (
214
+ is_search or # 明确的搜索模型
215
+ "search" in requested_model.lower() or # 模型名包含search
216
+ request.get("enable_search", False) or # 显式启用搜索
217
+ (request.get("tools") and any("search" in str(tool).lower() for tool in request.get("tools", []))) # 工具中包含搜索
218
+ )
219
+
220
  # 获取上游模型ID(使用模型映射)
221
  upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API")
222
  logger.debug(f" 模型映射: {requested_model} -> {upstream_model_id}")
223
  logger.debug(f" 模型特性检测: is_search={is_search}, is_thinking={is_thinking}, is_air={is_air}")
224
+ logger.debug(f" 需要搜索功能: {needs_search}")
225
  logger.debug(f" SEARCH_MODEL配置: {settings.SEARCH_MODEL}")
226
 
227
  # 处理消息列表
 
272
 
273
  # 构建MCP服务器列表
274
  mcp_servers = []
275
+ if needs_search:
276
  mcp_servers.append("deep-web-search")
277
+ logger.info(f"🔍 检测到需要搜索功能,添加 deep-web-search MCP 服务器")
278
+ logger.debug(f" 搜索检测原因: is_search={is_search}, model_name_contains_search={'search' in requested_model.lower()}")
279
  else:
280
+ logger.debug(f" 不需要搜索功能,不添加 MCP 服务器")
281
 
282
  logger.debug(f" MCP服务器列表: {mcp_servers}")
283
 
284
  # 构建上游请求体
285
  chat_id = generate_uuid()
286
+
287
+ # 根据参考文档构建更完整的features配置
288
+ features_list = []
289
+ if needs_search:
290
+ features_list.extend([
291
+ {"type": "mcp", "server": "deep-web-search", "status": "selected"},
292
+ {"type": "mcp", "server": "vibe-coding", "status": "hidden"},
293
+ {"type": "mcp", "server": "ppt-maker", "status": "hidden"},
294
+ {"type": "mcp", "server": "image-search", "status": "hidden"}
295
+ ])
296
+
297
  body = {
298
  "stream": True, # 总是使用流式
299
  "model": upstream_model_id, # 使用映射后的模型ID
 
301
  "params": {},
302
  "features": {
303
  "image_generation": False,
304
+ "web_search": needs_search,
305
+ "auto_web_search": needs_search,
306
+ "preview_mode": True if needs_search else False, # 搜索时启用预览模式
307
  "flags": [],
308
+ "features": features_list,
309
  "enable_thinking": is_thinking,
310
  },
311
  "background_tasks": {
 
320
  "{{CURRENT_DATE}}": datetime.now().strftime("%Y-%m-%d"),
321
  "{{CURRENT_TIME}}": datetime.now().strftime("%H:%M:%S"),
322
  "{{CURRENT_WEEKDAY}}": datetime.now().strftime("%A"),
323
+ "{{CURRENT_TIMEZONE}}": "Asia/Shanghai", # 使用更合适的时区
324
  "{{USER_LANGUAGE}}": "zh-CN",
325
  },
326
+ "model_item": {
327
+ "id": upstream_model_id,
328
+ "name": requested_model,
329
+ "owned_by": "z.ai"
330
+ },
331
  "chat_id": chat_id,
332
  "id": generate_uuid(),
333
  }
 
360
 
361
  # 记录关键的请求信息用于调试
362
  logger.debug(f" 📋 发送到Z.AI的关键信息:")
363
+ logger.debug(f" - 原始模型: {requested_model}")
364
  logger.debug(f" - 上游模型: {body['model']}")
365
  logger.debug(f" - MCP服务器: {body['mcp_servers']}")
366
  logger.debug(f" - web_search: {body['features']['web_search']}")
367
  logger.debug(f" - auto_web_search: {body['features']['auto_web_search']}")
368
  logger.debug(f" - 消息数量: {len(body['messages'])}")
369
+ tools_count = len(body.get('tools') or [])
370
+ logger.debug(f" - 工具数量: {tools_count}")
371
+
372
+ # 特别记录MCP相关信息
373
+ if body['mcp_servers']:
374
+ logger.info(f"🎯 MCP服务器配置成功: {body['mcp_servers']}")
375
+ logger.debug(f" 📋 完整的features配置: {json.dumps(body['features'], ensure_ascii=False, indent=2)}")
376
+ else:
377
+ logger.warning(f"⚠️ 未配置MCP服务器 - 检查模型: {requested_model}, 搜索需��: {needs_search}")
378
+
379
+ # 记录完整的请求体(用于调试)
380
+ logger.debug(f" 📋 完整请求体: {json.dumps(body, ensure_ascii=False, indent=2)}")
381
 
382
  return {"body": body, "config": config, "token": token}
383
 
app/utils/process_manager.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 进程管理模块
6
+ 提供服务唯一性验证和进程管理功能
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import time
12
+ import psutil
13
+ from typing import Optional, List
14
+ from pathlib import Path
15
+
16
+ from app.utils.logger import get_logger
17
+
18
+ logger = get_logger()
19
+
20
+
21
+ class ProcessManager:
22
+ """进程管理器 - 负责服务唯一性验证和进程管理"""
23
+
24
+ def __init__(self, service_name: str = "z-ai2api-server", port: int = 8080):
25
+ """
26
+ 初始化进程管理器
27
+
28
+ Args:
29
+ service_name: 服务名称,用于进程名称标识
30
+ port: 服务端口,用于唯一性检查
31
+ """
32
+ self.service_name = service_name
33
+ self.port = port
34
+ self.current_pid = os.getpid()
35
+ self.pid_file = Path(f"{service_name}.pid")
36
+
37
+ def check_service_uniqueness(self) -> bool:
38
+ """
39
+ 检查服务唯一性
40
+
41
+ 通过以下方式验证:
42
+ 1. 检查 PID 文件
43
+ 2. 检查端口是否被占用
44
+ 3. 检查进程名称 (pname) 是否已存在(可选)
45
+
46
+ Returns:
47
+ bool: True 表示可以启动服务,False 表示已有实例运行
48
+ """
49
+ logger.info(f"🔍 检查服务唯一性: {self.service_name} (端口: {self.port})")
50
+
51
+ # 1. 优先检查 PID 文件(最可靠)
52
+ if self._check_pid_file():
53
+ return False
54
+
55
+ # 2. 检查端口占用
56
+ if self._check_port_usage():
57
+ return False
58
+
59
+ # 3. 检查进程名称(作为额外保障)
60
+ if self._check_process_by_name():
61
+ return False
62
+
63
+ logger.info("✅ 服务唯一性检查通过,可以启动服务")
64
+ return True
65
+
66
+ def _check_process_by_name(self) -> bool:
67
+ """
68
+ 通过进程名称检查是否已有实例运行
69
+
70
+ 这是一个保守的检查,只检查明确的服务进程标识
71
+
72
+ Returns:
73
+ bool: True 表示发现同名进程,False 表示未发现
74
+ """
75
+ try:
76
+ running_processes = []
77
+
78
+ for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
79
+ try:
80
+ proc_info = proc.info
81
+
82
+ # 跳过当前进程
83
+ if proc_info['pid'] == self.current_pid:
84
+ continue
85
+
86
+ # 只检查进程名称直接匹配服务名称的情况
87
+ # 这通常发生在使用 Granian 的 process_name 参数时
88
+ if proc_info['name'] and proc_info['name'] == self.service_name:
89
+ running_processes.append(proc_info)
90
+ continue
91
+
92
+ # 检查命令行参数中是否包含明确的服务标识
93
+ cmdline = proc_info.get('cmdline', [])
94
+ if cmdline and len(cmdline) >= 2:
95
+ cmdline_str = ' '.join(cmdline)
96
+
97
+ # 只检查通过 Granian 启动且明确指定了进程名称的服务
98
+ if (f'--process-name={self.service_name}' in cmdline_str or
99
+ f'process_name={self.service_name}' in cmdline_str):
100
+ running_processes.append(proc_info)
101
+
102
+ except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
103
+ # 进程可能已经结束或无权限访问
104
+ continue
105
+
106
+ if running_processes:
107
+ logger.warning(f"⚠️ 发现 {len(running_processes)} 个同名进程正在运行:")
108
+ for proc_info in running_processes:
109
+ cmdline = proc_info.get('cmdline', [])
110
+ cmdline_preview = ' '.join(cmdline[:3]) + '...' if len(cmdline) > 3 else ' '.join(cmdline)
111
+ logger.warning(f" PID: {proc_info['pid']}, 名称: {proc_info['name']}, 命令: {cmdline_preview}")
112
+ logger.warning(f"❌ 服务 {self.service_name} 已在运行,请先停止现有实例")
113
+ return True
114
+
115
+ return False
116
+
117
+ except Exception as e:
118
+ logger.error(f"❌ 检查进程名称时发生错误: {e}")
119
+ return False
120
+
121
+ def _check_port_usage(self) -> bool:
122
+ """
123
+ 检查端口是否被占用
124
+
125
+ Returns:
126
+ bool: True 表示端口被占用,False 表示端口可用
127
+ """
128
+ try:
129
+ # 获取所有网络连接
130
+ connections = psutil.net_connections(kind='inet')
131
+
132
+ for conn in connections:
133
+ if (conn.laddr.port == self.port and
134
+ conn.status in [psutil.CONN_LISTEN, psutil.CONN_ESTABLISHED]):
135
+
136
+ # 尝试获取占用端口的进程信息
137
+ try:
138
+ proc = psutil.Process(conn.pid) if conn.pid else None
139
+ proc_name = proc.name() if proc else "未知进程"
140
+ logger.warning(f"⚠️ 端口 {self.port} 已被占用")
141
+ logger.warning(f" 占用进程: PID {conn.pid}, 名称: {proc_name}")
142
+ logger.warning(f"❌ 无法启动服务,端口 {self.port} 不可用")
143
+ return True
144
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
145
+ logger.warning(f"⚠️ 端口 {self.port} 已被占用(无法获取进程信息)")
146
+ return True
147
+
148
+ return False
149
+
150
+ except Exception as e:
151
+ logger.error(f"❌ 检查端口占用时发生错误: {e}")
152
+ return False
153
+
154
+ def _check_pid_file(self) -> bool:
155
+ """
156
+ 检查 PID 文件
157
+
158
+ Returns:
159
+ bool: True 表示发现有效的 PID 文件,False 表示无冲突
160
+ """
161
+ try:
162
+ if not self.pid_file.exists():
163
+ return False
164
+
165
+ # 读取 PID 文件
166
+ pid_content = self.pid_file.read_text().strip()
167
+ if not pid_content.isdigit():
168
+ logger.warning(f"⚠️ PID 文件格式无效: {self.pid_file}")
169
+ self._cleanup_pid_file()
170
+ return False
171
+
172
+ old_pid = int(pid_content)
173
+
174
+ # 检查进程是否仍在运行
175
+ try:
176
+ proc = psutil.Process(old_pid)
177
+ if proc.is_running():
178
+ logger.warning(f"⚠️ 发现有效的 PID 文件: {self.pid_file}")
179
+ logger.warning(f" 进程 PID {old_pid} 仍在运行: {proc.name()}")
180
+ logger.warning(f"❌ 服务可能已在运行,请检查进程或删除 PID 文件")
181
+ return True
182
+ else:
183
+ logger.info(f"🧹 清理无效的 PID 文件: {self.pid_file}")
184
+ self._cleanup_pid_file()
185
+ return False
186
+ except psutil.NoSuchProcess:
187
+ logger.info(f"🧹 清理过期的 PID 文件: {self.pid_file}")
188
+ self._cleanup_pid_file()
189
+ return False
190
+
191
+ except Exception as e:
192
+ logger.error(f"❌ 检查 PID 文件时发生错误: {e}")
193
+ return False
194
+
195
+ def _cleanup_pid_file(self):
196
+ """清理 PID 文件"""
197
+ try:
198
+ if self.pid_file.exists():
199
+ self.pid_file.unlink()
200
+ logger.debug(f"🧹 已删除 PID 文件: {self.pid_file}")
201
+ except Exception as e:
202
+ logger.error(f"❌ 删除 PID 文件失败: {e}")
203
+
204
+ def create_pid_file(self):
205
+ """创建 PID 文件"""
206
+ try:
207
+ self.pid_file.write_text(str(self.current_pid))
208
+ logger.info(f"📝 创建 PID 文件: {self.pid_file} (PID: {self.current_pid})")
209
+ except Exception as e:
210
+ logger.error(f"❌ 创建 PID 文件失败: {e}")
211
+
212
+ def cleanup_on_exit(self):
213
+ """退出时清理资源"""
214
+ logger.info(f"🧹 清理进程资源 (PID: {self.current_pid})")
215
+ self._cleanup_pid_file()
216
+
217
+ def get_running_instances(self) -> List[dict]:
218
+ """
219
+ 获取所有运行中的服务实例
220
+
221
+ Returns:
222
+ List[dict]: 运行中的实例信息列表
223
+ """
224
+ instances = []
225
+
226
+ try:
227
+ for proc in psutil.process_iter(['pid', 'name', 'cmdline', 'create_time']):
228
+ try:
229
+ proc_info = proc.info
230
+
231
+ # 跳过当前进程
232
+ if proc_info['pid'] == self.current_pid:
233
+ continue
234
+
235
+ # 使用与 _check_process_by_name 相同的保守逻辑
236
+ is_service = False
237
+
238
+ # 只检查进程名称直接匹配服务名称的情况
239
+ if proc_info['name'] and proc_info['name'] == self.service_name:
240
+ is_service = True
241
+
242
+ # 检查命令行参数中是否包含明确的服务标识
243
+ cmdline = proc_info.get('cmdline', [])
244
+ if cmdline and len(cmdline) >= 2:
245
+ cmdline_str = ' '.join(cmdline)
246
+
247
+ # 只检查通过 Granian 启动且明确指定了进程名称的服务
248
+ if (f'--process-name={self.service_name}' in cmdline_str or
249
+ f'process_name={self.service_name}' in cmdline_str):
250
+ is_service = True
251
+
252
+ if is_service:
253
+ instances.append({
254
+ 'pid': proc_info['pid'],
255
+ 'name': proc_info['name'],
256
+ 'cmdline': cmdline,
257
+ 'create_time': proc_info['create_time'],
258
+ 'start_time': time.strftime('%Y-%m-%d %H:%M:%S',
259
+ time.localtime(proc_info['create_time']))
260
+ })
261
+
262
+ except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
263
+ continue
264
+
265
+ except Exception as e:
266
+ logger.error(f"❌ 获取运行实例时发生错误: {e}")
267
+
268
+ return instances
269
+
270
+
271
+ def ensure_service_uniqueness(service_name: str = "z-ai2api-server", port: int = 8080) -> bool:
272
+ """
273
+ 确保服务唯一性的便捷函数
274
+
275
+ Args:
276
+ service_name: 服务名称
277
+ port: 服务端口
278
+
279
+ Returns:
280
+ bool: True 表示可以启动,False 表示应该退出
281
+ """
282
+ manager = ProcessManager(service_name, port)
283
+
284
+ if not manager.check_service_uniqueness():
285
+ logger.error("❌ 服务唯一性检查失败,程序退出")
286
+
287
+ # 显示运行中的实例
288
+ instances = manager.get_running_instances()
289
+ if instances:
290
+ logger.info("📋 当前运行的实例:")
291
+ for instance in instances:
292
+ logger.info(f" PID: {instance['pid']}, 启动时间: {instance['start_time']}")
293
+
294
+ return False
295
+
296
+ # 创建 PID 文件
297
+ manager.create_pid_file()
298
+
299
+ # 注册退出清理
300
+ import atexit
301
+ atexit.register(manager.cleanup_on_exit)
302
+
303
+ return True
main.py CHANGED
@@ -1,6 +1,9 @@
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
 
 
 
 
4
  from contextlib import asynccontextmanager
5
  from fastapi import FastAPI, Response
6
  from fastapi.middleware.cors import CORSMiddleware
@@ -10,6 +13,7 @@ from app.core import openai
10
  from app.utils.reload_config import RELOAD_CONFIG
11
  from app.utils.logger import setup_logger
12
  from app.utils.token_pool import initialize_token_pool
 
13
 
14
  from granian import Granian
15
 
@@ -62,14 +66,32 @@ async def root():
62
 
63
 
64
  def run_server():
65
- Granian(
66
- "main:app",
67
- interface="asgi",
68
- address="0.0.0.0",
69
- port=settings.LISTEN_PORT,
70
- reload=False, # 生产环境请关闭热重载
71
- **RELOAD_CONFIG,
72
- ).serve()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
 
4
+ import os
5
+ import sys
6
+ import psutil
7
  from contextlib import asynccontextmanager
8
  from fastapi import FastAPI, Response
9
  from fastapi.middleware.cors import CORSMiddleware
 
13
  from app.utils.reload_config import RELOAD_CONFIG
14
  from app.utils.logger import setup_logger
15
  from app.utils.token_pool import initialize_token_pool
16
+ from app.utils.process_manager import ensure_service_uniqueness
17
 
18
  from granian import Granian
19
 
 
66
 
67
 
68
  def run_server():
69
+ # 服务唯一性检查
70
+ service_name = settings.SERVICE_NAME
71
+ if not ensure_service_uniqueness(service_name=service_name, port=settings.LISTEN_PORT):
72
+ logger.error("❌ 服务已在运行,程序退出")
73
+ sys.exit(1)
74
+
75
+ logger.info(f"🚀 启动 {service_name} 服务...")
76
+ logger.info(f"📡 监听地址: 0.0.0.0:{settings.LISTEN_PORT}")
77
+ logger.info(f"🔧 调试模式: {'开启' if settings.DEBUG_LOGGING else '关闭'}")
78
+ logger.info(f"🔐 匿名模式: {'开启' if settings.ANONYMOUS_MODE else '关闭'}")
79
+
80
+ try:
81
+ Granian(
82
+ "main:app",
83
+ interface="asgi",
84
+ address="0.0.0.0",
85
+ port=settings.LISTEN_PORT,
86
+ reload=True, # 生产环境请关闭热重载
87
+ process_name=service_name, # 设置进程名称
88
+ **RELOAD_CONFIG,
89
+ ).serve()
90
+ except KeyboardInterrupt:
91
+ logger.info("🛑 收到中断信号,正在关闭服务...")
92
+ except Exception as e:
93
+ logger.error(f"❌ 服务启动失败: {e}")
94
+ sys.exit(1)
95
 
96
 
97
  if __name__ == "__main__":
pyproject.toml CHANGED
@@ -25,15 +25,16 @@ classifiers = [
25
  ]
26
  dependencies = [
27
  "fastapi==0.104.1",
28
- "granian[reload]==2.5.2",
29
  "requests==2.32.5",
 
30
  "pydantic==2.11.7",
31
  "pydantic-settings==2.10.1",
32
  "pydantic-core==2.33.2",
33
  "typing-inspection==0.4.1",
34
  "fake-useragent==2.2.0",
35
  "loguru==0.7.3",
36
- "httpx==0.27.0"
37
  ]
38
 
39
  [project.scripts]
 
25
  ]
26
  dependencies = [
27
  "fastapi==0.104.1",
28
+ "granian[reload,pname]==2.5.2",
29
  "requests==2.32.5",
30
+ "httpx==0.27.0",
31
  "pydantic==2.11.7",
32
  "pydantic-settings==2.10.1",
33
  "pydantic-core==2.33.2",
34
  "typing-inspection==0.4.1",
35
  "fake-useragent==2.2.0",
36
  "loguru==0.7.3",
37
+ "psutil>=7.0.0",
38
  ]
39
 
40
  [project.scripts]
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  fastapi==0.116.1
2
- granian[reload]==2.5.2
3
  requests==2.32.5
4
  httpx==0.27.0
5
  pydantic==2.11.7
@@ -7,4 +7,5 @@ pydantic-settings==2.10.1
7
  pydantic-core==2.33.2
8
  typing-inspection==0.4.1
9
  fake-useragent==2.2.0
10
- loguru==0.7.3
 
 
1
  fastapi==0.116.1
2
+ granian[reload,pname]==2.5.2
3
  requests==2.32.5
4
  httpx==0.27.0
5
  pydantic==2.11.7
 
7
  pydantic-core==2.33.2
8
  typing-inspection==0.4.1
9
  fake-useragent==2.2.0
10
+ loguru==0.7.3
11
+ psutil>=7.0.0
tests/test_comprehensive_tool_calls.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 全面的工具调用测试套件
6
+ 覆盖各种工具类型、参数格式、传输模式和边界情况
7
+ """
8
+
9
+ import json
10
+ import time
11
+ from typing import Dict, Any, List
12
+ from app.utils.sse_tool_handler import SSEToolHandler
13
+ from app.utils.logger import get_logger
14
+
15
+ logger = get_logger()
16
+
17
+ class TestResult:
18
+ """测试结果统计"""
19
+ def __init__(self, test_name: str):
20
+ self.test_name = test_name
21
+ self.passed = 0
22
+ self.failed = 0
23
+ self.errors = []
24
+
25
+ def add_pass(self):
26
+ self.passed += 1
27
+
28
+ def add_fail(self, error_msg: str):
29
+ self.failed += 1
30
+ self.errors.append(error_msg)
31
+
32
+ def print_summary(self):
33
+ total = self.passed + self.failed
34
+ success_rate = (self.passed / total * 100) if total > 0 else 0
35
+
36
+ print(f"\n📊 {self.test_name} 测试汇总:")
37
+ print(f" 总测试数: {total}")
38
+ print(f" ✅ 通过: {self.passed}")
39
+ print(f" ❌ 失败: {self.failed}")
40
+ print(f" 📈 成功率: {success_rate:.1f}%")
41
+
42
+ if self.errors:
43
+ print(f"\n❌ 失败详情:")
44
+ for i, error in enumerate(self.errors, 1):
45
+ print(f" {i}. {error}")
46
+
47
+ def test_various_tool_types():
48
+ """测试各种类型的工具调用"""
49
+
50
+ result = TestResult("工具类型测试")
51
+
52
+ # 定义各种工具类型的测试用例
53
+ tool_scenarios = [
54
+ {
55
+ "name": "浏览器导航工具",
56
+ "tool_name": "browser_navigate",
57
+ "arguments": '{"url": "https://www.google.com"}',
58
+ "expected_args": {"url": "https://www.google.com"},
59
+ "description": "测试浏览器导航工具的URL参数"
60
+ },
61
+ {
62
+ "name": "天气查询工具",
63
+ "tool_name": "get_weather",
64
+ "arguments": '{"city": "北京", "unit": "celsius"}',
65
+ "expected_args": {"city": "北京", "unit": "celsius"},
66
+ "description": "测试天气查询工具的城市和单位参数"
67
+ },
68
+ {
69
+ "name": "文件操作工具",
70
+ "tool_name": "file_write",
71
+ "arguments": '{"path": "/tmp/test.txt", "content": "Hello World", "encoding": "utf-8"}',
72
+ "expected_args": {"path": "/tmp/test.txt", "content": "Hello World", "encoding": "utf-8"},
73
+ "description": "测试文件写入工具的多参数"
74
+ },
75
+ {
76
+ "name": "搜索工具",
77
+ "tool_name": "web_search",
78
+ "arguments": '{"query": "Python编程", "limit": 10, "safe_search": true}',
79
+ "expected_args": {"query": "Python编程", "limit": 10, "safe_search": True},
80
+ "description": "测试搜索工具的混合类型参数"
81
+ },
82
+ {
83
+ "name": "数据库查询工具",
84
+ "tool_name": "db_query",
85
+ "arguments": '{"sql": "SELECT * FROM users WHERE age > ?", "params": [18], "timeout": 30.5}',
86
+ "expected_args": {"sql": "SELECT * FROM users WHERE age > ?", "params": [18], "timeout": 30.5},
87
+ "description": "测试数据库工具的复杂参数结构"
88
+ },
89
+ {
90
+ "name": "API调用工具",
91
+ "tool_name": "api_call",
92
+ "arguments": '{"method": "POST", "url": "https://api.example.com/data", "headers": {"Content-Type": "application/json"}, "body": {"key": "value"}}',
93
+ "expected_args": {"method": "POST", "url": "https://api.example.com/data", "headers": {"Content-Type": "application/json"}, "body": {"key": "value"}},
94
+ "description": "测试API调用工具的嵌套对象参数"
95
+ },
96
+ {
97
+ "name": "图像处理工具",
98
+ "tool_name": "image_resize",
99
+ "arguments": '{"input_path": "image.jpg", "output_path": "resized.jpg", "width": 800, "height": 600, "maintain_aspect": false}',
100
+ "expected_args": {"input_path": "image.jpg", "output_path": "resized.jpg", "width": 800, "height": 600, "maintain_aspect": False},
101
+ "description": "测试图像处理工具的数值和布尔参数"
102
+ },
103
+ {
104
+ "name": "邮件发送工具",
105
+ "tool_name": "send_email",
106
+ "arguments": '{"to": ["user1@example.com", "user2@example.com"], "subject": "测试邮件", "body": "这是一封测试邮件\\n包含换行符", "attachments": []}',
107
+ "expected_args": {"to": ["user1@example.com", "user2@example.com"], "subject": "测试邮件", "body": "这是一封测试邮件\n包含换行符", "attachments": []},
108
+ "description": "测试邮件工具的数组参数和转义字符"
109
+ }
110
+ ]
111
+
112
+ print("🔧 测试各种类型的工具调用")
113
+ print("=" * 80)
114
+
115
+ for i, scenario in enumerate(tool_scenarios, 1):
116
+ print(f"\n测试 {i}: {scenario['name']}")
117
+ print(f"描述: {scenario['description']}")
118
+
119
+ try:
120
+ handler = SSEToolHandler("test_chat_id", "GLM-4.5")
121
+
122
+ # 构造完整的工具调用数据
123
+ tool_data = {
124
+ "edit_index": 0,
125
+ "edit_content": f'<glm_block >{{"type": "mcp", "data": {{"metadata": {{"id": "call_{i}", "name": "{scenario["tool_name"]}", "arguments": "{scenario["arguments"]}", "result": "", "status": "completed"}}}}, "thought": null}}</glm_block>',
126
+ "phase": "tool_call"
127
+ }
128
+
129
+ # 处理工具调用
130
+ chunks = list(handler.process_tool_call_phase(tool_data, is_stream=False))
131
+
132
+ # 验证结果
133
+ if handler.active_tools:
134
+ tool = list(handler.active_tools.values())[0]
135
+ actual_args = tool["arguments"]
136
+ expected_args = scenario["expected_args"]
137
+
138
+ if actual_args == expected_args:
139
+ print(f" ✅ 参数解析正确: {actual_args}")
140
+ result.add_pass()
141
+ else:
142
+ error_msg = f"{scenario['name']}: 参数不匹配 - 期望: {expected_args}, 实际: {actual_args}"
143
+ print(f" ❌ {error_msg}")
144
+ result.add_fail(error_msg)
145
+ else:
146
+ error_msg = f"{scenario['name']}: 未检测到工具调用"
147
+ print(f" ❌ {error_msg}")
148
+ result.add_fail(error_msg)
149
+
150
+ except Exception as e:
151
+ error_msg = f"{scenario['name']}: 处理异常 - {str(e)}"
152
+ print(f" ❌ {error_msg}")
153
+ result.add_fail(error_msg)
154
+
155
+ result.print_summary()
156
+ return result
157
+
158
+ def test_parameter_formats():
159
+ """测试各种参数格式"""
160
+
161
+ result = TestResult("参数格式测试")
162
+
163
+ # 定义各种参数格式的测试用例
164
+ format_scenarios = [
165
+ {
166
+ "name": "空参数",
167
+ "arguments": "{}",
168
+ "expected": {},
169
+ "description": "测试空参数对象"
170
+ },
171
+ {
172
+ "name": "null参数",
173
+ "arguments": "null",
174
+ "expected": {},
175
+ "description": "测试null参数值"
176
+ },
177
+ {
178
+ "name": "转义JSON字符串",
179
+ "arguments": '{\\"key\\": \\"value\\"}',
180
+ "expected": {"key": "value"},
181
+ "description": "测试转义的JSON字符串"
182
+ },
183
+ {
184
+ "name": "包含特殊字符",
185
+ "arguments": '{"text": "Hello\\nWorld\\t!", "emoji": "😀🎉", "unicode": "中文测试"}',
186
+ "expected": {"text": "Hello\nWorld\t!", "emoji": "😀🎉", "unicode": "中文测试"},
187
+ "description": "测试包含换行符、制表符、emoji和中文的参数"
188
+ },
189
+ {
190
+ "name": "数值类型",
191
+ "arguments": '{"int": 42, "float": 3.14159, "negative": -100, "zero": 0}',
192
+ "expected": {"int": 42, "float": 3.14159, "negative": -100, "zero": 0},
193
+ "description": "测试各种数值类型参数"
194
+ },
195
+ {
196
+ "name": "布尔类型",
197
+ "arguments": '{"true_val": true, "false_val": false}',
198
+ "expected": {"true_val": True, "false_val": False},
199
+ "description": "测试布尔类型参数"
200
+ },
201
+ {
202
+ "name": "数组参数",
203
+ "arguments": '{"empty_array": [], "string_array": ["a", "b", "c"], "mixed_array": [1, "two", true, null]}',
204
+ "expected": {"empty_array": [], "string_array": ["a", "b", "c"], "mixed_array": [1, "two", True, None]},
205
+ "description": "测试各种数组类型参数"
206
+ },
207
+ {
208
+ "name": "嵌套对象",
209
+ "arguments": '{"nested": {"level1": {"level2": {"value": "deep"}}}, "array_of_objects": [{"id": 1}, {"id": 2}]}',
210
+ "expected": {"nested": {"level1": {"level2": {"value": "deep"}}}, "array_of_objects": [{"id": 1}, {"id": 2}]},
211
+ "description": "测试深度嵌套的对象和对象数组"
212
+ },
213
+ {
214
+ "name": "长字符串",
215
+ "arguments": '{"long_text": "' + "A" * 1000 + '"}',
216
+ "expected": {"long_text": "A" * 1000},
217
+ "description": "测试长字符串参数"
218
+ },
219
+ {
220
+ "name": "包含引号的字符串",
221
+ "arguments": '{"quoted": "He said \\"Hello\\" to me", "single_quote": "It\'s working"}',
222
+ "expected": {"quoted": 'He said "Hello" to me', "single_quote": "It's working"},
223
+ "description": "测试包含引号的字符串参数"
224
+ }
225
+ ]
226
+
227
+ print("\n📝 测试各种参数格式")
228
+ print("=" * 80)
229
+
230
+ for i, scenario in enumerate(format_scenarios, 1):
231
+ print(f"\n测试 {i}: {scenario['name']}")
232
+ print(f"描述: {scenario['description']}")
233
+
234
+ try:
235
+ handler = SSEToolHandler("test_chat_id", "GLM-4.5")
236
+
237
+ # 直接测试参数解析
238
+ result_args = handler._parse_partial_arguments(scenario["arguments"])
239
+
240
+ if result_args == scenario["expected"]:
241
+ print(f" ✅ 参数解析正确")
242
+ result.add_pass()
243
+ else:
244
+ error_msg = f"{scenario['name']}: 参数解析错误 - 期望: {scenario['expected']}, 实际: {result_args}"
245
+ print(f" ❌ {error_msg}")
246
+ result.add_fail(error_msg)
247
+
248
+ except Exception as e:
249
+ error_msg = f"{scenario['name']}: 解析异常 - {str(e)}"
250
+ print(f" ❌ {error_msg}")
251
+ result.add_fail(error_msg)
252
+
253
+ result.print_summary()
254
+ return result
tests/test_live_server.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 测试当前运行的服务器是否正确处理GLM-4.5-Search模型
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ import httpx
11
+ from app.core.config import settings
12
+
13
+ async def test_live_server():
14
+ """测试实际运行的服务器"""
15
+
16
+ print("🧪 测试当前运行的服务器...")
17
+ print(f"服务器地址: http://localhost:{settings.LISTEN_PORT}")
18
+ print()
19
+
20
+ try:
21
+ async with httpx.AsyncClient() as client:
22
+ # 测试搜索模型请求
23
+ search_request = {
24
+ "model": "GLM-4.5-Search",
25
+ "messages": [
26
+ {"role": "user", "content": "请搜索今天北京的天气"}
27
+ ],
28
+ "stream": True # 使用流式以便观察日志
29
+ }
30
+
31
+ headers = {
32
+ "Content-Type": "application/json",
33
+ "Authorization": f"Bearer {settings.AUTH_TOKEN}"
34
+ }
35
+
36
+ print(f"📤 发送GLM-4.5-Search请求...")
37
+ print(f"请求内容: {json.dumps(search_request, ensure_ascii=False, indent=2)}")
38
+ print()
39
+
40
+ # 发送请求并接收流式响应
41
+ async with client.stream(
42
+ "POST",
43
+ f"http://localhost:{settings.LISTEN_PORT}/v1/chat/completions",
44
+ json=search_request,
45
+ headers=headers,
46
+ timeout=30.0
47
+ ) as response:
48
+
49
+ print(f"📥 响应状态: {response.status_code}")
50
+
51
+ if response.status_code == 200:
52
+ print(f"✅ 请求成功,开始接收流式响应...")
53
+ print(f"💡 请查看服务器日志以确认是否正确添加了 deep-web-search MCP 服务器")
54
+ print()
55
+
56
+ # 读取前几个响应块
57
+ chunk_count = 0
58
+ async for line in response.aiter_lines():
59
+ if line.startswith("data: "):
60
+ chunk_count += 1
61
+ if chunk_count <= 3: # 只显示前3个块
62
+ data = line[6:] # 去掉 "data: " 前缀
63
+ if data.strip() and data.strip() != "[DONE]":
64
+ try:
65
+ chunk_data = json.loads(data)
66
+ content = chunk_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
67
+ if content:
68
+ print(f"📦 响应块 {chunk_count}: {content}")
69
+ except:
70
+ pass
71
+ elif chunk_count > 10: # 读取足够的块后停止
72
+ break
73
+
74
+ print(f"\n✅ 流式响应正常,共接收 {chunk_count} 个数据块")
75
+ print(f"🔍 请检查服务器日志中是否包含以下信息:")
76
+ print(f" - '模型特性检测: is_search=True'")
77
+ print(f" - '🔍 检测到搜索模型,添加 deep-web-search MCP 服务器'")
78
+ print(f" - 'MCP服务器列表: [\"deep-web-search\"]'")
79
+
80
+ else:
81
+ error_text = await response.aread()
82
+ print(f"❌ 请求失败: {response.status_code}")
83
+ print(f"错误信息: {error_text.decode('utf-8', errors='ignore')}")
84
+
85
+ except httpx.ConnectError:
86
+ print(f"❌ 无法连接到服务器 localhost:{settings.LISTEN_PORT}")
87
+ print(f" 请确保服务器正在运行: python main.py")
88
+ except Exception as e:
89
+ print(f"❌ 请求异常: {e}")
90
+
91
+ async def main():
92
+ """主函数"""
93
+ print("=" * 60)
94
+ print("GLM-4.5-Search 实时服务器测试")
95
+ print("=" * 60)
96
+ print()
97
+
98
+ await test_live_server()
99
+
100
+ print()
101
+ print("=" * 60)
102
+ print("测试完成")
103
+ print("=" * 60)
104
+ print()
105
+ print("📋 检查清单:")
106
+ print("1. 服务器是否正常响应 GLM-4.5-Search 请求?")
107
+ print("2. 日志中是否显示 'is_search=True'?")
108
+ print("3. 日志中是否显示添加 deep-web-search MCP 服务器?")
109
+ print("4. 如果以上信息缺失,请重启服务器以加载最新代码")
110
+
111
+ if __name__ == "__main__":
112
+ asyncio.run(main())
tests/test_model_comparison.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 对比不同模型的搜索行为
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ import httpx
11
+ from app.core.config import settings
12
+
13
+ async def test_model(model_name: str, question: str):
14
+ """测试特定模型的响应"""
15
+
16
+ print(f"🧪 测试模型: {model_name}")
17
+ print(f"问题: {question}")
18
+ print()
19
+
20
+ try:
21
+ async with httpx.AsyncClient() as client:
22
+ request_data = {
23
+ "model": model_name,
24
+ "messages": [
25
+ {"role": "user", "content": question}
26
+ ],
27
+ "stream": False # 使用非流式以便完整查看响应
28
+ }
29
+
30
+ headers = {
31
+ "Content-Type": "application/json",
32
+ "Authorization": f"Bearer {settings.AUTH_TOKEN}"
33
+ }
34
+
35
+ response = await client.post(
36
+ f"http://localhost:{settings.LISTEN_PORT}/v1/chat/completions",
37
+ json=request_data,
38
+ headers=headers,
39
+ timeout=60.0
40
+ )
41
+
42
+ if response.status_code == 200:
43
+ result = response.json()
44
+ content = result["choices"][0]["message"]["content"]
45
+ print(f"✅ 响应成功:")
46
+ print(f"内容: {content[:200]}...")
47
+ print()
48
+
49
+ # 检查是否包含搜索相关的内容
50
+ search_indicators = [
51
+ "搜索", "查询", "实时", "最新", "网络", "互联网",
52
+ "search", "query", "real-time", "latest", "web", "internet"
53
+ ]
54
+
55
+ has_search_content = any(indicator in content.lower() for indicator in search_indicators)
56
+ if has_search_content:
57
+ print(f"🔍 检测到搜索相关内容")
58
+ else:
59
+ print(f"❌ 未检测到搜索相关内容")
60
+
61
+ return content
62
+ else:
63
+ print(f"❌ 请求失败: {response.status_code}")
64
+ print(f"错误: {response.text}")
65
+ return None
66
+
67
+ except Exception as e:
68
+ print(f"❌ 请求异常: {e}")
69
+ return None
70
+
71
+ async def main():
72
+ """主测试函数"""
73
+ print("=" * 80)
74
+ print("GLM模型搜索能力对比测试")
75
+ print("=" * 80)
76
+ print()
77
+
78
+ # 测试问题
79
+ search_question = "请搜索今天北京的天气情况"
80
+ general_question = "你好,请介绍一下自己"
81
+
82
+ models_to_test = [
83
+ "GLM-4.5",
84
+ "GLM-4.5-Search",
85
+ "GLM-4.5-Thinking",
86
+ "GLM-4.5-Air"
87
+ ]
88
+
89
+ print("🔍 测试搜索相关问题:")
90
+ print(f"问题: {search_question}")
91
+ print("-" * 80)
92
+
93
+ for model in models_to_test:
94
+ await test_model(model, search_question)
95
+ print("-" * 40)
96
+
97
+ print()
98
+ print("💬 测试一般问题:")
99
+ print(f"问题: {general_question}")
100
+ print("-" * 80)
101
+
102
+ for model in models_to_test:
103
+ await test_model(model, general_question)
104
+ print("-" * 40)
105
+
106
+ print()
107
+ print("=" * 80)
108
+ print("测试完成")
109
+ print("=" * 80)
110
+ print()
111
+ print("📋 分析要点:")
112
+ print("1. GLM-4.5-Search 是否表现出不同的搜索行为?")
113
+ print("2. 其他模型是否都拒绝搜索请求?")
114
+ print("3. 模型响应中是否包含实际的搜索结果?")
115
+ print("4. 检查服务器日志中的MCP服务器配置是否正确")
116
+
117
+ if __name__ == "__main__":
118
+ asyncio.run(main())
tests/test_search_model.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 测试GLM-4.5-Search模型的deep-web-search MCP服务器功能
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ import httpx
11
+ from app.core.config import settings
12
+ from app.core.zai_transformer import ZAITransformer
13
+ from app.utils.logger import setup_logger
14
+
15
+ # 设置日志
16
+ logger = setup_logger(log_dir="logs", debug_mode=True)
17
+
18
+ async def test_search_model_mcp():
19
+ """测试搜索模型的MCP服务器配置"""
20
+
21
+ # 创建转换器实例
22
+ transformer = ZAITransformer()
23
+
24
+ # 模拟OpenAI请求 - 使用GLM-4.5-Search模型
25
+ openai_request = {
26
+ "model": "GLM-4.5-Search",
27
+ "messages": [
28
+ {"role": "user", "content": "请搜索一下今天的新闻"}
29
+ ],
30
+ "stream": True
31
+ }
32
+
33
+ print(f"🧪 测试请求:")
34
+ print(f" 模型: {openai_request['model']}")
35
+ print(f" SEARCH_MODEL配置: {settings.SEARCH_MODEL}")
36
+ print(f" 模型匹配: {openai_request['model'] == settings.SEARCH_MODEL}")
37
+ print()
38
+
39
+ try:
40
+ # 转换请求
41
+ transformed = await transformer.transform_request_in(openai_request)
42
+
43
+ print(f"✅ 转换成功!")
44
+ print(f" 上游模型: {transformed['body']['model']}")
45
+ print(f" MCP服务器: {transformed['body']['mcp_servers']}")
46
+ print(f" web_search特性: {transformed['body']['features']['web_search']}")
47
+ print(f" auto_web_search特性: {transformed['body']['features']['auto_web_search']}")
48
+ print()
49
+
50
+ # 检查是否正确添加了deep-web-search
51
+ mcp_servers = transformed['body']['mcp_servers']
52
+ if "deep-web-search" in mcp_servers:
53
+ print("✅ deep-web-search MCP服务器已正确添加!")
54
+ else:
55
+ print("❌ deep-web-search MCP服务器未添加!")
56
+ print(f" 实际MCP服务器列表: {mcp_servers}")
57
+
58
+ return transformed
59
+
60
+ except Exception as e:
61
+ print(f"❌ 转换失败: {e}")
62
+ return None
63
+
64
+ async def test_non_search_model():
65
+ """测试非搜索模型不应该添加MCP服务器"""
66
+
67
+ transformer = ZAITransformer()
68
+
69
+ # 模拟OpenAI请求 - 使用普通GLM-4.5模型
70
+ openai_request = {
71
+ "model": "GLM-4.5",
72
+ "messages": [
73
+ {"role": "user", "content": "你好"}
74
+ ],
75
+ "stream": True
76
+ }
77
+
78
+ print(f"🧪 测试普通模型:")
79
+ print(f" 模型: {openai_request['model']}")
80
+ print()
81
+
82
+ try:
83
+ # 转换请求
84
+ transformed = await transformer.transform_request_in(openai_request)
85
+
86
+ print(f"✅ 转换成功!")
87
+ print(f" 上游模型: {transformed['body']['model']}")
88
+ print(f" MCP服务器: {transformed['body']['mcp_servers']}")
89
+ print(f" web_search特性: {transformed['body']['features']['web_search']}")
90
+ print()
91
+
92
+ # 检查MCP服务器列表应该为空
93
+ mcp_servers = transformed['body']['mcp_servers']
94
+ if not mcp_servers:
95
+ print("✅ 普通模型正确地没有添加MCP服务器!")
96
+ else:
97
+ print(f"❌ 普通模型意外添加了MCP服务器: {mcp_servers}")
98
+
99
+ return transformed
100
+
101
+ except Exception as e:
102
+ print(f"❌ 转换失败: {e}")
103
+ return None
104
+
105
+ async def test_actual_request():
106
+ """测试实际的HTTP请求"""
107
+
108
+ print(f"🌐 测试实际HTTP请求到本地服务器...")
109
+
110
+ # 检查服务器是否运行
111
+ try:
112
+ async with httpx.AsyncClient() as client:
113
+ # 测试服务器是否可达
114
+ response = await client.get(f"http://localhost:{settings.LISTEN_PORT}/v1/models", timeout=5.0)
115
+ if response.status_code != 200:
116
+ print(f"❌ 服务器未运行或不可达: {response.status_code}")
117
+ return
118
+
119
+ print(f"✅ 服务器运行正常")
120
+
121
+ # 发送搜索模型请求
122
+ search_request = {
123
+ "model": "GLM-4.5-Search",
124
+ "messages": [
125
+ {"role": "user", "content": "搜索今天的天气"}
126
+ ],
127
+ "stream": False
128
+ }
129
+
130
+ headers = {
131
+ "Content-Type": "application/json",
132
+ "Authorization": f"Bearer {settings.AUTH_TOKEN}"
133
+ }
134
+
135
+ print(f"📤 发送搜索请求...")
136
+ response = await client.post(
137
+ f"http://localhost:{settings.LISTEN_PORT}/v1/chat/completions",
138
+ json=search_request,
139
+ headers=headers,
140
+ timeout=30.0
141
+ )
142
+
143
+ print(f"📥 响应状态: {response.status_code}")
144
+ if response.status_code == 200:
145
+ print(f"✅ 请求成功!")
146
+ # 不打印完整响应,只显示状态
147
+ else:
148
+ print(f"❌ 请求失败: {response.text}")
149
+
150
+ except httpx.ConnectError:
151
+ print(f"❌ 无法连接到服务器 localhost:{settings.LISTEN_PORT}")
152
+ print(f" 请确保服务器正在运行: python main.py")
153
+ except Exception as e:
154
+ print(f"❌ 请求异常: {e}")
155
+
156
+ async def main():
157
+ """主测试函数"""
158
+ print("=" * 60)
159
+ print("GLM-4.5-Search MCP服务器测试")
160
+ print("=" * 60)
161
+ print()
162
+
163
+ # 测试1: 搜索模型应该添加MCP服务器
164
+ await test_search_model_mcp()
165
+ print()
166
+
167
+ # 测试2: 普通模型不应该添加MCP服务器
168
+ await test_non_search_model()
169
+ print()
170
+
171
+ # 测试3: 实际HTTP请求(如果服务器运行)
172
+ await test_actual_request()
173
+ print()
174
+
175
+ print("=" * 60)
176
+ print("测试完成")
177
+ print("=" * 60)
178
+
179
+ if __name__ == "__main__":
180
+ asyncio.run(main())
tests/test_service_uniqueness.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 测试服务唯一性验证功能
6
+ """
7
+
8
+ import time
9
+ import subprocess
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ from app.core.config import settings
14
+ from app.utils.process_manager import ProcessManager, ensure_service_uniqueness
15
+ from app.utils.logger import setup_logger
16
+
17
+ # 设置日志
18
+ logger = setup_logger(log_dir="logs", debug_mode=True)
19
+
20
+
21
+ def test_process_manager():
22
+ """测试进程管理器功能"""
23
+ print("=" * 60)
24
+ print("测试进程管理器功能")
25
+ print("=" * 60)
26
+
27
+ service_name = "test-z-ai2api-server"
28
+ port = 8081
29
+
30
+ # 创建进程管理器
31
+ manager = ProcessManager(service_name=service_name, port=port)
32
+
33
+ print(f"\n1. 测试服务唯一性检查...")
34
+ print(f" 服务名称: {service_name}")
35
+ print(f" 端口: {port}")
36
+
37
+ # 第一次检查应该通过
38
+ result1 = manager.check_service_uniqueness()
39
+ print(f" 第一次检查结果: {'✅ 通过' if result1 else '❌ 失败'}")
40
+
41
+ if result1:
42
+ # 创建 PID 文件
43
+ manager.create_pid_file()
44
+ print(f" 已创建 PID 文件: {manager.pid_file}")
45
+
46
+ # 第二次检查应该失败(因为 PID 文件存在且进程运行中)
47
+ manager2 = ProcessManager(service_name=service_name, port=port)
48
+ result2 = manager2.check_service_uniqueness()
49
+ print(f" 第二次检查结果: {'✅ 通过' if result2 else '❌ 失败(预期)'}")
50
+
51
+ # 清理
52
+ manager.cleanup_on_exit()
53
+ print(f" 已清理 PID 文件")
54
+
55
+ # 第三次检查应该通过
56
+ manager3 = ProcessManager(service_name=service_name, port=port)
57
+ result3 = manager3.check_service_uniqueness()
58
+ print(f" 第三次检查结果: {'✅ 通过' if result3 else '❌ 失败'}")
59
+
60
+
61
+ def test_convenience_function():
62
+ """测试便捷函数"""
63
+ print("\n" + "=" * 60)
64
+ print("测试便捷函数")
65
+ print("=" * 60)
66
+
67
+ service_name = "test-convenience-server"
68
+ port = 8082
69
+
70
+ print(f"\n2. 测试便捷函数...")
71
+ print(f" 服务名称: {service_name}")
72
+ print(f" 端口: {port}")
73
+
74
+ # 第一次调用应该成功
75
+ result1 = ensure_service_uniqueness(service_name=service_name, port=port)
76
+ print(f" 第一次调用结果: {'✅ 成功' if result1 else '❌ 失败'}")
77
+
78
+ if result1:
79
+ # 第二次调用应该失败
80
+ result2 = ensure_service_uniqueness(service_name=service_name, port=port)
81
+ print(f" 第二次调用结果: {'✅ 成功' if result2 else '❌ 失败(预期)'}")
82
+
83
+ # 手动清理
84
+ pid_file = Path(f"{service_name}.pid")
85
+ if pid_file.exists():
86
+ pid_file.unlink()
87
+ print(f" 已手动清理 PID 文件")
88
+
89
+
90
+ def test_real_service():
91
+ """测试真实服务场景"""
92
+ print("\n" + "=" * 60)
93
+ print("测试真实服务场景")
94
+ print("=" * 60)
95
+
96
+ service_name = settings.SERVICE_NAME
97
+ port = settings.LISTEN_PORT
98
+
99
+ print(f"\n3. 测试真实服务场景...")
100
+ print(f" 服务名称: {service_name}")
101
+ print(f" 端口: {port}")
102
+
103
+ # 检查当前是否有服务运行
104
+ manager = ProcessManager(service_name=service_name, port=port)
105
+ instances = manager.get_running_instances()
106
+
107
+ if instances:
108
+ print(f" 发现 {len(instances)} 个运行中的实例:")
109
+ for instance in instances:
110
+ print(f" PID: {instance['pid']}, 启动时间: {instance['start_time']}")
111
+ else:
112
+ print(" 未发现运行中的实例")
113
+
114
+ # 测试唯一性检查
115
+ result = manager.check_service_uniqueness()
116
+ print(f" 唯一性检查结果: {'✅ 可以启动' if result else '❌ 已有实例运行'}")
117
+
118
+
119
+ def test_port_conflict():
120
+ """测试端口冲突检测"""
121
+ print("\n" + "=" * 60)
122
+ print("测试端口冲突检测")
123
+ print("=" * 60)
124
+
125
+ print(f"\n4. 测试端口冲突检测...")
126
+
127
+ # 尝试检测一些常用端口
128
+ test_ports = [80, 443, 8080, 3000, 5000]
129
+
130
+ for port in test_ports:
131
+ manager = ProcessManager(service_name="test-port-check", port=port)
132
+ is_occupied = manager._check_port_usage()
133
+ print(f" 端口 {port}: {'❌ 被占用' if is_occupied else '✅ 可用'}")
134
+
135
+
136
+ def main():
137
+ """主测试函数"""
138
+ print("🧪 Z.AI2API 服务唯一性验证测试")
139
+ print("=" * 60)
140
+ print("此测试将验证以下功能:")
141
+ print("1. 进程管理器基本功能")
142
+ print("2. 便捷函数功能")
143
+ print("3. 真实服务场景")
144
+ print("4. 端口冲突检测")
145
+ print("=" * 60)
146
+
147
+ try:
148
+ # 运行所有测试
149
+ test_process_manager()
150
+ test_convenience_function()
151
+ test_real_service()
152
+ test_port_conflict()
153
+
154
+ print("\n" + "=" * 60)
155
+ print("✅ 所有测试完成")
156
+ print("=" * 60)
157
+
158
+ print("\n📋 使用说��:")
159
+ print("1. 启动服务时会自动进行唯一性检查")
160
+ print("2. 如果检测到已有实例运行,新实例将拒绝启动")
161
+ print("3. 可以通过环境变量 SERVICE_NAME 自定义服务名称")
162
+ print("4. PID 文件会在服务正常退出时自动清理")
163
+ print("5. 异常退出的 PID 文件会在下次启动时自动清理")
164
+
165
+ except Exception as e:
166
+ logger.error(f"❌ 测试过程中发生错误: {e}")
167
+ import traceback
168
+ traceback.print_exc()
169
+ sys.exit(1)
170
+
171
+
172
+ if __name__ == "__main__":
173
+ main()
tests/test_sse_optimization.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 测试 SSE 工具调用处理器的优化效果
6
+ """
7
+
8
+ import json
9
+ import time
10
+ from app.utils.sse_tool_handler import SSEToolHandler
11
+ from app.utils.logger import get_logger
12
+
13
+ logger = get_logger()
14
+
15
+ def test_tool_call_processing():
16
+ """测试工具调用处理的优化效果"""
17
+
18
+ # 创建处理器
19
+ handler = SSEToolHandler("test_chat_id", "GLM-4.5")
20
+
21
+ # 模拟 Z.AI 的原始响应数据(基于文档中的示例)
22
+ test_data_sequence = [
23
+ # 第一个数据块 - 工具调用开始
24
+ {
25
+ "edit_index": 22,
26
+ "edit_content": '\n\n<glm_block >{"type": "mcp", "data": {"metadata": {"id": "call_fyh97tn03ow", "name": "playwri-browser_navigate", "arguments": "{\\"url\\":\\"https://www.goo',
27
+ "phase": "tool_call"
28
+ },
29
+ # 第二个数据块 - 参数补全
30
+ {
31
+ "edit_index": 176,
32
+ "edit_content": 'gle.com\\"}", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
33
+ "phase": "tool_call"
34
+ },
35
+ # 第三个数据块 - 工具调用结束
36
+ {
37
+ "edit_index": 199,
38
+ "edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
39
+ "phase": "other"
40
+ }
41
+ ]
42
+
43
+ print("🧪 开始测试 SSE 工具调用处理器优化...")
44
+
45
+ # 处理数据序列
46
+ all_chunks = []
47
+ for i, data in enumerate(test_data_sequence):
48
+ print(f"\n📦 处理数据块 {i+1}: phase={data['phase']}, edit_index={data['edit_index']}")
49
+
50
+ if data["phase"] == "tool_call":
51
+ chunks = list(handler.process_tool_call_phase(data, is_stream=True))
52
+ else:
53
+ chunks = list(handler.process_other_phase(data, is_stream=True))
54
+
55
+ all_chunks.extend(chunks)
56
+
57
+ # 打印生成的块
58
+ for j, chunk in enumerate(chunks):
59
+ if chunk.strip():
60
+ print(f" 📤 输出块 {j+1}: {chunk[:100]}...")
61
+
62
+ print(f"\n✅ 测试完成,共生成 {len(all_chunks)} 个输出块")
63
+
64
+ # 验证工具调用是否正确解析
65
+ print(f"🔧 活跃工具数: {len(handler.active_tools)}")
66
+ print(f"✅ 完成工具数: {len(handler.completed_tools)}")
67
+
68
+ # 打印最终的内容缓冲区
69
+ try:
70
+ final_content = handler.content_buffer.decode('utf-8', errors='ignore')
71
+ print(f"\n📝 最终内容缓冲区长度: {len(final_content)}")
72
+ print(f"📝 内容预览: {final_content[:200]}...")
73
+ except Exception as e:
74
+ print(f"❌ 内容缓冲区解析失败: {e}")
75
+
76
+ def test_partial_arguments_parsing():
77
+ """测试部分参数解析功能"""
78
+
79
+ handler = SSEToolHandler("test_chat_id", "GLM-4.5")
80
+
81
+ # 测试各种不完整的参数
82
+ test_cases = [
83
+ '{"url":"https://www.goo', # 不完整的URL
84
+ '{"city":"北京', # 缺少引号和括号
85
+ '{"query":"test", "limit":', # 不完整的数值
86
+ '{"name":"test"', # 缺少结束括号
87
+ '', # 空字符串
88
+ '{', # 只有开始括号
89
+ ]
90
+
91
+ print("\n🧪 测试部分参数解析...")
92
+
93
+ for i, test_arg in enumerate(test_cases):
94
+ print(f"\n📦 测试用例 {i+1}: {test_arg}")
95
+ result = handler._parse_partial_arguments(test_arg)
96
+ print(f" ✅ 解析结果: {result}")
97
+
98
+ def test_performance():
99
+ """测试性能优化效果"""
100
+
101
+ print("\n🚀 测试性能优化效果...")
102
+
103
+ # 创建大量数据进行性能测试
104
+ handler = SSEToolHandler("test_chat_id", "GLM-4.5")
105
+
106
+ # 模拟大量的编辑操作
107
+ start_time = time.time()
108
+
109
+ for i in range(1000):
110
+ edit_data = {
111
+ "edit_index": i * 10,
112
+ "edit_content": f"test_content_{i}",
113
+ "phase": "tool_call"
114
+ }
115
+ list(handler.process_tool_call_phase(edit_data, is_stream=False))
116
+
117
+ end_time = time.time()
118
+
119
+ print(f"⏱️ 处理1000次编辑操作耗时: {end_time - start_time:.3f}秒")
120
+ print(f"📊 平均每次操作耗时: {(end_time - start_time) * 1000 / 1000:.3f}毫秒")
121
+
122
+ if __name__ == "__main__":
123
+ try:
124
+ test_tool_call_processing()
125
+ test_partial_arguments_parsing()
126
+ test_performance()
127
+ print("\n🎉 所有测试完成!")
128
+ except Exception as e:
129
+ logger.error(f"❌ 测试失败: {e}")
130
+ import traceback
131
+ traceback.print_exc()
tests/test_tool_handler_optimized.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 测试优化后的SSE工具调用处理器
6
+ 基于真实的Z.AI响应格式和日志数据进行全面测试
7
+ """
8
+
9
+ import json
10
+ import time
11
+ import traceback
12
+ from typing import List, Dict, Any
13
+ from app.utils.sse_tool_handler import SSEToolHandler
14
+ from app.utils.logger import get_logger
15
+
16
+ logger = get_logger()
17
+
18
+
19
+ class TestResult:
20
+ """测试结果类"""
21
+ def __init__(self, name: str):
22
+ self.name = name
23
+ self.passed = 0
24
+ self.failed = 0
25
+ self.errors = []
26
+
27
+ def add_pass(self):
28
+ self.passed += 1
29
+
30
+ def add_fail(self, error: str):
31
+ self.failed += 1
32
+ self.errors.append(error)
33
+
34
+ def print_summary(self):
35
+ total = self.passed + self.failed
36
+ success_rate = (self.passed / total * 100) if total > 0 else 0
37
+
38
+ print(f"\n📊 {self.name} 测试结果:")
39
+ print(f" ✅ 通过: {self.passed}")
40
+ print(f" ❌ 失败: {self.failed}")
41
+ print(f" 📈 成功率: {success_rate:.1f}%")
42
+
43
+ if self.errors:
44
+ print(f" 🔍 错误详情:")
45
+ for i, error in enumerate(self.errors, 1):
46
+ print(f" {i}. {error}")
47
+
48
+
49
+ def parse_openai_chunk(chunk_data: str) -> Dict[str, Any]:
50
+ """解析OpenAI格式的chunk数据"""
51
+ try:
52
+ if chunk_data.startswith("data: "):
53
+ chunk_data = chunk_data[6:] # 移除 "data: " 前缀
54
+ if chunk_data.strip() == "[DONE]":
55
+ return {"type": "done"}
56
+ return json.loads(chunk_data)
57
+ except json.JSONDecodeError:
58
+ return {"type": "invalid", "raw": chunk_data}
59
+
60
+
61
+ def extract_tool_calls(chunks: List[str]) -> List[Dict[str, Any]]:
62
+ """从chunk列表中提取工具调用信息"""
63
+ tools = []
64
+ current_tool = None
65
+
66
+ for chunk in chunks:
67
+ parsed = parse_openai_chunk(chunk)
68
+ if parsed.get("type") == "invalid":
69
+ continue
70
+
71
+ choices = parsed.get("choices", [])
72
+ if not choices:
73
+ continue
74
+
75
+ delta = choices[0].get("delta", {})
76
+ tool_calls = delta.get("tool_calls", [])
77
+
78
+ for tc in tool_calls:
79
+ if tc.get("function", {}).get("name"): # 新工具开始
80
+ current_tool = {
81
+ "id": tc.get("id"),
82
+ "name": tc["function"]["name"],
83
+ "arguments": ""
84
+ }
85
+ tools.append(current_tool)
86
+ elif tc.get("function", {}).get("arguments") and current_tool: # 参数累积
87
+ current_tool["arguments"] += tc["function"]["arguments"]
88
+
89
+ # 解析最终参数
90
+ for tool in tools:
91
+ try:
92
+ tool["parsed_arguments"] = json.loads(tool["arguments"]) if tool["arguments"] else {}
93
+ except json.JSONDecodeError:
94
+ tool["parsed_arguments"] = {}
95
+
96
+ return tools
97
+
98
+
99
+ def test_real_world_scenarios():
100
+ """测试基于真实Z.AI响应的工具调用处理"""
101
+
102
+ result = TestResult("真实场景测试")
103
+
104
+ # 基于实际日志的测试数据
105
+ test_scenarios = [
106
+ {
107
+ "name": "浏览器导航工具调用",
108
+ "description": "模拟打开Google网站的工具调用",
109
+ "expected_tools": [
110
+ {
111
+ "name": "playwri-browser_navigate",
112
+ "id": "call_fyh97tn03ow",
113
+ "arguments": {"url": "https://www.google.com"}
114
+ }
115
+ ],
116
+ "data_sequence": [
117
+ {
118
+ "edit_index": 22,
119
+ "edit_content": '\n\n<glm_block >{"type": "mcp", "data": {"metadata": {"id": "call_fyh97tn03ow", "name": "playwri-browser_navigate", "arguments": "{\\"url\\":\\"https://www.goo',
120
+ "phase": "tool_call"
121
+ },
122
+ {
123
+ "edit_index": 176,
124
+ "edit_content": 'gle.com\\"}", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
125
+ "phase": "tool_call"
126
+ },
127
+ {
128
+ "edit_index": 199,
129
+ "edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
130
+ "phase": "other"
131
+ }
132
+ ]
133
+ },
134
+ {
135
+ "name": "天气查询工具调用",
136
+ "description": "模拟查询上海天气的工具调用",
137
+ "expected_tools": [
138
+ {
139
+ "name": "search",
140
+ "id": "call_qsn2jby8al",
141
+ "arguments": {"queries": ["今天上海天气", "上海天气预报 今天"]}
142
+ }
143
+ ],
144
+ "data_sequence": [
145
+ {
146
+ "edit_index": 16,
147
+ "edit_content": '\n\n<glm_block >{"type": "mcp", "data": {"metadata": {"id": "call_qsn2jby8al", "name": "search", "arguments": "{\\"queries\\":[\\"今天上海天气\\", \\"',
148
+ "phase": "tool_call"
149
+ },
150
+ {
151
+ "edit_index": 183,
152
+ "edit_content": '上海天气预报 今天\\"]}", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
153
+ "phase": "tool_call"
154
+ }
155
+ ]
156
+ },
157
+ {
158
+ "name": "多工具调用序列",
159
+ "description": "模拟连续的多个工具调用",
160
+ "expected_tools": [
161
+ {
162
+ "name": "search",
163
+ "id": "call_001",
164
+ "arguments": {"query": "北京天气"}
165
+ },
166
+ {
167
+ "name": "visit_page",
168
+ "id": "call_002",
169
+ "arguments": {"url": "https://weather.com"}
170
+ }
171
+ ],
172
+ "data_sequence": [
173
+ {
174
+ "edit_index": 0,
175
+ "edit_content": '<glm_block >{"type": "mcp", "data": {"metadata": {"id": "call_001", "name": "search", "arguments": "{\\"query\\":\\"北京天气\\"}", "result": "", "status": "completed"}}, "thought": null}}</glm_block>',
176
+ "phase": "tool_call"
177
+ },
178
+ {
179
+ "edit_index": 200,
180
+ "edit_content": '\n\n<glm_block >{"type": "mcp", "data": {"metadata": {"id": "call_002", "name": "visit_page", "arguments": "{\\"url\\":\\"https://weather.com\\"}", "result": "", "status": "completed"}}, "thought": null}}</glm_block>',
181
+ "phase": "tool_call"
182
+ }
183
+ ]
184
+ }
185
+ ]
186
+
187
+ print(f"\n🧪 开始执行 {len(test_scenarios)} 个真实场景测试...")
188
+
189
+ # 执行每个测试场景
190
+ for i, scenario in enumerate(test_scenarios, 1):
191
+ print(f"\n{'='*60}")
192
+ print(f"测试 {i}: {scenario['name']}")
193
+ print(f"描述: {scenario['description']}")
194
+ print('='*60)
195
+
196
+ try:
197
+ # 创建新的处理器实例
198
+ handler = SSEToolHandler("test_chat_id", "GLM-4.5")
199
+
200
+ # 处理数据序列
201
+ all_chunks = []
202
+ for j, data in enumerate(scenario["data_sequence"]):
203
+ print(f"\n📦 处理数据块 {j+1}: phase={data['phase']}, edit_index={data['edit_index']}")
204
+
205
+ if data["phase"] == "tool_call":
206
+ chunks = list(handler.process_tool_call_phase(data, is_stream=True))
207
+ else:
208
+ chunks = list(handler.process_other_phase(data, is_stream=True))
209
+
210
+ all_chunks.extend(chunks)
211
+
212
+ # 提取工具调用信息
213
+ extracted_tools = extract_tool_calls(all_chunks)
214
+
215
+ # 验证结果
216
+ expected_tools = scenario["expected_tools"]
217
+
218
+ print(f"\n📊 验证结果:")
219
+ print(f" 期望工具数: {len(expected_tools)}")
220
+ print(f" 实际工具数: {len(extracted_tools)}")
221
+
222
+ # 详细验证每个工具
223
+ for k, expected_tool in enumerate(expected_tools):
224
+ if k < len(extracted_tools):
225
+ actual_tool = extracted_tools[k]
226
+
227
+ # 验证工具名称
228
+ name_match = actual_tool["name"] == expected_tool["name"]
229
+ # 验证工具ID
230
+ id_match = actual_tool["id"] == expected_tool["id"]
231
+ # 验证参数
232
+ args_match = actual_tool["parsed_arguments"] == expected_tool["arguments"]
233
+
234
+ if name_match and id_match and args_match:
235
+ print(f" ✅ 工具 {k+1}: {expected_tool['name']} - 验证通过")
236
+ result.add_pass()
237
+ else:
238
+ error_details = []
239
+ if not name_match:
240
+ error_details.append(f"名称不匹配: 期望'{expected_tool['name']}', 实际'{actual_tool['name']}'")
241
+ if not id_match:
242
+ error_details.append(f"ID不匹配: 期望'{expected_tool['id']}', 实际'{actual_tool['id']}'")
243
+ if not args_match:
244
+ error_details.append(f"参数不匹配: 期望{expected_tool['arguments']}, 实际{actual_tool['parsed_arguments']}")
245
+
246
+ error_msg = f"工具 {k+1} 验证失败: {'; '.join(error_details)}"
247
+ print(f" ❌ {error_msg}")
248
+ result.add_fail(error_msg)
249
+ else:
250
+ error_msg = f"缺少工具 {k+1}: {expected_tool['name']}"
251
+ print(f" ❌ {error_msg}")
252
+ result.add_fail(error_msg)
253
+
254
+ # 显示提取的工具详情
255
+ if extracted_tools:
256
+ print(f"\n🔍 提取的工具详情:")
257
+ for tool in extracted_tools:
258
+ print(f" - {tool['name']}(id={tool['id']})")
259
+ print(f" 参数: {tool['parsed_arguments']}")
260
+
261
+ except Exception as e:
262
+ error_msg = f"测试 {scenario['name']} 执行失败: {str(e)}"
263
+ print(f"❌ {error_msg}")
264
+ result.add_fail(error_msg)
265
+ logger.error(f"测试执行异常: {e}")
266
+
267
+ result.print_summary()
268
+ return result
269
+
270
+
271
+ def test_edge_cases():
272
+ """测试边界情况和异常处理"""
273
+
274
+ result = TestResult("边界情况测试")
275
+
276
+ edge_cases = [
277
+ {
278
+ "name": "空内容处理",
279
+ "data": {"edit_index": 0, "edit_content": "", "phase": "tool_call"},
280
+ "should_pass": True
281
+ },
282
+ {
283
+ "name": "无效JSON处理",
284
+ "data": {"edit_index": 0, "edit_content": '<glm_block >{"invalid": json}}</glm_block>', "phase": "tool_call"},
285
+ "should_pass": True # 应该优雅处理,不崩溃
286
+ },
287
+ {
288
+ "name": "不完整的glm_block",
289
+ "data": {"edit_index": 0, "edit_content": '<glm_block >{"type": "mcp", "data": {"metadata": {"id": "test"', "phase": "tool_call"},
290
+ "should_pass": True
291
+ },
292
+ {
293
+ "name": "超大edit_index",
294
+ "data": {"edit_index": 999999, "edit_content": "test", "phase": "tool_call"},
295
+ "should_pass": True
296
+ },
297
+ {
298
+ "name": "特殊字符处理",
299
+ "data": {"edit_index": 0, "edit_content": '<glm_block >{"type": "mcp", "data": {"metadata": {"id": "test", "name": "test", "arguments": "{\\"text\\":\\"测试\\u4e2d\\u6587\\"}"}}}</glm_block>', "phase": "tool_call"},
300
+ "should_pass": True
301
+ }
302
+ ]
303
+
304
+ print(f"\n🧪 开始执行 {len(edge_cases)} 个边界情况测试...")
305
+
306
+ for i, case in enumerate(edge_cases, 1):
307
+ print(f"\n📦 测试 {i}: {case['name']}")
308
+
309
+ try:
310
+ handler = SSEToolHandler("test_chat_id", "GLM-4.5")
311
+
312
+ # 处理数据
313
+ if case["data"]["phase"] == "tool_call":
314
+ chunks = list(handler.process_tool_call_phase(case["data"], is_stream=True))
315
+ else:
316
+ chunks = list(handler.process_other_phase(case["data"], is_stream=True))
317
+
318
+ # 检查是否按预期处理
319
+ if case["should_pass"]:
320
+ print(f" ✅ 成功处理,生成 {len(chunks)} 个输出块")
321
+ result.add_pass()
322
+ else:
323
+ print(f" ❌ 应该失败但成功了")
324
+ result.add_fail(f"{case['name']}: 应该失败但成功了")
325
+
326
+ except Exception as e:
327
+ if case["should_pass"]:
328
+ error_msg = f"{case['name']}: 意外异常 - {str(e)}"
329
+ print(f" ❌ {error_msg}")
330
+ result.add_fail(error_msg)
331
+ else:
332
+ print(f" ✅ 按预期失败: {str(e)}")
333
+ result.add_pass()
334
+
335
+ result.print_summary()
336
+ return result
337
+
338
+
339
+ def test_performance():
340
+ """测试性能表现"""
341
+
342
+ result = TestResult("性能测试")
343
+
344
+ print(f"\n🚀 开始性能测试...")
345
+
346
+ # 测试大量小块数据的处理性能
347
+ handler = SSEToolHandler("test_chat_id", "GLM-4.5")
348
+
349
+ start_time = time.time()
350
+
351
+ # 模拟1000次小的编辑操作
352
+ for i in range(1000):
353
+ data = {
354
+ "edit_index": i * 5,
355
+ "edit_content": f"chunk_{i}",
356
+ "phase": "tool_call"
357
+ }
358
+ list(handler.process_tool_call_phase(data, is_stream=False))
359
+
360
+ end_time = time.time()
361
+ duration = end_time - start_time
362
+
363
+ print(f"⏱️ 处理1000次编辑操作耗时: {duration:.3f}秒")
364
+ print(f"📊 平均每次操作耗时: {duration * 1000 / 1000:.3f}毫秒")
365
+
366
+ # 性能基准:每次操作应该在1毫秒以内
367
+ if duration < 1.0: # 1秒内完成1000次操作
368
+ print("✅ 性能测试通过")
369
+ result.add_pass()
370
+ else:
371
+ error_msg = f"性能测试失败: 耗时{duration:.3f}秒,超过1秒基准"
372
+ print(f"❌ {error_msg}")
373
+ result.add_fail(error_msg)
374
+
375
+ result.print_summary()
376
+ return result
377
+
378
+
379
+ def test_argument_parsing():
380
+ """测试参数解析功能"""
381
+
382
+ result = TestResult("参数解析测试")
383
+
384
+ print(f"\n🧪 开始参数解析测试...")
385
+
386
+ handler = SSEToolHandler("test", "test")
387
+
388
+ test_cases = [
389
+ ('{"city": "北京"}', {"city": "北京"}),
390
+ ('{"city": "北京', {"city": "北京"}), # 缺少闭合括号
391
+ ('{"city": "北京"', {"city": "北京"}), # 缺少闭合括号但有引号
392
+ ('{\\"city\\": \\"北京\\"}', {"city": "北京"}), # 转义的JSON
393
+ ('{}', {}), # 空参数
394
+ ('null', {}), # null参数
395
+ ('{"array": [1,2,3], "nested": {"key": "value"}}', {"array": [1,2,3], "nested": {"key": "value"}}), # 复杂参数
396
+ ('{"url":"https://www.goo', {"url": "https://www.goo"}), # 不完整的URL
397
+ ('', {}), # 空字符串
398
+ ('{', {}), # 只有开始括号
399
+ ]
400
+
401
+ for i, (input_str, expected) in enumerate(test_cases, 1):
402
+ try:
403
+ parsed_result = handler._parse_partial_arguments(input_str)
404
+ success = parsed_result == expected
405
+
406
+ if success:
407
+ print(f"✅ 测试 {i}: 解析成功")
408
+ result.add_pass()
409
+ else:
410
+ error_msg = f"测试 {i} 失败: 输入'{input_str[:30]}...', 期望{expected}, 实际{parsed_result}"
411
+ print(f"❌ {error_msg}")
412
+ result.add_fail(error_msg)
413
+
414
+ except Exception as e:
415
+ error_msg = f"测试 {i} 异常: 输入'{input_str[:30]}...', 错误: {str(e)}"
416
+ print(f"❌ {error_msg}")
417
+ result.add_fail(error_msg)
418
+
419
+ result.print_summary()
420
+ return result
421
+
422
+
423
+ def run_all_tests():
424
+ """运行所有测试"""
425
+
426
+ print("🧪 SSE工具调用处理器优化测试套件")
427
+ print("="*60)
428
+
429
+ all_results = []
430
+
431
+ try:
432
+ # 运行真实场景测试
433
+ print("\n1️⃣ 真实场景测试")
434
+ all_results.append(test_real_world_scenarios())
435
+
436
+ # 运行边界情况测试
437
+ print("\n2️⃣ 边界情况测试")
438
+ all_results.append(test_edge_cases())
439
+
440
+ # 运行参数解析测试
441
+ print("\n3️⃣ 参数解析测试")
442
+ all_results.append(test_argument_parsing())
443
+
444
+ # 运行性能测试
445
+ print("\n4️⃣ 性能测试")
446
+ all_results.append(test_performance())
447
+
448
+ # 汇总结果
449
+ print("\n" + "="*60)
450
+ print("📊 测试汇总")
451
+ print("="*60)
452
+
453
+ total_passed = sum(r.passed for r in all_results)
454
+ total_failed = sum(r.failed for r in all_results)
455
+ total_tests = total_passed + total_failed
456
+
457
+ print(f"总测试数: {total_tests}")
458
+ print(f"✅ 通过: {total_passed}")
459
+ print(f"❌ 失败: {total_failed}")
460
+
461
+ if total_tests > 0:
462
+ success_rate = (total_passed / total_tests) * 100
463
+ print(f"📈 总体成功率: {success_rate:.1f}%")
464
+
465
+ if success_rate >= 90:
466
+ print("🎉 测试结果优秀!")
467
+ elif success_rate >= 70:
468
+ print("👍 测试结果良好")
469
+ else:
470
+ print("⚠️ 需要改进")
471
+
472
+ # 显示失败的测试
473
+ failed_tests = []
474
+ for result in all_results:
475
+ failed_tests.extend(result.errors)
476
+
477
+ if failed_tests:
478
+ print(f"\n🔍 失败测试详情:")
479
+ for i, error in enumerate(failed_tests, 1):
480
+ print(f" {i}. {error}")
481
+
482
+ return total_failed == 0
483
+
484
+ except Exception as e:
485
+ print(f"❌ 测试套件执行失败: {e}")
486
+ traceback.print_exc()
487
+ return False
488
+
489
+
490
+ if __name__ == "__main__":
491
+ success = run_all_tests()
492
+ exit(0 if success else 1)