NtGdi commited on
Commit
553c00b
·
0 Parent(s):

first commit

Browse files
Files changed (15) hide show
  1. .env.example +2 -0
  2. .gitignore +3 -0
  3. Dockerfile +19 -0
  4. README.md +116 -0
  5. go.mod +10 -0
  6. go.sum +6 -0
  7. internal/chat.go +584 -0
  8. internal/config.go +26 -0
  9. internal/jwt.go +39 -0
  10. internal/logger.go +73 -0
  11. internal/models.go +238 -0
  12. internal/signature.go +33 -0
  13. internal/upload.go +197 -0
  14. internal/version.go +55 -0
  15. main.go +22 -0
.env.example ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ PORT=8000
2
+ LOG_LEVEL=info
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ *.exe
3
+ *.cmd
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM golang:1.21-alpine AS builder
2
+
3
+ WORKDIR /app
4
+
5
+ COPY go.mod go.sum ./
6
+ RUN go mod download
7
+
8
+ COPY . .
9
+ RUN CGO_ENABLED=0 GOOS=linux go build -o zai-proxy .
10
+
11
+ FROM alpine:latest
12
+
13
+ WORKDIR /app
14
+
15
+ COPY --from=builder /app/zai-proxy .
16
+
17
+ EXPOSE 8000
18
+
19
+ CMD ["./zai-proxy"]
README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # zai-proxy
2
+
3
+ zai-proxy 是一个基于 Go 语言的代理服务,将 z.ai 网页聊天转换为 OpenAI API 兼容格式。用户使用自己的 z.ai token 进行调用。
4
+
5
+ ## 功能特性
6
+
7
+ - OpenAI API 兼容格式
8
+ - 支持流式和非流式响应
9
+ - 支持多种 GLM 模型
10
+ - 支持思考模式 (thinking)
11
+ - 支持联网搜索模式 (search)
12
+ - 支持多模态图片输入
13
+ - **自动生成签名**
14
+ - **自动更新签名版本号**
15
+
16
+ ## 快速开始
17
+
18
+ ### 安装运行
19
+
20
+ ```bash
21
+ # 克隆项目
22
+ git clone https://github.com/kao0312/zai-proxy.git
23
+ cd zai-proxy
24
+
25
+ # 安装依赖
26
+ go mod download
27
+
28
+ # 运行服务
29
+ go run main.go
30
+ ```
31
+
32
+ ### Docker 部署
33
+
34
+ ```bash
35
+ # 构建镜像
36
+ docker build -t zai-proxy .
37
+
38
+ # 运行容器
39
+ docker run -p 8000:8000 zai-proxy
40
+
41
+ # 使用环境变量
42
+ docker run -p 8000:8000 -e PORT=8080 -e LOG_LEVEL=debug zai-proxy
43
+ ```
44
+
45
+ ## 环境变量
46
+
47
+ | 变量名 | 说明 | 默认值 |
48
+ |--------|------|--------|
49
+ | PORT | 监听端口 | 8000 |
50
+ | LOG_LEVEL | 日志级别 | info |
51
+
52
+ ## 获取 z.ai Token
53
+
54
+ 1. 登录 https://chat.z.ai
55
+ 2. 打开浏览器开发者工具 (F12)
56
+ 3. 切换到 Application/Storage 标签
57
+ 4. 在 Cookies 中找到 `token` 字段
58
+ 5. 复制其值作为 API 调用的 Authorization
59
+
60
+ ## 支持的模型
61
+
62
+ | 模型名称 | 上游模型 |
63
+ |----------|----------|
64
+ | GLM-4.5 | 0727-360B-API |
65
+ | GLM-4.6 | GLM-4-6-API-V1 |
66
+ | GLM-4.5-V | glm-4.5v |
67
+ | GLM-4.5-Air | 0727-106B-API |
68
+
69
+ ### 模型标签
70
+
71
+ 模型名称支持以下后缀标签(可组合使用):
72
+
73
+ - `-thinking`: 启用思考模式,响应会包含 `reasoning_content` 字段
74
+ - `-search`: 启用联网搜索模式
75
+
76
+ 示例:
77
+
78
+ - `GLM-4.6-thinking`
79
+ - `GLM-4.6-search`
80
+ - `GLM-4.6-thinking-search`
81
+
82
+ ## 使用示例
83
+
84
+ ### curl 测试
85
+
86
+ ```bash
87
+ curl http://localhost:8000/v1/chat/completions \
88
+ -H "Authorization: Bearer YOUR_ZAI_TOKEN" \
89
+ -H "Content-Type: application/json" \
90
+ -d '{
91
+ "model": "GLM-4.6",
92
+ "messages": [{"role": "user", "content": "hello"}],
93
+ "stream": true
94
+ }'
95
+ ```
96
+
97
+ ### 多模态请求:
98
+
99
+ ```json
100
+ {
101
+ "model": "GLM-4.5-V",
102
+ "messages": [
103
+ {
104
+ "role": "user",
105
+ "content": [
106
+ {"type": "text", "text": "描述这张图片"},
107
+ {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
108
+ ]
109
+ }
110
+ ]
111
+ }
112
+ ```
113
+
114
+ ### 支持的图片格式:
115
+ - HTTP/HTTPS URL
116
+ - Base64 编码 (data:image/jpeg;base64,...)
go.mod ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ module zai-proxy
2
+
3
+ go 1.21
4
+
5
+ require (
6
+ github.com/google/uuid v1.6.0
7
+ github.com/joho/godotenv v1.5.1
8
+ )
9
+
10
+ require github.com/corpix/uarand v0.2.0 // indirect
go.sum ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ github.com/corpix/uarand v0.2.0 h1:U98xXwud/AVuCpkpgfPF7J5TQgr7R5tqT8VZP5KWbzE=
2
+ github.com/corpix/uarand v0.2.0/go.mod h1:/3Z1QIqWkDIhf6XWn/08/uMHoQ8JUoTIKc2iPchBOmM=
3
+ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
4
+ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
5
+ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
6
+ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
internal/chat.go ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package internal
2
+
3
+ import (
4
+ "bufio"
5
+ "bytes"
6
+ "encoding/json"
7
+ "fmt"
8
+ "io"
9
+ "net/http"
10
+ "strings"
11
+ "time"
12
+
13
+ "github.com/corpix/uarand"
14
+ "github.com/google/uuid"
15
+ )
16
+
17
+ func extractLatestUserContent(messages []Message) string {
18
+ for i := len(messages) - 1; i >= 0; i-- {
19
+ if messages[i].Role == "user" {
20
+ text, _ := messages[i].ParseContent()
21
+ return text
22
+ }
23
+ }
24
+ return ""
25
+ }
26
+
27
+ // 提取所有消息中的图片URL
28
+ func extractAllImageURLs(messages []Message) []string {
29
+ var allImageURLs []string
30
+ for _, msg := range messages {
31
+ _, imageURLs := msg.ParseContent()
32
+ allImageURLs = append(allImageURLs, imageURLs...)
33
+ }
34
+ return allImageURLs
35
+ }
36
+
37
+ func makeUpstreamRequest(token string, messages []Message, model string) (*http.Response, string, error) {
38
+ payload, err := DecodeJWTPayload(token)
39
+ if err != nil || payload == nil {
40
+ return nil, "", fmt.Errorf("invalid token")
41
+ }
42
+
43
+ userID := payload.ID
44
+ chatID := uuid.New().String()
45
+ timestamp := time.Now().UnixMilli()
46
+ requestID := uuid.New().String()
47
+ userMsgID := uuid.New().String()
48
+
49
+ targetModel := GetTargetModel(model)
50
+ latestUserContent := extractLatestUserContent(messages)
51
+ imageURLs := extractAllImageURLs(messages)
52
+
53
+ signature := GenerateSignature(userID, requestID, latestUserContent, timestamp)
54
+
55
+ url := fmt.Sprintf("https://chat.z.ai/api/v2/chat/completions?timestamp=%d&requestId=%s&user_id=%s&version=0.0.1&platform=web&token=%s&current_url=%s&pathname=%s&signature_timestamp=%d",
56
+ timestamp, requestID, userID, token,
57
+ fmt.Sprintf("https://chat.z.ai/c/%s", chatID),
58
+ fmt.Sprintf("/c/%s", chatID),
59
+ timestamp)
60
+
61
+ enableThinking := IsThinkingModel(model)
62
+ autoWebSearch := IsSearchModel(model)
63
+ // GLM-4.5-V 不支持 auto_web_search
64
+ if targetModel == "glm-4.5v" {
65
+ autoWebSearch = false
66
+ }
67
+
68
+ // 转换消息为上游格式
69
+ var upstreamMessages []map[string]string
70
+ for _, msg := range messages {
71
+ upstreamMessages = append(upstreamMessages, msg.ToUpstreamMessage())
72
+ }
73
+
74
+ body := map[string]interface{}{
75
+ "stream": true,
76
+ "model": targetModel,
77
+ "messages": upstreamMessages,
78
+ "signature_prompt": latestUserContent,
79
+ "params": map[string]interface{}{},
80
+ "features": map[string]interface{}{
81
+ "image_generation": false,
82
+ "web_search": false,
83
+ "auto_web_search": autoWebSearch,
84
+ "preview_mode": true,
85
+ "enable_thinking": enableThinking,
86
+ },
87
+ "chat_id": chatID,
88
+ "id": uuid.New().String(),
89
+ }
90
+
91
+ // 处理图片上传
92
+ if len(imageURLs) > 0 {
93
+ files, err := UploadImages(token, imageURLs)
94
+ if err != nil {
95
+ LogError("Failed to upload images: %v", err)
96
+ }
97
+ if len(files) > 0 {
98
+ // 设置 ref_user_msg_id
99
+ var filesData []map[string]interface{}
100
+ for _, f := range files {
101
+ fileMap := map[string]interface{}{
102
+ "type": f.Type,
103
+ "file": f.File,
104
+ "id": f.ID,
105
+ "url": f.URL,
106
+ "name": f.Name,
107
+ "status": f.Status,
108
+ "size": f.Size,
109
+ "error": f.Error,
110
+ "itemId": f.ItemID,
111
+ "media": f.Media,
112
+ "ref_user_msg_id": userMsgID,
113
+ }
114
+ filesData = append(filesData, fileMap)
115
+ }
116
+ body["files"] = filesData
117
+ body["current_user_message_id"] = userMsgID
118
+ }
119
+ }
120
+
121
+ bodyBytes, _ := json.Marshal(body)
122
+
123
+ req, err := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
124
+ if err != nil {
125
+ return nil, "", err
126
+ }
127
+
128
+ req.Header.Set("Authorization", "Bearer "+token)
129
+ req.Header.Set("X-FE-Version", GetFeVersion())
130
+ req.Header.Set("X-Signature", signature)
131
+ req.Header.Set("Content-Type", "application/json")
132
+ req.Header.Set("Connection", "keep-alive")
133
+ req.Header.Set("Origin", "https://chat.z.ai")
134
+ req.Header.Set("Referer", fmt.Sprintf("https://chat.z.ai/c/%s", uuid.New().String()))
135
+ req.Header.Set("User-Agent", uarand.GetRandom())
136
+
137
+ // LogDebug("[Request] URL: %s", url)
138
+ // LogDebug("[Request] Headers: %v", req.Header)
139
+
140
+ client := &http.Client{}
141
+ resp, err := client.Do(req)
142
+ if err != nil {
143
+ return nil, "", err
144
+ }
145
+
146
+ return resp, targetModel, nil
147
+ }
148
+
149
+ type UpstreamData struct {
150
+ Type string `json:"type"`
151
+ Data struct {
152
+ DeltaContent string `json:"delta_content"`
153
+ EditContent string `json:"edit_content"`
154
+ Phase string `json:"phase"`
155
+ Done bool `json:"done"`
156
+ } `json:"data"`
157
+ }
158
+
159
+ // 思考内容过滤器状态
160
+ type ThinkingFilter struct {
161
+ hasSeenFirstThinking bool
162
+ buffer string
163
+ }
164
+
165
+ // 处理思考阶段的内容
166
+ // 第一个 delta_content 包含 <details...>\n<summary>Thinking…</summary>\n> 前缀,需要过滤
167
+ // 后续 delta_content 需要替换 "\n> " 为 "\n"(跨块累积处理)
168
+ func (f *ThinkingFilter) ProcessThinking(deltaContent string) string {
169
+ if !f.hasSeenFirstThinking {
170
+ f.hasSeenFirstThinking = true
171
+ // 第一个 thinking 内容,查找 "> " 之后的内容
172
+ if idx := strings.Index(deltaContent, "> "); idx != -1 {
173
+ deltaContent = deltaContent[idx+2:]
174
+ } else {
175
+ return ""
176
+ }
177
+ }
178
+
179
+ // 合并缓冲区内容
180
+ content := f.buffer + deltaContent
181
+ f.buffer = ""
182
+
183
+ // 替换完整的 "\n> " 为 "\n"
184
+ content = strings.ReplaceAll(content, "\n> ", "\n")
185
+
186
+ // 检查末尾是否有可能是 "\n> " 的前缀
187
+ // 可能的前缀:"\n", "\n>"
188
+ if strings.HasSuffix(content, "\n>") {
189
+ f.buffer = "\n>"
190
+ return content[:len(content)-2]
191
+ }
192
+ if strings.HasSuffix(content, "\n") {
193
+ f.buffer = "\n"
194
+ return content[:len(content)-1]
195
+ }
196
+
197
+ return content
198
+ }
199
+
200
+ // Flush 返回缓冲区中剩余的内容
201
+ func (f *ThinkingFilter) Flush() string {
202
+ result := f.buffer
203
+ f.buffer = ""
204
+ return result
205
+ }
206
+
207
+ // 从 answer 阶段的 edit_content 中提取完整思考内容
208
+ // 格式:true" duration="0" ...>\n<summary>Thought for 0 seconds</summary>\n> 完整思考内容\n</details>\n你好
209
+ func (f *ThinkingFilter) ExtractCompleteThinking(editContent string) string {
210
+ // 查找 "> " 到 "</details>" 之间的内容
211
+ startIdx := strings.Index(editContent, "> ")
212
+ if startIdx == -1 {
213
+ return ""
214
+ }
215
+ startIdx += 2
216
+
217
+ endIdx := strings.Index(editContent, "\n</details>")
218
+ if endIdx == -1 {
219
+ return ""
220
+ }
221
+
222
+ content := editContent[startIdx:endIdx]
223
+ // 替换 "\n> " 为 "\n"
224
+ content = strings.ReplaceAll(content, "\n> ", "\n")
225
+ return content
226
+ }
227
+
228
+ func HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
229
+ token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
230
+ if token == "" {
231
+ http.Error(w, "Unauthorized", http.StatusUnauthorized)
232
+ return
233
+ }
234
+
235
+ var req ChatRequest
236
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
237
+ http.Error(w, "Invalid request", http.StatusBadRequest)
238
+ return
239
+ }
240
+
241
+ if req.Model == "" {
242
+ req.Model = "GLM-4.6"
243
+ }
244
+
245
+ resp, modelName, err := makeUpstreamRequest(token, req.Messages, req.Model)
246
+ if err != nil {
247
+ LogError("Upstream request failed: %v", err)
248
+ http.Error(w, "Upstream error", http.StatusBadGateway)
249
+ return
250
+ }
251
+ defer resp.Body.Close()
252
+
253
+ if resp.StatusCode != http.StatusOK {
254
+ body, _ := io.ReadAll(resp.Body)
255
+ bodyStr := string(body)
256
+ if len(bodyStr) > 500 {
257
+ bodyStr = bodyStr[:500]
258
+ }
259
+ LogError("Upstream error: status=%d, body=%s", resp.StatusCode, bodyStr)
260
+ http.Error(w, "Upstream error", resp.StatusCode)
261
+ return
262
+ }
263
+
264
+ completionID := fmt.Sprintf("chatcmpl-%s", uuid.New().String()[:29])
265
+
266
+ if req.Stream {
267
+ handleStreamResponse(w, resp.Body, completionID, modelName)
268
+ } else {
269
+ handleNonStreamResponse(w, resp.Body, completionID, modelName)
270
+ }
271
+ }
272
+
273
+ func handleStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string) {
274
+ w.Header().Set("Content-Type", "text/event-stream")
275
+ w.Header().Set("Cache-Control", "no-cache")
276
+ w.Header().Set("Connection", "keep-alive")
277
+
278
+ flusher, ok := w.(http.Flusher)
279
+ if !ok {
280
+ http.Error(w, "Streaming not supported", http.StatusInternalServerError)
281
+ return
282
+ }
283
+
284
+ scanner := bufio.NewScanner(body)
285
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
286
+ hasContent := false
287
+ searchRefFilter := NewSearchRefFilter()
288
+ thinkingFilter := &ThinkingFilter{}
289
+
290
+ for scanner.Scan() {
291
+ line := scanner.Text()
292
+ LogDebug("[Upstream] %s", line)
293
+
294
+ if !strings.HasPrefix(line, "data: ") {
295
+ continue
296
+ }
297
+
298
+ payload := strings.TrimPrefix(line, "data: ")
299
+ if payload == "[DONE]" {
300
+ break
301
+ }
302
+
303
+ var upstream UpstreamData
304
+ if err := json.Unmarshal([]byte(payload), &upstream); err != nil {
305
+ continue
306
+ }
307
+
308
+ if upstream.Data.Phase == "done" {
309
+ break
310
+ }
311
+
312
+ // 处理思考阶段的增量内容
313
+ if upstream.Data.Phase == "thinking" && upstream.Data.DeltaContent != "" {
314
+ reasoningContent := thinkingFilter.ProcessThinking(upstream.Data.DeltaContent)
315
+ if reasoningContent != "" {
316
+ hasContent = true
317
+ chunk := ChatCompletionChunk{
318
+ ID: completionID,
319
+ Object: "chat.completion.chunk",
320
+ Created: time.Now().Unix(),
321
+ Model: modelName,
322
+ Choices: []Choice{{
323
+ Index: 0,
324
+ Delta: Delta{ReasoningContent: reasoningContent},
325
+ FinishReason: nil,
326
+ }},
327
+ }
328
+ data, _ := json.Marshal(chunk)
329
+ fmt.Fprintf(w, "data: %s\n\n", data)
330
+ flusher.Flush()
331
+ }
332
+ continue
333
+ }
334
+
335
+ // 跳过搜索结果内容和搜索工具调用
336
+ if upstream.Data.EditContent != "" && (IsSearchResultContent(upstream.Data.EditContent) || IsSearchToolCall(upstream.Data.EditContent, upstream.Data.Phase)) {
337
+ continue
338
+ }
339
+
340
+ // 解析 answer 阶段内容
341
+ content := ""
342
+ reasoningContent := ""
343
+
344
+ // 先输出 thinking 缓冲区剩余内容
345
+ if thinkingRemaining := thinkingFilter.Flush(); thinkingRemaining != "" {
346
+ hasContent = true
347
+ chunk := ChatCompletionChunk{
348
+ ID: completionID,
349
+ Object: "chat.completion.chunk",
350
+ Created: time.Now().Unix(),
351
+ Model: modelName,
352
+ Choices: []Choice{{
353
+ Index: 0,
354
+ Delta: Delta{ReasoningContent: thinkingRemaining},
355
+ FinishReason: nil,
356
+ }},
357
+ }
358
+ data, _ := json.Marshal(chunk)
359
+ fmt.Fprintf(w, "data: %s\n\n", data)
360
+ flusher.Flush()
361
+ }
362
+
363
+ if upstream.Data.Phase == "answer" && upstream.Data.DeltaContent != "" {
364
+ content = upstream.Data.DeltaContent
365
+ } else if upstream.Data.Phase == "answer" && upstream.Data.EditContent != "" {
366
+ // 思考模型首次 answer:提取完整思考内容 + 正常回复开头
367
+ if strings.Contains(upstream.Data.EditContent, "</details>") {
368
+ reasoningContent = thinkingFilter.ExtractCompleteThinking(upstream.Data.EditContent)
369
+ if idx := strings.Index(upstream.Data.EditContent, "</details>\n"); idx != -1 {
370
+ content = upstream.Data.EditContent[idx+len("</details>\n"):]
371
+ }
372
+ }
373
+ } else if (upstream.Data.Phase == "other" || upstream.Data.Phase == "tool_call") && upstream.Data.EditContent != "" {
374
+ // other: 普通最后一个 token; tool_call: 搜索模式最后一个 token
375
+ content = upstream.Data.EditContent
376
+ }
377
+
378
+ // 输出完整思考内容(如果有)
379
+ if reasoningContent != "" {
380
+ hasContent = true
381
+ chunk := ChatCompletionChunk{
382
+ ID: completionID,
383
+ Object: "chat.completion.chunk",
384
+ Created: time.Now().Unix(),
385
+ Model: modelName,
386
+ Choices: []Choice{{
387
+ Index: 0,
388
+ Delta: Delta{ReasoningContent: reasoningContent},
389
+ FinishReason: nil,
390
+ }},
391
+ }
392
+ data, _ := json.Marshal(chunk)
393
+ fmt.Fprintf(w, "data: %s\n\n", data)
394
+ flusher.Flush()
395
+ }
396
+
397
+ if content == "" {
398
+ continue
399
+ }
400
+
401
+ // 过滤搜索引用标记(跨流累积处理)
402
+ content = searchRefFilter.Process(content)
403
+ if content == "" {
404
+ continue
405
+ }
406
+
407
+ hasContent = true
408
+ chunk := ChatCompletionChunk{
409
+ ID: completionID,
410
+ Object: "chat.completion.chunk",
411
+ Created: time.Now().Unix(),
412
+ Model: modelName,
413
+ Choices: []Choice{{
414
+ Index: 0,
415
+ Delta: Delta{Content: content},
416
+ FinishReason: nil,
417
+ }},
418
+ }
419
+
420
+ data, _ := json.Marshal(chunk)
421
+ fmt.Fprintf(w, "data: %s\n\n", data)
422
+ flusher.Flush()
423
+ }
424
+
425
+ if err := scanner.Err(); err != nil {
426
+ LogError("[Upstream] scanner error: %v", err)
427
+ }
428
+
429
+ // 输出过滤器中剩余的内容(非引用标记的部分)
430
+ if remaining := searchRefFilter.Flush(); remaining != "" {
431
+ hasContent = true
432
+ chunk := ChatCompletionChunk{
433
+ ID: completionID,
434
+ Object: "chat.completion.chunk",
435
+ Created: time.Now().Unix(),
436
+ Model: modelName,
437
+ Choices: []Choice{{
438
+ Index: 0,
439
+ Delta: Delta{Content: remaining},
440
+ FinishReason: nil,
441
+ }},
442
+ }
443
+ data, _ := json.Marshal(chunk)
444
+ fmt.Fprintf(w, "data: %s\n\n", data)
445
+ flusher.Flush()
446
+ }
447
+
448
+ if !hasContent {
449
+ LogError("Stream response 200 but no content received")
450
+ }
451
+
452
+ // Final chunk
453
+ stopReason := "stop"
454
+ finalChunk := ChatCompletionChunk{
455
+ ID: completionID,
456
+ Object: "chat.completion.chunk",
457
+ Created: time.Now().Unix(),
458
+ Model: modelName,
459
+ Choices: []Choice{{
460
+ Index: 0,
461
+ Delta: Delta{},
462
+ FinishReason: &stopReason,
463
+ }},
464
+ }
465
+
466
+ data, _ := json.Marshal(finalChunk)
467
+ fmt.Fprintf(w, "data: %s\n\n", data)
468
+ fmt.Fprintf(w, "data: [DONE]\n\n")
469
+ flusher.Flush()
470
+ }
471
+
472
+ func handleNonStreamResponse(w http.ResponseWriter, body io.ReadCloser, completionID, modelName string) {
473
+ scanner := bufio.NewScanner(body)
474
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
475
+ var chunks []string
476
+ var reasoningChunks []string
477
+ thinkingFilter := &ThinkingFilter{}
478
+
479
+ for scanner.Scan() {
480
+ line := scanner.Text()
481
+ if !strings.HasPrefix(line, "data: ") {
482
+ continue
483
+ }
484
+
485
+ payload := strings.TrimPrefix(line, "data: ")
486
+ if payload == "[DONE]" {
487
+ break
488
+ }
489
+
490
+ var upstream UpstreamData
491
+ if err := json.Unmarshal([]byte(payload), &upstream); err != nil {
492
+ continue
493
+ }
494
+
495
+ if upstream.Data.Phase == "done" {
496
+ break
497
+ }
498
+
499
+ // 处理思考阶段的增量内容
500
+ if upstream.Data.Phase == "thinking" && upstream.Data.DeltaContent != "" {
501
+ reasoningContent := thinkingFilter.ProcessThinking(upstream.Data.DeltaContent)
502
+ if reasoningContent != "" {
503
+ reasoningChunks = append(reasoningChunks, reasoningContent)
504
+ }
505
+ continue
506
+ }
507
+
508
+ // 跳过搜索结果内容和搜索工具调用
509
+ if upstream.Data.EditContent != "" && (IsSearchResultContent(upstream.Data.EditContent) || IsSearchToolCall(upstream.Data.EditContent, upstream.Data.Phase)) {
510
+ continue
511
+ }
512
+
513
+ // 解析 answer 阶段内容
514
+ content := ""
515
+ if upstream.Data.Phase == "answer" && upstream.Data.DeltaContent != "" {
516
+ content = upstream.Data.DeltaContent
517
+ } else if upstream.Data.Phase == "answer" && upstream.Data.EditContent != "" {
518
+ // 思考模型首次 answer:提取完整思考内容 + 正常回复开头
519
+ if strings.Contains(upstream.Data.EditContent, "</details>") {
520
+ reasoningContent := thinkingFilter.ExtractCompleteThinking(upstream.Data.EditContent)
521
+ if reasoningContent != "" {
522
+ reasoningChunks = append(reasoningChunks, reasoningContent)
523
+ }
524
+ if idx := strings.Index(upstream.Data.EditContent, "</details>\n"); idx != -1 {
525
+ content = upstream.Data.EditContent[idx+len("</details>\n"):]
526
+ }
527
+ }
528
+ } else if (upstream.Data.Phase == "other" || upstream.Data.Phase == "tool_call") && upstream.Data.EditContent != "" {
529
+ content = upstream.Data.EditContent
530
+ }
531
+
532
+ if content != "" {
533
+ chunks = append(chunks, content)
534
+ }
535
+ }
536
+
537
+ // 合并所有内容后统一过滤搜索引用标记
538
+ fullContent := strings.Join(chunks, "")
539
+ fullContent = searchRefPattern.ReplaceAllString(fullContent, "")
540
+ fullReasoning := strings.Join(reasoningChunks, "")
541
+
542
+ if fullContent == "" {
543
+ LogError("Non-stream response 200 but no content received")
544
+ }
545
+
546
+ stopReason := "stop"
547
+ response := ChatCompletionResponse{
548
+ ID: completionID,
549
+ Object: "chat.completion",
550
+ Created: time.Now().Unix(),
551
+ Model: modelName,
552
+ Choices: []Choice{{
553
+ Index: 0,
554
+ Message: &MessageResp{
555
+ Role: "assistant",
556
+ Content: fullContent,
557
+ ReasoningContent: fullReasoning,
558
+ },
559
+ FinishReason: &stopReason,
560
+ }},
561
+ }
562
+
563
+ w.Header().Set("Content-Type", "application/json")
564
+ json.NewEncoder(w).Encode(response)
565
+ }
566
+
567
+ func HandleModels(w http.ResponseWriter, r *http.Request) {
568
+ var models []ModelInfo
569
+ for _, id := range ModelList {
570
+ models = append(models, ModelInfo{
571
+ ID: id,
572
+ Object: "model",
573
+ OwnedBy: "z.ai",
574
+ })
575
+ }
576
+
577
+ response := ModelsResponse{
578
+ Object: "list",
579
+ Data: models,
580
+ }
581
+
582
+ w.Header().Set("Content-Type", "application/json")
583
+ json.NewEncoder(w).Encode(response)
584
+ }
internal/config.go ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package internal
2
+
3
+ import (
4
+ "os"
5
+
6
+ "github.com/joho/godotenv"
7
+ )
8
+
9
+ type Config struct {
10
+ Port string
11
+ }
12
+
13
+ var Cfg *Config
14
+
15
+ func LoadConfig() {
16
+ godotenv.Load()
17
+
18
+ port := os.Getenv("PORT")
19
+ if port == "" {
20
+ port = "8000"
21
+ }
22
+
23
+ Cfg = &Config{
24
+ Port: port,
25
+ }
26
+ }
internal/jwt.go ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package internal
2
+
3
+ import (
4
+ "encoding/base64"
5
+ "encoding/json"
6
+ "strings"
7
+ )
8
+
9
+ type JWTPayload struct {
10
+ ID string `json:"id"`
11
+ }
12
+
13
+ func DecodeJWTPayload(token string) (*JWTPayload, error) {
14
+ parts := strings.Split(token, ".")
15
+ if len(parts) < 2 {
16
+ return nil, nil
17
+ }
18
+
19
+ payload := parts[1]
20
+ // Add padding if needed
21
+ if padding := 4 - len(payload)%4; padding != 4 {
22
+ payload += strings.Repeat("=", padding)
23
+ }
24
+
25
+ decoded, err := base64.URLEncoding.DecodeString(payload)
26
+ if err != nil {
27
+ decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
28
+ if err != nil {
29
+ return nil, err
30
+ }
31
+ }
32
+
33
+ var result JWTPayload
34
+ if err := json.Unmarshal(decoded, &result); err != nil {
35
+ return nil, err
36
+ }
37
+
38
+ return &result, nil
39
+ }
internal/logger.go ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package internal
2
+
3
+ import (
4
+ "fmt"
5
+ "os"
6
+ "time"
7
+ )
8
+
9
+ type LogLevel int
10
+
11
+ const (
12
+ DEBUG LogLevel = iota
13
+ INFO
14
+ WARN
15
+ ERROR
16
+ )
17
+
18
+ var (
19
+ currentLevel LogLevel = INFO
20
+ levelNames = map[LogLevel]string{
21
+ DEBUG: "DEBUG",
22
+ INFO: "INFO",
23
+ WARN: "WARN",
24
+ ERROR: "ERROR",
25
+ }
26
+ // ANSI 颜色
27
+ levelColors = map[LogLevel]string{
28
+ DEBUG: "\033[36m", // 青色
29
+ INFO: "\033[32m", // 绿色
30
+ WARN: "\033[33m", // 黄色
31
+ ERROR: "\033[31m", // 红色
32
+ }
33
+ resetColor = "\033[0m"
34
+ )
35
+
36
+ func InitLogger() {
37
+ level := os.Getenv("LOG_LEVEL")
38
+ switch level {
39
+ case "debug", "DEBUG":
40
+ currentLevel = DEBUG
41
+ case "warn", "WARN":
42
+ currentLevel = WARN
43
+ case "error", "ERROR":
44
+ currentLevel = ERROR
45
+ default:
46
+ currentLevel = INFO
47
+ }
48
+ }
49
+
50
+ func log(level LogLevel, format string, v ...interface{}) {
51
+ if level < currentLevel {
52
+ return
53
+ }
54
+ timestamp := time.Now().Format("2006/01/02 15:04:05")
55
+ msg := fmt.Sprintf(format, v...)
56
+ fmt.Printf("%s[%s]%s %s %s\n", levelColors[level], levelNames[level], resetColor, timestamp, msg)
57
+ }
58
+
59
+ func LogDebug(format string, v ...interface{}) {
60
+ log(DEBUG, format, v...)
61
+ }
62
+
63
+ func LogInfo(format string, v ...interface{}) {
64
+ log(INFO, format, v...)
65
+ }
66
+
67
+ func LogWarn(format string, v ...interface{}) {
68
+ log(WARN, format, v...)
69
+ }
70
+
71
+ func LogError(format string, v ...interface{}) {
72
+ log(ERROR, format, v...)
73
+ }
internal/models.go ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package internal
2
+
3
+ import (
4
+ "regexp"
5
+ "strings"
6
+ )
7
+
8
+ // 基础模型映射(不包含标签后缀)
9
+ var BaseModelMapping = map[string]string{
10
+ "GLM-4.5": "0727-360B-API",
11
+ "GLM-4.6": "GLM-4-6-API-V1",
12
+ "GLM-4.5-V": "glm-4.5v",
13
+ "GLM-4.5-Air": "0727-106B-API",
14
+ }
15
+
16
+ // v1/models 返回的模型列表(不包含所有标签组合)
17
+ var ModelList = []string{
18
+ "GLM-4.5",
19
+ "GLM-4.6",
20
+ "GLM-4.5-thinking",
21
+ "GLM-4.6-thinking",
22
+ "GLM-4.5-V",
23
+ "GLM-4.5-Air",
24
+ }
25
+
26
+ // 解析模型名称,提取基础模型名和标签
27
+ // 支持 -thinking 和 -search 标签的任意排列组合
28
+ func ParseModelName(model string) (baseModel string, enableThinking bool, enableSearch bool) {
29
+ enableThinking = false
30
+ enableSearch = false
31
+ baseModel = model
32
+
33
+ // 检查并移除 -thinking 和 -search 标签(任意顺序)
34
+ for {
35
+ if strings.HasSuffix(baseModel, "-thinking") {
36
+ enableThinking = true
37
+ baseModel = strings.TrimSuffix(baseModel, "-thinking")
38
+ } else if strings.HasSuffix(baseModel, "-search") {
39
+ enableSearch = true
40
+ baseModel = strings.TrimSuffix(baseModel, "-search")
41
+ } else {
42
+ break
43
+ }
44
+ }
45
+
46
+ return baseModel, enableThinking, enableSearch
47
+ }
48
+
49
+ func IsThinkingModel(model string) bool {
50
+ _, enableThinking, _ := ParseModelName(model)
51
+ return enableThinking
52
+ }
53
+
54
+ func IsSearchModel(model string) bool {
55
+ _, _, enableSearch := ParseModelName(model)
56
+ return enableSearch
57
+ }
58
+
59
+ func GetTargetModel(model string) string {
60
+ baseModel, _, _ := ParseModelName(model)
61
+ if target, ok := BaseModelMapping[baseModel]; ok {
62
+ return target
63
+ }
64
+ return model
65
+ }
66
+
67
+ // OpenAI 格式的消息内容项
68
+ type ContentPart struct {
69
+ Type string `json:"type"`
70
+ Text string `json:"text,omitempty"`
71
+ ImageURL *ImageURL `json:"image_url,omitempty"`
72
+ }
73
+
74
+ type ImageURL struct {
75
+ URL string `json:"url"`
76
+ }
77
+
78
+ // Message 支持纯文本和多模态内容
79
+ type Message struct {
80
+ Role string `json:"role"`
81
+ Content interface{} `json:"content"` // string 或 []ContentPart
82
+ }
83
+
84
+ // 解析消息内容,返回文本和图片URL列表
85
+ func (m *Message) ParseContent() (text string, imageURLs []string) {
86
+ switch content := m.Content.(type) {
87
+ case string:
88
+ return content, nil
89
+ case []interface{}:
90
+ for _, item := range content {
91
+ if part, ok := item.(map[string]interface{}); ok {
92
+ partType, _ := part["type"].(string)
93
+ if partType == "text" {
94
+ if t, ok := part["text"].(string); ok {
95
+ text += t
96
+ }
97
+ } else if partType == "image_url" {
98
+ if imgURL, ok := part["image_url"].(map[string]interface{}); ok {
99
+ if url, ok := imgURL["url"].(string); ok {
100
+ imageURLs = append(imageURLs, url)
101
+ }
102
+ }
103
+ }
104
+ }
105
+ }
106
+ }
107
+ return text, imageURLs
108
+ }
109
+
110
+ // 转换为上游消息格式(纯文本)
111
+ func (m *Message) ToUpstreamMessage() map[string]string {
112
+ text, _ := m.ParseContent()
113
+ return map[string]string{
114
+ "role": m.Role,
115
+ "content": text,
116
+ }
117
+ }
118
+
119
+ type ChatRequest struct {
120
+ Model string `json:"model"`
121
+ Messages []Message `json:"messages"`
122
+ Stream bool `json:"stream"`
123
+ }
124
+
125
+ type ChatCompletionChunk struct {
126
+ ID string `json:"id"`
127
+ Object string `json:"object"`
128
+ Created int64 `json:"created"`
129
+ Model string `json:"model"`
130
+ Choices []Choice `json:"choices"`
131
+ }
132
+
133
+ type Choice struct {
134
+ Index int `json:"index"`
135
+ Delta Delta `json:"delta,omitempty"`
136
+ Message *MessageResp `json:"message,omitempty"`
137
+ FinishReason *string `json:"finish_reason"`
138
+ }
139
+
140
+ type Delta struct {
141
+ Content string `json:"content,omitempty"`
142
+ ReasoningContent string `json:"reasoning_content,omitempty"`
143
+ }
144
+
145
+ type MessageResp struct {
146
+ Role string `json:"role"`
147
+ Content string `json:"content"`
148
+ ReasoningContent string `json:"reasoning_content,omitempty"`
149
+ }
150
+
151
+ type ChatCompletionResponse struct {
152
+ ID string `json:"id"`
153
+ Object string `json:"object"`
154
+ Created int64 `json:"created"`
155
+ Model string `json:"model"`
156
+ Choices []Choice `json:"choices"`
157
+ }
158
+
159
+ type ModelsResponse struct {
160
+ Object string `json:"object"`
161
+ Data []ModelInfo `json:"data"`
162
+ }
163
+
164
+ type ModelInfo struct {
165
+ ID string `json:"id"`
166
+ Object string `json:"object"`
167
+ OwnedBy string `json:"owned_by"`
168
+ }
169
+
170
+ // 搜索引用标记正则:【turn数字search数字】
171
+ var searchRefPattern = regexp.MustCompile(`【turn\d+search\d+】`)
172
+
173
+ // 搜索引用标记可能的前缀模式
174
+ var searchRefPrefixPattern = regexp.MustCompile(`【(t(u(r(n(\d+(s(e(a(r(c(h(\d+)?)?)?)?)?)?)?)?)?)?)?)?$`)
175
+
176
+ // SearchRefFilter 用于跨流过滤搜索引用标记
177
+ type SearchRefFilter struct {
178
+ buffer string
179
+ }
180
+
181
+ // NewSearchRefFilter 创建新的过滤器
182
+ func NewSearchRefFilter() *SearchRefFilter {
183
+ return &SearchRefFilter{}
184
+ }
185
+
186
+ // Process 处理内容,返回可以安全输出的部分
187
+ // 如果末尾有可能是引用标记的前缀,会暂存起来
188
+ func (f *SearchRefFilter) Process(content string) string {
189
+ // 合并之前暂存的内容
190
+ content = f.buffer + content
191
+ f.buffer = ""
192
+
193
+ // 先移除完整的引用标记
194
+ content = searchRefPattern.ReplaceAllString(content, "")
195
+
196
+ if content == "" {
197
+ return ""
198
+ }
199
+
200
+ // 检查末尾是否有可能是引用标记的前缀
201
+ // 从末尾开始,最多检查【turn999search999】长度(约20字符)
202
+ maxPrefixLen := 20
203
+ if len(content) < maxPrefixLen {
204
+ maxPrefixLen = len(content)
205
+ }
206
+
207
+ for i := 1; i <= maxPrefixLen; i++ {
208
+ suffix := content[len(content)-i:]
209
+ if searchRefPrefixPattern.MatchString(suffix) {
210
+ // 找到可能的前缀,暂存起来
211
+ f.buffer = suffix
212
+ return content[:len(content)-i]
213
+ }
214
+ }
215
+
216
+ return content
217
+ }
218
+
219
+ // Flush 返回所有暂存的内容(流结束时调用)
220
+ func (f *SearchRefFilter) Flush() string {
221
+ result := f.buffer
222
+ f.buffer = ""
223
+ return result
224
+ }
225
+
226
+ // 检查是否为搜索结果内容(需要跳过)
227
+ func IsSearchResultContent(editContent string) bool {
228
+ return strings.Contains(editContent, `"search_result"`)
229
+ }
230
+
231
+ // 检查是否为搜索工具调用内容(需要跳过)
232
+ func IsSearchToolCall(editContent string, phase string) bool {
233
+ if phase != "tool_call" {
234
+ return false
235
+ }
236
+ // tool_call 阶段包含 mcp 相关内容的都跳过
237
+ return strings.Contains(editContent, `"mcp"`) || strings.Contains(editContent, `mcp-server`)
238
+ }
internal/signature.go ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package internal
2
+
3
+ import (
4
+ "crypto/hmac"
5
+ "crypto/sha256"
6
+ "encoding/base64"
7
+ "encoding/hex"
8
+ "fmt"
9
+ )
10
+
11
+ func hmacSha256Hex(key []byte, data string) string {
12
+ h := hmac.New(sha256.New, key)
13
+ h.Write([]byte(data))
14
+ return hex.EncodeToString(h.Sum(nil))
15
+ }
16
+
17
+ func GenerateSignature(userID, requestID, userContent string, timestamp int64) string {
18
+ requestInfo := fmt.Sprintf("requestId,%s,timestamp,%d,user_id,%s", requestID, timestamp, userID)
19
+ contentBase64 := base64.StdEncoding.EncodeToString([]byte(userContent))
20
+ signData := fmt.Sprintf("%s|%s|%d", requestInfo, contentBase64, timestamp)
21
+
22
+ period := timestamp / (5 * 60 * 1000)
23
+ // 两次加密均返回 hex 字符串
24
+ firstHmac := hmacSha256Hex([]byte("key-@@@@)))()((9))-xxxx&&&%%%%%"), fmt.Sprintf("%d", period))
25
+ signature := hmacSha256Hex([]byte(firstHmac), signData)
26
+
27
+ // LogDebug("[Signature] requestInfo=%s", requestInfo)
28
+ // LogDebug("[Signature] userContent=%s", userContent)
29
+ // LogDebug("[Signature] timestamp=%d", timestamp)
30
+ // LogDebug("[Signature] signature=%s", signature)
31
+
32
+ return signature
33
+ }
internal/upload.go ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package internal
2
+
3
+ import (
4
+ "bytes"
5
+ "encoding/base64"
6
+ "encoding/json"
7
+ "fmt"
8
+ "io"
9
+ "mime/multipart"
10
+ "net/http"
11
+ "path/filepath"
12
+ "strings"
13
+
14
+ "github.com/google/uuid"
15
+ )
16
+
17
+ // FileUploadResponse z.ai 文件上传响应
18
+ type FileUploadResponse struct {
19
+ ID string `json:"id"`
20
+ UserID string `json:"user_id"`
21
+ Filename string `json:"filename"`
22
+ Meta struct {
23
+ Name string `json:"name"`
24
+ ContentType string `json:"content_type"`
25
+ Size int64 `json:"size"`
26
+ CdnURL string `json:"cdn_url"`
27
+ } `json:"meta"`
28
+ }
29
+
30
+ // UpstreamFile 上游请求的文件格式
31
+ type UpstreamFile struct {
32
+ Type string `json:"type"`
33
+ File FileUploadResponse `json:"file"`
34
+ ID string `json:"id"`
35
+ URL string `json:"url"`
36
+ Name string `json:"name"`
37
+ Status string `json:"status"`
38
+ Size int64 `json:"size"`
39
+ Error string `json:"error"`
40
+ ItemID string `json:"itemId"`
41
+ Media string `json:"media"`
42
+ }
43
+
44
+ // UploadImageFromURL 从 URL 或 base64 上传图片到 z.ai
45
+ func UploadImageFromURL(token string, imageURL string) (*UpstreamFile, error) {
46
+ var imageData []byte
47
+ var filename string
48
+ var contentType string
49
+
50
+ if strings.HasPrefix(imageURL, "data:") {
51
+ // Base64 编码的图片
52
+ // 格式: data:image/jpeg;base64,/9j/4AAQ...
53
+ parts := strings.SplitN(imageURL, ",", 2)
54
+ if len(parts) != 2 {
55
+ return nil, fmt.Errorf("invalid base64 image format")
56
+ }
57
+
58
+ // 解析 MIME 类型
59
+ header := parts[0] // data:image/jpeg;base64
60
+ if idx := strings.Index(header, ":"); idx != -1 {
61
+ mimeAndEncoding := header[idx+1:]
62
+ if semiIdx := strings.Index(mimeAndEncoding, ";"); semiIdx != -1 {
63
+ contentType = mimeAndEncoding[:semiIdx]
64
+ }
65
+ }
66
+ if contentType == "" {
67
+ contentType = "image/png"
68
+ }
69
+
70
+ // 解码 base64
71
+ var err error
72
+ imageData, err = base64.StdEncoding.DecodeString(parts[1])
73
+ if err != nil {
74
+ return nil, fmt.Errorf("failed to decode base64: %v", err)
75
+ }
76
+
77
+ // 生成文件名
78
+ ext := ".png"
79
+ if strings.Contains(contentType, "jpeg") || strings.Contains(contentType, "jpg") {
80
+ ext = ".jpg"
81
+ } else if strings.Contains(contentType, "gif") {
82
+ ext = ".gif"
83
+ } else if strings.Contains(contentType, "webp") {
84
+ ext = ".webp"
85
+ }
86
+ filename = uuid.New().String()[:12] + ext
87
+ } else {
88
+ // 从 URL 下载图片
89
+ resp, err := http.Get(imageURL)
90
+ if err != nil {
91
+ return nil, fmt.Errorf("failed to download image: %v", err)
92
+ }
93
+ defer resp.Body.Close()
94
+
95
+ if resp.StatusCode != http.StatusOK {
96
+ return nil, fmt.Errorf("failed to download image: status %d", resp.StatusCode)
97
+ }
98
+
99
+ imageData, err = io.ReadAll(resp.Body)
100
+ if err != nil {
101
+ return nil, fmt.Errorf("failed to read image data: %v", err)
102
+ }
103
+
104
+ contentType = resp.Header.Get("Content-Type")
105
+ if contentType == "" {
106
+ contentType = "image/png"
107
+ }
108
+
109
+ // 从 URL 提取文件名
110
+ filename = filepath.Base(imageURL)
111
+ if filename == "" || filename == "." || filename == "/" {
112
+ ext := ".png"
113
+ if strings.Contains(contentType, "jpeg") || strings.Contains(contentType, "jpg") {
114
+ ext = ".jpg"
115
+ }
116
+ filename = uuid.New().String()[:12] + ext
117
+ }
118
+ }
119
+
120
+ // 构建 multipart form 请求
121
+ var buf bytes.Buffer
122
+ writer := multipart.NewWriter(&buf)
123
+
124
+ part, err := writer.CreateFormFile("file", filename)
125
+ if err != nil {
126
+ return nil, fmt.Errorf("failed to create form file: %v", err)
127
+ }
128
+
129
+ if _, err := part.Write(imageData); err != nil {
130
+ return nil, fmt.Errorf("failed to write image data: %v", err)
131
+ }
132
+
133
+ writer.Close()
134
+
135
+ // 发送上传请求
136
+ req, err := http.NewRequest("POST", "https://chat.z.ai/api/v1/files/", &buf)
137
+ if err != nil {
138
+ return nil, fmt.Errorf("failed to create upload request: %v", err)
139
+ }
140
+
141
+ req.Header.Set("Authorization", "Bearer "+token)
142
+ req.Header.Set("Content-Type", writer.FormDataContentType())
143
+ req.Header.Set("Origin", "https://chat.z.ai")
144
+ req.Header.Set("Referer", "https://chat.z.ai/")
145
+
146
+ client := &http.Client{}
147
+ resp, err := client.Do(req)
148
+ if err != nil {
149
+ return nil, fmt.Errorf("failed to upload image: %v", err)
150
+ }
151
+ defer resp.Body.Close()
152
+
153
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
154
+ body, _ := io.ReadAll(resp.Body)
155
+ return nil, fmt.Errorf("upload failed: status %d, body: %s", resp.StatusCode, string(body))
156
+ }
157
+
158
+ var uploadResp FileUploadResponse
159
+ if err := json.NewDecoder(resp.Body).Decode(&uploadResp); err != nil {
160
+ return nil, fmt.Errorf("failed to parse upload response: %v", err)
161
+ }
162
+
163
+ // 构建上游文件格式
164
+ return &UpstreamFile{
165
+ Type: "image",
166
+ File: uploadResp,
167
+ ID: uploadResp.ID,
168
+ URL: fmt.Sprintf("/api/v1/files/%s/content", uploadResp.ID),
169
+ Name: uploadResp.Filename,
170
+ Status: "uploaded",
171
+ Size: uploadResp.Meta.Size,
172
+ Error: "",
173
+ ItemID: uuid.New().String(),
174
+ Media: "image",
175
+ }, nil
176
+ }
177
+
178
+ // UploadImages 批量上传图片
179
+ func UploadImages(token string, imageURLs []string) ([]*UpstreamFile, error) {
180
+ var files []*UpstreamFile
181
+ for _, url := range imageURLs {
182
+ file, err := UploadImageFromURL(token, url)
183
+ if err != nil {
184
+ LogError("Failed to upload image %s: %v", url[:min(50, len(url))], err)
185
+ continue
186
+ }
187
+ files = append(files, file)
188
+ }
189
+ return files, nil
190
+ }
191
+
192
+ func min(a, b int) int {
193
+ if a < b {
194
+ return a
195
+ }
196
+ return b
197
+ }
internal/version.go ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package internal
2
+
3
+ import (
4
+ "io"
5
+ "net/http"
6
+ "regexp"
7
+ "sync"
8
+ "time"
9
+ )
10
+
11
+ var (
12
+ feVersion string
13
+ versionLock sync.RWMutex
14
+ )
15
+
16
+ func GetFeVersion() string {
17
+ versionLock.RLock()
18
+ defer versionLock.RUnlock()
19
+ return feVersion
20
+ }
21
+
22
+ func fetchFeVersion() {
23
+ resp, err := http.Get("https://chat.z.ai/")
24
+ if err != nil {
25
+ LogError("Failed to fetch fe version: %v", err)
26
+ return
27
+ }
28
+ defer resp.Body.Close()
29
+
30
+ body, err := io.ReadAll(resp.Body)
31
+ if err != nil {
32
+ LogError("Failed to read fe version response: %v", err)
33
+ return
34
+ }
35
+
36
+ re := regexp.MustCompile(`prod-fe-[\.\d]+`)
37
+ match := re.FindString(string(body))
38
+ if match != "" {
39
+ versionLock.Lock()
40
+ feVersion = match
41
+ versionLock.Unlock()
42
+ LogInfo("Updated fe version: %s", match)
43
+ }
44
+ }
45
+
46
+ func StartVersionUpdater() {
47
+ fetchFeVersion()
48
+
49
+ ticker := time.NewTicker(1 * time.Hour)
50
+ go func() {
51
+ for range ticker.C {
52
+ fetchFeVersion()
53
+ }
54
+ }()
55
+ }
main.go ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "net/http"
5
+
6
+ "zai-proxy/internal"
7
+ )
8
+
9
+ func main() {
10
+ internal.LoadConfig()
11
+ internal.InitLogger()
12
+ internal.StartVersionUpdater()
13
+
14
+ http.HandleFunc("/v1/models", internal.HandleModels)
15
+ http.HandleFunc("/v1/chat/completions", internal.HandleChatCompletions)
16
+
17
+ addr := ":" + internal.Cfg.Port
18
+ internal.LogInfo("Server starting on %s", addr)
19
+ if err := http.ListenAndServe(addr, nil); err != nil {
20
+ internal.LogError("Server failed: %v", err)
21
+ }
22
+ }