Spaces:
Paused
Paused
File size: 8,989 Bytes
fd21f34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
全面测试 ZAI Provider 修复效果
验证流式输出、工具调用、思考模式、重试机制等功能
"""
import asyncio
import json
import sys
import os
# 添加项目根目录到路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app.providers.zai_provider import ZAIProvider
from app.models.schemas import OpenAIRequest, Message
from app.core.config import settings
async def test_basic_stream():
"""测试基本流式输出"""
print("🧪 测试基本流式输出...")
provider = ZAIProvider()
request = OpenAIRequest(
model=settings.PRIMARY_MODEL,
messages=[
Message(role="user", content="你好,请简单介绍一下自己")
],
stream=True
)
try:
response = await provider.chat_completion(request)
if hasattr(response, '__aiter__'):
print("✅ 返回了异步生成器")
chunk_count = 0
content_chunks = []
async for chunk in response:
chunk_count += 1
if chunk.startswith("data: ") and not chunk.strip().endswith("[DONE]"):
try:
chunk_data = json.loads(chunk[6:].strip())
if "choices" in chunk_data and chunk_data["choices"]:
choice = chunk_data["choices"][0]
if "delta" in choice and "content" in choice["delta"]:
content = choice["delta"]["content"]
if content:
content_chunks.append(content)
except:
pass
if chunk_count >= 10: # 限制测试长度
break
full_content = "".join(content_chunks)
print(f"✅ 成功处理了 {chunk_count} 个数据块")
print(f"📝 内容预览: {full_content[:100]}...")
return len(content_chunks) > 0
else:
print("❌ 返回的不是流式响应")
return False
except Exception as e:
print(f"❌ 基本流式测试失败: {e}")
return False
async def test_thinking_mode():
"""测试思考模式"""
print("\n🧪 测试思考模式...")
provider = ZAIProvider()
request = OpenAIRequest(
model=settings.THINKING_MODEL,
messages=[
Message(role="user", content="请解释一下量子计算的基本原理")
],
stream=True
)
try:
response = await provider.chat_completion(request)
if hasattr(response, '__aiter__'):
print("✅ 返回了异步生成器")
chunk_count = 0
has_thinking = False
has_content = False
async for chunk in response:
chunk_count += 1
# 检查是否包含思考内容
if 'thinking' in chunk:
has_thinking = True
print("✅ 检测到思考内容")
# 检查是否包含普通内容
if '"content"' in chunk and '"thinking"' not in chunk:
has_content = True
print("✅ 检测到答案内容")
if chunk_count >= 20: # 限制测试长度
break
print(f"✅ 成功处理了 {chunk_count} 个数据块")
print(f"🤔 思考模式: {'正常' if has_thinking else '未检测到'}")
print(f"💬 答案内容: {'正常' if has_content else '未检测到'}")
return True
else:
print("❌ 返回的不是流式响应")
return False
except Exception as e:
print(f"❌ 思考模式测试失败: {e}")
return False
async def test_tool_support():
"""测试工具调用支持"""
print("\n🧪 测试工具调用支持...")
if not settings.TOOL_SUPPORT:
print("⚠️ 工具支持已禁用,跳过测试")
return True
provider = ZAIProvider()
# 简单的工具定义
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "获取指定城市的天气信息",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "城市名称"
}
},
"required": ["city"]
}
}
}
]
request = OpenAIRequest(
model=settings.PRIMARY_MODEL,
messages=[
Message(role="user", content="请帮我查询北京的天气")
],
tools=tools,
stream=True
)
try:
response = await provider.chat_completion(request)
if hasattr(response, '__aiter__'):
print("✅ 返回了异步生成器")
chunk_count = 0
has_tool_call = False
async for chunk in response:
chunk_count += 1
# 检查是否包含工具调用
if 'tool_calls' in chunk:
has_tool_call = True
print("✅ 检测到工具调用")
if chunk_count >= 30: # 限制测试长度
break
print(f"✅ 成功处理了 {chunk_count} 个数据块")
print(f"🔧 工具调用: {'正常' if has_tool_call else '未检测到'}")
return True
else:
print("❌ 返回的不是流式响应")
return False
except Exception as e:
print(f"❌ 工具调用测试失败: {e}")
return False
async def test_error_handling():
"""测试错误处理"""
print("\n🧪 测试错误处理...")
provider = ZAIProvider()
# 使用无效的消息来触发错误
request = OpenAIRequest(
model="invalid-model",
messages=[
Message(role="user", content="测试错误处理")
],
stream=True
)
try:
response = await provider.chat_completion(request)
if hasattr(response, '__aiter__'):
chunk_count = 0
has_error = False
async for chunk in response:
chunk_count += 1
# 检查是否包含错误信息
if 'error' in chunk:
has_error = True
print("✅ 检测到错误处理")
if chunk_count >= 5: # 限制测试长度
break
print(f"✅ 错误处理测试完成,处理了 {chunk_count} 个数据块")
return True
else:
print("✅ 返回了错误响应(非流式)")
return True
except Exception as e:
print(f"✅ 正确捕获了异常: {type(e).__name__}")
return True
async def main():
"""主测试函数"""
print("🚀 开始全面测试 ZAI Provider 修复效果\n")
# 显示配置信息
print("📋 当前配置:")
print(f" - 匿名模式: {settings.ANONYMOUS_MODE}")
print(f" - 工具支持: {settings.TOOL_SUPPORT}")
print(f" - 最大重试: {settings.MAX_RETRIES}")
print(f" - 重试延迟: {settings.RETRY_DELAY}s")
print()
tests = [
("基本流式输出", test_basic_stream),
("思考模式", test_thinking_mode),
("工具调用支持", test_tool_support),
("错误处理", test_error_handling),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
try:
print(f"{'='*50}")
result = await test_func()
if result:
passed += 1
print(f"✅ {test_name} 测试通过")
else:
print(f"❌ {test_name} 测试失败")
except Exception as e:
print(f"❌ {test_name} 测试异常: {e}")
print()
print(f"{'='*50}")
print(f"📊 测试结果: {passed}/{total} 通过")
if passed == total:
print("🎉 所有测试都通过了!ZAI Provider 修复成功")
elif passed >= total * 0.75:
print("✅ 大部分测试通过,ZAI Provider 基本修复成功")
else:
print("⚠️ 多个测试失败,需要进一步检查")
if __name__ == "__main__":
asyncio.run(main())
|