ss22345 commited on
Commit
8646505
·
1 Parent(s): 5a55e77

fix: improve tool calling reliability with multi-format parsing and Delta pointer fix

Browse files

- Enhance system prompt with Chinese instructions and few-shot examples for
more reliable <tool_call> output from GLM models
- Add fallback parsing for [TOOL]...[/TOOL], [TOOL_CALL]...[/TOOL_CALL],
and markdown JSON block formats in prompt injection mode
- Change Choice.Delta to *Delta pointer so omitempty correctly omits the
field in non-streaming responses (fixes extra "delta":{} in JSON)
- Convert tool_calls/tool messages to plain text for upstream z.ai API
which doesn't support native tools field
- Add comprehensive tests for all new parsing formats and serialization

internal/filter/prompttool.go ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package filter
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "regexp"
7
+ "strings"
8
+
9
+ "github.com/google/uuid"
10
+
11
+ "zai-proxy/internal/model"
12
+ )
13
+
14
+ // promptToolCallPattern 匹配 <tool_call>...</tool_call> 块
15
+ var promptToolCallPattern = regexp.MustCompile(`<tool_call>\s*([\s\S]*?)\s*</tool_call>`)
16
+
17
+ // altToolCallPattern 匹配 [TOOL]...[/TOOL] 和 [TOOL_CALL]...[/TOOL_CALL] 格式
18
+ var altToolCallPattern = regexp.MustCompile(`\[TOOL(?:_CALL)?\]\s*([\s\S]*?)\s*\[/TOOL(?:_CALL)?\]`)
19
+
20
+ // jsonBlockPattern 匹配 markdown JSON 代码块中的 tool call
21
+ var jsonBlockPattern = regexp.MustCompile("```json\\s*\\n(\\{[\\s\\S]*?\"name\"[\\s\\S]*?\\})\\s*\\n```")
22
+
23
+ // allToolCallPatterns 按优先级排列的所有 tool call 模式
24
+ var allToolCallPatterns = []*regexp.Regexp{
25
+ promptToolCallPattern, // <tool_call> 最高优先级
26
+ altToolCallPattern, // [TOOL] / [TOOL_CALL]
27
+ jsonBlockPattern, // ```json ... ```
28
+ }
29
+
30
+ // ExtractPromptToolCalls 从文本中提取所有 tool call 块(支持多种格式),
31
+ // 返回清理后的文本和解析出的 tool calls。
32
+ func ExtractPromptToolCalls(content string) (cleanContent string, toolCalls []model.ToolCall) {
33
+ var allCalls []model.ToolCall
34
+ cleaned := content
35
+
36
+ // 按优先级依次尝试各种格式
37
+ for _, pattern := range allToolCallPatterns {
38
+ matches := pattern.FindAllStringSubmatchIndex(cleaned, -1)
39
+ if len(matches) == 0 {
40
+ continue
41
+ }
42
+
43
+ // 从后向前移除匹配块,避免索引偏移
44
+ for i := len(matches) - 1; i >= 0; i-- {
45
+ match := matches[i]
46
+ fullStart, fullEnd := match[0], match[1]
47
+ groupStart, groupEnd := match[2], match[3]
48
+
49
+ jsonStr := cleaned[groupStart:groupEnd]
50
+ if calls := parsePromptToolCallJSON(jsonStr); len(calls) > 0 {
51
+ allCalls = append(calls, allCalls...)
52
+ }
53
+
54
+ cleaned = cleaned[:fullStart] + cleaned[fullEnd:]
55
+ }
56
+ }
57
+
58
+ if len(allCalls) == 0 {
59
+ return content, nil
60
+ }
61
+
62
+ // 清理多余空行
63
+ cleaned = strings.TrimSpace(cleaned)
64
+ for strings.Contains(cleaned, "\n\n\n") {
65
+ cleaned = strings.ReplaceAll(cleaned, "\n\n\n", "\n\n")
66
+ }
67
+
68
+ // 为每个 tool call 分配 ID
69
+ for i := range allCalls {
70
+ if allCalls[i].ID == "" {
71
+ allCalls[i].ID = fmt.Sprintf("call_%s", uuid.New().String()[:24])
72
+ }
73
+ allCalls[i].Index = i
74
+ allCalls[i].Type = "function"
75
+ }
76
+
77
+ return cleaned, allCalls
78
+ }
79
+
80
+ // parsePromptToolCallJSON 解析 <tool_call> 内的 JSON
81
+ func parsePromptToolCallJSON(content string) []model.ToolCall {
82
+ content = strings.TrimSpace(content)
83
+ if content == "" {
84
+ return nil
85
+ }
86
+
87
+ // 标准格式: {"name": "xxx", "arguments": {...}}
88
+ var call struct {
89
+ Name string `json:"name"`
90
+ Arguments json.RawMessage `json:"arguments"`
91
+ }
92
+ if err := json.Unmarshal([]byte(content), &call); err == nil && call.Name != "" {
93
+ argsStr := string(call.Arguments)
94
+ // 如果 arguments 不是字符串,序列化为字符串
95
+ if len(argsStr) > 0 && argsStr[0] != '"' {
96
+ // 已经是 JSON 对象/其他类型,直接用
97
+ } else {
98
+ // 是 JSON 字符串,解引用
99
+ var s string
100
+ if json.Unmarshal(call.Arguments, &s) == nil {
101
+ argsStr = s
102
+ }
103
+ }
104
+ return []model.ToolCall{{
105
+ Function: model.FunctionCall{
106
+ Name: call.Name,
107
+ Arguments: argsStr,
108
+ },
109
+ }}
110
+ }
111
+
112
+ return nil
113
+ }
114
+
115
+ // HasPromptToolCallOpen 检测文本中是否有未关闭的 tool call 标签
116
+ func HasPromptToolCallOpen(content string) bool {
117
+ // <tool_call>
118
+ if strings.Count(content, "<tool_call>") > strings.Count(content, "</tool_call>") {
119
+ return true
120
+ }
121
+ // [TOOL] / [TOOL_CALL]
122
+ if strings.Count(content, "[TOOL]") > strings.Count(content, "[/TOOL]") {
123
+ return true
124
+ }
125
+ if strings.Count(content, "[TOOL_CALL]") > strings.Count(content, "[/TOOL_CALL]") {
126
+ return true
127
+ }
128
+ return false
129
+ }
internal/filter/prompttool_test.go ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package filter
2
+
3
+ import (
4
+ "testing"
5
+ )
6
+
7
+ func TestExtractPromptToolCalls_NoToolCall(t *testing.T) {
8
+ content := "Hello, this is a normal response."
9
+ clean, calls := ExtractPromptToolCalls(content)
10
+ if clean != content {
11
+ t.Errorf("expected content unchanged, got %q", clean)
12
+ }
13
+ if len(calls) != 0 {
14
+ t.Error("expected no tool calls")
15
+ }
16
+ }
17
+
18
+ func TestExtractPromptToolCalls_SingleCall(t *testing.T) {
19
+ content := `Here is the result:
20
+ <tool_call>{"name": "get_weather", "arguments": {"city": "Beijing"}}</tool_call>
21
+ Done.`
22
+
23
+ clean, calls := ExtractPromptToolCalls(content)
24
+
25
+ if len(calls) != 1 {
26
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
27
+ }
28
+ if calls[0].Function.Name != "get_weather" {
29
+ t.Errorf("expected name get_weather, got %s", calls[0].Function.Name)
30
+ }
31
+ if calls[0].Function.Arguments != `{"city":"Beijing"}` && calls[0].Function.Arguments != `{"city": "Beijing"}` {
32
+ t.Errorf("unexpected arguments: %s", calls[0].Function.Arguments)
33
+ }
34
+ if calls[0].ID == "" {
35
+ t.Error("expected auto-generated ID")
36
+ }
37
+ if calls[0].Type != "function" {
38
+ t.Errorf("expected type function, got %s", calls[0].Type)
39
+ }
40
+ // Clean content should not contain tool_call tags
41
+ if clean == content {
42
+ t.Error("expected content to be cleaned")
43
+ }
44
+ if contains := "Here is the result:"; !containsStr(clean, contains) {
45
+ t.Errorf("expected clean content to contain %q", contains)
46
+ }
47
+ if containsStr(clean, "<tool_call>") {
48
+ t.Error("clean content should not contain <tool_call>")
49
+ }
50
+ }
51
+
52
+ func TestExtractPromptToolCalls_MultipleCalls(t *testing.T) {
53
+ content := `<tool_call>{"name": "func_a", "arguments": {"x": 1}}</tool_call>
54
+ <tool_call>{"name": "func_b", "arguments": {"y": 2}}</tool_call>`
55
+
56
+ clean, calls := ExtractPromptToolCalls(content)
57
+
58
+ if len(calls) != 2 {
59
+ t.Fatalf("expected 2 tool calls, got %d", len(calls))
60
+ }
61
+ if calls[0].Function.Name != "func_a" {
62
+ t.Errorf("expected first call func_a, got %s", calls[0].Function.Name)
63
+ }
64
+ if calls[1].Function.Name != "func_b" {
65
+ t.Errorf("expected second call func_b, got %s", calls[1].Function.Name)
66
+ }
67
+ if calls[0].Index != 0 || calls[1].Index != 1 {
68
+ t.Error("expected sequential indices")
69
+ }
70
+ if clean != "" {
71
+ t.Errorf("expected empty clean content, got %q", clean)
72
+ }
73
+ }
74
+
75
+ func TestExtractPromptToolCalls_OnlyToolCall(t *testing.T) {
76
+ content := `<tool_call>{"name": "calculate", "arguments": {"expression": "2+2"}}</tool_call>`
77
+
78
+ clean, calls := ExtractPromptToolCalls(content)
79
+ if len(calls) != 1 {
80
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
81
+ }
82
+ if calls[0].Function.Name != "calculate" {
83
+ t.Errorf("expected calculate, got %s", calls[0].Function.Name)
84
+ }
85
+ if clean != "" {
86
+ t.Errorf("expected empty clean content, got %q", clean)
87
+ }
88
+ }
89
+
90
+ func TestExtractPromptToolCalls_WithWhitespace(t *testing.T) {
91
+ content := `<tool_call>
92
+ {"name": "test", "arguments": {}}
93
+ </tool_call>`
94
+
95
+ _, calls := ExtractPromptToolCalls(content)
96
+ if len(calls) != 1 {
97
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
98
+ }
99
+ if calls[0].Function.Name != "test" {
100
+ t.Errorf("expected test, got %s", calls[0].Function.Name)
101
+ }
102
+ }
103
+
104
+ func TestHasPromptToolCallOpen(t *testing.T) {
105
+ tests := []struct {
106
+ content string
107
+ expected bool
108
+ }{
109
+ {"hello", false},
110
+ {"<tool_call>{}", true},
111
+ {"<tool_call>{}</tool_call>", false},
112
+ {"text <tool_call>partial...", true},
113
+ {"<tool_call>a</tool_call><tool_call>b", true},
114
+ {"[TOOL]partial", true},
115
+ {"[TOOL]{\"name\":\"x\"}[/TOOL]", false},
116
+ {"[TOOL_CALL]partial", true},
117
+ {"[TOOL_CALL]{\"name\":\"x\"}[/TOOL_CALL]", false},
118
+ }
119
+ for _, tt := range tests {
120
+ if got := HasPromptToolCallOpen(tt.content); got != tt.expected {
121
+ t.Errorf("HasPromptToolCallOpen(%q) = %v, want %v", tt.content, got, tt.expected)
122
+ }
123
+ }
124
+ }
125
+
126
+ // ===== [TOOL]...[/TOOL] 格式 =====
127
+
128
+ func TestExtractPromptToolCalls_AltToolFormat(t *testing.T) {
129
+ content := `[TOOL]{"name": "get_weather", "arguments": {"city": "上海"}}[/TOOL]`
130
+
131
+ clean, calls := ExtractPromptToolCalls(content)
132
+ if len(calls) != 1 {
133
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
134
+ }
135
+ if calls[0].Function.Name != "get_weather" {
136
+ t.Errorf("expected get_weather, got %s", calls[0].Function.Name)
137
+ }
138
+ if calls[0].Type != "function" {
139
+ t.Errorf("expected type function, got %s", calls[0].Type)
140
+ }
141
+ if clean != "" {
142
+ t.Errorf("expected empty clean, got %q", clean)
143
+ }
144
+ }
145
+
146
+ func TestExtractPromptToolCalls_AltToolCallFormat(t *testing.T) {
147
+ content := `好的,我来调用工具。
148
+ [TOOL_CALL]{"name": "create_file", "arguments": {"filename": "test.txt", "content": "hello"}}[/TOOL_CALL]`
149
+
150
+ clean, calls := ExtractPromptToolCalls(content)
151
+ if len(calls) != 1 {
152
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
153
+ }
154
+ if calls[0].Function.Name != "create_file" {
155
+ t.Errorf("expected create_file, got %s", calls[0].Function.Name)
156
+ }
157
+ if !containsStr(clean, "好的") {
158
+ t.Errorf("expected clean to contain surrounding text, got %q", clean)
159
+ }
160
+ }
161
+
162
+ // ===== markdown JSON block 格式 =====
163
+
164
+ func TestExtractPromptToolCalls_JsonBlockFormat(t *testing.T) {
165
+ content := "我来调用工具:\n```json\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"北京\"}}\n```\n"
166
+
167
+ clean, calls := ExtractPromptToolCalls(content)
168
+ if len(calls) != 1 {
169
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
170
+ }
171
+ if calls[0].Function.Name != "get_weather" {
172
+ t.Errorf("expected get_weather, got %s", calls[0].Function.Name)
173
+ }
174
+ if containsStr(clean, "```") {
175
+ t.Errorf("expected clean to not contain code block, got %q", clean)
176
+ }
177
+ }
178
+
179
+ // ===== 混合格式 =====
180
+
181
+ func TestExtractPromptToolCalls_MixedFormats(t *testing.T) {
182
+ content := `<tool_call>{"name": "func_a", "arguments": {}}</tool_call>
183
+ [TOOL]{"name": "func_b", "arguments": {}}[/TOOL]`
184
+
185
+ _, calls := ExtractPromptToolCalls(content)
186
+ if len(calls) != 2 {
187
+ t.Fatalf("expected 2 tool calls, got %d", len(calls))
188
+ }
189
+ // <tool_call> 优先被解析
190
+ names := map[string]bool{}
191
+ for _, c := range calls {
192
+ names[c.Function.Name] = true
193
+ }
194
+ if !names["func_a"] || !names["func_b"] {
195
+ t.Error("expected both func_a and func_b to be extracted")
196
+ }
197
+ }
198
+
199
+ // ===== <tool_call> 优先于其他格式 =====
200
+
201
+ func TestExtractPromptToolCalls_ToolCallPriority(t *testing.T) {
202
+ content := `<tool_call>{"name": "correct", "arguments": {}}</tool_call>`
203
+
204
+ _, calls := ExtractPromptToolCalls(content)
205
+ if len(calls) != 1 {
206
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
207
+ }
208
+ if calls[0].Function.Name != "correct" {
209
+ t.Errorf("expected correct, got %s", calls[0].Function.Name)
210
+ }
211
+ }
212
+
213
+ func containsStr(s, substr string) bool {
214
+ return len(s) >= len(substr) && (s == substr || len(s) > 0 && findSubstr(s, substr))
215
+ }
216
+
217
+ func findSubstr(s, substr string) bool {
218
+ for i := 0; i <= len(s)-len(substr); i++ {
219
+ if s[i:i+len(substr)] == substr {
220
+ return true
221
+ }
222
+ }
223
+ return false
224
+ }
internal/handler/chat.go CHANGED
@@ -94,6 +94,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
94
  totalContentOutputLength := 0
95
  hasToolCalls := false
96
  var collectedToolCalls []model.ToolCall
 
97
 
98
  for scanner.Scan() {
99
  line := scanner.Text()
@@ -145,7 +146,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
145
  Model: modelName,
146
  Choices: []model.Choice{{
147
  Index: 0,
148
- Delta: model.Delta{ReasoningContent: reasoningContent},
149
  FinishReason: nil,
150
  }},
151
  }
@@ -182,7 +183,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
182
  Model: modelName,
183
  Choices: []model.Choice{{
184
  Index: 0,
185
- Delta: model.Delta{Content: textBeforeBlock},
186
  FinishReason: nil,
187
  }},
188
  }
@@ -209,7 +210,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
209
  Model: modelName,
210
  Choices: []model.Choice{{
211
  Index: 0,
212
- Delta: model.Delta{Content: textBeforeBlock},
213
  FinishReason: nil,
214
  }},
215
  }
@@ -247,7 +248,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
247
  Model: modelName,
248
  Choices: []model.Choice{{
249
  Index: 0,
250
- Delta: model.Delta{
251
  ToolCalls: []model.ToolCall{tc},
252
  },
253
  FinishReason: nil,
@@ -270,7 +271,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
270
  Model: modelName,
271
  Choices: []model.Choice{{
272
  Index: 0,
273
- Delta: model.Delta{Content: pendingSourcesMarkdown},
274
  FinishReason: nil,
275
  }},
276
  }
@@ -288,7 +289,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
288
  Model: modelName,
289
  Choices: []model.Choice{{
290
  Index: 0,
291
- Delta: model.Delta{Content: pendingImageSearchMarkdown},
292
  FinishReason: nil,
293
  }},
294
  }
@@ -313,7 +314,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
313
  Model: modelName,
314
  Choices: []model.Choice{{
315
  Index: 0,
316
- Delta: model.Delta{ReasoningContent: processedRemaining},
317
  FinishReason: nil,
318
  }},
319
  }
@@ -332,7 +333,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
332
  Model: modelName,
333
  Choices: []model.Choice{{
334
  Index: 0,
335
- Delta: model.Delta{ReasoningContent: pendingSourcesMarkdown},
336
  FinishReason: nil,
337
  }},
338
  }
@@ -382,7 +383,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
382
  Model: modelName,
383
  Choices: []model.Choice{{
384
  Index: 0,
385
- Delta: model.Delta{ReasoningContent: reasoningContent},
386
  FinishReason: nil,
387
  }},
388
  }
@@ -405,6 +406,63 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
405
  totalContentOutputLength += len([]rune(content))
406
  }
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  chunk := model.ChatCompletionChunk{
409
  ID: completionID,
410
  Object: "chat.completion.chunk",
@@ -412,7 +470,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
412
  Model: modelName,
413
  Choices: []model.Choice{{
414
  Index: 0,
415
- Delta: model.Delta{Content: content},
416
  FinishReason: nil,
417
  }},
418
  }
@@ -426,6 +484,39 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
426
  logger.LogError("[Upstream] scanner error: %v", err)
427
  }
428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  if remaining := searchRefFilter.Flush(); remaining != "" {
430
  hasContent = true
431
  chunk := model.ChatCompletionChunk{
@@ -435,7 +526,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
435
  Model: modelName,
436
  Choices: []model.Choice{{
437
  Index: 0,
438
- Delta: model.Delta{Content: remaining},
439
  FinishReason: nil,
440
  }},
441
  }
@@ -459,7 +550,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
459
  Model: modelName,
460
  Choices: []model.Choice{{
461
  Index: 0,
462
- Delta: model.Delta{},
463
  FinishReason: &stopReason,
464
  }},
465
  }
@@ -616,6 +707,15 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
616
  fullReasoning := strings.Join(reasoningChunks, "")
617
  fullReasoning = searchRefFilter.Process(fullReasoning) + searchRefFilter.Flush()
618
 
 
 
 
 
 
 
 
 
 
619
  if fullContent == "" && len(collectedToolCalls) == 0 {
620
  logger.LogError("Non-stream response 200 but no content received")
621
  }
@@ -644,3 +744,21 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
644
  w.Header().Set("Content-Type", "application/json")
645
  json.NewEncoder(w).Encode(response)
646
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  totalContentOutputLength := 0
95
  hasToolCalls := false
96
  var collectedToolCalls []model.ToolCall
97
+ promptToolBuffer := "" // 用于 prompt 注入模式下缓冲 answer 文本以检测 <tool_call>
98
 
99
  for scanner.Scan() {
100
  line := scanner.Text()
 
146
  Model: modelName,
147
  Choices: []model.Choice{{
148
  Index: 0,
149
+ Delta: &model.Delta{ReasoningContent: reasoningContent},
150
  FinishReason: nil,
151
  }},
152
  }
 
183
  Model: modelName,
184
  Choices: []model.Choice{{
185
  Index: 0,
186
+ Delta: &model.Delta{Content: textBeforeBlock},
187
  FinishReason: nil,
188
  }},
189
  }
 
210
  Model: modelName,
211
  Choices: []model.Choice{{
212
  Index: 0,
213
+ Delta: &model.Delta{Content: textBeforeBlock},
214
  FinishReason: nil,
215
  }},
216
  }
 
248
  Model: modelName,
249
  Choices: []model.Choice{{
250
  Index: 0,
251
+ Delta: &model.Delta{
252
  ToolCalls: []model.ToolCall{tc},
253
  },
254
  FinishReason: nil,
 
271
  Model: modelName,
272
  Choices: []model.Choice{{
273
  Index: 0,
274
+ Delta: &model.Delta{Content: pendingSourcesMarkdown},
275
  FinishReason: nil,
276
  }},
277
  }
 
289
  Model: modelName,
290
  Choices: []model.Choice{{
291
  Index: 0,
292
+ Delta: &model.Delta{Content: pendingImageSearchMarkdown},
293
  FinishReason: nil,
294
  }},
295
  }
 
314
  Model: modelName,
315
  Choices: []model.Choice{{
316
  Index: 0,
317
+ Delta: &model.Delta{ReasoningContent: processedRemaining},
318
  FinishReason: nil,
319
  }},
320
  }
 
333
  Model: modelName,
334
  Choices: []model.Choice{{
335
  Index: 0,
336
+ Delta: &model.Delta{ReasoningContent: pendingSourcesMarkdown},
337
  FinishReason: nil,
338
  }},
339
  }
 
383
  Model: modelName,
384
  Choices: []model.Choice{{
385
  Index: 0,
386
+ Delta: &model.Delta{ReasoningContent: reasoningContent},
387
  FinishReason: nil,
388
  }},
389
  }
 
406
  totalContentOutputLength += len([]rune(content))
407
  }
408
 
409
+ // prompt 注入模式:缓冲 answer 文本,检测 <tool_call> 块
410
+ if len(tools) > 0 {
411
+ promptToolBuffer += content
412
+ // 循环提取完整的 <tool_call>...</tool_call> 块
413
+ for {
414
+ openIdx := strings.Index(promptToolBuffer, "<tool_call>")
415
+ if openIdx == -1 {
416
+ // 无 <tool_call> 标签,全部安全输出
417
+ break
418
+ }
419
+ // 输出 <tool_call> 之前的安全文本
420
+ if openIdx > 0 {
421
+ safeContent := promptToolBuffer[:openIdx]
422
+ promptToolBuffer = promptToolBuffer[openIdx:]
423
+ if safeContent != "" {
424
+ sendContentChunk(w, flusher, completionID, modelName, safeContent)
425
+ }
426
+ }
427
+ // 检查是否有完整的闭合标签
428
+ closeIdx := strings.Index(promptToolBuffer, "</tool_call>")
429
+ if closeIdx == -1 {
430
+ // 未闭合,等待更多数据
431
+ break
432
+ }
433
+ // 提取完整块
434
+ blockEnd := closeIdx + len("</tool_call>")
435
+ block := promptToolBuffer[:blockEnd]
436
+ promptToolBuffer = promptToolBuffer[blockEnd:]
437
+
438
+ // 解析 tool call
439
+ _, toolCalls := filter.ExtractPromptToolCalls(block)
440
+ if len(toolCalls) > 0 {
441
+ collectedToolCalls = append(collectedToolCalls, toolCalls...)
442
+ hasToolCalls = true
443
+ for _, tc := range toolCalls {
444
+ chunk := model.ChatCompletionChunk{
445
+ ID: completionID,
446
+ Object: "chat.completion.chunk",
447
+ Created: time.Now().Unix(),
448
+ Model: modelName,
449
+ Choices: []model.Choice{{
450
+ Index: 0,
451
+ Delta: &model.Delta{
452
+ ToolCalls: []model.ToolCall{tc},
453
+ },
454
+ FinishReason: nil,
455
+ }},
456
+ }
457
+ data, _ := json.Marshal(chunk)
458
+ fmt.Fprintf(w, "data: %s\n\n", data)
459
+ flusher.Flush()
460
+ }
461
+ }
462
+ }
463
+ continue
464
+ }
465
+
466
  chunk := model.ChatCompletionChunk{
467
  ID: completionID,
468
  Object: "chat.completion.chunk",
 
470
  Model: modelName,
471
  Choices: []model.Choice{{
472
  Index: 0,
473
+ Delta: &model.Delta{Content: content},
474
  FinishReason: nil,
475
  }},
476
  }
 
484
  logger.LogError("[Upstream] scanner error: %v", err)
485
  }
486
 
487
+ // prompt 注入模式:flush 缓冲区中剩余的文本
488
+ if promptToolBuffer != "" {
489
+ // 尝试最后一次提取 tool calls
490
+ cleanContent, toolCalls := filter.ExtractPromptToolCalls(promptToolBuffer)
491
+ if len(toolCalls) > 0 {
492
+ collectedToolCalls = append(collectedToolCalls, toolCalls...)
493
+ hasToolCalls = true
494
+ for _, tc := range toolCalls {
495
+ chunk := model.ChatCompletionChunk{
496
+ ID: completionID,
497
+ Object: "chat.completion.chunk",
498
+ Created: time.Now().Unix(),
499
+ Model: modelName,
500
+ Choices: []model.Choice{{
501
+ Index: 0,
502
+ Delta: &model.Delta{
503
+ ToolCalls: []model.ToolCall{tc},
504
+ },
505
+ FinishReason: nil,
506
+ }},
507
+ }
508
+ data, _ := json.Marshal(chunk)
509
+ fmt.Fprintf(w, "data: %s\n\n", data)
510
+ flusher.Flush()
511
+ }
512
+ }
513
+ if cleanContent != "" {
514
+ sendContentChunk(w, flusher, completionID, modelName, cleanContent)
515
+ hasContent = true
516
+ }
517
+ promptToolBuffer = ""
518
+ }
519
+
520
  if remaining := searchRefFilter.Flush(); remaining != "" {
521
  hasContent = true
522
  chunk := model.ChatCompletionChunk{
 
526
  Model: modelName,
527
  Choices: []model.Choice{{
528
  Index: 0,
529
+ Delta: &model.Delta{Content: remaining},
530
  FinishReason: nil,
531
  }},
532
  }
 
550
  Model: modelName,
551
  Choices: []model.Choice{{
552
  Index: 0,
553
+ Delta: &model.Delta{},
554
  FinishReason: &stopReason,
555
  }},
556
  }
 
707
  fullReasoning := strings.Join(reasoningChunks, "")
708
  fullReasoning = searchRefFilter.Process(fullReasoning) + searchRefFilter.Flush()
709
 
710
+ // prompt 注入模式:从 answer 文本中提取 <tool_call> 块
711
+ if len(tools) > 0 && len(collectedToolCalls) == 0 {
712
+ cleanContent, promptToolCalls := filter.ExtractPromptToolCalls(fullContent)
713
+ if len(promptToolCalls) > 0 {
714
+ collectedToolCalls = promptToolCalls
715
+ fullContent = cleanContent
716
+ }
717
+ }
718
+
719
  if fullContent == "" && len(collectedToolCalls) == 0 {
720
  logger.LogError("Non-stream response 200 but no content received")
721
  }
 
744
  w.Header().Set("Content-Type", "application/json")
745
  json.NewEncoder(w).Encode(response)
746
  }
747
+
748
+ // sendContentChunk 发送一个 content SSE chunk
749
+ func sendContentChunk(w http.ResponseWriter, flusher http.Flusher, completionID, modelName, content string) {
750
+ chunk := model.ChatCompletionChunk{
751
+ ID: completionID,
752
+ Object: "chat.completion.chunk",
753
+ Created: time.Now().Unix(),
754
+ Model: modelName,
755
+ Choices: []model.Choice{{
756
+ Index: 0,
757
+ Delta: &model.Delta{Content: content},
758
+ FinishReason: nil,
759
+ }},
760
+ }
761
+ data, _ := json.Marshal(chunk)
762
+ fmt.Fprintf(w, "data: %s\n\n", data)
763
+ flusher.Flush()
764
+ }
internal/handler/chat_test.go CHANGED
@@ -574,3 +574,76 @@ func TestNonStreamResponse_FullFormat(t *testing.T) {
574
  t.Errorf("Role = %q", resp.Choices[0].Message.Role)
575
  }
576
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  t.Errorf("Role = %q", resp.Choices[0].Message.Role)
575
  }
576
  }
577
+
578
+ // ===== 流式:prompt 注入模式 <tool_call> 在 answer 文本中 =====
579
+
580
+ func TestStreamResponse_PromptInjectionToolCall(t *testing.T) {
581
+ body := newFakeBody(
582
+ sseEvent("answer", "好的,我来查询。\n", ""),
583
+ sseEvent("answer", `<tool_call>{"name":"get_weather","arguments":{"city":"北京"}}</tool_call>`, ""),
584
+ sseEventDone(),
585
+ )
586
+
587
+ w := httptest.NewRecorder()
588
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
589
+
590
+ result := w.Body.String()
591
+
592
+ if !strings.Contains(result, `"tool_calls"`) {
593
+ t.Error("missing tool_calls in prompt injection stream")
594
+ }
595
+ if !strings.Contains(result, `"get_weather"`) {
596
+ t.Error("missing function name")
597
+ }
598
+ if !strings.Contains(result, `"finish_reason":"tool_calls"`) {
599
+ t.Error("finish_reason should be tool_calls")
600
+ }
601
+ }
602
+
603
+ // ===== 非流式:prompt 注入模式 =====
604
+
605
+ func TestNonStreamResponse_PromptInjectionToolCall(t *testing.T) {
606
+ body := newFakeBody(
607
+ sseEvent("answer", "我来查询天气。\n<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"city\":\"上海\"}}</tool_call>", ""),
608
+ sseEventDone(),
609
+ )
610
+
611
+ w := httptest.NewRecorder()
612
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
613
+
614
+ var resp model.ChatCompletionResponse
615
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
616
+ t.Fatalf("decode: %v", err)
617
+ }
618
+
619
+ msg := resp.Choices[0].Message
620
+ if len(msg.ToolCalls) != 1 {
621
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
622
+ }
623
+ if msg.ToolCalls[0].Function.Name != "get_weather" {
624
+ t.Errorf("Function.Name = %q", msg.ToolCalls[0].Function.Name)
625
+ }
626
+ if strings.Contains(msg.Content, "<tool_call>") {
627
+ t.Error("content should not contain <tool_call> tags")
628
+ }
629
+ if *resp.Choices[0].FinishReason != "tool_calls" {
630
+ t.Errorf("FinishReason = %q, want tool_calls", *resp.Choices[0].FinishReason)
631
+ }
632
+ }
633
+
634
+ // ===== 非流式:response 中不应有 delta 字段 =====
635
+
636
+ func TestNonStreamResponse_NoDeltaField(t *testing.T) {
637
+ body := newFakeBody(
638
+ sseEvent("answer", "hello", ""),
639
+ sseEventDone(),
640
+ )
641
+
642
+ w := httptest.NewRecorder()
643
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
644
+
645
+ result := w.Body.String()
646
+ if strings.Contains(result, `"delta"`) {
647
+ t.Error("non-streaming response should not contain delta field")
648
+ }
649
+ }
internal/model/types.go CHANGED
@@ -160,7 +160,7 @@ type ChatCompletionChunk struct {
160
 
161
  type Choice struct {
162
  Index int `json:"index"`
163
- Delta Delta `json:"delta,omitempty"`
164
  Message *MessageResp `json:"message,omitempty"`
165
  FinishReason *string `json:"finish_reason"`
166
  }
 
160
 
161
  type Choice struct {
162
  Index int `json:"index"`
163
+ Delta *Delta `json:"delta,omitempty"`
164
  Message *MessageResp `json:"message,omitempty"`
165
  FinishReason *string `json:"finish_reason"`
166
  }
internal/model/types_test.go CHANGED
@@ -2,6 +2,7 @@ package model
2
 
3
  import (
4
  "encoding/json"
 
5
  "testing"
6
  )
7
 
@@ -423,7 +424,7 @@ func TestChunkWithToolCallsFinishReason(t *testing.T) {
423
  Model: "glm-4.7",
424
  Choices: []Choice{{
425
  Index: 0,
426
- Delta: Delta{},
427
  FinishReason: &reason,
428
  }},
429
  }
@@ -501,3 +502,41 @@ func TestCompletionResponseWithToolCalls(t *testing.T) {
501
  t.Errorf("FinishReason = %q", *decoded.Choices[0].FinishReason)
502
  }
503
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import (
4
  "encoding/json"
5
+ "strings"
6
  "testing"
7
  )
8
 
 
424
  Model: "glm-4.7",
425
  Choices: []Choice{{
426
  Index: 0,
427
+ Delta: &Delta{},
428
  FinishReason: &reason,
429
  }},
430
  }
 
502
  t.Errorf("FinishReason = %q", *decoded.Choices[0].FinishReason)
503
  }
504
  }
505
+
506
+ // ===== Delta 指针:nil 时不出现在 JSON 中 =====
507
+
508
+ func TestChoiceDeltaNil_OmittedInJSON(t *testing.T) {
509
+ reason := "stop"
510
+ choice := Choice{
511
+ Index: 0,
512
+ Message: &MessageResp{
513
+ Role: "assistant",
514
+ Content: "hello",
515
+ },
516
+ FinishReason: &reason,
517
+ }
518
+
519
+ data, _ := json.Marshal(choice)
520
+ s := string(data)
521
+ if strings.Contains(s, `"delta"`) {
522
+ t.Errorf("nil Delta should be omitted, got: %s", s)
523
+ }
524
+ }
525
+
526
+ // ===== Delta 指针:非 nil 时正常序列化 =====
527
+
528
+ func TestChoiceDeltaNotNil_SerializedInJSON(t *testing.T) {
529
+ choice := Choice{
530
+ Index: 0,
531
+ Delta: &Delta{Content: "test content"},
532
+ }
533
+
534
+ data, _ := json.Marshal(choice)
535
+ s := string(data)
536
+ if !strings.Contains(s, `"delta"`) {
537
+ t.Error("non-nil Delta should appear in JSON")
538
+ }
539
+ if !strings.Contains(s, `"test content"`) {
540
+ t.Error("Delta content should be serialized")
541
+ }
542
+ }
internal/tools/prompt.go ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package tools
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "strings"
7
+
8
+ "zai-proxy/internal/model"
9
+ )
10
+
11
+ // BuildToolSystemPrompt 将工具定义列表转换为 system prompt 文本,
12
+ // 指示模型使用 <tool_call> 格式输出工具调用。
13
+ func BuildToolSystemPrompt(tools []model.Tool, toolChoice interface{}) string {
14
+ if len(tools) == 0 {
15
+ return ""
16
+ }
17
+
18
+ var sb strings.Builder
19
+
20
+ sb.WriteString("# 工具调用规则\n\n")
21
+ sb.WriteString("你可以使用下面列出的工具。当你需要调用工具时,**必须严格使用以下 XML 格式**输出调用请求(不要使用 markdown 代码块、不要使用 [TOOL] 或其他格式):\n\n")
22
+ sb.WriteString("<tool_call>{\"name\": \"函数名\", \"arguments\": {\"参数名\": \"参数值\"}}</tool_call>\n\n")
23
+ sb.WriteString("**重要规则:**\n")
24
+ sb.WriteString("- 你不能自行执行工具,只能输出 <tool_call> 标签,由系统执行后将结果返回给你\n")
25
+ sb.WriteString("- 每个工具调用必须独立包裹在 <tool_call></tool_call> 标签中\n")
26
+ sb.WriteString("- arguments 必须是合法 JSON 对象\n")
27
+ sb.WriteString("- 不要在 <tool_call> 标签外描述调用参数\n\n")
28
+
29
+ sb.WriteString("## 示例\n\n")
30
+ sb.WriteString("用户: 帮我创建一个文件 test.txt 内容为 hello\n")
31
+ sb.WriteString("助手: 好的,我来为您创建文件。\n")
32
+ sb.WriteString("<tool_call>{\"name\": \"create_file\", \"arguments\": {\"filename\": \"test.txt\", \"content\": \"hello\"}}</tool_call>\n\n")
33
+ sb.WriteString("用户: 查询北京和上海的天气\n")
34
+ sb.WriteString("助手: 我来查询这两个城市的天气。\n")
35
+ sb.WriteString("<tool_call>{\"name\": \"get_weather\", \"arguments\": {\"location\": \"北京\"}}</tool_call>\n")
36
+ sb.WriteString("<tool_call>{\"name\": \"get_weather\", \"arguments\": {\"location\": \"上海\"}}</tool_call>\n\n")
37
+
38
+ sb.WriteString("## 可用工具\n\n")
39
+
40
+ for _, tool := range tools {
41
+ sb.WriteString(fmt.Sprintf("### %s\n", tool.Function.Name))
42
+ if tool.Function.Description != "" {
43
+ sb.WriteString(fmt.Sprintf("%s\n", tool.Function.Description))
44
+ }
45
+ if tool.Function.Parameters != nil {
46
+ params, err := json.Marshal(tool.Function.Parameters)
47
+ if err == nil {
48
+ sb.WriteString(fmt.Sprintf("Parameters: %s\n", string(params)))
49
+ }
50
+ }
51
+ sb.WriteString("\n")
52
+ }
53
+
54
+ // 处理 tool_choice
55
+ if toolChoice != nil {
56
+ switch tc := toolChoice.(type) {
57
+ case string:
58
+ switch tc {
59
+ case "none":
60
+ sb.WriteString("**禁止调用任何工具,直接回答问题。**\n")
61
+ case "required":
62
+ sb.WriteString("**你的回复中必须包含至少一个 <tool_call> 标签。即使你认为不需要调用工具,也必须调用。**\n")
63
+ // "auto" is the default, no special instruction needed
64
+ }
65
+ case map[string]interface{}:
66
+ // tool_choice = {"type": "function", "function": {"name": "xxx"}}
67
+ if fn, ok := tc["function"].(map[string]interface{}); ok {
68
+ if name, ok := fn["name"].(string); ok {
69
+ sb.WriteString(fmt.Sprintf("**你必须调用工具 \"%s\",使用 <tool_call> 标签输出调用。**\n", name))
70
+ }
71
+ }
72
+ }
73
+ }
74
+
75
+ return sb.String()
76
+ }
77
+
78
+ // ConvertToolCallToText 将 assistant 消息中的 tool_calls 转换为 <tool_call> 文本格式,
79
+ // 用于在 prompt 注入模式下将历史 tool_calls 传给上游。
80
+ func ConvertToolCallToText(toolCalls []model.ToolCall) string {
81
+ var parts []string
82
+ for _, tc := range toolCalls {
83
+ callJSON, _ := json.Marshal(map[string]interface{}{
84
+ "name": tc.Function.Name,
85
+ "arguments": json.RawMessage(tc.Function.Arguments),
86
+ })
87
+ parts = append(parts, fmt.Sprintf("<tool_call>%s</tool_call>", string(callJSON)))
88
+ }
89
+ return strings.Join(parts, "\n")
90
+ }
91
+
92
+ // ConvertToolResultToText 将 tool 角色的消息转换为文本格式,
93
+ // 用于在 prompt 注入模式下传递工具执行结果。
94
+ func ConvertToolResultToText(toolCallID string, content string) string {
95
+ return fmt.Sprintf("<tool_result call_id=\"%s\">%s</tool_result>", toolCallID, content)
96
+ }
internal/tools/prompt_test.go ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package tools
2
+
3
+ import (
4
+ "strings"
5
+ "testing"
6
+
7
+ "zai-proxy/internal/model"
8
+ )
9
+
10
+ func TestBuildToolSystemPrompt_Basic(t *testing.T) {
11
+ tools := []model.Tool{
12
+ {
13
+ Type: "function",
14
+ Function: model.ToolFunction{
15
+ Name: "get_weather",
16
+ Description: "Get current weather",
17
+ Parameters: map[string]interface{}{
18
+ "type": "object",
19
+ "properties": map[string]interface{}{
20
+ "city": map[string]interface{}{
21
+ "type": "string",
22
+ "description": "City name",
23
+ },
24
+ },
25
+ "required": []string{"city"},
26
+ },
27
+ },
28
+ },
29
+ }
30
+
31
+ result := BuildToolSystemPrompt(tools, nil)
32
+
33
+ if !strings.Contains(result, "get_weather") {
34
+ t.Error("should contain tool name")
35
+ }
36
+ if !strings.Contains(result, "Get current weather") {
37
+ t.Error("should contain description")
38
+ }
39
+ if !strings.Contains(result, "<tool_call>") {
40
+ t.Error("should contain format instruction")
41
+ }
42
+ if !strings.Contains(result, "city") {
43
+ t.Error("should contain parameter info")
44
+ }
45
+ }
46
+
47
+ func TestBuildToolSystemPrompt_Empty(t *testing.T) {
48
+ result := BuildToolSystemPrompt(nil, nil)
49
+ if result != "" {
50
+ t.Error("should return empty for nil tools")
51
+ }
52
+ }
53
+
54
+ func TestBuildToolSystemPrompt_ToolChoiceNone(t *testing.T) {
55
+ tools := []model.Tool{{
56
+ Type: "function",
57
+ Function: model.ToolFunction{Name: "test"},
58
+ }}
59
+
60
+ result := BuildToolSystemPrompt(tools, "none")
61
+ if !strings.Contains(result, "禁止调用任何工具") {
62
+ t.Error("should instruct not to call tools")
63
+ }
64
+ }
65
+
66
+ func TestBuildToolSystemPrompt_ToolChoiceRequired(t *testing.T) {
67
+ tools := []model.Tool{{
68
+ Type: "function",
69
+ Function: model.ToolFunction{Name: "test"},
70
+ }}
71
+
72
+ result := BuildToolSystemPrompt(tools, "required")
73
+ if !strings.Contains(result, "必须包含至少一个") {
74
+ t.Error("should instruct to call at least one tool")
75
+ }
76
+ }
77
+
78
+ func TestBuildToolSystemPrompt_ToolChoiceSpecific(t *testing.T) {
79
+ tools := []model.Tool{{
80
+ Type: "function",
81
+ Function: model.ToolFunction{Name: "get_weather"},
82
+ }}
83
+
84
+ choice := map[string]interface{}{
85
+ "type": "function",
86
+ "function": map[string]interface{}{
87
+ "name": "get_weather",
88
+ },
89
+ }
90
+
91
+ result := BuildToolSystemPrompt(tools, choice)
92
+ if !strings.Contains(result, `必须调用工具 "get_weather"`) {
93
+ t.Error("should instruct to call specific tool")
94
+ }
95
+ }
96
+
97
+ func TestConvertToolCallToText(t *testing.T) {
98
+ toolCalls := []model.ToolCall{
99
+ {
100
+ ID: "call_123",
101
+ Type: "function",
102
+ Function: model.FunctionCall{
103
+ Name: "get_weather",
104
+ Arguments: `{"city":"Beijing"}`,
105
+ },
106
+ },
107
+ }
108
+
109
+ result := ConvertToolCallToText(toolCalls)
110
+ if !strings.Contains(result, "<tool_call>") {
111
+ t.Error("should contain <tool_call> tag")
112
+ }
113
+ if !strings.Contains(result, "get_weather") {
114
+ t.Error("should contain function name")
115
+ }
116
+ if !strings.Contains(result, "Beijing") {
117
+ t.Error("should contain arguments")
118
+ }
119
+ }
120
+
121
+ func TestConvertToolResultToText(t *testing.T) {
122
+ result := ConvertToolResultToText("call_123", `{"temp": 25}`)
123
+ if !strings.Contains(result, "call_123") {
124
+ t.Error("should contain call ID")
125
+ }
126
+ if !strings.Contains(result, `{"temp": 25}`) {
127
+ t.Error("should contain result content")
128
+ }
129
+ if !strings.Contains(result, "<tool_result") {
130
+ t.Error("should contain <tool_result> tag")
131
+ }
132
+ }
internal/upstream/client.go CHANGED
@@ -94,11 +94,62 @@ func MakeUpstreamRequest(token string, messages []model.Message, modelName strin
94
  }
95
  }
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  var upstreamMessages []map[string]interface{}
 
98
  for _, msg := range messages {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  upstreamMessages = append(upstreamMessages, msg.ToUpstreamMessage(urlToFileID))
100
  }
101
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  body := map[string]interface{}{
103
  "stream": true,
104
  "model": targetModel,
@@ -120,26 +171,6 @@ func MakeUpstreamRequest(token string, messages []model.Message, modelName strin
120
  body["mcp_servers"] = mcpServers
121
  }
122
 
123
- // 当使用 -tools 模型时,自动注入内置工具(客户端自带工具优先)
124
- if model.IsToolsModel(modelName) {
125
- clientToolNames := make(map[string]bool)
126
- for _, t := range tools {
127
- clientToolNames[t.Function.Name] = true
128
- }
129
- for _, bt := range builtintools.GetBuiltinTools() {
130
- if !clientToolNames[bt.Function.Name] {
131
- tools = append(tools, bt)
132
- }
133
- }
134
- }
135
-
136
- if len(tools) > 0 {
137
- body["tools"] = tools
138
- if toolChoice != nil {
139
- body["tool_choice"] = toolChoice
140
- }
141
- }
142
-
143
  if len(filesData) > 0 {
144
  body["files"] = filesData
145
  body["current_user_message_id"] = userMsgID
 
94
  }
95
  }
96
 
97
+ // 当使用 -tools 模型时,自动注入内置工具(客户端自带工具优先)
98
+ if model.IsToolsModel(modelName) {
99
+ clientToolNames := make(map[string]bool)
100
+ for _, t := range tools {
101
+ clientToolNames[t.Function.Name] = true
102
+ }
103
+ for _, bt := range builtintools.GetBuiltinTools() {
104
+ if !clientToolNames[bt.Function.Name] {
105
+ tools = append(tools, bt)
106
+ }
107
+ }
108
+ }
109
+
110
  var upstreamMessages []map[string]interface{}
111
+ hasPromptTools := len(tools) > 0
112
  for _, msg := range messages {
113
+ if hasPromptTools {
114
+ // prompt 注入模式:将 tool_calls / tool 结果转为纯文本
115
+ if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
116
+ text, _ := msg.ParseContent()
117
+ callText := builtintools.ConvertToolCallToText(msg.ToolCalls)
118
+ if text != "" {
119
+ text = text + "\n" + callText
120
+ } else {
121
+ text = callText
122
+ }
123
+ upstreamMessages = append(upstreamMessages, map[string]interface{}{
124
+ "role": "assistant",
125
+ "content": text,
126
+ })
127
+ continue
128
+ }
129
+ if msg.Role == "tool" {
130
+ text, _ := msg.ParseContent()
131
+ upstreamMessages = append(upstreamMessages, map[string]interface{}{
132
+ "role": "user",
133
+ "content": builtintools.ConvertToolResultToText(msg.ToolCallID, text),
134
+ })
135
+ continue
136
+ }
137
+ }
138
  upstreamMessages = append(upstreamMessages, msg.ToUpstreamMessage(urlToFileID))
139
  }
140
 
141
+ // 工具注入:通过 system prompt 注入工具定义(z.ai 不支持原生 tools 字段)
142
+ if len(tools) > 0 {
143
+ toolSystemPrompt := builtintools.BuildToolSystemPrompt(tools, toolChoice)
144
+ if toolSystemPrompt != "" {
145
+ systemMsg := map[string]interface{}{
146
+ "role": "system",
147
+ "content": toolSystemPrompt,
148
+ }
149
+ upstreamMessages = append([]map[string]interface{}{systemMsg}, upstreamMessages...)
150
+ }
151
+ }
152
+
153
  body := map[string]interface{}{
154
  "stream": true,
155
  "model": targetModel,
 
171
  body["mcp_servers"] = mcpServers
172
  }
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  if len(filesData) > 0 {
175
  body["files"] = filesData
176
  body["current_user_message_id"] = userMsgID