ZyphrZero
commited on
Commit
·
c6f78e9
1
Parent(s):
8118659
✨ refactor(core): 重构工具调用处理逻辑
Browse files- .env.example +1 -3
- app/core/config.py +0 -1
- app/core/openai.py +20 -58
- app/core/zai_transformer.py +24 -25
- app/models/schemas.py +0 -1
- app/utils/sse_tool_handler.py +207 -38
.env.example
CHANGED
|
@@ -47,6 +47,4 @@ SCAN_LIMIT=200000
|
|
| 47 |
# 重试次数
|
| 48 |
MAX_RETRIES=5
|
| 49 |
# 初始重试延迟
|
| 50 |
-
RETRY_DELAY=1
|
| 51 |
-
# 退避系数
|
| 52 |
-
RETRY_BACKOFF=2
|
|
|
|
| 47 |
# 重试次数
|
| 48 |
MAX_RETRIES=5
|
| 49 |
# 初始重试延迟
|
| 50 |
+
RETRY_DELAY=1
|
|
|
|
|
|
app/core/config.py
CHANGED
|
@@ -118,7 +118,6 @@ class Settings(BaseSettings):
|
|
| 118 |
# Retry Configuration
|
| 119 |
MAX_RETRIES: int = int(os.getenv("MAX_RETRIES", "5"))
|
| 120 |
RETRY_DELAY: float = float(os.getenv("RETRY_DELAY", "1.0")) # 初始重试延迟(秒)
|
| 121 |
-
RETRY_BACKOFF: float = float(os.getenv("RETRY_BACKOFF", "2.0")) # 退避系数
|
| 122 |
|
| 123 |
# Browser Headers
|
| 124 |
CLIENT_HEADERS: Dict[str, str] = {
|
|
|
|
| 118 |
# Retry Configuration
|
| 119 |
MAX_RETRIES: int = int(os.getenv("MAX_RETRIES", "5"))
|
| 120 |
RETRY_DELAY: float = float(os.getenv("RETRY_DELAY", "1.0")) # 初始重试延迟(秒)
|
|
|
|
| 121 |
|
| 122 |
# Browser Headers
|
| 123 |
CLIENT_HEADERS: Dict[str, str] = {
|
app/core/openai.py
CHANGED
|
@@ -43,47 +43,24 @@ async def list_models():
|
|
| 43 |
@router.post("/v1/chat/completions")
|
| 44 |
async def chat_completions(request: OpenAIRequest, authorization: str = Header(...)):
|
| 45 |
"""Handle chat completion requests with ZAI transformer"""
|
| 46 |
-
|
| 47 |
-
logger.
|
| 48 |
-
|
| 49 |
-
# 输出消息内容用于调试
|
| 50 |
-
for idx, msg in enumerate(request.messages):
|
| 51 |
-
content_preview = str(msg.content)[:1000] if msg.content else "None"
|
| 52 |
-
logger.debug(f" 消息[{idx}] - 角色: {msg.role}, 内容预览: {content_preview}...")
|
| 53 |
|
| 54 |
try:
|
| 55 |
# Validate API key (skip if SKIP_AUTH_TOKEN is enabled)
|
| 56 |
if not settings.SKIP_AUTH_TOKEN:
|
| 57 |
if not authorization.startswith("Bearer "):
|
| 58 |
-
logger.debug("缺少或无效的Authorization头")
|
| 59 |
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
|
| 60 |
|
| 61 |
api_key = authorization[7:]
|
| 62 |
if api_key != settings.AUTH_TOKEN:
|
| 63 |
-
logger.debug(f"无效的API key: {api_key}")
|
| 64 |
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 65 |
|
| 66 |
-
logger.debug(f"API key验证通过")
|
| 67 |
-
else:
|
| 68 |
-
logger.debug("SKIP_AUTH_TOKEN已启用,跳过API key验证")
|
| 69 |
-
|
| 70 |
-
# 输出原始请求体用于调试
|
| 71 |
-
request_dict = request.model_dump()
|
| 72 |
-
# logger.debug(f"🔄 原始 OpenAI 请求体: {json.dumps(request_dict, ensure_ascii=False, indent=2)}")
|
| 73 |
-
|
| 74 |
# 使用新的转换器转换请求
|
|
|
|
| 75 |
logger.info("🔄 开始转换请求格式: OpenAI -> Z.AI")
|
|
|
|
| 76 |
transformed = await transformer.transform_request_in(request_dict)
|
| 77 |
-
|
| 78 |
-
logger.info(
|
| 79 |
-
f"✅ 请求转换完成 - 上游模型: {transformed['body']['model']}, "
|
| 80 |
-
f"chat_id: {transformed['body']['chat_id']}"
|
| 81 |
-
)
|
| 82 |
-
logger.debug(
|
| 83 |
-
f" 特性配置 - enable_thinking: {transformed['body']['features']['enable_thinking']}, "
|
| 84 |
-
f"web_search: {transformed['body']['features']['web_search']}, "
|
| 85 |
-
f"mcp_servers: {transformed['body'].get('mcp_servers', [])}"
|
| 86 |
-
)
|
| 87 |
# logger.debug(f"🔄 转换后 Z.AI 请求体: {json.dumps(transformed['body'], ensure_ascii=False, indent=2)}")
|
| 88 |
|
| 89 |
# 调用上游API
|
|
@@ -97,11 +74,8 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 97 |
try:
|
| 98 |
# 如果是重试,重新获取令牌并更新请求
|
| 99 |
if retry_count > 0:
|
| 100 |
-
delay = settings.RETRY_DELAY
|
| 101 |
-
logger.warning(
|
| 102 |
-
f"🔄 重试请求 ({retry_count}/{settings.MAX_RETRIES}) - "
|
| 103 |
-
f"等待 {delay:.1f} 秒后重试..."
|
| 104 |
-
)
|
| 105 |
await asyncio.sleep(delay)
|
| 106 |
|
| 107 |
# 标记前一个token失败(如果不是匿名模式)
|
|
@@ -116,13 +90,10 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 116 |
raise Exception("重试时无法获取有效的认证令牌")
|
| 117 |
transformed["config"]["headers"]["Authorization"] = f"Bearer {new_token}"
|
| 118 |
current_token = new_token
|
| 119 |
-
logger.debug(f" 新令牌: {new_token[:20] if new_token else 'None'}...")
|
| 120 |
|
| 121 |
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 122 |
# 发送请求到上游
|
| 123 |
logger.info(f"🎯 发送请求到 Z.AI: {transformed['config']['url']}")
|
| 124 |
-
logger.debug(f" 请求头数量: {len(transformed['config']['headers'])}")
|
| 125 |
-
|
| 126 |
async with client.stream(
|
| 127 |
"POST",
|
| 128 |
transformed["config"]["url"],
|
|
@@ -134,10 +105,7 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 134 |
# 400 错误,触发重试
|
| 135 |
error_text = await response.aread()
|
| 136 |
error_msg = error_text.decode('utf-8', errors='ignore')
|
| 137 |
-
logger.warning(
|
| 138 |
-
f"⚠️ 上游返回 400 错误 (尝试 {retry_count + 1}/{settings.MAX_RETRIES + 1})"
|
| 139 |
-
)
|
| 140 |
-
logger.debug(f" 错误详情: {error_msg}")
|
| 141 |
|
| 142 |
retry_count += 1
|
| 143 |
last_error = f"400 Bad Request: {error_msg}"
|
|
@@ -164,7 +132,7 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 164 |
logger.error(f"❌ 上游返回错误: {response.status_code}")
|
| 165 |
error_text = await response.aread()
|
| 166 |
error_msg = error_text.decode('utf-8', errors='ignore')
|
| 167 |
-
logger.error(f"错误详情: {error_msg}")
|
| 168 |
|
| 169 |
error_response = {
|
| 170 |
"error": {
|
|
@@ -193,7 +161,7 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 193 |
chat_id = transformed["body"]["chat_id"]
|
| 194 |
model = request.model
|
| 195 |
tool_handler = SSEToolHandler(chat_id, model)
|
| 196 |
-
logger.info(f"🔧
|
| 197 |
|
| 198 |
# 处理状态
|
| 199 |
has_thinking = False
|
|
@@ -207,11 +175,8 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 207 |
async for line in response.aiter_lines():
|
| 208 |
line_count += 1
|
| 209 |
if not line:
|
| 210 |
-
# logger.debug(f" 行[{line_count}]: 空行,跳过")
|
| 211 |
continue
|
| 212 |
|
| 213 |
-
logger.debug(f" 行[{line_count}]: 接收到数据 - {line[:1000]}..." if len(line) > 1000 else f" 行[{line_count}]: 接收到数据 - {line}")
|
| 214 |
-
|
| 215 |
# 累积到buffer处理完整的数据行
|
| 216 |
buffer += line + "\n"
|
| 217 |
|
|
@@ -225,11 +190,10 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 225 |
chunk_str = current_line[5:].strip()
|
| 226 |
if not chunk_str or chunk_str == "[DONE]":
|
| 227 |
if chunk_str == "[DONE]":
|
| 228 |
-
logger.debug("🏁 收到结束信号 [DONE]")
|
| 229 |
yield "data: [DONE]\n\n"
|
| 230 |
continue
|
| 231 |
|
| 232 |
-
logger.debug(f"
|
| 233 |
|
| 234 |
try:
|
| 235 |
chunk = json.loads(chunk_str)
|
|
@@ -240,7 +204,7 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 240 |
|
| 241 |
# 记录每个阶段(只在阶段变化时记录)
|
| 242 |
if phase and phase != getattr(stream_response, '_last_phase', None):
|
| 243 |
-
logger.info(f"📈 SSE
|
| 244 |
stream_response._last_phase = phase
|
| 245 |
|
| 246 |
# 处理工具调用
|
|
@@ -274,7 +238,6 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 274 |
"object": "chat.completion.chunk",
|
| 275 |
"system_fingerprint": "fp_zai_001",
|
| 276 |
}
|
| 277 |
-
logger.debug(" ➡️ 发送初始角色")
|
| 278 |
yield f"data: {json.dumps(role_chunk)}\n\n"
|
| 279 |
|
| 280 |
delta_content = data.get("delta_content", "")
|
|
@@ -367,7 +330,6 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 367 |
|
| 368 |
# 处理增量内容
|
| 369 |
elif delta_content:
|
| 370 |
-
logger.debug(f" 📝 答案内容片段: {delta_content[:1000]}...")
|
| 371 |
# 如果还没有发送角色
|
| 372 |
if not has_thinking:
|
| 373 |
role_chunk = {
|
|
@@ -406,7 +368,7 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 406 |
"system_fingerprint": "fp_zai_001",
|
| 407 |
}
|
| 408 |
output_data = f"data: {json.dumps(content_chunk)}\n\n"
|
| 409 |
-
logger.debug(f"
|
| 410 |
yield output_data
|
| 411 |
|
| 412 |
# 处理完成
|
|
@@ -432,15 +394,15 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 432 |
"system_fingerprint": "fp_zai_001",
|
| 433 |
}
|
| 434 |
finish_output = f"data: {json.dumps(finish_chunk)}\n\n"
|
| 435 |
-
logger.debug(f"
|
| 436 |
yield finish_output
|
| 437 |
-
logger.debug("
|
| 438 |
yield "data: [DONE]\n\n"
|
| 439 |
|
| 440 |
except json.JSONDecodeError as e:
|
| 441 |
-
logger.debug(f"JSON解析错误: {e}, 内容: {chunk_str[:1000]}")
|
| 442 |
except Exception as e:
|
| 443 |
-
logger.error(f"处理chunk错误: {e}")
|
| 444 |
|
| 445 |
# 确保发送结束信号
|
| 446 |
if not tool_handler or not tool_handler.has_tool_call:
|
|
@@ -452,7 +414,7 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 452 |
return
|
| 453 |
|
| 454 |
except Exception as e:
|
| 455 |
-
logger.error(f"流处理错误: {e}")
|
| 456 |
import traceback
|
| 457 |
logger.error(traceback.format_exc())
|
| 458 |
|
|
@@ -487,7 +449,7 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 487 |
logger.debug("📤 开始向客户端流式传输数据...")
|
| 488 |
async for chunk in stream_response():
|
| 489 |
chunk_count += 1
|
| 490 |
-
logger.debug(f"
|
| 491 |
yield chunk
|
| 492 |
logger.info(f"✅ 流式传输完成,共发送 {chunk_count} 个数据块")
|
| 493 |
except Exception as e:
|
|
@@ -506,10 +468,10 @@ async def chat_completions(request: OpenAIRequest, authorization: str = Header(.
|
|
| 506 |
except HTTPException:
|
| 507 |
raise
|
| 508 |
except Exception as e:
|
| 509 |
-
logger.error(f"处理请求时发生错误: {str(e)}")
|
| 510 |
import traceback
|
| 511 |
|
| 512 |
-
logger.error(f"错误堆栈: {traceback.format_exc()}")
|
| 513 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 514 |
|
| 515 |
|
|
|
|
| 43 |
@router.post("/v1/chat/completions")
|
| 44 |
async def chat_completions(request: OpenAIRequest, authorization: str = Header(...)):
|
| 45 |
"""Handle chat completion requests with ZAI transformer"""
|
| 46 |
+
role = request.messages[0].role if request.messages else "unknown"
|
| 47 |
+
logger.info(f"😶🌫️ 收到 客户端 请求 - 模型: {request.model}, 流式: {request.stream}, 消息数: {len(request.messages)}, 角色: {role}, 工具数: {len(request.tools) if request.tools else 0}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
try:
|
| 50 |
# Validate API key (skip if SKIP_AUTH_TOKEN is enabled)
|
| 51 |
if not settings.SKIP_AUTH_TOKEN:
|
| 52 |
if not authorization.startswith("Bearer "):
|
|
|
|
| 53 |
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
|
| 54 |
|
| 55 |
api_key = authorization[7:]
|
| 56 |
if api_key != settings.AUTH_TOKEN:
|
|
|
|
| 57 |
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# 使用新的转换器转换请求
|
| 60 |
+
request_dict = request.model_dump()
|
| 61 |
logger.info("🔄 开始转换请求格式: OpenAI -> Z.AI")
|
| 62 |
+
|
| 63 |
transformed = await transformer.transform_request_in(request_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# logger.debug(f"🔄 转换后 Z.AI 请求体: {json.dumps(transformed['body'], ensure_ascii=False, indent=2)}")
|
| 65 |
|
| 66 |
# 调用上游API
|
|
|
|
| 74 |
try:
|
| 75 |
# 如果是重试,重新获取令牌并更新请求
|
| 76 |
if retry_count > 0:
|
| 77 |
+
delay = settings.RETRY_DELAY
|
| 78 |
+
logger.warning(f"重试请求 ({retry_count}/{settings.MAX_RETRIES}) - 等待 {delay:.1f}s")
|
|
|
|
|
|
|
|
|
|
| 79 |
await asyncio.sleep(delay)
|
| 80 |
|
| 81 |
# 标记前一个token失败(如果不是匿名模式)
|
|
|
|
| 90 |
raise Exception("重试时无法获取有效的认证令牌")
|
| 91 |
transformed["config"]["headers"]["Authorization"] = f"Bearer {new_token}"
|
| 92 |
current_token = new_token
|
|
|
|
| 93 |
|
| 94 |
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 95 |
# 发送请求到上游
|
| 96 |
logger.info(f"🎯 发送请求到 Z.AI: {transformed['config']['url']}")
|
|
|
|
|
|
|
| 97 |
async with client.stream(
|
| 98 |
"POST",
|
| 99 |
transformed["config"]["url"],
|
|
|
|
| 105 |
# 400 错误,触发重试
|
| 106 |
error_text = await response.aread()
|
| 107 |
error_msg = error_text.decode('utf-8', errors='ignore')
|
| 108 |
+
logger.warning(f"❌ 上游返回 400 错误 (尝试 {retry_count + 1}/{settings.MAX_RETRIES + 1})")
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
retry_count += 1
|
| 111 |
last_error = f"400 Bad Request: {error_msg}"
|
|
|
|
| 132 |
logger.error(f"❌ 上游返回错误: {response.status_code}")
|
| 133 |
error_text = await response.aread()
|
| 134 |
error_msg = error_text.decode('utf-8', errors='ignore')
|
| 135 |
+
logger.error(f"❌ 错误详情: {error_msg}")
|
| 136 |
|
| 137 |
error_response = {
|
| 138 |
"error": {
|
|
|
|
| 161 |
chat_id = transformed["body"]["chat_id"]
|
| 162 |
model = request.model
|
| 163 |
tool_handler = SSEToolHandler(chat_id, model)
|
| 164 |
+
logger.info(f"🔧 初始化工具处理器: {len(transformed['body'].get('tools', []))} 个工具")
|
| 165 |
|
| 166 |
# 处理状态
|
| 167 |
has_thinking = False
|
|
|
|
| 175 |
async for line in response.aiter_lines():
|
| 176 |
line_count += 1
|
| 177 |
if not line:
|
|
|
|
| 178 |
continue
|
| 179 |
|
|
|
|
|
|
|
| 180 |
# 累积到buffer处理完整的数据行
|
| 181 |
buffer += line + "\n"
|
| 182 |
|
|
|
|
| 190 |
chunk_str = current_line[5:].strip()
|
| 191 |
if not chunk_str or chunk_str == "[DONE]":
|
| 192 |
if chunk_str == "[DONE]":
|
|
|
|
| 193 |
yield "data: [DONE]\n\n"
|
| 194 |
continue
|
| 195 |
|
| 196 |
+
logger.debug(f"📦 解析数据块: {chunk_str[:1000]}..." if len(chunk_str) > 1000 else f"📦 解析数据块: {chunk_str}")
|
| 197 |
|
| 198 |
try:
|
| 199 |
chunk = json.loads(chunk_str)
|
|
|
|
| 204 |
|
| 205 |
# 记录每个阶段(只在阶段变化时记录)
|
| 206 |
if phase and phase != getattr(stream_response, '_last_phase', None):
|
| 207 |
+
logger.info(f"📈 SSE 阶段: {phase}")
|
| 208 |
stream_response._last_phase = phase
|
| 209 |
|
| 210 |
# 处理工具调用
|
|
|
|
| 238 |
"object": "chat.completion.chunk",
|
| 239 |
"system_fingerprint": "fp_zai_001",
|
| 240 |
}
|
|
|
|
| 241 |
yield f"data: {json.dumps(role_chunk)}\n\n"
|
| 242 |
|
| 243 |
delta_content = data.get("delta_content", "")
|
|
|
|
| 330 |
|
| 331 |
# 处理增量内容
|
| 332 |
elif delta_content:
|
|
|
|
| 333 |
# 如果还没有发送角色
|
| 334 |
if not has_thinking:
|
| 335 |
role_chunk = {
|
|
|
|
| 368 |
"system_fingerprint": "fp_zai_001",
|
| 369 |
}
|
| 370 |
output_data = f"data: {json.dumps(content_chunk)}\n\n"
|
| 371 |
+
logger.debug(f"➡️ 输出内容块到客户端: {output_data[:1000]}...")
|
| 372 |
yield output_data
|
| 373 |
|
| 374 |
# 处理完成
|
|
|
|
| 394 |
"system_fingerprint": "fp_zai_001",
|
| 395 |
}
|
| 396 |
finish_output = f"data: {json.dumps(finish_chunk)}\n\n"
|
| 397 |
+
logger.debug(f"➡️ 发送完成信号: {finish_output[:1000]}...")
|
| 398 |
yield finish_output
|
| 399 |
+
logger.debug("➡️ 发送 [DONE]")
|
| 400 |
yield "data: [DONE]\n\n"
|
| 401 |
|
| 402 |
except json.JSONDecodeError as e:
|
| 403 |
+
logger.debug(f"❌ JSON解析错误: {e}, 内容: {chunk_str[:1000]}")
|
| 404 |
except Exception as e:
|
| 405 |
+
logger.error(f"❌ 处理chunk错误: {e}")
|
| 406 |
|
| 407 |
# 确保发送结束信号
|
| 408 |
if not tool_handler or not tool_handler.has_tool_call:
|
|
|
|
| 414 |
return
|
| 415 |
|
| 416 |
except Exception as e:
|
| 417 |
+
logger.error(f"❌ 流处理错误: {e}")
|
| 418 |
import traceback
|
| 419 |
logger.error(traceback.format_exc())
|
| 420 |
|
|
|
|
| 449 |
logger.debug("📤 开始向客户端流式传输数据...")
|
| 450 |
async for chunk in stream_response():
|
| 451 |
chunk_count += 1
|
| 452 |
+
logger.debug(f"📤 发送块[{chunk_count}]: {chunk[:1000]}..." if len(chunk) > 1000 else f" 📤 发送块[{chunk_count}]: {chunk}")
|
| 453 |
yield chunk
|
| 454 |
logger.info(f"✅ 流式传输完成,共发送 {chunk_count} 个数据块")
|
| 455 |
except Exception as e:
|
|
|
|
| 468 |
except HTTPException:
|
| 469 |
raise
|
| 470 |
except Exception as e:
|
| 471 |
+
logger.error(f"❌ 处理请求时发生错误: {str(e)}")
|
| 472 |
import traceback
|
| 473 |
|
| 474 |
+
logger.error(f"❌ 错误堆栈: {traceback.format_exc()}")
|
| 475 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 476 |
|
| 477 |
|
app/core/zai_transformer.py
CHANGED
|
@@ -92,7 +92,6 @@ def get_dynamic_headers(chat_id: str = "") -> Dict[str, str]:
|
|
| 92 |
else:
|
| 93 |
headers["Referer"] = "https://chat.z.ai/"
|
| 94 |
|
| 95 |
-
logger.debug(f"使用动态User-Agent: {user_agent[:80]}...")
|
| 96 |
return headers
|
| 97 |
|
| 98 |
|
|
@@ -105,14 +104,13 @@ def get_auth_token_sync() -> str:
|
|
| 105 |
"""同步获取认证令牌(用于非异步场景)"""
|
| 106 |
if settings.ANONYMOUS_MODE:
|
| 107 |
try:
|
| 108 |
-
logger.debug("匿名模式:获取新的访客令牌")
|
| 109 |
headers = get_dynamic_headers()
|
| 110 |
response = requests.get("https://chat.z.ai/api/v1/auths/", headers=headers, timeout=10)
|
| 111 |
if response.status_code == 200:
|
| 112 |
data = response.json()
|
| 113 |
token = data.get("token", "")
|
| 114 |
if token:
|
| 115 |
-
logger.debug(f"
|
| 116 |
return token
|
| 117 |
except Exception as e:
|
| 118 |
logger.warning(f"获取访客令牌失败: {e}")
|
|
@@ -152,7 +150,7 @@ class ZAITransformer:
|
|
| 152 |
"""异步获取认证令牌"""
|
| 153 |
if settings.ANONYMOUS_MODE:
|
| 154 |
try:
|
| 155 |
-
|
| 156 |
headers = get_dynamic_headers()
|
| 157 |
async with httpx.AsyncClient() as client:
|
| 158 |
response = await client.get(self.auth_url, headers=headers, timeout=10.0)
|
|
@@ -160,7 +158,7 @@ class ZAITransformer:
|
|
| 160 |
data = response.json()
|
| 161 |
token = data.get("token", "")
|
| 162 |
if token:
|
| 163 |
-
logger.debug(f"
|
| 164 |
return token
|
| 165 |
except Exception as e:
|
| 166 |
logger.warning(f"异步获取访客令牌失败: {e}")
|
|
@@ -194,8 +192,8 @@ class ZAITransformer:
|
|
| 194 |
转换OpenAI请求为z.ai格式
|
| 195 |
整合现有功能:模型映射、MCP服务器等
|
| 196 |
"""
|
| 197 |
-
logger.info("🔄 开始转换 OpenAI 请求到 Z.AI
|
| 198 |
-
|
| 199 |
# 获取认证令牌
|
| 200 |
token = await self.get_token()
|
| 201 |
logger.debug(f" 使用令牌: {token[:20] if token else 'None'}...")
|
|
@@ -210,12 +208,12 @@ class ZAITransformer:
|
|
| 210 |
is_thinking = requested_model == settings.THINKING_MODEL or request.get("reasoning", False)
|
| 211 |
is_search = requested_model == settings.SEARCH_MODEL
|
| 212 |
is_air = requested_model == settings.AIR_MODEL
|
| 213 |
-
|
| 214 |
-
logger.info(f" 模型分析 - 请求模型: {requested_model}, 思考模式: {is_thinking}, 搜索模式: {is_search}, Air模式: {is_air}")
|
| 215 |
|
| 216 |
# 获取上游模型ID(使用模型映射)
|
| 217 |
upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API")
|
| 218 |
logger.debug(f" 模型映射: {requested_model} -> {upstream_model_id}")
|
|
|
|
|
|
|
| 219 |
|
| 220 |
# 处理消息列表
|
| 221 |
logger.debug(f" 开始处理 {len(request.get('messages', []))} 条消息")
|
|
@@ -225,7 +223,7 @@ class ZAITransformer:
|
|
| 225 |
|
| 226 |
# 处理system角色转换
|
| 227 |
if msg.get("role") == "system":
|
| 228 |
-
|
| 229 |
msg["role"] = "user"
|
| 230 |
content = msg.get("content")
|
| 231 |
|
|
@@ -257,7 +255,7 @@ class ZAITransformer:
|
|
| 257 |
|
| 258 |
# 处理assistant消息中的reasoning_content
|
| 259 |
elif msg.get("role") == "assistant" and msg.get("reasoning_content"):
|
| 260 |
-
|
| 261 |
# 如果有reasoning_content,保留它
|
| 262 |
pass
|
| 263 |
|
|
@@ -267,11 +265,14 @@ class ZAITransformer:
|
|
| 267 |
mcp_servers = []
|
| 268 |
if is_search:
|
| 269 |
mcp_servers.append("deep-web-search")
|
| 270 |
-
logger.info("
|
|
|
|
|
|
|
| 271 |
|
|
|
|
|
|
|
| 272 |
# 构建上游请求体
|
| 273 |
chat_id = generate_uuid()
|
| 274 |
-
logger.info(f" 生成 chat_id: {chat_id}")
|
| 275 |
|
| 276 |
body = {
|
| 277 |
"stream": True, # 总是使用流式
|
|
@@ -303,7 +304,6 @@ class ZAITransformer:
|
|
| 303 |
"{{USER_LANGUAGE}}": "zh-CN",
|
| 304 |
},
|
| 305 |
"model_item": {},
|
| 306 |
-
"tool_servers": [], # 保留工具服务器字段
|
| 307 |
"chat_id": chat_id,
|
| 308 |
"id": generate_uuid(),
|
| 309 |
}
|
|
@@ -311,18 +311,12 @@ class ZAITransformer:
|
|
| 311 |
# 处理工具支持
|
| 312 |
if settings.TOOL_SUPPORT and not is_thinking and request.get("tools"):
|
| 313 |
body["tools"] = request["tools"]
|
| 314 |
-
logger.info(f"
|
| 315 |
-
for tool_idx, tool in enumerate(request["tools"]):
|
| 316 |
-
tool_name = tool.get("function", {}).get("name", "unknown")
|
| 317 |
-
logger.debug(f" 工具[{tool_idx}]: {tool_name}")
|
| 318 |
else:
|
| 319 |
body["tools"] = None
|
| 320 |
-
if request.get("tools"):
|
| 321 |
-
logger.debug(f" 工具支持已禁用或在思考模式下,忽略 {len(request.get('tools', []))} 个工具")
|
| 322 |
|
| 323 |
# 构建请求配置
|
| 324 |
dynamic_headers = get_dynamic_headers(chat_id)
|
| 325 |
-
logger.debug(f" 生成动态请求头 - User-Agent: {dynamic_headers.get('User-Agent', '')[:80]}...")
|
| 326 |
|
| 327 |
config = {
|
| 328 |
"url": self.api_url, # 使用原始URL
|
|
@@ -339,9 +333,14 @@ class ZAITransformer:
|
|
| 339 |
}
|
| 340 |
|
| 341 |
logger.info("✅ 请求转换完成")
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
logger.debug(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
return {"body": body, "config": config, "token": token}
|
| 347 |
|
|
@@ -718,7 +717,7 @@ class ZAITransformer:
|
|
| 718 |
yield "data: [DONE]\n\n"
|
| 719 |
|
| 720 |
except json.JSONDecodeError as e:
|
| 721 |
-
logger.debug(f"JSON解析错误: {e}
|
| 722 |
except Exception as e:
|
| 723 |
logger.error(f"处理chunk错误: {e}")
|
| 724 |
|
|
|
|
| 92 |
else:
|
| 93 |
headers["Referer"] = "https://chat.z.ai/"
|
| 94 |
|
|
|
|
| 95 |
return headers
|
| 96 |
|
| 97 |
|
|
|
|
| 104 |
"""同步获取认证令牌(用于非异步场景)"""
|
| 105 |
if settings.ANONYMOUS_MODE:
|
| 106 |
try:
|
|
|
|
| 107 |
headers = get_dynamic_headers()
|
| 108 |
response = requests.get("https://chat.z.ai/api/v1/auths/", headers=headers, timeout=10)
|
| 109 |
if response.status_code == 200:
|
| 110 |
data = response.json()
|
| 111 |
token = data.get("token", "")
|
| 112 |
if token:
|
| 113 |
+
logger.debug(f"获取访客令牌成功: {token[:20]}...")
|
| 114 |
return token
|
| 115 |
except Exception as e:
|
| 116 |
logger.warning(f"获取访客令牌失败: {e}")
|
|
|
|
| 150 |
"""异步获取认证令牌"""
|
| 151 |
if settings.ANONYMOUS_MODE:
|
| 152 |
try:
|
| 153 |
+
|
| 154 |
headers = get_dynamic_headers()
|
| 155 |
async with httpx.AsyncClient() as client:
|
| 156 |
response = await client.get(self.auth_url, headers=headers, timeout=10.0)
|
|
|
|
| 158 |
data = response.json()
|
| 159 |
token = data.get("token", "")
|
| 160 |
if token:
|
| 161 |
+
logger.debug(f"获取访客令牌成功: {token[:20]}...")
|
| 162 |
return token
|
| 163 |
except Exception as e:
|
| 164 |
logger.warning(f"异步获取访客令牌失败: {e}")
|
|
|
|
| 192 |
转换OpenAI请求为z.ai格式
|
| 193 |
整合现有功能:模型映射、MCP服务器等
|
| 194 |
"""
|
| 195 |
+
logger.info(f"🔄 开始转换 OpenAI 请求到 Z.AI 格式: {request.get('model', settings.PRIMARY_MODEL)} -> Z.AI")
|
| 196 |
+
|
| 197 |
# 获取认证令牌
|
| 198 |
token = await self.get_token()
|
| 199 |
logger.debug(f" 使用令牌: {token[:20] if token else 'None'}...")
|
|
|
|
| 208 |
is_thinking = requested_model == settings.THINKING_MODEL or request.get("reasoning", False)
|
| 209 |
is_search = requested_model == settings.SEARCH_MODEL
|
| 210 |
is_air = requested_model == settings.AIR_MODEL
|
|
|
|
|
|
|
| 211 |
|
| 212 |
# 获取上游模型ID(使用模型映射)
|
| 213 |
upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API")
|
| 214 |
logger.debug(f" 模型映射: {requested_model} -> {upstream_model_id}")
|
| 215 |
+
logger.debug(f" 模型特性检测: is_search={is_search}, is_thinking={is_thinking}, is_air={is_air}")
|
| 216 |
+
logger.debug(f" SEARCH_MODEL配置: {settings.SEARCH_MODEL}")
|
| 217 |
|
| 218 |
# 处理消息列表
|
| 219 |
logger.debug(f" 开始处理 {len(request.get('messages', []))} 条消息")
|
|
|
|
| 223 |
|
| 224 |
# 处理system角色转换
|
| 225 |
if msg.get("role") == "system":
|
| 226 |
+
|
| 227 |
msg["role"] = "user"
|
| 228 |
content = msg.get("content")
|
| 229 |
|
|
|
|
| 255 |
|
| 256 |
# 处理assistant消息中的reasoning_content
|
| 257 |
elif msg.get("role") == "assistant" and msg.get("reasoning_content"):
|
| 258 |
+
|
| 259 |
# 如果有reasoning_content,保留它
|
| 260 |
pass
|
| 261 |
|
|
|
|
| 265 |
mcp_servers = []
|
| 266 |
if is_search:
|
| 267 |
mcp_servers.append("deep-web-search")
|
| 268 |
+
logger.info(f"🔍 检测到搜索模型,添加 deep-web-search MCP 服务器")
|
| 269 |
+
else:
|
| 270 |
+
logger.debug(f" 非搜索模型,不添加 MCP 服务器")
|
| 271 |
|
| 272 |
+
logger.debug(f" MCP服务器列表: {mcp_servers}")
|
| 273 |
+
|
| 274 |
# 构建上游请求体
|
| 275 |
chat_id = generate_uuid()
|
|
|
|
| 276 |
|
| 277 |
body = {
|
| 278 |
"stream": True, # 总是使用流式
|
|
|
|
| 304 |
"{{USER_LANGUAGE}}": "zh-CN",
|
| 305 |
},
|
| 306 |
"model_item": {},
|
|
|
|
| 307 |
"chat_id": chat_id,
|
| 308 |
"id": generate_uuid(),
|
| 309 |
}
|
|
|
|
| 311 |
# 处理工具支持
|
| 312 |
if settings.TOOL_SUPPORT and not is_thinking and request.get("tools"):
|
| 313 |
body["tools"] = request["tools"]
|
| 314 |
+
logger.info(f"启用工具支持: {len(request['tools'])} 个工具")
|
|
|
|
|
|
|
|
|
|
| 315 |
else:
|
| 316 |
body["tools"] = None
|
|
|
|
|
|
|
| 317 |
|
| 318 |
# 构建请求配置
|
| 319 |
dynamic_headers = get_dynamic_headers(chat_id)
|
|
|
|
| 320 |
|
| 321 |
config = {
|
| 322 |
"url": self.api_url, # 使用原始URL
|
|
|
|
| 333 |
}
|
| 334 |
|
| 335 |
logger.info("✅ 请求转换完成")
|
| 336 |
+
|
| 337 |
+
# 记录关键的请求信息用于调试
|
| 338 |
+
logger.debug(f" 📋 发送到Z.AI的关键信息:")
|
| 339 |
+
logger.debug(f" - 上游模型: {body['model']}")
|
| 340 |
+
logger.debug(f" - MCP服务器: {body['mcp_servers']}")
|
| 341 |
+
logger.debug(f" - web_search: {body['features']['web_search']}")
|
| 342 |
+
logger.debug(f" - auto_web_search: {body['features']['auto_web_search']}")
|
| 343 |
+
logger.debug(f" - 消息数量: {len(body['messages'])}")
|
| 344 |
|
| 345 |
return {"body": body, "config": config, "token": token}
|
| 346 |
|
|
|
|
| 717 |
yield "data: [DONE]\n\n"
|
| 718 |
|
| 719 |
except json.JSONDecodeError as e:
|
| 720 |
+
logger.debug(f"JSON解析错误: {e}")
|
| 721 |
except Exception as e:
|
| 722 |
logger.error(f"处理chunk错误: {e}")
|
| 723 |
|
app/models/schemas.py
CHANGED
|
@@ -54,7 +54,6 @@ class UpstreamRequest(BaseModel):
|
|
| 54 |
id: Optional[str] = None
|
| 55 |
mcp_servers: Optional[List[str]] = None
|
| 56 |
model_item: Optional[Dict[str, Any]] = {} # Model item dictionary
|
| 57 |
-
tool_servers: Optional[List[str]] = None
|
| 58 |
tools: Optional[List[Dict[str, Any]]] = None # Add tools field for OpenAI compatibility
|
| 59 |
variables: Optional[Dict[str, str]] = None
|
| 60 |
model_config = {"protected_namespaces": ()}
|
|
|
|
| 54 |
id: Optional[str] = None
|
| 55 |
mcp_servers: Optional[List[str]] = None
|
| 56 |
model_item: Optional[Dict[str, Any]] = {} # Model item dictionary
|
|
|
|
| 57 |
tools: Optional[List[Dict[str, Any]]] = None # Add tools field for OpenAI compatibility
|
| 58 |
variables: Optional[Dict[str, str]] = None
|
| 59 |
model_config = {"protected_namespaces": ()}
|
app/utils/sse_tool_handler.py
CHANGED
|
@@ -28,7 +28,6 @@ class SSEToolHandler:
|
|
| 28 |
self.content_index = 0
|
| 29 |
self.has_thinking = False
|
| 30 |
|
| 31 |
-
# 原生内容重建机制 - 基于 Z.AI 的 edit_index 机制
|
| 32 |
self.content_buffer = bytearray() # 使用字节数组提高性能
|
| 33 |
self.last_edit_index = 0 # 上次编辑的位置
|
| 34 |
|
|
@@ -39,7 +38,7 @@ class SSEToolHandler:
|
|
| 39 |
|
| 40 |
def process_tool_call_phase(self, data: Dict[str, Any], is_stream: bool = True) -> Generator[str, None, None]:
|
| 41 |
"""
|
| 42 |
-
处理tool_call阶段
|
| 43 |
"""
|
| 44 |
if not self.has_tool_call:
|
| 45 |
self.has_tool_call = True
|
|
@@ -53,7 +52,7 @@ class SSEToolHandler:
|
|
| 53 |
|
| 54 |
# logger.debug(f"📦 接收内容片段 [index={edit_index}]: {edit_content[:1000]}...")
|
| 55 |
|
| 56 |
-
#
|
| 57 |
self._apply_edit_to_buffer(edit_index, edit_content)
|
| 58 |
|
| 59 |
# 尝试解析和处理工具调用
|
|
@@ -61,8 +60,7 @@ class SSEToolHandler:
|
|
| 61 |
|
| 62 |
def _apply_edit_to_buffer(self, edit_index: int, edit_content: str):
|
| 63 |
"""
|
| 64 |
-
|
| 65 |
-
这是Z.AI的核心机制:在指定位置替换/插入内容
|
| 66 |
"""
|
| 67 |
edit_bytes = edit_content.encode('utf-8')
|
| 68 |
required_length = edit_index + len(edit_bytes)
|
|
@@ -97,7 +95,6 @@ class SSEToolHandler:
|
|
| 97 |
def _extract_and_process_tools(self, content_str: str, is_stream: bool) -> Generator[str, None, None]:
|
| 98 |
"""
|
| 99 |
从内容字符串中提取和处理工具调用
|
| 100 |
-
使用更原生的方式解析 glm_block
|
| 101 |
"""
|
| 102 |
# 查找所有 glm_block,包括不完整的
|
| 103 |
pattern = r'<glm_block\s*>(.*?)(?:</glm_block>|$)'
|
|
@@ -162,7 +159,7 @@ class SSEToolHandler:
|
|
| 162 |
|
| 163 |
def _handle_tool_update(self, tool_id: str, tool_name: str, arguments_raw: str, is_stream: bool) -> Generator[str, None, None]:
|
| 164 |
"""
|
| 165 |
-
处理工具的创建或更新
|
| 166 |
"""
|
| 167 |
# 解析参数
|
| 168 |
try:
|
|
@@ -173,36 +170,157 @@ class SSEToolHandler:
|
|
| 173 |
else:
|
| 174 |
arguments = arguments_raw
|
| 175 |
except json.JSONDecodeError:
|
| 176 |
-
logger.debug(f"📦
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
# 检查是否是新工具
|
| 180 |
if tool_id not in self.active_tools:
|
| 181 |
-
logger.debug(f"🎯 发现新工具: {tool_name}(id={tool_id})")
|
| 182 |
|
| 183 |
self.active_tools[tool_id] = {
|
| 184 |
"id": tool_id,
|
| 185 |
"name": tool_name,
|
| 186 |
"arguments": arguments,
|
|
|
|
| 187 |
"status": "active",
|
| 188 |
"sent_start": False,
|
| 189 |
-
"
|
|
|
|
|
|
|
| 190 |
}
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
yield self._create_tool_start_chunk(tool_id, tool_name)
|
| 195 |
self.active_tools[tool_id]["sent_start"] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
#
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
|
| 203 |
-
# 发送工具参数
|
| 204 |
-
yield self._create_tool_arguments_chunk(tool_id, arguments)
|
| 205 |
-
current_tool["sent_args"] = True
|
| 206 |
|
| 207 |
def _handle_partial_tool_block(self, block_content: str, is_stream: bool) -> Generator[str, None, None]:
|
| 208 |
"""
|
|
@@ -225,29 +343,38 @@ class SSEToolHandler:
|
|
| 225 |
|
| 226 |
# 如果是新工具,先创建记录
|
| 227 |
if tool_id not in self.active_tools:
|
|
|
|
|
|
|
|
|
|
| 228 |
self.active_tools[tool_id] = {
|
| 229 |
"id": tool_id,
|
| 230 |
"name": tool_name,
|
| 231 |
-
"arguments":
|
| 232 |
"status": "partial",
|
| 233 |
"sent_start": False,
|
| 234 |
-
"
|
|
|
|
| 235 |
"partial_args": partial_args
|
| 236 |
}
|
| 237 |
|
| 238 |
if is_stream:
|
| 239 |
-
yield self._create_tool_start_chunk(tool_id, tool_name)
|
| 240 |
self.active_tools[tool_id]["sent_start"] = True
|
|
|
|
| 241 |
else:
|
| 242 |
# 更新部分参数
|
| 243 |
self.active_tools[tool_id]["partial_args"] = partial_args
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
except Exception as e:
|
| 246 |
logger.debug(f"📦 部分块解析失败: {e}")
|
| 247 |
|
| 248 |
def _clean_arguments_string(self, arguments_raw: str) -> str:
|
| 249 |
"""
|
| 250 |
-
|
| 251 |
"""
|
| 252 |
if not arguments_raw:
|
| 253 |
return "{}"
|
|
@@ -266,6 +393,12 @@ class SSEToolHandler:
|
|
| 266 |
elif cleaned.startswith('"{\\"') and cleaned.endswith('\\"}'):
|
| 267 |
# 双重转义的情况
|
| 268 |
cleaned = cleaned[1:-1].replace('\\"', '"')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
# 标准化空格(移除JSON中的多余空格,但保留字符串值中的空格)
|
| 271 |
try:
|
|
@@ -276,10 +409,32 @@ class SSEToolHandler:
|
|
| 276 |
cleaned = json.dumps(parsed, ensure_ascii=False, separators=(',', ':'))
|
| 277 |
except json.JSONDecodeError:
|
| 278 |
# 如果解析失败,只做基本的空格清理
|
| 279 |
-
|
| 280 |
|
| 281 |
return cleaned
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
def _parse_partial_arguments(self, arguments_raw: str) -> Dict[str, Any]:
|
| 284 |
"""
|
| 285 |
解析不完整的参数字符串,尽可能提取有效信息
|
|
@@ -364,16 +519,29 @@ class SSEToolHandler:
|
|
| 364 |
|
| 365 |
def _complete_active_tools(self, is_stream: bool) -> Generator[str, None, None]:
|
| 366 |
"""
|
| 367 |
-
完成所有活跃的工具调用
|
| 368 |
"""
|
|
|
|
|
|
|
| 369 |
for tool_id, tool in self.active_tools.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
tool["status"] = "completed"
|
| 371 |
self.completed_tools.append(tool)
|
| 372 |
logger.debug(f"✅ 完成工具调用: {tool['name']}(id={tool_id})")
|
| 373 |
|
| 374 |
self.active_tools.clear()
|
| 375 |
|
| 376 |
-
if is_stream and self.completed_tools:
|
| 377 |
# 发送工具完成信号
|
| 378 |
yield self._create_tool_finish_chunk()
|
| 379 |
|
|
@@ -405,7 +573,7 @@ class SSEToolHandler:
|
|
| 405 |
|
| 406 |
if is_stream:
|
| 407 |
logger.info("🏁 发送工具调用完成信号")
|
| 408 |
-
yield "data: [DONE]
|
| 409 |
|
| 410 |
# 重置工具调用状态
|
| 411 |
self.has_tool_call = False
|
|
@@ -446,8 +614,12 @@ class SSEToolHandler:
|
|
| 446 |
self.completed_tools.clear()
|
| 447 |
self.tool_blocks_cache.clear()
|
| 448 |
|
| 449 |
-
def _create_tool_start_chunk(self, tool_id: str, tool_name: str) -> str:
|
| 450 |
-
"""创建工具调用开始的chunk"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
chunk = {
|
| 452 |
"choices": [
|
| 453 |
{
|
|
@@ -458,7 +630,7 @@ class SSEToolHandler:
|
|
| 458 |
{
|
| 459 |
"id": tool_id,
|
| 460 |
"type": "function",
|
| 461 |
-
"function": {"name": tool_name, "arguments":
|
| 462 |
}
|
| 463 |
],
|
| 464 |
},
|
|
@@ -476,18 +648,15 @@ class SSEToolHandler:
|
|
| 476 |
return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
| 477 |
|
| 478 |
def _create_tool_arguments_chunk(self, tool_id: str, arguments: Dict) -> str:
|
| 479 |
-
"""创建工具参数的chunk"""
|
| 480 |
chunk = {
|
| 481 |
"choices": [
|
| 482 |
{
|
| 483 |
"delta": {
|
| 484 |
-
"role": "assistant",
|
| 485 |
-
"content": None,
|
| 486 |
"tool_calls": [
|
| 487 |
{
|
| 488 |
"id": tool_id,
|
| 489 |
-
"
|
| 490 |
-
"function": {"name": None, "arguments": json.dumps(arguments, ensure_ascii=False)},
|
| 491 |
}
|
| 492 |
],
|
| 493 |
},
|
|
|
|
| 28 |
self.content_index = 0
|
| 29 |
self.has_thinking = False
|
| 30 |
|
|
|
|
| 31 |
self.content_buffer = bytearray() # 使用字节数组提高性能
|
| 32 |
self.last_edit_index = 0 # 上次编辑的位置
|
| 33 |
|
|
|
|
| 38 |
|
| 39 |
def process_tool_call_phase(self, data: Dict[str, Any], is_stream: bool = True) -> Generator[str, None, None]:
|
| 40 |
"""
|
| 41 |
+
处理tool_call阶段
|
| 42 |
"""
|
| 43 |
if not self.has_tool_call:
|
| 44 |
self.has_tool_call = True
|
|
|
|
| 52 |
|
| 53 |
# logger.debug(f"📦 接收内容片段 [index={edit_index}]: {edit_content[:1000]}...")
|
| 54 |
|
| 55 |
+
# 更新内容缓冲区
|
| 56 |
self._apply_edit_to_buffer(edit_index, edit_content)
|
| 57 |
|
| 58 |
# 尝试解析和处理工具调用
|
|
|
|
| 60 |
|
| 61 |
def _apply_edit_to_buffer(self, edit_index: int, edit_content: str):
|
| 62 |
"""
|
| 63 |
+
在指定位置替换/插入内容更新内容缓冲区
|
|
|
|
| 64 |
"""
|
| 65 |
edit_bytes = edit_content.encode('utf-8')
|
| 66 |
required_length = edit_index + len(edit_bytes)
|
|
|
|
| 95 |
def _extract_and_process_tools(self, content_str: str, is_stream: bool) -> Generator[str, None, None]:
|
| 96 |
"""
|
| 97 |
从内容字符串中提取和处理工具调用
|
|
|
|
| 98 |
"""
|
| 99 |
# 查找所有 glm_block,包括不完整的
|
| 100 |
pattern = r'<glm_block\s*>(.*?)(?:</glm_block>|$)'
|
|
|
|
| 159 |
|
| 160 |
def _handle_tool_update(self, tool_id: str, tool_name: str, arguments_raw: str, is_stream: bool) -> Generator[str, None, None]:
|
| 161 |
"""
|
| 162 |
+
处理工具的创建或更新 - 更可靠的参数完整性检查
|
| 163 |
"""
|
| 164 |
# 解析参数
|
| 165 |
try:
|
|
|
|
| 170 |
else:
|
| 171 |
arguments = arguments_raw
|
| 172 |
except json.JSONDecodeError:
|
| 173 |
+
logger.debug(f"📦 参数解析失败,暂不处理: {arguments_raw}")
|
| 174 |
+
# 参数解析失败时,不创建或更新工具,等待更完整的数据
|
| 175 |
+
return
|
| 176 |
+
|
| 177 |
+
# 检查参数是否看起来完整(基本的完整性验证)
|
| 178 |
+
is_args_complete = self._is_arguments_complete(arguments, arguments_raw)
|
| 179 |
|
| 180 |
# 检查是否是新工具
|
| 181 |
if tool_id not in self.active_tools:
|
| 182 |
+
logger.debug(f"🎯 发现新工具: {tool_name}(id={tool_id}), 参数完整性: {is_args_complete}")
|
| 183 |
|
| 184 |
self.active_tools[tool_id] = {
|
| 185 |
"id": tool_id,
|
| 186 |
"name": tool_name,
|
| 187 |
"arguments": arguments,
|
| 188 |
+
"arguments_raw": arguments_raw,
|
| 189 |
"status": "active",
|
| 190 |
"sent_start": False,
|
| 191 |
+
"last_sent_args": {}, # 跟踪上次发送的参数
|
| 192 |
+
"args_complete": is_args_complete,
|
| 193 |
+
"pending_send": True # 标记需要发送
|
| 194 |
}
|
| 195 |
|
| 196 |
+
# 只有在参数看起来完整时才发送工具开始信号
|
| 197 |
+
if is_stream and is_args_complete:
|
| 198 |
+
yield self._create_tool_start_chunk(tool_id, tool_name, arguments)
|
| 199 |
self.active_tools[tool_id]["sent_start"] = True
|
| 200 |
+
self.active_tools[tool_id]["last_sent_args"] = arguments.copy()
|
| 201 |
+
self.active_tools[tool_id]["pending_send"] = False
|
| 202 |
+
logger.debug(f"📤 发送完整工具开始: {tool_name}(id={tool_id})")
|
| 203 |
+
|
| 204 |
+
else:
|
| 205 |
+
# 更新现有工具
|
| 206 |
+
current_tool = self.active_tools[tool_id]
|
| 207 |
+
|
| 208 |
+
# 检查是否有实质性改进
|
| 209 |
+
if self._is_significant_improvement(current_tool["arguments"], arguments,
|
| 210 |
+
current_tool["arguments_raw"], arguments_raw):
|
| 211 |
+
logger.debug(f"🔄 工具参数有实质性改进: {tool_name}(id={tool_id})")
|
| 212 |
+
|
| 213 |
+
current_tool["arguments"] = arguments
|
| 214 |
+
current_tool["arguments_raw"] = arguments_raw
|
| 215 |
+
current_tool["args_complete"] = is_args_complete
|
| 216 |
+
|
| 217 |
+
# 如果之前没有发送过开始信号,且现在参数完整,发送开始信号
|
| 218 |
+
if is_stream and not current_tool["sent_start"] and is_args_complete:
|
| 219 |
+
yield self._create_tool_start_chunk(tool_id, tool_name, arguments)
|
| 220 |
+
current_tool["sent_start"] = True
|
| 221 |
+
current_tool["last_sent_args"] = arguments.copy()
|
| 222 |
+
current_tool["pending_send"] = False
|
| 223 |
+
logger.debug(f"📤 发送延迟的工具开始: {tool_name}(id={tool_id})")
|
| 224 |
+
|
| 225 |
+
# 如果已经发送过开始信号,且参数有显著改进,发送参数更新
|
| 226 |
+
elif is_stream and current_tool["sent_start"] and is_args_complete:
|
| 227 |
+
if self._should_send_argument_update(current_tool["last_sent_args"], arguments):
|
| 228 |
+
yield self._create_tool_arguments_chunk(tool_id, arguments)
|
| 229 |
+
current_tool["last_sent_args"] = arguments.copy()
|
| 230 |
+
logger.debug(f"📤 发送参数更新: {tool_name}(id={tool_id})")
|
| 231 |
+
|
| 232 |
+
def _is_arguments_complete(self, arguments: Dict[str, Any], arguments_raw: str) -> bool:
|
| 233 |
+
"""
|
| 234 |
+
检查参数是否看起来完整
|
| 235 |
+
"""
|
| 236 |
+
if not arguments:
|
| 237 |
+
return False
|
| 238 |
+
|
| 239 |
+
# 检查原始字符串是否看起来完整
|
| 240 |
+
if not arguments_raw or not arguments_raw.strip():
|
| 241 |
+
return False
|
| 242 |
+
|
| 243 |
+
# 检查是否有明显的截断迹象
|
| 244 |
+
raw_stripped = arguments_raw.strip()
|
| 245 |
+
|
| 246 |
+
# 如果原始字符串不以}结尾,可能是截断的
|
| 247 |
+
if not raw_stripped.endswith('}') and not raw_stripped.endswith('"'):
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
# 检查是否有不完整的URL(常见的截断情况)
|
| 251 |
+
for key, value in arguments.items():
|
| 252 |
+
if isinstance(value, str):
|
| 253 |
+
# 检查URL是否看起来完整
|
| 254 |
+
if 'http' in value.lower():
|
| 255 |
+
# 如果URL太短或以不完整的域名结尾,可能是截断的
|
| 256 |
+
if len(value) < 10 or value.endswith('.go') or value.endswith('.goo'):
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
# 检查其他可能的截断迹象
|
| 260 |
+
if len(value) > 0 and value[-1] in ['.', '/', ':', '=']:
|
| 261 |
+
# 以这些字符结尾可能表示截断
|
| 262 |
+
return False
|
| 263 |
+
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
def _is_significant_improvement(self, old_args: Dict[str, Any], new_args: Dict[str, Any],
|
| 267 |
+
old_raw: str, new_raw: str) -> bool:
|
| 268 |
+
"""
|
| 269 |
+
检查新参数是否比旧参数有显著改进
|
| 270 |
+
"""
|
| 271 |
+
# 如果新参数为空,不是改进
|
| 272 |
+
if not new_args:
|
| 273 |
+
return False
|
| 274 |
+
|
| 275 |
+
if len(new_args) > len(old_args):
|
| 276 |
+
return True
|
| 277 |
+
|
| 278 |
+
# 检查值的改进
|
| 279 |
+
for key, new_value in new_args.items():
|
| 280 |
+
old_value = old_args.get(key, "")
|
| 281 |
+
|
| 282 |
+
if isinstance(new_value, str) and isinstance(old_value, str):
|
| 283 |
+
# 如果新值明显更长且更完整,是改进
|
| 284 |
+
if len(new_value) > len(old_value) + 5: # 至少长5个字符才算显著改进
|
| 285 |
+
return True
|
| 286 |
+
|
| 287 |
+
# 如果旧值看起来是截断的,新值更完整,是改进
|
| 288 |
+
if old_value.endswith(('.go', '.goo', '.com/', 'http')) and len(new_value) > len(old_value):
|
| 289 |
+
return True
|
| 290 |
+
|
| 291 |
+
# 检查原始字符串的改进
|
| 292 |
+
if len(new_raw) > len(old_raw) + 10: # 原始字符串显著增长
|
| 293 |
+
return True
|
| 294 |
+
|
| 295 |
+
return False
|
| 296 |
+
|
| 297 |
+
def _should_send_argument_update(self, last_sent: Dict[str, Any], new_args: Dict[str, Any]) -> bool:
|
| 298 |
+
"""
|
| 299 |
+
判断是否应该发送参数更新 - 更严格的标准
|
| 300 |
+
"""
|
| 301 |
+
# 如果参数完全相同,不发送
|
| 302 |
+
if last_sent == new_args:
|
| 303 |
+
return False
|
| 304 |
+
|
| 305 |
+
# 如果新参数为空但之前有参数,不发送(避免倒退)
|
| 306 |
+
if not new_args and last_sent:
|
| 307 |
+
return False
|
| 308 |
+
|
| 309 |
+
# 如果新参数有更多键,发送更新
|
| 310 |
+
if len(new_args) > len(last_sent):
|
| 311 |
+
return True
|
| 312 |
|
| 313 |
+
# 检查是否有值变得显著更完整
|
| 314 |
+
for key, new_value in new_args.items():
|
| 315 |
+
last_value = last_sent.get(key, "")
|
| 316 |
+
if isinstance(new_value, str) and isinstance(last_value, str):
|
| 317 |
+
# 只有在值显著增长时才发送更新(避免微小变化)
|
| 318 |
+
if len(new_value) > len(last_value) + 5:
|
| 319 |
+
return True
|
| 320 |
+
elif new_value != last_value and new_value: # 确保新值不为空
|
| 321 |
+
return True
|
| 322 |
|
| 323 |
+
return False
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
def _handle_partial_tool_block(self, block_content: str, is_stream: bool) -> Generator[str, None, None]:
|
| 326 |
"""
|
|
|
|
| 343 |
|
| 344 |
# 如果是新工具,先创建记录
|
| 345 |
if tool_id not in self.active_tools:
|
| 346 |
+
# 尝试解析部分参数为字典
|
| 347 |
+
partial_args_dict = self._parse_partial_arguments(partial_args)
|
| 348 |
+
|
| 349 |
self.active_tools[tool_id] = {
|
| 350 |
"id": tool_id,
|
| 351 |
"name": tool_name,
|
| 352 |
+
"arguments": partial_args_dict,
|
| 353 |
"status": "partial",
|
| 354 |
"sent_start": False,
|
| 355 |
+
"last_sent_args": {},
|
| 356 |
+
"args_complete": False,
|
| 357 |
"partial_args": partial_args
|
| 358 |
}
|
| 359 |
|
| 360 |
if is_stream:
|
| 361 |
+
yield self._create_tool_start_chunk(tool_id, tool_name, partial_args_dict)
|
| 362 |
self.active_tools[tool_id]["sent_start"] = True
|
| 363 |
+
self.active_tools[tool_id]["last_sent_args"] = partial_args_dict.copy()
|
| 364 |
else:
|
| 365 |
# 更新部分参数
|
| 366 |
self.active_tools[tool_id]["partial_args"] = partial_args
|
| 367 |
+
# 尝试更新解析的参数
|
| 368 |
+
new_partial_dict = self._parse_partial_arguments(partial_args)
|
| 369 |
+
if new_partial_dict != self.active_tools[tool_id]["arguments"]:
|
| 370 |
+
self.active_tools[tool_id]["arguments"] = new_partial_dict
|
| 371 |
|
| 372 |
except Exception as e:
|
| 373 |
logger.debug(f"📦 部分块解析失败: {e}")
|
| 374 |
|
| 375 |
def _clean_arguments_string(self, arguments_raw: str) -> str:
|
| 376 |
"""
|
| 377 |
+
清理和标准化参数字符串,改进对不完整JSON的处理
|
| 378 |
"""
|
| 379 |
if not arguments_raw:
|
| 380 |
return "{}"
|
|
|
|
| 393 |
elif cleaned.startswith('"{\\"') and cleaned.endswith('\\"}'):
|
| 394 |
# 双重转义的情况
|
| 395 |
cleaned = cleaned[1:-1].replace('\\"', '"')
|
| 396 |
+
elif cleaned.startswith('"') and cleaned.endswith('"'):
|
| 397 |
+
# 简单的引号包围,去除外层引号
|
| 398 |
+
cleaned = cleaned[1:-1]
|
| 399 |
+
|
| 400 |
+
# 处理不完整的JSON字符串
|
| 401 |
+
cleaned = self._fix_incomplete_json(cleaned)
|
| 402 |
|
| 403 |
# 标准化空格(移除JSON中的多余空格,但保留字符串值中的空格)
|
| 404 |
try:
|
|
|
|
| 409 |
cleaned = json.dumps(parsed, ensure_ascii=False, separators=(',', ':'))
|
| 410 |
except json.JSONDecodeError:
|
| 411 |
# 如果解析失败,只做基本的空格清理
|
| 412 |
+
logger.debug(f"📦 JSON标准化失败,保持原样: {cleaned[:50]}...")
|
| 413 |
|
| 414 |
return cleaned
|
| 415 |
|
| 416 |
+
def _fix_incomplete_json(self, json_str: str) -> str:
|
| 417 |
+
"""
|
| 418 |
+
修复不完整的JSON字符串
|
| 419 |
+
"""
|
| 420 |
+
if not json_str:
|
| 421 |
+
return "{}"
|
| 422 |
+
|
| 423 |
+
# 确保以{开头
|
| 424 |
+
if not json_str.startswith('{'):
|
| 425 |
+
json_str = '{' + json_str
|
| 426 |
+
|
| 427 |
+
# 处理不完整的字符串值
|
| 428 |
+
if json_str.count('"') % 2 != 0:
|
| 429 |
+
# 奇数个引号,可能有未闭合的字符串
|
| 430 |
+
json_str += '"'
|
| 431 |
+
|
| 432 |
+
# 确保以}结尾
|
| 433 |
+
if not json_str.endswith('}'):
|
| 434 |
+
json_str += '}'
|
| 435 |
+
|
| 436 |
+
return json_str
|
| 437 |
+
|
| 438 |
def _parse_partial_arguments(self, arguments_raw: str) -> Dict[str, Any]:
|
| 439 |
"""
|
| 440 |
解析不完整的参数字符串,尽可能提取有效信息
|
|
|
|
| 519 |
|
| 520 |
def _complete_active_tools(self, is_stream: bool) -> Generator[str, None, None]:
|
| 521 |
"""
|
| 522 |
+
完成所有活跃的工具调用 - 处理待发送的工具
|
| 523 |
"""
|
| 524 |
+
tools_to_send = []
|
| 525 |
+
|
| 526 |
for tool_id, tool in self.active_tools.items():
|
| 527 |
+
# 如果工具还没有发送过且参数看起来完整,现在发送
|
| 528 |
+
if is_stream and tool.get("pending_send", False) and not tool.get("sent_start", False):
|
| 529 |
+
if tool.get("args_complete", False):
|
| 530 |
+
logger.debug(f"📤 完成时发送待发送工具: {tool['name']}(id={tool_id})")
|
| 531 |
+
yield self._create_tool_start_chunk(tool_id, tool["name"], tool["arguments"])
|
| 532 |
+
tool["sent_start"] = True
|
| 533 |
+
tool["pending_send"] = False
|
| 534 |
+
tools_to_send.append(tool)
|
| 535 |
+
else:
|
| 536 |
+
logger.debug(f"⚠️ 跳过不完整的工具: {tool['name']}(id={tool_id})")
|
| 537 |
+
|
| 538 |
tool["status"] = "completed"
|
| 539 |
self.completed_tools.append(tool)
|
| 540 |
logger.debug(f"✅ 完成工具调用: {tool['name']}(id={tool_id})")
|
| 541 |
|
| 542 |
self.active_tools.clear()
|
| 543 |
|
| 544 |
+
if is_stream and (self.completed_tools or tools_to_send):
|
| 545 |
# 发送工具完成信号
|
| 546 |
yield self._create_tool_finish_chunk()
|
| 547 |
|
|
|
|
| 573 |
|
| 574 |
if is_stream:
|
| 575 |
logger.info("🏁 发送工具调用完成信号")
|
| 576 |
+
yield "data: [DONE]"
|
| 577 |
|
| 578 |
# 重置工具调用状态
|
| 579 |
self.has_tool_call = False
|
|
|
|
| 614 |
self.completed_tools.clear()
|
| 615 |
self.tool_blocks_cache.clear()
|
| 616 |
|
| 617 |
+
def _create_tool_start_chunk(self, tool_id: str, tool_name: str, initial_args: Dict[str, Any] = None) -> str:
|
| 618 |
+
"""创建工具调用开始的chunk,支持初始参数"""
|
| 619 |
+
# 使用提供的初始参数,如果没有则使用空字典
|
| 620 |
+
args_dict = initial_args or {}
|
| 621 |
+
args_str = json.dumps(args_dict, ensure_ascii=False)
|
| 622 |
+
|
| 623 |
chunk = {
|
| 624 |
"choices": [
|
| 625 |
{
|
|
|
|
| 630 |
{
|
| 631 |
"id": tool_id,
|
| 632 |
"type": "function",
|
| 633 |
+
"function": {"name": tool_name, "arguments": args_str},
|
| 634 |
}
|
| 635 |
],
|
| 636 |
},
|
|
|
|
| 648 |
return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
| 649 |
|
| 650 |
def _create_tool_arguments_chunk(self, tool_id: str, arguments: Dict) -> str:
|
| 651 |
+
"""创建工具参数的chunk - 只包含参数更新,不包含函数名"""
|
| 652 |
chunk = {
|
| 653 |
"choices": [
|
| 654 |
{
|
| 655 |
"delta": {
|
|
|
|
|
|
|
| 656 |
"tool_calls": [
|
| 657 |
{
|
| 658 |
"id": tool_id,
|
| 659 |
+
"function": {"arguments": json.dumps(arguments, ensure_ascii=False)},
|
|
|
|
| 660 |
}
|
| 661 |
],
|
| 662 |
},
|