YourUsername commited on
Commit
d9f2cee
·
1 Parent(s): 72eae22

25:02:22 23:01:24

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.example +2 -2
  2. README.en.md +2 -0
  3. Rerank.md +1 -1
  4. VERSION +0 -1
  5. common/constants.go +1 -1
  6. common/go-channel.go +0 -13
  7. common/logger.go +11 -2
  8. common/model-ratio.go +17 -23
  9. constant/env.go +4 -1
  10. constant/user_setting.go +14 -0
  11. controller/channel-test.go +2 -7
  12. controller/pricing.go +1 -1
  13. controller/relay.go +1 -1
  14. controller/user.go +114 -4
  15. docker-compose.yml +1 -1
  16. dto/notify.go +25 -0
  17. dto/openai_request.go +84 -42
  18. dto/openai_response.go +11 -3
  19. main.go +2 -2
  20. middleware/distributor.go +0 -4
  21. model/option.go +2 -2
  22. model/token.go +3 -82
  23. model/token_cache.go +1 -1
  24. model/user.go +77 -13
  25. model/user_cache.go +128 -121
  26. relay/channel/cloudflare/adaptor.go +2 -1
  27. relay/channel/deepseek/adaptor.go +7 -1
  28. relay/channel/gemini/adaptor.go +99 -4
  29. relay/channel/gemini/constant.go +2 -0
  30. relay/channel/gemini/dto.go +27 -0
  31. relay/channel/mistral/adaptor.go +1 -4
  32. relay/channel/mistral/text.go +8 -12
  33. relay/channel/ollama/adaptor.go +2 -1
  34. relay/channel/ollama/dto.go +17 -14
  35. relay/channel/ollama/relay-ollama.go +29 -4
  36. relay/channel/openai/adaptor.go +1 -1
  37. relay/channel/openai/relay-openai.go +6 -1
  38. relay/channel/siliconflow/adaptor.go +8 -0
  39. relay/channel/zhipu_4v/relay-zhipu_v4.go +1 -2
  40. relay/common/relay_info.go +39 -27
  41. relay/helper/model_mapped.go +25 -0
  42. relay/helper/price.go +41 -0
  43. relay/relay-audio.go +13 -20
  44. relay/relay-image.go +16 -24
  45. relay/relay-mj.go +2 -2
  46. relay/relay-text.go +40 -61
  47. relay/relay_embedding.go +9 -32
  48. relay/relay_rerank.go +9 -32
  49. relay/relay_task.go +1 -1
  50. router/api-router.go +1 -0
.env.example CHANGED
@@ -10,9 +10,9 @@
10
 
11
  # 数据库相关配置
12
  # 数据库连接字符串
13
- # SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
14
  # 日志数据库连接字符串
15
- # LOG_SQL_DSN=mysql://user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
16
  # SQLite数据库路径
17
  # SQLITE_PATH=/path/to/sqlite.db
18
  # 数据库最大空闲连接数
 
10
 
11
  # 数据库相关配置
12
  # 数据库连接字符串
13
+ # SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
14
  # 日志数据库连接字符串
15
+ # LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
16
  # SQLite数据库路径
17
  # SQLITE_PATH=/path/to/sqlite.db
18
  # 数据库最大空闲连接数
README.en.md CHANGED
@@ -89,6 +89,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
89
  - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
90
  - `CRYPTO_SECRET`: Encryption key for encrypting database content
91
  - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
 
 
92
 
93
  ## Deployment
94
 
 
89
  - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
90
  - `CRYPTO_SECRET`: Encryption key for encrypting database content
91
  - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
92
+ - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
93
+ - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
94
 
95
  ## Deployment
96
 
Rerank.md CHANGED
@@ -13,7 +13,7 @@ Request:
13
 
14
  ```json
15
  {
16
- "model": "rerank-multilingual-v3.0",
17
  "query": "What is the capital of the United States?",
18
  "top_n": 3,
19
  "documents": [
 
13
 
14
  ```json
15
  {
16
+ "model": "jina-reranker-v2-base-multilingual",
17
  "query": "What is the capital of the United States?",
18
  "top_n": 3,
19
  "documents": [
VERSION CHANGED
@@ -1 +0,0 @@
1
- v0.4.7.2.1
 
 
common/constants.go CHANGED
@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
101
 
102
  var RetryTimes = 0
103
 
104
- var RootUserEmail = ""
105
 
106
  var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
107
 
 
101
 
102
  var RetryTimes = 0
103
 
104
+ //var RootUserEmail = ""
105
 
106
  var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
107
 
common/go-channel.go CHANGED
@@ -1,22 +1,9 @@
1
  package common
2
 
3
  import (
4
- "fmt"
5
- "runtime/debug"
6
  "time"
7
  )
8
 
9
- func SafeGoroutine(f func()) {
10
- go func() {
11
- defer func() {
12
- if r := recover(); r != nil {
13
- SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
14
- }
15
- }()
16
- f()
17
- }()
18
- }
19
-
20
  func SafeSendBool(ch chan bool, value bool) (closed bool) {
21
  defer func() {
22
  // Recover from panic if one occured. A panic would mean the channel was closed.
 
1
  package common
2
 
3
  import (
 
 
4
  "time"
5
  )
6
 
 
 
 
 
 
 
 
 
 
 
 
7
  func SafeSendBool(ch chan bool, value bool) (closed bool) {
8
  defer func() {
9
  // Recover from panic if one occured. A panic would mean the channel was closed.
common/logger.go CHANGED
@@ -4,6 +4,7 @@ import (
4
  "context"
5
  "encoding/json"
6
  "fmt"
 
7
  "github.com/gin-gonic/gin"
8
  "io"
9
  "log"
@@ -80,9 +81,9 @@ func logHelper(ctx context.Context, level string, msg string) {
80
  if logCount > maxLogCount && !setupLogWorking {
81
  logCount = 0
82
  setupLogWorking = true
83
- go func() {
84
  SetupLogger()
85
- }()
86
  }
87
  }
88
 
@@ -100,6 +101,14 @@ func LogQuota(quota int) string {
100
  }
101
  }
102
 
 
 
 
 
 
 
 
 
103
  // LogJson 仅供测试使用 only for test
104
  func LogJson(ctx context.Context, msg string, obj any) {
105
  jsonStr, err := json.Marshal(obj)
 
4
  "context"
5
  "encoding/json"
6
  "fmt"
7
+ "github.com/bytedance/gopkg/util/gopool"
8
  "github.com/gin-gonic/gin"
9
  "io"
10
  "log"
 
81
  if logCount > maxLogCount && !setupLogWorking {
82
  logCount = 0
83
  setupLogWorking = true
84
+ gopool.Go(func() {
85
  SetupLogger()
86
+ })
87
  }
88
  }
89
 
 
101
  }
102
  }
103
 
104
+ func FormatQuota(quota int) string {
105
+ if DisplayInCurrencyEnabled {
106
+ return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
107
+ } else {
108
+ return fmt.Sprintf("%d", quota)
109
+ }
110
+ }
111
+
112
  // LogJson 仅供测试使用 only for test
113
  func LogJson(ctx context.Context, msg string, obj any) {
114
  jsonStr, err := json.Marshal(obj)
common/model-ratio.go CHANGED
@@ -233,7 +233,11 @@ var (
233
  modelRatioMapMutex = sync.RWMutex{}
234
  )
235
 
236
- var CompletionRatio map[string]float64 = nil
 
 
 
 
237
  var defaultCompletionRatio = map[string]float64{
238
  "gpt-4-gizmo-*": 2,
239
  "gpt-4o-gizmo-*": 3,
@@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
334
  return defaultModelRatio
335
  }
336
 
337
- func CompletionRatio2JSONString() string {
 
 
338
  if CompletionRatio == nil {
339
  CompletionRatio = defaultCompletionRatio
340
  }
 
 
 
 
 
341
  jsonBytes, err := json.Marshal(CompletionRatio)
342
  if err != nil {
343
  SysError("error marshalling completion ratio: " + err.Error())
@@ -346,11 +357,15 @@ func CompletionRatio2JSONString() string {
346
  }
347
 
348
  func UpdateCompletionRatioByJSONString(jsonStr string) error {
 
 
349
  CompletionRatio = make(map[string]float64)
350
  return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
351
  }
352
 
353
  func GetCompletionRatio(name string) float64 {
 
 
354
  if strings.Contains(name, "/") {
355
  if ratio, ok := CompletionRatio[name]; ok {
356
  return ratio
@@ -476,24 +491,3 @@ func GetAudioCompletionRatio(name string) float64 {
476
  }
477
  return 2
478
  }
479
-
480
- //func GetAudioPricePerMinute(name string) float64 {
481
- // if strings.HasPrefix(name, "gpt-4o-realtime") {
482
- // return 0.06
483
- // }
484
- // return 0.06
485
- //}
486
- //
487
- //func GetAudioCompletionPricePerMinute(name string) float64 {
488
- // if strings.HasPrefix(name, "gpt-4o-realtime") {
489
- // return 0.24
490
- // }
491
- // return 0.24
492
- //}
493
-
494
- func GetCompletionRatioMap() map[string]float64 {
495
- if CompletionRatio == nil {
496
- CompletionRatio = defaultCompletionRatio
497
- }
498
- return CompletionRatio
499
- }
 
233
  modelRatioMapMutex = sync.RWMutex{}
234
  )
235
 
236
+ var (
237
+ CompletionRatio map[string]float64 = nil
238
+ CompletionRatioMutex = sync.RWMutex{}
239
+ )
240
+
241
  var defaultCompletionRatio = map[string]float64{
242
  "gpt-4-gizmo-*": 2,
243
  "gpt-4o-gizmo-*": 3,
 
338
  return defaultModelRatio
339
  }
340
 
341
+ func GetCompletionRatioMap() map[string]float64 {
342
+ CompletionRatioMutex.Lock()
343
+ defer CompletionRatioMutex.Unlock()
344
  if CompletionRatio == nil {
345
  CompletionRatio = defaultCompletionRatio
346
  }
347
+ return CompletionRatio
348
+ }
349
+
350
+ func CompletionRatio2JSONString() string {
351
+ GetCompletionRatioMap()
352
  jsonBytes, err := json.Marshal(CompletionRatio)
353
  if err != nil {
354
  SysError("error marshalling completion ratio: " + err.Error())
 
357
  }
358
 
359
  func UpdateCompletionRatioByJSONString(jsonStr string) error {
360
+ CompletionRatioMutex.Lock()
361
+ defer CompletionRatioMutex.Unlock()
362
  CompletionRatio = make(map[string]float64)
363
  return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
364
  }
365
 
366
  func GetCompletionRatio(name string) float64 {
367
+ GetCompletionRatioMap()
368
+
369
  if strings.Contains(name, "/") {
370
  if ratio, ok := CompletionRatio[name]; ok {
371
  return ratio
 
491
  }
492
  return 2
493
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
constant/env.go CHANGED
@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
29
 
30
  var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
31
 
 
 
 
32
  func InitEnv() {
33
  modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
34
  if modelVersionMapStr == "" {
@@ -44,5 +47,5 @@ func InitEnv() {
44
  }
45
  }
46
 
47
- // 是否生成初始令牌,默认关闭。
48
  var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
 
29
 
30
  var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
31
 
32
+ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
33
+ var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
34
+
35
  func InitEnv() {
36
  modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
37
  if modelVersionMapStr == "" {
 
47
  }
48
  }
49
 
50
+ // GenerateDefaultToken 是否生成初始令牌,默认关闭。
51
  var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
constant/user_setting.go ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package constant
2
+
3
+ var (
4
+ UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
5
+ UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
6
+ UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
7
+ UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
8
+ UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
9
+ )
10
+
11
+ var (
12
+ NotifyTypeEmail = "email" // Email 邮件
13
+ NotifyTypeWebhook = "webhook" // Webhook
14
+ )
controller/channel-test.go CHANGED
@@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex
238
  var testAllChannelsRunning bool = false
239
 
240
  func testAllChannels(notify bool) error {
241
- if common.RootUserEmail == "" {
242
- common.RootUserEmail = model.GetRootUserEmail()
243
- }
244
  testAllChannelsLock.Lock()
245
  if testAllChannelsRunning {
246
  testAllChannelsLock.Unlock()
@@ -295,10 +293,7 @@ func testAllChannels(notify bool) error {
295
  testAllChannelsRunning = false
296
  testAllChannelsLock.Unlock()
297
  if notify {
298
- err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
299
- if err != nil {
300
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
301
- }
302
  }
303
  })
304
  return nil
 
238
  var testAllChannelsRunning bool = false
239
 
240
  func testAllChannels(notify bool) error {
241
+
 
 
242
  testAllChannelsLock.Lock()
243
  if testAllChannelsRunning {
244
  testAllChannelsLock.Unlock()
 
293
  testAllChannelsRunning = false
294
  testAllChannelsLock.Unlock()
295
  if notify {
296
+ service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
 
 
 
297
  }
298
  })
299
  return nil
controller/pricing.go CHANGED
@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
17
  }
18
  var group string
19
  if exists {
20
- user, err := model.GetUserById(userId.(int), false)
21
  if err == nil {
22
  group = user.Group
23
  }
 
17
  }
18
  var group string
19
  if exists {
20
+ user, err := model.GetUserCache(userId.(int))
21
  if err == nil {
22
  group = user.Group
23
  }
controller/relay.go CHANGED
@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
24
  var err *dto.OpenAIErrorWithStatusCode
25
  switch relayMode {
26
  case relayconstant.RelayModeImagesGenerations:
27
- err = relay.ImageHelper(c, relayMode)
28
  case relayconstant.RelayModeAudioSpeech:
29
  fallthrough
30
  case relayconstant.RelayModeAudioTranslation:
 
24
  var err *dto.OpenAIErrorWithStatusCode
25
  switch relayMode {
26
  case relayconstant.RelayModeImagesGenerations:
27
+ err = relay.ImageHelper(c)
28
  case relayconstant.RelayModeAudioSpeech:
29
  fallthrough
30
  case relayconstant.RelayModeAudioTranslation:
controller/user.go CHANGED
@@ -4,6 +4,7 @@ import (
4
  "encoding/json"
5
  "fmt"
6
  "net/http"
 
7
  "one-api/common"
8
  "one-api/model"
9
  "one-api/setting"
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
471
  if err != nil {
472
  id = c.GetInt("id")
473
  }
474
- user, err := model.GetUserById(id, true)
475
  if err != nil {
476
  c.JSON(http.StatusOK, gin.H{
477
  "success": false,
@@ -869,9 +870,6 @@ func EmailBind(c *gin.Context) {
869
  })
870
  return
871
  }
872
- if user.Role == common.RoleRootUser {
873
- common.RootUserEmail = email
874
- }
875
  c.JSON(http.StatusOK, gin.H{
876
  "success": true,
877
  "message": "",
@@ -913,3 +911,115 @@ func TopUp(c *gin.Context) {
913
  })
914
  return
915
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  "encoding/json"
5
  "fmt"
6
  "net/http"
7
+ "net/url"
8
  "one-api/common"
9
  "one-api/model"
10
  "one-api/setting"
 
472
  if err != nil {
473
  id = c.GetInt("id")
474
  }
475
+ user, err := model.GetUserCache(id)
476
  if err != nil {
477
  c.JSON(http.StatusOK, gin.H{
478
  "success": false,
 
870
  })
871
  return
872
  }
 
 
 
873
  c.JSON(http.StatusOK, gin.H{
874
  "success": true,
875
  "message": "",
 
911
  })
912
  return
913
  }
914
+
915
+ type UpdateUserSettingRequest struct {
916
+ QuotaWarningType string `json:"notify_type"`
917
+ QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
918
+ WebhookUrl string `json:"webhook_url,omitempty"`
919
+ WebhookSecret string `json:"webhook_secret,omitempty"`
920
+ NotificationEmail string `json:"notification_email,omitempty"`
921
+ }
922
+
923
+ func UpdateUserSetting(c *gin.Context) {
924
+ var req UpdateUserSettingRequest
925
+ if err := c.ShouldBindJSON(&req); err != nil {
926
+ c.JSON(http.StatusOK, gin.H{
927
+ "success": false,
928
+ "message": "无效的参数",
929
+ })
930
+ return
931
+ }
932
+
933
+ // 验证预警类型
934
+ if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
935
+ c.JSON(http.StatusOK, gin.H{
936
+ "success": false,
937
+ "message": "无效的预警类型",
938
+ })
939
+ return
940
+ }
941
+
942
+ // 验证预警阈值
943
+ if req.QuotaWarningThreshold <= 0 {
944
+ c.JSON(http.StatusOK, gin.H{
945
+ "success": false,
946
+ "message": "预警阈值必须大于0",
947
+ })
948
+ return
949
+ }
950
+
951
+ // 如果是webhook类型,验证webhook地址
952
+ if req.QuotaWarningType == constant.NotifyTypeWebhook {
953
+ if req.WebhookUrl == "" {
954
+ c.JSON(http.StatusOK, gin.H{
955
+ "success": false,
956
+ "message": "Webhook地址不能为空",
957
+ })
958
+ return
959
+ }
960
+ // 验证URL格式
961
+ if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
962
+ c.JSON(http.StatusOK, gin.H{
963
+ "success": false,
964
+ "message": "无效的Webhook地址",
965
+ })
966
+ return
967
+ }
968
+ }
969
+
970
+ // 如果是邮件类型,验证邮箱地址
971
+ if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
972
+ // 验证邮箱格式
973
+ if !strings.Contains(req.NotificationEmail, "@") {
974
+ c.JSON(http.StatusOK, gin.H{
975
+ "success": false,
976
+ "message": "无效的邮箱地址",
977
+ })
978
+ return
979
+ }
980
+ }
981
+
982
+ userId := c.GetInt("id")
983
+ user, err := model.GetUserById(userId, true)
984
+ if err != nil {
985
+ c.JSON(http.StatusOK, gin.H{
986
+ "success": false,
987
+ "message": err.Error(),
988
+ })
989
+ return
990
+ }
991
+
992
+ // 构建设置
993
+ settings := map[string]interface{}{
994
+ constant.UserSettingNotifyType: req.QuotaWarningType,
995
+ constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
996
+ }
997
+
998
+ // 如果是webhook类型,添加webhook相关设置
999
+ if req.QuotaWarningType == constant.NotifyTypeWebhook {
1000
+ settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
1001
+ if req.WebhookSecret != "" {
1002
+ settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
1003
+ }
1004
+ }
1005
+
1006
+ // 如果提供了通知邮箱,添加到设置中
1007
+ if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
1008
+ settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
1009
+ }
1010
+
1011
+ // 更新用户设置
1012
+ user.SetSetting(settings)
1013
+ if err := user.Update(false); err != nil {
1014
+ c.JSON(http.StatusOK, gin.H{
1015
+ "success": false,
1016
+ "message": "更新设置失败: " + err.Error(),
1017
+ })
1018
+ return
1019
+ }
1020
+
1021
+ c.JSON(http.StatusOK, gin.H{
1022
+ "success": true,
1023
+ "message": "设置已更新",
1024
+ })
1025
+ }
docker-compose.yml CHANGED
@@ -24,7 +24,7 @@ services:
24
  - redis
25
  - mysql
26
  healthcheck:
27
- test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
28
  interval: 30s
29
  timeout: 10s
30
  retries: 3
 
24
  - redis
25
  - mysql
26
  healthcheck:
27
+ test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
28
  interval: 30s
29
  timeout: 10s
30
  retries: 3
dto/notify.go ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package dto
2
+
3
+ type Notify struct {
4
+ Type string `json:"type"`
5
+ Title string `json:"title"`
6
+ Content string `json:"content"`
7
+ Values []interface{} `json:"values"`
8
+ }
9
+
10
+ const ContentValueParam = "{{value}}"
11
+
12
+ const (
13
+ NotifyTypeQuotaExceed = "quota_exceed"
14
+ NotifyTypeChannelUpdate = "channel_update"
15
+ NotifyTypeChannelTest = "channel_test"
16
+ )
17
+
18
+ func NewNotify(t string, title string, content string, values []interface{}) Notify {
19
+ return Notify{
20
+ Type: t,
21
+ Title: title,
22
+ Content: content,
23
+ Values: values,
24
+ }
25
+ }
dto/openai_request.go CHANGED
@@ -18,6 +18,8 @@ type GeneralOpenAIRequest struct {
18
  Model string `json:"model,omitempty"`
19
  Messages []Message `json:"messages,omitempty"`
20
  Prompt any `json:"prompt,omitempty"`
 
 
21
  Stream bool `json:"stream,omitempty"`
22
  StreamOptions *StreamOptions `json:"stream_options,omitempty"`
23
  MaxTokens uint `json:"max_tokens,omitempty"`
@@ -86,18 +88,20 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
86
  }
87
 
88
  type Message struct {
89
- Role string `json:"role"`
90
- Content json.RawMessage `json:"content"`
91
- Name *string `json:"name,omitempty"`
92
- Prefix *bool `json:"prefix,omitempty"`
93
- ReasoningContent string `json:"reasoning_content,omitempty"`
94
- ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
95
- ToolCallId string `json:"tool_call_id,omitempty"`
 
 
96
  }
97
 
98
  type MediaContent struct {
99
  Type string `json:"type"`
100
- Text string `json:"text"`
101
  ImageUrl any `json:"image_url,omitempty"`
102
  InputAudio any `json:"input_audio,omitempty"`
103
  }
@@ -146,6 +150,9 @@ func (m *Message) SetToolCalls(toolCalls any) {
146
  }
147
 
148
  func (m *Message) StringContent() string {
 
 
 
149
  var stringContent string
150
  if err := json.Unmarshal(m.Content, &stringContent); err == nil {
151
  return stringContent
@@ -156,78 +163,113 @@ func (m *Message) StringContent() string {
156
  func (m *Message) SetStringContent(content string) {
157
  jsonContent, _ := json.Marshal(content)
158
  m.Content = jsonContent
 
 
 
 
 
 
 
 
 
159
  }
160
 
161
  func (m *Message) IsStringContent() bool {
 
 
 
162
  var stringContent string
163
  if err := json.Unmarshal(m.Content, &stringContent); err == nil {
 
164
  return true
165
  }
166
  return false
167
  }
168
 
169
  func (m *Message) ParseContent() []MediaContent {
 
 
 
 
170
  var contentList []MediaContent
 
 
171
  var stringContent string
172
  if err := json.Unmarshal(m.Content, &stringContent); err == nil {
173
- contentList = append(contentList, MediaContent{
174
  Type: ContentTypeText,
175
  Text: stringContent,
176
- })
 
177
  return contentList
178
  }
179
- var arrayContent []json.RawMessage
 
 
180
  if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
181
  for _, contentItem := range arrayContent {
182
- var contentMap map[string]any
183
- if err := json.Unmarshal(contentItem, &contentMap); err != nil {
184
  continue
185
  }
186
- switch contentMap["type"] {
 
187
  case ContentTypeText:
188
- if subStr, ok := contentMap["text"].(string); ok {
189
  contentList = append(contentList, MediaContent{
190
  Type: ContentTypeText,
191
- Text: subStr,
192
  })
193
  }
 
194
  case ContentTypeImageURL:
195
- if subObj, ok := contentMap["image_url"].(map[string]any); ok {
196
- detail, ok := subObj["detail"]
197
- if ok {
198
- subObj["detail"] = detail.(string)
199
- } else {
200
- subObj["detail"] = "high"
201
- }
202
  contentList = append(contentList, MediaContent{
203
  Type: ContentTypeImageURL,
204
  ImageUrl: MessageImageUrl{
205
- Url: subObj["url"].(string),
206
- Detail: subObj["detail"].(string),
207
- },
208
- })
209
- } else if url, ok := contentMap["image_url"].(string); ok {
210
- contentList = append(contentList, MediaContent{
211
- Type: ContentTypeImageURL,
212
- ImageUrl: MessageImageUrl{
213
- Url: url,
214
  Detail: "high",
215
  },
216
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  }
 
218
  case ContentTypeInputAudio:
219
- if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
220
- contentList = append(contentList, MediaContent{
221
- Type: ContentTypeInputAudio,
222
- InputAudio: MessageInputAudio{
223
- Data: subObj["data"].(string),
224
- Format: subObj["format"].(string),
225
- },
226
- })
 
 
 
 
227
  }
228
  }
229
  }
230
- return contentList
231
  }
232
- return nil
 
 
 
 
233
  }
 
18
  Model string `json:"model,omitempty"`
19
  Messages []Message `json:"messages,omitempty"`
20
  Prompt any `json:"prompt,omitempty"`
21
+ Prefix any `json:"prefix,omitempty"`
22
+ Suffix any `json:"suffix,omitempty"`
23
  Stream bool `json:"stream,omitempty"`
24
  StreamOptions *StreamOptions `json:"stream_options,omitempty"`
25
  MaxTokens uint `json:"max_tokens,omitempty"`
 
88
  }
89
 
90
  type Message struct {
91
+ Role string `json:"role"`
92
+ Content json.RawMessage `json:"content"`
93
+ Name *string `json:"name,omitempty"`
94
+ Prefix *bool `json:"prefix,omitempty"`
95
+ ReasoningContent string `json:"reasoning_content,omitempty"`
96
+ ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
97
+ ToolCallId string `json:"tool_call_id,omitempty"`
98
+ parsedContent []MediaContent
99
+ parsedStringContent *string
100
  }
101
 
102
  type MediaContent struct {
103
  Type string `json:"type"`
104
+ Text string `json:"text,omitempty"`
105
  ImageUrl any `json:"image_url,omitempty"`
106
  InputAudio any `json:"input_audio,omitempty"`
107
  }
 
150
  }
151
 
152
  func (m *Message) StringContent() string {
153
+ if m.parsedStringContent != nil {
154
+ return *m.parsedStringContent
155
+ }
156
  var stringContent string
157
  if err := json.Unmarshal(m.Content, &stringContent); err == nil {
158
  return stringContent
 
163
  func (m *Message) SetStringContent(content string) {
164
  jsonContent, _ := json.Marshal(content)
165
  m.Content = jsonContent
166
+ m.parsedStringContent = &content
167
+ m.parsedContent = nil
168
+ }
169
+
170
+ func (m *Message) SetMediaContent(content []MediaContent) {
171
+ jsonContent, _ := json.Marshal(content)
172
+ m.Content = jsonContent
173
+ m.parsedContent = nil
174
+ m.parsedStringContent = nil
175
  }
176
 
177
  func (m *Message) IsStringContent() bool {
178
+ if m.parsedStringContent != nil {
179
+ return true
180
+ }
181
  var stringContent string
182
  if err := json.Unmarshal(m.Content, &stringContent); err == nil {
183
+ m.parsedStringContent = &stringContent
184
  return true
185
  }
186
  return false
187
  }
188
 
189
  func (m *Message) ParseContent() []MediaContent {
190
+ if m.parsedContent != nil {
191
+ return m.parsedContent
192
+ }
193
+
194
  var contentList []MediaContent
195
+
196
+ // 先尝试解析为字符串
197
  var stringContent string
198
  if err := json.Unmarshal(m.Content, &stringContent); err == nil {
199
+ contentList = []MediaContent{{
200
  Type: ContentTypeText,
201
  Text: stringContent,
202
+ }}
203
+ m.parsedContent = contentList
204
  return contentList
205
  }
206
+
207
+ // 尝试解析为数组
208
+ var arrayContent []map[string]interface{}
209
  if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
210
  for _, contentItem := range arrayContent {
211
+ contentType, ok := contentItem["type"].(string)
212
+ if !ok {
213
  continue
214
  }
215
+
216
+ switch contentType {
217
  case ContentTypeText:
218
+ if text, ok := contentItem["text"].(string); ok {
219
  contentList = append(contentList, MediaContent{
220
  Type: ContentTypeText,
221
+ Text: text,
222
  })
223
  }
224
+
225
  case ContentTypeImageURL:
226
+ imageUrl := contentItem["image_url"]
227
+ switch v := imageUrl.(type) {
228
+ case string:
 
 
 
 
229
  contentList = append(contentList, MediaContent{
230
  Type: ContentTypeImageURL,
231
  ImageUrl: MessageImageUrl{
232
+ Url: v,
 
 
 
 
 
 
 
 
233
  Detail: "high",
234
  },
235
  })
236
+ case map[string]interface{}:
237
+ url, ok1 := v["url"].(string)
238
+ detail, ok2 := v["detail"].(string)
239
+ if !ok2 {
240
+ detail = "high"
241
+ }
242
+ if ok1 {
243
+ contentList = append(contentList, MediaContent{
244
+ Type: ContentTypeImageURL,
245
+ ImageUrl: MessageImageUrl{
246
+ Url: url,
247
+ Detail: detail,
248
+ },
249
+ })
250
+ }
251
  }
252
+
253
  case ContentTypeInputAudio:
254
+ if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
255
+ data, ok1 := audioData["data"].(string)
256
+ format, ok2 := audioData["format"].(string)
257
+ if ok1 && ok2 {
258
+ contentList = append(contentList, MediaContent{
259
+ Type: ContentTypeInputAudio,
260
+ InputAudio: MessageInputAudio{
261
+ Data: data,
262
+ Format: format,
263
+ },
264
+ })
265
+ }
266
  }
267
  }
268
  }
 
269
  }
270
+
271
+ if len(contentList) > 0 {
272
+ m.parsedContent = contentList
273
+ }
274
+ return contentList
275
  }
dto/openai_response.go CHANGED
@@ -62,9 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
62
  }
63
 
64
  type ChatCompletionsStreamResponseChoiceDelta struct {
65
- Content *string `json:"content,omitempty"`
66
- Role string `json:"role,omitempty"`
67
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
 
68
  }
69
 
70
  func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -78,6 +79,13 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
78
  return *c.Content
79
  }
80
 
 
 
 
 
 
 
 
81
  type ToolCall struct {
82
  // Index is not nil only in chat completion chunk object
83
  Index *int `json:"index,omitempty"`
 
62
  }
63
 
64
  type ChatCompletionsStreamResponseChoiceDelta struct {
65
+ Content *string `json:"content,omitempty"`
66
+ ReasoningContent *string `json:"reasoning_content,omitempty"`
67
+ Role string `json:"role,omitempty"`
68
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
69
  }
70
 
71
  func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
 
79
  return *c.Content
80
  }
81
 
82
+ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
83
+ if c.ReasoningContent == nil {
84
+ return ""
85
+ }
86
+ return *c.ReasoningContent
87
+ }
88
+
89
  type ToolCall struct {
90
  // Index is not nil only in chat completion chunk object
91
  Index *int `json:"index,omitempty"`
main.go CHANGED
@@ -119,9 +119,9 @@ func main() {
119
  }
120
 
121
  if os.Getenv("ENABLE_PPROF") == "true" {
122
- go func() {
123
  log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
124
- }()
125
  go common.Monitor()
126
  common.SysLog("pprof enabled")
127
  }
 
119
  }
120
 
121
  if os.Getenv("ENABLE_PPROF") == "true" {
122
+ gopool.Go(func() {
123
  log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
124
+ })
125
  go common.Monitor()
126
  common.SysLog("pprof enabled")
127
  }
middleware/distributor.go CHANGED
@@ -135,17 +135,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
135
  midjourneyRequest := dto.MidjourneyRequest{}
136
  err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
137
  if err != nil {
138
- abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
139
  return nil, false, err
140
  }
141
  midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
142
  if mjErr != nil {
143
- abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
144
  return nil, false, fmt.Errorf(mjErr.Description)
145
  }
146
  if midjourneyModel == "" {
147
  if !success {
148
- abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
149
  return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
150
  } else {
151
  // task fetch, task fetch by condition, notify
@@ -170,7 +167,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
170
  err = common.UnmarshalBodyReusable(c, &modelRequest)
171
  }
172
  if err != nil {
173
- abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
174
  return nil, false, errors.New("无效的请求, " + err.Error())
175
  }
176
  if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
 
135
  midjourneyRequest := dto.MidjourneyRequest{}
136
  err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
137
  if err != nil {
 
138
  return nil, false, err
139
  }
140
  midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
141
  if mjErr != nil {
 
142
  return nil, false, fmt.Errorf(mjErr.Description)
143
  }
144
  if midjourneyModel == "" {
145
  if !success {
 
146
  return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
147
  } else {
148
  // task fetch, task fetch by condition, notify
 
167
  err = common.UnmarshalBodyReusable(c, &modelRequest)
168
  }
169
  if err != nil {
 
170
  return nil, false, errors.New("无效的请求, " + err.Error())
171
  }
172
  if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
model/option.go CHANGED
@@ -84,7 +84,7 @@ func InitOptionMap() {
84
  common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
85
  common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
86
  common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
87
- common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
88
  common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
89
  common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
90
  common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
306
  common.QuotaForInvitee, _ = strconv.Atoi(value)
307
  case "QuotaRemindThreshold":
308
  common.QuotaRemindThreshold, _ = strconv.Atoi(value)
309
- case "PreConsumedQuota":
310
  common.PreConsumedQuota, _ = strconv.Atoi(value)
311
  case "RetryTimes":
312
  common.RetryTimes, _ = strconv.Atoi(value)
 
84
  common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
85
  common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
86
  common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
87
+ common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
88
  common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
89
  common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
90
  common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
 
306
  common.QuotaForInvitee, _ = strconv.Atoi(value)
307
  case "QuotaRemindThreshold":
308
  common.QuotaRemindThreshold, _ = strconv.Atoi(value)
309
+ case "ShouldPreConsumedQuota":
310
  common.PreConsumedQuota, _ = strconv.Atoi(value)
311
  case "RetryTimes":
312
  common.RetryTimes, _ = strconv.Atoi(value)
model/token.go CHANGED
@@ -3,13 +3,11 @@ package model
3
  import (
4
  "errors"
5
  "fmt"
6
- "github.com/bytedance/gopkg/util/gopool"
7
- "gorm.io/gorm"
8
  "one-api/common"
9
- relaycommon "one-api/relay/common"
10
- "one-api/setting"
11
- "strconv"
12
  "strings"
 
 
 
13
  )
14
 
15
  type Token struct {
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
322
  ).Error
323
  return err
324
  }
325
-
326
- func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
327
- if quota < 0 {
328
- return errors.New("quota 不能为负数!")
329
- }
330
- if relayInfo.IsPlayground {
331
- return nil
332
- }
333
- //if relayInfo.TokenUnlimited {
334
- // return nil
335
- //}
336
- token, err := GetTokenById(relayInfo.TokenId)
337
- if err != nil {
338
- return err
339
- }
340
- if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
341
- return errors.New("令牌额度不足")
342
- }
343
- err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
344
- if err != nil {
345
- return err
346
- }
347
- return nil
348
- }
349
-
350
- func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
351
-
352
- if quota > 0 {
353
- err = DecreaseUserQuota(relayInfo.UserId, quota)
354
- } else {
355
- err = IncreaseUserQuota(relayInfo.UserId, -quota)
356
- }
357
- if err != nil {
358
- return err
359
- }
360
-
361
- if !relayInfo.IsPlayground {
362
- if quota > 0 {
363
- err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
364
- } else {
365
- err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
366
- }
367
- if err != nil {
368
- return err
369
- }
370
- }
371
-
372
- if sendEmail {
373
- if (quota + preConsumedQuota) != 0 {
374
- quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
375
- noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
376
- if quotaTooLow || noMoreQuota {
377
- go func() {
378
- email, err := GetUserEmail(relayInfo.UserId)
379
- if err != nil {
380
- common.SysError("failed to fetch user email: " + err.Error())
381
- }
382
- prompt := "您的额度即将用尽"
383
- if noMoreQuota {
384
- prompt = "您的额度已用尽"
385
- }
386
- if email != "" {
387
- topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
388
- err = common.SendEmail(prompt, email,
389
- fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
390
- if err != nil {
391
- common.SysError("failed to send email" + err.Error())
392
- }
393
- common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
394
- }
395
- }()
396
- }
397
- }
398
- }
399
-
400
- return nil
401
- }
 
3
  import (
4
  "errors"
5
  "fmt"
 
 
6
  "one-api/common"
 
 
 
7
  "strings"
8
+
9
+ "github.com/bytedance/gopkg/util/gopool"
10
+ "gorm.io/gorm"
11
  )
12
 
13
  type Token struct {
 
320
  ).Error
321
  return err
322
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/token_cache.go CHANGED
@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
52
  func cacheGetTokenByKey(key string) (*Token, error) {
53
  hmacKey := common.GenerateHMAC(key)
54
  if !common.RedisEnabled {
55
- return nil, nil
56
  }
57
  var token Token
58
  err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
 
52
  func cacheGetTokenByKey(key string) (*Token, error) {
53
  hmacKey := common.GenerateHMAC(key)
54
  if !common.RedisEnabled {
55
+ return nil, fmt.Errorf("redis is not enabled")
56
  }
57
  var token Token
58
  err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
model/user.go CHANGED
@@ -1,6 +1,7 @@
1
  package model
2
 
3
  import (
 
4
  "errors"
5
  "fmt"
6
  "one-api/common"
@@ -38,6 +39,20 @@ type User struct {
38
  InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
39
  DeletedAt gorm.DeletedAt `gorm:"index"`
40
  LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  }
42
 
43
  func (user *User) GetAccessToken() string {
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
51
  user.AccessToken = &token
52
  }
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
55
  func CheckUserExistOrDeleted(username string, email string) (bool, error) {
56
  var user User
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
315
  return err
316
  }
317
 
318
- // 更新缓存
319
- return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
320
  }
321
 
322
  func (user *User) Edit(updatePassword bool) error {
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
344
  return err
345
  }
346
 
347
- // 更新缓存
348
- return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
349
  }
350
 
351
  func (user *User) Delete() error {
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
371
  // ValidateAndFill check password & user status
372
  func (user *User) ValidateAndFill() (err error) {
373
  // When querying with struct, GORM will only query with non-zero fields,
374
- // that means if your fields value is 0, '', false or other zero values,
375
- // it wont be used to build query conditions
376
  password := user.Password
377
  username := strings.TrimSpace(user.Username)
378
  if username == "" || password == "" {
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
531
  return quota, nil
532
  }
533
  // Don't return error - fall through to DB
534
- //common.SysError("failed to get user quota from cache: " + err.Error())
535
  }
536
  fromDB = true
537
  err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
@@ -580,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
580
  return group, nil
581
  }
582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  func IncreaseUserQuota(id int, quota int) (err error) {
584
  if quota < 0 {
585
  return errors.New("quota 不能为负数!")
@@ -641,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
641
  }
642
  }
643
 
644
- func GetRootUserEmail() (email string) {
645
- DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
646
- return email
 
 
 
 
 
647
  }
648
 
649
  func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
725
  return !errors.Is(err, gorm.ErrRecordNotFound)
726
  }
727
 
728
- func (u *User) FillUserByLinuxDOId() error {
729
- if u.LinuxDOId == "" {
730
  return errors.New("linux do id is empty")
731
  }
732
- err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
733
  return err
734
  }
 
1
  package model
2
 
3
  import (
4
+ "encoding/json"
5
  "errors"
6
  "fmt"
7
  "one-api/common"
 
39
  InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
40
  DeletedAt gorm.DeletedAt `gorm:"index"`
41
  LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
42
+ Setting string `json:"setting" gorm:"type:text;column:setting"`
43
+ }
44
+
45
+ func (user *User) ToBaseUser() *UserBase {
46
+ cache := &UserBase{
47
+ Id: user.Id,
48
+ Group: user.Group,
49
+ Quota: user.Quota,
50
+ Status: user.Status,
51
+ Username: user.Username,
52
+ Setting: user.Setting,
53
+ Email: user.Email,
54
+ }
55
+ return cache
56
  }
57
 
58
  func (user *User) GetAccessToken() string {
 
66
  user.AccessToken = &token
67
  }
68
 
69
+ func (user *User) GetSetting() map[string]interface{} {
70
+ if user.Setting == "" {
71
+ return nil
72
+ }
73
+ return common.StrToMap(user.Setting)
74
+ }
75
+
76
+ func (user *User) SetSetting(setting map[string]interface{}) {
77
+ settingBytes, err := json.Marshal(setting)
78
+ if err != nil {
79
+ common.SysError("failed to marshal setting: " + err.Error())
80
+ return
81
+ }
82
+ user.Setting = string(settingBytes)
83
+ }
84
+
85
  // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
86
  func CheckUserExistOrDeleted(username string, email string) (bool, error) {
87
  var user User
 
346
  return err
347
  }
348
 
349
+ // Update cache
350
+ return updateUserCache(*user)
351
  }
352
 
353
  func (user *User) Edit(updatePassword bool) error {
 
375
  return err
376
  }
377
 
378
+ // Update cache
379
+ return updateUserCache(*user)
380
  }
381
 
382
  func (user *User) Delete() error {
 
402
  // ValidateAndFill check password & user status
403
  func (user *User) ValidateAndFill() (err error) {
404
  // When querying with struct, GORM will only query with non-zero fields,
405
+ // that means if your field's value is 0, '', false or other zero values,
406
+ // it won't be used to build query conditions
407
  password := user.Password
408
  username := strings.TrimSpace(user.Username)
409
  if username == "" || password == "" {
 
562
  return quota, nil
563
  }
564
  // Don't return error - fall through to DB
 
565
  }
566
  fromDB = true
567
  err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
 
610
  return group, nil
611
  }
612
 
613
+ // GetUserSetting gets setting from Redis first, falls back to DB if needed
614
+ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
615
+ var setting string
616
+ defer func() {
617
+ // Update Redis cache asynchronously on successful DB read
618
+ if shouldUpdateRedis(fromDB, err) {
619
+ gopool.Go(func() {
620
+ if err := updateUserSettingCache(id, setting); err != nil {
621
+ common.SysError("failed to update user setting cache: " + err.Error())
622
+ }
623
+ })
624
+ }
625
+ }()
626
+ if !fromDB && common.RedisEnabled {
627
+ setting, err := getUserSettingCache(id)
628
+ if err == nil {
629
+ return setting, nil
630
+ }
631
+ // Don't return error - fall through to DB
632
+ }
633
+ fromDB = true
634
+ err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
635
+ if err != nil {
636
+ return map[string]interface{}{}, err
637
+ }
638
+
639
+ return common.StrToMap(setting), nil
640
+ }
641
+
642
  func IncreaseUserQuota(id int, quota int) (err error) {
643
  if quota < 0 {
644
  return errors.New("quota 不能为负数!")
 
700
  }
701
  }
702
 
703
+ //func GetRootUserEmail() (email string) {
704
+ // DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
705
+ // return email
706
+ //}
707
+
708
+ func GetRootUser() (user *User) {
709
+ DB.Where("role = ?", common.RoleRootUser).First(&user)
710
+ return user
711
  }
712
 
713
  func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
 
789
  return !errors.Is(err, gorm.ErrRecordNotFound)
790
  }
791
 
792
+ func (user *User) FillUserByLinuxDOId() error {
793
+ if user.LinuxDOId == "" {
794
  return errors.New("linux do id is empty")
795
  }
796
+ err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
797
  return err
798
  }
model/user_cache.go CHANGED
@@ -1,206 +1,213 @@
1
  package model
2
 
3
  import (
 
4
  "fmt"
5
  "one-api/common"
6
  "one-api/constant"
7
- "strconv"
8
  "time"
 
 
9
  )
10
 
11
- // Change UserCache struct to userCache
12
- type userCache struct {
13
  Id int `json:"id"`
14
  Group string `json:"group"`
 
15
  Quota int `json:"quota"`
16
  Status int `json:"status"`
17
- Role int `json:"role"`
18
  Username string `json:"username"`
 
19
  }
20
 
21
- // Rename all exported functions to private ones
22
- // invalidateUserCache clears all user related cache
23
- func invalidateUserCache(userId int) error {
24
- if !common.RedisEnabled {
25
  return nil
26
  }
27
-
28
- keys := []string{
29
- fmt.Sprintf(constant.UserGroupKeyFmt, userId),
30
- fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
31
- fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
32
- fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
33
- }
34
-
35
- for _, key := range keys {
36
- if err := common.RedisDel(key); err != nil {
37
- return fmt.Errorf("failed to delete cache key %s: %w", key, err)
38
- }
39
- }
40
- return nil
41
  }
42
 
43
- // updateUserGroupCache updates user group cache
44
- func updateUserGroupCache(userId int, group string) error {
45
- if !common.RedisEnabled {
46
- return nil
 
47
  }
48
- return common.RedisSet(
49
- fmt.Sprintf(constant.UserGroupKeyFmt, userId),
50
- group,
51
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
52
- )
53
  }
54
 
55
- // updateUserQuotaCache updates user quota cache
56
- func updateUserQuotaCache(userId int, quota int) error {
57
- if !common.RedisEnabled {
58
- return nil
59
- }
60
- return common.RedisSet(
61
- fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
62
- fmt.Sprintf("%d", quota),
63
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
64
- )
65
  }
66
 
67
- // updateUserStatusCache updates user status cache
68
- func updateUserStatusCache(userId int, userEnabled bool) error {
69
  if !common.RedisEnabled {
70
  return nil
71
  }
72
- enabled := "0"
73
- if userEnabled {
74
- enabled = "1"
75
- }
76
- return common.RedisSet(
77
- fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
78
- enabled,
79
- time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
80
- )
81
  }
82
 
83
- // updateUserNameCache updates username cache
84
- func updateUserNameCache(userId int, username string) error {
85
  if !common.RedisEnabled {
86
  return nil
87
  }
88
- return common.RedisSet(
89
- fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
90
- username,
 
91
  time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
92
  )
93
  }
94
 
95
- // updateUserCache updates all user cache fields
96
- func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
97
- if !common.RedisEnabled {
98
- return nil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  }
100
 
101
- if err := updateUserGroupCache(userId, userGroup); err != nil {
102
- return fmt.Errorf("update group cache: %w", err)
 
 
 
103
  }
104
 
105
- if err := updateUserQuotaCache(userId, quota); err != nil {
106
- return fmt.Errorf("update quota cache: %w", err)
 
 
 
 
 
 
 
107
  }
108
 
109
- if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
110
- return fmt.Errorf("update status cache: %w", err)
 
 
 
 
111
  }
 
 
 
 
 
 
 
 
112
 
113
- if err := updateUserNameCache(userId, username); err != nil {
114
- return fmt.Errorf("update username cache: %w", err)
 
 
115
  }
 
 
116
 
117
- return nil
 
118
  }
119
 
120
- // getUserGroupCache gets user group from cache
121
  func getUserGroupCache(userId int) (string, error) {
122
- if !common.RedisEnabled {
123
- return "", nil
 
124
  }
125
- return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
126
  }
127
 
128
- // getUserQuotaCache gets user quota from cache
129
  func getUserQuotaCache(userId int) (int, error) {
130
- if !common.RedisEnabled {
131
- return 0, nil
132
- }
133
- quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
134
  if err != nil {
135
  return 0, err
136
  }
137
- return strconv.Atoi(quotaStr)
138
  }
139
 
140
- // getUserStatusCache gets user status from cache
141
  func getUserStatusCache(userId int) (int, error) {
142
- if !common.RedisEnabled {
143
- return 0, nil
144
- }
145
- statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
146
  if err != nil {
147
  return 0, err
148
  }
149
- return strconv.Atoi(statusStr)
150
  }
151
 
152
- // getUserNameCache gets username from cache
153
  func getUserNameCache(userId int) (string, error) {
154
- if !common.RedisEnabled {
155
- return "", nil
 
156
  }
157
- return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
158
  }
159
 
160
- // getUserCache gets complete user cache
161
- func getUserCache(userId int) (*userCache, error) {
162
- if !common.RedisEnabled {
163
- return nil, nil
164
- }
165
-
166
- group, err := getUserGroupCache(userId)
167
  if err != nil {
168
- return nil, fmt.Errorf("get group cache: %w", err)
169
  }
 
 
170
 
171
- quota, err := getUserQuotaCache(userId)
172
- if err != nil {
173
- return nil, fmt.Errorf("get quota cache: %w", err)
 
174
  }
175
-
176
- status, err := getUserStatusCache(userId)
177
- if err != nil {
178
- return nil, fmt.Errorf("get status cache: %w", err)
179
  }
 
 
180
 
181
- username, err := getUserNameCache(userId)
182
- if err != nil {
183
- return nil, fmt.Errorf("get username cache: %w", err)
184
  }
 
 
185
 
186
- return &userCache{
187
- Id: userId,
188
- Group: group,
189
- Quota: quota,
190
- Status: status,
191
- Username: username,
192
- }, nil
193
  }
194
 
195
- // Add atomic quota operations
196
- func cacheIncrUserQuota(userId int, delta int64) error {
197
  if !common.RedisEnabled {
198
  return nil
199
  }
200
- key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
201
- return common.RedisIncr(key, delta)
202
  }
203
 
204
- func cacheDecrUserQuota(userId int, delta int64) error {
205
- return cacheIncrUserQuota(userId, -delta)
 
 
 
206
  }
 
1
  package model
2
 
3
  import (
4
+ "encoding/json"
5
  "fmt"
6
  "one-api/common"
7
  "one-api/constant"
 
8
  "time"
9
+
10
+ "github.com/bytedance/gopkg/util/gopool"
11
  )
12
 
13
+ // UserBase struct remains the same as it represents the cached data structure
14
+ type UserBase struct {
15
  Id int `json:"id"`
16
  Group string `json:"group"`
17
+ Email string `json:"email"`
18
  Quota int `json:"quota"`
19
  Status int `json:"status"`
 
20
  Username string `json:"username"`
21
+ Setting string `json:"setting"`
22
  }
23
 
24
+ func (user *UserBase) GetSetting() map[string]interface{} {
25
+ if user.Setting == "" {
 
 
26
  return nil
27
  }
28
+ return common.StrToMap(user.Setting)
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  }
30
 
31
+ func (user *UserBase) SetSetting(setting map[string]interface{}) {
32
+ settingBytes, err := json.Marshal(setting)
33
+ if err != nil {
34
+ common.SysError("failed to marshal setting: " + err.Error())
35
+ return
36
  }
37
+ user.Setting = string(settingBytes)
 
 
 
 
38
  }
39
 
40
+ // getUserCacheKey returns the key for user cache
41
+ func getUserCacheKey(userId int) string {
42
+ return fmt.Sprintf("user:%d", userId)
 
 
 
 
 
 
 
43
  }
44
 
45
+ // invalidateUserCache clears user cache
46
+ func invalidateUserCache(userId int) error {
47
  if !common.RedisEnabled {
48
  return nil
49
  }
50
+ return common.RedisHDelObj(getUserCacheKey(userId))
 
 
 
 
 
 
 
 
51
  }
52
 
53
+ // updateUserCache updates all user cache fields using hash
54
+ func updateUserCache(user User) error {
55
  if !common.RedisEnabled {
56
  return nil
57
  }
58
+
59
+ return common.RedisHSetObj(
60
+ getUserCacheKey(user.Id),
61
+ user.ToBaseUser(),
62
  time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
63
  )
64
  }
65
 
66
+ // GetUserCache gets complete user cache from hash
67
+ func GetUserCache(userId int) (userCache *UserBase, err error) {
68
+ var user *User
69
+ var fromDB bool
70
+ defer func() {
71
+ // Update Redis cache asynchronously on successful DB read
72
+ if shouldUpdateRedis(fromDB, err) && user != nil {
73
+ gopool.Go(func() {
74
+ if err := updateUserCache(*user); err != nil {
75
+ common.SysError("failed to update user status cache: " + err.Error())
76
+ }
77
+ })
78
+ }
79
+ }()
80
+
81
+ // Try getting from Redis first
82
+ userCache, err = cacheGetUserBase(userId)
83
+ if err == nil {
84
+ return userCache, nil
85
  }
86
 
87
+ // If Redis fails, get from DB
88
+ fromDB = true
89
+ user, err = GetUserById(userId, false)
90
+ if err != nil {
91
+ return nil, err // Return nil and error if DB lookup fails
92
  }
93
 
94
+ // Create cache object from user data
95
+ userCache = &UserBase{
96
+ Id: user.Id,
97
+ Group: user.Group,
98
+ Quota: user.Quota,
99
+ Status: user.Status,
100
+ Username: user.Username,
101
+ Setting: user.Setting,
102
+ Email: user.Email,
103
  }
104
 
105
+ return userCache, nil
106
+ }
107
+
108
+ func cacheGetUserBase(userId int) (*UserBase, error) {
109
+ if !common.RedisEnabled {
110
+ return nil, fmt.Errorf("redis is not enabled")
111
  }
112
+ var userCache UserBase
113
+ // Try getting from Redis first
114
+ err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
115
+ if err != nil {
116
+ return nil, err
117
+ }
118
+ return &userCache, nil
119
+ }
120
 
121
+ // Add atomic quota operations using hash fields
122
+ func cacheIncrUserQuota(userId int, delta int64) error {
123
+ if !common.RedisEnabled {
124
+ return nil
125
  }
126
+ return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
127
+ }
128
 
129
+ func cacheDecrUserQuota(userId int, delta int64) error {
130
+ return cacheIncrUserQuota(userId, -delta)
131
  }
132
 
133
+ // Helper functions to get individual fields if needed
134
  func getUserGroupCache(userId int) (string, error) {
135
+ cache, err := GetUserCache(userId)
136
+ if err != nil {
137
+ return "", err
138
  }
139
+ return cache.Group, nil
140
  }
141
 
 
142
  func getUserQuotaCache(userId int) (int, error) {
143
+ cache, err := GetUserCache(userId)
 
 
 
144
  if err != nil {
145
  return 0, err
146
  }
147
+ return cache.Quota, nil
148
  }
149
 
 
150
  func getUserStatusCache(userId int) (int, error) {
151
+ cache, err := GetUserCache(userId)
 
 
 
152
  if err != nil {
153
  return 0, err
154
  }
155
+ return cache.Status, nil
156
  }
157
 
 
158
  func getUserNameCache(userId int) (string, error) {
159
+ cache, err := GetUserCache(userId)
160
+ if err != nil {
161
+ return "", err
162
  }
163
+ return cache.Username, nil
164
  }
165
 
166
+ func getUserSettingCache(userId int) (map[string]interface{}, error) {
167
+ setting := make(map[string]interface{})
168
+ cache, err := GetUserCache(userId)
 
 
 
 
169
  if err != nil {
170
+ return setting, err
171
  }
172
+ return cache.GetSetting(), nil
173
+ }
174
 
175
+ // New functions for individual field updates
176
+ func updateUserStatusCache(userId int, status bool) error {
177
+ if !common.RedisEnabled {
178
+ return nil
179
  }
180
+ statusInt := common.UserStatusEnabled
181
+ if !status {
182
+ statusInt = common.UserStatusDisabled
 
183
  }
184
+ return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
185
+ }
186
 
187
+ func updateUserQuotaCache(userId int, quota int) error {
188
+ if !common.RedisEnabled {
189
+ return nil
190
  }
191
+ return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
192
+ }
193
 
194
+ func updateUserGroupCache(userId int, group string) error {
195
+ if !common.RedisEnabled {
196
+ return nil
197
+ }
198
+ return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
 
 
199
  }
200
 
201
+ func updateUserNameCache(userId int, username string) error {
 
202
  if !common.RedisEnabled {
203
  return nil
204
  }
205
+ return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
 
206
  }
207
 
208
+ func updateUserSettingCache(userId int, setting string) error {
209
+ if !common.RedisEnabled {
210
+ return nil
211
+ }
212
+ return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
213
  }
relay/channel/cloudflare/adaptor.go CHANGED
@@ -4,13 +4,14 @@ import (
4
  "bytes"
5
  "errors"
6
  "fmt"
7
- "github.com/gin-gonic/gin"
8
  "io"
9
  "net/http"
10
  "one-api/dto"
11
  "one-api/relay/channel"
12
  relaycommon "one-api/relay/common"
13
  "one-api/relay/constant"
 
 
14
  )
15
 
16
  type Adaptor struct {
 
4
  "bytes"
5
  "errors"
6
  "fmt"
 
7
  "io"
8
  "net/http"
9
  "one-api/dto"
10
  "one-api/relay/channel"
11
  relaycommon "one-api/relay/common"
12
  "one-api/relay/constant"
13
+
14
+ "github.com/gin-gonic/gin"
15
  )
16
 
17
  type Adaptor struct {
relay/channel/deepseek/adaptor.go CHANGED
@@ -10,6 +10,7 @@ import (
10
  "one-api/relay/channel"
11
  "one-api/relay/channel/openai"
12
  relaycommon "one-api/relay/common"
 
13
  )
14
 
15
  type Adaptor struct {
@@ -29,7 +30,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
29
  }
30
 
31
  func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
32
- return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
 
 
 
 
 
33
  }
34
 
35
  func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 
10
  "one-api/relay/channel"
11
  "one-api/relay/channel/openai"
12
  relaycommon "one-api/relay/common"
13
+ "one-api/relay/constant"
14
  )
15
 
16
  type Adaptor struct {
 
30
  }
31
 
32
  func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
33
+ switch info.RelayMode {
34
+ case constant.RelayModeCompletions:
35
+ return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
36
+ default:
37
+ return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
38
+ }
39
  }
40
 
41
  func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
relay/channel/gemini/adaptor.go CHANGED
@@ -1,15 +1,21 @@
1
  package gemini
2
 
3
  import (
 
4
  "errors"
5
  "fmt"
6
- "github.com/gin-gonic/gin"
7
  "io"
8
  "net/http"
 
9
  "one-api/constant"
10
  "one-api/dto"
11
  "one-api/relay/channel"
12
  relaycommon "one-api/relay/common"
 
 
 
 
 
13
  )
14
 
15
  type Adaptor struct {
@@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
21
  }
22
 
23
  func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
24
- //TODO implement me
25
- return nil, errors.New("not implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  }
27
 
28
  func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
40
  }
41
  }
42
 
 
 
 
 
43
  action := "generateContent"
44
  if info.IsStream {
45
  action = "streamGenerateContent?alt=sse"
@@ -73,12 +111,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
73
  return nil, errors.New("not implemented")
74
  }
75
 
76
-
77
  func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
78
  return channel.DoApiRequest(a, c, info, requestBody)
79
  }
80
 
81
  func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 
 
 
 
82
  if info.IsStream {
83
  err, usage = GeminiChatStreamHandler(c, resp, info)
84
  } else {
@@ -87,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
87
  return
88
  }
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  func (a *Adaptor) GetModelList() []string {
91
  return ModelList
92
  }
 
1
  package gemini
2
 
3
  import (
4
+ "encoding/json"
5
  "errors"
6
  "fmt"
 
7
  "io"
8
  "net/http"
9
+ "one-api/common"
10
  "one-api/constant"
11
  "one-api/dto"
12
  "one-api/relay/channel"
13
  relaycommon "one-api/relay/common"
14
+ "one-api/service"
15
+
16
+ "strings"
17
+
18
+ "github.com/gin-gonic/gin"
19
  )
20
 
21
  type Adaptor struct {
 
27
  }
28
 
29
  func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
30
+ if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
31
+ return nil, errors.New("not supported model for image generation")
32
+ }
33
+
34
+ // convert size to aspect ratio
35
+ aspectRatio := "1:1" // default aspect ratio
36
+ switch request.Size {
37
+ case "1024x1024":
38
+ aspectRatio = "1:1"
39
+ case "1024x1792":
40
+ aspectRatio = "9:16"
41
+ case "1792x1024":
42
+ aspectRatio = "16:9"
43
+ }
44
+
45
+ // build gemini imagen request
46
+ geminiRequest := GeminiImageRequest{
47
+ Instances: []GeminiImageInstance{
48
+ {
49
+ Prompt: request.Prompt,
50
+ },
51
+ },
52
+ Parameters: GeminiImageParameters{
53
+ SampleCount: request.N,
54
+ AspectRatio: aspectRatio,
55
+ PersonGeneration: "allow_adult", // default allow adult
56
+ },
57
+ }
58
+
59
+ return geminiRequest, nil
60
  }
61
 
62
  func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 
74
  }
75
  }
76
 
77
+ if strings.HasPrefix(info.UpstreamModelName, "imagen") {
78
+ return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
79
+ }
80
+
81
  action := "generateContent"
82
  if info.IsStream {
83
  action = "streamGenerateContent?alt=sse"
 
111
  return nil, errors.New("not implemented")
112
  }
113
 
 
114
  func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
115
  return channel.DoApiRequest(a, c, info, requestBody)
116
  }
117
 
118
  func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
119
+ if strings.HasPrefix(info.UpstreamModelName, "imagen") {
120
+ return GeminiImageHandler(c, resp, info)
121
+ }
122
+
123
  if info.IsStream {
124
  err, usage = GeminiChatStreamHandler(c, resp, info)
125
  } else {
 
128
  return
129
  }
130
 
131
+ func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
132
+ responseBody, readErr := io.ReadAll(resp.Body)
133
+ if readErr != nil {
134
+ return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
135
+ }
136
+ _ = resp.Body.Close()
137
+
138
+ var geminiResponse GeminiImageResponse
139
+ if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
140
+ return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
141
+ }
142
+
143
+ if len(geminiResponse.Predictions) == 0 {
144
+ return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
145
+ }
146
+
147
+ // convert to openai format response
148
+ openAIResponse := dto.ImageResponse{
149
+ Created: common.GetTimestamp(),
150
+ Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
151
+ }
152
+
153
+ for _, prediction := range geminiResponse.Predictions {
154
+ if prediction.RaiFilteredReason != "" {
155
+ continue // skip filtered image
156
+ }
157
+ openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
158
+ B64Json: prediction.BytesBase64Encoded,
159
+ })
160
+ }
161
+
162
+ jsonResponse, jsonErr := json.Marshal(openAIResponse)
163
+ if jsonErr != nil {
164
+ return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
165
+ }
166
+
167
+ c.Writer.Header().Set("Content-Type", "application/json")
168
+ c.Writer.WriteHeader(resp.StatusCode)
169
+ _, _ = c.Writer.Write(jsonResponse)
170
+
171
+ // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
172
+ // each image has fixed 258 tokens
173
+ const imageTokens = 258
174
+ generatedImages := len(openAIResponse.Data)
175
+
176
+ usage = &dto.Usage{
177
+ PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
178
+ CompletionTokens: 0, // image generation does not calculate completion tokens
179
+ TotalTokens: imageTokens * generatedImages,
180
+ }
181
+
182
+ return usage, nil
183
+ }
184
+
185
  func (a *Adaptor) GetModelList() []string {
186
  return ModelList
187
  }
relay/channel/gemini/constant.go CHANGED
@@ -16,6 +16,8 @@ var ModelList = []string{
16
  "gemini-2.0-pro-exp",
17
  // thinking exp
18
  "gemini-2.0-flash-thinking-exp",
 
 
19
  }
20
 
21
  var ChannelName = "google gemini"
 
16
  "gemini-2.0-pro-exp",
17
  // thinking exp
18
  "gemini-2.0-flash-thinking-exp",
19
+ // imagen models
20
+ "imagen-3.0-generate-002",
21
  }
22
 
23
  var ChannelName = "google gemini"
relay/channel/gemini/dto.go CHANGED
@@ -109,3 +109,30 @@ type GeminiUsageMetadata struct {
109
  CandidatesTokenCount int `json:"candidatesTokenCount"`
110
  TotalTokenCount int `json:"totalTokenCount"`
111
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  CandidatesTokenCount int `json:"candidatesTokenCount"`
110
  TotalTokenCount int `json:"totalTokenCount"`
111
  }
112
+
113
+ // Imagen related structs
114
+ type GeminiImageRequest struct {
115
+ Instances []GeminiImageInstance `json:"instances"`
116
+ Parameters GeminiImageParameters `json:"parameters"`
117
+ }
118
+
119
+ type GeminiImageInstance struct {
120
+ Prompt string `json:"prompt"`
121
+ }
122
+
123
+ type GeminiImageParameters struct {
124
+ SampleCount int `json:"sampleCount,omitempty"`
125
+ AspectRatio string `json:"aspectRatio,omitempty"`
126
+ PersonGeneration string `json:"personGeneration,omitempty"`
127
+ }
128
+
129
+ type GeminiImageResponse struct {
130
+ Predictions []GeminiImagePrediction `json:"predictions"`
131
+ }
132
+
133
+ type GeminiImagePrediction struct {
134
+ MimeType string `json:"mimeType"`
135
+ BytesBase64Encoded string `json:"bytesBase64Encoded"`
136
+ RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
137
+ SafetyAttributes any `json:"safetyAttributes,omitempty"`
138
+ }
relay/channel/mistral/adaptor.go CHANGED
@@ -41,9 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
41
  if request == nil {
42
  return nil, errors.New("request is nil")
43
  }
44
- mistralReq := requestOpenAI2Mistral(*request)
45
- //common.LogJson(c, "body", mistralReq)
46
- return mistralReq, nil
47
  }
48
 
49
  func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -55,7 +53,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
55
  return nil, errors.New("not implemented")
56
  }
57
 
58
-
59
  func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
60
  return channel.DoApiRequest(a, c, info, requestBody)
61
  }
 
41
  if request == nil {
42
  return nil, errors.New("request is nil")
43
  }
44
+ return requestOpenAI2Mistral(request), nil
 
 
45
  }
46
 
47
  func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 
53
  return nil, errors.New("not implemented")
54
  }
55
 
 
56
  func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
57
  return channel.DoApiRequest(a, c, info, requestBody)
58
  }
relay/channel/mistral/text.go CHANGED
@@ -1,25 +1,21 @@
1
  package mistral
2
 
3
  import (
4
- "encoding/json"
5
  "one-api/dto"
6
  )
7
 
8
- func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
9
  messages := make([]dto.Message, 0, len(request.Messages))
10
  for _, message := range request.Messages {
11
- if !message.IsStringContent() {
12
- mediaMessages := message.ParseContent()
13
- for j, mediaMessage := range mediaMessages {
14
- if mediaMessage.Type == dto.ContentTypeImageURL {
15
- imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
16
- mediaMessage.ImageUrl = imageUrl.Url
17
- mediaMessages[j] = mediaMessage
18
- }
19
  }
20
- messageRaw, _ := json.Marshal(mediaMessages)
21
- message.Content = messageRaw
22
  }
 
23
  messages = append(messages, dto.Message{
24
  Role: message.Role,
25
  Content: message.Content,
 
1
  package mistral
2
 
3
  import (
 
4
  "one-api/dto"
5
  )
6
 
7
+ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
8
  messages := make([]dto.Message, 0, len(request.Messages))
9
  for _, message := range request.Messages {
10
+ mediaMessages := message.ParseContent()
11
+ for j, mediaMessage := range mediaMessages {
12
+ if mediaMessage.Type == dto.ContentTypeImageURL {
13
+ imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
14
+ mediaMessage.ImageUrl = imageUrl.Url
15
+ mediaMessages[j] = mediaMessage
 
 
16
  }
 
 
17
  }
18
+ message.SetMediaContent(mediaMessages)
19
  messages = append(messages, dto.Message{
20
  Role: message.Role,
21
  Content: message.Content,
relay/channel/ollama/adaptor.go CHANGED
@@ -39,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
39
 
40
  func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
41
  channel.SetupApiRequestHeader(info, c, req)
 
42
  return nil
43
  }
44
 
@@ -46,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
46
  if request == nil {
47
  return nil, errors.New("request is nil")
48
  }
49
- return requestOpenAI2Ollama(*request), nil
50
  }
51
 
52
  func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 
39
 
40
  func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
41
  channel.SetupApiRequestHeader(info, c, req)
42
+ req.Set("Authorization", "Bearer "+info.ApiKey)
43
  return nil
44
  }
45
 
 
47
  if request == nil {
48
  return nil, errors.New("request is nil")
49
  }
50
+ return requestOpenAI2Ollama(*request)
51
  }
52
 
53
  func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
relay/channel/ollama/dto.go CHANGED
@@ -3,18 +3,21 @@ package ollama
3
  import "one-api/dto"
4
 
5
  type OllamaRequest struct {
6
- Model string `json:"model,omitempty"`
7
- Messages []dto.Message `json:"messages,omitempty"`
8
- Stream bool `json:"stream,omitempty"`
9
- Temperature *float64 `json:"temperature,omitempty"`
10
- Seed float64 `json:"seed,omitempty"`
11
- Topp float64 `json:"top_p,omitempty"`
12
- TopK int `json:"top_k,omitempty"`
13
- Stop any `json:"stop,omitempty"`
14
- Tools []dto.ToolCall `json:"tools,omitempty"`
15
- ResponseFormat any `json:"response_format,omitempty"`
16
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
17
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
 
 
 
18
  }
19
 
20
  type Options struct {
@@ -35,7 +38,7 @@ type OllamaEmbeddingRequest struct {
35
  }
36
 
37
  type OllamaEmbeddingResponse struct {
38
- Error string `json:"error,omitempty"`
39
- Model string `json:"model"`
40
  Embedding [][]float64 `json:"embeddings,omitempty"`
41
  }
 
3
  import "one-api/dto"
4
 
5
  type OllamaRequest struct {
6
+ Model string `json:"model,omitempty"`
7
+ Messages []dto.Message `json:"messages,omitempty"`
8
+ Stream bool `json:"stream,omitempty"`
9
+ Temperature *float64 `json:"temperature,omitempty"`
10
+ Seed float64 `json:"seed,omitempty"`
11
+ Topp float64 `json:"top_p,omitempty"`
12
+ TopK int `json:"top_k,omitempty"`
13
+ Stop any `json:"stop,omitempty"`
14
+ Tools []dto.ToolCall `json:"tools,omitempty"`
15
+ ResponseFormat any `json:"response_format,omitempty"`
16
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
17
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
18
+ Suffix any `json:"suffix,omitempty"`
19
+ StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
20
+ Prompt any `json:"prompt,omitempty"`
21
  }
22
 
23
  type Options struct {
 
38
  }
39
 
40
  type OllamaEmbeddingResponse struct {
41
+ Error string `json:"error,omitempty"`
42
+ Model string `json:"model"`
43
  Embedding [][]float64 `json:"embeddings,omitempty"`
44
  }
relay/channel/ollama/relay-ollama.go CHANGED
@@ -9,14 +9,36 @@ import (
9
  "net/http"
10
  "one-api/dto"
11
  "one-api/service"
 
12
  )
13
 
14
- func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
15
  messages := make([]dto.Message, 0, len(request.Messages))
16
  for _, message := range request.Messages {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  messages = append(messages, dto.Message{
18
- Role: message.Role,
19
- Content: message.Content,
 
 
20
  })
21
  }
22
  str, ok := request.Stop.(string)
@@ -39,7 +61,10 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
39
  ResponseFormat: request.ResponseFormat,
40
  FrequencyPenalty: request.FrequencyPenalty,
41
  PresencePenalty: request.PresencePenalty,
42
- }
 
 
 
43
  }
44
 
45
  func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
 
9
  "net/http"
10
  "one-api/dto"
11
  "one-api/service"
12
+ "strings"
13
  )
14
 
15
+ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
16
  messages := make([]dto.Message, 0, len(request.Messages))
17
  for _, message := range request.Messages {
18
+ if !message.IsStringContent() {
19
+ mediaMessages := message.ParseContent()
20
+ for j, mediaMessage := range mediaMessages {
21
+ if mediaMessage.Type == dto.ContentTypeImageURL {
22
+ imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
23
+ // check if not base64
24
+ if strings.HasPrefix(imageUrl.Url, "http") {
25
+ fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
26
+ if err != nil {
27
+ return nil, err
28
+ }
29
+ imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
30
+ }
31
+ mediaMessage.ImageUrl = imageUrl
32
+ mediaMessages[j] = mediaMessage
33
+ }
34
+ }
35
+ message.SetMediaContent(mediaMessages)
36
+ }
37
  messages = append(messages, dto.Message{
38
+ Role: message.Role,
39
+ Content: message.Content,
40
+ ToolCalls: message.ToolCalls,
41
+ ToolCallId: message.ToolCallId,
42
  })
43
  }
44
  str, ok := request.Stop.(string)
 
61
  ResponseFormat: request.ResponseFormat,
62
  FrequencyPenalty: request.FrequencyPenalty,
63
  PresencePenalty: request.PresencePenalty,
64
+ Prompt: request.Prompt,
65
+ StreamOptions: request.StreamOptions,
66
+ Suffix: request.Suffix,
67
+ }, nil
68
  }
69
 
70
  func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
relay/channel/openai/adaptor.go CHANGED
@@ -119,7 +119,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
119
  request.MaxCompletionTokens = request.MaxTokens
120
  request.MaxTokens = 0
121
  }
122
- if strings.HasPrefix(request.Model, "o3") {
123
  request.Temperature = nil
124
  }
125
  if strings.HasSuffix(request.Model, "-high") {
 
119
  request.MaxCompletionTokens = request.MaxTokens
120
  request.MaxTokens = 0
121
  }
122
+ if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
123
  request.Temperature = nil
124
  }
125
  if strings.HasSuffix(request.Model, "-high") {
relay/channel/openai/relay-openai.go CHANGED
@@ -87,6 +87,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
87
  info.SetFirstResponseTime()
88
  ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
89
  data := scanner.Text()
 
 
 
90
  if len(data) < 6 { // ignore blank line or wrong format
91
  continue
92
  }
@@ -162,6 +165,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
162
  //}
163
  for _, choice := range streamResponse.Choices {
164
  responseTextBuilder.WriteString(choice.Delta.GetContentString())
 
165
  if choice.Delta.ToolCalls != nil {
166
  if len(choice.Delta.ToolCalls) > toolCount {
167
  toolCount = len(choice.Delta.ToolCalls)
@@ -182,6 +186,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
182
  //}
183
  for _, choice := range streamResponse.Choices {
184
  responseTextBuilder.WriteString(choice.Delta.GetContentString())
 
185
  if choice.Delta.ToolCalls != nil {
186
  if len(choice.Delta.ToolCalls) > toolCount {
187
  toolCount = len(choice.Delta.ToolCalls)
@@ -273,7 +278,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
273
  if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
274
  completionTokens := 0
275
  for _, choice := range simpleResponse.Choices {
276
- ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
277
  completionTokens += ctkm
278
  }
279
  simpleResponse.Usage = dto.Usage{
 
87
  info.SetFirstResponseTime()
88
  ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
89
  data := scanner.Text()
90
+ if common.DebugEnabled {
91
+ println(data)
92
+ }
93
  if len(data) < 6 { // ignore blank line or wrong format
94
  continue
95
  }
 
165
  //}
166
  for _, choice := range streamResponse.Choices {
167
  responseTextBuilder.WriteString(choice.Delta.GetContentString())
168
+ responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
169
  if choice.Delta.ToolCalls != nil {
170
  if len(choice.Delta.ToolCalls) > toolCount {
171
  toolCount = len(choice.Delta.ToolCalls)
 
186
  //}
187
  for _, choice := range streamResponse.Choices {
188
  responseTextBuilder.WriteString(choice.Delta.GetContentString())
189
+ responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
190
  if choice.Delta.ToolCalls != nil {
191
  if len(choice.Delta.ToolCalls) > toolCount {
192
  toolCount = len(choice.Delta.ToolCalls)
 
278
  if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
279
  completionTokens := 0
280
  for _, choice := range simpleResponse.Choices {
281
+ ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
282
  completionTokens += ctkm
283
  }
284
  simpleResponse.Usage = dto.Usage{
relay/channel/siliconflow/adaptor.go CHANGED
@@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
36
  return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
37
  } else if info.RelayMode == constant.RelayModeChatCompletions {
38
  return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
 
 
39
  }
40
  return "", errors.New("invalid relay mode")
41
  }
@@ -72,6 +74,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
72
  } else {
73
  err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
74
  }
 
 
 
 
 
 
75
  case constant.RelayModeEmbeddings:
76
  err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
77
  }
 
36
  return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
37
  } else if info.RelayMode == constant.RelayModeChatCompletions {
38
  return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
39
+ } else if info.RelayMode == constant.RelayModeCompletions {
40
+ return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
41
  }
42
  return "", errors.New("invalid relay mode")
43
  }
 
74
  } else {
75
  err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
76
  }
77
+ case constant.RelayModeCompletions:
78
+ if info.IsStream {
79
+ err, usage = openai.OaiStreamHandler(c, resp, info)
80
+ } else {
81
+ err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
82
+ }
83
  case constant.RelayModeEmbeddings:
84
  err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
85
  }
relay/channel/zhipu_4v/relay-zhipu_v4.go CHANGED
@@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
90
  mediaMessages[j] = mediaMessage
91
  }
92
  }
93
- messageRaw, _ := json.Marshal(mediaMessages)
94
- message.Content = messageRaw
95
  }
96
  messages = append(messages, dto.Message{
97
  Role: message.Role,
 
90
  mediaMessages[j] = mediaMessage
91
  }
92
  }
93
+ message.SetMediaContent(mediaMessages)
 
94
  }
95
  messages = append(messages, dto.Message{
96
  Role: message.Role,
relay/common/relay_info.go CHANGED
@@ -13,24 +13,24 @@ import (
13
  )
14
 
15
  type RelayInfo struct {
16
- ChannelType int
17
- ChannelId int
18
- TokenId int
19
- TokenKey string
20
- UserId int
21
- Group string
22
- TokenUnlimited bool
23
- StartTime time.Time
24
- FirstResponseTime time.Time
25
- setFirstResponse bool
26
- ApiType int
27
- IsStream bool
28
- IsPlayground bool
29
- UsePrice bool
30
- RelayMode int
31
- UpstreamModelName string
32
- OriginModelName string
33
- RecodeModelName string
34
  RequestURLPath string
35
  ApiVersion string
36
  PromptTokens int
@@ -39,6 +39,7 @@ type RelayInfo struct {
39
  BaseUrl string
40
  SupportStreamOptions bool
41
  ShouldIncludeUsage bool
 
42
  ClientWs *websocket.Conn
43
  TargetWs *websocket.Conn
44
  InputAudioFormat string
@@ -50,6 +51,18 @@ type RelayInfo struct {
50
  ChannelSetting map[string]interface{}
51
  }
52
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
54
  info := GenRelayInfo(c)
55
  info.ClientWs = ws
@@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
89
  FirstResponseTime: startTime.Add(-time.Second),
90
  OriginModelName: c.GetString("original_model"),
91
  UpstreamModelName: c.GetString("original_model"),
92
- RecodeModelName: c.GetString("recode_model"),
93
- ApiType: apiType,
94
- ApiVersion: c.GetString("api_version"),
95
- ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
96
- Organization: c.GetString("channel_organization"),
97
- ChannelSetting: channelSetting,
 
98
  }
99
  if strings.HasPrefix(c.Request.URL.Path, "/pg") {
100
  info.IsPlayground = true
@@ -110,9 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
110
  if info.ChannelType == common.ChannelTypeVertexAi {
111
  info.ApiVersion = c.GetString("region")
112
  }
113
- if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
114
- info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
115
- info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
116
  info.SupportStreamOptions = true
117
  }
118
  return info
 
13
  )
14
 
15
  type RelayInfo struct {
16
+ ChannelType int
17
+ ChannelId int
18
+ TokenId int
19
+ TokenKey string
20
+ UserId int
21
+ Group string
22
+ TokenUnlimited bool
23
+ StartTime time.Time
24
+ FirstResponseTime time.Time
25
+ setFirstResponse bool
26
+ ApiType int
27
+ IsStream bool
28
+ IsPlayground bool
29
+ UsePrice bool
30
+ RelayMode int
31
+ UpstreamModelName string
32
+ OriginModelName string
33
+ //RecodeModelName string
34
  RequestURLPath string
35
  ApiVersion string
36
  PromptTokens int
 
39
  BaseUrl string
40
  SupportStreamOptions bool
41
  ShouldIncludeUsage bool
42
+ IsModelMapped bool
43
  ClientWs *websocket.Conn
44
  TargetWs *websocket.Conn
45
  InputAudioFormat string
 
51
  ChannelSetting map[string]interface{}
52
  }
53
 
54
+ // 定义支持流式选项的通道类型
55
+ var streamSupportedChannels = map[int]bool{
56
+ common.ChannelTypeOpenAI: true,
57
+ common.ChannelTypeAnthropic: true,
58
+ common.ChannelTypeAws: true,
59
+ common.ChannelTypeGemini: true,
60
+ common.ChannelCloudflare: true,
61
+ common.ChannelTypeAzure: true,
62
+ common.ChannelTypeVolcEngine: true,
63
+ common.ChannelTypeOllama: true,
64
+ }
65
+
66
  func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
67
  info := GenRelayInfo(c)
68
  info.ClientWs = ws
 
102
  FirstResponseTime: startTime.Add(-time.Second),
103
  OriginModelName: c.GetString("original_model"),
104
  UpstreamModelName: c.GetString("original_model"),
105
+ //RecodeModelName: c.GetString("original_model"),
106
+ IsModelMapped: false,
107
+ ApiType: apiType,
108
+ ApiVersion: c.GetString("api_version"),
109
+ ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
110
+ Organization: c.GetString("channel_organization"),
111
+ ChannelSetting: channelSetting,
112
  }
113
  if strings.HasPrefix(c.Request.URL.Path, "/pg") {
114
  info.IsPlayground = true
 
124
  if info.ChannelType == common.ChannelTypeVertexAi {
125
  info.ApiVersion = c.GetString("region")
126
  }
127
+ if streamSupportedChannels[info.ChannelType] {
 
 
128
  info.SupportStreamOptions = true
129
  }
130
  return info
relay/helper/model_mapped.go ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package helper
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "github.com/gin-gonic/gin"
7
+ "one-api/relay/common"
8
+ )
9
+
10
+ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
11
+ // map model name
12
+ modelMapping := c.GetString("model_mapping")
13
+ if modelMapping != "" && modelMapping != "{}" {
14
+ modelMap := make(map[string]string)
15
+ err := json.Unmarshal([]byte(modelMapping), &modelMap)
16
+ if err != nil {
17
+ return fmt.Errorf("unmarshal_model_mapping_failed")
18
+ }
19
+ if modelMap[info.OriginModelName] != "" {
20
+ info.UpstreamModelName = modelMap[info.OriginModelName]
21
+ info.IsModelMapped = true
22
+ }
23
+ }
24
+ return nil
25
+ }
relay/helper/price.go ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package helper
2
+
3
+ import (
4
+ "github.com/gin-gonic/gin"
5
+ "one-api/common"
6
+ relaycommon "one-api/relay/common"
7
+ "one-api/setting"
8
+ )
9
+
10
+ type PriceData struct {
11
+ ModelPrice float64
12
+ ModelRatio float64
13
+ GroupRatio float64
14
+ UsePrice bool
15
+ ShouldPreConsumedQuota int
16
+ }
17
+
18
+ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
19
+ modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
20
+ groupRatio := setting.GetGroupRatio(info.Group)
21
+ var preConsumedQuota int
22
+ var modelRatio float64
23
+ if !usePrice {
24
+ preConsumedTokens := common.PreConsumedQuota
25
+ if maxTokens != 0 {
26
+ preConsumedTokens = promptTokens + maxTokens
27
+ }
28
+ modelRatio = common.GetModelRatio(info.OriginModelName)
29
+ ratio := modelRatio * groupRatio
30
+ preConsumedQuota = int(float64(preConsumedTokens) * ratio)
31
+ } else {
32
+ preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
33
+ }
34
+ return PriceData{
35
+ ModelPrice: modelPrice,
36
+ ModelRatio: modelRatio,
37
+ GroupRatio: groupRatio,
38
+ UsePrice: usePrice,
39
+ ShouldPreConsumedQuota: preConsumedQuota,
40
+ }
41
+ }
relay/relay-audio.go CHANGED
@@ -1,7 +1,6 @@
1
  package relay
2
 
3
  import (
4
- "encoding/json"
5
  "errors"
6
  "fmt"
7
  "github.com/gin-gonic/gin"
@@ -11,8 +10,10 @@ import (
11
  "one-api/model"
12
  relaycommon "one-api/relay/common"
13
  relayconstant "one-api/relay/constant"
 
14
  "one-api/service"
15
  "one-api/setting"
 
16
  )
17
 
18
  func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
@@ -27,8 +28,9 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
27
  return nil, errors.New("model is required")
28
  }
29
  if setting.ShouldCheckPromptSensitive() {
30
- err := service.CheckSensitiveInput(audioRequest.Input)
31
  if err != nil {
 
32
  return nil, err
33
  }
34
  }
@@ -73,15 +75,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
73
  relayInfo.PromptTokens = promptTokens
74
  }
75
 
76
- modelRatio := common.GetModelRatio(audioRequest.Model)
77
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
78
- ratio := modelRatio * groupRatio
79
- preConsumedQuota := int(float64(preConsumedTokens) * ratio)
80
  userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
81
  if err != nil {
82
  return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
83
  }
84
- preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
85
  if openaiErr != nil {
86
  return openaiErr
87
  }
@@ -91,19 +91,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
91
  }
92
  }()
93
 
94
- // map model name
95
- modelMapping := c.GetString("model_mapping")
96
- if modelMapping != "" {
97
- modelMap := make(map[string]string)
98
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
99
- if err != nil {
100
- return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
101
- }
102
- if modelMap[audioRequest.Model] != "" {
103
- audioRequest.Model = modelMap[audioRequest.Model]
104
- }
105
  }
106
- relayInfo.UpstreamModelName = audioRequest.Model
 
107
 
108
  adaptor := GetAdaptor(relayInfo.ApiType)
109
  if adaptor == nil {
@@ -140,7 +133,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
140
  return openaiErr
141
  }
142
 
143
- postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
144
 
145
  return nil
146
  }
 
1
  package relay
2
 
3
  import (
 
4
  "errors"
5
  "fmt"
6
  "github.com/gin-gonic/gin"
 
10
  "one-api/model"
11
  relaycommon "one-api/relay/common"
12
  relayconstant "one-api/relay/constant"
13
+ "one-api/relay/helper"
14
  "one-api/service"
15
  "one-api/setting"
16
+ "strings"
17
  )
18
 
19
  func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
 
28
  return nil, errors.New("model is required")
29
  }
30
  if setting.ShouldCheckPromptSensitive() {
31
+ words, err := service.CheckSensitiveInput(audioRequest.Input)
32
  if err != nil {
33
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
34
  return nil, err
35
  }
36
  }
 
75
  relayInfo.PromptTokens = promptTokens
76
  }
77
 
78
+ priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
79
+
 
 
80
  userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
81
  if err != nil {
82
  return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
83
  }
84
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
85
  if openaiErr != nil {
86
  return openaiErr
87
  }
 
91
  }
92
  }()
93
 
94
+ err = helper.ModelMappedHelper(c, relayInfo)
95
+ if err != nil {
96
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 
 
 
 
 
 
 
 
97
  }
98
+
99
+ audioRequest.Model = relayInfo.UpstreamModelName
100
 
101
  adaptor := GetAdaptor(relayInfo.ApiType)
102
  if adaptor == nil {
 
133
  return openaiErr
134
  }
135
 
136
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
137
 
138
  return nil
139
  }
relay/relay-image.go CHANGED
@@ -12,6 +12,7 @@ import (
12
  "one-api/dto"
13
  "one-api/model"
14
  relaycommon "one-api/relay/common"
 
15
  "one-api/service"
16
  "one-api/setting"
17
  "strings"
@@ -60,15 +61,16 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
60
  // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
61
  //}
62
  if setting.ShouldCheckPromptSensitive() {
63
- err := service.CheckSensitiveInput(imageRequest.Prompt)
64
  if err != nil {
 
65
  return nil, err
66
  }
67
  }
68
  return imageRequest, nil
69
  }
70
 
71
- func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
72
  relayInfo := relaycommon.GenRelayInfo(c)
73
 
74
  imageRequest, err := getAndValidImageRequest(c, relayInfo)
@@ -77,29 +79,20 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
77
  return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
78
  }
79
 
80
- // map model name
81
- modelMapping := c.GetString("model_mapping")
82
- if modelMapping != "" {
83
- modelMap := make(map[string]string)
84
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
85
- if err != nil {
86
- return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
87
- }
88
- if modelMap[imageRequest.Model] != "" {
89
- imageRequest.Model = modelMap[imageRequest.Model]
90
- }
91
  }
92
- relayInfo.UpstreamModelName = imageRequest.Model
93
 
94
- modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
95
- if !success {
96
- modelRatio := common.GetModelRatio(imageRequest.Model)
 
97
  // modelRatio 16 = modelPrice $0.04
98
  // per 1 modelRatio = $0.04 / 16
99
- modelPrice = 0.0025 * modelRatio
100
  }
101
 
102
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
103
  userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
104
 
105
  sizeRatio := 1.0
@@ -122,11 +115,11 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
122
  }
123
  }
124
 
125
- imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
126
- quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
127
 
128
  if userQuota-quota < 0 {
129
- return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest)
130
  }
131
 
132
  adaptor := GetAdaptor(relayInfo.ApiType)
@@ -184,7 +177,6 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
184
  }
185
 
186
  logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
187
- postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
188
-
189
  return nil
190
  }
 
12
  "one-api/dto"
13
  "one-api/model"
14
  relaycommon "one-api/relay/common"
15
+ "one-api/relay/helper"
16
  "one-api/service"
17
  "one-api/setting"
18
  "strings"
 
61
  // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
62
  //}
63
  if setting.ShouldCheckPromptSensitive() {
64
+ words, err := service.CheckSensitiveInput(imageRequest.Prompt)
65
  if err != nil {
66
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
67
  return nil, err
68
  }
69
  }
70
  return imageRequest, nil
71
  }
72
 
73
+ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
74
  relayInfo := relaycommon.GenRelayInfo(c)
75
 
76
  imageRequest, err := getAndValidImageRequest(c, relayInfo)
 
79
  return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
80
  }
81
 
82
+ err = helper.ModelMappedHelper(c, relayInfo)
83
+ if err != nil {
84
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 
 
 
 
 
 
 
 
85
  }
 
86
 
87
+ imageRequest.Model = relayInfo.UpstreamModelName
88
+
89
+ priceData := helper.ModelPriceHelper(c, relayInfo, 0, 0)
90
+ if !priceData.UsePrice {
91
  // modelRatio 16 = modelPrice $0.04
92
  // per 1 modelRatio = $0.04 / 16
93
+ priceData.ModelPrice = 0.0025 * priceData.ModelRatio
94
  }
95
 
 
96
  userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
97
 
98
  sizeRatio := 1.0
 
115
  }
116
  }
117
 
118
+ priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
119
+ quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
120
 
121
  if userQuota-quota < 0 {
122
+ return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
123
  }
124
 
125
  adaptor := GetAdaptor(relayInfo.ApiType)
 
177
  }
178
 
179
  logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
180
+ postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
 
181
  return nil
182
  }
relay/relay-mj.go CHANGED
@@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
194
  }
195
  defer func(ctx context.Context) {
196
  if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
197
- err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
198
  if err != nil {
199
  common.SysError("error consuming token remain quota: " + err.Error())
200
  }
@@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
500
 
501
  defer func(ctx context.Context) {
502
  if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
503
- err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
504
  if err != nil {
505
  common.SysError("error consuming token remain quota: " + err.Error())
506
  }
 
194
  }
195
  defer func(ctx context.Context) {
196
  if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
197
+ err := service.PostConsumeQuota(relayInfo, quota, 0, true)
198
  if err != nil {
199
  common.SysError("error consuming token remain quota: " + err.Error())
200
  }
 
500
 
501
  defer func(ctx context.Context) {
502
  if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
503
+ err := service.PostConsumeQuota(relayInfo, quota, 0, true)
504
  if err != nil {
505
  common.SysError("error consuming token remain quota: " + err.Error())
506
  }
relay/relay-text.go CHANGED
@@ -5,6 +5,7 @@ import (
5
  "encoding/json"
6
  "errors"
7
  "fmt"
 
8
  "io"
9
  "math"
10
  "net/http"
@@ -14,6 +15,7 @@ import (
14
  "one-api/model"
15
  relaycommon "one-api/relay/common"
16
  relayconstant "one-api/relay/constant"
 
17
  "one-api/service"
18
  "one-api/setting"
19
  "strings"
@@ -75,40 +77,21 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
75
  return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
76
  }
77
 
78
- // map model name
79
- //isModelMapped := false
80
- modelMapping := c.GetString("model_mapping")
81
- //isModelMapped := false
82
- if modelMapping != "" && modelMapping != "{}" {
83
- modelMap := make(map[string]string)
84
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
85
- if err != nil {
86
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
87
- }
88
- if modelMap[textRequest.Model] != "" {
89
- //isModelMapped = true
90
- textRequest.Model = modelMap[textRequest.Model]
91
- // set upstream model name
92
- //isModelMapped = true
93
- }
94
- }
95
- relayInfo.UpstreamModelName = textRequest.Model
96
- relayInfo.RecodeModelName = textRequest.Model
97
- modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
98
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
99
-
100
- var preConsumedQuota int
101
- var ratio float64
102
- var modelRatio float64
103
- //err := service.SensitiveWordsCheck(textRequest)
104
-
105
  if setting.ShouldCheckPromptSensitive() {
106
- err = checkRequestSensitive(textRequest, relayInfo)
107
  if err != nil {
 
108
  return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
109
  }
110
  }
111
 
 
 
 
 
 
 
 
112
  // 获取 promptTokens,如果上下文中已经存在,则直接使用
113
  var promptTokens int
114
  if value, exists := c.Get("prompt_tokens"); exists {
@@ -123,20 +106,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
123
  c.Set("prompt_tokens", promptTokens)
124
  }
125
 
126
- if !getModelPriceSuccess {
127
- preConsumedTokens := common.PreConsumedQuota
128
- if textRequest.MaxTokens != 0 {
129
- preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
130
- }
131
- modelRatio = common.GetModelRatio(textRequest.Model)
132
- ratio = modelRatio * groupRatio
133
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
134
- } else {
135
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
136
- }
137
 
138
  // pre-consume quota 预消耗配额
139
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
140
  if openaiErr != nil {
141
  return openaiErr
142
  }
@@ -219,10 +192,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
219
  return openaiErr
220
  }
221
 
222
- if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") {
223
- service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
224
  } else {
225
- postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
226
  }
227
  return nil
228
  }
@@ -247,19 +220,20 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
247
  return promptTokens, err
248
  }
249
 
250
- func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
251
  var err error
 
252
  switch info.RelayMode {
253
  case relayconstant.RelayModeChatCompletions:
254
- err = service.CheckSensitiveMessages(textRequest.Messages)
255
  case relayconstant.RelayModeCompletions:
256
- err = service.CheckSensitiveInput(textRequest.Prompt)
257
  case relayconstant.RelayModeModerations:
258
- err = service.CheckSensitiveInput(textRequest.Input)
259
  case relayconstant.RelayModeEmbeddings:
260
- err = service.CheckSensitiveInput(textRequest.Input)
261
  }
262
- return err
263
  }
264
 
265
  // 预扣费并返回用户剩余配额
@@ -272,7 +246,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
272
  return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
273
  }
274
  if userQuota-preConsumedQuota < 0 {
275
- return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
276
  }
277
  if userQuota > 100*preConsumedQuota {
278
  // 用户额度充足,判断令牌额度是否充足
@@ -282,18 +256,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
282
  if tokenQuota > 100*preConsumedQuota {
283
  // 令牌额度充足,信任令牌
284
  preConsumedQuota = 0
285
- common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
286
  }
287
  } else {
288
  // in this case, we do not pre-consume quota
289
  // because the user has enough quota
290
  preConsumedQuota = 0
291
- common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
292
  }
293
  }
294
 
295
  if preConsumedQuota > 0 {
296
- err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
297
  if err != nil {
298
  return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
299
  }
@@ -307,20 +281,19 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
307
 
308
  func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
309
  if preConsumedQuota != 0 {
310
- go func() {
311
  relayInfoCopy := *relayInfo
312
 
313
- err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
314
  if err != nil {
315
  common.SysError("error return pre-consumed quota: " + err.Error())
316
  }
317
- }()
318
  }
319
  }
320
 
321
- func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
322
- usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
323
- modelPrice float64, usePrice bool, extraContent string) {
324
  if usage == nil {
325
  usage = &dto.Usage{
326
  PromptTokens: relayInfo.PromptTokens,
@@ -332,12 +305,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
332
  useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
333
  promptTokens := usage.PromptTokens
334
  completionTokens := usage.CompletionTokens
 
335
 
336
  tokenName := ctx.GetString("token_name")
337
  completionRatio := common.GetCompletionRatio(modelName)
 
 
 
 
 
338
 
339
  quota := 0
340
- if !usePrice {
341
  quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
342
  quota = int(math.Round(float64(quota) * ratio))
343
  if ratio != 0 && quota <= 0 {
@@ -368,7 +347,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
368
  //}
369
  quotaDelta := quota - preConsumedQuota
370
  if quotaDelta != 0 {
371
- err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
372
  if err != nil {
373
  common.LogError(ctx, "error consuming token remain quota: "+err.Error())
374
  }
 
5
  "encoding/json"
6
  "errors"
7
  "fmt"
8
+ "github.com/bytedance/gopkg/util/gopool"
9
  "io"
10
  "math"
11
  "net/http"
 
15
  "one-api/model"
16
  relaycommon "one-api/relay/common"
17
  relayconstant "one-api/relay/constant"
18
+ "one-api/relay/helper"
19
  "one-api/service"
20
  "one-api/setting"
21
  "strings"
 
77
  return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
78
  }
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  if setting.ShouldCheckPromptSensitive() {
81
+ words, err := checkRequestSensitive(textRequest, relayInfo)
82
  if err != nil {
83
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
84
  return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
85
  }
86
  }
87
 
88
+ err = helper.ModelMappedHelper(c, relayInfo)
89
+ if err != nil {
90
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
91
+ }
92
+
93
+ textRequest.Model = relayInfo.UpstreamModelName
94
+
95
  // 获取 promptTokens,如果上下文中已经存在,则直接使用
96
  var promptTokens int
97
  if value, exists := c.Get("prompt_tokens"); exists {
 
106
  c.Set("prompt_tokens", promptTokens)
107
  }
108
 
109
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
 
 
 
 
 
 
 
 
 
 
110
 
111
  // pre-consume quota 预消耗配额
112
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
113
  if openaiErr != nil {
114
  return openaiErr
115
  }
 
192
  return openaiErr
193
  }
194
 
195
+ if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
196
+ service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
197
  } else {
198
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
199
  }
200
  return nil
201
  }
 
220
  return promptTokens, err
221
  }
222
 
223
+ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
224
  var err error
225
+ var words []string
226
  switch info.RelayMode {
227
  case relayconstant.RelayModeChatCompletions:
228
+ words, err = service.CheckSensitiveMessages(textRequest.Messages)
229
  case relayconstant.RelayModeCompletions:
230
+ words, err = service.CheckSensitiveInput(textRequest.Prompt)
231
  case relayconstant.RelayModeModerations:
232
+ words, err = service.CheckSensitiveInput(textRequest.Input)
233
  case relayconstant.RelayModeEmbeddings:
234
+ words, err = service.CheckSensitiveInput(textRequest.Input)
235
  }
236
+ return words, err
237
  }
238
 
239
  // 预扣费并返回用户剩余配额
 
246
  return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
247
  }
248
  if userQuota-preConsumedQuota < 0 {
249
+ return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
250
  }
251
  if userQuota > 100*preConsumedQuota {
252
  // 用户额度充足,判断令牌额度是否充足
 
256
  if tokenQuota > 100*preConsumedQuota {
257
  // 令牌额度充足,信任令牌
258
  preConsumedQuota = 0
259
+ common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
260
  }
261
  } else {
262
  // in this case, we do not pre-consume quota
263
  // because the user has enough quota
264
  preConsumedQuota = 0
265
+ common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
266
  }
267
  }
268
 
269
  if preConsumedQuota > 0 {
270
+ err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
271
  if err != nil {
272
  return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
273
  }
 
281
 
282
  func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
283
  if preConsumedQuota != 0 {
284
+ gopool.Go(func() {
285
  relayInfoCopy := *relayInfo
286
 
287
+ err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
288
  if err != nil {
289
  common.SysError("error return pre-consumed quota: " + err.Error())
290
  }
291
+ })
292
  }
293
  }
294
 
295
+ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
296
+ usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
 
297
  if usage == nil {
298
  usage = &dto.Usage{
299
  PromptTokens: relayInfo.PromptTokens,
 
305
  useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
306
  promptTokens := usage.PromptTokens
307
  completionTokens := usage.CompletionTokens
308
+ modelName := relayInfo.OriginModelName
309
 
310
  tokenName := ctx.GetString("token_name")
311
  completionRatio := common.GetCompletionRatio(modelName)
312
+ ratio := priceData.ModelRatio * priceData.GroupRatio
313
+ modelRatio := priceData.ModelRatio
314
+ groupRatio := priceData.GroupRatio
315
+ modelPrice := priceData.ModelPrice
316
+ usePrice := priceData.UsePrice
317
 
318
  quota := 0
319
+ if !priceData.UsePrice {
320
  quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
321
  quota = int(math.Round(float64(quota) * ratio))
322
  if ratio != 0 && quota <= 0 {
 
347
  //}
348
  quotaDelta := quota - preConsumedQuota
349
  if quotaDelta != 0 {
350
+ err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
351
  if err != nil {
352
  common.LogError(ctx, "error consuming token remain quota: "+err.Error())
353
  }
relay/relay_embedding.go CHANGED
@@ -10,8 +10,8 @@ import (
10
  "one-api/dto"
11
  relaycommon "one-api/relay/common"
12
  relayconstant "one-api/relay/constant"
 
13
  "one-api/service"
14
- "one-api/setting"
15
  )
16
 
17
  func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
47
  return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
48
  }
49
 
50
- // map model name
51
- modelMapping := c.GetString("model_mapping")
52
- //isModelMapped := false
53
- if modelMapping != "" && modelMapping != "{}" {
54
- modelMap := make(map[string]string)
55
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
56
- if err != nil {
57
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
58
- }
59
- if modelMap[embeddingRequest.Model] != "" {
60
- embeddingRequest.Model = modelMap[embeddingRequest.Model]
61
- // set upstream model name
62
- //isModelMapped = true
63
- }
64
  }
65
 
66
- relayInfo.UpstreamModelName = embeddingRequest.Model
67
- modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
68
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
69
-
70
- var preConsumedQuota int
71
- var ratio float64
72
- var modelRatio float64
73
 
74
  promptToken := getEmbeddingPromptToken(*embeddingRequest)
75
- if !success {
76
- preConsumedTokens := promptToken
77
- modelRatio = common.GetModelRatio(embeddingRequest.Model)
78
- ratio = modelRatio * groupRatio
79
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
80
- } else {
81
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
82
- }
83
  relayInfo.PromptTokens = promptToken
84
 
 
 
85
  // pre-consume quota 预消耗配额
86
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
87
  if openaiErr != nil {
88
  return openaiErr
89
  }
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
132
  service.ResetStatusCode(openaiErr, statusCodeMappingStr)
133
  return openaiErr
134
  }
135
- postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
136
  return nil
137
  }
 
10
  "one-api/dto"
11
  relaycommon "one-api/relay/common"
12
  relayconstant "one-api/relay/constant"
13
+ "one-api/relay/helper"
14
  "one-api/service"
 
15
  )
16
 
17
  func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
 
47
  return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
48
  }
49
 
50
+ err = helper.ModelMappedHelper(c, relayInfo)
51
+ if err != nil {
52
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 
 
 
 
 
 
 
 
 
 
 
53
  }
54
 
55
+ embeddingRequest.Model = relayInfo.UpstreamModelName
 
 
 
 
 
 
56
 
57
  promptToken := getEmbeddingPromptToken(*embeddingRequest)
 
 
 
 
 
 
 
 
58
  relayInfo.PromptTokens = promptToken
59
 
60
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
61
+
62
  // pre-consume quota 预消耗配额
63
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
64
  if openaiErr != nil {
65
  return openaiErr
66
  }
 
109
  service.ResetStatusCode(openaiErr, statusCodeMappingStr)
110
  return openaiErr
111
  }
112
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
113
  return nil
114
  }
relay/relay_rerank.go CHANGED
@@ -9,8 +9,8 @@ import (
9
  "one-api/common"
10
  "one-api/dto"
11
  relaycommon "one-api/relay/common"
 
12
  "one-api/service"
13
- "one-api/setting"
14
  )
15
 
16
  func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
40
  return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
41
  }
42
 
43
- // map model name
44
- modelMapping := c.GetString("model_mapping")
45
- //isModelMapped := false
46
- if modelMapping != "" && modelMapping != "{}" {
47
- modelMap := make(map[string]string)
48
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
49
- if err != nil {
50
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
51
- }
52
- if modelMap[rerankRequest.Model] != "" {
53
- rerankRequest.Model = modelMap[rerankRequest.Model]
54
- // set upstream model name
55
- //isModelMapped = true
56
- }
57
  }
58
 
59
- relayInfo.UpstreamModelName = rerankRequest.Model
60
- modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
61
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
62
-
63
- var preConsumedQuota int
64
- var ratio float64
65
- var modelRatio float64
66
 
67
  promptToken := getRerankPromptToken(*rerankRequest)
68
- if !success {
69
- preConsumedTokens := promptToken
70
- modelRatio = common.GetModelRatio(rerankRequest.Model)
71
- ratio = modelRatio * groupRatio
72
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
73
- } else {
74
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
75
- }
76
  relayInfo.PromptTokens = promptToken
77
 
 
 
78
  // pre-consume quota 预消耗配额
79
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
80
  if openaiErr != nil {
81
  return openaiErr
82
  }
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
124
  service.ResetStatusCode(openaiErr, statusCodeMappingStr)
125
  return openaiErr
126
  }
127
- postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
128
  return nil
129
  }
 
9
  "one-api/common"
10
  "one-api/dto"
11
  relaycommon "one-api/relay/common"
12
+ "one-api/relay/helper"
13
  "one-api/service"
 
14
  )
15
 
16
  func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
 
40
  return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
41
  }
42
 
43
+ err = helper.ModelMappedHelper(c, relayInfo)
44
+ if err != nil {
45
+ return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 
 
 
 
 
 
 
 
 
 
 
46
  }
47
 
48
+ rerankRequest.Model = relayInfo.UpstreamModelName
 
 
 
 
 
 
49
 
50
  promptToken := getRerankPromptToken(*rerankRequest)
 
 
 
 
 
 
 
 
51
  relayInfo.PromptTokens = promptToken
52
 
53
+ priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
54
+
55
  // pre-consume quota 预消耗配额
56
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
57
  if openaiErr != nil {
58
  return openaiErr
59
  }
 
101
  service.ResetStatusCode(openaiErr, statusCodeMappingStr)
102
  return openaiErr
103
  }
104
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
105
  return nil
106
  }
relay/relay_task.go CHANGED
@@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
113
  // release quota
114
  if relayInfo.ConsumeQuota && taskErr == nil {
115
 
116
- err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
117
  if err != nil {
118
  common.SysError("error consuming token remain quota: " + err.Error())
119
  }
 
113
  // release quota
114
  if relayInfo.ConsumeQuota && taskErr == nil {
115
 
116
+ err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
117
  if err != nil {
118
  common.SysError("error consuming token remain quota: " + err.Error())
119
  }
router/api-router.go CHANGED
@@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) {
56
  selfRoute.POST("/pay", controller.RequestEpay)
57
  selfRoute.POST("/amount", controller.RequestAmount)
58
  selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
 
59
  }
60
 
61
  adminRoute := userRoute.Group("/")
 
56
  selfRoute.POST("/pay", controller.RequestEpay)
57
  selfRoute.POST("/amount", controller.RequestAmount)
58
  selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
59
+ selfRoute.PUT("/setting", controller.UpdateUserSetting)
60
  }
61
 
62
  adminRoute := userRoute.Group("/")