Spaces:
Paused
Paused
| package filter | |
| import ( | |
| "encoding/json" | |
| "fmt" | |
| "regexp" | |
| "strings" | |
| "github.com/google/uuid" | |
| "zai-proxy/internal/model" | |
| ) | |
| // altToolCallPattern 匹配 [TOOL]...[/TOOL] 和 [TOOL_CALL]...[/TOOL_CALL] 格式 | |
| var altToolCallPattern = regexp.MustCompile(`\[TOOL(?:_CALL)?\]\s*([\s\S]*?)\s*\[/TOOL(?:_CALL)?\]`) | |
| // jsonBlockPattern 匹配 markdown JSON 代码块中的 tool call | |
| var jsonBlockPattern = regexp.MustCompile("```json\\s*\\n(\\{[\\s\\S]*?\"name\"[\\s\\S]*?\\})\\s*\\n```") | |
| // ExtractPromptToolCalls 从文本中提取所有 tool call 块(支持多种格式), | |
| // 返回清理后的文本和解析出的 tool calls。 | |
| func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []model.ToolCall) { | |
| var allCalls []model.ToolCall | |
| cleaned := content | |
| // 首先尝试 <tool_call> 格式(使用 brace-counting 处理嵌套 JSON) | |
| if result, calls := extractToolCallTags(cleaned); len(calls) > 0 { | |
| allCalls = append(allCalls, calls...) | |
| cleaned = result | |
| } | |
| // 然后尝试 [TOOL]/[TOOL_CALL] 和 markdown JSON 格式 | |
| for _, pattern := range []*regexp.Regexp{altToolCallPattern, jsonBlockPattern} { | |
| matches := pattern.FindAllStringSubmatchIndex(cleaned, -1) | |
| if len(matches) == 0 { | |
| continue | |
| } | |
| for i := len(matches) - 1; i >= 0; i-- { | |
| match := matches[i] | |
| fullStart, fullEnd := match[0], match[1] | |
| groupStart, groupEnd := match[2], match[3] | |
| jsonStr := cleaned[groupStart:groupEnd] | |
| if calls := parsePromptToolCallJSON(jsonStr); len(calls) > 0 { | |
| allCalls = append(allCalls, calls...) | |
| } | |
| cleaned = cleaned[:fullStart] + cleaned[fullEnd:] | |
| } | |
| } | |
| if len(allCalls) == 0 { | |
| return content, nil | |
| } | |
| // 清理多余空行 | |
| cleaned = strings.TrimSpace(cleaned) | |
| for strings.Contains(cleaned, "\n\n\n") { | |
| cleaned = strings.ReplaceAll(cleaned, "\n\n\n", "\n\n") | |
| } | |
| // 为每个 tool call 分配 ID | |
| for i := range allCalls { | |
| if allCalls[i].ID == "" { | |
| allCalls[i].ID = fmt.Sprintf("call_%s", uuid.New().String()[:24]) | |
| } | |
| allCalls[i].Index = i | |
| allCalls[i].Type = "function" | |
| } | |
| return cleaned, allCalls | |
| } | |
| // extractToolCallTags 使用 brace-counting 提取 <tool_call> 块中的 JSON, | |
| // 正确处理嵌套 JSON 对象和缺失/错误的闭合标签。 | |
| func extractToolCallTags(content string) (cleanContent string, toolCalls []model.ToolCall) { | |
| const openTag = "<tool_call>" | |
| var calls []model.ToolCall | |
| cleaned := content | |
| // 从后向前查找所有 <tool_call> 标记,避免索引偏移 | |
| var tagPositions []int | |
| searchFrom := 0 | |
| for { | |
| idx := strings.Index(cleaned[searchFrom:], openTag) | |
| if idx == -1 { | |
| break | |
| } | |
| tagPositions = append(tagPositions, searchFrom+idx) | |
| searchFrom += idx + len(openTag) | |
| } | |
| // 从后向前处理 | |
| for i := len(tagPositions) - 1; i >= 0; i-- { | |
| tagStart := tagPositions[i] | |
| afterTag := cleaned[tagStart+len(openTag):] | |
| // 找到 JSON 对象的起始 { | |
| jsonStart := strings.Index(afterTag, "{") | |
| if jsonStart == -1 { | |
| continue | |
| } | |
| // 提取 { 之前可能存在的函数名前缀(如 <tool_call>Read{...}) | |
| funcNamePrefix := strings.TrimSpace(afterTag[:jsonStart]) | |
| // 使用 brace-counting 找到匹配的 } | |
| jsonEnd := findMatchingBrace(afterTag[jsonStart:]) | |
| if jsonEnd == -1 { | |
| continue | |
| } | |
| jsonStr := afterTag[jsonStart : jsonStart+jsonEnd+1] | |
| parsed := parsePromptToolCallJSON(jsonStr) | |
| // 如果标准格式解析失败,但有函数名前缀,尝试当作 FuncName{args} 格式 | |
| if len(parsed) == 0 && funcNamePrefix != "" { | |
| wrapped := fmt.Sprintf(`{"name": %q, "arguments": %s}`, funcNamePrefix, jsonStr) | |
| parsed = parsePromptToolCallJSON(wrapped) | |
| } | |
| if len(parsed) == 0 { | |
| continue | |
| } | |
| calls = append(parsed, calls...) | |
| // 计算要移除的范围:从 <tool_call> 到 JSON 结束 + 可选的闭合标签 | |
| blockEnd := tagStart + len(openTag) + jsonStart + jsonEnd + 1 | |
| remaining := cleaned[blockEnd:] | |
| // 移除可选的闭合标签 | |
| for _, closeTag := range []string{"</tool_call>", "</think>"} { | |
| trimmed := strings.TrimLeft(remaining, " \t\n") | |
| if strings.HasPrefix(trimmed, closeTag) { | |
| blockEnd = blockEnd + (len(remaining) - len(trimmed)) + len(closeTag) | |
| break | |
| } | |
| } | |
| cleaned = cleaned[:tagStart] + cleaned[blockEnd:] | |
| } | |
| return cleaned, calls | |
| } | |
| // findMatchingBrace 在以 { 开头的字符串中找到匹配的 } 的索引。 | |
| // 返回 -1 如果未找到匹配的闭合大括号。 | |
| func findMatchingBrace(s string) int { | |
| if len(s) == 0 || s[0] != '{' { | |
| return -1 | |
| } | |
| depth := 0 | |
| inString := false | |
| escape := false | |
| for i, ch := range s { | |
| if escape { | |
| escape = false | |
| continue | |
| } | |
| if ch == '\\' && inString { | |
| escape = true | |
| continue | |
| } | |
| if ch == '"' { | |
| inString = !inString | |
| continue | |
| } | |
| if inString { | |
| continue | |
| } | |
| if ch == '{' { | |
| depth++ | |
| } else if ch == '}' { | |
| depth-- | |
| if depth == 0 { | |
| return i | |
| } | |
| } | |
| } | |
| return -1 | |
| } | |
| // parsePromptToolCallJSON 解析 <tool_call> 内的 JSON | |
| func parsePromptToolCallJSON(content string) []model.ToolCall { | |
| content = strings.TrimSpace(content) | |
| if content == "" { | |
| return nil | |
| } | |
| // 标准格式: {"name": "xxx", "arguments": {...}} | |
| var call struct { | |
| Name string `json:"name"` | |
| Arguments json.RawMessage `json:"arguments"` | |
| } | |
| if err := json.Unmarshal([]byte(content), &call); err == nil && call.Name != "" { | |
| argsStr := string(call.Arguments) | |
| // 如果 arguments 不是字符串,序列化为字符串 | |
| if len(argsStr) > 0 && argsStr[0] != '"' { | |
| // 已经是 JSON 对象/其他类型,直接用 | |
| } else { | |
| // 是 JSON 字符串,解引用 | |
| var s string | |
| if json.Unmarshal(call.Arguments, &s) == nil { | |
| argsStr = s | |
| } | |
| } | |
| return []model.ToolCall{{ | |
| Function: model.FunctionCall{ | |
| Name: call.Name, | |
| Arguments: argsStr, | |
| }, | |
| }} | |
| } | |
| return nil | |
| } | |
| // HasPromptToolCallOpen 检测文本中是否有未关闭的 tool call 标签 | |
| func HasPromptToolCallOpen(content string) bool { | |
| // <tool_call> | |
| if strings.Count(content, "<tool_call>") > strings.Count(content, "</tool_call>") { | |
| return true | |
| } | |
| // [TOOL] / [TOOL_CALL] | |
| if strings.Count(content, "[TOOL]") > strings.Count(content, "[/TOOL]") { | |
| return true | |
| } | |
| if strings.Count(content, "[TOOL_CALL]") > strings.Count(content, "[/TOOL_CALL]") { | |
| return true | |
| } | |
| return false | |
| } | |