Spaces:
Paused
Paused
feat: support tool/function calling (OpenAI compatible)
Browse filesAdd tools and tool_choice fields to chat requests, parse upstream
tool_call responses, and return them in OpenAI-compatible format
with finish_reason=tool_calls. Includes builtin tools, test script,
and unit tests.
- README.md +88 -0
- internal/filter/toolcall.go +103 -0
- internal/filter/toolcall_test.go +277 -0
- internal/handler/chat.go +69 -6
- internal/handler/chat_test.go +576 -0
- internal/model/mapping.go +18 -7
- internal/model/mapping_test.go +201 -0
- internal/model/types.go +74 -12
- internal/model/types_test.go +503 -0
- internal/tools/builtin.go +149 -0
- internal/tools/builtin_test.go +89 -0
- internal/upstream/client.go +22 -1
- scripts/test_tool_call.sh +174 -0
README.md
CHANGED
|
@@ -9,6 +9,7 @@ zai-proxy 是一个基于 Go 语言的代理服务,将 z.ai 网页聊天转换
|
|
| 9 |
- 支持多种 GLM 模型
|
| 10 |
- 支持思考模式 (thinking)
|
| 11 |
- 支持联网搜索模式 (search)
|
|
|
|
| 12 |
- 支持多模态图片输入
|
| 13 |
- 支持匿名 Token(免登录)
|
| 14 |
- **自动生成签名**
|
|
@@ -87,13 +88,18 @@ curl http://localhost:8000/v1/chat/completions \
|
|
| 87 |
|
| 88 |
- `-thinking`: 启用思考模式,响应会包含 `reasoning_content` 字段
|
| 89 |
- `-search`: 启用联网搜索模式
|
|
|
|
| 90 |
- (TODO) `-deepsearch`: 启用多轮搜索,深入研究分析
|
| 91 |
|
|
|
|
|
|
|
| 92 |
示例:
|
| 93 |
|
| 94 |
- `GLM-4.7-thinking`
|
| 95 |
- `GLM-4.7-search`
|
| 96 |
- `GLM-4.7-thinking-search`
|
|
|
|
|
|
|
| 97 |
|
| 98 |
## 使用示例
|
| 99 |
|
|
@@ -130,3 +136,85 @@ curl http://localhost:8000/v1/chat/completions \
|
|
| 130 |
### 支持的图片格式:
|
| 131 |
- HTTP/HTTPS URL
|
| 132 |
- Base64 编码 (data:image/jpeg;base64,...)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
- 支持多种 GLM 模型
|
| 10 |
- 支持思考模式 (thinking)
|
| 11 |
- 支持联网搜索模式 (search)
|
| 12 |
+
- 支持内置工具调用 (tools)
|
| 13 |
- 支持多模态图片输入
|
| 14 |
- 支持匿名 Token(免登录)
|
| 15 |
- **自动生成签名**
|
|
|
|
| 88 |
|
| 89 |
- `-thinking`: 启用思考模式,响应会包含 `reasoning_content` 字段
|
| 90 |
- `-search`: 启用联网搜索模式
|
| 91 |
+
- `-tools`: 自动注入内置工具定义,模型会返回 `tool_calls` 进行函数调用
|
| 92 |
- (TODO) `-deepsearch`: 启用多轮搜索,深入研究分析
|
| 93 |
|
| 94 |
+
标签可任意组合,顺序不限:
|
| 95 |
+
|
| 96 |
示例:
|
| 97 |
|
| 98 |
- `GLM-4.7-thinking`
|
| 99 |
- `GLM-4.7-search`
|
| 100 |
- `GLM-4.7-thinking-search`
|
| 101 |
+
- `GLM-4.7-tools`
|
| 102 |
+
- `GLM-4.7-tools-thinking`
|
| 103 |
|
| 104 |
## 使用示例
|
| 105 |
|
|
|
|
| 136 |
### 支持的图片格式:
|
| 137 |
- HTTP/HTTPS URL
|
| 138 |
- Base64 编码 (data:image/jpeg;base64,...)
|
| 139 |
+
|
| 140 |
+
## 工具调用 (Function Calling)
|
| 141 |
+
|
| 142 |
+
使用 `-tools` 后缀时,代理会自动注入 6 个内置工具定义。模型会根据用户输入决定是否调用工具。
|
| 143 |
+
|
| 144 |
+
### 内置工具
|
| 145 |
+
|
| 146 |
+
| 工具名 | 描述 |
|
| 147 |
+
|--------|------|
|
| 148 |
+
| `get_current_time` | 获取当前时间 |
|
| 149 |
+
| `calculate` | 执行数学计算 |
|
| 150 |
+
| `search_web` | 搜索网络信息 |
|
| 151 |
+
| `query_database` | 执行SQL查询 |
|
| 152 |
+
| `file_operations` | 文件读写列表 |
|
| 153 |
+
| `call_external_api` | 调用外部API |
|
| 154 |
+
|
| 155 |
+
### 基本调用
|
| 156 |
+
|
| 157 |
+
```bash
|
| 158 |
+
curl http://localhost:8000/v1/chat/completions \
|
| 159 |
+
-H "Authorization: Bearer YOUR_ZAI_TOKEN" \
|
| 160 |
+
-H "Content-Type: application/json" \
|
| 161 |
+
-d '{
|
| 162 |
+
"model": "GLM-4.7-tools",
|
| 163 |
+
"messages": [{"role": "user", "content": "现在几点了?"}],
|
| 164 |
+
"stream": true
|
| 165 |
+
}'
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
模型会返回 `tool_calls`(`finish_reason` 为 `"tool_calls"`),由客户端自行执行工具并将结果发回。
|
| 169 |
+
|
| 170 |
+
### 多轮调用流程
|
| 171 |
+
|
| 172 |
+
```
|
| 173 |
+
第1轮:用户提问 → 模型返回 tool_calls
|
| 174 |
+
第2轮:发送工具执行结果 → 模型生成最终回答
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
curl http://localhost:8000/v1/chat/completions \
|
| 179 |
+
-H "Authorization: Bearer YOUR_ZAI_TOKEN" \
|
| 180 |
+
-H "Content-Type: application/json" \
|
| 181 |
+
-d '{
|
| 182 |
+
"model": "GLM-4.7-tools",
|
| 183 |
+
"messages": [
|
| 184 |
+
{"role": "user", "content": "现在几点了?"},
|
| 185 |
+
{"role": "assistant", "content": "", "tool_calls": [
|
| 186 |
+
{"id": "call_xxx", "type": "function", "function": {"name": "get_current_time", "arguments": "{}"}}
|
| 187 |
+
]},
|
| 188 |
+
{"role": "tool", "tool_call_id": "call_xxx", "content": "{\"time\": \"2026-03-14 15:30:00\"}"}
|
| 189 |
+
],
|
| 190 |
+
"stream": true
|
| 191 |
+
}'
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### 自定义工具
|
| 195 |
+
|
| 196 |
+
也可以不使用 `-tools` 后缀,直接在请求中传入 `tools` 字段(标准 OpenAI 格式):
|
| 197 |
+
|
| 198 |
+
```json
|
| 199 |
+
{
|
| 200 |
+
"model": "GLM-4.7",
|
| 201 |
+
"messages": [{"role": "user", "content": "北京天气怎么样?"}],
|
| 202 |
+
"tools": [{
|
| 203 |
+
"type": "function",
|
| 204 |
+
"function": {
|
| 205 |
+
"name": "get_weather",
|
| 206 |
+
"description": "获取天气信息",
|
| 207 |
+
"parameters": {
|
| 208 |
+
"type": "object",
|
| 209 |
+
"properties": {
|
| 210 |
+
"city": {"type": "string", "description": "城市名称"}
|
| 211 |
+
},
|
| 212 |
+
"required": ["city"]
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
}],
|
| 216 |
+
"tool_choice": "auto"
|
| 217 |
+
}
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
两者可混合使用:`-tools` 模型名 + 自定义 `tools` 字段。**客户端自带的同名工具优先**,不会被内置工具覆盖。
|
internal/filter/toolcall.go
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package filter
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"regexp"
|
| 6 |
+
"strings"
|
| 7 |
+
|
| 8 |
+
"zai-proxy/internal/model"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
var glmToolCallBlockPattern = regexp.MustCompile(`<glm_block[^>]*type="tool_call"[^>]*>([\s\S]*?)</glm_block>`)
|
| 12 |
+
|
| 13 |
+
// IsFunctionToolCall 判断 tool_call 阶段的内容是否是用户定义的函数调用(非 mcp/search)
|
| 14 |
+
func IsFunctionToolCall(editContent string, phase string) bool {
|
| 15 |
+
if phase != "tool_call" {
|
| 16 |
+
return false
|
| 17 |
+
}
|
| 18 |
+
// 排除 mcp / search 类型的 tool call
|
| 19 |
+
if strings.Contains(editContent, `"mcp"`) || strings.Contains(editContent, `mcp-server`) {
|
| 20 |
+
return false
|
| 21 |
+
}
|
| 22 |
+
if strings.Contains(editContent, `"search_result"`) || strings.Contains(editContent, `"search_image"`) {
|
| 23 |
+
return false
|
| 24 |
+
}
|
| 25 |
+
// 包含函数调用特征
|
| 26 |
+
return strings.Contains(editContent, `"function"`) || strings.Contains(editContent, `"arguments"`)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// ParseFunctionToolCalls 从上游 edit_content 解析函数调用
|
| 30 |
+
func ParseFunctionToolCalls(editContent string) []model.ToolCall {
|
| 31 |
+
// 尝试从 glm_block 中提取
|
| 32 |
+
matches := glmToolCallBlockPattern.FindAllStringSubmatch(editContent, -1)
|
| 33 |
+
if len(matches) > 0 {
|
| 34 |
+
var allCalls []model.ToolCall
|
| 35 |
+
for _, match := range matches {
|
| 36 |
+
if calls := parseToolCallJSON(match[1]); len(calls) > 0 {
|
| 37 |
+
allCalls = append(allCalls, calls...)
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
if len(allCalls) > 0 {
|
| 41 |
+
return allCalls
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// 尝试直接解析为 JSON
|
| 46 |
+
return parseToolCallJSON(editContent)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// parseToolCallJSON 解析 tool call JSON 数据
|
| 50 |
+
func parseToolCallJSON(content string) []model.ToolCall {
|
| 51 |
+
content = strings.TrimSpace(content)
|
| 52 |
+
if content == "" {
|
| 53 |
+
return nil
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// 尝试解析为单个 tool call 对象
|
| 57 |
+
var single struct {
|
| 58 |
+
ID string `json:"id"`
|
| 59 |
+
Type string `json:"type"`
|
| 60 |
+
Function struct {
|
| 61 |
+
Name string `json:"name"`
|
| 62 |
+
Arguments string `json:"arguments"`
|
| 63 |
+
} `json:"function"`
|
| 64 |
+
Name string `json:"name"`
|
| 65 |
+
Arguments string `json:"arguments"`
|
| 66 |
+
}
|
| 67 |
+
if err := json.Unmarshal([]byte(content), &single); err == nil {
|
| 68 |
+
if single.Function.Name != "" {
|
| 69 |
+
return []model.ToolCall{{
|
| 70 |
+
ID: single.ID,
|
| 71 |
+
Type: "function",
|
| 72 |
+
Function: model.FunctionCall{
|
| 73 |
+
Name: single.Function.Name,
|
| 74 |
+
Arguments: single.Function.Arguments,
|
| 75 |
+
},
|
| 76 |
+
}}
|
| 77 |
+
}
|
| 78 |
+
if single.Name != "" {
|
| 79 |
+
return []model.ToolCall{{
|
| 80 |
+
ID: single.ID,
|
| 81 |
+
Type: "function",
|
| 82 |
+
Function: model.FunctionCall{
|
| 83 |
+
Name: single.Name,
|
| 84 |
+
Arguments: single.Arguments,
|
| 85 |
+
},
|
| 86 |
+
}}
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// 尝试解析为数组
|
| 91 |
+
var arr []json.RawMessage
|
| 92 |
+
if err := json.Unmarshal([]byte(content), &arr); err == nil {
|
| 93 |
+
var calls []model.ToolCall
|
| 94 |
+
for _, raw := range arr {
|
| 95 |
+
if parsed := parseToolCallJSON(string(raw)); len(parsed) > 0 {
|
| 96 |
+
calls = append(calls, parsed...)
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
return calls
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
return nil
|
| 103 |
+
}
|
internal/filter/toolcall_test.go
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package filter
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"testing"
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
// ===== IsFunctionToolCall =====
|
| 8 |
+
|
| 9 |
+
func TestIsFunctionToolCall_True(t *testing.T) {
|
| 10 |
+
tests := []struct {
|
| 11 |
+
name string
|
| 12 |
+
content string
|
| 13 |
+
phase string
|
| 14 |
+
}{
|
| 15 |
+
{
|
| 16 |
+
name: "标准 function 字段",
|
| 17 |
+
content: `{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{}"}}`,
|
| 18 |
+
phase: "tool_call",
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
name: "包含 arguments 字段",
|
| 22 |
+
content: `{"name":"get_weather","arguments":"{\"location\":\"北京\"}"}`,
|
| 23 |
+
phase: "tool_call",
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
name: "glm_block 包裹的函数调用",
|
| 27 |
+
content: `<glm_block type="tool_call">{"function":{"name":"fn1","arguments":"{}"}}</glm_block>`,
|
| 28 |
+
phase: "tool_call",
|
| 29 |
+
},
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
for _, tt := range tests {
|
| 33 |
+
t.Run(tt.name, func(t *testing.T) {
|
| 34 |
+
if !IsFunctionToolCall(tt.content, tt.phase) {
|
| 35 |
+
t.Error("expected true")
|
| 36 |
+
}
|
| 37 |
+
})
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
func TestIsFunctionToolCall_False(t *testing.T) {
|
| 42 |
+
tests := []struct {
|
| 43 |
+
name string
|
| 44 |
+
content string
|
| 45 |
+
phase string
|
| 46 |
+
}{
|
| 47 |
+
{
|
| 48 |
+
name: "非 tool_call 阶段",
|
| 49 |
+
content: `{"function":{"name":"get_weather","arguments":"{}"}}`,
|
| 50 |
+
phase: "answer",
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
name: "mcp tool call",
|
| 54 |
+
content: `{"type":"mcp","function":{"name":"mcp_tool","arguments":"{}"}}`,
|
| 55 |
+
phase: "tool_call",
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
name: "mcp-server tool call",
|
| 59 |
+
content: `mcp-server something with "arguments"`,
|
| 60 |
+
phase: "tool_call",
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
name: "search_result 内容",
|
| 64 |
+
content: `{"search_result":[...],"function":"x","arguments":"y"}`,
|
| 65 |
+
phase: "tool_call",
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
name: "search_image 内容",
|
| 69 |
+
content: `{"search_image":{},"function":"x","arguments":"y"}`,
|
| 70 |
+
phase: "tool_call",
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
name: "无函数调用特征",
|
| 74 |
+
content: `{"type":"tool_call","data":"hello world"}`,
|
| 75 |
+
phase: "tool_call",
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
name: "空阶段",
|
| 79 |
+
content: `{"function":{"name":"fn","arguments":"{}"}}`,
|
| 80 |
+
phase: "",
|
| 81 |
+
},
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
for _, tt := range tests {
|
| 85 |
+
t.Run(tt.name, func(t *testing.T) {
|
| 86 |
+
if IsFunctionToolCall(tt.content, tt.phase) {
|
| 87 |
+
t.Error("expected false")
|
| 88 |
+
}
|
| 89 |
+
})
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// ===== ParseFunctionToolCalls =====
|
| 94 |
+
|
| 95 |
+
func TestParseFunctionToolCalls_StandardFormat(t *testing.T) {
|
| 96 |
+
content := `{"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"北京\"}"}}`
|
| 97 |
+
|
| 98 |
+
calls := ParseFunctionToolCalls(content)
|
| 99 |
+
if len(calls) != 1 {
|
| 100 |
+
t.Fatalf("len(calls) = %d, want 1", len(calls))
|
| 101 |
+
}
|
| 102 |
+
if calls[0].ID != "call_abc" {
|
| 103 |
+
t.Errorf("ID = %q, want %q", calls[0].ID, "call_abc")
|
| 104 |
+
}
|
| 105 |
+
if calls[0].Type != "function" {
|
| 106 |
+
t.Errorf("Type = %q, want %q", calls[0].Type, "function")
|
| 107 |
+
}
|
| 108 |
+
if calls[0].Function.Name != "get_weather" {
|
| 109 |
+
t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "get_weather")
|
| 110 |
+
}
|
| 111 |
+
if calls[0].Function.Arguments != `{"location":"北京"}` {
|
| 112 |
+
t.Errorf("Function.Arguments = %q", calls[0].Function.Arguments)
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
func TestParseFunctionToolCalls_FlatFormat(t *testing.T) {
|
| 117 |
+
content := `{"id":"call_flat","name":"get_time","arguments":"{\"timezone\":\"UTC\"}"}`
|
| 118 |
+
|
| 119 |
+
calls := ParseFunctionToolCalls(content)
|
| 120 |
+
if len(calls) != 1 {
|
| 121 |
+
t.Fatalf("len(calls) = %d, want 1", len(calls))
|
| 122 |
+
}
|
| 123 |
+
if calls[0].Function.Name != "get_time" {
|
| 124 |
+
t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "get_time")
|
| 125 |
+
}
|
| 126 |
+
if calls[0].Function.Arguments != `{"timezone":"UTC"}` {
|
| 127 |
+
t.Errorf("Function.Arguments = %q", calls[0].Function.Arguments)
|
| 128 |
+
}
|
| 129 |
+
if calls[0].Type != "function" {
|
| 130 |
+
t.Errorf("Type = %q, want %q", calls[0].Type, "function")
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
func TestParseFunctionToolCalls_GlmBlock(t *testing.T) {
|
| 135 |
+
content := `一些文本<glm_block type="tool_call">{"id":"call_glm","type":"function","function":{"name":"search","arguments":"{\"q\":\"test\"}"}}</glm_block>后续文本`
|
| 136 |
+
|
| 137 |
+
calls := ParseFunctionToolCalls(content)
|
| 138 |
+
if len(calls) != 1 {
|
| 139 |
+
t.Fatalf("len(calls) = %d, want 1", len(calls))
|
| 140 |
+
}
|
| 141 |
+
if calls[0].ID != "call_glm" {
|
| 142 |
+
t.Errorf("ID = %q, want %q", calls[0].ID, "call_glm")
|
| 143 |
+
}
|
| 144 |
+
if calls[0].Function.Name != "search" {
|
| 145 |
+
t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "search")
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
func TestParseFunctionToolCalls_MultipleGlmBlocks(t *testing.T) {
|
| 150 |
+
content := `<glm_block type="tool_call">{"function":{"name":"fn1","arguments":"{}"}}</glm_block>` +
|
| 151 |
+
`<glm_block type="tool_call">{"function":{"name":"fn2","arguments":"{}"}}</glm_block>`
|
| 152 |
+
|
| 153 |
+
calls := ParseFunctionToolCalls(content)
|
| 154 |
+
if len(calls) != 2 {
|
| 155 |
+
t.Fatalf("len(calls) = %d, want 2", len(calls))
|
| 156 |
+
}
|
| 157 |
+
if calls[0].Function.Name != "fn1" {
|
| 158 |
+
t.Errorf("calls[0].Function.Name = %q, want %q", calls[0].Function.Name, "fn1")
|
| 159 |
+
}
|
| 160 |
+
if calls[1].Function.Name != "fn2" {
|
| 161 |
+
t.Errorf("calls[1].Function.Name = %q, want %q", calls[1].Function.Name, "fn2")
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
func TestParseFunctionToolCalls_Array(t *testing.T) {
|
| 166 |
+
content := `[{"id":"c1","type":"function","function":{"name":"fn1","arguments":"{}"}},{"id":"c2","type":"function","function":{"name":"fn2","arguments":"{}"}}]`
|
| 167 |
+
|
| 168 |
+
calls := ParseFunctionToolCalls(content)
|
| 169 |
+
if len(calls) != 2 {
|
| 170 |
+
t.Fatalf("len(calls) = %d, want 2", len(calls))
|
| 171 |
+
}
|
| 172 |
+
if calls[0].Function.Name != "fn1" {
|
| 173 |
+
t.Errorf("calls[0].Function.Name = %q", calls[0].Function.Name)
|
| 174 |
+
}
|
| 175 |
+
if calls[1].Function.Name != "fn2" {
|
| 176 |
+
t.Errorf("calls[1].Function.Name = %q", calls[1].Function.Name)
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
func TestParseFunctionToolCalls_NoID(t *testing.T) {
|
| 181 |
+
content := `{"type":"function","function":{"name":"get_weather","arguments":"{}"}}`
|
| 182 |
+
|
| 183 |
+
calls := ParseFunctionToolCalls(content)
|
| 184 |
+
if len(calls) != 1 {
|
| 185 |
+
t.Fatalf("len(calls) = %d, want 1", len(calls))
|
| 186 |
+
}
|
| 187 |
+
if calls[0].ID != "" {
|
| 188 |
+
t.Errorf("ID = %q, want empty (caller assigns ID)", calls[0].ID)
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
func TestParseFunctionToolCalls_EmptyContent(t *testing.T) {
|
| 193 |
+
calls := ParseFunctionToolCalls("")
|
| 194 |
+
if len(calls) != 0 {
|
| 195 |
+
t.Errorf("len(calls) = %d, want 0", len(calls))
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
func TestParseFunctionToolCalls_WhitespaceOnly(t *testing.T) {
|
| 200 |
+
calls := ParseFunctionToolCalls(" \n\t ")
|
| 201 |
+
if len(calls) != 0 {
|
| 202 |
+
t.Errorf("len(calls) = %d, want 0", len(calls))
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
func TestParseFunctionToolCalls_InvalidJSON(t *testing.T) {
|
| 207 |
+
calls := ParseFunctionToolCalls("not json at all {{{")
|
| 208 |
+
if len(calls) != 0 {
|
| 209 |
+
t.Errorf("len(calls) = %d, want 0", len(calls))
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
func TestParseFunctionToolCalls_JSONWithoutFunctionFields(t *testing.T) {
|
| 214 |
+
calls := ParseFunctionToolCalls(`{"type":"something","data":"hello"}`)
|
| 215 |
+
if len(calls) != 0 {
|
| 216 |
+
t.Errorf("len(calls) = %d, want 0", len(calls))
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
func TestParseFunctionToolCalls_EmptyArray(t *testing.T) {
|
| 221 |
+
calls := ParseFunctionToolCalls(`[]`)
|
| 222 |
+
if len(calls) != 0 {
|
| 223 |
+
t.Errorf("len(calls) = %d, want 0", len(calls))
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
func TestParseFunctionToolCalls_ComplexArguments(t *testing.T) {
|
| 228 |
+
content := `{"function":{"name":"create_order","arguments":"{\"items\":[{\"id\":1,\"qty\":2},{\"id\":3,\"qty\":1}],\"user\":\"张三\"}"}}`
|
| 229 |
+
|
| 230 |
+
calls := ParseFunctionToolCalls(content)
|
| 231 |
+
if len(calls) != 1 {
|
| 232 |
+
t.Fatalf("len(calls) = %d, want 1", len(calls))
|
| 233 |
+
}
|
| 234 |
+
if calls[0].Function.Name != "create_order" {
|
| 235 |
+
t.Errorf("Function.Name = %q", calls[0].Function.Name)
|
| 236 |
+
}
|
| 237 |
+
// 确保复杂 JSON 参数完整保留
|
| 238 |
+
if calls[0].Function.Arguments == "" {
|
| 239 |
+
t.Error("Function.Arguments is empty")
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
func TestParseFunctionToolCalls_GlmBlockWithExtraAttrs(t *testing.T) {
|
| 244 |
+
content := `<glm_block id="123" type="tool_call" status="pending">{"function":{"name":"fn1","arguments":"{}"}}</glm_block>`
|
| 245 |
+
|
| 246 |
+
calls := ParseFunctionToolCalls(content)
|
| 247 |
+
if len(calls) != 1 {
|
| 248 |
+
t.Fatalf("len(calls) = %d, want 1", len(calls))
|
| 249 |
+
}
|
| 250 |
+
if calls[0].Function.Name != "fn1" {
|
| 251 |
+
t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "fn1")
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
func TestParseFunctionToolCalls_GlmBlockInvalidJSON(t *testing.T) {
|
| 256 |
+
content := `<glm_block type="tool_call">not valid json</glm_block>`
|
| 257 |
+
|
| 258 |
+
calls := ParseFunctionToolCalls(content)
|
| 259 |
+
if len(calls) != 0 {
|
| 260 |
+
t.Errorf("len(calls) = %d, want 0", len(calls))
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
// ===== 优先级:glm_block 优先于原始 JSON =====
|
| 265 |
+
|
| 266 |
+
func TestParseFunctionToolCalls_GlmBlockPriority(t *testing.T) {
|
| 267 |
+
// 如果同时存在 glm_block 和外层 JSON,优先从 glm_block 提取
|
| 268 |
+
content := `<glm_block type="tool_call">{"function":{"name":"from_block","arguments":"{}"}}</glm_block>`
|
| 269 |
+
|
| 270 |
+
calls := ParseFunctionToolCalls(content)
|
| 271 |
+
if len(calls) != 1 {
|
| 272 |
+
t.Fatalf("len(calls) = %d, want 1", len(calls))
|
| 273 |
+
}
|
| 274 |
+
if calls[0].Function.Name != "from_block" {
|
| 275 |
+
t.Errorf("Function.Name = %q, want %q", calls[0].Function.Name, "from_block")
|
| 276 |
+
}
|
| 277 |
+
}
|
internal/handler/chat.go
CHANGED
|
@@ -45,7 +45,7 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|
| 45 |
req.Model = "GLM-4.6"
|
| 46 |
}
|
| 47 |
|
| 48 |
-
resp, modelName, err := upstream.MakeUpstreamRequest(token, req.Messages, req.Model)
|
| 49 |
if err != nil {
|
| 50 |
logger.LogError("Upstream request failed: %v", err)
|
| 51 |
http.Error(w, "Upstream error", http.StatusBadGateway)
|
|
@@ -67,13 +67,13 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|
| 67 |
completionID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:29])
|
| 68 |
|
| 69 |
if req.Stream {
|
| 70 |
-
handleStreamResponse(w, resp.Body, completionID, modelName)
|
| 71 |
} else {
|
| 72 |
-
handleNonStreamResponse(w, resp.Body, completionID, modelName)
|
| 73 |
}
|
| 74 |
}
|
| 75 |
|
| 76 |
-
func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string) {
|
| 77 |
w.Header().Set("Content-Type", "text/event-stream")
|
| 78 |
w.Header().Set("Cache-Control", "no-cache")
|
| 79 |
w.Header().Set("Connection", "keep-alive")
|
|
@@ -92,6 +92,8 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 92 |
pendingSourcesMarkdown := ""
|
| 93 |
pendingImageSearchMarkdown := ""
|
| 94 |
totalContentOutputLength := 0
|
|
|
|
|
|
|
| 95 |
|
| 96 |
for scanner.Scan() {
|
| 97 |
line := scanner.Text()
|
|
@@ -221,6 +223,43 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 221 |
if editContent != "" && filter.IsSearchToolCall(editContent, upstreamData.Data.Phase) {
|
| 222 |
continue
|
| 223 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
if pendingSourcesMarkdown != "" {
|
| 226 |
hasContent = true
|
|
@@ -410,6 +449,9 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 410 |
}
|
| 411 |
|
| 412 |
stopReason := "stop"
|
|
|
|
|
|
|
|
|
|
| 413 |
finalChunk := model.ChatCompletionChunk{
|
| 414 |
ID: completionID,
|
| 415 |
Object: "chat.completion.chunk",
|
|
@@ -428,7 +470,7 @@ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionI
|
|
| 428 |
flusher.Flush()
|
| 429 |
}
|
| 430 |
|
| 431 |
-
func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string) {
|
| 432 |
scanner := bufio.NewScanner(body)
|
| 433 |
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
| 434 |
var chunks []string
|
|
@@ -438,6 +480,7 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
|
|
| 438 |
hasThinking := false
|
| 439 |
pendingSourcesMarkdown := ""
|
| 440 |
pendingImageSearchMarkdown := ""
|
|
|
|
| 441 |
|
| 442 |
for scanner.Scan() {
|
| 443 |
line := scanner.Text()
|
|
@@ -510,6 +553,22 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
|
|
| 510 |
if editContent != "" && filter.IsSearchToolCall(editContent, upstreamData.Data.Phase) {
|
| 511 |
continue
|
| 512 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
if pendingSourcesMarkdown != "" {
|
| 515 |
if hasThinking {
|
|
@@ -557,11 +616,14 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
|
|
| 557 |
fullReasoning := strings.Join(reasoningChunks, "")
|
| 558 |
fullReasoning = searchRefFilter.Process(fullReasoning) + searchRefFilter.Flush()
|
| 559 |
|
| 560 |
-
if fullContent == "" {
|
| 561 |
logger.LogError("Non-stream response 200 but no content received")
|
| 562 |
}
|
| 563 |
|
| 564 |
stopReason := "stop"
|
|
|
|
|
|
|
|
|
|
| 565 |
response := model.ChatCompletionResponse{
|
| 566 |
ID: completionID,
|
| 567 |
Object: "chat.completion",
|
|
@@ -573,6 +635,7 @@ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completi
|
|
| 573 |
Role: "assistant",
|
| 574 |
Content: fullContent,
|
| 575 |
ReasoningContent: fullReasoning,
|
|
|
|
| 576 |
},
|
| 577 |
FinishReason: &stopReason,
|
| 578 |
}},
|
|
|
|
| 45 |
req.Model = "GLM-4.6"
|
| 46 |
}
|
| 47 |
|
| 48 |
+
resp, modelName, err := upstream.MakeUpstreamRequest(token, req.Messages, req.Model, req.Tools, req.ToolChoice)
|
| 49 |
if err != nil {
|
| 50 |
logger.LogError("Upstream request failed: %v", err)
|
| 51 |
http.Error(w, "Upstream error", http.StatusBadGateway)
|
|
|
|
| 67 |
completionID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:29])
|
| 68 |
|
| 69 |
if req.Stream {
|
| 70 |
+
handleStreamResponse(w, resp.Body, completionID, modelName, req.Tools)
|
| 71 |
} else {
|
| 72 |
+
handleNonStreamResponse(w, resp.Body, completionID, modelName, req.Tools)
|
| 73 |
}
|
| 74 |
}
|
| 75 |
|
| 76 |
+
func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string, tools []model.Tool) {
|
| 77 |
w.Header().Set("Content-Type", "text/event-stream")
|
| 78 |
w.Header().Set("Cache-Control", "no-cache")
|
| 79 |
w.Header().Set("Connection", "keep-alive")
|
|
|
|
| 92 |
pendingSourcesMarkdown := ""
|
| 93 |
pendingImageSearchMarkdown := ""
|
| 94 |
totalContentOutputLength := 0
|
| 95 |
+
hasToolCalls := false
|
| 96 |
+
var collectedToolCalls []model.ToolCall
|
| 97 |
|
| 98 |
for scanner.Scan() {
|
| 99 |
line := scanner.Text()
|
|
|
|
| 223 |
if editContent != "" && filter.IsSearchToolCall(editContent, upstreamData.Data.Phase) {
|
| 224 |
continue
|
| 225 |
}
|
| 226 |
+
// 检测用户定义的函数调用(tool_call 阶段,非 mcp/search)
|
| 227 |
+
if upstreamData.Data.Phase == "tool_call" && editContent != "" {
|
| 228 |
+
logger.LogInfo("[ToolCall] phase=%s edit_content=%s", upstreamData.Data.Phase, editContent)
|
| 229 |
+
}
|
| 230 |
+
if len(tools) > 0 && editContent != "" && filter.IsFunctionToolCall(editContent, upstreamData.Data.Phase) {
|
| 231 |
+
if toolCalls := filter.ParseFunctionToolCalls(editContent); len(toolCalls) > 0 {
|
| 232 |
+
for i := range toolCalls {
|
| 233 |
+
if toolCalls[i].ID == "" {
|
| 234 |
+
toolCalls[i].ID = fmt.Sprintf("call_%s", uuid.New().String()[:24])
|
| 235 |
+
}
|
| 236 |
+
toolCalls[i].Index = i
|
| 237 |
+
}
|
| 238 |
+
collectedToolCalls = toolCalls
|
| 239 |
+
hasToolCalls = true
|
| 240 |
+
|
| 241 |
+
for _, tc := range toolCalls {
|
| 242 |
+
hasContent = true
|
| 243 |
+
chunk := model.ChatCompletionChunk{
|
| 244 |
+
ID: completionID,
|
| 245 |
+
Object: "chat.completion.chunk",
|
| 246 |
+
Created: time.Now().Unix(),
|
| 247 |
+
Model: modelName,
|
| 248 |
+
Choices: []model.Choice{{
|
| 249 |
+
Index: 0,
|
| 250 |
+
Delta: model.Delta{
|
| 251 |
+
ToolCalls: []model.ToolCall{tc},
|
| 252 |
+
},
|
| 253 |
+
FinishReason: nil,
|
| 254 |
+
}},
|
| 255 |
+
}
|
| 256 |
+
data, _ := json.Marshal(chunk)
|
| 257 |
+
fmt.Fprintf(w, "data: %s\n\n", data)
|
| 258 |
+
flusher.Flush()
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
continue
|
| 262 |
+
}
|
| 263 |
|
| 264 |
if pendingSourcesMarkdown != "" {
|
| 265 |
hasContent = true
|
|
|
|
| 449 |
}
|
| 450 |
|
| 451 |
stopReason := "stop"
|
| 452 |
+
if hasToolCalls && len(collectedToolCalls) > 0 {
|
| 453 |
+
stopReason = "tool_calls"
|
| 454 |
+
}
|
| 455 |
finalChunk := model.ChatCompletionChunk{
|
| 456 |
ID: completionID,
|
| 457 |
Object: "chat.completion.chunk",
|
|
|
|
| 470 |
flusher.Flush()
|
| 471 |
}
|
| 472 |
|
| 473 |
+
func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string, tools []model.Tool) {
|
| 474 |
scanner := bufio.NewScanner(body)
|
| 475 |
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
| 476 |
var chunks []string
|
|
|
|
| 480 |
hasThinking := false
|
| 481 |
pendingSourcesMarkdown := ""
|
| 482 |
pendingImageSearchMarkdown := ""
|
| 483 |
+
var collectedToolCalls []model.ToolCall
|
| 484 |
|
| 485 |
for scanner.Scan() {
|
| 486 |
line := scanner.Text()
|
|
|
|
| 553 |
if editContent != "" && filter.IsSearchToolCall(editContent, upstreamData.Data.Phase) {
|
| 554 |
continue
|
| 555 |
}
|
| 556 |
+
// 检测用户定义的函数调用
|
| 557 |
+
if upstreamData.Data.Phase == "tool_call" && editContent != "" {
|
| 558 |
+
logger.LogInfo("[ToolCall] phase=%s edit_content=%s", upstreamData.Data.Phase, editContent)
|
| 559 |
+
}
|
| 560 |
+
if len(tools) > 0 && editContent != "" && filter.IsFunctionToolCall(editContent, upstreamData.Data.Phase) {
|
| 561 |
+
if toolCalls := filter.ParseFunctionToolCalls(editContent); len(toolCalls) > 0 {
|
| 562 |
+
for i := range toolCalls {
|
| 563 |
+
if toolCalls[i].ID == "" {
|
| 564 |
+
toolCalls[i].ID = fmt.Sprintf("call_%s", uuid.New().String()[:24])
|
| 565 |
+
}
|
| 566 |
+
toolCalls[i].Index = i
|
| 567 |
+
}
|
| 568 |
+
collectedToolCalls = toolCalls
|
| 569 |
+
}
|
| 570 |
+
continue
|
| 571 |
+
}
|
| 572 |
|
| 573 |
if pendingSourcesMarkdown != "" {
|
| 574 |
if hasThinking {
|
|
|
|
| 616 |
fullReasoning := strings.Join(reasoningChunks, "")
|
| 617 |
fullReasoning = searchRefFilter.Process(fullReasoning) + searchRefFilter.Flush()
|
| 618 |
|
| 619 |
+
if fullContent == "" && len(collectedToolCalls) == 0 {
|
| 620 |
logger.LogError("Non-stream response 200 but no content received")
|
| 621 |
}
|
| 622 |
|
| 623 |
stopReason := "stop"
|
| 624 |
+
if len(collectedToolCalls) > 0 {
|
| 625 |
+
stopReason = "tool_calls"
|
| 626 |
+
}
|
| 627 |
response := model.ChatCompletionResponse{
|
| 628 |
ID: completionID,
|
| 629 |
Object: "chat.completion",
|
|
|
|
| 635 |
Role: "assistant",
|
| 636 |
Content: fullContent,
|
| 637 |
ReasoningContent: fullReasoning,
|
| 638 |
+
ToolCalls: collectedToolCalls,
|
| 639 |
},
|
| 640 |
FinishReason: &stopReason,
|
| 641 |
}},
|
internal/handler/chat_test.go
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"fmt"
|
| 6 |
+
"io"
|
| 7 |
+
"net/http/httptest"
|
| 8 |
+
"strings"
|
| 9 |
+
"testing"
|
| 10 |
+
|
| 11 |
+
"zai-proxy/internal/model"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
// fakeReadCloser 将 string 包装为 io.ReadCloser
|
| 15 |
+
type fakeReadCloser struct {
|
| 16 |
+
io.Reader
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
func (f *fakeReadCloser) Close() error { return nil }
|
| 20 |
+
|
| 21 |
+
func newFakeBody(lines ...string) io.ReadCloser {
|
| 22 |
+
return &fakeReadCloser{Reader: strings.NewReader(strings.Join(lines, "\n"))}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// 构造上游 SSE 数据行
|
| 26 |
+
func sseEvent(phase, deltaContent, editContent string) string {
|
| 27 |
+
data := model.UpstreamData{}
|
| 28 |
+
data.Data.Phase = phase
|
| 29 |
+
data.Data.DeltaContent = deltaContent
|
| 30 |
+
data.Data.EditContent = editContent
|
| 31 |
+
b, _ := json.Marshal(data)
|
| 32 |
+
return fmt.Sprintf("data: %s", string(b))
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
func sseEventDone() string {
|
| 36 |
+
return sseEvent("done", "", "")
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
func dummyTools() []model.Tool {
|
| 40 |
+
return []model.Tool{{
|
| 41 |
+
Type: "function",
|
| 42 |
+
Function: model.ToolFunction{
|
| 43 |
+
Name: "get_weather",
|
| 44 |
+
Description: "获取天气",
|
| 45 |
+
},
|
| 46 |
+
}}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// ===== 流式:普通文本回复 =====
|
| 50 |
+
|
| 51 |
+
func TestStreamResponse_NormalContent(t *testing.T) {
|
| 52 |
+
body := newFakeBody(
|
| 53 |
+
sseEvent("answer", "Hello", ""),
|
| 54 |
+
sseEvent("answer", " World", ""),
|
| 55 |
+
sseEventDone(),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
w := httptest.NewRecorder()
|
| 59 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 60 |
+
|
| 61 |
+
result := w.Body.String()
|
| 62 |
+
|
| 63 |
+
// 应包含内容 chunk
|
| 64 |
+
if !strings.Contains(result, "Hello") {
|
| 65 |
+
t.Error("missing 'Hello' in stream output")
|
| 66 |
+
}
|
| 67 |
+
if !strings.Contains(result, "World") {
|
| 68 |
+
t.Error("missing 'World' in stream output")
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
// finish_reason 应该是 "stop"
|
| 72 |
+
if !strings.Contains(result, `"finish_reason":"stop"`) {
|
| 73 |
+
t.Error("finish_reason should be 'stop'")
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// 应以 [DONE] 结尾
|
| 77 |
+
if !strings.Contains(result, "data: [DONE]") {
|
| 78 |
+
t.Error("missing [DONE]")
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// ===== 流式:tool_call 回复 =====
|
| 83 |
+
|
| 84 |
+
func TestStreamResponse_ToolCall(t *testing.T) {
|
| 85 |
+
toolCallJSON := `{"id":"call_test123","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"北京\"}"}}`
|
| 86 |
+
|
| 87 |
+
body := newFakeBody(
|
| 88 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 89 |
+
sseEventDone(),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
w := httptest.NewRecorder()
|
| 93 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 94 |
+
|
| 95 |
+
result := w.Body.String()
|
| 96 |
+
|
| 97 |
+
// 应包含 tool_calls
|
| 98 |
+
if !strings.Contains(result, `"tool_calls"`) {
|
| 99 |
+
t.Error("missing tool_calls in stream output")
|
| 100 |
+
}
|
| 101 |
+
if !strings.Contains(result, `"get_weather"`) {
|
| 102 |
+
t.Error("missing function name in stream output")
|
| 103 |
+
}
|
| 104 |
+
if !strings.Contains(result, `call_test123`) {
|
| 105 |
+
t.Error("missing tool call ID in stream output")
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// finish_reason 应该是 "tool_calls"
|
| 109 |
+
if !strings.Contains(result, `"finish_reason":"tool_calls"`) {
|
| 110 |
+
t.Error("finish_reason should be 'tool_calls'")
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// ===== 流式:tool_call 无 ID(自动分配)=====
|
| 115 |
+
|
| 116 |
+
func TestStreamResponse_ToolCallAutoID(t *testing.T) {
|
| 117 |
+
toolCallJSON := `{"type":"function","function":{"name":"get_weather","arguments":"{}"}}`
|
| 118 |
+
|
| 119 |
+
body := newFakeBody(
|
| 120 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 121 |
+
sseEventDone(),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
w := httptest.NewRecorder()
|
| 125 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 126 |
+
|
| 127 |
+
result := w.Body.String()
|
| 128 |
+
|
| 129 |
+
// 应自动分配 call_ 前缀的 ID
|
| 130 |
+
if !strings.Contains(result, `"id":"call_`) {
|
| 131 |
+
t.Error("missing auto-generated tool call ID")
|
| 132 |
+
}
|
| 133 |
+
if !strings.Contains(result, `"finish_reason":"tool_calls"`) {
|
| 134 |
+
t.Error("finish_reason should be 'tool_calls'")
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// ===== 流式:无 tools 时 tool_call 阶段被忽略 =====
|
| 139 |
+
|
| 140 |
+
func TestStreamResponse_ToolCallWithoutToolsDef(t *testing.T) {
|
| 141 |
+
toolCallJSON := `{"type":"function","function":{"name":"get_weather","arguments":"{}"}}`
|
| 142 |
+
|
| 143 |
+
body := newFakeBody(
|
| 144 |
+
sseEvent("answer", "text before", ""),
|
| 145 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 146 |
+
sseEventDone(),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
w := httptest.NewRecorder()
|
| 150 |
+
// 不传 tools,tool_call 不应被解析为函数调用
|
| 151 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 152 |
+
|
| 153 |
+
result := w.Body.String()
|
| 154 |
+
|
| 155 |
+
// finish_reason 应为 "stop"(没有检测到函数调用)
|
| 156 |
+
if !strings.Contains(result, `"finish_reason":"stop"`) {
|
| 157 |
+
t.Error("finish_reason should be 'stop' when no tools defined")
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// ===== 流式:mcp tool_call 被跳过 =====
|
| 162 |
+
|
| 163 |
+
func TestStreamResponse_McpToolCallSkipped(t *testing.T) {
|
| 164 |
+
mcpContent := `{"type":"mcp","name":"mcp-server-xxx","arguments":"{}"}`
|
| 165 |
+
|
| 166 |
+
body := newFakeBody(
|
| 167 |
+
sseEvent("answer", "response text", ""),
|
| 168 |
+
sseEvent("tool_call", "", mcpContent),
|
| 169 |
+
sseEventDone(),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
w := httptest.NewRecorder()
|
| 173 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 174 |
+
|
| 175 |
+
result := w.Body.String()
|
| 176 |
+
|
| 177 |
+
// mcp 类型的 tool_call 不应出现在输出中
|
| 178 |
+
if strings.Contains(result, `mcp-server`) {
|
| 179 |
+
t.Error("mcp tool call should be filtered out")
|
| 180 |
+
}
|
| 181 |
+
// 应为 "stop"(mcp 不算用户函数调用)
|
| 182 |
+
if !strings.Contains(result, `"finish_reason":"stop"`) {
|
| 183 |
+
t.Error("finish_reason should be 'stop'")
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// ===== 流式:混合内容 + tool_call =====
|
| 188 |
+
|
| 189 |
+
func TestStreamResponse_ContentThenToolCall(t *testing.T) {
|
| 190 |
+
toolCallJSON := `{"function":{"name":"get_weather","arguments":"{}"}}`
|
| 191 |
+
|
| 192 |
+
body := newFakeBody(
|
| 193 |
+
sseEvent("answer", "Let me check ", ""),
|
| 194 |
+
sseEvent("answer", "the weather.", ""),
|
| 195 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 196 |
+
sseEventDone(),
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
w := httptest.NewRecorder()
|
| 200 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 201 |
+
|
| 202 |
+
result := w.Body.String()
|
| 203 |
+
|
| 204 |
+
if !strings.Contains(result, "Let me check") {
|
| 205 |
+
t.Error("missing content text")
|
| 206 |
+
}
|
| 207 |
+
if !strings.Contains(result, `"get_weather"`) {
|
| 208 |
+
t.Error("missing tool call")
|
| 209 |
+
}
|
| 210 |
+
if !strings.Contains(result, `"finish_reason":"tool_calls"`) {
|
| 211 |
+
t.Error("finish_reason should be 'tool_calls'")
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// ===== 流式:多个 tool_call =====
|
| 216 |
+
|
| 217 |
+
func TestStreamResponse_MultipleToolCalls(t *testing.T) {
|
| 218 |
+
toolCallJSON := `[{"id":"c1","type":"function","function":{"name":"fn1","arguments":"{}"}},{"id":"c2","type":"function","function":{"name":"fn2","arguments":"{}"}}]`
|
| 219 |
+
|
| 220 |
+
body := newFakeBody(
|
| 221 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 222 |
+
sseEventDone(),
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
w := httptest.NewRecorder()
|
| 226 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 227 |
+
|
| 228 |
+
result := w.Body.String()
|
| 229 |
+
|
| 230 |
+
if !strings.Contains(result, `"fn1"`) {
|
| 231 |
+
t.Error("missing fn1")
|
| 232 |
+
}
|
| 233 |
+
if !strings.Contains(result, `"fn2"`) {
|
| 234 |
+
t.Error("missing fn2")
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
// 验证 chunk 数量:每个 tool_call 一个 delta chunk(包含 "tool_calls" 在 delta 中)
|
| 238 |
+
chunks := strings.Split(result, "data: ")
|
| 239 |
+
toolCallDeltaChunks := 0
|
| 240 |
+
for _, chunk := range chunks {
|
| 241 |
+
// 只计算 delta 中包含 tool_calls 的 chunk,排除 finish_reason 中的
|
| 242 |
+
if strings.Contains(chunk, `"tool_calls":[{`) {
|
| 243 |
+
toolCallDeltaChunks++
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
if toolCallDeltaChunks != 2 {
|
| 247 |
+
t.Errorf("tool_call delta chunks = %d, want 2", toolCallDeltaChunks)
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
// ===== 非流式:普通文本回复 =====
|
| 252 |
+
|
| 253 |
+
func TestNonStreamResponse_NormalContent(t *testing.T) {
|
| 254 |
+
body := newFakeBody(
|
| 255 |
+
sseEvent("answer", "Hello World", ""),
|
| 256 |
+
sseEventDone(),
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
w := httptest.NewRecorder()
|
| 260 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 261 |
+
|
| 262 |
+
var resp model.ChatCompletionResponse
|
| 263 |
+
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
| 264 |
+
t.Fatalf("decode response: %v", err)
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
if len(resp.Choices) != 1 {
|
| 268 |
+
t.Fatalf("len(Choices) = %d", len(resp.Choices))
|
| 269 |
+
}
|
| 270 |
+
if resp.Choices[0].Message == nil {
|
| 271 |
+
t.Fatal("Message is nil")
|
| 272 |
+
}
|
| 273 |
+
if resp.Choices[0].Message.Content != "Hello World" {
|
| 274 |
+
t.Errorf("Content = %q, want %q", resp.Choices[0].Message.Content, "Hello World")
|
| 275 |
+
}
|
| 276 |
+
if *resp.Choices[0].FinishReason != "stop" {
|
| 277 |
+
t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "stop")
|
| 278 |
+
}
|
| 279 |
+
if len(resp.Choices[0].Message.ToolCalls) != 0 {
|
| 280 |
+
t.Errorf("len(ToolCalls) = %d, want 0", len(resp.Choices[0].Message.ToolCalls))
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// ===== 非流式:tool_call 回复 =====
|
| 285 |
+
|
| 286 |
+
func TestNonStreamResponse_ToolCall(t *testing.T) {
|
| 287 |
+
toolCallJSON := `{"id":"call_ns","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"上海\"}"}}`
|
| 288 |
+
|
| 289 |
+
body := newFakeBody(
|
| 290 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 291 |
+
sseEventDone(),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
w := httptest.NewRecorder()
|
| 295 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 296 |
+
|
| 297 |
+
var resp model.ChatCompletionResponse
|
| 298 |
+
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
| 299 |
+
t.Fatalf("decode: %v", err)
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
msg := resp.Choices[0].Message
|
| 303 |
+
if msg == nil {
|
| 304 |
+
t.Fatal("Message is nil")
|
| 305 |
+
}
|
| 306 |
+
if len(msg.ToolCalls) != 1 {
|
| 307 |
+
t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
|
| 308 |
+
}
|
| 309 |
+
if msg.ToolCalls[0].Function.Name != "get_weather" {
|
| 310 |
+
t.Errorf("Function.Name = %q, want %q", msg.ToolCalls[0].Function.Name, "get_weather")
|
| 311 |
+
}
|
| 312 |
+
if msg.ToolCalls[0].Function.Arguments != `{"location":"上海"}` {
|
| 313 |
+
t.Errorf("Function.Arguments = %q", msg.ToolCalls[0].Function.Arguments)
|
| 314 |
+
}
|
| 315 |
+
if *resp.Choices[0].FinishReason != "tool_calls" {
|
| 316 |
+
t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "tool_calls")
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
// ===== 非流式:tool_call 无 ID =====
|
| 321 |
+
|
| 322 |
+
func TestNonStreamResponse_ToolCallAutoID(t *testing.T) {
|
| 323 |
+
toolCallJSON := `{"function":{"name":"fn1","arguments":"{}"}}`
|
| 324 |
+
|
| 325 |
+
body := newFakeBody(
|
| 326 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 327 |
+
sseEventDone(),
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
w := httptest.NewRecorder()
|
| 331 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 332 |
+
|
| 333 |
+
var resp model.ChatCompletionResponse
|
| 334 |
+
json.NewDecoder(w.Body).Decode(&resp)
|
| 335 |
+
|
| 336 |
+
msg := resp.Choices[0].Message
|
| 337 |
+
if len(msg.ToolCalls) != 1 {
|
| 338 |
+
t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
|
| 339 |
+
}
|
| 340 |
+
if !strings.HasPrefix(msg.ToolCalls[0].ID, "call_") {
|
| 341 |
+
t.Errorf("ID = %q, should have 'call_' prefix", msg.ToolCalls[0].ID)
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
// ===== 非流式:无 tools 定义时不解析 tool_call =====
|
| 346 |
+
|
| 347 |
+
func TestNonStreamResponse_ToolCallWithoutToolsDef(t *testing.T) {
|
| 348 |
+
toolCallJSON := `{"function":{"name":"get_weather","arguments":"{}"}}`
|
| 349 |
+
|
| 350 |
+
body := newFakeBody(
|
| 351 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 352 |
+
sseEventDone(),
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
w := httptest.NewRecorder()
|
| 356 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 357 |
+
|
| 358 |
+
var resp model.ChatCompletionResponse
|
| 359 |
+
json.NewDecoder(w.Body).Decode(&resp)
|
| 360 |
+
|
| 361 |
+
if *resp.Choices[0].FinishReason != "stop" {
|
| 362 |
+
t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "stop")
|
| 363 |
+
}
|
| 364 |
+
if len(resp.Choices[0].Message.ToolCalls) != 0 {
|
| 365 |
+
t.Errorf("len(ToolCalls) = %d, want 0", len(resp.Choices[0].Message.ToolCalls))
|
| 366 |
+
}
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
// ===== 非流式:mcp tool_call 被跳过 =====
|
| 370 |
+
|
| 371 |
+
func TestNonStreamResponse_McpToolCallSkipped(t *testing.T) {
|
| 372 |
+
mcpContent := `{"type":"mcp","name":"mcp-server-xxx","arguments":"{}"}`
|
| 373 |
+
|
| 374 |
+
body := newFakeBody(
|
| 375 |
+
sseEvent("answer", "response", ""),
|
| 376 |
+
sseEvent("tool_call", "", mcpContent),
|
| 377 |
+
sseEventDone(),
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
w := httptest.NewRecorder()
|
| 381 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 382 |
+
|
| 383 |
+
var resp model.ChatCompletionResponse
|
| 384 |
+
json.NewDecoder(w.Body).Decode(&resp)
|
| 385 |
+
|
| 386 |
+
if *resp.Choices[0].FinishReason != "stop" {
|
| 387 |
+
t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "stop")
|
| 388 |
+
}
|
| 389 |
+
if len(resp.Choices[0].Message.ToolCalls) != 0 {
|
| 390 |
+
t.Errorf("should not have tool_calls for mcp")
|
| 391 |
+
}
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
// ===== 非流式:内容 + tool_call =====
|
| 395 |
+
|
| 396 |
+
func TestNonStreamResponse_ContentAndToolCall(t *testing.T) {
|
| 397 |
+
toolCallJSON := `{"function":{"name":"get_weather","arguments":"{}"}}`
|
| 398 |
+
|
| 399 |
+
body := newFakeBody(
|
| 400 |
+
sseEvent("answer", "checking weather...", ""),
|
| 401 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 402 |
+
sseEventDone(),
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
w := httptest.NewRecorder()
|
| 406 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 407 |
+
|
| 408 |
+
var resp model.ChatCompletionResponse
|
| 409 |
+
json.NewDecoder(w.Body).Decode(&resp)
|
| 410 |
+
|
| 411 |
+
msg := resp.Choices[0].Message
|
| 412 |
+
if msg.Content != "checking weather..." {
|
| 413 |
+
t.Errorf("Content = %q, want %q", msg.Content, "checking weather...")
|
| 414 |
+
}
|
| 415 |
+
if len(msg.ToolCalls) != 1 {
|
| 416 |
+
t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
|
| 417 |
+
}
|
| 418 |
+
if *resp.Choices[0].FinishReason != "tool_calls" {
|
| 419 |
+
t.Errorf("FinishReason = %q, want %q", *resp.Choices[0].FinishReason, "tool_calls")
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
// ===== 非流式:多个 tool_call =====
|
| 424 |
+
|
| 425 |
+
func TestNonStreamResponse_MultipleToolCalls(t *testing.T) {
|
| 426 |
+
toolCallJSON := `[{"id":"c1","type":"function","function":{"name":"fn1","arguments":"{}"}},{"id":"c2","type":"function","function":{"name":"fn2","arguments":"{\"x\":1}"}}]`
|
| 427 |
+
|
| 428 |
+
body := newFakeBody(
|
| 429 |
+
sseEvent("tool_call", "", toolCallJSON),
|
| 430 |
+
sseEventDone(),
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
w := httptest.NewRecorder()
|
| 434 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 435 |
+
|
| 436 |
+
var resp model.ChatCompletionResponse
|
| 437 |
+
json.NewDecoder(w.Body).Decode(&resp)
|
| 438 |
+
|
| 439 |
+
msg := resp.Choices[0].Message
|
| 440 |
+
if len(msg.ToolCalls) != 2 {
|
| 441 |
+
t.Fatalf("len(ToolCalls) = %d, want 2", len(msg.ToolCalls))
|
| 442 |
+
}
|
| 443 |
+
if msg.ToolCalls[0].Function.Name != "fn1" {
|
| 444 |
+
t.Errorf("ToolCalls[0].Function.Name = %q", msg.ToolCalls[0].Function.Name)
|
| 445 |
+
}
|
| 446 |
+
if msg.ToolCalls[1].Function.Name != "fn2" {
|
| 447 |
+
t.Errorf("ToolCalls[1].Function.Name = %q", msg.ToolCalls[1].Function.Name)
|
| 448 |
+
}
|
| 449 |
+
if msg.ToolCalls[0].Index != 0 || msg.ToolCalls[1].Index != 1 {
|
| 450 |
+
t.Errorf("Indices = [%d, %d], want [0, 1]", msg.ToolCalls[0].Index, msg.ToolCalls[1].Index)
|
| 451 |
+
}
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
// ===== 非流式:glm_block 包裹的 tool_call =====
|
| 455 |
+
|
| 456 |
+
func TestNonStreamResponse_GlmBlockToolCall(t *testing.T) {
|
| 457 |
+
editContent := `<glm_block type="tool_call">{"id":"call_glm","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"深圳\"}"}}</glm_block>`
|
| 458 |
+
|
| 459 |
+
body := newFakeBody(
|
| 460 |
+
sseEvent("tool_call", "", editContent),
|
| 461 |
+
sseEventDone(),
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
w := httptest.NewRecorder()
|
| 465 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", dummyTools())
|
| 466 |
+
|
| 467 |
+
var resp model.ChatCompletionResponse
|
| 468 |
+
json.NewDecoder(w.Body).Decode(&resp)
|
| 469 |
+
|
| 470 |
+
msg := resp.Choices[0].Message
|
| 471 |
+
if len(msg.ToolCalls) != 1 {
|
| 472 |
+
t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
|
| 473 |
+
}
|
| 474 |
+
if msg.ToolCalls[0].ID != "call_glm" {
|
| 475 |
+
t.Errorf("ID = %q, want %q", msg.ToolCalls[0].ID, "call_glm")
|
| 476 |
+
}
|
| 477 |
+
if msg.ToolCalls[0].Function.Name != "get_weather" {
|
| 478 |
+
t.Errorf("Function.Name = %q", msg.ToolCalls[0].Function.Name)
|
| 479 |
+
}
|
| 480 |
+
if *resp.Choices[0].FinishReason != "tool_calls" {
|
| 481 |
+
t.Errorf("FinishReason = %q", *resp.Choices[0].FinishReason)
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
// ===== 流式:SSE headers 验证 =====
|
| 486 |
+
|
| 487 |
+
func TestStreamResponse_Headers(t *testing.T) {
|
| 488 |
+
body := newFakeBody(sseEventDone())
|
| 489 |
+
|
| 490 |
+
w := httptest.NewRecorder()
|
| 491 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 492 |
+
|
| 493 |
+
if ct := w.Header().Get("Content-Type"); ct != "text/event-stream" {
|
| 494 |
+
t.Errorf("Content-Type = %q, want %q", ct, "text/event-stream")
|
| 495 |
+
}
|
| 496 |
+
if cc := w.Header().Get("Cache-Control"); cc != "no-cache" {
|
| 497 |
+
t.Errorf("Cache-Control = %q, want %q", cc, "no-cache")
|
| 498 |
+
}
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
// ===== 非流式:response headers 验证 =====
|
| 502 |
+
|
| 503 |
+
func TestNonStreamResponse_Headers(t *testing.T) {
|
| 504 |
+
body := newFakeBody(sseEventDone())
|
| 505 |
+
|
| 506 |
+
w := httptest.NewRecorder()
|
| 507 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 508 |
+
|
| 509 |
+
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
| 510 |
+
t.Errorf("Content-Type = %q, want %q", ct, "application/json")
|
| 511 |
+
}
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
// ===== 流式:空数据 =====
|
| 515 |
+
|
| 516 |
+
func TestStreamResponse_EmptyBody(t *testing.T) {
|
| 517 |
+
body := newFakeBody(sseEventDone())
|
| 518 |
+
|
| 519 |
+
w := httptest.NewRecorder()
|
| 520 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 521 |
+
|
| 522 |
+
result := w.Body.String()
|
| 523 |
+
if !strings.Contains(result, `"finish_reason":"stop"`) {
|
| 524 |
+
t.Error("should have stop finish_reason")
|
| 525 |
+
}
|
| 526 |
+
if !strings.Contains(result, "data: [DONE]") {
|
| 527 |
+
t.Error("missing [DONE]")
|
| 528 |
+
}
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
// ===== 流式:[DONE] 信号 =====
|
| 532 |
+
|
| 533 |
+
func TestStreamResponse_DoneSignal(t *testing.T) {
|
| 534 |
+
body := newFakeBody(
|
| 535 |
+
sseEvent("answer", "hello", ""),
|
| 536 |
+
"data: [DONE]",
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
w := httptest.NewRecorder()
|
| 540 |
+
handleStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 541 |
+
|
| 542 |
+
result := w.Body.String()
|
| 543 |
+
if !strings.Contains(result, "hello") {
|
| 544 |
+
t.Error("missing content")
|
| 545 |
+
}
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
// ===== 非流式:response 格式完整性 =====
|
| 549 |
+
|
| 550 |
+
func TestNonStreamResponse_FullFormat(t *testing.T) {
|
| 551 |
+
body := newFakeBody(
|
| 552 |
+
sseEvent("answer", "test response", ""),
|
| 553 |
+
sseEventDone(),
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
w := httptest.NewRecorder()
|
| 557 |
+
handleNonStreamResponse(w, body, "chatcmpl-test", "glm-4.7", nil)
|
| 558 |
+
|
| 559 |
+
var resp model.ChatCompletionResponse
|
| 560 |
+
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
| 561 |
+
t.Fatalf("decode: %v", err)
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
if resp.ID != "chatcmpl-test" {
|
| 565 |
+
t.Errorf("ID = %q", resp.ID)
|
| 566 |
+
}
|
| 567 |
+
if resp.Object != "chat.completion" {
|
| 568 |
+
t.Errorf("Object = %q", resp.Object)
|
| 569 |
+
}
|
| 570 |
+
if resp.Model != "glm-4.7" {
|
| 571 |
+
t.Errorf("Model = %q", resp.Model)
|
| 572 |
+
}
|
| 573 |
+
if resp.Choices[0].Message.Role != "assistant" {
|
| 574 |
+
t.Errorf("Role = %q", resp.Choices[0].Message.Role)
|
| 575 |
+
}
|
| 576 |
+
}
|
internal/model/mapping.go
CHANGED
|
@@ -20,6 +20,8 @@ var ModelList = []string{
|
|
| 20 |
"GLM-4.7",
|
| 21 |
"GLM-4.7-thinking",
|
| 22 |
"GLM-4.7-thinking-search",
|
|
|
|
|
|
|
| 23 |
"GLM-4.5-V",
|
| 24 |
"GLM-4.6-V",
|
| 25 |
"GLM-4.6-V-thinking",
|
|
@@ -28,13 +30,14 @@ var ModelList = []string{
|
|
| 28 |
}
|
| 29 |
|
| 30 |
// 解析模型名称,提取基础模型名和标签
|
| 31 |
-
// 支持 -thinking 和 -
|
| 32 |
-
func ParseModelName(model string) (baseModel string, enableThinking bool, enableSearch bool) {
|
| 33 |
enableThinking = false
|
| 34 |
enableSearch = false
|
|
|
|
| 35 |
baseModel = model
|
| 36 |
|
| 37 |
-
// 检查并移除 -thinking 和 -
|
| 38 |
for {
|
| 39 |
if strings.HasSuffix(baseModel, "-thinking") {
|
| 40 |
enableThinking = true
|
|
@@ -42,26 +45,34 @@ func ParseModelName(model string) (baseModel string, enableThinking bool, enable
|
|
| 42 |
} else if strings.HasSuffix(baseModel, "-search") {
|
| 43 |
enableSearch = true
|
| 44 |
baseModel = strings.TrimSuffix(baseModel, "-search")
|
|
|
|
|
|
|
|
|
|
| 45 |
} else {
|
| 46 |
break
|
| 47 |
}
|
| 48 |
}
|
| 49 |
|
| 50 |
-
return baseModel, enableThinking, enableSearch
|
| 51 |
}
|
| 52 |
|
| 53 |
func IsThinkingModel(model string) bool {
|
| 54 |
-
_, enableThinking, _ := ParseModelName(model)
|
| 55 |
return enableThinking
|
| 56 |
}
|
| 57 |
|
| 58 |
func IsSearchModel(model string) bool {
|
| 59 |
-
_, _, enableSearch := ParseModelName(model)
|
| 60 |
return enableSearch
|
| 61 |
}
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
func GetTargetModel(model string) string {
|
| 64 |
-
baseModel, _, _ := ParseModelName(model)
|
| 65 |
if target, ok := BaseModelMapping[baseModel]; ok {
|
| 66 |
return target
|
| 67 |
}
|
|
|
|
| 20 |
"GLM-4.7",
|
| 21 |
"GLM-4.7-thinking",
|
| 22 |
"GLM-4.7-thinking-search",
|
| 23 |
+
"GLM-4.7-tools",
|
| 24 |
+
"GLM-4.7-tools-thinking",
|
| 25 |
"GLM-4.5-V",
|
| 26 |
"GLM-4.6-V",
|
| 27 |
"GLM-4.6-V-thinking",
|
|
|
|
| 30 |
}
|
| 31 |
|
| 32 |
// 解析模型名称,提取基础模型名和标签
|
| 33 |
+
// 支持 -thinking、-search 和 -tools 标签的任意排列组合
|
| 34 |
+
func ParseModelName(model string) (baseModel string, enableThinking bool, enableSearch bool, enableTools bool) {
|
| 35 |
enableThinking = false
|
| 36 |
enableSearch = false
|
| 37 |
+
enableTools = false
|
| 38 |
baseModel = model
|
| 39 |
|
| 40 |
+
// 检查并移除 -thinking、-search 和 -tools 标签(任意顺序)
|
| 41 |
for {
|
| 42 |
if strings.HasSuffix(baseModel, "-thinking") {
|
| 43 |
enableThinking = true
|
|
|
|
| 45 |
} else if strings.HasSuffix(baseModel, "-search") {
|
| 46 |
enableSearch = true
|
| 47 |
baseModel = strings.TrimSuffix(baseModel, "-search")
|
| 48 |
+
} else if strings.HasSuffix(baseModel, "-tools") {
|
| 49 |
+
enableTools = true
|
| 50 |
+
baseModel = strings.TrimSuffix(baseModel, "-tools")
|
| 51 |
} else {
|
| 52 |
break
|
| 53 |
}
|
| 54 |
}
|
| 55 |
|
| 56 |
+
return baseModel, enableThinking, enableSearch, enableTools
|
| 57 |
}
|
| 58 |
|
| 59 |
func IsThinkingModel(model string) bool {
|
| 60 |
+
_, enableThinking, _, _ := ParseModelName(model)
|
| 61 |
return enableThinking
|
| 62 |
}
|
| 63 |
|
| 64 |
func IsSearchModel(model string) bool {
|
| 65 |
+
_, _, enableSearch, _ := ParseModelName(model)
|
| 66 |
return enableSearch
|
| 67 |
}
|
| 68 |
|
| 69 |
+
func IsToolsModel(model string) bool {
|
| 70 |
+
_, _, _, enableTools := ParseModelName(model)
|
| 71 |
+
return enableTools
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
func GetTargetModel(model string) string {
|
| 75 |
+
baseModel, _, _, _ := ParseModelName(model)
|
| 76 |
if target, ok := BaseModelMapping[baseModel]; ok {
|
| 77 |
return target
|
| 78 |
}
|
internal/model/mapping_test.go
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package model
|
| 2 |
+
|
| 3 |
+
import "testing"
|
| 4 |
+
|
| 5 |
+
// ===== ParseModelName =====
|
| 6 |
+
|
| 7 |
+
func TestParseModelName_Plain(t *testing.T) {
|
| 8 |
+
base, thinking, search, tools := ParseModelName("GLM-4.7")
|
| 9 |
+
if base != "GLM-4.7" {
|
| 10 |
+
t.Errorf("base = %q, want %q", base, "GLM-4.7")
|
| 11 |
+
}
|
| 12 |
+
if thinking || search || tools {
|
| 13 |
+
t.Errorf("flags = (%v, %v, %v), want all false", thinking, search, tools)
|
| 14 |
+
}
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
func TestParseModelName_Thinking(t *testing.T) {
|
| 18 |
+
base, thinking, search, tools := ParseModelName("GLM-4.7-thinking")
|
| 19 |
+
if base != "GLM-4.7" {
|
| 20 |
+
t.Errorf("base = %q", base)
|
| 21 |
+
}
|
| 22 |
+
if !thinking {
|
| 23 |
+
t.Error("thinking should be true")
|
| 24 |
+
}
|
| 25 |
+
if search || tools {
|
| 26 |
+
t.Error("search and tools should be false")
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
func TestParseModelName_Search(t *testing.T) {
|
| 31 |
+
base, thinking, search, tools := ParseModelName("GLM-4.7-search")
|
| 32 |
+
if base != "GLM-4.7" {
|
| 33 |
+
t.Errorf("base = %q", base)
|
| 34 |
+
}
|
| 35 |
+
if !search {
|
| 36 |
+
t.Error("search should be true")
|
| 37 |
+
}
|
| 38 |
+
if thinking || tools {
|
| 39 |
+
t.Error("thinking and tools should be false")
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
func TestParseModelName_Tools(t *testing.T) {
|
| 44 |
+
base, thinking, search, tools := ParseModelName("GLM-4.7-tools")
|
| 45 |
+
if base != "GLM-4.7" {
|
| 46 |
+
t.Errorf("base = %q", base)
|
| 47 |
+
}
|
| 48 |
+
if !tools {
|
| 49 |
+
t.Error("tools should be true")
|
| 50 |
+
}
|
| 51 |
+
if thinking || search {
|
| 52 |
+
t.Error("thinking and search should be false")
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
func TestParseModelName_ThinkingSearch(t *testing.T) {
|
| 57 |
+
base, thinking, search, tools := ParseModelName("GLM-4.7-thinking-search")
|
| 58 |
+
if base != "GLM-4.7" {
|
| 59 |
+
t.Errorf("base = %q", base)
|
| 60 |
+
}
|
| 61 |
+
if !thinking || !search {
|
| 62 |
+
t.Error("thinking and search should both be true")
|
| 63 |
+
}
|
| 64 |
+
if tools {
|
| 65 |
+
t.Error("tools should be false")
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
func TestParseModelName_ToolsThinking(t *testing.T) {
|
| 70 |
+
base, thinking, search, tools := ParseModelName("GLM-4.7-tools-thinking")
|
| 71 |
+
if base != "GLM-4.7" {
|
| 72 |
+
t.Errorf("base = %q", base)
|
| 73 |
+
}
|
| 74 |
+
if !tools || !thinking {
|
| 75 |
+
t.Error("tools and thinking should both be true")
|
| 76 |
+
}
|
| 77 |
+
if search {
|
| 78 |
+
t.Error("search should be false")
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
func TestParseModelName_ToolsSearch(t *testing.T) {
|
| 83 |
+
base, thinking, search, tools := ParseModelName("GLM-4.7-tools-search")
|
| 84 |
+
if base != "GLM-4.7" {
|
| 85 |
+
t.Errorf("base = %q", base)
|
| 86 |
+
}
|
| 87 |
+
if !tools || !search {
|
| 88 |
+
t.Error("tools and search should both be true")
|
| 89 |
+
}
|
| 90 |
+
if thinking {
|
| 91 |
+
t.Error("thinking should be false")
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
func TestParseModelName_AllTags(t *testing.T) {
|
| 96 |
+
base, thinking, search, tools := ParseModelName("GLM-4.7-tools-thinking-search")
|
| 97 |
+
if base != "GLM-4.7" {
|
| 98 |
+
t.Errorf("base = %q", base)
|
| 99 |
+
}
|
| 100 |
+
if !thinking || !search || !tools {
|
| 101 |
+
t.Errorf("all flags should be true, got (%v, %v, %v)", thinking, search, tools)
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
func TestParseModelName_ReverseOrder(t *testing.T) {
|
| 106 |
+
base, thinking, search, tools := ParseModelName("GLM-4.7-search-thinking-tools")
|
| 107 |
+
if base != "GLM-4.7" {
|
| 108 |
+
t.Errorf("base = %q", base)
|
| 109 |
+
}
|
| 110 |
+
if !thinking || !search || !tools {
|
| 111 |
+
t.Errorf("all flags should be true, got (%v, %v, %v)", thinking, search, tools)
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// ===== IsToolsModel =====
|
| 116 |
+
|
| 117 |
+
func TestIsToolsModel_True(t *testing.T) {
|
| 118 |
+
tests := []string{
|
| 119 |
+
"GLM-4.7-tools",
|
| 120 |
+
"GLM-4.7-tools-thinking",
|
| 121 |
+
"GLM-4.7-tools-search",
|
| 122 |
+
"GLM-4.7-thinking-tools",
|
| 123 |
+
"GLM-4.5-tools",
|
| 124 |
+
}
|
| 125 |
+
for _, m := range tests {
|
| 126 |
+
if !IsToolsModel(m) {
|
| 127 |
+
t.Errorf("IsToolsModel(%q) = false, want true", m)
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
func TestIsToolsModel_False(t *testing.T) {
|
| 133 |
+
tests := []string{
|
| 134 |
+
"GLM-4.7",
|
| 135 |
+
"GLM-4.7-thinking",
|
| 136 |
+
"GLM-4.7-search",
|
| 137 |
+
"GLM-4.7-thinking-search",
|
| 138 |
+
}
|
| 139 |
+
for _, m := range tests {
|
| 140 |
+
if IsToolsModel(m) {
|
| 141 |
+
t.Errorf("IsToolsModel(%q) = true, want false", m)
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
// ===== IsThinkingModel / IsSearchModel 不受 -tools 影响 =====
|
| 147 |
+
|
| 148 |
+
func TestIsThinkingModel_WithTools(t *testing.T) {
|
| 149 |
+
if !IsThinkingModel("GLM-4.7-tools-thinking") {
|
| 150 |
+
t.Error("IsThinkingModel should be true for GLM-4.7-tools-thinking")
|
| 151 |
+
}
|
| 152 |
+
if IsThinkingModel("GLM-4.7-tools") {
|
| 153 |
+
t.Error("IsThinkingModel should be false for GLM-4.7-tools")
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
func TestIsSearchModel_WithTools(t *testing.T) {
|
| 158 |
+
if !IsSearchModel("GLM-4.7-tools-search") {
|
| 159 |
+
t.Error("IsSearchModel should be true for GLM-4.7-tools-search")
|
| 160 |
+
}
|
| 161 |
+
if IsSearchModel("GLM-4.7-tools") {
|
| 162 |
+
t.Error("IsSearchModel should be false for GLM-4.7-tools")
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
// ===== GetTargetModel with -tools =====
|
| 167 |
+
|
| 168 |
+
func TestGetTargetModel_WithTools(t *testing.T) {
|
| 169 |
+
target := GetTargetModel("GLM-4.7-tools")
|
| 170 |
+
if target != "glm-4.7" {
|
| 171 |
+
t.Errorf("GetTargetModel(GLM-4.7-tools) = %q, want %q", target, "glm-4.7")
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
func TestGetTargetModel_WithToolsThinking(t *testing.T) {
|
| 176 |
+
target := GetTargetModel("GLM-4.7-tools-thinking")
|
| 177 |
+
if target != "glm-4.7" {
|
| 178 |
+
t.Errorf("GetTargetModel(GLM-4.7-tools-thinking) = %q, want %q", target, "glm-4.7")
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
// ===== ModelList 包含 -tools 变体 =====
|
| 183 |
+
|
| 184 |
+
func TestModelList_ContainsToolsVariants(t *testing.T) {
|
| 185 |
+
expected := map[string]bool{
|
| 186 |
+
"GLM-4.7-tools": false,
|
| 187 |
+
"GLM-4.7-tools-thinking": false,
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
for _, m := range ModelList {
|
| 191 |
+
if _, ok := expected[m]; ok {
|
| 192 |
+
expected[m] = true
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
for name, found := range expected {
|
| 197 |
+
if !found {
|
| 198 |
+
t.Errorf("ModelList missing %q", name)
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
}
|
internal/model/types.go
CHANGED
|
@@ -13,10 +13,39 @@ type ImageURL struct {
|
|
| 13 |
URL string `json:"url"`
|
| 14 |
}
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
// Message 支持纯文本和多模态内容
|
| 17 |
type Message struct {
|
| 18 |
-
Role
|
| 19 |
-
Content
|
|
|
|
|
|
|
| 20 |
}
|
| 21 |
|
| 22 |
// 解析消息内容,返回文本和图片URL列表
|
|
@@ -47,6 +76,37 @@ func (m *Message) ParseContent() (text string, imageURLs []string) {
|
|
| 47 |
|
| 48 |
// 转换为上游消息格式,支持多模态
|
| 49 |
func (m *Message) ToUpstreamMessage(urlToFileID map[string]string) map[string]interface{} {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
text, imageURLs := m.ParseContent()
|
| 51 |
|
| 52 |
// 无图片,返回纯文本
|
|
@@ -83,9 +143,11 @@ func (m *Message) ToUpstreamMessage(urlToFileID map[string]string) map[string]in
|
|
| 83 |
}
|
| 84 |
|
| 85 |
type ChatRequest struct {
|
| 86 |
-
Model
|
| 87 |
-
Messages
|
| 88 |
-
Stream
|
|
|
|
|
|
|
| 89 |
}
|
| 90 |
|
| 91 |
type ChatCompletionChunk struct {
|
|
@@ -96,8 +158,6 @@ type ChatCompletionChunk struct {
|
|
| 96 |
Choices []Choice `json:"choices"`
|
| 97 |
}
|
| 98 |
|
| 99 |
-
type
|
| 100 |
-
|
| 101 |
type Choice struct {
|
| 102 |
Index int `json:"index"`
|
| 103 |
Delta Delta `json:"delta,omitempty"`
|
|
@@ -106,14 +166,16 @@ type Choice struct {
|
|
| 106 |
}
|
| 107 |
|
| 108 |
type Delta struct {
|
| 109 |
-
Content string
|
| 110 |
-
ReasoningContent string
|
|
|
|
| 111 |
}
|
| 112 |
|
| 113 |
type MessageResp struct {
|
| 114 |
-
Role string
|
| 115 |
-
Content string
|
| 116 |
-
ReasoningContent string
|
|
|
|
| 117 |
}
|
| 118 |
|
| 119 |
type ChatCompletionResponse struct {
|
|
|
|
| 13 |
URL string `json:"url"`
|
| 14 |
}
|
| 15 |
|
| 16 |
+
// Tool 工具定义(OpenAI 兼容)
|
| 17 |
+
type Tool struct {
|
| 18 |
+
Type string `json:"type"`
|
| 19 |
+
Function ToolFunction `json:"function"`
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
// ToolFunction 函数定义
|
| 23 |
+
type ToolFunction struct {
|
| 24 |
+
Name string `json:"name"`
|
| 25 |
+
Description string `json:"description,omitempty"`
|
| 26 |
+
Parameters interface{} `json:"parameters,omitempty"`
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// ToolCall 模型返回的工具调用
|
| 30 |
+
type ToolCall struct {
|
| 31 |
+
ID string `json:"id"`
|
| 32 |
+
Type string `json:"type"`
|
| 33 |
+
Function FunctionCall `json:"function"`
|
| 34 |
+
Index int `json:"index"`
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// FunctionCall 函数调用(名称 + 参数 JSON 字符串)
|
| 38 |
+
type FunctionCall struct {
|
| 39 |
+
Name string `json:"name"`
|
| 40 |
+
Arguments string `json:"arguments"`
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
// Message 支持纯文本和多模态内容
|
| 44 |
type Message struct {
|
| 45 |
+
Role string `json:"role"`
|
| 46 |
+
Content interface{} `json:"content"` // string 或 []ContentPart
|
| 47 |
+
ToolCallID string `json:"tool_call_id,omitempty"` // role: "tool" 时使用
|
| 48 |
+
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // role: "assistant" 时使用
|
| 49 |
}
|
| 50 |
|
| 51 |
// 解析消息内容,返回文本和图片URL列表
|
|
|
|
| 76 |
|
| 77 |
// 转换为上游消息格式,支持多模态
|
| 78 |
func (m *Message) ToUpstreamMessage(urlToFileID map[string]string) map[string]interface{} {
|
| 79 |
+
// tool 消息:包含 tool_call_id
|
| 80 |
+
if m.Role == "tool" {
|
| 81 |
+
msg := map[string]interface{}{
|
| 82 |
+
"role": m.Role,
|
| 83 |
+
"content": m.Content,
|
| 84 |
+
"tool_call_id": m.ToolCallID,
|
| 85 |
+
}
|
| 86 |
+
return msg
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// assistant 消息带 tool_calls
|
| 90 |
+
if m.Role == "assistant" && len(m.ToolCalls) > 0 {
|
| 91 |
+
msg := map[string]interface{}{
|
| 92 |
+
"role": m.Role,
|
| 93 |
+
"content": m.Content,
|
| 94 |
+
}
|
| 95 |
+
var toolCalls []map[string]interface{}
|
| 96 |
+
for _, tc := range m.ToolCalls {
|
| 97 |
+
toolCalls = append(toolCalls, map[string]interface{}{
|
| 98 |
+
"id": tc.ID,
|
| 99 |
+
"type": tc.Type,
|
| 100 |
+
"function": map[string]interface{}{
|
| 101 |
+
"name": tc.Function.Name,
|
| 102 |
+
"arguments": tc.Function.Arguments,
|
| 103 |
+
},
|
| 104 |
+
})
|
| 105 |
+
}
|
| 106 |
+
msg["tool_calls"] = toolCalls
|
| 107 |
+
return msg
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
text, imageURLs := m.ParseContent()
|
| 111 |
|
| 112 |
// 无图片,返回纯文本
|
|
|
|
| 143 |
}
|
| 144 |
|
| 145 |
type ChatRequest struct {
|
| 146 |
+
Model string `json:"model"`
|
| 147 |
+
Messages []Message `json:"messages"`
|
| 148 |
+
Stream bool `json:"stream"`
|
| 149 |
+
Tools []Tool `json:"tools,omitempty"`
|
| 150 |
+
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
| 151 |
}
|
| 152 |
|
| 153 |
type ChatCompletionChunk struct {
|
|
|
|
| 158 |
Choices []Choice `json:"choices"`
|
| 159 |
}
|
| 160 |
|
|
|
|
|
|
|
| 161 |
type Choice struct {
|
| 162 |
Index int `json:"index"`
|
| 163 |
Delta Delta `json:"delta,omitempty"`
|
|
|
|
| 166 |
}
|
| 167 |
|
| 168 |
type Delta struct {
|
| 169 |
+
Content string `json:"content,omitempty"`
|
| 170 |
+
ReasoningContent string `json:"reasoning_content,omitempty"`
|
| 171 |
+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
| 172 |
}
|
| 173 |
|
| 174 |
type MessageResp struct {
|
| 175 |
+
Role string `json:"role"`
|
| 176 |
+
Content string `json:"content"`
|
| 177 |
+
ReasoningContent string `json:"reasoning_content,omitempty"`
|
| 178 |
+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
| 179 |
}
|
| 180 |
|
| 181 |
type ChatCompletionResponse struct {
|
internal/model/types_test.go
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package model
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"testing"
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
// ===== Tool 类型序列化/反序列化 =====
|
| 9 |
+
|
| 10 |
+
func TestToolJSON(t *testing.T) {
|
| 11 |
+
tool := Tool{
|
| 12 |
+
Type: "function",
|
| 13 |
+
Function: ToolFunction{
|
| 14 |
+
Name: "get_weather",
|
| 15 |
+
Description: "获取天气信息",
|
| 16 |
+
Parameters: map[string]interface{}{
|
| 17 |
+
"type": "object",
|
| 18 |
+
"properties": map[string]interface{}{
|
| 19 |
+
"location": map[string]interface{}{
|
| 20 |
+
"type": "string",
|
| 21 |
+
"description": "城市名称",
|
| 22 |
+
},
|
| 23 |
+
},
|
| 24 |
+
"required": []string{"location"},
|
| 25 |
+
},
|
| 26 |
+
},
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
data, err := json.Marshal(tool)
|
| 30 |
+
if err != nil {
|
| 31 |
+
t.Fatalf("marshal Tool: %v", err)
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
var decoded Tool
|
| 35 |
+
if err := json.Unmarshal(data, &decoded); err != nil {
|
| 36 |
+
t.Fatalf("unmarshal Tool: %v", err)
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
if decoded.Type != "function" {
|
| 40 |
+
t.Errorf("Type = %q, want %q", decoded.Type, "function")
|
| 41 |
+
}
|
| 42 |
+
if decoded.Function.Name != "get_weather" {
|
| 43 |
+
t.Errorf("Function.Name = %q, want %q", decoded.Function.Name, "get_weather")
|
| 44 |
+
}
|
| 45 |
+
if decoded.Function.Description != "获取天气信息" {
|
| 46 |
+
t.Errorf("Function.Description = %q, want %q", decoded.Function.Description, "获取天气信息")
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
func TestToolCallJSON(t *testing.T) {
|
| 51 |
+
tc := ToolCall{
|
| 52 |
+
ID: "call_abc123",
|
| 53 |
+
Type: "function",
|
| 54 |
+
Function: FunctionCall{
|
| 55 |
+
Name: "get_weather",
|
| 56 |
+
Arguments: `{"location":"北京"}`,
|
| 57 |
+
},
|
| 58 |
+
Index: 0,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
data, err := json.Marshal(tc)
|
| 62 |
+
if err != nil {
|
| 63 |
+
t.Fatalf("marshal ToolCall: %v", err)
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
var decoded ToolCall
|
| 67 |
+
if err := json.Unmarshal(data, &decoded); err != nil {
|
| 68 |
+
t.Fatalf("unmarshal ToolCall: %v", err)
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
if decoded.ID != "call_abc123" {
|
| 72 |
+
t.Errorf("ID = %q, want %q", decoded.ID, "call_abc123")
|
| 73 |
+
}
|
| 74 |
+
if decoded.Function.Name != "get_weather" {
|
| 75 |
+
t.Errorf("Function.Name = %q, want %q", decoded.Function.Name, "get_weather")
|
| 76 |
+
}
|
| 77 |
+
if decoded.Function.Arguments != `{"location":"北京"}` {
|
| 78 |
+
t.Errorf("Function.Arguments = %q, want %q", decoded.Function.Arguments, `{"location":"北京"}`)
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// ===== ChatRequest 带 Tools 序列化 =====
|
| 83 |
+
|
| 84 |
+
func TestChatRequestWithTools(t *testing.T) {
|
| 85 |
+
reqJSON := `{
|
| 86 |
+
"model": "GLM-4.7",
|
| 87 |
+
"messages": [{"role": "user", "content": "北京天气怎么样?"}],
|
| 88 |
+
"stream": true,
|
| 89 |
+
"tools": [{
|
| 90 |
+
"type": "function",
|
| 91 |
+
"function": {
|
| 92 |
+
"name": "get_weather",
|
| 93 |
+
"description": "获取天气",
|
| 94 |
+
"parameters": {"type": "object", "properties": {"location": {"type": "string"}}}
|
| 95 |
+
}
|
| 96 |
+
}],
|
| 97 |
+
"tool_choice": "auto"
|
| 98 |
+
}`
|
| 99 |
+
|
| 100 |
+
var req ChatRequest
|
| 101 |
+
if err := json.Unmarshal([]byte(reqJSON), &req); err != nil {
|
| 102 |
+
t.Fatalf("unmarshal ChatRequest: %v", err)
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
if req.Model != "GLM-4.7" {
|
| 106 |
+
t.Errorf("Model = %q, want %q", req.Model, "GLM-4.7")
|
| 107 |
+
}
|
| 108 |
+
if len(req.Tools) != 1 {
|
| 109 |
+
t.Fatalf("len(Tools) = %d, want 1", len(req.Tools))
|
| 110 |
+
}
|
| 111 |
+
if req.Tools[0].Function.Name != "get_weather" {
|
| 112 |
+
t.Errorf("Tools[0].Function.Name = %q, want %q", req.Tools[0].Function.Name, "get_weather")
|
| 113 |
+
}
|
| 114 |
+
if req.ToolChoice != "auto" {
|
| 115 |
+
t.Errorf("ToolChoice = %v, want %q", req.ToolChoice, "auto")
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
func TestChatRequestWithoutTools(t *testing.T) {
|
| 120 |
+
reqJSON := `{
|
| 121 |
+
"model": "GLM-4.6",
|
| 122 |
+
"messages": [{"role": "user", "content": "hello"}],
|
| 123 |
+
"stream": false
|
| 124 |
+
}`
|
| 125 |
+
|
| 126 |
+
var req ChatRequest
|
| 127 |
+
if err := json.Unmarshal([]byte(reqJSON), &req); err != nil {
|
| 128 |
+
t.Fatalf("unmarshal ChatRequest: %v", err)
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
if len(req.Tools) != 0 {
|
| 132 |
+
t.Errorf("len(Tools) = %d, want 0", len(req.Tools))
|
| 133 |
+
}
|
| 134 |
+
if req.ToolChoice != nil {
|
| 135 |
+
t.Errorf("ToolChoice = %v, want nil", req.ToolChoice)
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
func TestChatRequestToolChoiceObject(t *testing.T) {
|
| 140 |
+
reqJSON := `{
|
| 141 |
+
"model": "GLM-4.7",
|
| 142 |
+
"messages": [{"role": "user", "content": "test"}],
|
| 143 |
+
"stream": false,
|
| 144 |
+
"tools": [{"type": "function", "function": {"name": "fn1"}}],
|
| 145 |
+
"tool_choice": {"type": "function", "function": {"name": "fn1"}}
|
| 146 |
+
}`
|
| 147 |
+
|
| 148 |
+
var req ChatRequest
|
| 149 |
+
if err := json.Unmarshal([]byte(reqJSON), &req); err != nil {
|
| 150 |
+
t.Fatalf("unmarshal: %v", err)
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
tc, ok := req.ToolChoice.(map[string]interface{})
|
| 154 |
+
if !ok {
|
| 155 |
+
t.Fatalf("ToolChoice type = %T, want map[string]interface{}", req.ToolChoice)
|
| 156 |
+
}
|
| 157 |
+
if tc["type"] != "function" {
|
| 158 |
+
t.Errorf("ToolChoice.type = %v, want %q", tc["type"], "function")
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// ===== Message 带 ToolCallID / ToolCalls 序列化 =====
|
| 163 |
+
|
| 164 |
+
func TestMessageWithToolCallID(t *testing.T) {
|
| 165 |
+
msgJSON := `{
|
| 166 |
+
"role": "tool",
|
| 167 |
+
"content": "{\"temperature\": 25}",
|
| 168 |
+
"tool_call_id": "call_abc123"
|
| 169 |
+
}`
|
| 170 |
+
|
| 171 |
+
var msg Message
|
| 172 |
+
if err := json.Unmarshal([]byte(msgJSON), &msg); err != nil {
|
| 173 |
+
t.Fatalf("unmarshal: %v", err)
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
if msg.Role != "tool" {
|
| 177 |
+
t.Errorf("Role = %q, want %q", msg.Role, "tool")
|
| 178 |
+
}
|
| 179 |
+
if msg.ToolCallID != "call_abc123" {
|
| 180 |
+
t.Errorf("ToolCallID = %q, want %q", msg.ToolCallID, "call_abc123")
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
func TestMessageWithToolCalls(t *testing.T) {
|
| 185 |
+
msgJSON := `{
|
| 186 |
+
"role": "assistant",
|
| 187 |
+
"content": "",
|
| 188 |
+
"tool_calls": [{
|
| 189 |
+
"id": "call_xyz",
|
| 190 |
+
"type": "function",
|
| 191 |
+
"function": {"name": "get_weather", "arguments": "{\"location\":\"上海\"}"},
|
| 192 |
+
"index": 0
|
| 193 |
+
}]
|
| 194 |
+
}`
|
| 195 |
+
|
| 196 |
+
var msg Message
|
| 197 |
+
if err := json.Unmarshal([]byte(msgJSON), &msg); err != nil {
|
| 198 |
+
t.Fatalf("unmarshal: %v", err)
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
if len(msg.ToolCalls) != 1 {
|
| 202 |
+
t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
|
| 203 |
+
}
|
| 204 |
+
if msg.ToolCalls[0].Function.Name != "get_weather" {
|
| 205 |
+
t.Errorf("ToolCalls[0].Function.Name = %q, want %q", msg.ToolCalls[0].Function.Name, "get_weather")
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// ===== ToUpstreamMessage =====
|
| 210 |
+
|
| 211 |
+
func TestToUpstreamMessage_ToolRole(t *testing.T) {
|
| 212 |
+
msg := Message{
|
| 213 |
+
Role: "tool",
|
| 214 |
+
Content: `{"temperature": 25}`,
|
| 215 |
+
ToolCallID: "call_abc",
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
result := msg.ToUpstreamMessage(nil)
|
| 219 |
+
|
| 220 |
+
if result["role"] != "tool" {
|
| 221 |
+
t.Errorf("role = %v, want %q", result["role"], "tool")
|
| 222 |
+
}
|
| 223 |
+
if result["tool_call_id"] != "call_abc" {
|
| 224 |
+
t.Errorf("tool_call_id = %v, want %q", result["tool_call_id"], "call_abc")
|
| 225 |
+
}
|
| 226 |
+
if result["content"] != `{"temperature": 25}` {
|
| 227 |
+
t.Errorf("content = %v, want %q", result["content"], `{"temperature": 25}`)
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
func TestToUpstreamMessage_AssistantWithToolCalls(t *testing.T) {
|
| 232 |
+
msg := Message{
|
| 233 |
+
Role: "assistant",
|
| 234 |
+
Content: "",
|
| 235 |
+
ToolCalls: []ToolCall{
|
| 236 |
+
{
|
| 237 |
+
ID: "call_1",
|
| 238 |
+
Type: "function",
|
| 239 |
+
Function: FunctionCall{
|
| 240 |
+
Name: "get_weather",
|
| 241 |
+
Arguments: `{"location":"北京"}`,
|
| 242 |
+
},
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
ID: "call_2",
|
| 246 |
+
Type: "function",
|
| 247 |
+
Function: FunctionCall{
|
| 248 |
+
Name: "get_time",
|
| 249 |
+
Arguments: `{"timezone":"Asia/Shanghai"}`,
|
| 250 |
+
},
|
| 251 |
+
},
|
| 252 |
+
},
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
result := msg.ToUpstreamMessage(nil)
|
| 256 |
+
|
| 257 |
+
if result["role"] != "assistant" {
|
| 258 |
+
t.Errorf("role = %v, want %q", result["role"], "assistant")
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
toolCalls, ok := result["tool_calls"].([]map[string]interface{})
|
| 262 |
+
if !ok {
|
| 263 |
+
t.Fatalf("tool_calls type = %T, want []map[string]interface{}", result["tool_calls"])
|
| 264 |
+
}
|
| 265 |
+
if len(toolCalls) != 2 {
|
| 266 |
+
t.Fatalf("len(tool_calls) = %d, want 2", len(toolCalls))
|
| 267 |
+
}
|
| 268 |
+
if toolCalls[0]["id"] != "call_1" {
|
| 269 |
+
t.Errorf("tool_calls[0].id = %v, want %q", toolCalls[0]["id"], "call_1")
|
| 270 |
+
}
|
| 271 |
+
fn, ok := toolCalls[0]["function"].(map[string]interface{})
|
| 272 |
+
if !ok {
|
| 273 |
+
t.Fatalf("function type = %T", toolCalls[0]["function"])
|
| 274 |
+
}
|
| 275 |
+
if fn["name"] != "get_weather" {
|
| 276 |
+
t.Errorf("function.name = %v, want %q", fn["name"], "get_weather")
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
func TestToUpstreamMessage_PlainUser(t *testing.T) {
|
| 281 |
+
msg := Message{
|
| 282 |
+
Role: "user",
|
| 283 |
+
Content: "hello",
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
result := msg.ToUpstreamMessage(nil)
|
| 287 |
+
if result["role"] != "user" {
|
| 288 |
+
t.Errorf("role = %v, want %q", result["role"], "user")
|
| 289 |
+
}
|
| 290 |
+
if result["content"] != "hello" {
|
| 291 |
+
t.Errorf("content = %v, want %q", result["content"], "hello")
|
| 292 |
+
}
|
| 293 |
+
if _, exists := result["tool_call_id"]; exists {
|
| 294 |
+
t.Error("tool_call_id should not be present for user messages")
|
| 295 |
+
}
|
| 296 |
+
if _, exists := result["tool_calls"]; exists {
|
| 297 |
+
t.Error("tool_calls should not be present for user messages")
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
func TestToUpstreamMessage_AssistantWithoutToolCalls(t *testing.T) {
|
| 302 |
+
msg := Message{
|
| 303 |
+
Role: "assistant",
|
| 304 |
+
Content: "你好!",
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
result := msg.ToUpstreamMessage(nil)
|
| 308 |
+
if result["role"] != "assistant" {
|
| 309 |
+
t.Errorf("role = %v, want %q", result["role"], "assistant")
|
| 310 |
+
}
|
| 311 |
+
if result["content"] != "你好!" {
|
| 312 |
+
t.Errorf("content = %v, want %q", result["content"], "你好!")
|
| 313 |
+
}
|
| 314 |
+
if _, exists := result["tool_calls"]; exists {
|
| 315 |
+
t.Error("tool_calls should not be present when empty")
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
// ===== Delta / MessageResp 带 ToolCalls =====
|
| 320 |
+
|
| 321 |
+
func TestDeltaWithToolCalls(t *testing.T) {
|
| 322 |
+
delta := Delta{
|
| 323 |
+
ToolCalls: []ToolCall{{
|
| 324 |
+
ID: "call_1",
|
| 325 |
+
Type: "function",
|
| 326 |
+
Index: 0,
|
| 327 |
+
Function: FunctionCall{
|
| 328 |
+
Name: "get_weather",
|
| 329 |
+
Arguments: `{"location":"北京"}`,
|
| 330 |
+
},
|
| 331 |
+
}},
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
data, err := json.Marshal(delta)
|
| 335 |
+
if err != nil {
|
| 336 |
+
t.Fatalf("marshal: %v", err)
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
var decoded Delta
|
| 340 |
+
if err := json.Unmarshal(data, &decoded); err != nil {
|
| 341 |
+
t.Fatalf("unmarshal: %v", err)
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
if len(decoded.ToolCalls) != 1 {
|
| 345 |
+
t.Fatalf("len(ToolCalls) = %d, want 1", len(decoded.ToolCalls))
|
| 346 |
+
}
|
| 347 |
+
if decoded.ToolCalls[0].Function.Name != "get_weather" {
|
| 348 |
+
t.Errorf("Name = %q, want %q", decoded.ToolCalls[0].Function.Name, "get_weather")
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
func TestDeltaOmitsEmptyToolCalls(t *testing.T) {
|
| 353 |
+
delta := Delta{Content: "hello"}
|
| 354 |
+
|
| 355 |
+
data, err := json.Marshal(delta)
|
| 356 |
+
if err != nil {
|
| 357 |
+
t.Fatalf("marshal: %v", err)
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
// tool_calls 为空时应被 omitempty 省略
|
| 361 |
+
var raw map[string]interface{}
|
| 362 |
+
json.Unmarshal(data, &raw)
|
| 363 |
+
if _, exists := raw["tool_calls"]; exists {
|
| 364 |
+
t.Error("tool_calls should be omitted when empty")
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
func TestMessageRespWithToolCalls(t *testing.T) {
|
| 369 |
+
resp := MessageResp{
|
| 370 |
+
Role: "assistant",
|
| 371 |
+
Content: "",
|
| 372 |
+
ToolCalls: []ToolCall{{
|
| 373 |
+
ID: "call_1",
|
| 374 |
+
Type: "function",
|
| 375 |
+
Index: 0,
|
| 376 |
+
Function: FunctionCall{
|
| 377 |
+
Name: "search",
|
| 378 |
+
Arguments: `{"query":"test"}`,
|
| 379 |
+
},
|
| 380 |
+
}},
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
data, err := json.Marshal(resp)
|
| 384 |
+
if err != nil {
|
| 385 |
+
t.Fatalf("marshal: %v", err)
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
var decoded MessageResp
|
| 389 |
+
if err := json.Unmarshal(data, &decoded); err != nil {
|
| 390 |
+
t.Fatalf("unmarshal: %v", err)
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
if len(decoded.ToolCalls) != 1 {
|
| 394 |
+
t.Fatalf("len(ToolCalls) = %d, want 1", len(decoded.ToolCalls))
|
| 395 |
+
}
|
| 396 |
+
if decoded.ToolCalls[0].Function.Arguments != `{"query":"test"}` {
|
| 397 |
+
t.Errorf("Arguments = %q", decoded.ToolCalls[0].Function.Arguments)
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
func TestMessageRespOmitsEmptyToolCalls(t *testing.T) {
|
| 402 |
+
resp := MessageResp{
|
| 403 |
+
Role: "assistant",
|
| 404 |
+
Content: "hello world",
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
data, _ := json.Marshal(resp)
|
| 408 |
+
var raw map[string]interface{}
|
| 409 |
+
json.Unmarshal(data, &raw)
|
| 410 |
+
if _, exists := raw["tool_calls"]; exists {
|
| 411 |
+
t.Error("tool_calls should be omitted when empty")
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
// ===== ChatCompletionChunk 带 tool_calls finish_reason =====
|
| 416 |
+
|
| 417 |
+
func TestChunkWithToolCallsFinishReason(t *testing.T) {
|
| 418 |
+
reason := "tool_calls"
|
| 419 |
+
chunk := ChatCompletionChunk{
|
| 420 |
+
ID: "chatcmpl-test",
|
| 421 |
+
Object: "chat.completion.chunk",
|
| 422 |
+
Created: 1000,
|
| 423 |
+
Model: "glm-4.7",
|
| 424 |
+
Choices: []Choice{{
|
| 425 |
+
Index: 0,
|
| 426 |
+
Delta: Delta{},
|
| 427 |
+
FinishReason: &reason,
|
| 428 |
+
}},
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
data, err := json.Marshal(chunk)
|
| 432 |
+
if err != nil {
|
| 433 |
+
t.Fatalf("marshal: %v", err)
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
var decoded ChatCompletionChunk
|
| 437 |
+
if err := json.Unmarshal(data, &decoded); err != nil {
|
| 438 |
+
t.Fatalf("unmarshal: %v", err)
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
if decoded.Choices[0].FinishReason == nil {
|
| 442 |
+
t.Fatal("FinishReason is nil")
|
| 443 |
+
}
|
| 444 |
+
if *decoded.Choices[0].FinishReason != "tool_calls" {
|
| 445 |
+
t.Errorf("FinishReason = %q, want %q", *decoded.Choices[0].FinishReason, "tool_calls")
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
// ===== ChatCompletionResponse 带 tool_calls =====
|
| 450 |
+
|
| 451 |
+
func TestCompletionResponseWithToolCalls(t *testing.T) {
|
| 452 |
+
reason := "tool_calls"
|
| 453 |
+
resp := ChatCompletionResponse{
|
| 454 |
+
ID: "chatcmpl-test",
|
| 455 |
+
Object: "chat.completion",
|
| 456 |
+
Created: 1000,
|
| 457 |
+
Model: "glm-4.7",
|
| 458 |
+
Choices: []Choice{{
|
| 459 |
+
Index: 0,
|
| 460 |
+
Message: &MessageResp{
|
| 461 |
+
Role: "assistant",
|
| 462 |
+
Content: "",
|
| 463 |
+
ToolCalls: []ToolCall{{
|
| 464 |
+
ID: "call_1",
|
| 465 |
+
Type: "function",
|
| 466 |
+
Index: 0,
|
| 467 |
+
Function: FunctionCall{
|
| 468 |
+
Name: "get_weather",
|
| 469 |
+
Arguments: `{"location":"北京"}`,
|
| 470 |
+
},
|
| 471 |
+
}},
|
| 472 |
+
},
|
| 473 |
+
FinishReason: &reason,
|
| 474 |
+
}},
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
data, err := json.Marshal(resp)
|
| 478 |
+
if err != nil {
|
| 479 |
+
t.Fatalf("marshal: %v", err)
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
var decoded ChatCompletionResponse
|
| 483 |
+
if err := json.Unmarshal(data, &decoded); err != nil {
|
| 484 |
+
t.Fatalf("unmarshal: %v", err)
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
if len(decoded.Choices) != 1 {
|
| 488 |
+
t.Fatalf("len(Choices) = %d", len(decoded.Choices))
|
| 489 |
+
}
|
| 490 |
+
msg := decoded.Choices[0].Message
|
| 491 |
+
if msg == nil {
|
| 492 |
+
t.Fatal("Message is nil")
|
| 493 |
+
}
|
| 494 |
+
if len(msg.ToolCalls) != 1 {
|
| 495 |
+
t.Fatalf("len(ToolCalls) = %d, want 1", len(msg.ToolCalls))
|
| 496 |
+
}
|
| 497 |
+
if msg.ToolCalls[0].Function.Name != "get_weather" {
|
| 498 |
+
t.Errorf("Function.Name = %q", msg.ToolCalls[0].Function.Name)
|
| 499 |
+
}
|
| 500 |
+
if *decoded.Choices[0].FinishReason != "tool_calls" {
|
| 501 |
+
t.Errorf("FinishReason = %q", *decoded.Choices[0].FinishReason)
|
| 502 |
+
}
|
| 503 |
+
}
|
internal/tools/builtin.go
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package tools
|
| 2 |
+
|
| 3 |
+
import "zai-proxy/internal/model"
|
| 4 |
+
|
| 5 |
+
// GetBuiltinTools 返回所有内置工具定义
|
| 6 |
+
func GetBuiltinTools() []model.Tool {
|
| 7 |
+
return []model.Tool{
|
| 8 |
+
// 多功能助手
|
| 9 |
+
{
|
| 10 |
+
Type: "function",
|
| 11 |
+
Function: model.ToolFunction{
|
| 12 |
+
Name: "get_current_time",
|
| 13 |
+
Description: "获取当前时间,支持不同时区和格式",
|
| 14 |
+
Parameters: map[string]interface{}{
|
| 15 |
+
"type": "object",
|
| 16 |
+
"properties": map[string]interface{}{
|
| 17 |
+
"timezone": map[string]interface{}{
|
| 18 |
+
"type": "string",
|
| 19 |
+
"description": "时区名称(如 Asia/Shanghai, America/New_York)",
|
| 20 |
+
},
|
| 21 |
+
"format": map[string]interface{}{
|
| 22 |
+
"type": "string",
|
| 23 |
+
"description": "时间格式(如 2006-01-02 15:04:05)",
|
| 24 |
+
},
|
| 25 |
+
},
|
| 26 |
+
"required": []string{},
|
| 27 |
+
},
|
| 28 |
+
},
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
Type: "function",
|
| 32 |
+
Function: model.ToolFunction{
|
| 33 |
+
Name: "calculate",
|
| 34 |
+
Description: "执行数学计算,支持基本运算和高级数学函数",
|
| 35 |
+
Parameters: map[string]interface{}{
|
| 36 |
+
"type": "object",
|
| 37 |
+
"properties": map[string]interface{}{
|
| 38 |
+
"expression": map[string]interface{}{
|
| 39 |
+
"type": "string",
|
| 40 |
+
"description": "数学表达式(如 2+3*4, sqrt(16), sin(pi/2))",
|
| 41 |
+
},
|
| 42 |
+
},
|
| 43 |
+
"required": []string{"expression"},
|
| 44 |
+
},
|
| 45 |
+
},
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
Type: "function",
|
| 49 |
+
Function: model.ToolFunction{
|
| 50 |
+
Name: "search_web",
|
| 51 |
+
Description: "搜索网络获取实时信息",
|
| 52 |
+
Parameters: map[string]interface{}{
|
| 53 |
+
"type": "object",
|
| 54 |
+
"properties": map[string]interface{}{
|
| 55 |
+
"query": map[string]interface{}{
|
| 56 |
+
"type": "string",
|
| 57 |
+
"description": "搜索关键词",
|
| 58 |
+
},
|
| 59 |
+
"num_results": map[string]interface{}{
|
| 60 |
+
"type": "integer",
|
| 61 |
+
"description": "返回结果数量,默认5",
|
| 62 |
+
},
|
| 63 |
+
},
|
| 64 |
+
"required": []string{"query"},
|
| 65 |
+
},
|
| 66 |
+
},
|
| 67 |
+
},
|
| 68 |
+
// 数据库查询
|
| 69 |
+
{
|
| 70 |
+
Type: "function",
|
| 71 |
+
Function: model.ToolFunction{
|
| 72 |
+
Name: "query_database",
|
| 73 |
+
Description: "执行SQL查询获取数据",
|
| 74 |
+
Parameters: map[string]interface{}{
|
| 75 |
+
"type": "object",
|
| 76 |
+
"properties": map[string]interface{}{
|
| 77 |
+
"sql": map[string]interface{}{
|
| 78 |
+
"type": "string",
|
| 79 |
+
"description": "SQL查询语句",
|
| 80 |
+
},
|
| 81 |
+
"database": map[string]interface{}{
|
| 82 |
+
"type": "string",
|
| 83 |
+
"description": "目标数据库名称",
|
| 84 |
+
},
|
| 85 |
+
},
|
| 86 |
+
"required": []string{"sql"},
|
| 87 |
+
},
|
| 88 |
+
},
|
| 89 |
+
},
|
| 90 |
+
// 文件操作
|
| 91 |
+
{
|
| 92 |
+
Type: "function",
|
| 93 |
+
Function: model.ToolFunction{
|
| 94 |
+
Name: "file_operations",
|
| 95 |
+
Description: "执行文件操作,支持读取、写入和列出文件",
|
| 96 |
+
Parameters: map[string]interface{}{
|
| 97 |
+
"type": "object",
|
| 98 |
+
"properties": map[string]interface{}{
|
| 99 |
+
"operation": map[string]interface{}{
|
| 100 |
+
"type": "string",
|
| 101 |
+
"enum": []string{"read", "write", "list"},
|
| 102 |
+
"description": "操作类型:read(读取)、write(写入)、list(列出)",
|
| 103 |
+
},
|
| 104 |
+
"path": map[string]interface{}{
|
| 105 |
+
"type": "string",
|
| 106 |
+
"description": "文件或目录路径",
|
| 107 |
+
},
|
| 108 |
+
"content": map[string]interface{}{
|
| 109 |
+
"type": "string",
|
| 110 |
+
"description": "写入内容(仅 write 操作需要)",
|
| 111 |
+
},
|
| 112 |
+
},
|
| 113 |
+
"required": []string{"operation", "path"},
|
| 114 |
+
},
|
| 115 |
+
},
|
| 116 |
+
},
|
| 117 |
+
// API集成
|
| 118 |
+
{
|
| 119 |
+
Type: "function",
|
| 120 |
+
Function: model.ToolFunction{
|
| 121 |
+
Name: "call_external_api",
|
| 122 |
+
Description: "调用外部API接口",
|
| 123 |
+
Parameters: map[string]interface{}{
|
| 124 |
+
"type": "object",
|
| 125 |
+
"properties": map[string]interface{}{
|
| 126 |
+
"url": map[string]interface{}{
|
| 127 |
+
"type": "string",
|
| 128 |
+
"description": "API请求URL",
|
| 129 |
+
},
|
| 130 |
+
"method": map[string]interface{}{
|
| 131 |
+
"type": "string",
|
| 132 |
+
"enum": []string{"GET", "POST", "PUT", "DELETE"},
|
| 133 |
+
"description": "HTTP请求方法",
|
| 134 |
+
},
|
| 135 |
+
"headers": map[string]interface{}{
|
| 136 |
+
"type": "object",
|
| 137 |
+
"description": "请求头",
|
| 138 |
+
},
|
| 139 |
+
"body": map[string]interface{}{
|
| 140 |
+
"type": "string",
|
| 141 |
+
"description": "请求体(JSON字符串)",
|
| 142 |
+
},
|
| 143 |
+
},
|
| 144 |
+
"required": []string{"url", "method"},
|
| 145 |
+
},
|
| 146 |
+
},
|
| 147 |
+
},
|
| 148 |
+
}
|
| 149 |
+
}
|
internal/tools/builtin_test.go
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package tools
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"testing"
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
func TestGetBuiltinTools_Count(t *testing.T) {
|
| 8 |
+
tools := GetBuiltinTools()
|
| 9 |
+
if len(tools) != 6 {
|
| 10 |
+
t.Errorf("len(GetBuiltinTools()) = %d, want 6", len(tools))
|
| 11 |
+
}
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
func TestGetBuiltinTools_AllFunction(t *testing.T) {
|
| 15 |
+
for _, tool := range GetBuiltinTools() {
|
| 16 |
+
if tool.Type != "function" {
|
| 17 |
+
t.Errorf("tool %q Type = %q, want %q", tool.Function.Name, tool.Type, "function")
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func TestGetBuiltinTools_Names(t *testing.T) {
|
| 23 |
+
expected := map[string]bool{
|
| 24 |
+
"get_current_time": true,
|
| 25 |
+
"calculate": true,
|
| 26 |
+
"search_web": true,
|
| 27 |
+
"query_database": true,
|
| 28 |
+
"file_operations": true,
|
| 29 |
+
"call_external_api": true,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
tools := GetBuiltinTools()
|
| 33 |
+
for _, tool := range tools {
|
| 34 |
+
name := tool.Function.Name
|
| 35 |
+
if !expected[name] {
|
| 36 |
+
t.Errorf("unexpected tool name: %q", name)
|
| 37 |
+
}
|
| 38 |
+
delete(expected, name)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
for name := range expected {
|
| 42 |
+
t.Errorf("missing tool: %q", name)
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
func TestGetBuiltinTools_HaveDescriptions(t *testing.T) {
|
| 47 |
+
for _, tool := range GetBuiltinTools() {
|
| 48 |
+
if tool.Function.Description == "" {
|
| 49 |
+
t.Errorf("tool %q has empty description", tool.Function.Name)
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
func TestGetBuiltinTools_HaveParameters(t *testing.T) {
|
| 55 |
+
for _, tool := range GetBuiltinTools() {
|
| 56 |
+
if tool.Function.Parameters == nil {
|
| 57 |
+
t.Errorf("tool %q has nil parameters", tool.Function.Name)
|
| 58 |
+
}
|
| 59 |
+
params, ok := tool.Function.Parameters.(map[string]interface{})
|
| 60 |
+
if !ok {
|
| 61 |
+
t.Errorf("tool %q parameters is not a map", tool.Function.Name)
|
| 62 |
+
continue
|
| 63 |
+
}
|
| 64 |
+
if params["type"] != "object" {
|
| 65 |
+
t.Errorf("tool %q parameters.type = %v, want %q", tool.Function.Name, params["type"], "object")
|
| 66 |
+
}
|
| 67 |
+
if _, ok := params["properties"]; !ok {
|
| 68 |
+
t.Errorf("tool %q parameters missing 'properties'", tool.Function.Name)
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
func TestGetBuiltinTools_NoDuplicateNames(t *testing.T) {
|
| 74 |
+
seen := make(map[string]bool)
|
| 75 |
+
for _, tool := range GetBuiltinTools() {
|
| 76 |
+
if seen[tool.Function.Name] {
|
| 77 |
+
t.Errorf("duplicate tool name: %q", tool.Function.Name)
|
| 78 |
+
}
|
| 79 |
+
seen[tool.Function.Name] = true
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
func TestGetBuiltinTools_ReturnsNewSlice(t *testing.T) {
|
| 84 |
+
a := GetBuiltinTools()
|
| 85 |
+
b := GetBuiltinTools()
|
| 86 |
+
if &a[0] == &b[0] {
|
| 87 |
+
t.Error("GetBuiltinTools should return a new slice each call")
|
| 88 |
+
}
|
| 89 |
+
}
|
internal/upstream/client.go
CHANGED
|
@@ -12,6 +12,7 @@ import (
|
|
| 12 |
|
| 13 |
"zai-proxy/internal/auth"
|
| 14 |
"zai-proxy/internal/model"
|
|
|
|
| 15 |
"zai-proxy/internal/version"
|
| 16 |
)
|
| 17 |
|
|
@@ -34,7 +35,7 @@ func ExtractAllImageURLs(messages []model.Message) []string {
|
|
| 34 |
return allImageURLs
|
| 35 |
}
|
| 36 |
|
| 37 |
-
func MakeUpstreamRequest(token string, messages []model.Message, modelName string) (*http.Response, string, error) {
|
| 38 |
payload, err := auth.DecodeJWTPayload(token)
|
| 39 |
if err != nil || payload == nil {
|
| 40 |
return nil, "", fmt.Errorf("invalid token")
|
|
@@ -119,6 +120,26 @@ func MakeUpstreamRequest(token string, messages []model.Message, modelName strin
|
|
| 119 |
body["mcp_servers"] = mcpServers
|
| 120 |
}
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
if len(filesData) > 0 {
|
| 123 |
body["files"] = filesData
|
| 124 |
body["current_user_message_id"] = userMsgID
|
|
|
|
| 12 |
|
| 13 |
"zai-proxy/internal/auth"
|
| 14 |
"zai-proxy/internal/model"
|
| 15 |
+
builtintools "zai-proxy/internal/tools"
|
| 16 |
"zai-proxy/internal/version"
|
| 17 |
)
|
| 18 |
|
|
|
|
| 35 |
return allImageURLs
|
| 36 |
}
|
| 37 |
|
| 38 |
+
func MakeUpstreamRequest(token string, messages []model.Message, modelName string, tools []model.Tool, toolChoice interface{}) (*http.Response, string, error) {
|
| 39 |
payload, err := auth.DecodeJWTPayload(token)
|
| 40 |
if err != nil || payload == nil {
|
| 41 |
return nil, "", fmt.Errorf("invalid token")
|
|
|
|
| 120 |
body["mcp_servers"] = mcpServers
|
| 121 |
}
|
| 122 |
|
| 123 |
+
// 当使用 -tools 模型时,自动注入内置工具(客户端自带工具优先)
|
| 124 |
+
if model.IsToolsModel(modelName) {
|
| 125 |
+
clientToolNames := make(map[string]bool)
|
| 126 |
+
for _, t := range tools {
|
| 127 |
+
clientToolNames[t.Function.Name] = true
|
| 128 |
+
}
|
| 129 |
+
for _, bt := range builtintools.GetBuiltinTools() {
|
| 130 |
+
if !clientToolNames[bt.Function.Name] {
|
| 131 |
+
tools = append(tools, bt)
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
if len(tools) > 0 {
|
| 137 |
+
body["tools"] = tools
|
| 138 |
+
if toolChoice != nil {
|
| 139 |
+
body["tool_choice"] = toolChoice
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
if len(filesData) > 0 {
|
| 144 |
body["files"] = filesData
|
| 145 |
body["current_user_message_id"] = userMsgID
|
scripts/test_tool_call.sh
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# 测试 tool/function calling 功能
|
| 3 |
+
# 用法: ./scripts/test_tool_call.sh [TOKEN] [BASE_URL]
|
| 4 |
+
#
|
| 5 |
+
# TOKEN 可以是你的 z.ai token 或 "free"(匿名)
|
| 6 |
+
# BASE_URL 默认 http://localhost:8000
|
| 7 |
+
|
| 8 |
+
TOKEN="${1:-free}"
|
| 9 |
+
BASE_URL="${2:-http://localhost:8000}"
|
| 10 |
+
|
| 11 |
+
echo "=== 测试 Tool/Function Calling ==="
|
| 12 |
+
echo "BASE_URL: $BASE_URL"
|
| 13 |
+
echo "TOKEN: ${TOKEN:0:10}..."
|
| 14 |
+
echo ""
|
| 15 |
+
|
| 16 |
+
# ===== 测试 1: 带 tools 的流式请求 =====
|
| 17 |
+
echo "--- 测试 1: 流式 tool calling ---"
|
| 18 |
+
curl -sS "${BASE_URL}/v1/chat/completions" \
|
| 19 |
+
-H "Authorization: Bearer ${TOKEN}" \
|
| 20 |
+
-H "Content-Type: application/json" \
|
| 21 |
+
-d '{
|
| 22 |
+
"model": "GLM-4.7",
|
| 23 |
+
"stream": true,
|
| 24 |
+
"messages": [
|
| 25 |
+
{"role": "user", "content": "北京今天天气怎么样?请调用 get_weather 函数查询。"}
|
| 26 |
+
],
|
| 27 |
+
"tools": [{
|
| 28 |
+
"type": "function",
|
| 29 |
+
"function": {
|
| 30 |
+
"name": "get_weather",
|
| 31 |
+
"description": "获取指定城市的当前天气信息",
|
| 32 |
+
"parameters": {
|
| 33 |
+
"type": "object",
|
| 34 |
+
"properties": {
|
| 35 |
+
"location": {
|
| 36 |
+
"type": "string",
|
| 37 |
+
"description": "城市名称,如:北京"
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
"required": ["location"]
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
}],
|
| 44 |
+
"tool_choice": "auto"
|
| 45 |
+
}' 2>&1
|
| 46 |
+
echo ""
|
| 47 |
+
echo ""
|
| 48 |
+
|
| 49 |
+
# ===== 测试 2: 带 tools 的非流式请求 =====
|
| 50 |
+
echo "--- 测试 2: 非流式 tool calling ---"
|
| 51 |
+
curl -sS "${BASE_URL}/v1/chat/completions" \
|
| 52 |
+
-H "Authorization: Bearer ${TOKEN}" \
|
| 53 |
+
-H "Content-Type: application/json" \
|
| 54 |
+
-d '{
|
| 55 |
+
"model": "GLM-4.7",
|
| 56 |
+
"stream": false,
|
| 57 |
+
"messages": [
|
| 58 |
+
{"role": "user", "content": "帮我查一下上海的天气,用 get_weather 工具。"}
|
| 59 |
+
],
|
| 60 |
+
"tools": [{
|
| 61 |
+
"type": "function",
|
| 62 |
+
"function": {
|
| 63 |
+
"name": "get_weather",
|
| 64 |
+
"description": "获取指定城市的当前天气信息",
|
| 65 |
+
"parameters": {
|
| 66 |
+
"type": "object",
|
| 67 |
+
"properties": {
|
| 68 |
+
"location": {
|
| 69 |
+
"type": "string",
|
| 70 |
+
"description": "城市名称"
|
| 71 |
+
}
|
| 72 |
+
},
|
| 73 |
+
"required": ["location"]
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
}],
|
| 77 |
+
"tool_choice": "auto"
|
| 78 |
+
}' 2>&1 | python3 -m json.tool 2>/dev/null || cat
|
| 79 |
+
echo ""
|
| 80 |
+
echo ""
|
| 81 |
+
|
| 82 |
+
# ===== 测试 3: 多工具 =====
|
| 83 |
+
echo "--- 测试 3: 多工具非流式 ---"
|
| 84 |
+
curl -sS "${BASE_URL}/v1/chat/completions" \
|
| 85 |
+
-H "Authorization: Bearer ${TOKEN}" \
|
| 86 |
+
-H "Content-Type: application/json" \
|
| 87 |
+
-d '{
|
| 88 |
+
"model": "GLM-4.7",
|
| 89 |
+
"stream": false,
|
| 90 |
+
"messages": [
|
| 91 |
+
{"role": "user", "content": "北京天气怎么样?现在几点了?请分别调用对应的工具。"}
|
| 92 |
+
],
|
| 93 |
+
"tools": [
|
| 94 |
+
{
|
| 95 |
+
"type": "function",
|
| 96 |
+
"function": {
|
| 97 |
+
"name": "get_weather",
|
| 98 |
+
"description": "获取天气",
|
| 99 |
+
"parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
|
| 100 |
+
}
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"type": "function",
|
| 104 |
+
"function": {
|
| 105 |
+
"name": "get_current_time",
|
| 106 |
+
"description": "获取当前时间",
|
| 107 |
+
"parameters": {"type": "object", "properties": {"timezone": {"type": "string"}}, "required": ["timezone"]}
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
],
|
| 111 |
+
"tool_choice": "auto"
|
| 112 |
+
}' 2>&1 | python3 -m json.tool 2>/dev/null || cat
|
| 113 |
+
echo ""
|
| 114 |
+
echo ""
|
| 115 |
+
|
| 116 |
+
# ===== 测试 4: 完整多轮对话(tool result 回传)=====
|
| 117 |
+
echo "--- 测试 4: 多轮对话 (tool result 回传) ---"
|
| 118 |
+
curl -sS "${BASE_URL}/v1/chat/completions" \
|
| 119 |
+
-H "Authorization: Bearer ${TOKEN}" \
|
| 120 |
+
-H "Content-Type: application/json" \
|
| 121 |
+
-d '{
|
| 122 |
+
"model": "GLM-4.7",
|
| 123 |
+
"stream": false,
|
| 124 |
+
"messages": [
|
| 125 |
+
{"role": "user", "content": "北京天气怎么样?"},
|
| 126 |
+
{
|
| 127 |
+
"role": "assistant",
|
| 128 |
+
"content": "",
|
| 129 |
+
"tool_calls": [{
|
| 130 |
+
"id": "call_abc123",
|
| 131 |
+
"type": "function",
|
| 132 |
+
"function": {"name": "get_weather", "arguments": "{\"location\":\"北京\"}"}
|
| 133 |
+
}]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"role": "tool",
|
| 137 |
+
"tool_call_id": "call_abc123",
|
| 138 |
+
"content": "{\"temperature\": 25, \"condition\": \"晴\", \"humidity\": 40}"
|
| 139 |
+
}
|
| 140 |
+
],
|
| 141 |
+
"tools": [{
|
| 142 |
+
"type": "function",
|
| 143 |
+
"function": {
|
| 144 |
+
"name": "get_weather",
|
| 145 |
+
"description": "获取天气",
|
| 146 |
+
"parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
|
| 147 |
+
}
|
| 148 |
+
}]
|
| 149 |
+
}' 2>&1 | python3 -m json.tool 2>/dev/null || cat
|
| 150 |
+
echo ""
|
| 151 |
+
echo ""
|
| 152 |
+
|
| 153 |
+
# ===== 测试 5: 不带 tools 的普通请求(回归测试)=====
|
| 154 |
+
echo "--- 测试 5: 不带 tools 的普通请求(回归)---"
|
| 155 |
+
curl -sS "${BASE_URL}/v1/chat/completions" \
|
| 156 |
+
-H "Authorization: Bearer ${TOKEN}" \
|
| 157 |
+
-H "Content-Type: application/json" \
|
| 158 |
+
-d '{
|
| 159 |
+
"model": "GLM-4.7",
|
| 160 |
+
"stream": false,
|
| 161 |
+
"messages": [
|
| 162 |
+
{"role": "user", "content": "你好,1+1等于几?"}
|
| 163 |
+
]
|
| 164 |
+
}' 2>&1 | python3 -m json.tool 2>/dev/null || cat
|
| 165 |
+
echo ""
|
| 166 |
+
|
| 167 |
+
echo "=== 测试完成 ==="
|
| 168 |
+
echo ""
|
| 169 |
+
echo "检查要点:"
|
| 170 |
+
echo " 1. 测试 1/2: 查看响应中是否有 tool_calls 字段和 finish_reason=tool_calls"
|
| 171 |
+
echo " 2. 测试 3: 是否返回多个 tool_calls"
|
| 172 |
+
echo " 3. 测试 4: 模型是否基于 tool result 生成了自然语言回复"
|
| 173 |
+
echo " 4. 测试 5: 不带 tools 时是否正常返回文本(无 tool_calls 字段)"
|
| 174 |
+
echo " 5. 查看服务端日志中的 [ToolCall] 行,确认上游返回的原始格式"
|