diff --git a/.env.example b/.env.example index 07602eca08e83e9aa5d422a4f10c16951ed3e75d..bece06dbc8276aef137f8a16080e08579c99afdd 100644 --- a/.env.example +++ b/.env.example @@ -50,10 +50,6 @@ # CHANNEL_TEST_FREQUENCY=10 # 生成默认token # GENERATE_DEFAULT_TOKEN=false -# Gemini 安全设置 -# GEMINI_SAFETY_SETTING=BLOCK_NONE -# Gemini版本设置 -# GEMINI_MODEL_MAP=gemini-1.0-pro:v1 # Cohere 安全设置 # COHERE_SAFETY_SETTING=NONE # 是否统计图片token diff --git a/README.en.md b/README.en.md index 446c88f61455c2fbcbbbb0e4358214ae5d217e30..51cf38bb252d65f9a27ef043d039605075f7957b 100644 --- a/README.en.md +++ b/README.en.md @@ -63,6 +63,8 @@ - Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`) - Add suffix `-medium` to set medium reasoning effort - Add suffix `-low` to set low reasoning effort +17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `` tags and concatenated to the content returned. +18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings` ## Model Support This version additionally supports: diff --git a/VERSION b/VERSION index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..bab2e920752421bdf1b3550124d43709a8caabb6 100644 --- a/VERSION +++ b/VERSION @@ -0,0 +1 @@ +v0.4.8.8.3 \ No newline at end of file diff --git a/common/constants.go b/common/constants.go index 04fb1b9a632852f9b646c0dda946a1eaef386e95..bcab24fc0a367c9e52b3589e64e462a102900ef7 100644 --- a/common/constants.go +++ b/common/constants.go @@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{ "https://api.cohere.ai", //34 "https://api.minimax.chat", //35 "", //36 - "", //37 + "https://api.dify.ai", //37 "https://api.jina.ai", //38 "https://api.cloudflare.com", //39 "https://api.siliconflow.cn", //40 diff --git a/common/model-ratio.go b/common/model-ratio.go index 542cd93c6965310e95da0c7e045d1b26da95c3ae..036811720ddbfa4c8e5f24be06f144d97ed73f76 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -50,24 +50,26 @@ var defaultModelRatio = map[string]float64{ "gpt-4o-realtime-preview-2024-12-17": 2.5, "gpt-4o-mini-realtime-preview": 0.3, "gpt-4o-mini-realtime-preview-2024-12-17": 0.3, - "o1": 7.5, - "o1-2024-12-17": 7.5, - "o1-preview": 7.5, - "o1-preview-2024-09-12": 7.5, - "o1-mini": 0.55, - "o1-mini-2024-09-12": 0.55, - "o3-mini": 0.55, - "o3-mini-2025-01-31": 0.55, - "o3-mini-high": 0.55, - "o3-mini-2025-01-31-high": 0.55, - "o3-mini-low": 0.55, - "o3-mini-2025-01-31-low": 0.55, - "o3-mini-medium": 0.55, - "o3-mini-2025-01-31-medium": 0.55, - "gpt-4o-mini": 0.075, - "gpt-4o-mini-2024-07-18": 0.075, - "gpt-4-turbo": 5, // $0.01 / 1K tokens - "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens + "o1": 7.5, + "o1-2024-12-17": 7.5, + "o1-preview": 7.5, + "o1-preview-2024-09-12": 7.5, + "o1-mini": 0.55, + "o1-mini-2024-09-12": 0.55, + "o3-mini": 0.55, + "o3-mini-2025-01-31": 0.55, + "o3-mini-high": 0.55, + "o3-mini-2025-01-31-high": 0.55, + "o3-mini-low": 0.55, + "o3-mini-2025-01-31-low": 0.55, + "o3-mini-medium": 0.55, + "o3-mini-2025-01-31-medium": 0.55, + "gpt-4o-mini": 0.075, + "gpt-4o-mini-2024-07-18": 0.075, + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens + "gpt-4.5-preview": 37.5, + "gpt-4.5-preview-2025-02-27": 37.5, //"gpt-3.5-turbo-0301": 0.75, //deprecated "gpt-3.5-turbo": 0.25, "gpt-3.5-turbo-0613": 0.75, @@ -83,92 +85,94 @@ var defaultModelRatio = map[string]float64{ "text-curie-001": 1, //"text-davinci-002": 10, //"text-davinci-003": 10, - "text-davinci-edit-001": 10, - "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens - "tts-1": 7.5, // 1k characters -> $0.015 - "tts-1-1106": 7.5, // 1k characters -> $0.015 - "tts-1-hd": 15, // 1k characters -> $0.03 - "tts-1-hd-1106": 15, // 1k characters -> $0.03 - "davinci": 10, - "curie": 10, - "babbage": 10, - "ada": 10, - "text-embedding-3-small": 0.01, - "text-embedding-3-large": 0.065, - "text-embedding-ada-002": 0.05, - "text-search-ada-doc-001": 10, - "text-moderation-stable": 0.1, - "text-moderation-latest": 0.1, - "claude-instant-1": 0.4, // $0.8 / 1M tokens - "claude-2.0": 4, // $8 / 1M tokens - "claude-2.1": 4, // $8 / 1M tokens - "claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens - "claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens - "claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens - "claude-3-5-sonnet-20240620": 1.5, - "claude-3-5-sonnet-20241022": 1.5, - "claude-3-opus-20240229": 7.5, // $15 / 1M tokens - "ERNIE-4.0-8K": 0.120 * RMB, - "ERNIE-3.5-8K": 0.012 * RMB, - "ERNIE-3.5-8K-0205": 0.024 * RMB, - "ERNIE-3.5-8K-1222": 0.012 * RMB, - "ERNIE-Bot-8K": 0.024 * RMB, - "ERNIE-3.5-4K-0205": 0.012 * RMB, - "ERNIE-Speed-8K": 0.004 * RMB, - "ERNIE-Speed-128K": 0.004 * RMB, - "ERNIE-Lite-8K-0922": 0.008 * RMB, - "ERNIE-Lite-8K-0308": 0.003 * RMB, - "ERNIE-Tiny-8K": 0.001 * RMB, - "BLOOMZ-7B": 0.004 * RMB, - "Embedding-V1": 0.002 * RMB, - "bge-large-zh": 0.002 * RMB, - "bge-large-en": 0.002 * RMB, - "tao-8k": 0.002 * RMB, - "PaLM-2": 1, - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-1.0-pro-vision-001": 1, - "gemini-1.0-pro-001": 1, - "gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens - "gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens - "gemini-1.5-flash-latest": 1, - "gemini-1.5-flash-exp-0827": 1, - "gemini-1.0-pro-latest": 1, - "gemini-1.0-pro-vision-latest": 1, - "gemini-ultra": 1, - "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "glm-4": 7.143, // ¥0.1 / 1k tokens - "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens - "glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens - "glm-3-turbo": 0.3572, - "glm-4-plus": 0.05 * RMB, - "glm-4-0520": 0.1 * RMB, - "glm-4-air": 0.001 * RMB, - "glm-4-airx": 0.01 * RMB, - "glm-4-long": 0.001 * RMB, - "glm-4-flash": 0, - "glm-4v-plus": 0.01 * RMB, - "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens - "qwen-plus": 10, // ¥0.14 / 1k tokens - "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens - "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v4.0": 1.2858, - "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens - "360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens - "360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens - "360gpt-pro": 0.8572, // ¥0.012 / 1k tokens - "360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens - "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens - "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens - "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens - "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // 1k characters -> $0.015 + "tts-1-1106": 7.5, // 1k characters -> $0.015 + "tts-1-hd": 15, // 1k characters -> $0.03 + "tts-1-hd-1106": 15, // 1k characters -> $0.03 + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-3-small": 0.01, + "text-embedding-3-large": 0.065, + "text-embedding-ada-002": 0.05, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, + "claude-instant-1": 0.4, // $0.8 / 1M tokens + "claude-2.0": 4, // $8 / 1M tokens + "claude-2.1": 4, // $8 / 1M tokens + "claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens + "claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens + "claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens + "claude-3-5-sonnet-20240620": 1.5, + "claude-3-5-sonnet-20241022": 1.5, + "claude-3-7-sonnet-20250219": 1.5, + "claude-3-7-sonnet-20250219-thinking": 1.5, + "claude-3-opus-20240229": 7.5, // $15 / 1M tokens + "ERNIE-4.0-8K": 0.120 * RMB, + "ERNIE-3.5-8K": 0.012 * RMB, + "ERNIE-3.5-8K-0205": 0.024 * RMB, + "ERNIE-3.5-8K-1222": 0.012 * RMB, + "ERNIE-Bot-8K": 0.024 * RMB, + "ERNIE-3.5-4K-0205": 0.012 * RMB, + "ERNIE-Speed-8K": 0.004 * RMB, + "ERNIE-Speed-128K": 0.004 * RMB, + "ERNIE-Lite-8K-0922": 0.008 * RMB, + "ERNIE-Lite-8K-0308": 0.003 * RMB, + "ERNIE-Tiny-8K": 0.001 * RMB, + "BLOOMZ-7B": 0.004 * RMB, + "Embedding-V1": 0.002 * RMB, + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "tao-8k": 0.002 * RMB, + "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro-vision-001": 1, + "gemini-1.0-pro-001": 1, + "gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens + "gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens + "gemini-1.5-flash-latest": 1, + "gemini-1.5-flash-exp-0827": 1, + "gemini-1.0-pro-latest": 1, + "gemini-1.0-pro-vision-latest": 1, + "gemini-ultra": 1, + "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "glm-4": 7.143, // ¥0.1 / 1k tokens + "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens + "glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens + "glm-3-turbo": 0.3572, + "glm-4-plus": 0.05 * RMB, + "glm-4-0520": 0.1 * RMB, + "glm-4-air": 0.001 * RMB, + "glm-4-airx": 0.01 * RMB, + "glm-4-long": 0.001 * RMB, + "glm-4-flash": 0, + "glm-4v-plus": 0.01 * RMB, + "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens + "qwen-plus": 10, // ¥0.14 / 1k tokens + "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v4.0": 1.2858, + "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens + "360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens + "360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens + "360gpt-pro": 0.8572, // ¥0.012 / 1k tokens + "360gpt2-pro": 0.8572, // ¥0.012 / 1k tokens + "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens + "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 // https://platform.lingyiwanwu.com/docs#-计费单元 // 已经按照 7.2 来换算美元价格 "yi-34b-chat-0205": 0.18, @@ -313,7 +317,7 @@ func UpdateModelRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &modelRatioMap) } -func GetModelRatio(name string) float64 { +func GetModelRatio(name string) (float64, bool) { GetModelRatioMap() if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" @@ -321,9 +325,9 @@ func GetModelRatio(name string) float64 { ratio, ok := modelRatioMap[name] if !ok { SysError("model ratio not found: " + name) - return 30 + return 37.5, false } - return ratio + return ratio, true } func DefaultModelRatio2JSONString() string { @@ -385,6 +389,9 @@ func GetCompletionRatio(name string) float64 { } return 4 } + if strings.HasPrefix(name, "gpt-4.5") { + return 2 + } if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") { return 3 } diff --git a/common/redis.go b/common/redis.go index 02582ee2193f3f1c0108bf8b2c10d2fa93fe94e1..49d3ec78a9edd7b8125217d7e54fd292d68b8eaa 100644 --- a/common/redis.go +++ b/common/redis.go @@ -32,6 +32,7 @@ func InitRedisClient() (err error) { if err != nil { FatalLog("failed to parse Redis connection string: " + err.Error()) } + opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10) RDB = redis.NewClient(opt) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -41,6 +42,10 @@ func InitRedisClient() (err error) { if err != nil { FatalLog("Redis ping test failed: " + err.Error()) } + if DebugEnabled { + SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr)) + SysLog(fmt.Sprintf("Redis database: %d", opt.DB)) + } return err } @@ -53,13 +58,20 @@ func ParseRedisOption() *redis.Options { } func RedisSet(key string, value string, expiration time.Duration) error { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration)) + } ctx := context.Background() return RDB.Set(ctx, key, value, expiration).Err() } func RedisGet(key string) (string, error) { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis GET: key=%s", key)) + } ctx := context.Background() - return RDB.Get(ctx, key).Result() + val, err := RDB.Get(ctx, key).Result() + return val, err } //func RedisExpire(key string, expiration time.Duration) error { @@ -73,16 +85,25 @@ func RedisGet(key string) (string, error) { //} func RedisDel(key string) error { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis DEL: key=%s", key)) + } ctx := context.Background() return RDB.Del(ctx, key).Err() } func RedisHDelObj(key string) error { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis HDEL: key=%s", key)) + } ctx := context.Background() return RDB.HDel(ctx, key).Err() } func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration)) + } ctx := context.Background() data := make(map[string]interface{}) @@ -130,6 +151,9 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error { } func RedisHGetObj(key string, obj interface{}) error { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key)) + } ctx := context.Background() result, err := RDB.HGetAll(ctx, key).Result() @@ -208,6 +232,9 @@ func RedisHGetObj(key string, obj interface{}) error { // RedisIncr Add this function to handle atomic increments func RedisIncr(key string, delta int64) error { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta)) + } // 检查键的剩余生存时间 ttlCmd := RDB.TTL(context.Background(), key) ttl, err := ttlCmd.Result() @@ -238,6 +265,9 @@ func RedisIncr(key string, delta int64) error { } func RedisHIncrBy(key, field string, delta int64) error { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta)) + } ttlCmd := RDB.TTL(context.Background(), key) ttl, err := ttlCmd.Result() if err != nil && !errors.Is(err, redis.Nil) { @@ -262,6 +292,9 @@ func RedisHIncrBy(key, field string, delta int64) error { } func RedisHSetField(key, field string, value interface{}) error { + if DebugEnabled { + SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value)) + } ttlCmd := RDB.TTL(context.Background(), key) ttl, err := ttlCmd.Result() if err != nil && !errors.Is(err, redis.Nil) { diff --git a/common/utils.go b/common/utils.go index fb769a7ceea2c565fb61914bf485075bba568daf..e57801e35fcb8a609515fa974fa5dcf475bde361 100644 --- a/common/utils.go +++ b/common/utils.go @@ -5,6 +5,7 @@ import ( "context" crand "crypto/rand" "encoding/base64" + "encoding/json" "fmt" "github.com/pkg/errors" "html/template" @@ -213,6 +214,24 @@ func RandomSleep() { time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) } +func GetPointer[T any](v T) *T { + return &v +} + +func Any2Type[T any](data any) (T, error) { + var zero T + bytes, err := json.Marshal(data) + if err != nil { + return zero, err + } + var res T + err = json.Unmarshal(bytes, &res) + if err != nil { + return zero, err + } + return res, nil +} + // SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. func SaveTmpFile(filename string, data io.Reader) (string, error) { f, err := os.CreateTemp(os.TempDir(), filename) diff --git a/constant/channel_setting.go b/constant/channel_setting.go index 6eccfb8433927e2633411b4035d10da488e0dd97..e06e7eb121a630ab29487cb02a37803e1661ea3f 100644 --- a/constant/channel_setting.go +++ b/constant/channel_setting.go @@ -1,6 +1,7 @@ package constant var ( - ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 - ChanelSettingProxy = "proxy" // Proxy 代理 + ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 + ChanelSettingProxy = "proxy" // Proxy 代理 + ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent ) diff --git a/constant/context_key.go b/constant/context_key.go index b02f2d43d911117f3a645b4c3ee3e30924c9186c..4b4d5cae0f0475666611c040e5a303bde6cc2a89 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -2,4 +2,9 @@ package constant const ( ContextKeyRequestStartTime = "request_start_time" + ContextKeyUserSetting = "user_setting" + ContextKeyUserQuota = "user_quota" + ContextKeyUserStatus = "user_status" + ContextKeyUserEmail = "user_email" + ContextKeyUserGroup = "user_group" ) diff --git a/constant/env.go b/constant/env.go index bffbfeea5ba1efbafdd20e176d5d0dfe2402c021..d2a1d04da0b95dc7b4bedaa4f84ee91346a5b495 100644 --- a/constant/env.go +++ b/constant/env.go @@ -1,10 +1,7 @@ package constant import ( - "fmt" "one-api/common" - "os" - "strings" ) var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60) @@ -23,9 +20,9 @@ var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview") -var GeminiModelMap = map[string]string{ - "gemini-1.0-pro": "v1", -} +//var GeminiModelMap = map[string]string{ +// "gemini-1.0-pro": "v1", +//} var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) @@ -33,18 +30,18 @@ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) func InitEnv() { - modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) - if modelVersionMapStr == "" { - return - } - for _, pair := range strings.Split(modelVersionMapStr, ",") { - parts := strings.Split(pair, ":") - if len(parts) == 2 { - GeminiModelMap[parts[0]] = parts[1] - } else { - common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) - } - } + //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) + //if modelVersionMapStr == "" { + // return + //} + //for _, pair := range strings.Split(modelVersionMapStr, ",") { + // parts := strings.Split(pair, ":") + // if len(parts) == 2 { + // GeminiModelMap[parts[0]] = parts[1] + // } else { + // common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) + // } + //} } // GenerateDefaultToken 是否生成初始令牌,默认关闭。 diff --git a/controller/channel-test.go b/controller/channel-test.go index 4b0cc169cb05c1d71b0443eff84ebe2681786c5b..23922073f0132db5dfc83d39c74194ae63b14b63 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -48,7 +48,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr if strings.Contains(strings.ToLower(testModel), "embedding") || strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 strings.Contains(testModel, "bge-") || // bge 系列模型 - testModel == "text-embedding-v1" || + strings.Contains(testModel, "embed") || channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型 requestPath = "/v1/embeddings" // 修改请求路径 } @@ -84,6 +84,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } } + cache, err := model.GetUserCache(1) + if err != nil { + return err, nil + } + cache.WriteContext(c) + c.Request.Header.Set("Authorization", "Bearer "+channel.Key) c.Request.Header.Set("Content-Type", "application/json") c.Set("channel", channel.Type) @@ -140,7 +146,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return err, nil } modelPrice, usePrice := common.GetModelPrice(testModel, false) - modelRatio := common.GetModelRatio(testModel) + modelRatio, success := common.GetModelRatio(testModel) + if !usePrice && !success { + return fmt.Errorf("模型 %s 倍率和价格均未设置", testModel), nil + } completionRatio := common.GetCompletionRatio(testModel) ratio := modelRatio quota := 0 diff --git a/controller/midjourney.go b/controller/midjourney.go index 2e351535c23f45361c783cdebf68bd5aabc4f2f9..21027d8f45c56ca10edf32401f847b3805ec54cc 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -159,7 +159,7 @@ func UpdateMidjourneyTaskBulk() { common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) } else { if shouldReturnQuota { - err = model.IncreaseUserQuota(task.UserId, task.Quota) + err = model.IncreaseUserQuota(task.UserId, task.Quota, false) if err != nil { common.LogError(ctx, "fail to increase user quota: "+err.Error()) } diff --git a/controller/model.go b/controller/model.go index 8ec2c7c90b921abf88a71a20ba778325340f347a..df7e59a6eeaf5ae282a4a99c9e1ef870d5f53e4a 100644 --- a/controller/model.go +++ b/controller/model.go @@ -216,6 +216,13 @@ func DashboardListModels(c *gin.Context) { }) } +func EnabledListModels(c *gin.Context) { + c.JSON(200, gin.H{ + "success": true, + "data": model.GetEnabledModels(), + }) +} + func RetrieveModel(c *gin.Context) { modelId := c.Param("model") if aiModel, ok := openAIModelsMap[modelId]; ok { diff --git a/controller/relay.go b/controller/relay.go index 0f7394156ce8d9a57f63fe64ed8e390b2c8373ca..e27ebb80f5015fc5cdfd1b65a51f64618105e322 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -85,6 +85,7 @@ func Relay(c *gin.Context) { if openaiErr != nil { if openaiErr.StatusCode == http.StatusTooManyRequests { + common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message)) openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) diff --git a/controller/task.go b/controller/task.go index 928f7ed7bef4433e442a883863765edc3edbf6f7..65f79ead252abe25e147e534b15064be655cdd70 100644 --- a/controller/task.go +++ b/controller/task.go @@ -159,7 +159,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas } else { quota := task.Quota if quota != 0 { - err = model.IncreaseUserQuota(task.UserId, quota) + err = model.IncreaseUserQuota(task.UserId, quota, false) if err != nil { common.LogError(ctx, "fail to increase user quota: "+err.Error()) } diff --git a/controller/topup.go b/controller/topup.go index fb51c545a99be0ca836920722032d376e58f62e5..a342ec3aeb59d86ca5ced84f86b862dd6b1ab609 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -210,7 +210,7 @@ func EpayNotify(c *gin.Context) { } //user, _ := model.GetUserById(topUp.UserId, false) //user.Quota += topUp.Amount * 500000 - err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit)) + err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit), true) if err != nil { log.Printf("易支付回调更新用户失败: %v", topUp) return diff --git a/docs/channel/other_setting.md b/docs/channel/other_setting.md index 775da5573775f98eb8641fe64f9283ad74ee6a5a..b3f4f969cd3f56eb1a0bba820104fab2d8379569 100644 --- a/docs/channel/other_setting.md +++ b/docs/channel/other_setting.md @@ -10,6 +10,10 @@ - 用于配置网络代理 - 类型为字符串,填写代理地址(例如 socks5 协议的代理地址) +3. thinking_to_content + - 用于标识是否将思考内容`reasoning_conetnt`转换为``标签拼接到内容中返回 + - 类型为布尔值,设置为 true 时启用思考内容转换 + -------------------------------------------------------------- ## JSON 格式示例 @@ -19,6 +23,7 @@ ```json { "force_format": true, + "thinking_to_content": true, "proxy": "socks5://xxxxxxx" } ``` diff --git a/dto/openai_request.go b/dto/openai_request.go index 028e0286cdb26188cc8c20b351830e70a46cfae6..6f12a19e5b01e632fbb1035b38bc7776a5f78103 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -1,6 +1,9 @@ package dto -import "encoding/json" +import ( + "encoding/json" + "strings" +) type ResponseFormat struct { Type string `json:"type,omitempty"` @@ -15,49 +18,52 @@ type FormatJsonSchema struct { } type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Prefix any `json:"prefix,omitempty"` - Suffix any `json:"suffix,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *StreamOptions `json:"stream_options,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - EncodingFormat any `json:"encoding_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools []ToolCall `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` - LogProbs bool `json:"logprobs,omitempty"` - TopLogProbs int `json:"top_logprobs,omitempty"` - Dimensions int `json:"dimensions,omitempty"` - Modalities any `json:"modalities,omitempty"` - Audio any `json:"audio,omitempty"` + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Prefix any `json:"prefix,omitempty"` + Suffix any `json:"suffix,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop any `json:"stop,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + EncodingFormat any `json:"encoding_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools []ToolCallRequest `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` + LogProbs bool `json:"logprobs,omitempty"` + TopLogProbs int `json:"top_logprobs,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + Modalities any `json:"modalities,omitempty"` + Audio any `json:"audio,omitempty"` + ExtraBody any `json:"extra_body,omitempty"` } -type OpenAITools struct { - Type string `json:"type"` - Function OpenAIFunction `json:"function"` +type ToolCallRequest struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` + Function FunctionRequest `json:"function"` } -type OpenAIFunction struct { +type FunctionRequest struct { Description string `json:"description,omitempty"` Name string `json:"name"` Parameters any `json:"parameters,omitempty"` + Arguments string `json:"arguments,omitempty"` } type StreamOptions struct { @@ -133,11 +139,11 @@ func (m *Message) SetPrefix(prefix bool) { m.Prefix = &prefix } -func (m *Message) ParseToolCalls() []ToolCall { +func (m *Message) ParseToolCalls() []ToolCallRequest { if m.ToolCalls == nil { return nil } - var toolCalls []ToolCall + var toolCalls []ToolCallRequest if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil { return toolCalls } @@ -153,11 +159,24 @@ func (m *Message) StringContent() string { if m.parsedStringContent != nil { return *m.parsedStringContent } + var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { + m.parsedStringContent = &stringContent return stringContent } - return string(m.Content) + + contentStr := new(strings.Builder) + arrayContent := m.ParseContent() + for _, content := range arrayContent { + if content.Type == ContentTypeText { + contentStr.WriteString(content.Text) + } + } + stringContent = contentStr.String() + m.parsedStringContent = &stringContent + + return stringContent } func (m *Message) SetStringContent(content string) { diff --git a/dto/openai_response.go b/dto/openai_response.go index febf01ff0d58fbd66660c4d9a4258e449538b743..56fac585fbfc4d5eb5bbd3bd1d0442a576c93148 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -62,10 +62,10 @@ type ChatCompletionsStreamResponseChoice struct { } type ChatCompletionsStreamResponseChoiceDelta struct { - Content *string `json:"content,omitempty"` - ReasoningContent *string `json:"reasoning_content,omitempty"` - Role string `json:"role,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Content *string `json:"content,omitempty"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + Role string `json:"role,omitempty"` + ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"` } func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { @@ -86,24 +86,28 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string return *c.ReasoningContent } -type ToolCall struct { +func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) { + c.ReasoningContent = &s +} + +type ToolCallResponse struct { // Index is not nil only in chat completion chunk object - Index *int `json:"index,omitempty"` - ID string `json:"id,omitempty"` - Type any `json:"type"` - Function FunctionCall `json:"function"` + Index *int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type any `json:"type"` + Function FunctionResponse `json:"function"` } -func (c *ToolCall) SetIndex(i int) { +func (c *ToolCallResponse) SetIndex(i int) { c.Index = &i } -type FunctionCall struct { +type FunctionResponse struct { Description string `json:"description,omitempty"` Name string `json:"name,omitempty"` // call function with arguments in JSON format Parameters any `json:"parameters,omitempty"` // request - Arguments string `json:"arguments,omitempty"` + Arguments string `json:"arguments"` // response } type ChatCompletionsStreamResponse struct { @@ -116,6 +120,20 @@ type ChatCompletionsStreamResponse struct { Usage *Usage `json:"usage"` } +func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse { + choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices)) + copy(choices, c.Choices) + return &ChatCompletionsStreamResponse{ + Id: c.Id, + Object: c.Object, + Created: c.Created, + Model: c.Model, + SystemFingerprint: c.SystemFingerprint, + Choices: choices, + Usage: c.Usage, + } +} + func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string { if c.SystemFingerprint == nil { return "" diff --git a/middleware/auth.go b/middleware/auth.go index 4d879a6cf624737b82d6f757ef5fd84c3c33f6c2..a589f52ccf0472bd7b67ed4d82a7e059f3f83ffe 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -199,15 +199,19 @@ func TokenAuth() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return } - userEnabled, err := model.IsUserEnabled(token.UserId, false) + userCache, err := model.GetUserCache(token.UserId) if err != nil { abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) return } + userEnabled := userCache.Status == common.UserStatusEnabled if !userEnabled { abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") return } + + userCache.WriteContext(c) + c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_key", token.Key) diff --git a/middleware/distributor.go b/middleware/distributor.go index e0f9342a84f2c3f4858af8789bdd6e3592ed7048..49fcf59b8a13676481b7ed5b3b04c11c894252b7 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -32,7 +32,6 @@ func Distribute() func(c *gin.Context) { return } } - userId := c.GetInt("id") var channel *model.Channel channelId, ok := c.Get("specific_channel_id") modelRequest, shouldSelectChannel, err := getModelRequest(c) @@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) return } - userGroup, _ := model.GetUserGroup(userId, false) + userGroup := c.GetString(constant.ContextKeyUserGroup) tokenGroup := c.GetString("token_group") if tokenGroup != "" { // check common.UserUsableGroups[userGroup] diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go new file mode 100644 index 0000000000000000000000000000000000000000..135e00058ebcdc34b5590acd47f566a505f2f393 --- /dev/null +++ b/middleware/model-rate-limit.go @@ -0,0 +1,172 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "one-api/common" + "one-api/setting" + "strconv" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" +) + +const ( + ModelRequestRateLimitCountMark = "MRRL" + ModelRequestRateLimitSuccessCountMark = "MRRLS" +) + +// 检查Redis中的请求限制 +func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { + // 如果maxCount为0,表示不限制 + if maxCount == 0 { + return true, nil + } + + // 获取当前计数 + length, err := rdb.LLen(ctx, key).Result() + if err != nil { + return false, err + } + + // 如果未达到限制,允许请求 + if length < int64(maxCount) { + return true, nil + } + + // 检查时间窗口 + oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() + oldTime, err := time.Parse(timeFormat, oldTimeStr) + if err != nil { + return false, err + } + + nowTimeStr := time.Now().Format(timeFormat) + nowTime, err := time.Parse(timeFormat, nowTimeStr) + if err != nil { + return false, err + } + // 如果在时间窗口内已达到限制,拒绝请求 + subTime := nowTime.Sub(oldTime).Seconds() + if int64(subTime) < duration { + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + return false, nil + } + + return true, nil +} + +// 记录Redis请求 +func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { + // 如果maxCount为0,不记录请求 + if maxCount == 0 { + return + } + + now := time.Now().Format(timeFormat) + rdb.LPush(ctx, key, now) + rdb.LTrim(ctx, key, 0, int64(maxCount-1)) + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) +} + +// Redis限流处理器 +func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { + return func(c *gin.Context) { + userId := strconv.Itoa(c.GetInt("id")) + ctx := context.Background() + rdb := common.RDB + + // 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过) + totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId) + allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration) + if err != nil { + fmt.Println("检查总请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + } + + // 2. 检查成功请求数限制 + successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) + allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) + if err != nil { + fmt.Println("检查成功请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount)) + return + } + + // 3. 记录总请求(当totalMaxCount为0时会自动跳过) + recordRedisRequest(ctx, rdb, totalKey, totalMaxCount) + + // 4. 处理请求 + c.Next() + + // 5. 如果请求成功,记录成功请求 + if c.Writer.Status() < 400 { + recordRedisRequest(ctx, rdb, successKey, successMaxCount) + } + } +} + +// 内存限流处理器 +func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { + inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + + return func(c *gin.Context) { + userId := strconv.Itoa(c.GetInt("id")) + totalKey := ModelRequestRateLimitCountMark + userId + successKey := ModelRequestRateLimitSuccessCountMark + userId + + // 1. 检查总请求数限制(当totalMaxCount为0时跳过) + if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } + + // 2. 检查成功请求数限制 + // 使用一个临时key来检查限制,这样可以避免实际记录 + checkKey := successKey + "_check" + if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } + + // 3. 处理请求 + c.Next() + + // 4. 如果请求成功,记录到实际的成功请求计数中 + if c.Writer.Status() < 400 { + inMemoryRateLimiter.Request(successKey, successMaxCount, duration) + } + } +} + +// ModelRequestRateLimit 模型请求限流中间件 +func ModelRequestRateLimit() func(c *gin.Context) { + // 如果未启用限流,直接放行 + if !setting.ModelRequestRateLimitEnabled { + return defNext + } + + // 计算限流参数 + duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) + totalMaxCount := setting.ModelRequestRateLimitCount + successMaxCount := setting.ModelRequestRateLimitSuccessCount + + // 根据存储类型选择限流处理器 + if common.RedisEnabled { + return redisRateLimitHandler(duration, totalMaxCount, successMaxCount) + } else { + return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount) + } +} diff --git a/model/log.go b/model/log.go index 82278c60b69a692b8748a3c32cc2a4ad2f0d3051..ed7ec2c796ff54134ff529462bf71746f7856ce8 100644 --- a/model/log.go +++ b/model/log.go @@ -1,8 +1,8 @@ package model import ( - "context" "fmt" + "github.com/gin-gonic/gin" "one-api/common" "os" "strings" @@ -87,14 +87,14 @@ func RecordLog(userId int, logType int, content string) { } } -func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, +func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, group string, other map[string]interface{}) { - common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) + common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !common.LogConsumeEnabled { return } - username, _ := GetUsernameById(userId, false) + username := c.GetString("username") otherStr := common.MapToJsonStr(other) log := &Log{ UserId: userId, @@ -116,7 +116,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke } err := LOG_DB.Create(log).Error if err != nil { - common.LogError(ctx, "failed to record log: "+err.Error()) + common.LogError(c, "failed to record log: "+err.Error()) } if common.DataExportEnabled { gopool.Go(func() { diff --git a/model/option.go b/model/option.go index 24935c69d1f7d0dac57ab4d2d4b72ba0a46f0a4f..64d15ca8e3a66f1147278c3d23c07e822c758d96 100644 --- a/model/option.go +++ b/model/option.go @@ -3,6 +3,7 @@ package model import ( "one-api/common" "one-api/setting" + "one-api/setting/config" "strconv" "strings" "time" @@ -23,6 +24,8 @@ func AllOption() ([]*Option, error) { func InitOptionMap() { common.OptionMapRWMutex.Lock() common.OptionMap = make(map[string]string) + + // 添加原有的系统配置 common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) @@ -85,6 +88,9 @@ func InitOptionMap() { common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) + common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) + common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() @@ -105,13 +111,19 @@ func InitOptionMap() { common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled) + common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) - //common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString() + // 自动添加所有注册的模型配置 + modelConfigs := config.GlobalConfig.ExportAllConfigs() + for k, v := range modelConfigs { + common.OptionMap[k] = v + } + common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -154,6 +166,13 @@ func updateOptionMap(key string, value string) (err error) { common.OptionMapRWMutex.Lock() defer common.OptionMapRWMutex.Unlock() common.OptionMap[key] = value + + // 检查是否是模型配置 - 使用更规范的方式处理 + if handleConfigUpdate(key, value) { + return nil // 已由配置系统处理 + } + + // 处理传统配置项... if strings.HasSuffix(key, "Permission") { intValue, _ := strconv.Atoi(value) switch key { @@ -226,8 +245,8 @@ func updateOptionMap(key string, value string) (err error) { setting.DemoSiteEnabled = boolValue case "CheckSensitiveOnPromptEnabled": setting.CheckSensitiveOnPromptEnabled = boolValue - //case "CheckSensitiveOnCompletionEnabled": - // constant.CheckSensitiveOnCompletionEnabled = boolValue + case "ModelRequestRateLimitEnabled": + setting.ModelRequestRateLimitEnabled = boolValue case "StopOnSensitiveEnabled": setting.StopOnSensitiveEnabled = boolValue case "SMTPSSLEnabled": @@ -308,6 +327,12 @@ func updateOptionMap(key string, value string) (err error) { common.QuotaRemindThreshold, _ = strconv.Atoi(value) case "ShouldPreConsumedQuota": common.PreConsumedQuota, _ = strconv.Atoi(value) + case "ModelRequestRateLimitCount": + setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitDurationMinutes": + setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) + case "ModelRequestRateLimitSuccessCount": + setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": @@ -343,3 +368,28 @@ func updateOptionMap(key string, value string) (err error) { } return err } + +// handleConfigUpdate 处理分层配置更新,返回是否已处理 +func handleConfigUpdate(key, value string) bool { + parts := strings.SplitN(key, ".", 2) + if len(parts) != 2 { + return false // 不是分层配置 + } + + configName := parts[0] + configKey := parts[1] + + // 获取配置对象 + cfg := config.GlobalConfig.Get(configName) + if cfg == nil { + return false // 未注册的配置 + } + + // 更新配置 + configMap := map[string]string{ + configKey: value, + } + config.UpdateConfigFromMap(cfg, configMap) + + return true // 已处理 +} diff --git a/model/pricing.go b/model/pricing.go index 8ae5e32be7f0f84deeed57ae64368b606cf4bda3..fc709ce4e5f1c73423f2efa5aede50d4bed596ec 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -69,7 +69,8 @@ func updatePricing() { pricing.ModelPrice = modelPrice pricing.QuotaType = 1 } else { - pricing.ModelRatio = common.GetModelRatio(model) + modelRatio, _ := common.GetModelRatio(model) + pricing.ModelRatio = modelRatio pricing.CompletionRatio = common.GetCompletionRatio(model) pricing.QuotaType = 0 } diff --git a/model/user.go b/model/user.go index 427b0625f4b2b998b1dd452e7bf969d984e21463..524f56b68ce2f61b7d3968fe3eded25576e253c3 100644 --- a/model/user.go +++ b/model/user.go @@ -320,7 +320,7 @@ func (user *User) Insert(inviterId int) error { } if inviterId != 0 { if common.QuotaForInvitee > 0 { - _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) + _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) } if common.QuotaForInviter > 0 { @@ -502,35 +502,35 @@ func IsAdmin(userId int) bool { return user.Role >= common.RoleAdminUser } -// IsUserEnabled checks user status from Redis first, falls back to DB if needed -func IsUserEnabled(id int, fromDB bool) (status bool, err error) { - defer func() { - // Update Redis cache asynchronously on successful DB read - if shouldUpdateRedis(fromDB, err) { - gopool.Go(func() { - if err := updateUserStatusCache(id, status); err != nil { - common.SysError("failed to update user status cache: " + err.Error()) - } - }) - } - }() - if !fromDB && common.RedisEnabled { - // Try Redis first - status, err := getUserStatusCache(id) - if err == nil { - return status == common.UserStatusEnabled, nil - } - // Don't return error - fall through to DB - } - fromDB = true - var user User - err = DB.Where("id = ?", id).Select("status").Find(&user).Error - if err != nil { - return false, err - } - - return user.Status == common.UserStatusEnabled, nil -} +//// IsUserEnabled checks user status from Redis first, falls back to DB if needed +//func IsUserEnabled(id int, fromDB bool) (status bool, err error) { +// defer func() { +// // Update Redis cache asynchronously on successful DB read +// if shouldUpdateRedis(fromDB, err) { +// gopool.Go(func() { +// if err := updateUserStatusCache(id, status); err != nil { +// common.SysError("failed to update user status cache: " + err.Error()) +// } +// }) +// } +// }() +// if !fromDB && common.RedisEnabled { +// // Try Redis first +// status, err := getUserStatusCache(id) +// if err == nil { +// return status == common.UserStatusEnabled, nil +// } +// // Don't return error - fall through to DB +// } +// fromDB = true +// var user User +// err = DB.Where("id = ?", id).Select("status").Find(&user).Error +// if err != nil { +// return false, err +// } +// +// return user.Status == common.UserStatusEnabled, nil +//} func ValidateAccessToken(token string) (user *User) { if token == "" { @@ -639,7 +639,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err return common.StrToMap(setting), nil } -func IncreaseUserQuota(id int, quota int) (err error) { +func IncreaseUserQuota(id int, quota int, db bool) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -649,7 +649,7 @@ func IncreaseUserQuota(id int, quota int) (err error) { common.SysError("failed to increase user quota: " + err.Error()) } }) - if common.BatchUpdateEnabled { + if !db && common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, quota) return nil } @@ -694,7 +694,7 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) { return nil } if delta > 0 { - return IncreaseUserQuota(id, delta) + return IncreaseUserQuota(id, delta, false) } else { return DecreaseUserQuota(id, -delta) } diff --git a/model/user_cache.go b/model/user_cache.go index cc08288d681e3cf77be4c071079101a8cc0726fa..bc412e77eae6c70ba8855d53927cf1218f58322a 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -3,6 +3,7 @@ package model import ( "encoding/json" "fmt" + "github.com/gin-gonic/gin" "one-api/common" "one-api/constant" "time" @@ -21,6 +22,15 @@ type UserBase struct { Setting string `json:"setting"` } +func (user *UserBase) WriteContext(c *gin.Context) { + c.Set(constant.ContextKeyUserGroup, user.Group) + c.Set(constant.ContextKeyUserQuota, user.Quota) + c.Set(constant.ContextKeyUserStatus, user.Status) + c.Set(constant.ContextKeyUserEmail, user.Email) + c.Set("username", user.Username) + c.Set(constant.ContextKeyUserSetting, user.GetSetting()) +} + func (user *UserBase) GetSetting() map[string]interface{} { if user.Setting == "" { return nil diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index cd1b5153467c2c32b94e850d0b8af0010ea24b8f..a60bc6f1898be24c14d9567a3a1f010d6f328e1f 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -130,7 +130,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } - resp, err := doRequest(c, req, info.ToRelayInfo()) + resp, err := doRequest(c, req, info.RelayInfo) if err != nil { return nil, fmt.Errorf("do request failed: %w", err) } diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 5a3d09b976623a7d2526c0395402a337e314491c..7f2a2841bc664e284685acc5d7b067620653acb2 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -8,6 +8,7 @@ import ( "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" + "one-api/setting/model_setting" ) const ( @@ -38,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req) return nil } @@ -49,8 +51,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re var claudeReq *claude.ClaudeRequest var err error claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) - - c.Set("request_model", request.Model) + if err != nil { + return nil, err + } + c.Set("request_model", claudeReq.Model) c.Set("converted_request", claudeReq) return claudeReq, err } @@ -64,7 +68,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return nil, nil } diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 8454bf9d181f9c37b8ca72390f688d8415d44f9e..66dc7cd97e74318574b123e4c4896387ad263ab1 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -9,7 +9,8 @@ var awsModelIDMap = map[string]string{ "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", - "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", + "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", + "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0", } var ChannelName = "aws" diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 505967adfa338b41b418fc55f19bac32f42ff840..e87ed6ecf94e78fa4ce5d5b27938330955965894 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -16,6 +16,7 @@ type AwsClaudeRequest struct { StopSequences []string `json:"stop_sequences,omitempty"` Tools []claude.Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + Thinking *claude.Thinking `json:"thinking,omitempty"` } func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest { @@ -30,5 +31,6 @@ func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest { StopSequences: req.StopSequences, Tools: req.Tools, ToolChoice: req.ToolChoice, + Thinking: req.Thinking, } } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 83168382d45dd7c21d66d2b63c6980fb4e415a17..bf03e5f5fdaff5ee71540527542feb7d3dd38a43 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/setting/model_setting" "strings" ) @@ -55,6 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel anthropicVersion = "2023-06-01" } req.Set("anthropic-version", anthropicVersion) + model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req) return nil } diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go index b6a33301ab59c04483b1005d41cd312f1321db1c..d7e0c8e36efc65dbeaa2fa5f7594a46fa456858f 100644 --- a/relay/channel/claude/constants.go +++ b/relay/channel/claude/constants.go @@ -11,6 +11,8 @@ var ModelList = []string{ "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022", + "claude-3-7-sonnet-20250219", + "claude-3-7-sonnet-20250219-thinking", } var ChannelName = "claude" diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go index 13a1430c420a8f052c9ef6a56ed9f42823b682c5..90f06b265a1040ad268c8218824914c99377908f 100644 --- a/relay/channel/claude/dto.go +++ b/relay/channel/claude/dto.go @@ -11,6 +11,9 @@ type ClaudeMediaMessage struct { Usage *ClaudeUsage `json:"usage,omitempty"` StopReason *string `json:"stop_reason,omitempty"` PartialJson string `json:"partial_json,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` + Delta string `json:"delta,omitempty"` // tool_calls Id string `json:"id,omitempty"` Name string `json:"name,omitempty"` @@ -54,9 +57,15 @@ type ClaudeRequest struct { TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Thinking *Thinking `json:"thinking,omitempty"` +} + +type Thinking struct { + Type string `json:"type"` + BudgetTokens int `json:"budget_tokens"` } type ClaudeError struct { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 317bf6047c80afd198459ea1615e2238de3628aa..e32ee817b3129ca14e1c94457290303fdbc63b24 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -10,6 +10,7 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" + "one-api/setting/model_setting" "strings" "github.com/gin-gonic/gin" @@ -92,9 +93,31 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR Stream: textRequest.Stream, Tools: claudeTools, } + if claudeRequest.MaxTokens == 0 { - claudeRequest.MaxTokens = 4096 + claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model)) } + + if model_setting.GetClaudeSettings().ThinkingAdapterEnabled && + strings.HasSuffix(textRequest.Model, "-thinking") { + + // 因为BudgetTokens 必须大于1024 + if claudeRequest.MaxTokens < 1280 { + claudeRequest.MaxTokens = 1280 + } + + // BudgetTokens 为 max_tokens 的 80% + claudeRequest.Thinking = &Thinking{ + Type: "enabled", + BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage), + } + // TODO: 临时处理 + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking + claudeRequest.TopP = 0 + claudeRequest.Temperature = common.GetPointer[float64](1.0) + claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") + } + if textRequest.Stop != nil { // stop maybe string/array string, convert to array string switch textRequest.Stop.(type) { @@ -273,7 +296,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) - tools := make([]dto.ToolCall, 0) + tools := make([]dto.ToolCallResponse, 0) var choice dto.ChatCompletionsStreamResponseChoice if reqMode == RequestModeCompletion { choice.Delta.SetContentString(claudeResponse.Completion) @@ -292,10 +315,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* if claudeResponse.ContentBlock != nil { //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) if claudeResponse.ContentBlock.Type == "tool_use" { - tools = append(tools, dto.ToolCall{ + tools = append(tools, dto.ToolCallResponse{ ID: claudeResponse.ContentBlock.Id, Type: "function", - Function: dto.FunctionCall{ + Function: dto.FunctionResponse{ Name: claudeResponse.ContentBlock.Name, Arguments: "", }, @@ -308,12 +331,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* if claudeResponse.Delta != nil { choice.Index = claudeResponse.Index choice.Delta.SetContentString(claudeResponse.Delta.Text) - if claudeResponse.Delta.Type == "input_json_delta" { - tools = append(tools, dto.ToolCall{ - Function: dto.FunctionCall{ + switch claudeResponse.Delta.Type { + case "input_json_delta": + tools = append(tools, dto.ToolCallResponse{ + Function: dto.FunctionResponse{ Arguments: claudeResponse.Delta.PartialJson, }, }) + case "signature_delta": + // 加密的不处理 + signatureContent := "\n" + choice.Delta.ReasoningContent = &signatureContent + case "thinking_delta": + thinkingContent := claudeResponse.Delta.Thinking + choice.Delta.ReasoningContent = &thinkingContent } } } else if claudeResponse.Type == "message_delta" { @@ -351,7 +382,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope if len(claudeResponse.Content) > 0 { responseText = claudeResponse.Content[0].Text } - tools := make([]dto.ToolCall, 0) + tools := make([]dto.ToolCallResponse, 0) + thinkingContent := "" + if reqMode == RequestModeCompletion { content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) choice := dto.OpenAITextResponseChoice{ @@ -367,16 +400,22 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope } else { fullTextResponse.Id = claudeResponse.Id for _, message := range claudeResponse.Content { - if message.Type == "tool_use" { + switch message.Type { + case "tool_use": args, _ := json.Marshal(message.Input) - tools = append(tools, dto.ToolCall{ + tools = append(tools, dto.ToolCallResponse{ ID: message.Id, Type: "function", // compatible with other OpenAI derivative applications - Function: dto.FunctionCall{ + Function: dto.FunctionResponse{ Name: message.Name, Arguments: string(args), }, }) + case "thinking": + // 加密的不管, 只输出明文的推理过程 + thinkingContent = message.Thinking + case "text": + responseText = message.Text } } } @@ -391,6 +430,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope if len(tools) > 0 { choice.Message.SetToolCalls(tools) } + choice.Message.ReasoningContent = thinkingContent fullTextResponse.Model = claudeResponse.Model choices = append(choices, choice) fullTextResponse.Choices = choices diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index ce73c78c194563b8f1e6ce8b65004cc3f7e425eb..2626dd7d41e16290e1e3e63faed33b95a7816a80 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -9,9 +9,18 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "strings" +) + +const ( + BotTypeChatFlow = 1 // chatflow default + BotTypeAgent = 2 + BotTypeWorkFlow = 3 + BotTypeCompletion = 4 ) type Adaptor struct { + BotType int } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -25,10 +34,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { + if strings.HasPrefix(info.UpstreamModelName, "agent") { + a.BotType = BotTypeAgent + } else if strings.HasPrefix(info.UpstreamModelName, "workflow") { + a.BotType = BotTypeWorkFlow + } else if strings.HasPrefix(info.UpstreamModelName, "chat") { + a.BotType = BotTypeCompletion + } else { + a.BotType = BotTypeChatFlow + } } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil + switch a.BotType { + case BotTypeWorkFlow: + return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil + case BotTypeCompletion: + return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil + case BotTypeAgent: + fallthrough + default: + return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -53,7 +80,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 32513c42ab99813de3258cbd6d5bd88ffae66f61..37c6c9df512555b14ebce28fcab6b4e088aec102 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -7,11 +7,11 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" + "one-api/setting/model_setting" "strings" @@ -64,15 +64,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta" - version, beta := constant.GeminiModelMap[info.UpstreamModelName] - if !beta { - if info.ApiVersion != "" { - version = info.ApiVersion - } else { - version = "v1beta" - } - } + version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName) if strings.HasPrefix(info.UpstreamModelName, "imagen") { return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go index b7c1f0cf8758c55ade55ccacabfa6e61c978650d..1f402cbc4e404fb798a02c3615540c53e61c36ad 100644 --- a/relay/channel/gemini/constant.go +++ b/relay/channel/gemini/constant.go @@ -20,4 +20,12 @@ var ModelList = []string{ "imagen-3.0-generate-002", } +var SafetySettingList = []string{ + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + "HARM_CATEGORY_CIVIC_INTEGRITY", +} + var ChannelName = "google gemini" diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 8068709e9a07ac7f12cc52fe6388ed02b583f891..d5103124edf6d8407913dc613b91349852604ac7 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -11,6 +11,7 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" + "one-api/setting/model_setting" "strings" "unicode/utf8" @@ -22,28 +23,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), - SafetySettings: []GeminiChatSafetySettings{ - { - Category: "HARM_CATEGORY_HARASSMENT", - Threshold: common.GeminiSafetySetting, - }, - { - Category: "HARM_CATEGORY_HATE_SPEECH", - Threshold: common.GeminiSafetySetting, - }, - { - Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", - Threshold: common.GeminiSafetySetting, - }, - { - Category: "HARM_CATEGORY_DANGEROUS_CONTENT", - Threshold: common.GeminiSafetySetting, - }, - { - Category: "HARM_CATEGORY_CIVIC_INTEGRITY", - Threshold: common.GeminiSafetySetting, - }, - }, + //SafetySettings: []GeminiChatSafetySettings{}, GenerationConfig: GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, @@ -52,9 +32,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque }, } + safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList)) + for _, category := range SafetySettingList { + safetySettings = append(safetySettings, GeminiChatSafetySettings{ + Category: category, + Threshold: model_setting.GetGeminiSafetySetting(category), + }) + } + geminiRequest.SafetySettings = safetySettings + // openaiContent.FuncToToolCalls() if textRequest.Tools != nil { - functions := make([]dto.FunctionCall, 0, len(textRequest.Tools)) + functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools)) googleSearch := false codeExecution := false for _, tool := range textRequest.Tools { @@ -349,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} { return data } -func getToolCall(item *GeminiPart) *dto.ToolCall { +func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse { var argsBytes []byte var err error if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok { @@ -361,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall { if err != nil { return nil } - return &dto.ToolCall{ + return &dto.ToolCallResponse{ ID: fmt.Sprintf("call_%s", common.GetUUID()), Type: "function", - Function: dto.FunctionCall{ + Function: dto.FunctionResponse{ Arguments: string(argsBytes), Name: item.FunctionCall.FunctionName, }, @@ -379,7 +368,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), } content, _ := json.Marshal("") - is_tool_call := false + isToolCall := false for _, candidate := range response.Candidates { choice := dto.OpenAITextResponseChoice{ Index: int(candidate.Index), @@ -391,12 +380,12 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp } if len(candidate.Content.Parts) > 0 { var texts []string - var tool_calls []dto.ToolCall + var toolCalls []dto.ToolCallResponse for _, part := range candidate.Content.Parts { if part.FunctionCall != nil { choice.FinishReason = constant.FinishReasonToolCalls - if call := getToolCall(&part); call != nil { - tool_calls = append(tool_calls, *call) + if call := getResponseToolCall(&part); call != nil { + toolCalls = append(toolCalls, *call) } } else { if part.ExecutableCode != nil { @@ -411,9 +400,9 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp } } } - if len(tool_calls) > 0 { - choice.Message.SetToolCalls(tool_calls) - is_tool_call = true + if len(toolCalls) > 0 { + choice.Message.SetToolCalls(toolCalls) + isToolCall = true } choice.Message.SetStringContent(strings.Join(texts, "\n")) @@ -429,7 +418,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp choice.FinishReason = constant.FinishReasonContentFilter } } - if is_tool_call { + if isToolCall { choice.FinishReason = constant.FinishReasonToolCalls } @@ -468,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C for _, part := range candidate.Content.Parts { if part.FunctionCall != nil { isTools = true - if call := getToolCall(&part); call != nil { + if call := getResponseToolCall(&part); call != nil { call.SetIndex(len(choice.Delta.ToolCalls)) choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call) } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 3706e3b8e938a317eac2d9ee15db108844d5eafc..77076bd48101ad2f30c38575cb67e50fed6930b9 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -61,7 +61,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { - err, usage = jinaRerankHandler(c, resp) + err, usage = JinaRerankHandler(c, resp) } else if info.RelayMode == constant.RelayModeEmbeddings { err, usage = jinaEmbeddingHandler(c, resp) } diff --git a/relay/channel/jina/relay-jina.go b/relay/channel/jina/relay-jina.go index 6c339aee33fdca9c5b0962fecd4a533a16ce5a74..aee7b13175b910de15e3ace3315b0e9d9fc54b63 100644 --- a/relay/channel/jina/relay-jina.go +++ b/relay/channel/jina/relay-jina.go @@ -9,7 +9,7 @@ import ( "one-api/service" ) -func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index a954c607a69104c4ae35cde8aaf2fb8c66a96460..15c64cdcd85e1f5055de9d33ff5f2ec8d602c0a3 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -3,21 +3,22 @@ package ollama import "one-api/dto" type OllamaRequest struct { - Model string `json:"model,omitempty"` - Messages []dto.Message `json:"messages,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - Seed float64 `json:"seed,omitempty"` - Topp float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - Tools []dto.ToolCall `json:"tools,omitempty"` - ResponseFormat any `json:"response_format,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - Suffix any `json:"suffix,omitempty"` - StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"` - Prompt any `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + Messages []dto.Message `json:"messages,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + Seed float64 `json:"seed,omitempty"` + Topp float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop any `json:"stop,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + Tools []dto.ToolCallRequest `json:"tools,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Suffix any `json:"suffix,omitempty"` + StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"` + Prompt any `json:"prompt,omitempty"` } type Options struct { diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 8b53fbfb56ae899c1c40644ccfae49f46c6b0089..89e9c214c872802e947da801dce4210f276c5392 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -58,6 +58,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err TopK: request.TopK, Stop: Stop, Tools: request.Tools, + MaxTokens: request.MaxTokens, ResponseFormat: request.ResponseFormat, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index f927fa74f45f69aed47d5774481230604f29c9fc..6dbbb17e21657a8198e5b79d8ba8caffd260bfdf 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -14,6 +14,7 @@ import ( "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/ai360" + "one-api/relay/channel/jina" "one-api/relay/channel/lingyiwanwu" "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" @@ -146,7 +147,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, errors.New("not implemented") + return request, nil } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { @@ -228,6 +229,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) case constant.RelayModeImagesGenerations: err, usage = OpenaiTTSHandler(c, resp, info) + case constant.RelayModeRerank: + err, usage = jina.JinaRerankHandler(c, resp) default: if info.IsStream { err, usage = OaiStreamHandler(c, resp, info) diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go index d55242edc579b82228f24948f37e34e71edf2fe1..c703e414b5e1d866e5ab87ab4e6ba599aa8664ab 100644 --- a/relay/channel/openai/constant.go +++ b/relay/channel/openai/constant.go @@ -11,6 +11,7 @@ var ModelList = []string{ "chatgpt-4o-latest", "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", + "gpt-4.5-preview", "gpt-4.5-preview-2025-02-27", "o1-preview", "o1-preview-2024-09-12", "o1-mini", "o1-mini-2024-09-12", "o3-mini", "o3-mini-2025-01-31", diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 33cdea48639f8625e25d14662ae862900636e9eb..a5bd0e33d10751e3a6b5ecd5c1ea4e38b1c6f8e5 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -5,10 +5,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "github.com/pkg/errors" "io" "math" "mime/multipart" @@ -23,21 +19,66 @@ import ( "strings" "sync" "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/pkg/errors" ) -func sendStreamData(c *gin.Context, data string, forceFormat bool) error { +func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { if data == "" { return nil } - if forceFormat { - var lastStreamResponse dto.ChatCompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil { - return err + if !forceFormat && !thinkToContent { + return service.StringData(c, data) + } + + var lastStreamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil { + return err + } + + if !thinkToContent { + return service.ObjectData(c, lastStreamResponse) + } + + // Handle think to content conversion + if info.IsFirstResponse { + response := lastStreamResponse.Copy() + for i := range response.Choices { + response.Choices[i].Delta.SetContentString("\n") + response.Choices[i].Delta.SetReasoningContent("") } + service.ObjectData(c, response) + } + + if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 { return service.ObjectData(c, lastStreamResponse) } - return service.StringData(c, data) + + // Process each choice + for i, choice := range lastStreamResponse.Choices { + // Handle transition from thinking to content + if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse { + response := lastStreamResponse.Copy() + for j := range response.Choices { + response.Choices[j].Delta.SetContentString("\n") + response.Choices[j].Delta.SetReasoningContent("") + } + info.SendLastReasoningResponse = true + service.ObjectData(c, response) + } + + // Convert reasoning content to regular content + if len(choice.Delta.GetReasoningContent()) > 0 { + lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent()) + lastStreamResponse.Choices[i].Delta.SetReasoningContent("") + } + } + + return service.ObjectData(c, lastStreamResponse) } func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -56,11 +97,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel var usage = &dto.Usage{} var streamItems []string // store stream items var forceFormat bool + var thinkToContent bool - if info.ChannelType == common.ChannelTypeCustom { - if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok { - forceFormat = forceFmt - } + if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { + forceFormat = forceFmt + } + + if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok { + thinkToContent = think2Content } toolCount := 0 @@ -84,7 +128,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel ) gopool.Go(func() { for scanner.Scan() { - info.SetFirstResponseTime() + //info.SetFirstResponseTime() ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) data := scanner.Text() if common.DebugEnabled { @@ -101,10 +145,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel data = strings.TrimSpace(data) if !strings.HasPrefix(data, "[DONE]") { if lastStreamData != "" { - err := sendStreamData(c, lastStreamData, forceFormat) + err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) if err != nil { common.LogError(c, "streaming error: "+err.Error()) } + info.SetFirstResponseTime() } lastStreamData = data streamItems = append(streamItems, data) @@ -144,7 +189,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } } if shouldSendLastResp { - sendStreamData(c, lastStreamData, forceFormat) + sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) } // 计算token diff --git a/relay/channel/openrouter/adaptor.go b/relay/channel/openrouter/adaptor.go new file mode 100644 index 0000000000000000000000000000000000000000..83afb6af7bade7e23dd5c697d4016ac294d1fc5e --- /dev/null +++ b/relay/channel/openrouter/adaptor.go @@ -0,0 +1,74 @@ +package openrouter + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" +) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api") + req.Set("X-Title", "New API") + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = openai.OaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/openrouter/constant.go b/relay/channel/openrouter/constant.go new file mode 100644 index 0000000000000000000000000000000000000000..0372eb9a2c8e876f6b8937cef1668fe3ace8895d --- /dev/null +++ b/relay/channel/openrouter/constant.go @@ -0,0 +1,5 @@ +package openrouter + +var ModelList = []string{} + +var ChannelName = "openrouter" diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 07659c20ab62e0e3bf00b3231a1702f2d932fe77..0d8ccef1fcd2c48d5b8092176dd251dcce72de8a 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -28,6 +28,7 @@ var claudeModelMap = map[string]string{ "claude-3-opus-20240229": "claude-3-opus@20240229", "claude-3-haiku-20240307": "claude-3-haiku@20240307", "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620", + "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219", } const anthropicVersion = "vertex-2023-10-16" @@ -132,7 +133,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil { return nil, errors.New("failed to copy claude request") } - c.Set("request_model", request.Model) + c.Set("request_model", claudeReq.Model) return vertexClaudeReq, nil } else if a.RequestMode == RequestModeGemini { geminiRequest, err := gemini.CovertGemini2OpenAI(*request) @@ -156,7 +157,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 1f4a3a42b6fd99f6344a9a11a86516c239d824e0..022ab62800e7b03aeb8d9485f5bc35c5ca15345d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -13,23 +13,24 @@ import ( ) type RelayInfo struct { - ChannelType int - ChannelId int - TokenId int - TokenKey string - UserId int - Group string - TokenUnlimited bool - StartTime time.Time - FirstResponseTime time.Time - setFirstResponse bool - ApiType int - IsStream bool - IsPlayground bool - UsePrice bool - RelayMode int - UpstreamModelName string - OriginModelName string + ChannelType int + ChannelId int + TokenId int + TokenKey string + UserId int + Group string + TokenUnlimited bool + StartTime time.Time + FirstResponseTime time.Time + IsFirstResponse bool + SendLastReasoningResponse bool + ApiType int + IsStream bool + IsPlayground bool + UsePrice bool + RelayMode int + UpstreamModelName string + OriginModelName string //RecodeModelName string RequestURLPath string ApiVersion string @@ -49,6 +50,9 @@ type RelayInfo struct { AudioUsage bool ReasoningEffort string ChannelSetting map[string]interface{} + UserSetting map[string]interface{} + UserEmail string + UserQuota int } // 定义支持流式选项的通道类型 @@ -88,6 +92,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { apiType, _ := relayconstant.ChannelType2APIType(channelType) info := &RelayInfo{ + UserQuota: c.GetInt(constant.ContextKeyUserQuota), + UserSetting: c.GetStringMap(constant.ContextKeyUserSetting), + UserEmail: c.GetString(constant.ContextKeyUserEmail), + IsFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), BaseUrl: c.GetString("base_url"), RequestURLPath: c.Request.URL.String(), @@ -139,26 +147,14 @@ func (info *RelayInfo) SetIsStream(isStream bool) { } func (info *RelayInfo) SetFirstResponseTime() { - if !info.setFirstResponse { + if info.IsFirstResponse { info.FirstResponseTime = time.Now() - info.setFirstResponse = true + info.IsFirstResponse = false } } type TaskRelayInfo struct { - ChannelType int - ChannelId int - TokenId int - UserId int - Group string - StartTime time.Time - ApiType int - RelayMode int - UpstreamModelName string - RequestURLPath string - ApiKey string - BaseUrl string - + *RelayInfo Action string OriginTaskID string @@ -166,48 +162,8 @@ type TaskRelayInfo struct { } func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { - channelType := c.GetInt("channel_type") - channelId := c.GetInt("channel_id") - - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - group := c.GetString("group") - startTime := time.Now() - - apiType, _ := relayconstant.ChannelType2APIType(channelType) - info := &TaskRelayInfo{ - RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: c.GetString("base_url"), - RequestURLPath: c.Request.URL.String(), - ChannelType: channelType, - ChannelId: channelId, - TokenId: tokenId, - UserId: userId, - Group: group, - StartTime: startTime, - ApiType: apiType, - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - } - if info.BaseUrl == "" { - info.BaseUrl = common.ChannelBaseURLs[channelType] + RelayInfo: GenRelayInfo(c), } return info } - -func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo { - return &RelayInfo{ - ChannelType: info.ChannelType, - ChannelId: info.ChannelId, - TokenId: info.TokenId, - UserId: info.UserId, - Group: info.Group, - StartTime: info.StartTime, - ApiType: info.ApiType, - RelayMode: info.RelayMode, - UpstreamModelName: info.UpstreamModelName, - RequestURLPath: info.RequestURLPath, - ApiKey: info.ApiKey, - BaseUrl: info.BaseUrl, - } -} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index f7a875369a4aea07820b9d4b7bc858b78ace5f50..8ccfee03c1adf7c747c360a385c8926158cc9558 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -30,6 +30,7 @@ const ( APITypeMokaAI APITypeVolcEngine APITypeBaiduV2 + APITypeOpenRouter APITypeDummy // this one is only for count, do not add any channel after this ) @@ -86,6 +87,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeVolcEngine case common.ChannelTypeBaiduV2: apiType = APITypeBaiduV2 + case common.ChannelTypeOpenRouter: + apiType = APITypeOpenRouter } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/helper/price.go b/relay/helper/price.go index d65b86aa5a307e51db3325d7cd907b62819e78b8..1f4a5b3c5cbc26337179f52baf054c7653521f79 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -1,6 +1,7 @@ package helper import ( + "fmt" "github.com/gin-gonic/gin" "one-api/common" relaycommon "one-api/relay/common" @@ -15,7 +16,7 @@ type PriceData struct { ShouldPreConsumedQuota int } -func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData { +func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false) groupRatio := setting.GetGroupRatio(info.Group) var preConsumedQuota int @@ -25,7 +26,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens if maxTokens != 0 { preConsumedTokens = promptTokens + maxTokens } - modelRatio = common.GetModelRatio(info.OriginModelName) + var success bool + modelRatio, success = common.GetModelRatio(info.OriginModelName) + if !success { + return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName) + } ratio := modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { @@ -37,5 +42,5 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens GroupRatio: groupRatio, UsePrice: usePrice, ShouldPreConsumedQuota: preConsumedQuota, - } + }, nil } diff --git a/relay/relay-audio.go b/relay/relay-audio.go index b95c1eb693f6ce30720d2e291789b4953ae40536..6263dcb955063e7fa9d3dbaaa36ea9340af0490d 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -75,7 +75,10 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo.PromptTokens = promptTokens } - priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) + priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { diff --git a/relay/relay-image.go b/relay/relay-image.go index afa5b8e2ce160c6e69039b527a50b57bfd7375f9..90b423f9707449299c2f1676face98ac4906455b 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -86,7 +86,10 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { imageRequest.Model = relayInfo.UpstreamModelName - priceData := helper.ModelPriceHelper(c, relayInfo, 0, 0) + priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } if !priceData.UsePrice { // modelRatio 16 = modelPrice $0.04 // per 1 modelRatio = $0.04 / 16 diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 766064cbc0db0f69dc5fc5f9488a4a70d9e476d8..57de8d100dd5014e2c53d8a27f819fcc074d7ae0 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "context" "encoding/json" "fmt" "io" @@ -192,7 +191,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { if err != nil { return &mjResp.Response } - defer func(ctx context.Context) { + defer func() { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { err := service.PostConsumeQuota(relayInfo, quota, 0, true) if err != nil { @@ -208,14 +207,14 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { other := make(map[string]interface{}) other["model_price"] = modelPrice other["group_ratio"] = groupRatio - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, + model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, group, other) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } } - }(c.Request.Context()) + }() midjResponse := &mjResp.Response midjourneyTask := &model.Midjourney{ UserId: userId, @@ -498,7 +497,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } midjResponse := &midjResponseWithStatus.Response - defer func(ctx context.Context) { + defer func() { if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := service.PostConsumeQuota(relayInfo, quota, 0, true) if err != nil { @@ -510,14 +509,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons other := make(map[string]interface{}) other["model_price"] = modelPrice other["group_ratio"] = groupRatio - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, + model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, group, other) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } } - }(c.Request.Context()) + }() // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md //1-提交成功 diff --git a/relay/relay-text.go b/relay/relay-text.go index bfd91cdf9cce5331d67e19aa7d9874e639cd9726..eb331e256481faab6796341ee6e2e75ca8eb2cef 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -106,8 +106,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { c.Set("prompt_tokens", promptTokens) } - priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) - + priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } // pre-consume quota 预消耗配额 preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { @@ -248,6 +250,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if userQuota-preConsumedQuota < 0 { return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden) } + relayInfo.UserQuota = userQuota if userQuota > 100*preConsumedQuota { // 用户额度充足,判断令牌额度是否充足 if !relayInfo.TokenUnlimited { @@ -267,7 +270,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo } if preConsumedQuota > 0 { - err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) + err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index c9111106c923c3c3138d403d811c204f07a7f892..00cff3168b351cbcd602d6ae30c5744c8fe93c56 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -18,6 +18,7 @@ import ( "one-api/relay/channel/mokaai" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" + "one-api/relay/channel/openrouter" "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" @@ -83,6 +84,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &volcengine.Adaptor{} case constant.APITypeBaiduV2: return &baidu_v2.Adaptor{} + case constant.APITypeOpenRouter: + return &openrouter.Adaptor{} } return nil } diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go index 18739d9f92369212e47ce77848def8ee6c3049e0..e5bfa8636deb88d792fc6242debc3bc32698e244 100644 --- a/relay/relay_embedding.go +++ b/relay/relay_embedding.go @@ -57,8 +57,10 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) promptToken := getEmbeddingPromptToken(*embeddingRequest) relayInfo.PromptTokens = promptToken - priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - + priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } // pre-consume quota 预消耗配额 preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index 37178cad3a550c19f8aec4971d893b193e9543d0..a376138711425a92e6c0a66324b4197fae8fd1ac 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -50,8 +50,10 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith promptToken := getRerankPromptToken(*rerankRequest) relayInfo.PromptTokens = promptToken - priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) - + priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } // pre-consume quota 预消耗配额 preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { diff --git a/relay/relay_task.go b/relay/relay_task.go index f03fcb2d912f4a1ec2a0e57e22379c45e396a31b..591ad3bb7207178df35e544b1d3aac7ff42ca7fc 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "context" "encoding/json" "errors" "fmt" @@ -109,11 +108,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { return } - defer func(ctx context.Context) { + defer func() { // release quota if relayInfo.ConsumeQuota && taskErr == nil { - err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true) + err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } @@ -123,13 +122,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { other := make(map[string]interface{}) other["model_price"] = modelPrice other["group_ratio"] = groupRatio - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, + model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other) model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } } - }(c.Request.Context()) + }() taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo) if taskErr != nil { diff --git a/relay/websocket.go b/relay/websocket.go index 75a7d1f0c88411007eed67bcad51370dfd3be46e..2dac60afbc0a1b98330e79a61be27f75d73b29e2 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -65,7 +65,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi //if realtimeEvent.Session.MaxResponseOutputTokens != 0 { // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens) //} - modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName) + modelRatio, _ = common.GetModelRatio(relayInfo.UpstreamModelName) ratio = modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { diff --git a/router/api-router.go b/router/api-router.go index bf88449a56bc57256da2c0f31edf8a44bb279888..bc3f5d9fe3e2252dd274360daf2f0aeff692b53f 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -84,6 +84,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/models", controller.ChannelListModels) + channelRoute.GET("/models_enabled", controller.EnabledListModels) channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test/:id", controller.TestChannel) diff --git a/router/relay-router.go b/router/relay-router.go index 63f5c36dad1990b1cca5e1859cdd45c5315a3047..32e0c682686b9c01ae7b56d18f38bfd98ac39518 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -24,6 +24,7 @@ func SetRelayRouter(router *gin.Engine) { } relayV1Router := router.Group("/v1") relayV1Router.Use(middleware.TokenAuth()) + relayV1Router.Use(middleware.ModelRequestRateLimit()) { // WebSocket 路由 wsRouter := relayV1Router.Group("") diff --git a/service/quota.go b/service/quota.go index 2cae93def4228f8aa93f0479c5301813015b9ba8..9ce2858d38bf6260ee7cba3634b4b608952c8b10 100644 --- a/service/quota.go +++ b/service/quota.go @@ -75,7 +75,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag audioInputTokens := usage.InputTokenDetails.AudioTokens audioOutTokens := usage.OutputTokenDetails.AudioTokens groupRatio := setting.GetGroupRatio(relayInfo.Group) - modelRatio := common.GetModelRatio(modelName) + modelRatio, _ := common.GetModelRatio(modelName) quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -276,7 +276,7 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu if quota > 0 { err = model.DecreaseUserQuota(relayInfo.UserId, quota) } else { - err = model.IncreaseUserQuota(relayInfo.UserId, -quota) + err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false) } if err != nil { return err @@ -295,20 +295,16 @@ func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQu if sendEmail { if (quota + preConsumedQuota) != 0 { - checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota) + checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota) } } return nil } -func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) { +func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) { gopool.Go(func() { - userCache, err := model.GetUserCache(userId) - if err != nil { - common.SysError("failed to get user cache: " + err.Error()) - } - userSetting := userCache.GetSetting() + userSetting := relayInfo.UserSetting threshold := common.QuotaRemindThreshold if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok { threshold = int(userCustomThreshold.(float64)) @@ -317,16 +313,16 @@ func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) { //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0 quotaTooLow := false consumeQuota := quota + preConsumedQuota - if userCache.Quota-consumeQuota < threshold { + if relayInfo.UserQuota-consumeQuota < threshold { quotaTooLow = true } if quotaTooLow { prompt := "您的额度即将用尽" topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" - err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink})) + err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink})) if err != nil { - common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error())) + common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error())) } } }) diff --git a/service/token_counter.go b/service/token_counter.go index 319c9b112f51052e9ba32fbf3b6e6b9833b306f8..aa62bc6e7b75b83872ba453e621e8f902a5be65b 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -1,7 +1,6 @@ package service import ( - "encoding/json" "errors" "fmt" "image" @@ -78,6 +77,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { } func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { + if text == "" { + return 0 + } return len(tokenEncoder.Encode(text, nil, nil)) } @@ -167,12 +169,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA } tkm += msgTokens if request.Tools != nil { - toolsData, _ := json.Marshal(request.Tools) - var openaiTools []dto.OpenAITools - err := json.Unmarshal(toolsData, &openaiTools) - if err != nil { - return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())) - } + openaiTools := request.Tools countStr := "" for _, tool := range openaiTools { countStr = tool.Function.Name @@ -282,30 +279,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod tokenNum += tokensPerMessage tokenNum += getTokenNum(tokenEncoder, message.Role) if len(message.Content) > 0 { - if message.IsStringContent() { - stringContent := message.StringContent() - tokenNum += getTokenNum(tokenEncoder, stringContent) - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) - } - } else { - arrayContent := message.ParseContent() - for _, m := range arrayContent { - if m.Type == dto.ContentTypeImageURL { - imageUrl := m.ImageUrl.(dto.MessageImageUrl) - imageTokenNum, err := getImageToken(info, &imageUrl, model, stream) - if err != nil { - return 0, err - } - tokenNum += imageTokenNum - log.Printf("image token num: %d", imageTokenNum) - } else if m.Type == dto.ContentTypeInputAudio { - // TODO: 音频token数量计算 - tokenNum += 100 - } else { - tokenNum += getTokenNum(tokenEncoder, m.Text) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } + arrayContent := message.ParseContent() + for _, m := range arrayContent { + if m.Type == dto.ContentTypeImageURL { + imageUrl := m.ImageUrl.(dto.MessageImageUrl) + imageTokenNum, err := getImageToken(info, &imageUrl, model, stream) + if err != nil { + return 0, err } + tokenNum += imageTokenNum + log.Printf("image token num: %d", imageTokenNum) + } else if m.Type == dto.ContentTypeInputAudio { + // TODO: 音频token数量计算 + tokenNum += 100 + } else { + tokenNum += getTokenNum(tokenEncoder, m.Text) } } } diff --git a/service/user_notify.go b/service/user_notify.go index e01b7aa9c04f3b990e9122470db1a0c596ca6919..db291f0fe73d2fc520529f1b8238a28fcf0b65cb 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -11,47 +11,45 @@ import ( func NotifyRootUser(t string, subject string, content string) { user := model.GetRootUser().ToBaseUser() - _ = NotifyUser(user, dto.NewNotify(t, subject, content, nil)) + _ = NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil)) } -func NotifyUser(user *model.UserBase, data dto.Notify) error { - userSetting := user.GetSetting() +func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error { notifyType, ok := userSetting[constant.UserSettingNotifyType] if !ok { notifyType = constant.NotifyTypeEmail } // Check notification limit - canSend, err := CheckNotificationLimit(user.Id, data.Type) + canSend, err := CheckNotificationLimit(userId, data.Type) if err != nil { common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) return err } if !canSend { - return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType) + return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType) } switch notifyType { case constant.NotifyTypeEmail: - userEmail := user.Email // check setting email if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok { userEmail = settingEmail.(string) } if userEmail == "" { - common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id)) + common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) return nil } return sendEmailNotify(userEmail, data) case constant.NotifyTypeWebhook: webhookURL, ok := userSetting[constant.UserSettingWebhookUrl] if !ok { - common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id)) + common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) return nil } webhookURLStr, ok := webhookURL.(string) if !ok { - common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id)) + common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId)) return nil } diff --git a/setting/config/config.go b/setting/config/config.go new file mode 100644 index 0000000000000000000000000000000000000000..3af51b146482b484638c1aa1b85c029e0eb700eb --- /dev/null +++ b/setting/config/config.go @@ -0,0 +1,259 @@ +package config + +import ( + "encoding/json" + "one-api/common" + "reflect" + "strconv" + "strings" + "sync" +) + +// ConfigManager 统一管理所有配置 +type ConfigManager struct { + configs map[string]interface{} + mutex sync.RWMutex +} + +var GlobalConfig = NewConfigManager() + +func NewConfigManager() *ConfigManager { + return &ConfigManager{ + configs: make(map[string]interface{}), + } +} + +// Register 注册一个配置模块 +func (cm *ConfigManager) Register(name string, config interface{}) { + cm.mutex.Lock() + defer cm.mutex.Unlock() + cm.configs[name] = config +} + +// Get 获取指定配置模块 +func (cm *ConfigManager) Get(name string) interface{} { + cm.mutex.RLock() + defer cm.mutex.RUnlock() + return cm.configs[name] +} + +// LoadFromDB 从数据库加载配置 +func (cm *ConfigManager) LoadFromDB(options map[string]string) error { + cm.mutex.Lock() + defer cm.mutex.Unlock() + + for name, config := range cm.configs { + prefix := name + "." + configMap := make(map[string]string) + + // 收集属于此配置的所有选项 + for key, value := range options { + if strings.HasPrefix(key, prefix) { + configKey := strings.TrimPrefix(key, prefix) + configMap[configKey] = value + } + } + + // 如果找到配置项,则更新配置 + if len(configMap) > 0 { + if err := updateConfigFromMap(config, configMap); err != nil { + common.SysError("failed to update config " + name + ": " + err.Error()) + continue + } + } + } + + return nil +} + +// SaveToDB 将配置保存到数据库 +func (cm *ConfigManager) SaveToDB(updateFunc func(key, value string) error) error { + cm.mutex.RLock() + defer cm.mutex.RUnlock() + + for name, config := range cm.configs { + configMap, err := configToMap(config) + if err != nil { + return err + } + + for key, value := range configMap { + dbKey := name + "." + key + if err := updateFunc(dbKey, value); err != nil { + return err + } + } + } + + return nil +} + +// 辅助函数:将配置对象转换为map +func configToMap(config interface{}) (map[string]string, error) { + result := make(map[string]string) + + val := reflect.ValueOf(config) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + if val.Kind() != reflect.Struct { + return nil, nil + } + + typ := val.Type() + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + fieldType := typ.Field(i) + + // 跳过未导出字段 + if !fieldType.IsExported() { + continue + } + + // 获取json标签作为键名 + key := fieldType.Tag.Get("json") + if key == "" || key == "-" { + key = fieldType.Name + } + + // 处理不同类型的字段 + var strValue string + switch field.Kind() { + case reflect.String: + strValue = field.String() + case reflect.Bool: + strValue = strconv.FormatBool(field.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + strValue = strconv.FormatInt(field.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + strValue = strconv.FormatUint(field.Uint(), 10) + case reflect.Float32, reflect.Float64: + strValue = strconv.FormatFloat(field.Float(), 'f', -1, 64) + case reflect.Map, reflect.Slice, reflect.Struct: + // 复杂类型使用JSON序列化 + bytes, err := json.Marshal(field.Interface()) + if err != nil { + return nil, err + } + strValue = string(bytes) + default: + // 跳过不支持的类型 + continue + } + + result[key] = strValue + } + + return result, nil +} + +// 辅助函数:从map更新配置对象 +func updateConfigFromMap(config interface{}, configMap map[string]string) error { + val := reflect.ValueOf(config) + if val.Kind() != reflect.Ptr { + return nil + } + val = val.Elem() + + if val.Kind() != reflect.Struct { + return nil + } + + typ := val.Type() + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + fieldType := typ.Field(i) + + // 跳过未导出字段 + if !fieldType.IsExported() { + continue + } + + // 获取json标签作为键名 + key := fieldType.Tag.Get("json") + if key == "" || key == "-" { + key = fieldType.Name + } + + // 检查map中是否有对应的值 + strValue, ok := configMap[key] + if !ok { + continue + } + + // 根据字段类型设置值 + if !field.CanSet() { + continue + } + + switch field.Kind() { + case reflect.String: + field.SetString(strValue) + case reflect.Bool: + boolValue, err := strconv.ParseBool(strValue) + if err != nil { + continue + } + field.SetBool(boolValue) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + intValue, err := strconv.ParseInt(strValue, 10, 64) + if err != nil { + continue + } + field.SetInt(intValue) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + uintValue, err := strconv.ParseUint(strValue, 10, 64) + if err != nil { + continue + } + field.SetUint(uintValue) + case reflect.Float32, reflect.Float64: + floatValue, err := strconv.ParseFloat(strValue, 64) + if err != nil { + continue + } + field.SetFloat(floatValue) + case reflect.Map, reflect.Slice, reflect.Struct: + // 复杂类型使用JSON反序列化 + err := json.Unmarshal([]byte(strValue), field.Addr().Interface()) + if err != nil { + continue + } + } + } + + return nil +} + +// ConfigToMap 将配置对象转换为map(导出函数) +func ConfigToMap(config interface{}) (map[string]string, error) { + return configToMap(config) +} + +// UpdateConfigFromMap 从map更新配置对象(导出函数) +func UpdateConfigFromMap(config interface{}, configMap map[string]string) error { + return updateConfigFromMap(config, configMap) +} + +// ExportAllConfigs 导出所有已注册的配置为扁平结构 +func (cm *ConfigManager) ExportAllConfigs() map[string]string { + cm.mutex.RLock() + defer cm.mutex.RUnlock() + + result := make(map[string]string) + + for name, cfg := range cm.configs { + configMap, err := ConfigToMap(cfg) + if err != nil { + continue + } + + // 使用 "模块名.配置项" 的格式添加到结果中 + for key, value := range configMap { + result[name+"."+key] = value + } + } + + return result +} diff --git a/setting/model_setting/claude.go b/setting/model_setting/claude.go new file mode 100644 index 0000000000000000000000000000000000000000..0498318217d61899c73094506a70bf7cb73b5f1d --- /dev/null +++ b/setting/model_setting/claude.go @@ -0,0 +1,65 @@ +package model_setting + +import ( + "net/http" + "one-api/setting/config" +) + +//var claudeHeadersSettings = map[string][]string{} +// +//var ClaudeThinkingAdapterEnabled = true +//var ClaudeThinkingAdapterMaxTokens = 8192 +//var ClaudeThinkingAdapterBudgetTokensPercentage = 0.8 + +// ClaudeSettings 定义Claude模型的配置 +type ClaudeSettings struct { + HeadersSettings map[string]map[string][]string `json:"model_headers_settings"` + DefaultMaxTokens map[string]int `json:"default_max_tokens"` + ThinkingAdapterEnabled bool `json:"thinking_adapter_enabled"` + ThinkingAdapterBudgetTokensPercentage float64 `json:"thinking_adapter_budget_tokens_percentage"` +} + +// 默认配置 +var defaultClaudeSettings = ClaudeSettings{ + HeadersSettings: map[string]map[string][]string{}, + ThinkingAdapterEnabled: true, + DefaultMaxTokens: map[string]int{ + "default": 8192, + }, + ThinkingAdapterBudgetTokensPercentage: 0.8, +} + +// 全局实例 +var claudeSettings = defaultClaudeSettings + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("claude", &claudeSettings) +} + +// GetClaudeSettings 获取Claude配置 +func GetClaudeSettings() *ClaudeSettings { + // check default max tokens must have default key + if _, ok := claudeSettings.DefaultMaxTokens["default"]; !ok { + claudeSettings.DefaultMaxTokens["default"] = 8192 + } + return &claudeSettings +} + +func (c *ClaudeSettings) WriteHeaders(originModel string, httpHeader *http.Header) { + if headers, ok := c.HeadersSettings[originModel]; ok { + for headerKey, headerValues := range headers { + httpHeader.Del(headerKey) + for _, headerValue := range headerValues { + httpHeader.Add(headerKey, headerValue) + } + } + } +} + +func (c *ClaudeSettings) GetDefaultMaxTokens(model string) int { + if maxTokens, ok := c.DefaultMaxTokens[model]; ok { + return maxTokens + } + return c.DefaultMaxTokens["default"] +} diff --git a/setting/model_setting/gemini.go b/setting/model_setting/gemini.go new file mode 100644 index 0000000000000000000000000000000000000000..07e993bc1ff5eb87995e76d5eea77d85aa842af8 --- /dev/null +++ b/setting/model_setting/gemini.go @@ -0,0 +1,52 @@ +package model_setting + +import ( + "one-api/setting/config" +) + +// GeminiSettings 定义Gemini模型的配置 +type GeminiSettings struct { + SafetySettings map[string]string `json:"safety_settings"` + VersionSettings map[string]string `json:"version_settings"` +} + +// 默认配置 +var defaultGeminiSettings = GeminiSettings{ + SafetySettings: map[string]string{ + "default": "OFF", + "HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE", + }, + VersionSettings: map[string]string{ + "default": "v1beta", + "gemini-1.0-pro": "v1", + }, +} + +// 全局实例 +var geminiSettings = defaultGeminiSettings + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("gemini", &geminiSettings) +} + +// GetGeminiSettings 获取Gemini配置 +func GetGeminiSettings() *GeminiSettings { + return &geminiSettings +} + +// GetGeminiSafetySetting 获取安全设置 +func GetGeminiSafetySetting(key string) string { + if value, ok := geminiSettings.SafetySettings[key]; ok { + return value + } + return geminiSettings.SafetySettings["default"] +} + +// GetGeminiVersionSetting 获取版本设置 +func GetGeminiVersionSetting(key string) string { + if value, ok := geminiSettings.VersionSettings[key]; ok { + return value + } + return geminiSettings.VersionSettings["default"] +} diff --git a/setting/rate_limit.go b/setting/rate_limit.go new file mode 100644 index 0000000000000000000000000000000000000000..4b2169489e2e388daadde669144250ce156603b2 --- /dev/null +++ b/setting/rate_limit.go @@ -0,0 +1,6 @@ +package setting + +var ModelRequestRateLimitEnabled = false +var ModelRequestRateLimitDurationMinutes = 1 +var ModelRequestRateLimitCount = 0 +var ModelRequestRateLimitSuccessCount = 1000 diff --git a/web/src/components/ModelSetting.js b/web/src/components/ModelSetting.js new file mode 100644 index 0000000000000000000000000000000000000000..904b40153c310b2f8578cd6fe304df3c55245c4d --- /dev/null +++ b/web/src/components/ModelSetting.js @@ -0,0 +1,83 @@ +import React, { useEffect, useState } from 'react'; +import { Card, Spin, Tabs } from '@douyinfe/semi-ui'; + + +import { API, showError, showSuccess } from '../helpers'; +import { useTranslation } from 'react-i18next'; +import SettingGeminiModel from '../pages/Setting/Model/SettingGeminiModel.js'; +import SettingClaudeModel from '../pages/Setting/Model/SettingClaudeModel.js'; + +const ModelSetting = () => { + const { t } = useTranslation(); + let [inputs, setInputs] = useState({ + 'gemini.safety_settings': '', + 'gemini.version_settings': '', + 'claude.model_headers_settings': '', + 'claude.thinking_adapter_enabled': true, + 'claude.default_max_tokens': '', + 'claude.thinking_adapter_budget_tokens_percentage': 0.8, + }); + + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if ( + item.key === 'gemini.safety_settings' || + item.key === 'gemini.version_settings' || + item.key === 'claude.model_headers_settings'|| + item.key === 'claude.default_max_tokens' + ) { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + if ( + item.key.endsWith('Enabled') + ) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + + setInputs(newInputs); + } else { + showError(message); + } + }; + async function onRefresh() { + try { + setLoading(true); + await getOptions(); + // showSuccess('刷新成功'); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } + } + + useEffect(() => { + onRefresh(); + }, []); + + return ( + <> + + {/* Gemini */} + + + + {/* Claude */} + + + + + + ); +}; + +export default ModelSetting; diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index caa9cc2ee25429c206ba8e39bfad51761afc51de..19f2dbe63fd87cd2607297160c0f382137fe91a4 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -16,6 +16,7 @@ import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js import { API, showError, showSuccess } from '../helpers'; import SettingsChats from '../pages/Setting/Operation/SettingsChats.js'; import { useTranslation } from 'react-i18next'; +import ModelRatioNotSetEditor from '../pages/Setting/Operation/ModelRationNotSetEditor.js'; const OperationSetting = () => { const { t } = useTranslation(); @@ -158,6 +159,9 @@ const OperationSetting = () => { + + + diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js new file mode 100644 index 0000000000000000000000000000000000000000..b6c92917a638d05de9ffeaaedea20d203b161d99 --- /dev/null +++ b/web/src/components/RateLimitSetting.js @@ -0,0 +1,80 @@ +import React, { useEffect, useState } from 'react'; +import { Card, Spin, Tabs } from '@douyinfe/semi-ui'; +import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js'; +import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js'; +import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js'; +import SettingsLog from '../pages/Setting/Operation/SettingsLog.js'; +import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js'; +import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js'; +import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js'; +import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js'; +import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js'; +import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js'; +import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js'; + + +import { API, showError, showSuccess } from '../helpers'; +import SettingsChats from '../pages/Setting/Operation/SettingsChats.js'; +import { useTranslation } from 'react-i18next'; +import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimit.js'; + +const RateLimitSetting = () => { + const { t } = useTranslation(); + let [inputs, setInputs] = useState({ + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: 0, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + }); + + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if ( + item.key.endsWith('Enabled') + ) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + + setInputs(newInputs); + } else { + showError(message); + } + }; + async function onRefresh() { + try { + setLoading(true); + await getOptions(); + // showSuccess('刷新成功'); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } + } + + useEffect(() => { + onRefresh(); + }, []); + + return ( + <> + + {/* AI请求速率限制 */} + + + + + + ); +}; + +export default RateLimitSetting; diff --git a/web/src/components/UsersTable.js b/web/src/components/UsersTable.js index 00ad6de47ae02275d413cc6b54d52d76aab7ae2a..a1e43b455265d180eda1859e62b774a37c1cad79 100644 --- a/web/src/components/UsersTable.js +++ b/web/src/components/UsersTable.js @@ -376,7 +376,7 @@ const UsersTable = () => { if (searchKeyword === '') { await loadUsers(activePage, pageSize); } else { - await searchUsers(searchKeyword, searchGroup); + await searchUsers(activePage, pageSize, searchKeyword, searchGroup); } }; diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index a2cb3b831f50be2043633f68d639e030c3567226..aa2fb2d5eed6b9a89ec0a7aed66e2a9e7adeba18 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -856,7 +856,7 @@ "IP黑名单": "IP blacklist", "不允许的IP,一行一个": "IPs not allowed, one per line", "请选择该渠道所支持的模型": "Please select the model supported by this channel", - "次": "Second-rate", + "次": "times", "达到限速报错内容": "Error content when the speed limit is reached", "不填则使用默认报错": "If not filled in, the default error will be reported.", "Midjouney 设置 (可选)": "Midjouney settings (optional)", @@ -1074,10 +1074,9 @@ "删除所选通道": "Delete selected channels", "标签聚合模式": "Enable tag mode", "没有账户?": "No account? ", - "注意,模型部署名称必须和模型名称保持一致,因为 One API 会把请求体中的 model 参数替换为你的部署名称(模型名称中的点会被剔除)": "Note: The model deployment name must match the model name because One API will replace the model parameter in the request body with your deployment name (dots in the model name will be removed)", "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com": "Please enter AZURE_OPENAI_ENDPOINT, e.g.: https://docs-test-001.openai.azure.com", "默认 API 版本": "Default API Version", - "请输入默认 API 版本,例如:2023-06-01-preview,该配置可以被实际的请求查询参数所覆盖": "Please enter default API version, e.g.: 2023-06-01-preview. This configuration can be overridden by actual request query parameters", + "请输入默认 API 版本,例如:2024-12-01-preview": "Please enter default API version, e.g.: 2024-12-01-preview.", "请为渠道命名": "Please name the channel", "请选择可以使用该渠道的分组": "Please select groups that can use this channel", "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit Group ratios in system settings to add new groups:", @@ -1270,5 +1269,53 @@ "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "Set the email address for receiving quota warning notifications, if not set, the email address bound to the account will be used", "留空则使用账号绑定的邮箱": "If left blank, the email address bound to the account will be used", "代理站地址": "Base URL", - "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in" -} \ No newline at end of file + "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in", + "渠道额外设置": "Channel extra settings", + "模型请求速率限制": "Model request rate limit", + "启用用户模型请求速率限制(可能会影响高并发性能)": "Enable user model request rate limit (may affect high concurrency performance)", + "限制周期": "Limit period", + "用户每周期最多请求次数": "User max request times per period", + "用户每周期最多请求完成次数": "User max successful request times per period", + "包括失败请求的次数,0代表不限制": "Including failed request times, 0 means no limit", + "频率限制的周期(分钟)": "Rate limit period (minutes)", + "只包括请求成功的次数": "Only include successful request times", + "保存模型速率限制": "Save model rate limit settings", + "速率限制设置": "Rate limit settings", + "获取启用模型失败:": "Failed to get enabled models:", + "获取启用模型失败": "Failed to get enabled models", + "JSON解析错误:": "JSON parsing error:", + "保存失败:": "Save failed:", + "输入模型倍率": "Enter model ratio", + "输入补全倍率": "Enter completion ratio", + "请输入数字": "Please enter a number", + "模型名称已存在": "Model name already exists", + "添加成功": "Added successfully", + "请先选择需要批量设置的模型": "Please select models for batch setting first", + "请输入模型倍率和补全倍率": "Please enter model ratio and completion ratio", + "请输入有效的数字": "Please enter a valid number", + "请输入填充值": "Please enter a value", + "批量设置成功": "Batch setting successful", + "已为 {{count}} 个模型设置{{type}}": "Set {{type}} for {{count}} models", + "固定价格": "Fixed Price", + "模型倍率和补全倍率": "Model Ratio and Completion Ratio", + "批量设置": "Batch Setting", + "搜索模型名称": "Search model name", + "此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除": "This page only shows models without price or ratio settings. After setting, they will be automatically removed from the list", + "没有未设置的模型": "No unconfigured models", + "定价模式": "Pricing Mode", + "固定价格(每次)": "Fixed Price (per use)", + "输入每次价格": "Enter per-use price", + "批量设置模型参数": "Batch Set Model Parameters", + "设置类型": "Setting Type", + "模型倍率值": "Model Ratio Value", + "补全倍率值": "Completion Ratio Value", + "请输入模型倍率": "Enter model ratio", + "请输入补全倍率": "Enter completion ratio", + "请输入数值": "Enter a value", + "将为选中的 ": "Will set for selected ", + " 个模型设置相同的值": " models with the same value", + "当前设置类型: ": "Current setting type: ", + "固定价格值": "Fixed Price Value", + "未设置倍率模型": "Models without ratio settings", + "模型倍率和补全倍率同时设置": "Both model ratio and completion ratio are set" +} diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 4720100aa9fec4c9f80decdb5417d9a71a588144..bfc611fe7fd52b5b9a6996dced8d11e577d17ede 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -327,9 +327,6 @@ const EditChannel = (props) => { localInputs.base_url.length - 1 ); } - if (localInputs.type === 3 && localInputs.other === '') { - localInputs.other = '2023-06-01-preview'; - } if (localInputs.type === 18 && localInputs.other === '') { localInputs.other = 'v2.1'; } @@ -494,7 +491,7 @@ const EditChannel = (props) => { { handleInputChange('other', value); }} diff --git a/web/src/pages/Setting/Model/SettingClaudeModel.js b/web/src/pages/Setting/Model/SettingClaudeModel.js new file mode 100644 index 0000000000000000000000000000000000000000..76ee8cfa351ebef3b87def51832d8fada9f1845d --- /dev/null +++ b/web/src/pages/Setting/Model/SettingClaudeModel.js @@ -0,0 +1,169 @@ +import React, { useEffect, useState, useRef } from 'react'; +import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui'; +import { + compareObjects, + API, + showError, + showSuccess, + showWarning, verifyJSON +} from '../../../helpers'; +import { useTranslation } from 'react-i18next'; +import Text from '@douyinfe/semi-ui/lib/es/typography/text'; + +const CLAUDE_HEADER = { + 'claude-3-7-sonnet-20250219-thinking': { + 'anthropic-beta': ['output-128k-2025-02-19', 'token-efficient-tools-2025-02-19'], + } +}; + +const CLAUDE_DEFAULT_MAX_TOKENS = { + 'default': 8192, + 'claude-3-7-sonnet-20250219-thinking': 8192, +} + +export default function SettingClaudeModel(props) { + const { t } = useTranslation(); + + const [loading, setLoading] = useState(false); + const [inputs, setInputs] = useState({ + 'claude.model_headers_settings': '', + 'claude.thinking_adapter_enabled': true, + 'claude.default_max_tokens': '', + 'claude.thinking_adapter_budget_tokens_percentage': 0.8, + }); + const refForm = useRef(); + const [inputsRow, setInputsRow] = useState(inputs); + + function onSubmit() { + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = String(inputs[item.key]); + + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); + } + + useEffect(() => { + const currentInputs = {}; + for (let key in props.options) { + if (Object.keys(inputs).includes(key)) { + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + refForm.current.setValues(currentInputs); + }, [props.options]); + + return ( + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串') + } + ]} + onChange={(value) => setInputs({ ...inputs, 'claude.model_headers_settings': value })} + /> + + + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串') + } + ]} + onChange={(value) => setInputs({ ...inputs, 'claude.default_max_tokens': value })} + /> + + + + + setInputs({ ...inputs, 'claude.thinking_adapter_enabled': value })} + /> + + + + + {/*//展示MaxTokens和BudgetTokens的计算公式, 并展示实际数字*/} + + {t('Claude思考适配 BudgetTokens = MaxTokens * BudgetTokens 百分比')} + + + + + + setInputs({ ...inputs, 'claude.thinking_adapter_budget_tokens_percentage': value })} + /> + + + + + + + +
+
+ + ); +} diff --git a/web/src/pages/Setting/Model/SettingGeminiModel.js b/web/src/pages/Setting/Model/SettingGeminiModel.js new file mode 100644 index 0000000000000000000000000000000000000000..6fc08a87500fcd259f9818edad2cdae5ed1cb3d5 --- /dev/null +++ b/web/src/pages/Setting/Model/SettingGeminiModel.js @@ -0,0 +1,139 @@ +import React, { useEffect, useState, useRef } from 'react'; +import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui'; +import { + compareObjects, + API, + showError, + showSuccess, + showWarning, verifyJSON +} from '../../../helpers'; +import { useTranslation } from 'react-i18next'; + +const GEMINI_SETTING_EXAMPLE = { + 'default': 'OFF', + 'HARM_CATEGORY_CIVIC_INTEGRITY': 'BLOCK_NONE', +}; + +const GEMINI_VERSION_EXAMPLE = { + 'default': 'v1beta', +}; + + +export default function SettingGeminiModel(props) { + const { t } = useTranslation(); + + const [loading, setLoading] = useState(false); + const [inputs, setInputs] = useState({ + 'gemini.safety_settings': '', + 'gemini.version_settings': '', + }); + const refForm = useRef(); + const [inputsRow, setInputsRow] = useState(inputs); + + function onSubmit() { + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); + } + + useEffect(() => { + const currentInputs = {}; + for (let key in props.options) { + if (Object.keys(inputs).includes(key)) { + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + refForm.current.setValues(currentInputs); + }, [props.options]); + + return ( + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串') + } + ]} + onChange={(value) => setInputs({ ...inputs, 'gemini.safety_settings': value })} + /> + + + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串') + } + ]} + onChange={(value) => setInputs({ ...inputs, 'gemini.version_settings': value })} + /> + + + + + + + +
+
+ + ); +} diff --git a/web/src/pages/Setting/Operation/ModelRationNotSetEditor.js b/web/src/pages/Setting/Operation/ModelRationNotSetEditor.js new file mode 100644 index 0000000000000000000000000000000000000000..e98e0a3c46d6535470ab34d12a4f0dc308876fca --- /dev/null +++ b/web/src/pages/Setting/Operation/ModelRationNotSetEditor.js @@ -0,0 +1,549 @@ +import React, { useEffect, useState } from 'react'; +import { Table, Button, Input, Modal, Form, Space, Typography, Radio, Notification } from '@douyinfe/semi-ui'; +import { IconDelete, IconPlus, IconSearch, IconSave, IconBolt } from '@douyinfe/semi-icons'; +import { showError, showSuccess } from '../../../helpers'; +import { API } from '../../../helpers'; +import { useTranslation } from 'react-i18next'; + +export default function ModelRatioNotSetEditor(props) { + const { t } = useTranslation(); + const [models, setModels] = useState([]); + const [visible, setVisible] = useState(false); + const [batchVisible, setBatchVisible] = useState(false); + const [currentModel, setCurrentModel] = useState(null); + const [searchText, setSearchText] = useState(''); + const [currentPage, setCurrentPage] = useState(1); + const [pageSize, setPageSize] = useState(10); + const [loading, setLoading] = useState(false); + const [enabledModels, setEnabledModels] = useState([]); + const [selectedRowKeys, setSelectedRowKeys] = useState([]); + const [batchFillType, setBatchFillType] = useState('ratio'); + const [batchFillValue, setBatchFillValue] = useState(''); + const [batchRatioValue, setBatchRatioValue] = useState(''); + const [batchCompletionRatioValue, setBatchCompletionRatioValue] = useState(''); + const { Text } = Typography; + // 定义可选的每页显示条数 + const pageSizeOptions = [10, 20, 50, 100]; + + const getAllEnabledModels = async () => { + try { + const res = await API.get('/api/channel/models_enabled'); + const { success, message, data } = res.data; + if (success) { + setEnabledModels(data); + } else { + showError(message); + } + } catch (error) { + console.error(t('获取启用模型失败:'), error); + showError(t('获取启用模型失败')); + } + } + + useEffect(() => { + // 获取所有启用的模型 + getAllEnabledModels(); + }, []); + + useEffect(() => { + try { + const modelPrice = JSON.parse(props.options.ModelPrice || '{}'); + const modelRatio = JSON.parse(props.options.ModelRatio || '{}'); + const completionRatio = JSON.parse(props.options.CompletionRatio || '{}'); + + // 找出所有未设置价格和倍率的模型 + const unsetModels = enabledModels.filter(modelName => { + const hasPrice = modelPrice[modelName] !== undefined; + const hasRatio = modelRatio[modelName] !== undefined; + + // 如果模型没有价格或者没有倍率设置,则显示 + return !hasPrice && !hasRatio; + }); + + // 创建模型数据 + const modelData = unsetModels.map(name => ({ + name, + price: modelPrice[name] || '', + ratio: modelRatio[name] || '', + completionRatio: completionRatio[name] || '' + })); + + setModels(modelData); + // 清空选择 + setSelectedRowKeys([]); + } catch (error) { + console.error(t('JSON解析错误:'), error); + } + }, [props.options, enabledModels]); + + // 首先声明分页相关的工具函数 + const getPagedData = (data, currentPage, pageSize) => { + const start = (currentPage - 1) * pageSize; + const end = start + pageSize; + return data.slice(start, end); + }; + + // 处理页面大小变化 + const handlePageSizeChange = (size) => { + setPageSize(size); + // 重新计算当前页,避免数据丢失 + const totalPages = Math.ceil(filteredModels.length / size); + if (currentPage > totalPages) { + setCurrentPage(totalPages || 1); + } + }; + + // 在 return 语句之前,先处理过滤和分页逻辑 + const filteredModels = models.filter(model => + searchText ? model.name.toLowerCase().includes(searchText.toLowerCase()) : true + ); + + // 然后基于过滤后的数据计算分页数据 + const pagedData = getPagedData(filteredModels, currentPage, pageSize); + + const SubmitData = async () => { + setLoading(true); + const output = { + ModelPrice: JSON.parse(props.options.ModelPrice || '{}'), + ModelRatio: JSON.parse(props.options.ModelRatio || '{}'), + CompletionRatio: JSON.parse(props.options.CompletionRatio || '{}') + }; + + try { + // 数据转换 - 只处理已修改的模型 + models.forEach(model => { + // 只有当用户设置了值时才更新 + if (model.price !== '') { + // 如果价格不为空,则转换为浮点数,忽略倍率参数 + output.ModelPrice[model.name] = parseFloat(model.price); + } else { + if (model.ratio !== '') output.ModelRatio[model.name] = parseFloat(model.ratio); + if (model.completionRatio !== '') output.CompletionRatio[model.name] = parseFloat(model.completionRatio); + } + }); + + // 准备API请求数组 + const finalOutput = { + ModelPrice: JSON.stringify(output.ModelPrice, null, 2), + ModelRatio: JSON.stringify(output.ModelRatio, null, 2), + CompletionRatio: JSON.stringify(output.CompletionRatio, null, 2) + }; + + const requestQueue = Object.entries(finalOutput).map(([key, value]) => { + return API.put('/api/option/', { + key, + value + }); + }); + + // 批量处理请求 + const results = await Promise.all(requestQueue); + + // 验证结果 + if (requestQueue.length === 1) { + if (results.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (results.includes(undefined)) { + return showError(t('部分保存失败,请重试')); + } + } + + // 检查每个请求的结果 + for (const res of results) { + if (!res.data.success) { + return showError(res.data.message); + } + } + + showSuccess(t('保存成功')); + props.refresh(); + // 重新获取未设置的模型 + getAllEnabledModels(); + + } catch (error) { + console.error(t('保存失败:'), error); + showError(t('保存失败,请重试')); + } finally { + setLoading(false); + } + }; + + const columns = [ + { + title: t('模型名称'), + dataIndex: 'name', + key: 'name', + }, + { + title: t('模型固定价格'), + dataIndex: 'price', + key: 'price', + render: (text, record) => ( + updateModel(record.name, 'price', value)} + /> + ) + }, + { + title: t('模型倍率'), + dataIndex: 'ratio', + key: 'ratio', + render: (text, record) => ( + updateModel(record.name, 'ratio', value)} + /> + ) + }, + { + title: t('补全倍率'), + dataIndex: 'completionRatio', + key: 'completionRatio', + render: (text, record) => ( + updateModel(record.name, 'completionRatio', value)} + /> + ) + } + ]; + + const updateModel = (name, field, value) => { + if (value !== '' && isNaN(value)) { + showError(t('请输入数字')); + return; + } + setModels(prev => + prev.map(model => + model.name === name + ? { ...model, [field]: value } + : model + ) + ); + }; + + const addModel = (values) => { + // 检查模型名称是否存在, 如果存在则拒绝添加 + if (models.some(model => model.name === values.name)) { + showError(t('模型名称已存在')); + return; + } + setModels(prev => [{ + name: values.name, + price: values.price || '', + ratio: values.ratio || '', + completionRatio: values.completionRatio || '' + }, ...prev]); + setVisible(false); + showSuccess(t('添加成功')); + }; + + // 批量填充功能 + const handleBatchFill = () => { + if (selectedRowKeys.length === 0) { + showError(t('请先选择需要批量设置的模型')); + return; + } + + if (batchFillType === 'bothRatio') { + if (batchRatioValue === '' || batchCompletionRatioValue === '') { + showError(t('请输入模型倍率和补全倍率')); + return; + } + if (isNaN(batchRatioValue) || isNaN(batchCompletionRatioValue)) { + showError(t('请输入有效的数字')); + return; + } + } else { + if (batchFillValue === '') { + showError(t('请输入填充值')); + return; + } + if (isNaN(batchFillValue)) { + showError(t('请输入有效的数字')); + return; + } + } + + // 根据选择的类型批量更新模型 + setModels(prev => + prev.map(model => { + if (selectedRowKeys.includes(model.name)) { + if (batchFillType === 'price') { + return { + ...model, + price: batchFillValue, + ratio: '', + completionRatio: '' + }; + } else if (batchFillType === 'ratio') { + return { + ...model, + price: '', + ratio: batchFillValue + }; + } else if (batchFillType === 'completionRatio') { + return { + ...model, + price: '', + completionRatio: batchFillValue + }; + } else if (batchFillType === 'bothRatio') { + return { + ...model, + price: '', + ratio: batchRatioValue, + completionRatio: batchCompletionRatioValue + }; + } + } + return model; + }) + ); + + setBatchVisible(false); + Notification.success({ + title: t('批量设置成功'), + content: t('已为 {{count}} 个模型设置{{type}}', { + count: selectedRowKeys.length, + type: batchFillType === 'price' ? t('固定价格') : + batchFillType === 'ratio' ? t('模型倍率') : + batchFillType === 'completionRatio' ? t('补全倍率') : t('模型倍率和补全倍率') + }), + duration: 3, + }); + }; + + const handleBatchTypeChange = (value) => { + console.log(t('Changing batch type to:'), value); + setBatchFillType(value); + + // 切换类型时清空对应的值 + if (value !== 'bothRatio') { + setBatchFillValue(''); + } else { + setBatchRatioValue(''); + setBatchCompletionRatioValue(''); + } + }; + + const rowSelection = { + selectedRowKeys, + onChange: (selectedKeys) => { + setSelectedRowKeys(selectedKeys); + }, + }; + + return ( + <> + + + + + + } + placeholder={t('搜索模型名称')} + value={searchText} + onChange={value => { + setSearchText(value) + setCurrentPage(1); + }} + style={{ width: 200 }} + /> + + + {t('此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除')} + + setCurrentPage(page), + onPageSizeChange: handlePageSizeChange, + pageSizeOptions: pageSizeOptions, + formatPageText: (page) => + t('第 {{start}} - {{end}} 条,共 {{total}} 条', { + start: page.currentStart, + end: page.currentEnd, + total: filteredModels.length + }), + showTotal: true, + showSizeChanger: true + }} + empty={ +
+ {t('没有未设置的模型')} +
+ } + /> + + + {/* 添加模型弹窗 */} + setVisible(false)} + onOk={() => { + currentModel && addModel(currentModel); + }} + > +
+ setCurrentModel(prev => ({ ...prev, name: value }))} + /> + {t('定价模式')}:{currentModel?.priceMode ? t("固定价格") : t("倍率模式")}} + onChange={checked => { + setCurrentModel(prev => ({ + ...prev, + price: '', + ratio: '', + completionRatio: '', + priceMode: checked + })); + }} + /> + {currentModel?.priceMode ? ( + setCurrentModel(prev => ({ ...prev, price: value }))} + /> + ) : ( + <> + setCurrentModel(prev => ({ ...prev, ratio: value }))} + /> + setCurrentModel(prev => ({ ...prev, completionRatio: value }))} + /> + + )} + +
+ + {/* 批量设置弹窗 */} + setBatchVisible(false)} + onOk={handleBatchFill} + width={500} + > +
+ +
+ + handleBatchTypeChange('price')} + > + {t('固定价格')} + + handleBatchTypeChange('ratio')} + > + {t('模型倍率')} + + handleBatchTypeChange('completionRatio')} + > + {t('补全倍率')} + + handleBatchTypeChange('bothRatio')} + > + {t('模型倍率和补全倍率同时设置')} + + +
+
+ + {batchFillType === 'bothRatio' ? ( + <> + setBatchRatioValue(value)} + /> + setBatchCompletionRatioValue(value)} + /> + + ) : ( + setBatchFillValue(value)} + /> + )} + + + {t('将为选中的 ')} {selectedRowKeys.length} {t(' 个模型设置相同的值')} + +
+ + {t('当前设置类型: ')} { + batchFillType === 'price' ? t('固定价格') : + batchFillType === 'ratio' ? t('模型倍率') : + batchFillType === 'completionRatio' ? t('补全倍率') : t('模型倍率和补全倍率') + } + +
+ +
+ + ); +} diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js new file mode 100644 index 0000000000000000000000000000000000000000..6f4a5571d4342368c0e77a73fa3ce20ec5264106 --- /dev/null +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -0,0 +1,159 @@ +import React, { useEffect, useState, useRef } from 'react'; +import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui'; +import { + compareObjects, + API, + showError, + showSuccess, + showWarning, +} from '../../../helpers'; +import { useTranslation } from 'react-i18next'; + +export default function RequestRateLimit(props) { + const { t } = useTranslation(); + + const [loading, setLoading] = useState(false); + const [inputs, setInputs] = useState({ + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: -1, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1 + }); + const refForm = useRef(); + const [inputsRow, setInputsRow] = useState(inputs); + + function onSubmit() { + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); + } + + useEffect(() => { + const currentInputs = {}; + for (let key in props.options) { + if (Object.keys(inputs).includes(key)) { + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + refForm.current.setValues(currentInputs); + }, [props.options]); + + return ( + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + +
+ { + setInputs({ + ...inputs, + ModelRequestRateLimitEnabled: value, + }); + }} + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitDurationMinutes: String(value), + }) + } + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitCount: String(value), + }) + } + /> + + + + setInputs({ + ...inputs, + ModelRequestRateLimitSuccessCount: String(value), + }) + } + /> + + + + + + + + + + ); +} diff --git a/web/src/pages/Setting/index.js b/web/src/pages/Setting/index.js index 385fbfebae49cb210ba1f2ac1eb6bba97163e451..17a850884c950797720af6c0570feedb0852f8cf 100644 --- a/web/src/pages/Setting/index.js +++ b/web/src/pages/Setting/index.js @@ -8,6 +8,8 @@ import { isRoot } from '../../helpers'; import OtherSetting from '../../components/OtherSetting'; import PersonalSetting from '../../components/PersonalSetting'; import OperationSetting from '../../components/OperationSetting'; +import RateLimitSetting from '../../components/RateLimitSetting.js'; +import ModelSetting from '../../components/ModelSetting.js'; const Setting = () => { const { t } = useTranslation(); @@ -28,6 +30,16 @@ const Setting = () => { content: , itemKey: 'operation', }); + panes.push({ + tab: t('速率限制设置'), + content: , + itemKey: 'ratelimit', + }); + panes.push({ + tab: t('模型相关设置'), + content: , + itemKey: 'models', + }); panes.push({ tab: t('系统设置'), content: ,