package filter import ( "encoding/json" "fmt" "regexp" "strings" "github.com/google/uuid" "zai-proxy/internal/model" ) // altToolCallPattern 匹配 [TOOL]...[/TOOL] 和 [TOOL_CALL]...[/TOOL_CALL] 格式 var altToolCallPattern = regexp.MustCompile(`\[TOOL(?:_CALL)?\]\s*([\s\S]*?)\s*\[/TOOL(?:_CALL)?\]`) // jsonBlockPattern 匹配 markdown JSON 代码块中的 tool call var jsonBlockPattern = regexp.MustCompile("```json\\s*\\n(\\{[\\s\\S]*?\"name\"[\\s\\S]*?\\})\\s*\\n```") // ExtractPromptToolCalls 从文本中提取所有 tool call 块(支持多种格式), // 返回清理后的文本和解析出的 tool calls。 func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []model.ToolCall) { var allCalls []model.ToolCall cleaned := content // 首先尝试 格式(使用 brace-counting 处理嵌套 JSON) if result, calls := extractToolCallTags(cleaned); len(calls) > 0 { allCalls = append(allCalls, calls...) cleaned = result } // 然后尝试 [TOOL]/[TOOL_CALL] 和 markdown JSON 格式 for _, pattern := range []*regexp.Regexp{altToolCallPattern, jsonBlockPattern} { matches := pattern.FindAllStringSubmatchIndex(cleaned, -1) if len(matches) == 0 { continue } for i := len(matches) - 1; i >= 0; i-- { match := matches[i] fullStart, fullEnd := match[0], match[1] groupStart, groupEnd := match[2], match[3] jsonStr := cleaned[groupStart:groupEnd] if calls := parsePromptToolCallJSON(jsonStr); len(calls) > 0 { allCalls = append(allCalls, calls...) } cleaned = cleaned[:fullStart] + cleaned[fullEnd:] } } if len(allCalls) == 0 { return content, nil } // 清理多余空行 cleaned = strings.TrimSpace(cleaned) for strings.Contains(cleaned, "\n\n\n") { cleaned = strings.ReplaceAll(cleaned, "\n\n\n", "\n\n") } // 为每个 tool call 分配 ID for i := range allCalls { if allCalls[i].ID == "" { allCalls[i].ID = fmt.Sprintf("call_%s", uuid.New().String()[:24]) } allCalls[i].Index = i allCalls[i].Type = "function" } return cleaned, allCalls } // extractToolCallTags 使用 brace-counting 提取 块中的 JSON, // 正确处理嵌套 JSON 对象和缺失/错误的闭合标签。 func extractToolCallTags(content string) (cleanContent string, toolCalls []model.ToolCall) { const openTag = "" var calls []model.ToolCall cleaned := content // 从后向前查找所有 标记,避免索引偏移 var tagPositions []int searchFrom := 0 for { idx := strings.Index(cleaned[searchFrom:], openTag) if idx == -1 { break } tagPositions = append(tagPositions, searchFrom+idx) searchFrom += idx + len(openTag) } // 从后向前处理 for i := len(tagPositions) - 1; i >= 0; i-- { tagStart := tagPositions[i] afterTag := cleaned[tagStart+len(openTag):] // 找到 JSON 对象的起始 { jsonStart := strings.Index(afterTag, "{") if jsonStart == -1 { continue } // 提取 { 之前可能存在的函数名前缀(如 Read{...}) funcNamePrefix := strings.TrimSpace(afterTag[:jsonStart]) // 使用 brace-counting 找到匹配的 } jsonEnd := findMatchingBrace(afterTag[jsonStart:]) if jsonEnd == -1 { continue } jsonStr := afterTag[jsonStart : jsonStart+jsonEnd+1] parsed := parsePromptToolCallJSON(jsonStr) // 如果标准格式解析失败,但有函数名前缀,尝试当作 FuncName{args} 格式 if len(parsed) == 0 && funcNamePrefix != "" { wrapped := fmt.Sprintf(`{"name": %q, "arguments": %s}`, funcNamePrefix, jsonStr) parsed = parsePromptToolCallJSON(wrapped) } if len(parsed) == 0 { continue } calls = append(parsed, calls...) // 计算要移除的范围:从 到 JSON 结束 + 可选的闭合标签 blockEnd := tagStart + len(openTag) + jsonStart + jsonEnd + 1 remaining := cleaned[blockEnd:] // 移除可选的闭合标签 for _, closeTag := range []string{"", ""} { trimmed := strings.TrimLeft(remaining, " \t\n") if strings.HasPrefix(trimmed, closeTag) { blockEnd = blockEnd + (len(remaining) - len(trimmed)) + len(closeTag) break } } cleaned = cleaned[:tagStart] + cleaned[blockEnd:] } return cleaned, calls } // findMatchingBrace 在以 { 开头的字符串中找到匹配的 } 的索引。 // 返回 -1 如果未找到匹配的闭合大括号。 func findMatchingBrace(s string) int { if len(s) == 0 || s[0] != '{' { return -1 } depth := 0 inString := false escape := false for i, ch := range s { if escape { escape = false continue } if ch == '\\' && inString { escape = true continue } if ch == '"' { inString = !inString continue } if inString { continue } if ch == '{' { depth++ } else if ch == '}' { depth-- if depth == 0 { return i } } } return -1 } // parsePromptToolCallJSON 解析 内的 JSON func parsePromptToolCallJSON(content string) []model.ToolCall { content = strings.TrimSpace(content) if content == "" { return nil } // 标准格式: {"name": "xxx", "arguments": {...}} var call struct { Name string `json:"name"` Arguments json.RawMessage `json:"arguments"` } if err := json.Unmarshal([]byte(content), &call); err == nil && call.Name != "" { argsStr := string(call.Arguments) // 如果 arguments 不是字符串,序列化为字符串 if len(argsStr) > 0 && argsStr[0] != '"' { // 已经是 JSON 对象/其他类型,直接用 } else { // 是 JSON 字符串,解引用 var s string if json.Unmarshal(call.Arguments, &s) == nil { argsStr = s } } return []model.ToolCall{{ Function: model.FunctionCall{ Name: call.Name, Arguments: argsStr, }, }} } return nil } // HasPromptToolCallOpen 检测文本中是否有未关闭的 tool call 标签 func HasPromptToolCallOpen(content string) bool { // if strings.Count(content, "") > strings.Count(content, "") { return true } // [TOOL] / [TOOL_CALL] if strings.Count(content, "[TOOL]") > strings.Count(content, "[/TOOL]") { return true } if strings.Count(content, "[TOOL_CALL]") > strings.Count(content, "[/TOOL_CALL]") { return true } return false }