Spaces:
Configuration error
Configuration error
Merge pull request #7 from CurtainTears/main
Browse files- api/main.go +57 -2
api/main.go
CHANGED
|
@@ -17,9 +17,41 @@ import (
|
|
| 17 |
"github.com/google/uuid"
|
| 18 |
)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
func init() {
|
| 21 |
// 初始化随机数生成器
|
| 22 |
rand.Seed(time.Now().UnixNano())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
// TokenCount 定义了 token 计数的结构
|
|
@@ -206,6 +238,16 @@ func Handler(w http.ResponseWriter, r *http.Request) {
|
|
| 206 |
})
|
| 207 |
}
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
response := ModelResponse{
|
| 210 |
Object: "list",
|
| 211 |
Data: models,
|
|
@@ -460,12 +502,25 @@ func Handler(w http.ResponseWriter, r *http.Request) {
|
|
| 460 |
q.Add("chatId", chatId)
|
| 461 |
q.Add("conversationTurnId", conversationTurnId)
|
| 462 |
q.Add("pastChatLength", fmt.Sprintf("%d", len(chatHistory)))
|
| 463 |
-
q.Add("selectedChatMode", "custom")
|
| 464 |
-
q.Add("selectedAiModel", mapModelName(openAIReq.Model))
|
| 465 |
q.Add("enable_agent_clarification_questions", "true")
|
| 466 |
q.Add("traceId", traceId)
|
| 467 |
q.Add("use_nested_youchat_updates", "true")
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
// 如果最后一条消息超过限制,使用文件上传
|
| 470 |
if lastMessageTokens > MaxContextTokens {
|
| 471 |
// 获取 nonce - 不再需要nonce
|
|
|
|
| 17 |
"github.com/google/uuid"
|
| 18 |
)
|
| 19 |
|
| 20 |
+
// 新增:存储agent模型ID的全局变量
|
| 21 |
+
var agentModelIDs []string
|
| 22 |
+
|
| 23 |
func init() {
|
| 24 |
// 初始化随机数生成器
|
| 25 |
rand.Seed(time.Now().UnixNano())
|
| 26 |
+
|
| 27 |
+
// 新增:初始化agent模型ID
|
| 28 |
+
initAgentModelIDs()
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// 新增:初始化函数读取环境变量中的agent模型ID
|
| 32 |
+
func initAgentModelIDs() {
|
| 33 |
+
agentModelIDsStr := os.Getenv("AGENT_MODEL_IDS")
|
| 34 |
+
if agentModelIDsStr != "" {
|
| 35 |
+
// 分割字符串,获取所有agent模型ID
|
| 36 |
+
agentModelIDs = strings.Split(agentModelIDsStr, ",")
|
| 37 |
+
// 去除每个ID的空白字符
|
| 38 |
+
for i := range agentModelIDs {
|
| 39 |
+
agentModelIDs[i] = strings.TrimSpace(agentModelIDs[i])
|
| 40 |
+
}
|
| 41 |
+
fmt.Printf("已加载 %d 个Agent模型ID: %v\n", len(agentModelIDs), agentModelIDs)
|
| 42 |
+
} else {
|
| 43 |
+
fmt.Println("未设置Agent模型ID环境变量,仅使用默认模型")
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// 新增:检查模型是否为agent模型
|
| 48 |
+
func isAgentModel(modelID string) bool {
|
| 49 |
+
for _, id := range agentModelIDs {
|
| 50 |
+
if id == modelID {
|
| 51 |
+
return true
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
return false
|
| 55 |
}
|
| 56 |
|
| 57 |
// TokenCount 定义了 token 计数的结构
|
|
|
|
| 238 |
})
|
| 239 |
}
|
| 240 |
|
| 241 |
+
// 新增:添加agent模型到模型列表
|
| 242 |
+
for _, agentID := range agentModelIDs {
|
| 243 |
+
models = append(models, ModelDetail{
|
| 244 |
+
ID: agentID,
|
| 245 |
+
Object: "model",
|
| 246 |
+
Created: created,
|
| 247 |
+
OwnedBy: "organization-owner",
|
| 248 |
+
})
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
response := ModelResponse{
|
| 252 |
Object: "list",
|
| 253 |
Data: models,
|
|
|
|
| 502 |
q.Add("chatId", chatId)
|
| 503 |
q.Add("conversationTurnId", conversationTurnId)
|
| 504 |
q.Add("pastChatLength", fmt.Sprintf("%d", len(chatHistory)))
|
| 505 |
+
//q.Add("selectedChatMode", "custom")
|
| 506 |
+
//q.Add("selectedAiModel", mapModelName(openAIReq.Model))
|
| 507 |
q.Add("enable_agent_clarification_questions", "true")
|
| 508 |
q.Add("traceId", traceId)
|
| 509 |
q.Add("use_nested_youchat_updates", "true")
|
| 510 |
|
| 511 |
+
// 新增:根据模型类型设置不同的参数
|
| 512 |
+
isAgent := isAgentModel(openAIReq.Model)
|
| 513 |
+
if isAgent {
|
| 514 |
+
// 新增:Agent模型: 只使用selectedChatMode=agent模型ID
|
| 515 |
+
fmt.Printf("使用Agent模型: %s\n", openAIReq.Model)
|
| 516 |
+
q.Add("selectedChatMode", openAIReq.Model) // 修改:直接使用模型ID作为chatMode
|
| 517 |
+
} else {
|
| 518 |
+
// 修改:默认模型: 使用selectedAiModel和selectedChatMode=custom
|
| 519 |
+
fmt.Printf("使用默认模型: %s (映射为: %s)\n", openAIReq.Model, mapModelName(openAIReq.Model))
|
| 520 |
+
q.Add("selectedAiModel", mapModelName(openAIReq.Model))
|
| 521 |
+
q.Add("selectedChatMode", "custom")
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
// 如果最后一条消息超过限制,使用文件上传
|
| 525 |
if lastMessageTokens > MaxContextTokens {
|
| 526 |
// 获取 nonce - 不再需要nonce
|