ss22345 commited on
Commit
5a55e77
·
1 Parent(s): 48d903a

feat: support tool/function calling (OpenAI compatible)

Browse files

Add tools and tool_choice fields to chat requests, parse upstream
tool_call responses, and return them in OpenAI-compatible format
with finish_reason=tool_calls. Includes builtin tools, test script,
and unit tests.

README.md CHANGED
@@ -9,6 +9,7 @@ zai-proxy 是一个基于 Go 语言的代理服务,将 z.ai 网页聊天转换
9
  - 支持多种 GLM 模型
10
  - 支持思考模式 (thinking)
11
  - 支持联网搜索模式 (search)
 
12
  - 支持多模态图片输入
13
  - 支持匿名 Token(免登录)
14
  - **自动生成签名**
@@ -87,13 +88,18 @@ curl http://localhost:8000/v1/chat/completions \
87
 
88
  - `-thinking`: 启用思考模式,响应会包含 `reasoning_content` 字段
89
  - `-search`: 启用联网搜索模式
 
90
  - (TODO) `-deepsearch`: 启用多轮搜索,深入研究分析
91
 
 
 
92
  示例:
93
 
94
  - `GLM-4.7-thinking`
95
  - `GLM-4.7-search`
96
  - `GLM-4.7-thinking-search`
 
 
97
 
98
  ## 使用示例
99
 
@@ -130,3 +136,85 @@ curl http://localhost:8000/v1/chat/completions \
130
  ### 支持的图片格式:
131
  - HTTP/HTTPS URL
132
  - Base64 编码 (data:image/jpeg;base64,...)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  - 支持多种 GLM 模型
10
  - 支持思考模式 (thinking)
11
  - 支持联网搜索模式 (search)
12
+ - 支持内置工具调用 (tools)
13
  - 支持多模态图片输入
14
  - 支持匿名 Token(免登录)
15
  - **自动生成签名**
 
88
 
89
  - `-thinking`: 启用思考模式,响应会包含 `reasoning_content` 字段
90
  - `-search`: 启用联网搜索模式
91
+ - `-tools`: 自动注入内置工具定义,模型会返回 `tool_calls` 进行函数调用
92
  - (TODO) `-deepsearch`: 启用多轮搜索,深入研究分析
93
 
94
+ 标签可任意组合,顺序不限:
95
+
96
  示例:
97
 
98
  - `GLM-4.7-thinking`
99
  - `GLM-4.7-search`
100
  - `GLM-4.7-thinking-search`
101
+ - `GLM-4.7-tools`
102
+ - `GLM-4.7-tools-thinking`
103
 
104
  ## 使用示例
105
 
 
136
  ### 支持的图片格式:
137
  - HTTP/HTTPS URL
138
  - Base64 编码 (data:image/jpeg;base64,...)
139
+
140
+ ## 工具调用 (Function Calling)
141
+
142
+ 使用 `-tools` 后缀时,代理会自动注入 6 个内置工具定义。模型会根据用户输入决定是否调用工具。
143
+
144
+ ### 内置工具
145
+
146
+ | 工具名 | 描述 |
147
+ |--------|------|
148
+ | `get_current_time` | 获取当前时间 |
149
+ | `calculate` | 执行数学计算 |
150
+ | `search_web` | 搜索网络信息 |
151
+ | `query_database` | 执行SQL查询 |
152
+ | `file_operations` | 文件读写列表 |
153
+ | `call_external_api` | 调用外部API |
154
+
155
+ ### 基本调用
156
+
157
+ ```bash
158
+ curl http://localhost:8000/v1/chat/completions \
159
+ -H "Authorization: Bearer YOUR_ZAI_TOKEN" \
160
+ -H "Content-Type: application/json" \
161
+ -d '{
162
+ "model": "GLM-4.7-tools",
163
+ "messages": [{"role": "user", "content": "现在几点了?"}],
164
+ "stream": true
165
+ }'
166
+ ```
167
+
168
+ 模型会返回 `tool_calls`(`finish_reason` 为 `"tool_calls"`),由客户端自行执行工具并将结果发回。
169
+
170
+ ### 多轮调用流程
171
+
172
+ ```
173
+ 第1轮:用户提问 → 模型返回 tool_calls
174
+ 第2轮:发送工具执行结果 → 模型生成最终回答
175
+ ```
176
+
177
+ ```bash
178
+ curl http://localhost:8000/v1/chat/completions \
179
+ -H "Authorization: Bearer YOUR_ZAI_TOKEN" \
180
+ -H "Content-Type: application/json" \
181
+ -d '{
182
+ "model": "GLM-4.7-tools",
183
+ "messages": [
184
+ {"role": "user", "content": "现在几点了?"},
185
+ {"role": "assistant", "content": "", "tool_calls": [
186
+ {"id": "call_xxx", "type": "function", "function": {"name": "get_current_time", "arguments": "{}"}}
187
+ ]},
188
+ {"role": "tool", "tool_call_id": "call_xxx", "content": "{\"time\": \"2026-03-14 15:30:00\"}"}
189
+ ],
190
+ "stream": true
191
+ }'
192
+ ```
193
+
194
+ ### 自定义工具
195
+
196
+ 也可以不使用 `-tools` 后缀,直接在请求中传入 `tools` 字段(标准 OpenAI 格式):
197
+
198
+ ```json
199
+ {
200
+ "model": "GLM-4.7",
201
+ "messages": [{"role": "user", "content": "北京天气怎么样?"}],
202
+ "tools": [{
203
+ "type": "function",
204
+ "function": {
205
+ "name": "get_weather",
206
+ "description": "获取天气信息",
207
+ "parameters": {
208
+ "type": "object",
209
+ "properties": {
210
+ "city": {"type": "string", "description": "城市名称"}
211
+ },
212
+ "required": ["city"]
213
+ }
214
+ }
215
+ }],
216
+ "tool_choice": "auto"
217
+ }
218
+ ```
219
+
220
+ 两者可混合使用:`-tools` 模型名 + 自定义 `tools` 字段。**客户端自带的同名工具优先**,不会被内置工具覆盖。
internal/filter/toolcall.go ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package filter
2
+
3
+ import (
4
+ "encoding/json"
5
+ "regexp"
6
+ "strings"
7
+
8
+ "zai-proxy/internal/model"
9
+ )
10
+
11
+ var glmToolCallBlockPattern = regexp.MustCompile(`<glm_block[^>]*type="tool_call"[^>]*>([\s\S]*?)</glm_block>`)
12
+
13
+ // IsFunctionToolCall 判断 tool_call 阶段的内容是否是用户定义的函数调用(非 mcp/search)
14
+ func IsFunctionToolCall(editContent string, phase string) bool {
15
+ if phase != "tool_call" {
16
+ return false
17
+ }
18
+ // 排除 mcp / search 类型的 tool call
19
+ if strings.Contains(editContent, `"mcp"`) || strings.Contains(editContent, `mcp-server`) {
20
+ return false
21
+ }
22
+ if strings.Contains(editContent, `"search_result"`) || strings.Contains(editContent, `"search_image"`) {
23
+ return false
24
+ }
25
+ // 包含函数调用特征
26
+ return strings.Contains(editContent, `"function"`) || strings.Contains(editContent, `"arguments"`)
27
+ }
28
+
29
+ // ParseFunctionToolCalls 从上游 edit_content 解析函数调用
30
+ func ParseFunctionToolCalls(editContent string) []model.ToolCall {
31
+ // 尝试从 glm_block 中提取
32
+ matches := glmToolCallBlockPattern.FindAllStringSubmatch(editContent, -1)
33
+ if len(matches) > 0 {
34
+ var allCalls []model.ToolCall
35
+ for _, match := range matches {
36
+ if calls := parseToolCallJSON(match[1]); len(calls) > 0 {
37
+ allCalls = append(allCalls, calls...)
38
+ }
39
+ }
40
+ if len(allCalls) > 0 {
41
+ return allCalls
42
+ }
43
+ }
44
+
45
+ // 尝试直接解析为 JSON
46
+ return parseToolCallJSON(editContent)
47
+ }
48
+
49
+ // parseToolCallJSON 解析 tool call JSON 数据
50
+ func parseToolCallJSON(content string) []model.ToolCall {
51
+ content = strings.TrimSpace(content)
52
+ if content == "" {
53
+ return nil
54
+ }
55
+
56
+ // 尝试解析为单个 tool call 对象
57
+ var single struct {
58
+ ID string `json:"id"`
59
+ Type string `json:"type"`
60
+ Function struct {
61
+ Name string `json:"name"`
62
+ Arguments string `json:"arguments"`
63
+ } `json:"function"`
64
+ Name string `json:"name"`
65
+ Arguments string `json:"arguments"`
66
+ }
67
+ if err := json.Unmarshal([]byte(content), &single); err == nil {
68
+ if single.Function.Name != "" {
69
+ return []model.ToolCall{{
70
+ ID: single.ID,
71
+ Type: "function",
72
+ Function: model.FunctionCall{
73
+ Name: single.Function.Name,
74
+ Arguments: single.Function.Arguments,
75
+ },
76
+ }}
77
+ }
78
+ if single.Name != "" {
79
+ return []model.ToolCall{{
80
+ ID: single.ID,
81
+ Type: "function",
82
+ Function: model.FunctionCall{
83
+ Name: single.Name,
84
+ Arguments: single.Arguments,
85
+ },
86
+ }}
87
+ }
88
+ }
89
+
90
+ // 尝试解析为数组
91
+ var arr []json.RawMessage
92
+ if err := json.Unmarshal([]byte(content), &arr); err == nil {
93
+ var calls []model.ToolCall
94
+ for _, raw := range arr {
95
+ if parsed := parseToolCallJSON(string(raw)); len(parsed) > 0 {
96
+ calls = append(calls, parsed...)
97
+ }
98
+ }
99
+ return calls
100
+ }
101
+
102
+ return nil
103
+ }
internal/filter/toolcall_test.go ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package filter
2
+
3
+ import (
4
+ "testing"
5
+ )
6
+
7
+ // ===== IsFunctionToolCall =====
8
+
9
+ func TestIsFunctionToolCall_True(t *testing.T) {
10
+ tests := []struct {
11
+ name string
12
+ content string
13
+ phase string
14
+ }{
15
+ {
16
+ name: "标准 function 字段",
17
+ content: `{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{}"}}`,
18
+ phase: "tool_call",
19
+ },
20
+ {
21
+ name: "包含 arguments 字段",
22
+ content: `{"name":"get_weather","arguments":"{\"location\":\"北京\"}"}`,
23
+ phase: "tool_call",
24
+ },
25
+ {
26
+ name: "glm_block 包裹的函数调用",
27
+ content: `<glm_block type="tool_call">{"function":{"name":"fn1","arguments":"{}"}}</glm_block>`,
28
+ phase: "tool_call",
29
+ },
30
+ }
31
+
32
+ for _, tt := range tests {
33
+ t.Run(tt.name, func(t *testing.T) {
34
+ if !IsFunctionToolCall(tt.content, tt.phase) {
35
+ t.Error("expected true")
36
+ }
37
+ })
38
+ }
39
+ }
40
+
41
+ func TestIsFunctionToolCall_False(t *testing.T) {
42
+ tests := []struct {
43
+ name string
44
+ content string
45
+ phase string
46
+ }{
47
+ {
48
+ name: "非 tool_call 阶段",
49
+ content: `{"function":{"name":"get_weather","arguments":"{}"}}`,
50
+ phase: "answer",
51
+ },
52
+ {
53
+ name: "mcp tool call",
54
+ content: `{"type":"mcp","function":{"name":"mcp_tool","arguments":"{}"}}`,
55
+ phase: "tool_call",
56
+ },
57
+ {
58
+ name: "mcp-server tool call",
59
+ content: `mcp-server something with "arguments"`,
60
+ phase: "tool_call",
61
+ },
62
+ {
63
+ name: "search_result 内容",
64
+ content: `{"search_result":[...],"function":"x","arguments":"y"}`,
65
+ phase: "tool_call",
66
+ },
67
+ {
68
+ name: "search_image 内容",
69
+ content: `{"search_image":{},"function":"x","arguments":"y"}`,
70
+ phase: "tool_call",
71
+ },
72
+ {
73
+ name: "无函数调用特征",
74
+ content: `{"type":"tool_call","data":"hello world"}`,
75
+ phase: "tool_call",
76
+ },
77
+ {
78
+ name: "空阶段",
79
+ content: `{"function":{"name":"fn","arguments":"{}"}}`,
80
+ phase: "",
81
+ },
82
+ }
83
+
84
+ for _, tt := range tests {
85
+ t.Run(tt.name, func(t *testing.T) {
86
+ if IsFunctionToolCall(tt.content, tt.phase) {
87
+ t.Error("expected false")
88
+ }
89
+ })
90
+ }
91
+ }
92
+
93
+ // ===== ParseFunctionToolCalls =====
94
+
95
+ func TestParseFunctionToolCalls_StandardFormat(t *testing.T) {
96
+ content := `{"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"北京\"}"}}`
97
+
98
+ calls := ParseFunctionToolCalls(content)
99
+ if len(calls) != 1 {
100
+ t.Fatalf("len(calls) = %d, want 1", len(calls))
101
+ }
102
+ if calls[0].ID != "call_abc" {
103
+ t.Errorf("ID = %q, want %q", calls[0].ID, "call_abc")
104
+ }
105
+ if calls[0].Type != "function" {
106
+ t.Errorf("Type = %q, want %q", calls[0].Type, "function")
107
+ }
108
+ if calls[0].Function.Name != "get_weather" {
109
+ t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "get_weather")
110
+ }
111
+ if calls[0].Function.Arguments != `{"location":"北京"}` {
112
+ t.Errorf("Function.Arguments = %q", calls[0].Function.Arguments)
113
+ }
114
+ }
115
+
116
+ func TestParseFunctionToolCalls_FlatFormat(t *testing.T) {
117
+ content := `{"id":"call_flat","name":"get_time","arguments":"{\"timezone\":\"UTC\"}"}`
118
+
119
+ calls := ParseFunctionToolCalls(content)
120
+ if len(calls) != 1 {
121
+ t.Fatalf("len(calls) = %d, want 1", len(calls))
122
+ }
123
+ if calls[0].Function.Name != "get_time" {
124
+ t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "get_time")
125
+ }
126
+ if calls[0].Function.Arguments != `{"timezone":"UTC"}` {
127
+ t.Errorf("Function.Arguments = %q", calls[0].Function.Arguments)
128
+ }
129
+ if calls[0].Type != "function" {
130
+ t.Errorf("Type = %q, want %q", calls[0].Type, "function")
131
+ }
132
+ }
133
+
134
+ func TestParseFunctionToolCalls_GlmBlock(t *testing.T) {
135
+ content := `一些文本<glm_block type="tool_call">{"id":"call_glm","type":"function","function":{"name":"search","arguments":"{\"q\":\"test\"}"}}</glm_block>后续文本`
136
+
137
+ calls := ParseFunctionToolCalls(content)
138
+ if len(calls) != 1 {
139
+ t.Fatalf("len(calls) = %d, want 1", len(calls))
140
+ }
141
+ if calls[0].ID != "call_glm" {
142
+ t.Errorf("ID = %q, want %q", calls[0].ID, "call_glm")
143
+ }
144
+ if calls[0].Function.Name != "search" {
145
+ t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "search")
146
+ }
147
+ }
148
+
149
+ func TestParseFunctionToolCalls_MultipleGlmBlocks(t *testing.T) {
150
+ content := `<glm_block type="tool_call">{"function":{"name":"fn1","arguments":"{}"}}</glm_block>` +
151
+ `<glm_block type="tool_call">{"function":{"name":"fn2","arguments":"{}"}}</glm_block>`
152
+
153
+ calls := ParseFunctionToolCalls(content)
154
+ if len(calls) != 2 {
155
+ t.Fatalf("len(calls) = %d, want 2", len(calls))
156
+ }
157
+ if calls[0].Function.Name != "fn1" {
158
+ t.Errorf("calls[0].Function.Name = %q, want %q", calls[0].Function.Name, "fn1")
159
+ }
160
+ if calls[1].Function.Name != "fn2" {
161
+ t.Errorf("calls[1].Function.Name = %q, want %q", calls[1].Function.Name, "fn2")
162
+ }
163
+ }
164
+
165
+ func TestParseFunctionToolCalls_Array(t *testing.T) {
166
+ content := `[{"id":"c1","type":"function","function":{"name":"fn1","arguments":"{}"}},{"id":"c2","type":"function","function":{"name":"fn2","arguments":"{}"}}]`
167
+
168
+ calls := ParseFunctionToolCalls(content)
169
+ if len(calls) != 2 {
170
+ t.Fatalf("len(calls) = %d, want 2", len(calls))
171
+ }
172
+ if calls[0].Function.Name != "fn1" {
173
+ t.Errorf("calls[0].Function.Name = %q", calls[0].Function.Name)
174
+ }
175
+ if calls[1].Function.Name != "fn2" {
176
+ t.Errorf("calls[1].Function.Name = %q", calls[1].Function.Name)
177
+ }
178
+ }
179
+
180
+ func TestParseFunctionToolCalls_NoID(t *testing.T) {
181
+ content := `{"type":"function","function":{"name":"get_weather","arguments":"{}"}}`
182
+
183
+ calls := ParseFunctionToolCalls(content)
184
+ if len(calls) != 1 {
185
+ t.Fatalf("len(calls) = %d, want 1", len(calls))
186
+ }
187
+ if calls[0].ID != "" {
188
+ t.Errorf("ID = %q, want empty (caller assigns ID)", calls[0].ID)
189
+ }
190
+ }
191
+
192
+ func TestParseFunctionToolCalls_EmptyContent(t *testing.T) {
193
+ calls := ParseFunctionToolCalls("")
194
+ if len(calls) != 0 {
195
+ t.Errorf("len(calls) = %d, want 0", len(calls))
196
+ }
197
+ }
198
+
199
+ func TestParseFunctionToolCalls_WhitespaceOnly(t *testing.T) {
200
+ calls := ParseFunctionToolCalls(" \n\t ")
201
+ if len(calls) != 0 {
202
+ t.Errorf("len(calls) = %d, want 0", len(calls))
203
+ }
204
+ }
205
+
206
+ func TestParseFunctionToolCalls_InvalidJSON(t *testing.T) {
207
+ calls := ParseFunctionToolCalls("not json at all {{{")
208
+ if len(calls) != 0 {
209
+ t.Errorf("len(calls) = %d, want 0", len(calls))
210
+ }
211
+ }
212
+
213
+ func TestParseFunctionToolCalls_JSONWithoutFunctionFields(t *testing.T) {
214
+ calls := ParseFunctionToolCalls(`{"type":"something","data":"hello"}`)
215
+ if len(calls) != 0 {
216
+ t.Errorf("len(calls) = %d, want 0", len(calls))
217
+ }
218
+ }
219
+
220
+ func TestParseFunctionToolCalls_EmptyArray(t *testing.T) {
221
+ calls := ParseFunctionToolCalls(`[]`)
222
+ if len(calls) != 0 {
223
+ t.Errorf("len(calls) = %d, want 0", len(calls))
224
+ }
225
+ }
226
+
227
+ func TestParseFunctionToolCalls_ComplexArguments(t *testing.T) {
228
+ content := `{"function":{"name":"create_order","arguments":"{\"items\":[{\"id\":1,\"qty\":2},{\"id\":3,\"qty\":1}],\"user\":\"张三\"}"}}`
229
+
230
+ calls := ParseFunctionToolCalls(content)
231
+ if len(calls) != 1 {
232
+ t.Fatalf("len(calls) = %d, want 1", len(calls))
233
+ }
234
+ if calls[0].Function.Name != "create_order" {
235
+ t.Errorf("Function.Name = %q", calls[0].Function.Name)
236
+ }
237
+ // 确保复杂 JSON 参数完整保留
238
+ if calls[0].Function.Arguments == "" {
239
+ t.Error("Function.Arguments is empty")
240
+ }
241
+ }
242
+
243
+ func TestParseFunctionToolCalls_GlmBlockWithExtraAttrs(t *testing.T) {
244
+ content := `<glm_block id="123" type="tool_call" status="pending">{"function":{"name":"fn1","arguments":"{}"}}</glm_block>`
245
+
246
+ calls := ParseFunctionToolCalls(content)
247
+ if len(calls) != 1 {
248
+ t.Fatalf("len(calls) = %d, want 1", len(calls))
249
+ }
250
+ if calls[0].Function.Name != "fn1" {
251
+ t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "fn1")
252
+ }
253
+ }
254
+
255
+ func TestParseFunctionToolCalls_GlmBlockInvalidJSON(t *testing.T) {
256
+ content := `<glm_block type="tool_call">not valid json</glm_block>`
257
+
258
+ calls := ParseFunctionToolCalls(content)
259
+ if len(calls) != 0 {
260
+ t.Errorf("len(calls) = %d, want 0", len(calls))
261
+ }
262
+ }
263
+
264
+ // ===== 优先级:glm_block 优先于原始 JSON =====
265
+
266
+ func TestParseFunctionToolCalls_GlmBlockPriority(t *testing.T) {
267
+ // 如果同时存在 glm_block 和外层 JSON,优先从 glm_block 提取
268
+ content := `<glm_block type="tool_call">{"function":{"name":"from_block","arguments":"{}"}}</glm_block>`
269
+
270
+ calls := ParseFunctionToolCalls(content)
271
+ if len(calls) != 1 {
272
+ t.Fatalf("len(calls) = %d, want 1", len(calls))
273
+ }
274
+ if calls[0].Function.Name != "from_block" {
275
+ t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "from_block")
276
+ }
277
+ }
internal/handler/chat.go CHANGED
@@ -45,7 +45,7 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
45
  req.Model = "GLM-4.6"
46
  }
47
 
48
- resp, modelName, err := upstream.MakeUpstreamRequest(token, req.Messages, req.Model)
49
  if err != nil {
50
  logger.LogError("Upstream request failed: %v", err)
51
  http.Error(w, "Upstream error", http.StatusBadGateway)
@@ -67,13 +67,13 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
67
  completionID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:29])
68
 
69
  if req.Stream {
70
- handleStreamResponse(w, resp.Body, completionID, modelName)
71
  } else {
72
- handleNonStreamResponse(w, resp.Body, completionID, modelName)
73
  }
74
  }
75
 
76
- func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string) {
77
  w.Header().Set("Content-Type", "text/event-stream")
78
  w.Header().Set("Cache-Control", "no-cache")
79
  w.Header().Set("Connection", "keep-alive")
@@ -92,6 +92,8 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
92
  pendingSourcesMarkdown := ""
93
  pendingImageSearchMarkdown := ""
94
  totalContentOutputLength := 0
 
 
95
 
96
  for scanner.Scan() {
97
  line := scanner.Text()
@@ -221,6 +223,43 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
221
  if editContent != "" && filter.IsSearchToolCall(editContent, upstreamData.Data.Phase) {
222
  continue
223
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  if pendingSourcesMarkdown != "" {
226
  hasContent = true
@@ -410,6 +449,9 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
410
  }
411
 
412
  stopReason := "stop"
 
 
 
413
  finalChunk := model.ChatCompletionChunk{
414
  ID: completionID,
415
  Object: "chat.completion.chunk",
@@ -428,7 +470,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
428
  flusher.Flush()
429
  }
430
 
431
- func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string) {
432
  scanner := bufio.NewScanner(body)
433
  scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
434
  var chunks []string
@@ -438,6 +480,7 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
438
  hasThinking := false
439
  pendingSourcesMarkdown := ""
440
  pendingImageSearchMarkdown := ""
 
441
 
442
  for scanner.Scan() {
443
  line := scanner.Text()
@@ -510,6 +553,22 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
510
  if editContent != "" && filter.IsSearchToolCall(editContent, upstreamData.Data.Phase) {
511
  continue
512
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
  if pendingSourcesMarkdown != "" {
515
  if hasThinking {
@@ -557,11 +616,14 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
557
  fullReasoning := strings.Join(reasoningChunks, "")
558
  fullReasoning = searchRefFilter.Process(fullReasoning) + searchRefFilter.Flush()
559
 
560
- if fullContent == "" {
561
  logger.LogError("Non-stream response 200 but no content received")
562
  }
563
 
564
  stopReason := "stop"
 
 
 
565
  response := model.ChatCompletionResponse{
566
  ID: completionID,
567
  Object: "chat.completion",
@@ -573,6 +635,7 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
573
  Role: "assistant",
574
  Content: fullContent,
575
  ReasoningContent: fullReasoning,
 
576
  },
577
  FinishReason: &stopReason,
578
  }},
 
45
  req.Model = "GLM-4.6"
46
  }
47
 
48
+ resp, modelName, err := upstream.MakeUpstreamRequest(token, req.Messages, req.Model, req.Tools, req.ToolChoice)
49
  if err != nil {
50
  logger.LogError("Upstream request failed: %v", err)
51
  http.Error(w, "Upstream error", http.StatusBadGateway)
 
67
  completionID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:29])
68
 
69
  if req.Stream {
70
+ handleStreamResponse(w, resp.Body, completionID, modelName, req.Tools)
71
  } else {
72
+ handleNonStreamResponse(w, resp.Body, completionID, modelName, req.Tools)
73
  }
74
  }
75
 
76
+ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string, tools []model.Tool) {
77
  w.Header().Set("Content-Type", "text/event-stream")
78
  w.Header().Set("Cache-Control", "no-cache")
79
  w.Header().Set("Connection", "keep-alive")
 
92
  pendingSourcesMarkdown := ""
93
  pendingImageSearchMarkdown := ""
94
  totalContentOutputLength := 0
95
+ hasToolCalls := false
96
+ var collectedToolCalls []model.ToolCall
97
 
98
  for scanner.Scan() {
99
  line := scanner.Text()
 
223
  if editContent != "" && filter.IsSearchToolCall(editContent, upstreamData.Data.Phase) {
224
  continue
225
  }
226
+ // 检测用户定义的函数调用(tool_call 阶段,非 mcp/search)
227
+ if upstreamData.Data.Phase == "tool_call" && editContent != "" {
228
+ logger.LogInfo("[ToolCall] phase=%s edit_content=%s", upstreamData.Data.Phase, editContent)
229
+ }
230
+ if len(tools) > 0 && editContent != "" && filter.IsFunctionToolCall(editContent, upstreamData.Data.Phase) {
231
+ if toolCalls := filter.ParseFunctionToolCalls(editContent); len(toolCalls) > 0 {
232
+ for i := range toolCalls {
233
+ if toolCalls[i].ID == "" {
234
+ toolCalls[i].ID = fmt.Sprintf("call_%s", uuid.New().String()[:24])
235
+ }
236
+ toolCalls[i].Index = i
237
+ }
238
+ collectedToolCalls = toolCalls
239
+ hasToolCalls = true
240
+
241
+ for _, tc := range toolCalls {
242
+ hasContent = true
243
+ chunk := model.ChatCompletionChunk{
244
+ ID: completionID,
245
+ Object: "chat.completion.chunk",
246
+ Created: time.Now().Unix(),
247
+ Model: modelName,
248
+ Choices: []model.Choice{{
249
+ Index: 0,
250
+ Delta: model.Delta{
251
+ ToolCalls: []model.ToolCall{tc},
252
+ },
253
+ FinishReason: nil,
254
+ }},
255
+ }
256
+ data, _ := json.Marshal(chunk)
257
+ fmt.Fprintf(w, "data: %s\n\n", data)
258
+ flusher.Flush()
259
+ }
260
+ }
261
+ continue
262
+ }
263
 
264
  if pendingSourcesMarkdown != "" {
265
  hasContent = true
 
449
  }
450
 
451
  stopReason := "stop"
452
+ if hasToolCalls && len(collectedToolCalls) > 0 {
453
+ stopReason = "tool_calls"
454
+ }
455
  finalChunk := model.ChatCompletionChunk{
456
  ID: completionID,
457
  Object: "chat.completion.chunk",
 
470
  flusher.Flush()
471
  }
472
 
473
+ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string, tools []model.Tool) {
474
  scanner := bufio.NewScanner(body)
475
  scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
476
  var chunks []string
 
480
  hasThinking := false
481
  pendingSourcesMarkdown := ""
482
  pendingImageSearchMarkdown := ""
483
+ var collectedToolCalls []model.ToolCall
484
 
485
  for scanner.Scan() {
486
  line := scanner.Text()
 
553
  if editContent != "" && filter.IsSearchToolCall(editContent, upstreamData.Data.Phase) {
554
  continue
555
  }
556
+ // 检测用户定义的函数调用
557
+ if upstreamData.Data.Phase == "tool_call" && editContent != "" {
558
+ logger.LogInfo("[ToolCall] phase=%s edit_content=%s", upstreamData.Data.Phase, editContent)
559
+ }
560
+ if len(tools) > 0 && editContent != "" && filter.IsFunctionToolCall(editContent, upstreamData.Data.Phase) {
561
+ if toolCalls := filter.ParseFunctionToolCalls(editContent); len(toolCalls) > 0 {
562
+ for i := range toolCalls {
563
+ if toolCalls[i].ID == "" {
564
+ toolCalls[i].ID = fmt.Sprintf("call_%s", uuid.New().String()[:24])
565
+ }
566
+ toolCalls[i].Index = i
567
+ }
568
+ collectedToolCalls = toolCalls
569
+ }
570
+ continue
571
+ }
572
 
573
  if pendingSourcesMarkdown != "" {
574
  if hasThinking {
 
616
  fullReasoning := strings.Join(reasoningChunks, "")
617
  fullReasoning = searchRefFilter.Process(fullReasoning) + searchRefFilter.Flush()
618
 
619
+ if fullContent == "" && len(collectedToolCalls) == 0 {
620
  logger.LogError("Non-stream response 200 but no content received")
621
  }
622
 
623
  stopReason := "stop"
624
+ if len(collectedToolCalls) > 0 {
625
+ stopReason = "tool_calls"
626
+ }
627
  response := model.ChatCompletionResponse{
628
  ID: completionID,
629
  Object: "chat.completion",
 
635
  Role: "assistant",
636
  Content: fullContent,
637
  ReasoningContent: fullReasoning,
638
+ ToolCalls: collectedToolCalls,
639
  },
640
  FinishReason: &stopReason,
641
  }},
internal/handler/chat_test.go ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package handler
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "io"
7
+ "net/http/httptest"
8
+ "strings"
9
+ "testing"
10
+
11
+ "zai-proxy/internal/model"
12
+ )
13
+
14
+ // fakeReadCloser 将 string 包装为 io.ReadCloser
15
+ type fakeReadCloser struct {
16
+ io.Reader
17
+ }
18
+
19
+ func (f *fakeReadCloser) Close() error { return nil }
20
+
21
+ func newFakeBody(lines ...string) io.ReadCloser {
22
+ return &fakeReadCloser{Reader: strings.NewReader(strings.Join(lines, "\n"))}
23
+ }
24
+
25
+ // 构造上游 SSE 数据行
26
+ func sseEvent(phase, deltaContent, editContent string) string {
27
+ data := model.UpstreamData{}
28
+ data.Data.Phase = phase
29
+ data.Data.DeltaContent = deltaContent
30
+ data.Data.EditContent = editContent
31
+ b, _ := json.Marshal(data)
32
+ return fmt.Sprintf("data: %s", string(b))
33
+ }
34
+
35
+ func sseEventDone() string {
36
+ return sseEvent("done", "", "")
37
+ }
38
+
39
+ func dummyTools() []model.Tool {
40
+ return []model.Tool{{
41
+ Type: "function",
42
+ Function: model.ToolFunction{
43
+ Name: "get_weather",
44
+ Description: "获取天气",
45
+ },
46
+ }}
47
+ }
48
+
49
+ // ===== 流式:普通文本回复 =====
50
+
51
+ func TestStreamResponse_NormalContent(t *testing.T) {
52
+ body := newFakeBody(
53
+ sseEvent("answer", "Hello", ""),
54
+ sseEvent("answer", " World", ""),
55
+ sseEventDone(),
56
+ )
57
+
58
+ w := httptest.NewRecorder()
59
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
60
+
61
+ result := w.Body.String()
62
+
63
+ // 应包含内容 chunk
64
+ if !strings.Contains(result, "Hello") {
65
+ t.Error("missing 'Hello' in stream output")
66
+ }
67
+ if !strings.Contains(result, "World") {
68
+ t.Error("missing 'World' in stream output")
69
+ }
70
+
71
+ // finish_reason 应该是 "stop"
72
+ if !strings.Contains(result, `"finish_reason":"stop"`) {
73
+ t.Error("finish_reason should be 'stop'")
74
+ }
75
+
76
+ // 应以 [DONE] 结尾
77
+ if !strings.Contains(result, "data: [DONE]") {
78
+ t.Error("missing [DONE]")
79
+ }
80
+ }
81
+
82
+ // ===== 流式:tool_call 回复 =====
83
+
84
+ func TestStreamResponse_ToolCall(t *testing.T) {
85
+ toolCallJSON := `{"id":"call_test123","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"北京\"}"}}`
86
+
87
+ body := newFakeBody(
88
+ sseEvent("tool_call", "", toolCallJSON),
89
+ sseEventDone(),
90
+ )
91
+
92
+ w := httptest.NewRecorder()
93
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
94
+
95
+ result := w.Body.String()
96
+
97
+ // 应包含 tool_calls
98
+ if !strings.Contains(result, `"tool_calls"`) {
99
+ t.Error("missing tool_calls in stream output")
100
+ }
101
+ if !strings.Contains(result, `"get_weather"`) {
102
+ t.Error("missing function name in stream output")
103
+ }
104
+ if !strings.Contains(result, `call_test123`) {
105
+ t.Error("missing tool call ID in stream output")
106
+ }
107
+
108
+ // finish_reason 应该是 "tool_calls"
109
+ if !strings.Contains(result, `"finish_reason":"tool_calls"`) {
110
+ t.Error("finish_reason should be 'tool_calls'")
111
+ }
112
+ }
113
+
114
+ // ===== 流式:tool_call 无 ID(自动分配)=====
115
+
116
+ func TestStreamResponse_ToolCallAutoID(t *testing.T) {
117
+ toolCallJSON := `{"type":"function","function":{"name":"get_weather","arguments":"{}"}}`
118
+
119
+ body := newFakeBody(
120
+ sseEvent("tool_call", "", toolCallJSON),
121
+ sseEventDone(),
122
+ )
123
+
124
+ w := httptest.NewRecorder()
125
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
126
+
127
+ result := w.Body.String()
128
+
129
+ // 应自动分配 call_ 前缀的 ID
130
+ if !strings.Contains(result, `"id":"call_`) {
131
+ t.Error("missing auto-generated tool call ID")
132
+ }
133
+ if !strings.Contains(result, `"finish_reason":"tool_calls"`) {
134
+ t.Error("finish_reason should be 'tool_calls'")
135
+ }
136
+ }
137
+
138
+ // ===== 流式:无 tools 时 tool_call 阶段被忽略 =====
139
+
140
+ func TestStreamResponse_ToolCallWithoutToolsDef(t *testing.T) {
141
+ toolCallJSON := `{"type":"function","function":{"name":"get_weather","arguments":"{}"}}`
142
+
143
+ body := newFakeBody(
144
+ sseEvent("answer", "text before", ""),
145
+ sseEvent("tool_call", "", toolCallJSON),
146
+ sseEventDone(),
147
+ )
148
+
149
+ w := httptest.NewRecorder()
150
+ // 不传 tools,tool_call 不应被解析为函数调用
151
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
152
+
153
+ result := w.Body.String()
154
+
155
+ // finish_reason 应为 "stop"(没有检测到函数调用)
156
+ if !strings.Contains(result, `"finish_reason":"stop"`) {
157
+ t.Error("finish_reason should be 'stop' when no tools defined")
158
+ }
159
+ }
160
+
161
+ // ===== 流式:mcp tool_call 被跳过 =====
162
+
163
+ func TestStreamResponse_McpToolCallSkipped(t *testing.T) {
164
+ mcpContent := `{"type":"mcp","name":"mcp-server-xxx","arguments":"{}"}`
165
+
166
+ body := newFakeBody(
167
+ sseEvent("answer", "response text", ""),
168
+ sseEvent("tool_call", "", mcpContent),
169
+ sseEventDone(),
170
+ )
171
+
172
+ w := httptest.NewRecorder()
173
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
174
+
175
+ result := w.Body.String()
176
+
177
+ // mcp 类型的 tool_call 不应出现在输出中
178
+ if strings.Contains(result, `mcp-server`) {
179
+ t.Error("mcp tool call should be filtered out")
180
+ }
181
+ // 应为 "stop"(mcp 不算用户函数调用)
182
+ if !strings.Contains(result, `"finish_reason":"stop"`) {
183
+ t.Error("finish_reason should be 'stop'")
184
+ }
185
+ }
186
+
187
+ // ===== 流式:混合内容 + tool_call =====
188
+
189
+ func TestStreamResponse_ContentThenToolCall(t *testing.T) {
190
+ toolCallJSON := `{"function":{"name":"get_weather","arguments":"{}"}}`
191
+
192
+ body := newFakeBody(
193
+ sseEvent("answer", "Let me check ", ""),
194
+ sseEvent("answer", "the weather.", ""),
195
+ sseEvent("tool_call", "", toolCallJSON),
196
+ sseEventDone(),
197
+ )
198
+
199
+ w := httptest.NewRecorder()
200
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
201
+
202
+ result := w.Body.String()
203
+
204
+ if !strings.Contains(result, "Let me check") {
205
+ t.Error("missing content text")
206
+ }
207
+ if !strings.Contains(result, `"get_weather"`) {
208
+ t.Error("missing tool call")
209
+ }
210
+ if !strings.Contains(result, `"finish_reason":"tool_calls"`) {
211
+ t.Error("finish_reason should be 'tool_calls'")
212
+ }
213
+ }
214
+
215
+ // ===== 流式:多个 tool_call =====
216
+
217
+ func TestStreamResponse_MultipleToolCalls(t *testing.T) {
218
+ toolCallJSON := `[{"id":"c1","type":"function","function":{"name":"fn1","arguments":"{}"}},{"id":"c2","type":"function","function":{"name":"fn2","arguments":"{}"}}]`
219
+
220
+ body := newFakeBody(
221
+ sseEvent("tool_call", "", toolCallJSON),
222
+ sseEventDone(),
223
+ )
224
+
225
+ w := httptest.NewRecorder()
226
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
227
+
228
+ result := w.Body.String()
229
+
230
+ if !strings.Contains(result, `"fn1"`) {
231
+ t.Error("missing fn1")
232
+ }
233
+ if !strings.Contains(result, `"fn2"`) {
234
+ t.Error("missing fn2")
235
+ }
236
+
237
+ // 验证 chunk 数量:每个 tool_call 一个 delta chunk(包含 "tool_calls" 在 delta 中)
238
+ chunks := strings.Split(result, "data: ")
239
+ toolCallDeltaChunks := 0
240
+ for _, chunk := range chunks {
241
+ // 只计算 delta 中包含 tool_calls 的 chunk,排除 finish_reason 中的
242
+ if strings.Contains(chunk, `"tool_calls":[{`) {
243
+ toolCallDeltaChunks++
244
+ }
245
+ }
246
+ if toolCallDeltaChunks != 2 {
247
+ t.Errorf("tool_call delta chunks = %d, want 2", toolCallDeltaChunks)
248
+ }
249
+ }
250
+
251
+ // ===== 非流式:普通文本回复 =====
252
+
253
+ func TestNonStreamResponse_NormalContent(t *testing.T) {
254
+ body := newFakeBody(
255
+ sseEvent("answer", "Hello World", ""),
256
+ sseEventDone(),
257
+ )
258
+
259
+ w := httptest.NewRecorder()
260
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
261
+
262
+ var resp model.ChatCompletionResponse
263
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
264
+ t.Fatalf("decode response: %v", err)
265
+ }
266
+
267
+ if len(resp.Choices) != 1 {
268
+ t.Fatalf("len(Choices) = %d", len(resp.Choices))
269
+ }
270
+ if resp.Choices[0].Message == nil {
271
+ t.Fatal("Message is nil")
272
+ }
273
+ if resp.Choices[0].Message.Content != "Hello World" {
274
+ t.Errorf("Content = %q, want %q", resp.Choices[0].Message.Content, "Hello World")
275
+ }
276
+ if *resp.Choices[0].FinishReason != "stop" {
277
+ t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "stop")
278
+ }
279
+ if len(resp.Choices[0].Message.ToolCalls) != 0 {
280
+ t.Errorf("len(ToolCalls) = %d, want 0", len(resp.Choices[0].Message.ToolCalls))
281
+ }
282
+ }
283
+
284
+ // ===== 非流式:tool_call 回复 =====
285
+
286
+ func TestNonStreamResponse_ToolCall(t *testing.T) {
287
+ toolCallJSON := `{"id":"call_ns","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"上海\"}"}}`
288
+
289
+ body := newFakeBody(
290
+ sseEvent("tool_call", "", toolCallJSON),
291
+ sseEventDone(),
292
+ )
293
+
294
+ w := httptest.NewRecorder()
295
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
296
+
297
+ var resp model.ChatCompletionResponse
298
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
299
+ t.Fatalf("decode: %v", err)
300
+ }
301
+
302
+ msg := resp.Choices[0].Message
303
+ if msg == nil {
304
+ t.Fatal("Message is nil")
305
+ }
306
+ if len(msg.ToolCalls) != 1 {
307
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
308
+ }
309
+ if msg.ToolCalls[0].Function.Name != "get_weather" {
310
+ t.Errorf("Function.Name = %q, want %q", msg.ToolCalls[0].Function.Name, "get_weather")
311
+ }
312
+ if msg.ToolCalls[0].Function.Arguments != `{"location":"上海"}` {
313
+ t.Errorf("Function.Arguments = %q", msg.ToolCalls[0].Function.Arguments)
314
+ }
315
+ if *resp.Choices[0].FinishReason != "tool_calls" {
316
+ t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "tool_calls")
317
+ }
318
+ }
319
+
320
+ // ===== 非流式:tool_call 无 ID =====
321
+
322
+ func TestNonStreamResponse_ToolCallAutoID(t *testing.T) {
323
+ toolCallJSON := `{"function":{"name":"fn1","arguments":"{}"}}`
324
+
325
+ body := newFakeBody(
326
+ sseEvent("tool_call", "", toolCallJSON),
327
+ sseEventDone(),
328
+ )
329
+
330
+ w := httptest.NewRecorder()
331
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
332
+
333
+ var resp model.ChatCompletionResponse
334
+ json.NewDecoder(w.Body).Decode(&resp)
335
+
336
+ msg := resp.Choices[0].Message
337
+ if len(msg.ToolCalls) != 1 {
338
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
339
+ }
340
+ if !strings.HasPrefix(msg.ToolCalls[0].ID, "call_") {
341
+ t.Errorf("ID = %q, should have 'call_' prefix", msg.ToolCalls[0].ID)
342
+ }
343
+ }
344
+
345
+ // ===== 非流式:无 tools 定义时不解析 tool_call =====
346
+
347
+ func TestNonStreamResponse_ToolCallWithoutToolsDef(t *testing.T) {
348
+ toolCallJSON := `{"function":{"name":"get_weather","arguments":"{}"}}`
349
+
350
+ body := newFakeBody(
351
+ sseEvent("tool_call", "", toolCallJSON),
352
+ sseEventDone(),
353
+ )
354
+
355
+ w := httptest.NewRecorder()
356
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
357
+
358
+ var resp model.ChatCompletionResponse
359
+ json.NewDecoder(w.Body).Decode(&resp)
360
+
361
+ if *resp.Choices[0].FinishReason != "stop" {
362
+ t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "stop")
363
+ }
364
+ if len(resp.Choices[0].Message.ToolCalls) != 0 {
365
+ t.Errorf("len(ToolCalls) = %d, want 0", len(resp.Choices[0].Message.ToolCalls))
366
+ }
367
+ }
368
+
369
+ // ===== 非流式:mcp tool_call 被跳过 =====
370
+
371
+ func TestNonStreamResponse_McpToolCallSkipped(t *testing.T) {
372
+ mcpContent := `{"type":"mcp","name":"mcp-server-xxx","arguments":"{}"}`
373
+
374
+ body := newFakeBody(
375
+ sseEvent("answer", "response", ""),
376
+ sseEvent("tool_call", "", mcpContent),
377
+ sseEventDone(),
378
+ )
379
+
380
+ w := httptest.NewRecorder()
381
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
382
+
383
+ var resp model.ChatCompletionResponse
384
+ json.NewDecoder(w.Body).Decode(&resp)
385
+
386
+ if *resp.Choices[0].FinishReason != "stop" {
387
+ t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "stop")
388
+ }
389
+ if len(resp.Choices[0].Message.ToolCalls) != 0 {
390
+ t.Errorf("should not have tool_calls for mcp")
391
+ }
392
+ }
393
+
394
+ // ===== 非流式:内容 + tool_call =====
395
+
396
+ func TestNonStreamResponse_ContentAndToolCall(t *testing.T) {
397
+ toolCallJSON := `{"function":{"name":"get_weather","arguments":"{}"}}`
398
+
399
+ body := newFakeBody(
400
+ sseEvent("answer", "checking weather...", ""),
401
+ sseEvent("tool_call", "", toolCallJSON),
402
+ sseEventDone(),
403
+ )
404
+
405
+ w := httptest.NewRecorder()
406
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
407
+
408
+ var resp model.ChatCompletionResponse
409
+ json.NewDecoder(w.Body).Decode(&resp)
410
+
411
+ msg := resp.Choices[0].Message
412
+ if msg.Content != "checking weather..." {
413
+ t.Errorf("Content = %q, want %q", msg.Content, "checking weather...")
414
+ }
415
+ if len(msg.ToolCalls) != 1 {
416
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
417
+ }
418
+ if *resp.Choices[0].FinishReason != "tool_calls" {
419
+ t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "tool_calls")
420
+ }
421
+ }
422
+
423
+ // ===== 非流式:多个 tool_call =====
424
+
425
+ func TestNonStreamResponse_MultipleToolCalls(t *testing.T) {
426
+ toolCallJSON := `[{"id":"c1","type":"function","function":{"name":"fn1","arguments":"{}"}},{"id":"c2","type":"function","function":{"name":"fn2","arguments":"{\"x\":1}"}}]`
427
+
428
+ body := newFakeBody(
429
+ sseEvent("tool_call", "", toolCallJSON),
430
+ sseEventDone(),
431
+ )
432
+
433
+ w := httptest.NewRecorder()
434
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
435
+
436
+ var resp model.ChatCompletionResponse
437
+ json.NewDecoder(w.Body).Decode(&resp)
438
+
439
+ msg := resp.Choices[0].Message
440
+ if len(msg.ToolCalls) != 2 {
441
+ t.Fatalf("len(ToolCalls) = %d, want 2", len(msg.ToolCalls))
442
+ }
443
+ if msg.ToolCalls[0].Function.Name != "fn1" {
444
+ t.Errorf("ToolCalls[0].Function.Name = %q", msg.ToolCalls[0].Function.Name)
445
+ }
446
+ if msg.ToolCalls[1].Function.Name != "fn2" {
447
+ t.Errorf("ToolCalls[1].Function.Name = %q", msg.ToolCalls[1].Function.Name)
448
+ }
449
+ if msg.ToolCalls[0].Index != 0 || msg.ToolCalls[1].Index != 1 {
450
+ t.Errorf("Indices = [%d, %d], want [0, 1]", msg.ToolCalls[0].Index, msg.ToolCalls[1].Index)
451
+ }
452
+ }
453
+
454
+ // ===== 非流式:glm_block 包裹的 tool_call =====
455
+
456
+ func TestNonStreamResponse_GlmBlockToolCall(t *testing.T) {
457
+ editContent := `<glm_block type="tool_call">{"id":"call_glm","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"深圳\"}"}}</glm_block>`
458
+
459
+ body := newFakeBody(
460
+ sseEvent("tool_call", "", editContent),
461
+ sseEventDone(),
462
+ )
463
+
464
+ w := httptest.NewRecorder()
465
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
466
+
467
+ var resp model.ChatCompletionResponse
468
+ json.NewDecoder(w.Body).Decode(&resp)
469
+
470
+ msg := resp.Choices[0].Message
471
+ if len(msg.ToolCalls) != 1 {
472
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
473
+ }
474
+ if msg.ToolCalls[0].ID != "call_glm" {
475
+ t.Errorf("ID = %q, want %q", msg.ToolCalls[0].ID, "call_glm")
476
+ }
477
+ if msg.ToolCalls[0].Function.Name != "get_weather" {
478
+ t.Errorf("Function.Name = %q", msg.ToolCalls[0].Function.Name)
479
+ }
480
+ if *resp.Choices[0].FinishReason != "tool_calls" {
481
+ t.Errorf("FinishReason = %q", *resp.Choices[0].FinishReason)
482
+ }
483
+ }
484
+
485
+ // ===== 流式:SSE headers 验证 =====
486
+
487
+ func TestStreamResponse_Headers(t *testing.T) {
488
+ body := newFakeBody(sseEventDone())
489
+
490
+ w := httptest.NewRecorder()
491
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
492
+
493
+ if ct := w.Header().Get("Content-Type"); ct != "text/event-stream" {
494
+ t.Errorf("Content-Type = %q, want %q", ct, "text/event-stream")
495
+ }
496
+ if cc := w.Header().Get("Cache-Control"); cc != "no-cache" {
497
+ t.Errorf("Cache-Control = %q, want %q", cc, "no-cache")
498
+ }
499
+ }
500
+
501
+ // ===== 非流式:response headers 验证 =====
502
+
503
+ func TestNonStreamResponse_Headers(t *testing.T) {
504
+ body := newFakeBody(sseEventDone())
505
+
506
+ w := httptest.NewRecorder()
507
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
508
+
509
+ if ct := w.Header().Get("Content-Type"); ct != "application/json" {
510
+ t.Errorf("Content-Type = %q, want %q", ct, "application/json")
511
+ }
512
+ }
513
+
514
+ // ===== 流式:空数据 =====
515
+
516
+ func TestStreamResponse_EmptyBody(t *testing.T) {
517
+ body := newFakeBody(sseEventDone())
518
+
519
+ w := httptest.NewRecorder()
520
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
521
+
522
+ result := w.Body.String()
523
+ if !strings.Contains(result, `"finish_reason":"stop"`) {
524
+ t.Error("should have stop finish_reason")
525
+ }
526
+ if !strings.Contains(result, "data: [DONE]") {
527
+ t.Error("missing [DONE]")
528
+ }
529
+ }
530
+
531
+ // ===== 流式:[DONE] 信号 =====
532
+
533
+ func TestStreamResponse_DoneSignal(t *testing.T) {
534
+ body := newFakeBody(
535
+ sseEvent("answer", "hello", ""),
536
+ "data: [DONE]",
537
+ )
538
+
539
+ w := httptest.NewRecorder()
540
+ handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
541
+
542
+ result := w.Body.String()
543
+ if !strings.Contains(result, "hello") {
544
+ t.Error("missing content")
545
+ }
546
+ }
547
+
548
+ // ===== 非流式:response 格式完整性 =====
549
+
550
+ func TestNonStreamResponse_FullFormat(t *testing.T) {
551
+ body := newFakeBody(
552
+ sseEvent("answer", "test response", ""),
553
+ sseEventDone(),
554
+ )
555
+
556
+ w := httptest.NewRecorder()
557
+ handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
558
+
559
+ var resp model.ChatCompletionResponse
560
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
561
+ t.Fatalf("decode: %v", err)
562
+ }
563
+
564
+ if resp.ID != "chatcmpl-test" {
565
+ t.Errorf("ID = %q", resp.ID)
566
+ }
567
+ if resp.Object != "chat.completion" {
568
+ t.Errorf("Object = %q", resp.Object)
569
+ }
570
+ if resp.Model != "glm-4.7" {
571
+ t.Errorf("Model = %q", resp.Model)
572
+ }
573
+ if resp.Choices[0].Message.Role != "assistant" {
574
+ t.Errorf("Role = %q", resp.Choices[0].Message.Role)
575
+ }
576
+ }
internal/model/mapping.go CHANGED
@@ -20,6 +20,8 @@ var ModelList = []string{
20
  "GLM-4.7",
21
  "GLM-4.7-thinking",
22
  "GLM-4.7-thinking-search",
 
 
23
  "GLM-4.5-V",
24
  "GLM-4.6-V",
25
  "GLM-4.6-V-thinking",
@@ -28,13 +30,14 @@ var ModelList = []string{
28
  }
29
 
30
  // 解析模型名称,提取基础模型名和标签
31
- // 支持 -thinking 和 -search 标签的任意排列组合
32
- func ParseModelName(model string) (baseModel string, enableThinking bool, enableSearch bool) {
33
  enableThinking = false
34
  enableSearch = false
 
35
  baseModel = model
36
 
37
- // 检查并移除 -thinking 和 -search 标签(任意顺序)
38
  for {
39
  if strings.HasSuffix(baseModel, "-thinking") {
40
  enableThinking = true
@@ -42,26 +45,34 @@ func ParseModelName(model string) (baseModel string, enableThinking bool, enable
42
  } else if strings.HasSuffix(baseModel, "-search") {
43
  enableSearch = true
44
  baseModel = strings.TrimSuffix(baseModel, "-search")
 
 
 
45
  } else {
46
  break
47
  }
48
  }
49
 
50
- return baseModel, enableThinking, enableSearch
51
  }
52
 
53
  func IsThinkingModel(model string) bool {
54
- _, enableThinking, _ := ParseModelName(model)
55
  return enableThinking
56
  }
57
 
58
  func IsSearchModel(model string) bool {
59
- _, _, enableSearch := ParseModelName(model)
60
  return enableSearch
61
  }
62
 
 
 
 
 
 
63
  func GetTargetModel(model string) string {
64
- baseModel, _, _ := ParseModelName(model)
65
  if target, ok := BaseModelMapping[baseModel]; ok {
66
  return target
67
  }
 
20
  "GLM-4.7",
21
  "GLM-4.7-thinking",
22
  "GLM-4.7-thinking-search",
23
+ "GLM-4.7-tools",
24
+ "GLM-4.7-tools-thinking",
25
  "GLM-4.5-V",
26
  "GLM-4.6-V",
27
  "GLM-4.6-V-thinking",
 
30
  }
31
 
32
  // 解析模型名称,提取基础模型名和标签
33
+ // 支持 -thinking、-search 和 -tools 标签的任意排列组合
34
+ func ParseModelName(model string) (baseModel string, enableThinking bool, enableSearch bool, enableTools bool) {
35
  enableThinking = false
36
  enableSearch = false
37
+ enableTools = false
38
  baseModel = model
39
 
40
+ // 检查并移除 -thinking、-search 和 -tools 标签(任意顺序)
41
  for {
42
  if strings.HasSuffix(baseModel, "-thinking") {
43
  enableThinking = true
 
45
  } else if strings.HasSuffix(baseModel, "-search") {
46
  enableSearch = true
47
  baseModel = strings.TrimSuffix(baseModel, "-search")
48
+ } else if strings.HasSuffix(baseModel, "-tools") {
49
+ enableTools = true
50
+ baseModel = strings.TrimSuffix(baseModel, "-tools")
51
  } else {
52
  break
53
  }
54
  }
55
 
56
+ return baseModel, enableThinking, enableSearch, enableTools
57
  }
58
 
59
  func IsThinkingModel(model string) bool {
60
+ _, enableThinking, _, _ := ParseModelName(model)
61
  return enableThinking
62
  }
63
 
64
  func IsSearchModel(model string) bool {
65
+ _, _, enableSearch, _ := ParseModelName(model)
66
  return enableSearch
67
  }
68
 
69
+ func IsToolsModel(model string) bool {
70
+ _, _, _, enableTools := ParseModelName(model)
71
+ return enableTools
72
+ }
73
+
74
  func GetTargetModel(model string) string {
75
+ baseModel, _, _, _ := ParseModelName(model)
76
  if target, ok := BaseModelMapping[baseModel]; ok {
77
  return target
78
  }
internal/model/mapping_test.go ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package model
2
+
3
+ import "testing"
4
+
5
+ // ===== ParseModelName =====
6
+
7
+ func TestParseModelName_Plain(t *testing.T) {
8
+ base, thinking, search, tools := ParseModelName("GLM-4.7")
9
+ if base != "GLM-4.7" {
10
+ t.Errorf("base = %q, want %q", base, "GLM-4.7")
11
+ }
12
+ if thinking || search || tools {
13
+ t.Errorf("flags = (%v, %v, %v), want all false", thinking, search, tools)
14
+ }
15
+ }
16
+
17
+ func TestParseModelName_Thinking(t *testing.T) {
18
+ base, thinking, search, tools := ParseModelName("GLM-4.7-thinking")
19
+ if base != "GLM-4.7" {
20
+ t.Errorf("base = %q", base)
21
+ }
22
+ if !thinking {
23
+ t.Error("thinking should be true")
24
+ }
25
+ if search || tools {
26
+ t.Error("search and tools should be false")
27
+ }
28
+ }
29
+
30
+ func TestParseModelName_Search(t *testing.T) {
31
+ base, thinking, search, tools := ParseModelName("GLM-4.7-search")
32
+ if base != "GLM-4.7" {
33
+ t.Errorf("base = %q", base)
34
+ }
35
+ if !search {
36
+ t.Error("search should be true")
37
+ }
38
+ if thinking || tools {
39
+ t.Error("thinking and tools should be false")
40
+ }
41
+ }
42
+
43
+ func TestParseModelName_Tools(t *testing.T) {
44
+ base, thinking, search, tools := ParseModelName("GLM-4.7-tools")
45
+ if base != "GLM-4.7" {
46
+ t.Errorf("base = %q", base)
47
+ }
48
+ if !tools {
49
+ t.Error("tools should be true")
50
+ }
51
+ if thinking || search {
52
+ t.Error("thinking and search should be false")
53
+ }
54
+ }
55
+
56
+ func TestParseModelName_ThinkingSearch(t *testing.T) {
57
+ base, thinking, search, tools := ParseModelName("GLM-4.7-thinking-search")
58
+ if base != "GLM-4.7" {
59
+ t.Errorf("base = %q", base)
60
+ }
61
+ if !thinking || !search {
62
+ t.Error("thinking and search should both be true")
63
+ }
64
+ if tools {
65
+ t.Error("tools should be false")
66
+ }
67
+ }
68
+
69
+ func TestParseModelName_ToolsThinking(t *testing.T) {
70
+ base, thinking, search, tools := ParseModelName("GLM-4.7-tools-thinking")
71
+ if base != "GLM-4.7" {
72
+ t.Errorf("base = %q", base)
73
+ }
74
+ if !tools || !thinking {
75
+ t.Error("tools and thinking should both be true")
76
+ }
77
+ if search {
78
+ t.Error("search should be false")
79
+ }
80
+ }
81
+
82
+ func TestParseModelName_ToolsSearch(t *testing.T) {
83
+ base, thinking, search, tools := ParseModelName("GLM-4.7-tools-search")
84
+ if base != "GLM-4.7" {
85
+ t.Errorf("base = %q", base)
86
+ }
87
+ if !tools || !search {
88
+ t.Error("tools and search should both be true")
89
+ }
90
+ if thinking {
91
+ t.Error("thinking should be false")
92
+ }
93
+ }
94
+
95
+ func TestParseModelName_AllTags(t *testing.T) {
96
+ base, thinking, search, tools := ParseModelName("GLM-4.7-tools-thinking-search")
97
+ if base != "GLM-4.7" {
98
+ t.Errorf("base = %q", base)
99
+ }
100
+ if !thinking || !search || !tools {
101
+ t.Errorf("all flags should be true, got (%v, %v, %v)", thinking, search, tools)
102
+ }
103
+ }
104
+
105
+ func TestParseModelName_ReverseOrder(t *testing.T) {
106
+ base, thinking, search, tools := ParseModelName("GLM-4.7-search-thinking-tools")
107
+ if base != "GLM-4.7" {
108
+ t.Errorf("base = %q", base)
109
+ }
110
+ if !thinking || !search || !tools {
111
+ t.Errorf("all flags should be true, got (%v, %v, %v)", thinking, search, tools)
112
+ }
113
+ }
114
+
115
+ // ===== IsToolsModel =====
116
+
117
+ func TestIsToolsModel_True(t *testing.T) {
118
+ tests := []string{
119
+ "GLM-4.7-tools",
120
+ "GLM-4.7-tools-thinking",
121
+ "GLM-4.7-tools-search",
122
+ "GLM-4.7-thinking-tools",
123
+ "GLM-4.5-tools",
124
+ }
125
+ for _, m := range tests {
126
+ if !IsToolsModel(m) {
127
+ t.Errorf("IsToolsModel(%q) = false, want true", m)
128
+ }
129
+ }
130
+ }
131
+
132
+ func TestIsToolsModel_False(t *testing.T) {
133
+ tests := []string{
134
+ "GLM-4.7",
135
+ "GLM-4.7-thinking",
136
+ "GLM-4.7-search",
137
+ "GLM-4.7-thinking-search",
138
+ }
139
+ for _, m := range tests {
140
+ if IsToolsModel(m) {
141
+ t.Errorf("IsToolsModel(%q) = true, want false", m)
142
+ }
143
+ }
144
+ }
145
+
146
+ // ===== IsThinkingModel / IsSearchModel 不受 -tools 影响 =====
147
+
148
+ func TestIsThinkingModel_WithTools(t *testing.T) {
149
+ if !IsThinkingModel("GLM-4.7-tools-thinking") {
150
+ t.Error("IsThinkingModel should be true for GLM-4.7-tools-thinking")
151
+ }
152
+ if IsThinkingModel("GLM-4.7-tools") {
153
+ t.Error("IsThinkingModel should be false for GLM-4.7-tools")
154
+ }
155
+ }
156
+
157
+ func TestIsSearchModel_WithTools(t *testing.T) {
158
+ if !IsSearchModel("GLM-4.7-tools-search") {
159
+ t.Error("IsSearchModel should be true for GLM-4.7-tools-search")
160
+ }
161
+ if IsSearchModel("GLM-4.7-tools") {
162
+ t.Error("IsSearchModel should be false for GLM-4.7-tools")
163
+ }
164
+ }
165
+
166
+ // ===== GetTargetModel with -tools =====
167
+
168
+ func TestGetTargetModel_WithTools(t *testing.T) {
169
+ target := GetTargetModel("GLM-4.7-tools")
170
+ if target != "glm-4.7" {
171
+ t.Errorf("GetTargetModel(GLM-4.7-tools) = %q, want %q", target, "glm-4.7")
172
+ }
173
+ }
174
+
175
+ func TestGetTargetModel_WithToolsThinking(t *testing.T) {
176
+ target := GetTargetModel("GLM-4.7-tools-thinking")
177
+ if target != "glm-4.7" {
178
+ t.Errorf("GetTargetModel(GLM-4.7-tools-thinking) = %q, want %q", target, "glm-4.7")
179
+ }
180
+ }
181
+
182
+ // ===== ModelList 包含 -tools 变体 =====
183
+
184
+ func TestModelList_ContainsToolsVariants(t *testing.T) {
185
+ expected := map[string]bool{
186
+ "GLM-4.7-tools": false,
187
+ "GLM-4.7-tools-thinking": false,
188
+ }
189
+
190
+ for _, m := range ModelList {
191
+ if _, ok := expected[m]; ok {
192
+ expected[m] = true
193
+ }
194
+ }
195
+
196
+ for name, found := range expected {
197
+ if !found {
198
+ t.Errorf("ModelList missing %q", name)
199
+ }
200
+ }
201
+ }
internal/model/types.go CHANGED
@@ -13,10 +13,39 @@ type ImageURL struct {
13
  URL string `json:"url"`
14
  }
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  // Message 支持纯文本和多模态内容
17
  type Message struct {
18
- Role string `json:"role"`
19
- Content interface{} `json:"content"` // string 或 []ContentPart
 
 
20
  }
21
 
22
  // 解析消息内容,返回文本和图片URL列表
@@ -47,6 +76,37 @@ func (m *Message) ParseContent() (text string, imageURLs []string) {
47
 
48
  // 转换为上游消息格式,支持多模态
49
  func (m *Message) ToUpstreamMessage(urlToFileID map[string]string) map[string]interface{} {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  text, imageURLs := m.ParseContent()
51
 
52
  // 无图片,返回纯文本
@@ -83,9 +143,11 @@ func (m *Message) ToUpstreamMessage(urlToFileID map[string]string) map[string]in
83
  }
84
 
85
  type ChatRequest struct {
86
- Model string `json:"model"`
87
- Messages []Message `json:"messages"`
88
- Stream bool `json:"stream"`
 
 
89
  }
90
 
91
  type ChatCompletionChunk struct {
@@ -96,8 +158,6 @@ type ChatCompletionChunk struct {
96
  Choices []Choice `json:"choices"`
97
  }
98
 
99
- type
100
-
101
  type Choice struct {
102
  Index int `json:"index"`
103
  Delta Delta `json:"delta,omitempty"`
@@ -106,14 +166,16 @@ type Choice struct {
106
  }
107
 
108
  type Delta struct {
109
- Content string `json:"content,omitempty"`
110
- ReasoningContent string `json:"reasoning_content,omitempty"`
 
111
  }
112
 
113
  type MessageResp struct {
114
- Role string `json:"role"`
115
- Content string `json:"content"`
116
- ReasoningContent string `json:"reasoning_content,omitempty"`
 
117
  }
118
 
119
  type ChatCompletionResponse struct {
 
13
  URL string `json:"url"`
14
  }
15
 
16
+ // Tool 工具定义(OpenAI 兼容)
17
+ type Tool struct {
18
+ Type string `json:"type"`
19
+ Function ToolFunction `json:"function"`
20
+ }
21
+
22
+ // ToolFunction 函数定义
23
+ type ToolFunction struct {
24
+ Name string `json:"name"`
25
+ Description string `json:"description,omitempty"`
26
+ Parameters interface{} `json:"parameters,omitempty"`
27
+ }
28
+
29
+ // ToolCall 模型返回的工具调用
30
+ type ToolCall struct {
31
+ ID string `json:"id"`
32
+ Type string `json:"type"`
33
+ Function FunctionCall `json:"function"`
34
+ Index int `json:"index"`
35
+ }
36
+
37
+ // FunctionCall 函数调用(名称 + 参数 JSON 字符串)
38
+ type FunctionCall struct {
39
+ Name string `json:"name"`
40
+ Arguments string `json:"arguments"`
41
+ }
42
+
43
  // Message 支持纯文本和多模态内容
44
  type Message struct {
45
+ Role string `json:"role"`
46
+ Content interface{} `json:"content"` // string 或 []ContentPart
47
+ ToolCallID string `json:"tool_call_id,omitempty"` // role: "tool" 时使用
48
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"` // role: "assistant" 时使用
49
  }
50
 
51
  // 解析消息内容,返回文本和图片URL列表
 
76
 
77
  // 转换为上游消息格式,支持多模态
78
  func (m *Message) ToUpstreamMessage(urlToFileID map[string]string) map[string]interface{} {
79
+ // tool 消息:包含 tool_call_id
80
+ if m.Role == "tool" {
81
+ msg := map[string]interface{}{
82
+ "role": m.Role,
83
+ "content": m.Content,
84
+ "tool_call_id": m.ToolCallID,
85
+ }
86
+ return msg
87
+ }
88
+
89
+ // assistant 消息带 tool_calls
90
+ if m.Role == "assistant" && len(m.ToolCalls) > 0 {
91
+ msg := map[string]interface{}{
92
+ "role": m.Role,
93
+ "content": m.Content,
94
+ }
95
+ var toolCalls []map[string]interface{}
96
+ for _, tc := range m.ToolCalls {
97
+ toolCalls = append(toolCalls, map[string]interface{}{
98
+ "id": tc.ID,
99
+ "type": tc.Type,
100
+ "function": map[string]interface{}{
101
+ "name": tc.Function.Name,
102
+ "arguments": tc.Function.Arguments,
103
+ },
104
+ })
105
+ }
106
+ msg["tool_calls"] = toolCalls
107
+ return msg
108
+ }
109
+
110
  text, imageURLs := m.ParseContent()
111
 
112
  // 无图片,返回纯文本
 
143
  }
144
 
145
  type ChatRequest struct {
146
+ Model string `json:"model"`
147
+ Messages []Message `json:"messages"`
148
+ Stream bool `json:"stream"`
149
+ Tools []Tool `json:"tools,omitempty"`
150
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
151
  }
152
 
153
  type ChatCompletionChunk struct {
 
158
  Choices []Choice `json:"choices"`
159
  }
160
 
 
 
161
  type Choice struct {
162
  Index int `json:"index"`
163
  Delta Delta `json:"delta,omitempty"`
 
166
  }
167
 
168
  type Delta struct {
169
+ Content string `json:"content,omitempty"`
170
+ ReasoningContent string `json:"reasoning_content,omitempty"`
171
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
172
  }
173
 
174
  type MessageResp struct {
175
+ Role string `json:"role"`
176
+ Content string `json:"content"`
177
+ ReasoningContent string `json:"reasoning_content,omitempty"`
178
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
179
  }
180
 
181
  type ChatCompletionResponse struct {
internal/model/types_test.go ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package model
2
+
3
+ import (
4
+ "encoding/json"
5
+ "testing"
6
+ )
7
+
8
+ // ===== Tool 类型序列化/反序列化 =====
9
+
10
+ func TestToolJSON(t *testing.T) {
11
+ tool := Tool{
12
+ Type: "function",
13
+ Function: ToolFunction{
14
+ Name: "get_weather",
15
+ Description: "获取天气信息",
16
+ Parameters: map[string]interface{}{
17
+ "type": "object",
18
+ "properties": map[string]interface{}{
19
+ "location": map[string]interface{}{
20
+ "type": "string",
21
+ "description": "城市名称",
22
+ },
23
+ },
24
+ "required": []string{"location"},
25
+ },
26
+ },
27
+ }
28
+
29
+ data, err := json.Marshal(tool)
30
+ if err != nil {
31
+ t.Fatalf("marshal Tool: %v", err)
32
+ }
33
+
34
+ var decoded Tool
35
+ if err := json.Unmarshal(data, &decoded); err != nil {
36
+ t.Fatalf("unmarshal Tool: %v", err)
37
+ }
38
+
39
+ if decoded.Type != "function" {
40
+ t.Errorf("Type = %q, want %q", decoded.Type, "function")
41
+ }
42
+ if decoded.Function.Name != "get_weather" {
43
+ t.Errorf("Function.Name = %q, want %q", decoded.Function.Name, "get_weather")
44
+ }
45
+ if decoded.Function.Description != "获取天气信息" {
46
+ t.Errorf("Function.Description = %q, want %q", decoded.Function.Description, "获取天气信息")
47
+ }
48
+ }
49
+
50
+ func TestToolCallJSON(t *testing.T) {
51
+ tc := ToolCall{
52
+ ID: "call_abc123",
53
+ Type: "function",
54
+ Function: FunctionCall{
55
+ Name: "get_weather",
56
+ Arguments: `{"location":"北京"}`,
57
+ },
58
+ Index: 0,
59
+ }
60
+
61
+ data, err := json.Marshal(tc)
62
+ if err != nil {
63
+ t.Fatalf("marshal ToolCall: %v", err)
64
+ }
65
+
66
+ var decoded ToolCall
67
+ if err := json.Unmarshal(data, &decoded); err != nil {
68
+ t.Fatalf("unmarshal ToolCall: %v", err)
69
+ }
70
+
71
+ if decoded.ID != "call_abc123" {
72
+ t.Errorf("ID = %q, want %q", decoded.ID, "call_abc123")
73
+ }
74
+ if decoded.Function.Name != "get_weather" {
75
+ t.Errorf("Function.Name = %q, want %q", decoded.Function.Name, "get_weather")
76
+ }
77
+ if decoded.Function.Arguments != `{"location":"北京"}` {
78
+ t.Errorf("Function.Arguments = %q, want %q", decoded.Function.Arguments, `{"location":"北京"}`)
79
+ }
80
+ }
81
+
82
+ // ===== ChatRequest 带 Tools 序列化 =====
83
+
84
+ func TestChatRequestWithTools(t *testing.T) {
85
+ reqJSON := `{
86
+ "model": "GLM-4.7",
87
+ "messages": [{"role": "user", "content": "北京天气怎么样?"}],
88
+ "stream": true,
89
+ "tools": [{
90
+ "type": "function",
91
+ "function": {
92
+ "name": "get_weather",
93
+ "description": "获取天气",
94
+ "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}
95
+ }
96
+ }],
97
+ "tool_choice": "auto"
98
+ }`
99
+
100
+ var req ChatRequest
101
+ if err := json.Unmarshal([]byte(reqJSON), &req); err != nil {
102
+ t.Fatalf("unmarshal ChatRequest: %v", err)
103
+ }
104
+
105
+ if req.Model != "GLM-4.7" {
106
+ t.Errorf("Model = %q, want %q", req.Model, "GLM-4.7")
107
+ }
108
+ if len(req.Tools) != 1 {
109
+ t.Fatalf("len(Tools) = %d, want 1", len(req.Tools))
110
+ }
111
+ if req.Tools[0].Function.Name != "get_weather" {
112
+ t.Errorf("Tools[0].Function.Name = %q, want %q", req.Tools[0].Function.Name, "get_weather")
113
+ }
114
+ if req.ToolChoice != "auto" {
115
+ t.Errorf("ToolChoice = %v, want %q", req.ToolChoice, "auto")
116
+ }
117
+ }
118
+
119
+ func TestChatRequestWithoutTools(t *testing.T) {
120
+ reqJSON := `{
121
+ "model": "GLM-4.6",
122
+ "messages": [{"role": "user", "content": "hello"}],
123
+ "stream": false
124
+ }`
125
+
126
+ var req ChatRequest
127
+ if err := json.Unmarshal([]byte(reqJSON), &req); err != nil {
128
+ t.Fatalf("unmarshal ChatRequest: %v", err)
129
+ }
130
+
131
+ if len(req.Tools) != 0 {
132
+ t.Errorf("len(Tools) = %d, want 0", len(req.Tools))
133
+ }
134
+ if req.ToolChoice != nil {
135
+ t.Errorf("ToolChoice = %v, want nil", req.ToolChoice)
136
+ }
137
+ }
138
+
139
+ func TestChatRequestToolChoiceObject(t *testing.T) {
140
+ reqJSON := `{
141
+ "model": "GLM-4.7",
142
+ "messages": [{"role": "user", "content": "test"}],
143
+ "stream": false,
144
+ "tools": [{"type": "function", "function": {"name": "fn1"}}],
145
+ "tool_choice": {"type": "function", "function": {"name": "fn1"}}
146
+ }`
147
+
148
+ var req ChatRequest
149
+ if err := json.Unmarshal([]byte(reqJSON), &req); err != nil {
150
+ t.Fatalf("unmarshal: %v", err)
151
+ }
152
+
153
+ tc, ok := req.ToolChoice.(map[string]interface{})
154
+ if !ok {
155
+ t.Fatalf("ToolChoice type = %T, want map[string]interface{}", req.ToolChoice)
156
+ }
157
+ if tc["type"] != "function" {
158
+ t.Errorf("ToolChoice.type = %v, want %q", tc["type"], "function")
159
+ }
160
+ }
161
+
162
+ // ===== Message 带 ToolCallID / ToolCalls 序列化 =====
163
+
164
+ func TestMessageWithToolCallID(t *testing.T) {
165
+ msgJSON := `{
166
+ "role": "tool",
167
+ "content": "{\"temperature\": 25}",
168
+ "tool_call_id": "call_abc123"
169
+ }`
170
+
171
+ var msg Message
172
+ if err := json.Unmarshal([]byte(msgJSON), &msg); err != nil {
173
+ t.Fatalf("unmarshal: %v", err)
174
+ }
175
+
176
+ if msg.Role != "tool" {
177
+ t.Errorf("Role = %q, want %q", msg.Role, "tool")
178
+ }
179
+ if msg.ToolCallID != "call_abc123" {
180
+ t.Errorf("ToolCallID = %q, want %q", msg.ToolCallID, "call_abc123")
181
+ }
182
+ }
183
+
184
+ func TestMessageWithToolCalls(t *testing.T) {
185
+ msgJSON := `{
186
+ "role": "assistant",
187
+ "content": "",
188
+ "tool_calls": [{
189
+ "id": "call_xyz",
190
+ "type": "function",
191
+ "function": {"name": "get_weather", "arguments": "{\"location\":\"上海\"}"},
192
+ "index": 0
193
+ }]
194
+ }`
195
+
196
+ var msg Message
197
+ if err := json.Unmarshal([]byte(msgJSON), &msg); err != nil {
198
+ t.Fatalf("unmarshal: %v", err)
199
+ }
200
+
201
+ if len(msg.ToolCalls) != 1 {
202
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
203
+ }
204
+ if msg.ToolCalls[0].Function.Name != "get_weather" {
205
+ t.Errorf("ToolCalls[0].Function.Name = %q, want %q", msg.ToolCalls[0].Function.Name, "get_weather")
206
+ }
207
+ }
208
+
209
+ // ===== ToUpstreamMessage =====
210
+
211
+ func TestToUpstreamMessage_ToolRole(t *testing.T) {
212
+ msg := Message{
213
+ Role: "tool",
214
+ Content: `{"temperature": 25}`,
215
+ ToolCallID: "call_abc",
216
+ }
217
+
218
+ result := msg.ToUpstreamMessage(nil)
219
+
220
+ if result["role"] != "tool" {
221
+ t.Errorf("role = %v, want %q", result["role"], "tool")
222
+ }
223
+ if result["tool_call_id"] != "call_abc" {
224
+ t.Errorf("tool_call_id = %v, want %q", result["tool_call_id"], "call_abc")
225
+ }
226
+ if result["content"] != `{"temperature": 25}` {
227
+ t.Errorf("content = %v, want %q", result["content"], `{"temperature": 25}`)
228
+ }
229
+ }
230
+
231
+ func TestToUpstreamMessage_AssistantWithToolCalls(t *testing.T) {
232
+ msg := Message{
233
+ Role: "assistant",
234
+ Content: "",
235
+ ToolCalls: []ToolCall{
236
+ {
237
+ ID: "call_1",
238
+ Type: "function",
239
+ Function: FunctionCall{
240
+ Name: "get_weather",
241
+ Arguments: `{"location":"北京"}`,
242
+ },
243
+ },
244
+ {
245
+ ID: "call_2",
246
+ Type: "function",
247
+ Function: FunctionCall{
248
+ Name: "get_time",
249
+ Arguments: `{"timezone":"Asia/Shanghai"}`,
250
+ },
251
+ },
252
+ },
253
+ }
254
+
255
+ result := msg.ToUpstreamMessage(nil)
256
+
257
+ if result["role"] != "assistant" {
258
+ t.Errorf("role = %v, want %q", result["role"], "assistant")
259
+ }
260
+
261
+ toolCalls, ok := result["tool_calls"].([]map[string]interface{})
262
+ if !ok {
263
+ t.Fatalf("tool_calls type = %T, want []map[string]interface{}", result["tool_calls"])
264
+ }
265
+ if len(toolCalls) != 2 {
266
+ t.Fatalf("len(tool_calls) = %d, want 2", len(toolCalls))
267
+ }
268
+ if toolCalls[0]["id"] != "call_1" {
269
+ t.Errorf("tool_calls[0].id = %v, want %q", toolCalls[0]["id"], "call_1")
270
+ }
271
+ fn, ok := toolCalls[0]["function"].(map[string]interface{})
272
+ if !ok {
273
+ t.Fatalf("function type = %T", toolCalls[0]["function"])
274
+ }
275
+ if fn["name"] != "get_weather" {
276
+ t.Errorf("function.name = %v, want %q", fn["name"], "get_weather")
277
+ }
278
+ }
279
+
280
+ func TestToUpstreamMessage_PlainUser(t *testing.T) {
281
+ msg := Message{
282
+ Role: "user",
283
+ Content: "hello",
284
+ }
285
+
286
+ result := msg.ToUpstreamMessage(nil)
287
+ if result["role"] != "user" {
288
+ t.Errorf("role = %v, want %q", result["role"], "user")
289
+ }
290
+ if result["content"] != "hello" {
291
+ t.Errorf("content = %v, want %q", result["content"], "hello")
292
+ }
293
+ if _, exists := result["tool_call_id"]; exists {
294
+ t.Error("tool_call_id should not be present for user messages")
295
+ }
296
+ if _, exists := result["tool_calls"]; exists {
297
+ t.Error("tool_calls should not be present for user messages")
298
+ }
299
+ }
300
+
301
+ func TestToUpstreamMessage_AssistantWithoutToolCalls(t *testing.T) {
302
+ msg := Message{
303
+ Role: "assistant",
304
+ Content: "你好!",
305
+ }
306
+
307
+ result := msg.ToUpstreamMessage(nil)
308
+ if result["role"] != "assistant" {
309
+ t.Errorf("role = %v, want %q", result["role"], "assistant")
310
+ }
311
+ if result["content"] != "你好!" {
312
+ t.Errorf("content = %v, want %q", result["content"], "你好!")
313
+ }
314
+ if _, exists := result["tool_calls"]; exists {
315
+ t.Error("tool_calls should not be present when empty")
316
+ }
317
+ }
318
+
319
+ // ===== Delta / MessageResp 带 ToolCalls =====
320
+
321
+ func TestDeltaWithToolCalls(t *testing.T) {
322
+ delta := Delta{
323
+ ToolCalls: []ToolCall{{
324
+ ID: "call_1",
325
+ Type: "function",
326
+ Index: 0,
327
+ Function: FunctionCall{
328
+ Name: "get_weather",
329
+ Arguments: `{"location":"北京"}`,
330
+ },
331
+ }},
332
+ }
333
+
334
+ data, err := json.Marshal(delta)
335
+ if err != nil {
336
+ t.Fatalf("marshal: %v", err)
337
+ }
338
+
339
+ var decoded Delta
340
+ if err := json.Unmarshal(data, &decoded); err != nil {
341
+ t.Fatalf("unmarshal: %v", err)
342
+ }
343
+
344
+ if len(decoded.ToolCalls) != 1 {
345
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(decoded.ToolCalls))
346
+ }
347
+ if decoded.ToolCalls[0].Function.Name != "get_weather" {
348
+ t.Errorf("Name = %q, want %q", decoded.ToolCalls[0].Function.Name, "get_weather")
349
+ }
350
+ }
351
+
352
+ func TestDeltaOmitsEmptyToolCalls(t *testing.T) {
353
+ delta := Delta{Content: "hello"}
354
+
355
+ data, err := json.Marshal(delta)
356
+ if err != nil {
357
+ t.Fatalf("marshal: %v", err)
358
+ }
359
+
360
+ // tool_calls 为空时应被 omitempty 省略
361
+ var raw map[string]interface{}
362
+ json.Unmarshal(data, &raw)
363
+ if _, exists := raw["tool_calls"]; exists {
364
+ t.Error("tool_calls should be omitted when empty")
365
+ }
366
+ }
367
+
368
+ func TestMessageRespWithToolCalls(t *testing.T) {
369
+ resp := MessageResp{
370
+ Role: "assistant",
371
+ Content: "",
372
+ ToolCalls: []ToolCall{{
373
+ ID: "call_1",
374
+ Type: "function",
375
+ Index: 0,
376
+ Function: FunctionCall{
377
+ Name: "search",
378
+ Arguments: `{"query":"test"}`,
379
+ },
380
+ }},
381
+ }
382
+
383
+ data, err := json.Marshal(resp)
384
+ if err != nil {
385
+ t.Fatalf("marshal: %v", err)
386
+ }
387
+
388
+ var decoded MessageResp
389
+ if err := json.Unmarshal(data, &decoded); err != nil {
390
+ t.Fatalf("unmarshal: %v", err)
391
+ }
392
+
393
+ if len(decoded.ToolCalls) != 1 {
394
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(decoded.ToolCalls))
395
+ }
396
+ if decoded.ToolCalls[0].Function.Arguments != `{"query":"test"}` {
397
+ t.Errorf("Arguments = %q", decoded.ToolCalls[0].Function.Arguments)
398
+ }
399
+ }
400
+
401
+ func TestMessageRespOmitsEmptyToolCalls(t *testing.T) {
402
+ resp := MessageResp{
403
+ Role: "assistant",
404
+ Content: "hello world",
405
+ }
406
+
407
+ data, _ := json.Marshal(resp)
408
+ var raw map[string]interface{}
409
+ json.Unmarshal(data, &raw)
410
+ if _, exists := raw["tool_calls"]; exists {
411
+ t.Error("tool_calls should be omitted when empty")
412
+ }
413
+ }
414
+
415
+ // ===== ChatCompletionChunk 带 tool_calls finish_reason =====
416
+
417
+ func TestChunkWithToolCallsFinishReason(t *testing.T) {
418
+ reason := "tool_calls"
419
+ chunk := ChatCompletionChunk{
420
+ ID: "chatcmpl-test",
421
+ Object: "chat.completion.chunk",
422
+ Created: 1000,
423
+ Model: "glm-4.7",
424
+ Choices: []Choice{{
425
+ Index: 0,
426
+ Delta: Delta{},
427
+ FinishReason: &reason,
428
+ }},
429
+ }
430
+
431
+ data, err := json.Marshal(chunk)
432
+ if err != nil {
433
+ t.Fatalf("marshal: %v", err)
434
+ }
435
+
436
+ var decoded ChatCompletionChunk
437
+ if err := json.Unmarshal(data, &decoded); err != nil {
438
+ t.Fatalf("unmarshal: %v", err)
439
+ }
440
+
441
+ if decoded.Choices[0].FinishReason == nil {
442
+ t.Fatal("FinishReason is nil")
443
+ }
444
+ if *decoded.Choices[0].FinishReason != "tool_calls" {
445
+ t.Errorf("FinishReason = %q, want %q", *decoded.Choices[0].FinishReason, "tool_calls")
446
+ }
447
+ }
448
+
449
+ // ===== ChatCompletionResponse 带 tool_calls =====
450
+
451
+ func TestCompletionResponseWithToolCalls(t *testing.T) {
452
+ reason := "tool_calls"
453
+ resp := ChatCompletionResponse{
454
+ ID: "chatcmpl-test",
455
+ Object: "chat.completion",
456
+ Created: 1000,
457
+ Model: "glm-4.7",
458
+ Choices: []Choice{{
459
+ Index: 0,
460
+ Message: &MessageResp{
461
+ Role: "assistant",
462
+ Content: "",
463
+ ToolCalls: []ToolCall{{
464
+ ID: "call_1",
465
+ Type: "function",
466
+ Index: 0,
467
+ Function: FunctionCall{
468
+ Name: "get_weather",
469
+ Arguments: `{"location":"北京"}`,
470
+ },
471
+ }},
472
+ },
473
+ FinishReason: &reason,
474
+ }},
475
+ }
476
+
477
+ data, err := json.Marshal(resp)
478
+ if err != nil {
479
+ t.Fatalf("marshal: %v", err)
480
+ }
481
+
482
+ var decoded ChatCompletionResponse
483
+ if err := json.Unmarshal(data, &decoded); err != nil {
484
+ t.Fatalf("unmarshal: %v", err)
485
+ }
486
+
487
+ if len(decoded.Choices) != 1 {
488
+ t.Fatalf("len(Choices) = %d", len(decoded.Choices))
489
+ }
490
+ msg := decoded.Choices[0].Message
491
+ if msg == nil {
492
+ t.Fatal("Message is nil")
493
+ }
494
+ if len(msg.ToolCalls) != 1 {
495
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
496
+ }
497
+ if msg.ToolCalls[0].Function.Name != "get_weather" {
498
+ t.Errorf("Function.Name = %q", msg.ToolCalls[0].Function.Name)
499
+ }
500
+ if *decoded.Choices[0].FinishReason != "tool_calls" {
501
+ t.Errorf("FinishReason = %q", *decoded.Choices[0].FinishReason)
502
+ }
503
+ }
internal/tools/builtin.go ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package tools
2
+
3
+ import "zai-proxy/internal/model"
4
+
5
+ // GetBuiltinTools 返回所有内置工具定义
6
+ func GetBuiltinTools() []model.Tool {
7
+ return []model.Tool{
8
+ // 多功能助手
9
+ {
10
+ Type: "function",
11
+ Function: model.ToolFunction{
12
+ Name: "get_current_time",
13
+ Description: "获取当前时间,支持不同时区和格式",
14
+ Parameters: map[string]interface{}{
15
+ "type": "object",
16
+ "properties": map[string]interface{}{
17
+ "timezone": map[string]interface{}{
18
+ "type": "string",
19
+ "description": "时区名称(如 Asia/Shanghai, America/New_York)",
20
+ },
21
+ "format": map[string]interface{}{
22
+ "type": "string",
23
+ "description": "时间格式(如 2006-01-02 15:04:05)",
24
+ },
25
+ },
26
+ "required": []string{},
27
+ },
28
+ },
29
+ },
30
+ {
31
+ Type: "function",
32
+ Function: model.ToolFunction{
33
+ Name: "calculate",
34
+ Description: "执行数学计算,支持基本运算和高级数学函数",
35
+ Parameters: map[string]interface{}{
36
+ "type": "object",
37
+ "properties": map[string]interface{}{
38
+ "expression": map[string]interface{}{
39
+ "type": "string",
40
+ "description": "数学表达式(如 2+3*4, sqrt(16), sin(pi/2))",
41
+ },
42
+ },
43
+ "required": []string{"expression"},
44
+ },
45
+ },
46
+ },
47
+ {
48
+ Type: "function",
49
+ Function: model.ToolFunction{
50
+ Name: "search_web",
51
+ Description: "搜索网络获取实时信息",
52
+ Parameters: map[string]interface{}{
53
+ "type": "object",
54
+ "properties": map[string]interface{}{
55
+ "query": map[string]interface{}{
56
+ "type": "string",
57
+ "description": "搜索关键词",
58
+ },
59
+ "num_results": map[string]interface{}{
60
+ "type": "integer",
61
+ "description": "返回结果数量,默认5",
62
+ },
63
+ },
64
+ "required": []string{"query"},
65
+ },
66
+ },
67
+ },
68
+ // 数据库查询
69
+ {
70
+ Type: "function",
71
+ Function: model.ToolFunction{
72
+ Name: "query_database",
73
+ Description: "执行SQL查询获取数据",
74
+ Parameters: map[string]interface{}{
75
+ "type": "object",
76
+ "properties": map[string]interface{}{
77
+ "sql": map[string]interface{}{
78
+ "type": "string",
79
+ "description": "SQL查询语句",
80
+ },
81
+ "database": map[string]interface{}{
82
+ "type": "string",
83
+ "description": "目标数据库名称",
84
+ },
85
+ },
86
+ "required": []string{"sql"},
87
+ },
88
+ },
89
+ },
90
+ // 文件操作
91
+ {
92
+ Type: "function",
93
+ Function: model.ToolFunction{
94
+ Name: "file_operations",
95
+ Description: "执行文件操作,支持读取、写入和列出文件",
96
+ Parameters: map[string]interface{}{
97
+ "type": "object",
98
+ "properties": map[string]interface{}{
99
+ "operation": map[string]interface{}{
100
+ "type": "string",
101
+ "enum": []string{"read", "write", "list"},
102
+ "description": "操作类型:read(读取)、write(写入)、list(列出)",
103
+ },
104
+ "path": map[string]interface{}{
105
+ "type": "string",
106
+ "description": "文件或目录路径",
107
+ },
108
+ "content": map[string]interface{}{
109
+ "type": "string",
110
+ "description": "写入内容(仅 write 操作需要)",
111
+ },
112
+ },
113
+ "required": []string{"operation", "path"},
114
+ },
115
+ },
116
+ },
117
+ // API集成
118
+ {
119
+ Type: "function",
120
+ Function: model.ToolFunction{
121
+ Name: "call_external_api",
122
+ Description: "调用外部API接口",
123
+ Parameters: map[string]interface{}{
124
+ "type": "object",
125
+ "properties": map[string]interface{}{
126
+ "url": map[string]interface{}{
127
+ "type": "string",
128
+ "description": "API请求URL",
129
+ },
130
+ "method": map[string]interface{}{
131
+ "type": "string",
132
+ "enum": []string{"GET", "POST", "PUT", "DELETE"},
133
+ "description": "HTTP请求方法",
134
+ },
135
+ "headers": map[string]interface{}{
136
+ "type": "object",
137
+ "description": "请求头",
138
+ },
139
+ "body": map[string]interface{}{
140
+ "type": "string",
141
+ "description": "请求体(JSON字符串)",
142
+ },
143
+ },
144
+ "required": []string{"url", "method"},
145
+ },
146
+ },
147
+ },
148
+ }
149
+ }
internal/tools/builtin_test.go ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package tools
2
+
3
+ import (
4
+ "testing"
5
+ )
6
+
7
+ func TestGetBuiltinTools_Count(t *testing.T) {
8
+ tools := GetBuiltinTools()
9
+ if len(tools) != 6 {
10
+ t.Errorf("len(GetBuiltinTools()) = %d, want 6", len(tools))
11
+ }
12
+ }
13
+
14
+ func TestGetBuiltinTools_AllFunction(t *testing.T) {
15
+ for _, tool := range GetBuiltinTools() {
16
+ if tool.Type != "function" {
17
+ t.Errorf("tool %q Type = %q, want %q", tool.Function.Name, tool.Type, "function")
18
+ }
19
+ }
20
+ }
21
+
22
+ func TestGetBuiltinTools_Names(t *testing.T) {
23
+ expected := map[string]bool{
24
+ "get_current_time": true,
25
+ "calculate": true,
26
+ "search_web": true,
27
+ "query_database": true,
28
+ "file_operations": true,
29
+ "call_external_api": true,
30
+ }
31
+
32
+ tools := GetBuiltinTools()
33
+ for _, tool := range tools {
34
+ name := tool.Function.Name
35
+ if !expected[name] {
36
+ t.Errorf("unexpected tool name: %q", name)
37
+ }
38
+ delete(expected, name)
39
+ }
40
+
41
+ for name := range expected {
42
+ t.Errorf("missing tool: %q", name)
43
+ }
44
+ }
45
+
46
+ func TestGetBuiltinTools_HaveDescriptions(t *testing.T) {
47
+ for _, tool := range GetBuiltinTools() {
48
+ if tool.Function.Description == "" {
49
+ t.Errorf("tool %q has empty description", tool.Function.Name)
50
+ }
51
+ }
52
+ }
53
+
54
+ func TestGetBuiltinTools_HaveParameters(t *testing.T) {
55
+ for _, tool := range GetBuiltinTools() {
56
+ if tool.Function.Parameters == nil {
57
+ t.Errorf("tool %q has nil parameters", tool.Function.Name)
58
+ }
59
+ params, ok := tool.Function.Parameters.(map[string]interface{})
60
+ if !ok {
61
+ t.Errorf("tool %q parameters is not a map", tool.Function.Name)
62
+ continue
63
+ }
64
+ if params["type"] != "object" {
65
+ t.Errorf("tool %q parameters.type = %v, want %q", tool.Function.Name, params["type"], "object")
66
+ }
67
+ if _, ok := params["properties"]; !ok {
68
+ t.Errorf("tool %q parameters missing 'properties'", tool.Function.Name)
69
+ }
70
+ }
71
+ }
72
+
73
+ func TestGetBuiltinTools_NoDuplicateNames(t *testing.T) {
74
+ seen := make(map[string]bool)
75
+ for _, tool := range GetBuiltinTools() {
76
+ if seen[tool.Function.Name] {
77
+ t.Errorf("duplicate tool name: %q", tool.Function.Name)
78
+ }
79
+ seen[tool.Function.Name] = true
80
+ }
81
+ }
82
+
83
+ func TestGetBuiltinTools_ReturnsNewSlice(t *testing.T) {
84
+ a := GetBuiltinTools()
85
+ b := GetBuiltinTools()
86
+ if &a[0] == &b[0] {
87
+ t.Error("GetBuiltinTools should return a new slice each call")
88
+ }
89
+ }
internal/upstream/client.go CHANGED
@@ -12,6 +12,7 @@ import (
12
 
13
  "zai-proxy/internal/auth"
14
  "zai-proxy/internal/model"
 
15
  "zai-proxy/internal/version"
16
  )
17
 
@@ -34,7 +35,7 @@ func ExtractAllImageURLs(messages []model.Message) []string {
34
  return allImageURLs
35
  }
36
 
37
- func MakeUpstreamRequest(token string, messages []model.Message, modelName string) (*http.Response, string, error) {
38
  payload, err := auth.DecodeJWTPayload(token)
39
  if err != nil || payload == nil {
40
  return nil, "", fmt.Errorf("invalid token")
@@ -119,6 +120,26 @@ func MakeUpstreamRequest(token string, messages []model.Message, modelName strin
119
  body["mcp_servers"] = mcpServers
120
  }
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  if len(filesData) > 0 {
123
  body["files"] = filesData
124
  body["current_user_message_id"] = userMsgID
 
12
 
13
  "zai-proxy/internal/auth"
14
  "zai-proxy/internal/model"
15
+ builtintools "zai-proxy/internal/tools"
16
  "zai-proxy/internal/version"
17
  )
18
 
 
35
  return allImageURLs
36
  }
37
 
38
+ func MakeUpstreamRequest(token string, messages []model.Message, modelName string, tools []model.Tool, toolChoice interface{}) (*http.Response, string, error) {
39
  payload, err := auth.DecodeJWTPayload(token)
40
  if err != nil || payload == nil {
41
  return nil, "", fmt.Errorf("invalid token")
 
120
  body["mcp_servers"] = mcpServers
121
  }
122
 
123
+ // 当使用 -tools 模型时,自动注入内置工具(客户端自带工具优先)
124
+ if model.IsToolsModel(modelName) {
125
+ clientToolNames := make(map[string]bool)
126
+ for _, t := range tools {
127
+ clientToolNames[t.Function.Name] = true
128
+ }
129
+ for _, bt := range builtintools.GetBuiltinTools() {
130
+ if !clientToolNames[bt.Function.Name] {
131
+ tools = append(tools, bt)
132
+ }
133
+ }
134
+ }
135
+
136
+ if len(tools) > 0 {
137
+ body["tools"] = tools
138
+ if toolChoice != nil {
139
+ body["tool_choice"] = toolChoice
140
+ }
141
+ }
142
+
143
  if len(filesData) > 0 {
144
  body["files"] = filesData
145
  body["current_user_message_id"] = userMsgID
scripts/test_tool_call.sh ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 测试 tool/function calling 功能
3
+ # 用法: ./scripts/test_tool_call.sh [TOKEN] [BASE_URL]
4
+ #
5
+ # TOKEN 可以是你的 z.ai token 或 "free"(匿名)
6
+ # BASE_URL 默认 http://localhost:8000
7
+
8
+ TOKEN="${1:-free}"
9
+ BASE_URL="${2:-http://localhost:8000}"
10
+
11
+ echo "=== 测试 Tool/Function Calling ==="
12
+ echo "BASE_URL: $BASE_URL"
13
+ echo "TOKEN: ${TOKEN:0:10}..."
14
+ echo ""
15
+
16
+ # ===== 测试 1: 带 tools 的流式请求 =====
17
+ echo "--- 测试 1: 流式 tool calling ---"
18
+ curl -sS "${BASE_URL}/v1/chat/completions" \
19
+ -H "Authorization: Bearer ${TOKEN}" \
20
+ -H "Content-Type: application/json" \
21
+ -d '{
22
+ "model": "GLM-4.7",
23
+ "stream": true,
24
+ "messages": [
25
+ {"role": "user", "content": "北京今天天气怎么样?请调用 get_weather 函数查询。"}
26
+ ],
27
+ "tools": [{
28
+ "type": "function",
29
+ "function": {
30
+ "name": "get_weather",
31
+ "description": "获取指定城市的当前天气信息",
32
+ "parameters": {
33
+ "type": "object",
34
+ "properties": {
35
+ "location": {
36
+ "type": "string",
37
+ "description": "城市名称,如:北京"
38
+ }
39
+ },
40
+ "required": ["location"]
41
+ }
42
+ }
43
+ }],
44
+ "tool_choice": "auto"
45
+ }' 2>&1
46
+ echo ""
47
+ echo ""
48
+
49
+ # ===== 测试 2: 带 tools 的非流式请求 =====
50
+ echo "--- 测试 2: 非流式 tool calling ---"
51
+ curl -sS "${BASE_URL}/v1/chat/completions" \
52
+ -H "Authorization: Bearer ${TOKEN}" \
53
+ -H "Content-Type: application/json" \
54
+ -d '{
55
+ "model": "GLM-4.7",
56
+ "stream": false,
57
+ "messages": [
58
+ {"role": "user", "content": "帮我查一下上海的天气,用 get_weather 工具。"}
59
+ ],
60
+ "tools": [{
61
+ "type": "function",
62
+ "function": {
63
+ "name": "get_weather",
64
+ "description": "获取指定城市的当前天气信息",
65
+ "parameters": {
66
+ "type": "object",
67
+ "properties": {
68
+ "location": {
69
+ "type": "string",
70
+ "description": "城市名称"
71
+ }
72
+ },
73
+ "required": ["location"]
74
+ }
75
+ }
76
+ }],
77
+ "tool_choice": "auto"
78
+ }' 2>&1 | python3 -m json.tool 2>/dev/null || cat
79
+ echo ""
80
+ echo ""
81
+
82
+ # ===== 测试 3: 多工具 =====
83
+ echo "--- 测试 3: 多工具非流式 ---"
84
+ curl -sS "${BASE_URL}/v1/chat/completions" \
85
+ -H "Authorization: Bearer ${TOKEN}" \
86
+ -H "Content-Type: application/json" \
87
+ -d '{
88
+ "model": "GLM-4.7",
89
+ "stream": false,
90
+ "messages": [
91
+ {"role": "user", "content": "北京天气怎么样?现在几点了?请分别调用对应的工具。"}
92
+ ],
93
+ "tools": [
94
+ {
95
+ "type": "function",
96
+ "function": {
97
+ "name": "get_weather",
98
+ "description": "获取天气",
99
+ "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
100
+ }
101
+ },
102
+ {
103
+ "type": "function",
104
+ "function": {
105
+ "name": "get_current_time",
106
+ "description": "获取当前时间",
107
+ "parameters": {"type": "object", "properties": {"timezone": {"type": "string"}}, "required": ["timezone"]}
108
+ }
109
+ }
110
+ ],
111
+ "tool_choice": "auto"
112
+ }' 2>&1 | python3 -m json.tool 2>/dev/null || cat
113
+ echo ""
114
+ echo ""
115
+
116
+ # ===== 测试 4: 完整多轮对话(tool result 回传)=====
117
+ echo "--- 测试 4: 多轮对话 (tool result 回传) ---"
118
+ curl -sS "${BASE_URL}/v1/chat/completions" \
119
+ -H "Authorization: Bearer ${TOKEN}" \
120
+ -H "Content-Type: application/json" \
121
+ -d '{
122
+ "model": "GLM-4.7",
123
+ "stream": false,
124
+ "messages": [
125
+ {"role": "user", "content": "北京天气怎么样?"},
126
+ {
127
+ "role": "assistant",
128
+ "content": "",
129
+ "tool_calls": [{
130
+ "id": "call_abc123",
131
+ "type": "function",
132
+ "function": {"name": "get_weather", "arguments": "{\"location\":\"北京\"}"}
133
+ }]
134
+ },
135
+ {
136
+ "role": "tool",
137
+ "tool_call_id": "call_abc123",
138
+ "content": "{\"temperature\": 25, \"condition\": \"晴\", \"humidity\": 40}"
139
+ }
140
+ ],
141
+ "tools": [{
142
+ "type": "function",
143
+ "function": {
144
+ "name": "get_weather",
145
+ "description": "获取天气",
146
+ "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
147
+ }
148
+ }]
149
+ }' 2>&1 | python3 -m json.tool 2>/dev/null || cat
150
+ echo ""
151
+ echo ""
152
+
153
+ # ===== 测试 5: 不带 tools 的普通请求(回归测试)=====
154
+ echo "--- 测试 5: 不带 tools 的普通请求(回归)---"
155
+ curl -sS "${BASE_URL}/v1/chat/completions" \
156
+ -H "Authorization: Bearer ${TOKEN}" \
157
+ -H "Content-Type: application/json" \
158
+ -d '{
159
+ "model": "GLM-4.7",
160
+ "stream": false,
161
+ "messages": [
162
+ {"role": "user", "content": "你好,1+1等于几?"}
163
+ ]
164
+ }' 2>&1 | python3 -m json.tool 2>/dev/null || cat
165
+ echo ""
166
+
167
+ echo "=== 测试完成 ==="
168
+ echo ""
169
+ echo "检查要点:"
170
+ echo " 1. 测试 1/2: 查看响应中是否有 tool_calls 字段和 finish_reason=tool_calls"
171
+ echo " 2. 测试 3: 是否返回多个 tool_calls"
172
+ echo " 3. 测试 4: 模型是否基于 tool result 生成了自然语言回复"
173
+ echo " 4. 测试 5: 不带 tools 时是否正常返回文本(无 tool_calls 字段)"
174
+ echo " 5. 查看服务端日志中的 [ToolCall] 行,确认上游返回的原始格式"