Spaces:
Paused
Paused
fix: improve tool calling reliability with multi-format parsing and Delta pointer fix
Browse files- Enhance system prompt with Chinese instructions and few-shot examples for
more reliable <tool_call> output from GLM models
- Add fallback parsing for [TOOL]...[/TOOL], [TOOL_CALL]...[/TOOL_CALL],
and markdown JSON block formats in prompt injection mode
- Change Choice.Delta to *Delta pointer so omitempty correctly omits the
field in non-streaming responses (fixes extra "delta":{} in JSON)
- Convert tool_calls/tool messages to plain text for upstream z.ai API
which doesn't support native tools field
- Add comprehensive tests for all new parsing formats and serialization
- internal/filter/prompttool.go +129 -0
- internal/filter/prompttool_test.go +224 -0
- internal/handler/chat.go +130 -12
- internal/handler/chat_test.go +73 -0
- internal/model/types.go +1 -1
- internal/model/types_test.go +40 -1
- internal/tools/prompt.go +96 -0
- internal/tools/prompt_test.go +132 -0
- internal/upstream/client.go +51 -20
internal/filter/prompttool.go
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package filter
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"fmt"
|
| 6 |
+
"regexp"
|
| 7 |
+
"strings"
|
| 8 |
+
|
| 9 |
+
"github.com/google/uuid"
|
| 10 |
+
|
| 11 |
+
"zai-proxy/internal/model"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
// promptToolCallPattern 匹配 <tool_call>...</tool_call> 块
|
| 15 |
+
var promptToolCallPattern = regexp.MustCompile(`<tool_call>\s*([\s\S]*?)\s*</tool_call>`)
|
| 16 |
+
|
| 17 |
+
// altToolCallPattern 匹配 [TOOL]...[/TOOL] 和 [TOOL_CALL]...[/TOOL_CALL] 格式
|
| 18 |
+
var altToolCallPattern = regexp.MustCompile(`\[TOOL(?:_CALL)?\]\s*([\s\S]*?)\s*\[/TOOL(?:_CALL)?\]`)
|
| 19 |
+
|
| 20 |
+
// jsonBlockPattern 匹配 markdown JSON 代码块中的 tool call
|
| 21 |
+
var jsonBlockPattern = regexp.MustCompile("```json\\s*\\n(\\{[\\s\\S]*?\"name\"[\\s\\S]*?\\})\\s*\\n```")
|
| 22 |
+
|
| 23 |
+
// allToolCallPatterns 按优先级排列的所有 tool call 模式
|
| 24 |
+
var allToolCallPatterns = []*regexp.Regexp{
|
| 25 |
+
promptToolCallPattern, // <tool_call> 最高优先级
|
| 26 |
+
altToolCallPattern, // [TOOL] / [TOOL_CALL]
|
| 27 |
+
jsonBlockPattern, // ```json ... ```
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// ExtractPromptToolCalls 从文本中提取所有 tool call 块(支持多种格式),
|
| 31 |
+
// 返回清理后的文本和解析出的 tool calls。
|
| 32 |
+
func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []model.ToolCall) {
|
| 33 |
+
var allCalls []model.ToolCall
|
| 34 |
+
cleaned := content
|
| 35 |
+
|
| 36 |
+
// 按优先级依次尝试各种格式
|
| 37 |
+
for _, pattern := range allToolCallPatterns {
|
| 38 |
+
matches := pattern.FindAllStringSubmatchIndex(cleaned, -1)
|
| 39 |
+
if len(matches) == 0 {
|
| 40 |
+
continue
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// 从后向前移除匹配块,避免索引偏移
|
| 44 |
+
for i := len(matches) - 1; i >= 0; i-- {
|
| 45 |
+
match := matches[i]
|
| 46 |
+
fullStart, fullEnd := match[0], match[1]
|
| 47 |
+
groupStart, groupEnd := match[2], match[3]
|
| 48 |
+
|
| 49 |
+
jsonStr := cleaned[groupStart:groupEnd]
|
| 50 |
+
if calls := parsePromptToolCallJSON(jsonStr); len(calls) > 0 {
|
| 51 |
+
allCalls = append(calls, allCalls...)
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
cleaned = cleaned[:fullStart] + cleaned[fullEnd:]
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
if len(allCalls) == 0 {
|
| 59 |
+
return content, nil
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// 清理多余空行
|
| 63 |
+
cleaned = strings.TrimSpace(cleaned)
|
| 64 |
+
for strings.Contains(cleaned, "\n\n\n") {
|
| 65 |
+
cleaned = strings.ReplaceAll(cleaned, "\n\n\n", "\n\n")
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// 为每个 tool call 分配 ID
|
| 69 |
+
for i := range allCalls {
|
| 70 |
+
if allCalls[i].ID == "" {
|
| 71 |
+
allCalls[i].ID = fmt.Sprintf("call_%s", uuid.New().String()[:24])
|
| 72 |
+
}
|
| 73 |
+
allCalls[i].Index = i
|
| 74 |
+
allCalls[i].Type = "function"
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return cleaned, allCalls
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// parsePromptToolCallJSON 解析 <tool_call> 内的 JSON
|
| 81 |
+
func parsePromptToolCallJSON(content string) []model.ToolCall {
|
| 82 |
+
content = strings.TrimSpace(content)
|
| 83 |
+
if content == "" {
|
| 84 |
+
return nil
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// 标准格式: {"name": "xxx", "arguments": {...}}
|
| 88 |
+
var call struct {
|
| 89 |
+
Name string `json:"name"`
|
| 90 |
+
Arguments json.RawMessage `json:"arguments"`
|
| 91 |
+
}
|
| 92 |
+
if err := json.Unmarshal([]byte(content), &call); err == nil && call.Name != "" {
|
| 93 |
+
argsStr := string(call.Arguments)
|
| 94 |
+
// 如果 arguments 不是字符串,序列化为字符串
|
| 95 |
+
if len(argsStr) > 0 && argsStr[0] != '"' {
|
| 96 |
+
// 已经是 JSON 对象/其他类型,直接用
|
| 97 |
+
} else {
|
| 98 |
+
// 是 JSON 字符串,解引用
|
| 99 |
+
var s string
|
| 100 |
+
if json.Unmarshal(call.Arguments, &s) == nil {
|
| 101 |
+
argsStr = s
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
return []model.ToolCall{{
|
| 105 |
+
Function: model.FunctionCall{
|
| 106 |
+
Name: call.Name,
|
| 107 |
+
Arguments: argsStr,
|
| 108 |
+
},
|
| 109 |
+
}}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
return nil
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// HasPromptToolCallOpen 检测文本中是否有未关闭的 tool call 标签
|
| 116 |
+
func HasPromptToolCallOpen(content string) bool {
|
| 117 |
+
// <tool_call>
|
| 118 |
+
if strings.Count(content, "<tool_call>") > strings.Count(content, "</tool_call>") {
|
| 119 |
+
return true
|
| 120 |
+
}
|
| 121 |
+
// [TOOL] / [TOOL_CALL]
|
| 122 |
+
if strings.Count(content, "[TOOL]") > strings.Count(content, "[/TOOL]") {
|
| 123 |
+
return true
|
| 124 |
+
}
|
| 125 |
+
if strings.Count(content, "[TOOL_CALL]") > strings.Count(content, "[/TOOL_CALL]") {
|
| 126 |
+
return true
|
| 127 |
+
}
|
| 128 |
+
return false
|
| 129 |
+
}
|
internal/filter/prompttool_test.go
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package filter
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"testing"
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
func TestExtractPromptToolCalls_NoToolCall(t *testing.T) {
|
| 8 |
+
content := "Hello, this is a normal response."
|
| 9 |
+
clean, calls := ExtractPromptToolCalls(content)
|
| 10 |
+
if clean != content {
|
| 11 |
+
t.Errorf("expected content unchanged, got %q", clean)
|
| 12 |
+
}
|
| 13 |
+
if len(calls) != 0 {
|
| 14 |
+
t.Error("expected no tool calls")
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
func TestExtractPromptToolCalls_SingleCall(t *testing.T) {
|
| 19 |
+
content := `Here is the result:
|
| 20 |
+
<tool_call>{"name": "get_weather", "arguments": {"city": "Beijing"}}</tool_call>
|
| 21 |
+
Done.`
|
| 22 |
+
|
| 23 |
+
clean, calls := ExtractPromptToolCalls(content)
|
| 24 |
+
|
| 25 |
+
if len(calls) != 1 {
|
| 26 |
+
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
| 27 |
+
}
|
| 28 |
+
if calls[0].Function.Name != "get_weather" {
|
| 29 |
+
t.Errorf("expected name get_weather, got %s", calls[0].Function.Name)
|
| 30 |
+
}
|
| 31 |
+
if calls[0].Function.Arguments != `{"city":"Beijing"}` && calls[0].Function.Arguments != `{"city": "Beijing"}` {
|
| 32 |
+
t.Errorf("unexpected arguments: %s", calls[0].Function.Arguments)
|
| 33 |
+
}
|
| 34 |
+
if calls[0].ID == "" {
|
| 35 |
+
t.Error("expected auto-generated ID")
|
| 36 |
+
}
|
| 37 |
+
if calls[0].Type != "function" {
|
| 38 |
+
t.Errorf("expected type function, got %s", calls[0].Type)
|
| 39 |
+
}
|
| 40 |
+
// Clean content should not contain tool_call tags
|
| 41 |
+
if clean == content {
|
| 42 |
+
t.Error("expected content to be cleaned")
|
| 43 |
+
}
|
| 44 |
+
if contains := "Here is the result:"; !containsStr(clean, contains) {
|
| 45 |
+
t.Errorf("expected clean content to contain %q", contains)
|
| 46 |
+
}
|
| 47 |
+
if containsStr(clean, "<tool_call>") {
|
| 48 |
+
t.Error("clean content should not contain <tool_call>")
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
func TestExtractPromptToolCalls_MultipleCalls(t *testing.T) {
|
| 53 |
+
content := `<tool_call>{"name": "func_a", "arguments": {"x": 1}}</tool_call>
|
| 54 |
+
<tool_call>{"name": "func_b", "arguments": {"y": 2}}</tool_call>`
|
| 55 |
+
|
| 56 |
+
clean, calls := ExtractPromptToolCalls(content)
|
| 57 |
+
|
| 58 |
+
if len(calls) != 2 {
|
| 59 |
+
t.Fatalf("expected 2 tool calls, got %d", len(calls))
|
| 60 |
+
}
|
| 61 |
+
if calls[0].Function.Name != "func_a" {
|
| 62 |
+
t.Errorf("expected first call func_a, got %s", calls[0].Function.Name)
|
| 63 |
+
}
|
| 64 |
+
if calls[1].Function.Name != "func_b" {
|
| 65 |
+
t.Errorf("expected second call func_b, got %s", calls[1].Function.Name)
|
| 66 |
+
}
|
| 67 |
+
if calls[0].Index != 0 || calls[1].Index != 1 {
|
| 68 |
+
t.Error("expected sequential indices")
|
| 69 |
+
}
|
| 70 |
+
if clean != "" {
|
| 71 |
+
t.Errorf("expected empty clean content, got %q", clean)
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
func TestExtractPromptToolCalls_OnlyToolCall(t *testing.T) {
|
| 76 |
+
content := `<tool_call>{"name": "calculate", "arguments": {"expression": "2+2"}}</tool_call>`
|
| 77 |
+
|
| 78 |
+
clean, calls := ExtractPromptToolCalls(content)
|
| 79 |
+
if len(calls) != 1 {
|
| 80 |
+
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
| 81 |
+
}
|
| 82 |
+
if calls[0].Function.Name != "calculate" {
|
| 83 |
+
t.Errorf("expected calculate, got %s", calls[0].Function.Name)
|
| 84 |
+
}
|
| 85 |
+
if clean != "" {
|
| 86 |
+
t.Errorf("expected empty clean content, got %q", clean)
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
func TestExtractPromptToolCalls_WithWhitespace(t *testing.T) {
|
| 91 |
+
content := `<tool_call>
|
| 92 |
+
{"name": "test", "arguments": {}}
|
| 93 |
+
</tool_call>`
|
| 94 |
+
|
| 95 |
+
_, calls := ExtractPromptToolCalls(content)
|
| 96 |
+
if len(calls) != 1 {
|
| 97 |
+
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
| 98 |
+
}
|
| 99 |
+
if calls[0].Function.Name != "test" {
|
| 100 |
+
t.Errorf("expected test, got %s", calls[0].Function.Name)
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
func TestHasPromptToolCallOpen(t *testing.T) {
|
| 105 |
+
tests := []struct {
|
| 106 |
+
content string
|
| 107 |
+
expected bool
|
| 108 |
+
}{
|
| 109 |
+
{"hello", false},
|
| 110 |
+
{"<tool_call>{}", true},
|
| 111 |
+
{"<tool_call>{}</tool_call>", false},
|
| 112 |
+
{"text <tool_call>partial...", true},
|
| 113 |
+
{"<tool_call>a</tool_call><tool_call>b", true},
|
| 114 |
+
{"[TOOL]partial", true},
|
| 115 |
+
{"[TOOL]{\"name\":\"x\"}[/TOOL]", false},
|
| 116 |
+
{"[TOOL_CALL]partial", true},
|
| 117 |
+
{"[TOOL_CALL]{\"name\":\"x\"}[/TOOL_CALL]", false},
|
| 118 |
+
}
|
| 119 |
+
for _, tt := range tests {
|
| 120 |
+
if got := HasPromptToolCallOpen(tt.content); got != tt.expected {
|
| 121 |
+
t.Errorf("HasPromptToolCallOpen(%q) = %v, want %v", tt.content, got, tt.expected)
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// ===== [TOOL]...[/TOOL] 格式 =====
|
| 127 |
+
|
| 128 |
+
func TestExtractPromptToolCalls_AltToolFormat(t *testing.T) {
|
| 129 |
+
content := `[TOOL]{"name": "get_weather", "arguments": {"city": "上海"}}[/TOOL]`
|
| 130 |
+
|
| 131 |
+
clean, calls := ExtractPromptToolCalls(content)
|
| 132 |
+
if len(calls) != 1 {
|
| 133 |
+
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
| 134 |
+
}
|
| 135 |
+
if calls[0].Function.Name != "get_weather" {
|
| 136 |
+
t.Errorf("expected get_weather, got %s", calls[0].Function.Name)
|
| 137 |
+
}
|
| 138 |
+
if calls[0].Type != "function" {
|
| 139 |
+
t.Errorf("expected type function, got %s", calls[0].Type)
|
| 140 |
+
}
|
| 141 |
+
if clean != "" {
|
| 142 |
+
t.Errorf("expected empty clean, got %q", clean)
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
func TestExtractPromptToolCalls_AltToolCallFormat(t *testing.T) {
|
| 147 |
+
content := `好的,我来调用工具。
|
| 148 |
+
[TOOL_CALL]{"name": "create_file", "arguments": {"filename": "test.txt", "content": "hello"}}[/TOOL_CALL]`
|
| 149 |
+
|
| 150 |
+
clean, calls := ExtractPromptToolCalls(content)
|
| 151 |
+
if len(calls) != 1 {
|
| 152 |
+
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
| 153 |
+
}
|
| 154 |
+
if calls[0].Function.Name != "create_file" {
|
| 155 |
+
t.Errorf("expected create_file, got %s", calls[0].Function.Name)
|
| 156 |
+
}
|
| 157 |
+
if !containsStr(clean, "好的") {
|
| 158 |
+
t.Errorf("expected clean to contain surrounding text, got %q", clean)
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// ===== markdown JSON block 格式 =====
|
| 163 |
+
|
| 164 |
+
func TestExtractPromptToolCalls_JsonBlockFormat(t *testing.T) {
|
| 165 |
+
content := "我来调用工具:\n```json\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"北京\"}}\n```\n"
|
| 166 |
+
|
| 167 |
+
clean, calls := ExtractPromptToolCalls(content)
|
| 168 |
+
if len(calls) != 1 {
|
| 169 |
+
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
| 170 |
+
}
|
| 171 |
+
if calls[0].Function.Name != "get_weather" {
|
| 172 |
+
t.Errorf("expected get_weather, got %s", calls[0].Function.Name)
|
| 173 |
+
}
|
| 174 |
+
if containsStr(clean, "```") {
|
| 175 |
+
t.Errorf("expected clean to not contain code block, got %q", clean)
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
// ===== 混合格式 =====
|
| 180 |
+
|
| 181 |
+
func TestExtractPromptToolCalls_MixedFormats(t *testing.T) {
|
| 182 |
+
content := `<tool_call>{"name": "func_a", "arguments": {}}</tool_call>
|
| 183 |
+
[TOOL]{"name": "func_b", "arguments": {}}[/TOOL]`
|
| 184 |
+
|
| 185 |
+
_, calls := ExtractPromptToolCalls(content)
|
| 186 |
+
if len(calls) != 2 {
|
| 187 |
+
t.Fatalf("expected 2 tool calls, got %d", len(calls))
|
| 188 |
+
}
|
| 189 |
+
// <tool_call> 优先被解析
|
| 190 |
+
names := map[string]bool{}
|
| 191 |
+
for _, c := range calls {
|
| 192 |
+
names[c.Function.Name] = true
|
| 193 |
+
}
|
| 194 |
+
if !names["func_a"] || !names["func_b"] {
|
| 195 |
+
t.Error("expected both func_a and func_b to be extracted")
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
// ===== <tool_call> 优先于其他格式 =====
|
| 200 |
+
|
| 201 |
+
func TestExtractPromptToolCalls_ToolCallPriority(t *testing.T) {
|
| 202 |
+
content := `<tool_call>{"name": "correct", "arguments": {}}</tool_call>`
|
| 203 |
+
|
| 204 |
+
_, calls := ExtractPromptToolCalls(content)
|
| 205 |
+
if len(calls) != 1 {
|
| 206 |
+
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
| 207 |
+
}
|
| 208 |
+
if calls[0].Function.Name != "correct" {
|
| 209 |
+
t.Errorf("expected correct, got %s", calls[0].Function.Name)
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
func containsStr(s, substr string) bool {
|
| 214 |
+
return len(s) >= len(substr) && (s == substr || len(s) > 0 && findSubstr(s, substr))
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
func findSubstr(s, substr string) bool {
|
| 218 |
+
for i := 0; i <= len(s)-len(substr); i++ {
|
| 219 |
+
if s[i:i+len(substr)] == substr {
|
| 220 |
+
return true
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
return false
|
| 224 |
+
}
|
internal/handler/chat.go
CHANGED
|
@@ -94,6 +94,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 94 |
totalContentOutputLength := 0
|
| 95 |
hasToolCalls := false
|
| 96 |
var collectedToolCalls []model.ToolCall
|
|
|
|
| 97 |
|
| 98 |
for scanner.Scan() {
|
| 99 |
line := scanner.Text()
|
|
@@ -145,7 +146,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 145 |
Model: modelName,
|
| 146 |
Choices: []model.Choice{{
|
| 147 |
Index: 0,
|
| 148 |
-
Delta: model.Delta{ReasoningContent: reasoningContent},
|
| 149 |
FinishReason: nil,
|
| 150 |
}},
|
| 151 |
}
|
|
@@ -182,7 +183,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 182 |
Model: modelName,
|
| 183 |
Choices: []model.Choice{{
|
| 184 |
Index: 0,
|
| 185 |
-
Delta: model.Delta{Content: textBeforeBlock},
|
| 186 |
FinishReason: nil,
|
| 187 |
}},
|
| 188 |
}
|
|
@@ -209,7 +210,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 209 |
Model: modelName,
|
| 210 |
Choices: []model.Choice{{
|
| 211 |
Index: 0,
|
| 212 |
-
Delta: model.Delta{Content: textBeforeBlock},
|
| 213 |
FinishReason: nil,
|
| 214 |
}},
|
| 215 |
}
|
|
@@ -247,7 +248,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 247 |
Model: modelName,
|
| 248 |
Choices: []model.Choice{{
|
| 249 |
Index: 0,
|
| 250 |
-
Delta: model.Delta{
|
| 251 |
ToolCalls: []model.ToolCall{tc},
|
| 252 |
},
|
| 253 |
FinishReason: nil,
|
|
@@ -270,7 +271,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 270 |
Model: modelName,
|
| 271 |
Choices: []model.Choice{{
|
| 272 |
Index: 0,
|
| 273 |
-
Delta: model.Delta{Content: pendingSourcesMarkdown},
|
| 274 |
FinishReason: nil,
|
| 275 |
}},
|
| 276 |
}
|
|
@@ -288,7 +289,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 288 |
Model: modelName,
|
| 289 |
Choices: []model.Choice{{
|
| 290 |
Index: 0,
|
| 291 |
-
Delta: model.Delta{Content: pendingImageSearchMarkdown},
|
| 292 |
FinishReason: nil,
|
| 293 |
}},
|
| 294 |
}
|
|
@@ -313,7 +314,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 313 |
Model: modelName,
|
| 314 |
Choices: []model.Choice{{
|
| 315 |
Index: 0,
|
| 316 |
-
Delta: model.Delta{ReasoningContent: processedRemaining},
|
| 317 |
FinishReason: nil,
|
| 318 |
}},
|
| 319 |
}
|
|
@@ -332,7 +333,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 332 |
Model: modelName,
|
| 333 |
Choices: []model.Choice{{
|
| 334 |
Index: 0,
|
| 335 |
-
Delta: model.Delta{ReasoningContent: pendingSourcesMarkdown},
|
| 336 |
FinishReason: nil,
|
| 337 |
}},
|
| 338 |
}
|
|
@@ -382,7 +383,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 382 |
Model: modelName,
|
| 383 |
Choices: []model.Choice{{
|
| 384 |
Index: 0,
|
| 385 |
-
Delta: model.Delta{ReasoningContent: reasoningContent},
|
| 386 |
FinishReason: nil,
|
| 387 |
}},
|
| 388 |
}
|
|
@@ -405,6 +406,63 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 405 |
totalContentOutputLength += len([]rune(content))
|
| 406 |
}
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
chunk := model.ChatCompletionChunk{
|
| 409 |
ID: completionID,
|
| 410 |
Object: "chat.completion.chunk",
|
|
@@ -412,7 +470,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 412 |
Model: modelName,
|
| 413 |
Choices: []model.Choice{{
|
| 414 |
Index: 0,
|
| 415 |
-
Delta: model.Delta{Content: content},
|
| 416 |
FinishReason: nil,
|
| 417 |
}},
|
| 418 |
}
|
|
@@ -426,6 +484,39 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 426 |
logger.LogError("[Upstream] scanner error: %v", err)
|
| 427 |
}
|
| 428 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
if remaining := searchRefFilter.Flush(); remaining != "" {
|
| 430 |
hasContent = true
|
| 431 |
chunk := model.ChatCompletionChunk{
|
|
@@ -435,7 +526,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 435 |
Model: modelName,
|
| 436 |
Choices: []model.Choice{{
|
| 437 |
Index: 0,
|
| 438 |
-
Delta: model.Delta{Content: remaining},
|
| 439 |
FinishReason: nil,
|
| 440 |
}},
|
| 441 |
}
|
|
@@ -459,7 +550,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 459 |
Model: modelName,
|
| 460 |
Choices: []model.Choice{{
|
| 461 |
Index: 0,
|
| 462 |
-
Delta: model.Delta{},
|
| 463 |
FinishReason: &stopReason,
|
| 464 |
}},
|
| 465 |
}
|
|
@@ -616,6 +707,15 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
|
|
| 616 |
fullReasoning := strings.Join(reasoningChunks, "")
|
| 617 |
fullReasoning = searchRefFilter.Process(fullReasoning) + searchRefFilter.Flush()
|
| 618 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
if fullContent == "" && len(collectedToolCalls) == 0 {
|
| 620 |
logger.LogError("Non-stream response 200 but no content received")
|
| 621 |
}
|
|
@@ -644,3 +744,21 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
|
|
| 644 |
w.Header().Set("Content-Type", "application/json")
|
| 645 |
json.NewEncoder(w).Encode(response)
|
| 646 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
totalContentOutputLength := 0
|
| 95 |
hasToolCalls := false
|
| 96 |
var collectedToolCalls []model.ToolCall
|
| 97 |
+
promptToolBuffer := "" // 用于 prompt 注入模式下缓冲 answer 文本以检测 <tool_call>
|
| 98 |
|
| 99 |
for scanner.Scan() {
|
| 100 |
line := scanner.Text()
|
|
|
|
| 146 |
Model: modelName,
|
| 147 |
Choices: []model.Choice{{
|
| 148 |
Index: 0,
|
| 149 |
+
Delta: &model.Delta{ReasoningContent: reasoningContent},
|
| 150 |
FinishReason: nil,
|
| 151 |
}},
|
| 152 |
}
|
|
|
|
| 183 |
Model: modelName,
|
| 184 |
Choices: []model.Choice{{
|
| 185 |
Index: 0,
|
| 186 |
+
Delta: &model.Delta{Content: textBeforeBlock},
|
| 187 |
FinishReason: nil,
|
| 188 |
}},
|
| 189 |
}
|
|
|
|
| 210 |
Model: modelName,
|
| 211 |
Choices: []model.Choice{{
|
| 212 |
Index: 0,
|
| 213 |
+
Delta: &model.Delta{Content: textBeforeBlock},
|
| 214 |
FinishReason: nil,
|
| 215 |
}},
|
| 216 |
}
|
|
|
|
| 248 |
Model: modelName,
|
| 249 |
Choices: []model.Choice{{
|
| 250 |
Index: 0,
|
| 251 |
+
Delta: &model.Delta{
|
| 252 |
ToolCalls: []model.ToolCall{tc},
|
| 253 |
},
|
| 254 |
FinishReason: nil,
|
|
|
|
| 271 |
Model: modelName,
|
| 272 |
Choices: []model.Choice{{
|
| 273 |
Index: 0,
|
| 274 |
+
Delta: &model.Delta{Content: pendingSourcesMarkdown},
|
| 275 |
FinishReason: nil,
|
| 276 |
}},
|
| 277 |
}
|
|
|
|
| 289 |
Model: modelName,
|
| 290 |
Choices: []model.Choice{{
|
| 291 |
Index: 0,
|
| 292 |
+
Delta: &model.Delta{Content: pendingImageSearchMarkdown},
|
| 293 |
FinishReason: nil,
|
| 294 |
}},
|
| 295 |
}
|
|
|
|
| 314 |
Model: modelName,
|
| 315 |
Choices: []model.Choice{{
|
| 316 |
Index: 0,
|
| 317 |
+
Delta: &model.Delta{ReasoningContent: processedRemaining},
|
| 318 |
FinishReason: nil,
|
| 319 |
}},
|
| 320 |
}
|
|
|
|
| 333 |
Model: modelName,
|
| 334 |
Choices: []model.Choice{{
|
| 335 |
Index: 0,
|
| 336 |
+
Delta: &model.Delta{ReasoningContent: pendingSourcesMarkdown},
|
| 337 |
FinishReason: nil,
|
| 338 |
}},
|
| 339 |
}
|
|
|
|
| 383 |
Model: modelName,
|
| 384 |
Choices: []model.Choice{{
|
| 385 |
Index: 0,
|
| 386 |
+
Delta: &model.Delta{ReasoningContent: reasoningContent},
|
| 387 |
FinishReason: nil,
|
| 388 |
}},
|
| 389 |
}
|
|
|
|
| 406 |
totalContentOutputLength += len([]rune(content))
|
| 407 |
}
|
| 408 |
|
| 409 |
+
// prompt 注入模式:缓冲 answer 文本,检测 <tool_call> 块
|
| 410 |
+
if len(tools) > 0 {
|
| 411 |
+
promptToolBuffer += content
|
| 412 |
+
// 循环提取完整的 <tool_call>...</tool_call> 块
|
| 413 |
+
for {
|
| 414 |
+
openIdx := strings.Index(promptToolBuffer, "<tool_call>")
|
| 415 |
+
if openIdx == -1 {
|
| 416 |
+
// 无 <tool_call> 标签,全部安全输出
|
| 417 |
+
break
|
| 418 |
+
}
|
| 419 |
+
// 输出 <tool_call> 之前的安全文本
|
| 420 |
+
if openIdx > 0 {
|
| 421 |
+
safeContent := promptToolBuffer[:openIdx]
|
| 422 |
+
promptToolBuffer = promptToolBuffer[openIdx:]
|
| 423 |
+
if safeContent != "" {
|
| 424 |
+
sendContentChunk(w, flusher, completionID, modelName, safeContent)
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
// 检查是否有完整的闭合标签
|
| 428 |
+
closeIdx := strings.Index(promptToolBuffer, "</tool_call>")
|
| 429 |
+
if closeIdx == -1 {
|
| 430 |
+
// 未闭合,等待更多数据
|
| 431 |
+
break
|
| 432 |
+
}
|
| 433 |
+
// 提取完整块
|
| 434 |
+
blockEnd := closeIdx + len("</tool_call>")
|
| 435 |
+
block := promptToolBuffer[:blockEnd]
|
| 436 |
+
promptToolBuffer = promptToolBuffer[blockEnd:]
|
| 437 |
+
|
| 438 |
+
// 解析 tool call
|
| 439 |
+
_, toolCalls := filter.ExtractPromptToolCalls(block)
|
| 440 |
+
if len(toolCalls) > 0 {
|
| 441 |
+
collectedToolCalls = append(collectedToolCalls, toolCalls...)
|
| 442 |
+
hasToolCalls = true
|
| 443 |
+
for _, tc := range toolCalls {
|
| 444 |
+
chunk := model.ChatCompletionChunk{
|
| 445 |
+
ID: completionID,
|
| 446 |
+
Object: "chat.completion.chunk",
|
| 447 |
+
Created: time.Now().Unix(),
|
| 448 |
+
Model: modelName,
|
| 449 |
+
Choices: []model.Choice{{
|
| 450 |
+
Index: 0,
|
| 451 |
+
Delta: &model.Delta{
|
| 452 |
+
ToolCalls: []model.ToolCall{tc},
|
| 453 |
+
},
|
| 454 |
+
FinishReason: nil,
|
| 455 |
+
}},
|
| 456 |
+
}
|
| 457 |
+
data, _ := json.Marshal(chunk)
|
| 458 |
+
fmt.Fprintf(w, "data: %s\n\n", data)
|
| 459 |
+
flusher.Flush()
|
| 460 |
+
}
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
+
continue
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
chunk := model.ChatCompletionChunk{
|
| 467 |
ID: completionID,
|
| 468 |
Object: "chat.completion.chunk",
|
|
|
|
| 470 |
Model: modelName,
|
| 471 |
Choices: []model.Choice{{
|
| 472 |
Index: 0,
|
| 473 |
+
Delta: &model.Delta{Content: content},
|
| 474 |
FinishReason: nil,
|
| 475 |
}},
|
| 476 |
}
|
|
|
|
| 484 |
logger.LogError("[Upstream] scanner error: %v", err)
|
| 485 |
}
|
| 486 |
|
| 487 |
+
// prompt 注入模式:flush 缓冲区中剩余的文本
|
| 488 |
+
if promptToolBuffer != "" {
|
| 489 |
+
// 尝试最后一次提取 tool calls
|
| 490 |
+
cleanContent, toolCalls := filter.ExtractPromptToolCalls(promptToolBuffer)
|
| 491 |
+
if len(toolCalls) > 0 {
|
| 492 |
+
collectedToolCalls = append(collectedToolCalls, toolCalls...)
|
| 493 |
+
hasToolCalls = true
|
| 494 |
+
for _, tc := range toolCalls {
|
| 495 |
+
chunk := model.ChatCompletionChunk{
|
| 496 |
+
ID: completionID,
|
| 497 |
+
Object: "chat.completion.chunk",
|
| 498 |
+
Created: time.Now().Unix(),
|
| 499 |
+
Model: modelName,
|
| 500 |
+
Choices: []model.Choice{{
|
| 501 |
+
Index: 0,
|
| 502 |
+
Delta: &model.Delta{
|
| 503 |
+
ToolCalls: []model.ToolCall{tc},
|
| 504 |
+
},
|
| 505 |
+
FinishReason: nil,
|
| 506 |
+
}},
|
| 507 |
+
}
|
| 508 |
+
data, _ := json.Marshal(chunk)
|
| 509 |
+
fmt.Fprintf(w, "data: %s\n\n", data)
|
| 510 |
+
flusher.Flush()
|
| 511 |
+
}
|
| 512 |
+
}
|
| 513 |
+
if cleanContent != "" {
|
| 514 |
+
sendContentChunk(w, flusher, completionID, modelName, cleanContent)
|
| 515 |
+
hasContent = true
|
| 516 |
+
}
|
| 517 |
+
promptToolBuffer = ""
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
if remaining := searchRefFilter.Flush(); remaining != "" {
|
| 521 |
hasContent = true
|
| 522 |
chunk := model.ChatCompletionChunk{
|
|
|
|
| 526 |
Model: modelName,
|
| 527 |
Choices: []model.Choice{{
|
| 528 |
Index: 0,
|
| 529 |
+
Delta: &model.Delta{Content: remaining},
|
| 530 |
FinishReason: nil,
|
| 531 |
}},
|
| 532 |
}
|
|
|
|
| 550 |
Model: modelName,
|
| 551 |
Choices: []model.Choice{{
|
| 552 |
Index: 0,
|
| 553 |
+
Delta: &model.Delta{},
|
| 554 |
FinishReason: &stopReason,
|
| 555 |
}},
|
| 556 |
}
|
|
|
|
| 707 |
fullReasoning := strings.Join(reasoningChunks, "")
|
| 708 |
fullReasoning = searchRefFilter.Process(fullReasoning) + searchRefFilter.Flush()
|
| 709 |
|
| 710 |
+
// prompt 注入模式:从 answer 文本中提取 <tool_call> 块
|
| 711 |
+
if len(tools) > 0 && len(collectedToolCalls) == 0 {
|
| 712 |
+
cleanContent, promptToolCalls := filter.ExtractPromptToolCalls(fullContent)
|
| 713 |
+
if len(promptToolCalls) > 0 {
|
| 714 |
+
collectedToolCalls = promptToolCalls
|
| 715 |
+
fullContent = cleanContent
|
| 716 |
+
}
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
if fullContent == "" && len(collectedToolCalls) == 0 {
|
| 720 |
logger.LogError("Non-stream response 200 but no content received")
|
| 721 |
}
|
|
|
|
| 744 |
w.Header().Set("Content-Type", "application/json")
|
| 745 |
json.NewEncoder(w).Encode(response)
|
| 746 |
}
|
| 747 |
+
|
| 748 |
+
// sendContentChunk 发送一个 content SSE chunk
|
| 749 |
+
func sendContentChunk(w http.ResponseWriter, flusher http.Flusher, completionID, modelName, content string) {
|
| 750 |
+
chunk := model.ChatCompletionChunk{
|
| 751 |
+
ID: completionID,
|
| 752 |
+
Object: "chat.completion.chunk",
|
| 753 |
+
Created: time.Now().Unix(),
|
| 754 |
+
Model: modelName,
|
| 755 |
+
Choices: []model.Choice{{
|
| 756 |
+
Index: 0,
|
| 757 |
+
Delta: &model.Delta{Content: content},
|
| 758 |
+
FinishReason: nil,
|
| 759 |
+
}},
|
| 760 |
+
}
|
| 761 |
+
data, _ := json.Marshal(chunk)
|
| 762 |
+
fmt.Fprintf(w, "data: %s\n\n", data)
|
| 763 |
+
flusher.Flush()
|
| 764 |
+
}
|
internal/handler/chat_test.go
CHANGED
|
@@ -574,3 +574,76 @@ func TestNonStreamResponse_FullFormat(t *testing.T) {
|
|
| 574 |
t.Errorf("Role = %q", resp.Choices[0].Message.Role)
|
| 575 |
}
|
| 576 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
t.Errorf("Role = %q", resp.Choices[0].Message.Role)
|
| 575 |
}
|
| 576 |
}
|
| 577 |
+
|
| 578 |
+
// ===== 流式:prompt 注入模式 <tool_call> 在 answer 文本中 =====
|
| 579 |
+
|
| 580 |
+
func TestStreamResponse_PromptInjectionToolCall(t *testing.T) {
|
| 581 |
+
body := newFakeBody(
|
| 582 |
+
sseEvent("answer", "好的,我来查询。\n", ""),
|
| 583 |
+
sseEvent("answer", `<tool_call>{"name":"get_weather","arguments":{"city":"北京"}}</tool_call>`, ""),
|
| 584 |
+
sseEventDone(),
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
w := httptest.NewRecorder()
|
| 588 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 589 |
+
|
| 590 |
+
result := w.Body.String()
|
| 591 |
+
|
| 592 |
+
if !strings.Contains(result, `"tool_calls"`) {
|
| 593 |
+
t.Error("missing tool_calls in prompt injection stream")
|
| 594 |
+
}
|
| 595 |
+
if !strings.Contains(result, `"get_weather"`) {
|
| 596 |
+
t.Error("missing function name")
|
| 597 |
+
}
|
| 598 |
+
if !strings.Contains(result, `"finish_reason":"tool_calls"`) {
|
| 599 |
+
t.Error("finish_reason should be tool_calls")
|
| 600 |
+
}
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
// ===== 非流式:prompt 注入模式 =====
|
| 604 |
+
|
| 605 |
+
func TestNonStreamResponse_PromptInjectionToolCall(t *testing.T) {
|
| 606 |
+
body := newFakeBody(
|
| 607 |
+
sseEvent("answer", "我来查询天气。\n<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"city\":\"上海\"}}</tool_call>", ""),
|
| 608 |
+
sseEventDone(),
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
w := httptest.NewRecorder()
|
| 612 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 613 |
+
|
| 614 |
+
var resp model.ChatCompletionResponse
|
| 615 |
+
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
| 616 |
+
t.Fatalf("decode: %v", err)
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
msg := resp.Choices[0].Message
|
| 620 |
+
if len(msg.ToolCalls) != 1 {
|
| 621 |
+
t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
|
| 622 |
+
}
|
| 623 |
+
if msg.ToolCalls[0].Function.Name != "get_weather" {
|
| 624 |
+
t.Errorf("Function.Name = %q", msg.ToolCalls[0].Function.Name)
|
| 625 |
+
}
|
| 626 |
+
if strings.Contains(msg.Content, "<tool_call>") {
|
| 627 |
+
t.Error("content should not contain <tool_call> tags")
|
| 628 |
+
}
|
| 629 |
+
if *resp.Choices[0].FinishReason != "tool_calls" {
|
| 630 |
+
t.Errorf("FinishReason = %q, want tool_calls", *resp.Choices[0].FinishReason)
|
| 631 |
+
}
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
// ===== 非流式:response 中不应有 delta 字段 =====
|
| 635 |
+
|
| 636 |
+
func TestNonStreamResponse_NoDeltaField(t *testing.T) {
|
| 637 |
+
body := newFakeBody(
|
| 638 |
+
sseEvent("answer", "hello", ""),
|
| 639 |
+
sseEventDone(),
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
w := httptest.NewRecorder()
|
| 643 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 644 |
+
|
| 645 |
+
result := w.Body.String()
|
| 646 |
+
if strings.Contains(result, `"delta"`) {
|
| 647 |
+
t.Error("non-streaming response should not contain delta field")
|
| 648 |
+
}
|
| 649 |
+
}
|
internal/model/types.go
CHANGED
|
@@ -160,7 +160,7 @@ type ChatCompletionChunk struct {
|
|
| 160 |
|
| 161 |
type Choice struct {
|
| 162 |
Index int `json:"index"`
|
| 163 |
-
Delta Delta
|
| 164 |
Message *MessageResp `json:"message,omitempty"`
|
| 165 |
FinishReason *string `json:"finish_reason"`
|
| 166 |
}
|
|
|
|
| 160 |
|
| 161 |
type Choice struct {
|
| 162 |
Index int `json:"index"`
|
| 163 |
+
Delta *Delta `json:"delta,omitempty"`
|
| 164 |
Message *MessageResp `json:"message,omitempty"`
|
| 165 |
FinishReason *string `json:"finish_reason"`
|
| 166 |
}
|
internal/model/types_test.go
CHANGED
|
@@ -2,6 +2,7 @@ package model
|
|
| 2 |
|
| 3 |
import (
|
| 4 |
"encoding/json"
|
|
|
|
| 5 |
"testing"
|
| 6 |
)
|
| 7 |
|
|
@@ -423,7 +424,7 @@ func TestChunkWithToolCallsFinishReason(t *testing.T) {
|
|
| 423 |
Model: "glm-4.7",
|
| 424 |
Choices: []Choice{{
|
| 425 |
Index: 0,
|
| 426 |
-
Delta: Delta{},
|
| 427 |
FinishReason: &reason,
|
| 428 |
}},
|
| 429 |
}
|
|
@@ -501,3 +502,41 @@ func TestCompletionResponseWithToolCalls(t *testing.T) {
|
|
| 501 |
t.Errorf("FinishReason = %q", *decoded.Choices[0].FinishReason)
|
| 502 |
}
|
| 503 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import (
|
| 4 |
"encoding/json"
|
| 5 |
+
"strings"
|
| 6 |
"testing"
|
| 7 |
)
|
| 8 |
|
|
|
|
| 424 |
Model: "glm-4.7",
|
| 425 |
Choices: []Choice{{
|
| 426 |
Index: 0,
|
| 427 |
+
Delta: &Delta{},
|
| 428 |
FinishReason: &reason,
|
| 429 |
}},
|
| 430 |
}
|
|
|
|
| 502 |
t.Errorf("FinishReason = %q", *decoded.Choices[0].FinishReason)
|
| 503 |
}
|
| 504 |
}
|
| 505 |
+
|
| 506 |
+
// ===== Delta 指针:nil 时不出现在 JSON 中 =====
|
| 507 |
+
|
| 508 |
+
func TestChoiceDeltaNil_OmittedInJSON(t *testing.T) {
|
| 509 |
+
reason := "stop"
|
| 510 |
+
choice := Choice{
|
| 511 |
+
Index: 0,
|
| 512 |
+
Message: &MessageResp{
|
| 513 |
+
Role: "assistant",
|
| 514 |
+
Content: "hello",
|
| 515 |
+
},
|
| 516 |
+
FinishReason: &reason,
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
data, _ := json.Marshal(choice)
|
| 520 |
+
s := string(data)
|
| 521 |
+
if strings.Contains(s, `"delta"`) {
|
| 522 |
+
t.Errorf("nil Delta should be omitted, got: %s", s)
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
// ===== Delta 指针:非 nil 时正常序列化 =====
|
| 527 |
+
|
| 528 |
+
func TestChoiceDeltaNotNil_SerializedInJSON(t *testing.T) {
|
| 529 |
+
choice := Choice{
|
| 530 |
+
Index: 0,
|
| 531 |
+
Delta: &Delta{Content: "test content"},
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
data, _ := json.Marshal(choice)
|
| 535 |
+
s := string(data)
|
| 536 |
+
if !strings.Contains(s, `"delta"`) {
|
| 537 |
+
t.Error("non-nil Delta should appear in JSON")
|
| 538 |
+
}
|
| 539 |
+
if !strings.Contains(s, `"test content"`) {
|
| 540 |
+
t.Error("Delta content should be serialized")
|
| 541 |
+
}
|
| 542 |
+
}
|
internal/tools/prompt.go
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package tools
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"fmt"
|
| 6 |
+
"strings"
|
| 7 |
+
|
| 8 |
+
"zai-proxy/internal/model"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
// BuildToolSystemPrompt 将工具定义列表转换为 system prompt 文本,
|
| 12 |
+
// 指示模型使用 <tool_call> 格式输出工具调用。
|
| 13 |
+
func BuildToolSystemPrompt(tools []model.Tool, toolChoice interface{}) string {
|
| 14 |
+
if len(tools) == 0 {
|
| 15 |
+
return ""
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
var sb strings.Builder
|
| 19 |
+
|
| 20 |
+
sb.WriteString("# 工具调用规则\n\n")
|
| 21 |
+
sb.WriteString("你可以使用下面列出的工具。当你需要调用工具时,**必须严格使用以下 XML 格式**输出调用请求(不要使用 markdown 代码块、不要使用 [TOOL] 或其他格式):\n\n")
|
| 22 |
+
sb.WriteString("<tool_call>{\"name\": \"函数名\", \"arguments\": {\"参数名\": \"参数值\"}}</tool_call>\n\n")
|
| 23 |
+
sb.WriteString("**重要规则:**\n")
|
| 24 |
+
sb.WriteString("- 你不能自行执行工具,只能输出 <tool_call> 标签,由系统执行后将结果返回给你\n")
|
| 25 |
+
sb.WriteString("- 每个工具调用必须独立包裹在 <tool_call></tool_call> 标签中\n")
|
| 26 |
+
sb.WriteString("- arguments 必须是合法 JSON 对象\n")
|
| 27 |
+
sb.WriteString("- 不要在 <tool_call> 标签外描述调用参数\n\n")
|
| 28 |
+
|
| 29 |
+
sb.WriteString("## 示例\n\n")
|
| 30 |
+
sb.WriteString("用户: 帮我创建一个文件 test.txt 内容为 hello\n")
|
| 31 |
+
sb.WriteString("助手: 好的,我来为您创建文件。\n")
|
| 32 |
+
sb.WriteString("<tool_call>{\"name\": \"create_file\", \"arguments\": {\"filename\": \"test.txt\", \"content\": \"hello\"}}</tool_call>\n\n")
|
| 33 |
+
sb.WriteString("用户: 查询北京和上海的天气\n")
|
| 34 |
+
sb.WriteString("助手: 我来查询这两个城市的天气。\n")
|
| 35 |
+
sb.WriteString("<tool_call>{\"name\": \"get_weather\", \"arguments\": {\"location\": \"北京\"}}</tool_call>\n")
|
| 36 |
+
sb.WriteString("<tool_call>{\"name\": \"get_weather\", \"arguments\": {\"location\": \"上海\"}}</tool_call>\n\n")
|
| 37 |
+
|
| 38 |
+
sb.WriteString("## 可用工具\n\n")
|
| 39 |
+
|
| 40 |
+
for _, tool := range tools {
|
| 41 |
+
sb.WriteString(fmt.Sprintf("### %s\n", tool.Function.Name))
|
| 42 |
+
if tool.Function.Description != "" {
|
| 43 |
+
sb.WriteString(fmt.Sprintf("%s\n", tool.Function.Description))
|
| 44 |
+
}
|
| 45 |
+
if tool.Function.Parameters != nil {
|
| 46 |
+
params, err := json.Marshal(tool.Function.Parameters)
|
| 47 |
+
if err == nil {
|
| 48 |
+
sb.WriteString(fmt.Sprintf("Parameters: %s\n", string(params)))
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
sb.WriteString("\n")
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// 处理 tool_choice
|
| 55 |
+
if toolChoice != nil {
|
| 56 |
+
switch tc := toolChoice.(type) {
|
| 57 |
+
case string:
|
| 58 |
+
switch tc {
|
| 59 |
+
case "none":
|
| 60 |
+
sb.WriteString("**禁止调用任何工具,直接回答问题。**\n")
|
| 61 |
+
case "required":
|
| 62 |
+
sb.WriteString("**你的回复中必须包含至少一个 <tool_call> 标签。即使你认为不需要调用工具,也必须调用。**\n")
|
| 63 |
+
// "auto" is the default, no special instruction needed
|
| 64 |
+
}
|
| 65 |
+
case map[string]interface{}:
|
| 66 |
+
// tool_choice = {"type": "function", "function": {"name": "xxx"}}
|
| 67 |
+
if fn, ok := tc["function"].(map[string]interface{}); ok {
|
| 68 |
+
if name, ok := fn["name"].(string); ok {
|
| 69 |
+
sb.WriteString(fmt.Sprintf("**你必须调用工具 \"%s\",使用 <tool_call> 标签输出调用。**\n", name))
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
return sb.String()
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
// ConvertToolCallToText 将 assistant 消息中的 tool_calls 转换为 <tool_call> 文本格式,
|
| 79 |
+
// 用于在 prompt 注入模式下将历史 tool_calls 传给上游。
|
| 80 |
+
func ConvertToolCallToText(toolCalls []model.ToolCall) string {
|
| 81 |
+
var parts []string
|
| 82 |
+
for _, tc := range toolCalls {
|
| 83 |
+
callJSON, _ := json.Marshal(map[string]interface{}{
|
| 84 |
+
"name": tc.Function.Name,
|
| 85 |
+
"arguments": json.RawMessage(tc.Function.Arguments),
|
| 86 |
+
})
|
| 87 |
+
parts = append(parts, fmt.Sprintf("<tool_call>%s</tool_call>", string(callJSON)))
|
| 88 |
+
}
|
| 89 |
+
return strings.Join(parts, "\n")
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
// ConvertToolResultToText 将 tool 角色的消息转换为文本格式,
|
| 93 |
+
// 用于在 prompt 注入模式下传递工具执行结果。
|
| 94 |
+
func ConvertToolResultToText(toolCallID string, content string) string {
|
| 95 |
+
return fmt.Sprintf("<tool_result call_id=\"%s\">%s</tool_result>", toolCallID, content)
|
| 96 |
+
}
|
internal/tools/prompt_test.go
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package tools
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"strings"
|
| 5 |
+
"testing"
|
| 6 |
+
|
| 7 |
+
"zai-proxy/internal/model"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
func TestBuildToolSystemPrompt_Basic(t *testing.T) {
|
| 11 |
+
tools := []model.Tool{
|
| 12 |
+
{
|
| 13 |
+
Type: "function",
|
| 14 |
+
Function: model.ToolFunction{
|
| 15 |
+
Name: "get_weather",
|
| 16 |
+
Description: "Get current weather",
|
| 17 |
+
Parameters: map[string]interface{}{
|
| 18 |
+
"type": "object",
|
| 19 |
+
"properties": map[string]interface{}{
|
| 20 |
+
"city": map[string]interface{}{
|
| 21 |
+
"type": "string",
|
| 22 |
+
"description": "City name",
|
| 23 |
+
},
|
| 24 |
+
},
|
| 25 |
+
"required": []string{"city"},
|
| 26 |
+
},
|
| 27 |
+
},
|
| 28 |
+
},
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
result := BuildToolSystemPrompt(tools, nil)
|
| 32 |
+
|
| 33 |
+
if !strings.Contains(result, "get_weather") {
|
| 34 |
+
t.Error("should contain tool name")
|
| 35 |
+
}
|
| 36 |
+
if !strings.Contains(result, "Get current weather") {
|
| 37 |
+
t.Error("should contain description")
|
| 38 |
+
}
|
| 39 |
+
if !strings.Contains(result, "<tool_call>") {
|
| 40 |
+
t.Error("should contain format instruction")
|
| 41 |
+
}
|
| 42 |
+
if !strings.Contains(result, "city") {
|
| 43 |
+
t.Error("should contain parameter info")
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
func TestBuildToolSystemPrompt_Empty(t *testing.T) {
|
| 48 |
+
result := BuildToolSystemPrompt(nil, nil)
|
| 49 |
+
if result != "" {
|
| 50 |
+
t.Error("should return empty for nil tools")
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
func TestBuildToolSystemPrompt_ToolChoiceNone(t *testing.T) {
|
| 55 |
+
tools := []model.Tool{{
|
| 56 |
+
Type: "function",
|
| 57 |
+
Function: model.ToolFunction{Name: "test"},
|
| 58 |
+
}}
|
| 59 |
+
|
| 60 |
+
result := BuildToolSystemPrompt(tools, "none")
|
| 61 |
+
if !strings.Contains(result, "禁止调用任何工具") {
|
| 62 |
+
t.Error("should instruct not to call tools")
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
func TestBuildToolSystemPrompt_ToolChoiceRequired(t *testing.T) {
|
| 67 |
+
tools := []model.Tool{{
|
| 68 |
+
Type: "function",
|
| 69 |
+
Function: model.ToolFunction{Name: "test"},
|
| 70 |
+
}}
|
| 71 |
+
|
| 72 |
+
result := BuildToolSystemPrompt(tools, "required")
|
| 73 |
+
if !strings.Contains(result, "必须包含至少一个") {
|
| 74 |
+
t.Error("should instruct to call at least one tool")
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
func TestBuildToolSystemPrompt_ToolChoiceSpecific(t *testing.T) {
|
| 79 |
+
tools := []model.Tool{{
|
| 80 |
+
Type: "function",
|
| 81 |
+
Function: model.ToolFunction{Name: "get_weather"},
|
| 82 |
+
}}
|
| 83 |
+
|
| 84 |
+
choice := map[string]interface{}{
|
| 85 |
+
"type": "function",
|
| 86 |
+
"function": map[string]interface{}{
|
| 87 |
+
"name": "get_weather",
|
| 88 |
+
},
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
result := BuildToolSystemPrompt(tools, choice)
|
| 92 |
+
if !strings.Contains(result, `必须调用工具 "get_weather"`) {
|
| 93 |
+
t.Error("should instruct to call specific tool")
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
func TestConvertToolCallToText(t *testing.T) {
|
| 98 |
+
toolCalls := []model.ToolCall{
|
| 99 |
+
{
|
| 100 |
+
ID: "call_123",
|
| 101 |
+
Type: "function",
|
| 102 |
+
Function: model.FunctionCall{
|
| 103 |
+
Name: "get_weather",
|
| 104 |
+
Arguments: `{"city":"Beijing"}`,
|
| 105 |
+
},
|
| 106 |
+
},
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
result := ConvertToolCallToText(toolCalls)
|
| 110 |
+
if !strings.Contains(result, "<tool_call>") {
|
| 111 |
+
t.Error("should contain <tool_call> tag")
|
| 112 |
+
}
|
| 113 |
+
if !strings.Contains(result, "get_weather") {
|
| 114 |
+
t.Error("should contain function name")
|
| 115 |
+
}
|
| 116 |
+
if !strings.Contains(result, "Beijing") {
|
| 117 |
+
t.Error("should contain arguments")
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
func TestConvertToolResultToText(t *testing.T) {
|
| 122 |
+
result := ConvertToolResultToText("call_123", `{"temp": 25}`)
|
| 123 |
+
if !strings.Contains(result, "call_123") {
|
| 124 |
+
t.Error("should contain call ID")
|
| 125 |
+
}
|
| 126 |
+
if !strings.Contains(result, `{"temp": 25}`) {
|
| 127 |
+
t.Error("should contain result content")
|
| 128 |
+
}
|
| 129 |
+
if !strings.Contains(result, "<tool_result") {
|
| 130 |
+
t.Error("should contain <tool_result> tag")
|
| 131 |
+
}
|
| 132 |
+
}
|
internal/upstream/client.go
CHANGED
|
@@ -94,11 +94,62 @@ func MakeUpstreamRequest(token string, messages []model.Message, modelName strin
|
|
| 94 |
}
|
| 95 |
}
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
var upstreamMessages []map[string]interface{}
|
|
|
|
| 98 |
for _, msg := range messages {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
upstreamMessages = append(upstreamMessages, msg.ToUpstreamMessage(urlToFileID))
|
| 100 |
}
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
body := map[string]interface{}{
|
| 103 |
"stream": true,
|
| 104 |
"model": targetModel,
|
|
@@ -120,26 +171,6 @@ func MakeUpstreamRequest(token string, messages []model.Message, modelName strin
|
|
| 120 |
body["mcp_servers"] = mcpServers
|
| 121 |
}
|
| 122 |
|
| 123 |
-
// 当使用 -tools 模型时,自动注入内置工具(客户端自带工具优先)
|
| 124 |
-
if model.IsToolsModel(modelName) {
|
| 125 |
-
clientToolNames := make(map[string]bool)
|
| 126 |
-
for _, t := range tools {
|
| 127 |
-
clientToolNames[t.Function.Name] = true
|
| 128 |
-
}
|
| 129 |
-
for _, bt := range builtintools.GetBuiltinTools() {
|
| 130 |
-
if !clientToolNames[bt.Function.Name] {
|
| 131 |
-
tools = append(tools, bt)
|
| 132 |
-
}
|
| 133 |
-
}
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
if len(tools) > 0 {
|
| 137 |
-
body["tools"] = tools
|
| 138 |
-
if toolChoice != nil {
|
| 139 |
-
body["tool_choice"] = toolChoice
|
| 140 |
-
}
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
if len(filesData) > 0 {
|
| 144 |
body["files"] = filesData
|
| 145 |
body["current_user_message_id"] = userMsgID
|
|
|
|
| 94 |
}
|
| 95 |
}
|
| 96 |
|
| 97 |
+
// 当使用 -tools 模型时,自动注入内置工具(客户端自带工具优先)
|
| 98 |
+
if model.IsToolsModel(modelName) {
|
| 99 |
+
clientToolNames := make(map[string]bool)
|
| 100 |
+
for _, t := range tools {
|
| 101 |
+
clientToolNames[t.Function.Name] = true
|
| 102 |
+
}
|
| 103 |
+
for _, bt := range builtintools.GetBuiltinTools() {
|
| 104 |
+
if !clientToolNames[bt.Function.Name] {
|
| 105 |
+
tools = append(tools, bt)
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
var upstreamMessages []map[string]interface{}
|
| 111 |
+
hasPromptTools := len(tools) > 0
|
| 112 |
for _, msg := range messages {
|
| 113 |
+
if hasPromptTools {
|
| 114 |
+
// prompt 注入模式:将 tool_calls / tool 结果转为纯文本
|
| 115 |
+
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
| 116 |
+
text, _ := msg.ParseContent()
|
| 117 |
+
callText := builtintools.ConvertToolCallToText(msg.ToolCalls)
|
| 118 |
+
if text != "" {
|
| 119 |
+
text = text + "\n" + callText
|
| 120 |
+
} else {
|
| 121 |
+
text = callText
|
| 122 |
+
}
|
| 123 |
+
upstreamMessages = append(upstreamMessages, map[string]interface{}{
|
| 124 |
+
"role": "assistant",
|
| 125 |
+
"content": text,
|
| 126 |
+
})
|
| 127 |
+
continue
|
| 128 |
+
}
|
| 129 |
+
if msg.Role == "tool" {
|
| 130 |
+
text, _ := msg.ParseContent()
|
| 131 |
+
upstreamMessages = append(upstreamMessages, map[string]interface{}{
|
| 132 |
+
"role": "user",
|
| 133 |
+
"content": builtintools.ConvertToolResultToText(msg.ToolCallID, text),
|
| 134 |
+
})
|
| 135 |
+
continue
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
upstreamMessages = append(upstreamMessages, msg.ToUpstreamMessage(urlToFileID))
|
| 139 |
}
|
| 140 |
|
| 141 |
+
// 工具注入:通过 system prompt 注入工具定义(z.ai 不支持原生 tools 字段)
|
| 142 |
+
if len(tools) > 0 {
|
| 143 |
+
toolSystemPrompt := builtintools.BuildToolSystemPrompt(tools, toolChoice)
|
| 144 |
+
if toolSystemPrompt != "" {
|
| 145 |
+
systemMsg := map[string]interface{}{
|
| 146 |
+
"role": "system",
|
| 147 |
+
"content": toolSystemPrompt,
|
| 148 |
+
}
|
| 149 |
+
upstreamMessages = append([]map[string]interface{}{systemMsg}, upstreamMessages...)
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
body := map[string]interface{}{
|
| 154 |
"stream": true,
|
| 155 |
"model": targetModel,
|
|
|
|
| 171 |
body["mcp_servers"] = mcpServers
|
| 172 |
}
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
if len(filesData) > 0 {
|
| 175 |
body["files"] = filesData
|
| 176 |
body["current_user_message_id"] = userMsgID
|