#!/usr/bin/env python # -*- coding: utf-8 -*- """ 测试优化后的SSE工具调用处理器 基于真实的Z.AI响应格式和日志数据进行全面测试 """ import json import time import traceback from typing import List, Dict, Any from app.utils.sse_tool_handler import SSEToolHandler from app.utils.logger import get_logger logger = get_logger() class TestResult: """测试结果类""" def __init__(self, name: str): self.name = name self.passed = 0 self.failed = 0 self.errors = [] def add_pass(self): self.passed += 1 def add_fail(self, error: str): self.failed += 1 self.errors.append(error) def print_summary(self): total = self.passed + self.failed success_rate = (self.passed / total * 100) if total > 0 else 0 print(f"\n📊 {self.name} 测试结果:") print(f" ✅ 通过: {self.passed}") print(f" ❌ 失败: {self.failed}") print(f" 📈 成功率: {success_rate:.1f}%") if self.errors: print(f" 🔍 错误详情:") for i, error in enumerate(self.errors, 1): print(f" {i}. {error}") def parse_openai_chunk(chunk_data: str) -> Dict[str, Any]: """解析OpenAI格式的chunk数据""" try: if chunk_data.startswith("data: "): chunk_data = chunk_data[6:] # 移除 "data: " 前缀 if chunk_data.strip() == "[DONE]": return {"type": "done"} return json.loads(chunk_data) except json.JSONDecodeError: return {"type": "invalid", "raw": chunk_data} def extract_tool_calls(chunks: List[str]) -> List[Dict[str, Any]]: """从chunk列表中提取工具调用信息""" tools = [] current_tool = None for chunk in chunks: parsed = parse_openai_chunk(chunk) if parsed.get("type") == "invalid": continue choices = parsed.get("choices", []) if not choices: continue delta = choices[0].get("delta", {}) tool_calls = delta.get("tool_calls", []) for tc in tool_calls: if tc.get("function", {}).get("name"): # 新工具开始 current_tool = { "id": tc.get("id"), "name": tc["function"]["name"], "arguments": "" } tools.append(current_tool) elif tc.get("function", {}).get("arguments") and current_tool: # 参数累积 current_tool["arguments"] += tc["function"]["arguments"] # 解析最终参数 for tool in tools: try: tool["parsed_arguments"] = json.loads(tool["arguments"]) if tool["arguments"] else {} except json.JSONDecodeError: tool["parsed_arguments"] = {} return tools def test_real_world_scenarios(): """测试基于真实Z.AI响应的工具调用处理""" result = TestResult("真实场景测试") # 基于实际日志的测试数据 test_scenarios = [ { "name": "浏览器导航工具调用", "description": "模拟打开Google网站的工具调用", "expected_tools": [ { "name": "playwri-browser_navigate", "id": "call_fyh97tn03ow", "arguments": {"url": "https://www.google.com"} } ], "data_sequence": [ { "edit_index": 22, "edit_content": '\n\n{"type": "mcp", "data": {"metadata": {"id": "call_fyh97tn03ow", "name": "playwri-browser_navigate", "arguments": "{\\"url\\":\\"https://www.goo', "phase": "tool_call" }, { "edit_index": 176, "edit_content": 'gle.com\\"}", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}', "phase": "tool_call" }, { "edit_index": 199, "edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}', "phase": "other" } ] }, { "name": "天气查询工具调用", "description": "模拟查询上海天气的工具调用", "expected_tools": [ { "name": "search", "id": "call_qsn2jby8al", "arguments": {"queries": ["今天上海天气", "上海天气预报 今天"]} } ], "data_sequence": [ { "edit_index": 16, "edit_content": '\n\n{"type": "mcp", "data": {"metadata": {"id": "call_qsn2jby8al", "name": "search", "arguments": "{\\"queries\\":[\\"今天上海天气\\", \\"', "phase": "tool_call" }, { "edit_index": 183, "edit_content": '上海天气预报 今天\\"]}", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}', "phase": "tool_call" } ] }, { "name": "多工具调用序列", "description": "模拟连续的多个工具调用", "expected_tools": [ { "name": "search", "id": "call_001", "arguments": {"query": "北京天气"} }, { "name": "visit_page", "id": "call_002", "arguments": {"url": "https://weather.com"} } ], "data_sequence": [ { "edit_index": 0, "edit_content": '{"type": "mcp", "data": {"metadata": {"id": "call_001", "name": "search", "arguments": "{\\"query\\":\\"北京天气\\"}", "result": "", "status": "completed"}}, "thought": null}}', "phase": "tool_call" }, { "edit_index": 200, "edit_content": '\n\n{"type": "mcp", "data": {"metadata": {"id": "call_002", "name": "visit_page", "arguments": "{\\"url\\":\\"https://weather.com\\"}", "result": "", "status": "completed"}}, "thought": null}}', "phase": "tool_call" } ] } ] print(f"\n🧪 开始执行 {len(test_scenarios)} 个真实场景测试...") # 执行每个测试场景 for i, scenario in enumerate(test_scenarios, 1): print(f"\n{'='*60}") print(f"测试 {i}: {scenario['name']}") print(f"描述: {scenario['description']}") print('='*60) try: # 创建新的处理器实例 handler = SSEToolHandler("test_chat_id", "GLM-4.5") # 处理数据序列 all_chunks = [] for j, data in enumerate(scenario["data_sequence"]): print(f"\n📦 处理数据块 {j+1}: phase={data['phase']}, edit_index={data['edit_index']}") if data["phase"] == "tool_call": chunks = list(handler.process_tool_call_phase(data, is_stream=True)) else: chunks = list(handler.process_other_phase(data, is_stream=True)) all_chunks.extend(chunks) # 提取工具调用信息 extracted_tools = extract_tool_calls(all_chunks) # 验证结果 expected_tools = scenario["expected_tools"] print(f"\n📊 验证结果:") print(f" 期望工具数: {len(expected_tools)}") print(f" 实际工具数: {len(extracted_tools)}") # 详细验证每个工具 for k, expected_tool in enumerate(expected_tools): if k < len(extracted_tools): actual_tool = extracted_tools[k] # 验证工具名称 name_match = actual_tool["name"] == expected_tool["name"] # 验证工具ID id_match = actual_tool["id"] == expected_tool["id"] # 验证参数 args_match = actual_tool["parsed_arguments"] == expected_tool["arguments"] if name_match and id_match and args_match: print(f" ✅ 工具 {k+1}: {expected_tool['name']} - 验证通过") result.add_pass() else: error_details = [] if not name_match: error_details.append(f"名称不匹配: 期望'{expected_tool['name']}', 实际'{actual_tool['name']}'") if not id_match: error_details.append(f"ID不匹配: 期望'{expected_tool['id']}', 实际'{actual_tool['id']}'") if not args_match: error_details.append(f"参数不匹配: 期望{expected_tool['arguments']}, 实际{actual_tool['parsed_arguments']}") error_msg = f"工具 {k+1} 验证失败: {'; '.join(error_details)}" print(f" ❌ {error_msg}") result.add_fail(error_msg) else: error_msg = f"缺少工具 {k+1}: {expected_tool['name']}" print(f" ❌ {error_msg}") result.add_fail(error_msg) # 显示提取的工具详情 if extracted_tools: print(f"\n🔍 提取的工具详情:") for tool in extracted_tools: print(f" - {tool['name']}(id={tool['id']})") print(f" 参数: {tool['parsed_arguments']}") except Exception as e: error_msg = f"测试 {scenario['name']} 执行失败: {str(e)}" print(f"❌ {error_msg}") result.add_fail(error_msg) logger.error(f"测试执行异常: {e}") result.print_summary() return result def test_edge_cases(): """测试边界情况和异常处理""" result = TestResult("边界情况测试") edge_cases = [ { "name": "空内容处理", "data": {"edit_index": 0, "edit_content": "", "phase": "tool_call"}, "should_pass": True }, { "name": "无效JSON处理", "data": {"edit_index": 0, "edit_content": '{"invalid": json}}', "phase": "tool_call"}, "should_pass": True # 应该优雅处理,不崩溃 }, { "name": "不完整的glm_block", "data": {"edit_index": 0, "edit_content": '{"type": "mcp", "data": {"metadata": {"id": "test"', "phase": "tool_call"}, "should_pass": True }, { "name": "超大edit_index", "data": {"edit_index": 999999, "edit_content": "test", "phase": "tool_call"}, "should_pass": True }, { "name": "特殊字符处理", "data": {"edit_index": 0, "edit_content": '{"type": "mcp", "data": {"metadata": {"id": "test", "name": "test", "arguments": "{\\"text\\":\\"测试\\u4e2d\\u6587\\"}"}}}', "phase": "tool_call"}, "should_pass": True } ] print(f"\n🧪 开始执行 {len(edge_cases)} 个边界情况测试...") for i, case in enumerate(edge_cases, 1): print(f"\n📦 测试 {i}: {case['name']}") try: handler = SSEToolHandler("test_chat_id", "GLM-4.5") # 处理数据 if case["data"]["phase"] == "tool_call": chunks = list(handler.process_tool_call_phase(case["data"], is_stream=True)) else: chunks = list(handler.process_other_phase(case["data"], is_stream=True)) # 检查是否按预期处理 if case["should_pass"]: print(f" ✅ 成功处理,生成 {len(chunks)} 个输出块") result.add_pass() else: print(f" ❌ 应该失败但成功了") result.add_fail(f"{case['name']}: 应该失败但成功了") except Exception as e: if case["should_pass"]: error_msg = f"{case['name']}: 意外异常 - {str(e)}" print(f" ❌ {error_msg}") result.add_fail(error_msg) else: print(f" ✅ 按预期失败: {str(e)}") result.add_pass() result.print_summary() return result def test_performance(): """测试性能表现""" result = TestResult("性能测试") print(f"\n🚀 开始性能测试...") # 测试大量小块数据的处理性能 handler = SSEToolHandler("test_chat_id", "GLM-4.5") start_time = time.time() # 模拟1000次小的编辑操作 for i in range(1000): data = { "edit_index": i * 5, "edit_content": f"chunk_{i}", "phase": "tool_call" } list(handler.process_tool_call_phase(data, is_stream=False)) end_time = time.time() duration = end_time - start_time print(f"⏱️ 处理1000次编辑操作耗时: {duration:.3f}秒") print(f"📊 平均每次操作耗时: {duration * 1000 / 1000:.3f}毫秒") # 性能基准:每次操作应该在1毫秒以内 if duration < 1.0: # 1秒内完成1000次操作 print("✅ 性能测试通过") result.add_pass() else: error_msg = f"性能测试失败: 耗时{duration:.3f}秒,超过1秒基准" print(f"❌ {error_msg}") result.add_fail(error_msg) result.print_summary() return result def test_argument_parsing(): """测试参数解析功能""" result = TestResult("参数解析测试") print(f"\n🧪 开始参数解析测试...") handler = SSEToolHandler("test", "test") test_cases = [ ('{"city": "北京"}', {"city": "北京"}), ('{"city": "北京', {"city": "北京"}), # 缺少闭合括号 ('{"city": "北京"', {"city": "北京"}), # 缺少闭合括号但有引号 ('{\\"city\\": \\"北京\\"}', {"city": "北京"}), # 转义的JSON ('{}', {}), # 空参数 ('null', {}), # null参数 ('{"array": [1,2,3], "nested": {"key": "value"}}', {"array": [1,2,3], "nested": {"key": "value"}}), # 复杂参数 ('{"url":"https://www.goo', {"url": "https://www.goo"}), # 不完整的URL ('', {}), # 空字符串 ('{', {}), # 只有开始括号 ] for i, (input_str, expected) in enumerate(test_cases, 1): try: parsed_result = handler._parse_partial_arguments(input_str) success = parsed_result == expected if success: print(f"✅ 测试 {i}: 解析成功") result.add_pass() else: error_msg = f"测试 {i} 失败: 输入'{input_str[:30]}...', 期望{expected}, 实际{parsed_result}" print(f"❌ {error_msg}") result.add_fail(error_msg) except Exception as e: error_msg = f"测试 {i} 异常: 输入'{input_str[:30]}...', 错误: {str(e)}" print(f"❌ {error_msg}") result.add_fail(error_msg) result.print_summary() return result def run_all_tests(): """运行所有测试""" print("🧪 SSE工具调用处理器优化测试套件") print("="*60) all_results = [] try: # 运行真实场景测试 print("\n1️⃣ 真实场景测试") all_results.append(test_real_world_scenarios()) # 运行边界情况测试 print("\n2️⃣ 边界情况测试") all_results.append(test_edge_cases()) # 运行参数解析测试 print("\n3️⃣ 参数解析测试") all_results.append(test_argument_parsing()) # 运行性能测试 print("\n4️⃣ 性能测试") all_results.append(test_performance()) # 汇总结果 print("\n" + "="*60) print("📊 测试汇总") print("="*60) total_passed = sum(r.passed for r in all_results) total_failed = sum(r.failed for r in all_results) total_tests = total_passed + total_failed print(f"总测试数: {total_tests}") print(f"✅ 通过: {total_passed}") print(f"❌ 失败: {total_failed}") if total_tests > 0: success_rate = (total_passed / total_tests) * 100 print(f"📈 总体成功率: {success_rate:.1f}%") if success_rate >= 90: print("🎉 测试结果优秀!") elif success_rate >= 70: print("👍 测试结果良好") else: print("⚠️ 需要改进") # 显示失败的测试 failed_tests = [] for result in all_results: failed_tests.extend(result.errors) if failed_tests: print(f"\n🔍 失败测试详情:") for i, error in enumerate(failed_tests, 1): print(f" {i}. {error}") return total_failed == 0 except Exception as e: print(f"❌ 测试套件执行失败: {e}") traceback.print_exc() return False if __name__ == "__main__": success = run_all_tests() exit(0 if success else 1)