|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
SSE Tool Handler |
|
|
|
|
|
处理 Z.AI SSE 流数据并转换为 OpenAI 兼容格式的工具调用处理器。 |
|
|
|
|
|
主要功能: |
|
|
- 解析 glm_block 格式的工具调用 |
|
|
- 从 metadata.arguments 提取完整参数 |
|
|
- 支持多阶段处理:thinking → tool_call → other → answer |
|
|
- 输出符合 OpenAI API 规范的流式响应 |
|
|
""" |
|
|
|
|
|
import json |
|
|
import time |
|
|
from typing import Dict, Any, Generator |
|
|
from enum import Enum |
|
|
|
|
|
from app.utils.logger import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class SSEPhase(Enum): |
|
|
"""SSE 处理阶段枚举""" |
|
|
THINKING = "thinking" |
|
|
TOOL_CALL = "tool_call" |
|
|
OTHER = "other" |
|
|
ANSWER = "answer" |
|
|
DONE = "done" |
|
|
|
|
|
|
|
|
class SSEToolHandler: |
|
|
"""SSE 工具调用处理器""" |
|
|
|
|
|
def __init__(self, model: str, stream: bool = True): |
|
|
self.model = model |
|
|
self.stream = stream |
|
|
|
|
|
|
|
|
self.current_phase = None |
|
|
self.has_tool_call = False |
|
|
|
|
|
|
|
|
self.tool_id = "" |
|
|
self.tool_name = "" |
|
|
self.tool_args = "" |
|
|
self.tool_call_usage = {} |
|
|
self.content_index = 0 |
|
|
|
|
|
|
|
|
self.content_buffer = "" |
|
|
self.buffer_size = 0 |
|
|
self.last_flush_time = time.time() |
|
|
self.flush_interval = 0.05 |
|
|
self.max_buffer_size = 100 |
|
|
|
|
|
logger.debug(f"🔧 初始化工具处理器: model={model}, stream={stream}") |
|
|
|
|
|
def process_sse_chunk(self, chunk_data: Dict[str, Any]) -> Generator[str, None, None]: |
|
|
""" |
|
|
处理 SSE 数据块,返回 OpenAI 格式的流式响应 |
|
|
|
|
|
Args: |
|
|
chunk_data: Z.AI SSE 数据块 |
|
|
|
|
|
Yields: |
|
|
str: OpenAI 格式的 SSE 响应行 |
|
|
""" |
|
|
try: |
|
|
phase = chunk_data.get("phase") |
|
|
edit_content = chunk_data.get("edit_content", "") |
|
|
delta_content = chunk_data.get("delta_content", "") |
|
|
edit_index = chunk_data.get("edit_index") |
|
|
usage = chunk_data.get("usage", {}) |
|
|
|
|
|
|
|
|
if not phase: |
|
|
logger.warning("⚠️ 收到无效的 SSE 块:缺少 phase 字段") |
|
|
return |
|
|
|
|
|
|
|
|
if phase != self.current_phase: |
|
|
|
|
|
if hasattr(self, 'content_buffer') and self.content_buffer: |
|
|
yield from self._flush_content_buffer() |
|
|
|
|
|
logger.info(f"📈 SSE 阶段变化: {self.current_phase} → {phase}") |
|
|
content_preview = edit_content or delta_content |
|
|
if content_preview: |
|
|
logger.debug(f" 📝 内容预览: {content_preview[:1000]}{'...' if len(content_preview) > 1000 else ''}") |
|
|
if edit_index is not None: |
|
|
logger.debug(f" 📍 edit_index: {edit_index}") |
|
|
self.current_phase = phase |
|
|
|
|
|
|
|
|
if phase == SSEPhase.THINKING.value: |
|
|
yield from self._process_thinking_phase(delta_content) |
|
|
|
|
|
elif phase == SSEPhase.TOOL_CALL.value: |
|
|
yield from self._process_tool_call_phase(edit_content) |
|
|
|
|
|
elif phase == SSEPhase.OTHER.value: |
|
|
yield from self._process_other_phase(usage, edit_content) |
|
|
|
|
|
elif phase == SSEPhase.ANSWER.value: |
|
|
yield from self._process_answer_phase(delta_content) |
|
|
|
|
|
elif phase == SSEPhase.DONE.value: |
|
|
yield from self._process_done_phase(chunk_data) |
|
|
else: |
|
|
logger.warning(f"⚠️ 未知的 SSE 阶段: {phase}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ 处理 SSE 块时发生错误: {e}") |
|
|
logger.debug(f" 📦 错误块数据: {chunk_data}") |
|
|
|
|
|
|
|
|
def _process_thinking_phase(self, delta_content: str) -> Generator[str, None, None]: |
|
|
"""处理思考阶段""" |
|
|
if not delta_content: |
|
|
return |
|
|
|
|
|
logger.debug(f"🤔 思考内容: +{len(delta_content)} 字符") |
|
|
|
|
|
|
|
|
if self.stream: |
|
|
chunk = self._create_content_chunk(delta_content) |
|
|
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
|
|
|
def _process_tool_call_phase(self, edit_content: str) -> Generator[str, None, None]: |
|
|
"""处理工具调用阶段""" |
|
|
if not edit_content: |
|
|
return |
|
|
|
|
|
logger.debug(f"🔧 进入工具调用阶段,内容长度: {len(edit_content)}") |
|
|
|
|
|
|
|
|
if "<glm_block " in edit_content: |
|
|
yield from self._handle_glm_blocks(edit_content) |
|
|
else: |
|
|
|
|
|
if self.has_tool_call: |
|
|
|
|
|
result_pos = edit_content.find('", "result"') |
|
|
if result_pos > 0: |
|
|
param_fragment = edit_content[:result_pos] |
|
|
self.tool_args += param_fragment |
|
|
logger.debug(f"📦 累积参数片段: {param_fragment}") |
|
|
else: |
|
|
|
|
|
self.tool_args += edit_content |
|
|
logger.debug(f"📦 累积参数片段: {edit_content[:100]}...") |
|
|
|
|
|
def _handle_glm_blocks(self, edit_content: str) -> Generator[str, None, None]: |
|
|
"""处理 glm_block 标记的内容""" |
|
|
blocks = edit_content.split('<glm_block ') |
|
|
logger.debug(f"📦 分割得到 {len(blocks)} 个块") |
|
|
|
|
|
for index, block in enumerate(blocks): |
|
|
if not block.strip(): |
|
|
continue |
|
|
|
|
|
if index == 0: |
|
|
|
|
|
if self.has_tool_call: |
|
|
logger.debug(f"📦 从第一个块提取参数片段") |
|
|
|
|
|
result_pos = edit_content.find('"result"') |
|
|
if result_pos > 0: |
|
|
|
|
|
param_fragment = edit_content[:result_pos - 3] |
|
|
self.tool_args += param_fragment |
|
|
logger.debug(f"📦 累积参数片段: {param_fragment}") |
|
|
else: |
|
|
|
|
|
continue |
|
|
else: |
|
|
|
|
|
if "</glm_block>" not in block: |
|
|
continue |
|
|
|
|
|
|
|
|
if self.has_tool_call: |
|
|
|
|
|
self.tool_args += '"' |
|
|
yield from self._finish_current_tool() |
|
|
|
|
|
|
|
|
yield from self._process_metadata_block(block) |
|
|
|
|
|
def _process_metadata_block(self, block: str) -> Generator[str, None, None]: |
|
|
"""处理包含工具元数据的块""" |
|
|
try: |
|
|
|
|
|
start_pos = block.find('>') |
|
|
end_pos = block.rfind('</glm_block>') |
|
|
|
|
|
if start_pos == -1 or end_pos == -1: |
|
|
logger.warning(f"❌ 无法找到 JSON 内容边界: {block[:1000]}...") |
|
|
return |
|
|
|
|
|
json_content = block[start_pos + 1:end_pos] |
|
|
logger.debug(f"📦 提取的 JSON 内容: {json_content[:1000]}...") |
|
|
|
|
|
|
|
|
metadata_obj = json.loads(json_content) |
|
|
|
|
|
if "data" in metadata_obj and "metadata" in metadata_obj["data"]: |
|
|
metadata = metadata_obj["data"]["metadata"] |
|
|
|
|
|
|
|
|
self.tool_id = metadata.get("id", f"call_{int(time.time() * 1000000)}") |
|
|
self.tool_name = metadata.get("name", "unknown") |
|
|
self.has_tool_call = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "arguments" in metadata: |
|
|
arguments_str = metadata["arguments"] |
|
|
|
|
|
self.tool_args = arguments_str[:-1] if arguments_str.endswith('"') else arguments_str |
|
|
logger.debug(f"🎯 新工具调用: {self.tool_name}(id={self.tool_id}), 初始参数: {self.tool_args}") |
|
|
else: |
|
|
self.tool_args = "{}" |
|
|
logger.debug(f"🎯 新工具调用: {self.tool_name}(id={self.tool_id}), 空参数") |
|
|
|
|
|
except (json.JSONDecodeError, KeyError, AttributeError) as e: |
|
|
logger.error(f"❌ 解析工具元数据失败: {e}, 块内容: {block[:1000]}...") |
|
|
|
|
|
|
|
|
if False: |
|
|
yield |
|
|
|
|
|
def _process_other_phase(self, usage: Dict[str, Any], edit_content: str = "") -> Generator[str, None, None]: |
|
|
"""处理其他阶段""" |
|
|
|
|
|
if usage: |
|
|
self.tool_call_usage = usage |
|
|
logger.debug(f"📊 保存使用统计: {usage}") |
|
|
|
|
|
|
|
|
if self.has_tool_call and edit_content and edit_content.startswith("null,"): |
|
|
logger.info(f"🏁 检测到工具调用结束标记") |
|
|
|
|
|
|
|
|
yield from self._finish_current_tool() |
|
|
|
|
|
|
|
|
if self.stream: |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
|
|
|
self._reset_all_state() |
|
|
|
|
|
def _process_answer_phase(self, delta_content: str) -> Generator[str, None, None]: |
|
|
"""处理回答阶段(优化版本)""" |
|
|
if not delta_content: |
|
|
return |
|
|
|
|
|
logger.info(f"📝 工具处理器收到答案内容: {delta_content[:50]}...") |
|
|
|
|
|
|
|
|
self.content_buffer += delta_content |
|
|
self.buffer_size += len(delta_content) |
|
|
|
|
|
current_time = time.time() |
|
|
time_since_last_flush = current_time - self.last_flush_time |
|
|
|
|
|
|
|
|
should_flush = ( |
|
|
self.buffer_size >= self.max_buffer_size or |
|
|
time_since_last_flush >= self.flush_interval or |
|
|
'\n' in delta_content or |
|
|
'。' in delta_content or '!' in delta_content or '?' in delta_content |
|
|
) |
|
|
|
|
|
if should_flush and self.content_buffer: |
|
|
yield from self._flush_content_buffer() |
|
|
|
|
|
def _flush_content_buffer(self) -> Generator[str, None, None]: |
|
|
"""刷新内容缓冲区""" |
|
|
if not self.content_buffer: |
|
|
return |
|
|
|
|
|
logger.info(f"💬 工具处理器刷新缓冲区: {self.buffer_size} 字符 - {self.content_buffer[:50]}...") |
|
|
|
|
|
if self.stream: |
|
|
chunk = self._create_content_chunk(self.content_buffer) |
|
|
output_data = f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
logger.info(f"➡️ 工具处理器输出: {output_data[:100]}...") |
|
|
yield output_data |
|
|
|
|
|
|
|
|
self.content_buffer = "" |
|
|
self.buffer_size = 0 |
|
|
self.last_flush_time = time.time() |
|
|
|
|
|
def _process_done_phase(self, chunk_data: Dict[str, Any]) -> Generator[str, None, None]: |
|
|
"""处理完成阶段""" |
|
|
logger.info("🏁 对话完成") |
|
|
|
|
|
|
|
|
if self.content_buffer: |
|
|
yield from self._flush_content_buffer() |
|
|
|
|
|
|
|
|
if self.has_tool_call: |
|
|
yield from self._finish_current_tool() |
|
|
|
|
|
|
|
|
if self.stream: |
|
|
|
|
|
final_chunk = { |
|
|
"id": f"chatcmpl-{int(time.time())}", |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": self.model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": {}, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
} |
|
|
|
|
|
|
|
|
if "usage" in chunk_data: |
|
|
final_chunk["usage"] = chunk_data["usage"] |
|
|
|
|
|
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
|
|
|
self._reset_all_state() |
|
|
|
|
|
def _finish_current_tool(self) -> Generator[str, None, None]: |
|
|
"""完成当前工具调用""" |
|
|
if not self.has_tool_call: |
|
|
return |
|
|
|
|
|
|
|
|
fixed_args = self._fix_tool_arguments(self.tool_args) |
|
|
logger.debug(f"✅ 完成工具调用: {self.tool_name}, 参数: {fixed_args}") |
|
|
|
|
|
|
|
|
if self.stream: |
|
|
|
|
|
start_chunk = self._create_tool_start_chunk() |
|
|
yield f"data: {json.dumps(start_chunk, ensure_ascii=False)}\n\n" |
|
|
|
|
|
|
|
|
args_chunk = self._create_tool_arguments_chunk(fixed_args) |
|
|
yield f"data: {json.dumps(args_chunk, ensure_ascii=False)}\n\n" |
|
|
|
|
|
|
|
|
finish_chunk = self._create_tool_finish_chunk() |
|
|
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" |
|
|
|
|
|
|
|
|
self._reset_tool_state() |
|
|
|
|
|
def _fix_tool_arguments(self, raw_args: str) -> str: |
|
|
"""使用 json-repair 库修复工具参数格式""" |
|
|
if not raw_args or raw_args == "{}": |
|
|
return "{}" |
|
|
|
|
|
logger.debug(f"🔧 开始修复参数: {raw_args[:1000]}{'...' if len(raw_args) > 1000 else ''}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
processed_args = self._preprocess_json_string(raw_args.strip()) |
|
|
|
|
|
|
|
|
from json_repair import repair_json |
|
|
repaired_json = repair_json(processed_args) |
|
|
logger.debug(f"🔧 json-repair 修复结果: {repaired_json}") |
|
|
|
|
|
|
|
|
args_obj = json.loads(repaired_json) |
|
|
args_obj = self._post_process_args(args_obj) |
|
|
|
|
|
|
|
|
fixed_result = json.dumps(args_obj, ensure_ascii=False) |
|
|
|
|
|
return fixed_result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ JSON 修复失败: {e}, 原始参数: {raw_args[:1000]}..., 使用空参数") |
|
|
return "{}" |
|
|
|
|
|
def _post_process_args(self, args_obj: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""统一的后处理方法""" |
|
|
|
|
|
args_obj = self._fix_path_escaping_in_args(args_obj) |
|
|
|
|
|
|
|
|
args_obj = self._fix_command_quotes(args_obj) |
|
|
|
|
|
return args_obj |
|
|
|
|
|
def _preprocess_json_string(self, text: str) -> str: |
|
|
"""预处理 JSON 字符串,只处理 json-repair 无法处理的问题""" |
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not text.startswith('{') and text.endswith('}'): |
|
|
text = '{' + text |
|
|
logger.debug(f"🔧 补全开始括号") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pattern = r'([^\\])\\"([}\s,])' |
|
|
if re.search(pattern, text): |
|
|
text = re.sub(pattern, r'\1"\2', text) |
|
|
logger.debug(f"🔧 修复末尾多余的反斜杠") |
|
|
|
|
|
return text |
|
|
|
|
|
def _fix_path_escaping_in_args(self, args_obj: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""修复参数对象中路径的过度转义问题""" |
|
|
import re |
|
|
|
|
|
|
|
|
path_fields = ['file_path', 'path', 'directory', 'folder'] |
|
|
|
|
|
for field in path_fields: |
|
|
if field in args_obj and isinstance(args_obj[field], str): |
|
|
path_value = args_obj[field] |
|
|
|
|
|
|
|
|
if path_value.startswith('C:') and '\\\\' in path_value: |
|
|
logger.debug(f"🔍 检查路径字段 {field}: {repr(path_value)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fixed_path = path_value |
|
|
|
|
|
|
|
|
if '\\\\' in path_value: |
|
|
|
|
|
parts = path_value.split('\\') |
|
|
|
|
|
clean_parts = [part for part in parts if part] |
|
|
if len(clean_parts) > 1: |
|
|
fixed_path = '\\'.join(clean_parts) |
|
|
|
|
|
logger.debug(f"🔍 修复后路径: {repr(fixed_path)}") |
|
|
|
|
|
if fixed_path != path_value: |
|
|
args_obj[field] = fixed_path |
|
|
logger.debug(f"🔧 修复字段 {field} 的路径转义: {path_value} -> {fixed_path}") |
|
|
else: |
|
|
logger.debug(f"🔍 路径无需修复: {path_value}") |
|
|
|
|
|
return args_obj |
|
|
|
|
|
def _fix_command_quotes(self, args_obj: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""修复命令中的多余引号问题""" |
|
|
import re |
|
|
|
|
|
|
|
|
if 'command' in args_obj and isinstance(args_obj['command'], str): |
|
|
command = args_obj['command'] |
|
|
|
|
|
|
|
|
if command.endswith('""'): |
|
|
logger.debug(f"🔧 发现命令末尾多余引号: {command}") |
|
|
|
|
|
fixed_command = command[:-1] |
|
|
args_obj['command'] = fixed_command |
|
|
logger.debug(f"🔧 修复命令引号: {command} -> {fixed_command}") |
|
|
|
|
|
|
|
|
|
|
|
elif re.search(r'\\""+$', command): |
|
|
logger.debug(f"🔧 发现命令末尾引号模式问题: {command}") |
|
|
|
|
|
fixed_command = re.sub(r'\\""+$', '\\"', command) |
|
|
args_obj['command'] = fixed_command |
|
|
logger.debug(f"🔧 修复命令引号模式: {command} -> {fixed_command}") |
|
|
|
|
|
return args_obj |
|
|
|
|
|
def _create_content_chunk(self, content: str) -> Dict[str, Any]: |
|
|
"""创建内容块""" |
|
|
return { |
|
|
"id": f"chatcmpl-{int(time.time())}", |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": self.model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": { |
|
|
"role": "assistant", |
|
|
"content": content |
|
|
}, |
|
|
"finish_reason": None |
|
|
}] |
|
|
} |
|
|
|
|
|
def _create_tool_start_chunk(self) -> Dict[str, Any]: |
|
|
"""创建工具开始块""" |
|
|
return { |
|
|
"id": f"chatcmpl-{int(time.time())}", |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": self.model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": { |
|
|
"role": "assistant", |
|
|
"tool_calls": [{ |
|
|
"index": self.content_index, |
|
|
"id": self.tool_id, |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": self.tool_name, |
|
|
"arguments": "" |
|
|
} |
|
|
}] |
|
|
}, |
|
|
"finish_reason": None |
|
|
}] |
|
|
} |
|
|
|
|
|
def _create_tool_arguments_chunk(self, arguments: str) -> Dict[str, Any]: |
|
|
"""创建工具参数块""" |
|
|
return { |
|
|
"id": f"chatcmpl-{int(time.time())}", |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": self.model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": { |
|
|
"tool_calls": [{ |
|
|
"index": self.content_index, |
|
|
"id": self.tool_id, |
|
|
"function": { |
|
|
"arguments": arguments |
|
|
} |
|
|
}] |
|
|
}, |
|
|
"finish_reason": None |
|
|
}] |
|
|
} |
|
|
|
|
|
def _create_tool_finish_chunk(self) -> Dict[str, Any]: |
|
|
"""创建工具完成块""" |
|
|
chunk = { |
|
|
"id": f"chatcmpl-{int(time.time())}", |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": self.model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": { |
|
|
"tool_calls": [] |
|
|
}, |
|
|
"finish_reason": "tool_calls" |
|
|
}] |
|
|
} |
|
|
|
|
|
|
|
|
if self.tool_call_usage: |
|
|
chunk["usage"] = self.tool_call_usage |
|
|
|
|
|
return chunk |
|
|
|
|
|
def _reset_tool_state(self): |
|
|
"""重置工具状态""" |
|
|
self.tool_id = "" |
|
|
self.tool_name = "" |
|
|
self.tool_args = "" |
|
|
self.has_tool_call = False |
|
|
|
|
|
|
|
|
def _reset_all_state(self): |
|
|
"""重置所有状态""" |
|
|
|
|
|
if hasattr(self, 'content_buffer') and self.content_buffer: |
|
|
list(self._flush_content_buffer()) |
|
|
|
|
|
self._reset_tool_state() |
|
|
self.current_phase = None |
|
|
self.tool_call_usage = {} |
|
|
|
|
|
|
|
|
self.content_buffer = "" |
|
|
self.buffer_size = 0 |
|
|
self.last_flush_time = time.time() |
|
|
|
|
|
|
|
|
self.content_index = 0 |
|
|
logger.debug("🔄 重置所有处理器状态") |
|
|
|