diff --git a/README.en.md b/README.en.md index 51cf38bb252d65f9a27ef043d039605075f7957b..3885003f9bcd14554e73538587cb1f9c007cfe53 100644 --- a/README.en.md +++ b/README.en.md @@ -68,7 +68,7 @@ ## Model Support This version additionally supports: -1. Third-party model **gps** (gpt-4-gizmo-*) +1. Third-party model **gpts** (gpt-4-gizmo-*) 2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md) 3. Custom channels with full API URL support 4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md) @@ -162,7 +162,7 @@ docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtow ## Channel Retry Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**. -First retry uses same priority, second retry uses next priority, and so on. +If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request. ### Cache Configuration 1. `REDIS_CONN_STRING`: Use Redis as cache diff --git a/VERSION b/VERSION index bab2e920752421bdf1b3550124d43709a8caabb6..c2a30d15aa262002b2038b12db9a8aeb51310379 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v0.4.8.8.3 \ No newline at end of file +v0.4.9.0 \ No newline at end of file diff --git a/common/gopool.go b/common/gopool.go new file mode 100644 index 0000000000000000000000000000000000000000..bf5df311986f822a61816db3bbed601587a14093 --- /dev/null +++ b/common/gopool.go @@ -0,0 +1,24 @@ +package common + +import ( + "context" + "fmt" + "github.com/bytedance/gopkg/util/gopool" + "math" +) + +var relayGoPool gopool.Pool + +func init() { + relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig()) + relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) { + if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok { + SafeSendBool(stopChan, true) + } + SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i)) + }) +} + +func RelayCtxGo(ctx context.Context, f func()) { + relayGoPool.CtxGo(ctx, f) +} diff --git a/controller/channel-test.go b/controller/channel-test.go index 23922073f0132db5dfc83d39c74194ae63b14b63..98623a76c13b92766e3e41ef77c004737b71875d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -17,6 +17,7 @@ import ( "one-api/relay" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "strconv" "strings" @@ -72,18 +73,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } } - modelMapping := *channel.ModelMapping - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[testModel] != "" { - testModel = modelMap[testModel] - } - } - cache, err := model.GetUserCache(1) if err != nil { return err, nil @@ -97,7 +86,14 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr middleware.SetupContextForSelectedChannel(c, channel, testModel) - meta := relaycommon.GenRelayInfo(c) + info := relaycommon.GenRelayInfo(c) + + err = helper.ModelMappedHelper(c, info) + if err != nil { + return err, nil + } + testModel = info.UpstreamModelName + apiType, _ := constant.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { @@ -105,12 +101,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } request := buildTestRequest(testModel) - meta.UpstreamModelName = testModel - common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta)) + common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info)) - adaptor.Init(meta) + adaptor.Init(info) - convertedRequest, err := adaptor.ConvertRequest(c, meta, request) + convertedRequest, err := adaptor.ConvertRequest(c, info, request) if err != nil { return err, nil } @@ -120,7 +115,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } requestBody := bytes.NewBuffer(jsonData) c.Request.Body = io.NopCloser(requestBody) - resp, err := adaptor.DoRequest(c, meta, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return err, nil } @@ -132,7 +127,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err } } - usageA, respErr := adaptor.DoResponse(c, httpResp, meta) + usageA, respErr := adaptor.DoResponse(c, httpResp, info) if respErr != nil { return fmt.Errorf("%s", respErr.Error.Message), respErr } @@ -145,29 +140,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr if err != nil { return err, nil } - modelPrice, usePrice := common.GetModelPrice(testModel, false) - modelRatio, success := common.GetModelRatio(testModel) - if !usePrice && !success { - return fmt.Errorf("模型 %s 倍率和价格均未设置", testModel), nil + info.PromptTokens = usage.PromptTokens + priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens)) + if err != nil { + return err, nil } - completionRatio := common.GetCompletionRatio(testModel) - ratio := modelRatio quota := 0 - if !usePrice { - quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio)) - quota = int(math.Round(float64(quota) * ratio)) - if ratio != 0 && quota <= 0 { + if !priceData.UsePrice { + quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) + quota = int(math.Round(float64(quota) * priceData.ModelRatio)) + if priceData.ModelRatio != 0 && quota <= 0 { quota = 1 } } else { - quota = int(modelPrice * common.QuotaPerUnit) + quota = int(priceData.ModelPrice * common.QuotaPerUnit) } tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() consumedTime := float64(milliseconds) / 1000.0 - other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice) - model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", - quota, "模型测试", 0, quota, int(consumedTime), false, "default", other) + other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, priceData.ModelPrice) + model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", + quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return nil, nil } diff --git a/controller/misc.go b/controller/misc.go index 1ea0c13317f3c4c205ab332ab4e784de483b9cb2..a451b5e3faa5172ae3bde0140ce4fb05f6fb4e05 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/model" "one-api/setting" + "one-api/setting/operation_setting" "strings" "github.com/gin-gonic/gin" @@ -66,7 +67,8 @@ func GetStatus(c *gin.Context) { "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", "mj_notify_enabled": setting.MjNotifyEnabled, "chats": setting.Chats, - "demo_site_enabled": setting.DemoSiteEnabled, + "demo_site_enabled": operation_setting.DemoSiteEnabled, + "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, }, }) return diff --git a/controller/pricing.go b/controller/pricing.go index d7af5a4c8341308e7d5f4347ba0a405426e509ba..97f27490956eb3119eca82eb8105392b91dc3751 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -2,7 +2,6 @@ package controller import ( "github.com/gin-gonic/gin" - "one-api/common" "one-api/model" "one-api/setting" ) @@ -40,7 +39,7 @@ func GetPricing(c *gin.Context) { } func ResetModelRatio(c *gin.Context) { - defaultStr := common.DefaultModelRatio2JSONString() + defaultStr := setting.DefaultModelRatio2JSONString() err := model.UpdateOption("ModelRatio", defaultStr) if err != nil { c.JSON(200, gin.H{ @@ -49,7 +48,7 @@ func ResetModelRatio(c *gin.Context) { }) return } - err = common.UpdateModelRatioByJSONString(defaultStr) + err = setting.UpdateModelRatioByJSONString(defaultStr) if err != nil { c.JSON(200, gin.H{ "success": false, diff --git a/controller/relay.go b/controller/relay.go index e27ebb80f5015fc5cdfd1b65a51f64618105e322..460599b54ee90e1be14959ad8d5a74758e52bba5 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -16,6 +16,7 @@ import ( "one-api/relay" "one-api/relay/constant" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "strings" ) @@ -41,15 +42,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode return err } -func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode { - var err *dto.OpenAIErrorWithStatusCode - switch relayMode { - default: - err = relay.TextHelper(c) - } - return err -} - func Relay(c *gin.Context) { relayMode := constant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) @@ -110,7 +102,7 @@ func WssRelay(c *gin.Context) { if err != nil { openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError) - service.WssError(c, ws, openaiErr.Error) + helper.WssError(c, ws, openaiErr.Error) return } @@ -152,7 +144,7 @@ func WssRelay(c *gin.Context) { openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) - service.WssError(c, ws, openaiErr.Error) + helper.WssError(c, ws, openaiErr.Error) } } diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 135e00058ebcdc34b5590acd47f566a505f2f393..bd5f9d2553963ed2bb0364372d8d957af3b977e3 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -51,7 +51,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) return false, nil } @@ -68,7 +68,7 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC now := time.Now().Format(timeFormat) rdb.LPush(ctx, key, now) rdb.LTrim(ctx, key, 0, int64(maxCount-1)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } // Redis限流处理器 @@ -118,7 +118,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g // 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { - inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) @@ -153,20 +153,23 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) // ModelRequestRateLimit 模型请求限流中间件 func ModelRequestRateLimit() func(c *gin.Context) { - // 如果未启用限流,直接放行 - if !setting.ModelRequestRateLimitEnabled { - return defNext - } + return func(c *gin.Context) { + // 在每个请求时检查是否启用限流 + if !setting.ModelRequestRateLimitEnabled { + c.Next() + return + } - // 计算限流参数 - duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) - totalMaxCount := setting.ModelRequestRateLimitCount - successMaxCount := setting.ModelRequestRateLimitSuccessCount + // 计算限流参数 + duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) + totalMaxCount := setting.ModelRequestRateLimitCount + successMaxCount := setting.ModelRequestRateLimitSuccessCount - // 根据存储类型选择限流处理器 - if common.RedisEnabled { - return redisRateLimitHandler(duration, totalMaxCount, successMaxCount) - } else { - return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount) + // 根据存储类型选择并执行限流处理器 + if common.RedisEnabled { + redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + } else { + memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + } } } diff --git a/model/channel.go b/model/channel.go index 12186712d9c0b69e37f0a2fafda340fff87b2f4a..6ff0901d9c3086bfdc1d2d6118b0cfa7e1bb2e7b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -290,35 +290,42 @@ func (channel *Channel) Delete() error { var channelStatusLock sync.Mutex -func UpdateChannelStatusById(id int, status int, reason string) { +func UpdateChannelStatusById(id int, status int, reason string) bool { if common.MemoryCacheEnabled { channelStatusLock.Lock() + defer channelStatusLock.Unlock() + channelCache, _ := CacheGetChannel(id) // 如果缓存渠道存在,且状态已是目标状态,直接返回 if channelCache != nil && channelCache.Status == status { - channelStatusLock.Unlock() - return + return false } // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 if channelCache == nil && status != common.ChannelStatusEnabled { - channelStatusLock.Unlock() - return + return false } CacheUpdateChannelStatus(id, status) - channelStatusLock.Unlock() } err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { common.SysError("failed to update ability status: " + err.Error()) + return false } channel, err := GetChannelById(id, true) if err != nil { // find channel by id error, directly update status - err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error - if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status) + if result.Error != nil { + common.SysError("failed to update channel status: " + result.Error.Error()) + return false + } + if result.RowsAffected == 0 { + return false } } else { + if channel.Status == status { + return false + } // find channel by id success, update status and other info info := channel.GetOtherInfo() info["status_reason"] = reason @@ -328,9 +335,10 @@ func UpdateChannelStatusById(id int, status int, reason string) { err = channel.Save() if err != nil { common.SysError("failed to update channel status: " + err.Error()) + return false } } - + return true } func EnableChannelByTag(tag string) error { diff --git a/model/log.go b/model/log.go index ed7ec2c796ff54134ff529462bf71746f7856ce8..86850a55a44c60d998786c867bf62bbcd894c8bd 100644 --- a/model/log.go +++ b/model/log.go @@ -2,12 +2,13 @@ package model import ( "fmt" - "github.com/gin-gonic/gin" "one-api/common" "os" "strings" "time" + "github.com/gin-gonic/gin" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" ) @@ -18,7 +19,7 @@ type Log struct { CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` Type int `json:"type" gorm:"index:idx_created_at_type"` Content string `json:"content"` - Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` + Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"` TokenName string `json:"token_name" gorm:"index;default:''"` ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` Quota int `json:"quota" gorm:"default:0"` diff --git a/model/option.go b/model/option.go index 64d15ca8e3a66f1147278c3d23c07e822c758d96..a184c069eaf543a5c3f5ffab3abc8957359f89b3 100644 --- a/model/option.go +++ b/model/option.go @@ -4,6 +4,7 @@ import ( "one-api/common" "one-api/setting" "one-api/setting/config" + "one-api/setting/operation_setting" "strconv" "strings" "time" @@ -87,15 +88,15 @@ func InitOptionMap() { common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) - common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() - common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() + common.OptionMap["ModelRatio"] = setting.ModelRatio2JSONString() + common.OptionMap["ModelPrice"] = setting.ModelPrice2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() - common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() + common.OptionMap["CompletionRatio"] = setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["ChatLink2"] = common.ChatLink2 @@ -110,13 +111,14 @@ func InitOptionMap() { common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled) common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) - common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled) + common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled) + common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled) common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) - common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString() + common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() // 自动添加所有注册的模型配置 modelConfigs := config.GlobalConfig.ExportAllConfigs() @@ -242,7 +244,9 @@ func updateOptionMap(key string, value string) (err error) { case "CheckSensitiveEnabled": setting.CheckSensitiveEnabled = boolValue case "DemoSiteEnabled": - setting.DemoSiteEnabled = boolValue + operation_setting.DemoSiteEnabled = boolValue + case "SelfUseModeEnabled": + operation_setting.SelfUseModeEnabled = boolValue case "CheckSensitiveOnPromptEnabled": setting.CheckSensitiveOnPromptEnabled = boolValue case "ModelRequestRateLimitEnabled": @@ -325,7 +329,7 @@ func updateOptionMap(key string, value string) (err error) { common.QuotaForInvitee, _ = strconv.Atoi(value) case "QuotaRemindThreshold": common.QuotaRemindThreshold, _ = strconv.Atoi(value) - case "ShouldPreConsumedQuota": + case "PreConsumedQuota": common.PreConsumedQuota, _ = strconv.Atoi(value) case "ModelRequestRateLimitCount": setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) @@ -340,15 +344,15 @@ func updateOptionMap(key string, value string) (err error) { case "DataExportDefaultTime": common.DataExportDefaultTime = value case "ModelRatio": - err = common.UpdateModelRatioByJSONString(value) + err = setting.UpdateModelRatioByJSONString(value) case "GroupRatio": err = setting.UpdateGroupRatioByJSONString(value) case "UserUsableGroups": err = setting.UpdateUserUsableGroupsByJSONString(value) case "CompletionRatio": - err = common.UpdateCompletionRatioByJSONString(value) + err = setting.UpdateCompletionRatioByJSONString(value) case "ModelPrice": - err = common.UpdateModelPriceByJSONString(value) + err = setting.UpdateModelPriceByJSONString(value) case "TopUpLink": common.TopUpLink = value case "ChatLink": @@ -362,7 +366,7 @@ func updateOptionMap(key string, value string) (err error) { case "SensitiveWords": setting.SensitiveWordsFromString(value) case "AutomaticDisableKeywords": - setting.AutomaticDisableKeywordsFromString(value) + operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) } diff --git a/model/pricing.go b/model/pricing.go index fc709ce4e5f1c73423f2efa5aede50d4bed596ec..2d0aa1b777c4ceaecc5f88c2bdff03c8c330ddd7 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/setting" "sync" "time" ) @@ -64,14 +65,14 @@ func updatePricing() { ModelName: model, EnableGroup: groups, } - modelPrice, findPrice := common.GetModelPrice(model, false) + modelPrice, findPrice := setting.GetModelPrice(model, false) if findPrice { pricing.ModelPrice = modelPrice pricing.QuotaType = 1 } else { - modelRatio, _ := common.GetModelRatio(model) + modelRatio, _ := setting.GetModelRatio(model) pricing.ModelRatio = modelRatio - pricing.CompletionRatio = common.GetCompletionRatio(model) + pricing.CompletionRatio = setting.GetCompletionRatio(model) pricing.QuotaType = 0 } pricingMap = append(pricingMap, pricing) diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index db4df0a9a1b59b3210be3d5d7cfa74ac60273ec6..3fe893b35177c5684fd55de9264a28327f9d4312 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" ) @@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) lastResponseText := "" c.Stream(func(w io.Writer) bool { select { diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index e87ed6ecf94e78fa4ce5d5b27938330955965894..3b615134c73ba1cbeee26fb6ec94148818950028 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -14,7 +14,7 @@ type AwsClaudeRequest struct { TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` - Tools []claude.Tool `json:"tools,omitempty"` + Tools any `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` Thinking *claude.Thinking `json:"thinking,omitempty"` } diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 1b0882b3f016dd0c0d7e246a1851443ecefd7940..976f97cec951eaa12a360f709d2ca1f6da9f5aff 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -12,6 +12,7 @@ import ( relaymodel "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "strings" "time" @@ -203,13 +204,13 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } }) if info.ShouldIncludeUsage { - response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage) - err := service.ObjectData(c, response) + response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage) + err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) } } - service.Done(c) + helper.Done(c) if resp != nil { err = resp.Body.Close() if err != nil { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index d88f521205878e45d376ec329c87aa84a9a6d2a9..62b06413e9efe2aa537dbafd19e75fee13136d92 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" "sync" @@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go index 90f06b265a1040ad268c8218824914c99377908f..9532ca7478045db9ae9003b0dd3b2b6d6a29905c 100644 --- a/relay/channel/claude/dto.go +++ b/relay/channel/claude/dto.go @@ -58,7 +58,7 @@ type ClaudeRequest struct { TopK int `json:"top_k,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"` Stream bool `json:"stream,omitempty"` - Tools []Tool `json:"tools,omitempty"` + Tools any `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` Thinking *Thinking `json:"thinking,omitempty"` } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index e32ee817b3129ca14e1c94457290303fdbc63b24..09154bcbe51bb66e995fc35ccf0d5f9b5f78204a 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -1,7 +1,6 @@ package claude import ( - "bufio" "encoding/json" "fmt" "io" @@ -9,6 +8,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "one-api/setting/model_setting" "strings" @@ -443,28 +443,18 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. usage = &dto.Usage{} responseText := "" createdTime := common.GetTimestamp() - scanner := bufio.NewScanner(resp.Body) - scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) - for scanner.Scan() { - data := scanner.Text() - info.SetFirstResponseTime() - if len(data) < 6 || !strings.HasPrefix(data, "data:") { - continue - } - data = strings.TrimPrefix(data, "data:") - data = strings.TrimSpace(data) + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var claudeResponse ClaudeResponse err := json.Unmarshal([]byte(data), &claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) - continue + return true } response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) if response == nil { - continue + return true } if requestMode == RequestModeCompletion { responseText += claudeResponse.Completion @@ -481,9 +471,9 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. usage.CompletionTokens = claudeUsage.OutputTokens usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens } else if claudeResponse.Type == "content_block_start" { - + return true } else { - continue + return true } } //response.Id = responseId @@ -491,11 +481,12 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. response.Created = createdTime response.Model = info.UpstreamModelName - err = service.ObjectData(c, response) + err = helper.ObjectData(c, response) if err != nil { common.LogError(c, "send_stream_response_failed: "+err.Error()) } - } + return true + }) if requestMode == RequestModeCompletion { usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) @@ -508,14 +499,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } if info.ShouldIncludeUsage { - response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) - err := service.ObjectData(c, response) + response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) + err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) } } - service.Done(c) - resp.Body.Close() + helper.Done(c) + //resp.Body.Close() return nil, usage } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index d21e524db0e0a02327c8a4836452544a834d9085..a487429c18a751d8b29fb1ae5ae5c5ebe302c8b9 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "strings" "time" @@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) - id := service.GetResponseID(c) + helper.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) var responseText string isFirst := true @@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } response.Id = id response.Model = info.UpstreamModelName - err = service.ObjectData(c, response) + err = helper.ObjectData(c, response) if isFirst { isFirst = false info.FirstResponseTime = time.Now() @@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { - response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) - err := service.ObjectData(c, response) + response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) + err := helper.ObjectData(c, response) if err != nil { common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) } } - service.Done(c) + helper.Done(c) err := resp.Body.Close() if err != nil { @@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) } usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) response.Usage = *usage - response.Id = service.GetResponseID(c) + response.Id = helper.GetResponseID(c) jsonResponse, err := json.Marshal(response) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 132039b3b5bfa806a1941e0a7ea2722c65c10316..17b58dbc59dcad3e669cc57471cd9696fae6e694 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "strings" "time" @@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) isFirst := true c.Stream(func(w io.Writer) bool { select { diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 5df34d35c7b273f2ed5f5f14ac370a95647b1d7c..3e62d41c9470e9af0847a75e6e0a8e331a616e54 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -10,6 +10,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "strings" ) @@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) for scanner.Scan() { data := scanner.Text() @@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re responseText += openaiResponse.Choices[0].Delta.GetContentString() } } - err = service.ObjectData(c, openaiResponse) + err = helper.ObjectData(c, openaiResponse) if err != nil { common.SysError(err.Error()) } @@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re if err := scanner.Err(); err != nil { common.SysError("error reading stream: " + err.Error()) } - service.Done(c) + helper.Done(c) err := resp.Body.Close() if err != nil { //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index d5103124edf6d8407913dc613b91349852604ac7..c1ce8219dccb397d1eaf965520e929f7e65f905c 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1,7 +1,6 @@ package gemini import ( - "bufio" "encoding/json" "fmt" "io" @@ -10,6 +9,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "one-api/setting/model_setting" "strings" @@ -429,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) { choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) - is_stop := false + isStop := false for _, candidate := range geminiResponse.Candidates { if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" { - is_stop = true + isStop = true candidate.FinishReason = nil } choice := dto.ChatCompletionsStreamResponseChoice{ @@ -482,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" - response.Model = "gemini" response.Choices = choices - return &response, is_stop + return &response, isStop } func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -492,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createAt := common.GetTimestamp() var usage = &dto.Usage{} - scanner := bufio.NewScanner(resp.Body) - scanner.Split(bufio.ScanLines) - - service.SetEventStreamHeaders(c) - for scanner.Scan() { - data := scanner.Text() - info.SetFirstResponseTime() - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - data = strings.TrimSuffix(data, "\"") + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse err := json.Unmarshal([]byte(data), &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) - continue + return false } - response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse) + response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse) response.Id = id response.Created = createAt response.Model = info.UpstreamModelName @@ -521,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount } - err = service.ObjectData(c, response) + err = helper.ObjectData(c, response) if err != nil { common.LogError(c, err.Error()) } - if is_stop { - response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) - service.ObjectData(c, response) + if isStop { + response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) + helper.ObjectData(c, response) } - } + return true + }) var response *dto.ChatCompletionsStreamResponse @@ -538,14 +527,14 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens if info.ShouldIncludeUsage { - response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) - err := service.ObjectData(c, response) + response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) + err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) } } - service.Done(c) - resp.Body.Close() + helper.Done(c) + //resp.Body.Close() return nil, usage } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index a5bd0e33d10751e3a6b5ecd5c1ea4e38b1c6f8e5..0afe3f51292bc41528473f03e0be7cdf41d6ca49 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -1,10 +1,13 @@ package openai import ( - "bufio" "bytes" "encoding/json" "fmt" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/pkg/errors" "io" "math" "mime/multipart" @@ -14,16 +17,10 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "os" "strings" - "sync" - "time" - - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "github.com/pkg/errors" ) func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { @@ -32,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } if !forceFormat && !thinkToContent { - return service.StringData(c, data) + return helper.StringData(c, data) } var lastStreamResponse dto.ChatCompletionsStreamResponse @@ -41,34 +38,47 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } if !thinkToContent { - return service.ObjectData(c, lastStreamResponse) + return helper.ObjectData(c, lastStreamResponse) + } + + hasThinkingContent := false + for _, choice := range lastStreamResponse.Choices { + if len(choice.Delta.GetReasoningContent()) > 0 { + hasThinkingContent = true + break + } } // Handle think to content conversion - if info.IsFirstResponse { - response := lastStreamResponse.Copy() - for i := range response.Choices { - response.Choices[i].Delta.SetContentString("\n") - response.Choices[i].Delta.SetReasoningContent("") + if info.ThinkingContentInfo.IsFirstThinkingContent { + if hasThinkingContent { + response := lastStreamResponse.Copy() + for i := range response.Choices { + response.Choices[i].Delta.SetContentString("\n") + response.Choices[i].Delta.SetReasoningContent("") + } + info.ThinkingContentInfo.IsFirstThinkingContent = false + return helper.ObjectData(c, response) + } else { + return helper.ObjectData(c, lastStreamResponse) } - service.ObjectData(c, response) } if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 { - return service.ObjectData(c, lastStreamResponse) + return helper.ObjectData(c, lastStreamResponse) } // Process each choice for i, choice := range lastStreamResponse.Choices { // Handle transition from thinking to content - if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse { + if len(choice.Delta.GetContentString()) > 0 && !info.ThinkingContentInfo.SendLastThinkingContent { response := lastStreamResponse.Copy() for j := range response.Choices { - response.Choices[j].Delta.SetContentString("\n") + response.Choices[j].Delta.SetContentString("\n\n\n") response.Choices[j].Delta.SetReasoningContent("") } - info.SendLastReasoningResponse = true - service.ObjectData(c, response) + info.ThinkingContentInfo.SendLastThinkingContent = true + helper.ObjectData(c, response) } // Convert reasoning content to regular content @@ -78,7 +88,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } } - return service.ObjectData(c, lastStreamResponse) + return helper.ObjectData(c, lastStreamResponse) } func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -108,65 +118,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } toolCount := 0 - scanner := bufio.NewScanner(resp.Body) - scanner.Split(bufio.ScanLines) - - service.SetEventStreamHeaders(c) - streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second - if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") { - // twice timeout for o1 model - streamingTimeout *= 2 - } - ticker := time.NewTicker(streamingTimeout) - defer ticker.Stop() - stopChan := make(chan bool) - defer close(stopChan) var ( lastStreamData string - mu sync.Mutex ) - gopool.Go(func() { - for scanner.Scan() { - //info.SetFirstResponseTime() - ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) - data := scanner.Text() - if common.DebugEnabled { - println(data) - } - if len(data) < 6 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" && data[:6] != "[DONE]" { - continue - } - mu.Lock() - data = data[5:] - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "[DONE]") { - if lastStreamData != "" { - err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) - if err != nil { - common.LogError(c, "streaming error: "+err.Error()) - } - info.SetFirstResponseTime() - } - lastStreamData = data - streamItems = append(streamItems, data) + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + if lastStreamData != "" { + err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) + if err != nil { + common.LogError(c, "streaming error: "+err.Error()) } - mu.Unlock() } - common.SafeSendBool(stopChan, true) + lastStreamData = data + streamItems = append(streamItems, data) + return true }) - select { - case <-ticker.C: - // 超时处理逻辑 - common.LogError(c, "streaming timeout") - case <-stopChan: - // 正常结束 - } - shouldSendLastResp := true var lastStreamResponse dto.ChatCompletionsStreamResponse err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse) @@ -274,14 +242,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } if info.ShouldIncludeUsage && !containStreamUsage { - response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage) + response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) - service.ObjectData(c, response) + helper.ObjectData(c, response) } - service.Done(c) + helper.Done(c) - resp.Body.Close() + //resp.Body.Close() return nil, usage } @@ -512,7 +480,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage.InputTokenDetails.TextTokens += textToken localUsage.InputTokenDetails.AudioTokens += audioToken - err = service.WssString(c, targetConn, string(message)) + err = helper.WssString(c, targetConn, string(message)) if err != nil { errChan <- fmt.Errorf("error writing to target: %v", err) return @@ -618,7 +586,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage.OutputTokenDetails.AudioTokens += audioToken } - err = service.WssString(c, clientConn, string(message)) + err = helper.WssString(c, clientConn, string(message)) if err != nil { errChan <- fmt.Errorf("error writing to client: %v", err) return diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 02a3e382914a8129c6f402c25b074a5231401813..c8e337de6a98ead027e6b127f244a17ab9285527 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -9,6 +9,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" ) @@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit dataChan <- string(jsonResponse) stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index dd3ac93feac0e067a44a934b85aa39a88bb4d6c6..5630650f142d0a6a522eb5f312c7fb827f04ba7c 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -14,6 +14,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strconv" "strings" @@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) for scanner.Scan() { data := scanner.Text() @@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError responseText += response.Choices[0].Delta.GetContentString() } - err = service.ObjectData(c, response) + err = helper.ObjectData(c, response) if err != nil { common.SysError(err.Error()) } @@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError common.SysError("error reading stream: " + err.Error()) } - service.Done(c) + helper.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 0d8ccef1fcd2c48d5b8092176dd251dcce72de8a..7ccd3f30dbdc60f3a0ed321cadfcb67a25315285 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/jinzhu/copier" "io" "net/http" "one-api/dto" @@ -28,6 +27,7 @@ var claudeModelMap = map[string]string{ "claude-3-opus-20240229": "claude-3-opus@20240229", "claude-3-haiku-20240307": "claude-3-haiku@20240307", "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620", + "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219", } @@ -86,15 +86,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } else { suffix = "rawPredict" } + model := info.UpstreamModelName if v, ok := claudeModelMap[info.UpstreamModelName]; ok { - info.UpstreamModelName = v + model = v } return fmt.Sprintf( "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", region, adc.ProjectID, region, - info.UpstreamModelName, + model, suffix, ), nil } else if a.RequestMode == RequestModeLlama { @@ -127,13 +128,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re if err != nil { return nil, err } - vertexClaudeReq := &VertexAIClaudeRequest{ - AnthropicVersion: anthropicVersion, - } - if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil { - return nil, errors.New("failed to copy claude request") - } + vertexClaudeReq := copyRequest(claudeReq, anthropicVersion) c.Set("request_model", claudeReq.Model) + info.UpstreamModelName = claudeReq.Model return vertexClaudeReq, nil } else if a.RequestMode == RequestModeGemini { geminiRequest, err := gemini.CovertGemini2OpenAI(*request) diff --git a/relay/channel/vertex/dto.go b/relay/channel/vertex/dto.go index 3889c343a5d101c546e9f9da0e3ef40618ab74c0..4ba570deac63cc16a7f35870d4216c8cea571726 100644 --- a/relay/channel/vertex/dto.go +++ b/relay/channel/vertex/dto.go @@ -1,17 +1,37 @@ package vertex -import "one-api/relay/channel/claude" +import ( + "one-api/relay/channel/claude" +) type VertexAIClaudeRequest struct { AnthropicVersion string `json:"anthropic_version"` Messages []claude.ClaudeMessage `json:"messages"` - System string `json:"system,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` + System any `json:"system,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` - Tools []claude.Tool `json:"tools,omitempty"` + Tools any `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + Thinking *claude.Thinking `json:"thinking,omitempty"` +} + +func copyRequest(req *claude.ClaudeRequest, version string) *VertexAIClaudeRequest { + return &VertexAIClaudeRequest{ + AnthropicVersion: version, + System: req.System, + Messages: req.Messages, + MaxTokens: req.MaxTokens, + Stream: req.Stream, + Temperature: req.Temperature, + TopP: req.TopP, + TopK: req.TopK, + StopSequences: req.StopSequences, + Tools: req.Tools, + ToolChoice: req.ToolChoice, + Thinking: req.Thinking, + } } diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 067ff6e47e7d50b2e92800e10ad03dde15b0a14d..15d33510e50a239733ab85c44d1ea9ec13dbb60a 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -14,6 +14,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" "time" @@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a if err != nil { return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) var usage dto.Usage c.Stream(func(w io.Writer) bool { select { diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 6bdd1c2a26bb5c3841c9a267525272b808e3cbbc..b0cac858003d3a57c8ac08dff011326269272854 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -10,6 +10,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" "sync" @@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index 97d82c718eaa86f6c995f35b558e4cf253333c12..faffec6fcd63c58fec03df73aff7804a0f901e74 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -10,6 +10,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/relay/helper" "one-api/service" "strings" "sync" @@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi } stopChan <- true }() - service.SetEventStreamHeaders(c) + helper.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 022ab62800e7b03aeb8d9485f5bc35c5ca15345d..c1d3f4a4d4a497e68379f8ab0e7cf113fbc9523e 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -12,25 +12,30 @@ import ( "github.com/gorilla/websocket" ) +type ThinkingContentInfo struct { + IsFirstThinkingContent bool + SendLastThinkingContent bool +} + type RelayInfo struct { - ChannelType int - ChannelId int - TokenId int - TokenKey string - UserId int - Group string - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - IsFirstResponse bool - SendLastReasoningResponse bool - ApiType int - IsStream bool - IsPlayground bool - UsePrice bool - RelayMode int - UpstreamModelName string - OriginModelName string + ChannelType int + ChannelId int + TokenId int + TokenKey string + UserId int + Group string + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + isFirstResponse bool + //SendLastReasoningResponse bool + ApiType int + IsStream bool + IsPlayground bool + UsePrice bool + RelayMode int + UpstreamModelName string + OriginModelName string //RecodeModelName string RequestURLPath string ApiVersion string @@ -53,6 +58,7 @@ type RelayInfo struct { UserSetting map[string]interface{} UserEmail string UserQuota int + ThinkingContentInfo } // 定义支持流式选项的通道类型 @@ -95,7 +101,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { UserQuota: c.GetInt(constant.ContextKeyUserQuota), UserSetting: c.GetStringMap(constant.ContextKeyUserSetting), UserEmail: c.GetString(constant.ContextKeyUserEmail), - IsFirstResponse: true, + isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), BaseUrl: c.GetString("base_url"), RequestURLPath: c.Request.URL.String(), @@ -117,6 +123,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Organization: c.GetString("channel_organization"), ChannelSetting: channelSetting, + ThinkingContentInfo: ThinkingContentInfo{ + IsFirstThinkingContent: true, + SendLastThinkingContent: false, + }, } if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true @@ -147,9 +157,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) { } func (info *RelayInfo) SetFirstResponseTime() { - if info.IsFirstResponse { + if info.isFirstResponse { info.FirstResponseTime = time.Now() - info.IsFirstResponse = false + info.isFirstResponse = false } } diff --git a/service/relay.go b/relay/helper/common.go similarity index 99% rename from service/relay.go rename to relay/helper/common.go index 6ffed1e22ca2cc5f239cab9051b5b7ac8257e133..2a72d30a9ce34ec1742d53cd50e9e87dac2ffc4b 100644 --- a/service/relay.go +++ b/relay/helper/common.go @@ -1,4 +1,4 @@ -package service +package helper import ( "encoding/json" diff --git a/relay/helper/price.go b/relay/helper/price.go index 1f4a5b3c5cbc26337179f52baf054c7653521f79..51f640829795ab460082bc4dfc4e2ee069060ee6 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -11,26 +11,33 @@ import ( type PriceData struct { ModelPrice float64 ModelRatio float64 + CompletionRatio float64 GroupRatio float64 UsePrice bool ShouldPreConsumedQuota int } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { - modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false) + modelPrice, usePrice := setting.GetModelPrice(info.OriginModelName, false) groupRatio := setting.GetGroupRatio(info.Group) var preConsumedQuota int var modelRatio float64 + var completionRatio float64 if !usePrice { preConsumedTokens := common.PreConsumedQuota if maxTokens != 0 { preConsumedTokens = promptTokens + maxTokens } var success bool - modelRatio, success = common.GetModelRatio(info.OriginModelName) + modelRatio, success = setting.GetModelRatio(info.OriginModelName) if !success { - return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName) + if info.UserId == 1 { + return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName) + } else { + return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName) + } } + completionRatio = setting.GetCompletionRatio(info.OriginModelName) ratio := modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { @@ -39,6 +46,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens return PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, + CompletionRatio: completionRatio, GroupRatio: groupRatio, UsePrice: usePrice, ShouldPreConsumedQuota: preConsumedQuota, diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go new file mode 100644 index 0000000000000000000000000000000000000000..7a7507f5cd3a05cc49f1f6a5e5a82f5f3ab2e656 --- /dev/null +++ b/relay/helper/stream_scanner.go @@ -0,0 +1,91 @@ +package helper + +import ( + "bufio" + "context" + "io" + "net/http" + "one-api/common" + "one-api/constant" + relaycommon "one-api/relay/common" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) { + + if resp == nil { + return + } + + defer resp.Body.Close() + + streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second + if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") { + // twice timeout for thinking model + streamingTimeout *= 2 + } + + var ( + stopChan = make(chan bool, 2) + scanner = bufio.NewScanner(resp.Body) + ticker = time.NewTicker(streamingTimeout) + ) + + defer func() { + ticker.Stop() + close(stopChan) + }() + + scanner.Split(bufio.ScanLines) + SetEventStreamHeaders(c) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ctx = context.WithValue(ctx, "stop_chan", stopChan) + common.RelayCtxGo(ctx, func() { + for scanner.Scan() { + ticker.Reset(streamingTimeout) + data := scanner.Text() + if common.DebugEnabled { + println(data) + } + + if len(data) < 6 { + continue + } + if data[:5] != "data:" && data[:6] != "[DONE]" { + continue + } + data = data[5:] + data = strings.TrimLeft(data, " ") + data = strings.TrimSuffix(data, "\"") + if !strings.HasPrefix(data, "[DONE]") { + info.SetFirstResponseTime() + success := dataHandler(data) + if !success { + break + } + } + } + + if err := scanner.Err(); err != nil { + if err != io.EOF { + common.LogError(c, "scanner error: "+err.Error()) + } + } + + common.SafeSendBool(stopChan, true) + }) + + select { + case <-ticker.C: + // 超时处理逻辑 + common.LogError(c, "streaming timeout") + case <-stopChan: + // 正常结束 + } +} diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 57de8d100dd5014e2c53d8a27f819fcc074d7ae0..8baf033aba21cd4dbd9ae0c514ffa17614164127 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -157,10 +157,10 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") } modelName := service.CoverActionToModelName(constant.MjActionSwapFace) - modelPrice, success := common.GetModelPrice(modelName, true) + modelPrice, success := setting.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 if !success { - defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName] + defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName] if !ok { modelPrice = 0.1 } else { @@ -463,10 +463,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) modelName := service.CoverActionToModelName(midjRequest.Action) - modelPrice, success := common.GetModelPrice(modelName, true) + modelPrice, success := setting.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 if !success { - defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName] + defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName] if !ok { modelPrice = 0.1 } else { diff --git a/relay/relay-text.go b/relay/relay-text.go index eb331e256481faab6796341ee6e2e75ca8eb2cef..bf6c5fd318cad362db76b0e51e59dfe76251eee0 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -311,7 +311,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(modelName) + completionRatio := setting.GetCompletionRatio(modelName) ratio := priceData.ModelRatio * priceData.GroupRatio modelRatio := priceData.ModelRatio groupRatio := priceData.GroupRatio diff --git a/relay/relay_task.go b/relay/relay_task.go index 591ad3bb7207178df35e544b1d3aac7ff42ca7fc..ab35d3e8a6ab24bbe2a8f5642b676034441b672b 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -37,9 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) - modelPrice, success := common.GetModelPrice(modelName, true) + modelPrice, success := setting.GetModelPrice(modelName, true) if !success { - defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName] + defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName] if !ok { modelPrice = 0.1 } else { diff --git a/relay/websocket.go b/relay/websocket.go index 2dac60afbc0a1b98330e79a61be27f75d73b29e2..b0636057b4e66b673e3a3b77367a896d05a8841f 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -39,7 +39,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi } } //relayInfo.UpstreamModelName = textRequest.Model - modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false) + modelPrice, getModelPriceSuccess := setting.GetModelPrice(relayInfo.UpstreamModelName, false) groupRatio := setting.GetGroupRatio(relayInfo.Group) var preConsumedQuota int @@ -65,7 +65,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi //if realtimeEvent.Session.MaxResponseOutputTokens != 0 { // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens) //} - modelRatio, _ = common.GetModelRatio(relayInfo.UpstreamModelName) + modelRatio, _ = setting.GetModelRatio(relayInfo.UpstreamModelName) ratio = modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { diff --git a/service/channel.go b/service/channel.go index 76bcacf1bb1433d3716519f567de3b3ede97205f..e3a76af4120db366e22a433cc82ba2a80b860e35 100644 --- a/service/channel.go +++ b/service/channel.go @@ -6,23 +6,31 @@ import ( "one-api/common" "one-api/dto" "one-api/model" - "one-api/setting" + "one-api/setting/operation_setting" "strings" ) +func formatNotifyType(channelId int, status int) string { + return fmt.Sprintf("%s_%d_%d", dto.NotifyTypeChannelUpdate, channelId, status) +} + // disable & notify func DisableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate) + success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) + if success { + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content) + } } func EnableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") - subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate) + success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") + if success { + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusEnabled), subject, content) + } } func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool { @@ -67,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b } lowerMessage := strings.ToLower(err.Error.Message) - search, _ := AcSearch(lowerMessage, setting.AutomaticDisableKeywords, true) + search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true) if search { return true } diff --git a/service/image.go b/service/image.go index 77a0cc7a7442bccbfb0a0af0ff4a57ee4cfe5a62..252093f1ffa1e6a08971b5a708f5d073a81547f1 100644 --- a/service/image.go +++ b/service/image.go @@ -7,7 +7,9 @@ import ( "fmt" "image" "io" + "net/http" "one-api/common" + "one-api/constant" "strings" "golang.org/x/image/webp" @@ -23,7 +25,7 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e decodedData, err := base64.StdEncoding.DecodeString(base64String) if err != nil { fmt.Println("Error: Failed to decode base64 string") - return image.Config{}, "", "", err + return image.Config{}, "", "", fmt.Errorf("failed to decode base64 string: %s", err.Error()) } // 创建一个bytes.Buffer用于存储解码后的数据 @@ -61,20 +63,51 @@ func DecodeBase64FileData(base64String string) (string, string, error) { func GetImageFromUrl(url string) (mimeType string, data string, err error) { resp, err := DoDownloadRequest(url) if err != nil { - return "", "", err - } - if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { - return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type")) + return "", "", fmt.Errorf("failed to download image: %w", err) } defer resp.Body.Close() - buffer := bytes.NewBuffer(nil) - _, err = buffer.ReadFrom(resp.Body) + + // Check HTTP status code + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("failed to download image: HTTP %d", resp.StatusCode) + } + + contentType := resp.Header.Get("Content-Type") + if contentType != "application/octet-stream" && !strings.HasPrefix(contentType, "image/") { + return "", "", fmt.Errorf("invalid content type: %s, required image/*", contentType) + } + maxImageSize := int64(constant.MaxFileDownloadMB * 1024 * 1024) + + // Check Content-Length if available + if resp.ContentLength > maxImageSize { + return "", "", fmt.Errorf("image size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxImageSize) + } + + // Use LimitReader to prevent reading oversized images + limitReader := io.LimitReader(resp.Body, maxImageSize) + buffer := &bytes.Buffer{} + + written, err := io.Copy(buffer, limitReader) if err != nil { - return + return "", "", fmt.Errorf("failed to read image data: %w", err) + } + if written >= maxImageSize { + return "", "", fmt.Errorf("image size exceeds maximum allowed size of %d bytes", maxImageSize) } - mimeType = resp.Header.Get("Content-Type") + data = base64.StdEncoding.EncodeToString(buffer.Bytes()) - return + mimeType = contentType + + // Handle application/octet-stream type + if mimeType == "application/octet-stream" { + _, format, _, err := DecodeBase64ImageData(data) + if err != nil { + return "", "", err + } + mimeType = "image/" + format + } + + return mimeType, data, nil } func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { @@ -92,7 +125,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { mimeType := response.Header.Get("Content-Type") - if !strings.HasPrefix(mimeType, "image/") { + if mimeType != "application/octet-stream" && !strings.HasPrefix(mimeType, "image/") { return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType) } diff --git a/service/quota.go b/service/quota.go index 9ce2858d38bf6260ee7cba3634b4b608952c8b10..b3412c1ea337c85c7ba62b887966977fdc8383fb 100644 --- a/service/quota.go +++ b/service/quota.go @@ -38,9 +38,9 @@ func calculateAudioQuota(info QuotaInfo) int { return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio) } - completionRatio := common.GetCompletionRatio(info.ModelName) - audioRatio := common.GetAudioRatio(info.ModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName) + completionRatio := setting.GetCompletionRatio(info.ModelName) + audioRatio := setting.GetAudioRatio(info.ModelName) + audioCompletionRatio := setting.GetAudioCompletionRatio(info.ModelName) ratio := info.GroupRatio * info.ModelRatio quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio)) @@ -75,7 +75,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag audioInputTokens := usage.InputTokenDetails.AudioTokens audioOutTokens := usage.OutputTokenDetails.AudioTokens groupRatio := setting.GetGroupRatio(relayInfo.Group) - modelRatio, _ := common.GetModelRatio(modelName) + modelRatio, _ := setting.GetModelRatio(modelName) quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -122,9 +122,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioOutTokens := usage.OutputTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(modelName) - audioRatio := common.GetAudioRatio(relayInfo.OriginModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(modelName) + completionRatio := setting.GetCompletionRatio(modelName) + audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName) + audioCompletionRatio := setting.GetAudioCompletionRatio(modelName) quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -184,9 +184,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioOutTokens := usage.CompletionTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName) - audioRatio := common.GetAudioRatio(relayInfo.OriginModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName) + completionRatio := setting.GetCompletionRatio(relayInfo.OriginModelName) + audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName) + audioCompletionRatio := setting.GetAudioCompletionRatio(relayInfo.OriginModelName) modelRatio := priceData.ModelRatio groupRatio := priceData.GroupRatio diff --git a/service/token_counter.go b/service/token_counter.go index aa62bc6e7b75b83872ba453e621e8f902a5be65b..e868beb476f56bd6a9427f742eba6a174a84c106 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -10,6 +10,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/setting" "strings" "unicode/utf8" @@ -32,7 +33,7 @@ func InitTokenEncoders() { if err != nil { common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) } - for model, _ := range common.GetDefaultModelRatioMap() { + for model, _ := range setting.GetDefaultModelRatioMap() { if strings.HasPrefix(model, "gpt-3.5") { tokenEncoderMap[model] = cl100TokenEncoder } else if strings.HasPrefix(model, "gpt-4") { diff --git a/service/user_notify.go b/service/user_notify.go index db291f0fe73d2fc520529f1b8238a28fcf0b65cb..51f1ff9965528a9132636c57b4e0427776b78782 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -11,7 +11,10 @@ import ( func NotifyRootUser(t string, subject string, content string) { user := model.GetRootUser().ToBaseUser() - _ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil)) + err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil)) + if err != nil { + common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error())) + } } func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error { diff --git a/common/model-ratio.go b/setting/model-ratio.go similarity index 97% rename from common/model-ratio.go rename to setting/model-ratio.go index 036811720ddbfa4c8e5f24be06f144d97ed73f76..54b214f92dc71bbc84350bf6d1c9407213deed8c 100644 --- a/common/model-ratio.go +++ b/setting/model-ratio.go @@ -1,7 +1,9 @@ -package common +package setting import ( "encoding/json" + "one-api/common" + "one-api/setting/operation_setting" "strings" "sync" ) @@ -261,7 +263,7 @@ func ModelPrice2JSONString() string { GetModelPriceMap() jsonBytes, err := json.Marshal(modelPriceMap) if err != nil { - SysError("error marshalling model price: " + err.Error()) + common.SysError("error marshalling model price: " + err.Error()) } return string(jsonBytes) } @@ -285,7 +287,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) { price, ok := modelPriceMap[name] if !ok { if printErr { - SysError("model price not found: " + name) + common.SysError("model price not found: " + name) } return -1, false } @@ -305,7 +307,7 @@ func ModelRatio2JSONString() string { GetModelRatioMap() jsonBytes, err := json.Marshal(modelRatioMap) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + common.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -324,8 +326,7 @@ func GetModelRatio(name string) (float64, bool) { } ratio, ok := modelRatioMap[name] if !ok { - SysError("model ratio not found: " + name) - return 37.5, false + return 37.5, operation_setting.SelfUseModeEnabled } return ratio, true } @@ -333,7 +334,7 @@ func GetModelRatio(name string) (float64, bool) { func DefaultModelRatio2JSONString() string { jsonBytes, err := json.Marshal(defaultModelRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + common.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -355,7 +356,7 @@ func CompletionRatio2JSONString() string { GetCompletionRatioMap() jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { - SysError("error marshalling completion ratio: " + err.Error()) + common.SysError("error marshalling completion ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/operation_setting.go b/setting/operation_setting/operation_setting.go similarity index 92% rename from setting/operation_setting.go rename to setting/operation_setting/operation_setting.go index 4940d0fc6cc3718b9a633627f1323eb856bd1c53..ef330d1adb876a6a2e0e00a1e1d1261d7e9b4faa 100644 --- a/setting/operation_setting.go +++ b/setting/operation_setting/operation_setting.go @@ -1,8 +1,9 @@ -package setting +package operation_setting import "strings" var DemoSiteEnabled = false +var SelfUseModeEnabled = false var AutomaticDisableKeywords = []string{ "Your credit balance is too low", diff --git a/web/src/App.js b/web/src/App.js index 05fd597f076acb4b4a4ed5447bdfa8c08fd8a33c..9e9407c01cf6f20769d8dd27b84a411a17c013ca 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -30,6 +30,7 @@ import { useTranslation } from 'react-i18next'; import { StatusContext } from './context/Status'; import { setStatusData } from './helpers/data.js'; import { API, showError } from './helpers'; +import PersonalSetting from './components/PersonalSetting.js'; const Home = lazy(() => import('./pages/Home')); const Detail = lazy(() => import('./pages/Detail')); @@ -177,6 +178,16 @@ function App() { } /> + + }> + + + + } + /> { > {t('测试')} - - - + { const [enableTagMode, setEnableTagMode] = useState(false); const [showBatchSetTag, setShowBatchSetTag] = useState(false); const [batchSetTagValue, setBatchSetTagValue] = useState(''); + const [showModelTestModal, setShowModelTestModal] = useState(false); + const [currentTestChannel, setCurrentTestChannel] = useState(null); + const [modelSearchKeyword, setModelSearchKeyword] = useState(''); const removeRecord = (record) => { @@ -1289,6 +1290,77 @@ const ChannelsTable = () => { onChange={(v) => setBatchSetTagValue(v)} /> + + {/* 模型测试弹窗 */} + { + setShowModelTestModal(false); + setModelSearchKeyword(''); + }} + footer={null} + maskClosable={true} + centered={true} + width={600} + > +
+ {currentTestChannel && ( +
+ + {t('渠道')}: {currentTestChannel.name} + + + {/* 搜索框 */} + setModelSearchKeyword(value)} + style={{ marginBottom: '16px' }} + showClear + /> + +
+ {currentTestChannel.models.split(',') + .filter(model => model.toLowerCase().includes(modelSearchKeyword.toLowerCase())) + .map((model, index) => { + + return ( + + ); + })} +
+ + {/* 显示搜索结果数量 */} + {modelSearchKeyword && ( + + {t('找到')} {currentTestChannel.models.split(',').filter(model => + model.toLowerCase().includes(modelSearchKeyword.toLowerCase()) + ).length} {t('个模型')} + + )} +
+ )} +
+
); }; diff --git a/web/src/components/HeaderBar.js b/web/src/components/HeaderBar.js index c9105e7106eadb2108209aebd77a174adfd0b138..68169ed26a0c00e93a02ffa322acb589dc904604 100644 --- a/web/src/components/HeaderBar.js +++ b/web/src/components/HeaderBar.js @@ -21,15 +21,17 @@ import { IconUser, IconLanguage } from '@douyinfe/semi-icons'; -import { Avatar, Button, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui'; +import { Avatar, Button, Dropdown, Layout, Nav, Switch, Tag } from '@douyinfe/semi-ui'; import { stringToColor } from '../helpers/render'; import Text from '@douyinfe/semi-ui/lib/es/typography/text'; import { StyleContext } from '../context/Style/index.js'; +import { StatusContext } from '../context/Status/index.js'; const HeaderBar = () => { const { t, i18n } = useTranslation(); const [userState, userDispatch] = useContext(UserContext); const [styleState, styleDispatch] = useContext(StyleContext); + const [statusState, statusDispatch] = useContext(StatusContext); let navigate = useNavigate(); const [currentLang, setCurrentLang] = useState(i18n.language); @@ -40,6 +42,10 @@ const HeaderBar = () => { const isNewYear = (currentDate.getMonth() === 0 && currentDate.getDate() === 1); + // Check if self-use mode is enabled + const isSelfUseMode = statusState?.status?.self_use_mode_enabled || false; + const isDemoSiteMode = statusState?.status?.demo_site_enabled || false; + let buttons = [ { text: t('首页'), @@ -166,7 +172,7 @@ const HeaderBar = () => { onSelect={(key) => {}} header={styleState.isMobile?{ logo: ( - <> +
{ !styleState.showSider ?
), }:{ logo: ( logo ), - text: systemName, + text: ( +
+ {systemName} + {(isSelfUseMode || isDemoSiteMode) && ( + + {isSelfUseMode ? t('自用模式') : t('演示站点')} + + )} +
+ ), }} items={buttons} footer={ @@ -266,7 +311,8 @@ const HeaderBar = () => { icon={} /> { - !styleState.isMobile && ( + // Hide register option in self-use mode + !styleState.isMobile && !isSelfUseMode && ( { RetryTimes: 0, Chats: "[]", DemoSiteEnabled: false, + SelfUseModeEnabled: false, AutomaticDisableKeywords: '', }); diff --git a/web/src/components/OtherSetting.js b/web/src/components/OtherSetting.js index dad79fd1be271e46c466802c93ba8340fb7c96a4..e3295fb171c24aed80cd1786cc1120203e55d45d 100644 --- a/web/src/components/OtherSetting.js +++ b/web/src/components/OtherSetting.js @@ -1,8 +1,10 @@ -import React, { useEffect, useRef, useState } from 'react'; -import { Banner, Button, Col, Form, Row } from '@douyinfe/semi-ui'; -import { API, showError, showSuccess } from '../helpers'; +import React, { useContext, useEffect, useRef, useState } from 'react'; +import { Banner, Button, Col, Form, Row, Modal, Space } from '@douyinfe/semi-ui'; +import { API, showError, showSuccess, timestamp2string } from '../helpers'; import { marked } from 'marked'; import { useTranslation } from 'react-i18next'; +import { StatusContext } from '../context/Status/index.js'; +import Text from '@douyinfe/semi-ui/lib/es/typography/text'; const OtherSetting = () => { const { t } = useTranslation(); @@ -16,6 +18,7 @@ const OtherSetting = () => { }); let [loading, setLoading] = useState(false); const [showUpdateModal, setShowUpdateModal] = useState(false); + const [statusState, statusDispatch] = useContext(StatusContext); const [updateData, setUpdateData] = useState({ tag_name: '', content: '', @@ -43,6 +46,7 @@ const OtherSetting = () => { HomePageContent: false, About: false, Footer: false, + CheckUpdate: false }); const handleInputChange = async (value, e) => { const name = e.target.id; @@ -145,23 +149,48 @@ const OtherSetting = () => { } }; - const openGitHubRelease = () => { - window.location = 'https://github.com/songquanpeng/one-api/releases/latest'; - }; - const checkUpdate = async () => { - const res = await API.get( - 'https://api.github.com/repos/songquanpeng/one-api/releases/latest', - ); - const { tag_name, body } = res.data; - if (tag_name === process.env.REACT_APP_VERSION) { - showSuccess(`已是最新版本:${tag_name}`); - } else { - setUpdateData({ - tag_name: tag_name, - content: marked.parse(body), - }); - setShowUpdateModal(true); + try { + setLoadingInput((loadingInput) => ({ ...loadingInput, CheckUpdate: true })); + // Use a CORS proxy to avoid direct cross-origin requests to GitHub API + // Option 1: Use a public CORS proxy service + // const proxyUrl = 'https://cors-anywhere.herokuapp.com/'; + // const res = await API.get( + // `${proxyUrl}https://api.github.com/repos/Calcium-Ion/new-api/releases/latest`, + // ); + + // Option 2: Use the JSON proxy approach which often works better with GitHub API + const res = await fetch( + 'https://api.github.com/repos/Calcium-Ion/new-api/releases/latest', + { + headers: { + 'Accept': 'application/json', + 'Content-Type': 'application/json', + // Adding User-Agent which is often required by GitHub API + 'User-Agent': 'new-api-update-checker' + } + } + ).then(response => response.json()); + + // Option 3: Use a local proxy endpoint + // Create a cached version of the response to avoid frequent GitHub API calls + // const res = await API.get('/api/status/github-latest-release'); + + const { tag_name, body } = res; + if (tag_name === statusState?.status?.version) { + showSuccess(`已是最新版本:${tag_name}`); + } else { + setUpdateData({ + tag_name: tag_name, + content: marked.parse(body), + }); + setShowUpdateModal(true); + } + } catch (error) { + console.error('Failed to check for updates:', error); + showError('检查更新失败,请稍后再试'); + } finally { + setLoadingInput((loadingInput) => ({ ...loadingInput, CheckUpdate: false })); } }; const getOptions = async () => { @@ -186,9 +215,41 @@ const OtherSetting = () => { getOptions(); }, []); + // Function to open GitHub release page + const openGitHubRelease = () => { + window.open(`https://github.com/Calcium-Ion/new-api/releases/tag/${updateData.tag_name}`, '_blank'); + }; + + const getStartTimeString = () => { + const timestamp = statusState?.status?.start_time; + return statusState.status ? timestamp2string(timestamp) : ''; + }; + return ( + {/* 版本信息 */} +
+ + + + + + {t('当前版本')}:{statusState?.status?.version || t('未知')} + + + + + + + + {t('启动时间')}:{getStartTimeString()} + + + +
{/* 通用设置 */}
{
- {/* setShowUpdateModal(false)}*/} - {/* onOpen={() => setShowUpdateModal(true)}*/} - {/* open={showUpdateModal}*/} - {/*>*/} - {/* 新版本:{updateData.tag_name}*/} - {/* */} - {/* */} - {/*
*/} - {/*
*/} - {/*
*/} - {/* */} - {/* */} - {/* {*/} - {/* setShowUpdateModal(false);*/} - {/* openGitHubRelease();*/} - {/* }}*/} - {/* />*/} - {/* */} - {/**/} + setShowUpdateModal(false)} + footer={[ + + ]} + > +
+
); }; diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index 49a0784ca412c034c28b82f874fa906df681a5d0..5ca95397c3f1bc83b871dfd015315834f2f22559 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -69,7 +69,11 @@ const PersonalSetting = () => { const [models, setModels] = useState([]); const [openTransfer, setOpenTransfer] = useState(false); const [transferAmount, setTransferAmount] = useState(0); - const [isModelsExpanded, setIsModelsExpanded] = useState(false); + const [isModelsExpanded, setIsModelsExpanded] = useState(() => { + // Initialize from localStorage if available + const savedState = localStorage.getItem('modelsExpanded'); + return savedState ? JSON.parse(savedState) : false; + }); const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量 const [notificationSettings, setNotificationSettings] = useState({ warningType: 'email', @@ -124,6 +128,11 @@ const PersonalSetting = () => { } }, [userState?.user?.setting]); + // Save models expanded state to localStorage whenever it changes + useEffect(() => { + localStorage.setItem('modelsExpanded', JSON.stringify(isModelsExpanded)); + }, [isModelsExpanded]); + const handleInputChange = (name, value) => { setInputs((inputs) => ({...inputs, [name]: value})); }; @@ -384,7 +393,7 @@ const PersonalSetting = () => { -
+
{ pricing: '/pricing', task: '/task', playground: '/playground', + personal: '/personal', }; - const headerButtons = useMemo( + const workspaceItems = useMemo( () => [ - { - text: 'Playground', - itemKey: 'playground', - to: '/playground', - icon: , - }, - { - text: t('渠道'), - itemKey: 'channel', - to: '/channel', - icon: , - className: isAdmin() ? '' : 'tableHiddle', - }, - { - text: t('聊天'), - itemKey: 'chat', - items: chatItems, - icon: , - }, - { - text: t('令牌'), - itemKey: 'token', - to: '/token', - icon: , - }, { text: t('数据看板'), itemKey: 'detail', @@ -105,33 +81,19 @@ const SiderBar = () => { : 'tableHiddle', }, { - text: t('兑换码'), - itemKey: 'redemption', - to: '/redemption', - icon: , - className: isAdmin() ? '' : 'tableHiddle', - }, - { - text: t('钱包'), - itemKey: 'topup', - to: '/topup', - icon: , - }, - { - text: t('用户管理'), - itemKey: 'user', - to: '/user', - icon: , - className: isAdmin() ? '' : 'tableHiddle', + text: t('API令牌'), + itemKey: 'token', + to: '/token', + icon: , }, { - text: t('日志'), + text: t('使用日志'), itemKey: 'log', to: '/log', icon: , }, { - text: t('绘图'), + text: t('绘图日志'), itemKey: 'midjourney', to: '/midjourney', icon: , @@ -141,31 +103,90 @@ const SiderBar = () => { : 'tableHiddle', }, { - text: t('异步任务'), + text: t('任务日志'), itemKey: 'task', to: '/task', icon: , className: - localStorage.getItem('enable_task') === 'true' - ? '' - : 'tableHiddle', + localStorage.getItem('enable_task') === 'true' + ? '' + : 'tableHiddle', + } + ], + [ + localStorage.getItem('enable_data_export'), + localStorage.getItem('enable_drawing'), + localStorage.getItem('enable_task'), + t, + ], + ); + + const financeItems = useMemo( + () => [ + { + text: t('钱包'), + itemKey: 'topup', + to: '/topup', + icon: , }, { - text: t('设置'), + text: t('个人设置'), + itemKey: 'personal', + to: '/personal', + icon: , + }, + ], + [t], + ); + + const adminItems = useMemo( + () => [ + { + text: t('渠道'), + itemKey: 'channel', + to: '/channel', + icon: , + className: isAdmin() ? '' : 'tableHiddle', + }, + { + text: t('兑换码'), + itemKey: 'redemption', + to: '/redemption', + icon: , + className: isAdmin() ? '' : 'tableHiddle', + }, + { + text: t('用户管理'), + itemKey: 'user', + to: '/user', + icon: , + }, + { + text: t('系统设置'), itemKey: 'setting', to: '/setting', icon: , }, ], - [ - localStorage.getItem('enable_data_export'), - localStorage.getItem('enable_drawing'), - localStorage.getItem('enable_task'), - localStorage.getItem('chat_link'), - chatItems, - isAdmin(), - t, + [isAdmin(), t], + ); + + const chatMenuItems = useMemo( + () => [ + { + text: 'Playground', + itemKey: 'playground', + to: '/playground', + icon: , + }, + { + text: t('聊天'), + itemKey: 'chat', + items: chatItems, + icon: , + }, ], + [chatItems, t], ); useEffect(() => { @@ -174,42 +195,56 @@ const SiderBar = () => { localKey = 'home'; } setSelectedKeys([localKey]); - + let chatLink = localStorage.getItem('chat_link'); if (!chatLink) { - let chats = localStorage.getItem('chats'); - if (chats) { - // console.log(chats); - try { - chats = JSON.parse(chats); - if (Array.isArray(chats)) { - let chatItems = []; - for (let i = 0; i < chats.length; i++) { - let chat = {}; - for (let key in chats[i]) { - chat.text = key; - chat.itemKey = 'chat' + i; - chat.to = '/chat/' + i; - } - // setRouterMap({ ...routerMap, chat: '/chat/' + i }) - chatItems.push(chat); - } - setChatItems(chatItems); - } - } catch (e) { - console.error(e); - showError('聊天数据解析失败') + let chats = localStorage.getItem('chats'); + if (chats) { + // console.log(chats); + try { + chats = JSON.parse(chats); + if (Array.isArray(chats)) { + let chatItems = []; + for (let i = 0; i < chats.length; i++) { + let chat = {}; + for (let key in chats[i]) { + chat.text = key; + chat.itemKey = 'chat' + i; + chat.to = '/chat/' + i; + } + // setRouterMap({ ...routerMap, chat: '/chat/' + i }) + chatItems.push(chat); } + setChatItems(chatItems); + } + } catch (e) { + console.error(e); + showError('聊天数据解析失败') } + } } - + setIsCollapsed(localStorage.getItem('default_collapse_sidebar') === 'true'); }, []); + // Custom divider style + const dividerStyle = { + margin: '8px 0', + opacity: 0.6, + }; + + // Custom group label style + const groupLabelStyle = { + padding: '8px 16px', + color: 'var(--semi-color-text-2)', + fontSize: '12px', + fontWeight: 'normal', + }; + return ( <> ); diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index dec74b06237a768e3126453fab2b720853e04492..5738d656bb1fca9c015344e5eeb0804861e6d655 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -82,7 +82,7 @@ export const CHANNEL_OPTIONS = [ { value: 45, color: 'blue', - label: '火山方舟(豆包)' + label: '字节火山方舟、豆包、DeepSeek通用' }, { value: 25, color: 'green', label: 'Moonshot' }, { value: 19, color: 'blue', label: '360 智脑' }, diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index aa2fb2d5eed6b9a89ec0a7aed66e2a9e7adeba18..89b2bcbb2a76297840b18cf42da33a05aedca81c 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -621,6 +621,7 @@ "窗口等待": "window wait", "失败": "Failed", "绘图": "Drawing", + "绘图日志": "Drawing log", "放大": "Upscalers", "微妙放大": "Upscale (Subtle)", "创造放大": "Upscale (Creative)", @@ -1120,7 +1121,7 @@ "知识库 ID": "Knowledge Base ID", "请输入知识库 ID,例如:123456": "Please enter knowledge base ID, e.g.: 123456", "可选值": "Optional value", - "异步任务": "Async task", + "任务日志": "Task log", "你好": "Hello", "你好,请问有什么可以帮助您的吗?": "Hello, how may I help you?", "用户分组": "Your default group", @@ -1317,5 +1318,25 @@ "当前设置类型: ": "Current setting type: ", "固定价格值": "Fixed Price Value", "未设置倍率模型": "Models without ratio settings", - "模型倍率和补全倍率同时设置": "Both model ratio and completion ratio are set" + "模型倍率和补全倍率同时设置": "Both model ratio and completion ratio are set", + "自用模式": "Self-use mode", + "开启后不限制:必须设置模型倍率": "After enabling, no limit: must set model ratio", + "演示站点模式": "Demo site mode", + "当前版本": "Current version", + "Gemini设置": "Gemini settings", + "Gemini安全设置": "Gemini safety settings", + "default为默认设置,可单独设置每个分类的安全等级": "\"default\" is the default setting, and each category can be set separately", + "Gemini版本设置": "Gemini version settings", + "default为默认设置,可单独设置每个模型的版本": "\"default\" is the default setting, and each model can be set separately", + "Claude设置": "Claude settings", + "Claude请求头覆盖": "Claude request header override", + "示例": "Example", + "缺省 MaxTokens": "Default MaxTokens", + "启用Claude思考适配(-thinking后缀)": "Enable Claude thinking adaptation (-thinking suffix)", + "Claude思考适配 BudgetTokens = MaxTokens * BudgetTokens 百分比": "Claude thinking adaptation BudgetTokens = MaxTokens * BudgetTokens percentage", + "思考适配 BudgetTokens 百分比": "Thinking adaptation BudgetTokens percentage", + "0.1-1之间的小数": "Decimal between 0.1 and 1", + "模型相关设置": "Model related settings", + "收起侧边栏": "Collapse sidebar", + "展开侧边栏": "Expand sidebar" } diff --git a/web/src/pages/Setting/Model/SettingClaudeModel.js b/web/src/pages/Setting/Model/SettingClaudeModel.js index 76ee8cfa351ebef3b87def51832d8fada9f1845d..1cddd8390b8878dc38dbed9c2d5a363b5e4a1a49 100644 --- a/web/src/pages/Setting/Model/SettingClaudeModel.js +++ b/web/src/pages/Setting/Model/SettingClaudeModel.js @@ -18,6 +18,8 @@ const CLAUDE_HEADER = { const CLAUDE_DEFAULT_MAX_TOKENS = { 'default': 8192, + "claude-3-haiku-20240307": 4096, + "claude-3-opus-20240229": 4096, 'claude-3-7-sonnet-20250219-thinking': 8192, } diff --git a/web/src/pages/Setting/Operation/SettingsGeneral.js b/web/src/pages/Setting/Operation/SettingsGeneral.js index 1c98d33eb7e6d088efbfb9870c268a406a3d06b8..e46e7db2aa234388e54d17af7190f46925878b06 100644 --- a/web/src/pages/Setting/Operation/SettingsGeneral.js +++ b/web/src/pages/Setting/Operation/SettingsGeneral.js @@ -22,6 +22,7 @@ export default function GeneralSettings(props) { DisplayTokenStatEnabled: false, DefaultCollapseSidebar: false, DemoSiteEnabled: false, + SelfUseModeEnabled: false, }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -205,6 +206,22 @@ export default function GeneralSettings(props) { } /> + + + setInputs({ + ...inputs, + SelfUseModeEnabled: value + }) + } + /> +