ZyphrZero commited on
Commit
e985dd1
·
1 Parent(s): 9730d78

refactor(app/utils): 重构工具调用处理逻辑

Browse files

- 完善了工具调用结束的处理逻辑,包括参数解析和状态重置
- 调整了工具调用完成后的响应格式,使其更加规范

app/utils/sse_tool_handler.py CHANGED
@@ -35,12 +35,14 @@ class SSEToolHandler:
35
  """
36
  if not self.has_tool_call:
37
  self.has_tool_call = True
38
- logger.debug("进入工具调用阶段")
39
 
40
  edit_content = data.get("edit_content", "")
41
  if not edit_content:
42
  return
43
 
 
 
44
  # 分割glm_block块
45
  blocks = edit_content.split("<glm_block >")
46
 
@@ -48,28 +50,34 @@ class SSEToolHandler:
48
  if not block:
49
  continue
50
 
 
 
51
  if "</glm_block>" not in block:
52
  # 这个块不完整,可能是参数片段
53
  if index == 0:
54
  # 第一个块的参数片段
55
  self.tool_args += block
 
56
  continue
57
 
58
  if index == 0:
59
  # 第一个块:提取参数片段(到"result"之前)
 
60
  if '"result"' in edit_content:
61
- args_fragment = edit_content[: edit_content.index('"result"') - 3]
 
62
  self.tool_args += args_fragment
63
- logger.debug(f"从第一个块提取参数片段: {args_fragment}")
64
  else:
65
  # 后续块:新的工具调用
66
  # 如果当前有工具正在处理,先完成它
67
  if self.tool_id:
 
68
  yield from self._finish_current_tool(is_stream)
69
 
70
  # 解析新工具信息
71
  try:
72
- block_content = block[: block.index("</glm_block>")]
73
  content = json.loads(block_content)
74
  metadata = content.get("data", {}).get("metadata", {})
75
 
@@ -81,8 +89,8 @@ class SSEToolHandler:
81
  # 累积参数(去掉最后的}以便后续累积)
82
  self.tool_args = json.dumps(arguments, ensure_ascii=False)[:-1]
83
 
84
- logger.debug(f"新工具调用: {self.tool_name}(id={self.tool_id})")
85
- logger.debug(f"初始参数: {self.tool_args}")
86
 
87
  if is_stream:
88
  yield self._create_tool_start_chunk()
@@ -90,40 +98,32 @@ class SSEToolHandler:
90
  self.content_index += 1
91
 
92
  except (json.JSONDecodeError, KeyError) as e:
93
- logger.error(f"解析工具块失败: {e}")
 
94
 
95
  def _finish_current_tool(self, is_stream: bool) -> Generator[str, None, None]:
96
- """完成当前工具调用"""
97
  if not self.tool_id:
98
  return
99
 
100
  try:
101
- # 处理不同的参数状态
102
- if not self.tool_args or self.tool_args == "{":
103
- # 空参数或只有开始括号
104
- params = {}
105
- else:
106
- # 尝试补充结束符
107
- test_args = self.tool_args
108
-
109
- # 检查是否需要补充结束引号
110
- quote_count = test_args.count('"')
111
- if quote_count % 2 != 0:
112
- test_args += '"'
113
 
114
- # 检查是否需要补充结束括号
115
- if not test_args.endswith("}"):
116
- test_args += "}"
117
 
118
- params = json.loads(test_args)
 
119
 
120
- logger.debug(f"完成工具调用: {self.tool_name} with params: {params}")
121
 
122
  if is_stream:
123
  yield self._create_tool_arguments_chunk(params)
124
 
125
  except json.JSONDecodeError as e:
126
- logger.error(f"工具参数解析失败: {e}, 原始参数: {self.tool_args[:200]}")
 
 
 
127
  params = {}
128
  if is_stream:
129
  yield self._create_tool_arguments_chunk(params)
@@ -144,21 +144,81 @@ class SSEToolHandler:
144
  # 保存usage信息
145
  if self.has_tool_call and usage:
146
  self.tool_call_usage = usage
147
- logger.debug(f"保存工具调用usage: {usage}")
148
 
149
  # 检测工具调用结束标记 "null,"
150
  if self.has_tool_call and edit_content and edit_content.startswith("null,"):
151
- logger.debug("检测到工具调用结束标记: null,")
152
-
153
- # 完成最后一个工具调用
154
- if self.tool_id:
155
- yield from self._finish_current_tool(is_stream)
156
-
157
- # 发送结束信号
158
- if is_stream:
159
- logger.info(" 发送工具调用完成信号")
160
- yield self._create_tool_finish_chunk()
161
- yield "data: [DONE]\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  # 重置所有状态
164
  self._reset_all_state()
@@ -218,7 +278,7 @@ class SSEToolHandler:
218
  ],
219
  },
220
  "finish_reason": None,
221
- "index": 0,
222
  "logprobs": None,
223
  }
224
  ],
@@ -243,7 +303,7 @@ class SSEToolHandler:
243
  ],
244
  "created": int(time.time()),
245
  "id": self.chat_id,
246
- "usage": self.tool_call_usage,
247
  "model": self.model,
248
  "object": "chat.completion.chunk",
249
  "system_fingerprint": "fp_zai_001",
 
35
  """
36
  if not self.has_tool_call:
37
  self.has_tool_call = True
38
+ logger.debug("🔧 进入工具调用阶段")
39
 
40
  edit_content = data.get("edit_content", "")
41
  if not edit_content:
42
  return
43
 
44
+ logger.debug(f"📦 解析数据块: {edit_content}")
45
+
46
  # 分割glm_block块
47
  blocks = edit_content.split("<glm_block >")
48
 
 
50
  if not block:
51
  continue
52
 
53
+ logger.debug(f" 📦 处理块 {index}: {block[:200]}...")
54
+
55
  if "</glm_block>" not in block:
56
  # 这个块不完整,可能是参数片段
57
  if index == 0:
58
  # 第一个块的参数片段
59
  self.tool_args += block
60
+ logger.debug(f" 📦 累积参数片段: {block}")
61
  continue
62
 
63
  if index == 0:
64
  # 第一个块:提取参数片段(到"result"之前)
65
+ # 提取到 '"result"' 之前的内容
66
  if '"result"' in edit_content:
67
+ result_index = edit_content.index('"result"')
68
+ args_fragment = edit_content[:result_index - 3]
69
  self.tool_args += args_fragment
70
+ logger.debug(f"📦 从第一个块提取参数片段: {args_fragment}")
71
  else:
72
  # 后续块:新的工具调用
73
  # 如果当前有工具正在处理,先完成它
74
  if self.tool_id:
75
+ logger.debug(f" 🎯 完成当前工具: {self.tool_name}")
76
  yield from self._finish_current_tool(is_stream)
77
 
78
  # 解析新工具信息
79
  try:
80
+ block_content = block[:block.index("</glm_block>")]
81
  content = json.loads(block_content)
82
  metadata = content.get("data", {}).get("metadata", {})
83
 
 
89
  # 累积参数(去掉最后的}以便后续累积)
90
  self.tool_args = json.dumps(arguments, ensure_ascii=False)[:-1]
91
 
92
+ logger.debug(f"🎯 新工具调用: {self.tool_name}(id={self.tool_id})")
93
+ logger.debug(f" 📦 初始参数: {self.tool_args}")
94
 
95
  if is_stream:
96
  yield self._create_tool_start_chunk()
 
98
  self.content_index += 1
99
 
100
  except (json.JSONDecodeError, KeyError) as e:
101
+ logger.error(f"解析工具块失败: {e}")
102
+ logger.error(f" 📦 块内容: {block[:500]}")
103
 
104
  def _finish_current_tool(self, is_stream: bool) -> Generator[str, None, None]:
 
105
  if not self.tool_id:
106
  return
107
 
108
  try:
109
+ test_args = self.tool_args + '"'
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ logger.debug(f"✅ 工具参数解析成功: {self.tool_name}")
112
+ logger.debug(f" 📦 最终参数字符串: {test_args}")
 
113
 
114
+ # 解析参数
115
+ params = json.loads(test_args)
116
 
117
+ logger.debug(f"完成工具调用: {self.tool_name} with params: {params}")
118
 
119
  if is_stream:
120
  yield self._create_tool_arguments_chunk(params)
121
 
122
  except json.JSONDecodeError as e:
123
+ logger.error(f"工具参数解析失败: {e}")
124
+ logger.error(f" 📦 原始参数: {self.tool_args[:200]}")
125
+ logger.error(f" 📦 测试参数: {test_args[:200] if 'test_args' in locals() else 'N/A'}")
126
+ # 解析失败时使用空参数
127
  params = {}
128
  if is_stream:
129
  yield self._create_tool_arguments_chunk(params)
 
144
  # 保存usage信息
145
  if self.has_tool_call and usage:
146
  self.tool_call_usage = usage
147
+ logger.debug(f"💾 保存工具调用usage: {usage}")
148
 
149
  # 检测工具调用结束标记 "null,"
150
  if self.has_tool_call and edit_content and edit_content.startswith("null,"):
151
+ logger.debug("🏁 检测到工具调用结束标记: null,")
152
+
153
+ # 补充引号并完成最后一个工具调用
154
+ self.tool_args += '"'
155
+ self.has_tool_call = False
156
+
157
+ try:
158
+ # 解析最终参数
159
+ params = json.loads(self.tool_args)
160
+ logger.debug(f"✅ 最终工具参数解析成功: {params}")
161
+
162
+ if is_stream:
163
+ # 创建工具参数块
164
+ tool_call_delta = {
165
+ "id": self.tool_id,
166
+ "type": "function",
167
+ "function": {
168
+ "name": None,
169
+ "arguments": json.dumps(params, ensure_ascii=False),
170
+ },
171
+ }
172
+ delta_res = {
173
+ "choices": [
174
+ {
175
+ "delta": {
176
+ "role": "assistant",
177
+ "content": None,
178
+ "tool_calls": [tool_call_delta],
179
+ },
180
+ "finish_reason": None,
181
+ "index": 0,
182
+ "logprobs": None,
183
+ }
184
+ ],
185
+ "created": int(time.time()),
186
+ "id": self.chat_id,
187
+ "model": self.model,
188
+ "object": "chat.completion.chunk",
189
+ "system_fingerprint": "fp_zai_001",
190
+ }
191
+ yield f"data: {json.dumps(delta_res, ensure_ascii=False)}\n\n"
192
+
193
+ # 发送工具完成信号
194
+ finish_res = {
195
+ "choices": [
196
+ {
197
+ "delta": {
198
+ "role": "assistant",
199
+ "content": None,
200
+ "tool_calls": [],
201
+ },
202
+ "finish_reason": "tool_calls",
203
+ "index": 0,
204
+ "logprobs": None,
205
+ }
206
+ ],
207
+ "created": int(time.time()),
208
+ "id": self.chat_id,
209
+ "usage": self.tool_call_usage or None,
210
+ "model": self.model,
211
+ "object": "chat.completion.chunk",
212
+ "system_fingerprint": "fp_zai_001",
213
+ }
214
+
215
+ logger.info("🏁 发送工具调用完成信号")
216
+ yield f"data: {json.dumps(finish_res, ensure_ascii=False)}\n\n"
217
+ yield "data: [DONE]\n\n"
218
+
219
+ except json.JSONDecodeError as e:
220
+ logger.error(f"❌ 最终参数解析失败: {e}")
221
+ logger.error(f" 📦 参数内容: {self.tool_args}")
222
 
223
  # 重置所有状态
224
  self._reset_all_state()
 
278
  ],
279
  },
280
  "finish_reason": None,
281
+ "index": self.content_index, # 使用正确的索引
282
  "logprobs": None,
283
  }
284
  ],
 
303
  ],
304
  "created": int(time.time()),
305
  "id": self.chat_id,
306
+ "usage": self.tool_call_usage or None,
307
  "model": self.model,
308
  "object": "chat.completion.chunk",
309
  "system_fingerprint": "fp_zai_001",
tests/test_tool_call_fix.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 测试修复后的工具调用功能
6
+ """
7
+
8
+ import json
9
+ import asyncio
10
+ import httpx
11
+ from typing import Dict, Any
12
+
13
+ # 测试配置
14
+ TEST_URL = "http://localhost:8080/v1/chat/completions"
15
+ TEST_AUTH_TOKEN = "sk-test-key"
16
+
17
+ # 测试工具定义
18
+ TEST_TOOLS = [
19
+ {
20
+ "type": "function",
21
+ "function": {
22
+ "name": "get_weather",
23
+ "description": "获取指定城市的天气信息",
24
+ "parameters": {
25
+ "type": "object",
26
+ "properties": {
27
+ "city": {
28
+ "type": "string",
29
+ "description": "城市名称"
30
+ },
31
+ "unit": {
32
+ "type": "string",
33
+ "enum": ["celsius", "fahrenheit"],
34
+ "description": "温度单位"
35
+ }
36
+ },
37
+ "required": ["city"]
38
+ }
39
+ }
40
+ }
41
+ ]
42
+
43
+ async def test_tool_call_streaming():
44
+ """测试流式工具调用"""
45
+ print("🧪 开始测试流式工具调用...")
46
+
47
+ payload = {
48
+ "model": "glm-4.5",
49
+ "messages": [
50
+ {
51
+ "role": "user",
52
+ "content": "请帮我查询北京的天气,使用摄氏度"
53
+ }
54
+ ],
55
+ "tools": TEST_TOOLS,
56
+ "stream": True,
57
+ "temperature": 0.7
58
+ }
59
+
60
+ headers = {
61
+ "Content-Type": "application/json",
62
+ "Authorization": f"Bearer {TEST_AUTH_TOKEN}"
63
+ }
64
+
65
+ try:
66
+ async with httpx.AsyncClient(timeout=30.0) as client:
67
+ async with client.stream(
68
+ "POST",
69
+ TEST_URL,
70
+ json=payload,
71
+ headers=headers
72
+ ) as response:
73
+ print(f"📡 响应状态: {response.status_code}")
74
+ print(f"📡 响应头: {dict(response.headers)}")
75
+
76
+ if response.status_code != 200:
77
+ error_text = await response.aread()
78
+ print(f"❌ 请求失败: {error_text.decode()}")
79
+ return
80
+
81
+ print("\n📦 开始接收流式数据:")
82
+ print("-" * 80)
83
+
84
+ chunk_count = 0
85
+ tool_calls_found = False
86
+
87
+ async for line in response.aiter_lines():
88
+ if not line:
89
+ continue
90
+
91
+ if line.startswith("data: "):
92
+ chunk_count += 1
93
+ data_str = line[6:].strip()
94
+
95
+ if data_str == "[DONE]":
96
+ print(f"🏁 [{chunk_count:03d}] 流结束: [DONE]")
97
+ break
98
+
99
+ try:
100
+ chunk = json.loads(data_str)
101
+
102
+ # 检查是否包含工具调用
103
+ choices = chunk.get("choices", [])
104
+ if choices:
105
+ choice = choices[0]
106
+ delta = choice.get("delta", {})
107
+ tool_calls = delta.get("tool_calls", [])
108
+
109
+ if tool_calls:
110
+ tool_calls_found = True
111
+ print(f"🔧 [{chunk_count:03d}] 工具调用块:")
112
+ for tool_call in tool_calls:
113
+ print(f" ID: {tool_call.get('id', 'N/A')}")
114
+ print(f" 类型: {tool_call.get('type', 'N/A')}")
115
+ function = tool_call.get('function', {})
116
+ print(f" 函数名: {function.get('name', 'N/A')}")
117
+ print(f" 参数: {function.get('arguments', 'N/A')}")
118
+ print(f" 参数类型: {type(function.get('arguments', 'N/A'))}")
119
+
120
+ finish_reason = choice.get("finish_reason")
121
+ if finish_reason:
122
+ print(f"🏁 [{chunk_count:03d}] 完成原因: {finish_reason}")
123
+
124
+ # 显示其他内容
125
+ content = delta.get("content")
126
+ if content:
127
+ print(f"💬 [{chunk_count:03d}] 内容: {content}")
128
+
129
+ # 显示usage信息
130
+ usage = chunk.get("usage")
131
+ if usage:
132
+ print(f"📊 [{chunk_count:03d}] 使用统计: {usage}")
133
+
134
+ except json.JSONDecodeError as e:
135
+ print(f"❌ [{chunk_count:03d}] JSON解析错误: {e}")
136
+ print(f" 原始数据: {data_str[:200]}...")
137
+
138
+ print("-" * 80)
139
+ print(f"✅ 测试完成,共处理 {chunk_count} 个数据块")
140
+ print(f"🔧 工具调用检测: {'成功' if tool_calls_found else '失败'}")
141
+
142
+ except Exception as e:
143
+ print(f"❌ 测试异常: {e}")
144
+ import traceback
145
+ traceback.print_exc()
146
+
147
+ async def test_tool_call_non_streaming():
148
+ """测试非流式工具调用"""
149
+ print("\n🧪 开始测试非流式工具调用...")
150
+
151
+ payload = {
152
+ "model": "glm-4.5",
153
+ "messages": [
154
+ {
155
+ "role": "user",
156
+ "content": "请帮我查询上海的天气"
157
+ }
158
+ ],
159
+ "tools": TEST_TOOLS,
160
+ "stream": False,
161
+ "temperature": 0.7
162
+ }
163
+
164
+ headers = {
165
+ "Content-Type": "application/json",
166
+ "Authorization": f"Bearer {TEST_AUTH_TOKEN}"
167
+ }
168
+
169
+ try:
170
+ async with httpx.AsyncClient(timeout=30.0) as client:
171
+ response = await client.post(TEST_URL, json=payload, headers=headers)
172
+
173
+ print(f"📡 响应状态: {response.status_code}")
174
+
175
+ if response.status_code == 200:
176
+ result = response.json()
177
+ print("📦 响应结果:")
178
+ print(json.dumps(result, indent=2, ensure_ascii=False))
179
+
180
+ # 检查工具调用
181
+ choices = result.get("choices", [])
182
+ if choices:
183
+ message = choices[0].get("message", {})
184
+ tool_calls = message.get("tool_calls", [])
185
+ if tool_calls:
186
+ print(f"🔧 发现 {len(tool_calls)} 个工具调用")
187
+ for i, tool_call in enumerate(tool_calls):
188
+ print(f" 工具 {i+1}: {tool_call}")
189
+ else:
190
+ print("❌ 未发现工具调用")
191
+ else:
192
+ print(f"❌ 请求失败: {response.text}")
193
+
194
+ except Exception as e:
195
+ print(f"❌ 测试异常: {e}")
196
+
197
+ async def main():
198
+ """主测试函数"""
199
+ print("🚀 开始工具调用修复验证测试")
200
+ print("=" * 80)
201
+
202
+ # 测试流式工具调用
203
+ await test_tool_call_streaming()
204
+
205
+ # 等待一下
206
+ await asyncio.sleep(2)
207
+
208
+ # 测试非流式工具调用
209
+ await test_tool_call_non_streaming()
210
+
211
+ print("\n" + "=" * 80)
212
+ print("🎯 测试完成")
213
+
214
+ if __name__ == "__main__":
215
+ asyncio.run(main())