Spaces:
Paused
Paused
fix: use brace-counting to extract tool_call JSON from nested structures
Browse filesReplace regex-based <tool_call> extraction with a brace-counting parser
that correctly handles nested JSON objects and missing/malformed close tags.
Also supports FuncName{args} shorthand format.
- internal/filter/prompttool.go +116 -17
internal/filter/prompttool.go
CHANGED
|
@@ -11,46 +11,38 @@ import (
|
|
| 11 |
"zai-proxy/internal/model"
|
| 12 |
)
|
| 13 |
|
| 14 |
-
// promptToolCallPattern 匹配 <tool_call>...</tool_call> 块
|
| 15 |
-
var promptToolCallPattern = regexp.MustCompile(`<tool_call>\s*([\s\S]*?)\s*</tool_call>`)
|
| 16 |
-
|
| 17 |
// altToolCallPattern 匹配 [TOOL]...[/TOOL] 和 [TOOL_CALL]...[/TOOL_CALL] 格式
|
| 18 |
var altToolCallPattern = regexp.MustCompile(`\[TOOL(?:_CALL)?\]\s*([\s\S]*?)\s*\[/TOOL(?:_CALL)?\]`)
|
| 19 |
|
| 20 |
// jsonBlockPattern 匹配 markdown JSON 代码块中的 tool call
|
| 21 |
var jsonBlockPattern = regexp.MustCompile("```json\\s*\\n(\\{[\\s\\S]*?\"name\"[\\s\\S]*?\\})\\s*\\n```")
|
| 22 |
|
| 23 |
-
// allToolCallPatterns 按优先级排列的所有 tool call 模式
|
| 24 |
-
var allToolCallPatterns = []*regexp.Regexp{
|
| 25 |
-
promptToolCallPattern, // <tool_call> 最高优先级
|
| 26 |
-
altToolCallPattern, // [TOOL] / [TOOL_CALL]
|
| 27 |
-
jsonBlockPattern, // ```json ... ```
|
| 28 |
-
}
|
| 29 |
-
|
| 30 |
// ExtractPromptToolCalls 从文本中提取所有 tool call 块(支持多种格式),
|
| 31 |
// 返回清理后的文本和解析出的 tool calls。
|
| 32 |
func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []model.ToolCall) {
|
| 33 |
var allCalls []model.ToolCall
|
| 34 |
cleaned := content
|
| 35 |
|
| 36 |
-
//
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
matches := pattern.FindAllStringSubmatchIndex(cleaned, -1)
|
| 39 |
if len(matches) == 0 {
|
| 40 |
continue
|
| 41 |
}
|
| 42 |
-
|
| 43 |
-
// 从后向前移除匹配块,避免索引偏移
|
| 44 |
for i := len(matches) - 1; i >= 0; i-- {
|
| 45 |
match := matches[i]
|
| 46 |
fullStart, fullEnd := match[0], match[1]
|
| 47 |
groupStart, groupEnd := match[2], match[3]
|
| 48 |
-
|
| 49 |
jsonStr := cleaned[groupStart:groupEnd]
|
| 50 |
if calls := parsePromptToolCallJSON(jsonStr); len(calls) > 0 {
|
| 51 |
-
allCalls = append(
|
| 52 |
}
|
| 53 |
-
|
| 54 |
cleaned = cleaned[:fullStart] + cleaned[fullEnd:]
|
| 55 |
}
|
| 56 |
}
|
|
@@ -77,6 +69,113 @@ func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []mo
|
|
| 77 |
return cleaned, allCalls
|
| 78 |
}
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
// parsePromptToolCallJSON 解析 <tool_call> 内的 JSON
|
| 81 |
func parsePromptToolCallJSON(content string) []model.ToolCall {
|
| 82 |
content = strings.TrimSpace(content)
|
|
|
|
| 11 |
"zai-proxy/internal/model"
|
| 12 |
)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
// altToolCallPattern 匹配 [TOOL]...[/TOOL] 和 [TOOL_CALL]...[/TOOL_CALL] 格式
|
| 15 |
var altToolCallPattern = regexp.MustCompile(`\[TOOL(?:_CALL)?\]\s*([\s\S]*?)\s*\[/TOOL(?:_CALL)?\]`)
|
| 16 |
|
| 17 |
// jsonBlockPattern 匹配 markdown JSON 代码块中的 tool call
|
| 18 |
var jsonBlockPattern = regexp.MustCompile("```json\\s*\\n(\\{[\\s\\S]*?\"name\"[\\s\\S]*?\\})\\s*\\n```")
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
// ExtractPromptToolCalls 从文本中提取所有 tool call 块(支持多种格式),
|
| 21 |
// 返回清理后的文本和解析出的 tool calls。
|
| 22 |
func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []model.ToolCall) {
|
| 23 |
var allCalls []model.ToolCall
|
| 24 |
cleaned := content
|
| 25 |
|
| 26 |
+
// 首先尝试 <tool_call> 格式(使用 brace-counting 处理嵌套 JSON)
|
| 27 |
+
if result, calls := extractToolCallTags(cleaned); len(calls) > 0 {
|
| 28 |
+
allCalls = append(allCalls, calls...)
|
| 29 |
+
cleaned = result
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
// 然后尝试 [TOOL]/[TOOL_CALL] 和 markdown JSON 格式
|
| 33 |
+
for _, pattern := range []*regexp.Regexp{altToolCallPattern, jsonBlockPattern} {
|
| 34 |
matches := pattern.FindAllStringSubmatchIndex(cleaned, -1)
|
| 35 |
if len(matches) == 0 {
|
| 36 |
continue
|
| 37 |
}
|
|
|
|
|
|
|
| 38 |
for i := len(matches) - 1; i >= 0; i-- {
|
| 39 |
match := matches[i]
|
| 40 |
fullStart, fullEnd := match[0], match[1]
|
| 41 |
groupStart, groupEnd := match[2], match[3]
|
|
|
|
| 42 |
jsonStr := cleaned[groupStart:groupEnd]
|
| 43 |
if calls := parsePromptToolCallJSON(jsonStr); len(calls) > 0 {
|
| 44 |
+
allCalls = append(allCalls, calls...)
|
| 45 |
}
|
|
|
|
| 46 |
cleaned = cleaned[:fullStart] + cleaned[fullEnd:]
|
| 47 |
}
|
| 48 |
}
|
|
|
|
| 69 |
return cleaned, allCalls
|
| 70 |
}
|
| 71 |
|
| 72 |
+
// extractToolCallTags 使用 brace-counting 提取 <tool_call> 块中的 JSON,
|
| 73 |
+
// 正确处理嵌套 JSON 对象和缺失/错误的闭合标签。
|
| 74 |
+
func extractToolCallTags(content string) (cleanContent string, toolCalls []model.ToolCall) {
|
| 75 |
+
const openTag = "<tool_call>"
|
| 76 |
+
var calls []model.ToolCall
|
| 77 |
+
cleaned := content
|
| 78 |
+
|
| 79 |
+
// 从后向前查找所有 <tool_call> 标记,避免索引偏移
|
| 80 |
+
var tagPositions []int
|
| 81 |
+
searchFrom := 0
|
| 82 |
+
for {
|
| 83 |
+
idx := strings.Index(cleaned[searchFrom:], openTag)
|
| 84 |
+
if idx == -1 {
|
| 85 |
+
break
|
| 86 |
+
}
|
| 87 |
+
tagPositions = append(tagPositions, searchFrom+idx)
|
| 88 |
+
searchFrom += idx + len(openTag)
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// 从后向前处理
|
| 92 |
+
for i := len(tagPositions) - 1; i >= 0; i-- {
|
| 93 |
+
tagStart := tagPositions[i]
|
| 94 |
+
afterTag := cleaned[tagStart+len(openTag):]
|
| 95 |
+
|
| 96 |
+
// 找到 JSON 对象的起始 {
|
| 97 |
+
jsonStart := strings.Index(afterTag, "{")
|
| 98 |
+
if jsonStart == -1 {
|
| 99 |
+
continue
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// 提取 { 之前可能存在的函数名前缀(如 <tool_call>Read{...})
|
| 103 |
+
funcNamePrefix := strings.TrimSpace(afterTag[:jsonStart])
|
| 104 |
+
|
| 105 |
+
// 使用 brace-counting 找到匹配的 }
|
| 106 |
+
jsonEnd := findMatchingBrace(afterTag[jsonStart:])
|
| 107 |
+
if jsonEnd == -1 {
|
| 108 |
+
continue
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
jsonStr := afterTag[jsonStart : jsonStart+jsonEnd+1]
|
| 112 |
+
parsed := parsePromptToolCallJSON(jsonStr)
|
| 113 |
+
|
| 114 |
+
// 如果标准格式解析失败,但有函数名前缀,尝试当作 FuncName{args} 格式
|
| 115 |
+
if len(parsed) == 0 && funcNamePrefix != "" {
|
| 116 |
+
wrapped := fmt.Sprintf(`{"name": %q, "arguments": %s}`, funcNamePrefix, jsonStr)
|
| 117 |
+
parsed = parsePromptToolCallJSON(wrapped)
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
if len(parsed) == 0 {
|
| 121 |
+
continue
|
| 122 |
+
}
|
| 123 |
+
calls = append(parsed, calls...)
|
| 124 |
+
|
| 125 |
+
// 计算要移除的范围:从 <tool_call> 到 JSON 结束 + 可选的闭合标签
|
| 126 |
+
blockEnd := tagStart + len(openTag) + jsonStart + jsonEnd + 1
|
| 127 |
+
remaining := cleaned[blockEnd:]
|
| 128 |
+
// 移除可选的闭合标签
|
| 129 |
+
for _, closeTag := range []string{"</tool_call>", "</think>"} {
|
| 130 |
+
trimmed := strings.TrimLeft(remaining, " \t\n")
|
| 131 |
+
if strings.HasPrefix(trimmed, closeTag) {
|
| 132 |
+
blockEnd = blockEnd + (len(remaining) - len(trimmed)) + len(closeTag)
|
| 133 |
+
break
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
cleaned = cleaned[:tagStart] + cleaned[blockEnd:]
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
return cleaned, calls
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
// findMatchingBrace 在以 { 开头的字符串中找到匹配的 } 的索引。
|
| 143 |
+
// 返回 -1 如果未找到匹配的闭合大括号。
|
| 144 |
+
func findMatchingBrace(s string) int {
|
| 145 |
+
if len(s) == 0 || s[0] != '{' {
|
| 146 |
+
return -1
|
| 147 |
+
}
|
| 148 |
+
depth := 0
|
| 149 |
+
inString := false
|
| 150 |
+
escape := false
|
| 151 |
+
for i, ch := range s {
|
| 152 |
+
if escape {
|
| 153 |
+
escape = false
|
| 154 |
+
continue
|
| 155 |
+
}
|
| 156 |
+
if ch == '\\' && inString {
|
| 157 |
+
escape = true
|
| 158 |
+
continue
|
| 159 |
+
}
|
| 160 |
+
if ch == '"' {
|
| 161 |
+
inString = !inString
|
| 162 |
+
continue
|
| 163 |
+
}
|
| 164 |
+
if inString {
|
| 165 |
+
continue
|
| 166 |
+
}
|
| 167 |
+
if ch == '{' {
|
| 168 |
+
depth++
|
| 169 |
+
} else if ch == '}' {
|
| 170 |
+
depth--
|
| 171 |
+
if depth == 0 {
|
| 172 |
+
return i
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
return -1
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
// parsePromptToolCallJSON 解析 <tool_call> 内的 JSON
|
| 180 |
func parsePromptToolCallJSON(content string) []model.ToolCall {
|
| 181 |
content = strings.TrimSpace(content)
|