xidu commited on
Commit
46bfd69
·
1 Parent(s): a71c0ec

fix(build): Refactor to use ChatSession and fix compiler errors

Browse files
Files changed (2) hide show
  1. Dockerfile +10 -2
  2. main.go +62 -63
Dockerfile CHANGED
@@ -3,12 +3,20 @@ FROM golang:1.21-alpine AS builder
3
 
4
  WORKDIR /app
5
 
6
- # 复制go.mod和go.sum文件并下载依赖项
7
  COPY go.mod ./
 
 
 
 
 
 
8
  COPY main.go ./
 
 
9
  RUN go mod tidy
10
 
11
- # 构建应用 (在同一个RUN指令中,这样go.sum会被找到)
12
  RUN CGO_ENABLED=0 GOOS=linux go build -o /go-api
13
 
14
  # 阶段 2: 运行
 
3
 
4
  WORKDIR /app
5
 
6
+ # 复制go.mod和go.sum文件
7
  COPY go.mod ./
8
+
9
+ # (此步骤在下一步的 go mod tidy 中已包含,为保持清晰而保留)
10
+ # 先下载依赖,可以利用Docker的层缓存
11
+ RUN go mod download
12
+
13
+ # 复制源代码
14
  COPY main.go ./
15
+
16
+ # tidy会确保go.sum文件是最新的,并移除不用的依赖
17
  RUN go mod tidy
18
 
19
+ # 构建应用
20
  RUN CGO_ENABLED=0 GOOS=linux go build -o /go-api
21
 
22
  # 阶段 2: 运行
main.go CHANGED
@@ -5,7 +5,6 @@ import (
5
  "context"
6
  "encoding/json"
7
  "fmt"
8
- "io"
9
  "log"
10
  "math/rand"
11
  "net/http"
@@ -61,12 +60,11 @@ var supportedModels = []ModelInfo{
61
  }
62
 
63
  // 将OpenAI模型名称映射到Gemini模型名称
64
- // 根据用户要求,键和值现在是相同的。
65
  var modelMapping = map[string]string{
66
- "gemini-2.5-flash-preview-05-20": "gemini-2.5-flash-preview-05-20",
67
- "gemini-2.5-flash": "gemini-2.5-flash",
68
  "gemini-1.5-pro-latest": "gemini-1.5-pro-latest",
69
- "gemini-2.5-pro": "gemini-2.5-pro",
70
  }
71
 
72
  // 配置安全设置 (全部禁用)
@@ -93,13 +91,11 @@ const maxRetries = 3
93
 
94
  // --- 数据结构 (用于JSON序列化/反序列化) ---
95
 
96
- // OpenAI格式的聊天消息
97
  type ChatMessage struct {
98
  Role string `json:"role"`
99
  Content string `json:"content"`
100
  }
101
 
102
- // OpenAI格式的聊天请求
103
  type ChatCompletionRequest struct {
104
  Model string `json:"model"`
105
  Messages []ChatMessage `json:"messages"`
@@ -109,7 +105,6 @@ type ChatCompletionRequest struct {
109
  TopP float32 `json:"top_p,omitempty"`
110
  }
111
 
112
- // OpenAI格式的标准聊天响应
113
  type ChatCompletionResponse struct {
114
  ID string `json:"id"`
115
  Object string `json:"object"`
@@ -131,7 +126,6 @@ type Usage struct {
131
  TotalTokens int `json:"total_tokens"`
132
  }
133
 
134
- // OpenAI格式的流式聊天响应
135
  type ChatCompletionStreamResponse struct {
136
  ID string `json:"id"`
137
  Object string `json:"object"`
@@ -146,7 +140,6 @@ type StreamChoice struct {
146
  FinishReason *string `json:"finish_reason,omitempty"`
147
  }
148
 
149
- // 模型信息结构
150
  type ModelInfo struct {
151
  ID string `json:"id"`
152
  Object string `json:"object"`
@@ -162,7 +155,6 @@ type ModelListResponse struct {
162
 
163
  // --- 核心逻辑 ---
164
 
165
- // 获取一个随机的API密钥
166
  func getRandomAPIKey() string {
167
  if len(apiKeys) == 0 {
168
  log.Fatal("API密钥列表为空,请在 `apiKeys` 变量中配置密钥。")
@@ -171,33 +163,38 @@ func getRandomAPIKey() string {
171
  return apiKeys[r.Intn(len(apiKeys))]
172
  }
173
 
174
- // 将OpenAI格式的消息转换为Gemini格式
175
- func convertMessages(messages []ChatMessage) ([]*genai.Content, *genai.Content) {
176
- var geminiContents []*genai.Content
177
- var systemInstruction *genai.Content
 
178
 
179
- for _, msg := range messages {
180
  var role string
181
- if msg.Role == "user" {
182
- role = "user"
183
- } else if msg.Role == "assistant" {
184
- role = "model"
185
- } else if msg.Role == "system" {
186
- // 将系统指令分开处理
187
  systemInstruction = &genai.Content{Parts: []genai.Part{genai.Text(msg.Content)}}
188
- continue // 系统指令不包含在主要内容中
 
 
 
 
 
 
 
 
 
189
  } else {
190
- role = "user" // 默认为用户
191
  }
192
- geminiContents = append(geminiContents, &genai.Content{
 
193
  Role: role,
194
  Parts: []genai.Part{genai.Text(msg.Content)},
195
  })
196
  }
197
- return geminiContents, systemInstruction
198
  }
199
 
200
- // chatCompletionsHandler 处理聊天请求
201
  func chatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
202
  if r.Method != http.MethodPost {
203
  http.Error(w, "仅支持POST方法", http.StatusMethodNotAllowed)
@@ -210,18 +207,12 @@ func chatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
210
  return
211
  }
212
 
213
- // 映射模型名称
214
  modelName, ok := modelMapping[req.Model]
215
  if !ok {
216
- // 如果在映射中找不到,则直接使用请求的模型名称,
217
- // 并选择一个默认的最新模型作为备用。
218
- modelName = req.Model
219
- log.Printf("警告: 模型 '%s' 不在预定义的映射中。将直接使用该名称。", req.Model)
220
  }
221
 
222
-
223
- // 转换消息格式
224
- contents, systemInstruction := convertMessages(req.Messages)
225
 
226
  var lastErr error
227
  usedKeys := make(map[string]bool)
@@ -230,7 +221,6 @@ func chatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
230
  ctx := context.Background()
231
  apiKey := getRandomAPIKey()
232
 
233
- // 确保在一次重试中不使用重复的密钥
234
  if len(usedKeys) < len(apiKeys) {
235
  for usedKeys[apiKey] {
236
  apiKey = getRandomAPIKey()
@@ -256,32 +246,34 @@ func chatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
256
  if req.MaxTokens > 0 {
257
  model.SetMaxOutputTokens(req.MaxTokens)
258
  }
 
 
 
259
 
260
  if req.Stream {
261
- err = handleStream(w, ctx, model, contents, req.Model)
262
  } else {
263
- err = handleNonStream(w, ctx, model, contents, req.Model)
264
  }
265
 
266
  if err == nil {
267
- return // 成功处理
268
  }
269
 
270
  lastErr = err
271
  log.Printf("第 %d 次尝试失败: %v", i+1, err)
272
- time.Sleep(1 * time.Second) // 等待1秒后重试
273
  }
274
 
275
  http.Error(w, fmt.Sprintf("所有重试均失败: %v", lastErr), http.StatusInternalServerError)
276
  }
277
 
278
- // handleStream 处理流式响应
279
- func handleStream(w http.ResponseWriter, ctx context.Context, model *genai.GenerativeModel, contents []*genai.Content, modelID string) error {
280
  w.Header().Set("Content-Type", "text/event-stream")
281
  w.Header().Set("Cache-Control", "no-cache")
282
  w.Header().Set("Connection", "keep-alive")
283
 
284
- iter := model.GenerateContentStream(ctx, contents...)
285
  for {
286
  resp, err := iter.Next()
287
  if err == iterator.Done {
@@ -297,7 +289,7 @@ func handleStream(w http.ResponseWriter, ctx context.Context, model *genai.Gener
297
  contentBuilder.WriteString(string(txt))
298
  }
299
  }
300
-
301
  chunk := ChatCompletionStreamResponse{
302
  ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
303
  Object: "chat.completion.chunk",
@@ -325,7 +317,6 @@ func handleStream(w http.ResponseWriter, ctx context.Context, model *genai.Gener
325
  }
326
  }
327
 
328
- // 发送结束标志
329
  finishReason := "stop"
330
  doneChunk := ChatCompletionStreamResponse{
331
  ID: fmt.Sprintf("chatcmpl-%d-done", time.Now().Unix()),
@@ -350,23 +341,38 @@ func handleStream(w http.ResponseWriter, ctx context.Context, model *genai.Gener
350
  return nil
351
  }
352
 
353
- // handleNonStream 处理非流式响应
354
- func handleNonStream(w http.ResponseWriter, ctx context.Context, model *genai.GenerativeModel, contents []*genai.Content, modelID string) error {
355
- resp, err := model.GenerateContent(ctx, contents...)
356
  if err != nil {
357
  return fmt.Errorf("生成内容失败: %v", err)
358
  }
359
 
360
  var contentBuilder strings.Builder
361
- for _, part := range resp.Candidates[0].Content.Parts {
362
- if txt, ok := part.(genai.Text); ok {
363
- contentBuilder.WriteString(string(txt))
 
 
364
  }
365
  }
 
 
 
 
 
 
 
366
 
367
- promptTokens := int(model.CountTokens(ctx, contents...).TotalTokens)
368
- completionTokens := int(model.CountTokens(ctx, resp.Candidates[0].Content).TotalTokens)
 
 
369
 
 
 
 
 
 
370
  response := ChatCompletionResponse{
371
  ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
372
  Object: "chat.completion",
@@ -383,9 +389,9 @@ func handleNonStream(w http.ResponseWriter, ctx context.Context, model *genai.Ge
383
  },
384
  },
385
  Usage: Usage{
386
- PromptTokens: promptTokens,
387
- CompletionTokens: completionTokens,
388
- TotalTokens: promptTokens + completionTokens,
389
  },
390
  }
391
 
@@ -393,7 +399,6 @@ func handleNonStream(w http.ResponseWriter, ctx context.Context, model *genai.Ge
393
  return json.NewEncoder(w).Encode(response)
394
  }
395
 
396
- // --- 辅助端点 ---
397
 
398
  func modelsHandler(w http.ResponseWriter, r *http.Request) {
399
  resp := ModelListResponse{
@@ -435,8 +440,6 @@ func healthHandler(w http.ResponseWriter, r *http.Request) {
435
  json.NewEncoder(w).Encode(health)
436
  }
437
 
438
- // --- Main函数 ---
439
-
440
  func main() {
441
  mux := http.NewServeMux()
442
 
@@ -444,10 +447,8 @@ func main() {
444
  mux.HandleFunc("/health", healthHandler)
445
  mux.HandleFunc("/v1/models", modelsHandler)
446
  mux.HandleFunc("/v1/chat/completions", chatCompletionsHandler)
447
- // 添加兼容路径
448
  mux.HandleFunc("/v1/chat/completions/v1/models", modelsHandler)
449
 
450
- // 配置CORS
451
  c := cors.New(cors.Options{
452
  AllowedOrigins: []string{"*"},
453
  AllowedMethods: []string{"GET", "POST", "OPTIONS"},
@@ -469,14 +470,12 @@ func main() {
469
  log.Println("🔄 支持自动重试和密钥轮换")
470
  log.Printf("🔗 服务器正在监听 http://0.0.0.0:%s", port)
471
 
472
- // 从环境变量中读取密钥
473
  envKey := os.Getenv("GEMINI_API_KEY")
474
  if envKey != "" {
475
  apiKeys = strings.Split(envKey, ",")
476
  log.Printf("从环境变量 GEMINI_API_KEY 加载了 %d 个密钥", len(apiKeys))
477
  }
478
 
479
-
480
  if err := http.ListenAndServe(":"+port, handler); err != nil {
481
  log.Fatalf("启动服务器失败: %v", err)
482
  }
 
5
  "context"
6
  "encoding/json"
7
  "fmt"
 
8
  "log"
9
  "math/rand"
10
  "net/http"
 
60
  }
61
 
62
  // 将OpenAI模型名称映射到Gemini模型名称
 
63
  var modelMapping = map[string]string{
64
+ "gemini-2.5-flash-preview-05-20": "gemini-1.5-flash-latest",
65
+ "gemini-2.5-flash": "gemini-1.5-flash-latest",
66
  "gemini-1.5-pro-latest": "gemini-1.5-pro-latest",
67
+ "gemini-2.5-pro": "gemini-1.5-pro-latest",
68
  }
69
 
70
  // 配置安全设置 (全部禁用)
 
91
 
92
  // --- 数据结构 (用于JSON序列化/反序列化) ---
93
 
 
94
  type ChatMessage struct {
95
  Role string `json:"role"`
96
  Content string `json:"content"`
97
  }
98
 
 
99
  type ChatCompletionRequest struct {
100
  Model string `json:"model"`
101
  Messages []ChatMessage `json:"messages"`
 
105
  TopP float32 `json:"top_p,omitempty"`
106
  }
107
 
 
108
  type ChatCompletionResponse struct {
109
  ID string `json:"id"`
110
  Object string `json:"object"`
 
126
  TotalTokens int `json:"total_tokens"`
127
  }
128
 
 
129
  type ChatCompletionStreamResponse struct {
130
  ID string `json:"id"`
131
  Object string `json:"object"`
 
140
  FinishReason *string `json:"finish_reason,omitempty"`
141
  }
142
 
 
143
  type ModelInfo struct {
144
  ID string `json:"id"`
145
  Object string `json:"object"`
 
155
 
156
  // --- 核心逻辑 ---
157
 
 
158
  func getRandomAPIKey() string {
159
  if len(apiKeys) == 0 {
160
  log.Fatal("API密钥列表为空,请在 `apiKeys` 变量中配置密钥。")
 
163
  return apiKeys[r.Intn(len(apiKeys))]
164
  }
165
 
166
+ // convertMessages 将OpenAI格式的消息转换为Gemini格式的历史记录和最后一个用户的提示
167
+ func convertMessages(messages []ChatMessage) (history []*genai.Content, lastPrompt []genai.Part, systemInstruction *genai.Content) {
168
+ if len(messages) == 0 {
169
+ return nil, nil, nil
170
+ }
171
 
172
+ for i, msg := range messages {
173
  var role string
174
+ if msg.Role == "system" {
 
 
 
 
 
175
  systemInstruction = &genai.Content{Parts: []genai.Part{genai.Text(msg.Content)}}
176
+ continue
177
+ }
178
+
179
+ if i == len(messages)-1 && msg.Role == "user" {
180
+ lastPrompt = append(lastPrompt, genai.Text(msg.Content))
181
+ continue
182
+ }
183
+
184
+ if msg.Role == "assistant" {
185
+ role = "model"
186
  } else {
187
+ role = "user"
188
  }
189
+
190
+ history = append(history, &genai.Content{
191
  Role: role,
192
  Parts: []genai.Part{genai.Text(msg.Content)},
193
  })
194
  }
195
+ return history, lastPrompt, systemInstruction
196
  }
197
 
 
198
  func chatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
199
  if r.Method != http.MethodPost {
200
  http.Error(w, "仅支持POST方法", http.StatusMethodNotAllowed)
 
207
  return
208
  }
209
 
 
210
  modelName, ok := modelMapping[req.Model]
211
  if !ok {
212
+ modelName = "gemini-1.5-flash-latest" // 默认模型
 
 
 
213
  }
214
 
215
+ history, lastPrompt, systemInstruction := convertMessages(req.Messages)
 
 
216
 
217
  var lastErr error
218
  usedKeys := make(map[string]bool)
 
221
  ctx := context.Background()
222
  apiKey := getRandomAPIKey()
223
 
 
224
  if len(usedKeys) < len(apiKeys) {
225
  for usedKeys[apiKey] {
226
  apiKey = getRandomAPIKey()
 
246
  if req.MaxTokens > 0 {
247
  model.SetMaxOutputTokens(req.MaxTokens)
248
  }
249
+
250
+ chat := model.StartChat()
251
+ chat.History = history
252
 
253
  if req.Stream {
254
+ err = handleStream(w, ctx, chat, lastPrompt, req.Model)
255
  } else {
256
+ err = handleNonStream(w, ctx, model, chat, lastPrompt, req.Model)
257
  }
258
 
259
  if err == nil {
260
+ return
261
  }
262
 
263
  lastErr = err
264
  log.Printf("第 %d 次尝试失败: %v", i+1, err)
265
+ time.Sleep(1 * time.Second)
266
  }
267
 
268
  http.Error(w, fmt.Sprintf("所有重试均失败: %v", lastErr), http.StatusInternalServerError)
269
  }
270
 
271
+ func handleStream(w http.ResponseWriter, ctx context.Context, chat *genai.ChatSession, prompt []genai.Part, modelID string) error {
 
272
  w.Header().Set("Content-Type", "text/event-stream")
273
  w.Header().Set("Cache-Control", "no-cache")
274
  w.Header().Set("Connection", "keep-alive")
275
 
276
+ iter := chat.SendMessageStream(ctx, prompt...)
277
  for {
278
  resp, err := iter.Next()
279
  if err == iterator.Done {
 
289
  contentBuilder.WriteString(string(txt))
290
  }
291
  }
292
+
293
  chunk := ChatCompletionStreamResponse{
294
  ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
295
  Object: "chat.completion.chunk",
 
317
  }
318
  }
319
 
 
320
  finishReason := "stop"
321
  doneChunk := ChatCompletionStreamResponse{
322
  ID: fmt.Sprintf("chatcmpl-%d-done", time.Now().Unix()),
 
341
  return nil
342
  }
343
 
344
+ func handleNonStream(w http.ResponseWriter, ctx context.Context, model *genai.GenerativeModel, chat *genai.ChatSession, prompt []genai.Part, modelID string) error {
345
+ resp, err := chat.SendMessage(ctx, prompt...)
 
346
  if err != nil {
347
  return fmt.Errorf("生成内容失败: %v", err)
348
  }
349
 
350
  var contentBuilder strings.Builder
351
+ if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
352
+ for _, part := range resp.Candidates[0].Content.Parts {
353
+ if txt, ok := part.(genai.Text); ok {
354
+ contentBuilder.WriteString(string(txt))
355
+ }
356
  }
357
  }
358
+
359
+ // 计算Token
360
+ var promptParts []genai.Part
361
+ for _, c := range chat.History {
362
+ promptParts = append(promptParts, c.Parts...)
363
+ }
364
+ promptParts = append(promptParts, prompt...)
365
 
366
+ promptTokenCount, err := model.CountTokens(ctx, promptParts...)
367
+ if err != nil {
368
+ return fmt.Errorf("计算prompt tokens失败: %v", err)
369
+ }
370
 
371
+ completionTokenCount, err := model.CountTokens(ctx, resp.Candidates[0].Content.Parts...)
372
+ if err != nil {
373
+ return fmt.Errorf("计算completion tokens失败: %v", err)
374
+ }
375
+
376
  response := ChatCompletionResponse{
377
  ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
378
  Object: "chat.completion",
 
389
  },
390
  },
391
  Usage: Usage{
392
+ PromptTokens: int(promptTokenCount.TotalTokens),
393
+ CompletionTokens: int(completionTokenCount.TotalTokens),
394
+ TotalTokens: int(promptTokenCount.TotalTokens) + int(completionTokenCount.TotalTokens),
395
  },
396
  }
397
 
 
399
  return json.NewEncoder(w).Encode(response)
400
  }
401
 
 
402
 
403
  func modelsHandler(w http.ResponseWriter, r *http.Request) {
404
  resp := ModelListResponse{
 
440
  json.NewEncoder(w).Encode(health)
441
  }
442
 
 
 
443
  func main() {
444
  mux := http.NewServeMux()
445
 
 
447
  mux.HandleFunc("/health", healthHandler)
448
  mux.HandleFunc("/v1/models", modelsHandler)
449
  mux.HandleFunc("/v1/chat/completions", chatCompletionsHandler)
 
450
  mux.HandleFunc("/v1/chat/completions/v1/models", modelsHandler)
451
 
 
452
  c := cors.New(cors.Options{
453
  AllowedOrigins: []string{"*"},
454
  AllowedMethods: []string{"GET", "POST", "OPTIONS"},
 
470
  log.Println("🔄 支持自动重试和密钥轮换")
471
  log.Printf("🔗 服务器正在监听 http://0.0.0.0:%s", port)
472
 
 
473
  envKey := os.Getenv("GEMINI_API_KEY")
474
  if envKey != "" {
475
  apiKeys = strings.Split(envKey, ",")
476
  log.Printf("从环境变量 GEMINI_API_KEY 加载了 %d 个密钥", len(apiKeys))
477
  }
478
 
 
479
  if err := http.ListenAndServe(":"+port, handler); err != nil {
480
  log.Fatalf("启动服务器失败: %v", err)
481
  }