luowuyin commited on
Commit
0722d2c
·
1 Parent(s): b066e83

25:03:07 18:16:42 v0.4.9.0

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.en.md +2 -2
  2. VERSION +1 -1
  3. common/gopool.go +24 -0
  4. controller/channel-test.go +26 -33
  5. controller/misc.go +3 -1
  6. controller/pricing.go +2 -3
  7. controller/relay.go +3 -11
  8. middleware/model-rate-limit.go +19 -16
  9. model/channel.go +18 -10
  10. model/log.go +3 -2
  11. model/option.go +16 -12
  12. model/pricing.go +4 -3
  13. relay/channel/ali/text.go +2 -1
  14. relay/channel/aws/dto.go +1 -1
  15. relay/channel/aws/relay-aws.go +4 -3
  16. relay/channel/baidu/relay-baidu.go +2 -1
  17. relay/channel/claude/dto.go +1 -1
  18. relay/channel/claude/relay-claude.go +13 -22
  19. relay/channel/cloudflare/relay_cloudflare.go +8 -7
  20. relay/channel/cohere/relay-cohere.go +2 -1
  21. relay/channel/dify/relay-dify.go +4 -3
  22. relay/channel/gemini/relay-gemini.go +18 -29
  23. relay/channel/openai/relay-openai.go +47 -79
  24. relay/channel/palm/relay-palm.go +2 -1
  25. relay/channel/tencent/relay-tencent.go +4 -3
  26. relay/channel/vertex/adaptor.go +6 -9
  27. relay/channel/vertex/dto.go +24 -4
  28. relay/channel/xunfei/relay-xunfei.go +2 -1
  29. relay/channel/zhipu/relay-zhipu.go +2 -1
  30. relay/channel/zhipu_4v/relay-zhipu_v4.go +2 -1
  31. relay/common/relay_info.go +31 -21
  32. service/relay.go → relay/helper/common.go +1 -1
  33. relay/helper/price.go +11 -3
  34. relay/helper/stream_scanner.go +91 -0
  35. relay/relay-mj.go +4 -4
  36. relay/relay-text.go +1 -1
  37. relay/relay_task.go +2 -2
  38. relay/websocket.go +2 -2
  39. service/channel.go +18 -10
  40. service/image.go +44 -11
  41. service/quota.go +10 -10
  42. service/token_counter.go +2 -1
  43. service/user_notify.go +4 -1
  44. {common → setting}/model-ratio.go +9 -8
  45. setting/{operation_setting.go → operation_setting/operation_setting.go} +2 -1
  46. web/src/App.js +11 -0
  47. web/src/components/ChannelsTable.js +84 -12
  48. web/src/components/HeaderBar.js +51 -5
  49. web/src/components/OperationSetting.js +1 -0
  50. web/src/components/OtherSetting.js +99 -41
README.en.md CHANGED
@@ -68,7 +68,7 @@
68
 
69
  ## Model Support
70
  This version additionally supports:
71
- 1. Third-party model **gps** (gpt-4-gizmo-*)
72
  2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md)
73
  3. Custom channels with full API URL support
74
  4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md)
@@ -162,7 +162,7 @@ docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtow
162
 
163
  ## Channel Retry
164
  Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
165
- First retry uses same priority, second retry uses next priority, and so on.
166
 
167
  ### Cache Configuration
168
  1. `REDIS_CONN_STRING`: Use Redis as cache
 
68
 
69
  ## Model Support
70
  This version additionally supports:
71
+ 1. Third-party model **gpts** (gpt-4-gizmo-*)
72
  2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md)
73
  3. Custom channels with full API URL support
74
  4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md)
 
162
 
163
  ## Channel Retry
164
  Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
165
+ If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request.
166
 
167
  ### Cache Configuration
168
  1. `REDIS_CONN_STRING`: Use Redis as cache
VERSION CHANGED
@@ -1 +1 @@
1
- v0.4.8.8.3
 
1
+ v0.4.9.0
common/gopool.go ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "context"
5
+ "fmt"
6
+ "github.com/bytedance/gopkg/util/gopool"
7
+ "math"
8
+ )
9
+
10
+ var relayGoPool gopool.Pool
11
+
12
+ func init() {
13
+ relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
14
+ relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
15
+ if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
16
+ SafeSendBool(stopChan, true)
17
+ }
18
+ SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i))
19
+ })
20
+ }
21
+
22
+ func RelayCtxGo(ctx context.Context, f func()) {
23
+ relayGoPool.CtxGo(ctx, f)
24
+ }
controller/channel-test.go CHANGED
@@ -17,6 +17,7 @@ import (
17
  "one-api/relay"
18
  relaycommon "one-api/relay/common"
19
  "one-api/relay/constant"
 
20
  "one-api/service"
21
  "strconv"
22
  "strings"
@@ -72,18 +73,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
72
  }
73
  }
74
 
75
- modelMapping := *channel.ModelMapping
76
- if modelMapping != "" && modelMapping != "{}" {
77
- modelMap := make(map[string]string)
78
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
79
- if err != nil {
80
- return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
81
- }
82
- if modelMap[testModel] != "" {
83
- testModel = modelMap[testModel]
84
- }
85
- }
86
-
87
  cache, err := model.GetUserCache(1)
88
  if err != nil {
89
  return err, nil
@@ -97,7 +86,14 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
97
 
98
  middleware.SetupContextForSelectedChannel(c, channel, testModel)
99
 
100
- meta := relaycommon.GenRelayInfo(c)
 
 
 
 
 
 
 
101
  apiType, _ := constant.ChannelType2APIType(channel.Type)
102
  adaptor := relay.GetAdaptor(apiType)
103
  if adaptor == nil {
@@ -105,12 +101,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
105
  }
106
 
107
  request := buildTestRequest(testModel)
108
- meta.UpstreamModelName = testModel
109
- common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
110
 
111
- adaptor.Init(meta)
112
 
113
- convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
114
  if err != nil {
115
  return err, nil
116
  }
@@ -120,7 +115,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
120
  }
121
  requestBody := bytes.NewBuffer(jsonData)
122
  c.Request.Body = io.NopCloser(requestBody)
123
- resp, err := adaptor.DoRequest(c, meta, requestBody)
124
  if err != nil {
125
  return err, nil
126
  }
@@ -132,7 +127,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
132
  return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
133
  }
134
  }
135
- usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
136
  if respErr != nil {
137
  return fmt.Errorf("%s", respErr.Error.Message), respErr
138
  }
@@ -145,29 +140,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
145
  if err != nil {
146
  return err, nil
147
  }
148
- modelPrice, usePrice := common.GetModelPrice(testModel, false)
149
- modelRatio, success := common.GetModelRatio(testModel)
150
- if !usePrice && !success {
151
- return fmt.Errorf("模型 %s 倍率和价格均未设置", testModel), nil
152
  }
153
- completionRatio := common.GetCompletionRatio(testModel)
154
- ratio := modelRatio
155
  quota := 0
156
- if !usePrice {
157
- quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
158
- quota = int(math.Round(float64(quota) * ratio))
159
- if ratio != 0 && quota <= 0 {
160
  quota = 1
161
  }
162
  } else {
163
- quota = int(modelPrice * common.QuotaPerUnit)
164
  }
165
  tok := time.Now()
166
  milliseconds := tok.Sub(tik).Milliseconds()
167
  consumedTime := float64(milliseconds) / 1000.0
168
- other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
169
- model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
170
- quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
171
  common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
172
  return nil, nil
173
  }
 
17
  "one-api/relay"
18
  relaycommon "one-api/relay/common"
19
  "one-api/relay/constant"
20
+ "one-api/relay/helper"
21
  "one-api/service"
22
  "strconv"
23
  "strings"
 
73
  }
74
  }
75
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  cache, err := model.GetUserCache(1)
77
  if err != nil {
78
  return err, nil
 
86
 
87
  middleware.SetupContextForSelectedChannel(c, channel, testModel)
88
 
89
+ info := relaycommon.GenRelayInfo(c)
90
+
91
+ err = helper.ModelMappedHelper(c, info)
92
+ if err != nil {
93
+ return err, nil
94
+ }
95
+ testModel = info.UpstreamModelName
96
+
97
  apiType, _ := constant.ChannelType2APIType(channel.Type)
98
  adaptor := relay.GetAdaptor(apiType)
99
  if adaptor == nil {
 
101
  }
102
 
103
  request := buildTestRequest(testModel)
104
+ common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
 
105
 
106
+ adaptor.Init(info)
107
 
108
+ convertedRequest, err := adaptor.ConvertRequest(c, info, request)
109
  if err != nil {
110
  return err, nil
111
  }
 
115
  }
116
  requestBody := bytes.NewBuffer(jsonData)
117
  c.Request.Body = io.NopCloser(requestBody)
118
+ resp, err := adaptor.DoRequest(c, info, requestBody)
119
  if err != nil {
120
  return err, nil
121
  }
 
127
  return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
128
  }
129
  }
130
+ usageA, respErr := adaptor.DoResponse(c, httpResp, info)
131
  if respErr != nil {
132
  return fmt.Errorf("%s", respErr.Error.Message), respErr
133
  }
 
140
  if err != nil {
141
  return err, nil
142
  }
143
+ info.PromptTokens = usage.PromptTokens
144
+ priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
145
+ if err != nil {
146
+ return err, nil
147
  }
 
 
148
  quota := 0
149
+ if !priceData.UsePrice {
150
+ quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
151
+ quota = int(math.Round(float64(quota) * priceData.ModelRatio))
152
+ if priceData.ModelRatio != 0 && quota <= 0 {
153
  quota = 1
154
  }
155
  } else {
156
+ quota = int(priceData.ModelPrice * common.QuotaPerUnit)
157
  }
158
  tok := time.Now()
159
  milliseconds := tok.Sub(tik).Milliseconds()
160
  consumedTime := float64(milliseconds) / 1000.0
161
+ other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, priceData.ModelPrice)
162
+ model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
163
+ quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
164
  common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
165
  return nil, nil
166
  }
controller/misc.go CHANGED
@@ -7,6 +7,7 @@ import (
7
  "one-api/common"
8
  "one-api/model"
9
  "one-api/setting"
 
10
  "strings"
11
 
12
  "github.com/gin-gonic/gin"
@@ -66,7 +67,8 @@ func GetStatus(c *gin.Context) {
66
  "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
67
  "mj_notify_enabled": setting.MjNotifyEnabled,
68
  "chats": setting.Chats,
69
- "demo_site_enabled": setting.DemoSiteEnabled,
 
70
  },
71
  })
72
  return
 
7
  "one-api/common"
8
  "one-api/model"
9
  "one-api/setting"
10
+ "one-api/setting/operation_setting"
11
  "strings"
12
 
13
  "github.com/gin-gonic/gin"
 
67
  "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
68
  "mj_notify_enabled": setting.MjNotifyEnabled,
69
  "chats": setting.Chats,
70
+ "demo_site_enabled": operation_setting.DemoSiteEnabled,
71
+ "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
72
  },
73
  })
74
  return
controller/pricing.go CHANGED
@@ -2,7 +2,6 @@ package controller
2
 
3
  import (
4
  "github.com/gin-gonic/gin"
5
- "one-api/common"
6
  "one-api/model"
7
  "one-api/setting"
8
  )
@@ -40,7 +39,7 @@ func GetPricing(c *gin.Context) {
40
  }
41
 
42
  func ResetModelRatio(c *gin.Context) {
43
- defaultStr := common.DefaultModelRatio2JSONString()
44
  err := model.UpdateOption("ModelRatio", defaultStr)
45
  if err != nil {
46
  c.JSON(200, gin.H{
@@ -49,7 +48,7 @@ func ResetModelRatio(c *gin.Context) {
49
  })
50
  return
51
  }
52
- err = common.UpdateModelRatioByJSONString(defaultStr)
53
  if err != nil {
54
  c.JSON(200, gin.H{
55
  "success": false,
 
2
 
3
  import (
4
  "github.com/gin-gonic/gin"
 
5
  "one-api/model"
6
  "one-api/setting"
7
  )
 
39
  }
40
 
41
  func ResetModelRatio(c *gin.Context) {
42
+ defaultStr := setting.DefaultModelRatio2JSONString()
43
  err := model.UpdateOption("ModelRatio", defaultStr)
44
  if err != nil {
45
  c.JSON(200, gin.H{
 
48
  })
49
  return
50
  }
51
+ err = setting.UpdateModelRatioByJSONString(defaultStr)
52
  if err != nil {
53
  c.JSON(200, gin.H{
54
  "success": false,
controller/relay.go CHANGED
@@ -16,6 +16,7 @@ import (
16
  "one-api/relay"
17
  "one-api/relay/constant"
18
  relayconstant "one-api/relay/constant"
 
19
  "one-api/service"
20
  "strings"
21
  )
@@ -41,15 +42,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
41
  return err
42
  }
43
 
44
- func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
45
- var err *dto.OpenAIErrorWithStatusCode
46
- switch relayMode {
47
- default:
48
- err = relay.TextHelper(c)
49
- }
50
- return err
51
- }
52
-
53
  func Relay(c *gin.Context) {
54
  relayMode := constant.Path2RelayMode(c.Request.URL.Path)
55
  requestId := c.GetString(common.RequestIdKey)
@@ -110,7 +102,7 @@ func WssRelay(c *gin.Context) {
110
 
111
  if err != nil {
112
  openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
113
- service.WssError(c, ws, openaiErr.Error)
114
  return
115
  }
116
 
@@ -152,7 +144,7 @@ func WssRelay(c *gin.Context) {
152
  openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
153
  }
154
  openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
155
- service.WssError(c, ws, openaiErr.Error)
156
  }
157
  }
158
 
 
16
  "one-api/relay"
17
  "one-api/relay/constant"
18
  relayconstant "one-api/relay/constant"
19
+ "one-api/relay/helper"
20
  "one-api/service"
21
  "strings"
22
  )
 
42
  return err
43
  }
44
 
 
 
 
 
 
 
 
 
 
45
  func Relay(c *gin.Context) {
46
  relayMode := constant.Path2RelayMode(c.Request.URL.Path)
47
  requestId := c.GetString(common.RequestIdKey)
 
102
 
103
  if err != nil {
104
  openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
105
+ helper.WssError(c, ws, openaiErr.Error)
106
  return
107
  }
108
 
 
144
  openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
145
  }
146
  openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
147
+ helper.WssError(c, ws, openaiErr.Error)
148
  }
149
  }
150
 
middleware/model-rate-limit.go CHANGED
@@ -51,7 +51,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
51
  // 如果在时间窗口内已达到限制,拒绝请求
52
  subTime := nowTime.Sub(oldTime).Seconds()
53
  if int64(subTime) < duration {
54
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
55
  return false, nil
56
  }
57
 
@@ -68,7 +68,7 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC
68
  now := time.Now().Format(timeFormat)
69
  rdb.LPush(ctx, key, now)
70
  rdb.LTrim(ctx, key, 0, int64(maxCount-1))
71
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
72
  }
73
 
74
  // Redis限流处理器
@@ -118,7 +118,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
118
 
119
  // 内存限流处理器
120
  func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
121
- inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
122
 
123
  return func(c *gin.Context) {
124
  userId := strconv.Itoa(c.GetInt("id"))
@@ -153,20 +153,23 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
153
 
154
  // ModelRequestRateLimit 模型请求限流中间件
155
  func ModelRequestRateLimit() func(c *gin.Context) {
156
- // 如果未启用限流,直接放行
157
- if !setting.ModelRequestRateLimitEnabled {
158
- return defNext
159
- }
 
 
160
 
161
- // 计算限流参数
162
- duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
163
- totalMaxCount := setting.ModelRequestRateLimitCount
164
- successMaxCount := setting.ModelRequestRateLimitSuccessCount
165
 
166
- // 根据存储类型选择限流处理器
167
- if common.RedisEnabled {
168
- return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
169
- } else {
170
- return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
 
171
  }
172
  }
 
51
  // 如果在时间窗口内已达到限制,拒绝请求
52
  subTime := nowTime.Sub(oldTime).Seconds()
53
  if int64(subTime) < duration {
54
+ rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
55
  return false, nil
56
  }
57
 
 
68
  now := time.Now().Format(timeFormat)
69
  rdb.LPush(ctx, key, now)
70
  rdb.LTrim(ctx, key, 0, int64(maxCount-1))
71
+ rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
72
  }
73
 
74
  // Redis限流处理器
 
118
 
119
  // 内存限流处理器
120
  func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
121
+ inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
122
 
123
  return func(c *gin.Context) {
124
  userId := strconv.Itoa(c.GetInt("id"))
 
153
 
154
  // ModelRequestRateLimit 模型请求限流中间件
155
  func ModelRequestRateLimit() func(c *gin.Context) {
156
+ return func(c *gin.Context) {
157
+ // 在每个请求时检查是否启用限流
158
+ if !setting.ModelRequestRateLimitEnabled {
159
+ c.Next()
160
+ return
161
+ }
162
 
163
+ // 计算限流参数
164
+ duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
165
+ totalMaxCount := setting.ModelRequestRateLimitCount
166
+ successMaxCount := setting.ModelRequestRateLimitSuccessCount
167
 
168
+ // 根据存储类型选择并执行限流处理器
169
+ if common.RedisEnabled {
170
+ redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
171
+ } else {
172
+ memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
173
+ }
174
  }
175
  }
model/channel.go CHANGED
@@ -290,35 +290,42 @@ func (channel *Channel) Delete() error {
290
 
291
  var channelStatusLock sync.Mutex
292
 
293
- func UpdateChannelStatusById(id int, status int, reason string) {
294
  if common.MemoryCacheEnabled {
295
  channelStatusLock.Lock()
 
 
296
  channelCache, _ := CacheGetChannel(id)
297
  // 如果缓存渠道存在,且状态已是目标状态,直接返回
298
  if channelCache != nil && channelCache.Status == status {
299
- channelStatusLock.Unlock()
300
- return
301
  }
302
  // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
303
  if channelCache == nil && status != common.ChannelStatusEnabled {
304
- channelStatusLock.Unlock()
305
- return
306
  }
307
  CacheUpdateChannelStatus(id, status)
308
- channelStatusLock.Unlock()
309
  }
310
  err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
311
  if err != nil {
312
  common.SysError("failed to update ability status: " + err.Error())
 
313
  }
314
  channel, err := GetChannelById(id, true)
315
  if err != nil {
316
  // find channel by id error, directly update status
317
- err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
318
- if err != nil {
319
- common.SysError("failed to update channel status: " + err.Error())
 
 
 
 
320
  }
321
  } else {
 
 
 
322
  // find channel by id success, update status and other info
323
  info := channel.GetOtherInfo()
324
  info["status_reason"] = reason
@@ -328,9 +335,10 @@ func UpdateChannelStatusById(id int, status int, reason string) {
328
  err = channel.Save()
329
  if err != nil {
330
  common.SysError("failed to update channel status: " + err.Error())
 
331
  }
332
  }
333
-
334
  }
335
 
336
  func EnableChannelByTag(tag string) error {
 
290
 
291
  var channelStatusLock sync.Mutex
292
 
293
+ func UpdateChannelStatusById(id int, status int, reason string) bool {
294
  if common.MemoryCacheEnabled {
295
  channelStatusLock.Lock()
296
+ defer channelStatusLock.Unlock()
297
+
298
  channelCache, _ := CacheGetChannel(id)
299
  // 如果缓存渠道存在,且状态已是目标状态,直接返回
300
  if channelCache != nil && channelCache.Status == status {
301
+ return false
 
302
  }
303
  // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
304
  if channelCache == nil && status != common.ChannelStatusEnabled {
305
+ return false
 
306
  }
307
  CacheUpdateChannelStatus(id, status)
 
308
  }
309
  err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
310
  if err != nil {
311
  common.SysError("failed to update ability status: " + err.Error())
312
+ return false
313
  }
314
  channel, err := GetChannelById(id, true)
315
  if err != nil {
316
  // find channel by id error, directly update status
317
+ result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
318
+ if result.Error != nil {
319
+ common.SysError("failed to update channel status: " + result.Error.Error())
320
+ return false
321
+ }
322
+ if result.RowsAffected == 0 {
323
+ return false
324
  }
325
  } else {
326
+ if channel.Status == status {
327
+ return false
328
+ }
329
  // find channel by id success, update status and other info
330
  info := channel.GetOtherInfo()
331
  info["status_reason"] = reason
 
335
  err = channel.Save()
336
  if err != nil {
337
  common.SysError("failed to update channel status: " + err.Error())
338
+ return false
339
  }
340
  }
341
+ return true
342
  }
343
 
344
  func EnableChannelByTag(tag string) error {
model/log.go CHANGED
@@ -2,12 +2,13 @@ package model
2
 
3
  import (
4
  "fmt"
5
- "github.com/gin-gonic/gin"
6
  "one-api/common"
7
  "os"
8
  "strings"
9
  "time"
10
 
 
 
11
  "github.com/bytedance/gopkg/util/gopool"
12
  "gorm.io/gorm"
13
  )
@@ -18,7 +19,7 @@ type Log struct {
18
  CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
19
  Type int `json:"type" gorm:"index:idx_created_at_type"`
20
  Content string `json:"content"`
21
- Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
22
  TokenName string `json:"token_name" gorm:"index;default:''"`
23
  ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
24
  Quota int `json:"quota" gorm:"default:0"`
 
2
 
3
  import (
4
  "fmt"
 
5
  "one-api/common"
6
  "os"
7
  "strings"
8
  "time"
9
 
10
+ "github.com/gin-gonic/gin"
11
+
12
  "github.com/bytedance/gopkg/util/gopool"
13
  "gorm.io/gorm"
14
  )
 
19
  CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
20
  Type int `json:"type" gorm:"index:idx_created_at_type"`
21
  Content string `json:"content"`
22
+ Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
23
  TokenName string `json:"token_name" gorm:"index;default:''"`
24
  ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
25
  Quota int `json:"quota" gorm:"default:0"`
model/option.go CHANGED
@@ -4,6 +4,7 @@ import (
4
  "one-api/common"
5
  "one-api/setting"
6
  "one-api/setting/config"
 
7
  "strconv"
8
  "strings"
9
  "time"
@@ -87,15 +88,15 @@ func InitOptionMap() {
87
  common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
88
  common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
89
  common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
90
- common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
91
  common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
92
  common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
93
  common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
94
- common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
95
- common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
96
  common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
97
  common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
98
- common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
99
  common.OptionMap["TopUpLink"] = common.TopUpLink
100
  common.OptionMap["ChatLink"] = common.ChatLink
101
  common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -110,13 +111,14 @@ func InitOptionMap() {
110
  common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
111
  common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
112
  common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
113
- common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
 
114
  common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
115
  common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
116
  common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
117
  common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
118
  common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
119
- common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
120
 
121
  // 自动添加所有注册的模型配置
122
  modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -242,7 +244,9 @@ func updateOptionMap(key string, value string) (err error) {
242
  case "CheckSensitiveEnabled":
243
  setting.CheckSensitiveEnabled = boolValue
244
  case "DemoSiteEnabled":
245
- setting.DemoSiteEnabled = boolValue
 
 
246
  case "CheckSensitiveOnPromptEnabled":
247
  setting.CheckSensitiveOnPromptEnabled = boolValue
248
  case "ModelRequestRateLimitEnabled":
@@ -325,7 +329,7 @@ func updateOptionMap(key string, value string) (err error) {
325
  common.QuotaForInvitee, _ = strconv.Atoi(value)
326
  case "QuotaRemindThreshold":
327
  common.QuotaRemindThreshold, _ = strconv.Atoi(value)
328
- case "ShouldPreConsumedQuota":
329
  common.PreConsumedQuota, _ = strconv.Atoi(value)
330
  case "ModelRequestRateLimitCount":
331
  setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
@@ -340,15 +344,15 @@ func updateOptionMap(key string, value string) (err error) {
340
  case "DataExportDefaultTime":
341
  common.DataExportDefaultTime = value
342
  case "ModelRatio":
343
- err = common.UpdateModelRatioByJSONString(value)
344
  case "GroupRatio":
345
  err = setting.UpdateGroupRatioByJSONString(value)
346
  case "UserUsableGroups":
347
  err = setting.UpdateUserUsableGroupsByJSONString(value)
348
  case "CompletionRatio":
349
- err = common.UpdateCompletionRatioByJSONString(value)
350
  case "ModelPrice":
351
- err = common.UpdateModelPriceByJSONString(value)
352
  case "TopUpLink":
353
  common.TopUpLink = value
354
  case "ChatLink":
@@ -362,7 +366,7 @@ func updateOptionMap(key string, value string) (err error) {
362
  case "SensitiveWords":
363
  setting.SensitiveWordsFromString(value)
364
  case "AutomaticDisableKeywords":
365
- setting.AutomaticDisableKeywordsFromString(value)
366
  case "StreamCacheQueueLength":
367
  setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
368
  }
 
4
  "one-api/common"
5
  "one-api/setting"
6
  "one-api/setting/config"
7
+ "one-api/setting/operation_setting"
8
  "strconv"
9
  "strings"
10
  "time"
 
88
  common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
89
  common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
90
  common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
91
+ common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
92
  common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
93
  common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
94
  common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
95
+ common.OptionMap["ModelRatio"] = setting.ModelRatio2JSONString()
96
+ common.OptionMap["ModelPrice"] = setting.ModelPrice2JSONString()
97
  common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
98
  common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
99
+ common.OptionMap["CompletionRatio"] = setting.CompletionRatio2JSONString()
100
  common.OptionMap["TopUpLink"] = common.TopUpLink
101
  common.OptionMap["ChatLink"] = common.ChatLink
102
  common.OptionMap["ChatLink2"] = common.ChatLink2
 
111
  common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
112
  common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
113
  common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
114
+ common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled)
115
+ common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled)
116
  common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
117
  common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
118
  common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
119
  common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
120
  common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
121
+ common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
122
 
123
  // 自动添加所有注册的模型配置
124
  modelConfigs := config.GlobalConfig.ExportAllConfigs()
 
244
  case "CheckSensitiveEnabled":
245
  setting.CheckSensitiveEnabled = boolValue
246
  case "DemoSiteEnabled":
247
+ operation_setting.DemoSiteEnabled = boolValue
248
+ case "SelfUseModeEnabled":
249
+ operation_setting.SelfUseModeEnabled = boolValue
250
  case "CheckSensitiveOnPromptEnabled":
251
  setting.CheckSensitiveOnPromptEnabled = boolValue
252
  case "ModelRequestRateLimitEnabled":
 
329
  common.QuotaForInvitee, _ = strconv.Atoi(value)
330
  case "QuotaRemindThreshold":
331
  common.QuotaRemindThreshold, _ = strconv.Atoi(value)
332
+ case "PreConsumedQuota":
333
  common.PreConsumedQuota, _ = strconv.Atoi(value)
334
  case "ModelRequestRateLimitCount":
335
  setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
 
344
  case "DataExportDefaultTime":
345
  common.DataExportDefaultTime = value
346
  case "ModelRatio":
347
+ err = setting.UpdateModelRatioByJSONString(value)
348
  case "GroupRatio":
349
  err = setting.UpdateGroupRatioByJSONString(value)
350
  case "UserUsableGroups":
351
  err = setting.UpdateUserUsableGroupsByJSONString(value)
352
  case "CompletionRatio":
353
+ err = setting.UpdateCompletionRatioByJSONString(value)
354
  case "ModelPrice":
355
+ err = setting.UpdateModelPriceByJSONString(value)
356
  case "TopUpLink":
357
  common.TopUpLink = value
358
  case "ChatLink":
 
366
  case "SensitiveWords":
367
  setting.SensitiveWordsFromString(value)
368
  case "AutomaticDisableKeywords":
369
+ operation_setting.AutomaticDisableKeywordsFromString(value)
370
  case "StreamCacheQueueLength":
371
  setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
372
  }
model/pricing.go CHANGED
@@ -2,6 +2,7 @@ package model
2
 
3
  import (
4
  "one-api/common"
 
5
  "sync"
6
  "time"
7
  )
@@ -64,14 +65,14 @@ func updatePricing() {
64
  ModelName: model,
65
  EnableGroup: groups,
66
  }
67
- modelPrice, findPrice := common.GetModelPrice(model, false)
68
  if findPrice {
69
  pricing.ModelPrice = modelPrice
70
  pricing.QuotaType = 1
71
  } else {
72
- modelRatio, _ := common.GetModelRatio(model)
73
  pricing.ModelRatio = modelRatio
74
- pricing.CompletionRatio = common.GetCompletionRatio(model)
75
  pricing.QuotaType = 0
76
  }
77
  pricingMap = append(pricingMap, pricing)
 
2
 
3
  import (
4
  "one-api/common"
5
+ "one-api/setting"
6
  "sync"
7
  "time"
8
  )
 
65
  ModelName: model,
66
  EnableGroup: groups,
67
  }
68
+ modelPrice, findPrice := setting.GetModelPrice(model, false)
69
  if findPrice {
70
  pricing.ModelPrice = modelPrice
71
  pricing.QuotaType = 1
72
  } else {
73
+ modelRatio, _ := setting.GetModelRatio(model)
74
  pricing.ModelRatio = modelRatio
75
+ pricing.CompletionRatio = setting.GetCompletionRatio(model)
76
  pricing.QuotaType = 0
77
  }
78
  pricingMap = append(pricingMap, pricing)
relay/channel/ali/text.go CHANGED
@@ -8,6 +8,7 @@ import (
8
  "net/http"
9
  "one-api/common"
10
  "one-api/dto"
 
11
  "one-api/service"
12
  "strings"
13
  )
@@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
153
  }
154
  stopChan <- true
155
  }()
156
- service.SetEventStreamHeaders(c)
157
  lastResponseText := ""
158
  c.Stream(func(w io.Writer) bool {
159
  select {
 
8
  "net/http"
9
  "one-api/common"
10
  "one-api/dto"
11
+ "one-api/relay/helper"
12
  "one-api/service"
13
  "strings"
14
  )
 
154
  }
155
  stopChan <- true
156
  }()
157
+ helper.SetEventStreamHeaders(c)
158
  lastResponseText := ""
159
  c.Stream(func(w io.Writer) bool {
160
  select {
relay/channel/aws/dto.go CHANGED
@@ -14,7 +14,7 @@ type AwsClaudeRequest struct {
14
  TopP float64 `json:"top_p,omitempty"`
15
  TopK int `json:"top_k,omitempty"`
16
  StopSequences []string `json:"stop_sequences,omitempty"`
17
- Tools []claude.Tool `json:"tools,omitempty"`
18
  ToolChoice any `json:"tool_choice,omitempty"`
19
  Thinking *claude.Thinking `json:"thinking,omitempty"`
20
  }
 
14
  TopP float64 `json:"top_p,omitempty"`
15
  TopK int `json:"top_k,omitempty"`
16
  StopSequences []string `json:"stop_sequences,omitempty"`
17
+ Tools any `json:"tools,omitempty"`
18
  ToolChoice any `json:"tool_choice,omitempty"`
19
  Thinking *claude.Thinking `json:"thinking,omitempty"`
20
  }
relay/channel/aws/relay-aws.go CHANGED
@@ -12,6 +12,7 @@ import (
12
  relaymodel "one-api/dto"
13
  "one-api/relay/channel/claude"
14
  relaycommon "one-api/relay/common"
 
15
  "one-api/service"
16
  "strings"
17
  "time"
@@ -203,13 +204,13 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
203
  }
204
  })
205
  if info.ShouldIncludeUsage {
206
- response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
207
- err := service.ObjectData(c, response)
208
  if err != nil {
209
  common.SysError("send final response failed: " + err.Error())
210
  }
211
  }
212
- service.Done(c)
213
  if resp != nil {
214
  err = resp.Body.Close()
215
  if err != nil {
 
12
  relaymodel "one-api/dto"
13
  "one-api/relay/channel/claude"
14
  relaycommon "one-api/relay/common"
15
+ "one-api/relay/helper"
16
  "one-api/service"
17
  "strings"
18
  "time"
 
204
  }
205
  })
206
  if info.ShouldIncludeUsage {
207
+ response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
208
+ err := helper.ObjectData(c, response)
209
  if err != nil {
210
  common.SysError("send final response failed: " + err.Error())
211
  }
212
  }
213
+ helper.Done(c)
214
  if resp != nil {
215
  err = resp.Body.Close()
216
  if err != nil {
relay/channel/baidu/relay-baidu.go CHANGED
@@ -11,6 +11,7 @@ import (
11
  "one-api/common"
12
  "one-api/constant"
13
  "one-api/dto"
 
14
  "one-api/service"
15
  "strings"
16
  "sync"
@@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
138
  }
139
  stopChan <- true
140
  }()
141
- service.SetEventStreamHeaders(c)
142
  c.Stream(func(w io.Writer) bool {
143
  select {
144
  case data := <-dataChan:
 
11
  "one-api/common"
12
  "one-api/constant"
13
  "one-api/dto"
14
+ "one-api/relay/helper"
15
  "one-api/service"
16
  "strings"
17
  "sync"
 
139
  }
140
  stopChan <- true
141
  }()
142
+ helper.SetEventStreamHeaders(c)
143
  c.Stream(func(w io.Writer) bool {
144
  select {
145
  case data := <-dataChan:
relay/channel/claude/dto.go CHANGED
@@ -58,7 +58,7 @@ type ClaudeRequest struct {
58
  TopK int `json:"top_k,omitempty"`
59
  //ClaudeMetadata `json:"metadata,omitempty"`
60
  Stream bool `json:"stream,omitempty"`
61
- Tools []Tool `json:"tools,omitempty"`
62
  ToolChoice any `json:"tool_choice,omitempty"`
63
  Thinking *Thinking `json:"thinking,omitempty"`
64
  }
 
58
  TopK int `json:"top_k,omitempty"`
59
  //ClaudeMetadata `json:"metadata,omitempty"`
60
  Stream bool `json:"stream,omitempty"`
61
+ Tools any `json:"tools,omitempty"`
62
  ToolChoice any `json:"tool_choice,omitempty"`
63
  Thinking *Thinking `json:"thinking,omitempty"`
64
  }
relay/channel/claude/relay-claude.go CHANGED
@@ -1,7 +1,6 @@
1
  package claude
2
 
3
  import (
4
- "bufio"
5
  "encoding/json"
6
  "fmt"
7
  "io"
@@ -9,6 +8,7 @@ import (
9
  "one-api/common"
10
  "one-api/dto"
11
  relaycommon "one-api/relay/common"
 
12
  "one-api/service"
13
  "one-api/setting/model_setting"
14
  "strings"
@@ -443,28 +443,18 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
443
  usage = &dto.Usage{}
444
  responseText := ""
445
  createdTime := common.GetTimestamp()
446
- scanner := bufio.NewScanner(resp.Body)
447
- scanner.Split(bufio.ScanLines)
448
- service.SetEventStreamHeaders(c)
449
 
450
- for scanner.Scan() {
451
- data := scanner.Text()
452
- info.SetFirstResponseTime()
453
- if len(data) < 6 || !strings.HasPrefix(data, "data:") {
454
- continue
455
- }
456
- data = strings.TrimPrefix(data, "data:")
457
- data = strings.TrimSpace(data)
458
  var claudeResponse ClaudeResponse
459
  err := json.Unmarshal([]byte(data), &claudeResponse)
460
  if err != nil {
461
  common.SysError("error unmarshalling stream response: " + err.Error())
462
- continue
463
  }
464
 
465
  response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
466
  if response == nil {
467
- continue
468
  }
469
  if requestMode == RequestModeCompletion {
470
  responseText += claudeResponse.Completion
@@ -481,9 +471,9 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
481
  usage.CompletionTokens = claudeUsage.OutputTokens
482
  usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
483
  } else if claudeResponse.Type == "content_block_start" {
484
-
485
  } else {
486
- continue
487
  }
488
  }
489
  //response.Id = responseId
@@ -491,11 +481,12 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
491
  response.Created = createdTime
492
  response.Model = info.UpstreamModelName
493
 
494
- err = service.ObjectData(c, response)
495
  if err != nil {
496
  common.LogError(c, "send_stream_response_failed: "+err.Error())
497
  }
498
- }
 
499
 
500
  if requestMode == RequestModeCompletion {
501
  usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
@@ -508,14 +499,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
508
  }
509
  }
510
  if info.ShouldIncludeUsage {
511
- response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
512
- err := service.ObjectData(c, response)
513
  if err != nil {
514
  common.SysError("send final response failed: " + err.Error())
515
  }
516
  }
517
- service.Done(c)
518
- resp.Body.Close()
519
  return nil, usage
520
  }
521
 
 
1
  package claude
2
 
3
  import (
 
4
  "encoding/json"
5
  "fmt"
6
  "io"
 
8
  "one-api/common"
9
  "one-api/dto"
10
  relaycommon "one-api/relay/common"
11
+ "one-api/relay/helper"
12
  "one-api/service"
13
  "one-api/setting/model_setting"
14
  "strings"
 
443
  usage = &dto.Usage{}
444
  responseText := ""
445
  createdTime := common.GetTimestamp()
 
 
 
446
 
447
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 
 
 
 
 
 
 
448
  var claudeResponse ClaudeResponse
449
  err := json.Unmarshal([]byte(data), &claudeResponse)
450
  if err != nil {
451
  common.SysError("error unmarshalling stream response: " + err.Error())
452
+ return true
453
  }
454
 
455
  response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
456
  if response == nil {
457
+ return true
458
  }
459
  if requestMode == RequestModeCompletion {
460
  responseText += claudeResponse.Completion
 
471
  usage.CompletionTokens = claudeUsage.OutputTokens
472
  usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
473
  } else if claudeResponse.Type == "content_block_start" {
474
+ return true
475
  } else {
476
+ return true
477
  }
478
  }
479
  //response.Id = responseId
 
481
  response.Created = createdTime
482
  response.Model = info.UpstreamModelName
483
 
484
+ err = helper.ObjectData(c, response)
485
  if err != nil {
486
  common.LogError(c, "send_stream_response_failed: "+err.Error())
487
  }
488
+ return true
489
+ })
490
 
491
  if requestMode == RequestModeCompletion {
492
  usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 
499
  }
500
  }
501
  if info.ShouldIncludeUsage {
502
+ response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
503
+ err := helper.ObjectData(c, response)
504
  if err != nil {
505
  common.SysError("send final response failed: " + err.Error())
506
  }
507
  }
508
+ helper.Done(c)
509
+ //resp.Body.Close()
510
  return nil, usage
511
  }
512
 
relay/channel/cloudflare/relay_cloudflare.go CHANGED
@@ -9,6 +9,7 @@ import (
9
  "one-api/common"
10
  "one-api/dto"
11
  relaycommon "one-api/relay/common"
 
12
  "one-api/service"
13
  "strings"
14
  "time"
@@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
28
  scanner := bufio.NewScanner(resp.Body)
29
  scanner.Split(bufio.ScanLines)
30
 
31
- service.SetEventStreamHeaders(c)
32
- id := service.GetResponseID(c)
33
  var responseText string
34
  isFirst := true
35
 
@@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
57
  }
58
  response.Id = id
59
  response.Model = info.UpstreamModelName
60
- err = service.ObjectData(c, response)
61
  if isFirst {
62
  isFirst = false
63
  info.FirstResponseTime = time.Now()
@@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
72
  }
73
  usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
74
  if info.ShouldIncludeUsage {
75
- response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
76
- err := service.ObjectData(c, response)
77
  if err != nil {
78
  common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
79
  }
80
  }
81
- service.Done(c)
82
 
83
  err := resp.Body.Close()
84
  if err != nil {
@@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
109
  }
110
  usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
111
  response.Usage = *usage
112
- response.Id = service.GetResponseID(c)
113
  jsonResponse, err := json.Marshal(response)
114
  if err != nil {
115
  return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 
9
  "one-api/common"
10
  "one-api/dto"
11
  relaycommon "one-api/relay/common"
12
+ "one-api/relay/helper"
13
  "one-api/service"
14
  "strings"
15
  "time"
 
29
  scanner := bufio.NewScanner(resp.Body)
30
  scanner.Split(bufio.ScanLines)
31
 
32
+ helper.SetEventStreamHeaders(c)
33
+ id := helper.GetResponseID(c)
34
  var responseText string
35
  isFirst := true
36
 
 
58
  }
59
  response.Id = id
60
  response.Model = info.UpstreamModelName
61
+ err = helper.ObjectData(c, response)
62
  if isFirst {
63
  isFirst = false
64
  info.FirstResponseTime = time.Now()
 
73
  }
74
  usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
75
  if info.ShouldIncludeUsage {
76
+ response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
77
+ err := helper.ObjectData(c, response)
78
  if err != nil {
79
  common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
80
  }
81
  }
82
+ helper.Done(c)
83
 
84
  err := resp.Body.Close()
85
  if err != nil {
 
110
  }
111
  usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
112
  response.Usage = *usage
113
+ response.Id = helper.GetResponseID(c)
114
  jsonResponse, err := json.Marshal(response)
115
  if err != nil {
116
  return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
relay/channel/cohere/relay-cohere.go CHANGED
@@ -10,6 +10,7 @@ import (
10
  "one-api/common"
11
  "one-api/dto"
12
  relaycommon "one-api/relay/common"
 
13
  "one-api/service"
14
  "strings"
15
  "time"
@@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
103
  }
104
  stopChan <- true
105
  }()
106
- service.SetEventStreamHeaders(c)
107
  isFirst := true
108
  c.Stream(func(w io.Writer) bool {
109
  select {
 
10
  "one-api/common"
11
  "one-api/dto"
12
  relaycommon "one-api/relay/common"
13
+ "one-api/relay/helper"
14
  "one-api/service"
15
  "strings"
16
  "time"
 
104
  }
105
  stopChan <- true
106
  }()
107
+ helper.SetEventStreamHeaders(c)
108
  isFirst := true
109
  c.Stream(func(w io.Writer) bool {
110
  select {
relay/channel/dify/relay-dify.go CHANGED
@@ -10,6 +10,7 @@ import (
10
  "one-api/constant"
11
  "one-api/dto"
12
  relaycommon "one-api/relay/common"
 
13
  "one-api/service"
14
  "strings"
15
  )
@@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
66
  scanner := bufio.NewScanner(resp.Body)
67
  scanner.Split(bufio.ScanLines)
68
 
69
- service.SetEventStreamHeaders(c)
70
 
71
  for scanner.Scan() {
72
  data := scanner.Text()
@@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
92
  responseText += openaiResponse.Choices[0].Delta.GetContentString()
93
  }
94
  }
95
- err = service.ObjectData(c, openaiResponse)
96
  if err != nil {
97
  common.SysError(err.Error())
98
  }
@@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
100
  if err := scanner.Err(); err != nil {
101
  common.SysError("error reading stream: " + err.Error())
102
  }
103
- service.Done(c)
104
  err := resp.Body.Close()
105
  if err != nil {
106
  //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 
10
  "one-api/constant"
11
  "one-api/dto"
12
  relaycommon "one-api/relay/common"
13
+ "one-api/relay/helper"
14
  "one-api/service"
15
  "strings"
16
  )
 
67
  scanner := bufio.NewScanner(resp.Body)
68
  scanner.Split(bufio.ScanLines)
69
 
70
+ helper.SetEventStreamHeaders(c)
71
 
72
  for scanner.Scan() {
73
  data := scanner.Text()
 
93
  responseText += openaiResponse.Choices[0].Delta.GetContentString()
94
  }
95
  }
96
+ err = helper.ObjectData(c, openaiResponse)
97
  if err != nil {
98
  common.SysError(err.Error())
99
  }
 
101
  if err := scanner.Err(); err != nil {
102
  common.SysError("error reading stream: " + err.Error())
103
  }
104
+ helper.Done(c)
105
  err := resp.Body.Close()
106
  if err != nil {
107
  //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
relay/channel/gemini/relay-gemini.go CHANGED
@@ -1,7 +1,6 @@
1
  package gemini
2
 
3
  import (
4
- "bufio"
5
  "encoding/json"
6
  "fmt"
7
  "io"
@@ -10,6 +9,7 @@ import (
10
  "one-api/constant"
11
  "one-api/dto"
12
  relaycommon "one-api/relay/common"
 
13
  "one-api/service"
14
  "one-api/setting/model_setting"
15
  "strings"
@@ -429,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
429
 
430
  func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
431
  choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
432
- is_stop := false
433
  for _, candidate := range geminiResponse.Candidates {
434
  if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
435
- is_stop = true
436
  candidate.FinishReason = nil
437
  }
438
  choice := dto.ChatCompletionsStreamResponseChoice{
@@ -482,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
482
 
483
  var response dto.ChatCompletionsStreamResponse
484
  response.Object = "chat.completion.chunk"
485
- response.Model = "gemini"
486
  response.Choices = choices
487
- return &response, is_stop
488
  }
489
 
490
  func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -492,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
492
  id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
493
  createAt := common.GetTimestamp()
494
  var usage = &dto.Usage{}
495
- scanner := bufio.NewScanner(resp.Body)
496
- scanner.Split(bufio.ScanLines)
497
-
498
- service.SetEventStreamHeaders(c)
499
- for scanner.Scan() {
500
- data := scanner.Text()
501
- info.SetFirstResponseTime()
502
- data = strings.TrimSpace(data)
503
- if !strings.HasPrefix(data, "data: ") {
504
- continue
505
- }
506
- data = strings.TrimPrefix(data, "data: ")
507
- data = strings.TrimSuffix(data, "\"")
508
  var geminiResponse GeminiChatResponse
509
  err := json.Unmarshal([]byte(data), &geminiResponse)
510
  if err != nil {
511
  common.LogError(c, "error unmarshalling stream response: "+err.Error())
512
- continue
513
  }
514
 
515
- response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
516
  response.Id = id
517
  response.Created = createAt
518
  response.Model = info.UpstreamModelName
@@ -521,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
521
  usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
522
  usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
523
  }
524
- err = service.ObjectData(c, response)
525
  if err != nil {
526
  common.LogError(c, err.Error())
527
  }
528
- if is_stop {
529
- response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
530
- service.ObjectData(c, response)
531
  }
532
- }
 
533
 
534
  var response *dto.ChatCompletionsStreamResponse
535
 
@@ -538,14 +527,14 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
538
  usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
539
 
540
  if info.ShouldIncludeUsage {
541
- response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
542
- err := service.ObjectData(c, response)
543
  if err != nil {
544
  common.SysError("send final response failed: " + err.Error())
545
  }
546
  }
547
- service.Done(c)
548
- resp.Body.Close()
549
  return nil, usage
550
  }
551
 
 
1
  package gemini
2
 
3
  import (
 
4
  "encoding/json"
5
  "fmt"
6
  "io"
 
9
  "one-api/constant"
10
  "one-api/dto"
11
  relaycommon "one-api/relay/common"
12
+ "one-api/relay/helper"
13
  "one-api/service"
14
  "one-api/setting/model_setting"
15
  "strings"
 
429
 
430
  func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
431
  choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
432
+ isStop := false
433
  for _, candidate := range geminiResponse.Candidates {
434
  if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
435
+ isStop = true
436
  candidate.FinishReason = nil
437
  }
438
  choice := dto.ChatCompletionsStreamResponseChoice{
 
482
 
483
  var response dto.ChatCompletionsStreamResponse
484
  response.Object = "chat.completion.chunk"
 
485
  response.Choices = choices
486
+ return &response, isStop
487
  }
488
 
489
  func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 
491
  id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
492
  createAt := common.GetTimestamp()
493
  var usage = &dto.Usage{}
494
+
495
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 
 
 
 
 
 
 
 
 
 
 
496
  var geminiResponse GeminiChatResponse
497
  err := json.Unmarshal([]byte(data), &geminiResponse)
498
  if err != nil {
499
  common.LogError(c, "error unmarshalling stream response: "+err.Error())
500
+ return false
501
  }
502
 
503
+ response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
504
  response.Id = id
505
  response.Created = createAt
506
  response.Model = info.UpstreamModelName
 
509
  usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
510
  usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
511
  }
512
+ err = helper.ObjectData(c, response)
513
  if err != nil {
514
  common.LogError(c, err.Error())
515
  }
516
+ if isStop {
517
+ response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
518
+ helper.ObjectData(c, response)
519
  }
520
+ return true
521
+ })
522
 
523
  var response *dto.ChatCompletionsStreamResponse
524
 
 
527
  usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
528
 
529
  if info.ShouldIncludeUsage {
530
+ response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
531
+ err := helper.ObjectData(c, response)
532
  if err != nil {
533
  common.SysError("send final response failed: " + err.Error())
534
  }
535
  }
536
+ helper.Done(c)
537
+ //resp.Body.Close()
538
  return nil, usage
539
  }
540
 
relay/channel/openai/relay-openai.go CHANGED
@@ -1,10 +1,13 @@
1
  package openai
2
 
3
  import (
4
- "bufio"
5
  "bytes"
6
  "encoding/json"
7
  "fmt"
 
 
 
 
8
  "io"
9
  "math"
10
  "mime/multipart"
@@ -14,16 +17,10 @@ import (
14
  "one-api/dto"
15
  relaycommon "one-api/relay/common"
16
  relayconstant "one-api/relay/constant"
 
17
  "one-api/service"
18
  "os"
19
  "strings"
20
- "sync"
21
- "time"
22
-
23
- "github.com/bytedance/gopkg/util/gopool"
24
- "github.com/gin-gonic/gin"
25
- "github.com/gorilla/websocket"
26
- "github.com/pkg/errors"
27
  )
28
 
29
  func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
@@ -32,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
32
  }
33
 
34
  if !forceFormat && !thinkToContent {
35
- return service.StringData(c, data)
36
  }
37
 
38
  var lastStreamResponse dto.ChatCompletionsStreamResponse
@@ -41,34 +38,47 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
41
  }
42
 
43
  if !thinkToContent {
44
- return service.ObjectData(c, lastStreamResponse)
 
 
 
 
 
 
 
 
45
  }
46
 
47
  // Handle think to content conversion
48
- if info.IsFirstResponse {
49
- response := lastStreamResponse.Copy()
50
- for i := range response.Choices {
51
- response.Choices[i].Delta.SetContentString("<think>\n")
52
- response.Choices[i].Delta.SetReasoningContent("")
 
 
 
 
 
 
53
  }
54
- service.ObjectData(c, response)
55
  }
56
 
57
  if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
58
- return service.ObjectData(c, lastStreamResponse)
59
  }
60
 
61
  // Process each choice
62
  for i, choice := range lastStreamResponse.Choices {
63
  // Handle transition from thinking to content
64
- if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
65
  response := lastStreamResponse.Copy()
66
  for j := range response.Choices {
67
- response.Choices[j].Delta.SetContentString("\n</think>")
68
  response.Choices[j].Delta.SetReasoningContent("")
69
  }
70
- info.SendLastReasoningResponse = true
71
- service.ObjectData(c, response)
72
  }
73
 
74
  // Convert reasoning content to regular content
@@ -78,7 +88,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
78
  }
79
  }
80
 
81
- return service.ObjectData(c, lastStreamResponse)
82
  }
83
 
84
  func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -108,65 +118,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
108
  }
109
 
110
  toolCount := 0
111
- scanner := bufio.NewScanner(resp.Body)
112
- scanner.Split(bufio.ScanLines)
113
-
114
- service.SetEventStreamHeaders(c)
115
- streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
116
- if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
117
- // twice timeout for o1 model
118
- streamingTimeout *= 2
119
- }
120
- ticker := time.NewTicker(streamingTimeout)
121
- defer ticker.Stop()
122
 
123
- stopChan := make(chan bool)
124
- defer close(stopChan)
125
  var (
126
  lastStreamData string
127
- mu sync.Mutex
128
  )
129
- gopool.Go(func() {
130
- for scanner.Scan() {
131
- //info.SetFirstResponseTime()
132
- ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
133
- data := scanner.Text()
134
- if common.DebugEnabled {
135
- println(data)
136
- }
137
- if len(data) < 6 { // ignore blank line or wrong format
138
- continue
139
- }
140
- if data[:5] != "data:" && data[:6] != "[DONE]" {
141
- continue
142
- }
143
- mu.Lock()
144
- data = data[5:]
145
- data = strings.TrimSpace(data)
146
- if !strings.HasPrefix(data, "[DONE]") {
147
- if lastStreamData != "" {
148
- err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
149
- if err != nil {
150
- common.LogError(c, "streaming error: "+err.Error())
151
- }
152
- info.SetFirstResponseTime()
153
- }
154
- lastStreamData = data
155
- streamItems = append(streamItems, data)
156
  }
157
- mu.Unlock()
158
  }
159
- common.SafeSendBool(stopChan, true)
 
 
160
  })
161
 
162
- select {
163
- case <-ticker.C:
164
- // 超时处理逻辑
165
- common.LogError(c, "streaming timeout")
166
- case <-stopChan:
167
- // 正常结束
168
- }
169
-
170
  shouldSendLastResp := true
171
  var lastStreamResponse dto.ChatCompletionsStreamResponse
172
  err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
@@ -274,14 +242,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
274
  }
275
 
276
  if info.ShouldIncludeUsage && !containStreamUsage {
277
- response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
278
  response.SetSystemFingerprint(systemFingerprint)
279
- service.ObjectData(c, response)
280
  }
281
 
282
- service.Done(c)
283
 
284
- resp.Body.Close()
285
  return nil, usage
286
  }
287
 
@@ -512,7 +480,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
512
  localUsage.InputTokenDetails.TextTokens += textToken
513
  localUsage.InputTokenDetails.AudioTokens += audioToken
514
 
515
- err = service.WssString(c, targetConn, string(message))
516
  if err != nil {
517
  errChan <- fmt.Errorf("error writing to target: %v", err)
518
  return
@@ -618,7 +586,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
618
  localUsage.OutputTokenDetails.AudioTokens += audioToken
619
  }
620
 
621
- err = service.WssString(c, clientConn, string(message))
622
  if err != nil {
623
  errChan <- fmt.Errorf("error writing to client: %v", err)
624
  return
 
1
  package openai
2
 
3
  import (
 
4
  "bytes"
5
  "encoding/json"
6
  "fmt"
7
+ "github.com/bytedance/gopkg/util/gopool"
8
+ "github.com/gin-gonic/gin"
9
+ "github.com/gorilla/websocket"
10
+ "github.com/pkg/errors"
11
  "io"
12
  "math"
13
  "mime/multipart"
 
17
  "one-api/dto"
18
  relaycommon "one-api/relay/common"
19
  relayconstant "one-api/relay/constant"
20
+ "one-api/relay/helper"
21
  "one-api/service"
22
  "os"
23
  "strings"
 
 
 
 
 
 
 
24
  )
25
 
26
  func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
 
29
  }
30
 
31
  if !forceFormat && !thinkToContent {
32
+ return helper.StringData(c, data)
33
  }
34
 
35
  var lastStreamResponse dto.ChatCompletionsStreamResponse
 
38
  }
39
 
40
  if !thinkToContent {
41
+ return helper.ObjectData(c, lastStreamResponse)
42
+ }
43
+
44
+ hasThinkingContent := false
45
+ for _, choice := range lastStreamResponse.Choices {
46
+ if len(choice.Delta.GetReasoningContent()) > 0 {
47
+ hasThinkingContent = true
48
+ break
49
+ }
50
  }
51
 
52
  // Handle think to content conversion
53
+ if info.ThinkingContentInfo.IsFirstThinkingContent {
54
+ if hasThinkingContent {
55
+ response := lastStreamResponse.Copy()
56
+ for i := range response.Choices {
57
+ response.Choices[i].Delta.SetContentString("<think>\n")
58
+ response.Choices[i].Delta.SetReasoningContent("")
59
+ }
60
+ info.ThinkingContentInfo.IsFirstThinkingContent = false
61
+ return helper.ObjectData(c, response)
62
+ } else {
63
+ return helper.ObjectData(c, lastStreamResponse)
64
  }
 
65
  }
66
 
67
  if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
68
+ return helper.ObjectData(c, lastStreamResponse)
69
  }
70
 
71
  // Process each choice
72
  for i, choice := range lastStreamResponse.Choices {
73
  // Handle transition from thinking to content
74
+ if len(choice.Delta.GetContentString()) > 0 && !info.ThinkingContentInfo.SendLastThinkingContent {
75
  response := lastStreamResponse.Copy()
76
  for j := range response.Choices {
77
+ response.Choices[j].Delta.SetContentString("\n</think>\n\n")
78
  response.Choices[j].Delta.SetReasoningContent("")
79
  }
80
+ info.ThinkingContentInfo.SendLastThinkingContent = true
81
+ helper.ObjectData(c, response)
82
  }
83
 
84
  // Convert reasoning content to regular content
 
88
  }
89
  }
90
 
91
+ return helper.ObjectData(c, lastStreamResponse)
92
  }
93
 
94
  func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 
118
  }
119
 
120
  toolCount := 0
 
 
 
 
 
 
 
 
 
 
 
121
 
 
 
122
  var (
123
  lastStreamData string
 
124
  )
125
+
126
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
127
+ if lastStreamData != "" {
128
+ err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
129
+ if err != nil {
130
+ common.LogError(c, "streaming error: "+err.Error())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  }
 
132
  }
133
+ lastStreamData = data
134
+ streamItems = append(streamItems, data)
135
+ return true
136
  })
137
 
 
 
 
 
 
 
 
 
138
  shouldSendLastResp := true
139
  var lastStreamResponse dto.ChatCompletionsStreamResponse
140
  err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
 
242
  }
243
 
244
  if info.ShouldIncludeUsage && !containStreamUsage {
245
+ response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
246
  response.SetSystemFingerprint(systemFingerprint)
247
+ helper.ObjectData(c, response)
248
  }
249
 
250
+ helper.Done(c)
251
 
252
+ //resp.Body.Close()
253
  return nil, usage
254
  }
255
 
 
480
  localUsage.InputTokenDetails.TextTokens += textToken
481
  localUsage.InputTokenDetails.AudioTokens += audioToken
482
 
483
+ err = helper.WssString(c, targetConn, string(message))
484
  if err != nil {
485
  errChan <- fmt.Errorf("error writing to target: %v", err)
486
  return
 
586
  localUsage.OutputTokenDetails.AudioTokens += audioToken
587
  }
588
 
589
+ err = helper.WssString(c, clientConn, string(message))
590
  if err != nil {
591
  errChan <- fmt.Errorf("error writing to client: %v", err)
592
  return
relay/channel/palm/relay-palm.go CHANGED
@@ -9,6 +9,7 @@ import (
9
  "one-api/common"
10
  "one-api/constant"
11
  "one-api/dto"
 
12
  "one-api/service"
13
  )
14
 
@@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
112
  dataChan <- string(jsonResponse)
113
  stopChan <- true
114
  }()
115
- service.SetEventStreamHeaders(c)
116
  c.Stream(func(w io.Writer) bool {
117
  select {
118
  case data := <-dataChan:
 
9
  "one-api/common"
10
  "one-api/constant"
11
  "one-api/dto"
12
+ "one-api/relay/helper"
13
  "one-api/service"
14
  )
15
 
 
113
  dataChan <- string(jsonResponse)
114
  stopChan <- true
115
  }()
116
+ helper.SetEventStreamHeaders(c)
117
  c.Stream(func(w io.Writer) bool {
118
  select {
119
  case data := <-dataChan:
relay/channel/tencent/relay-tencent.go CHANGED
@@ -14,6 +14,7 @@ import (
14
  "one-api/common"
15
  "one-api/constant"
16
  "one-api/dto"
 
17
  "one-api/service"
18
  "strconv"
19
  "strings"
@@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
91
  scanner := bufio.NewScanner(resp.Body)
92
  scanner.Split(bufio.ScanLines)
93
 
94
- service.SetEventStreamHeaders(c)
95
 
96
  for scanner.Scan() {
97
  data := scanner.Text()
@@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
112
  responseText += response.Choices[0].Delta.GetContentString()
113
  }
114
 
115
- err = service.ObjectData(c, response)
116
  if err != nil {
117
  common.SysError(err.Error())
118
  }
@@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
122
  common.SysError("error reading stream: " + err.Error())
123
  }
124
 
125
- service.Done(c)
126
 
127
  err := resp.Body.Close()
128
  if err != nil {
 
14
  "one-api/common"
15
  "one-api/constant"
16
  "one-api/dto"
17
+ "one-api/relay/helper"
18
  "one-api/service"
19
  "strconv"
20
  "strings"
 
92
  scanner := bufio.NewScanner(resp.Body)
93
  scanner.Split(bufio.ScanLines)
94
 
95
+ helper.SetEventStreamHeaders(c)
96
 
97
  for scanner.Scan() {
98
  data := scanner.Text()
 
113
  responseText += response.Choices[0].Delta.GetContentString()
114
  }
115
 
116
+ err = helper.ObjectData(c, response)
117
  if err != nil {
118
  common.SysError(err.Error())
119
  }
 
123
  common.SysError("error reading stream: " + err.Error())
124
  }
125
 
126
+ helper.Done(c)
127
 
128
  err := resp.Body.Close()
129
  if err != nil {
relay/channel/vertex/adaptor.go CHANGED
@@ -5,7 +5,6 @@ import (
5
  "errors"
6
  "fmt"
7
  "github.com/gin-gonic/gin"
8
- "github.com/jinzhu/copier"
9
  "io"
10
  "net/http"
11
  "one-api/dto"
@@ -28,6 +27,7 @@ var claudeModelMap = map[string]string{
28
  "claude-3-opus-20240229": "claude-3-opus@20240229",
29
  "claude-3-haiku-20240307": "claude-3-haiku@20240307",
30
  "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
 
31
  "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
32
  }
33
 
@@ -86,15 +86,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
86
  } else {
87
  suffix = "rawPredict"
88
  }
 
89
  if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
90
- info.UpstreamModelName = v
91
  }
92
  return fmt.Sprintf(
93
  "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
94
  region,
95
  adc.ProjectID,
96
  region,
97
- info.UpstreamModelName,
98
  suffix,
99
  ), nil
100
  } else if a.RequestMode == RequestModeLlama {
@@ -127,13 +128,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
127
  if err != nil {
128
  return nil, err
129
  }
130
- vertexClaudeReq := &VertexAIClaudeRequest{
131
- AnthropicVersion: anthropicVersion,
132
- }
133
- if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
134
- return nil, errors.New("failed to copy claude request")
135
- }
136
  c.Set("request_model", claudeReq.Model)
 
137
  return vertexClaudeReq, nil
138
  } else if a.RequestMode == RequestModeGemini {
139
  geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
 
5
  "errors"
6
  "fmt"
7
  "github.com/gin-gonic/gin"
 
8
  "io"
9
  "net/http"
10
  "one-api/dto"
 
27
  "claude-3-opus-20240229": "claude-3-opus@20240229",
28
  "claude-3-haiku-20240307": "claude-3-haiku@20240307",
29
  "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
30
+ "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
31
  "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
32
  }
33
 
 
86
  } else {
87
  suffix = "rawPredict"
88
  }
89
+ model := info.UpstreamModelName
90
  if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
91
+ model = v
92
  }
93
  return fmt.Sprintf(
94
  "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
95
  region,
96
  adc.ProjectID,
97
  region,
98
+ model,
99
  suffix,
100
  ), nil
101
  } else if a.RequestMode == RequestModeLlama {
 
128
  if err != nil {
129
  return nil, err
130
  }
131
+ vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
 
 
 
 
 
132
  c.Set("request_model", claudeReq.Model)
133
+ info.UpstreamModelName = claudeReq.Model
134
  return vertexClaudeReq, nil
135
  } else if a.RequestMode == RequestModeGemini {
136
  geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
relay/channel/vertex/dto.go CHANGED
@@ -1,17 +1,37 @@
1
  package vertex
2
 
3
- import "one-api/relay/channel/claude"
 
 
4
 
5
  type VertexAIClaudeRequest struct {
6
  AnthropicVersion string `json:"anthropic_version"`
7
  Messages []claude.ClaudeMessage `json:"messages"`
8
- System string `json:"system,omitempty"`
9
- MaxTokens int `json:"max_tokens,omitempty"`
10
  StopSequences []string `json:"stop_sequences,omitempty"`
11
  Stream bool `json:"stream,omitempty"`
12
  Temperature *float64 `json:"temperature,omitempty"`
13
  TopP float64 `json:"top_p,omitempty"`
14
  TopK int `json:"top_k,omitempty"`
15
- Tools []claude.Tool `json:"tools,omitempty"`
16
  ToolChoice any `json:"tool_choice,omitempty"`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  }
 
1
  package vertex
2
 
3
+ import (
4
+ "one-api/relay/channel/claude"
5
+ )
6
 
7
  type VertexAIClaudeRequest struct {
8
  AnthropicVersion string `json:"anthropic_version"`
9
  Messages []claude.ClaudeMessage `json:"messages"`
10
+ System any `json:"system,omitempty"`
11
+ MaxTokens uint `json:"max_tokens,omitempty"`
12
  StopSequences []string `json:"stop_sequences,omitempty"`
13
  Stream bool `json:"stream,omitempty"`
14
  Temperature *float64 `json:"temperature,omitempty"`
15
  TopP float64 `json:"top_p,omitempty"`
16
  TopK int `json:"top_k,omitempty"`
17
+ Tools any `json:"tools,omitempty"`
18
  ToolChoice any `json:"tool_choice,omitempty"`
19
+ Thinking *claude.Thinking `json:"thinking,omitempty"`
20
+ }
21
+
22
+ func copyRequest(req *claude.ClaudeRequest, version string) *VertexAIClaudeRequest {
23
+ return &VertexAIClaudeRequest{
24
+ AnthropicVersion: version,
25
+ System: req.System,
26
+ Messages: req.Messages,
27
+ MaxTokens: req.MaxTokens,
28
+ Stream: req.Stream,
29
+ Temperature: req.Temperature,
30
+ TopP: req.TopP,
31
+ TopK: req.TopK,
32
+ StopSequences: req.StopSequences,
33
+ Tools: req.Tools,
34
+ ToolChoice: req.ToolChoice,
35
+ Thinking: req.Thinking,
36
+ }
37
  }
relay/channel/xunfei/relay-xunfei.go CHANGED
@@ -14,6 +14,7 @@ import (
14
  "one-api/common"
15
  "one-api/constant"
16
  "one-api/dto"
 
17
  "one-api/service"
18
  "strings"
19
  "time"
@@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
132
  if err != nil {
133
  return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
134
  }
135
- service.SetEventStreamHeaders(c)
136
  var usage dto.Usage
137
  c.Stream(func(w io.Writer) bool {
138
  select {
 
14
  "one-api/common"
15
  "one-api/constant"
16
  "one-api/dto"
17
+ "one-api/relay/helper"
18
  "one-api/service"
19
  "strings"
20
  "time"
 
133
  if err != nil {
134
  return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
135
  }
136
+ helper.SetEventStreamHeaders(c)
137
  var usage dto.Usage
138
  c.Stream(func(w io.Writer) bool {
139
  select {
relay/channel/zhipu/relay-zhipu.go CHANGED
@@ -10,6 +10,7 @@ import (
10
  "one-api/common"
11
  "one-api/constant"
12
  "one-api/dto"
 
13
  "one-api/service"
14
  "strings"
15
  "sync"
@@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
177
  }
178
  stopChan <- true
179
  }()
180
- service.SetEventStreamHeaders(c)
181
  c.Stream(func(w io.Writer) bool {
182
  select {
183
  case data := <-dataChan:
 
10
  "one-api/common"
11
  "one-api/constant"
12
  "one-api/dto"
13
+ "one-api/relay/helper"
14
  "one-api/service"
15
  "strings"
16
  "sync"
 
178
  }
179
  stopChan <- true
180
  }()
181
+ helper.SetEventStreamHeaders(c)
182
  c.Stream(func(w io.Writer) bool {
183
  select {
184
  case data := <-dataChan:
relay/channel/zhipu_4v/relay-zhipu_v4.go CHANGED
@@ -10,6 +10,7 @@ import (
10
  "net/http"
11
  "one-api/common"
12
  "one-api/dto"
 
13
  "one-api/service"
14
  "strings"
15
  "sync"
@@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
197
  }
198
  stopChan <- true
199
  }()
200
- service.SetEventStreamHeaders(c)
201
  c.Stream(func(w io.Writer) bool {
202
  select {
203
  case data := <-dataChan:
 
10
  "net/http"
11
  "one-api/common"
12
  "one-api/dto"
13
+ "one-api/relay/helper"
14
  "one-api/service"
15
  "strings"
16
  "sync"
 
198
  }
199
  stopChan <- true
200
  }()
201
+ helper.SetEventStreamHeaders(c)
202
  c.Stream(func(w io.Writer) bool {
203
  select {
204
  case data := <-dataChan:
relay/common/relay_info.go CHANGED
@@ -12,25 +12,30 @@ import (
12
  "github.com/gorilla/websocket"
13
  )
14
 
 
 
 
 
 
15
  type RelayInfo struct {
16
- ChannelType int
17
- ChannelId int
18
- TokenId int
19
- TokenKey string
20
- UserId int
21
- Group string
22
- TokenUnlimited bool
23
- StartTime time.Time
24
- FirstResponseTime time.Time
25
- IsFirstResponse bool
26
- SendLastReasoningResponse bool
27
- ApiType int
28
- IsStream bool
29
- IsPlayground bool
30
- UsePrice bool
31
- RelayMode int
32
- UpstreamModelName string
33
- OriginModelName string
34
  //RecodeModelName string
35
  RequestURLPath string
36
  ApiVersion string
@@ -53,6 +58,7 @@ type RelayInfo struct {
53
  UserSetting map[string]interface{}
54
  UserEmail string
55
  UserQuota int
 
56
  }
57
 
58
  // 定义支持流式选项的通道类型
@@ -95,7 +101,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
95
  UserQuota: c.GetInt(constant.ContextKeyUserQuota),
96
  UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
97
  UserEmail: c.GetString(constant.ContextKeyUserEmail),
98
- IsFirstResponse: true,
99
  RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
100
  BaseUrl: c.GetString("base_url"),
101
  RequestURLPath: c.Request.URL.String(),
@@ -117,6 +123,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
117
  ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
118
  Organization: c.GetString("channel_organization"),
119
  ChannelSetting: channelSetting,
 
 
 
 
120
  }
121
  if strings.HasPrefix(c.Request.URL.Path, "/pg") {
122
  info.IsPlayground = true
@@ -147,9 +157,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
147
  }
148
 
149
  func (info *RelayInfo) SetFirstResponseTime() {
150
- if info.IsFirstResponse {
151
  info.FirstResponseTime = time.Now()
152
- info.IsFirstResponse = false
153
  }
154
  }
155
 
 
12
  "github.com/gorilla/websocket"
13
  )
14
 
15
+ type ThinkingContentInfo struct {
16
+ IsFirstThinkingContent bool
17
+ SendLastThinkingContent bool
18
+ }
19
+
20
  type RelayInfo struct {
21
+ ChannelType int
22
+ ChannelId int
23
+ TokenId int
24
+ TokenKey string
25
+ UserId int
26
+ Group string
27
+ TokenUnlimited bool
28
+ StartTime time.Time
29
+ FirstResponseTime time.Time
30
+ isFirstResponse bool
31
+ //SendLastReasoningResponse bool
32
+ ApiType int
33
+ IsStream bool
34
+ IsPlayground bool
35
+ UsePrice bool
36
+ RelayMode int
37
+ UpstreamModelName string
38
+ OriginModelName string
39
  //RecodeModelName string
40
  RequestURLPath string
41
  ApiVersion string
 
58
  UserSetting map[string]interface{}
59
  UserEmail string
60
  UserQuota int
61
+ ThinkingContentInfo
62
  }
63
 
64
  // 定义支持流式选项的通道类型
 
101
  UserQuota: c.GetInt(constant.ContextKeyUserQuota),
102
  UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
103
  UserEmail: c.GetString(constant.ContextKeyUserEmail),
104
+ isFirstResponse: true,
105
  RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
106
  BaseUrl: c.GetString("base_url"),
107
  RequestURLPath: c.Request.URL.String(),
 
123
  ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
124
  Organization: c.GetString("channel_organization"),
125
  ChannelSetting: channelSetting,
126
+ ThinkingContentInfo: ThinkingContentInfo{
127
+ IsFirstThinkingContent: true,
128
+ SendLastThinkingContent: false,
129
+ },
130
  }
131
  if strings.HasPrefix(c.Request.URL.Path, "/pg") {
132
  info.IsPlayground = true
 
157
  }
158
 
159
  func (info *RelayInfo) SetFirstResponseTime() {
160
+ if info.isFirstResponse {
161
  info.FirstResponseTime = time.Now()
162
+ info.isFirstResponse = false
163
  }
164
  }
165
 
service/relay.go → relay/helper/common.go RENAMED
@@ -1,4 +1,4 @@
1
- package service
2
 
3
  import (
4
  "encoding/json"
 
1
+ package helper
2
 
3
  import (
4
  "encoding/json"
relay/helper/price.go CHANGED
@@ -11,26 +11,33 @@ import (
11
  type PriceData struct {
12
  ModelPrice float64
13
  ModelRatio float64
 
14
  GroupRatio float64
15
  UsePrice bool
16
  ShouldPreConsumedQuota int
17
  }
18
 
19
  func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
20
- modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
21
  groupRatio := setting.GetGroupRatio(info.Group)
22
  var preConsumedQuota int
23
  var modelRatio float64
 
24
  if !usePrice {
25
  preConsumedTokens := common.PreConsumedQuota
26
  if maxTokens != 0 {
27
  preConsumedTokens = promptTokens + maxTokens
28
  }
29
  var success bool
30
- modelRatio, success = common.GetModelRatio(info.OriginModelName)
31
  if !success {
32
- return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
 
 
 
 
33
  }
 
34
  ratio := modelRatio * groupRatio
35
  preConsumedQuota = int(float64(preConsumedTokens) * ratio)
36
  } else {
@@ -39,6 +46,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
39
  return PriceData{
40
  ModelPrice: modelPrice,
41
  ModelRatio: modelRatio,
 
42
  GroupRatio: groupRatio,
43
  UsePrice: usePrice,
44
  ShouldPreConsumedQuota: preConsumedQuota,
 
11
  type PriceData struct {
12
  ModelPrice float64
13
  ModelRatio float64
14
+ CompletionRatio float64
15
  GroupRatio float64
16
  UsePrice bool
17
  ShouldPreConsumedQuota int
18
  }
19
 
20
  func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
21
+ modelPrice, usePrice := setting.GetModelPrice(info.OriginModelName, false)
22
  groupRatio := setting.GetGroupRatio(info.Group)
23
  var preConsumedQuota int
24
  var modelRatio float64
25
+ var completionRatio float64
26
  if !usePrice {
27
  preConsumedTokens := common.PreConsumedQuota
28
  if maxTokens != 0 {
29
  preConsumedTokens = promptTokens + maxTokens
30
  }
31
  var success bool
32
+ modelRatio, success = setting.GetModelRatio(info.OriginModelName)
33
  if !success {
34
+ if info.UserId == 1 {
35
+ return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
36
+ } else {
37
+ return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
38
+ }
39
  }
40
+ completionRatio = setting.GetCompletionRatio(info.OriginModelName)
41
  ratio := modelRatio * groupRatio
42
  preConsumedQuota = int(float64(preConsumedTokens) * ratio)
43
  } else {
 
46
  return PriceData{
47
  ModelPrice: modelPrice,
48
  ModelRatio: modelRatio,
49
+ CompletionRatio: completionRatio,
50
  GroupRatio: groupRatio,
51
  UsePrice: usePrice,
52
  ShouldPreConsumedQuota: preConsumedQuota,
relay/helper/stream_scanner.go ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package helper
2
+
3
+ import (
4
+ "bufio"
5
+ "context"
6
+ "io"
7
+ "net/http"
8
+ "one-api/common"
9
+ "one-api/constant"
10
+ relaycommon "one-api/relay/common"
11
+ "strings"
12
+ "time"
13
+
14
+ "github.com/gin-gonic/gin"
15
+ )
16
+
17
+ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
18
+
19
+ if resp == nil {
20
+ return
21
+ }
22
+
23
+ defer resp.Body.Close()
24
+
25
+ streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
26
+ if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
27
+ // twice timeout for thinking model
28
+ streamingTimeout *= 2
29
+ }
30
+
31
+ var (
32
+ stopChan = make(chan bool, 2)
33
+ scanner = bufio.NewScanner(resp.Body)
34
+ ticker = time.NewTicker(streamingTimeout)
35
+ )
36
+
37
+ defer func() {
38
+ ticker.Stop()
39
+ close(stopChan)
40
+ }()
41
+
42
+ scanner.Split(bufio.ScanLines)
43
+ SetEventStreamHeaders(c)
44
+
45
+ ctx, cancel := context.WithCancel(context.Background())
46
+ defer cancel()
47
+
48
+ ctx = context.WithValue(ctx, "stop_chan", stopChan)
49
+ common.RelayCtxGo(ctx, func() {
50
+ for scanner.Scan() {
51
+ ticker.Reset(streamingTimeout)
52
+ data := scanner.Text()
53
+ if common.DebugEnabled {
54
+ println(data)
55
+ }
56
+
57
+ if len(data) < 6 {
58
+ continue
59
+ }
60
+ if data[:5] != "data:" && data[:6] != "[DONE]" {
61
+ continue
62
+ }
63
+ data = data[5:]
64
+ data = strings.TrimLeft(data, " ")
65
+ data = strings.TrimSuffix(data, "\"")
66
+ if !strings.HasPrefix(data, "[DONE]") {
67
+ info.SetFirstResponseTime()
68
+ success := dataHandler(data)
69
+ if !success {
70
+ break
71
+ }
72
+ }
73
+ }
74
+
75
+ if err := scanner.Err(); err != nil {
76
+ if err != io.EOF {
77
+ common.LogError(c, "scanner error: "+err.Error())
78
+ }
79
+ }
80
+
81
+ common.SafeSendBool(stopChan, true)
82
+ })
83
+
84
+ select {
85
+ case <-ticker.C:
86
+ // 超时处理逻辑
87
+ common.LogError(c, "streaming timeout")
88
+ case <-stopChan:
89
+ // 正常结束
90
+ }
91
+ }
relay/relay-mj.go CHANGED
@@ -157,10 +157,10 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
157
  return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
158
  }
159
  modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
160
- modelPrice, success := common.GetModelPrice(modelName, true)
161
  // 如果没有配置价格,则使用默认价格
162
  if !success {
163
- defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
164
  if !ok {
165
  modelPrice = 0.1
166
  } else {
@@ -463,10 +463,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
463
  fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
464
 
465
  modelName := service.CoverActionToModelName(midjRequest.Action)
466
- modelPrice, success := common.GetModelPrice(modelName, true)
467
  // 如果没有配置价格,则使用默认价格
468
  if !success {
469
- defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
470
  if !ok {
471
  modelPrice = 0.1
472
  } else {
 
157
  return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
158
  }
159
  modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
160
+ modelPrice, success := setting.GetModelPrice(modelName, true)
161
  // 如果没有配置价格,则使用默认价格
162
  if !success {
163
+ defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName]
164
  if !ok {
165
  modelPrice = 0.1
166
  } else {
 
463
  fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
464
 
465
  modelName := service.CoverActionToModelName(midjRequest.Action)
466
+ modelPrice, success := setting.GetModelPrice(modelName, true)
467
  // 如果没有配置价格,则使用默认价格
468
  if !success {
469
+ defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName]
470
  if !ok {
471
  modelPrice = 0.1
472
  } else {
relay/relay-text.go CHANGED
@@ -311,7 +311,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
311
  modelName := relayInfo.OriginModelName
312
 
313
  tokenName := ctx.GetString("token_name")
314
- completionRatio := common.GetCompletionRatio(modelName)
315
  ratio := priceData.ModelRatio * priceData.GroupRatio
316
  modelRatio := priceData.ModelRatio
317
  groupRatio := priceData.GroupRatio
 
311
  modelName := relayInfo.OriginModelName
312
 
313
  tokenName := ctx.GetString("token_name")
314
+ completionRatio := setting.GetCompletionRatio(modelName)
315
  ratio := priceData.ModelRatio * priceData.GroupRatio
316
  modelRatio := priceData.ModelRatio
317
  groupRatio := priceData.GroupRatio
relay/relay_task.go CHANGED
@@ -37,9 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
37
  }
38
 
39
  modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
40
- modelPrice, success := common.GetModelPrice(modelName, true)
41
  if !success {
42
- defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
43
  if !ok {
44
  modelPrice = 0.1
45
  } else {
 
37
  }
38
 
39
  modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
40
+ modelPrice, success := setting.GetModelPrice(modelName, true)
41
  if !success {
42
+ defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName]
43
  if !ok {
44
  modelPrice = 0.1
45
  } else {
relay/websocket.go CHANGED
@@ -39,7 +39,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
39
  }
40
  }
41
  //relayInfo.UpstreamModelName = textRequest.Model
42
- modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
43
  groupRatio := setting.GetGroupRatio(relayInfo.Group)
44
 
45
  var preConsumedQuota int
@@ -65,7 +65,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
65
  //if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
66
  // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
67
  //}
68
- modelRatio, _ = common.GetModelRatio(relayInfo.UpstreamModelName)
69
  ratio = modelRatio * groupRatio
70
  preConsumedQuota = int(float64(preConsumedTokens) * ratio)
71
  } else {
 
39
  }
40
  }
41
  //relayInfo.UpstreamModelName = textRequest.Model
42
+ modelPrice, getModelPriceSuccess := setting.GetModelPrice(relayInfo.UpstreamModelName, false)
43
  groupRatio := setting.GetGroupRatio(relayInfo.Group)
44
 
45
  var preConsumedQuota int
 
65
  //if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
66
  // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
67
  //}
68
+ modelRatio, _ = setting.GetModelRatio(relayInfo.UpstreamModelName)
69
  ratio = modelRatio * groupRatio
70
  preConsumedQuota = int(float64(preConsumedTokens) * ratio)
71
  } else {
service/channel.go CHANGED
@@ -6,23 +6,31 @@ import (
6
  "one-api/common"
7
  "one-api/dto"
8
  "one-api/model"
9
- "one-api/setting"
10
  "strings"
11
  )
12
 
 
 
 
 
13
  // disable & notify
14
  func DisableChannel(channelId int, channelName string, reason string) {
15
- model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
16
- subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
17
- content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
18
- NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
 
 
19
  }
20
 
21
  func EnableChannel(channelId int, channelName string) {
22
- model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
23
- subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
24
- content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
25
- NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
 
 
26
  }
27
 
28
  func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
@@ -67,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
67
  }
68
 
69
  lowerMessage := strings.ToLower(err.Error.Message)
70
- search, _ := AcSearch(lowerMessage, setting.AutomaticDisableKeywords, true)
71
  if search {
72
  return true
73
  }
 
6
  "one-api/common"
7
  "one-api/dto"
8
  "one-api/model"
9
+ "one-api/setting/operation_setting"
10
  "strings"
11
  )
12
 
13
+ func formatNotifyType(channelId int, status int) string {
14
+ return fmt.Sprintf("%s_%d_%d", dto.NotifyTypeChannelUpdate, channelId, status)
15
+ }
16
+
17
  // disable & notify
18
  func DisableChannel(channelId int, channelName string, reason string) {
19
+ success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
20
+ if success {
21
+ subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
22
+ content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
23
+ NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content)
24
+ }
25
  }
26
 
27
  func EnableChannel(channelId int, channelName string) {
28
+ success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
29
+ if success {
30
+ subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
31
+ content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
32
+ NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusEnabled), subject, content)
33
+ }
34
  }
35
 
36
  func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
 
75
  }
76
 
77
  lowerMessage := strings.ToLower(err.Error.Message)
78
+ search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true)
79
  if search {
80
  return true
81
  }
service/image.go CHANGED
@@ -7,7 +7,9 @@ import (
7
  "fmt"
8
  "image"
9
  "io"
 
10
  "one-api/common"
 
11
  "strings"
12
 
13
  "golang.org/x/image/webp"
@@ -23,7 +25,7 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
23
  decodedData, err := base64.StdEncoding.DecodeString(base64String)
24
  if err != nil {
25
  fmt.Println("Error: Failed to decode base64 string")
26
- return image.Config{}, "", "", err
27
  }
28
 
29
  // 创建一个bytes.Buffer用于存储解码后的数据
@@ -61,20 +63,51 @@ func DecodeBase64FileData(base64String string) (string, string, error) {
61
  func GetImageFromUrl(url string) (mimeType string, data string, err error) {
62
  resp, err := DoDownloadRequest(url)
63
  if err != nil {
64
- return "", "", err
65
- }
66
- if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
67
- return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
68
  }
69
  defer resp.Body.Close()
70
- buffer := bytes.NewBuffer(nil)
71
- _, err = buffer.ReadFrom(resp.Body)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  if err != nil {
73
- return
 
 
 
74
  }
75
- mimeType = resp.Header.Get("Content-Type")
76
  data = base64.StdEncoding.EncodeToString(buffer.Bytes())
77
- return
 
 
 
 
 
 
 
 
 
 
 
78
  }
79
 
80
  func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
@@ -92,7 +125,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
92
 
93
  mimeType := response.Header.Get("Content-Type")
94
 
95
- if !strings.HasPrefix(mimeType, "image/") {
96
  return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
97
  }
98
 
 
7
  "fmt"
8
  "image"
9
  "io"
10
+ "net/http"
11
  "one-api/common"
12
+ "one-api/constant"
13
  "strings"
14
 
15
  "golang.org/x/image/webp"
 
25
  decodedData, err := base64.StdEncoding.DecodeString(base64String)
26
  if err != nil {
27
  fmt.Println("Error: Failed to decode base64 string")
28
+ return image.Config{}, "", "", fmt.Errorf("failed to decode base64 string: %s", err.Error())
29
  }
30
 
31
  // 创建一个bytes.Buffer用于存储解码后的数据
 
63
  func GetImageFromUrl(url string) (mimeType string, data string, err error) {
64
  resp, err := DoDownloadRequest(url)
65
  if err != nil {
66
+ return "", "", fmt.Errorf("failed to download image: %w", err)
 
 
 
67
  }
68
  defer resp.Body.Close()
69
+
70
+ // Check HTTP status code
71
+ if resp.StatusCode != http.StatusOK {
72
+ return "", "", fmt.Errorf("failed to download image: HTTP %d", resp.StatusCode)
73
+ }
74
+
75
+ contentType := resp.Header.Get("Content-Type")
76
+ if contentType != "application/octet-stream" && !strings.HasPrefix(contentType, "image/") {
77
+ return "", "", fmt.Errorf("invalid content type: %s, required image/*", contentType)
78
+ }
79
+ maxImageSize := int64(constant.MaxFileDownloadMB * 1024 * 1024)
80
+
81
+ // Check Content-Length if available
82
+ if resp.ContentLength > maxImageSize {
83
+ return "", "", fmt.Errorf("image size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxImageSize)
84
+ }
85
+
86
+ // Use LimitReader to prevent reading oversized images
87
+ limitReader := io.LimitReader(resp.Body, maxImageSize)
88
+ buffer := &bytes.Buffer{}
89
+
90
+ written, err := io.Copy(buffer, limitReader)
91
  if err != nil {
92
+ return "", "", fmt.Errorf("failed to read image data: %w", err)
93
+ }
94
+ if written >= maxImageSize {
95
+ return "", "", fmt.Errorf("image size exceeds maximum allowed size of %d bytes", maxImageSize)
96
  }
97
+
98
  data = base64.StdEncoding.EncodeToString(buffer.Bytes())
99
+ mimeType = contentType
100
+
101
+ // Handle application/octet-stream type
102
+ if mimeType == "application/octet-stream" {
103
+ _, format, _, err := DecodeBase64ImageData(data)
104
+ if err != nil {
105
+ return "", "", err
106
+ }
107
+ mimeType = "image/" + format
108
+ }
109
+
110
+ return mimeType, data, nil
111
  }
112
 
113
  func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
 
125
 
126
  mimeType := response.Header.Get("Content-Type")
127
 
128
+ if mimeType != "application/octet-stream" && !strings.HasPrefix(mimeType, "image/") {
129
  return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
130
  }
131
 
service/quota.go CHANGED
@@ -38,9 +38,9 @@ func calculateAudioQuota(info QuotaInfo) int {
38
  return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
39
  }
40
 
41
- completionRatio := common.GetCompletionRatio(info.ModelName)
42
- audioRatio := common.GetAudioRatio(info.ModelName)
43
- audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName)
44
  ratio := info.GroupRatio * info.ModelRatio
45
 
46
  quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
@@ -75,7 +75,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
75
  audioInputTokens := usage.InputTokenDetails.AudioTokens
76
  audioOutTokens := usage.OutputTokenDetails.AudioTokens
77
  groupRatio := setting.GetGroupRatio(relayInfo.Group)
78
- modelRatio, _ := common.GetModelRatio(modelName)
79
 
80
  quotaInfo := QuotaInfo{
81
  InputDetails: TokenDetails{
@@ -122,9 +122,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
122
  audioOutTokens := usage.OutputTokenDetails.AudioTokens
123
 
124
  tokenName := ctx.GetString("token_name")
125
- completionRatio := common.GetCompletionRatio(modelName)
126
- audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
127
- audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
128
 
129
  quotaInfo := QuotaInfo{
130
  InputDetails: TokenDetails{
@@ -184,9 +184,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
184
  audioOutTokens := usage.CompletionTokenDetails.AudioTokens
185
 
186
  tokenName := ctx.GetString("token_name")
187
- completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
188
- audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
189
- audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
190
 
191
  modelRatio := priceData.ModelRatio
192
  groupRatio := priceData.GroupRatio
 
38
  return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
39
  }
40
 
41
+ completionRatio := setting.GetCompletionRatio(info.ModelName)
42
+ audioRatio := setting.GetAudioRatio(info.ModelName)
43
+ audioCompletionRatio := setting.GetAudioCompletionRatio(info.ModelName)
44
  ratio := info.GroupRatio * info.ModelRatio
45
 
46
  quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
 
75
  audioInputTokens := usage.InputTokenDetails.AudioTokens
76
  audioOutTokens := usage.OutputTokenDetails.AudioTokens
77
  groupRatio := setting.GetGroupRatio(relayInfo.Group)
78
+ modelRatio, _ := setting.GetModelRatio(modelName)
79
 
80
  quotaInfo := QuotaInfo{
81
  InputDetails: TokenDetails{
 
122
  audioOutTokens := usage.OutputTokenDetails.AudioTokens
123
 
124
  tokenName := ctx.GetString("token_name")
125
+ completionRatio := setting.GetCompletionRatio(modelName)
126
+ audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName)
127
+ audioCompletionRatio := setting.GetAudioCompletionRatio(modelName)
128
 
129
  quotaInfo := QuotaInfo{
130
  InputDetails: TokenDetails{
 
184
  audioOutTokens := usage.CompletionTokenDetails.AudioTokens
185
 
186
  tokenName := ctx.GetString("token_name")
187
+ completionRatio := setting.GetCompletionRatio(relayInfo.OriginModelName)
188
+ audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName)
189
+ audioCompletionRatio := setting.GetAudioCompletionRatio(relayInfo.OriginModelName)
190
 
191
  modelRatio := priceData.ModelRatio
192
  groupRatio := priceData.GroupRatio
service/token_counter.go CHANGED
@@ -10,6 +10,7 @@ import (
10
  "one-api/constant"
11
  "one-api/dto"
12
  relaycommon "one-api/relay/common"
 
13
  "strings"
14
  "unicode/utf8"
15
 
@@ -32,7 +33,7 @@ func InitTokenEncoders() {
32
  if err != nil {
33
  common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
34
  }
35
- for model, _ := range common.GetDefaultModelRatioMap() {
36
  if strings.HasPrefix(model, "gpt-3.5") {
37
  tokenEncoderMap[model] = cl100TokenEncoder
38
  } else if strings.HasPrefix(model, "gpt-4") {
 
10
  "one-api/constant"
11
  "one-api/dto"
12
  relaycommon "one-api/relay/common"
13
+ "one-api/setting"
14
  "strings"
15
  "unicode/utf8"
16
 
 
33
  if err != nil {
34
  common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
35
  }
36
+ for model, _ := range setting.GetDefaultModelRatioMap() {
37
  if strings.HasPrefix(model, "gpt-3.5") {
38
  tokenEncoderMap[model] = cl100TokenEncoder
39
  } else if strings.HasPrefix(model, "gpt-4") {
service/user_notify.go CHANGED
@@ -11,7 +11,10 @@ import (
11
 
12
  func NotifyRootUser(t string, subject string, content string) {
13
  user := model.GetRootUser().ToBaseUser()
14
- _ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
 
 
 
15
  }
16
 
17
  func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
 
11
 
12
  func NotifyRootUser(t string, subject string, content string) {
13
  user := model.GetRootUser().ToBaseUser()
14
+ err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
15
+ if err != nil {
16
+ common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
17
+ }
18
  }
19
 
20
  func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
{common → setting}/model-ratio.go RENAMED
@@ -1,7 +1,9 @@
1
- package common
2
 
3
  import (
4
  "encoding/json"
 
 
5
  "strings"
6
  "sync"
7
  )
@@ -261,7 +263,7 @@ func ModelPrice2JSONString() string {
261
  GetModelPriceMap()
262
  jsonBytes, err := json.Marshal(modelPriceMap)
263
  if err != nil {
264
- SysError("error marshalling model price: " + err.Error())
265
  }
266
  return string(jsonBytes)
267
  }
@@ -285,7 +287,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
285
  price, ok := modelPriceMap[name]
286
  if !ok {
287
  if printErr {
288
- SysError("model price not found: " + name)
289
  }
290
  return -1, false
291
  }
@@ -305,7 +307,7 @@ func ModelRatio2JSONString() string {
305
  GetModelRatioMap()
306
  jsonBytes, err := json.Marshal(modelRatioMap)
307
  if err != nil {
308
- SysError("error marshalling model ratio: " + err.Error())
309
  }
310
  return string(jsonBytes)
311
  }
@@ -324,8 +326,7 @@ func GetModelRatio(name string) (float64, bool) {
324
  }
325
  ratio, ok := modelRatioMap[name]
326
  if !ok {
327
- SysError("model ratio not found: " + name)
328
- return 37.5, false
329
  }
330
  return ratio, true
331
  }
@@ -333,7 +334,7 @@ func GetModelRatio(name string) (float64, bool) {
333
  func DefaultModelRatio2JSONString() string {
334
  jsonBytes, err := json.Marshal(defaultModelRatio)
335
  if err != nil {
336
- SysError("error marshalling model ratio: " + err.Error())
337
  }
338
  return string(jsonBytes)
339
  }
@@ -355,7 +356,7 @@ func CompletionRatio2JSONString() string {
355
  GetCompletionRatioMap()
356
  jsonBytes, err := json.Marshal(CompletionRatio)
357
  if err != nil {
358
- SysError("error marshalling completion ratio: " + err.Error())
359
  }
360
  return string(jsonBytes)
361
  }
 
1
+ package setting
2
 
3
  import (
4
  "encoding/json"
5
+ "one-api/common"
6
+ "one-api/setting/operation_setting"
7
  "strings"
8
  "sync"
9
  )
 
263
  GetModelPriceMap()
264
  jsonBytes, err := json.Marshal(modelPriceMap)
265
  if err != nil {
266
+ common.SysError("error marshalling model price: " + err.Error())
267
  }
268
  return string(jsonBytes)
269
  }
 
287
  price, ok := modelPriceMap[name]
288
  if !ok {
289
  if printErr {
290
+ common.SysError("model price not found: " + name)
291
  }
292
  return -1, false
293
  }
 
307
  GetModelRatioMap()
308
  jsonBytes, err := json.Marshal(modelRatioMap)
309
  if err != nil {
310
+ common.SysError("error marshalling model ratio: " + err.Error())
311
  }
312
  return string(jsonBytes)
313
  }
 
326
  }
327
  ratio, ok := modelRatioMap[name]
328
  if !ok {
329
+ return 37.5, operation_setting.SelfUseModeEnabled
 
330
  }
331
  return ratio, true
332
  }
 
334
  func DefaultModelRatio2JSONString() string {
335
  jsonBytes, err := json.Marshal(defaultModelRatio)
336
  if err != nil {
337
+ common.SysError("error marshalling model ratio: " + err.Error())
338
  }
339
  return string(jsonBytes)
340
  }
 
356
  GetCompletionRatioMap()
357
  jsonBytes, err := json.Marshal(CompletionRatio)
358
  if err != nil {
359
+ common.SysError("error marshalling completion ratio: " + err.Error())
360
  }
361
  return string(jsonBytes)
362
  }
setting/{operation_setting.go → operation_setting/operation_setting.go} RENAMED
@@ -1,8 +1,9 @@
1
- package setting
2
 
3
  import "strings"
4
 
5
  var DemoSiteEnabled = false
 
6
 
7
  var AutomaticDisableKeywords = []string{
8
  "Your credit balance is too low",
 
1
+ package operation_setting
2
 
3
  import "strings"
4
 
5
  var DemoSiteEnabled = false
6
+ var SelfUseModeEnabled = false
7
 
8
  var AutomaticDisableKeywords = []string{
9
  "Your credit balance is too low",
web/src/App.js CHANGED
@@ -30,6 +30,7 @@ import { useTranslation } from 'react-i18next';
30
  import { StatusContext } from './context/Status';
31
  import { setStatusData } from './helpers/data.js';
32
  import { API, showError } from './helpers';
 
33
 
34
  const Home = lazy(() => import('./pages/Home'));
35
  const Detail = lazy(() => import('./pages/Detail'));
@@ -177,6 +178,16 @@ function App() {
177
  </PrivateRoute>
178
  }
179
  />
 
 
 
 
 
 
 
 
 
 
180
  <Route
181
  path='/topup'
182
  element={
 
30
  import { StatusContext } from './context/Status';
31
  import { setStatusData } from './helpers/data.js';
32
  import { API, showError } from './helpers';
33
+ import PersonalSetting from './components/PersonalSetting.js';
34
 
35
  const Home = lazy(() => import('./pages/Home'));
36
  const Detail = lazy(() => import('./pages/Detail'));
 
178
  </PrivateRoute>
179
  }
180
  />
181
+ <Route
182
+ path='/personal'
183
+ element={
184
+ <PrivateRoute>
185
+ <Suspense fallback={<Loading></Loading>}>
186
+ <PersonalSetting />
187
+ </Suspense>
188
+ </PrivateRoute>
189
+ }
190
+ />
191
  <Route
192
  path='/topup'
193
  element={
web/src/components/ChannelsTable.js CHANGED
@@ -15,7 +15,7 @@ import {
15
  getQuotaPerUnit,
16
  renderGroup,
17
  renderNumberWithPoint,
18
- renderQuota, renderQuotaWithPrompt
19
  } from '../helpers/render';
20
  import {
21
  Button, Divider,
@@ -378,17 +378,15 @@ const ChannelsTable = () => {
378
  >
379
  {t('测试')}
380
  </Button>
381
- <Dropdown
382
- trigger="click"
383
- position="bottomRight"
384
- menu={modelMenuItems} // 使用即时生成的菜单项
385
- >
386
- <Button
387
- style={{ padding: '8px 4px' }}
388
- type="primary"
389
- icon={<IconTreeTriangleDown />}
390
- ></Button>
391
- </Dropdown>
392
  </SplitButtonGroup>
393
  <Popconfirm
394
  title={t('确定是否要删除此渠道?')}
@@ -522,6 +520,9 @@ const ChannelsTable = () => {
522
  const [enableTagMode, setEnableTagMode] = useState(false);
523
  const [showBatchSetTag, setShowBatchSetTag] = useState(false);
524
  const [batchSetTagValue, setBatchSetTagValue] = useState('');
 
 
 
525
 
526
 
527
  const removeRecord = (record) => {
@@ -1289,6 +1290,77 @@ const ChannelsTable = () => {
1289
  onChange={(v) => setBatchSetTagValue(v)}
1290
  />
1291
  </Modal>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1292
  </>
1293
  );
1294
  };
 
15
  getQuotaPerUnit,
16
  renderGroup,
17
  renderNumberWithPoint,
18
+ renderQuota, renderQuotaWithPrompt, stringToColor
19
  } from '../helpers/render';
20
  import {
21
  Button, Divider,
 
378
  >
379
  {t('测试')}
380
  </Button>
381
+ <Button
382
+ style={{ padding: '8px 4px' }}
383
+ type="primary"
384
+ icon={<IconTreeTriangleDown />}
385
+ onClick={() => {
386
+ setCurrentTestChannel(record);
387
+ setShowModelTestModal(true);
388
+ }}
389
+ ></Button>
 
 
390
  </SplitButtonGroup>
391
  <Popconfirm
392
  title={t('确定是否要删除此渠道?')}
 
520
  const [enableTagMode, setEnableTagMode] = useState(false);
521
  const [showBatchSetTag, setShowBatchSetTag] = useState(false);
522
  const [batchSetTagValue, setBatchSetTagValue] = useState('');
523
+ const [showModelTestModal, setShowModelTestModal] = useState(false);
524
+ const [currentTestChannel, setCurrentTestChannel] = useState(null);
525
+ const [modelSearchKeyword, setModelSearchKeyword] = useState('');
526
 
527
 
528
  const removeRecord = (record) => {
 
1290
  onChange={(v) => setBatchSetTagValue(v)}
1291
  />
1292
  </Modal>
1293
+
1294
+ {/* 模型测试弹窗 */}
1295
+ <Modal
1296
+ title={t('选择模型进行测试')}
1297
+ visible={showModelTestModal && currentTestChannel !== null}
1298
+ onCancel={() => {
1299
+ setShowModelTestModal(false);
1300
+ setModelSearchKeyword('');
1301
+ }}
1302
+ footer={null}
1303
+ maskClosable={true}
1304
+ centered={true}
1305
+ width={600}
1306
+ >
1307
+ <div style={{ maxHeight: '500px', overflowY: 'auto', padding: '10px' }}>
1308
+ {currentTestChannel && (
1309
+ <div>
1310
+ <Typography.Title heading={6} style={{ marginBottom: '16px' }}>
1311
+ {t('渠道')}: {currentTestChannel.name}
1312
+ </Typography.Title>
1313
+
1314
+ {/* 搜索框 */}
1315
+ <Input
1316
+ placeholder={t('搜索模型...')}
1317
+ value={modelSearchKeyword}
1318
+ onChange={(value) => setModelSearchKeyword(value)}
1319
+ style={{ marginBottom: '16px' }}
1320
+ showClear
1321
+ />
1322
+
1323
+ <div style={{
1324
+ display: 'grid',
1325
+ gridTemplateColumns: 'repeat(auto-fill, minmax(180px, 1fr))',
1326
+ gap: '10px'
1327
+ }}>
1328
+ {currentTestChannel.models.split(',')
1329
+ .filter(model => model.toLowerCase().includes(modelSearchKeyword.toLowerCase()))
1330
+ .map((model, index) => {
1331
+
1332
+ return (
1333
+ <Button
1334
+ key={index}
1335
+ theme="light"
1336
+ type="tertiary"
1337
+ style={{
1338
+ height: 'auto',
1339
+ padding: '8px 12px',
1340
+ textAlign: 'center',
1341
+ }}
1342
+ onClick={() => {
1343
+ testChannel(currentTestChannel, model);
1344
+ }}
1345
+ >
1346
+ {model}
1347
+ </Button>
1348
+ );
1349
+ })}
1350
+ </div>
1351
+
1352
+ {/* 显示搜索结果数量 */}
1353
+ {modelSearchKeyword && (
1354
+ <Typography.Text type="secondary" style={{ marginTop: '16px', display: 'block' }}>
1355
+ {t('找到')} {currentTestChannel.models.split(',').filter(model =>
1356
+ model.toLowerCase().includes(modelSearchKeyword.toLowerCase())
1357
+ ).length} {t('个模型')}
1358
+ </Typography.Text>
1359
+ )}
1360
+ </div>
1361
+ )}
1362
+ </div>
1363
+ </Modal>
1364
  </>
1365
  );
1366
  };
web/src/components/HeaderBar.js CHANGED
@@ -21,15 +21,17 @@ import {
21
  IconUser,
22
  IconLanguage
23
  } from '@douyinfe/semi-icons';
24
- import { Avatar, Button, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui';
25
  import { stringToColor } from '../helpers/render';
26
  import Text from '@douyinfe/semi-ui/lib/es/typography/text';
27
  import { StyleContext } from '../context/Style/index.js';
 
28
 
29
  const HeaderBar = () => {
30
  const { t, i18n } = useTranslation();
31
  const [userState, userDispatch] = useContext(UserContext);
32
  const [styleState, styleDispatch] = useContext(StyleContext);
 
33
  let navigate = useNavigate();
34
  const [currentLang, setCurrentLang] = useState(i18n.language);
35
 
@@ -40,6 +42,10 @@ const HeaderBar = () => {
40
  const isNewYear =
41
  (currentDate.getMonth() === 0 && currentDate.getDate() === 1);
42
 
 
 
 
 
43
  let buttons = [
44
  {
45
  text: t('首页'),
@@ -166,7 +172,7 @@ const HeaderBar = () => {
166
  onSelect={(key) => {}}
167
  header={styleState.isMobile?{
168
  logo: (
169
- <>
170
  {
171
  !styleState.showSider ?
172
  <Button icon={<IconMenu />} theme="light" aria-label={t('展开侧边栏')} onClick={
@@ -176,13 +182,52 @@ const HeaderBar = () => {
176
  () => styleDispatch({ type: 'SET_SIDER', payload: false })
177
  } />
178
  }
179
- </>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  ),
181
  }:{
182
  logo: (
183
  <img src={logo} alt='logo' />
184
  ),
185
- text: systemName,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  }}
187
  items={buttons}
188
  footer={
@@ -266,7 +311,8 @@ const HeaderBar = () => {
266
  icon={<IconUser />}
267
  />
268
  {
269
- !styleState.isMobile && (
 
270
  <Nav.Item
271
  itemKey={'register'}
272
  text={t('注册')}
 
21
  IconUser,
22
  IconLanguage
23
  } from '@douyinfe/semi-icons';
24
+ import { Avatar, Button, Dropdown, Layout, Nav, Switch, Tag } from '@douyinfe/semi-ui';
25
  import { stringToColor } from '../helpers/render';
26
  import Text from '@douyinfe/semi-ui/lib/es/typography/text';
27
  import { StyleContext } from '../context/Style/index.js';
28
+ import { StatusContext } from '../context/Status/index.js';
29
 
30
  const HeaderBar = () => {
31
  const { t, i18n } = useTranslation();
32
  const [userState, userDispatch] = useContext(UserContext);
33
  const [styleState, styleDispatch] = useContext(StyleContext);
34
+ const [statusState, statusDispatch] = useContext(StatusContext);
35
  let navigate = useNavigate();
36
  const [currentLang, setCurrentLang] = useState(i18n.language);
37
 
 
42
  const isNewYear =
43
  (currentDate.getMonth() === 0 && currentDate.getDate() === 1);
44
 
45
+ // Check if self-use mode is enabled
46
+ const isSelfUseMode = statusState?.status?.self_use_mode_enabled || false;
47
+ const isDemoSiteMode = statusState?.status?.demo_site_enabled || false;
48
+
49
  let buttons = [
50
  {
51
  text: t('首页'),
 
172
  onSelect={(key) => {}}
173
  header={styleState.isMobile?{
174
  logo: (
175
+ <div style={{ display: 'flex', alignItems: 'center', position: 'relative' }}>
176
  {
177
  !styleState.showSider ?
178
  <Button icon={<IconMenu />} theme="light" aria-label={t('展开侧边栏')} onClick={
 
182
  () => styleDispatch({ type: 'SET_SIDER', payload: false })
183
  } />
184
  }
185
+ {(isSelfUseMode || isDemoSiteMode) && (
186
+ <Tag
187
+ color={isSelfUseMode ? 'purple' : 'blue'}
188
+ style={{
189
+ position: 'absolute',
190
+ top: '-8px',
191
+ right: '-15px',
192
+ fontSize: '0.7rem',
193
+ padding: '0 4px',
194
+ height: 'auto',
195
+ lineHeight: '1.2',
196
+ zIndex: 1,
197
+ pointerEvents: 'none'
198
+ }}
199
+ >
200
+ {isSelfUseMode ? t('自用模式') : t('演示站点')}
201
+ </Tag>
202
+ )}
203
+ </div>
204
  ),
205
  }:{
206
  logo: (
207
  <img src={logo} alt='logo' />
208
  ),
209
+ text: (
210
+ <div style={{ position: 'relative', display: 'inline-block' }}>
211
+ {systemName}
212
+ {(isSelfUseMode || isDemoSiteMode) && (
213
+ <Tag
214
+ color={isSelfUseMode ? 'purple' : 'blue'}
215
+ style={{
216
+ position: 'absolute',
217
+ top: '-10px',
218
+ right: '-25px',
219
+ fontSize: '0.7rem',
220
+ padding: '0 4px',
221
+ whiteSpace: 'nowrap',
222
+ zIndex: 1,
223
+ boxShadow: '0 0 3px rgba(255, 255, 255, 0.7)'
224
+ }}
225
+ >
226
+ {isSelfUseMode ? t('自用模式') : t('演示站点')}
227
+ </Tag>
228
+ )}
229
+ </div>
230
+ ),
231
  }}
232
  items={buttons}
233
  footer={
 
311
  icon={<IconUser />}
312
  />
313
  {
314
+ // Hide register option in self-use mode
315
+ !styleState.isMobile && !isSelfUseMode && (
316
  <Nav.Item
317
  itemKey={'register'}
318
  text={t('注册')}
web/src/components/OperationSetting.js CHANGED
@@ -60,6 +60,7 @@ const OperationSetting = () => {
60
  RetryTimes: 0,
61
  Chats: "[]",
62
  DemoSiteEnabled: false,
 
63
  AutomaticDisableKeywords: '',
64
  });
65
 
 
60
  RetryTimes: 0,
61
  Chats: "[]",
62
  DemoSiteEnabled: false,
63
+ SelfUseModeEnabled: false,
64
  AutomaticDisableKeywords: '',
65
  });
66
 
web/src/components/OtherSetting.js CHANGED
@@ -1,8 +1,10 @@
1
- import React, { useEffect, useRef, useState } from 'react';
2
- import { Banner, Button, Col, Form, Row } from '@douyinfe/semi-ui';
3
- import { API, showError, showSuccess } from '../helpers';
4
  import { marked } from 'marked';
5
  import { useTranslation } from 'react-i18next';
 
 
6
 
7
  const OtherSetting = () => {
8
  const { t } = useTranslation();
@@ -16,6 +18,7 @@ const OtherSetting = () => {
16
  });
17
  let [loading, setLoading] = useState(false);
18
  const [showUpdateModal, setShowUpdateModal] = useState(false);
 
19
  const [updateData, setUpdateData] = useState({
20
  tag_name: '',
21
  content: '',
@@ -43,6 +46,7 @@ const OtherSetting = () => {
43
  HomePageContent: false,
44
  About: false,
45
  Footer: false,
 
46
  });
47
  const handleInputChange = async (value, e) => {
48
  const name = e.target.id;
@@ -145,23 +149,48 @@ const OtherSetting = () => {
145
  }
146
  };
147
 
148
- const openGitHubRelease = () => {
149
- window.location = 'https://github.com/songquanpeng/one-api/releases/latest';
150
- };
151
-
152
  const checkUpdate = async () => {
153
- const res = await API.get(
154
- 'https://api.github.com/repos/songquanpeng/one-api/releases/latest',
155
- );
156
- const { tag_name, body } = res.data;
157
- if (tag_name === process.env.REACT_APP_VERSION) {
158
- showSuccess(`已是最新版本:${tag_name}`);
159
- } else {
160
- setUpdateData({
161
- tag_name: tag_name,
162
- content: marked.parse(body),
163
- });
164
- setShowUpdateModal(true);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  }
166
  };
167
  const getOptions = async () => {
@@ -186,9 +215,41 @@ const OtherSetting = () => {
186
  getOptions();
187
  }, []);
188
 
 
 
 
 
 
 
 
 
 
 
189
  return (
190
  <Row>
191
  <Col span={24}>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  {/* 通用设置 */}
193
  <Form
194
  values={inputs}
@@ -282,28 +343,25 @@ const OtherSetting = () => {
282
  </Form.Section>
283
  </Form>
284
  </Col>
285
- {/*<Modal*/}
286
- {/* onClose={() => setShowUpdateModal(false)}*/}
287
- {/* onOpen={() => setShowUpdateModal(true)}*/}
288
- {/* open={showUpdateModal}*/}
289
- {/*>*/}
290
- {/* <Modal.Header>新版本:{updateData.tag_name}</Modal.Header>*/}
291
- {/* <Modal.Content>*/}
292
- {/* <Modal.Description>*/}
293
- {/* <div dangerouslySetInnerHTML={{ __html: updateData.content }}></div>*/}
294
- {/* </Modal.Description>*/}
295
- {/* </Modal.Content>*/}
296
- {/* <Modal.Actions>*/}
297
- {/* <Button onClick={() => setShowUpdateModal(false)}>关闭</Button>*/}
298
- {/* <Button*/}
299
- {/* content='详情'*/}
300
- {/* onClick={() => {*/}
301
- {/* setShowUpdateModal(false);*/}
302
- {/* openGitHubRelease();*/}
303
- {/* }}*/}
304
- {/* />*/}
305
- {/* </Modal.Actions>*/}
306
- {/*</Modal>*/}
307
  </Row>
308
  );
309
  };
 
1
+ import React, { useContext, useEffect, useRef, useState } from 'react';
2
+ import { Banner, Button, Col, Form, Row, Modal, Space } from '@douyinfe/semi-ui';
3
+ import { API, showError, showSuccess, timestamp2string } from '../helpers';
4
  import { marked } from 'marked';
5
  import { useTranslation } from 'react-i18next';
6
+ import { StatusContext } from '../context/Status/index.js';
7
+ import Text from '@douyinfe/semi-ui/lib/es/typography/text';
8
 
9
  const OtherSetting = () => {
10
  const { t } = useTranslation();
 
18
  });
19
  let [loading, setLoading] = useState(false);
20
  const [showUpdateModal, setShowUpdateModal] = useState(false);
21
+ const [statusState, statusDispatch] = useContext(StatusContext);
22
  const [updateData, setUpdateData] = useState({
23
  tag_name: '',
24
  content: '',
 
46
  HomePageContent: false,
47
  About: false,
48
  Footer: false,
49
+ CheckUpdate: false
50
  });
51
  const handleInputChange = async (value, e) => {
52
  const name = e.target.id;
 
149
  }
150
  };
151
 
 
 
 
 
152
  const checkUpdate = async () => {
153
+ try {
154
+ setLoadingInput((loadingInput) => ({ ...loadingInput, CheckUpdate: true }));
155
+ // Use a CORS proxy to avoid direct cross-origin requests to GitHub API
156
+ // Option 1: Use a public CORS proxy service
157
+ // const proxyUrl = 'https://cors-anywhere.herokuapp.com/';
158
+ // const res = await API.get(
159
+ // `${proxyUrl}https://api.github.com/repos/Calcium-Ion/new-api/releases/latest`,
160
+ // );
161
+
162
+ // Option 2: Use the JSON proxy approach which often works better with GitHub API
163
+ const res = await fetch(
164
+ 'https://api.github.com/repos/Calcium-Ion/new-api/releases/latest',
165
+ {
166
+ headers: {
167
+ 'Accept': 'application/json',
168
+ 'Content-Type': 'application/json',
169
+ // Adding User-Agent which is often required by GitHub API
170
+ 'User-Agent': 'new-api-update-checker'
171
+ }
172
+ }
173
+ ).then(response => response.json());
174
+
175
+ // Option 3: Use a local proxy endpoint
176
+ // Create a cached version of the response to avoid frequent GitHub API calls
177
+ // const res = await API.get('/api/status/github-latest-release');
178
+
179
+ const { tag_name, body } = res;
180
+ if (tag_name === statusState?.status?.version) {
181
+ showSuccess(`已是最新版本:${tag_name}`);
182
+ } else {
183
+ setUpdateData({
184
+ tag_name: tag_name,
185
+ content: marked.parse(body),
186
+ });
187
+ setShowUpdateModal(true);
188
+ }
189
+ } catch (error) {
190
+ console.error('Failed to check for updates:', error);
191
+ showError('检查更新失败,请稍后再试');
192
+ } finally {
193
+ setLoadingInput((loadingInput) => ({ ...loadingInput, CheckUpdate: false }));
194
  }
195
  };
196
  const getOptions = async () => {
 
215
  getOptions();
216
  }, []);
217
 
218
+ // Function to open GitHub release page
219
+ const openGitHubRelease = () => {
220
+ window.open(`https://github.com/Calcium-Ion/new-api/releases/tag/${updateData.tag_name}`, '_blank');
221
+ };
222
+
223
+ const getStartTimeString = () => {
224
+ const timestamp = statusState?.status?.start_time;
225
+ return statusState.status ? timestamp2string(timestamp) : '';
226
+ };
227
+
228
  return (
229
  <Row>
230
  <Col span={24}>
231
+ {/* 版本信息 */}
232
+ <Form style={{ marginBottom: 15 }}>
233
+ <Form.Section text={t('系统信息')}>
234
+ <Row>
235
+ <Col span={16}>
236
+ <Space>
237
+ <Text>
238
+ {t('当前版本')}:{statusState?.status?.version || t('未知')}
239
+ </Text>
240
+ <Button type="primary" onClick={checkUpdate} loading={loadingInput['CheckUpdate']}>
241
+ {t('检查更新')}
242
+ </Button>
243
+ </Space>
244
+ </Col>
245
+ </Row>
246
+ <Row>
247
+ <Col span={16}>
248
+ <Text>{t('启动时间')}:{getStartTimeString()}</Text>
249
+ </Col>
250
+ </Row>
251
+ </Form.Section>
252
+ </Form>
253
  {/* 通用设置 */}
254
  <Form
255
  values={inputs}
 
343
  </Form.Section>
344
  </Form>
345
  </Col>
346
+ <Modal
347
+ title={t('新版本') + ':' + updateData.tag_name}
348
+ visible={showUpdateModal}
349
+ onCancel={() => setShowUpdateModal(false)}
350
+ footer={[
351
+ <Button
352
+ key="details"
353
+ type="primary"
354
+ onClick={() => {
355
+ setShowUpdateModal(false);
356
+ openGitHubRelease();
357
+ }}
358
+ >
359
+ {t('详情')}
360
+ </Button>
361
+ ]}
362
+ >
363
+ <div dangerouslySetInnerHTML={{ __html: updateData.content }}></div>
364
+ </Modal>
 
 
 
365
  </Row>
366
  );
367
  };