diff --git a/.env.example b/.env.example
index 36ba42bd8666422b12c7d4003164353a7bd4ff36..07602eca08e83e9aa5d422a4f10c16951ed3e75d 100644
--- a/.env.example
+++ b/.env.example
@@ -10,9 +10,9 @@
# 数据库相关配置
# 数据库连接字符串
-# SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
+# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
# 日志数据库连接字符串
-# LOG_SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
+# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
# SQLite数据库路径
# SQLITE_PATH=/path/to/sqlite.db
# 数据库最大空闲连接数
diff --git a/README.en.md b/README.en.md
index feb4b0bb46f9167aa633168d37ccade2e7f29380..446c88f61455c2fbcbbbb0e4358214ae5d217e30 100644
--- a/README.en.md
+++ b/README.en.md
@@ -89,6 +89,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
- `CRYPTO_SECRET`: Encryption key for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
+- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
+- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
## Deployment
diff --git a/Rerank.md b/Rerank.md
index 6a07287eff221bbe2a3259fb0d579341af7cd66a..dc57d99bb0b7020a31eb12f3fa8c5d96012a3ab9 100644
--- a/Rerank.md
+++ b/Rerank.md
@@ -13,7 +13,7 @@ Request:
```json
{
- "model": "rerank-multilingual-v3.0",
+ "model": "jina-reranker-v2-base-multilingual",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
diff --git a/VERSION b/VERSION
index c302e31b8afc35c818fbba2f2d1a927f255a5861..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +0,0 @@
-v0.4.7.2.1
\ No newline at end of file
diff --git a/common/constants.go b/common/constants.go
index f967d066e29cd04796272fc57cf53d2ea5d69c40..04fb1b9a632852f9b646c0dda946a1eaef386e95 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
var RetryTimes = 0
-var RootUserEmail = ""
+//var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
diff --git a/common/go-channel.go b/common/go-channel.go
index 3fc86537255a3e26e2ae110a07d49eb15b7f8200..f9168fc4674e5f53e8168663da366aaa10649bb9 100644
--- a/common/go-channel.go
+++ b/common/go-channel.go
@@ -1,22 +1,9 @@
package common
import (
- "fmt"
- "runtime/debug"
"time"
)
-func SafeGoroutine(f func()) {
- go func() {
- defer func() {
- if r := recover(); r != nil {
- SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
- }
- }()
- f()
- }()
-}
-
func SafeSendBool(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
diff --git a/common/logger.go b/common/logger.go
index 93d557d8c6eebf09963e5728d8e345e844b280ab..86d15fa4db7ac46a96abe1b2dfeda9839978ed2e 100644
--- a/common/logger.go
+++ b/common/logger.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io"
"log"
@@ -80,9 +81,9 @@ func logHelper(ctx context.Context, level string, msg string) {
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
- go func() {
+ gopool.Go(func() {
SetupLogger()
- }()
+ })
}
}
@@ -100,6 +101,14 @@ func LogQuota(quota int) string {
}
}
+func FormatQuota(quota int) string {
+ if DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
+ } else {
+ return fmt.Sprintf("%d", quota)
+ }
+}
+
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
jsonStr, err := json.Marshal(obj)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index bb94ad36c17293a9647710b079487d3ee448d7cb..542cd93c6965310e95da0c7e045d1b26da95c3ae 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -233,7 +233,11 @@ var (
modelRatioMapMutex = sync.RWMutex{}
)
-var CompletionRatio map[string]float64 = nil
+var (
+ CompletionRatio map[string]float64 = nil
+ CompletionRatioMutex = sync.RWMutex{}
+)
+
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"gpt-4o-gizmo-*": 3,
@@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
-func CompletionRatio2JSONString() string {
+func GetCompletionRatioMap() map[string]float64 {
+ CompletionRatioMutex.Lock()
+ defer CompletionRatioMutex.Unlock()
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
+ return CompletionRatio
+}
+
+func CompletionRatio2JSONString() string {
+ GetCompletionRatioMap()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
@@ -346,11 +357,15 @@ func CompletionRatio2JSONString() string {
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
+ CompletionRatioMutex.Lock()
+ defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
+ GetCompletionRatioMap()
+
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
@@ -476,24 +491,3 @@ func GetAudioCompletionRatio(name string) float64 {
}
return 2
}
-
-//func GetAudioPricePerMinute(name string) float64 {
-// if strings.HasPrefix(name, "gpt-4o-realtime") {
-// return 0.06
-// }
-// return 0.06
-//}
-//
-//func GetAudioCompletionPricePerMinute(name string) float64 {
-// if strings.HasPrefix(name, "gpt-4o-realtime") {
-// return 0.24
-// }
-// return 0.24
-//}
-
-func GetCompletionRatioMap() map[string]float64 {
- if CompletionRatio == nil {
- CompletionRatio = defaultCompletionRatio
- }
- return CompletionRatio
-}
diff --git a/constant/env.go b/constant/env.go
index 4135e8c7c3659d2a3487fade91da20db8330f073..bffbfeea5ba1efbafdd20e176d5d0dfe2402c021 100644
--- a/constant/env.go
+++ b/constant/env.go
@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
+var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
+var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
+
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
@@ -44,5 +47,5 @@ func InitEnv() {
}
}
-// 是否生成初始令牌,默认关闭。
+// GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
diff --git a/constant/user_setting.go b/constant/user_setting.go
new file mode 100644
index 0000000000000000000000000000000000000000..a5b921b2ffbcf24c0f72a309df49ce280d8a466f
--- /dev/null
+++ b/constant/user_setting.go
@@ -0,0 +1,14 @@
+package constant
+
+var (
+ UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
+ UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
+ UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
+ UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
+ UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
+)
+
+var (
+ NotifyTypeEmail = "email" // Email 邮件
+ NotifyTypeWebhook = "webhook" // Webhook
+)
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 7e74bec23dbbf2a050e363ace4020cd313f00909..4b0cc169cb05c1d71b0443eff84ebe2681786c5b 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
- if common.RootUserEmail == "" {
- common.RootUserEmail = model.GetRootUserEmail()
- }
+
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
@@ -295,10 +293,7 @@ func testAllChannels(notify bool) error {
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
- err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
- if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
- }
+ service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
})
return nil
diff --git a/controller/pricing.go b/controller/pricing.go
index 36caff9d1dbe3451c3e64ee370511e0fb6afc055..d7af5a4c8341308e7d5f4347ba0a405426e509ba 100644
--- a/controller/pricing.go
+++ b/controller/pricing.go
@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
}
var group string
if exists {
- user, err := model.GetUserById(userId.(int), false)
+ user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
}
diff --git a/controller/relay.go b/controller/relay.go
index d7e0f00ad4916a9eb0cb2d72976fef82de39514a..0f7394156ce8d9a57f63fe64ed8e390b2c8373ca 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
- err = relay.ImageHelper(c, relayMode)
+ err = relay.ImageHelper(c)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
diff --git a/controller/user.go b/controller/user.go
index 7146f00e2d904378899f9d2e5e20a616d257a240..51e6f955c417fe333d0ebfe3217599216cc29b07 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "net/url"
"one-api/common"
"one-api/model"
"one-api/setting"
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
if err != nil {
id = c.GetInt("id")
}
- user, err := model.GetUserById(id, true)
+ user, err := model.GetUserCache(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -869,9 +870,6 @@ func EmailBind(c *gin.Context) {
})
return
}
- if user.Role == common.RoleRootUser {
- common.RootUserEmail = email
- }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -913,3 +911,115 @@ func TopUp(c *gin.Context) {
})
return
}
+
+type UpdateUserSettingRequest struct {
+ QuotaWarningType string `json:"notify_type"`
+ QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
+ WebhookUrl string `json:"webhook_url,omitempty"`
+ WebhookSecret string `json:"webhook_secret,omitempty"`
+ NotificationEmail string `json:"notification_email,omitempty"`
+}
+
+func UpdateUserSetting(c *gin.Context) {
+ var req UpdateUserSettingRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+
+ // 验证预警类型
+ if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的预警类型",
+ })
+ return
+ }
+
+ // 验证预警阈值
+ if req.QuotaWarningThreshold <= 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "预警阈值必须大于0",
+ })
+ return
+ }
+
+ // 如果是webhook类型,验证webhook地址
+ if req.QuotaWarningType == constant.NotifyTypeWebhook {
+ if req.WebhookUrl == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "Webhook地址不能为空",
+ })
+ return
+ }
+ // 验证URL格式
+ if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的Webhook地址",
+ })
+ return
+ }
+ }
+
+ // 如果是邮件类型,验证邮箱地址
+ if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
+ // 验证邮箱格式
+ if !strings.Contains(req.NotificationEmail, "@") {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的邮箱地址",
+ })
+ return
+ }
+ }
+
+ userId := c.GetInt("id")
+ user, err := model.GetUserById(userId, true)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ // 构建设置
+ settings := map[string]interface{}{
+ constant.UserSettingNotifyType: req.QuotaWarningType,
+ constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
+ }
+
+ // 如果是webhook类型,添加webhook相关设置
+ if req.QuotaWarningType == constant.NotifyTypeWebhook {
+ settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
+ if req.WebhookSecret != "" {
+ settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
+ }
+ }
+
+ // 如果提供了通知邮箱,添加到设置中
+ if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
+ settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
+ }
+
+ // 更新用户设置
+ user.SetSetting(settings)
+ if err := user.Update(false); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "更新设置失败: " + err.Error(),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "设置已更新",
+ })
+}
diff --git a/docker-compose.yml b/docker-compose.yml
index 640cf074788a21567c8e339e6d137d4f69a03720..0f23cea27171d75a32a64f05ab1d211bddb50e44 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -24,7 +24,7 @@ services:
- redis
- mysql
healthcheck:
- test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
+ test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
interval: 30s
timeout: 10s
retries: 3
diff --git a/dto/notify.go b/dto/notify.go
new file mode 100644
index 0000000000000000000000000000000000000000..b75cec70cae493900ddbd31271fc5059efc1003f
--- /dev/null
+++ b/dto/notify.go
@@ -0,0 +1,25 @@
+package dto
+
+type Notify struct {
+ Type string `json:"type"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Values []interface{} `json:"values"`
+}
+
+const ContentValueParam = "{{value}}"
+
+const (
+ NotifyTypeQuotaExceed = "quota_exceed"
+ NotifyTypeChannelUpdate = "channel_update"
+ NotifyTypeChannelTest = "channel_test"
+)
+
+func NewNotify(t string, title string, content string, values []interface{}) Notify {
+ return Notify{
+ Type: t,
+ Title: title,
+ Content: content,
+ Values: values,
+ }
+}
diff --git a/dto/openai_request.go b/dto/openai_request.go
index 0f6411bb7a0ed792d722eb0072ccaeb96c225a97..028e0286cdb26188cc8c20b351830e70a46cfae6 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -18,6 +18,8 @@ type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
+ Prefix any `json:"prefix,omitempty"`
+ Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
@@ -86,18 +88,20 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
}
type Message struct {
- Role string `json:"role"`
- Content json.RawMessage `json:"content"`
- Name *string `json:"name,omitempty"`
- Prefix *bool `json:"prefix,omitempty"`
- ReasoningContent string `json:"reasoning_content,omitempty"`
- ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
- ToolCallId string `json:"tool_call_id,omitempty"`
+ Role string `json:"role"`
+ Content json.RawMessage `json:"content"`
+ Name *string `json:"name,omitempty"`
+ Prefix *bool `json:"prefix,omitempty"`
+ ReasoningContent string `json:"reasoning_content,omitempty"`
+ ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
+ ToolCallId string `json:"tool_call_id,omitempty"`
+ parsedContent []MediaContent
+ parsedStringContent *string
}
type MediaContent struct {
Type string `json:"type"`
- Text string `json:"text"`
+ Text string `json:"text,omitempty"`
ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"`
}
@@ -146,6 +150,9 @@ func (m *Message) SetToolCalls(toolCalls any) {
}
func (m *Message) StringContent() string {
+ if m.parsedStringContent != nil {
+ return *m.parsedStringContent
+ }
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return stringContent
@@ -156,78 +163,113 @@ func (m *Message) StringContent() string {
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
+ m.parsedStringContent = &content
+ m.parsedContent = nil
+}
+
+func (m *Message) SetMediaContent(content []MediaContent) {
+ jsonContent, _ := json.Marshal(content)
+ m.Content = jsonContent
+ m.parsedContent = nil
+ m.parsedStringContent = nil
}
func (m *Message) IsStringContent() bool {
+ if m.parsedStringContent != nil {
+ return true
+ }
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+ m.parsedStringContent = &stringContent
return true
}
return false
}
func (m *Message) ParseContent() []MediaContent {
+ if m.parsedContent != nil {
+ return m.parsedContent
+ }
+
var contentList []MediaContent
+
+ // 先尝试解析为字符串
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
- contentList = append(contentList, MediaContent{
+ contentList = []MediaContent{{
Type: ContentTypeText,
Text: stringContent,
- })
+ }}
+ m.parsedContent = contentList
return contentList
}
- var arrayContent []json.RawMessage
+
+ // 尝试解析为数组
+ var arrayContent []map[string]interface{}
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
- var contentMap map[string]any
- if err := json.Unmarshal(contentItem, &contentMap); err != nil {
+ contentType, ok := contentItem["type"].(string)
+ if !ok {
continue
}
- switch contentMap["type"] {
+
+ switch contentType {
case ContentTypeText:
- if subStr, ok := contentMap["text"].(string); ok {
+ if text, ok := contentItem["text"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeText,
- Text: subStr,
+ Text: text,
})
}
+
case ContentTypeImageURL:
- if subObj, ok := contentMap["image_url"].(map[string]any); ok {
- detail, ok := subObj["detail"]
- if ok {
- subObj["detail"] = detail.(string)
- } else {
- subObj["detail"] = "high"
- }
+ imageUrl := contentItem["image_url"]
+ switch v := imageUrl.(type) {
+ case string:
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
- Url: subObj["url"].(string),
- Detail: subObj["detail"].(string),
- },
- })
- } else if url, ok := contentMap["image_url"].(string); ok {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeImageURL,
- ImageUrl: MessageImageUrl{
- Url: url,
+ Url: v,
Detail: "high",
},
})
+ case map[string]interface{}:
+ url, ok1 := v["url"].(string)
+ detail, ok2 := v["detail"].(string)
+ if !ok2 {
+ detail = "high"
+ }
+ if ok1 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeImageURL,
+ ImageUrl: MessageImageUrl{
+ Url: url,
+ Detail: detail,
+ },
+ })
+ }
}
+
case ContentTypeInputAudio:
- if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeInputAudio,
- InputAudio: MessageInputAudio{
- Data: subObj["data"].(string),
- Format: subObj["format"].(string),
- },
- })
+ if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
+ data, ok1 := audioData["data"].(string)
+ format, ok2 := audioData["format"].(string)
+ if ok1 && ok2 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeInputAudio,
+ InputAudio: MessageInputAudio{
+ Data: data,
+ Format: format,
+ },
+ })
+ }
}
}
}
- return contentList
}
- return nil
+
+ if len(contentList) > 0 {
+ m.parsedContent = contentList
+ }
+ return contentList
}
diff --git a/dto/openai_response.go b/dto/openai_response.go
index 2e0e2221e1ce5e27d9aa9ae4afb3abf7021a3380..febf01ff0d58fbd66660c4d9a4258e449538b743 100644
--- a/dto/openai_response.go
+++ b/dto/openai_response.go
@@ -62,9 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
}
type ChatCompletionsStreamResponseChoiceDelta struct {
- Content *string `json:"content,omitempty"`
- Role string `json:"role,omitempty"`
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ Content *string `json:"content,omitempty"`
+ ReasoningContent *string `json:"reasoning_content,omitempty"`
+ Role string `json:"role,omitempty"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -78,6 +79,13 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
return *c.Content
}
+func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
+ if c.ReasoningContent == nil {
+ return ""
+ }
+ return *c.ReasoningContent
+}
+
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
diff --git a/main.go b/main.go
index 68dae8f496c4cc455a4327a12b64094f67cfa4f4..495057cf1efc4bbf842378817c9b38dc11e41814 100644
--- a/main.go
+++ b/main.go
@@ -119,9 +119,9 @@ func main() {
}
if os.Getenv("ENABLE_PPROF") == "true" {
- go func() {
+ gopool.Go(func() {
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
- }()
+ })
go common.Monitor()
common.SysLog("pprof enabled")
}
diff --git a/middleware/distributor.go b/middleware/distributor.go
index c90f3e5eeb286ae484668f07a27a989c2bd61ccf..e0f9342a84f2c3f4858af8789bdd6e3592ed7048 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -135,17 +135,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
- abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
- abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
- abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
@@ -170,7 +167,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
diff --git a/model/option.go b/model/option.go
index 0c4114a42256be3bbb4b0d2fdbc11b2bc3fd41ae..24935c69d1f7d0dac57ab4d2d4b72ba0a46f0a4f 100644
--- a/model/option.go
+++ b/model/option.go
@@ -84,7 +84,7 @@ func InitOptionMap() {
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
- common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
+ common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
- case "PreConsumedQuota":
+ case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
diff --git a/model/token.go b/model/token.go
index 3abd22cf6e45c79e351e935fcde62e906e04835a..8587ea62a9f3e24c6c9f8f8c24f095d08757af5b 100644
--- a/model/token.go
+++ b/model/token.go
@@ -3,13 +3,11 @@ package model
import (
"errors"
"fmt"
- "github.com/bytedance/gopkg/util/gopool"
- "gorm.io/gorm"
"one-api/common"
- relaycommon "one-api/relay/common"
- "one-api/setting"
- "strconv"
"strings"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "gorm.io/gorm"
)
type Token struct {
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
-
-func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
- if quota < 0 {
- return errors.New("quota 不能为负数!")
- }
- if relayInfo.IsPlayground {
- return nil
- }
- //if relayInfo.TokenUnlimited {
- // return nil
- //}
- token, err := GetTokenById(relayInfo.TokenId)
- if err != nil {
- return err
- }
- if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
- return errors.New("令牌额度不足")
- }
- err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
- if err != nil {
- return err
- }
- return nil
-}
-
-func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
-
- if quota > 0 {
- err = DecreaseUserQuota(relayInfo.UserId, quota)
- } else {
- err = IncreaseUserQuota(relayInfo.UserId, -quota)
- }
- if err != nil {
- return err
- }
-
- if !relayInfo.IsPlayground {
- if quota > 0 {
- err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
- } else {
- err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
- }
- if err != nil {
- return err
- }
- }
-
- if sendEmail {
- if (quota + preConsumedQuota) != 0 {
- quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
- noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
- if quotaTooLow || noMoreQuota {
- go func() {
- email, err := GetUserEmail(relayInfo.UserId)
- if err != nil {
- common.SysError("failed to fetch user email: " + err.Error())
- }
- prompt := "您的额度即将用尽"
- if noMoreQuota {
- prompt = "您的额度已用尽"
- }
- if email != "" {
- topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
- err = common.SendEmail(prompt, email,
- fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink))
- if err != nil {
- common.SysError("failed to send email" + err.Error())
- }
- common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
- }
- }()
- }
- }
- }
-
- return nil
-}
diff --git a/model/token_cache.go b/model/token_cache.go
index 99b762f513ecb1aa42b2a48be104e2fb5993ca09..0fe02fea59058f4b5651bd927d7504d098f3d984 100644
--- a/model/token_cache.go
+++ b/model/token_cache.go
@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
func cacheGetTokenByKey(key string) (*Token, error) {
hmacKey := common.GenerateHMAC(key)
if !common.RedisEnabled {
- return nil, nil
+ return nil, fmt.Errorf("redis is not enabled")
}
var token Token
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
diff --git a/model/user.go b/model/user.go
index 95123c2192f2c9ac4b27db729f964e8f5ec75a60..427b0625f4b2b998b1dd452e7bf969d984e21463 100644
--- a/model/user.go
+++ b/model/user.go
@@ -1,6 +1,7 @@
package model
import (
+ "encoding/json"
"errors"
"fmt"
"one-api/common"
@@ -38,6 +39,20 @@ type User struct {
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
+ Setting string `json:"setting" gorm:"type:text;column:setting"`
+}
+
+func (user *User) ToBaseUser() *UserBase {
+ cache := &UserBase{
+ Id: user.Id,
+ Group: user.Group,
+ Quota: user.Quota,
+ Status: user.Status,
+ Username: user.Username,
+ Setting: user.Setting,
+ Email: user.Email,
+ }
+ return cache
}
func (user *User) GetAccessToken() string {
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
+func (user *User) GetSetting() map[string]interface{} {
+ if user.Setting == "" {
+ return nil
+ }
+ return common.StrToMap(user.Setting)
+}
+
+func (user *User) SetSetting(setting map[string]interface{}) {
+ settingBytes, err := json.Marshal(setting)
+ if err != nil {
+ common.SysError("failed to marshal setting: " + err.Error())
+ return
+ }
+ user.Setting = string(settingBytes)
+}
+
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
return err
}
- // 更新缓存
- return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
+ // Update cache
+ return updateUserCache(*user)
}
func (user *User) Edit(updatePassword bool) error {
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
return err
}
- // 更新缓存
- return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
+ // Update cache
+ return updateUserCache(*user)
}
func (user *User) Delete() error {
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,
- // that means if your field’s value is 0, '', false or other zero values,
- // it won’t be used to build query conditions
+ // that means if your field's value is 0, '', false or other zero values,
+ // it won't be used to build query conditions
password := user.Password
username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
return quota, nil
}
// Don't return error - fall through to DB
- //common.SysError("failed to get user quota from cache: " + err.Error())
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
@@ -580,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
return group, nil
}
+// GetUserSetting gets setting from Redis first, falls back to DB if needed
+func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
+ var setting string
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) {
+ gopool.Go(func() {
+ if err := updateUserSettingCache(id, setting); err != nil {
+ common.SysError("failed to update user setting cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ if !fromDB && common.RedisEnabled {
+ setting, err := getUserSettingCache(id)
+ if err == nil {
+ return setting, nil
+ }
+ // Don't return error - fall through to DB
+ }
+ fromDB = true
+ err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
+ if err != nil {
+ return map[string]interface{}{}, err
+ }
+
+ return common.StrToMap(setting), nil
+}
+
func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
@@ -641,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
}
}
-func GetRootUserEmail() (email string) {
- DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
- return email
+//func GetRootUserEmail() (email string) {
+// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
+// return email
+//}
+
+func GetRootUser() (user *User) {
+ DB.Where("role = ?", common.RoleRootUser).First(&user)
+ return user
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
return !errors.Is(err, gorm.ErrRecordNotFound)
}
-func (u *User) FillUserByLinuxDOId() error {
- if u.LinuxDOId == "" {
+func (user *User) FillUserByLinuxDOId() error {
+ if user.LinuxDOId == "" {
return errors.New("linux do id is empty")
}
- err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
+ err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err
}
diff --git a/model/user_cache.go b/model/user_cache.go
index 9dc7e899ebb2ddf99698dc1c5abd2e325ce787f8..cc08288d681e3cf77be4c071079101a8cc0726fa 100644
--- a/model/user_cache.go
+++ b/model/user_cache.go
@@ -1,206 +1,213 @@
package model
import (
+ "encoding/json"
"fmt"
"one-api/common"
"one-api/constant"
- "strconv"
"time"
+
+ "github.com/bytedance/gopkg/util/gopool"
)
-// Change UserCache struct to userCache
-type userCache struct {
+// UserBase struct remains the same as it represents the cached data structure
+type UserBase struct {
Id int `json:"id"`
Group string `json:"group"`
+ Email string `json:"email"`
Quota int `json:"quota"`
Status int `json:"status"`
- Role int `json:"role"`
Username string `json:"username"`
+ Setting string `json:"setting"`
}
-// Rename all exported functions to private ones
-// invalidateUserCache clears all user related cache
-func invalidateUserCache(userId int) error {
- if !common.RedisEnabled {
+func (user *UserBase) GetSetting() map[string]interface{} {
+ if user.Setting == "" {
return nil
}
-
- keys := []string{
- fmt.Sprintf(constant.UserGroupKeyFmt, userId),
- fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
- fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
- fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
- }
-
- for _, key := range keys {
- if err := common.RedisDel(key); err != nil {
- return fmt.Errorf("failed to delete cache key %s: %w", key, err)
- }
- }
- return nil
+ return common.StrToMap(user.Setting)
}
-// updateUserGroupCache updates user group cache
-func updateUserGroupCache(userId int, group string) error {
- if !common.RedisEnabled {
- return nil
+func (user *UserBase) SetSetting(setting map[string]interface{}) {
+ settingBytes, err := json.Marshal(setting)
+ if err != nil {
+ common.SysError("failed to marshal setting: " + err.Error())
+ return
}
- return common.RedisSet(
- fmt.Sprintf(constant.UserGroupKeyFmt, userId),
- group,
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
- )
+ user.Setting = string(settingBytes)
}
-// updateUserQuotaCache updates user quota cache
-func updateUserQuotaCache(userId int, quota int) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisSet(
- fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
- fmt.Sprintf("%d", quota),
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
- )
+// getUserCacheKey returns the key for user cache
+func getUserCacheKey(userId int) string {
+ return fmt.Sprintf("user:%d", userId)
}
-// updateUserStatusCache updates user status cache
-func updateUserStatusCache(userId int, userEnabled bool) error {
+// invalidateUserCache clears user cache
+func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
- enabled := "0"
- if userEnabled {
- enabled = "1"
- }
- return common.RedisSet(
- fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
- enabled,
- time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
- )
+ return common.RedisHDelObj(getUserCacheKey(userId))
}
-// updateUserNameCache updates username cache
-func updateUserNameCache(userId int, username string) error {
+// updateUserCache updates all user cache fields using hash
+func updateUserCache(user User) error {
if !common.RedisEnabled {
return nil
}
- return common.RedisSet(
- fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
- username,
+
+ return common.RedisHSetObj(
+ getUserCacheKey(user.Id),
+ user.ToBaseUser(),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
-// updateUserCache updates all user cache fields
-func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
- if !common.RedisEnabled {
- return nil
+// GetUserCache gets complete user cache from hash
+func GetUserCache(userId int) (userCache *UserBase, err error) {
+ var user *User
+ var fromDB bool
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) && user != nil {
+ gopool.Go(func() {
+ if err := updateUserCache(*user); err != nil {
+ common.SysError("failed to update user status cache: " + err.Error())
+ }
+ })
+ }
+ }()
+
+ // Try getting from Redis first
+ userCache, err = cacheGetUserBase(userId)
+ if err == nil {
+ return userCache, nil
}
- if err := updateUserGroupCache(userId, userGroup); err != nil {
- return fmt.Errorf("update group cache: %w", err)
+ // If Redis fails, get from DB
+ fromDB = true
+ user, err = GetUserById(userId, false)
+ if err != nil {
+ return nil, err // Return nil and error if DB lookup fails
}
- if err := updateUserQuotaCache(userId, quota); err != nil {
- return fmt.Errorf("update quota cache: %w", err)
+ // Create cache object from user data
+ userCache = &UserBase{
+ Id: user.Id,
+ Group: user.Group,
+ Quota: user.Quota,
+ Status: user.Status,
+ Username: user.Username,
+ Setting: user.Setting,
+ Email: user.Email,
}
- if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
- return fmt.Errorf("update status cache: %w", err)
+ return userCache, nil
+}
+
+func cacheGetUserBase(userId int) (*UserBase, error) {
+ if !common.RedisEnabled {
+ return nil, fmt.Errorf("redis is not enabled")
}
+ var userCache UserBase
+ // Try getting from Redis first
+ err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
+ if err != nil {
+ return nil, err
+ }
+ return &userCache, nil
+}
- if err := updateUserNameCache(userId, username); err != nil {
- return fmt.Errorf("update username cache: %w", err)
+// Add atomic quota operations using hash fields
+func cacheIncrUserQuota(userId int, delta int64) error {
+ if !common.RedisEnabled {
+ return nil
}
+ return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
+}
- return nil
+func cacheDecrUserQuota(userId int, delta int64) error {
+ return cacheIncrUserQuota(userId, -delta)
}
-// getUserGroupCache gets user group from cache
+// Helper functions to get individual fields if needed
func getUserGroupCache(userId int) (string, error) {
- if !common.RedisEnabled {
- return "", nil
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return "", err
}
- return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
+ return cache.Group, nil
}
-// getUserQuotaCache gets user quota from cache
func getUserQuotaCache(userId int) (int, error) {
- if !common.RedisEnabled {
- return 0, nil
- }
- quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
+ cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
- return strconv.Atoi(quotaStr)
+ return cache.Quota, nil
}
-// getUserStatusCache gets user status from cache
func getUserStatusCache(userId int) (int, error) {
- if !common.RedisEnabled {
- return 0, nil
- }
- statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
+ cache, err := GetUserCache(userId)
if err != nil {
return 0, err
}
- return strconv.Atoi(statusStr)
+ return cache.Status, nil
}
-// getUserNameCache gets username from cache
func getUserNameCache(userId int) (string, error) {
- if !common.RedisEnabled {
- return "", nil
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return "", err
}
- return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
+ return cache.Username, nil
}
-// getUserCache gets complete user cache
-func getUserCache(userId int) (*userCache, error) {
- if !common.RedisEnabled {
- return nil, nil
- }
-
- group, err := getUserGroupCache(userId)
+func getUserSettingCache(userId int) (map[string]interface{}, error) {
+ setting := make(map[string]interface{})
+ cache, err := GetUserCache(userId)
if err != nil {
- return nil, fmt.Errorf("get group cache: %w", err)
+ return setting, err
}
+ return cache.GetSetting(), nil
+}
- quota, err := getUserQuotaCache(userId)
- if err != nil {
- return nil, fmt.Errorf("get quota cache: %w", err)
+// New functions for individual field updates
+func updateUserStatusCache(userId int, status bool) error {
+ if !common.RedisEnabled {
+ return nil
}
-
- status, err := getUserStatusCache(userId)
- if err != nil {
- return nil, fmt.Errorf("get status cache: %w", err)
+ statusInt := common.UserStatusEnabled
+ if !status {
+ statusInt = common.UserStatusDisabled
}
+ return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
+}
- username, err := getUserNameCache(userId)
- if err != nil {
- return nil, fmt.Errorf("get username cache: %w", err)
+func updateUserQuotaCache(userId int, quota int) error {
+ if !common.RedisEnabled {
+ return nil
}
+ return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
+}
- return &userCache{
- Id: userId,
- Group: group,
- Quota: quota,
- Status: status,
- Username: username,
- }, nil
+func updateUserGroupCache(userId int, group string) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
}
-// Add atomic quota operations
-func cacheIncrUserQuota(userId int, delta int64) error {
+func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
- key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
- return common.RedisIncr(key, delta)
+ return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
}
-func cacheDecrUserQuota(userId int, delta int64) error {
- return cacheIncrUserQuota(userId, -delta)
+func updateUserSettingCache(userId int, setting string) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
}
diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go
index 754000980249221a0aac7c0a4f06950650421894..5c2eadc20f5efe63333782eebd1449f39f236a6b 100644
--- a/relay/channel/cloudflare/adaptor.go
+++ b/relay/channel/cloudflare/adaptor.go
@@ -4,13 +4,14 @@ import (
"bytes"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+
+ "github.com/gin-gonic/gin"
)
type Adaptor struct {
diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go
index 14dd74f03d762c8c65a9487d92322764839bee59..d779ee651772b06b471f56c6c0c08ac4078f8010 100644
--- a/relay/channel/deepseek/adaptor.go
+++ b/relay/channel/deepseek/adaptor.go
@@ -10,6 +10,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
)
type Adaptor struct {
@@ -29,7 +30,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+ switch info.RelayMode {
+ case constant.RelayModeCompletions:
+ return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
+ default:
+ return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+ }
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index 681e9988a7a42d1fe02738e7e6fd277adec62d5f..32513c42ab99813de3258cbd6d5bd88ffae66f61 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -1,15 +1,21 @@
package gemini
import (
+ "encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
+ "one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
+ "one-api/service"
+
+ "strings"
+
+ "github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
+ if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
+ return nil, errors.New("not supported model for image generation")
+ }
+
+ // convert size to aspect ratio
+ aspectRatio := "1:1" // default aspect ratio
+ switch request.Size {
+ case "1024x1024":
+ aspectRatio = "1:1"
+ case "1024x1792":
+ aspectRatio = "9:16"
+ case "1792x1024":
+ aspectRatio = "16:9"
+ }
+
+ // build gemini imagen request
+ geminiRequest := GeminiImageRequest{
+ Instances: []GeminiImageInstance{
+ {
+ Prompt: request.Prompt,
+ },
+ },
+ Parameters: GeminiImageParameters{
+ SampleCount: request.N,
+ AspectRatio: aspectRatio,
+ PersonGeneration: "allow_adult", // default allow adult
+ },
+ }
+
+ return geminiRequest, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
+ if strings.HasPrefix(info.UpstreamModelName, "imagen") {
+ return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
+ }
+
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
@@ -73,12 +111,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
-
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+ if strings.HasPrefix(info.UpstreamModelName, "imagen") {
+ return GeminiImageHandler(c, resp, info)
+ }
+
if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info)
} else {
@@ -87,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return
}
+func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+ responseBody, readErr := io.ReadAll(resp.Body)
+ if readErr != nil {
+ return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
+ }
+ _ = resp.Body.Close()
+
+ var geminiResponse GeminiImageResponse
+ if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
+ return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ }
+
+ if len(geminiResponse.Predictions) == 0 {
+ return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
+ }
+
+ // convert to openai format response
+ openAIResponse := dto.ImageResponse{
+ Created: common.GetTimestamp(),
+ Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
+ }
+
+ for _, prediction := range geminiResponse.Predictions {
+ if prediction.RaiFilteredReason != "" {
+ continue // skip filtered image
+ }
+ openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
+ B64Json: prediction.BytesBase64Encoded,
+ })
+ }
+
+ jsonResponse, jsonErr := json.Marshal(openAIResponse)
+ if jsonErr != nil {
+ return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
+ }
+
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(jsonResponse)
+
+ // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
+ // each image has fixed 258 tokens
+ const imageTokens = 258
+ generatedImages := len(openAIResponse.Data)
+
+ usage = &dto.Usage{
+ PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
+ CompletionTokens: 0, // image generation does not calculate completion tokens
+ TotalTokens: imageTokens * generatedImages,
+ }
+
+ return usage, nil
+}
+
func (a *Adaptor) GetModelList() []string {
return ModelList
}
diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go
index 9651bd607071007d37421456affa1843ddbfc3af..b7c1f0cf8758c55ade55ccacabfa6e61c978650d 100644
--- a/relay/channel/gemini/constant.go
+++ b/relay/channel/gemini/constant.go
@@ -16,6 +16,8 @@ var ModelList = []string{
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
+ // imagen models
+ "imagen-3.0-generate-002",
}
var ChannelName = "google gemini"
diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go
index 08a5db8486e6f95a2adf39af6d400fe27dd24a38..bbcb1248d4825b440447c72583ab35f7e3acf3a3 100644
--- a/relay/channel/gemini/dto.go
+++ b/relay/channel/gemini/dto.go
@@ -109,3 +109,30 @@ type GeminiUsageMetadata struct {
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
+
+// Imagen related structs
+type GeminiImageRequest struct {
+ Instances []GeminiImageInstance `json:"instances"`
+ Parameters GeminiImageParameters `json:"parameters"`
+}
+
+type GeminiImageInstance struct {
+ Prompt string `json:"prompt"`
+}
+
+type GeminiImageParameters struct {
+ SampleCount int `json:"sampleCount,omitempty"`
+ AspectRatio string `json:"aspectRatio,omitempty"`
+ PersonGeneration string `json:"personGeneration,omitempty"`
+}
+
+type GeminiImageResponse struct {
+ Predictions []GeminiImagePrediction `json:"predictions"`
+}
+
+type GeminiImagePrediction struct {
+ MimeType string `json:"mimeType"`
+ BytesBase64Encoded string `json:"bytesBase64Encoded"`
+ RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
+ SafetyAttributes any `json:"safetyAttributes,omitempty"`
+}
diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go
index c99e539617371781f4e51590928611c2253f7506..fcea169a8d74aaf130f6d6138f76d479cbb9989a 100644
--- a/relay/channel/mistral/adaptor.go
+++ b/relay/channel/mistral/adaptor.go
@@ -41,9 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
- mistralReq := requestOpenAI2Mistral(*request)
- //common.LogJson(c, "body", mistralReq)
- return mistralReq, nil
+ return requestOpenAI2Mistral(request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -55,7 +53,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
-
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go
index 04add067517762a02126b831f4b551485f3c69ec..8987b8f0836d15b409176b2aa15855dd13f5c14a 100644
--- a/relay/channel/mistral/text.go
+++ b/relay/channel/mistral/text.go
@@ -1,25 +1,21 @@
package mistral
import (
- "encoding/json"
"one-api/dto"
)
-func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
+func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
- if !message.IsStringContent() {
- mediaMessages := message.ParseContent()
- for j, mediaMessage := range mediaMessages {
- if mediaMessage.Type == dto.ContentTypeImageURL {
- imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
- mediaMessage.ImageUrl = imageUrl.Url
- mediaMessages[j] = mediaMessage
- }
+ mediaMessages := message.ParseContent()
+ for j, mediaMessage := range mediaMessages {
+ if mediaMessage.Type == dto.ContentTypeImageURL {
+ imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
+ mediaMessage.ImageUrl = imageUrl.Url
+ mediaMessages[j] = mediaMessage
}
- messageRaw, _ := json.Marshal(mediaMessages)
- message.Content = messageRaw
}
+ message.SetMediaContent(mediaMessages)
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
index 36889cb8d1728674f9fee57d92e52c73dc14644b..7e1c62377f27a7ede4dfc55d638474561d80da3d 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/channel/ollama/adaptor.go
@@ -39,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -46,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
- return requestOpenAI2Ollama(*request), nil
+ return requestOpenAI2Ollama(*request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go
index 080191151d9755ce63d8dc3ee8c698f6ae75f7db..a954c607a69104c4ae35cde8aaf2fb8c66a96460 100644
--- a/relay/channel/ollama/dto.go
+++ b/relay/channel/ollama/dto.go
@@ -3,18 +3,21 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
- Model string `json:"model,omitempty"`
- Messages []dto.Message `json:"messages,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- Seed float64 `json:"seed,omitempty"`
- Topp float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- Stop any `json:"stop,omitempty"`
- Tools []dto.ToolCall `json:"tools,omitempty"`
- ResponseFormat any `json:"response_format,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ Model string `json:"model,omitempty"`
+ Messages []dto.Message `json:"messages,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ Topp float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Stop any `json:"stop,omitempty"`
+ Tools []dto.ToolCall `json:"tools,omitempty"`
+ ResponseFormat any `json:"response_format,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ Suffix any `json:"suffix,omitempty"`
+ StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
}
type Options struct {
@@ -35,7 +38,7 @@ type OllamaEmbeddingRequest struct {
}
type OllamaEmbeddingResponse struct {
- Error string `json:"error,omitempty"`
- Model string `json:"model"`
+ Error string `json:"error,omitempty"`
+ Model string `json:"model"`
Embedding [][]float64 `json:"embeddings,omitempty"`
}
diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go
index 4ecdd19bef927d9941716199a9db31dab9e75808..8b53fbfb56ae899c1c40644ccfae49f46c6b0089 100644
--- a/relay/channel/ollama/relay-ollama.go
+++ b/relay/channel/ollama/relay-ollama.go
@@ -9,14 +9,36 @@ import (
"net/http"
"one-api/dto"
"one-api/service"
+ "strings"
)
-func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
+func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
+ if !message.IsStringContent() {
+ mediaMessages := message.ParseContent()
+ for j, mediaMessage := range mediaMessages {
+ if mediaMessage.Type == dto.ContentTypeImageURL {
+ imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
+ // check if not base64
+ if strings.HasPrefix(imageUrl.Url, "http") {
+ fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+ if err != nil {
+ return nil, err
+ }
+ imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
+ }
+ mediaMessage.ImageUrl = imageUrl
+ mediaMessages[j] = mediaMessage
+ }
+ }
+ message.SetMediaContent(mediaMessages)
+ }
messages = append(messages, dto.Message{
- Role: message.Role,
- Content: message.Content,
+ Role: message.Role,
+ Content: message.Content,
+ ToolCalls: message.ToolCalls,
+ ToolCallId: message.ToolCallId,
})
}
str, ok := request.Stop.(string)
@@ -39,7 +61,10 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
- }
+ Prompt: request.Prompt,
+ StreamOptions: request.StreamOptions,
+ Suffix: request.Suffix,
+ }, nil
}
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index e94399eaa04417e5dc3a7ea67379ee57286f76b4..f927fa74f45f69aed47d5774481230604f29c9fc 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -119,7 +119,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
- if strings.HasPrefix(request.Model, "o3") {
+ if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
request.Temperature = nil
}
if strings.HasSuffix(request.Model, "-high") {
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 6c9359f334bd7af5bff02272a123a1d55d33670a..33cdea48639f8625e25d14662ae862900636e9eb 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -87,6 +87,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
+ if common.DebugEnabled {
+ println(data)
+ }
if len(data) < 6 { // ignore blank line or wrong format
continue
}
@@ -162,6 +165,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
+ responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -182,6 +186,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
+ responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
@@ -273,7 +278,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
- ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
+ ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go
index c02d18a3e84fc016e92a95108e4d4b27dfa3c5c9..797f02442be9b9a233e72b086e708a0e70d91c03 100644
--- a/relay/channel/siliconflow/adaptor.go
+++ b/relay/channel/siliconflow/adaptor.go
@@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+ } else if info.RelayMode == constant.RelayModeCompletions {
+ return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
@@ -72,6 +74,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
+ case constant.RelayModeCompletions:
+ if info.IsStream {
+ err, usage = openai.OaiStreamHandler(c, resp, info)
+ } else {
+ err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+ }
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go
index 06f306f6b0c78a9bf38051d493a8f3437e42658c..97d82c718eaa86f6c995f35b558e4cf253333c12 100644
--- a/relay/channel/zhipu_4v/relay-zhipu_v4.go
+++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go
@@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
mediaMessages[j] = mediaMessage
}
}
- messageRaw, _ := json.Marshal(mediaMessages)
- message.Content = messageRaw
+ message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index 4978f84f09088834b7405b485d2a3ddb2aee7d20..1f4a3a42b6fd99f6344a9a11a86516c239d824e0 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -13,24 +13,24 @@ import (
)
type RelayInfo struct {
- ChannelType int
- ChannelId int
- TokenId int
- TokenKey string
- UserId int
- Group string
- TokenUnlimited bool
- StartTime time.Time
- FirstResponseTime time.Time
- setFirstResponse bool
- ApiType int
- IsStream bool
- IsPlayground bool
- UsePrice bool
- RelayMode int
- UpstreamModelName string
- OriginModelName string
- RecodeModelName string
+ ChannelType int
+ ChannelId int
+ TokenId int
+ TokenKey string
+ UserId int
+ Group string
+ TokenUnlimited bool
+ StartTime time.Time
+ FirstResponseTime time.Time
+ setFirstResponse bool
+ ApiType int
+ IsStream bool
+ IsPlayground bool
+ UsePrice bool
+ RelayMode int
+ UpstreamModelName string
+ OriginModelName string
+ //RecodeModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
@@ -39,6 +39,7 @@ type RelayInfo struct {
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
+ IsModelMapped bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
@@ -50,6 +51,18 @@ type RelayInfo struct {
ChannelSetting map[string]interface{}
}
+// 定义支持流式选项的通道类型
+var streamSupportedChannels = map[int]bool{
+ common.ChannelTypeOpenAI: true,
+ common.ChannelTypeAnthropic: true,
+ common.ChannelTypeAws: true,
+ common.ChannelTypeGemini: true,
+ common.ChannelCloudflare: true,
+ common.ChannelTypeAzure: true,
+ common.ChannelTypeVolcEngine: true,
+ common.ChannelTypeOllama: true,
+}
+
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c)
info.ClientWs = ws
@@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"),
- RecodeModelName: c.GetString("recode_model"),
- ApiType: apiType,
- ApiVersion: c.GetString("api_version"),
- ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
- Organization: c.GetString("channel_organization"),
- ChannelSetting: channelSetting,
+ //RecodeModelName: c.GetString("original_model"),
+ IsModelMapped: false,
+ ApiType: apiType,
+ ApiVersion: c.GetString("api_version"),
+ ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+ Organization: c.GetString("channel_organization"),
+ ChannelSetting: channelSetting,
}
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
@@ -110,9 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region")
}
- if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
- info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
- info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
+ if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
return info
diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go
new file mode 100644
index 0000000000000000000000000000000000000000..948c5226b6917a123955c0c51d5707640e37fd28
--- /dev/null
+++ b/relay/helper/model_mapped.go
@@ -0,0 +1,25 @@
+package helper
+
+import (
+ "encoding/json"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "one-api/relay/common"
+)
+
+func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
+ // map model name
+ modelMapping := c.GetString("model_mapping")
+ if modelMapping != "" && modelMapping != "{}" {
+ modelMap := make(map[string]string)
+ err := json.Unmarshal([]byte(modelMapping), &modelMap)
+ if err != nil {
+ return fmt.Errorf("unmarshal_model_mapping_failed")
+ }
+ if modelMap[info.OriginModelName] != "" {
+ info.UpstreamModelName = modelMap[info.OriginModelName]
+ info.IsModelMapped = true
+ }
+ }
+ return nil
+}
diff --git a/relay/helper/price.go b/relay/helper/price.go
new file mode 100644
index 0000000000000000000000000000000000000000..d65b86aa5a307e51db3325d7cd907b62819e78b8
--- /dev/null
+++ b/relay/helper/price.go
@@ -0,0 +1,41 @@
+package helper
+
+import (
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+ relaycommon "one-api/relay/common"
+ "one-api/setting"
+)
+
+type PriceData struct {
+ ModelPrice float64
+ ModelRatio float64
+ GroupRatio float64
+ UsePrice bool
+ ShouldPreConsumedQuota int
+}
+
+func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
+ modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
+ groupRatio := setting.GetGroupRatio(info.Group)
+ var preConsumedQuota int
+ var modelRatio float64
+ if !usePrice {
+ preConsumedTokens := common.PreConsumedQuota
+ if maxTokens != 0 {
+ preConsumedTokens = promptTokens + maxTokens
+ }
+ modelRatio = common.GetModelRatio(info.OriginModelName)
+ ratio := modelRatio * groupRatio
+ preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+ } else {
+ preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+ }
+ return PriceData{
+ ModelPrice: modelPrice,
+ ModelRatio: modelRatio,
+ GroupRatio: groupRatio,
+ UsePrice: usePrice,
+ ShouldPreConsumedQuota: preConsumedQuota,
+ }
+}
diff --git a/relay/relay-audio.go b/relay/relay-audio.go
index 4c23a8f8d99136cfbb40cb22958345c3621f20b0..b95c1eb693f6ce30720d2e291789b4953ae40536 100644
--- a/relay/relay-audio.go
+++ b/relay/relay-audio.go
@@ -1,7 +1,6 @@
package relay
import (
- "encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
@@ -11,8 +10,10 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
"one-api/service"
"one-api/setting"
+ "strings"
)
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
@@ -27,8 +28,9 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return nil, errors.New("model is required")
}
if setting.ShouldCheckPromptSensitive() {
- err := service.CheckSensitiveInput(audioRequest.Input)
+ words, err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
@@ -73,15 +75,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo.PromptTokens = promptTokens
}
- modelRatio := common.GetModelRatio(audioRequest.Model)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
- ratio := modelRatio * groupRatio
- preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+ priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
+
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
- preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -91,19 +91,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}()
- // map model name
- modelMapping := c.GetString("model_mapping")
- if modelMapping != "" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[audioRequest.Model] != "" {
- audioRequest.Model = modelMap[audioRequest.Model]
- }
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- relayInfo.UpstreamModelName = audioRequest.Model
+
+ audioRequest.Model = relayInfo.UpstreamModelName
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
@@ -140,7 +133,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr
}
- postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
diff --git a/relay/relay-image.go b/relay/relay-image.go
index 207350da6b85cdf00cfebcdbc8027bd356e0846a..afa5b8e2ce160c6e69039b527a50b57bfd7375f9 100644
--- a/relay/relay-image.go
+++ b/relay/relay-image.go
@@ -12,6 +12,7 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
@@ -60,15 +61,16 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
//}
if setting.ShouldCheckPromptSensitive() {
- err := service.CheckSensitiveInput(imageRequest.Prompt)
+ words, err := service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
return imageRequest, nil
}
-func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
+func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
@@ -77,29 +79,20 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
}
- // map model name
- modelMapping := c.GetString("model_mapping")
- if modelMapping != "" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[imageRequest.Model] != "" {
- imageRequest.Model = modelMap[imageRequest.Model]
- }
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- relayInfo.UpstreamModelName = imageRequest.Model
- modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
- if !success {
- modelRatio := common.GetModelRatio(imageRequest.Model)
+ imageRequest.Model = relayInfo.UpstreamModelName
+
+ priceData := helper.ModelPriceHelper(c, relayInfo, 0, 0)
+ if !priceData.UsePrice {
// modelRatio 16 = modelPrice $0.04
// per 1 modelRatio = $0.04 / 16
- modelPrice = 0.0025 * modelRatio
+ priceData.ModelPrice = 0.0025 * priceData.ModelRatio
}
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
sizeRatio := 1.0
@@ -122,11 +115,11 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
}
- imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
- quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
+ priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
+ quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 {
- return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest)
+ return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
}
adaptor := GetAdaptor(relayInfo.ApiType)
@@ -184,7 +177,6 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
- postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
-
+ postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
return nil
}
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 0facecabea1793da881d39815be68a2e5c24a968..766064cbc0db0f69dc5fc5f9488a4a70d9e476d8 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
- err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
+ err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
- err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
+ err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
diff --git a/relay/relay-text.go b/relay/relay-text.go
index f303ff6a678db257d1b68728e02956b2a5da7450..bfd91cdf9cce5331d67e19aa7d9874e639cd9726 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
@@ -14,6 +15,7 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
"one-api/service"
"one-api/setting"
"strings"
@@ -75,40 +77,21 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
- // map model name
- //isModelMapped := false
- modelMapping := c.GetString("model_mapping")
- //isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[textRequest.Model] != "" {
- //isModelMapped = true
- textRequest.Model = modelMap[textRequest.Model]
- // set upstream model name
- //isModelMapped = true
- }
- }
- relayInfo.UpstreamModelName = textRequest.Model
- relayInfo.RecodeModelName = textRequest.Model
- modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
- //err := service.SensitiveWordsCheck(textRequest)
-
if setting.ShouldCheckPromptSensitive() {
- err = checkRequestSensitive(textRequest, relayInfo)
+ words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
}
}
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
+ }
+
+ textRequest.Model = relayInfo.UpstreamModelName
+
// 获取 promptTokens,如果上下文中已经存在,则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
@@ -123,20 +106,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens)
}
- if !getModelPriceSuccess {
- preConsumedTokens := common.PreConsumedQuota
- if textRequest.MaxTokens != 0 {
- preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
- }
- modelRatio = common.GetModelRatio(textRequest.Model)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- }
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
// pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -219,10 +192,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return openaiErr
}
- if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") {
- service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+ if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
+ service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
} else {
- postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
}
return nil
}
@@ -247,19 +220,20 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
return promptTokens, err
}
-func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
+func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
var err error
+ var words []string
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
- err = service.CheckSensitiveMessages(textRequest.Messages)
+ words, err = service.CheckSensitiveMessages(textRequest.Messages)
case relayconstant.RelayModeCompletions:
- err = service.CheckSensitiveInput(textRequest.Prompt)
+ words, err = service.CheckSensitiveInput(textRequest.Prompt)
case relayconstant.RelayModeModerations:
- err = service.CheckSensitiveInput(textRequest.Input)
+ words, err = service.CheckSensitiveInput(textRequest.Input)
case relayconstant.RelayModeEmbeddings:
- err = service.CheckSensitiveInput(textRequest.Input)
+ words, err = service.CheckSensitiveInput(textRequest.Input)
}
- return err
+ return words, err
}
// 预扣费并返回用户剩余配额
@@ -272,7 +246,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
if userQuota-preConsumedQuota < 0 {
- return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
+ return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
@@ -282,18 +256,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
- common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
+ common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
- common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
+ common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
}
}
if preConsumedQuota > 0 {
- err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
+ err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
@@ -307,20 +281,19 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 {
- go func() {
+ gopool.Go(func() {
relayInfoCopy := *relayInfo
- err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
+ err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error())
}
- }()
+ })
}
}
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
- usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
- modelPrice float64, usePrice bool, extraContent string) {
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
+ usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
@@ -332,12 +305,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
+ modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
+ ratio := priceData.ModelRatio * priceData.GroupRatio
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatio
+ modelPrice := priceData.ModelPrice
+ usePrice := priceData.UsePrice
quota := 0
- if !usePrice {
+ if !priceData.UsePrice {
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
@@ -368,7 +347,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
- err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
+ err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go
index 0a41c11d59738c30104fd1abfbb1c730efa509a5..18739d9f92369212e47ce77848def8ee6c3049e0 100644
--- a/relay/relay_embedding.go
+++ b/relay/relay_embedding.go
@@ -10,8 +10,8 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
"one-api/service"
- "one-api/setting"
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
}
- // map model name
- modelMapping := c.GetString("model_mapping")
- //isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[embeddingRequest.Model] != "" {
- embeddingRequest.Model = modelMap[embeddingRequest.Model]
- // set upstream model name
- //isModelMapped = true
- }
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- relayInfo.UpstreamModelName = embeddingRequest.Model
- modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
+ embeddingRequest.Model = relayInfo.UpstreamModelName
promptToken := getEmbeddingPromptToken(*embeddingRequest)
- if !success {
- preConsumedTokens := promptToken
- modelRatio = common.GetModelRatio(embeddingRequest.Model)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- }
relayInfo.PromptTokens = promptToken
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
+
// pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
- postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go
index e53e37d480823d0604a679e92db0b3100e440b67..37178cad3a550c19f8aec4971d893b193e9543d0 100644
--- a/relay/relay_rerank.go
+++ b/relay/relay_rerank.go
@@ -9,8 +9,8 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
"one-api/service"
- "one-api/setting"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
- // map model name
- modelMapping := c.GetString("model_mapping")
- //isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[rerankRequest.Model] != "" {
- rerankRequest.Model = modelMap[rerankRequest.Model]
- // set upstream model name
- //isModelMapped = true
- }
+ err = helper.ModelMappedHelper(c, relayInfo)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
- relayInfo.UpstreamModelName = rerankRequest.Model
- modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
+ rerankRequest.Model = relayInfo.UpstreamModelName
promptToken := getRerankPromptToken(*rerankRequest)
- if !success {
- preConsumedTokens := promptToken
- modelRatio = common.GetModelRatio(rerankRequest.Model)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- }
relayInfo.PromptTokens = promptToken
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
+
// pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
- postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
diff --git a/relay/relay_task.go b/relay/relay_task.go
index 61577faf7937e90e880e359bef106d476441818f..f03fcb2d912f4a1ec2a0e57e22379c45e396a31b 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
- err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
+ err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
diff --git a/router/api-router.go b/router/api-router.go
index b00595af3256e9f4ddc1bd9dc90a7fd7d764f3b4..bf88449a56bc57256da2c0f31edf8a44bb279888 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/pay", controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
+ selfRoute.PUT("/setting", controller.UpdateUserSetting)
}
adminRoute := userRoute.Group("/")
diff --git a/service/cf_worker.go b/service/cf_worker.go
index afe65411b0eac6431a450e94e78e3345b202ad86..40a1e29450c5db52321bb5f05c3b9b763dc5f36f 100644
--- a/service/cf_worker.go
+++ b/service/cf_worker.go
@@ -2,6 +2,7 @@ package service
import (
"bytes"
+ "encoding/json"
"fmt"
"net/http"
"one-api/common"
@@ -9,19 +10,46 @@ import (
"strings"
)
+// WorkerRequest Worker请求的数据结构
+type WorkerRequest struct {
+ URL string `json:"url"`
+ Key string `json:"key"`
+ Method string `json:"method,omitempty"`
+ Headers map[string]string `json:"headers,omitempty"`
+ Body json.RawMessage `json:"body,omitempty"`
+}
+
+// DoWorkerRequest 通过Worker发送请求
+func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
+ if !setting.EnableWorker() {
+ return nil, fmt.Errorf("worker not enabled")
+ }
+ if !strings.HasPrefix(req.URL, "https") {
+ return nil, fmt.Errorf("only support https url")
+ }
+
+ workerUrl := setting.WorkerUrl
+ if !strings.HasSuffix(workerUrl, "/") {
+ workerUrl += "/"
+ }
+
+ // 序列化worker请求数据
+ workerPayload, err := json.Marshal(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
+ }
+
+ return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
+}
+
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
- if !strings.HasPrefix(originUrl, "https") {
- return nil, fmt.Errorf("only support https url")
- }
- workerUrl := setting.WorkerUrl
- if !strings.HasSuffix(workerUrl, "/") {
- workerUrl += "/"
+ req := &WorkerRequest{
+ URL: originUrl,
+ Key: setting.WorkerValidKey,
}
- // post request to worker
- data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
- return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
+ return DoWorkerRequest(req)
} else {
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
return http.Get(originUrl)
diff --git a/service/channel.go b/service/channel.go
index 73545b1e65b26e6d7156badf0cdc8aa68dcef0f1..76bcacf1bb1433d3716519f567de3b3ede97205f 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -4,7 +4,7 @@ import (
"fmt"
"net/http"
"one-api/common"
- relaymodel "one-api/dto"
+ "one-api/dto"
"one-api/model"
"one-api/setting"
"strings"
@@ -15,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
- notifyRootUser(subject, content)
+ NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
- notifyRootUser(subject, content)
+ NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
-func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
+func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
@@ -75,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
return false
}
-func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
+func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
diff --git a/service/log_info_generate.go b/service/log_info_generate.go
index 1ce09d92f7f1fe7408d0395e7cb22b38d406ec58..1e32d6f1ed5276d54f7a6522ff1b6a040efe807d 100644
--- a/service/log_info_generate.go
+++ b/service/log_info_generate.go
@@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
if relayInfo.ReasoningEffort != "" {
other["reasoning_effort"] = relayInfo.ReasoningEffort
}
+ if relayInfo.IsModelMapped {
+ other["is_model_mapped"] = true
+ other["upstream_model_name"] = relayInfo.UpstreamModelName
+ }
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
other["admin_info"] = adminInfo
diff --git a/service/notify-limit.go b/service/notify-limit.go
new file mode 100644
index 0000000000000000000000000000000000000000..309ea54d2192058009e2992cbcc60f3e33a236d1
--- /dev/null
+++ b/service/notify-limit.go
@@ -0,0 +1,117 @@
+package service
+
+import (
+ "fmt"
+ "github.com/bytedance/gopkg/util/gopool"
+ "one-api/common"
+ "one-api/constant"
+ "strconv"
+ "sync"
+ "time"
+)
+
+// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
+var (
+ notifyLimitStore sync.Map
+ cleanupOnce sync.Once
+)
+
+type limitCount struct {
+ Count int
+ Timestamp time.Time
+}
+
+func getDuration() time.Duration {
+ minute := constant.NotificationLimitDurationMinute
+ return time.Duration(minute) * time.Minute
+}
+
+// startCleanupTask starts a background task to clean up expired entries
+func startCleanupTask() {
+ gopool.Go(func() {
+ for {
+ time.Sleep(time.Hour)
+ now := time.Now()
+ notifyLimitStore.Range(func(key, value interface{}) bool {
+ if limit, ok := value.(limitCount); ok {
+ if now.Sub(limit.Timestamp) >= getDuration() {
+ notifyLimitStore.Delete(key)
+ }
+ }
+ return true
+ })
+ }
+ })
+}
+
+// CheckNotificationLimit checks if the user has exceeded their notification limit
+// Returns true if the user can send notification, false if limit exceeded
+func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
+ if common.RedisEnabled {
+ return checkRedisLimit(userId, notifyType)
+ }
+ return checkMemoryLimit(userId, notifyType)
+}
+
+func checkRedisLimit(userId int, notifyType string) (bool, error) {
+ key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+
+ // Get current count
+ count, err := common.RedisGet(key)
+ if err != nil && err.Error() != "redis: nil" {
+ return false, fmt.Errorf("failed to get notification count: %w", err)
+ }
+
+ // If key doesn't exist, initialize it
+ if count == "" {
+ err = common.RedisSet(key, "1", getDuration())
+ return true, err
+ }
+
+ currentCount, _ := strconv.Atoi(count)
+ limit := constant.NotifyLimitCount
+
+ // Check if limit is already reached
+ if currentCount >= limit {
+ return false, nil
+ }
+
+ // Only increment if under limit
+ err = common.RedisIncr(key, 1)
+ if err != nil {
+ return false, fmt.Errorf("failed to increment notification count: %w", err)
+ }
+
+ return true, nil
+}
+
+func checkMemoryLimit(userId int, notifyType string) (bool, error) {
+ // Ensure cleanup task is started
+ cleanupOnce.Do(startCleanupTask)
+
+ key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+ now := time.Now()
+
+ // Get current limit count or initialize new one
+ var currentLimit limitCount
+ if value, ok := notifyLimitStore.Load(key); ok {
+ currentLimit = value.(limitCount)
+ // Check if the entry has expired
+ if now.Sub(currentLimit.Timestamp) >= getDuration() {
+ currentLimit = limitCount{Count: 0, Timestamp: now}
+ }
+ } else {
+ currentLimit = limitCount{Count: 0, Timestamp: now}
+ }
+
+ // Increment count
+ currentLimit.Count++
+
+ // Check against limits
+ limit := constant.NotifyLimitCount
+
+ // Store updated count
+ notifyLimitStore.Store(key, currentLimit)
+
+ return currentLimit.Count <= limit, nil
+}
diff --git a/service/quota.go b/service/quota.go
index ab04800823bf7af79ec6f70134927d55a8cd20a5..2cae93def4228f8aa93f0479c5301813015b9ba8 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -3,11 +3,14 @@ package service
import (
"errors"
"fmt"
+ "github.com/bytedance/gopkg/util/gopool"
"math"
"one-api/common"
+ constant2 "one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
"one-api/setting"
"strings"
"time"
@@ -66,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return err
}
- modelName := relayInfo.UpstreamModelName
+ modelName := relayInfo.OriginModelName
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
@@ -92,14 +95,14 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
quota := calculateAudioQuota(quotaInfo)
if userQuota < quota {
- return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
+ return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
- return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
+ return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
}
- err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
+ err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
@@ -120,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(modelName)
- audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
+ audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
quotaInfo := QuotaInfo{
@@ -171,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
- usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
- modelPrice float64, usePrice bool, extraContent string) {
+ usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -182,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
- completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName)
- audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName)
- audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName)
+ completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
+ audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
+ audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
+
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatio
+ modelPrice := priceData.ModelPrice
+ usePrice := priceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -195,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
- ModelName: relayInfo.RecodeModelName,
+ ModelName: relayInfo.OriginModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
@@ -218,11 +225,11 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota))
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
- err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
+ err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
@@ -231,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
- logModel := relayInfo.RecodeModelName
+ logModel := relayInfo.OriginModelName
if extraContent != "" {
logContent += ", " + extraContent
}
@@ -239,3 +246,88 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
+
+func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
+ if quota < 0 {
+ return errors.New("quota 不能为负数!")
+ }
+ if relayInfo.IsPlayground {
+ return nil
+ }
+ //if relayInfo.TokenUnlimited {
+ // return nil
+ //}
+ token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
+ if err != nil {
+ return err
+ }
+ if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
+ return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
+ }
+ err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
+
+ if quota > 0 {
+ err = model.DecreaseUserQuota(relayInfo.UserId, quota)
+ } else {
+ err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
+ }
+ if err != nil {
+ return err
+ }
+
+ if !relayInfo.IsPlayground {
+ if quota > 0 {
+ err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
+ } else {
+ err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if sendEmail {
+ if (quota + preConsumedQuota) != 0 {
+ checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
+ }
+ }
+
+ return nil
+}
+
+func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
+ gopool.Go(func() {
+ userCache, err := model.GetUserCache(userId)
+ if err != nil {
+ common.SysError("failed to get user cache: " + err.Error())
+ }
+ userSetting := userCache.GetSetting()
+ threshold := common.QuotaRemindThreshold
+ if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
+ threshold = int(userCustomThreshold.(float64))
+ }
+
+ //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
+ quotaTooLow := false
+ consumeQuota := quota + preConsumedQuota
+ if userCache.Quota-consumeQuota < threshold {
+ quotaTooLow = true
+ }
+ if quotaTooLow {
+ prompt := "您的额度即将用尽"
+ topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
+ content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}"
+ err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
+ if err != nil {
+ common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
+ }
+ }
+ })
+}
diff --git a/service/sensitive.go b/service/sensitive.go
index 14ac94819e772e6dbc3fc568080e453413345be1..b3e3c4d664e98417f229ad4dd2301a7440fcdb88 100644
--- a/service/sensitive.go
+++ b/service/sensitive.go
@@ -8,48 +8,47 @@ import (
"strings"
)
-func CheckSensitiveMessages(messages []dto.Message) error {
+func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
+ if len(messages) == 0 {
+ return nil, nil
+ }
+
for _, message := range messages {
- if len(message.Content) > 0 {
- if message.IsStringContent() {
- stringContent := message.StringContent()
- if ok, words := SensitiveWordContains(stringContent); ok {
- return errors.New("sensitive words: " + strings.Join(words, ","))
- }
+ arrayContent := message.ParseContent()
+ for _, m := range arrayContent {
+ if m.Type == "image_url" {
+ // TODO: check image url
+ continue
+ }
+ // 检查 text 是否为空
+ if m.Text == "" {
+ continue
}
- } else {
- arrayContent := message.ParseContent()
- for _, m := range arrayContent {
- if m.Type == "image_url" {
- // TODO: check image url
- } else {
- if ok, words := SensitiveWordContains(m.Text); ok {
- return errors.New("sensitive words: " + strings.Join(words, ","))
- }
- }
+ if ok, words := SensitiveWordContains(m.Text); ok {
+ return words, errors.New("sensitive words detected")
}
}
}
- return nil
+ return nil, nil
}
-func CheckSensitiveText(text string) error {
+func CheckSensitiveText(text string) ([]string, error) {
if ok, words := SensitiveWordContains(text); ok {
- return errors.New("sensitive words: " + strings.Join(words, ","))
+ return words, errors.New("sensitive words detected")
}
- return nil
+ return nil, nil
}
-func CheckSensitiveInput(input any) error {
+func CheckSensitiveInput(input any) ([]string, error) {
switch v := input.(type) {
case string:
return CheckSensitiveText(v)
case []string:
- text := ""
+ var builder strings.Builder
for _, s := range v {
- text += s
+ builder.WriteString(s)
}
- return CheckSensitiveText(text)
+ return CheckSensitiveText(builder.String())
}
return CheckSensitiveText(fmt.Sprintf("%v", input))
}
@@ -59,8 +58,11 @@ func SensitiveWordContains(text string) (bool, []string) {
if len(setting.SensitiveWords) == 0 {
return false, nil
}
+ if len(text) == 0 {
+ return false, nil
+ }
checkText := strings.ToLower(text)
- return AcSearch(checkText, setting.SensitiveWords, false)
+ return AcSearch(checkText, setting.SensitiveWords, true)
}
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
@@ -72,14 +74,21 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
m := InitAc(setting.SensitiveWords)
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
- words := make([]string, 0)
+ words := make([]string, 0, len(hits))
+ var builder strings.Builder
+ builder.Grow(len(text))
+ lastPos := 0
+
for _, hit := range hits {
pos := hit.Pos
word := string(hit.Word)
- text = text[:pos] + "**###**" + text[pos+len(word):]
+ builder.WriteString(text[lastPos:pos])
+ builder.WriteString("**###**")
+ lastPos = pos + len(word)
words = append(words, word)
}
- return true, words, text
+ builder.WriteString(text[lastPos:])
+ return true, words, builder.String()
}
return false, nil, text
}
diff --git a/service/user_notify.go b/service/user_notify.go
index 7ae9062bce886eb00f87e0cb17af2a03e18bb292..e01b7aa9c04f3b990e9122470db1a0c596ca6919 100644
--- a/service/user_notify.go
+++ b/service/user_notify.go
@@ -3,15 +3,75 @@ package service
import (
"fmt"
"one-api/common"
+ "one-api/constant"
+ "one-api/dto"
"one-api/model"
+ "strings"
)
-func notifyRootUser(subject string, content string) {
- if common.RootUserEmail == "" {
- common.RootUserEmail = model.GetRootUserEmail()
+func NotifyRootUser(t string, subject string, content string) {
+ user := model.GetRootUser().ToBaseUser()
+ _ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
+}
+
+func NotifyUser(user *model.UserBase, data dto.Notify) error {
+ userSetting := user.GetSetting()
+ notifyType, ok := userSetting[constant.UserSettingNotifyType]
+ if !ok {
+ notifyType = constant.NotifyTypeEmail
}
- err := common.SendEmail(subject, common.RootUserEmail, content)
+
+ // Check notification limit
+ canSend, err := CheckNotificationLimit(user.Id, data.Type)
if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+ common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
+ return err
+ }
+ if !canSend {
+ return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
+ }
+
+ switch notifyType {
+ case constant.NotifyTypeEmail:
+ userEmail := user.Email
+ // check setting email
+ if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
+ userEmail = settingEmail.(string)
+ }
+ if userEmail == "" {
+ common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
+ return nil
+ }
+ return sendEmailNotify(userEmail, data)
+ case constant.NotifyTypeWebhook:
+ webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
+ if !ok {
+ common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
+ return nil
+ }
+ webhookURLStr, ok := webhookURL.(string)
+ if !ok {
+ common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
+ return nil
+ }
+
+ // 获取 webhook secret
+ var webhookSecret string
+ if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
+ webhookSecret, _ = secret.(string)
+ }
+
+ return SendWebhookNotify(webhookURLStr, webhookSecret, data)
+ }
+ return nil
+}
+
+func sendEmailNotify(userEmail string, data dto.Notify) error {
+ // make email content
+ content := data.Content
+ // 处理占位符
+ for _, value := range data.Values {
+ content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
}
+ return common.SendEmail(data.Title, userEmail, content)
}
diff --git a/service/webhook.go b/service/webhook.go
new file mode 100644
index 0000000000000000000000000000000000000000..ad2967eb0b142785976f0c30dcaf61e69f2d654f
--- /dev/null
+++ b/service/webhook.go
@@ -0,0 +1,118 @@
+package service
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/dto"
+ "one-api/setting"
+ "time"
+)
+
+// WebhookPayload webhook 通知的负载数据
+type WebhookPayload struct {
+ Type string `json:"type"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Values []interface{} `json:"values,omitempty"`
+ Timestamp int64 `json:"timestamp"`
+}
+
+// generateSignature 生成 webhook 签名
+func generateSignature(secret string, payload []byte) string {
+ h := hmac.New(sha256.New, []byte(secret))
+ h.Write(payload)
+ return hex.EncodeToString(h.Sum(nil))
+}
+
+// SendWebhookNotify 发送 webhook 通知
+func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error {
+ // 处理占位符
+ content := data.Content
+ for _, value := range data.Values {
+ content = fmt.Sprintf(content, value)
+ }
+
+ // 构建 webhook 负载
+ payload := WebhookPayload{
+ Type: data.Type,
+ Title: data.Title,
+ Content: content,
+ Values: data.Values,
+ Timestamp: time.Now().Unix(),
+ }
+
+ // 序列化负载
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return fmt.Errorf("failed to marshal webhook payload: %v", err)
+ }
+
+ // 创建 HTTP 请求
+ var req *http.Request
+ var resp *http.Response
+
+ if setting.EnableWorker() {
+ // 构建worker请求数据
+ workerReq := &WorkerRequest{
+ URL: webhookURL,
+ Key: setting.WorkerValidKey,
+ Method: http.MethodPost,
+ Headers: map[string]string{
+ "Content-Type": "application/json",
+ },
+ Body: payloadBytes,
+ }
+
+ // 如果有secret,添加签名到headers
+ if secret != "" {
+ signature := generateSignature(secret, payloadBytes)
+ workerReq.Headers["X-Webhook-Signature"] = signature
+ workerReq.Headers["Authorization"] = "Bearer " + secret
+ }
+
+ resp, err = DoWorkerRequest(workerReq)
+ if err != nil {
+ return fmt.Errorf("failed to send webhook request through worker: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
+ }
+ } else {
+ req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
+ if err != nil {
+ return fmt.Errorf("failed to create webhook request: %v", err)
+ }
+
+ // 设置请求头
+ req.Header.Set("Content-Type", "application/json")
+
+ // 如果有 secret,生成签名
+ if secret != "" {
+ signature := generateSignature(secret, payloadBytes)
+ req.Header.Set("X-Webhook-Signature", signature)
+ }
+
+ // 发送请求
+ client := GetImpatientHttpClient()
+ resp, err = client.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send webhook request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
+ }
+ }
+
+ return nil
+}
diff --git a/setting/operation_setting.go b/setting/operation_setting.go
index 9a28e987103d1c7ba9d422cb015f5dfc839e0890..4940d0fc6cc3718b9a633627f1323eb856bd1c53 100644
--- a/setting/operation_setting.go
+++ b/setting/operation_setting.go
@@ -23,6 +23,7 @@ func AutomaticDisableKeywordsFromString(s string) {
ak := strings.Split(s, "\n")
for _, k := range ak {
k = strings.TrimSpace(k)
+ k = strings.ToLower(k)
if k != "" {
AutomaticDisableKeywords = append(AutomaticDisableKeywords, k)
}
diff --git a/setting/system-setting.go b/setting/system_setting.go
similarity index 100%
rename from setting/system-setting.go
rename to setting/system_setting.go
diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js
index 605103ae258bd78c85115fdcf0fad5861706eaf6..71914c4e7c4f5261d16d9847d367a1d79f5100f7 100644
--- a/web/src/components/ChannelsTable.js
+++ b/web/src/components/ChannelsTable.js
@@ -357,6 +357,13 @@ const ChannelsTable = () => {
dataIndex: 'operate',
render: (text, record, index) => {
if (record.children === undefined) {
+ // 构建模型测试菜单
+ const modelMenuItems = record.models.split(',').map(model => ({
+ node: 'item',
+ name: model,
+ onClick: () => testChannel(record, model)
+ }));
+
return (
+{`{
+ "type": "quota_exceed", // 通知类型
+ "title": "标题", // 通知标题
+ "content": "通知内容", // 通知内容,支持 {{value}} 变量占位符
+ "values": ["值1", "值2"], // 按顺序替换content中的 {{value}} 占位符
+ "timestamp": 1739950503 // 时间戳
+}
+
+示例:
+{
+ "type": "quota_exceed",
+ "title": "额度预警通知",
+ "content": "您的额度即将用尽,当前剩余额度为 {{value}}",
+ "values": ["$0.99"],
+ "timestamp": 1739950503
+}`}
+
+