|
|
package middleware |
|
|
|
|
|
import ( |
|
|
"errors" |
|
|
"fmt" |
|
|
"net/http" |
|
|
"slices" |
|
|
"strconv" |
|
|
"strings" |
|
|
"time" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/constant" |
|
|
"github.com/QuantumNous/new-api/dto" |
|
|
"github.com/QuantumNous/new-api/model" |
|
|
relayconstant "github.com/QuantumNous/new-api/relay/constant" |
|
|
"github.com/QuantumNous/new-api/service" |
|
|
"github.com/QuantumNous/new-api/setting/ratio_setting" |
|
|
"github.com/QuantumNous/new-api/types" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
) |
|
|
|
|
|
type ModelRequest struct { |
|
|
Model string `json:"model"` |
|
|
Group string `json:"group,omitempty"` |
|
|
} |
|
|
|
|
|
func Distribute() func(c *gin.Context) { |
|
|
return func(c *gin.Context) { |
|
|
var channel *model.Channel |
|
|
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId) |
|
|
modelRequest, shouldSelectChannel, err := getModelRequest(c) |
|
|
if err != nil { |
|
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) |
|
|
return |
|
|
} |
|
|
if ok { |
|
|
id, err := strconv.Atoi(channelId.(string)) |
|
|
if err != nil { |
|
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") |
|
|
return |
|
|
} |
|
|
channel, err = model.GetChannelById(id, true) |
|
|
if err != nil { |
|
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id") |
|
|
return |
|
|
} |
|
|
if channel.Status != common.ChannelStatusEnabled { |
|
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用") |
|
|
return |
|
|
} |
|
|
} else { |
|
|
|
|
|
|
|
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) |
|
|
if modelLimitEnable { |
|
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) |
|
|
if !ok { |
|
|
|
|
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") |
|
|
return |
|
|
} |
|
|
var tokenModelLimit map[string]bool |
|
|
tokenModelLimit, ok = s.(map[string]bool) |
|
|
if !ok { |
|
|
tokenModelLimit = map[string]bool{} |
|
|
} |
|
|
matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) |
|
|
if _, ok := tokenModelLimit[matchName]; !ok { |
|
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
if shouldSelectChannel { |
|
|
if modelRequest.Model == "" { |
|
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空") |
|
|
return |
|
|
} |
|
|
var selectGroup string |
|
|
usingGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) |
|
|
|
|
|
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") { |
|
|
playgroundRequest := &dto.PlayGroundRequest{} |
|
|
err = common.UnmarshalBodyReusable(c, playgroundRequest) |
|
|
if err != nil { |
|
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的playground请求, "+err.Error()) |
|
|
return |
|
|
} |
|
|
if playgroundRequest.Group != "" { |
|
|
if !service.GroupInUserUsableGroups(usingGroup, playgroundRequest.Group) && playgroundRequest.Group != usingGroup { |
|
|
abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组") |
|
|
return |
|
|
} |
|
|
usingGroup = playgroundRequest.Group |
|
|
common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup) |
|
|
} |
|
|
} |
|
|
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(c, usingGroup, modelRequest.Model, 0) |
|
|
if err != nil { |
|
|
showGroup := usingGroup |
|
|
if usingGroup == "auto" { |
|
|
showGroup = fmt.Sprintf("auto(%s)", selectGroup) |
|
|
} |
|
|
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(distributor): %s", showGroup, modelRequest.Model, err.Error()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound)) |
|
|
return |
|
|
} |
|
|
if channel == nil { |
|
|
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", usingGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound)) |
|
|
return |
|
|
} |
|
|
} |
|
|
} |
|
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) |
|
|
SetupContextForSelectedChannel(c, channel, modelRequest.Model) |
|
|
c.Next() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) { |
|
|
var modelRequest ModelRequest |
|
|
err := common.UnmarshalBodyReusable(c, &modelRequest) |
|
|
if err != nil { |
|
|
return nil, errors.New("无效的请求, " + err.Error()) |
|
|
} |
|
|
return &modelRequest, nil |
|
|
} |
|
|
|
|
|
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { |
|
|
var modelRequest ModelRequest |
|
|
shouldSelectChannel := true |
|
|
var err error |
|
|
if strings.Contains(c.Request.URL.Path, "/mj/") { |
|
|
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) |
|
|
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || |
|
|
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || |
|
|
relayMode == relayconstant.RelayModeMidjourneyNotify || |
|
|
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { |
|
|
shouldSelectChannel = false |
|
|
} else { |
|
|
midjourneyRequest := dto.MidjourneyRequest{} |
|
|
err = common.UnmarshalBodyReusable(c, &midjourneyRequest) |
|
|
if err != nil { |
|
|
return nil, false, errors.New("无效的midjourney请求, " + err.Error()) |
|
|
} |
|
|
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) |
|
|
if mjErr != nil { |
|
|
return nil, false, fmt.Errorf(mjErr.Description) |
|
|
} |
|
|
if midjourneyModel == "" { |
|
|
if !success { |
|
|
return nil, false, fmt.Errorf("无效的请求, 无法解析模型") |
|
|
} else { |
|
|
|
|
|
shouldSelectChannel = false |
|
|
} |
|
|
} |
|
|
modelRequest.Model = midjourneyModel |
|
|
} |
|
|
c.Set("relay_mode", relayMode) |
|
|
} else if strings.Contains(c.Request.URL.Path, "/suno/") { |
|
|
relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path) |
|
|
if relayMode == relayconstant.RelayModeSunoFetch || |
|
|
relayMode == relayconstant.RelayModeSunoFetchByID { |
|
|
shouldSelectChannel = false |
|
|
} else { |
|
|
modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action")) |
|
|
modelRequest.Model = modelName |
|
|
} |
|
|
c.Set("platform", string(constant.TaskPlatformSuno)) |
|
|
c.Set("relay_mode", relayMode) |
|
|
} else if strings.Contains(c.Request.URL.Path, "/v1/videos") { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
relayMode := relayconstant.RelayModeUnknown |
|
|
if c.Request.Method == http.MethodPost { |
|
|
relayMode = relayconstant.RelayModeVideoSubmit |
|
|
req, err := getModelFromRequest(c) |
|
|
if err != nil { |
|
|
return nil, false, err |
|
|
} |
|
|
if req != nil { |
|
|
modelRequest.Model = req.Model |
|
|
} |
|
|
} else if c.Request.Method == http.MethodGet { |
|
|
relayMode = relayconstant.RelayModeVideoFetchByID |
|
|
shouldSelectChannel = false |
|
|
} |
|
|
c.Set("relay_mode", relayMode) |
|
|
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { |
|
|
relayMode := relayconstant.RelayModeUnknown |
|
|
if c.Request.Method == http.MethodPost { |
|
|
req, err := getModelFromRequest(c) |
|
|
if err != nil { |
|
|
return nil, false, err |
|
|
} |
|
|
modelRequest.Model = req.Model |
|
|
relayMode = relayconstant.RelayModeVideoSubmit |
|
|
} else if c.Request.Method == http.MethodGet { |
|
|
relayMode = relayconstant.RelayModeVideoFetchByID |
|
|
shouldSelectChannel = false |
|
|
} |
|
|
if _, ok := c.Get("relay_mode"); !ok { |
|
|
c.Set("relay_mode", relayMode) |
|
|
} |
|
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { |
|
|
|
|
|
relayMode := relayconstant.RelayModeGemini |
|
|
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path) |
|
|
if modelName != "" { |
|
|
modelRequest.Model = modelName |
|
|
} |
|
|
c.Set("relay_mode", relayMode) |
|
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { |
|
|
req, err := getModelFromRequest(c) |
|
|
if err != nil { |
|
|
return nil, false, err |
|
|
} |
|
|
modelRequest.Model = req.Model |
|
|
} |
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") { |
|
|
|
|
|
modelRequest.Model = c.Query("model") |
|
|
} |
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { |
|
|
if modelRequest.Model == "" { |
|
|
modelRequest.Model = "text-moderation-stable" |
|
|
} |
|
|
} |
|
|
if strings.HasSuffix(c.Request.URL.Path, "embeddings") { |
|
|
if modelRequest.Model == "" { |
|
|
modelRequest.Model = c.Param("model") |
|
|
} |
|
|
} |
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { |
|
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") |
|
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { |
|
|
|
|
|
contentType := c.ContentType() |
|
|
if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) { |
|
|
req, err := getModelFromRequest(c) |
|
|
if err == nil && req.Model != "" { |
|
|
modelRequest.Model = req.Model |
|
|
} |
|
|
} |
|
|
} |
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { |
|
|
relayMode := relayconstant.RelayModeAudioSpeech |
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { |
|
|
|
|
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1") |
|
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { |
|
|
|
|
|
if req, err := getModelFromRequest(c); err == nil && req.Model != "" { |
|
|
modelRequest.Model = req.Model |
|
|
} |
|
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") |
|
|
relayMode = relayconstant.RelayModeAudioTranslation |
|
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { |
|
|
|
|
|
if req, err := getModelFromRequest(c); err == nil && req.Model != "" { |
|
|
modelRequest.Model = req.Model |
|
|
} |
|
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") |
|
|
relayMode = relayconstant.RelayModeAudioTranscription |
|
|
} |
|
|
c.Set("relay_mode", relayMode) |
|
|
} |
|
|
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") { |
|
|
|
|
|
req, err := getModelFromRequest(c) |
|
|
if err != nil { |
|
|
return nil, false, err |
|
|
} |
|
|
modelRequest.Model = req.Model |
|
|
modelRequest.Group = req.Group |
|
|
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group) |
|
|
} |
|
|
return &modelRequest, shouldSelectChannel, nil |
|
|
} |
|
|
|
|
|
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError { |
|
|
c.Set("original_model", modelName) |
|
|
if channel == nil { |
|
|
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) |
|
|
} |
|
|
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings()) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride()) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride()) |
|
|
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" { |
|
|
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization) |
|
|
} |
|
|
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan()) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping()) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping()) |
|
|
|
|
|
key, index, newAPIError := channel.GetNextEnabledKey() |
|
|
if newAPIError != nil { |
|
|
return newAPIError |
|
|
} |
|
|
if channel.ChannelInfo.IsMultiKey { |
|
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index) |
|
|
} else { |
|
|
|
|
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false) |
|
|
} |
|
|
|
|
|
common.SetContextKey(c, constant.ContextKeyChannelKey, key) |
|
|
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) |
|
|
|
|
|
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, false) |
|
|
|
|
|
|
|
|
switch channel.Type { |
|
|
case constant.ChannelTypeAzure: |
|
|
c.Set("api_version", channel.Other) |
|
|
case constant.ChannelTypeVertexAi: |
|
|
c.Set("region", channel.Other) |
|
|
case constant.ChannelTypeXunfei: |
|
|
c.Set("api_version", channel.Other) |
|
|
case constant.ChannelTypeGemini: |
|
|
c.Set("api_version", channel.Other) |
|
|
case constant.ChannelTypeAli: |
|
|
c.Set("plugin", channel.Other) |
|
|
case constant.ChannelCloudflare: |
|
|
c.Set("api_version", channel.Other) |
|
|
case constant.ChannelTypeMokaAI: |
|
|
c.Set("api_version", channel.Other) |
|
|
case constant.ChannelTypeCoze: |
|
|
c.Set("bot_id", channel.Other) |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func extractModelNameFromGeminiPath(path string) string { |
|
|
|
|
|
modelsPrefix := "/models/" |
|
|
modelsIndex := strings.Index(path, modelsPrefix) |
|
|
if modelsIndex == -1 { |
|
|
return "" |
|
|
} |
|
|
|
|
|
|
|
|
startIndex := modelsIndex + len(modelsPrefix) |
|
|
if startIndex >= len(path) { |
|
|
return "" |
|
|
} |
|
|
|
|
|
|
|
|
colonIndex := strings.Index(path[startIndex:], ":") |
|
|
if colonIndex == -1 { |
|
|
|
|
|
return path[startIndex:] |
|
|
} |
|
|
|
|
|
|
|
|
return path[startIndex : startIndex+colonIndex] |
|
|
} |
|
|
|