ss22345 commited on
Commit
9506c04
·
1 Parent(s): ab5972f

fix: use brace-counting to extract tool_call JSON from nested structures

Browse files

Replace regex-based <tool_call> extraction with a brace-counting parser
that correctly handles nested JSON objects and missing/malformed close tags.
Also supports FuncName{args} shorthand format.

Files changed (1) hide show
  1. internal/filter/prompttool.go +116 -17
internal/filter/prompttool.go CHANGED
@@ -11,46 +11,38 @@ import (
11
  "zai-proxy/internal/model"
12
  )
13
 
14
- // promptToolCallPattern 匹配 <tool_call>...</tool_call> 块
15
- var promptToolCallPattern = regexp.MustCompile(`<tool_call>\s*([\s\S]*?)\s*</tool_call>`)
16
-
17
  // altToolCallPattern 匹配 [TOOL]...[/TOOL] 和 [TOOL_CALL]...[/TOOL_CALL] 格式
18
  var altToolCallPattern = regexp.MustCompile(`\[TOOL(?:_CALL)?\]\s*([\s\S]*?)\s*\[/TOOL(?:_CALL)?\]`)
19
 
20
  // jsonBlockPattern 匹配 markdown JSON 代码块中的 tool call
21
  var jsonBlockPattern = regexp.MustCompile("```json\\s*\\n(\\{[\\s\\S]*?\"name\"[\\s\\S]*?\\})\\s*\\n```")
22
 
23
- // allToolCallPatterns 按优先级排列的所有 tool call 模式
24
- var allToolCallPatterns = []*regexp.Regexp{
25
- promptToolCallPattern, // <tool_call> 最高优先级
26
- altToolCallPattern, // [TOOL] / [TOOL_CALL]
27
- jsonBlockPattern, // ```json ... ```
28
- }
29
-
30
  // ExtractPromptToolCalls 从文本中提取所有 tool call 块(支持多种格式),
31
  // 返回清理后的文本和解析出的 tool calls。
32
  func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []model.ToolCall) {
33
  var allCalls []model.ToolCall
34
  cleaned := content
35
 
36
- // 按优级依次尝试各种格式
37
- for _, pattern := range allToolCallPatterns {
 
 
 
 
 
 
38
  matches := pattern.FindAllStringSubmatchIndex(cleaned, -1)
39
  if len(matches) == 0 {
40
  continue
41
  }
42
-
43
- // 从后向前移除匹配块,避免索引偏移
44
  for i := len(matches) - 1; i >= 0; i-- {
45
  match := matches[i]
46
  fullStart, fullEnd := match[0], match[1]
47
  groupStart, groupEnd := match[2], match[3]
48
-
49
  jsonStr := cleaned[groupStart:groupEnd]
50
  if calls := parsePromptToolCallJSON(jsonStr); len(calls) > 0 {
51
- allCalls = append(calls, allCalls...)
52
  }
53
-
54
  cleaned = cleaned[:fullStart] + cleaned[fullEnd:]
55
  }
56
  }
@@ -77,6 +69,113 @@ func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []mo
77
  return cleaned, allCalls
78
  }
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  // parsePromptToolCallJSON 解析 <tool_call> 内的 JSON
81
  func parsePromptToolCallJSON(content string) []model.ToolCall {
82
  content = strings.TrimSpace(content)
 
11
  "zai-proxy/internal/model"
12
  )
13
 
 
 
 
14
  // altToolCallPattern 匹配 [TOOL]...[/TOOL] 和 [TOOL_CALL]...[/TOOL_CALL] 格式
15
  var altToolCallPattern = regexp.MustCompile(`\[TOOL(?:_CALL)?\]\s*([\s\S]*?)\s*\[/TOOL(?:_CALL)?\]`)
16
 
17
  // jsonBlockPattern 匹配 markdown JSON 代码块中的 tool call
18
  var jsonBlockPattern = regexp.MustCompile("```json\\s*\\n(\\{[\\s\\S]*?\"name\"[\\s\\S]*?\\})\\s*\\n```")
19
 
 
 
 
 
 
 
 
20
  // ExtractPromptToolCalls 从文本中提取所有 tool call 块(支持多种格式),
21
  // 返回清理后的文本和解析出的 tool calls。
22
  func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []model.ToolCall) {
23
  var allCalls []model.ToolCall
24
  cleaned := content
25
 
26
+ // 先尝试 <tool_call> 格式(使用 brace-counting 处理嵌套 JSON)
27
+ if result, calls := extractToolCallTags(cleaned); len(calls) > 0 {
28
+ allCalls = append(allCalls, calls...)
29
+ cleaned = result
30
+ }
31
+
32
+ // 然后尝试 [TOOL]/[TOOL_CALL] 和 markdown JSON 格式
33
+ for _, pattern := range []*regexp.Regexp{altToolCallPattern, jsonBlockPattern} {
34
  matches := pattern.FindAllStringSubmatchIndex(cleaned, -1)
35
  if len(matches) == 0 {
36
  continue
37
  }
 
 
38
  for i := len(matches) - 1; i >= 0; i-- {
39
  match := matches[i]
40
  fullStart, fullEnd := match[0], match[1]
41
  groupStart, groupEnd := match[2], match[3]
 
42
  jsonStr := cleaned[groupStart:groupEnd]
43
  if calls := parsePromptToolCallJSON(jsonStr); len(calls) > 0 {
44
+ allCalls = append(allCalls, calls...)
45
  }
 
46
  cleaned = cleaned[:fullStart] + cleaned[fullEnd:]
47
  }
48
  }
 
69
  return cleaned, allCalls
70
  }
71
 
72
+ // extractToolCallTags 使用 brace-counting 提取 <tool_call> 块中的 JSON,
73
+ // 正确处理嵌套 JSON 对象和缺失/错误的闭合标签。
74
+ func extractToolCallTags(content string) (cleanContent string, toolCalls []model.ToolCall) {
75
+ const openTag = "<tool_call>"
76
+ var calls []model.ToolCall
77
+ cleaned := content
78
+
79
+ // 从后向前查找所有 <tool_call> 标记,避免索引偏移
80
+ var tagPositions []int
81
+ searchFrom := 0
82
+ for {
83
+ idx := strings.Index(cleaned[searchFrom:], openTag)
84
+ if idx == -1 {
85
+ break
86
+ }
87
+ tagPositions = append(tagPositions, searchFrom+idx)
88
+ searchFrom += idx + len(openTag)
89
+ }
90
+
91
+ // 从后向前处理
92
+ for i := len(tagPositions) - 1; i >= 0; i-- {
93
+ tagStart := tagPositions[i]
94
+ afterTag := cleaned[tagStart+len(openTag):]
95
+
96
+ // 找到 JSON 对象的起始 {
97
+ jsonStart := strings.Index(afterTag, "{")
98
+ if jsonStart == -1 {
99
+ continue
100
+ }
101
+
102
+ // 提取 { 之前可能存在的函数名前缀(如 <tool_call>Read{...})
103
+ funcNamePrefix := strings.TrimSpace(afterTag[:jsonStart])
104
+
105
+ // 使用 brace-counting 找到匹配的 }
106
+ jsonEnd := findMatchingBrace(afterTag[jsonStart:])
107
+ if jsonEnd == -1 {
108
+ continue
109
+ }
110
+
111
+ jsonStr := afterTag[jsonStart : jsonStart+jsonEnd+1]
112
+ parsed := parsePromptToolCallJSON(jsonStr)
113
+
114
+ // 如果标准格式解析失败,但有函数名前缀,尝试当作 FuncName{args} 格式
115
+ if len(parsed) == 0 && funcNamePrefix != "" {
116
+ wrapped := fmt.Sprintf(`{"name": %q, "arguments": %s}`, funcNamePrefix, jsonStr)
117
+ parsed = parsePromptToolCallJSON(wrapped)
118
+ }
119
+
120
+ if len(parsed) == 0 {
121
+ continue
122
+ }
123
+ calls = append(parsed, calls...)
124
+
125
+ // 计算要移除的范围:从 <tool_call> 到 JSON 结束 + 可选的闭合标签
126
+ blockEnd := tagStart + len(openTag) + jsonStart + jsonEnd + 1
127
+ remaining := cleaned[blockEnd:]
128
+ // 移除可选的闭合标签
129
+ for _, closeTag := range []string{"</tool_call>", "</think>"} {
130
+ trimmed := strings.TrimLeft(remaining, " \t\n")
131
+ if strings.HasPrefix(trimmed, closeTag) {
132
+ blockEnd = blockEnd + (len(remaining) - len(trimmed)) + len(closeTag)
133
+ break
134
+ }
135
+ }
136
+ cleaned = cleaned[:tagStart] + cleaned[blockEnd:]
137
+ }
138
+
139
+ return cleaned, calls
140
+ }
141
+
142
+ // findMatchingBrace 在以 { 开头的字符串中找到匹配的 } 的索引。
143
+ // 返回 -1 如果未找到匹配的闭合大括号。
144
+ func findMatchingBrace(s string) int {
145
+ if len(s) == 0 || s[0] != '{' {
146
+ return -1
147
+ }
148
+ depth := 0
149
+ inString := false
150
+ escape := false
151
+ for i, ch := range s {
152
+ if escape {
153
+ escape = false
154
+ continue
155
+ }
156
+ if ch == '\\' && inString {
157
+ escape = true
158
+ continue
159
+ }
160
+ if ch == '"' {
161
+ inString = !inString
162
+ continue
163
+ }
164
+ if inString {
165
+ continue
166
+ }
167
+ if ch == '{' {
168
+ depth++
169
+ } else if ch == '}' {
170
+ depth--
171
+ if depth == 0 {
172
+ return i
173
+ }
174
+ }
175
+ }
176
+ return -1
177
+ }
178
+
179
  // parsePromptToolCallJSON 解析 <tool_call> 内的 JSON
180
  func parsePromptToolCallJSON(content string) []model.ToolCall {
181
  content = strings.TrimSpace(content)