File size: 6,278 Bytes
8646505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9506c04
 
 
 
 
 
 
 
8646505
 
 
 
 
 
 
 
 
 
9506c04
8646505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9506c04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8646505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
package filter

import (
	"encoding/json"
	"fmt"
	"regexp"
	"strings"

	"github.com/google/uuid"

	"zai-proxy/internal/model"
)

// altToolCallPattern 匹配 [TOOL]...[/TOOL] 和 [TOOL_CALL]...[/TOOL_CALL] 格式
var altToolCallPattern = regexp.MustCompile(`\[TOOL(?:_CALL)?\]\s*([\s\S]*?)\s*\[/TOOL(?:_CALL)?\]`)

// jsonBlockPattern 匹配 markdown JSON 代码块中的 tool call
var jsonBlockPattern = regexp.MustCompile("```json\\s*\\n(\\{[\\s\\S]*?\"name\"[\\s\\S]*?\\})\\s*\\n```")

// ExtractPromptToolCalls 从文本中提取所有 tool call 块(支持多种格式),
// 返回清理后的文本和解析出的 tool calls。
func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []model.ToolCall) {
	var allCalls []model.ToolCall
	cleaned := content

	// 首先尝试 <tool_call> 格式(使用 brace-counting 处理嵌套 JSON)
	if result, calls := extractToolCallTags(cleaned); len(calls) > 0 {
		allCalls = append(allCalls, calls...)
		cleaned = result
	}

	// 然后尝试 [TOOL]/[TOOL_CALL] 和 markdown JSON 格式
	for _, pattern := range []*regexp.Regexp{altToolCallPattern, jsonBlockPattern} {
		matches := pattern.FindAllStringSubmatchIndex(cleaned, -1)
		if len(matches) == 0 {
			continue
		}
		for i := len(matches) - 1; i >= 0; i-- {
			match := matches[i]
			fullStart, fullEnd := match[0], match[1]
			groupStart, groupEnd := match[2], match[3]
			jsonStr := cleaned[groupStart:groupEnd]
			if calls := parsePromptToolCallJSON(jsonStr); len(calls) > 0 {
				allCalls = append(allCalls, calls...)
			}
			cleaned = cleaned[:fullStart] + cleaned[fullEnd:]
		}
	}

	if len(allCalls) == 0 {
		return content, nil
	}

	// 清理多余空行
	cleaned = strings.TrimSpace(cleaned)
	for strings.Contains(cleaned, "\n\n\n") {
		cleaned = strings.ReplaceAll(cleaned, "\n\n\n", "\n\n")
	}

	// 为每个 tool call 分配 ID
	for i := range allCalls {
		if allCalls[i].ID == "" {
			allCalls[i].ID = fmt.Sprintf("call_%s", uuid.New().String()[:24])
		}
		allCalls[i].Index = i
		allCalls[i].Type = "function"
	}

	return cleaned, allCalls
}

// extractToolCallTags 使用 brace-counting 提取 <tool_call> 块中的 JSON,
// 正确处理嵌套 JSON 对象和缺失/错误的闭合标签。
func extractToolCallTags(content string) (cleanContent string, toolCalls []model.ToolCall) {
	const openTag = "<tool_call>"
	var calls []model.ToolCall
	cleaned := content

	// 从后向前查找所有 <tool_call> 标记,避免索引偏移
	var tagPositions []int
	searchFrom := 0
	for {
		idx := strings.Index(cleaned[searchFrom:], openTag)
		if idx == -1 {
			break
		}
		tagPositions = append(tagPositions, searchFrom+idx)
		searchFrom += idx + len(openTag)
	}

	// 从后向前处理
	for i := len(tagPositions) - 1; i >= 0; i-- {
		tagStart := tagPositions[i]
		afterTag := cleaned[tagStart+len(openTag):]

		// 找到 JSON 对象的起始 {
		jsonStart := strings.Index(afterTag, "{")
		if jsonStart == -1 {
			continue
		}

		// 提取 { 之前可能存在的函数名前缀(如 <tool_call>Read{...})
		funcNamePrefix := strings.TrimSpace(afterTag[:jsonStart])

		// 使用 brace-counting 找到匹配的 }
		jsonEnd := findMatchingBrace(afterTag[jsonStart:])
		if jsonEnd == -1 {
			continue
		}

		jsonStr := afterTag[jsonStart : jsonStart+jsonEnd+1]
		parsed := parsePromptToolCallJSON(jsonStr)

		// 如果标准格式解析失败,但有函数名前缀,尝试当作 FuncName{args} 格式
		if len(parsed) == 0 && funcNamePrefix != "" {
			wrapped := fmt.Sprintf(`{"name": %q, "arguments": %s}`, funcNamePrefix, jsonStr)
			parsed = parsePromptToolCallJSON(wrapped)
		}

		if len(parsed) == 0 {
			continue
		}
		calls = append(parsed, calls...)

		// 计算要移除的范围:从 <tool_call> 到 JSON 结束 + 可选的闭合标签
		blockEnd := tagStart + len(openTag) + jsonStart + jsonEnd + 1
		remaining := cleaned[blockEnd:]
		// 移除可选的闭合标签
		for _, closeTag := range []string{"</tool_call>", "</think>"} {
			trimmed := strings.TrimLeft(remaining, " \t\n")
			if strings.HasPrefix(trimmed, closeTag) {
				blockEnd = blockEnd + (len(remaining) - len(trimmed)) + len(closeTag)
				break
			}
		}
		cleaned = cleaned[:tagStart] + cleaned[blockEnd:]
	}

	return cleaned, calls
}

// findMatchingBrace 在以 { 开头的字符串中找到匹配的 } 的索引。
// 返回 -1 如果未找到匹配的闭合大括号。
func findMatchingBrace(s string) int {
	if len(s) == 0 || s[0] != '{' {
		return -1
	}
	depth := 0
	inString := false
	escape := false
	for i, ch := range s {
		if escape {
			escape = false
			continue
		}
		if ch == '\\' && inString {
			escape = true
			continue
		}
		if ch == '"' {
			inString = !inString
			continue
		}
		if inString {
			continue
		}
		if ch == '{' {
			depth++
		} else if ch == '}' {
			depth--
			if depth == 0 {
				return i
			}
		}
	}
	return -1
}

// parsePromptToolCallJSON 解析 <tool_call> 内的 JSON
func parsePromptToolCallJSON(content string) []model.ToolCall {
	content = strings.TrimSpace(content)
	if content == "" {
		return nil
	}

	// 标准格式: {"name": "xxx", "arguments": {...}}
	var call struct {
		Name      string          `json:"name"`
		Arguments json.RawMessage `json:"arguments"`
	}
	if err := json.Unmarshal([]byte(content), &call); err == nil && call.Name != "" {
		argsStr := string(call.Arguments)
		// 如果 arguments 不是字符串,序列化为字符串
		if len(argsStr) > 0 && argsStr[0] != '"' {
			// 已经是 JSON 对象/其他类型,直接用
		} else {
			// 是 JSON 字符串,解引用
			var s string
			if json.Unmarshal(call.Arguments, &s) == nil {
				argsStr = s
			}
		}
		return []model.ToolCall{{
			Function: model.FunctionCall{
				Name:      call.Name,
				Arguments: argsStr,
			},
		}}
	}

	return nil
}

// HasPromptToolCallOpen 检测文本中是否有未关闭的 tool call 标签
func HasPromptToolCallOpen(content string) bool {
	// <tool_call>
	if strings.Count(content, "<tool_call>") > strings.Count(content, "</tool_call>") {
		return true
	}
	// [TOOL] / [TOOL_CALL]
	if strings.Count(content, "[TOOL]") > strings.Count(content, "[/TOOL]") {
		return true
	}
	if strings.Count(content, "[TOOL_CALL]") > strings.Count(content, "[/TOOL_CALL]") {
		return true
	}
	return false
}