#!/usr/bin/env python # -*- coding: utf-8 -*- """ 全面测试 ZAI Provider 修复效果 验证流式输出、工具调用、思考模式、重试机制等功能 """ import asyncio import json import sys import os # 添加项目根目录到路径 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from app.providers.zai_provider import ZAIProvider from app.models.schemas import OpenAIRequest, Message from app.core.config import settings async def test_basic_stream(): """测试基本流式输出""" print("🧪 测试基本流式输出...") provider = ZAIProvider() request = OpenAIRequest( model=settings.PRIMARY_MODEL, messages=[ Message(role="user", content="你好,请简单介绍一下自己") ], stream=True ) try: response = await provider.chat_completion(request) if hasattr(response, '__aiter__'): print("✅ 返回了异步生成器") chunk_count = 0 content_chunks = [] async for chunk in response: chunk_count += 1 if chunk.startswith("data: ") and not chunk.strip().endswith("[DONE]"): try: chunk_data = json.loads(chunk[6:].strip()) if "choices" in chunk_data and chunk_data["choices"]: choice = chunk_data["choices"][0] if "delta" in choice and "content" in choice["delta"]: content = choice["delta"]["content"] if content: content_chunks.append(content) except: pass if chunk_count >= 10: # 限制测试长度 break full_content = "".join(content_chunks) print(f"✅ 成功处理了 {chunk_count} 个数据块") print(f"📝 内容预览: {full_content[:100]}...") return len(content_chunks) > 0 else: print("❌ 返回的不是流式响应") return False except Exception as e: print(f"❌ 基本流式测试失败: {e}") return False async def test_thinking_mode(): """测试思考模式""" print("\n🧪 测试思考模式...") provider = ZAIProvider() request = OpenAIRequest( model=settings.THINKING_MODEL, messages=[ Message(role="user", content="请解释一下量子计算的基本原理") ], stream=True ) try: response = await provider.chat_completion(request) if hasattr(response, '__aiter__'): print("✅ 返回了异步生成器") chunk_count = 0 has_thinking = False has_content = False async for chunk in response: chunk_count += 1 # 检查是否包含思考内容 if 'thinking' in chunk: has_thinking = True print("✅ 检测到思考内容") # 检查是否包含普通内容 if '"content"' in chunk and '"thinking"' not in chunk: has_content = True print("✅ 检测到答案内容") if chunk_count >= 20: # 限制测试长度 break print(f"✅ 成功处理了 {chunk_count} 个数据块") print(f"🤔 思考模式: {'正常' if has_thinking else '未检测到'}") print(f"💬 答案内容: {'正常' if has_content else '未检测到'}") return True else: print("❌ 返回的不是流式响应") return False except Exception as e: print(f"❌ 思考模式测试失败: {e}") return False async def test_tool_support(): """测试工具调用支持""" print("\n🧪 测试工具调用支持...") if not settings.TOOL_SUPPORT: print("⚠️ 工具支持已禁用,跳过测试") return True provider = ZAIProvider() # 简单的工具定义 tools = [ { "type": "function", "function": { "name": "get_weather", "description": "获取指定城市的天气信息", "parameters": { "type": "object", "properties": { "city": { "type": "string", "description": "城市名称" } }, "required": ["city"] } } } ] request = OpenAIRequest( model=settings.PRIMARY_MODEL, messages=[ Message(role="user", content="请帮我查询北京的天气") ], tools=tools, stream=True ) try: response = await provider.chat_completion(request) if hasattr(response, '__aiter__'): print("✅ 返回了异步生成器") chunk_count = 0 has_tool_call = False async for chunk in response: chunk_count += 1 # 检查是否包含工具调用 if 'tool_calls' in chunk: has_tool_call = True print("✅ 检测到工具调用") if chunk_count >= 30: # 限制测试长度 break print(f"✅ 成功处理了 {chunk_count} 个数据块") print(f"🔧 工具调用: {'正常' if has_tool_call else '未检测到'}") return True else: print("❌ 返回的不是流式响应") return False except Exception as e: print(f"❌ 工具调用测试失败: {e}") return False async def test_error_handling(): """测试错误处理""" print("\n🧪 测试错误处理...") provider = ZAIProvider() # 使用无效的消息来触发错误 request = OpenAIRequest( model="invalid-model", messages=[ Message(role="user", content="测试错误处理") ], stream=True ) try: response = await provider.chat_completion(request) if hasattr(response, '__aiter__'): chunk_count = 0 has_error = False async for chunk in response: chunk_count += 1 # 检查是否包含错误信息 if 'error' in chunk: has_error = True print("✅ 检测到错误处理") if chunk_count >= 5: # 限制测试长度 break print(f"✅ 错误处理测试完成,处理了 {chunk_count} 个数据块") return True else: print("✅ 返回了错误响应(非流式)") return True except Exception as e: print(f"✅ 正确捕获了异常: {type(e).__name__}") return True async def main(): """主测试函数""" print("🚀 开始全面测试 ZAI Provider 修复效果\n") # 显示配置信息 print("📋 当前配置:") print(f" - 匿名模式: {settings.ANONYMOUS_MODE}") print(f" - 工具支持: {settings.TOOL_SUPPORT}") print(f" - 最大重试: {settings.MAX_RETRIES}") print(f" - 重试延迟: {settings.RETRY_DELAY}s") print() tests = [ ("基本流式输出", test_basic_stream), ("思考模式", test_thinking_mode), ("工具调用支持", test_tool_support), ("错误处理", test_error_handling), ] passed = 0 total = len(tests) for test_name, test_func in tests: try: print(f"{'='*50}") result = await test_func() if result: passed += 1 print(f"✅ {test_name} 测试通过") else: print(f"❌ {test_name} 测试失败") except Exception as e: print(f"❌ {test_name} 测试异常: {e}") print() print(f"{'='*50}") print(f"📊 测试结果: {passed}/{total} 通过") if passed == total: print("🎉 所有测试都通过了!ZAI Provider 修复成功") elif passed >= total * 0.75: print("✅ 大部分测试通过,ZAI Provider 基本修复成功") else: print("⚠️ 多个测试失败,需要进一步检查") if __name__ == "__main__": asyncio.run(main())