ZyphrZero commited on
Commit
8d0dafa
·
1 Parent(s): a11ee1d

♻️ refactor(utils): 改进工具调用提取和移除的逻辑

Browse files

- 移除了 TOOL_CALL_INLINE_PATTERN 正则表达式,因为它会导致过度匹配
- 实现了基于括号平衡的方法来提取工具调用,提高了准确性
- 改进了 remove_tool_json_content 函数,使用更精确的括号匹配算法
- 添加了多个测试用例来验证工具调用移除功能的正确性
- 优化了处理复杂嵌套 JSON 结构的能力

app/utils/tools.py CHANGED
@@ -156,7 +156,8 @@ def process_messages_with_tools(
156
 
157
  # Tool Extraction Patterns
158
  TOOL_CALL_FENCE_PATTERN = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)
159
- TOOL_CALL_INLINE_PATTERN = re.compile(r"(\{[^{}]{0,10000}\"tool_calls\".*?\})", re.DOTALL)
 
160
  FUNCTION_CALL_PATTERN = re.compile(r"调用函数\s*[::]\s*([\w\-\.]+)\s*(?:参数|arguments)[::]\s*(\{.*?\})", re.DOTALL)
161
 
162
 
@@ -189,27 +190,55 @@ def extract_tool_invocations(text: str) -> Optional[List[Dict[str, Any]]]:
189
  except (json.JSONDecodeError, AttributeError):
190
  continue
191
 
192
- # Attempt 2: Extract inline JSON objects
193
- inline_match = TOOL_CALL_INLINE_PATTERN.search(scannable_text)
194
- if inline_match:
195
- try:
196
- inline_json = inline_match.group(1)
197
- parsed_data = json.loads(inline_json)
198
- tool_calls = parsed_data.get("tool_calls")
199
- if tool_calls and isinstance(tool_calls, list):
200
- # Ensure arguments field is a string
201
- for tc in tool_calls:
202
- if "function" in tc:
203
- func = tc["function"]
204
- if "arguments" in func:
205
- if isinstance(func["arguments"], dict):
206
- # Convert dict to JSON string
207
- func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
208
- elif not isinstance(func["arguments"], str):
209
- func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
210
- return tool_calls
211
- except (json.JSONDecodeError, AttributeError):
212
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  # Attempt 3: Parse natural language function calls
215
  natural_lang_match = FUNCTION_CALL_PATTERN.search(scannable_text)
@@ -233,8 +262,8 @@ def extract_tool_invocations(text: str) -> Optional[List[Dict[str, Any]]]:
233
 
234
 
235
  def remove_tool_json_content(text: str) -> str:
236
- """Remove tool JSON content from response text"""
237
-
238
  def remove_tool_call_block(match: re.Match) -> str:
239
  json_content = match.group(1)
240
  try:
@@ -244,9 +273,53 @@ def remove_tool_json_content(text: str) -> str:
244
  except (json.JSONDecodeError, AttributeError):
245
  pass
246
  return match.group(0)
247
-
248
- # Remove fenced tool JSON blocks
249
  cleaned_text = TOOL_CALL_FENCE_PATTERN.sub(remove_tool_call_block, text)
250
- # Remove inline tool JSON
251
- cleaned_text = TOOL_CALL_INLINE_PATTERN.sub("", cleaned_text)
252
- return cleaned_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # Tool Extraction Patterns
158
  TOOL_CALL_FENCE_PATTERN = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)
159
+ # 注意:TOOL_CALL_INLINE_PATTERN 已被移除,因为它会导致过度匹配
160
+ # 现在在 remove_tool_json_content 函数中使用基于括号平衡的方法
161
  FUNCTION_CALL_PATTERN = re.compile(r"调用函数\s*[::]\s*([\w\-\.]+)\s*(?:参数|arguments)[::]\s*(\{.*?\})", re.DOTALL)
162
 
163
 
 
190
  except (json.JSONDecodeError, AttributeError):
191
  continue
192
 
193
+ # Attempt 2: Extract inline JSON objects using bracket balance method
194
+ # 查找包含 "tool_calls" 的 JSON 对象
195
+ i = 0
196
+ while i < len(scannable_text):
197
+ if scannable_text[i] == '{':
198
+ # 尝试找到匹配的右括号
199
+ brace_count = 1
200
+ j = i + 1
201
+ in_string = False
202
+ escape_next = False
203
+
204
+ while j < len(scannable_text) and brace_count > 0:
205
+ if escape_next:
206
+ escape_next = False
207
+ elif scannable_text[j] == '\\':
208
+ escape_next = True
209
+ elif scannable_text[j] == '"' and not escape_next:
210
+ in_string = not in_string
211
+ elif not in_string:
212
+ if scannable_text[j] == '{':
213
+ brace_count += 1
214
+ elif scannable_text[j] == '}':
215
+ brace_count -= 1
216
+ j += 1
217
+
218
+ if brace_count == 0:
219
+ # 找到了完整的 JSON 对象
220
+ json_str = scannable_text[i:j]
221
+ try:
222
+ parsed_data = json.loads(json_str)
223
+ tool_calls = parsed_data.get("tool_calls")
224
+ if tool_calls and isinstance(tool_calls, list):
225
+ # Ensure arguments field is a string
226
+ for tc in tool_calls:
227
+ if "function" in tc:
228
+ func = tc["function"]
229
+ if "arguments" in func:
230
+ if isinstance(func["arguments"], dict):
231
+ # Convert dict to JSON string
232
+ func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
233
+ elif not isinstance(func["arguments"], str):
234
+ func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
235
+ return tool_calls
236
+ except (json.JSONDecodeError, AttributeError):
237
+ pass
238
+
239
+ i += 1
240
+ else:
241
+ i += 1
242
 
243
  # Attempt 3: Parse natural language function calls
244
  natural_lang_match = FUNCTION_CALL_PATTERN.search(scannable_text)
 
262
 
263
 
264
  def remove_tool_json_content(text: str) -> str:
265
+ """Remove tool JSON content from response text - using bracket balance method"""
266
+
267
  def remove_tool_call_block(match: re.Match) -> str:
268
  json_content = match.group(1)
269
  try:
 
273
  except (json.JSONDecodeError, AttributeError):
274
  pass
275
  return match.group(0)
276
+
277
+ # Step 1: Remove fenced tool JSON blocks
278
  cleaned_text = TOOL_CALL_FENCE_PATTERN.sub(remove_tool_call_block, text)
279
+
280
+ # Step 2: Remove inline tool JSON - 使用基于括号平衡的智能方法
281
+ # 查找所有可能的 JSON 对象并精确删除包含 tool_calls 的对象
282
+ result = []
283
+ i = 0
284
+ while i < len(cleaned_text):
285
+ if cleaned_text[i] == '{':
286
+ # 尝试找到匹配的右括号
287
+ brace_count = 1
288
+ j = i + 1
289
+ in_string = False
290
+ escape_next = False
291
+
292
+ while j < len(cleaned_text) and brace_count > 0:
293
+ if escape_next:
294
+ escape_next = False
295
+ elif cleaned_text[j] == '\\':
296
+ escape_next = True
297
+ elif cleaned_text[j] == '"' and not escape_next:
298
+ in_string = not in_string
299
+ elif not in_string:
300
+ if cleaned_text[j] == '{':
301
+ brace_count += 1
302
+ elif cleaned_text[j] == '}':
303
+ brace_count -= 1
304
+ j += 1
305
+
306
+ if brace_count == 0:
307
+ # 找到了完整的 JSON 对象
308
+ json_str = cleaned_text[i:j]
309
+ try:
310
+ parsed = json.loads(json_str)
311
+ if "tool_calls" in parsed:
312
+ # 这是一个工具调用,跳过它
313
+ i = j
314
+ continue
315
+ except:
316
+ pass
317
+
318
+ # 不是工具调用或无法解析,保留这个字符
319
+ result.append(cleaned_text[i])
320
+ i += 1
321
+ else:
322
+ result.append(cleaned_text[i])
323
+ i += 1
324
+
325
+ return ''.join(result).strip()
tests/test_final_verification.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """验证 tools.py 修复后的功能"""
2
+
3
+ import sys
4
+ sys.path.append('E:\\GitHub\\z.ai2api_python')
5
+
6
+ from app.utils.tools import remove_tool_json_content
7
+
8
+ def test_remove_tool_json():
9
+ print("=" * 60)
10
+ print("验证 tools.py 中的 remove_tool_json_content 函数")
11
+ print("=" * 60)
12
+
13
+ # 测试案例 1: 纯工具调用 JSON(应该被完全移除)
14
+ test1 = '{"tool_calls": [{"id": "call_1", "type": "function"}]}'
15
+ result1 = remove_tool_json_content(test1)
16
+ print(f"\n测试1 - 纯工具调用:")
17
+ print(f"输入: {test1}")
18
+ print(f"输出: '{result1}'")
19
+ print("[PASS] 通过" if result1 == "" else "[FAIL] 失败")
20
+
21
+ # 测试案例 2: 混合内容
22
+ test2 = '''这是开始文本
23
+ {"tool_calls": [{"id": "call_2", "type": "function"}]}
24
+ 这是结束文本'''
25
+ result2 = remove_tool_json_content(test2)
26
+ print(f"\n测试2 - 混合内容:")
27
+ print(f"输入: {repr(test2)}")
28
+ print(f"输出: {repr(result2)}")
29
+ expected2 = "这是开始文本\n\n这是结束文本"
30
+ print("[PASS] 通过" if result2 == expected2 else "[FAIL] 失败")
31
+
32
+ # 测试案例 3: 普通 JSON(不应被删除)
33
+ test3 = '{"data": {"result": "success"}}'
34
+ result3 = remove_tool_json_content(test3)
35
+ print(f"\n测试3 - 普通JSON:")
36
+ print(f"输入: {test3}")
37
+ print(f"输出: '{result3}'")
38
+ print("[PASS] 通过" if result3 == test3 else "[FAIL] 失败")
39
+
40
+ # 测试案例 4: 代码块中的工具调用
41
+ test4 = '''正常文本
42
+ ```json
43
+ {"tool_calls": [{"id": "call_3"}]}
44
+ ```
45
+ 保留文本'''
46
+ result4 = remove_tool_json_content(test4)
47
+ print(f"\n测试4 - 代码块中的工具调用:")
48
+ print(f"输入: {repr(test4)}")
49
+ print(f"输出: {repr(result4)}")
50
+ print("[PASS] 通过" if "保留文本" in result4 and "tool_calls" not in result4 else "[FAIL] 失败")
51
+
52
+ if __name__ == "__main__":
53
+ test_remove_tool_json()
54
+ print("\n" + "=" * 60)
55
+ print("所有测试完成!正则表达式问题已成功修复。")
56
+ print("=" * 60)
tests/test_re.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """测试和修复正则表达式问题"""
2
+
3
+ import json
4
+ import re
5
+
6
+ # 原始的正则表达式(来自 tools.py)
7
+ TOOL_CALL_FENCE_PATTERN = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)
8
+ TOOL_CALL_INLINE_PATTERN_OLD = re.compile(r"(\{[^{}]{0,10000}\"tool_calls\".*?\})", re.DOTALL)
9
+
10
+ # 改进的正则表达式
11
+ # 方案1:更精确的匹配 - 只匹配包含 tool_calls 的完整 JSON 对象
12
+ TOOL_CALL_INLINE_PATTERN_NEW = re.compile(
13
+ r'\{(?:[^{}]|\{[^{}]*\})*"tool_calls"\s*:\s*\[[^\]]*\](?:[^{}]|\{[^{}]*\})*\}',
14
+ re.MULTILINE
15
+ )
16
+
17
+ def remove_tool_json_content_old(text: str) -> str:
18
+ """原始的移除工具JSON内容函数"""
19
+
20
+ def remove_tool_call_block(match: re.Match) -> str:
21
+ json_content = match.group(1)
22
+ try:
23
+ parsed_data = json.loads(json_content)
24
+ if "tool_calls" in parsed_data:
25
+ return ""
26
+ except (json.JSONDecodeError, AttributeError):
27
+ pass
28
+ return match.group(0)
29
+
30
+ # Remove fenced tool JSON blocks
31
+ cleaned_text = TOOL_CALL_FENCE_PATTERN.sub(remove_tool_call_block, text)
32
+ # Remove inline tool JSON
33
+ cleaned_text = TOOL_CALL_INLINE_PATTERN_OLD.sub("", cleaned_text)
34
+ return cleaned_text.strip()
35
+
36
+ def remove_tool_json_content_new(text: str) -> str:
37
+ """改进的移除工具JSON内容函数 - 使用基于括号平衡的方法"""
38
+
39
+ def remove_tool_call_block(match: re.Match) -> str:
40
+ json_content = match.group(1)
41
+ try:
42
+ parsed_data = json.loads(json_content)
43
+ if "tool_calls" in parsed_data:
44
+ return ""
45
+ except (json.JSONDecodeError, AttributeError):
46
+ pass
47
+ return match.group(0)
48
+
49
+ # Step 1: Remove fenced tool JSON blocks
50
+ cleaned_text = TOOL_CALL_FENCE_PATTERN.sub(remove_tool_call_block, text)
51
+
52
+ # Step 2: Remove inline tool JSON - 使用更智能的方法
53
+ # 查找所有可能的 JSON 对象
54
+ result = []
55
+ i = 0
56
+ while i < len(cleaned_text):
57
+ if cleaned_text[i] == '{':
58
+ # 尝试找到匹配的右括号
59
+ brace_count = 1
60
+ j = i + 1
61
+ in_string = False
62
+ escape_next = False
63
+
64
+ while j < len(cleaned_text) and brace_count > 0:
65
+ if escape_next:
66
+ escape_next = False
67
+ elif cleaned_text[j] == '\\':
68
+ escape_next = True
69
+ elif cleaned_text[j] == '"' and not escape_next:
70
+ in_string = not in_string
71
+ elif not in_string:
72
+ if cleaned_text[j] == '{':
73
+ brace_count += 1
74
+ elif cleaned_text[j] == '}':
75
+ brace_count -= 1
76
+ j += 1
77
+
78
+ if brace_count == 0:
79
+ # 找到了完整的 JSON 对象
80
+ json_str = cleaned_text[i:j]
81
+ try:
82
+ parsed = json.loads(json_str)
83
+ if "tool_calls" in parsed:
84
+ # 这是一个工具调用,跳过它
85
+ i = j
86
+ continue
87
+ except:
88
+ pass
89
+
90
+ # 不是工具调用或无法解析,保留这个字符
91
+ result.append(cleaned_text[i])
92
+ i += 1
93
+ else:
94
+ result.append(cleaned_text[i])
95
+ i += 1
96
+
97
+ return ''.join(result).strip()
98
+
99
+ # 测试用例
100
+ test_cases = [
101
+ # 测试案例 1: 只有工具调用JSON,应该被完全删除
102
+ {
103
+ "name": "纯工具调用JSON",
104
+ "input": """{"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "test", "arguments": "{}"}}]}""",
105
+ "expected": ""
106
+ },
107
+
108
+ # 测试案例 2: 包含工具调用的 JSON 代码块
109
+ {
110
+ "name": "代码块中的工具调用",
111
+ "input": """这是一些正常的文本内容。
112
+
113
+ ```json
114
+ {
115
+ "tool_calls": [
116
+ {
117
+ "id": "call_123",
118
+ "type": "function",
119
+ "function": {
120
+ "name": "test_function",
121
+ "arguments": "{\\"param\\": \\"value\\"}"
122
+ }
123
+ }
124
+ ]
125
+ }
126
+ ```
127
+
128
+ 这部分内容应该被保留。""",
129
+ "expected": """这是一些正常的文本内容。
130
+
131
+
132
+
133
+ 这部分内容应该被保留。"""
134
+ },
135
+
136
+ # 测试案例 3: 混合内容
137
+ {
138
+ "name": "混合内容",
139
+ "input": """让我为您执行一个函数调用:
140
+
141
+ {"tool_calls": [{"id": "call_789", "type": "function", "function": {"name": "search", "arguments": "{\\"query\\": \\"test\\"}"}}]}
142
+
143
+ 函数执行结果如下:
144
+ - 找到了相关内容
145
+ - 处理完成
146
+
147
+ 这里还有其他重要信息需要保留。""",
148
+ "expected": """让我为您执行一个函数调用:
149
+
150
+
151
+
152
+ 函数执行结果如下:
153
+ - 找到了相关内容
154
+ - 处理完成
155
+
156
+ 这里还有其他重要信息需要保留。"""
157
+ },
158
+
159
+ # 测试案例 4: 不应该被删除的普通 JSON
160
+ {
161
+ "name": "普通JSON(应保留)",
162
+ "input": """这是一个普通的 JSON 示例:
163
+ {"data": {"result": "success"}}
164
+
165
+ 这不是工具调用,应该保留。""",
166
+ "expected": """这是一个普通的 JSON 示例:
167
+ {"data": {"result": "success"}}
168
+
169
+ 这不是工具调用,应该保留。"""
170
+ },
171
+
172
+ # 测试案例 5: 嵌套的复杂JSON
173
+ {
174
+ "name": "嵌套复杂JSON",
175
+ "input": """开始文本
176
+ {"tool_calls": [{"id": "call_1", "function": {"name": "test", "arguments": "{\\"nested\\": {\\"deep\\": \\"value\\"}}"}}]}
177
+ 中间文本
178
+ {"normal": {"data": "keep this"}}
179
+ 结束文本""",
180
+ "expected": """开始文本
181
+
182
+ 中间文本
183
+ {"normal": {"data": "keep this"}}
184
+ 结束文本"""
185
+ }
186
+ ]
187
+
188
+ def run_tests():
189
+ print("=" * 80)
190
+ print("测试正则表达式处理")
191
+ print("=" * 80)
192
+
193
+ passed = 0
194
+ failed = 0
195
+
196
+ for test_case in test_cases:
197
+ print(f"\n测试案例: {test_case['name']}")
198
+ print("-" * 40)
199
+ print("输入文本:")
200
+ print(repr(test_case['input']))
201
+
202
+ print("\n使用原始函数处理后:")
203
+ result_old = remove_tool_json_content_old(test_case['input'])
204
+ print(repr(result_old))
205
+
206
+ print("\n使用改进函数处理后:")
207
+ result_new = remove_tool_json_content_new(test_case['input'])
208
+ print(repr(result_new))
209
+
210
+ print("\n期望结果:")
211
+ print(repr(test_case['expected']))
212
+
213
+ # 检查新函数是否正确
214
+ if result_new == test_case['expected']:
215
+ print("[PASS] 新函数通过测试")
216
+ passed += 1
217
+ else:
218
+ print("[FAIL] 新函数测试失败")
219
+ failed += 1
220
+
221
+ print("-" * 40)
222
+
223
+ print(f"\n\n总结: {passed} 个通过, {failed} 个失败")
224
+
225
+ if __name__ == "__main__":
226
+ run_tests()