diff --git a/.env.example b/.env.example index 36ba42bd8666422b12c7d4003164353a7bd4ff36..07602eca08e83e9aa5d422a4f10c16951ed3e75d 100644 --- a/.env.example +++ b/.env.example @@ -10,9 +10,9 @@ # 数据库相关配置 # 数据库连接字符串 -# SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true +# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true # 日志数据库连接字符串 -# LOG_SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true +# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true # SQLite数据库路径 # SQLITE_PATH=/path/to/sqlite.db # 数据库最大空闲连接数 diff --git a/README.en.md b/README.en.md index feb4b0bb46f9167aa633168d37ccade2e7f29380..446c88f61455c2fbcbbbb0e4358214ae5d217e30 100644 --- a/README.en.md +++ b/README.en.md @@ -89,6 +89,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20` - `CRYPTO_SECRET`: Encryption key for encrypting database content - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview` +- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10` +- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2` ## Deployment diff --git a/Rerank.md b/Rerank.md index 6a07287eff221bbe2a3259fb0d579341af7cd66a..dc57d99bb0b7020a31eb12f3fa8c5d96012a3ab9 100644 --- a/Rerank.md +++ b/Rerank.md @@ -13,7 +13,7 @@ Request: ```json { - "model": "rerank-multilingual-v3.0", + "model": "jina-reranker-v2-base-multilingual", "query": "What is the capital of the United States?", "top_n": 3, "documents": [ diff --git a/VERSION b/VERSION index c302e31b8afc35c818fbba2f2d1a927f255a5861..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/VERSION +++ b/VERSION @@ -1 +0,0 @@ -v0.4.7.2.1 \ No newline at end of file diff --git a/common/constants.go b/common/constants.go index f967d066e29cd04796272fc57cf53d2ea5d69c40..04fb1b9a632852f9b646c0dda946a1eaef386e95 100644 --- a/common/constants.go +++ b/common/constants.go @@ -101,7 +101,7 @@ var PreConsumedQuota = 500 var RetryTimes = 0 -var RootUserEmail = "" +//var RootUserEmail = "" var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" diff --git a/common/go-channel.go b/common/go-channel.go index 3fc86537255a3e26e2ae110a07d49eb15b7f8200..f9168fc4674e5f53e8168663da366aaa10649bb9 100644 --- a/common/go-channel.go +++ b/common/go-channel.go @@ -1,22 +1,9 @@ package common import ( - "fmt" - "runtime/debug" "time" ) -func SafeGoroutine(f func()) { - go func() { - defer func() { - if r := recover(); r != nil { - SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack()))) - } - }() - f() - }() -} - func SafeSendBool(ch chan bool, value bool) (closed bool) { defer func() { // Recover from panic if one occured. A panic would mean the channel was closed. diff --git a/common/logger.go b/common/logger.go index 93d557d8c6eebf09963e5728d8e345e844b280ab..86d15fa4db7ac46a96abe1b2dfeda9839978ed2e 100644 --- a/common/logger.go +++ b/common/logger.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "io" "log" @@ -80,9 +81,9 @@ func logHelper(ctx context.Context, level string, msg string) { if logCount > maxLogCount && !setupLogWorking { logCount = 0 setupLogWorking = true - go func() { + gopool.Go(func() { SetupLogger() - }() + }) } } @@ -100,6 +101,14 @@ func LogQuota(quota int) string { } } +func FormatQuota(quota int) string { + if DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit) + } else { + return fmt.Sprintf("%d", quota) + } +} + // LogJson 仅供测试使用 only for test func LogJson(ctx context.Context, msg string, obj any) { jsonStr, err := json.Marshal(obj) diff --git a/common/model-ratio.go b/common/model-ratio.go index bb94ad36c17293a9647710b079487d3ee448d7cb..542cd93c6965310e95da0c7e045d1b26da95c3ae 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -233,7 +233,11 @@ var ( modelRatioMapMutex = sync.RWMutex{} ) -var CompletionRatio map[string]float64 = nil +var ( + CompletionRatio map[string]float64 = nil + CompletionRatioMutex = sync.RWMutex{} +) + var defaultCompletionRatio = map[string]float64{ "gpt-4-gizmo-*": 2, "gpt-4o-gizmo-*": 3, @@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 { return defaultModelRatio } -func CompletionRatio2JSONString() string { +func GetCompletionRatioMap() map[string]float64 { + CompletionRatioMutex.Lock() + defer CompletionRatioMutex.Unlock() if CompletionRatio == nil { CompletionRatio = defaultCompletionRatio } + return CompletionRatio +} + +func CompletionRatio2JSONString() string { + GetCompletionRatioMap() jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { SysError("error marshalling completion ratio: " + err.Error()) @@ -346,11 +357,15 @@ func CompletionRatio2JSONString() string { } func UpdateCompletionRatioByJSONString(jsonStr string) error { + CompletionRatioMutex.Lock() + defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) return json.Unmarshal([]byte(jsonStr), &CompletionRatio) } func GetCompletionRatio(name string) float64 { + GetCompletionRatioMap() + if strings.Contains(name, "/") { if ratio, ok := CompletionRatio[name]; ok { return ratio @@ -476,24 +491,3 @@ func GetAudioCompletionRatio(name string) float64 { } return 2 } - -//func GetAudioPricePerMinute(name string) float64 { -// if strings.HasPrefix(name, "gpt-4o-realtime") { -// return 0.06 -// } -// return 0.06 -//} -// -//func GetAudioCompletionPricePerMinute(name string) float64 { -// if strings.HasPrefix(name, "gpt-4o-realtime") { -// return 0.24 -// } -// return 0.24 -//} - -func GetCompletionRatioMap() map[string]float64 { - if CompletionRatio == nil { - CompletionRatio = defaultCompletionRatio - } - return CompletionRatio -} diff --git a/constant/env.go b/constant/env.go index 4135e8c7c3659d2a3487fade91da20db8330f073..bffbfeea5ba1efbafdd20e176d5d0dfe2402c021 100644 --- a/constant/env.go +++ b/constant/env.go @@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{ var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) +var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) +var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) + func InitEnv() { modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) if modelVersionMapStr == "" { @@ -44,5 +47,5 @@ func InitEnv() { } } -// 是否生成初始令牌,默认关闭。 +// GenerateDefaultToken 是否生成初始令牌,默认关闭。 var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false) diff --git a/constant/user_setting.go b/constant/user_setting.go new file mode 100644 index 0000000000000000000000000000000000000000..a5b921b2ffbcf24c0f72a309df49ce280d8a466f --- /dev/null +++ b/constant/user_setting.go @@ -0,0 +1,14 @@ +package constant + +var ( + UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型 + UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值 + UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址 + UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥 + UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址 +) + +var ( + NotifyTypeEmail = "email" // Email 邮件 + NotifyTypeWebhook = "webhook" // Webhook +) diff --git a/controller/channel-test.go b/controller/channel-test.go index 7e74bec23dbbf2a050e363ace4020cd313f00909..4b0cc169cb05c1d71b0443eff84ebe2681786c5b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false func testAllChannels(notify bool) error { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() - } + testAllChannelsLock.Lock() if testAllChannelsRunning { testAllChannelsLock.Unlock() @@ -295,10 +293,7 @@ func testAllChannels(notify bool) error { testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") - if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } + service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") } }) return nil diff --git a/controller/pricing.go b/controller/pricing.go index 36caff9d1dbe3451c3e64ee370511e0fb6afc055..d7af5a4c8341308e7d5f4347ba0a405426e509ba 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) { } var group string if exists { - user, err := model.GetUserById(userId.(int), false) + user, err := model.GetUserCache(userId.(int)) if err == nil { group = user.Group } diff --git a/controller/relay.go b/controller/relay.go index d7e0f00ad4916a9eb0cb2d72976fef82de39514a..0f7394156ce8d9a57f63fe64ed8e390b2c8373ca 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode switch relayMode { case relayconstant.RelayModeImagesGenerations: - err = relay.ImageHelper(c, relayMode) + err = relay.ImageHelper(c) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: diff --git a/controller/user.go b/controller/user.go index 7146f00e2d904378899f9d2e5e20a616d257a240..51e6f955c417fe333d0ebfe3217599216cc29b07 100644 --- a/controller/user.go +++ b/controller/user.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "one-api/common" "one-api/model" "one-api/setting" @@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) { if err != nil { id = c.GetInt("id") } - user, err := model.GetUserById(id, true) + user, err := model.GetUserCache(id) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -869,9 +870,6 @@ func EmailBind(c *gin.Context) { }) return } - if user.Role == common.RoleRootUser { - common.RootUserEmail = email - } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -913,3 +911,115 @@ func TopUp(c *gin.Context) { }) return } + +type UpdateUserSettingRequest struct { + QuotaWarningType string `json:"notify_type"` + QuotaWarningThreshold float64 `json:"quota_warning_threshold"` + WebhookUrl string `json:"webhook_url,omitempty"` + WebhookSecret string `json:"webhook_secret,omitempty"` + NotificationEmail string `json:"notification_email,omitempty"` +} + +func UpdateUserSetting(c *gin.Context) { + var req UpdateUserSettingRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + + // 验证预警类型 + if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的预警类型", + }) + return + } + + // 验证预警阈值 + if req.QuotaWarningThreshold <= 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "预警阈值必须大于0", + }) + return + } + + // 如果是webhook类型,验证webhook地址 + if req.QuotaWarningType == constant.NotifyTypeWebhook { + if req.WebhookUrl == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Webhook地址不能为空", + }) + return + } + // 验证URL格式 + if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的Webhook地址", + }) + return + } + } + + // 如果是邮件类型,验证邮箱地址 + if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { + // 验证邮箱格式 + if !strings.Contains(req.NotificationEmail, "@") { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的邮箱地址", + }) + return + } + } + + userId := c.GetInt("id") + user, err := model.GetUserById(userId, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + // 构建设置 + settings := map[string]interface{}{ + constant.UserSettingNotifyType: req.QuotaWarningType, + constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold, + } + + // 如果是webhook类型,添加webhook相关设置 + if req.QuotaWarningType == constant.NotifyTypeWebhook { + settings[constant.UserSettingWebhookUrl] = req.WebhookUrl + if req.WebhookSecret != "" { + settings[constant.UserSettingWebhookSecret] = req.WebhookSecret + } + } + + // 如果提供了通知邮箱,添加到设置中 + if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { + settings[constant.UserSettingNotificationEmail] = req.NotificationEmail + } + + // 更新用户设置 + user.SetSetting(settings) + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "更新设置失败: " + err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "设置已更新", + }) +} diff --git a/docker-compose.yml b/docker-compose.yml index 640cf074788a21567c8e339e6d137d4f69a03720..0f23cea27171d75a32a64f05ab1d211bddb50e44 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,7 +24,7 @@ services: - redis - mysql healthcheck: - test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] + test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"] interval: 30s timeout: 10s retries: 3 diff --git a/dto/notify.go b/dto/notify.go new file mode 100644 index 0000000000000000000000000000000000000000..b75cec70cae493900ddbd31271fc5059efc1003f --- /dev/null +++ b/dto/notify.go @@ -0,0 +1,25 @@ +package dto + +type Notify struct { + Type string `json:"type"` + Title string `json:"title"` + Content string `json:"content"` + Values []interface{} `json:"values"` +} + +const ContentValueParam = "{{value}}" + +const ( + NotifyTypeQuotaExceed = "quota_exceed" + NotifyTypeChannelUpdate = "channel_update" + NotifyTypeChannelTest = "channel_test" +) + +func NewNotify(t string, title string, content string, values []interface{}) Notify { + return Notify{ + Type: t, + Title: title, + Content: content, + Values: values, + } +} diff --git a/dto/openai_request.go b/dto/openai_request.go index 0f6411bb7a0ed792d722eb0072ccaeb96c225a97..028e0286cdb26188cc8c20b351830e70a46cfae6 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -18,6 +18,8 @@ type GeneralOpenAIRequest struct { Model string `json:"model,omitempty"` Messages []Message `json:"messages,omitempty"` Prompt any `json:"prompt,omitempty"` + Prefix any `json:"prefix,omitempty"` + Suffix any `json:"suffix,omitempty"` Stream bool `json:"stream,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"` @@ -86,18 +88,20 @@ func (r GeneralOpenAIRequest) ParseInput() []string { } type Message struct { - Role string `json:"role"` - Content json.RawMessage `json:"content"` - Name *string `json:"name,omitempty"` - Prefix *bool `json:"prefix,omitempty"` - ReasoningContent string `json:"reasoning_content,omitempty"` - ToolCalls json.RawMessage `json:"tool_calls,omitempty"` - ToolCallId string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Name *string `json:"name,omitempty"` + Prefix *bool `json:"prefix,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls json.RawMessage `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` + parsedContent []MediaContent + parsedStringContent *string } type MediaContent struct { Type string `json:"type"` - Text string `json:"text"` + Text string `json:"text,omitempty"` ImageUrl any `json:"image_url,omitempty"` InputAudio any `json:"input_audio,omitempty"` } @@ -146,6 +150,9 @@ func (m *Message) SetToolCalls(toolCalls any) { } func (m *Message) StringContent() string { + if m.parsedStringContent != nil { + return *m.parsedStringContent + } var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { return stringContent @@ -156,78 +163,113 @@ func (m *Message) StringContent() string { func (m *Message) SetStringContent(content string) { jsonContent, _ := json.Marshal(content) m.Content = jsonContent + m.parsedStringContent = &content + m.parsedContent = nil +} + +func (m *Message) SetMediaContent(content []MediaContent) { + jsonContent, _ := json.Marshal(content) + m.Content = jsonContent + m.parsedContent = nil + m.parsedStringContent = nil } func (m *Message) IsStringContent() bool { + if m.parsedStringContent != nil { + return true + } var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { + m.parsedStringContent = &stringContent return true } return false } func (m *Message) ParseContent() []MediaContent { + if m.parsedContent != nil { + return m.parsedContent + } + var contentList []MediaContent + + // 先尝试解析为字符串 var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { - contentList = append(contentList, MediaContent{ + contentList = []MediaContent{{ Type: ContentTypeText, Text: stringContent, - }) + }} + m.parsedContent = contentList return contentList } - var arrayContent []json.RawMessage + + // 尝试解析为数组 + var arrayContent []map[string]interface{} if err := json.Unmarshal(m.Content, &arrayContent); err == nil { for _, contentItem := range arrayContent { - var contentMap map[string]any - if err := json.Unmarshal(contentItem, &contentMap); err != nil { + contentType, ok := contentItem["type"].(string) + if !ok { continue } - switch contentMap["type"] { + + switch contentType { case ContentTypeText: - if subStr, ok := contentMap["text"].(string); ok { + if text, ok := contentItem["text"].(string); ok { contentList = append(contentList, MediaContent{ Type: ContentTypeText, - Text: subStr, + Text: text, }) } + case ContentTypeImageURL: - if subObj, ok := contentMap["image_url"].(map[string]any); ok { - detail, ok := subObj["detail"] - if ok { - subObj["detail"] = detail.(string) - } else { - subObj["detail"] = "high" - } + imageUrl := contentItem["image_url"] + switch v := imageUrl.(type) { + case string: contentList = append(contentList, MediaContent{ Type: ContentTypeImageURL, ImageUrl: MessageImageUrl{ - Url: subObj["url"].(string), - Detail: subObj["detail"].(string), - }, - }) - } else if url, ok := contentMap["image_url"].(string); ok { - contentList = append(contentList, MediaContent{ - Type: ContentTypeImageURL, - ImageUrl: MessageImageUrl{ - Url: url, + Url: v, Detail: "high", }, }) + case map[string]interface{}: + url, ok1 := v["url"].(string) + detail, ok2 := v["detail"].(string) + if !ok2 { + detail = "high" + } + if ok1 { + contentList = append(contentList, MediaContent{ + Type: ContentTypeImageURL, + ImageUrl: MessageImageUrl{ + Url: url, + Detail: detail, + }, + }) + } } + case ContentTypeInputAudio: - if subObj, ok := contentMap["input_audio"].(map[string]any); ok { - contentList = append(contentList, MediaContent{ - Type: ContentTypeInputAudio, - InputAudio: MessageInputAudio{ - Data: subObj["data"].(string), - Format: subObj["format"].(string), - }, - }) + if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok { + data, ok1 := audioData["data"].(string) + format, ok2 := audioData["format"].(string) + if ok1 && ok2 { + contentList = append(contentList, MediaContent{ + Type: ContentTypeInputAudio, + InputAudio: MessageInputAudio{ + Data: data, + Format: format, + }, + }) + } } } } - return contentList } - return nil + + if len(contentList) > 0 { + m.parsedContent = contentList + } + return contentList } diff --git a/dto/openai_response.go b/dto/openai_response.go index 2e0e2221e1ce5e27d9aa9ae4afb3abf7021a3380..febf01ff0d58fbd66660c4d9a4258e449538b743 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -62,9 +62,10 @@ type ChatCompletionsStreamResponseChoice struct { } type ChatCompletionsStreamResponseChoiceDelta struct { - Content *string `json:"content,omitempty"` - Role string `json:"role,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Content *string `json:"content,omitempty"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + Role string `json:"role,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { @@ -78,6 +79,13 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string { return *c.Content } +func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string { + if c.ReasoningContent == nil { + return "" + } + return *c.ReasoningContent +} + type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` diff --git a/main.go b/main.go index 68dae8f496c4cc455a4327a12b64094f67cfa4f4..495057cf1efc4bbf842378817c9b38dc11e41814 100644 --- a/main.go +++ b/main.go @@ -119,9 +119,9 @@ func main() { } if os.Getenv("ENABLE_PPROF") == "true" { - go func() { + gopool.Go(func() { log.Println(http.ListenAndServe("0.0.0.0:8005", nil)) - }() + }) go common.Monitor() common.SysLog("pprof enabled") } diff --git a/middleware/distributor.go b/middleware/distributor.go index c90f3e5eeb286ae484668f07a27a989c2bd61ccf..e0f9342a84f2c3f4858af8789bdd6e3592ed7048 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -135,17 +135,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { midjourneyRequest := dto.MidjourneyRequest{} err = common.UnmarshalBodyReusable(c, &midjourneyRequest) if err != nil { - abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) return nil, false, err } midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) if mjErr != nil { - abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) return nil, false, fmt.Errorf(mjErr.Description) } if midjourneyModel == "" { if !success { - abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") return nil, false, fmt.Errorf("无效的请求, 无法解析模型") } else { // task fetch, task fetch by condition, notify @@ -170,7 +167,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { - abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) return nil, false, errors.New("无效的请求, " + err.Error()) } if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") { diff --git a/model/option.go b/model/option.go index 0c4114a42256be3bbb4b0d2fdbc11b2bc3fd41ae..24935c69d1f7d0dac57ab4d2d4b72ba0a46f0a4f 100644 --- a/model/option.go +++ b/model/option.go @@ -84,7 +84,7 @@ func InitOptionMap() { common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() @@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) { common.QuotaForInvitee, _ = strconv.Atoi(value) case "QuotaRemindThreshold": common.QuotaRemindThreshold, _ = strconv.Atoi(value) - case "PreConsumedQuota": + case "ShouldPreConsumedQuota": common.PreConsumedQuota, _ = strconv.Atoi(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) diff --git a/model/token.go b/model/token.go index 3abd22cf6e45c79e351e935fcde62e906e04835a..8587ea62a9f3e24c6c9f8f8c24f095d08757af5b 100644 --- a/model/token.go +++ b/model/token.go @@ -3,13 +3,11 @@ package model import ( "errors" "fmt" - "github.com/bytedance/gopkg/util/gopool" - "gorm.io/gorm" "one-api/common" - relaycommon "one-api/relay/common" - "one-api/setting" - "strconv" "strings" + + "github.com/bytedance/gopkg/util/gopool" + "gorm.io/gorm" ) type Token struct { @@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) { ).Error return err } - -func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { - if quota < 0 { - return errors.New("quota 不能为负数!") - } - if relayInfo.IsPlayground { - return nil - } - //if relayInfo.TokenUnlimited { - // return nil - //} - token, err := GetTokenById(relayInfo.TokenId) - if err != nil { - return err - } - if !relayInfo.TokenUnlimited && token.RemainQuota < quota { - return errors.New("令牌额度不足") - } - err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) - if err != nil { - return err - } - return nil -} - -func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { - - if quota > 0 { - err = DecreaseUserQuota(relayInfo.UserId, quota) - } else { - err = IncreaseUserQuota(relayInfo.UserId, -quota) - } - if err != nil { - return err - } - - if !relayInfo.IsPlayground { - if quota > 0 { - err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) - } else { - err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota) - } - if err != nil { - return err - } - } - - if sendEmail { - if (quota + preConsumedQuota) != 0 { - quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold - noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0 - if quotaTooLow || noMoreQuota { - go func() { - email, err := GetUserEmail(relayInfo.UserId) - if err != nil { - common.SysError("failed to fetch user email: " + err.Error()) - } - prompt := "您的额度即将用尽" - if noMoreQuota { - prompt = "您的额度已用尽" - } - if email != "" { - topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) - err = common.SendEmail(prompt, email, - fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) - if err != nil { - common.SysError("failed to send email" + err.Error()) - } - common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota)) - } - }() - } - } - } - - return nil -} diff --git a/model/token_cache.go b/model/token_cache.go index 99b762f513ecb1aa42b2a48be104e2fb5993ca09..0fe02fea59058f4b5651bd927d7504d098f3d984 100644 --- a/model/token_cache.go +++ b/model/token_cache.go @@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error { func cacheGetTokenByKey(key string) (*Token, error) { hmacKey := common.GenerateHMAC(key) if !common.RedisEnabled { - return nil, nil + return nil, fmt.Errorf("redis is not enabled") } var token Token err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token) diff --git a/model/user.go b/model/user.go index 95123c2192f2c9ac4b27db729f964e8f5ec75a60..427b0625f4b2b998b1dd452e7bf969d984e21463 100644 --- a/model/user.go +++ b/model/user.go @@ -1,6 +1,7 @@ package model import ( + "encoding/json" "errors" "fmt" "one-api/common" @@ -38,6 +39,20 @@ type User struct { InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` DeletedAt gorm.DeletedAt `gorm:"index"` LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"` + Setting string `json:"setting" gorm:"type:text;column:setting"` +} + +func (user *User) ToBaseUser() *UserBase { + cache := &UserBase{ + Id: user.Id, + Group: user.Group, + Quota: user.Quota, + Status: user.Status, + Username: user.Username, + Setting: user.Setting, + Email: user.Email, + } + return cache } func (user *User) GetAccessToken() string { @@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) { user.AccessToken = &token } +func (user *User) GetSetting() map[string]interface{} { + if user.Setting == "" { + return nil + } + return common.StrToMap(user.Setting) +} + +func (user *User) SetSetting(setting map[string]interface{}) { + settingBytes, err := json.Marshal(setting) + if err != nil { + common.SysError("failed to marshal setting: " + err.Error()) + return + } + user.Setting = string(settingBytes) +} + // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil func CheckUserExistOrDeleted(username string, email string) (bool, error) { var user User @@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error { return err } - // 更新缓存 - return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) + // Update cache + return updateUserCache(*user) } func (user *User) Edit(updatePassword bool) error { @@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error { return err } - // 更新缓存 - return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status) + // Update cache + return updateUserCache(*user) } func (user *User) Delete() error { @@ -371,8 +402,8 @@ func (user *User) HardDelete() error { // ValidateAndFill check password & user status func (user *User) ValidateAndFill() (err error) { // When querying with struct, GORM will only query with non-zero fields, - // that means if your field’s value is 0, '', false or other zero values, - // it won’t be used to build query conditions + // that means if your field's value is 0, '', false or other zero values, + // it won't be used to build query conditions password := user.Password username := strings.TrimSpace(user.Username) if username == "" || password == "" { @@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { return quota, nil } // Don't return error - fall through to DB - //common.SysError("failed to get user quota from cache: " + err.Error()) } fromDB = true err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error @@ -580,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { return group, nil } +// GetUserSetting gets setting from Redis first, falls back to DB if needed +func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) { + var setting string + defer func() { + // Update Redis cache asynchronously on successful DB read + if shouldUpdateRedis(fromDB, err) { + gopool.Go(func() { + if err := updateUserSettingCache(id, setting); err != nil { + common.SysError("failed to update user setting cache: " + err.Error()) + } + }) + } + }() + if !fromDB && common.RedisEnabled { + setting, err := getUserSettingCache(id) + if err == nil { + return setting, nil + } + // Don't return error - fall through to DB + } + fromDB = true + err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error + if err != nil { + return map[string]interface{}{}, err + } + + return common.StrToMap(setting), nil +} + func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") @@ -641,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) { } } -func GetRootUserEmail() (email string) { - DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) - return email +//func GetRootUserEmail() (email string) { +// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) +// return email +//} + +func GetRootUser() (user *User) { + DB.Where("role = ?", common.RoleRootUser).First(&user) + return user } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { @@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool { return !errors.Is(err, gorm.ErrRecordNotFound) } -func (u *User) FillUserByLinuxDOId() error { - if u.LinuxDOId == "" { +func (user *User) FillUserByLinuxDOId() error { + if user.LinuxDOId == "" { return errors.New("linux do id is empty") } - err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error + err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error return err } diff --git a/model/user_cache.go b/model/user_cache.go index 9dc7e899ebb2ddf99698dc1c5abd2e325ce787f8..cc08288d681e3cf77be4c071079101a8cc0726fa 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -1,206 +1,213 @@ package model import ( + "encoding/json" "fmt" "one-api/common" "one-api/constant" - "strconv" "time" + + "github.com/bytedance/gopkg/util/gopool" ) -// Change UserCache struct to userCache -type userCache struct { +// UserBase struct remains the same as it represents the cached data structure +type UserBase struct { Id int `json:"id"` Group string `json:"group"` + Email string `json:"email"` Quota int `json:"quota"` Status int `json:"status"` - Role int `json:"role"` Username string `json:"username"` + Setting string `json:"setting"` } -// Rename all exported functions to private ones -// invalidateUserCache clears all user related cache -func invalidateUserCache(userId int) error { - if !common.RedisEnabled { +func (user *UserBase) GetSetting() map[string]interface{} { + if user.Setting == "" { return nil } - - keys := []string{ - fmt.Sprintf(constant.UserGroupKeyFmt, userId), - fmt.Sprintf(constant.UserQuotaKeyFmt, userId), - fmt.Sprintf(constant.UserEnabledKeyFmt, userId), - fmt.Sprintf(constant.UserUsernameKeyFmt, userId), - } - - for _, key := range keys { - if err := common.RedisDel(key); err != nil { - return fmt.Errorf("failed to delete cache key %s: %w", key, err) - } - } - return nil + return common.StrToMap(user.Setting) } -// updateUserGroupCache updates user group cache -func updateUserGroupCache(userId int, group string) error { - if !common.RedisEnabled { - return nil +func (user *UserBase) SetSetting(setting map[string]interface{}) { + settingBytes, err := json.Marshal(setting) + if err != nil { + common.SysError("failed to marshal setting: " + err.Error()) + return } - return common.RedisSet( - fmt.Sprintf(constant.UserGroupKeyFmt, userId), - group, - time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, - ) + user.Setting = string(settingBytes) } -// updateUserQuotaCache updates user quota cache -func updateUserQuotaCache(userId int, quota int) error { - if !common.RedisEnabled { - return nil - } - return common.RedisSet( - fmt.Sprintf(constant.UserQuotaKeyFmt, userId), - fmt.Sprintf("%d", quota), - time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, - ) +// getUserCacheKey returns the key for user cache +func getUserCacheKey(userId int) string { + return fmt.Sprintf("user:%d", userId) } -// updateUserStatusCache updates user status cache -func updateUserStatusCache(userId int, userEnabled bool) error { +// invalidateUserCache clears user cache +func invalidateUserCache(userId int) error { if !common.RedisEnabled { return nil } - enabled := "0" - if userEnabled { - enabled = "1" - } - return common.RedisSet( - fmt.Sprintf(constant.UserEnabledKeyFmt, userId), - enabled, - time.Duration(constant.UserId2StatusCacheSeconds)*time.Second, - ) + return common.RedisHDelObj(getUserCacheKey(userId)) } -// updateUserNameCache updates username cache -func updateUserNameCache(userId int, username string) error { +// updateUserCache updates all user cache fields using hash +func updateUserCache(user User) error { if !common.RedisEnabled { return nil } - return common.RedisSet( - fmt.Sprintf(constant.UserUsernameKeyFmt, userId), - username, + + return common.RedisHSetObj( + getUserCacheKey(user.Id), + user.ToBaseUser(), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, ) } -// updateUserCache updates all user cache fields -func updateUserCache(userId int, username string, userGroup string, quota int, status int) error { - if !common.RedisEnabled { - return nil +// GetUserCache gets complete user cache from hash +func GetUserCache(userId int) (userCache *UserBase, err error) { + var user *User + var fromDB bool + defer func() { + // Update Redis cache asynchronously on successful DB read + if shouldUpdateRedis(fromDB, err) && user != nil { + gopool.Go(func() { + if err := updateUserCache(*user); err != nil { + common.SysError("failed to update user status cache: " + err.Error()) + } + }) + } + }() + + // Try getting from Redis first + userCache, err = cacheGetUserBase(userId) + if err == nil { + return userCache, nil } - if err := updateUserGroupCache(userId, userGroup); err != nil { - return fmt.Errorf("update group cache: %w", err) + // If Redis fails, get from DB + fromDB = true + user, err = GetUserById(userId, false) + if err != nil { + return nil, err // Return nil and error if DB lookup fails } - if err := updateUserQuotaCache(userId, quota); err != nil { - return fmt.Errorf("update quota cache: %w", err) + // Create cache object from user data + userCache = &UserBase{ + Id: user.Id, + Group: user.Group, + Quota: user.Quota, + Status: user.Status, + Username: user.Username, + Setting: user.Setting, + Email: user.Email, } - if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil { - return fmt.Errorf("update status cache: %w", err) + return userCache, nil +} + +func cacheGetUserBase(userId int) (*UserBase, error) { + if !common.RedisEnabled { + return nil, fmt.Errorf("redis is not enabled") } + var userCache UserBase + // Try getting from Redis first + err := common.RedisHGetObj(getUserCacheKey(userId), &userCache) + if err != nil { + return nil, err + } + return &userCache, nil +} - if err := updateUserNameCache(userId, username); err != nil { - return fmt.Errorf("update username cache: %w", err) +// Add atomic quota operations using hash fields +func cacheIncrUserQuota(userId int, delta int64) error { + if !common.RedisEnabled { + return nil } + return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta) +} - return nil +func cacheDecrUserQuota(userId int, delta int64) error { + return cacheIncrUserQuota(userId, -delta) } -// getUserGroupCache gets user group from cache +// Helper functions to get individual fields if needed func getUserGroupCache(userId int) (string, error) { - if !common.RedisEnabled { - return "", nil + cache, err := GetUserCache(userId) + if err != nil { + return "", err } - return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId)) + return cache.Group, nil } -// getUserQuotaCache gets user quota from cache func getUserQuotaCache(userId int) (int, error) { - if !common.RedisEnabled { - return 0, nil - } - quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId)) + cache, err := GetUserCache(userId) if err != nil { return 0, err } - return strconv.Atoi(quotaStr) + return cache.Quota, nil } -// getUserStatusCache gets user status from cache func getUserStatusCache(userId int) (int, error) { - if !common.RedisEnabled { - return 0, nil - } - statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId)) + cache, err := GetUserCache(userId) if err != nil { return 0, err } - return strconv.Atoi(statusStr) + return cache.Status, nil } -// getUserNameCache gets username from cache func getUserNameCache(userId int) (string, error) { - if !common.RedisEnabled { - return "", nil + cache, err := GetUserCache(userId) + if err != nil { + return "", err } - return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId)) + return cache.Username, nil } -// getUserCache gets complete user cache -func getUserCache(userId int) (*userCache, error) { - if !common.RedisEnabled { - return nil, nil - } - - group, err := getUserGroupCache(userId) +func getUserSettingCache(userId int) (map[string]interface{}, error) { + setting := make(map[string]interface{}) + cache, err := GetUserCache(userId) if err != nil { - return nil, fmt.Errorf("get group cache: %w", err) + return setting, err } + return cache.GetSetting(), nil +} - quota, err := getUserQuotaCache(userId) - if err != nil { - return nil, fmt.Errorf("get quota cache: %w", err) +// New functions for individual field updates +func updateUserStatusCache(userId int, status bool) error { + if !common.RedisEnabled { + return nil } - - status, err := getUserStatusCache(userId) - if err != nil { - return nil, fmt.Errorf("get status cache: %w", err) + statusInt := common.UserStatusEnabled + if !status { + statusInt = common.UserStatusDisabled } + return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt)) +} - username, err := getUserNameCache(userId) - if err != nil { - return nil, fmt.Errorf("get username cache: %w", err) +func updateUserQuotaCache(userId int, quota int) error { + if !common.RedisEnabled { + return nil } + return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota)) +} - return &userCache{ - Id: userId, - Group: group, - Quota: quota, - Status: status, - Username: username, - }, nil +func updateUserGroupCache(userId int, group string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Group", group) } -// Add atomic quota operations -func cacheIncrUserQuota(userId int, delta int64) error { +func updateUserNameCache(userId int, username string) error { if !common.RedisEnabled { return nil } - key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId) - return common.RedisIncr(key, delta) + return common.RedisHSetField(getUserCacheKey(userId), "Username", username) } -func cacheDecrUserQuota(userId int, delta int64) error { - return cacheIncrUserQuota(userId, -delta) +func updateUserSettingCache(userId int, setting string) error { + if !common.RedisEnabled { + return nil + } + return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting) } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 754000980249221a0aac7c0a4f06950650421894..5c2eadc20f5efe63333782eebd1449f39f236a6b 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -4,13 +4,14 @@ import ( "bytes" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 14dd74f03d762c8c65a9487d92322764839bee59..d779ee651772b06b471f56c6c0c08ac4078f8010 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/constant" ) type Adaptor struct { @@ -29,7 +30,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + switch info.RelayMode { + case constant.RelayModeCompletions: + return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil + default: + return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 681e9988a7a42d1fe02738e7e6fd277adec62d5f..32513c42ab99813de3258cbd6d5bd88ffae66f61 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -1,15 +1,21 @@ package gemini import ( + "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" + "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/service" + + "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + if !strings.HasPrefix(info.UpstreamModelName, "imagen") { + return nil, errors.New("not supported model for image generation") + } + + // convert size to aspect ratio + aspectRatio := "1:1" // default aspect ratio + switch request.Size { + case "1024x1024": + aspectRatio = "1:1" + case "1024x1792": + aspectRatio = "9:16" + case "1792x1024": + aspectRatio = "16:9" + } + + // build gemini imagen request + geminiRequest := GeminiImageRequest{ + Instances: []GeminiImageInstance{ + { + Prompt: request.Prompt, + }, + }, + Parameters: GeminiImageParameters{ + SampleCount: request.N, + AspectRatio: aspectRatio, + PersonGeneration: "allow_adult", // default allow adult + }, + } + + return geminiRequest, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil + } + action := "generateContent" if info.IsStream { action = "streamGenerateContent?alt=sse" @@ -73,12 +111,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + return GeminiImageHandler(c, resp, info) + } + if info.IsStream { err, usage = GeminiChatStreamHandler(c, resp, info) } else { @@ -87,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return } +func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + responseBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) + } + _ = resp.Body.Close() + + var geminiResponse GeminiImageResponse + if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { + return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + if len(geminiResponse.Predictions) == 0 { + return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest) + } + + // convert to openai format response + openAIResponse := dto.ImageResponse{ + Created: common.GetTimestamp(), + Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), + } + + for _, prediction := range geminiResponse.Predictions { + if prediction.RaiFilteredReason != "" { + continue // skip filtered image + } + openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + jsonResponse, jsonErr := json.Marshal(openAIResponse) + if jsonErr != nil { + return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb + // each image has fixed 258 tokens + const imageTokens = 258 + generatedImages := len(openAIResponse.Data) + + usage = &dto.Usage{ + PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens + CompletionTokens: 0, // image generation does not calculate completion tokens + TotalTokens: imageTokens * generatedImages, + } + + return usage, nil +} + func (a *Adaptor) GetModelList() []string { return ModelList } diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go index 9651bd607071007d37421456affa1843ddbfc3af..b7c1f0cf8758c55ade55ccacabfa6e61c978650d 100644 --- a/relay/channel/gemini/constant.go +++ b/relay/channel/gemini/constant.go @@ -16,6 +16,8 @@ var ModelList = []string{ "gemini-2.0-pro-exp", // thinking exp "gemini-2.0-flash-thinking-exp", + // imagen models + "imagen-3.0-generate-002", } var ChannelName = "google gemini" diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 08a5db8486e6f95a2adf39af6d400fe27dd24a38..bbcb1248d4825b440447c72583ab35f7e3acf3a3 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -109,3 +109,30 @@ type GeminiUsageMetadata struct { CandidatesTokenCount int `json:"candidatesTokenCount"` TotalTokenCount int `json:"totalTokenCount"` } + +// Imagen related structs +type GeminiImageRequest struct { + Instances []GeminiImageInstance `json:"instances"` + Parameters GeminiImageParameters `json:"parameters"` +} + +type GeminiImageInstance struct { + Prompt string `json:"prompt"` +} + +type GeminiImageParameters struct { + SampleCount int `json:"sampleCount,omitempty"` + AspectRatio string `json:"aspectRatio,omitempty"` + PersonGeneration string `json:"personGeneration,omitempty"` +} + +type GeminiImageResponse struct { + Predictions []GeminiImagePrediction `json:"predictions"` +} + +type GeminiImagePrediction struct { + MimeType string `json:"mimeType"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` + RaiFilteredReason string `json:"raiFilteredReason,omitempty"` + SafetyAttributes any `json:"safetyAttributes,omitempty"` +} diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index c99e539617371781f4e51590928611c2253f7506..fcea169a8d74aaf130f6d6138f76d479cbb9989a 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -41,9 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re if request == nil { return nil, errors.New("request is nil") } - mistralReq := requestOpenAI2Mistral(*request) - //common.LogJson(c, "body", mistralReq) - return mistralReq, nil + return requestOpenAI2Mistral(request), nil } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -55,7 +53,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go index 04add067517762a02126b831f4b551485f3c69ec..8987b8f0836d15b409176b2aa15855dd13f5c14a 100644 --- a/relay/channel/mistral/text.go +++ b/relay/channel/mistral/text.go @@ -1,25 +1,21 @@ package mistral import ( - "encoding/json" "one-api/dto" ) -func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { +func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { - if !message.IsStringContent() { - mediaMessages := message.ParseContent() - for j, mediaMessage := range mediaMessages { - if mediaMessage.Type == dto.ContentTypeImageURL { - imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) - mediaMessage.ImageUrl = imageUrl.Url - mediaMessages[j] = mediaMessage - } + mediaMessages := message.ParseContent() + for j, mediaMessage := range mediaMessages { + if mediaMessage.Type == dto.ContentTypeImageURL { + imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) + mediaMessage.ImageUrl = imageUrl.Url + mediaMessages[j] = mediaMessage } - messageRaw, _ := json.Marshal(mediaMessages) - message.Content = messageRaw } + message.SetMediaContent(mediaMessages) messages = append(messages, dto.Message{ Role: message.Role, Content: message.Content, diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 36889cb8d1728674f9fee57d92e52c73dc14644b..7e1c62377f27a7ede4dfc55d638474561d80da3d 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -39,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", "Bearer "+info.ApiKey) return nil } @@ -46,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re if request == nil { return nil, errors.New("request is nil") } - return requestOpenAI2Ollama(*request), nil + return requestOpenAI2Ollama(*request) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index 080191151d9755ce63d8dc3ee8c698f6ae75f7db..a954c607a69104c4ae35cde8aaf2fb8c66a96460 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -3,18 +3,21 @@ package ollama import "one-api/dto" type OllamaRequest struct { - Model string `json:"model,omitempty"` - Messages []dto.Message `json:"messages,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - Seed float64 `json:"seed,omitempty"` - Topp float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - Tools []dto.ToolCall `json:"tools,omitempty"` - ResponseFormat any `json:"response_format,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` + Model string `json:"model,omitempty"` + Messages []dto.Message `json:"messages,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + Seed float64 `json:"seed,omitempty"` + Topp float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop any `json:"stop,omitempty"` + Tools []dto.ToolCall `json:"tools,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Suffix any `json:"suffix,omitempty"` + StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"` + Prompt any `json:"prompt,omitempty"` } type Options struct { @@ -35,7 +38,7 @@ type OllamaEmbeddingRequest struct { } type OllamaEmbeddingResponse struct { - Error string `json:"error,omitempty"` - Model string `json:"model"` + Error string `json:"error,omitempty"` + Model string `json:"model"` Embedding [][]float64 `json:"embeddings,omitempty"` } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 4ecdd19bef927d9941716199a9db31dab9e75808..8b53fbfb56ae899c1c40644ccfae49f46c6b0089 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -9,14 +9,36 @@ import ( "net/http" "one-api/dto" "one-api/service" + "strings" ) -func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { +func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { + if !message.IsStringContent() { + mediaMessages := message.ParseContent() + for j, mediaMessage := range mediaMessages { + if mediaMessage.Type == dto.ContentTypeImageURL { + imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl) + // check if not base64 + if strings.HasPrefix(imageUrl.Url, "http") { + fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) + if err != nil { + return nil, err + } + imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data) + } + mediaMessage.ImageUrl = imageUrl + mediaMessages[j] = mediaMessage + } + } + message.SetMediaContent(mediaMessages) + } messages = append(messages, dto.Message{ - Role: message.Role, - Content: message.Content, + Role: message.Role, + Content: message.Content, + ToolCalls: message.ToolCalls, + ToolCallId: message.ToolCallId, }) } str, ok := request.Stop.(string) @@ -39,7 +61,10 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { ResponseFormat: request.ResponseFormat, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, - } + Prompt: request.Prompt, + StreamOptions: request.StreamOptions, + Suffix: request.Suffix, + }, nil } func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index e94399eaa04417e5dc3a7ea67379ee57286f76b4..f927fa74f45f69aed47d5774481230604f29c9fc 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -119,7 +119,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re request.MaxCompletionTokens = request.MaxTokens request.MaxTokens = 0 } - if strings.HasPrefix(request.Model, "o3") { + if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") { request.Temperature = nil } if strings.HasSuffix(request.Model, "-high") { diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 6c9359f334bd7af5bff02272a123a1d55d33670a..33cdea48639f8625e25d14662ae862900636e9eb 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -87,6 +87,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel info.SetFirstResponseTime() ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) data := scanner.Text() + if common.DebugEnabled { + println(data) + } if len(data) < 6 { // ignore blank line or wrong format continue } @@ -162,6 +165,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel //} for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) + responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) if choice.Delta.ToolCalls != nil { if len(choice.Delta.ToolCalls) > toolCount { toolCount = len(choice.Delta.ToolCalls) @@ -182,6 +186,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel //} for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) + responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) if choice.Delta.ToolCalls != nil { if len(choice.Delta.ToolCalls) > toolCount { toolCount = len(choice.Delta.ToolCalls) @@ -273,7 +278,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(string(choice.Message.Content), model) + ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index c02d18a3e84fc016e92a95108e4d4b27dfa3c5c9..797f02442be9b9a233e72b086e708a0e70d91c03 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil } else if info.RelayMode == constant.RelayModeChatCompletions { return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + } else if info.RelayMode == constant.RelayModeCompletions { + return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil } return "", errors.New("invalid relay mode") } @@ -72,6 +74,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } + case constant.RelayModeCompletions: + if info.IsStream { + err, usage = openai.OaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } case constant.RelayModeEmbeddings: err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index 06f306f6b0c78a9bf38051d493a8f3437e42658c..97d82c718eaa86f6c995f35b558e4cf253333c12 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq mediaMessages[j] = mediaMessage } } - messageRaw, _ := json.Marshal(mediaMessages) - message.Content = messageRaw + message.SetMediaContent(mediaMessages) } messages = append(messages, dto.Message{ Role: message.Role, diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 4978f84f09088834b7405b485d2a3ddb2aee7d20..1f4a3a42b6fd99f6344a9a11a86516c239d824e0 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -13,24 +13,24 @@ import ( ) type RelayInfo struct { - ChannelType int - ChannelId int - TokenId int - TokenKey string - UserId int - Group string - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - setFirstResponse bool - ApiType int - IsStream bool - IsPlayground bool - UsePrice bool - RelayMode int - UpstreamModelName string - OriginModelName string - RecodeModelName string + ChannelType int + ChannelId int + TokenId int + TokenKey string + UserId int + Group string + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + setFirstResponse bool + ApiType int + IsStream bool + IsPlayground bool + UsePrice bool + RelayMode int + UpstreamModelName string + OriginModelName string + //RecodeModelName string RequestURLPath string ApiVersion string PromptTokens int @@ -39,6 +39,7 @@ type RelayInfo struct { BaseUrl string SupportStreamOptions bool ShouldIncludeUsage bool + IsModelMapped bool ClientWs *websocket.Conn TargetWs *websocket.Conn InputAudioFormat string @@ -50,6 +51,18 @@ type RelayInfo struct { ChannelSetting map[string]interface{} } +// 定义支持流式选项的通道类型 +var streamSupportedChannels = map[int]bool{ + common.ChannelTypeOpenAI: true, + common.ChannelTypeAnthropic: true, + common.ChannelTypeAws: true, + common.ChannelTypeGemini: true, + common.ChannelCloudflare: true, + common.ChannelTypeAzure: true, + common.ChannelTypeVolcEngine: true, + common.ChannelTypeOllama: true, +} + func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { info := GenRelayInfo(c) info.ClientWs = ws @@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { FirstResponseTime: startTime.Add(-time.Second), OriginModelName: c.GetString("original_model"), UpstreamModelName: c.GetString("original_model"), - RecodeModelName: c.GetString("recode_model"), - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Organization: c.GetString("channel_organization"), - ChannelSetting: channelSetting, + //RecodeModelName: c.GetString("original_model"), + IsModelMapped: false, + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Organization: c.GetString("channel_organization"), + ChannelSetting: channelSetting, } if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true @@ -110,9 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if info.ChannelType == common.ChannelTypeVertexAi { info.ApiVersion = c.GetString("region") } - if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || - info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini || - info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure { + if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } return info diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go new file mode 100644 index 0000000000000000000000000000000000000000..948c5226b6917a123955c0c51d5707640e37fd28 --- /dev/null +++ b/relay/helper/model_mapped.go @@ -0,0 +1,25 @@ +package helper + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "one-api/relay/common" +) + +func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { + // map model name + modelMapping := c.GetString("model_mapping") + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return fmt.Errorf("unmarshal_model_mapping_failed") + } + if modelMap[info.OriginModelName] != "" { + info.UpstreamModelName = modelMap[info.OriginModelName] + info.IsModelMapped = true + } + } + return nil +} diff --git a/relay/helper/price.go b/relay/helper/price.go new file mode 100644 index 0000000000000000000000000000000000000000..d65b86aa5a307e51db3325d7cd907b62819e78b8 --- /dev/null +++ b/relay/helper/price.go @@ -0,0 +1,41 @@ +package helper + +import ( + "github.com/gin-gonic/gin" + "one-api/common" + relaycommon "one-api/relay/common" + "one-api/setting" +) + +type PriceData struct { + ModelPrice float64 + ModelRatio float64 + GroupRatio float64 + UsePrice bool + ShouldPreConsumedQuota int +} + +func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData { + modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false) + groupRatio := setting.GetGroupRatio(info.Group) + var preConsumedQuota int + var modelRatio float64 + if !usePrice { + preConsumedTokens := common.PreConsumedQuota + if maxTokens != 0 { + preConsumedTokens = promptTokens + maxTokens + } + modelRatio = common.GetModelRatio(info.OriginModelName) + ratio := modelRatio * groupRatio + preConsumedQuota = int(float64(preConsumedTokens) * ratio) + } else { + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + return PriceData{ + ModelPrice: modelPrice, + ModelRatio: modelRatio, + GroupRatio: groupRatio, + UsePrice: usePrice, + ShouldPreConsumedQuota: preConsumedQuota, + } +} diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 4c23a8f8d99136cfbb40cb22958345c3621f20b0..b95c1eb693f6ce30720d2e291789b4953ae40536 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -1,7 +1,6 @@ package relay import ( - "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" @@ -11,8 +10,10 @@ import ( "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "one-api/setting" + "strings" ) func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { @@ -27,8 +28,9 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. return nil, errors.New("model is required") } if setting.ShouldCheckPromptSensitive() { - err := service.CheckSensitiveInput(audioRequest.Input) + words, err := service.CheckSensitiveInput(audioRequest.Input) if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) return nil, err } } @@ -73,15 +75,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo.PromptTokens = promptTokens } - modelRatio := common.GetModelRatio(audioRequest.Model) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) + priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } - preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -91,19 +91,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } }() - // map model name - modelMapping := c.GetString("model_mapping") - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[audioRequest.Model] != "" { - audioRequest.Model = modelMap[audioRequest.Model] - } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - relayInfo.UpstreamModelName = audioRequest.Model + + audioRequest.Model = relayInfo.UpstreamModelName adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { @@ -140,7 +133,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return openaiErr } - postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil } diff --git a/relay/relay-image.go b/relay/relay-image.go index 207350da6b85cdf00cfebcdbc8027bd356e0846a..afa5b8e2ce160c6e69039b527a50b57bfd7375f9 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -12,6 +12,7 @@ import ( "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "one-api/setting" "strings" @@ -60,15 +61,16 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) //} if setting.ShouldCheckPromptSensitive() { - err := service.CheckSensitiveInput(imageRequest.Prompt) + words, err := service.CheckSensitiveInput(imageRequest.Prompt) if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ","))) return nil, err } } return imageRequest, nil } -func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { +func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { relayInfo := relaycommon.GenRelayInfo(c) imageRequest, err := getAndValidImageRequest(c, relayInfo) @@ -77,29 +79,20 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } - // map model name - modelMapping := c.GetString("model_mapping") - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[imageRequest.Model] != "" { - imageRequest.Model = modelMap[imageRequest.Model] - } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - relayInfo.UpstreamModelName = imageRequest.Model - modelPrice, success := common.GetModelPrice(imageRequest.Model, true) - if !success { - modelRatio := common.GetModelRatio(imageRequest.Model) + imageRequest.Model = relayInfo.UpstreamModelName + + priceData := helper.ModelPriceHelper(c, relayInfo, 0, 0) + if !priceData.UsePrice { // modelRatio 16 = modelPrice $0.04 // per 1 modelRatio = $0.04 / 16 - modelPrice = 0.0025 * modelRatio + priceData.ModelPrice = 0.0025 * priceData.ModelRatio } - groupRatio := setting.GetGroupRatio(relayInfo.Group) userQuota, err := model.GetUserQuota(relayInfo.UserId, false) sizeRatio := 1.0 @@ -122,11 +115,11 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { } } - imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N) - quota := int(imageRatio * groupRatio * common.QuotaPerUnit) + priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) + quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit) if userQuota-quota < 0 { - return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest) + return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden) } adaptor := GetAdaptor(relayInfo.ApiType) @@ -184,7 +177,6 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { } logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) - postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent) - + postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent) return nil } diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 0facecabea1793da881d39815be68a2e5c24a968..766064cbc0db0f69dc5fc5f9488a4a70d9e476d8 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func(ctx context.Context) { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } @@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons defer func(ctx context.Context) { if consumeQuota && midjResponseWithStatus.StatusCode == 200 { - err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } diff --git a/relay/relay-text.go b/relay/relay-text.go index f303ff6a678db257d1b68728e02956b2a5da7450..bfd91cdf9cce5331d67e19aa7d9874e639cd9726 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/bytedance/gopkg/util/gopool" "io" "math" "net/http" @@ -14,6 +15,7 @@ import ( "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "one-api/setting" "strings" @@ -75,40 +77,21 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) } - // map model name - //isModelMapped := false - modelMapping := c.GetString("model_mapping") - //isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[textRequest.Model] != "" { - //isModelMapped = true - textRequest.Model = modelMap[textRequest.Model] - // set upstream model name - //isModelMapped = true - } - } - relayInfo.UpstreamModelName = textRequest.Model - relayInfo.RecodeModelName = textRequest.Model - modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - - var preConsumedQuota int - var ratio float64 - var modelRatio float64 - //err := service.SensitiveWordsCheck(textRequest) - if setting.ShouldCheckPromptSensitive() { - err = checkRequestSensitive(textRequest, relayInfo) + words, err := checkRequestSensitive(textRequest, relayInfo) if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) } } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) + } + + textRequest.Model = relayInfo.UpstreamModelName + // 获取 promptTokens,如果上下文中已经存在,则直接使用 var promptTokens int if value, exists := c.Get("prompt_tokens"); exists { @@ -123,20 +106,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { c.Set("prompt_tokens", promptTokens) } - if !getModelPriceSuccess { - preConsumedTokens := common.PreConsumedQuota - if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + int(textRequest.MaxTokens) - } - modelRatio = common.GetModelRatio(textRequest.Model) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } + priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -219,10 +192,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return openaiErr } - if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") { - service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") } else { - postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") } return nil } @@ -247,19 +220,20 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re return promptTokens, err } -func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error { +func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) { var err error + var words []string switch info.RelayMode { case relayconstant.RelayModeChatCompletions: - err = service.CheckSensitiveMessages(textRequest.Messages) + words, err = service.CheckSensitiveMessages(textRequest.Messages) case relayconstant.RelayModeCompletions: - err = service.CheckSensitiveInput(textRequest.Prompt) + words, err = service.CheckSensitiveInput(textRequest.Prompt) case relayconstant.RelayModeModerations: - err = service.CheckSensitiveInput(textRequest.Input) + words, err = service.CheckSensitiveInput(textRequest.Input) case relayconstant.RelayModeEmbeddings: - err = service.CheckSensitiveInput(textRequest.Input) + words, err = service.CheckSensitiveInput(textRequest.Input) } - return err + return words, err } // 预扣费并返回用户剩余配额 @@ -272,7 +246,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } if userQuota-preConsumedQuota < 0 { - return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest) + return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden) } if userQuota > 100*preConsumedQuota { // 用户额度充足,判断令牌额度是否充足 @@ -282,18 +256,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if tokenQuota > 100*preConsumedQuota { // 令牌额度充足,信任令牌 preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) + common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) } } else { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 - common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) + common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota))) } } if preConsumedQuota > 0 { - err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) + err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } @@ -307,20 +281,19 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) { if preConsumedQuota != 0 { - go func() { + gopool.Go(func() { relayInfoCopy := *relayInfo - err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false) + err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) if err != nil { common.SysError("error return pre-consumed quota: " + err.Error()) } - }() + }) } } -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool, extraContent string) { +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, + usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { if usage == nil { usage = &dto.Usage{ PromptTokens: relayInfo.PromptTokens, @@ -332,12 +305,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens + modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") completionRatio := common.GetCompletionRatio(modelName) + ratio := priceData.ModelRatio * priceData.GroupRatio + modelRatio := priceData.ModelRatio + groupRatio := priceData.GroupRatio + modelPrice := priceData.ModelPrice + usePrice := priceData.UsePrice quota := 0 - if !usePrice { + if !priceData.UsePrice { quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio)) quota = int(math.Round(float64(quota) * ratio)) if ratio != 0 && quota <= 0 { @@ -368,7 +347,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN //} quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { - err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) if err != nil { common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go index 0a41c11d59738c30104fd1abfbb1c730efa509a5..18739d9f92369212e47ce77848def8ee6c3049e0 100644 --- a/relay/relay_embedding.go +++ b/relay/relay_embedding.go @@ -10,8 +10,8 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" - "one-api/setting" ) func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { @@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) } - // map model name - modelMapping := c.GetString("model_mapping") - //isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[embeddingRequest.Model] != "" { - embeddingRequest.Model = modelMap[embeddingRequest.Model] - // set upstream model name - //isModelMapped = true - } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - relayInfo.UpstreamModelName = embeddingRequest.Model - modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - - var preConsumedQuota int - var ratio float64 - var modelRatio float64 + embeddingRequest.Model = relayInfo.UpstreamModelName promptToken := getEmbeddingPromptToken(*embeddingRequest) - if !success { - preConsumedTokens := promptToken - modelRatio = common.GetModelRatio(embeddingRequest.Model) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } relayInfo.PromptTokens = promptToken + priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) + // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil } diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index e53e37d480823d0604a679e92db0b3100e440b67..37178cad3a550c19f8aec4971d893b193e9543d0 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -9,8 +9,8 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" - "one-api/setting" ) func getRerankPromptToken(rerankRequest dto.RerankRequest) int { @@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) } - // map model name - modelMapping := c.GetString("model_mapping") - //isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[rerankRequest.Model] != "" { - rerankRequest.Model = modelMap[rerankRequest.Model] - // set upstream model name - //isModelMapped = true - } + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - relayInfo.UpstreamModelName = rerankRequest.Model - modelPrice, success := common.GetModelPrice(rerankRequest.Model, false) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - - var preConsumedQuota int - var ratio float64 - var modelRatio float64 + rerankRequest.Model = relayInfo.UpstreamModelName promptToken := getRerankPromptToken(*rerankRequest) - if !success { - preConsumedTokens := promptToken - modelRatio = common.GetModelRatio(rerankRequest.Model) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } relayInfo.PromptTokens = promptToken + priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) + // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index 61577faf7937e90e880e359bef106d476441818f..f03fcb2d912f4a1ec2a0e57e22379c45e396a31b 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { // release quota if relayInfo.ConsumeQuota && taskErr == nil { - err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true) + err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } diff --git a/router/api-router.go b/router/api-router.go index b00595af3256e9f4ddc1bd9dc90a7fd7d764f3b4..bf88449a56bc57256da2c0f31edf8a44bb279888 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.POST("/pay", controller.RequestEpay) selfRoute.POST("/amount", controller.RequestAmount) selfRoute.POST("/aff_transfer", controller.TransferAffQuota) + selfRoute.PUT("/setting", controller.UpdateUserSetting) } adminRoute := userRoute.Group("/") diff --git a/service/cf_worker.go b/service/cf_worker.go index afe65411b0eac6431a450e94e78e3345b202ad86..40a1e29450c5db52321bb5f05c3b9b763dc5f36f 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -2,6 +2,7 @@ package service import ( "bytes" + "encoding/json" "fmt" "net/http" "one-api/common" @@ -9,19 +10,46 @@ import ( "strings" ) +// WorkerRequest Worker请求的数据结构 +type WorkerRequest struct { + URL string `json:"url"` + Key string `json:"key"` + Method string `json:"method,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + Body json.RawMessage `json:"body,omitempty"` +} + +// DoWorkerRequest 通过Worker发送请求 +func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { + if !setting.EnableWorker() { + return nil, fmt.Errorf("worker not enabled") + } + if !strings.HasPrefix(req.URL, "https") { + return nil, fmt.Errorf("only support https url") + } + + workerUrl := setting.WorkerUrl + if !strings.HasSuffix(workerUrl, "/") { + workerUrl += "/" + } + + // 序列化worker请求数据 + workerPayload, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal worker payload: %v", err) + } + + return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload)) +} + func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { if setting.EnableWorker() { common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) - if !strings.HasPrefix(originUrl, "https") { - return nil, fmt.Errorf("only support https url") - } - workerUrl := setting.WorkerUrl - if !strings.HasSuffix(workerUrl, "/") { - workerUrl += "/" + req := &WorkerRequest{ + URL: originUrl, + Key: setting.WorkerValidKey, } - // post request to worker - data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`) - return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data)) + return DoWorkerRequest(req) } else { common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) return http.Get(originUrl) diff --git a/service/channel.go b/service/channel.go index 73545b1e65b26e6d7156badf0cdc8aa68dcef0f1..76bcacf1bb1433d3716519f567de3b3ede97205f 100644 --- a/service/channel.go +++ b/service/channel.go @@ -4,7 +4,7 @@ import ( "fmt" "net/http" "one-api/common" - relaymodel "one-api/dto" + "one-api/dto" "one-api/model" "one-api/setting" "strings" @@ -15,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) { model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) + NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate) } func EnableChannel(channelId int, channelName string) { model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - notifyRootUser(subject, content) + NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate) } -func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool { +func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool { if !common.AutomaticDisableChannelEnabled { return false } @@ -75,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus return false } -func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool { +func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool { if !common.AutomaticEnableChannelEnabled { return false } diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 1ce09d92f7f1fe7408d0395e7cb22b38d406ec58..1e32d6f1ed5276d54f7a6522ff1b6a040efe807d 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m if relayInfo.ReasoningEffort != "" { other["reasoning_effort"] = relayInfo.ReasoningEffort } + if relayInfo.IsModelMapped { + other["is_model_mapped"] = true + other["upstream_model_name"] = relayInfo.UpstreamModelName + } adminInfo := make(map[string]interface{}) adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") other["admin_info"] = adminInfo diff --git a/service/notify-limit.go b/service/notify-limit.go new file mode 100644 index 0000000000000000000000000000000000000000..309ea54d2192058009e2992cbcc60f3e33a236d1 --- /dev/null +++ b/service/notify-limit.go @@ -0,0 +1,117 @@ +package service + +import ( + "fmt" + "github.com/bytedance/gopkg/util/gopool" + "one-api/common" + "one-api/constant" + "strconv" + "sync" + "time" +) + +// notifyLimitStore is used for in-memory rate limiting when Redis is disabled +var ( + notifyLimitStore sync.Map + cleanupOnce sync.Once +) + +type limitCount struct { + Count int + Timestamp time.Time +} + +func getDuration() time.Duration { + minute := constant.NotificationLimitDurationMinute + return time.Duration(minute) * time.Minute +} + +// startCleanupTask starts a background task to clean up expired entries +func startCleanupTask() { + gopool.Go(func() { + for { + time.Sleep(time.Hour) + now := time.Now() + notifyLimitStore.Range(func(key, value interface{}) bool { + if limit, ok := value.(limitCount); ok { + if now.Sub(limit.Timestamp) >= getDuration() { + notifyLimitStore.Delete(key) + } + } + return true + }) + } + }) +} + +// CheckNotificationLimit checks if the user has exceeded their notification limit +// Returns true if the user can send notification, false if limit exceeded +func CheckNotificationLimit(userId int, notifyType string) (bool, error) { + if common.RedisEnabled { + return checkRedisLimit(userId, notifyType) + } + return checkMemoryLimit(userId, notifyType) +} + +func checkRedisLimit(userId int, notifyType string) (bool, error) { + key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215")) + + // Get current count + count, err := common.RedisGet(key) + if err != nil && err.Error() != "redis: nil" { + return false, fmt.Errorf("failed to get notification count: %w", err) + } + + // If key doesn't exist, initialize it + if count == "" { + err = common.RedisSet(key, "1", getDuration()) + return true, err + } + + currentCount, _ := strconv.Atoi(count) + limit := constant.NotifyLimitCount + + // Check if limit is already reached + if currentCount >= limit { + return false, nil + } + + // Only increment if under limit + err = common.RedisIncr(key, 1) + if err != nil { + return false, fmt.Errorf("failed to increment notification count: %w", err) + } + + return true, nil +} + +func checkMemoryLimit(userId int, notifyType string) (bool, error) { + // Ensure cleanup task is started + cleanupOnce.Do(startCleanupTask) + + key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215")) + now := time.Now() + + // Get current limit count or initialize new one + var currentLimit limitCount + if value, ok := notifyLimitStore.Load(key); ok { + currentLimit = value.(limitCount) + // Check if the entry has expired + if now.Sub(currentLimit.Timestamp) >= getDuration() { + currentLimit = limitCount{Count: 0, Timestamp: now} + } + } else { + currentLimit = limitCount{Count: 0, Timestamp: now} + } + + // Increment count + currentLimit.Count++ + + // Check against limits + limit := constant.NotifyLimitCount + + // Store updated count + notifyLimitStore.Store(key, currentLimit) + + return currentLimit.Count <= limit, nil +} diff --git a/service/quota.go b/service/quota.go index ab04800823bf7af79ec6f70134927d55a8cd20a5..2cae93def4228f8aa93f0479c5301813015b9ba8 100644 --- a/service/quota.go +++ b/service/quota.go @@ -3,11 +3,14 @@ package service import ( "errors" "fmt" + "github.com/bytedance/gopkg/util/gopool" "math" "one-api/common" + constant2 "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/setting" "strings" "time" @@ -66,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag return err } - modelName := relayInfo.UpstreamModelName + modelName := relayInfo.OriginModelName textInputTokens := usage.InputTokenDetails.TextTokens textOutTokens := usage.OutputTokenDetails.TextTokens audioInputTokens := usage.InputTokenDetails.AudioTokens @@ -92,14 +95,14 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag quota := calculateAudioQuota(quotaInfo) if userQuota < quota { - return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) + return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)) } if !token.UnlimitedQuota && token.RemainQuota < quota { - return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota)) + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) } - err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false) + err = PostConsumeQuota(relayInfo, quota, 0, false) if err != nil { return err } @@ -120,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod tokenName := ctx.GetString("token_name") completionRatio := common.GetCompletionRatio(modelName) - audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) + audioRatio := common.GetAudioRatio(relayInfo.OriginModelName) audioCompletionRatio := common.GetAudioCompletionRatio(modelName) quotaInfo := QuotaInfo{ @@ -171,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod } func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, - usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool, extraContent string) { + usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.PromptTokensDetails.TextTokens @@ -182,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioOutTokens := usage.CompletionTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName) - audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName) + completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName) + audioRatio := common.GetAudioRatio(relayInfo.OriginModelName) + audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName) + + modelRatio := priceData.ModelRatio + groupRatio := priceData.GroupRatio + modelPrice := priceData.ModelPrice + usePrice := priceData.UsePrice quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -195,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, TextTokens: textOutTokens, AudioTokens: audioOutTokens, }, - ModelName: relayInfo.RecodeModelName, + ModelName: relayInfo.OriginModelName, UsePrice: usePrice, ModelRatio: modelRatio, GroupRatio: groupRatio, @@ -218,11 +225,11 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota)) + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota)) } else { quotaDelta := quota - preConsumedQuota if quotaDelta != 0 { - err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true) if err != nil { common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } @@ -231,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - logModel := relayInfo.RecodeModelName + logModel := relayInfo.OriginModelName if extraContent != "" { logContent += ", " + extraContent } @@ -239,3 +246,88 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } + +func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + if relayInfo.IsPlayground { + return nil + } + //if relayInfo.TokenUnlimited { + // return nil + //} + token, err := model.GetTokenByKey(relayInfo.TokenKey, false) + if err != nil { + return err + } + if !relayInfo.TokenUnlimited && token.RemainQuota < quota { + return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota)) + } + err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) + if err != nil { + return err + } + return nil +} + +func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) { + + if quota > 0 { + err = model.DecreaseUserQuota(relayInfo.UserId, quota) + } else { + err = model.IncreaseUserQuota(relayInfo.UserId, -quota) + } + if err != nil { + return err + } + + if !relayInfo.IsPlayground { + if quota > 0 { + err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota) + } else { + err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota) + } + if err != nil { + return err + } + } + + if sendEmail { + if (quota + preConsumedQuota) != 0 { + checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota) + } + } + + return nil +} + +func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) { + gopool.Go(func() { + userCache, err := model.GetUserCache(userId) + if err != nil { + common.SysError("failed to get user cache: " + err.Error()) + } + userSetting := userCache.GetSetting() + threshold := common.QuotaRemindThreshold + if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok { + threshold = int(userCustomThreshold.(float64)) + } + + //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0 + quotaTooLow := false + consumeQuota := quota + preConsumedQuota + if userCache.Quota-consumeQuota < threshold { + quotaTooLow = true + } + if quotaTooLow { + prompt := "您的额度即将用尽" + topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) + content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" + err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink})) + if err != nil { + common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error())) + } + } + }) +} diff --git a/service/sensitive.go b/service/sensitive.go index 14ac94819e772e6dbc3fc568080e453413345be1..b3e3c4d664e98417f229ad4dd2301a7440fcdb88 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -8,48 +8,47 @@ import ( "strings" ) -func CheckSensitiveMessages(messages []dto.Message) error { +func CheckSensitiveMessages(messages []dto.Message) ([]string, error) { + if len(messages) == 0 { + return nil, nil + } + for _, message := range messages { - if len(message.Content) > 0 { - if message.IsStringContent() { - stringContent := message.StringContent() - if ok, words := SensitiveWordContains(stringContent); ok { - return errors.New("sensitive words: " + strings.Join(words, ",")) - } + arrayContent := message.ParseContent() + for _, m := range arrayContent { + if m.Type == "image_url" { + // TODO: check image url + continue + } + // 检查 text 是否为空 + if m.Text == "" { + continue } - } else { - arrayContent := message.ParseContent() - for _, m := range arrayContent { - if m.Type == "image_url" { - // TODO: check image url - } else { - if ok, words := SensitiveWordContains(m.Text); ok { - return errors.New("sensitive words: " + strings.Join(words, ",")) - } - } + if ok, words := SensitiveWordContains(m.Text); ok { + return words, errors.New("sensitive words detected") } } } - return nil + return nil, nil } -func CheckSensitiveText(text string) error { +func CheckSensitiveText(text string) ([]string, error) { if ok, words := SensitiveWordContains(text); ok { - return errors.New("sensitive words: " + strings.Join(words, ",")) + return words, errors.New("sensitive words detected") } - return nil + return nil, nil } -func CheckSensitiveInput(input any) error { +func CheckSensitiveInput(input any) ([]string, error) { switch v := input.(type) { case string: return CheckSensitiveText(v) case []string: - text := "" + var builder strings.Builder for _, s := range v { - text += s + builder.WriteString(s) } - return CheckSensitiveText(text) + return CheckSensitiveText(builder.String()) } return CheckSensitiveText(fmt.Sprintf("%v", input)) } @@ -59,8 +58,11 @@ func SensitiveWordContains(text string) (bool, []string) { if len(setting.SensitiveWords) == 0 { return false, nil } + if len(text) == 0 { + return false, nil + } checkText := strings.ToLower(text) - return AcSearch(checkText, setting.SensitiveWords, false) + return AcSearch(checkText, setting.SensitiveWords, true) } // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 @@ -72,14 +74,21 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, m := InitAc(setting.SensitiveWords) hits := m.MultiPatternSearch([]rune(checkText), returnImmediately) if len(hits) > 0 { - words := make([]string, 0) + words := make([]string, 0, len(hits)) + var builder strings.Builder + builder.Grow(len(text)) + lastPos := 0 + for _, hit := range hits { pos := hit.Pos word := string(hit.Word) - text = text[:pos] + "**###**" + text[pos+len(word):] + builder.WriteString(text[lastPos:pos]) + builder.WriteString("**###**") + lastPos = pos + len(word) words = append(words, word) } - return true, words, text + builder.WriteString(text[lastPos:]) + return true, words, builder.String() } return false, nil, text } diff --git a/service/user_notify.go b/service/user_notify.go index 7ae9062bce886eb00f87e0cb17af2a03e18bb292..e01b7aa9c04f3b990e9122470db1a0c596ca6919 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -3,15 +3,75 @@ package service import ( "fmt" "one-api/common" + "one-api/constant" + "one-api/dto" "one-api/model" + "strings" ) -func notifyRootUser(subject string, content string) { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() +func NotifyRootUser(t string, subject string, content string) { + user := model.GetRootUser().ToBaseUser() + _ = NotifyUser(user, dto.NewNotify(t, subject, content, nil)) +} + +func NotifyUser(user *model.UserBase, data dto.Notify) error { + userSetting := user.GetSetting() + notifyType, ok := userSetting[constant.UserSettingNotifyType] + if !ok { + notifyType = constant.NotifyTypeEmail } - err := common.SendEmail(subject, common.RootUserEmail, content) + + // Check notification limit + canSend, err := CheckNotificationLimit(user.Id, data.Type) if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) + return err + } + if !canSend { + return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType) + } + + switch notifyType { + case constant.NotifyTypeEmail: + userEmail := user.Email + // check setting email + if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok { + userEmail = settingEmail.(string) + } + if userEmail == "" { + common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id)) + return nil + } + return sendEmailNotify(userEmail, data) + case constant.NotifyTypeWebhook: + webhookURL, ok := userSetting[constant.UserSettingWebhookUrl] + if !ok { + common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id)) + return nil + } + webhookURLStr, ok := webhookURL.(string) + if !ok { + common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id)) + return nil + } + + // 获取 webhook secret + var webhookSecret string + if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok { + webhookSecret, _ = secret.(string) + } + + return SendWebhookNotify(webhookURLStr, webhookSecret, data) + } + return nil +} + +func sendEmailNotify(userEmail string, data dto.Notify) error { + // make email content + content := data.Content + // 处理占位符 + for _, value := range data.Values { + content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1) } + return common.SendEmail(data.Title, userEmail, content) } diff --git a/service/webhook.go b/service/webhook.go new file mode 100644 index 0000000000000000000000000000000000000000..ad2967eb0b142785976f0c30dcaf61e69f2d654f --- /dev/null +++ b/service/webhook.go @@ -0,0 +1,118 @@ +package service + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "one-api/dto" + "one-api/setting" + "time" +) + +// WebhookPayload webhook 通知的负载数据 +type WebhookPayload struct { + Type string `json:"type"` + Title string `json:"title"` + Content string `json:"content"` + Values []interface{} `json:"values,omitempty"` + Timestamp int64 `json:"timestamp"` +} + +// generateSignature 生成 webhook 签名 +func generateSignature(secret string, payload []byte) string { + h := hmac.New(sha256.New, []byte(secret)) + h.Write(payload) + return hex.EncodeToString(h.Sum(nil)) +} + +// SendWebhookNotify 发送 webhook 通知 +func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error { + // 处理占位符 + content := data.Content + for _, value := range data.Values { + content = fmt.Sprintf(content, value) + } + + // 构建 webhook 负载 + payload := WebhookPayload{ + Type: data.Type, + Title: data.Title, + Content: content, + Values: data.Values, + Timestamp: time.Now().Unix(), + } + + // 序列化负载 + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal webhook payload: %v", err) + } + + // 创建 HTTP 请求 + var req *http.Request + var resp *http.Response + + if setting.EnableWorker() { + // 构建worker请求数据 + workerReq := &WorkerRequest{ + URL: webhookURL, + Key: setting.WorkerValidKey, + Method: http.MethodPost, + Headers: map[string]string{ + "Content-Type": "application/json", + }, + Body: payloadBytes, + } + + // 如果有secret,添加签名到headers + if secret != "" { + signature := generateSignature(secret, payloadBytes) + workerReq.Headers["X-Webhook-Signature"] = signature + workerReq.Headers["Authorization"] = "Bearer " + secret + } + + resp, err = DoWorkerRequest(workerReq) + if err != nil { + return fmt.Errorf("failed to send webhook request through worker: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode) + } + } else { + req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes)) + if err != nil { + return fmt.Errorf("failed to create webhook request: %v", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + + // 如果有 secret,生成签名 + if secret != "" { + signature := generateSignature(secret, payloadBytes) + req.Header.Set("X-Webhook-Signature", signature) + } + + // 发送请求 + client := GetImpatientHttpClient() + resp, err = client.Do(req) + if err != nil { + return fmt.Errorf("failed to send webhook request: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode) + } + } + + return nil +} diff --git a/setting/operation_setting.go b/setting/operation_setting.go index 9a28e987103d1c7ba9d422cb015f5dfc839e0890..4940d0fc6cc3718b9a633627f1323eb856bd1c53 100644 --- a/setting/operation_setting.go +++ b/setting/operation_setting.go @@ -23,6 +23,7 @@ func AutomaticDisableKeywordsFromString(s string) { ak := strings.Split(s, "\n") for _, k := range ak { k = strings.TrimSpace(k) + k = strings.ToLower(k) if k != "" { AutomaticDisableKeywords = append(AutomaticDisableKeywords, k) } diff --git a/setting/system-setting.go b/setting/system_setting.go similarity index 100% rename from setting/system-setting.go rename to setting/system_setting.go diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 605103ae258bd78c85115fdcf0fad5861706eaf6..71914c4e7c4f5261d16d9847d367a1d79f5100f7 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -357,6 +357,13 @@ const ChannelsTable = () => { dataIndex: 'operate', render: (text, record, index) => { if (record.children === undefined) { + // 构建模型测试菜单 + const modelMenuItems = record.models.split(',').map(model => ({ + node: 'item', + name: model, + onClick: () => testChannel(record, model) + })); + return (
{
+ }> + { + copyText(event, record.model_name).then(r => {}); + }} + suffixIcon={} + > + {' '}{record.model_name}{' '} + + + {/**/} + {/* {*/} + {/* copyText(event, other.upstream_model_name).then(r => {});*/} + {/* }}*/} + {/* >*/} + {/* {' '}{other.upstream_model_name}{' '}*/} + {/* */} + {/**/} + + + ); + } + + } const columns = [ { @@ -272,18 +344,7 @@ const LogsTable = () => { dataIndex: 'model_name', render: (text, record, index) => { return record.type === 0 || record.type === 2 ? ( - <> - { - copyText(event, text); - }} - > - {' '} - {text}{' '} - - + <>{renderModelName(record)} ) : ( <> ); @@ -580,6 +641,17 @@ const LogsTable = () => { value: logs[i].content, }); if (logs[i].type === 2) { + let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== ''; + if (modelMapped) { + expandDataLocal.push({ + key: t('请求并计费模型'), + value: logs[i].model_name, + }); + expandDataLocal.push({ + key: t('实际模型'), + value: other.upstream_model_name, + }); + } let content = ''; if (other?.ws || other?.audio) { content = renderAudioModelPrice( diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index 2f112c37c1bb8117c54315c4a6236fe1d934367a..49a0784ca412c034c28b82f874fa906df681a5d0 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -26,6 +26,10 @@ import { Tag, Typography, Collapsible, + Select, + Radio, + RadioGroup, + AutoComplete, } from '@douyinfe/semi-ui'; import { getQuotaPerUnit, @@ -67,14 +71,16 @@ const PersonalSetting = () => { const [transferAmount, setTransferAmount] = useState(0); const [isModelsExpanded, setIsModelsExpanded] = useState(false); const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量 + const [notificationSettings, setNotificationSettings] = useState({ + warningType: 'email', + warningThreshold: 100000, + webhookUrl: '', + webhookSecret: '', + notificationEmail: '' + }); + const [showWebhookDocs, setShowWebhookDocs] = useState(false); useEffect(() => { - // let user = localStorage.getItem('user'); - // if (user) { - // userDispatch({ type: 'login', payload: user }); - // } - // console.log(localStorage.getItem('user')) - let status = localStorage.getItem('status'); if (status) { status = JSON.parse(status); @@ -105,6 +111,19 @@ const PersonalSetting = () => { return () => clearInterval(countdownInterval); // Clean up on unmount }, [disableButton, countdown]); + useEffect(() => { + if (userState?.user?.setting) { + const settings = JSON.parse(userState.user.setting); + setNotificationSettings({ + warningType: settings.notify_type || 'email', + warningThreshold: settings.quota_warning_threshold || 500000, + webhookUrl: settings.webhook_url || '', + webhookSecret: settings.webhook_secret || '', + notificationEmail: settings.notification_email || '' + }); + } + }, [userState?.user?.setting]); + const handleInputChange = (name, value) => { setInputs((inputs) => ({...inputs, [name]: value})); }; @@ -300,7 +319,36 @@ const PersonalSetting = () => { } }; + const handleNotificationSettingChange = (type, value) => { + setNotificationSettings(prev => ({ + ...prev, + [type]: value.target ? value.target.value : value // 处理 Radio 事件对象 + })); + }; + + const saveNotificationSettings = async () => { + try { + const res = await API.put('/api/user/setting', { + notify_type: notificationSettings.warningType, + quota_warning_threshold: parseFloat(notificationSettings.warningThreshold), + webhook_url: notificationSettings.webhookUrl, + webhook_secret: notificationSettings.webhookSecret, + notification_email: notificationSettings.notificationEmail + }); + + if (res.data.success) { + showSuccess(t('通知设置已更新')); + await getUserData(); + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('更新通知设置失败')); + } + }; + return ( +
@@ -526,9 +574,7 @@ const PersonalSetting = () => {
{t('微信')} -
+
{
@@ -672,18 +722,8 @@ const PersonalSetting = () => { style={{marginTop: '10px'}} /> )} - {status.wechat_login && ( - - )} setShowWeChatBindModal(false)} - // onOpen={() => setShowWeChatBindModal(true)} visible={showWeChatBindModal} size={'small'} > @@ -707,9 +747,121 @@ const PersonalSetting = () => {
+ + {t('通知设置')} +
+ {t('通知方式')} +
+ handleNotificationSettingChange('warningType', value)} + > + {t('邮件通知')} + {t('Webhook通知')} + +
+
+ {notificationSettings.warningType === 'webhook' && ( + <> +
+ {t('Webhook地址')} +
+ handleNotificationSettingChange('webhookUrl', val)} + placeholder={t('请输入Webhook地址,例如: https://example.com/webhook')} + /> + + {t('只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')} + + +
setShowWebhookDocs(!showWebhookDocs)}> + {t('Webhook请求结构')} {showWebhookDocs ? '▼' : '▶'} +
+ +
+{`{
+    "type": "quota_exceed",      // 通知类型
+    "title": "标题",             // 通知标题
+    "content": "通知内容",       // 通知内容,支持 {{value}} 变量占位符
+    "values": ["值1", "值2"],    // 按顺序替换content中的 {{value}} 占位符
+    "timestamp": 1739950503      // 时间戳
+}
+
+示例:
+{
+    "type": "quota_exceed",
+    "title": "额度预警通知",
+    "content": "您的额度即将用尽,当前剩余额度为 {{value}}",
+    "values": ["$0.99"],
+    "timestamp": 1739950503
+}`}
+                                                    
+
+
+
+
+
+ {t('接口凭证(可选)')} +
+ handleNotificationSettingChange('webhookSecret', val)} + placeholder={t('请输入密钥')} + /> + + {t('密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性')} + + + {t('Authorization: Bearer your-secret-key')} + +
+
+ + )} + {notificationSettings.warningType === 'email' && ( +
+ {t('通知邮箱')} +
+ handleNotificationSettingChange('notificationEmail', val)} + placeholder={t('留空则使用账号绑定的邮箱')} + /> + + {t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')} + +
+
+ )} +
+ {t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)} +
+ handleNotificationSettingChange('warningThreshold', val)} + style={{width: 200}} + placeholder={t('请输入预警额度')} + data={[ + { value: 100000, label: '0.2$' }, + { value: 500000, label: '1$' }, + { value: 1000000, label: '5$' }, + { value: 5000000, label: '10$' } + ]} + /> +
+ + {t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')} + +
+
+ +
+
setShowEmailBindModal(false)} - // onOpen={() => setShowEmailBindModal(true)} onOk={bindEmail} visible={showEmailBindModal} size={'small'} diff --git a/web/src/components/SiderBar.js b/web/src/components/SiderBar.js index 503dc81a5a79b02c143f588f77184b7f1c2bec8d..46d728dafb4906171b450d22f13d1035586057ab 100644 --- a/web/src/components/SiderBar.js +++ b/web/src/components/SiderBar.js @@ -80,7 +80,7 @@ const SiderBar = () => { itemKey: 'channel', to: '/channel', icon: , - className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle', + className: isAdmin() ? '' : 'tableHiddle', }, { text: t('聊天'), @@ -101,7 +101,7 @@ const SiderBar = () => { icon: , className: localStorage.getItem('enable_data_export') === 'true' - ? 'semi-navigation-item-normal' + ? '' : 'tableHiddle', }, { @@ -109,7 +109,7 @@ const SiderBar = () => { itemKey: 'redemption', to: '/redemption', icon: , - className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle', + className: isAdmin() ? '' : 'tableHiddle', }, { text: t('钱包'), @@ -122,7 +122,7 @@ const SiderBar = () => { itemKey: 'user', to: '/user', icon: , - className: isAdmin() ? 'semi-navigation-item-normal' : 'tableHiddle', + className: isAdmin() ? '' : 'tableHiddle', }, { text: t('日志'), @@ -137,7 +137,7 @@ const SiderBar = () => { icon: , className: localStorage.getItem('enable_drawing') === 'true' - ? 'semi-navigation-item-normal' + ? '' : 'tableHiddle', }, { @@ -147,7 +147,7 @@ const SiderBar = () => { icon: , className: localStorage.getItem('enable_task') === 'true' - ? 'semi-navigation-item-normal' + ? '' : 'tableHiddle', }, { diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js index 1c953f6b93480c69af3ffad7490b7b787f14381a..3149f91e19d92fd66a682f1beca601471748f0ef 100644 --- a/web/src/components/SystemSetting.js +++ b/web/src/components/SystemSetting.js @@ -368,6 +368,17 @@ const SystemSetting = () => { ) + + 注意:代理功能仅对图片请求和 Webhook 请求生效,不会影响其他 API 请求。如需配置 API 请求代理,请参考 + + {' '}API 代理设置文档 + + 。 + limit) { @@ -67,6 +67,73 @@ export function renderRatio(ratio) { return {ratio}x {i18next.t('倍率')}; } +const measureTextWidth = (text, style = { + fontSize: '14px', + fontFamily: '-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif' +}, containerWidth) => { + const span = document.createElement('span'); + + span.style.visibility = 'hidden'; + span.style.position = 'absolute'; + span.style.whiteSpace = 'nowrap'; + span.style.fontSize = style.fontSize; + span.style.fontFamily = style.fontFamily; + + span.textContent = text; + + document.body.appendChild(span); + const width = span.offsetWidth; + + document.body.removeChild(span); + + return width; +}; + +export function truncateText(text, maxWidth = 200) { + if (!isMobile()) { + return text; + } + if (!text) return text; + + try { + // Handle percentage-based maxWidth + let actualMaxWidth = maxWidth; + if (typeof maxWidth === 'string' && maxWidth.endsWith('%')) { + const percentage = parseFloat(maxWidth) / 100; + // Use window width as fallback container width + actualMaxWidth = window.innerWidth * percentage; + } + + const width = measureTextWidth(text); + if (width <= actualMaxWidth) return text; + + let left = 0; + let right = text.length; + let result = text; + + while (left <= right) { + const mid = Math.floor((left + right) / 2); + const truncated = text.slice(0, mid) + '...'; + const currentWidth = measureTextWidth(truncated); + + if (currentWidth <= actualMaxWidth) { + result = truncated; + left = mid + 1; + } else { + right = mid - 1; + } + } + + return result; + } catch (error) { + console.warn('Text measurement failed, falling back to character count', error); + if (text.length > 20) { + return text.slice(0, 17) + '...'; + } + return text; + } +} + export const renderGroupOption = (item) => { const { disabled, @@ -386,7 +453,7 @@ export function renderQuotaWithPrompt(quota, digits) { let displayInCurrency = localStorage.getItem('display_in_currency'); displayInCurrency = displayInCurrency === 'true'; if (displayInCurrency) { - return '|' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + ''; + return ' | ' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + ''; } return ''; } diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 3d2c7a55a32efa9fd605a61a6fcde3a2f6d6ef3b..a2cb3b831f50be2043633f68d639e030c3567226 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1249,5 +1249,26 @@ "已注销": "Logged out", "自动禁用关键词": "Automatic disable keywords", "一行一个,不区分大小写": "One line per keyword, not case-sensitive", - "当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel" + "当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel", + "请求并计费模型": "Request and charge model", + "实际模型": "Actual model", + "渠道信息": "Channel information", + "通知设置": "Notification settings", + "Webhook地址": "Webhook URL", + "请输入Webhook地址,例如: https://example.com/webhook": "Please enter the Webhook URL, e.g.: https://example.com/webhook", + "邮件通知": "Email notification", + "Webhook通知": "Webhook notification", + "接口凭证(可选)": "Interface credentials (optional)", + "密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性": "The secret will be added to the request header as a Bearer token to verify the legitimacy of the webhook request", + "Authorization: Bearer your-secret-key": "Authorization: Bearer your-secret-key", + "额度预警阈值": "Quota warning threshold", + "当剩余额度低于此数值时,系统将通过选择的方式发送通知": "When the remaining quota is lower than this value, the system will send a notification through the selected method", + "Webhook请求结构": "Webhook request structure", + "只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求": "Only https is supported, the system will send a notification through POST, please ensure the address can receive POST requests", + "保存设置": "Save settings", + "通知邮箱": "Notification email", + "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "Set the email address for receiving quota warning notifications, if not set, the email address bound to the account will be used", + "留空则使用账号绑定的邮箱": "If left blank, the email address bound to the account will be used", + "代理站地址": "Base URL", + "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in" } \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index e98610594f7d604ea4c5a7a027846d5e57033a91..4720100aa9fec4c9f80decdb5417d9a71a588144 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -540,21 +540,23 @@ const EditChannel = (props) => { value={inputs.name} autoComplete="new-password" /> - {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && ( + {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && ( <>
- {t('BaseURL')}: + {t('代理站地址')}:
- { - handleInputChange('base_url', value); - }} - value={inputs.base_url} - autoComplete="new-password" - /> + + { + handleInputChange('base_url', value); + }} + value={inputs.base_url} + autoComplete="new-password" + /> + )}
diff --git a/web/src/pages/Playground/Playground.js b/web/src/pages/Playground/Playground.js index 3468d2b176b2b9dfa415bab85581179c2dc7ea4e..8579d1cc662eb46401923fc8fa9baacf1c6eadf1 100644 --- a/web/src/pages/Playground/Playground.js +++ b/web/src/pages/Playground/Playground.js @@ -7,7 +7,7 @@ import { SSE } from 'sse'; import { IconSetting } from '@douyinfe/semi-icons'; import { StyleContext } from '../../context/Style/index.js'; import { useTranslation } from 'react-i18next'; -import { renderGroupOption } from '../../helpers/render.js'; +import { renderGroupOption, truncateText } from '../../helpers/render.js'; const roleInfo = { user: { @@ -99,9 +99,10 @@ const Playground = () => { const { success, message, data } = res.data; if (success) { let localGroupOptions = Object.entries(data).map(([group, info]) => ({ - label: info.desc, + label: truncateText(info.desc, "50%"), value: group, - ratio: info.ratio + ratio: info.ratio, + fullLabel: info.desc // 保存完整文本用于tooltip })); if (localGroupOptions.length === 0) {