BOHE commited on
Commit
dfd67c1
·
unverified ·
2 Parent(s): 8a85dd8 4c1a6c0

Merge pull request #7 from CurtainTears/main

Browse files
Files changed (1) hide show
  1. api/main.go +57 -2
api/main.go CHANGED
@@ -17,9 +17,41 @@ import (
17
  "github.com/google/uuid"
18
  )
19
 
 
 
 
20
  func init() {
21
  // 初始化随机数生成器
22
  rand.Seed(time.Now().UnixNano())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  }
24
 
25
  // TokenCount 定义了 token 计数的结构
@@ -206,6 +238,16 @@ func Handler(w http.ResponseWriter, r *http.Request) {
206
  })
207
  }
208
 
 
 
 
 
 
 
 
 
 
 
209
  response := ModelResponse{
210
  Object: "list",
211
  Data: models,
@@ -460,12 +502,25 @@ func Handler(w http.ResponseWriter, r *http.Request) {
460
  q.Add("chatId", chatId)
461
  q.Add("conversationTurnId", conversationTurnId)
462
  q.Add("pastChatLength", fmt.Sprintf("%d", len(chatHistory)))
463
- q.Add("selectedChatMode", "custom")
464
- q.Add("selectedAiModel", mapModelName(openAIReq.Model))
465
  q.Add("enable_agent_clarification_questions", "true")
466
  q.Add("traceId", traceId)
467
  q.Add("use_nested_youchat_updates", "true")
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  // 如果最后一条消息超过限制,使用文件上传
470
  if lastMessageTokens > MaxContextTokens {
471
  // 获取 nonce - 不再需要nonce
 
17
  "github.com/google/uuid"
18
  )
19
 
20
+ // 新增:存储agent模型ID的全局变量
21
+ var agentModelIDs []string
22
+
23
  func init() {
24
  // 初始化随机数生成器
25
  rand.Seed(time.Now().UnixNano())
26
+
27
+ // 新增:初始化agent模型ID
28
+ initAgentModelIDs()
29
+ }
30
+
31
+ // 新增:初始化函数读取环境变量中的agent模型ID
32
+ func initAgentModelIDs() {
33
+ agentModelIDsStr := os.Getenv("AGENT_MODEL_IDS")
34
+ if agentModelIDsStr != "" {
35
+ // 分割字符串,获取所有agent模型ID
36
+ agentModelIDs = strings.Split(agentModelIDsStr, ",")
37
+ // 去除每个ID的空白字符
38
+ for i := range agentModelIDs {
39
+ agentModelIDs[i] = strings.TrimSpace(agentModelIDs[i])
40
+ }
41
+ fmt.Printf("已加载 %d 个Agent模型ID: %v\n", len(agentModelIDs), agentModelIDs)
42
+ } else {
43
+ fmt.Println("未设置Agent模型ID环境变量,仅使用默认模型")
44
+ }
45
+ }
46
+
47
+ // 新增:检查模型是否为agent模型
48
+ func isAgentModel(modelID string) bool {
49
+ for _, id := range agentModelIDs {
50
+ if id == modelID {
51
+ return true
52
+ }
53
+ }
54
+ return false
55
  }
56
 
57
  // TokenCount 定义了 token 计数的结构
 
238
  })
239
  }
240
 
241
+ // 新增:添加agent模型到模型列表
242
+ for _, agentID := range agentModelIDs {
243
+ models = append(models, ModelDetail{
244
+ ID: agentID,
245
+ Object: "model",
246
+ Created: created,
247
+ OwnedBy: "organization-owner",
248
+ })
249
+ }
250
+
251
  response := ModelResponse{
252
  Object: "list",
253
  Data: models,
 
502
  q.Add("chatId", chatId)
503
  q.Add("conversationTurnId", conversationTurnId)
504
  q.Add("pastChatLength", fmt.Sprintf("%d", len(chatHistory)))
505
+ //q.Add("selectedChatMode", "custom")
506
+ //q.Add("selectedAiModel", mapModelName(openAIReq.Model))
507
  q.Add("enable_agent_clarification_questions", "true")
508
  q.Add("traceId", traceId)
509
  q.Add("use_nested_youchat_updates", "true")
510
 
511
+ // 新增:根据模型类型设置不同的参数
512
+ isAgent := isAgentModel(openAIReq.Model)
513
+ if isAgent {
514
+ // 新增:Agent模型: 只使用selectedChatMode=agent模型ID
515
+ fmt.Printf("使用Agent模型: %s\n", openAIReq.Model)
516
+ q.Add("selectedChatMode", openAIReq.Model) // 修改:直接使用模型ID作为chatMode
517
+ } else {
518
+ // 修改:默认模型: 使用selectedAiModel和selectedChatMode=custom
519
+ fmt.Printf("使用默认模型: %s (映射为: %s)\n", openAIReq.Model, mapModelName(openAIReq.Model))
520
+ q.Add("selectedAiModel", mapModelName(openAIReq.Model))
521
+ q.Add("selectedChatMode", "custom")
522
+ }
523
+
524
  // 如果最后一条消息超过限制,使用文件上传
525
  if lastMessageTokens > MaxContextTokens {
526
  // 获取 nonce - 不再需要nonce