YourUsername commited on
Commit ·
d9f2cee
1
Parent(s): 72eae22
25:02:22 23:01:24
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env.example +2 -2
- README.en.md +2 -0
- Rerank.md +1 -1
- VERSION +0 -1
- common/constants.go +1 -1
- common/go-channel.go +0 -13
- common/logger.go +11 -2
- common/model-ratio.go +17 -23
- constant/env.go +4 -1
- constant/user_setting.go +14 -0
- controller/channel-test.go +2 -7
- controller/pricing.go +1 -1
- controller/relay.go +1 -1
- controller/user.go +114 -4
- docker-compose.yml +1 -1
- dto/notify.go +25 -0
- dto/openai_request.go +84 -42
- dto/openai_response.go +11 -3
- main.go +2 -2
- middleware/distributor.go +0 -4
- model/option.go +2 -2
- model/token.go +3 -82
- model/token_cache.go +1 -1
- model/user.go +77 -13
- model/user_cache.go +128 -121
- relay/channel/cloudflare/adaptor.go +2 -1
- relay/channel/deepseek/adaptor.go +7 -1
- relay/channel/gemini/adaptor.go +99 -4
- relay/channel/gemini/constant.go +2 -0
- relay/channel/gemini/dto.go +27 -0
- relay/channel/mistral/adaptor.go +1 -4
- relay/channel/mistral/text.go +8 -12
- relay/channel/ollama/adaptor.go +2 -1
- relay/channel/ollama/dto.go +17 -14
- relay/channel/ollama/relay-ollama.go +29 -4
- relay/channel/openai/adaptor.go +1 -1
- relay/channel/openai/relay-openai.go +6 -1
- relay/channel/siliconflow/adaptor.go +8 -0
- relay/channel/zhipu_4v/relay-zhipu_v4.go +1 -2
- relay/common/relay_info.go +39 -27
- relay/helper/model_mapped.go +25 -0
- relay/helper/price.go +41 -0
- relay/relay-audio.go +13 -20
- relay/relay-image.go +16 -24
- relay/relay-mj.go +2 -2
- relay/relay-text.go +40 -61
- relay/relay_embedding.go +9 -32
- relay/relay_rerank.go +9 -32
- relay/relay_task.go +1 -1
- router/api-router.go +1 -0
.env.example
CHANGED
|
@@ -10,9 +10,9 @@
|
|
| 10 |
|
| 11 |
# 数据库相关配置
|
| 12 |
# 数据库连接字符串
|
| 13 |
-
# SQL_DSN=
|
| 14 |
# 日志数据库连接字符串
|
| 15 |
-
# LOG_SQL_DSN=
|
| 16 |
# SQLite数据库路径
|
| 17 |
# SQLITE_PATH=/path/to/sqlite.db
|
| 18 |
# 数据库最大空闲连接数
|
|
|
|
| 10 |
|
| 11 |
# 数据库相关配置
|
| 12 |
# 数据库连接字符串
|
| 13 |
+
# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
|
| 14 |
# 日志数据库连接字符串
|
| 15 |
+
# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
|
| 16 |
# SQLite数据库路径
|
| 17 |
# SQLITE_PATH=/path/to/sqlite.db
|
| 18 |
# 数据库最大空闲连接数
|
README.en.md
CHANGED
|
@@ -89,6 +89,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
|
|
| 89 |
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
|
| 90 |
- `CRYPTO_SECRET`: Encryption key for encrypting database content
|
| 91 |
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
|
|
|
|
|
|
|
| 92 |
|
| 93 |
## Deployment
|
| 94 |
|
|
|
|
| 89 |
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
|
| 90 |
- `CRYPTO_SECRET`: Encryption key for encrypting database content
|
| 91 |
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
|
| 92 |
+
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
|
| 93 |
+
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
|
| 94 |
|
| 95 |
## Deployment
|
| 96 |
|
Rerank.md
CHANGED
|
@@ -13,7 +13,7 @@ Request:
|
|
| 13 |
|
| 14 |
```json
|
| 15 |
{
|
| 16 |
-
"model": "
|
| 17 |
"query": "What is the capital of the United States?",
|
| 18 |
"top_n": 3,
|
| 19 |
"documents": [
|
|
|
|
| 13 |
|
| 14 |
```json
|
| 15 |
{
|
| 16 |
+
"model": "jina-reranker-v2-base-multilingual",
|
| 17 |
"query": "What is the capital of the United States?",
|
| 18 |
"top_n": 3,
|
| 19 |
"documents": [
|
VERSION
CHANGED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
v0.4.7.2.1
|
|
|
|
|
|
common/constants.go
CHANGED
|
@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
|
|
| 101 |
|
| 102 |
var RetryTimes = 0
|
| 103 |
|
| 104 |
-
var RootUserEmail = ""
|
| 105 |
|
| 106 |
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
| 107 |
|
|
|
|
| 101 |
|
| 102 |
var RetryTimes = 0
|
| 103 |
|
| 104 |
+
//var RootUserEmail = ""
|
| 105 |
|
| 106 |
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
| 107 |
|
common/go-channel.go
CHANGED
|
@@ -1,22 +1,9 @@
|
|
| 1 |
package common
|
| 2 |
|
| 3 |
import (
|
| 4 |
-
"fmt"
|
| 5 |
-
"runtime/debug"
|
| 6 |
"time"
|
| 7 |
)
|
| 8 |
|
| 9 |
-
func SafeGoroutine(f func()) {
|
| 10 |
-
go func() {
|
| 11 |
-
defer func() {
|
| 12 |
-
if r := recover(); r != nil {
|
| 13 |
-
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
|
| 14 |
-
}
|
| 15 |
-
}()
|
| 16 |
-
f()
|
| 17 |
-
}()
|
| 18 |
-
}
|
| 19 |
-
|
| 20 |
func SafeSendBool(ch chan bool, value bool) (closed bool) {
|
| 21 |
defer func() {
|
| 22 |
// Recover from panic if one occured. A panic would mean the channel was closed.
|
|
|
|
| 1 |
package common
|
| 2 |
|
| 3 |
import (
|
|
|
|
|
|
|
| 4 |
"time"
|
| 5 |
)
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
func SafeSendBool(ch chan bool, value bool) (closed bool) {
|
| 8 |
defer func() {
|
| 9 |
// Recover from panic if one occured. A panic would mean the channel was closed.
|
common/logger.go
CHANGED
|
@@ -4,6 +4,7 @@ import (
|
|
| 4 |
"context"
|
| 5 |
"encoding/json"
|
| 6 |
"fmt"
|
|
|
|
| 7 |
"github.com/gin-gonic/gin"
|
| 8 |
"io"
|
| 9 |
"log"
|
|
@@ -80,9 +81,9 @@ func logHelper(ctx context.Context, level string, msg string) {
|
|
| 80 |
if logCount > maxLogCount && !setupLogWorking {
|
| 81 |
logCount = 0
|
| 82 |
setupLogWorking = true
|
| 83 |
-
|
| 84 |
SetupLogger()
|
| 85 |
-
}
|
| 86 |
}
|
| 87 |
}
|
| 88 |
|
|
@@ -100,6 +101,14 @@ func LogQuota(quota int) string {
|
|
| 100 |
}
|
| 101 |
}
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
// LogJson 仅供测试使用 only for test
|
| 104 |
func LogJson(ctx context.Context, msg string, obj any) {
|
| 105 |
jsonStr, err := json.Marshal(obj)
|
|
|
|
| 4 |
"context"
|
| 5 |
"encoding/json"
|
| 6 |
"fmt"
|
| 7 |
+
"github.com/bytedance/gopkg/util/gopool"
|
| 8 |
"github.com/gin-gonic/gin"
|
| 9 |
"io"
|
| 10 |
"log"
|
|
|
|
| 81 |
if logCount > maxLogCount && !setupLogWorking {
|
| 82 |
logCount = 0
|
| 83 |
setupLogWorking = true
|
| 84 |
+
gopool.Go(func() {
|
| 85 |
SetupLogger()
|
| 86 |
+
})
|
| 87 |
}
|
| 88 |
}
|
| 89 |
|
|
|
|
| 101 |
}
|
| 102 |
}
|
| 103 |
|
| 104 |
+
func FormatQuota(quota int) string {
|
| 105 |
+
if DisplayInCurrencyEnabled {
|
| 106 |
+
return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
|
| 107 |
+
} else {
|
| 108 |
+
return fmt.Sprintf("%d", quota)
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
// LogJson 仅供测试使用 only for test
|
| 113 |
func LogJson(ctx context.Context, msg string, obj any) {
|
| 114 |
jsonStr, err := json.Marshal(obj)
|
common/model-ratio.go
CHANGED
|
@@ -233,7 +233,11 @@ var (
|
|
| 233 |
modelRatioMapMutex = sync.RWMutex{}
|
| 234 |
)
|
| 235 |
|
| 236 |
-
var
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
var defaultCompletionRatio = map[string]float64{
|
| 238 |
"gpt-4-gizmo-*": 2,
|
| 239 |
"gpt-4o-gizmo-*": 3,
|
|
@@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
|
|
| 334 |
return defaultModelRatio
|
| 335 |
}
|
| 336 |
|
| 337 |
-
func
|
|
|
|
|
|
|
| 338 |
if CompletionRatio == nil {
|
| 339 |
CompletionRatio = defaultCompletionRatio
|
| 340 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
jsonBytes, err := json.Marshal(CompletionRatio)
|
| 342 |
if err != nil {
|
| 343 |
SysError("error marshalling completion ratio: " + err.Error())
|
|
@@ -346,11 +357,15 @@ func CompletionRatio2JSONString() string {
|
|
| 346 |
}
|
| 347 |
|
| 348 |
func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
|
|
|
|
|
|
| 349 |
CompletionRatio = make(map[string]float64)
|
| 350 |
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
| 351 |
}
|
| 352 |
|
| 353 |
func GetCompletionRatio(name string) float64 {
|
|
|
|
|
|
|
| 354 |
if strings.Contains(name, "/") {
|
| 355 |
if ratio, ok := CompletionRatio[name]; ok {
|
| 356 |
return ratio
|
|
@@ -476,24 +491,3 @@ func GetAudioCompletionRatio(name string) float64 {
|
|
| 476 |
}
|
| 477 |
return 2
|
| 478 |
}
|
| 479 |
-
|
| 480 |
-
//func GetAudioPricePerMinute(name string) float64 {
|
| 481 |
-
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
| 482 |
-
// return 0.06
|
| 483 |
-
// }
|
| 484 |
-
// return 0.06
|
| 485 |
-
//}
|
| 486 |
-
//
|
| 487 |
-
//func GetAudioCompletionPricePerMinute(name string) float64 {
|
| 488 |
-
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
| 489 |
-
// return 0.24
|
| 490 |
-
// }
|
| 491 |
-
// return 0.24
|
| 492 |
-
//}
|
| 493 |
-
|
| 494 |
-
func GetCompletionRatioMap() map[string]float64 {
|
| 495 |
-
if CompletionRatio == nil {
|
| 496 |
-
CompletionRatio = defaultCompletionRatio
|
| 497 |
-
}
|
| 498 |
-
return CompletionRatio
|
| 499 |
-
}
|
|
|
|
| 233 |
modelRatioMapMutex = sync.RWMutex{}
|
| 234 |
)
|
| 235 |
|
| 236 |
+
var (
|
| 237 |
+
CompletionRatio map[string]float64 = nil
|
| 238 |
+
CompletionRatioMutex = sync.RWMutex{}
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
var defaultCompletionRatio = map[string]float64{
|
| 242 |
"gpt-4-gizmo-*": 2,
|
| 243 |
"gpt-4o-gizmo-*": 3,
|
|
|
|
| 338 |
return defaultModelRatio
|
| 339 |
}
|
| 340 |
|
| 341 |
+
func GetCompletionRatioMap() map[string]float64 {
|
| 342 |
+
CompletionRatioMutex.Lock()
|
| 343 |
+
defer CompletionRatioMutex.Unlock()
|
| 344 |
if CompletionRatio == nil {
|
| 345 |
CompletionRatio = defaultCompletionRatio
|
| 346 |
}
|
| 347 |
+
return CompletionRatio
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
func CompletionRatio2JSONString() string {
|
| 351 |
+
GetCompletionRatioMap()
|
| 352 |
jsonBytes, err := json.Marshal(CompletionRatio)
|
| 353 |
if err != nil {
|
| 354 |
SysError("error marshalling completion ratio: " + err.Error())
|
|
|
|
| 357 |
}
|
| 358 |
|
| 359 |
func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
| 360 |
+
CompletionRatioMutex.Lock()
|
| 361 |
+
defer CompletionRatioMutex.Unlock()
|
| 362 |
CompletionRatio = make(map[string]float64)
|
| 363 |
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
| 364 |
}
|
| 365 |
|
| 366 |
func GetCompletionRatio(name string) float64 {
|
| 367 |
+
GetCompletionRatioMap()
|
| 368 |
+
|
| 369 |
if strings.Contains(name, "/") {
|
| 370 |
if ratio, ok := CompletionRatio[name]; ok {
|
| 371 |
return ratio
|
|
|
|
| 491 |
}
|
| 492 |
return 2
|
| 493 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
constant/env.go
CHANGED
|
@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
|
|
| 29 |
|
| 30 |
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
func InitEnv() {
|
| 33 |
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
| 34 |
if modelVersionMapStr == "" {
|
|
@@ -44,5 +47,5 @@ func InitEnv() {
|
|
| 44 |
}
|
| 45 |
}
|
| 46 |
|
| 47 |
-
// 是否生成初始令牌,默认关闭。
|
| 48 |
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
|
|
|
| 29 |
|
| 30 |
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
| 31 |
|
| 32 |
+
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
| 33 |
+
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
| 34 |
+
|
| 35 |
func InitEnv() {
|
| 36 |
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
| 37 |
if modelVersionMapStr == "" {
|
|
|
|
| 47 |
}
|
| 48 |
}
|
| 49 |
|
| 50 |
+
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
| 51 |
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
constant/user_setting.go
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package constant
|
| 2 |
+
|
| 3 |
+
var (
|
| 4 |
+
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
|
| 5 |
+
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
|
| 6 |
+
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
|
| 7 |
+
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
| 8 |
+
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
var (
|
| 12 |
+
NotifyTypeEmail = "email" // Email 邮件
|
| 13 |
+
NotifyTypeWebhook = "webhook" // Webhook
|
| 14 |
+
)
|
controller/channel-test.go
CHANGED
|
@@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex
|
|
| 238 |
var testAllChannelsRunning bool = false
|
| 239 |
|
| 240 |
func testAllChannels(notify bool) error {
|
| 241 |
-
|
| 242 |
-
common.RootUserEmail = model.GetRootUserEmail()
|
| 243 |
-
}
|
| 244 |
testAllChannelsLock.Lock()
|
| 245 |
if testAllChannelsRunning {
|
| 246 |
testAllChannelsLock.Unlock()
|
|
@@ -295,10 +293,7 @@ func testAllChannels(notify bool) error {
|
|
| 295 |
testAllChannelsRunning = false
|
| 296 |
testAllChannelsLock.Unlock()
|
| 297 |
if notify {
|
| 298 |
-
|
| 299 |
-
if err != nil {
|
| 300 |
-
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
| 301 |
-
}
|
| 302 |
}
|
| 303 |
})
|
| 304 |
return nil
|
|
|
|
| 238 |
var testAllChannelsRunning bool = false
|
| 239 |
|
| 240 |
func testAllChannels(notify bool) error {
|
| 241 |
+
|
|
|
|
|
|
|
| 242 |
testAllChannelsLock.Lock()
|
| 243 |
if testAllChannelsRunning {
|
| 244 |
testAllChannelsLock.Unlock()
|
|
|
|
| 293 |
testAllChannelsRunning = false
|
| 294 |
testAllChannelsLock.Unlock()
|
| 295 |
if notify {
|
| 296 |
+
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
|
|
|
|
|
|
|
|
|
| 297 |
}
|
| 298 |
})
|
| 299 |
return nil
|
controller/pricing.go
CHANGED
|
@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
|
|
| 17 |
}
|
| 18 |
var group string
|
| 19 |
if exists {
|
| 20 |
-
user, err := model.
|
| 21 |
if err == nil {
|
| 22 |
group = user.Group
|
| 23 |
}
|
|
|
|
| 17 |
}
|
| 18 |
var group string
|
| 19 |
if exists {
|
| 20 |
+
user, err := model.GetUserCache(userId.(int))
|
| 21 |
if err == nil {
|
| 22 |
group = user.Group
|
| 23 |
}
|
controller/relay.go
CHANGED
|
@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|
| 24 |
var err *dto.OpenAIErrorWithStatusCode
|
| 25 |
switch relayMode {
|
| 26 |
case relayconstant.RelayModeImagesGenerations:
|
| 27 |
-
err = relay.ImageHelper(c
|
| 28 |
case relayconstant.RelayModeAudioSpeech:
|
| 29 |
fallthrough
|
| 30 |
case relayconstant.RelayModeAudioTranslation:
|
|
|
|
| 24 |
var err *dto.OpenAIErrorWithStatusCode
|
| 25 |
switch relayMode {
|
| 26 |
case relayconstant.RelayModeImagesGenerations:
|
| 27 |
+
err = relay.ImageHelper(c)
|
| 28 |
case relayconstant.RelayModeAudioSpeech:
|
| 29 |
fallthrough
|
| 30 |
case relayconstant.RelayModeAudioTranslation:
|
controller/user.go
CHANGED
|
@@ -4,6 +4,7 @@ import (
|
|
| 4 |
"encoding/json"
|
| 5 |
"fmt"
|
| 6 |
"net/http"
|
|
|
|
| 7 |
"one-api/common"
|
| 8 |
"one-api/model"
|
| 9 |
"one-api/setting"
|
|
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
|
|
| 471 |
if err != nil {
|
| 472 |
id = c.GetInt("id")
|
| 473 |
}
|
| 474 |
-
user, err := model.
|
| 475 |
if err != nil {
|
| 476 |
c.JSON(http.StatusOK, gin.H{
|
| 477 |
"success": false,
|
|
@@ -869,9 +870,6 @@ func EmailBind(c *gin.Context) {
|
|
| 869 |
})
|
| 870 |
return
|
| 871 |
}
|
| 872 |
-
if user.Role == common.RoleRootUser {
|
| 873 |
-
common.RootUserEmail = email
|
| 874 |
-
}
|
| 875 |
c.JSON(http.StatusOK, gin.H{
|
| 876 |
"success": true,
|
| 877 |
"message": "",
|
|
@@ -913,3 +911,115 @@ func TopUp(c *gin.Context) {
|
|
| 913 |
})
|
| 914 |
return
|
| 915 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"encoding/json"
|
| 5 |
"fmt"
|
| 6 |
"net/http"
|
| 7 |
+
"net/url"
|
| 8 |
"one-api/common"
|
| 9 |
"one-api/model"
|
| 10 |
"one-api/setting"
|
|
|
|
| 472 |
if err != nil {
|
| 473 |
id = c.GetInt("id")
|
| 474 |
}
|
| 475 |
+
user, err := model.GetUserCache(id)
|
| 476 |
if err != nil {
|
| 477 |
c.JSON(http.StatusOK, gin.H{
|
| 478 |
"success": false,
|
|
|
|
| 870 |
})
|
| 871 |
return
|
| 872 |
}
|
|
|
|
|
|
|
|
|
|
| 873 |
c.JSON(http.StatusOK, gin.H{
|
| 874 |
"success": true,
|
| 875 |
"message": "",
|
|
|
|
| 911 |
})
|
| 912 |
return
|
| 913 |
}
|
| 914 |
+
|
| 915 |
+
type UpdateUserSettingRequest struct {
|
| 916 |
+
QuotaWarningType string `json:"notify_type"`
|
| 917 |
+
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
|
| 918 |
+
WebhookUrl string `json:"webhook_url,omitempty"`
|
| 919 |
+
WebhookSecret string `json:"webhook_secret,omitempty"`
|
| 920 |
+
NotificationEmail string `json:"notification_email,omitempty"`
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
func UpdateUserSetting(c *gin.Context) {
|
| 924 |
+
var req UpdateUserSettingRequest
|
| 925 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 926 |
+
c.JSON(http.StatusOK, gin.H{
|
| 927 |
+
"success": false,
|
| 928 |
+
"message": "无效的参数",
|
| 929 |
+
})
|
| 930 |
+
return
|
| 931 |
+
}
|
| 932 |
+
|
| 933 |
+
// 验证预警类型
|
| 934 |
+
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
|
| 935 |
+
c.JSON(http.StatusOK, gin.H{
|
| 936 |
+
"success": false,
|
| 937 |
+
"message": "无效的预警类型",
|
| 938 |
+
})
|
| 939 |
+
return
|
| 940 |
+
}
|
| 941 |
+
|
| 942 |
+
// 验证预警阈值
|
| 943 |
+
if req.QuotaWarningThreshold <= 0 {
|
| 944 |
+
c.JSON(http.StatusOK, gin.H{
|
| 945 |
+
"success": false,
|
| 946 |
+
"message": "预警阈值必须大于0",
|
| 947 |
+
})
|
| 948 |
+
return
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
// 如果是webhook类型,验证webhook地址
|
| 952 |
+
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
| 953 |
+
if req.WebhookUrl == "" {
|
| 954 |
+
c.JSON(http.StatusOK, gin.H{
|
| 955 |
+
"success": false,
|
| 956 |
+
"message": "Webhook地址不能为空",
|
| 957 |
+
})
|
| 958 |
+
return
|
| 959 |
+
}
|
| 960 |
+
// 验证URL格式
|
| 961 |
+
if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
|
| 962 |
+
c.JSON(http.StatusOK, gin.H{
|
| 963 |
+
"success": false,
|
| 964 |
+
"message": "无效的Webhook地址",
|
| 965 |
+
})
|
| 966 |
+
return
|
| 967 |
+
}
|
| 968 |
+
}
|
| 969 |
+
|
| 970 |
+
// 如果是邮件类型,验证邮箱地址
|
| 971 |
+
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
| 972 |
+
// 验证邮箱格式
|
| 973 |
+
if !strings.Contains(req.NotificationEmail, "@") {
|
| 974 |
+
c.JSON(http.StatusOK, gin.H{
|
| 975 |
+
"success": false,
|
| 976 |
+
"message": "无效的邮箱地址",
|
| 977 |
+
})
|
| 978 |
+
return
|
| 979 |
+
}
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
userId := c.GetInt("id")
|
| 983 |
+
user, err := model.GetUserById(userId, true)
|
| 984 |
+
if err != nil {
|
| 985 |
+
c.JSON(http.StatusOK, gin.H{
|
| 986 |
+
"success": false,
|
| 987 |
+
"message": err.Error(),
|
| 988 |
+
})
|
| 989 |
+
return
|
| 990 |
+
}
|
| 991 |
+
|
| 992 |
+
// 构建设置
|
| 993 |
+
settings := map[string]interface{}{
|
| 994 |
+
constant.UserSettingNotifyType: req.QuotaWarningType,
|
| 995 |
+
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
|
| 996 |
+
}
|
| 997 |
+
|
| 998 |
+
// 如果是webhook类型,添加webhook相关设置
|
| 999 |
+
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
| 1000 |
+
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
|
| 1001 |
+
if req.WebhookSecret != "" {
|
| 1002 |
+
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
|
| 1003 |
+
}
|
| 1004 |
+
}
|
| 1005 |
+
|
| 1006 |
+
// 如果提供了通知邮箱,添加到设置中
|
| 1007 |
+
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
| 1008 |
+
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
|
| 1009 |
+
}
|
| 1010 |
+
|
| 1011 |
+
// 更新用户设置
|
| 1012 |
+
user.SetSetting(settings)
|
| 1013 |
+
if err := user.Update(false); err != nil {
|
| 1014 |
+
c.JSON(http.StatusOK, gin.H{
|
| 1015 |
+
"success": false,
|
| 1016 |
+
"message": "更新设置失败: " + err.Error(),
|
| 1017 |
+
})
|
| 1018 |
+
return
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
c.JSON(http.StatusOK, gin.H{
|
| 1022 |
+
"success": true,
|
| 1023 |
+
"message": "设置已更新",
|
| 1024 |
+
})
|
| 1025 |
+
}
|
docker-compose.yml
CHANGED
|
@@ -24,7 +24,7 @@ services:
|
|
| 24 |
- redis
|
| 25 |
- mysql
|
| 26 |
healthcheck:
|
| 27 |
-
test: [
|
| 28 |
interval: 30s
|
| 29 |
timeout: 10s
|
| 30 |
retries: 3
|
|
|
|
| 24 |
- redis
|
| 25 |
- mysql
|
| 26 |
healthcheck:
|
| 27 |
+
test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
|
| 28 |
interval: 30s
|
| 29 |
timeout: 10s
|
| 30 |
retries: 3
|
dto/notify.go
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package dto
|
| 2 |
+
|
| 3 |
+
type Notify struct {
|
| 4 |
+
Type string `json:"type"`
|
| 5 |
+
Title string `json:"title"`
|
| 6 |
+
Content string `json:"content"`
|
| 7 |
+
Values []interface{} `json:"values"`
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
const ContentValueParam = "{{value}}"
|
| 11 |
+
|
| 12 |
+
const (
|
| 13 |
+
NotifyTypeQuotaExceed = "quota_exceed"
|
| 14 |
+
NotifyTypeChannelUpdate = "channel_update"
|
| 15 |
+
NotifyTypeChannelTest = "channel_test"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
func NewNotify(t string, title string, content string, values []interface{}) Notify {
|
| 19 |
+
return Notify{
|
| 20 |
+
Type: t,
|
| 21 |
+
Title: title,
|
| 22 |
+
Content: content,
|
| 23 |
+
Values: values,
|
| 24 |
+
}
|
| 25 |
+
}
|
dto/openai_request.go
CHANGED
|
@@ -18,6 +18,8 @@ type GeneralOpenAIRequest struct {
|
|
| 18 |
Model string `json:"model,omitempty"`
|
| 19 |
Messages []Message `json:"messages,omitempty"`
|
| 20 |
Prompt any `json:"prompt,omitempty"`
|
|
|
|
|
|
|
| 21 |
Stream bool `json:"stream,omitempty"`
|
| 22 |
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
| 23 |
MaxTokens uint `json:"max_tokens,omitempty"`
|
|
@@ -86,18 +88,20 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
|
|
| 86 |
}
|
| 87 |
|
| 88 |
type Message struct {
|
| 89 |
-
Role
|
| 90 |
-
Content
|
| 91 |
-
Name
|
| 92 |
-
Prefix
|
| 93 |
-
ReasoningContent
|
| 94 |
-
ToolCalls
|
| 95 |
-
ToolCallId
|
|
|
|
|
|
|
| 96 |
}
|
| 97 |
|
| 98 |
type MediaContent struct {
|
| 99 |
Type string `json:"type"`
|
| 100 |
-
Text string `json:"text"`
|
| 101 |
ImageUrl any `json:"image_url,omitempty"`
|
| 102 |
InputAudio any `json:"input_audio,omitempty"`
|
| 103 |
}
|
|
@@ -146,6 +150,9 @@ func (m *Message) SetToolCalls(toolCalls any) {
|
|
| 146 |
}
|
| 147 |
|
| 148 |
func (m *Message) StringContent() string {
|
|
|
|
|
|
|
|
|
|
| 149 |
var stringContent string
|
| 150 |
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
| 151 |
return stringContent
|
|
@@ -156,78 +163,113 @@ func (m *Message) StringContent() string {
|
|
| 156 |
func (m *Message) SetStringContent(content string) {
|
| 157 |
jsonContent, _ := json.Marshal(content)
|
| 158 |
m.Content = jsonContent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
}
|
| 160 |
|
| 161 |
func (m *Message) IsStringContent() bool {
|
|
|
|
|
|
|
|
|
|
| 162 |
var stringContent string
|
| 163 |
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
|
|
|
| 164 |
return true
|
| 165 |
}
|
| 166 |
return false
|
| 167 |
}
|
| 168 |
|
| 169 |
func (m *Message) ParseContent() []MediaContent {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
var contentList []MediaContent
|
|
|
|
|
|
|
| 171 |
var stringContent string
|
| 172 |
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
| 173 |
-
contentList =
|
| 174 |
Type: ContentTypeText,
|
| 175 |
Text: stringContent,
|
| 176 |
-
}
|
|
|
|
| 177 |
return contentList
|
| 178 |
}
|
| 179 |
-
|
|
|
|
|
|
|
| 180 |
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
|
| 181 |
for _, contentItem := range arrayContent {
|
| 182 |
-
|
| 183 |
-
if
|
| 184 |
continue
|
| 185 |
}
|
| 186 |
-
|
|
|
|
| 187 |
case ContentTypeText:
|
| 188 |
-
if
|
| 189 |
contentList = append(contentList, MediaContent{
|
| 190 |
Type: ContentTypeText,
|
| 191 |
-
Text:
|
| 192 |
})
|
| 193 |
}
|
|
|
|
| 194 |
case ContentTypeImageURL:
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
subObj["detail"] = detail.(string)
|
| 199 |
-
} else {
|
| 200 |
-
subObj["detail"] = "high"
|
| 201 |
-
}
|
| 202 |
contentList = append(contentList, MediaContent{
|
| 203 |
Type: ContentTypeImageURL,
|
| 204 |
ImageUrl: MessageImageUrl{
|
| 205 |
-
Url:
|
| 206 |
-
Detail: subObj["detail"].(string),
|
| 207 |
-
},
|
| 208 |
-
})
|
| 209 |
-
} else if url, ok := contentMap["image_url"].(string); ok {
|
| 210 |
-
contentList = append(contentList, MediaContent{
|
| 211 |
-
Type: ContentTypeImageURL,
|
| 212 |
-
ImageUrl: MessageImageUrl{
|
| 213 |
-
Url: url,
|
| 214 |
Detail: "high",
|
| 215 |
},
|
| 216 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
}
|
|
|
|
| 218 |
case ContentTypeInputAudio:
|
| 219 |
-
if
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
}
|
| 228 |
}
|
| 229 |
}
|
| 230 |
-
return contentList
|
| 231 |
}
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
}
|
|
|
|
| 18 |
Model string `json:"model,omitempty"`
|
| 19 |
Messages []Message `json:"messages,omitempty"`
|
| 20 |
Prompt any `json:"prompt,omitempty"`
|
| 21 |
+
Prefix any `json:"prefix,omitempty"`
|
| 22 |
+
Suffix any `json:"suffix,omitempty"`
|
| 23 |
Stream bool `json:"stream,omitempty"`
|
| 24 |
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
| 25 |
MaxTokens uint `json:"max_tokens,omitempty"`
|
|
|
|
| 88 |
}
|
| 89 |
|
| 90 |
type Message struct {
|
| 91 |
+
Role string `json:"role"`
|
| 92 |
+
Content json.RawMessage `json:"content"`
|
| 93 |
+
Name *string `json:"name,omitempty"`
|
| 94 |
+
Prefix *bool `json:"prefix,omitempty"`
|
| 95 |
+
ReasoningContent string `json:"reasoning_content,omitempty"`
|
| 96 |
+
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
| 97 |
+
ToolCallId string `json:"tool_call_id,omitempty"`
|
| 98 |
+
parsedContent []MediaContent
|
| 99 |
+
parsedStringContent *string
|
| 100 |
}
|
| 101 |
|
| 102 |
type MediaContent struct {
|
| 103 |
Type string `json:"type"`
|
| 104 |
+
Text string `json:"text,omitempty"`
|
| 105 |
ImageUrl any `json:"image_url,omitempty"`
|
| 106 |
InputAudio any `json:"input_audio,omitempty"`
|
| 107 |
}
|
|
|
|
| 150 |
}
|
| 151 |
|
| 152 |
func (m *Message) StringContent() string {
|
| 153 |
+
if m.parsedStringContent != nil {
|
| 154 |
+
return *m.parsedStringContent
|
| 155 |
+
}
|
| 156 |
var stringContent string
|
| 157 |
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
| 158 |
return stringContent
|
|
|
|
| 163 |
func (m *Message) SetStringContent(content string) {
|
| 164 |
jsonContent, _ := json.Marshal(content)
|
| 165 |
m.Content = jsonContent
|
| 166 |
+
m.parsedStringContent = &content
|
| 167 |
+
m.parsedContent = nil
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
func (m *Message) SetMediaContent(content []MediaContent) {
|
| 171 |
+
jsonContent, _ := json.Marshal(content)
|
| 172 |
+
m.Content = jsonContent
|
| 173 |
+
m.parsedContent = nil
|
| 174 |
+
m.parsedStringContent = nil
|
| 175 |
}
|
| 176 |
|
| 177 |
func (m *Message) IsStringContent() bool {
|
| 178 |
+
if m.parsedStringContent != nil {
|
| 179 |
+
return true
|
| 180 |
+
}
|
| 181 |
var stringContent string
|
| 182 |
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
| 183 |
+
m.parsedStringContent = &stringContent
|
| 184 |
return true
|
| 185 |
}
|
| 186 |
return false
|
| 187 |
}
|
| 188 |
|
| 189 |
func (m *Message) ParseContent() []MediaContent {
|
| 190 |
+
if m.parsedContent != nil {
|
| 191 |
+
return m.parsedContent
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
var contentList []MediaContent
|
| 195 |
+
|
| 196 |
+
// 先尝试解析为字符串
|
| 197 |
var stringContent string
|
| 198 |
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
| 199 |
+
contentList = []MediaContent{{
|
| 200 |
Type: ContentTypeText,
|
| 201 |
Text: stringContent,
|
| 202 |
+
}}
|
| 203 |
+
m.parsedContent = contentList
|
| 204 |
return contentList
|
| 205 |
}
|
| 206 |
+
|
| 207 |
+
// 尝试解析为数组
|
| 208 |
+
var arrayContent []map[string]interface{}
|
| 209 |
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
|
| 210 |
for _, contentItem := range arrayContent {
|
| 211 |
+
contentType, ok := contentItem["type"].(string)
|
| 212 |
+
if !ok {
|
| 213 |
continue
|
| 214 |
}
|
| 215 |
+
|
| 216 |
+
switch contentType {
|
| 217 |
case ContentTypeText:
|
| 218 |
+
if text, ok := contentItem["text"].(string); ok {
|
| 219 |
contentList = append(contentList, MediaContent{
|
| 220 |
Type: ContentTypeText,
|
| 221 |
+
Text: text,
|
| 222 |
})
|
| 223 |
}
|
| 224 |
+
|
| 225 |
case ContentTypeImageURL:
|
| 226 |
+
imageUrl := contentItem["image_url"]
|
| 227 |
+
switch v := imageUrl.(type) {
|
| 228 |
+
case string:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
contentList = append(contentList, MediaContent{
|
| 230 |
Type: ContentTypeImageURL,
|
| 231 |
ImageUrl: MessageImageUrl{
|
| 232 |
+
Url: v,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
Detail: "high",
|
| 234 |
},
|
| 235 |
})
|
| 236 |
+
case map[string]interface{}:
|
| 237 |
+
url, ok1 := v["url"].(string)
|
| 238 |
+
detail, ok2 := v["detail"].(string)
|
| 239 |
+
if !ok2 {
|
| 240 |
+
detail = "high"
|
| 241 |
+
}
|
| 242 |
+
if ok1 {
|
| 243 |
+
contentList = append(contentList, MediaContent{
|
| 244 |
+
Type: ContentTypeImageURL,
|
| 245 |
+
ImageUrl: MessageImageUrl{
|
| 246 |
+
Url: url,
|
| 247 |
+
Detail: detail,
|
| 248 |
+
},
|
| 249 |
+
})
|
| 250 |
+
}
|
| 251 |
}
|
| 252 |
+
|
| 253 |
case ContentTypeInputAudio:
|
| 254 |
+
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
|
| 255 |
+
data, ok1 := audioData["data"].(string)
|
| 256 |
+
format, ok2 := audioData["format"].(string)
|
| 257 |
+
if ok1 && ok2 {
|
| 258 |
+
contentList = append(contentList, MediaContent{
|
| 259 |
+
Type: ContentTypeInputAudio,
|
| 260 |
+
InputAudio: MessageInputAudio{
|
| 261 |
+
Data: data,
|
| 262 |
+
Format: format,
|
| 263 |
+
},
|
| 264 |
+
})
|
| 265 |
+
}
|
| 266 |
}
|
| 267 |
}
|
| 268 |
}
|
|
|
|
| 269 |
}
|
| 270 |
+
|
| 271 |
+
if len(contentList) > 0 {
|
| 272 |
+
m.parsedContent = contentList
|
| 273 |
+
}
|
| 274 |
+
return contentList
|
| 275 |
}
|
dto/openai_response.go
CHANGED
|
@@ -62,9 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
|
|
| 62 |
}
|
| 63 |
|
| 64 |
type ChatCompletionsStreamResponseChoiceDelta struct {
|
| 65 |
-
Content
|
| 66 |
-
|
| 67 |
-
|
|
|
|
| 68 |
}
|
| 69 |
|
| 70 |
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
|
@@ -78,6 +79,13 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
|
|
| 78 |
return *c.Content
|
| 79 |
}
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
type ToolCall struct {
|
| 82 |
// Index is not nil only in chat completion chunk object
|
| 83 |
Index *int `json:"index,omitempty"`
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
type ChatCompletionsStreamResponseChoiceDelta struct {
|
| 65 |
+
Content *string `json:"content,omitempty"`
|
| 66 |
+
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
| 67 |
+
Role string `json:"role,omitempty"`
|
| 68 |
+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
| 69 |
}
|
| 70 |
|
| 71 |
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
|
|
|
| 79 |
return *c.Content
|
| 80 |
}
|
| 81 |
|
| 82 |
+
func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
|
| 83 |
+
if c.ReasoningContent == nil {
|
| 84 |
+
return ""
|
| 85 |
+
}
|
| 86 |
+
return *c.ReasoningContent
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
type ToolCall struct {
|
| 90 |
// Index is not nil only in chat completion chunk object
|
| 91 |
Index *int `json:"index,omitempty"`
|
main.go
CHANGED
|
@@ -119,9 +119,9 @@ func main() {
|
|
| 119 |
}
|
| 120 |
|
| 121 |
if os.Getenv("ENABLE_PPROF") == "true" {
|
| 122 |
-
|
| 123 |
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
|
| 124 |
-
}
|
| 125 |
go common.Monitor()
|
| 126 |
common.SysLog("pprof enabled")
|
| 127 |
}
|
|
|
|
| 119 |
}
|
| 120 |
|
| 121 |
if os.Getenv("ENABLE_PPROF") == "true" {
|
| 122 |
+
gopool.Go(func() {
|
| 123 |
log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
|
| 124 |
+
})
|
| 125 |
go common.Monitor()
|
| 126 |
common.SysLog("pprof enabled")
|
| 127 |
}
|
middleware/distributor.go
CHANGED
|
@@ -135,17 +135,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|
| 135 |
midjourneyRequest := dto.MidjourneyRequest{}
|
| 136 |
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
| 137 |
if err != nil {
|
| 138 |
-
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
|
| 139 |
return nil, false, err
|
| 140 |
}
|
| 141 |
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
| 142 |
if mjErr != nil {
|
| 143 |
-
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
|
| 144 |
return nil, false, fmt.Errorf(mjErr.Description)
|
| 145 |
}
|
| 146 |
if midjourneyModel == "" {
|
| 147 |
if !success {
|
| 148 |
-
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
|
| 149 |
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
|
| 150 |
} else {
|
| 151 |
// task fetch, task fetch by condition, notify
|
|
@@ -170,7 +167,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|
| 170 |
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
| 171 |
}
|
| 172 |
if err != nil {
|
| 173 |
-
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
| 174 |
return nil, false, errors.New("无效的请求, " + err.Error())
|
| 175 |
}
|
| 176 |
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
|
|
|
|
| 135 |
midjourneyRequest := dto.MidjourneyRequest{}
|
| 136 |
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
| 137 |
if err != nil {
|
|
|
|
| 138 |
return nil, false, err
|
| 139 |
}
|
| 140 |
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
| 141 |
if mjErr != nil {
|
|
|
|
| 142 |
return nil, false, fmt.Errorf(mjErr.Description)
|
| 143 |
}
|
| 144 |
if midjourneyModel == "" {
|
| 145 |
if !success {
|
|
|
|
| 146 |
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
|
| 147 |
} else {
|
| 148 |
// task fetch, task fetch by condition, notify
|
|
|
|
| 167 |
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
| 168 |
}
|
| 169 |
if err != nil {
|
|
|
|
| 170 |
return nil, false, errors.New("无效的请求, " + err.Error())
|
| 171 |
}
|
| 172 |
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
|
model/option.go
CHANGED
|
@@ -84,7 +84,7 @@ func InitOptionMap() {
|
|
| 84 |
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
|
| 85 |
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
|
| 86 |
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
|
| 87 |
-
common.OptionMap["
|
| 88 |
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
| 89 |
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
| 90 |
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
|
@@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
|
|
| 306 |
common.QuotaForInvitee, _ = strconv.Atoi(value)
|
| 307 |
case "QuotaRemindThreshold":
|
| 308 |
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
| 309 |
-
case "
|
| 310 |
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
| 311 |
case "RetryTimes":
|
| 312 |
common.RetryTimes, _ = strconv.Atoi(value)
|
|
|
|
| 84 |
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
|
| 85 |
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
|
| 86 |
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
|
| 87 |
+
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
| 88 |
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
| 89 |
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
| 90 |
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
|
|
|
| 306 |
common.QuotaForInvitee, _ = strconv.Atoi(value)
|
| 307 |
case "QuotaRemindThreshold":
|
| 308 |
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
| 309 |
+
case "ShouldPreConsumedQuota":
|
| 310 |
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
| 311 |
case "RetryTimes":
|
| 312 |
common.RetryTimes, _ = strconv.Atoi(value)
|
model/token.go
CHANGED
|
@@ -3,13 +3,11 @@ package model
|
|
| 3 |
import (
|
| 4 |
"errors"
|
| 5 |
"fmt"
|
| 6 |
-
"github.com/bytedance/gopkg/util/gopool"
|
| 7 |
-
"gorm.io/gorm"
|
| 8 |
"one-api/common"
|
| 9 |
-
relaycommon "one-api/relay/common"
|
| 10 |
-
"one-api/setting"
|
| 11 |
-
"strconv"
|
| 12 |
"strings"
|
|
|
|
|
|
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
type Token struct {
|
|
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
|
|
| 322 |
).Error
|
| 323 |
return err
|
| 324 |
}
|
| 325 |
-
|
| 326 |
-
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
| 327 |
-
if quota < 0 {
|
| 328 |
-
return errors.New("quota 不能为负数!")
|
| 329 |
-
}
|
| 330 |
-
if relayInfo.IsPlayground {
|
| 331 |
-
return nil
|
| 332 |
-
}
|
| 333 |
-
//if relayInfo.TokenUnlimited {
|
| 334 |
-
// return nil
|
| 335 |
-
//}
|
| 336 |
-
token, err := GetTokenById(relayInfo.TokenId)
|
| 337 |
-
if err != nil {
|
| 338 |
-
return err
|
| 339 |
-
}
|
| 340 |
-
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
|
| 341 |
-
return errors.New("令牌额度不足")
|
| 342 |
-
}
|
| 343 |
-
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
| 344 |
-
if err != nil {
|
| 345 |
-
return err
|
| 346 |
-
}
|
| 347 |
-
return nil
|
| 348 |
-
}
|
| 349 |
-
|
| 350 |
-
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
|
| 351 |
-
|
| 352 |
-
if quota > 0 {
|
| 353 |
-
err = DecreaseUserQuota(relayInfo.UserId, quota)
|
| 354 |
-
} else {
|
| 355 |
-
err = IncreaseUserQuota(relayInfo.UserId, -quota)
|
| 356 |
-
}
|
| 357 |
-
if err != nil {
|
| 358 |
-
return err
|
| 359 |
-
}
|
| 360 |
-
|
| 361 |
-
if !relayInfo.IsPlayground {
|
| 362 |
-
if quota > 0 {
|
| 363 |
-
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
| 364 |
-
} else {
|
| 365 |
-
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
|
| 366 |
-
}
|
| 367 |
-
if err != nil {
|
| 368 |
-
return err
|
| 369 |
-
}
|
| 370 |
-
}
|
| 371 |
-
|
| 372 |
-
if sendEmail {
|
| 373 |
-
if (quota + preConsumedQuota) != 0 {
|
| 374 |
-
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
|
| 375 |
-
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
|
| 376 |
-
if quotaTooLow || noMoreQuota {
|
| 377 |
-
go func() {
|
| 378 |
-
email, err := GetUserEmail(relayInfo.UserId)
|
| 379 |
-
if err != nil {
|
| 380 |
-
common.SysError("failed to fetch user email: " + err.Error())
|
| 381 |
-
}
|
| 382 |
-
prompt := "您的额度即将用尽"
|
| 383 |
-
if noMoreQuota {
|
| 384 |
-
prompt = "您的额度已用尽"
|
| 385 |
-
}
|
| 386 |
-
if email != "" {
|
| 387 |
-
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
| 388 |
-
err = common.SendEmail(prompt, email,
|
| 389 |
-
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
| 390 |
-
if err != nil {
|
| 391 |
-
common.SysError("failed to send email" + err.Error())
|
| 392 |
-
}
|
| 393 |
-
common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
|
| 394 |
-
}
|
| 395 |
-
}()
|
| 396 |
-
}
|
| 397 |
-
}
|
| 398 |
-
}
|
| 399 |
-
|
| 400 |
-
return nil
|
| 401 |
-
}
|
|
|
|
| 3 |
import (
|
| 4 |
"errors"
|
| 5 |
"fmt"
|
|
|
|
|
|
|
| 6 |
"one-api/common"
|
|
|
|
|
|
|
|
|
|
| 7 |
"strings"
|
| 8 |
+
|
| 9 |
+
"github.com/bytedance/gopkg/util/gopool"
|
| 10 |
+
"gorm.io/gorm"
|
| 11 |
)
|
| 12 |
|
| 13 |
type Token struct {
|
|
|
|
| 320 |
).Error
|
| 321 |
return err
|
| 322 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/token_cache.go
CHANGED
|
@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
|
|
| 52 |
func cacheGetTokenByKey(key string) (*Token, error) {
|
| 53 |
hmacKey := common.GenerateHMAC(key)
|
| 54 |
if !common.RedisEnabled {
|
| 55 |
-
return nil,
|
| 56 |
}
|
| 57 |
var token Token
|
| 58 |
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
|
|
|
|
| 52 |
func cacheGetTokenByKey(key string) (*Token, error) {
|
| 53 |
hmacKey := common.GenerateHMAC(key)
|
| 54 |
if !common.RedisEnabled {
|
| 55 |
+
return nil, fmt.Errorf("redis is not enabled")
|
| 56 |
}
|
| 57 |
var token Token
|
| 58 |
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
|
model/user.go
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
package model
|
| 2 |
|
| 3 |
import (
|
|
|
|
| 4 |
"errors"
|
| 5 |
"fmt"
|
| 6 |
"one-api/common"
|
|
@@ -38,6 +39,20 @@ type User struct {
|
|
| 38 |
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
| 39 |
DeletedAt gorm.DeletedAt `gorm:"index"`
|
| 40 |
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
}
|
| 42 |
|
| 43 |
func (user *User) GetAccessToken() string {
|
|
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
|
|
| 51 |
user.AccessToken = &token
|
| 52 |
}
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
| 55 |
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
| 56 |
var user User
|
|
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
|
|
| 315 |
return err
|
| 316 |
}
|
| 317 |
|
| 318 |
-
//
|
| 319 |
-
return updateUserCache(user
|
| 320 |
}
|
| 321 |
|
| 322 |
func (user *User) Edit(updatePassword bool) error {
|
|
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
|
|
| 344 |
return err
|
| 345 |
}
|
| 346 |
|
| 347 |
-
//
|
| 348 |
-
return updateUserCache(user
|
| 349 |
}
|
| 350 |
|
| 351 |
func (user *User) Delete() error {
|
|
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
|
|
| 371 |
// ValidateAndFill check password & user status
|
| 372 |
func (user *User) ValidateAndFill() (err error) {
|
| 373 |
// When querying with struct, GORM will only query with non-zero fields,
|
| 374 |
-
// that means if your field
|
| 375 |
-
// it won
|
| 376 |
password := user.Password
|
| 377 |
username := strings.TrimSpace(user.Username)
|
| 378 |
if username == "" || password == "" {
|
|
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
|
|
| 531 |
return quota, nil
|
| 532 |
}
|
| 533 |
// Don't return error - fall through to DB
|
| 534 |
-
//common.SysError("failed to get user quota from cache: " + err.Error())
|
| 535 |
}
|
| 536 |
fromDB = true
|
| 537 |
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
|
@@ -580,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
|
| 580 |
return group, nil
|
| 581 |
}
|
| 582 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
func IncreaseUserQuota(id int, quota int) (err error) {
|
| 584 |
if quota < 0 {
|
| 585 |
return errors.New("quota 不能为负数!")
|
|
@@ -641,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
|
|
| 641 |
}
|
| 642 |
}
|
| 643 |
|
| 644 |
-
func GetRootUserEmail() (email string) {
|
| 645 |
-
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
| 646 |
-
return email
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
}
|
| 648 |
|
| 649 |
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
|
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
|
|
| 725 |
return !errors.Is(err, gorm.ErrRecordNotFound)
|
| 726 |
}
|
| 727 |
|
| 728 |
-
func (
|
| 729 |
-
if
|
| 730 |
return errors.New("linux do id is empty")
|
| 731 |
}
|
| 732 |
-
err := DB.Where("linux_do_id = ?",
|
| 733 |
return err
|
| 734 |
}
|
|
|
|
| 1 |
package model
|
| 2 |
|
| 3 |
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
"errors"
|
| 6 |
"fmt"
|
| 7 |
"one-api/common"
|
|
|
|
| 39 |
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
| 40 |
DeletedAt gorm.DeletedAt `gorm:"index"`
|
| 41 |
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
|
| 42 |
+
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
func (user *User) ToBaseUser() *UserBase {
|
| 46 |
+
cache := &UserBase{
|
| 47 |
+
Id: user.Id,
|
| 48 |
+
Group: user.Group,
|
| 49 |
+
Quota: user.Quota,
|
| 50 |
+
Status: user.Status,
|
| 51 |
+
Username: user.Username,
|
| 52 |
+
Setting: user.Setting,
|
| 53 |
+
Email: user.Email,
|
| 54 |
+
}
|
| 55 |
+
return cache
|
| 56 |
}
|
| 57 |
|
| 58 |
func (user *User) GetAccessToken() string {
|
|
|
|
| 66 |
user.AccessToken = &token
|
| 67 |
}
|
| 68 |
|
| 69 |
+
func (user *User) GetSetting() map[string]interface{} {
|
| 70 |
+
if user.Setting == "" {
|
| 71 |
+
return nil
|
| 72 |
+
}
|
| 73 |
+
return common.StrToMap(user.Setting)
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
func (user *User) SetSetting(setting map[string]interface{}) {
|
| 77 |
+
settingBytes, err := json.Marshal(setting)
|
| 78 |
+
if err != nil {
|
| 79 |
+
common.SysError("failed to marshal setting: " + err.Error())
|
| 80 |
+
return
|
| 81 |
+
}
|
| 82 |
+
user.Setting = string(settingBytes)
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
| 86 |
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
| 87 |
var user User
|
|
|
|
| 346 |
return err
|
| 347 |
}
|
| 348 |
|
| 349 |
+
// Update cache
|
| 350 |
+
return updateUserCache(*user)
|
| 351 |
}
|
| 352 |
|
| 353 |
func (user *User) Edit(updatePassword bool) error {
|
|
|
|
| 375 |
return err
|
| 376 |
}
|
| 377 |
|
| 378 |
+
// Update cache
|
| 379 |
+
return updateUserCache(*user)
|
| 380 |
}
|
| 381 |
|
| 382 |
func (user *User) Delete() error {
|
|
|
|
| 402 |
// ValidateAndFill check password & user status
|
| 403 |
func (user *User) ValidateAndFill() (err error) {
|
| 404 |
// When querying with struct, GORM will only query with non-zero fields,
|
| 405 |
+
// that means if your field's value is 0, '', false or other zero values,
|
| 406 |
+
// it won't be used to build query conditions
|
| 407 |
password := user.Password
|
| 408 |
username := strings.TrimSpace(user.Username)
|
| 409 |
if username == "" || password == "" {
|
|
|
|
| 562 |
return quota, nil
|
| 563 |
}
|
| 564 |
// Don't return error - fall through to DB
|
|
|
|
| 565 |
}
|
| 566 |
fromDB = true
|
| 567 |
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
|
|
|
| 610 |
return group, nil
|
| 611 |
}
|
| 612 |
|
| 613 |
+
// GetUserSetting gets setting from Redis first, falls back to DB if needed
|
| 614 |
+
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
|
| 615 |
+
var setting string
|
| 616 |
+
defer func() {
|
| 617 |
+
// Update Redis cache asynchronously on successful DB read
|
| 618 |
+
if shouldUpdateRedis(fromDB, err) {
|
| 619 |
+
gopool.Go(func() {
|
| 620 |
+
if err := updateUserSettingCache(id, setting); err != nil {
|
| 621 |
+
common.SysError("failed to update user setting cache: " + err.Error())
|
| 622 |
+
}
|
| 623 |
+
})
|
| 624 |
+
}
|
| 625 |
+
}()
|
| 626 |
+
if !fromDB && common.RedisEnabled {
|
| 627 |
+
setting, err := getUserSettingCache(id)
|
| 628 |
+
if err == nil {
|
| 629 |
+
return setting, nil
|
| 630 |
+
}
|
| 631 |
+
// Don't return error - fall through to DB
|
| 632 |
+
}
|
| 633 |
+
fromDB = true
|
| 634 |
+
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
|
| 635 |
+
if err != nil {
|
| 636 |
+
return map[string]interface{}{}, err
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
return common.StrToMap(setting), nil
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
func IncreaseUserQuota(id int, quota int) (err error) {
|
| 643 |
if quota < 0 {
|
| 644 |
return errors.New("quota 不能为负数!")
|
|
|
|
| 700 |
}
|
| 701 |
}
|
| 702 |
|
| 703 |
+
//func GetRootUserEmail() (email string) {
|
| 704 |
+
// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
| 705 |
+
// return email
|
| 706 |
+
//}
|
| 707 |
+
|
| 708 |
+
func GetRootUser() (user *User) {
|
| 709 |
+
DB.Where("role = ?", common.RoleRootUser).First(&user)
|
| 710 |
+
return user
|
| 711 |
}
|
| 712 |
|
| 713 |
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
|
|
|
| 789 |
return !errors.Is(err, gorm.ErrRecordNotFound)
|
| 790 |
}
|
| 791 |
|
| 792 |
+
func (user *User) FillUserByLinuxDOId() error {
|
| 793 |
+
if user.LinuxDOId == "" {
|
| 794 |
return errors.New("linux do id is empty")
|
| 795 |
}
|
| 796 |
+
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
|
| 797 |
return err
|
| 798 |
}
|
model/user_cache.go
CHANGED
|
@@ -1,206 +1,213 @@
|
|
| 1 |
package model
|
| 2 |
|
| 3 |
import (
|
|
|
|
| 4 |
"fmt"
|
| 5 |
"one-api/common"
|
| 6 |
"one-api/constant"
|
| 7 |
-
"strconv"
|
| 8 |
"time"
|
|
|
|
|
|
|
| 9 |
)
|
| 10 |
|
| 11 |
-
//
|
| 12 |
-
type
|
| 13 |
Id int `json:"id"`
|
| 14 |
Group string `json:"group"`
|
|
|
|
| 15 |
Quota int `json:"quota"`
|
| 16 |
Status int `json:"status"`
|
| 17 |
-
Role int `json:"role"`
|
| 18 |
Username string `json:"username"`
|
|
|
|
| 19 |
}
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
func invalidateUserCache(userId int) error {
|
| 24 |
-
if !common.RedisEnabled {
|
| 25 |
return nil
|
| 26 |
}
|
| 27 |
-
|
| 28 |
-
keys := []string{
|
| 29 |
-
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
|
| 30 |
-
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
|
| 31 |
-
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
|
| 32 |
-
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
|
| 33 |
-
}
|
| 34 |
-
|
| 35 |
-
for _, key := range keys {
|
| 36 |
-
if err := common.RedisDel(key); err != nil {
|
| 37 |
-
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
|
| 38 |
-
}
|
| 39 |
-
}
|
| 40 |
-
return nil
|
| 41 |
}
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
if
|
| 46 |
-
|
|
|
|
| 47 |
}
|
| 48 |
-
|
| 49 |
-
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
|
| 50 |
-
group,
|
| 51 |
-
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
| 52 |
-
)
|
| 53 |
}
|
| 54 |
|
| 55 |
-
//
|
| 56 |
-
func
|
| 57 |
-
|
| 58 |
-
return nil
|
| 59 |
-
}
|
| 60 |
-
return common.RedisSet(
|
| 61 |
-
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
|
| 62 |
-
fmt.Sprintf("%d", quota),
|
| 63 |
-
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
| 64 |
-
)
|
| 65 |
}
|
| 66 |
|
| 67 |
-
//
|
| 68 |
-
func
|
| 69 |
if !common.RedisEnabled {
|
| 70 |
return nil
|
| 71 |
}
|
| 72 |
-
|
| 73 |
-
if userEnabled {
|
| 74 |
-
enabled = "1"
|
| 75 |
-
}
|
| 76 |
-
return common.RedisSet(
|
| 77 |
-
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
|
| 78 |
-
enabled,
|
| 79 |
-
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
|
| 80 |
-
)
|
| 81 |
}
|
| 82 |
|
| 83 |
-
//
|
| 84 |
-
func
|
| 85 |
if !common.RedisEnabled {
|
| 86 |
return nil
|
| 87 |
}
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
| 92 |
)
|
| 93 |
}
|
| 94 |
|
| 95 |
-
//
|
| 96 |
-
func
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
}
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
| 103 |
}
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
}
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
}
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
|
|
|
| 118 |
}
|
| 119 |
|
| 120 |
-
//
|
| 121 |
func getUserGroupCache(userId int) (string, error) {
|
| 122 |
-
|
| 123 |
-
|
|
|
|
| 124 |
}
|
| 125 |
-
return
|
| 126 |
}
|
| 127 |
|
| 128 |
-
// getUserQuotaCache gets user quota from cache
|
| 129 |
func getUserQuotaCache(userId int) (int, error) {
|
| 130 |
-
|
| 131 |
-
return 0, nil
|
| 132 |
-
}
|
| 133 |
-
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
|
| 134 |
if err != nil {
|
| 135 |
return 0, err
|
| 136 |
}
|
| 137 |
-
return
|
| 138 |
}
|
| 139 |
|
| 140 |
-
// getUserStatusCache gets user status from cache
|
| 141 |
func getUserStatusCache(userId int) (int, error) {
|
| 142 |
-
|
| 143 |
-
return 0, nil
|
| 144 |
-
}
|
| 145 |
-
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
|
| 146 |
if err != nil {
|
| 147 |
return 0, err
|
| 148 |
}
|
| 149 |
-
return
|
| 150 |
}
|
| 151 |
|
| 152 |
-
// getUserNameCache gets username from cache
|
| 153 |
func getUserNameCache(userId int) (string, error) {
|
| 154 |
-
|
| 155 |
-
|
|
|
|
| 156 |
}
|
| 157 |
-
return
|
| 158 |
}
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
return nil, nil
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
group, err := getUserGroupCache(userId)
|
| 167 |
if err != nil {
|
| 168 |
-
return
|
| 169 |
}
|
|
|
|
|
|
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
| 174 |
}
|
| 175 |
-
|
| 176 |
-
status
|
| 177 |
-
|
| 178 |
-
return nil, fmt.Errorf("get status cache: %w", err)
|
| 179 |
}
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
|
| 182 |
-
if
|
| 183 |
-
return nil
|
| 184 |
}
|
|
|
|
|
|
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
Username: username,
|
| 192 |
-
}, nil
|
| 193 |
}
|
| 194 |
|
| 195 |
-
|
| 196 |
-
func cacheIncrUserQuota(userId int, delta int64) error {
|
| 197 |
if !common.RedisEnabled {
|
| 198 |
return nil
|
| 199 |
}
|
| 200 |
-
|
| 201 |
-
return common.RedisIncr(key, delta)
|
| 202 |
}
|
| 203 |
|
| 204 |
-
func
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
| 206 |
}
|
|
|
|
| 1 |
package model
|
| 2 |
|
| 3 |
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
"fmt"
|
| 6 |
"one-api/common"
|
| 7 |
"one-api/constant"
|
|
|
|
| 8 |
"time"
|
| 9 |
+
|
| 10 |
+
"github.com/bytedance/gopkg/util/gopool"
|
| 11 |
)
|
| 12 |
|
| 13 |
+
// UserBase struct remains the same as it represents the cached data structure
|
| 14 |
+
type UserBase struct {
|
| 15 |
Id int `json:"id"`
|
| 16 |
Group string `json:"group"`
|
| 17 |
+
Email string `json:"email"`
|
| 18 |
Quota int `json:"quota"`
|
| 19 |
Status int `json:"status"`
|
|
|
|
| 20 |
Username string `json:"username"`
|
| 21 |
+
Setting string `json:"setting"`
|
| 22 |
}
|
| 23 |
|
| 24 |
+
func (user *UserBase) GetSetting() map[string]interface{} {
|
| 25 |
+
if user.Setting == "" {
|
|
|
|
|
|
|
| 26 |
return nil
|
| 27 |
}
|
| 28 |
+
return common.StrToMap(user.Setting)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
}
|
| 30 |
|
| 31 |
+
func (user *UserBase) SetSetting(setting map[string]interface{}) {
|
| 32 |
+
settingBytes, err := json.Marshal(setting)
|
| 33 |
+
if err != nil {
|
| 34 |
+
common.SysError("failed to marshal setting: " + err.Error())
|
| 35 |
+
return
|
| 36 |
}
|
| 37 |
+
user.Setting = string(settingBytes)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
}
|
| 39 |
|
| 40 |
+
// getUserCacheKey returns the key for user cache
|
| 41 |
+
func getUserCacheKey(userId int) string {
|
| 42 |
+
return fmt.Sprintf("user:%d", userId)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
+
// invalidateUserCache clears user cache
|
| 46 |
+
func invalidateUserCache(userId int) error {
|
| 47 |
if !common.RedisEnabled {
|
| 48 |
return nil
|
| 49 |
}
|
| 50 |
+
return common.RedisHDelObj(getUserCacheKey(userId))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
}
|
| 52 |
|
| 53 |
+
// updateUserCache updates all user cache fields using hash
|
| 54 |
+
func updateUserCache(user User) error {
|
| 55 |
if !common.RedisEnabled {
|
| 56 |
return nil
|
| 57 |
}
|
| 58 |
+
|
| 59 |
+
return common.RedisHSetObj(
|
| 60 |
+
getUserCacheKey(user.Id),
|
| 61 |
+
user.ToBaseUser(),
|
| 62 |
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
| 63 |
)
|
| 64 |
}
|
| 65 |
|
| 66 |
+
// GetUserCache gets complete user cache from hash
|
| 67 |
+
func GetUserCache(userId int) (userCache *UserBase, err error) {
|
| 68 |
+
var user *User
|
| 69 |
+
var fromDB bool
|
| 70 |
+
defer func() {
|
| 71 |
+
// Update Redis cache asynchronously on successful DB read
|
| 72 |
+
if shouldUpdateRedis(fromDB, err) && user != nil {
|
| 73 |
+
gopool.Go(func() {
|
| 74 |
+
if err := updateUserCache(*user); err != nil {
|
| 75 |
+
common.SysError("failed to update user status cache: " + err.Error())
|
| 76 |
+
}
|
| 77 |
+
})
|
| 78 |
+
}
|
| 79 |
+
}()
|
| 80 |
+
|
| 81 |
+
// Try getting from Redis first
|
| 82 |
+
userCache, err = cacheGetUserBase(userId)
|
| 83 |
+
if err == nil {
|
| 84 |
+
return userCache, nil
|
| 85 |
}
|
| 86 |
|
| 87 |
+
// If Redis fails, get from DB
|
| 88 |
+
fromDB = true
|
| 89 |
+
user, err = GetUserById(userId, false)
|
| 90 |
+
if err != nil {
|
| 91 |
+
return nil, err // Return nil and error if DB lookup fails
|
| 92 |
}
|
| 93 |
|
| 94 |
+
// Create cache object from user data
|
| 95 |
+
userCache = &UserBase{
|
| 96 |
+
Id: user.Id,
|
| 97 |
+
Group: user.Group,
|
| 98 |
+
Quota: user.Quota,
|
| 99 |
+
Status: user.Status,
|
| 100 |
+
Username: user.Username,
|
| 101 |
+
Setting: user.Setting,
|
| 102 |
+
Email: user.Email,
|
| 103 |
}
|
| 104 |
|
| 105 |
+
return userCache, nil
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
func cacheGetUserBase(userId int) (*UserBase, error) {
|
| 109 |
+
if !common.RedisEnabled {
|
| 110 |
+
return nil, fmt.Errorf("redis is not enabled")
|
| 111 |
}
|
| 112 |
+
var userCache UserBase
|
| 113 |
+
// Try getting from Redis first
|
| 114 |
+
err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
|
| 115 |
+
if err != nil {
|
| 116 |
+
return nil, err
|
| 117 |
+
}
|
| 118 |
+
return &userCache, nil
|
| 119 |
+
}
|
| 120 |
|
| 121 |
+
// Add atomic quota operations using hash fields
|
| 122 |
+
func cacheIncrUserQuota(userId int, delta int64) error {
|
| 123 |
+
if !common.RedisEnabled {
|
| 124 |
+
return nil
|
| 125 |
}
|
| 126 |
+
return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
|
| 127 |
+
}
|
| 128 |
|
| 129 |
+
func cacheDecrUserQuota(userId int, delta int64) error {
|
| 130 |
+
return cacheIncrUserQuota(userId, -delta)
|
| 131 |
}
|
| 132 |
|
| 133 |
+
// Helper functions to get individual fields if needed
|
| 134 |
func getUserGroupCache(userId int) (string, error) {
|
| 135 |
+
cache, err := GetUserCache(userId)
|
| 136 |
+
if err != nil {
|
| 137 |
+
return "", err
|
| 138 |
}
|
| 139 |
+
return cache.Group, nil
|
| 140 |
}
|
| 141 |
|
|
|
|
| 142 |
func getUserQuotaCache(userId int) (int, error) {
|
| 143 |
+
cache, err := GetUserCache(userId)
|
|
|
|
|
|
|
|
|
|
| 144 |
if err != nil {
|
| 145 |
return 0, err
|
| 146 |
}
|
| 147 |
+
return cache.Quota, nil
|
| 148 |
}
|
| 149 |
|
|
|
|
| 150 |
func getUserStatusCache(userId int) (int, error) {
|
| 151 |
+
cache, err := GetUserCache(userId)
|
|
|
|
|
|
|
|
|
|
| 152 |
if err != nil {
|
| 153 |
return 0, err
|
| 154 |
}
|
| 155 |
+
return cache.Status, nil
|
| 156 |
}
|
| 157 |
|
|
|
|
| 158 |
func getUserNameCache(userId int) (string, error) {
|
| 159 |
+
cache, err := GetUserCache(userId)
|
| 160 |
+
if err != nil {
|
| 161 |
+
return "", err
|
| 162 |
}
|
| 163 |
+
return cache.Username, nil
|
| 164 |
}
|
| 165 |
|
| 166 |
+
func getUserSettingCache(userId int) (map[string]interface{}, error) {
|
| 167 |
+
setting := make(map[string]interface{})
|
| 168 |
+
cache, err := GetUserCache(userId)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
if err != nil {
|
| 170 |
+
return setting, err
|
| 171 |
}
|
| 172 |
+
return cache.GetSetting(), nil
|
| 173 |
+
}
|
| 174 |
|
| 175 |
+
// New functions for individual field updates
|
| 176 |
+
func updateUserStatusCache(userId int, status bool) error {
|
| 177 |
+
if !common.RedisEnabled {
|
| 178 |
+
return nil
|
| 179 |
}
|
| 180 |
+
statusInt := common.UserStatusEnabled
|
| 181 |
+
if !status {
|
| 182 |
+
statusInt = common.UserStatusDisabled
|
|
|
|
| 183 |
}
|
| 184 |
+
return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
|
| 185 |
+
}
|
| 186 |
|
| 187 |
+
func updateUserQuotaCache(userId int, quota int) error {
|
| 188 |
+
if !common.RedisEnabled {
|
| 189 |
+
return nil
|
| 190 |
}
|
| 191 |
+
return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
|
| 192 |
+
}
|
| 193 |
|
| 194 |
+
func updateUserGroupCache(userId int, group string) error {
|
| 195 |
+
if !common.RedisEnabled {
|
| 196 |
+
return nil
|
| 197 |
+
}
|
| 198 |
+
return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
|
|
|
|
|
|
|
| 199 |
}
|
| 200 |
|
| 201 |
+
func updateUserNameCache(userId int, username string) error {
|
|
|
|
| 202 |
if !common.RedisEnabled {
|
| 203 |
return nil
|
| 204 |
}
|
| 205 |
+
return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
|
|
|
|
| 206 |
}
|
| 207 |
|
| 208 |
+
func updateUserSettingCache(userId int, setting string) error {
|
| 209 |
+
if !common.RedisEnabled {
|
| 210 |
+
return nil
|
| 211 |
+
}
|
| 212 |
+
return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
|
| 213 |
}
|
relay/channel/cloudflare/adaptor.go
CHANGED
|
@@ -4,13 +4,14 @@ import (
|
|
| 4 |
"bytes"
|
| 5 |
"errors"
|
| 6 |
"fmt"
|
| 7 |
-
"github.com/gin-gonic/gin"
|
| 8 |
"io"
|
| 9 |
"net/http"
|
| 10 |
"one-api/dto"
|
| 11 |
"one-api/relay/channel"
|
| 12 |
relaycommon "one-api/relay/common"
|
| 13 |
"one-api/relay/constant"
|
|
|
|
|
|
|
| 14 |
)
|
| 15 |
|
| 16 |
type Adaptor struct {
|
|
|
|
| 4 |
"bytes"
|
| 5 |
"errors"
|
| 6 |
"fmt"
|
|
|
|
| 7 |
"io"
|
| 8 |
"net/http"
|
| 9 |
"one-api/dto"
|
| 10 |
"one-api/relay/channel"
|
| 11 |
relaycommon "one-api/relay/common"
|
| 12 |
"one-api/relay/constant"
|
| 13 |
+
|
| 14 |
+
"github.com/gin-gonic/gin"
|
| 15 |
)
|
| 16 |
|
| 17 |
type Adaptor struct {
|
relay/channel/deepseek/adaptor.go
CHANGED
|
@@ -10,6 +10,7 @@ import (
|
|
| 10 |
"one-api/relay/channel"
|
| 11 |
"one-api/relay/channel/openai"
|
| 12 |
relaycommon "one-api/relay/common"
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
type Adaptor struct {
|
|
@@ -29,7 +30,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
| 29 |
}
|
| 30 |
|
| 31 |
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
}
|
| 34 |
|
| 35 |
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
|
|
|
| 10 |
"one-api/relay/channel"
|
| 11 |
"one-api/relay/channel/openai"
|
| 12 |
relaycommon "one-api/relay/common"
|
| 13 |
+
"one-api/relay/constant"
|
| 14 |
)
|
| 15 |
|
| 16 |
type Adaptor struct {
|
|
|
|
| 30 |
}
|
| 31 |
|
| 32 |
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
| 33 |
+
switch info.RelayMode {
|
| 34 |
+
case constant.RelayModeCompletions:
|
| 35 |
+
return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil
|
| 36 |
+
default:
|
| 37 |
+
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
| 38 |
+
}
|
| 39 |
}
|
| 40 |
|
| 41 |
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
relay/channel/gemini/adaptor.go
CHANGED
|
@@ -1,15 +1,21 @@
|
|
| 1 |
package gemini
|
| 2 |
|
| 3 |
import (
|
|
|
|
| 4 |
"errors"
|
| 5 |
"fmt"
|
| 6 |
-
"github.com/gin-gonic/gin"
|
| 7 |
"io"
|
| 8 |
"net/http"
|
|
|
|
| 9 |
"one-api/constant"
|
| 10 |
"one-api/dto"
|
| 11 |
"one-api/relay/channel"
|
| 12 |
relaycommon "one-api/relay/common"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
type Adaptor struct {
|
|
@@ -21,8 +27,36 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
|
| 21 |
}
|
| 22 |
|
| 23 |
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
}
|
| 27 |
|
| 28 |
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
@@ -40,6 +74,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
| 40 |
}
|
| 41 |
}
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
action := "generateContent"
|
| 44 |
if info.IsStream {
|
| 45 |
action = "streamGenerateContent?alt=sse"
|
|
@@ -73,12 +111,15 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|
| 73 |
return nil, errors.New("not implemented")
|
| 74 |
}
|
| 75 |
|
| 76 |
-
|
| 77 |
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
| 78 |
return channel.DoApiRequest(a, c, info, requestBody)
|
| 79 |
}
|
| 80 |
|
| 81 |
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
if info.IsStream {
|
| 83 |
err, usage = GeminiChatStreamHandler(c, resp, info)
|
| 84 |
} else {
|
|
@@ -87,6 +128,60 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|
| 87 |
return
|
| 88 |
}
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
func (a *Adaptor) GetModelList() []string {
|
| 91 |
return ModelList
|
| 92 |
}
|
|
|
|
| 1 |
package gemini
|
| 2 |
|
| 3 |
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
"errors"
|
| 6 |
"fmt"
|
|
|
|
| 7 |
"io"
|
| 8 |
"net/http"
|
| 9 |
+
"one-api/common"
|
| 10 |
"one-api/constant"
|
| 11 |
"one-api/dto"
|
| 12 |
"one-api/relay/channel"
|
| 13 |
relaycommon "one-api/relay/common"
|
| 14 |
+
"one-api/service"
|
| 15 |
+
|
| 16 |
+
"strings"
|
| 17 |
+
|
| 18 |
+
"github.com/gin-gonic/gin"
|
| 19 |
)
|
| 20 |
|
| 21 |
type Adaptor struct {
|
|
|
|
| 27 |
}
|
| 28 |
|
| 29 |
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
| 30 |
+
if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
| 31 |
+
return nil, errors.New("not supported model for image generation")
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// convert size to aspect ratio
|
| 35 |
+
aspectRatio := "1:1" // default aspect ratio
|
| 36 |
+
switch request.Size {
|
| 37 |
+
case "1024x1024":
|
| 38 |
+
aspectRatio = "1:1"
|
| 39 |
+
case "1024x1792":
|
| 40 |
+
aspectRatio = "9:16"
|
| 41 |
+
case "1792x1024":
|
| 42 |
+
aspectRatio = "16:9"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// build gemini imagen request
|
| 46 |
+
geminiRequest := GeminiImageRequest{
|
| 47 |
+
Instances: []GeminiImageInstance{
|
| 48 |
+
{
|
| 49 |
+
Prompt: request.Prompt,
|
| 50 |
+
},
|
| 51 |
+
},
|
| 52 |
+
Parameters: GeminiImageParameters{
|
| 53 |
+
SampleCount: request.N,
|
| 54 |
+
AspectRatio: aspectRatio,
|
| 55 |
+
PersonGeneration: "allow_adult", // default allow adult
|
| 56 |
+
},
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
return geminiRequest, nil
|
| 60 |
}
|
| 61 |
|
| 62 |
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
|
|
| 74 |
}
|
| 75 |
}
|
| 76 |
|
| 77 |
+
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
| 78 |
+
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
action := "generateContent"
|
| 82 |
if info.IsStream {
|
| 83 |
action = "streamGenerateContent?alt=sse"
|
|
|
|
| 111 |
return nil, errors.New("not implemented")
|
| 112 |
}
|
| 113 |
|
|
|
|
| 114 |
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
| 115 |
return channel.DoApiRequest(a, c, info, requestBody)
|
| 116 |
}
|
| 117 |
|
| 118 |
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
| 119 |
+
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
| 120 |
+
return GeminiImageHandler(c, resp, info)
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
if info.IsStream {
|
| 124 |
err, usage = GeminiChatStreamHandler(c, resp, info)
|
| 125 |
} else {
|
|
|
|
| 128 |
return
|
| 129 |
}
|
| 130 |
|
| 131 |
+
func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
| 132 |
+
responseBody, readErr := io.ReadAll(resp.Body)
|
| 133 |
+
if readErr != nil {
|
| 134 |
+
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
| 135 |
+
}
|
| 136 |
+
_ = resp.Body.Close()
|
| 137 |
+
|
| 138 |
+
var geminiResponse GeminiImageResponse
|
| 139 |
+
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
| 140 |
+
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if len(geminiResponse.Predictions) == 0 {
|
| 144 |
+
return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// convert to openai format response
|
| 148 |
+
openAIResponse := dto.ImageResponse{
|
| 149 |
+
Created: common.GetTimestamp(),
|
| 150 |
+
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
for _, prediction := range geminiResponse.Predictions {
|
| 154 |
+
if prediction.RaiFilteredReason != "" {
|
| 155 |
+
continue // skip filtered image
|
| 156 |
+
}
|
| 157 |
+
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
|
| 158 |
+
B64Json: prediction.BytesBase64Encoded,
|
| 159 |
+
})
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
| 163 |
+
if jsonErr != nil {
|
| 164 |
+
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
c.Writer.Header().Set("Content-Type", "application/json")
|
| 168 |
+
c.Writer.WriteHeader(resp.StatusCode)
|
| 169 |
+
_, _ = c.Writer.Write(jsonResponse)
|
| 170 |
+
|
| 171 |
+
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
|
| 172 |
+
// each image has fixed 258 tokens
|
| 173 |
+
const imageTokens = 258
|
| 174 |
+
generatedImages := len(openAIResponse.Data)
|
| 175 |
+
|
| 176 |
+
usage = &dto.Usage{
|
| 177 |
+
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
| 178 |
+
CompletionTokens: 0, // image generation does not calculate completion tokens
|
| 179 |
+
TotalTokens: imageTokens * generatedImages,
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
return usage, nil
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
func (a *Adaptor) GetModelList() []string {
|
| 186 |
return ModelList
|
| 187 |
}
|
relay/channel/gemini/constant.go
CHANGED
|
@@ -16,6 +16,8 @@ var ModelList = []string{
|
|
| 16 |
"gemini-2.0-pro-exp",
|
| 17 |
// thinking exp
|
| 18 |
"gemini-2.0-flash-thinking-exp",
|
|
|
|
|
|
|
| 19 |
}
|
| 20 |
|
| 21 |
var ChannelName = "google gemini"
|
|
|
|
| 16 |
"gemini-2.0-pro-exp",
|
| 17 |
// thinking exp
|
| 18 |
"gemini-2.0-flash-thinking-exp",
|
| 19 |
+
// imagen models
|
| 20 |
+
"imagen-3.0-generate-002",
|
| 21 |
}
|
| 22 |
|
| 23 |
var ChannelName = "google gemini"
|
relay/channel/gemini/dto.go
CHANGED
|
@@ -109,3 +109,30 @@ type GeminiUsageMetadata struct {
|
|
| 109 |
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
| 110 |
TotalTokenCount int `json:"totalTokenCount"`
|
| 111 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
| 110 |
TotalTokenCount int `json:"totalTokenCount"`
|
| 111 |
}
|
| 112 |
+
|
| 113 |
+
// Imagen related structs
|
| 114 |
+
type GeminiImageRequest struct {
|
| 115 |
+
Instances []GeminiImageInstance `json:"instances"`
|
| 116 |
+
Parameters GeminiImageParameters `json:"parameters"`
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
type GeminiImageInstance struct {
|
| 120 |
+
Prompt string `json:"prompt"`
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
type GeminiImageParameters struct {
|
| 124 |
+
SampleCount int `json:"sampleCount,omitempty"`
|
| 125 |
+
AspectRatio string `json:"aspectRatio,omitempty"`
|
| 126 |
+
PersonGeneration string `json:"personGeneration,omitempty"`
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
type GeminiImageResponse struct {
|
| 130 |
+
Predictions []GeminiImagePrediction `json:"predictions"`
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
type GeminiImagePrediction struct {
|
| 134 |
+
MimeType string `json:"mimeType"`
|
| 135 |
+
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
| 136 |
+
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
|
| 137 |
+
SafetyAttributes any `json:"safetyAttributes,omitempty"`
|
| 138 |
+
}
|
relay/channel/mistral/adaptor.go
CHANGED
|
@@ -41,9 +41,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|
| 41 |
if request == nil {
|
| 42 |
return nil, errors.New("request is nil")
|
| 43 |
}
|
| 44 |
-
|
| 45 |
-
//common.LogJson(c, "body", mistralReq)
|
| 46 |
-
return mistralReq, nil
|
| 47 |
}
|
| 48 |
|
| 49 |
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
@@ -55,7 +53,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|
| 55 |
return nil, errors.New("not implemented")
|
| 56 |
}
|
| 57 |
|
| 58 |
-
|
| 59 |
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
| 60 |
return channel.DoApiRequest(a, c, info, requestBody)
|
| 61 |
}
|
|
|
|
| 41 |
if request == nil {
|
| 42 |
return nil, errors.New("request is nil")
|
| 43 |
}
|
| 44 |
+
return requestOpenAI2Mistral(request), nil
|
|
|
|
|
|
|
| 45 |
}
|
| 46 |
|
| 47 |
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
|
|
| 53 |
return nil, errors.New("not implemented")
|
| 54 |
}
|
| 55 |
|
|
|
|
| 56 |
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
| 57 |
return channel.DoApiRequest(a, c, info, requestBody)
|
| 58 |
}
|
relay/channel/mistral/text.go
CHANGED
|
@@ -1,25 +1,21 @@
|
|
| 1 |
package mistral
|
| 2 |
|
| 3 |
import (
|
| 4 |
-
"encoding/json"
|
| 5 |
"one-api/dto"
|
| 6 |
)
|
| 7 |
|
| 8 |
-
func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
| 9 |
messages := make([]dto.Message, 0, len(request.Messages))
|
| 10 |
for _, message := range request.Messages {
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
mediaMessages[j] = mediaMessage
|
| 18 |
-
}
|
| 19 |
}
|
| 20 |
-
messageRaw, _ := json.Marshal(mediaMessages)
|
| 21 |
-
message.Content = messageRaw
|
| 22 |
}
|
|
|
|
| 23 |
messages = append(messages, dto.Message{
|
| 24 |
Role: message.Role,
|
| 25 |
Content: message.Content,
|
|
|
|
| 1 |
package mistral
|
| 2 |
|
| 3 |
import (
|
|
|
|
| 4 |
"one-api/dto"
|
| 5 |
)
|
| 6 |
|
| 7 |
+
func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
| 8 |
messages := make([]dto.Message, 0, len(request.Messages))
|
| 9 |
for _, message := range request.Messages {
|
| 10 |
+
mediaMessages := message.ParseContent()
|
| 11 |
+
for j, mediaMessage := range mediaMessages {
|
| 12 |
+
if mediaMessage.Type == dto.ContentTypeImageURL {
|
| 13 |
+
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
| 14 |
+
mediaMessage.ImageUrl = imageUrl.Url
|
| 15 |
+
mediaMessages[j] = mediaMessage
|
|
|
|
|
|
|
| 16 |
}
|
|
|
|
|
|
|
| 17 |
}
|
| 18 |
+
message.SetMediaContent(mediaMessages)
|
| 19 |
messages = append(messages, dto.Message{
|
| 20 |
Role: message.Role,
|
| 21 |
Content: message.Content,
|
relay/channel/ollama/adaptor.go
CHANGED
|
@@ -39,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
| 39 |
|
| 40 |
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
| 41 |
channel.SetupApiRequestHeader(info, c, req)
|
|
|
|
| 42 |
return nil
|
| 43 |
}
|
| 44 |
|
|
@@ -46,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|
| 46 |
if request == nil {
|
| 47 |
return nil, errors.New("request is nil")
|
| 48 |
}
|
| 49 |
-
return requestOpenAI2Ollama(*request)
|
| 50 |
}
|
| 51 |
|
| 52 |
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
|
|
| 39 |
|
| 40 |
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
| 41 |
channel.SetupApiRequestHeader(info, c, req)
|
| 42 |
+
req.Set("Authorization", "Bearer "+info.ApiKey)
|
| 43 |
return nil
|
| 44 |
}
|
| 45 |
|
|
|
|
| 47 |
if request == nil {
|
| 48 |
return nil, errors.New("request is nil")
|
| 49 |
}
|
| 50 |
+
return requestOpenAI2Ollama(*request)
|
| 51 |
}
|
| 52 |
|
| 53 |
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
relay/channel/ollama/dto.go
CHANGED
|
@@ -3,18 +3,21 @@ package ollama
|
|
| 3 |
import "one-api/dto"
|
| 4 |
|
| 5 |
type OllamaRequest struct {
|
| 6 |
-
Model string
|
| 7 |
-
Messages []dto.Message
|
| 8 |
-
Stream bool
|
| 9 |
-
Temperature *float64
|
| 10 |
-
Seed float64
|
| 11 |
-
Topp float64
|
| 12 |
-
TopK int
|
| 13 |
-
Stop any
|
| 14 |
-
Tools []dto.ToolCall
|
| 15 |
-
ResponseFormat any
|
| 16 |
-
FrequencyPenalty float64
|
| 17 |
-
PresencePenalty float64
|
|
|
|
|
|
|
|
|
|
| 18 |
}
|
| 19 |
|
| 20 |
type Options struct {
|
|
@@ -35,7 +38,7 @@ type OllamaEmbeddingRequest struct {
|
|
| 35 |
}
|
| 36 |
|
| 37 |
type OllamaEmbeddingResponse struct {
|
| 38 |
-
Error string
|
| 39 |
-
Model string
|
| 40 |
Embedding [][]float64 `json:"embeddings,omitempty"`
|
| 41 |
}
|
|
|
|
| 3 |
import "one-api/dto"
|
| 4 |
|
| 5 |
type OllamaRequest struct {
|
| 6 |
+
Model string `json:"model,omitempty"`
|
| 7 |
+
Messages []dto.Message `json:"messages,omitempty"`
|
| 8 |
+
Stream bool `json:"stream,omitempty"`
|
| 9 |
+
Temperature *float64 `json:"temperature,omitempty"`
|
| 10 |
+
Seed float64 `json:"seed,omitempty"`
|
| 11 |
+
Topp float64 `json:"top_p,omitempty"`
|
| 12 |
+
TopK int `json:"top_k,omitempty"`
|
| 13 |
+
Stop any `json:"stop,omitempty"`
|
| 14 |
+
Tools []dto.ToolCall `json:"tools,omitempty"`
|
| 15 |
+
ResponseFormat any `json:"response_format,omitempty"`
|
| 16 |
+
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
| 17 |
+
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
| 18 |
+
Suffix any `json:"suffix,omitempty"`
|
| 19 |
+
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
|
| 20 |
+
Prompt any `json:"prompt,omitempty"`
|
| 21 |
}
|
| 22 |
|
| 23 |
type Options struct {
|
|
|
|
| 38 |
}
|
| 39 |
|
| 40 |
type OllamaEmbeddingResponse struct {
|
| 41 |
+
Error string `json:"error,omitempty"`
|
| 42 |
+
Model string `json:"model"`
|
| 43 |
Embedding [][]float64 `json:"embeddings,omitempty"`
|
| 44 |
}
|
relay/channel/ollama/relay-ollama.go
CHANGED
|
@@ -9,14 +9,36 @@ import (
|
|
| 9 |
"net/http"
|
| 10 |
"one-api/dto"
|
| 11 |
"one-api/service"
|
|
|
|
| 12 |
)
|
| 13 |
|
| 14 |
-
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
| 15 |
messages := make([]dto.Message, 0, len(request.Messages))
|
| 16 |
for _, message := range request.Messages {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
messages = append(messages, dto.Message{
|
| 18 |
-
Role:
|
| 19 |
-
Content:
|
|
|
|
|
|
|
| 20 |
})
|
| 21 |
}
|
| 22 |
str, ok := request.Stop.(string)
|
|
@@ -39,7 +61,10 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
|
| 39 |
ResponseFormat: request.ResponseFormat,
|
| 40 |
FrequencyPenalty: request.FrequencyPenalty,
|
| 41 |
PresencePenalty: request.PresencePenalty,
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
|
|
|
| 9 |
"net/http"
|
| 10 |
"one-api/dto"
|
| 11 |
"one-api/service"
|
| 12 |
+
"strings"
|
| 13 |
)
|
| 14 |
|
| 15 |
+
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
| 16 |
messages := make([]dto.Message, 0, len(request.Messages))
|
| 17 |
for _, message := range request.Messages {
|
| 18 |
+
if !message.IsStringContent() {
|
| 19 |
+
mediaMessages := message.ParseContent()
|
| 20 |
+
for j, mediaMessage := range mediaMessages {
|
| 21 |
+
if mediaMessage.Type == dto.ContentTypeImageURL {
|
| 22 |
+
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
| 23 |
+
// check if not base64
|
| 24 |
+
if strings.HasPrefix(imageUrl.Url, "http") {
|
| 25 |
+
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
| 26 |
+
if err != nil {
|
| 27 |
+
return nil, err
|
| 28 |
+
}
|
| 29 |
+
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
|
| 30 |
+
}
|
| 31 |
+
mediaMessage.ImageUrl = imageUrl
|
| 32 |
+
mediaMessages[j] = mediaMessage
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
message.SetMediaContent(mediaMessages)
|
| 36 |
+
}
|
| 37 |
messages = append(messages, dto.Message{
|
| 38 |
+
Role: message.Role,
|
| 39 |
+
Content: message.Content,
|
| 40 |
+
ToolCalls: message.ToolCalls,
|
| 41 |
+
ToolCallId: message.ToolCallId,
|
| 42 |
})
|
| 43 |
}
|
| 44 |
str, ok := request.Stop.(string)
|
|
|
|
| 61 |
ResponseFormat: request.ResponseFormat,
|
| 62 |
FrequencyPenalty: request.FrequencyPenalty,
|
| 63 |
PresencePenalty: request.PresencePenalty,
|
| 64 |
+
Prompt: request.Prompt,
|
| 65 |
+
StreamOptions: request.StreamOptions,
|
| 66 |
+
Suffix: request.Suffix,
|
| 67 |
+
}, nil
|
| 68 |
}
|
| 69 |
|
| 70 |
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
relay/channel/openai/adaptor.go
CHANGED
|
@@ -119,7 +119,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
|
| 119 |
request.MaxCompletionTokens = request.MaxTokens
|
| 120 |
request.MaxTokens = 0
|
| 121 |
}
|
| 122 |
-
if strings.HasPrefix(request.Model, "o3") {
|
| 123 |
request.Temperature = nil
|
| 124 |
}
|
| 125 |
if strings.HasSuffix(request.Model, "-high") {
|
|
|
|
| 119 |
request.MaxCompletionTokens = request.MaxTokens
|
| 120 |
request.MaxTokens = 0
|
| 121 |
}
|
| 122 |
+
if strings.HasPrefix(request.Model, "o3") || strings.HasPrefix(request.Model, "o1") {
|
| 123 |
request.Temperature = nil
|
| 124 |
}
|
| 125 |
if strings.HasSuffix(request.Model, "-high") {
|
relay/channel/openai/relay-openai.go
CHANGED
|
@@ -87,6 +87,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
| 87 |
info.SetFirstResponseTime()
|
| 88 |
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
| 89 |
data := scanner.Text()
|
|
|
|
|
|
|
|
|
|
| 90 |
if len(data) < 6 { // ignore blank line or wrong format
|
| 91 |
continue
|
| 92 |
}
|
|
@@ -162,6 +165,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
| 162 |
//}
|
| 163 |
for _, choice := range streamResponse.Choices {
|
| 164 |
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
|
|
| 165 |
if choice.Delta.ToolCalls != nil {
|
| 166 |
if len(choice.Delta.ToolCalls) > toolCount {
|
| 167 |
toolCount = len(choice.Delta.ToolCalls)
|
|
@@ -182,6 +186,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
| 182 |
//}
|
| 183 |
for _, choice := range streamResponse.Choices {
|
| 184 |
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
|
|
| 185 |
if choice.Delta.ToolCalls != nil {
|
| 186 |
if len(choice.Delta.ToolCalls) > toolCount {
|
| 187 |
toolCount = len(choice.Delta.ToolCalls)
|
|
@@ -273,7 +278,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|
| 273 |
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
| 274 |
completionTokens := 0
|
| 275 |
for _, choice := range simpleResponse.Choices {
|
| 276 |
-
ctkm, _ := service.CountTextToken(
|
| 277 |
completionTokens += ctkm
|
| 278 |
}
|
| 279 |
simpleResponse.Usage = dto.Usage{
|
|
|
|
| 87 |
info.SetFirstResponseTime()
|
| 88 |
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
| 89 |
data := scanner.Text()
|
| 90 |
+
if common.DebugEnabled {
|
| 91 |
+
println(data)
|
| 92 |
+
}
|
| 93 |
if len(data) < 6 { // ignore blank line or wrong format
|
| 94 |
continue
|
| 95 |
}
|
|
|
|
| 165 |
//}
|
| 166 |
for _, choice := range streamResponse.Choices {
|
| 167 |
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
| 168 |
+
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
| 169 |
if choice.Delta.ToolCalls != nil {
|
| 170 |
if len(choice.Delta.ToolCalls) > toolCount {
|
| 171 |
toolCount = len(choice.Delta.ToolCalls)
|
|
|
|
| 186 |
//}
|
| 187 |
for _, choice := range streamResponse.Choices {
|
| 188 |
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
| 189 |
+
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
| 190 |
if choice.Delta.ToolCalls != nil {
|
| 191 |
if len(choice.Delta.ToolCalls) > toolCount {
|
| 192 |
toolCount = len(choice.Delta.ToolCalls)
|
|
|
|
| 278 |
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
| 279 |
completionTokens := 0
|
| 280 |
for _, choice := range simpleResponse.Choices {
|
| 281 |
+
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
|
| 282 |
completionTokens += ctkm
|
| 283 |
}
|
| 284 |
simpleResponse.Usage = dto.Usage{
|
relay/channel/siliconflow/adaptor.go
CHANGED
|
@@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
| 36 |
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
| 37 |
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
| 38 |
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
|
|
|
|
|
|
| 39 |
}
|
| 40 |
return "", errors.New("invalid relay mode")
|
| 41 |
}
|
|
@@ -72,6 +74,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|
| 72 |
} else {
|
| 73 |
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
| 74 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
case constant.RelayModeEmbeddings:
|
| 76 |
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
| 77 |
}
|
|
|
|
| 36 |
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
| 37 |
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
| 38 |
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
| 39 |
+
} else if info.RelayMode == constant.RelayModeCompletions {
|
| 40 |
+
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
|
| 41 |
}
|
| 42 |
return "", errors.New("invalid relay mode")
|
| 43 |
}
|
|
|
|
| 74 |
} else {
|
| 75 |
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
| 76 |
}
|
| 77 |
+
case constant.RelayModeCompletions:
|
| 78 |
+
if info.IsStream {
|
| 79 |
+
err, usage = openai.OaiStreamHandler(c, resp, info)
|
| 80 |
+
} else {
|
| 81 |
+
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
| 82 |
+
}
|
| 83 |
case constant.RelayModeEmbeddings:
|
| 84 |
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
| 85 |
}
|
relay/channel/zhipu_4v/relay-zhipu_v4.go
CHANGED
|
@@ -90,8 +90,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
|
|
| 90 |
mediaMessages[j] = mediaMessage
|
| 91 |
}
|
| 92 |
}
|
| 93 |
-
|
| 94 |
-
message.Content = messageRaw
|
| 95 |
}
|
| 96 |
messages = append(messages, dto.Message{
|
| 97 |
Role: message.Role,
|
|
|
|
| 90 |
mediaMessages[j] = mediaMessage
|
| 91 |
}
|
| 92 |
}
|
| 93 |
+
message.SetMediaContent(mediaMessages)
|
|
|
|
| 94 |
}
|
| 95 |
messages = append(messages, dto.Message{
|
| 96 |
Role: message.Role,
|
relay/common/relay_info.go
CHANGED
|
@@ -13,24 +13,24 @@ import (
|
|
| 13 |
)
|
| 14 |
|
| 15 |
type RelayInfo struct {
|
| 16 |
-
ChannelType
|
| 17 |
-
ChannelId
|
| 18 |
-
TokenId
|
| 19 |
-
TokenKey
|
| 20 |
-
UserId
|
| 21 |
-
Group
|
| 22 |
-
TokenUnlimited
|
| 23 |
-
StartTime
|
| 24 |
-
FirstResponseTime
|
| 25 |
-
setFirstResponse
|
| 26 |
-
ApiType
|
| 27 |
-
IsStream
|
| 28 |
-
IsPlayground
|
| 29 |
-
UsePrice
|
| 30 |
-
RelayMode
|
| 31 |
-
UpstreamModelName
|
| 32 |
-
OriginModelName
|
| 33 |
-
RecodeModelName string
|
| 34 |
RequestURLPath string
|
| 35 |
ApiVersion string
|
| 36 |
PromptTokens int
|
|
@@ -39,6 +39,7 @@ type RelayInfo struct {
|
|
| 39 |
BaseUrl string
|
| 40 |
SupportStreamOptions bool
|
| 41 |
ShouldIncludeUsage bool
|
|
|
|
| 42 |
ClientWs *websocket.Conn
|
| 43 |
TargetWs *websocket.Conn
|
| 44 |
InputAudioFormat string
|
|
@@ -50,6 +51,18 @@ type RelayInfo struct {
|
|
| 50 |
ChannelSetting map[string]interface{}
|
| 51 |
}
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
| 54 |
info := GenRelayInfo(c)
|
| 55 |
info.ClientWs = ws
|
|
@@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|
| 89 |
FirstResponseTime: startTime.Add(-time.Second),
|
| 90 |
OriginModelName: c.GetString("original_model"),
|
| 91 |
UpstreamModelName: c.GetString("original_model"),
|
| 92 |
-
RecodeModelName: c.GetString("
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
| 98 |
}
|
| 99 |
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
| 100 |
info.IsPlayground = true
|
|
@@ -110,9 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|
| 110 |
if info.ChannelType == common.ChannelTypeVertexAi {
|
| 111 |
info.ApiVersion = c.GetString("region")
|
| 112 |
}
|
| 113 |
-
if info.ChannelType
|
| 114 |
-
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
|
| 115 |
-
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
|
| 116 |
info.SupportStreamOptions = true
|
| 117 |
}
|
| 118 |
return info
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
type RelayInfo struct {
|
| 16 |
+
ChannelType int
|
| 17 |
+
ChannelId int
|
| 18 |
+
TokenId int
|
| 19 |
+
TokenKey string
|
| 20 |
+
UserId int
|
| 21 |
+
Group string
|
| 22 |
+
TokenUnlimited bool
|
| 23 |
+
StartTime time.Time
|
| 24 |
+
FirstResponseTime time.Time
|
| 25 |
+
setFirstResponse bool
|
| 26 |
+
ApiType int
|
| 27 |
+
IsStream bool
|
| 28 |
+
IsPlayground bool
|
| 29 |
+
UsePrice bool
|
| 30 |
+
RelayMode int
|
| 31 |
+
UpstreamModelName string
|
| 32 |
+
OriginModelName string
|
| 33 |
+
//RecodeModelName string
|
| 34 |
RequestURLPath string
|
| 35 |
ApiVersion string
|
| 36 |
PromptTokens int
|
|
|
|
| 39 |
BaseUrl string
|
| 40 |
SupportStreamOptions bool
|
| 41 |
ShouldIncludeUsage bool
|
| 42 |
+
IsModelMapped bool
|
| 43 |
ClientWs *websocket.Conn
|
| 44 |
TargetWs *websocket.Conn
|
| 45 |
InputAudioFormat string
|
|
|
|
| 51 |
ChannelSetting map[string]interface{}
|
| 52 |
}
|
| 53 |
|
| 54 |
+
// 定义支持流式选项的通道类型
|
| 55 |
+
var streamSupportedChannels = map[int]bool{
|
| 56 |
+
common.ChannelTypeOpenAI: true,
|
| 57 |
+
common.ChannelTypeAnthropic: true,
|
| 58 |
+
common.ChannelTypeAws: true,
|
| 59 |
+
common.ChannelTypeGemini: true,
|
| 60 |
+
common.ChannelCloudflare: true,
|
| 61 |
+
common.ChannelTypeAzure: true,
|
| 62 |
+
common.ChannelTypeVolcEngine: true,
|
| 63 |
+
common.ChannelTypeOllama: true,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
| 67 |
info := GenRelayInfo(c)
|
| 68 |
info.ClientWs = ws
|
|
|
|
| 102 |
FirstResponseTime: startTime.Add(-time.Second),
|
| 103 |
OriginModelName: c.GetString("original_model"),
|
| 104 |
UpstreamModelName: c.GetString("original_model"),
|
| 105 |
+
//RecodeModelName: c.GetString("original_model"),
|
| 106 |
+
IsModelMapped: false,
|
| 107 |
+
ApiType: apiType,
|
| 108 |
+
ApiVersion: c.GetString("api_version"),
|
| 109 |
+
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
| 110 |
+
Organization: c.GetString("channel_organization"),
|
| 111 |
+
ChannelSetting: channelSetting,
|
| 112 |
}
|
| 113 |
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
| 114 |
info.IsPlayground = true
|
|
|
|
| 124 |
if info.ChannelType == common.ChannelTypeVertexAi {
|
| 125 |
info.ApiVersion = c.GetString("region")
|
| 126 |
}
|
| 127 |
+
if streamSupportedChannels[info.ChannelType] {
|
|
|
|
|
|
|
| 128 |
info.SupportStreamOptions = true
|
| 129 |
}
|
| 130 |
return info
|
relay/helper/model_mapped.go
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package helper
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"fmt"
|
| 6 |
+
"github.com/gin-gonic/gin"
|
| 7 |
+
"one-api/relay/common"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
|
| 11 |
+
// map model name
|
| 12 |
+
modelMapping := c.GetString("model_mapping")
|
| 13 |
+
if modelMapping != "" && modelMapping != "{}" {
|
| 14 |
+
modelMap := make(map[string]string)
|
| 15 |
+
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
| 16 |
+
if err != nil {
|
| 17 |
+
return fmt.Errorf("unmarshal_model_mapping_failed")
|
| 18 |
+
}
|
| 19 |
+
if modelMap[info.OriginModelName] != "" {
|
| 20 |
+
info.UpstreamModelName = modelMap[info.OriginModelName]
|
| 21 |
+
info.IsModelMapped = true
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
return nil
|
| 25 |
+
}
|
relay/helper/price.go
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package helper
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"github.com/gin-gonic/gin"
|
| 5 |
+
"one-api/common"
|
| 6 |
+
relaycommon "one-api/relay/common"
|
| 7 |
+
"one-api/setting"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
type PriceData struct {
|
| 11 |
+
ModelPrice float64
|
| 12 |
+
ModelRatio float64
|
| 13 |
+
GroupRatio float64
|
| 14 |
+
UsePrice bool
|
| 15 |
+
ShouldPreConsumedQuota int
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
|
| 19 |
+
modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
|
| 20 |
+
groupRatio := setting.GetGroupRatio(info.Group)
|
| 21 |
+
var preConsumedQuota int
|
| 22 |
+
var modelRatio float64
|
| 23 |
+
if !usePrice {
|
| 24 |
+
preConsumedTokens := common.PreConsumedQuota
|
| 25 |
+
if maxTokens != 0 {
|
| 26 |
+
preConsumedTokens = promptTokens + maxTokens
|
| 27 |
+
}
|
| 28 |
+
modelRatio = common.GetModelRatio(info.OriginModelName)
|
| 29 |
+
ratio := modelRatio * groupRatio
|
| 30 |
+
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
| 31 |
+
} else {
|
| 32 |
+
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
| 33 |
+
}
|
| 34 |
+
return PriceData{
|
| 35 |
+
ModelPrice: modelPrice,
|
| 36 |
+
ModelRatio: modelRatio,
|
| 37 |
+
GroupRatio: groupRatio,
|
| 38 |
+
UsePrice: usePrice,
|
| 39 |
+
ShouldPreConsumedQuota: preConsumedQuota,
|
| 40 |
+
}
|
| 41 |
+
}
|
relay/relay-audio.go
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
package relay
|
| 2 |
|
| 3 |
import (
|
| 4 |
-
"encoding/json"
|
| 5 |
"errors"
|
| 6 |
"fmt"
|
| 7 |
"github.com/gin-gonic/gin"
|
|
@@ -11,8 +10,10 @@ import (
|
|
| 11 |
"one-api/model"
|
| 12 |
relaycommon "one-api/relay/common"
|
| 13 |
relayconstant "one-api/relay/constant"
|
|
|
|
| 14 |
"one-api/service"
|
| 15 |
"one-api/setting"
|
|
|
|
| 16 |
)
|
| 17 |
|
| 18 |
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
|
@@ -27,8 +28,9 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
|
| 27 |
return nil, errors.New("model is required")
|
| 28 |
}
|
| 29 |
if setting.ShouldCheckPromptSensitive() {
|
| 30 |
-
err := service.CheckSensitiveInput(audioRequest.Input)
|
| 31 |
if err != nil {
|
|
|
|
| 32 |
return nil, err
|
| 33 |
}
|
| 34 |
}
|
|
@@ -73,15 +75,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
| 73 |
relayInfo.PromptTokens = promptTokens
|
| 74 |
}
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
ratio := modelRatio * groupRatio
|
| 79 |
-
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
| 80 |
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
| 81 |
if err != nil {
|
| 82 |
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
| 83 |
}
|
| 84 |
-
preConsumedQuota, userQuota, openaiErr
|
| 85 |
if openaiErr != nil {
|
| 86 |
return openaiErr
|
| 87 |
}
|
|
@@ -91,19 +91,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
| 91 |
}
|
| 92 |
}()
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
modelMap := make(map[string]string)
|
| 98 |
-
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
| 99 |
-
if err != nil {
|
| 100 |
-
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
| 101 |
-
}
|
| 102 |
-
if modelMap[audioRequest.Model] != "" {
|
| 103 |
-
audioRequest.Model = modelMap[audioRequest.Model]
|
| 104 |
-
}
|
| 105 |
}
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
adaptor := GetAdaptor(relayInfo.ApiType)
|
| 109 |
if adaptor == nil {
|
|
@@ -140,7 +133,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
| 140 |
return openaiErr
|
| 141 |
}
|
| 142 |
|
| 143 |
-
postConsumeQuota(c, relayInfo,
|
| 144 |
|
| 145 |
return nil
|
| 146 |
}
|
|
|
|
| 1 |
package relay
|
| 2 |
|
| 3 |
import (
|
|
|
|
| 4 |
"errors"
|
| 5 |
"fmt"
|
| 6 |
"github.com/gin-gonic/gin"
|
|
|
|
| 10 |
"one-api/model"
|
| 11 |
relaycommon "one-api/relay/common"
|
| 12 |
relayconstant "one-api/relay/constant"
|
| 13 |
+
"one-api/relay/helper"
|
| 14 |
"one-api/service"
|
| 15 |
"one-api/setting"
|
| 16 |
+
"strings"
|
| 17 |
)
|
| 18 |
|
| 19 |
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
|
|
|
| 28 |
return nil, errors.New("model is required")
|
| 29 |
}
|
| 30 |
if setting.ShouldCheckPromptSensitive() {
|
| 31 |
+
words, err := service.CheckSensitiveInput(audioRequest.Input)
|
| 32 |
if err != nil {
|
| 33 |
+
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
|
| 34 |
return nil, err
|
| 35 |
}
|
| 36 |
}
|
|
|
|
| 75 |
relayInfo.PromptTokens = promptTokens
|
| 76 |
}
|
| 77 |
|
| 78 |
+
priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
| 79 |
+
|
|
|
|
|
|
|
| 80 |
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
| 81 |
if err != nil {
|
| 82 |
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
| 83 |
}
|
| 84 |
+
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
| 85 |
if openaiErr != nil {
|
| 86 |
return openaiErr
|
| 87 |
}
|
|
|
|
| 91 |
}
|
| 92 |
}()
|
| 93 |
|
| 94 |
+
err = helper.ModelMappedHelper(c, relayInfo)
|
| 95 |
+
if err != nil {
|
| 96 |
+
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
}
|
| 98 |
+
|
| 99 |
+
audioRequest.Model = relayInfo.UpstreamModelName
|
| 100 |
|
| 101 |
adaptor := GetAdaptor(relayInfo.ApiType)
|
| 102 |
if adaptor == nil {
|
|
|
|
| 133 |
return openaiErr
|
| 134 |
}
|
| 135 |
|
| 136 |
+
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
| 137 |
|
| 138 |
return nil
|
| 139 |
}
|
relay/relay-image.go
CHANGED
|
@@ -12,6 +12,7 @@ import (
|
|
| 12 |
"one-api/dto"
|
| 13 |
"one-api/model"
|
| 14 |
relaycommon "one-api/relay/common"
|
|
|
|
| 15 |
"one-api/service"
|
| 16 |
"one-api/setting"
|
| 17 |
"strings"
|
|
@@ -60,15 +61,16 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
|
| 60 |
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
| 61 |
//}
|
| 62 |
if setting.ShouldCheckPromptSensitive() {
|
| 63 |
-
err := service.CheckSensitiveInput(imageRequest.Prompt)
|
| 64 |
if err != nil {
|
|
|
|
| 65 |
return nil, err
|
| 66 |
}
|
| 67 |
}
|
| 68 |
return imageRequest, nil
|
| 69 |
}
|
| 70 |
|
| 71 |
-
func ImageHelper(c *gin.Context
|
| 72 |
relayInfo := relaycommon.GenRelayInfo(c)
|
| 73 |
|
| 74 |
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
|
@@ -77,29 +79,20 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|
| 77 |
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
| 78 |
}
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
modelMap := make(map[string]string)
|
| 84 |
-
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
| 85 |
-
if err != nil {
|
| 86 |
-
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
| 87 |
-
}
|
| 88 |
-
if modelMap[imageRequest.Model] != "" {
|
| 89 |
-
imageRequest.Model = modelMap[imageRequest.Model]
|
| 90 |
-
}
|
| 91 |
}
|
| 92 |
-
relayInfo.UpstreamModelName = imageRequest.Model
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
| 97 |
// modelRatio 16 = modelPrice $0.04
|
| 98 |
// per 1 modelRatio = $0.04 / 16
|
| 99 |
-
|
| 100 |
}
|
| 101 |
|
| 102 |
-
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
| 103 |
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
| 104 |
|
| 105 |
sizeRatio := 1.0
|
|
@@ -122,11 +115,11 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|
| 122 |
}
|
| 123 |
}
|
| 124 |
|
| 125 |
-
|
| 126 |
-
quota := int(
|
| 127 |
|
| 128 |
if userQuota-quota < 0 {
|
| 129 |
-
return service.OpenAIErrorWrapperLocal(
|
| 130 |
}
|
| 131 |
|
| 132 |
adaptor := GetAdaptor(relayInfo.ApiType)
|
|
@@ -184,7 +177,6 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|
| 184 |
}
|
| 185 |
|
| 186 |
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
| 187 |
-
postConsumeQuota(c, relayInfo,
|
| 188 |
-
|
| 189 |
return nil
|
| 190 |
}
|
|
|
|
| 12 |
"one-api/dto"
|
| 13 |
"one-api/model"
|
| 14 |
relaycommon "one-api/relay/common"
|
| 15 |
+
"one-api/relay/helper"
|
| 16 |
"one-api/service"
|
| 17 |
"one-api/setting"
|
| 18 |
"strings"
|
|
|
|
| 61 |
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
| 62 |
//}
|
| 63 |
if setting.ShouldCheckPromptSensitive() {
|
| 64 |
+
words, err := service.CheckSensitiveInput(imageRequest.Prompt)
|
| 65 |
if err != nil {
|
| 66 |
+
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
|
| 67 |
return nil, err
|
| 68 |
}
|
| 69 |
}
|
| 70 |
return imageRequest, nil
|
| 71 |
}
|
| 72 |
|
| 73 |
+
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
| 74 |
relayInfo := relaycommon.GenRelayInfo(c)
|
| 75 |
|
| 76 |
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
|
|
|
| 79 |
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
| 80 |
}
|
| 81 |
|
| 82 |
+
err = helper.ModelMappedHelper(c, relayInfo)
|
| 83 |
+
if err != nil {
|
| 84 |
+
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
}
|
|
|
|
| 86 |
|
| 87 |
+
imageRequest.Model = relayInfo.UpstreamModelName
|
| 88 |
+
|
| 89 |
+
priceData := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
| 90 |
+
if !priceData.UsePrice {
|
| 91 |
// modelRatio 16 = modelPrice $0.04
|
| 92 |
// per 1 modelRatio = $0.04 / 16
|
| 93 |
+
priceData.ModelPrice = 0.0025 * priceData.ModelRatio
|
| 94 |
}
|
| 95 |
|
|
|
|
| 96 |
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
| 97 |
|
| 98 |
sizeRatio := 1.0
|
|
|
|
| 115 |
}
|
| 116 |
}
|
| 117 |
|
| 118 |
+
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
|
| 119 |
+
quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
|
| 120 |
|
| 121 |
if userQuota-quota < 0 {
|
| 122 |
+
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
|
| 123 |
}
|
| 124 |
|
| 125 |
adaptor := GetAdaptor(relayInfo.ApiType)
|
|
|
|
| 177 |
}
|
| 178 |
|
| 179 |
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
| 180 |
+
postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
|
|
|
|
| 181 |
return nil
|
| 182 |
}
|
relay/relay-mj.go
CHANGED
|
@@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|
| 194 |
}
|
| 195 |
defer func(ctx context.Context) {
|
| 196 |
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
| 197 |
-
err :=
|
| 198 |
if err != nil {
|
| 199 |
common.SysError("error consuming token remain quota: " + err.Error())
|
| 200 |
}
|
|
@@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|
| 500 |
|
| 501 |
defer func(ctx context.Context) {
|
| 502 |
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
| 503 |
-
err :=
|
| 504 |
if err != nil {
|
| 505 |
common.SysError("error consuming token remain quota: " + err.Error())
|
| 506 |
}
|
|
|
|
| 194 |
}
|
| 195 |
defer func(ctx context.Context) {
|
| 196 |
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
| 197 |
+
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
| 198 |
if err != nil {
|
| 199 |
common.SysError("error consuming token remain quota: " + err.Error())
|
| 200 |
}
|
|
|
|
| 500 |
|
| 501 |
defer func(ctx context.Context) {
|
| 502 |
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
| 503 |
+
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
| 504 |
if err != nil {
|
| 505 |
common.SysError("error consuming token remain quota: " + err.Error())
|
| 506 |
}
|
relay/relay-text.go
CHANGED
|
@@ -5,6 +5,7 @@ import (
|
|
| 5 |
"encoding/json"
|
| 6 |
"errors"
|
| 7 |
"fmt"
|
|
|
|
| 8 |
"io"
|
| 9 |
"math"
|
| 10 |
"net/http"
|
|
@@ -14,6 +15,7 @@ import (
|
|
| 14 |
"one-api/model"
|
| 15 |
relaycommon "one-api/relay/common"
|
| 16 |
relayconstant "one-api/relay/constant"
|
|
|
|
| 17 |
"one-api/service"
|
| 18 |
"one-api/setting"
|
| 19 |
"strings"
|
|
@@ -75,40 +77,21 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
| 75 |
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
| 76 |
}
|
| 77 |
|
| 78 |
-
// map model name
|
| 79 |
-
//isModelMapped := false
|
| 80 |
-
modelMapping := c.GetString("model_mapping")
|
| 81 |
-
//isModelMapped := false
|
| 82 |
-
if modelMapping != "" && modelMapping != "{}" {
|
| 83 |
-
modelMap := make(map[string]string)
|
| 84 |
-
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
| 85 |
-
if err != nil {
|
| 86 |
-
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
| 87 |
-
}
|
| 88 |
-
if modelMap[textRequest.Model] != "" {
|
| 89 |
-
//isModelMapped = true
|
| 90 |
-
textRequest.Model = modelMap[textRequest.Model]
|
| 91 |
-
// set upstream model name
|
| 92 |
-
//isModelMapped = true
|
| 93 |
-
}
|
| 94 |
-
}
|
| 95 |
-
relayInfo.UpstreamModelName = textRequest.Model
|
| 96 |
-
relayInfo.RecodeModelName = textRequest.Model
|
| 97 |
-
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
|
| 98 |
-
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
| 99 |
-
|
| 100 |
-
var preConsumedQuota int
|
| 101 |
-
var ratio float64
|
| 102 |
-
var modelRatio float64
|
| 103 |
-
//err := service.SensitiveWordsCheck(textRequest)
|
| 104 |
-
|
| 105 |
if setting.ShouldCheckPromptSensitive() {
|
| 106 |
-
err
|
| 107 |
if err != nil {
|
|
|
|
| 108 |
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
| 109 |
}
|
| 110 |
}
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
| 113 |
var promptTokens int
|
| 114 |
if value, exists := c.Get("prompt_tokens"); exists {
|
|
@@ -123,20 +106,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
| 123 |
c.Set("prompt_tokens", promptTokens)
|
| 124 |
}
|
| 125 |
|
| 126 |
-
|
| 127 |
-
preConsumedTokens := common.PreConsumedQuota
|
| 128 |
-
if textRequest.MaxTokens != 0 {
|
| 129 |
-
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
| 130 |
-
}
|
| 131 |
-
modelRatio = common.GetModelRatio(textRequest.Model)
|
| 132 |
-
ratio = modelRatio * groupRatio
|
| 133 |
-
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
| 134 |
-
} else {
|
| 135 |
-
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
| 136 |
-
}
|
| 137 |
|
| 138 |
// pre-consume quota 预消耗配额
|
| 139 |
-
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c,
|
| 140 |
if openaiErr != nil {
|
| 141 |
return openaiErr
|
| 142 |
}
|
|
@@ -219,10 +192,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|
| 219 |
return openaiErr
|
| 220 |
}
|
| 221 |
|
| 222 |
-
if strings.HasPrefix(relayInfo.
|
| 223 |
-
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota,
|
| 224 |
} else {
|
| 225 |
-
postConsumeQuota(c, relayInfo,
|
| 226 |
}
|
| 227 |
return nil
|
| 228 |
}
|
|
@@ -247,19 +220,20 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
|
| 247 |
return promptTokens, err
|
| 248 |
}
|
| 249 |
|
| 250 |
-
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
|
| 251 |
var err error
|
|
|
|
| 252 |
switch info.RelayMode {
|
| 253 |
case relayconstant.RelayModeChatCompletions:
|
| 254 |
-
err = service.CheckSensitiveMessages(textRequest.Messages)
|
| 255 |
case relayconstant.RelayModeCompletions:
|
| 256 |
-
err = service.CheckSensitiveInput(textRequest.Prompt)
|
| 257 |
case relayconstant.RelayModeModerations:
|
| 258 |
-
err = service.CheckSensitiveInput(textRequest.Input)
|
| 259 |
case relayconstant.RelayModeEmbeddings:
|
| 260 |
-
err = service.CheckSensitiveInput(textRequest.Input)
|
| 261 |
}
|
| 262 |
-
return err
|
| 263 |
}
|
| 264 |
|
| 265 |
// 预扣费并返回用户剩余配额
|
|
@@ -272,7 +246,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|
| 272 |
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
| 273 |
}
|
| 274 |
if userQuota-preConsumedQuota < 0 {
|
| 275 |
-
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %
|
| 276 |
}
|
| 277 |
if userQuota > 100*preConsumedQuota {
|
| 278 |
// 用户额度充足,判断令牌额度是否充足
|
|
@@ -282,18 +256,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|
| 282 |
if tokenQuota > 100*preConsumedQuota {
|
| 283 |
// 令牌额度充足,信任令牌
|
| 284 |
preConsumedQuota = 0
|
| 285 |
-
common.LogInfo(c, fmt.Sprintf("user %d quota %
|
| 286 |
}
|
| 287 |
} else {
|
| 288 |
// in this case, we do not pre-consume quota
|
| 289 |
// because the user has enough quota
|
| 290 |
preConsumedQuota = 0
|
| 291 |
-
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %
|
| 292 |
}
|
| 293 |
}
|
| 294 |
|
| 295 |
if preConsumedQuota > 0 {
|
| 296 |
-
err =
|
| 297 |
if err != nil {
|
| 298 |
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
| 299 |
}
|
|
@@ -307,20 +281,19 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|
| 307 |
|
| 308 |
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
|
| 309 |
if preConsumedQuota != 0 {
|
| 310 |
-
|
| 311 |
relayInfoCopy := *relayInfo
|
| 312 |
|
| 313 |
-
err :=
|
| 314 |
if err != nil {
|
| 315 |
common.SysError("error return pre-consumed quota: " + err.Error())
|
| 316 |
}
|
| 317 |
-
}
|
| 318 |
}
|
| 319 |
}
|
| 320 |
|
| 321 |
-
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
| 322 |
-
usage *dto.Usage,
|
| 323 |
-
modelPrice float64, usePrice bool, extraContent string) {
|
| 324 |
if usage == nil {
|
| 325 |
usage = &dto.Usage{
|
| 326 |
PromptTokens: relayInfo.PromptTokens,
|
|
@@ -332,12 +305,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
|
| 332 |
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
| 333 |
promptTokens := usage.PromptTokens
|
| 334 |
completionTokens := usage.CompletionTokens
|
|
|
|
| 335 |
|
| 336 |
tokenName := ctx.GetString("token_name")
|
| 337 |
completionRatio := common.GetCompletionRatio(modelName)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
quota := 0
|
| 340 |
-
if !
|
| 341 |
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
|
| 342 |
quota = int(math.Round(float64(quota) * ratio))
|
| 343 |
if ratio != 0 && quota <= 0 {
|
|
@@ -368,7 +347,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
|
| 368 |
//}
|
| 369 |
quotaDelta := quota - preConsumedQuota
|
| 370 |
if quotaDelta != 0 {
|
| 371 |
-
err :=
|
| 372 |
if err != nil {
|
| 373 |
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
| 374 |
}
|
|
|
|
| 5 |
"encoding/json"
|
| 6 |
"errors"
|
| 7 |
"fmt"
|
| 8 |
+
"github.com/bytedance/gopkg/util/gopool"
|
| 9 |
"io"
|
| 10 |
"math"
|
| 11 |
"net/http"
|
|
|
|
| 15 |
"one-api/model"
|
| 16 |
relaycommon "one-api/relay/common"
|
| 17 |
relayconstant "one-api/relay/constant"
|
| 18 |
+
"one-api/relay/helper"
|
| 19 |
"one-api/service"
|
| 20 |
"one-api/setting"
|
| 21 |
"strings"
|
|
|
|
| 77 |
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
| 78 |
}
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
if setting.ShouldCheckPromptSensitive() {
|
| 81 |
+
words, err := checkRequestSensitive(textRequest, relayInfo)
|
| 82 |
if err != nil {
|
| 83 |
+
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
| 84 |
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
| 85 |
}
|
| 86 |
}
|
| 87 |
|
| 88 |
+
err = helper.ModelMappedHelper(c, relayInfo)
|
| 89 |
+
if err != nil {
|
| 90 |
+
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
textRequest.Model = relayInfo.UpstreamModelName
|
| 94 |
+
|
| 95 |
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
| 96 |
var promptTokens int
|
| 97 |
if value, exists := c.Get("prompt_tokens"); exists {
|
|
|
|
| 106 |
c.Set("prompt_tokens", promptTokens)
|
| 107 |
}
|
| 108 |
|
| 109 |
+
priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
// pre-consume quota 预消耗配额
|
| 112 |
+
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
| 113 |
if openaiErr != nil {
|
| 114 |
return openaiErr
|
| 115 |
}
|
|
|
|
| 192 |
return openaiErr
|
| 193 |
}
|
| 194 |
|
| 195 |
+
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
| 196 |
+
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
| 197 |
} else {
|
| 198 |
+
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
| 199 |
}
|
| 200 |
return nil
|
| 201 |
}
|
|
|
|
| 220 |
return promptTokens, err
|
| 221 |
}
|
| 222 |
|
| 223 |
+
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
|
| 224 |
var err error
|
| 225 |
+
var words []string
|
| 226 |
switch info.RelayMode {
|
| 227 |
case relayconstant.RelayModeChatCompletions:
|
| 228 |
+
words, err = service.CheckSensitiveMessages(textRequest.Messages)
|
| 229 |
case relayconstant.RelayModeCompletions:
|
| 230 |
+
words, err = service.CheckSensitiveInput(textRequest.Prompt)
|
| 231 |
case relayconstant.RelayModeModerations:
|
| 232 |
+
words, err = service.CheckSensitiveInput(textRequest.Input)
|
| 233 |
case relayconstant.RelayModeEmbeddings:
|
| 234 |
+
words, err = service.CheckSensitiveInput(textRequest.Input)
|
| 235 |
}
|
| 236 |
+
return words, err
|
| 237 |
}
|
| 238 |
|
| 239 |
// 预扣费并返回用户剩余配额
|
|
|
|
| 246 |
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
| 247 |
}
|
| 248 |
if userQuota-preConsumedQuota < 0 {
|
| 249 |
+
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
|
| 250 |
}
|
| 251 |
if userQuota > 100*preConsumedQuota {
|
| 252 |
// 用户额度充足,判断令牌额度是否充足
|
|
|
|
| 256 |
if tokenQuota > 100*preConsumedQuota {
|
| 257 |
// 令牌额度充足,信任令牌
|
| 258 |
preConsumedQuota = 0
|
| 259 |
+
common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
|
| 260 |
}
|
| 261 |
} else {
|
| 262 |
// in this case, we do not pre-consume quota
|
| 263 |
// because the user has enough quota
|
| 264 |
preConsumedQuota = 0
|
| 265 |
+
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
|
| 266 |
}
|
| 267 |
}
|
| 268 |
|
| 269 |
if preConsumedQuota > 0 {
|
| 270 |
+
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
| 271 |
if err != nil {
|
| 272 |
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
| 273 |
}
|
|
|
|
| 281 |
|
| 282 |
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
|
| 283 |
if preConsumedQuota != 0 {
|
| 284 |
+
gopool.Go(func() {
|
| 285 |
relayInfoCopy := *relayInfo
|
| 286 |
|
| 287 |
+
err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
| 288 |
if err != nil {
|
| 289 |
common.SysError("error return pre-consumed quota: " + err.Error())
|
| 290 |
}
|
| 291 |
+
})
|
| 292 |
}
|
| 293 |
}
|
| 294 |
|
| 295 |
+
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
| 296 |
+
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
|
|
|
| 297 |
if usage == nil {
|
| 298 |
usage = &dto.Usage{
|
| 299 |
PromptTokens: relayInfo.PromptTokens,
|
|
|
|
| 305 |
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
| 306 |
promptTokens := usage.PromptTokens
|
| 307 |
completionTokens := usage.CompletionTokens
|
| 308 |
+
modelName := relayInfo.OriginModelName
|
| 309 |
|
| 310 |
tokenName := ctx.GetString("token_name")
|
| 311 |
completionRatio := common.GetCompletionRatio(modelName)
|
| 312 |
+
ratio := priceData.ModelRatio * priceData.GroupRatio
|
| 313 |
+
modelRatio := priceData.ModelRatio
|
| 314 |
+
groupRatio := priceData.GroupRatio
|
| 315 |
+
modelPrice := priceData.ModelPrice
|
| 316 |
+
usePrice := priceData.UsePrice
|
| 317 |
|
| 318 |
quota := 0
|
| 319 |
+
if !priceData.UsePrice {
|
| 320 |
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
|
| 321 |
quota = int(math.Round(float64(quota) * ratio))
|
| 322 |
if ratio != 0 && quota <= 0 {
|
|
|
|
| 347 |
//}
|
| 348 |
quotaDelta := quota - preConsumedQuota
|
| 349 |
if quotaDelta != 0 {
|
| 350 |
+
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
|
| 351 |
if err != nil {
|
| 352 |
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
| 353 |
}
|
relay/relay_embedding.go
CHANGED
|
@@ -10,8 +10,8 @@ import (
|
|
| 10 |
"one-api/dto"
|
| 11 |
relaycommon "one-api/relay/common"
|
| 12 |
relayconstant "one-api/relay/constant"
|
|
|
|
| 13 |
"one-api/service"
|
| 14 |
-
"one-api/setting"
|
| 15 |
)
|
| 16 |
|
| 17 |
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
|
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|
| 47 |
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
| 48 |
}
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
if modelMapping != "" && modelMapping != "{}" {
|
| 54 |
-
modelMap := make(map[string]string)
|
| 55 |
-
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
| 56 |
-
if err != nil {
|
| 57 |
-
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
| 58 |
-
}
|
| 59 |
-
if modelMap[embeddingRequest.Model] != "" {
|
| 60 |
-
embeddingRequest.Model = modelMap[embeddingRequest.Model]
|
| 61 |
-
// set upstream model name
|
| 62 |
-
//isModelMapped = true
|
| 63 |
-
}
|
| 64 |
}
|
| 65 |
|
| 66 |
-
|
| 67 |
-
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
|
| 68 |
-
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
| 69 |
-
|
| 70 |
-
var preConsumedQuota int
|
| 71 |
-
var ratio float64
|
| 72 |
-
var modelRatio float64
|
| 73 |
|
| 74 |
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
| 75 |
-
if !success {
|
| 76 |
-
preConsumedTokens := promptToken
|
| 77 |
-
modelRatio = common.GetModelRatio(embeddingRequest.Model)
|
| 78 |
-
ratio = modelRatio * groupRatio
|
| 79 |
-
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
| 80 |
-
} else {
|
| 81 |
-
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
| 82 |
-
}
|
| 83 |
relayInfo.PromptTokens = promptToken
|
| 84 |
|
|
|
|
|
|
|
| 85 |
// pre-consume quota 预消耗配额
|
| 86 |
-
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c,
|
| 87 |
if openaiErr != nil {
|
| 88 |
return openaiErr
|
| 89 |
}
|
|
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|
| 132 |
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
| 133 |
return openaiErr
|
| 134 |
}
|
| 135 |
-
postConsumeQuota(c, relayInfo,
|
| 136 |
return nil
|
| 137 |
}
|
|
|
|
| 10 |
"one-api/dto"
|
| 11 |
relaycommon "one-api/relay/common"
|
| 12 |
relayconstant "one-api/relay/constant"
|
| 13 |
+
"one-api/relay/helper"
|
| 14 |
"one-api/service"
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
|
|
|
| 47 |
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
| 48 |
}
|
| 49 |
|
| 50 |
+
err = helper.ModelMappedHelper(c, relayInfo)
|
| 51 |
+
if err != nil {
|
| 52 |
+
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
}
|
| 54 |
|
| 55 |
+
embeddingRequest.Model = relayInfo.UpstreamModelName
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
relayInfo.PromptTokens = promptToken
|
| 59 |
|
| 60 |
+
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
| 61 |
+
|
| 62 |
// pre-consume quota 预消耗配额
|
| 63 |
+
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
| 64 |
if openaiErr != nil {
|
| 65 |
return openaiErr
|
| 66 |
}
|
|
|
|
| 109 |
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
| 110 |
return openaiErr
|
| 111 |
}
|
| 112 |
+
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
| 113 |
return nil
|
| 114 |
}
|
relay/relay_rerank.go
CHANGED
|
@@ -9,8 +9,8 @@ import (
|
|
| 9 |
"one-api/common"
|
| 10 |
"one-api/dto"
|
| 11 |
relaycommon "one-api/relay/common"
|
|
|
|
| 12 |
"one-api/service"
|
| 13 |
-
"one-api/setting"
|
| 14 |
)
|
| 15 |
|
| 16 |
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
|
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
|
| 40 |
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
| 41 |
}
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
if modelMapping != "" && modelMapping != "{}" {
|
| 47 |
-
modelMap := make(map[string]string)
|
| 48 |
-
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
| 49 |
-
if err != nil {
|
| 50 |
-
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
| 51 |
-
}
|
| 52 |
-
if modelMap[rerankRequest.Model] != "" {
|
| 53 |
-
rerankRequest.Model = modelMap[rerankRequest.Model]
|
| 54 |
-
// set upstream model name
|
| 55 |
-
//isModelMapped = true
|
| 56 |
-
}
|
| 57 |
}
|
| 58 |
|
| 59 |
-
|
| 60 |
-
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
| 61 |
-
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
| 62 |
-
|
| 63 |
-
var preConsumedQuota int
|
| 64 |
-
var ratio float64
|
| 65 |
-
var modelRatio float64
|
| 66 |
|
| 67 |
promptToken := getRerankPromptToken(*rerankRequest)
|
| 68 |
-
if !success {
|
| 69 |
-
preConsumedTokens := promptToken
|
| 70 |
-
modelRatio = common.GetModelRatio(rerankRequest.Model)
|
| 71 |
-
ratio = modelRatio * groupRatio
|
| 72 |
-
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
| 73 |
-
} else {
|
| 74 |
-
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
| 75 |
-
}
|
| 76 |
relayInfo.PromptTokens = promptToken
|
| 77 |
|
|
|
|
|
|
|
| 78 |
// pre-consume quota 预消耗配额
|
| 79 |
-
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c,
|
| 80 |
if openaiErr != nil {
|
| 81 |
return openaiErr
|
| 82 |
}
|
|
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
|
| 124 |
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
| 125 |
return openaiErr
|
| 126 |
}
|
| 127 |
-
postConsumeQuota(c, relayInfo,
|
| 128 |
return nil
|
| 129 |
}
|
|
|
|
| 9 |
"one-api/common"
|
| 10 |
"one-api/dto"
|
| 11 |
relaycommon "one-api/relay/common"
|
| 12 |
+
"one-api/relay/helper"
|
| 13 |
"one-api/service"
|
|
|
|
| 14 |
)
|
| 15 |
|
| 16 |
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
|
|
|
| 40 |
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
| 41 |
}
|
| 42 |
|
| 43 |
+
err = helper.ModelMappedHelper(c, relayInfo)
|
| 44 |
+
if err != nil {
|
| 45 |
+
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
}
|
| 47 |
|
| 48 |
+
rerankRequest.Model = relayInfo.UpstreamModelName
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
promptToken := getRerankPromptToken(*rerankRequest)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
relayInfo.PromptTokens = promptToken
|
| 52 |
|
| 53 |
+
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
| 54 |
+
|
| 55 |
// pre-consume quota 预消耗配额
|
| 56 |
+
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
| 57 |
if openaiErr != nil {
|
| 58 |
return openaiErr
|
| 59 |
}
|
|
|
|
| 101 |
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
| 102 |
return openaiErr
|
| 103 |
}
|
| 104 |
+
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
| 105 |
return nil
|
| 106 |
}
|
relay/relay_task.go
CHANGED
|
@@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
| 113 |
// release quota
|
| 114 |
if relayInfo.ConsumeQuota && taskErr == nil {
|
| 115 |
|
| 116 |
-
err :=
|
| 117 |
if err != nil {
|
| 118 |
common.SysError("error consuming token remain quota: " + err.Error())
|
| 119 |
}
|
|
|
|
| 113 |
// release quota
|
| 114 |
if relayInfo.ConsumeQuota && taskErr == nil {
|
| 115 |
|
| 116 |
+
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
|
| 117 |
if err != nil {
|
| 118 |
common.SysError("error consuming token remain quota: " + err.Error())
|
| 119 |
}
|
router/api-router.go
CHANGED
|
@@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) {
|
|
| 56 |
selfRoute.POST("/pay", controller.RequestEpay)
|
| 57 |
selfRoute.POST("/amount", controller.RequestAmount)
|
| 58 |
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
|
|
|
| 59 |
}
|
| 60 |
|
| 61 |
adminRoute := userRoute.Group("/")
|
|
|
|
| 56 |
selfRoute.POST("/pay", controller.RequestEpay)
|
| 57 |
selfRoute.POST("/amount", controller.RequestAmount)
|
| 58 |
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
| 59 |
+
selfRoute.PUT("/setting", controller.UpdateUserSetting)
|
| 60 |
}
|
| 61 |
|
| 62 |
adminRoute := userRoute.Group("/")
|