25:03:07 18:16:42 v0.4.9.0
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.en.md +2 -2
- VERSION +1 -1
- common/gopool.go +24 -0
- controller/channel-test.go +26 -33
- controller/misc.go +3 -1
- controller/pricing.go +2 -3
- controller/relay.go +3 -11
- middleware/model-rate-limit.go +19 -16
- model/channel.go +18 -10
- model/log.go +3 -2
- model/option.go +16 -12
- model/pricing.go +4 -3
- relay/channel/ali/text.go +2 -1
- relay/channel/aws/dto.go +1 -1
- relay/channel/aws/relay-aws.go +4 -3
- relay/channel/baidu/relay-baidu.go +2 -1
- relay/channel/claude/dto.go +1 -1
- relay/channel/claude/relay-claude.go +13 -22
- relay/channel/cloudflare/relay_cloudflare.go +8 -7
- relay/channel/cohere/relay-cohere.go +2 -1
- relay/channel/dify/relay-dify.go +4 -3
- relay/channel/gemini/relay-gemini.go +18 -29
- relay/channel/openai/relay-openai.go +47 -79
- relay/channel/palm/relay-palm.go +2 -1
- relay/channel/tencent/relay-tencent.go +4 -3
- relay/channel/vertex/adaptor.go +6 -9
- relay/channel/vertex/dto.go +24 -4
- relay/channel/xunfei/relay-xunfei.go +2 -1
- relay/channel/zhipu/relay-zhipu.go +2 -1
- relay/channel/zhipu_4v/relay-zhipu_v4.go +2 -1
- relay/common/relay_info.go +31 -21
- service/relay.go → relay/helper/common.go +1 -1
- relay/helper/price.go +11 -3
- relay/helper/stream_scanner.go +91 -0
- relay/relay-mj.go +4 -4
- relay/relay-text.go +1 -1
- relay/relay_task.go +2 -2
- relay/websocket.go +2 -2
- service/channel.go +18 -10
- service/image.go +44 -11
- service/quota.go +10 -10
- service/token_counter.go +2 -1
- service/user_notify.go +4 -1
- {common → setting}/model-ratio.go +9 -8
- setting/{operation_setting.go → operation_setting/operation_setting.go} +2 -1
- web/src/App.js +11 -0
- web/src/components/ChannelsTable.js +84 -12
- web/src/components/HeaderBar.js +51 -5
- web/src/components/OperationSetting.js +1 -0
- web/src/components/OtherSetting.js +99 -41
README.en.md
CHANGED
|
@@ -68,7 +68,7 @@
|
|
| 68 |
|
| 69 |
## Model Support
|
| 70 |
This version additionally supports:
|
| 71 |
-
1. Third-party model **
|
| 72 |
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md)
|
| 73 |
3. Custom channels with full API URL support
|
| 74 |
4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md)
|
|
@@ -162,7 +162,7 @@ docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtow
|
|
| 162 |
|
| 163 |
## Channel Retry
|
| 164 |
Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
|
| 165 |
-
|
| 166 |
|
| 167 |
### Cache Configuration
|
| 168 |
1. `REDIS_CONN_STRING`: Use Redis as cache
|
|
|
|
| 68 |
|
| 69 |
## Model Support
|
| 70 |
This version additionally supports:
|
| 71 |
+
1. Third-party model **gpts** (gpt-4-gizmo-*)
|
| 72 |
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md)
|
| 73 |
3. Custom channels with full API URL support
|
| 74 |
4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md)
|
|
|
|
| 162 |
|
| 163 |
## Channel Retry
|
| 164 |
Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
|
| 165 |
+
If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request.
|
| 166 |
|
| 167 |
### Cache Configuration
|
| 168 |
1. `REDIS_CONN_STRING`: Use Redis as cache
|
VERSION
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
v0.4.
|
|
|
|
| 1 |
+
v0.4.9.0
|
common/gopool.go
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package common
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"fmt"
|
| 6 |
+
"github.com/bytedance/gopkg/util/gopool"
|
| 7 |
+
"math"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
var relayGoPool gopool.Pool
|
| 11 |
+
|
| 12 |
+
func init() {
|
| 13 |
+
relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
|
| 14 |
+
relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
|
| 15 |
+
if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
|
| 16 |
+
SafeSendBool(stopChan, true)
|
| 17 |
+
}
|
| 18 |
+
SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i))
|
| 19 |
+
})
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func RelayCtxGo(ctx context.Context, f func()) {
|
| 23 |
+
relayGoPool.CtxGo(ctx, f)
|
| 24 |
+
}
|
controller/channel-test.go
CHANGED
|
@@ -17,6 +17,7 @@ import (
|
|
| 17 |
"one-api/relay"
|
| 18 |
relaycommon "one-api/relay/common"
|
| 19 |
"one-api/relay/constant"
|
|
|
|
| 20 |
"one-api/service"
|
| 21 |
"strconv"
|
| 22 |
"strings"
|
|
@@ -72,18 +73,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|
| 72 |
}
|
| 73 |
}
|
| 74 |
|
| 75 |
-
modelMapping := *channel.ModelMapping
|
| 76 |
-
if modelMapping != "" && modelMapping != "{}" {
|
| 77 |
-
modelMap := make(map[string]string)
|
| 78 |
-
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
| 79 |
-
if err != nil {
|
| 80 |
-
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
| 81 |
-
}
|
| 82 |
-
if modelMap[testModel] != "" {
|
| 83 |
-
testModel = modelMap[testModel]
|
| 84 |
-
}
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
cache, err := model.GetUserCache(1)
|
| 88 |
if err != nil {
|
| 89 |
return err, nil
|
|
@@ -97,7 +86,14 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|
| 97 |
|
| 98 |
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
| 99 |
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
| 102 |
adaptor := relay.GetAdaptor(apiType)
|
| 103 |
if adaptor == nil {
|
|
@@ -105,12 +101,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|
| 105 |
}
|
| 106 |
|
| 107 |
request := buildTestRequest(testModel)
|
| 108 |
-
|
| 109 |
-
common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
|
| 110 |
|
| 111 |
-
adaptor.Init(
|
| 112 |
|
| 113 |
-
convertedRequest, err := adaptor.ConvertRequest(c,
|
| 114 |
if err != nil {
|
| 115 |
return err, nil
|
| 116 |
}
|
|
@@ -120,7 +115,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|
| 120 |
}
|
| 121 |
requestBody := bytes.NewBuffer(jsonData)
|
| 122 |
c.Request.Body = io.NopCloser(requestBody)
|
| 123 |
-
resp, err := adaptor.DoRequest(c,
|
| 124 |
if err != nil {
|
| 125 |
return err, nil
|
| 126 |
}
|
|
@@ -132,7 +127,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|
| 132 |
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
| 133 |
}
|
| 134 |
}
|
| 135 |
-
usageA, respErr := adaptor.DoResponse(c, httpResp,
|
| 136 |
if respErr != nil {
|
| 137 |
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
| 138 |
}
|
|
@@ -145,29 +140,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|
| 145 |
if err != nil {
|
| 146 |
return err, nil
|
| 147 |
}
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
if
|
| 151 |
-
return
|
| 152 |
}
|
| 153 |
-
completionRatio := common.GetCompletionRatio(testModel)
|
| 154 |
-
ratio := modelRatio
|
| 155 |
quota := 0
|
| 156 |
-
if !
|
| 157 |
-
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*
|
| 158 |
-
quota = int(math.Round(float64(quota) *
|
| 159 |
-
if
|
| 160 |
quota = 1
|
| 161 |
}
|
| 162 |
} else {
|
| 163 |
-
quota = int(
|
| 164 |
}
|
| 165 |
tok := time.Now()
|
| 166 |
milliseconds := tok.Sub(tik).Milliseconds()
|
| 167 |
consumedTime := float64(milliseconds) / 1000.0
|
| 168 |
-
other := service.GenerateTextOtherInfo(c,
|
| 169 |
-
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens,
|
| 170 |
-
quota, "模型测试", 0, quota, int(consumedTime), false,
|
| 171 |
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
| 172 |
return nil, nil
|
| 173 |
}
|
|
|
|
| 17 |
"one-api/relay"
|
| 18 |
relaycommon "one-api/relay/common"
|
| 19 |
"one-api/relay/constant"
|
| 20 |
+
"one-api/relay/helper"
|
| 21 |
"one-api/service"
|
| 22 |
"strconv"
|
| 23 |
"strings"
|
|
|
|
| 73 |
}
|
| 74 |
}
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
cache, err := model.GetUserCache(1)
|
| 77 |
if err != nil {
|
| 78 |
return err, nil
|
|
|
|
| 86 |
|
| 87 |
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
| 88 |
|
| 89 |
+
info := relaycommon.GenRelayInfo(c)
|
| 90 |
+
|
| 91 |
+
err = helper.ModelMappedHelper(c, info)
|
| 92 |
+
if err != nil {
|
| 93 |
+
return err, nil
|
| 94 |
+
}
|
| 95 |
+
testModel = info.UpstreamModelName
|
| 96 |
+
|
| 97 |
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
| 98 |
adaptor := relay.GetAdaptor(apiType)
|
| 99 |
if adaptor == nil {
|
|
|
|
| 101 |
}
|
| 102 |
|
| 103 |
request := buildTestRequest(testModel)
|
| 104 |
+
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
|
|
|
|
| 105 |
|
| 106 |
+
adaptor.Init(info)
|
| 107 |
|
| 108 |
+
convertedRequest, err := adaptor.ConvertRequest(c, info, request)
|
| 109 |
if err != nil {
|
| 110 |
return err, nil
|
| 111 |
}
|
|
|
|
| 115 |
}
|
| 116 |
requestBody := bytes.NewBuffer(jsonData)
|
| 117 |
c.Request.Body = io.NopCloser(requestBody)
|
| 118 |
+
resp, err := adaptor.DoRequest(c, info, requestBody)
|
| 119 |
if err != nil {
|
| 120 |
return err, nil
|
| 121 |
}
|
|
|
|
| 127 |
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
| 128 |
}
|
| 129 |
}
|
| 130 |
+
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
| 131 |
if respErr != nil {
|
| 132 |
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
| 133 |
}
|
|
|
|
| 140 |
if err != nil {
|
| 141 |
return err, nil
|
| 142 |
}
|
| 143 |
+
info.PromptTokens = usage.PromptTokens
|
| 144 |
+
priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
|
| 145 |
+
if err != nil {
|
| 146 |
+
return err, nil
|
| 147 |
}
|
|
|
|
|
|
|
| 148 |
quota := 0
|
| 149 |
+
if !priceData.UsePrice {
|
| 150 |
+
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
| 151 |
+
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
| 152 |
+
if priceData.ModelRatio != 0 && quota <= 0 {
|
| 153 |
quota = 1
|
| 154 |
}
|
| 155 |
} else {
|
| 156 |
+
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
|
| 157 |
}
|
| 158 |
tok := time.Now()
|
| 159 |
milliseconds := tok.Sub(tik).Milliseconds()
|
| 160 |
consumedTime := float64(milliseconds) / 1000.0
|
| 161 |
+
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, priceData.ModelPrice)
|
| 162 |
+
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
| 163 |
+
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
| 164 |
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
| 165 |
return nil, nil
|
| 166 |
}
|
controller/misc.go
CHANGED
|
@@ -7,6 +7,7 @@ import (
|
|
| 7 |
"one-api/common"
|
| 8 |
"one-api/model"
|
| 9 |
"one-api/setting"
|
|
|
|
| 10 |
"strings"
|
| 11 |
|
| 12 |
"github.com/gin-gonic/gin"
|
|
@@ -66,7 +67,8 @@ func GetStatus(c *gin.Context) {
|
|
| 66 |
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
| 67 |
"mj_notify_enabled": setting.MjNotifyEnabled,
|
| 68 |
"chats": setting.Chats,
|
| 69 |
-
"demo_site_enabled":
|
|
|
|
| 70 |
},
|
| 71 |
})
|
| 72 |
return
|
|
|
|
| 7 |
"one-api/common"
|
| 8 |
"one-api/model"
|
| 9 |
"one-api/setting"
|
| 10 |
+
"one-api/setting/operation_setting"
|
| 11 |
"strings"
|
| 12 |
|
| 13 |
"github.com/gin-gonic/gin"
|
|
|
|
| 67 |
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
| 68 |
"mj_notify_enabled": setting.MjNotifyEnabled,
|
| 69 |
"chats": setting.Chats,
|
| 70 |
+
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
| 71 |
+
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
| 72 |
},
|
| 73 |
})
|
| 74 |
return
|
controller/pricing.go
CHANGED
|
@@ -2,7 +2,6 @@ package controller
|
|
| 2 |
|
| 3 |
import (
|
| 4 |
"github.com/gin-gonic/gin"
|
| 5 |
-
"one-api/common"
|
| 6 |
"one-api/model"
|
| 7 |
"one-api/setting"
|
| 8 |
)
|
|
@@ -40,7 +39,7 @@ func GetPricing(c *gin.Context) {
|
|
| 40 |
}
|
| 41 |
|
| 42 |
func ResetModelRatio(c *gin.Context) {
|
| 43 |
-
defaultStr :=
|
| 44 |
err := model.UpdateOption("ModelRatio", defaultStr)
|
| 45 |
if err != nil {
|
| 46 |
c.JSON(200, gin.H{
|
|
@@ -49,7 +48,7 @@ func ResetModelRatio(c *gin.Context) {
|
|
| 49 |
})
|
| 50 |
return
|
| 51 |
}
|
| 52 |
-
err =
|
| 53 |
if err != nil {
|
| 54 |
c.JSON(200, gin.H{
|
| 55 |
"success": false,
|
|
|
|
| 2 |
|
| 3 |
import (
|
| 4 |
"github.com/gin-gonic/gin"
|
|
|
|
| 5 |
"one-api/model"
|
| 6 |
"one-api/setting"
|
| 7 |
)
|
|
|
|
| 39 |
}
|
| 40 |
|
| 41 |
func ResetModelRatio(c *gin.Context) {
|
| 42 |
+
defaultStr := setting.DefaultModelRatio2JSONString()
|
| 43 |
err := model.UpdateOption("ModelRatio", defaultStr)
|
| 44 |
if err != nil {
|
| 45 |
c.JSON(200, gin.H{
|
|
|
|
| 48 |
})
|
| 49 |
return
|
| 50 |
}
|
| 51 |
+
err = setting.UpdateModelRatioByJSONString(defaultStr)
|
| 52 |
if err != nil {
|
| 53 |
c.JSON(200, gin.H{
|
| 54 |
"success": false,
|
controller/relay.go
CHANGED
|
@@ -16,6 +16,7 @@ import (
|
|
| 16 |
"one-api/relay"
|
| 17 |
"one-api/relay/constant"
|
| 18 |
relayconstant "one-api/relay/constant"
|
|
|
|
| 19 |
"one-api/service"
|
| 20 |
"strings"
|
| 21 |
)
|
|
@@ -41,15 +42,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|
| 41 |
return err
|
| 42 |
}
|
| 43 |
|
| 44 |
-
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
| 45 |
-
var err *dto.OpenAIErrorWithStatusCode
|
| 46 |
-
switch relayMode {
|
| 47 |
-
default:
|
| 48 |
-
err = relay.TextHelper(c)
|
| 49 |
-
}
|
| 50 |
-
return err
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
func Relay(c *gin.Context) {
|
| 54 |
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
| 55 |
requestId := c.GetString(common.RequestIdKey)
|
|
@@ -110,7 +102,7 @@ func WssRelay(c *gin.Context) {
|
|
| 110 |
|
| 111 |
if err != nil {
|
| 112 |
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
| 113 |
-
|
| 114 |
return
|
| 115 |
}
|
| 116 |
|
|
@@ -152,7 +144,7 @@ func WssRelay(c *gin.Context) {
|
|
| 152 |
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
| 153 |
}
|
| 154 |
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
| 155 |
-
|
| 156 |
}
|
| 157 |
}
|
| 158 |
|
|
|
|
| 16 |
"one-api/relay"
|
| 17 |
"one-api/relay/constant"
|
| 18 |
relayconstant "one-api/relay/constant"
|
| 19 |
+
"one-api/relay/helper"
|
| 20 |
"one-api/service"
|
| 21 |
"strings"
|
| 22 |
)
|
|
|
|
| 42 |
return err
|
| 43 |
}
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
func Relay(c *gin.Context) {
|
| 46 |
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
| 47 |
requestId := c.GetString(common.RequestIdKey)
|
|
|
|
| 102 |
|
| 103 |
if err != nil {
|
| 104 |
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
| 105 |
+
helper.WssError(c, ws, openaiErr.Error)
|
| 106 |
return
|
| 107 |
}
|
| 108 |
|
|
|
|
| 144 |
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
| 145 |
}
|
| 146 |
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
| 147 |
+
helper.WssError(c, ws, openaiErr.Error)
|
| 148 |
}
|
| 149 |
}
|
| 150 |
|
middleware/model-rate-limit.go
CHANGED
|
@@ -51,7 +51,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
|
|
| 51 |
// 如果在时间窗口内已达到限制,拒绝请求
|
| 52 |
subTime := nowTime.Sub(oldTime).Seconds()
|
| 53 |
if int64(subTime) < duration {
|
| 54 |
-
rdb.Expire(ctx, key,
|
| 55 |
return false, nil
|
| 56 |
}
|
| 57 |
|
|
@@ -68,7 +68,7 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC
|
|
| 68 |
now := time.Now().Format(timeFormat)
|
| 69 |
rdb.LPush(ctx, key, now)
|
| 70 |
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
|
| 71 |
-
rdb.Expire(ctx, key,
|
| 72 |
}
|
| 73 |
|
| 74 |
// Redis限流处理器
|
|
@@ -118,7 +118,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
|
|
| 118 |
|
| 119 |
// 内存限流处理器
|
| 120 |
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
|
| 121 |
-
inMemoryRateLimiter.Init(
|
| 122 |
|
| 123 |
return func(c *gin.Context) {
|
| 124 |
userId := strconv.Itoa(c.GetInt("id"))
|
|
@@ -153,20 +153,23 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
|
|
| 153 |
|
| 154 |
// ModelRequestRateLimit 模型请求限流中间件
|
| 155 |
func ModelRequestRateLimit() func(c *gin.Context) {
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
| 171 |
}
|
| 172 |
}
|
|
|
|
| 51 |
// 如果在时间窗口内已达到限制,拒绝请求
|
| 52 |
subTime := nowTime.Sub(oldTime).Seconds()
|
| 53 |
if int64(subTime) < duration {
|
| 54 |
+
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
| 55 |
return false, nil
|
| 56 |
}
|
| 57 |
|
|
|
|
| 68 |
now := time.Now().Format(timeFormat)
|
| 69 |
rdb.LPush(ctx, key, now)
|
| 70 |
rdb.LTrim(ctx, key, 0, int64(maxCount-1))
|
| 71 |
+
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
|
| 72 |
}
|
| 73 |
|
| 74 |
// Redis限流处理器
|
|
|
|
| 118 |
|
| 119 |
// 内存限流处理器
|
| 120 |
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
|
| 121 |
+
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
|
| 122 |
|
| 123 |
return func(c *gin.Context) {
|
| 124 |
userId := strconv.Itoa(c.GetInt("id"))
|
|
|
|
| 153 |
|
| 154 |
// ModelRequestRateLimit 模型请求限流中间件
|
| 155 |
func ModelRequestRateLimit() func(c *gin.Context) {
|
| 156 |
+
return func(c *gin.Context) {
|
| 157 |
+
// 在每个请求时检查是否启用限流
|
| 158 |
+
if !setting.ModelRequestRateLimitEnabled {
|
| 159 |
+
c.Next()
|
| 160 |
+
return
|
| 161 |
+
}
|
| 162 |
|
| 163 |
+
// 计算限流参数
|
| 164 |
+
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
|
| 165 |
+
totalMaxCount := setting.ModelRequestRateLimitCount
|
| 166 |
+
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
| 167 |
|
| 168 |
+
// 根据存储类型选择并执行限流处理器
|
| 169 |
+
if common.RedisEnabled {
|
| 170 |
+
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
| 171 |
+
} else {
|
| 172 |
+
memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
| 173 |
+
}
|
| 174 |
}
|
| 175 |
}
|
model/channel.go
CHANGED
|
@@ -290,35 +290,42 @@ func (channel *Channel) Delete() error {
|
|
| 290 |
|
| 291 |
var channelStatusLock sync.Mutex
|
| 292 |
|
| 293 |
-
func UpdateChannelStatusById(id int, status int, reason string) {
|
| 294 |
if common.MemoryCacheEnabled {
|
| 295 |
channelStatusLock.Lock()
|
|
|
|
|
|
|
| 296 |
channelCache, _ := CacheGetChannel(id)
|
| 297 |
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
| 298 |
if channelCache != nil && channelCache.Status == status {
|
| 299 |
-
|
| 300 |
-
return
|
| 301 |
}
|
| 302 |
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
| 303 |
if channelCache == nil && status != common.ChannelStatusEnabled {
|
| 304 |
-
|
| 305 |
-
return
|
| 306 |
}
|
| 307 |
CacheUpdateChannelStatus(id, status)
|
| 308 |
-
channelStatusLock.Unlock()
|
| 309 |
}
|
| 310 |
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
| 311 |
if err != nil {
|
| 312 |
common.SysError("failed to update ability status: " + err.Error())
|
|
|
|
| 313 |
}
|
| 314 |
channel, err := GetChannelById(id, true)
|
| 315 |
if err != nil {
|
| 316 |
// find channel by id error, directly update status
|
| 317 |
-
|
| 318 |
-
if
|
| 319 |
-
common.SysError("failed to update channel status: " +
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
}
|
| 321 |
} else {
|
|
|
|
|
|
|
|
|
|
| 322 |
// find channel by id success, update status and other info
|
| 323 |
info := channel.GetOtherInfo()
|
| 324 |
info["status_reason"] = reason
|
|
@@ -328,9 +335,10 @@ func UpdateChannelStatusById(id int, status int, reason string) {
|
|
| 328 |
err = channel.Save()
|
| 329 |
if err != nil {
|
| 330 |
common.SysError("failed to update channel status: " + err.Error())
|
|
|
|
| 331 |
}
|
| 332 |
}
|
| 333 |
-
|
| 334 |
}
|
| 335 |
|
| 336 |
func EnableChannelByTag(tag string) error {
|
|
|
|
| 290 |
|
| 291 |
var channelStatusLock sync.Mutex
|
| 292 |
|
| 293 |
+
func UpdateChannelStatusById(id int, status int, reason string) bool {
|
| 294 |
if common.MemoryCacheEnabled {
|
| 295 |
channelStatusLock.Lock()
|
| 296 |
+
defer channelStatusLock.Unlock()
|
| 297 |
+
|
| 298 |
channelCache, _ := CacheGetChannel(id)
|
| 299 |
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
| 300 |
if channelCache != nil && channelCache.Status == status {
|
| 301 |
+
return false
|
|
|
|
| 302 |
}
|
| 303 |
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
|
| 304 |
if channelCache == nil && status != common.ChannelStatusEnabled {
|
| 305 |
+
return false
|
|
|
|
| 306 |
}
|
| 307 |
CacheUpdateChannelStatus(id, status)
|
|
|
|
| 308 |
}
|
| 309 |
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
| 310 |
if err != nil {
|
| 311 |
common.SysError("failed to update ability status: " + err.Error())
|
| 312 |
+
return false
|
| 313 |
}
|
| 314 |
channel, err := GetChannelById(id, true)
|
| 315 |
if err != nil {
|
| 316 |
// find channel by id error, directly update status
|
| 317 |
+
result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
|
| 318 |
+
if result.Error != nil {
|
| 319 |
+
common.SysError("failed to update channel status: " + result.Error.Error())
|
| 320 |
+
return false
|
| 321 |
+
}
|
| 322 |
+
if result.RowsAffected == 0 {
|
| 323 |
+
return false
|
| 324 |
}
|
| 325 |
} else {
|
| 326 |
+
if channel.Status == status {
|
| 327 |
+
return false
|
| 328 |
+
}
|
| 329 |
// find channel by id success, update status and other info
|
| 330 |
info := channel.GetOtherInfo()
|
| 331 |
info["status_reason"] = reason
|
|
|
|
| 335 |
err = channel.Save()
|
| 336 |
if err != nil {
|
| 337 |
common.SysError("failed to update channel status: " + err.Error())
|
| 338 |
+
return false
|
| 339 |
}
|
| 340 |
}
|
| 341 |
+
return true
|
| 342 |
}
|
| 343 |
|
| 344 |
func EnableChannelByTag(tag string) error {
|
model/log.go
CHANGED
|
@@ -2,12 +2,13 @@ package model
|
|
| 2 |
|
| 3 |
import (
|
| 4 |
"fmt"
|
| 5 |
-
"github.com/gin-gonic/gin"
|
| 6 |
"one-api/common"
|
| 7 |
"os"
|
| 8 |
"strings"
|
| 9 |
"time"
|
| 10 |
|
|
|
|
|
|
|
| 11 |
"github.com/bytedance/gopkg/util/gopool"
|
| 12 |
"gorm.io/gorm"
|
| 13 |
)
|
|
@@ -18,7 +19,7 @@ type Log struct {
|
|
| 18 |
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
| 19 |
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
| 20 |
Content string `json:"content"`
|
| 21 |
-
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
|
| 22 |
TokenName string `json:"token_name" gorm:"index;default:''"`
|
| 23 |
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
| 24 |
Quota int `json:"quota" gorm:"default:0"`
|
|
|
|
| 2 |
|
| 3 |
import (
|
| 4 |
"fmt"
|
|
|
|
| 5 |
"one-api/common"
|
| 6 |
"os"
|
| 7 |
"strings"
|
| 8 |
"time"
|
| 9 |
|
| 10 |
+
"github.com/gin-gonic/gin"
|
| 11 |
+
|
| 12 |
"github.com/bytedance/gopkg/util/gopool"
|
| 13 |
"gorm.io/gorm"
|
| 14 |
)
|
|
|
|
| 19 |
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
| 20 |
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
| 21 |
Content string `json:"content"`
|
| 22 |
+
Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
|
| 23 |
TokenName string `json:"token_name" gorm:"index;default:''"`
|
| 24 |
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
| 25 |
Quota int `json:"quota" gorm:"default:0"`
|
model/option.go
CHANGED
|
@@ -4,6 +4,7 @@ import (
|
|
| 4 |
"one-api/common"
|
| 5 |
"one-api/setting"
|
| 6 |
"one-api/setting/config"
|
|
|
|
| 7 |
"strconv"
|
| 8 |
"strings"
|
| 9 |
"time"
|
|
@@ -87,15 +88,15 @@ func InitOptionMap() {
|
|
| 87 |
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
|
| 88 |
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
|
| 89 |
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
|
| 90 |
-
common.OptionMap["
|
| 91 |
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
|
| 92 |
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
| 93 |
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
| 94 |
-
common.OptionMap["ModelRatio"] =
|
| 95 |
-
common.OptionMap["ModelPrice"] =
|
| 96 |
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
| 97 |
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
| 98 |
-
common.OptionMap["CompletionRatio"] =
|
| 99 |
common.OptionMap["TopUpLink"] = common.TopUpLink
|
| 100 |
common.OptionMap["ChatLink"] = common.ChatLink
|
| 101 |
common.OptionMap["ChatLink2"] = common.ChatLink2
|
|
@@ -110,13 +111,14 @@ func InitOptionMap() {
|
|
| 110 |
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
|
| 111 |
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
|
| 112 |
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
|
| 113 |
-
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(
|
|
|
|
| 114 |
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
|
| 115 |
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
|
| 116 |
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
|
| 117 |
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
| 118 |
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
| 119 |
-
common.OptionMap["AutomaticDisableKeywords"] =
|
| 120 |
|
| 121 |
// 自动添加所有注册的模型配置
|
| 122 |
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
|
@@ -242,7 +244,9 @@ func updateOptionMap(key string, value string) (err error) {
|
|
| 242 |
case "CheckSensitiveEnabled":
|
| 243 |
setting.CheckSensitiveEnabled = boolValue
|
| 244 |
case "DemoSiteEnabled":
|
| 245 |
-
|
|
|
|
|
|
|
| 246 |
case "CheckSensitiveOnPromptEnabled":
|
| 247 |
setting.CheckSensitiveOnPromptEnabled = boolValue
|
| 248 |
case "ModelRequestRateLimitEnabled":
|
|
@@ -325,7 +329,7 @@ func updateOptionMap(key string, value string) (err error) {
|
|
| 325 |
common.QuotaForInvitee, _ = strconv.Atoi(value)
|
| 326 |
case "QuotaRemindThreshold":
|
| 327 |
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
| 328 |
-
case "
|
| 329 |
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
| 330 |
case "ModelRequestRateLimitCount":
|
| 331 |
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
|
|
@@ -340,15 +344,15 @@ func updateOptionMap(key string, value string) (err error) {
|
|
| 340 |
case "DataExportDefaultTime":
|
| 341 |
common.DataExportDefaultTime = value
|
| 342 |
case "ModelRatio":
|
| 343 |
-
err =
|
| 344 |
case "GroupRatio":
|
| 345 |
err = setting.UpdateGroupRatioByJSONString(value)
|
| 346 |
case "UserUsableGroups":
|
| 347 |
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
| 348 |
case "CompletionRatio":
|
| 349 |
-
err =
|
| 350 |
case "ModelPrice":
|
| 351 |
-
err =
|
| 352 |
case "TopUpLink":
|
| 353 |
common.TopUpLink = value
|
| 354 |
case "ChatLink":
|
|
@@ -362,7 +366,7 @@ func updateOptionMap(key string, value string) (err error) {
|
|
| 362 |
case "SensitiveWords":
|
| 363 |
setting.SensitiveWordsFromString(value)
|
| 364 |
case "AutomaticDisableKeywords":
|
| 365 |
-
|
| 366 |
case "StreamCacheQueueLength":
|
| 367 |
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
| 368 |
}
|
|
|
|
| 4 |
"one-api/common"
|
| 5 |
"one-api/setting"
|
| 6 |
"one-api/setting/config"
|
| 7 |
+
"one-api/setting/operation_setting"
|
| 8 |
"strconv"
|
| 9 |
"strings"
|
| 10 |
"time"
|
|
|
|
| 88 |
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
|
| 89 |
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
|
| 90 |
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
|
| 91 |
+
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
| 92 |
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
|
| 93 |
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
| 94 |
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
| 95 |
+
common.OptionMap["ModelRatio"] = setting.ModelRatio2JSONString()
|
| 96 |
+
common.OptionMap["ModelPrice"] = setting.ModelPrice2JSONString()
|
| 97 |
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
| 98 |
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
| 99 |
+
common.OptionMap["CompletionRatio"] = setting.CompletionRatio2JSONString()
|
| 100 |
common.OptionMap["TopUpLink"] = common.TopUpLink
|
| 101 |
common.OptionMap["ChatLink"] = common.ChatLink
|
| 102 |
common.OptionMap["ChatLink2"] = common.ChatLink2
|
|
|
|
| 111 |
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
|
| 112 |
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
|
| 113 |
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
|
| 114 |
+
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled)
|
| 115 |
+
common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled)
|
| 116 |
common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
|
| 117 |
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
|
| 118 |
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
|
| 119 |
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
| 120 |
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
| 121 |
+
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
| 122 |
|
| 123 |
// 自动添加所有注册的模型配置
|
| 124 |
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
|
|
|
| 244 |
case "CheckSensitiveEnabled":
|
| 245 |
setting.CheckSensitiveEnabled = boolValue
|
| 246 |
case "DemoSiteEnabled":
|
| 247 |
+
operation_setting.DemoSiteEnabled = boolValue
|
| 248 |
+
case "SelfUseModeEnabled":
|
| 249 |
+
operation_setting.SelfUseModeEnabled = boolValue
|
| 250 |
case "CheckSensitiveOnPromptEnabled":
|
| 251 |
setting.CheckSensitiveOnPromptEnabled = boolValue
|
| 252 |
case "ModelRequestRateLimitEnabled":
|
|
|
|
| 329 |
common.QuotaForInvitee, _ = strconv.Atoi(value)
|
| 330 |
case "QuotaRemindThreshold":
|
| 331 |
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
| 332 |
+
case "PreConsumedQuota":
|
| 333 |
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
| 334 |
case "ModelRequestRateLimitCount":
|
| 335 |
setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
|
|
|
|
| 344 |
case "DataExportDefaultTime":
|
| 345 |
common.DataExportDefaultTime = value
|
| 346 |
case "ModelRatio":
|
| 347 |
+
err = setting.UpdateModelRatioByJSONString(value)
|
| 348 |
case "GroupRatio":
|
| 349 |
err = setting.UpdateGroupRatioByJSONString(value)
|
| 350 |
case "UserUsableGroups":
|
| 351 |
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
| 352 |
case "CompletionRatio":
|
| 353 |
+
err = setting.UpdateCompletionRatioByJSONString(value)
|
| 354 |
case "ModelPrice":
|
| 355 |
+
err = setting.UpdateModelPriceByJSONString(value)
|
| 356 |
case "TopUpLink":
|
| 357 |
common.TopUpLink = value
|
| 358 |
case "ChatLink":
|
|
|
|
| 366 |
case "SensitiveWords":
|
| 367 |
setting.SensitiveWordsFromString(value)
|
| 368 |
case "AutomaticDisableKeywords":
|
| 369 |
+
operation_setting.AutomaticDisableKeywordsFromString(value)
|
| 370 |
case "StreamCacheQueueLength":
|
| 371 |
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
| 372 |
}
|
model/pricing.go
CHANGED
|
@@ -2,6 +2,7 @@ package model
|
|
| 2 |
|
| 3 |
import (
|
| 4 |
"one-api/common"
|
|
|
|
| 5 |
"sync"
|
| 6 |
"time"
|
| 7 |
)
|
|
@@ -64,14 +65,14 @@ func updatePricing() {
|
|
| 64 |
ModelName: model,
|
| 65 |
EnableGroup: groups,
|
| 66 |
}
|
| 67 |
-
modelPrice, findPrice :=
|
| 68 |
if findPrice {
|
| 69 |
pricing.ModelPrice = modelPrice
|
| 70 |
pricing.QuotaType = 1
|
| 71 |
} else {
|
| 72 |
-
modelRatio, _ :=
|
| 73 |
pricing.ModelRatio = modelRatio
|
| 74 |
-
pricing.CompletionRatio =
|
| 75 |
pricing.QuotaType = 0
|
| 76 |
}
|
| 77 |
pricingMap = append(pricingMap, pricing)
|
|
|
|
| 2 |
|
| 3 |
import (
|
| 4 |
"one-api/common"
|
| 5 |
+
"one-api/setting"
|
| 6 |
"sync"
|
| 7 |
"time"
|
| 8 |
)
|
|
|
|
| 65 |
ModelName: model,
|
| 66 |
EnableGroup: groups,
|
| 67 |
}
|
| 68 |
+
modelPrice, findPrice := setting.GetModelPrice(model, false)
|
| 69 |
if findPrice {
|
| 70 |
pricing.ModelPrice = modelPrice
|
| 71 |
pricing.QuotaType = 1
|
| 72 |
} else {
|
| 73 |
+
modelRatio, _ := setting.GetModelRatio(model)
|
| 74 |
pricing.ModelRatio = modelRatio
|
| 75 |
+
pricing.CompletionRatio = setting.GetCompletionRatio(model)
|
| 76 |
pricing.QuotaType = 0
|
| 77 |
}
|
| 78 |
pricingMap = append(pricingMap, pricing)
|
relay/channel/ali/text.go
CHANGED
|
@@ -8,6 +8,7 @@ import (
|
|
| 8 |
"net/http"
|
| 9 |
"one-api/common"
|
| 10 |
"one-api/dto"
|
|
|
|
| 11 |
"one-api/service"
|
| 12 |
"strings"
|
| 13 |
)
|
|
@@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
|
| 153 |
}
|
| 154 |
stopChan <- true
|
| 155 |
}()
|
| 156 |
-
|
| 157 |
lastResponseText := ""
|
| 158 |
c.Stream(func(w io.Writer) bool {
|
| 159 |
select {
|
|
|
|
| 8 |
"net/http"
|
| 9 |
"one-api/common"
|
| 10 |
"one-api/dto"
|
| 11 |
+
"one-api/relay/helper"
|
| 12 |
"one-api/service"
|
| 13 |
"strings"
|
| 14 |
)
|
|
|
|
| 154 |
}
|
| 155 |
stopChan <- true
|
| 156 |
}()
|
| 157 |
+
helper.SetEventStreamHeaders(c)
|
| 158 |
lastResponseText := ""
|
| 159 |
c.Stream(func(w io.Writer) bool {
|
| 160 |
select {
|
relay/channel/aws/dto.go
CHANGED
|
@@ -14,7 +14,7 @@ type AwsClaudeRequest struct {
|
|
| 14 |
TopP float64 `json:"top_p,omitempty"`
|
| 15 |
TopK int `json:"top_k,omitempty"`
|
| 16 |
StopSequences []string `json:"stop_sequences,omitempty"`
|
| 17 |
-
Tools
|
| 18 |
ToolChoice any `json:"tool_choice,omitempty"`
|
| 19 |
Thinking *claude.Thinking `json:"thinking,omitempty"`
|
| 20 |
}
|
|
|
|
| 14 |
TopP float64 `json:"top_p,omitempty"`
|
| 15 |
TopK int `json:"top_k,omitempty"`
|
| 16 |
StopSequences []string `json:"stop_sequences,omitempty"`
|
| 17 |
+
Tools any `json:"tools,omitempty"`
|
| 18 |
ToolChoice any `json:"tool_choice,omitempty"`
|
| 19 |
Thinking *claude.Thinking `json:"thinking,omitempty"`
|
| 20 |
}
|
relay/channel/aws/relay-aws.go
CHANGED
|
@@ -12,6 +12,7 @@ import (
|
|
| 12 |
relaymodel "one-api/dto"
|
| 13 |
"one-api/relay/channel/claude"
|
| 14 |
relaycommon "one-api/relay/common"
|
|
|
|
| 15 |
"one-api/service"
|
| 16 |
"strings"
|
| 17 |
"time"
|
|
@@ -203,13 +204,13 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
| 203 |
}
|
| 204 |
})
|
| 205 |
if info.ShouldIncludeUsage {
|
| 206 |
-
response :=
|
| 207 |
-
err :=
|
| 208 |
if err != nil {
|
| 209 |
common.SysError("send final response failed: " + err.Error())
|
| 210 |
}
|
| 211 |
}
|
| 212 |
-
|
| 213 |
if resp != nil {
|
| 214 |
err = resp.Body.Close()
|
| 215 |
if err != nil {
|
|
|
|
| 12 |
relaymodel "one-api/dto"
|
| 13 |
"one-api/relay/channel/claude"
|
| 14 |
relaycommon "one-api/relay/common"
|
| 15 |
+
"one-api/relay/helper"
|
| 16 |
"one-api/service"
|
| 17 |
"strings"
|
| 18 |
"time"
|
|
|
|
| 204 |
}
|
| 205 |
})
|
| 206 |
if info.ShouldIncludeUsage {
|
| 207 |
+
response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
| 208 |
+
err := helper.ObjectData(c, response)
|
| 209 |
if err != nil {
|
| 210 |
common.SysError("send final response failed: " + err.Error())
|
| 211 |
}
|
| 212 |
}
|
| 213 |
+
helper.Done(c)
|
| 214 |
if resp != nil {
|
| 215 |
err = resp.Body.Close()
|
| 216 |
if err != nil {
|
relay/channel/baidu/relay-baidu.go
CHANGED
|
@@ -11,6 +11,7 @@ import (
|
|
| 11 |
"one-api/common"
|
| 12 |
"one-api/constant"
|
| 13 |
"one-api/dto"
|
|
|
|
| 14 |
"one-api/service"
|
| 15 |
"strings"
|
| 16 |
"sync"
|
|
@@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
|
| 138 |
}
|
| 139 |
stopChan <- true
|
| 140 |
}()
|
| 141 |
-
|
| 142 |
c.Stream(func(w io.Writer) bool {
|
| 143 |
select {
|
| 144 |
case data := <-dataChan:
|
|
|
|
| 11 |
"one-api/common"
|
| 12 |
"one-api/constant"
|
| 13 |
"one-api/dto"
|
| 14 |
+
"one-api/relay/helper"
|
| 15 |
"one-api/service"
|
| 16 |
"strings"
|
| 17 |
"sync"
|
|
|
|
| 139 |
}
|
| 140 |
stopChan <- true
|
| 141 |
}()
|
| 142 |
+
helper.SetEventStreamHeaders(c)
|
| 143 |
c.Stream(func(w io.Writer) bool {
|
| 144 |
select {
|
| 145 |
case data := <-dataChan:
|
relay/channel/claude/dto.go
CHANGED
|
@@ -58,7 +58,7 @@ type ClaudeRequest struct {
|
|
| 58 |
TopK int `json:"top_k,omitempty"`
|
| 59 |
//ClaudeMetadata `json:"metadata,omitempty"`
|
| 60 |
Stream bool `json:"stream,omitempty"`
|
| 61 |
-
Tools
|
| 62 |
ToolChoice any `json:"tool_choice,omitempty"`
|
| 63 |
Thinking *Thinking `json:"thinking,omitempty"`
|
| 64 |
}
|
|
|
|
| 58 |
TopK int `json:"top_k,omitempty"`
|
| 59 |
//ClaudeMetadata `json:"metadata,omitempty"`
|
| 60 |
Stream bool `json:"stream,omitempty"`
|
| 61 |
+
Tools any `json:"tools,omitempty"`
|
| 62 |
ToolChoice any `json:"tool_choice,omitempty"`
|
| 63 |
Thinking *Thinking `json:"thinking,omitempty"`
|
| 64 |
}
|
relay/channel/claude/relay-claude.go
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
package claude
|
| 2 |
|
| 3 |
import (
|
| 4 |
-
"bufio"
|
| 5 |
"encoding/json"
|
| 6 |
"fmt"
|
| 7 |
"io"
|
|
@@ -9,6 +8,7 @@ import (
|
|
| 9 |
"one-api/common"
|
| 10 |
"one-api/dto"
|
| 11 |
relaycommon "one-api/relay/common"
|
|
|
|
| 12 |
"one-api/service"
|
| 13 |
"one-api/setting/model_setting"
|
| 14 |
"strings"
|
|
@@ -443,28 +443,18 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|
| 443 |
usage = &dto.Usage{}
|
| 444 |
responseText := ""
|
| 445 |
createdTime := common.GetTimestamp()
|
| 446 |
-
scanner := bufio.NewScanner(resp.Body)
|
| 447 |
-
scanner.Split(bufio.ScanLines)
|
| 448 |
-
service.SetEventStreamHeaders(c)
|
| 449 |
|
| 450 |
-
|
| 451 |
-
data := scanner.Text()
|
| 452 |
-
info.SetFirstResponseTime()
|
| 453 |
-
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
|
| 454 |
-
continue
|
| 455 |
-
}
|
| 456 |
-
data = strings.TrimPrefix(data, "data:")
|
| 457 |
-
data = strings.TrimSpace(data)
|
| 458 |
var claudeResponse ClaudeResponse
|
| 459 |
err := json.Unmarshal([]byte(data), &claudeResponse)
|
| 460 |
if err != nil {
|
| 461 |
common.SysError("error unmarshalling stream response: " + err.Error())
|
| 462 |
-
|
| 463 |
}
|
| 464 |
|
| 465 |
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
| 466 |
if response == nil {
|
| 467 |
-
|
| 468 |
}
|
| 469 |
if requestMode == RequestModeCompletion {
|
| 470 |
responseText += claudeResponse.Completion
|
|
@@ -481,9 +471,9 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|
| 481 |
usage.CompletionTokens = claudeUsage.OutputTokens
|
| 482 |
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
| 483 |
} else if claudeResponse.Type == "content_block_start" {
|
| 484 |
-
|
| 485 |
} else {
|
| 486 |
-
|
| 487 |
}
|
| 488 |
}
|
| 489 |
//response.Id = responseId
|
|
@@ -491,11 +481,12 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|
| 491 |
response.Created = createdTime
|
| 492 |
response.Model = info.UpstreamModelName
|
| 493 |
|
| 494 |
-
err =
|
| 495 |
if err != nil {
|
| 496 |
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
| 497 |
}
|
| 498 |
-
|
|
|
|
| 499 |
|
| 500 |
if requestMode == RequestModeCompletion {
|
| 501 |
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
@@ -508,14 +499,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|
| 508 |
}
|
| 509 |
}
|
| 510 |
if info.ShouldIncludeUsage {
|
| 511 |
-
response :=
|
| 512 |
-
err :=
|
| 513 |
if err != nil {
|
| 514 |
common.SysError("send final response failed: " + err.Error())
|
| 515 |
}
|
| 516 |
}
|
| 517 |
-
|
| 518 |
-
resp.Body.Close()
|
| 519 |
return nil, usage
|
| 520 |
}
|
| 521 |
|
|
|
|
| 1 |
package claude
|
| 2 |
|
| 3 |
import (
|
|
|
|
| 4 |
"encoding/json"
|
| 5 |
"fmt"
|
| 6 |
"io"
|
|
|
|
| 8 |
"one-api/common"
|
| 9 |
"one-api/dto"
|
| 10 |
relaycommon "one-api/relay/common"
|
| 11 |
+
"one-api/relay/helper"
|
| 12 |
"one-api/service"
|
| 13 |
"one-api/setting/model_setting"
|
| 14 |
"strings"
|
|
|
|
| 443 |
usage = &dto.Usage{}
|
| 444 |
responseText := ""
|
| 445 |
createdTime := common.GetTimestamp()
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
+
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
var claudeResponse ClaudeResponse
|
| 449 |
err := json.Unmarshal([]byte(data), &claudeResponse)
|
| 450 |
if err != nil {
|
| 451 |
common.SysError("error unmarshalling stream response: " + err.Error())
|
| 452 |
+
return true
|
| 453 |
}
|
| 454 |
|
| 455 |
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
| 456 |
if response == nil {
|
| 457 |
+
return true
|
| 458 |
}
|
| 459 |
if requestMode == RequestModeCompletion {
|
| 460 |
responseText += claudeResponse.Completion
|
|
|
|
| 471 |
usage.CompletionTokens = claudeUsage.OutputTokens
|
| 472 |
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
| 473 |
} else if claudeResponse.Type == "content_block_start" {
|
| 474 |
+
return true
|
| 475 |
} else {
|
| 476 |
+
return true
|
| 477 |
}
|
| 478 |
}
|
| 479 |
//response.Id = responseId
|
|
|
|
| 481 |
response.Created = createdTime
|
| 482 |
response.Model = info.UpstreamModelName
|
| 483 |
|
| 484 |
+
err = helper.ObjectData(c, response)
|
| 485 |
if err != nil {
|
| 486 |
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
| 487 |
}
|
| 488 |
+
return true
|
| 489 |
+
})
|
| 490 |
|
| 491 |
if requestMode == RequestModeCompletion {
|
| 492 |
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
|
|
| 499 |
}
|
| 500 |
}
|
| 501 |
if info.ShouldIncludeUsage {
|
| 502 |
+
response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
| 503 |
+
err := helper.ObjectData(c, response)
|
| 504 |
if err != nil {
|
| 505 |
common.SysError("send final response failed: " + err.Error())
|
| 506 |
}
|
| 507 |
}
|
| 508 |
+
helper.Done(c)
|
| 509 |
+
//resp.Body.Close()
|
| 510 |
return nil, usage
|
| 511 |
}
|
| 512 |
|
relay/channel/cloudflare/relay_cloudflare.go
CHANGED
|
@@ -9,6 +9,7 @@ import (
|
|
| 9 |
"one-api/common"
|
| 10 |
"one-api/dto"
|
| 11 |
relaycommon "one-api/relay/common"
|
|
|
|
| 12 |
"one-api/service"
|
| 13 |
"strings"
|
| 14 |
"time"
|
|
@@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|
| 28 |
scanner := bufio.NewScanner(resp.Body)
|
| 29 |
scanner.Split(bufio.ScanLines)
|
| 30 |
|
| 31 |
-
|
| 32 |
-
id :=
|
| 33 |
var responseText string
|
| 34 |
isFirst := true
|
| 35 |
|
|
@@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|
| 57 |
}
|
| 58 |
response.Id = id
|
| 59 |
response.Model = info.UpstreamModelName
|
| 60 |
-
err =
|
| 61 |
if isFirst {
|
| 62 |
isFirst = false
|
| 63 |
info.FirstResponseTime = time.Now()
|
|
@@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|
| 72 |
}
|
| 73 |
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
| 74 |
if info.ShouldIncludeUsage {
|
| 75 |
-
response :=
|
| 76 |
-
err :=
|
| 77 |
if err != nil {
|
| 78 |
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
| 79 |
}
|
| 80 |
}
|
| 81 |
-
|
| 82 |
|
| 83 |
err := resp.Body.Close()
|
| 84 |
if err != nil {
|
|
@@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
|
| 109 |
}
|
| 110 |
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
| 111 |
response.Usage = *usage
|
| 112 |
-
response.Id =
|
| 113 |
jsonResponse, err := json.Marshal(response)
|
| 114 |
if err != nil {
|
| 115 |
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
|
| 9 |
"one-api/common"
|
| 10 |
"one-api/dto"
|
| 11 |
relaycommon "one-api/relay/common"
|
| 12 |
+
"one-api/relay/helper"
|
| 13 |
"one-api/service"
|
| 14 |
"strings"
|
| 15 |
"time"
|
|
|
|
| 29 |
scanner := bufio.NewScanner(resp.Body)
|
| 30 |
scanner.Split(bufio.ScanLines)
|
| 31 |
|
| 32 |
+
helper.SetEventStreamHeaders(c)
|
| 33 |
+
id := helper.GetResponseID(c)
|
| 34 |
var responseText string
|
| 35 |
isFirst := true
|
| 36 |
|
|
|
|
| 58 |
}
|
| 59 |
response.Id = id
|
| 60 |
response.Model = info.UpstreamModelName
|
| 61 |
+
err = helper.ObjectData(c, response)
|
| 62 |
if isFirst {
|
| 63 |
isFirst = false
|
| 64 |
info.FirstResponseTime = time.Now()
|
|
|
|
| 73 |
}
|
| 74 |
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
| 75 |
if info.ShouldIncludeUsage {
|
| 76 |
+
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
| 77 |
+
err := helper.ObjectData(c, response)
|
| 78 |
if err != nil {
|
| 79 |
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
| 80 |
}
|
| 81 |
}
|
| 82 |
+
helper.Done(c)
|
| 83 |
|
| 84 |
err := resp.Body.Close()
|
| 85 |
if err != nil {
|
|
|
|
| 110 |
}
|
| 111 |
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
| 112 |
response.Usage = *usage
|
| 113 |
+
response.Id = helper.GetResponseID(c)
|
| 114 |
jsonResponse, err := json.Marshal(response)
|
| 115 |
if err != nil {
|
| 116 |
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
relay/channel/cohere/relay-cohere.go
CHANGED
|
@@ -10,6 +10,7 @@ import (
|
|
| 10 |
"one-api/common"
|
| 11 |
"one-api/dto"
|
| 12 |
relaycommon "one-api/relay/common"
|
|
|
|
| 13 |
"one-api/service"
|
| 14 |
"strings"
|
| 15 |
"time"
|
|
@@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|
| 103 |
}
|
| 104 |
stopChan <- true
|
| 105 |
}()
|
| 106 |
-
|
| 107 |
isFirst := true
|
| 108 |
c.Stream(func(w io.Writer) bool {
|
| 109 |
select {
|
|
|
|
| 10 |
"one-api/common"
|
| 11 |
"one-api/dto"
|
| 12 |
relaycommon "one-api/relay/common"
|
| 13 |
+
"one-api/relay/helper"
|
| 14 |
"one-api/service"
|
| 15 |
"strings"
|
| 16 |
"time"
|
|
|
|
| 104 |
}
|
| 105 |
stopChan <- true
|
| 106 |
}()
|
| 107 |
+
helper.SetEventStreamHeaders(c)
|
| 108 |
isFirst := true
|
| 109 |
c.Stream(func(w io.Writer) bool {
|
| 110 |
select {
|
relay/channel/dify/relay-dify.go
CHANGED
|
@@ -10,6 +10,7 @@ import (
|
|
| 10 |
"one-api/constant"
|
| 11 |
"one-api/dto"
|
| 12 |
relaycommon "one-api/relay/common"
|
|
|
|
| 13 |
"one-api/service"
|
| 14 |
"strings"
|
| 15 |
)
|
|
@@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|
| 66 |
scanner := bufio.NewScanner(resp.Body)
|
| 67 |
scanner.Split(bufio.ScanLines)
|
| 68 |
|
| 69 |
-
|
| 70 |
|
| 71 |
for scanner.Scan() {
|
| 72 |
data := scanner.Text()
|
|
@@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|
| 92 |
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
| 93 |
}
|
| 94 |
}
|
| 95 |
-
err =
|
| 96 |
if err != nil {
|
| 97 |
common.SysError(err.Error())
|
| 98 |
}
|
|
@@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|
| 100 |
if err := scanner.Err(); err != nil {
|
| 101 |
common.SysError("error reading stream: " + err.Error())
|
| 102 |
}
|
| 103 |
-
|
| 104 |
err := resp.Body.Close()
|
| 105 |
if err != nil {
|
| 106 |
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
|
|
| 10 |
"one-api/constant"
|
| 11 |
"one-api/dto"
|
| 12 |
relaycommon "one-api/relay/common"
|
| 13 |
+
"one-api/relay/helper"
|
| 14 |
"one-api/service"
|
| 15 |
"strings"
|
| 16 |
)
|
|
|
|
| 67 |
scanner := bufio.NewScanner(resp.Body)
|
| 68 |
scanner.Split(bufio.ScanLines)
|
| 69 |
|
| 70 |
+
helper.SetEventStreamHeaders(c)
|
| 71 |
|
| 72 |
for scanner.Scan() {
|
| 73 |
data := scanner.Text()
|
|
|
|
| 93 |
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
| 94 |
}
|
| 95 |
}
|
| 96 |
+
err = helper.ObjectData(c, openaiResponse)
|
| 97 |
if err != nil {
|
| 98 |
common.SysError(err.Error())
|
| 99 |
}
|
|
|
|
| 101 |
if err := scanner.Err(); err != nil {
|
| 102 |
common.SysError("error reading stream: " + err.Error())
|
| 103 |
}
|
| 104 |
+
helper.Done(c)
|
| 105 |
err := resp.Body.Close()
|
| 106 |
if err != nil {
|
| 107 |
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
relay/channel/gemini/relay-gemini.go
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
package gemini
|
| 2 |
|
| 3 |
import (
|
| 4 |
-
"bufio"
|
| 5 |
"encoding/json"
|
| 6 |
"fmt"
|
| 7 |
"io"
|
|
@@ -10,6 +9,7 @@ import (
|
|
| 10 |
"one-api/constant"
|
| 11 |
"one-api/dto"
|
| 12 |
relaycommon "one-api/relay/common"
|
|
|
|
| 13 |
"one-api/service"
|
| 14 |
"one-api/setting/model_setting"
|
| 15 |
"strings"
|
|
@@ -429,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
|
| 429 |
|
| 430 |
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
| 431 |
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
| 432 |
-
|
| 433 |
for _, candidate := range geminiResponse.Candidates {
|
| 434 |
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
| 435 |
-
|
| 436 |
candidate.FinishReason = nil
|
| 437 |
}
|
| 438 |
choice := dto.ChatCompletionsStreamResponseChoice{
|
|
@@ -482,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
|
| 482 |
|
| 483 |
var response dto.ChatCompletionsStreamResponse
|
| 484 |
response.Object = "chat.completion.chunk"
|
| 485 |
-
response.Model = "gemini"
|
| 486 |
response.Choices = choices
|
| 487 |
-
return &response,
|
| 488 |
}
|
| 489 |
|
| 490 |
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
@@ -492,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|
| 492 |
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
| 493 |
createAt := common.GetTimestamp()
|
| 494 |
var usage = &dto.Usage{}
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
service.SetEventStreamHeaders(c)
|
| 499 |
-
for scanner.Scan() {
|
| 500 |
-
data := scanner.Text()
|
| 501 |
-
info.SetFirstResponseTime()
|
| 502 |
-
data = strings.TrimSpace(data)
|
| 503 |
-
if !strings.HasPrefix(data, "data: ") {
|
| 504 |
-
continue
|
| 505 |
-
}
|
| 506 |
-
data = strings.TrimPrefix(data, "data: ")
|
| 507 |
-
data = strings.TrimSuffix(data, "\"")
|
| 508 |
var geminiResponse GeminiChatResponse
|
| 509 |
err := json.Unmarshal([]byte(data), &geminiResponse)
|
| 510 |
if err != nil {
|
| 511 |
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
| 512 |
-
|
| 513 |
}
|
| 514 |
|
| 515 |
-
response,
|
| 516 |
response.Id = id
|
| 517 |
response.Created = createAt
|
| 518 |
response.Model = info.UpstreamModelName
|
|
@@ -521,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|
| 521 |
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
| 522 |
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
| 523 |
}
|
| 524 |
-
err =
|
| 525 |
if err != nil {
|
| 526 |
common.LogError(c, err.Error())
|
| 527 |
}
|
| 528 |
-
if
|
| 529 |
-
response :=
|
| 530 |
-
|
| 531 |
}
|
| 532 |
-
|
|
|
|
| 533 |
|
| 534 |
var response *dto.ChatCompletionsStreamResponse
|
| 535 |
|
|
@@ -538,14 +527,14 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|
| 538 |
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
|
| 539 |
|
| 540 |
if info.ShouldIncludeUsage {
|
| 541 |
-
response =
|
| 542 |
-
err :=
|
| 543 |
if err != nil {
|
| 544 |
common.SysError("send final response failed: " + err.Error())
|
| 545 |
}
|
| 546 |
}
|
| 547 |
-
|
| 548 |
-
resp.Body.Close()
|
| 549 |
return nil, usage
|
| 550 |
}
|
| 551 |
|
|
|
|
| 1 |
package gemini
|
| 2 |
|
| 3 |
import (
|
|
|
|
| 4 |
"encoding/json"
|
| 5 |
"fmt"
|
| 6 |
"io"
|
|
|
|
| 9 |
"one-api/constant"
|
| 10 |
"one-api/dto"
|
| 11 |
relaycommon "one-api/relay/common"
|
| 12 |
+
"one-api/relay/helper"
|
| 13 |
"one-api/service"
|
| 14 |
"one-api/setting/model_setting"
|
| 15 |
"strings"
|
|
|
|
| 429 |
|
| 430 |
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
| 431 |
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
| 432 |
+
isStop := false
|
| 433 |
for _, candidate := range geminiResponse.Candidates {
|
| 434 |
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
| 435 |
+
isStop = true
|
| 436 |
candidate.FinishReason = nil
|
| 437 |
}
|
| 438 |
choice := dto.ChatCompletionsStreamResponseChoice{
|
|
|
|
| 482 |
|
| 483 |
var response dto.ChatCompletionsStreamResponse
|
| 484 |
response.Object = "chat.completion.chunk"
|
|
|
|
| 485 |
response.Choices = choices
|
| 486 |
+
return &response, isStop
|
| 487 |
}
|
| 488 |
|
| 489 |
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
|
| 491 |
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
| 492 |
createAt := common.GetTimestamp()
|
| 493 |
var usage = &dto.Usage{}
|
| 494 |
+
|
| 495 |
+
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
var geminiResponse GeminiChatResponse
|
| 497 |
err := json.Unmarshal([]byte(data), &geminiResponse)
|
| 498 |
if err != nil {
|
| 499 |
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
| 500 |
+
return false
|
| 501 |
}
|
| 502 |
|
| 503 |
+
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
| 504 |
response.Id = id
|
| 505 |
response.Created = createAt
|
| 506 |
response.Model = info.UpstreamModelName
|
|
|
|
| 509 |
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
| 510 |
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
| 511 |
}
|
| 512 |
+
err = helper.ObjectData(c, response)
|
| 513 |
if err != nil {
|
| 514 |
common.LogError(c, err.Error())
|
| 515 |
}
|
| 516 |
+
if isStop {
|
| 517 |
+
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
| 518 |
+
helper.ObjectData(c, response)
|
| 519 |
}
|
| 520 |
+
return true
|
| 521 |
+
})
|
| 522 |
|
| 523 |
var response *dto.ChatCompletionsStreamResponse
|
| 524 |
|
|
|
|
| 527 |
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
|
| 528 |
|
| 529 |
if info.ShouldIncludeUsage {
|
| 530 |
+
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
| 531 |
+
err := helper.ObjectData(c, response)
|
| 532 |
if err != nil {
|
| 533 |
common.SysError("send final response failed: " + err.Error())
|
| 534 |
}
|
| 535 |
}
|
| 536 |
+
helper.Done(c)
|
| 537 |
+
//resp.Body.Close()
|
| 538 |
return nil, usage
|
| 539 |
}
|
| 540 |
|
relay/channel/openai/relay-openai.go
CHANGED
|
@@ -1,10 +1,13 @@
|
|
| 1 |
package openai
|
| 2 |
|
| 3 |
import (
|
| 4 |
-
"bufio"
|
| 5 |
"bytes"
|
| 6 |
"encoding/json"
|
| 7 |
"fmt"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"io"
|
| 9 |
"math"
|
| 10 |
"mime/multipart"
|
|
@@ -14,16 +17,10 @@ import (
|
|
| 14 |
"one-api/dto"
|
| 15 |
relaycommon "one-api/relay/common"
|
| 16 |
relayconstant "one-api/relay/constant"
|
|
|
|
| 17 |
"one-api/service"
|
| 18 |
"os"
|
| 19 |
"strings"
|
| 20 |
-
"sync"
|
| 21 |
-
"time"
|
| 22 |
-
|
| 23 |
-
"github.com/bytedance/gopkg/util/gopool"
|
| 24 |
-
"github.com/gin-gonic/gin"
|
| 25 |
-
"github.com/gorilla/websocket"
|
| 26 |
-
"github.com/pkg/errors"
|
| 27 |
)
|
| 28 |
|
| 29 |
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
|
@@ -32,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
|
| 32 |
}
|
| 33 |
|
| 34 |
if !forceFormat && !thinkToContent {
|
| 35 |
-
return
|
| 36 |
}
|
| 37 |
|
| 38 |
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
|
@@ -41,34 +38,47 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
|
| 41 |
}
|
| 42 |
|
| 43 |
if !thinkToContent {
|
| 44 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
}
|
| 46 |
|
| 47 |
// Handle think to content conversion
|
| 48 |
-
if info.
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
response.Choices
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
}
|
| 54 |
-
service.ObjectData(c, response)
|
| 55 |
}
|
| 56 |
|
| 57 |
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
|
| 58 |
-
return
|
| 59 |
}
|
| 60 |
|
| 61 |
// Process each choice
|
| 62 |
for i, choice := range lastStreamResponse.Choices {
|
| 63 |
// Handle transition from thinking to content
|
| 64 |
-
if len(choice.Delta.GetContentString()) > 0 && !info.
|
| 65 |
response := lastStreamResponse.Copy()
|
| 66 |
for j := range response.Choices {
|
| 67 |
-
response.Choices[j].Delta.SetContentString("\n</think
|
| 68 |
response.Choices[j].Delta.SetReasoningContent("")
|
| 69 |
}
|
| 70 |
-
info.
|
| 71 |
-
|
| 72 |
}
|
| 73 |
|
| 74 |
// Convert reasoning content to regular content
|
|
@@ -78,7 +88,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
|
| 78 |
}
|
| 79 |
}
|
| 80 |
|
| 81 |
-
return
|
| 82 |
}
|
| 83 |
|
| 84 |
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
@@ -108,65 +118,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
| 108 |
}
|
| 109 |
|
| 110 |
toolCount := 0
|
| 111 |
-
scanner := bufio.NewScanner(resp.Body)
|
| 112 |
-
scanner.Split(bufio.ScanLines)
|
| 113 |
-
|
| 114 |
-
service.SetEventStreamHeaders(c)
|
| 115 |
-
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
| 116 |
-
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
|
| 117 |
-
// twice timeout for o1 model
|
| 118 |
-
streamingTimeout *= 2
|
| 119 |
-
}
|
| 120 |
-
ticker := time.NewTicker(streamingTimeout)
|
| 121 |
-
defer ticker.Stop()
|
| 122 |
|
| 123 |
-
stopChan := make(chan bool)
|
| 124 |
-
defer close(stopChan)
|
| 125 |
var (
|
| 126 |
lastStreamData string
|
| 127 |
-
mu sync.Mutex
|
| 128 |
)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
println(data)
|
| 136 |
-
}
|
| 137 |
-
if len(data) < 6 { // ignore blank line or wrong format
|
| 138 |
-
continue
|
| 139 |
-
}
|
| 140 |
-
if data[:5] != "data:" && data[:6] != "[DONE]" {
|
| 141 |
-
continue
|
| 142 |
-
}
|
| 143 |
-
mu.Lock()
|
| 144 |
-
data = data[5:]
|
| 145 |
-
data = strings.TrimSpace(data)
|
| 146 |
-
if !strings.HasPrefix(data, "[DONE]") {
|
| 147 |
-
if lastStreamData != "" {
|
| 148 |
-
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
| 149 |
-
if err != nil {
|
| 150 |
-
common.LogError(c, "streaming error: "+err.Error())
|
| 151 |
-
}
|
| 152 |
-
info.SetFirstResponseTime()
|
| 153 |
-
}
|
| 154 |
-
lastStreamData = data
|
| 155 |
-
streamItems = append(streamItems, data)
|
| 156 |
}
|
| 157 |
-
mu.Unlock()
|
| 158 |
}
|
| 159 |
-
|
|
|
|
|
|
|
| 160 |
})
|
| 161 |
|
| 162 |
-
select {
|
| 163 |
-
case <-ticker.C:
|
| 164 |
-
// 超时处理逻辑
|
| 165 |
-
common.LogError(c, "streaming timeout")
|
| 166 |
-
case <-stopChan:
|
| 167 |
-
// 正常结束
|
| 168 |
-
}
|
| 169 |
-
|
| 170 |
shouldSendLastResp := true
|
| 171 |
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
| 172 |
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
|
|
@@ -274,14 +242,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
| 274 |
}
|
| 275 |
|
| 276 |
if info.ShouldIncludeUsage && !containStreamUsage {
|
| 277 |
-
response :=
|
| 278 |
response.SetSystemFingerprint(systemFingerprint)
|
| 279 |
-
|
| 280 |
}
|
| 281 |
|
| 282 |
-
|
| 283 |
|
| 284 |
-
resp.Body.Close()
|
| 285 |
return nil, usage
|
| 286 |
}
|
| 287 |
|
|
@@ -512,7 +480,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|
| 512 |
localUsage.InputTokenDetails.TextTokens += textToken
|
| 513 |
localUsage.InputTokenDetails.AudioTokens += audioToken
|
| 514 |
|
| 515 |
-
err =
|
| 516 |
if err != nil {
|
| 517 |
errChan <- fmt.Errorf("error writing to target: %v", err)
|
| 518 |
return
|
|
@@ -618,7 +586,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|
| 618 |
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
| 619 |
}
|
| 620 |
|
| 621 |
-
err =
|
| 622 |
if err != nil {
|
| 623 |
errChan <- fmt.Errorf("error writing to client: %v", err)
|
| 624 |
return
|
|
|
|
| 1 |
package openai
|
| 2 |
|
| 3 |
import (
|
|
|
|
| 4 |
"bytes"
|
| 5 |
"encoding/json"
|
| 6 |
"fmt"
|
| 7 |
+
"github.com/bytedance/gopkg/util/gopool"
|
| 8 |
+
"github.com/gin-gonic/gin"
|
| 9 |
+
"github.com/gorilla/websocket"
|
| 10 |
+
"github.com/pkg/errors"
|
| 11 |
"io"
|
| 12 |
"math"
|
| 13 |
"mime/multipart"
|
|
|
|
| 17 |
"one-api/dto"
|
| 18 |
relaycommon "one-api/relay/common"
|
| 19 |
relayconstant "one-api/relay/constant"
|
| 20 |
+
"one-api/relay/helper"
|
| 21 |
"one-api/service"
|
| 22 |
"os"
|
| 23 |
"strings"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
)
|
| 25 |
|
| 26 |
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
|
|
|
| 29 |
}
|
| 30 |
|
| 31 |
if !forceFormat && !thinkToContent {
|
| 32 |
+
return helper.StringData(c, data)
|
| 33 |
}
|
| 34 |
|
| 35 |
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
|
|
|
| 38 |
}
|
| 39 |
|
| 40 |
if !thinkToContent {
|
| 41 |
+
return helper.ObjectData(c, lastStreamResponse)
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
hasThinkingContent := false
|
| 45 |
+
for _, choice := range lastStreamResponse.Choices {
|
| 46 |
+
if len(choice.Delta.GetReasoningContent()) > 0 {
|
| 47 |
+
hasThinkingContent = true
|
| 48 |
+
break
|
| 49 |
+
}
|
| 50 |
}
|
| 51 |
|
| 52 |
// Handle think to content conversion
|
| 53 |
+
if info.ThinkingContentInfo.IsFirstThinkingContent {
|
| 54 |
+
if hasThinkingContent {
|
| 55 |
+
response := lastStreamResponse.Copy()
|
| 56 |
+
for i := range response.Choices {
|
| 57 |
+
response.Choices[i].Delta.SetContentString("<think>\n")
|
| 58 |
+
response.Choices[i].Delta.SetReasoningContent("")
|
| 59 |
+
}
|
| 60 |
+
info.ThinkingContentInfo.IsFirstThinkingContent = false
|
| 61 |
+
return helper.ObjectData(c, response)
|
| 62 |
+
} else {
|
| 63 |
+
return helper.ObjectData(c, lastStreamResponse)
|
| 64 |
}
|
|
|
|
| 65 |
}
|
| 66 |
|
| 67 |
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
|
| 68 |
+
return helper.ObjectData(c, lastStreamResponse)
|
| 69 |
}
|
| 70 |
|
| 71 |
// Process each choice
|
| 72 |
for i, choice := range lastStreamResponse.Choices {
|
| 73 |
// Handle transition from thinking to content
|
| 74 |
+
if len(choice.Delta.GetContentString()) > 0 && !info.ThinkingContentInfo.SendLastThinkingContent {
|
| 75 |
response := lastStreamResponse.Copy()
|
| 76 |
for j := range response.Choices {
|
| 77 |
+
response.Choices[j].Delta.SetContentString("\n</think>\n\n")
|
| 78 |
response.Choices[j].Delta.SetReasoningContent("")
|
| 79 |
}
|
| 80 |
+
info.ThinkingContentInfo.SendLastThinkingContent = true
|
| 81 |
+
helper.ObjectData(c, response)
|
| 82 |
}
|
| 83 |
|
| 84 |
// Convert reasoning content to regular content
|
|
|
|
| 88 |
}
|
| 89 |
}
|
| 90 |
|
| 91 |
+
return helper.ObjectData(c, lastStreamResponse)
|
| 92 |
}
|
| 93 |
|
| 94 |
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
|
| 118 |
}
|
| 119 |
|
| 120 |
toolCount := 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
|
|
|
|
|
|
| 122 |
var (
|
| 123 |
lastStreamData string
|
|
|
|
| 124 |
)
|
| 125 |
+
|
| 126 |
+
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
| 127 |
+
if lastStreamData != "" {
|
| 128 |
+
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
| 129 |
+
if err != nil {
|
| 130 |
+
common.LogError(c, "streaming error: "+err.Error())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
}
|
|
|
|
| 132 |
}
|
| 133 |
+
lastStreamData = data
|
| 134 |
+
streamItems = append(streamItems, data)
|
| 135 |
+
return true
|
| 136 |
})
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
shouldSendLastResp := true
|
| 139 |
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
| 140 |
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
|
|
|
|
| 242 |
}
|
| 243 |
|
| 244 |
if info.ShouldIncludeUsage && !containStreamUsage {
|
| 245 |
+
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
| 246 |
response.SetSystemFingerprint(systemFingerprint)
|
| 247 |
+
helper.ObjectData(c, response)
|
| 248 |
}
|
| 249 |
|
| 250 |
+
helper.Done(c)
|
| 251 |
|
| 252 |
+
//resp.Body.Close()
|
| 253 |
return nil, usage
|
| 254 |
}
|
| 255 |
|
|
|
|
| 480 |
localUsage.InputTokenDetails.TextTokens += textToken
|
| 481 |
localUsage.InputTokenDetails.AudioTokens += audioToken
|
| 482 |
|
| 483 |
+
err = helper.WssString(c, targetConn, string(message))
|
| 484 |
if err != nil {
|
| 485 |
errChan <- fmt.Errorf("error writing to target: %v", err)
|
| 486 |
return
|
|
|
|
| 586 |
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
| 587 |
}
|
| 588 |
|
| 589 |
+
err = helper.WssString(c, clientConn, string(message))
|
| 590 |
if err != nil {
|
| 591 |
errChan <- fmt.Errorf("error writing to client: %v", err)
|
| 592 |
return
|
relay/channel/palm/relay-palm.go
CHANGED
|
@@ -9,6 +9,7 @@ import (
|
|
| 9 |
"one-api/common"
|
| 10 |
"one-api/constant"
|
| 11 |
"one-api/dto"
|
|
|
|
| 12 |
"one-api/service"
|
| 13 |
)
|
| 14 |
|
|
@@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
|
| 112 |
dataChan <- string(jsonResponse)
|
| 113 |
stopChan <- true
|
| 114 |
}()
|
| 115 |
-
|
| 116 |
c.Stream(func(w io.Writer) bool {
|
| 117 |
select {
|
| 118 |
case data := <-dataChan:
|
|
|
|
| 9 |
"one-api/common"
|
| 10 |
"one-api/constant"
|
| 11 |
"one-api/dto"
|
| 12 |
+
"one-api/relay/helper"
|
| 13 |
"one-api/service"
|
| 14 |
)
|
| 15 |
|
|
|
|
| 113 |
dataChan <- string(jsonResponse)
|
| 114 |
stopChan <- true
|
| 115 |
}()
|
| 116 |
+
helper.SetEventStreamHeaders(c)
|
| 117 |
c.Stream(func(w io.Writer) bool {
|
| 118 |
select {
|
| 119 |
case data := <-dataChan:
|
relay/channel/tencent/relay-tencent.go
CHANGED
|
@@ -14,6 +14,7 @@ import (
|
|
| 14 |
"one-api/common"
|
| 15 |
"one-api/constant"
|
| 16 |
"one-api/dto"
|
|
|
|
| 17 |
"one-api/service"
|
| 18 |
"strconv"
|
| 19 |
"strings"
|
|
@@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|
| 91 |
scanner := bufio.NewScanner(resp.Body)
|
| 92 |
scanner.Split(bufio.ScanLines)
|
| 93 |
|
| 94 |
-
|
| 95 |
|
| 96 |
for scanner.Scan() {
|
| 97 |
data := scanner.Text()
|
|
@@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|
| 112 |
responseText += response.Choices[0].Delta.GetContentString()
|
| 113 |
}
|
| 114 |
|
| 115 |
-
err =
|
| 116 |
if err != nil {
|
| 117 |
common.SysError(err.Error())
|
| 118 |
}
|
|
@@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|
| 122 |
common.SysError("error reading stream: " + err.Error())
|
| 123 |
}
|
| 124 |
|
| 125 |
-
|
| 126 |
|
| 127 |
err := resp.Body.Close()
|
| 128 |
if err != nil {
|
|
|
|
| 14 |
"one-api/common"
|
| 15 |
"one-api/constant"
|
| 16 |
"one-api/dto"
|
| 17 |
+
"one-api/relay/helper"
|
| 18 |
"one-api/service"
|
| 19 |
"strconv"
|
| 20 |
"strings"
|
|
|
|
| 92 |
scanner := bufio.NewScanner(resp.Body)
|
| 93 |
scanner.Split(bufio.ScanLines)
|
| 94 |
|
| 95 |
+
helper.SetEventStreamHeaders(c)
|
| 96 |
|
| 97 |
for scanner.Scan() {
|
| 98 |
data := scanner.Text()
|
|
|
|
| 113 |
responseText += response.Choices[0].Delta.GetContentString()
|
| 114 |
}
|
| 115 |
|
| 116 |
+
err = helper.ObjectData(c, response)
|
| 117 |
if err != nil {
|
| 118 |
common.SysError(err.Error())
|
| 119 |
}
|
|
|
|
| 123 |
common.SysError("error reading stream: " + err.Error())
|
| 124 |
}
|
| 125 |
|
| 126 |
+
helper.Done(c)
|
| 127 |
|
| 128 |
err := resp.Body.Close()
|
| 129 |
if err != nil {
|
relay/channel/vertex/adaptor.go
CHANGED
|
@@ -5,7 +5,6 @@ import (
|
|
| 5 |
"errors"
|
| 6 |
"fmt"
|
| 7 |
"github.com/gin-gonic/gin"
|
| 8 |
-
"github.com/jinzhu/copier"
|
| 9 |
"io"
|
| 10 |
"net/http"
|
| 11 |
"one-api/dto"
|
|
@@ -28,6 +27,7 @@ var claudeModelMap = map[string]string{
|
|
| 28 |
"claude-3-opus-20240229": "claude-3-opus@20240229",
|
| 29 |
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
|
| 30 |
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
|
|
|
|
| 31 |
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
|
| 32 |
}
|
| 33 |
|
|
@@ -86,15 +86,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
| 86 |
} else {
|
| 87 |
suffix = "rawPredict"
|
| 88 |
}
|
|
|
|
| 89 |
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
| 90 |
-
|
| 91 |
}
|
| 92 |
return fmt.Sprintf(
|
| 93 |
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
| 94 |
region,
|
| 95 |
adc.ProjectID,
|
| 96 |
region,
|
| 97 |
-
|
| 98 |
suffix,
|
| 99 |
), nil
|
| 100 |
} else if a.RequestMode == RequestModeLlama {
|
|
@@ -127,13 +128,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|
| 127 |
if err != nil {
|
| 128 |
return nil, err
|
| 129 |
}
|
| 130 |
-
vertexClaudeReq :=
|
| 131 |
-
AnthropicVersion: anthropicVersion,
|
| 132 |
-
}
|
| 133 |
-
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
|
| 134 |
-
return nil, errors.New("failed to copy claude request")
|
| 135 |
-
}
|
| 136 |
c.Set("request_model", claudeReq.Model)
|
|
|
|
| 137 |
return vertexClaudeReq, nil
|
| 138 |
} else if a.RequestMode == RequestModeGemini {
|
| 139 |
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
|
|
|
|
| 5 |
"errors"
|
| 6 |
"fmt"
|
| 7 |
"github.com/gin-gonic/gin"
|
|
|
|
| 8 |
"io"
|
| 9 |
"net/http"
|
| 10 |
"one-api/dto"
|
|
|
|
| 27 |
"claude-3-opus-20240229": "claude-3-opus@20240229",
|
| 28 |
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
|
| 29 |
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
|
| 30 |
+
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
|
| 31 |
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
|
| 32 |
}
|
| 33 |
|
|
|
|
| 86 |
} else {
|
| 87 |
suffix = "rawPredict"
|
| 88 |
}
|
| 89 |
+
model := info.UpstreamModelName
|
| 90 |
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
| 91 |
+
model = v
|
| 92 |
}
|
| 93 |
return fmt.Sprintf(
|
| 94 |
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
| 95 |
region,
|
| 96 |
adc.ProjectID,
|
| 97 |
region,
|
| 98 |
+
model,
|
| 99 |
suffix,
|
| 100 |
), nil
|
| 101 |
} else if a.RequestMode == RequestModeLlama {
|
|
|
|
| 128 |
if err != nil {
|
| 129 |
return nil, err
|
| 130 |
}
|
| 131 |
+
vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
c.Set("request_model", claudeReq.Model)
|
| 133 |
+
info.UpstreamModelName = claudeReq.Model
|
| 134 |
return vertexClaudeReq, nil
|
| 135 |
} else if a.RequestMode == RequestModeGemini {
|
| 136 |
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
|
relay/channel/vertex/dto.go
CHANGED
|
@@ -1,17 +1,37 @@
|
|
| 1 |
package vertex
|
| 2 |
|
| 3 |
-
import
|
|
|
|
|
|
|
| 4 |
|
| 5 |
type VertexAIClaudeRequest struct {
|
| 6 |
AnthropicVersion string `json:"anthropic_version"`
|
| 7 |
Messages []claude.ClaudeMessage `json:"messages"`
|
| 8 |
-
System
|
| 9 |
-
MaxTokens
|
| 10 |
StopSequences []string `json:"stop_sequences,omitempty"`
|
| 11 |
Stream bool `json:"stream,omitempty"`
|
| 12 |
Temperature *float64 `json:"temperature,omitempty"`
|
| 13 |
TopP float64 `json:"top_p,omitempty"`
|
| 14 |
TopK int `json:"top_k,omitempty"`
|
| 15 |
-
Tools
|
| 16 |
ToolChoice any `json:"tool_choice,omitempty"`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
}
|
|
|
|
| 1 |
package vertex
|
| 2 |
|
| 3 |
+
import (
|
| 4 |
+
"one-api/relay/channel/claude"
|
| 5 |
+
)
|
| 6 |
|
| 7 |
type VertexAIClaudeRequest struct {
|
| 8 |
AnthropicVersion string `json:"anthropic_version"`
|
| 9 |
Messages []claude.ClaudeMessage `json:"messages"`
|
| 10 |
+
System any `json:"system,omitempty"`
|
| 11 |
+
MaxTokens uint `json:"max_tokens,omitempty"`
|
| 12 |
StopSequences []string `json:"stop_sequences,omitempty"`
|
| 13 |
Stream bool `json:"stream,omitempty"`
|
| 14 |
Temperature *float64 `json:"temperature,omitempty"`
|
| 15 |
TopP float64 `json:"top_p,omitempty"`
|
| 16 |
TopK int `json:"top_k,omitempty"`
|
| 17 |
+
Tools any `json:"tools,omitempty"`
|
| 18 |
ToolChoice any `json:"tool_choice,omitempty"`
|
| 19 |
+
Thinking *claude.Thinking `json:"thinking,omitempty"`
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func copyRequest(req *claude.ClaudeRequest, version string) *VertexAIClaudeRequest {
|
| 23 |
+
return &VertexAIClaudeRequest{
|
| 24 |
+
AnthropicVersion: version,
|
| 25 |
+
System: req.System,
|
| 26 |
+
Messages: req.Messages,
|
| 27 |
+
MaxTokens: req.MaxTokens,
|
| 28 |
+
Stream: req.Stream,
|
| 29 |
+
Temperature: req.Temperature,
|
| 30 |
+
TopP: req.TopP,
|
| 31 |
+
TopK: req.TopK,
|
| 32 |
+
StopSequences: req.StopSequences,
|
| 33 |
+
Tools: req.Tools,
|
| 34 |
+
ToolChoice: req.ToolChoice,
|
| 35 |
+
Thinking: req.Thinking,
|
| 36 |
+
}
|
| 37 |
}
|
relay/channel/xunfei/relay-xunfei.go
CHANGED
|
@@ -14,6 +14,7 @@ import (
|
|
| 14 |
"one-api/common"
|
| 15 |
"one-api/constant"
|
| 16 |
"one-api/dto"
|
|
|
|
| 17 |
"one-api/service"
|
| 18 |
"strings"
|
| 19 |
"time"
|
|
@@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
|
|
| 132 |
if err != nil {
|
| 133 |
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
| 134 |
}
|
| 135 |
-
|
| 136 |
var usage dto.Usage
|
| 137 |
c.Stream(func(w io.Writer) bool {
|
| 138 |
select {
|
|
|
|
| 14 |
"one-api/common"
|
| 15 |
"one-api/constant"
|
| 16 |
"one-api/dto"
|
| 17 |
+
"one-api/relay/helper"
|
| 18 |
"one-api/service"
|
| 19 |
"strings"
|
| 20 |
"time"
|
|
|
|
| 133 |
if err != nil {
|
| 134 |
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
| 135 |
}
|
| 136 |
+
helper.SetEventStreamHeaders(c)
|
| 137 |
var usage dto.Usage
|
| 138 |
c.Stream(func(w io.Writer) bool {
|
| 139 |
select {
|
relay/channel/zhipu/relay-zhipu.go
CHANGED
|
@@ -10,6 +10,7 @@ import (
|
|
| 10 |
"one-api/common"
|
| 11 |
"one-api/constant"
|
| 12 |
"one-api/dto"
|
|
|
|
| 13 |
"one-api/service"
|
| 14 |
"strings"
|
| 15 |
"sync"
|
|
@@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
|
| 177 |
}
|
| 178 |
stopChan <- true
|
| 179 |
}()
|
| 180 |
-
|
| 181 |
c.Stream(func(w io.Writer) bool {
|
| 182 |
select {
|
| 183 |
case data := <-dataChan:
|
|
|
|
| 10 |
"one-api/common"
|
| 11 |
"one-api/constant"
|
| 12 |
"one-api/dto"
|
| 13 |
+
"one-api/relay/helper"
|
| 14 |
"one-api/service"
|
| 15 |
"strings"
|
| 16 |
"sync"
|
|
|
|
| 178 |
}
|
| 179 |
stopChan <- true
|
| 180 |
}()
|
| 181 |
+
helper.SetEventStreamHeaders(c)
|
| 182 |
c.Stream(func(w io.Writer) bool {
|
| 183 |
select {
|
| 184 |
case data := <-dataChan:
|
relay/channel/zhipu_4v/relay-zhipu_v4.go
CHANGED
|
@@ -10,6 +10,7 @@ import (
|
|
| 10 |
"net/http"
|
| 11 |
"one-api/common"
|
| 12 |
"one-api/dto"
|
|
|
|
| 13 |
"one-api/service"
|
| 14 |
"strings"
|
| 15 |
"sync"
|
|
@@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
|
| 197 |
}
|
| 198 |
stopChan <- true
|
| 199 |
}()
|
| 200 |
-
|
| 201 |
c.Stream(func(w io.Writer) bool {
|
| 202 |
select {
|
| 203 |
case data := <-dataChan:
|
|
|
|
| 10 |
"net/http"
|
| 11 |
"one-api/common"
|
| 12 |
"one-api/dto"
|
| 13 |
+
"one-api/relay/helper"
|
| 14 |
"one-api/service"
|
| 15 |
"strings"
|
| 16 |
"sync"
|
|
|
|
| 198 |
}
|
| 199 |
stopChan <- true
|
| 200 |
}()
|
| 201 |
+
helper.SetEventStreamHeaders(c)
|
| 202 |
c.Stream(func(w io.Writer) bool {
|
| 203 |
select {
|
| 204 |
case data := <-dataChan:
|
relay/common/relay_info.go
CHANGED
|
@@ -12,25 +12,30 @@ import (
|
|
| 12 |
"github.com/gorilla/websocket"
|
| 13 |
)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
type RelayInfo struct {
|
| 16 |
-
ChannelType
|
| 17 |
-
ChannelId
|
| 18 |
-
TokenId
|
| 19 |
-
TokenKey
|
| 20 |
-
UserId
|
| 21 |
-
Group
|
| 22 |
-
TokenUnlimited
|
| 23 |
-
StartTime
|
| 24 |
-
FirstResponseTime
|
| 25 |
-
|
| 26 |
-
SendLastReasoningResponse bool
|
| 27 |
-
ApiType
|
| 28 |
-
IsStream
|
| 29 |
-
IsPlayground
|
| 30 |
-
UsePrice
|
| 31 |
-
RelayMode
|
| 32 |
-
UpstreamModelName
|
| 33 |
-
OriginModelName
|
| 34 |
//RecodeModelName string
|
| 35 |
RequestURLPath string
|
| 36 |
ApiVersion string
|
|
@@ -53,6 +58,7 @@ type RelayInfo struct {
|
|
| 53 |
UserSetting map[string]interface{}
|
| 54 |
UserEmail string
|
| 55 |
UserQuota int
|
|
|
|
| 56 |
}
|
| 57 |
|
| 58 |
// 定义支持流式选项的通道类型
|
|
@@ -95,7 +101,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|
| 95 |
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
| 96 |
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
| 97 |
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
| 98 |
-
|
| 99 |
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
| 100 |
BaseUrl: c.GetString("base_url"),
|
| 101 |
RequestURLPath: c.Request.URL.String(),
|
|
@@ -117,6 +123,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|
| 117 |
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
| 118 |
Organization: c.GetString("channel_organization"),
|
| 119 |
ChannelSetting: channelSetting,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
}
|
| 121 |
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
| 122 |
info.IsPlayground = true
|
|
@@ -147,9 +157,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
|
|
| 147 |
}
|
| 148 |
|
| 149 |
func (info *RelayInfo) SetFirstResponseTime() {
|
| 150 |
-
if info.
|
| 151 |
info.FirstResponseTime = time.Now()
|
| 152 |
-
info.
|
| 153 |
}
|
| 154 |
}
|
| 155 |
|
|
|
|
| 12 |
"github.com/gorilla/websocket"
|
| 13 |
)
|
| 14 |
|
| 15 |
+
type ThinkingContentInfo struct {
|
| 16 |
+
IsFirstThinkingContent bool
|
| 17 |
+
SendLastThinkingContent bool
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
type RelayInfo struct {
|
| 21 |
+
ChannelType int
|
| 22 |
+
ChannelId int
|
| 23 |
+
TokenId int
|
| 24 |
+
TokenKey string
|
| 25 |
+
UserId int
|
| 26 |
+
Group string
|
| 27 |
+
TokenUnlimited bool
|
| 28 |
+
StartTime time.Time
|
| 29 |
+
FirstResponseTime time.Time
|
| 30 |
+
isFirstResponse bool
|
| 31 |
+
//SendLastReasoningResponse bool
|
| 32 |
+
ApiType int
|
| 33 |
+
IsStream bool
|
| 34 |
+
IsPlayground bool
|
| 35 |
+
UsePrice bool
|
| 36 |
+
RelayMode int
|
| 37 |
+
UpstreamModelName string
|
| 38 |
+
OriginModelName string
|
| 39 |
//RecodeModelName string
|
| 40 |
RequestURLPath string
|
| 41 |
ApiVersion string
|
|
|
|
| 58 |
UserSetting map[string]interface{}
|
| 59 |
UserEmail string
|
| 60 |
UserQuota int
|
| 61 |
+
ThinkingContentInfo
|
| 62 |
}
|
| 63 |
|
| 64 |
// 定义支持流式选项的通道类型
|
|
|
|
| 101 |
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
| 102 |
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
| 103 |
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
| 104 |
+
isFirstResponse: true,
|
| 105 |
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
| 106 |
BaseUrl: c.GetString("base_url"),
|
| 107 |
RequestURLPath: c.Request.URL.String(),
|
|
|
|
| 123 |
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
| 124 |
Organization: c.GetString("channel_organization"),
|
| 125 |
ChannelSetting: channelSetting,
|
| 126 |
+
ThinkingContentInfo: ThinkingContentInfo{
|
| 127 |
+
IsFirstThinkingContent: true,
|
| 128 |
+
SendLastThinkingContent: false,
|
| 129 |
+
},
|
| 130 |
}
|
| 131 |
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
| 132 |
info.IsPlayground = true
|
|
|
|
| 157 |
}
|
| 158 |
|
| 159 |
func (info *RelayInfo) SetFirstResponseTime() {
|
| 160 |
+
if info.isFirstResponse {
|
| 161 |
info.FirstResponseTime = time.Now()
|
| 162 |
+
info.isFirstResponse = false
|
| 163 |
}
|
| 164 |
}
|
| 165 |
|
service/relay.go → relay/helper/common.go
RENAMED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
package
|
| 2 |
|
| 3 |
import (
|
| 4 |
"encoding/json"
|
|
|
|
| 1 |
+
package helper
|
| 2 |
|
| 3 |
import (
|
| 4 |
"encoding/json"
|
relay/helper/price.go
CHANGED
|
@@ -11,26 +11,33 @@ import (
|
|
| 11 |
type PriceData struct {
|
| 12 |
ModelPrice float64
|
| 13 |
ModelRatio float64
|
|
|
|
| 14 |
GroupRatio float64
|
| 15 |
UsePrice bool
|
| 16 |
ShouldPreConsumedQuota int
|
| 17 |
}
|
| 18 |
|
| 19 |
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
| 20 |
-
modelPrice, usePrice :=
|
| 21 |
groupRatio := setting.GetGroupRatio(info.Group)
|
| 22 |
var preConsumedQuota int
|
| 23 |
var modelRatio float64
|
|
|
|
| 24 |
if !usePrice {
|
| 25 |
preConsumedTokens := common.PreConsumedQuota
|
| 26 |
if maxTokens != 0 {
|
| 27 |
preConsumedTokens = promptTokens + maxTokens
|
| 28 |
}
|
| 29 |
var success bool
|
| 30 |
-
modelRatio, success =
|
| 31 |
if !success {
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
}
|
|
|
|
| 34 |
ratio := modelRatio * groupRatio
|
| 35 |
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
| 36 |
} else {
|
|
@@ -39,6 +46,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
|
| 39 |
return PriceData{
|
| 40 |
ModelPrice: modelPrice,
|
| 41 |
ModelRatio: modelRatio,
|
|
|
|
| 42 |
GroupRatio: groupRatio,
|
| 43 |
UsePrice: usePrice,
|
| 44 |
ShouldPreConsumedQuota: preConsumedQuota,
|
|
|
|
| 11 |
type PriceData struct {
|
| 12 |
ModelPrice float64
|
| 13 |
ModelRatio float64
|
| 14 |
+
CompletionRatio float64
|
| 15 |
GroupRatio float64
|
| 16 |
UsePrice bool
|
| 17 |
ShouldPreConsumedQuota int
|
| 18 |
}
|
| 19 |
|
| 20 |
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
| 21 |
+
modelPrice, usePrice := setting.GetModelPrice(info.OriginModelName, false)
|
| 22 |
groupRatio := setting.GetGroupRatio(info.Group)
|
| 23 |
var preConsumedQuota int
|
| 24 |
var modelRatio float64
|
| 25 |
+
var completionRatio float64
|
| 26 |
if !usePrice {
|
| 27 |
preConsumedTokens := common.PreConsumedQuota
|
| 28 |
if maxTokens != 0 {
|
| 29 |
preConsumedTokens = promptTokens + maxTokens
|
| 30 |
}
|
| 31 |
var success bool
|
| 32 |
+
modelRatio, success = setting.GetModelRatio(info.OriginModelName)
|
| 33 |
if !success {
|
| 34 |
+
if info.UserId == 1 {
|
| 35 |
+
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
|
| 36 |
+
} else {
|
| 37 |
+
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
|
| 38 |
+
}
|
| 39 |
}
|
| 40 |
+
completionRatio = setting.GetCompletionRatio(info.OriginModelName)
|
| 41 |
ratio := modelRatio * groupRatio
|
| 42 |
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
| 43 |
} else {
|
|
|
|
| 46 |
return PriceData{
|
| 47 |
ModelPrice: modelPrice,
|
| 48 |
ModelRatio: modelRatio,
|
| 49 |
+
CompletionRatio: completionRatio,
|
| 50 |
GroupRatio: groupRatio,
|
| 51 |
UsePrice: usePrice,
|
| 52 |
ShouldPreConsumedQuota: preConsumedQuota,
|
relay/helper/stream_scanner.go
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package helper
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bufio"
|
| 5 |
+
"context"
|
| 6 |
+
"io"
|
| 7 |
+
"net/http"
|
| 8 |
+
"one-api/common"
|
| 9 |
+
"one-api/constant"
|
| 10 |
+
relaycommon "one-api/relay/common"
|
| 11 |
+
"strings"
|
| 12 |
+
"time"
|
| 13 |
+
|
| 14 |
+
"github.com/gin-gonic/gin"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
|
| 18 |
+
|
| 19 |
+
if resp == nil {
|
| 20 |
+
return
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
defer resp.Body.Close()
|
| 24 |
+
|
| 25 |
+
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
| 26 |
+
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
|
| 27 |
+
// twice timeout for thinking model
|
| 28 |
+
streamingTimeout *= 2
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
var (
|
| 32 |
+
stopChan = make(chan bool, 2)
|
| 33 |
+
scanner = bufio.NewScanner(resp.Body)
|
| 34 |
+
ticker = time.NewTicker(streamingTimeout)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
defer func() {
|
| 38 |
+
ticker.Stop()
|
| 39 |
+
close(stopChan)
|
| 40 |
+
}()
|
| 41 |
+
|
| 42 |
+
scanner.Split(bufio.ScanLines)
|
| 43 |
+
SetEventStreamHeaders(c)
|
| 44 |
+
|
| 45 |
+
ctx, cancel := context.WithCancel(context.Background())
|
| 46 |
+
defer cancel()
|
| 47 |
+
|
| 48 |
+
ctx = context.WithValue(ctx, "stop_chan", stopChan)
|
| 49 |
+
common.RelayCtxGo(ctx, func() {
|
| 50 |
+
for scanner.Scan() {
|
| 51 |
+
ticker.Reset(streamingTimeout)
|
| 52 |
+
data := scanner.Text()
|
| 53 |
+
if common.DebugEnabled {
|
| 54 |
+
println(data)
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
if len(data) < 6 {
|
| 58 |
+
continue
|
| 59 |
+
}
|
| 60 |
+
if data[:5] != "data:" && data[:6] != "[DONE]" {
|
| 61 |
+
continue
|
| 62 |
+
}
|
| 63 |
+
data = data[5:]
|
| 64 |
+
data = strings.TrimLeft(data, " ")
|
| 65 |
+
data = strings.TrimSuffix(data, "\"")
|
| 66 |
+
if !strings.HasPrefix(data, "[DONE]") {
|
| 67 |
+
info.SetFirstResponseTime()
|
| 68 |
+
success := dataHandler(data)
|
| 69 |
+
if !success {
|
| 70 |
+
break
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
if err := scanner.Err(); err != nil {
|
| 76 |
+
if err != io.EOF {
|
| 77 |
+
common.LogError(c, "scanner error: "+err.Error())
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
common.SafeSendBool(stopChan, true)
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
select {
|
| 85 |
+
case <-ticker.C:
|
| 86 |
+
// 超时处理逻辑
|
| 87 |
+
common.LogError(c, "streaming timeout")
|
| 88 |
+
case <-stopChan:
|
| 89 |
+
// 正常结束
|
| 90 |
+
}
|
| 91 |
+
}
|
relay/relay-mj.go
CHANGED
|
@@ -157,10 +157,10 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
| 157 |
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
| 158 |
}
|
| 159 |
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
| 160 |
-
modelPrice, success :=
|
| 161 |
// 如果没有配置价格,则使用默认价格
|
| 162 |
if !success {
|
| 163 |
-
defaultPrice, ok :=
|
| 164 |
if !ok {
|
| 165 |
modelPrice = 0.1
|
| 166 |
} else {
|
|
@@ -463,10 +463,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
| 463 |
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
| 464 |
|
| 465 |
modelName := service.CoverActionToModelName(midjRequest.Action)
|
| 466 |
-
modelPrice, success :=
|
| 467 |
// 如果没有配置价格,则使用默认价格
|
| 468 |
if !success {
|
| 469 |
-
defaultPrice, ok :=
|
| 470 |
if !ok {
|
| 471 |
modelPrice = 0.1
|
| 472 |
} else {
|
|
|
|
| 157 |
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
| 158 |
}
|
| 159 |
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
| 160 |
+
modelPrice, success := setting.GetModelPrice(modelName, true)
|
| 161 |
// 如果没有配置价格,则使用默认价格
|
| 162 |
if !success {
|
| 163 |
+
defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName]
|
| 164 |
if !ok {
|
| 165 |
modelPrice = 0.1
|
| 166 |
} else {
|
|
|
|
| 463 |
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
| 464 |
|
| 465 |
modelName := service.CoverActionToModelName(midjRequest.Action)
|
| 466 |
+
modelPrice, success := setting.GetModelPrice(modelName, true)
|
| 467 |
// 如果没有配置价格,则使用默认价格
|
| 468 |
if !success {
|
| 469 |
+
defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName]
|
| 470 |
if !ok {
|
| 471 |
modelPrice = 0.1
|
| 472 |
} else {
|
relay/relay-text.go
CHANGED
|
@@ -311,7 +311,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|
| 311 |
modelName := relayInfo.OriginModelName
|
| 312 |
|
| 313 |
tokenName := ctx.GetString("token_name")
|
| 314 |
-
completionRatio :=
|
| 315 |
ratio := priceData.ModelRatio * priceData.GroupRatio
|
| 316 |
modelRatio := priceData.ModelRatio
|
| 317 |
groupRatio := priceData.GroupRatio
|
|
|
|
| 311 |
modelName := relayInfo.OriginModelName
|
| 312 |
|
| 313 |
tokenName := ctx.GetString("token_name")
|
| 314 |
+
completionRatio := setting.GetCompletionRatio(modelName)
|
| 315 |
ratio := priceData.ModelRatio * priceData.GroupRatio
|
| 316 |
modelRatio := priceData.ModelRatio
|
| 317 |
groupRatio := priceData.GroupRatio
|
relay/relay_task.go
CHANGED
|
@@ -37,9 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
| 37 |
}
|
| 38 |
|
| 39 |
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
| 40 |
-
modelPrice, success :=
|
| 41 |
if !success {
|
| 42 |
-
defaultPrice, ok :=
|
| 43 |
if !ok {
|
| 44 |
modelPrice = 0.1
|
| 45 |
} else {
|
|
|
|
| 37 |
}
|
| 38 |
|
| 39 |
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
| 40 |
+
modelPrice, success := setting.GetModelPrice(modelName, true)
|
| 41 |
if !success {
|
| 42 |
+
defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName]
|
| 43 |
if !ok {
|
| 44 |
modelPrice = 0.1
|
| 45 |
} else {
|
relay/websocket.go
CHANGED
|
@@ -39,7 +39,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
|
| 39 |
}
|
| 40 |
}
|
| 41 |
//relayInfo.UpstreamModelName = textRequest.Model
|
| 42 |
-
modelPrice, getModelPriceSuccess :=
|
| 43 |
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
| 44 |
|
| 45 |
var preConsumedQuota int
|
|
@@ -65,7 +65,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
|
| 65 |
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
| 66 |
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
| 67 |
//}
|
| 68 |
-
modelRatio, _ =
|
| 69 |
ratio = modelRatio * groupRatio
|
| 70 |
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
| 71 |
} else {
|
|
|
|
| 39 |
}
|
| 40 |
}
|
| 41 |
//relayInfo.UpstreamModelName = textRequest.Model
|
| 42 |
+
modelPrice, getModelPriceSuccess := setting.GetModelPrice(relayInfo.UpstreamModelName, false)
|
| 43 |
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
| 44 |
|
| 45 |
var preConsumedQuota int
|
|
|
|
| 65 |
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
| 66 |
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
| 67 |
//}
|
| 68 |
+
modelRatio, _ = setting.GetModelRatio(relayInfo.UpstreamModelName)
|
| 69 |
ratio = modelRatio * groupRatio
|
| 70 |
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
| 71 |
} else {
|
service/channel.go
CHANGED
|
@@ -6,23 +6,31 @@ import (
|
|
| 6 |
"one-api/common"
|
| 7 |
"one-api/dto"
|
| 8 |
"one-api/model"
|
| 9 |
-
"one-api/setting"
|
| 10 |
"strings"
|
| 11 |
)
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
// disable & notify
|
| 14 |
func DisableChannel(channelId int, channelName string, reason string) {
|
| 15 |
-
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
}
|
| 20 |
|
| 21 |
func EnableChannel(channelId int, channelName string) {
|
| 22 |
-
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
}
|
| 27 |
|
| 28 |
func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
|
|
@@ -67,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
|
|
| 67 |
}
|
| 68 |
|
| 69 |
lowerMessage := strings.ToLower(err.Error.Message)
|
| 70 |
-
search, _ := AcSearch(lowerMessage,
|
| 71 |
if search {
|
| 72 |
return true
|
| 73 |
}
|
|
|
|
| 6 |
"one-api/common"
|
| 7 |
"one-api/dto"
|
| 8 |
"one-api/model"
|
| 9 |
+
"one-api/setting/operation_setting"
|
| 10 |
"strings"
|
| 11 |
)
|
| 12 |
|
| 13 |
+
func formatNotifyType(channelId int, status int) string {
|
| 14 |
+
return fmt.Sprintf("%s_%d_%d", dto.NotifyTypeChannelUpdate, channelId, status)
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
// disable & notify
|
| 18 |
func DisableChannel(channelId int, channelName string, reason string) {
|
| 19 |
+
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
|
| 20 |
+
if success {
|
| 21 |
+
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
| 22 |
+
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
| 23 |
+
NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content)
|
| 24 |
+
}
|
| 25 |
}
|
| 26 |
|
| 27 |
func EnableChannel(channelId int, channelName string) {
|
| 28 |
+
success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
|
| 29 |
+
if success {
|
| 30 |
+
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
| 31 |
+
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
| 32 |
+
NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusEnabled), subject, content)
|
| 33 |
+
}
|
| 34 |
}
|
| 35 |
|
| 36 |
func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
|
|
|
|
| 75 |
}
|
| 76 |
|
| 77 |
lowerMessage := strings.ToLower(err.Error.Message)
|
| 78 |
+
search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true)
|
| 79 |
if search {
|
| 80 |
return true
|
| 81 |
}
|
service/image.go
CHANGED
|
@@ -7,7 +7,9 @@ import (
|
|
| 7 |
"fmt"
|
| 8 |
"image"
|
| 9 |
"io"
|
|
|
|
| 10 |
"one-api/common"
|
|
|
|
| 11 |
"strings"
|
| 12 |
|
| 13 |
"golang.org/x/image/webp"
|
|
@@ -23,7 +25,7 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
|
|
| 23 |
decodedData, err := base64.StdEncoding.DecodeString(base64String)
|
| 24 |
if err != nil {
|
| 25 |
fmt.Println("Error: Failed to decode base64 string")
|
| 26 |
-
return image.Config{}, "", "", err
|
| 27 |
}
|
| 28 |
|
| 29 |
// 创建一个bytes.Buffer用于存储解码后的数据
|
|
@@ -61,20 +63,51 @@ func DecodeBase64FileData(base64String string) (string, string, error) {
|
|
| 61 |
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
| 62 |
resp, err := DoDownloadRequest(url)
|
| 63 |
if err != nil {
|
| 64 |
-
return "", "", err
|
| 65 |
-
}
|
| 66 |
-
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
| 67 |
-
return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
|
| 68 |
}
|
| 69 |
defer resp.Body.Close()
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
if err != nil {
|
| 73 |
-
return
|
|
|
|
|
|
|
|
|
|
| 74 |
}
|
| 75 |
-
|
| 76 |
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
}
|
| 79 |
|
| 80 |
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
|
@@ -92,7 +125,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
|
| 92 |
|
| 93 |
mimeType := response.Header.Get("Content-Type")
|
| 94 |
|
| 95 |
-
if !strings.HasPrefix(mimeType, "image/") {
|
| 96 |
return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
|
| 97 |
}
|
| 98 |
|
|
|
|
| 7 |
"fmt"
|
| 8 |
"image"
|
| 9 |
"io"
|
| 10 |
+
"net/http"
|
| 11 |
"one-api/common"
|
| 12 |
+
"one-api/constant"
|
| 13 |
"strings"
|
| 14 |
|
| 15 |
"golang.org/x/image/webp"
|
|
|
|
| 25 |
decodedData, err := base64.StdEncoding.DecodeString(base64String)
|
| 26 |
if err != nil {
|
| 27 |
fmt.Println("Error: Failed to decode base64 string")
|
| 28 |
+
return image.Config{}, "", "", fmt.Errorf("failed to decode base64 string: %s", err.Error())
|
| 29 |
}
|
| 30 |
|
| 31 |
// 创建一个bytes.Buffer用于存储解码后的数据
|
|
|
|
| 63 |
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
| 64 |
resp, err := DoDownloadRequest(url)
|
| 65 |
if err != nil {
|
| 66 |
+
return "", "", fmt.Errorf("failed to download image: %w", err)
|
|
|
|
|
|
|
|
|
|
| 67 |
}
|
| 68 |
defer resp.Body.Close()
|
| 69 |
+
|
| 70 |
+
// Check HTTP status code
|
| 71 |
+
if resp.StatusCode != http.StatusOK {
|
| 72 |
+
return "", "", fmt.Errorf("failed to download image: HTTP %d", resp.StatusCode)
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
contentType := resp.Header.Get("Content-Type")
|
| 76 |
+
if contentType != "application/octet-stream" && !strings.HasPrefix(contentType, "image/") {
|
| 77 |
+
return "", "", fmt.Errorf("invalid content type: %s, required image/*", contentType)
|
| 78 |
+
}
|
| 79 |
+
maxImageSize := int64(constant.MaxFileDownloadMB * 1024 * 1024)
|
| 80 |
+
|
| 81 |
+
// Check Content-Length if available
|
| 82 |
+
if resp.ContentLength > maxImageSize {
|
| 83 |
+
return "", "", fmt.Errorf("image size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxImageSize)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Use LimitReader to prevent reading oversized images
|
| 87 |
+
limitReader := io.LimitReader(resp.Body, maxImageSize)
|
| 88 |
+
buffer := &bytes.Buffer{}
|
| 89 |
+
|
| 90 |
+
written, err := io.Copy(buffer, limitReader)
|
| 91 |
if err != nil {
|
| 92 |
+
return "", "", fmt.Errorf("failed to read image data: %w", err)
|
| 93 |
+
}
|
| 94 |
+
if written >= maxImageSize {
|
| 95 |
+
return "", "", fmt.Errorf("image size exceeds maximum allowed size of %d bytes", maxImageSize)
|
| 96 |
}
|
| 97 |
+
|
| 98 |
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
|
| 99 |
+
mimeType = contentType
|
| 100 |
+
|
| 101 |
+
// Handle application/octet-stream type
|
| 102 |
+
if mimeType == "application/octet-stream" {
|
| 103 |
+
_, format, _, err := DecodeBase64ImageData(data)
|
| 104 |
+
if err != nil {
|
| 105 |
+
return "", "", err
|
| 106 |
+
}
|
| 107 |
+
mimeType = "image/" + format
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
return mimeType, data, nil
|
| 111 |
}
|
| 112 |
|
| 113 |
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
|
|
|
| 125 |
|
| 126 |
mimeType := response.Header.Get("Content-Type")
|
| 127 |
|
| 128 |
+
if mimeType != "application/octet-stream" && !strings.HasPrefix(mimeType, "image/") {
|
| 129 |
return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
|
| 130 |
}
|
| 131 |
|
service/quota.go
CHANGED
|
@@ -38,9 +38,9 @@ func calculateAudioQuota(info QuotaInfo) int {
|
|
| 38 |
return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
|
| 39 |
}
|
| 40 |
|
| 41 |
-
completionRatio :=
|
| 42 |
-
audioRatio :=
|
| 43 |
-
audioCompletionRatio :=
|
| 44 |
ratio := info.GroupRatio * info.ModelRatio
|
| 45 |
|
| 46 |
quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
|
|
@@ -75,7 +75,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
|
| 75 |
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
| 76 |
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
| 77 |
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
| 78 |
-
modelRatio, _ :=
|
| 79 |
|
| 80 |
quotaInfo := QuotaInfo{
|
| 81 |
InputDetails: TokenDetails{
|
|
@@ -122,9 +122,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
|
| 122 |
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
| 123 |
|
| 124 |
tokenName := ctx.GetString("token_name")
|
| 125 |
-
completionRatio :=
|
| 126 |
-
audioRatio :=
|
| 127 |
-
audioCompletionRatio :=
|
| 128 |
|
| 129 |
quotaInfo := QuotaInfo{
|
| 130 |
InputDetails: TokenDetails{
|
|
@@ -184,9 +184,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|
| 184 |
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
|
| 185 |
|
| 186 |
tokenName := ctx.GetString("token_name")
|
| 187 |
-
completionRatio :=
|
| 188 |
-
audioRatio :=
|
| 189 |
-
audioCompletionRatio :=
|
| 190 |
|
| 191 |
modelRatio := priceData.ModelRatio
|
| 192 |
groupRatio := priceData.GroupRatio
|
|
|
|
| 38 |
return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
|
| 39 |
}
|
| 40 |
|
| 41 |
+
completionRatio := setting.GetCompletionRatio(info.ModelName)
|
| 42 |
+
audioRatio := setting.GetAudioRatio(info.ModelName)
|
| 43 |
+
audioCompletionRatio := setting.GetAudioCompletionRatio(info.ModelName)
|
| 44 |
ratio := info.GroupRatio * info.ModelRatio
|
| 45 |
|
| 46 |
quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
|
|
|
|
| 75 |
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
| 76 |
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
| 77 |
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
| 78 |
+
modelRatio, _ := setting.GetModelRatio(modelName)
|
| 79 |
|
| 80 |
quotaInfo := QuotaInfo{
|
| 81 |
InputDetails: TokenDetails{
|
|
|
|
| 122 |
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
| 123 |
|
| 124 |
tokenName := ctx.GetString("token_name")
|
| 125 |
+
completionRatio := setting.GetCompletionRatio(modelName)
|
| 126 |
+
audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName)
|
| 127 |
+
audioCompletionRatio := setting.GetAudioCompletionRatio(modelName)
|
| 128 |
|
| 129 |
quotaInfo := QuotaInfo{
|
| 130 |
InputDetails: TokenDetails{
|
|
|
|
| 184 |
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
|
| 185 |
|
| 186 |
tokenName := ctx.GetString("token_name")
|
| 187 |
+
completionRatio := setting.GetCompletionRatio(relayInfo.OriginModelName)
|
| 188 |
+
audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName)
|
| 189 |
+
audioCompletionRatio := setting.GetAudioCompletionRatio(relayInfo.OriginModelName)
|
| 190 |
|
| 191 |
modelRatio := priceData.ModelRatio
|
| 192 |
groupRatio := priceData.GroupRatio
|
service/token_counter.go
CHANGED
|
@@ -10,6 +10,7 @@ import (
|
|
| 10 |
"one-api/constant"
|
| 11 |
"one-api/dto"
|
| 12 |
relaycommon "one-api/relay/common"
|
|
|
|
| 13 |
"strings"
|
| 14 |
"unicode/utf8"
|
| 15 |
|
|
@@ -32,7 +33,7 @@ func InitTokenEncoders() {
|
|
| 32 |
if err != nil {
|
| 33 |
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
| 34 |
}
|
| 35 |
-
for model, _ := range
|
| 36 |
if strings.HasPrefix(model, "gpt-3.5") {
|
| 37 |
tokenEncoderMap[model] = cl100TokenEncoder
|
| 38 |
} else if strings.HasPrefix(model, "gpt-4") {
|
|
|
|
| 10 |
"one-api/constant"
|
| 11 |
"one-api/dto"
|
| 12 |
relaycommon "one-api/relay/common"
|
| 13 |
+
"one-api/setting"
|
| 14 |
"strings"
|
| 15 |
"unicode/utf8"
|
| 16 |
|
|
|
|
| 33 |
if err != nil {
|
| 34 |
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
| 35 |
}
|
| 36 |
+
for model, _ := range setting.GetDefaultModelRatioMap() {
|
| 37 |
if strings.HasPrefix(model, "gpt-3.5") {
|
| 38 |
tokenEncoderMap[model] = cl100TokenEncoder
|
| 39 |
} else if strings.HasPrefix(model, "gpt-4") {
|
service/user_notify.go
CHANGED
|
@@ -11,7 +11,10 @@ import (
|
|
| 11 |
|
| 12 |
func NotifyRootUser(t string, subject string, content string) {
|
| 13 |
user := model.GetRootUser().ToBaseUser()
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
}
|
| 16 |
|
| 17 |
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
|
|
|
|
| 11 |
|
| 12 |
func NotifyRootUser(t string, subject string, content string) {
|
| 13 |
user := model.GetRootUser().ToBaseUser()
|
| 14 |
+
err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
|
| 15 |
+
if err != nil {
|
| 16 |
+
common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
|
| 17 |
+
}
|
| 18 |
}
|
| 19 |
|
| 20 |
func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
|
{common → setting}/model-ratio.go
RENAMED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
-
package
|
| 2 |
|
| 3 |
import (
|
| 4 |
"encoding/json"
|
|
|
|
|
|
|
| 5 |
"strings"
|
| 6 |
"sync"
|
| 7 |
)
|
|
@@ -261,7 +263,7 @@ func ModelPrice2JSONString() string {
|
|
| 261 |
GetModelPriceMap()
|
| 262 |
jsonBytes, err := json.Marshal(modelPriceMap)
|
| 263 |
if err != nil {
|
| 264 |
-
SysError("error marshalling model price: " + err.Error())
|
| 265 |
}
|
| 266 |
return string(jsonBytes)
|
| 267 |
}
|
|
@@ -285,7 +287,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
|
| 285 |
price, ok := modelPriceMap[name]
|
| 286 |
if !ok {
|
| 287 |
if printErr {
|
| 288 |
-
SysError("model price not found: " + name)
|
| 289 |
}
|
| 290 |
return -1, false
|
| 291 |
}
|
|
@@ -305,7 +307,7 @@ func ModelRatio2JSONString() string {
|
|
| 305 |
GetModelRatioMap()
|
| 306 |
jsonBytes, err := json.Marshal(modelRatioMap)
|
| 307 |
if err != nil {
|
| 308 |
-
SysError("error marshalling model ratio: " + err.Error())
|
| 309 |
}
|
| 310 |
return string(jsonBytes)
|
| 311 |
}
|
|
@@ -324,8 +326,7 @@ func GetModelRatio(name string) (float64, bool) {
|
|
| 324 |
}
|
| 325 |
ratio, ok := modelRatioMap[name]
|
| 326 |
if !ok {
|
| 327 |
-
|
| 328 |
-
return 37.5, false
|
| 329 |
}
|
| 330 |
return ratio, true
|
| 331 |
}
|
|
@@ -333,7 +334,7 @@ func GetModelRatio(name string) (float64, bool) {
|
|
| 333 |
func DefaultModelRatio2JSONString() string {
|
| 334 |
jsonBytes, err := json.Marshal(defaultModelRatio)
|
| 335 |
if err != nil {
|
| 336 |
-
SysError("error marshalling model ratio: " + err.Error())
|
| 337 |
}
|
| 338 |
return string(jsonBytes)
|
| 339 |
}
|
|
@@ -355,7 +356,7 @@ func CompletionRatio2JSONString() string {
|
|
| 355 |
GetCompletionRatioMap()
|
| 356 |
jsonBytes, err := json.Marshal(CompletionRatio)
|
| 357 |
if err != nil {
|
| 358 |
-
SysError("error marshalling completion ratio: " + err.Error())
|
| 359 |
}
|
| 360 |
return string(jsonBytes)
|
| 361 |
}
|
|
|
|
| 1 |
+
package setting
|
| 2 |
|
| 3 |
import (
|
| 4 |
"encoding/json"
|
| 5 |
+
"one-api/common"
|
| 6 |
+
"one-api/setting/operation_setting"
|
| 7 |
"strings"
|
| 8 |
"sync"
|
| 9 |
)
|
|
|
|
| 263 |
GetModelPriceMap()
|
| 264 |
jsonBytes, err := json.Marshal(modelPriceMap)
|
| 265 |
if err != nil {
|
| 266 |
+
common.SysError("error marshalling model price: " + err.Error())
|
| 267 |
}
|
| 268 |
return string(jsonBytes)
|
| 269 |
}
|
|
|
|
| 287 |
price, ok := modelPriceMap[name]
|
| 288 |
if !ok {
|
| 289 |
if printErr {
|
| 290 |
+
common.SysError("model price not found: " + name)
|
| 291 |
}
|
| 292 |
return -1, false
|
| 293 |
}
|
|
|
|
| 307 |
GetModelRatioMap()
|
| 308 |
jsonBytes, err := json.Marshal(modelRatioMap)
|
| 309 |
if err != nil {
|
| 310 |
+
common.SysError("error marshalling model ratio: " + err.Error())
|
| 311 |
}
|
| 312 |
return string(jsonBytes)
|
| 313 |
}
|
|
|
|
| 326 |
}
|
| 327 |
ratio, ok := modelRatioMap[name]
|
| 328 |
if !ok {
|
| 329 |
+
return 37.5, operation_setting.SelfUseModeEnabled
|
|
|
|
| 330 |
}
|
| 331 |
return ratio, true
|
| 332 |
}
|
|
|
|
| 334 |
func DefaultModelRatio2JSONString() string {
|
| 335 |
jsonBytes, err := json.Marshal(defaultModelRatio)
|
| 336 |
if err != nil {
|
| 337 |
+
common.SysError("error marshalling model ratio: " + err.Error())
|
| 338 |
}
|
| 339 |
return string(jsonBytes)
|
| 340 |
}
|
|
|
|
| 356 |
GetCompletionRatioMap()
|
| 357 |
jsonBytes, err := json.Marshal(CompletionRatio)
|
| 358 |
if err != nil {
|
| 359 |
+
common.SysError("error marshalling completion ratio: " + err.Error())
|
| 360 |
}
|
| 361 |
return string(jsonBytes)
|
| 362 |
}
|
setting/{operation_setting.go → operation_setting/operation_setting.go}
RENAMED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
-
package
|
| 2 |
|
| 3 |
import "strings"
|
| 4 |
|
| 5 |
var DemoSiteEnabled = false
|
|
|
|
| 6 |
|
| 7 |
var AutomaticDisableKeywords = []string{
|
| 8 |
"Your credit balance is too low",
|
|
|
|
| 1 |
+
package operation_setting
|
| 2 |
|
| 3 |
import "strings"
|
| 4 |
|
| 5 |
var DemoSiteEnabled = false
|
| 6 |
+
var SelfUseModeEnabled = false
|
| 7 |
|
| 8 |
var AutomaticDisableKeywords = []string{
|
| 9 |
"Your credit balance is too low",
|
web/src/App.js
CHANGED
|
@@ -30,6 +30,7 @@ import { useTranslation } from 'react-i18next';
|
|
| 30 |
import { StatusContext } from './context/Status';
|
| 31 |
import { setStatusData } from './helpers/data.js';
|
| 32 |
import { API, showError } from './helpers';
|
|
|
|
| 33 |
|
| 34 |
const Home = lazy(() => import('./pages/Home'));
|
| 35 |
const Detail = lazy(() => import('./pages/Detail'));
|
|
@@ -177,6 +178,16 @@ function App() {
|
|
| 177 |
</PrivateRoute>
|
| 178 |
}
|
| 179 |
/>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
<Route
|
| 181 |
path='/topup'
|
| 182 |
element={
|
|
|
|
| 30 |
import { StatusContext } from './context/Status';
|
| 31 |
import { setStatusData } from './helpers/data.js';
|
| 32 |
import { API, showError } from './helpers';
|
| 33 |
+
import PersonalSetting from './components/PersonalSetting.js';
|
| 34 |
|
| 35 |
const Home = lazy(() => import('./pages/Home'));
|
| 36 |
const Detail = lazy(() => import('./pages/Detail'));
|
|
|
|
| 178 |
</PrivateRoute>
|
| 179 |
}
|
| 180 |
/>
|
| 181 |
+
<Route
|
| 182 |
+
path='/personal'
|
| 183 |
+
element={
|
| 184 |
+
<PrivateRoute>
|
| 185 |
+
<Suspense fallback={<Loading></Loading>}>
|
| 186 |
+
<PersonalSetting />
|
| 187 |
+
</Suspense>
|
| 188 |
+
</PrivateRoute>
|
| 189 |
+
}
|
| 190 |
+
/>
|
| 191 |
<Route
|
| 192 |
path='/topup'
|
| 193 |
element={
|
web/src/components/ChannelsTable.js
CHANGED
|
@@ -15,7 +15,7 @@ import {
|
|
| 15 |
getQuotaPerUnit,
|
| 16 |
renderGroup,
|
| 17 |
renderNumberWithPoint,
|
| 18 |
-
renderQuota, renderQuotaWithPrompt
|
| 19 |
} from '../helpers/render';
|
| 20 |
import {
|
| 21 |
Button, Divider,
|
|
@@ -378,17 +378,15 @@ const ChannelsTable = () => {
|
|
| 378 |
>
|
| 379 |
{t('测试')}
|
| 380 |
</Button>
|
| 381 |
-
<
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
></Button>
|
| 391 |
-
</Dropdown>
|
| 392 |
</SplitButtonGroup>
|
| 393 |
<Popconfirm
|
| 394 |
title={t('确定是否要删除此渠道?')}
|
|
@@ -522,6 +520,9 @@ const ChannelsTable = () => {
|
|
| 522 |
const [enableTagMode, setEnableTagMode] = useState(false);
|
| 523 |
const [showBatchSetTag, setShowBatchSetTag] = useState(false);
|
| 524 |
const [batchSetTagValue, setBatchSetTagValue] = useState('');
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
|
| 527 |
const removeRecord = (record) => {
|
|
@@ -1289,6 +1290,77 @@ const ChannelsTable = () => {
|
|
| 1289 |
onChange={(v) => setBatchSetTagValue(v)}
|
| 1290 |
/>
|
| 1291 |
</Modal>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1292 |
</>
|
| 1293 |
);
|
| 1294 |
};
|
|
|
|
| 15 |
getQuotaPerUnit,
|
| 16 |
renderGroup,
|
| 17 |
renderNumberWithPoint,
|
| 18 |
+
renderQuota, renderQuotaWithPrompt, stringToColor
|
| 19 |
} from '../helpers/render';
|
| 20 |
import {
|
| 21 |
Button, Divider,
|
|
|
|
| 378 |
>
|
| 379 |
{t('测试')}
|
| 380 |
</Button>
|
| 381 |
+
<Button
|
| 382 |
+
style={{ padding: '8px 4px' }}
|
| 383 |
+
type="primary"
|
| 384 |
+
icon={<IconTreeTriangleDown />}
|
| 385 |
+
onClick={() => {
|
| 386 |
+
setCurrentTestChannel(record);
|
| 387 |
+
setShowModelTestModal(true);
|
| 388 |
+
}}
|
| 389 |
+
></Button>
|
|
|
|
|
|
|
| 390 |
</SplitButtonGroup>
|
| 391 |
<Popconfirm
|
| 392 |
title={t('确定是否要删除此渠道?')}
|
|
|
|
| 520 |
const [enableTagMode, setEnableTagMode] = useState(false);
|
| 521 |
const [showBatchSetTag, setShowBatchSetTag] = useState(false);
|
| 522 |
const [batchSetTagValue, setBatchSetTagValue] = useState('');
|
| 523 |
+
const [showModelTestModal, setShowModelTestModal] = useState(false);
|
| 524 |
+
const [currentTestChannel, setCurrentTestChannel] = useState(null);
|
| 525 |
+
const [modelSearchKeyword, setModelSearchKeyword] = useState('');
|
| 526 |
|
| 527 |
|
| 528 |
const removeRecord = (record) => {
|
|
|
|
| 1290 |
onChange={(v) => setBatchSetTagValue(v)}
|
| 1291 |
/>
|
| 1292 |
</Modal>
|
| 1293 |
+
|
| 1294 |
+
{/* 模型测试弹窗 */}
|
| 1295 |
+
<Modal
|
| 1296 |
+
title={t('选择模型进行测试')}
|
| 1297 |
+
visible={showModelTestModal && currentTestChannel !== null}
|
| 1298 |
+
onCancel={() => {
|
| 1299 |
+
setShowModelTestModal(false);
|
| 1300 |
+
setModelSearchKeyword('');
|
| 1301 |
+
}}
|
| 1302 |
+
footer={null}
|
| 1303 |
+
maskClosable={true}
|
| 1304 |
+
centered={true}
|
| 1305 |
+
width={600}
|
| 1306 |
+
>
|
| 1307 |
+
<div style={{ maxHeight: '500px', overflowY: 'auto', padding: '10px' }}>
|
| 1308 |
+
{currentTestChannel && (
|
| 1309 |
+
<div>
|
| 1310 |
+
<Typography.Title heading={6} style={{ marginBottom: '16px' }}>
|
| 1311 |
+
{t('渠道')}: {currentTestChannel.name}
|
| 1312 |
+
</Typography.Title>
|
| 1313 |
+
|
| 1314 |
+
{/* 搜索框 */}
|
| 1315 |
+
<Input
|
| 1316 |
+
placeholder={t('搜索模型...')}
|
| 1317 |
+
value={modelSearchKeyword}
|
| 1318 |
+
onChange={(value) => setModelSearchKeyword(value)}
|
| 1319 |
+
style={{ marginBottom: '16px' }}
|
| 1320 |
+
showClear
|
| 1321 |
+
/>
|
| 1322 |
+
|
| 1323 |
+
<div style={{
|
| 1324 |
+
display: 'grid',
|
| 1325 |
+
gridTemplateColumns: 'repeat(auto-fill, minmax(180px, 1fr))',
|
| 1326 |
+
gap: '10px'
|
| 1327 |
+
}}>
|
| 1328 |
+
{currentTestChannel.models.split(',')
|
| 1329 |
+
.filter(model => model.toLowerCase().includes(modelSearchKeyword.toLowerCase()))
|
| 1330 |
+
.map((model, index) => {
|
| 1331 |
+
|
| 1332 |
+
return (
|
| 1333 |
+
<Button
|
| 1334 |
+
key={index}
|
| 1335 |
+
theme="light"
|
| 1336 |
+
type="tertiary"
|
| 1337 |
+
style={{
|
| 1338 |
+
height: 'auto',
|
| 1339 |
+
padding: '8px 12px',
|
| 1340 |
+
textAlign: 'center',
|
| 1341 |
+
}}
|
| 1342 |
+
onClick={() => {
|
| 1343 |
+
testChannel(currentTestChannel, model);
|
| 1344 |
+
}}
|
| 1345 |
+
>
|
| 1346 |
+
{model}
|
| 1347 |
+
</Button>
|
| 1348 |
+
);
|
| 1349 |
+
})}
|
| 1350 |
+
</div>
|
| 1351 |
+
|
| 1352 |
+
{/* 显示搜索结果数量 */}
|
| 1353 |
+
{modelSearchKeyword && (
|
| 1354 |
+
<Typography.Text type="secondary" style={{ marginTop: '16px', display: 'block' }}>
|
| 1355 |
+
{t('找到')} {currentTestChannel.models.split(',').filter(model =>
|
| 1356 |
+
model.toLowerCase().includes(modelSearchKeyword.toLowerCase())
|
| 1357 |
+
).length} {t('个模型')}
|
| 1358 |
+
</Typography.Text>
|
| 1359 |
+
)}
|
| 1360 |
+
</div>
|
| 1361 |
+
)}
|
| 1362 |
+
</div>
|
| 1363 |
+
</Modal>
|
| 1364 |
</>
|
| 1365 |
);
|
| 1366 |
};
|
web/src/components/HeaderBar.js
CHANGED
|
@@ -21,15 +21,17 @@ import {
|
|
| 21 |
IconUser,
|
| 22 |
IconLanguage
|
| 23 |
} from '@douyinfe/semi-icons';
|
| 24 |
-
import { Avatar, Button, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui';
|
| 25 |
import { stringToColor } from '../helpers/render';
|
| 26 |
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
| 27 |
import { StyleContext } from '../context/Style/index.js';
|
|
|
|
| 28 |
|
| 29 |
const HeaderBar = () => {
|
| 30 |
const { t, i18n } = useTranslation();
|
| 31 |
const [userState, userDispatch] = useContext(UserContext);
|
| 32 |
const [styleState, styleDispatch] = useContext(StyleContext);
|
|
|
|
| 33 |
let navigate = useNavigate();
|
| 34 |
const [currentLang, setCurrentLang] = useState(i18n.language);
|
| 35 |
|
|
@@ -40,6 +42,10 @@ const HeaderBar = () => {
|
|
| 40 |
const isNewYear =
|
| 41 |
(currentDate.getMonth() === 0 && currentDate.getDate() === 1);
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
let buttons = [
|
| 44 |
{
|
| 45 |
text: t('首页'),
|
|
@@ -166,7 +172,7 @@ const HeaderBar = () => {
|
|
| 166 |
onSelect={(key) => {}}
|
| 167 |
header={styleState.isMobile?{
|
| 168 |
logo: (
|
| 169 |
-
|
| 170 |
{
|
| 171 |
!styleState.showSider ?
|
| 172 |
<Button icon={<IconMenu />} theme="light" aria-label={t('展开侧边栏')} onClick={
|
|
@@ -176,13 +182,52 @@ const HeaderBar = () => {
|
|
| 176 |
() => styleDispatch({ type: 'SET_SIDER', payload: false })
|
| 177 |
} />
|
| 178 |
}
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
),
|
| 181 |
}:{
|
| 182 |
logo: (
|
| 183 |
<img src={logo} alt='logo' />
|
| 184 |
),
|
| 185 |
-
text:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
}}
|
| 187 |
items={buttons}
|
| 188 |
footer={
|
|
@@ -266,7 +311,8 @@ const HeaderBar = () => {
|
|
| 266 |
icon={<IconUser />}
|
| 267 |
/>
|
| 268 |
{
|
| 269 |
-
|
|
|
|
| 270 |
<Nav.Item
|
| 271 |
itemKey={'register'}
|
| 272 |
text={t('注册')}
|
|
|
|
| 21 |
IconUser,
|
| 22 |
IconLanguage
|
| 23 |
} from '@douyinfe/semi-icons';
|
| 24 |
+
import { Avatar, Button, Dropdown, Layout, Nav, Switch, Tag } from '@douyinfe/semi-ui';
|
| 25 |
import { stringToColor } from '../helpers/render';
|
| 26 |
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
| 27 |
import { StyleContext } from '../context/Style/index.js';
|
| 28 |
+
import { StatusContext } from '../context/Status/index.js';
|
| 29 |
|
| 30 |
const HeaderBar = () => {
|
| 31 |
const { t, i18n } = useTranslation();
|
| 32 |
const [userState, userDispatch] = useContext(UserContext);
|
| 33 |
const [styleState, styleDispatch] = useContext(StyleContext);
|
| 34 |
+
const [statusState, statusDispatch] = useContext(StatusContext);
|
| 35 |
let navigate = useNavigate();
|
| 36 |
const [currentLang, setCurrentLang] = useState(i18n.language);
|
| 37 |
|
|
|
|
| 42 |
const isNewYear =
|
| 43 |
(currentDate.getMonth() === 0 && currentDate.getDate() === 1);
|
| 44 |
|
| 45 |
+
// Check if self-use mode is enabled
|
| 46 |
+
const isSelfUseMode = statusState?.status?.self_use_mode_enabled || false;
|
| 47 |
+
const isDemoSiteMode = statusState?.status?.demo_site_enabled || false;
|
| 48 |
+
|
| 49 |
let buttons = [
|
| 50 |
{
|
| 51 |
text: t('首页'),
|
|
|
|
| 172 |
onSelect={(key) => {}}
|
| 173 |
header={styleState.isMobile?{
|
| 174 |
logo: (
|
| 175 |
+
<div style={{ display: 'flex', alignItems: 'center', position: 'relative' }}>
|
| 176 |
{
|
| 177 |
!styleState.showSider ?
|
| 178 |
<Button icon={<IconMenu />} theme="light" aria-label={t('展开侧边栏')} onClick={
|
|
|
|
| 182 |
() => styleDispatch({ type: 'SET_SIDER', payload: false })
|
| 183 |
} />
|
| 184 |
}
|
| 185 |
+
{(isSelfUseMode || isDemoSiteMode) && (
|
| 186 |
+
<Tag
|
| 187 |
+
color={isSelfUseMode ? 'purple' : 'blue'}
|
| 188 |
+
style={{
|
| 189 |
+
position: 'absolute',
|
| 190 |
+
top: '-8px',
|
| 191 |
+
right: '-15px',
|
| 192 |
+
fontSize: '0.7rem',
|
| 193 |
+
padding: '0 4px',
|
| 194 |
+
height: 'auto',
|
| 195 |
+
lineHeight: '1.2',
|
| 196 |
+
zIndex: 1,
|
| 197 |
+
pointerEvents: 'none'
|
| 198 |
+
}}
|
| 199 |
+
>
|
| 200 |
+
{isSelfUseMode ? t('自用模式') : t('演示站点')}
|
| 201 |
+
</Tag>
|
| 202 |
+
)}
|
| 203 |
+
</div>
|
| 204 |
),
|
| 205 |
}:{
|
| 206 |
logo: (
|
| 207 |
<img src={logo} alt='logo' />
|
| 208 |
),
|
| 209 |
+
text: (
|
| 210 |
+
<div style={{ position: 'relative', display: 'inline-block' }}>
|
| 211 |
+
{systemName}
|
| 212 |
+
{(isSelfUseMode || isDemoSiteMode) && (
|
| 213 |
+
<Tag
|
| 214 |
+
color={isSelfUseMode ? 'purple' : 'blue'}
|
| 215 |
+
style={{
|
| 216 |
+
position: 'absolute',
|
| 217 |
+
top: '-10px',
|
| 218 |
+
right: '-25px',
|
| 219 |
+
fontSize: '0.7rem',
|
| 220 |
+
padding: '0 4px',
|
| 221 |
+
whiteSpace: 'nowrap',
|
| 222 |
+
zIndex: 1,
|
| 223 |
+
boxShadow: '0 0 3px rgba(255, 255, 255, 0.7)'
|
| 224 |
+
}}
|
| 225 |
+
>
|
| 226 |
+
{isSelfUseMode ? t('自用模式') : t('演示站点')}
|
| 227 |
+
</Tag>
|
| 228 |
+
)}
|
| 229 |
+
</div>
|
| 230 |
+
),
|
| 231 |
}}
|
| 232 |
items={buttons}
|
| 233 |
footer={
|
|
|
|
| 311 |
icon={<IconUser />}
|
| 312 |
/>
|
| 313 |
{
|
| 314 |
+
// Hide register option in self-use mode
|
| 315 |
+
!styleState.isMobile && !isSelfUseMode && (
|
| 316 |
<Nav.Item
|
| 317 |
itemKey={'register'}
|
| 318 |
text={t('注册')}
|
web/src/components/OperationSetting.js
CHANGED
|
@@ -60,6 +60,7 @@ const OperationSetting = () => {
|
|
| 60 |
RetryTimes: 0,
|
| 61 |
Chats: "[]",
|
| 62 |
DemoSiteEnabled: false,
|
|
|
|
| 63 |
AutomaticDisableKeywords: '',
|
| 64 |
});
|
| 65 |
|
|
|
|
| 60 |
RetryTimes: 0,
|
| 61 |
Chats: "[]",
|
| 62 |
DemoSiteEnabled: false,
|
| 63 |
+
SelfUseModeEnabled: false,
|
| 64 |
AutomaticDisableKeywords: '',
|
| 65 |
});
|
| 66 |
|
web/src/components/OtherSetting.js
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
-
import React, { useEffect, useRef, useState } from 'react';
|
| 2 |
-
import { Banner, Button, Col, Form, Row } from '@douyinfe/semi-ui';
|
| 3 |
-
import { API, showError, showSuccess } from '../helpers';
|
| 4 |
import { marked } from 'marked';
|
| 5 |
import { useTranslation } from 'react-i18next';
|
|
|
|
|
|
|
| 6 |
|
| 7 |
const OtherSetting = () => {
|
| 8 |
const { t } = useTranslation();
|
|
@@ -16,6 +18,7 @@ const OtherSetting = () => {
|
|
| 16 |
});
|
| 17 |
let [loading, setLoading] = useState(false);
|
| 18 |
const [showUpdateModal, setShowUpdateModal] = useState(false);
|
|
|
|
| 19 |
const [updateData, setUpdateData] = useState({
|
| 20 |
tag_name: '',
|
| 21 |
content: '',
|
|
@@ -43,6 +46,7 @@ const OtherSetting = () => {
|
|
| 43 |
HomePageContent: false,
|
| 44 |
About: false,
|
| 45 |
Footer: false,
|
|
|
|
| 46 |
});
|
| 47 |
const handleInputChange = async (value, e) => {
|
| 48 |
const name = e.target.id;
|
|
@@ -145,23 +149,48 @@ const OtherSetting = () => {
|
|
| 145 |
}
|
| 146 |
};
|
| 147 |
|
| 148 |
-
const openGitHubRelease = () => {
|
| 149 |
-
window.location = 'https://github.com/songquanpeng/one-api/releases/latest';
|
| 150 |
-
};
|
| 151 |
-
|
| 152 |
const checkUpdate = async () => {
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
}
|
| 166 |
};
|
| 167 |
const getOptions = async () => {
|
|
@@ -186,9 +215,41 @@ const OtherSetting = () => {
|
|
| 186 |
getOptions();
|
| 187 |
}, []);
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
return (
|
| 190 |
<Row>
|
| 191 |
<Col span={24}>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
{/* 通用设置 */}
|
| 193 |
<Form
|
| 194 |
values={inputs}
|
|
@@ -282,28 +343,25 @@ const OtherSetting = () => {
|
|
| 282 |
</Form.Section>
|
| 283 |
</Form>
|
| 284 |
</Col>
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
{/* />*/}
|
| 305 |
-
{/* </Modal.Actions>*/}
|
| 306 |
-
{/*</Modal>*/}
|
| 307 |
</Row>
|
| 308 |
);
|
| 309 |
};
|
|
|
|
| 1 |
+
import React, { useContext, useEffect, useRef, useState } from 'react';
|
| 2 |
+
import { Banner, Button, Col, Form, Row, Modal, Space } from '@douyinfe/semi-ui';
|
| 3 |
+
import { API, showError, showSuccess, timestamp2string } from '../helpers';
|
| 4 |
import { marked } from 'marked';
|
| 5 |
import { useTranslation } from 'react-i18next';
|
| 6 |
+
import { StatusContext } from '../context/Status/index.js';
|
| 7 |
+
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
| 8 |
|
| 9 |
const OtherSetting = () => {
|
| 10 |
const { t } = useTranslation();
|
|
|
|
| 18 |
});
|
| 19 |
let [loading, setLoading] = useState(false);
|
| 20 |
const [showUpdateModal, setShowUpdateModal] = useState(false);
|
| 21 |
+
const [statusState, statusDispatch] = useContext(StatusContext);
|
| 22 |
const [updateData, setUpdateData] = useState({
|
| 23 |
tag_name: '',
|
| 24 |
content: '',
|
|
|
|
| 46 |
HomePageContent: false,
|
| 47 |
About: false,
|
| 48 |
Footer: false,
|
| 49 |
+
CheckUpdate: false
|
| 50 |
});
|
| 51 |
const handleInputChange = async (value, e) => {
|
| 52 |
const name = e.target.id;
|
|
|
|
| 149 |
}
|
| 150 |
};
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
const checkUpdate = async () => {
|
| 153 |
+
try {
|
| 154 |
+
setLoadingInput((loadingInput) => ({ ...loadingInput, CheckUpdate: true }));
|
| 155 |
+
// Use a CORS proxy to avoid direct cross-origin requests to GitHub API
|
| 156 |
+
// Option 1: Use a public CORS proxy service
|
| 157 |
+
// const proxyUrl = 'https://cors-anywhere.herokuapp.com/';
|
| 158 |
+
// const res = await API.get(
|
| 159 |
+
// `${proxyUrl}https://api.github.com/repos/Calcium-Ion/new-api/releases/latest`,
|
| 160 |
+
// );
|
| 161 |
+
|
| 162 |
+
// Option 2: Use the JSON proxy approach which often works better with GitHub API
|
| 163 |
+
const res = await fetch(
|
| 164 |
+
'https://api.github.com/repos/Calcium-Ion/new-api/releases/latest',
|
| 165 |
+
{
|
| 166 |
+
headers: {
|
| 167 |
+
'Accept': 'application/json',
|
| 168 |
+
'Content-Type': 'application/json',
|
| 169 |
+
// Adding User-Agent which is often required by GitHub API
|
| 170 |
+
'User-Agent': 'new-api-update-checker'
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
).then(response => response.json());
|
| 174 |
+
|
| 175 |
+
// Option 3: Use a local proxy endpoint
|
| 176 |
+
// Create a cached version of the response to avoid frequent GitHub API calls
|
| 177 |
+
// const res = await API.get('/api/status/github-latest-release');
|
| 178 |
+
|
| 179 |
+
const { tag_name, body } = res;
|
| 180 |
+
if (tag_name === statusState?.status?.version) {
|
| 181 |
+
showSuccess(`已是最新版本:${tag_name}`);
|
| 182 |
+
} else {
|
| 183 |
+
setUpdateData({
|
| 184 |
+
tag_name: tag_name,
|
| 185 |
+
content: marked.parse(body),
|
| 186 |
+
});
|
| 187 |
+
setShowUpdateModal(true);
|
| 188 |
+
}
|
| 189 |
+
} catch (error) {
|
| 190 |
+
console.error('Failed to check for updates:', error);
|
| 191 |
+
showError('检查更新失败,请稍后再试');
|
| 192 |
+
} finally {
|
| 193 |
+
setLoadingInput((loadingInput) => ({ ...loadingInput, CheckUpdate: false }));
|
| 194 |
}
|
| 195 |
};
|
| 196 |
const getOptions = async () => {
|
|
|
|
| 215 |
getOptions();
|
| 216 |
}, []);
|
| 217 |
|
| 218 |
+
// Function to open GitHub release page
|
| 219 |
+
const openGitHubRelease = () => {
|
| 220 |
+
window.open(`https://github.com/Calcium-Ion/new-api/releases/tag/${updateData.tag_name}`, '_blank');
|
| 221 |
+
};
|
| 222 |
+
|
| 223 |
+
const getStartTimeString = () => {
|
| 224 |
+
const timestamp = statusState?.status?.start_time;
|
| 225 |
+
return statusState.status ? timestamp2string(timestamp) : '';
|
| 226 |
+
};
|
| 227 |
+
|
| 228 |
return (
|
| 229 |
<Row>
|
| 230 |
<Col span={24}>
|
| 231 |
+
{/* 版本信息 */}
|
| 232 |
+
<Form style={{ marginBottom: 15 }}>
|
| 233 |
+
<Form.Section text={t('系统信息')}>
|
| 234 |
+
<Row>
|
| 235 |
+
<Col span={16}>
|
| 236 |
+
<Space>
|
| 237 |
+
<Text>
|
| 238 |
+
{t('当前版本')}:{statusState?.status?.version || t('未知')}
|
| 239 |
+
</Text>
|
| 240 |
+
<Button type="primary" onClick={checkUpdate} loading={loadingInput['CheckUpdate']}>
|
| 241 |
+
{t('检查更新')}
|
| 242 |
+
</Button>
|
| 243 |
+
</Space>
|
| 244 |
+
</Col>
|
| 245 |
+
</Row>
|
| 246 |
+
<Row>
|
| 247 |
+
<Col span={16}>
|
| 248 |
+
<Text>{t('启动时间')}:{getStartTimeString()}</Text>
|
| 249 |
+
</Col>
|
| 250 |
+
</Row>
|
| 251 |
+
</Form.Section>
|
| 252 |
+
</Form>
|
| 253 |
{/* 通用设置 */}
|
| 254 |
<Form
|
| 255 |
values={inputs}
|
|
|
|
| 343 |
</Form.Section>
|
| 344 |
</Form>
|
| 345 |
</Col>
|
| 346 |
+
<Modal
|
| 347 |
+
title={t('新版本') + ':' + updateData.tag_name}
|
| 348 |
+
visible={showUpdateModal}
|
| 349 |
+
onCancel={() => setShowUpdateModal(false)}
|
| 350 |
+
footer={[
|
| 351 |
+
<Button
|
| 352 |
+
key="details"
|
| 353 |
+
type="primary"
|
| 354 |
+
onClick={() => {
|
| 355 |
+
setShowUpdateModal(false);
|
| 356 |
+
openGitHubRelease();
|
| 357 |
+
}}
|
| 358 |
+
>
|
| 359 |
+
{t('详情')}
|
| 360 |
+
</Button>
|
| 361 |
+
]}
|
| 362 |
+
>
|
| 363 |
+
<div dangerouslySetInnerHTML={{ __html: updateData.content }}></div>
|
| 364 |
+
</Modal>
|
|
|
|
|
|
|
|
|
|
| 365 |
</Row>
|
| 366 |
);
|
| 367 |
};
|