Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| SSE Tool Handler | |
| 处理 Z.AI SSE 流数据并转换为 OpenAI 兼容格式的工具调用处理器。 | |
| 主要功能: | |
| - 解析 glm_block 格式的工具调用 | |
| - 从 metadata.arguments 提取完整参数 | |
| - 支持多阶段处理:thinking → tool_call → other → answer | |
| - 输出符合 OpenAI API 规范的流式响应 | |
| """ | |
| import json | |
| import time | |
| from typing import Dict, Any, Generator | |
| from enum import Enum | |
| from app.utils.logger import get_logger | |
| logger = get_logger() | |
| class SSEPhase(Enum): | |
| """SSE 处理阶段枚举""" | |
| THINKING = "thinking" | |
| TOOL_CALL = "tool_call" | |
| OTHER = "other" | |
| ANSWER = "answer" | |
| DONE = "done" | |
| class SSEToolHandler: | |
| """SSE 工具调用处理器""" | |
| def __init__(self, model: str, stream: bool = True): | |
| self.model = model | |
| self.stream = stream | |
| # 状态管理 | |
| self.current_phase = None | |
| self.has_tool_call = False | |
| # 工具调用状态 | |
| self.tool_id = "" | |
| self.tool_name = "" | |
| self.tool_args = "" | |
| self.tool_call_usage = {} | |
| self.content_index = 0 # 工具调用索引 | |
| # 性能优化:内容缓冲 | |
| self.content_buffer = "" | |
| self.buffer_size = 0 | |
| self.last_flush_time = time.time() | |
| self.flush_interval = 0.05 # 50ms 刷新间隔 | |
| self.max_buffer_size = 100 # 最大缓冲字符数 | |
| logger.debug(f"🔧 初始化工具处理器: model={model}, stream={stream}") | |
| def process_sse_chunk(self, chunk_data: Dict[str, Any]) -> Generator[str, None, None]: | |
| """ | |
| 处理 SSE 数据块,返回 OpenAI 格式的流式响应 | |
| Args: | |
| chunk_data: Z.AI SSE 数据块 | |
| Yields: | |
| str: OpenAI 格式的 SSE 响应行 | |
| """ | |
| try: | |
| phase = chunk_data.get("phase") | |
| edit_content = chunk_data.get("edit_content", "") | |
| delta_content = chunk_data.get("delta_content", "") | |
| edit_index = chunk_data.get("edit_index") | |
| usage = chunk_data.get("usage", {}) | |
| # 数据验证 | |
| if not phase: | |
| logger.warning("⚠️ 收到无效的 SSE 块:缺少 phase 字段") | |
| return | |
| # 阶段变化检测和日志 | |
| if phase != self.current_phase: | |
| # 阶段变化时强制刷新缓冲区 | |
| if hasattr(self, 'content_buffer') and self.content_buffer: | |
| yield from self._flush_content_buffer() | |
| logger.info(f"📈 SSE 阶段变化: {self.current_phase} → {phase}") | |
| content_preview = edit_content or delta_content | |
| if content_preview: | |
| logger.debug(f" 📝 内容预览: {content_preview[:1000]}{'...' if len(content_preview) > 1000 else ''}") | |
| if edit_index is not None: | |
| logger.debug(f" 📍 edit_index: {edit_index}") | |
| self.current_phase = phase | |
| # 根据阶段处理 | |
| if phase == SSEPhase.THINKING.value: | |
| yield from self._process_thinking_phase(delta_content) | |
| elif phase == SSEPhase.TOOL_CALL.value: | |
| yield from self._process_tool_call_phase(edit_content) | |
| elif phase == SSEPhase.OTHER.value: | |
| yield from self._process_other_phase(usage, edit_content) | |
| elif phase == SSEPhase.ANSWER.value: | |
| yield from self._process_answer_phase(delta_content) | |
| elif phase == SSEPhase.DONE.value: | |
| yield from self._process_done_phase(chunk_data) | |
| else: | |
| logger.warning(f"⚠️ 未知的 SSE 阶段: {phase}") | |
| except Exception as e: | |
| logger.error(f"❌ 处理 SSE 块时发生错误: {e}") | |
| logger.debug(f" 📦 错误块数据: {chunk_data}") | |
| # 不中断流,继续处理后续块 | |
| def _process_thinking_phase(self, delta_content: str) -> Generator[str, None, None]: | |
| """处理思考阶段""" | |
| if not delta_content: | |
| return | |
| logger.debug(f"🤔 思考内容: +{len(delta_content)} 字符") | |
| # 在流模式下输出思考内容 | |
| if self.stream: | |
| chunk = self._create_content_chunk(delta_content) | |
| yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" | |
| def _process_tool_call_phase(self, edit_content: str) -> Generator[str, None, None]: | |
| """处理工具调用阶段""" | |
| if not edit_content: | |
| return | |
| logger.debug(f"🔧 进入工具调用阶段,内容长度: {len(edit_content)}") | |
| # 检测 glm_block 标记 | |
| if "<glm_block " in edit_content: | |
| yield from self._handle_glm_blocks(edit_content) | |
| else: | |
| # 没有 glm_block 标记,可能是参数补充 | |
| if self.has_tool_call: | |
| # 只累积参数部分,找到第一个 ", "result"" 之前的内容 | |
| result_pos = edit_content.find('", "result"') | |
| if result_pos > 0: | |
| param_fragment = edit_content[:result_pos] | |
| self.tool_args += param_fragment | |
| logger.debug(f"📦 累积参数片段: {param_fragment}") | |
| else: | |
| # 如果没有找到结束标记,累积整个内容(可能是中间片段) | |
| self.tool_args += edit_content | |
| logger.debug(f"📦 累积参数片段: {edit_content[:100]}...") | |
| def _handle_glm_blocks(self, edit_content: str) -> Generator[str, None, None]: | |
| """处理 glm_block 标记的内容""" | |
| blocks = edit_content.split('<glm_block ') | |
| logger.debug(f"📦 分割得到 {len(blocks)} 个块") | |
| for index, block in enumerate(blocks): | |
| if not block.strip(): | |
| continue | |
| if index == 0: | |
| # 第一个块:提取参数片段 | |
| if self.has_tool_call: | |
| logger.debug(f"📦 从第一个块提取参数片段") | |
| # 找到 "result" 的位置,提取之前的参数片段 | |
| result_pos = edit_content.find('"result"') | |
| if result_pos > 0: | |
| # 往前退3个字符去掉 ", " | |
| param_fragment = edit_content[:result_pos - 3] | |
| self.tool_args += param_fragment | |
| logger.debug(f"📦 累积参数片段: {param_fragment}") | |
| else: | |
| # 没有活跃工具调用,跳过第一个块 | |
| continue | |
| else: | |
| # 后续块:处理新工具调用 | |
| if "</glm_block>" not in block: | |
| continue | |
| # 如果有活跃的工具调用,先完成它 | |
| if self.has_tool_call: | |
| # 补全参数并完成工具调用 | |
| self.tool_args += '"' # 补全最后的引号 | |
| yield from self._finish_current_tool() | |
| # 处理新工具调用 | |
| yield from self._process_metadata_block(block) | |
| def _process_metadata_block(self, block: str) -> Generator[str, None, None]: | |
| """处理包含工具元数据的块""" | |
| try: | |
| # 提取 JSON 内容 | |
| start_pos = block.find('>') | |
| end_pos = block.rfind('</glm_block>') | |
| if start_pos == -1 or end_pos == -1: | |
| logger.warning(f"❌ 无法找到 JSON 内容边界: {block[:1000]}...") | |
| return | |
| json_content = block[start_pos + 1:end_pos] | |
| logger.debug(f"📦 提取的 JSON 内容: {json_content[:1000]}...") | |
| # 解析工具元数据 | |
| metadata_obj = json.loads(json_content) | |
| if "data" in metadata_obj and "metadata" in metadata_obj["data"]: | |
| metadata = metadata_obj["data"]["metadata"] | |
| # 开始新的工具调用 | |
| self.tool_id = metadata.get("id", f"call_{int(time.time() * 1000000)}") | |
| self.tool_name = metadata.get("name", "unknown") | |
| self.has_tool_call = True | |
| # 只有在这是第二个及以后的工具调用时才递增 index | |
| # 第一个工具调用应该使用 index 0 | |
| # 从 metadata.arguments 获取参数起始部分 | |
| if "arguments" in metadata: | |
| arguments_str = metadata["arguments"] | |
| # 去掉最后一个字符 | |
| self.tool_args = arguments_str[:-1] if arguments_str.endswith('"') else arguments_str | |
| logger.debug(f"🎯 新工具调用: {self.tool_name}(id={self.tool_id}), 初始参数: {self.tool_args}") | |
| else: | |
| self.tool_args = "{}" | |
| logger.debug(f"🎯 新工具调用: {self.tool_name}(id={self.tool_id}), 空参数") | |
| except (json.JSONDecodeError, KeyError, AttributeError) as e: | |
| logger.error(f"❌ 解析工具元数据失败: {e}, 块内容: {block[:1000]}...") | |
| # 确保返回生成器(即使为空) | |
| if False: # 永远不会执行,但确保函数是生成器 | |
| yield | |
| def _process_other_phase(self, usage: Dict[str, Any], edit_content: str = "") -> Generator[str, None, None]: | |
| """处理其他阶段""" | |
| # 保存使用统计信息 | |
| if usage: | |
| self.tool_call_usage = usage | |
| logger.debug(f"📊 保存使用统计: {usage}") | |
| # 工具调用完成判断:检测到 "null," 开头的 edit_content | |
| if self.has_tool_call and edit_content and edit_content.startswith("null,"): | |
| logger.info(f"🏁 检测到工具调用结束标记") | |
| # 完成当前工具调用 | |
| yield from self._finish_current_tool() | |
| # 发送流结束标记 | |
| if self.stream: | |
| yield "data: [DONE]\n\n" | |
| # 重置状态 | |
| self._reset_all_state() | |
| def _process_answer_phase(self, delta_content: str) -> Generator[str, None, None]: | |
| """处理回答阶段(优化版本)""" | |
| if not delta_content: | |
| return | |
| logger.info(f"📝 工具处理器收到答案内容: {delta_content[:50]}...") | |
| # 添加到缓冲区 | |
| self.content_buffer += delta_content | |
| self.buffer_size += len(delta_content) | |
| current_time = time.time() | |
| time_since_last_flush = current_time - self.last_flush_time | |
| # 检查是否需要刷新缓冲区 | |
| should_flush = ( | |
| self.buffer_size >= self.max_buffer_size or # 缓冲区满了 | |
| time_since_last_flush >= self.flush_interval or # 时间间隔到了 | |
| '\n' in delta_content or # 包含换行符 | |
| '。' in delta_content or '!' in delta_content or '?' in delta_content # 包含句子结束符 | |
| ) | |
| if should_flush and self.content_buffer: | |
| yield from self._flush_content_buffer() | |
| def _flush_content_buffer(self) -> Generator[str, None, None]: | |
| """刷新内容缓冲区""" | |
| if not self.content_buffer: | |
| return | |
| logger.info(f"💬 工具处理器刷新缓冲区: {self.buffer_size} 字符 - {self.content_buffer[:50]}...") | |
| if self.stream: | |
| chunk = self._create_content_chunk(self.content_buffer) | |
| output_data = f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" | |
| logger.info(f"➡️ 工具处理器输出: {output_data[:100]}...") | |
| yield output_data | |
| # 清空缓冲区 | |
| self.content_buffer = "" | |
| self.buffer_size = 0 | |
| self.last_flush_time = time.time() | |
| def _process_done_phase(self, chunk_data: Dict[str, Any]) -> Generator[str, None, None]: | |
| """处理完成阶段""" | |
| logger.info("🏁 对话完成") | |
| # 先刷新任何剩余的缓冲内容 | |
| if self.content_buffer: | |
| yield from self._flush_content_buffer() | |
| # 完成任何未完成的工具调用 | |
| if self.has_tool_call: | |
| yield from self._finish_current_tool() | |
| # 发送流结束标记 | |
| if self.stream: | |
| # 创建最终的完成块 | |
| final_chunk = { | |
| "id": f"chatcmpl-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": self.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| # 如果有 usage 信息,添加到最终块中 | |
| if "usage" in chunk_data: | |
| final_chunk["usage"] = chunk_data["usage"] | |
| yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| # 重置所有状态 | |
| self._reset_all_state() | |
| def _finish_current_tool(self) -> Generator[str, None, None]: | |
| """完成当前工具调用""" | |
| if not self.has_tool_call: | |
| return | |
| # 修复参数格式 | |
| fixed_args = self._fix_tool_arguments(self.tool_args) | |
| logger.debug(f"✅ 完成工具调用: {self.tool_name}, 参数: {fixed_args}") | |
| # 输出工具调用(开始 + 参数 + 完成) | |
| if self.stream: | |
| # 发送工具开始块 | |
| start_chunk = self._create_tool_start_chunk() | |
| yield f"data: {json.dumps(start_chunk, ensure_ascii=False)}\n\n" | |
| # 发送参数块 | |
| args_chunk = self._create_tool_arguments_chunk(fixed_args) | |
| yield f"data: {json.dumps(args_chunk, ensure_ascii=False)}\n\n" | |
| # 发送完成块 | |
| finish_chunk = self._create_tool_finish_chunk() | |
| yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" | |
| # 重置工具状态 | |
| self._reset_tool_state() | |
| def _fix_tool_arguments(self, raw_args: str) -> str: | |
| """使用 json-repair 库修复工具参数格式""" | |
| if not raw_args or raw_args == "{}": | |
| return "{}" | |
| logger.debug(f"🔧 开始修复参数: {raw_args[:1000]}{'...' if len(raw_args) > 1000 else ''}") | |
| # 统一的修复流程:预处理 -> json-repair -> 后处理 | |
| try: | |
| # 1. 预处理:只处理 json-repair 无法处理的问题 | |
| processed_args = self._preprocess_json_string(raw_args.strip()) | |
| # 2. 使用 json-repair 进行主要修复 | |
| from json_repair import repair_json | |
| repaired_json = repair_json(processed_args) | |
| logger.debug(f"🔧 json-repair 修复结果: {repaired_json}") | |
| # 3. 解析并后处理 | |
| args_obj = json.loads(repaired_json) | |
| args_obj = self._post_process_args(args_obj) | |
| # 4. 生成最终结果 | |
| fixed_result = json.dumps(args_obj, ensure_ascii=False) | |
| return fixed_result | |
| except Exception as e: | |
| logger.error(f"❌ JSON 修复失败: {e}, 原始参数: {raw_args[:1000]}..., 使用空参数") | |
| return "{}" | |
| def _post_process_args(self, args_obj: Dict[str, Any]) -> Dict[str, Any]: | |
| """统一的后处理方法""" | |
| # 修复路径中的过度转义 | |
| args_obj = self._fix_path_escaping_in_args(args_obj) | |
| # 修复命令中的多余引号 | |
| args_obj = self._fix_command_quotes(args_obj) | |
| return args_obj | |
| def _preprocess_json_string(self, text: str) -> str: | |
| """预处理 JSON 字符串,只处理 json-repair 无法处理的问题""" | |
| import re | |
| # 只保留 json-repair 无法处理的预处理步骤 | |
| # 1. 修复缺少开始括号的情况(json-repair 无法处理) | |
| if not text.startswith('{') and text.endswith('}'): | |
| text = '{' + text | |
| logger.debug(f"🔧 补全开始括号") | |
| # 2. 修复末尾多余的反斜杠和引号(json-repair 可能处理不当) | |
| # 匹配模式:字符串值末尾的 \" 后面跟着 } 或 , | |
| # 例如:{"url":"https://www.bilibili.com\"} -> {"url":"https://www.bilibili.com"} | |
| # 例如:{"url":"https://www.bilibili.com\",} -> {"url":"https://www.bilibili.com",} | |
| pattern = r'([^\\])\\"([}\s,])' | |
| if re.search(pattern, text): | |
| text = re.sub(pattern, r'\1"\2', text) | |
| logger.debug(f"🔧 修复末尾多余的反斜杠") | |
| return text | |
| def _fix_path_escaping_in_args(self, args_obj: Dict[str, Any]) -> Dict[str, Any]: | |
| """修复参数对象中路径的过度转义问题""" | |
| import re | |
| # 需要检查的路径字段 | |
| path_fields = ['file_path', 'path', 'directory', 'folder'] | |
| for field in path_fields: | |
| if field in args_obj and isinstance(args_obj[field], str): | |
| path_value = args_obj[field] | |
| # 检查是否是Windows路径且包含过度转义 | |
| if path_value.startswith('C:') and '\\\\' in path_value: | |
| logger.debug(f"🔍 检查路径字段 {field}: {repr(path_value)}") | |
| # 分析路径结构:正常路径应该是 C:\Users\... | |
| # 但过度转义的路径可能是 C:\Users\\Documents(多了一个反斜杠) | |
| # 我们需要找到不正常的双反斜杠模式并修复 | |
| # 先检查是否有不正常的双反斜杠(不在路径开头) | |
| # 正常:C:\Users\Documents | |
| # 异常:C:\Users\\Documents 或 C:\Users\\\\Documents | |
| # 使用更精确的模式:匹配路径分隔符后的额外反斜杠 | |
| # 但要保留正常的路径分隔符 | |
| fixed_path = path_value | |
| # 检查是否有连续的多个反斜杠(超过正常的路径分隔符) | |
| if '\\\\' in path_value: | |
| # 计算反斜杠的数量,如果超过正常数量就修复 | |
| parts = path_value.split('\\') | |
| # 重新组装路径,去除空的部分(由多余的反斜杠造成) | |
| clean_parts = [part for part in parts if part] | |
| if len(clean_parts) > 1: | |
| fixed_path = '\\'.join(clean_parts) | |
| logger.debug(f"🔍 修复后路径: {repr(fixed_path)}") | |
| if fixed_path != path_value: | |
| args_obj[field] = fixed_path | |
| logger.debug(f"🔧 修复字段 {field} 的路径转义: {path_value} -> {fixed_path}") | |
| else: | |
| logger.debug(f"🔍 路径无需修复: {path_value}") | |
| return args_obj | |
| def _fix_command_quotes(self, args_obj: Dict[str, Any]) -> Dict[str, Any]: | |
| """修复命令中的多余引号问题""" | |
| import re | |
| # 检查命令字段 | |
| if 'command' in args_obj and isinstance(args_obj['command'], str): | |
| command = args_obj['command'] | |
| # 检查是否以双引号结尾(多余的引号) | |
| if command.endswith('""'): | |
| logger.debug(f"🔧 发现命令末尾多余引号: {command}") | |
| # 移除最后一个多余的引号 | |
| fixed_command = command[:-1] | |
| args_obj['command'] = fixed_command | |
| logger.debug(f"🔧 修复命令引号: {command} -> {fixed_command}") | |
| # 检查其他可能的引号问题 | |
| # 例如:路径末尾的 \"" 模式 | |
| elif re.search(r'\\""+$', command): | |
| logger.debug(f"🔧 发现命令末尾引号模式问题: {command}") | |
| # 修复路径末尾的引号问题 | |
| fixed_command = re.sub(r'\\""+$', '\\"', command) | |
| args_obj['command'] = fixed_command | |
| logger.debug(f"🔧 修复命令引号模式: {command} -> {fixed_command}") | |
| return args_obj | |
| def _create_content_chunk(self, content: str) -> Dict[str, Any]: | |
| """创建内容块""" | |
| return { | |
| "id": f"chatcmpl-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": self.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": { | |
| "role": "assistant", | |
| "content": content | |
| }, | |
| "finish_reason": None | |
| }] | |
| } | |
| def _create_tool_start_chunk(self) -> Dict[str, Any]: | |
| """创建工具开始块""" | |
| return { | |
| "id": f"chatcmpl-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": self.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": { | |
| "role": "assistant", | |
| "tool_calls": [{ | |
| "index": self.content_index, | |
| "id": self.tool_id, | |
| "type": "function", | |
| "function": { | |
| "name": self.tool_name, | |
| "arguments": "" | |
| } | |
| }] | |
| }, | |
| "finish_reason": None | |
| }] | |
| } | |
| def _create_tool_arguments_chunk(self, arguments: str) -> Dict[str, Any]: | |
| """创建工具参数块""" | |
| return { | |
| "id": f"chatcmpl-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": self.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": { | |
| "tool_calls": [{ | |
| "index": self.content_index, | |
| "id": self.tool_id, | |
| "function": { | |
| "arguments": arguments | |
| } | |
| }] | |
| }, | |
| "finish_reason": None | |
| }] | |
| } | |
| def _create_tool_finish_chunk(self) -> Dict[str, Any]: | |
| """创建工具完成块""" | |
| chunk = { | |
| "id": f"chatcmpl-{int(time.time())}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": self.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": { | |
| "tool_calls": [] | |
| }, | |
| "finish_reason": "tool_calls" | |
| }] | |
| } | |
| # 添加使用统计(如果有) | |
| if self.tool_call_usage: | |
| chunk["usage"] = self.tool_call_usage | |
| return chunk | |
| def _reset_tool_state(self): | |
| """重置工具状态""" | |
| self.tool_id = "" | |
| self.tool_name = "" | |
| self.tool_args = "" | |
| self.has_tool_call = False | |
| # content_index 在单次对话中应该保持不变,只有在新的工具调用开始时才递增 | |
| def _reset_all_state(self): | |
| """重置所有状态""" | |
| # 先刷新任何剩余的缓冲内容 | |
| if hasattr(self, 'content_buffer') and self.content_buffer: | |
| list(self._flush_content_buffer()) # 消费生成器 | |
| self._reset_tool_state() | |
| self.current_phase = None | |
| self.tool_call_usage = {} | |
| # 重置缓冲区 | |
| self.content_buffer = "" | |
| self.buffer_size = 0 | |
| self.last_flush_time = time.time() | |
| # content_index 重置为 0,为下一轮对话做准备 | |
| self.content_index = 0 | |
| logger.debug("🔄 重置所有处理器状态") | |